├── .github └── workflow │ └── auto_assign.yml ├── LICENSE ├── README.md ├── environment.yml ├── experiments-results-analysis └── experiments │ ├── add_dog8_config_01.sh │ ├── add_dog8_config_01.yaml │ ├── replace_dog8_config_01.sh │ └── replace_dog8_config_01.yaml ├── requirements.txt ├── setup.cfg ├── setup.py └── src ├── __init__.py ├── auto_evaluation.py ├── clip_img_ret.py ├── configs ├── autoencoder │ ├── autoencoder_kl_16x16x16.yaml │ ├── autoencoder_kl_32x32x4.yaml │ ├── autoencoder_kl_64x64x3.yaml │ └── autoencoder_kl_8x8x64.yaml ├── dream-edit │ ├── add_dog_default.yaml │ └── edit_backpack_default.yaml ├── latent-diffusion │ ├── celebahq-ldm-vq-4.yaml │ ├── cin-ldm-vq-f8.yaml │ ├── cin256-v2.yaml │ ├── ffhq-ldm-vq-4.yaml │ ├── lsun_bedrooms-ldm-vq-4.yaml │ ├── lsun_churches-ldm-kl-8.yaml │ └── txt2img-1p4B-eval.yaml ├── retrieval-augmented-diffusion │ └── 768x768.yaml └── stable-diffusion │ └── v1-inference.yaml ├── generate_new.py ├── iterate_generate.py ├── ldm ├── __init__.py ├── data │ ├── __init__.py │ ├── base.py │ ├── imagenet.py │ └── lsun.py ├── lr_scheduler.py ├── models │ ├── __init__.py │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── dpm_solver │ │ ├── __init__.py │ │ ├── dpm_solver.py │ │ └── sampler.py │ │ └── plms.py ├── modules │ ├── __init__.py │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py └── util.py ├── metrics ├── __init__.py ├── clip_vit.py ├── dino_vit.py ├── distances.py └── evaluate_dino.py ├── pipelines ├── __init__.py ├── extract_object_pipeline.py ├── imagecaption_pipelines.py └── inpainting_pipelines.py └── utils ├── losses.py ├── mask_helper.py ├── path_finder.py └── visual_helper.py /.github/workflow/auto_assign.yml: -------------------------------------------------------------------------------- 1 | name: Auto Assign 2 | on: 3 | issues: 4 | types: [opened] 5 | pull_request: 6 | types: [opened] 7 | jobs: 8 | run: 9 | runs-on: ubuntu-latest 10 | permissions: 11 | issues: write 12 | pull-requests: write 13 | steps: 14 | - name: 'Auto-assign issue' 15 | uses: pozil/auto-assign-issue@v1 16 | with: 17 | repo-token: ${{ secrets.GITHUB_TOKEN }} 18 | assignees: 19 | - DreamEditTeam 20 | - vinesmsuic 21 | - lim142857 22 | numOfAssignee: 1 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Tianle Li, Max Ku, Cong Wei and Wenhu Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DreamEdit: Subject-driven Image Editing 2 | [![arXiv](https://img.shields.io/badge/arXiv-2306.12624-b31b1b.svg)](https://arxiv.org/abs/2306.12624) 3 | 4 | Replace the subject in a given image to a customized one or add your customized subject to any provided background! 5 | 6 | ![image](https://github.com/DreamEditBenchTeam/DreamEdit/assets/34955859/b66e3809-967d-46d5-a3ba-87879550106b) 7 | 8 | Models, code, and dataset for [DreamEdit: Subject-driven Image Editing](https://arxiv.org/abs/2306.12624). 9 | 10 | Check [project website](https://dreameditbenchteam.github.io/) for demos and data examples. 11 | 12 | ## News 13 | * [26/10/2023] Paper accepted to [TMLR 2023](https://jmlr.org/tmlr/). 14 | * [20/07/2023] We released our DreamBooth weights on [UWaterloo Vault](https://vault.cs.uwaterloo.ca/s/EiNjg9yTAKEFgF2) and [HuggingFace Space](https://huggingface.co/ImagenHub/DreamEdit-DreamBooth-Models). 15 | 16 | ## Requirements 17 | A suitable conda environment named `dream_edit` can be created and activated with: 18 | 19 | ```bash 20 | conda env create -f environment.yml 21 | conda activate dream_edit 22 | 23 | # To update env 24 | conda env update dream_edit --file environment.yml --prune 25 | ``` 26 | > ^There is some problem with the environment file setup currently, we will fix it soon. 27 | 28 | > For now to get the code able to run: 29 | > Our repo requires dependencies from different repos. Please follow the official installation of: 30 | > * [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO) 31 | > * [LangSAM](https://github.com/luca-medeiros/lang-segment-anything/tree/main) 32 | > * [Stable Diffusion](https://github.com/CompVis/stable-diffusion) 33 | > * [Gligen's fork of diffuser](https://github.com/gligen/diffusers) 34 | 35 | For example, besides the auto install environment, we also install dependencies with: 36 | ```bash 37 | pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu116 38 | cd .. 39 | 40 | # Install SAM 41 | pip install git+https://github.com/luca-medeiros/lang-segment-anything.git (already included in the yml file) 42 | git clone https://github.com/IDEA-Research/Grounded-Segment-Anything.git 43 | 44 | # To enable gpu in grounding dino: 45 | conda install -c conda-forge cudatoolkit-dev -y 46 | export BUILD_WITH_CUDA=True 47 | export CUDA_HOME=$CONDA_PREFIX 48 | export AM_I_DOCKER=False 49 | 50 | # This might be optional: 51 | cd ~/dreamedit_env_dependency/Grounded-Segment-Anything/ 52 | python -m pip install -e segment_anything 53 | python -m pip install -e GroundingDINO 54 | git submodule update --init --recursive 55 | cd grounded-sam-osx && bash install.sh 56 | 57 | pip install opencv-python pycocotools matplotlib onnxruntime onnx ipykernel 58 | pip install accelerate 59 | 60 | # Install diffusers in gligen fork to enable gligen pipeline: 61 | git clone https://github.com/gligen/diffusers.git 62 | cd diffusers 63 | pip install -e . 64 | ``` 65 | 66 | 67 | 68 | ## Huggingface Dataset 69 | Our dataset is on Huggingface now: https://huggingface.co/datasets/tianleliphoebe/DreamEditBench, a more self-contained one: https://huggingface.co/datasets/tianleliphoebe/DreamEditBench_SelfContained. 70 | ```python 71 | from datasets import load_dataset 72 | dataset = load_dataset("tianleliphoebe/DreamEditBench") 73 | ``` 74 | 75 | ## How to run 76 | Go to `experiment-results-analysis` folder: 77 | ```bash 78 | cd experiment-results-analysis/experiments 79 | ``` 80 | 81 | Run the script: 82 | ```bash 83 | sh replace_dog8_config_01.sh 84 | ``` 85 | You can change the input path for data, model, and other parameter setting in the corresponding config file (e.g. replace_dog8_config_01.yaml). The ```data: src_img_data_folder_path``` should be specified as the path to our DreamEditBench dataset can be downloaded as mentioned above. And the ```db_dataset_path``` should be the path of the original dreambooth dataset that can be downloaded here: https://github.com/google/dreambooth/tree/main/dataset. 86 | An example fine-tuned dreambooth model checkpoint for dog8 can be downloaded at [here](https://drive.google.com/file/d/1aSyA6CsCchYC1l9DxJiy0CrJsht0K0sj/view?usp=sharing). 87 | 88 | All the other subject fine-tuned model weights can be downloaded at [this link](https://huggingface.co/vinesmsuic/DreamEdit-DreamBooth-Models/tree/main/dreamedit_official_ckpt). You can also finetune your own dreambooth models following the [implementation](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion). 89 | 90 | 91 | ## BibTeX 92 | 93 | If you find this paper or repo useful for your research, please consider citing our paper: 94 | ```bibtex 95 | @misc{li2023dreamedit, 96 | title={DreamEdit: Subject-driven Image Editing}, 97 | author={Tianle Li and Max Ku and Cong Wei and Wenhu Chen}, 98 | year={2023}, 99 | eprint={2306.12624}, 100 | archivePrefix={arXiv}, 101 | primaryClass={cs.CV} 102 | } 103 | ``` 104 | 105 | ## Star History 106 | 107 | [![Star History Chart](https://api.star-history.com/svg?repos=DreamEditBenchTeam/DreamEdit&type=Date)](https://star-history.com/#DreamEditBenchTeam/DreamEdit&Date) 108 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dream_edit 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - python=3.8.16 8 | - pip=20.3 9 | - cudatoolkit=11.0 10 | - pytorch=1.7.0 11 | - torchvision=0.8.1 12 | - numpy=1.19.2 13 | - pip: 14 | - -r requirements.txt -------------------------------------------------------------------------------- /experiments-results-analysis /experiments/add_dog8_config_01.sh: -------------------------------------------------------------------------------- 1 | cd .. # cd to experiments-results-analysis 2 | cd .. # cd to DreamEdit project root 3 | cd src 4 | python3 iterate_generate.py \ 5 | --config "../experiments-results-analysis/experiments/add_dog8_config_01.yaml" -------------------------------------------------------------------------------- /experiments-results-analysis /experiments/add_dog8_config_01.yaml: -------------------------------------------------------------------------------- 1 | # Purpose: 2 | # Validate whether higher scale is the reason for the background noise issue 3 | # Conclusion: 4 | # No 5 | 6 | base_path: "${oc.env:HOME}" 7 | experiment_name: "add_dog8" 8 | 9 | class_obj_name: "dog" 10 | class_folder_name: "dog8" 11 | benchmark_folder_name: "dog" # for our labeling 12 | token_of_class: "wie" 13 | ckpt_base_folder: "dog82023-04-17T00-58-03_dog8_april/" 14 | dream_edit_prompt: "a grey and white border collie dog" 15 | 16 | config_name: "add_${class_folder_name}_config_01" 17 | experiment_result_path: "/home/data/dream_edit_project/results/${experiment_name}/${config_name}/" 18 | 19 | data: 20 | src_img_data_folder_path: "/home/data/dream_edit_project/benchmark/background_images_refine/" 21 | class_name: "${class_obj_name}" 22 | bbox_file_name: "bbox.json" 23 | src_img_file_name: "found0.jpg" 24 | 25 | db_dataset_path: "/home/data/dream_edit_project/benchmark/cvpr_dataset/" 26 | db_folder_name: "${class_folder_name}" 27 | obj_img_file_name: "00.jpg" 28 | 29 | model: 30 | gligen: # GLIGEN: Open-Set Grounded Text-to-Image Generation 31 | gligen_scheduled_sampling_beta: 1 # TODO: What is this? 32 | num_inference_steps: 100 33 | 34 | lang_sam: # Segment Anything 35 | segment_confidence: 0.1 # segmentation confidence in segment-anything 36 | 37 | sd: # Stable Diffusion 38 | conf_path: "configs/stable-diffusion/v1-inference.yaml" 39 | ckpt_prefix: "/home/data/dream_edit_project/model_weights/" 40 | ckpt: "${ckpt_base_folder}" 41 | ckpt_suffix: "checkpoints/last.ckpt" 42 | ckpt_path: "${model.sd.ckpt_prefix}${model.sd.ckpt}${model.sd.ckpt_suffix}" 43 | 44 | de: # DreamEdit 45 | task_type: "add" 46 | special_token: "${token_of_class}" 47 | bounding_box: "bbox.json" 48 | inpaint_after_last_iteration: False # whether to inpaint after the last iteration 49 | postprocessing_type: "sd_inpaint" 50 | use_diffedit: False # baseline of diffedit 51 | 52 | addition_config: 53 | use_copy_paste: False 54 | inpaint_type: "gligen" 55 | automate_prompt: False # whether to generate prompt from BLIP image caption model 56 | inpaint_prompt: "photo of ${dream_edit_prompt}" 57 | inpaint_phrase: "${dream_edit_prompt}" 58 | 59 | mask_config: 60 | mask_dilate_kernel: 20 61 | mask_type: "dilation" 62 | use_bbox_mask_for_first_iteration: True # whether to use bbox as the mask for the first iteration 63 | use_bbox_mask_for_all_iterations: False # whether to use bbox as the mask for all iterations 64 | 65 | ddim: 66 | seed: 42 # the seed (for reproducible sampling) 67 | scale: 5.5 68 | ddim_steps: 40 69 | noise_step: 0 70 | iteration_number: 7 71 | encode_ratio_schedule: 72 | decay_type: "manual" # "linear" or "exponential" or "constant" or "manual" 73 | start_ratio: 0.8 74 | end_ratio: 0.3 75 | manual_ratio_list: [0.5, 0.4, 0.4, 0.4, 0.3, 0.3, 0.3] # only used when decay_type is "manual" 76 | 77 | background_correction_enabled: True # set to false when doing diffedit baseline 78 | background_correction: 79 | iteration_number: 7 # how many iterations to correct the background 80 | use_latents_record: False # reuse the latents from the first iteration 81 | use_background_from_original_image: True 82 | use_obj_mask_from_first_iteration: False # whether we always use the object mask from the first iteration -------------------------------------------------------------------------------- /experiments-results-analysis /experiments/replace_dog8_config_01.sh: -------------------------------------------------------------------------------- 1 | cd .. # cd to experiments-results-analysis 2 | cd .. # cd to DreamEdit project root 3 | cd src 4 | python3 iterate_generate.py \ 5 | --config "../experiments-results-analysis/experiments/replace_dog8_config_01.yaml" -------------------------------------------------------------------------------- /experiments-results-analysis /experiments/replace_dog8_config_01.yaml: -------------------------------------------------------------------------------- 1 | # Purpose: 2 | # Validate whether higher scale is the reason for the background noise issue 3 | # Conclusion: 4 | # No 5 | 6 | base_path: "${oc.env:HOME}" 7 | experiment_name: "replace_dog8" 8 | 9 | class_obj_name: "dog" 10 | class_folder_name: "dog8" 11 | benchmark_folder_name: "dog" # for our labeling 12 | token_of_class: "wie" 13 | ckpt_base_folder: "dog82023-04-17T00-58-03_dog8_april/" 14 | dream_edit_prompt: "a grey and white border collie dog" 15 | 16 | config_name: "replace_${class_folder_name}_config_01" 17 | experiment_result_path: "/home/data/dream_edit_project/results/${experiment_name}/${config_name}/" 18 | 19 | data: 20 | src_img_data_folder_path: "/home/data/dream_edit_project/benchmark/ref_images/" 21 | class_name: "${class_obj_name}" 22 | bbox_file_name: "bbox.json" 23 | src_img_file_name: "found0.jpg" 24 | 25 | db_dataset_path: "/home/data/dream_edit_project/benchmark/cvpr_dataset/" 26 | db_folder_name: "${class_folder_name}" 27 | obj_img_file_name: "00.jpg" 28 | 29 | model: 30 | gligen: # GLIGEN: Open-Set Grounded Text-to-Image Generation 31 | gligen_scheduled_sampling_beta: 1 # TODO: What is this? 32 | num_inference_steps: 100 33 | 34 | lang_sam: # Segment Anything 35 | segment_confidence: 0.1 # segmentation confidence in segment-anything 36 | 37 | sd: # Stable Diffusion 38 | conf_path: "configs/stable-diffusion/v1-inference.yaml" 39 | ckpt_prefix: "/home/data/dream_edit_project/model_weights/" 40 | ckpt: "${ckpt_base_folder}" 41 | ckpt_suffix: "checkpoints/last.ckpt" 42 | ckpt_path: "${model.sd.ckpt_prefix}${model.sd.ckpt}${model.sd.ckpt_suffix}" 43 | 44 | de: # DreamEdit 45 | task_type: "replace" 46 | special_token: "${token_of_class}" 47 | bounding_box: "bbox.json" 48 | inpaint_after_last_iteration: False # whether to inpaint after the last iteration 49 | postprocessing_type: "sd_inpaint" 50 | use_diffedit: False # baseline of diffedit 51 | 52 | addition_config: 53 | use_copy_paste: False 54 | inpaint_type: "gligen" 55 | automate_prompt: False # whether to generate prompt from BLIP image caption model 56 | inpaint_prompt: "photo of ${dream_edit_prompt}" 57 | inpaint_phrase: "${dream_edit_prompt}" 58 | 59 | mask_config: 60 | mask_dilate_kernel: 20 61 | mask_type: "dilation" 62 | use_bbox_mask_for_first_iteration: False # whether to use bbox as the mask for the first iteration 63 | use_bbox_mask_for_all_iterations: False # whether to use bbox as the mask for all iterations 64 | 65 | ddim: 66 | seed: 42 # the seed (for reproducible sampling) 67 | scale: 5.5 68 | ddim_steps: 40 69 | noise_step: 0 70 | iteration_number: 5 71 | encode_ratio_schedule: 72 | decay_type: "linear" # "linear" or "exponential" or "constant" or "manual" 73 | start_ratio: 0.6 74 | end_ratio: 0.3 75 | manual_ratio_list: [0.5, 0.4, 0.4, 0.4, 0.3] # only used when decay_type is "manual" 76 | 77 | background_correction_enabled: True # set to false when doing diffedit baseline 78 | background_correction: 79 | iteration_number: 3 # how many iterations to correct the background 80 | use_latents_record: False # reuse the latents from the first iteration 81 | use_background_from_original_image: True 82 | use_obj_mask_from_first_iteration: True # whether we always use the object mask from the first iteration -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations~=1.1.0 2 | opencv-python~=4.2.0.34 3 | pudb~=2019.2 4 | imageio~=2.14.1 5 | imageio-ffmpeg~=0.4.7 6 | pytorch-lightning~=1.5.9 7 | omegaconf~=2.1.1 8 | test-tube>=0.7.5 9 | streamlit>=0.73.1 10 | pillow~=9.0.1 11 | einops~=0.4.1 12 | torch-fidelity~=0.3.0 13 | setuptools~=59.5.0 14 | transformers~=4.18.0 15 | torchmetrics~=0.6.0 16 | kornia~=0.6 17 | diffusers~=0.15.1 18 | tqdm~=4.64.1 19 | matplotlib~=3.7.1 20 | taming-transformers-rom1504~=0.0.6 21 | -e git+https://github.com/openai/CLIP.git@main#egg=clip 22 | git+https://github.com/luca-medeiros/lang-segment-anything.git 23 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [pycodestyle] 2 | max-line-length = 88 3 | 4 | [flake8] 5 | max-line-length = 88 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import setuptools 3 | 4 | with open(os.path.join("README.md"), "r", encoding="utf-8") as fh: 5 | long_description = fh.read() 6 | 7 | setuptools.setup( 8 | name="dream_edit", 9 | version="0.0.1", 10 | author="Tiger Lab", 11 | description="Dream Edit", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/ltl3A87/DreamEdit", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ], 21 | python_requires=">=3.8", 22 | install_requires=[ 23 | "torch", 24 | "torchvision", 25 | ], 26 | ) -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamEditBenchTeam/DreamEdit/14d21b0a3eb6305c1378080ccd8361db0a8adcc0/src/__init__.py -------------------------------------------------------------------------------- /src/auto_evaluation.py: -------------------------------------------------------------------------------- 1 | from metrics.dino_vit import VITs16 2 | from metrics.clip_vit import CLIP 3 | from metrics.evaluate_dino import * 4 | import csv 5 | import os 6 | from utils.mask_helper import * 7 | from utils.visual_helper import * 8 | import glob 9 | 10 | def geo_mean(iterable): 11 | iterable_new = np.where(iterable<0, 0.001, iterable) 12 | a = np.array(iterable_new) 13 | return a.prod()**(1.0/len(a)) 14 | 15 | 16 | subject_folder_names = ["dog", "dog2", "dog3", "dog5", "dog6", "dog7", "dog8", "cat", "cat2", 17 | "bear_plushie", "backpack", "backpack_dog", "berry_bowl", "can", "candle", 18 | "clock", "colorful_sneaker", "duck_toy", "fancy_boot", "grey_sloth_plushie", 19 | "monster_toy", "pink_sunglasses", "poop_emoji", "rc_car", "red_cartoon", "robot_toy", 20 | "shiny_sneaker", "teapot", "vase", "wolf_plushie"] 21 | # subject_folder_names = ["poop_emoji"] 22 | task = "add" 23 | method = "dreambooth" 24 | # result_folder_name = "/home/data/dream_edit_project/results/2023-05-31-replacement" 25 | cvpr_folder_name = "/home/data/dream_edit_project/benchmark/cvpr_dataset" 26 | if task == "replace": 27 | if method == "dreamedit": 28 | replacement_index = [3, 6, 9, 12, 15] 29 | result_folder_name = "/home/data/dream_edit_project/results/2023-05-31-replacement" 30 | device = torch.device("cuda:3") if torch.cuda.is_available() else torch.device("cpu") 31 | elif method == "diffedit": 32 | replacement_index = [4] 33 | result_folder_name = "/home/data/dream_edit_project/results/2023-05-31-replacement-diffedit" 34 | device = torch.device("cuda:7") if torch.cuda.is_available() else torch.device("cpu") 35 | elif method == "copygen": 36 | replacement_index = [4, 7, 10, 13, 16] 37 | result_folder_name = "/home/data/dream_edit_project/results/2023-06-13-replacement-copy-paste" 38 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 39 | elif method == "copypaste": 40 | replacement_index = [1] 41 | result_folder_name = "/home/data/dream_edit_project/results/2023-06-13-replacement-copy-paste" 42 | device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu") 43 | elif method == "dreambooth": 44 | replacement_index = [0] 45 | result_folder_name = "/home/data/dream_edit_project/results/2023-06-13-replacement-dreambooth" 46 | device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu") 47 | 48 | else: 49 | if method == "dreamedit": 50 | replacement_index = [5, 8, 11, 14, 17] 51 | result_folder_name = "/home/data/dream_edit_project/results/2023-06-12-addition-combined" 52 | device = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu") 53 | elif method == "copygen": 54 | replacement_index = [4, 7, 10, 13, 16] 55 | result_folder_name = "/home/data/dream_edit_project/results/2023-05-31-addition-copy-paste" 56 | device = torch.device("cuda:6") if torch.cuda.is_available() else torch.device("cpu") 57 | elif method == "copypaste": 58 | replacement_index = [1] 59 | result_folder_name = "/home/data/dream_edit_project/results/2023-05-31-addition-copy-paste" 60 | device = torch.device("cuda:5") if torch.cuda.is_available() else torch.device("cpu") 61 | elif method == "diffedit": 62 | replacement_index = [6] 63 | result_folder_name = "/home/data/dream_edit_project/results/2023-05-31-addition-diffedit" 64 | device = torch.device("cuda:4") if torch.cuda.is_available() else torch.device("cpu") 65 | elif method == "dreambooth": 66 | replacement_index = [0] 67 | result_folder_name = "/home/data/dream_edit_project/results/2023-06-13-addition-dreambooth" 68 | device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu") 69 | 70 | iterations = len(replacement_index) 71 | dino_model = VITs16(device) 72 | clip_model = CLIP(device) 73 | number_sub_images = 10 74 | image_size = 512 75 | 76 | 77 | if task == "replace": 78 | subject_folder_to_bench_dict = {"dog2": "dog", "dog3": "dog", "dog5": "dog", "dog6": "dog", 79 | "dog7": "dog", "dog8": "dog", "cat2": "cat", "backpack_dog": "backpack"} 80 | else: 81 | subject_folder_to_bench_dict = {"dog2": "dog", "dog3": "dog", "dog5": "dog", "dog6": "dog", 82 | "dog7": "dog", "dog8": "dog", "cat2": "cat", "backpack_dog": "backpack", "shiny_sneaker": "sneaker", 83 | "colorful_sneaker": "sneaker"} 84 | 85 | dino_sub_score_list = [] 86 | clipi_sub_score_list = [] 87 | dino_back_score_list = [] 88 | clipi_back_score_list = [] 89 | overall_list = [] 90 | 91 | dino_sub_score_list_iteration = [[] for i in range(iterations)] 92 | dino_back_score_list_iteration = [[] for i in range(iterations)] 93 | clipi_sub_score_list_iteration = [[] for i in range(iterations)] 94 | clipi_back_score_list_iteration = [[] for i in range(iterations)] 95 | overall_list_iteration = [[] for i in range(iterations)] 96 | 97 | 98 | for subject_folder in subject_folder_names: 99 | print(subject_folder, flush=True) 100 | if task == "replace": 101 | config_folder_name = "replace_" + subject_folder + "_config_01" 102 | image_path = os.path.join(result_folder_name, config_folder_name, subject_folder, "replace", "result_all.jpg") 103 | else: 104 | config_folder_name = "add_" + subject_folder + "_config_01" 105 | image_path = os.path.join(result_folder_name, config_folder_name, subject_folder, "add", "result_all.jpg") 106 | if method == "dreambooth": 107 | image_path = os.path.join(result_folder_name, subject_folder+".jpg") 108 | src_img = Image.open(image_path) 109 | obj_image_path_list = glob.glob("/home/data/dream_edit_project/benchmark/cvpr_dataset/"+subject_folder+"/*.jpg") 110 | print("obj_image_path_list: ", obj_image_path_list, flush=True) 111 | print(src_img.size, flush=True) 112 | for i in range(number_sub_images): 113 | print(i, flush=True) 114 | if subject_folder in subject_folder_to_bench_dict: 115 | bench_folder = subject_folder_to_bench_dict[subject_folder] 116 | else: 117 | bench_folder = subject_folder 118 | if task == "replace": 119 | background_img_path = "/home/data/dream_edit_project/benchmark/ref_images/" + bench_folder + "/found" + str(i) + ".jpg" 120 | else: 121 | background_img_path = "/home/data/dream_edit_project/benchmark/background_images_refine/" + bench_folder + "/found" + str( 122 | i) + ".jpg" 123 | iterations_image = [] 124 | iterations_dino_sub = [] 125 | iterations_clip_sub = [] 126 | if method != "dreambooth": 127 | for index in replacement_index: 128 | array_src = np.asarray(src_img) 129 | extract_img = Image.fromarray(np.uint8(array_src[i * image_size:(i + 1) * image_size, index * image_size:(index+1) * image_size, :])) 130 | iterations_image.append(extract_img) 131 | else: 132 | array_src = np.asarray(src_img) 133 | extract_img = Image.fromarray(np.uint8( 134 | array_src[0:image_size, i * image_size:(i + 1) * image_size, :])) 135 | iterations_image.append(extract_img ) 136 | obj_image_list = [Image.open(obj_path).resize((512, 512)) for obj_path in obj_image_path_list] 137 | background_img = Image.open(background_img_path).resize((512, 512)) 138 | for obj_img in obj_image_list: 139 | dino_score_list_subject = evaluate_dino_score_list(obj_img, iterations_image, device, dino_model) 140 | clip_score_list_subject = evaluate_clipi_score_list(obj_img, iterations_image, device, clip_model) 141 | iterations_dino_sub.append(dino_score_list_subject) 142 | iterations_clip_sub.append(clip_score_list_subject) 143 | iterations_dino_sub_avg = np.array(iterations_dino_sub).mean(axis=0).tolist() 144 | iterations_clip_sub_avg = np.array(iterations_clip_sub).mean(axis=0).tolist() 145 | iterations_dino_back = evaluate_dino_score_list(background_img, iterations_image, device, dino_model) 146 | iterations_clip_back = evaluate_clipi_score_list(background_img, iterations_image, device, clip_model) 147 | iterations_result = [iterations_dino_sub_avg, iterations_clip_sub_avg, iterations_dino_back, iterations_dino_back] 148 | iterations_result = np.array(iterations_result).T 149 | iterations_geo_avg = [] 150 | for it in range(iterations): 151 | dino_sub_score_list_iteration[it].append(iterations_dino_sub_avg[it]) 152 | dino_back_score_list_iteration[it].append(iterations_dino_back[it]) 153 | clipi_sub_score_list_iteration[it].append(iterations_clip_sub_avg[it]) 154 | clipi_back_score_list_iteration[it].append(iterations_clip_back[it]) 155 | overall_list_iteration[it].append(geo_mean(iterations_result[it])) 156 | iterations_geo_avg.append(geo_mean(iterations_result[it])) 157 | # print("iterations_result: ", iterations_result) 158 | # print(overall_list_iteration[0][-1]) 159 | best_overall = max(iterations_geo_avg) 160 | best_iter = iterations_geo_avg.index(best_overall) 161 | dino_sub_score_list.append(iterations_dino_sub_avg[best_iter]) 162 | dino_back_score_list.append(iterations_dino_back[best_iter]) 163 | clipi_sub_score_list.append(iterations_clip_sub_avg[best_iter]) 164 | clipi_back_score_list.append(iterations_clip_back[best_iter]) 165 | overall_list.append(best_overall) 166 | 167 | 168 | print("dino sub: ", sum(dino_sub_score_list)/len(dino_sub_score_list), flush=True) 169 | print("dino back: ", sum(dino_back_score_list)/len(dino_back_score_list), flush=True) 170 | print("clip sub: ", sum(clipi_sub_score_list)/len(clipi_sub_score_list), flush=True) 171 | print("clip back: ", sum(clipi_back_score_list)/len(clipi_back_score_list), flush=True) 172 | print("overall: ", sum(overall_list)/len(overall_list)) 173 | 174 | for i in range(iterations): 175 | print("=============================", flush=True) 176 | print("dino sub in iteration {}: {}".format(i+1, sum(dino_sub_score_list_iteration[i])/len(dino_sub_score_list_iteration[i])), flush=True) 177 | print("dino back in iteration {}: {}".format(i+1, sum(dino_back_score_list_iteration[i]) / len( 178 | dino_back_score_list_iteration[i])), flush=True) 179 | print("clip sub in iteration {}: {}".format(i+1, sum(clipi_sub_score_list_iteration[i]) / len( 180 | clipi_sub_score_list_iteration[i])), flush=True) 181 | print("clip back in iteration {}: {}".format(i+1, sum(clipi_back_score_list_iteration[i]) / len( 182 | clipi_back_score_list_iteration[i])), flush=True) 183 | print("overall in iteration {}: {}".format(i+1, sum(overall_list_iteration[i]) / len( 184 | overall_list_iteration[i])), flush=True) 185 | print("=============================", flush=True) 186 | 187 | 188 | -------------------------------------------------------------------------------- /src/clip_img_ret.py: -------------------------------------------------------------------------------- 1 | from IPython.display import Image, display 2 | from clip_retrieval.clip_client import ClipClient, Modality 3 | import urllib.request 4 | import logging 5 | import argparse 6 | from pathlib import Path 7 | import os 8 | import socket 9 | 10 | 11 | def parse_args(input_args=None): 12 | parser = argparse.ArgumentParser(description="Script for clip retrieval.") 13 | parser.add_argument( 14 | "--text_prompt", 15 | type=str, 16 | default=None, 17 | required=True, 18 | help="The text prompt of image want to be retrieved.", 19 | ) 20 | parser.add_argument( 21 | "--output_path", 22 | type=str, 23 | default=None, 24 | help="Path to the retrieved image.", 25 | ) 26 | parser.add_argument( 27 | "--num_images", 28 | type=int, 29 | default=10, 30 | required=False, 31 | help="Number of images to be retrived.", 32 | ) 33 | parser.add_argument( 34 | "--output_dir", 35 | type=str, 36 | default="./", 37 | help="The output directory where the model predictions and checkpoints will be written.", 38 | ) 39 | if input_args is not None: 40 | args = parser.parse_args(input_args) 41 | else: 42 | args = parser.parse_args() 43 | return args 44 | 45 | 46 | def log_result(result): 47 | id, caption, url, similarity = result["id"], result["caption"], result["url"], result["similarity"] 48 | print(f"id: {id}") 49 | print(f"caption: {caption}") 50 | print(f"url: {url}") 51 | print(f"similarity: {similarity}") 52 | display(Image(url=url, unconfined=True)) 53 | 54 | 55 | def main(args): 56 | # logging_dir = Path(args.output_dir, "0", args.logging_dir) 57 | 58 | logging.basicConfig( 59 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 60 | datefmt="%m/%d/%Y %H:%M:%S", 61 | level=logging.INFO, 62 | ) 63 | logger = logging.getLogger('clip_image_retrieval') 64 | 65 | client = ClipClient( 66 | url="https://knn.laion.ai/knn-service", 67 | indice_name="laion5B-H-14", 68 | aesthetic_score=9, 69 | aesthetic_weight=0.5, 70 | modality=Modality.IMAGE, 71 | num_images=args.num_images) 72 | 73 | results = client.query(text=args.text_prompt) 74 | logger.info("{} of {} are retrieved".format(len(results), args.text_prompt)) 75 | for i, result in enumerate(results): 76 | if not result["url"].endswith(".jpg"): 77 | continue 78 | if not os.path.exists(args.output_path): 79 | os.makedirs(args.output_path) 80 | try: 81 | urllib.request.urlretrieve(result["url"], args.output_path + "found" + str(i) + ".jpeg") 82 | except: 83 | logger.info("failure") 84 | continue 85 | 86 | 87 | if __name__ == "__main__": 88 | args = parse_args() 89 | main(args) 90 | -------------------------------------------------------------------------------- /src/configs/autoencoder/autoencoder_kl_16x16x16.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 16 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 16 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /src/configs/autoencoder/autoencoder_kl_32x32x4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 4 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 4 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /src/configs/autoencoder/autoencoder_kl_64x64x3.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 3 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 3 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [ ] 24 | dropout: 0.0 25 | 26 | 27 | data: 28 | target: main.DataModuleFromConfig 29 | params: 30 | batch_size: 12 31 | wrap: True 32 | train: 33 | target: ldm.data.imagenet.ImageNetSRTrain 34 | params: 35 | size: 256 36 | degradation: pil_nearest 37 | validation: 38 | target: ldm.data.imagenet.ImageNetSRValidation 39 | params: 40 | size: 256 41 | degradation: pil_nearest 42 | 43 | lightning: 44 | callbacks: 45 | image_logger: 46 | target: main.ImageLogger 47 | params: 48 | batch_frequency: 1000 49 | max_images: 8 50 | increase_log_steps: True 51 | 52 | trainer: 53 | benchmark: True 54 | accumulate_grad_batches: 2 55 | -------------------------------------------------------------------------------- /src/configs/autoencoder/autoencoder_kl_8x8x64.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-6 3 | target: ldm.models.autoencoder.AutoencoderKL 4 | params: 5 | monitor: "val/rec_loss" 6 | embed_dim: 64 7 | lossconfig: 8 | target: ldm.modules.losses.LPIPSWithDiscriminator 9 | params: 10 | disc_start: 50001 11 | kl_weight: 0.000001 12 | disc_weight: 0.5 13 | 14 | ddconfig: 15 | double_z: True 16 | z_channels: 64 17 | resolution: 256 18 | in_channels: 3 19 | out_ch: 3 20 | ch: 128 21 | ch_mult: [ 1,1,2,2,4,4] # num_down = len(ch_mult)-1 22 | num_res_blocks: 2 23 | attn_resolutions: [16,8] 24 | dropout: 0.0 25 | 26 | data: 27 | target: main.DataModuleFromConfig 28 | params: 29 | batch_size: 12 30 | wrap: True 31 | train: 32 | target: ldm.data.imagenet.ImageNetSRTrain 33 | params: 34 | size: 256 35 | degradation: pil_nearest 36 | validation: 37 | target: ldm.data.imagenet.ImageNetSRValidation 38 | params: 39 | size: 256 40 | degradation: pil_nearest 41 | 42 | lightning: 43 | callbacks: 44 | image_logger: 45 | target: main.ImageLogger 46 | params: 47 | batch_frequency: 1000 48 | max_images: 8 49 | increase_log_steps: True 50 | 51 | trainer: 52 | benchmark: True 53 | accumulate_grad_batches: 2 54 | -------------------------------------------------------------------------------- /src/configs/dream-edit/add_dog_default.yaml: -------------------------------------------------------------------------------- 1 | base_path: "${oc.env:HOME}" 2 | experiment_name: "default" 3 | config_name: "add_dog_default" 4 | experiment_result_path: "${base_path}/DreamEdit/experiments-results-analysis/${experiment_name}/results/${config_name}/" 5 | 6 | data: 7 | src_img_data_folder_path: "${base_path}/DreamEdit/data/background_images_refine/" 8 | class_name: "dog" 9 | bbox_file_name: "bbox.json" 10 | src_img_file_name: "found0.jpg" 11 | 12 | db_dataset_path: "${base_path}/dream_booth2.0/cvpr_dataset/" 13 | db_folder_name: "dog" 14 | obj_img_file_name: "00.jpg" 15 | 16 | model: 17 | gligen: # GLIGEN: Open-Set Grounded Text-to-Image Generation 18 | gligen_scheduled_sampling_beta: 1 # TODO: What is this? 19 | num_inference_steps: 100 20 | 21 | lang_sam: # Segment Anything 22 | segment_confidence: 0.1 # segmentation confidence in segment-anything 23 | 24 | sd: # Stable Diffusion 25 | conf_path: "configs/stable-diffusion/v1-inference.yaml" 26 | ckpt_prefix: "${base_path}/Dreambooth-Stable-Diffusion_org/logs/" 27 | ckpt: "dog2023-04-17T00-32-48_dog_april/" 28 | ckpt_suffix: "checkpoints/last.ckpt" 29 | ckpt_path: "${model.sd.ckpt_prefix}${model.sd.ckpt}${model.sd.ckpt_suffix}" 30 | 31 | de: # DreamEdit 32 | task_type: "add" 33 | special_token: "zwx" 34 | bounding_box: "bbox.json" 35 | inpaint_after_last_iteration: False # whether to inpaint after the last iteration 36 | postprocessing_type: "sd_inpaint" 37 | 38 | addition_config: 39 | use_copy_paste: False 40 | inpaint_type: "gligen" 41 | automate_prompt: False # whether to generate prompt from BLIP image caption model 42 | inpaint_prompt: "photo of a yellow and white corgi dog" 43 | inpaint_phrase: "a yellow and white corgi dog" 44 | 45 | mask_config: 46 | mask_dilate_kernel: 22 47 | mask_type: "dilation" 48 | use_bbox_mask_for_first_iteration: True # whether to use bbox as the mask for the first iteration 49 | use_bbox_mask_for_all_iterations: False # whether to use bbox as the mask for all iterations 50 | 51 | ddim: 52 | seed: 42 # the seed (for reproducible sampling) 53 | scale: 7.5 54 | ddim_steps: 40 55 | noise_step: 0 56 | iteration_number: 10 57 | encode_ratio_schedule: 58 | decay_type: "exponential" # "linear" or "exponential" or "constant" or "manual" 59 | start_ratio: 0.8 60 | end_ratio: 0.3 61 | manual_ratio_list: [0.8, 0.7, 0.6, 0.5, 0.4, 0.3] # only used when decay_type is "manual" 62 | 63 | background_correction_enabled: True 64 | background_correction: 65 | iteration_number: 4 # how many iterations to correct the background 66 | use_latents_record: False # reuse the latents from the first iteration 67 | use_background_from_original_image: True 68 | use_obj_mask_from_first_iteration: False # whether always use the object mask from the first iteration -------------------------------------------------------------------------------- /src/configs/dream-edit/edit_backpack_default.yaml: -------------------------------------------------------------------------------- 1 | base_path: "${oc.env:HOME}" 2 | experiment_name: "default" 3 | config_name: "edit_backpack_default" 4 | experiment_result_path: "${base_path}/DreamEdit/experiments-results-analysis/${experiment_name}/results/${config_name}/" 5 | 6 | data: 7 | src_img_data_folder_path: "${base_path}/DreamEdit/data/ref_images/" 8 | class_name: "backpack" 9 | bbox_file_name: "bbox.json" 10 | src_img_file_name: "found0.jpg" 11 | 12 | db_dataset_path: "${base_path}/dream_booth2.0/cvpr_dataset/" 13 | db_folder_name: "backpack" 14 | obj_img_file_name: "00.jpg" 15 | 16 | model: 17 | gligen: # GLIGEN: Open-Set Grounded Text-to-Image Generation 18 | gligen_scheduled_sampling_beta: 1 # TODO: What is this? 19 | num_inference_steps: 150 20 | 21 | lang_sam: # Segment Anything 22 | segment_confidence: 0.3 # segmentation confidence in segment-anything 23 | 24 | sd: # Stable Diffusion 25 | conf_path: "configs/stable-diffusion/v1-inference.yaml" 26 | ckpt_prefix: "${base_path}/Dreambooth-Stable-Diffusion_org/logs/" 27 | ckpt: "backpack2023-04-17T01-18-19_backpack_april/" 28 | ckpt_suffix: "checkpoints/last.ckpt" 29 | ckpt_path: "${model.sd.ckpt_prefix}${model.sd.ckpt}${model.sd.ckpt_suffix}" 30 | 31 | de: # DreamEdit 32 | task_type: "replace" 33 | special_token: "wcj" 34 | bounding_box: "bbox.json" 35 | inpaint_after_last_iteration: False # whether to inpaint after the last iteration 36 | postprocessing_type: "sd_inpaint" 37 | 38 | addition_config: 39 | use_copy_paste: False 40 | inpaint_type: "gligen" 41 | automate_prompt: False # whether to generate prompt from BLIP image caption model 42 | inpaint_prompt: "" 43 | inpaint_phrase: "" 44 | 45 | mask_config: 46 | mask_dilate_kernel: 5 47 | mask_type: "dilation" 48 | use_bbox_mask_for_first_iteration: False # whether to use bbox as the mask for the first iteration 49 | use_bbox_mask_for_all_iterations: False # whether to use bbox as the mask for all iterations 50 | 51 | ddim: 52 | seed: 42 # the seed (for reproducible sampling) 53 | scale: 7.5 54 | ddim_steps: 40 55 | noise_step: 0 56 | iteration_number: 15 57 | encode_ratio_schedule: 58 | decay_type: "exponential" # "linear" or "exponential" or "constant" or "manual" 59 | start_ratio: 0.8 60 | end_ratio: 0.3 61 | manual_ratio_list: [0.8, 0.7, 0.6, 0.5, 0.4, 0.3] # only used when decay_type is "manual" 62 | 63 | background_correction_enabled: True 64 | background_correction: 65 | iteration_number: 4 # how many iterations to correct the background 66 | use_latents_record: False # reuse the latents from the first iteration 67 | use_background_from_original_image: True 68 | use_obj_mask_from_first_iteration: False # whether always use the object mask from the first iteration -------------------------------------------------------------------------------- /src/configs/latent-diffusion/celebahq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | 15 | unet_config: 16 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 17 | params: 18 | image_size: 64 19 | in_channels: 3 20 | out_channels: 3 21 | model_channels: 224 22 | attention_resolutions: 23 | # note: this isn\t actually the resolution but 24 | # the downsampling factor, i.e. this corresnponds to 25 | # attention on spatial resolution 8,16,32, as the 26 | # spatial reolution of the latents is 64 for f4 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | num_head_channels: 32 37 | first_stage_config: 38 | target: ldm.models.autoencoder.VQModelInterface 39 | params: 40 | embed_dim: 3 41 | n_embed: 8192 42 | ckpt_path: models/first_stage_models/vq-f4/model.ckpt 43 | ddconfig: 44 | double_z: false 45 | z_channels: 3 46 | resolution: 256 47 | in_channels: 3 48 | out_ch: 3 49 | ch: 128 50 | ch_mult: 51 | - 1 52 | - 2 53 | - 4 54 | num_res_blocks: 2 55 | attn_resolutions: [] 56 | dropout: 0.0 57 | lossconfig: 58 | target: torch.nn.Identity 59 | cond_stage_config: __is_unconditional__ 60 | data: 61 | target: main.DataModuleFromConfig 62 | params: 63 | batch_size: 48 64 | num_workers: 5 65 | wrap: false 66 | train: 67 | target: taming.data.faceshq.CelebAHQTrain 68 | params: 69 | size: 256 70 | validation: 71 | target: taming.data.faceshq.CelebAHQValidation 72 | params: 73 | size: 256 74 | 75 | 76 | lightning: 77 | callbacks: 78 | image_logger: 79 | target: main.ImageLogger 80 | params: 81 | batch_frequency: 5000 82 | max_images: 8 83 | increase_log_steps: False 84 | 85 | trainer: 86 | benchmark: True -------------------------------------------------------------------------------- /src/configs/latent-diffusion/cin-ldm-vq-f8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | unet_config: 18 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 19 | params: 20 | image_size: 32 21 | in_channels: 4 22 | out_channels: 4 23 | model_channels: 256 24 | attention_resolutions: 25 | #note: this isn\t actually the resolution but 26 | # the downsampling factor, i.e. this corresnponds to 27 | # attention on spatial resolution 8,16,32, as the 28 | # spatial reolution of the latents is 32 for f8 29 | - 4 30 | - 2 31 | - 1 32 | num_res_blocks: 2 33 | channel_mult: 34 | - 1 35 | - 2 36 | - 4 37 | num_head_channels: 32 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 512 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 4 45 | n_embed: 16384 46 | ckpt_path: configs/first_stage_models/vq-f8/model.yaml 47 | ddconfig: 48 | double_z: false 49 | z_channels: 4 50 | resolution: 256 51 | in_channels: 3 52 | out_ch: 3 53 | ch: 128 54 | ch_mult: 55 | - 1 56 | - 2 57 | - 2 58 | - 4 59 | num_res_blocks: 2 60 | attn_resolutions: 61 | - 32 62 | dropout: 0.0 63 | lossconfig: 64 | target: torch.nn.Identity 65 | cond_stage_config: 66 | target: ldm.modules.encoders.modules.ClassEmbedder 67 | params: 68 | embed_dim: 512 69 | key: class_label 70 | data: 71 | target: main.DataModuleFromConfig 72 | params: 73 | batch_size: 64 74 | num_workers: 12 75 | wrap: false 76 | train: 77 | target: ldm.data.imagenet.ImageNetTrain 78 | params: 79 | config: 80 | size: 256 81 | validation: 82 | target: ldm.data.imagenet.ImageNetValidation 83 | params: 84 | config: 85 | size: 256 86 | 87 | 88 | lightning: 89 | callbacks: 90 | image_logger: 91 | target: main.ImageLogger 92 | params: 93 | batch_frequency: 5000 94 | max_images: 8 95 | increase_log_steps: False 96 | 97 | trainer: 98 | benchmark: True -------------------------------------------------------------------------------- /src/configs/latent-diffusion/cin256-v2.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 64 23 | in_channels: 3 24 | out_channels: 3 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 5 36 | num_heads: 1 37 | use_spatial_transformer: true 38 | transformer_depth: 1 39 | context_dim: 512 40 | 41 | first_stage_config: 42 | target: ldm.models.autoencoder.VQModelInterface 43 | params: 44 | embed_dim: 3 45 | n_embed: 8192 46 | ddconfig: 47 | double_z: false 48 | z_channels: 3 49 | resolution: 256 50 | in_channels: 3 51 | out_ch: 3 52 | ch: 128 53 | ch_mult: 54 | - 1 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: [] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: 64 | target: ldm.modules.encoders.modules.ClassEmbedder 65 | params: 66 | n_classes: 1001 67 | embed_dim: 512 68 | key: class_label 69 | -------------------------------------------------------------------------------- /src/configs/latent-diffusion/ffhq-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | embed_dim: 3 40 | n_embed: 8192 41 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 42 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: taming.data.faceshq.FFHQTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: taming.data.faceshq.FFHQValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /src/configs/latent-diffusion/lsun_bedrooms-ldm-vq-4.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 2.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | image_size: 64 12 | channels: 3 13 | monitor: val/loss_simple_ema 14 | unet_config: 15 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 16 | params: 17 | image_size: 64 18 | in_channels: 3 19 | out_channels: 3 20 | model_channels: 224 21 | attention_resolutions: 22 | # note: this isn\t actually the resolution but 23 | # the downsampling factor, i.e. this corresnponds to 24 | # attention on spatial resolution 8,16,32, as the 25 | # spatial reolution of the latents is 64 for f4 26 | - 8 27 | - 4 28 | - 2 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 3 34 | - 4 35 | num_head_channels: 32 36 | first_stage_config: 37 | target: ldm.models.autoencoder.VQModelInterface 38 | params: 39 | ckpt_path: configs/first_stage_models/vq-f4/model.yaml 40 | embed_dim: 3 41 | n_embed: 8192 42 | ddconfig: 43 | double_z: false 44 | z_channels: 3 45 | resolution: 256 46 | in_channels: 3 47 | out_ch: 3 48 | ch: 128 49 | ch_mult: 50 | - 1 51 | - 2 52 | - 4 53 | num_res_blocks: 2 54 | attn_resolutions: [] 55 | dropout: 0.0 56 | lossconfig: 57 | target: torch.nn.Identity 58 | cond_stage_config: __is_unconditional__ 59 | data: 60 | target: main.DataModuleFromConfig 61 | params: 62 | batch_size: 48 63 | num_workers: 5 64 | wrap: false 65 | train: 66 | target: ldm.data.lsun.LSUNBedroomsTrain 67 | params: 68 | size: 256 69 | validation: 70 | target: ldm.data.lsun.LSUNBedroomsValidation 71 | params: 72 | size: 256 73 | 74 | 75 | lightning: 76 | callbacks: 77 | image_logger: 78 | target: main.ImageLogger 79 | params: 80 | batch_frequency: 5000 81 | max_images: 8 82 | increase_log_steps: False 83 | 84 | trainer: 85 | benchmark: True -------------------------------------------------------------------------------- /src/configs/latent-diffusion/lsun_churches-ldm-kl-8.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-5 # set to target_lr by starting main.py with '--scale_lr False' 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0155 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | loss_type: l1 11 | first_stage_key: "image" 12 | cond_stage_key: "image" 13 | image_size: 32 14 | channels: 4 15 | cond_stage_trainable: False 16 | concat_mode: False 17 | scale_by_std: True 18 | monitor: 'val/loss_simple_ema' 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [10000] 24 | cycle_lengths: [10000000000000] 25 | f_start: [1.e-6] 26 | f_max: [1.] 27 | f_min: [ 1.] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 192 36 | attention_resolutions: [ 1, 2, 4, 8 ] # 32, 16, 8, 4 37 | num_res_blocks: 2 38 | channel_mult: [ 1,2,2,4,4 ] # 32, 16, 8, 4, 2 39 | num_heads: 8 40 | use_scale_shift_norm: True 41 | resblock_updown: True 42 | 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | embed_dim: 4 47 | monitor: "val/rec_loss" 48 | ckpt_path: "models/first_stage_models/kl-f8/model.ckpt" 49 | ddconfig: 50 | double_z: True 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: [ 1,2,4,4 ] # num_down = len(ch_mult)-1 57 | num_res_blocks: 2 58 | attn_resolutions: [ ] 59 | dropout: 0.0 60 | lossconfig: 61 | target: torch.nn.Identity 62 | 63 | cond_stage_config: "__is_unconditional__" 64 | 65 | data: 66 | target: main.DataModuleFromConfig 67 | params: 68 | batch_size: 96 69 | num_workers: 5 70 | wrap: False 71 | train: 72 | target: ldm.data.lsun.LSUNChurchesTrain 73 | params: 74 | size: 256 75 | validation: 76 | target: ldm.data.lsun.LSUNChurchesValidation 77 | params: 78 | size: 256 79 | 80 | lightning: 81 | callbacks: 82 | image_logger: 83 | target: main.ImageLogger 84 | params: 85 | batch_frequency: 5000 86 | max_images: 8 87 | increase_log_steps: False 88 | 89 | 90 | trainer: 91 | benchmark: True -------------------------------------------------------------------------------- /src/configs/latent-diffusion/txt2img-1p4B-eval.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.012 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: caption 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | unet_config: 21 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 22 | params: 23 | image_size: 32 24 | in_channels: 4 25 | out_channels: 4 26 | model_channels: 320 27 | attention_resolutions: 28 | - 4 29 | - 2 30 | - 1 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 4 36 | - 4 37 | num_heads: 8 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 1280 41 | use_checkpoint: true 42 | legacy: False 43 | 44 | first_stage_config: 45 | target: ldm.models.autoencoder.AutoencoderKL 46 | params: 47 | embed_dim: 4 48 | monitor: val/rec_loss 49 | ddconfig: 50 | double_z: true 51 | z_channels: 4 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: ldm.modules.encoders.modules.BERTEmbedder 69 | params: 70 | n_embed: 1280 71 | n_layer: 32 72 | -------------------------------------------------------------------------------- /src/configs/retrieval-augmented-diffusion/768x768.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.015 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: jpg 11 | cond_stage_key: nix 12 | image_size: 48 13 | channels: 16 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_by_std: false 18 | scale_factor: 0.22765929 19 | unet_config: 20 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 48 23 | in_channels: 16 24 | out_channels: 16 25 | model_channels: 448 26 | attention_resolutions: 27 | - 4 28 | - 2 29 | - 1 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 4 36 | use_scale_shift_norm: false 37 | resblock_updown: false 38 | num_head_channels: 32 39 | use_spatial_transformer: true 40 | transformer_depth: 1 41 | context_dim: 768 42 | use_checkpoint: true 43 | first_stage_config: 44 | target: ldm.models.autoencoder.AutoencoderKL 45 | params: 46 | monitor: val/rec_loss 47 | embed_dim: 16 48 | ddconfig: 49 | double_z: true 50 | z_channels: 16 51 | resolution: 256 52 | in_channels: 3 53 | out_ch: 3 54 | ch: 128 55 | ch_mult: 56 | - 1 57 | - 1 58 | - 2 59 | - 2 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: 63 | - 16 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | cond_stage_config: 68 | target: torch.nn.Identity -------------------------------------------------------------------------------- /src/configs/stable-diffusion/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /src/ldm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamEditBenchTeam/DreamEdit/14d21b0a3eb6305c1378080ccd8361db0a8adcc0/src/ldm/__init__.py -------------------------------------------------------------------------------- /src/ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamEditBenchTeam/DreamEdit/14d21b0a3eb6305c1378080ccd8361db0a8adcc0/src/ldm/data/__init__.py -------------------------------------------------------------------------------- /src/ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /src/ldm/data/imagenet.py: -------------------------------------------------------------------------------- 1 | import os, yaml, pickle, shutil, tarfile, glob 2 | import cv2 3 | import albumentations 4 | import PIL 5 | import numpy as np 6 | import torchvision.transforms.functional as TF 7 | from omegaconf import OmegaConf 8 | from functools import partial 9 | from PIL import Image 10 | from tqdm import tqdm 11 | from torch.utils.data import Dataset, Subset 12 | 13 | import taming.data.utils as tdu 14 | from taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve 15 | from taming.data.imagenet import ImagePaths 16 | 17 | from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light 18 | 19 | 20 | def synset2idx(path_to_yaml="data/index_synset.yaml"): 21 | with open(path_to_yaml) as f: 22 | di2s = yaml.load(f) 23 | return dict((v,k) for k,v in di2s.items()) 24 | 25 | 26 | class ImageNetBase(Dataset): 27 | def __init__(self, config=None): 28 | self.config = config or OmegaConf.create() 29 | if not type(self.config)==dict: 30 | self.config = OmegaConf.to_container(self.config) 31 | self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) 32 | self.process_images = True # if False we skip loading & processing images and self.data contains filepaths 33 | self._prepare() 34 | self._prepare_synset_to_human() 35 | self._prepare_idx_to_synset() 36 | self._prepare_human_to_integer_label() 37 | self._load() 38 | 39 | def __len__(self): 40 | return len(self.data) 41 | 42 | def __getitem__(self, i): 43 | return self.data[i] 44 | 45 | def _prepare(self): 46 | raise NotImplementedError() 47 | 48 | def _filter_relpaths(self, relpaths): 49 | ignore = set([ 50 | "n06596364_9591.JPEG", 51 | ]) 52 | relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] 53 | if "sub_indices" in self.config: 54 | indices = str_to_indices(self.config["sub_indices"]) 55 | synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings 56 | self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) 57 | files = [] 58 | for rpath in relpaths: 59 | syn = rpath.split("/")[0] 60 | if syn in synsets: 61 | files.append(rpath) 62 | return files 63 | else: 64 | return relpaths 65 | 66 | def _prepare_synset_to_human(self): 67 | SIZE = 2655750 68 | URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" 69 | self.human_dict = os.path.join(self.root, "synset_human.txt") 70 | if (not os.path.exists(self.human_dict) or 71 | not os.path.getsize(self.human_dict)==SIZE): 72 | download(URL, self.human_dict) 73 | 74 | def _prepare_idx_to_synset(self): 75 | URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" 76 | self.idx2syn = os.path.join(self.root, "index_synset.yaml") 77 | if (not os.path.exists(self.idx2syn)): 78 | download(URL, self.idx2syn) 79 | 80 | def _prepare_human_to_integer_label(self): 81 | URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" 82 | self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") 83 | if (not os.path.exists(self.human2integer)): 84 | download(URL, self.human2integer) 85 | with open(self.human2integer, "r") as f: 86 | lines = f.read().splitlines() 87 | assert len(lines) == 1000 88 | self.human2integer_dict = dict() 89 | for line in lines: 90 | value, key = line.split(":") 91 | self.human2integer_dict[key] = int(value) 92 | 93 | def _load(self): 94 | with open(self.txt_filelist, "r") as f: 95 | self.relpaths = f.read().splitlines() 96 | l1 = len(self.relpaths) 97 | self.relpaths = self._filter_relpaths(self.relpaths) 98 | print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) 99 | 100 | self.synsets = [p.split("/")[0] for p in self.relpaths] 101 | self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] 102 | 103 | unique_synsets = np.unique(self.synsets) 104 | class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) 105 | if not self.keep_orig_class_label: 106 | self.class_labels = [class_dict[s] for s in self.synsets] 107 | else: 108 | self.class_labels = [self.synset2idx[s] for s in self.synsets] 109 | 110 | with open(self.human_dict, "r") as f: 111 | human_dict = f.read().splitlines() 112 | human_dict = dict(line.split(maxsplit=1) for line in human_dict) 113 | 114 | self.human_labels = [human_dict[s] for s in self.synsets] 115 | 116 | labels = { 117 | "relpath": np.array(self.relpaths), 118 | "synsets": np.array(self.synsets), 119 | "class_label": np.array(self.class_labels), 120 | "human_label": np.array(self.human_labels), 121 | } 122 | 123 | if self.process_images: 124 | self.size = retrieve(self.config, "size", default=256) 125 | self.data = ImagePaths(self.abspaths, 126 | labels=labels, 127 | size=self.size, 128 | random_crop=self.random_crop, 129 | ) 130 | else: 131 | self.data = self.abspaths 132 | 133 | 134 | class ImageNetTrain(ImageNetBase): 135 | NAME = "ILSVRC2012_train" 136 | URL = "http://www.image-net.org/challenges/LSVRC/2012/" 137 | AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" 138 | FILES = [ 139 | "ILSVRC2012_img_train.tar", 140 | ] 141 | SIZES = [ 142 | 147897477120, 143 | ] 144 | 145 | def __init__(self, process_images=True, data_root=None, **kwargs): 146 | self.process_images = process_images 147 | self.data_root = data_root 148 | super().__init__(**kwargs) 149 | 150 | def _prepare(self): 151 | if self.data_root: 152 | self.root = os.path.join(self.data_root, self.NAME) 153 | else: 154 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) 155 | self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) 156 | 157 | self.datadir = os.path.join(self.root, "data") 158 | self.txt_filelist = os.path.join(self.root, "filelist.txt") 159 | self.expected_length = 1281167 160 | self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", 161 | default=True) 162 | if not tdu.is_prepared(self.root): 163 | # prep 164 | print("Preparing dataset {} in {}".format(self.NAME, self.root)) 165 | 166 | datadir = self.datadir 167 | if not os.path.exists(datadir): 168 | path = os.path.join(self.root, self.FILES[0]) 169 | if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: 170 | import academictorrents as at 171 | atpath = at.get(self.AT_HASH, datastore=self.root) 172 | assert atpath == path 173 | 174 | print("Extracting {} to {}".format(path, datadir)) 175 | os.makedirs(datadir, exist_ok=True) 176 | with tarfile.open(path, "r:") as tar: 177 | tar.extractall(path=datadir) 178 | 179 | print("Extracting sub-tars.") 180 | subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) 181 | for subpath in tqdm(subpaths): 182 | subdir = subpath[:-len(".tar")] 183 | os.makedirs(subdir, exist_ok=True) 184 | with tarfile.open(subpath, "r:") as tar: 185 | tar.extractall(path=subdir) 186 | 187 | filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) 188 | filelist = [os.path.relpath(p, start=datadir) for p in filelist] 189 | filelist = sorted(filelist) 190 | filelist = "\n".join(filelist)+"\n" 191 | with open(self.txt_filelist, "w") as f: 192 | f.write(filelist) 193 | 194 | tdu.mark_prepared(self.root) 195 | 196 | 197 | class ImageNetValidation(ImageNetBase): 198 | NAME = "ILSVRC2012_validation" 199 | URL = "http://www.image-net.org/challenges/LSVRC/2012/" 200 | AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" 201 | VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" 202 | FILES = [ 203 | "ILSVRC2012_img_val.tar", 204 | "validation_synset.txt", 205 | ] 206 | SIZES = [ 207 | 6744924160, 208 | 1950000, 209 | ] 210 | 211 | def __init__(self, process_images=True, data_root=None, **kwargs): 212 | self.data_root = data_root 213 | self.process_images = process_images 214 | super().__init__(**kwargs) 215 | 216 | def _prepare(self): 217 | if self.data_root: 218 | self.root = os.path.join(self.data_root, self.NAME) 219 | else: 220 | cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) 221 | self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) 222 | self.datadir = os.path.join(self.root, "data") 223 | self.txt_filelist = os.path.join(self.root, "filelist.txt") 224 | self.expected_length = 50000 225 | self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", 226 | default=False) 227 | if not tdu.is_prepared(self.root): 228 | # prep 229 | print("Preparing dataset {} in {}".format(self.NAME, self.root)) 230 | 231 | datadir = self.datadir 232 | if not os.path.exists(datadir): 233 | path = os.path.join(self.root, self.FILES[0]) 234 | if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: 235 | import academictorrents as at 236 | atpath = at.get(self.AT_HASH, datastore=self.root) 237 | assert atpath == path 238 | 239 | print("Extracting {} to {}".format(path, datadir)) 240 | os.makedirs(datadir, exist_ok=True) 241 | with tarfile.open(path, "r:") as tar: 242 | tar.extractall(path=datadir) 243 | 244 | vspath = os.path.join(self.root, self.FILES[1]) 245 | if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: 246 | download(self.VS_URL, vspath) 247 | 248 | with open(vspath, "r") as f: 249 | synset_dict = f.read().splitlines() 250 | synset_dict = dict(line.split() for line in synset_dict) 251 | 252 | print("Reorganizing into synset folders") 253 | synsets = np.unique(list(synset_dict.values())) 254 | for s in synsets: 255 | os.makedirs(os.path.join(datadir, s), exist_ok=True) 256 | for k, v in synset_dict.items(): 257 | src = os.path.join(datadir, k) 258 | dst = os.path.join(datadir, v) 259 | shutil.move(src, dst) 260 | 261 | filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) 262 | filelist = [os.path.relpath(p, start=datadir) for p in filelist] 263 | filelist = sorted(filelist) 264 | filelist = "\n".join(filelist)+"\n" 265 | with open(self.txt_filelist, "w") as f: 266 | f.write(filelist) 267 | 268 | tdu.mark_prepared(self.root) 269 | 270 | 271 | 272 | class ImageNetSR(Dataset): 273 | def __init__(self, size=None, 274 | degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., 275 | random_crop=True): 276 | """ 277 | Imagenet Superresolution Dataloader 278 | Performs following ops in order: 279 | 1. crops a crop of size s from image either as random or center crop 280 | 2. resizes crop to size with cv2.area_interpolation 281 | 3. degrades resized crop with degradation_fn 282 | 283 | :param size: resizing to size after cropping 284 | :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light 285 | :param downscale_f: Low Resolution Downsample factor 286 | :param min_crop_f: determines crop size s, 287 | where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) 288 | :param max_crop_f: "" 289 | :param data_root: 290 | :param random_crop: 291 | """ 292 | self.base = self.get_base() 293 | assert size 294 | assert (size / downscale_f).is_integer() 295 | self.size = size 296 | self.LR_size = int(size / downscale_f) 297 | self.min_crop_f = min_crop_f 298 | self.max_crop_f = max_crop_f 299 | assert(max_crop_f <= 1.) 300 | self.center_crop = not random_crop 301 | 302 | self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) 303 | 304 | self.pil_interpolation = False # gets reset later if incase interp_op is from pillow 305 | 306 | if degradation == "bsrgan": 307 | self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) 308 | 309 | elif degradation == "bsrgan_light": 310 | self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) 311 | 312 | else: 313 | interpolation_fn = { 314 | "cv_nearest": cv2.INTER_NEAREST, 315 | "cv_bilinear": cv2.INTER_LINEAR, 316 | "cv_bicubic": cv2.INTER_CUBIC, 317 | "cv_area": cv2.INTER_AREA, 318 | "cv_lanczos": cv2.INTER_LANCZOS4, 319 | "pil_nearest": PIL.Image.NEAREST, 320 | "pil_bilinear": PIL.Image.BILINEAR, 321 | "pil_bicubic": PIL.Image.BICUBIC, 322 | "pil_box": PIL.Image.BOX, 323 | "pil_hamming": PIL.Image.HAMMING, 324 | "pil_lanczos": PIL.Image.LANCZOS, 325 | }[degradation] 326 | 327 | self.pil_interpolation = degradation.startswith("pil_") 328 | 329 | if self.pil_interpolation: 330 | self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) 331 | 332 | else: 333 | self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, 334 | interpolation=interpolation_fn) 335 | 336 | def __len__(self): 337 | return len(self.base) 338 | 339 | def __getitem__(self, i): 340 | example = self.base[i] 341 | image = Image.open(example["file_path_"]) 342 | 343 | if not image.mode == "RGB": 344 | image = image.convert("RGB") 345 | 346 | image = np.array(image).astype(np.uint8) 347 | 348 | min_side_len = min(image.shape[:2]) 349 | crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) 350 | crop_side_len = int(crop_side_len) 351 | 352 | if self.center_crop: 353 | self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) 354 | 355 | else: 356 | self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) 357 | 358 | image = self.cropper(image=image)["image"] 359 | image = self.image_rescaler(image=image)["image"] 360 | 361 | if self.pil_interpolation: 362 | image_pil = PIL.Image.fromarray(image) 363 | LR_image = self.degradation_process(image_pil) 364 | LR_image = np.array(LR_image).astype(np.uint8) 365 | 366 | else: 367 | LR_image = self.degradation_process(image=image)["image"] 368 | 369 | example["image"] = (image/127.5 - 1.0).astype(np.float32) 370 | example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) 371 | 372 | return example 373 | 374 | 375 | class ImageNetSRTrain(ImageNetSR): 376 | def __init__(self, **kwargs): 377 | super().__init__(**kwargs) 378 | 379 | def get_base(self): 380 | with open("data/imagenet_train_hr_indices.p", "rb") as f: 381 | indices = pickle.load(f) 382 | dset = ImageNetTrain(process_images=False,) 383 | return Subset(dset, indices) 384 | 385 | 386 | class ImageNetSRValidation(ImageNetSR): 387 | def __init__(self, **kwargs): 388 | super().__init__(**kwargs) 389 | 390 | def get_base(self): 391 | with open("data/imagenet_val_hr_indices.p", "rb") as f: 392 | indices = pickle.load(f) 393 | dset = ImageNetValidation(process_images=False,) 394 | return Subset(dset, indices) 395 | -------------------------------------------------------------------------------- /src/ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /src/ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /src/ldm/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamEditBenchTeam/DreamEdit/14d21b0a3eb6305c1378080ccd8361db0a8adcc0/src/ldm/models/__init__.py -------------------------------------------------------------------------------- /src/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamEditBenchTeam/DreamEdit/14d21b0a3eb6305c1378080ccd8361db0a8adcc0/src/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /src/ldm/models/diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | from torch.nn import functional as F 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from copy import deepcopy 9 | from einops import rearrange 10 | from glob import glob 11 | from natsort import natsorted 12 | 13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel 14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config 15 | 16 | __models__ = { 17 | 'class_label': EncoderUNetModel, 18 | 'segmentation': UNetModel 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | class NoisyLatentImageClassifier(pl.LightningModule): 29 | 30 | def __init__(self, 31 | diffusion_path, 32 | num_classes, 33 | ckpt_path=None, 34 | pool='attention', 35 | label_key=None, 36 | diffusion_ckpt_path=None, 37 | scheduler_config=None, 38 | weight_decay=1.e-2, 39 | log_steps=10, 40 | monitor='val/loss', 41 | *args, 42 | **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.num_classes = num_classes 45 | # get latest config of diffusion model 46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 47 | self.diffusion_config = OmegaConf.load(diffusion_config).model 48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 49 | self.load_diffusion() 50 | 51 | self.monitor = monitor 52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 54 | self.log_steps = log_steps 55 | 56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 57 | else self.diffusion_model.cond_stage_key 58 | 59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' 60 | 61 | if self.label_key not in __models__: 62 | raise NotImplementedError() 63 | 64 | self.load_classifier(ckpt_path, pool) 65 | 66 | self.scheduler_config = scheduler_config 67 | self.use_scheduler = self.scheduler_config is not None 68 | self.weight_decay = weight_decay 69 | 70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 71 | sd = torch.load(path, map_location="cpu") 72 | if "state_dict" in list(sd.keys()): 73 | sd = sd["state_dict"] 74 | keys = list(sd.keys()) 75 | for k in keys: 76 | for ik in ignore_keys: 77 | if k.startswith(ik): 78 | print("Deleting key {} from state_dict.".format(k)) 79 | del sd[k] 80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 81 | sd, strict=False) 82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | if len(unexpected) > 0: 86 | print(f"Unexpected Keys: {unexpected}") 87 | 88 | def load_diffusion(self): 89 | model = instantiate_from_config(self.diffusion_config) 90 | self.diffusion_model = model.eval() 91 | self.diffusion_model.train = disabled_train 92 | for param in self.diffusion_model.parameters(): 93 | param.requires_grad = False 94 | 95 | def load_classifier(self, ckpt_path, pool): 96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 98 | model_config.out_channels = self.num_classes 99 | if self.label_key == 'class_label': 100 | model_config.pool = pool 101 | 102 | self.model = __models__[self.label_key](**model_config) 103 | if ckpt_path is not None: 104 | print('#####################################################################') 105 | print(f'load from ckpt "{ckpt_path}"') 106 | print('#####################################################################') 107 | self.init_from_ckpt(ckpt_path) 108 | 109 | @torch.no_grad() 110 | def get_x_noisy(self, x, t, noise=None): 111 | noise = default(noise, lambda: torch.randn_like(x)) 112 | continuous_sqrt_alpha_cumprod = None 113 | if self.diffusion_model.use_continuous_noise: 114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 115 | # todo: make sure t+1 is correct here 116 | 117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 119 | 120 | def forward(self, x_noisy, t, *args, **kwargs): 121 | return self.model(x_noisy, t) 122 | 123 | @torch.no_grad() 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = rearrange(x, 'b h w c -> b c h w') 129 | x = x.to(memory_format=torch.contiguous_format).float() 130 | return x 131 | 132 | @torch.no_grad() 133 | def get_conditioning(self, batch, k=None): 134 | if k is None: 135 | k = self.label_key 136 | assert k is not None, 'Needs to provide label key' 137 | 138 | targets = batch[k].to(self.device) 139 | 140 | if self.label_key == 'segmentation': 141 | targets = rearrange(targets, 'b h w c -> b c h w') 142 | for down in range(self.numd): 143 | h, w = targets.shape[-2:] 144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 145 | 146 | # targets = rearrange(targets,'b c h w -> b h w c') 147 | 148 | return targets 149 | 150 | def compute_top_k(self, logits, labels, k, reduction="mean"): 151 | _, top_ks = torch.topk(logits, k, dim=1) 152 | if reduction == "mean": 153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 154 | elif reduction == "none": 155 | return (top_ks == labels[:, None]).float().sum(dim=-1) 156 | 157 | def on_train_epoch_start(self): 158 | # save some memory 159 | self.diffusion_model.model.to('cpu') 160 | 161 | @torch.no_grad() 162 | def write_logs(self, loss, logits, targets): 163 | log_prefix = 'train' if self.training else 'val' 164 | log = {} 165 | log[f"{log_prefix}/loss"] = loss.mean() 166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 167 | logits, targets, k=1, reduction="mean" 168 | ) 169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 170 | logits, targets, k=5, reduction="mean" 171 | ) 172 | 173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) 176 | lr = self.optimizers().param_groups[0]['lr'] 177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) 178 | 179 | def shared_step(self, batch, t=None): 180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 181 | targets = self.get_conditioning(batch) 182 | if targets.dim() == 4: 183 | targets = targets.argmax(dim=1) 184 | if t is None: 185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 186 | else: 187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 188 | x_noisy = self.get_x_noisy(x, t) 189 | logits = self(x_noisy, t) 190 | 191 | loss = F.cross_entropy(logits, targets, reduction='none') 192 | 193 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 194 | 195 | loss = loss.mean() 196 | return loss, logits, x_noisy, targets 197 | 198 | def training_step(self, batch, batch_idx): 199 | loss, *_ = self.shared_step(batch) 200 | return loss 201 | 202 | def reset_noise_accs(self): 203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 205 | 206 | def on_validation_start(self): 207 | self.reset_noise_accs() 208 | 209 | @torch.no_grad() 210 | def validation_step(self, batch, batch_idx): 211 | loss, *_ = self.shared_step(batch) 212 | 213 | for t in self.noisy_acc: 214 | _, logits, _, targets = self.shared_step(batch, t) 215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 217 | 218 | return loss 219 | 220 | def configure_optimizers(self): 221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 222 | 223 | if self.use_scheduler: 224 | scheduler = instantiate_from_config(self.scheduler_config) 225 | 226 | print("Setting up LambdaLR scheduler...") 227 | scheduler = [ 228 | { 229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 230 | 'interval': 'step', 231 | 'frequency': 1 232 | }] 233 | return [optimizer], scheduler 234 | 235 | return optimizer 236 | 237 | @torch.no_grad() 238 | def log_images(self, batch, N=8, *args, **kwargs): 239 | log = dict() 240 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 241 | log['inputs'] = x 242 | 243 | y = self.get_conditioning(batch) 244 | 245 | if self.label_key == 'class_label': 246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 247 | log['labels'] = y 248 | 249 | if ismap(y): 250 | log['labels'] = self.diffusion_model.to_rgb(y) 251 | 252 | for step in range(self.log_steps): 253 | current_time = step * self.log_time_interval 254 | 255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 256 | 257 | log[f'inputs@t{current_time}'] = x_noisy 258 | 259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 260 | pred = rearrange(pred, 'b h w c -> b c h w') 261 | 262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 263 | 264 | for key in log: 265 | log[key] = log[key][:N] 266 | 267 | return log 268 | -------------------------------------------------------------------------------- /src/ldm/models/diffusion/ddim.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like, \ 9 | extract_into_tensor 10 | 11 | 12 | class DDIMSampler(object): 13 | def __init__(self, model, schedule="linear", **kwargs): 14 | super().__init__() 15 | self.model = model 16 | self.ddpm_num_timesteps = model.num_timesteps 17 | self.schedule = schedule 18 | 19 | def register_buffer(self, name, attr): 20 | if type(attr) == torch.Tensor: 21 | if attr.device != torch.device("cuda"): 22 | attr = attr.to(torch.device("cuda")) 23 | setattr(self, name, attr) 24 | 25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 26 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 27 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 28 | alphas_cumprod = self.model.alphas_cumprod 29 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 30 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 31 | 32 | self.register_buffer('betas', to_torch(self.model.betas)) 33 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 34 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 35 | 36 | # calculations for diffusion q(x_t | x_{t-1}) and others 37 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 38 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 39 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 41 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 42 | 43 | # ddim sampling parameters 44 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 45 | ddim_timesteps=self.ddim_timesteps, 46 | eta=ddim_eta,verbose=verbose) 47 | self.register_buffer('ddim_sigmas', ddim_sigmas) 48 | self.register_buffer('ddim_alphas', ddim_alphas) 49 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 50 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 51 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 52 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 53 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 54 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 55 | 56 | @torch.no_grad() 57 | def sample(self, 58 | S, 59 | batch_size, 60 | shape, 61 | conditioning=None, 62 | callback=None, 63 | normals_sequence=None, 64 | img_callback=None, 65 | quantize_x0=False, 66 | eta=0., 67 | mask=None, 68 | x0=None, 69 | temperature=1., 70 | noise_dropout=0., 71 | score_corrector=None, 72 | corrector_kwargs=None, 73 | verbose=True, 74 | x_T=None, 75 | log_every_t=100, 76 | unconditional_guidance_scale=1., 77 | unconditional_conditioning=None, 78 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 79 | **kwargs 80 | ): 81 | if conditioning is not None: 82 | if isinstance(conditioning, dict): 83 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 84 | if cbs != batch_size: 85 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 86 | else: 87 | if conditioning.shape[0] != batch_size: 88 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 89 | 90 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 91 | # sampling 92 | C, H, W = shape 93 | size = (batch_size, C, H, W) 94 | print(f'Data shape for DDIM sampling is {size}, eta {eta}') 95 | 96 | samples, intermediates = self.ddim_sampling(conditioning, size, 97 | callback=callback, 98 | img_callback=img_callback, 99 | quantize_denoised=quantize_x0, 100 | mask=mask, x0=x0, 101 | ddim_use_original_steps=False, 102 | noise_dropout=noise_dropout, 103 | temperature=temperature, 104 | score_corrector=score_corrector, 105 | corrector_kwargs=corrector_kwargs, 106 | x_T=x_T, 107 | log_every_t=log_every_t, 108 | unconditional_guidance_scale=unconditional_guidance_scale, 109 | unconditional_conditioning=unconditional_conditioning, 110 | ) 111 | return samples, intermediates 112 | 113 | @torch.no_grad() 114 | def ddim_sampling(self, cond, shape, 115 | x_T=None, ddim_use_original_steps=False, 116 | callback=None, timesteps=None, quantize_denoised=False, 117 | mask=None, x0=None, img_callback=None, log_every_t=100, 118 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 119 | unconditional_guidance_scale=1., unconditional_conditioning=None,): 120 | device = self.model.betas.device 121 | b = shape[0] 122 | if x_T is None: 123 | img = torch.randn(shape, device=device) 124 | else: 125 | img = x_T 126 | 127 | if timesteps is None: 128 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 129 | elif timesteps is not None and not ddim_use_original_steps: 130 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 131 | timesteps = self.ddim_timesteps[:subset_end] 132 | 133 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 134 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 135 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 136 | print(f"Running DDIM Sampling with {total_steps} timesteps") 137 | 138 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 139 | 140 | for i, step in enumerate(iterator): 141 | index = total_steps - i - 1 142 | ts = torch.full((b,), step, device=device, dtype=torch.long) 143 | 144 | if mask is not None: 145 | assert x0 is not None 146 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 147 | img = img_orig * mask + (1. - mask) * img 148 | 149 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 150 | quantize_denoised=quantize_denoised, temperature=temperature, 151 | noise_dropout=noise_dropout, score_corrector=score_corrector, 152 | corrector_kwargs=corrector_kwargs, 153 | unconditional_guidance_scale=unconditional_guidance_scale, 154 | unconditional_conditioning=unconditional_conditioning) 155 | img, pred_x0 = outs 156 | if callback: callback(i) 157 | if img_callback: img_callback(pred_x0, i) 158 | 159 | if index % log_every_t == 0 or index == total_steps - 1: 160 | intermediates['x_inter'].append(img) 161 | intermediates['pred_x0'].append(pred_x0) 162 | 163 | return img, intermediates 164 | 165 | @torch.no_grad() 166 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 167 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 168 | unconditional_guidance_scale=1., unconditional_conditioning=None): 169 | b, *_, device = *x.shape, x.device 170 | 171 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 172 | e_t = self.model.apply_model(x, t, c) 173 | else: 174 | x_in = torch.cat([x] * 2) 175 | t_in = torch.cat([t] * 2) 176 | c_in = torch.cat([unconditional_conditioning, c]) 177 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 178 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 179 | 180 | if score_corrector is not None: 181 | assert self.model.parameterization == "eps" 182 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 183 | 184 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 185 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 186 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 187 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 188 | # select parameters corresponding to the currently considered timestep 189 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 190 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 191 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 192 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 193 | 194 | # current prediction for x_0 195 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 196 | if quantize_denoised: 197 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 198 | # direction pointing to x_t 199 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 200 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 201 | if noise_dropout > 0.: 202 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 203 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 204 | return x_prev, pred_x0 205 | 206 | @torch.no_grad() 207 | def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): 208 | # fast, but does not allow for exact reconstruction 209 | # t serves as an index to gather the correct alphas 210 | if use_original_steps: 211 | sqrt_alphas_cumprod = self.sqrt_alphas_cumprod 212 | sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod 213 | else: 214 | sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) 215 | sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas 216 | 217 | if noise is None: 218 | noise = torch.randn_like(x0) 219 | return (extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + 220 | extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise) 221 | 222 | @torch.no_grad() 223 | def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, 224 | use_original_steps=False): 225 | 226 | timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps 227 | timesteps = timesteps[:t_start] 228 | 229 | time_range = np.flip(timesteps) 230 | total_steps = timesteps.shape[0] 231 | print(f"Running DDIM Sampling with {total_steps} timesteps") 232 | 233 | iterator = tqdm(time_range, desc='Decoding image', total=total_steps) 234 | x_dec = x_latent 235 | for i, step in enumerate(iterator): 236 | index = total_steps - i - 1 237 | ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) 238 | x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, 239 | unconditional_guidance_scale=unconditional_guidance_scale, 240 | unconditional_conditioning=unconditional_conditioning) 241 | return x_dec -------------------------------------------------------------------------------- /src/ldm/models/diffusion/dpm_solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import DPMSolverSampler 2 | from .dpm_solver import ( 3 | model_wrapper, 4 | NoiseScheduleVP, 5 | DPM_Solver 6 | ) -------------------------------------------------------------------------------- /src/ldm/models/diffusion/dpm_solver/sampler.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | 5 | from .dpm_solver import NoiseScheduleVP, model_wrapper, DPM_Solver 6 | 7 | 8 | class DPMSolverSampler(object): 9 | def __init__(self, model, **kwargs): 10 | super().__init__() 11 | self.model = model 12 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) 13 | self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) 14 | 15 | def register_buffer(self, name, attr): 16 | if type(attr) == torch.Tensor: 17 | if attr.device != torch.device("cuda"): 18 | attr = attr.to(torch.device("cuda")) 19 | setattr(self, name, attr) 20 | 21 | @torch.no_grad() 22 | def sample(self, 23 | S, 24 | batch_size, 25 | shape, 26 | conditioning=None, 27 | callback=None, 28 | normals_sequence=None, 29 | img_callback=None, 30 | quantize_x0=False, 31 | eta=0., 32 | mask=None, 33 | x0=None, 34 | temperature=1., 35 | noise_dropout=0., 36 | score_corrector=None, 37 | corrector_kwargs=None, 38 | verbose=True, 39 | x_T=None, 40 | log_every_t=100, 41 | unconditional_guidance_scale=1., 42 | unconditional_conditioning=None, 43 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 44 | **kwargs 45 | ): 46 | if conditioning is not None: 47 | if isinstance(conditioning, dict): 48 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 49 | if cbs != batch_size: 50 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 51 | else: 52 | if conditioning.shape[0] != batch_size: 53 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 54 | 55 | # sampling 56 | C, H, W = shape 57 | size = (batch_size, C, H, W) 58 | 59 | # print(f'Data shape for DPM-Solver sampling is {size}, sampling steps {S}') 60 | 61 | device = self.model.betas.device 62 | if x_T is None: 63 | img = torch.randn(size, device=device) 64 | else: 65 | img = x_T 66 | 67 | ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) 68 | 69 | model_fn = model_wrapper( 70 | lambda x, t, c: self.model.apply_model(x, t, c), 71 | ns, 72 | model_type="noise", 73 | guidance_type="classifier-free", 74 | condition=conditioning, 75 | unconditional_condition=unconditional_conditioning, 76 | guidance_scale=unconditional_guidance_scale, 77 | ) 78 | 79 | dpm_solver = DPM_Solver(model_fn, ns, predict_x0=True, thresholding=False) 80 | x = dpm_solver.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=2, lower_order_final=True) 81 | 82 | return x.to(device), None 83 | -------------------------------------------------------------------------------- /src/ldm/models/diffusion/plms.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | 10 | 11 | class PLMSSampler(object): 12 | def __init__(self, model, schedule="linear", **kwargs): 13 | super().__init__() 14 | self.model = model 15 | self.ddpm_num_timesteps = model.num_timesteps 16 | self.schedule = schedule 17 | 18 | def register_buffer(self, name, attr): 19 | if type(attr) == torch.Tensor: 20 | if attr.device != torch.device("cuda"): 21 | attr = attr.to(torch.device("cuda")) 22 | setattr(self, name, attr) 23 | 24 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 25 | if ddim_eta != 0: 26 | raise ValueError('ddim_eta must be 0 for PLMS') 27 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 28 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 29 | alphas_cumprod = self.model.alphas_cumprod 30 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 31 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 32 | 33 | self.register_buffer('betas', to_torch(self.model.betas)) 34 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 35 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 36 | 37 | # calculations for diffusion q(x_t | x_{t-1}) and others 38 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 39 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 40 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 43 | 44 | # ddim sampling parameters 45 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 46 | ddim_timesteps=self.ddim_timesteps, 47 | eta=ddim_eta,verbose=verbose) 48 | self.register_buffer('ddim_sigmas', ddim_sigmas) 49 | self.register_buffer('ddim_alphas', ddim_alphas) 50 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 51 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 52 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 53 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 54 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 55 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 56 | 57 | @torch.no_grad() 58 | def sample(self, 59 | S, 60 | batch_size, 61 | shape, 62 | conditioning=None, 63 | callback=None, 64 | normals_sequence=None, 65 | img_callback=None, 66 | quantize_x0=False, 67 | eta=0., 68 | mask=None, 69 | x0=None, 70 | temperature=1., 71 | noise_dropout=0., 72 | score_corrector=None, 73 | corrector_kwargs=None, 74 | verbose=True, 75 | x_T=None, 76 | log_every_t=100, 77 | unconditional_guidance_scale=1., 78 | unconditional_conditioning=None, 79 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 80 | **kwargs 81 | ): 82 | if conditioning is not None: 83 | if isinstance(conditioning, dict): 84 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 85 | if cbs != batch_size: 86 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 87 | else: 88 | if conditioning.shape[0] != batch_size: 89 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 90 | 91 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 92 | # sampling 93 | C, H, W = shape 94 | size = (batch_size, C, H, W) 95 | print(f'Data shape for PLMS sampling is {size}') 96 | 97 | samples, intermediates = self.plms_sampling(conditioning, size, 98 | callback=callback, 99 | img_callback=img_callback, 100 | quantize_denoised=quantize_x0, 101 | mask=mask, x0=x0, 102 | ddim_use_original_steps=False, 103 | noise_dropout=noise_dropout, 104 | temperature=temperature, 105 | score_corrector=score_corrector, 106 | corrector_kwargs=corrector_kwargs, 107 | x_T=x_T, 108 | log_every_t=log_every_t, 109 | unconditional_guidance_scale=unconditional_guidance_scale, 110 | unconditional_conditioning=unconditional_conditioning, 111 | ) 112 | return samples, intermediates 113 | 114 | @torch.no_grad() 115 | def plms_sampling(self, cond, shape, 116 | x_T=None, ddim_use_original_steps=False, 117 | callback=None, timesteps=None, quantize_denoised=False, 118 | mask=None, x0=None, img_callback=None, log_every_t=100, 119 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 120 | unconditional_guidance_scale=1., unconditional_conditioning=None,): 121 | device = self.model.betas.device 122 | b = shape[0] 123 | if x_T is None: 124 | img = torch.randn(shape, device=device) 125 | else: 126 | img = x_T 127 | 128 | if timesteps is None: 129 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 130 | elif timesteps is not None and not ddim_use_original_steps: 131 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 132 | timesteps = self.ddim_timesteps[:subset_end] 133 | 134 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 135 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) 136 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 137 | print(f"Running PLMS Sampling with {total_steps} timesteps") 138 | 139 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) 140 | old_eps = [] 141 | 142 | for i, step in enumerate(iterator): 143 | index = total_steps - i - 1 144 | ts = torch.full((b,), step, device=device, dtype=torch.long) 145 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 146 | 147 | if mask is not None: 148 | assert x0 is not None 149 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 150 | img = img_orig * mask + (1. - mask) * img 151 | 152 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 153 | quantize_denoised=quantize_denoised, temperature=temperature, 154 | noise_dropout=noise_dropout, score_corrector=score_corrector, 155 | corrector_kwargs=corrector_kwargs, 156 | unconditional_guidance_scale=unconditional_guidance_scale, 157 | unconditional_conditioning=unconditional_conditioning, 158 | old_eps=old_eps, t_next=ts_next) 159 | img, pred_x0, e_t = outs 160 | old_eps.append(e_t) 161 | if len(old_eps) >= 4: 162 | old_eps.pop(0) 163 | if callback: callback(i) 164 | if img_callback: img_callback(pred_x0, i) 165 | 166 | if index % log_every_t == 0 or index == total_steps - 1: 167 | intermediates['x_inter'].append(img) 168 | intermediates['pred_x0'].append(pred_x0) 169 | 170 | return img, intermediates 171 | 172 | @torch.no_grad() 173 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 174 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 175 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): 176 | b, *_, device = *x.shape, x.device 177 | 178 | def get_model_output(x, t): 179 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 180 | e_t = self.model.apply_model(x, t, c) 181 | else: 182 | x_in = torch.cat([x] * 2) 183 | t_in = torch.cat([t] * 2) 184 | c_in = torch.cat([unconditional_conditioning, c]) 185 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 186 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 187 | 188 | if score_corrector is not None: 189 | assert self.model.parameterization == "eps" 190 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 191 | 192 | return e_t 193 | 194 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 195 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 196 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 197 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 198 | 199 | def get_x_prev_and_pred_x0(e_t, index): 200 | # select parameters corresponding to the currently considered timestep 201 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 202 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 203 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 204 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 205 | 206 | # current prediction for x_0 207 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 208 | if quantize_denoised: 209 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 210 | # direction pointing to x_t 211 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 212 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 213 | if noise_dropout > 0.: 214 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 215 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 216 | return x_prev, pred_x0 217 | 218 | e_t = get_model_output(x, t) 219 | if len(old_eps) == 0: 220 | # Pseudo Improved Euler (2nd order) 221 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 222 | e_t_next = get_model_output(x_prev, t_next) 223 | e_t_prime = (e_t + e_t_next) / 2 224 | elif len(old_eps) == 1: 225 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 226 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 227 | elif len(old_eps) == 2: 228 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 229 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 230 | elif len(old_eps) >= 3: 231 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 232 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 233 | 234 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 235 | 236 | return x_prev, pred_x0, e_t 237 | -------------------------------------------------------------------------------- /src/ldm/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamEditBenchTeam/DreamEdit/14d21b0a3eb6305c1378080ccd8361db0a8adcc0/src/ldm/modules/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads 172 | 173 | q = self.to_q(x) 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 181 | 182 | if exists(mask): 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 198 | super().__init__() 199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 203 | self.norm1 = nn.LayerNorm(dim) 204 | self.norm2 = nn.LayerNorm(dim) 205 | self.norm3 = nn.LayerNorm(dim) 206 | self.checkpoint = checkpoint 207 | 208 | def forward(self, x, context=None): 209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 210 | 211 | def _forward(self, x, context=None): 212 | x = self.attn1(self.norm1(x)) + x 213 | x = self.attn2(self.norm2(x), context=context) + x 214 | x = self.ff(self.norm3(x)) + x 215 | return x 216 | 217 | 218 | class SpatialTransformer(nn.Module): 219 | """ 220 | Transformer block for image-like data. 221 | First, project the input (aka embedding) 222 | and reshape to b, t, d. 223 | Then apply standard transformer action. 224 | Finally, reshape to image 225 | """ 226 | def __init__(self, in_channels, n_heads, d_head, 227 | depth=1, dropout=0., context_dim=None): 228 | super().__init__() 229 | self.in_channels = in_channels 230 | inner_dim = n_heads * d_head 231 | self.norm = Normalize(in_channels) 232 | 233 | self.proj_in = nn.Conv2d(in_channels, 234 | inner_dim, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | 239 | self.transformer_blocks = nn.ModuleList( 240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 241 | for d in range(depth)] 242 | ) 243 | 244 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 245 | in_channels, 246 | kernel_size=1, 247 | stride=1, 248 | padding=0)) 249 | 250 | def forward(self, x, context=None): 251 | # note: if no context is given, cross-attention defaults to self-attention 252 | b, c, h, w = x.shape 253 | x_in = x 254 | x = self.norm(x) 255 | x = self.proj_in(x) 256 | x = rearrange(x, 'b c h w -> b (h w) c') 257 | for block in self.transformer_blocks: 258 | x = block(x, context=context) 259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 260 | x = self.proj_out(x) 261 | return x + x_in -------------------------------------------------------------------------------- /src/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamEditBenchTeam/DreamEdit/14d21b0a3eb6305c1378080ccd8361db0a8adcc0/src/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNorm32(32, channels) 206 | 207 | 208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 209 | class SiLU(nn.Module): 210 | def forward(self, x): 211 | return x * torch.sigmoid(x) 212 | 213 | 214 | class GroupNorm32(nn.GroupNorm): 215 | def forward(self, x): 216 | return super().forward(x.float()).type(x.dtype) 217 | 218 | def conv_nd(dims, *args, **kwargs): 219 | """ 220 | Create a 1D, 2D, or 3D convolution module. 221 | """ 222 | if dims == 1: 223 | return nn.Conv1d(*args, **kwargs) 224 | elif dims == 2: 225 | return nn.Conv2d(*args, **kwargs) 226 | elif dims == 3: 227 | return nn.Conv3d(*args, **kwargs) 228 | raise ValueError(f"unsupported dimensions: {dims}") 229 | 230 | 231 | def linear(*args, **kwargs): 232 | """ 233 | Create a linear module. 234 | """ 235 | return nn.Linear(*args, **kwargs) 236 | 237 | 238 | def avg_pool_nd(dims, *args, **kwargs): 239 | """ 240 | Create a 1D, 2D, or 3D average pooling module. 241 | """ 242 | if dims == 1: 243 | return nn.AvgPool1d(*args, **kwargs) 244 | elif dims == 2: 245 | return nn.AvgPool2d(*args, **kwargs) 246 | elif dims == 3: 247 | return nn.AvgPool3d(*args, **kwargs) 248 | raise ValueError(f"unsupported dimensions: {dims}") 249 | 250 | 251 | class HybridConditioner(nn.Module): 252 | 253 | def __init__(self, c_concat_config, c_crossattn_config): 254 | super().__init__() 255 | self.concat_conditioner = instantiate_from_config(c_concat_config) 256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 257 | 258 | def forward(self, c_concat, c_crossattn): 259 | c_concat = self.concat_conditioner(c_concat) 260 | c_crossattn = self.crossattn_conditioner(c_crossattn) 261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 262 | 263 | 264 | def noise_like(shape, device, repeat=False): 265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 266 | noise = lambda: torch.randn(shape, device=device) 267 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /src/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamEditBenchTeam/DreamEdit/14d21b0a3eb6305c1378080ccd8361db0a8adcc0/src/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /src/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /src/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamEditBenchTeam/DreamEdit/14d21b0a3eb6305c1378080ccd8361db0a8adcc0/src/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /src/ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import clip 5 | from einops import rearrange, repeat 6 | from transformers import CLIPTokenizer, CLIPTextModel 7 | import kornia 8 | 9 | from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test 10 | 11 | 12 | class AbstractEncoder(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def encode(self, *args, **kwargs): 17 | raise NotImplementedError 18 | 19 | 20 | 21 | class ClassEmbedder(nn.Module): 22 | def __init__(self, embed_dim, n_classes=1000, key='class'): 23 | super().__init__() 24 | self.key = key 25 | self.embedding = nn.Embedding(n_classes, embed_dim) 26 | 27 | def forward(self, batch, key=None): 28 | if key is None: 29 | key = self.key 30 | # this is for use in crossattn 31 | c = batch[key][:, None] 32 | c = self.embedding(c) 33 | return c 34 | 35 | 36 | class TransformerEmbedder(AbstractEncoder): 37 | """Some transformer encoder layers""" 38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 39 | super().__init__() 40 | self.device = device 41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 42 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 43 | 44 | def forward(self, tokens): 45 | tokens = tokens.to(self.device) # meh 46 | z = self.transformer(tokens, return_embeddings=True) 47 | return z 48 | 49 | def encode(self, x): 50 | return self(x) 51 | 52 | 53 | class BERTTokenizer(AbstractEncoder): 54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 55 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 56 | super().__init__() 57 | from transformers import BertTokenizerFast # TODO: add to reuquirements 58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 59 | self.device = device 60 | self.vq_interface = vq_interface 61 | self.max_length = max_length 62 | 63 | def forward(self, text): 64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 66 | tokens = batch_encoding["input_ids"].to(self.device) 67 | return tokens 68 | 69 | @torch.no_grad() 70 | def encode(self, text): 71 | tokens = self(text) 72 | if not self.vq_interface: 73 | return tokens 74 | return None, None, [None, None, tokens] 75 | 76 | def decode(self, text): 77 | return text 78 | 79 | 80 | class BERTEmbedder(AbstractEncoder): 81 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 82 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 83 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 84 | super().__init__() 85 | self.use_tknz_fn = use_tokenizer 86 | if self.use_tknz_fn: 87 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 88 | self.device = device 89 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 90 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 91 | emb_dropout=embedding_dropout) 92 | 93 | def forward(self, text): 94 | if self.use_tknz_fn: 95 | tokens = self.tknz_fn(text)#.to(self.device) 96 | else: 97 | tokens = text 98 | z = self.transformer(tokens, return_embeddings=True) 99 | return z 100 | 101 | def encode(self, text): 102 | # output of length 77 103 | return self(text) 104 | 105 | 106 | class SpatialRescaler(nn.Module): 107 | def __init__(self, 108 | n_stages=1, 109 | method='bilinear', 110 | multiplier=0.5, 111 | in_channels=3, 112 | out_channels=None, 113 | bias=False): 114 | super().__init__() 115 | self.n_stages = n_stages 116 | assert self.n_stages >= 0 117 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 118 | self.multiplier = multiplier 119 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 120 | self.remap_output = out_channels is not None 121 | if self.remap_output: 122 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 123 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 124 | 125 | def forward(self,x): 126 | for stage in range(self.n_stages): 127 | x = self.interpolator(x, scale_factor=self.multiplier) 128 | 129 | 130 | if self.remap_output: 131 | x = self.channel_mapper(x) 132 | return x 133 | 134 | def encode(self, x): 135 | return self(x) 136 | 137 | class FrozenCLIPEmbedder(AbstractEncoder): 138 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 139 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): 140 | super().__init__() 141 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 142 | self.transformer = CLIPTextModel.from_pretrained(version) 143 | self.device = device 144 | self.max_length = max_length 145 | self.freeze() 146 | 147 | def freeze(self): 148 | self.transformer = self.transformer.eval() 149 | for param in self.parameters(): 150 | param.requires_grad = False 151 | 152 | def forward(self, text): 153 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 154 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 155 | tokens = batch_encoding["input_ids"].to(self.device) 156 | outputs = self.transformer(input_ids=tokens) 157 | 158 | z = outputs.last_hidden_state 159 | return z 160 | 161 | def encode(self, text): 162 | return self(text) 163 | 164 | 165 | class FrozenCLIPTextEmbedder(nn.Module): 166 | """ 167 | Uses the CLIP transformer encoder for text. 168 | """ 169 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): 170 | super().__init__() 171 | self.model, _ = clip.load(version, jit=False, device="cpu") 172 | self.device = device 173 | self.max_length = max_length 174 | self.n_repeat = n_repeat 175 | self.normalize = normalize 176 | 177 | def freeze(self): 178 | self.model = self.model.eval() 179 | for param in self.parameters(): 180 | param.requires_grad = False 181 | 182 | def forward(self, text): 183 | tokens = clip.tokenize(text).to(self.device) 184 | z = self.model.encode_text(tokens) 185 | if self.normalize: 186 | z = z / torch.linalg.norm(z, dim=1, keepdim=True) 187 | return z 188 | 189 | def encode(self, text): 190 | z = self(text) 191 | if z.ndim==2: 192 | z = z[:, None, :] 193 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) 194 | return z 195 | 196 | 197 | class FrozenClipImageEmbedder(nn.Module): 198 | """ 199 | Uses the CLIP image encoder. 200 | """ 201 | def __init__( 202 | self, 203 | model, 204 | jit=False, 205 | device='cuda' if torch.cuda.is_available() else 'cpu', 206 | antialias=False, 207 | ): 208 | super().__init__() 209 | self.model, _ = clip.load(name=model, device=device, jit=jit) 210 | 211 | self.antialias = antialias 212 | 213 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 214 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 215 | 216 | def preprocess(self, x): 217 | # normalize to [0,1] 218 | x = kornia.geometry.resize(x, (224, 224), 219 | interpolation='bicubic',align_corners=True, 220 | antialias=self.antialias) 221 | x = (x + 1.) / 2. 222 | # renormalize according to clip 223 | x = kornia.enhance.normalize(x, self.mean, self.std) 224 | return x 225 | 226 | def forward(self, x): 227 | # x is assumed to be in range [-1,1] 228 | return self.model.encode_image(self.preprocess(x)) 229 | 230 | 231 | if __name__ == "__main__": 232 | from ldm.util import count_params 233 | model = FrozenCLIPEmbedder() 234 | count_params(model, verbose=True) -------------------------------------------------------------------------------- /src/ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /src/ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamEditBenchTeam/DreamEdit/14d21b0a3eb6305c1378080ccd8361db0a8adcc0/src/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /src/ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /src/ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /src/ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /src/ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == '__is_first_stage__': 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 86 | 87 | 88 | def get_obj_from_str(string, reload=False): 89 | module, cls = string.rsplit(".", 1) 90 | if reload: 91 | module_imp = importlib.import_module(module) 92 | importlib.reload(module_imp) 93 | return getattr(importlib.import_module(module, package=None), cls) 94 | 95 | 96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 97 | # create dummy dataset instance 98 | 99 | # run prefetching 100 | if idx_to_fn: 101 | res = func(data, worker_id=idx) 102 | else: 103 | res = func(data) 104 | Q.put([idx, res]) 105 | Q.put("Done") 106 | 107 | 108 | def parallel_data_prefetch( 109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 110 | ): 111 | # if target_data_type not in ["ndarray", "list"]: 112 | # raise ValueError( 113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 114 | # ) 115 | if isinstance(data, np.ndarray) and target_data_type == "list": 116 | raise ValueError("list expected but function got ndarray.") 117 | elif isinstance(data, abc.Iterable): 118 | if isinstance(data, dict): 119 | print( 120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 121 | ) 122 | data = list(data.values()) 123 | if target_data_type == "ndarray": 124 | data = np.asarray(data) 125 | else: 126 | data = list(data) 127 | else: 128 | raise TypeError( 129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 130 | ) 131 | 132 | if cpu_intensive: 133 | Q = mp.Queue(1000) 134 | proc = mp.Process 135 | else: 136 | Q = Queue(1000) 137 | proc = Thread 138 | # spawn processes 139 | if target_data_type == "ndarray": 140 | arguments = [ 141 | [func, Q, part, i, use_worker_id] 142 | for i, part in enumerate(np.array_split(data, n_proc)) 143 | ] 144 | else: 145 | step = ( 146 | int(len(data) / n_proc + 1) 147 | if len(data) % n_proc != 0 148 | else int(len(data) / n_proc) 149 | ) 150 | arguments = [ 151 | [func, Q, part, i, use_worker_id] 152 | for i, part in enumerate( 153 | [data[i: i + step] for i in range(0, len(data), step)] 154 | ) 155 | ] 156 | processes = [] 157 | for i in range(n_proc): 158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 159 | processes += [p] 160 | 161 | # start processes 162 | print(f"Start prefetching...") 163 | import time 164 | 165 | start = time.time() 166 | gather_res = [[] for _ in range(n_proc)] 167 | try: 168 | for p in processes: 169 | p.start() 170 | 171 | k = 0 172 | while k < n_proc: 173 | # get result 174 | res = Q.get() 175 | if res == "Done": 176 | k += 1 177 | else: 178 | gather_res[res[0]] = res[1] 179 | 180 | except Exception as e: 181 | print("Exception: ", e) 182 | for p in processes: 183 | p.terminate() 184 | 185 | raise e 186 | finally: 187 | for p in processes: 188 | p.join() 189 | print(f"Prefetching complete. [{time.time() - start} sec.]") 190 | 191 | if target_data_type == 'ndarray': 192 | if not isinstance(gather_res[0], np.ndarray): 193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 194 | 195 | # order outputs 196 | return np.concatenate(gather_res, axis=0) 197 | elif target_data_type == 'list': 198 | out = [] 199 | for r in gather_res: 200 | out.extend(r) 201 | return out 202 | else: 203 | return gather_res 204 | -------------------------------------------------------------------------------- /src/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamEditBenchTeam/DreamEdit/14d21b0a3eb6305c1378080ccd8361db0a8adcc0/src/metrics/__init__.py -------------------------------------------------------------------------------- /src/metrics/clip_vit.py: -------------------------------------------------------------------------------- 1 | import clip 2 | from PIL import Image 3 | import torch 4 | 5 | class CLIP(): 6 | def __init__(self, device="cuda"): 7 | self.model, self.preprocess = clip.load("ViT-B/32", device=device) 8 | self.model.eval() 9 | self.device = device 10 | 11 | def get_transform(self): 12 | return self.preprocess 13 | 14 | def encode_image(self, tensor_image): 15 | """ 16 | Take input in size [B, 3, H, W] 17 | """ 18 | output = self.model.encode_image(tensor_image.to(self.device)) 19 | return output 20 | 21 | def encode_text(self, prompt): 22 | text = clip.tokenize([prompt]).to(self.device) 23 | text_features = self.model.encode_text(text) 24 | return text_features 25 | 26 | if __name__ == "__main__": 27 | device = "cuda" if torch.cuda.is_available() else "cpu" 28 | clip_model = CLIP(device) 29 | 30 | import requests 31 | url = "http://images.cocodataset.org/val2017/000000039769.jpg" 32 | image = Image.open(requests.get(url, stream=True).raw).convert('RGB') 33 | image = image.resize((512, 512)) 34 | preprocess = clip_model.get_transform() 35 | image = preprocess(image) 36 | image = image.unsqueeze(0) 37 | assert image.shape == torch.Size([1, 3, 224, 224]) 38 | result = clip_model.encode_image(image) 39 | assert result.shape == torch.Size([1, 512]) 40 | text = "hello world" 41 | result = clip_model.encode_text(text) 42 | assert result.shape == torch.Size([1, 512]) -------------------------------------------------------------------------------- /src/metrics/dino_vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms 3 | from PIL import Image 4 | 5 | class VITs16(): 6 | def __init__(self, device="cuda"): 7 | self.model = torch.hub.load('facebookresearch/dino:main', 'dino_vits16').to(device) 8 | self.model.eval() 9 | self.device = device 10 | 11 | def get_transform(self): 12 | val_transform = transforms.Compose([ 13 | transforms.Resize(256, interpolation=3), 14 | transforms.CenterCrop(224), 15 | transforms.ToTensor(), 16 | transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), 17 | ]) 18 | return val_transform 19 | 20 | def get_embeddings(self, tensor_image): 21 | output = self.model(tensor_image.to(self.device)) 22 | return output 23 | 24 | def get_embeddings_intermediate(self, tensor_image, n_last_block=4): 25 | """ 26 | We use `n_last_block=4` when evaluating ViT-Small 27 | """ 28 | intermediate_output = self.model.get_intermediate_layers(tensor_image, n=n_last_block) 29 | output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1) 30 | return output 31 | 32 | if __name__ == "__main__": 33 | device = "cuda" if torch.cuda.is_available() else "cpu" 34 | import requests 35 | url = "http://images.cocodataset.org/val2017/000000039769.jpg" 36 | image = Image.open(requests.get(url, stream=True).raw).convert('RGB') 37 | image = image.resize((512, 512)) 38 | fidelity = VITs16(device) 39 | preprocess = fidelity.get_transform() 40 | image = preprocess(image) 41 | print(image.shape) 42 | image = image.unsqueeze(0) 43 | result = fidelity.get_embeddings(image) 44 | assert result.shape == torch.Size([1, 384]) -------------------------------------------------------------------------------- /src/metrics/distances.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def compute_cosine_distance(image_features, image_features2): 5 | # normalized features 6 | image_features = image_features / np.linalg.norm(np.float32(image_features), ord=2) 7 | image_features2 = image_features2 / np.linalg.norm(np.float32(image_features2), ord=2) 8 | return np.dot(image_features, image_features2) 9 | 10 | 11 | def compute_l2_distance(image_features, image_features2): 12 | return np.linalg.norm(np.float32(image_features - image_features2)) 13 | -------------------------------------------------------------------------------- /src/metrics/evaluate_dino.py: -------------------------------------------------------------------------------- 1 | from .distances import compute_cosine_distance 2 | import numpy as np 3 | import torch 4 | 5 | def evaluate_dino_score(real_image, generated_image, device, fidelity): 6 | #tensor_image_1 = torch.from_numpy(np.asarray(real_image)).permute(2, 0, 1).unsqueeze(0) 7 | #tensor_image_2 = torch.from_numpy(np.asarray(generated_image)).permute(2, 0, 1).unsqueeze(0) 8 | preprocess = fidelity.get_transform() 9 | tensor_image_1 = preprocess(real_image).unsqueeze(0) 10 | tensor_image_2 = preprocess(generated_image).unsqueeze(0) 11 | emb_1 = fidelity.get_embeddings(tensor_image_1.float().to(device)) 12 | emb_2 = fidelity.get_embeddings(tensor_image_2.float().to(device)) 13 | assert emb_1.shape == emb_2.shape 14 | score = compute_cosine_distance(emb_1.detach().cpu().numpy(), emb_2.detach().permute(1, 0).cpu().numpy()) 15 | return score[0][0] 16 | 17 | def evaluate_dino_score_list(real_image, generated_image_list, device, fidelity): 18 | score_list = [] 19 | total = len(generated_image_list) 20 | for i in range(total): 21 | score = evaluate_dino_score(real_image, generated_image_list[i], device, fidelity) 22 | score_list.append(score) 23 | # max_score = max(score_list) 24 | # max_index = score_list.index(max_score) 25 | # print("The best result is at {} th iteration with dino score {}".format(max_index + 1, max_score)) 26 | # return max_score, max_index 27 | return score_list 28 | 29 | def evaluate_clipi_score(real_image, generated_image, device, clip_model): 30 | preprocess = clip_model.get_transform() 31 | tensor_image_1 = preprocess(real_image).unsqueeze(0) 32 | tensor_image_2 = preprocess(generated_image).unsqueeze(0) 33 | emb_1 = clip_model.encode_image(tensor_image_1.float().to(device)) 34 | emb_2 = clip_model.encode_image(tensor_image_2.float().to(device)) 35 | assert emb_1.shape == emb_2.shape 36 | score = compute_cosine_distance(emb_1.detach().cpu().numpy(), emb_2.detach().permute(1, 0).cpu().numpy()) 37 | return score[0][0] 38 | 39 | def evaluate_clipi_score_list(real_image, generated_image_list, device, clip_model): 40 | score_list = [] 41 | total = len(generated_image_list) 42 | for i in range(total): 43 | score = evaluate_clipi_score(real_image, generated_image_list[i], device, clip_model) 44 | score_list.append(score) 45 | # max_score = max(score_list) 46 | # max_index = score_list.index(max_score) 47 | # print("The best result is at {} th iteration with dino score {}".format(max_index + 1, max_score)) 48 | # return max_score, max_index 49 | return score_list -------------------------------------------------------------------------------- /src/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DreamEditBenchTeam/DreamEdit/14d21b0a3eb6305c1378080ccd8361db0a8adcc0/src/pipelines/__init__.py -------------------------------------------------------------------------------- /src/pipelines/extract_object_pipeline.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from lang_sam import LangSAM 3 | from lang_sam.utils import draw_image 4 | 5 | from src.pipelines.imagecaption_pipelines import * 6 | from src.utils.mask_helper import * 7 | from src.utils.visual_helper import * 8 | 9 | 10 | def get_object_caption(image, class_name, BLIP_Model, langSAM): 11 | image = image.resize((256, 256), Image.LANCZOS) 12 | masks, boxes, phrases, logits = langSAM.predict(image, class_name, box_threshold=0.1, text_threshold=0.1) 13 | mask = merge_masks(masks) 14 | tensor_img = pil_to_tensor(image) 15 | expanded_mask = mask.unsqueeze(0).expand_as(tensor_img) 16 | masked_tensor = tensor_img * expanded_mask 17 | final_image = tensor_to_pil(masked_tensor) 18 | caption = BLIP_Model.predict_one_image(final_image) 19 | return caption 20 | 21 | 22 | class Object_Caption_Extractor(): 23 | 24 | def __init__(self, device="cuda"): 25 | self.BLIP_Model = BLIP_Model(device) 26 | self.langSAM = LangSAM() 27 | 28 | def get_object_caption(self, image, class_name: str): 29 | image = image.resize((256, 256), Image.LANCZOS) 30 | masks, boxes, phrases, logits = self.langSAM.predict(image, class_name, box_threshold=0.1, text_threshold=0.1) 31 | mask = merge_masks(masks) 32 | tensor_img = pil_to_tensor(image) 33 | expanded_mask = mask.unsqueeze(0).expand_as(tensor_img) 34 | masked_tensor = tensor_img * expanded_mask 35 | final_image = tensor_to_pil(masked_tensor) 36 | caption = self.BLIP_Model.predict_one_image(final_image) 37 | return caption 38 | 39 | def get_langSAM(self): 40 | return self.langSAM 41 | 42 | def get_BLIP_Model(self): 43 | return self.BLIP_Model 44 | 45 | 46 | if __name__ == "__main__": 47 | image = Image.open("/home/maxku/Research/diffusion_project/data/ref_images/rc_car/found0.jpg") 48 | extractor = Object_Caption_Extractor() 49 | caption = extractor.get_object_caption(image, 'rc_car') 50 | print(caption) 51 | assert caption == 'a pink toy car with a fireman on top' 52 | -------------------------------------------------------------------------------- /src/pipelines/imagecaption_pipelines.py: -------------------------------------------------------------------------------- 1 | from transformers import BlipProcessor, BlipForConditionalGeneration 2 | import torch 3 | from PIL import Image 4 | 5 | class BLIP_Model(): 6 | 7 | def __init__(self, device="cuda", processor_weights = "Salesforce/blip-image-captioning-base", conditional_generation_weights = "Salesforce/blip-image-captioning-base"): 8 | 9 | self.processor = BlipProcessor.from_pretrained(processor_weights) 10 | self.model = BlipForConditionalGeneration.from_pretrained(conditional_generation_weights) 11 | self.device = device 12 | self.model.to(self.device) 13 | 14 | def predict_one_image(self, image): 15 | inputs = self.processor(images=image, return_tensors="pt").to(self.device) 16 | generated_ids = self.model.generate(**inputs) 17 | generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip() 18 | return generated_text 19 | 20 | if __name__ == "__main__": 21 | 22 | import requests 23 | 24 | BLIP_Model = BLIP_Model() 25 | 26 | url = "http://images.cocodataset.org/val2017/000000039769.jpg" 27 | image = Image.open(requests.get(url, stream=True).raw) 28 | generated_text = BLIP_Model.predict_one_image(image) 29 | #print(generated_text) 30 | assert generated_text == "two cats sleeping on a couch" 31 | -------------------------------------------------------------------------------- /src/pipelines/inpainting_pipelines.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionInpaintPipeline 2 | import torch 3 | 4 | import os 5 | import torchvision 6 | from diffusers import StableDiffusionGLIGENPipeline 7 | 8 | from PIL import Image 9 | 10 | 11 | # pipe = StableDiffusionGLIGENPipeline.from_pretrained("gligen/diffusers-inpainting-text-box", torch_dtype=torch.float32) 12 | # pipe.to("cuda") 13 | 14 | # os.makedirs("images", exist_ok=True) 15 | # os.makedirs("images/output", exist_ok=True) 16 | 17 | def inpaint_text_gligen(pipe, prompt, background_path, bounding_box, gligen_phrase, config): 18 | images = pipe( 19 | prompt, 20 | num_images_per_prompt=1, 21 | gligen_phrases=[gligen_phrase], 22 | gligen_inpaint_image=Image.open(background_path).convert('RGB'), 23 | gligen_boxes=[[x / 512 for x in bounding_box]], 24 | gligen_scheduled_sampling_beta=config.gligen_scheduled_sampling_beta, 25 | output_type="numpy", 26 | num_inference_steps=config.num_inference_steps 27 | ).images 28 | return images 29 | 30 | 31 | def select_inpainting_pipeline(name: str, device="cuda"): 32 | if name == "sd_inpaint": 33 | return get_stable_diffusion_inpaint_pipeline(device) 34 | elif name == "gligen_inpaint": 35 | return get_gligen_inpaint_pipeline(device) 36 | else: 37 | return None 38 | 39 | 40 | def get_stable_diffusion_inpaint_pipeline(device): 41 | inpainting_pipe = StableDiffusionInpaintPipeline.from_pretrained( 42 | "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16 43 | ).to(device) 44 | return inpainting_pipe 45 | 46 | 47 | def get_gligen_inpaint_pipeline(device): 48 | inpainting_pipe = StableDiffusionGLIGENPipeline.from_pretrained( 49 | "gligen/diffusers-inpainting-text-box", torch_dtype=torch.float32 50 | ).to(device) 51 | return inpainting_pipe 52 | -------------------------------------------------------------------------------- /src/utils/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def mse(img1, img2): 4 | h, w = img1.shape[-2:] 5 | diff = (img1 - img2).abs_().mean(dim=[0, 1]) 6 | err = torch.square(diff).sum() 7 | # err = err/(float(h*w)) 8 | return torch.log(err) -------------------------------------------------------------------------------- /src/utils/mask_helper.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from kornia.morphology import dilation, closing, erosion 3 | from einops import rearrange 4 | import torch 5 | import numpy as np 6 | import torchvision.transforms as T 7 | 8 | 9 | def merge_masks(masks_tensor): 10 | """ 11 | param: 12 | masks in size [N, H, W] 13 | return: 14 | mask in size [] 15 | """ 16 | merged_mask, _ = torch.max(masks_tensor, dim=0) 17 | return merged_mask 18 | 19 | 20 | def subtract_mask(big_mask, small_mask): 21 | """ 22 | small_mask must be $\in$ big_mask 23 | param: 24 | big_mask in size [H, W] 25 | small_mask in size [H, W] 26 | return: 27 | mask in size [H, W] 28 | """ 29 | # Subtract the smaller mask from the bigger mask 30 | result_mask = big_mask - small_mask 31 | return result_mask 32 | 33 | 34 | def get_polished_mask(mask_tensor, k: int, mask_type: str): 35 | """ 36 | param: 37 | mask in size [H, W] 38 | k = kernel number 39 | mask_type name of mask type 40 | return: 41 | mask in size [H, W] 42 | """ 43 | if mask_type == "dilation": 44 | return mask_dilation(mask_tensor, k) 45 | elif mask_type == "closing": 46 | return mask_closing(mask_tensor, k) 47 | elif mask_type == "closing_half": 48 | return mask_closing_half(mask_tensor, k) 49 | elif mask_type == "pixelwise_enlarge": 50 | return mask_pixelwise_enlarge(mask_tensor, k) 51 | else: 52 | return mask_tensor 53 | 54 | 55 | def mask_dilation(mask_tensor, k: int): 56 | """ 57 | param: 58 | mask in size [H, W] 59 | k = kernel number 60 | return: 61 | mask in size [H, W] 62 | """ 63 | mask_tensor = rearrange(mask_tensor, 'w h -> 1 1 w h') 64 | kernel = torch.ones(k, k) 65 | dilated_img = dilation(mask_tensor, kernel, border_type="constant") 66 | dilated_img = rearrange(dilated_img, '1 1 w h -> w h') 67 | return dilated_img 68 | 69 | 70 | def mask_closing(mask_tensor, k: int): 71 | """ 72 | param: 73 | mask in size [H, W] 74 | k = kernel number 75 | return: 76 | mask in size [H, W] 77 | """ 78 | mask_tensor = rearrange(mask_tensor, 'w h -> 1 1 w h') 79 | kernel = torch.ones(k, k) 80 | dilated_img = closing(mask_tensor, kernel, border_type="constant") 81 | dilated_img = rearrange(dilated_img, '1 1 w h -> w h') 82 | return dilated_img 83 | 84 | 85 | def mask_closing_half(mask_tensor, k: int): 86 | """ 87 | param: 88 | mask in size [H, W] 89 | k = kernel number 90 | return: 91 | mask in size [H, W] 92 | """ 93 | mask_tensor = rearrange(mask_tensor, 'w h -> 1 1 w h') 94 | kernel = torch.ones(k, k) 95 | dilated_img = dilation(mask_tensor, kernel, border_type="constant") 96 | kernel = torch.ones(k // 2, k // 2) 97 | dilated_img = erosion(dilated_img, kernel, border_type="constant") 98 | dilated_img = rearrange(dilated_img, '1 1 w h -> w h') 99 | return dilated_img 100 | 101 | 102 | def transform_box_mask(labeled_box, sam_box, mask): 103 | mask_box = mask[sam_box[1]:sam_box[3], sam_box[0]:sam_box[2]] 104 | reshape_y = labeled_box[3] - labeled_box[1] 105 | reshape_x = labeled_box[2] - labeled_box[0] 106 | transform = T.Resize((reshape_y, reshape_x)) 107 | mask_box = transform(mask_box.unsqueeze(0))[0] 108 | mask_return = torch.zeros(mask.shape) 109 | mask_return[labeled_box[1]:labeled_box[3], labeled_box[0]:labeled_box[2]] = mask_box 110 | return mask_return 111 | 112 | 113 | def transform_box_mask_paste(labeled_box, sam_box, mask, background_image, subject_image): 114 | mask_box = mask[sam_box[1]:sam_box[3], sam_box[0]:sam_box[2]] 115 | subject_box_array = np.asarray(subject_image) 116 | subject_box = torch.from_numpy(subject_box_array[sam_box[1]:sam_box[3], sam_box[0]:sam_box[2], :]).permute(2, 0, 1) 117 | reshape_y = labeled_box[3] - labeled_box[1] 118 | reshape_x = labeled_box[2] - labeled_box[0] 119 | transform = T.Resize((reshape_y, reshape_x)) 120 | mask_box = transform(mask_box.unsqueeze(0))[0] 121 | subject_box = transform(subject_box).permute(1, 2, 0) 122 | mask_return = torch.zeros(mask.shape) 123 | subject_copy = torch.zeros(np.asarray(background_image).shape) 124 | mask_return[labeled_box[1]:labeled_box[3], labeled_box[0]:labeled_box[2]] = mask_box 125 | subject_copy[labeled_box[1]:labeled_box[3], labeled_box[0]:labeled_box[2], :] = subject_box 126 | mask_repeat = mask_return.unsqueeze(-1).repeat(1, 1, 3) 127 | subject_paste = torch.where(mask_repeat > 0, subject_copy, torch.from_numpy(np.asarray(background_image))) 128 | return mask_return, subject_paste.detach().cpu().numpy() 129 | 130 | 131 | def resize_box_from_middle(resize_ratio, sam_box): 132 | y_mid = (sam_box[3] + sam_box[1]) // 2 133 | x_mid = (sam_box[2] + sam_box[0]) // 2 134 | y_max = (sam_box[3] - y_mid) * resize_ratio + y_mid 135 | y_min = y_mid - (sam_box[3] - y_mid) * resize_ratio 136 | x_max = (sam_box[2] - x_mid) * resize_ratio + x_mid 137 | x_min = x_mid - (sam_box[2] - x_mid) * resize_ratio 138 | return [int(x_min), int(y_min), int(x_max), int(y_max)] 139 | 140 | 141 | def resize_box_from_bottom(resize_ratio, sam_box): 142 | y_mid = sam_box[3] 143 | x_mid = (sam_box[2] + sam_box[0]) // 2 144 | y_max = sam_box[3] 145 | y_min = y_mid - (sam_box[3] - sam_box[1]) * resize_ratio 146 | x_max = (sam_box[2] - x_mid) * resize_ratio + x_mid 147 | x_min = x_mid - (sam_box[2] - x_mid) * resize_ratio 148 | return [int(x_min), int(y_min), int(x_max), int(y_max)] 149 | 150 | 151 | def bounding_box_merge(bbox_tensor): 152 | x_min_list = bbox_tensor[:, 0].tolist() 153 | y_min_list = bbox_tensor[:, 1].tolist() 154 | x_max_list = bbox_tensor[:, 2].tolist() 155 | y_max_list = bbox_tensor[:, 3].tolist() 156 | x_min = min(x_min_list) 157 | y_min = min(y_min_list) 158 | x_max = max(x_max_list) 159 | y_max = max(y_max_list) 160 | result = [int(x_min), int(y_min), int(x_max), int(y_max)] 161 | result = [max(0, min(512, x)) for x in result] 162 | return result 163 | 164 | 165 | def get_mask_from_bbox(bbox): 166 | return_mask = torch.zeros(512, 512) 167 | y = bbox[3] - bbox[1] 168 | x = bbox[2] - bbox[0] 169 | mask = torch.ones(y, x) 170 | return_mask[bbox[1]:bbox[3], bbox[0]:bbox[2]] = mask 171 | return_mask = return_mask > 0 172 | return return_mask 173 | -------------------------------------------------------------------------------- /src/utils/path_finder.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def get_file_path(filename): 4 | """ 5 | search file across whole repo and return abspath 6 | """ 7 | for root, dirs, files in os.walk(r'.'): 8 | for name in files: 9 | if name == filename: 10 | return os.path.abspath(os.path.join(root, name)) 11 | raise FileNotFoundError(filename, "not found.") 12 | 13 | def check_format(source_path): 14 | if(os.path.isdir(source_path)): 15 | return 'folder' 16 | IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp' # include image suffixes 17 | VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes 18 | if(source_path.split(".")[-1] in IMG_FORMATS): 19 | return 'image' 20 | if(source_path.split(".")[-1] in VID_FORMATS): 21 | return 'video' 22 | 23 | return None 24 | 25 | def check_is_image(source_path): 26 | if(not os.path.isdir(source_path)): 27 | IMG_FORMATS = 'jpeg', 'jpg', 'png' # include image suffixes 28 | if(source_path.split(".")[-1] in IMG_FORMATS): 29 | return True 30 | return False -------------------------------------------------------------------------------- /src/utils/visual_helper.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.transforms as transforms 3 | from torchvision.utils import save_image 4 | import numpy as np 5 | import os 6 | from PIL import Image 7 | 8 | 9 | def save_tensor_image(tensor_image, dest_folder, filename: str, filetype="jpg"): 10 | """ 11 | param: 12 | image in size [B, 3, H, W] 13 | dest_folder 14 | filename 15 | return: 16 | --- 17 | """ 18 | if not os.path.exists(dest_folder): 19 | os.makedirs(dest_folder) 20 | save_image(tensor_image, os.path.join(dest_folder, f"{filename}.{filetype}")) 21 | 22 | 23 | def save_pil_image(pil_image, dest_folder, filename: str, filetype="jpg"): 24 | """ 25 | param: 26 | image as PIL Image 27 | dest_folder 28 | filename 29 | return: 30 | --- 31 | """ 32 | if not os.path.exists(dest_folder): 33 | os.makedirs(dest_folder) 34 | pil_image.save(os.path.join(dest_folder, f"{filename}.{filetype}")) 35 | 36 | 37 | def get_mask_pil_image(mask_tensor): 38 | """ 39 | param: 40 | mask_tensor in size [H, W] 41 | return: 42 | mask in PIL image 43 | """ 44 | mask = np.array(mask_tensor).astype('uint8') 45 | mask = np.squeeze(mask) 46 | mask_img = Image.fromarray(mask * 255) 47 | return mask_img 48 | 49 | 50 | def pil_to_tensor(pil_img): 51 | """ 52 | param: 53 | pil_img - Image Object 54 | return: 55 | tensor in size [1, C, H, W] 56 | """ 57 | tensor = transforms.ToTensor()(pil_img).unsqueeze_(0) 58 | return tensor 59 | 60 | 61 | def tensor_to_pil(tensor_img): 62 | """ 63 | param: 64 | tensor_img in size [1, C, H, W] 65 | return: 66 | pil - Image Object 67 | """ 68 | pil = transforms.ToPILImage()(tensor_img.squeeze_(0)) 69 | return pil 70 | 71 | 72 | def get_concat_pil_images(images: list, direction: str = 'h'): 73 | """ 74 | param: 75 | images - list of pil images 76 | direction - h for horizonal 77 | return: 78 | pil - Image Object 79 | """ 80 | if direction == 'h': 81 | widths, heights = zip(*(i.size for i in images)) 82 | total_width = sum(widths) 83 | max_height = max(heights) 84 | new_im = Image.new('RGB', (total_width, max_height)) 85 | x_offset = 0 86 | for im in images: 87 | new_im.paste(im, (x_offset, 0)) 88 | x_offset += im.size[0] 89 | return new_im 90 | else: 91 | widths, heights = zip(*(i.size for i in images)) 92 | max_width = max(widths) 93 | total_height = sum(heights) 94 | new_im = Image.new('RGB', (max_width, total_height)) 95 | y_offset = 0 96 | for im in images: 97 | new_im.paste(im, (0, y_offset)) 98 | y_offset += im.size[1] 99 | return new_im 100 | --------------------------------------------------------------------------------