├── LICENSE
├── README.EN.md
├── README.md
├── checksave
├── denoiser.pth
├── init.pth
├── seal_denoiser.pth
└── seal_init.pth
├── conf.yml
├── data
└── docdata.py
├── demo
├── 0000184_blur.png
├── 0000184_orig.png
├── DEGAN_0000184.png
├── DEGAN_0001260.png
├── inference.ipynb
└── teaser.png
├── main.py
├── model
└── DocDiff.py
├── schedule
├── diffusionSample.py
├── dpm_solver_pytorch.py
└── schedule.py
├── src
├── config.py
├── sobel.py
├── train.py
└── trainer.py
└── utils
├── font
├── simhei.ttf
├── times.ttf
├── timesbd.ttf
├── timesbi.ttf
├── timesi.ttf
├── 方正仿宋_GBK.TTF
├── 楷体_GB2312.ttf
└── 青鸟华光简琥珀.ttf
└── marker.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Zongyuan Yang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.EN.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |

4 |
5 |
6 |
7 |
8 |
9 | [简体中文](README.md) | [English](README.EN.md) | [Paper](https://dl.acm.org/doi/abs/10.1145/3581783.3611730)
10 | # DocDiff
11 | This is the official repository for the paper [DocDiff: Document Enhancement via Residual Diffusion Models](https://dl.acm.org/doi/abs/10.1145/3581783.3611730). DocDiff is a document enhancement model (please refer to the [paper](https://arxiv.org/abs/2305.03892v1)) that can be used for tasks such as document deblurring, denoising, binarization, watermark and stamp removal, etc. DocDiff is a lightweight residual prediction-based diffusion model, that can be trained on a batch size of 64 with only 12GB of VRAM at a resolution of 128*128.
12 |
13 | Not only for document enhancement, DocDiff can also be used for other img2img tasks, such as natural scene deblurring[1](#refer-anchor-1), denoising, rain removal, super-resolution[2](#refer-anchor-2), image inpainting, as well as high-level tasks such as semantic segmentation[4](#refer-anchor-4).
14 |
15 |
16 | # News
17 |
18 | - **Pinned**: Introducing our laboratory-developed versatile and cross-platform [**OCR software**](https://www.aibupt.com/). **It includes the automatic removal of watermarks and stamps using DocDiff (automatic watermark removal feature coming soon)**. It also encompasses various commonly used OCR functions such as PDF to Word conversion, PDF to Excel conversion, formula recognition, and table recognition. Feel free to give it a try!
19 | - 2023.09.14: Uploaded watermark synthesis code `utils/marker.py` and seal dataset. [Seal dataset Google Drive](https://drive.google.com/file/d/125SgEmHFUIzDexsrj2d3yMJdYMVhovti/view?usp=sharing)
20 | - 2023.08.02: Document binarization results for H-DIBCO 2018 [6](#refer-anchor-6) and DIBCO 2019 [7](#refer-anchor-7) have been uploaded. You can access them in the [Google Drive](https://drive.google.com/drive/folders/1gT8PFnfW0qFbFmWX6ReQntfFr9POVtYR?usp=sharing)
21 | - 2023.08.01: **Congratulations! DocDiff has been accepted by ACM Multimedia 2023!**
22 | - 2023.06.13: The inference notebook `demo/inference.ipynb` is uploaded for convenient reproduction and pretrained models `checksave/` are uploaded.
23 | - 2023.05.08: The initial version of the code is uploaded. Please check the to-do list for future updates.
24 |
25 | # Guide
26 |
27 | Whether it's for training or inference, you just need to modify the configuration parameters in `conf.yml` and run `main.py`. MODE=1 is for training, MODE=0 is for inference. The parameters in `conf.yml` have detailed annotations, so you can modify them as needed. Pre-trained weights for document deblurring Coarse Predictor and Denoiser can be found in `checksave/`, respectively.
28 |
29 | Please note that the default parameters in `conf.yml` work best for document scenarios. If you want to apply DocDiff to natural scenes, please first read [Notes!](#notes!) carefully. If you still have issues, welcome to submit an issue.
30 |
31 | - Because downsampling is applied three times, the resolution of the input image must be a multiple of 8. If your image is not a multiple of 8, you can adjust the image to be a multiple of 8 using padding or cropping. Please do not directly resize, as it may cause image distortion. In particular, in the deblurring task, image distortion will increase the blur and result in poor performance. For example, the document deblurring dataset [5](#refer-anchor-5) used by DocDiff has a resolution of 300\*300, which needs to be padded to 304\*304 before inference.
32 |
33 | ## Environment
34 |
35 | - python >= 3.7
36 | - pytorch >= 1.7.0
37 | - torchvision >= 0.8.0
38 |
39 | ## Watermark Synthesis and Seal Dataset
40 |
41 | We provide watermark synthesis code `utils/marker.py` and a stamp dataset. [Seal dataset Google Drive](https://drive.google.com/file/d/125SgEmHFUIzDexsrj2d3yMJdYMVhovti/view?usp=sharing). Since the document background images used are our internal data, we did not provide the background images. If you want to use the watermark synthesis code, you need to find some document background images yourself. The watermark synthesis code is implemented based on OpenCV, so you need to install OpenCV.
42 |
43 | ### Seal Dataset
44 |
45 | The Seal Dataset belongs to the [DocDiff project](https://github.com/Royalvice/DocDiff). It contains 1597 red seals in Chinese scenes, along with their corresponding binary masks. These seal data can be used for tasks such as seal synthesis and seal removal. Due to limited manpower, it is extremely difficult to extract seals from document images, so some seal images may contain noise. Most of the original seal
46 | images in the dataset are from the ICDAR 2023 Competition on Reading the Seal Title ([https://rrc.cvc.uab.es/?ch=20](https://rrc.cvc.uab.es/?ch=20)) dataset, and a few are from our internal images. If you find this dataset helpful, please give our [project](https://github.com/Royalvice/DocDiff) a free star, thank you!!!
47 |
48 |
49 |
50 | # Notes!
51 |
52 |
53 |
54 | - The default configuration parameters of DocDiff are designed for **document images**, and if you want to achieve better results when using it for **natural scenes**, you need to adjust the parameters. For example, you can scale up the model, add **self-attention**, etc. (because document images have relatively fixed patterns, but natural scenes have more diverse patterns and require more parameters). Additionally, you may need to modify the **training and inference strategies**.
55 | - **Training strategy**: As described in the paper, in document scenarios, we do not pursue diverse results and we need to minimize the inference time as much as possible. Therefore, we set the diffusion step T to 100, and predict $x_0$ instead of predicting $\epsilon$. Based on the premise of using a channel-wise concatenation conditioning scheme, this strategy can recover a fine $x_0$ in the early steps of reverse diffusion. In natural scenes, in order to better reconstruct textures and pursue diverse results, the diffusion step T should be set as large as possible, and $\epsilon$ should be predicted. You just need to modify **PRE_ORI="False"** in `conf.yml` to use the scheme of predicting $\epsilon$, and modify **TIMESTEPS=1000** to use a larger diffusion step.
56 | - **Inference strategy**: The images generated in document scenarios should not have randomness. (short-step stochastic sampling may cause text edges to be distorted), so DocDiff performs deterministic sampling as described in DDIM[3](#refer-anchor-3). In natural scenes, stochastic sampling is essential for diverse results, so you can use stochastic sampling by modifying **PRE_ORI="False"** in `conf.yml`. In other words, the scheme of predicting $\epsilon$ is bound to stochastic sampling, while the scheme of predicting $x_0$ is bound to deterministic sampling. If you want to predict $x_0$ and use stochastic sampling, or predict $\epsilon$ and use deterministic sampling, you need to modify the code yourself. In DocDiff, deterministic sampling is performed using the method in DDIM, while stochastic sampling is performed using the method in DDPM. You can modify the code to implement other sampling strategies yourself.
57 | - **Summary**: For tasks that do not require diverse results, such as semantic segmentation, document enhancement, predicting $x_0$ with a diffusion step of 100 is enough, and the performance is already good. For tasks that require diverse results, such as deblurring for natural scenes, super-resolution, image restoration, etc., predicting $\epsilon$ with a diffusion step of 1000 is recommended.
58 |
59 | # To-do Lists
60 |
61 | - [x] Add training code
62 | - [x] Add inference code
63 | - [x] Upload pre-trained model
64 | - [x] Upload watermark synthesis code and seal dataset.
65 | - [x] Use DPM_solver to reduce inference step size (although the effect is not significant in practice)
66 | - [x] Uploaded the inference notebook for convenient reproduction
67 | - [ ] Synthesize document datasets with more noise, such as salt-and-pepper noise and noise generated from compression.
68 | - [ ] Train on multiple GPUs
69 | - [ ] Jump-step sampling for DDIM
70 | - [ ] Use depth separable convolution to compress the model
71 | - [ ] Train the model on natural scenes and provide results and pre-trained models
72 |
73 | # Stars over time
74 |
75 | [](https://starchart.cc/Royalvice/DocDiff)
76 |
77 | # Acknowledgement
78 |
79 | - If you find DocDiff helpful, please give us a star. Thank you! 🤞😘
80 | - If you have any questions, please don't hesitate to open an issue. We will reply as soon as possible.
81 | - If you want to communicate with us, please send an email to **viceyzy@foxmail.com** with the subject "**DocDiff**".
82 | - If you want to use DocDiff as the baseline for your project, please cite our paper.
83 | ```
84 | @inproceedings{yang2023docdiff,
85 | title={DocDiff: Document Enhancement via Residual Diffusion Models},
86 | author={Yang, Zongyuan and Liu, Baolin and Xxiong, Yongping and Yi, Lan and Wu, Guibin and Tang, Xiaojun and Liu, Ziqi and Zhou, Junjie and Zhang, Xing},
87 | booktitle={Proceedings of the 31st ACM International Conference on Multimedia},
88 | pages={2795--2806},
89 | year={2023}
90 | }
91 | ```
92 |
93 | # References
94 |
95 |
96 | - [1] Whang J, Delbracio M, Talebi H, et al. Deblurring via stochastic refinement[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022: 16293-16303.
97 |
98 |
99 |
100 | - [2] Shang S, Shan Z, Liu G, et al. ResDiff: Combining CNN and Diffusion Model for Image Super-Resolution[J]. arXiv preprint arXiv:2303.08714, 2023.
101 |
102 |
103 |
104 | - [3] Song J, Meng C, Ermon S. Denoising diffusion implicit models[J]. arXiv preprint arXiv:2010.02502, 2020.
105 |
106 |
107 |
108 | - [4] Wu J, Fang H, Zhang Y, et al. MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model[J]. arXiv preprint arXiv:2211.00611, 2022.
109 |
110 |
111 |
112 | - [5] Michal Hradiš, Jan Kotera, Pavel Zemčík and Filip Šroubek. Convolutional Neural Networks for Direct Text Deblurring. In Xianghua Xie, Mark W. Jones, and Gary K. L. Tam, editors, Proceedings of the British Machine Vision Conference (BMVC), pages 6.1-6.13. BMVA Press, September 2015.
113 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |

4 |
5 |
6 |
7 |
8 |
9 | [简体中文](README.md) | [English](README.EN.md) | [Paper](https://dl.acm.org/doi/abs/10.1145/3581783.3611730)
10 |
11 | [](https://visitorbadge.io/status?path=https%3A%2F%2Fgithub.com%2FRoyalvice%2FDocDiff)
12 |
13 | # DocDiff
14 | 这里是论文[DocDiff: Document Enhancement via Residual Diffusion Models](https://dl.acm.org/doi/abs/10.1145/3581783.3611730)的官方复现仓库。DocDiff是一个文档增强模型(详见[论文](https://arxiv.org/abs/2305.03892v1)),可以用于文档去模糊、文档去噪、文档二值化、文档去水印和印章等任务。DocDiff是一个轻量级的基于残差预测的扩散模型,在128*128分辨率上以Batchsize=64训练只需要12GB显存。
15 | 不仅文档增强,DocDiff还可以应用在其他img2img任务上,比如自然场景去模糊[1](#refer-anchor-1),去噪,去雨,超分[2](#refer-anchor-2),图像修复等low-level任务以及语义分割[4](#refer-anchor-4)等high-level任务。
16 |
17 |
18 | # News
19 |
20 | - **置顶**: 介绍一款我们实验室开发的多功能且多平台的[**OCR软件**](https://www.aibupt.com/)。**其中包含了DocDiff的自动去除水印和印章的功能(自动去除水印功能即将上线)**。同样包含常用的各种OCR功能,例如PDF转word,PDF转excel,公式识别,表格识别。欢迎试用!
21 | - 2023.09.14: 上传了水印合成代码`utils/marker.py`和印章数据集。[印章数据集Google Drive](https://drive.google.com/file/d/125SgEmHFUIzDexsrj2d3yMJdYMVhovti/view?usp=sharing)
22 | - 2023.08.02: H-DIBCO 2018 [6](#refer-anchor-6) 和 DIBCO 2019 [7](#refer-anchor-7) 的文档二值化结果已经上传。[Google Drive](https://drive.google.com/drive/folders/1gT8PFnfW0qFbFmWX6ReQntfFr9POVtYR?usp=sharing)
23 | - 2023.08.01: **祝贺!DocDiff被ACM Multimedia 2023接收!**
24 | - 2023.06.13: 为了方便复现,已上传推理笔记本`demo/inference.ipynb`和预训练模型`checksave/`。
25 | - 2023.05.08: 代码的初始版本已经上传。请查看To-do lists来获取未来的更新。
26 |
27 | # 使用指南
28 |
29 | 无论是训练还是推理,你只需要修改conf.yml中的配置参数,然后运行main.py即可。MODE=1为训练,MODE=0为推理。conf.yml中的参数都有详细注释,你可以根据注释修改参数。文档去模糊预训练权重在`checksave/`。
30 | **请注意**conf.yml中的默认参数在文档场景表现最好。如果你想应用DocDiff在自然场景,请先看一下[注意事项!!!](#注意事项!!!)。如果仍有问题,欢迎提issue。
31 |
32 | - 由于要下采样3次,所以输入图像的分辨率必须是8的倍数。如果你的图像不是8的倍数,可以使用padding或者裁剪的方式将图像调整为8的倍数。请不要直接Resize,因为这样会导致图像失真。尤其在去模糊任务中,图像失真会导致模糊程度增加,效果会变得很差。例如,DocDiff使用的文档去模糊数据集[5](#refer-anchor-5)分辨率为300\*300,需要先padding到304\*304,再送入推理。
33 |
34 | ## 环境配置
35 |
36 | - python >= 3.7
37 | - pytorch >= 1.7.0
38 | - torchvision >= 0.8.0
39 |
40 | ## 水印合成与印章数据集
41 |
42 | 我们提供了水印合成代码`utils/marker.py`和印章数据集。[印章数据集Google Drive](https://drive.google.com/file/d/125SgEmHFUIzDexsrj2d3yMJdYMVhovti/view?usp=sharing)。由于使用的文档背景图像是我们内部的数据,所以我们没有提供背景图片。如果你想使用水印合成代码,你需要自己找一些文档背景图像。水印合成代码是基于OpenCV实现的,所以你需要安装OpenCV。
43 |
44 | ### 印章数据集
45 |
46 | 印章数据集隶属于[DocDiff项目](https://github.com/Royalvice/DocDiff),其中包含1597个中文场景下的红色系印章以及它们对应的二值化的掩膜,这些印章数据可以用于印章合成、印章消除等等任务中。由于人力有限,而从文档图片中抠出来印章是极其困难的事情,所以某些印章图片中包含一些噪声。数据集中的原始印章图片大部分来自于ICDAR 2023 Competition on Reading the Seal Title([https://rrc.cvc.uab.es/?ch=20](https://rrc.cvc.uab.es/?ch=20))数据集,少部分来自于我们自己内部的图片。如果您觉得这份数据集对您有帮助,请给我们的[项目](https://github.com/Royalvice/DocDiff)一个免费的star,谢谢!!!
47 |
48 |
49 |
50 | # 注意事项!!!
51 |
52 |
53 | - DocDiff的默认配置参数,训练和推理策略是为**文档图像设计**的,如果要用于自然场景,想获得更好的效果,需要**调整参数**,比如扩大模型,添加Self-Attention等(因为文档图像的模式相对固定,但是自然场景的模式比较多样需要更多的参数)并修改**训练和推理策略**。
54 | - **训练策略**:如论文所述,在文档场景中,因为不追求生成多样性,并且希望尽可能缩减推理时间。所以我们将扩散步长T设为100,并预测 $x_0$ 而不是预测 $\epsilon$。在使用基于通道叠加的引入条件(Coarse Predictor的输出)的方案的前提下,这种策略可以使得在逆向扩散的前几步就可以恢复出较好的 $x_0$ 。在自然场景中,为了更好地重建纹理并追求生成多样性,扩散步长T尽可能大,并要预测 $\epsilon$ 。你只需要修改**conf.yml**中的**PRE_ORI="False"**,即可使用预测 $\epsilon$ 的方案; 修改**conf.yml**中的**TIMESTEPS=1000**,即可使用更大的扩散步长。
55 | - **推理策略**:在文档场景中生成的图像不想带有随机性(短步随机采样会导致文本边缘扭曲),DocDiff执行DDIM[3](#refer-anchor-3)中的确定采样。在自然场景中,随机采样是生成多样性的关键,修改**conf.yml**中的**PRE_ORI="False"**,即可使用随机采样。也就是说,预测 $\epsilon$ 的方案与随机采样是绑定的,而预测 $x_0$ 的方案与确定采样是绑定的。如果你想预测 $x_0$ 并随机采样或者 预测 $\epsilon$ 并确定采样,你需要自己修改代码。DocDiff中确定采样是DDIM中的确定采样,随机采样是DDPM中的随机采样,你可以自己修改代码实现其他采样策略。
56 | - **总结**:应用于不需要生成多样性的任务,比如语义分割,文档增强,使用预测 $x_0$ 的方案,扩散步长T设为100就ok,效果已经很好了;应用于需要生成多样性的任务,比如自然场景去模糊,超分,图像修复等,使用预测 $\epsilon$ 的方案,扩散步长T设为1000。
57 |
58 | # To-do lists
59 |
60 | - [x] 添加训练代码
61 | - [x] 添加推理代码
62 | - [x] 上传预训练模型
63 | - [x] 上传水印合成代码和印章数据集
64 | - [x] 使用DPM_solver减少推理步长(实际用起来,效果一般)
65 | - [x] 上传Inference notebook,方便复现
66 | - [ ] 合成包含更多噪声的文档数据集(比如椒盐噪声,压缩产生的噪声)
67 | - [ ] 多GPU训练
68 | - [ ] DDIM的跳步采样
69 | - [ ] 使用深度可分离卷积压缩模型
70 | - [ ] 在自然场景上训练模型并提供结果和预训练模型
71 |
72 | # Stars over time
73 |
74 | [](https://starchart.cc/Royalvice/DocDiff)
75 |
76 | # 感谢
77 |
78 | - 如果你觉得DocDiff对你有帮助,请给个star,谢谢!🤞😘
79 | - 如果你有任何问题,欢迎提issue,我会尽快回复。
80 | - 如果你想交流,欢迎给我发邮件**viceyzy@foxmail.com**,备注:**DocDiff**。
81 | - 如果你愿意将DocDiff作为你的项目的baseline,欢迎引用我们的论文。
82 | ```
83 | @inproceedings{yang2023docdiff,
84 | title={DocDiff: Document Enhancement via Residual Diffusion Models},
85 | author={Yang, Zongyuan and Liu, Baolin and Xxiong, Yongping and Yi, Lan and Wu, Guibin and Tang, Xiaojun and Liu, Ziqi and Zhou, Junjie and Zhang, Xing},
86 | booktitle={Proceedings of the 31st ACM International Conference on Multimedia},
87 | pages={2795--2806},
88 | year={2023}
89 | }
90 | ```
91 |
92 | # References
93 |
94 |
95 | - [1] Whang J, Delbracio M, Talebi H, et al. Deblurring via stochastic refinement[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022: 16293-16303.
96 |
97 |
98 |
99 | - [2] Shang S, Shan Z, Liu G, et al. ResDiff: Combining CNN and Diffusion Model for Image Super-Resolution[J]. arXiv preprint arXiv:2303.08714, 2023.
100 |
101 |
102 |
103 | - [3] Song J, Meng C, Ermon S. Denoising diffusion implicit models[J]. arXiv preprint arXiv:2010.02502, 2020.
104 |
105 |
106 |
107 | - [4] Wu J, Fang H, Zhang Y, et al. MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model[J]. arXiv preprint arXiv:2211.00611, 2022.
108 |
109 |
110 |
111 | - [5] Michal Hradiš, Jan Kotera, Pavel Zemčík and Filip Šroubek. Convolutional Neural Networks for Direct Text Deblurring. In Xianghua Xie, Mark W. Jones, and Gary K. L. Tam, editors, Proceedings of the British Machine Vision Conference (BMVC), pages 6.1-6.13. BMVA Press, September 2015.
112 |
113 |
114 |
115 | - [6] I. Pratikakis, K. Zagori, P. Kaddas and B. Gatos, "ICFHR 2018 Competition on Handwritten Document Image Binarization (H-DIBCO 2018)," 2018 16th International Conference on Frontiers in Handwriting Recognition (ICFHR), Niagara Falls, NY, USA, 2018, pp. 489-493, doi: 10.1109/ICFHR-2018.2018.00091.
116 |
117 |
118 |
119 | - [7] I. Pratikakis, K. Zagoris, X. Karagiannis, L. Tsochatzidis, T. Mondal and I. Marthot-Santaniello, "ICDAR 2019 Competition on Document Image Binarization (DIBCO 2019)," 2019 International Conference on Document Analysis and Recognition (ICDAR), Sydney, NSW, Australia, 2019, pp. 1547-1556, doi: 10.1109/ICDAR.2019.00249.
120 |
--------------------------------------------------------------------------------
/checksave/denoiser.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/checksave/denoiser.pth
--------------------------------------------------------------------------------
/checksave/init.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/checksave/init.pth
--------------------------------------------------------------------------------
/checksave/seal_denoiser.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/checksave/seal_denoiser.pth
--------------------------------------------------------------------------------
/checksave/seal_init.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/checksave/seal_init.pth
--------------------------------------------------------------------------------
/conf.yml:
--------------------------------------------------------------------------------
1 | # model
2 | IMAGE_SIZE : [128, 128] # load image size, if it's train mode, it will be randomly cropped to IMAGE_SIZE. If it's test mode, it will be resized to IMAGE_SIZE.
3 | CHANNEL_X : 3 # input channel
4 | CHANNEL_Y : 3 # output channel
5 | TIMESTEPS : 100 # diffusion steps
6 | SCHEDULE : 'linear' # linear or cosine
7 | MODEL_CHANNELS : 32 # basic channels of Unet
8 | NUM_RESBLOCKS : 1 # number of residual blocks
9 | CHANNEL_MULT : [1,2,3,4] # channel multiplier of each layer
10 | NUM_HEADS : 1
11 |
12 | MODE : 1 # 1 Train, 0 Test
13 | PRE_ORI : 'True' # if True, predict $x_0$, else predict $\epsilon$.
14 |
15 |
16 | # train
17 | PATH_GT : '' # path of ground truth
18 | PATH_IMG : '' # path of input
19 | BATCH_SIZE : 32 # training batch size
20 | NUM_WORKERS : 16 # number of workers
21 | ITERATION_MAX : 1000000 # max training iteration
22 | LR : 0.0001 # learning rate
23 | LOSS : 'L2' # L1 or L2
24 | EMA_EVERY : 100 # update EMA every EMA_EVERY iterations
25 | START_EMA : 2000 # start EMA after START_EMA iterations
26 | SAVE_MODEL_EVERY : 10000 # save model every SAVE_MODEL_EVERY iterations
27 | EMA: 'True' # if True, use EMA
28 | CONTINUE_TRAINING : 'False' # if True, continue training
29 | CONTINUE_TRAINING_STEPS : 10000 # continue training from CONTINUE_TRAINING_STEPS
30 | PRETRAINED_PATH_INITIAL_PREDICTOR : '' # path of pretrained initial predictor
31 | PRETRAINED_PATH_DENOISER : '' # path of pretrained denoiser
32 | WEIGHT_SAVE_PATH : './checksave' # path to save model
33 | TRAINING_PATH : './Training' # path of training data
34 | BETA_LOSS : 50 # hyperparameter to balance the pixel loss and the diffusion loss
35 | HIGH_LOW_FREQ : 'True' # if True, training with frequency separation
36 |
37 |
38 | # test
39 | NATIVE_RESOLUTION : 'False' # if True, test with native resolution
40 | DPM_SOLVER : 'False' # if True, test with DPM_solver
41 | DPM_STEP : 20 # DPM_solver step
42 | BATCH_SIZE_VAL : 1 # test batch size
43 | TEST_PATH_GT : '' # path of ground truth
44 | TEST_PATH_IMG : '' # path of input
45 | TEST_INITIAL_PREDICTOR_WEIGHT_PATH : '' # path of initial predictor
46 | TEST_DENOISER_WEIGHT_PATH : '' # path of denoiser
47 | TEST_IMG_SAVE_PATH : './results' # path to save results
48 |
--------------------------------------------------------------------------------
/data/docdata.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.data import Dataset
4 | from torchvision.transforms import Compose, ToTensor, RandomAffine, RandomHorizontalFlip, RandomCrop
5 | from PIL import Image
6 |
7 |
8 | def ImageTransform(loadSize):
9 | return {"train": Compose([
10 | RandomCrop(loadSize, pad_if_needed=True, padding_mode='constant', fill=255),
11 | RandomAffine(10, fill=255),
12 | RandomHorizontalFlip(p=0.2),
13 | ToTensor(),
14 | ]), "test": Compose([
15 | ToTensor(),
16 | ]), "train_gt": Compose([
17 | RandomCrop(loadSize, pad_if_needed=True, padding_mode='constant', fill=255),
18 | RandomAffine(10, fill=255),
19 | RandomHorizontalFlip(p=0.2),
20 | ToTensor(),
21 | ])}
22 |
23 |
24 | class DocData(Dataset):
25 | def __init__(self, path_img, path_gt, loadSize, mode=1):
26 | super().__init__()
27 | self.path_gt = path_gt
28 | self.path_img = path_img
29 | self.data_gt = os.listdir(path_gt)
30 | self.data_img = os.listdir(path_img)
31 | self.mode = mode
32 | if mode == 1:
33 | self.ImgTrans = (ImageTransform(loadSize)["train"], ImageTransform(loadSize)["train_gt"])
34 | else:
35 | self.ImgTrans = ImageTransform(loadSize)["test"]
36 |
37 | def __len__(self):
38 | return len(self.data_gt)
39 |
40 | def __getitem__(self, idx):
41 |
42 | gt = Image.open(os.path.join(self.path_gt, self.data_img[idx]))
43 | img = Image.open(os.path.join(self.path_img, self.data_img[idx]))
44 | img = img.convert('RGB')
45 | gt = gt.convert('RGB')
46 | if self.mode == 1:
47 | seed = torch.random.seed()
48 | torch.random.manual_seed(seed)
49 | img = self.ImgTrans[0](img)
50 | torch.random.manual_seed(seed)
51 | gt = self.ImgTrans[1](gt)
52 | else:
53 | img= self.ImgTrans(img)
54 | gt = self.ImgTrans(gt)
55 | name = self.data_img[idx]
56 | return img, gt, name
57 |
--------------------------------------------------------------------------------
/demo/0000184_blur.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/demo/0000184_blur.png
--------------------------------------------------------------------------------
/demo/0000184_orig.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/demo/0000184_orig.png
--------------------------------------------------------------------------------
/demo/DEGAN_0000184.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/demo/DEGAN_0000184.png
--------------------------------------------------------------------------------
/demo/DEGAN_0001260.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/demo/DEGAN_0001260.png
--------------------------------------------------------------------------------
/demo/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/demo/teaser.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | from src.config import load_config
2 | from src.train import train, test
3 | import argparse
4 |
5 |
6 | def main():
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--config', type=str, default='conf.yml', help='path to the config.yaml file')
9 | args = parser.parse_args()
10 | config = load_config(args.config)
11 | print('Config loaded')
12 | mode = config.MODE
13 | if mode == 1:
14 | train(config)
15 | else:
16 | test(config)
17 | if __name__ == "__main__":
18 | main()
--------------------------------------------------------------------------------
/model/DocDiff.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Optional, Tuple, Union, List
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 | from src.sobel import Sobel, Laplacian
7 |
8 |
9 | class Swish(nn.Module):
10 | """
11 | ### Swish activation function
12 | $$x \cdot \sigma(x)$$
13 | """
14 |
15 | def forward(self, x):
16 | return x * torch.sigmoid(x)
17 |
18 |
19 | class TimeEmbedding(nn.Module):
20 | """
21 | ### Embeddings for $t$
22 | """
23 |
24 | def __init__(self, n_channels: int):
25 | """
26 | * `n_channels` is the number of dimensions in the embedding
27 | """
28 | super().__init__()
29 | self.n_channels = n_channels
30 | # First linear layer
31 | self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
32 | # Activation
33 | self.act = Swish()
34 | # Second linear layer
35 | self.lin2 = nn.Linear(self.n_channels, self.n_channels)
36 |
37 | def forward(self, t: torch.Tensor):
38 | # Create sinusoidal position embeddings
39 | # [same as those from the transformer](../../transformers/positional_encoding.html)
40 | #
41 | # \begin{align}
42 | # PE^{(1)}_{t,i} &= sin\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg) \\
43 | # PE^{(2)}_{t,i} &= cos\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg)
44 | # \end{align}
45 | #
46 | # where $d$ is `half_dim`
47 | half_dim = self.n_channels // 8
48 | emb = math.log(10_000) / (half_dim - 1)
49 | emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb)
50 | emb = t[:, None] * emb[None, :]
51 | emb = torch.cat((emb.sin(), emb.cos()), dim=1)
52 |
53 | # Transform with the MLP
54 | emb = self.act(self.lin1(emb))
55 | emb = self.lin2(emb)
56 |
57 | #
58 | return emb
59 |
60 |
61 | class ResidualBlock(nn.Module):
62 | """
63 | ### Residual block
64 | A residual block has two convolution layers with group normalization.
65 | Each resolution is processed with two residual blocks.
66 | """
67 |
68 | def __init__(self, in_channels: int, out_channels: int, time_channels: int,
69 | dropout: float = 0.1, is_noise: bool = True):
70 | """
71 | * `in_channels` is the number of input channels
72 | * `out_channels` is the number of input channels
73 | * `time_channels` is the number channels in the time step ($t$) embeddings
74 | * `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html)
75 | * `dropout` is the dropout rate
76 | """
77 | super().__init__()
78 | # Group normalization and the first convolution layer
79 | self.is_noise = is_noise
80 | self.act1 = Swish()
81 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
82 |
83 | # Group normalization and the second convolution layer
84 |
85 | self.act2 = Swish()
86 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1))
87 |
88 | # If the number of input channels is not equal to the number of output channels we have to
89 | # project the shortcut connection
90 | if in_channels != out_channels:
91 | self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
92 | else:
93 | self.shortcut = nn.Identity()
94 |
95 | # Linear layer for time embeddings
96 | if self.is_noise:
97 | self.time_emb = nn.Linear(time_channels, out_channels)
98 | self.time_act = Swish()
99 |
100 | self.dropout = nn.Dropout(dropout)
101 |
102 | def forward(self, x: torch.Tensor, t: torch.Tensor):
103 | """
104 | * `x` has shape `[batch_size, in_channels, height, width]`
105 | * `t` has shape `[batch_size, time_channels]`
106 | """
107 | # First convolution layer
108 | h = self.conv1(self.act1(x))
109 | # Add time embeddings
110 | if self.is_noise:
111 | h += self.time_emb(self.time_act(t))[:, :, None, None]
112 | # Second convolution layer
113 | h = self.conv2(self.dropout(self.act2(h)))
114 |
115 | # Add the shortcut connection and return
116 | return h + self.shortcut(x)
117 |
118 |
119 | class DownBlock(nn.Module):
120 | """
121 | ### Down block
122 | This combines `ResidualBlock` and `AttentionBlock`. These are used in the first half of U-Net at each resolution.
123 | """
124 |
125 | def __init__(self, in_channels: int, out_channels: int, time_channels: int, is_noise: bool = True):
126 | super().__init__()
127 | self.res = ResidualBlock(in_channels, out_channels, time_channels, is_noise=is_noise)
128 |
129 | def forward(self, x: torch.Tensor, t: torch.Tensor):
130 | x = self.res(x, t)
131 | return x
132 |
133 |
134 | class UpBlock(nn.Module):
135 | """
136 | ### Up block
137 | This combines `ResidualBlock` and `AttentionBlock`. These are used in the second half of U-Net at each resolution.
138 | """
139 |
140 | def __init__(self, in_channels: int, out_channels: int, time_channels: int, is_noise: bool = True):
141 | super().__init__()
142 | # The input has `in_channels + out_channels` because we concatenate the output of the same resolution
143 | # from the first half of the U-Net
144 | self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels, is_noise=is_noise)
145 |
146 | def forward(self, x: torch.Tensor, t: torch.Tensor):
147 | x = self.res(x, t)
148 | return x
149 |
150 |
151 | class MiddleBlock(nn.Module):
152 | """
153 | ### Middle block
154 | It combines a `ResidualBlock`, `AttentionBlock`, followed by another `ResidualBlock`.
155 | This block is applied at the lowest resolution of the U-Net.
156 | """
157 |
158 | def __init__(self, n_channels: int, time_channels: int, is_noise: bool = True):
159 | super().__init__()
160 | self.res1 = ResidualBlock(n_channels, n_channels, time_channels, is_noise=is_noise)
161 | self.dia1 = nn.Conv2d(n_channels, n_channels, 3, 1, dilation=2, padding=get_pad(16, 3, 1, 2))
162 | self.dia2 = nn.Conv2d(n_channels, n_channels, 3, 1, dilation=4, padding=get_pad(16, 3, 1, 4))
163 | self.dia3 = nn.Conv2d(n_channels, n_channels, 3, 1, dilation=8, padding=get_pad(16, 3, 1, 8))
164 | self.dia4 = nn.Conv2d(n_channels, n_channels, 3, 1, dilation=16, padding=get_pad(16, 3, 1, 16))
165 | self.res2 = ResidualBlock(n_channels, n_channels, time_channels, is_noise=is_noise)
166 |
167 | def forward(self, x: torch.Tensor, t: torch.Tensor):
168 | x = self.res1(x, t)
169 | x = self.dia1(x)
170 | x = self.dia2(x)
171 | x = self.dia3(x)
172 | x = self.dia4(x)
173 | x = self.res2(x, t)
174 | return x
175 |
176 |
177 | class Upsample(nn.Module):
178 | """
179 | ### Scale up the feature map by $2 \times$
180 | """
181 |
182 | def __init__(self, n_channels):
183 | super().__init__()
184 | self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))
185 |
186 | def forward(self, x: torch.Tensor, t: torch.Tensor):
187 | # `t` is not used, but it's kept in the arguments because for the attention layer function signature
188 | # to match with `ResidualBlock`.
189 | _ = t
190 | return self.conv(x)
191 |
192 |
193 | class Downsample(nn.Module):
194 | """
195 | ### Scale down the feature map by $\frac{1}{2} \times$
196 | """
197 |
198 | def __init__(self, n_channels):
199 | super().__init__()
200 | self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))
201 |
202 | def forward(self, x: torch.Tensor, t: torch.Tensor):
203 | # `t` is not used, but it's kept in the arguments because for the attention layer function signature
204 | # to match with `ResidualBlock`.
205 | _ = t
206 | return self.conv(x)
207 |
208 |
209 | class UNet(nn.Module):
210 | """
211 | ## U-Net
212 | """
213 |
214 | def __init__(self, input_channels: int = 2, output_channels: int = 1, n_channels: int = 32,
215 | ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
216 | n_blocks: int = 2, is_noise: bool = True):
217 | """
218 | * `image_channels` is the number of channels in the image. $3$ for RGB.
219 | * `n_channels` is number of channels in the initial feature map that we transform the image into
220 | * `ch_mults` is the list of channel numbers at each resolution. The number of channels is `ch_mults[i] * n_channels`
221 | * `is_attn` is a list of booleans that indicate whether to use attention at each resolution
222 | * `n_blocks` is the number of `UpDownBlocks` at each resolution
223 | """
224 | super().__init__()
225 |
226 | # Number of resolutions
227 | n_resolutions = len(ch_mults)
228 |
229 | # Project image into feature map
230 | self.image_proj = nn.Conv2d(input_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))
231 |
232 | # Time embedding layer. Time embedding has `n_channels * 4` channels
233 | self.is_noise = is_noise
234 | if is_noise:
235 | self.time_emb = TimeEmbedding(n_channels * 4)
236 |
237 | # #### First half of U-Net - decreasing resolution
238 | down = []
239 | # Number of channels
240 | out_channels = in_channels = n_channels
241 | # For each resolution
242 | for i in range(n_resolutions):
243 | # Number of output channels at this resolution
244 | out_channels = n_channels * ch_mults[i]
245 | # Add `n_blocks`
246 | for _ in range(n_blocks):
247 | down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_noise=is_noise))
248 | in_channels = out_channels
249 | # Down sample at all resolutions except the last
250 | if i < n_resolutions - 1:
251 | down.append(Downsample(in_channels))
252 |
253 | # Combine the set of modules
254 | self.down = nn.ModuleList(down)
255 |
256 | # Middle block
257 | self.middle = MiddleBlock(out_channels, n_channels * 4, is_noise=False)
258 |
259 | # #### Second half of U-Net - increasing resolution
260 | up = []
261 | # Number of channels
262 | in_channels = out_channels
263 | # For each resolution
264 | for i in reversed(range(n_resolutions)):
265 | # `n_blocks` at the same resolution
266 | out_channels = n_channels * ch_mults[i]
267 | for _ in range(n_blocks):
268 | up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_noise=is_noise))
269 | # Final block to reduce the number of channels
270 | in_channels = n_channels * (ch_mults[i-1] if i >= 1 else 1)
271 | up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_noise=is_noise))
272 | in_channels = out_channels
273 | # Up sample at all resolutions except last
274 | if i > 0:
275 | up.append(Upsample(in_channels))
276 |
277 | # Combine the set of modules
278 | self.up = nn.ModuleList(up)
279 |
280 | self.act = Swish()
281 | self.final = nn.Conv2d(in_channels, output_channels, kernel_size=(3, 3), padding=(1, 1))
282 |
283 | def forward(self, x: torch.Tensor, t: torch.Tensor=torch.tensor([0]).cuda()):
284 | """
285 | * `x` has shape `[batch_size, in_channels, height, width]`
286 | * `t` has shape `[batch_size]`
287 | """
288 |
289 | # Get time-step embeddings
290 | if self.is_noise:
291 | t = self.time_emb(t)
292 | else:
293 | t = None
294 | # Get image projection
295 | x = self.image_proj(x)
296 |
297 | # `h` will store outputs at each resolution for skip connection
298 | h = [x]
299 | # First half of U-Net
300 | for m in self.down:
301 | x = m(x, t)
302 | h.append(x)
303 |
304 | # Middle (bottom)
305 | x = self.middle(x, t)
306 |
307 | # Second half of U-Net
308 | for m in self.up:
309 | if isinstance(m, Upsample):
310 | x = m(x, t)
311 | else:
312 | # Get the skip connection from first half of U-Net and concatenate
313 | s = h.pop()
314 | # print(x.shape, s.shape)
315 | x = torch.cat((x, s), dim=1)
316 | #
317 | x = m(x, t)
318 |
319 | # Final normalization and convolution
320 | return self.final(self.act(x))
321 |
322 | 9
323 | class DocDiff(nn.Module):
324 | def __init__(self, input_channels: int = 2, output_channels: int = 1, n_channels: int = 32,
325 | ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4),
326 | n_blocks: int = 1):
327 | super(DocDiff, self).__init__()
328 | self.denoiser = UNet(input_channels, output_channels, n_channels, ch_mults, n_blocks, is_noise=True)
329 | self.init_predictor = UNet(input_channels//2, output_channels, n_channels, ch_mults, n_blocks, is_noise=False)
330 | # self.init_predictor = UNet(input_channels, output_channels, 2 * n_channels, ch_mults, n_blocks)
331 |
332 | def forward(self, x, condition, t, diffusion):
333 | x_ = self.init_predictor(condition, t)
334 | residual = x - x_
335 | noisy_image, noise_ref = diffusion.noisy_image(t, residual)
336 | x__ = self.denoiser(torch.cat((noisy_image, x_.clone().detach()), dim=1), t)
337 | return x_, x__, noisy_image, noise_ref
338 |
339 |
340 | class EMA():
341 | def __init__(self, beta):
342 | super().__init__()
343 | self.beta = beta
344 |
345 | def update_model_average(self, ma_model, current_model):
346 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
347 | old_weight, up_weight = ma_params.data, current_params.data
348 | ma_params.data = self.update_average(old_weight, up_weight)
349 |
350 | def update_average(self, old, new):
351 | if old is None:
352 | return new
353 | return old * self.beta + (1 - self.beta) * new
354 |
355 |
356 | def get_pad(in_, ksize, stride, atrous=1):
357 | out_ = np.ceil(float(in_)/stride)
358 | return int(((out_ - 1) * stride + atrous*(ksize-1) + 1 - in_)/2)
359 |
360 |
361 | if __name__ == '__main__':
362 | from src.config import load_config
363 | import argparse
364 | from schedule.diffusionSample import GaussianDiffusion
365 | from schedule.schedule import Schedule
366 | import torchsummary
367 | parser = argparse.ArgumentParser()
368 | parser.add_argument('--config', type=str, default='../conf.yml', help='path to the config.yaml file')
369 | args = parser.parse_args()
370 | config = load_config(args.config)
371 | print('Config loaded')
372 | model = DocDiff(input_channels=config.CHANNEL_X + config.CHANNEL_Y,
373 | output_channels=config.CHANNEL_Y,
374 | n_channels=config.MODEL_CHANNELS,
375 | ch_mults=config.CHANNEL_MULT,
376 | n_blocks=config.NUM_RESBLOCKS)
377 | schedule = Schedule(config.SCHEDULE, config.TIMESTEPS)
378 | diffusion = GaussianDiffusion(model, config.TIMESTEPS, schedule)
379 | model.eval()
380 | print(torchsummary.summary(model.init_predictor.cuda(), [(3, 128, 128)], batch_size=32))
381 |
382 |
--------------------------------------------------------------------------------
/schedule/diffusionSample.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import torchvision.utils
6 |
7 |
8 | def extract_(a, t, x_shape):
9 | b, *_ = t.shape
10 | out = a.gather(-1, t)
11 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
12 |
13 |
14 | def extract(v, t, x_shape):
15 | """
16 | Extract some coefficients at specified timesteps, then reshape to
17 | [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
18 | """
19 | device = t.device
20 | out = torch.gather(v, index=t, dim=0).float().to(device)
21 | return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
22 |
23 |
24 | class GaussianDiffusion(nn.Module):
25 | def __init__(self, model, T, schedule):
26 | super().__init__()
27 | self.visual = False
28 | if self.visual:
29 | self.num = 0
30 | self.model = model
31 | self.T = T
32 | self.schedule = schedule
33 | betas = self.schedule.get_betas()
34 | self.register_buffer('betas', betas.float())
35 | alphas = 1. - self.betas
36 | alphas_bar = torch.cumprod(alphas, dim=0)
37 | alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T]
38 | gammas = alphas_bar
39 |
40 | self.register_buffer('coeff1', torch.sqrt(1. / alphas))
41 | self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar))
42 | self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar))
43 |
44 | # calculation for q(y_t|y_{t-1})
45 | self.register_buffer('gammas', gammas)
46 | self.register_buffer('sqrt_one_minus_gammas', np.sqrt(1 - gammas))
47 | self.register_buffer('sqrt_gammas', np.sqrt(gammas))
48 |
49 | def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
50 | assert x_t.shape == eps.shape
51 | return extract(self.coeff1, t, x_t.shape) * x_t - extract(self.coeff2, t, x_t.shape) * eps
52 |
53 | def p_mean_variance(self, x_t, cond_, t):
54 | # below: only log_variance is used in the KL computations
55 | var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
56 | #var = self.betas
57 | var = extract(var, t, x_t.shape)
58 | eps = self.model(torch.cat((x_t, cond_), dim=1), t)
59 | # nonEps = self.model(x_t, t, torch.zeros_like(labels).to(labels.device))
60 | # eps = (1. + self.w) * eps - self.w * nonEps
61 | xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)
62 | return xt_prev_mean, var
63 |
64 | def noisy_image(self, t, y):
65 | """ Compute y_noisy according to (6) p15 of [2]"""
66 | noise = torch.randn_like(y)
67 | y_noisy = extract_(self.sqrt_gammas, t, y.shape) * y + extract_(self.sqrt_one_minus_gammas, t, noise.shape) * noise
68 | return y_noisy, noise
69 |
70 | def forward(self, x_T, cond, pre_ori='False'):
71 | """
72 | Algorithm 2.
73 | """
74 | x_t = x_T
75 | cond_ = cond
76 | for time_step in reversed(range(self.T)):
77 | print("time_step: ", time_step)
78 | t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step
79 | if pre_ori == 'False':
80 | mean, var = self.p_mean_variance(x_t=x_t, t=t, cond_=cond_)
81 | if time_step > 0:
82 | noise = torch.randn_like(x_t)
83 | else:
84 | noise = 0
85 | x_t = mean + torch.sqrt(var) * noise
86 | assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
87 | else:
88 | if time_step > 0:
89 | ori = self.model(torch.cat((x_t, cond_), dim=1), t)
90 | eps = x_t - extract_(self.sqrt_gammas, t, ori.shape) * ori
91 | eps = eps / extract_(self.sqrt_one_minus_gammas, t, eps.shape)
92 | x_t = extract_(self.sqrt_gammas, t - 1, ori.shape) * ori + extract_(self.sqrt_one_minus_gammas, t - 1, eps.shape) * eps
93 | else:
94 | x_t = self.model(torch.cat((x_t, cond_), dim=1), t)
95 |
96 | x_0 = x_t
97 | return x_0
98 |
99 |
100 | if __name__ == '__main__':
101 | from schedule import Schedule
102 | test = GaussianDiffusion(None, 100, Schedule('linear', 100))
103 | print(test.gammas)
104 |
--------------------------------------------------------------------------------
/schedule/dpm_solver_pytorch.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import math
4 |
5 |
6 | class NoiseScheduleVP:
7 | def __init__(
8 | self,
9 | schedule='discrete',
10 | betas=None,
11 | alphas_cumprod=None,
12 | continuous_beta_0=0.1,
13 | continuous_beta_1=20.,
14 | dtype=torch.float32,
15 | ):
16 | """Create a wrapper class for the forward SDE (VP type).
17 |
18 | ***
19 | Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
20 | We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
21 | ***
22 |
23 | The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
24 | We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
25 | Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
26 |
27 | log_alpha_t = self.marginal_log_mean_coeff(t)
28 | sigma_t = self.marginal_std(t)
29 | lambda_t = self.marginal_lambda(t)
30 |
31 | Moreover, as lambda(t) is an invertible function, we also support its inverse function:
32 |
33 | t = self.inverse_lambda(lambda_t)
34 |
35 | ===============================================================
36 |
37 | We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
38 |
39 | 1. For discrete-time DPMs:
40 |
41 | For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
42 | t_i = (i + 1) / N
43 | e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
44 | We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
45 |
46 | Args:
47 | betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
48 | alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
49 |
50 | Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
51 |
52 | **Important**: Please pay special attention for the args for `alphas_cumprod`:
53 | The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
54 | q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
55 | Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
56 | alpha_{t_n} = \sqrt{\hat{alpha_n}},
57 | and
58 | log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
59 |
60 |
61 | 2. For continuous-time DPMs:
62 |
63 | We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
64 | schedule are the default settings in DDPM and improved-DDPM:
65 |
66 | Args:
67 | beta_min: A `float` number. The smallest beta for the linear schedule.
68 | beta_max: A `float` number. The largest beta for the linear schedule.
69 | cosine_s: A `float` number. The hyperparameter in the cosine schedule.
70 | cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
71 | T: A `float` number. The ending time of the forward process.
72 |
73 | ===============================================================
74 |
75 | Args:
76 | schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
77 | 'linear' or 'cosine' for continuous-time DPMs.
78 | Returns:
79 | A wrapper object of the forward SDE (VP type).
80 |
81 | ===============================================================
82 |
83 | Example:
84 |
85 | # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
86 | >>> ns = NoiseScheduleVP('discrete', betas=betas)
87 |
88 | # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
89 | >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
90 |
91 | # For continuous-time DPMs (VPSDE), linear schedule:
92 | >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
93 |
94 | """
95 |
96 | if schedule not in ['discrete', 'linear', 'cosine']:
97 | raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule))
98 |
99 | self.schedule = schedule
100 | if schedule == 'discrete':
101 | if betas is not None:
102 | log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
103 | else:
104 | assert alphas_cumprod is not None
105 | log_alphas = 0.5 * torch.log(alphas_cumprod)
106 | self.total_N = len(log_alphas)
107 | self.T = 1.
108 | self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype)
109 | self.log_alpha_array = log_alphas.reshape((1, -1,)).to(dtype=dtype)
110 | else:
111 | self.total_N = 1000
112 | self.beta_0 = continuous_beta_0
113 | self.beta_1 = continuous_beta_1
114 | self.cosine_s = 0.008
115 | self.cosine_beta_max = 999.
116 | self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
117 | self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.))
118 | self.schedule = schedule
119 | if schedule == 'cosine':
120 | # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
121 | # Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
122 | self.T = 0.9946
123 | else:
124 | self.T = 1.
125 |
126 | def marginal_log_mean_coeff(self, t):
127 | """
128 | Compute log(alpha_t) of a given continuous-time label t in [0, T].
129 | """
130 | if self.schedule == 'discrete':
131 | return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1))
132 | elif self.schedule == 'linear':
133 | return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
134 | elif self.schedule == 'cosine':
135 | log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.))
136 | log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
137 | return log_alpha_t
138 |
139 | def marginal_alpha(self, t):
140 | """
141 | Compute alpha_t of a given continuous-time label t in [0, T].
142 | """
143 | return torch.exp(self.marginal_log_mean_coeff(t))
144 |
145 | def marginal_std(self, t):
146 | """
147 | Compute sigma_t of a given continuous-time label t in [0, T].
148 | """
149 | return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t)))
150 |
151 | def marginal_lambda(self, t):
152 | """
153 | Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
154 | """
155 | log_mean_coeff = self.marginal_log_mean_coeff(t)
156 | log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff))
157 | return log_mean_coeff - log_std
158 |
159 | def inverse_lambda(self, lamb):
160 | """
161 | Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
162 | """
163 | if self.schedule == 'linear':
164 | tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
165 | Delta = self.beta_0**2 + tmp
166 | return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
167 | elif self.schedule == 'discrete':
168 | log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb)
169 | t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1]))
170 | return t.reshape((-1,))
171 | else:
172 | log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb))
173 | t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s
174 | t = t_fn(log_alpha)
175 | return t
176 |
177 |
178 | def model_wrapper(
179 | model,
180 | noise_schedule,
181 | model_type="noise",
182 | model_kwargs={},
183 | guidance_type="uncond",
184 | condition=None,
185 | unconditional_condition=None,
186 | guidance_scale=1.,
187 | classifier_fn=None,
188 | classifier_kwargs={},
189 | ):
190 | """Create a wrapper function for the noise prediction model.
191 |
192 | DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
193 | firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
194 |
195 | We support four types of the diffusion model by setting `model_type`:
196 |
197 | 1. "noise": noise prediction model. (Trained by predicting noise).
198 |
199 | 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
200 |
201 | 3. "v": velocity prediction model. (Trained by predicting the velocity).
202 | The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
203 |
204 | [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
205 | arXiv preprint arXiv:2202.00512 (2022).
206 | [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
207 | arXiv preprint arXiv:2210.02303 (2022).
208 |
209 | 4. "score": marginal score function. (Trained by denoising score matching).
210 | Note that the score function and the noise prediction model follows a simple relationship:
211 | ```
212 | noise(x_t, t) = -sigma_t * score(x_t, t)
213 | ```
214 |
215 | We support three types of guided sampling by DPMs by setting `guidance_type`:
216 | 1. "uncond": unconditional sampling by DPMs.
217 | The input `model` has the following format:
218 | ``
219 | model(x, t_input, **model_kwargs) -> noise | x_start | v | score
220 | ``
221 |
222 | 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
223 | The input `model` has the following format:
224 | ``
225 | model(x, t_input, **model_kwargs) -> noise | x_start | v | score
226 | ``
227 |
228 | The input `classifier_fn` has the following format:
229 | ``
230 | classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
231 | ``
232 |
233 | [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
234 | in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
235 |
236 | 3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
237 | The input `model` has the following format:
238 | ``
239 | model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
240 | ``
241 | And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
242 |
243 | [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
244 | arXiv preprint arXiv:2207.12598 (2022).
245 |
246 |
247 | The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
248 | or continuous-time labels (i.e. epsilon to T).
249 |
250 | We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
251 | ``
252 | def model_fn(x, t_continuous) -> noise:
253 | t_input = get_model_input_time(t_continuous)
254 | return noise_pred(model, x, t_input, **model_kwargs)
255 | ``
256 | where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
257 |
258 | ===============================================================
259 |
260 | Args:
261 | model: A diffusion model with the corresponding format described above.
262 | noise_schedule: A noise schedule object, such as NoiseScheduleVP.
263 | model_type: A `str`. The parameterization type of the diffusion model.
264 | "noise" or "x_start" or "v" or "score".
265 | model_kwargs: A `dict`. A dict for the other inputs of the model function.
266 | guidance_type: A `str`. The type of the guidance for sampling.
267 | "uncond" or "classifier" or "classifier-free".
268 | condition: A pytorch tensor. The condition for the guided sampling.
269 | Only used for "classifier" or "classifier-free" guidance type.
270 | unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
271 | Only used for "classifier-free" guidance type.
272 | guidance_scale: A `float`. The scale for the guided sampling.
273 | classifier_fn: A classifier function. Only used for the classifier guidance.
274 | classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
275 | Returns:
276 | A noise prediction model that accepts the noised data and the continuous time as the inputs.
277 | """
278 |
279 | def get_model_input_time(t_continuous):
280 | """
281 | Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
282 | For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
283 | For continuous-time DPMs, we just use `t_continuous`.
284 | """
285 | if noise_schedule.schedule == 'discrete':
286 | return (t_continuous - 1. / noise_schedule.total_N) * 1000.
287 | else:
288 | return t_continuous
289 |
290 | def noise_pred_fn(x, t_continuous, cond=None):
291 | t_input = get_model_input_time(t_continuous)
292 | if cond is None:
293 | output = model(x, t_input, **model_kwargs)
294 | else:
295 | output = model(x, t_input, cond, **model_kwargs)
296 | if model_type == "noise":
297 | return output
298 | elif model_type == "x_start":
299 | alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
300 | return (x - alpha_t * output) / sigma_t
301 | elif model_type == "v":
302 | alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
303 | return alpha_t * output + sigma_t * x
304 | elif model_type == "score":
305 | sigma_t = noise_schedule.marginal_std(t_continuous)
306 | return -sigma_t * output
307 |
308 | def cond_grad_fn(x, t_input):
309 | """
310 | Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
311 | """
312 | with torch.enable_grad():
313 | x_in = x.detach().requires_grad_(True)
314 | log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
315 | return torch.autograd.grad(log_prob.sum(), x_in)[0]
316 |
317 | def model_fn(x, t_continuous):
318 | """
319 | The noise predicition model function that is used for DPM-Solver.
320 | """
321 | if guidance_type == "uncond":
322 | return noise_pred_fn(x, t_continuous)
323 | elif guidance_type == "classifier":
324 | assert classifier_fn is not None
325 | t_input = get_model_input_time(t_continuous)
326 | cond_grad = cond_grad_fn(x, t_input)
327 | sigma_t = noise_schedule.marginal_std(t_continuous)
328 | noise = noise_pred_fn(x, t_continuous)
329 | return noise - guidance_scale * sigma_t * cond_grad
330 | elif guidance_type == "classifier-free":
331 | if guidance_scale == 1. or unconditional_condition is None:
332 | return noise_pred_fn(x, t_continuous, cond=condition)
333 | else:
334 | x_in = torch.cat([x] * 2)
335 | t_in = torch.cat([t_continuous] * 2)
336 | c_in = torch.cat([unconditional_condition, condition])
337 | noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
338 | return noise_uncond + guidance_scale * (noise - noise_uncond)
339 |
340 | assert model_type in ["noise", "x_start", "v", "score"]
341 | assert guidance_type in ["uncond", "classifier", "classifier-free"]
342 | return model_fn
343 |
344 |
345 | class DPM_Solver:
346 | def __init__(
347 | self,
348 | model_fn,
349 | noise_schedule,
350 | algorithm_type="dpmsolver++",
351 | correcting_x0_fn=None,
352 | correcting_xt_fn=None,
353 | thresholding_max_val=1.,
354 | dynamic_thresholding_ratio=0.995,
355 | ):
356 | """Construct a DPM-Solver.
357 |
358 | We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
359 |
360 | We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you
361 | can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
362 | dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
363 | DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
364 | DPMs (such as stable-diffusion).
365 |
366 | To support advanced algorithms in image-to-image applications, we also support corrector functions for
367 | both x0 and xt.
368 |
369 | Args:
370 | model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
371 | ``
372 | def model_fn(x, t_continuous):
373 | return noise
374 | ``
375 | The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
376 | noise_schedule: A noise schedule object, such as NoiseScheduleVP.
377 | algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
378 | correcting_x0_fn: A `str` or a function with the following format:
379 | ```
380 | def correcting_x0_fn(x0, t):
381 | x0_new = ...
382 | return x0_new
383 | ```
384 | This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
385 | ```
386 | x0_pred = data_pred_model(xt, t)
387 | if correcting_x0_fn is not None:
388 | x0_pred = correcting_x0_fn(x0_pred, t)
389 | xt_1 = update(x0_pred, xt, t)
390 | ```
391 | If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
392 | correcting_xt_fn: A function with the following format:
393 | ```
394 | def correcting_xt_fn(xt, t, step):
395 | x_new = ...
396 | return x_new
397 | ```
398 | This function is to correct the intermediate samples xt at each sampling step. e.g.,
399 | ```
400 | xt = ...
401 | xt = correcting_xt_fn(xt, t, step)
402 | ```
403 | thresholding_max_val: A `float`. The max value for thresholding.
404 | Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
405 | dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
406 | Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
407 |
408 | [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
409 | Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models
410 | with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
411 | """
412 | self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
413 | self.noise_schedule = noise_schedule
414 | assert algorithm_type in ["dpmsolver", "dpmsolver++"]
415 | self.algorithm_type = algorithm_type
416 | if correcting_x0_fn == "dynamic_thresholding":
417 | self.correcting_x0_fn = self.dynamic_thresholding_fn
418 | else:
419 | self.correcting_x0_fn = correcting_x0_fn
420 | self.correcting_xt_fn = correcting_xt_fn
421 | self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
422 | self.thresholding_max_val = thresholding_max_val
423 |
424 | def dynamic_thresholding_fn(self, x0, t):
425 | """
426 | The dynamic thresholding method.
427 | """
428 | dims = x0.dim()
429 | p = self.dynamic_thresholding_ratio
430 | s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
431 | s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
432 | x0 = torch.clamp(x0, -s, s) / s
433 | return x0
434 |
435 | def noise_prediction_fn(self, x, t):
436 | """
437 | Return the noise prediction model.
438 | """
439 | return self.model(x, t)
440 |
441 | def data_prediction_fn(self, x, t):
442 | """
443 | Return the data prediction model (with corrector).
444 | """
445 | noise = self.noise_prediction_fn(x, t)
446 | alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
447 | x0 = (x - sigma_t * noise) / alpha_t
448 | if self.correcting_x0_fn is not None:
449 | x0 = self.correcting_x0_fn(x0, t)
450 | return x0
451 |
452 | def model_fn(self, x, t):
453 | """
454 | Convert the model to the noise prediction model or the data prediction model.
455 | """
456 | if self.algorithm_type == "dpmsolver++":
457 | return self.data_prediction_fn(x, t)
458 | else:
459 | return self.noise_prediction_fn(x, t)
460 |
461 | def get_time_steps(self, skip_type, t_T, t_0, N, device):
462 | """Compute the intermediate time steps for sampling.
463 |
464 | Args:
465 | skip_type: A `str`. The type for the spacing of the time steps. We support three types:
466 | - 'logSNR': uniform logSNR for the time steps.
467 | - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
468 | - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
469 | t_T: A `float`. The starting time of the sampling (default is T).
470 | t_0: A `float`. The ending time of the sampling (default is epsilon).
471 | N: A `int`. The total number of the spacing of the time steps.
472 | device: A torch device.
473 | Returns:
474 | A pytorch tensor of the time steps, with the shape (N + 1,).
475 | """
476 | if skip_type == 'logSNR':
477 | lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
478 | lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
479 | logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
480 | return self.noise_schedule.inverse_lambda(logSNR_steps)
481 | elif skip_type == 'time_uniform':
482 | return torch.linspace(t_T, t_0, N + 1).to(device)
483 | elif skip_type == 'time_quadratic':
484 | t_order = 2
485 | t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device)
486 | return t
487 | else:
488 | raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type))
489 |
490 | def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
491 | """
492 | Get the order of each step for sampling by the singlestep DPM-Solver.
493 |
494 | We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
495 | Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
496 | - If order == 1:
497 | We take `steps` of DPM-Solver-1 (i.e. DDIM).
498 | - If order == 2:
499 | - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
500 | - If steps % 2 == 0, we use K steps of DPM-Solver-2.
501 | - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
502 | - If order == 3:
503 | - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
504 | - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
505 | - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
506 | - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
507 |
508 | ============================================
509 | Args:
510 | order: A `int`. The max order for the solver (2 or 3).
511 | steps: A `int`. The total number of function evaluations (NFE).
512 | skip_type: A `str`. The type for the spacing of the time steps. We support three types:
513 | - 'logSNR': uniform logSNR for the time steps.
514 | - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
515 | - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
516 | t_T: A `float`. The starting time of the sampling (default is T).
517 | t_0: A `float`. The ending time of the sampling (default is epsilon).
518 | device: A torch device.
519 | Returns:
520 | orders: A list of the solver order of each step.
521 | """
522 | if order == 3:
523 | K = steps // 3 + 1
524 | if steps % 3 == 0:
525 | orders = [3,] * (K - 2) + [2, 1]
526 | elif steps % 3 == 1:
527 | orders = [3,] * (K - 1) + [1]
528 | else:
529 | orders = [3,] * (K - 1) + [2]
530 | elif order == 2:
531 | if steps % 2 == 0:
532 | K = steps // 2
533 | orders = [2,] * K
534 | else:
535 | K = steps // 2 + 1
536 | orders = [2,] * (K - 1) + [1]
537 | elif order == 1:
538 | K = 1
539 | orders = [1,] * steps
540 | else:
541 | raise ValueError("'order' must be '1' or '2' or '3'.")
542 | if skip_type == 'logSNR':
543 | # To reproduce the results in DPM-Solver paper
544 | timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
545 | else:
546 | timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)]
547 | return timesteps_outer, orders
548 |
549 | def denoise_to_zero_fn(self, x, s):
550 | """
551 | Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
552 | """
553 | return self.data_prediction_fn(x, s)
554 |
555 | def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
556 | """
557 | DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
558 |
559 | Args:
560 | x: A pytorch tensor. The initial value at time `s`.
561 | s: A pytorch tensor. The starting time, with the shape (1,).
562 | t: A pytorch tensor. The ending time, with the shape (1,).
563 | model_s: A pytorch tensor. The model function evaluated at time `s`.
564 | If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
565 | return_intermediate: A `bool`. If true, also return the model value at time `s`.
566 | Returns:
567 | x_t: A pytorch tensor. The approximated solution at time `t`.
568 | """
569 | ns = self.noise_schedule
570 | dims = x.dim()
571 | lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
572 | h = lambda_t - lambda_s
573 | log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
574 | sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
575 | alpha_t = torch.exp(log_alpha_t)
576 |
577 | if self.algorithm_type == "dpmsolver++":
578 | phi_1 = torch.expm1(-h)
579 | if model_s is None:
580 | model_s = self.model_fn(x, s)
581 | x_t = (
582 | sigma_t / sigma_s * x
583 | - alpha_t * phi_1 * model_s
584 | )
585 | if return_intermediate:
586 | return x_t, {'model_s': model_s}
587 | else:
588 | return x_t
589 | else:
590 | phi_1 = torch.expm1(h)
591 | if model_s is None:
592 | model_s = self.model_fn(x, s)
593 | x_t = (
594 | torch.exp(log_alpha_t - log_alpha_s) * x
595 | - (sigma_t * phi_1) * model_s
596 | )
597 | if return_intermediate:
598 | return x_t, {'model_s': model_s}
599 | else:
600 | return x_t
601 |
602 | def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpmsolver'):
603 | """
604 | Singlestep solver DPM-Solver-2 from time `s` to time `t`.
605 |
606 | Args:
607 | x: A pytorch tensor. The initial value at time `s`.
608 | s: A pytorch tensor. The starting time, with the shape (1,).
609 | t: A pytorch tensor. The ending time, with the shape (1,).
610 | r1: A `float`. The hyperparameter of the second-order solver.
611 | model_s: A pytorch tensor. The model function evaluated at time `s`.
612 | If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
613 | return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
614 | solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
615 | The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
616 | Returns:
617 | x_t: A pytorch tensor. The approximated solution at time `t`.
618 | """
619 | if solver_type not in ['dpmsolver', 'taylor']:
620 | raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
621 | if r1 is None:
622 | r1 = 0.5
623 | ns = self.noise_schedule
624 | lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
625 | h = lambda_t - lambda_s
626 | lambda_s1 = lambda_s + r1 * h
627 | s1 = ns.inverse_lambda(lambda_s1)
628 | log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t)
629 | sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
630 | alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
631 |
632 | if self.algorithm_type == "dpmsolver++":
633 | phi_11 = torch.expm1(-r1 * h)
634 | phi_1 = torch.expm1(-h)
635 |
636 | if model_s is None:
637 | model_s = self.model_fn(x, s)
638 | x_s1 = (
639 | (sigma_s1 / sigma_s) * x
640 | - (alpha_s1 * phi_11) * model_s
641 | )
642 | model_s1 = self.model_fn(x_s1, s1)
643 | if solver_type == 'dpmsolver':
644 | x_t = (
645 | (sigma_t / sigma_s) * x
646 | - (alpha_t * phi_1) * model_s
647 | - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
648 | )
649 | elif solver_type == 'taylor':
650 | x_t = (
651 | (sigma_t / sigma_s) * x
652 | - (alpha_t * phi_1) * model_s
653 | + (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s)
654 | )
655 | else:
656 | phi_11 = torch.expm1(r1 * h)
657 | phi_1 = torch.expm1(h)
658 |
659 | if model_s is None:
660 | model_s = self.model_fn(x, s)
661 | x_s1 = (
662 | torch.exp(log_alpha_s1 - log_alpha_s) * x
663 | - (sigma_s1 * phi_11) * model_s
664 | )
665 | model_s1 = self.model_fn(x_s1, s1)
666 | if solver_type == 'dpmsolver':
667 | x_t = (
668 | torch.exp(log_alpha_t - log_alpha_s) * x
669 | - (sigma_t * phi_1) * model_s
670 | - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
671 | )
672 | elif solver_type == 'taylor':
673 | x_t = (
674 | torch.exp(log_alpha_t - log_alpha_s) * x
675 | - (sigma_t * phi_1) * model_s
676 | - (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s)
677 | )
678 | if return_intermediate:
679 | return x_t, {'model_s': model_s, 'model_s1': model_s1}
680 | else:
681 | return x_t
682 |
683 | def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpmsolver'):
684 | """
685 | Singlestep solver DPM-Solver-3 from time `s` to time `t`.
686 |
687 | Args:
688 | x: A pytorch tensor. The initial value at time `s`.
689 | s: A pytorch tensor. The starting time, with the shape (1,).
690 | t: A pytorch tensor. The ending time, with the shape (1,).
691 | r1: A `float`. The hyperparameter of the third-order solver.
692 | r2: A `float`. The hyperparameter of the third-order solver.
693 | model_s: A pytorch tensor. The model function evaluated at time `s`.
694 | If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
695 | model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
696 | If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
697 | return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
698 | solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
699 | The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
700 | Returns:
701 | x_t: A pytorch tensor. The approximated solution at time `t`.
702 | """
703 | if solver_type not in ['dpmsolver', 'taylor']:
704 | raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
705 | if r1 is None:
706 | r1 = 1. / 3.
707 | if r2 is None:
708 | r2 = 2. / 3.
709 | ns = self.noise_schedule
710 | lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
711 | h = lambda_t - lambda_s
712 | lambda_s1 = lambda_s + r1 * h
713 | lambda_s2 = lambda_s + r2 * h
714 | s1 = ns.inverse_lambda(lambda_s1)
715 | s2 = ns.inverse_lambda(lambda_s2)
716 | log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t)
717 | sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t)
718 | alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
719 |
720 | if self.algorithm_type == "dpmsolver++":
721 | phi_11 = torch.expm1(-r1 * h)
722 | phi_12 = torch.expm1(-r2 * h)
723 | phi_1 = torch.expm1(-h)
724 | phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.
725 | phi_2 = phi_1 / h + 1.
726 | phi_3 = phi_2 / h - 0.5
727 |
728 | if model_s is None:
729 | model_s = self.model_fn(x, s)
730 | if model_s1 is None:
731 | x_s1 = (
732 | (sigma_s1 / sigma_s) * x
733 | - (alpha_s1 * phi_11) * model_s
734 | )
735 | model_s1 = self.model_fn(x_s1, s1)
736 | x_s2 = (
737 | (sigma_s2 / sigma_s) * x
738 | - (alpha_s2 * phi_12) * model_s
739 | + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
740 | )
741 | model_s2 = self.model_fn(x_s2, s2)
742 | if solver_type == 'dpmsolver':
743 | x_t = (
744 | (sigma_t / sigma_s) * x
745 | - (alpha_t * phi_1) * model_s
746 | + (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
747 | )
748 | elif solver_type == 'taylor':
749 | D1_0 = (1. / r1) * (model_s1 - model_s)
750 | D1_1 = (1. / r2) * (model_s2 - model_s)
751 | D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
752 | D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
753 | x_t = (
754 | (sigma_t / sigma_s) * x
755 | - (alpha_t * phi_1) * model_s
756 | + (alpha_t * phi_2) * D1
757 | - (alpha_t * phi_3) * D2
758 | )
759 | else:
760 | phi_11 = torch.expm1(r1 * h)
761 | phi_12 = torch.expm1(r2 * h)
762 | phi_1 = torch.expm1(h)
763 | phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.
764 | phi_2 = phi_1 / h - 1.
765 | phi_3 = phi_2 / h - 0.5
766 |
767 | if model_s is None:
768 | model_s = self.model_fn(x, s)
769 | if model_s1 is None:
770 | x_s1 = (
771 | (torch.exp(log_alpha_s1 - log_alpha_s)) * x
772 | - (sigma_s1 * phi_11) * model_s
773 | )
774 | model_s1 = self.model_fn(x_s1, s1)
775 | x_s2 = (
776 | (torch.exp(log_alpha_s2 - log_alpha_s)) * x
777 | - (sigma_s2 * phi_12) * model_s
778 | - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
779 | )
780 | model_s2 = self.model_fn(x_s2, s2)
781 | if solver_type == 'dpmsolver':
782 | x_t = (
783 | (torch.exp(log_alpha_t - log_alpha_s)) * x
784 | - (sigma_t * phi_1) * model_s
785 | - (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
786 | )
787 | elif solver_type == 'taylor':
788 | D1_0 = (1. / r1) * (model_s1 - model_s)
789 | D1_1 = (1. / r2) * (model_s2 - model_s)
790 | D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
791 | D2 = 2. * (D1_1 - D1_0) / (r2 - r1)
792 | x_t = (
793 | (torch.exp(log_alpha_t - log_alpha_s)) * x
794 | - (sigma_t * phi_1) * model_s
795 | - (sigma_t * phi_2) * D1
796 | - (sigma_t * phi_3) * D2
797 | )
798 |
799 | if return_intermediate:
800 | return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2}
801 | else:
802 | return x_t
803 |
804 | def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
805 | """
806 | Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
807 |
808 | Args:
809 | x: A pytorch tensor. The initial value at time `s`.
810 | model_prev_list: A list of pytorch tensor. The previous computed model values.
811 | t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
812 | t: A pytorch tensor. The ending time, with the shape (1,).
813 | solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
814 | The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
815 | Returns:
816 | x_t: A pytorch tensor. The approximated solution at time `t`.
817 | """
818 | if solver_type not in ['dpmsolver', 'taylor']:
819 | raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type))
820 | ns = self.noise_schedule
821 | model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
822 | t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
823 | lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
824 | log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
825 | sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
826 | alpha_t = torch.exp(log_alpha_t)
827 |
828 | h_0 = lambda_prev_0 - lambda_prev_1
829 | h = lambda_t - lambda_prev_0
830 | r0 = h_0 / h
831 | D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
832 | if self.algorithm_type == "dpmsolver++":
833 | phi_1 = torch.expm1(-h)
834 | if solver_type == 'dpmsolver':
835 | x_t = (
836 | (sigma_t / sigma_prev_0) * x
837 | - (alpha_t * phi_1) * model_prev_0
838 | - 0.5 * (alpha_t * phi_1) * D1_0
839 | )
840 | elif solver_type == 'taylor':
841 | x_t = (
842 | (sigma_t / sigma_prev_0) * x
843 | - (alpha_t * phi_1) * model_prev_0
844 | + (alpha_t * (phi_1 / h + 1.)) * D1_0
845 | )
846 | else:
847 | phi_1 = torch.expm1(h)
848 | if solver_type == 'dpmsolver':
849 | x_t = (
850 | (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
851 | - (sigma_t * phi_1) * model_prev_0
852 | - 0.5 * (sigma_t * phi_1) * D1_0
853 | )
854 | elif solver_type == 'taylor':
855 | x_t = (
856 | (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
857 | - (sigma_t * phi_1) * model_prev_0
858 | - (sigma_t * (phi_1 / h - 1.)) * D1_0
859 | )
860 | return x_t
861 |
862 | def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'):
863 | """
864 | Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
865 |
866 | Args:
867 | x: A pytorch tensor. The initial value at time `s`.
868 | model_prev_list: A list of pytorch tensor. The previous computed model values.
869 | t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
870 | t: A pytorch tensor. The ending time, with the shape (1,).
871 | solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
872 | The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
873 | Returns:
874 | x_t: A pytorch tensor. The approximated solution at time `t`.
875 | """
876 | ns = self.noise_schedule
877 | model_prev_2, model_prev_1, model_prev_0 = model_prev_list
878 | t_prev_2, t_prev_1, t_prev_0 = t_prev_list
879 | lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t)
880 | log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
881 | sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
882 | alpha_t = torch.exp(log_alpha_t)
883 |
884 | h_1 = lambda_prev_1 - lambda_prev_2
885 | h_0 = lambda_prev_0 - lambda_prev_1
886 | h = lambda_t - lambda_prev_0
887 | r0, r1 = h_0 / h, h_1 / h
888 | D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
889 | D1_1 = (1. / r1) * (model_prev_1 - model_prev_2)
890 | D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
891 | D2 = (1. / (r0 + r1)) * (D1_0 - D1_1)
892 | if self.algorithm_type == "dpmsolver++":
893 | phi_1 = torch.expm1(-h)
894 | phi_2 = phi_1 / h + 1.
895 | phi_3 = phi_2 / h - 0.5
896 | x_t = (
897 | (sigma_t / sigma_prev_0) * x
898 | - (alpha_t * phi_1) * model_prev_0
899 | + (alpha_t * phi_2) * D1
900 | - (alpha_t * phi_3) * D2
901 | )
902 | else:
903 | phi_1 = torch.expm1(h)
904 | phi_2 = phi_1 / h - 1.
905 | phi_3 = phi_2 / h - 0.5
906 | x_t = (
907 | (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
908 | - (sigma_t * phi_1) * model_prev_0
909 | - (sigma_t * phi_2) * D1
910 | - (sigma_t * phi_3) * D2
911 | )
912 | return x_t
913 |
914 | def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None, r2=None):
915 | """
916 | Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
917 |
918 | Args:
919 | x: A pytorch tensor. The initial value at time `s`.
920 | s: A pytorch tensor. The starting time, with the shape (1,).
921 | t: A pytorch tensor. The ending time, with the shape (1,).
922 | order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
923 | return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
924 | solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
925 | The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
926 | r1: A `float`. The hyperparameter of the second-order or third-order solver.
927 | r2: A `float`. The hyperparameter of the third-order solver.
928 | Returns:
929 | x_t: A pytorch tensor. The approximated solution at time `t`.
930 | """
931 | if order == 1:
932 | return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
933 | elif order == 2:
934 | return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1)
935 | elif order == 3:
936 | return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2)
937 | else:
938 | raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
939 |
940 | def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'):
941 | """
942 | Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
943 |
944 | Args:
945 | x: A pytorch tensor. The initial value at time `s`.
946 | model_prev_list: A list of pytorch tensor. The previous computed model values.
947 | t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
948 | t: A pytorch tensor. The ending time, with the shape (1,).
949 | order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
950 | solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
951 | The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
952 | Returns:
953 | x_t: A pytorch tensor. The approximated solution at time `t`.
954 | """
955 | if order == 1:
956 | return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
957 | elif order == 2:
958 | return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
959 | elif order == 3:
960 | return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
961 | else:
962 | raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
963 |
964 | def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpmsolver'):
965 | """
966 | The adaptive step size solver based on singlestep DPM-Solver.
967 |
968 | Args:
969 | x: A pytorch tensor. The initial value at time `t_T`.
970 | order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
971 | t_T: A `float`. The starting time of the sampling (default is T).
972 | t_0: A `float`. The ending time of the sampling (default is epsilon).
973 | h_init: A `float`. The initial step size (for logSNR).
974 | atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
975 | rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
976 | theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
977 | t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
978 | current time and `t_0` is less than `t_err`. The default setting is 1e-5.
979 | solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
980 | The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
981 | Returns:
982 | x_0: A pytorch tensor. The approximated solution at time `t_0`.
983 |
984 | [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
985 | """
986 | ns = self.noise_schedule
987 | s = t_T * torch.ones((1,)).to(x)
988 | lambda_s = ns.marginal_lambda(s)
989 | lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
990 | h = h_init * torch.ones_like(s).to(x)
991 | x_prev = x
992 | nfe = 0
993 | if order == 2:
994 | r1 = 0.5
995 | lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
996 | higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs)
997 | elif order == 3:
998 | r1, r2 = 1. / 3., 2. / 3.
999 | lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type)
1000 | higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs)
1001 | else:
1002 | raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order))
1003 | while torch.abs((s - t_0)).mean() > t_err:
1004 | t = ns.inverse_lambda(lambda_s + h)
1005 | x_lower, lower_noise_kwargs = lower_update(x, s, t)
1006 | x_higher = higher_update(x, s, t, **lower_noise_kwargs)
1007 | delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
1008 | norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
1009 | E = norm_fn((x_higher - x_lower) / delta).max()
1010 | if torch.all(E <= 1.):
1011 | x = x_higher
1012 | s = t
1013 | x_prev = x_lower
1014 | lambda_s = ns.marginal_lambda(s)
1015 | h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s)
1016 | nfe += order
1017 | print('adaptive solver nfe', nfe)
1018 | return x
1019 |
1020 | def add_noise(self, x, t, noise=None):
1021 | """
1022 | Compute the noised input xt = alpha_t * x + sigma_t * noise.
1023 |
1024 | Args:
1025 | x: A `torch.Tensor` with shape `(batch_size, *shape)`.
1026 | t: A `torch.Tensor` with shape `(t_size,)`.
1027 | Returns:
1028 | xt with shape `(t_size, batch_size, *shape)`.
1029 | """
1030 | alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
1031 | if noise is None:
1032 | noise = torch.randn((t.shape[0], *x.shape), device=x.device)
1033 | x = x.reshape((-1, *x.shape))
1034 | xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
1035 | if t.shape[0] == 1:
1036 | return xt.squeeze(0)
1037 | else:
1038 | return xt
1039 |
1040 | def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
1041 | method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
1042 | atol=0.0078, rtol=0.05, return_intermediate=False,
1043 | ):
1044 | """
1045 | Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
1046 | For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
1047 | """
1048 | t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start
1049 | t_T = self.noise_schedule.T if t_end is None else t_end
1050 | assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1051 | return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type,
1052 | method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero, solver_type=solver_type,
1053 | atol=atol, rtol=rtol, return_intermediate=return_intermediate)
1054 |
1055 | def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform',
1056 | method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver',
1057 | atol=0.0078, rtol=0.05, return_intermediate=False,
1058 | ):
1059 | """
1060 | Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
1061 |
1062 | =====================================================
1063 |
1064 | We support the following algorithms for both noise prediction model and data prediction model:
1065 | - 'singlestep':
1066 | Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
1067 | We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
1068 | The total number of function evaluations (NFE) == `steps`.
1069 | Given a fixed NFE == `steps`, the sampling procedure is:
1070 | - If `order` == 1:
1071 | - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
1072 | - If `order` == 2:
1073 | - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
1074 | - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
1075 | - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1076 | - If `order` == 3:
1077 | - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
1078 | - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
1079 | - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
1080 | - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
1081 | - 'multistep':
1082 | Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
1083 | We initialize the first `order` values by lower order multistep solvers.
1084 | Given a fixed NFE == `steps`, the sampling procedure is:
1085 | Denote K = steps.
1086 | - If `order` == 1:
1087 | - We use K steps of DPM-Solver-1 (i.e. DDIM).
1088 | - If `order` == 2:
1089 | - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
1090 | - If `order` == 3:
1091 | - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
1092 | - 'singlestep_fixed':
1093 | Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
1094 | We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
1095 | - 'adaptive':
1096 | Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
1097 | We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
1098 | You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
1099 | (NFE) and the sample quality.
1100 | - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
1101 | - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
1102 |
1103 | =====================================================
1104 |
1105 | Some advices for choosing the algorithm:
1106 | - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
1107 | Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
1108 | e.g., DPM-Solver:
1109 | >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
1110 | >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1111 | skip_type='time_uniform', method='singlestep')
1112 | e.g., DPM-Solver++:
1113 | >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1114 | >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
1115 | skip_type='time_uniform', method='singlestep')
1116 | - For **guided sampling with large guidance scale** by DPMs:
1117 | Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
1118 | e.g.
1119 | >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
1120 | >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
1121 | skip_type='time_uniform', method='multistep')
1122 |
1123 | We support three types of `skip_type`:
1124 | - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
1125 | - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
1126 | - 'time_quadratic': quadratic time for the time steps.
1127 |
1128 | =====================================================
1129 | Args:
1130 | x: A pytorch tensor. The initial value at time `t_start`
1131 | e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
1132 | steps: A `int`. The total number of function evaluations (NFE).
1133 | t_start: A `float`. The starting time of the sampling.
1134 | If `T` is None, we use self.noise_schedule.T (default is 1.0).
1135 | t_end: A `float`. The ending time of the sampling.
1136 | If `t_end` is None, we use 1. / self.noise_schedule.total_N.
1137 | e.g. if total_N == 1000, we have `t_end` == 1e-3.
1138 | For discrete-time DPMs:
1139 | - We recommend `t_end` == 1. / self.noise_schedule.total_N.
1140 | For continuous-time DPMs:
1141 | - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
1142 | order: A `int`. The order of DPM-Solver.
1143 | skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
1144 | method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
1145 | denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
1146 | Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
1147 |
1148 | This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
1149 | score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
1150 | for diffusion models sampling by diffusion SDEs for low-resolutional images
1151 | (such as CIFAR-10). However, we observed that such trick does not matter for
1152 | high-resolutional images. As it needs an additional NFE, we do not recommend
1153 | it for high-resolutional images.
1154 | lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
1155 | Only valid for `method=multistep` and `steps < 15`. We empirically find that
1156 | this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
1157 | (especially for steps <= 10). So we recommend to set it to be `True`.
1158 | solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
1159 | atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1160 | rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
1161 | return_intermediate: A `bool`. Whether to save the xt at each step.
1162 | When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
1163 | Returns:
1164 | x_end: A pytorch tensor. The approximated solution at time `t_end`.
1165 |
1166 | """
1167 | t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
1168 | t_T = self.noise_schedule.T if t_start is None else t_start
1169 | assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
1170 | if return_intermediate:
1171 | assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values"
1172 | if self.correcting_xt_fn is not None:
1173 | assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None"
1174 | device = x.device
1175 | intermediates = []
1176 | with torch.no_grad():
1177 | if method == 'adaptive':
1178 | x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type)
1179 | elif method == 'multistep':
1180 | assert steps >= order
1181 | timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
1182 | assert timesteps.shape[0] - 1 == steps
1183 | # Init the initial values.
1184 | step = 0
1185 | t = timesteps[step]
1186 | t_prev_list = [t]
1187 | model_prev_list = [self.model_fn(x, t)]
1188 | if self.correcting_xt_fn is not None:
1189 | x = self.correcting_xt_fn(x, t, step)
1190 | if return_intermediate:
1191 | intermediates.append(x)
1192 | # Init the first `order` values by lower order multistep DPM-Solver.
1193 | for step in range(1, order):
1194 | t = timesteps[step]
1195 | x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step, solver_type=solver_type)
1196 | if self.correcting_xt_fn is not None:
1197 | x = self.correcting_xt_fn(x, t, step)
1198 | if return_intermediate:
1199 | intermediates.append(x)
1200 | t_prev_list.append(t)
1201 | model_prev_list.append(self.model_fn(x, t))
1202 | # Compute the remaining values by `order`-th order multistep DPM-Solver.
1203 | for step in range(order, steps + 1):
1204 | t = timesteps[step]
1205 | # We only use lower order for steps < 10
1206 | if lower_order_final and steps < 10:
1207 | step_order = min(order, steps + 1 - step)
1208 | else:
1209 | step_order = order
1210 | x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type)
1211 | if self.correcting_xt_fn is not None:
1212 | x = self.correcting_xt_fn(x, t, step)
1213 | if return_intermediate:
1214 | intermediates.append(x)
1215 | for i in range(order - 1):
1216 | t_prev_list[i] = t_prev_list[i + 1]
1217 | model_prev_list[i] = model_prev_list[i + 1]
1218 | t_prev_list[-1] = t
1219 | # We do not need to evaluate the final model value.
1220 | if step < steps:
1221 | model_prev_list[-1] = self.model_fn(x, t)
1222 | elif method in ['singlestep', 'singlestep_fixed']:
1223 | if method == 'singlestep':
1224 | timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device)
1225 | elif method == 'singlestep_fixed':
1226 | K = steps // order
1227 | orders = [order,] * K
1228 | timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
1229 | for step, order in enumerate(orders):
1230 | s, t = timesteps_outer[step], timesteps_outer[step + 1]
1231 | timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device)
1232 | lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
1233 | h = lambda_inner[-1] - lambda_inner[0]
1234 | r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
1235 | r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
1236 | x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
1237 | if self.correcting_xt_fn is not None:
1238 | x = self.correcting_xt_fn(x, t, step)
1239 | if return_intermediate:
1240 | intermediates.append(x)
1241 | else:
1242 | raise ValueError("Got wrong method {}".format(method))
1243 | if denoise_to_zero:
1244 | t = torch.ones((1,)).to(device) * t_0
1245 | x = self.denoise_to_zero_fn(x, t)
1246 | if self.correcting_xt_fn is not None:
1247 | x = self.correcting_xt_fn(x, t, step + 1)
1248 | if return_intermediate:
1249 | intermediates.append(x)
1250 | if return_intermediate:
1251 | return x, intermediates
1252 | else:
1253 | return x
1254 |
1255 |
1256 |
1257 | #############################################################
1258 | # other utility functions
1259 | #############################################################
1260 |
1261 | def interpolate_fn(x, xp, yp):
1262 | """
1263 | A piecewise linear function y = f(x), using xp and yp as keypoints.
1264 | We implement f(x) in a differentiable way (i.e. applicable for autograd).
1265 | The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
1266 |
1267 | Args:
1268 | x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
1269 | xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
1270 | yp: PyTorch tensor with shape [C, K].
1271 | Returns:
1272 | The function values f(x), with shape [N, C].
1273 | """
1274 | N, K = x.shape[0], xp.shape[1]
1275 | all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
1276 | sorted_all_x, x_indices = torch.sort(all_x, dim=2)
1277 | x_idx = torch.argmin(x_indices, dim=2)
1278 | cand_start_idx = x_idx - 1
1279 | start_idx = torch.where(
1280 | torch.eq(x_idx, 0),
1281 | torch.tensor(1, device=x.device),
1282 | torch.where(
1283 | torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1284 | ),
1285 | )
1286 | end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
1287 | start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
1288 | end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
1289 | start_idx2 = torch.where(
1290 | torch.eq(x_idx, 0),
1291 | torch.tensor(0, device=x.device),
1292 | torch.where(
1293 | torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
1294 | ),
1295 | )
1296 | y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
1297 | start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
1298 | end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
1299 | cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
1300 | return cand
1301 |
1302 |
1303 | def expand_dims(v, dims):
1304 | """
1305 | Expand the tensor `v` to the dim `dims`.
1306 |
1307 | Args:
1308 | `v`: a PyTorch tensor with shape [N].
1309 | `dim`: a `int`.
1310 | Returns:
1311 | a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
1312 | """
1313 | return v[(...,) + (None,)*(dims - 1)]
1314 |
--------------------------------------------------------------------------------
/schedule/schedule.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | class Schedule:
6 | def __init__(self, schedule, timesteps):
7 | self.timesteps = timesteps
8 | self.schedule = schedule
9 |
10 | def cosine_beta_schedule(self, s=0.001):
11 | timesteps = self.timesteps
12 | steps = timesteps + 1
13 | x = torch.linspace(0, timesteps, steps)
14 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2
15 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
16 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
17 | return torch.clip(betas, 0.0001, 0.9999)
18 |
19 | def linear_beta_schedule(self):
20 | timesteps = self.timesteps
21 | scale = 1000 / timesteps
22 | beta_start = 1e-6 * scale
23 | beta_end = 0.02 * scale
24 | return torch.linspace(beta_start, beta_end, timesteps)
25 |
26 | def quadratic_beta_schedule(self):
27 | timesteps = self.timesteps
28 | scale = 1000 / timesteps
29 | beta_start = 1e-6 * scale
30 | beta_end = 0.02 * scale
31 | return torch.linspace(beta_start ** 0.5, beta_end ** 0.5, timesteps) ** 2
32 |
33 | def sigmoid_beta_schedule(self):
34 | timesteps = self.timesteps
35 | scale = 1000 / timesteps
36 | beta_start = 1e-6 * scale
37 | beta_end = 0.02 * scale
38 | betas = torch.linspace(-6, 6, timesteps)
39 | return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start
40 |
41 | def get_betas(self):
42 | if self.schedule == "linear":
43 | return self.linear_beta_schedule()
44 | elif self.schedule == 'cosine':
45 | return self.cosine_beta_schedule()
46 | else:
47 | raise NotImplementedError
48 |
49 |
50 | if __name__ == "__main__":
51 | schedule = Schedule(schedule="linear", timesteps=100)
52 | print(schedule.get_betas().shape)
53 | print(schedule.get_betas())
--------------------------------------------------------------------------------
/src/config.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | import os
3 |
4 |
5 | class Config(dict):
6 | def __init__(self, config_path):
7 | with open(config_path, 'r') as f:
8 | self._yaml = f.read()
9 | self._dict = yaml.safe_load(self._yaml)
10 | self._dict['PATH'] = os.path.dirname(config_path)
11 |
12 | def __getattr__(self, name):
13 | if self._dict.get(name) is not None:
14 | return self._dict[name]
15 | return None
16 |
17 | def print(self):
18 | print('Model configurations:')
19 | print('---------------------------------')
20 | print(self._yaml)
21 | print('')
22 | print('---------------------------------')
23 | print('')
24 |
25 |
26 | def load_config(path):
27 | config_path = path
28 | config = Config(config_path)
29 | return config
30 |
--------------------------------------------------------------------------------
/src/sobel.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | import torch
3 |
4 |
5 | class Sobel(nn.Module):
6 | def __init__(self):
7 | super().__init__()
8 | self.filter = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=1, padding=1, bias=False)
9 |
10 | Gx = torch.tensor([[1.0, 0.0, -1.0], [2.0, 0.0, -2.0], [1.0, 0.0, -1.0]])
11 | Gy = torch.tensor([[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]])
12 | G = torch.cat([Gx.unsqueeze(0), Gy.unsqueeze(0)], 0)
13 | G = G.unsqueeze(1)
14 | self.filter.weight = nn.Parameter(G, requires_grad=False)
15 |
16 | def forward(self, img):
17 | if img.shape[1] == 3:
18 | img = torch.mean(img, dim=1, keepdim=True)
19 | x = self.filter(img)
20 | x = torch.mul(x, x)
21 | x = torch.sum(x, dim=1, keepdim=True)
22 | x = torch.sqrt(x)
23 | return x
24 |
25 |
26 | class Laplacian(nn.Module):
27 | def __init__(self):
28 | super().__init__()
29 | self.filter = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False, groups=3)
30 | G = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]).float()
31 | G = G.unsqueeze(0).unsqueeze(0)
32 | G = torch.cat([G, G, G], 0)
33 | self.filter.weight = nn.Parameter(G, requires_grad=False)
34 |
35 | def forward(self, img):
36 | x = self.filter(img)
37 | return x
38 |
39 | if __name__ == "__main__":
40 | laplacian = Laplacian()
41 | img = torch.randn(1, 3, 256, 256)
42 | y = laplacian(img)
43 | print(y.shape)
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | from src.trainer import Trainer
2 |
3 |
4 | def train(config):
5 | trainer = Trainer(config)
6 | trainer.train()
7 | print('training complete')
8 |
9 |
10 | def test(config):
11 | trainer = Trainer(config)
12 | trainer.test()
13 | print('testing complete')
14 |
--------------------------------------------------------------------------------
/src/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | from schedule.schedule import Schedule
3 | from model.DocDiff import DocDiff, EMA
4 | from schedule.diffusionSample import GaussianDiffusion
5 | from schedule.dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver
6 | import torch
7 | import torch.optim as optim
8 | import torch.nn as nn
9 | from torch.utils.data import DataLoader
10 | from torchvision.utils import save_image
11 | from tqdm import tqdm
12 | import copy
13 | from src.sobel import Laplacian
14 |
15 |
16 | def init__result_Dir():
17 | work_dir = os.path.join(os.getcwd(), 'Training')
18 | max_model = 0
19 | for root, j, file in os.walk(work_dir):
20 | for dirs in j:
21 | try:
22 | temp = int(dirs)
23 | if temp > max_model:
24 | max_model = temp
25 | except:
26 | continue
27 | break
28 | max_model += 1
29 | path = os.path.join(work_dir, str(max_model))
30 | os.mkdir(path)
31 | return path
32 |
33 |
34 | class Trainer:
35 | def __init__(self, config):
36 | self.mode = config.MODE
37 | self.schedule = Schedule(config.SCHEDULE, config.TIMESTEPS)
38 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
39 | in_channels = config.CHANNEL_X + config.CHANNEL_Y
40 | out_channels = config.CHANNEL_Y
41 | self.out_channels = out_channels
42 | self.network = DocDiff(
43 | input_channels=in_channels,
44 | output_channels=out_channels,
45 | n_channels=config.MODEL_CHANNELS,
46 | ch_mults=config.CHANNEL_MULT,
47 | n_blocks=config.NUM_RESBLOCKS
48 | ).to(self.device)
49 | self.diffusion = GaussianDiffusion(self.network.denoiser, config.TIMESTEPS, self.schedule).to(self.device)
50 | self.test_img_save_path = config.TEST_IMG_SAVE_PATH
51 | if not os.path.exists(self.test_img_save_path):
52 | os.makedirs(self.test_img_save_path)
53 | self.pretrained_path_init_predictor = config.PRETRAINED_PATH_INITIAL_PREDICTOR
54 | self.pretrained_path_denoiser = config.PRETRAINED_PATH_DENOISER
55 | self.continue_training = config.CONTINUE_TRAINING
56 | self.continue_training_steps = 0
57 | self.path_train_gt = config.PATH_GT
58 | self.path_train_img = config.PATH_IMG
59 | self.iteration_max = config.ITERATION_MAX
60 | self.LR = config.LR
61 | self.cross_entropy = nn.BCELoss()
62 | self.num_timesteps = config.TIMESTEPS
63 | self.ema_every = config.EMA_EVERY
64 | self.start_ema = config.START_EMA
65 | self.save_model_every = config.SAVE_MODEL_EVERY
66 | self.EMA_or_not = config.EMA
67 | self.weight_save_path = config.WEIGHT_SAVE_PATH
68 | self.TEST_INITIAL_PREDICTOR_WEIGHT_PATH = config.TEST_INITIAL_PREDICTOR_WEIGHT_PATH
69 | self.TEST_DENOISER_WEIGHT_PATH = config.TEST_DENOISER_WEIGHT_PATH
70 | self.DPM_SOLVER = config.DPM_SOLVER
71 | self.DPM_STEP = config.DPM_STEP
72 | self.test_path_img = config.TEST_PATH_IMG
73 | self.test_path_gt = config.TEST_PATH_GT
74 | self.beta_loss = config.BETA_LOSS
75 | self.pre_ori = config.PRE_ORI
76 | self.high_low_freq = config.HIGH_LOW_FREQ
77 | self.image_size = config.IMAGE_SIZE
78 | self.native_resolution = config.NATIVE_RESOLUTION
79 | if self.mode == 1 and self.continue_training == 'True':
80 | print('Continue Training')
81 | self.network.init_predictor.load_state_dict(torch.load(self.pretrained_path_init_predictor))
82 | self.network.denoiser.load_state_dict(torch.load(self.pretrained_path_denoiser))
83 | self.continue_training_steps = config.CONTINUE_TRAINING_STEPS
84 | from data.docdata import DocData
85 | if self.mode == 1:
86 | dataset_train = DocData(self.path_train_img, self.path_train_gt, config.IMAGE_SIZE, self.mode)
87 | self.batch_size = config.BATCH_SIZE
88 | self.dataloader_train = DataLoader(dataset_train, batch_size=self.batch_size, shuffle=True, drop_last=False,
89 | num_workers=config.NUM_WORKERS)
90 | else:
91 | dataset_test = DocData(config.TEST_PATH_IMG, config.TEST_PATH_GT, config.IMAGE_SIZE, self.mode)
92 | self.dataloader_test = DataLoader(dataset_test, batch_size=config.BATCH_SIZE_VAL, shuffle=False,
93 | drop_last=False,
94 | num_workers=config.NUM_WORKERS)
95 | if self.mode == 1 and config.EMA == 'True':
96 | self.EMA = EMA(0.9999)
97 | self.ema_model = copy.deepcopy(self.network).to(self.device)
98 | if config.LOSS == 'L1':
99 | self.loss = nn.L1Loss()
100 | elif config.LOSS == 'L2':
101 | self.loss = nn.MSELoss()
102 | else:
103 | print('Loss not implemented, setting the loss to L2 (default one)')
104 | self.loss = nn.MSELoss()
105 | if self.high_low_freq == 'True':
106 | self.high_filter = Laplacian().to(self.device)
107 |
108 | def test(self):
109 | def crop_concat(img, size=128):
110 | shape = img.shape
111 | correct_shape = (size*(shape[2]//size+1), size*(shape[3]//size+1))
112 | one = torch.ones((shape[0], shape[1], correct_shape[0], correct_shape[1]))
113 | one[:, :, :shape[2], :shape[3]] = img
114 | # crop
115 | for i in range(shape[2]//size+1):
116 | for j in range(shape[3]//size+1):
117 | if i == 0 and j == 0:
118 | crop = one[:, :, i*size:(i+1)*size, j*size:(j+1)*size]
119 | else:
120 | crop = torch.cat((crop, one[:, :, i*size:(i+1)*size, j*size:(j+1)*size]), dim=0)
121 | return crop
122 | def crop_concat_back(img, prediction, size=128):
123 | shape = img.shape
124 | for i in range(shape[2]//size+1):
125 | for j in range(shape[3]//size+1):
126 | if j == 0:
127 | crop = prediction[(i*(shape[3]//size+1)+j)*shape[0]:(i*(shape[3]//size+1)+j+1)*shape[0], :, :, :]
128 | else:
129 | crop = torch.cat((crop, prediction[(i*(shape[3]//size+1)+j)*shape[0]:(i*(shape[3]//size+1)+j+1)*shape[0], :, :, :]), dim=3)
130 | if i == 0:
131 | crop_concat = crop
132 | else:
133 | crop_concat = torch.cat((crop_concat, crop), dim=2)
134 | return crop_concat[:, :, :shape[2], :shape[3]]
135 |
136 | def min_max(array):
137 | return (array - array.min()) / (array.max() - array.min())
138 | with torch.no_grad():
139 | self.network.init_predictor.load_state_dict(torch.load(self.TEST_INITIAL_PREDICTOR_WEIGHT_PATH))
140 | self.network.denoiser.load_state_dict(torch.load(self.TEST_DENOISER_WEIGHT_PATH))
141 | print('Test Model loaded')
142 | self.network.eval()
143 | tq = tqdm(self.dataloader_test)
144 | sampler = self.diffusion
145 | iteration = 0
146 | for img, gt, name in tq:
147 | tq.set_description(f'Iteration {iteration} / {len(self.dataloader_test.dataset)}')
148 | iteration += 1
149 | if self.native_resolution == 'True':
150 | temp = img
151 | img = crop_concat(img)
152 | noisyImage = torch.randn_like(img).to(self.device)
153 | init_predict = self.network.init_predictor(img.to(self.device), 0)
154 |
155 | if self.DPM_SOLVER == 'True':
156 | sampledImgs = dpm_solver(self.schedule.get_betas(), self.network,
157 | torch.cat((noisyImage, img.to(self.device)), dim=1), self.DPM_STEP)
158 | else:
159 | sampledImgs = sampler(noisyImage.cuda(), init_predict, self.pre_ori)
160 | finalImgs = (sampledImgs + init_predict)
161 | if self.native_resolution == 'True':
162 | finalImgs = crop_concat_back(temp, finalImgs)
163 | init_predict = crop_concat_back(temp, init_predict)
164 | sampledImgs = crop_concat_back(temp, sampledImgs)
165 | img = temp
166 | img_save = torch.cat((img, gt, init_predict.cpu(), min_max(sampledImgs.cpu()), finalImgs.cpu()), dim=3)
167 | save_image(img_save, os.path.join(
168 | self.test_img_save_path, f"{name[0]}.png"), nrow=4)
169 |
170 |
171 | def train(self):
172 | optimizer = optim.AdamW(self.network.parameters(), lr=self.LR, weight_decay=1e-4)
173 | iteration = self.continue_training_steps
174 | save_img_path = init__result_Dir()
175 | print('Starting Training', f"Step is {self.num_timesteps}")
176 |
177 | while iteration < self.iteration_max:
178 |
179 | tq = tqdm(self.dataloader_train)
180 |
181 | for img, gt, name in tq:
182 | tq.set_description(f'Iteration {iteration} / {self.iteration_max}')
183 | self.network.train()
184 | optimizer.zero_grad()
185 |
186 | t = torch.randint(0, self.num_timesteps, (img.shape[0],)).long().to(self.device)
187 | init_predict, noise_pred, noisy_image, noise_ref = self.network(gt.to(self.device), img.to(self.device),
188 | t, self.diffusion)
189 | if self.pre_ori == 'True':
190 | if self.high_low_freq == 'True':
191 | residual_high = self.high_filter(gt.to(self.device) - init_predict)
192 | ddpm_loss = 2*self.loss(self.high_filter(noise_pred), residual_high) + self.loss(noise_pred, gt.to(self.device) - init_predict)
193 | else:
194 | ddpm_loss = self.loss(noise_pred, gt.to(self.device) - init_predict)
195 | else:
196 | ddpm_loss = self.loss(noise_pred, noise_ref.to(self.device))
197 | if self.high_low_freq == 'True':
198 | low_high_loss = self.loss(init_predict, gt.to(self.device))
199 | low_freq_loss = self.loss(init_predict - self.high_filter(init_predict), gt.to(self.device) - self.high_filter(gt.to(self.device)))
200 | pixel_loss = low_high_loss + 2*low_freq_loss
201 | else:
202 | pixel_loss = self.loss(init_predict, gt.to(self.device))
203 |
204 | loss = ddpm_loss + self.beta_loss * (pixel_loss) / self.num_timesteps
205 | loss.backward()
206 | optimizer.step()
207 | if self.high_low_freq == 'True':
208 | tq.set_postfix(loss=loss.item(), high_freq_ddpm_loss=ddpm_loss.item(), low_freq_pixel_loss=low_freq_loss.item(), pixel_loss=low_high_loss.item())
209 | else:
210 | tq.set_postfix(loss=loss.item(), ddpm_loss=ddpm_loss.item(), pixel_loss=pixel_loss.item())
211 | if iteration % 500 == 0:
212 | if not os.path.exists(save_img_path):
213 | os.makedirs(save_img_path)
214 | img_save = torch.cat([img, gt, init_predict.cpu()], dim=3)
215 | if self.pre_ori == 'True':
216 | if self.high_low_freq == 'True':
217 | img_save = torch.cat([img, gt, init_predict.cpu(), noise_pred.cpu() + self.high_filter(init_predict).cpu(), noise_pred.cpu() + init_predict.cpu()], dim=3)
218 | else:
219 | img_save = torch.cat([img, gt, init_predict.cpu(), noise_pred.cpu() + init_predict.cpu()], dim=3)
220 | save_image(img_save, os.path.join(
221 | save_img_path, f"{iteration}.png"), nrow=4)
222 | iteration += 1
223 | if self.EMA_or_not == 'True':
224 | if iteration % self.ema_every == 0 and iteration > self.start_ema:
225 | print('EMA update')
226 | self.EMA.update_model_average(self.ema_model, self.network)
227 |
228 | if iteration % self.save_model_every == 0:
229 | print('Saving models')
230 | if not os.path.exists(self.weight_save_path):
231 | os.makedirs(self.weight_save_path)
232 | torch.save(self.network.init_predictor.state_dict(),
233 | os.path.join(self.weight_save_path, f'model_init_{iteration}.pth'))
234 | torch.save(self.network.denoiser.state_dict(),
235 | os.path.join(self.weight_save_path, f'model_denoiser_{iteration}.pth'))
236 | if self.EMA_or_not == 'True':
237 | torch.save(self.ema_model.init_predictor.state_dict(),
238 | os.path.join(self.weight_save_path, f'model_init_ema_{iteration}.pth'))
239 | torch.save(self.ema_model.denoiser.state_dict(),
240 | os.path.join(self.weight_save_path, f'model_denoiser_ema_{iteration}.pth'))
241 |
242 |
243 |
244 | def dpm_solver(betas, model, x_T, steps, model_kwargs):
245 | # You need to firstly define your model and the extra inputs of your model,
246 | # And initialize an `x_T` from the standard normal distribution.
247 | # `model` has the format: model(x_t, t_input, **model_kwargs).
248 | # If your model has no extra inputs, just let model_kwargs = {}.
249 |
250 | # If you use discrete-time DPMs, you need to further define the
251 | # beta arrays for the noise schedule.
252 |
253 | # model = ....
254 | # model_kwargs = {...}
255 | # x_T = ...
256 | # betas = ....
257 |
258 | # 1. Define the noise schedule.
259 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas)
260 |
261 | # 2. Convert your discrete-time `model` to the continuous-time
262 | # noise prediction model. Here is an example for a diffusion model
263 | # `model` with the noise prediction type ("noise") .
264 | model_fn = model_wrapper(
265 | model,
266 | noise_schedule,
267 | model_type="noise", # or "x_start" or "v" or "score"
268 | model_kwargs=model_kwargs,
269 | )
270 |
271 | # 3. Define dpm-solver and sample by singlestep DPM-Solver.
272 | # (We recommend singlestep DPM-Solver for unconditional sampling)
273 | # You can adjust the `steps` to balance the computation
274 | # costs and the sample quality.
275 | dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++",
276 | correcting_x0_fn="dynamic_thresholding")
277 | # Can also try
278 | # dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
279 |
280 | # You can use steps = 10, 12, 15, 20, 25, 50, 100.
281 | # Empirically, we find that steps in [10, 20] can generate quite good samples.
282 | # And steps = 20 can almost converge.
283 | x_sample = dpm_solver.sample(
284 | x_T,
285 | steps=steps,
286 | order=1,
287 | skip_type="time_uniform",
288 | method="singlestep",
289 | )
290 | return x_sample
291 |
--------------------------------------------------------------------------------
/utils/font/simhei.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/utils/font/simhei.ttf
--------------------------------------------------------------------------------
/utils/font/times.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/utils/font/times.ttf
--------------------------------------------------------------------------------
/utils/font/timesbd.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/utils/font/timesbd.ttf
--------------------------------------------------------------------------------
/utils/font/timesbi.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/utils/font/timesbi.ttf
--------------------------------------------------------------------------------
/utils/font/timesi.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/utils/font/timesi.ttf
--------------------------------------------------------------------------------
/utils/font/方正仿宋_GBK.TTF:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/utils/font/方正仿宋_GBK.TTF
--------------------------------------------------------------------------------
/utils/font/楷体_GB2312.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/utils/font/楷体_GB2312.ttf
--------------------------------------------------------------------------------
/utils/font/青鸟华光简琥珀.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/utils/font/青鸟华光简琥珀.ttf
--------------------------------------------------------------------------------
/utils/marker.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | # This is a watermark tool for images
4 | # Highly inherit from https://github.com/2Dou/watermarker
5 |
6 |
7 | import argparse
8 | import os
9 | import sys
10 | import math
11 | import textwrap
12 | import random
13 |
14 | import numpy as np
15 | from PIL import Image, ImageFont, ImageDraw, ImageEnhance, ImageChops, ImageOps
16 |
17 |
18 | def randomtext():
19 | def GBK2312():
20 | head = random.randint(0xb0, 0xf7)
21 | body = random.randint(0xa1, 0xfe)
22 | val = f'{head:x}{body:x}'
23 | strr = bytes.fromhex(val).decode('gb2312',errors="ignore")
24 | return strr
25 |
26 | def randomABC():
27 | A = np.random.randint(65, 91)
28 | a = np.random.randint(97, 123)
29 | char = chr(A) + chr(a)
30 | for i in range(3):
31 | char = char + str(np.random.randint(0, 9))
32 | return char
33 |
34 | char = randomABC()
35 | for i in range(5):
36 | char = char + GBK2312()
37 | return char
38 |
39 |
40 | def randomcolor():
41 | colorArr = ['1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F']
42 | color = ""
43 | for i in range(6):
44 | color += colorArr[random.randint(0, 14)]
45 | return "#" + color
46 |
47 |
48 | def add_mark(imagePath, mark, args):
49 | """
50 | 添加水印,然后保存图片
51 | """
52 | im = Image.open(imagePath)
53 | im = ImageOps.exif_transpose(im)
54 |
55 | image, mask = mark(im)
56 | name = os.path.basename(imagePath)
57 | if image:
58 | if not os.path.exists(args.out):
59 | os.mkdir(args.out)
60 |
61 | new_name = os.path.join(args.out, name)
62 | if os.path.splitext(new_name)[1] != '.png':
63 | image = image.convert('RGB')
64 | image.save(new_name, quality=args.quality)
65 | #mask.save(os.path.join(args.out, os.path.basename(os.path.splitext(new_name)[0]) + '.png'),
66 | # quality=args.quality)
67 | print(name + " Success.")
68 | else:
69 | print(name + " Failed.")
70 |
71 |
72 | def set_opacity(im, opacity):
73 | """
74 | 设置水印透明度
75 | """
76 | assert 0 <= opacity <= 1
77 |
78 | alpha = im.split()[3]
79 | alpha = ImageEnhance.Brightness(alpha).enhance(opacity)
80 | im.putalpha(alpha)
81 | return im
82 |
83 |
84 | def crop_image(im):
85 | """裁剪图片边缘空白"""
86 | bg = Image.new(mode='RGBA', size=im.size)
87 | diff = ImageChops.difference(im, bg)
88 | del bg
89 | bbox = diff.getbbox()
90 | if bbox:
91 | return im.crop(bbox)
92 | return im
93 |
94 |
95 | def gen_mark(args):
96 | """
97 | 生成mark图片,返回添加水印的函数
98 | """
99 | # 字体宽度、高度
100 | is_height_crop_float = '.' in args.font_height_crop # not good but work
101 | width = len(args.mark) * args.size
102 | if is_height_crop_float:
103 | height = round(args.size * float(args.font_height_crop))
104 | else:
105 | height = int(args.font_height_crop)
106 |
107 | # 创建水印图片(宽度、高度)
108 | mark = Image.new(mode='RGBA', size=(width, height))
109 |
110 | # 生成文字
111 | draw_table = ImageDraw.Draw(im=mark)
112 | draw_table.text(xy=(0, 0),
113 | text=args.mark,
114 | fill=args.color,
115 | font=ImageFont.truetype(args.font_family,
116 | size=args.size))
117 | del draw_table
118 |
119 | # 裁剪空白
120 | mark = crop_image(mark)
121 |
122 | # 透明度
123 | set_opacity(mark, args.opacity)
124 |
125 | def mark_im(im):
126 | """ 在im图片上添加水印 im为打开的原图"""
127 |
128 | # 计算斜边长度
129 | c = int(math.sqrt(im.size[0] * im.size[0] + im.size[1] * im.size[1]))
130 |
131 | # 以斜边长度为宽高创建大图(旋转后大图才足以覆盖原图)
132 | mark2 = Image.new(mode='RGBA', size=(c, c))
133 |
134 | # 在大图上生成水印文字,此处mark为上面生成的水印图片
135 | y, idx = 0, 0
136 | while y < c:
137 | # 制造x坐标错位
138 | x = -int((mark.size[0] + args.space) * 0.5 * idx)
139 | idx = (idx + 1) % 2
140 |
141 | while x < c:
142 | # 在该位置粘贴mark水印图片
143 | mark2.paste(mark, (x, y))
144 | x = x + mark.size[0] + args.space
145 | y = y + mark.size[1] + args.space
146 |
147 | # 将大图旋转一定角度
148 | mark2 = mark2.rotate(args.angle)
149 |
150 | # 在原图上添加大图水印
151 | if im.mode != 'RGBA':
152 | im = im.convert('RGBA')
153 | mask = np.zeros_like(np.array(im))
154 | mask = Image.fromarray(mask)
155 | im.paste(mark2, # 大图
156 | (int((im.size[0] - c) / 2), int((im.size[1] - c) / 2)), # 坐标
157 | mask=mark2.split()[3])
158 | mask.paste(mark2, # 大图
159 | (int((im.size[0] - c) / 2), int((im.size[1] - c) / 2)), # 坐标
160 | mask=mark2.split()[3])
161 | #mask.show()
162 | del mark2
163 | return im, mask
164 |
165 | return mark_im
166 |
167 |
168 | def main():
169 | parse = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
170 | parse.add_argument("-f", "--file", type=str, default="./input",
171 | help="image file path or directory")
172 | parse.add_argument("-m", "--mark", type=str, default="测试一句话行不行", help="watermark content")
173 | parse.add_argument("-o", "--out", default="./2",
174 | help="image output directory, default is ./output")
175 | parse.add_argument("-c", "--color", default="#8B8B1B", type=str,
176 | help="text color like '#000000', default is #8B8B1B")
177 | parse.add_argument("-s", "--space", default=75, type=int,
178 | help="space between watermarks, default is 75")
179 | parse.add_argument("-a", "--angle", default=90, type=int,
180 | help="rotate angle of watermarks, default is 30")
181 | parse.add_argument("--font-family", default="./font/青鸟华光简琥珀.ttf", type=str,
182 | help=textwrap.dedent('''\
183 | font family of text, default is './font/青鸟华光简琥珀.ttf'
184 | using font in system just by font file name
185 | for example 'PingFang.ttc', which is default installed on macOS
186 | '''))
187 | parse.add_argument("--font-height-crop", default="1.2", type=str,
188 | help=textwrap.dedent('''\
189 | change watermark font height crop
190 | float will be parsed to factor; int will be parsed to value
191 | default is '1.2', meaning 1.2 times font size
192 | this useful with CJK font, because line height may be higher than size
193 | '''))
194 | parse.add_argument("--size", default=50, type=int,
195 | help="font size of text, default is 50")
196 | parse.add_argument("--opacity", default=0.15, type=float,
197 | help="opacity of watermarks, default is 0.15")
198 | parse.add_argument("--quality", default=100, type=int,
199 | help="quality of output images, default is 90")
200 |
201 | args = parse.parse_args()
202 |
203 | # 随机参数,从[A,B]中随机选取一个整数
204 | space_ = [30, 40]
205 | angel_ = [0, 180]
206 | size_ = [20, 60]
207 | opacity_ = [50, 95]
208 | # 字体随机(需要别的字体的话,可以去字体网站下载)
209 | fonts = [r'./font/青鸟华光简琥珀.ttf', r'./font/楷体_GB2312.ttf', './font/方正仿宋_GBK.TTF', './font/times.ttf',
210 | './font/simhei.ttf']
211 | if isinstance(args.mark, str) and sys.version_info[0] < 3:
212 | args.mark = args.mark.decode("utf-8")
213 |
214 | if os.path.isdir(args.file):
215 | names = os.listdir(args.file)
216 | for name in names:
217 | image_file = os.path.join(args.file, name)
218 | space = random.randint(space_[0], space_[1])
219 | args.space = space
220 | angel = random.randint(angel_[0], angel_[1])
221 | args.angle = angel
222 | size = random.randint(size_[0], size_[1])
223 | args.size = size
224 | opacity = random.randint(opacity_[0], opacity_[1]) / 100
225 | args.opacity = opacity
226 | color = randomcolor()
227 | args.color = color
228 | font = np.random.choice(fonts)
229 | args.font_family = font
230 | mark_text = randomtext()
231 | args.mark = mark_text
232 | mark = gen_mark(args)
233 | add_mark(image_file, mark, args)
234 | else:
235 | mark = gen_mark(args)
236 | add_mark(args.file, mark, args)
237 |
238 |
239 | if __name__ == '__main__':
240 | main()
241 |
--------------------------------------------------------------------------------