├── LICENSE ├── README.md ├── configs ├── sd-objaverse-insert.yaml ├── sd-objaverse-multitask.yaml ├── sd-objaverse-remove.yaml ├── sd-objaverse-rotate.yaml └── sd-objaverse-translate.yaml ├── demo_images ├── insert_yard.png ├── move_cube.jpg ├── remove_cup.jpg └── rotate_mug.jpg ├── generate_scripts ├── insert.sh ├── remove.sh ├── rotate.sh └── translate.sh ├── ldm ├── __pycache__ │ ├── lr_scheduler.cpython-39.pyc │ └── util.cpython-39.pyc ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── base.cpython-39.pyc │ │ └── simple.cpython-39.pyc │ ├── base.py │ ├── coco.py │ ├── dummy.py │ ├── imagenet.py │ ├── inpainting │ │ ├── __init__.py │ │ └── synthetic_mask.py │ ├── laion.py │ ├── lsun.py │ ├── nerf_like.py │ └── simple.py ├── extras.py ├── guidance.py ├── lr_scheduler.py ├── models │ ├── __pycache__ │ │ └── autoencoder.cpython-39.pyc │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── ddim.cpython-39.pyc │ │ ├── ddpm.cpython-39.pyc │ │ └── sampling_util.cpython-39.pyc │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── plms.py │ │ └── sampling_util.py ├── modules │ ├── __pycache__ │ │ ├── attention.cpython-39.pyc │ │ ├── ema.cpython-39.pyc │ │ └── x_transformer.cpython-39.pyc │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ ├── model.cpython-39.pyc │ │ │ ├── openaimodel.cpython-39.pyc │ │ │ └── util.cpython-39.pyc │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ └── distributions.cpython-39.pyc │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-39.pyc │ │ │ └── modules.cpython-39.pyc │ │ └── modules.py │ ├── evaluate │ │ ├── adm_evaluator.py │ │ ├── evaluate_perceptualsim.py │ │ ├── frechet_video_distance.py │ │ ├── ssim.py │ │ └── torch_frechet_video_distance.py │ ├── image_degradation │ │ ├── __init__.py │ │ ├── bsrgan.py │ │ ├── bsrgan_light.py │ │ ├── utils │ │ │ └── test.png │ │ └── utils_image.py │ ├── lora.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py ├── thirdp │ └── psp │ │ ├── __pycache__ │ │ ├── helpers.cpython-39.pyc │ │ ├── id_loss.cpython-39.pyc │ │ └── model_irse.cpython-39.pyc │ │ ├── helpers.py │ │ ├── id_loss.py │ │ └── model_irse.py └── util.py ├── main.py ├── metrics.py ├── objaverse_cat_descriptions_64k.json ├── req.txt ├── run_eval.py ├── run_generation.py ├── setup_reqs.sh ├── train.sh └── uses.md /LICENSE: -------------------------------------------------------------------------------- 1 | CreativeML Open RAIL-M 2 | 3 | Section I: PREAMBLE 4 | 5 | Multimodal generative models are being widely adopted and used, and have the potential to transform the way artists, among other individuals, conceive and benefit from AI or ML technologies as a tool for content creation. 6 | 7 | Notwithstanding the current and potential benefits that these artifacts can bring to society at large, there are also concerns about potential misuses of them, either due to their technical limitations or ethical considerations. 8 | 9 | In short, this license strives for both the open and responsible downstream use of the accompanying model. When it comes to the open character, we took inspiration from open source permissive licenses regarding the grant of IP rights. Referring to the downstream responsible use, we added use-based restrictions not permitting the use of the Model in very specific scenarios, in order for the licensor to be able to enforce the license in case potential misuses of the Model may occur. At the same time, we strive to promote open and responsible research on generative models for art and content generation. 10 | 11 | Even though downstream derivative versions of the model could be released under different licensing terms, the latter will always have to include - at minimum - the same use-based restrictions as the ones in the original license (this license). We believe in the intersection between open and responsible AI development; thus, this License aims to strike a balance between both in order to enable responsible open-science in the field of AI. 12 | 13 | This License governs the use of the model (and its derivatives) and is informed by the model card associated with the model. 14 | 15 | NOW THEREFORE, You and Licensor agree as follows: 16 | 17 | 1. Definitions 18 | 19 | - "License" means the terms and conditions for use, reproduction, and Distribution as defined in this document. 20 | - "Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License. 21 | - "Output" means the results of operating a Model as embodied in informational content resulting therefrom. 22 | - "Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material. 23 | - "Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model. 24 | - "Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any. 25 | - "Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access. 26 | - "Licensor" means the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model. 27 | - "You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, image generator. 28 | - "Third Parties" means individuals or legal entities that are not under common control with Licensor or You. 29 | - "Contribution" means any work of authorship, including the original version of the Model and any modifications or additions to that Model or Derivatives of the Model thereof, that is intentionally submitted to Licensor for inclusion in the Model by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Model, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." 30 | - "Contributor" means Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Model. 31 | 32 | Section II: INTELLECTUAL PROPERTY RIGHTS 33 | 34 | Both copyright and patent grants apply to the Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of the Model are subject to additional terms as described in Section III. 35 | 36 | 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly display, publicly perform, sublicense, and distribute the Complementary Material, the Model, and Derivatives of the Model. 37 | 3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Model to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution incorporated within the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or Work shall terminate as of the date such litigation is asserted or filed. 38 | 39 | Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION 40 | 41 | 4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions: 42 | Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material. 43 | You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License; 44 | You must cause any modified files to carry prominent notices stating that You changed the files; 45 | You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model. 46 | You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. - for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License. 47 | 5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5). 48 | 6. The Output You Generate. Except as set forth herein, Licensor claims no rights in the Output You generate using the Model. You are accountable for the Output you generate and its subsequent uses. No use of the output can contravene any provision as stated in the License. 49 | 50 | Section IV: OTHER PROVISIONS 51 | 52 | 7. Updates and Runtime Restrictions. To the maximum extent permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage of the Model in violation of this License, update the Model through electronic means, or modify the Output of the Model based on updates. You shall undertake reasonable efforts to use the latest version of the Model. 53 | 8. Trademarks and related. Nothing in this License permits You to make use of Licensors’ trademarks, trade names, logos or to otherwise suggest endorsement or misrepresent the relationship between the parties; and any rights not expressly granted herein are reserved by the Licensors. 54 | 9. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Model and the Complementary Material (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Model, Derivatives of the Model, and the Complementary Material and assume any risks associated with Your exercise of permissions under this License. 55 | 10. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Model and the Complementary Material (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 56 | 11. Accepting Warranty or Additional Liability. While redistributing the Model, Derivatives of the Model and the Complementary Material thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. 57 | 12. If any provision of this License is held to be invalid, illegal or unenforceable, the remaining provisions shall be unaffected thereby and remain valid as if such provision had not been set forth herein. 58 | 59 | END OF TERMS AND CONDITIONS 60 | 61 | 62 | 63 | 64 | Attachment A 65 | 66 | Use Restrictions 67 | 68 | You agree not to use the Model or Derivatives of the Model: 69 | - In any way that violates any applicable national, federal, state, local or international law or regulation; 70 | - For the purpose of exploiting, harming or attempting to exploit or harm minors in any way; 71 | - To generate or disseminate verifiably false information and/or content with the purpose of harming others; 72 | - To generate or disseminate personal identifiable information that can be used to harm an individual; 73 | - To defame, disparage or otherwise harass others; 74 | - For fully automated decision making that adversely impacts an individual’s legal rights or otherwise creates or modifies a binding, enforceable obligation; 75 | - For any use intended to or which has the effect of discriminating against or harming individuals or groups based on online or offline social behavior or known or predicted personal or personality characteristics; 76 | - To exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm; 77 | - For any use intended to or which has the effect of discriminating against individuals or groups based on legally protected characteristics or categories; 78 | - To provide medical advice and medical results interpretation; 79 | - To generate or disseminate information for the purpose to be used for administration of justice, law enforcement, immigration or asylum processes, such as predicting an individual will commit fraud/crime commitment (e.g. by text profiling, drawing causal relationships between assertions made in documents, indiscriminate and arbitrarily-targeted use). 80 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OBJect 3DIT: Language-guided 3D-aware Image Editing 2 | 3 | This is the official codebase for the paper [OBJect 3DIT: Language-guided 3D-aware Image Editing](https://arxiv.org/abs/2307.11073). 4 | 5 | ## Set-up 6 | 7 | First clone this repository. 8 | 9 | ```bash 10 | git clone https://github.com/allenai/object-edit.git 11 | cd object-edit 12 | mkdir DATASET 13 | ``` 14 | 15 | Optionally, create a conda environment. 16 | 17 | ```bash 18 | conda create --name object python=3.9 19 | conda activate object 20 | ``` 21 | Install all of the requirements. 22 | ```bash 23 | . setup_reqs.sh 24 | ``` 25 | Download the dataset from HuggingFace [here](https://huggingface.co/datasets/allenai/object-edit/tree/main). By default the dataset is expected to be in the `DATASET` dir, so download them here. You can also change the default path of the dataset in the config files. 26 | Unzip with the following commands 27 | ``` 28 | cd DATASET 29 | tar -xzvf TASK.tar.gz 30 | rm TASK.tar.gz 31 | ``` 32 | where `TASK` is either "remove", "rotate" or "translate". There is no extra data needed for the insertion task since it uses the same data as removal. 33 | 34 | The dataset has the following structure. 35 | ``` 36 | DATASET/ 37 | └── rotate/ 38 | └── train/ 39 | ├── uid/ 40 | │ ├── 1.png 41 | │ ├── 2.png 42 | │ ├── 1_mask.png 43 | | ├── 2_mask.png 44 | │ └── metadata.json 45 | ``` 46 | The checkpoints for trained editing models and Zero123 initialization can also be found on the Huggingface page for this project. 47 | ## Training 48 | 49 | If you would like to finetune from a [Zero123](https://github.com/cvlab-columbia/zero123) or [Image-Conditioned StableDiffusion](https://huggingface.co/lambdalabs/stable-diffusion-image-conditioned) checkpoint, please download and modify the path in `train.sh`. If you would like to train from scratch, then delete the `--finetune_from` argument from `train.sh`. You may also change the devices used in the `--gpus` argument. To train, run the following, replacing `TASK` with either "rotate","remove","insert","translate" or "multitask": 50 | ```bash 51 | . train.sh TASK 52 | ``` 53 | 54 | ## Inference demo 55 | You can run the scripts in `generate_scripts` to see inference in each of the editing tasks. 56 | ``` 57 | . generate_scripts/rotate.sh 58 | . generate_scripts/remove.sh 59 | . generate_scripts/insert.sh 60 | . generate_scripts/translate.sh 61 | ``` 62 | They each run the `run_generation.py` script. You can modify the arguments in these shell scripts to perform different edits. Note that the object prompt should not contain the editing instruction, that will be filled in automatically. You only need to put in a description of the targeted object. 63 | ## Evaluation 64 | 65 | If you would like to evaluate your generated images on the benchmark, you can run: 66 | 67 | ``` 68 | python run_eval.py \ 69 | --generation_dir YOUR_GENERATED_IMAGES_PATH \ 70 | --data_dir PATH_OF_OBJECT_DATASET \ 71 | --task [rotate|remove|insert|translate] \ 72 | --split [train|val|test] \ 73 | --seen [seen|unseen] \ 74 | --save_dir PATH_TO_SAVE_STATISTICS_SUMMARY 75 | ``` 76 | 77 | This script assumes your generated images are saved in the directory with the same UID as the corresponding sample in the dataset. 78 | ``` 79 | base_path/ 80 | │ 81 | ├── uid/ 82 | │ ├── 0.png 83 | │ ├── 1.png 84 | │ ├── 2.png 85 | │ └── 3.png 86 | -------------------------------------------------------------------------------- /configs/sd-objaverse-insert.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image_target" 11 | cond_stage_key: "cond" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid 16 | monitor: val/loss/dataloader_idx_0 17 | scale_factor: 0.18215 18 | use_lora: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 100 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 8 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder 71 | 72 | 73 | data: 74 | target: ldm.data.simple.ObjaverseTaskDataModuleFromConfig 75 | params: 76 | root_dir: 'DATASET' 77 | task: "insert" 78 | batch_size: 64 79 | num_workers: 16 80 | total_view: 4 81 | train: 82 | validation: False 83 | image_transforms: 84 | size: 256 85 | 86 | validation: 87 | validation: True 88 | image_transforms: 89 | size: 256 90 | 91 | 92 | lightning: 93 | find_unused_parameters: false 94 | modelcheckpoint: 95 | params: 96 | every_n_epochs: 5 97 | callbacks: 98 | image_logger: 99 | target: main.ImageLogger 100 | params: 101 | batch_frequency: 500 102 | max_images: 32 103 | increase_log_steps: False 104 | log_first_step: True 105 | log_all_val: True 106 | log_images_kwargs: 107 | use_ema_scope: False 108 | inpaint: False 109 | plot_progressive_rows: False 110 | plot_diffusion_rows: False 111 | N: 32 112 | unconditional_guidance_scale: 1.0 113 | unconditional_guidance_label: [""] 114 | 115 | # trainer: 116 | # benchmark: True 117 | # num_sanity_val_steps: 0 118 | # accumulate_grad_batches: 2 119 | # check_val_every_n_epoch: 1 120 | trainer: 121 | benchmark: True 122 | check_val_every_n_epoch: 5 123 | num_sanity_val_steps: 0 124 | accumulate_grad_batches: 2 125 | log_every_n_steps: 20 -------------------------------------------------------------------------------- /configs/sd-objaverse-multitask.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image_target" 11 | cond_stage_key: "cond" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid 16 | monitor: val/loss/dataloader_idx_0 17 | scale_factor: 0.18215 18 | use_lora: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 100 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 8 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder 71 | 72 | 73 | data: 74 | target: ldm.data.simple.ObjaverseTaskDataModuleFromConfig 75 | params: 76 | root_dir: 'DATASET' 77 | task: "multitask" 78 | batch_size: 64 79 | num_workers: 16 80 | train: 81 | validation: False 82 | image_transforms: 83 | size: 256 84 | num_samples: 100000 85 | 86 | validation: 87 | validation: True 88 | image_transforms: 89 | size: 256 90 | num_samples: 512 91 | 92 | 93 | lightning: 94 | find_unused_parameters: false 95 | modelcheckpoint: 96 | params: 97 | every_n_epochs: 5 98 | callbacks: 99 | image_logger: 100 | target: main.ImageLogger 101 | params: 102 | batch_frequency: 500 103 | max_images: 32 104 | increase_log_steps: False 105 | log_first_step: True 106 | log_all_val: True 107 | log_images_kwargs: 108 | use_ema_scope: False 109 | inpaint: False 110 | plot_progressive_rows: False 111 | plot_diffusion_rows: False 112 | N: 64 113 | unconditional_guidance_scale: 1.0 114 | unconditional_guidance_label: [""] 115 | 116 | # trainer: 117 | # benchmark: True 118 | # num_sanity_val_steps: 0 119 | # accumulate_grad_batches: 2 120 | # check_val_every_n_epoch: 1 121 | trainer: 122 | benchmark: True 123 | check_val_every_n_epoch: 5 124 | num_sanity_val_steps: 0 125 | accumulate_grad_batches: 2 126 | log_every_n_steps: 20 -------------------------------------------------------------------------------- /configs/sd-objaverse-remove.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image_target" 11 | cond_stage_key: "cond" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid 16 | monitor: val/loss/dataloader_idx_0 17 | scale_factor: 0.18215 18 | use_lora: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 100 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 8 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder 71 | 72 | 73 | data: 74 | target: ldm.data.simple.ObjaverseTaskDataModuleFromConfig 75 | params: 76 | root_dir: 'DATASET' 77 | task: "remove" 78 | batch_size: 64 79 | num_workers: 0 80 | train: 81 | validation: False 82 | image_transforms: 83 | size: 256 84 | 85 | validation: 86 | validation: True 87 | image_transforms: 88 | size: 256 89 | 90 | 91 | lightning: 92 | find_unused_parameters: false 93 | modelcheckpoint: 94 | params: 95 | every_n_epochs: 5 96 | callbacks: 97 | image_logger: 98 | target: main.ImageLogger 99 | params: 100 | batch_frequency: 500 101 | max_images: 32 102 | increase_log_steps: False 103 | log_first_step: True 104 | log_all_val: True 105 | log_images_kwargs: 106 | use_ema_scope: False 107 | inpaint: False 108 | plot_progressive_rows: False 109 | plot_diffusion_rows: False 110 | N: 32 111 | unconditional_guidance_scale: 1.0 112 | unconditional_guidance_label: [""] 113 | 114 | # trainer: 115 | # benchmark: True 116 | # num_sanity_val_steps: 0 117 | # accumulate_grad_batches: 2 118 | # check_val_every_n_epoch: 1 119 | trainer: 120 | benchmark: True 121 | check_val_every_n_epoch: 5 122 | num_sanity_val_steps: 0 123 | accumulate_grad_batches: 2 124 | log_every_n_steps: 20 -------------------------------------------------------------------------------- /configs/sd-objaverse-rotate.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image_target" 11 | cond_stage_key: "cond" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid 16 | monitor: val/loss/dataloader_idx_0 17 | scale_factor: 0.18215 18 | use_lora: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 100 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 8 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder 71 | 72 | 73 | data: 74 | target: ldm.data.simple.ObjaverseTaskDataModuleFromConfig 75 | params: 76 | root_dir: 'DATASET' 77 | task: "rotate" 78 | batch_size: 64 79 | num_workers: 16 80 | train: 81 | validation: False 82 | image_transforms: 83 | size: 256 84 | 85 | validation: 86 | validation: True 87 | image_transforms: 88 | size: 256 89 | 90 | 91 | lightning: 92 | find_unused_parameters: false 93 | modelcheckpoint: 94 | params: 95 | every_n_epochs: 5 96 | callbacks: 97 | image_logger: 98 | target: main.ImageLogger 99 | params: 100 | batch_frequency: 500 101 | max_images: 32 102 | increase_log_steps: False 103 | log_first_step: True 104 | log_all_val: True 105 | log_images_kwargs: 106 | use_ema_scope: False 107 | inpaint: False 108 | plot_progressive_rows: False 109 | plot_diffusion_rows: False 110 | N: 32 111 | unconditional_guidance_scale: 1.0 112 | unconditional_guidance_label: [""] 113 | 114 | # trainer: 115 | # benchmark: True 116 | # num_sanity_val_steps: 0 117 | # accumulate_grad_batches: 2 118 | # check_val_every_n_epoch: 1 119 | trainer: 120 | benchmark: True 121 | check_val_every_n_epoch: 5 122 | num_sanity_val_steps: 1 123 | accumulate_grad_batches: 2 124 | log_every_n_steps: 20 -------------------------------------------------------------------------------- /configs/sd-objaverse-translate.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "image_target" 11 | cond_stage_key: "cond" 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid 16 | monitor: val/loss/dataloader_idx_0 17 | scale_factor: 0.18215 18 | use_lora: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 100 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 8 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: True 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPImageEmbedder 71 | 72 | 73 | data: 74 | target: ldm.data.simple.ObjaverseTaskDataModuleFromConfig 75 | params: 76 | root_dir: 'DATASET' 77 | task: "translate" 78 | batch_size: 64 79 | num_workers: 16 80 | train: 81 | validation: False 82 | image_transforms: 83 | size: 256 84 | 85 | validation: 86 | validation: True 87 | image_transforms: 88 | size: 256 89 | 90 | 91 | lightning: 92 | find_unused_parameters: false 93 | modelcheckpoint: 94 | params: 95 | every_n_epochs: 5 96 | callbacks: 97 | image_logger: 98 | target: main.ImageLogger 99 | params: 100 | batch_frequency: 500 101 | max_images: 32 102 | increase_log_steps: False 103 | log_first_step: True 104 | log_all_val: True 105 | log_images_kwargs: 106 | use_ema_scope: False 107 | inpaint: False 108 | plot_progressive_rows: False 109 | plot_diffusion_rows: False 110 | N: 32 111 | unconditional_guidance_scale: 1.0 112 | unconditional_guidance_label: [""] 113 | 114 | # trainer: 115 | # benchmark: True 116 | # num_sanity_val_steps: 0 117 | # accumulate_grad_batches: 2 118 | # check_val_every_n_epoch: 1 119 | trainer: 120 | benchmark: True 121 | check_val_every_n_epoch: 5 122 | num_sanity_val_steps: 0 123 | accumulate_grad_batches: 2 124 | log_every_n_steps: 20 -------------------------------------------------------------------------------- /demo_images/insert_yard.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/demo_images/insert_yard.png -------------------------------------------------------------------------------- /demo_images/move_cube.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/demo_images/move_cube.jpg -------------------------------------------------------------------------------- /demo_images/remove_cup.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/demo_images/remove_cup.jpg -------------------------------------------------------------------------------- /demo_images/rotate_mug.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/demo_images/rotate_mug.jpg -------------------------------------------------------------------------------- /generate_scripts/insert.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python run_generation.py \ 4 | --task insert \ 5 | --checkpoint_path insert.ckpt \ 6 | --image_path demo_images/insert_yard.png \ 7 | --object_prompt "a blue bookcase" \ 8 | --position 0.25,0.6 \ 9 | --device 1 \ 10 | --cfg_scale 3.0 -------------------------------------------------------------------------------- /generate_scripts/remove.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python run_generation.py \ 4 | --task remove \ 5 | --checkpoint_path remove.ckpt \ 6 | --image_path demo_images/remove_cup.jpg \ 7 | --object_prompt "the purple cup" \ 8 | --device 1 \ 9 | --cfg_scale 3.0 -------------------------------------------------------------------------------- /generate_scripts/rotate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python run_generation.py \ 4 | --task insert \ 5 | --checkpoint_path rotate.ckpt \ 6 | --image_path demo_images/rotate_mug.jpg \ 7 | --object_prompt "white mug" \ 8 | --rotation_angle 90 \ 9 | --device 1 \ 10 | --cfg_scale 3.0 -------------------------------------------------------------------------------- /generate_scripts/translate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python run_generation.py \ 4 | --task translate \ 5 | --checkpoint_path translate.ckpt \ 6 | --image_path demo_images/move_cube.jpg \ 7 | --object_prompt "the blue cube" \ 8 | --position 0.8,0.2 \ 9 | --device 1 \ 10 | --cfg_scale 3.0 -------------------------------------------------------------------------------- /ldm/__pycache__/lr_scheduler.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/__pycache__/lr_scheduler.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/base.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/data/__pycache__/base.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/simple.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/data/__pycache__/simple.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/data/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from abc import abstractmethod 4 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 5 | 6 | 7 | class Txt2ImgIterableBaseDataset(IterableDataset): 8 | ''' 9 | Define an interface to make the IterableDatasets for text2img data chainable 10 | ''' 11 | def __init__(self, num_records=0, valid_ids=None, size=256): 12 | super().__init__() 13 | self.num_records = num_records 14 | self.valid_ids = valid_ids 15 | self.sample_ids = valid_ids 16 | self.size = size 17 | 18 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 19 | 20 | def __len__(self): 21 | return self.num_records 22 | 23 | @abstractmethod 24 | def __iter__(self): 25 | pass 26 | 27 | 28 | class PRNGMixin(object): 29 | """ 30 | Adds a prng property which is a numpy RandomState which gets 31 | reinitialized whenever the pid changes to avoid synchronized sampling 32 | behavior when used in conjunction with multiprocessing. 33 | """ 34 | @property 35 | def prng(self): 36 | currentpid = os.getpid() 37 | if getattr(self, "_initpid", None) != currentpid: 38 | self._initpid = currentpid 39 | self._prng = np.random.RandomState() 40 | return self._prng 41 | -------------------------------------------------------------------------------- /ldm/data/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import albumentations 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | from torch.utils.data import Dataset 8 | from abc import abstractmethod 9 | 10 | 11 | class CocoBase(Dataset): 12 | """needed for (image, caption, segmentation) pairs""" 13 | def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False, 14 | crop_size=None, force_no_crop=False, given_files=None, use_segmentation=True,crop_type=None): 15 | self.split = self.get_split() 16 | self.size = size 17 | if crop_size is None: 18 | self.crop_size = size 19 | else: 20 | self.crop_size = crop_size 21 | 22 | assert crop_type in [None, 'random', 'center'] 23 | self.crop_type = crop_type 24 | self.use_segmenation = use_segmentation 25 | self.onehot = onehot_segmentation # return segmentation as rgb or one hot 26 | self.stuffthing = use_stuffthing # include thing in segmentation 27 | if self.onehot and not self.stuffthing: 28 | raise NotImplemented("One hot mode is only supported for the " 29 | "stuffthings version because labels are stored " 30 | "a bit different.") 31 | 32 | data_json = datajson 33 | with open(data_json) as json_file: 34 | self.json_data = json.load(json_file) 35 | self.img_id_to_captions = dict() 36 | self.img_id_to_filepath = dict() 37 | self.img_id_to_segmentation_filepath = dict() 38 | 39 | assert data_json.split("/")[-1] in [f"captions_train{self.year()}.json", 40 | f"captions_val{self.year()}.json"] 41 | # TODO currently hardcoded paths, would be better to follow logic in 42 | # cocstuff pixelmaps 43 | if self.use_segmenation: 44 | if self.stuffthing: 45 | self.segmentation_prefix = ( 46 | f"data/cocostuffthings/val{self.year()}" if 47 | data_json.endswith(f"captions_val{self.year()}.json") else 48 | f"data/cocostuffthings/train{self.year()}") 49 | else: 50 | self.segmentation_prefix = ( 51 | f"data/coco/annotations/stuff_val{self.year()}_pixelmaps" if 52 | data_json.endswith(f"captions_val{self.year()}.json") else 53 | f"data/coco/annotations/stuff_train{self.year()}_pixelmaps") 54 | 55 | imagedirs = self.json_data["images"] 56 | self.labels = {"image_ids": list()} 57 | for imgdir in tqdm(imagedirs, desc="ImgToPath"): 58 | self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"]) 59 | self.img_id_to_captions[imgdir["id"]] = list() 60 | pngfilename = imgdir["file_name"].replace("jpg", "png") 61 | if self.use_segmenation: 62 | self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join( 63 | self.segmentation_prefix, pngfilename) 64 | if given_files is not None: 65 | if pngfilename in given_files: 66 | self.labels["image_ids"].append(imgdir["id"]) 67 | else: 68 | self.labels["image_ids"].append(imgdir["id"]) 69 | 70 | capdirs = self.json_data["annotations"] 71 | for capdir in tqdm(capdirs, desc="ImgToCaptions"): 72 | # there are in average 5 captions per image 73 | #self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]])) 74 | self.img_id_to_captions[capdir["image_id"]].append(capdir["caption"]) 75 | 76 | self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) 77 | if self.split=="validation": 78 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) 79 | else: 80 | # default option for train is random crop 81 | if self.crop_type in [None, 'random']: 82 | self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) 83 | else: 84 | self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) 85 | self.preprocessor = albumentations.Compose( 86 | [self.rescaler, self.cropper], 87 | additional_targets={"segmentation": "image"}) 88 | if force_no_crop: 89 | self.rescaler = albumentations.Resize(height=self.size, width=self.size) 90 | self.preprocessor = albumentations.Compose( 91 | [self.rescaler], 92 | additional_targets={"segmentation": "image"}) 93 | 94 | @abstractmethod 95 | def year(self): 96 | raise NotImplementedError() 97 | 98 | def __len__(self): 99 | return len(self.labels["image_ids"]) 100 | 101 | def preprocess_image(self, image_path, segmentation_path=None): 102 | image = Image.open(image_path) 103 | if not image.mode == "RGB": 104 | image = image.convert("RGB") 105 | image = np.array(image).astype(np.uint8) 106 | if segmentation_path: 107 | segmentation = Image.open(segmentation_path) 108 | if not self.onehot and not segmentation.mode == "RGB": 109 | segmentation = segmentation.convert("RGB") 110 | segmentation = np.array(segmentation).astype(np.uint8) 111 | if self.onehot: 112 | assert self.stuffthing 113 | # stored in caffe format: unlabeled==255. stuff and thing from 114 | # 0-181. to be compatible with the labels in 115 | # https://github.com/nightrome/cocostuff/blob/master/labels.txt 116 | # we shift stuffthing one to the right and put unlabeled in zero 117 | # as long as segmentation is uint8 shifting to right handles the 118 | # latter too 119 | assert segmentation.dtype == np.uint8 120 | segmentation = segmentation + 1 121 | 122 | processed = self.preprocessor(image=image, segmentation=segmentation) 123 | 124 | image, segmentation = processed["image"], processed["segmentation"] 125 | else: 126 | image = self.preprocessor(image=image,)['image'] 127 | 128 | image = (image / 127.5 - 1.0).astype(np.float32) 129 | if segmentation_path: 130 | if self.onehot: 131 | assert segmentation.dtype == np.uint8 132 | # make it one hot 133 | n_labels = 183 134 | flatseg = np.ravel(segmentation) 135 | onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool) 136 | onehot[np.arange(flatseg.size), flatseg] = True 137 | onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int) 138 | segmentation = onehot 139 | else: 140 | segmentation = (segmentation / 127.5 - 1.0).astype(np.float32) 141 | return image, segmentation 142 | else: 143 | return image 144 | 145 | def __getitem__(self, i): 146 | img_path = self.img_id_to_filepath[self.labels["image_ids"][i]] 147 | if self.use_segmenation: 148 | seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]] 149 | image, segmentation = self.preprocess_image(img_path, seg_path) 150 | else: 151 | image = self.preprocess_image(img_path) 152 | captions = self.img_id_to_captions[self.labels["image_ids"][i]] 153 | # randomly draw one of all available captions per image 154 | caption = captions[np.random.randint(0, len(captions))] 155 | example = {"image": image, 156 | #"caption": [str(caption[0])], 157 | "caption": caption, 158 | "img_path": img_path, 159 | "filename_": img_path.split(os.sep)[-1] 160 | } 161 | if self.use_segmenation: 162 | example.update({"seg_path": seg_path, 'segmentation': segmentation}) 163 | return example 164 | 165 | 166 | class CocoImagesAndCaptionsTrain2017(CocoBase): 167 | """returns a pair of (image, caption)""" 168 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,): 169 | super().__init__(size=size, 170 | dataroot="data/coco/train2017", 171 | datajson="data/coco/annotations/captions_train2017.json", 172 | onehot_segmentation=onehot_segmentation, 173 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) 174 | 175 | def get_split(self): 176 | return "train" 177 | 178 | def year(self): 179 | return '2017' 180 | 181 | 182 | class CocoImagesAndCaptionsValidation2017(CocoBase): 183 | """returns a pair of (image, caption)""" 184 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, 185 | given_files=None): 186 | super().__init__(size=size, 187 | dataroot="data/coco/val2017", 188 | datajson="data/coco/annotations/captions_val2017.json", 189 | onehot_segmentation=onehot_segmentation, 190 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, 191 | given_files=given_files) 192 | 193 | def get_split(self): 194 | return "validation" 195 | 196 | def year(self): 197 | return '2017' 198 | 199 | 200 | 201 | class CocoImagesAndCaptionsTrain2014(CocoBase): 202 | """returns a pair of (image, caption)""" 203 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False,crop_type='random'): 204 | super().__init__(size=size, 205 | dataroot="data/coco/train2014", 206 | datajson="data/coco/annotations2014/annotations/captions_train2014.json", 207 | onehot_segmentation=onehot_segmentation, 208 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, 209 | use_segmentation=False, 210 | crop_type=crop_type) 211 | 212 | def get_split(self): 213 | return "train" 214 | 215 | def year(self): 216 | return '2014' 217 | 218 | class CocoImagesAndCaptionsValidation2014(CocoBase): 219 | """returns a pair of (image, caption)""" 220 | def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, 221 | given_files=None,crop_type='center',**kwargs): 222 | super().__init__(size=size, 223 | dataroot="data/coco/val2014", 224 | datajson="data/coco/annotations2014/annotations/captions_val2014.json", 225 | onehot_segmentation=onehot_segmentation, 226 | use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, 227 | given_files=given_files, 228 | use_segmentation=False, 229 | crop_type=crop_type) 230 | 231 | def get_split(self): 232 | return "validation" 233 | 234 | def year(self): 235 | return '2014' 236 | 237 | if __name__ == '__main__': 238 | with open("data/coco/annotations2014/annotations/captions_val2014.json", "r") as json_file: 239 | json_data = json.load(json_file) 240 | capdirs = json_data["annotations"] 241 | import pudb; pudb.set_trace() 242 | #d2 = CocoImagesAndCaptionsTrain2014(size=256) 243 | d2 = CocoImagesAndCaptionsValidation2014(size=256) 244 | print("constructed dataset.") 245 | print(f"length of {d2.__class__.__name__}: {len(d2)}") 246 | 247 | ex2 = d2[0] 248 | # ex3 = d3[0] 249 | # print(ex1["image"].shape) 250 | print(ex2["image"].shape) 251 | # print(ex3["image"].shape) 252 | # print(ex1["segmentation"].shape) 253 | print(ex2["caption"].__class__.__name__) 254 | -------------------------------------------------------------------------------- /ldm/data/dummy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import string 4 | from torch.utils.data import Dataset, Subset 5 | 6 | class DummyData(Dataset): 7 | def __init__(self, length, size): 8 | self.length = length 9 | self.size = size 10 | 11 | def __len__(self): 12 | return self.length 13 | 14 | def __getitem__(self, i): 15 | x = np.random.randn(*self.size) 16 | letters = string.ascii_lowercase 17 | y = ''.join(random.choice(string.ascii_lowercase) for i in range(10)) 18 | return {"jpg": x, "txt": y} 19 | 20 | 21 | class DummyDataWithEmbeddings(Dataset): 22 | def __init__(self, length, size, emb_size): 23 | self.length = length 24 | self.size = size 25 | self.emb_size = emb_size 26 | 27 | def __len__(self): 28 | return self.length 29 | 30 | def __getitem__(self, i): 31 | x = np.random.randn(*self.size) 32 | y = np.random.randn(*self.emb_size).astype(np.float32) 33 | return {"jpg": x, "txt": y} 34 | 35 | -------------------------------------------------------------------------------- /ldm/data/inpainting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/data/inpainting/__init__.py -------------------------------------------------------------------------------- /ldm/data/inpainting/synthetic_mask.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw 2 | import numpy as np 3 | 4 | settings = { 5 | "256narrow": { 6 | "p_irr": 1, 7 | "min_n_irr": 4, 8 | "max_n_irr": 50, 9 | "max_l_irr": 40, 10 | "max_w_irr": 10, 11 | "min_n_box": None, 12 | "max_n_box": None, 13 | "min_s_box": None, 14 | "max_s_box": None, 15 | "marg": None, 16 | }, 17 | "256train": { 18 | "p_irr": 0.5, 19 | "min_n_irr": 1, 20 | "max_n_irr": 5, 21 | "max_l_irr": 200, 22 | "max_w_irr": 100, 23 | "min_n_box": 1, 24 | "max_n_box": 4, 25 | "min_s_box": 30, 26 | "max_s_box": 150, 27 | "marg": 10, 28 | }, 29 | "512train": { # TODO: experimental 30 | "p_irr": 0.5, 31 | "min_n_irr": 1, 32 | "max_n_irr": 5, 33 | "max_l_irr": 450, 34 | "max_w_irr": 250, 35 | "min_n_box": 1, 36 | "max_n_box": 4, 37 | "min_s_box": 30, 38 | "max_s_box": 300, 39 | "marg": 10, 40 | }, 41 | "512train-large": { # TODO: experimental 42 | "p_irr": 0.5, 43 | "min_n_irr": 1, 44 | "max_n_irr": 5, 45 | "max_l_irr": 450, 46 | "max_w_irr": 400, 47 | "min_n_box": 1, 48 | "max_n_box": 4, 49 | "min_s_box": 75, 50 | "max_s_box": 450, 51 | "marg": 10, 52 | }, 53 | } 54 | 55 | 56 | def gen_segment_mask(mask, start, end, brush_width): 57 | mask = mask > 0 58 | mask = (255 * mask).astype(np.uint8) 59 | mask = Image.fromarray(mask) 60 | draw = ImageDraw.Draw(mask) 61 | draw.line([start, end], fill=255, width=brush_width, joint="curve") 62 | mask = np.array(mask) / 255 63 | return mask 64 | 65 | 66 | def gen_box_mask(mask, masked): 67 | x_0, y_0, w, h = masked 68 | mask[y_0:y_0 + h, x_0:x_0 + w] = 1 69 | return mask 70 | 71 | 72 | def gen_round_mask(mask, masked, radius): 73 | x_0, y_0, w, h = masked 74 | xy = [(x_0, y_0), (x_0 + w, y_0 + w)] 75 | 76 | mask = mask > 0 77 | mask = (255 * mask).astype(np.uint8) 78 | mask = Image.fromarray(mask) 79 | draw = ImageDraw.Draw(mask) 80 | draw.rounded_rectangle(xy, radius=radius, fill=255) 81 | mask = np.array(mask) / 255 82 | return mask 83 | 84 | 85 | def gen_large_mask(prng, img_h, img_w, 86 | marg, p_irr, min_n_irr, max_n_irr, max_l_irr, max_w_irr, 87 | min_n_box, max_n_box, min_s_box, max_s_box): 88 | """ 89 | img_h: int, an image height 90 | img_w: int, an image width 91 | marg: int, a margin for a box starting coordinate 92 | p_irr: float, 0 <= p_irr <= 1, a probability of a polygonal chain mask 93 | 94 | min_n_irr: int, min number of segments 95 | max_n_irr: int, max number of segments 96 | max_l_irr: max length of a segment in polygonal chain 97 | max_w_irr: max width of a segment in polygonal chain 98 | 99 | min_n_box: int, min bound for the number of box primitives 100 | max_n_box: int, max bound for the number of box primitives 101 | min_s_box: int, min length of a box side 102 | max_s_box: int, max length of a box side 103 | """ 104 | 105 | mask = np.zeros((img_h, img_w)) 106 | uniform = prng.randint 107 | 108 | if np.random.uniform(0, 1) < p_irr: # generate polygonal chain 109 | n = uniform(min_n_irr, max_n_irr) # sample number of segments 110 | 111 | for _ in range(n): 112 | y = uniform(0, img_h) # sample a starting point 113 | x = uniform(0, img_w) 114 | 115 | a = uniform(0, 360) # sample angle 116 | l = uniform(10, max_l_irr) # sample segment length 117 | w = uniform(5, max_w_irr) # sample a segment width 118 | 119 | # draw segment starting from (x,y) to (x_,y_) using brush of width w 120 | x_ = x + l * np.sin(a) 121 | y_ = y + l * np.cos(a) 122 | 123 | mask = gen_segment_mask(mask, start=(x, y), end=(x_, y_), brush_width=w) 124 | x, y = x_, y_ 125 | else: # generate Box masks 126 | n = uniform(min_n_box, max_n_box) # sample number of rectangles 127 | 128 | for _ in range(n): 129 | h = uniform(min_s_box, max_s_box) # sample box shape 130 | w = uniform(min_s_box, max_s_box) 131 | 132 | x_0 = uniform(marg, img_w - marg - w) # sample upper-left coordinates of box 133 | y_0 = uniform(marg, img_h - marg - h) 134 | 135 | if np.random.uniform(0, 1) < 0.5: 136 | mask = gen_box_mask(mask, masked=(x_0, y_0, w, h)) 137 | else: 138 | r = uniform(0, 60) # sample radius 139 | mask = gen_round_mask(mask, masked=(x_0, y_0, w, h), radius=r) 140 | return mask 141 | 142 | 143 | make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256train"]) 144 | make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256narrow"]) 145 | make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train"]) 146 | make_512_lama_mask_large = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train-large"]) 147 | 148 | 149 | MASK_MODES = { 150 | "256train": make_lama_mask, 151 | "256narrow": make_narrow_lama_mask, 152 | "512train": make_512_lama_mask, 153 | "512train-large": make_512_lama_mask_large 154 | } 155 | 156 | if __name__ == "__main__": 157 | import sys 158 | 159 | out = sys.argv[1] 160 | 161 | prng = np.random.RandomState(1) 162 | kwargs = settings["256train"] 163 | mask = gen_large_mask(prng, 256, 256, **kwargs) 164 | mask = (255 * mask).astype(np.uint8) 165 | mask = Image.fromarray(mask) 166 | mask.save(out) 167 | -------------------------------------------------------------------------------- /ldm/data/lsun.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | 9 | class LSUNBase(Dataset): 10 | def __init__(self, 11 | txt_file, 12 | data_root, 13 | size=None, 14 | interpolation="bicubic", 15 | flip_p=0.5 16 | ): 17 | self.data_paths = txt_file 18 | self.data_root = data_root 19 | with open(self.data_paths, "r") as f: 20 | self.image_paths = f.read().splitlines() 21 | self._length = len(self.image_paths) 22 | self.labels = { 23 | "relative_file_path_": [l for l in self.image_paths], 24 | "file_path_": [os.path.join(self.data_root, l) 25 | for l in self.image_paths], 26 | } 27 | 28 | self.size = size 29 | self.interpolation = {"linear": PIL.Image.LINEAR, 30 | "bilinear": PIL.Image.BILINEAR, 31 | "bicubic": PIL.Image.BICUBIC, 32 | "lanczos": PIL.Image.LANCZOS, 33 | }[interpolation] 34 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) 35 | 36 | def __len__(self): 37 | return self._length 38 | 39 | def __getitem__(self, i): 40 | example = dict((k, self.labels[k][i]) for k in self.labels) 41 | image = Image.open(example["file_path_"]) 42 | if not image.mode == "RGB": 43 | image = image.convert("RGB") 44 | 45 | # default to score-sde preprocessing 46 | img = np.array(image).astype(np.uint8) 47 | crop = min(img.shape[0], img.shape[1]) 48 | h, w, = img.shape[0], img.shape[1] 49 | img = img[(h - crop) // 2:(h + crop) // 2, 50 | (w - crop) // 2:(w + crop) // 2] 51 | 52 | image = Image.fromarray(img) 53 | if self.size is not None: 54 | image = image.resize((self.size, self.size), resample=self.interpolation) 55 | 56 | image = self.flip(image) 57 | image = np.array(image).astype(np.uint8) 58 | example["image"] = (image / 127.5 - 1.0).astype(np.float32) 59 | return example 60 | 61 | 62 | class LSUNChurchesTrain(LSUNBase): 63 | def __init__(self, **kwargs): 64 | super().__init__(txt_file="data/lsun/church_outdoor_train.txt", data_root="data/lsun/churches", **kwargs) 65 | 66 | 67 | class LSUNChurchesValidation(LSUNBase): 68 | def __init__(self, flip_p=0., **kwargs): 69 | super().__init__(txt_file="data/lsun/church_outdoor_val.txt", data_root="data/lsun/churches", 70 | flip_p=flip_p, **kwargs) 71 | 72 | 73 | class LSUNBedroomsTrain(LSUNBase): 74 | def __init__(self, **kwargs): 75 | super().__init__(txt_file="data/lsun/bedrooms_train.txt", data_root="data/lsun/bedrooms", **kwargs) 76 | 77 | 78 | class LSUNBedroomsValidation(LSUNBase): 79 | def __init__(self, flip_p=0.0, **kwargs): 80 | super().__init__(txt_file="data/lsun/bedrooms_val.txt", data_root="data/lsun/bedrooms", 81 | flip_p=flip_p, **kwargs) 82 | 83 | 84 | class LSUNCatsTrain(LSUNBase): 85 | def __init__(self, **kwargs): 86 | super().__init__(txt_file="data/lsun/cat_train.txt", data_root="data/lsun/cats", **kwargs) 87 | 88 | 89 | class LSUNCatsValidation(LSUNBase): 90 | def __init__(self, flip_p=0., **kwargs): 91 | super().__init__(txt_file="data/lsun/cat_val.txt", data_root="data/lsun/cats", 92 | flip_p=flip_p, **kwargs) 93 | -------------------------------------------------------------------------------- /ldm/data/nerf_like.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import os 3 | import json 4 | import numpy as np 5 | import torch 6 | import imageio 7 | import math 8 | import cv2 9 | from torchvision import transforms 10 | 11 | def cartesian_to_spherical(xyz): 12 | ptsnew = np.hstack((xyz, np.zeros(xyz.shape))) 13 | xy = xyz[:,0]**2 + xyz[:,1]**2 14 | z = np.sqrt(xy + xyz[:,2]**2) 15 | theta = np.arctan2(np.sqrt(xy), xyz[:,2]) # for elevation angle defined from Z-axis down 16 | #ptsnew[:,4] = np.arctan2(xyz[:,2], np.sqrt(xy)) # for elevation angle defined from XY-plane up 17 | azimuth = np.arctan2(xyz[:,1], xyz[:,0]) 18 | return np.array([theta, azimuth, z]) 19 | 20 | 21 | def get_T(T_target, T_cond): 22 | theta_cond, azimuth_cond, z_cond = cartesian_to_spherical(T_cond[None, :]) 23 | theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :]) 24 | 25 | d_theta = theta_target - theta_cond 26 | d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) 27 | d_z = z_target - z_cond 28 | 29 | d_T = torch.tensor([d_theta.item(), math.sin(d_azimuth.item()), math.cos(d_azimuth.item()), d_z.item()]) 30 | return d_T 31 | 32 | def get_spherical(T_target, T_cond): 33 | theta_cond, azimuth_cond, z_cond = cartesian_to_spherical(T_cond[None, :]) 34 | theta_target, azimuth_target, z_target = cartesian_to_spherical(T_target[None, :]) 35 | 36 | d_theta = theta_target - theta_cond 37 | d_azimuth = (azimuth_target - azimuth_cond) % (2 * math.pi) 38 | d_z = z_target - z_cond 39 | 40 | d_T = torch.tensor([math.degrees(d_theta.item()), math.degrees(d_azimuth.item()), d_z.item()]) 41 | return d_T 42 | 43 | class RTMV(Dataset): 44 | def __init__(self, root_dir='datasets/RTMV/google_scanned',\ 45 | first_K=64, resolution=256, load_target=False): 46 | self.root_dir = root_dir 47 | self.scene_list = sorted(next(os.walk(root_dir))[1]) 48 | self.resolution = resolution 49 | self.first_K = first_K 50 | self.load_target = load_target 51 | 52 | def __len__(self): 53 | return len(self.scene_list) 54 | 55 | def __getitem__(self, idx): 56 | scene_dir = os.path.join(self.root_dir, self.scene_list[idx]) 57 | with open(os.path.join(scene_dir, 'transforms.json'), "r") as f: 58 | meta = json.load(f) 59 | imgs = [] 60 | poses = [] 61 | for i_img in range(self.first_K): 62 | meta_img = meta['frames'][i_img] 63 | 64 | if i_img == 0 or self.load_target: 65 | img_path = os.path.join(scene_dir, meta_img['file_path']) 66 | img = imageio.imread(img_path) 67 | img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR) 68 | imgs.append(img) 69 | 70 | c2w = meta_img['transform_matrix'] 71 | poses.append(c2w) 72 | 73 | imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs 74 | imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2) 75 | imgs = imgs * 2 - 1. # convert to stable diffusion range 76 | poses = torch.tensor(np.array(poses).astype(np.float32)) 77 | return imgs, poses 78 | 79 | def blend_rgba(self, img): 80 | img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB 81 | return img 82 | 83 | 84 | class GSO(Dataset): 85 | def __init__(self, root_dir='datasets/GoogleScannedObjects',\ 86 | split='val', first_K=5, resolution=256, load_target=False, name='render_mvs'): 87 | self.root_dir = root_dir 88 | with open(os.path.join(root_dir, '%s.json' % split), "r") as f: 89 | self.scene_list = json.load(f) 90 | self.resolution = resolution 91 | self.first_K = first_K 92 | self.load_target = load_target 93 | self.name = name 94 | 95 | def __len__(self): 96 | return len(self.scene_list) 97 | 98 | def __getitem__(self, idx): 99 | scene_dir = os.path.join(self.root_dir, self.scene_list[idx]) 100 | with open(os.path.join(scene_dir, 'transforms_%s.json' % self.name), "r") as f: 101 | meta = json.load(f) 102 | imgs = [] 103 | poses = [] 104 | for i_img in range(self.first_K): 105 | meta_img = meta['frames'][i_img] 106 | 107 | if i_img == 0 or self.load_target: 108 | img_path = os.path.join(scene_dir, meta_img['file_path']) 109 | img = imageio.imread(img_path) 110 | img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR) 111 | imgs.append(img) 112 | 113 | c2w = meta_img['transform_matrix'] 114 | poses.append(c2w) 115 | 116 | imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs 117 | mask = imgs[:, :, :, -1] 118 | imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2) 119 | imgs = imgs * 2 - 1. # convert to stable diffusion range 120 | poses = torch.tensor(np.array(poses).astype(np.float32)) 121 | return imgs, poses 122 | 123 | def blend_rgba(self, img): 124 | img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB 125 | return img 126 | 127 | class WILD(Dataset): 128 | def __init__(self, root_dir='data/nerf_wild',\ 129 | first_K=33, resolution=256, load_target=False): 130 | self.root_dir = root_dir 131 | self.scene_list = sorted(next(os.walk(root_dir))[1]) 132 | self.resolution = resolution 133 | self.first_K = first_K 134 | self.load_target = load_target 135 | 136 | def __len__(self): 137 | return len(self.scene_list) 138 | 139 | def __getitem__(self, idx): 140 | scene_dir = os.path.join(self.root_dir, self.scene_list[idx]) 141 | with open(os.path.join(scene_dir, 'transforms_train.json'), "r") as f: 142 | meta = json.load(f) 143 | imgs = [] 144 | poses = [] 145 | for i_img in range(self.first_K): 146 | meta_img = meta['frames'][i_img] 147 | 148 | if i_img == 0 or self.load_target: 149 | img_path = os.path.join(scene_dir, meta_img['file_path']) 150 | img = imageio.imread(img_path + '.png') 151 | img = cv2.resize(img, (self.resolution, self.resolution), interpolation = cv2.INTER_LINEAR) 152 | imgs.append(img) 153 | 154 | c2w = meta_img['transform_matrix'] 155 | poses.append(c2w) 156 | 157 | imgs = (np.array(imgs) / 255.).astype(np.float32) # (RGBA) imgs 158 | imgs = torch.tensor(self.blend_rgba(imgs)).permute(0, 3, 1, 2) 159 | imgs = imgs * 2 - 1. # convert to stable diffusion range 160 | poses = torch.tensor(np.array(poses).astype(np.float32)) 161 | return imgs, poses 162 | 163 | def blend_rgba(self, img): 164 | img = img[..., :3] * img[..., -1:] + (1. - img[..., -1:]) # blend A to RGB 165 | return img -------------------------------------------------------------------------------- /ldm/extras.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from omegaconf import OmegaConf 3 | import torch 4 | from ldm.util import instantiate_from_config 5 | import logging 6 | from contextlib import contextmanager 7 | 8 | from contextlib import contextmanager 9 | import logging 10 | 11 | @contextmanager 12 | def all_logging_disabled(highest_level=logging.CRITICAL): 13 | """ 14 | A context manager that will prevent any logging messages 15 | triggered during the body from being processed. 16 | 17 | :param highest_level: the maximum logging level in use. 18 | This would only need to be changed if a custom level greater than CRITICAL 19 | is defined. 20 | 21 | https://gist.github.com/simon-weber/7853144 22 | """ 23 | # two kind-of hacks here: 24 | # * can't get the highest logging level in effect => delegate to the user 25 | # * can't get the current module-level override => use an undocumented 26 | # (but non-private!) interface 27 | 28 | previous_level = logging.root.manager.disable 29 | 30 | logging.disable(highest_level) 31 | 32 | try: 33 | yield 34 | finally: 35 | logging.disable(previous_level) 36 | 37 | def load_training_dir(train_dir, device, epoch="last"): 38 | """Load a checkpoint and config from training directory""" 39 | train_dir = Path(train_dir) 40 | ckpt = list(train_dir.rglob(f"*{epoch}.ckpt")) 41 | assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files" 42 | config = list(train_dir.rglob(f"*-project.yaml")) 43 | assert len(ckpt) > 0, f"didn't find any config in {train_dir}" 44 | if len(config) > 1: 45 | print(f"found {len(config)} matching config files") 46 | config = sorted(config)[-1] 47 | print(f"selecting {config}") 48 | else: 49 | config = config[0] 50 | 51 | 52 | config = OmegaConf.load(config) 53 | return load_model_from_config(config, ckpt[0], device) 54 | 55 | def load_model_from_config(config, ckpt, device="cpu", verbose=False): 56 | """Loads a model from config and a ckpt 57 | if config is a path will use omegaconf to load 58 | """ 59 | if isinstance(config, (str, Path)): 60 | config = OmegaConf.load(config) 61 | 62 | with all_logging_disabled(): 63 | print(f"Loading model from {ckpt}") 64 | pl_sd = torch.load(ckpt, map_location="cpu") 65 | global_step = pl_sd["global_step"] 66 | sd = pl_sd["state_dict"] 67 | model = instantiate_from_config(config.model) 68 | m, u = model.load_state_dict(sd, strict=False) 69 | if len(m) > 0 and verbose: 70 | print("missing keys:") 71 | print(m) 72 | if len(u) > 0 and verbose: 73 | print("unexpected keys:") 74 | model.to(device) 75 | model.eval() 76 | model.cond_stage_model.device = device 77 | return model -------------------------------------------------------------------------------- /ldm/guidance.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | from scipy import interpolate 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | from IPython.display import clear_output 7 | import abc 8 | 9 | 10 | class GuideModel(torch.nn.Module, abc.ABC): 11 | def __init__(self) -> None: 12 | super().__init__() 13 | 14 | @abc.abstractmethod 15 | def preprocess(self, x_img): 16 | pass 17 | 18 | @abc.abstractmethod 19 | def compute_loss(self, inp): 20 | pass 21 | 22 | 23 | class Guider(torch.nn.Module): 24 | def __init__(self, sampler, guide_model, scale=1.0, verbose=False): 25 | """Apply classifier guidance 26 | 27 | Specify a guidance scale as either a scalar 28 | Or a schedule as a list of tuples t = 0->1 and scale, e.g. 29 | [(0, 10), (0.5, 20), (1, 50)] 30 | """ 31 | super().__init__() 32 | self.sampler = sampler 33 | self.index = 0 34 | self.show = verbose 35 | self.guide_model = guide_model 36 | self.history = [] 37 | 38 | if isinstance(scale, (Tuple, List)): 39 | times = np.array([x[0] for x in scale]) 40 | values = np.array([x[1] for x in scale]) 41 | self.scale_schedule = {"times": times, "values": values} 42 | else: 43 | self.scale_schedule = float(scale) 44 | 45 | self.ddim_timesteps = sampler.ddim_timesteps 46 | self.ddpm_num_timesteps = sampler.ddpm_num_timesteps 47 | 48 | 49 | def get_scales(self): 50 | if isinstance(self.scale_schedule, float): 51 | return len(self.ddim_timesteps)*[self.scale_schedule] 52 | 53 | interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"]) 54 | fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps 55 | return interpolater(fractional_steps) 56 | 57 | def modify_score(self, model, e_t, x, t, c): 58 | 59 | # TODO look up index by t 60 | scale = self.get_scales()[self.index] 61 | 62 | if (scale == 0): 63 | return e_t 64 | 65 | sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device) 66 | with torch.enable_grad(): 67 | x_in = x.detach().requires_grad_(True) 68 | pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t) 69 | x_img = model.first_stage_model.decode((1/0.18215)*pred_x0) 70 | 71 | inp = self.guide_model.preprocess(x_img) 72 | loss = self.guide_model.compute_loss(inp) 73 | grads = torch.autograd.grad(loss.sum(), x_in)[0] 74 | correction = grads * scale 75 | 76 | if self.show: 77 | clear_output(wait=True) 78 | print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item()) 79 | self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()]) 80 | plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2) 81 | plt.axis('off') 82 | plt.show() 83 | plt.imshow(correction[0][0].detach().cpu()) 84 | plt.axis('off') 85 | plt.show() 86 | 87 | 88 | e_t_mod = e_t - sqrt_1ma*correction 89 | if self.show: 90 | fig, axs = plt.subplots(1, 3) 91 | axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2) 92 | axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2) 93 | axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2) 94 | plt.show() 95 | self.index += 1 96 | return e_t_mod -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /ldm/models/__pycache__/autoencoder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/models/__pycache__/autoencoder.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/models/diffusion/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/models/diffusion/__pycache__/ddim.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/models/diffusion/__pycache__/ddpm.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/sampling_util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/models/diffusion/__pycache__/sampling_util.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | from torch.nn import functional as F 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from copy import deepcopy 9 | from einops import rearrange 10 | from glob import glob 11 | from natsort import natsorted 12 | 13 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel 14 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config 15 | 16 | __models__ = { 17 | 'class_label': EncoderUNetModel, 18 | 'segmentation': UNetModel 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | class NoisyLatentImageClassifier(pl.LightningModule): 29 | 30 | def __init__(self, 31 | diffusion_path, 32 | num_classes, 33 | ckpt_path=None, 34 | pool='attention', 35 | label_key=None, 36 | diffusion_ckpt_path=None, 37 | scheduler_config=None, 38 | weight_decay=1.e-2, 39 | log_steps=10, 40 | monitor='val/loss', 41 | *args, 42 | **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.num_classes = num_classes 45 | # get latest config of diffusion model 46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 47 | self.diffusion_config = OmegaConf.load(diffusion_config).model 48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 49 | self.load_diffusion() 50 | 51 | self.monitor = monitor 52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 54 | self.log_steps = log_steps 55 | 56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 57 | else self.diffusion_model.cond_stage_key 58 | 59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' 60 | 61 | if self.label_key not in __models__: 62 | raise NotImplementedError() 63 | 64 | self.load_classifier(ckpt_path, pool) 65 | 66 | self.scheduler_config = scheduler_config 67 | self.use_scheduler = self.scheduler_config is not None 68 | self.weight_decay = weight_decay 69 | 70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 71 | sd = torch.load(path, map_location="cpu") 72 | if "state_dict" in list(sd.keys()): 73 | sd = sd["state_dict"] 74 | keys = list(sd.keys()) 75 | for k in keys: 76 | for ik in ignore_keys: 77 | if k.startswith(ik): 78 | print("Deleting key {} from state_dict.".format(k)) 79 | del sd[k] 80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 81 | sd, strict=False) 82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | if len(unexpected) > 0: 86 | print(f"Unexpected Keys: {unexpected}") 87 | 88 | def load_diffusion(self): 89 | model = instantiate_from_config(self.diffusion_config) 90 | self.diffusion_model = model.eval() 91 | self.diffusion_model.train = disabled_train 92 | for param in self.diffusion_model.parameters(): 93 | param.requires_grad = False 94 | 95 | def load_classifier(self, ckpt_path, pool): 96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 98 | model_config.out_channels = self.num_classes 99 | if self.label_key == 'class_label': 100 | model_config.pool = pool 101 | 102 | self.model = __models__[self.label_key](**model_config) 103 | if ckpt_path is not None: 104 | print('#####################################################################') 105 | print(f'load from ckpt "{ckpt_path}"') 106 | print('#####################################################################') 107 | self.init_from_ckpt(ckpt_path) 108 | 109 | @torch.no_grad() 110 | def get_x_noisy(self, x, t, noise=None): 111 | noise = default(noise, lambda: torch.randn_like(x)) 112 | continuous_sqrt_alpha_cumprod = None 113 | if self.diffusion_model.use_continuous_noise: 114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 115 | # todo: make sure t+1 is correct here 116 | 117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 119 | 120 | def forward(self, x_noisy, t, *args, **kwargs): 121 | return self.model(x_noisy, t) 122 | 123 | @torch.no_grad() 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = rearrange(x, 'b h w c -> b c h w') 129 | x = x.to(memory_format=torch.contiguous_format).float() 130 | return x 131 | 132 | @torch.no_grad() 133 | def get_conditioning(self, batch, k=None): 134 | if k is None: 135 | k = self.label_key 136 | assert k is not None, 'Needs to provide label key' 137 | 138 | targets = batch[k].to(self.device) 139 | 140 | if self.label_key == 'segmentation': 141 | targets = rearrange(targets, 'b h w c -> b c h w') 142 | for down in range(self.numd): 143 | h, w = targets.shape[-2:] 144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 145 | 146 | # targets = rearrange(targets,'b c h w -> b h w c') 147 | 148 | return targets 149 | 150 | def compute_top_k(self, logits, labels, k, reduction="mean"): 151 | _, top_ks = torch.topk(logits, k, dim=1) 152 | if reduction == "mean": 153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 154 | elif reduction == "none": 155 | return (top_ks == labels[:, None]).float().sum(dim=-1) 156 | 157 | def on_train_epoch_start(self): 158 | # save some memory 159 | self.diffusion_model.model.to('cpu') 160 | 161 | @torch.no_grad() 162 | def write_logs(self, loss, logits, targets): 163 | log_prefix = 'train' if self.training else 'val' 164 | log = {} 165 | log[f"{log_prefix}/loss"] = loss.mean() 166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 167 | logits, targets, k=1, reduction="mean" 168 | ) 169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 170 | logits, targets, k=5, reduction="mean" 171 | ) 172 | 173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) 176 | lr = self.optimizers().param_groups[0]['lr'] 177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) 178 | 179 | def shared_step(self, batch, t=None): 180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 181 | targets = self.get_conditioning(batch) 182 | if targets.dim() == 4: 183 | targets = targets.argmax(dim=1) 184 | if t is None: 185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 186 | else: 187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 188 | x_noisy = self.get_x_noisy(x, t) 189 | logits = self(x_noisy, t) 190 | 191 | loss = F.cross_entropy(logits, targets, reduction='none') 192 | 193 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 194 | 195 | loss = loss.mean() 196 | return loss, logits, x_noisy, targets 197 | 198 | def training_step(self, batch, batch_idx): 199 | loss, *_ = self.shared_step(batch) 200 | return loss 201 | 202 | def reset_noise_accs(self): 203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 205 | 206 | def on_validation_start(self): 207 | self.reset_noise_accs() 208 | 209 | @torch.no_grad() 210 | def validation_step(self, batch, batch_idx): 211 | loss, *_ = self.shared_step(batch) 212 | 213 | for t in self.noisy_acc: 214 | _, logits, _, targets = self.shared_step(batch, t) 215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 217 | 218 | return loss 219 | 220 | def configure_optimizers(self): 221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 222 | 223 | if self.use_scheduler: 224 | scheduler = instantiate_from_config(self.scheduler_config) 225 | 226 | print("Setting up LambdaLR scheduler...") 227 | scheduler = [ 228 | { 229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 230 | 'interval': 'step', 231 | 'frequency': 1 232 | }] 233 | return [optimizer], scheduler 234 | 235 | return optimizer 236 | 237 | @torch.no_grad() 238 | def log_images(self, batch, N=8, *args, **kwargs): 239 | log = dict() 240 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 241 | log['inputs'] = x 242 | 243 | y = self.get_conditioning(batch) 244 | 245 | if self.label_key == 'class_label': 246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 247 | log['labels'] = y 248 | 249 | if ismap(y): 250 | log['labels'] = self.diffusion_model.to_rgb(y) 251 | 252 | for step in range(self.log_steps): 253 | current_time = step * self.log_time_interval 254 | 255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 256 | 257 | log[f'inputs@t{current_time}'] = x_noisy 258 | 259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 260 | pred = rearrange(pred, 'b h w c -> b c h w') 261 | 262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 263 | 264 | for key in log: 265 | log[key] = log[key][:N] 266 | 267 | return log 268 | -------------------------------------------------------------------------------- /ldm/models/diffusion/plms.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | from functools import partial 7 | 8 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 9 | from ldm.models.diffusion.sampling_util import norm_thresholding 10 | 11 | 12 | class PLMSSampler(object): 13 | def __init__(self, model, schedule="linear", **kwargs): 14 | super().__init__() 15 | self.model = model 16 | self.ddpm_num_timesteps = model.num_timesteps 17 | self.schedule = schedule 18 | 19 | def register_buffer(self, name, attr): 20 | if type(attr) == torch.Tensor: 21 | if attr.device != torch.device("cuda"): 22 | attr = attr.to(torch.device("cuda")) 23 | setattr(self, name, attr) 24 | 25 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 26 | if ddim_eta != 0: 27 | raise ValueError('ddim_eta must be 0 for PLMS') 28 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 29 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 30 | alphas_cumprod = self.model.alphas_cumprod 31 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 32 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 33 | 34 | self.register_buffer('betas', to_torch(self.model.betas)) 35 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 36 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 37 | 38 | # calculations for diffusion q(x_t | x_{t-1}) and others 39 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 40 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 41 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 42 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 43 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 44 | 45 | # ddim sampling parameters 46 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 47 | ddim_timesteps=self.ddim_timesteps, 48 | eta=ddim_eta,verbose=verbose) 49 | self.register_buffer('ddim_sigmas', ddim_sigmas) 50 | self.register_buffer('ddim_alphas', ddim_alphas) 51 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 52 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 53 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 54 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 55 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 56 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 57 | 58 | @torch.no_grad() 59 | def sample(self, 60 | S, 61 | batch_size, 62 | shape, 63 | conditioning=None, 64 | callback=None, 65 | normals_sequence=None, 66 | img_callback=None, 67 | quantize_x0=False, 68 | eta=0., 69 | mask=None, 70 | x0=None, 71 | temperature=1., 72 | noise_dropout=0., 73 | score_corrector=None, 74 | corrector_kwargs=None, 75 | verbose=True, 76 | x_T=None, 77 | log_every_t=100, 78 | unconditional_guidance_scale=1., 79 | unconditional_conditioning=None, 80 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 81 | dynamic_threshold=None, 82 | **kwargs 83 | ): 84 | if conditioning is not None: 85 | if isinstance(conditioning, dict): 86 | ctmp = conditioning[list(conditioning.keys())[0]] 87 | while isinstance(ctmp, list): ctmp = ctmp[0] 88 | cbs = ctmp.shape[0] 89 | if cbs != batch_size: 90 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 91 | else: 92 | if conditioning.shape[0] != batch_size: 93 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 94 | 95 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 96 | # sampling 97 | C, H, W = shape 98 | size = (batch_size, C, H, W) 99 | print(f'Data shape for PLMS sampling is {size}') 100 | 101 | samples, intermediates = self.plms_sampling(conditioning, size, 102 | callback=callback, 103 | img_callback=img_callback, 104 | quantize_denoised=quantize_x0, 105 | mask=mask, x0=x0, 106 | ddim_use_original_steps=False, 107 | noise_dropout=noise_dropout, 108 | temperature=temperature, 109 | score_corrector=score_corrector, 110 | corrector_kwargs=corrector_kwargs, 111 | x_T=x_T, 112 | log_every_t=log_every_t, 113 | unconditional_guidance_scale=unconditional_guidance_scale, 114 | unconditional_conditioning=unconditional_conditioning, 115 | dynamic_threshold=dynamic_threshold, 116 | ) 117 | return samples, intermediates 118 | 119 | @torch.no_grad() 120 | def plms_sampling(self, cond, shape, 121 | x_T=None, ddim_use_original_steps=False, 122 | callback=None, timesteps=None, quantize_denoised=False, 123 | mask=None, x0=None, img_callback=None, log_every_t=100, 124 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 125 | unconditional_guidance_scale=1., unconditional_conditioning=None, 126 | dynamic_threshold=None): 127 | device = self.model.betas.device 128 | b = shape[0] 129 | if x_T is None: 130 | img = torch.randn(shape, device=device) 131 | else: 132 | img = x_T 133 | 134 | if timesteps is None: 135 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 136 | elif timesteps is not None and not ddim_use_original_steps: 137 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 138 | timesteps = self.ddim_timesteps[:subset_end] 139 | 140 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 141 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) 142 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 143 | print(f"Running PLMS Sampling with {total_steps} timesteps") 144 | 145 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) 146 | old_eps = [] 147 | 148 | for i, step in enumerate(iterator): 149 | index = total_steps - i - 1 150 | ts = torch.full((b,), step, device=device, dtype=torch.long) 151 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 152 | 153 | if mask is not None: 154 | assert x0 is not None 155 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 156 | img = img_orig * mask + (1. - mask) * img 157 | 158 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 159 | quantize_denoised=quantize_denoised, temperature=temperature, 160 | noise_dropout=noise_dropout, score_corrector=score_corrector, 161 | corrector_kwargs=corrector_kwargs, 162 | unconditional_guidance_scale=unconditional_guidance_scale, 163 | unconditional_conditioning=unconditional_conditioning, 164 | old_eps=old_eps, t_next=ts_next, 165 | dynamic_threshold=dynamic_threshold) 166 | img, pred_x0, e_t = outs 167 | old_eps.append(e_t) 168 | if len(old_eps) >= 4: 169 | old_eps.pop(0) 170 | if callback: callback(i) 171 | if img_callback: img_callback(pred_x0, i) 172 | 173 | if index % log_every_t == 0 or index == total_steps - 1: 174 | intermediates['x_inter'].append(img) 175 | intermediates['pred_x0'].append(pred_x0) 176 | 177 | return img, intermediates 178 | 179 | @torch.no_grad() 180 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 181 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 182 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, 183 | dynamic_threshold=None): 184 | b, *_, device = *x.shape, x.device 185 | 186 | def get_model_output(x, t): 187 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 188 | e_t = self.model.apply_model(x, t, c) 189 | else: 190 | x_in = torch.cat([x] * 2) 191 | t_in = torch.cat([t] * 2) 192 | if isinstance(c, dict): 193 | assert isinstance(unconditional_conditioning, dict) 194 | c_in = dict() 195 | for k in c: 196 | if isinstance(c[k], list): 197 | c_in[k] = [torch.cat([ 198 | unconditional_conditioning[k][i], 199 | c[k][i]]) for i in range(len(c[k]))] 200 | else: 201 | c_in[k] = torch.cat([ 202 | unconditional_conditioning[k], 203 | c[k]]) 204 | else: 205 | c_in = torch.cat([unconditional_conditioning, c]) 206 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 207 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 208 | 209 | if score_corrector is not None: 210 | assert self.model.parameterization == "eps" 211 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 212 | 213 | return e_t 214 | 215 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 216 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 217 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 218 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 219 | 220 | def get_x_prev_and_pred_x0(e_t, index): 221 | # select parameters corresponding to the currently considered timestep 222 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 223 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 224 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 225 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 226 | 227 | # current prediction for x_0 228 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 229 | if quantize_denoised: 230 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 231 | if dynamic_threshold is not None: 232 | pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) 233 | # direction pointing to x_t 234 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 235 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 236 | if noise_dropout > 0.: 237 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 238 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 239 | return x_prev, pred_x0 240 | 241 | e_t = get_model_output(x, t) 242 | if len(old_eps) == 0: 243 | # Pseudo Improved Euler (2nd order) 244 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 245 | e_t_next = get_model_output(x_prev, t_next) 246 | e_t_prime = (e_t + e_t_next) / 2 247 | elif len(old_eps) == 1: 248 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 249 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 250 | elif len(old_eps) == 2: 251 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 252 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 253 | elif len(old_eps) >= 3: 254 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 255 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 256 | 257 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 258 | 259 | return x_prev, pred_x0, e_t 260 | -------------------------------------------------------------------------------- /ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def renorm_thresholding(x0, value): 15 | # renorm 16 | pred_max = x0.max() 17 | pred_min = x0.min() 18 | pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1 19 | pred_x0 = 2 * pred_x0 - 1. # -1 ... 1 20 | 21 | s = torch.quantile( 22 | rearrange(pred_x0, 'b ... -> b (...)').abs(), 23 | value, 24 | dim=-1 25 | ) 26 | s.clamp_(min=1.0) 27 | s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) 28 | 29 | # clip by threshold 30 | # pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max 31 | 32 | # temporary hack: numpy on cpu 33 | pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy() 34 | pred_x0 = torch.tensor(pred_x0).to(self.model.device) 35 | 36 | # re.renorm 37 | pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1 38 | pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range 39 | return pred_x0 40 | 41 | 42 | def norm_thresholding(x0, value): 43 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 44 | return x0 * (value / s) 45 | 46 | 47 | def spatial_norm_thresholding(x0, value): 48 | # b c h w 49 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 50 | return x0 * (value / s) -------------------------------------------------------------------------------- /ldm/modules/__pycache__/attention.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/modules/__pycache__/attention.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/ema.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/modules/__pycache__/ema.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/x_transformer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/modules/__pycache__/x_transformer.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads 172 | 173 | q = self.to_q(x) 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 181 | 182 | if exists(mask): 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, 198 | disable_self_attn=False): 199 | super().__init__() 200 | self.disable_self_attn = disable_self_attn 201 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout, 202 | context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn 203 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 204 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 205 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 206 | self.norm1 = nn.LayerNorm(dim) 207 | self.norm2 = nn.LayerNorm(dim) 208 | self.norm3 = nn.LayerNorm(dim) 209 | self.checkpoint = checkpoint 210 | 211 | def forward(self, x, context=None): 212 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 213 | 214 | def _forward(self, x, context=None): 215 | x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x 216 | x = self.attn2(self.norm2(x), context=context) + x 217 | x = self.ff(self.norm3(x)) + x 218 | return x 219 | 220 | 221 | class SpatialTransformer(nn.Module): 222 | """ 223 | Transformer block for image-like data. 224 | First, project the input (aka embedding) 225 | and reshape to b, t, d. 226 | Then apply standard transformer action. 227 | Finally, reshape to image 228 | """ 229 | def __init__(self, in_channels, n_heads, d_head, 230 | depth=1, dropout=0., context_dim=None, 231 | disable_self_attn=False): 232 | super().__init__() 233 | self.in_channels = in_channels 234 | inner_dim = n_heads * d_head 235 | self.norm = Normalize(in_channels) 236 | 237 | self.proj_in = nn.Conv2d(in_channels, 238 | inner_dim, 239 | kernel_size=1, 240 | stride=1, 241 | padding=0) 242 | 243 | self.transformer_blocks = nn.ModuleList( 244 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim, 245 | disable_self_attn=disable_self_attn) 246 | for d in range(depth)] 247 | ) 248 | 249 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 250 | in_channels, 251 | kernel_size=1, 252 | stride=1, 253 | padding=0)) 254 | 255 | def forward(self, x, context=None): 256 | # note: if no context is given, cross-attention defaults to self-attention 257 | b, c, h, w = x.shape 258 | x_in = x 259 | x = self.norm(x) 260 | x = self.proj_in(x) 261 | x = rearrange(x, 'b c h w -> b (h w) c').contiguous() 262 | for block in self.transformer_blocks: 263 | x = block(x, context=context) 264 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() 265 | x = self.proj_out(x) 266 | return x + x_in 267 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/modules/diffusionmodules/__pycache__/model.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/modules/diffusionmodules/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from ldm.util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNorm32(32, channels) 206 | 207 | 208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 209 | class SiLU(nn.Module): 210 | def forward(self, x): 211 | return x * torch.sigmoid(x) 212 | 213 | 214 | class GroupNorm32(nn.GroupNorm): 215 | def forward(self, x): 216 | return super().forward(x.float()).type(x.dtype) 217 | 218 | def conv_nd(dims, *args, **kwargs): 219 | """ 220 | Create a 1D, 2D, or 3D convolution module. 221 | """ 222 | if dims == 1: 223 | return nn.Conv1d(*args, **kwargs) 224 | elif dims == 2: 225 | return nn.Conv2d(*args, **kwargs) 226 | elif dims == 3: 227 | return nn.Conv3d(*args, **kwargs) 228 | raise ValueError(f"unsupported dimensions: {dims}") 229 | 230 | 231 | def linear(*args, **kwargs): 232 | """ 233 | Create a linear module. 234 | """ 235 | return nn.Linear(*args, **kwargs) 236 | 237 | 238 | def avg_pool_nd(dims, *args, **kwargs): 239 | """ 240 | Create a 1D, 2D, or 3D average pooling module. 241 | """ 242 | if dims == 1: 243 | return nn.AvgPool1d(*args, **kwargs) 244 | elif dims == 2: 245 | return nn.AvgPool2d(*args, **kwargs) 246 | elif dims == 3: 247 | return nn.AvgPool3d(*args, **kwargs) 248 | raise ValueError(f"unsupported dimensions: {dims}") 249 | 250 | 251 | class HybridConditioner(nn.Module): 252 | 253 | def __init__(self, c_concat_config, c_crossattn_config): 254 | super().__init__() 255 | self.concat_conditioner = instantiate_from_config(c_concat_config) 256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 257 | 258 | def forward(self, c_concat, c_crossattn): 259 | c_concat = self.concat_conditioner(c_concat) 260 | c_crossattn = self.crossattn_conditioner(c_crossattn) 261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 262 | 263 | 264 | def noise_like(shape, device, repeat=False): 265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 266 | noise = lambda: torch.randn(shape, device=device) 267 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/modules/distributions/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/distributions.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/modules/distributions/__pycache__/distributions.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/modules/encoders/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/modules.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/modules/encoders/__pycache__/modules.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/modules/evaluate/frechet_video_distance.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Minimal Reference implementation for the Frechet Video Distance (FVD). 18 | 19 | FVD is a metric for the quality of video generation models. It is inspired by 20 | the FID (Frechet Inception Distance) used for images, but uses a different 21 | embedding to be better suitable for videos. 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | 29 | import six 30 | import tensorflow.compat.v1 as tf 31 | import tensorflow_gan as tfgan 32 | import tensorflow_hub as hub 33 | 34 | 35 | def preprocess(videos, target_resolution): 36 | """Runs some preprocessing on the videos for I3D model. 37 | 38 | Args: 39 | videos: [batch_size, num_frames, height, width, depth] The videos to be 40 | preprocessed. We don't care about the specific dtype of the videos, it can 41 | be anything that tf.image.resize_bilinear accepts. Values are expected to 42 | be in the range 0-255. 43 | target_resolution: (width, height): target video resolution 44 | 45 | Returns: 46 | videos: [batch_size, num_frames, height, width, depth] 47 | """ 48 | videos_shape = list(videos.shape) 49 | all_frames = tf.reshape(videos, [-1] + videos_shape[-3:]) 50 | resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution) 51 | target_shape = [videos_shape[0], -1] + list(target_resolution) + [3] 52 | output_videos = tf.reshape(resized_videos, target_shape) 53 | scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1 54 | return scaled_videos 55 | 56 | 57 | def _is_in_graph(tensor_name): 58 | """Checks whether a given tensor does exists in the graph.""" 59 | try: 60 | tf.get_default_graph().get_tensor_by_name(tensor_name) 61 | except KeyError: 62 | return False 63 | return True 64 | 65 | 66 | def create_id3_embedding(videos,warmup=False,batch_size=16): 67 | """Embeds the given videos using the Inflated 3D Convolution ne twork. 68 | 69 | Downloads the graph of the I3D from tf.hub and adds it to the graph on the 70 | first call. 71 | 72 | Args: 73 | videos: [batch_size, num_frames, height=224, width=224, depth=3]. 74 | Expected range is [-1, 1]. 75 | 76 | Returns: 77 | embedding: [batch_size, embedding_size]. embedding_size depends 78 | on the model used. 79 | 80 | Raises: 81 | ValueError: when a provided embedding_layer is not supported. 82 | """ 83 | 84 | # batch_size = 16 85 | module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1" 86 | 87 | 88 | # Making sure that we import the graph separately for 89 | # each different input video tensor. 90 | module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str( 91 | videos.name).replace(":", "_") 92 | 93 | 94 | 95 | assert_ops = [ 96 | tf.Assert( 97 | tf.reduce_max(videos) <= 1.001, 98 | ["max value in frame is > 1", videos]), 99 | tf.Assert( 100 | tf.reduce_min(videos) >= -1.001, 101 | ["min value in frame is < -1", videos]), 102 | tf.assert_equal( 103 | tf.shape(videos)[0], 104 | batch_size, ["invalid frame batch size: ", 105 | tf.shape(videos)], 106 | summarize=6), 107 | ] 108 | with tf.control_dependencies(assert_ops): 109 | videos = tf.identity(videos) 110 | 111 | module_scope = "%s_apply_default/" % module_name 112 | 113 | # To check whether the module has already been loaded into the graph, we look 114 | # for a given tensor name. If this tensor name exists, we assume the function 115 | # has been called before and the graph was imported. Otherwise we import it. 116 | # Note: in theory, the tensor could exist, but have wrong shapes. 117 | # This will happen if create_id3_embedding is called with a frames_placehoder 118 | # of wrong size/batch size, because even though that will throw a tf.Assert 119 | # on graph-execution time, it will insert the tensor (with wrong shape) into 120 | # the graph. This is why we need the following assert. 121 | if warmup: 122 | video_batch_size = int(videos.shape[0]) 123 | assert video_batch_size in [batch_size, -1, None], f"Invalid batch size {video_batch_size}" 124 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0" 125 | if not _is_in_graph(tensor_name): 126 | i3d_model = hub.Module(module_spec, name=module_name) 127 | i3d_model(videos) 128 | 129 | # gets the kinetics-i3d-400-logits layer 130 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0" 131 | tensor = tf.get_default_graph().get_tensor_by_name(tensor_name) 132 | return tensor 133 | 134 | 135 | def calculate_fvd(real_activations, 136 | generated_activations): 137 | """Returns a list of ops that compute metrics as funcs of activations. 138 | 139 | Args: 140 | real_activations: [num_samples, embedding_size] 141 | generated_activations: [num_samples, embedding_size] 142 | 143 | Returns: 144 | A scalar that contains the requested FVD. 145 | """ 146 | return tfgan.eval.frechet_classifier_distance_from_activations( 147 | real_activations, generated_activations) 148 | -------------------------------------------------------------------------------- /ldm/modules/evaluate/ssim.py: -------------------------------------------------------------------------------- 1 | # MIT Licence 2 | 3 | # Methods to predict the SSIM, taken from 4 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 5 | 6 | from math import exp 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | 12 | def gaussian(window_size, sigma): 13 | gauss = torch.Tensor( 14 | [ 15 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) 16 | for x in range(window_size) 17 | ] 18 | ) 19 | return gauss / gauss.sum() 20 | 21 | 22 | def create_window(window_size, channel): 23 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 24 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 25 | window = Variable( 26 | _2D_window.expand(channel, 1, window_size, window_size).contiguous() 27 | ) 28 | return window 29 | 30 | 31 | def _ssim( 32 | img1, img2, window, window_size, channel, mask=None, size_average=True 33 | ): 34 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 35 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 36 | 37 | mu1_sq = mu1.pow(2) 38 | mu2_sq = mu2.pow(2) 39 | mu1_mu2 = mu1 * mu2 40 | 41 | sigma1_sq = ( 42 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) 43 | - mu1_sq 44 | ) 45 | sigma2_sq = ( 46 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) 47 | - mu2_sq 48 | ) 49 | sigma12 = ( 50 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) 51 | - mu1_mu2 52 | ) 53 | 54 | C1 = (0.01) ** 2 55 | C2 = (0.03) ** 2 56 | 57 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 58 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 59 | ) 60 | 61 | if not (mask is None): 62 | b = mask.size(0) 63 | ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask 64 | ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum( 65 | dim=1 66 | ).clamp(min=1) 67 | return ssim_map 68 | 69 | import pdb 70 | 71 | pdb.set_trace 72 | 73 | if size_average: 74 | return ssim_map.mean() 75 | else: 76 | return ssim_map.mean(1).mean(1).mean(1) 77 | 78 | 79 | class SSIM(torch.nn.Module): 80 | def __init__(self, window_size=11, size_average=True): 81 | super(SSIM, self).__init__() 82 | self.window_size = window_size 83 | self.size_average = size_average 84 | self.channel = 1 85 | self.window = create_window(window_size, self.channel) 86 | 87 | def forward(self, img1, img2, mask=None): 88 | (_, channel, _, _) = img1.size() 89 | 90 | if ( 91 | channel == self.channel 92 | and self.window.data.type() == img1.data.type() 93 | ): 94 | window = self.window 95 | else: 96 | window = create_window(self.window_size, channel) 97 | 98 | if img1.is_cuda: 99 | window = window.cuda(img1.get_device()) 100 | window = window.type_as(img1) 101 | 102 | self.window = window 103 | self.channel = channel 104 | 105 | return _ssim( 106 | img1, 107 | img2, 108 | window, 109 | self.window_size, 110 | channel, 111 | mask, 112 | self.size_average, 113 | ) 114 | 115 | 116 | def ssim(img1, img2, window_size=11, mask=None, size_average=True): 117 | (_, channel, _, _) = img1.size() 118 | window = create_window(window_size, channel) 119 | 120 | if img1.is_cuda: 121 | window = window.cuda(img1.get_device()) 122 | window = window.type_as(img1) 123 | 124 | return _ssim(img1, img2, window, window_size, channel, mask, size_average) 125 | -------------------------------------------------------------------------------- /ldm/modules/evaluate/torch_frechet_video_distance.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/universome/fvd-comparison/blob/master/compare_models.py; huge thanks! 2 | import os 3 | import numpy as np 4 | import io 5 | import re 6 | import requests 7 | import html 8 | import hashlib 9 | import urllib 10 | import urllib.request 11 | import scipy.linalg 12 | import multiprocessing as mp 13 | import glob 14 | 15 | 16 | from tqdm import tqdm 17 | from typing import Any, List, Tuple, Union, Dict, Callable 18 | 19 | from torchvision.io import read_video 20 | import torch; torch.set_grad_enabled(False) 21 | from einops import rearrange 22 | 23 | from nitro.util import isvideo 24 | 25 | def compute_frechet_distance(mu_sample,sigma_sample,mu_ref,sigma_ref) -> float: 26 | print('Calculate frechet distance...') 27 | m = np.square(mu_sample - mu_ref).sum() 28 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_sample, sigma_ref), disp=False) # pylint: disable=no-member 29 | fid = np.real(m + np.trace(sigma_sample + sigma_ref - s * 2)) 30 | 31 | return float(fid) 32 | 33 | 34 | def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: 35 | mu = feats.mean(axis=0) # [d] 36 | sigma = np.cov(feats, rowvar=False) # [d, d] 37 | 38 | return mu, sigma 39 | 40 | 41 | def open_url(url: str, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False) -> Any: 42 | """Download the given URL and return a binary-mode file object to access the data.""" 43 | assert num_attempts >= 1 44 | 45 | # Doesn't look like an URL scheme so interpret it as a local filename. 46 | if not re.match('^[a-z]+://', url): 47 | return url if return_filename else open(url, "rb") 48 | 49 | # Handle file URLs. This code handles unusual file:// patterns that 50 | # arise on Windows: 51 | # 52 | # file:///c:/foo.txt 53 | # 54 | # which would translate to a local '/c:/foo.txt' filename that's 55 | # invalid. Drop the forward slash for such pathnames. 56 | # 57 | # If you touch this code path, you should test it on both Linux and 58 | # Windows. 59 | # 60 | # Some internet resources suggest using urllib.request.url2pathname() but 61 | # but that converts forward slashes to backslashes and this causes 62 | # its own set of problems. 63 | if url.startswith('file://'): 64 | filename = urllib.parse.urlparse(url).path 65 | if re.match(r'^/[a-zA-Z]:', filename): 66 | filename = filename[1:] 67 | return filename if return_filename else open(filename, "rb") 68 | 69 | url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest() 70 | 71 | # Download. 72 | url_name = None 73 | url_data = None 74 | with requests.Session() as session: 75 | if verbose: 76 | print("Downloading %s ..." % url, end="", flush=True) 77 | for attempts_left in reversed(range(num_attempts)): 78 | try: 79 | with session.get(url) as res: 80 | res.raise_for_status() 81 | if len(res.content) == 0: 82 | raise IOError("No data received") 83 | 84 | if len(res.content) < 8192: 85 | content_str = res.content.decode("utf-8") 86 | if "download_warning" in res.headers.get("Set-Cookie", ""): 87 | links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link] 88 | if len(links) == 1: 89 | url = requests.compat.urljoin(url, links[0]) 90 | raise IOError("Google Drive virus checker nag") 91 | if "Google Drive - Quota exceeded" in content_str: 92 | raise IOError("Google Drive download quota exceeded -- please try again later") 93 | 94 | match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", "")) 95 | url_name = match[1] if match else url 96 | url_data = res.content 97 | if verbose: 98 | print(" done") 99 | break 100 | except KeyboardInterrupt: 101 | raise 102 | except: 103 | if not attempts_left: 104 | if verbose: 105 | print(" failed") 106 | raise 107 | if verbose: 108 | print(".", end="", flush=True) 109 | 110 | # Return data as file object. 111 | assert not return_filename 112 | return io.BytesIO(url_data) 113 | 114 | def load_video(ip): 115 | vid, *_ = read_video(ip) 116 | vid = rearrange(vid, 't h w c -> t c h w').to(torch.uint8) 117 | return vid 118 | 119 | def get_data_from_str(input_str,nprc = None): 120 | assert os.path.isdir(input_str), f'Specified input folder "{input_str}" is not a directory' 121 | vid_filelist = glob.glob(os.path.join(input_str,'*.mp4')) 122 | print(f'Found {len(vid_filelist)} videos in dir {input_str}') 123 | 124 | if nprc is None: 125 | try: 126 | nprc = mp.cpu_count() 127 | except NotImplementedError: 128 | print('WARNING: cpu_count() not avlailable, using only 1 cpu for video loading') 129 | nprc = 1 130 | 131 | pool = mp.Pool(processes=nprc) 132 | 133 | vids = [] 134 | for v in tqdm(pool.imap_unordered(load_video,vid_filelist),total=len(vid_filelist),desc='Loading videos...'): 135 | vids.append(v) 136 | 137 | 138 | vids = torch.stack(vids,dim=0).float() 139 | 140 | return vids 141 | 142 | def get_stats(stats): 143 | assert os.path.isfile(stats) and stats.endswith('.npz'), f'no stats found under {stats}' 144 | 145 | print(f'Using precomputed statistics under {stats}') 146 | stats = np.load(stats) 147 | stats = {key: stats[key] for key in stats.files} 148 | 149 | return stats 150 | 151 | 152 | 153 | 154 | @torch.no_grad() 155 | def compute_fvd(ref_input, sample_input, bs=32, 156 | ref_stats=None, 157 | sample_stats=None, 158 | nprc_load=None): 159 | 160 | 161 | 162 | calc_stats = ref_stats is None or sample_stats is None 163 | 164 | if calc_stats: 165 | 166 | only_ref = sample_stats is not None 167 | only_sample = ref_stats is not None 168 | 169 | 170 | if isinstance(ref_input,str) and not only_sample: 171 | ref_input = get_data_from_str(ref_input,nprc_load) 172 | 173 | if isinstance(sample_input, str) and not only_ref: 174 | sample_input = get_data_from_str(sample_input, nprc_load) 175 | 176 | stats = compute_statistics(sample_input,ref_input, 177 | device='cuda' if torch.cuda.is_available() else 'cpu', 178 | bs=bs, 179 | only_ref=only_ref, 180 | only_sample=only_sample) 181 | 182 | if only_ref: 183 | stats.update(get_stats(sample_stats)) 184 | elif only_sample: 185 | stats.update(get_stats(ref_stats)) 186 | 187 | 188 | 189 | else: 190 | stats = get_stats(sample_stats) 191 | stats.update(get_stats(ref_stats)) 192 | 193 | fvd = compute_frechet_distance(**stats) 194 | 195 | return {'FVD' : fvd,} 196 | 197 | 198 | @torch.no_grad() 199 | def compute_statistics(videos_fake, videos_real, device: str='cuda', bs=32, only_ref=False,only_sample=False) -> Dict: 200 | detector_url = 'https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt?dl=1' 201 | detector_kwargs = dict(rescale=True, resize=True, return_features=True) # Return raw features before the softmax layer. 202 | 203 | with open_url(detector_url, verbose=False) as f: 204 | detector = torch.jit.load(f).eval().to(device) 205 | 206 | 207 | 208 | assert not (only_sample and only_ref), 'only_ref and only_sample arguments are mutually exclusive' 209 | 210 | ref_embed, sample_embed = [], [] 211 | 212 | info = f'Computing I3D activations for FVD score with batch size {bs}' 213 | 214 | if only_ref: 215 | 216 | if not isvideo(videos_real): 217 | # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255] 218 | videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float() 219 | print(videos_real.shape) 220 | 221 | if videos_real.shape[0] % bs == 0: 222 | n_secs = videos_real.shape[0] // bs 223 | else: 224 | n_secs = videos_real.shape[0] // bs + 1 225 | 226 | videos_real = torch.tensor_split(videos_real, n_secs, dim=0) 227 | 228 | for ref_v in tqdm(videos_real, total=len(videos_real),desc=info): 229 | 230 | feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() 231 | ref_embed.append(feats_ref) 232 | 233 | elif only_sample: 234 | 235 | if not isvideo(videos_fake): 236 | # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255] 237 | videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float() 238 | print(videos_fake.shape) 239 | 240 | if videos_fake.shape[0] % bs == 0: 241 | n_secs = videos_fake.shape[0] // bs 242 | else: 243 | n_secs = videos_fake.shape[0] // bs + 1 244 | 245 | videos_real = torch.tensor_split(videos_real, n_secs, dim=0) 246 | 247 | for sample_v in tqdm(videos_fake, total=len(videos_real),desc=info): 248 | feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() 249 | sample_embed.append(feats_sample) 250 | 251 | 252 | else: 253 | 254 | if not isvideo(videos_real): 255 | # if not is video we assume to have numpy arrays pf shape (n_vids, t, h, w, c) in range [0,255] 256 | videos_real = torch.from_numpy(videos_real).permute(0, 4, 1, 2, 3).float() 257 | 258 | if not isvideo(videos_fake): 259 | videos_fake = torch.from_numpy(videos_fake).permute(0, 4, 1, 2, 3).float() 260 | 261 | if videos_fake.shape[0] % bs == 0: 262 | n_secs = videos_fake.shape[0] // bs 263 | else: 264 | n_secs = videos_fake.shape[0] // bs + 1 265 | 266 | videos_real = torch.tensor_split(videos_real, n_secs, dim=0) 267 | videos_fake = torch.tensor_split(videos_fake, n_secs, dim=0) 268 | 269 | for ref_v, sample_v in tqdm(zip(videos_real,videos_fake),total=len(videos_fake),desc=info): 270 | # print(ref_v.shape) 271 | # ref_v = torch.nn.functional.interpolate(ref_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False) 272 | # sample_v = torch.nn.functional.interpolate(sample_v, size=(sample_v.shape[2], 256, 256), mode='trilinear', align_corners=False) 273 | 274 | 275 | feats_sample = detector(sample_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() 276 | feats_ref = detector(ref_v.to(device).contiguous(), **detector_kwargs).cpu().numpy() 277 | sample_embed.append(feats_sample) 278 | ref_embed.append(feats_ref) 279 | 280 | out = dict() 281 | if len(sample_embed) > 0: 282 | sample_embed = np.concatenate(sample_embed,axis=0) 283 | mu_sample, sigma_sample = compute_stats(sample_embed) 284 | out.update({'mu_sample': mu_sample, 285 | 'sigma_sample': sigma_sample}) 286 | 287 | if len(ref_embed) > 0: 288 | ref_embed = np.concatenate(ref_embed,axis=0) 289 | mu_ref, sigma_ref = compute_stats(ref_embed) 290 | out.update({'mu_ref': mu_ref, 291 | 'sigma_ref': sigma_ref}) 292 | 293 | 294 | return out 295 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /ldm/thirdp/psp/__pycache__/helpers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/thirdp/psp/__pycache__/helpers.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/thirdp/psp/__pycache__/id_loss.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/thirdp/psp/__pycache__/id_loss.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/thirdp/psp/__pycache__/model_irse.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/allenai/object-edit/054bb09b8989754169a49bdce73bb767d745750b/ldm/thirdp/psp/__pycache__/model_irse.cpython-39.pyc -------------------------------------------------------------------------------- /ldm/thirdp/psp/helpers.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | 3 | from collections import namedtuple 4 | import torch 5 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 6 | 7 | """ 8 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 9 | """ 10 | 11 | 12 | class Flatten(Module): 13 | def forward(self, input): 14 | return input.view(input.size(0), -1) 15 | 16 | 17 | def l2_norm(input, axis=1): 18 | norm = torch.norm(input, 2, axis, True) 19 | output = torch.div(input, norm) 20 | return output 21 | 22 | 23 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 24 | """ A named tuple describing a ResNet block. """ 25 | 26 | 27 | def get_block(in_channel, depth, num_units, stride=2): 28 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 29 | 30 | 31 | def get_blocks(num_layers): 32 | if num_layers == 50: 33 | blocks = [ 34 | get_block(in_channel=64, depth=64, num_units=3), 35 | get_block(in_channel=64, depth=128, num_units=4), 36 | get_block(in_channel=128, depth=256, num_units=14), 37 | get_block(in_channel=256, depth=512, num_units=3) 38 | ] 39 | elif num_layers == 100: 40 | blocks = [ 41 | get_block(in_channel=64, depth=64, num_units=3), 42 | get_block(in_channel=64, depth=128, num_units=13), 43 | get_block(in_channel=128, depth=256, num_units=30), 44 | get_block(in_channel=256, depth=512, num_units=3) 45 | ] 46 | elif num_layers == 152: 47 | blocks = [ 48 | get_block(in_channel=64, depth=64, num_units=3), 49 | get_block(in_channel=64, depth=128, num_units=8), 50 | get_block(in_channel=128, depth=256, num_units=36), 51 | get_block(in_channel=256, depth=512, num_units=3) 52 | ] 53 | else: 54 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 55 | return blocks 56 | 57 | 58 | class SEModule(Module): 59 | def __init__(self, channels, reduction): 60 | super(SEModule, self).__init__() 61 | self.avg_pool = AdaptiveAvgPool2d(1) 62 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 63 | self.relu = ReLU(inplace=True) 64 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 65 | self.sigmoid = Sigmoid() 66 | 67 | def forward(self, x): 68 | module_input = x 69 | x = self.avg_pool(x) 70 | x = self.fc1(x) 71 | x = self.relu(x) 72 | x = self.fc2(x) 73 | x = self.sigmoid(x) 74 | return module_input * x 75 | 76 | 77 | class bottleneck_IR(Module): 78 | def __init__(self, in_channel, depth, stride): 79 | super(bottleneck_IR, self).__init__() 80 | if in_channel == depth: 81 | self.shortcut_layer = MaxPool2d(1, stride) 82 | else: 83 | self.shortcut_layer = Sequential( 84 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 85 | BatchNorm2d(depth) 86 | ) 87 | self.res_layer = Sequential( 88 | BatchNorm2d(in_channel), 89 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 90 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 91 | ) 92 | 93 | def forward(self, x): 94 | shortcut = self.shortcut_layer(x) 95 | res = self.res_layer(x) 96 | return res + shortcut 97 | 98 | 99 | class bottleneck_IR_SE(Module): 100 | def __init__(self, in_channel, depth, stride): 101 | super(bottleneck_IR_SE, self).__init__() 102 | if in_channel == depth: 103 | self.shortcut_layer = MaxPool2d(1, stride) 104 | else: 105 | self.shortcut_layer = Sequential( 106 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 107 | BatchNorm2d(depth) 108 | ) 109 | self.res_layer = Sequential( 110 | BatchNorm2d(in_channel), 111 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 112 | PReLU(depth), 113 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 114 | BatchNorm2d(depth), 115 | SEModule(depth, 16) 116 | ) 117 | 118 | def forward(self, x): 119 | shortcut = self.shortcut_layer(x) 120 | res = self.res_layer(x) 121 | return res + shortcut -------------------------------------------------------------------------------- /ldm/thirdp/psp/id_loss.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | import torch 3 | from torch import nn 4 | from ldm.thirdp.psp.model_irse import Backbone 5 | 6 | 7 | class IDFeatures(nn.Module): 8 | def __init__(self, model_path): 9 | super(IDFeatures, self).__init__() 10 | print('Loading ResNet ArcFace') 11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 12 | self.facenet.load_state_dict(torch.load(model_path, map_location="cpu")) 13 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 14 | self.facenet.eval() 15 | 16 | def forward(self, x, crop=False): 17 | # Not sure of the image range here 18 | if crop: 19 | x = torch.nn.functional.interpolate(x, (256, 256), mode="area") 20 | x = x[:, :, 35:223, 32:220] 21 | x = self.face_pool(x) 22 | x_feats = self.facenet(x) 23 | return x_feats 24 | -------------------------------------------------------------------------------- /ldm/thirdp/psp/model_irse.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | 3 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 4 | from ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 5 | 6 | """ 7 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 8 | """ 9 | 10 | 11 | class Backbone(Module): 12 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 13 | super(Backbone, self).__init__() 14 | assert input_size in [112, 224], "input_size should be 112 or 224" 15 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 16 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 17 | blocks = get_blocks(num_layers) 18 | if mode == 'ir': 19 | unit_module = bottleneck_IR 20 | elif mode == 'ir_se': 21 | unit_module = bottleneck_IR_SE 22 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 23 | BatchNorm2d(64), 24 | PReLU(64)) 25 | if input_size == 112: 26 | self.output_layer = Sequential(BatchNorm2d(512), 27 | Dropout(drop_ratio), 28 | Flatten(), 29 | Linear(512 * 7 * 7, 512), 30 | BatchNorm1d(512, affine=affine)) 31 | else: 32 | self.output_layer = Sequential(BatchNorm2d(512), 33 | Dropout(drop_ratio), 34 | Flatten(), 35 | Linear(512 * 14 * 14, 512), 36 | BatchNorm1d(512, affine=affine)) 37 | 38 | modules = [] 39 | for block in blocks: 40 | for bottleneck in block: 41 | modules.append(unit_module(bottleneck.in_channel, 42 | bottleneck.depth, 43 | bottleneck.stride)) 44 | self.body = Sequential(*modules) 45 | 46 | def forward(self, x): 47 | x = self.input_layer(x) 48 | x = self.body(x) 49 | x = self.output_layer(x) 50 | return l2_norm(x) 51 | 52 | 53 | def IR_50(input_size): 54 | """Constructs a ir-50 model.""" 55 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 56 | return model 57 | 58 | 59 | def IR_101(input_size): 60 | """Constructs a ir-101 model.""" 61 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 62 | return model 63 | 64 | 65 | def IR_152(input_size): 66 | """Constructs a ir-152 model.""" 67 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 68 | return model 69 | 70 | 71 | def IR_SE_50(input_size): 72 | """Constructs a ir_se-50 model.""" 73 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 74 | return model 75 | 76 | 77 | def IR_SE_101(input_size): 78 | """Constructs a ir_se-101 model.""" 79 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 80 | return model 81 | 82 | 83 | def IR_SE_152(input_size): 84 | """Constructs a ir_se-152 model.""" 85 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 86 | return model -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torchvision 4 | import torch 5 | from torch import optim 6 | import numpy as np 7 | 8 | from inspect import isfunction 9 | from PIL import Image, ImageDraw, ImageFont 10 | 11 | import os 12 | import numpy as np 13 | import matplotlib.pyplot as plt 14 | from PIL import Image 15 | import torch 16 | import time 17 | import cv2 18 | # from carvekit.api.high import HiInterface 19 | import PIL 20 | 21 | def pil_rectangle_crop(im): 22 | width, height = im.size # Get dimensions 23 | 24 | if width <= height: 25 | left = 0 26 | right = width 27 | top = (height - width)/2 28 | bottom = (height + width)/2 29 | else: 30 | 31 | top = 0 32 | bottom = height 33 | left = (width - height) / 2 34 | bottom = (width + height) / 2 35 | 36 | # Crop the center of the image 37 | im = im.crop((left, top, right, bottom)) 38 | return im 39 | 40 | def add_margin(pil_img, color, size=256): 41 | width, height = pil_img.size 42 | result = Image.new(pil_img.mode, (size, size), color) 43 | result.paste(pil_img, ((size - width) // 2, (size - height) // 2)) 44 | return result 45 | 46 | 47 | def create_carvekit_interface(): 48 | # Check doc strings for more information 49 | interface = HiInterface(object_type="object", # Can be "object" or "hairs-like". 50 | batch_size_seg=5, 51 | batch_size_matting=1, 52 | device='cuda' if torch.cuda.is_available() else 'cpu', 53 | seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net 54 | matting_mask_size=2048, 55 | trimap_prob_threshold=231, 56 | trimap_dilation=30, 57 | trimap_erosion_iters=5, 58 | fp16=False) 59 | 60 | return interface 61 | 62 | 63 | def load_and_preprocess(interface, input_im): 64 | ''' 65 | :param input_im (PIL Image). 66 | :return image (H, W, 3) array in [0, 1]. 67 | ''' 68 | # See https://github.com/Ir1d/image-background-remove-tool 69 | image = input_im.convert('RGB') 70 | 71 | image_without_background = interface([image])[0] 72 | image_without_background = np.array(image_without_background) 73 | est_seg = image_without_background > 127 74 | image = np.array(image) 75 | foreground = est_seg[:, : , -1].astype(np.bool_) 76 | image[~foreground] = [255., 255., 255.] 77 | x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8)) 78 | image = image[y:y+h, x:x+w, :] 79 | image = PIL.Image.fromarray(np.array(image)) 80 | 81 | # resize image such that long edge is 512 82 | image.thumbnail([200, 200], Image.Resampling.LANCZOS) 83 | image = add_margin(image, (255, 255, 255), size=256) 84 | image = np.array(image) 85 | 86 | return image 87 | 88 | 89 | def log_txt_as_img(wh, xc, size=10): 90 | # wh a tuple of (width, height) 91 | # xc a list of captions to plot 92 | b = len(xc) 93 | txts = list() 94 | for bi in range(b): 95 | txt = Image.new("RGB", wh, color="white") 96 | draw = ImageDraw.Draw(txt) 97 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 98 | nc = int(40 * (wh[0] / 256)) 99 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 100 | 101 | try: 102 | draw.text((0, 0), lines, fill="black", font=font) 103 | except UnicodeEncodeError: 104 | print("Cant encode string for logging. Skipping.") 105 | 106 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 107 | txts.append(txt) 108 | txts = np.stack(txts) 109 | txts = torch.tensor(txts) 110 | return txts 111 | 112 | 113 | def ismap(x): 114 | if not isinstance(x, torch.Tensor): 115 | return False 116 | return (len(x.shape) == 4) and (x.shape[1] > 3) 117 | 118 | 119 | def isimage(x): 120 | if not isinstance(x,torch.Tensor): 121 | return False 122 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 123 | 124 | 125 | def exists(x): 126 | return x is not None 127 | 128 | 129 | def default(val, d): 130 | if exists(val): 131 | return val 132 | return d() if isfunction(d) else d 133 | 134 | 135 | def mean_flat(tensor): 136 | """ 137 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 138 | Take the mean over all non-batch dimensions. 139 | """ 140 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 141 | 142 | 143 | def count_params(model, verbose=False): 144 | total_params = sum(p.numel() for p in model.parameters()) 145 | if verbose: 146 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 147 | return total_params 148 | 149 | 150 | def instantiate_from_config(config): 151 | if not "target" in config: 152 | if config == '__is_first_stage__': 153 | return None 154 | elif config == "__is_unconditional__": 155 | return None 156 | raise KeyError("Expected key `target` to instantiate.") 157 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 158 | 159 | 160 | def get_obj_from_str(string, reload=False): 161 | module, cls = string.rsplit(".", 1) 162 | if reload: 163 | module_imp = importlib.import_module(module) 164 | importlib.reload(module_imp) 165 | return getattr(importlib.import_module(module, package=None), cls) 166 | 167 | 168 | class AdamWwithEMAandWings(optim.Optimizer): 169 | # credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298 170 | def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using 171 | weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code 172 | ema_power=1., param_names=()): 173 | """AdamW that saves EMA versions of the parameters.""" 174 | if not 0.0 <= lr: 175 | raise ValueError("Invalid learning rate: {}".format(lr)) 176 | if not 0.0 <= eps: 177 | raise ValueError("Invalid epsilon value: {}".format(eps)) 178 | if not 0.0 <= betas[0] < 1.0: 179 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 180 | if not 0.0 <= betas[1] < 1.0: 181 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 182 | if not 0.0 <= weight_decay: 183 | raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) 184 | if not 0.0 <= ema_decay <= 1.0: 185 | raise ValueError("Invalid ema_decay value: {}".format(ema_decay)) 186 | defaults = dict(lr=lr, betas=betas, eps=eps, 187 | weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay, 188 | ema_power=ema_power, param_names=param_names) 189 | super().__init__(params, defaults) 190 | 191 | def __setstate__(self, state): 192 | super().__setstate__(state) 193 | for group in self.param_groups: 194 | group.setdefault('amsgrad', False) 195 | 196 | @torch.no_grad() 197 | def step(self, closure=None): 198 | """Performs a single optimization step. 199 | Args: 200 | closure (callable, optional): A closure that reevaluates the model 201 | and returns the loss. 202 | """ 203 | loss = None 204 | if closure is not None: 205 | with torch.enable_grad(): 206 | loss = closure() 207 | 208 | for group in self.param_groups: 209 | params_with_grad = [] 210 | grads = [] 211 | exp_avgs = [] 212 | exp_avg_sqs = [] 213 | ema_params_with_grad = [] 214 | state_sums = [] 215 | max_exp_avg_sqs = [] 216 | state_steps = [] 217 | amsgrad = group['amsgrad'] 218 | beta1, beta2 = group['betas'] 219 | ema_decay = group['ema_decay'] 220 | ema_power = group['ema_power'] 221 | 222 | for p in group['params']: 223 | if p.grad is None: 224 | continue 225 | params_with_grad.append(p) 226 | if p.grad.is_sparse: 227 | raise RuntimeError('AdamW does not support sparse gradients') 228 | grads.append(p.grad) 229 | 230 | state = self.state[p] 231 | 232 | # State initialization 233 | if len(state) == 0: 234 | state['step'] = 0 235 | # Exponential moving average of gradient values 236 | state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) 237 | # Exponential moving average of squared gradient values 238 | state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 239 | if amsgrad: 240 | # Maintains max of all exp. moving avg. of sq. grad. values 241 | state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) 242 | # Exponential moving average of parameter values 243 | state['param_exp_avg'] = p.detach().float().clone() 244 | 245 | exp_avgs.append(state['exp_avg']) 246 | exp_avg_sqs.append(state['exp_avg_sq']) 247 | ema_params_with_grad.append(state['param_exp_avg']) 248 | 249 | if amsgrad: 250 | max_exp_avg_sqs.append(state['max_exp_avg_sq']) 251 | 252 | # update the steps for each param group update 253 | state['step'] += 1 254 | # record the step after step update 255 | state_steps.append(state['step']) 256 | 257 | optim._functional.adamw(params_with_grad, 258 | grads, 259 | exp_avgs, 260 | exp_avg_sqs, 261 | max_exp_avg_sqs, 262 | state_steps, 263 | amsgrad=amsgrad, 264 | beta1=beta1, 265 | beta2=beta2, 266 | lr=group['lr'], 267 | weight_decay=group['weight_decay'], 268 | eps=group['eps'], 269 | maximize=False) 270 | 271 | cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power) 272 | for param, ema_param in zip(params_with_grad, ema_params_with_grad): 273 | ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay) 274 | 275 | return loss -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import torchmetrics 2 | from torchmetrics.functional import peak_signal_noise_ratio, structural_similarity_index_measure 3 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 4 | from torchmetrics.image.fid import FrechetInceptionDistance 5 | import torch 6 | import torchvision 7 | 8 | lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg') 9 | 10 | def fix_bounds(xmin,xmax,ymin,ymax,size,min_size=32): 11 | 12 | if xmax - xmin < min_size: # if size of masked region is too small, we make size 64 13 | center = (xmax + xmin) // 2 14 | xmin = max(center - min_size,0) 15 | xmax = min(center + min_size,size-1) 16 | if xmin == 0: 17 | xmax += min_size - (xmax - xmin) 18 | elif xmax == size-1: 19 | xmin -= min_size - (xmax - xmin) 20 | 21 | if ymax - ymin < min_size: 22 | center = (ymax + ymin) // 2 23 | ymin = max(center - min_size,0) 24 | ymax = min(center + min_size,size-1) 25 | if xmin == 0: 26 | ymax += min_size - (ymax - ymin) 27 | elif xmax == size-1: 28 | ymin -= min_size - (ymax - ymin) 29 | 30 | return xmin,xmax,ymin,ymax 31 | 32 | 33 | def psnr(x,y,*args): 34 | 35 | return peak_signal_noise_ratio(x,y).item() 36 | 37 | def psnr_mask(x,y,mask,*args): 38 | 39 | rows, cols = torch.where(mask) 40 | if len(rows) > 0: 41 | xmin, ymin = rows.min().item(), cols.min().item() 42 | xmax, ymax = rows.max().item(), cols.max().item() 43 | 44 | xmin, xmax, ymin, ymax = fix_bounds(xmin,xmax,ymin,ymax,256) 45 | 46 | x_region = x[:,:, xmin:xmax, ymin:ymax] 47 | y_region = y[:,:, xmin:xmax, ymin:ymax] 48 | else: # object fully occluded 49 | x_region, y_region = x ,y 50 | 51 | return psnr(x_region,y_region) 52 | 53 | 54 | 55 | def ssim(x,y,*args): 56 | 57 | return structural_similarity_index_measure(x,y).item() 58 | 59 | def ssim_mask(x,y,mask,*args): 60 | 61 | rows, cols = torch.where(mask) 62 | if len(rows) > 0: 63 | xmin, ymin = rows.min().item(), cols.min().item() 64 | xmax, ymax = rows.max().item(), cols.max().item() 65 | 66 | xmin, xmax, ymin, ymax = fix_bounds(xmin,xmax,ymin,ymax,256) 67 | 68 | x_region = x[:,:, xmin:xmax, ymin:ymax] 69 | y_region = y[:,:, xmin:xmax, ymin:ymax] 70 | else: # object fully occluded 71 | x_region, y_region = x,y 72 | 73 | return ssim(x_region,y_region) 74 | 75 | 76 | @torch.no_grad() 77 | def lpip(x,y,*args): 78 | 79 | x = 2*x - 1 80 | y = 2*y - 1 81 | return -lpips(x,y).item() 82 | 83 | @torch.no_grad() 84 | def lpip_mask(x,y,mask,*args): 85 | 86 | 87 | rows, cols = torch.where(mask) 88 | if len(rows) > 0: 89 | xmin, ymin = rows.min().item(), cols.min().item() 90 | xmax, ymax = rows.max().item(), cols.max().item() 91 | 92 | xmin, xmax, ymin, ymax = fix_bounds(xmin,xmax,ymin,ymax,256) 93 | 94 | x_region = x[:,:, xmin:xmax, ymin:ymax] 95 | y_region = y[:,:, xmin:xmax, ymin:ymax] 96 | else: # object fully occluded 97 | x_region, y_region = x, y 98 | 99 | return lpip(x_region,y_region) 100 | 101 | def fid(xs,ys,*args): 102 | 103 | fid_metric = FrechetInceptionDistance(feature=64) 104 | fid_metric.update(xs,real=False) 105 | fid_metric.update(ys,real=True) 106 | 107 | return fid_metric.compute().item() 108 | 109 | 110 | -------------------------------------------------------------------------------- /req.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu113 2 | torch==1.12.1 3 | torchvision==0.13.1 4 | opencv-python==4.5.5.64 5 | pudb==2019.2 6 | imageio==2.9.0 7 | imageio-ffmpeg==0.4.2 8 | pytorch_lightning==1.5.0 9 | omegaconf==2.1.1 10 | streamlit>=0.73.1 11 | einops==0.3.0 12 | transformers==4.22.2 13 | kornia==0.6 14 | webdataset==0.2.5 15 | torchmetrics==0.6.0 16 | diffusers==0.12.1 17 | datasets[vision]==2.4.0 18 | plotly==5.13.1 19 | ipdb 20 | ftfy 21 | regex 22 | tqdm 23 | git+https://github.com/openai/CLIP.git 24 | taming-transformers-rom1504 25 | matplotlib 26 | wandb 27 | torchmetrics[image] -------------------------------------------------------------------------------- /run_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torchvision 5 | from metrics import psnr_mask, ssim_mask, lpip_mask, fid 6 | from ldm.util import instantiate_from_config 7 | from ldm.data.simple import ObjaverseDataRotation, ObjaverseDataRemove, ObjaverseDataInsert, ObjaverseDataTranslation 8 | from einops import rearrange 9 | from tqdm import tqdm 10 | import numpy as np 11 | import json 12 | import ipdb 13 | 14 | task_classes = { 15 | "rotate": ObjaverseDataRotation, 16 | "translate": ObjaverseDataTranslation, 17 | "remove": ObjaverseDataRemove, 18 | "insert": ObjaverseDataInsert 19 | } 20 | 21 | def run(args): 22 | 23 | image_transforms = torchvision.transforms.Compose([ 24 | torchvision.transforms.Resize(256), 25 | torchvision.transforms.ToTensor(), 26 | ]) 27 | dataset = task_classes[args.task]( 28 | root_dir=args.data_dir, 29 | image_transforms=image_transforms, 30 | task=args.task, 31 | split=args.split, 32 | seen_or_unseen=args.seen, 33 | ) 34 | sample_metrics = { 35 | "psnr_mask": psnr_mask, 36 | "ssim_mask": ssim_mask, 37 | "lpip_mask": lpip_mask, 38 | } 39 | sample_statistics = [] 40 | best_samples = [] 41 | fid_samples, fid_targets = [], [] 42 | for i in tqdm(range(len(dataset))): 43 | sample = dataset[i] 44 | uid = sample["uid"] 45 | generated_samples = [torchvision.io.read_image(os.path.join(args.generation_dir,uid,f"{i}.png")) for i in range(4)] 46 | generated_samples = [s.unsqueeze(0) / 255. for s in generated_samples] 47 | mask_target = sample["mask_target"] 48 | image_target = sample["image_target"].unsqueeze(0) 49 | metrics_per_sample = [ 50 | {k:f(s,image_target,mask_target,i,uid) for k,f in sample_metrics.items()} 51 | for s in generated_samples 52 | ] 53 | sample_statistics.append(metrics_per_sample) 54 | best_sample_index = max(range(4),key=lambda x: metrics_per_sample[x]["psnr_mask"]) 55 | best_samples.append(metrics_per_sample[best_sample_index]) 56 | fid_samples.append((generated_samples[best_sample_index]*255).to(torch.uint8)) 57 | fid_targets.append((image_target*255).to(torch.uint8)) 58 | 59 | sample_averages = {} 60 | for k in sample_metrics.keys(): 61 | sample_averages["mean_" + k] = np.mean([x[k] for x in best_samples]) 62 | fid_samples, fid_targets = torch.cat(fid_samples,dim=0), torch.cat(fid_targets,dim=0) 63 | sample_averages["fid"] = fid(fid_samples,fid_targets) 64 | os.makedirs(args.save_dir,exist_ok=True) 65 | with open(os.path.join(args.save_dir,f"{args.task}_{args.split}_{args.seen}_summary.json"),"w") as f: 66 | json.dump(sample_averages,f,indent=2) 67 | 68 | 69 | if __name__ == "__main__": 70 | 71 | parser = argparse.ArgumentParser() 72 | parser.add_argument( 73 | "--generation_dir", 74 | type=str, 75 | required=True 76 | ) 77 | parser.add_argument( 78 | "--data_dir", 79 | type=str, 80 | required=True 81 | ) 82 | parser.add_argument( 83 | "--task", 84 | type=str, 85 | choices=["rotate","remove","insert","translate"] 86 | ) 87 | parser.add_argument( 88 | "--split", 89 | type=str, 90 | choices=["train","val","test"] 91 | ) 92 | parser.add_argument( 93 | "--seen", 94 | type=str, 95 | default="seen", 96 | choices=["seen","unseen"] 97 | ) 98 | parser.add_argument( 99 | "--save_dir", 100 | type=str, 101 | default="statistics" 102 | ) 103 | args = parser.parse_args() 104 | assert not (args.split == "train" and args.seen == "unseen") 105 | run(args) 106 | -------------------------------------------------------------------------------- /run_generation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torchvision 5 | from omegaconf import OmegaConf 6 | from ldm.util import instantiate_from_config 7 | from ldm.models.diffusion.ddim import DDIMSampler 8 | import matplotlib.pyplot as plt 9 | from einops import rearrange 10 | import math 11 | import numpy as np 12 | from PIL import Image 13 | 14 | base_path = "/net/nfs/prior/oscarm/best_checkpoints_2.0" 15 | 16 | def load_checkpoint(model,checkpoint): 17 | 18 | print(f"Attempting to load state from {checkpoint}") 19 | old_state = torch.load(checkpoint, map_location="cpu") 20 | 21 | if "state_dict" in old_state: 22 | print(f"Found nested key 'state_dict' in checkpoint, loading this instead") 23 | old_state = old_state["state_dict"] 24 | 25 | # Check if we need to port weights from 4ch input to 8ch 26 | in_filters_load = old_state["model.diffusion_model.input_blocks.0.0.weight"] 27 | new_state = model.state_dict() 28 | in_filters_current = new_state["model.diffusion_model.input_blocks.0.0.weight"] 29 | in_shape = in_filters_current.shape 30 | if in_shape != in_filters_load.shape: 31 | input_keys = [ 32 | "model.diffusion_model.input_blocks.0.0.weight", 33 | "model_ema.diffusion_modelinput_blocks00weight", 34 | ] 35 | 36 | for input_key in input_keys: 37 | if input_key not in old_state or input_key not in new_state: 38 | continue 39 | input_weight = new_state[input_key] 40 | if input_weight.size() != old_state[input_key].size(): 41 | print(f"Manual init: {input_key}") 42 | input_weight.zero_() 43 | input_weight[:, :4, :, :].copy_(old_state[input_key]) 44 | old_state[input_key] = torch.nn.parameter.Parameter(input_weight) 45 | 46 | m, u = model.load_state_dict(old_state, strict=False) 47 | 48 | if len(m) > 0: 49 | print("missing keys:") 50 | print(m) 51 | if len(u) > 0: 52 | print("unexpected keys:") 53 | print(u) 54 | 55 | def preprocess_image(img): 56 | 57 | img = img.convert("RGB") 58 | image_transforms = torchvision.transforms.Compose([ 59 | torchvision.transforms.Resize((256,256)), 60 | torchvision.transforms.ToTensor(), 61 | # torchvision.transforms.Lambda(lambda x: rearrange(x * 2. - 1., 'c h w -> h w c')) 62 | ]) 63 | img = image_transforms(img) 64 | img = img*2. - 1. 65 | img = img.unsqueeze(0) 66 | return img 67 | 68 | def sample_model( 69 | model, 70 | input_im, 71 | prompt, 72 | T, 73 | sampler, 74 | ddim_steps, 75 | n_samples, 76 | scale=1.0, 77 | ddim_eta=1.0 78 | ): 79 | 80 | print(prompt) 81 | c = model.get_learned_conditioning({"image":input_im,"text":prompt}).tile(n_samples, 1, 1) 82 | null_prompt = model.get_learned_conditioning([""]) 83 | uc = null_prompt.repeat(1,c.shape[1],1).tile(n_samples,1,1) 84 | T = T[None, None, :].repeat(n_samples, 1, 1).to(c.device) 85 | T = T.repeat(1,c.shape[1],1) 86 | c = torch.cat([c, T], dim=-1) 87 | c = model.cc_projection(c) 88 | cond = {} 89 | cond['c_crossattn'] = [c] 90 | cond['c_concat'] = [model.encode_first_stage((input_im.to(c.device))).mode().detach() 91 | .repeat(n_samples, 1, 1, 1)] 92 | uncond = {} 93 | uncond['c_crossattn'] = [uc] 94 | uncond['c_concat'] = [model.encode_first_stage((input_im.to(c.device))).mode().detach() 95 | .repeat(n_samples, 1, 1, 1)] 96 | h, w = 256, 256 97 | shape = [4, h // 8, w // 8] 98 | samples_ddim, _ = sampler.sample(S=ddim_steps, 99 | conditioning=cond, 100 | batch_size=n_samples, 101 | shape=shape, 102 | verbose=False, 103 | unconditional_guidance_scale=scale, 104 | unconditional_conditioning=uncond, 105 | eta=ddim_eta, 106 | x_T=None) 107 | 108 | x_samples_ddim = model.decode_first_stage(samples_ddim) 109 | x_samples_ddim = torch.clamp(x_samples_ddim, -1. ,1.) 110 | x_samples_ddim = ((x_samples_ddim + 1.0) / 2.0).cpu() 111 | 112 | return x_samples_ddim 113 | 114 | def run(args): 115 | 116 | print("LOADING MODEL!") 117 | config = OmegaConf.load(f"configs/sd-objaverse-{args.task}.yaml") 118 | OmegaConf.update(config,"model.params.cond_stage_config.params.device",args.device) 119 | model = instantiate_from_config(config.model) 120 | model.cpu() 121 | load_checkpoint(model,args.checkpoint_path) 122 | model.to(args.device) 123 | model.eval() 124 | print("FINISHED LOADING!") 125 | 126 | image = Image.open(args.image_path) 127 | input_im = preprocess_image(image).to(args.device) 128 | x, y = map(float, args.position.split(',')) 129 | if args.task == "rotate": 130 | prompt = f"rotate the {args.object_prompt}" 131 | azimuth = math.radians(args.rotation_angle) 132 | T = torch.tensor([np.pi / 2, math.sin(azimuth), math.cos(azimuth),0]) 133 | elif args.task == "remove": 134 | prompt = f"remove the {args.object_prompt}" 135 | T = torch.tensor([0.,0.,0.,0.]) 136 | elif args.task == "insert": 137 | prompt = f"insert the {args.object_prompt}" 138 | T = torch.tensor([0,x,y,0]) 139 | elif args.task == "translate": 140 | prompt = f"move the {args.object_prompt}" 141 | T = torch.tensor([0,x,y,0]) 142 | 143 | 144 | sampler = DDIMSampler(model) 145 | 146 | x_samples_ddim = sample_model( 147 | model, 148 | input_im, 149 | prompt, 150 | T, 151 | sampler, 152 | args.ddim_steps, 153 | args.num_samples, 154 | scale=args.cfg_scale 155 | ) 156 | output_ims = [] 157 | for x_sample in x_samples_ddim: 158 | x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') 159 | output_ims.append(Image.fromarray(x_sample.astype(np.uint8))) 160 | 161 | input_im = ((input_im + 1.0) / 2.0).cpu()[0] 162 | input_im = 255.0 * rearrange(input_im.numpy(), 'c h w -> h w c') 163 | input_im = Image.fromarray(input_im.astype(np.uint8)) 164 | 165 | os.makedirs(args.save_dir,exist_ok=True) 166 | 167 | input_im.save(os.path.join(args.save_dir,"input_im.png")) 168 | for i,img in enumerate(output_ims): 169 | img.save(os.path.join(args.save_dir,f"{i}.png")) 170 | 171 | 172 | 173 | 174 | if __name__ == "__main__": 175 | 176 | parser = argparse.ArgumentParser() 177 | parser.add_argument( 178 | "--task", 179 | type=str, 180 | required=True, 181 | choices=["rotate","remove","insert","translate"] 182 | ) 183 | parser.add_argument( 184 | "--checkpoint_path", 185 | type=str, 186 | required=True 187 | ) 188 | parser.add_argument( 189 | "--image_path", 190 | type=str, 191 | required=True 192 | ) 193 | parser.add_argument( 194 | "--save_dir", 195 | type=str, 196 | default="generated_images" 197 | ) 198 | parser.add_argument( 199 | "--object_prompt", 200 | type=str, 201 | required=True 202 | ) 203 | parser.add_argument( 204 | "--rotation_angle", 205 | type=float, 206 | default=0.0 207 | ) 208 | parser.add_argument( 209 | "--position", 210 | type=str, 211 | default="0.5,0.5", 212 | help="Coordinates in x,y form where 0 <= x,y <= 1" 213 | ) 214 | parser.add_argument( 215 | "--ddim_steps", 216 | type=int, 217 | default=50 218 | ) 219 | parser.add_argument( 220 | "--num_samples", 221 | type=int, 222 | default=8 223 | ) 224 | parser.add_argument( 225 | "--device", 226 | type=int, 227 | default=0 228 | ) 229 | parser.add_argument( 230 | "--cfg_scale", 231 | type=float, 232 | default=1.0 233 | ) 234 | args = parser.parse_args() 235 | run(args) 236 | 237 | 238 | -------------------------------------------------------------------------------- /setup_reqs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | pip install -r req.txt 3 | pip install pytorch_lightning==1.5.0 4 | pip install omegaconf 5 | pip install opencv-python 6 | pip install carvekit-colab==4.1.0 7 | pip install einops 8 | pip install taming-transformers-rom1504 9 | pip install kornia 10 | pip install git+https://github.com/openai/CLIP.git 11 | pip install transformers 12 | pip install wandb 13 | pip install webdataset==0.2.5 14 | pip install ipdb 15 | pip install matplotlib 16 | sudo apt-get update && sudo apt-get install ffmpeg libsm6 libxext6 -y 17 | pip install torchmetrics[image] -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python main.py \ 4 | -t \ 5 | --base configs/sd-objaverse-$1.yaml \ 6 | --gpus 0,1,2,3,4,5,6,7 \ 7 | --scale_lr False \ 8 | --num_nodes 1 \ 9 | --seed 42 \ 10 | --finetune_from zero123.ckpt \ -------------------------------------------------------------------------------- /uses.md: -------------------------------------------------------------------------------- 1 | 2 | # Uses 3 | _Note: This section is originally taken from the [Stable Diffusion v2 model card](https://huggingface.co/stabilityai/stable-diffusion-2), but applies in the same way to Zero-1-to-3._ 4 | 5 | ## Direct Use 6 | The model is intended for research purposes only. Possible research areas and tasks include: 7 | 8 | - Safe deployment of large-scale models. 9 | - Probing and understanding the limitations and biases of generative models. 10 | - Generation of artworks and use in design and other artistic processes. 11 | - Applications in educational or creative tools. 12 | - Research on generative models. 13 | 14 | Excluded uses are described below. 15 | 16 | ### Misuse, Malicious Use, and Out-of-Scope Use 17 | The model should not be used to intentionally create or disseminate images that create hostile or alienating environments for people. This includes generating images that people would foreseeably find disturbing, distressing, or offensive; or content that propagates historical or current stereotypes. 18 | 19 | #### Out-of-Scope Use 20 | The model was not trained to be factual or true representations of people or events, and therefore using the model to generate such content is out-of-scope for the abilities of this model. 21 | 22 | #### Misuse and Malicious Use 23 | Using the model to generate content that is cruel to individuals is a misuse of this model. This includes, but is not limited to: 24 | 25 | - Generating demeaning, dehumanizing, or otherwise harmful representations of people or their environments, cultures, religions, etc. 26 | - Intentionally promoting or propagating discriminatory content or harmful stereotypes. 27 | - Impersonating individuals without their consent. 28 | - Sexual content without consent of the people who might see it. 29 | - Mis- and disinformation 30 | - Representations of egregious violence and gore 31 | - Sharing of copyrighted or licensed material in violation of its terms of use. 32 | - Sharing content that is an alteration of copyrighted or licensed material in violation of its terms of use. 33 | 34 | ## Limitations and Bias 35 | 36 | ### Limitations 37 | 38 | - The model does not achieve perfect photorealism. 39 | - The model cannot render legible text. 40 | - Faces and people in general may not be parsed or generated properly. 41 | - The autoencoding part of the model is lossy. 42 | - Stable Diffusion was trained on a subset of the large-scale dataset [LAION-5B](https://laion.ai/blog/laion-5b/), which contains adult, violent and sexual content. To partially mitigate this, Stability AI has filtered the dataset using LAION's NSFW detector. 43 | - Zero-1-to-3 was subsequently finetuned on a subset of the large-scale dataset [Objaverse](https://objaverse.allenai.org/), which might also potentially contain inappropriate content. To partially mitigate this, our demo applies a safety check to every uploaded image. 44 | 45 | ### Bias 46 | While the capabilities of image generation models are impressive, they can also reinforce or exacerbate social biases. 47 | Stable Diffusion was primarily trained on subsets of [LAION-2B(en)](https://laion.ai/blog/laion-5b/), which consists of images that are limited to English descriptions. 48 | Images and concepts from communities and cultures that use other languages are likely to be insufficiently accounted for. 49 | This affects the overall output of the model, as Western cultures are often overrepresented. 50 | Stable Diffusion mirrors and exacerbates biases to such a degree that viewer discretion must be advised irrespective of the input or its intent. 51 | 52 | 53 | ### Safety Module 54 | The intended use of this model is with the [Safety Checker](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/safety_checker.py) in Diffusers. 55 | This checker works by checking model inputs against known hard-coded NSFW concepts. 56 | Specifically, the checker compares the class probability of harmful concepts in the embedding space of the uploaded input images. 57 | The concepts are passed into the model with the image and compared to a hand-engineered weight for each NSFW concept. 58 | 59 | ## Citation 60 | ``` 61 | @misc{liu2023zero1to3, 62 | title={Zero-1-to-3: Zero-shot One Image to 3D Object}, 63 | author={Ruoshi Liu and Rundi Wu and Basile Van Hoorick and Pavel Tokmakov and Sergey Zakharov and Carl Vondrick}, 64 | year={2023}, 65 | eprint={2303.11328}, 66 | archivePrefix={arXiv}, 67 | primaryClass={cs.CV} 68 | } 69 | ``` 70 | --------------------------------------------------------------------------------