├── AI_ETHICS.md ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.txt ├── README.md ├── SECURITY.md ├── configs ├── generate.yaml ├── generate_v21_base.yaml └── train_v21_base.yaml ├── edit_cli.py ├── edit_cli_batch.py ├── edit_cli_batch_rw_label.py ├── edit_cli_rw_label.py ├── edit_dataset.py ├── environment.yaml ├── imgs ├── example1.jpg ├── example2.jpg ├── example3.jpg └── results.png ├── main.py ├── metrics ├── clip_similarity.py └── compute_metrics.py ├── scripts ├── download_checkpoints.sh ├── download_hive_data.sh └── download_instructpix2pix_data.sh └── stable_diffusion ├── LICENSE ├── README.md ├── Stable_Diffusion_v1_Model_Card.md ├── assets ├── a-painting-of-a-fire.png ├── a-photograph-of-a-fire.png ├── a-shirt-with-a-fire-printed-on-it.png ├── a-shirt-with-the-inscription-'fire'.png ├── a-watercolor-painting-of-a-fire.png ├── birdhouse.png ├── fire.png ├── inpainting.png ├── modelfigure.png ├── rdm-preview.jpg ├── reconstruction1.png ├── reconstruction2.png ├── results.gif.REMOVED.git-id ├── rick.jpeg ├── stable-samples │ ├── img2img │ │ ├── mountains-1.png │ │ ├── mountains-2.png │ │ ├── mountains-3.png │ │ ├── sketch-mountains-input.jpg │ │ ├── upscaling-in.png.REMOVED.git-id │ │ └── upscaling-out.png.REMOVED.git-id │ └── txt2img │ │ ├── 000002025.png │ │ ├── 000002035.png │ │ ├── merged-0005.png.REMOVED.git-id │ │ ├── merged-0006.png.REMOVED.git-id │ │ └── merged-0007.png.REMOVED.git-id ├── the-earth-is-on-fire,-oil-on-canvas.png ├── txt2img-convsample.png ├── txt2img-preview.png.REMOVED.git-id └── v1-variants-scores.jpg ├── configs ├── autoencoder │ ├── autoencoder_kl_16x16x16.yaml │ ├── autoencoder_kl_32x32x4.yaml │ ├── autoencoder_kl_64x64x3.yaml │ └── autoencoder_kl_8x8x64.yaml ├── latent-diffusion │ ├── celebahq-ldm-vq-4.yaml │ ├── cin-ldm-vq-f8.yaml │ ├── cin256-v2.yaml │ ├── ffhq-ldm-vq-4.yaml │ ├── lsun_bedrooms-ldm-vq-4.yaml │ ├── lsun_churches-ldm-kl-8.yaml │ └── txt2img-1p4B-eval.yaml ├── retrieval-augmented-diffusion │ └── 768x768.yaml └── stable-diffusion │ └── v1-inference.yaml ├── data ├── DejaVuSans.ttf ├── example_conditioning │ ├── superresolution │ │ └── sample_0.jpg │ └── text_conditional │ │ └── sample_0.txt ├── imagenet_clsidx_to_label.txt ├── imagenet_train_hr_indices.p.REMOVED.git-id ├── imagenet_val_hr_indices.p ├── index_synset.yaml └── inpainting_examples │ ├── 6458524847_2f4c361183_k.png │ ├── 6458524847_2f4c361183_k_mask.png │ ├── 8399166846_f6fb4e4b8e_k.png │ ├── 8399166846_f6fb4e4b8e_k_mask.png │ ├── alex-iby-G_Pk4D9rMLs.png │ ├── alex-iby-G_Pk4D9rMLs_mask.png │ ├── bench2.png │ ├── bench2_mask.png │ ├── bertrand-gabioud-CpuFzIsHYJ0.png │ ├── bertrand-gabioud-CpuFzIsHYJ0_mask.png │ ├── billow926-12-Wc-Zgx6Y.png │ ├── billow926-12-Wc-Zgx6Y_mask.png │ ├── overture-creations-5sI6fQgYIuo.png │ ├── overture-creations-5sI6fQgYIuo_mask.png │ ├── photo-1583445095369-9c651e7e5d34.png │ └── photo-1583445095369-9c651e7e5d34_mask.png ├── environment.yaml ├── ldm ├── data │ ├── __init__.py │ ├── base.py │ ├── imagenet.py │ └── lsun.py ├── lr_scheduler.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── ddpm_edit.py │ │ ├── ddpm_edit_rw.py │ │ ├── ddpm_edit_v21.py │ │ ├── dpm_solver │ │ ├── __init__.py │ │ ├── dpm_solver.py │ │ └── sampler.py │ │ └── plms.py ├── modules │ ├── attention.py │ ├── attention_v21.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ ├── openaimodel_v21.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py └── util.py ├── main.py ├── models ├── first_stage_models │ ├── kl-f16 │ │ └── config.yaml │ ├── kl-f32 │ │ └── config.yaml │ ├── kl-f4 │ │ └── config.yaml │ ├── kl-f8 │ │ └── config.yaml │ ├── vq-f16 │ │ └── config.yaml │ ├── vq-f4-noattn │ │ └── config.yaml │ ├── vq-f4 │ │ └── config.yaml │ ├── vq-f8-n256 │ │ └── config.yaml │ └── vq-f8 │ │ └── config.yaml └── ldm │ ├── bsr_sr │ └── config.yaml │ ├── celeba256 │ └── config.yaml │ ├── cin256 │ └── config.yaml │ ├── ffhq256 │ └── config.yaml │ ├── inpainting_big │ └── config.yaml │ ├── layout2img-openimages256 │ └── config.yaml │ ├── lsun_beds256 │ └── config.yaml │ ├── lsun_churches256 │ └── config.yaml │ ├── semantic_synthesis256 │ └── config.yaml │ ├── semantic_synthesis512 │ └── config.yaml │ └── text2img256 │ └── config.yaml ├── notebook_helpers.py ├── scripts ├── download_first_stages.sh ├── download_models.sh ├── img2img.py ├── inpaint.py ├── knn2img.py ├── latent_imagenet_diffusion.ipynb.REMOVED.git-id ├── sample_diffusion.py ├── tests │ └── test_watermark.py ├── train_searcher.py └── txt2img.py └── setup.py /AI_ETHICS.md: -------------------------------------------------------------------------------- 1 | ## Ethics disclaimer for Salesforce AI models, data, code 2 | 3 | This release is for research purposes only in support of an academic 4 | paper. Our models, datasets, and code are not specifically designed or 5 | evaluated for all downstream purposes. We strongly recommend users 6 | evaluate and address potential concerns related to accuracy, safety, and 7 | fairness before deploying this model. We encourage users to consider the 8 | common limitations of AI, comply with applicable laws, and leverage best 9 | practices when selecting use cases, particularly for high-risk scenarios 10 | where errors or misuse could significantly impact people’s lives, rights, 11 | or safety. For further guidance on use cases, refer to our standard 12 | [AUP](https://www.salesforce.com/content/dam/web/en_us/www/documents/legal/Agreements/policies/ExternalFacing_Services_Policy.pdf) 13 | and [AI AUP](https://www.salesforce.com/content/dam/web/en_us/www/documents/legal/Agreements/policies/ai-acceptable-use-policy.pdf). 14 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | # Comment line immediately above ownership line is reserved for related other information. Please be careful while editing. 2 | #ECCN:Open Source 3 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Salesforce Open Source Community Code of Conduct 2 | 3 | ## About the Code of Conduct 4 | 5 | Equality is a core value at Salesforce. We believe a diverse and inclusive 6 | community fosters innovation and creativity, and are committed to building a 7 | culture where everyone feels included. 8 | 9 | Salesforce open-source projects are committed to providing a friendly, safe, and 10 | welcoming environment for all, regardless of gender identity and expression, 11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality, 12 | race, age, religion, level of experience, education, socioeconomic status, or 13 | other similar personal characteristics. 14 | 15 | The goal of this code of conduct is to specify a baseline standard of behavior so 16 | that people with different social values and communication styles can work 17 | together effectively, productively, and respectfully in our open source community. 18 | It also establishes a mechanism for reporting issues and resolving conflicts. 19 | 20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior 21 | in a Salesforce open-source project may be reported by contacting the Salesforce 22 | Open Source Conduct Committee at ossconduct@salesforce.com. 23 | 24 | ## Our Pledge 25 | 26 | In the interest of fostering an open and welcoming environment, we as 27 | contributors and maintainers pledge to making participation in our project and 28 | our community a harassment-free experience for everyone, regardless of gender 29 | identity and expression, sexual orientation, disability, physical appearance, 30 | body size, ethnicity, nationality, race, age, religion, level of experience, education, 31 | socioeconomic status, or other similar personal characteristics. 32 | 33 | ## Our Standards 34 | 35 | Examples of behavior that contributes to creating a positive environment 36 | include: 37 | 38 | * Using welcoming and inclusive language 39 | * Being respectful of differing viewpoints and experiences 40 | * Gracefully accepting constructive criticism 41 | * Focusing on what is best for the community 42 | * Showing empathy toward other community members 43 | 44 | Examples of unacceptable behavior by participants include: 45 | 46 | * The use of sexualized language or imagery and unwelcome sexual attention or 47 | advances 48 | * Personal attacks, insulting/derogatory comments, or trolling 49 | * Public or private harassment 50 | * Publishing, or threatening to publish, others' private information—such as 51 | a physical or electronic address—without explicit permission 52 | * Other conduct which could reasonably be considered inappropriate in a 53 | professional setting 54 | * Advocating for or encouraging any of the above behaviors 55 | 56 | ## Our Responsibilities 57 | 58 | Project maintainers are responsible for clarifying the standards of acceptable 59 | behavior and are expected to take appropriate and fair corrective action in 60 | response to any instances of unacceptable behavior. 61 | 62 | Project maintainers have the right and responsibility to remove, edit, or 63 | reject comments, commits, code, wiki edits, issues, and other contributions 64 | that are not aligned with this Code of Conduct, or to ban temporarily or 65 | permanently any contributor for other behaviors that they deem inappropriate, 66 | threatening, offensive, or harmful. 67 | 68 | ## Scope 69 | 70 | This Code of Conduct applies both within project spaces and in public spaces 71 | when an individual is representing the project or its community. Examples of 72 | representing a project or community include using an official project email 73 | address, posting via an official social media account, or acting as an appointed 74 | representative at an online or offline event. Representation of a project may be 75 | further defined and clarified by project maintainers. 76 | 77 | ## Enforcement 78 | 79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 80 | reported by contacting the Salesforce Open Source Conduct Committee 81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated 82 | and will result in a response that is deemed necessary and appropriate to the 83 | circumstances. The committee is obligated to maintain confidentiality with 84 | regard to the reporter of an incident. Further details of specific enforcement 85 | policies may be posted separately. 86 | 87 | Project maintainers who do not follow or enforce the Code of Conduct in good 88 | faith may face temporary or permanent repercussions as determined by other 89 | members of the project's leadership and the Salesforce Open Source Conduct 90 | Committee. 91 | 92 | ## Attribution 93 | 94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home], 95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html. 96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc], 97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc]. 98 | 99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us]. 100 | 101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/) 102 | [golang-coc]: https://golang.org/conduct 103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md 104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/ 105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/ -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guide 2 | 3 | This page lists the operational governance model of this project, as well as the recommendations and requirements for how to best contribute to it. We strive to obey these as best as possible. As always, thanks for contributing – we hope these guidelines make it easier and shed some light on our approach and processes. 4 | 5 | # Governance Model 6 | 7 | ## Published but not supported 8 | 9 | The intent and goal of open sourcing this project is because it may contain useful or interesting code/concepts that we wish to share with the larger open source community. Although occasional work may be done on it, we will not be looking for or soliciting contributions. 10 | 11 | # Getting started 12 | 13 | Please join the community. Also please make sure to take a look at the project [roadmap](ROADMAP.md), if it exists, to see where are headed. 14 | 15 | # Issues, requests & ideas 16 | 17 | Use GitHub Issues page to submit issues, enhancement requests and discuss ideas. 18 | 19 | ### Bug Reports and Fixes 20 | - If you find a bug, please search for it in the Issues, and if it isn't already tracked, 21 | create a new issue. Fill out the "Bug Report" section of the issue template. Even if an Issue is closed, feel free to comment and add details, it will still 22 | be reviewed. 23 | - Issues that have already been identified as a bug (note: able to reproduce) will be labelled `bug`. 24 | - If you'd like to submit a fix for a bug, [send a Pull Request](#creating_a_pull_request) and mention the Issue number. 25 | - Include tests that isolate the bug and verifies that it was fixed. 26 | 27 | ### New Features 28 | - If you'd like to add new functionality to this project, describe the problem you want to solve in a new Issue. 29 | - Issues that have been identified as a feature request will be labelled `enhancement`. 30 | - If you'd like to implement the new feature, please wait for feedback from the project 31 | maintainers before spending too much time writing the code. In some cases, `enhancement`s may 32 | not align well with the project objectives at the time. 33 | 34 | ### Tests, Documentation, Miscellaneous 35 | - If you'd like to improve the tests, you want to make the documentation clearer, you have an 36 | alternative implementation of something that may have advantages over the way its currently 37 | done, or you have any other change, we would be happy to hear about it! 38 | - If its a trivial change, go ahead and [send a Pull Request](#creating_a_pull_request) with the changes you have in mind. 39 | - If not, open an Issue to discuss the idea first. 40 | 41 | If you're new to our project and looking for some way to make your first contribution, look for 42 | Issues labelled `good first contribution`. 43 | 44 | # Contribution Checklist 45 | 46 | - [x] Clean, simple, well styled code 47 | - [x] Commits should be atomic and messages must be descriptive. Related issues should be mentioned by Issue number. 48 | - [x] Comments 49 | - Module-level & function-level comments. 50 | - Comments on complex blocks of code or algorithms (include references to sources). 51 | - [x] Tests 52 | - The test suite, if provided, must be complete and pass 53 | - Increase code coverage, not versa. 54 | - Use any of our testkits that contains a bunch of testing facilities you would need. For example: `import com.salesforce.op.test._` and borrow inspiration from existing tests. 55 | - [x] Dependencies 56 | - Minimize number of dependencies. 57 | - Prefer Apache 2.0, BSD3, MIT, ISC and MPL licenses. 58 | - [x] Reviews 59 | - Changes must be approved via peer code review 60 | 61 | # Creating a Pull Request 62 | 63 | 1. **Ensure the bug/feature was not already reported** by searching on GitHub under Issues. If none exists, create a new issue so that other contributors can keep track of what you are trying to add/fix and offer suggestions (or let you know if there is already an effort in progress). 64 | 3. **Clone** the forked repo to your machine. 65 | 4. **Create** a new branch to contain your work (e.g. `git br fix-issue-11`) 66 | 4. **Commit** changes to your own branch. 67 | 5. **Push** your work back up to your fork. (e.g. `git push fix-issue-11`) 68 | 6. **Submit** a Pull Request against the `main` branch and refer to the issue(s) you are fixing. Try not to pollute your pull request with unintended changes. Keep it simple and small. 69 | 7. **Sign** the Salesforce CLA (you will be prompted to do so when submitting the Pull Request) 70 | 71 | > **NOTE**: Be sure to [sync your fork](https://help.github.com/articles/syncing-a-fork/) before making a pull request. 72 | 73 | 74 | # Code of Conduct 75 | Please follow our [Code of Conduct](CODE_OF_CONDUCT.md). 76 | 77 | # License 78 | By contributing your code, you agree to license your contribution under the terms of our project [LICENSE](LICENSE.txt) and to sign the [Salesforce CLA](https://cla.salesforce.com/sign-cla) 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HIVE 2 | 3 | ### [HIVE: Harnessing Human Feedback for Instructional Visual Editing](https://arxiv.org/pdf/2303.09618.pdf) 4 | Shu Zhang\*1, Xinyi Yang\*1, Yihao Feng\*1, Can Qin3, Chia-Chih Chen1, Ning Yu1, Zeyuan Chen1, Huan Wang1, Silvio Savarese1,2, Stefano Ermon2, Caiming Xiong1, and Ran Xu1
5 | 1Salesforce AI, 2Stanford University, 3Northeastern University
6 | \*denotes equal contribution
7 | arXiv 2023 8 | 9 | ### [paper](https://arxiv.org/pdf/2303.09618.pdf) | [project page](https://shugerdou.github.io/hive/) 10 | 11 | 12 | 13 | This is a PyTorch implementation of [HIVE: Harnessing Human Feedback for Instructional Visual Editing](https://arxiv.org/pdf/2303.09618.pdf). The major part of the code follows [InstructPix2Pix](https://github.com/timothybrooks/instruct-pix2pix). In this repo, we have implemented both [stable diffusion v1.5-base](https://huggingface.co/runwayml/stable-diffusion-v1-5) and [stable diffusion v2.1-base](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) as the backbone. 14 | 15 | 16 | ## Updates 17 | * **07/08/23**: ***Training code and training data is public.***:blush: 18 | * **03/26/24**: ***HIVE will appear in CVPR, 2024.***:blush: 19 | 20 | ## Usage 21 | 22 | ### Preparation 23 | First set-up the ```hive``` enviroment and download the pretrianed model as below. This is only verified on CUDA 11.0 and CUDA 11.3 with NVIDIA A100 GPU. 24 | 25 | ``` 26 | conda env create -f environment.yaml 27 | conda activate hive 28 | bash scripts/download_checkpoints.sh 29 | ``` 30 | 31 | To fine-tune a stable diffusion model, you need to obtain the pre-trained stable diffusion models following their [instructions](https://github.com/runwayml/stable-diffusion). If you use SD-V1.5, you can download the huggingface weights [HuggingFace SD 1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt). If you use SD-V2.1, the weights can be downloaded on [HuggingFace SD 2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1-base). You can decide which version of checkpoint to use. We use ```v2-1_512-ema-pruned.ckpt```. Download the model to checkpoints/. 32 | 33 | 34 | ### Data 35 | We suggest to install Gcloud CLI following [Gcloud download](https://cloud.google.com/sdk/docs/install). To obtain both training and evaluation data, run 36 | ``` 37 | bash scripts/download_hive_data.sh 38 | ``` 39 | 40 | An alternative method is to directly download the data through [Evaluation data](https://storage.cloud.google.com/sfr-hive-data-research/data/evaluation.zip) and [Evaluation instructions](https://storage.cloud.google.com/sfr-hive-data-research/data/test.jsonl). 41 | 42 | 43 | ### Step-1 Training 44 | For SD v2.1, we run 45 | 46 | ``` 47 | python main.py --name step1 --base configs/train_v21_base.yaml --train --gpus 0,1,2,3,4,5,6,7 48 | ``` 49 | 50 | ### Inference 51 | Samples can be obtained by running the command. 52 | 53 | For SD v2.1, if we use the conditional reward, we run 54 | 55 | ``` 56 | python edit_cli_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \ 57 | --input imgs/example1.jpg --output imgs/output.jpg --edit "move it to Mars" --ckpt checkpoints/hive_v2_rw_condition.ckpt \ 58 | --config configs/generate_v21_base.yaml 59 | ``` 60 | 61 | 62 | or run batch inference on our inference data: 63 | 64 | ``` 65 | python edit_cli_batch_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \ 66 | --jsonl_file data/test.jsonl --output_dir imgs/sdv21_rw_label/ --ckpt checkpoints/hive_v2_rw_condition.ckpt \ 67 | --config configs/generate_v21_base.yaml --image_dir data/evaluation/ 68 | ``` 69 | 70 | For SD v2.1, if we use the weighted reward, we can run 71 | 72 | 73 | ``` 74 | python edit_cli.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \ 75 | --input imgs/example1.jpg --output imgs/output.jpg --edit "move it to Mars" \ 76 | --ckpt checkpoints/hive_v2_rw.ckpt --config configs/generate_v21_base.yaml 77 | ``` 78 | 79 | or run batch inference on our inference data: 80 | 81 | ``` 82 | python edit_cli_batch.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \ 83 | --jsonl_file data/test.jsonl --output_dir imgs/sdv21/ --ckpt checkpoints/hive_v2_rw.ckpt \ 84 | --config configs/generate_v21_base.yaml --image_dir data/evaluation/ 85 | ``` 86 | 87 | For SD v1.5, if we use the conditional reward, we can run 88 | 89 | ``` 90 | python edit_cli_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \ 91 | --input imgs/example1.jpg --output imgs/output.jpg --edit "move it to Mars" \ 92 | --ckpt checkpoints/hive_rw_condition.ckpt --config configs/generate.yaml 93 | ``` 94 | 95 | or run batch inference on our inference data: 96 | 97 | ``` 98 | python edit_cli_batch_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \ 99 | --jsonl_file data/test.jsonl --output_dir imgs/sdv15_rw_label/ \ 100 | --ckpt checkpoints/hive_rw_condition.ckpt --config configs/generate.yaml \ 101 | --image_dir data/evaluation/ 102 | ``` 103 | 104 | For SD v1.5, if we use the weighted reward, we run 105 | 106 | 107 | ``` 108 | python edit_cli.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 --input imgs/example1.jpg \ 109 | --output imgs/output.jpg --edit "move it to Mars" \ 110 | --ckpt checkpoints/hive_rw.ckpt --config configs/generate.yaml 111 | ``` 112 | 113 | or run batch inference on our inference data: 114 | 115 | ``` 116 | python edit_cli_batch.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \ 117 | --jsonl_file data/test.jsonl --output_dir imgs/sdv15/ \ 118 | --ckpt checkpoints/hive_rw.ckpt --config configs/generate.yaml \ 119 | --image_dir data/evaluation/ 120 | ``` 121 | 122 | ## Citation 123 | ``` 124 | @article{zhang2023hive, 125 | title={HIVE: Harnessing Human Feedback for Instructional Visual Editing}, 126 | author={Zhang, Shu and Yang, Xinyi and Feng, Yihao and Qin, Can and Chen, Chia-Chih and Yu, Ning and Chen, Zeyuan and Wang, Huan and Savarese, Silvio and Ermon, Stefano and Xiong, Caiming and Xu, Ran}, 127 | journal={arXiv preprint arXiv:2303.09618}, 128 | year={2023} 129 | } 130 | ``` 131 | 132 | 133 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | ## Security 2 | 3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com) 4 | as soon as it is discovered. This library limits its runtime dependencies in 5 | order to reduce the total cost of ownership as much as can be, but all consumers 6 | should remain vigilant and have their security stakeholders review all third-party 7 | products (3PP) like this one and their dependencies. -------------------------------------------------------------------------------- /configs/generate.yaml: -------------------------------------------------------------------------------- 1 | # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion). 2 | # See more details in LICENSE. 3 | 4 | model: 5 | base_learning_rate: 1.0e-04 6 | target: ldm.models.diffusion.ddpm_edit.LatentDiffusion 7 | params: 8 | linear_start: 0.00085 9 | linear_end: 0.0120 10 | num_timesteps_cond: 1 11 | log_every_t: 200 12 | timesteps: 1000 13 | first_stage_key: edited 14 | cond_stage_key: edit 15 | # image_size: 64 16 | # image_size: 32 17 | image_size: 16 18 | channels: 4 19 | cond_stage_trainable: false # Note: different from the one we trained before 20 | conditioning_key: hybrid 21 | monitor: val/loss_simple_ema 22 | scale_factor: 0.18215 23 | use_ema: true 24 | load_ema: true 25 | 26 | scheduler_config: # 10000 warmup steps 27 | target: ldm.lr_scheduler.LambdaLinearScheduler 28 | params: 29 | warm_up_steps: [ 0 ] 30 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 31 | f_start: [ 1.e-6 ] 32 | f_max: [ 1. ] 33 | f_min: [ 1. ] 34 | 35 | unet_config: 36 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 37 | params: 38 | image_size: 32 # unused 39 | in_channels: 8 40 | out_channels: 4 41 | model_channels: 320 42 | attention_resolutions: [ 4, 2, 1 ] 43 | num_res_blocks: 2 44 | channel_mult: [ 1, 2, 4, 4 ] 45 | num_heads: 8 46 | use_spatial_transformer: True 47 | transformer_depth: 1 48 | context_dim: 768 49 | use_checkpoint: True 50 | legacy: False 51 | 52 | first_stage_config: 53 | target: ldm.models.autoencoder.AutoencoderKL 54 | params: 55 | embed_dim: 4 56 | monitor: val/rec_loss 57 | ddconfig: 58 | double_z: true 59 | z_channels: 4 60 | resolution: 256 61 | in_channels: 3 62 | out_ch: 3 63 | ch: 128 64 | ch_mult: 65 | - 1 66 | - 2 67 | - 4 68 | - 4 69 | num_res_blocks: 2 70 | attn_resolutions: [] 71 | dropout: 0.0 72 | lossconfig: 73 | target: torch.nn.Identity 74 | 75 | cond_stage_config: 76 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 77 | 78 | data: 79 | target: main.DataModuleFromConfig 80 | params: 81 | batch_size: 128 82 | num_workers: 1 83 | wrap: false 84 | validation: 85 | target: edit_dataset.EditDataset 86 | params: 87 | path: ./data/training/instructpix2pix 88 | cache_dir: data/ 89 | cache_name: data_10k 90 | split: val 91 | min_text_sim: 0.2 92 | min_image_sim: 0.75 93 | min_direction_sim: 0.2 94 | max_samples_per_prompt: 1 95 | min_resize_res: 512 96 | max_resize_res: 512 97 | crop_res: 512 98 | output_as_edit: False 99 | real_input: True 100 | -------------------------------------------------------------------------------- /configs/generate_v21_base.yaml: -------------------------------------------------------------------------------- 1 | # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion). 2 | # See more details in LICENSE. 3 | 4 | model: 5 | base_learning_rate: 1.0e-04 6 | target: ldm.models.diffusion.ddpm_edit_v21.LatentDiffusion 7 | params: 8 | ckpt_path: checkpoints/v2-1_512-ema-pruned.ckpt 9 | linear_start: 0.00085 10 | linear_end: 0.0120 11 | num_timesteps_cond: 1 12 | log_every_t: 200 13 | timesteps: 1000 14 | first_stage_key: edited 15 | cond_stage_key: edit 16 | image_size: 32 17 | channels: 4 18 | cond_stage_trainable: false # Note: different from the one we trained before 19 | conditioning_key: hybrid 20 | monitor: val/loss_simple_ema 21 | scale_factor: 0.18215 22 | use_ema: true 23 | load_ema: true 24 | 25 | scheduler_config: # 10000 warmup steps 26 | target: ldm.lr_scheduler.LambdaLinearScheduler 27 | params: 28 | warm_up_steps: [ 0 ] 29 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 30 | f_start: [ 1.e-6 ] 31 | f_max: [ 1. ] 32 | f_min: [ 1. ] 33 | 34 | unet_config: 35 | target: ldm.modules.diffusionmodules.openaimodel_v21.UNetModel 36 | params: 37 | image_size: 32 # unused 38 | in_channels: 8 39 | out_channels: 4 40 | model_channels: 320 41 | attention_resolutions: [ 4, 2, 1 ] 42 | num_res_blocks: 2 43 | channel_mult: [ 1, 2, 4, 4 ] 44 | num_heads: 8 45 | num_head_channels: 64 # need to fix for flash-attn 46 | use_spatial_transformer: True 47 | use_linear_in_transformer: True 48 | transformer_depth: 1 49 | context_dim: 1024 50 | use_checkpoint: True 51 | legacy: False 52 | 53 | first_stage_config: 54 | target: ldm.models.autoencoder.AutoencoderKL 55 | params: 56 | embed_dim: 4 57 | monitor: val/rec_loss 58 | ddconfig: 59 | double_z: true 60 | z_channels: 4 61 | resolution: 256 62 | in_channels: 3 63 | out_ch: 3 64 | ch: 128 65 | ch_mult: 66 | - 1 67 | - 2 68 | - 4 69 | - 4 70 | num_res_blocks: 2 71 | attn_resolutions: [] 72 | dropout: 0.0 73 | lossconfig: 74 | target: torch.nn.Identity 75 | 76 | cond_stage_config: 77 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 78 | params: 79 | freeze: True 80 | layer: "penultimate" 81 | 82 | data: 83 | target: main.DataModuleFromConfig 84 | params: 85 | batch_size: 128 86 | num_workers: 1 87 | wrap: false 88 | validation: 89 | target: edit_dataset.EditDataset 90 | params: 91 | path: data/clip-filtered-dataset 92 | cache_dir: data/ 93 | cache_name: data_10k 94 | split: val 95 | min_text_sim: 0.2 96 | min_image_sim: 0.75 97 | min_direction_sim: 0.2 98 | max_samples_per_prompt: 1 99 | min_resize_res: 512 100 | max_resize_res: 512 101 | crop_res: 512 102 | output_as_edit: False 103 | real_input: True 104 | -------------------------------------------------------------------------------- /configs/train_v21_base.yaml: -------------------------------------------------------------------------------- 1 | # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion). 2 | # See more details in LICENSE. 3 | 4 | model: 5 | base_learning_rate: 1.0e-04 6 | target: ldm.models.diffusion.ddpm_edit_v21.LatentDiffusion 7 | params: 8 | ckpt_path: ./checkpoints/v2-1_512-ema-pruned.ckpt 9 | linear_start: 0.00085 10 | linear_end: 0.0120 11 | num_timesteps_cond: 1 12 | log_every_t: 200 13 | timesteps: 1000 14 | first_stage_key: edited 15 | cond_stage_key: edit 16 | image_size: 32 17 | channels: 4 18 | cond_stage_trainable: false # Note: different from the one we trained before 19 | conditioning_key: hybrid 20 | monitor: val/loss_simple_ema 21 | scale_factor: 0.18215 22 | use_ema: true 23 | load_ema: false 24 | 25 | scheduler_config: # 10000 warmup steps 26 | target: ldm.lr_scheduler.LambdaLinearScheduler 27 | params: 28 | warm_up_steps: [ 0 ] 29 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 30 | f_start: [ 1.e-6 ] 31 | f_max: [ 1. ] 32 | f_min: [ 1. ] 33 | 34 | unet_config: 35 | target: ldm.modules.diffusionmodules.openaimodel_v21.UNetModel 36 | params: 37 | image_size: 32 # unused 38 | in_channels: 8 39 | out_channels: 4 40 | model_channels: 320 41 | attention_resolutions: [ 4, 2, 1 ] 42 | num_res_blocks: 2 43 | channel_mult: [ 1, 2, 4, 4 ] 44 | num_heads: 8 45 | num_head_channels: 64 # need to fix for flash-attn 46 | use_spatial_transformer: True 47 | use_linear_in_transformer: True 48 | transformer_depth: 1 49 | context_dim: 1024 50 | use_checkpoint: True 51 | legacy: False 52 | 53 | first_stage_config: 54 | target: ldm.models.autoencoder.AutoencoderKL 55 | params: 56 | embed_dim: 4 57 | monitor: val/rec_loss 58 | ddconfig: 59 | double_z: true 60 | z_channels: 4 61 | resolution: 256 62 | in_channels: 3 63 | out_ch: 3 64 | ch: 128 65 | ch_mult: 66 | - 1 67 | - 2 68 | - 4 69 | - 4 70 | num_res_blocks: 2 71 | attn_resolutions: [] 72 | dropout: 0.0 73 | lossconfig: 74 | target: torch.nn.Identity 75 | 76 | cond_stage_config: 77 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder 78 | params: 79 | freeze: True 80 | layer: "penultimate" 81 | 82 | 83 | data: 84 | target: main.DataModuleFromConfig 85 | params: 86 | batch_size: 32 87 | num_workers: 2 88 | train: 89 | target: edit_dataset.EditDataset 90 | params: 91 | path_instructpix2pix: ./data/training/instructpix2pix 92 | path_hive_0: ./data/training 93 | path_hive_1: ./data/training/part_0_blip_prompt_new 94 | path_hive_2: ./data/training/part_1_blip_prompt_new 95 | split: train 96 | min_resize_res: 256 97 | max_resize_res: 256 98 | crop_res: 256 99 | flip_prob: 0.5 100 | validation: 101 | target: edit_dataset.EditDataset 102 | params: 103 | path_instructpix2pix: ./data/training/instructpix2pix 104 | path_hive_0: ./data/training 105 | path_hive_1: ./data/training/part_0_blip_prompt_new 106 | path_hive_2: ./data/training/part_1_blip_prompt_new 107 | split: val 108 | min_resize_res: 256 109 | max_resize_res: 256 110 | crop_res: 256 111 | 112 | lightning: 113 | callbacks: 114 | image_logger: 115 | target: main.ImageLogger 116 | params: 117 | batch_frequency: 2000 118 | max_images: 2 119 | increase_log_steps: False 120 | 121 | trainer: 122 | max_epochs: 3000 123 | benchmark: True 124 | accumulate_grad_batches: 4 125 | check_val_every_n_epoch: 4 126 | -------------------------------------------------------------------------------- /edit_cli.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2023 Salesforce, Inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: Apache License 2.0 5 | * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/ 6 | * By Shu Zhang 7 | * Modified from InstructPix2Pix repo: https://github.com/timothybrooks/instruct-pix2pix 8 | * Copyright (c) 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros. All rights reserved. 9 | ''' 10 | 11 | from __future__ import annotations 12 | 13 | import math 14 | import random 15 | import sys 16 | from argparse import ArgumentParser 17 | 18 | import einops 19 | import k_diffusion as K 20 | import numpy as np 21 | import torch 22 | import torch.nn as nn 23 | from einops import rearrange 24 | from omegaconf import OmegaConf 25 | from PIL import Image, ImageOps 26 | from torch import autocast 27 | 28 | sys.path.append("./stable_diffusion") 29 | 30 | from stable_diffusion.ldm.util import instantiate_from_config 31 | 32 | 33 | class CFGDenoiser(nn.Module): 34 | def __init__(self, model): 35 | super().__init__() 36 | self.inner_model = model 37 | 38 | def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale): 39 | cfg_z = einops.repeat(z, "1 ... -> n ...", n=3) 40 | cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3) 41 | cfg_cond = { 42 | "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])], 43 | "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])], 44 | } 45 | out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3) 46 | return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond) 47 | 48 | 49 | def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False): 50 | print(f"Loading model from {ckpt}") 51 | pl_sd = torch.load(ckpt, map_location="cpu") 52 | if "global_step" in pl_sd: 53 | print(f"Global Step: {pl_sd['global_step']}") 54 | sd = pl_sd["state_dict"] 55 | if vae_ckpt is not None: 56 | print(f"Loading VAE from {vae_ckpt}") 57 | vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"] 58 | sd = { 59 | k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v 60 | for k, v in sd.items() 61 | } 62 | model = instantiate_from_config(config.model) 63 | m, u = model.load_state_dict(sd, strict=False) 64 | if len(m) > 0 and verbose: 65 | print("missing keys:") 66 | print(m) 67 | if len(u) > 0 and verbose: 68 | print("unexpected keys:") 69 | print(u) 70 | return model 71 | 72 | 73 | def main(): 74 | parser = ArgumentParser() 75 | parser.add_argument("--resolution", default=512, type=int) 76 | parser.add_argument("--steps", default=100, type=int) 77 | parser.add_argument("--config", default="configs/generate.yaml", type=str) 78 | parser.add_argument("--ckpt", default="checkpoints/hive.ckpt", type=str) 79 | parser.add_argument("--vae-ckpt", default=None, type=str) 80 | parser.add_argument("--input", required=True, type=str) 81 | parser.add_argument("--output", required=True, type=str) 82 | parser.add_argument("--edit", required=True, type=str) 83 | parser.add_argument("--cfg-text", default=7.5, type=float) 84 | parser.add_argument("--cfg-image", default=1.5, type=float) 85 | parser.add_argument("--seed", type=int) 86 | args = parser.parse_args() 87 | 88 | config = OmegaConf.load(args.config) 89 | model = load_model_from_config(config, args.ckpt, args.vae_ckpt) 90 | model.eval().cuda() 91 | model_wrap = K.external.CompVisDenoiser(model) 92 | model_wrap_cfg = CFGDenoiser(model_wrap) 93 | null_token = model.get_learned_conditioning([""]) 94 | 95 | seed = random.randint(0, 100000) if args.seed is None else args.seed 96 | input_image = Image.open(args.input).convert("RGB") 97 | width, height = input_image.size 98 | factor = args.resolution / max(width, height) 99 | factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height) 100 | width = int((width * factor) // 64) * 64 101 | height = int((height * factor) // 64) * 64 102 | input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS) 103 | 104 | if args.edit == "": 105 | input_image.save(args.output) 106 | return 107 | 108 | with torch.no_grad(), autocast("cuda"), model.ema_scope(): 109 | cond = {} 110 | cond["c_crossattn"] = [model.get_learned_conditioning([args.edit])] 111 | input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1 112 | input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device) 113 | cond["c_concat"] = [model.encode_first_stage(input_image).mode()] 114 | 115 | uncond = {} 116 | uncond["c_crossattn"] = [null_token] 117 | uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])] 118 | 119 | sigmas = model_wrap.get_sigmas(args.steps) 120 | 121 | extra_args = { 122 | "cond": cond, 123 | "uncond": uncond, 124 | "text_cfg_scale": args.cfg_text, 125 | "image_cfg_scale": args.cfg_image, 126 | } 127 | torch.manual_seed(seed) 128 | z = torch.randn_like(cond["c_concat"][0]) * sigmas[0] 129 | z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args) 130 | x = model.decode_first_stage(z) 131 | x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0) 132 | x = 255.0 * rearrange(x, "1 c h w -> h w c") 133 | edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy()) 134 | edited_image.save(args.output) 135 | 136 | 137 | if __name__ == "__main__": 138 | main() 139 | -------------------------------------------------------------------------------- /edit_cli_batch.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2023 Salesforce, Inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: Apache License 2.0 5 | * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/ 6 | * By Shu Zhang 7 | * Modified from InstructPix2Pix repo: https://github.com/timothybrooks/instruct-pix2pix 8 | * Copyright (c) 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros. All rights reserved. 9 | ''' 10 | 11 | from __future__ import annotations 12 | 13 | import math 14 | import os 15 | import random 16 | import sys 17 | from argparse import ArgumentParser 18 | 19 | import einops 20 | import k_diffusion as K 21 | import numpy as np 22 | import torch 23 | import torch.nn as nn 24 | import glob 25 | import re 26 | import jsonlines 27 | from einops import rearrange 28 | from omegaconf import OmegaConf 29 | from PIL import Image, ImageOps 30 | from torch import autocast 31 | 32 | sys.path.append("./stable_diffusion") 33 | 34 | from stable_diffusion.ldm.util import instantiate_from_config 35 | 36 | 37 | class CFGDenoiser(nn.Module): 38 | def __init__(self, model): 39 | super().__init__() 40 | self.inner_model = model 41 | 42 | def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale): 43 | cfg_z = einops.repeat(z, "1 ... -> n ...", n=3) 44 | cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3) 45 | cfg_cond = { 46 | "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])], 47 | "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])], 48 | } 49 | out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3) 50 | return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond) 51 | 52 | 53 | def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False): 54 | print(f"Loading model from {ckpt}") 55 | pl_sd = torch.load(ckpt, map_location="cpu") 56 | if "global_step" in pl_sd: 57 | print(f"Global Step: {pl_sd['global_step']}") 58 | sd = pl_sd["state_dict"] 59 | if vae_ckpt is not None: 60 | print(f"Loading VAE from {vae_ckpt}") 61 | vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"] 62 | sd = { 63 | k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v 64 | for k, v in sd.items() 65 | } 66 | model = instantiate_from_config(config.model) 67 | m, u = model.load_state_dict(sd, strict=False) 68 | if len(m) > 0 and verbose: 69 | print("missing keys:") 70 | print(m) 71 | if len(u) > 0 and verbose: 72 | print("unexpected keys:") 73 | print(u) 74 | return model 75 | 76 | 77 | def main(): 78 | parser = ArgumentParser() 79 | parser.add_argument("--resolution", default=512, type=int) 80 | parser.add_argument("--steps", default=100, type=int) 81 | parser.add_argument("--config", default="configs/generate.yaml", type=str) 82 | parser.add_argument("--ckpt", default="checkpoints/hive.ckpt", type=str) 83 | parser.add_argument("--vae-ckpt", default=None, type=str) 84 | parser.add_argument("--output_dir", required=True, type=str) 85 | parser.add_argument("--jsonl_file", required=True, type=str) 86 | parser.add_argument("--image_dir", required=True, default="data/evaluation/", type=str) 87 | parser.add_argument("--cfg-text", default=7.5, type=float) 88 | parser.add_argument("--cfg-image", default=1.5, type=float) 89 | parser.add_argument("--seed", default=100, type=int) 90 | args = parser.parse_args() 91 | 92 | config = OmegaConf.load(args.config) 93 | model = load_model_from_config(config, args.ckpt, args.vae_ckpt) 94 | model.eval().cuda() 95 | model_wrap = K.external.CompVisDenoiser(model) 96 | model_wrap_cfg = CFGDenoiser(model_wrap) 97 | null_token = model.get_learned_conditioning([""]) 98 | 99 | seed = random.randint(0, 100000) if args.seed is None else args.seed 100 | output_dir = args.output_dir 101 | if not os.path.exists(output_dir): 102 | os.makedirs(output_dir) 103 | image_dir = args.image_dir 104 | instructions = [] 105 | image_paths = [] 106 | with jsonlines.open(args.jsonl_file) as reader: 107 | for ll in reader: 108 | instructions.append(ll["instruction"]) 109 | image_paths.append(os.path.join(image_dir, ll['source_img'])) 110 | 111 | for i, instruction in enumerate(instructions): 112 | output_image = os.path.join(output_dir, f'instruct_{i}.png') 113 | input_image = Image.open(image_paths[i]).convert("RGB") 114 | width, height = input_image.size 115 | factor = args.resolution / max(width, height) 116 | factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height) 117 | width = int((width * factor) // 64) * 64 118 | height = int((height * factor) // 64) * 64 119 | input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS) 120 | edit = instructions[i] 121 | 122 | with torch.no_grad(), autocast("cuda"), model.ema_scope(): 123 | cond = {} 124 | cond["c_crossattn"] = [model.get_learned_conditioning([edit])] 125 | input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1 126 | input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device) 127 | cond["c_concat"] = [model.encode_first_stage(input_image).mode()] 128 | 129 | uncond = {} 130 | uncond["c_crossattn"] = [null_token] 131 | uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])] 132 | 133 | sigmas = model_wrap.get_sigmas(args.steps) 134 | 135 | extra_args = { 136 | "cond": cond, 137 | "uncond": uncond, 138 | "text_cfg_scale": args.cfg_text, 139 | "image_cfg_scale": args.cfg_image, 140 | } 141 | torch.manual_seed(seed) 142 | z = torch.randn_like(cond["c_concat"][0]) * sigmas[0] 143 | z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args) 144 | x = model.decode_first_stage(z) 145 | x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0) 146 | x = 255.0 * rearrange(x, "1 c h w -> h w c") 147 | edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy()) 148 | edited_image.save(output_image) 149 | 150 | 151 | if __name__ == "__main__": 152 | main() 153 | -------------------------------------------------------------------------------- /edit_cli_batch_rw_label.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2023 Salesforce, Inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: Apache License 2.0 5 | * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/ 6 | * By Shu Zhang 7 | * Modified from InstructPix2Pix repo: https://github.com/timothybrooks/instruct-pix2pix 8 | * Copyright (c) 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros. All rights reserved. 9 | ''' 10 | 11 | from __future__ import annotations 12 | 13 | import math 14 | import os 15 | import random 16 | import sys 17 | from argparse import ArgumentParser 18 | 19 | import einops 20 | import k_diffusion as K 21 | import numpy as np 22 | import torch 23 | import torch.nn as nn 24 | import glob 25 | import re 26 | import jsonlines 27 | from einops import rearrange 28 | from omegaconf import OmegaConf 29 | from PIL import Image, ImageOps 30 | from torch import autocast 31 | 32 | sys.path.append("./stable_diffusion") 33 | 34 | from stable_diffusion.ldm.util import instantiate_from_config 35 | 36 | 37 | class CFGDenoiser(nn.Module): 38 | def __init__(self, model): 39 | super().__init__() 40 | self.inner_model = model 41 | 42 | def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale): 43 | cfg_z = einops.repeat(z, "1 ... -> n ...", n=3) 44 | cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3) 45 | cfg_cond = { 46 | "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])], 47 | "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])], 48 | } 49 | out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3) 50 | return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond) 51 | 52 | 53 | def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False): 54 | print(f"Loading model from {ckpt}") 55 | pl_sd = torch.load(ckpt, map_location="cpu") 56 | if "global_step" in pl_sd: 57 | print(f"Global Step: {pl_sd['global_step']}") 58 | sd = pl_sd["state_dict"] 59 | if vae_ckpt is not None: 60 | print(f"Loading VAE from {vae_ckpt}") 61 | vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"] 62 | sd = { 63 | k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v 64 | for k, v in sd.items() 65 | } 66 | model = instantiate_from_config(config.model) 67 | m, u = model.load_state_dict(sd, strict=False) 68 | if len(m) > 0 and verbose: 69 | print("missing keys:") 70 | print(m) 71 | if len(u) > 0 and verbose: 72 | print("unexpected keys:") 73 | print(u) 74 | return model 75 | 76 | 77 | def main(): 78 | parser = ArgumentParser() 79 | parser.add_argument("--resolution", default=512, type=int) 80 | parser.add_argument("--steps", default=100, type=int) 81 | parser.add_argument("--config", default="configs/generate.yaml", type=str) 82 | parser.add_argument("--ckpt", default="checkpoints/hive_rw_label.ckpt", type=str) 83 | parser.add_argument("--vae-ckpt", default=None, type=str) 84 | parser.add_argument("--output_dir", required=True, type=str) 85 | parser.add_argument("--jsonl_file", required=True, type=str) 86 | parser.add_argument("--image_dir", required=True, default="data/evaluation/", type=str) 87 | parser.add_argument("--cfg-text", default=7.5, type=float) 88 | parser.add_argument("--cfg-image", default=1.5, type=float) 89 | parser.add_argument("--seed", default=100, type=int) 90 | args = parser.parse_args() 91 | 92 | config = OmegaConf.load(args.config) 93 | model = load_model_from_config(config, args.ckpt, args.vae_ckpt) 94 | model.eval().cuda() 95 | model_wrap = K.external.CompVisDenoiser(model) 96 | model_wrap_cfg = CFGDenoiser(model_wrap) 97 | null_token = model.get_learned_conditioning([""]) 98 | 99 | seed = random.randint(0, 100000) if args.seed is None else args.seed 100 | output_dir = args.output_dir 101 | if not os.path.exists(output_dir): 102 | os.makedirs(output_dir) 103 | image_dir = args.image_dir 104 | instructions = [] 105 | image_paths = [] 106 | with jsonlines.open(args.jsonl_file) as reader: 107 | for ll in reader: 108 | instructions.append(ll["instruction"]) 109 | image_paths.append(os.path.join(image_dir, ll['source_img'])) 110 | 111 | for i, instruction in enumerate(instructions): 112 | output_image = os.path.join(output_dir, f'instruct_{i}.png') 113 | input_image = Image.open(image_paths[i]).convert("RGB") 114 | width, height = input_image.size 115 | factor = args.resolution / max(width, height) 116 | factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height) 117 | width = int((width * factor) // 64) * 64 118 | height = int((height * factor) // 64) * 64 119 | input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS) 120 | edit = instructions[i] 121 | edit = edit + ', ' + f'image quality is five out of five' 122 | 123 | with torch.no_grad(), autocast("cuda"), model.ema_scope(): 124 | cond = {} 125 | cond["c_crossattn"] = [model.get_learned_conditioning([edit])] 126 | input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1 127 | input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device) 128 | cond["c_concat"] = [model.encode_first_stage(input_image).mode()] 129 | 130 | uncond = {} 131 | uncond["c_crossattn"] = [null_token] 132 | uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])] 133 | 134 | sigmas = model_wrap.get_sigmas(args.steps) 135 | 136 | extra_args = { 137 | "cond": cond, 138 | "uncond": uncond, 139 | "text_cfg_scale": args.cfg_text, 140 | "image_cfg_scale": args.cfg_image, 141 | } 142 | torch.manual_seed(seed) 143 | z = torch.randn_like(cond["c_concat"][0]) * sigmas[0] 144 | z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args) 145 | x = model.decode_first_stage(z) 146 | x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0) 147 | x = 255.0 * rearrange(x, "1 c h w -> h w c") 148 | edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy()) 149 | edited_image.save(output_image) 150 | 151 | 152 | if __name__ == "__main__": 153 | main() 154 | -------------------------------------------------------------------------------- /edit_cli_rw_label.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2023 Salesforce, Inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: Apache License 2.0 5 | * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/ 6 | * By Shu Zhang 7 | * Modified from InstructPix2Pix repo: https://github.com/timothybrooks/instruct-pix2pix 8 | * Copyright (c) 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros. All rights reserved. 9 | ''' 10 | 11 | from __future__ import annotations 12 | 13 | import math 14 | import random 15 | import sys 16 | from argparse import ArgumentParser 17 | 18 | import einops 19 | import k_diffusion as K 20 | import numpy as np 21 | import torch 22 | import torch.nn as nn 23 | from einops import rearrange 24 | from omegaconf import OmegaConf 25 | from PIL import Image, ImageOps 26 | from torch import autocast 27 | 28 | sys.path.append("./stable_diffusion") 29 | 30 | from stable_diffusion.ldm.util import instantiate_from_config 31 | 32 | 33 | class CFGDenoiser(nn.Module): 34 | def __init__(self, model): 35 | super().__init__() 36 | self.inner_model = model 37 | 38 | def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale): 39 | cfg_z = einops.repeat(z, "1 ... -> n ...", n=3) 40 | cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3) 41 | cfg_cond = { 42 | "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])], 43 | "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])], 44 | } 45 | out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3) 46 | return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond) 47 | 48 | 49 | def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False): 50 | print(f"Loading model from {ckpt}") 51 | pl_sd = torch.load(ckpt, map_location="cpu") 52 | if "global_step" in pl_sd: 53 | print(f"Global Step: {pl_sd['global_step']}") 54 | sd = pl_sd["state_dict"] 55 | if vae_ckpt is not None: 56 | print(f"Loading VAE from {vae_ckpt}") 57 | vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"] 58 | sd = { 59 | k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v 60 | for k, v in sd.items() 61 | } 62 | model = instantiate_from_config(config.model) 63 | m, u = model.load_state_dict(sd, strict=False) 64 | if len(m) > 0 and verbose: 65 | print("missing keys:") 66 | print(m) 67 | if len(u) > 0 and verbose: 68 | print("unexpected keys:") 69 | print(u) 70 | return model 71 | 72 | 73 | def main(): 74 | parser = ArgumentParser() 75 | parser.add_argument("--resolution", default=512, type=int) 76 | parser.add_argument("--steps", default=100, type=int) 77 | parser.add_argument("--config", default="configs/generate.yaml", type=str) 78 | parser.add_argument("--ckpt", default="checkpoints/hive_rw_label.ckpt", type=str) 79 | parser.add_argument("--vae-ckpt", default=None, type=str) 80 | parser.add_argument("--input", required=True, type=str) 81 | parser.add_argument("--output", required=True, type=str) 82 | parser.add_argument("--edit", required=True, type=str) 83 | parser.add_argument("--cfg-text", default=7.5, type=float) 84 | parser.add_argument("--cfg-image", default=1.5, type=float) 85 | parser.add_argument("--seed", type=int) 86 | args = parser.parse_args() 87 | 88 | config = OmegaConf.load(args.config) 89 | model = load_model_from_config(config, args.ckpt, args.vae_ckpt) 90 | model.eval().cuda() 91 | model_wrap = K.external.CompVisDenoiser(model) 92 | model_wrap_cfg = CFGDenoiser(model_wrap) 93 | null_token = model.get_learned_conditioning([""]) 94 | 95 | seed = random.randint(0, 100000) if args.seed is None else args.seed 96 | input_image = Image.open(args.input).convert("RGB") 97 | width, height = input_image.size 98 | factor = args.resolution / max(width, height) 99 | factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height) 100 | width = int((width * factor) // 64) * 64 101 | height = int((height * factor) // 64) * 64 102 | input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS) 103 | 104 | if args.edit == "": 105 | input_image.save(args.output) 106 | return 107 | 108 | with torch.no_grad(), autocast("cuda"), model.ema_scope(): 109 | cond = {} 110 | edit = args.edit + ', ' + f'image quality is five out of five' 111 | cond["c_crossattn"] = [model.get_learned_conditioning([edit])] 112 | input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1 113 | input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device) 114 | cond["c_concat"] = [model.encode_first_stage(input_image).mode()] 115 | 116 | uncond = {} 117 | uncond["c_crossattn"] = [null_token] 118 | uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])] 119 | 120 | sigmas = model_wrap.get_sigmas(args.steps) 121 | 122 | extra_args = { 123 | "cond": cond, 124 | "uncond": uncond, 125 | "text_cfg_scale": args.cfg_text, 126 | "image_cfg_scale": args.cfg_image, 127 | } 128 | torch.manual_seed(seed) 129 | z = torch.randn_like(cond["c_concat"][0]) * sigmas[0] 130 | z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args) 131 | x = model.decode_first_stage(z) 132 | x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0) 133 | x = 255.0 * rearrange(x, "1 c h w -> h w c") 134 | edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy()) 135 | edited_image.save(args.output) 136 | 137 | 138 | if __name__ == "__main__": 139 | main() 140 | -------------------------------------------------------------------------------- /edit_dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2023 Salesforce, Inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: Apache License 2.0 5 | * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/ 6 | * By Shu Zhang 7 | * Modified from InstructPix2Pix repo: https://github.com/timothybrooks/instruct-pix2pix 8 | * Copyright (c) 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros. All rights reserved. 9 | ''' 10 | 11 | from __future__ import annotations 12 | import json 13 | import math 14 | from pathlib import Path 15 | from typing import Any 16 | 17 | import numpy as np 18 | import torch 19 | import torchvision 20 | from einops import rearrange 21 | from PIL import Image 22 | from torch.utils.data import Dataset 23 | import jsonlines 24 | from collections import deque 25 | 26 | 27 | class EditDataset(Dataset): 28 | def __init__( 29 | self, 30 | path_instructpix2pix: str, 31 | path_hive_0: str, 32 | path_hive_1: str, 33 | path_hive_2: str, 34 | split: str = "train", 35 | splits: tuple[float, float, float] = (0.9, 0.05, 0.05), 36 | min_resize_res: int = 256, 37 | max_resize_res: int = 256, 38 | crop_res: int = 256, 39 | flip_prob: float = 0.0, 40 | ): 41 | assert split in ("train", "val", "test") 42 | assert sum(splits) == 1 43 | self.path_instructpix2pix = path_instructpix2pix 44 | self.path_hive_0 = path_hive_0 45 | self.path_hive_1 = path_hive_1 46 | self.path_hive_2 = path_hive_2 47 | self.min_resize_res = min_resize_res 48 | self.max_resize_res = max_resize_res 49 | self.crop_res = crop_res 50 | self.flip_prob = flip_prob 51 | self.seeds = [] 52 | self.instructions = [] 53 | self.source_imgs = [] 54 | self.edited_imgs = [] 55 | # load instructpix2pix dataset 56 | with open(Path(self.path_instructpix2pix, "seeds.json")) as f: 57 | seeds = json.load(f) 58 | split_0, split_1 = { 59 | "train": (0.0, splits[0]), 60 | "val": (splits[0], splits[0] + splits[1]), 61 | "test": (splits[0] + splits[1], 1.0), 62 | }[split] 63 | 64 | idx_0 = math.floor(split_0 * len(seeds)) 65 | idx_1 = math.floor(split_1 * len(seeds)) 66 | seeds = seeds[idx_0:idx_1] 67 | 68 | for seed in seeds: 69 | seed = deque(seed) 70 | seed.appendleft('') 71 | seed.appendleft('instructpix2pix') 72 | self.seeds.append(list(seed)) 73 | 74 | 75 | # load HIVE dataset first part 76 | 77 | cnt = 0 78 | with jsonlines.open(Path(self.path_hive_0, "training_cycle.jsonl")) as reader: 79 | for ll in reader: 80 | self.instructions.append(ll['instruction']) 81 | self.source_imgs.append(ll['source_img']) 82 | self.edited_imgs.append(ll['edited_img']) 83 | self.seeds.append(['hive_0', '', '', [cnt]]) 84 | cnt += 1 85 | 86 | # load HIVE dataset second part 87 | with open(Path(self.path_hive_1, "seeds.json")) as f: 88 | seeds = json.load(f) 89 | for seed in seeds: 90 | seed = deque(seed) 91 | seed.appendleft('hive_1') 92 | self.seeds.append(list(seed)) 93 | # load HIVE dataset third part 94 | with open(Path(self.path_hive_2, "seeds.json")) as f: 95 | seeds = json.load(f) 96 | for seed in seeds: 97 | seed = deque(seed) 98 | seed.appendleft('hive_2') 99 | self.seeds.append(list(seed)) 100 | 101 | def __len__(self) -> int: 102 | return len(self.seeds) 103 | 104 | def __getitem__(self, i: int) -> dict[str, Any]: 105 | 106 | name_0, name_1, name_2, seeds = self.seeds[i] 107 | if name_0 == 'instructpix2pix': 108 | propt_dir = Path(self.path_instructpix2pix, name_2) 109 | seed = seeds[torch.randint(0, len(seeds), ()).item()] 110 | with open(propt_dir.joinpath("prompt.json")) as fp: 111 | prompt = json.load(fp)["edit"] 112 | image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg")) 113 | image_1 = Image.open(propt_dir.joinpath(f"{seed}_1.jpg")) 114 | elif name_0 == 'hive_1': 115 | propt_dir = Path(self.path_hive_1, name_1, name_2) 116 | seed = seeds[torch.randint(0, len(seeds), ()).item()] 117 | with open(propt_dir.joinpath("prompt.json")) as fp: 118 | prompt = json.load(fp)["instruction"] 119 | image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg")) 120 | image_1 = Image.open(propt_dir.joinpath(f"{seed}_1.jpg")) 121 | elif name_0 == 'hive_2': 122 | propt_dir = Path(self.path_hive_2, name_1, name_2) 123 | seed = seeds[torch.randint(0, len(seeds), ()).item()] 124 | with open(propt_dir.joinpath("prompt.json")) as fp: 125 | prompt = json.load(fp)["instruction"] 126 | image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg")) 127 | image_1 = Image.open(propt_dir.joinpath(f"{seed}_1.jpg")) 128 | else: 129 | j = seeds[0] 130 | image_0 = Image.open(self.source_imgs[j]) 131 | image_1 = Image.open(self.edited_imgs[j]) 132 | prompt = self.instructions[j] 133 | 134 | reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item() 135 | image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS) 136 | image_1 = image_1.resize((reize_res, reize_res), Image.Resampling.LANCZOS) 137 | 138 | image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w") 139 | image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w") 140 | 141 | crop = torchvision.transforms.RandomCrop(self.crop_res) 142 | flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob)) 143 | image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2) 144 | 145 | return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt)) 146 | 147 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion). 2 | # See more details in LICENSE. 3 | 4 | name: hive 5 | channels: 6 | - pytorch 7 | - defaults 8 | dependencies: 9 | - python=3.8.5 10 | - pip=20.3 11 | - cudatoolkit=11.3 12 | - pytorch=1.11.0 13 | - torchvision=0.12.0 14 | - numpy=1.19.2 15 | - pip: 16 | - albumentations==0.4.3 17 | - datasets==2.8.0 18 | - diffusers 19 | - opencv-python==4.1.2.30 20 | - pudb==2019.2 21 | - invisible-watermark 22 | - imageio==2.9.0 23 | - imageio-ffmpeg==0.4.2 24 | - pytorch-lightning==1.4.2 25 | - omegaconf==2.1.1 26 | - test-tube>=0.7.5 27 | - streamlit>=0.73.1 28 | - einops==0.3.0 29 | - torch-fidelity==0.3.0 30 | - transformers==4.19.2 31 | - torchmetrics==0.6.0 32 | - kornia==0.6 33 | - open_clip_torch==2.0.2 34 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 35 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip 36 | - openai 37 | - gradio 38 | - seaborn 39 | - jsonlines 40 | - git+https://github.com/crowsonkb/k-diffusion.git 41 | -------------------------------------------------------------------------------- /imgs/example1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/imgs/example1.jpg -------------------------------------------------------------------------------- /imgs/example2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/imgs/example2.jpg -------------------------------------------------------------------------------- /imgs/example3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/imgs/example3.jpg -------------------------------------------------------------------------------- /imgs/results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/imgs/results.png -------------------------------------------------------------------------------- /metrics/clip_similarity.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2023 Salesforce, Inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: Apache License 2.0 5 | * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/ 6 | * By Shu Zhang 7 | * Redistributed from InstructPix2Pix repo: https://github.com/timothybrooks/instruct-pix2pix 8 | * Copyright (c) 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros. All rights reserved. 9 | ''' 10 | 11 | from __future__ import annotations 12 | 13 | import clip 14 | import torch 15 | import torch.nn as nn 16 | import torch.nn.functional as F 17 | from einops import rearrange 18 | 19 | 20 | class ClipSimilarity(nn.Module): 21 | def __init__(self, name: str = "ViT-L/14"): 22 | super().__init__() 23 | assert name in ("RN50", "RN101", "RN50x4", "RN50x16", "RN50x64", "ViT-B/32", "ViT-B/16", "ViT-L/14", "ViT-L/14@336px") # fmt: skip 24 | self.size = {"RN50x4": 288, "RN50x16": 384, "RN50x64": 448, "ViT-L/14@336px": 336}.get(name, 224) 25 | 26 | self.model, _ = clip.load(name, device="cpu", download_root="./") 27 | self.model.eval().requires_grad_(False) 28 | 29 | self.register_buffer("mean", torch.tensor((0.48145466, 0.4578275, 0.40821073))) 30 | self.register_buffer("std", torch.tensor((0.26862954, 0.26130258, 0.27577711))) 31 | 32 | def encode_text(self, text: list[str]) -> torch.Tensor: 33 | text = clip.tokenize(text, truncate=True).to(next(self.parameters()).device) 34 | text_features = self.model.encode_text(text) 35 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 36 | return text_features 37 | 38 | def encode_image(self, image: torch.Tensor) -> torch.Tensor: # Input images in range [0, 1]. 39 | image = F.interpolate(image.float(), size=self.size, mode="bicubic", align_corners=False) 40 | image = image - rearrange(self.mean, "c -> 1 c 1 1") 41 | image = image / rearrange(self.std, "c -> 1 c 1 1") 42 | image_features = self.model.encode_image(image) 43 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 44 | return image_features 45 | 46 | def forward( 47 | self, image_0: torch.Tensor, image_1: torch.Tensor, text_0: list[str], text_1: list[str] 48 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 49 | image_features_0 = self.encode_image(image_0) 50 | image_features_1 = self.encode_image(image_1) 51 | text_features_0 = self.encode_text(text_0) 52 | text_features_1 = self.encode_text(text_1) 53 | sim_0 = F.cosine_similarity(image_features_0, text_features_0) 54 | sim_1 = F.cosine_similarity(image_features_1, text_features_1) 55 | sim_direction = F.cosine_similarity(image_features_1 - image_features_0, text_features_1 - text_features_0) 56 | sim_image = F.cosine_similarity(image_features_0, image_features_1) 57 | return sim_0, sim_1, sim_direction, sim_image 58 | -------------------------------------------------------------------------------- /metrics/compute_metrics.py: -------------------------------------------------------------------------------- 1 | ''' 2 | * Copyright (c) 2023 Salesforce, Inc. 3 | * All rights reserved. 4 | * SPDX-License-Identifier: Apache License 2.0 5 | * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/ 6 | * By Shu Zhang 7 | * Redistributed from InstructPix2Pix repo: https://github.com/timothybrooks/instruct-pix2pix 8 | * Copyright (c) 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros. All rights reserved. 9 | ''' 10 | 11 | from __future__ import annotations 12 | 13 | import math 14 | import random 15 | import sys 16 | from argparse import ArgumentParser 17 | 18 | import einops 19 | import k_diffusion as K 20 | import numpy as np 21 | import torch 22 | import torch.nn as nn 23 | from tqdm.auto import tqdm 24 | from einops import rearrange 25 | from omegaconf import OmegaConf 26 | from PIL import Image, ImageOps 27 | from torch import autocast 28 | 29 | import json 30 | import matplotlib.pyplot as plt 31 | import seaborn 32 | from pathlib import Path 33 | 34 | sys.path.append("./") 35 | 36 | from clip_similarity import ClipSimilarity 37 | from edit_dataset import EditDatasetEval 38 | 39 | sys.path.append("./stable_diffusion") 40 | 41 | from ldm.util import instantiate_from_config 42 | 43 | 44 | class CFGDenoiser(nn.Module): 45 | def __init__(self, model): 46 | super().__init__() 47 | self.inner_model = model 48 | 49 | def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale): 50 | cfg_z = einops.repeat(z, "1 ... -> n ...", n=3) 51 | cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3) 52 | cfg_cond = { 53 | "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])], 54 | "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])], 55 | } 56 | out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3) 57 | return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond) 58 | 59 | 60 | def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False): 61 | print(f"Loading model from {ckpt}") 62 | pl_sd = torch.load(ckpt, map_location="cpu") 63 | if "global_step" in pl_sd: 64 | print(f"Global Step: {pl_sd['global_step']}") 65 | sd = pl_sd["state_dict"] 66 | if vae_ckpt is not None: 67 | print(f"Loading VAE from {vae_ckpt}") 68 | vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"] 69 | sd = { 70 | k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v 71 | for k, v in sd.items() 72 | } 73 | model = instantiate_from_config(config.model) 74 | m, u = model.load_state_dict(sd, strict=False) 75 | if len(m) > 0 and verbose: 76 | print("missing keys:") 77 | print(m) 78 | if len(u) > 0 and verbose: 79 | print("unexpected keys:") 80 | print(u) 81 | return model 82 | 83 | class ImageEditor(nn.Module): 84 | def __init__(self, config, ckpt, vae_ckpt=None): 85 | super().__init__() 86 | 87 | config = OmegaConf.load(config) 88 | self.model = load_model_from_config(config, ckpt, vae_ckpt) 89 | self.model.eval().cuda() 90 | self.model_wrap = K.external.CompVisDenoiser(self.model) 91 | self.model_wrap_cfg = CFGDenoiser(self.model_wrap) 92 | self.null_token = self.model.get_learned_conditioning([""]) 93 | 94 | def forward( 95 | self, 96 | image: torch.Tensor, 97 | edit: str, 98 | scale_txt: float = 7.5, 99 | scale_img: float = 1.0, 100 | steps: int = 100, 101 | ) -> torch.Tensor: 102 | assert image.dim() == 3 103 | assert image.size(1) % 64 == 0 104 | assert image.size(2) % 64 == 0 105 | with torch.no_grad(), autocast("cuda"), self.model.ema_scope(): 106 | cond = { 107 | "c_crossattn": [self.model.get_learned_conditioning([edit])], 108 | "c_concat": [self.model.encode_first_stage(image[None]).mode()], 109 | } 110 | uncond = { 111 | "c_crossattn": [self.model.get_learned_conditioning([""])], 112 | "c_concat": [torch.zeros_like(cond["c_concat"][0])], 113 | } 114 | extra_args = { 115 | "uncond": uncond, 116 | "cond": cond, 117 | "image_cfg_scale": scale_img, 118 | "text_cfg_scale": scale_txt, 119 | } 120 | sigmas = self.model_wrap.get_sigmas(steps) 121 | x = torch.randn_like(cond["c_concat"][0]) * sigmas[0] 122 | x = K.sampling.sample_euler_ancestral(self.model_wrap_cfg, x, sigmas, extra_args=extra_args) 123 | x = self.model.decode_first_stage(x)[0] 124 | return x 125 | 126 | 127 | def compute_metrics(config, 128 | model_path, 129 | vae_ckpt, 130 | data_path, 131 | output_path, 132 | scales_img, 133 | scales_txt, 134 | num_samples = 5000, 135 | split = "test", 136 | steps = 50, 137 | res = 512, 138 | seed = 0): 139 | editor = ImageEditor(config, model_path, vae_ckpt).cuda() 140 | clip_similarity = ClipSimilarity().cuda() 141 | 142 | 143 | 144 | outpath = Path(output_path, f"n={num_samples}_p={split}_s={steps}_r={res}_e={seed}.jsonl") 145 | Path(output_path).mkdir(parents=True, exist_ok=True) 146 | 147 | for scale_txt in scales_txt: 148 | for scale_img in scales_img: 149 | dataset = EditDatasetEval( 150 | path=data_path, 151 | split=split, 152 | res=res 153 | ) 154 | assert num_samples <= len(dataset) 155 | print(f'Processing t={scale_txt}, i={scale_img}') 156 | torch.manual_seed(seed) 157 | perm = torch.randperm(len(dataset)) 158 | count = 0 159 | i = 0 160 | 161 | sim_0_avg = 0 162 | sim_1_avg = 0 163 | sim_direction_avg = 0 164 | sim_image_avg = 0 165 | count = 0 166 | 167 | pbar = tqdm(total=num_samples) 168 | while count < num_samples: 169 | 170 | idx = perm[i].item() 171 | sample = dataset[idx] 172 | i += 1 173 | 174 | gen = editor(sample["image_0"].cuda(), sample["edit"], scale_txt=scale_txt, scale_img=scale_img, steps=steps) 175 | 176 | sim_0, sim_1, sim_direction, sim_image = clip_similarity( 177 | sample["image_0"][None].cuda(), gen[None].cuda(), [sample["input_prompt"]], [sample["output_prompt"]] 178 | ) 179 | sim_0_avg += sim_0.item() 180 | sim_1_avg += sim_1.item() 181 | sim_direction_avg += sim_direction.item() 182 | sim_image_avg += sim_image.item() 183 | count += 1 184 | pbar.update(count) 185 | pbar.close() 186 | 187 | sim_0_avg /= count 188 | sim_1_avg /= count 189 | sim_direction_avg /= count 190 | sim_image_avg /= count 191 | 192 | with open(outpath, "a") as f: 193 | f.write(f"{json.dumps(dict(sim_0=sim_0_avg, sim_1=sim_1_avg, sim_direction=sim_direction_avg, sim_image=sim_image_avg, num_samples=num_samples, split=split, scale_txt=scale_txt, scale_img=scale_img, steps=steps, res=res, seed=seed))}\n") 194 | return outpath 195 | 196 | def plot_metrics(metrics_file, output_path): 197 | 198 | with open(metrics_file, 'r') as f: 199 | data = [json.loads(line) for line in f] 200 | 201 | plt.rcParams.update({'font.size': 11.5}) 202 | seaborn.set_style("darkgrid") 203 | plt.figure(figsize=(20.5* 0.7, 10.8* 0.7), dpi=200) 204 | 205 | x = [d["sim_direction"] for d in data] 206 | y = [d["sim_image"] for d in data] 207 | 208 | plt.plot(x, y, marker='o', linewidth=2, markersize=4) 209 | 210 | plt.xlabel("CLIP Text-Image Direction Similarity", labelpad=10) 211 | plt.ylabel("CLIP Image Similarity", labelpad=10) 212 | 213 | plt.savefig(Path(output_path) / Path("plot.pdf"), bbox_inches="tight") 214 | 215 | def main(): 216 | parser = ArgumentParser() 217 | parser.add_argument("--resolution", default=512, type=int) 218 | parser.add_argument("--steps", default=100, type=int) 219 | parser.add_argument("--config", default="configs/generate.yaml", type=str) 220 | parser.add_argument("--output_path", default="analysis/", type=str) 221 | parser.add_argument("--ckpt", default="checkpoints/hive_v2_rw_condition.ckpt", type=str) 222 | parser.add_argument("--dataset", default="data/clip-filtered-dataset/", type=str) 223 | parser.add_argument("--vae-ckpt", default=None, type=str) 224 | args = parser.parse_args() 225 | 226 | scales_img = [1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2] 227 | scales_txt = [7.5] 228 | 229 | metrics_file = compute_metrics( 230 | args.config, 231 | args.ckpt, 232 | args.vae_ckpt, 233 | args.dataset, 234 | args.output_path, 235 | scales_img, 236 | scales_txt 237 | steps = args.steps 238 | ) 239 | 240 | plot_metrics(metrics_file, args.output_path) 241 | 242 | 243 | 244 | if __name__ == "__main__": 245 | main() 246 | -------------------------------------------------------------------------------- /scripts/download_checkpoints.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 4 | 5 | mkdir -p $SCRIPT_DIR/../checkpoints 6 | 7 | wget https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt $SCRIPT_DIR/../checkpoints/v2-1_512-ema-pruned.ckpt 8 | gcloud storage cp gs://sfr-hive-data-research/checkpoints/hive_rw_condition.ckpt $SCRIPT_DIR/../checkpoints/hive_rw_condition.ckpt 9 | gcloud storage cp gs://sfr-hive-data-research/checkpoints/hive_v2_rw_condition.ckpt $SCRIPT_DIR/../checkpoints/hive_v2_rw_condition.ckpt 10 | gcloud storage cp gs://sfr-hive-data-research/checkpoints/hive_rw.ckpt $SCRIPT_DIR/../checkpoints/hive_rw.ckpt 11 | gcloud storage cp gs://sfr-hive-data-research/checkpoints/hive_v2_rw.ckpt $SCRIPT_DIR/../checkpoints/hive_v2_rw.ckpt 12 | -------------------------------------------------------------------------------- /scripts/download_hive_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Make data folder relative to script location 4 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 5 | 6 | mkdir -p $SCRIPT_DIR/../data 7 | 8 | # Download HIVE training data 9 | gsutil cp -r gs://sfr-hive-data-research/data/training $SCRIPT_DIR/../data/training 10 | 11 | # Download HIVE evaluation data 12 | 13 | gcloud storage cp gs://sfr-hive-data-research/data/test.jsonl $SCRIPT_DIR/../data/test.jsonl 14 | gsutil cp -r gs://sfr-hive-data-research/data/evaluation $SCRIPT_DIR/../data/ -------------------------------------------------------------------------------- /scripts/download_instructpix2pix_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Make data folder relative to script location 4 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) 5 | 6 | mkdir -p $SCRIPT_DIR/../data_instructpix2pix 7 | 8 | # Download InstructPix2Pix data (https://arxiv.org/pdf/2211.09800.pdf) 9 | 10 | # Copy text datasets 11 | wget -q --show-progress http://instruct-pix2pix.eecs.berkeley.edu/gpt-generated-prompts.jsonl -O $SCRIPT_DIR/../data_instructpix2pix/gpt-generated-prompts.jsonl 12 | wget -q --show-progress http://instruct-pix2pix.eecs.berkeley.edu/human-written-prompts.jsonl -O $SCRIPT_DIR/../data_instructpix2pix/human-written-prompts.jsonl 13 | 14 | # If dataset name isn't provided, exit. 15 | if [ -z $1 ] 16 | then 17 | exit 0 18 | fi 19 | 20 | # Copy dataset files 21 | mkdir $SCRIPT_DIR/../data_instructpix2pix/$1 22 | wget -A zip,json -R "index.html*" -q --show-progress -r --no-parent http://instruct-pix2pix.eecs.berkeley.edu/$1/ -nd -P $SCRIPT_DIR/../data_instructpix2pix/$1/ 23 | 24 | # Unzip to folders 25 | unzip $SCRIPT_DIR/../data_instructpix2pix/$1/\*.zip -d $SCRIPT_DIR/../data_instructpix2pix/$1/ 26 | 27 | # Cleanup 28 | rm -f $SCRIPT_DIR/../data_instructpix2pix/$1/*.zip 29 | rm -f $SCRIPT_DIR/../data_instructpix2pix/$1/*.html 30 | -------------------------------------------------------------------------------- /stable_diffusion/Stable_Diffusion_v1_Model_Card.md: -------------------------------------------------------------------------------- 1 | # Stable Diffusion v1 Model Card 2 | This model card focuses on the model associated with the Stable Diffusion model, available [here](https://github.com/CompVis/stable-diffusion). 3 | 4 | ## Model Details 5 | - **Developed by:** Robin Rombach, Patrick Esser 6 | - **Model type:** Diffusion-based text-to-image generation model 7 | - **Language(s):** English 8 | - **License:** [Proprietary](LICENSE) 9 | - **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([CLIP ViT-L/14](https://arxiv.org/abs/2103.00020)) as suggested in the [Imagen paper](https://arxiv.org/abs/2205.11487). 10 | - **Resources for more information:** [GitHub Repository](https://github.com/CompVis/stable-diffusion), [Paper](https://arxiv.org/abs/2112.10752). 11 | - **Cite as:** 12 | 13 | @InProceedings{Rombach_2022_CVPR, 14 | author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn}, 15 | title = {High-Resolution Image Synthesis With Latent Diffusion Models}, 16 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 17 | month = {June}, 18 | year = {2022}, 19 | pages = {10684-10695} 20 | } 21 | 22 | # Uses 23 | 24 | ## Direct Use 25 | The model is intended for research purposes only. Possible research areas and 26 | tasks include 27 | 28 | - Safe deployment of models which have the potential to generate harmful content. 29 | - Probing and understanding the limitations and biases of generative models. 30 | - Generation of artworks and use in design and other artistic processes. 31 | - Applications in educational or creative tools. 32 | - Research on generative models. 33 | 34 | Excluded uses are described below. 35 | 36 | ### Misuse, Malicious Use, and Out-of-Scope Use 37 | _Note: This section is taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), but applies in the same way to Stable Diffusion v1_. 38 | 39 | The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes. 40 | 41 | #### Out-of-Scope Use 42 | The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model. 43 | 44 | #### Misuse and Malicious Use 45 | Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to: 46 | 47 | - Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc. 48 | - Intentionally promoting or propagating discriminatory content or harmful stereotypes. 49 | - Impersonating individuals without their consent. 50 | - Sexual content without consent of the people who might see it. 51 | - Mis- and disinformation 52 | - Representations of egregious violence and gore 53 | - Sharing of copyrighted or licensed material in violation of its terms of use. 54 | - Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use. 55 | 56 | ## Limitations and Bias 57 | 58 | ### Limitations 59 | 60 | - The model does not achieve perfect photorealism 61 | - The model cannot render legible text 62 | - The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere” 63 | - Faces and people in general may not be generated properly. 64 | - The model was trained mainly with English captions and will not work as well in other languages. 65 | - The autoencoding part of the model is lossy 66 | - The model was trained on a large-scale dataset 67 | [LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material 68 | and is not fit for product use without additional safety mechanisms and 69 | considerations. 70 | - No additional measures were used to deduplicate the dataset. As a result, we observe some degree of memorization for images that are duplicated in the training data. 71 | The training data can be searched at [https://rom1504.github.io/clip-retrieval/](https://rom1504.github.io/clip-retrieval/) to possibly assist in the detection of memorized images. 72 | 73 | ### Bias 74 | While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases. 75 | Stable Diffusion v1 was primarily trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/), 76 | which consists of images that are limited to English descriptions. 77 | Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for. 78 | This affects the overall output of the model, as white and western cultures are often set as the default. Further, the 79 | ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts. 80 | Stable Diffusion v1 mirrors and exacerbates biases to such a degree that viewer discretion must be advised irrespective of the input or its intent. 81 | 82 | 83 | ## Training 84 | 85 | **Training Data** 86 | The model developers used the following dataset for training the model: 87 | 88 | - LAION-5B and subsets thereof (see next section) 89 | 90 | **Training Procedure** 91 | Stable Diffusion v1 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training, 92 | 93 | - Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4 94 | - Text prompts are encoded through a ViT-L/14 text-encoder. 95 | - The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention. 96 | - The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet. 97 | 98 | We currently provide the following checkpoints: 99 | 100 | - `sd-v1-1.ckpt`: 237k steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en). 101 | 194k steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`). 102 | - `sd-v1-2.ckpt`: Resumed from `sd-v1-1.ckpt`. 103 | 515k steps at resolution `512x512` on [laion-aesthetics v2 5+](https://laion.ai/blog/laion-aesthetics/) (a subset of laion2B-en with estimated aesthetics score `> 5.0`, and additionally 104 | filtered to images with an original size `>= 512x512`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the [LAION-5B](https://laion.ai/blog/laion-5b/) metadata, the aesthetics score is estimated using the [LAION-Aesthetics Predictor V2](https://github.com/christophschuhmann/improved-aesthetic-predictor)). 105 | - `sd-v1-3.ckpt`: Resumed from `sd-v1-2.ckpt`. 195k steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598). 106 | - `sd-v1-4.ckpt`: Resumed from `sd-v1-2.ckpt`. 225k steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598). 107 | 108 | - **Hardware:** 32 x 8 x A100 GPUs 109 | - **Optimizer:** AdamW 110 | - **Gradient Accumulations**: 2 111 | - **Batch:** 32 x 8 x 2 x 4 = 2048 112 | - **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant 113 | 114 | ## Evaluation Results 115 | Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0, 116 | 5.0, 6.0, 7.0, 8.0) and 50 PLMS sampling 117 | steps show the relative improvements of the checkpoints: 118 | 119 | ![pareto](assets/v1-variants-scores.jpg) 120 | 121 | Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores. 122 | 123 | ## Environmental Impact 124 | 125 | **Stable Diffusion v1** **Estimated Emissions** 126 | Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact. 127 | 128 | - **Hardware Type:** A100 PCIe 40GB 129 | - **Hours used:** 150000 130 | - **Cloud Provider:** AWS 131 | - **Compute Region:** US-east 132 | - **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 11250 kg CO2 eq. 133 | 134 | ## Citation 135 | @InProceedings{Rombach_2022_CVPR, 136 | author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn}, 137 | title = {High-Resolution Image Synthesis With Latent Diffusion Models}, 138 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 139 | month = {June}, 140 | year = {2022}, 141 | pages = {10684-10695} 142 | } 143 | 144 | *This model card was written by: Robin Rombach and Patrick Esser and is based on the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).* 145 | -------------------------------------------------------------------------------- /stable_diffusion/assets/a-painting-of-a-fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/a-painting-of-a-fire.png -------------------------------------------------------------------------------- /stable_diffusion/assets/a-photograph-of-a-fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/a-photograph-of-a-fire.png -------------------------------------------------------------------------------- /stable_diffusion/assets/a-shirt-with-a-fire-printed-on-it.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/a-shirt-with-a-fire-printed-on-it.png -------------------------------------------------------------------------------- /stable_diffusion/assets/a-shirt-with-the-inscription-'fire'.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/a-shirt-with-the-inscription-'fire'.png -------------------------------------------------------------------------------- /stable_diffusion/assets/a-watercolor-painting-of-a-fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/a-watercolor-painting-of-a-fire.png -------------------------------------------------------------------------------- /stable_diffusion/assets/birdhouse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/birdhouse.png -------------------------------------------------------------------------------- /stable_diffusion/assets/fire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/fire.png -------------------------------------------------------------------------------- /stable_diffusion/assets/inpainting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/inpainting.png -------------------------------------------------------------------------------- /stable_diffusion/assets/modelfigure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/modelfigure.png -------------------------------------------------------------------------------- /stable_diffusion/assets/rdm-preview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/rdm-preview.jpg -------------------------------------------------------------------------------- /stable_diffusion/assets/reconstruction1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/reconstruction1.png -------------------------------------------------------------------------------- /stable_diffusion/assets/reconstruction2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/reconstruction2.png -------------------------------------------------------------------------------- /stable_diffusion/assets/results.gif.REMOVED.git-id: -------------------------------------------------------------------------------- 1 | 82b6590e670a32196093cc6333ea19e6547d07de -------------------------------------------------------------------------------- /stable_diffusion/assets/rick.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/rick.jpeg -------------------------------------------------------------------------------- /stable_diffusion/assets/stable-samples/img2img/mountains-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/stable-samples/img2img/mountains-1.png -------------------------------------------------------------------------------- /stable_diffusion/assets/stable-samples/img2img/mountains-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/stable-samples/img2img/mountains-2.png -------------------------------------------------------------------------------- /stable_diffusion/assets/stable-samples/img2img/mountains-3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/stable-samples/img2img/mountains-3.png -------------------------------------------------------------------------------- /stable_diffusion/assets/stable-samples/img2img/sketch-mountains-input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/stable-samples/img2img/sketch-mountains-input.jpg -------------------------------------------------------------------------------- /stable_diffusion/assets/stable-samples/img2img/upscaling-in.png.REMOVED.git-id: -------------------------------------------------------------------------------- 1 | 501c31c21751664957e69ce52cad1818b6d2f4ce -------------------------------------------------------------------------------- /stable_diffusion/assets/stable-samples/img2img/upscaling-out.png.REMOVED.git-id: -------------------------------------------------------------------------------- 1 | 1c4bb25a779f34d86b2d90e584ac67af91bb1303 -------------------------------------------------------------------------------- /stable_diffusion/assets/stable-samples/txt2img/000002025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/stable-samples/txt2img/000002025.png -------------------------------------------------------------------------------- /stable_diffusion/assets/stable-samples/txt2img/000002035.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/stable-samples/txt2img/000002035.png -------------------------------------------------------------------------------- /stable_diffusion/assets/stable-samples/txt2img/merged-0005.png.REMOVED.git-id: -------------------------------------------------------------------------------- 1 | ca0a1af206555f0f208a1ab879e95efedc1b1c5b -------------------------------------------------------------------------------- /stable_diffusion/assets/stable-samples/txt2img/merged-0006.png.REMOVED.git-id: -------------------------------------------------------------------------------- 1 | 999f3703230580e8c89e9081abd6a1f8f50896d4 -------------------------------------------------------------------------------- /stable_diffusion/assets/stable-samples/txt2img/merged-0007.png.REMOVED.git-id: -------------------------------------------------------------------------------- 1 | af390acaf601283782d6f479d4cade4d78e30b26 -------------------------------------------------------------------------------- /stable_diffusion/assets/the-earth-is-on-fire,-oil-on-canvas.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/the-earth-is-on-fire,-oil-on-canvas.png -------------------------------------------------------------------------------- /stable_diffusion/assets/txt2img-convsample.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/txt2img-convsample.png -------------------------------------------------------------------------------- /stable_diffusion/assets/txt2img-preview.png.REMOVED.git-id: -------------------------------------------------------------------------------- 1 | 51ee1c235dfdc63d4c41de7d303d03730e43c33c -------------------------------------------------------------------------------- /stable_diffusion/assets/v1-variants-scores.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/v1-variants-scores.jpg -------------------------------------------------------------------------------- /stable_diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 16 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /stable_diffusion/configs/autoencoder/autoencoder_kl_32x32x4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 4 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /stable_diffusion/configs/autoencoder/autoencoder_kl_64x64x3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 3 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /stable_diffusion/configs/autoencoder/autoencoder_kl_8x8x64.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 64 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16,8] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /stable_diffusion/configs/latent-diffusion/celebahq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | 15 | unet_config: 16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 17 | params: 18 | image_size: 64 19 | in_channels: 3 20 | out_channels: 3 21 | model_channels: 224 22 | attention_resolutions: 23 | # note: this isn\t actually the resolution but 24 | # the downsampling factor, i.e. this corresnponds to 25 | # attention on spatial resolution 8,16,32, as the 26 | # spatial reolution of the latents is 64 for f4 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | num_head_channels: 32 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: __is_unconditional__ 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 48 64 | num_workers: 5 65 | wrap: false 66 | train: 67 | target: taming.data.faceshq.CelebAHQTrain 68 | params: 69 | size: 256 70 | validation: 71 | target: taming.data.faceshq.CelebAHQValidation 72 | params: 73 | size: 256 74 | 75 | 76 | lightning: 77 | callbacks: 78 | image_logger: 79 | target: main.ImageLogger 80 | params: 81 | batch_frequency: 5000 82 | max_images: 8 83 | increase_log_steps: False 84 | 85 | trainer: 86 | benchmark: True -------------------------------------------------------------------------------- /stable_diffusion/configs/latent-diffusion/cin-ldm-vq-f8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | #note: this isn\t actually the resolution but 26 | # the downsampling factor, i.e. this corresnponds to 27 | # attention on spatial resolution 8,16,32, as the 28 | # spatial reolution of the latents is 32 for f8 29 | - 4 30 | - 2 31 | - 1 32 | num_res_blocks: 2 33 | channel_mult: 34 | - 1 35 | - 2 36 | - 4 37 | num_head_channels: 32 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 512 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 4 45 | n_embed: 16384 46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml 47 | ddconfig: 48 | double_z: false 49 | z_channels: 4 50 | resolution: 256 51 | in_channels: 3 52 | out_ch: 3 53 | ch: 128 54 | ch_mult: 55 | - 1 56 | - 2 57 | - 2 58 | - 4 59 | num_res_blocks: 2 60 | attn_resolutions: 61 | - 32 62 | dropout: 0.0 63 | lossconfig: 64 | target: torch.nn.Identity 65 | cond_stage_config: 66 | target: ldm.modules.encoders.modules.ClassEmbedder 67 | params: 68 | embed_dim: 512 69 | key: class_label 70 | data: 71 | target: main.DataModuleFromConfig 72 | params: 73 | batch_size: 64 74 | num_workers: 12 75 | wrap: false 76 | train: 77 | target: ldm.data.imagenet.ImageNetTrain 78 | params: 79 | config: 80 | size: 256 81 | validation: 82 | target: ldm.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | 87 | 88 | lightning: 89 | callbacks: 90 | image_logger: 91 | target: main.ImageLogger 92 | params: 93 | batch_frequency: 5000 94 | max_images: 8 95 | increase_log_steps: False 96 | 97 | trainer: 98 | benchmark: True -------------------------------------------------------------------------------- /stable_diffusion/configs/latent-diffusion/cin256-v2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 64 23 | in_channels: 3 24 | out_channels: 3 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 5 36 | num_heads: 1 37 | use_spatial_transformer: true 38 | transformer_depth: 1 39 | context_dim: 512 40 | 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 3 45 | n_embed: 8192 46 | ddconfig: 47 | double_z: false 48 | z_channels: 3 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.ClassEmbedder 65 | params: 66 | n_classes: 1001 67 | embed_dim: 512 68 | key: class_label 69 | -------------------------------------------------------------------------------- /stable_diffusion/configs/latent-diffusion/ffhq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | embed_dim: 3 40 | n_embed: 8192 41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 42 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: taming.data.faceshq.FFHQTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: taming.data.faceshq.FFHQValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /stable_diffusion/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 40 | embed_dim: 3 41 | n_embed: 8192 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 48 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: ldm.data.lsun.LSUNBedroomsTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: ldm.data.lsun.LSUNBedroomsValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /stable_diffusion/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: "image" 12 | cond_stage_key: "image" 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: False 16 | concat_mode: False 17 | scale_by_std: True 18 | monitor: 'val/loss_simple_ema' 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [10000] 24 | cycle_lengths: [10000000000000] 25 | f_start: [1.e-6] 26 | f_max: [1.] 27 | f_min: [ 1.] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 192 36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4 37 | num_res_blocks: 2 38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2 39 | num_heads: 8 40 | use_scale_shift_norm: True 41 | resblock_updown: True 42 | 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: "val/rec_loss" 48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 49 | ddconfig: 50 | double_z: True 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 57 | num_res_blocks: 2 58 | attn_resolutions: [ ] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: "__is_unconditional__" 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 96 69 | num_workers: 5 70 | wrap: False 71 | train: 72 | target: ldm.data.lsun.LSUNChurchesTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: ldm.data.lsun.LSUNChurchesValidation 77 | params: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 5000 86 | max_images: 8 87 | increase_log_steps: False 88 | 89 | 90 | trainer: 91 | benchmark: True -------------------------------------------------------------------------------- /stable_diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 32 24 | in_channels: 4 25 | out_channels: 4 26 | model_channels: 320 27 | attention_resolutions: 28 | - 4 29 | - 2 30 | - 1 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 4 36 | - 4 37 | num_heads: 8 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 1280 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.BERTEmbedder 69 | params: 70 | n_embed: 1280 71 | n_layer: 32 72 | -------------------------------------------------------------------------------- /stable_diffusion/configs/retrieval-augmented-diffusion/768x768.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.015 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: jpg 11 | cond_stage_key: nix 12 | image_size: 48 13 | channels: 16 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_by_std: false 18 | scale_factor: 0.22765929 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 48 23 | in_channels: 16 24 | out_channels: 16 25 | model_channels: 448 26 | attention_resolutions: 27 | - 4 28 | - 2 29 | - 1 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | use_scale_shift_norm: false 37 | resblock_updown: false 38 | num_head_channels: 32 39 | use_spatial_transformer: true 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: true 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | monitor: val/rec_loss 47 | embed_dim: 16 48 | ddconfig: 49 | double_z: true 50 | z_channels: 16 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: 56 | - 1 57 | - 1 58 | - 2 59 | - 2 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: 63 | - 16 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | cond_stage_config: 68 | target: torch.nn.Identity -------------------------------------------------------------------------------- /stable_diffusion/configs/stable-diffusion/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /stable_diffusion/data/DejaVuSans.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/DejaVuSans.ttf -------------------------------------------------------------------------------- /stable_diffusion/data/example_conditioning/superresolution/sample_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/example_conditioning/superresolution/sample_0.jpg -------------------------------------------------------------------------------- /stable_diffusion/data/example_conditioning/text_conditional/sample_0.txt: -------------------------------------------------------------------------------- 1 | A basket of cerries 2 | -------------------------------------------------------------------------------- /stable_diffusion/data/imagenet_train_hr_indices.p.REMOVED.git-id: -------------------------------------------------------------------------------- 1 | b8d6d4689d2ecf32147e9cc2f5e6c50e072df26f -------------------------------------------------------------------------------- /stable_diffusion/data/imagenet_val_hr_indices.p: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/imagenet_val_hr_indices.p -------------------------------------------------------------------------------- /stable_diffusion/data/inpainting_examples/6458524847_2f4c361183_k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/6458524847_2f4c361183_k.png -------------------------------------------------------------------------------- /stable_diffusion/data/inpainting_examples/6458524847_2f4c361183_k_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/6458524847_2f4c361183_k_mask.png -------------------------------------------------------------------------------- /stable_diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png -------------------------------------------------------------------------------- /stable_diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png -------------------------------------------------------------------------------- /stable_diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png -------------------------------------------------------------------------------- /stable_diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png -------------------------------------------------------------------------------- /stable_diffusion/data/inpainting_examples/bench2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/bench2.png -------------------------------------------------------------------------------- /stable_diffusion/data/inpainting_examples/bench2_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/bench2_mask.png -------------------------------------------------------------------------------- /stable_diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png -------------------------------------------------------------------------------- /stable_diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png -------------------------------------------------------------------------------- /stable_diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png -------------------------------------------------------------------------------- /stable_diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png -------------------------------------------------------------------------------- /stable_diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png -------------------------------------------------------------------------------- /stable_diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png -------------------------------------------------------------------------------- /stable_diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png -------------------------------------------------------------------------------- /stable_diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png -------------------------------------------------------------------------------- /stable_diffusion/environment.yaml: -------------------------------------------------------------------------------- 1 | name: ldm 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.11.0 10 | - torchvision=0.12.0 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - diffusers 15 | - opencv-python==4.1.2.30 16 | - pudb==2019.2 17 | - invisible-watermark 18 | - imageio==2.9.0 19 | - imageio-ffmpeg==0.4.2 20 | - pytorch-lightning==1.4.2 21 | - omegaconf==2.1.1 22 | - test-tube>=0.7.5 23 | - streamlit>=0.73.1 24 | - einops==0.3.0 25 | - torch-fidelity==0.3.0 26 | - transformers==4.19.2 27 | - torchmetrics==0.6.0 28 | - kornia==0.6 29 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 30 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip 31 | - -e . 32 | -------------------------------------------------------------------------------- /stable_diffusion/ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/ldm/data/__init__.py -------------------------------------------------------------------------------- /stable_diffusion/ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /stable_diffusion/ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /stable_diffusion/ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /stable_diffusion/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /stable_diffusion/ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler -------------------------------------------------------------------------------- /stable_diffusion/ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | 5 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 6 | 7 | 8 | class DPMSolverSampler(object): 9 | def __init__(self, model, **kwargs): 10 | super().__init__() 11 | self.model = model 12 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 13 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 14 | 15 | def register_buffer(self, name, attr): 16 | if type(attr) == torch.Tensor: 17 | if attr.device != torch.device("cuda"): 18 | attr = attr.to(torch.device("cuda")) 19 | setattr(self, name, attr) 20 | 21 | @torch.no_grad() 22 | def sample(self, 23 | S, 24 | batch_size, 25 | shape, 26 | conditioning=None, 27 | callback=None, 28 | normals_sequence=None, 29 | img_callback=None, 30 | quantize_x0=False, 31 | eta=0., 32 | mask=None, 33 | x0=None, 34 | temperature=1., 35 | noise_dropout=0., 36 | score_corrector=None, 37 | corrector_kwargs=None, 38 | verbose=True, 39 | x_T=None, 40 | log_every_t=100, 41 | unconditional_guidance_scale=1., 42 | unconditional_conditioning=None, 43 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 44 | **kwargs 45 | ): 46 | if conditioning is not None: 47 | if isinstance(conditioning, dict): 48 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 49 | if cbs != batch_size: 50 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 51 | else: 52 | if conditioning.shape[0] != batch_size: 53 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 54 | 55 | # sampling 56 | C, H, W = shape 57 | size = (batch_size, C, H, W) 58 | 59 | # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 60 | 61 | device = self.model.betas.device 62 | if x_T is None: 63 | img = torch.randn(size, device=device) 64 | else: 65 | img = x_T 66 | 67 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 68 | 69 | model_fn = model_wrapper( 70 | lambda x, t, c: self.model.apply_model(x, t, c), 71 | ns, 72 | model_type="noise", 73 | guidance_type="classifier-free", 74 | condition=conditioning, 75 | unconditional_condition=unconditional_conditioning, 76 | guidance_scale=unconditional_guidance_scale, 77 | ) 78 | 79 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 80 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 81 | 82 | return x.to(device), None 83 | -------------------------------------------------------------------------------- /stable_diffusion/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /stable_diffusion/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /stable_diffusion/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /stable_diffusion/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /stable_diffusion/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /stable_diffusion/ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /stable_diffusion/ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /stable_diffusion/ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /stable_diffusion/ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /stable_diffusion/ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /stable_diffusion/ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == '__is_first_stage__': 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 86 | 87 | 88 | def get_obj_from_str(string, reload=False): 89 | module, cls = string.rsplit(".", 1) 90 | if reload: 91 | module_imp = importlib.import_module(module) 92 | importlib.reload(module_imp) 93 | return getattr(importlib.import_module(module, package=None), cls) 94 | 95 | 96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 97 | # create dummy dataset instance 98 | 99 | # run prefetching 100 | if idx_to_fn: 101 | res = func(data, worker_id=idx) 102 | else: 103 | res = func(data) 104 | Q.put([idx, res]) 105 | Q.put("Done") 106 | 107 | 108 | def parallel_data_prefetch( 109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 110 | ): 111 | # if target_data_type not in ["ndarray", "list"]: 112 | # raise ValueError( 113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 114 | # ) 115 | if isinstance(data, np.ndarray) and target_data_type == "list": 116 | raise ValueError("list expected but function got ndarray.") 117 | elif isinstance(data, abc.Iterable): 118 | if isinstance(data, dict): 119 | print( 120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 121 | ) 122 | data = list(data.values()) 123 | if target_data_type == "ndarray": 124 | data = np.asarray(data) 125 | else: 126 | data = list(data) 127 | else: 128 | raise TypeError( 129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 130 | ) 131 | 132 | if cpu_intensive: 133 | Q = mp.Queue(1000) 134 | proc = mp.Process 135 | else: 136 | Q = Queue(1000) 137 | proc = Thread 138 | # spawn processes 139 | if target_data_type == "ndarray": 140 | arguments = [ 141 | [func, Q, part, i, use_worker_id] 142 | for i, part in enumerate(np.array_split(data, n_proc)) 143 | ] 144 | else: 145 | step = ( 146 | int(len(data) / n_proc + 1) 147 | if len(data) % n_proc != 0 148 | else int(len(data) / n_proc) 149 | ) 150 | arguments = [ 151 | [func, Q, part, i, use_worker_id] 152 | for i, part in enumerate( 153 | [data[i: i + step] for i in range(0, len(data), step)] 154 | ) 155 | ] 156 | processes = [] 157 | for i in range(n_proc): 158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 159 | processes += [p] 160 | 161 | # start processes 162 | print(f"Start prefetching...") 163 | import time 164 | 165 | start = time.time() 166 | gather_res = [[] for _ in range(n_proc)] 167 | try: 168 | for p in processes: 169 | p.start() 170 | 171 | k = 0 172 | while k < n_proc: 173 | # get result 174 | res = Q.get() 175 | if res == "Done": 176 | k += 1 177 | else: 178 | gather_res[res[0]] = res[1] 179 | 180 | except Exception as e: 181 | print("Exception: ", e) 182 | for p in processes: 183 | p.terminate() 184 | 185 | raise e 186 | finally: 187 | for p in processes: 188 | p.join() 189 | print(f"Prefetching complete. [{time.time() - start} sec.]") 190 | 191 | if target_data_type == 'ndarray': 192 | if not isinstance(gather_res[0], np.ndarray): 193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 194 | 195 | # order outputs 196 | return np.concatenate(gather_res, axis=0) 197 | elif target_data_type == 'list': 198 | out = [] 199 | for r in gather_res: 200 | out.extend(r) 201 | return out 202 | else: 203 | return gather_res 204 | -------------------------------------------------------------------------------- /stable_diffusion/models/first_stage_models/kl-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 16 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | num_res_blocks: 2 27 | attn_resolutions: 28 | - 16 29 | dropout: 0.0 30 | data: 31 | target: main.DataModuleFromConfig 32 | params: 33 | batch_size: 6 34 | wrap: true 35 | train: 36 | target: ldm.data.openimages.FullOpenImagesTrain 37 | params: 38 | size: 384 39 | crop_size: 256 40 | validation: 41 | target: ldm.data.openimages.FullOpenImagesValidation 42 | params: 43 | size: 384 44 | crop_size: 256 45 | -------------------------------------------------------------------------------- /stable_diffusion/models/first_stage_models/kl-f32/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 64 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 1 23 | - 2 24 | - 2 25 | - 4 26 | - 4 27 | num_res_blocks: 2 28 | attn_resolutions: 29 | - 16 30 | - 8 31 | dropout: 0.0 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 6 36 | wrap: true 37 | train: 38 | target: ldm.data.openimages.FullOpenImagesTrain 39 | params: 40 | size: 384 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | size: 384 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /stable_diffusion/models/first_stage_models/kl-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 3 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | num_res_blocks: 2 25 | attn_resolutions: [] 26 | dropout: 0.0 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 10 31 | wrap: true 32 | train: 33 | target: ldm.data.openimages.FullOpenImagesTrain 34 | params: 35 | size: 384 36 | crop_size: 256 37 | validation: 38 | target: ldm.data.openimages.FullOpenImagesValidation 39 | params: 40 | size: 384 41 | crop_size: 256 42 | -------------------------------------------------------------------------------- /stable_diffusion/models/first_stage_models/kl-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: val/rec_loss 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 1.0e-06 12 | disc_weight: 0.5 13 | ddconfig: 14 | double_z: true 15 | z_channels: 4 16 | resolution: 256 17 | in_channels: 3 18 | out_ch: 3 19 | ch: 128 20 | ch_mult: 21 | - 1 22 | - 2 23 | - 4 24 | - 4 25 | num_res_blocks: 2 26 | attn_resolutions: [] 27 | dropout: 0.0 28 | data: 29 | target: main.DataModuleFromConfig 30 | params: 31 | batch_size: 4 32 | wrap: true 33 | train: 34 | target: ldm.data.openimages.FullOpenImagesTrain 35 | params: 36 | size: 384 37 | crop_size: 256 38 | validation: 39 | target: ldm.data.openimages.FullOpenImagesValidation 40 | params: 41 | size: 384 42 | crop_size: 256 43 | -------------------------------------------------------------------------------- /stable_diffusion/models/first_stage_models/vq-f16/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 8 6 | n_embed: 16384 7 | ddconfig: 8 | double_z: false 9 | z_channels: 8 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: 15 | - 1 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 16 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | disc_num_layers: 2 32 | codebook_weight: 1.0 33 | 34 | data: 35 | target: main.DataModuleFromConfig 36 | params: 37 | batch_size: 14 38 | num_workers: 20 39 | wrap: true 40 | train: 41 | target: ldm.data.openimages.FullOpenImagesTrain 42 | params: 43 | size: 384 44 | crop_size: 256 45 | validation: 46 | target: ldm.data.openimages.FullOpenImagesValidation 47 | params: 48 | size: 384 49 | crop_size: 256 50 | -------------------------------------------------------------------------------- /stable_diffusion/models/first_stage_models/vq-f4-noattn/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | attn_type: none 11 | double_z: false 12 | z_channels: 3 13 | resolution: 256 14 | in_channels: 3 15 | out_ch: 3 16 | ch: 128 17 | ch_mult: 18 | - 1 19 | - 2 20 | - 4 21 | num_res_blocks: 2 22 | attn_resolutions: [] 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 11 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 8 37 | num_workers: 12 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | crop_size: 256 43 | validation: 44 | target: ldm.data.openimages.FullOpenImagesValidation 45 | params: 46 | crop_size: 256 47 | -------------------------------------------------------------------------------- /stable_diffusion/models/first_stage_models/vq-f4/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 3 6 | n_embed: 8192 7 | monitor: val/rec_loss 8 | 9 | ddconfig: 10 | double_z: false 11 | z_channels: 3 12 | resolution: 256 13 | in_channels: 3 14 | out_ch: 3 15 | ch: 128 16 | ch_mult: 17 | - 1 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: [] 22 | dropout: 0.0 23 | lossconfig: 24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 25 | params: 26 | disc_conditional: false 27 | disc_in_channels: 3 28 | disc_start: 0 29 | disc_weight: 0.75 30 | codebook_weight: 1.0 31 | 32 | data: 33 | target: main.DataModuleFromConfig 34 | params: 35 | batch_size: 8 36 | num_workers: 16 37 | wrap: true 38 | train: 39 | target: ldm.data.openimages.FullOpenImagesTrain 40 | params: 41 | crop_size: 256 42 | validation: 43 | target: ldm.data.openimages.FullOpenImagesValidation 44 | params: 45 | crop_size: 256 46 | -------------------------------------------------------------------------------- /stable_diffusion/models/first_stage_models/vq-f8-n256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 256 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 250001 30 | disc_weight: 0.75 31 | codebook_weight: 1.0 32 | 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /stable_diffusion/models/first_stage_models/vq-f8/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: ldm.models.autoencoder.VQModel 4 | params: 5 | embed_dim: 4 6 | n_embed: 16384 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 4 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 32 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_num_layers: 2 30 | disc_start: 1 31 | disc_weight: 0.6 32 | codebook_weight: 1.0 33 | data: 34 | target: main.DataModuleFromConfig 35 | params: 36 | batch_size: 10 37 | num_workers: 20 38 | wrap: true 39 | train: 40 | target: ldm.data.openimages.FullOpenImagesTrain 41 | params: 42 | size: 384 43 | crop_size: 256 44 | validation: 45 | target: ldm.data.openimages.FullOpenImagesValidation 46 | params: 47 | size: 384 48 | crop_size: 256 49 | -------------------------------------------------------------------------------- /stable_diffusion/models/ldm/bsr_sr/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l2 10 | first_stage_key: image 11 | cond_stage_key: LR_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: false 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 160 23 | attention_resolutions: 24 | - 16 25 | - 8 26 | num_res_blocks: 2 27 | channel_mult: 28 | - 1 29 | - 2 30 | - 2 31 | - 4 32 | num_head_channels: 32 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | monitor: val/rec_loss 39 | ddconfig: 40 | double_z: false 41 | z_channels: 3 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | cond_stage_config: 56 | target: torch.nn.Identity 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 64 61 | wrap: false 62 | num_workers: 12 63 | train: 64 | target: ldm.data.openimages.SuperresOpenImagesAdvancedTrain 65 | params: 66 | size: 256 67 | degradation: bsrgan_light 68 | downscale_f: 4 69 | min_crop_f: 0.5 70 | max_crop_f: 1.0 71 | random_crop: true 72 | validation: 73 | target: ldm.data.openimages.SuperresOpenImagesAdvancedValidation 74 | params: 75 | size: 256 76 | degradation: bsrgan_light 77 | downscale_f: 4 78 | min_crop_f: 0.5 79 | max_crop_f: 1.0 80 | random_crop: true 81 | -------------------------------------------------------------------------------- /stable_diffusion/models/ldm/celeba256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.CelebAHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.CelebAHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /stable_diffusion/models/ldm/cin256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | - 4 26 | - 2 27 | - 1 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 4 33 | num_head_channels: 32 34 | use_spatial_transformer: true 35 | transformer_depth: 1 36 | context_dim: 512 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 4 41 | n_embed: 16384 42 | ddconfig: 43 | double_z: false 44 | z_channels: 4 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: 56 | - 32 57 | dropout: 0.0 58 | lossconfig: 59 | target: torch.nn.Identity 60 | cond_stage_config: 61 | target: ldm.modules.encoders.modules.ClassEmbedder 62 | params: 63 | embed_dim: 512 64 | key: class_label 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 64 69 | num_workers: 12 70 | wrap: false 71 | train: 72 | target: ldm.data.imagenet.ImageNetTrain 73 | params: 74 | config: 75 | size: 256 76 | validation: 77 | target: ldm.data.imagenet.ImageNetValidation 78 | params: 79 | config: 80 | size: 256 81 | -------------------------------------------------------------------------------- /stable_diffusion/models/ldm/ffhq256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 42 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.faceshq.FFHQTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.faceshq.FFHQValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /stable_diffusion/models/ldm/inpainting_big/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: masked_image 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | monitor: val/loss 16 | scheduler_config: 17 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler 18 | params: 19 | verbosity_interval: 0 20 | warm_up_steps: 1000 21 | max_decay_steps: 50000 22 | lr_start: 0.001 23 | lr_max: 0.1 24 | lr_min: 0.0001 25 | unet_config: 26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 27 | params: 28 | image_size: 64 29 | in_channels: 7 30 | out_channels: 3 31 | model_channels: 256 32 | attention_resolutions: 33 | - 8 34 | - 4 35 | - 2 36 | num_res_blocks: 2 37 | channel_mult: 38 | - 1 39 | - 2 40 | - 3 41 | - 4 42 | num_heads: 8 43 | resblock_updown: true 44 | first_stage_config: 45 | target: ldm.models.autoencoder.VQModelInterface 46 | params: 47 | embed_dim: 3 48 | n_embed: 8192 49 | monitor: val/rec_loss 50 | ddconfig: 51 | attn_type: none 52 | double_z: false 53 | z_channels: 3 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: ldm.modules.losses.contperceptual.DummyLoss 67 | cond_stage_config: __is_first_stage__ 68 | -------------------------------------------------------------------------------- /stable_diffusion/models/ldm/layout2img-openimages256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: coordinates_bbox 12 | image_size: 64 13 | channels: 3 14 | conditioning_key: crossattn 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 3 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 8 25 | - 4 26 | - 2 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 2 31 | - 3 32 | - 4 33 | num_head_channels: 32 34 | use_spatial_transformer: true 35 | transformer_depth: 3 36 | context_dim: 512 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | monitor: val/rec_loss 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.BERTEmbedder 61 | params: 62 | n_embed: 512 63 | n_layer: 16 64 | vocab_size: 8192 65 | max_seq_len: 92 66 | use_tokenizer: false 67 | monitor: val/loss_simple_ema 68 | data: 69 | target: main.DataModuleFromConfig 70 | params: 71 | batch_size: 24 72 | wrap: false 73 | num_workers: 10 74 | train: 75 | target: ldm.data.openimages.OpenImagesBBoxTrain 76 | params: 77 | size: 256 78 | validation: 79 | target: ldm.data.openimages.OpenImagesBBoxValidation 80 | params: 81 | size: 256 82 | -------------------------------------------------------------------------------- /stable_diffusion/models/ldm/lsun_beds256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: false 15 | concat_mode: false 16 | monitor: val/loss 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 224 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 4 34 | num_head_channels: 32 35 | first_stage_config: 36 | target: ldm.models.autoencoder.VQModelInterface 37 | params: 38 | embed_dim: 3 39 | n_embed: 8192 40 | ddconfig: 41 | double_z: false 42 | z_channels: 3 43 | resolution: 256 44 | in_channels: 3 45 | out_ch: 3 46 | ch: 128 47 | ch_mult: 48 | - 1 49 | - 2 50 | - 4 51 | num_res_blocks: 2 52 | attn_resolutions: [] 53 | dropout: 0.0 54 | lossconfig: 55 | target: torch.nn.Identity 56 | cond_stage_config: __is_unconditional__ 57 | data: 58 | target: main.DataModuleFromConfig 59 | params: 60 | batch_size: 48 61 | num_workers: 5 62 | wrap: false 63 | train: 64 | target: ldm.data.lsun.LSUNBedroomsTrain 65 | params: 66 | size: 256 67 | validation: 68 | target: ldm.data.lsun.LSUNBedroomsValidation 69 | params: 70 | size: 256 71 | -------------------------------------------------------------------------------- /stable_diffusion/models/ldm/lsun_churches256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: image 12 | cond_stage_key: image 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: false 16 | concat_mode: false 17 | scale_by_std: true 18 | monitor: val/loss_simple_ema 19 | scheduler_config: 20 | target: ldm.lr_scheduler.LambdaLinearScheduler 21 | params: 22 | warm_up_steps: 23 | - 10000 24 | cycle_lengths: 25 | - 10000000000000 26 | f_start: 27 | - 1.0e-06 28 | f_max: 29 | - 1.0 30 | f_min: 31 | - 1.0 32 | unet_config: 33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 34 | params: 35 | image_size: 32 36 | in_channels: 4 37 | out_channels: 4 38 | model_channels: 192 39 | attention_resolutions: 40 | - 1 41 | - 2 42 | - 4 43 | - 8 44 | num_res_blocks: 2 45 | channel_mult: 46 | - 1 47 | - 2 48 | - 2 49 | - 4 50 | - 4 51 | num_heads: 8 52 | use_scale_shift_norm: true 53 | resblock_updown: true 54 | first_stage_config: 55 | target: ldm.models.autoencoder.AutoencoderKL 56 | params: 57 | embed_dim: 4 58 | monitor: val/rec_loss 59 | ddconfig: 60 | double_z: true 61 | z_channels: 4 62 | resolution: 256 63 | in_channels: 3 64 | out_ch: 3 65 | ch: 128 66 | ch_mult: 67 | - 1 68 | - 2 69 | - 4 70 | - 4 71 | num_res_blocks: 2 72 | attn_resolutions: [] 73 | dropout: 0.0 74 | lossconfig: 75 | target: torch.nn.Identity 76 | 77 | cond_stage_config: '__is_unconditional__' 78 | 79 | data: 80 | target: main.DataModuleFromConfig 81 | params: 82 | batch_size: 96 83 | num_workers: 5 84 | wrap: false 85 | train: 86 | target: ldm.data.lsun.LSUNChurchesTrain 87 | params: 88 | size: 256 89 | validation: 90 | target: ldm.data.lsun.LSUNChurchesValidation 91 | params: 92 | size: 256 93 | -------------------------------------------------------------------------------- /stable_diffusion/models/ldm/semantic_synthesis256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 64 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 64 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | ddconfig: 39 | double_z: false 40 | z_channels: 3 41 | resolution: 256 42 | in_channels: 3 43 | out_ch: 3 44 | ch: 128 45 | ch_mult: 46 | - 1 47 | - 2 48 | - 4 49 | num_res_blocks: 2 50 | attn_resolutions: [] 51 | dropout: 0.0 52 | lossconfig: 53 | target: torch.nn.Identity 54 | cond_stage_config: 55 | target: ldm.modules.encoders.modules.SpatialRescaler 56 | params: 57 | n_stages: 2 58 | in_channels: 182 59 | out_channels: 3 60 | -------------------------------------------------------------------------------- /stable_diffusion/models/ldm/semantic_synthesis512/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0205 7 | log_every_t: 100 8 | timesteps: 1000 9 | loss_type: l1 10 | first_stage_key: image 11 | cond_stage_key: segmentation 12 | image_size: 128 13 | channels: 3 14 | concat_mode: true 15 | cond_stage_trainable: true 16 | unet_config: 17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 18 | params: 19 | image_size: 128 20 | in_channels: 6 21 | out_channels: 3 22 | model_channels: 128 23 | attention_resolutions: 24 | - 32 25 | - 16 26 | - 8 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | - 4 31 | - 8 32 | num_heads: 8 33 | first_stage_config: 34 | target: ldm.models.autoencoder.VQModelInterface 35 | params: 36 | embed_dim: 3 37 | n_embed: 8192 38 | monitor: val/rec_loss 39 | ddconfig: 40 | double_z: false 41 | z_channels: 3 42 | resolution: 256 43 | in_channels: 3 44 | out_ch: 3 45 | ch: 128 46 | ch_mult: 47 | - 1 48 | - 2 49 | - 4 50 | num_res_blocks: 2 51 | attn_resolutions: [] 52 | dropout: 0.0 53 | lossconfig: 54 | target: torch.nn.Identity 55 | cond_stage_config: 56 | target: ldm.modules.encoders.modules.SpatialRescaler 57 | params: 58 | n_stages: 2 59 | in_channels: 182 60 | out_channels: 3 61 | data: 62 | target: main.DataModuleFromConfig 63 | params: 64 | batch_size: 8 65 | wrap: false 66 | num_workers: 10 67 | train: 68 | target: ldm.data.landscapes.RFWTrain 69 | params: 70 | size: 768 71 | crop_size: 512 72 | segmentation_to_float32: true 73 | validation: 74 | target: ldm.data.landscapes.RFWValidation 75 | params: 76 | size: 768 77 | crop_size: 512 78 | segmentation_to_float32: true 79 | -------------------------------------------------------------------------------- /stable_diffusion/models/ldm/text2img256/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 64 21 | in_channels: 3 22 | out_channels: 3 23 | model_channels: 192 24 | attention_resolutions: 25 | - 8 26 | - 4 27 | - 2 28 | num_res_blocks: 2 29 | channel_mult: 30 | - 1 31 | - 2 32 | - 3 33 | - 5 34 | num_head_channels: 32 35 | use_spatial_transformer: true 36 | transformer_depth: 1 37 | context_dim: 640 38 | first_stage_config: 39 | target: ldm.models.autoencoder.VQModelInterface 40 | params: 41 | embed_dim: 3 42 | n_embed: 8192 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: 60 | target: ldm.modules.encoders.modules.BERTEmbedder 61 | params: 62 | n_embed: 640 63 | n_layer: 32 64 | data: 65 | target: main.DataModuleFromConfig 66 | params: 67 | batch_size: 28 68 | num_workers: 10 69 | wrap: false 70 | train: 71 | target: ldm.data.previews.pytorch_dataset.PreviewsTrain 72 | params: 73 | size: 256 74 | validation: 75 | target: ldm.data.previews.pytorch_dataset.PreviewsValidation 76 | params: 77 | size: 256 78 | -------------------------------------------------------------------------------- /stable_diffusion/scripts/download_first_stages.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip 3 | wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip 4 | wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip 5 | wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip 6 | wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip 7 | wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip 8 | wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip 9 | wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip 10 | wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip 11 | 12 | 13 | 14 | cd models/first_stage_models/kl-f4 15 | unzip -o model.zip 16 | 17 | cd ../kl-f8 18 | unzip -o model.zip 19 | 20 | cd ../kl-f16 21 | unzip -o model.zip 22 | 23 | cd ../kl-f32 24 | unzip -o model.zip 25 | 26 | cd ../vq-f4 27 | unzip -o model.zip 28 | 29 | cd ../vq-f4-noattn 30 | unzip -o model.zip 31 | 32 | cd ../vq-f8 33 | unzip -o model.zip 34 | 35 | cd ../vq-f8-n256 36 | unzip -o model.zip 37 | 38 | cd ../vq-f16 39 | unzip -o model.zip 40 | 41 | cd ../.. -------------------------------------------------------------------------------- /stable_diffusion/scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip 3 | wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip 4 | wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip 5 | wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip 6 | wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip 7 | wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip 8 | wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip 9 | wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip 10 | wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip 11 | wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip 12 | wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip 13 | 14 | 15 | 16 | cd models/ldm/celeba256 17 | unzip -o celeba-256.zip 18 | 19 | cd ../ffhq256 20 | unzip -o ffhq-256.zip 21 | 22 | cd ../lsun_churches256 23 | unzip -o lsun_churches-256.zip 24 | 25 | cd ../lsun_beds256 26 | unzip -o lsun_beds-256.zip 27 | 28 | cd ../text2img256 29 | unzip -o model.zip 30 | 31 | cd ../cin256 32 | unzip -o model.zip 33 | 34 | cd ../semantic_synthesis512 35 | unzip -o model.zip 36 | 37 | cd ../semantic_synthesis256 38 | unzip -o model.zip 39 | 40 | cd ../bsr_sr 41 | unzip -o model.zip 42 | 43 | cd ../layout2img-openimages256 44 | unzip -o model.zip 45 | 46 | cd ../inpainting_big 47 | unzip -o model.zip 48 | 49 | cd ../.. 50 | -------------------------------------------------------------------------------- /stable_diffusion/scripts/inpaint.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, glob 2 | from omegaconf import OmegaConf 3 | from PIL import Image 4 | from tqdm import tqdm 5 | import numpy as np 6 | import torch 7 | from main import instantiate_from_config 8 | from ldm.models.diffusion.ddim import DDIMSampler 9 | 10 | 11 | def make_batch(image, mask, device): 12 | image = np.array(Image.open(image).convert("RGB")) 13 | image = image.astype(np.float32)/255.0 14 | image = image[None].transpose(0,3,1,2) 15 | image = torch.from_numpy(image) 16 | 17 | mask = np.array(Image.open(mask).convert("L")) 18 | mask = mask.astype(np.float32)/255.0 19 | mask = mask[None,None] 20 | mask[mask < 0.5] = 0 21 | mask[mask >= 0.5] = 1 22 | mask = torch.from_numpy(mask) 23 | 24 | masked_image = (1-mask)*image 25 | 26 | batch = {"image": image, "mask": mask, "masked_image": masked_image} 27 | for k in batch: 28 | batch[k] = batch[k].to(device=device) 29 | batch[k] = batch[k]*2.0-1.0 30 | return batch 31 | 32 | 33 | if __name__ == "__main__": 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument( 36 | "--indir", 37 | type=str, 38 | nargs="?", 39 | help="dir containing image-mask pairs (`example.png` and `example_mask.png`)", 40 | ) 41 | parser.add_argument( 42 | "--outdir", 43 | type=str, 44 | nargs="?", 45 | help="dir to write results to", 46 | ) 47 | parser.add_argument( 48 | "--steps", 49 | type=int, 50 | default=50, 51 | help="number of ddim sampling steps", 52 | ) 53 | opt = parser.parse_args() 54 | 55 | masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png"))) 56 | images = [x.replace("_mask.png", ".png") for x in masks] 57 | print(f"Found {len(masks)} inputs.") 58 | 59 | config = OmegaConf.load("models/ldm/inpainting_big/config.yaml") 60 | model = instantiate_from_config(config.model) 61 | model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], 62 | strict=False) 63 | 64 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 65 | model = model.to(device) 66 | sampler = DDIMSampler(model) 67 | 68 | os.makedirs(opt.outdir, exist_ok=True) 69 | with torch.no_grad(): 70 | with model.ema_scope(): 71 | for image, mask in tqdm(zip(images, masks)): 72 | outpath = os.path.join(opt.outdir, os.path.split(image)[1]) 73 | batch = make_batch(image, mask, device=device) 74 | 75 | # encode masked image and concat downsampled mask 76 | c = model.cond_stage_model.encode(batch["masked_image"]) 77 | cc = torch.nn.functional.interpolate(batch["mask"], 78 | size=c.shape[-2:]) 79 | c = torch.cat((c, cc), dim=1) 80 | 81 | shape = (c.shape[1]-1,)+c.shape[2:] 82 | samples_ddim, _ = sampler.sample(S=opt.steps, 83 | conditioning=c, 84 | batch_size=c.shape[0], 85 | shape=shape, 86 | verbose=False) 87 | x_samples_ddim = model.decode_first_stage(samples_ddim) 88 | 89 | image = torch.clamp((batch["image"]+1.0)/2.0, 90 | min=0.0, max=1.0) 91 | mask = torch.clamp((batch["mask"]+1.0)/2.0, 92 | min=0.0, max=1.0) 93 | predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0, 94 | min=0.0, max=1.0) 95 | 96 | inpainted = (1-mask)*image+mask*predicted_image 97 | inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255 98 | Image.fromarray(inpainted.astype(np.uint8)).save(outpath) 99 | -------------------------------------------------------------------------------- /stable_diffusion/scripts/latent_imagenet_diffusion.ipynb.REMOVED.git-id: -------------------------------------------------------------------------------- 1 | 607f94fc7d3ef6d8d1627017215476d9dfc7ddc4 -------------------------------------------------------------------------------- /stable_diffusion/scripts/tests/test_watermark.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import fire 3 | from imwatermark import WatermarkDecoder 4 | 5 | 6 | def testit(img_path): 7 | bgr = cv2.imread(img_path) 8 | decoder = WatermarkDecoder('bytes', 136) 9 | watermark = decoder.decode(bgr, 'dwtDct') 10 | try: 11 | dec = watermark.decode('utf-8') 12 | except: 13 | dec = "null" 14 | print(dec) 15 | 16 | 17 | if __name__ == "__main__": 18 | fire.Fire(testit) -------------------------------------------------------------------------------- /stable_diffusion/scripts/train_searcher.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import scann 4 | import argparse 5 | import glob 6 | from multiprocessing import cpu_count 7 | from tqdm import tqdm 8 | 9 | from ldm.util import parallel_data_prefetch 10 | 11 | 12 | def search_bruteforce(searcher): 13 | return searcher.score_brute_force().build() 14 | 15 | 16 | def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, 17 | partioning_trainsize, num_leaves, num_leaves_to_search): 18 | return searcher.tree(num_leaves=num_leaves, 19 | num_leaves_to_search=num_leaves_to_search, 20 | training_sample_size=partioning_trainsize). \ 21 | score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() 22 | 23 | 24 | def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): 25 | return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( 26 | reorder_k).build() 27 | 28 | def load_datapool(dpath): 29 | 30 | 31 | def load_single_file(saved_embeddings): 32 | compressed = np.load(saved_embeddings) 33 | database = {key: compressed[key] for key in compressed.files} 34 | return database 35 | 36 | def load_multi_files(data_archive): 37 | database = {key: [] for key in data_archive[0].files} 38 | for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): 39 | for key in d.files: 40 | database[key].append(d[key]) 41 | 42 | return database 43 | 44 | print(f'Load saved patch embedding from "{dpath}"') 45 | file_content = glob.glob(os.path.join(dpath, '*.npz')) 46 | 47 | if len(file_content) == 1: 48 | data_pool = load_single_file(file_content[0]) 49 | elif len(file_content) > 1: 50 | data = [np.load(f) for f in file_content] 51 | prefetched_data = parallel_data_prefetch(load_multi_files, data, 52 | n_proc=min(len(data), cpu_count()), target_data_type='dict') 53 | 54 | data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} 55 | else: 56 | raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') 57 | 58 | print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.') 59 | return data_pool 60 | 61 | 62 | def train_searcher(opt, 63 | metric='dot_product', 64 | partioning_trainsize=None, 65 | reorder_k=None, 66 | # todo tune 67 | aiq_thld=0.2, 68 | dims_per_block=2, 69 | num_leaves=None, 70 | num_leaves_to_search=None,): 71 | 72 | data_pool = load_datapool(opt.database) 73 | k = opt.knn 74 | 75 | if not reorder_k: 76 | reorder_k = 2 * k 77 | 78 | # normalize 79 | # embeddings = 80 | searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) 81 | pool_size = data_pool['embedding'].shape[0] 82 | 83 | print(*(['#'] * 100)) 84 | print('Initializing scaNN searcher with the following values:') 85 | print(f'k: {k}') 86 | print(f'metric: {metric}') 87 | print(f'reorder_k: {reorder_k}') 88 | print(f'anisotropic_quantization_threshold: {aiq_thld}') 89 | print(f'dims_per_block: {dims_per_block}') 90 | print(*(['#'] * 100)) 91 | print('Start training searcher....') 92 | print(f'N samples in pool is {pool_size}') 93 | 94 | # this reflects the recommended design choices proposed at 95 | # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md 96 | if pool_size < 2e4: 97 | print('Using brute force search.') 98 | searcher = search_bruteforce(searcher) 99 | elif 2e4 <= pool_size and pool_size < 1e5: 100 | print('Using asymmetric hashing search and reordering.') 101 | searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) 102 | else: 103 | print('Using using partioning, asymmetric hashing search and reordering.') 104 | 105 | if not partioning_trainsize: 106 | partioning_trainsize = data_pool['embedding'].shape[0] // 10 107 | if not num_leaves: 108 | num_leaves = int(np.sqrt(pool_size)) 109 | 110 | if not num_leaves_to_search: 111 | num_leaves_to_search = max(num_leaves // 20, 1) 112 | 113 | print('Partitioning params:') 114 | print(f'num_leaves: {num_leaves}') 115 | print(f'num_leaves_to_search: {num_leaves_to_search}') 116 | # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k) 117 | searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, 118 | partioning_trainsize, num_leaves, num_leaves_to_search) 119 | 120 | print('Finish training searcher') 121 | searcher_savedir = opt.target_path 122 | os.makedirs(searcher_savedir, exist_ok=True) 123 | searcher.serialize(searcher_savedir) 124 | print(f'Saved trained searcher under "{searcher_savedir}"') 125 | 126 | if __name__ == '__main__': 127 | sys.path.append(os.getcwd()) 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--database', 130 | '-d', 131 | default='data/rdm/retrieval_databases/openimages', 132 | type=str, 133 | help='path to folder containing the clip feature of the database') 134 | parser.add_argument('--target_path', 135 | '-t', 136 | default='data/rdm/searchers/openimages', 137 | type=str, 138 | help='path to the target folder where the searcher shall be stored.') 139 | parser.add_argument('--knn', 140 | '-k', 141 | default=20, 142 | type=int, 143 | help='number of nearest neighbors, for which the searcher shall be optimized') 144 | 145 | opt, _ = parser.parse_known_args() 146 | 147 | train_searcher(opt,) -------------------------------------------------------------------------------- /stable_diffusion/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='latent-diffusion', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) --------------------------------------------------------------------------------