├── LICENSE ├── README.EN.md ├── README.md ├── checksave ├── denoiser.pth ├── init.pth ├── seal_denoiser.pth └── seal_init.pth ├── conf.yml ├── data └── docdata.py ├── demo ├── 0000184_blur.png ├── 0000184_orig.png ├── DEGAN_0000184.png ├── DEGAN_0001260.png ├── inference.ipynb └── teaser.png ├── main.py ├── model └── DocDiff.py ├── schedule ├── diffusionSample.py ├── dpm_solver_pytorch.py └── schedule.py ├── src ├── config.py ├── sobel.py ├── train.py └── trainer.py └── utils ├── font ├── simhei.ttf ├── times.ttf ├── timesbd.ttf ├── timesbi.ttf ├── timesi.ttf ├── 方正仿宋_GBK.TTF ├── 楷体_GB2312.ttf └── 青鸟华光简琥珀.ttf └── marker.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Zongyuan Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.EN.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | 5 |
6 | 7 |
8 | 9 | [简体中文](README.md) | [English](README.EN.md) | [Paper](https://dl.acm.org/doi/abs/10.1145/3581783.3611730) 10 | # DocDiff 11 | This is the official repository for the paper [DocDiff: Document Enhancement via Residual Diffusion Models](https://dl.acm.org/doi/abs/10.1145/3581783.3611730). DocDiff is a document enhancement model (please refer to the [paper](https://arxiv.org/abs/2305.03892v1)) that can be used for tasks such as document deblurring, denoising, binarization, watermark and stamp removal, etc. DocDiff is a lightweight residual prediction-based diffusion model, that can be trained on a batch size of 64 with only 12GB of VRAM at a resolution of 128*128. 12 | 13 | Not only for document enhancement, DocDiff can also be used for other img2img tasks, such as natural scene deblurring[1](#refer-anchor-1), denoising, rain removal, super-resolution[2](#refer-anchor-2), image inpainting, as well as high-level tasks such as semantic segmentation[4](#refer-anchor-4). 14 |
15 | 16 | # News 17 | 18 | - **Pinned**: Introducing our laboratory-developed versatile and cross-platform [**OCR software**](https://www.aibupt.com/). **It includes the automatic removal of watermarks and stamps using DocDiff (automatic watermark removal feature coming soon)**. It also encompasses various commonly used OCR functions such as PDF to Word conversion, PDF to Excel conversion, formula recognition, and table recognition. Feel free to give it a try! 19 | - 2023.09.14: Uploaded watermark synthesis code `utils/marker.py` and seal dataset. [Seal dataset Google Drive](https://drive.google.com/file/d/125SgEmHFUIzDexsrj2d3yMJdYMVhovti/view?usp=sharing) 20 | - 2023.08.02: Document binarization results for H-DIBCO 2018 [6](#refer-anchor-6) and DIBCO 2019 [7](#refer-anchor-7) have been uploaded. You can access them in the [Google Drive](https://drive.google.com/drive/folders/1gT8PFnfW0qFbFmWX6ReQntfFr9POVtYR?usp=sharing) 21 | - 2023.08.01: **Congratulations! DocDiff has been accepted by ACM Multimedia 2023!** 22 | - 2023.06.13: The inference notebook `demo/inference.ipynb` is uploaded for convenient reproduction and pretrained models `checksave/` are uploaded. 23 | - 2023.05.08: The initial version of the code is uploaded. Please check the to-do list for future updates. 24 | 25 | # Guide 26 | 27 | Whether it's for training or inference, you just need to modify the configuration parameters in `conf.yml` and run `main.py`. MODE=1 is for training, MODE=0 is for inference. The parameters in `conf.yml` have detailed annotations, so you can modify them as needed. Pre-trained weights for document deblurring Coarse Predictor and Denoiser can be found in `checksave/`, respectively. 28 | 29 | Please note that the default parameters in `conf.yml` work best for document scenarios. If you want to apply DocDiff to natural scenes, please first read [Notes!](#notes!) carefully. If you still have issues, welcome to submit an issue. 30 | 31 | - Because downsampling is applied three times, the resolution of the input image must be a multiple of 8. If your image is not a multiple of 8, you can adjust the image to be a multiple of 8 using padding or cropping. Please do not directly resize, as it may cause image distortion. In particular, in the deblurring task, image distortion will increase the blur and result in poor performance. For example, the document deblurring dataset [5](#refer-anchor-5) used by DocDiff has a resolution of 300\*300, which needs to be padded to 304\*304 before inference. 32 | 33 | ## Environment 34 | 35 | - python >= 3.7 36 | - pytorch >= 1.7.0 37 | - torchvision >= 0.8.0 38 | 39 | ## Watermark Synthesis and Seal Dataset 40 | 41 | We provide watermark synthesis code `utils/marker.py` and a stamp dataset. [Seal dataset Google Drive](https://drive.google.com/file/d/125SgEmHFUIzDexsrj2d3yMJdYMVhovti/view?usp=sharing). Since the document background images used are our internal data, we did not provide the background images. If you want to use the watermark synthesis code, you need to find some document background images yourself. The watermark synthesis code is implemented based on OpenCV, so you need to install OpenCV. 42 | 43 | ### Seal Dataset 44 | 45 | The Seal Dataset belongs to the [DocDiff project](https://github.com/Royalvice/DocDiff). It contains 1597 red seals in Chinese scenes, along with their corresponding binary masks. These seal data can be used for tasks such as seal synthesis and seal removal. Due to limited manpower, it is extremely difficult to extract seals from document images, so some seal images may contain noise. Most of the original seal 46 | images in the dataset are from the ICDAR 2023 Competition on Reading the Seal Title ([https://rrc.cvc.uab.es/?ch=20](https://rrc.cvc.uab.es/?ch=20)) dataset, and a few are from our internal images. If you find this dataset helpful, please give our [project](https://github.com/Royalvice/DocDiff) a free star, thank you!!! 47 | 48 |
49 | 50 | # Notes! 51 | 52 |
53 | 54 | - The default configuration parameters of DocDiff are designed for **document images**, and if you want to achieve better results when using it for **natural scenes**, you need to adjust the parameters. For example, you can scale up the model, add **self-attention**, etc. (because document images have relatively fixed patterns, but natural scenes have more diverse patterns and require more parameters). Additionally, you may need to modify the **training and inference strategies**. 55 | - **Training strategy**: As described in the paper, in document scenarios, we do not pursue diverse results and we need to minimize the inference time as much as possible. Therefore, we set the diffusion step T to 100, and predict $x_0$ instead of predicting $\epsilon$. Based on the premise of using a channel-wise concatenation conditioning scheme, this strategy can recover a fine $x_0$ in the early steps of reverse diffusion. In natural scenes, in order to better reconstruct textures and pursue diverse results, the diffusion step T should be set as large as possible, and $\epsilon$ should be predicted. You just need to modify **PRE_ORI="False"** in `conf.yml` to use the scheme of predicting $\epsilon$, and modify **TIMESTEPS=1000** to use a larger diffusion step. 56 | - **Inference strategy**: The images generated in document scenarios should not have randomness. (short-step stochastic sampling may cause text edges to be distorted), so DocDiff performs deterministic sampling as described in DDIM[3](#refer-anchor-3). In natural scenes, stochastic sampling is essential for diverse results, so you can use stochastic sampling by modifying **PRE_ORI="False"** in `conf.yml`. In other words, the scheme of predicting $\epsilon$ is bound to stochastic sampling, while the scheme of predicting $x_0$ is bound to deterministic sampling. If you want to predict $x_0$ and use stochastic sampling, or predict $\epsilon$ and use deterministic sampling, you need to modify the code yourself. In DocDiff, deterministic sampling is performed using the method in DDIM, while stochastic sampling is performed using the method in DDPM. You can modify the code to implement other sampling strategies yourself. 57 | - **Summary**: For tasks that do not require diverse results, such as semantic segmentation, document enhancement, predicting $x_0$ with a diffusion step of 100 is enough, and the performance is already good. For tasks that require diverse results, such as deblurring for natural scenes, super-resolution, image restoration, etc., predicting $\epsilon$ with a diffusion step of 1000 is recommended. 58 | 59 | # To-do Lists 60 | 61 | - [x] Add training code 62 | - [x] Add inference code 63 | - [x] Upload pre-trained model 64 | - [x] Upload watermark synthesis code and seal dataset. 65 | - [x] Use DPM_solver to reduce inference step size (although the effect is not significant in practice) 66 | - [x] Uploaded the inference notebook for convenient reproduction 67 | - [ ] Synthesize document datasets with more noise, such as salt-and-pepper noise and noise generated from compression. 68 | - [ ] Train on multiple GPUs 69 | - [ ] Jump-step sampling for DDIM 70 | - [ ] Use depth separable convolution to compress the model 71 | - [ ] Train the model on natural scenes and provide results and pre-trained models 72 | 73 | # Stars over time 74 | 75 | [![Stargazers over time](https://starchart.cc/Royalvice/DocDiff.svg)](https://starchart.cc/Royalvice/DocDiff) 76 | 77 | # Acknowledgement 78 | 79 | - If you find DocDiff helpful, please give us a star. Thank you! 🤞😘 80 | - If you have any questions, please don't hesitate to open an issue. We will reply as soon as possible. 81 | - If you want to communicate with us, please send an email to **viceyzy@foxmail.com** with the subject "**DocDiff**". 82 | - If you want to use DocDiff as the baseline for your project, please cite our paper. 83 | ``` 84 | @inproceedings{yang2023docdiff, 85 | title={DocDiff: Document Enhancement via Residual Diffusion Models}, 86 | author={Yang, Zongyuan and Liu, Baolin and Xxiong, Yongping and Yi, Lan and Wu, Guibin and Tang, Xiaojun and Liu, Ziqi and Zhou, Junjie and Zhang, Xing}, 87 | booktitle={Proceedings of the 31st ACM International Conference on Multimedia}, 88 | pages={2795--2806}, 89 | year={2023} 90 | } 91 | ``` 92 | 93 | # References 94 |
95 | 96 | - [1] Whang J, Delbracio M, Talebi H, et al. Deblurring via stochastic refinement[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022: 16293-16303. 97 | 98 |
99 | 100 | - [2] Shang S, Shan Z, Liu G, et al. ResDiff: Combining CNN and Diffusion Model for Image Super-Resolution[J]. arXiv preprint arXiv:2303.08714, 2023. 101 | 102 |
103 | 104 | - [3] Song J, Meng C, Ermon S. Denoising diffusion implicit models[J]. arXiv preprint arXiv:2010.02502, 2020. 105 | 106 |
107 | 108 | - [4] Wu J, Fang H, Zhang Y, et al. MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model[J]. arXiv preprint arXiv:2211.00611, 2022. 109 | 110 |
111 | 112 | - [5] Michal Hradiš, Jan Kotera, Pavel Zemčík and Filip Šroubek. Convolutional Neural Networks for Direct Text Deblurring. In Xianghua Xie, Mark W. Jones, and Gary K. L. Tam, editors, Proceedings of the British Machine Vision Conference (BMVC), pages 6.1-6.13. BMVA Press, September 2015. 113 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 | 5 |
6 | 7 |
8 | 9 | [简体中文](README.md) | [English](README.EN.md) | [Paper](https://dl.acm.org/doi/abs/10.1145/3581783.3611730) 10 | 11 | [![Visitors](https://api.visitorbadge.io/api/combined?path=https%3A%2F%2Fgithub.com%2FRoyalvice%2FDocDiff&countColor=%23d9e3f0)](https://visitorbadge.io/status?path=https%3A%2F%2Fgithub.com%2FRoyalvice%2FDocDiff) 12 | 13 | # DocDiff 14 | 这里是论文[DocDiff: Document Enhancement via Residual Diffusion Models](https://dl.acm.org/doi/abs/10.1145/3581783.3611730)的官方复现仓库。DocDiff是一个文档增强模型(详见[论文](https://arxiv.org/abs/2305.03892v1)),可以用于文档去模糊、文档去噪、文档二值化、文档去水印和印章等任务。DocDiff是一个轻量级的基于残差预测的扩散模型,在128*128分辨率上以Batchsize=64训练只需要12GB显存。 15 | 不仅文档增强,DocDiff还可以应用在其他img2img任务上,比如自然场景去模糊[1](#refer-anchor-1),去噪,去雨,超分[2](#refer-anchor-2),图像修复等low-level任务以及语义分割[4](#refer-anchor-4)等high-level任务。 16 |
17 | 18 | # News 19 | 20 | - **置顶**: 介绍一款我们实验室开发的多功能且多平台的[**OCR软件**](https://www.aibupt.com/)。**其中包含了DocDiff的自动去除水印和印章的功能(自动去除水印功能即将上线)**。同样包含常用的各种OCR功能,例如PDF转word,PDF转excel,公式识别,表格识别。欢迎试用! 21 | - 2023.09.14: 上传了水印合成代码`utils/marker.py`和印章数据集。[印章数据集Google Drive](https://drive.google.com/file/d/125SgEmHFUIzDexsrj2d3yMJdYMVhovti/view?usp=sharing) 22 | - 2023.08.02: H-DIBCO 2018 [6](#refer-anchor-6) 和 DIBCO 2019 [7](#refer-anchor-7) 的文档二值化结果已经上传。[Google Drive](https://drive.google.com/drive/folders/1gT8PFnfW0qFbFmWX6ReQntfFr9POVtYR?usp=sharing) 23 | - 2023.08.01: **祝贺!DocDiff被ACM Multimedia 2023接收!** 24 | - 2023.06.13: 为了方便复现,已上传推理笔记本`demo/inference.ipynb`和预训练模型`checksave/`。 25 | - 2023.05.08: 代码的初始版本已经上传。请查看To-do lists来获取未来的更新。 26 | 27 | # 使用指南 28 | 29 | 无论是训练还是推理,你只需要修改conf.yml中的配置参数,然后运行main.py即可。MODE=1为训练,MODE=0为推理。conf.yml中的参数都有详细注释,你可以根据注释修改参数。文档去模糊预训练权重在`checksave/`。 30 | **请注意**conf.yml中的默认参数在文档场景表现最好。如果你想应用DocDiff在自然场景,请先看一下[注意事项!!!](#注意事项!!!)。如果仍有问题,欢迎提issue。 31 | 32 | - 由于要下采样3次,所以输入图像的分辨率必须是8的倍数。如果你的图像不是8的倍数,可以使用padding或者裁剪的方式将图像调整为8的倍数。请不要直接Resize,因为这样会导致图像失真。尤其在去模糊任务中,图像失真会导致模糊程度增加,效果会变得很差。例如,DocDiff使用的文档去模糊数据集[5](#refer-anchor-5)分辨率为300\*300,需要先padding到304\*304,再送入推理。 33 | 34 | ## 环境配置 35 | 36 | - python >= 3.7 37 | - pytorch >= 1.7.0 38 | - torchvision >= 0.8.0 39 | 40 | ## 水印合成与印章数据集 41 | 42 | 我们提供了水印合成代码`utils/marker.py`和印章数据集。[印章数据集Google Drive](https://drive.google.com/file/d/125SgEmHFUIzDexsrj2d3yMJdYMVhovti/view?usp=sharing)。由于使用的文档背景图像是我们内部的数据,所以我们没有提供背景图片。如果你想使用水印合成代码,你需要自己找一些文档背景图像。水印合成代码是基于OpenCV实现的,所以你需要安装OpenCV。 43 | 44 | ### 印章数据集 45 | 46 | 印章数据集隶属于[DocDiff项目](https://github.com/Royalvice/DocDiff),其中包含1597个中文场景下的红色系印章以及它们对应的二值化的掩膜,这些印章数据可以用于印章合成、印章消除等等任务中。由于人力有限,而从文档图片中抠出来印章是极其困难的事情,所以某些印章图片中包含一些噪声。数据集中的原始印章图片大部分来自于ICDAR 2023 Competition on Reading the Seal Title([https://rrc.cvc.uab.es/?ch=20](https://rrc.cvc.uab.es/?ch=20))数据集,少部分来自于我们自己内部的图片。如果您觉得这份数据集对您有帮助,请给我们的[项目](https://github.com/Royalvice/DocDiff)一个免费的star,谢谢!!! 47 | 48 |
49 | 50 | # 注意事项!!! 51 |
52 | 53 | - DocDiff的默认配置参数,训练和推理策略是为**文档图像设计**的,如果要用于自然场景,想获得更好的效果,需要**调整参数**,比如扩大模型,添加Self-Attention等(因为文档图像的模式相对固定,但是自然场景的模式比较多样需要更多的参数)并修改**训练和推理策略**。 54 | - **训练策略**:如论文所述,在文档场景中,因为不追求生成多样性,并且希望尽可能缩减推理时间。所以我们将扩散步长T设为100,并预测 $x_0$ 而不是预测 $\epsilon$。在使用基于通道叠加的引入条件(Coarse Predictor的输出)的方案的前提下,这种策略可以使得在逆向扩散的前几步就可以恢复出较好的 $x_0$ 。在自然场景中,为了更好地重建纹理并追求生成多样性,扩散步长T尽可能大,并要预测 $\epsilon$ 。你只需要修改**conf.yml**中的**PRE_ORI="False"**,即可使用预测 $\epsilon$ 的方案; 修改**conf.yml**中的**TIMESTEPS=1000**,即可使用更大的扩散步长。 55 | - **推理策略**:在文档场景中生成的图像不想带有随机性(短步随机采样会导致文本边缘扭曲),DocDiff执行DDIM[3](#refer-anchor-3)中的确定采样。在自然场景中,随机采样是生成多样性的关键,修改**conf.yml**中的**PRE_ORI="False"**,即可使用随机采样。也就是说,预测 $\epsilon$ 的方案与随机采样是绑定的,而预测 $x_0$ 的方案与确定采样是绑定的。如果你想预测 $x_0$ 并随机采样或者 预测 $\epsilon$ 并确定采样,你需要自己修改代码。DocDiff中确定采样是DDIM中的确定采样,随机采样是DDPM中的随机采样,你可以自己修改代码实现其他采样策略。 56 | - **总结**:应用于不需要生成多样性的任务,比如语义分割,文档增强,使用预测 $x_0$ 的方案,扩散步长T设为100就ok,效果已经很好了;应用于需要生成多样性的任务,比如自然场景去模糊,超分,图像修复等,使用预测 $\epsilon$ 的方案,扩散步长T设为1000。 57 | 58 | # To-do lists 59 | 60 | - [x] 添加训练代码 61 | - [x] 添加推理代码 62 | - [x] 上传预训练模型 63 | - [x] 上传水印合成代码和印章数据集 64 | - [x] 使用DPM_solver减少推理步长(实际用起来,效果一般) 65 | - [x] 上传Inference notebook,方便复现 66 | - [ ] 合成包含更多噪声的文档数据集(比如椒盐噪声,压缩产生的噪声) 67 | - [ ] 多GPU训练 68 | - [ ] DDIM的跳步采样 69 | - [ ] 使用深度可分离卷积压缩模型 70 | - [ ] 在自然场景上训练模型并提供结果和预训练模型 71 | 72 | # Stars over time 73 | 74 | [![Stargazers over time](https://starchart.cc/Royalvice/DocDiff.svg)](https://starchart.cc/Royalvice/DocDiff) 75 | 76 | # 感谢 77 | 78 | - 如果你觉得DocDiff对你有帮助,请给个star,谢谢!🤞😘 79 | - 如果你有任何问题,欢迎提issue,我会尽快回复。 80 | - 如果你想交流,欢迎给我发邮件**viceyzy@foxmail.com**,备注:**DocDiff**。 81 | - 如果你愿意将DocDiff作为你的项目的baseline,欢迎引用我们的论文。 82 | ``` 83 | @inproceedings{yang2023docdiff, 84 | title={DocDiff: Document Enhancement via Residual Diffusion Models}, 85 | author={Yang, Zongyuan and Liu, Baolin and Xxiong, Yongping and Yi, Lan and Wu, Guibin and Tang, Xiaojun and Liu, Ziqi and Zhou, Junjie and Zhang, Xing}, 86 | booktitle={Proceedings of the 31st ACM International Conference on Multimedia}, 87 | pages={2795--2806}, 88 | year={2023} 89 | } 90 | ``` 91 | 92 | # References 93 |
94 | 95 | - [1] Whang J, Delbracio M, Talebi H, et al. Deblurring via stochastic refinement[C]//Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022: 16293-16303. 96 | 97 |
98 | 99 | - [2] Shang S, Shan Z, Liu G, et al. ResDiff: Combining CNN and Diffusion Model for Image Super-Resolution[J]. arXiv preprint arXiv:2303.08714, 2023. 100 | 101 |
102 | 103 | - [3] Song J, Meng C, Ermon S. Denoising diffusion implicit models[J]. arXiv preprint arXiv:2010.02502, 2020. 104 | 105 |
106 | 107 | - [4] Wu J, Fang H, Zhang Y, et al. MedSegDiff: Medical Image Segmentation with Diffusion Probabilistic Model[J]. arXiv preprint arXiv:2211.00611, 2022. 108 | 109 |
110 | 111 | - [5] Michal Hradiš, Jan Kotera, Pavel Zemčík and Filip Šroubek. Convolutional Neural Networks for Direct Text Deblurring. In Xianghua Xie, Mark W. Jones, and Gary K. L. Tam, editors, Proceedings of the British Machine Vision Conference (BMVC), pages 6.1-6.13. BMVA Press, September 2015. 112 | 113 |
114 | 115 | - [6] I. Pratikakis, K. Zagori, P. Kaddas and B. Gatos, "ICFHR 2018 Competition on Handwritten Document Image Binarization (H-DIBCO 2018)," 2018 16th International Conference on Frontiers in Handwriting Recognition (ICFHR), Niagara Falls, NY, USA, 2018, pp. 489-493, doi: 10.1109/ICFHR-2018.2018.00091. 116 | 117 |
118 | 119 | - [7] I. Pratikakis, K. Zagoris, X. Karagiannis, L. Tsochatzidis, T. Mondal and I. Marthot-Santaniello, "ICDAR 2019 Competition on Document Image Binarization (DIBCO 2019)," 2019 International Conference on Document Analysis and Recognition (ICDAR), Sydney, NSW, Australia, 2019, pp. 1547-1556, doi: 10.1109/ICDAR.2019.00249. 120 | -------------------------------------------------------------------------------- /checksave/denoiser.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/checksave/denoiser.pth -------------------------------------------------------------------------------- /checksave/init.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/checksave/init.pth -------------------------------------------------------------------------------- /checksave/seal_denoiser.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/checksave/seal_denoiser.pth -------------------------------------------------------------------------------- /checksave/seal_init.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/checksave/seal_init.pth -------------------------------------------------------------------------------- /conf.yml: -------------------------------------------------------------------------------- 1 | # model 2 | IMAGE_SIZE : [128, 128] # load image size, if it's train mode, it will be randomly cropped to IMAGE_SIZE. If it's test mode, it will be resized to IMAGE_SIZE. 3 | CHANNEL_X : 3 # input channel 4 | CHANNEL_Y : 3 # output channel 5 | TIMESTEPS : 100 # diffusion steps 6 | SCHEDULE : 'linear' # linear or cosine 7 | MODEL_CHANNELS : 32 # basic channels of Unet 8 | NUM_RESBLOCKS : 1 # number of residual blocks 9 | CHANNEL_MULT : [1,2,3,4] # channel multiplier of each layer 10 | NUM_HEADS : 1 11 | 12 | MODE : 1 # 1 Train, 0 Test 13 | PRE_ORI : 'True' # if True, predict $x_0$, else predict $\epsilon$. 14 | 15 | 16 | # train 17 | PATH_GT : '' # path of ground truth 18 | PATH_IMG : '' # path of input 19 | BATCH_SIZE : 32 # training batch size 20 | NUM_WORKERS : 16 # number of workers 21 | ITERATION_MAX : 1000000 # max training iteration 22 | LR : 0.0001 # learning rate 23 | LOSS : 'L2' # L1 or L2 24 | EMA_EVERY : 100 # update EMA every EMA_EVERY iterations 25 | START_EMA : 2000 # start EMA after START_EMA iterations 26 | SAVE_MODEL_EVERY : 10000 # save model every SAVE_MODEL_EVERY iterations 27 | EMA: 'True' # if True, use EMA 28 | CONTINUE_TRAINING : 'False' # if True, continue training 29 | CONTINUE_TRAINING_STEPS : 10000 # continue training from CONTINUE_TRAINING_STEPS 30 | PRETRAINED_PATH_INITIAL_PREDICTOR : '' # path of pretrained initial predictor 31 | PRETRAINED_PATH_DENOISER : '' # path of pretrained denoiser 32 | WEIGHT_SAVE_PATH : './checksave' # path to save model 33 | TRAINING_PATH : './Training' # path of training data 34 | BETA_LOSS : 50 # hyperparameter to balance the pixel loss and the diffusion loss 35 | HIGH_LOW_FREQ : 'True' # if True, training with frequency separation 36 | 37 | 38 | # test 39 | NATIVE_RESOLUTION : 'False' # if True, test with native resolution 40 | DPM_SOLVER : 'False' # if True, test with DPM_solver 41 | DPM_STEP : 20 # DPM_solver step 42 | BATCH_SIZE_VAL : 1 # test batch size 43 | TEST_PATH_GT : '' # path of ground truth 44 | TEST_PATH_IMG : '' # path of input 45 | TEST_INITIAL_PREDICTOR_WEIGHT_PATH : '' # path of initial predictor 46 | TEST_DENOISER_WEIGHT_PATH : '' # path of denoiser 47 | TEST_IMG_SAVE_PATH : './results' # path to save results 48 | -------------------------------------------------------------------------------- /data/docdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | from torchvision.transforms import Compose, ToTensor, RandomAffine, RandomHorizontalFlip, RandomCrop 5 | from PIL import Image 6 | 7 | 8 | def ImageTransform(loadSize): 9 | return {"train": Compose([ 10 | RandomCrop(loadSize, pad_if_needed=True, padding_mode='constant', fill=255), 11 | RandomAffine(10, fill=255), 12 | RandomHorizontalFlip(p=0.2), 13 | ToTensor(), 14 | ]), "test": Compose([ 15 | ToTensor(), 16 | ]), "train_gt": Compose([ 17 | RandomCrop(loadSize, pad_if_needed=True, padding_mode='constant', fill=255), 18 | RandomAffine(10, fill=255), 19 | RandomHorizontalFlip(p=0.2), 20 | ToTensor(), 21 | ])} 22 | 23 | 24 | class DocData(Dataset): 25 | def __init__(self, path_img, path_gt, loadSize, mode=1): 26 | super().__init__() 27 | self.path_gt = path_gt 28 | self.path_img = path_img 29 | self.data_gt = os.listdir(path_gt) 30 | self.data_img = os.listdir(path_img) 31 | self.mode = mode 32 | if mode == 1: 33 | self.ImgTrans = (ImageTransform(loadSize)["train"], ImageTransform(loadSize)["train_gt"]) 34 | else: 35 | self.ImgTrans = ImageTransform(loadSize)["test"] 36 | 37 | def __len__(self): 38 | return len(self.data_gt) 39 | 40 | def __getitem__(self, idx): 41 | 42 | gt = Image.open(os.path.join(self.path_gt, self.data_img[idx])) 43 | img = Image.open(os.path.join(self.path_img, self.data_img[idx])) 44 | img = img.convert('RGB') 45 | gt = gt.convert('RGB') 46 | if self.mode == 1: 47 | seed = torch.random.seed() 48 | torch.random.manual_seed(seed) 49 | img = self.ImgTrans[0](img) 50 | torch.random.manual_seed(seed) 51 | gt = self.ImgTrans[1](gt) 52 | else: 53 | img= self.ImgTrans(img) 54 | gt = self.ImgTrans(gt) 55 | name = self.data_img[idx] 56 | return img, gt, name 57 | -------------------------------------------------------------------------------- /demo/0000184_blur.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/demo/0000184_blur.png -------------------------------------------------------------------------------- /demo/0000184_orig.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/demo/0000184_orig.png -------------------------------------------------------------------------------- /demo/DEGAN_0000184.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/demo/DEGAN_0000184.png -------------------------------------------------------------------------------- /demo/DEGAN_0001260.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/demo/DEGAN_0001260.png -------------------------------------------------------------------------------- /demo/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/demo/teaser.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from src.config import load_config 2 | from src.train import train, test 3 | import argparse 4 | 5 | 6 | def main(): 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--config', type=str, default='conf.yml', help='path to the config.yaml file') 9 | args = parser.parse_args() 10 | config = load_config(args.config) 11 | print('Config loaded') 12 | mode = config.MODE 13 | if mode == 1: 14 | train(config) 15 | else: 16 | test(config) 17 | if __name__ == "__main__": 18 | main() -------------------------------------------------------------------------------- /model/DocDiff.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple, Union, List 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | from src.sobel import Sobel, Laplacian 7 | 8 | 9 | class Swish(nn.Module): 10 | """ 11 | ### Swish activation function 12 | $$x \cdot \sigma(x)$$ 13 | """ 14 | 15 | def forward(self, x): 16 | return x * torch.sigmoid(x) 17 | 18 | 19 | class TimeEmbedding(nn.Module): 20 | """ 21 | ### Embeddings for $t$ 22 | """ 23 | 24 | def __init__(self, n_channels: int): 25 | """ 26 | * `n_channels` is the number of dimensions in the embedding 27 | """ 28 | super().__init__() 29 | self.n_channels = n_channels 30 | # First linear layer 31 | self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels) 32 | # Activation 33 | self.act = Swish() 34 | # Second linear layer 35 | self.lin2 = nn.Linear(self.n_channels, self.n_channels) 36 | 37 | def forward(self, t: torch.Tensor): 38 | # Create sinusoidal position embeddings 39 | # [same as those from the transformer](../../transformers/positional_encoding.html) 40 | # 41 | # \begin{align} 42 | # PE^{(1)}_{t,i} &= sin\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg) \\ 43 | # PE^{(2)}_{t,i} &= cos\Bigg(\frac{t}{10000^{\frac{i}{d - 1}}}\Bigg) 44 | # \end{align} 45 | # 46 | # where $d$ is `half_dim` 47 | half_dim = self.n_channels // 8 48 | emb = math.log(10_000) / (half_dim - 1) 49 | emb = torch.exp(torch.arange(half_dim, device=t.device) * -emb) 50 | emb = t[:, None] * emb[None, :] 51 | emb = torch.cat((emb.sin(), emb.cos()), dim=1) 52 | 53 | # Transform with the MLP 54 | emb = self.act(self.lin1(emb)) 55 | emb = self.lin2(emb) 56 | 57 | # 58 | return emb 59 | 60 | 61 | class ResidualBlock(nn.Module): 62 | """ 63 | ### Residual block 64 | A residual block has two convolution layers with group normalization. 65 | Each resolution is processed with two residual blocks. 66 | """ 67 | 68 | def __init__(self, in_channels: int, out_channels: int, time_channels: int, 69 | dropout: float = 0.1, is_noise: bool = True): 70 | """ 71 | * `in_channels` is the number of input channels 72 | * `out_channels` is the number of input channels 73 | * `time_channels` is the number channels in the time step ($t$) embeddings 74 | * `n_groups` is the number of groups for [group normalization](../../normalization/group_norm/index.html) 75 | * `dropout` is the dropout rate 76 | """ 77 | super().__init__() 78 | # Group normalization and the first convolution layer 79 | self.is_noise = is_noise 80 | self.act1 = Swish() 81 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=(1, 1)) 82 | 83 | # Group normalization and the second convolution layer 84 | 85 | self.act2 = Swish() 86 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=(1, 1)) 87 | 88 | # If the number of input channels is not equal to the number of output channels we have to 89 | # project the shortcut connection 90 | if in_channels != out_channels: 91 | self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1)) 92 | else: 93 | self.shortcut = nn.Identity() 94 | 95 | # Linear layer for time embeddings 96 | if self.is_noise: 97 | self.time_emb = nn.Linear(time_channels, out_channels) 98 | self.time_act = Swish() 99 | 100 | self.dropout = nn.Dropout(dropout) 101 | 102 | def forward(self, x: torch.Tensor, t: torch.Tensor): 103 | """ 104 | * `x` has shape `[batch_size, in_channels, height, width]` 105 | * `t` has shape `[batch_size, time_channels]` 106 | """ 107 | # First convolution layer 108 | h = self.conv1(self.act1(x)) 109 | # Add time embeddings 110 | if self.is_noise: 111 | h += self.time_emb(self.time_act(t))[:, :, None, None] 112 | # Second convolution layer 113 | h = self.conv2(self.dropout(self.act2(h))) 114 | 115 | # Add the shortcut connection and return 116 | return h + self.shortcut(x) 117 | 118 | 119 | class DownBlock(nn.Module): 120 | """ 121 | ### Down block 122 | This combines `ResidualBlock` and `AttentionBlock`. These are used in the first half of U-Net at each resolution. 123 | """ 124 | 125 | def __init__(self, in_channels: int, out_channels: int, time_channels: int, is_noise: bool = True): 126 | super().__init__() 127 | self.res = ResidualBlock(in_channels, out_channels, time_channels, is_noise=is_noise) 128 | 129 | def forward(self, x: torch.Tensor, t: torch.Tensor): 130 | x = self.res(x, t) 131 | return x 132 | 133 | 134 | class UpBlock(nn.Module): 135 | """ 136 | ### Up block 137 | This combines `ResidualBlock` and `AttentionBlock`. These are used in the second half of U-Net at each resolution. 138 | """ 139 | 140 | def __init__(self, in_channels: int, out_channels: int, time_channels: int, is_noise: bool = True): 141 | super().__init__() 142 | # The input has `in_channels + out_channels` because we concatenate the output of the same resolution 143 | # from the first half of the U-Net 144 | self.res = ResidualBlock(in_channels + out_channels, out_channels, time_channels, is_noise=is_noise) 145 | 146 | def forward(self, x: torch.Tensor, t: torch.Tensor): 147 | x = self.res(x, t) 148 | return x 149 | 150 | 151 | class MiddleBlock(nn.Module): 152 | """ 153 | ### Middle block 154 | It combines a `ResidualBlock`, `AttentionBlock`, followed by another `ResidualBlock`. 155 | This block is applied at the lowest resolution of the U-Net. 156 | """ 157 | 158 | def __init__(self, n_channels: int, time_channels: int, is_noise: bool = True): 159 | super().__init__() 160 | self.res1 = ResidualBlock(n_channels, n_channels, time_channels, is_noise=is_noise) 161 | self.dia1 = nn.Conv2d(n_channels, n_channels, 3, 1, dilation=2, padding=get_pad(16, 3, 1, 2)) 162 | self.dia2 = nn.Conv2d(n_channels, n_channels, 3, 1, dilation=4, padding=get_pad(16, 3, 1, 4)) 163 | self.dia3 = nn.Conv2d(n_channels, n_channels, 3, 1, dilation=8, padding=get_pad(16, 3, 1, 8)) 164 | self.dia4 = nn.Conv2d(n_channels, n_channels, 3, 1, dilation=16, padding=get_pad(16, 3, 1, 16)) 165 | self.res2 = ResidualBlock(n_channels, n_channels, time_channels, is_noise=is_noise) 166 | 167 | def forward(self, x: torch.Tensor, t: torch.Tensor): 168 | x = self.res1(x, t) 169 | x = self.dia1(x) 170 | x = self.dia2(x) 171 | x = self.dia3(x) 172 | x = self.dia4(x) 173 | x = self.res2(x, t) 174 | return x 175 | 176 | 177 | class Upsample(nn.Module): 178 | """ 179 | ### Scale up the feature map by $2 \times$ 180 | """ 181 | 182 | def __init__(self, n_channels): 183 | super().__init__() 184 | self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1)) 185 | 186 | def forward(self, x: torch.Tensor, t: torch.Tensor): 187 | # `t` is not used, but it's kept in the arguments because for the attention layer function signature 188 | # to match with `ResidualBlock`. 189 | _ = t 190 | return self.conv(x) 191 | 192 | 193 | class Downsample(nn.Module): 194 | """ 195 | ### Scale down the feature map by $\frac{1}{2} \times$ 196 | """ 197 | 198 | def __init__(self, n_channels): 199 | super().__init__() 200 | self.conv = nn.Conv2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1)) 201 | 202 | def forward(self, x: torch.Tensor, t: torch.Tensor): 203 | # `t` is not used, but it's kept in the arguments because for the attention layer function signature 204 | # to match with `ResidualBlock`. 205 | _ = t 206 | return self.conv(x) 207 | 208 | 209 | class UNet(nn.Module): 210 | """ 211 | ## U-Net 212 | """ 213 | 214 | def __init__(self, input_channels: int = 2, output_channels: int = 1, n_channels: int = 32, 215 | ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4), 216 | n_blocks: int = 2, is_noise: bool = True): 217 | """ 218 | * `image_channels` is the number of channels in the image. $3$ for RGB. 219 | * `n_channels` is number of channels in the initial feature map that we transform the image into 220 | * `ch_mults` is the list of channel numbers at each resolution. The number of channels is `ch_mults[i] * n_channels` 221 | * `is_attn` is a list of booleans that indicate whether to use attention at each resolution 222 | * `n_blocks` is the number of `UpDownBlocks` at each resolution 223 | """ 224 | super().__init__() 225 | 226 | # Number of resolutions 227 | n_resolutions = len(ch_mults) 228 | 229 | # Project image into feature map 230 | self.image_proj = nn.Conv2d(input_channels, n_channels, kernel_size=(3, 3), padding=(1, 1)) 231 | 232 | # Time embedding layer. Time embedding has `n_channels * 4` channels 233 | self.is_noise = is_noise 234 | if is_noise: 235 | self.time_emb = TimeEmbedding(n_channels * 4) 236 | 237 | # #### First half of U-Net - decreasing resolution 238 | down = [] 239 | # Number of channels 240 | out_channels = in_channels = n_channels 241 | # For each resolution 242 | for i in range(n_resolutions): 243 | # Number of output channels at this resolution 244 | out_channels = n_channels * ch_mults[i] 245 | # Add `n_blocks` 246 | for _ in range(n_blocks): 247 | down.append(DownBlock(in_channels, out_channels, n_channels * 4, is_noise=is_noise)) 248 | in_channels = out_channels 249 | # Down sample at all resolutions except the last 250 | if i < n_resolutions - 1: 251 | down.append(Downsample(in_channels)) 252 | 253 | # Combine the set of modules 254 | self.down = nn.ModuleList(down) 255 | 256 | # Middle block 257 | self.middle = MiddleBlock(out_channels, n_channels * 4, is_noise=False) 258 | 259 | # #### Second half of U-Net - increasing resolution 260 | up = [] 261 | # Number of channels 262 | in_channels = out_channels 263 | # For each resolution 264 | for i in reversed(range(n_resolutions)): 265 | # `n_blocks` at the same resolution 266 | out_channels = n_channels * ch_mults[i] 267 | for _ in range(n_blocks): 268 | up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_noise=is_noise)) 269 | # Final block to reduce the number of channels 270 | in_channels = n_channels * (ch_mults[i-1] if i >= 1 else 1) 271 | up.append(UpBlock(in_channels, out_channels, n_channels * 4, is_noise=is_noise)) 272 | in_channels = out_channels 273 | # Up sample at all resolutions except last 274 | if i > 0: 275 | up.append(Upsample(in_channels)) 276 | 277 | # Combine the set of modules 278 | self.up = nn.ModuleList(up) 279 | 280 | self.act = Swish() 281 | self.final = nn.Conv2d(in_channels, output_channels, kernel_size=(3, 3), padding=(1, 1)) 282 | 283 | def forward(self, x: torch.Tensor, t: torch.Tensor=torch.tensor([0]).cuda()): 284 | """ 285 | * `x` has shape `[batch_size, in_channels, height, width]` 286 | * `t` has shape `[batch_size]` 287 | """ 288 | 289 | # Get time-step embeddings 290 | if self.is_noise: 291 | t = self.time_emb(t) 292 | else: 293 | t = None 294 | # Get image projection 295 | x = self.image_proj(x) 296 | 297 | # `h` will store outputs at each resolution for skip connection 298 | h = [x] 299 | # First half of U-Net 300 | for m in self.down: 301 | x = m(x, t) 302 | h.append(x) 303 | 304 | # Middle (bottom) 305 | x = self.middle(x, t) 306 | 307 | # Second half of U-Net 308 | for m in self.up: 309 | if isinstance(m, Upsample): 310 | x = m(x, t) 311 | else: 312 | # Get the skip connection from first half of U-Net and concatenate 313 | s = h.pop() 314 | # print(x.shape, s.shape) 315 | x = torch.cat((x, s), dim=1) 316 | # 317 | x = m(x, t) 318 | 319 | # Final normalization and convolution 320 | return self.final(self.act(x)) 321 | 322 | 9 323 | class DocDiff(nn.Module): 324 | def __init__(self, input_channels: int = 2, output_channels: int = 1, n_channels: int = 32, 325 | ch_mults: Union[Tuple[int, ...], List[int]] = (1, 2, 2, 4), 326 | n_blocks: int = 1): 327 | super(DocDiff, self).__init__() 328 | self.denoiser = UNet(input_channels, output_channels, n_channels, ch_mults, n_blocks, is_noise=True) 329 | self.init_predictor = UNet(input_channels//2, output_channels, n_channels, ch_mults, n_blocks, is_noise=False) 330 | # self.init_predictor = UNet(input_channels, output_channels, 2 * n_channels, ch_mults, n_blocks) 331 | 332 | def forward(self, x, condition, t, diffusion): 333 | x_ = self.init_predictor(condition, t) 334 | residual = x - x_ 335 | noisy_image, noise_ref = diffusion.noisy_image(t, residual) 336 | x__ = self.denoiser(torch.cat((noisy_image, x_.clone().detach()), dim=1), t) 337 | return x_, x__, noisy_image, noise_ref 338 | 339 | 340 | class EMA(): 341 | def __init__(self, beta): 342 | super().__init__() 343 | self.beta = beta 344 | 345 | def update_model_average(self, ma_model, current_model): 346 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 347 | old_weight, up_weight = ma_params.data, current_params.data 348 | ma_params.data = self.update_average(old_weight, up_weight) 349 | 350 | def update_average(self, old, new): 351 | if old is None: 352 | return new 353 | return old * self.beta + (1 - self.beta) * new 354 | 355 | 356 | def get_pad(in_, ksize, stride, atrous=1): 357 | out_ = np.ceil(float(in_)/stride) 358 | return int(((out_ - 1) * stride + atrous*(ksize-1) + 1 - in_)/2) 359 | 360 | 361 | if __name__ == '__main__': 362 | from src.config import load_config 363 | import argparse 364 | from schedule.diffusionSample import GaussianDiffusion 365 | from schedule.schedule import Schedule 366 | import torchsummary 367 | parser = argparse.ArgumentParser() 368 | parser.add_argument('--config', type=str, default='../conf.yml', help='path to the config.yaml file') 369 | args = parser.parse_args() 370 | config = load_config(args.config) 371 | print('Config loaded') 372 | model = DocDiff(input_channels=config.CHANNEL_X + config.CHANNEL_Y, 373 | output_channels=config.CHANNEL_Y, 374 | n_channels=config.MODEL_CHANNELS, 375 | ch_mults=config.CHANNEL_MULT, 376 | n_blocks=config.NUM_RESBLOCKS) 377 | schedule = Schedule(config.SCHEDULE, config.TIMESTEPS) 378 | diffusion = GaussianDiffusion(model, config.TIMESTEPS, schedule) 379 | model.eval() 380 | print(torchsummary.summary(model.init_predictor.cuda(), [(3, 128, 128)], batch_size=32)) 381 | 382 | -------------------------------------------------------------------------------- /schedule/diffusionSample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torchvision.utils 6 | 7 | 8 | def extract_(a, t, x_shape): 9 | b, *_ = t.shape 10 | out = a.gather(-1, t) 11 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 12 | 13 | 14 | def extract(v, t, x_shape): 15 | """ 16 | Extract some coefficients at specified timesteps, then reshape to 17 | [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. 18 | """ 19 | device = t.device 20 | out = torch.gather(v, index=t, dim=0).float().to(device) 21 | return out.view([t.shape[0]] + [1] * (len(x_shape) - 1)) 22 | 23 | 24 | class GaussianDiffusion(nn.Module): 25 | def __init__(self, model, T, schedule): 26 | super().__init__() 27 | self.visual = False 28 | if self.visual: 29 | self.num = 0 30 | self.model = model 31 | self.T = T 32 | self.schedule = schedule 33 | betas = self.schedule.get_betas() 34 | self.register_buffer('betas', betas.float()) 35 | alphas = 1. - self.betas 36 | alphas_bar = torch.cumprod(alphas, dim=0) 37 | alphas_bar_prev = F.pad(alphas_bar, [1, 0], value=1)[:T] 38 | gammas = alphas_bar 39 | 40 | self.register_buffer('coeff1', torch.sqrt(1. / alphas)) 41 | self.register_buffer('coeff2', self.coeff1 * (1. - alphas) / torch.sqrt(1. - alphas_bar)) 42 | self.register_buffer('posterior_var', self.betas * (1. - alphas_bar_prev) / (1. - alphas_bar)) 43 | 44 | # calculation for q(y_t|y_{t-1}) 45 | self.register_buffer('gammas', gammas) 46 | self.register_buffer('sqrt_one_minus_gammas', np.sqrt(1 - gammas)) 47 | self.register_buffer('sqrt_gammas', np.sqrt(gammas)) 48 | 49 | def predict_xt_prev_mean_from_eps(self, x_t, t, eps): 50 | assert x_t.shape == eps.shape 51 | return extract(self.coeff1, t, x_t.shape) * x_t - extract(self.coeff2, t, x_t.shape) * eps 52 | 53 | def p_mean_variance(self, x_t, cond_, t): 54 | # below: only log_variance is used in the KL computations 55 | var = torch.cat([self.posterior_var[1:2], self.betas[1:]]) 56 | #var = self.betas 57 | var = extract(var, t, x_t.shape) 58 | eps = self.model(torch.cat((x_t, cond_), dim=1), t) 59 | # nonEps = self.model(x_t, t, torch.zeros_like(labels).to(labels.device)) 60 | # eps = (1. + self.w) * eps - self.w * nonEps 61 | xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps) 62 | return xt_prev_mean, var 63 | 64 | def noisy_image(self, t, y): 65 | """ Compute y_noisy according to (6) p15 of [2]""" 66 | noise = torch.randn_like(y) 67 | y_noisy = extract_(self.sqrt_gammas, t, y.shape) * y + extract_(self.sqrt_one_minus_gammas, t, noise.shape) * noise 68 | return y_noisy, noise 69 | 70 | def forward(self, x_T, cond, pre_ori='False'): 71 | """ 72 | Algorithm 2. 73 | """ 74 | x_t = x_T 75 | cond_ = cond 76 | for time_step in reversed(range(self.T)): 77 | print("time_step: ", time_step) 78 | t = x_t.new_ones([x_T.shape[0], ], dtype=torch.long) * time_step 79 | if pre_ori == 'False': 80 | mean, var = self.p_mean_variance(x_t=x_t, t=t, cond_=cond_) 81 | if time_step > 0: 82 | noise = torch.randn_like(x_t) 83 | else: 84 | noise = 0 85 | x_t = mean + torch.sqrt(var) * noise 86 | assert torch.isnan(x_t).int().sum() == 0, "nan in tensor." 87 | else: 88 | if time_step > 0: 89 | ori = self.model(torch.cat((x_t, cond_), dim=1), t) 90 | eps = x_t - extract_(self.sqrt_gammas, t, ori.shape) * ori 91 | eps = eps / extract_(self.sqrt_one_minus_gammas, t, eps.shape) 92 | x_t = extract_(self.sqrt_gammas, t - 1, ori.shape) * ori + extract_(self.sqrt_one_minus_gammas, t - 1, eps.shape) * eps 93 | else: 94 | x_t = self.model(torch.cat((x_t, cond_), dim=1), t) 95 | 96 | x_0 = x_t 97 | return x_0 98 | 99 | 100 | if __name__ == '__main__': 101 | from schedule import Schedule 102 | test = GaussianDiffusion(None, 100, Schedule('linear', 100)) 103 | print(test.gammas) 104 | -------------------------------------------------------------------------------- /schedule/dpm_solver_pytorch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import math 4 | 5 | 6 | class NoiseScheduleVP: 7 | def __init__( 8 | self, 9 | schedule='discrete', 10 | betas=None, 11 | alphas_cumprod=None, 12 | continuous_beta_0=0.1, 13 | continuous_beta_1=20., 14 | dtype=torch.float32, 15 | ): 16 | """Create a wrapper class for the forward SDE (VP type). 17 | 18 | *** 19 | Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t. 20 | We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images. 21 | *** 22 | 23 | The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ). 24 | We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper). 25 | Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have: 26 | 27 | log_alpha_t = self.marginal_log_mean_coeff(t) 28 | sigma_t = self.marginal_std(t) 29 | lambda_t = self.marginal_lambda(t) 30 | 31 | Moreover, as lambda(t) is an invertible function, we also support its inverse function: 32 | 33 | t = self.inverse_lambda(lambda_t) 34 | 35 | =============================================================== 36 | 37 | We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]). 38 | 39 | 1. For discrete-time DPMs: 40 | 41 | For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by: 42 | t_i = (i + 1) / N 43 | e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1. 44 | We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3. 45 | 46 | Args: 47 | betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details) 48 | alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details) 49 | 50 | Note that we always have alphas_cumprod = cumprod(1 - betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`. 51 | 52 | **Important**: Please pay special attention for the args for `alphas_cumprod`: 53 | The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that 54 | q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ). 55 | Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have 56 | alpha_{t_n} = \sqrt{\hat{alpha_n}}, 57 | and 58 | log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}). 59 | 60 | 61 | 2. For continuous-time DPMs: 62 | 63 | We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise 64 | schedule are the default settings in DDPM and improved-DDPM: 65 | 66 | Args: 67 | beta_min: A `float` number. The smallest beta for the linear schedule. 68 | beta_max: A `float` number. The largest beta for the linear schedule. 69 | cosine_s: A `float` number. The hyperparameter in the cosine schedule. 70 | cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule. 71 | T: A `float` number. The ending time of the forward process. 72 | 73 | =============================================================== 74 | 75 | Args: 76 | schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs, 77 | 'linear' or 'cosine' for continuous-time DPMs. 78 | Returns: 79 | A wrapper object of the forward SDE (VP type). 80 | 81 | =============================================================== 82 | 83 | Example: 84 | 85 | # For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1): 86 | >>> ns = NoiseScheduleVP('discrete', betas=betas) 87 | 88 | # For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1): 89 | >>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) 90 | 91 | # For continuous-time DPMs (VPSDE), linear schedule: 92 | >>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.) 93 | 94 | """ 95 | 96 | if schedule not in ['discrete', 'linear', 'cosine']: 97 | raise ValueError("Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(schedule)) 98 | 99 | self.schedule = schedule 100 | if schedule == 'discrete': 101 | if betas is not None: 102 | log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0) 103 | else: 104 | assert alphas_cumprod is not None 105 | log_alphas = 0.5 * torch.log(alphas_cumprod) 106 | self.total_N = len(log_alphas) 107 | self.T = 1. 108 | self.t_array = torch.linspace(0., 1., self.total_N + 1)[1:].reshape((1, -1)).to(dtype=dtype) 109 | self.log_alpha_array = log_alphas.reshape((1, -1,)).to(dtype=dtype) 110 | else: 111 | self.total_N = 1000 112 | self.beta_0 = continuous_beta_0 113 | self.beta_1 = continuous_beta_1 114 | self.cosine_s = 0.008 115 | self.cosine_beta_max = 999. 116 | self.cosine_t_max = math.atan(self.cosine_beta_max * (1. + self.cosine_s) / math.pi) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s 117 | self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1. + self.cosine_s) * math.pi / 2.)) 118 | self.schedule = schedule 119 | if schedule == 'cosine': 120 | # For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T. 121 | # Note that T = 0.9946 may be not the optimal setting. However, we find it works well. 122 | self.T = 0.9946 123 | else: 124 | self.T = 1. 125 | 126 | def marginal_log_mean_coeff(self, t): 127 | """ 128 | Compute log(alpha_t) of a given continuous-time label t in [0, T]. 129 | """ 130 | if self.schedule == 'discrete': 131 | return interpolate_fn(t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)).reshape((-1)) 132 | elif self.schedule == 'linear': 133 | return -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0 134 | elif self.schedule == 'cosine': 135 | log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1. + self.cosine_s) * math.pi / 2.)) 136 | log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0 137 | return log_alpha_t 138 | 139 | def marginal_alpha(self, t): 140 | """ 141 | Compute alpha_t of a given continuous-time label t in [0, T]. 142 | """ 143 | return torch.exp(self.marginal_log_mean_coeff(t)) 144 | 145 | def marginal_std(self, t): 146 | """ 147 | Compute sigma_t of a given continuous-time label t in [0, T]. 148 | """ 149 | return torch.sqrt(1. - torch.exp(2. * self.marginal_log_mean_coeff(t))) 150 | 151 | def marginal_lambda(self, t): 152 | """ 153 | Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T]. 154 | """ 155 | log_mean_coeff = self.marginal_log_mean_coeff(t) 156 | log_std = 0.5 * torch.log(1. - torch.exp(2. * log_mean_coeff)) 157 | return log_mean_coeff - log_std 158 | 159 | def inverse_lambda(self, lamb): 160 | """ 161 | Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t. 162 | """ 163 | if self.schedule == 'linear': 164 | tmp = 2. * (self.beta_1 - self.beta_0) * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) 165 | Delta = self.beta_0**2 + tmp 166 | return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0) 167 | elif self.schedule == 'discrete': 168 | log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2. * lamb) 169 | t = interpolate_fn(log_alpha.reshape((-1, 1)), torch.flip(self.log_alpha_array.to(lamb.device), [1]), torch.flip(self.t_array.to(lamb.device), [1])) 170 | return t.reshape((-1,)) 171 | else: 172 | log_alpha = -0.5 * torch.logaddexp(-2. * lamb, torch.zeros((1,)).to(lamb)) 173 | t_fn = lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0)) * 2. * (1. + self.cosine_s) / math.pi - self.cosine_s 174 | t = t_fn(log_alpha) 175 | return t 176 | 177 | 178 | def model_wrapper( 179 | model, 180 | noise_schedule, 181 | model_type="noise", 182 | model_kwargs={}, 183 | guidance_type="uncond", 184 | condition=None, 185 | unconditional_condition=None, 186 | guidance_scale=1., 187 | classifier_fn=None, 188 | classifier_kwargs={}, 189 | ): 190 | """Create a wrapper function for the noise prediction model. 191 | 192 | DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to 193 | firstly wrap the model function to a noise prediction model that accepts the continuous time as the input. 194 | 195 | We support four types of the diffusion model by setting `model_type`: 196 | 197 | 1. "noise": noise prediction model. (Trained by predicting noise). 198 | 199 | 2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0). 200 | 201 | 3. "v": velocity prediction model. (Trained by predicting the velocity). 202 | The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2]. 203 | 204 | [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models." 205 | arXiv preprint arXiv:2202.00512 (2022). 206 | [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." 207 | arXiv preprint arXiv:2210.02303 (2022). 208 | 209 | 4. "score": marginal score function. (Trained by denoising score matching). 210 | Note that the score function and the noise prediction model follows a simple relationship: 211 | ``` 212 | noise(x_t, t) = -sigma_t * score(x_t, t) 213 | ``` 214 | 215 | We support three types of guided sampling by DPMs by setting `guidance_type`: 216 | 1. "uncond": unconditional sampling by DPMs. 217 | The input `model` has the following format: 218 | `` 219 | model(x, t_input, **model_kwargs) -> noise | x_start | v | score 220 | `` 221 | 222 | 2. "classifier": classifier guidance sampling [3] by DPMs and another classifier. 223 | The input `model` has the following format: 224 | `` 225 | model(x, t_input, **model_kwargs) -> noise | x_start | v | score 226 | `` 227 | 228 | The input `classifier_fn` has the following format: 229 | `` 230 | classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond) 231 | `` 232 | 233 | [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis," 234 | in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794. 235 | 236 | 3. "classifier-free": classifier-free guidance sampling by conditional DPMs. 237 | The input `model` has the following format: 238 | `` 239 | model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score 240 | `` 241 | And if cond == `unconditional_condition`, the model output is the unconditional DPM output. 242 | 243 | [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." 244 | arXiv preprint arXiv:2207.12598 (2022). 245 | 246 | 247 | The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) 248 | or continuous-time labels (i.e. epsilon to T). 249 | 250 | We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise: 251 | `` 252 | def model_fn(x, t_continuous) -> noise: 253 | t_input = get_model_input_time(t_continuous) 254 | return noise_pred(model, x, t_input, **model_kwargs) 255 | `` 256 | where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver. 257 | 258 | =============================================================== 259 | 260 | Args: 261 | model: A diffusion model with the corresponding format described above. 262 | noise_schedule: A noise schedule object, such as NoiseScheduleVP. 263 | model_type: A `str`. The parameterization type of the diffusion model. 264 | "noise" or "x_start" or "v" or "score". 265 | model_kwargs: A `dict`. A dict for the other inputs of the model function. 266 | guidance_type: A `str`. The type of the guidance for sampling. 267 | "uncond" or "classifier" or "classifier-free". 268 | condition: A pytorch tensor. The condition for the guided sampling. 269 | Only used for "classifier" or "classifier-free" guidance type. 270 | unconditional_condition: A pytorch tensor. The condition for the unconditional sampling. 271 | Only used for "classifier-free" guidance type. 272 | guidance_scale: A `float`. The scale for the guided sampling. 273 | classifier_fn: A classifier function. Only used for the classifier guidance. 274 | classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function. 275 | Returns: 276 | A noise prediction model that accepts the noised data and the continuous time as the inputs. 277 | """ 278 | 279 | def get_model_input_time(t_continuous): 280 | """ 281 | Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time. 282 | For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N]. 283 | For continuous-time DPMs, we just use `t_continuous`. 284 | """ 285 | if noise_schedule.schedule == 'discrete': 286 | return (t_continuous - 1. / noise_schedule.total_N) * 1000. 287 | else: 288 | return t_continuous 289 | 290 | def noise_pred_fn(x, t_continuous, cond=None): 291 | t_input = get_model_input_time(t_continuous) 292 | if cond is None: 293 | output = model(x, t_input, **model_kwargs) 294 | else: 295 | output = model(x, t_input, cond, **model_kwargs) 296 | if model_type == "noise": 297 | return output 298 | elif model_type == "x_start": 299 | alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) 300 | return (x - alpha_t * output) / sigma_t 301 | elif model_type == "v": 302 | alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous) 303 | return alpha_t * output + sigma_t * x 304 | elif model_type == "score": 305 | sigma_t = noise_schedule.marginal_std(t_continuous) 306 | return -sigma_t * output 307 | 308 | def cond_grad_fn(x, t_input): 309 | """ 310 | Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t). 311 | """ 312 | with torch.enable_grad(): 313 | x_in = x.detach().requires_grad_(True) 314 | log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs) 315 | return torch.autograd.grad(log_prob.sum(), x_in)[0] 316 | 317 | def model_fn(x, t_continuous): 318 | """ 319 | The noise predicition model function that is used for DPM-Solver. 320 | """ 321 | if guidance_type == "uncond": 322 | return noise_pred_fn(x, t_continuous) 323 | elif guidance_type == "classifier": 324 | assert classifier_fn is not None 325 | t_input = get_model_input_time(t_continuous) 326 | cond_grad = cond_grad_fn(x, t_input) 327 | sigma_t = noise_schedule.marginal_std(t_continuous) 328 | noise = noise_pred_fn(x, t_continuous) 329 | return noise - guidance_scale * sigma_t * cond_grad 330 | elif guidance_type == "classifier-free": 331 | if guidance_scale == 1. or unconditional_condition is None: 332 | return noise_pred_fn(x, t_continuous, cond=condition) 333 | else: 334 | x_in = torch.cat([x] * 2) 335 | t_in = torch.cat([t_continuous] * 2) 336 | c_in = torch.cat([unconditional_condition, condition]) 337 | noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2) 338 | return noise_uncond + guidance_scale * (noise - noise_uncond) 339 | 340 | assert model_type in ["noise", "x_start", "v", "score"] 341 | assert guidance_type in ["uncond", "classifier", "classifier-free"] 342 | return model_fn 343 | 344 | 345 | class DPM_Solver: 346 | def __init__( 347 | self, 348 | model_fn, 349 | noise_schedule, 350 | algorithm_type="dpmsolver++", 351 | correcting_x0_fn=None, 352 | correcting_xt_fn=None, 353 | thresholding_max_val=1., 354 | dynamic_thresholding_ratio=0.995, 355 | ): 356 | """Construct a DPM-Solver. 357 | 358 | We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`). 359 | 360 | We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you 361 | can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the 362 | dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space 363 | DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space 364 | DPMs (such as stable-diffusion). 365 | 366 | To support advanced algorithms in image-to-image applications, we also support corrector functions for 367 | both x0 and xt. 368 | 369 | Args: 370 | model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]): 371 | `` 372 | def model_fn(x, t_continuous): 373 | return noise 374 | `` 375 | The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`. 376 | noise_schedule: A noise schedule object, such as NoiseScheduleVP. 377 | algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++". 378 | correcting_x0_fn: A `str` or a function with the following format: 379 | ``` 380 | def correcting_x0_fn(x0, t): 381 | x0_new = ... 382 | return x0_new 383 | ``` 384 | This function is to correct the outputs of the data prediction model at each sampling step. e.g., 385 | ``` 386 | x0_pred = data_pred_model(xt, t) 387 | if correcting_x0_fn is not None: 388 | x0_pred = correcting_x0_fn(x0_pred, t) 389 | xt_1 = update(x0_pred, xt, t) 390 | ``` 391 | If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1]. 392 | correcting_xt_fn: A function with the following format: 393 | ``` 394 | def correcting_xt_fn(xt, t, step): 395 | x_new = ... 396 | return x_new 397 | ``` 398 | This function is to correct the intermediate samples xt at each sampling step. e.g., 399 | ``` 400 | xt = ... 401 | xt = correcting_xt_fn(xt, t, step) 402 | ``` 403 | thresholding_max_val: A `float`. The max value for thresholding. 404 | Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. 405 | dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details). 406 | Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`. 407 | 408 | [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour, 409 | Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models 410 | with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b. 411 | """ 412 | self.model = lambda x, t: model_fn(x, t.expand((x.shape[0]))) 413 | self.noise_schedule = noise_schedule 414 | assert algorithm_type in ["dpmsolver", "dpmsolver++"] 415 | self.algorithm_type = algorithm_type 416 | if correcting_x0_fn == "dynamic_thresholding": 417 | self.correcting_x0_fn = self.dynamic_thresholding_fn 418 | else: 419 | self.correcting_x0_fn = correcting_x0_fn 420 | self.correcting_xt_fn = correcting_xt_fn 421 | self.dynamic_thresholding_ratio = dynamic_thresholding_ratio 422 | self.thresholding_max_val = thresholding_max_val 423 | 424 | def dynamic_thresholding_fn(self, x0, t): 425 | """ 426 | The dynamic thresholding method. 427 | """ 428 | dims = x0.dim() 429 | p = self.dynamic_thresholding_ratio 430 | s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1) 431 | s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims) 432 | x0 = torch.clamp(x0, -s, s) / s 433 | return x0 434 | 435 | def noise_prediction_fn(self, x, t): 436 | """ 437 | Return the noise prediction model. 438 | """ 439 | return self.model(x, t) 440 | 441 | def data_prediction_fn(self, x, t): 442 | """ 443 | Return the data prediction model (with corrector). 444 | """ 445 | noise = self.noise_prediction_fn(x, t) 446 | alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) 447 | x0 = (x - sigma_t * noise) / alpha_t 448 | if self.correcting_x0_fn is not None: 449 | x0 = self.correcting_x0_fn(x0, t) 450 | return x0 451 | 452 | def model_fn(self, x, t): 453 | """ 454 | Convert the model to the noise prediction model or the data prediction model. 455 | """ 456 | if self.algorithm_type == "dpmsolver++": 457 | return self.data_prediction_fn(x, t) 458 | else: 459 | return self.noise_prediction_fn(x, t) 460 | 461 | def get_time_steps(self, skip_type, t_T, t_0, N, device): 462 | """Compute the intermediate time steps for sampling. 463 | 464 | Args: 465 | skip_type: A `str`. The type for the spacing of the time steps. We support three types: 466 | - 'logSNR': uniform logSNR for the time steps. 467 | - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) 468 | - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) 469 | t_T: A `float`. The starting time of the sampling (default is T). 470 | t_0: A `float`. The ending time of the sampling (default is epsilon). 471 | N: A `int`. The total number of the spacing of the time steps. 472 | device: A torch device. 473 | Returns: 474 | A pytorch tensor of the time steps, with the shape (N + 1,). 475 | """ 476 | if skip_type == 'logSNR': 477 | lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device)) 478 | lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device)) 479 | logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device) 480 | return self.noise_schedule.inverse_lambda(logSNR_steps) 481 | elif skip_type == 'time_uniform': 482 | return torch.linspace(t_T, t_0, N + 1).to(device) 483 | elif skip_type == 'time_quadratic': 484 | t_order = 2 485 | t = torch.linspace(t_T**(1. / t_order), t_0**(1. / t_order), N + 1).pow(t_order).to(device) 486 | return t 487 | else: 488 | raise ValueError("Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)) 489 | 490 | def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device): 491 | """ 492 | Get the order of each step for sampling by the singlestep DPM-Solver. 493 | 494 | We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast". 495 | Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is: 496 | - If order == 1: 497 | We take `steps` of DPM-Solver-1 (i.e. DDIM). 498 | - If order == 2: 499 | - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling. 500 | - If steps % 2 == 0, we use K steps of DPM-Solver-2. 501 | - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1. 502 | - If order == 3: 503 | - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. 504 | - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1. 505 | - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1. 506 | - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2. 507 | 508 | ============================================ 509 | Args: 510 | order: A `int`. The max order for the solver (2 or 3). 511 | steps: A `int`. The total number of function evaluations (NFE). 512 | skip_type: A `str`. The type for the spacing of the time steps. We support three types: 513 | - 'logSNR': uniform logSNR for the time steps. 514 | - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.) 515 | - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.) 516 | t_T: A `float`. The starting time of the sampling (default is T). 517 | t_0: A `float`. The ending time of the sampling (default is epsilon). 518 | device: A torch device. 519 | Returns: 520 | orders: A list of the solver order of each step. 521 | """ 522 | if order == 3: 523 | K = steps // 3 + 1 524 | if steps % 3 == 0: 525 | orders = [3,] * (K - 2) + [2, 1] 526 | elif steps % 3 == 1: 527 | orders = [3,] * (K - 1) + [1] 528 | else: 529 | orders = [3,] * (K - 1) + [2] 530 | elif order == 2: 531 | if steps % 2 == 0: 532 | K = steps // 2 533 | orders = [2,] * K 534 | else: 535 | K = steps // 2 + 1 536 | orders = [2,] * (K - 1) + [1] 537 | elif order == 1: 538 | K = 1 539 | orders = [1,] * steps 540 | else: 541 | raise ValueError("'order' must be '1' or '2' or '3'.") 542 | if skip_type == 'logSNR': 543 | # To reproduce the results in DPM-Solver paper 544 | timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device) 545 | else: 546 | timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[torch.cumsum(torch.tensor([0,] + orders), 0).to(device)] 547 | return timesteps_outer, orders 548 | 549 | def denoise_to_zero_fn(self, x, s): 550 | """ 551 | Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization. 552 | """ 553 | return self.data_prediction_fn(x, s) 554 | 555 | def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False): 556 | """ 557 | DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`. 558 | 559 | Args: 560 | x: A pytorch tensor. The initial value at time `s`. 561 | s: A pytorch tensor. The starting time, with the shape (1,). 562 | t: A pytorch tensor. The ending time, with the shape (1,). 563 | model_s: A pytorch tensor. The model function evaluated at time `s`. 564 | If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. 565 | return_intermediate: A `bool`. If true, also return the model value at time `s`. 566 | Returns: 567 | x_t: A pytorch tensor. The approximated solution at time `t`. 568 | """ 569 | ns = self.noise_schedule 570 | dims = x.dim() 571 | lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) 572 | h = lambda_t - lambda_s 573 | log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t) 574 | sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t) 575 | alpha_t = torch.exp(log_alpha_t) 576 | 577 | if self.algorithm_type == "dpmsolver++": 578 | phi_1 = torch.expm1(-h) 579 | if model_s is None: 580 | model_s = self.model_fn(x, s) 581 | x_t = ( 582 | sigma_t / sigma_s * x 583 | - alpha_t * phi_1 * model_s 584 | ) 585 | if return_intermediate: 586 | return x_t, {'model_s': model_s} 587 | else: 588 | return x_t 589 | else: 590 | phi_1 = torch.expm1(h) 591 | if model_s is None: 592 | model_s = self.model_fn(x, s) 593 | x_t = ( 594 | torch.exp(log_alpha_t - log_alpha_s) * x 595 | - (sigma_t * phi_1) * model_s 596 | ) 597 | if return_intermediate: 598 | return x_t, {'model_s': model_s} 599 | else: 600 | return x_t 601 | 602 | def singlestep_dpm_solver_second_update(self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type='dpmsolver'): 603 | """ 604 | Singlestep solver DPM-Solver-2 from time `s` to time `t`. 605 | 606 | Args: 607 | x: A pytorch tensor. The initial value at time `s`. 608 | s: A pytorch tensor. The starting time, with the shape (1,). 609 | t: A pytorch tensor. The ending time, with the shape (1,). 610 | r1: A `float`. The hyperparameter of the second-order solver. 611 | model_s: A pytorch tensor. The model function evaluated at time `s`. 612 | If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. 613 | return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time). 614 | solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. 615 | The type slightly impacts the performance. We recommend to use 'dpmsolver' type. 616 | Returns: 617 | x_t: A pytorch tensor. The approximated solution at time `t`. 618 | """ 619 | if solver_type not in ['dpmsolver', 'taylor']: 620 | raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) 621 | if r1 is None: 622 | r1 = 0.5 623 | ns = self.noise_schedule 624 | lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) 625 | h = lambda_t - lambda_s 626 | lambda_s1 = lambda_s + r1 * h 627 | s1 = ns.inverse_lambda(lambda_s1) 628 | log_alpha_s, log_alpha_s1, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(t) 629 | sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t) 630 | alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t) 631 | 632 | if self.algorithm_type == "dpmsolver++": 633 | phi_11 = torch.expm1(-r1 * h) 634 | phi_1 = torch.expm1(-h) 635 | 636 | if model_s is None: 637 | model_s = self.model_fn(x, s) 638 | x_s1 = ( 639 | (sigma_s1 / sigma_s) * x 640 | - (alpha_s1 * phi_11) * model_s 641 | ) 642 | model_s1 = self.model_fn(x_s1, s1) 643 | if solver_type == 'dpmsolver': 644 | x_t = ( 645 | (sigma_t / sigma_s) * x 646 | - (alpha_t * phi_1) * model_s 647 | - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s) 648 | ) 649 | elif solver_type == 'taylor': 650 | x_t = ( 651 | (sigma_t / sigma_s) * x 652 | - (alpha_t * phi_1) * model_s 653 | + (1. / r1) * (alpha_t * (phi_1 / h + 1.)) * (model_s1 - model_s) 654 | ) 655 | else: 656 | phi_11 = torch.expm1(r1 * h) 657 | phi_1 = torch.expm1(h) 658 | 659 | if model_s is None: 660 | model_s = self.model_fn(x, s) 661 | x_s1 = ( 662 | torch.exp(log_alpha_s1 - log_alpha_s) * x 663 | - (sigma_s1 * phi_11) * model_s 664 | ) 665 | model_s1 = self.model_fn(x_s1, s1) 666 | if solver_type == 'dpmsolver': 667 | x_t = ( 668 | torch.exp(log_alpha_t - log_alpha_s) * x 669 | - (sigma_t * phi_1) * model_s 670 | - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s) 671 | ) 672 | elif solver_type == 'taylor': 673 | x_t = ( 674 | torch.exp(log_alpha_t - log_alpha_s) * x 675 | - (sigma_t * phi_1) * model_s 676 | - (1. / r1) * (sigma_t * (phi_1 / h - 1.)) * (model_s1 - model_s) 677 | ) 678 | if return_intermediate: 679 | return x_t, {'model_s': model_s, 'model_s1': model_s1} 680 | else: 681 | return x_t 682 | 683 | def singlestep_dpm_solver_third_update(self, x, s, t, r1=1./3., r2=2./3., model_s=None, model_s1=None, return_intermediate=False, solver_type='dpmsolver'): 684 | """ 685 | Singlestep solver DPM-Solver-3 from time `s` to time `t`. 686 | 687 | Args: 688 | x: A pytorch tensor. The initial value at time `s`. 689 | s: A pytorch tensor. The starting time, with the shape (1,). 690 | t: A pytorch tensor. The ending time, with the shape (1,). 691 | r1: A `float`. The hyperparameter of the third-order solver. 692 | r2: A `float`. The hyperparameter of the third-order solver. 693 | model_s: A pytorch tensor. The model function evaluated at time `s`. 694 | If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it. 695 | model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`). 696 | If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it. 697 | return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). 698 | solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. 699 | The type slightly impacts the performance. We recommend to use 'dpmsolver' type. 700 | Returns: 701 | x_t: A pytorch tensor. The approximated solution at time `t`. 702 | """ 703 | if solver_type not in ['dpmsolver', 'taylor']: 704 | raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) 705 | if r1 is None: 706 | r1 = 1. / 3. 707 | if r2 is None: 708 | r2 = 2. / 3. 709 | ns = self.noise_schedule 710 | lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t) 711 | h = lambda_t - lambda_s 712 | lambda_s1 = lambda_s + r1 * h 713 | lambda_s2 = lambda_s + r2 * h 714 | s1 = ns.inverse_lambda(lambda_s1) 715 | s2 = ns.inverse_lambda(lambda_s2) 716 | log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(s1), ns.marginal_log_mean_coeff(s2), ns.marginal_log_mean_coeff(t) 717 | sigma_s, sigma_s1, sigma_s2, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(s2), ns.marginal_std(t) 718 | alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t) 719 | 720 | if self.algorithm_type == "dpmsolver++": 721 | phi_11 = torch.expm1(-r1 * h) 722 | phi_12 = torch.expm1(-r2 * h) 723 | phi_1 = torch.expm1(-h) 724 | phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1. 725 | phi_2 = phi_1 / h + 1. 726 | phi_3 = phi_2 / h - 0.5 727 | 728 | if model_s is None: 729 | model_s = self.model_fn(x, s) 730 | if model_s1 is None: 731 | x_s1 = ( 732 | (sigma_s1 / sigma_s) * x 733 | - (alpha_s1 * phi_11) * model_s 734 | ) 735 | model_s1 = self.model_fn(x_s1, s1) 736 | x_s2 = ( 737 | (sigma_s2 / sigma_s) * x 738 | - (alpha_s2 * phi_12) * model_s 739 | + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s) 740 | ) 741 | model_s2 = self.model_fn(x_s2, s2) 742 | if solver_type == 'dpmsolver': 743 | x_t = ( 744 | (sigma_t / sigma_s) * x 745 | - (alpha_t * phi_1) * model_s 746 | + (1. / r2) * (alpha_t * phi_2) * (model_s2 - model_s) 747 | ) 748 | elif solver_type == 'taylor': 749 | D1_0 = (1. / r1) * (model_s1 - model_s) 750 | D1_1 = (1. / r2) * (model_s2 - model_s) 751 | D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) 752 | D2 = 2. * (D1_1 - D1_0) / (r2 - r1) 753 | x_t = ( 754 | (sigma_t / sigma_s) * x 755 | - (alpha_t * phi_1) * model_s 756 | + (alpha_t * phi_2) * D1 757 | - (alpha_t * phi_3) * D2 758 | ) 759 | else: 760 | phi_11 = torch.expm1(r1 * h) 761 | phi_12 = torch.expm1(r2 * h) 762 | phi_1 = torch.expm1(h) 763 | phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1. 764 | phi_2 = phi_1 / h - 1. 765 | phi_3 = phi_2 / h - 0.5 766 | 767 | if model_s is None: 768 | model_s = self.model_fn(x, s) 769 | if model_s1 is None: 770 | x_s1 = ( 771 | (torch.exp(log_alpha_s1 - log_alpha_s)) * x 772 | - (sigma_s1 * phi_11) * model_s 773 | ) 774 | model_s1 = self.model_fn(x_s1, s1) 775 | x_s2 = ( 776 | (torch.exp(log_alpha_s2 - log_alpha_s)) * x 777 | - (sigma_s2 * phi_12) * model_s 778 | - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s) 779 | ) 780 | model_s2 = self.model_fn(x_s2, s2) 781 | if solver_type == 'dpmsolver': 782 | x_t = ( 783 | (torch.exp(log_alpha_t - log_alpha_s)) * x 784 | - (sigma_t * phi_1) * model_s 785 | - (1. / r2) * (sigma_t * phi_2) * (model_s2 - model_s) 786 | ) 787 | elif solver_type == 'taylor': 788 | D1_0 = (1. / r1) * (model_s1 - model_s) 789 | D1_1 = (1. / r2) * (model_s2 - model_s) 790 | D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1) 791 | D2 = 2. * (D1_1 - D1_0) / (r2 - r1) 792 | x_t = ( 793 | (torch.exp(log_alpha_t - log_alpha_s)) * x 794 | - (sigma_t * phi_1) * model_s 795 | - (sigma_t * phi_2) * D1 796 | - (sigma_t * phi_3) * D2 797 | ) 798 | 799 | if return_intermediate: 800 | return x_t, {'model_s': model_s, 'model_s1': model_s1, 'model_s2': model_s2} 801 | else: 802 | return x_t 803 | 804 | def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"): 805 | """ 806 | Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`. 807 | 808 | Args: 809 | x: A pytorch tensor. The initial value at time `s`. 810 | model_prev_list: A list of pytorch tensor. The previous computed model values. 811 | t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) 812 | t: A pytorch tensor. The ending time, with the shape (1,). 813 | solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. 814 | The type slightly impacts the performance. We recommend to use 'dpmsolver' type. 815 | Returns: 816 | x_t: A pytorch tensor. The approximated solution at time `t`. 817 | """ 818 | if solver_type not in ['dpmsolver', 'taylor']: 819 | raise ValueError("'solver_type' must be either 'dpmsolver' or 'taylor', got {}".format(solver_type)) 820 | ns = self.noise_schedule 821 | model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1] 822 | t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1] 823 | lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) 824 | log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) 825 | sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) 826 | alpha_t = torch.exp(log_alpha_t) 827 | 828 | h_0 = lambda_prev_0 - lambda_prev_1 829 | h = lambda_t - lambda_prev_0 830 | r0 = h_0 / h 831 | D1_0 = (1. / r0) * (model_prev_0 - model_prev_1) 832 | if self.algorithm_type == "dpmsolver++": 833 | phi_1 = torch.expm1(-h) 834 | if solver_type == 'dpmsolver': 835 | x_t = ( 836 | (sigma_t / sigma_prev_0) * x 837 | - (alpha_t * phi_1) * model_prev_0 838 | - 0.5 * (alpha_t * phi_1) * D1_0 839 | ) 840 | elif solver_type == 'taylor': 841 | x_t = ( 842 | (sigma_t / sigma_prev_0) * x 843 | - (alpha_t * phi_1) * model_prev_0 844 | + (alpha_t * (phi_1 / h + 1.)) * D1_0 845 | ) 846 | else: 847 | phi_1 = torch.expm1(h) 848 | if solver_type == 'dpmsolver': 849 | x_t = ( 850 | (torch.exp(log_alpha_t - log_alpha_prev_0)) * x 851 | - (sigma_t * phi_1) * model_prev_0 852 | - 0.5 * (sigma_t * phi_1) * D1_0 853 | ) 854 | elif solver_type == 'taylor': 855 | x_t = ( 856 | (torch.exp(log_alpha_t - log_alpha_prev_0)) * x 857 | - (sigma_t * phi_1) * model_prev_0 858 | - (sigma_t * (phi_1 / h - 1.)) * D1_0 859 | ) 860 | return x_t 861 | 862 | def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type='dpmsolver'): 863 | """ 864 | Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`. 865 | 866 | Args: 867 | x: A pytorch tensor. The initial value at time `s`. 868 | model_prev_list: A list of pytorch tensor. The previous computed model values. 869 | t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) 870 | t: A pytorch tensor. The ending time, with the shape (1,). 871 | solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. 872 | The type slightly impacts the performance. We recommend to use 'dpmsolver' type. 873 | Returns: 874 | x_t: A pytorch tensor. The approximated solution at time `t`. 875 | """ 876 | ns = self.noise_schedule 877 | model_prev_2, model_prev_1, model_prev_0 = model_prev_list 878 | t_prev_2, t_prev_1, t_prev_0 = t_prev_list 879 | lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = ns.marginal_lambda(t_prev_2), ns.marginal_lambda(t_prev_1), ns.marginal_lambda(t_prev_0), ns.marginal_lambda(t) 880 | log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t) 881 | sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t) 882 | alpha_t = torch.exp(log_alpha_t) 883 | 884 | h_1 = lambda_prev_1 - lambda_prev_2 885 | h_0 = lambda_prev_0 - lambda_prev_1 886 | h = lambda_t - lambda_prev_0 887 | r0, r1 = h_0 / h, h_1 / h 888 | D1_0 = (1. / r0) * (model_prev_0 - model_prev_1) 889 | D1_1 = (1. / r1) * (model_prev_1 - model_prev_2) 890 | D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1) 891 | D2 = (1. / (r0 + r1)) * (D1_0 - D1_1) 892 | if self.algorithm_type == "dpmsolver++": 893 | phi_1 = torch.expm1(-h) 894 | phi_2 = phi_1 / h + 1. 895 | phi_3 = phi_2 / h - 0.5 896 | x_t = ( 897 | (sigma_t / sigma_prev_0) * x 898 | - (alpha_t * phi_1) * model_prev_0 899 | + (alpha_t * phi_2) * D1 900 | - (alpha_t * phi_3) * D2 901 | ) 902 | else: 903 | phi_1 = torch.expm1(h) 904 | phi_2 = phi_1 / h - 1. 905 | phi_3 = phi_2 / h - 0.5 906 | x_t = ( 907 | (torch.exp(log_alpha_t - log_alpha_prev_0)) * x 908 | - (sigma_t * phi_1) * model_prev_0 909 | - (sigma_t * phi_2) * D1 910 | - (sigma_t * phi_3) * D2 911 | ) 912 | return x_t 913 | 914 | def singlestep_dpm_solver_update(self, x, s, t, order, return_intermediate=False, solver_type='dpmsolver', r1=None, r2=None): 915 | """ 916 | Singlestep DPM-Solver with the order `order` from time `s` to time `t`. 917 | 918 | Args: 919 | x: A pytorch tensor. The initial value at time `s`. 920 | s: A pytorch tensor. The starting time, with the shape (1,). 921 | t: A pytorch tensor. The ending time, with the shape (1,). 922 | order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. 923 | return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times). 924 | solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. 925 | The type slightly impacts the performance. We recommend to use 'dpmsolver' type. 926 | r1: A `float`. The hyperparameter of the second-order or third-order solver. 927 | r2: A `float`. The hyperparameter of the third-order solver. 928 | Returns: 929 | x_t: A pytorch tensor. The approximated solution at time `t`. 930 | """ 931 | if order == 1: 932 | return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate) 933 | elif order == 2: 934 | return self.singlestep_dpm_solver_second_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1) 935 | elif order == 3: 936 | return self.singlestep_dpm_solver_third_update(x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2) 937 | else: 938 | raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) 939 | 940 | def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type='dpmsolver'): 941 | """ 942 | Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`. 943 | 944 | Args: 945 | x: A pytorch tensor. The initial value at time `s`. 946 | model_prev_list: A list of pytorch tensor. The previous computed model values. 947 | t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,) 948 | t: A pytorch tensor. The ending time, with the shape (1,). 949 | order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3. 950 | solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. 951 | The type slightly impacts the performance. We recommend to use 'dpmsolver' type. 952 | Returns: 953 | x_t: A pytorch tensor. The approximated solution at time `t`. 954 | """ 955 | if order == 1: 956 | return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1]) 957 | elif order == 2: 958 | return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) 959 | elif order == 3: 960 | return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type) 961 | else: 962 | raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order)) 963 | 964 | def dpm_solver_adaptive(self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type='dpmsolver'): 965 | """ 966 | The adaptive step size solver based on singlestep DPM-Solver. 967 | 968 | Args: 969 | x: A pytorch tensor. The initial value at time `t_T`. 970 | order: A `int`. The (higher) order of the solver. We only support order == 2 or 3. 971 | t_T: A `float`. The starting time of the sampling (default is T). 972 | t_0: A `float`. The ending time of the sampling (default is epsilon). 973 | h_init: A `float`. The initial step size (for logSNR). 974 | atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1]. 975 | rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05. 976 | theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1]. 977 | t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the 978 | current time and `t_0` is less than `t_err`. The default setting is 1e-5. 979 | solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers. 980 | The type slightly impacts the performance. We recommend to use 'dpmsolver' type. 981 | Returns: 982 | x_0: A pytorch tensor. The approximated solution at time `t_0`. 983 | 984 | [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021. 985 | """ 986 | ns = self.noise_schedule 987 | s = t_T * torch.ones((1,)).to(x) 988 | lambda_s = ns.marginal_lambda(s) 989 | lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x)) 990 | h = h_init * torch.ones_like(s).to(x) 991 | x_prev = x 992 | nfe = 0 993 | if order == 2: 994 | r1 = 0.5 995 | lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True) 996 | higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, solver_type=solver_type, **kwargs) 997 | elif order == 3: 998 | r1, r2 = 1. / 3., 2. / 3. 999 | lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type) 1000 | higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs) 1001 | else: 1002 | raise ValueError("For adaptive step size solver, order must be 2 or 3, got {}".format(order)) 1003 | while torch.abs((s - t_0)).mean() > t_err: 1004 | t = ns.inverse_lambda(lambda_s + h) 1005 | x_lower, lower_noise_kwargs = lower_update(x, s, t) 1006 | x_higher = higher_update(x, s, t, **lower_noise_kwargs) 1007 | delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev))) 1008 | norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True)) 1009 | E = norm_fn((x_higher - x_lower) / delta).max() 1010 | if torch.all(E <= 1.): 1011 | x = x_higher 1012 | s = t 1013 | x_prev = x_lower 1014 | lambda_s = ns.marginal_lambda(s) 1015 | h = torch.min(theta * h * torch.float_power(E, -1. / order).float(), lambda_0 - lambda_s) 1016 | nfe += order 1017 | print('adaptive solver nfe', nfe) 1018 | return x 1019 | 1020 | def add_noise(self, x, t, noise=None): 1021 | """ 1022 | Compute the noised input xt = alpha_t * x + sigma_t * noise. 1023 | 1024 | Args: 1025 | x: A `torch.Tensor` with shape `(batch_size, *shape)`. 1026 | t: A `torch.Tensor` with shape `(t_size,)`. 1027 | Returns: 1028 | xt with shape `(t_size, batch_size, *shape)`. 1029 | """ 1030 | alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t) 1031 | if noise is None: 1032 | noise = torch.randn((t.shape[0], *x.shape), device=x.device) 1033 | x = x.reshape((-1, *x.shape)) 1034 | xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise 1035 | if t.shape[0] == 1: 1036 | return xt.squeeze(0) 1037 | else: 1038 | return xt 1039 | 1040 | def inverse(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', 1041 | method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver', 1042 | atol=0.0078, rtol=0.05, return_intermediate=False, 1043 | ): 1044 | """ 1045 | Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver. 1046 | For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training. 1047 | """ 1048 | t_0 = 1. / self.noise_schedule.total_N if t_start is None else t_start 1049 | t_T = self.noise_schedule.T if t_end is None else t_end 1050 | assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" 1051 | return self.sample(x, steps=steps, t_start=t_0, t_end=t_T, order=order, skip_type=skip_type, 1052 | method=method, lower_order_final=lower_order_final, denoise_to_zero=denoise_to_zero, solver_type=solver_type, 1053 | atol=atol, rtol=rtol, return_intermediate=return_intermediate) 1054 | 1055 | def sample(self, x, steps=20, t_start=None, t_end=None, order=2, skip_type='time_uniform', 1056 | method='multistep', lower_order_final=True, denoise_to_zero=False, solver_type='dpmsolver', 1057 | atol=0.0078, rtol=0.05, return_intermediate=False, 1058 | ): 1059 | """ 1060 | Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`. 1061 | 1062 | ===================================================== 1063 | 1064 | We support the following algorithms for both noise prediction model and data prediction model: 1065 | - 'singlestep': 1066 | Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver. 1067 | We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps). 1068 | The total number of function evaluations (NFE) == `steps`. 1069 | Given a fixed NFE == `steps`, the sampling procedure is: 1070 | - If `order` == 1: 1071 | - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM). 1072 | - If `order` == 2: 1073 | - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling. 1074 | - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2. 1075 | - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. 1076 | - If `order` == 3: 1077 | - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling. 1078 | - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1. 1079 | - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1. 1080 | - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2. 1081 | - 'multistep': 1082 | Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`. 1083 | We initialize the first `order` values by lower order multistep solvers. 1084 | Given a fixed NFE == `steps`, the sampling procedure is: 1085 | Denote K = steps. 1086 | - If `order` == 1: 1087 | - We use K steps of DPM-Solver-1 (i.e. DDIM). 1088 | - If `order` == 2: 1089 | - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2. 1090 | - If `order` == 3: 1091 | - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3. 1092 | - 'singlestep_fixed': 1093 | Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3). 1094 | We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE. 1095 | - 'adaptive': 1096 | Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper). 1097 | We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`. 1098 | You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs 1099 | (NFE) and the sample quality. 1100 | - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2. 1101 | - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3. 1102 | 1103 | ===================================================== 1104 | 1105 | Some advices for choosing the algorithm: 1106 | - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs: 1107 | Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`. 1108 | e.g., DPM-Solver: 1109 | >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver") 1110 | >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, 1111 | skip_type='time_uniform', method='singlestep') 1112 | e.g., DPM-Solver++: 1113 | >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") 1114 | >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3, 1115 | skip_type='time_uniform', method='singlestep') 1116 | - For **guided sampling with large guidance scale** by DPMs: 1117 | Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`. 1118 | e.g. 1119 | >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") 1120 | >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2, 1121 | skip_type='time_uniform', method='multistep') 1122 | 1123 | We support three types of `skip_type`: 1124 | - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images** 1125 | - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**. 1126 | - 'time_quadratic': quadratic time for the time steps. 1127 | 1128 | ===================================================== 1129 | Args: 1130 | x: A pytorch tensor. The initial value at time `t_start` 1131 | e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution. 1132 | steps: A `int`. The total number of function evaluations (NFE). 1133 | t_start: A `float`. The starting time of the sampling. 1134 | If `T` is None, we use self.noise_schedule.T (default is 1.0). 1135 | t_end: A `float`. The ending time of the sampling. 1136 | If `t_end` is None, we use 1. / self.noise_schedule.total_N. 1137 | e.g. if total_N == 1000, we have `t_end` == 1e-3. 1138 | For discrete-time DPMs: 1139 | - We recommend `t_end` == 1. / self.noise_schedule.total_N. 1140 | For continuous-time DPMs: 1141 | - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15. 1142 | order: A `int`. The order of DPM-Solver. 1143 | skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'. 1144 | method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'. 1145 | denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step. 1146 | Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1). 1147 | 1148 | This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and 1149 | score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID 1150 | for diffusion models sampling by diffusion SDEs for low-resolutional images 1151 | (such as CIFAR-10). However, we observed that such trick does not matter for 1152 | high-resolutional images. As it needs an additional NFE, we do not recommend 1153 | it for high-resolutional images. 1154 | lower_order_final: A `bool`. Whether to use lower order solvers at the final steps. 1155 | Only valid for `method=multistep` and `steps < 15`. We empirically find that 1156 | this trick is a key to stabilizing the sampling by DPM-Solver with very few steps 1157 | (especially for steps <= 10). So we recommend to set it to be `True`. 1158 | solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`. 1159 | atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. 1160 | rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'. 1161 | return_intermediate: A `bool`. Whether to save the xt at each step. 1162 | When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0. 1163 | Returns: 1164 | x_end: A pytorch tensor. The approximated solution at time `t_end`. 1165 | 1166 | """ 1167 | t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end 1168 | t_T = self.noise_schedule.T if t_start is None else t_start 1169 | assert t_0 > 0 and t_T > 0, "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array" 1170 | if return_intermediate: 1171 | assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when saving intermediate values" 1172 | if self.correcting_xt_fn is not None: 1173 | assert method in ['multistep', 'singlestep', 'singlestep_fixed'], "Cannot use adaptive solver when correcting_xt_fn is not None" 1174 | device = x.device 1175 | intermediates = [] 1176 | with torch.no_grad(): 1177 | if method == 'adaptive': 1178 | x = self.dpm_solver_adaptive(x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type) 1179 | elif method == 'multistep': 1180 | assert steps >= order 1181 | timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device) 1182 | assert timesteps.shape[0] - 1 == steps 1183 | # Init the initial values. 1184 | step = 0 1185 | t = timesteps[step] 1186 | t_prev_list = [t] 1187 | model_prev_list = [self.model_fn(x, t)] 1188 | if self.correcting_xt_fn is not None: 1189 | x = self.correcting_xt_fn(x, t, step) 1190 | if return_intermediate: 1191 | intermediates.append(x) 1192 | # Init the first `order` values by lower order multistep DPM-Solver. 1193 | for step in range(1, order): 1194 | t = timesteps[step] 1195 | x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step, solver_type=solver_type) 1196 | if self.correcting_xt_fn is not None: 1197 | x = self.correcting_xt_fn(x, t, step) 1198 | if return_intermediate: 1199 | intermediates.append(x) 1200 | t_prev_list.append(t) 1201 | model_prev_list.append(self.model_fn(x, t)) 1202 | # Compute the remaining values by `order`-th order multistep DPM-Solver. 1203 | for step in range(order, steps + 1): 1204 | t = timesteps[step] 1205 | # We only use lower order for steps < 10 1206 | if lower_order_final and steps < 10: 1207 | step_order = min(order, steps + 1 - step) 1208 | else: 1209 | step_order = order 1210 | x = self.multistep_dpm_solver_update(x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type) 1211 | if self.correcting_xt_fn is not None: 1212 | x = self.correcting_xt_fn(x, t, step) 1213 | if return_intermediate: 1214 | intermediates.append(x) 1215 | for i in range(order - 1): 1216 | t_prev_list[i] = t_prev_list[i + 1] 1217 | model_prev_list[i] = model_prev_list[i + 1] 1218 | t_prev_list[-1] = t 1219 | # We do not need to evaluate the final model value. 1220 | if step < steps: 1221 | model_prev_list[-1] = self.model_fn(x, t) 1222 | elif method in ['singlestep', 'singlestep_fixed']: 1223 | if method == 'singlestep': 1224 | timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device) 1225 | elif method == 'singlestep_fixed': 1226 | K = steps // order 1227 | orders = [order,] * K 1228 | timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device) 1229 | for step, order in enumerate(orders): 1230 | s, t = timesteps_outer[step], timesteps_outer[step + 1] 1231 | timesteps_inner = self.get_time_steps(skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device) 1232 | lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner) 1233 | h = lambda_inner[-1] - lambda_inner[0] 1234 | r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h 1235 | r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h 1236 | x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2) 1237 | if self.correcting_xt_fn is not None: 1238 | x = self.correcting_xt_fn(x, t, step) 1239 | if return_intermediate: 1240 | intermediates.append(x) 1241 | else: 1242 | raise ValueError("Got wrong method {}".format(method)) 1243 | if denoise_to_zero: 1244 | t = torch.ones((1,)).to(device) * t_0 1245 | x = self.denoise_to_zero_fn(x, t) 1246 | if self.correcting_xt_fn is not None: 1247 | x = self.correcting_xt_fn(x, t, step + 1) 1248 | if return_intermediate: 1249 | intermediates.append(x) 1250 | if return_intermediate: 1251 | return x, intermediates 1252 | else: 1253 | return x 1254 | 1255 | 1256 | 1257 | ############################################################# 1258 | # other utility functions 1259 | ############################################################# 1260 | 1261 | def interpolate_fn(x, xp, yp): 1262 | """ 1263 | A piecewise linear function y = f(x), using xp and yp as keypoints. 1264 | We implement f(x) in a differentiable way (i.e. applicable for autograd). 1265 | The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.) 1266 | 1267 | Args: 1268 | x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver). 1269 | xp: PyTorch tensor with shape [C, K], where K is the number of keypoints. 1270 | yp: PyTorch tensor with shape [C, K]. 1271 | Returns: 1272 | The function values f(x), with shape [N, C]. 1273 | """ 1274 | N, K = x.shape[0], xp.shape[1] 1275 | all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2) 1276 | sorted_all_x, x_indices = torch.sort(all_x, dim=2) 1277 | x_idx = torch.argmin(x_indices, dim=2) 1278 | cand_start_idx = x_idx - 1 1279 | start_idx = torch.where( 1280 | torch.eq(x_idx, 0), 1281 | torch.tensor(1, device=x.device), 1282 | torch.where( 1283 | torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, 1284 | ), 1285 | ) 1286 | end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1) 1287 | start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2) 1288 | end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2) 1289 | start_idx2 = torch.where( 1290 | torch.eq(x_idx, 0), 1291 | torch.tensor(0, device=x.device), 1292 | torch.where( 1293 | torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx, 1294 | ), 1295 | ) 1296 | y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1) 1297 | start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2) 1298 | end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2) 1299 | cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x) 1300 | return cand 1301 | 1302 | 1303 | def expand_dims(v, dims): 1304 | """ 1305 | Expand the tensor `v` to the dim `dims`. 1306 | 1307 | Args: 1308 | `v`: a PyTorch tensor with shape [N]. 1309 | `dim`: a `int`. 1310 | Returns: 1311 | a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. 1312 | """ 1313 | return v[(...,) + (None,)*(dims - 1)] 1314 | -------------------------------------------------------------------------------- /schedule/schedule.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class Schedule: 6 | def __init__(self, schedule, timesteps): 7 | self.timesteps = timesteps 8 | self.schedule = schedule 9 | 10 | def cosine_beta_schedule(self, s=0.001): 11 | timesteps = self.timesteps 12 | steps = timesteps + 1 13 | x = torch.linspace(0, timesteps, steps) 14 | alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * np.pi * 0.5) ** 2 15 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 16 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 17 | return torch.clip(betas, 0.0001, 0.9999) 18 | 19 | def linear_beta_schedule(self): 20 | timesteps = self.timesteps 21 | scale = 1000 / timesteps 22 | beta_start = 1e-6 * scale 23 | beta_end = 0.02 * scale 24 | return torch.linspace(beta_start, beta_end, timesteps) 25 | 26 | def quadratic_beta_schedule(self): 27 | timesteps = self.timesteps 28 | scale = 1000 / timesteps 29 | beta_start = 1e-6 * scale 30 | beta_end = 0.02 * scale 31 | return torch.linspace(beta_start ** 0.5, beta_end ** 0.5, timesteps) ** 2 32 | 33 | def sigmoid_beta_schedule(self): 34 | timesteps = self.timesteps 35 | scale = 1000 / timesteps 36 | beta_start = 1e-6 * scale 37 | beta_end = 0.02 * scale 38 | betas = torch.linspace(-6, 6, timesteps) 39 | return torch.sigmoid(betas) * (beta_end - beta_start) + beta_start 40 | 41 | def get_betas(self): 42 | if self.schedule == "linear": 43 | return self.linear_beta_schedule() 44 | elif self.schedule == 'cosine': 45 | return self.cosine_beta_schedule() 46 | else: 47 | raise NotImplementedError 48 | 49 | 50 | if __name__ == "__main__": 51 | schedule = Schedule(schedule="linear", timesteps=100) 52 | print(schedule.get_betas().shape) 53 | print(schedule.get_betas()) -------------------------------------------------------------------------------- /src/config.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import os 3 | 4 | 5 | class Config(dict): 6 | def __init__(self, config_path): 7 | with open(config_path, 'r') as f: 8 | self._yaml = f.read() 9 | self._dict = yaml.safe_load(self._yaml) 10 | self._dict['PATH'] = os.path.dirname(config_path) 11 | 12 | def __getattr__(self, name): 13 | if self._dict.get(name) is not None: 14 | return self._dict[name] 15 | return None 16 | 17 | def print(self): 18 | print('Model configurations:') 19 | print('---------------------------------') 20 | print(self._yaml) 21 | print('') 22 | print('---------------------------------') 23 | print('') 24 | 25 | 26 | def load_config(path): 27 | config_path = path 28 | config = Config(config_path) 29 | return config 30 | -------------------------------------------------------------------------------- /src/sobel.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch 3 | 4 | 5 | class Sobel(nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | self.filter = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=3, stride=1, padding=1, bias=False) 9 | 10 | Gx = torch.tensor([[1.0, 0.0, -1.0], [2.0, 0.0, -2.0], [1.0, 0.0, -1.0]]) 11 | Gy = torch.tensor([[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]]) 12 | G = torch.cat([Gx.unsqueeze(0), Gy.unsqueeze(0)], 0) 13 | G = G.unsqueeze(1) 14 | self.filter.weight = nn.Parameter(G, requires_grad=False) 15 | 16 | def forward(self, img): 17 | if img.shape[1] == 3: 18 | img = torch.mean(img, dim=1, keepdim=True) 19 | x = self.filter(img) 20 | x = torch.mul(x, x) 21 | x = torch.sum(x, dim=1, keepdim=True) 22 | x = torch.sqrt(x) 23 | return x 24 | 25 | 26 | class Laplacian(nn.Module): 27 | def __init__(self): 28 | super().__init__() 29 | self.filter = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=1, padding=1, bias=False, groups=3) 30 | G = torch.tensor([[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]]).float() 31 | G = G.unsqueeze(0).unsqueeze(0) 32 | G = torch.cat([G, G, G], 0) 33 | self.filter.weight = nn.Parameter(G, requires_grad=False) 34 | 35 | def forward(self, img): 36 | x = self.filter(img) 37 | return x 38 | 39 | if __name__ == "__main__": 40 | laplacian = Laplacian() 41 | img = torch.randn(1, 3, 256, 256) 42 | y = laplacian(img) 43 | print(y.shape) -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | from src.trainer import Trainer 2 | 3 | 4 | def train(config): 5 | trainer = Trainer(config) 6 | trainer.train() 7 | print('training complete') 8 | 9 | 10 | def test(config): 11 | trainer = Trainer(config) 12 | trainer.test() 13 | print('testing complete') 14 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | from schedule.schedule import Schedule 3 | from model.DocDiff import DocDiff, EMA 4 | from schedule.diffusionSample import GaussianDiffusion 5 | from schedule.dpm_solver_pytorch import NoiseScheduleVP, model_wrapper, DPM_Solver 6 | import torch 7 | import torch.optim as optim 8 | import torch.nn as nn 9 | from torch.utils.data import DataLoader 10 | from torchvision.utils import save_image 11 | from tqdm import tqdm 12 | import copy 13 | from src.sobel import Laplacian 14 | 15 | 16 | def init__result_Dir(): 17 | work_dir = os.path.join(os.getcwd(), 'Training') 18 | max_model = 0 19 | for root, j, file in os.walk(work_dir): 20 | for dirs in j: 21 | try: 22 | temp = int(dirs) 23 | if temp > max_model: 24 | max_model = temp 25 | except: 26 | continue 27 | break 28 | max_model += 1 29 | path = os.path.join(work_dir, str(max_model)) 30 | os.mkdir(path) 31 | return path 32 | 33 | 34 | class Trainer: 35 | def __init__(self, config): 36 | self.mode = config.MODE 37 | self.schedule = Schedule(config.SCHEDULE, config.TIMESTEPS) 38 | self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 39 | in_channels = config.CHANNEL_X + config.CHANNEL_Y 40 | out_channels = config.CHANNEL_Y 41 | self.out_channels = out_channels 42 | self.network = DocDiff( 43 | input_channels=in_channels, 44 | output_channels=out_channels, 45 | n_channels=config.MODEL_CHANNELS, 46 | ch_mults=config.CHANNEL_MULT, 47 | n_blocks=config.NUM_RESBLOCKS 48 | ).to(self.device) 49 | self.diffusion = GaussianDiffusion(self.network.denoiser, config.TIMESTEPS, self.schedule).to(self.device) 50 | self.test_img_save_path = config.TEST_IMG_SAVE_PATH 51 | if not os.path.exists(self.test_img_save_path): 52 | os.makedirs(self.test_img_save_path) 53 | self.pretrained_path_init_predictor = config.PRETRAINED_PATH_INITIAL_PREDICTOR 54 | self.pretrained_path_denoiser = config.PRETRAINED_PATH_DENOISER 55 | self.continue_training = config.CONTINUE_TRAINING 56 | self.continue_training_steps = 0 57 | self.path_train_gt = config.PATH_GT 58 | self.path_train_img = config.PATH_IMG 59 | self.iteration_max = config.ITERATION_MAX 60 | self.LR = config.LR 61 | self.cross_entropy = nn.BCELoss() 62 | self.num_timesteps = config.TIMESTEPS 63 | self.ema_every = config.EMA_EVERY 64 | self.start_ema = config.START_EMA 65 | self.save_model_every = config.SAVE_MODEL_EVERY 66 | self.EMA_or_not = config.EMA 67 | self.weight_save_path = config.WEIGHT_SAVE_PATH 68 | self.TEST_INITIAL_PREDICTOR_WEIGHT_PATH = config.TEST_INITIAL_PREDICTOR_WEIGHT_PATH 69 | self.TEST_DENOISER_WEIGHT_PATH = config.TEST_DENOISER_WEIGHT_PATH 70 | self.DPM_SOLVER = config.DPM_SOLVER 71 | self.DPM_STEP = config.DPM_STEP 72 | self.test_path_img = config.TEST_PATH_IMG 73 | self.test_path_gt = config.TEST_PATH_GT 74 | self.beta_loss = config.BETA_LOSS 75 | self.pre_ori = config.PRE_ORI 76 | self.high_low_freq = config.HIGH_LOW_FREQ 77 | self.image_size = config.IMAGE_SIZE 78 | self.native_resolution = config.NATIVE_RESOLUTION 79 | if self.mode == 1 and self.continue_training == 'True': 80 | print('Continue Training') 81 | self.network.init_predictor.load_state_dict(torch.load(self.pretrained_path_init_predictor)) 82 | self.network.denoiser.load_state_dict(torch.load(self.pretrained_path_denoiser)) 83 | self.continue_training_steps = config.CONTINUE_TRAINING_STEPS 84 | from data.docdata import DocData 85 | if self.mode == 1: 86 | dataset_train = DocData(self.path_train_img, self.path_train_gt, config.IMAGE_SIZE, self.mode) 87 | self.batch_size = config.BATCH_SIZE 88 | self.dataloader_train = DataLoader(dataset_train, batch_size=self.batch_size, shuffle=True, drop_last=False, 89 | num_workers=config.NUM_WORKERS) 90 | else: 91 | dataset_test = DocData(config.TEST_PATH_IMG, config.TEST_PATH_GT, config.IMAGE_SIZE, self.mode) 92 | self.dataloader_test = DataLoader(dataset_test, batch_size=config.BATCH_SIZE_VAL, shuffle=False, 93 | drop_last=False, 94 | num_workers=config.NUM_WORKERS) 95 | if self.mode == 1 and config.EMA == 'True': 96 | self.EMA = EMA(0.9999) 97 | self.ema_model = copy.deepcopy(self.network).to(self.device) 98 | if config.LOSS == 'L1': 99 | self.loss = nn.L1Loss() 100 | elif config.LOSS == 'L2': 101 | self.loss = nn.MSELoss() 102 | else: 103 | print('Loss not implemented, setting the loss to L2 (default one)') 104 | self.loss = nn.MSELoss() 105 | if self.high_low_freq == 'True': 106 | self.high_filter = Laplacian().to(self.device) 107 | 108 | def test(self): 109 | def crop_concat(img, size=128): 110 | shape = img.shape 111 | correct_shape = (size*(shape[2]//size+1), size*(shape[3]//size+1)) 112 | one = torch.ones((shape[0], shape[1], correct_shape[0], correct_shape[1])) 113 | one[:, :, :shape[2], :shape[3]] = img 114 | # crop 115 | for i in range(shape[2]//size+1): 116 | for j in range(shape[3]//size+1): 117 | if i == 0 and j == 0: 118 | crop = one[:, :, i*size:(i+1)*size, j*size:(j+1)*size] 119 | else: 120 | crop = torch.cat((crop, one[:, :, i*size:(i+1)*size, j*size:(j+1)*size]), dim=0) 121 | return crop 122 | def crop_concat_back(img, prediction, size=128): 123 | shape = img.shape 124 | for i in range(shape[2]//size+1): 125 | for j in range(shape[3]//size+1): 126 | if j == 0: 127 | crop = prediction[(i*(shape[3]//size+1)+j)*shape[0]:(i*(shape[3]//size+1)+j+1)*shape[0], :, :, :] 128 | else: 129 | crop = torch.cat((crop, prediction[(i*(shape[3]//size+1)+j)*shape[0]:(i*(shape[3]//size+1)+j+1)*shape[0], :, :, :]), dim=3) 130 | if i == 0: 131 | crop_concat = crop 132 | else: 133 | crop_concat = torch.cat((crop_concat, crop), dim=2) 134 | return crop_concat[:, :, :shape[2], :shape[3]] 135 | 136 | def min_max(array): 137 | return (array - array.min()) / (array.max() - array.min()) 138 | with torch.no_grad(): 139 | self.network.init_predictor.load_state_dict(torch.load(self.TEST_INITIAL_PREDICTOR_WEIGHT_PATH)) 140 | self.network.denoiser.load_state_dict(torch.load(self.TEST_DENOISER_WEIGHT_PATH)) 141 | print('Test Model loaded') 142 | self.network.eval() 143 | tq = tqdm(self.dataloader_test) 144 | sampler = self.diffusion 145 | iteration = 0 146 | for img, gt, name in tq: 147 | tq.set_description(f'Iteration {iteration} / {len(self.dataloader_test.dataset)}') 148 | iteration += 1 149 | if self.native_resolution == 'True': 150 | temp = img 151 | img = crop_concat(img) 152 | noisyImage = torch.randn_like(img).to(self.device) 153 | init_predict = self.network.init_predictor(img.to(self.device), 0) 154 | 155 | if self.DPM_SOLVER == 'True': 156 | sampledImgs = dpm_solver(self.schedule.get_betas(), self.network, 157 | torch.cat((noisyImage, img.to(self.device)), dim=1), self.DPM_STEP) 158 | else: 159 | sampledImgs = sampler(noisyImage.cuda(), init_predict, self.pre_ori) 160 | finalImgs = (sampledImgs + init_predict) 161 | if self.native_resolution == 'True': 162 | finalImgs = crop_concat_back(temp, finalImgs) 163 | init_predict = crop_concat_back(temp, init_predict) 164 | sampledImgs = crop_concat_back(temp, sampledImgs) 165 | img = temp 166 | img_save = torch.cat((img, gt, init_predict.cpu(), min_max(sampledImgs.cpu()), finalImgs.cpu()), dim=3) 167 | save_image(img_save, os.path.join( 168 | self.test_img_save_path, f"{name[0]}.png"), nrow=4) 169 | 170 | 171 | def train(self): 172 | optimizer = optim.AdamW(self.network.parameters(), lr=self.LR, weight_decay=1e-4) 173 | iteration = self.continue_training_steps 174 | save_img_path = init__result_Dir() 175 | print('Starting Training', f"Step is {self.num_timesteps}") 176 | 177 | while iteration < self.iteration_max: 178 | 179 | tq = tqdm(self.dataloader_train) 180 | 181 | for img, gt, name in tq: 182 | tq.set_description(f'Iteration {iteration} / {self.iteration_max}') 183 | self.network.train() 184 | optimizer.zero_grad() 185 | 186 | t = torch.randint(0, self.num_timesteps, (img.shape[0],)).long().to(self.device) 187 | init_predict, noise_pred, noisy_image, noise_ref = self.network(gt.to(self.device), img.to(self.device), 188 | t, self.diffusion) 189 | if self.pre_ori == 'True': 190 | if self.high_low_freq == 'True': 191 | residual_high = self.high_filter(gt.to(self.device) - init_predict) 192 | ddpm_loss = 2*self.loss(self.high_filter(noise_pred), residual_high) + self.loss(noise_pred, gt.to(self.device) - init_predict) 193 | else: 194 | ddpm_loss = self.loss(noise_pred, gt.to(self.device) - init_predict) 195 | else: 196 | ddpm_loss = self.loss(noise_pred, noise_ref.to(self.device)) 197 | if self.high_low_freq == 'True': 198 | low_high_loss = self.loss(init_predict, gt.to(self.device)) 199 | low_freq_loss = self.loss(init_predict - self.high_filter(init_predict), gt.to(self.device) - self.high_filter(gt.to(self.device))) 200 | pixel_loss = low_high_loss + 2*low_freq_loss 201 | else: 202 | pixel_loss = self.loss(init_predict, gt.to(self.device)) 203 | 204 | loss = ddpm_loss + self.beta_loss * (pixel_loss) / self.num_timesteps 205 | loss.backward() 206 | optimizer.step() 207 | if self.high_low_freq == 'True': 208 | tq.set_postfix(loss=loss.item(), high_freq_ddpm_loss=ddpm_loss.item(), low_freq_pixel_loss=low_freq_loss.item(), pixel_loss=low_high_loss.item()) 209 | else: 210 | tq.set_postfix(loss=loss.item(), ddpm_loss=ddpm_loss.item(), pixel_loss=pixel_loss.item()) 211 | if iteration % 500 == 0: 212 | if not os.path.exists(save_img_path): 213 | os.makedirs(save_img_path) 214 | img_save = torch.cat([img, gt, init_predict.cpu()], dim=3) 215 | if self.pre_ori == 'True': 216 | if self.high_low_freq == 'True': 217 | img_save = torch.cat([img, gt, init_predict.cpu(), noise_pred.cpu() + self.high_filter(init_predict).cpu(), noise_pred.cpu() + init_predict.cpu()], dim=3) 218 | else: 219 | img_save = torch.cat([img, gt, init_predict.cpu(), noise_pred.cpu() + init_predict.cpu()], dim=3) 220 | save_image(img_save, os.path.join( 221 | save_img_path, f"{iteration}.png"), nrow=4) 222 | iteration += 1 223 | if self.EMA_or_not == 'True': 224 | if iteration % self.ema_every == 0 and iteration > self.start_ema: 225 | print('EMA update') 226 | self.EMA.update_model_average(self.ema_model, self.network) 227 | 228 | if iteration % self.save_model_every == 0: 229 | print('Saving models') 230 | if not os.path.exists(self.weight_save_path): 231 | os.makedirs(self.weight_save_path) 232 | torch.save(self.network.init_predictor.state_dict(), 233 | os.path.join(self.weight_save_path, f'model_init_{iteration}.pth')) 234 | torch.save(self.network.denoiser.state_dict(), 235 | os.path.join(self.weight_save_path, f'model_denoiser_{iteration}.pth')) 236 | if self.EMA_or_not == 'True': 237 | torch.save(self.ema_model.init_predictor.state_dict(), 238 | os.path.join(self.weight_save_path, f'model_init_ema_{iteration}.pth')) 239 | torch.save(self.ema_model.denoiser.state_dict(), 240 | os.path.join(self.weight_save_path, f'model_denoiser_ema_{iteration}.pth')) 241 | 242 | 243 | 244 | def dpm_solver(betas, model, x_T, steps, model_kwargs): 245 | # You need to firstly define your model and the extra inputs of your model, 246 | # And initialize an `x_T` from the standard normal distribution. 247 | # `model` has the format: model(x_t, t_input, **model_kwargs). 248 | # If your model has no extra inputs, just let model_kwargs = {}. 249 | 250 | # If you use discrete-time DPMs, you need to further define the 251 | # beta arrays for the noise schedule. 252 | 253 | # model = .... 254 | # model_kwargs = {...} 255 | # x_T = ... 256 | # betas = .... 257 | 258 | # 1. Define the noise schedule. 259 | noise_schedule = NoiseScheduleVP(schedule='discrete', betas=betas) 260 | 261 | # 2. Convert your discrete-time `model` to the continuous-time 262 | # noise prediction model. Here is an example for a diffusion model 263 | # `model` with the noise prediction type ("noise") . 264 | model_fn = model_wrapper( 265 | model, 266 | noise_schedule, 267 | model_type="noise", # or "x_start" or "v" or "score" 268 | model_kwargs=model_kwargs, 269 | ) 270 | 271 | # 3. Define dpm-solver and sample by singlestep DPM-Solver. 272 | # (We recommend singlestep DPM-Solver for unconditional sampling) 273 | # You can adjust the `steps` to balance the computation 274 | # costs and the sample quality. 275 | dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++", 276 | correcting_x0_fn="dynamic_thresholding") 277 | # Can also try 278 | # dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++") 279 | 280 | # You can use steps = 10, 12, 15, 20, 25, 50, 100. 281 | # Empirically, we find that steps in [10, 20] can generate quite good samples. 282 | # And steps = 20 can almost converge. 283 | x_sample = dpm_solver.sample( 284 | x_T, 285 | steps=steps, 286 | order=1, 287 | skip_type="time_uniform", 288 | method="singlestep", 289 | ) 290 | return x_sample 291 | -------------------------------------------------------------------------------- /utils/font/simhei.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/utils/font/simhei.ttf -------------------------------------------------------------------------------- /utils/font/times.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/utils/font/times.ttf -------------------------------------------------------------------------------- /utils/font/timesbd.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/utils/font/timesbd.ttf -------------------------------------------------------------------------------- /utils/font/timesbi.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/utils/font/timesbi.ttf -------------------------------------------------------------------------------- /utils/font/timesi.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/utils/font/timesi.ttf -------------------------------------------------------------------------------- /utils/font/方正仿宋_GBK.TTF: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/utils/font/方正仿宋_GBK.TTF -------------------------------------------------------------------------------- /utils/font/楷体_GB2312.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/utils/font/楷体_GB2312.ttf -------------------------------------------------------------------------------- /utils/font/青鸟华光简琥珀.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Royalvice/DocDiff/f422a3af10657eeabeca90f2e862d88c72725845/utils/font/青鸟华光简琥珀.ttf -------------------------------------------------------------------------------- /utils/marker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | # This is a watermark tool for images 4 | # Highly inherit from https://github.com/2Dou/watermarker 5 | 6 | 7 | import argparse 8 | import os 9 | import sys 10 | import math 11 | import textwrap 12 | import random 13 | 14 | import numpy as np 15 | from PIL import Image, ImageFont, ImageDraw, ImageEnhance, ImageChops, ImageOps 16 | 17 | 18 | def randomtext(): 19 | def GBK2312(): 20 | head = random.randint(0xb0, 0xf7) 21 | body = random.randint(0xa1, 0xfe) 22 | val = f'{head:x}{body:x}' 23 | strr = bytes.fromhex(val).decode('gb2312',errors="ignore") 24 | return strr 25 | 26 | def randomABC(): 27 | A = np.random.randint(65, 91) 28 | a = np.random.randint(97, 123) 29 | char = chr(A) + chr(a) 30 | for i in range(3): 31 | char = char + str(np.random.randint(0, 9)) 32 | return char 33 | 34 | char = randomABC() 35 | for i in range(5): 36 | char = char + GBK2312() 37 | return char 38 | 39 | 40 | def randomcolor(): 41 | colorArr = ['1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'] 42 | color = "" 43 | for i in range(6): 44 | color += colorArr[random.randint(0, 14)] 45 | return "#" + color 46 | 47 | 48 | def add_mark(imagePath, mark, args): 49 | """ 50 | 添加水印,然后保存图片 51 | """ 52 | im = Image.open(imagePath) 53 | im = ImageOps.exif_transpose(im) 54 | 55 | image, mask = mark(im) 56 | name = os.path.basename(imagePath) 57 | if image: 58 | if not os.path.exists(args.out): 59 | os.mkdir(args.out) 60 | 61 | new_name = os.path.join(args.out, name) 62 | if os.path.splitext(new_name)[1] != '.png': 63 | image = image.convert('RGB') 64 | image.save(new_name, quality=args.quality) 65 | #mask.save(os.path.join(args.out, os.path.basename(os.path.splitext(new_name)[0]) + '.png'), 66 | # quality=args.quality) 67 | print(name + " Success.") 68 | else: 69 | print(name + " Failed.") 70 | 71 | 72 | def set_opacity(im, opacity): 73 | """ 74 | 设置水印透明度 75 | """ 76 | assert 0 <= opacity <= 1 77 | 78 | alpha = im.split()[3] 79 | alpha = ImageEnhance.Brightness(alpha).enhance(opacity) 80 | im.putalpha(alpha) 81 | return im 82 | 83 | 84 | def crop_image(im): 85 | """裁剪图片边缘空白""" 86 | bg = Image.new(mode='RGBA', size=im.size) 87 | diff = ImageChops.difference(im, bg) 88 | del bg 89 | bbox = diff.getbbox() 90 | if bbox: 91 | return im.crop(bbox) 92 | return im 93 | 94 | 95 | def gen_mark(args): 96 | """ 97 | 生成mark图片,返回添加水印的函数 98 | """ 99 | # 字体宽度、高度 100 | is_height_crop_float = '.' in args.font_height_crop # not good but work 101 | width = len(args.mark) * args.size 102 | if is_height_crop_float: 103 | height = round(args.size * float(args.font_height_crop)) 104 | else: 105 | height = int(args.font_height_crop) 106 | 107 | # 创建水印图片(宽度、高度) 108 | mark = Image.new(mode='RGBA', size=(width, height)) 109 | 110 | # 生成文字 111 | draw_table = ImageDraw.Draw(im=mark) 112 | draw_table.text(xy=(0, 0), 113 | text=args.mark, 114 | fill=args.color, 115 | font=ImageFont.truetype(args.font_family, 116 | size=args.size)) 117 | del draw_table 118 | 119 | # 裁剪空白 120 | mark = crop_image(mark) 121 | 122 | # 透明度 123 | set_opacity(mark, args.opacity) 124 | 125 | def mark_im(im): 126 | """ 在im图片上添加水印 im为打开的原图""" 127 | 128 | # 计算斜边长度 129 | c = int(math.sqrt(im.size[0] * im.size[0] + im.size[1] * im.size[1])) 130 | 131 | # 以斜边长度为宽高创建大图(旋转后大图才足以覆盖原图) 132 | mark2 = Image.new(mode='RGBA', size=(c, c)) 133 | 134 | # 在大图上生成水印文字,此处mark为上面生成的水印图片 135 | y, idx = 0, 0 136 | while y < c: 137 | # 制造x坐标错位 138 | x = -int((mark.size[0] + args.space) * 0.5 * idx) 139 | idx = (idx + 1) % 2 140 | 141 | while x < c: 142 | # 在该位置粘贴mark水印图片 143 | mark2.paste(mark, (x, y)) 144 | x = x + mark.size[0] + args.space 145 | y = y + mark.size[1] + args.space 146 | 147 | # 将大图旋转一定角度 148 | mark2 = mark2.rotate(args.angle) 149 | 150 | # 在原图上添加大图水印 151 | if im.mode != 'RGBA': 152 | im = im.convert('RGBA') 153 | mask = np.zeros_like(np.array(im)) 154 | mask = Image.fromarray(mask) 155 | im.paste(mark2, # 大图 156 | (int((im.size[0] - c) / 2), int((im.size[1] - c) / 2)), # 坐标 157 | mask=mark2.split()[3]) 158 | mask.paste(mark2, # 大图 159 | (int((im.size[0] - c) / 2), int((im.size[1] - c) / 2)), # 坐标 160 | mask=mark2.split()[3]) 161 | #mask.show() 162 | del mark2 163 | return im, mask 164 | 165 | return mark_im 166 | 167 | 168 | def main(): 169 | parse = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter) 170 | parse.add_argument("-f", "--file", type=str, default="./input", 171 | help="image file path or directory") 172 | parse.add_argument("-m", "--mark", type=str, default="测试一句话行不行", help="watermark content") 173 | parse.add_argument("-o", "--out", default="./2", 174 | help="image output directory, default is ./output") 175 | parse.add_argument("-c", "--color", default="#8B8B1B", type=str, 176 | help="text color like '#000000', default is #8B8B1B") 177 | parse.add_argument("-s", "--space", default=75, type=int, 178 | help="space between watermarks, default is 75") 179 | parse.add_argument("-a", "--angle", default=90, type=int, 180 | help="rotate angle of watermarks, default is 30") 181 | parse.add_argument("--font-family", default="./font/青鸟华光简琥珀.ttf", type=str, 182 | help=textwrap.dedent('''\ 183 | font family of text, default is './font/青鸟华光简琥珀.ttf' 184 | using font in system just by font file name 185 | for example 'PingFang.ttc', which is default installed on macOS 186 | ''')) 187 | parse.add_argument("--font-height-crop", default="1.2", type=str, 188 | help=textwrap.dedent('''\ 189 | change watermark font height crop 190 | float will be parsed to factor; int will be parsed to value 191 | default is '1.2', meaning 1.2 times font size 192 | this useful with CJK font, because line height may be higher than size 193 | ''')) 194 | parse.add_argument("--size", default=50, type=int, 195 | help="font size of text, default is 50") 196 | parse.add_argument("--opacity", default=0.15, type=float, 197 | help="opacity of watermarks, default is 0.15") 198 | parse.add_argument("--quality", default=100, type=int, 199 | help="quality of output images, default is 90") 200 | 201 | args = parse.parse_args() 202 | 203 | # 随机参数,从[A,B]中随机选取一个整数 204 | space_ = [30, 40] 205 | angel_ = [0, 180] 206 | size_ = [20, 60] 207 | opacity_ = [50, 95] 208 | # 字体随机(需要别的字体的话,可以去字体网站下载) 209 | fonts = [r'./font/青鸟华光简琥珀.ttf', r'./font/楷体_GB2312.ttf', './font/方正仿宋_GBK.TTF', './font/times.ttf', 210 | './font/simhei.ttf'] 211 | if isinstance(args.mark, str) and sys.version_info[0] < 3: 212 | args.mark = args.mark.decode("utf-8") 213 | 214 | if os.path.isdir(args.file): 215 | names = os.listdir(args.file) 216 | for name in names: 217 | image_file = os.path.join(args.file, name) 218 | space = random.randint(space_[0], space_[1]) 219 | args.space = space 220 | angel = random.randint(angel_[0], angel_[1]) 221 | args.angle = angel 222 | size = random.randint(size_[0], size_[1]) 223 | args.size = size 224 | opacity = random.randint(opacity_[0], opacity_[1]) / 100 225 | args.opacity = opacity 226 | color = randomcolor() 227 | args.color = color 228 | font = np.random.choice(fonts) 229 | args.font_family = font 230 | mark_text = randomtext() 231 | args.mark = mark_text 232 | mark = gen_mark(args) 233 | add_mark(image_file, mark, args) 234 | else: 235 | mark = gen_mark(args) 236 | add_mark(args.file, mark, args) 237 | 238 | 239 | if __name__ == '__main__': 240 | main() 241 | --------------------------------------------------------------------------------