├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.md ├── README.md ├── UI ├── .gitignore ├── README.md ├── app.py ├── assets │ └── screenshot.png ├── data │ ├── beach.jpg │ ├── cat_beach.jpg │ ├── dessert.jpg │ ├── dog_beach.jpg │ ├── lion_2.png │ ├── lizards.jpg │ ├── lizards_cropped.jpg │ ├── my_snow_ball.jpg │ ├── oranges.jpg │ ├── photo-1632085221727-e54b3302caf2.webp │ └── savana.jpg ├── pdm.lock ├── pyproject.toml └── tests │ └── __init__.py ├── configs ├── collage_composite_train.yaml ├── collage_flow_train.yaml └── collage_mix_train.yaml ├── data_processing ├── example_videos │ ├── getty-soccer-ball-jordan-video-id473239807_26.mp4 │ ├── getty-video-of-american-flags-being-sewn-together-at-flagsource-in-batavia-video-id804937470_87.mp4 │ ├── giphy-fgiT2cbsTxl8k_0.mp4 │ ├── giphy-gkvCpHRX9IqkM_3.mp4 │ ├── yt--4Fx5XUD-9Y_345.mp4 │ └── yt-mNdvtOO7UqY_15.mp4 ├── moments_dataset.py ├── moments_processing.py └── processing_utils.py ├── environment.yaml ├── examples ├── dog_beach__edit__003.png ├── dog_beach_og.png ├── fox_drinking__edit__01.png ├── fox_drinking__edit__02.png ├── fox_drinking_og.png ├── kingfisher__edit__001.png ├── kingfisher_og.png ├── log.csv ├── palm_tree__edit__01.png ├── palm_tree_og.png ├── pipes__edit__01.png └── pipes_og.png ├── ldm ├── __pycache__ │ └── util.cpython-38.pyc ├── data │ ├── __init__.py │ └── collage_dataset.py ├── lr_scheduler.py ├── models │ ├── __pycache__ │ │ └── autoencoder.cpython-38.pyc │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── ddim.cpython-38.pyc │ │ └── ddpm.cpython-38.pyc │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ └── plms.py ├── modules │ ├── __pycache__ │ │ ├── attention.cpython-38.pyc │ │ ├── ema.cpython-38.pyc │ │ └── x_transformer.cpython-38.pyc │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── model.cpython-38.pyc │ │ │ ├── openaimodel.cpython-38.pyc │ │ │ └── util.cpython-38.pyc │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── distributions.cpython-38.pyc │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── modules.cpython-38.pyc │ │ │ └── xf.cpython-38.pyc │ │ ├── modules.py │ │ └── xf.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py └── util.py ├── magicfu_gradio.py ├── main.py ├── run_magicfu.py ├── scripts ├── combine_model_params.py ├── inference.py └── modify_checkpoints.py ├── setup.py └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | *.pth 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 111 | .pdm.toml 112 | .pdm-python 113 | .pdm-build/ 114 | 115 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 116 | __pypackages__/ 117 | 118 | # Celery stuff 119 | celerybeat-schedule 120 | celerybeat.pid 121 | 122 | # SageMath parsed files 123 | *.sage.py 124 | 125 | # Environments 126 | .env 127 | .venv 128 | env/ 129 | venv/ 130 | ENV/ 131 | env.bak/ 132 | venv.bak/ 133 | 134 | # Spyder project settings 135 | .spyderproject 136 | .spyproject 137 | 138 | # Rope project settings 139 | .ropeproject 140 | 141 | # mkdocs documentation 142 | /site 143 | 144 | # mypy 145 | .mypy_cache/ 146 | .dmypy.json 147 | dmypy.json 148 | 149 | # Pyre type checker 150 | .pyre/ 151 | 152 | # pytype static type analyzer 153 | .pytype/ 154 | 155 | # Cython debug symbols 156 | cython_debug/ 157 | 158 | # PyCharm 159 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 160 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 161 | # and can be added to the global gitignore or merged into this file. For a more nuclear 162 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 163 | #.idea/ 164 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Adobe Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our project and community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, caste, color, religion, or sexual identity and orientation. 6 | 7 | We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. 8 | 9 | ## Our Standards 10 | 11 | Examples of behavior that contribute to a positive environment for our project and community include: 12 | 13 | * Demonstrating empathy and kindness toward other people 14 | * Being respectful of differing opinions, viewpoints, and experiences 15 | * Giving and gracefully accepting constructive feedback 16 | * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience 17 | * Focusing on what is best, not just for us as individuals but for the overall community 18 | 19 | Examples of unacceptable behavior include: 20 | 21 | * The use of sexualized language or imagery, and sexual attention or advances of any kind 22 | * Trolling, insulting or derogatory comments, and personal or political attacks 23 | * Public or private harassment 24 | * Publishing others’ private information, such as a physical or email address, without their explicit permission 25 | * Other conduct which could reasonably be considered inappropriate in a professional setting 26 | 27 | ## Our Responsibilities 28 | 29 | Project maintainers are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any instances of unacceptable behavior. 30 | 31 | Project maintainers have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, or to ban temporarily or permanently any contributor for behaviors that they deem inappropriate, threatening, offensive, or harmful. 32 | 33 | ## Scope 34 | 35 | This Code of Conduct applies when an individual is representing the project or its community both within project spaces and in public spaces. Examples of representing a project or community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. Representation of a project may be further defined and clarified by project maintainers. 36 | 37 | ## Enforcement 38 | 39 | Instances of abusive, harassing, or otherwise unacceptable behavior may be reported by first contacting the project team. Oversight of Adobe projects is handled by the Adobe Open Source Office, which has final say in any violations and enforcement of this Code of Conduct and can be reached at Grp-opensourceoffice@adobe.com. All complaints will be reviewed and investigated promptly and fairly. 40 | 41 | The project team must respect the privacy and security of the reporter of any incident. 42 | 43 | Project maintainers who do not follow or enforce the Code of Conduct may face temporary or permanent repercussions as determined by other members of the project's leadership or the Adobe Open Source Office. 44 | 45 | ## Enforcement Guidelines 46 | 47 | Project maintainers will follow these Community Impact Guidelines in determining the consequences for any action they deem to be in violation of this Code of Conduct: 48 | 49 | **1. Correction** 50 | 51 | Community Impact: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. 52 | 53 | Consequence: A private, written warning from project maintainers describing the violation and why the behavior was unacceptable. A public apology may be requested from the violator before any further involvement in the project by violator. 54 | 55 | **2. Warning** 56 | 57 | Community Impact: A relatively minor violation through a single incident or series of actions. 58 | 59 | Consequence: A written warning from project maintainers that includes stated consequences for continued unacceptable behavior. Violator must refrain from interacting with the people involved for a specified period of time as determined by the project maintainers, including, but not limited to, unsolicited interaction with those enforcing the Code of Conduct through channels such as community spaces and social media. Continued violations may lead to a temporary or permanent ban. 60 | 61 | **3. Temporary Ban** 62 | 63 | Community Impact: A more serious violation of community standards, including sustained unacceptable behavior. 64 | 65 | Consequence: A temporary ban from any interaction or public communication with the community for a specified period of time. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Failure to comply with the temporary ban may lead to a permanent ban. 66 | 67 | **4. Permanent Ban** 68 | 69 | Community Impact: Demonstrating a consistent pattern of violation of community standards or an egregious violation of community standards, including, but not limited to, sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. 70 | 71 | Consequence: A permanent ban from any interaction with the community. 72 | 73 | ## Attribution 74 | 75 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.1, 76 | available at [https://contributor-covenant.org/version/2/1][version] 77 | 78 | [homepage]: https://contributor-covenant.org 79 | [version]: https://contributor-covenant.org/version/2/1 80 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Thanks for choosing to contribute! 4 | 5 | The following are a set of guidelines to follow when contributing to this project. 6 | 7 | ## Code Of Conduct 8 | 9 | This project adheres to the Adobe [code of conduct](./CODE_OF_CONDUCT.md). By participating, 10 | you are expected to uphold this code. Please report unacceptable behavior to 11 | [Grp-opensourceoffice@adobe.com](mailto:Grp-opensourceoffice@adobe.com). 12 | 13 | ## Have A Question? 14 | 15 | Start by filing an issue. The existing committers on this project work to reach 16 | consensus around project direction and issue solutions within issue threads 17 | (when appropriate). 18 | 19 | ## Contributor License Agreement 20 | 21 | All third-party contributions to this project must be accompanied by a signed contributor 22 | license agreement. This gives Adobe permission to redistribute your contributions 23 | as part of the project. [Sign our CLA](https://opensource.adobe.com/cla.html). You 24 | only need to submit an Adobe CLA one time, so if you have submitted one previously, 25 | you are good to go! 26 | 27 | ## Code Reviews 28 | 29 | All submissions should come in the form of pull requests and need to be reviewed 30 | by project committers. Read [GitHub's pull request documentation](https://help.github.com/articles/about-pull-requests/) 31 | for more information on sending pull requests. 32 | 33 | Lastly, please follow the [pull request template](PULL_REQUEST_TEMPLATE.md) when 34 | submitting a pull request! 35 | 36 | ## From Contributor To Committer 37 | 38 | We love contributions from our community! If you'd like to go a step beyond contributor 39 | and become a committer with full write access and a say in the project, you must 40 | be invited to the project. The existing committers employ an internal nomination 41 | process that must reach lazy consensus (silence is approval) before invitations 42 | are issued. If you feel you are qualified and want to get more deeply involved, 43 | feel free to reach out to existing committers to have a conversation about that. 44 | 45 | ## Security Issues 46 | 47 | Security issues shouldn't be reported on this issue tracker. Instead, [file an issue to our security experts](https://helpx.adobe.com/security/alertus.html). 48 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright 2024, Adobe Inc. and its licensors. All rights reserved. 2 | 3 | ADOBE RESEARCH LICENSE 4 | 5 | Adobe grants any person or entity ("you" or "your") obtaining a copy of these certain research materials that are owned by Adobe ("Licensed Materials") a nonexclusive, worldwide, royalty-free, revocable, fully paid license to (A) reproduce, use, modify, and publicly display the Licensed Materials; and (B) redistribute the Licensed Materials, and modifications or derivative works thereof, provided the following conditions are met: 6 | 7 | The rights granted herein may be exercised for noncommercial research purposes (i.e., academic research and teaching) only. Noncommercial research purposes do not include commercial licensing or distribution, development of commercial products, or any other activity that results in commercial gain. 8 | You may add your own copyright statement to your modifications and/or provide additional or different license terms for use, reproduction, modification, public display, and redistribution of your modifications and derivative works, provided that such license terms limit the use, reproduction, modification, public display, and redistribution of such modifications and derivative works to noncommercial research purposes only. 9 | You acknowledge that Adobe and its licensors own all right, title, and interest in the Licensed Materials. 10 | All copies of the Licensed Materials must include the above copyright notice, this list of conditions, and the disclaimer below. 11 | Failure to meet any of the above conditions will automatically terminate the rights granted herein. 12 | 13 | THE LICENSED MATERIALS ARE PROVIDED "AS IS" WITHOUT WARRANTY OF ANY KIND. THE ENTIRE RISK AS TO THE USE, RESULTS, AND PERFORMANCE OF THE LICENSED MATERIALS IS ASSUMED BY YOU. ADOBE DISCLAIMS ALL WARRANTIES, EXPRESS, IMPLIED OR STATUTORY, WITH REGARD TO YOUR USE OF THE LICENSED MATERIALS, INCLUDING, BUT NOT LIMITED TO, NONINFRINGEMENT OF THIRD-PARTY RIGHTS. IN NO EVENT WILL ADOBE BE LIABLE FOR ANY ACTUAL, INCIDENTAL, SPECIAL OR CONSEQUENTIAL DAMAGES, INCLUDING WITHOUT LIMITATION, LOSS OF PROFITS OR OTHER COMMERCIAL LOSS, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THE LICENSED MATERIALS, EVEN IF ADOBE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MagicFixup 2 | This is the repo for the paper [Magic Fixup: Streamlining Photo Editing by Watching Dynamic Videos](https://magic-fixup.github.io) 3 | 4 | **NEW Released the User Interface!** 5 | 6 | ## Installation 7 | We provide an `environment.yaml` file to assist with installation. All what you need for setup is to run the following script 8 | ``` 9 | conda env create -f environment.yaml -v 10 | ``` 11 | and this will create a conda environment that you can activate using `conda activate MagicFixup` 12 | 13 | ## Inference 14 | 15 | #### Downloading Magic Fixup checkpoint 16 | You can download the model trained on the Moments in Time dataset using this [Google Drive Link](https://drive.google.com/file/d/1zOcDcJzCijbGr9I9adC0Cv6yzW60U9TQ/view?usp=share_link) or from [HuggingFace] (https://huggingface.co/HadiZayer/MagicFixup) 17 | 18 | 19 | ### Inference script 20 | The inference scripts is `run_magicfu.py`. It takes the path of the reference image (the original image), and the edited image. Note that it assumes that the alpha channel is set appropriately in the edited image PNG, as we use the alpha channel to set the disocclusion mask. You can run the inference script with 21 | 22 | ``` 23 | python run_magicfu.py --checkpoint --reference --edit 24 | ``` 25 | 26 | ### gradio demo 27 | We have a gradio demo that allows you to test out your inputs with a friendly user interface. Simply start the demo with 28 | ``` 29 | python magicfu_gradio.py --checkpoint 30 | ``` 31 | 32 | 33 | ## Training your own model 34 | To train your own model, first you need to process a video dataset, train the model using the processed pairs from your videos. In our model, we used the Momnets in Time dataset to train the weights we provided above. 35 | 36 | #### Pretrained SD1.4 diffusion model 37 | We start training from the official SD1.4 model (with the first layer modified to take our 9 channel input). You can either download the official SD1.4 model and modify the first layer using `scripts/modify_checkpoints.py` and place it under `pretrained_models` folder. 38 | 39 | ### Data Processing 40 | The data processing code can be found under the `data_processing` folder. You can simply put all the videos in a directory, and pass the directory as the folder name in `data_processing/moments_processing.py`. If your videos are long (~ex more than 5 seconds and contain cut scenes), then you would want to use pyscenedetect to detect cut scenes and split the videos accordingly. 41 | For data processing, you also need to download the checkpoint for SegmentAnything, and install soft-splatting. You can setup softmax-splatting and SAM, by following 42 | ``` 43 | cd data_processing 44 | git clone https://github.com/sniklaus/softmax-splatting.git 45 | pip install segment_anything 46 | cd sam_model 47 | wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth 48 | ``` 49 | For softmax-splatting to run, you need to install `pip install cupy` (or you might need to use `pip install cupy-cuda11x` or `pip install cupy-cuda12x` depending on your cuda version, and load the appropriate cuda module) 50 | 51 | Then run `python moments_processing.py` to start processing frames from the provided examples video (included under `data_processing/example_videos`). For the version provided, we used the [Moments in Time Dataset](http://moments.csail.mit.edu) 52 | 53 | ### Running the training script 54 | Make sure that you have downloaded the pretrained SD1.4 model above. Once you download the training dataset and pretrained model, you can simply start training the model with 55 | ``` 56 | ./train.sh 57 | ``` 58 | The training code is in `main.py`, and relies mainly on pytorch_lightning in training. 59 | 60 | Note that you need to modify the train and val paths in the chosen config file to the location where you have the processed data. 61 | 62 | Note: we use Deepspeed to lower the memory requirements, so the saved model weights will be sharded. The script to reconstruct the model weights will be created in the checkpoint directory with name `zero_to_fp32.py`. One bug in the file is that it wouldn't recognize files with deepspeed1 (which is the one we use), so simply find and replace the string `== 2` with the string `<= 2` and it will work. 63 | 64 | ### Saving the Full Model Weights 65 | To save storage requirements, we only checkpoint the learnable parameters in training (i.e. the frozen autoencoder params are not saved). To create a checkpoint that contains all the parameters, you can combine the frozen pretrained weights and learned parameters by running 66 | 67 | ``` 68 | python combine_model_params.py --pretrained_sd --learned_params --save_path 69 | ``` 70 | 71 | ## Editing UI 72 | To help making your edits easier, we have released the our segmenting based UI. See the [UI folder](https://github.com/adobe-research/MagicFixup/tree/main/UI) for instructions on how to use it and set it up. 73 | 74 | ## Bibtex 75 | if you find our work useful, please consider citing it in your work 76 | 77 | ``` 78 | @misc{alzayer2024magicfixup, 79 | title={Magic Fixup: Streamlining Photo Editing by Watching Dynamic Videos}, 80 | author={Hadi Alzayer and Zhihao Xia and Xuaner Zhang and Eli Shechtman and Jia-Bin Huang and Michael Gharbi}, 81 | year={2024}, 82 | eprint={2403.13044}, 83 | archivePrefix={arXiv}, 84 | primaryClass={cs.CV}, 85 | url={https://arxiv.org/abs/2403.13044}, 86 | } 87 | ``` 88 | 89 | ##### Acknowledgement 90 | The diffusion code was built on top of the codebase adapted in [PaintByExample](https://github.com/Fantasy-Studio/Paint-by-Example) -------------------------------------------------------------------------------- /UI/.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .DS_Store 3 | .pdm-python 4 | output 5 | data/sam_vit_l_0b3195.pth 6 | __pycache__ 7 | data/sam_vit_h_4b8939.pth 8 | -------------------------------------------------------------------------------- /UI/README.md: -------------------------------------------------------------------------------- 1 | # Interactive UI for collage2photo 2 | 3 | ![Screenshot](assets/screenshot.png) 4 | 5 | How-to: 6 | 7 | - Load an image of your choice. Note that MagicFixup is designed to take square images as input 8 | - Click to segment and an object. 9 | - Drag the segment around to perform your edit. 10 | - Clik the "Process..." button to run the model and produce the cleaned-up output. 11 | 12 | We use [PDM](https://pdm-project.org/latest/) as dependency manager: 13 | 14 | Install with: 15 | 16 | ```shell 17 | pdm install 18 | ``` 19 | 20 | Alternatively, you can use your package manager of choice and install pyglet, opencv-python, and [segment anything](https://github.com/facebookresearch/segment-anything) 21 | 22 | Download segment anything [checkpoint](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth) and place it under the data directory. If needed, you can use lighterweight checkpoint for faster inference. 23 | 24 | Then run locally using: 25 | 26 | ```shell 27 | python app.py 28 | ``` 29 | -------------------------------------------------------------------------------- /UI/assets/screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/UI/assets/screenshot.png -------------------------------------------------------------------------------- /UI/data/beach.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/UI/data/beach.jpg -------------------------------------------------------------------------------- /UI/data/cat_beach.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/UI/data/cat_beach.jpg -------------------------------------------------------------------------------- /UI/data/dessert.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/UI/data/dessert.jpg -------------------------------------------------------------------------------- /UI/data/dog_beach.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/UI/data/dog_beach.jpg -------------------------------------------------------------------------------- /UI/data/lion_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/UI/data/lion_2.png -------------------------------------------------------------------------------- /UI/data/lizards.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/UI/data/lizards.jpg -------------------------------------------------------------------------------- /UI/data/lizards_cropped.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/UI/data/lizards_cropped.jpg -------------------------------------------------------------------------------- /UI/data/my_snow_ball.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/UI/data/my_snow_ball.jpg -------------------------------------------------------------------------------- /UI/data/oranges.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/UI/data/oranges.jpg -------------------------------------------------------------------------------- /UI/data/photo-1632085221727-e54b3302caf2.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/UI/data/photo-1632085221727-e54b3302caf2.webp -------------------------------------------------------------------------------- /UI/data/savana.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/UI/data/savana.jpg -------------------------------------------------------------------------------- /UI/pyproject.toml: -------------------------------------------------------------------------------- 1 | 2 | [build-system] 3 | build-backend = "pdm.backend" 4 | requires = ["pdm-backend"] 5 | 6 | [[tool.mypy.overrides]] 7 | ignore_missing_imports = true 8 | module = [ 9 | 'PIL.*', 10 | 'pyglet.gl.*', 11 | ] 12 | 13 | [tool.pdm.dev-dependencies] 14 | dev = [ 15 | "mypy>=1.6.1", 16 | ] 17 | [project] 18 | authors = [ 19 | {name = "Michael Gharbi", email = "mgharbi@adobe.com"}, 20 | ] 21 | dependencies = [ 22 | "pillow<11.0.0,>=10.1.0", 23 | "pysimplegui<5.0.0,>=4.60.5", 24 | "tkmacosx<2.0.0,>=1.0.5", 25 | "segment-anything @ git+https://github.com/facebookresearch/segment-anything.git", 26 | "torch>=2.1.0", 27 | "numpy>=1.25.2", 28 | "torchvision>=0.16.0", 29 | "setuptools>=68.2.2", 30 | "nicegui>=1.4.1", 31 | "scipy>=1.9.3", 32 | "imageio>=2.31.5", 33 | "pygame>=2.5.2", 34 | "pygame-gui>=0.6.9", 35 | "pyglet>=2.0.9", 36 | ] 37 | description = "" 38 | license = {text = "MIT"} 39 | name = "hadi-app" 40 | readme = "README.md" 41 | requires-python = ">=3.10,<4.0" 42 | version = "0.1.0" 43 | 44 | [project.group.dev.dependencies] 45 | mypy = "^1.6.1" 46 | -------------------------------------------------------------------------------- /UI/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/UI/tests/__init__.py -------------------------------------------------------------------------------- /configs/collage_composite_train.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Adobe. All rights reserved. 2 | model: 3 | base_learning_rate: 1.0e-05 4 | target: ldm.models.diffusion.ddpm.LatentDiffusion 5 | params: 6 | linear_start: 0.00085 7 | linear_end: 0.0120 8 | num_timesteps_cond: 1 9 | log_every_t: 200 10 | timesteps: 1000 11 | first_stage_key: "inpaint" 12 | cond_stage_key: "image" 13 | image_size: 64 14 | channels: 4 15 | cond_stage_trainable: true # Note: different from the one we trained before 16 | conditioning_key: "rewarp" 17 | monitor: val/loss_simple_ema 18 | u_cond_percent: 0.2 19 | scale_factor: 0.18215 20 | use_ema: False 21 | context_embedding_dim: 768 # TODO embedding # 1024 clip, DINO: 'small': 384,'big': 768,'large': 1024,'huge': 1536 22 | 23 | 24 | scheduler_config: # 10000 warmup steps 25 | target: ldm.lr_scheduler.LambdaLinearScheduler 26 | params: 27 | warm_up_steps: [ 10000 ] 28 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 29 | f_start: [ 1.e-6 ] 30 | f_max: [ 1. ] 31 | f_min: [ 1. ] 32 | 33 | unet_config: 34 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 35 | params: 36 | image_size: 32 # unused 37 | in_channels: 9 38 | out_channels: 4 39 | model_channels: 320 40 | attention_resolutions: [ 4, 2, 1 ] 41 | num_res_blocks: 2 42 | channel_mult: [ 1, 2, 4, 4 ] 43 | num_heads: 8 44 | use_spatial_transformer: True 45 | transformer_depth: 1 46 | context_dim: 768 47 | use_checkpoint: True 48 | legacy: False 49 | add_conv_in_front_of_unet: False 50 | 51 | first_stage_config: 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | embed_dim: 4 55 | monitor: val/rec_loss 56 | ddconfig: 57 | double_z: true 58 | z_channels: 4 59 | resolution: 256 60 | in_channels: 3 61 | out_ch: 3 62 | ch: 128 63 | ch_mult: 64 | - 1 65 | - 2 66 | - 4 67 | - 4 68 | num_res_blocks: 2 69 | attn_resolutions: [] 70 | dropout: 0.0 71 | lossconfig: 72 | target: torch.nn.Identity 73 | 74 | cond_stage_config: 75 | target: ldm.modules.encoders.modules.DINOEmbedder # TODO embedding 76 | params: 77 | dino_version: "big" # [small, big, large, huge] 78 | 79 | data: 80 | target: main.DataModuleFromConfig 81 | params: 82 | batch_size: 2 83 | num_workers: 8 84 | use_worker_init_fn: False 85 | wrap: False 86 | train: 87 | target: ldm.data.collage_dataset.CollageDataset 88 | params: 89 | split_files: "" 90 | image_size: 512 91 | embedding_type: 'dino' # TODO embedding 92 | warping_type: 'collage' 93 | validation: 94 | target: ldm.data.collage_dataset.CollageDataset 95 | params: 96 | split_files: "" 97 | image_size: 512 98 | embedding_type: 'dino' # TODO embedding 99 | warping_type: 'mix' 100 | test: 101 | target: ldm.data.collage_dataset.CollageDataset 102 | params: 103 | split_files: "" 104 | image_size: 512 105 | embedding_type: 'dino' # TODO embedding 106 | warping_type: 'mix' 107 | 108 | lightning: 109 | trainer: 110 | max_epochs: 500 111 | num_nodes: 1 112 | num_sanity_val_steps: 0 113 | accelerator: 'gpu' 114 | gpus: "0,1,2,3,4,5,6,7" 115 | -------------------------------------------------------------------------------- /configs/collage_flow_train.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Adobe. All rights reserved. 2 | model: 3 | base_learning_rate: 1.0e-05 4 | target: ldm.models.diffusion.ddpm.LatentDiffusion 5 | params: 6 | linear_start: 0.00085 7 | linear_end: 0.0120 8 | num_timesteps_cond: 1 9 | log_every_t: 200 10 | timesteps: 1000 11 | first_stage_key: "inpaint" 12 | cond_stage_key: "image" 13 | image_size: 64 14 | channels: 4 15 | cond_stage_trainable: true # Note: different from the one we trained before 16 | conditioning_key: "rewarp" 17 | monitor: val/loss_simple_ema 18 | u_cond_percent: 0.2 19 | scale_factor: 0.18215 20 | use_ema: False 21 | context_embedding_dim: 768 # TODO embedding # 1024 clip, DINO: 'small': 384,'big': 768,'large': 1024,'huge': 1536 22 | 23 | 24 | scheduler_config: # 10000 warmup steps 25 | target: ldm.lr_scheduler.LambdaLinearScheduler 26 | params: 27 | warm_up_steps: [ 10000 ] 28 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 29 | f_start: [ 1.e-6 ] 30 | f_max: [ 1. ] 31 | f_min: [ 1. ] 32 | 33 | unet_config: 34 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 35 | params: 36 | image_size: 32 # unused 37 | in_channels: 9 38 | out_channels: 4 39 | model_channels: 320 40 | attention_resolutions: [ 4, 2, 1 ] 41 | num_res_blocks: 2 42 | channel_mult: [ 1, 2, 4, 4 ] 43 | num_heads: 8 44 | use_spatial_transformer: True 45 | transformer_depth: 1 46 | context_dim: 768 47 | use_checkpoint: True 48 | legacy: False 49 | add_conv_in_front_of_unet: False 50 | 51 | first_stage_config: 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | embed_dim: 4 55 | monitor: val/rec_loss 56 | ddconfig: 57 | double_z: true 58 | z_channels: 4 59 | resolution: 256 60 | in_channels: 3 61 | out_ch: 3 62 | ch: 128 63 | ch_mult: 64 | - 1 65 | - 2 66 | - 4 67 | - 4 68 | num_res_blocks: 2 69 | attn_resolutions: [] 70 | dropout: 0.0 71 | lossconfig: 72 | target: torch.nn.Identity 73 | 74 | cond_stage_config: 75 | target: ldm.modules.encoders.modules.DINOEmbedder # TODO embedding 76 | params: 77 | dino_version: "big" # [small, big, large, huge] 78 | 79 | data: 80 | target: main.DataModuleFromConfig 81 | params: 82 | batch_size: 2 83 | num_workers: 8 84 | use_worker_init_fn: False 85 | wrap: False 86 | train: 87 | target: ldm.data.collage_dataset.CollageDataset 88 | params: 89 | split_files: /mnt/localssd/new_train 90 | image_size: 512 91 | embedding_type: 'dino' # TODO embedding 92 | warping_type: 'flow' 93 | validation: 94 | target: ldm.data.collage_dataset.CollageDataset 95 | params: 96 | split_files: /mnt/localssd/new_val 97 | image_size: 512 98 | embedding_type: 'dino' # TODO embedding 99 | warping_type: 'mix' 100 | test: 101 | target: ldm.data.collage_dataset.CollageDataset 102 | params: 103 | split_files: /mnt/localssd/new_val 104 | image_size: 512 105 | embedding_type: 'dino' # TODO embedding 106 | warping_type: 'mix' 107 | 108 | lightning: 109 | trainer: 110 | max_epochs: 500 111 | num_nodes: 1 112 | num_sanity_val_steps: 0 113 | accelerator: 'gpu' 114 | gpus: "0,1,2,3,4,5,6,7" 115 | -------------------------------------------------------------------------------- /configs/collage_mix_train.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Adobe. All rights reserved. 2 | model: 3 | base_learning_rate: 1.0e-05 4 | target: ldm.models.diffusion.ddpm.LatentDiffusion 5 | params: 6 | linear_start: 0.00085 7 | linear_end: 0.0120 8 | num_timesteps_cond: 1 9 | log_every_t: 200 10 | timesteps: 1000 11 | first_stage_key: "inpaint" 12 | cond_stage_key: "image" 13 | image_size: 64 14 | channels: 4 15 | cond_stage_trainable: true # Note: different from the one we trained before 16 | conditioning_key: "rewarp" 17 | monitor: val/loss_simple_ema 18 | u_cond_percent: 0.2 19 | scale_factor: 0.18215 20 | use_ema: False 21 | context_embedding_dim: 384 # TODO embedding # 1024 clip, DINO: 'small': 384,'big': 768,'large': 1024,'huge': 1536 22 | dropping_warped_latent_prob: 0.2 23 | 24 | 25 | scheduler_config: # 10000 warmup steps 26 | target: ldm.lr_scheduler.LambdaLinearScheduler 27 | params: 28 | warm_up_steps: [ 10000 ] 29 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 30 | f_start: [ 1.e-6 ] 31 | f_max: [ 1. ] 32 | f_min: [ 1. ] 33 | 34 | unet_config: 35 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 36 | params: 37 | image_size: 32 # unused 38 | in_channels: 9 39 | out_channels: 4 40 | model_channels: 320 41 | attention_resolutions: [ 4, 2, 1 ] 42 | num_res_blocks: 2 43 | channel_mult: [ 1, 2, 4, 4 ] 44 | num_heads: 8 45 | use_spatial_transformer: True 46 | transformer_depth: 1 47 | context_dim: 768 48 | use_checkpoint: True 49 | legacy: False 50 | add_conv_in_front_of_unet: False 51 | 52 | first_stage_config: 53 | target: ldm.models.autoencoder.AutoencoderKL 54 | params: 55 | embed_dim: 4 56 | monitor: val/rec_loss 57 | ddconfig: 58 | double_z: true 59 | z_channels: 4 60 | resolution: 256 61 | in_channels: 3 62 | out_ch: 3 63 | ch: 128 64 | ch_mult: 65 | - 1 66 | - 2 67 | - 4 68 | - 4 69 | num_res_blocks: 2 70 | attn_resolutions: [] 71 | dropout: 0.0 72 | lossconfig: 73 | target: torch.nn.Identity 74 | 75 | cond_stage_config: 76 | target: ldm.modules.encoders.modules.DINOEmbedder # TODO embedding 77 | params: 78 | dino_version: "small" # [small, big, large, huge] 79 | 80 | data: 81 | target: main.DataModuleFromConfig 82 | params: 83 | batch_size: 4 84 | num_workers: 8 85 | use_worker_init_fn: False 86 | wrap: False 87 | train: 88 | target: ldm.data.collage_dataset.CollageDataset 89 | params: 90 | split_files: /mnt/localssd/new_train 91 | image_size: 512 92 | embedding_type: 'dino' # TODO embedding 93 | warping_type: 'mix' 94 | validation: 95 | target: ldm.data.collage_dataset.CollageDataset 96 | params: 97 | split_files: /mnt/localssd/new_val 98 | image_size: 512 99 | embedding_type: 'dino' # TODO embedding 100 | warping_type: 'mix' 101 | test: 102 | target: ldm.data.collage_dataset.CollageDataset 103 | params: 104 | split_files: /mnt/localssd/new_val 105 | image_size: 512 106 | embedding_type: 'dino' # TODO embedding 107 | warping_type: 'mix' 108 | 109 | lightning: 110 | trainer: 111 | max_epochs: 500 112 | num_nodes: 1 113 | num_sanity_val_steps: 0 114 | accelerator: 'gpu' 115 | gpus: "0,1,2,3,4,5,6,7" 116 | -------------------------------------------------------------------------------- /data_processing/example_videos/getty-soccer-ball-jordan-video-id473239807_26.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/data_processing/example_videos/getty-soccer-ball-jordan-video-id473239807_26.mp4 -------------------------------------------------------------------------------- /data_processing/example_videos/getty-video-of-american-flags-being-sewn-together-at-flagsource-in-batavia-video-id804937470_87.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/data_processing/example_videos/getty-video-of-american-flags-being-sewn-together-at-flagsource-in-batavia-video-id804937470_87.mp4 -------------------------------------------------------------------------------- /data_processing/example_videos/giphy-fgiT2cbsTxl8k_0.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/data_processing/example_videos/giphy-fgiT2cbsTxl8k_0.mp4 -------------------------------------------------------------------------------- /data_processing/example_videos/giphy-gkvCpHRX9IqkM_3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/data_processing/example_videos/giphy-gkvCpHRX9IqkM_3.mp4 -------------------------------------------------------------------------------- /data_processing/example_videos/yt--4Fx5XUD-9Y_345.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/data_processing/example_videos/yt--4Fx5XUD-9Y_345.mp4 -------------------------------------------------------------------------------- /data_processing/example_videos/yt-mNdvtOO7UqY_15.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/data_processing/example_videos/yt-mNdvtOO7UqY_15.mp4 -------------------------------------------------------------------------------- /data_processing/moments_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Adobe. All rights reserved. 2 | 3 | #%% 4 | import glob 5 | import torch 6 | import torchvision 7 | import matplotlib.pyplot as plt 8 | from torch.utils.data import Dataset 9 | import numpy as np 10 | 11 | 12 | # %% 13 | class MomentsDataset(Dataset): 14 | def __init__(self, videos_folder, num_frames, samples_per_video, frame_size=512) -> None: 15 | super().__init__() 16 | 17 | self.videos_paths = glob.glob(f'{videos_folder}/*mp4') 18 | self.resize = torchvision.transforms.Resize(size=frame_size) 19 | self.center_crop = torchvision.transforms.CenterCrop(size=frame_size) 20 | self.num_samples_per_video = samples_per_video 21 | self.num_frames = num_frames 22 | 23 | def __len__(self): 24 | return len(self.videos_paths) * self.num_samples_per_video 25 | 26 | def __getitem__(self, idx): 27 | video_idx = idx // self.num_samples_per_video 28 | video_path = self.videos_paths[video_idx] 29 | 30 | try: 31 | start_idx = np.random.randint(0, 20) 32 | 33 | unsampled_video_frames, audio_frames, info = torchvision.io.read_video(video_path,output_format="TCHW") 34 | sampled_indices = torch.tensor(np.linspace(start_idx, len(unsampled_video_frames)-1, self.num_frames).astype(int)) 35 | sampled_frames = unsampled_video_frames[sampled_indices] 36 | processed_frames = [] 37 | 38 | for frame in sampled_frames: 39 | resized_cropped_frame = self.center_crop(self.resize(frame)) 40 | processed_frames.append(resized_cropped_frame) 41 | frames = torch.stack(processed_frames, dim=0) 42 | frames = frames.float() / 255.0 43 | except Exception as e: 44 | print('oops', e) 45 | rand_idx = np.random.randint(0, len(self)) 46 | return self.__getitem__(rand_idx) 47 | 48 | out_dict = {'frames': frames, 49 | 'caption': 'none', 50 | 'keywords': 'none'} 51 | 52 | return out_dict 53 | 54 | 55 | -------------------------------------------------------------------------------- /data_processing/moments_processing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Adobe. All rights reserved. 2 | 3 | #%% 4 | from torchvision.transforms import ToPILImage 5 | import torch 6 | import pandas as pd 7 | import matplotlib.pyplot as plt 8 | import numpy as np 9 | import torchvision 10 | import cv2 11 | import tqdm 12 | import matplotlib.pyplot as plt 13 | import torchvision.transforms.functional as F 14 | from PIL import Image 15 | from torchvision.utils import save_image 16 | import time 17 | import os 18 | import sys 19 | import pathlib 20 | from torchvision.utils import flow_to_image 21 | from torch.utils.data import DataLoader 22 | from einops import rearrange 23 | # %matplotlib inline 24 | from kornia.filters.median import MedianBlur 25 | median_filter = MedianBlur(kernel_size=(15,15)) 26 | from moments_dataset import MomentsDataset 27 | 28 | try: 29 | from processing_utils import aggregate_frames 30 | import processing_utils 31 | except Exception as e: 32 | print(e) 33 | print('process failed') 34 | exit() 35 | 36 | 37 | 38 | 39 | import pytorch_lightning as pl 40 | import torch 41 | from omegaconf import OmegaConf 42 | 43 | # %% 44 | 45 | def load_image(img_path, resize_size=None,crop_size=None): 46 | 47 | img1_pil = Image.open(img_path) 48 | img1_frames = torchvision.transforms.functional.pil_to_tensor(img1_pil) 49 | 50 | if resize_size: 51 | img1_frames = torchvision.transforms.functional.resize(img1_frames, resize_size) 52 | 53 | if crop_size: 54 | img1_frames = torchvision.transforms.functional.center_crop(img1_frames, crop_size) 55 | 56 | img1_batch = torch.unsqueeze(img1_frames, dim=0) 57 | 58 | return img1_batch 59 | 60 | def get_grid(size): 61 | y = np.repeat(np.arange(size)[None, ...], size) 62 | y = y.reshape(size, size) 63 | x = y.transpose() 64 | out = np.stack([y,x], -1) 65 | return out 66 | 67 | def collage_from_frames(frames_t): 68 | # decide forward or backward 69 | if np.random.randint(0, 2) == 0: 70 | # flip 71 | frames_t = frames_t.flip(0) 72 | 73 | # decide how deep you would go 74 | tgt_idx_guess = np.random.randint(1, min(len(frames_t), 20)) 75 | tgt_idx = 1 76 | pairwise_flows = [] 77 | flow = None 78 | init_time = time.time() 79 | unsmoothed_agg = None 80 | for cur_idx in range(1, tgt_idx_guess+1): 81 | # cur_idx = i+1 82 | cur_flow, pairwise_flows = aggregate_frames(frames_t[:cur_idx+1] , pairwise_flows, unsmoothed_agg) # passing pairwise flows for efficiency 83 | unsmoothed_agg = cur_flow.clone() 84 | agg_cur_flow = median_filter(cur_flow) 85 | 86 | flow_norm = torch.norm(agg_cur_flow.squeeze(), dim=0).flatten() 87 | # flow_10 = np.percentile(flow_norm.cpu().numpy(), 10) 88 | flow_90 = np.percentile(flow_norm.cpu().numpy(), 90) 89 | 90 | # flow_10 = np.percentile(flow_norm.cpu().numpy(), 10) 91 | flow_90 = np.percentile(flow_norm.cpu().numpy(), 90) 92 | flow_95 = np.percentile(flow_norm.cpu().numpy(), 95) 93 | 94 | if cur_idx == 5: # if still small flow then drop 95 | if flow_95 < 20.0: 96 | # no motion in the frame. skip 97 | print('flow is tiny :(') 98 | return None 99 | 100 | if cur_idx == tgt_idx_guess-1: # if still small flow then drop 101 | if flow_95 < 50.0: 102 | # no motion in the frame. skip 103 | print('flow is tiny :(') 104 | return None 105 | 106 | if flow is None: # means first iter 107 | if flow_90 < 1.0: 108 | # no motion in the frame. skip 109 | return None 110 | flow = agg_cur_flow 111 | 112 | if flow_90 <= 300: # maybe should increase this part 113 | # update idx 114 | tgt_idx = cur_idx 115 | flow = agg_cur_flow 116 | else: 117 | break 118 | final_time = time.time() 119 | print('time guessing idx', final_time - init_time) 120 | 121 | _, flow_warping_mask = processing_utils.forward_warp(frames_t[0], frames_t[tgt_idx], flow, grid=None, alpha_mask=None) 122 | flow_warping_mask = flow_warping_mask.squeeze().numpy() > 0.5 123 | 124 | if np.mean(flow_warping_mask) < 0.6: 125 | return 126 | 127 | 128 | src_array = frames_t[0].moveaxis(0, -1).cpu().numpy() * 1.0 129 | init_time = time.time() 130 | depth = get_depth_from_array(frames_t[0]) 131 | finish_time = time.time() 132 | print('time getting depth', finish_time - init_time) 133 | # flow, pairwise_flows = aggregate_frames(frames_t) 134 | # agg_flow = median_filter(flow) 135 | 136 | src_array_uint = src_array * 255.0 137 | src_array_uint = src_array_uint.astype(np.uint8) 138 | segments = processing_utils.mask_generator.generate(src_array_uint) 139 | 140 | size = src_array.shape[1] 141 | grid_np = get_grid(size).astype(np.float16) / size # 512 x 512 x 2get 142 | grid_t = torch.tensor(grid_np).moveaxis(-1, 0) # 512 x 512 x 2 143 | 144 | 145 | collage, canvas_alpha, lost_alpha = collage_warp(src_array, flow.squeeze(), depth, segments, grid_array=grid_np) 146 | lost_alpha_t = torch.tensor(lost_alpha).squeeze().unsqueeze(0) 147 | warping_alpha = (lost_alpha_t < 0.5).float() 148 | 149 | rgb_grid_splatted, actual_warped_mask = processing_utils.forward_warp(frames_t[0], frames_t[tgt_idx], flow, grid=grid_t, alpha_mask=warping_alpha) 150 | 151 | 152 | # basic blending now 153 | # print('rgb grid splatted', rgb_grid_splatted.shape) 154 | warped_src = (rgb_grid_splatted * actual_warped_mask).moveaxis(0, -1).cpu().numpy() 155 | canvas_alpha_mask = canvas_alpha == 0.0 156 | collage_mask = canvas_alpha.squeeze() + actual_warped_mask.squeeze().cpu().numpy() 157 | collage_mask = collage_mask > 0.5 158 | 159 | composite_grid = warped_src * canvas_alpha_mask + collage 160 | rgb_grid_splatted_np = rgb_grid_splatted.moveaxis(0, -1).cpu().numpy() 161 | 162 | return frames_t[0], frames_t[tgt_idx], rgb_grid_splatted_np, composite_grid, flow_warping_mask, collage_mask 163 | 164 | def collage_warp(rgb_array, flow, depth, segments, grid_array): 165 | avg_depths = [] 166 | avg_flows = [] 167 | 168 | # src_array = src_array.moveaxis(-1, 0).cpu().numpy() #np.array(Image.open(src_path).convert('RGB')) / 255.0 169 | src_array = np.concatenate([rgb_array, grid_array], axis=-1) 170 | canvas = np.zeros_like(src_array) 171 | canvas_alpha = np.zeros_like(canvas[...,-1:]).astype(float) 172 | lost_regions = np.zeros_like(canvas[...,-1:]).astype(float) 173 | z_buffer = np.ones_like(depth)[..., None] * -1.0 174 | unsqueezed_depth = depth[..., None] 175 | 176 | affine_transforms = [] 177 | 178 | filtered_segments = [] 179 | for segment in segments: 180 | if segment['area'] > 300: 181 | filtered_segments.append(segment) 182 | 183 | for segment in filtered_segments: 184 | seg_mask = segment['segmentation'] 185 | avg_flow = torch.mean(flow[:, seg_mask],dim=1) 186 | avg_flows.append(avg_flow) 187 | # median depth (conversion from disparity) 188 | avg_depth = torch.median(1.0 / (depth[seg_mask] + 1e-6)) 189 | avg_depths.append(avg_depth) 190 | 191 | all_y, all_x = np.nonzero(segment['segmentation']) 192 | rand_indices = np.random.randint(0, len(all_y), size=50) 193 | rand_x = [all_x[i] for i in rand_indices] 194 | rand_y = [all_y[i] for i in rand_indices] 195 | 196 | src_pairs = [(x, y) for x, y in zip(rand_x, rand_y)] 197 | # tgt_pairs = [(x + w, y) for x, y in src_pairs] 198 | tgt_pairs = [] 199 | # print('estimating affine') # TODO this can be faster 200 | for i in range(len(src_pairs)): 201 | x, y = src_pairs[i] 202 | dx, dy = flow[:, y, x] 203 | tgt_pairs.append((x+dx, y+dy)) 204 | 205 | # affine_trans, inliers = cv2.estimateAffine2D(np.array(src_pairs).astype(np.float32), np.array(tgt_pairs).astype(np.float32)) 206 | affine_trans, inliers = cv2.estimateAffinePartial2D(np.array(src_pairs).astype(np.float32), np.array(tgt_pairs).astype(np.float32)) 207 | # print('num inliers', np.sum(inliers)) 208 | # # print('num inliers', np.sum(inliers)) 209 | affine_transforms.append(affine_trans) 210 | 211 | depth_sorted_indices = np.arange(len(avg_depths)) 212 | depth_sorted_indices = sorted(depth_sorted_indices, key=lambda x: avg_depths[x]) 213 | # sorted_masks = [] 214 | # print('warping stuff') 215 | for idx in depth_sorted_indices: 216 | # sorted_masks.append(mask[idx]) 217 | alpha_mask = filtered_segments[idx]['segmentation'][..., None] * (lost_regions < 0.5).astype(float) 218 | src_rgba = np.concatenate([src_array, alpha_mask, unsqueezed_depth], axis=-1) 219 | warp_dst = cv2.warpAffine(src_rgba, affine_transforms[idx], (src_array.shape[1], src_array.shape[0])) 220 | warped_mask = warp_dst[..., -2:-1] # this is warped alpha 221 | warped_depth = warp_dst[..., -1:] 222 | warped_rgb = warp_dst[...,:-2] 223 | 224 | good_z_region = warped_depth > z_buffer 225 | 226 | warped_mask = np.logical_and(warped_mask > 0.5, good_z_region).astype(float) 227 | 228 | kernel = np.ones((3,3), float) 229 | # print('og masked shape', warped_mask.shape) 230 | # warped_mask = cv2.erode(warped_mask,(5,5))[..., None] 231 | # print('eroded masked shape', warped_mask.shape) 232 | canvas_alpha += cv2.erode(warped_mask,kernel)[..., None] 233 | 234 | lost_regions += alpha_mask 235 | canvas = canvas * (1.0 - warped_mask) + warped_mask * warped_rgb # TODO check if need to dialate here 236 | z_buffer = z_buffer * (1.0 - warped_mask) + warped_mask * warped_depth # TODO check if need to dialate here # print('max lost region', np.max(lost_regions)) 237 | return canvas, canvas_alpha, lost_regions 238 | 239 | def get_depth_from_array(img_t): 240 | img_arr = img_t.moveaxis(0, -1).cpu().numpy() * 1.0 241 | # print(img_arr.shape) 242 | img_arr *= 255.0 243 | img_arr = img_arr.astype(np.uint8) 244 | input_batch = processing_utils.depth_transform(img_arr).cuda() 245 | 246 | with torch.no_grad(): 247 | prediction = processing_utils.midas(input_batch) 248 | 249 | prediction = torch.nn.functional.interpolate( 250 | prediction.unsqueeze(1), 251 | size=img_arr.shape[:2], 252 | mode="bicubic", 253 | align_corners=False, 254 | ).squeeze() 255 | 256 | output = prediction.cpu() 257 | return output 258 | 259 | 260 | # %% 261 | 262 | def main(): 263 | print('starting main') 264 | video_folder = './example_videos' 265 | save_dir = pathlib.Path('./processed_data') 266 | process_video_folder(video_folder, save_dir) 267 | 268 | def process_video_folder(video_folder, save_dir): 269 | all_counter = 0 270 | success_counter = 0 271 | 272 | # save_folder = pathlib.Path('/dev/shm/processed') 273 | # save_dir = save_folder / foldername #pathlib.Path('/sensei-fs/users/halzayer/collage2photo/testing_partitioning_dilate_extreme') 274 | os.makedirs(save_dir, exist_ok=True) 275 | 276 | dataset = MomentsDataset(videos_folder=video_folder, num_frames=20, samples_per_video=5) 277 | batch_size = 4 278 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) 279 | 280 | with torch.no_grad(): 281 | for i, batch in tqdm.tqdm(enumerate(dataloader), total=len(dataset)//batch_size): 282 | frames_to_visualize = batch["frames"] 283 | bs = frames_to_visualize.shape[0] 284 | 285 | for j in range(bs): 286 | frames = frames_to_visualize[j] 287 | caption = batch["caption"][j] 288 | 289 | collage_init_time = time.time() 290 | out = collage_from_frames(frames) 291 | collage_finish_time = time.time() 292 | print('collage processing time', collage_finish_time - collage_init_time) 293 | all_counter += 1 294 | if out is not None: 295 | src_image, tgt_image, splatted, collage, flow_mask, collage_mask = out 296 | 297 | splatted_rgb = splatted[...,:3] 298 | splatted_grid = splatted[...,3:].astype(np.float16) 299 | 300 | collage_rgb = collage[...,:3] 301 | collage_grid = collage[...,3:].astype(np.float16) 302 | success_counter += 1 303 | else: 304 | continue 305 | 306 | id_str = f'{success_counter:08d}' 307 | 308 | src_path = str(save_dir / f'src_{id_str}.png') 309 | tgt_path = str(save_dir / f'tgt_{id_str}.png') 310 | flow_warped_path = str(save_dir / f'flow_warped_{id_str}.png') 311 | composite_path = str(save_dir / f'composite_{id_str}.png') 312 | flow_mask_path = str(save_dir / f'flow_mask_{id_str}.png') 313 | composite_mask_path = str(save_dir / f'composite_mask_{id_str}.png') 314 | 315 | flow_grid_path = str(save_dir / f'flow_warped_grid_{id_str}.npy') 316 | composite_grid_path = str(save_dir / f'composite_grid_{id_str}.npy') 317 | 318 | save_image(src_image, src_path) 319 | save_image(tgt_image, tgt_path) 320 | 321 | collage_pil = Image.fromarray((collage_rgb * 255).astype(np.uint8)) 322 | collage_pil.save(composite_path) 323 | 324 | splatted_pil = Image.fromarray((splatted_rgb * 255).astype(np.uint8)) 325 | splatted_pil.save(flow_warped_path) 326 | 327 | flow_mask_pil = Image.fromarray((flow_mask.astype(float) * 255).astype(np.uint8)) 328 | flow_mask_pil.save(flow_mask_path) 329 | 330 | composite_mask_pil = Image.fromarray((collage_mask.astype(float) * 255).astype(np.uint8)) 331 | composite_mask_pil.save(composite_mask_path) 332 | 333 | splatted_grid_t = torch.tensor(splatted_grid).moveaxis(-1, 0) 334 | splatted_grid_resized = torchvision.transforms.functional.resize(splatted_grid_t, (64,64)) 335 | 336 | collage_grid_t = torch.tensor(collage_grid).moveaxis(-1, 0) 337 | collage_grid_resized = torchvision.transforms.functional.resize(collage_grid_t, (64,64)) 338 | np.save(flow_grid_path, splatted_grid_resized.cpu().numpy()) 339 | np.save(composite_grid_path, collage_grid_resized.cpu().numpy()) 340 | 341 | 342 | del out 343 | del splatted_grid 344 | del collage_grid 345 | del frames 346 | 347 | del frames_to_visualize 348 | 349 | 350 | 351 | #%% 352 | 353 | if __name__ == '__main__': 354 | try: 355 | main() 356 | except Exception as e: 357 | print(e) 358 | print('process failed') 359 | 360 | -------------------------------------------------------------------------------- /data_processing/processing_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import cv2 3 | import numpy as np 4 | import sys 5 | import torchvision 6 | from PIL import Image 7 | from torchvision.models.optical_flow import Raft_Large_Weights 8 | from torchvision.models.optical_flow import raft_large 9 | from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor 10 | import matplotlib.pyplot as plt 11 | import torchvision.transforms.functional as F 12 | sys.path.append('./softmax-splatting') 13 | import softsplat 14 | 15 | 16 | sam_checkpoint = "./sam_model/sam_vit_h_4b8939.pth" 17 | model_type = "vit_h" 18 | 19 | device = "cuda" 20 | 21 | sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) 22 | sam.to(device=device) 23 | # mask_generator = SamAutomaticMaskGenerator(sam, 24 | # crop_overlap_ratio=0.05, 25 | # box_nms_thresh=0.2, 26 | # points_per_side=32, 27 | # pred_iou_thresh=0.86, 28 | # stability_score_thresh=0.8, 29 | 30 | # min_mask_region_area=100,) 31 | # mask_generator = SamAutomaticMaskGenerator(sam) 32 | mask_generator = SamAutomaticMaskGenerator(sam, 33 | # box_nms_thresh=0.5, 34 | # crop_overlap_ratio=0.75, 35 | # min_mask_region_area=200, 36 | ) 37 | 38 | def get_mask(img_path): 39 | image = cv2.imread(img_path) 40 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 41 | masks = mask_generator.generate(image) 42 | return masks 43 | 44 | def get_mask_from_array(arr): 45 | return mask_generator.generate(arr) 46 | 47 | # depth model 48 | 49 | import cv2 50 | import torch 51 | import urllib.request 52 | 53 | import matplotlib.pyplot as plt 54 | 55 | # potentially downgrade this. just need rough depths. benchmark this 56 | model_type = "DPT_Large" # MiDaS v3 - Large (highest accuracy, slowest inference speed) 57 | #model_type = "DPT_Hybrid" # MiDaS v3 - Hybrid (medium accuracy, medium inference speed) 58 | #model_type = "MiDaS_small" # MiDaS v2.1 - Small (lowest accuracy, highest inference speed) 59 | 60 | # midas = torch.hub.load("intel-isl/MiDaS", model_type) 61 | midas = torch.hub.load("/sensei-fs/users/halzayer/collage2photo/model_cache/intel-isl_MiDaS_master", model_type, source='local') 62 | 63 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 64 | midas.to(device) 65 | midas.eval() 66 | 67 | # midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") 68 | midas_transforms = torch.hub.load("/sensei-fs/users/halzayer/collage2photo/model_cache/intel-isl_MiDaS_master", "transforms", source='local') 69 | 70 | if model_type == "DPT_Large" or model_type == "DPT_Hybrid": 71 | depth_transform = midas_transforms.dpt_transform 72 | else: 73 | depth_transform = midas_transforms.small_transform 74 | 75 | # img_path = '/sensei-fs/users/halzayer/valid/JPEGImages/45597680/00005.jpg' 76 | def get_depth(img_path): 77 | img = cv2.imread(img_path) 78 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 79 | 80 | input_batch = depth_transform(img).to(device) 81 | 82 | with torch.no_grad(): 83 | prediction = midas(input_batch) 84 | 85 | prediction = torch.nn.functional.interpolate( 86 | prediction.unsqueeze(1), 87 | size=img.shape[:2], 88 | mode="bicubic", 89 | align_corners=False, 90 | ).squeeze() 91 | 92 | output = prediction.cpu() 93 | return output 94 | 95 | def get_depth_from_array(img): 96 | input_batch = depth_transform(img).to(device) 97 | 98 | with torch.no_grad(): 99 | prediction = midas(input_batch) 100 | 101 | prediction = torch.nn.functional.interpolate( 102 | prediction.unsqueeze(1), 103 | size=img.shape[:2], 104 | mode="bicubic", 105 | align_corners=False, 106 | ).squeeze() 107 | 108 | output = prediction.cpu() 109 | return output 110 | 111 | 112 | def load_image(img_path): 113 | img1_names = [img_path] 114 | 115 | img1_pil = [Image.open(fn) for fn in img1_names] 116 | img1_frames = [torchvision.transforms.functional.pil_to_tensor(fn) for fn in img1_pil] 117 | 118 | img1_batch = torch.stack(img1_frames) 119 | 120 | return img1_batch 121 | 122 | weights = Raft_Large_Weights.DEFAULT 123 | transforms = weights.transforms() 124 | 125 | device = "cuda" if torch.cuda.is_available() else "cpu" 126 | 127 | model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device) 128 | model = model.eval() 129 | 130 | print('created model') 131 | 132 | def preprocess(img1_batch, img2_batch, size=[520,960], transform_batch=True): 133 | img1_batch = F.resize(img1_batch, size=size, antialias=False) 134 | img2_batch = F.resize(img2_batch, size=size, antialias=False) 135 | if transform_batch: 136 | return transforms(img1_batch, img2_batch) 137 | else: 138 | return img1_batch, img2_batch 139 | 140 | def compute_flow(img_path_1, img_path_2): 141 | img1_batch_og, img2_batch_og = load_image(img_path_1), load_image(img_path_2) 142 | B, C, H, W = img1_batch_og.shape 143 | 144 | img1_batch, img2_batch = preprocess(img1_batch_og, img2_batch_og, transform_batch=False) 145 | img1_batch_t, img2_batch_t = transforms(img1_batch, img2_batch) 146 | 147 | # If you can, run this example on a GPU, it will be a lot faster. 148 | with torch.no_grad(): 149 | list_of_flows = model(img1_batch_t.to(device), img2_batch_t.to(device)) 150 | predicted_flows = list_of_flows[-1] 151 | # flows.append(predicted_flows) 152 | 153 | resized_flow = F.resize(predicted_flows, size=(H, W), antialias=False) 154 | 155 | _, _, flow_H, flow_W = predicted_flows.shape 156 | 157 | resized_flow[:,0] *= (W / flow_W) 158 | resized_flow[:,1] *= (H / flow_H) 159 | 160 | return resized_flow.detach().cpu().squeeze() 161 | 162 | def compute_flow_from_tensors(img1_batch_og, img2_batch_og): 163 | if len(img1_batch_og.shape) < 4: 164 | img1_batch_og = img1_batch_og.unsqueeze(0) 165 | if len(img2_batch_og.shape) < 4: 166 | img2_batch_og = img2_batch_og.unsqueeze(0) 167 | 168 | B, C, H, W = img1_batch_og.shape 169 | img1_batch, img2_batch = preprocess(img1_batch_og, img2_batch_og, transform_batch=False) 170 | img1_batch_t, img2_batch_t = transforms(img1_batch, img2_batch) 171 | 172 | # If you can, run this example on a GPU, it will be a lot faster. 173 | with torch.no_grad(): 174 | list_of_flows = model(img1_batch_t.to(device), img2_batch_t.to(device)) 175 | predicted_flows = list_of_flows[-1] 176 | # flows.append(predicted_flows) 177 | 178 | resized_flow = F.resize(predicted_flows, size=(H, W), antialias=False) 179 | 180 | _, _, flow_H, flow_W = predicted_flows.shape 181 | 182 | resized_flow[:,0] *= (W / flow_W) 183 | resized_flow[:,1] *= (H / flow_H) 184 | 185 | return resized_flow.detach().cpu().squeeze() 186 | 187 | 188 | 189 | # import run 190 | backwarp_tenGrid = {} 191 | 192 | def backwarp(tenIn, tenFlow): 193 | if str(tenFlow.shape) not in backwarp_tenGrid: 194 | tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1) 195 | tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3]) 196 | 197 | backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([tenHor, tenVer], 1).cuda() 198 | # end 199 | 200 | tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenIn.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenIn.shape[2] - 1.0) / 2.0)], 1) 201 | 202 | return torch.nn.functional.grid_sample(input=tenIn, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True) 203 | 204 | torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance 205 | 206 | ########################################################## 207 | def forward_splt(src, tgt, flow, partial=False): 208 | tenTwo = tgt.unsqueeze(0).cuda() #torch.FloatTensor(numpy.ascontiguousarray(cv2.imread(filename='./images/one.png', flags=-1).transpose(2, 0, 1)[None, :, :, :].astype(numpy.float32) * (1.0 / 255.0))).cuda() 209 | tenOne = src.unsqueeze(0).cuda() #torch.FloatTensor(numpy.ascontiguousarray(cv2.imread(filename='./images/two.png', flags=-1).transpose(2, 0, 1)[None, :, :, :].astype(numpy.float32) * (1.0 / 255.0))).cuda() 210 | tenFlow = flow.unsqueeze(0).cuda() #torch.FloatTensor(numpy.ascontiguousarray(run.read_flo('./images/flow.flo').transpose(2, 0, 1)[None, :, :, :])).cuda() 211 | 212 | if not partial: 213 | tenMetric = torch.nn.functional.l1_loss(input=tenOne, target=backwarp(tenIn=tenTwo, tenFlow=tenFlow), reduction='none').mean([1], True) 214 | else: 215 | tenMetric = torch.nn.functional.l1_loss(input=tenOne[:,:3], target=backwarp(tenIn=tenTwo[:,:3], tenFlow=tenFlow[:,:3]), reduction='none').mean([1], True) 216 | # for intTime, fltTime in enumerate(np.linspace(0.0, 1.0, 11).tolist()): 217 | tenSoftmax = softsplat.softsplat(tenIn=tenOne, tenFlow=tenFlow , tenMetric=(-20.0 * tenMetric).clip(-20.0, 20.0), strMode='soft') # -20.0 is a hyperparameter, called 'alpha' in the paper, that could be learned using a torch.Parameter 218 | 219 | return tenSoftmax.cpu() 220 | 221 | 222 | def aggregate_frames(frames, pairwise_flows=None, agg_flow=None): 223 | if pairwise_flows is None: 224 | # store pairwise flows 225 | pairwise_flows = [] 226 | 227 | if agg_flow is None: 228 | start_idx = 0 229 | else: 230 | start_idx = len(pairwise_flows) 231 | 232 | og_image = frames[start_idx] 233 | prev_frame = og_image 234 | 235 | for i in range(start_idx, len(frames)-1): 236 | tgt_frame = frames[i+1] 237 | 238 | if i < len(pairwise_flows): 239 | flow = pairwise_flows[i] 240 | else: 241 | flow = compute_flow_from_tensors(prev_frame, tgt_frame) 242 | pairwise_flows.append(flow.clone()) 243 | 244 | _, H, W = flow.shape 245 | B=1 246 | 247 | xx = torch.arange(0, W).view(1,-1).repeat(H,1) 248 | 249 | yy = torch.arange(0, H).view(-1,1).repeat(1,W) 250 | 251 | xx = xx.view(1,1,H,W).repeat(B,1,1,1) 252 | 253 | yy = yy.view(1,1,H,W).repeat(B,1,1,1) 254 | 255 | grid = torch.cat((xx,yy),1).float() 256 | 257 | flow = flow.unsqueeze(0) 258 | if agg_flow is None: 259 | agg_flow = torch.zeros_like(flow) 260 | 261 | vgrid = grid + agg_flow 262 | vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1) - 1 263 | 264 | vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1) - 1 265 | 266 | flow_out = torch.nn.functional.grid_sample(flow, vgrid.permute(0,2,3,1), 'nearest') 267 | 268 | agg_flow += flow_out 269 | 270 | 271 | # mask = forward_splt(torch.ones_like(og_image), torch.ones_like(og_image), agg_flow.squeeze()).squeeze() 272 | # blur_t = torchvision.transforms.GaussianBlur(kernel_size=(25,25), sigma=5.0) 273 | # warping_mask = (blur_t(mask)[0:1] > 0.8) 274 | # masks.append(warping_mask) 275 | prev_frame = tgt_frame 276 | 277 | return agg_flow, pairwise_flows #og_splatted_img, agg_flow, actual_warped_mask 278 | 279 | 280 | def forward_warp(src_frame, tgt_frame, flow, grid=None, alpha_mask=None): 281 | if alpha_mask is None: 282 | alpha_mask = torch.ones_like(src_frame[:1]) 283 | 284 | if grid is not None: 285 | src_list = [src_frame, grid, alpha_mask] 286 | tgt_list = [tgt_frame, grid, alpha_mask] 287 | else: 288 | src_list = [src_frame, alpha_mask] 289 | tgt_list = [tgt_frame, alpha_mask] 290 | 291 | og_image_padded = torch.concat(src_list, dim=0) 292 | tgt_frame_padded = torch.concat(tgt_list, dim=0) 293 | 294 | og_splatted_img = forward_splt(og_image_padded, tgt_frame_padded, flow.squeeze(), partial=True).squeeze() 295 | # print('og splatted image shape') 296 | # grid_transformed = og_splatted_img[3:-1] 297 | # print('grid transformed shape', grid_transformed) 298 | 299 | # grid *= grid_size 300 | # grid_transformed *= grid_size 301 | actual_warped_mask = og_splatted_img[-1:] 302 | splatted_rgb_grid = og_splatted_img[:-1] 303 | 304 | return splatted_rgb_grid, actual_warped_mask -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: MagicFixup 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip=20.3 8 | - cudatoolkit=11.3 9 | - pytorch=1.11.0 10 | - torchvision=0.12.0 11 | - numpy=1.19.2 12 | - pip: 13 | - albumentations==0.4.3 14 | - diffusers 15 | - bezier 16 | - gradio 17 | - opencv-python==4.1.2.30 18 | - pudb==2019.2 19 | - invisible-watermark 20 | - imageio==2.9.0 21 | - imageio-ffmpeg==0.4.2 22 | - pytorch-lightning==2.0.0 23 | - omegaconf==2.1.1 24 | - test-tube>=0.7.5 25 | - streamlit>=0.73.1 26 | - einops==0.3.0 27 | - torch-fidelity==0.3.0 28 | - transformers==4.19.2 29 | - torchmetrics==0.7.0 30 | - kornia==0.6 31 | - -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers 32 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip 33 | - -e . 34 | -------------------------------------------------------------------------------- /examples/dog_beach__edit__003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/examples/dog_beach__edit__003.png -------------------------------------------------------------------------------- /examples/dog_beach_og.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/examples/dog_beach_og.png -------------------------------------------------------------------------------- /examples/fox_drinking__edit__01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/examples/fox_drinking__edit__01.png -------------------------------------------------------------------------------- /examples/fox_drinking__edit__02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/examples/fox_drinking__edit__02.png -------------------------------------------------------------------------------- /examples/fox_drinking_og.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/examples/fox_drinking_og.png -------------------------------------------------------------------------------- /examples/kingfisher__edit__001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/examples/kingfisher__edit__001.png -------------------------------------------------------------------------------- /examples/kingfisher_og.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/examples/kingfisher_og.png -------------------------------------------------------------------------------- /examples/log.csv: -------------------------------------------------------------------------------- 1 | fox_drinking_og.png,fox_drinking__edit__01.png 2 | palm_tree_og.png,palm_tree__edit__01.png 3 | kingfisher_og.png,kingfisher__edit__001.png 4 | pipes_og.png,pipes__edit__01.png 5 | dog_beach_og.png,dog_beach__edit__003.png 6 | fox_drinking_og.png,fox_drinking__edit__02.png -------------------------------------------------------------------------------- /examples/palm_tree__edit__01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/examples/palm_tree__edit__01.png -------------------------------------------------------------------------------- /examples/palm_tree_og.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/examples/palm_tree_og.png -------------------------------------------------------------------------------- /examples/pipes__edit__01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/examples/pipes__edit__01.png -------------------------------------------------------------------------------- /examples/pipes_og.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/examples/pipes_og.png -------------------------------------------------------------------------------- /ldm/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/data/__init__.py -------------------------------------------------------------------------------- /ldm/data/collage_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Adobe. All rights reserved. 2 | 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | import torchvision.transforms.functional as F 7 | import glob 8 | import torchvision 9 | from PIL import Image 10 | import time 11 | import os 12 | import tqdm 13 | from torch.utils.data import Dataset 14 | import pathlib 15 | import cv2 16 | from PIL import Image 17 | import os 18 | import json 19 | import albumentations as A 20 | 21 | def get_tensor(normalize=True, toTensor=True): 22 | transform_list = [] 23 | if toTensor: 24 | transform_list += [torchvision.transforms.ToTensor()] 25 | 26 | if normalize: 27 | # transform_list += [torchvision.transforms.Normalize((0.0, 0.0, 0.0), 28 | # (10.0, 10.0, 10.0))] 29 | transform_list += [torchvision.transforms.Normalize((0.5, 0.5, 0.5), 30 | (0.5, 0.5, 0.5))] 31 | return torchvision.transforms.Compose(transform_list) 32 | 33 | def get_tensor_clip(normalize=True, toTensor=True): 34 | transform_list = [torchvision.transforms.Resize((224,224))] 35 | if toTensor: 36 | transform_list += [torchvision.transforms.ToTensor()] 37 | 38 | if normalize: 39 | transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), 40 | (0.26862954, 0.26130258, 0.27577711))] 41 | return torchvision.transforms.Compose(transform_list) 42 | 43 | def get_tensor_dino(normalize=True, toTensor=True): 44 | transform_list = [torchvision.transforms.Resize((224,224))] 45 | if toTensor: 46 | transform_list += [torchvision.transforms.ToTensor()] 47 | 48 | if normalize: 49 | transform_list += [lambda x: 255.0 * x[:3], 50 | torchvision.transforms.Normalize( 51 | mean=(123.675, 116.28, 103.53), 52 | std=(58.395, 57.12, 57.375), 53 | )] 54 | return torchvision.transforms.Compose(transform_list) 55 | 56 | def crawl_folders(folder_path): 57 | # glob crawl 58 | all_files = [] 59 | folders = glob.glob(f'{folder_path}/*') 60 | 61 | for folder in folders: 62 | src_paths = glob.glob(f'{folder}/src_*png') 63 | all_files.extend(src_paths) 64 | return all_files 65 | 66 | def get_grid(size): 67 | y = np.repeat(np.arange(size)[None, ...], size) 68 | y = y.reshape(size, size) 69 | x = y.transpose() 70 | out = np.stack([y,x], -1) 71 | return out 72 | 73 | 74 | class CollageDataset(Dataset): 75 | def __init__(self, split_files, image_size, embedding_type, warping_type, blur_warped=False): 76 | self.size = image_size 77 | # depends on the embedding type 78 | if embedding_type == 'clip': 79 | self.get_embedding_vector = get_tensor_clip() 80 | elif embedding_type == 'dino': 81 | self.get_embedding_vector = get_tensor_dino() 82 | self.get_tensor = get_tensor() 83 | self.resize = torchvision.transforms.Resize(size=(image_size, image_size)) 84 | self.to_mask_tensor = get_tensor(normalize=False) 85 | 86 | self.src_paths = crawl_folders(split_files) 87 | print('current split size', len(self.src_paths)) 88 | print('for dir', split_files) 89 | 90 | assert warping_type in ['collage', 'flow', 'mix'] 91 | self.warping_type = warping_type 92 | 93 | self.mask_threshold = 0.85 94 | 95 | self.blur_t = torchvision.transforms.GaussianBlur(kernel_size=51, sigma=20.0) 96 | self.blur_warped = blur_warped 97 | 98 | # self.save_folder = '/mnt/localssd/collage_out' 99 | # os.makedirs(self.save_folder, exist_ok=True) 100 | self.save_counter = 0 101 | self.save_subfolder = None 102 | 103 | def __len__(self): 104 | return len(self.src_paths) 105 | 106 | 107 | def __getitem__(self, idx, depth=0): 108 | 109 | if self.warping_type == 'mix': 110 | # randomly sample 111 | warping_type = np.random.choice(['collage', 'flow']) 112 | else: 113 | warping_type = self.warping_type 114 | 115 | src_path = self.src_paths[idx] 116 | tgt_path = src_path.replace('src_', 'tgt_') 117 | 118 | if warping_type == 'collage': 119 | warped_path = src_path.replace('src_', 'composite_') 120 | mask_path = src_path.replace('src_', 'composite_mask_') 121 | corresp_path = src_path.replace('src_', 'composite_grid_') 122 | corresp_path = corresp_path.split('.')[0] 123 | corresp_path += '.npy' 124 | elif warping_type == 'flow': 125 | warped_path = src_path.replace('src_', 'flow_warped_') 126 | mask_path = src_path.replace('src_', 'flow_mask_') 127 | corresp_path = src_path.replace('src_', 'flow_warped_grid_') 128 | corresp_path = corresp_path.split('.')[0] 129 | corresp_path += '.npy' 130 | else: 131 | raise ValueError 132 | 133 | # load reference image, warped image, and target GT image 134 | reference_img = Image.open(src_path).convert('RGB') 135 | gt_img = Image.open(tgt_path).convert('RGB') 136 | warped_img = Image.open(warped_path).convert('RGB') 137 | warping_mask = Image.open(mask_path).convert('RGB') 138 | 139 | # resize all 140 | reference_img = self.resize(reference_img) 141 | gt_img = self.resize(gt_img) 142 | warped_img = self.resize(warped_img) 143 | warping_mask = self.resize(warping_mask) 144 | 145 | 146 | # NO CROPPING PLEASE. ALL INPUTS ARE 512X512 147 | # Random crop 148 | # i, j, h, w = torchvision.transforms.RandomCrop.get_params( 149 | # reference_img, output_size=(512, 512)) 150 | 151 | # reference_img = torchvision.transforms.functional.crop(reference_img, i, j, h, w) 152 | # gt_img = torchvision.transforms.functional.crop(gt_img, i, j, h, w) 153 | # warped_img = torchvision.transforms.functional.crop(warped_img, i, j, h, w) 154 | # # TODO start using the warping mask 155 | # warping_mask = torchvision.transforms.functional.crop(warping_mask, i, j, h, w) 156 | 157 | grid_transformed = torch.tensor(np.load(corresp_path)) 158 | # grid_transformed = torchvision.transforms.functional.crop(grid_transformed, i, j, h, w) 159 | 160 | 161 | 162 | # reference_t = to_tensor(reference_img) 163 | gt_t = self.get_tensor(gt_img) 164 | warped_t = self.get_tensor(warped_img) 165 | warping_mask_t = self.to_mask_tensor(warping_mask) 166 | clean_reference_t = self.get_tensor(reference_img) 167 | # compute error to generate mask 168 | blur_t = torchvision.transforms.GaussianBlur(kernel_size=(11,11), sigma=5.0) 169 | 170 | reference_clip_img = self.get_embedding_vector(reference_img) 171 | 172 | mask = torch.ones_like(gt_t)[:1] 173 | warping_mask_t = warping_mask_t[:1] 174 | 175 | good_region = torch.mean(warping_mask_t) 176 | # print('good region', good_region) 177 | # print('good region frac', good_region) 178 | if good_region < 0.4 and depth < 3: 179 | # example too hard, sample something else 180 | # print('bad image, resampling..') 181 | rand_idx = np.random.randint(len(self.src_paths)) 182 | return self.__getitem__(rand_idx, depth+1) 183 | 184 | # if mask is too large then ignore 185 | 186 | # #gaussian inpainting now 187 | missing_mask = warping_mask_t[0] < 0.5 188 | 189 | 190 | reference = (warped_t.clone() + 1) / 2.0 191 | ref_cv = torch.moveaxis(reference, 0, -1).cpu().numpy() 192 | ref_cv = (ref_cv * 255).astype(np.uint8) 193 | cv_mask = missing_mask.int().squeeze().cpu().numpy().astype(np.uint8) 194 | kernel = np.ones((7,7)) 195 | dilated_mask = cv2.dilate(cv_mask, kernel) 196 | # cv_mask = np.stack([cv_mask]*3, axis=-1) 197 | dst = cv2.inpaint(ref_cv,dilated_mask,5,cv2.INPAINT_NS) 198 | 199 | mask_resized = torchvision.transforms.functional.resize(warping_mask_t, (64,64)) 200 | # print(mask_resized) 201 | size=512 202 | grid_np = (get_grid(size) / size).astype(np.float16)# 512 x 512 x 2 203 | grid_t = torch.tensor(grid_np).moveaxis(-1, 0) # 512 x 512 x 2 204 | grid_resized = torchvision.transforms.functional.resize(grid_t, (64,64)).to(torch.float16) 205 | changed_pixels = torch.logical_or((torch.abs(grid_resized - grid_transformed)[0] > 0.04) , (torch.abs(grid_resized - grid_transformed)[1] > 0.04)) 206 | changed_pixels = changed_pixels.unsqueeze(0) 207 | # changed_pixels = torch.logical_and(changed_pixels, (mask_resized >= 0.3)) 208 | changed_pixels = changed_pixels.float() 209 | 210 | inpainted_warped = (torch.tensor(dst).moveaxis(-1, 0).float() / 255.0) * 2.0 - 1.0 211 | 212 | if self.blur_warped: 213 | inpainted_warped= self.blur_t(inpainted_warped) 214 | 215 | out = {"GT": gt_t,"inpaint_image": inpainted_warped,"inpaint_mask": warping_mask_t, "ref_imgs": reference_clip_img, "clean_reference": clean_reference_t, 'grid_transformed': grid_transformed, "changed_pixels": changed_pixels} 216 | # out = {"GT": gt_t,"inpaint_image": inpainted_warped * 0.0,"inpaint_mask": torch.ones_like(warping_mask_t), "ref_imgs": reference_clip_img * 0.0, "clean_reference": gt_t, 'grid_transformed': grid_transformed, "changed_pixels": changed_pixels} 217 | # out = {"GT": gt_t,"inpaint_image": inpainted_warped * 0.0,"inpaint_mask": warping_mask_t, "ref_imgs": reference_clip_img * 0.0, "clean_reference": clean_reference_t, 'grid_transformed': grid_transformed, "changed_pixels": changed_pixels} 218 | 219 | # out = {"GT": gt_t,"inpaint_image": warped_t,"inpaint_mask": warping_mask_t, "ref_imgs": reference_clip_img, "clean_reference": clean_reference_t, 'grid_transformed': grid_transformed, 'inpainted': inpainted_warped} 220 | # out_half = {key: out[key].half() for key in out} 221 | # if self.save_counter < 50: 222 | # save_path = f'{self.save_folder}/output_{time.time()}.pt' 223 | # torch.save(out, save_path) 224 | # self.save_counter += 1 225 | 226 | return out 227 | 228 | 229 | 230 | 231 | -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and 2 | # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example 3 | # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. 4 | # CreativeML Open RAIL-M 5 | # 6 | # ========================================================================================== 7 | # 8 | # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. 9 | # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit 10 | # LICENSE.md. 11 | # 12 | # ========================================================================================== 13 | 14 | import numpy as np 15 | 16 | 17 | class LambdaWarmUpCosineScheduler: 18 | """ 19 | note: use with a base_lr of 1.0 20 | """ 21 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 22 | self.lr_warm_up_steps = warm_up_steps 23 | self.lr_start = lr_start 24 | self.lr_min = lr_min 25 | self.lr_max = lr_max 26 | self.lr_max_decay_steps = max_decay_steps 27 | self.last_lr = 0. 28 | self.verbosity_interval = verbosity_interval 29 | 30 | def schedule(self, n, **kwargs): 31 | if self.verbosity_interval > 0: 32 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 33 | if n < self.lr_warm_up_steps: 34 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 35 | self.last_lr = lr 36 | return lr 37 | else: 38 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 39 | t = min(t, 1.0) 40 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 41 | 1 + np.cos(t * np.pi)) 42 | self.last_lr = lr 43 | return lr 44 | 45 | def __call__(self, n, **kwargs): 46 | return self.schedule(n,**kwargs) 47 | 48 | 49 | class LambdaWarmUpCosineScheduler2: 50 | """ 51 | supports repeated iterations, configurable via lists 52 | note: use with a base_lr of 1.0. 53 | """ 54 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 55 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 56 | self.lr_warm_up_steps = warm_up_steps 57 | self.f_start = f_start 58 | self.f_min = f_min 59 | self.f_max = f_max 60 | self.cycle_lengths = cycle_lengths 61 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 62 | self.last_f = 0. 63 | self.verbosity_interval = verbosity_interval 64 | 65 | def find_in_interval(self, n): 66 | interval = 0 67 | for cl in self.cum_cycles[1:]: 68 | if n <= cl: 69 | return interval 70 | interval += 1 71 | 72 | def schedule(self, n, **kwargs): 73 | cycle = self.find_in_interval(n) 74 | n = n - self.cum_cycles[cycle] 75 | if self.verbosity_interval > 0: 76 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 77 | f"current cycle {cycle}") 78 | if n < self.lr_warm_up_steps[cycle]: 79 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 80 | self.last_f = f 81 | return f 82 | else: 83 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 84 | t = min(t, 1.0) 85 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 86 | 1 + np.cos(t * np.pi)) 87 | self.last_f = f 88 | return f 89 | 90 | def __call__(self, n, **kwargs): 91 | return self.schedule(n, **kwargs) 92 | 93 | 94 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 95 | 96 | def schedule(self, n, **kwargs): 97 | cycle = self.find_in_interval(n) 98 | n = n - self.cum_cycles[cycle] 99 | if self.verbosity_interval > 0: 100 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 101 | f"current cycle {cycle}") 102 | 103 | if n < self.lr_warm_up_steps[cycle]: 104 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 105 | self.last_f = f 106 | return f 107 | else: 108 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 109 | self.last_f = f 110 | return f 111 | 112 | -------------------------------------------------------------------------------- /ldm/models/__pycache__/autoencoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/models/__pycache__/autoencoder.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/models/diffusion/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/models/diffusion/__pycache__/ddim.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/models/diffusion/__pycache__/ddpm.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and 2 | # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example 3 | # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. 4 | # CreativeML Open RAIL-M 5 | # 6 | # ========================================================================================== 7 | # 8 | # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. 9 | # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit 10 | # LICENSE.md. 11 | # 12 | # ========================================================================================== 13 | 14 | import os 15 | import torch 16 | import pytorch_lightning as pl 17 | from omegaconf import OmegaConf 18 | from torch.nn import functional as F 19 | from torch.optim import AdamW 20 | from torch.optim.lr_scheduler import LambdaLR 21 | from copy import deepcopy 22 | from einops import rearrange 23 | from glob import glob 24 | from natsort import natsorted 25 | 26 | from ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel 27 | from ldm.util import log_txt_as_img, default, ismap, instantiate_from_config 28 | 29 | __models__ = { 30 | 'class_label': EncoderUNetModel, 31 | 'segmentation': UNetModel 32 | } 33 | 34 | 35 | def disabled_train(self, mode=True): 36 | """Overwrite model.train with this function to make sure train/eval mode 37 | does not change anymore.""" 38 | return self 39 | 40 | 41 | class NoisyLatentImageClassifier(pl.LightningModule): 42 | 43 | def __init__(self, 44 | diffusion_path, 45 | num_classes, 46 | ckpt_path=None, 47 | pool='attention', 48 | label_key=None, 49 | diffusion_ckpt_path=None, 50 | scheduler_config=None, 51 | weight_decay=1.e-2, 52 | log_steps=10, 53 | monitor='val/loss', 54 | *args, 55 | **kwargs): 56 | super().__init__(*args, **kwargs) 57 | self.num_classes = num_classes 58 | # get latest config of diffusion model 59 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 60 | self.diffusion_config = OmegaConf.load(diffusion_config).model 61 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 62 | self.load_diffusion() 63 | 64 | self.monitor = monitor 65 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 66 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 67 | self.log_steps = log_steps 68 | 69 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 70 | else self.diffusion_model.cond_stage_key 71 | 72 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' 73 | 74 | if self.label_key not in __models__: 75 | raise NotImplementedError() 76 | 77 | self.load_classifier(ckpt_path, pool) 78 | 79 | self.scheduler_config = scheduler_config 80 | self.use_scheduler = self.scheduler_config is not None 81 | self.weight_decay = weight_decay 82 | 83 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 84 | sd = torch.load(path, map_location="cpu") 85 | if "state_dict" in list(sd.keys()): 86 | sd = sd["state_dict"] 87 | keys = list(sd.keys()) 88 | for k in keys: 89 | for ik in ignore_keys: 90 | if k.startswith(ik): 91 | print("Deleting key {} from state_dict.".format(k)) 92 | del sd[k] 93 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 94 | sd, strict=False) 95 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 96 | if len(missing) > 0: 97 | print(f"Missing Keys: {missing}") 98 | if len(unexpected) > 0: 99 | print(f"Unexpected Keys: {unexpected}") 100 | 101 | def load_diffusion(self): 102 | model = instantiate_from_config(self.diffusion_config) 103 | self.diffusion_model = model.eval() 104 | self.diffusion_model.train = disabled_train 105 | for param in self.diffusion_model.parameters(): 106 | param.requires_grad = False 107 | 108 | def load_classifier(self, ckpt_path, pool): 109 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 110 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 111 | model_config.out_channels = self.num_classes 112 | if self.label_key == 'class_label': 113 | model_config.pool = pool 114 | 115 | self.model = __models__[self.label_key](**model_config) 116 | if ckpt_path is not None: 117 | print('#####################################################################') 118 | print(f'load from ckpt "{ckpt_path}"') 119 | print('#####################################################################') 120 | self.init_from_ckpt(ckpt_path) 121 | 122 | @torch.no_grad() 123 | def get_x_noisy(self, x, t, noise=None): 124 | noise = default(noise, lambda: torch.randn_like(x)) 125 | continuous_sqrt_alpha_cumprod = None 126 | if self.diffusion_model.use_continuous_noise: 127 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 128 | # todo: make sure t+1 is correct here 129 | 130 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 131 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 132 | 133 | def forward(self, x_noisy, t, *args, **kwargs): 134 | return self.model(x_noisy, t) 135 | 136 | @torch.no_grad() 137 | def get_input(self, batch, k): 138 | x = batch[k] 139 | if len(x.shape) == 3: 140 | x = x[..., None] 141 | x = rearrange(x, 'b h w c -> b c h w') 142 | x = x.to(memory_format=torch.contiguous_format).float() 143 | return x 144 | 145 | @torch.no_grad() 146 | def get_conditioning(self, batch, k=None): 147 | if k is None: 148 | k = self.label_key 149 | assert k is not None, 'Needs to provide label key' 150 | 151 | targets = batch[k].to(self.device) 152 | 153 | if self.label_key == 'segmentation': 154 | targets = rearrange(targets, 'b h w c -> b c h w') 155 | for down in range(self.numd): 156 | h, w = targets.shape[-2:] 157 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 158 | 159 | # targets = rearrange(targets,'b c h w -> b h w c') 160 | 161 | return targets 162 | 163 | def compute_top_k(self, logits, labels, k, reduction="mean"): 164 | _, top_ks = torch.topk(logits, k, dim=1) 165 | if reduction == "mean": 166 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 167 | elif reduction == "none": 168 | return (top_ks == labels[:, None]).float().sum(dim=-1) 169 | 170 | def on_train_epoch_start(self): 171 | # save some memory 172 | self.diffusion_model.model.to('cpu') 173 | 174 | @torch.no_grad() 175 | def write_logs(self, loss, logits, targets): 176 | log_prefix = 'train' if self.training else 'val' 177 | log = {} 178 | log[f"{log_prefix}/loss"] = loss.mean() 179 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 180 | logits, targets, k=1, reduction="mean" 181 | ) 182 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 183 | logits, targets, k=5, reduction="mean" 184 | ) 185 | 186 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 187 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 188 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) 189 | lr = self.optimizers().param_groups[0]['lr'] 190 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) 191 | 192 | def shared_step(self, batch, t=None): 193 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 194 | targets = self.get_conditioning(batch) 195 | if targets.dim() == 4: 196 | targets = targets.argmax(dim=1) 197 | if t is None: 198 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 199 | else: 200 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 201 | x_noisy = self.get_x_noisy(x, t) 202 | logits = self(x_noisy, t) 203 | 204 | loss = F.cross_entropy(logits, targets, reduction='none') 205 | 206 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 207 | 208 | loss = loss.mean() 209 | return loss, logits, x_noisy, targets 210 | 211 | def training_step(self, batch, batch_idx): 212 | loss, *_ = self.shared_step(batch) 213 | return loss 214 | 215 | def reset_noise_accs(self): 216 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 217 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 218 | 219 | def on_validation_start(self): 220 | self.reset_noise_accs() 221 | 222 | @torch.no_grad() 223 | def validation_step(self, batch, batch_idx): 224 | loss, *_ = self.shared_step(batch) 225 | 226 | for t in self.noisy_acc: 227 | _, logits, _, targets = self.shared_step(batch, t) 228 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 229 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 230 | 231 | return loss 232 | 233 | def configure_optimizers(self): 234 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 235 | 236 | if self.use_scheduler: 237 | scheduler = instantiate_from_config(self.scheduler_config) 238 | 239 | print("Setting up LambdaLR scheduler...") 240 | scheduler = [ 241 | { 242 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 243 | 'interval': 'step', 244 | 'frequency': 1 245 | }] 246 | return [optimizer], scheduler 247 | 248 | return optimizer 249 | 250 | @torch.no_grad() 251 | def log_images(self, batch, N=8, *args, **kwargs): 252 | log = dict() 253 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 254 | log['inputs'] = x 255 | 256 | y = self.get_conditioning(batch) 257 | 258 | if self.label_key == 'class_label': 259 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 260 | log['labels'] = y 261 | 262 | if ismap(y): 263 | log['labels'] = self.diffusion_model.to_rgb(y) 264 | 265 | for step in range(self.log_steps): 266 | current_time = step * self.log_time_interval 267 | 268 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 269 | 270 | log[f'inputs@t{current_time}'] = x_noisy 271 | 272 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 273 | pred = rearrange(pred, 'b h w c -> b c h w') 274 | 275 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 276 | 277 | for key in log: 278 | log[key] = log[key][:N] 279 | 280 | return log 281 | -------------------------------------------------------------------------------- /ldm/models/diffusion/plms.py: -------------------------------------------------------------------------------- 1 | # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and 2 | # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example 3 | # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. 4 | # CreativeML Open RAIL-M 5 | # 6 | # ========================================================================================== 7 | # 8 | # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. 9 | # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit 10 | # LICENSE.md. 11 | # 12 | # ========================================================================================== 13 | 14 | """SAMPLING ONLY.""" 15 | 16 | import torch 17 | import numpy as np 18 | from tqdm import tqdm 19 | from functools import partial 20 | 21 | from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 22 | 23 | 24 | class PLMSSampler(object): 25 | def __init__(self, model, schedule="linear", **kwargs): 26 | super().__init__() 27 | self.model = model 28 | self.ddpm_num_timesteps = model.num_timesteps 29 | self.schedule = schedule 30 | 31 | def register_buffer(self, name, attr): 32 | if type(attr) == torch.Tensor: 33 | if attr.device != torch.device("cuda"): 34 | attr = attr.to(torch.device("cuda")) 35 | setattr(self, name, attr) 36 | 37 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 38 | if ddim_eta != 0: 39 | raise ValueError('ddim_eta must be 0 for PLMS') 40 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 41 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 42 | alphas_cumprod = self.model.alphas_cumprod 43 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 44 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 45 | 46 | self.register_buffer('betas', to_torch(self.model.betas)) 47 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 48 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 49 | 50 | # calculations for diffusion q(x_t | x_{t-1}) and others 51 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 52 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 53 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 54 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 55 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 56 | 57 | # ddim sampling parameters 58 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 59 | ddim_timesteps=self.ddim_timesteps, 60 | eta=ddim_eta,verbose=verbose) 61 | self.register_buffer('ddim_sigmas', ddim_sigmas) 62 | self.register_buffer('ddim_alphas', ddim_alphas) 63 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 64 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 65 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 66 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 67 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 68 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 69 | 70 | @torch.no_grad() 71 | def sample(self, 72 | S, 73 | batch_size, 74 | shape, 75 | conditioning=None, 76 | callback=None, 77 | normals_sequence=None, 78 | img_callback=None, 79 | quantize_x0=False, 80 | eta=0., 81 | mask=None, 82 | x0=None, 83 | temperature=1., 84 | noise_dropout=0., 85 | score_corrector=None, 86 | corrector_kwargs=None, 87 | verbose=True, 88 | x_T=None, 89 | log_every_t=100, 90 | unconditional_guidance_scale=1., 91 | unconditional_conditioning=None, 92 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 93 | **kwargs 94 | ): 95 | if conditioning is not None: 96 | if isinstance(conditioning, dict): 97 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 98 | if cbs != batch_size: 99 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 100 | else: 101 | if conditioning.shape[0] != batch_size: 102 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 103 | 104 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 105 | # sampling 106 | C, H, W = shape 107 | size = (batch_size, C, H, W) 108 | print(f'Data shape for PLMS sampling is {size}') 109 | 110 | samples, intermediates = self.plms_sampling(conditioning, size, 111 | callback=callback, 112 | img_callback=img_callback, 113 | quantize_denoised=quantize_x0, 114 | mask=mask, x0=x0, 115 | ddim_use_original_steps=False, 116 | noise_dropout=noise_dropout, 117 | temperature=temperature, 118 | score_corrector=score_corrector, 119 | corrector_kwargs=corrector_kwargs, 120 | x_T=x_T, 121 | log_every_t=log_every_t, 122 | unconditional_guidance_scale=unconditional_guidance_scale, 123 | unconditional_conditioning=unconditional_conditioning, 124 | **kwargs 125 | ) 126 | return samples, intermediates 127 | 128 | @torch.no_grad() 129 | def plms_sampling(self, cond, shape, 130 | x_T=None, ddim_use_original_steps=False, 131 | callback=None, timesteps=None, quantize_denoised=False, 132 | mask=None, x0=None, img_callback=None, log_every_t=100, 133 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 134 | unconditional_guidance_scale=1., unconditional_conditioning=None,**kwargs): 135 | device = self.model.betas.device 136 | b = shape[0] 137 | if x_T is None: 138 | img = torch.randn(shape, device=device) 139 | else: 140 | img = x_T 141 | 142 | if timesteps is None: 143 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 144 | elif timesteps is not None and not ddim_use_original_steps: 145 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 146 | timesteps = self.ddim_timesteps[:subset_end] 147 | 148 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 149 | time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) 150 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 151 | print(f"Running PLMS Sampling with {total_steps} timesteps") 152 | 153 | iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) 154 | old_eps = [] 155 | 156 | for i, step in enumerate(iterator): 157 | index = total_steps - i - 1 158 | ts = torch.full((b,), step, device=device, dtype=torch.long) 159 | ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) 160 | 161 | if mask is not None: 162 | assert x0 is not None 163 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 164 | img = img_orig * mask + (1. - mask) * img 165 | 166 | outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 167 | quantize_denoised=quantize_denoised, temperature=temperature, 168 | noise_dropout=noise_dropout, score_corrector=score_corrector, 169 | corrector_kwargs=corrector_kwargs, 170 | unconditional_guidance_scale=unconditional_guidance_scale, 171 | unconditional_conditioning=unconditional_conditioning, 172 | old_eps=old_eps, t_next=ts_next,**kwargs) 173 | img, pred_x0, e_t = outs 174 | old_eps.append(e_t) 175 | if len(old_eps) >= 4: 176 | old_eps.pop(0) 177 | if callback: callback(i) 178 | if img_callback: img_callback(pred_x0, i) 179 | 180 | if index % log_every_t == 0 or index == total_steps - 1: 181 | intermediates['x_inter'].append(img) 182 | intermediates['pred_x0'].append(pred_x0) 183 | 184 | return img, intermediates 185 | 186 | @torch.no_grad() 187 | def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 188 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 189 | unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None,**kwargs): 190 | b, *_, device = *x.shape, x.device 191 | def get_model_output(x, t): 192 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 193 | e_t = self.model.apply_model(x, t, c) 194 | else: 195 | x_in = torch.cat([x] * 2) 196 | t_in = torch.cat([t] * 2) 197 | c_in = torch.cat([unconditional_conditioning, c]) 198 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 199 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 200 | 201 | if score_corrector is not None: 202 | assert self.model.parameterization == "eps" 203 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 204 | 205 | return e_t 206 | 207 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 208 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 209 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 210 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 211 | 212 | def get_x_prev_and_pred_x0(e_t, index): 213 | # select parameters corresponding to the currently considered timestep 214 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 215 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 216 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 217 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 218 | 219 | # current prediction for x_0 220 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 221 | if quantize_denoised: 222 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 223 | # direction pointing to x_t 224 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 225 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 226 | if noise_dropout > 0.: 227 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 228 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 229 | return x_prev, pred_x0 230 | kwargs=kwargs['test_model_kwargs'] 231 | x_new=torch.cat([x,kwargs['inpaint_image'],kwargs['inpaint_mask']],dim=1) 232 | e_t = get_model_output(x_new, t) 233 | if len(old_eps) == 0: 234 | # Pseudo Improved Euler (2nd order) 235 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) 236 | x_prev_new=torch.cat([x_prev,kwargs['inpaint_image'],kwargs['inpaint_mask']],dim=1) 237 | e_t_next = get_model_output(x_prev_new, t_next) 238 | e_t_prime = (e_t + e_t_next) / 2 239 | elif len(old_eps) == 1: 240 | # 2nd order Pseudo Linear Multistep (Adams-Bashforth) 241 | e_t_prime = (3 * e_t - old_eps[-1]) / 2 242 | elif len(old_eps) == 2: 243 | # 3nd order Pseudo Linear Multistep (Adams-Bashforth) 244 | e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 245 | elif len(old_eps) >= 3: 246 | # 4nd order Pseudo Linear Multistep (Adams-Bashforth) 247 | e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 248 | 249 | x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) 250 | 251 | return x_prev, pred_x0, e_t 252 | -------------------------------------------------------------------------------- /ldm/modules/__pycache__/attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/modules/__pycache__/attention.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/ema.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/modules/__pycache__/ema.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/x_transformer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/modules/__pycache__/x_transformer.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and 2 | # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example 3 | # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. 4 | # CreativeML Open RAIL-M 5 | # 6 | # ========================================================================================== 7 | # 8 | # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. 9 | # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit 10 | # LICENSE.md. 11 | # 12 | # ========================================================================================== 13 | 14 | from inspect import isfunction 15 | import math 16 | import numpy as np 17 | import torch 18 | import torch.nn.functional as F 19 | from torch import nn, einsum 20 | from einops import rearrange, repeat 21 | import glob 22 | 23 | from ldm.modules.diffusionmodules.util import checkpoint 24 | 25 | 26 | def exists(val): 27 | return val is not None 28 | 29 | 30 | def uniq(arr): 31 | return{el: True for el in arr}.keys() 32 | 33 | 34 | def default(val, d): 35 | if exists(val): 36 | return val 37 | return d() if isfunction(d) else d 38 | 39 | 40 | def max_neg_value(t): 41 | return -torch.finfo(t.dtype).max 42 | 43 | 44 | def init_(tensor): 45 | dim = tensor.shape[-1] 46 | std = 1 / math.sqrt(dim) 47 | tensor.uniform_(-std, std) 48 | return tensor 49 | 50 | 51 | # feedforward 52 | class GEGLU(nn.Module): 53 | def __init__(self, dim_in, dim_out): 54 | super().__init__() 55 | self.proj = nn.Linear(dim_in, dim_out * 2) 56 | 57 | def forward(self, x): 58 | x, gate = self.proj(x).chunk(2, dim=-1) 59 | return x * F.gelu(gate) 60 | 61 | 62 | class FeedForward(nn.Module): 63 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 64 | super().__init__() 65 | inner_dim = int(dim * mult) 66 | dim_out = default(dim_out, dim) 67 | project_in = nn.Sequential( 68 | nn.Linear(dim, inner_dim), 69 | nn.GELU() 70 | ) if not glu else GEGLU(dim, inner_dim) 71 | 72 | self.net = nn.Sequential( 73 | project_in, 74 | nn.Dropout(dropout), 75 | nn.Linear(inner_dim, dim_out) 76 | ) 77 | 78 | def forward(self, x): 79 | return self.net(x) 80 | 81 | 82 | def zero_module(module): 83 | """ 84 | Zero out the parameters of a module and return it. 85 | """ 86 | for p in module.parameters(): 87 | p.detach().zero_() 88 | return module 89 | 90 | 91 | def Normalize(in_channels): 92 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 93 | 94 | 95 | class LinearAttention(nn.Module): 96 | def __init__(self, dim, heads=4, dim_head=32): 97 | super().__init__() 98 | self.heads = heads 99 | hidden_dim = dim_head * heads 100 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 101 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 102 | 103 | def forward(self, x): 104 | b, c, h, w = x.shape 105 | qkv = self.to_qkv(x) 106 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 107 | k = k.softmax(dim=-1) 108 | context = torch.einsum('bhdn,bhen->bhde', k, v) 109 | out = torch.einsum('bhde,bhdn->bhen', context, q) 110 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 111 | return self.to_out(out) 112 | 113 | 114 | class SpatialSelfAttention(nn.Module): 115 | def __init__(self, in_channels): 116 | super().__init__() 117 | self.in_channels = in_channels 118 | 119 | self.norm = Normalize(in_channels) 120 | self.q = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | self.k = torch.nn.Conv2d(in_channels, 126 | in_channels, 127 | kernel_size=1, 128 | stride=1, 129 | padding=0) 130 | self.v = torch.nn.Conv2d(in_channels, 131 | in_channels, 132 | kernel_size=1, 133 | stride=1, 134 | padding=0) 135 | self.proj_out = torch.nn.Conv2d(in_channels, 136 | in_channels, 137 | kernel_size=1, 138 | stride=1, 139 | padding=0) 140 | 141 | def forward(self, x): 142 | h_ = x 143 | h_ = self.norm(h_) 144 | q = self.q(h_) 145 | k = self.k(h_) 146 | v = self.v(h_) 147 | 148 | # compute attention 149 | b,c,h,w = q.shape 150 | q = rearrange(q, 'b c h w -> b (h w) c') 151 | k = rearrange(k, 'b c h w -> b c (h w)') 152 | w_ = torch.einsum('bij,bjk->bik', q, k) 153 | 154 | w_ = w_ * (int(c)**(-0.5)) 155 | w_ = torch.nn.functional.softmax(w_, dim=2) 156 | 157 | # attend to values 158 | v = rearrange(v, 'b c h w -> b c (h w)') 159 | w_ = rearrange(w_, 'b i j -> b j i') 160 | h_ = torch.einsum('bij,bjk->bik', v, w_) 161 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 162 | h_ = self.proj_out(h_) 163 | 164 | return x+h_ 165 | 166 | 167 | class CrossAttention(nn.Module): 168 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., only_crossref=False): 169 | super().__init__() 170 | inner_dim = dim_head * heads 171 | # forcing attention to only attend on vectors of same size 172 | # breaking the image2text attention 173 | context_dim = default(context_dim, query_dim) 174 | 175 | # print('creating cross attention. Query dim', query_dim, ' context dim', context_dim) 176 | 177 | self.scale = dim_head ** -0.5 178 | self.heads = heads 179 | 180 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 181 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 182 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 183 | 184 | self.to_out = nn.Sequential( 185 | nn.Linear(inner_dim, query_dim), 186 | nn.Dropout(dropout) 187 | ) 188 | 189 | self.only_crossref = only_crossref 190 | if only_crossref: 191 | self.merge_attentions = zero_module(nn.Conv2d(self.heads * 2, 192 | self.heads, 193 | kernel_size=1, 194 | stride=1, 195 | padding=0)) 196 | else: 197 | self.merge_attentions = zero_module(nn.Conv2d(self.heads * 3, 198 | self.heads, 199 | kernel_size=1, 200 | stride=1, 201 | padding=0)) 202 | 203 | 204 | self.merge_attentions_missing = zero_module(nn.Conv2d(self.heads * 2, 205 | self.heads, 206 | kernel_size=1, 207 | stride=1, 208 | padding=0)) 209 | 210 | 211 | def forward(self, x, context=None, mask=None, passed_qkv=None, masks=None, corresp=None, missing_region=None): 212 | is_self_attention = context is None 213 | 214 | # if masks is not None: 215 | # print(is_self_attention, masks.keys()) 216 | 217 | h = self.heads 218 | 219 | # if passed_qkv is not None: 220 | # assert context is None 221 | 222 | # _,_,_,_, x_features = passed_qkv 223 | # assert x_features is not None 224 | 225 | # # print('x shape', x.shape, 'x features', x_features.shape) 226 | # # breakpoint() 227 | # x = torch.concat([x, x_features], dim=1) 228 | 229 | q = self.to_q(x) 230 | context = default(context, x) 231 | k = self.to_k(context) 232 | v = self.to_v(context) 233 | 234 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 235 | 236 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 237 | 238 | if exists(mask): 239 | assert False 240 | mask = rearrange(mask, 'b ... -> b (...)') 241 | max_neg_value = -torch.finfo(sim.dtype).max 242 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 243 | sim.masked_fill_(~mask, max_neg_value) 244 | 245 | # attention, what we cannot get enough of 246 | attn = sim.softmax(dim=-1) 247 | out = einsum('b i j, b j d -> b i d', attn, v) 248 | inter_out = rearrange(out, '(b h) n d -> b h n d', h=h) 249 | 250 | combined_attention = inter_out 251 | out = rearrange(combined_attention, 'b h n d -> b n (h d)', h=h) 252 | 253 | final_out = self.to_out(out) 254 | 255 | if is_self_attention: 256 | return final_out, q, k, v, inter_out #TODO add attn out 257 | else: 258 | return final_out 259 | 260 | 261 | class BasicTransformerBlock(nn.Module): 262 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 263 | super().__init__() 264 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 265 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 266 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 267 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 268 | self.attn3 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) 269 | self.norm1 = nn.LayerNorm(dim) 270 | self.norm2 = nn.LayerNorm(dim) 271 | self.norm3 = nn.LayerNorm(dim) 272 | self.checkpoint = checkpoint 273 | 274 | # TODO add attn in 275 | def forward(self, x, context=None, passed_qkv=None, masks=None, corresp=None): 276 | if passed_qkv is None: 277 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 278 | else: 279 | q, k, v, attn, x_features = passed_qkv 280 | d = int(np.sqrt(q.shape[1])) 281 | current_mask = masks[d] 282 | if corresp: 283 | current_corresp, missing_region = corresp[d] 284 | current_corresp = current_corresp.float() 285 | missing_region = missing_region.float() 286 | else: 287 | raise ValueError('cannot have empty corresp') 288 | current_corresp = None 289 | missing_region = current_mask.float() 290 | # breakpoint() 291 | stuff = [q, k, v, attn, x_features, current_mask, current_corresp, missing_region] 292 | for element in stuff: 293 | assert element is not None 294 | return checkpoint(self._forward, (x, context, q, k, v, attn, x_features, current_mask, current_corresp, missing_region), self.parameters(), self.checkpoint) 295 | 296 | # TODO add attn in 297 | def _forward(self, x, context=None, q=None, k=None, v=None, attn=None, passed_x=None, masks=None, corresp=None, missing_region=None): 298 | if q is not None: 299 | passed_qkv = (q, k, v, attn, passed_x) 300 | else: 301 | passed_qkv = None 302 | x_features = self.norm1(x) 303 | attended_x, q, k, v, attn = self.attn1(x_features, passed_qkv=passed_qkv, masks=masks, corresp=corresp, missing_region=missing_region) 304 | x = attended_x + x 305 | # killing CLIP features 306 | 307 | if passed_x is not None: 308 | normed_x = self.norm2(x) 309 | attn_out = self.attn3(normed_x, context=passed_x) 310 | x = attn_out + x 311 | # then use y + x 312 | # print('y shape', y.shape, ' x shape', x.shape) 313 | 314 | x = self.ff(self.norm3(x)) + x 315 | return x, q, k, v, attn, x_features 316 | 317 | 318 | class SpatialTransformer(nn.Module): 319 | """ 320 | Transformer block for image-like data. 321 | First, project the input (aka embedding) 322 | and reshape to b, t, d. 323 | Then apply standard transformer action. 324 | Finally, reshape to image 325 | """ 326 | def __init__(self, in_channels, n_heads, d_head, 327 | depth=1, dropout=0., context_dim=None): 328 | super().__init__() 329 | self.in_channels = in_channels 330 | inner_dim = n_heads * d_head 331 | self.norm = Normalize(in_channels) 332 | 333 | # print('creating spatial transformer') 334 | # print('in channels', in_channels, 'inner dim', inner_dim) 335 | 336 | self.proj_in = nn.Conv2d(in_channels, 337 | inner_dim, 338 | kernel_size=1, 339 | stride=1, 340 | padding=0) 341 | 342 | self.transformer_blocks = nn.ModuleList( 343 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 344 | for d in range(depth)] 345 | ) 346 | 347 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 348 | in_channels, 349 | kernel_size=1, 350 | stride=1, 351 | padding=0)) 352 | 353 | # TODO add attn in and corresp 354 | def forward(self, x, context=None, passed_qkv=None, masks=None, corresp=None): 355 | # note: if no context is given, cross-attention defaults to self-attention 356 | b, c, h, w = x.shape 357 | # print('spatial transformer x shape given', x.shape) 358 | # if context is not None: 359 | # print('also context was provided with shape ', context.shape) 360 | x_in = x 361 | x = self.norm(x) 362 | x = self.proj_in(x) 363 | x = rearrange(x, 'b c h w -> b (h w) c') 364 | 365 | qkvs = [] 366 | for block in self.transformer_blocks: 367 | x, q, k, v, attn, x_features = block(x, context=context, passed_qkv=passed_qkv, masks=masks, corresp=corresp) 368 | qkv = (q,k,v,attn, x_features) 369 | qkvs.append(qkv) 370 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 371 | x = self.proj_out(x) 372 | return x + x_in, qkvs -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/modules/diffusionmodules/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/modules/diffusionmodules/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and 2 | # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example 3 | # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. 4 | # CreativeML Open RAIL-M 5 | # 6 | # ========================================================================================== 7 | # 8 | # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. 9 | # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit 10 | # LICENSE.md. 11 | # 12 | # ========================================================================================== 13 | 14 | # adopted from 15 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 16 | # and 17 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 18 | # and 19 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 20 | # 21 | # thanks! 22 | 23 | 24 | import os 25 | import math 26 | import torch 27 | import torch.nn as nn 28 | import numpy as np 29 | from einops import repeat 30 | 31 | from ldm.util import instantiate_from_config 32 | 33 | 34 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 35 | if schedule == "linear": 36 | betas = ( 37 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 38 | ) 39 | 40 | elif schedule == "cosine": 41 | timesteps = ( 42 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 43 | ) 44 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 45 | alphas = torch.cos(alphas).pow(2) 46 | alphas = alphas / alphas[0] 47 | betas = 1 - alphas[1:] / alphas[:-1] 48 | betas = np.clip(betas, a_min=0, a_max=0.999) 49 | 50 | elif schedule == "sqrt_linear": 51 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 52 | elif schedule == "sqrt": 53 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 54 | else: 55 | raise ValueError(f"schedule '{schedule}' unknown.") 56 | return betas.numpy() 57 | 58 | 59 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True, steps=None): 60 | if ddim_discr_method == 'uniform': 61 | c = num_ddpm_timesteps // num_ddim_timesteps 62 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 63 | elif ddim_discr_method == 'quad': 64 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 65 | elif ddim_discr_method == 'manual': 66 | assert steps is not None 67 | ddim_timesteps = np.asarray(steps) 68 | else: 69 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 70 | 71 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 72 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 73 | steps_out = ddim_timesteps + 1 74 | if verbose: 75 | print(f'Selected timesteps for ddim sampler: {steps_out}') 76 | return steps_out 77 | 78 | 79 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 80 | # select alphas for computing the variance schedule 81 | alphas = alphacums[ddim_timesteps] 82 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 83 | 84 | # according the the formula provided in https://arxiv.org/abs/2010.02502 85 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 86 | if verbose: 87 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 88 | print(f'For the chosen value of eta, which is {eta}, ' 89 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 90 | return sigmas, alphas, alphas_prev 91 | 92 | 93 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 94 | """ 95 | Create a beta schedule that discretizes the given alpha_t_bar function, 96 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 97 | :param num_diffusion_timesteps: the number of betas to produce. 98 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 99 | produces the cumulative product of (1-beta) up to that 100 | part of the diffusion process. 101 | :param max_beta: the maximum beta to use; use values lower than 1 to 102 | prevent singularities. 103 | """ 104 | betas = [] 105 | for i in range(num_diffusion_timesteps): 106 | t1 = i / num_diffusion_timesteps 107 | t2 = (i + 1) / num_diffusion_timesteps 108 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 109 | return np.array(betas) 110 | 111 | 112 | def extract_into_tensor(a, t, x_shape): 113 | b, *_ = t.shape 114 | out = a.gather(-1, t) 115 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 116 | 117 | 118 | def checkpoint(func, inputs, params, flag): 119 | """ 120 | Evaluate a function without caching intermediate activations, allowing for 121 | reduced memory at the expense of extra compute in the backward pass. 122 | :param func: the function to evaluate. 123 | :param inputs: the argument sequence to pass to `func`. 124 | :param params: a sequence of parameters `func` depends on but does not 125 | explicitly take as arguments. 126 | :param flag: if False, disable gradient checkpointing. 127 | """ 128 | if flag: 129 | args = tuple(inputs) + tuple(params) 130 | return CheckpointFunction.apply(func, len(inputs), *args) 131 | else: 132 | return func(*inputs) 133 | 134 | 135 | class CheckpointFunction(torch.autograd.Function): 136 | @staticmethod 137 | # @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) # added this for map 138 | def forward(ctx, run_function, length, *args): 139 | ctx.run_function = run_function 140 | ctx.input_tensors = list(args[:length]) 141 | ctx.input_params = list(args[length:]) 142 | 143 | with torch.no_grad(): 144 | output_tensors = ctx.run_function(*ctx.input_tensors) 145 | return output_tensors 146 | 147 | @staticmethod 148 | # @torch.cuda.amp.custom_bwd # added this for map 149 | def backward(ctx, *output_grads): 150 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 151 | with torch.enable_grad(): 152 | # Fixes a bug where the first op in run_function modifies the 153 | # Tensor storage in place, which is not allowed for detach()'d 154 | # Tensors. 155 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 156 | output_tensors = ctx.run_function(*shallow_copies) 157 | input_grads = torch.autograd.grad( 158 | output_tensors, 159 | ctx.input_tensors + ctx.input_params, 160 | output_grads, 161 | allow_unused=True, 162 | ) 163 | del ctx.input_tensors 164 | del ctx.input_params 165 | del output_tensors 166 | return (None, None) + input_grads 167 | 168 | 169 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 170 | """ 171 | Create sinusoidal timestep embeddings. 172 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 173 | These may be fractional. 174 | :param dim: the dimension of the output. 175 | :param max_period: controls the minimum frequency of the embeddings. 176 | :return: an [N x dim] Tensor of positional embeddings. 177 | """ 178 | if not repeat_only: 179 | half = dim // 2 180 | freqs = torch.exp( 181 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 182 | ).to(device=timesteps.device) 183 | args = timesteps[:, None].float() * freqs[None] 184 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 185 | if dim % 2: 186 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 187 | else: 188 | embedding = repeat(timesteps, 'b -> b d', d=dim) 189 | return embedding 190 | 191 | 192 | def zero_module(module): 193 | """ 194 | Zero out the parameters of a module and return it. 195 | """ 196 | for p in module.parameters(): 197 | p.detach().zero_() 198 | return module 199 | 200 | 201 | def scale_module(module, scale): 202 | """ 203 | Scale the parameters of a module and return it. 204 | """ 205 | for p in module.parameters(): 206 | p.detach().mul_(scale) 207 | return module 208 | 209 | 210 | def mean_flat(tensor): 211 | """ 212 | Take the mean over all non-batch dimensions. 213 | """ 214 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 215 | 216 | 217 | def normalization(channels): 218 | """ 219 | Make a standard normalization layer. 220 | :param channels: number of input channels. 221 | :return: an nn.Module for normalization. 222 | """ 223 | return GroupNorm32(32, channels) 224 | 225 | 226 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 227 | class SiLU(nn.Module): 228 | def forward(self, x): 229 | return x * torch.sigmoid(x) 230 | 231 | 232 | class GroupNorm32(nn.GroupNorm): 233 | def forward(self, x): 234 | return super().forward(x.float()).type(x.dtype) 235 | 236 | def conv_nd(dims, *args, **kwargs): 237 | """ 238 | Create a 1D, 2D, or 3D convolution module. 239 | """ 240 | if dims == 1: 241 | return nn.Conv1d(*args, **kwargs) 242 | elif dims == 2: 243 | return nn.Conv2d(*args, **kwargs) 244 | elif dims == 3: 245 | return nn.Conv3d(*args, **kwargs) 246 | raise ValueError(f"unsupported dimensions: {dims}") 247 | 248 | 249 | def linear(*args, **kwargs): 250 | """ 251 | Create a linear module. 252 | """ 253 | return nn.Linear(*args, **kwargs) 254 | 255 | 256 | def avg_pool_nd(dims, *args, **kwargs): 257 | """ 258 | Create a 1D, 2D, or 3D average pooling module. 259 | """ 260 | if dims == 1: 261 | return nn.AvgPool1d(*args, **kwargs) 262 | elif dims == 2: 263 | return nn.AvgPool2d(*args, **kwargs) 264 | elif dims == 3: 265 | return nn.AvgPool3d(*args, **kwargs) 266 | raise ValueError(f"unsupported dimensions: {dims}") 267 | 268 | 269 | class HybridConditioner(nn.Module): 270 | 271 | def __init__(self, c_concat_config, c_crossattn_config): 272 | super().__init__() 273 | self.concat_conditioner = instantiate_from_config(c_concat_config) 274 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 275 | 276 | def forward(self, c_concat, c_crossattn): 277 | c_concat = self.concat_conditioner(c_concat) 278 | c_crossattn = self.crossattn_conditioner(c_crossattn) 279 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 280 | 281 | 282 | def noise_like(shape, device, repeat=False): 283 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 284 | noise = lambda: torch.randn(shape, device=device) 285 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/modules/distributions/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/modules/distributions/__pycache__/distributions.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and 2 | # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example 3 | # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. 4 | # CreativeML Open RAIL-M 5 | # 6 | # ========================================================================================== 7 | # 8 | # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. 9 | # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit 10 | # LICENSE.md. 11 | # 12 | # ========================================================================================== 13 | 14 | import torch 15 | import numpy as np 16 | 17 | 18 | class AbstractDistribution: 19 | def sample(self): 20 | raise NotImplementedError() 21 | 22 | def mode(self): 23 | raise NotImplementedError() 24 | 25 | 26 | class DiracDistribution(AbstractDistribution): 27 | def __init__(self, value): 28 | self.value = value 29 | 30 | def sample(self): 31 | return self.value 32 | 33 | def mode(self): 34 | return self.value 35 | 36 | 37 | class DiagonalGaussianDistribution(object): 38 | def __init__(self, parameters, deterministic=False): 39 | self.parameters = parameters 40 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 41 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 42 | self.deterministic = deterministic 43 | self.std = torch.exp(0.5 * self.logvar) 44 | self.var = torch.exp(self.logvar) 45 | if self.deterministic: 46 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 47 | 48 | def sample(self): 49 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 50 | return x 51 | 52 | def kl(self, other=None): 53 | if self.deterministic: 54 | return torch.Tensor([0.]) 55 | else: 56 | if other is None: 57 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 58 | + self.var - 1.0 - self.logvar, 59 | dim=[1, 2, 3]) 60 | else: 61 | return 0.5 * torch.sum( 62 | torch.pow(self.mean - other.mean, 2) / other.var 63 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 64 | dim=[1, 2, 3]) 65 | 66 | def nll(self, sample, dims=[1,2,3]): 67 | if self.deterministic: 68 | return torch.Tensor([0.]) 69 | logtwopi = np.log(2.0 * np.pi) 70 | return 0.5 * torch.sum( 71 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 72 | dim=dims) 73 | 74 | def mode(self): 75 | return self.mean 76 | 77 | 78 | def normal_kl(mean1, logvar1, mean2, logvar2): 79 | """ 80 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 81 | Compute the KL divergence between two gaussians. 82 | Shapes are automatically broadcasted, so batches can be compared to 83 | scalars, among other use cases. 84 | """ 85 | tensor = None 86 | for obj in (mean1, logvar1, mean2, logvar2): 87 | if isinstance(obj, torch.Tensor): 88 | tensor = obj 89 | break 90 | assert tensor is not None, "at least one argument must be a Tensor" 91 | 92 | # Force variances to be Tensors. Broadcasting helps convert scalars to 93 | # Tensors, but it does not work for torch.exp(). 94 | logvar1, logvar2 = [ 95 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 96 | for x in (logvar1, logvar2) 97 | ] 98 | 99 | return 0.5 * ( 100 | -1.0 101 | + logvar2 102 | - logvar1 103 | + torch.exp(logvar1 - logvar2) 104 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 105 | ) 106 | -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and 2 | # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example 3 | # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. 4 | # CreativeML Open RAIL-M 5 | # 6 | # ========================================================================================== 7 | # 8 | # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. 9 | # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit 10 | # LICENSE.md. 11 | # 12 | # ========================================================================================== 13 | 14 | import torch 15 | from torch import nn 16 | 17 | 18 | class LitEma(nn.Module): 19 | def __init__(self, model, decay=0.9999, use_num_upates=True): 20 | super().__init__() 21 | if decay < 0.0 or decay > 1.0: 22 | raise ValueError('Decay must be between 0 and 1') 23 | 24 | self.m_name2s_name = {} 25 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 26 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 27 | else torch.tensor(-1,dtype=torch.int)) 28 | 29 | for name, p in model.named_parameters(): 30 | if p.requires_grad: 31 | #remove as '.'-character is not allowed in buffers 32 | s_name = name.replace('.','') 33 | self.m_name2s_name.update({name:s_name}) 34 | self.register_buffer(s_name,p.clone().detach().data) 35 | 36 | self.collected_params = [] 37 | 38 | def forward(self,model): 39 | decay = self.decay 40 | 41 | if self.num_updates >= 0: 42 | self.num_updates += 1 43 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 44 | 45 | one_minus_decay = 1.0 - decay 46 | 47 | with torch.no_grad(): 48 | m_param = dict(model.named_parameters()) 49 | shadow_params = dict(self.named_buffers()) 50 | 51 | for key in m_param: 52 | if m_param[key].requires_grad: 53 | sname = self.m_name2s_name[key] 54 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 55 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 56 | else: 57 | assert not key in self.m_name2s_name 58 | 59 | def copy_to(self, model): 60 | m_param = dict(model.named_parameters()) 61 | shadow_params = dict(self.named_buffers()) 62 | for key in m_param: 63 | if m_param[key].requires_grad: 64 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 65 | else: 66 | assert not key in self.m_name2s_name 67 | 68 | def store(self, parameters): 69 | """ 70 | Save the current parameters for restoring later. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | temporarily stored. 74 | """ 75 | self.collected_params = [param.clone() for param in parameters] 76 | 77 | def restore(self, parameters): 78 | """ 79 | Restore the parameters stored with the `store` method. 80 | Useful to validate the model with EMA parameters without affecting the 81 | original optimization process. Store the parameters before the 82 | `copy_to` method. After validation (or model saving), use this to 83 | restore the former parameters. 84 | Args: 85 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 86 | updated with the stored parameters. 87 | """ 88 | for c_param, param in zip(self.collected_params, parameters): 89 | param.data.copy_(c_param.data) 90 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/modules/encoders/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/modules.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/modules/encoders/__pycache__/modules.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/xf.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adobe-research/MagicFixup/47508c136b0f2427256b5fe7f13a9bbbabc07361/ldm/modules/encoders/__pycache__/xf.cpython-38.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and 2 | # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example 3 | # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. 4 | # CreativeML Open RAIL-M 5 | # 6 | # ========================================================================================== 7 | # 8 | # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. 9 | # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit 10 | # LICENSE.md. 11 | # 12 | # ========================================================================================== 13 | 14 | import torch 15 | import torch.nn as nn 16 | from functools import partial 17 | import clip 18 | from einops import rearrange, repeat 19 | from transformers import CLIPTokenizer, CLIPTextModel,CLIPVisionModel,CLIPModel 20 | import kornia 21 | from ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test 22 | from .xf import LayerNorm, Transformer 23 | import math 24 | 25 | class AbstractEncoder(nn.Module): 26 | def __init__(self): 27 | super().__init__() 28 | 29 | def encode(self, *args, **kwargs): 30 | raise NotImplementedError 31 | 32 | 33 | 34 | class ClassEmbedder(nn.Module): 35 | def __init__(self, embed_dim, n_classes=1000, key='class'): 36 | super().__init__() 37 | self.key = key 38 | self.embedding = nn.Embedding(n_classes, embed_dim) 39 | 40 | def forward(self, batch, key=None): 41 | if key is None: 42 | key = self.key 43 | # this is for use in crossattn 44 | c = batch[key][:, None] 45 | c = self.embedding(c) 46 | return c 47 | 48 | 49 | class TransformerEmbedder(AbstractEncoder): 50 | """Some transformer encoder layers""" 51 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 52 | super().__init__() 53 | self.device = device 54 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 55 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 56 | 57 | def forward(self, tokens): 58 | tokens = tokens.to(self.device) # meh 59 | z = self.transformer(tokens, return_embeddings=True) 60 | return z 61 | 62 | def encode(self, x): 63 | return self(x) 64 | 65 | 66 | class BERTTokenizer(AbstractEncoder): 67 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 68 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 69 | super().__init__() 70 | from transformers import BertTokenizerFast # TODO: add to reuquirements 71 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 72 | self.device = device 73 | self.vq_interface = vq_interface 74 | self.max_length = max_length 75 | 76 | def forward(self, text): 77 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 78 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 79 | tokens = batch_encoding["input_ids"].to(self.device) 80 | return tokens 81 | 82 | @torch.no_grad() 83 | def encode(self, text): 84 | tokens = self(text) 85 | if not self.vq_interface: 86 | return tokens 87 | return None, None, [None, None, tokens] 88 | 89 | def decode(self, text): 90 | return text 91 | 92 | 93 | class BERTEmbedder(AbstractEncoder): 94 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 95 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 96 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 97 | super().__init__() 98 | self.use_tknz_fn = use_tokenizer 99 | if self.use_tknz_fn: 100 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 101 | self.device = device 102 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 103 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 104 | emb_dropout=embedding_dropout) 105 | 106 | def forward(self, text): 107 | if self.use_tknz_fn: 108 | tokens = self.tknz_fn(text)#.to(self.device) 109 | else: 110 | tokens = text 111 | z = self.transformer(tokens, return_embeddings=True) 112 | return z 113 | 114 | def encode(self, text): 115 | # output of length 77 116 | return self(text) 117 | 118 | class FrozenCLIPEmbedder(AbstractEncoder): 119 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 120 | def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): 121 | super().__init__() 122 | self.tokenizer = CLIPTokenizer.from_pretrained(version) 123 | self.transformer = CLIPTextModel.from_pretrained(version) 124 | self.device = device 125 | self.max_length = max_length 126 | self.freeze() 127 | 128 | def freeze(self): 129 | self.transformer = self.transformer.eval() 130 | for param in self.parameters(): 131 | param.requires_grad = False 132 | 133 | def forward(self, text): 134 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 135 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 136 | tokens = batch_encoding["input_ids"].to(self.device) 137 | outputs = self.transformer(input_ids=tokens) 138 | 139 | z = outputs.last_hidden_state 140 | return z 141 | 142 | def encode(self, text): 143 | return self(text) 144 | 145 | 146 | class SpatialRescaler(nn.Module): 147 | def __init__(self, 148 | n_stages=1, 149 | method='bilinear', 150 | multiplier=0.5, 151 | in_channels=3, 152 | out_channels=None, 153 | bias=False): 154 | super().__init__() 155 | self.n_stages = n_stages 156 | assert self.n_stages >= 0 157 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 158 | self.multiplier = multiplier 159 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 160 | self.remap_output = out_channels is not None 161 | if self.remap_output: 162 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 163 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 164 | 165 | def forward(self,x): 166 | for stage in range(self.n_stages): 167 | x = self.interpolator(x, scale_factor=self.multiplier) 168 | 169 | 170 | if self.remap_output: 171 | x = self.channel_mapper(x) 172 | return x 173 | 174 | def encode(self, x): 175 | return self(x) 176 | 177 | class FrozenCLIPTextEmbedder(nn.Module): 178 | """ 179 | Uses the CLIP transformer encoder for text. 180 | """ 181 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): 182 | super().__init__() 183 | self.model, _ = clip.load(version, jit=False, device="cpu") 184 | self.device = device 185 | self.max_length = max_length 186 | self.n_repeat = n_repeat 187 | self.normalize = normalize 188 | 189 | def freeze(self): 190 | self.model = self.model.eval() 191 | for param in self.parameters(): 192 | param.requires_grad = False 193 | 194 | def forward(self, text): 195 | tokens = clip.tokenize(text).to(self.device) 196 | z = self.model.encode_text(tokens) 197 | if self.normalize: 198 | z = z / torch.linalg.norm(z, dim=1, keepdim=True) 199 | return z 200 | 201 | def encode(self, text): 202 | z = self(text) 203 | if z.ndim==2: 204 | z = z[:, None, :] 205 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) 206 | return z 207 | 208 | class FrozenCLIPImageEmbedder(AbstractEncoder): 209 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 210 | def __init__(self, version="openai/clip-vit-large-patch14"): 211 | super().__init__() 212 | self.transformer = CLIPVisionModel.from_pretrained(version) 213 | self.final_ln = LayerNorm(1024) 214 | self.mapper = Transformer( 215 | 1, 216 | 1024, 217 | 5, 218 | 1, 219 | ) 220 | 221 | self.freeze() 222 | 223 | def freeze(self): 224 | self.transformer = self.transformer.eval() 225 | for param in self.parameters(): 226 | param.requires_grad = False 227 | for param in self.mapper.parameters(): 228 | param.requires_grad = True 229 | for param in self.final_ln.parameters(): 230 | param.requires_grad = True 231 | 232 | def forward(self, image): 233 | outputs = self.transformer(pixel_values=image) 234 | z = outputs.pooler_output 235 | z = z.unsqueeze(1) 236 | z = self.mapper(z) 237 | z = self.final_ln(z) 238 | return z 239 | 240 | def encode(self, image): 241 | return self(image) 242 | 243 | 244 | 245 | class DINOEmbedder(AbstractEncoder): 246 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 247 | def __init__(self, dino_version): # small, large, huge, gigantic 248 | super().__init__() 249 | assert dino_version in ['small', 'big', 'large', 'huge'] 250 | letter_map = { 251 | 'small': 's', 252 | 'big': 'b', 253 | 'large': 'l', 254 | 'huge': 'g' 255 | } 256 | 257 | self.final_ln = LayerNorm(32) # unused -- remove later 258 | self.mapper = LayerNorm(32) # unused -- remove later 259 | # embedding_sizes = { 260 | # 'small': 384, 261 | # 'big': 768, 262 | # 'large': 1024, 263 | # 'huge': 1536 264 | # } 265 | 266 | # embedding_size = embedding_sizes[dino_version] 267 | letter = letter_map[dino_version] 268 | # self.transformer = CLIPVisionModel.from_pretrained(version) 269 | self.dino_model = torch.hub.load('facebookresearch/dinov2', f'dinov2_vit{letter}14_reg').cuda() 270 | 271 | 272 | self.freeze() 273 | 274 | def freeze(self): 275 | for param in self.parameters(): 276 | param.requires_grad = False 277 | 278 | def forward(self, image): 279 | with torch.no_grad(): 280 | outputs = self.dino_model.forward_features(image) 281 | patch_tokens = outputs['x_norm_patchtokens'] 282 | global_token = outputs['x_norm_clstoken'].unsqueeze(1) 283 | features = torch.concat([patch_tokens, global_token], dim=1) 284 | return torch.zeros_like(features) 285 | 286 | def encode(self, image): 287 | return self(image) 288 | 289 | 290 | class FixedVector(AbstractEncoder): 291 | """Uses the CLIP transformer encoder for text (from Hugging Face)""" 292 | def __init__(self): # small, large, huge, gigantic 293 | super().__init__() 294 | self.final_ln = LayerNorm(32) 295 | self.mapper = LayerNorm(32) 296 | self.fixed_vector = nn.Parameter(torch.randn((1,1,768)), requires_grad=True).cuda() 297 | def forward(self, image): 298 | return self.fixed_vector.repeat(image.shape[0],1,1).to(image.device) * 0.0 299 | 300 | def encode(self, image): 301 | return self(image) 302 | 303 | 304 | 305 | 306 | if __name__ == "__main__": 307 | from ldm.util import count_params 308 | model = FrozenCLIPEmbedder() 309 | count_params(model, verbose=True) -------------------------------------------------------------------------------- /ldm/modules/encoders/xf.py: -------------------------------------------------------------------------------- 1 | # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and 2 | # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example 3 | # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. 4 | # CreativeML Open RAIL-M 5 | # 6 | # ========================================================================================== 7 | # 8 | # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. 9 | # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit 10 | # LICENSE.md. 11 | # 12 | # ========================================================================================== 13 | 14 | """ 15 | Transformer implementation adapted from CLIP ViT: 16 | https://github.com/openai/CLIP/blob/4c0275784d6d9da97ca1f47eaaee31de1867da91/clip/model.py 17 | """ 18 | 19 | import math 20 | 21 | import torch as th 22 | import torch.nn as nn 23 | 24 | 25 | def convert_module_to_f16(l): 26 | """ 27 | Convert primitive modules to float16. 28 | """ 29 | if isinstance(l, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)): 30 | l.weight.data = l.weight.data.half() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.half() 33 | 34 | 35 | class LayerNorm(nn.LayerNorm): 36 | """ 37 | Implementation that supports fp16 inputs but fp32 gains/biases. 38 | """ 39 | 40 | def forward(self, x: th.Tensor): 41 | return super().forward(x.float()).to(x.dtype) 42 | 43 | 44 | class MultiheadAttention(nn.Module): 45 | def __init__(self, n_ctx, width, heads): 46 | super().__init__() 47 | self.n_ctx = n_ctx 48 | self.width = width 49 | self.heads = heads 50 | self.c_qkv = nn.Linear(width, width * 3) 51 | self.c_proj = nn.Linear(width, width) 52 | self.attention = QKVMultiheadAttention(heads, n_ctx) 53 | 54 | def forward(self, x): 55 | x = self.c_qkv(x) 56 | x = self.attention(x) 57 | x = self.c_proj(x) 58 | return x 59 | 60 | 61 | class MLP(nn.Module): 62 | def __init__(self, width): 63 | super().__init__() 64 | self.width = width 65 | self.c_fc = nn.Linear(width, width * 4) 66 | self.c_proj = nn.Linear(width * 4, width) 67 | self.gelu = nn.GELU() 68 | 69 | def forward(self, x): 70 | return self.c_proj(self.gelu(self.c_fc(x))) 71 | 72 | 73 | class QKVMultiheadAttention(nn.Module): 74 | def __init__(self, n_heads: int, n_ctx: int): 75 | super().__init__() 76 | self.n_heads = n_heads 77 | self.n_ctx = n_ctx 78 | 79 | def forward(self, qkv): 80 | bs, n_ctx, width = qkv.shape 81 | attn_ch = width // self.n_heads // 3 82 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 83 | qkv = qkv.view(bs, n_ctx, self.n_heads, -1) 84 | q, k, v = th.split(qkv, attn_ch, dim=-1) 85 | weight = th.einsum( 86 | "bthc,bshc->bhts", q * scale, k * scale 87 | ) # More stable with f16 than dividing afterwards 88 | wdtype = weight.dtype 89 | weight = th.softmax(weight.float(), dim=-1).type(wdtype) 90 | return th.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 91 | 92 | 93 | class ResidualAttentionBlock(nn.Module): 94 | def __init__( 95 | self, 96 | n_ctx: int, 97 | width: int, 98 | heads: int, 99 | ): 100 | super().__init__() 101 | 102 | self.attn = MultiheadAttention( 103 | n_ctx, 104 | width, 105 | heads, 106 | ) 107 | self.ln_1 = LayerNorm(width) 108 | self.mlp = MLP(width) 109 | self.ln_2 = LayerNorm(width) 110 | 111 | def forward(self, x: th.Tensor): 112 | x = x + self.attn(self.ln_1(x)) 113 | x = x + self.mlp(self.ln_2(x)) 114 | return x 115 | 116 | 117 | class Transformer(nn.Module): 118 | def __init__( 119 | self, 120 | n_ctx: int, 121 | width: int, 122 | layers: int, 123 | heads: int, 124 | ): 125 | super().__init__() 126 | self.n_ctx = n_ctx 127 | self.width = width 128 | self.layers = layers 129 | self.resblocks = nn.ModuleList( 130 | [ 131 | ResidualAttentionBlock( 132 | n_ctx, 133 | width, 134 | heads, 135 | ) 136 | for _ in range(layers) 137 | ] 138 | ) 139 | 140 | def forward(self, x: th.Tensor): 141 | for block in self.resblocks: 142 | x = block(x) 143 | return x 144 | -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and 2 | # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example 3 | # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. 4 | # CreativeML Open RAIL-M 5 | # 6 | # ========================================================================================== 7 | # 8 | # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. 9 | # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit 10 | # LICENSE.md. 11 | # 12 | # ========================================================================================== 13 | 14 | import torch 15 | import torch.nn as nn 16 | 17 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 18 | 19 | 20 | class LPIPSWithDiscriminator(nn.Module): 21 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 22 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 23 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 24 | disc_loss="hinge"): 25 | 26 | super().__init__() 27 | assert disc_loss in ["hinge", "vanilla"] 28 | self.kl_weight = kl_weight 29 | self.pixel_weight = pixelloss_weight 30 | self.perceptual_loss = LPIPS().eval() 31 | self.perceptual_weight = perceptual_weight 32 | # output log variance 33 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 34 | 35 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 36 | n_layers=disc_num_layers, 37 | use_actnorm=use_actnorm 38 | ).apply(weights_init) 39 | self.discriminator_iter_start = disc_start 40 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 41 | self.disc_factor = disc_factor 42 | self.discriminator_weight = disc_weight 43 | self.disc_conditional = disc_conditional 44 | 45 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 46 | if last_layer is not None: 47 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 48 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 49 | else: 50 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 51 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 52 | 53 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 54 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 55 | d_weight = d_weight * self.discriminator_weight 56 | return d_weight 57 | 58 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 59 | global_step, last_layer=None, cond=None, split="train", 60 | weights=None): 61 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 62 | if self.perceptual_weight > 0: 63 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 64 | rec_loss = rec_loss + self.perceptual_weight * p_loss 65 | 66 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 67 | weighted_nll_loss = nll_loss 68 | if weights is not None: 69 | weighted_nll_loss = weights*nll_loss 70 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 71 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 72 | kl_loss = posteriors.kl() 73 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 74 | 75 | # now the GAN part 76 | if optimizer_idx == 0: 77 | # generator update 78 | if cond is None: 79 | assert not self.disc_conditional 80 | logits_fake = self.discriminator(reconstructions.contiguous()) 81 | else: 82 | assert self.disc_conditional 83 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 84 | g_loss = -torch.mean(logits_fake) 85 | 86 | if self.disc_factor > 0.0: 87 | try: 88 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 89 | except RuntimeError: 90 | assert not self.training 91 | d_weight = torch.tensor(0.0) 92 | else: 93 | d_weight = torch.tensor(0.0) 94 | 95 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 96 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 97 | 98 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 99 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 100 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 101 | "{}/d_weight".format(split): d_weight.detach(), 102 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 103 | "{}/g_loss".format(split): g_loss.detach().mean(), 104 | } 105 | return loss, log 106 | 107 | if optimizer_idx == 1: 108 | # second pass for discriminator update 109 | if cond is None: 110 | logits_real = self.discriminator(inputs.contiguous().detach()) 111 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 112 | else: 113 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 114 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 115 | 116 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 117 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 118 | 119 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 120 | "{}/logits_real".format(split): logits_real.detach().mean(), 121 | "{}/logits_fake".format(split): logits_fake.detach().mean() 122 | } 123 | return d_loss, log 124 | 125 | -------------------------------------------------------------------------------- /ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and 2 | # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example 3 | # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. 4 | # CreativeML Open RAIL-M 5 | # 6 | # ========================================================================================== 7 | # 8 | # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. 9 | # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit 10 | # LICENSE.md. 11 | # 12 | # ========================================================================================== 13 | 14 | import torch 15 | from torch import nn 16 | import torch.nn.functional as F 17 | from einops import repeat 18 | 19 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 20 | from taming.modules.losses.lpips import LPIPS 21 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 22 | 23 | 24 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 25 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 26 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 27 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 28 | loss_real = (weights * loss_real).sum() / weights.sum() 29 | loss_fake = (weights * loss_fake).sum() / weights.sum() 30 | d_loss = 0.5 * (loss_real + loss_fake) 31 | return d_loss 32 | 33 | def adopt_weight(weight, global_step, threshold=0, value=0.): 34 | if global_step < threshold: 35 | weight = value 36 | return weight 37 | 38 | 39 | def measure_perplexity(predicted_indices, n_embed): 40 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 41 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 42 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 43 | avg_probs = encodings.mean(0) 44 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 45 | cluster_use = torch.sum(avg_probs > 0) 46 | return perplexity, cluster_use 47 | 48 | def l1(x, y): 49 | return torch.abs(x-y) 50 | 51 | 52 | def l2(x, y): 53 | return torch.pow((x-y), 2) 54 | 55 | 56 | class VQLPIPSWithDiscriminator(nn.Module): 57 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 58 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 59 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 60 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 61 | pixel_loss="l1"): 62 | super().__init__() 63 | assert disc_loss in ["hinge", "vanilla"] 64 | assert perceptual_loss in ["lpips", "clips", "dists"] 65 | assert pixel_loss in ["l1", "l2"] 66 | self.codebook_weight = codebook_weight 67 | self.pixel_weight = pixelloss_weight 68 | if perceptual_loss == "lpips": 69 | print(f"{self.__class__.__name__}: Running with LPIPS.") 70 | self.perceptual_loss = LPIPS().eval() 71 | else: 72 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 73 | self.perceptual_weight = perceptual_weight 74 | 75 | if pixel_loss == "l1": 76 | self.pixel_loss = l1 77 | else: 78 | self.pixel_loss = l2 79 | 80 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 81 | n_layers=disc_num_layers, 82 | use_actnorm=use_actnorm, 83 | ndf=disc_ndf 84 | ).apply(weights_init) 85 | self.discriminator_iter_start = disc_start 86 | if disc_loss == "hinge": 87 | self.disc_loss = hinge_d_loss 88 | elif disc_loss == "vanilla": 89 | self.disc_loss = vanilla_d_loss 90 | else: 91 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 92 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 93 | self.disc_factor = disc_factor 94 | self.discriminator_weight = disc_weight 95 | self.disc_conditional = disc_conditional 96 | self.n_classes = n_classes 97 | 98 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 99 | if last_layer is not None: 100 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 101 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 102 | else: 103 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 104 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 105 | 106 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 107 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 108 | d_weight = d_weight * self.discriminator_weight 109 | return d_weight 110 | 111 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 112 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 113 | if not exists(codebook_loss): 114 | codebook_loss = torch.tensor([0.]).to(inputs.device) 115 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 116 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 117 | if self.perceptual_weight > 0: 118 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 119 | rec_loss = rec_loss + self.perceptual_weight * p_loss 120 | else: 121 | p_loss = torch.tensor([0.0]) 122 | 123 | nll_loss = rec_loss 124 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 125 | nll_loss = torch.mean(nll_loss) 126 | 127 | # now the GAN part 128 | if optimizer_idx == 0: 129 | # generator update 130 | if cond is None: 131 | assert not self.disc_conditional 132 | logits_fake = self.discriminator(reconstructions.contiguous()) 133 | else: 134 | assert self.disc_conditional 135 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 136 | g_loss = -torch.mean(logits_fake) 137 | 138 | try: 139 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 140 | except RuntimeError: 141 | assert not self.training 142 | d_weight = torch.tensor(0.0) 143 | 144 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 145 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 146 | 147 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 148 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 149 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 150 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 151 | "{}/p_loss".format(split): p_loss.detach().mean(), 152 | "{}/d_weight".format(split): d_weight.detach(), 153 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 154 | "{}/g_loss".format(split): g_loss.detach().mean(), 155 | } 156 | if predicted_indices is not None: 157 | assert self.n_classes is not None 158 | with torch.no_grad(): 159 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 160 | log[f"{split}/perplexity"] = perplexity 161 | log[f"{split}/cluster_usage"] = cluster_usage 162 | return loss, log 163 | 164 | if optimizer_idx == 1: 165 | # second pass for discriminator update 166 | if cond is None: 167 | logits_real = self.discriminator(inputs.contiguous().detach()) 168 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 169 | else: 170 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 171 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 172 | 173 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 174 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 175 | 176 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 177 | "{}/logits_real".format(split): logits_real.detach().mean(), 178 | "{}/logits_fake".format(split): logits_fake.detach().mean() 179 | } 180 | return d_loss, log 181 | -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | # This code is built from the Stable Diffusion repository: https://github.com/CompVis/stable-diffusion, and 2 | # Paint-by-Example repo https://github.com/Fantasy-Studio/Paint-by-Example 3 | # Copyright (c) 2022 Robin Rombach and Patrick Esser and contributors. 4 | # CreativeML Open RAIL-M 5 | # 6 | # ========================================================================================== 7 | # 8 | # Adobe’s modifications are Copyright 2024 Adobe Research. All rights reserved. 9 | # Adobe’s modifications are licensed under the Adobe Research License. To view a copy of the license, visit 10 | # LICENSE.md. 11 | # 12 | # ========================================================================================== 13 | 14 | import importlib 15 | 16 | import torch 17 | import numpy as np 18 | from collections import abc 19 | from einops import rearrange 20 | from functools import partial 21 | 22 | import multiprocessing as mp 23 | from threading import Thread 24 | from queue import Queue 25 | 26 | from inspect import isfunction 27 | from PIL import Image, ImageDraw, ImageFont 28 | 29 | 30 | def log_txt_as_img(wh, xc, size=10): 31 | # wh a tuple of (width, height) 32 | # xc a list of captions to plot 33 | b = len(xc) 34 | txts = list() 35 | for bi in range(b): 36 | txt = Image.new("RGB", wh, color="white") 37 | draw = ImageDraw.Draw(txt) 38 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 39 | nc = int(40 * (wh[0] / 256)) 40 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 41 | 42 | try: 43 | draw.text((0, 0), lines, fill="black", font=font) 44 | except UnicodeEncodeError: 45 | print("Cant encode string for logging. Skipping.") 46 | 47 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 48 | txts.append(txt) 49 | txts = np.stack(txts) 50 | txts = torch.tensor(txts) 51 | return txts 52 | 53 | 54 | def ismap(x): 55 | if not isinstance(x, torch.Tensor): 56 | return False 57 | return (len(x.shape) == 4) and (x.shape[1] > 3) 58 | 59 | 60 | def isimage(x): 61 | if not isinstance(x, torch.Tensor): 62 | return False 63 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 64 | 65 | 66 | def exists(x): 67 | return x is not None 68 | 69 | 70 | def default(val, d): 71 | if exists(val): 72 | return val 73 | return d() if isfunction(d) else d 74 | 75 | 76 | def mean_flat(tensor): 77 | """ 78 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 79 | Take the mean over all non-batch dimensions. 80 | """ 81 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 82 | 83 | 84 | def count_params(model, verbose=False): 85 | total_params = sum(p.numel() for p in model.parameters()) 86 | if verbose: 87 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 88 | return total_params 89 | 90 | 91 | def instantiate_from_config(config): 92 | if not "target" in config: 93 | if config == '__is_first_stage__': 94 | return None 95 | elif config == "__is_unconditional__": 96 | return None 97 | raise KeyError("Expected key `target` to instantiate.") 98 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 99 | 100 | 101 | def get_obj_from_str(string, reload=False): 102 | module, cls = string.rsplit(".", 1) 103 | if reload: 104 | module_imp = importlib.import_module(module) 105 | importlib.reload(module_imp) 106 | return getattr(importlib.import_module(module, package=None), cls) 107 | 108 | 109 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 110 | # create dummy dataset instance 111 | 112 | # run prefetching 113 | if idx_to_fn: 114 | res = func(data, worker_id=idx) 115 | else: 116 | res = func(data) 117 | Q.put([idx, res]) 118 | Q.put("Done") 119 | 120 | 121 | def parallel_data_prefetch( 122 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 123 | ): 124 | # if target_data_type not in ["ndarray", "list"]: 125 | # raise ValueError( 126 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 127 | # ) 128 | if isinstance(data, np.ndarray) and target_data_type == "list": 129 | raise ValueError("list expected but function got ndarray.") 130 | elif isinstance(data, abc.Iterable): 131 | if isinstance(data, dict): 132 | print( 133 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 134 | ) 135 | data = list(data.values()) 136 | if target_data_type == "ndarray": 137 | data = np.asarray(data) 138 | else: 139 | data = list(data) 140 | else: 141 | raise TypeError( 142 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 143 | ) 144 | 145 | if cpu_intensive: 146 | Q = mp.Queue(1000) 147 | proc = mp.Process 148 | else: 149 | Q = Queue(1000) 150 | proc = Thread 151 | # spawn processes 152 | if target_data_type == "ndarray": 153 | arguments = [ 154 | [func, Q, part, i, use_worker_id] 155 | for i, part in enumerate(np.array_split(data, n_proc)) 156 | ] 157 | else: 158 | step = ( 159 | int(len(data) / n_proc + 1) 160 | if len(data) % n_proc != 0 161 | else int(len(data) / n_proc) 162 | ) 163 | arguments = [ 164 | [func, Q, part, i, use_worker_id] 165 | for i, part in enumerate( 166 | [data[i: i + step] for i in range(0, len(data), step)] 167 | ) 168 | ] 169 | processes = [] 170 | for i in range(n_proc): 171 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 172 | processes += [p] 173 | 174 | # start processes 175 | print(f"Start prefetching...") 176 | import time 177 | 178 | start = time.time() 179 | gather_res = [[] for _ in range(n_proc)] 180 | try: 181 | for p in processes: 182 | p.start() 183 | 184 | k = 0 185 | while k < n_proc: 186 | # get result 187 | res = Q.get() 188 | if res == "Done": 189 | k += 1 190 | else: 191 | gather_res[res[0]] = res[1] 192 | 193 | except Exception as e: 194 | print("Exception: ", e) 195 | for p in processes: 196 | p.terminate() 197 | 198 | raise e 199 | finally: 200 | for p in processes: 201 | p.join() 202 | print(f"Prefetching complete. [{time.time() - start} sec.]") 203 | 204 | if target_data_type == 'ndarray': 205 | if not isinstance(gather_res[0], np.ndarray): 206 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 207 | 208 | # order outputs 209 | return np.concatenate(gather_res, axis=0) 210 | elif target_data_type == 'list': 211 | out = [] 212 | for r in gather_res: 213 | out.extend(r) 214 | return out 215 | else: 216 | return gather_res 217 | -------------------------------------------------------------------------------- /magicfu_gradio.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Adobe. All rights reserved. 2 | 3 | from run_magicfu import MagicFixup 4 | import os 5 | import pathlib 6 | import torchvision 7 | from torch import autocast 8 | from PIL import Image 9 | import gradio as gr 10 | import numpy as np 11 | import argparse 12 | 13 | 14 | def sample(original_image, coarse_edit): 15 | to_tensor = torchvision.transforms.ToTensor() 16 | with autocast("cuda"): 17 | w, h = coarse_edit.size 18 | ref_image_t = to_tensor(original_image.resize((512,512))).half().cuda() 19 | coarse_edit_t = to_tensor(coarse_edit.resize((512,512))).half().cuda() 20 | # get mask from coarse 21 | coarse_edit_mask_t = to_tensor(coarse_edit.resize((512,512))).half().cuda() 22 | mask_t = (coarse_edit_mask_t[-1][None, None,...]).half() # do center crop 23 | coarse_edit_t_rgb = coarse_edit_t[:-1] 24 | 25 | out_rgb = magic_fixup.edit_image(ref_image_t, coarse_edit_t_rgb, mask_t, start_step=1.0, steps=50) 26 | output = out_rgb.squeeze().cpu().detach().moveaxis(0, -1).float().numpy() 27 | output = (output * 255.0).astype(np.uint8) 28 | output_pil = Image.fromarray(output) 29 | output_pil = output_pil.resize((w, h)) 30 | return output_pil 31 | 32 | def file_exists(path): 33 | """ Check if a file exists and is not a directory. """ 34 | if not os.path.isfile(path): 35 | raise argparse.ArgumentTypeError(f"{path} is not a valid file.") 36 | return path 37 | 38 | def parse_arguments(): 39 | """ Parses command-line arguments. """ 40 | parser = argparse.ArgumentParser(description="Process images based on provided paths.") 41 | parser.add_argument("--checkpoint", type=file_exists, required=True, help="Path to the MagicFixup checkpoint file.") 42 | 43 | return parser.parse_args() 44 | 45 | demo = gr.Interface(fn=sample, inputs=[gr.Image(type="pil", image_mode='RGB'), gr.Image(type="pil", image_mode='RGBA')], outputs=gr.Image(), 46 | examples='examples') 47 | 48 | if __name__ == "__main__": 49 | args = parse_arguments() 50 | 51 | # create magic fixup model 52 | magic_fixup = MagicFixup(model_path=args.checkpoint) 53 | demo.launch(share=True) 54 | -------------------------------------------------------------------------------- /run_magicfu.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Adobe. All rights reserved. 2 | 3 | #%% 4 | import cv2 5 | import torch 6 | import numpy as np 7 | from omegaconf import OmegaConf 8 | from PIL import Image 9 | from itertools import islice 10 | from torch import autocast 11 | import torchvision 12 | from ldm.util import instantiate_from_config 13 | from ldm.models.diffusion.ddim import DDIMSampler 14 | from torchvision.transforms import Resize 15 | import argparse 16 | import os 17 | import pathlib 18 | import glob 19 | 20 | 21 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 22 | 23 | def fix_img(test_img): 24 | width, height = test_img.size 25 | if width != height: 26 | left = 0 27 | right = height 28 | bottom = height 29 | top = 0 30 | return test_img.crop((left, top, right, bottom)) 31 | else: 32 | return test_img 33 | # util funcs 34 | def chunk(it, size): 35 | it = iter(it) 36 | return iter(lambda: tuple(islice(it, size)), ()) 37 | 38 | def get_tensor_clip(normalize=True, toTensor=True): 39 | transform_list = [] 40 | if toTensor: 41 | transform_list += [torchvision.transforms.ToTensor()] 42 | 43 | if normalize: 44 | transform_list += [torchvision.transforms.Normalize((0.48145466, 0.4578275, 0.40821073), 45 | (0.26862954, 0.26130258, 0.27577711))] 46 | return torchvision.transforms.Compose(transform_list) 47 | 48 | def get_tensor_dino(normalize=True, toTensor=True): 49 | transform_list = [torchvision.transforms.Resize((224,224))] 50 | if toTensor: 51 | transform_list += [torchvision.transforms.ToTensor()] 52 | 53 | if normalize: 54 | transform_list += [lambda x: 255.0 * x[:3], 55 | torchvision.transforms.Normalize( 56 | mean=(123.675, 116.28, 103.53), 57 | std=(58.395, 57.12, 57.375), 58 | )] 59 | return torchvision.transforms.Compose(transform_list) 60 | 61 | def get_tensor(normalize=True, toTensor=True): 62 | transform_list = [] 63 | if toTensor: 64 | transform_list += [torchvision.transforms.ToTensor()] 65 | 66 | if normalize: 67 | transform_list += [torchvision.transforms.Normalize((0.5, 0.5, 0.5), 68 | (0.5, 0.5, 0.5))] 69 | transform_list += [ 70 | torchvision.transforms.Resize(512), 71 | torchvision.transforms.CenterCrop(512) 72 | ] 73 | return torchvision.transforms.Compose(transform_list) 74 | 75 | 76 | def numpy_to_pil(images): 77 | """ 78 | Convert a numpy image or a batch of images to a PIL image. 79 | """ 80 | if images.ndim == 3: 81 | images = images[None, ...] 82 | images = (images * 255).round().astype("uint8") 83 | pil_images = [Image.fromarray(image) for image in images] 84 | 85 | return pil_images 86 | 87 | 88 | 89 | def load_model_from_config(config, ckpt, verbose=False): 90 | model = instantiate_from_config(config.model) 91 | # print('NOTE: NO CHECKPOINT IS LOADED') 92 | 93 | if ckpt is not None: 94 | print(f"Loading model from {ckpt}") 95 | pl_sd = torch.load(ckpt, map_location="cpu") 96 | if "global_step" in pl_sd: 97 | print(f"Global Step: {pl_sd['global_step']}") 98 | # sd = pl_sd["state_dict"] 99 | 100 | m, u = model.load_state_dict(sd, strict=False) 101 | if len(m) > 0 and verbose: 102 | print("missing keys:") 103 | print(m) 104 | if len(u) > 0 and verbose: 105 | print("unexpected keys:") 106 | print(u) 107 | 108 | model.cuda() 109 | model.eval() 110 | return model 111 | 112 | 113 | def get_model(config_path, ckpt_path): 114 | config = OmegaConf.load(f"{config_path}") 115 | model = load_model_from_config(config, None) 116 | pl_sd = torch.load(ckpt_path, map_location="cpu") 117 | 118 | m, u = model.load_state_dict(pl_sd, strict=True) 119 | if len(m) > 0: 120 | print("WARNING: missing keys:") 121 | print(m) 122 | if len(u) > 0: 123 | print("unexpected keys:") 124 | print(u) 125 | 126 | 127 | model = model.to(device) 128 | return model 129 | 130 | def get_grid(size): 131 | y = np.repeat(np.arange(size)[None, ...], size) 132 | y = y.reshape(size, size) 133 | x = y.transpose() 134 | out = np.stack([y,x], -1) 135 | return out 136 | 137 | def un_norm(x): 138 | return (x+1.0)/2.0 139 | 140 | class MagicFixup: 141 | def __init__(self, model_path='/sensei-fs/users/halzayer/collage2photo/Paint-by-Example/official_checkpoint_image_attn_200k.pt'): 142 | self.model = get_model('configs/collage_mix_train.yaml',model_path) 143 | 144 | 145 | def edit_image(self, ref_image, coarse_edit, mask_tensor, start_step, steps): 146 | # essentially sample 147 | sampler = DDIMSampler(self.model) 148 | 149 | start_code = None 150 | 151 | transformed_grid = torch.zeros((2, 64, 64)) 152 | 153 | self.model.model.og_grid = None 154 | self.model.model.transformed_grid = transformed_grid.unsqueeze(0).to(self.model.device) 155 | 156 | scale = 1.0 157 | C, f, H, W= 4, 8, 512, 512 158 | n_samples = 1 159 | ddim_steps = steps 160 | ddim_eta = 1.0 161 | step = start_step 162 | 163 | with torch.no_grad(): 164 | with autocast("cuda"): 165 | with self.model.ema_scope(): 166 | image_tensor = get_tensor(toTensor=False)(coarse_edit) 167 | 168 | clean_ref_tensor = get_tensor(toTensor=False)(ref_image) 169 | clean_ref_tensor = clean_ref_tensor.unsqueeze(0) 170 | 171 | ref_tensor=get_tensor_dino(toTensor=False)(ref_image).unsqueeze(0) 172 | 173 | b_mask = mask_tensor.cpu() < 0.5 174 | 175 | # inpainting 176 | reference = un_norm(image_tensor) 177 | reference = reference.squeeze() 178 | ref_cv = torch.moveaxis(reference, 0, -1).cpu().numpy() 179 | ref_cv = (ref_cv * 255).astype(np.uint8) 180 | 181 | cv_mask = b_mask.int().squeeze().cpu().numpy().astype(np.uint8) 182 | kernel = np.ones((7,7)) 183 | dilated_mask = cv2.dilate(cv_mask, kernel) 184 | 185 | dst = cv2.inpaint(ref_cv,dilated_mask,3,cv2.INPAINT_NS) 186 | # dst = inpaint.inpaint_biharmonic(ref_cv, dilated_mask, channel_axis=-1) 187 | dst_tensor = torch.tensor(dst).moveaxis(-1, 0) / 255.0 188 | image_tensor = (dst_tensor * 2.0) - 1.0 189 | image_tensor = image_tensor.unsqueeze(0) 190 | 191 | ref_tensor = ref_tensor 192 | 193 | inpaint_image = image_tensor#*mask_tensor 194 | 195 | test_model_kwargs={} 196 | test_model_kwargs['inpaint_mask']=mask_tensor.to(device) 197 | test_model_kwargs['inpaint_image']=inpaint_image.to(device) 198 | clean_ref_tensor = clean_ref_tensor.to(device) 199 | ref_tensor=ref_tensor.to(device) 200 | uc = None 201 | if scale != 1.0: 202 | uc = self.model.learnable_vector 203 | c = self.model.get_learned_conditioning(ref_tensor.to(torch.float16)) 204 | c = self.model.proj_out(c) 205 | 206 | z_inpaint = self.model.encode_first_stage(test_model_kwargs['inpaint_image']) 207 | z_inpaint = self.model.get_first_stage_encoding(z_inpaint).detach() 208 | 209 | 210 | z_ref = self.model.encode_first_stage(clean_ref_tensor) 211 | z_ref = self.model.get_first_stage_encoding(z_ref).detach() 212 | 213 | test_model_kwargs['inpaint_image']=z_inpaint 214 | test_model_kwargs['inpaint_mask']=Resize([z_inpaint.shape[-2],z_inpaint.shape[-1]])(test_model_kwargs['inpaint_mask']) 215 | 216 | 217 | shape = [C, H // f, W // f] 218 | 219 | samples_ddim, _ = sampler.sample(S=ddim_steps, 220 | conditioning=c, 221 | z_ref=z_ref, 222 | batch_size=n_samples, 223 | shape=shape, 224 | verbose=False, 225 | unconditional_guidance_scale=scale, 226 | unconditional_conditioning=uc, 227 | eta=ddim_eta, 228 | x_T=start_code, 229 | test_model_kwargs=test_model_kwargs, 230 | x0=z_inpaint, 231 | x0_step=step, 232 | ddim_discretize='uniform', 233 | drop_latent_guidance=1.0 234 | ) 235 | 236 | 237 | x_samples_ddim = self.model.decode_first_stage(samples_ddim) 238 | x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) 239 | x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).numpy() 240 | 241 | x_checked_image=x_samples_ddim 242 | x_checked_image_torch = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2) 243 | 244 | 245 | return x_checked_image_torch 246 | #%% 247 | 248 | 249 | #%% 250 | import time 251 | 252 | 253 | 254 | # %% 255 | def file_exists(path): 256 | """ Check if a file exists and is not a directory. """ 257 | if not os.path.isfile(path): 258 | raise argparse.ArgumentTypeError(f"{path} is not a valid file.") 259 | return path 260 | 261 | def parse_arguments(): 262 | """ Parses command-line arguments. """ 263 | parser = argparse.ArgumentParser(description="Process images based on provided paths.") 264 | parser.add_argument("--checkpoint", type=file_exists, required=True, help="Path to the MagicFixup checkpoint file.") 265 | parser.add_argument("--reference", type=file_exists, default='examples/fox_drinking_og.png', help="Path to the reference original image.") 266 | parser.add_argument("--edit", type=file_exists, default='examples/fox_drinking__edit__01.png', help="Path to the image edit. Make sure the alpha channel is set properly") 267 | parser.add_argument("--output-dir", type=str, default='./outputs', help="Path to the folder where to save the outputs") 268 | parser.add_argument("--samples", type=int, default=5, help="number of samples to output") 269 | 270 | return parser.parse_args() 271 | 272 | 273 | def main(): 274 | # Parse arguments 275 | args = parse_arguments() 276 | 277 | # create magic fixup model 278 | magic_fixup = MagicFixup(model_path=args.checkpoint) 279 | output_dir = args.output_dir 280 | 281 | os.makedirs(output_dir, exist_ok=True) 282 | 283 | # run it here 284 | 285 | to_tensor = torchvision.transforms.ToTensor() 286 | 287 | 288 | 289 | ref_path = args.reference 290 | coarse_edit_path = args.edit 291 | mask_edit_path = coarse_edit_path 292 | 293 | edit_file_name = pathlib.Path(coarse_edit_path).stem 294 | save_pattern = f'{output_dir}/{edit_file_name}__sample__*.png' 295 | save_counter = len(glob.glob(save_pattern)) 296 | 297 | all_rgbs = [] 298 | for i in range(args.samples): 299 | with autocast("cuda"): 300 | ref_image_t = to_tensor(Image.open(ref_path).convert('RGB').resize((512,512))).half().cuda() 301 | coarse_edit_t = to_tensor(Image.open(coarse_edit_path).resize((512,512))).half().cuda() 302 | # get mask from coarse 303 | # mask_t = torch.ones_like(coarse_edit_t[-1][None, None,...]) 304 | coarse_edit_mask_t = to_tensor(Image.open(mask_edit_path).resize((512,512))).half().cuda() 305 | # get mask from coarse 306 | mask_t = (coarse_edit_mask_t[-1][None, None,...]).half() # do center crop 307 | coarse_edit_t_rgb = coarse_edit_t[:-1] 308 | 309 | out_rgb = magic_fixup.edit_image(ref_image_t, coarse_edit_t_rgb, mask_t, start_step=1.0, steps=50) 310 | all_rgbs.append(out_rgb.squeeze().cpu().detach().float()) 311 | 312 | save_path = f'{output_dir}/{edit_file_name}__sample__{save_counter:03d}.png' 313 | torchvision.utils.save_image(all_rgbs[i], save_path) 314 | save_counter += 1 315 | 316 | 317 | 318 | if __name__ == "__main__": 319 | main() -------------------------------------------------------------------------------- /scripts/combine_model_params.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Adobe. All rights reserved. 2 | #%% 3 | import matplotlib.pyplot as plt 4 | import cv2 5 | import torch 6 | import numpy as np 7 | from omegaconf import OmegaConf 8 | from PIL import Image 9 | from tqdm import tqdm, trange 10 | from imwatermark import WatermarkEncoder 11 | from itertools import islice 12 | import time 13 | from pytorch_lightning import seed_everything 14 | from torch import autocast 15 | import torchvision 16 | from ldm.util import instantiate_from_config 17 | from ldm.models.diffusion.ddim import DDIMSampler 18 | from torchvision.transforms import Resize 19 | import argparse 20 | import os 21 | import pathlib 22 | import glob 23 | import tqdm 24 | 25 | 26 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 27 | 28 | def load_model_from_config(config): 29 | model = instantiate_from_config(config.model) 30 | 31 | model.cuda() 32 | model.eval() 33 | return model 34 | 35 | 36 | def get_model(config_path, ckpt_path, pretrained_sd_path): 37 | config = OmegaConf.load(f"{config_path}") 38 | model = load_model_from_config(config) 39 | model.load_state_dict(torch.load(pretrained_sd_path,map_location='cpu')['state_dict'],strict=False) 40 | 41 | 42 | pl_sd = torch.load(ckpt_path, map_location="cpu") 43 | wrapped_state_dict = pl_sd #self.lightning_module.trainer.model.state_dict() 44 | new_sd = {k.replace("_forward_module.", ""): wrapped_state_dict[k] for k in wrapped_state_dict} 45 | 46 | m, u = model.load_state_dict(new_sd, strict=False) 47 | if len(m) > 0: 48 | print("missing keys:") 49 | print(m) 50 | if len(u) > 0: 51 | print("unexpected keys:") 52 | print(u) 53 | 54 | 55 | model = model.to(device) 56 | return model 57 | 58 | def file_exists(path): 59 | """ Check if a file exists and is not a directory. """ 60 | if not os.path.isfile(path): 61 | raise argparse.ArgumentTypeError(f"{path} is not a valid file.") 62 | return path 63 | 64 | def parse_arguments(): 65 | """ Parses command-line arguments. """ 66 | parser = argparse.ArgumentParser(description="Process images based on provided paths.") 67 | parser.add_argument("--pretrained_sd", type=file_exists, required=True, help="Path to the SD1.4 pretrained checkpoint") 68 | parser.add_argument("--learned_params", type=file_exists, required=True, help="Path to the MagicFixup learned parameters.") 69 | parser.add_argument("--save_path", type=str, required=True, help="Path to save the full model state dict") 70 | 71 | return parser.parse_args() 72 | 73 | def main(): 74 | args = parse_arguments() 75 | model = get_model('configs/collage_mix_train.yaml',args.learned_params, args.pretrained_sd) 76 | torch.save(model.state_dict(), args.save_path) 77 | 78 | if __name__ == '__main__': 79 | main() -------------------------------------------------------------------------------- /scripts/modify_checkpoints.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 Adobe. All rights reserved. 2 | import torch 3 | pretrained_model_path='pretrained_models/sd-v1-4.ckpt' 4 | ckpt_file=torch.load(pretrained_model_path,map_location='cpu') 5 | zero_data=torch.zeros(320,5,3,3) 6 | new_weight=torch.cat((ckpt_file['state_dict']['model.diffusion_model.input_blocks.0.0.weight'],zero_data),dim=1) 7 | ckpt_file['state_dict']['model.diffusion_model.input_blocks.0.0.weight']=new_weight 8 | torch.save(ckpt_file,"pretrained_models/sd-v1-4-modified-9channel.ckpt") -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='Magic-FU', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) 14 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | python -u main.py \ 2 | --logdir models/Paint-by-Example \ 3 | --pretrained_model pretrained_models/sd-v1-4-modified-9channel.ckpt \ 4 | --base configs/collage_mix_train.yaml \ 5 | --scale_lr False \ 6 | --name collage_mix_magic_fixup 7 | --------------------------------------------------------------------------------