├── AI_ETHICS.md
├── CODEOWNERS
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE.txt
├── README.md
├── SECURITY.md
├── configs
├── generate.yaml
├── generate_v21_base.yaml
└── train_v21_base.yaml
├── edit_cli.py
├── edit_cli_batch.py
├── edit_cli_batch_rw_label.py
├── edit_cli_rw_label.py
├── edit_dataset.py
├── environment.yaml
├── imgs
├── example1.jpg
├── example2.jpg
├── example3.jpg
└── results.png
├── main.py
├── metrics
├── clip_similarity.py
└── compute_metrics.py
├── scripts
├── download_checkpoints.sh
├── download_hive_data.sh
└── download_instructpix2pix_data.sh
└── stable_diffusion
├── LICENSE
├── README.md
├── Stable_Diffusion_v1_Model_Card.md
├── assets
├── a-painting-of-a-fire.png
├── a-photograph-of-a-fire.png
├── a-shirt-with-a-fire-printed-on-it.png
├── a-shirt-with-the-inscription-'fire'.png
├── a-watercolor-painting-of-a-fire.png
├── birdhouse.png
├── fire.png
├── inpainting.png
├── modelfigure.png
├── rdm-preview.jpg
├── reconstruction1.png
├── reconstruction2.png
├── results.gif.REMOVED.git-id
├── rick.jpeg
├── stable-samples
│ ├── img2img
│ │ ├── mountains-1.png
│ │ ├── mountains-2.png
│ │ ├── mountains-3.png
│ │ ├── sketch-mountains-input.jpg
│ │ ├── upscaling-in.png.REMOVED.git-id
│ │ └── upscaling-out.png.REMOVED.git-id
│ └── txt2img
│ │ ├── 000002025.png
│ │ ├── 000002035.png
│ │ ├── merged-0005.png.REMOVED.git-id
│ │ ├── merged-0006.png.REMOVED.git-id
│ │ └── merged-0007.png.REMOVED.git-id
├── the-earth-is-on-fire,-oil-on-canvas.png
├── txt2img-convsample.png
├── txt2img-preview.png.REMOVED.git-id
└── v1-variants-scores.jpg
├── configs
├── autoencoder
│ ├── autoencoder_kl_16x16x16.yaml
│ ├── autoencoder_kl_32x32x4.yaml
│ ├── autoencoder_kl_64x64x3.yaml
│ └── autoencoder_kl_8x8x64.yaml
├── latent-diffusion
│ ├── celebahq-ldm-vq-4.yaml
│ ├── cin-ldm-vq-f8.yaml
│ ├── cin256-v2.yaml
│ ├── ffhq-ldm-vq-4.yaml
│ ├── lsun_bedrooms-ldm-vq-4.yaml
│ ├── lsun_churches-ldm-kl-8.yaml
│ └── txt2img-1p4B-eval.yaml
├── retrieval-augmented-diffusion
│ └── 768x768.yaml
└── stable-diffusion
│ └── v1-inference.yaml
├── data
├── DejaVuSans.ttf
├── example_conditioning
│ ├── superresolution
│ │ └── sample_0.jpg
│ └── text_conditional
│ │ └── sample_0.txt
├── imagenet_clsidx_to_label.txt
├── imagenet_train_hr_indices.p.REMOVED.git-id
├── imagenet_val_hr_indices.p
├── index_synset.yaml
└── inpainting_examples
│ ├── 6458524847_2f4c361183_k.png
│ ├── 6458524847_2f4c361183_k_mask.png
│ ├── 8399166846_f6fb4e4b8e_k.png
│ ├── 8399166846_f6fb4e4b8e_k_mask.png
│ ├── alex-iby-G_Pk4D9rMLs.png
│ ├── alex-iby-G_Pk4D9rMLs_mask.png
│ ├── bench2.png
│ ├── bench2_mask.png
│ ├── bertrand-gabioud-CpuFzIsHYJ0.png
│ ├── bertrand-gabioud-CpuFzIsHYJ0_mask.png
│ ├── billow926-12-Wc-Zgx6Y.png
│ ├── billow926-12-Wc-Zgx6Y_mask.png
│ ├── overture-creations-5sI6fQgYIuo.png
│ ├── overture-creations-5sI6fQgYIuo_mask.png
│ ├── photo-1583445095369-9c651e7e5d34.png
│ └── photo-1583445095369-9c651e7e5d34_mask.png
├── environment.yaml
├── ldm
├── data
│ ├── __init__.py
│ ├── base.py
│ ├── imagenet.py
│ └── lsun.py
├── lr_scheduler.py
├── models
│ ├── autoencoder.py
│ └── diffusion
│ │ ├── __init__.py
│ │ ├── classifier.py
│ │ ├── ddim.py
│ │ ├── ddpm.py
│ │ ├── ddpm_edit.py
│ │ ├── ddpm_edit_rw.py
│ │ ├── ddpm_edit_v21.py
│ │ ├── dpm_solver
│ │ ├── __init__.py
│ │ ├── dpm_solver.py
│ │ └── sampler.py
│ │ └── plms.py
├── modules
│ ├── attention.py
│ ├── attention_v21.py
│ ├── diffusionmodules
│ │ ├── __init__.py
│ │ ├── model.py
│ │ ├── openaimodel.py
│ │ ├── openaimodel_v21.py
│ │ └── util.py
│ ├── distributions
│ │ ├── __init__.py
│ │ └── distributions.py
│ ├── ema.py
│ ├── encoders
│ │ ├── __init__.py
│ │ └── modules.py
│ ├── image_degradation
│ │ ├── __init__.py
│ │ ├── bsrgan.py
│ │ ├── bsrgan_light.py
│ │ ├── utils
│ │ │ └── test.png
│ │ └── utils_image.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── contperceptual.py
│ │ └── vqperceptual.py
│ └── x_transformer.py
└── util.py
├── main.py
├── models
├── first_stage_models
│ ├── kl-f16
│ │ └── config.yaml
│ ├── kl-f32
│ │ └── config.yaml
│ ├── kl-f4
│ │ └── config.yaml
│ ├── kl-f8
│ │ └── config.yaml
│ ├── vq-f16
│ │ └── config.yaml
│ ├── vq-f4-noattn
│ │ └── config.yaml
│ ├── vq-f4
│ │ └── config.yaml
│ ├── vq-f8-n256
│ │ └── config.yaml
│ └── vq-f8
│ │ └── config.yaml
└── ldm
│ ├── bsr_sr
│ └── config.yaml
│ ├── celeba256
│ └── config.yaml
│ ├── cin256
│ └── config.yaml
│ ├── ffhq256
│ └── config.yaml
│ ├── inpainting_big
│ └── config.yaml
│ ├── layout2img-openimages256
│ └── config.yaml
│ ├── lsun_beds256
│ └── config.yaml
│ ├── lsun_churches256
│ └── config.yaml
│ ├── semantic_synthesis256
│ └── config.yaml
│ ├── semantic_synthesis512
│ └── config.yaml
│ └── text2img256
│ └── config.yaml
├── notebook_helpers.py
├── scripts
├── download_first_stages.sh
├── download_models.sh
├── img2img.py
├── inpaint.py
├── knn2img.py
├── latent_imagenet_diffusion.ipynb.REMOVED.git-id
├── sample_diffusion.py
├── tests
│ └── test_watermark.py
├── train_searcher.py
└── txt2img.py
└── setup.py
/AI_ETHICS.md:
--------------------------------------------------------------------------------
1 | ## Ethics disclaimer for Salesforce AI models, data, code
2 |
3 | This release is for research purposes only in support of an academic
4 | paper. Our models, datasets, and code are not specifically designed or
5 | evaluated for all downstream purposes. We strongly recommend users
6 | evaluate and address potential concerns related to accuracy, safety, and
7 | fairness before deploying this model. We encourage users to consider the
8 | common limitations of AI, comply with applicable laws, and leverage best
9 | practices when selecting use cases, particularly for high-risk scenarios
10 | where errors or misuse could significantly impact people’s lives, rights,
11 | or safety. For further guidance on use cases, refer to our standard
12 | [AUP](https://www.salesforce.com/content/dam/web/en_us/www/documents/legal/Agreements/policies/ExternalFacing_Services_Policy.pdf)
13 | and [AI AUP](https://www.salesforce.com/content/dam/web/en_us/www/documents/legal/Agreements/policies/ai-acceptable-use-policy.pdf).
14 |
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | # Comment line immediately above ownership line is reserved for related other information. Please be careful while editing.
2 | #ECCN:Open Source
3 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Salesforce Open Source Community Code of Conduct
2 |
3 | ## About the Code of Conduct
4 |
5 | Equality is a core value at Salesforce. We believe a diverse and inclusive
6 | community fosters innovation and creativity, and are committed to building a
7 | culture where everyone feels included.
8 |
9 | Salesforce open-source projects are committed to providing a friendly, safe, and
10 | welcoming environment for all, regardless of gender identity and expression,
11 | sexual orientation, disability, physical appearance, body size, ethnicity, nationality,
12 | race, age, religion, level of experience, education, socioeconomic status, or
13 | other similar personal characteristics.
14 |
15 | The goal of this code of conduct is to specify a baseline standard of behavior so
16 | that people with different social values and communication styles can work
17 | together effectively, productively, and respectfully in our open source community.
18 | It also establishes a mechanism for reporting issues and resolving conflicts.
19 |
20 | All questions and reports of abusive, harassing, or otherwise unacceptable behavior
21 | in a Salesforce open-source project may be reported by contacting the Salesforce
22 | Open Source Conduct Committee at ossconduct@salesforce.com.
23 |
24 | ## Our Pledge
25 |
26 | In the interest of fostering an open and welcoming environment, we as
27 | contributors and maintainers pledge to making participation in our project and
28 | our community a harassment-free experience for everyone, regardless of gender
29 | identity and expression, sexual orientation, disability, physical appearance,
30 | body size, ethnicity, nationality, race, age, religion, level of experience, education,
31 | socioeconomic status, or other similar personal characteristics.
32 |
33 | ## Our Standards
34 |
35 | Examples of behavior that contributes to creating a positive environment
36 | include:
37 |
38 | * Using welcoming and inclusive language
39 | * Being respectful of differing viewpoints and experiences
40 | * Gracefully accepting constructive criticism
41 | * Focusing on what is best for the community
42 | * Showing empathy toward other community members
43 |
44 | Examples of unacceptable behavior by participants include:
45 |
46 | * The use of sexualized language or imagery and unwelcome sexual attention or
47 | advances
48 | * Personal attacks, insulting/derogatory comments, or trolling
49 | * Public or private harassment
50 | * Publishing, or threatening to publish, others' private information—such as
51 | a physical or electronic address—without explicit permission
52 | * Other conduct which could reasonably be considered inappropriate in a
53 | professional setting
54 | * Advocating for or encouraging any of the above behaviors
55 |
56 | ## Our Responsibilities
57 |
58 | Project maintainers are responsible for clarifying the standards of acceptable
59 | behavior and are expected to take appropriate and fair corrective action in
60 | response to any instances of unacceptable behavior.
61 |
62 | Project maintainers have the right and responsibility to remove, edit, or
63 | reject comments, commits, code, wiki edits, issues, and other contributions
64 | that are not aligned with this Code of Conduct, or to ban temporarily or
65 | permanently any contributor for other behaviors that they deem inappropriate,
66 | threatening, offensive, or harmful.
67 |
68 | ## Scope
69 |
70 | This Code of Conduct applies both within project spaces and in public spaces
71 | when an individual is representing the project or its community. Examples of
72 | representing a project or community include using an official project email
73 | address, posting via an official social media account, or acting as an appointed
74 | representative at an online or offline event. Representation of a project may be
75 | further defined and clarified by project maintainers.
76 |
77 | ## Enforcement
78 |
79 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
80 | reported by contacting the Salesforce Open Source Conduct Committee
81 | at ossconduct@salesforce.com. All complaints will be reviewed and investigated
82 | and will result in a response that is deemed necessary and appropriate to the
83 | circumstances. The committee is obligated to maintain confidentiality with
84 | regard to the reporter of an incident. Further details of specific enforcement
85 | policies may be posted separately.
86 |
87 | Project maintainers who do not follow or enforce the Code of Conduct in good
88 | faith may face temporary or permanent repercussions as determined by other
89 | members of the project's leadership and the Salesforce Open Source Conduct
90 | Committee.
91 |
92 | ## Attribution
93 |
94 | This Code of Conduct is adapted from the [Contributor Covenant][contributor-covenant-home],
95 | version 1.4, available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html.
96 | It includes adaptions and additions from [Go Community Code of Conduct][golang-coc],
97 | [CNCF Code of Conduct][cncf-coc], and [Microsoft Open Source Code of Conduct][microsoft-coc].
98 |
99 | This Code of Conduct is licensed under the [Creative Commons Attribution 3.0 License][cc-by-3-us].
100 |
101 | [contributor-covenant-home]: https://www.contributor-covenant.org (https://www.contributor-covenant.org/)
102 | [golang-coc]: https://golang.org/conduct
103 | [cncf-coc]: https://github.com/cncf/foundation/blob/master/code-of-conduct.md
104 | [microsoft-coc]: https://opensource.microsoft.com/codeofconduct/
105 | [cc-by-3-us]: https://creativecommons.org/licenses/by/3.0/us/
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guide
2 |
3 | This page lists the operational governance model of this project, as well as the recommendations and requirements for how to best contribute to it. We strive to obey these as best as possible. As always, thanks for contributing – we hope these guidelines make it easier and shed some light on our approach and processes.
4 |
5 | # Governance Model
6 |
7 | ## Published but not supported
8 |
9 | The intent and goal of open sourcing this project is because it may contain useful or interesting code/concepts that we wish to share with the larger open source community. Although occasional work may be done on it, we will not be looking for or soliciting contributions.
10 |
11 | # Getting started
12 |
13 | Please join the community. Also please make sure to take a look at the project [roadmap](ROADMAP.md), if it exists, to see where are headed.
14 |
15 | # Issues, requests & ideas
16 |
17 | Use GitHub Issues page to submit issues, enhancement requests and discuss ideas.
18 |
19 | ### Bug Reports and Fixes
20 | - If you find a bug, please search for it in the Issues, and if it isn't already tracked,
21 | create a new issue. Fill out the "Bug Report" section of the issue template. Even if an Issue is closed, feel free to comment and add details, it will still
22 | be reviewed.
23 | - Issues that have already been identified as a bug (note: able to reproduce) will be labelled `bug`.
24 | - If you'd like to submit a fix for a bug, [send a Pull Request](#creating_a_pull_request) and mention the Issue number.
25 | - Include tests that isolate the bug and verifies that it was fixed.
26 |
27 | ### New Features
28 | - If you'd like to add new functionality to this project, describe the problem you want to solve in a new Issue.
29 | - Issues that have been identified as a feature request will be labelled `enhancement`.
30 | - If you'd like to implement the new feature, please wait for feedback from the project
31 | maintainers before spending too much time writing the code. In some cases, `enhancement`s may
32 | not align well with the project objectives at the time.
33 |
34 | ### Tests, Documentation, Miscellaneous
35 | - If you'd like to improve the tests, you want to make the documentation clearer, you have an
36 | alternative implementation of something that may have advantages over the way its currently
37 | done, or you have any other change, we would be happy to hear about it!
38 | - If its a trivial change, go ahead and [send a Pull Request](#creating_a_pull_request) with the changes you have in mind.
39 | - If not, open an Issue to discuss the idea first.
40 |
41 | If you're new to our project and looking for some way to make your first contribution, look for
42 | Issues labelled `good first contribution`.
43 |
44 | # Contribution Checklist
45 |
46 | - [x] Clean, simple, well styled code
47 | - [x] Commits should be atomic and messages must be descriptive. Related issues should be mentioned by Issue number.
48 | - [x] Comments
49 | - Module-level & function-level comments.
50 | - Comments on complex blocks of code or algorithms (include references to sources).
51 | - [x] Tests
52 | - The test suite, if provided, must be complete and pass
53 | - Increase code coverage, not versa.
54 | - Use any of our testkits that contains a bunch of testing facilities you would need. For example: `import com.salesforce.op.test._` and borrow inspiration from existing tests.
55 | - [x] Dependencies
56 | - Minimize number of dependencies.
57 | - Prefer Apache 2.0, BSD3, MIT, ISC and MPL licenses.
58 | - [x] Reviews
59 | - Changes must be approved via peer code review
60 |
61 | # Creating a Pull Request
62 |
63 | 1. **Ensure the bug/feature was not already reported** by searching on GitHub under Issues. If none exists, create a new issue so that other contributors can keep track of what you are trying to add/fix and offer suggestions (or let you know if there is already an effort in progress).
64 | 3. **Clone** the forked repo to your machine.
65 | 4. **Create** a new branch to contain your work (e.g. `git br fix-issue-11`)
66 | 4. **Commit** changes to your own branch.
67 | 5. **Push** your work back up to your fork. (e.g. `git push fix-issue-11`)
68 | 6. **Submit** a Pull Request against the `main` branch and refer to the issue(s) you are fixing. Try not to pollute your pull request with unintended changes. Keep it simple and small.
69 | 7. **Sign** the Salesforce CLA (you will be prompted to do so when submitting the Pull Request)
70 |
71 | > **NOTE**: Be sure to [sync your fork](https://help.github.com/articles/syncing-a-fork/) before making a pull request.
72 |
73 |
74 | # Code of Conduct
75 | Please follow our [Code of Conduct](CODE_OF_CONDUCT.md).
76 |
77 | # License
78 | By contributing your code, you agree to license your contribution under the terms of our project [LICENSE](LICENSE.txt) and to sign the [Salesforce CLA](https://cla.salesforce.com/sign-cla)
79 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # HIVE
2 |
3 | ### [HIVE: Harnessing Human Feedback for Instructional Visual Editing](https://arxiv.org/pdf/2303.09618.pdf)
4 | Shu Zhang\*1, Xinyi Yang\*1, Yihao Feng\*1, Can Qin3, Chia-Chih Chen1, Ning Yu1, Zeyuan Chen1, Huan Wang1, Silvio Savarese1,2, Stefano Ermon2, Caiming Xiong1, and Ran Xu1
5 | 1Salesforce AI, 2Stanford University, 3Northeastern University
6 | \*denotes equal contribution
7 | arXiv 2023
8 |
9 | ### [paper](https://arxiv.org/pdf/2303.09618.pdf) | [project page](https://shugerdou.github.io/hive/)
10 |
11 |
12 |
13 | This is a PyTorch implementation of [HIVE: Harnessing Human Feedback for Instructional Visual Editing](https://arxiv.org/pdf/2303.09618.pdf). The major part of the code follows [InstructPix2Pix](https://github.com/timothybrooks/instruct-pix2pix). In this repo, we have implemented both [stable diffusion v1.5-base](https://huggingface.co/runwayml/stable-diffusion-v1-5) and [stable diffusion v2.1-base](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) as the backbone.
14 |
15 |
16 | ## Updates
17 | * **07/08/23**: ***Training code and training data is public.***:blush:
18 | * **03/26/24**: ***HIVE will appear in CVPR, 2024.***:blush:
19 |
20 | ## Usage
21 |
22 | ### Preparation
23 | First set-up the ```hive``` enviroment and download the pretrianed model as below. This is only verified on CUDA 11.0 and CUDA 11.3 with NVIDIA A100 GPU.
24 |
25 | ```
26 | conda env create -f environment.yaml
27 | conda activate hive
28 | bash scripts/download_checkpoints.sh
29 | ```
30 |
31 | To fine-tune a stable diffusion model, you need to obtain the pre-trained stable diffusion models following their [instructions](https://github.com/runwayml/stable-diffusion). If you use SD-V1.5, you can download the huggingface weights [HuggingFace SD 1.5](https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt). If you use SD-V2.1, the weights can be downloaded on [HuggingFace SD 2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1-base). You can decide which version of checkpoint to use. We use ```v2-1_512-ema-pruned.ckpt```. Download the model to checkpoints/.
32 |
33 |
34 | ### Data
35 | We suggest to install Gcloud CLI following [Gcloud download](https://cloud.google.com/sdk/docs/install). To obtain both training and evaluation data, run
36 | ```
37 | bash scripts/download_hive_data.sh
38 | ```
39 |
40 | An alternative method is to directly download the data through [Evaluation data](https://storage.cloud.google.com/sfr-hive-data-research/data/evaluation.zip) and [Evaluation instructions](https://storage.cloud.google.com/sfr-hive-data-research/data/test.jsonl).
41 |
42 |
43 | ### Step-1 Training
44 | For SD v2.1, we run
45 |
46 | ```
47 | python main.py --name step1 --base configs/train_v21_base.yaml --train --gpus 0,1,2,3,4,5,6,7
48 | ```
49 |
50 | ### Inference
51 | Samples can be obtained by running the command.
52 |
53 | For SD v2.1, if we use the conditional reward, we run
54 |
55 | ```
56 | python edit_cli_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
57 | --input imgs/example1.jpg --output imgs/output.jpg --edit "move it to Mars" --ckpt checkpoints/hive_v2_rw_condition.ckpt \
58 | --config configs/generate_v21_base.yaml
59 | ```
60 |
61 |
62 | or run batch inference on our inference data:
63 |
64 | ```
65 | python edit_cli_batch_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
66 | --jsonl_file data/test.jsonl --output_dir imgs/sdv21_rw_label/ --ckpt checkpoints/hive_v2_rw_condition.ckpt \
67 | --config configs/generate_v21_base.yaml --image_dir data/evaluation/
68 | ```
69 |
70 | For SD v2.1, if we use the weighted reward, we can run
71 |
72 |
73 | ```
74 | python edit_cli.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
75 | --input imgs/example1.jpg --output imgs/output.jpg --edit "move it to Mars" \
76 | --ckpt checkpoints/hive_v2_rw.ckpt --config configs/generate_v21_base.yaml
77 | ```
78 |
79 | or run batch inference on our inference data:
80 |
81 | ```
82 | python edit_cli_batch.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
83 | --jsonl_file data/test.jsonl --output_dir imgs/sdv21/ --ckpt checkpoints/hive_v2_rw.ckpt \
84 | --config configs/generate_v21_base.yaml --image_dir data/evaluation/
85 | ```
86 |
87 | For SD v1.5, if we use the conditional reward, we can run
88 |
89 | ```
90 | python edit_cli_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
91 | --input imgs/example1.jpg --output imgs/output.jpg --edit "move it to Mars" \
92 | --ckpt checkpoints/hive_rw_condition.ckpt --config configs/generate.yaml
93 | ```
94 |
95 | or run batch inference on our inference data:
96 |
97 | ```
98 | python edit_cli_batch_rw_label.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
99 | --jsonl_file data/test.jsonl --output_dir imgs/sdv15_rw_label/ \
100 | --ckpt checkpoints/hive_rw_condition.ckpt --config configs/generate.yaml \
101 | --image_dir data/evaluation/
102 | ```
103 |
104 | For SD v1.5, if we use the weighted reward, we run
105 |
106 |
107 | ```
108 | python edit_cli.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 --input imgs/example1.jpg \
109 | --output imgs/output.jpg --edit "move it to Mars" \
110 | --ckpt checkpoints/hive_rw.ckpt --config configs/generate.yaml
111 | ```
112 |
113 | or run batch inference on our inference data:
114 |
115 | ```
116 | python edit_cli_batch.py --steps 100 --resolution 512 --seed 100 --cfg-text 7.5 --cfg-image 1.5 \
117 | --jsonl_file data/test.jsonl --output_dir imgs/sdv15/ \
118 | --ckpt checkpoints/hive_rw.ckpt --config configs/generate.yaml \
119 | --image_dir data/evaluation/
120 | ```
121 |
122 | ## Citation
123 | ```
124 | @article{zhang2023hive,
125 | title={HIVE: Harnessing Human Feedback for Instructional Visual Editing},
126 | author={Zhang, Shu and Yang, Xinyi and Feng, Yihao and Qin, Can and Chen, Chia-Chih and Yu, Ning and Chen, Zeyuan and Wang, Huan and Savarese, Silvio and Ermon, Stefano and Xiong, Caiming and Xu, Ran},
127 | journal={arXiv preprint arXiv:2303.09618},
128 | year={2023}
129 | }
130 | ```
131 |
132 |
133 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | ## Security
2 |
3 | Please report any security issue to [security@salesforce.com](mailto:security@salesforce.com)
4 | as soon as it is discovered. This library limits its runtime dependencies in
5 | order to reduce the total cost of ownership as much as can be, but all consumers
6 | should remain vigilant and have their security stakeholders review all third-party
7 | products (3PP) like this one and their dependencies.
--------------------------------------------------------------------------------
/configs/generate.yaml:
--------------------------------------------------------------------------------
1 | # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
2 | # See more details in LICENSE.
3 |
4 | model:
5 | base_learning_rate: 1.0e-04
6 | target: ldm.models.diffusion.ddpm_edit.LatentDiffusion
7 | params:
8 | linear_start: 0.00085
9 | linear_end: 0.0120
10 | num_timesteps_cond: 1
11 | log_every_t: 200
12 | timesteps: 1000
13 | first_stage_key: edited
14 | cond_stage_key: edit
15 | # image_size: 64
16 | # image_size: 32
17 | image_size: 16
18 | channels: 4
19 | cond_stage_trainable: false # Note: different from the one we trained before
20 | conditioning_key: hybrid
21 | monitor: val/loss_simple_ema
22 | scale_factor: 0.18215
23 | use_ema: true
24 | load_ema: true
25 |
26 | scheduler_config: # 10000 warmup steps
27 | target: ldm.lr_scheduler.LambdaLinearScheduler
28 | params:
29 | warm_up_steps: [ 0 ]
30 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
31 | f_start: [ 1.e-6 ]
32 | f_max: [ 1. ]
33 | f_min: [ 1. ]
34 |
35 | unet_config:
36 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
37 | params:
38 | image_size: 32 # unused
39 | in_channels: 8
40 | out_channels: 4
41 | model_channels: 320
42 | attention_resolutions: [ 4, 2, 1 ]
43 | num_res_blocks: 2
44 | channel_mult: [ 1, 2, 4, 4 ]
45 | num_heads: 8
46 | use_spatial_transformer: True
47 | transformer_depth: 1
48 | context_dim: 768
49 | use_checkpoint: True
50 | legacy: False
51 |
52 | first_stage_config:
53 | target: ldm.models.autoencoder.AutoencoderKL
54 | params:
55 | embed_dim: 4
56 | monitor: val/rec_loss
57 | ddconfig:
58 | double_z: true
59 | z_channels: 4
60 | resolution: 256
61 | in_channels: 3
62 | out_ch: 3
63 | ch: 128
64 | ch_mult:
65 | - 1
66 | - 2
67 | - 4
68 | - 4
69 | num_res_blocks: 2
70 | attn_resolutions: []
71 | dropout: 0.0
72 | lossconfig:
73 | target: torch.nn.Identity
74 |
75 | cond_stage_config:
76 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
77 |
78 | data:
79 | target: main.DataModuleFromConfig
80 | params:
81 | batch_size: 128
82 | num_workers: 1
83 | wrap: false
84 | validation:
85 | target: edit_dataset.EditDataset
86 | params:
87 | path: ./data/training/instructpix2pix
88 | cache_dir: data/
89 | cache_name: data_10k
90 | split: val
91 | min_text_sim: 0.2
92 | min_image_sim: 0.75
93 | min_direction_sim: 0.2
94 | max_samples_per_prompt: 1
95 | min_resize_res: 512
96 | max_resize_res: 512
97 | crop_res: 512
98 | output_as_edit: False
99 | real_input: True
100 |
--------------------------------------------------------------------------------
/configs/generate_v21_base.yaml:
--------------------------------------------------------------------------------
1 | # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
2 | # See more details in LICENSE.
3 |
4 | model:
5 | base_learning_rate: 1.0e-04
6 | target: ldm.models.diffusion.ddpm_edit_v21.LatentDiffusion
7 | params:
8 | ckpt_path: checkpoints/v2-1_512-ema-pruned.ckpt
9 | linear_start: 0.00085
10 | linear_end: 0.0120
11 | num_timesteps_cond: 1
12 | log_every_t: 200
13 | timesteps: 1000
14 | first_stage_key: edited
15 | cond_stage_key: edit
16 | image_size: 32
17 | channels: 4
18 | cond_stage_trainable: false # Note: different from the one we trained before
19 | conditioning_key: hybrid
20 | monitor: val/loss_simple_ema
21 | scale_factor: 0.18215
22 | use_ema: true
23 | load_ema: true
24 |
25 | scheduler_config: # 10000 warmup steps
26 | target: ldm.lr_scheduler.LambdaLinearScheduler
27 | params:
28 | warm_up_steps: [ 0 ]
29 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
30 | f_start: [ 1.e-6 ]
31 | f_max: [ 1. ]
32 | f_min: [ 1. ]
33 |
34 | unet_config:
35 | target: ldm.modules.diffusionmodules.openaimodel_v21.UNetModel
36 | params:
37 | image_size: 32 # unused
38 | in_channels: 8
39 | out_channels: 4
40 | model_channels: 320
41 | attention_resolutions: [ 4, 2, 1 ]
42 | num_res_blocks: 2
43 | channel_mult: [ 1, 2, 4, 4 ]
44 | num_heads: 8
45 | num_head_channels: 64 # need to fix for flash-attn
46 | use_spatial_transformer: True
47 | use_linear_in_transformer: True
48 | transformer_depth: 1
49 | context_dim: 1024
50 | use_checkpoint: True
51 | legacy: False
52 |
53 | first_stage_config:
54 | target: ldm.models.autoencoder.AutoencoderKL
55 | params:
56 | embed_dim: 4
57 | monitor: val/rec_loss
58 | ddconfig:
59 | double_z: true
60 | z_channels: 4
61 | resolution: 256
62 | in_channels: 3
63 | out_ch: 3
64 | ch: 128
65 | ch_mult:
66 | - 1
67 | - 2
68 | - 4
69 | - 4
70 | num_res_blocks: 2
71 | attn_resolutions: []
72 | dropout: 0.0
73 | lossconfig:
74 | target: torch.nn.Identity
75 |
76 | cond_stage_config:
77 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
78 | params:
79 | freeze: True
80 | layer: "penultimate"
81 |
82 | data:
83 | target: main.DataModuleFromConfig
84 | params:
85 | batch_size: 128
86 | num_workers: 1
87 | wrap: false
88 | validation:
89 | target: edit_dataset.EditDataset
90 | params:
91 | path: data/clip-filtered-dataset
92 | cache_dir: data/
93 | cache_name: data_10k
94 | split: val
95 | min_text_sim: 0.2
96 | min_image_sim: 0.75
97 | min_direction_sim: 0.2
98 | max_samples_per_prompt: 1
99 | min_resize_res: 512
100 | max_resize_res: 512
101 | crop_res: 512
102 | output_as_edit: False
103 | real_input: True
104 |
--------------------------------------------------------------------------------
/configs/train_v21_base.yaml:
--------------------------------------------------------------------------------
1 | # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
2 | # See more details in LICENSE.
3 |
4 | model:
5 | base_learning_rate: 1.0e-04
6 | target: ldm.models.diffusion.ddpm_edit_v21.LatentDiffusion
7 | params:
8 | ckpt_path: ./checkpoints/v2-1_512-ema-pruned.ckpt
9 | linear_start: 0.00085
10 | linear_end: 0.0120
11 | num_timesteps_cond: 1
12 | log_every_t: 200
13 | timesteps: 1000
14 | first_stage_key: edited
15 | cond_stage_key: edit
16 | image_size: 32
17 | channels: 4
18 | cond_stage_trainable: false # Note: different from the one we trained before
19 | conditioning_key: hybrid
20 | monitor: val/loss_simple_ema
21 | scale_factor: 0.18215
22 | use_ema: true
23 | load_ema: false
24 |
25 | scheduler_config: # 10000 warmup steps
26 | target: ldm.lr_scheduler.LambdaLinearScheduler
27 | params:
28 | warm_up_steps: [ 0 ]
29 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
30 | f_start: [ 1.e-6 ]
31 | f_max: [ 1. ]
32 | f_min: [ 1. ]
33 |
34 | unet_config:
35 | target: ldm.modules.diffusionmodules.openaimodel_v21.UNetModel
36 | params:
37 | image_size: 32 # unused
38 | in_channels: 8
39 | out_channels: 4
40 | model_channels: 320
41 | attention_resolutions: [ 4, 2, 1 ]
42 | num_res_blocks: 2
43 | channel_mult: [ 1, 2, 4, 4 ]
44 | num_heads: 8
45 | num_head_channels: 64 # need to fix for flash-attn
46 | use_spatial_transformer: True
47 | use_linear_in_transformer: True
48 | transformer_depth: 1
49 | context_dim: 1024
50 | use_checkpoint: True
51 | legacy: False
52 |
53 | first_stage_config:
54 | target: ldm.models.autoencoder.AutoencoderKL
55 | params:
56 | embed_dim: 4
57 | monitor: val/rec_loss
58 | ddconfig:
59 | double_z: true
60 | z_channels: 4
61 | resolution: 256
62 | in_channels: 3
63 | out_ch: 3
64 | ch: 128
65 | ch_mult:
66 | - 1
67 | - 2
68 | - 4
69 | - 4
70 | num_res_blocks: 2
71 | attn_resolutions: []
72 | dropout: 0.0
73 | lossconfig:
74 | target: torch.nn.Identity
75 |
76 | cond_stage_config:
77 | target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
78 | params:
79 | freeze: True
80 | layer: "penultimate"
81 |
82 |
83 | data:
84 | target: main.DataModuleFromConfig
85 | params:
86 | batch_size: 32
87 | num_workers: 2
88 | train:
89 | target: edit_dataset.EditDataset
90 | params:
91 | path_instructpix2pix: ./data/training/instructpix2pix
92 | path_hive_0: ./data/training
93 | path_hive_1: ./data/training/part_0_blip_prompt_new
94 | path_hive_2: ./data/training/part_1_blip_prompt_new
95 | split: train
96 | min_resize_res: 256
97 | max_resize_res: 256
98 | crop_res: 256
99 | flip_prob: 0.5
100 | validation:
101 | target: edit_dataset.EditDataset
102 | params:
103 | path_instructpix2pix: ./data/training/instructpix2pix
104 | path_hive_0: ./data/training
105 | path_hive_1: ./data/training/part_0_blip_prompt_new
106 | path_hive_2: ./data/training/part_1_blip_prompt_new
107 | split: val
108 | min_resize_res: 256
109 | max_resize_res: 256
110 | crop_res: 256
111 |
112 | lightning:
113 | callbacks:
114 | image_logger:
115 | target: main.ImageLogger
116 | params:
117 | batch_frequency: 2000
118 | max_images: 2
119 | increase_log_steps: False
120 |
121 | trainer:
122 | max_epochs: 3000
123 | benchmark: True
124 | accumulate_grad_batches: 4
125 | check_val_every_n_epoch: 4
126 |
--------------------------------------------------------------------------------
/edit_cli.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2023 Salesforce, Inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: Apache License 2.0
5 | * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
6 | * By Shu Zhang
7 | * Modified from InstructPix2Pix repo: https://github.com/timothybrooks/instruct-pix2pix
8 | * Copyright (c) 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros. All rights reserved.
9 | '''
10 |
11 | from __future__ import annotations
12 |
13 | import math
14 | import random
15 | import sys
16 | from argparse import ArgumentParser
17 |
18 | import einops
19 | import k_diffusion as K
20 | import numpy as np
21 | import torch
22 | import torch.nn as nn
23 | from einops import rearrange
24 | from omegaconf import OmegaConf
25 | from PIL import Image, ImageOps
26 | from torch import autocast
27 |
28 | sys.path.append("./stable_diffusion")
29 |
30 | from stable_diffusion.ldm.util import instantiate_from_config
31 |
32 |
33 | class CFGDenoiser(nn.Module):
34 | def __init__(self, model):
35 | super().__init__()
36 | self.inner_model = model
37 |
38 | def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale):
39 | cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
40 | cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
41 | cfg_cond = {
42 | "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])],
43 | "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])],
44 | }
45 | out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)
46 | return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
47 |
48 |
49 | def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
50 | print(f"Loading model from {ckpt}")
51 | pl_sd = torch.load(ckpt, map_location="cpu")
52 | if "global_step" in pl_sd:
53 | print(f"Global Step: {pl_sd['global_step']}")
54 | sd = pl_sd["state_dict"]
55 | if vae_ckpt is not None:
56 | print(f"Loading VAE from {vae_ckpt}")
57 | vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
58 | sd = {
59 | k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
60 | for k, v in sd.items()
61 | }
62 | model = instantiate_from_config(config.model)
63 | m, u = model.load_state_dict(sd, strict=False)
64 | if len(m) > 0 and verbose:
65 | print("missing keys:")
66 | print(m)
67 | if len(u) > 0 and verbose:
68 | print("unexpected keys:")
69 | print(u)
70 | return model
71 |
72 |
73 | def main():
74 | parser = ArgumentParser()
75 | parser.add_argument("--resolution", default=512, type=int)
76 | parser.add_argument("--steps", default=100, type=int)
77 | parser.add_argument("--config", default="configs/generate.yaml", type=str)
78 | parser.add_argument("--ckpt", default="checkpoints/hive.ckpt", type=str)
79 | parser.add_argument("--vae-ckpt", default=None, type=str)
80 | parser.add_argument("--input", required=True, type=str)
81 | parser.add_argument("--output", required=True, type=str)
82 | parser.add_argument("--edit", required=True, type=str)
83 | parser.add_argument("--cfg-text", default=7.5, type=float)
84 | parser.add_argument("--cfg-image", default=1.5, type=float)
85 | parser.add_argument("--seed", type=int)
86 | args = parser.parse_args()
87 |
88 | config = OmegaConf.load(args.config)
89 | model = load_model_from_config(config, args.ckpt, args.vae_ckpt)
90 | model.eval().cuda()
91 | model_wrap = K.external.CompVisDenoiser(model)
92 | model_wrap_cfg = CFGDenoiser(model_wrap)
93 | null_token = model.get_learned_conditioning([""])
94 |
95 | seed = random.randint(0, 100000) if args.seed is None else args.seed
96 | input_image = Image.open(args.input).convert("RGB")
97 | width, height = input_image.size
98 | factor = args.resolution / max(width, height)
99 | factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
100 | width = int((width * factor) // 64) * 64
101 | height = int((height * factor) // 64) * 64
102 | input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
103 |
104 | if args.edit == "":
105 | input_image.save(args.output)
106 | return
107 |
108 | with torch.no_grad(), autocast("cuda"), model.ema_scope():
109 | cond = {}
110 | cond["c_crossattn"] = [model.get_learned_conditioning([args.edit])]
111 | input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
112 | input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device)
113 | cond["c_concat"] = [model.encode_first_stage(input_image).mode()]
114 |
115 | uncond = {}
116 | uncond["c_crossattn"] = [null_token]
117 | uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
118 |
119 | sigmas = model_wrap.get_sigmas(args.steps)
120 |
121 | extra_args = {
122 | "cond": cond,
123 | "uncond": uncond,
124 | "text_cfg_scale": args.cfg_text,
125 | "image_cfg_scale": args.cfg_image,
126 | }
127 | torch.manual_seed(seed)
128 | z = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
129 | z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args)
130 | x = model.decode_first_stage(z)
131 | x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
132 | x = 255.0 * rearrange(x, "1 c h w -> h w c")
133 | edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy())
134 | edited_image.save(args.output)
135 |
136 |
137 | if __name__ == "__main__":
138 | main()
139 |
--------------------------------------------------------------------------------
/edit_cli_batch.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2023 Salesforce, Inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: Apache License 2.0
5 | * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
6 | * By Shu Zhang
7 | * Modified from InstructPix2Pix repo: https://github.com/timothybrooks/instruct-pix2pix
8 | * Copyright (c) 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros. All rights reserved.
9 | '''
10 |
11 | from __future__ import annotations
12 |
13 | import math
14 | import os
15 | import random
16 | import sys
17 | from argparse import ArgumentParser
18 |
19 | import einops
20 | import k_diffusion as K
21 | import numpy as np
22 | import torch
23 | import torch.nn as nn
24 | import glob
25 | import re
26 | import jsonlines
27 | from einops import rearrange
28 | from omegaconf import OmegaConf
29 | from PIL import Image, ImageOps
30 | from torch import autocast
31 |
32 | sys.path.append("./stable_diffusion")
33 |
34 | from stable_diffusion.ldm.util import instantiate_from_config
35 |
36 |
37 | class CFGDenoiser(nn.Module):
38 | def __init__(self, model):
39 | super().__init__()
40 | self.inner_model = model
41 |
42 | def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale):
43 | cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
44 | cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
45 | cfg_cond = {
46 | "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])],
47 | "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])],
48 | }
49 | out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)
50 | return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
51 |
52 |
53 | def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
54 | print(f"Loading model from {ckpt}")
55 | pl_sd = torch.load(ckpt, map_location="cpu")
56 | if "global_step" in pl_sd:
57 | print(f"Global Step: {pl_sd['global_step']}")
58 | sd = pl_sd["state_dict"]
59 | if vae_ckpt is not None:
60 | print(f"Loading VAE from {vae_ckpt}")
61 | vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
62 | sd = {
63 | k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
64 | for k, v in sd.items()
65 | }
66 | model = instantiate_from_config(config.model)
67 | m, u = model.load_state_dict(sd, strict=False)
68 | if len(m) > 0 and verbose:
69 | print("missing keys:")
70 | print(m)
71 | if len(u) > 0 and verbose:
72 | print("unexpected keys:")
73 | print(u)
74 | return model
75 |
76 |
77 | def main():
78 | parser = ArgumentParser()
79 | parser.add_argument("--resolution", default=512, type=int)
80 | parser.add_argument("--steps", default=100, type=int)
81 | parser.add_argument("--config", default="configs/generate.yaml", type=str)
82 | parser.add_argument("--ckpt", default="checkpoints/hive.ckpt", type=str)
83 | parser.add_argument("--vae-ckpt", default=None, type=str)
84 | parser.add_argument("--output_dir", required=True, type=str)
85 | parser.add_argument("--jsonl_file", required=True, type=str)
86 | parser.add_argument("--image_dir", required=True, default="data/evaluation/", type=str)
87 | parser.add_argument("--cfg-text", default=7.5, type=float)
88 | parser.add_argument("--cfg-image", default=1.5, type=float)
89 | parser.add_argument("--seed", default=100, type=int)
90 | args = parser.parse_args()
91 |
92 | config = OmegaConf.load(args.config)
93 | model = load_model_from_config(config, args.ckpt, args.vae_ckpt)
94 | model.eval().cuda()
95 | model_wrap = K.external.CompVisDenoiser(model)
96 | model_wrap_cfg = CFGDenoiser(model_wrap)
97 | null_token = model.get_learned_conditioning([""])
98 |
99 | seed = random.randint(0, 100000) if args.seed is None else args.seed
100 | output_dir = args.output_dir
101 | if not os.path.exists(output_dir):
102 | os.makedirs(output_dir)
103 | image_dir = args.image_dir
104 | instructions = []
105 | image_paths = []
106 | with jsonlines.open(args.jsonl_file) as reader:
107 | for ll in reader:
108 | instructions.append(ll["instruction"])
109 | image_paths.append(os.path.join(image_dir, ll['source_img']))
110 |
111 | for i, instruction in enumerate(instructions):
112 | output_image = os.path.join(output_dir, f'instruct_{i}.png')
113 | input_image = Image.open(image_paths[i]).convert("RGB")
114 | width, height = input_image.size
115 | factor = args.resolution / max(width, height)
116 | factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
117 | width = int((width * factor) // 64) * 64
118 | height = int((height * factor) // 64) * 64
119 | input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
120 | edit = instructions[i]
121 |
122 | with torch.no_grad(), autocast("cuda"), model.ema_scope():
123 | cond = {}
124 | cond["c_crossattn"] = [model.get_learned_conditioning([edit])]
125 | input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
126 | input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device)
127 | cond["c_concat"] = [model.encode_first_stage(input_image).mode()]
128 |
129 | uncond = {}
130 | uncond["c_crossattn"] = [null_token]
131 | uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
132 |
133 | sigmas = model_wrap.get_sigmas(args.steps)
134 |
135 | extra_args = {
136 | "cond": cond,
137 | "uncond": uncond,
138 | "text_cfg_scale": args.cfg_text,
139 | "image_cfg_scale": args.cfg_image,
140 | }
141 | torch.manual_seed(seed)
142 | z = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
143 | z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args)
144 | x = model.decode_first_stage(z)
145 | x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
146 | x = 255.0 * rearrange(x, "1 c h w -> h w c")
147 | edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy())
148 | edited_image.save(output_image)
149 |
150 |
151 | if __name__ == "__main__":
152 | main()
153 |
--------------------------------------------------------------------------------
/edit_cli_batch_rw_label.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2023 Salesforce, Inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: Apache License 2.0
5 | * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
6 | * By Shu Zhang
7 | * Modified from InstructPix2Pix repo: https://github.com/timothybrooks/instruct-pix2pix
8 | * Copyright (c) 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros. All rights reserved.
9 | '''
10 |
11 | from __future__ import annotations
12 |
13 | import math
14 | import os
15 | import random
16 | import sys
17 | from argparse import ArgumentParser
18 |
19 | import einops
20 | import k_diffusion as K
21 | import numpy as np
22 | import torch
23 | import torch.nn as nn
24 | import glob
25 | import re
26 | import jsonlines
27 | from einops import rearrange
28 | from omegaconf import OmegaConf
29 | from PIL import Image, ImageOps
30 | from torch import autocast
31 |
32 | sys.path.append("./stable_diffusion")
33 |
34 | from stable_diffusion.ldm.util import instantiate_from_config
35 |
36 |
37 | class CFGDenoiser(nn.Module):
38 | def __init__(self, model):
39 | super().__init__()
40 | self.inner_model = model
41 |
42 | def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale):
43 | cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
44 | cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
45 | cfg_cond = {
46 | "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])],
47 | "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])],
48 | }
49 | out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)
50 | return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
51 |
52 |
53 | def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
54 | print(f"Loading model from {ckpt}")
55 | pl_sd = torch.load(ckpt, map_location="cpu")
56 | if "global_step" in pl_sd:
57 | print(f"Global Step: {pl_sd['global_step']}")
58 | sd = pl_sd["state_dict"]
59 | if vae_ckpt is not None:
60 | print(f"Loading VAE from {vae_ckpt}")
61 | vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
62 | sd = {
63 | k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
64 | for k, v in sd.items()
65 | }
66 | model = instantiate_from_config(config.model)
67 | m, u = model.load_state_dict(sd, strict=False)
68 | if len(m) > 0 and verbose:
69 | print("missing keys:")
70 | print(m)
71 | if len(u) > 0 and verbose:
72 | print("unexpected keys:")
73 | print(u)
74 | return model
75 |
76 |
77 | def main():
78 | parser = ArgumentParser()
79 | parser.add_argument("--resolution", default=512, type=int)
80 | parser.add_argument("--steps", default=100, type=int)
81 | parser.add_argument("--config", default="configs/generate.yaml", type=str)
82 | parser.add_argument("--ckpt", default="checkpoints/hive_rw_label.ckpt", type=str)
83 | parser.add_argument("--vae-ckpt", default=None, type=str)
84 | parser.add_argument("--output_dir", required=True, type=str)
85 | parser.add_argument("--jsonl_file", required=True, type=str)
86 | parser.add_argument("--image_dir", required=True, default="data/evaluation/", type=str)
87 | parser.add_argument("--cfg-text", default=7.5, type=float)
88 | parser.add_argument("--cfg-image", default=1.5, type=float)
89 | parser.add_argument("--seed", default=100, type=int)
90 | args = parser.parse_args()
91 |
92 | config = OmegaConf.load(args.config)
93 | model = load_model_from_config(config, args.ckpt, args.vae_ckpt)
94 | model.eval().cuda()
95 | model_wrap = K.external.CompVisDenoiser(model)
96 | model_wrap_cfg = CFGDenoiser(model_wrap)
97 | null_token = model.get_learned_conditioning([""])
98 |
99 | seed = random.randint(0, 100000) if args.seed is None else args.seed
100 | output_dir = args.output_dir
101 | if not os.path.exists(output_dir):
102 | os.makedirs(output_dir)
103 | image_dir = args.image_dir
104 | instructions = []
105 | image_paths = []
106 | with jsonlines.open(args.jsonl_file) as reader:
107 | for ll in reader:
108 | instructions.append(ll["instruction"])
109 | image_paths.append(os.path.join(image_dir, ll['source_img']))
110 |
111 | for i, instruction in enumerate(instructions):
112 | output_image = os.path.join(output_dir, f'instruct_{i}.png')
113 | input_image = Image.open(image_paths[i]).convert("RGB")
114 | width, height = input_image.size
115 | factor = args.resolution / max(width, height)
116 | factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
117 | width = int((width * factor) // 64) * 64
118 | height = int((height * factor) // 64) * 64
119 | input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
120 | edit = instructions[i]
121 | edit = edit + ', ' + f'image quality is five out of five'
122 |
123 | with torch.no_grad(), autocast("cuda"), model.ema_scope():
124 | cond = {}
125 | cond["c_crossattn"] = [model.get_learned_conditioning([edit])]
126 | input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
127 | input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device)
128 | cond["c_concat"] = [model.encode_first_stage(input_image).mode()]
129 |
130 | uncond = {}
131 | uncond["c_crossattn"] = [null_token]
132 | uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
133 |
134 | sigmas = model_wrap.get_sigmas(args.steps)
135 |
136 | extra_args = {
137 | "cond": cond,
138 | "uncond": uncond,
139 | "text_cfg_scale": args.cfg_text,
140 | "image_cfg_scale": args.cfg_image,
141 | }
142 | torch.manual_seed(seed)
143 | z = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
144 | z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args)
145 | x = model.decode_first_stage(z)
146 | x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
147 | x = 255.0 * rearrange(x, "1 c h w -> h w c")
148 | edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy())
149 | edited_image.save(output_image)
150 |
151 |
152 | if __name__ == "__main__":
153 | main()
154 |
--------------------------------------------------------------------------------
/edit_cli_rw_label.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2023 Salesforce, Inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: Apache License 2.0
5 | * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
6 | * By Shu Zhang
7 | * Modified from InstructPix2Pix repo: https://github.com/timothybrooks/instruct-pix2pix
8 | * Copyright (c) 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros. All rights reserved.
9 | '''
10 |
11 | from __future__ import annotations
12 |
13 | import math
14 | import random
15 | import sys
16 | from argparse import ArgumentParser
17 |
18 | import einops
19 | import k_diffusion as K
20 | import numpy as np
21 | import torch
22 | import torch.nn as nn
23 | from einops import rearrange
24 | from omegaconf import OmegaConf
25 | from PIL import Image, ImageOps
26 | from torch import autocast
27 |
28 | sys.path.append("./stable_diffusion")
29 |
30 | from stable_diffusion.ldm.util import instantiate_from_config
31 |
32 |
33 | class CFGDenoiser(nn.Module):
34 | def __init__(self, model):
35 | super().__init__()
36 | self.inner_model = model
37 |
38 | def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale):
39 | cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
40 | cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
41 | cfg_cond = {
42 | "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])],
43 | "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])],
44 | }
45 | out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)
46 | return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
47 |
48 |
49 | def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
50 | print(f"Loading model from {ckpt}")
51 | pl_sd = torch.load(ckpt, map_location="cpu")
52 | if "global_step" in pl_sd:
53 | print(f"Global Step: {pl_sd['global_step']}")
54 | sd = pl_sd["state_dict"]
55 | if vae_ckpt is not None:
56 | print(f"Loading VAE from {vae_ckpt}")
57 | vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
58 | sd = {
59 | k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
60 | for k, v in sd.items()
61 | }
62 | model = instantiate_from_config(config.model)
63 | m, u = model.load_state_dict(sd, strict=False)
64 | if len(m) > 0 and verbose:
65 | print("missing keys:")
66 | print(m)
67 | if len(u) > 0 and verbose:
68 | print("unexpected keys:")
69 | print(u)
70 | return model
71 |
72 |
73 | def main():
74 | parser = ArgumentParser()
75 | parser.add_argument("--resolution", default=512, type=int)
76 | parser.add_argument("--steps", default=100, type=int)
77 | parser.add_argument("--config", default="configs/generate.yaml", type=str)
78 | parser.add_argument("--ckpt", default="checkpoints/hive_rw_label.ckpt", type=str)
79 | parser.add_argument("--vae-ckpt", default=None, type=str)
80 | parser.add_argument("--input", required=True, type=str)
81 | parser.add_argument("--output", required=True, type=str)
82 | parser.add_argument("--edit", required=True, type=str)
83 | parser.add_argument("--cfg-text", default=7.5, type=float)
84 | parser.add_argument("--cfg-image", default=1.5, type=float)
85 | parser.add_argument("--seed", type=int)
86 | args = parser.parse_args()
87 |
88 | config = OmegaConf.load(args.config)
89 | model = load_model_from_config(config, args.ckpt, args.vae_ckpt)
90 | model.eval().cuda()
91 | model_wrap = K.external.CompVisDenoiser(model)
92 | model_wrap_cfg = CFGDenoiser(model_wrap)
93 | null_token = model.get_learned_conditioning([""])
94 |
95 | seed = random.randint(0, 100000) if args.seed is None else args.seed
96 | input_image = Image.open(args.input).convert("RGB")
97 | width, height = input_image.size
98 | factor = args.resolution / max(width, height)
99 | factor = math.ceil(min(width, height) * factor / 64) * 64 / min(width, height)
100 | width = int((width * factor) // 64) * 64
101 | height = int((height * factor) // 64) * 64
102 | input_image = ImageOps.fit(input_image, (width, height), method=Image.Resampling.LANCZOS)
103 |
104 | if args.edit == "":
105 | input_image.save(args.output)
106 | return
107 |
108 | with torch.no_grad(), autocast("cuda"), model.ema_scope():
109 | cond = {}
110 | edit = args.edit + ', ' + f'image quality is five out of five'
111 | cond["c_crossattn"] = [model.get_learned_conditioning([edit])]
112 | input_image = 2 * torch.tensor(np.array(input_image)).float() / 255 - 1
113 | input_image = rearrange(input_image, "h w c -> 1 c h w").to(model.device)
114 | cond["c_concat"] = [model.encode_first_stage(input_image).mode()]
115 |
116 | uncond = {}
117 | uncond["c_crossattn"] = [null_token]
118 | uncond["c_concat"] = [torch.zeros_like(cond["c_concat"][0])]
119 |
120 | sigmas = model_wrap.get_sigmas(args.steps)
121 |
122 | extra_args = {
123 | "cond": cond,
124 | "uncond": uncond,
125 | "text_cfg_scale": args.cfg_text,
126 | "image_cfg_scale": args.cfg_image,
127 | }
128 | torch.manual_seed(seed)
129 | z = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
130 | z = K.sampling.sample_euler_ancestral(model_wrap_cfg, z, sigmas, extra_args=extra_args)
131 | x = model.decode_first_stage(z)
132 | x = torch.clamp((x + 1.0) / 2.0, min=0.0, max=1.0)
133 | x = 255.0 * rearrange(x, "1 c h w -> h w c")
134 | edited_image = Image.fromarray(x.type(torch.uint8).cpu().numpy())
135 | edited_image.save(args.output)
136 |
137 |
138 | if __name__ == "__main__":
139 | main()
140 |
--------------------------------------------------------------------------------
/edit_dataset.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2023 Salesforce, Inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: Apache License 2.0
5 | * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
6 | * By Shu Zhang
7 | * Modified from InstructPix2Pix repo: https://github.com/timothybrooks/instruct-pix2pix
8 | * Copyright (c) 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros. All rights reserved.
9 | '''
10 |
11 | from __future__ import annotations
12 | import json
13 | import math
14 | from pathlib import Path
15 | from typing import Any
16 |
17 | import numpy as np
18 | import torch
19 | import torchvision
20 | from einops import rearrange
21 | from PIL import Image
22 | from torch.utils.data import Dataset
23 | import jsonlines
24 | from collections import deque
25 |
26 |
27 | class EditDataset(Dataset):
28 | def __init__(
29 | self,
30 | path_instructpix2pix: str,
31 | path_hive_0: str,
32 | path_hive_1: str,
33 | path_hive_2: str,
34 | split: str = "train",
35 | splits: tuple[float, float, float] = (0.9, 0.05, 0.05),
36 | min_resize_res: int = 256,
37 | max_resize_res: int = 256,
38 | crop_res: int = 256,
39 | flip_prob: float = 0.0,
40 | ):
41 | assert split in ("train", "val", "test")
42 | assert sum(splits) == 1
43 | self.path_instructpix2pix = path_instructpix2pix
44 | self.path_hive_0 = path_hive_0
45 | self.path_hive_1 = path_hive_1
46 | self.path_hive_2 = path_hive_2
47 | self.min_resize_res = min_resize_res
48 | self.max_resize_res = max_resize_res
49 | self.crop_res = crop_res
50 | self.flip_prob = flip_prob
51 | self.seeds = []
52 | self.instructions = []
53 | self.source_imgs = []
54 | self.edited_imgs = []
55 | # load instructpix2pix dataset
56 | with open(Path(self.path_instructpix2pix, "seeds.json")) as f:
57 | seeds = json.load(f)
58 | split_0, split_1 = {
59 | "train": (0.0, splits[0]),
60 | "val": (splits[0], splits[0] + splits[1]),
61 | "test": (splits[0] + splits[1], 1.0),
62 | }[split]
63 |
64 | idx_0 = math.floor(split_0 * len(seeds))
65 | idx_1 = math.floor(split_1 * len(seeds))
66 | seeds = seeds[idx_0:idx_1]
67 |
68 | for seed in seeds:
69 | seed = deque(seed)
70 | seed.appendleft('')
71 | seed.appendleft('instructpix2pix')
72 | self.seeds.append(list(seed))
73 |
74 |
75 | # load HIVE dataset first part
76 |
77 | cnt = 0
78 | with jsonlines.open(Path(self.path_hive_0, "training_cycle.jsonl")) as reader:
79 | for ll in reader:
80 | self.instructions.append(ll['instruction'])
81 | self.source_imgs.append(ll['source_img'])
82 | self.edited_imgs.append(ll['edited_img'])
83 | self.seeds.append(['hive_0', '', '', [cnt]])
84 | cnt += 1
85 |
86 | # load HIVE dataset second part
87 | with open(Path(self.path_hive_1, "seeds.json")) as f:
88 | seeds = json.load(f)
89 | for seed in seeds:
90 | seed = deque(seed)
91 | seed.appendleft('hive_1')
92 | self.seeds.append(list(seed))
93 | # load HIVE dataset third part
94 | with open(Path(self.path_hive_2, "seeds.json")) as f:
95 | seeds = json.load(f)
96 | for seed in seeds:
97 | seed = deque(seed)
98 | seed.appendleft('hive_2')
99 | self.seeds.append(list(seed))
100 |
101 | def __len__(self) -> int:
102 | return len(self.seeds)
103 |
104 | def __getitem__(self, i: int) -> dict[str, Any]:
105 |
106 | name_0, name_1, name_2, seeds = self.seeds[i]
107 | if name_0 == 'instructpix2pix':
108 | propt_dir = Path(self.path_instructpix2pix, name_2)
109 | seed = seeds[torch.randint(0, len(seeds), ()).item()]
110 | with open(propt_dir.joinpath("prompt.json")) as fp:
111 | prompt = json.load(fp)["edit"]
112 | image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg"))
113 | image_1 = Image.open(propt_dir.joinpath(f"{seed}_1.jpg"))
114 | elif name_0 == 'hive_1':
115 | propt_dir = Path(self.path_hive_1, name_1, name_2)
116 | seed = seeds[torch.randint(0, len(seeds), ()).item()]
117 | with open(propt_dir.joinpath("prompt.json")) as fp:
118 | prompt = json.load(fp)["instruction"]
119 | image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg"))
120 | image_1 = Image.open(propt_dir.joinpath(f"{seed}_1.jpg"))
121 | elif name_0 == 'hive_2':
122 | propt_dir = Path(self.path_hive_2, name_1, name_2)
123 | seed = seeds[torch.randint(0, len(seeds), ()).item()]
124 | with open(propt_dir.joinpath("prompt.json")) as fp:
125 | prompt = json.load(fp)["instruction"]
126 | image_0 = Image.open(propt_dir.joinpath(f"{seed}_0.jpg"))
127 | image_1 = Image.open(propt_dir.joinpath(f"{seed}_1.jpg"))
128 | else:
129 | j = seeds[0]
130 | image_0 = Image.open(self.source_imgs[j])
131 | image_1 = Image.open(self.edited_imgs[j])
132 | prompt = self.instructions[j]
133 |
134 | reize_res = torch.randint(self.min_resize_res, self.max_resize_res + 1, ()).item()
135 | image_0 = image_0.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
136 | image_1 = image_1.resize((reize_res, reize_res), Image.Resampling.LANCZOS)
137 |
138 | image_0 = rearrange(2 * torch.tensor(np.array(image_0)).float() / 255 - 1, "h w c -> c h w")
139 | image_1 = rearrange(2 * torch.tensor(np.array(image_1)).float() / 255 - 1, "h w c -> c h w")
140 |
141 | crop = torchvision.transforms.RandomCrop(self.crop_res)
142 | flip = torchvision.transforms.RandomHorizontalFlip(float(self.flip_prob))
143 | image_0, image_1 = flip(crop(torch.cat((image_0, image_1)))).chunk(2)
144 |
145 | return dict(edited=image_1, edit=dict(c_concat=image_0, c_crossattn=prompt))
146 |
147 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion).
2 | # See more details in LICENSE.
3 |
4 | name: hive
5 | channels:
6 | - pytorch
7 | - defaults
8 | dependencies:
9 | - python=3.8.5
10 | - pip=20.3
11 | - cudatoolkit=11.3
12 | - pytorch=1.11.0
13 | - torchvision=0.12.0
14 | - numpy=1.19.2
15 | - pip:
16 | - albumentations==0.4.3
17 | - datasets==2.8.0
18 | - diffusers
19 | - opencv-python==4.1.2.30
20 | - pudb==2019.2
21 | - invisible-watermark
22 | - imageio==2.9.0
23 | - imageio-ffmpeg==0.4.2
24 | - pytorch-lightning==1.4.2
25 | - omegaconf==2.1.1
26 | - test-tube>=0.7.5
27 | - streamlit>=0.73.1
28 | - einops==0.3.0
29 | - torch-fidelity==0.3.0
30 | - transformers==4.19.2
31 | - torchmetrics==0.6.0
32 | - kornia==0.6
33 | - open_clip_torch==2.0.2
34 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
35 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip
36 | - openai
37 | - gradio
38 | - seaborn
39 | - jsonlines
40 | - git+https://github.com/crowsonkb/k-diffusion.git
41 |
--------------------------------------------------------------------------------
/imgs/example1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/imgs/example1.jpg
--------------------------------------------------------------------------------
/imgs/example2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/imgs/example2.jpg
--------------------------------------------------------------------------------
/imgs/example3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/imgs/example3.jpg
--------------------------------------------------------------------------------
/imgs/results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/imgs/results.png
--------------------------------------------------------------------------------
/metrics/clip_similarity.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2023 Salesforce, Inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: Apache License 2.0
5 | * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
6 | * By Shu Zhang
7 | * Redistributed from InstructPix2Pix repo: https://github.com/timothybrooks/instruct-pix2pix
8 | * Copyright (c) 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros. All rights reserved.
9 | '''
10 |
11 | from __future__ import annotations
12 |
13 | import clip
14 | import torch
15 | import torch.nn as nn
16 | import torch.nn.functional as F
17 | from einops import rearrange
18 |
19 |
20 | class ClipSimilarity(nn.Module):
21 | def __init__(self, name: str = "ViT-L/14"):
22 | super().__init__()
23 | assert name in ("RN50", "RN101", "RN50x4", "RN50x16", "RN50x64", "ViT-B/32", "ViT-B/16", "ViT-L/14", "ViT-L/14@336px") # fmt: skip
24 | self.size = {"RN50x4": 288, "RN50x16": 384, "RN50x64": 448, "ViT-L/14@336px": 336}.get(name, 224)
25 |
26 | self.model, _ = clip.load(name, device="cpu", download_root="./")
27 | self.model.eval().requires_grad_(False)
28 |
29 | self.register_buffer("mean", torch.tensor((0.48145466, 0.4578275, 0.40821073)))
30 | self.register_buffer("std", torch.tensor((0.26862954, 0.26130258, 0.27577711)))
31 |
32 | def encode_text(self, text: list[str]) -> torch.Tensor:
33 | text = clip.tokenize(text, truncate=True).to(next(self.parameters()).device)
34 | text_features = self.model.encode_text(text)
35 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
36 | return text_features
37 |
38 | def encode_image(self, image: torch.Tensor) -> torch.Tensor: # Input images in range [0, 1].
39 | image = F.interpolate(image.float(), size=self.size, mode="bicubic", align_corners=False)
40 | image = image - rearrange(self.mean, "c -> 1 c 1 1")
41 | image = image / rearrange(self.std, "c -> 1 c 1 1")
42 | image_features = self.model.encode_image(image)
43 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
44 | return image_features
45 |
46 | def forward(
47 | self, image_0: torch.Tensor, image_1: torch.Tensor, text_0: list[str], text_1: list[str]
48 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
49 | image_features_0 = self.encode_image(image_0)
50 | image_features_1 = self.encode_image(image_1)
51 | text_features_0 = self.encode_text(text_0)
52 | text_features_1 = self.encode_text(text_1)
53 | sim_0 = F.cosine_similarity(image_features_0, text_features_0)
54 | sim_1 = F.cosine_similarity(image_features_1, text_features_1)
55 | sim_direction = F.cosine_similarity(image_features_1 - image_features_0, text_features_1 - text_features_0)
56 | sim_image = F.cosine_similarity(image_features_0, image_features_1)
57 | return sim_0, sim_1, sim_direction, sim_image
58 |
--------------------------------------------------------------------------------
/metrics/compute_metrics.py:
--------------------------------------------------------------------------------
1 | '''
2 | * Copyright (c) 2023 Salesforce, Inc.
3 | * All rights reserved.
4 | * SPDX-License-Identifier: Apache License 2.0
5 | * For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/
6 | * By Shu Zhang
7 | * Redistributed from InstructPix2Pix repo: https://github.com/timothybrooks/instruct-pix2pix
8 | * Copyright (c) 2023 Timothy Brooks, Aleksander Holynski, Alexei A. Efros. All rights reserved.
9 | '''
10 |
11 | from __future__ import annotations
12 |
13 | import math
14 | import random
15 | import sys
16 | from argparse import ArgumentParser
17 |
18 | import einops
19 | import k_diffusion as K
20 | import numpy as np
21 | import torch
22 | import torch.nn as nn
23 | from tqdm.auto import tqdm
24 | from einops import rearrange
25 | from omegaconf import OmegaConf
26 | from PIL import Image, ImageOps
27 | from torch import autocast
28 |
29 | import json
30 | import matplotlib.pyplot as plt
31 | import seaborn
32 | from pathlib import Path
33 |
34 | sys.path.append("./")
35 |
36 | from clip_similarity import ClipSimilarity
37 | from edit_dataset import EditDatasetEval
38 |
39 | sys.path.append("./stable_diffusion")
40 |
41 | from ldm.util import instantiate_from_config
42 |
43 |
44 | class CFGDenoiser(nn.Module):
45 | def __init__(self, model):
46 | super().__init__()
47 | self.inner_model = model
48 |
49 | def forward(self, z, sigma, cond, uncond, text_cfg_scale, image_cfg_scale):
50 | cfg_z = einops.repeat(z, "1 ... -> n ...", n=3)
51 | cfg_sigma = einops.repeat(sigma, "1 ... -> n ...", n=3)
52 | cfg_cond = {
53 | "c_crossattn": [torch.cat([cond["c_crossattn"][0], uncond["c_crossattn"][0], uncond["c_crossattn"][0]])],
54 | "c_concat": [torch.cat([cond["c_concat"][0], cond["c_concat"][0], uncond["c_concat"][0]])],
55 | }
56 | out_cond, out_img_cond, out_uncond = self.inner_model(cfg_z, cfg_sigma, cond=cfg_cond).chunk(3)
57 | return out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
58 |
59 |
60 | def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
61 | print(f"Loading model from {ckpt}")
62 | pl_sd = torch.load(ckpt, map_location="cpu")
63 | if "global_step" in pl_sd:
64 | print(f"Global Step: {pl_sd['global_step']}")
65 | sd = pl_sd["state_dict"]
66 | if vae_ckpt is not None:
67 | print(f"Loading VAE from {vae_ckpt}")
68 | vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
69 | sd = {
70 | k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
71 | for k, v in sd.items()
72 | }
73 | model = instantiate_from_config(config.model)
74 | m, u = model.load_state_dict(sd, strict=False)
75 | if len(m) > 0 and verbose:
76 | print("missing keys:")
77 | print(m)
78 | if len(u) > 0 and verbose:
79 | print("unexpected keys:")
80 | print(u)
81 | return model
82 |
83 | class ImageEditor(nn.Module):
84 | def __init__(self, config, ckpt, vae_ckpt=None):
85 | super().__init__()
86 |
87 | config = OmegaConf.load(config)
88 | self.model = load_model_from_config(config, ckpt, vae_ckpt)
89 | self.model.eval().cuda()
90 | self.model_wrap = K.external.CompVisDenoiser(self.model)
91 | self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
92 | self.null_token = self.model.get_learned_conditioning([""])
93 |
94 | def forward(
95 | self,
96 | image: torch.Tensor,
97 | edit: str,
98 | scale_txt: float = 7.5,
99 | scale_img: float = 1.0,
100 | steps: int = 100,
101 | ) -> torch.Tensor:
102 | assert image.dim() == 3
103 | assert image.size(1) % 64 == 0
104 | assert image.size(2) % 64 == 0
105 | with torch.no_grad(), autocast("cuda"), self.model.ema_scope():
106 | cond = {
107 | "c_crossattn": [self.model.get_learned_conditioning([edit])],
108 | "c_concat": [self.model.encode_first_stage(image[None]).mode()],
109 | }
110 | uncond = {
111 | "c_crossattn": [self.model.get_learned_conditioning([""])],
112 | "c_concat": [torch.zeros_like(cond["c_concat"][0])],
113 | }
114 | extra_args = {
115 | "uncond": uncond,
116 | "cond": cond,
117 | "image_cfg_scale": scale_img,
118 | "text_cfg_scale": scale_txt,
119 | }
120 | sigmas = self.model_wrap.get_sigmas(steps)
121 | x = torch.randn_like(cond["c_concat"][0]) * sigmas[0]
122 | x = K.sampling.sample_euler_ancestral(self.model_wrap_cfg, x, sigmas, extra_args=extra_args)
123 | x = self.model.decode_first_stage(x)[0]
124 | return x
125 |
126 |
127 | def compute_metrics(config,
128 | model_path,
129 | vae_ckpt,
130 | data_path,
131 | output_path,
132 | scales_img,
133 | scales_txt,
134 | num_samples = 5000,
135 | split = "test",
136 | steps = 50,
137 | res = 512,
138 | seed = 0):
139 | editor = ImageEditor(config, model_path, vae_ckpt).cuda()
140 | clip_similarity = ClipSimilarity().cuda()
141 |
142 |
143 |
144 | outpath = Path(output_path, f"n={num_samples}_p={split}_s={steps}_r={res}_e={seed}.jsonl")
145 | Path(output_path).mkdir(parents=True, exist_ok=True)
146 |
147 | for scale_txt in scales_txt:
148 | for scale_img in scales_img:
149 | dataset = EditDatasetEval(
150 | path=data_path,
151 | split=split,
152 | res=res
153 | )
154 | assert num_samples <= len(dataset)
155 | print(f'Processing t={scale_txt}, i={scale_img}')
156 | torch.manual_seed(seed)
157 | perm = torch.randperm(len(dataset))
158 | count = 0
159 | i = 0
160 |
161 | sim_0_avg = 0
162 | sim_1_avg = 0
163 | sim_direction_avg = 0
164 | sim_image_avg = 0
165 | count = 0
166 |
167 | pbar = tqdm(total=num_samples)
168 | while count < num_samples:
169 |
170 | idx = perm[i].item()
171 | sample = dataset[idx]
172 | i += 1
173 |
174 | gen = editor(sample["image_0"].cuda(), sample["edit"], scale_txt=scale_txt, scale_img=scale_img, steps=steps)
175 |
176 | sim_0, sim_1, sim_direction, sim_image = clip_similarity(
177 | sample["image_0"][None].cuda(), gen[None].cuda(), [sample["input_prompt"]], [sample["output_prompt"]]
178 | )
179 | sim_0_avg += sim_0.item()
180 | sim_1_avg += sim_1.item()
181 | sim_direction_avg += sim_direction.item()
182 | sim_image_avg += sim_image.item()
183 | count += 1
184 | pbar.update(count)
185 | pbar.close()
186 |
187 | sim_0_avg /= count
188 | sim_1_avg /= count
189 | sim_direction_avg /= count
190 | sim_image_avg /= count
191 |
192 | with open(outpath, "a") as f:
193 | f.write(f"{json.dumps(dict(sim_0=sim_0_avg, sim_1=sim_1_avg, sim_direction=sim_direction_avg, sim_image=sim_image_avg, num_samples=num_samples, split=split, scale_txt=scale_txt, scale_img=scale_img, steps=steps, res=res, seed=seed))}\n")
194 | return outpath
195 |
196 | def plot_metrics(metrics_file, output_path):
197 |
198 | with open(metrics_file, 'r') as f:
199 | data = [json.loads(line) for line in f]
200 |
201 | plt.rcParams.update({'font.size': 11.5})
202 | seaborn.set_style("darkgrid")
203 | plt.figure(figsize=(20.5* 0.7, 10.8* 0.7), dpi=200)
204 |
205 | x = [d["sim_direction"] for d in data]
206 | y = [d["sim_image"] for d in data]
207 |
208 | plt.plot(x, y, marker='o', linewidth=2, markersize=4)
209 |
210 | plt.xlabel("CLIP Text-Image Direction Similarity", labelpad=10)
211 | plt.ylabel("CLIP Image Similarity", labelpad=10)
212 |
213 | plt.savefig(Path(output_path) / Path("plot.pdf"), bbox_inches="tight")
214 |
215 | def main():
216 | parser = ArgumentParser()
217 | parser.add_argument("--resolution", default=512, type=int)
218 | parser.add_argument("--steps", default=100, type=int)
219 | parser.add_argument("--config", default="configs/generate.yaml", type=str)
220 | parser.add_argument("--output_path", default="analysis/", type=str)
221 | parser.add_argument("--ckpt", default="checkpoints/hive_v2_rw_condition.ckpt", type=str)
222 | parser.add_argument("--dataset", default="data/clip-filtered-dataset/", type=str)
223 | parser.add_argument("--vae-ckpt", default=None, type=str)
224 | args = parser.parse_args()
225 |
226 | scales_img = [1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2]
227 | scales_txt = [7.5]
228 |
229 | metrics_file = compute_metrics(
230 | args.config,
231 | args.ckpt,
232 | args.vae_ckpt,
233 | args.dataset,
234 | args.output_path,
235 | scales_img,
236 | scales_txt
237 | steps = args.steps
238 | )
239 |
240 | plot_metrics(metrics_file, args.output_path)
241 |
242 |
243 |
244 | if __name__ == "__main__":
245 | main()
246 |
--------------------------------------------------------------------------------
/scripts/download_checkpoints.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
4 |
5 | mkdir -p $SCRIPT_DIR/../checkpoints
6 |
7 | wget https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.ckpt $SCRIPT_DIR/../checkpoints/v2-1_512-ema-pruned.ckpt
8 | gcloud storage cp gs://sfr-hive-data-research/checkpoints/hive_rw_condition.ckpt $SCRIPT_DIR/../checkpoints/hive_rw_condition.ckpt
9 | gcloud storage cp gs://sfr-hive-data-research/checkpoints/hive_v2_rw_condition.ckpt $SCRIPT_DIR/../checkpoints/hive_v2_rw_condition.ckpt
10 | gcloud storage cp gs://sfr-hive-data-research/checkpoints/hive_rw.ckpt $SCRIPT_DIR/../checkpoints/hive_rw.ckpt
11 | gcloud storage cp gs://sfr-hive-data-research/checkpoints/hive_v2_rw.ckpt $SCRIPT_DIR/../checkpoints/hive_v2_rw.ckpt
12 |
--------------------------------------------------------------------------------
/scripts/download_hive_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Make data folder relative to script location
4 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
5 |
6 | mkdir -p $SCRIPT_DIR/../data
7 |
8 | # Download HIVE training data
9 | gsutil cp -r gs://sfr-hive-data-research/data/training $SCRIPT_DIR/../data/training
10 |
11 | # Download HIVE evaluation data
12 |
13 | gcloud storage cp gs://sfr-hive-data-research/data/test.jsonl $SCRIPT_DIR/../data/test.jsonl
14 | gsutil cp -r gs://sfr-hive-data-research/data/evaluation $SCRIPT_DIR/../data/
--------------------------------------------------------------------------------
/scripts/download_instructpix2pix_data.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Make data folder relative to script location
4 | SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
5 |
6 | mkdir -p $SCRIPT_DIR/../data_instructpix2pix
7 |
8 | # Download InstructPix2Pix data (https://arxiv.org/pdf/2211.09800.pdf)
9 |
10 | # Copy text datasets
11 | wget -q --show-progress http://instruct-pix2pix.eecs.berkeley.edu/gpt-generated-prompts.jsonl -O $SCRIPT_DIR/../data_instructpix2pix/gpt-generated-prompts.jsonl
12 | wget -q --show-progress http://instruct-pix2pix.eecs.berkeley.edu/human-written-prompts.jsonl -O $SCRIPT_DIR/../data_instructpix2pix/human-written-prompts.jsonl
13 |
14 | # If dataset name isn't provided, exit.
15 | if [ -z $1 ]
16 | then
17 | exit 0
18 | fi
19 |
20 | # Copy dataset files
21 | mkdir $SCRIPT_DIR/../data_instructpix2pix/$1
22 | wget -A zip,json -R "index.html*" -q --show-progress -r --no-parent http://instruct-pix2pix.eecs.berkeley.edu/$1/ -nd -P $SCRIPT_DIR/../data_instructpix2pix/$1/
23 |
24 | # Unzip to folders
25 | unzip $SCRIPT_DIR/../data_instructpix2pix/$1/\*.zip -d $SCRIPT_DIR/../data_instructpix2pix/$1/
26 |
27 | # Cleanup
28 | rm -f $SCRIPT_DIR/../data_instructpix2pix/$1/*.zip
29 | rm -f $SCRIPT_DIR/../data_instructpix2pix/$1/*.html
30 |
--------------------------------------------------------------------------------
/stable_diffusion/Stable_Diffusion_v1_Model_Card.md:
--------------------------------------------------------------------------------
1 | # Stable Diffusion v1 Model Card
2 | This model card focuses on the model associated with the Stable Diffusion model, available [here](https://github.com/CompVis/stable-diffusion).
3 |
4 | ## Model Details
5 | - **Developed by:** Robin Rombach, Patrick Esser
6 | - **Model type:** Diffusion-based text-to-image generation model
7 | - **Language(s):** English
8 | - **License:** [Proprietary](LICENSE)
9 | - **Model Description:** This is a model that can be used to generate and modify images based on text prompts. It is a [Latent Diffusion Model](https://arxiv.org/abs/2112.10752) that uses a fixed, pretrained text encoder ([CLIP ViT-L/14](https://arxiv.org/abs/2103.00020)) as suggested in the [Imagen paper](https://arxiv.org/abs/2205.11487).
10 | - **Resources for more information:** [GitHub Repository](https://github.com/CompVis/stable-diffusion), [Paper](https://arxiv.org/abs/2112.10752).
11 | - **Cite as:**
12 |
13 | @InProceedings{Rombach_2022_CVPR,
14 | author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
15 | title = {High-Resolution Image Synthesis With Latent Diffusion Models},
16 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
17 | month = {June},
18 | year = {2022},
19 | pages = {10684-10695}
20 | }
21 |
22 | # Uses
23 |
24 | ## Direct Use
25 | The model is intended for research purposes only. Possible research areas and
26 | tasks include
27 |
28 | - Safe deployment of models which have the potential to generate harmful content.
29 | - Probing and understanding the limitations and biases of generative models.
30 | - Generation of artworks and use in design and other artistic processes.
31 | - Applications in educational or creative tools.
32 | - Research on generative models.
33 |
34 | Excluded uses are described below.
35 |
36 | ### Misuse, Malicious Use, and Out-of-Scope Use
37 | _Note: This section is taken from the [DALLE-MINI model card](https://huggingface.co/dalle-mini/dalle-mini), but applies in the same way to Stable Diffusion v1_.
38 |
39 | The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes.
40 |
41 | #### Out-of-Scope Use
42 | The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model.
43 |
44 | #### Misuse and Malicious Use
45 | Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to:
46 |
47 | - Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc.
48 | - Intentionally promoting or propagating discriminatory content or harmful stereotypes.
49 | - Impersonating individuals without their consent.
50 | - Sexual content without consent of the people who might see it.
51 | - Mis- and disinformation
52 | - Representations of egregious violence and gore
53 | - Sharing of copyrighted or licensed material in violation of its terms of use.
54 | - Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use.
55 |
56 | ## Limitations and Bias
57 |
58 | ### Limitations
59 |
60 | - The model does not achieve perfect photorealism
61 | - The model cannot render legible text
62 | - The model does not perform well on more difficult tasks which involve compositionality, such as rendering an image corresponding to “A red cube on top of a blue sphere”
63 | - Faces and people in general may not be generated properly.
64 | - The model was trained mainly with English captions and will not work as well in other languages.
65 | - The autoencoding part of the model is lossy
66 | - The model was trained on a large-scale dataset
67 | [LAION-5B](https://laion.ai/blog/laion-5b/) which contains adult material
68 | and is not fit for product use without additional safety mechanisms and
69 | considerations.
70 | - No additional measures were used to deduplicate the dataset. As a result, we observe some degree of memorization for images that are duplicated in the training data.
71 | The training data can be searched at [https://rom1504.github.io/clip-retrieval/](https://rom1504.github.io/clip-retrieval/) to possibly assist in the detection of memorized images.
72 |
73 | ### Bias
74 | While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases.
75 | Stable Diffusion v1 was primarily trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/),
76 | which consists of images that are limited to English descriptions.
77 | Texts and images from communities and cultures that use other languages are likely to be insufficiently accounted for.
78 | This affects the overall output of the model, as white and western cultures are often set as the default. Further, the
79 | ability of the model to generate content with non-English prompts is significantly worse than with English-language prompts.
80 | Stable Diffusion v1 mirrors and exacerbates biases to such a degree that viewer discretion must be advised irrespective of the input or its intent.
81 |
82 |
83 | ## Training
84 |
85 | **Training Data**
86 | The model developers used the following dataset for training the model:
87 |
88 | - LAION-5B and subsets thereof (see next section)
89 |
90 | **Training Procedure**
91 | Stable Diffusion v1 is a latent diffusion model which combines an autoencoder with a diffusion model that is trained in the latent space of the autoencoder. During training,
92 |
93 | - Images are encoded through an encoder, which turns images into latent representations. The autoencoder uses a relative downsampling factor of 8 and maps images of shape H x W x 3 to latents of shape H/f x W/f x 4
94 | - Text prompts are encoded through a ViT-L/14 text-encoder.
95 | - The non-pooled output of the text encoder is fed into the UNet backbone of the latent diffusion model via cross-attention.
96 | - The loss is a reconstruction objective between the noise that was added to the latent and the prediction made by the UNet.
97 |
98 | We currently provide the following checkpoints:
99 |
100 | - `sd-v1-1.ckpt`: 237k steps at resolution `256x256` on [laion2B-en](https://huggingface.co/datasets/laion/laion2B-en).
101 | 194k steps at resolution `512x512` on [laion-high-resolution](https://huggingface.co/datasets/laion/laion-high-resolution) (170M examples from LAION-5B with resolution `>= 1024x1024`).
102 | - `sd-v1-2.ckpt`: Resumed from `sd-v1-1.ckpt`.
103 | 515k steps at resolution `512x512` on [laion-aesthetics v2 5+](https://laion.ai/blog/laion-aesthetics/) (a subset of laion2B-en with estimated aesthetics score `> 5.0`, and additionally
104 | filtered to images with an original size `>= 512x512`, and an estimated watermark probability `< 0.5`. The watermark estimate is from the [LAION-5B](https://laion.ai/blog/laion-5b/) metadata, the aesthetics score is estimated using the [LAION-Aesthetics Predictor V2](https://github.com/christophschuhmann/improved-aesthetic-predictor)).
105 | - `sd-v1-3.ckpt`: Resumed from `sd-v1-2.ckpt`. 195k steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
106 | - `sd-v1-4.ckpt`: Resumed from `sd-v1-2.ckpt`. 225k steps at resolution `512x512` on "laion-aesthetics v2 5+" and 10\% dropping of the text-conditioning to improve [classifier-free guidance sampling](https://arxiv.org/abs/2207.12598).
107 |
108 | - **Hardware:** 32 x 8 x A100 GPUs
109 | - **Optimizer:** AdamW
110 | - **Gradient Accumulations**: 2
111 | - **Batch:** 32 x 8 x 2 x 4 = 2048
112 | - **Learning rate:** warmup to 0.0001 for 10,000 steps and then kept constant
113 |
114 | ## Evaluation Results
115 | Evaluations with different classifier-free guidance scales (1.5, 2.0, 3.0, 4.0,
116 | 5.0, 6.0, 7.0, 8.0) and 50 PLMS sampling
117 | steps show the relative improvements of the checkpoints:
118 |
119 | 
120 |
121 | Evaluated using 50 PLMS steps and 10000 random prompts from the COCO2017 validation set, evaluated at 512x512 resolution. Not optimized for FID scores.
122 |
123 | ## Environmental Impact
124 |
125 | **Stable Diffusion v1** **Estimated Emissions**
126 | Based on that information, we estimate the following CO2 emissions using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). The hardware, runtime, cloud provider, and compute region were utilized to estimate the carbon impact.
127 |
128 | - **Hardware Type:** A100 PCIe 40GB
129 | - **Hours used:** 150000
130 | - **Cloud Provider:** AWS
131 | - **Compute Region:** US-east
132 | - **Carbon Emitted (Power consumption x Time x Carbon produced based on location of power grid):** 11250 kg CO2 eq.
133 |
134 | ## Citation
135 | @InProceedings{Rombach_2022_CVPR,
136 | author = {Rombach, Robin and Blattmann, Andreas and Lorenz, Dominik and Esser, Patrick and Ommer, Bj\"orn},
137 | title = {High-Resolution Image Synthesis With Latent Diffusion Models},
138 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
139 | month = {June},
140 | year = {2022},
141 | pages = {10684-10695}
142 | }
143 |
144 | *This model card was written by: Robin Rombach and Patrick Esser and is based on the [DALL-E Mini model card](https://huggingface.co/dalle-mini/dalle-mini).*
145 |
--------------------------------------------------------------------------------
/stable_diffusion/assets/a-painting-of-a-fire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/a-painting-of-a-fire.png
--------------------------------------------------------------------------------
/stable_diffusion/assets/a-photograph-of-a-fire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/a-photograph-of-a-fire.png
--------------------------------------------------------------------------------
/stable_diffusion/assets/a-shirt-with-a-fire-printed-on-it.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/a-shirt-with-a-fire-printed-on-it.png
--------------------------------------------------------------------------------
/stable_diffusion/assets/a-shirt-with-the-inscription-'fire'.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/a-shirt-with-the-inscription-'fire'.png
--------------------------------------------------------------------------------
/stable_diffusion/assets/a-watercolor-painting-of-a-fire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/a-watercolor-painting-of-a-fire.png
--------------------------------------------------------------------------------
/stable_diffusion/assets/birdhouse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/birdhouse.png
--------------------------------------------------------------------------------
/stable_diffusion/assets/fire.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/fire.png
--------------------------------------------------------------------------------
/stable_diffusion/assets/inpainting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/inpainting.png
--------------------------------------------------------------------------------
/stable_diffusion/assets/modelfigure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/modelfigure.png
--------------------------------------------------------------------------------
/stable_diffusion/assets/rdm-preview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/rdm-preview.jpg
--------------------------------------------------------------------------------
/stable_diffusion/assets/reconstruction1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/reconstruction1.png
--------------------------------------------------------------------------------
/stable_diffusion/assets/reconstruction2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/reconstruction2.png
--------------------------------------------------------------------------------
/stable_diffusion/assets/results.gif.REMOVED.git-id:
--------------------------------------------------------------------------------
1 | 82b6590e670a32196093cc6333ea19e6547d07de
--------------------------------------------------------------------------------
/stable_diffusion/assets/rick.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/rick.jpeg
--------------------------------------------------------------------------------
/stable_diffusion/assets/stable-samples/img2img/mountains-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/stable-samples/img2img/mountains-1.png
--------------------------------------------------------------------------------
/stable_diffusion/assets/stable-samples/img2img/mountains-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/stable-samples/img2img/mountains-2.png
--------------------------------------------------------------------------------
/stable_diffusion/assets/stable-samples/img2img/mountains-3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/stable-samples/img2img/mountains-3.png
--------------------------------------------------------------------------------
/stable_diffusion/assets/stable-samples/img2img/sketch-mountains-input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/stable-samples/img2img/sketch-mountains-input.jpg
--------------------------------------------------------------------------------
/stable_diffusion/assets/stable-samples/img2img/upscaling-in.png.REMOVED.git-id:
--------------------------------------------------------------------------------
1 | 501c31c21751664957e69ce52cad1818b6d2f4ce
--------------------------------------------------------------------------------
/stable_diffusion/assets/stable-samples/img2img/upscaling-out.png.REMOVED.git-id:
--------------------------------------------------------------------------------
1 | 1c4bb25a779f34d86b2d90e584ac67af91bb1303
--------------------------------------------------------------------------------
/stable_diffusion/assets/stable-samples/txt2img/000002025.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/stable-samples/txt2img/000002025.png
--------------------------------------------------------------------------------
/stable_diffusion/assets/stable-samples/txt2img/000002035.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/stable-samples/txt2img/000002035.png
--------------------------------------------------------------------------------
/stable_diffusion/assets/stable-samples/txt2img/merged-0005.png.REMOVED.git-id:
--------------------------------------------------------------------------------
1 | ca0a1af206555f0f208a1ab879e95efedc1b1c5b
--------------------------------------------------------------------------------
/stable_diffusion/assets/stable-samples/txt2img/merged-0006.png.REMOVED.git-id:
--------------------------------------------------------------------------------
1 | 999f3703230580e8c89e9081abd6a1f8f50896d4
--------------------------------------------------------------------------------
/stable_diffusion/assets/stable-samples/txt2img/merged-0007.png.REMOVED.git-id:
--------------------------------------------------------------------------------
1 | af390acaf601283782d6f479d4cade4d78e30b26
--------------------------------------------------------------------------------
/stable_diffusion/assets/the-earth-is-on-fire,-oil-on-canvas.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/the-earth-is-on-fire,-oil-on-canvas.png
--------------------------------------------------------------------------------
/stable_diffusion/assets/txt2img-convsample.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/txt2img-convsample.png
--------------------------------------------------------------------------------
/stable_diffusion/assets/txt2img-preview.png.REMOVED.git-id:
--------------------------------------------------------------------------------
1 | 51ee1c235dfdc63d4c41de7d303d03730e43c33c
--------------------------------------------------------------------------------
/stable_diffusion/assets/v1-variants-scores.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/assets/v1-variants-scores.jpg
--------------------------------------------------------------------------------
/stable_diffusion/configs/autoencoder/autoencoder_kl_16x16x16.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 16
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 16
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: ldm.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: ldm.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/stable_diffusion/configs/autoencoder/autoencoder_kl_32x32x4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 4
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 4
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: ldm.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: ldm.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/stable_diffusion/configs/autoencoder/autoencoder_kl_64x64x3.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 3
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [ ]
24 | dropout: 0.0
25 |
26 |
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 12
31 | wrap: True
32 | train:
33 | target: ldm.data.imagenet.ImageNetSRTrain
34 | params:
35 | size: 256
36 | degradation: pil_nearest
37 | validation:
38 | target: ldm.data.imagenet.ImageNetSRValidation
39 | params:
40 | size: 256
41 | degradation: pil_nearest
42 |
43 | lightning:
44 | callbacks:
45 | image_logger:
46 | target: main.ImageLogger
47 | params:
48 | batch_frequency: 1000
49 | max_images: 8
50 | increase_log_steps: True
51 |
52 | trainer:
53 | benchmark: True
54 | accumulate_grad_batches: 2
55 |
--------------------------------------------------------------------------------
/stable_diffusion/configs/autoencoder/autoencoder_kl_8x8x64.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-6
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: "val/rec_loss"
6 | embed_dim: 64
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 0.000001
12 | disc_weight: 0.5
13 |
14 | ddconfig:
15 | double_z: True
16 | z_channels: 64
17 | resolution: 256
18 | in_channels: 3
19 | out_ch: 3
20 | ch: 128
21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1
22 | num_res_blocks: 2
23 | attn_resolutions: [16,8]
24 | dropout: 0.0
25 |
26 | data:
27 | target: main.DataModuleFromConfig
28 | params:
29 | batch_size: 12
30 | wrap: True
31 | train:
32 | target: ldm.data.imagenet.ImageNetSRTrain
33 | params:
34 | size: 256
35 | degradation: pil_nearest
36 | validation:
37 | target: ldm.data.imagenet.ImageNetSRValidation
38 | params:
39 | size: 256
40 | degradation: pil_nearest
41 |
42 | lightning:
43 | callbacks:
44 | image_logger:
45 | target: main.ImageLogger
46 | params:
47 | batch_frequency: 1000
48 | max_images: 8
49 | increase_log_steps: True
50 |
51 | trainer:
52 | benchmark: True
53 | accumulate_grad_batches: 2
54 |
--------------------------------------------------------------------------------
/stable_diffusion/configs/latent-diffusion/celebahq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 |
15 | unet_config:
16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
17 | params:
18 | image_size: 64
19 | in_channels: 3
20 | out_channels: 3
21 | model_channels: 224
22 | attention_resolutions:
23 | # note: this isn\t actually the resolution but
24 | # the downsampling factor, i.e. this corresnponds to
25 | # attention on spatial resolution 8,16,32, as the
26 | # spatial reolution of the latents is 64 for f4
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | num_head_channels: 32
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 3
41 | n_embed: 8192
42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config: __is_unconditional__
60 | data:
61 | target: main.DataModuleFromConfig
62 | params:
63 | batch_size: 48
64 | num_workers: 5
65 | wrap: false
66 | train:
67 | target: taming.data.faceshq.CelebAHQTrain
68 | params:
69 | size: 256
70 | validation:
71 | target: taming.data.faceshq.CelebAHQValidation
72 | params:
73 | size: 256
74 |
75 |
76 | lightning:
77 | callbacks:
78 | image_logger:
79 | target: main.ImageLogger
80 | params:
81 | batch_frequency: 5000
82 | max_images: 8
83 | increase_log_steps: False
84 |
85 | trainer:
86 | benchmark: True
--------------------------------------------------------------------------------
/stable_diffusion/configs/latent-diffusion/cin-ldm-vq-f8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 32
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 256
24 | attention_resolutions:
25 | #note: this isn\t actually the resolution but
26 | # the downsampling factor, i.e. this corresnponds to
27 | # attention on spatial resolution 8,16,32, as the
28 | # spatial reolution of the latents is 32 for f8
29 | - 4
30 | - 2
31 | - 1
32 | num_res_blocks: 2
33 | channel_mult:
34 | - 1
35 | - 2
36 | - 4
37 | num_head_channels: 32
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 512
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 4
45 | n_embed: 16384
46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml
47 | ddconfig:
48 | double_z: false
49 | z_channels: 4
50 | resolution: 256
51 | in_channels: 3
52 | out_ch: 3
53 | ch: 128
54 | ch_mult:
55 | - 1
56 | - 2
57 | - 2
58 | - 4
59 | num_res_blocks: 2
60 | attn_resolutions:
61 | - 32
62 | dropout: 0.0
63 | lossconfig:
64 | target: torch.nn.Identity
65 | cond_stage_config:
66 | target: ldm.modules.encoders.modules.ClassEmbedder
67 | params:
68 | embed_dim: 512
69 | key: class_label
70 | data:
71 | target: main.DataModuleFromConfig
72 | params:
73 | batch_size: 64
74 | num_workers: 12
75 | wrap: false
76 | train:
77 | target: ldm.data.imagenet.ImageNetTrain
78 | params:
79 | config:
80 | size: 256
81 | validation:
82 | target: ldm.data.imagenet.ImageNetValidation
83 | params:
84 | config:
85 | size: 256
86 |
87 |
88 | lightning:
89 | callbacks:
90 | image_logger:
91 | target: main.ImageLogger
92 | params:
93 | batch_frequency: 5000
94 | max_images: 8
95 | increase_log_steps: False
96 |
97 | trainer:
98 | benchmark: True
--------------------------------------------------------------------------------
/stable_diffusion/configs/latent-diffusion/cin256-v2.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss
17 | use_ema: False
18 |
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 64
23 | in_channels: 3
24 | out_channels: 3
25 | model_channels: 192
26 | attention_resolutions:
27 | - 8
28 | - 4
29 | - 2
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 5
36 | num_heads: 1
37 | use_spatial_transformer: true
38 | transformer_depth: 1
39 | context_dim: 512
40 |
41 | first_stage_config:
42 | target: ldm.models.autoencoder.VQModelInterface
43 | params:
44 | embed_dim: 3
45 | n_embed: 8192
46 | ddconfig:
47 | double_z: false
48 | z_channels: 3
49 | resolution: 256
50 | in_channels: 3
51 | out_ch: 3
52 | ch: 128
53 | ch_mult:
54 | - 1
55 | - 2
56 | - 4
57 | num_res_blocks: 2
58 | attn_resolutions: []
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config:
64 | target: ldm.modules.encoders.modules.ClassEmbedder
65 | params:
66 | n_classes: 1001
67 | embed_dim: 512
68 | key: class_label
69 |
--------------------------------------------------------------------------------
/stable_diffusion/configs/latent-diffusion/ffhq-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | embed_dim: 3
40 | n_embed: 8192
41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 42
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: taming.data.faceshq.FFHQTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: taming.data.faceshq.FFHQValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/stable_diffusion/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | image_size: 64
12 | channels: 3
13 | monitor: val/loss_simple_ema
14 | unet_config:
15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
16 | params:
17 | image_size: 64
18 | in_channels: 3
19 | out_channels: 3
20 | model_channels: 224
21 | attention_resolutions:
22 | # note: this isn\t actually the resolution but
23 | # the downsampling factor, i.e. this corresnponds to
24 | # attention on spatial resolution 8,16,32, as the
25 | # spatial reolution of the latents is 64 for f4
26 | - 8
27 | - 4
28 | - 2
29 | num_res_blocks: 2
30 | channel_mult:
31 | - 1
32 | - 2
33 | - 3
34 | - 4
35 | num_head_channels: 32
36 | first_stage_config:
37 | target: ldm.models.autoencoder.VQModelInterface
38 | params:
39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml
40 | embed_dim: 3
41 | n_embed: 8192
42 | ddconfig:
43 | double_z: false
44 | z_channels: 3
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 4
53 | num_res_blocks: 2
54 | attn_resolutions: []
55 | dropout: 0.0
56 | lossconfig:
57 | target: torch.nn.Identity
58 | cond_stage_config: __is_unconditional__
59 | data:
60 | target: main.DataModuleFromConfig
61 | params:
62 | batch_size: 48
63 | num_workers: 5
64 | wrap: false
65 | train:
66 | target: ldm.data.lsun.LSUNBedroomsTrain
67 | params:
68 | size: 256
69 | validation:
70 | target: ldm.data.lsun.LSUNBedroomsValidation
71 | params:
72 | size: 256
73 |
74 |
75 | lightning:
76 | callbacks:
77 | image_logger:
78 | target: main.ImageLogger
79 | params:
80 | batch_frequency: 5000
81 | max_images: 8
82 | increase_log_steps: False
83 |
84 | trainer:
85 | benchmark: True
--------------------------------------------------------------------------------
/stable_diffusion/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False'
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: "image"
12 | cond_stage_key: "image"
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: False
16 | concat_mode: False
17 | scale_by_std: True
18 | monitor: 'val/loss_simple_ema'
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [10000]
24 | cycle_lengths: [10000000000000]
25 | f_start: [1.e-6]
26 | f_max: [1.]
27 | f_min: [ 1.]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 192
36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4
37 | num_res_blocks: 2
38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2
39 | num_heads: 8
40 | use_scale_shift_norm: True
41 | resblock_updown: True
42 |
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | embed_dim: 4
47 | monitor: "val/rec_loss"
48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt"
49 | ddconfig:
50 | double_z: True
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1
57 | num_res_blocks: 2
58 | attn_resolutions: [ ]
59 | dropout: 0.0
60 | lossconfig:
61 | target: torch.nn.Identity
62 |
63 | cond_stage_config: "__is_unconditional__"
64 |
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 96
69 | num_workers: 5
70 | wrap: False
71 | train:
72 | target: ldm.data.lsun.LSUNChurchesTrain
73 | params:
74 | size: 256
75 | validation:
76 | target: ldm.data.lsun.LSUNChurchesValidation
77 | params:
78 | size: 256
79 |
80 | lightning:
81 | callbacks:
82 | image_logger:
83 | target: main.ImageLogger
84 | params:
85 | batch_frequency: 5000
86 | max_images: 8
87 | increase_log_steps: False
88 |
89 |
90 | trainer:
91 | benchmark: True
--------------------------------------------------------------------------------
/stable_diffusion/configs/latent-diffusion/txt2img-1p4B-eval.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.012
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | unet_config:
21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
22 | params:
23 | image_size: 32
24 | in_channels: 4
25 | out_channels: 4
26 | model_channels: 320
27 | attention_resolutions:
28 | - 4
29 | - 2
30 | - 1
31 | num_res_blocks: 2
32 | channel_mult:
33 | - 1
34 | - 2
35 | - 4
36 | - 4
37 | num_heads: 8
38 | use_spatial_transformer: true
39 | transformer_depth: 1
40 | context_dim: 1280
41 | use_checkpoint: true
42 | legacy: False
43 |
44 | first_stage_config:
45 | target: ldm.models.autoencoder.AutoencoderKL
46 | params:
47 | embed_dim: 4
48 | monitor: val/rec_loss
49 | ddconfig:
50 | double_z: true
51 | z_channels: 4
52 | resolution: 256
53 | in_channels: 3
54 | out_ch: 3
55 | ch: 128
56 | ch_mult:
57 | - 1
58 | - 2
59 | - 4
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions: []
63 | dropout: 0.0
64 | lossconfig:
65 | target: torch.nn.Identity
66 |
67 | cond_stage_config:
68 | target: ldm.modules.encoders.modules.BERTEmbedder
69 | params:
70 | n_embed: 1280
71 | n_layer: 32
72 |
--------------------------------------------------------------------------------
/stable_diffusion/configs/retrieval-augmented-diffusion/768x768.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 0.0001
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.015
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: jpg
11 | cond_stage_key: nix
12 | image_size: 48
13 | channels: 16
14 | cond_stage_trainable: false
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_by_std: false
18 | scale_factor: 0.22765929
19 | unet_config:
20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
21 | params:
22 | image_size: 48
23 | in_channels: 16
24 | out_channels: 16
25 | model_channels: 448
26 | attention_resolutions:
27 | - 4
28 | - 2
29 | - 1
30 | num_res_blocks: 2
31 | channel_mult:
32 | - 1
33 | - 2
34 | - 3
35 | - 4
36 | use_scale_shift_norm: false
37 | resblock_updown: false
38 | num_head_channels: 32
39 | use_spatial_transformer: true
40 | transformer_depth: 1
41 | context_dim: 768
42 | use_checkpoint: true
43 | first_stage_config:
44 | target: ldm.models.autoencoder.AutoencoderKL
45 | params:
46 | monitor: val/rec_loss
47 | embed_dim: 16
48 | ddconfig:
49 | double_z: true
50 | z_channels: 16
51 | resolution: 256
52 | in_channels: 3
53 | out_ch: 3
54 | ch: 128
55 | ch_mult:
56 | - 1
57 | - 1
58 | - 2
59 | - 2
60 | - 4
61 | num_res_blocks: 2
62 | attn_resolutions:
63 | - 16
64 | dropout: 0.0
65 | lossconfig:
66 | target: torch.nn.Identity
67 | cond_stage_config:
68 | target: torch.nn.Identity
--------------------------------------------------------------------------------
/stable_diffusion/configs/stable-diffusion/v1-inference.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-04
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.00085
6 | linear_end: 0.0120
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: "jpg"
11 | cond_stage_key: "txt"
12 | image_size: 64
13 | channels: 4
14 | cond_stage_trainable: false # Note: different from the one we trained before
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | scale_factor: 0.18215
18 | use_ema: False
19 |
20 | scheduler_config: # 10000 warmup steps
21 | target: ldm.lr_scheduler.LambdaLinearScheduler
22 | params:
23 | warm_up_steps: [ 10000 ]
24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases
25 | f_start: [ 1.e-6 ]
26 | f_max: [ 1. ]
27 | f_min: [ 1. ]
28 |
29 | unet_config:
30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
31 | params:
32 | image_size: 32 # unused
33 | in_channels: 4
34 | out_channels: 4
35 | model_channels: 320
36 | attention_resolutions: [ 4, 2, 1 ]
37 | num_res_blocks: 2
38 | channel_mult: [ 1, 2, 4, 4 ]
39 | num_heads: 8
40 | use_spatial_transformer: True
41 | transformer_depth: 1
42 | context_dim: 768
43 | use_checkpoint: True
44 | legacy: False
45 |
46 | first_stage_config:
47 | target: ldm.models.autoencoder.AutoencoderKL
48 | params:
49 | embed_dim: 4
50 | monitor: val/rec_loss
51 | ddconfig:
52 | double_z: true
53 | z_channels: 4
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | - 4
63 | num_res_blocks: 2
64 | attn_resolutions: []
65 | dropout: 0.0
66 | lossconfig:
67 | target: torch.nn.Identity
68 |
69 | cond_stage_config:
70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder
71 |
--------------------------------------------------------------------------------
/stable_diffusion/data/DejaVuSans.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/DejaVuSans.ttf
--------------------------------------------------------------------------------
/stable_diffusion/data/example_conditioning/superresolution/sample_0.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/example_conditioning/superresolution/sample_0.jpg
--------------------------------------------------------------------------------
/stable_diffusion/data/example_conditioning/text_conditional/sample_0.txt:
--------------------------------------------------------------------------------
1 | A basket of cerries
2 |
--------------------------------------------------------------------------------
/stable_diffusion/data/imagenet_train_hr_indices.p.REMOVED.git-id:
--------------------------------------------------------------------------------
1 | b8d6d4689d2ecf32147e9cc2f5e6c50e072df26f
--------------------------------------------------------------------------------
/stable_diffusion/data/imagenet_val_hr_indices.p:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/imagenet_val_hr_indices.p
--------------------------------------------------------------------------------
/stable_diffusion/data/inpainting_examples/6458524847_2f4c361183_k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/6458524847_2f4c361183_k.png
--------------------------------------------------------------------------------
/stable_diffusion/data/inpainting_examples/6458524847_2f4c361183_k_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/6458524847_2f4c361183_k_mask.png
--------------------------------------------------------------------------------
/stable_diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k.png
--------------------------------------------------------------------------------
/stable_diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/8399166846_f6fb4e4b8e_k_mask.png
--------------------------------------------------------------------------------
/stable_diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs.png
--------------------------------------------------------------------------------
/stable_diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/alex-iby-G_Pk4D9rMLs_mask.png
--------------------------------------------------------------------------------
/stable_diffusion/data/inpainting_examples/bench2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/bench2.png
--------------------------------------------------------------------------------
/stable_diffusion/data/inpainting_examples/bench2_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/bench2_mask.png
--------------------------------------------------------------------------------
/stable_diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0.png
--------------------------------------------------------------------------------
/stable_diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/bertrand-gabioud-CpuFzIsHYJ0_mask.png
--------------------------------------------------------------------------------
/stable_diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y.png
--------------------------------------------------------------------------------
/stable_diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/billow926-12-Wc-Zgx6Y_mask.png
--------------------------------------------------------------------------------
/stable_diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png
--------------------------------------------------------------------------------
/stable_diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png
--------------------------------------------------------------------------------
/stable_diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34.png
--------------------------------------------------------------------------------
/stable_diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/data/inpainting_examples/photo-1583445095369-9c651e7e5d34_mask.png
--------------------------------------------------------------------------------
/stable_diffusion/environment.yaml:
--------------------------------------------------------------------------------
1 | name: ldm
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - python=3.8.5
7 | - pip=20.3
8 | - cudatoolkit=11.3
9 | - pytorch=1.11.0
10 | - torchvision=0.12.0
11 | - numpy=1.19.2
12 | - pip:
13 | - albumentations==0.4.3
14 | - diffusers
15 | - opencv-python==4.1.2.30
16 | - pudb==2019.2
17 | - invisible-watermark
18 | - imageio==2.9.0
19 | - imageio-ffmpeg==0.4.2
20 | - pytorch-lightning==1.4.2
21 | - omegaconf==2.1.1
22 | - test-tube>=0.7.5
23 | - streamlit>=0.73.1
24 | - einops==0.3.0
25 | - torch-fidelity==0.3.0
26 | - transformers==4.19.2
27 | - torchmetrics==0.6.0
28 | - kornia==0.6
29 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers
30 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip
31 | - -e .
32 |
--------------------------------------------------------------------------------
/stable_diffusion/ldm/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/ldm/data/__init__.py
--------------------------------------------------------------------------------
/stable_diffusion/ldm/data/base.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset
3 |
4 |
5 | class Txt2ImgIterableBaseDataset(IterableDataset):
6 | '''
7 | Define an interface to make the IterableDatasets for text2img data chainable
8 | '''
9 | def __init__(self, num_records=0, valid_ids=None, size=256):
10 | super().__init__()
11 | self.num_records = num_records
12 | self.valid_ids = valid_ids
13 | self.sample_ids = valid_ids
14 | self.size = size
15 |
16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.')
17 |
18 | def __len__(self):
19 | return self.num_records
20 |
21 | @abstractmethod
22 | def __iter__(self):
23 | pass
--------------------------------------------------------------------------------
/stable_diffusion/ldm/data/lsun.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import PIL
4 | from PIL import Image
5 | from torch.utils.data import Dataset
6 | from torchvision import transforms
7 |
8 |
9 | class LSUNBase(Dataset):
10 | def __init__(self,
11 | txt_file,
12 | data_root,
13 | size=None,
14 | interpolation="bicubic",
15 | flip_p=0.5
16 | ):
17 | self.data_paths = txt_file
18 | self.data_root = data_root
19 | with open(self.data_paths, "r") as f:
20 | self.image_paths = f.read().splitlines()
21 | self._length = len(self.image_paths)
22 | self.labels = {
23 | "relative_file_path_": [l for l in self.image_paths],
24 | "file_path_": [os.path.join(self.data_root, l)
25 | for l in self.image_paths],
26 | }
27 |
28 | self.size = size
29 | self.interpolation = {"linear": PIL.Image.LINEAR,
30 | "bilinear": PIL.Image.BILINEAR,
31 | "bicubic": PIL.Image.BICUBIC,
32 | "lanczos": PIL.Image.LANCZOS,
33 | }[interpolation]
34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p)
35 |
36 | def __len__(self):
37 | return self._length
38 |
39 | def __getitem__(self, i):
40 | example = dict((k, self.labels[k][i]) for k in self.labels)
41 | image = Image.open(example["file_path_"])
42 | if not image.mode == "RGB":
43 | image = image.convert("RGB")
44 |
45 | # default to score-sde preprocessing
46 | img = np.array(image).astype(np.uint8)
47 | crop = min(img.shape[0], img.shape[1])
48 | h, w, = img.shape[0], img.shape[1]
49 | img = img[(h - crop) // 2:(h + crop) // 2,
50 | (w - crop) // 2:(w + crop) // 2]
51 |
52 | image = Image.fromarray(img)
53 | if self.size is not None:
54 | image = image.resize((self.size, self.size), resample=self.interpolation)
55 |
56 | image = self.flip(image)
57 | image = np.array(image).astype(np.uint8)
58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32)
59 | return example
60 |
61 |
62 | class LSUNChurchesTrain(LSUNBase):
63 | def __init__(self, **kwargs):
64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs)
65 |
66 |
67 | class LSUNChurchesValidation(LSUNBase):
68 | def __init__(self, flip_p=0., **kwargs):
69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches",
70 | flip_p=flip_p, **kwargs)
71 |
72 |
73 | class LSUNBedroomsTrain(LSUNBase):
74 | def __init__(self, **kwargs):
75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs)
76 |
77 |
78 | class LSUNBedroomsValidation(LSUNBase):
79 | def __init__(self, flip_p=0.0, **kwargs):
80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms",
81 | flip_p=flip_p, **kwargs)
82 |
83 |
84 | class LSUNCatsTrain(LSUNBase):
85 | def __init__(self, **kwargs):
86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs)
87 |
88 |
89 | class LSUNCatsValidation(LSUNBase):
90 | def __init__(self, flip_p=0., **kwargs):
91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats",
92 | flip_p=flip_p, **kwargs)
93 |
--------------------------------------------------------------------------------
/stable_diffusion/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/stable_diffusion/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/ldm/models/diffusion/__init__.py
--------------------------------------------------------------------------------
/stable_diffusion/ldm/models/diffusion/dpm_solver/__init__.py:
--------------------------------------------------------------------------------
1 | from .sampler import DPMSolverSampler
--------------------------------------------------------------------------------
/stable_diffusion/ldm/models/diffusion/dpm_solver/sampler.py:
--------------------------------------------------------------------------------
1 | """SAMPLING ONLY."""
2 |
3 | import torch
4 |
5 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver
6 |
7 |
8 | class DPMSolverSampler(object):
9 | def __init__(self, model, **kwargs):
10 | super().__init__()
11 | self.model = model
12 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device)
13 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod))
14 |
15 | def register_buffer(self, name, attr):
16 | if type(attr) == torch.Tensor:
17 | if attr.device != torch.device("cuda"):
18 | attr = attr.to(torch.device("cuda"))
19 | setattr(self, name, attr)
20 |
21 | @torch.no_grad()
22 | def sample(self,
23 | S,
24 | batch_size,
25 | shape,
26 | conditioning=None,
27 | callback=None,
28 | normals_sequence=None,
29 | img_callback=None,
30 | quantize_x0=False,
31 | eta=0.,
32 | mask=None,
33 | x0=None,
34 | temperature=1.,
35 | noise_dropout=0.,
36 | score_corrector=None,
37 | corrector_kwargs=None,
38 | verbose=True,
39 | x_T=None,
40 | log_every_t=100,
41 | unconditional_guidance_scale=1.,
42 | unconditional_conditioning=None,
43 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
44 | **kwargs
45 | ):
46 | if conditioning is not None:
47 | if isinstance(conditioning, dict):
48 | cbs = conditioning[list(conditioning.keys())[0]].shape[0]
49 | if cbs != batch_size:
50 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
51 | else:
52 | if conditioning.shape[0] != batch_size:
53 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
54 |
55 | # sampling
56 | C, H, W = shape
57 | size = (batch_size, C, H, W)
58 |
59 | # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}')
60 |
61 | device = self.model.betas.device
62 | if x_T is None:
63 | img = torch.randn(size, device=device)
64 | else:
65 | img = x_T
66 |
67 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod)
68 |
69 | model_fn = model_wrapper(
70 | lambda x, t, c: self.model.apply_model(x, t, c),
71 | ns,
72 | model_type="noise",
73 | guidance_type="classifier-free",
74 | condition=conditioning,
75 | unconditional_condition=unconditional_conditioning,
76 | guidance_scale=unconditional_guidance_scale,
77 | )
78 |
79 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False)
80 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True)
81 |
82 | return x.to(device), None
83 |
--------------------------------------------------------------------------------
/stable_diffusion/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/ldm/modules/diffusionmodules/__init__.py
--------------------------------------------------------------------------------
/stable_diffusion/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/ldm/modules/distributions/__init__.py
--------------------------------------------------------------------------------
/stable_diffusion/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/stable_diffusion/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/stable_diffusion/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/ldm/modules/encoders/__init__.py
--------------------------------------------------------------------------------
/stable_diffusion/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr
2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light
3 |
--------------------------------------------------------------------------------
/stable_diffusion/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/salesforce/HIVE/a14ba7f26779e7d3dd506a4f5ab203589187e789/stable_diffusion/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/stable_diffusion/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/stable_diffusion/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/stable_diffusion/ldm/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 |
11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15 | loss_real = (weights * loss_real).sum() / weights.sum()
16 | loss_fake = (weights * loss_fake).sum() / weights.sum()
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 | def adopt_weight(weight, global_step, threshold=0, value=0.):
21 | if global_step < threshold:
22 | weight = value
23 | return weight
24 |
25 |
26 | def measure_perplexity(predicted_indices, n_embed):
27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30 | avg_probs = encodings.mean(0)
31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32 | cluster_use = torch.sum(avg_probs > 0)
33 | return perplexity, cluster_use
34 |
35 | def l1(x, y):
36 | return torch.abs(x-y)
37 |
38 |
39 | def l2(x, y):
40 | return torch.pow((x-y), 2)
41 |
42 |
43 | class VQLPIPSWithDiscriminator(nn.Module):
44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48 | pixel_loss="l1"):
49 | super().__init__()
50 | assert disc_loss in ["hinge", "vanilla"]
51 | assert perceptual_loss in ["lpips", "clips", "dists"]
52 | assert pixel_loss in ["l1", "l2"]
53 | self.codebook_weight = codebook_weight
54 | self.pixel_weight = pixelloss_weight
55 | if perceptual_loss == "lpips":
56 | print(f"{self.__class__.__name__}: Running with LPIPS.")
57 | self.perceptual_loss = LPIPS().eval()
58 | else:
59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60 | self.perceptual_weight = perceptual_weight
61 |
62 | if pixel_loss == "l1":
63 | self.pixel_loss = l1
64 | else:
65 | self.pixel_loss = l2
66 |
67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68 | n_layers=disc_num_layers,
69 | use_actnorm=use_actnorm,
70 | ndf=disc_ndf
71 | ).apply(weights_init)
72 | self.discriminator_iter_start = disc_start
73 | if disc_loss == "hinge":
74 | self.disc_loss = hinge_d_loss
75 | elif disc_loss == "vanilla":
76 | self.disc_loss = vanilla_d_loss
77 | else:
78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80 | self.disc_factor = disc_factor
81 | self.discriminator_weight = disc_weight
82 | self.disc_conditional = disc_conditional
83 | self.n_classes = n_classes
84 |
85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86 | if last_layer is not None:
87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89 | else:
90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92 |
93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95 | d_weight = d_weight * self.discriminator_weight
96 | return d_weight
97 |
98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
100 | if not exists(codebook_loss):
101 | codebook_loss = torch.tensor([0.]).to(inputs.device)
102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
104 | if self.perceptual_weight > 0:
105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
106 | rec_loss = rec_loss + self.perceptual_weight * p_loss
107 | else:
108 | p_loss = torch.tensor([0.0])
109 |
110 | nll_loss = rec_loss
111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
112 | nll_loss = torch.mean(nll_loss)
113 |
114 | # now the GAN part
115 | if optimizer_idx == 0:
116 | # generator update
117 | if cond is None:
118 | assert not self.disc_conditional
119 | logits_fake = self.discriminator(reconstructions.contiguous())
120 | else:
121 | assert self.disc_conditional
122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
123 | g_loss = -torch.mean(logits_fake)
124 |
125 | try:
126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
127 | except RuntimeError:
128 | assert not self.training
129 | d_weight = torch.tensor(0.0)
130 |
131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
133 |
134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
136 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
137 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
138 | "{}/p_loss".format(split): p_loss.detach().mean(),
139 | "{}/d_weight".format(split): d_weight.detach(),
140 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
141 | "{}/g_loss".format(split): g_loss.detach().mean(),
142 | }
143 | if predicted_indices is not None:
144 | assert self.n_classes is not None
145 | with torch.no_grad():
146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
147 | log[f"{split}/perplexity"] = perplexity
148 | log[f"{split}/cluster_usage"] = cluster_usage
149 | return loss, log
150 |
151 | if optimizer_idx == 1:
152 | # second pass for discriminator update
153 | if cond is None:
154 | logits_real = self.discriminator(inputs.contiguous().detach())
155 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
156 | else:
157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
159 |
160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
162 |
163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
164 | "{}/logits_real".format(split): logits_real.detach().mean(),
165 | "{}/logits_fake".format(split): logits_fake.detach().mean()
166 | }
167 | return d_loss, log
168 |
--------------------------------------------------------------------------------
/stable_diffusion/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 | from einops import rearrange
7 | from functools import partial
8 |
9 | import multiprocessing as mp
10 | from threading import Thread
11 | from queue import Queue
12 |
13 | from inspect import isfunction
14 | from PIL import Image, ImageDraw, ImageFont
15 |
16 |
17 | def log_txt_as_img(wh, xc, size=10):
18 | # wh a tuple of (width, height)
19 | # xc a list of captions to plot
20 | b = len(xc)
21 | txts = list()
22 | for bi in range(b):
23 | txt = Image.new("RGB", wh, color="white")
24 | draw = ImageDraw.Draw(txt)
25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
26 | nc = int(40 * (wh[0] / 256))
27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
28 |
29 | try:
30 | draw.text((0, 0), lines, fill="black", font=font)
31 | except UnicodeEncodeError:
32 | print("Cant encode string for logging. Skipping.")
33 |
34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
35 | txts.append(txt)
36 | txts = np.stack(txts)
37 | txts = torch.tensor(txts)
38 | return txts
39 |
40 |
41 | def ismap(x):
42 | if not isinstance(x, torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] > 3)
45 |
46 |
47 | def isimage(x):
48 | if not isinstance(x, torch.Tensor):
49 | return False
50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51 |
52 |
53 | def exists(x):
54 | return x is not None
55 |
56 |
57 | def default(val, d):
58 | if exists(val):
59 | return val
60 | return d() if isfunction(d) else d
61 |
62 |
63 | def mean_flat(tensor):
64 | """
65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
66 | Take the mean over all non-batch dimensions.
67 | """
68 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
69 |
70 |
71 | def count_params(model, verbose=False):
72 | total_params = sum(p.numel() for p in model.parameters())
73 | if verbose:
74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75 | return total_params
76 |
77 |
78 | def instantiate_from_config(config):
79 | if not "target" in config:
80 | if config == '__is_first_stage__':
81 | return None
82 | elif config == "__is_unconditional__":
83 | return None
84 | raise KeyError("Expected key `target` to instantiate.")
85 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
86 |
87 |
88 | def get_obj_from_str(string, reload=False):
89 | module, cls = string.rsplit(".", 1)
90 | if reload:
91 | module_imp = importlib.import_module(module)
92 | importlib.reload(module_imp)
93 | return getattr(importlib.import_module(module, package=None), cls)
94 |
95 |
96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
97 | # create dummy dataset instance
98 |
99 | # run prefetching
100 | if idx_to_fn:
101 | res = func(data, worker_id=idx)
102 | else:
103 | res = func(data)
104 | Q.put([idx, res])
105 | Q.put("Done")
106 |
107 |
108 | def parallel_data_prefetch(
109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
110 | ):
111 | # if target_data_type not in ["ndarray", "list"]:
112 | # raise ValueError(
113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
114 | # )
115 | if isinstance(data, np.ndarray) and target_data_type == "list":
116 | raise ValueError("list expected but function got ndarray.")
117 | elif isinstance(data, abc.Iterable):
118 | if isinstance(data, dict):
119 | print(
120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
121 | )
122 | data = list(data.values())
123 | if target_data_type == "ndarray":
124 | data = np.asarray(data)
125 | else:
126 | data = list(data)
127 | else:
128 | raise TypeError(
129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
130 | )
131 |
132 | if cpu_intensive:
133 | Q = mp.Queue(1000)
134 | proc = mp.Process
135 | else:
136 | Q = Queue(1000)
137 | proc = Thread
138 | # spawn processes
139 | if target_data_type == "ndarray":
140 | arguments = [
141 | [func, Q, part, i, use_worker_id]
142 | for i, part in enumerate(np.array_split(data, n_proc))
143 | ]
144 | else:
145 | step = (
146 | int(len(data) / n_proc + 1)
147 | if len(data) % n_proc != 0
148 | else int(len(data) / n_proc)
149 | )
150 | arguments = [
151 | [func, Q, part, i, use_worker_id]
152 | for i, part in enumerate(
153 | [data[i: i + step] for i in range(0, len(data), step)]
154 | )
155 | ]
156 | processes = []
157 | for i in range(n_proc):
158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
159 | processes += [p]
160 |
161 | # start processes
162 | print(f"Start prefetching...")
163 | import time
164 |
165 | start = time.time()
166 | gather_res = [[] for _ in range(n_proc)]
167 | try:
168 | for p in processes:
169 | p.start()
170 |
171 | k = 0
172 | while k < n_proc:
173 | # get result
174 | res = Q.get()
175 | if res == "Done":
176 | k += 1
177 | else:
178 | gather_res[res[0]] = res[1]
179 |
180 | except Exception as e:
181 | print("Exception: ", e)
182 | for p in processes:
183 | p.terminate()
184 |
185 | raise e
186 | finally:
187 | for p in processes:
188 | p.join()
189 | print(f"Prefetching complete. [{time.time() - start} sec.]")
190 |
191 | if target_data_type == 'ndarray':
192 | if not isinstance(gather_res[0], np.ndarray):
193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
194 |
195 | # order outputs
196 | return np.concatenate(gather_res, axis=0)
197 | elif target_data_type == 'list':
198 | out = []
199 | for r in gather_res:
200 | out.extend(r)
201 | return out
202 | else:
203 | return gather_res
204 |
--------------------------------------------------------------------------------
/stable_diffusion/models/first_stage_models/kl-f16/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 16
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 16
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 1
23 | - 2
24 | - 2
25 | - 4
26 | num_res_blocks: 2
27 | attn_resolutions:
28 | - 16
29 | dropout: 0.0
30 | data:
31 | target: main.DataModuleFromConfig
32 | params:
33 | batch_size: 6
34 | wrap: true
35 | train:
36 | target: ldm.data.openimages.FullOpenImagesTrain
37 | params:
38 | size: 384
39 | crop_size: 256
40 | validation:
41 | target: ldm.data.openimages.FullOpenImagesValidation
42 | params:
43 | size: 384
44 | crop_size: 256
45 |
--------------------------------------------------------------------------------
/stable_diffusion/models/first_stage_models/kl-f32/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 64
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 64
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 1
23 | - 2
24 | - 2
25 | - 4
26 | - 4
27 | num_res_blocks: 2
28 | attn_resolutions:
29 | - 16
30 | - 8
31 | dropout: 0.0
32 | data:
33 | target: main.DataModuleFromConfig
34 | params:
35 | batch_size: 6
36 | wrap: true
37 | train:
38 | target: ldm.data.openimages.FullOpenImagesTrain
39 | params:
40 | size: 384
41 | crop_size: 256
42 | validation:
43 | target: ldm.data.openimages.FullOpenImagesValidation
44 | params:
45 | size: 384
46 | crop_size: 256
47 |
--------------------------------------------------------------------------------
/stable_diffusion/models/first_stage_models/kl-f4/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 3
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 3
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 2
23 | - 4
24 | num_res_blocks: 2
25 | attn_resolutions: []
26 | dropout: 0.0
27 | data:
28 | target: main.DataModuleFromConfig
29 | params:
30 | batch_size: 10
31 | wrap: true
32 | train:
33 | target: ldm.data.openimages.FullOpenImagesTrain
34 | params:
35 | size: 384
36 | crop_size: 256
37 | validation:
38 | target: ldm.data.openimages.FullOpenImagesValidation
39 | params:
40 | size: 384
41 | crop_size: 256
42 |
--------------------------------------------------------------------------------
/stable_diffusion/models/first_stage_models/kl-f8/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.AutoencoderKL
4 | params:
5 | monitor: val/rec_loss
6 | embed_dim: 4
7 | lossconfig:
8 | target: ldm.modules.losses.LPIPSWithDiscriminator
9 | params:
10 | disc_start: 50001
11 | kl_weight: 1.0e-06
12 | disc_weight: 0.5
13 | ddconfig:
14 | double_z: true
15 | z_channels: 4
16 | resolution: 256
17 | in_channels: 3
18 | out_ch: 3
19 | ch: 128
20 | ch_mult:
21 | - 1
22 | - 2
23 | - 4
24 | - 4
25 | num_res_blocks: 2
26 | attn_resolutions: []
27 | dropout: 0.0
28 | data:
29 | target: main.DataModuleFromConfig
30 | params:
31 | batch_size: 4
32 | wrap: true
33 | train:
34 | target: ldm.data.openimages.FullOpenImagesTrain
35 | params:
36 | size: 384
37 | crop_size: 256
38 | validation:
39 | target: ldm.data.openimages.FullOpenImagesValidation
40 | params:
41 | size: 384
42 | crop_size: 256
43 |
--------------------------------------------------------------------------------
/stable_diffusion/models/first_stage_models/vq-f16/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 8
6 | n_embed: 16384
7 | ddconfig:
8 | double_z: false
9 | z_channels: 8
10 | resolution: 256
11 | in_channels: 3
12 | out_ch: 3
13 | ch: 128
14 | ch_mult:
15 | - 1
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 16
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 250001
30 | disc_weight: 0.75
31 | disc_num_layers: 2
32 | codebook_weight: 1.0
33 |
34 | data:
35 | target: main.DataModuleFromConfig
36 | params:
37 | batch_size: 14
38 | num_workers: 20
39 | wrap: true
40 | train:
41 | target: ldm.data.openimages.FullOpenImagesTrain
42 | params:
43 | size: 384
44 | crop_size: 256
45 | validation:
46 | target: ldm.data.openimages.FullOpenImagesValidation
47 | params:
48 | size: 384
49 | crop_size: 256
50 |
--------------------------------------------------------------------------------
/stable_diffusion/models/first_stage_models/vq-f4-noattn/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 3
6 | n_embed: 8192
7 | monitor: val/rec_loss
8 |
9 | ddconfig:
10 | attn_type: none
11 | double_z: false
12 | z_channels: 3
13 | resolution: 256
14 | in_channels: 3
15 | out_ch: 3
16 | ch: 128
17 | ch_mult:
18 | - 1
19 | - 2
20 | - 4
21 | num_res_blocks: 2
22 | attn_resolutions: []
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 11
30 | disc_weight: 0.75
31 | codebook_weight: 1.0
32 |
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 8
37 | num_workers: 12
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | crop_size: 256
43 | validation:
44 | target: ldm.data.openimages.FullOpenImagesValidation
45 | params:
46 | crop_size: 256
47 |
--------------------------------------------------------------------------------
/stable_diffusion/models/first_stage_models/vq-f4/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 3
6 | n_embed: 8192
7 | monitor: val/rec_loss
8 |
9 | ddconfig:
10 | double_z: false
11 | z_channels: 3
12 | resolution: 256
13 | in_channels: 3
14 | out_ch: 3
15 | ch: 128
16 | ch_mult:
17 | - 1
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions: []
22 | dropout: 0.0
23 | lossconfig:
24 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
25 | params:
26 | disc_conditional: false
27 | disc_in_channels: 3
28 | disc_start: 0
29 | disc_weight: 0.75
30 | codebook_weight: 1.0
31 |
32 | data:
33 | target: main.DataModuleFromConfig
34 | params:
35 | batch_size: 8
36 | num_workers: 16
37 | wrap: true
38 | train:
39 | target: ldm.data.openimages.FullOpenImagesTrain
40 | params:
41 | crop_size: 256
42 | validation:
43 | target: ldm.data.openimages.FullOpenImagesValidation
44 | params:
45 | crop_size: 256
46 |
--------------------------------------------------------------------------------
/stable_diffusion/models/first_stage_models/vq-f8-n256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 4
6 | n_embed: 256
7 | monitor: val/rec_loss
8 | ddconfig:
9 | double_z: false
10 | z_channels: 4
11 | resolution: 256
12 | in_channels: 3
13 | out_ch: 3
14 | ch: 128
15 | ch_mult:
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 32
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 250001
30 | disc_weight: 0.75
31 | codebook_weight: 1.0
32 |
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 10
37 | num_workers: 20
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | size: 384
43 | crop_size: 256
44 | validation:
45 | target: ldm.data.openimages.FullOpenImagesValidation
46 | params:
47 | size: 384
48 | crop_size: 256
49 |
--------------------------------------------------------------------------------
/stable_diffusion/models/first_stage_models/vq-f8/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: ldm.models.autoencoder.VQModel
4 | params:
5 | embed_dim: 4
6 | n_embed: 16384
7 | monitor: val/rec_loss
8 | ddconfig:
9 | double_z: false
10 | z_channels: 4
11 | resolution: 256
12 | in_channels: 3
13 | out_ch: 3
14 | ch: 128
15 | ch_mult:
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 32
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_num_layers: 2
30 | disc_start: 1
31 | disc_weight: 0.6
32 | codebook_weight: 1.0
33 | data:
34 | target: main.DataModuleFromConfig
35 | params:
36 | batch_size: 10
37 | num_workers: 20
38 | wrap: true
39 | train:
40 | target: ldm.data.openimages.FullOpenImagesTrain
41 | params:
42 | size: 384
43 | crop_size: 256
44 | validation:
45 | target: ldm.data.openimages.FullOpenImagesValidation
46 | params:
47 | size: 384
48 | crop_size: 256
49 |
--------------------------------------------------------------------------------
/stable_diffusion/models/ldm/bsr_sr/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l2
10 | first_stage_key: image
11 | cond_stage_key: LR_image
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: false
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 160
23 | attention_resolutions:
24 | - 16
25 | - 8
26 | num_res_blocks: 2
27 | channel_mult:
28 | - 1
29 | - 2
30 | - 2
31 | - 4
32 | num_head_channels: 32
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | monitor: val/rec_loss
39 | ddconfig:
40 | double_z: false
41 | z_channels: 3
42 | resolution: 256
43 | in_channels: 3
44 | out_ch: 3
45 | ch: 128
46 | ch_mult:
47 | - 1
48 | - 2
49 | - 4
50 | num_res_blocks: 2
51 | attn_resolutions: []
52 | dropout: 0.0
53 | lossconfig:
54 | target: torch.nn.Identity
55 | cond_stage_config:
56 | target: torch.nn.Identity
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 64
61 | wrap: false
62 | num_workers: 12
63 | train:
64 | target: ldm.data.openimages.SuperresOpenImagesAdvancedTrain
65 | params:
66 | size: 256
67 | degradation: bsrgan_light
68 | downscale_f: 4
69 | min_crop_f: 0.5
70 | max_crop_f: 1.0
71 | random_crop: true
72 | validation:
73 | target: ldm.data.openimages.SuperresOpenImagesAdvancedValidation
74 | params:
75 | size: 256
76 | degradation: bsrgan_light
77 | downscale_f: 4
78 | min_crop_f: 0.5
79 | max_crop_f: 1.0
80 | random_crop: true
81 |
--------------------------------------------------------------------------------
/stable_diffusion/models/ldm/celeba256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 48
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.faceshq.CelebAHQTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.faceshq.CelebAHQValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/stable_diffusion/models/ldm/cin256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 32
13 | channels: 4
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 32
21 | in_channels: 4
22 | out_channels: 4
23 | model_channels: 256
24 | attention_resolutions:
25 | - 4
26 | - 2
27 | - 1
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 4
33 | num_head_channels: 32
34 | use_spatial_transformer: true
35 | transformer_depth: 1
36 | context_dim: 512
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 4
41 | n_embed: 16384
42 | ddconfig:
43 | double_z: false
44 | z_channels: 4
45 | resolution: 256
46 | in_channels: 3
47 | out_ch: 3
48 | ch: 128
49 | ch_mult:
50 | - 1
51 | - 2
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions:
56 | - 32
57 | dropout: 0.0
58 | lossconfig:
59 | target: torch.nn.Identity
60 | cond_stage_config:
61 | target: ldm.modules.encoders.modules.ClassEmbedder
62 | params:
63 | embed_dim: 512
64 | key: class_label
65 | data:
66 | target: main.DataModuleFromConfig
67 | params:
68 | batch_size: 64
69 | num_workers: 12
70 | wrap: false
71 | train:
72 | target: ldm.data.imagenet.ImageNetTrain
73 | params:
74 | config:
75 | size: 256
76 | validation:
77 | target: ldm.data.imagenet.ImageNetValidation
78 | params:
79 | config:
80 | size: 256
81 |
--------------------------------------------------------------------------------
/stable_diffusion/models/ldm/ffhq256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 42
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.faceshq.FFHQTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.faceshq.FFHQValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/stable_diffusion/models/ldm/inpainting_big/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: masked_image
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | monitor: val/loss
16 | scheduler_config:
17 | target: ldm.lr_scheduler.LambdaWarmUpCosineScheduler
18 | params:
19 | verbosity_interval: 0
20 | warm_up_steps: 1000
21 | max_decay_steps: 50000
22 | lr_start: 0.001
23 | lr_max: 0.1
24 | lr_min: 0.0001
25 | unet_config:
26 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
27 | params:
28 | image_size: 64
29 | in_channels: 7
30 | out_channels: 3
31 | model_channels: 256
32 | attention_resolutions:
33 | - 8
34 | - 4
35 | - 2
36 | num_res_blocks: 2
37 | channel_mult:
38 | - 1
39 | - 2
40 | - 3
41 | - 4
42 | num_heads: 8
43 | resblock_updown: true
44 | first_stage_config:
45 | target: ldm.models.autoencoder.VQModelInterface
46 | params:
47 | embed_dim: 3
48 | n_embed: 8192
49 | monitor: val/rec_loss
50 | ddconfig:
51 | attn_type: none
52 | double_z: false
53 | z_channels: 3
54 | resolution: 256
55 | in_channels: 3
56 | out_ch: 3
57 | ch: 128
58 | ch_mult:
59 | - 1
60 | - 2
61 | - 4
62 | num_res_blocks: 2
63 | attn_resolutions: []
64 | dropout: 0.0
65 | lossconfig:
66 | target: ldm.modules.losses.contperceptual.DummyLoss
67 | cond_stage_config: __is_first_stage__
68 |
--------------------------------------------------------------------------------
/stable_diffusion/models/ldm/layout2img-openimages256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: coordinates_bbox
12 | image_size: 64
13 | channels: 3
14 | conditioning_key: crossattn
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 3
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 8
25 | - 4
26 | - 2
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 2
31 | - 3
32 | - 4
33 | num_head_channels: 32
34 | use_spatial_transformer: true
35 | transformer_depth: 3
36 | context_dim: 512
37 | first_stage_config:
38 | target: ldm.models.autoencoder.VQModelInterface
39 | params:
40 | embed_dim: 3
41 | n_embed: 8192
42 | monitor: val/rec_loss
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config:
60 | target: ldm.modules.encoders.modules.BERTEmbedder
61 | params:
62 | n_embed: 512
63 | n_layer: 16
64 | vocab_size: 8192
65 | max_seq_len: 92
66 | use_tokenizer: false
67 | monitor: val/loss_simple_ema
68 | data:
69 | target: main.DataModuleFromConfig
70 | params:
71 | batch_size: 24
72 | wrap: false
73 | num_workers: 10
74 | train:
75 | target: ldm.data.openimages.OpenImagesBBoxTrain
76 | params:
77 | size: 256
78 | validation:
79 | target: ldm.data.openimages.OpenImagesBBoxValidation
80 | params:
81 | size: 256
82 |
--------------------------------------------------------------------------------
/stable_diffusion/models/ldm/lsun_beds256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: class_label
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: false
15 | concat_mode: false
16 | monitor: val/loss
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 224
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 4
34 | num_head_channels: 32
35 | first_stage_config:
36 | target: ldm.models.autoencoder.VQModelInterface
37 | params:
38 | embed_dim: 3
39 | n_embed: 8192
40 | ddconfig:
41 | double_z: false
42 | z_channels: 3
43 | resolution: 256
44 | in_channels: 3
45 | out_ch: 3
46 | ch: 128
47 | ch_mult:
48 | - 1
49 | - 2
50 | - 4
51 | num_res_blocks: 2
52 | attn_resolutions: []
53 | dropout: 0.0
54 | lossconfig:
55 | target: torch.nn.Identity
56 | cond_stage_config: __is_unconditional__
57 | data:
58 | target: main.DataModuleFromConfig
59 | params:
60 | batch_size: 48
61 | num_workers: 5
62 | wrap: false
63 | train:
64 | target: ldm.data.lsun.LSUNBedroomsTrain
65 | params:
66 | size: 256
67 | validation:
68 | target: ldm.data.lsun.LSUNBedroomsValidation
69 | params:
70 | size: 256
71 |
--------------------------------------------------------------------------------
/stable_diffusion/models/ldm/lsun_churches256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 5.0e-05
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0155
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | loss_type: l1
11 | first_stage_key: image
12 | cond_stage_key: image
13 | image_size: 32
14 | channels: 4
15 | cond_stage_trainable: false
16 | concat_mode: false
17 | scale_by_std: true
18 | monitor: val/loss_simple_ema
19 | scheduler_config:
20 | target: ldm.lr_scheduler.LambdaLinearScheduler
21 | params:
22 | warm_up_steps:
23 | - 10000
24 | cycle_lengths:
25 | - 10000000000000
26 | f_start:
27 | - 1.0e-06
28 | f_max:
29 | - 1.0
30 | f_min:
31 | - 1.0
32 | unet_config:
33 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
34 | params:
35 | image_size: 32
36 | in_channels: 4
37 | out_channels: 4
38 | model_channels: 192
39 | attention_resolutions:
40 | - 1
41 | - 2
42 | - 4
43 | - 8
44 | num_res_blocks: 2
45 | channel_mult:
46 | - 1
47 | - 2
48 | - 2
49 | - 4
50 | - 4
51 | num_heads: 8
52 | use_scale_shift_norm: true
53 | resblock_updown: true
54 | first_stage_config:
55 | target: ldm.models.autoencoder.AutoencoderKL
56 | params:
57 | embed_dim: 4
58 | monitor: val/rec_loss
59 | ddconfig:
60 | double_z: true
61 | z_channels: 4
62 | resolution: 256
63 | in_channels: 3
64 | out_ch: 3
65 | ch: 128
66 | ch_mult:
67 | - 1
68 | - 2
69 | - 4
70 | - 4
71 | num_res_blocks: 2
72 | attn_resolutions: []
73 | dropout: 0.0
74 | lossconfig:
75 | target: torch.nn.Identity
76 |
77 | cond_stage_config: '__is_unconditional__'
78 |
79 | data:
80 | target: main.DataModuleFromConfig
81 | params:
82 | batch_size: 96
83 | num_workers: 5
84 | wrap: false
85 | train:
86 | target: ldm.data.lsun.LSUNChurchesTrain
87 | params:
88 | size: 256
89 | validation:
90 | target: ldm.data.lsun.LSUNChurchesValidation
91 | params:
92 | size: 256
93 |
--------------------------------------------------------------------------------
/stable_diffusion/models/ldm/semantic_synthesis256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: segmentation
12 | image_size: 64
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 64
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 32
25 | - 16
26 | - 8
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 4
31 | - 8
32 | num_heads: 8
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | ddconfig:
39 | double_z: false
40 | z_channels: 3
41 | resolution: 256
42 | in_channels: 3
43 | out_ch: 3
44 | ch: 128
45 | ch_mult:
46 | - 1
47 | - 2
48 | - 4
49 | num_res_blocks: 2
50 | attn_resolutions: []
51 | dropout: 0.0
52 | lossconfig:
53 | target: torch.nn.Identity
54 | cond_stage_config:
55 | target: ldm.modules.encoders.modules.SpatialRescaler
56 | params:
57 | n_stages: 2
58 | in_channels: 182
59 | out_channels: 3
60 |
--------------------------------------------------------------------------------
/stable_diffusion/models/ldm/semantic_synthesis512/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 1.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0205
7 | log_every_t: 100
8 | timesteps: 1000
9 | loss_type: l1
10 | first_stage_key: image
11 | cond_stage_key: segmentation
12 | image_size: 128
13 | channels: 3
14 | concat_mode: true
15 | cond_stage_trainable: true
16 | unet_config:
17 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
18 | params:
19 | image_size: 128
20 | in_channels: 6
21 | out_channels: 3
22 | model_channels: 128
23 | attention_resolutions:
24 | - 32
25 | - 16
26 | - 8
27 | num_res_blocks: 2
28 | channel_mult:
29 | - 1
30 | - 4
31 | - 8
32 | num_heads: 8
33 | first_stage_config:
34 | target: ldm.models.autoencoder.VQModelInterface
35 | params:
36 | embed_dim: 3
37 | n_embed: 8192
38 | monitor: val/rec_loss
39 | ddconfig:
40 | double_z: false
41 | z_channels: 3
42 | resolution: 256
43 | in_channels: 3
44 | out_ch: 3
45 | ch: 128
46 | ch_mult:
47 | - 1
48 | - 2
49 | - 4
50 | num_res_blocks: 2
51 | attn_resolutions: []
52 | dropout: 0.0
53 | lossconfig:
54 | target: torch.nn.Identity
55 | cond_stage_config:
56 | target: ldm.modules.encoders.modules.SpatialRescaler
57 | params:
58 | n_stages: 2
59 | in_channels: 182
60 | out_channels: 3
61 | data:
62 | target: main.DataModuleFromConfig
63 | params:
64 | batch_size: 8
65 | wrap: false
66 | num_workers: 10
67 | train:
68 | target: ldm.data.landscapes.RFWTrain
69 | params:
70 | size: 768
71 | crop_size: 512
72 | segmentation_to_float32: true
73 | validation:
74 | target: ldm.data.landscapes.RFWValidation
75 | params:
76 | size: 768
77 | crop_size: 512
78 | segmentation_to_float32: true
79 |
--------------------------------------------------------------------------------
/stable_diffusion/models/ldm/text2img256/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 2.0e-06
3 | target: ldm.models.diffusion.ddpm.LatentDiffusion
4 | params:
5 | linear_start: 0.0015
6 | linear_end: 0.0195
7 | num_timesteps_cond: 1
8 | log_every_t: 200
9 | timesteps: 1000
10 | first_stage_key: image
11 | cond_stage_key: caption
12 | image_size: 64
13 | channels: 3
14 | cond_stage_trainable: true
15 | conditioning_key: crossattn
16 | monitor: val/loss_simple_ema
17 | unet_config:
18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel
19 | params:
20 | image_size: 64
21 | in_channels: 3
22 | out_channels: 3
23 | model_channels: 192
24 | attention_resolutions:
25 | - 8
26 | - 4
27 | - 2
28 | num_res_blocks: 2
29 | channel_mult:
30 | - 1
31 | - 2
32 | - 3
33 | - 5
34 | num_head_channels: 32
35 | use_spatial_transformer: true
36 | transformer_depth: 1
37 | context_dim: 640
38 | first_stage_config:
39 | target: ldm.models.autoencoder.VQModelInterface
40 | params:
41 | embed_dim: 3
42 | n_embed: 8192
43 | ddconfig:
44 | double_z: false
45 | z_channels: 3
46 | resolution: 256
47 | in_channels: 3
48 | out_ch: 3
49 | ch: 128
50 | ch_mult:
51 | - 1
52 | - 2
53 | - 4
54 | num_res_blocks: 2
55 | attn_resolutions: []
56 | dropout: 0.0
57 | lossconfig:
58 | target: torch.nn.Identity
59 | cond_stage_config:
60 | target: ldm.modules.encoders.modules.BERTEmbedder
61 | params:
62 | n_embed: 640
63 | n_layer: 32
64 | data:
65 | target: main.DataModuleFromConfig
66 | params:
67 | batch_size: 28
68 | num_workers: 10
69 | wrap: false
70 | train:
71 | target: ldm.data.previews.pytorch_dataset.PreviewsTrain
72 | params:
73 | size: 256
74 | validation:
75 | target: ldm.data.previews.pytorch_dataset.PreviewsValidation
76 | params:
77 | size: 256
78 |
--------------------------------------------------------------------------------
/stable_diffusion/scripts/download_first_stages.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip
3 | wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip
4 | wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip
5 | wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip
6 | wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip
7 | wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip
8 | wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip
9 | wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip
10 | wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip
11 |
12 |
13 |
14 | cd models/first_stage_models/kl-f4
15 | unzip -o model.zip
16 |
17 | cd ../kl-f8
18 | unzip -o model.zip
19 |
20 | cd ../kl-f16
21 | unzip -o model.zip
22 |
23 | cd ../kl-f32
24 | unzip -o model.zip
25 |
26 | cd ../vq-f4
27 | unzip -o model.zip
28 |
29 | cd ../vq-f4-noattn
30 | unzip -o model.zip
31 |
32 | cd ../vq-f8
33 | unzip -o model.zip
34 |
35 | cd ../vq-f8-n256
36 | unzip -o model.zip
37 |
38 | cd ../vq-f16
39 | unzip -o model.zip
40 |
41 | cd ../..
--------------------------------------------------------------------------------
/stable_diffusion/scripts/download_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip
3 | wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip
4 | wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip
5 | wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip
6 | wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip
7 | wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip
8 | wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip
9 | wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip
10 | wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip
11 | wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip
12 | wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip
13 |
14 |
15 |
16 | cd models/ldm/celeba256
17 | unzip -o celeba-256.zip
18 |
19 | cd ../ffhq256
20 | unzip -o ffhq-256.zip
21 |
22 | cd ../lsun_churches256
23 | unzip -o lsun_churches-256.zip
24 |
25 | cd ../lsun_beds256
26 | unzip -o lsun_beds-256.zip
27 |
28 | cd ../text2img256
29 | unzip -o model.zip
30 |
31 | cd ../cin256
32 | unzip -o model.zip
33 |
34 | cd ../semantic_synthesis512
35 | unzip -o model.zip
36 |
37 | cd ../semantic_synthesis256
38 | unzip -o model.zip
39 |
40 | cd ../bsr_sr
41 | unzip -o model.zip
42 |
43 | cd ../layout2img-openimages256
44 | unzip -o model.zip
45 |
46 | cd ../inpainting_big
47 | unzip -o model.zip
48 |
49 | cd ../..
50 |
--------------------------------------------------------------------------------
/stable_diffusion/scripts/inpaint.py:
--------------------------------------------------------------------------------
1 | import argparse, os, sys, glob
2 | from omegaconf import OmegaConf
3 | from PIL import Image
4 | from tqdm import tqdm
5 | import numpy as np
6 | import torch
7 | from main import instantiate_from_config
8 | from ldm.models.diffusion.ddim import DDIMSampler
9 |
10 |
11 | def make_batch(image, mask, device):
12 | image = np.array(Image.open(image).convert("RGB"))
13 | image = image.astype(np.float32)/255.0
14 | image = image[None].transpose(0,3,1,2)
15 | image = torch.from_numpy(image)
16 |
17 | mask = np.array(Image.open(mask).convert("L"))
18 | mask = mask.astype(np.float32)/255.0
19 | mask = mask[None,None]
20 | mask[mask < 0.5] = 0
21 | mask[mask >= 0.5] = 1
22 | mask = torch.from_numpy(mask)
23 |
24 | masked_image = (1-mask)*image
25 |
26 | batch = {"image": image, "mask": mask, "masked_image": masked_image}
27 | for k in batch:
28 | batch[k] = batch[k].to(device=device)
29 | batch[k] = batch[k]*2.0-1.0
30 | return batch
31 |
32 |
33 | if __name__ == "__main__":
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument(
36 | "--indir",
37 | type=str,
38 | nargs="?",
39 | help="dir containing image-mask pairs (`example.png` and `example_mask.png`)",
40 | )
41 | parser.add_argument(
42 | "--outdir",
43 | type=str,
44 | nargs="?",
45 | help="dir to write results to",
46 | )
47 | parser.add_argument(
48 | "--steps",
49 | type=int,
50 | default=50,
51 | help="number of ddim sampling steps",
52 | )
53 | opt = parser.parse_args()
54 |
55 | masks = sorted(glob.glob(os.path.join(opt.indir, "*_mask.png")))
56 | images = [x.replace("_mask.png", ".png") for x in masks]
57 | print(f"Found {len(masks)} inputs.")
58 |
59 | config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
60 | model = instantiate_from_config(config.model)
61 | model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
62 | strict=False)
63 |
64 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
65 | model = model.to(device)
66 | sampler = DDIMSampler(model)
67 |
68 | os.makedirs(opt.outdir, exist_ok=True)
69 | with torch.no_grad():
70 | with model.ema_scope():
71 | for image, mask in tqdm(zip(images, masks)):
72 | outpath = os.path.join(opt.outdir, os.path.split(image)[1])
73 | batch = make_batch(image, mask, device=device)
74 |
75 | # encode masked image and concat downsampled mask
76 | c = model.cond_stage_model.encode(batch["masked_image"])
77 | cc = torch.nn.functional.interpolate(batch["mask"],
78 | size=c.shape[-2:])
79 | c = torch.cat((c, cc), dim=1)
80 |
81 | shape = (c.shape[1]-1,)+c.shape[2:]
82 | samples_ddim, _ = sampler.sample(S=opt.steps,
83 | conditioning=c,
84 | batch_size=c.shape[0],
85 | shape=shape,
86 | verbose=False)
87 | x_samples_ddim = model.decode_first_stage(samples_ddim)
88 |
89 | image = torch.clamp((batch["image"]+1.0)/2.0,
90 | min=0.0, max=1.0)
91 | mask = torch.clamp((batch["mask"]+1.0)/2.0,
92 | min=0.0, max=1.0)
93 | predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
94 | min=0.0, max=1.0)
95 |
96 | inpainted = (1-mask)*image+mask*predicted_image
97 | inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
98 | Image.fromarray(inpainted.astype(np.uint8)).save(outpath)
99 |
--------------------------------------------------------------------------------
/stable_diffusion/scripts/latent_imagenet_diffusion.ipynb.REMOVED.git-id:
--------------------------------------------------------------------------------
1 | 607f94fc7d3ef6d8d1627017215476d9dfc7ddc4
--------------------------------------------------------------------------------
/stable_diffusion/scripts/tests/test_watermark.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import fire
3 | from imwatermark import WatermarkDecoder
4 |
5 |
6 | def testit(img_path):
7 | bgr = cv2.imread(img_path)
8 | decoder = WatermarkDecoder('bytes', 136)
9 | watermark = decoder.decode(bgr, 'dwtDct')
10 | try:
11 | dec = watermark.decode('utf-8')
12 | except:
13 | dec = "null"
14 | print(dec)
15 |
16 |
17 | if __name__ == "__main__":
18 | fire.Fire(testit)
--------------------------------------------------------------------------------
/stable_diffusion/scripts/train_searcher.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 | import numpy as np
3 | import scann
4 | import argparse
5 | import glob
6 | from multiprocessing import cpu_count
7 | from tqdm import tqdm
8 |
9 | from ldm.util import parallel_data_prefetch
10 |
11 |
12 | def search_bruteforce(searcher):
13 | return searcher.score_brute_force().build()
14 |
15 |
16 | def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
17 | partioning_trainsize, num_leaves, num_leaves_to_search):
18 | return searcher.tree(num_leaves=num_leaves,
19 | num_leaves_to_search=num_leaves_to_search,
20 | training_sample_size=partioning_trainsize). \
21 | score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
22 |
23 |
24 | def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
25 | return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
26 | reorder_k).build()
27 |
28 | def load_datapool(dpath):
29 |
30 |
31 | def load_single_file(saved_embeddings):
32 | compressed = np.load(saved_embeddings)
33 | database = {key: compressed[key] for key in compressed.files}
34 | return database
35 |
36 | def load_multi_files(data_archive):
37 | database = {key: [] for key in data_archive[0].files}
38 | for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
39 | for key in d.files:
40 | database[key].append(d[key])
41 |
42 | return database
43 |
44 | print(f'Load saved patch embedding from "{dpath}"')
45 | file_content = glob.glob(os.path.join(dpath, '*.npz'))
46 |
47 | if len(file_content) == 1:
48 | data_pool = load_single_file(file_content[0])
49 | elif len(file_content) > 1:
50 | data = [np.load(f) for f in file_content]
51 | prefetched_data = parallel_data_prefetch(load_multi_files, data,
52 | n_proc=min(len(data), cpu_count()), target_data_type='dict')
53 |
54 | data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()}
55 | else:
56 | raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
57 |
58 | print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.')
59 | return data_pool
60 |
61 |
62 | def train_searcher(opt,
63 | metric='dot_product',
64 | partioning_trainsize=None,
65 | reorder_k=None,
66 | # todo tune
67 | aiq_thld=0.2,
68 | dims_per_block=2,
69 | num_leaves=None,
70 | num_leaves_to_search=None,):
71 |
72 | data_pool = load_datapool(opt.database)
73 | k = opt.knn
74 |
75 | if not reorder_k:
76 | reorder_k = 2 * k
77 |
78 | # normalize
79 | # embeddings =
80 | searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
81 | pool_size = data_pool['embedding'].shape[0]
82 |
83 | print(*(['#'] * 100))
84 | print('Initializing scaNN searcher with the following values:')
85 | print(f'k: {k}')
86 | print(f'metric: {metric}')
87 | print(f'reorder_k: {reorder_k}')
88 | print(f'anisotropic_quantization_threshold: {aiq_thld}')
89 | print(f'dims_per_block: {dims_per_block}')
90 | print(*(['#'] * 100))
91 | print('Start training searcher....')
92 | print(f'N samples in pool is {pool_size}')
93 |
94 | # this reflects the recommended design choices proposed at
95 | # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
96 | if pool_size < 2e4:
97 | print('Using brute force search.')
98 | searcher = search_bruteforce(searcher)
99 | elif 2e4 <= pool_size and pool_size < 1e5:
100 | print('Using asymmetric hashing search and reordering.')
101 | searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
102 | else:
103 | print('Using using partioning, asymmetric hashing search and reordering.')
104 |
105 | if not partioning_trainsize:
106 | partioning_trainsize = data_pool['embedding'].shape[0] // 10
107 | if not num_leaves:
108 | num_leaves = int(np.sqrt(pool_size))
109 |
110 | if not num_leaves_to_search:
111 | num_leaves_to_search = max(num_leaves // 20, 1)
112 |
113 | print('Partitioning params:')
114 | print(f'num_leaves: {num_leaves}')
115 | print(f'num_leaves_to_search: {num_leaves_to_search}')
116 | # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
117 | searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
118 | partioning_trainsize, num_leaves, num_leaves_to_search)
119 |
120 | print('Finish training searcher')
121 | searcher_savedir = opt.target_path
122 | os.makedirs(searcher_savedir, exist_ok=True)
123 | searcher.serialize(searcher_savedir)
124 | print(f'Saved trained searcher under "{searcher_savedir}"')
125 |
126 | if __name__ == '__main__':
127 | sys.path.append(os.getcwd())
128 | parser = argparse.ArgumentParser()
129 | parser.add_argument('--database',
130 | '-d',
131 | default='data/rdm/retrieval_databases/openimages',
132 | type=str,
133 | help='path to folder containing the clip feature of the database')
134 | parser.add_argument('--target_path',
135 | '-t',
136 | default='data/rdm/searchers/openimages',
137 | type=str,
138 | help='path to the target folder where the searcher shall be stored.')
139 | parser.add_argument('--knn',
140 | '-k',
141 | default=20,
142 | type=int,
143 | help='number of nearest neighbors, for which the searcher shall be optimized')
144 |
145 | opt, _ = parser.parse_known_args()
146 |
147 | train_searcher(opt,)
--------------------------------------------------------------------------------
/stable_diffusion/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name='latent-diffusion',
5 | version='0.0.1',
6 | description='',
7 | packages=find_packages(),
8 | install_requires=[
9 | 'torch',
10 | 'numpy',
11 | 'tqdm',
12 | ],
13 | )
--------------------------------------------------------------------------------