├── .gitignore ├── LICENSE ├── README.md ├── app.py ├── dreamo ├── dreamo_pipeline.py ├── transformer.py └── utils.py ├── example_inputs ├── cat.png ├── dog1.png ├── dog2.png ├── dress.png ├── hinton.jpeg ├── man1.png ├── man2.jpeg ├── mickey.png ├── mountain.png ├── perfume.png ├── shirt.png ├── skirt.jpeg ├── toy1.png ├── woman1.png ├── woman2.png ├── woman3.png └── woman4.jpeg ├── models └── .gitkeep ├── pyproject.toml ├── requirements.txt └── tools └── BEN2.py /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/* 2 | experiments/* 3 | results/* 4 | tb_logger/* 5 | wandb/* 6 | tmp/* 7 | weights/* 8 | inputs/* 9 | models/* 10 | 11 | *.DS_Store 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | pip-wheel-metadata/ 36 | share/python-wheels/ 37 | *.egg-info/ 38 | .installed.cfg 39 | *.egg 40 | MANIFEST 41 | 42 | # PyInstaller 43 | # Usually these files are written by a python script from a template 44 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 45 | *.manifest 46 | *.spec 47 | 48 | # Installer logs 49 | pip-log.txt 50 | pip-delete-this-directory.txt 51 | 52 | # Unit test / coverage reports 53 | htmlcov/ 54 | .tox/ 55 | .nox/ 56 | .coverage 57 | .coverage.* 58 | .cache 59 | nosetests.xml 60 | coverage.xml 61 | *.cover 62 | *.py,cover 63 | .hypothesis/ 64 | .pytest_cache/ 65 | 66 | # Translations 67 | *.mo 68 | *.pot 69 | 70 | # Django stuff: 71 | *.log 72 | local_settings.py 73 | db.sqlite3 74 | db.sqlite3-journal 75 | 76 | # Flask stuff: 77 | instance/ 78 | .webassets-cache 79 | 80 | # Scrapy stuff: 81 | .scrapy 82 | 83 | # Sphinx documentation 84 | docs/_build/ 85 | 86 | # PyBuilder 87 | target/ 88 | 89 | # Jupyter Notebook 90 | .ipynb_checkpoints 91 | 92 | # IPython 93 | profile_default/ 94 | ipython_config.py 95 | 96 | # pyenv 97 | .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 107 | __pypackages__/ 108 | 109 | # Celery stuff 110 | celerybeat-schedule 111 | celerybeat.pid 112 | 113 | # SageMath parsed files 114 | *.sage.py 115 | 116 | # Environments 117 | .env 118 | .venv 119 | env/ 120 | venv/ 121 | ENV/ 122 | env.bak/ 123 | venv.bak/ 124 | 125 | # Spyder project settings 126 | .spyderproject 127 | .spyproject 128 | 129 | # Rope project settings 130 | .ropeproject 131 | 132 | # mkdocs documentation 133 | /site 134 | 135 | # mypy 136 | .mypy_cache/ 137 | .dmypy.json 138 | dmypy.json 139 | 140 | # Pyre type checker 141 | .pyre/ 142 | 143 | .idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DreamO 2 | 3 | Official implementation of **[DreamO: A Unified Framework for Image Customization](https://arxiv.org/abs/2504.16915)** 4 | 5 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2504.16915) [![demo](https://img.shields.io/badge/🤗-HuggingFace_Demo-orange)](https://huggingface.co/spaces/ByteDance/DreamO)
6 | 7 | ### :triangular_flag_on_post: Updates 8 | * **2025.05.30**: 🔥🔥 Native [ComfyUI implementation](https://github.com/ToTheBeginning/ComfyUI-DreamO) is now available! 9 | * **2025.05.12**: 🔥 Support consumer-grade GPUs (16GB or 24GB) now, see [here](#for-consumer-grade-gpus) for instruction 10 | * **2025.05.11**: 🔥 **We have updated the model to mitigate over-saturation and plastic-face issue**. The new version shows consistent improvements over the previous release. Please check it out! 11 | * **2025.05.08**: release codes and models 12 | * 2025.04.24: release DreamO tech report. 13 | 14 | https://github.com/user-attachments/assets/385ba166-79df-40d3-bcd7-5472940fa24a 15 | 16 | ## :wrench: Dependencies and Installation 17 | ```bash 18 | # clone DreamO repo 19 | git clone https://github.com/bytedance/DreamO.git 20 | cd DreamO 21 | # create conda env 22 | conda create --name dreamo python=3.10 23 | # activate env 24 | conda activate dreamo 25 | # install dependent packages 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | 30 | ## :zap: Quick Inference 31 | ### Local Gradio Demo 32 | ```bash 33 | python app.py 34 | ``` 35 | We observe strong compatibility between DreamO and the accelerated FLUX LoRA variant 36 | ([FLUX-turbo](https://huggingface.co/alimama-creative/FLUX.1-Turbo-Alpha)), and thus enable Turbo LoRA by default, 37 | reducing inference to 12 steps (vs. 25+ by default). Turbo can be disabled via `--no_turbo`, though our evaluation shows mixed results; 38 | we therefore recommend keeping Turbo enabled. 39 | 40 | **tips**: If you observe limb distortion or poor text generation, try increasing the guidance scale; if the image appears overly glossy or over-saturated, consider lowering the guidance scale. 41 | 42 | #### For consumer-grade GPUs 43 | We have added support for 8-bit quantization and CPU offload to enable execution on consumer-grade GPUs. This requires the `optimum-quanto` library, and thus the PyTorch version in `requirements.txt` has been upgraded to 2.6.0. If you are using an older version of PyTorch, you may need to reconfigure your environment. 44 | 45 | - **For users with 24GB GPUs**, run `python app.py --int8` to enable the int8-quantized model. 46 | 47 | - **For users with 16GB GPUs**, run `python app.py --int8 --offload` to enable CPU offloading alongside int8 quantization. Note that CPU offload significantly reduces inference speed and should only be enabled when necessary. 48 | 49 | #### For macOS Apple Silicon (M1/M2/M3/M4) 50 | DreamO now supports macOS with Apple Silicon chips using Metal Performance Shaders (MPS). The app automatically detects and uses MPS when available. 51 | 52 | - **For macOS users**, simply run `python app.py` and the app will automatically use MPS acceleration. 53 | - **Manual device selection**: You can explicitly specify the device using `python app.py --device mps` (or `--device cpu` if needed). 54 | - **Memory optimization**: For devices with limited memory, you can combine MPS with quantization: `python app.py --device mps --int8` 55 | 56 | **Note**: Make sure you have PyTorch with MPS support installed. The current requirements.txt includes PyTorch 2.6.0+ which has full MPS support. 57 | 58 | ### Supported Tasks 59 | #### IP 60 | This task is similar to IP-Adapter and supports a wide range of inputs including characters, objects, and animals. 61 | By leveraging VAE-based feature encoding, DreamO achieves higher fidelity than previous adapter methods, with a distinct advantage in preserving character identity. 62 | 63 | ![IP_example](https://github.com/user-attachments/assets/086ceabd-338b-4fef-ad1f-bab6b30a1160) 64 | 65 | #### ID 66 | Here, ID specifically refers to facial identity. Unlike the IP task, which considers both face and clothing, 67 | the ID task focuses solely on facial features. This task is similar to InstantID and PuLID. 68 | Compared to previous methods, DreamO achieves higher facial fidelity, but introduces more model contamination than the SOTA approach PuLID. 69 | 70 | ![ID_example](https://github.com/user-attachments/assets/392dd325-d4f4-4abb-9718-4b16fe7844c6) 71 | 72 | tips: If you notice the face appears overly glossy, try lowering the guidance scale. 73 | 74 | #### Try-On 75 | This task supports inputs such as tops, bottoms, glasses, and hats, and enables virtual try-on with multiple garments. 76 | Notably, our training set does not include multi-garment or ID+garment data, yet the model generalizes well to these unseen combinations. 77 | 78 | ![tryon_example](https://github.com/user-attachments/assets/fefec673-110a-44f2-83a9-5b779728a734) 79 | 80 | #### Style 81 | This task is similar to Style-Adapter and InstantStyle. Please note that style consistency is currently less stable compared to other tasks, 82 | and in the current version, style cannot be combined with other conditions. We are working on improvements in future releases—stay tuned. 83 | 84 | ![style_example](https://github.com/user-attachments/assets/0a31674a-c3c2-451f-91e4-c521659d40f3) 85 | 86 | #### Multi Condition 87 | You can use multiple conditions (ID, IP, Try-On) to generate more creative images. 88 | Thanks to the feature routing constraint proposed in the paper, DreamO effectively mitigates conflicts and entanglement among multiple entities. 89 | 90 | ![multi_cond_example](https://github.com/user-attachments/assets/e43e6ebb-a028-4b29-b76d-3eaa1e69b9c9) 91 | 92 | ### ComfyUI 93 | - native ComfyUI support: [ComfyUI-DreamO](https://github.com/ToTheBeginning/ComfyUI-DreamO) 94 | 95 | 96 | ### Online HuggingFace Demo 97 | You can try DreamO demo on [HuggingFace](https://huggingface.co/spaces/ByteDance/DreamO). 98 | 99 | 100 | ## Disclaimer 101 | 102 | This project strives to impact the domain of AI-driven image generation positively. Users are granted the freedom to 103 | create images using this tool, but they are expected to comply with local laws and utilize it responsibly. 104 | The developers do not assume any responsibility for potential misuse by users. 105 | 106 | 107 | ## Citation 108 | 109 | If DreamO is helpful, please help to ⭐ the repo. 110 | 111 | If you find this project useful for your research, please consider citing our [paper](https://arxiv.org/abs/2504.16915). 112 | 113 | ## :e-mail: Contact 114 | If you have any comments or questions, please [open a new issue](https://github.com/xxx/xxx/issues/new/choose) or contact [Yanze Wu](https://tothebeginning.github.io/) and [Chong Mou](mailto:eechongm@gmail.com). -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # import os 16 | # os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0" 17 | import argparse 18 | 19 | import cv2 20 | import gradio as gr 21 | import numpy as np 22 | import torch 23 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper 24 | from huggingface_hub import hf_hub_download 25 | from optimum.quanto import freeze, qint8, quantize 26 | from PIL import Image 27 | from torchvision.transforms.functional import normalize 28 | 29 | from dreamo.dreamo_pipeline import DreamOPipeline 30 | from dreamo.utils import ( 31 | img2tensor, 32 | resize_numpy_image_area, 33 | resize_numpy_image_long, 34 | tensor2img, 35 | ) 36 | from tools import BEN2 37 | 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--port', type=int, default=8080) 40 | parser.add_argument('--no_turbo', action='store_true') 41 | parser.add_argument('--int8', action='store_true') 42 | parser.add_argument('--offload', action='store_true') 43 | parser.add_argument('--device', type=str, default='auto', help='Device to use: auto, cuda, mps, or cpu') 44 | args = parser.parse_args() 45 | 46 | 47 | def get_device(): 48 | """Automatically detect the best available device""" 49 | if args.device != 'auto': 50 | return torch.device(args.device) 51 | 52 | if torch.cuda.is_available(): 53 | return torch.device('cuda') 54 | elif torch.backends.mps.is_available(): 55 | return torch.device('mps') 56 | else: 57 | return torch.device('cpu') 58 | 59 | 60 | class Generator: 61 | def __init__(self): 62 | self.device = get_device() 63 | print(f"Using device: {self.device}") 64 | 65 | # preprocessing models 66 | # background remove model: BEN2 67 | self.bg_rm_model = BEN2.BEN_Base().to(self.device).eval() 68 | hf_hub_download(repo_id='PramaLLC/BEN2', filename='BEN2_Base.pth', local_dir='models') 69 | self.bg_rm_model.loadcheckpoints('models/BEN2_Base.pth') 70 | # face crop and align tool: facexlib 71 | self.face_helper = FaceRestoreHelper( 72 | upscale_factor=1, 73 | face_size=512, 74 | crop_ratio=(1, 1), 75 | det_model='retinaface_resnet50', 76 | save_ext='png', 77 | device=self.device, 78 | ) 79 | if args.offload: 80 | self.ben_to_device(torch.device('cpu')) 81 | self.facexlib_to_device(torch.device('cpu')) 82 | 83 | # load dreamo 84 | model_root = 'black-forest-labs/FLUX.1-dev' 85 | dreamo_pipeline = DreamOPipeline.from_pretrained(model_root, torch_dtype=torch.bfloat16) 86 | dreamo_pipeline.load_dreamo_model(self.device, use_turbo=not args.no_turbo) 87 | if args.int8: 88 | print('start quantize') 89 | quantize(dreamo_pipeline.transformer, qint8) 90 | freeze(dreamo_pipeline.transformer) 91 | quantize(dreamo_pipeline.text_encoder_2, qint8) 92 | freeze(dreamo_pipeline.text_encoder_2) 93 | print('done quantize') 94 | self.dreamo_pipeline = dreamo_pipeline.to(self.device) 95 | if args.offload: 96 | self.dreamo_pipeline.enable_model_cpu_offload() 97 | self.dreamo_pipeline.offload = True 98 | else: 99 | self.dreamo_pipeline.offload = False 100 | 101 | def ben_to_device(self, device): 102 | self.bg_rm_model.to(device) 103 | 104 | def facexlib_to_device(self, device): 105 | self.face_helper.face_det.to(device) 106 | self.face_helper.face_parse.to(device) 107 | 108 | @torch.no_grad() 109 | def get_align_face(self, img): 110 | # the face preprocessing code is same as PuLID 111 | self.face_helper.clean_all() 112 | image_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 113 | self.face_helper.read_image(image_bgr) 114 | self.face_helper.get_face_landmarks_5(only_center_face=True) 115 | self.face_helper.align_warp_face() 116 | if len(self.face_helper.cropped_faces) == 0: 117 | return None 118 | align_face = self.face_helper.cropped_faces[0] 119 | 120 | input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 121 | input = input.to(self.device) 122 | parsing_out = self.face_helper.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[0] 123 | parsing_out = parsing_out.argmax(dim=1, keepdim=True) 124 | bg_label = [0, 16, 18, 7, 8, 9, 14, 15] 125 | bg = sum(parsing_out == i for i in bg_label).bool() 126 | white_image = torch.ones_like(input) 127 | # only keep the face features 128 | face_features_image = torch.where(bg, white_image, input) 129 | face_features_image = tensor2img(face_features_image, rgb2bgr=False) 130 | 131 | return face_features_image 132 | 133 | 134 | generator = Generator() 135 | 136 | 137 | @torch.inference_mode() 138 | def generate_image( 139 | ref_image1, 140 | ref_image2, 141 | ref_task1, 142 | ref_task2, 143 | prompt, 144 | width, 145 | height, 146 | ref_res, 147 | num_steps, 148 | guidance, 149 | seed, 150 | true_cfg, 151 | cfg_start_step, 152 | cfg_end_step, 153 | neg_prompt, 154 | neg_guidance, 155 | first_step_guidance, 156 | ): 157 | print(prompt) 158 | ref_conds = [] 159 | debug_images = [] 160 | 161 | ref_images = [ref_image1, ref_image2] 162 | ref_tasks = [ref_task1, ref_task2] 163 | 164 | for idx, (ref_image, ref_task) in enumerate(zip(ref_images, ref_tasks)): 165 | if ref_image is not None: 166 | if ref_task == "id": 167 | if args.offload: 168 | generator.facexlib_to_device(generator.device) 169 | ref_image = resize_numpy_image_long(ref_image, 1024) 170 | ref_image = generator.get_align_face(ref_image) 171 | if args.offload: 172 | generator.facexlib_to_device(torch.device('cpu')) 173 | elif ref_task != "style": 174 | if args.offload: 175 | generator.ben_to_device(generator.device) 176 | ref_image = generator.bg_rm_model.inference(Image.fromarray(ref_image)) 177 | if args.offload: 178 | generator.ben_to_device(torch.device('cpu')) 179 | if ref_task != "id": 180 | ref_image = resize_numpy_image_area(np.array(ref_image), ref_res * ref_res) 181 | debug_images.append(ref_image) 182 | ref_image = img2tensor(ref_image, bgr2rgb=False).unsqueeze(0) / 255.0 183 | ref_image = 2 * ref_image - 1.0 184 | ref_conds.append( 185 | { 186 | 'img': ref_image, 187 | 'task': ref_task, 188 | 'idx': idx + 1, 189 | } 190 | ) 191 | 192 | seed = int(seed) 193 | if seed == -1: 194 | seed = torch.Generator(device="cpu").seed() 195 | 196 | image = generator.dreamo_pipeline( 197 | prompt=prompt, 198 | width=width, 199 | height=height, 200 | num_inference_steps=num_steps, 201 | guidance_scale=guidance, 202 | ref_conds=ref_conds, 203 | generator=torch.Generator(device="cpu").manual_seed(seed), 204 | true_cfg_scale=true_cfg, 205 | true_cfg_start_step=cfg_start_step, 206 | true_cfg_end_step=cfg_end_step, 207 | negative_prompt=neg_prompt, 208 | neg_guidance_scale=neg_guidance, 209 | first_step_guidance_scale=first_step_guidance if first_step_guidance > 0 else guidance, 210 | ).images[0] 211 | 212 | return image, debug_images, seed 213 | 214 | 215 | _HEADER_ = ''' 216 |
217 |

DreamO

218 |

Paper: DreamO: A Unified Framework for Image Customization | Codes: GitHub

219 |
220 | 221 | 🚩 Update Notes: 222 | - 2025.05.11: We have updated the model to mitigate over-saturation and plastic-face issues. The new version shows consistent improvements over the previous release. 223 | 224 | ❗️❗️❗️**User Guide:** 225 | - The most important thing to do first is to try the examples provided below the demo, which will help you better understand the capabilities of the DreamO model and the types of tasks it currently supports 226 | - For each input, please select the appropriate task type. For general objects, characters, or clothing, choose IP — we will remove the background from the input image. If you select ID, we will extract the face region from the input image (similar to PuLID). If you select Style, the background will be preserved, and you must prepend the prompt with the instruction: 'generate a same style image.' to activate the style task. 227 | - The most import hyperparameter in this demo is the guidance scale, which is set to 3.5 by default. If you notice that faces appear overly glossy or unrealistic—especially in ID tasks—you can lower the guidance scale (e.g., to 3). Conversely, if text rendering is poor or limb distortion occurs, increasing the guidance scale (e.g., to 4) may help. 228 | - To accelerate inference, we adopt FLUX-turbo LoRA, which reduces the sampling steps from 25 to 12 compared to FLUX-dev. Additionally, we distill a CFG LoRA, achieving nearly a twofold reduction in steps by eliminating the need for true CFG 229 | 230 | ''' # noqa E501 231 | 232 | _CITE_ = r""" 233 | If DreamO is helpful, please help to ⭐ the Github Repo. Thanks! 234 | --- 235 | 236 | 📧 **Contact** 237 | If you have any questions or feedbacks, feel free to open a discussion or contact wuyanze123@gmail.com and eechongm@gmail.com 238 | """ # noqa E501 239 | 240 | 241 | def create_demo(): 242 | 243 | with gr.Blocks() as demo: 244 | gr.Markdown(_HEADER_) 245 | 246 | with gr.Row(): 247 | with gr.Column(): 248 | with gr.Row(): 249 | ref_image1 = gr.Image(label="ref image 1", type="numpy", height=256) 250 | ref_image2 = gr.Image(label="ref image 2", type="numpy", height=256) 251 | with gr.Row(): 252 | ref_task1 = gr.Dropdown(choices=["ip", "id", "style"], value="ip", label="task for ref image 1") 253 | ref_task2 = gr.Dropdown(choices=["ip", "id", "style"], value="ip", label="task for ref image 2") 254 | prompt = gr.Textbox(label="Prompt", value="a person playing guitar in the street") 255 | width = gr.Slider(768, 1024, 1024, step=16, label="Width") 256 | height = gr.Slider(768, 1024, 1024, step=16, label="Height") 257 | num_steps = gr.Slider(8, 30, 12, step=1, label="Number of steps") 258 | guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance") 259 | seed = gr.Textbox(label="Seed (-1 for random)", value="-1") 260 | with gr.Accordion("Advanced Options", open=False, visible=False): 261 | ref_res = gr.Slider(512, 1024, 512, step=16, label="resolution for ref image") 262 | neg_prompt = gr.Textbox(label="Neg Prompt", value="") 263 | neg_guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Neg Guidance") 264 | true_cfg = gr.Slider(1, 5, 1, step=0.1, label="true cfg") 265 | cfg_start_step = gr.Slider(0, 30, 0, step=1, label="cfg start step") 266 | cfg_end_step = gr.Slider(0, 30, 0, step=1, label="cfg end step") 267 | first_step_guidance = gr.Slider(0, 10, 0, step=0.1, label="first step guidance") 268 | generate_btn = gr.Button("Generate") 269 | gr.Markdown(_CITE_) 270 | 271 | with gr.Column(): 272 | output_image = gr.Image(label="Generated Image", format='png') 273 | debug_image = gr.Gallery( 274 | label="Preprocessing output (including possible face crop and background remove)", 275 | elem_id="gallery", 276 | ) 277 | seed_output = gr.Textbox(label="Used Seed") 278 | 279 | with gr.Row(), gr.Column(): 280 | gr.Markdown("## Examples") 281 | example_inps = [ 282 | [ 283 | 'example_inputs/woman1.png', 284 | 'ip', 285 | 'profile shot dark photo of a 25-year-old female with smoke escaping from her mouth, the backlit smoke gives the image an ephemeral quality, natural face, natural eyebrows, natural skin texture, award winning photo, highly detailed face, atmospheric lighting, film grain, monochrome', # noqa E501 286 | 9180879731249039735, 287 | ], 288 | [ 289 | 'example_inputs/man1.png', 290 | 'ip', 291 | 'a man sitting on the cloud, playing guitar', 292 | 1206523688721442817, 293 | ], 294 | [ 295 | 'example_inputs/toy1.png', 296 | 'ip', 297 | 'a purple toy holding a sign saying "DreamO", on the mountain', 298 | 10441727852953907380, 299 | ], 300 | [ 301 | 'example_inputs/perfume.png', 302 | 'ip', 303 | 'a perfume under spotlight', 304 | 116150031980664704, 305 | ], 306 | ] 307 | gr.Examples(examples=example_inps, inputs=[ref_image1, ref_task1, prompt, seed], label='IP task') 308 | 309 | example_inps = [ 310 | [ 311 | 'example_inputs/hinton.jpeg', 312 | 'id', 313 | 'portrait, Chibi', 314 | 5443415087540486371, 315 | ], 316 | ] 317 | gr.Examples( 318 | examples=example_inps, 319 | inputs=[ref_image1, ref_task1, prompt, seed], 320 | label='ID task (similar to PuLID, will only refer to the face)', 321 | ) 322 | 323 | example_inps = [ 324 | [ 325 | 'example_inputs/mickey.png', 326 | 'style', 327 | 'generate a same style image. A rooster wearing overalls.', 328 | 6245580464677124951, 329 | ], 330 | [ 331 | 'example_inputs/mountain.png', 332 | 'style', 333 | 'generate a same style image. A pavilion by the river, and the distant mountains are endless', 334 | 5248066378927500767, 335 | ], 336 | ] 337 | gr.Examples(examples=example_inps, inputs=[ref_image1, ref_task1, prompt, seed], label='Style task') 338 | 339 | example_inps = [ 340 | [ 341 | 'example_inputs/shirt.png', 342 | 'example_inputs/skirt.jpeg', 343 | 'ip', 344 | 'ip', 345 | 'A girl is wearing a short-sleeved shirt and a short skirt on the beach.', 346 | 9514069256241143615, 347 | ], 348 | [ 349 | 'example_inputs/woman2.png', 350 | 'example_inputs/dress.png', 351 | 'id', 352 | 'ip', 353 | 'the woman wearing a dress, In the banquet hall', 354 | 7698454872441022867, 355 | ], 356 | ] 357 | gr.Examples( 358 | examples=example_inps, 359 | inputs=[ref_image1, ref_image2, ref_task1, ref_task2, prompt, seed], 360 | label='Try-On task', 361 | ) 362 | 363 | example_inps = [ 364 | [ 365 | 'example_inputs/dog1.png', 366 | 'example_inputs/dog2.png', 367 | 'ip', 368 | 'ip', 369 | 'two dogs in the jungle', 370 | 6187006025405083344, 371 | ], 372 | [ 373 | 'example_inputs/woman3.png', 374 | 'example_inputs/cat.png', 375 | 'ip', 376 | 'ip', 377 | 'A girl rides a giant cat, walking in the noisy modern city. High definition, realistic, non-cartoonish. Excellent photography work, 8k high definition.', # noqa E501 378 | 11980469406460273604, 379 | ], 380 | [ 381 | 'example_inputs/man2.jpeg', 382 | 'example_inputs/woman4.jpeg', 383 | 'ip', 384 | 'ip', 385 | 'a man is dancing with a woman in the room', 386 | 8303780338601106219, 387 | ], 388 | ] 389 | gr.Examples( 390 | examples=example_inps, 391 | inputs=[ref_image1, ref_image2, ref_task1, ref_task2, prompt, seed], 392 | label='Multi IP', 393 | ) 394 | 395 | generate_btn.click( 396 | fn=generate_image, 397 | inputs=[ 398 | ref_image1, 399 | ref_image2, 400 | ref_task1, 401 | ref_task2, 402 | prompt, 403 | width, 404 | height, 405 | ref_res, 406 | num_steps, 407 | guidance, 408 | seed, 409 | true_cfg, 410 | cfg_start_step, 411 | cfg_end_step, 412 | neg_prompt, 413 | neg_guidance, 414 | first_step_guidance, 415 | ], 416 | outputs=[output_image, debug_image, seed_output], 417 | ) 418 | 419 | return demo 420 | 421 | 422 | if __name__ == '__main__': 423 | demo = create_demo() 424 | demo.queue().launch(server_name='0.0.0.0', server_port=args.port) 425 | -------------------------------------------------------------------------------- /dreamo/dreamo_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Any, Callable, Dict, List, Optional, Union 17 | 18 | import diffusers 19 | import numpy as np 20 | import torch 21 | import torch.nn as nn 22 | from diffusers import FluxPipeline 23 | from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps 24 | from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput 25 | from einops import repeat 26 | from huggingface_hub import hf_hub_download 27 | from safetensors.torch import load_file 28 | 29 | from dreamo.transformer import flux_transformer_forward 30 | from dreamo.utils import convert_flux_lora_to_diffusers 31 | 32 | diffusers.models.transformers.transformer_flux.FluxTransformer2DModel.forward = flux_transformer_forward 33 | 34 | 35 | def get_task_embedding_idx(task): 36 | return 0 37 | 38 | 39 | class DreamOPipeline(FluxPipeline): 40 | def __init__(self, scheduler, vae, text_encoder, tokenizer, text_encoder_2, tokenizer_2, transformer): 41 | super().__init__(scheduler, vae, text_encoder, tokenizer, text_encoder_2, tokenizer_2, transformer) 42 | self.t5_embedding = nn.Embedding(10, 4096) 43 | self.task_embedding = nn.Embedding(2, 3072) 44 | self.idx_embedding = nn.Embedding(10, 3072) 45 | 46 | def load_dreamo_model(self, device, use_turbo=True): 47 | # download models and load file 48 | hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo.safetensors', local_dir='models') 49 | hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo_cfg_distill.safetensors', local_dir='models') 50 | hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo_quality_lora_pos.safetensors', local_dir='models') 51 | hf_hub_download(repo_id='ByteDance/DreamO', filename='dreamo_quality_lora_neg.safetensors', local_dir='models') 52 | dreamo_lora = load_file('models/dreamo.safetensors') 53 | cfg_distill_lora = load_file('models/dreamo_cfg_distill.safetensors') 54 | quality_lora_pos = load_file('models/dreamo_quality_lora_pos.safetensors') 55 | quality_lora_neg = load_file('models/dreamo_quality_lora_neg.safetensors') 56 | 57 | # load embedding 58 | self.t5_embedding.weight.data = dreamo_lora.pop('dreamo_t5_embedding.weight')[-10:] 59 | self.task_embedding.weight.data = dreamo_lora.pop('dreamo_task_embedding.weight') 60 | self.idx_embedding.weight.data = dreamo_lora.pop('dreamo_idx_embedding.weight') 61 | self._prepare_t5() 62 | 63 | # main lora 64 | dreamo_diffuser_lora = convert_flux_lora_to_diffusers(dreamo_lora) 65 | adapter_names = ['dreamo'] 66 | adapter_weights = [1] 67 | self.load_lora_weights(dreamo_diffuser_lora, adapter_name='dreamo') 68 | 69 | # cfg lora to avoid true image cfg 70 | cfg_diffuser_lora = convert_flux_lora_to_diffusers(cfg_distill_lora) 71 | self.load_lora_weights(cfg_diffuser_lora, adapter_name='cfg') 72 | adapter_names.append('cfg') 73 | adapter_weights.append(1) 74 | 75 | # turbo lora to speed up (from 25+ step to 12 step) 76 | if use_turbo: 77 | self.load_lora_weights( 78 | hf_hub_download( 79 | "alimama-creative/FLUX.1-Turbo-Alpha", "diffusion_pytorch_model.safetensors", local_dir='models' 80 | ), 81 | adapter_name='turbo', 82 | ) 83 | adapter_names.append('turbo') 84 | adapter_weights.append(1) 85 | 86 | # quality loras, one pos, one neg 87 | quality_lora_pos = convert_flux_lora_to_diffusers(quality_lora_pos) 88 | self.load_lora_weights(quality_lora_pos, adapter_name='quality_pos') 89 | adapter_names.append('quality_pos') 90 | adapter_weights.append(0.15) 91 | quality_lora_neg = convert_flux_lora_to_diffusers(quality_lora_neg) 92 | self.load_lora_weights(quality_lora_neg, adapter_name='quality_neg') 93 | adapter_names.append('quality_neg') 94 | adapter_weights.append(-0.8) 95 | 96 | self.set_adapters(adapter_names, adapter_weights) 97 | self.fuse_lora(adapter_names=adapter_names, lora_scale=1) 98 | self.unload_lora_weights() 99 | 100 | self.t5_embedding = self.t5_embedding.to(device) 101 | self.task_embedding = self.task_embedding.to(device) 102 | self.idx_embedding = self.idx_embedding.to(device) 103 | 104 | def _prepare_t5(self): 105 | self.text_encoder_2.resize_token_embeddings(len(self.tokenizer_2)) 106 | num_new_token = 10 107 | new_token_list = [f"[ref#{i}]" for i in range(1, 10)] + ["[res]"] 108 | self.tokenizer_2.add_tokens(new_token_list, special_tokens=False) 109 | self.text_encoder_2.resize_token_embeddings(len(self.tokenizer_2)) 110 | input_embedding = self.text_encoder_2.get_input_embeddings().weight.data 111 | input_embedding[-num_new_token:] = self.t5_embedding.weight.data 112 | 113 | @staticmethod 114 | def _prepare_latent_image_ids(batch_size, height, width, device, dtype, start_height=0, start_width=0): 115 | latent_image_ids = torch.zeros(height // 2, width // 2, 3) 116 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None] + start_height 117 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :] + start_width 118 | 119 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape 120 | 121 | latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) 122 | latent_image_ids = latent_image_ids.reshape( 123 | batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels 124 | ) 125 | 126 | return latent_image_ids.to(device=device, dtype=dtype) 127 | 128 | @staticmethod 129 | def _prepare_style_latent_image_ids(batch_size, height, width, device, dtype, start_height=0, start_width=0): 130 | latent_image_ids = torch.zeros(height // 2, width // 2, 3) 131 | latent_image_ids[..., 1] = latent_image_ids[..., 1] + start_height 132 | latent_image_ids[..., 2] = latent_image_ids[..., 2] + start_width 133 | 134 | latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape 135 | 136 | latent_image_ids = latent_image_ids[None, :].repeat(batch_size, 1, 1, 1) 137 | latent_image_ids = latent_image_ids.reshape( 138 | batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels 139 | ) 140 | 141 | return latent_image_ids.to(device=device, dtype=dtype) 142 | 143 | @torch.no_grad() 144 | def __call__( 145 | self, 146 | prompt: Union[str, List[str]] = None, 147 | prompt_2: Optional[Union[str, List[str]]] = None, 148 | negative_prompt: Union[str, List[str]] = None, 149 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 150 | true_cfg_scale: float = 1.0, 151 | true_cfg_start_step: int = 1, 152 | true_cfg_end_step: int = 1, 153 | height: Optional[int] = None, 154 | width: Optional[int] = None, 155 | num_inference_steps: int = 28, 156 | sigmas: Optional[List[float]] = None, 157 | guidance_scale: float = 3.5, 158 | neg_guidance_scale: float = 3.5, 159 | num_images_per_prompt: Optional[int] = 1, 160 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 161 | latents: Optional[torch.FloatTensor] = None, 162 | prompt_embeds: Optional[torch.FloatTensor] = None, 163 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 164 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 165 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 166 | output_type: Optional[str] = "pil", 167 | return_dict: bool = True, 168 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 169 | callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, 170 | callback_on_step_end_tensor_inputs: List[str] = ["latents"], 171 | max_sequence_length: int = 512, 172 | ref_conds=None, 173 | first_step_guidance_scale=3.5, 174 | ): 175 | r""" 176 | Function invoked when calling the pipeline for generation. 177 | 178 | Args: 179 | prompt (`str` or `List[str]`, *optional*): 180 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 181 | instead. 182 | prompt_2 (`str` or `List[str]`, *optional*): 183 | The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 184 | will be used instead. 185 | negative_prompt (`str` or `List[str]`, *optional*): 186 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 187 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is 188 | not greater than `1`). 189 | negative_prompt_2 (`str` or `List[str]`, *optional*): 190 | The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and 191 | `text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders. 192 | true_cfg_scale (`float`, *optional*, defaults to 1.0): 193 | When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance. 194 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 195 | The height in pixels of the generated image. This is set to 1024 by default for the best results. 196 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 197 | The width in pixels of the generated image. This is set to 1024 by default for the best results. 198 | num_inference_steps (`int`, *optional*, defaults to 50): 199 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 200 | expense of slower inference. 201 | sigmas (`List[float]`, *optional*): 202 | Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in 203 | their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed 204 | will be used. 205 | guidance_scale (`float`, *optional*, defaults to 3.5): 206 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 207 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 208 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 209 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 210 | usually at the expense of lower image quality. 211 | num_images_per_prompt (`int`, *optional*, defaults to 1): 212 | The number of images to generate per prompt. 213 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 214 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 215 | to make generation deterministic. 216 | latents (`torch.FloatTensor`, *optional*): 217 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 218 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 219 | tensor will ge generated by sampling using the supplied random `generator`. 220 | prompt_embeds (`torch.FloatTensor`, *optional*): 221 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 222 | provided, text embeddings will be generated from `prompt` input argument. 223 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 224 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 225 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 226 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 227 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 228 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 229 | argument. 230 | negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 231 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 232 | weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` 233 | input argument. 234 | output_type (`str`, *optional*, defaults to `"pil"`): 235 | The output format of the generate image. Choose between 236 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 237 | return_dict (`bool`, *optional*, defaults to `True`): 238 | Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple. 239 | joint_attention_kwargs (`dict`, *optional*): 240 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 241 | `self.processor` in 242 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 243 | callback_on_step_end (`Callable`, *optional*): 244 | A function that calls at the end of each denoising steps during the inference. The function is called 245 | with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, 246 | callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by 247 | `callback_on_step_end_tensor_inputs`. 248 | callback_on_step_end_tensor_inputs (`List`, *optional*): 249 | The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list 250 | will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the 251 | `._callback_tensor_inputs` attribute of your pipeline class. 252 | max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`. 253 | 254 | Examples: 255 | 256 | Returns: 257 | [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict` 258 | is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated 259 | images. 260 | """ 261 | 262 | height = height or self.default_sample_size * self.vae_scale_factor 263 | width = width or self.default_sample_size * self.vae_scale_factor 264 | 265 | # 1. Check inputs. Raise error if not correct 266 | self.check_inputs( 267 | prompt, 268 | prompt_2, 269 | height, 270 | width, 271 | prompt_embeds=prompt_embeds, 272 | pooled_prompt_embeds=pooled_prompt_embeds, 273 | callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, 274 | max_sequence_length=max_sequence_length, 275 | ) 276 | 277 | self._guidance_scale = guidance_scale 278 | self._joint_attention_kwargs = joint_attention_kwargs 279 | self._current_timestep = None 280 | self._interrupt = False 281 | 282 | # 2. Define call parameters 283 | if prompt is not None and isinstance(prompt, str): 284 | batch_size = 1 285 | elif prompt is not None and isinstance(prompt, list): 286 | batch_size = len(prompt) 287 | else: 288 | batch_size = prompt_embeds.shape[0] 289 | 290 | device = self._execution_device 291 | 292 | lora_scale = ( 293 | self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None 294 | ) 295 | has_neg_prompt = negative_prompt is not None or ( 296 | negative_prompt_embeds is not None and negative_pooled_prompt_embeds is not None 297 | ) 298 | do_true_cfg = true_cfg_scale > 1 and has_neg_prompt 299 | ( 300 | prompt_embeds, 301 | pooled_prompt_embeds, 302 | text_ids, 303 | ) = self.encode_prompt( 304 | prompt=prompt, 305 | prompt_2=prompt_2, 306 | prompt_embeds=prompt_embeds, 307 | pooled_prompt_embeds=pooled_prompt_embeds, 308 | device=device, 309 | num_images_per_prompt=num_images_per_prompt, 310 | max_sequence_length=max_sequence_length, 311 | lora_scale=lora_scale, 312 | ) 313 | if do_true_cfg: 314 | ( 315 | negative_prompt_embeds, 316 | negative_pooled_prompt_embeds, 317 | _, 318 | ) = self.encode_prompt( 319 | prompt=negative_prompt, 320 | prompt_2=negative_prompt_2, 321 | prompt_embeds=negative_prompt_embeds, 322 | pooled_prompt_embeds=negative_pooled_prompt_embeds, 323 | device=device, 324 | num_images_per_prompt=num_images_per_prompt, 325 | max_sequence_length=max_sequence_length, 326 | lora_scale=lora_scale, 327 | ) 328 | 329 | # 4. Prepare latent variables 330 | num_channels_latents = self.transformer.config.in_channels // 4 331 | latents, latent_image_ids = self.prepare_latents( 332 | batch_size * num_images_per_prompt, 333 | num_channels_latents, 334 | height, 335 | width, 336 | prompt_embeds.dtype, 337 | device, 338 | generator, 339 | latents, 340 | ) 341 | 342 | # 4.1 concat ref tokens to latent 343 | origin_img_len = latents.shape[1] 344 | embeddings = repeat(self.task_embedding.weight[1], "c -> n l c", n=batch_size, l=origin_img_len) 345 | ref_latents = [] 346 | ref_latent_image_idss = [] 347 | start_height = height // 16 348 | start_width = width // 16 349 | for ref_cond in ref_conds: 350 | img = ref_cond['img'] # [b, 3, h, w], range [-1, 1] 351 | task = ref_cond['task'] 352 | idx = ref_cond['idx'] 353 | 354 | # encode ref with VAE 355 | img = img.to(latents) 356 | ref_latent = self.vae.encode(img).latent_dist.sample() 357 | ref_latent = (ref_latent - self.vae.config.shift_factor) * self.vae.config.scaling_factor 358 | cur_height = ref_latent.shape[2] 359 | cur_width = ref_latent.shape[3] 360 | ref_latent = self._pack_latents(ref_latent, batch_size, num_channels_latents, cur_height, cur_width) 361 | ref_latent_image_ids = self._prepare_latent_image_ids( 362 | batch_size, cur_height, cur_width, device, prompt_embeds.dtype, start_height, start_width 363 | ) 364 | start_height += cur_height // 2 365 | start_width += cur_width // 2 366 | 367 | # prepare task_idx_embedding 368 | task_idx = get_task_embedding_idx(task) 369 | cur_task_embedding = repeat( 370 | self.task_embedding.weight[task_idx], "c -> n l c", n=batch_size, l=ref_latent.shape[1] 371 | ) 372 | cur_idx_embedding = repeat( 373 | self.idx_embedding.weight[idx], "c -> n l c", n=batch_size, l=ref_latent.shape[1] 374 | ) 375 | cur_embedding = cur_task_embedding + cur_idx_embedding 376 | 377 | # concat ref to latent 378 | embeddings = torch.cat([embeddings, cur_embedding], dim=1) 379 | ref_latents.append(ref_latent) 380 | ref_latent_image_idss.append(ref_latent_image_ids) 381 | 382 | # 5. Prepare timesteps 383 | sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas 384 | image_seq_len = latents.shape[1] 385 | mu = calculate_shift( 386 | image_seq_len, 387 | self.scheduler.config.get("base_image_seq_len", 256), 388 | self.scheduler.config.get("max_image_seq_len", 4096), 389 | self.scheduler.config.get("base_shift", 0.5), 390 | self.scheduler.config.get("max_shift", 1.15), 391 | ) 392 | timesteps, num_inference_steps = retrieve_timesteps( 393 | self.scheduler, 394 | num_inference_steps, 395 | device, 396 | sigmas=sigmas, 397 | mu=mu, 398 | ) 399 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 400 | self._num_timesteps = len(timesteps) 401 | 402 | # handle guidance 403 | if self.transformer.config.guidance_embeds: 404 | guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) 405 | guidance = guidance.expand(latents.shape[0]) 406 | else: 407 | guidance = None 408 | neg_guidance = torch.full([1], neg_guidance_scale, device=device, dtype=torch.float32) 409 | neg_guidance = neg_guidance.expand(latents.shape[0]) 410 | first_step_guidance = torch.full([1], first_step_guidance_scale, device=device, dtype=torch.float32) 411 | 412 | if self.joint_attention_kwargs is None: 413 | self._joint_attention_kwargs = {} 414 | 415 | # 6. Denoising loop 416 | with self.progress_bar(total=num_inference_steps) as progress_bar: 417 | for i, t in enumerate(timesteps): 418 | if self.interrupt: 419 | continue 420 | 421 | self._current_timestep = t 422 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 423 | timestep = t.expand(latents.shape[0]).to(latents.dtype) 424 | 425 | noise_pred = self.transformer( 426 | hidden_states=torch.cat((latents, *ref_latents), dim=1), 427 | timestep=timestep / 1000, 428 | guidance=guidance if i > 0 else first_step_guidance, 429 | pooled_projections=pooled_prompt_embeds, 430 | encoder_hidden_states=prompt_embeds, 431 | txt_ids=text_ids, 432 | img_ids=torch.cat((latent_image_ids, *ref_latent_image_idss), dim=1), 433 | joint_attention_kwargs=self.joint_attention_kwargs, 434 | return_dict=False, 435 | embeddings=embeddings, 436 | )[0][:, :origin_img_len] 437 | 438 | if do_true_cfg and i >= true_cfg_start_step and i < true_cfg_end_step: 439 | neg_noise_pred = self.transformer( 440 | hidden_states=latents, 441 | timestep=timestep / 1000, 442 | guidance=neg_guidance, 443 | pooled_projections=negative_pooled_prompt_embeds, 444 | encoder_hidden_states=negative_prompt_embeds, 445 | txt_ids=text_ids, 446 | img_ids=latent_image_ids, 447 | joint_attention_kwargs=self.joint_attention_kwargs, 448 | return_dict=False, 449 | )[0] 450 | noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) 451 | 452 | # compute the previous noisy sample x_t -> x_t-1 453 | latents_dtype = latents.dtype 454 | latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] 455 | 456 | if latents.dtype != latents_dtype and torch.backends.mps.is_available(): 457 | # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272 458 | latents = latents.to(latents_dtype) 459 | 460 | if callback_on_step_end is not None: 461 | callback_kwargs = {} 462 | for k in callback_on_step_end_tensor_inputs: 463 | callback_kwargs[k] = locals()[k] 464 | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) 465 | 466 | latents = callback_outputs.pop("latents", latents) 467 | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) 468 | 469 | # call the callback, if provided 470 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 471 | progress_bar.update() 472 | 473 | self._current_timestep = None 474 | 475 | if self.offload: 476 | self.transformer.cpu() 477 | torch.cuda.empty_cache() 478 | 479 | if output_type == "latent": 480 | image = latents 481 | else: 482 | latents = self._unpack_latents(latents, height, width, self.vae_scale_factor) 483 | latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor 484 | image = self.vae.decode(latents, return_dict=False)[0] 485 | image = self.image_processor.postprocess(image, output_type=output_type) 486 | 487 | # Offload all models 488 | self.maybe_free_model_hooks() 489 | 490 | if not return_dict: 491 | return (image,) 492 | 493 | return FluxPipelineOutput(images=image) 494 | -------------------------------------------------------------------------------- /dreamo/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Any, Dict, Optional, Union 17 | 18 | import numpy as np 19 | import torch 20 | from diffusers.models.modeling_outputs import Transformer2DModelOutput 21 | from diffusers.utils import ( 22 | USE_PEFT_BACKEND, 23 | logging, 24 | scale_lora_layers, 25 | unscale_lora_layers, 26 | ) 27 | 28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 29 | 30 | 31 | def flux_transformer_forward( 32 | self, 33 | hidden_states: torch.Tensor, 34 | encoder_hidden_states: torch.Tensor = None, 35 | pooled_projections: torch.Tensor = None, 36 | timestep: torch.LongTensor = None, 37 | img_ids: torch.Tensor = None, 38 | txt_ids: torch.Tensor = None, 39 | guidance: torch.Tensor = None, 40 | joint_attention_kwargs: Optional[Dict[str, Any]] = None, 41 | controlnet_block_samples=None, 42 | controlnet_single_block_samples=None, 43 | return_dict: bool = True, 44 | controlnet_blocks_repeat: bool = False, 45 | embeddings: torch.Tensor = None, 46 | ) -> Union[torch.Tensor, Transformer2DModelOutput]: 47 | """ 48 | The [`FluxTransformer2DModel`] forward method. 49 | 50 | Args: 51 | hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): 52 | Input `hidden_states`. 53 | encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): 54 | Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. 55 | pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected 56 | from the embeddings of input conditions. 57 | timestep ( `torch.LongTensor`): 58 | Used to indicate denoising step. 59 | block_controlnet_hidden_states: (`list` of `torch.Tensor`): 60 | A list of tensors that if specified are added to the residuals of transformer blocks. 61 | joint_attention_kwargs (`dict`, *optional*): 62 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 63 | `self.processor` in 64 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 65 | return_dict (`bool`, *optional*, defaults to `True`): 66 | Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain 67 | tuple. 68 | 69 | Returns: 70 | If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a 71 | `tuple` where the first element is the sample tensor. 72 | """ 73 | if joint_attention_kwargs is not None: 74 | joint_attention_kwargs = joint_attention_kwargs.copy() 75 | lora_scale = joint_attention_kwargs.pop("scale", 1.0) 76 | else: 77 | lora_scale = 1.0 78 | 79 | if USE_PEFT_BACKEND: 80 | # weight the lora layers by setting `lora_scale` for each PEFT layer 81 | scale_lora_layers(self, lora_scale) 82 | else: 83 | if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: 84 | logger.warning( 85 | "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." 86 | ) 87 | 88 | hidden_states = self.x_embedder(hidden_states) 89 | # add task and idx embedding 90 | if embeddings is not None: 91 | hidden_states = hidden_states + embeddings 92 | 93 | timestep = timestep.to(hidden_states.dtype) * 1000 94 | guidance = guidance.to(hidden_states.dtype) * 1000 if guidance is not None else None 95 | 96 | temb = ( 97 | self.time_text_embed(timestep, pooled_projections) 98 | if guidance is None 99 | else self.time_text_embed(timestep, guidance, pooled_projections) 100 | ) 101 | encoder_hidden_states = self.context_embedder(encoder_hidden_states) 102 | 103 | if txt_ids.ndim == 3: 104 | # logger.warning( 105 | # "Passing `txt_ids` 3d torch.Tensor is deprecated." 106 | # "Please remove the batch dimension and pass it as a 2d torch Tensor" 107 | # ) 108 | txt_ids = txt_ids[0] 109 | if img_ids.ndim == 3: 110 | # logger.warning( 111 | # "Passing `img_ids` 3d torch.Tensor is deprecated." 112 | # "Please remove the batch dimension and pass it as a 2d torch Tensor" 113 | # ) 114 | img_ids = img_ids[0] 115 | 116 | ids = torch.cat((txt_ids, img_ids), dim=0) 117 | image_rotary_emb = self.pos_embed(ids) 118 | 119 | for index_block, block in enumerate(self.transformer_blocks): 120 | if torch.is_grad_enabled() and self.gradient_checkpointing: 121 | encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( 122 | block, 123 | hidden_states, 124 | encoder_hidden_states, 125 | temb, 126 | image_rotary_emb, 127 | ) 128 | 129 | else: 130 | encoder_hidden_states, hidden_states = block( 131 | hidden_states=hidden_states, 132 | encoder_hidden_states=encoder_hidden_states, 133 | temb=temb, 134 | image_rotary_emb=image_rotary_emb, 135 | joint_attention_kwargs=joint_attention_kwargs, 136 | ) 137 | 138 | # controlnet residual 139 | if controlnet_block_samples is not None: 140 | interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) 141 | interval_control = int(np.ceil(interval_control)) 142 | # For Xlabs ControlNet. 143 | if controlnet_blocks_repeat: 144 | hidden_states = hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] 145 | else: 146 | hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] 147 | hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) 148 | 149 | for index_block, block in enumerate(self.single_transformer_blocks): 150 | if torch.is_grad_enabled() and self.gradient_checkpointing: 151 | hidden_states = self._gradient_checkpointing_func( 152 | block, 153 | hidden_states, 154 | temb, 155 | image_rotary_emb, 156 | ) 157 | 158 | else: 159 | hidden_states = block( 160 | hidden_states=hidden_states, 161 | temb=temb, 162 | image_rotary_emb=image_rotary_emb, 163 | joint_attention_kwargs=joint_attention_kwargs, 164 | ) 165 | 166 | # controlnet residual 167 | if controlnet_single_block_samples is not None: 168 | interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) 169 | interval_control = int(np.ceil(interval_control)) 170 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( 171 | hidden_states[:, encoder_hidden_states.shape[1] :, ...] 172 | + controlnet_single_block_samples[index_block // interval_control] 173 | ) 174 | 175 | hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] 176 | 177 | hidden_states = self.norm_out(hidden_states, temb) 178 | output = self.proj_out(hidden_states) 179 | 180 | if USE_PEFT_BACKEND: 181 | # remove `lora_scale` from each PEFT layer 182 | unscale_lora_layers(self, lora_scale) 183 | 184 | if not return_dict: 185 | return (output,) 186 | 187 | return Transformer2DModelOutput(sample=output) 188 | -------------------------------------------------------------------------------- /dreamo/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import math 16 | import re 17 | 18 | import cv2 19 | import numpy as np 20 | import torch 21 | from torchvision.utils import make_grid 22 | 23 | 24 | # from basicsr 25 | def img2tensor(imgs, bgr2rgb=True, float32=True): 26 | """Numpy array to tensor. 27 | 28 | Args: 29 | imgs (list[ndarray] | ndarray): Input images. 30 | bgr2rgb (bool): Whether to change bgr to rgb. 31 | float32 (bool): Whether to change to float32. 32 | 33 | Returns: 34 | list[tensor] | tensor: Tensor images. If returned results only have 35 | one element, just return tensor. 36 | """ 37 | 38 | def _totensor(img, bgr2rgb, float32): 39 | if img.shape[2] == 3 and bgr2rgb: 40 | if img.dtype == 'float64': 41 | img = img.astype('float32') 42 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 43 | img = torch.from_numpy(img.transpose(2, 0, 1)) 44 | if float32: 45 | img = img.float() 46 | return img 47 | 48 | if isinstance(imgs, list): 49 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 50 | return _totensor(imgs, bgr2rgb, float32) 51 | 52 | 53 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 54 | """Convert torch Tensors into image numpy arrays. 55 | 56 | After clamping to [min, max], values will be normalized to [0, 1]. 57 | 58 | Args: 59 | tensor (Tensor or list[Tensor]): Accept shapes: 60 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 61 | 2) 3D Tensor of shape (3/1 x H x W); 62 | 3) 2D Tensor of shape (H x W). 63 | Tensor channel should be in RGB order. 64 | rgb2bgr (bool): Whether to change rgb to bgr. 65 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 66 | to uint8 type with range [0, 255]; otherwise, float type with 67 | range [0, 1]. Default: ``np.uint8``. 68 | min_max (tuple[int]): min and max values for clamp. 69 | 70 | Returns: 71 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 72 | shape (H x W). The channel order is BGR. 73 | """ 74 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 75 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') 76 | 77 | if torch.is_tensor(tensor): 78 | tensor = [tensor] 79 | result = [] 80 | for _tensor in tensor: 81 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 82 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 83 | 84 | n_dim = _tensor.dim() 85 | if n_dim == 4: 86 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() 87 | img_np = img_np.transpose(1, 2, 0) 88 | if rgb2bgr: 89 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 90 | elif n_dim == 3: 91 | img_np = _tensor.numpy() 92 | img_np = img_np.transpose(1, 2, 0) 93 | if img_np.shape[2] == 1: # gray image 94 | img_np = np.squeeze(img_np, axis=2) 95 | else: 96 | if rgb2bgr: 97 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 98 | elif n_dim == 2: 99 | img_np = _tensor.numpy() 100 | else: 101 | raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') 102 | if out_type == np.uint8: 103 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 104 | img_np = (img_np * 255.0).round() 105 | img_np = img_np.astype(out_type) 106 | result.append(img_np) 107 | if len(result) == 1: 108 | result = result[0] 109 | return result 110 | 111 | 112 | def resize_numpy_image_area(image, area=512 * 512): 113 | h, w = image.shape[:2] 114 | k = math.sqrt(area / (h * w)) 115 | h = int(h * k) - (int(h * k) % 16) 116 | w = int(w * k) - (int(w * k) % 16) 117 | image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA) 118 | return image 119 | 120 | def resize_numpy_image_long(image, long_edge=768): 121 | h, w = image.shape[:2] 122 | if max(h, w) <= long_edge: 123 | return image 124 | k = long_edge / max(h, w) 125 | h = int(h * k) 126 | w = int(w * k) 127 | image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA) 128 | return image 129 | 130 | 131 | # reference: https://github.com/huggingface/diffusers/pull/9295/files 132 | def convert_flux_lora_to_diffusers(old_state_dict): 133 | new_state_dict = {} 134 | orig_keys = list(old_state_dict.keys()) 135 | 136 | def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None): 137 | down_weight = sds_sd.pop(sds_key) 138 | up_weight = sds_sd.pop(sds_key.replace(".down.weight", ".up.weight")) 139 | 140 | # calculate dims if not provided 141 | num_splits = len(ait_keys) 142 | if dims is None: 143 | dims = [up_weight.shape[0] // num_splits] * num_splits 144 | else: 145 | assert sum(dims) == up_weight.shape[0] 146 | 147 | # make ai-toolkit weight 148 | ait_down_keys = [k + ".lora_A.weight" for k in ait_keys] 149 | ait_up_keys = [k + ".lora_B.weight" for k in ait_keys] 150 | 151 | # down_weight is copied to each split 152 | ait_sd.update({k: down_weight for k in ait_down_keys}) 153 | 154 | # up_weight is split to each split 155 | ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416 156 | 157 | for old_key in orig_keys: 158 | # Handle double_blocks 159 | if 'double_blocks' in old_key: 160 | block_num = re.search(r"double_blocks_(\d+)", old_key).group(1) 161 | new_key = f"transformer.transformer_blocks.{block_num}" 162 | 163 | if "proj_lora1" in old_key: 164 | new_key += ".attn.to_out.0" 165 | elif "proj_lora2" in old_key: 166 | new_key += ".attn.to_add_out" 167 | elif "qkv_lora2" in old_key and "up" not in old_key: 168 | handle_qkv( 169 | old_state_dict, 170 | new_state_dict, 171 | old_key, 172 | [ 173 | f"transformer.transformer_blocks.{block_num}.attn.add_q_proj", 174 | f"transformer.transformer_blocks.{block_num}.attn.add_k_proj", 175 | f"transformer.transformer_blocks.{block_num}.attn.add_v_proj", 176 | ], 177 | ) 178 | # continue 179 | elif "qkv_lora1" in old_key and "up" not in old_key: 180 | handle_qkv( 181 | old_state_dict, 182 | new_state_dict, 183 | old_key, 184 | [ 185 | f"transformer.transformer_blocks.{block_num}.attn.to_q", 186 | f"transformer.transformer_blocks.{block_num}.attn.to_k", 187 | f"transformer.transformer_blocks.{block_num}.attn.to_v", 188 | ], 189 | ) 190 | # continue 191 | 192 | if "down" in old_key: 193 | new_key += ".lora_A.weight" 194 | elif "up" in old_key: 195 | new_key += ".lora_B.weight" 196 | 197 | # Handle single_blocks 198 | elif 'single_blocks' in old_key: 199 | block_num = re.search(r"single_blocks_(\d+)", old_key).group(1) 200 | new_key = f"transformer.single_transformer_blocks.{block_num}" 201 | 202 | if "proj_lora" in old_key: 203 | new_key += ".proj_out" 204 | elif "qkv_lora" in old_key and "up" not in old_key: 205 | handle_qkv( 206 | old_state_dict, 207 | new_state_dict, 208 | old_key, 209 | [ 210 | f"transformer.single_transformer_blocks.{block_num}.attn.to_q", 211 | f"transformer.single_transformer_blocks.{block_num}.attn.to_k", 212 | f"transformer.single_transformer_blocks.{block_num}.attn.to_v", 213 | ], 214 | ) 215 | 216 | if "down" in old_key: 217 | new_key += ".lora_A.weight" 218 | elif "up" in old_key: 219 | new_key += ".lora_B.weight" 220 | 221 | else: 222 | # Handle other potential key patterns here 223 | new_key = old_key 224 | 225 | # Since we already handle qkv above. 226 | if "qkv" not in old_key and 'embedding' not in old_key: 227 | new_state_dict[new_key] = old_state_dict.pop(old_key) 228 | 229 | # if len(old_state_dict) > 0: 230 | # raise ValueError(f"`old_state_dict` should be at this point but has: {list(old_state_dict.keys())}.") 231 | 232 | return new_state_dict 233 | -------------------------------------------------------------------------------- /example_inputs/cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DreamO/22df501d41b164f2d3918bfc59f4a6ac41b95384/example_inputs/cat.png -------------------------------------------------------------------------------- /example_inputs/dog1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DreamO/22df501d41b164f2d3918bfc59f4a6ac41b95384/example_inputs/dog1.png -------------------------------------------------------------------------------- /example_inputs/dog2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DreamO/22df501d41b164f2d3918bfc59f4a6ac41b95384/example_inputs/dog2.png -------------------------------------------------------------------------------- /example_inputs/dress.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DreamO/22df501d41b164f2d3918bfc59f4a6ac41b95384/example_inputs/dress.png -------------------------------------------------------------------------------- /example_inputs/hinton.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DreamO/22df501d41b164f2d3918bfc59f4a6ac41b95384/example_inputs/hinton.jpeg -------------------------------------------------------------------------------- /example_inputs/man1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DreamO/22df501d41b164f2d3918bfc59f4a6ac41b95384/example_inputs/man1.png -------------------------------------------------------------------------------- /example_inputs/man2.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DreamO/22df501d41b164f2d3918bfc59f4a6ac41b95384/example_inputs/man2.jpeg -------------------------------------------------------------------------------- /example_inputs/mickey.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DreamO/22df501d41b164f2d3918bfc59f4a6ac41b95384/example_inputs/mickey.png -------------------------------------------------------------------------------- /example_inputs/mountain.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DreamO/22df501d41b164f2d3918bfc59f4a6ac41b95384/example_inputs/mountain.png -------------------------------------------------------------------------------- /example_inputs/perfume.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DreamO/22df501d41b164f2d3918bfc59f4a6ac41b95384/example_inputs/perfume.png -------------------------------------------------------------------------------- /example_inputs/shirt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DreamO/22df501d41b164f2d3918bfc59f4a6ac41b95384/example_inputs/shirt.png -------------------------------------------------------------------------------- /example_inputs/skirt.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DreamO/22df501d41b164f2d3918bfc59f4a6ac41b95384/example_inputs/skirt.jpeg -------------------------------------------------------------------------------- /example_inputs/toy1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DreamO/22df501d41b164f2d3918bfc59f4a6ac41b95384/example_inputs/toy1.png -------------------------------------------------------------------------------- /example_inputs/woman1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DreamO/22df501d41b164f2d3918bfc59f4a6ac41b95384/example_inputs/woman1.png -------------------------------------------------------------------------------- /example_inputs/woman2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DreamO/22df501d41b164f2d3918bfc59f4a6ac41b95384/example_inputs/woman2.png -------------------------------------------------------------------------------- /example_inputs/woman3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DreamO/22df501d41b164f2d3918bfc59f4a6ac41b95384/example_inputs/woman3.png -------------------------------------------------------------------------------- /example_inputs/woman4.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DreamO/22df501d41b164f2d3918bfc59f4a6ac41b95384/example_inputs/woman4.jpeg -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bytedance/DreamO/22df501d41b164f2d3918bfc59f4a6ac41b95384/models/.gitkeep -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | line-length = 120 3 | exclude = ['tools'] 4 | # A list of file patterns to omit from linting, in addition to those specified by exclude. 5 | extend-exclude = ["__pycache__", "*.pyc", "*.egg-info", ".cache"] 6 | 7 | select = ["E", "F", "W", "C90", "I", "UP", "B", "C4", "RET", "RUF", "SIM"] 8 | 9 | 10 | ignore = [ 11 | "UP006", # UP006: Use list instead of typing.List for type annotations 12 | "UP007", # UP007: Use X | Y for type annotations 13 | "UP009", 14 | "UP035", 15 | "UP038", 16 | "E402", 17 | "RET504", 18 | "C901", 19 | "RUF013", 20 | "B006", 21 | ] 22 | 23 | [tool.isort] 24 | profile = "black" 25 | 26 | [tool.black] 27 | line-length = 119 28 | skip-string-normalization = 1 29 | exclude = 'tools' -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.6.0 2 | torchvision==0.21.0 3 | protobuf 4 | optimum-quanto==0.2.7 5 | einops 6 | timm 7 | diffusers==0.31.0 8 | transformers==4.45.2 9 | sentencepiece 10 | gradio 11 | spaces 12 | huggingface_hub 13 | accelerate==0.32.0 14 | peft 15 | git+https://github.com/ToTheBeginning/facexlib.git 16 | -------------------------------------------------------------------------------- /tools/BEN2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2025 Prama LLC 2 | # SPDX-License-Identifier: MIT 3 | 4 | import math 5 | import os 6 | import random 7 | import subprocess 8 | import tempfile 9 | import time 10 | 11 | import cv2 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | import torch.nn.functional as F 16 | import torch.utils.checkpoint as checkpoint 17 | from einops import rearrange 18 | from PIL import Image, ImageOps 19 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 20 | from torchvision import transforms 21 | 22 | 23 | def set_random_seed(seed): 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.cuda.manual_seed_all(seed) 29 | torch.backends.cudnn.deterministic = True 30 | torch.backends.cudnn.benchmark = False 31 | 32 | 33 | # set_random_seed(9) 34 | 35 | torch.set_float32_matmul_precision('highest') 36 | 37 | 38 | class Mlp(nn.Module): 39 | """ Multilayer perceptron.""" 40 | 41 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 42 | super().__init__() 43 | out_features = out_features or in_features 44 | hidden_features = hidden_features or in_features 45 | self.fc1 = nn.Linear(in_features, hidden_features) 46 | self.act = act_layer() 47 | self.fc2 = nn.Linear(hidden_features, out_features) 48 | self.drop = nn.Dropout(drop) 49 | 50 | def forward(self, x): 51 | x = self.fc1(x) 52 | x = self.act(x) 53 | x = self.drop(x) 54 | x = self.fc2(x) 55 | x = self.drop(x) 56 | return x 57 | 58 | 59 | def window_partition(x, window_size): 60 | """ 61 | Args: 62 | x: (B, H, W, C) 63 | window_size (int): window size 64 | Returns: 65 | windows: (num_windows*B, window_size, window_size, C) 66 | """ 67 | B, H, W, C = x.shape 68 | x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) 69 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) 70 | return windows 71 | 72 | 73 | def window_reverse(windows, window_size, H, W): 74 | """ 75 | Args: 76 | windows: (num_windows*B, window_size, window_size, C) 77 | window_size (int): Window size 78 | H (int): Height of image 79 | W (int): Width of image 80 | Returns: 81 | x: (B, H, W, C) 82 | """ 83 | B = int(windows.shape[0] / (H * W / window_size / window_size)) 84 | x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) 85 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 86 | return x 87 | 88 | 89 | class WindowAttention(nn.Module): 90 | """ Window based multi-head self attention (W-MSA) module with relative position bias. 91 | It supports both of shifted and non-shifted window. 92 | Args: 93 | dim (int): Number of input channels. 94 | window_size (tuple[int]): The height and width of the window. 95 | num_heads (int): Number of attention heads. 96 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 97 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set 98 | attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 99 | proj_drop (float, optional): Dropout ratio of output. Default: 0.0 100 | """ 101 | 102 | def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 103 | 104 | super().__init__() 105 | self.dim = dim 106 | self.window_size = window_size # Wh, Ww 107 | self.num_heads = num_heads 108 | head_dim = dim // num_heads 109 | self.scale = qk_scale or head_dim ** -0.5 110 | 111 | # define a parameter table of relative position bias 112 | self.relative_position_bias_table = nn.Parameter( 113 | torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 114 | 115 | # get pair-wise relative position index for each token inside the window 116 | coords_h = torch.arange(self.window_size[0]) 117 | coords_w = torch.arange(self.window_size[1]) 118 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 119 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 120 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 121 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 122 | relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 123 | relative_coords[:, :, 1] += self.window_size[1] - 1 124 | relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 125 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 126 | self.register_buffer("relative_position_index", relative_position_index) 127 | 128 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 129 | self.attn_drop = nn.Dropout(attn_drop) 130 | self.proj = nn.Linear(dim, dim) 131 | self.proj_drop = nn.Dropout(proj_drop) 132 | 133 | trunc_normal_(self.relative_position_bias_table, std=.02) 134 | self.softmax = nn.Softmax(dim=-1) 135 | 136 | def forward(self, x, mask=None): 137 | """ Forward function. 138 | Args: 139 | x: input features with shape of (num_windows*B, N, C) 140 | mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None 141 | """ 142 | B_, N, C = x.shape 143 | qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 144 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 145 | 146 | q = q * self.scale 147 | attn = (q @ k.transpose(-2, -1)) 148 | 149 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 150 | self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 151 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 152 | attn = attn + relative_position_bias.unsqueeze(0) 153 | 154 | if mask is not None: 155 | nW = mask.shape[0] 156 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 157 | attn = attn.view(-1, self.num_heads, N, N) 158 | attn = self.softmax(attn) 159 | else: 160 | attn = self.softmax(attn) 161 | 162 | attn = self.attn_drop(attn) 163 | 164 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 165 | x = self.proj(x) 166 | x = self.proj_drop(x) 167 | return x 168 | 169 | 170 | class SwinTransformerBlock(nn.Module): 171 | """ Swin Transformer Block. 172 | Args: 173 | dim (int): Number of input channels. 174 | num_heads (int): Number of attention heads. 175 | window_size (int): Window size. 176 | shift_size (int): Shift size for SW-MSA. 177 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. 178 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 179 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 180 | drop (float, optional): Dropout rate. Default: 0.0 181 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 182 | drop_path (float, optional): Stochastic depth rate. Default: 0.0 183 | act_layer (nn.Module, optional): Activation layer. Default: nn.GELU 184 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 185 | """ 186 | 187 | def __init__(self, dim, num_heads, window_size=7, shift_size=0, 188 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 189 | act_layer=nn.GELU, norm_layer=nn.LayerNorm): 190 | super().__init__() 191 | self.dim = dim 192 | self.num_heads = num_heads 193 | self.window_size = window_size 194 | self.shift_size = shift_size 195 | self.mlp_ratio = mlp_ratio 196 | assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" 197 | 198 | self.norm1 = norm_layer(dim) 199 | self.attn = WindowAttention( 200 | dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, 201 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 202 | 203 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 204 | self.norm2 = norm_layer(dim) 205 | mlp_hidden_dim = int(dim * mlp_ratio) 206 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 207 | 208 | self.H = None 209 | self.W = None 210 | 211 | def forward(self, x, mask_matrix): 212 | """ Forward function. 213 | Args: 214 | x: Input feature, tensor size (B, H*W, C). 215 | H, W: Spatial resolution of the input feature. 216 | mask_matrix: Attention mask for cyclic shift. 217 | """ 218 | B, L, C = x.shape 219 | H, W = self.H, self.W 220 | assert L == H * W, "input feature has wrong size" 221 | 222 | shortcut = x 223 | x = self.norm1(x) 224 | x = x.view(B, H, W, C) 225 | 226 | # pad feature maps to multiples of window size 227 | pad_l = pad_t = 0 228 | pad_r = (self.window_size - W % self.window_size) % self.window_size 229 | pad_b = (self.window_size - H % self.window_size) % self.window_size 230 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) 231 | _, Hp, Wp, _ = x.shape 232 | 233 | # cyclic shift 234 | if self.shift_size > 0: 235 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 236 | attn_mask = mask_matrix 237 | else: 238 | shifted_x = x 239 | attn_mask = None 240 | 241 | # partition windows 242 | x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C 243 | x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C 244 | 245 | # W-MSA/SW-MSA 246 | attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C 247 | 248 | # merge windows 249 | attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) 250 | shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C 251 | 252 | # reverse cyclic shift 253 | if self.shift_size > 0: 254 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 255 | else: 256 | x = shifted_x 257 | 258 | if pad_r > 0 or pad_b > 0: 259 | x = x[:, :H, :W, :].contiguous() 260 | 261 | x = x.view(B, H * W, C) 262 | 263 | # FFN 264 | x = shortcut + self.drop_path(x) 265 | x = x + self.drop_path(self.mlp(self.norm2(x))) 266 | 267 | return x 268 | 269 | 270 | class PatchMerging(nn.Module): 271 | """ Patch Merging Layer 272 | Args: 273 | dim (int): Number of input channels. 274 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 275 | """ 276 | 277 | def __init__(self, dim, norm_layer=nn.LayerNorm): 278 | super().__init__() 279 | self.dim = dim 280 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) 281 | self.norm = norm_layer(4 * dim) 282 | 283 | def forward(self, x, H, W): 284 | """ Forward function. 285 | Args: 286 | x: Input feature, tensor size (B, H*W, C). 287 | H, W: Spatial resolution of the input feature. 288 | """ 289 | B, L, C = x.shape 290 | assert L == H * W, "input feature has wrong size" 291 | 292 | x = x.view(B, H, W, C) 293 | 294 | # padding 295 | pad_input = (H % 2 == 1) or (W % 2 == 1) 296 | if pad_input: 297 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) 298 | 299 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C 300 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C 301 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C 302 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C 303 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C 304 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C 305 | 306 | x = self.norm(x) 307 | x = self.reduction(x) 308 | 309 | return x 310 | 311 | 312 | class BasicLayer(nn.Module): 313 | """ A basic Swin Transformer layer for one stage. 314 | Args: 315 | dim (int): Number of feature channels 316 | depth (int): Depths of this stage. 317 | num_heads (int): Number of attention head. 318 | window_size (int): Local window size. Default: 7. 319 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 320 | qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True 321 | qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. 322 | drop (float, optional): Dropout rate. Default: 0.0 323 | attn_drop (float, optional): Attention dropout rate. Default: 0.0 324 | drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 325 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm 326 | downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None 327 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 328 | """ 329 | 330 | def __init__(self, 331 | dim, 332 | depth, 333 | num_heads, 334 | window_size=7, 335 | mlp_ratio=4., 336 | qkv_bias=True, 337 | qk_scale=None, 338 | drop=0., 339 | attn_drop=0., 340 | drop_path=0., 341 | norm_layer=nn.LayerNorm, 342 | downsample=None, 343 | use_checkpoint=False): 344 | super().__init__() 345 | self.window_size = window_size 346 | self.shift_size = window_size // 2 347 | self.depth = depth 348 | self.use_checkpoint = use_checkpoint 349 | 350 | # build blocks 351 | self.blocks = nn.ModuleList([ 352 | SwinTransformerBlock( 353 | dim=dim, 354 | num_heads=num_heads, 355 | window_size=window_size, 356 | shift_size=0 if (i % 2 == 0) else window_size // 2, 357 | mlp_ratio=mlp_ratio, 358 | qkv_bias=qkv_bias, 359 | qk_scale=qk_scale, 360 | drop=drop, 361 | attn_drop=attn_drop, 362 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 363 | norm_layer=norm_layer) 364 | for i in range(depth)]) 365 | 366 | # patch merging layer 367 | if downsample is not None: 368 | self.downsample = downsample(dim=dim, norm_layer=norm_layer) 369 | else: 370 | self.downsample = None 371 | 372 | def forward(self, x, H, W): 373 | """ Forward function. 374 | Args: 375 | x: Input feature, tensor size (B, H*W, C). 376 | H, W: Spatial resolution of the input feature. 377 | """ 378 | 379 | # calculate attention mask for SW-MSA 380 | Hp = int(np.ceil(H / self.window_size)) * self.window_size 381 | Wp = int(np.ceil(W / self.window_size)) * self.window_size 382 | img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 383 | h_slices = (slice(0, -self.window_size), 384 | slice(-self.window_size, -self.shift_size), 385 | slice(-self.shift_size, None)) 386 | w_slices = (slice(0, -self.window_size), 387 | slice(-self.window_size, -self.shift_size), 388 | slice(-self.shift_size, None)) 389 | cnt = 0 390 | for h in h_slices: 391 | for w in w_slices: 392 | img_mask[:, h, w, :] = cnt 393 | cnt += 1 394 | 395 | mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 396 | mask_windows = mask_windows.view(-1, self.window_size * self.window_size) 397 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) 398 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 399 | 400 | for blk in self.blocks: 401 | blk.H, blk.W = H, W 402 | if self.use_checkpoint: 403 | x = checkpoint.checkpoint(blk, x, attn_mask) 404 | else: 405 | x = blk(x, attn_mask) 406 | if self.downsample is not None: 407 | x_down = self.downsample(x, H, W) 408 | Wh, Ww = (H + 1) // 2, (W + 1) // 2 409 | return x, H, W, x_down, Wh, Ww 410 | else: 411 | return x, H, W, x, H, W 412 | 413 | 414 | class PatchEmbed(nn.Module): 415 | """ Image to Patch Embedding 416 | Args: 417 | patch_size (int): Patch token size. Default: 4. 418 | in_chans (int): Number of input image channels. Default: 3. 419 | embed_dim (int): Number of linear projection output channels. Default: 96. 420 | norm_layer (nn.Module, optional): Normalization layer. Default: None 421 | """ 422 | 423 | def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): 424 | super().__init__() 425 | patch_size = to_2tuple(patch_size) 426 | self.patch_size = patch_size 427 | 428 | self.in_chans = in_chans 429 | self.embed_dim = embed_dim 430 | 431 | self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) 432 | if norm_layer is not None: 433 | self.norm = norm_layer(embed_dim) 434 | else: 435 | self.norm = None 436 | 437 | def forward(self, x): 438 | """Forward function.""" 439 | # padding 440 | _, _, H, W = x.size() 441 | if W % self.patch_size[1] != 0: 442 | x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) 443 | if H % self.patch_size[0] != 0: 444 | x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) 445 | 446 | x = self.proj(x) # B C Wh Ww 447 | if self.norm is not None: 448 | Wh, Ww = x.size(2), x.size(3) 449 | x = x.flatten(2).transpose(1, 2) 450 | x = self.norm(x) 451 | x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) 452 | 453 | return x 454 | 455 | 456 | class SwinTransformer(nn.Module): 457 | """ Swin Transformer backbone. 458 | A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - 459 | https://arxiv.org/pdf/2103.14030 460 | Args: 461 | pretrain_img_size (int): Input image size for training the pretrained model, 462 | used in absolute postion embedding. Default 224. 463 | patch_size (int | tuple(int)): Patch size. Default: 4. 464 | in_chans (int): Number of input image channels. Default: 3. 465 | embed_dim (int): Number of linear projection output channels. Default: 96. 466 | depths (tuple[int]): Depths of each Swin Transformer stage. 467 | num_heads (tuple[int]): Number of attention head of each stage. 468 | window_size (int): Window size. Default: 7. 469 | mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. 470 | qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True 471 | qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. 472 | drop_rate (float): Dropout rate. 473 | attn_drop_rate (float): Attention dropout rate. Default: 0. 474 | drop_path_rate (float): Stochastic depth rate. Default: 0.2. 475 | norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. 476 | ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. 477 | patch_norm (bool): If True, add normalization after patch embedding. Default: True. 478 | out_indices (Sequence[int]): Output from which stages. 479 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode). 480 | -1 means not freezing any parameters. 481 | use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. 482 | """ 483 | 484 | def __init__(self, 485 | pretrain_img_size=224, 486 | patch_size=4, 487 | in_chans=3, 488 | embed_dim=96, 489 | depths=[2, 2, 6, 2], 490 | num_heads=[3, 6, 12, 24], 491 | window_size=7, 492 | mlp_ratio=4., 493 | qkv_bias=True, 494 | qk_scale=None, 495 | drop_rate=0., 496 | attn_drop_rate=0., 497 | drop_path_rate=0.2, 498 | norm_layer=nn.LayerNorm, 499 | ape=False, 500 | patch_norm=True, 501 | out_indices=(0, 1, 2, 3), 502 | frozen_stages=-1, 503 | use_checkpoint=False): 504 | super().__init__() 505 | 506 | self.pretrain_img_size = pretrain_img_size 507 | self.num_layers = len(depths) 508 | self.embed_dim = embed_dim 509 | self.ape = ape 510 | self.patch_norm = patch_norm 511 | self.out_indices = out_indices 512 | self.frozen_stages = frozen_stages 513 | 514 | # split image into non-overlapping patches 515 | self.patch_embed = PatchEmbed( 516 | patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, 517 | norm_layer=norm_layer if self.patch_norm else None) 518 | 519 | # absolute position embedding 520 | if self.ape: 521 | pretrain_img_size = to_2tuple(pretrain_img_size) 522 | patch_size = to_2tuple(patch_size) 523 | patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]] 524 | 525 | self.absolute_pos_embed = nn.Parameter( 526 | torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1])) 527 | trunc_normal_(self.absolute_pos_embed, std=.02) 528 | 529 | self.pos_drop = nn.Dropout(p=drop_rate) 530 | 531 | # stochastic depth 532 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule 533 | 534 | # build layers 535 | self.layers = nn.ModuleList() 536 | for i_layer in range(self.num_layers): 537 | layer = BasicLayer( 538 | dim=int(embed_dim * 2 ** i_layer), 539 | depth=depths[i_layer], 540 | num_heads=num_heads[i_layer], 541 | window_size=window_size, 542 | mlp_ratio=mlp_ratio, 543 | qkv_bias=qkv_bias, 544 | qk_scale=qk_scale, 545 | drop=drop_rate, 546 | attn_drop=attn_drop_rate, 547 | drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], 548 | norm_layer=norm_layer, 549 | downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, 550 | use_checkpoint=use_checkpoint) 551 | self.layers.append(layer) 552 | 553 | num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] 554 | self.num_features = num_features 555 | 556 | # add a norm layer for each output 557 | for i_layer in out_indices: 558 | layer = norm_layer(num_features[i_layer]) 559 | layer_name = f'norm{i_layer}' 560 | self.add_module(layer_name, layer) 561 | 562 | self._freeze_stages() 563 | 564 | def _freeze_stages(self): 565 | if self.frozen_stages >= 0: 566 | self.patch_embed.eval() 567 | for param in self.patch_embed.parameters(): 568 | param.requires_grad = False 569 | 570 | if self.frozen_stages >= 1 and self.ape: 571 | self.absolute_pos_embed.requires_grad = False 572 | 573 | if self.frozen_stages >= 2: 574 | self.pos_drop.eval() 575 | for i in range(0, self.frozen_stages - 1): 576 | m = self.layers[i] 577 | m.eval() 578 | for param in m.parameters(): 579 | param.requires_grad = False 580 | 581 | def forward(self, x): 582 | 583 | x = self.patch_embed(x) 584 | 585 | Wh, Ww = x.size(2), x.size(3) 586 | if self.ape: 587 | # interpolate the position embedding to the corresponding size 588 | absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic') 589 | x = (x + absolute_pos_embed) # B Wh*Ww C 590 | 591 | outs = [x.contiguous()] 592 | x = x.flatten(2).transpose(1, 2) 593 | x = self.pos_drop(x) 594 | 595 | for i in range(self.num_layers): 596 | layer = self.layers[i] 597 | x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) 598 | 599 | if i in self.out_indices: 600 | norm_layer = getattr(self, f'norm{i}') 601 | x_out = norm_layer(x_out) 602 | 603 | out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() 604 | outs.append(out) 605 | 606 | return tuple(outs) 607 | 608 | 609 | def get_activation_fn(activation): 610 | """Return an activation function given a string""" 611 | if activation == "gelu": 612 | return F.gelu 613 | 614 | raise RuntimeError(F"activation should be gelu, not {activation}.") 615 | 616 | 617 | def make_cbr(in_dim, out_dim): 618 | return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU()) 619 | 620 | 621 | def make_cbg(in_dim, out_dim): 622 | return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, padding=1), nn.InstanceNorm2d(out_dim), nn.GELU()) 623 | 624 | 625 | def rescale_to(x, scale_factor: float = 2, interpolation='nearest'): 626 | return F.interpolate(x, scale_factor=scale_factor, mode=interpolation) 627 | 628 | 629 | def resize_as(x, y, interpolation='bilinear'): 630 | return F.interpolate(x, size=y.shape[-2:], mode=interpolation) 631 | 632 | 633 | def image2patches(x): 634 | """b c (hg h) (wg w) -> (hg wg b) c h w""" 635 | x = rearrange(x, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2) 636 | return x 637 | 638 | 639 | def patches2image(x): 640 | """(hg wg b) c h w -> b c (hg h) (wg w)""" 641 | x = rearrange(x, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2) 642 | return x 643 | 644 | 645 | class PositionEmbeddingSine: 646 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 647 | super().__init__() 648 | self.num_pos_feats = num_pos_feats 649 | self.temperature = temperature 650 | self.normalize = normalize 651 | if scale is not None and normalize is False: 652 | raise ValueError("normalize should be True if scale is passed") 653 | if scale is None: 654 | scale = 2 * math.pi 655 | self.scale = scale 656 | self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32) 657 | 658 | def __call__(self, b, h, w): 659 | device = self.dim_t.device 660 | mask = torch.zeros([b, h, w], dtype=torch.bool, device=device) 661 | assert mask is not None 662 | not_mask = ~mask 663 | y_embed = not_mask.cumsum(dim=1, dtype=torch.float32) 664 | x_embed = not_mask.cumsum(dim=2, dtype=torch.float32) 665 | if self.normalize: 666 | eps = 1e-6 667 | y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale 668 | x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale 669 | 670 | dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats) 671 | pos_x = x_embed[:, :, :, None] / dim_t 672 | pos_y = y_embed[:, :, :, None] / dim_t 673 | 674 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 675 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 676 | 677 | return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 678 | 679 | 680 | class PositionEmbeddingSine: 681 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 682 | super().__init__() 683 | self.num_pos_feats = num_pos_feats 684 | self.temperature = temperature 685 | self.normalize = normalize 686 | if scale is not None and normalize is False: 687 | raise ValueError("normalize should be True if scale is passed") 688 | if scale is None: 689 | scale = 2 * math.pi 690 | self.scale = scale 691 | self.dim_t = torch.arange(0, self.num_pos_feats, dtype=torch.float32) 692 | 693 | def __call__(self, b, h, w): 694 | device = self.dim_t.device 695 | mask = torch.zeros([b, h, w], dtype=torch.bool, device=device) 696 | assert mask is not None 697 | not_mask = ~mask 698 | y_embed = not_mask.cumsum(dim=1, dtype=torch.float32) 699 | x_embed = not_mask.cumsum(dim=2, dtype=torch.float32) 700 | if self.normalize: 701 | eps = 1e-6 702 | y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale 703 | x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale 704 | 705 | dim_t = self.temperature ** (2 * (self.dim_t.to(device) // 2) / self.num_pos_feats) 706 | pos_x = x_embed[:, :, :, None] / dim_t 707 | pos_y = y_embed[:, :, :, None] / dim_t 708 | 709 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 710 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 711 | 712 | return torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 713 | 714 | 715 | class MCLM(nn.Module): 716 | def __init__(self, d_model, num_heads, pool_ratios=[1, 4, 8]): 717 | super(MCLM, self).__init__() 718 | self.attention = nn.ModuleList([ 719 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 720 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 721 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 722 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 723 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1) 724 | ]) 725 | 726 | self.linear1 = nn.Linear(d_model, d_model * 2) 727 | self.linear2 = nn.Linear(d_model * 2, d_model) 728 | self.linear3 = nn.Linear(d_model, d_model * 2) 729 | self.linear4 = nn.Linear(d_model * 2, d_model) 730 | self.norm1 = nn.LayerNorm(d_model) 731 | self.norm2 = nn.LayerNorm(d_model) 732 | self.dropout = nn.Dropout(0.1) 733 | self.dropout1 = nn.Dropout(0.1) 734 | self.dropout2 = nn.Dropout(0.1) 735 | self.activation = get_activation_fn('gelu') 736 | self.pool_ratios = pool_ratios 737 | self.p_poses = [] 738 | self.g_pos = None 739 | self.positional_encoding = PositionEmbeddingSine(num_pos_feats=d_model // 2, normalize=True) 740 | 741 | def forward(self, l, g): 742 | """ 743 | l: 4,c,h,w 744 | g: 1,c,h,w 745 | """ 746 | self.p_poses = [] 747 | self.g_pos = None 748 | b, c, h, w = l.size() 749 | # 4,c,h,w -> 1,c,2h,2w 750 | concated_locs = rearrange(l, '(hg wg b) c h w -> b c (hg h) (wg w)', hg=2, wg=2) 751 | 752 | pools = [] 753 | for pool_ratio in self.pool_ratios: 754 | # b,c,h,w 755 | tgt_hw = (round(h / pool_ratio), round(w / pool_ratio)) 756 | pool = F.adaptive_avg_pool2d(concated_locs, tgt_hw) 757 | pools.append(rearrange(pool, 'b c h w -> (h w) b c')) 758 | if self.g_pos is None: 759 | pos_emb = self.positional_encoding(pool.shape[0], pool.shape[2], pool.shape[3]) 760 | pos_emb = rearrange(pos_emb, 'b c h w -> (h w) b c') 761 | self.p_poses.append(pos_emb) 762 | pools = torch.cat(pools, 0) 763 | if self.g_pos is None: 764 | self.p_poses = torch.cat(self.p_poses, dim=0) 765 | pos_emb = self.positional_encoding(g.shape[0], g.shape[2], g.shape[3]) 766 | self.g_pos = rearrange(pos_emb, 'b c h w -> (h w) b c') 767 | 768 | device = pools.device 769 | self.p_poses = self.p_poses.to(device) 770 | self.g_pos = self.g_pos.to(device) 771 | 772 | # attention between glb (q) & multisensory concated-locs (k,v) 773 | g_hw_b_c = rearrange(g, 'b c h w -> (h w) b c') 774 | 775 | g_hw_b_c = g_hw_b_c + self.dropout1(self.attention[0](g_hw_b_c + self.g_pos, pools + self.p_poses, pools)[0]) 776 | g_hw_b_c = self.norm1(g_hw_b_c) 777 | g_hw_b_c = g_hw_b_c + self.dropout2(self.linear2(self.dropout(self.activation(self.linear1(g_hw_b_c)).clone()))) 778 | g_hw_b_c = self.norm2(g_hw_b_c) 779 | 780 | # attention between origin locs (q) & freashed glb (k,v) 781 | l_hw_b_c = rearrange(l, "b c h w -> (h w) b c") 782 | _g_hw_b_c = rearrange(g_hw_b_c, '(h w) b c -> h w b c', h=h, w=w) 783 | _g_hw_b_c = rearrange(_g_hw_b_c, "(ng h) (nw w) b c -> (h w) (ng nw b) c", ng=2, nw=2) 784 | outputs_re = [] 785 | for i, (_l, _g) in enumerate(zip(l_hw_b_c.chunk(4, dim=1), _g_hw_b_c.chunk(4, dim=1))): 786 | outputs_re.append(self.attention[i + 1](_l, _g, _g)[0]) # (h w) 1 c 787 | outputs_re = torch.cat(outputs_re, 1) # (h w) 4 c 788 | 789 | l_hw_b_c = l_hw_b_c + self.dropout1(outputs_re) 790 | l_hw_b_c = self.norm1(l_hw_b_c) 791 | l_hw_b_c = l_hw_b_c + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(l_hw_b_c)).clone()))) 792 | l_hw_b_c = self.norm2(l_hw_b_c) 793 | 794 | l = torch.cat((l_hw_b_c, g_hw_b_c), 1) # hw,b(5),c 795 | return rearrange(l, "(h w) b c -> b c h w", h=h, w=w) ## (5,c,h*w) 796 | 797 | 798 | class MCRM(nn.Module): 799 | def __init__(self, d_model, num_heads, pool_ratios=[4, 8, 16], h=None): 800 | super(MCRM, self).__init__() 801 | self.attention = nn.ModuleList([ 802 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 803 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 804 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1), 805 | nn.MultiheadAttention(d_model, num_heads, dropout=0.1) 806 | ]) 807 | self.linear3 = nn.Linear(d_model, d_model * 2) 808 | self.linear4 = nn.Linear(d_model * 2, d_model) 809 | self.norm1 = nn.LayerNorm(d_model) 810 | self.norm2 = nn.LayerNorm(d_model) 811 | self.dropout = nn.Dropout(0.1) 812 | self.dropout1 = nn.Dropout(0.1) 813 | self.dropout2 = nn.Dropout(0.1) 814 | self.sigmoid = nn.Sigmoid() 815 | self.activation = get_activation_fn('gelu') 816 | self.sal_conv = nn.Conv2d(d_model, 1, 1) 817 | self.pool_ratios = pool_ratios 818 | 819 | def forward(self, x): 820 | device = x.device 821 | b, c, h, w = x.size() 822 | loc, glb = x.split([4, 1], dim=0) # 4,c,h,w; 1,c,h,w 823 | 824 | patched_glb = rearrange(glb, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2) 825 | 826 | token_attention_map = self.sigmoid(self.sal_conv(glb)) 827 | token_attention_map = F.interpolate(token_attention_map, size=patches2image(loc).shape[-2:], mode='nearest') 828 | loc = loc * rearrange(token_attention_map, 'b c (hg h) (wg w) -> (hg wg b) c h w', hg=2, wg=2) 829 | 830 | pools = [] 831 | for pool_ratio in self.pool_ratios: 832 | tgt_hw = (round(h / pool_ratio), round(w / pool_ratio)) 833 | pool = F.adaptive_avg_pool2d(patched_glb, tgt_hw) 834 | pools.append(rearrange(pool, 'nl c h w -> nl c (h w)')) # nl(4),c,hw 835 | 836 | pools = rearrange(torch.cat(pools, 2), "nl c nphw -> nl nphw 1 c") 837 | loc_ = rearrange(loc, 'nl c h w -> nl (h w) 1 c') 838 | 839 | outputs = [] 840 | for i, q in enumerate(loc_.unbind(dim=0)): # traverse all local patches 841 | v = pools[i] 842 | k = v 843 | outputs.append(self.attention[i](q, k, v)[0]) 844 | 845 | outputs = torch.cat(outputs, 1) 846 | src = loc.view(4, c, -1).permute(2, 0, 1) + self.dropout1(outputs) 847 | src = self.norm1(src) 848 | src = src + self.dropout2(self.linear4(self.dropout(self.activation(self.linear3(src)).clone()))) 849 | src = self.norm2(src) 850 | src = src.permute(1, 2, 0).reshape(4, c, h, w) # freshed loc 851 | glb = glb + F.interpolate(patches2image(src), size=glb.shape[-2:], mode='nearest') # freshed glb 852 | 853 | return torch.cat((src, glb), 0), token_attention_map 854 | 855 | 856 | class BEN_Base(nn.Module): 857 | def __init__(self): 858 | super().__init__() 859 | 860 | self.backbone = SwinTransformer(embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], window_size=12) 861 | emb_dim = 128 862 | self.sideout5 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) 863 | self.sideout4 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) 864 | self.sideout3 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) 865 | self.sideout2 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) 866 | self.sideout1 = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) 867 | 868 | self.output5 = make_cbr(1024, emb_dim) 869 | self.output4 = make_cbr(512, emb_dim) 870 | self.output3 = make_cbr(256, emb_dim) 871 | self.output2 = make_cbr(128, emb_dim) 872 | self.output1 = make_cbr(128, emb_dim) 873 | 874 | self.multifieldcrossatt = MCLM(emb_dim, 1, [1, 4, 8]) 875 | self.conv1 = make_cbr(emb_dim, emb_dim) 876 | self.conv2 = make_cbr(emb_dim, emb_dim) 877 | self.conv3 = make_cbr(emb_dim, emb_dim) 878 | self.conv4 = make_cbr(emb_dim, emb_dim) 879 | self.dec_blk1 = MCRM(emb_dim, 1, [2, 4, 8]) 880 | self.dec_blk2 = MCRM(emb_dim, 1, [2, 4, 8]) 881 | self.dec_blk3 = MCRM(emb_dim, 1, [2, 4, 8]) 882 | self.dec_blk4 = MCRM(emb_dim, 1, [2, 4, 8]) 883 | 884 | self.insmask_head = nn.Sequential( 885 | nn.Conv2d(emb_dim, 384, kernel_size=3, padding=1), 886 | nn.InstanceNorm2d(384), 887 | nn.GELU(), 888 | nn.Conv2d(384, 384, kernel_size=3, padding=1), 889 | nn.InstanceNorm2d(384), 890 | nn.GELU(), 891 | nn.Conv2d(384, emb_dim, kernel_size=3, padding=1) 892 | ) 893 | 894 | self.shallow = nn.Sequential(nn.Conv2d(3, emb_dim, kernel_size=3, padding=1)) 895 | self.upsample1 = make_cbg(emb_dim, emb_dim) 896 | self.upsample2 = make_cbg(emb_dim, emb_dim) 897 | self.output = nn.Sequential(nn.Conv2d(emb_dim, 1, kernel_size=3, padding=1)) 898 | 899 | for m in self.modules(): 900 | if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout): 901 | m.inplace = True 902 | 903 | @torch.inference_mode() 904 | @torch.autocast(device_type="cuda", dtype=torch.float16) 905 | def forward(self, x): 906 | real_batch = x.size(0) 907 | 908 | shallow_batch = self.shallow(x) 909 | glb_batch = rescale_to(x, scale_factor=0.5, interpolation='bilinear') 910 | 911 | final_input = None 912 | for i in range(real_batch): 913 | start = i * 4 914 | end = (i + 1) * 4 915 | loc_batch = image2patches(x[i, :, :, :].unsqueeze(dim=0)) 916 | input_ = torch.cat((loc_batch, glb_batch[i, :, :, :].unsqueeze(dim=0)), dim=0) 917 | 918 | if final_input == None: 919 | final_input = input_ 920 | else: 921 | final_input = torch.cat((final_input, input_), dim=0) 922 | 923 | features = self.backbone(final_input) 924 | outputs = [] 925 | 926 | for i in range(real_batch): 927 | start = i * 5 928 | end = (i + 1) * 5 929 | 930 | f4 = features[4][start:end, :, :, :] # shape: [5, C, H, W] 931 | f3 = features[3][start:end, :, :, :] 932 | f2 = features[2][start:end, :, :, :] 933 | f1 = features[1][start:end, :, :, :] 934 | f0 = features[0][start:end, :, :, :] 935 | e5 = self.output5(f4) 936 | e4 = self.output4(f3) 937 | e3 = self.output3(f2) 938 | e2 = self.output2(f1) 939 | e1 = self.output1(f0) 940 | loc_e5, glb_e5 = e5.split([4, 1], dim=0) 941 | e5 = self.multifieldcrossatt(loc_e5, glb_e5) # (4,128,16,16) 942 | 943 | e4, tokenattmap4 = self.dec_blk4(e4 + resize_as(e5, e4)) 944 | e4 = self.conv4(e4) 945 | e3, tokenattmap3 = self.dec_blk3(e3 + resize_as(e4, e3)) 946 | e3 = self.conv3(e3) 947 | e2, tokenattmap2 = self.dec_blk2(e2 + resize_as(e3, e2)) 948 | e2 = self.conv2(e2) 949 | e1, tokenattmap1 = self.dec_blk1(e1 + resize_as(e2, e1)) 950 | e1 = self.conv1(e1) 951 | 952 | loc_e1, glb_e1 = e1.split([4, 1], dim=0) 953 | 954 | output1_cat = patches2image(loc_e1) # (1,128,256,256) 955 | 956 | # add glb feat in 957 | output1_cat = output1_cat + resize_as(glb_e1, output1_cat) 958 | # merge 959 | final_output = self.insmask_head(output1_cat) # (1,128,256,256) 960 | # shallow feature merge 961 | shallow = shallow_batch[i, :, :, :].unsqueeze(dim=0) 962 | final_output = final_output + resize_as(shallow, final_output) 963 | final_output = self.upsample1(rescale_to(final_output)) 964 | final_output = rescale_to(final_output + resize_as(shallow, final_output)) 965 | final_output = self.upsample2(final_output) 966 | final_output = self.output(final_output) 967 | mask = final_output.sigmoid() 968 | outputs.append(mask) 969 | 970 | return torch.cat(outputs, dim=0) 971 | 972 | def loadcheckpoints(self, model_path): 973 | model_dict = torch.load(model_path, map_location="cpu", weights_only=True) 974 | self.load_state_dict(model_dict['model_state_dict'], strict=True) 975 | del model_path 976 | 977 | def inference(self, image, refine_foreground=False): 978 | 979 | # set_random_seed(9) 980 | # image = ImageOps.exif_transpose(image) 981 | if isinstance(image, Image.Image): 982 | image, h, w, original_image = rgb_loader_refiner(image) 983 | if torch.cuda.is_available(): 984 | 985 | img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device) 986 | else: 987 | img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device) 988 | 989 | with torch.no_grad(): 990 | res = self.forward(img_tensor) 991 | 992 | # Show Results 993 | if refine_foreground == True: 994 | 995 | pred_pil = transforms.ToPILImage()(res.squeeze()) 996 | image_masked = refine_foreground_process(original_image, pred_pil) 997 | 998 | image_masked.putalpha(pred_pil.resize(original_image.size)) 999 | return image_masked 1000 | 1001 | else: 1002 | alpha = postprocess_image(res, im_size=[w, h]) 1003 | pred_pil = transforms.ToPILImage()(alpha) 1004 | mask = pred_pil.resize(original_image.size) 1005 | original_image.putalpha(mask) 1006 | # mask = Image.fromarray(alpha) 1007 | 1008 | # 将背景置为白色 1009 | white_background = Image.new('RGB', original_image.size, (255, 255, 255)) 1010 | white_background.paste(original_image, mask=original_image.split()[3]) 1011 | original_image = white_background 1012 | 1013 | return original_image 1014 | 1015 | 1016 | else: 1017 | foregrounds = [] 1018 | for batch in image: 1019 | image, h, w, original_image = rgb_loader_refiner(batch) 1020 | if torch.cuda.is_available(): 1021 | 1022 | img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device) 1023 | else: 1024 | img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device) 1025 | 1026 | with torch.no_grad(): 1027 | res = self.forward(img_tensor) 1028 | 1029 | if refine_foreground == True: 1030 | 1031 | pred_pil = transforms.ToPILImage()(res.squeeze()) 1032 | image_masked = refine_foreground_process(original_image, pred_pil) 1033 | 1034 | image_masked.putalpha(pred_pil.resize(original_image.size)) 1035 | 1036 | foregrounds.append(image_masked) 1037 | else: 1038 | alpha = postprocess_image(res, im_size=[w, h]) 1039 | pred_pil = transforms.ToPILImage()(alpha) 1040 | mask = pred_pil.resize(original_image.size) 1041 | original_image.putalpha(mask) 1042 | # mask = Image.fromarray(alpha) 1043 | foregrounds.append(original_image) 1044 | 1045 | return foregrounds 1046 | 1047 | def segment_video(self, video_path, output_path="./", fps=0, refine_foreground=False, batch=1, 1048 | print_frames_processed=True, webm=False, rgb_value=(0, 255, 0)): 1049 | 1050 | """ 1051 | Segments the given video to extract the foreground (with alpha) from each frame 1052 | and saves the result as either a WebM video (with alpha channel) or MP4 (with a 1053 | color background). 1054 | 1055 | Args: 1056 | video_path (str): 1057 | Path to the input video file. 1058 | 1059 | output_path (str, optional): 1060 | Directory (or full path) where the output video and/or files will be saved. 1061 | Defaults to "./". 1062 | 1063 | fps (int, optional): 1064 | The frames per second (FPS) to use for the output video. If 0 (default), the 1065 | original FPS of the input video is used. Otherwise, overrides it. 1066 | 1067 | refine_foreground (bool, optional): 1068 | Whether to run an additional “refine foreground” process on each frame. 1069 | Defaults to False. 1070 | 1071 | batch (int, optional): 1072 | Number of frames to process at once (inference batch size). Large batch sizes 1073 | may require more GPU memory. Defaults to 1. 1074 | 1075 | print_frames_processed (bool, optional): 1076 | If True (default), prints progress (how many frames have been processed) to 1077 | the console. 1078 | 1079 | webm (bool, optional): 1080 | If True (default), exports a WebM video with alpha channel (VP9 / yuva420p). 1081 | If False, exports an MP4 video composited over a solid color background. 1082 | 1083 | rgb_value (tuple, optional): 1084 | The RGB background color (e.g., green screen) used to composite frames when 1085 | saving to MP4. Defaults to (0, 255, 0). 1086 | 1087 | Returns: 1088 | None. Writes the output video(s) to disk in the specified format. 1089 | """ 1090 | 1091 | cap = cv2.VideoCapture(video_path) 1092 | if not cap.isOpened(): 1093 | raise IOError(f"Cannot open video: {video_path}") 1094 | 1095 | original_fps = cap.get(cv2.CAP_PROP_FPS) 1096 | original_fps = 30 if original_fps == 0 else original_fps 1097 | fps = original_fps if fps == 0 else fps 1098 | 1099 | ret, first_frame = cap.read() 1100 | if not ret: 1101 | raise ValueError("No frames found in the video.") 1102 | height, width = first_frame.shape[:2] 1103 | cap.set(cv2.CAP_PROP_POS_FRAMES, 0) 1104 | 1105 | foregrounds = [] 1106 | frame_idx = 0 1107 | processed_count = 0 1108 | batch_frames = [] 1109 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 1110 | 1111 | while True: 1112 | ret, frame = cap.read() 1113 | if not ret: 1114 | if batch_frames: 1115 | batch_results = self.inference(batch_frames, refine_foreground) 1116 | if isinstance(batch_results, Image.Image): 1117 | foregrounds.append(batch_results) 1118 | else: 1119 | foregrounds.extend(batch_results) 1120 | if print_frames_processed: 1121 | print(f"Processed frames {frame_idx - len(batch_frames) + 1} to {frame_idx} of {total_frames}") 1122 | break 1123 | 1124 | # Process every frame instead of using intervals 1125 | frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 1126 | pil_frame = Image.fromarray(frame_rgb) 1127 | batch_frames.append(pil_frame) 1128 | 1129 | if len(batch_frames) == batch: 1130 | batch_results = self.inference(batch_frames, refine_foreground) 1131 | if isinstance(batch_results, Image.Image): 1132 | foregrounds.append(batch_results) 1133 | else: 1134 | foregrounds.extend(batch_results) 1135 | if print_frames_processed: 1136 | print(f"Processed frames {frame_idx - batch + 1} to {frame_idx} of {total_frames}") 1137 | batch_frames = [] 1138 | processed_count += batch 1139 | 1140 | frame_idx += 1 1141 | 1142 | if webm: 1143 | alpha_webm_path = os.path.join(output_path, "foreground.webm") 1144 | pil_images_to_webm_alpha(foregrounds, alpha_webm_path, fps=original_fps) 1145 | 1146 | else: 1147 | cap.release() 1148 | fg_output = os.path.join(output_path, 'foreground.mp4') 1149 | 1150 | pil_images_to_mp4(foregrounds, fg_output, fps=original_fps, rgb_value=rgb_value) 1151 | cv2.destroyAllWindows() 1152 | 1153 | try: 1154 | fg_audio_output = os.path.join(output_path, 'foreground_output_with_audio.mp4') 1155 | add_audio_to_video(fg_output, video_path, fg_audio_output) 1156 | except Exception as e: 1157 | print("No audio found in the original video") 1158 | print(e) 1159 | 1160 | 1161 | def rgb_loader_refiner(original_image): 1162 | h, w = original_image.size 1163 | 1164 | image = original_image 1165 | # Convert to RGB if necessary 1166 | if image.mode != 'RGB': 1167 | image = image.convert('RGB') 1168 | 1169 | # Resize the image 1170 | image = image.resize((1024, 1024), resample=Image.LANCZOS) 1171 | 1172 | return image.convert('RGB'), h, w, original_image 1173 | 1174 | 1175 | # Define the image transformation 1176 | img_transform = transforms.Compose([ 1177 | transforms.ToTensor(), 1178 | transforms.ConvertImageDtype(torch.float16), 1179 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 1180 | ]) 1181 | 1182 | img_transform32 = transforms.Compose([ 1183 | transforms.ToTensor(), 1184 | transforms.ConvertImageDtype(torch.float32), 1185 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 1186 | ]) 1187 | 1188 | 1189 | def pil_images_to_mp4(images, output_path, fps=24, rgb_value=(0, 255, 0)): 1190 | """ 1191 | Converts an array of PIL images to an MP4 video. 1192 | 1193 | Args: 1194 | images: List of PIL images 1195 | output_path: Path to save the MP4 file 1196 | fps: Frames per second (default: 24) 1197 | rgb_value: Background RGB color tuple (default: green (0, 255, 0)) 1198 | """ 1199 | if not images: 1200 | raise ValueError("No images provided to convert to MP4.") 1201 | 1202 | width, height = images[0].size 1203 | fourcc = cv2.VideoWriter_fourcc(*'mp4v') 1204 | video_writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) 1205 | 1206 | for image in images: 1207 | # If image has alpha channel, composite onto the specified background color 1208 | if image.mode == 'RGBA': 1209 | # Create background image with specified RGB color 1210 | background = Image.new('RGB', image.size, rgb_value) 1211 | background = background.convert('RGBA') 1212 | # Composite the image onto the background 1213 | image = Image.alpha_composite(background, image) 1214 | image = image.convert('RGB') 1215 | else: 1216 | # Ensure RGB format for non-alpha images 1217 | image = image.convert('RGB') 1218 | 1219 | # Convert to OpenCV format and write 1220 | open_cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) 1221 | video_writer.write(open_cv_image) 1222 | 1223 | video_writer.release() 1224 | 1225 | 1226 | def pil_images_to_webm_alpha(images, output_path, fps=30): 1227 | """ 1228 | Converts a list of PIL RGBA images to a VP9 .webm video with alpha channel. 1229 | 1230 | NOTE: Not all players will display alpha in WebM. 1231 | Browsers like Chrome/Firefox typically do support VP9 alpha. 1232 | """ 1233 | if not images: 1234 | raise ValueError("No images provided for WebM with alpha.") 1235 | 1236 | # Ensure output directory exists 1237 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 1238 | 1239 | with tempfile.TemporaryDirectory() as tmpdir: 1240 | # Save frames as PNG (with alpha) 1241 | for idx, img in enumerate(images): 1242 | if img.mode != "RGBA": 1243 | img = img.convert("RGBA") 1244 | out_path = os.path.join(tmpdir, f"{idx:06d}.png") 1245 | img.save(out_path, "PNG") 1246 | 1247 | # Construct ffmpeg command 1248 | # -c:v libvpx-vp9 => VP9 encoder 1249 | # -pix_fmt yuva420p => alpha-enabled pixel format 1250 | # -auto-alt-ref 0 => helps preserve alpha frames (libvpx quirk) 1251 | ffmpeg_cmd = [ 1252 | "ffmpeg", "-y", 1253 | "-framerate", str(fps), 1254 | "-i", os.path.join(tmpdir, "%06d.png"), 1255 | "-c:v", "libvpx-vp9", 1256 | "-pix_fmt", "yuva420p", 1257 | "-auto-alt-ref", "0", 1258 | output_path 1259 | ] 1260 | 1261 | subprocess.run(ffmpeg_cmd, check=True) 1262 | 1263 | print(f"WebM with alpha saved to {output_path}") 1264 | 1265 | 1266 | def add_audio_to_video(video_without_audio_path, original_video_path, output_path): 1267 | """ 1268 | Check if the original video has an audio stream. If yes, add it. If not, skip. 1269 | """ 1270 | # 1) Probe original video for audio streams 1271 | probe_command = [ 1272 | 'ffprobe', '-v', 'error', 1273 | '-select_streams', 'a:0', 1274 | '-show_entries', 'stream=index', 1275 | '-of', 'csv=p=0', 1276 | original_video_path 1277 | ] 1278 | result = subprocess.run(probe_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) 1279 | 1280 | # result.stdout is empty if no audio stream found 1281 | if not result.stdout.strip(): 1282 | print("No audio track found in original video, skipping audio addition.") 1283 | return 1284 | 1285 | print("Audio track detected; proceeding to mux audio.") 1286 | # 2) If audio found, run ffmpeg to add it 1287 | command = [ 1288 | 'ffmpeg', '-y', 1289 | '-i', video_without_audio_path, 1290 | '-i', original_video_path, 1291 | '-c', 'copy', 1292 | '-map', '0:v:0', 1293 | '-map', '1:a:0', # we know there's an audio track now 1294 | output_path 1295 | ] 1296 | subprocess.run(command, check=True) 1297 | print(f"Audio added successfully => {output_path}") 1298 | 1299 | 1300 | ### Thanks to the source: https://huggingface.co/ZhengPeng7/BiRefNet/blob/main/handler.py 1301 | def refine_foreground_process(image, mask, r=90): 1302 | if mask.size != image.size: 1303 | mask = mask.resize(image.size) 1304 | image = np.array(image) / 255.0 1305 | mask = np.array(mask) / 255.0 1306 | estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r) 1307 | image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8)) 1308 | return image_masked 1309 | 1310 | 1311 | def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90): 1312 | # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation 1313 | alpha = alpha[:, :, None] 1314 | F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r) 1315 | return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0] 1316 | 1317 | 1318 | def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90): 1319 | if isinstance(image, Image.Image): 1320 | image = np.array(image) / 255.0 1321 | blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None] 1322 | 1323 | blurred_FA = cv2.blur(F * alpha, (r, r)) 1324 | blurred_F = blurred_FA / (blurred_alpha + 1e-5) 1325 | 1326 | blurred_B1A = cv2.blur(B * (1 - alpha), (r, r)) 1327 | blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5) 1328 | F = blurred_F + alpha * \ 1329 | (image - alpha * blurred_F - (1 - alpha) * blurred_B) 1330 | F = np.clip(F, 0, 1) 1331 | return F, blurred_B 1332 | 1333 | 1334 | def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray: 1335 | result = torch.squeeze(F.interpolate(result, size=im_size, mode='bilinear'), 0) 1336 | ma = torch.max(result) 1337 | mi = torch.min(result) 1338 | result = (result - mi) / (ma - mi) 1339 | im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8) 1340 | im_array = np.squeeze(im_array) 1341 | return im_array 1342 | 1343 | 1344 | def rgb_loader_refiner(original_image): 1345 | h, w = original_image.size 1346 | # # Apply EXIF orientation 1347 | 1348 | image = ImageOps.exif_transpose(original_image) 1349 | 1350 | if original_image.mode != 'RGB': 1351 | original_image = original_image.convert('RGB') 1352 | 1353 | image = original_image 1354 | # Convert to RGB if necessary 1355 | 1356 | # Resize the image 1357 | image = image.resize((1024, 1024), resample=Image.LANCZOS) 1358 | 1359 | return image, h, w, original_image --------------------------------------------------------------------------------