├── .github └── ISSUE_TEMPLATE │ ├── bug_report.yml │ └── config.yml ├── .gitignore ├── .pre-commit-config.yaml ├── COVENANT.md ├── COVENANT_zh-CN.md ├── LICENSE ├── README.md ├── README_zh-CN.md ├── api_test ├── README.md ├── double_blind │ ├── app.py │ ├── default_template.json │ ├── format_data2json.py │ └── style.css ├── post_infer.py ├── post_train.py └── post_video_infer.py ├── images ├── controlnet_num.jpg ├── ding_erweima.jpg ├── double_blindui.jpg ├── dsw.png ├── erweima.jpg ├── infer_ui.jpg ├── install.jpg ├── multi_people_1.jpg ├── multi_people_2.jpg ├── no_found_image.jpg ├── overview.jpg ├── results_1.jpg ├── results_2.jpg ├── results_3.jpg ├── scene_lora │ ├── Christmas_1.jpg │ ├── Cyberpunk_1.jpg │ ├── FairMaidenStyle_1.jpg │ ├── Gentleman_1.jpg │ ├── GuoFeng_1.jpg │ ├── GuoFeng_2.jpg │ ├── GuoFeng_3.jpg │ ├── GuoFeng_4.jpg │ ├── Minimalism_1.jpg │ ├── NaturalWind_1.jpg │ ├── Princess_1.jpg │ ├── Princess_2.jpg │ ├── Princess_3.jpg │ ├── SchoolUniform_1.jpg │ └── SchoolUniform_2.jpg ├── single_people.jpg ├── train_1.jpg ├── train_2.jpg ├── train_3.jpg ├── train_detail.jpg ├── train_detail1.jpg ├── train_ui.jpg ├── tryon │ ├── cloth │ │ ├── demo_black_200.jpg │ │ ├── demo_dress_200.jpg │ │ ├── demo_purple_200.jpg │ │ ├── demo_short_200.jpg │ │ └── demo_white_200.jpg │ └── template │ │ ├── boy.jpg │ │ ├── dress.jpg │ │ ├── girl.jpg │ │ └── short.jpg └── wechat.jpg ├── install.py ├── javascript └── ui.js ├── models ├── infer_templates │ ├── 1.jpg │ ├── 3.jpg │ ├── 4.jpg │ ├── 5.jpg │ ├── 6.jpg │ ├── 7.jpg │ ├── 8.jpg │ ├── 9.jpg │ └── Put templates here ├── stable-diffusion-v1-5 │ ├── model_index.json │ ├── scheduler │ │ └── scheduler_config.json │ └── tokenizer │ │ ├── merges.txt │ │ ├── special_tokens_map.json │ │ ├── tokenizer_config.json │ │ └── vocab.json ├── stable-diffusion-xl │ ├── madebyollin_sdxl_vae_fp16_fix │ │ └── config.json │ ├── models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k │ │ ├── refs │ │ │ └── main │ │ └── snapshots │ │ │ └── 8c7a3583335de4dba1b07182dbf81c75137ce67b │ │ │ ├── config.json │ │ │ ├── merges.txt │ │ │ ├── open_clip_config.json │ │ │ ├── preprocessor_config.json │ │ │ ├── pytorch_model.bin.index.json │ │ │ ├── special_tokens_map.json │ │ │ ├── tokenizer.json │ │ │ ├── tokenizer_config.json │ │ │ └── vocab.json │ ├── models--openai--clip-vit-large-patch14 │ │ ├── refs │ │ │ └── main │ │ └── snapshots │ │ │ └── 32bd64288804d66eefd0ccbe215aa642df71cc41 │ │ │ ├── config.json │ │ │ ├── merges.txt │ │ │ ├── preprocessor_config.json │ │ │ ├── special_tokens_map.json │ │ │ ├── tokenizer.json │ │ │ ├── tokenizer_config.json │ │ │ └── vocab.json │ └── stabilityai_stable_diffusion_xl_base_1.0 │ │ └── scheduler │ │ └── scheduler_config.json └── training_templates │ ├── 1.jpg │ ├── 2.jpg │ ├── 3.jpg │ └── 4.jpg └── scripts ├── api.py ├── easyphoto_config.py ├── easyphoto_infer.py ├── easyphoto_train.py ├── easyphoto_tryon_infer.py ├── easyphoto_ui.py ├── easyphoto_utils ├── __init__.py ├── animatediff │ ├── README.MD │ ├── animatediff_cn.py │ ├── animatediff_i2ibatch.py │ ├── animatediff_infotext.py │ ├── animatediff_infv2v.py │ ├── animatediff_latent.py │ ├── animatediff_lcm.py │ ├── animatediff_logger.py │ ├── animatediff_lora.py │ ├── animatediff_mm.py │ ├── animatediff_output.py │ ├── animatediff_prompt.py │ ├── animatediff_ui.py │ └── motion_module.py ├── animatediff_utils.py ├── common_utils.py ├── face_process_utils.py ├── fire_utils.py ├── loractl_utils.py ├── psgan_utils.py └── tryon_utils.py ├── preprocess.py ├── sdwebui.py └── train_kohya ├── ddpo_pytorch ├── diffusers_patch │ ├── ddim_with_logprob.py │ └── pipeline_with_logprob.py ├── prompts.py ├── rewards.py └── stat_tracking.py ├── train_ddpo.py ├── train_lora.py ├── train_lora_sd_XL.py └── utils ├── __init__.py ├── gpu_info.py ├── lora_utils.py ├── lora_utils_diffusers.py ├── model_utils.py ├── original_unet.py └── original_unet_sd_XL.py /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | # Borrowed from sd-webui-controlnet. 2 | name: Bug Report 3 | description: Create a report 4 | title: "[Bug]: " 5 | labels: ["bug-report"] 6 | 7 | body: 8 | - type: checkboxes 9 | attributes: 10 | label: Is there an existing issue for this? 11 | description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit. 12 | options: 13 | - label: I have searched the existing issues and checked the recent builds/commits of both this extension and the webui 14 | required: true 15 | - type: checkboxes 16 | attributes: 17 | label: Is EasyPhoto the latest version? 18 | description: Please check for updates in the extensions tab and reproduce the bug with the latest version. 19 | options: 20 | - label: I have updated EasyPhoto to the latest version and the bug still exists. 21 | required: true 22 | - type: markdown 23 | attributes: 24 | value: | 25 | *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible** 26 | - type: textarea 27 | id: what-did 28 | attributes: 29 | label: What happened? 30 | description: Tell us what happened in a very clear and simple way 31 | validations: 32 | required: true 33 | - type: textarea 34 | id: steps 35 | attributes: 36 | label: Steps to reproduce the problem 37 | description: Please provide us with precise step by step information on how to reproduce the bug 38 | value: | 39 | 1. Go to .... 40 | 2. Press .... 41 | 3. ... 42 | validations: 43 | required: true 44 | - type: textarea 45 | id: what-should 46 | attributes: 47 | label: What should have happened? 48 | description: Tell what you think the normal behavior should be 49 | validations: 50 | required: true 51 | - type: textarea 52 | id: commits 53 | attributes: 54 | label: Commit where the problem happens 55 | description: Which commit of the extension are you running on? Please include the commit of both the extension and the webui (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.) 56 | value: | 57 | webui: 58 | EastPhoto: 59 | validations: 60 | required: true 61 | - type: dropdown 62 | id: browsers 63 | attributes: 64 | label: What browsers do you use to access the UI ? 65 | multiple: true 66 | options: 67 | - Mozilla Firefox 68 | - Google Chrome 69 | - Brave 70 | - Apple Safari 71 | - Microsoft Edge 72 | - type: textarea 73 | id: cmdargs 74 | attributes: 75 | label: Command Line Arguments 76 | description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise. 77 | render: Shell 78 | validations: 79 | required: true 80 | - type: textarea 81 | id: extensions 82 | attributes: 83 | label: List of enabled extensions 84 | description: Please provide a full list of enabled extensions or screenshots of your "Extensions" tab. 85 | validations: 86 | required: true 87 | - type: textarea 88 | id: logs 89 | attributes: 90 | label: Console logs 91 | description: Please provide full cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service. 92 | render: Shell 93 | validations: 94 | required: true 95 | - type: textarea 96 | id: misc 97 | attributes: 98 | label: Additional information 99 | description: Please provide us with any relevant additional info or context. -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: true -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | # weight 162 | *.pth 163 | *.onnx 164 | *.safetensors 165 | *.ckpt 166 | *.bin 167 | *.pkl 168 | *.jpg 169 | *.png 170 | 171 | models/stable-diffusion-xl/version.txt 172 | models/pose_templates 173 | scripts/thirdparty 174 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/psf/black 3 | rev: 22.3.0 4 | hooks: 5 | - id: black 6 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py 7 | args: ["--line-length=140"] 8 | - repo: https://github.com/PyCQA/flake8 9 | rev: 3.9.2 10 | hooks: 11 | - id: flake8 12 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py 13 | args: ["--max-line-length=140", "--ignore=E303,E731,W191,W504,E402,E203,F541,W605,W503,E501,E712, F401", "--exclude=__init__.py"] 14 | - repo: https://github.com/myint/autoflake 15 | rev: v1.4 16 | hooks: 17 | - id: autoflake 18 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py 19 | args: 20 | [ 21 | "--recursive", 22 | "--in-place", 23 | "--remove-unused-variable", 24 | "--ignore-init-module-imports", 25 | "--exclude=__init__.py" 26 | ] 27 | - repo: https://github.com/pre-commit/pre-commit-hooks 28 | rev: v4.4.0 29 | hooks: 30 | - id: check-ast 31 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py 32 | - id: check-byte-order-marker 33 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py 34 | - id: check-case-conflict 35 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py 36 | - id: check-docstring-first 37 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py 38 | - id: check-executables-have-shebangs 39 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py 40 | - id: check-json 41 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py 42 | - id: check-yaml 43 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py 44 | - id: debug-statements 45 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py 46 | - id: detect-private-key 47 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py 48 | - id: end-of-file-fixer 49 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py 50 | - id: trailing-whitespace 51 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py 52 | - id: mixed-line-ending 53 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py 54 | -------------------------------------------------------------------------------- /COVENANT.md: -------------------------------------------------------------------------------- 1 | # EasyPhoto Developer Covenant 2 | Disclaimer: This covenant serves as a set of recommended guidelines. 3 | 4 | ## Overview: 5 | EasyPhoto is an open-source software built on the SDWebUI plugin ecosystem, focusing on leveraging AIGC technology to create true-to-life, aesthetic, and beautiful AI portraits ("true/like/beautiful"). We are committed to expanding the application scope of this technology, lowering the entry barrier, and facilitating use for a wide audience. 6 | 7 | ## Covenant Purpose: 8 | Although technology is inherently neutral in value, considering that EasyPhoto already has the capability to produce highly realistic images, particularly facial images, we strongly suggest that everyone involved in development and usage adhere to the following guidelines. 9 | 10 | ## Behavioral Guidelines: 11 | - Comply with laws and regulations of relevant jurisdictions: It is prohibited to use this technology for any unlawful, criminal, or activities against public morals and decency. 12 | - Content Restrictions: It is prohibited to produce or disseminate any images that may involve political figures, pornography, violence, or other activities contrary to regional regulations. 13 | 14 | ## Ongoing Updates: 15 | This covenant will be updated periodically to adapt to technological and social advancements. We encourage community members to follow these guidelines in daily interactions and usage. Non-compliance will result in appropriate community management actions. 16 | 17 | Thank you for your cooperation and support. Together, let's ensure that EasyPhoto remains a responsible and sustainably-developed open-source software. 18 | -------------------------------------------------------------------------------- /COVENANT_zh-CN.md: -------------------------------------------------------------------------------- 1 | # EasyPhoto 开发者公约 2 | !声明:本公约仅为推荐性准则 3 | 4 | ## 概述: 5 | EasyPhoto 是一个基于SDWebUI插件生态构建的开源软件,专注于利用AIGC技术实现真/像/美的AI-写真。我们致力于拓展该技术的应用范围,降低使用门槛,并为广大用户提供便利。 6 | 7 | ## 公约宗旨: 8 | 尽管技术本身并无价值倾向,但考虑到EasyPhoto目前已具备生成逼真图像(特别是人脸图像)的能力,我们强烈建议所有参与开发和使用的人员遵循以下准则。 9 | 10 | ## 行为准则: 11 | - 遵循相关地区的法律和法规:不得利用本技术从事任何违法、犯罪或有悖于社会公序良俗的活动。 12 | - 内容限制:禁止生成或传播任何可能涉及政治人物、色情、暴力或其他违反相关地区规定的图像。 13 | 14 | ## 持续更新: 15 | 本公约将不定期进行更新以适应技术和社会发展。我们鼓励社群成员在日常交流和使用中遵循这些准则。未遵守本公约的行为将在社群管理中受到相应限制。 16 | 感谢您的配合与支持。我们共同努力,以确保EasyPhoto成为一个负责任和可持续发展的开源软件。 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | the copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by the Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributors that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, the Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assuming any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README_zh-CN.md: -------------------------------------------------------------------------------- 1 | # EasyPhoto | 您的智能 AI 照片生成器。 2 | 🦜 EasyPhoto是一款Webui UI插件,用于生成AI肖像画,该代码可用于训练与您相关的数字分身。 3 | 4 | 🦜 🦜 Welcome! 5 | 6 | [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow)](https://huggingface.co/spaces/alibaba-pai/easyphoto) 7 | 8 | [English](./README.md) | 简体中文 9 | 10 | # 目录 11 | - [简介](#简介) 12 | - [TODO List](#todo-list) 13 | - [快速启动](#快速启动) 14 | - [1. 云使用: AliyunDSW/AutoDL/Docker](#1-云使用-aliyundswautodldocker) 15 | - [2. 本地安装: 环境检查/下载/安装](#2-本地安装-环境检查下载安装) 16 | - [如何使用](#如何使用) 17 | - [1. 模型训练](#1-模型训练) 18 | - [2. 人物生成](#2-人物生成) 19 | - [API测试](./api_test/README.md) 20 | - [算法详细信息](#算法详细信息) 21 | - [1. 架构概述](#1-架构概述) 22 | - [2. 训练细节](#2-训练细节) 23 | - [3. 推理细节](#3-推理细节) 24 | - [参考文献](#参考文献) 25 | - [相关项目](#相关项目) 26 | - [许可证](#许可证) 27 | - [联系我们](#联系我们) 28 | 29 | # 简介 30 | EasyPhoto是一款Webui UI插件,用于生成AI肖像画,该代码可用于训练与您相关的数字分身。建议使用 5 到 20 张肖像图片进行训练,最好是半身照片且不要佩戴眼镜(少量可以接受)。训练完成后,我们可以在推理部分生成图像。我们支持使用预设模板图片与上传自己的图片进行推理。 31 | 32 | 请阅读我们的开发者公约,共建美好社区 [covenant](./COVENANT.md) | [简体中文](./COVENANT_zh-CN.md) 33 | 34 | 如果您在训练中遇到一些问题,请参考 [VQA](https://github.com/aigc-apps/sd-webui-EasyPhoto/wiki)。 35 | 36 | 我们现在支持从不同平台快速启动,请参阅 [快速启动](#快速启动)。 37 | 38 | 新特性: 39 | - 支持基于LCM-Lora的采样加速,现在您只需要进行12个steps(vs 50steps)来生成图像和视频, 并支持了场景化(风格化) Lora的训练和大量内置的模型。[🔥 🔥 🔥 🔥 2023.12.09] 40 | - 支持基于Concepts-Sliders的属性编辑和虚拟试穿,请参考[sliders-wiki](https://github.com/aigc-apps/sd-webui-EasyPhoto/wiki/Attribute-Edit) , [tryon-wiki](https://github.com/aigc-apps/sd-webui-EasyPhoto/wiki/TryOn)获取更多详细信息。[🔥 🔥 🔥 🔥 2023.12.08] 41 | - 感谢[揽睿星舟](https://www.lanrui-ai.com/) 提供了内置EasyPhoto的SDWebUI官方镜像,并承诺每两周更新一次。亲自测试,可以在2分钟内拉起资源,并在5分钟内完成启动。[🔥 🔥 🔥 🔥 2023.11.20] 42 | - ComfyUI 支持 [repo](https://github.com/THtianhao/ComfyUI-Portrait-Maker), 感谢[THtianhao](https://github.com/THtianhao)的精彩工作![🔥 🔥 🔥 2023.10.17] 43 | - EasyPhoto 论文地址 [arxiv](https://arxiv.org/abs/2310.04672)[🔥 🔥 🔥 2023.10.10] 44 | - 支持使用SDXL模型和一定的选项直接生成高清大图,不再需要上传模板,需要16GB显存。具体细节可以前往[这里](https://zhuanlan.zhihu.com/p/658940203)[🔥 🔥 🔥 2023.09.26] 45 | - 我们同样支持[Diffusers版本](https://github.com/aigc-apps/EasyPhoto/)。 [🔥 2023.09.25] 46 | - **支持对背景进行微调,并计算生成的图像与用户之间的相似度得分。** [🔥🔥 2023.09.15] 47 | - **支持不同预测基础模型。** [🔥🔥 2023.09.08] 48 | - **支持多人生成!添加缓存选项以优化推理速度。在UI上添加日志刷新。** [🔥🔥 2023.09.06] 49 | - 创建代码!现在支持 Windows 和 Linux。[🔥 2023.09.02] 50 | 51 | 这些是我们的生成结果: 52 | ![results_1](images/results_1.jpg) 53 | ![results_2](images/results_2.jpg) 54 | ![results_3](images/results_3.jpg) 55 | 56 | 我们的ui界面如下: 57 | **训练部分:** 58 | ![train_ui](images/train_ui.jpg) 59 | **预测部分:** 60 | ![infer_ui](images/infer_ui.jpg) 61 | 62 | # TODO List 63 | - 支持中文界面。 64 | - 支持模板背景部分变化。 65 | - 支持高分辨率。 66 | 67 | # 快速启动 68 | ### 1. 云使用: AliyunDSW/AutoDL/揽睿星舟/Docker 69 | #### a. 通过阿里云 DSW 70 | DSW 有免费 GPU 时间,用户可申请一次,申请后3个月内有效。 71 | 72 | 阿里云在[Freetier](https://free.aliyun.com/?product=9602825&crowd=enterprise&spm=5176.28055625.J_5831864660.1.e939154aRgha4e&scm=20140722.M_9974135.P_110.MO_1806-ID_9974135-MID_9974135-CID_30683-ST_8512-V_1)提供免费GPU时间,获取并在阿里云PAI-DSW中使用,3分钟内即可启动EasyPhoto 73 | 74 | [![DSW Notebook](images/dsw.png)](https://gallery.pai-ml.com/#/preview/deepLearning/cv/stable_diffusion_easyphoto) 75 | 76 | #### b. 通过揽睿星舟/AutoDL 77 | ##### 揽睿星舟 78 | 揽睿星舟官方全插件版本内置EasyPhoto,并承诺每两周测试与更新,亲测可用,5分钟内拉起,感谢他们的支持和对社区做出的贡献。 79 | 80 | ##### AutoDL 81 | 如果您正在使用 AutoDL,您可以使用我们提供的镜像快速启动 Stable DIffusion webui。 82 | 83 | 您可以在社区镜像中填写以下信息来选择所需的镜像。 84 | ``` 85 | aigc-apps/sd-webui-EasyPhoto/sd-webui-EasyPhoto 86 | ``` 87 | #### c. 通过docker 88 | 使用docker的情况下,请保证机器中已经正确安装显卡驱动与CUDA环境,然后以此执行以下命令: 89 | ``` 90 | # 拉取镜像 91 | docker pull registry.cn-beijing.aliyuncs.com/mybigpai/sd-webui-easyphoto:0.0.3 92 | 93 | # 进入镜像 94 | docker run -it -p 7860:7860 --network host --gpus all registry.cn-beijing.aliyuncs.com/mybigpai/sd-webui-easyphoto:0.0.3 95 | 96 | # 启动webui 97 | python3 launch.py --port 7860 98 | ``` 99 | 100 | ### 2. 本地安装: 环境检查/下载/安装 101 | #### a. 环境检查 102 | 我们已验证EasyPhoto可在以下环境中执行: 103 | 如果你遇到内存使用过高而导致WebUI进程自动被kill掉,请参考[ISSUE21](https://github.com/aigc-apps/sd-webui-EasyPhoto/issues/21),设置一些参数,例如num_threads=0,如果你也发现了其他解决的好办法,请及时联系我们。 104 | 105 | Windows 10 的详细信息: 106 | - 操作系统: Windows10 107 | - python: python 3.10 108 | - pytorch: torch2.0.1 109 | - tensorflow-cpu: 2.13.0 110 | - CUDA: 11.7 111 | - CUDNN: 8+ 112 | - GPU: Nvidia-3060 12G 113 | 114 | Linux 的详细信息: 115 | - 操作系统 Ubuntu 20.04, CentOS 116 | - python: python3.10 & python3.11 117 | - pytorch: torch2.0.1 118 | - tensorflow-cpu: 2.13.0 119 | - CUDA: 11.7 120 | - CUDNN: 8+ 121 | - GPU: Nvidia-A10 24G & Nvidia-V100 16G & Nvidia-A100 40G 122 | 123 | 我们需要大约 60GB 的可用磁盘空间(用于保存权重和数据集),请检查! 124 | 125 | #### b. 相关资料库和权重下载 126 | ##### i. Controlnet 127 | 我们需要使用 Controlnet 进行推理。相关软件源是[Mikubill/sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet)。在使用 EasyPhoto 之前,您需要安装这个软件源。 128 | 129 | 此外,我们至少需要三个 Controlnets 用于推理。因此,您需要设置 **Multi ControlNet: Max models amount (requires restart)**。 130 | ![controlnet_num](images/controlnet_num.jpg) 131 | 132 | ##### ii. 其他依赖关系。 133 | 我们与现有的 stable-diffusion-webui 环境相互兼容,启动 stable-diffusion-webui 时会安装相关软件源。 134 | 135 | 我们所需的权重会在第一次开始训练时自动下载。 136 | 137 | #### c. 插件安装 138 | 现在我们支持从 git 安装 EasyPhoto。我们的仓库网址是 https://github.com/aigc-apps/sd-webui-EasyPhoto。 139 | 140 | 今后,我们将支持从 **Available** 安装 EasyPhoto。 141 | 142 | ![install](images/install.jpg) 143 | 144 | # 如何使用 145 | ### 1. 模型训练 146 | EasyPhoto训练界面如下: 147 | - 左边是训练图像。只需点击上传照片即可上传图片,点击清除照片即可删除上传的图片; 148 | - 右边是训练参数,不能为第一次训练进行调整。 149 | 150 | 点击上传照片后,我们可以开始上传图像**这里最好上传5到20张图像,包括不同的角度和光照**。最好有一些不包括眼镜的图像。如果所有图片都包含眼镜眼镜,则生成的结果可以容易地生成眼镜。 151 | ![train_1](images/train_1.jpg) 152 | 153 | 然后我们点击下面的“开始培训”,此时,我们需要填写上面的用户ID,例如用户名,才能开始培训。 154 | ![train_2](images/train_2.jpg) 155 | 156 | 模型开始训练后,webui会自动刷新训练日志。如果没有刷新,请单击“Refresh Log”按钮。 157 | ![train_3](images/train_3.jpg) 158 | 159 | 如果要设置参数,每个参数的解析如下: 160 | | 参数名 | 含义 | 161 | |--|--| 162 | | resolution | 训练时喂入网络的图片大小,默认值为512 | 163 | | validation & save steps| 验证图片与保存中间权重的steps数,默认值为100,代表每100步验证一次图片并保存权重 | 164 | | max train steps | 最大训练步数,默认值为800 | 165 | | max steps per photos | 每张图片的最大训练次数,默认为200 | 166 | | train batch size | 训练的批次大小,默认值为1 | 167 | | gradient accumulationsteps | 是否进行梯度累计,默认值为4,结合train batch size来看,每个Step相当于喂入四张图片 | 168 | | dataloader num workers | 数据加载的works数量,windows下不生效,因为设置了会报错,Linux正常设置 | 169 | | learning rate | 训练Lora的学习率,默认为1e-4 | 170 | | rank Lora | 权重的特征长度,默认为128 | 171 | | network alpha | Lora训练的正则化参数,一般为rank的二分之一,默认为64 | 172 | 173 | ### 2. 人物生成 174 | #### a. 单人模版 175 | - 步骤1:点击刷新按钮,查询训练后的用户ID对应的模型。 176 | - 步骤2:选择用户ID。 177 | - 步骤3:选择需要生成的模板。 178 | - 步骤4:单击“生成”按钮生成结果。 179 | 180 | ![single_people](images/single_people.jpg) 181 | 182 | #### b. 多人模板 183 | - 步骤1:转到EasyPhoto的设置页面,设置num_of_Faceid大于1。 184 | - 步骤2:应用设置。 185 | - 步骤3:重新启动webui的ui界面。 186 | - 步骤4:返回EasyPhoto并上传多人模板。 187 | - 步骤5:选择两个人的用户ID。 188 | - 步骤6:单击“生成”按钮。执行图像生成。 189 | 190 | ![single_people](images/multi_people_1.jpg) 191 | ![single_people](images/multi_people_2.jpg) 192 | 193 | # 算法详细信息 194 | - 英文论文[arxiv](https://arxiv.org/abs/2310.04672) 195 | - 中文博客[这里](https://blog.csdn.net/weixin_44791964/article/details/132922309) 196 | 197 | ### 1. 架构概述 198 | 199 | ![overview](images/overview.jpg) 200 | 201 | 在人工智能肖像领域,我们希望模型生成的图像逼真且与用户相似,而传统方法会引入不真实的光照(如人脸融合或roop)。为了解决这种不真实的问题,我们引入了稳定扩散模型的图像到图像功能。生成完美的个人肖像需要考虑所需的生成场景和用户的数字分身。我们使用一个预先准备好的模板作为所需的生成场景,并使用一个在线训练的人脸 LoRA 模型作为用户的数字分身,这是一种流行的稳定扩散微调模型。我们使用少量用户图像来训练用户的稳定数字分身,并在推理过程中根据人脸 LoRA 模型和预期生成场景生成个人肖像图像。 202 | 203 | 204 | ### 2. 训练细节 205 | 206 | ![overview](images/train_detail1.jpg) 207 | 208 | 首先,我们对输入的用户图像进行人脸检测,确定人脸位置后,按照一定比例截取输入图像。然后,我们使用显著性检测模型和皮肤美化模型获得干净的人脸训练图像,该图像基本上只包含人脸。然后,我们为每张图像贴上一个固定标签。这里不需要使用标签器,而且效果很好。最后,我们对稳定扩散模型进行微调,得到用户的数字分身。 209 | 210 | 在训练过程中,我们会利用模板图像进行实时验证,在训练结束后,我们会计算验证图像与用户图像之间的人脸 ID 差距,从而实现 Lora 融合,确保我们的 Lora 是用户的完美数字分 211 | 身。 212 | 213 | 此外,我们将选择验证中与用户最相似的图像作为 face_id 图像,用于推理。 214 | 215 | ### 3. 推理细节 216 | #### a. 第一次扩散: 217 | 首先,我们将对接收到的模板图像进行人脸检测,以确定为实现稳定扩散而需要涂抹的遮罩。然后,我们将使用模板图像与最佳用户图像进行人脸融合。人脸融合完成后,我们将使用上述遮罩对融合后的人脸图像进行内绘(fusion_image)。此外,我们还将通过仿射变换(replace_image)把训练中获得的最佳 face_id 图像贴到模板图像上。然后,我们将对其应用 Controlnets,在融合图像中使用带有颜色的 canny 提取特征,在替换图像中使用 openpose 提取特征,以确保图像的相似性和稳定性。然后,我们将使用稳定扩散(Stable Diffusion)结合用户的数字分割进行生成。 218 | 219 | #### b. 第二次扩散: 220 | 在得到第一次扩散的结果后,我们将把该结果与最佳用户图像进行人脸融合,然后再次使用稳定扩散与用户的数字分身进行生成。第二次生成将使用更高的分辨率。 221 | 222 | # 特别感谢 223 | 特别感谢DevelopmentZheng, qiuyanxin, rainlee, jhuang1207, bubbliiiing, wuziheng, yjjinjie, hkunzhe, yunkchen同学们的代码贡献(此排名不分先后)。 224 | 225 | # 参考文献 226 | - insightface:https://github.com/deepinsight/insightface 227 | - cv_resnet50_face:https://www.modelscope.cn/models/damo/cv_resnet50_face-detection_retinaface/summary 228 | - cv_u2net_salient:https://www.modelscope.cn/models/damo/cv_u2net_salient-detection/summary 229 | - cv_unet_skin_retouching_torch:https://www.modelscope.cn/models/damo/cv_unet_skin_retouching_torch/summary 230 | - cv_unet-image-face-fusion:https://www.modelscope.cn/models/damo/cv_unet-image-face-fusion_damo/summary 231 | - kohya:https://github.com/bmaltais/kohya_ss 232 | - controlnet-webui:https://github.com/Mikubill/sd-webui-controlnet 233 | 234 | # 相关项目 235 | 我们还列出了一些很棒的开源项目以及任何你可能会感兴趣的扩展项目: 236 | - [ModelScope](https://github.com/modelscope/modelscope). 237 | - [FaceChain](https://github.com/modelscope/facechain). 238 | - [sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet). 239 | - [sd-webui-roop](https://github.com/s0md3v/sd-webui-roop). 240 | - [roop](https://github.com/s0md3v/roop). 241 | - [sd-webui-deforum](https://github.com/deforum-art/sd-webui-deforum). 242 | - [sd-webui-additional-networks](https://github.com/kohya-ss/sd-webui-additional-networks). 243 | - [a1111-sd-webui-tagcomplete](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete). 244 | - [sd-webui-segment-anything](https://github.com/continue-revolution/sd-webui-segment-anything). 245 | - [sd-webui-tunnels](https://github.com/Bing-su/sd-webui-tunnels). 246 | - [sd-webui-mov2mov](https://github.com/Scholar01/sd-webui-mov2mov). 247 | 248 | # 许可证 249 | 本项目采用 [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE). 250 | 251 | # 联系我们 252 | 1. 使用[钉钉](https://www.dingtalk.com/)搜索2群54095000124或扫描下列二维码加入群聊 253 | 2. 由于 微信群 已经满了,需要扫描右边的图片先添加这个同学为好友,然后再加入 微信群 。 254 |
255 | 256 | 257 |
258 | -------------------------------------------------------------------------------- /api_test/README.md: -------------------------------------------------------------------------------- 1 | # Test & Profile-23/10/09 2 | ## 环境准备 3 | - 请搭建一个待测试的SDWebUI环境 4 | - 保证可以访问到上述SDWebUI,准备python环境,满足依赖 base64/json/numpy/cv2 5 | 6 | ## 训练/推理测试代码 7 | - **post_train.py** 支持公网URL图片/本地图片读取 8 | - **post_infer.py** 支持公网URL图片/本地图片读取 9 | 代码提供了默认URL,可修改,本地图片通过命令行参数输入。 10 | 11 | ## 双盲测试 12 | 基于上述的推理代码,我们可以实现预定模板和预定人物的Lora的批量测试图片生成,形成某个版本的记录。并基于此对两个版本的代码的生成结果进行双盲测试,下面,我们简单的使用一个例子进行双盲测试。打开后的UI 如下图 13 | 14 | ![overview](../images/double_blindui.jpg) 15 | 16 | #### 步骤 17 | - 按照环境准备,在环境准备相关的user_id模型 18 | - 准备预设模板和上面user_id对应的真人图片 19 | - templates 20 | - 1.jpg 21 | - 2.jpg 22 | - ref_image 23 | - id1.jpg 24 | - id2.jpg 25 | - 运行批量推理代码: version1, version2 需要分别推理一次。 26 | ```python 27 | python3 post_infer.py --template_dir templates --output_path test_data/version1 --user_ids your_id 28 | ``` 29 | - 运行数据整理代码 30 | ```python 31 | python3 ./double_blind/format_data2json.py --ref_images ref_image --version1_dir test_data/version1 --version2_dir test_data/version2 --output_json test_v1_v2.json 32 | ``` 33 | - 运行./double_blind/app.py 获取如下双盲测试页面, 请注意,gradio 存在很多bug,我们稳定可以运行的版本是gradio==3.48.0, 否则会出现gradio, BarPlot的奇怪问题。 34 | ```python 35 | python3 ./double_blind/app.py --data-path test_v1_v2.json --result-path ./result.jsonl 36 | ``` 37 | 运行上述代码后,会得到一个上述页面。如果在域名指定的机器,则可分享相关测试域名(待补充),然后获得 version1和version2的 WinningRate,作为PR的记录。 38 | -------------------------------------------------------------------------------- /api_test/double_blind/app.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import random 5 | 6 | import gradio as gr 7 | import pandas as pd 8 | 9 | 10 | def read_json(file_path: str): 11 | return json.load(open(file_path)) 12 | 13 | 14 | def read_jsonl(file_path: str): 15 | data = [] 16 | with open(file_path, "r") as file: 17 | for line in file: 18 | data.append(json.loads(line)) 19 | return data 20 | 21 | 22 | def write_jsonl(data: any, file_path: str): 23 | if not os.path.exists(file_path): 24 | with open(file_path, "w") as f: 25 | pass 26 | with open(file_path, "a") as f: 27 | json.dump(data, f, ensure_ascii=False) 28 | f.write("\n") 29 | 30 | 31 | def save_result(id, submit_cnt, ids, ids_list, id2data, results, *eval_results): 32 | 33 | if not all(eval_results): 34 | gr.Warning("请完整填写所有问题的答案。\nPlease complete the answers to all questions.") 35 | return next_item(id) + (submit_cnt,) 36 | 37 | if id is None: 38 | gr.Info("感谢您参与EasyPhoto的评测,本次评测已全部完成~🥰\nThank you for participating in the EasyPhoto review, this review is complete ~🥰") 39 | return None, [], None, None, draw_results(), submit_cnt 40 | 41 | if id in ids: 42 | ids.remove(id) 43 | item = id2data[id] 44 | result = {"id": id, "questions": template["questions"], "answers": []} 45 | for r in eval_results: 46 | if r == "持平/Tie": 47 | result["answers"].append("tie") 48 | elif r == "左边/Left": 49 | result["answers"].append("method1" if item["left"] == "img1" else "method2") 50 | elif r == "右边/Right": 51 | result["answers"].append("method2" if item["left"] == "img1" else "method1") 52 | 53 | results.append(result) 54 | write_jsonl(result, args.result_path) 55 | 56 | return next_item(ids, ids_list, id2data, results) + (submit_cnt + 1,) 57 | 58 | 59 | def next_item(ids, ids_list, id2data, results): 60 | 61 | if len(ids) <= 0: 62 | gr.Info("感谢您参与EasyPhoto的评测,本次评测已全部完成~🥰\nThank you for participating in the EasyPhoto review, this review is complete ~🥰") 63 | return None, [], None, None, draw_results(results, ids_list), ids, ids_list, id2data, results 64 | 65 | id = random.choice(list(ids)) 66 | 67 | if random.random() < 0.5: 68 | id2data[id]["left"] = "img1" 69 | left_img = id2data[id]["img1"] 70 | right_img = id2data[id]["img2"] 71 | else: 72 | id2data[id]["left"] = "img2" 73 | left_img = id2data[id]["img2"] 74 | right_img = id2data[id]["img1"] 75 | 76 | item = id2data[id] 77 | 78 | return ( 79 | item["id"], 80 | [(x, "") for x in item["reference_imgs"]], 81 | left_img, 82 | right_img, 83 | draw_results(results, ids_list), 84 | ids, 85 | ids_list, 86 | id2data, 87 | results, 88 | ) 89 | 90 | 91 | def draw_results(results, ids_list): 92 | 93 | if len(results) < len(ids_list): 94 | return None 95 | else: 96 | 97 | questions = template["questions"] 98 | num_questions = len(questions) 99 | 100 | method1_win = [0] * num_questions 101 | tie = [0] * num_questions 102 | method2_win = [0] * num_questions 103 | 104 | for item in results: 105 | assert len(item["answers"]) == num_questions 106 | for i in range(num_questions): 107 | if item["answers"][i] == "method1": 108 | method1_win[i] += 1 109 | elif item["answers"][i] == "tie": 110 | tie[i] += 1 111 | elif item["answers"][i] == "method2": 112 | method2_win[i] += 1 113 | else: 114 | raise NotImplementedError() 115 | results_for_drawing = {} 116 | 117 | method1_win += [sum(method1_win) / len(method1_win)] 118 | tie += [sum(tie) / len(tie)] 119 | method2_win += [sum(method2_win) / len(method2_win)] 120 | 121 | results_for_drawing["Questions"] = (questions + ["Average"]) * 3 122 | results_for_drawing["Win Rate"] = ( 123 | [x / len(results) * 100 for x in method1_win] 124 | + [x / len(results) * 100 for x in tie] 125 | + [x / len(results) * 100 for x in method2_win] 126 | ) 127 | 128 | results_for_drawing["Winner"] = ( 129 | [data[0]["method1"]] * (num_questions + 1) + ["Tie"] * (num_questions + 1) + [data[0]["method2"]] * (num_questions + 1) 130 | ) 131 | results_for_drawing = pd.DataFrame(results_for_drawing) 132 | 133 | return gr.BarPlot( 134 | results_for_drawing, 135 | x="Questions", 136 | y="Win Rate", 137 | color="Winner", 138 | title="Human Evaluation Result", 139 | vertical=False, 140 | width=450, 141 | height=300, 142 | ) 143 | 144 | 145 | def init_start(ids, ids_list, id2data, results): 146 | random_elements = random.sample(data, len(data) // 2) 147 | id2data = {} 148 | for item in random_elements: 149 | id2data[item["id"]] = item 150 | ids = set(id2data.keys()) 151 | ids_list = set(id2data.keys()) 152 | results = [] 153 | return next_item(ids, ids_list, id2data, results) 154 | 155 | 156 | parser = argparse.ArgumentParser() 157 | parser.add_argument("--template-file", default="default_template.json") 158 | parser.add_argument("--data-path", default="data/makeup_transfer/data.json") 159 | parser.add_argument("--result-path", default="data/makeup_transfer/result.jsonl") 160 | parser.add_argument("--port", type=int, default=80) 161 | 162 | args = parser.parse_args() 163 | # global data 164 | if not os.path.exists(args.template_file): 165 | args.template_file = "./double_blind/default_template.json" 166 | template = read_json(args.template_file) 167 | data = read_json(args.data_path) 168 | 169 | with gr.Blocks(title="EasyPhoto双盲评测", css="style.css") as app: 170 | 171 | id = gr.State() 172 | id2data = gr.State({}) 173 | ids = gr.State() 174 | ids_list = gr.State() 175 | results = gr.State([]) 176 | 177 | with gr.Column(visible=True, elem_id="start"): 178 | gr.Markdown("### 欢迎您参与EasyPhoto的本次评测。") 179 | gr.Markdown("### Welcome to this review of EasyPhoto.") 180 | with gr.Row(): 181 | start_btn = gr.Button("开始 / Start") 182 | 183 | with gr.Column(visible=False, elem_id="main"): 184 | submit_cnt = gr.State(value=1) 185 | 186 | with gr.Row(): 187 | with gr.Column(scale=3): 188 | reference_imgs = gr.Gallery( 189 | [], columns=3, rows=1, label="人物参考图片", show_label=True, elem_id="reference-imgs", visible=template["show_references"] 190 | ) 191 | with gr.Column(scale=1): 192 | pass 193 | 194 | gr.Markdown("### 根据下面的图片和上面的参考图片(如果有),回答下面的问题。") 195 | with gr.Row(): 196 | with gr.Column(scale=3): 197 | with gr.Row(): 198 | left_img = gr.Image(show_label=False) 199 | right_img = gr.Image(show_label=False) 200 | with gr.Column(scale=1): 201 | pass 202 | 203 | eval_results = [] 204 | for question in template["questions"]: 205 | eval_results.append(gr.Radio(["左边/Left", "持平/Tie", "右边/Right"], label=question, elem_classes="question")) 206 | 207 | submit = gr.Button("提交 / Submit") 208 | next_btn = gr.Button("换一个 / Change Another") 209 | 210 | with gr.Accordion("查看结果/View Results", open=False): 211 | with gr.Row(): 212 | with gr.Column(scale=1): 213 | plot = gr.BarPlot() 214 | with gr.Column(scale=1): 215 | pass 216 | 217 | start_btn.click( 218 | init_start, 219 | inputs=[ids, ids_list, id2data, results], 220 | outputs=[id, reference_imgs, left_img, right_img, plot, ids, ids_list, id2data, results], 221 | ).then( 222 | fn=None, 223 | _js="\ 224 | () => {\ 225 | document.querySelector('#start').style.display = 'none';\ 226 | document.querySelector('#main').style.display = 'flex';\ 227 | }\ 228 | ", 229 | inputs=None, 230 | outputs=[], 231 | ) 232 | 233 | submit.click( 234 | save_result, 235 | inputs=[id, submit_cnt, ids, ids_list, id2data, results] + eval_results, 236 | outputs=[id, reference_imgs, left_img, right_img, plot, ids, ids_list, id2data, results, submit_cnt], 237 | ) 238 | next_btn.click( 239 | next_item, 240 | inputs=[ids, ids_list, id2data, results], 241 | outputs=[id, reference_imgs, left_img, right_img, plot, ids, ids_list, id2data, results], 242 | ) 243 | 244 | if __name__ == "__main__": 245 | 246 | # 最高并发15 247 | app.queue(concurrency_count=15).launch(server_name="0.0.0.0", server_port=args.port, show_api=False) 248 | -------------------------------------------------------------------------------- /api_test/double_blind/default_template.json: -------------------------------------------------------------------------------- 1 | { 2 | "show_references": true, 3 | "num_reference_imgs": 2, 4 | "max_once": 15, 5 | "questions": [ 6 | "更美观?/Beauty?", 7 | "更像本人?/Similar?", 8 | "更真实?/Reality?" 9 | ] 10 | } 11 | -------------------------------------------------------------------------------- /api_test/double_blind/format_data2json.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from glob import glob 5 | 6 | 7 | def format_ref_images(file_paths): 8 | file_names = [os.path.basename(path).split(".")[0] for path in file_paths] 9 | result_dict = {k: [v] for k, v in zip(file_names, file_paths)} 10 | 11 | return result_dict 12 | 13 | 14 | def remove_last_underscore(input_string): 15 | parts = input_string.split("_") 16 | 17 | if len(parts) > 1: 18 | parts = parts[:-1] 19 | 20 | result = "_".join(parts) 21 | 22 | return result 23 | 24 | 25 | def match_prefix(file_name, image_formats): 26 | for img_prefix in image_formats: 27 | if os.path.exists(file_name): 28 | return file_name 29 | else: 30 | current_prefix = file_name.split("/")[-1].split(".")[-1] 31 | file_name = file_name.replace(current_prefix, img_prefix) 32 | 33 | print("Warning: No match file in compared version!") 34 | return file_name 35 | 36 | 37 | def find_value_for_key(file_name, dictionary): 38 | parts = file_name.split("/") 39 | last_part = parts[-1].split(".")[0] 40 | 41 | user_id = remove_last_underscore(last_part) 42 | 43 | if user_id in dictionary.keys(): 44 | return dictionary[user_id] 45 | else: 46 | return None 47 | 48 | 49 | if __name__ == "__main__": 50 | 51 | parser = argparse.ArgumentParser(description="Description of your script") 52 | 53 | parser.add_argument("--ref_images", type=str, default="", help="Path to the user_id reference directory") 54 | parser.add_argument("--version1_dir", type=str, help="Path to version1 output result") 55 | parser.add_argument("--version2_dir", type=str, help="Path to version2 output result") 56 | parser.add_argument("--output_json", type=str, help="Path to output_datajson") 57 | 58 | args = parser.parse_args() 59 | 60 | image_formats = ["*.jpg", "*.jpeg", "*.png", "*.webp"] 61 | ref_images = [] 62 | for image_format in image_formats: 63 | ref_images.extend(glob(os.path.join(args.ref_images, image_format))) 64 | 65 | if len(ref_images) == 0: 66 | print(f"Your reference dir contains no reference images. Set --ref_images to your user_id reference directory") 67 | else: 68 | print(f"reference images contains : {ref_images}") 69 | 70 | ref_dicts = format_ref_images(ref_images) 71 | # print(ref_dicts) 72 | 73 | result_data = [] 74 | abs_path = True 75 | 76 | version1_dir = args.version1_dir 77 | version2_dir = args.version2_dir 78 | method_a = version1_dir.strip().split("/")[-1] 79 | method_b = version2_dir.strip().split("/")[-1] 80 | 81 | image_formats = ["jpg", "jpeg", "png", "webp"] 82 | 83 | for root, dirs, files in os.walk(version1_dir): 84 | for filename in files: 85 | if filename.split(".")[-1] in image_formats: 86 | file_path = os.path.join(root, filename) 87 | file_path2 = os.path.join(version2_dir, filename) 88 | file_path2 = match_prefix(file_path2, image_formats) 89 | 90 | reference = find_value_for_key(file_path, ref_dicts) 91 | 92 | if reference: 93 | if abs_path: 94 | file_path = os.path.abspath(file_path) 95 | file_path2 = os.path.abspath(file_path2) 96 | reference = [os.path.abspath(t) for t in reference] 97 | 98 | if os.path.exists(file_path2) and reference is not None: 99 | data_item = { 100 | "id": len(result_data), 101 | "method1": method_a, 102 | "img1": file_path, 103 | "method2": method_b, 104 | "img2": file_path2, 105 | "reference_imgs": reference, 106 | } 107 | 108 | result_data.append(data_item) 109 | else: 110 | pass 111 | else: 112 | user_id = file_path.split("/")[-1].split("_")[0] 113 | print( 114 | f"No matching user_id for {file_path}. Aborting! \ 115 | Please rename the ref image as {user_id}.jpg" 116 | ) 117 | 118 | output_json = args.output_json 119 | with open(output_json, "w") as json_file: 120 | json.dump(result_data, json_file, indent=4) 121 | 122 | print(f"Generated JSON file: {output_json}") 123 | -------------------------------------------------------------------------------- /api_test/double_blind/style.css: -------------------------------------------------------------------------------- 1 | .contain { 2 | display: flex; 3 | } 4 | 5 | #component-0 { 6 | flex-grow: 1; 7 | } 8 | 9 | #start { 10 | margin-top: -20vh; 11 | text-align: center; 12 | justify-content: center; 13 | } 14 | 15 | #start h3 { 16 | margin: 0; 17 | } 18 | 19 | #start #component-5 { 20 | display: block; 21 | } 22 | 23 | /* #start button { 24 | width: 40%; 25 | min-width: 280px; 26 | max-width: 500px; 27 | } */ 28 | 29 | #reference-imgs .grid-wrap { 30 | min-height: 0 !important; 31 | max-height: 60vh; 32 | } 33 | 34 | .question span { 35 | font-size: 1.1rem; 36 | } 37 | 38 | .question input { 39 | width: 1.1rem !important; 40 | height: 1.1rem !important; 41 | } 42 | -------------------------------------------------------------------------------- /api_test/post_infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import base64 3 | import json 4 | import os 5 | import time 6 | from datetime import datetime 7 | from glob import glob 8 | from io import BytesIO 9 | 10 | import cv2 11 | import numpy as np 12 | import requests 13 | from tqdm import tqdm 14 | 15 | 16 | def decode_image_from_base64jpeg(base64_image): 17 | image_bytes = base64.b64decode(base64_image) 18 | np_arr = np.frombuffer(image_bytes, np.uint8) 19 | image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR) 20 | return image 21 | 22 | 23 | def post(encoded_image, user_id=None, url="http://0.0.0.0:7860"): 24 | if user_id is None: 25 | user_id = "test" 26 | datas = json.dumps( 27 | { 28 | "user_ids": [user_id], 29 | "sd_model_checkpoint": "Chilloutmix-Ni-pruned-fp16-fix.safetensors", 30 | "init_image": encoded_image, 31 | "first_diffusion_steps": 50, 32 | "first_denoising_strength": 0.45, 33 | "second_diffusion_steps": 20, 34 | "second_denoising_strength": 0.35, 35 | "seed": 12345, 36 | "crop_face_preprocess": True, 37 | "before_face_fusion_ratio": 0.5, 38 | "after_face_fusion_ratio": 0.5, 39 | "apply_face_fusion_before": True, 40 | "apply_face_fusion_after": True, 41 | "color_shift_middle": True, 42 | "color_shift_last": True, 43 | "super_resolution": True, 44 | "super_resolution_method": "gpen", 45 | "skin_retouching_bool": False, 46 | "background_restore": False, 47 | "background_restore_denoising_strength": 0.35, 48 | "makeup_transfer": False, 49 | "makeup_transfer_ratio": 0.50, 50 | "face_shape_match": False, 51 | "tabs": 1, 52 | "ipa_control": False, 53 | "ipa_weight": 0.50, 54 | "ipa_image": None, 55 | "ref_mode_choose": "Infer with Pretrained Lora", 56 | "ipa_only_weight": 0.60, 57 | "ipa_only_image": None, 58 | "lcm_accelerate": False, 59 | } 60 | ) 61 | r = requests.post(f"{url}/easyphoto/easyphoto_infer_forward", data=datas, timeout=1500) 62 | data = r.content.decode("utf-8") 63 | return data 64 | 65 | 66 | if __name__ == "__main__": 67 | """ 68 | There are two ways to test: 69 | The first: make sure the directory is full of readable images 70 | The second: public link of readable picture 71 | """ 72 | parser = argparse.ArgumentParser(description="Description of your script") 73 | 74 | parser.add_argument("--template_dir", type=str, default="", help="Path to the template directory") 75 | parser.add_argument("--output_path", type=str, default="./", help="Path to the output directory") 76 | parser.add_argument("--user_ids", type=str, default="test", help="Test user ids, split with space") 77 | 78 | args = parser.parse_args() 79 | 80 | template_dir = args.template_dir 81 | output_path = args.output_path 82 | user_ids = args.user_ids.split(" ") 83 | 84 | if output_path != "./": 85 | os.makedirs(output_path, exist_ok=True) 86 | 87 | # initiate time 88 | now_date = datetime.now() 89 | time_start = time.time() 90 | 91 | # -------------------test infer------------------- # 92 | # When there is no parameter input. 93 | if template_dir == "": 94 | encoded_image = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/webui/api/template1.jpeg" 95 | encoded_image = requests.get(encoded_image) 96 | encoded_image = base64.b64encode(BytesIO(encoded_image.content).read()).decode("utf-8") 97 | 98 | for user_id in tqdm(user_ids): 99 | outputs = post(encoded_image, user_id) 100 | outputs = json.loads(outputs) 101 | image = decode_image_from_base64jpeg(outputs["outputs"][0]) 102 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_tmp.jpg") 103 | cv2.imwrite(toutput_path, image) 104 | print(outputs["message"]) 105 | 106 | # When selecting a local file as a parameter input. 107 | else: 108 | image_formats = ["*.jpg", "*.jpeg", "*.png", "*.webp"] 109 | img_list = [] 110 | for image_format in image_formats: 111 | img_list.extend(glob(os.path.join(template_dir, image_format))) 112 | 113 | if len(img_list) == 0: 114 | print(f" Input template dir {template_dir} contains no images") 115 | else: 116 | print(f" Total {len(img_list)} templates to test for {len(user_ids)} ID") 117 | 118 | # please set your test user ids in args 119 | for user_id in tqdm(user_ids): 120 | for img_path in tqdm(img_list): 121 | print(f" Call generate for ID ({user_id}) and Template ({img_path})") 122 | 123 | with open(img_path, "rb") as f: 124 | encoded_image = base64.b64encode(f.read()).decode("utf-8") 125 | outputs = post(encoded_image, user_id=user_id) 126 | outputs = json.loads(outputs) 127 | 128 | if len(outputs["outputs"]): 129 | image = decode_image_from_base64jpeg(outputs["outputs"][0]) 130 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_" + os.path.basename(img_path)) 131 | print(output_path) 132 | cv2.imwrite(toutput_path, image) 133 | else: 134 | print("Error!", outputs["message"]) 135 | print(outputs["message"]) 136 | 137 | # End of record time 138 | # The calculated time difference is the execution time of the program, expressed in seconds / s 139 | time_end = time.time() 140 | time_sum = time_end - time_start 141 | print("# --------------------------------------------------------- #") 142 | print(f"# Total expenditure: {time_sum}s") 143 | print("# --------------------------------------------------------- #") 144 | -------------------------------------------------------------------------------- /api_test/post_train.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import os 4 | import sys 5 | import time 6 | from glob import glob 7 | from io import BytesIO 8 | 9 | import requests 10 | 11 | 12 | def post_train(encoded_images, url="http://0.0.0.0:7860"): 13 | datas = json.dumps( 14 | { 15 | "user_id": "test", # A custom ID that identifies the trained face model 16 | "sd_model_checkpoint": "Chilloutmix-Ni-pruned-fp16-fix.safetensors", 17 | "train_mode_choose": "Train Human Lora", 18 | "resolution": 512, 19 | "val_and_checkpointing_steps": 100, 20 | "max_train_steps": 800, # Training steps 21 | "steps_per_photos": 200, 22 | "train_batch_size": 1, 23 | "gradient_accumulation_steps": 4, 24 | "dataloader_num_workers": 16, 25 | "learning_rate": 1e-4, 26 | "rank": 128, 27 | "network_alpha": 64, 28 | "instance_images": encoded_images, 29 | "skin_retouching_bool": False, 30 | } 31 | ) 32 | r = requests.post(f"{url}/easyphoto/easyphoto_train_forward", data=datas, timeout=1500) 33 | data = r.content.decode("utf-8") 34 | return data 35 | 36 | 37 | if __name__ == "__main__": 38 | """ 39 | There are two ways to test: 40 | The first: make sure the directory is full of readable images 41 | The second: public link of readable picture 42 | """ 43 | # initiate time 44 | time_start = time.time() 45 | 46 | # -------------------training procedure------------------- # 47 | # When there is no parameter input. 48 | if len(sys.argv) == 1: 49 | img_list = [ 50 | "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/webui/api/t1.jpg", 51 | "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/webui/api/t2.jpg", 52 | "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/webui/api/t3.jpg", 53 | "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/webui/api/t4.jpg", 54 | ] 55 | encoded_images = [] 56 | for idx, img_path in enumerate(img_list): 57 | encoded_image = requests.get(img_path) 58 | encoded_image = base64.b64encode(BytesIO(encoded_image.content).read()).decode("utf-8") 59 | encoded_images.append(encoded_image) 60 | 61 | outputs = post_train(encoded_images) 62 | outputs = json.loads(outputs) 63 | print(outputs["message"]) 64 | 65 | # When selecting a folder as a parameter input. 66 | elif len(sys.argv) == 2: 67 | img_list = glob(os.path.join(sys.argv[1], "*")) 68 | encoded_images = [] 69 | for idx, img_path in enumerate(img_list): 70 | with open(img_path, "rb") as f: 71 | encoded_image = base64.b64encode(f.read()).decode("utf-8") 72 | encoded_images.append(encoded_image) 73 | outputs = post_train(encoded_images) 74 | outputs = json.loads(outputs) 75 | print(outputs["message"]) 76 | 77 | else: 78 | print("other modes except url and local read are not supported") 79 | 80 | # End of record time 81 | # The calculated time difference is the execution time of the program, expressed in minute / m 82 | time_end = time.time() 83 | time_sum = (time_end - time_start) // 60 84 | 85 | print("# --------------------------------------------------------- #") 86 | print(f"# Total expenditure:{time_sum} minutes ") 87 | print("# --------------------------------------------------------- #") 88 | -------------------------------------------------------------------------------- /api_test/post_video_infer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import base64 3 | import json 4 | import os 5 | import time 6 | from datetime import datetime 7 | 8 | import requests 9 | from tqdm import tqdm 10 | 11 | 12 | # Function to encode a video file to Base64 13 | def encode_video_to_base64(video_file_path): 14 | with open(video_file_path, "rb") as video_file: 15 | # Read the video file as binary data 16 | video_data = video_file.read() 17 | # Encode the data to Base64 18 | video_base64 = base64.b64encode(video_data) 19 | return video_base64 20 | 21 | 22 | # Function to decode Base64 encoded data and save it as a video file 23 | def decode_base64_to_video(encoded_video, output_file_path): 24 | with open(output_file_path, "wb") as output_file: 25 | # Decode the Base64 encoded data 26 | video_data = base64.b64decode(encoded_video) 27 | # Write the decoded binary data to the file 28 | output_file.write(video_data) 29 | 30 | 31 | def post(t2v_input_prompt="", init_image=None, last_image=None, init_video=None, user_id=None, tabs=0, url="http://0.0.0.0:7860"): 32 | if user_id is None: 33 | user_id = "test" 34 | datas = json.dumps( 35 | { 36 | "user_ids": [user_id], 37 | "sd_model_checkpoint": "Chilloutmix-Ni-pruned-fp16-fix.safetensors", 38 | "sd_model_checkpoint_for_animatediff_text2video": "majicmixRealistic_v7.safetensors", 39 | "sd_model_checkpoint_for_animatediff_image2video": "majicmixRealistic_v7.safetensors", 40 | "t2v_input_prompt": t2v_input_prompt, 41 | "t2v_input_width": 512, 42 | "t2v_input_height": 768, 43 | "scene_id": "none", 44 | "upload_control_video": False, 45 | "upload_control_video_type": "openpose", 46 | "openpose_video": None, 47 | "init_image": init_image, 48 | "init_image_prompt": "", 49 | "last_image": last_image, 50 | "init_video": init_video, 51 | "additional_prompt": "masterpiece, beauty", 52 | "max_frames": 16, 53 | "max_fps": 8, 54 | "save_as": "gif", 55 | "first_diffusion_steps": 50, 56 | "first_denoising_strength": 0.45, 57 | "seed": -1, 58 | "crop_face_preprocess": True, 59 | "before_face_fusion_ratio": 0.5, 60 | "after_face_fusion_ratio": 0.5, 61 | "apply_face_fusion_before": True, 62 | "apply_face_fusion_after": True, 63 | "color_shift_middle": True, 64 | "super_resolution": True, 65 | "super_resolution_method": "gpen", 66 | "skin_retouching_bool": False, 67 | "makeup_transfer": False, 68 | "makeup_transfer_ratio": 0.50, 69 | "face_shape_match": False, 70 | "video_interpolation": False, 71 | "video_interpolation_ext": 1, 72 | "tabs": tabs, 73 | "ipa_control": False, 74 | "ipa_weight": 0.50, 75 | "ipa_image": None, 76 | "lcm_accelerate": False, 77 | } 78 | ) 79 | r = requests.post(f"{url}/easyphoto/easyphoto_video_infer_forward", data=datas, timeout=1500) 80 | data = r.content.decode("utf-8") 81 | return data 82 | 83 | 84 | if __name__ == "__main__": 85 | """ 86 | There are two ways to test: 87 | The first: make sure the directory is full of readable images 88 | The second: public link of readable picture 89 | """ 90 | parser = argparse.ArgumentParser(description="Description of your script") 91 | 92 | parser.add_argument( 93 | "--t2v_input_prompt", 94 | type=str, 95 | default="1girl, (white hair, long hair), blue eyes, hair ornament, blue dress, standing, looking at viewer, shy, upper-body", 96 | help="Prompt for t2v", 97 | ) 98 | parser.add_argument("--init_image_path", type=str, default="", help="Path to the init image path") 99 | parser.add_argument("--last_image_path", type=str, default="", help="Path to the last image path") 100 | parser.add_argument("--video_path", type=str, default="", help="Path to the video path") 101 | parser.add_argument("--output_path", type=str, default="./", help="Path to the output directory") 102 | parser.add_argument("--user_ids", type=str, default="test", help="Test user ids, split with space") 103 | 104 | args = parser.parse_args() 105 | 106 | t2v_input_prompt = args.t2v_input_prompt 107 | init_image_path = args.init_image_path 108 | last_image_path = args.last_image_path 109 | video_path = args.video_path 110 | output_path = args.output_path 111 | user_ids = args.user_ids.split(" ") 112 | 113 | if output_path != "./": 114 | os.makedirs(output_path, exist_ok=True) 115 | 116 | # initiate time 117 | now_date = datetime.now() 118 | time_start = time.time() 119 | 120 | # -------------------test infer------------------- # 121 | # When there is no parameter input. 122 | if init_image_path == "" and last_image_path == "" and video_path == "": 123 | for user_id in tqdm(user_ids): 124 | outputs = post(t2v_input_prompt, None, None, None, user_id, tabs=0) 125 | outputs = json.loads(outputs) 126 | print(outputs["message"]) 127 | if outputs["output_video"] is not None: 128 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_tmp.mp4") 129 | decode_base64_to_video(outputs["output_video"], toutput_path) 130 | elif outputs["output_gif"] is not None: 131 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_tmp.gif") 132 | decode_base64_to_video(outputs["output_gif"], toutput_path) 133 | 134 | elif init_image_path != "" and last_image_path == "" and video_path == "": 135 | with open(init_image_path, "rb") as f: 136 | init_image = base64.b64encode(f.read()).decode("utf-8") 137 | 138 | for user_id in tqdm(user_ids): 139 | outputs = post(t2v_input_prompt, init_image, None, None, user_id, tabs=1) 140 | outputs = json.loads(outputs) 141 | print(outputs["message"]) 142 | if outputs["output_video"] is not None: 143 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_tmp.mp4") 144 | decode_base64_to_video(outputs["output_video"], toutput_path) 145 | elif outputs["output_gif"] is not None: 146 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_tmp.gif") 147 | decode_base64_to_video(outputs["output_gif"], toutput_path) 148 | 149 | elif init_image_path != "" and last_image_path != "" and video_path == "": 150 | with open(init_image_path, "rb") as f: 151 | init_image = base64.b64encode(f.read()).decode("utf-8") 152 | with open(last_image_path, "rb") as f: 153 | last_image = base64.b64encode(f.read()).decode("utf-8") 154 | 155 | for user_id in tqdm(user_ids): 156 | outputs = post(t2v_input_prompt, init_image, last_image, None, user_id, tabs=1) 157 | outputs = json.loads(outputs) 158 | print(outputs["message"]) 159 | if outputs["output_video"] is not None: 160 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_tmp.mp4") 161 | decode_base64_to_video(outputs["output_video"], toutput_path) 162 | elif outputs["output_gif"] is not None: 163 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_tmp.gif") 164 | decode_base64_to_video(outputs["output_gif"], toutput_path) 165 | 166 | elif init_image_path == "" and last_image_path == "" and video_path != "": 167 | with open(video_path, "rb") as f: 168 | init_video = base64.b64encode(f.read()).decode("utf-8") 169 | 170 | for user_id in tqdm(user_ids): 171 | outputs = post(t2v_input_prompt, None, None, init_video, user_id, tabs=2) 172 | outputs = json.loads(outputs) 173 | print(outputs["message"]) 174 | if outputs["output_video"] is not None: 175 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_tmp.mp4") 176 | decode_base64_to_video(outputs["output_video"], toutput_path) 177 | elif outputs["output_gif"] is not None: 178 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_tmp.gif") 179 | decode_base64_to_video(outputs["output_gif"], toutput_path) 180 | 181 | # End of record time 182 | # The calculated time difference is the execution time of the program, expressed in seconds / s 183 | time_end = time.time() 184 | time_sum = time_end - time_start 185 | print("# --------------------------------------------------------- #") 186 | print(f"# Total expenditure: {time_sum}s") 187 | print("# --------------------------------------------------------- #") 188 | -------------------------------------------------------------------------------- /images/controlnet_num.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/controlnet_num.jpg -------------------------------------------------------------------------------- /images/ding_erweima.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/ding_erweima.jpg -------------------------------------------------------------------------------- /images/double_blindui.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/double_blindui.jpg -------------------------------------------------------------------------------- /images/dsw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/dsw.png -------------------------------------------------------------------------------- /images/erweima.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/erweima.jpg -------------------------------------------------------------------------------- /images/infer_ui.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/infer_ui.jpg -------------------------------------------------------------------------------- /images/install.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/install.jpg -------------------------------------------------------------------------------- /images/multi_people_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/multi_people_1.jpg -------------------------------------------------------------------------------- /images/multi_people_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/multi_people_2.jpg -------------------------------------------------------------------------------- /images/no_found_image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/no_found_image.jpg -------------------------------------------------------------------------------- /images/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/overview.jpg -------------------------------------------------------------------------------- /images/results_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/results_1.jpg -------------------------------------------------------------------------------- /images/results_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/results_2.jpg -------------------------------------------------------------------------------- /images/results_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/results_3.jpg -------------------------------------------------------------------------------- /images/scene_lora/Christmas_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/Christmas_1.jpg -------------------------------------------------------------------------------- /images/scene_lora/Cyberpunk_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/Cyberpunk_1.jpg -------------------------------------------------------------------------------- /images/scene_lora/FairMaidenStyle_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/FairMaidenStyle_1.jpg -------------------------------------------------------------------------------- /images/scene_lora/Gentleman_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/Gentleman_1.jpg -------------------------------------------------------------------------------- /images/scene_lora/GuoFeng_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/GuoFeng_1.jpg -------------------------------------------------------------------------------- /images/scene_lora/GuoFeng_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/GuoFeng_2.jpg -------------------------------------------------------------------------------- /images/scene_lora/GuoFeng_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/GuoFeng_3.jpg -------------------------------------------------------------------------------- /images/scene_lora/GuoFeng_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/GuoFeng_4.jpg -------------------------------------------------------------------------------- /images/scene_lora/Minimalism_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/Minimalism_1.jpg -------------------------------------------------------------------------------- /images/scene_lora/NaturalWind_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/NaturalWind_1.jpg -------------------------------------------------------------------------------- /images/scene_lora/Princess_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/Princess_1.jpg -------------------------------------------------------------------------------- /images/scene_lora/Princess_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/Princess_2.jpg -------------------------------------------------------------------------------- /images/scene_lora/Princess_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/Princess_3.jpg -------------------------------------------------------------------------------- /images/scene_lora/SchoolUniform_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/SchoolUniform_1.jpg -------------------------------------------------------------------------------- /images/scene_lora/SchoolUniform_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/SchoolUniform_2.jpg -------------------------------------------------------------------------------- /images/single_people.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/single_people.jpg -------------------------------------------------------------------------------- /images/train_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/train_1.jpg -------------------------------------------------------------------------------- /images/train_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/train_2.jpg -------------------------------------------------------------------------------- /images/train_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/train_3.jpg -------------------------------------------------------------------------------- /images/train_detail.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/train_detail.jpg -------------------------------------------------------------------------------- /images/train_detail1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/train_detail1.jpg -------------------------------------------------------------------------------- /images/train_ui.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/train_ui.jpg -------------------------------------------------------------------------------- /images/tryon/cloth/demo_black_200.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/tryon/cloth/demo_black_200.jpg -------------------------------------------------------------------------------- /images/tryon/cloth/demo_dress_200.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/tryon/cloth/demo_dress_200.jpg -------------------------------------------------------------------------------- /images/tryon/cloth/demo_purple_200.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/tryon/cloth/demo_purple_200.jpg -------------------------------------------------------------------------------- /images/tryon/cloth/demo_short_200.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/tryon/cloth/demo_short_200.jpg -------------------------------------------------------------------------------- /images/tryon/cloth/demo_white_200.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/tryon/cloth/demo_white_200.jpg -------------------------------------------------------------------------------- /images/tryon/template/boy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/tryon/template/boy.jpg -------------------------------------------------------------------------------- /images/tryon/template/dress.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/tryon/template/dress.jpg -------------------------------------------------------------------------------- /images/tryon/template/girl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/tryon/template/girl.jpg -------------------------------------------------------------------------------- /images/tryon/template/short.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/tryon/template/short.jpg -------------------------------------------------------------------------------- /images/wechat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/wechat.jpg -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | # Package check util 2 | # Modified from https://github.com/Bing-su/adetailer/blob/main/install.py 3 | import importlib.util 4 | import platform 5 | from importlib.metadata import version 6 | 7 | import launch 8 | from packaging.version import parse 9 | 10 | 11 | def is_installed(package: str): 12 | min_version = "0.0.0" 13 | max_version = "99999999.99999999.99999999" 14 | pkg_name = package 15 | version_check = True 16 | if "==" in package: 17 | pkg_name, _version = package.split("==") 18 | min_version = max_version = _version 19 | elif "<=" in package: 20 | pkg_name, _version = package.split("<=") 21 | max_version = _version 22 | elif ">=" in package: 23 | pkg_name, _version = package.split(">=") 24 | min_version = _version 25 | else: 26 | version_check = False 27 | package = pkg_name 28 | try: 29 | spec = importlib.util.find_spec(package) 30 | except ModuleNotFoundError: 31 | message = f"is_installed check for {str(package)} failed as error ModuleNotFoundError" 32 | print(message) 33 | return False 34 | if spec is None: 35 | message = f"is_installed check for {str(package)} failed as 'spec is None'" 36 | print(message) 37 | return False 38 | if not version_check: 39 | return True 40 | if package == "google.protobuf": 41 | package = "protobuf" 42 | try: 43 | pkg_version = version(package) 44 | return parse(min_version) <= parse(pkg_version) <= parse(max_version) 45 | except Exception as e: 46 | message = f"is_installed check for {str(package)} failed as error {str(e)}" 47 | print(message) 48 | return False 49 | 50 | 51 | # End of Package check util 52 | 53 | if not is_installed("cv2"): 54 | print("Installing requirements for easyphoto-webui") 55 | launch.run_pip("install opencv-python", "requirements for opencv") 56 | 57 | if not is_installed("tensorflow-cpu"): 58 | print("Installing requirements for easyphoto-webui") 59 | launch.run_pip("install tensorflow-cpu", "requirements for tensorflow") 60 | 61 | if not is_installed("onnx"): 62 | print("Installing requirements for easyphoto-webui") 63 | launch.run_pip("install onnx", "requirements for onnx") 64 | 65 | if not is_installed("onnxruntime"): 66 | print("Installing requirements for easyphoto-webui") 67 | launch.run_pip("install onnxruntime", "requirements for onnxruntime") 68 | 69 | if not is_installed("modelscope==1.9.3"): 70 | print("Installing requirements for easyphoto-webui") 71 | launch.run_pip("install modelscope==1.9.3", "requirements for modelscope") 72 | 73 | if not is_installed("einops"): 74 | print("Installing requirements for easyphoto-webui") 75 | launch.run_pip("install einops", "requirements for diffusers") 76 | 77 | if not is_installed("imageio>=2.29.0"): 78 | print("Installing requirements for easyphoto-webui") 79 | # The '>' will be interpreted as redirection (in linux) since SD WebUI uses `shell=True` in `subprocess.run`. 80 | launch.run_pip("install \"imageio>=2.29.0\"", "requirements for imageio") 81 | 82 | if not is_installed("av"): 83 | print("Installing requirements for easyphoto-webui") 84 | launch.run_pip("install \"imageio[pyav]\"", "requirements for av") 85 | 86 | # Temporarily pin fsspec==2023.9.2. See https://github.com/huggingface/datasets/issues/6330 for details. 87 | if not is_installed("fsspec==2023.9.2"): 88 | print("Installing requirements for easyphoto-webui") 89 | launch.run_pip("install fsspec==2023.9.2", "requirements for fsspec") 90 | 91 | # `StableDiffusionXLPipeline` in diffusers requires the invisible-watermark library. 92 | if not launch.is_installed("invisible-watermark"): 93 | print("Installing requirements for easyphoto-webui") 94 | launch.run_pip("install invisible-watermark", "requirements for invisible-watermark") 95 | 96 | # Tryon requires the shapely and segment-anything library. 97 | if not launch.is_installed("shapely"): 98 | print("Installing requirements for easyphoto-webui") 99 | launch.run_pip("install shapely", "requirements for shapely") 100 | 101 | if not launch.is_installed("segment_anything"): 102 | try: 103 | launch.run_pip("install segment-anything", "requirements for segment_anything") 104 | except Exception: 105 | print("Can't install segment-anything. Please follow the readme to install manually") 106 | 107 | if not is_installed("diffusers>=0.18.2"): 108 | print("Installing requirements for easyphoto-webui") 109 | try: 110 | launch.run_pip("install diffusers==0.23.0", "requirements for diffusers") 111 | except Exception as e: 112 | print(f"Can't install the diffusers==0.23.0. Error info {e}") 113 | launch.run_pip("install diffusers==0.18.2", "requirements for diffusers") 114 | 115 | if platform.system() != "Windows": 116 | if not is_installed("nvitop"): 117 | print("Installing requirements for easyphoto-webui") 118 | launch.run_pip("install nvitop==1.3.0", "requirements for tensorflow") 119 | -------------------------------------------------------------------------------- /javascript/ui.js: -------------------------------------------------------------------------------- 1 | function ask_for_style_name(sd_model_checkpoint, dummy_component, _, train_mode_choose, resolution, val_and_checkpointing_steps, max_train_steps, steps_per_photos, train_batch_size, gradient_accumulation_steps, dataloader_num_workers, learning_rate, rank, network_alpha, validation, instance_images, enable_rl, max_rl_time, timestep_fraction, skin_retouching, training_prefix_prompt, crop_ratio) { 2 | var name_ = prompt('User id:'); 3 | return [sd_model_checkpoint, dummy_component, name_, train_mode_choose, resolution, val_and_checkpointing_steps, max_train_steps, steps_per_photos, train_batch_size, gradient_accumulation_steps, dataloader_num_workers, learning_rate, rank, network_alpha, validation, instance_images, enable_rl, max_rl_time, timestep_fraction, skin_retouching, training_prefix_prompt, crop_ratio]; 4 | } 5 | 6 | function switch_to_ep_photoinfer_upload() { 7 | gradioApp().getElementById('mode_easyphoto').querySelectorAll('button')[1].click(); 8 | gradioApp().getElementById('mode_easyphoto_photo_inference').querySelectorAll('button')[0].click(); 9 | 10 | return Array.from(arguments); 11 | } 12 | 13 | function switch_to_ep_tryon() { 14 | gradioApp().getElementById('mode_easyphoto').querySelectorAll('button')[3].click(); 15 | 16 | return Array.from(arguments); 17 | } 18 | -------------------------------------------------------------------------------- /models/infer_templates/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/infer_templates/1.jpg -------------------------------------------------------------------------------- /models/infer_templates/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/infer_templates/3.jpg -------------------------------------------------------------------------------- /models/infer_templates/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/infer_templates/4.jpg -------------------------------------------------------------------------------- /models/infer_templates/5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/infer_templates/5.jpg -------------------------------------------------------------------------------- /models/infer_templates/6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/infer_templates/6.jpg -------------------------------------------------------------------------------- /models/infer_templates/7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/infer_templates/7.jpg -------------------------------------------------------------------------------- /models/infer_templates/8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/infer_templates/8.jpg -------------------------------------------------------------------------------- /models/infer_templates/9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/infer_templates/9.jpg -------------------------------------------------------------------------------- /models/infer_templates/Put templates here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/infer_templates/Put templates here -------------------------------------------------------------------------------- /models/stable-diffusion-v1-5/model_index.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "StableDiffusionPipeline", 3 | "_diffusers_version": "0.6.0", 4 | "feature_extractor": [ 5 | "transformers", 6 | "CLIPImageProcessor" 7 | ], 8 | "safety_checker": [ 9 | "stable_diffusion", 10 | "StableDiffusionSafetyChecker" 11 | ], 12 | "scheduler": [ 13 | "diffusers", 14 | "PNDMScheduler" 15 | ], 16 | "text_encoder": [ 17 | "transformers", 18 | "CLIPTextModel" 19 | ], 20 | "tokenizer": [ 21 | "transformers", 22 | "CLIPTokenizer" 23 | ], 24 | "unet": [ 25 | "diffusers", 26 | "UNet2DConditionModel" 27 | ], 28 | "vae": [ 29 | "diffusers", 30 | "AutoencoderKL" 31 | ] 32 | } 33 | -------------------------------------------------------------------------------- /models/stable-diffusion-v1-5/scheduler/scheduler_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "PNDMScheduler", 3 | "_diffusers_version": "0.6.0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "num_train_timesteps": 1000, 8 | "set_alpha_to_one": false, 9 | "skip_prk_steps": true, 10 | "steps_offset": 1, 11 | "trained_betas": null, 12 | "clip_sample": false 13 | } 14 | -------------------------------------------------------------------------------- /models/stable-diffusion-v1-5/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "<|startoftext|>", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "<|endoftext|>", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "pad_token": "<|endoftext|>", 17 | "unk_token": { 18 | "content": "<|endoftext|>", 19 | "lstrip": false, 20 | "normalized": true, 21 | "rstrip": false, 22 | "single_word": false 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /models/stable-diffusion-v1-5/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_prefix_space": false, 3 | "bos_token": { 4 | "__type": "AddedToken", 5 | "content": "<|startoftext|>", 6 | "lstrip": false, 7 | "normalized": true, 8 | "rstrip": false, 9 | "single_word": false 10 | }, 11 | "do_lower_case": true, 12 | "eos_token": { 13 | "__type": "AddedToken", 14 | "content": "<|endoftext|>", 15 | "lstrip": false, 16 | "normalized": true, 17 | "rstrip": false, 18 | "single_word": false 19 | }, 20 | "errors": "replace", 21 | "model_max_length": 77, 22 | "name_or_path": "openai/clip-vit-large-patch14", 23 | "pad_token": "<|endoftext|>", 24 | "special_tokens_map_file": "./special_tokens_map.json", 25 | "tokenizer_class": "CLIPTokenizer", 26 | "unk_token": { 27 | "__type": "AddedToken", 28 | "content": "<|endoftext|>", 29 | "lstrip": false, 30 | "normalized": true, 31 | "rstrip": false, 32 | "single_word": false 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /models/stable-diffusion-xl/madebyollin_sdxl_vae_fp16_fix/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "AutoencoderKL", 3 | "_diffusers_version": "0.18.0.dev0", 4 | "_name_or_path": ".", 5 | "act_fn": "silu", 6 | "block_out_channels": [ 7 | 128, 8 | 256, 9 | 512, 10 | 512 11 | ], 12 | "down_block_types": [ 13 | "DownEncoderBlock2D", 14 | "DownEncoderBlock2D", 15 | "DownEncoderBlock2D", 16 | "DownEncoderBlock2D" 17 | ], 18 | "in_channels": 3, 19 | "latent_channels": 4, 20 | "layers_per_block": 2, 21 | "norm_num_groups": 32, 22 | "out_channels": 3, 23 | "sample_size": 512, 24 | "scaling_factor": 0.13025, 25 | "up_block_types": [ 26 | "UpDecoderBlock2D", 27 | "UpDecoderBlock2D", 28 | "UpDecoderBlock2D", 29 | "UpDecoderBlock2D" 30 | ], 31 | "force_upcast": false 32 | } 33 | -------------------------------------------------------------------------------- /models/stable-diffusion-xl/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k/refs/main: -------------------------------------------------------------------------------- 1 | 8c7a3583335de4dba1b07182dbf81c75137ce67b -------------------------------------------------------------------------------- /models/stable-diffusion-xl/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k/snapshots/8c7a3583335de4dba1b07182dbf81c75137ce67b/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_commit_hash": null, 3 | "architectures": [ 4 | "CLIPModel" 5 | ], 6 | "initializer_factor": 1.0, 7 | "logit_scale_init_value": 2.6592, 8 | "model_type": "clip", 9 | "projection_dim": 1280, 10 | "text_config": { 11 | "_name_or_path": "", 12 | "add_cross_attention": false, 13 | "architectures": null, 14 | "attention_dropout": 0.0, 15 | "bad_words_ids": null, 16 | "begin_suppress_tokens": null, 17 | "bos_token_id": 0, 18 | "chunk_size_feed_forward": 0, 19 | "cross_attention_hidden_size": null, 20 | "decoder_start_token_id": null, 21 | "diversity_penalty": 0.0, 22 | "do_sample": false, 23 | "dropout": 0.0, 24 | "early_stopping": false, 25 | "encoder_no_repeat_ngram_size": 0, 26 | "eos_token_id": 2, 27 | "exponential_decay_length_penalty": null, 28 | "finetuning_task": null, 29 | "forced_bos_token_id": null, 30 | "forced_eos_token_id": null, 31 | "hidden_act": "gelu", 32 | "hidden_size": 1280, 33 | "id2label": { 34 | "0": "LABEL_0", 35 | "1": "LABEL_1" 36 | }, 37 | "initializer_factor": 1.0, 38 | "initializer_range": 0.02, 39 | "intermediate_size": 5120, 40 | "is_decoder": false, 41 | "is_encoder_decoder": false, 42 | "label2id": { 43 | "LABEL_0": 0, 44 | "LABEL_1": 1 45 | }, 46 | "layer_norm_eps": 1e-05, 47 | "length_penalty": 1.0, 48 | "max_length": 20, 49 | "max_position_embeddings": 77, 50 | "min_length": 0, 51 | "model_type": "clip_text_model", 52 | "no_repeat_ngram_size": 0, 53 | "num_attention_heads": 20, 54 | "num_beam_groups": 1, 55 | "num_beams": 1, 56 | "num_hidden_layers": 32, 57 | "num_return_sequences": 1, 58 | "output_attentions": false, 59 | "output_hidden_states": false, 60 | "output_scores": false, 61 | "pad_token_id": 1, 62 | "prefix": null, 63 | "problem_type": null, 64 | "pruned_heads": {}, 65 | "remove_invalid_values": false, 66 | "repetition_penalty": 1.0, 67 | "return_dict": true, 68 | "return_dict_in_generate": false, 69 | "sep_token_id": null, 70 | "suppress_tokens": null, 71 | "task_specific_params": null, 72 | "temperature": 1.0, 73 | "tf_legacy_loss": false, 74 | "tie_encoder_decoder": false, 75 | "tie_word_embeddings": true, 76 | "tokenizer_class": null, 77 | "top_k": 50, 78 | "top_p": 1.0, 79 | "torch_dtype": null, 80 | "torchscript": false, 81 | "transformers_version": "4.24.0", 82 | "typical_p": 1.0, 83 | "use_bfloat16": false, 84 | "vocab_size": 49408 85 | }, 86 | "text_config_dict": { 87 | "hidden_act": "gelu", 88 | "hidden_size": 1280, 89 | "intermediate_size": 5120, 90 | "num_attention_heads": 20, 91 | "num_hidden_layers": 32 92 | }, 93 | "torch_dtype": "float32", 94 | "transformers_version": null, 95 | "vision_config": { 96 | "_name_or_path": "", 97 | "add_cross_attention": false, 98 | "architectures": null, 99 | "attention_dropout": 0.0, 100 | "bad_words_ids": null, 101 | "begin_suppress_tokens": null, 102 | "bos_token_id": null, 103 | "chunk_size_feed_forward": 0, 104 | "cross_attention_hidden_size": null, 105 | "decoder_start_token_id": null, 106 | "diversity_penalty": 0.0, 107 | "do_sample": false, 108 | "dropout": 0.0, 109 | "early_stopping": false, 110 | "encoder_no_repeat_ngram_size": 0, 111 | "eos_token_id": null, 112 | "exponential_decay_length_penalty": null, 113 | "finetuning_task": null, 114 | "forced_bos_token_id": null, 115 | "forced_eos_token_id": null, 116 | "hidden_act": "gelu", 117 | "hidden_size": 1664, 118 | "id2label": { 119 | "0": "LABEL_0", 120 | "1": "LABEL_1" 121 | }, 122 | "image_size": 224, 123 | "initializer_factor": 1.0, 124 | "initializer_range": 0.02, 125 | "intermediate_size": 8192, 126 | "is_decoder": false, 127 | "is_encoder_decoder": false, 128 | "label2id": { 129 | "LABEL_0": 0, 130 | "LABEL_1": 1 131 | }, 132 | "layer_norm_eps": 1e-05, 133 | "length_penalty": 1.0, 134 | "max_length": 20, 135 | "min_length": 0, 136 | "model_type": "clip_vision_model", 137 | "no_repeat_ngram_size": 0, 138 | "num_attention_heads": 16, 139 | "num_beam_groups": 1, 140 | "num_beams": 1, 141 | "num_channels": 3, 142 | "num_hidden_layers": 48, 143 | "num_return_sequences": 1, 144 | "output_attentions": false, 145 | "output_hidden_states": false, 146 | "output_scores": false, 147 | "pad_token_id": null, 148 | "patch_size": 14, 149 | "prefix": null, 150 | "problem_type": null, 151 | "pruned_heads": {}, 152 | "remove_invalid_values": false, 153 | "repetition_penalty": 1.0, 154 | "return_dict": true, 155 | "return_dict_in_generate": false, 156 | "sep_token_id": null, 157 | "suppress_tokens": null, 158 | "task_specific_params": null, 159 | "temperature": 1.0, 160 | "tf_legacy_loss": false, 161 | "tie_encoder_decoder": false, 162 | "tie_word_embeddings": true, 163 | "tokenizer_class": null, 164 | "top_k": 50, 165 | "top_p": 1.0, 166 | "torch_dtype": null, 167 | "torchscript": false, 168 | "transformers_version": "4.24.0", 169 | "typical_p": 1.0, 170 | "use_bfloat16": false 171 | }, 172 | "vision_config_dict": { 173 | "hidden_act": "gelu", 174 | "hidden_size": 1664, 175 | "intermediate_size": 8192, 176 | "num_attention_heads": 16, 177 | "num_hidden_layers": 48, 178 | "patch_size": 14 179 | } 180 | } 181 | -------------------------------------------------------------------------------- /models/stable-diffusion-xl/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k/snapshots/8c7a3583335de4dba1b07182dbf81c75137ce67b/open_clip_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "model_cfg": { 3 | "embed_dim": 1280, 4 | "vision_cfg": { 5 | "image_size": 224, 6 | "layers": 48, 7 | "width": 1664, 8 | "head_width": 104, 9 | "mlp_ratio": 4.9231, 10 | "patch_size": 14 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 1280, 16 | "heads": 20, 17 | "layers": 32 18 | } 19 | }, 20 | "preprocess_cfg": { 21 | "mean": [ 22 | 0.48145466, 23 | 0.4578275, 24 | 0.40821073 25 | ], 26 | "std": [ 27 | 0.26862954, 28 | 0.26130258, 29 | 0.27577711 30 | ] 31 | } 32 | } -------------------------------------------------------------------------------- /models/stable-diffusion-xl/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k/snapshots/8c7a3583335de4dba1b07182dbf81c75137ce67b/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "crop_size": 224, 3 | "do_center_crop": true, 4 | "do_normalize": true, 5 | "do_resize": true, 6 | "feature_extractor_type": "CLIPFeatureExtractor", 7 | "image_mean": [ 8 | 0.48145466, 9 | 0.4578275, 10 | 0.40821073 11 | ], 12 | "image_std": [ 13 | 0.26862954, 14 | 0.26130258, 15 | 0.27577711 16 | ], 17 | "resample": 3, 18 | "size": 224 19 | } 20 | -------------------------------------------------------------------------------- /models/stable-diffusion-xl/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k/snapshots/8c7a3583335de4dba1b07182dbf81c75137ce67b/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /models/stable-diffusion-xl/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k/snapshots/8c7a3583335de4dba1b07182dbf81c75137ce67b/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "unk_token": { 3 | "content": "<|endoftext|>", 4 | "single_word": false, 5 | "lstrip": false, 6 | "rstrip": false, 7 | "normalized": true, 8 | "__type": "AddedToken" 9 | }, 10 | "bos_token": { 11 | "content": "<|startoftext|>", 12 | "single_word": false, 13 | "lstrip": false, 14 | "rstrip": false, 15 | "normalized": true, 16 | "__type": "AddedToken" 17 | }, 18 | "eos_token": { 19 | "content": "<|endoftext|>", 20 | "single_word": false, 21 | "lstrip": false, 22 | "rstrip": false, 23 | "normalized": true, 24 | "__type": "AddedToken" 25 | }, 26 | "pad_token": "<|endoftext|>", 27 | "add_prefix_space": false, 28 | "errors": "replace", 29 | "do_lower_case": true, 30 | "name_or_path": "openai/clip-vit-base-patch32", 31 | "model_max_length": 77, 32 | "special_tokens_map_file": "./special_tokens_map.json", 33 | "tokenizer_class": "CLIPTokenizer" 34 | } -------------------------------------------------------------------------------- /models/stable-diffusion-xl/models--openai--clip-vit-large-patch14/refs/main: -------------------------------------------------------------------------------- 1 | 32bd64288804d66eefd0ccbe215aa642df71cc41 -------------------------------------------------------------------------------- /models/stable-diffusion-xl/models--openai--clip-vit-large-patch14/snapshots/32bd64288804d66eefd0ccbe215aa642df71cc41/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "clip-vit-large-patch14/", 3 | "architectures": [ 4 | "CLIPModel" 5 | ], 6 | "initializer_factor": 1.0, 7 | "logit_scale_init_value": 2.6592, 8 | "model_type": "clip", 9 | "projection_dim": 768, 10 | "text_config": { 11 | "_name_or_path": "", 12 | "add_cross_attention": false, 13 | "architectures": null, 14 | "attention_dropout": 0.0, 15 | "bad_words_ids": null, 16 | "bos_token_id": 0, 17 | "chunk_size_feed_forward": 0, 18 | "cross_attention_hidden_size": null, 19 | "decoder_start_token_id": null, 20 | "diversity_penalty": 0.0, 21 | "do_sample": false, 22 | "dropout": 0.0, 23 | "early_stopping": false, 24 | "encoder_no_repeat_ngram_size": 0, 25 | "eos_token_id": 2, 26 | "finetuning_task": null, 27 | "forced_bos_token_id": null, 28 | "forced_eos_token_id": null, 29 | "hidden_act": "quick_gelu", 30 | "hidden_size": 768, 31 | "id2label": { 32 | "0": "LABEL_0", 33 | "1": "LABEL_1" 34 | }, 35 | "initializer_factor": 1.0, 36 | "initializer_range": 0.02, 37 | "intermediate_size": 3072, 38 | "is_decoder": false, 39 | "is_encoder_decoder": false, 40 | "label2id": { 41 | "LABEL_0": 0, 42 | "LABEL_1": 1 43 | }, 44 | "layer_norm_eps": 1e-05, 45 | "length_penalty": 1.0, 46 | "max_length": 20, 47 | "max_position_embeddings": 77, 48 | "min_length": 0, 49 | "model_type": "clip_text_model", 50 | "no_repeat_ngram_size": 0, 51 | "num_attention_heads": 12, 52 | "num_beam_groups": 1, 53 | "num_beams": 1, 54 | "num_hidden_layers": 12, 55 | "num_return_sequences": 1, 56 | "output_attentions": false, 57 | "output_hidden_states": false, 58 | "output_scores": false, 59 | "pad_token_id": 1, 60 | "prefix": null, 61 | "problem_type": null, 62 | "projection_dim" : 768, 63 | "pruned_heads": {}, 64 | "remove_invalid_values": false, 65 | "repetition_penalty": 1.0, 66 | "return_dict": true, 67 | "return_dict_in_generate": false, 68 | "sep_token_id": null, 69 | "task_specific_params": null, 70 | "temperature": 1.0, 71 | "tie_encoder_decoder": false, 72 | "tie_word_embeddings": true, 73 | "tokenizer_class": null, 74 | "top_k": 50, 75 | "top_p": 1.0, 76 | "torch_dtype": null, 77 | "torchscript": false, 78 | "transformers_version": "4.16.0.dev0", 79 | "use_bfloat16": false, 80 | "vocab_size": 49408 81 | }, 82 | "text_config_dict": { 83 | "hidden_size": 768, 84 | "intermediate_size": 3072, 85 | "num_attention_heads": 12, 86 | "num_hidden_layers": 12, 87 | "projection_dim": 768 88 | }, 89 | "torch_dtype": "float32", 90 | "transformers_version": null, 91 | "vision_config": { 92 | "_name_or_path": "", 93 | "add_cross_attention": false, 94 | "architectures": null, 95 | "attention_dropout": 0.0, 96 | "bad_words_ids": null, 97 | "bos_token_id": null, 98 | "chunk_size_feed_forward": 0, 99 | "cross_attention_hidden_size": null, 100 | "decoder_start_token_id": null, 101 | "diversity_penalty": 0.0, 102 | "do_sample": false, 103 | "dropout": 0.0, 104 | "early_stopping": false, 105 | "encoder_no_repeat_ngram_size": 0, 106 | "eos_token_id": null, 107 | "finetuning_task": null, 108 | "forced_bos_token_id": null, 109 | "forced_eos_token_id": null, 110 | "hidden_act": "quick_gelu", 111 | "hidden_size": 1024, 112 | "id2label": { 113 | "0": "LABEL_0", 114 | "1": "LABEL_1" 115 | }, 116 | "image_size": 224, 117 | "initializer_factor": 1.0, 118 | "initializer_range": 0.02, 119 | "intermediate_size": 4096, 120 | "is_decoder": false, 121 | "is_encoder_decoder": false, 122 | "label2id": { 123 | "LABEL_0": 0, 124 | "LABEL_1": 1 125 | }, 126 | "layer_norm_eps": 1e-05, 127 | "length_penalty": 1.0, 128 | "max_length": 20, 129 | "min_length": 0, 130 | "model_type": "clip_vision_model", 131 | "no_repeat_ngram_size": 0, 132 | "num_attention_heads": 16, 133 | "num_beam_groups": 1, 134 | "num_beams": 1, 135 | "num_hidden_layers": 24, 136 | "num_return_sequences": 1, 137 | "output_attentions": false, 138 | "output_hidden_states": false, 139 | "output_scores": false, 140 | "pad_token_id": null, 141 | "patch_size": 14, 142 | "prefix": null, 143 | "problem_type": null, 144 | "projection_dim" : 768, 145 | "pruned_heads": {}, 146 | "remove_invalid_values": false, 147 | "repetition_penalty": 1.0, 148 | "return_dict": true, 149 | "return_dict_in_generate": false, 150 | "sep_token_id": null, 151 | "task_specific_params": null, 152 | "temperature": 1.0, 153 | "tie_encoder_decoder": false, 154 | "tie_word_embeddings": true, 155 | "tokenizer_class": null, 156 | "top_k": 50, 157 | "top_p": 1.0, 158 | "torch_dtype": null, 159 | "torchscript": false, 160 | "transformers_version": "4.16.0.dev0", 161 | "use_bfloat16": false 162 | }, 163 | "vision_config_dict": { 164 | "hidden_size": 1024, 165 | "intermediate_size": 4096, 166 | "num_attention_heads": 16, 167 | "num_hidden_layers": 24, 168 | "patch_size": 14, 169 | "projection_dim": 768 170 | } 171 | } 172 | -------------------------------------------------------------------------------- /models/stable-diffusion-xl/models--openai--clip-vit-large-patch14/snapshots/32bd64288804d66eefd0ccbe215aa642df71cc41/preprocessor_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "crop_size": 224, 3 | "do_center_crop": true, 4 | "do_normalize": true, 5 | "do_resize": true, 6 | "feature_extractor_type": "CLIPFeatureExtractor", 7 | "image_mean": [ 8 | 0.48145466, 9 | 0.4578275, 10 | 0.40821073 11 | ], 12 | "image_std": [ 13 | 0.26862954, 14 | 0.26130258, 15 | 0.27577711 16 | ], 17 | "resample": 3, 18 | "size": 224 19 | } 20 | -------------------------------------------------------------------------------- /models/stable-diffusion-xl/models--openai--clip-vit-large-patch14/snapshots/32bd64288804d66eefd0ccbe215aa642df71cc41/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"} -------------------------------------------------------------------------------- /models/stable-diffusion-xl/models--openai--clip-vit-large-patch14/snapshots/32bd64288804d66eefd0ccbe215aa642df71cc41/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "unk_token": { 3 | "content": "<|endoftext|>", 4 | "single_word": false, 5 | "lstrip": false, 6 | "rstrip": false, 7 | "normalized": true, 8 | "__type": "AddedToken" 9 | }, 10 | "bos_token": { 11 | "content": "<|startoftext|>", 12 | "single_word": false, 13 | "lstrip": false, 14 | "rstrip": false, 15 | "normalized": true, 16 | "__type": "AddedToken" 17 | }, 18 | "eos_token": { 19 | "content": "<|endoftext|>", 20 | "single_word": false, 21 | "lstrip": false, 22 | "rstrip": false, 23 | "normalized": true, 24 | "__type": "AddedToken" 25 | }, 26 | "pad_token": "<|endoftext|>", 27 | "add_prefix_space": false, 28 | "errors": "replace", 29 | "do_lower_case": true, 30 | "name_or_path": "openai/clip-vit-base-patch32", 31 | "model_max_length": 77, 32 | "special_tokens_map_file": "./special_tokens_map.json", 33 | "tokenizer_class": "CLIPTokenizer" 34 | } 35 | -------------------------------------------------------------------------------- /models/stable-diffusion-xl/stabilityai_stable_diffusion_xl_base_1.0/scheduler/scheduler_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_class_name": "EulerDiscreteScheduler", 3 | "_diffusers_version": "0.19.0.dev0", 4 | "beta_end": 0.012, 5 | "beta_schedule": "scaled_linear", 6 | "beta_start": 0.00085, 7 | "clip_sample": false, 8 | "interpolation_type": "linear", 9 | "num_train_timesteps": 1000, 10 | "prediction_type": "epsilon", 11 | "sample_max_value": 1.0, 12 | "set_alpha_to_one": false, 13 | "skip_prk_steps": true, 14 | "steps_offset": 1, 15 | "timestep_spacing": "leading", 16 | "trained_betas": null, 17 | "use_karras_sigmas": false 18 | } -------------------------------------------------------------------------------- /models/training_templates/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/training_templates/1.jpg -------------------------------------------------------------------------------- /models/training_templates/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/training_templates/2.jpg -------------------------------------------------------------------------------- /models/training_templates/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/training_templates/3.jpg -------------------------------------------------------------------------------- /models/training_templates/4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/training_templates/4.jpg -------------------------------------------------------------------------------- /scripts/easyphoto_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from modules.paths import data_path, models_path, extensions_builtin_dir, extensions_dir 3 | 4 | # save_dirs 5 | data_dir = data_path 6 | models_path = models_path 7 | extensions_builtin_dir = extensions_builtin_dir 8 | extensions_dir = extensions_dir 9 | easyphoto_models_path = os.path.abspath(os.path.dirname(__file__)).replace("scripts", "models") 10 | easyphoto_img2img_samples = os.path.join(data_dir, "outputs/img2img-images") 11 | easyphoto_txt2img_samples = os.path.join(data_dir, "outputs/txt2img-images") 12 | easyphoto_outpath_samples = os.path.join(data_dir, "outputs/easyphoto-outputs") 13 | easyphoto_video_outpath_samples = os.path.join(data_dir, "outputs/easyphoto-video-outputs") 14 | user_id_outpath_samples = os.path.join(data_dir, "outputs/easyphoto-user-id-infos") 15 | cloth_id_outpath_samples = os.path.join(data_dir, "outputs/easyphoto-cloth-id-infos") 16 | scene_id_outpath_samples = os.path.join(data_dir, "outputs/easyphoto-scene-id-infos") 17 | cache_log_file_path = os.path.join(data_dir, "outputs/easyphoto-tmp/train_kohya_log.txt") 18 | 19 | # gallery_dir 20 | tryon_preview_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)).replace("scripts", "images"), "tryon") 21 | tryon_gallery_dir = os.path.join(cloth_id_outpath_samples, "gallery") 22 | 23 | # prompts 24 | validation_prompt = "easyphoto_face, easyphoto, 1person" 25 | validation_prompt_scene = "special_scene, scene" 26 | validation_tryon_prompt = "easyphoto, 1thing" 27 | DEFAULT_POSITIVE = "(cloth:1.5), (best quality), (realistic, photo-realistic:1.3), (beautiful eyes:1.3), (sparkling eyes:1.3), (beautiful mouth:1.3), finely detail, light smile, extremely detailed CG unity 8k wallpaper, huge filesize, best quality, realistic, photo-realistic, ultra high res, raw photo, put on makeup" 28 | DEFAULT_NEGATIVE = "(bags under the eyes:1.5), (bags under eyes:1.5), (earrings:1.3), (glasses:1.2), (naked:1.5), (nsfw:1.5), nude, breasts, penis, cum, (over red lips: 1.3), (bad lips: 1.3), (bad ears:1.3), (bad hair: 1.3), (bad teeth: 1.3), (worst quality:2), (low quality:2), (normal quality:2), lowres, watermark, badhand, lowres, bad anatomy, bad hands, normal quality, mural," 29 | DEFAULT_POSITIVE_AD = "(realistic, photorealistic), (masterpiece, best quality, high quality), (delicate eyes and face), extremely detailed CG unity 8k wallpaper, best quality, realistic, photo-realistic, ultra high res, raw photo" 30 | DEFAULT_NEGATIVE_AD = "(naked:1.2), (nsfw:1.2), nipple slip, nude, breasts, (huge breasts:1.2), penis, cum, (blurry background:1.3), (depth of field:1.7), (holding:2), (worst quality:2), (normal quality:2), lowres, bad anatomy, bad hands" 31 | DEFAULT_POSITIVE_T2I = "(cloth:1.0), (best quality), (realistic, photo-realistic:1.3), film photography, minor acne, (portrait:1.1), (indirect lighting), extremely detailed CG unity 8k wallpaper, huge filesize, best quality, realistic, photo-realistic, ultra high res, raw photo, put on makeup" 32 | DEFAULT_NEGATIVE_T2I = "(nsfw:1.5), (huge breast:1.5), nude, breasts, penis, cum, bokeh, cgi, illustration, cartoon, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, ugly, deformed, blurry, Noisy, log, text (worst quality:2), (low quality:2), (normal quality:2), lowres, watermark, badhand, lowres" 33 | 34 | # scene lora 35 | DEFAULT_SCENE_LORA = [ 36 | "Christmas_1", 37 | "Cyberpunk_1", 38 | "FairMaidenStyle_1", 39 | "Gentleman_1", 40 | "GuoFeng_1", 41 | "GuoFeng_2", 42 | "GuoFeng_3", 43 | "GuoFeng_4", 44 | "Minimalism_1", 45 | "NaturalWind_1", 46 | "Princess_1", 47 | "Princess_2", 48 | "Princess_3", 49 | "SchoolUniform_1", 50 | "SchoolUniform_2", 51 | ] 52 | 53 | # tryon template 54 | DEFAULT_TRYON_TEMPLATE = ["boy", "girl", "dress", "short"] 55 | 56 | # cloth lora 57 | DEFAULT_CLOTH_LORA = ["demo_black_200", "demo_white_200", "demo_purple_200", "demo_dress_200", "demo_short_200"] 58 | 59 | # sliders 60 | DEFAULT_SLIDERS = ["age_sd1_sliders", "smiling_sd1_sliders", "age_sdxl_sliders", "smiling_sdxl_sliders"] 61 | 62 | # ModelName 63 | SDXL_MODEL_NAME = "SDXL_1.0_ArienMixXL_v2.0.safetensors" 64 | -------------------------------------------------------------------------------- /scripts/easyphoto_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .face_process_utils import ( 2 | Face_Skin, 3 | alignment_photo, 4 | call_face_crop, 5 | call_face_crop_templates, 6 | color_transfer, 7 | crop_and_paste, 8 | safe_get_box_mask_keypoints_and_padding_image, 9 | ) 10 | from .fire_utils import FIRE_forward 11 | from .psgan_utils import PSGAN_Inference 12 | from .tryon_utils import ( 13 | align_and_overlay_images, 14 | apply_mask_to_image, 15 | compute_rotation_angle, 16 | copy_white_mask_to_template, 17 | crop_image, 18 | expand_box_by_pad, 19 | expand_roi, 20 | find_best_angle_ratio, 21 | get_background_color, 22 | mask_to_box, 23 | mask_to_polygon, 24 | merge_with_inner_canny, 25 | prepare_tryon_train_data, 26 | resize_and_stretch, 27 | resize_image_with_pad, 28 | seg_by_box, 29 | find_connected_components, 30 | ) 31 | 32 | from .common_utils import ( 33 | check_files_exists_and_download, 34 | check_id_valid, 35 | check_scene_valid, 36 | convert_to_video, 37 | ep_logger, 38 | get_controlnet_version, 39 | get_mov_all_images, 40 | modelscope_models_to_cpu, 41 | modelscope_models_to_gpu, 42 | switch_ms_model_cpu, 43 | unload_models, 44 | seed_everything, 45 | get_attribute_edit_ids, 46 | encode_video_to_base64, 47 | decode_base64_to_video, 48 | cleanup_decorator, 49 | auto_to_gpu_model, 50 | ) 51 | from .loractl_utils import check_loractl_conflict, LoraCtlScript 52 | from .animatediff_utils import ( 53 | AnimateDiffControl, 54 | AnimateDiffI2VLatent, 55 | AnimateDiffInfV2V, 56 | AnimateDiffLora, 57 | AnimateDiffMM, 58 | AnimateDiffOutput, 59 | AnimateDiffProcess, 60 | AnimateDiffPromptSchedule, 61 | AnimateDiffUiGroup, 62 | animatediff_i2ibatch, 63 | motion_module, 64 | update_infotext, 65 | video_visible, 66 | ) 67 | -------------------------------------------------------------------------------- /scripts/easyphoto_utils/animatediff/README.MD: -------------------------------------------------------------------------------- 1 | ## 如何修改animatediff 2 | 1. 将根目录下的motion_module.py拷贝进scripts文件夹。 3 | 2. 将scripts下的文件,其中每个py文件里面的绝对导入改为相对路径,例如: 4 | ``` 5 | scripts.animatediff_logger => .animatediff_logger 6 | ``` 7 | 即可。 8 | 9 | ## hack的代码 10 | hack代码不多,主要在animatediff_utils.py中, 11 | 12 | 1. 重写AnimateDiffControl,为了可以获取批处理的图片,从视频中。 13 | 2. 重写AnimateDiffMM,为了加载模型。 14 | 3. 重写AnimateDiffProcess与AnimateDiffI2VLatent,为了image2video的合理的特征保留。 15 | 4. 重写AnimateDiffScript,为了easyphoto更简单的调用。 -------------------------------------------------------------------------------- /scripts/easyphoto_utils/animatediff/animatediff_infotext.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from modules.paths import data_path 4 | from modules.processing import StableDiffusionProcessing, StableDiffusionProcessingImg2Img 5 | 6 | from .animatediff_ui import AnimateDiffProcess 7 | from .animatediff_logger import logger_animatediff as logger 8 | 9 | 10 | def update_infotext(p: StableDiffusionProcessing, params: AnimateDiffProcess): 11 | if p.extra_generation_params is not None: 12 | p.extra_generation_params["AnimateDiff"] = params.get_dict(isinstance(p, StableDiffusionProcessingImg2Img)) 13 | 14 | 15 | def write_params_txt(info: str): 16 | with open(os.path.join(data_path, "params.txt"), "w", encoding="utf8") as file: 17 | file.write(info) 18 | 19 | 20 | 21 | def infotext_pasted(infotext, results): 22 | for k, v in results.items(): 23 | if not k.startswith("AnimateDiff"): 24 | continue 25 | 26 | assert isinstance(v, str), f"Expect string but got {v}." 27 | try: 28 | for items in v.split(', '): 29 | field, value = items.split(': ') 30 | results[f"AnimateDiff {field}"] = value 31 | except Exception: 32 | logger.warn( 33 | f"Failed to parse infotext, legacy format infotext is no longer supported:\n{v}" 34 | ) 35 | break 36 | -------------------------------------------------------------------------------- /scripts/easyphoto_utils/animatediff/animatediff_latent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from modules import images, shared 4 | from modules.devices import device, dtype_vae, torch_gc 5 | from modules.processing import StableDiffusionProcessingImg2Img 6 | from modules.sd_samplers_common import (approximation_indexes, 7 | images_tensor_to_samples) 8 | 9 | from .animatediff_logger import logger_animatediff as logger 10 | from .animatediff_ui import AnimateDiffProcess 11 | 12 | 13 | class AnimateDiffI2VLatent: 14 | def randomize( 15 | self, p: StableDiffusionProcessingImg2Img, params: AnimateDiffProcess 16 | ): 17 | # Get init_alpha 18 | reserve_scale = [ 19 | 0.75 for i in range(params.video_length) 20 | ] 21 | logger.info(f"Randomizing reserve_scale according to {reserve_scale}.") 22 | reserve_scale = torch.tensor(reserve_scale, dtype=torch.float32, device=device)[ 23 | :, None, None, None 24 | ] 25 | reserve_scale[reserve_scale < 0] = 0 26 | 27 | if params.last_frame is not None: 28 | # Get init_alpha 29 | init_alpha = [ 30 | 1 - pow(i, params.latent_power) / params.latent_scale 31 | for i in range(params.video_length) 32 | ] 33 | logger.info(f"Randomizing init_latent according to {init_alpha}.") 34 | init_alpha = torch.tensor(init_alpha, dtype=torch.float32, device=device)[ 35 | :, None, None, None 36 | ] 37 | init_alpha[init_alpha < 0] = 0 38 | 39 | last_frame = params.last_frame 40 | if type(last_frame) == str: 41 | from modules.api.api import decode_base64_to_image 42 | last_frame = decode_base64_to_image(last_frame) 43 | # Get last_alpha 44 | last_alpha = [ 45 | 1 - pow(i, params.latent_power_last) / params.latent_scale_last 46 | for i in range(params.video_length) 47 | ] 48 | last_alpha.reverse() 49 | logger.info(f"Randomizing last_latent according to {last_alpha}.") 50 | last_alpha = torch.tensor(last_alpha, dtype=torch.float32, device=device)[ 51 | :, None, None, None 52 | ] 53 | last_alpha[last_alpha < 0] = 0 54 | 55 | # Normalize alpha 56 | sum_alpha = init_alpha + last_alpha 57 | mask_alpha = sum_alpha > 1 58 | scaling_factor = 1 / sum_alpha[mask_alpha] 59 | init_alpha[mask_alpha] *= scaling_factor 60 | last_alpha[mask_alpha] *= scaling_factor 61 | init_alpha[0] = 1 62 | init_alpha[-1] = 0 63 | last_alpha[0] = 0 64 | last_alpha[-1] = 1 65 | 66 | # Calculate last_latent 67 | if p.resize_mode != 3: 68 | last_frame = images.resize_image( 69 | p.resize_mode, last_frame, p.width, p.height 70 | ) 71 | last_frame = np.array(last_frame).astype(np.float32) / 255.0 72 | last_frame = np.moveaxis(last_frame, 2, 0)[None, ...] 73 | last_frame = torch.from_numpy(last_frame).to(device).to(dtype_vae) 74 | last_latent = images_tensor_to_samples( 75 | last_frame, 76 | approximation_indexes.get(shared.opts.sd_vae_encode_method), 77 | p.sd_model, 78 | ) 79 | torch_gc() 80 | if p.resize_mode == 3: 81 | opt_f = 8 82 | last_latent = torch.nn.functional.interpolate( 83 | last_latent, 84 | size=(p.height // opt_f, p.width // opt_f), 85 | mode="bilinear", 86 | ) 87 | # Modify init_latent 88 | p.init_latent = ( 89 | (p.init_latent * init_alpha 90 | + last_latent * last_alpha 91 | + p.rng.next() * (1 - init_alpha - last_alpha)) * reserve_scale 92 | + p.rng.next() * (1 - reserve_scale) 93 | ) 94 | else: 95 | p.init_latent = p.init_latent * reserve_scale + p.rng.next() * (1 - reserve_scale) 96 | -------------------------------------------------------------------------------- /scripts/easyphoto_utils/animatediff/animatediff_lcm.py: -------------------------------------------------------------------------------- 1 | 2 | # TODO: remove this file when LCM is merged to A1111 3 | import torch 4 | 5 | from k_diffusion import utils, sampling 6 | from k_diffusion.external import DiscreteEpsDDPMDenoiser 7 | from k_diffusion.sampling import default_noise_sampler, trange 8 | 9 | from modules import shared, sd_samplers_cfg_denoiser, sd_samplers_kdiffusion 10 | from .animatediff_logger import logger_animatediff as logger 11 | 12 | 13 | class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser): 14 | def __init__(self, model): 15 | timesteps = 1000 16 | beta_start = 0.00085 17 | beta_end = 0.012 18 | 19 | betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2 20 | alphas = 1.0 - betas 21 | alphas_cumprod = torch.cumprod(alphas, dim=0) 22 | 23 | original_timesteps = 50 # LCM Original Timesteps (default=50, for current version of LCM) 24 | self.skip_steps = timesteps // original_timesteps 25 | 26 | 27 | alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32, device=model.device) 28 | for x in range(original_timesteps): 29 | alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps] 30 | 31 | super().__init__(model, alphas_cumprod_valid, quantize=None) 32 | 33 | 34 | def get_sigmas(self, n=None, sgm=False): 35 | if n is None: 36 | return sampling.append_zero(self.sigmas.flip(0)) 37 | 38 | start = self.sigma_to_t(self.sigma_max) 39 | end = self.sigma_to_t(self.sigma_min) 40 | 41 | if sgm: 42 | t = torch.linspace(start, end, n + 1, device=shared.sd_model.device)[:-1] 43 | else: 44 | t = torch.linspace(start, end, n, device=shared.sd_model.device) 45 | 46 | return sampling.append_zero(self.t_to_sigma(t)) 47 | 48 | 49 | def sigma_to_t(self, sigma, quantize=None): 50 | log_sigma = sigma.log() 51 | dists = log_sigma - self.log_sigmas[:, None] 52 | return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1) 53 | 54 | 55 | def t_to_sigma(self, timestep): 56 | t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1)) 57 | return super().t_to_sigma(t) 58 | 59 | 60 | def get_eps(self, *args, **kwargs): 61 | return self.inner_model.apply_model(*args, **kwargs) 62 | 63 | 64 | def get_scaled_out(self, sigma, output, input): 65 | sigma_data = 0.5 66 | scaled_timestep = utils.append_dims(self.sigma_to_t(sigma), output.ndim) * 10.0 67 | 68 | c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) 69 | c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 70 | 71 | return c_out * output + c_skip * input 72 | 73 | 74 | def forward(self, input, sigma, **kwargs): 75 | c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)] 76 | eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs) 77 | return self.get_scaled_out(sigma, input + eps * c_out, input) 78 | 79 | 80 | def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None): 81 | extra_args = {} if extra_args is None else extra_args 82 | noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler 83 | s_in = x.new_ones([x.shape[0]]) 84 | 85 | for i in trange(len(sigmas) - 1, disable=disable): 86 | denoised = model(x, sigmas[i] * s_in, **extra_args) 87 | 88 | if callback is not None: 89 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) 90 | 91 | x = denoised 92 | if sigmas[i + 1] > 0: 93 | x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) 94 | return x 95 | 96 | 97 | class CFGDenoiserLCM(sd_samplers_cfg_denoiser.CFGDenoiser): 98 | @property 99 | def inner_model(self): 100 | if self.model_wrap is None: 101 | denoiser = LCMCompVisDenoiser 102 | self.model_wrap = denoiser(shared.sd_model) 103 | 104 | return self.model_wrap 105 | 106 | 107 | class LCMSampler(sd_samplers_kdiffusion.KDiffusionSampler): 108 | def __init__(self, funcname, sd_model, options=None): 109 | super().__init__(funcname, sd_model, options) 110 | self.model_wrap_cfg = CFGDenoiserLCM(self) 111 | self.model_wrap = self.model_wrap_cfg.inner_model 112 | 113 | 114 | class AnimateDiffLCM: 115 | lcm_ui_injected = False 116 | 117 | 118 | @staticmethod 119 | def hack_kdiff_ui(): 120 | if AnimateDiffLCM.lcm_ui_injected: 121 | logger.info(f"LCM UI already injected.") 122 | return 123 | 124 | logger.info(f"Injecting LCM to UI.") 125 | from modules import sd_samplers, sd_samplers_common 126 | samplers_lcm = [('LCM', sample_lcm, ['k_lcm'], {})] 127 | samplers_data_lcm = [ 128 | sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: LCMSampler(funcname, model), aliases, options) 129 | for label, funcname, aliases, options in samplers_lcm 130 | ] 131 | sd_samplers.all_samplers.extend(samplers_data_lcm) 132 | sd_samplers.all_samplers_map = {x.name: x for x in sd_samplers.all_samplers} 133 | sd_samplers.set_samplers() 134 | AnimateDiffLCM.lcm_ui_injected = True 135 | -------------------------------------------------------------------------------- /scripts/easyphoto_utils/animatediff/animatediff_logger.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import sys 4 | 5 | from modules import shared 6 | 7 | 8 | class ColoredFormatter(logging.Formatter): 9 | COLORS = { 10 | "DEBUG": "\033[0;36m", # CYAN 11 | "INFO": "\033[0;32m", # GREEN 12 | "WARNING": "\033[0;33m", # YELLOW 13 | "ERROR": "\033[0;31m", # RED 14 | "CRITICAL": "\033[0;37;41m", # WHITE ON RED 15 | "RESET": "\033[0m", # RESET COLOR 16 | } 17 | 18 | def format(self, record): 19 | colored_record = copy.copy(record) 20 | levelname = colored_record.levelname 21 | seq = self.COLORS.get(levelname, self.COLORS["RESET"]) 22 | colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" 23 | return super().format(colored_record) 24 | 25 | 26 | # Create a new logger 27 | logger_animatediff = logging.getLogger("AnimateDiff") 28 | logger_animatediff.propagate = False 29 | 30 | # Add handler if we don't have one. 31 | if not logger_animatediff.handlers: 32 | handler = logging.StreamHandler(sys.stdout) 33 | handler.setFormatter( 34 | ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") 35 | ) 36 | logger_animatediff.addHandler(handler) 37 | 38 | # Configure logger 39 | loglevel_string = getattr(shared.cmd_opts, "animatediff_loglevel", "INFO") 40 | loglevel = getattr(logging, loglevel_string.upper(), None) 41 | logger_animatediff.setLevel(loglevel) 42 | -------------------------------------------------------------------------------- /scripts/easyphoto_utils/animatediff/animatediff_lora.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import sys 4 | 5 | from modules import sd_models, shared 6 | from modules.paths import extensions_builtin_dir 7 | 8 | from .animatediff_logger import logger_animatediff as logger 9 | 10 | sys.path.append(f"{extensions_builtin_dir}/Lora") 11 | 12 | class AnimateDiffLora: 13 | original_load_network = None 14 | 15 | def __init__(self, v2: bool): 16 | self.v2 = v2 17 | 18 | def hack(self): 19 | if not self.v2: 20 | return 21 | 22 | if AnimateDiffLora.original_load_network is not None: 23 | logger.info("AnimateDiff LoRA already hacked") 24 | return 25 | 26 | logger.info("Hacking LoRA module to support motion LoRA") 27 | import network 28 | import networks 29 | AnimateDiffLora.original_load_network = networks.load_network 30 | original_load_network = AnimateDiffLora.original_load_network 31 | 32 | def mm_load_network(name, network_on_disk): 33 | 34 | def convert_mm_name_to_compvis(key): 35 | sd_module_key, _, network_part = re.split(r'(_lora\.)', key) 36 | sd_module_key = sd_module_key.replace("processor.", "").replace("to_out", "to_out.0") 37 | return sd_module_key, 'lora_' + network_part 38 | 39 | net = network.Network(name, network_on_disk) 40 | net.mtime = os.path.getmtime(network_on_disk.filename) 41 | 42 | sd = sd_models.read_state_dict(network_on_disk.filename) 43 | 44 | if 'motion_modules' in list(sd.keys())[0]: 45 | logger.info(f"Loading motion LoRA {name} from {network_on_disk.filename}") 46 | matched_networks = {} 47 | 48 | for key_network, weight in sd.items(): 49 | key, network_part = convert_mm_name_to_compvis(key_network) 50 | sd_module = shared.sd_model.network_layer_mapping.get(key, None) 51 | 52 | assert sd_module is not None, f"Failed to find sd module for key {key}." 53 | 54 | if key not in matched_networks: 55 | matched_networks[key] = network.NetworkWeights( 56 | network_key=key_network, sd_key=key, w={}, sd_module=sd_module) 57 | 58 | matched_networks[key].w[network_part] = weight 59 | 60 | for key, weights in matched_networks.items(): 61 | net_module = networks.module_types[0].create_module(net, weights) 62 | assert net_module is not None, "Failed to create motion module LoRA" 63 | net.modules[key] = net_module 64 | 65 | return net 66 | else: 67 | del sd 68 | return original_load_network(name, network_on_disk) 69 | 70 | networks.load_network = mm_load_network 71 | 72 | 73 | def restore(self): 74 | if not self.v2: 75 | return 76 | 77 | if AnimateDiffLora.original_load_network is None: 78 | logger.info("AnimateDiff LoRA already restored") 79 | return 80 | 81 | logger.info("Restoring hacked LoRA") 82 | import networks 83 | networks.load_network = AnimateDiffLora.original_load_network 84 | AnimateDiffLora.original_load_network = None 85 | -------------------------------------------------------------------------------- /scripts/easyphoto_utils/animatediff/animatediff_mm.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | 4 | import torch 5 | from einops import rearrange 6 | from modules import hashes, shared, sd_models, devices 7 | from modules.devices import cpu, device, torch_gc 8 | 9 | from .motion_module import MotionWrapper, MotionModuleType 10 | from .animatediff_logger import logger_animatediff as logger 11 | 12 | 13 | class AnimateDiffMM: 14 | mm_injected = False 15 | 16 | def __init__(self): 17 | self.mm: MotionWrapper = None 18 | self.script_dir = None 19 | self.prev_alpha_cumprod = None 20 | self.gn32_original_forward = None 21 | 22 | 23 | def set_script_dir(self, script_dir): 24 | self.script_dir = script_dir 25 | 26 | 27 | def get_model_dir(self): 28 | model_dir = shared.opts.data.get("animatediff_model_path", os.path.join(self.script_dir, "model")) 29 | if not model_dir: 30 | model_dir = os.path.join(self.script_dir, "model") 31 | return model_dir 32 | 33 | 34 | def _load(self, model_name): 35 | model_path = os.path.join(self.get_model_dir(), model_name) 36 | if not os.path.isfile(model_path): 37 | raise RuntimeError("Please download models manually.") 38 | if self.mm is None or self.mm.mm_name != model_name: 39 | logger.info(f"Loading motion module {model_name} from {model_path}") 40 | model_hash = hashes.sha256(model_path, f"AnimateDiff/{model_name}") 41 | mm_state_dict = sd_models.read_state_dict(model_path) 42 | model_type = MotionModuleType.get_mm_type(mm_state_dict) 43 | logger.info(f"Guessed {model_name} architecture: {model_type}") 44 | self.mm = MotionWrapper(model_name, model_hash, model_type) 45 | missed_keys = self.mm.load_state_dict(mm_state_dict) 46 | logger.warn(f"Missing keys {missed_keys}") 47 | self.mm.to(device).eval() 48 | if not shared.cmd_opts.no_half: 49 | self.mm.half() 50 | if getattr(devices, "fp8", False): 51 | for module in self.mm.modules(): 52 | if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)): 53 | module.to(torch.float8_e4m3fn) 54 | 55 | 56 | def inject(self, sd_model, model_name="mm_sd_v15.ckpt"): 57 | if AnimateDiffMM.mm_injected: 58 | logger.info("Motion module already injected. Trying to restore.") 59 | self.restore(sd_model) 60 | 61 | unet = sd_model.model.diffusion_model 62 | self._load(model_name) 63 | inject_sdxl = sd_model.is_sdxl or self.mm.is_xl 64 | sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5" 65 | assert sd_model.is_sdxl == self.mm.is_xl, f"Motion module incompatible with SD. You are using {sd_ver} with {self.mm.mm_type}." 66 | 67 | if self.mm.is_v2: 68 | logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet middle block.") 69 | unet.middle_block.insert(-1, self.mm.mid_block.motion_modules[0]) 70 | elif not self.mm.is_adxl: 71 | logger.info(f"Hacking {sd_ver} GroupNorm32 forward function.") 72 | if self.mm.is_hotshot: 73 | from sgm.modules.diffusionmodules.util import GroupNorm32 74 | else: 75 | from ldm.modules.diffusionmodules.util import GroupNorm32 76 | self.gn32_original_forward = GroupNorm32.forward 77 | gn32_original_forward = self.gn32_original_forward 78 | 79 | def groupnorm32_mm_forward(self, x): 80 | x = rearrange(x, "(b f) c h w -> b c f h w", b=2) 81 | x = gn32_original_forward(self, x) 82 | x = rearrange(x, "b c f h w -> (b f) c h w", b=2) 83 | return x 84 | 85 | GroupNorm32.forward = groupnorm32_mm_forward 86 | 87 | logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet input blocks.") 88 | for mm_idx, unet_idx in enumerate([1, 2, 4, 5, 7, 8, 10, 11]): 89 | if inject_sdxl and mm_idx >= 6: 90 | break 91 | mm_idx0, mm_idx1 = mm_idx // 2, mm_idx % 2 92 | mm_inject = getattr(self.mm.down_blocks[mm_idx0], "temporal_attentions" if self.mm.is_hotshot else "motion_modules")[mm_idx1] 93 | unet.input_blocks[unet_idx].append(mm_inject) 94 | 95 | logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet output blocks.") 96 | for unet_idx in range(12): 97 | if inject_sdxl and unet_idx >= 9: 98 | break 99 | mm_idx0, mm_idx1 = unet_idx // 3, unet_idx % 3 100 | mm_inject = getattr(self.mm.up_blocks[mm_idx0], "temporal_attentions" if self.mm.is_hotshot else "motion_modules")[mm_idx1] 101 | if unet_idx % 3 == 2 and unet_idx != (8 if self.mm.is_xl else 11): 102 | unet.output_blocks[unet_idx].insert(-1, mm_inject) 103 | else: 104 | unet.output_blocks[unet_idx].append(mm_inject) 105 | 106 | self._set_ddim_alpha(sd_model) 107 | self._set_layer_mapping(sd_model) 108 | AnimateDiffMM.mm_injected = True 109 | logger.info(f"Injection finished.") 110 | 111 | 112 | def restore(self, sd_model): 113 | if not AnimateDiffMM.mm_injected: 114 | logger.info("Motion module already removed.") 115 | return 116 | 117 | inject_sdxl = sd_model.is_sdxl or self.mm.is_xl 118 | sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5" 119 | self._restore_ddim_alpha(sd_model) 120 | unet = sd_model.model.diffusion_model 121 | 122 | logger.info(f"Removing motion module from {sd_ver} UNet input blocks.") 123 | for unet_idx in [1, 2, 4, 5, 7, 8, 10, 11]: 124 | if inject_sdxl and unet_idx >= 9: 125 | break 126 | unet.input_blocks[unet_idx].pop(-1) 127 | 128 | logger.info(f"Removing motion module from {sd_ver} UNet output blocks.") 129 | for unet_idx in range(12): 130 | if inject_sdxl and unet_idx >= 9: 131 | break 132 | if unet_idx % 3 == 2 and unet_idx != (8 if self.mm.is_xl else 11): 133 | unet.output_blocks[unet_idx].pop(-2) 134 | else: 135 | unet.output_blocks[unet_idx].pop(-1) 136 | 137 | if self.mm.is_v2: 138 | logger.info(f"Removing motion module from {sd_ver} UNet middle block.") 139 | unet.middle_block.pop(-2) 140 | elif not self.mm.is_adxl: 141 | logger.info(f"Restoring {sd_ver} GroupNorm32 forward function.") 142 | if self.mm.is_hotshot: 143 | from sgm.modules.diffusionmodules.util import GroupNorm32 144 | else: 145 | from ldm.modules.diffusionmodules.util import GroupNorm32 146 | GroupNorm32.forward = self.gn32_original_forward 147 | self.gn32_original_forward = None 148 | 149 | AnimateDiffMM.mm_injected = False 150 | logger.info(f"Removal finished.") 151 | if shared.cmd_opts.lowvram: 152 | self.unload() 153 | 154 | 155 | def _set_ddim_alpha(self, sd_model): 156 | logger.info(f"Setting DDIM alpha.") 157 | beta_start = 0.00085 158 | beta_end = 0.020 if self.mm.is_adxl else 0.012 159 | if self.mm.is_adxl: 160 | betas = torch.linspace(beta_start**0.5, beta_end**0.5, 1000, dtype=torch.float32, device=device) ** 2 161 | else: 162 | betas = torch.linspace( 163 | beta_start, 164 | beta_end, 165 | 1000 if sd_model.is_sdxl else sd_model.num_timesteps, 166 | dtype=torch.float32, 167 | device=device, 168 | ) 169 | alphas = 1.0 - betas 170 | alphas_cumprod = torch.cumprod(alphas, dim=0) 171 | self.prev_alpha_cumprod = sd_model.alphas_cumprod 172 | sd_model.alphas_cumprod = alphas_cumprod 173 | 174 | 175 | def _set_layer_mapping(self, sd_model): 176 | if hasattr(sd_model, 'network_layer_mapping'): 177 | for name, module in self.mm.named_modules(): 178 | sd_model.network_layer_mapping[name] = module 179 | module.network_layer_name = name 180 | 181 | 182 | def _restore_ddim_alpha(self, sd_model): 183 | logger.info(f"Restoring DDIM alpha.") 184 | sd_model.alphas_cumprod = self.prev_alpha_cumprod 185 | self.prev_alpha_cumprod = None 186 | 187 | 188 | def unload(self): 189 | logger.info("Moving motion module to CPU") 190 | if self.mm is not None: 191 | self.mm.to(cpu) 192 | torch_gc() 193 | gc.collect() 194 | 195 | 196 | def remove(self): 197 | logger.info("Removing motion module from any memory") 198 | del self.mm 199 | self.mm = None 200 | torch_gc() 201 | gc.collect() 202 | 203 | 204 | mm_animatediff = AnimateDiffMM() 205 | -------------------------------------------------------------------------------- /scripts/easyphoto_utils/animatediff/animatediff_prompt.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | 4 | from modules.processing import StableDiffusionProcessing, Processed 5 | 6 | from .animatediff_logger import logger_animatediff as logger 7 | from .animatediff_infotext import write_params_txt 8 | 9 | 10 | class AnimateDiffPromptSchedule: 11 | 12 | def __init__(self): 13 | self.prompt_map = None 14 | self.original_prompt = None 15 | 16 | 17 | def save_infotext_img(self, p: StableDiffusionProcessing): 18 | if self.prompt_map is not None: 19 | p.prompts = [self.original_prompt for _ in range(p.batch_size)] 20 | 21 | 22 | def save_infotext_txt(self, res: Processed): 23 | if self.prompt_map is not None: 24 | parts = res.info.split('\nNegative prompt: ', 1) 25 | if len(parts) > 1: 26 | res.info = f"{self.original_prompt}\nNegative prompt: {parts[1]}" 27 | for i in range(len(res.infotexts)): 28 | parts = res.infotexts[i].split('\nNegative prompt: ', 1) 29 | if len(parts) > 1: 30 | res.infotexts[i] = f"{self.original_prompt}\nNegative prompt: {parts[1]}" 31 | write_params_txt(res.info) 32 | 33 | 34 | def parse_prompt(self, p: StableDiffusionProcessing): 35 | if type(p.prompt) is not str: 36 | logger.warn("prompt is not str, cannot support prompt map") 37 | return 38 | 39 | lines = p.prompt.strip().split('\n') 40 | data = { 41 | 'head_prompts': [], 42 | 'mapp_prompts': {}, 43 | 'tail_prompts': [] 44 | } 45 | 46 | mode = 'head' 47 | for line in lines: 48 | if mode == 'head': 49 | if re.match(r'^\d+:', line): 50 | mode = 'mapp' 51 | else: 52 | data['head_prompts'].append(line) 53 | 54 | if mode == 'mapp': 55 | match = re.match(r'^(\d+): (.+)$', line) 56 | if match: 57 | frame, prompt = match.groups() 58 | data['mapp_prompts'][int(frame)] = prompt 59 | else: 60 | mode = 'tail' 61 | 62 | if mode == 'tail': 63 | data['tail_prompts'].append(line) 64 | 65 | if data['mapp_prompts']: 66 | logger.info("You are using prompt travel.") 67 | self.prompt_map = {} 68 | prompt_list = [] 69 | last_frame = 0 70 | current_prompt = '' 71 | for frame, prompt in data['mapp_prompts'].items(): 72 | prompt_list += [current_prompt for _ in range(last_frame, frame)] 73 | last_frame = frame 74 | current_prompt = f"{', '.join(data['head_prompts'])}, {prompt}, {', '.join(data['tail_prompts'])}" 75 | self.prompt_map[frame] = current_prompt 76 | prompt_list += [current_prompt for _ in range(last_frame, p.batch_size)] 77 | assert len(prompt_list) == p.batch_size, f"prompt_list length {len(prompt_list)} != batch_size {p.batch_size}" 78 | self.original_prompt = p.prompt 79 | p.prompt = prompt_list * p.n_iter 80 | 81 | 82 | def single_cond(self, center_frame, video_length: int, cond: torch.Tensor, closed_loop = False): 83 | if closed_loop: 84 | key_prev = list(self.prompt_map.keys())[-1] 85 | key_next = list(self.prompt_map.keys())[0] 86 | else: 87 | key_prev = list(self.prompt_map.keys())[0] 88 | key_next = list(self.prompt_map.keys())[-1] 89 | 90 | for p in self.prompt_map.keys(): 91 | if p > center_frame: 92 | key_next = p 93 | break 94 | key_prev = p 95 | 96 | dist_prev = center_frame - key_prev 97 | if dist_prev < 0: 98 | dist_prev += video_length 99 | dist_next = key_next - center_frame 100 | if dist_next < 0: 101 | dist_next += video_length 102 | 103 | if key_prev == key_next or dist_prev + dist_next == 0: 104 | return cond[key_prev] if isinstance(cond, torch.Tensor) else {k: v[key_prev] for k, v in cond.items()} 105 | 106 | rate = dist_prev / (dist_prev + dist_next) 107 | if isinstance(cond, torch.Tensor): 108 | return AnimateDiffPromptSchedule.slerp(cond[key_prev], cond[key_next], rate) 109 | else: # isinstance(cond, dict) 110 | return { 111 | k: AnimateDiffPromptSchedule.slerp(v[key_prev], v[key_next], rate) 112 | for k, v in cond.items() 113 | } 114 | 115 | 116 | def multi_cond(self, cond: torch.Tensor, closed_loop = False): 117 | if self.prompt_map is None: 118 | return cond 119 | cond_list = [] if isinstance(cond, torch.Tensor) else {k: [] for k in cond.keys()} 120 | for i in range(cond.shape[0]): 121 | single_cond = self.single_cond(i, cond.shape[0], cond, closed_loop) 122 | if isinstance(cond, torch.Tensor): 123 | cond_list.append(single_cond) 124 | else: 125 | for k, v in single_cond.items(): 126 | cond_list[k].append(v) 127 | if isinstance(cond, torch.Tensor): 128 | return torch.stack(cond_list).to(cond.dtype).to(cond.device) 129 | else: 130 | return {k: torch.stack(v).to(cond[k].dtype).to(cond[k].device) for k, v in cond_list.items()} 131 | 132 | 133 | @staticmethod 134 | def slerp( 135 | v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995 136 | ) -> torch.Tensor: 137 | u0 = v0 / v0.norm() 138 | u1 = v1 / v1.norm() 139 | dot = (u0 * u1).sum() 140 | if dot.abs() > DOT_THRESHOLD: 141 | return (1.0 - t) * v0 + t * v1 142 | omega = dot.acos() 143 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin() 144 | -------------------------------------------------------------------------------- /scripts/easyphoto_utils/animatediff/animatediff_ui.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import gradio as gr 5 | 6 | from modules import shared 7 | from modules.processing import StableDiffusionProcessing 8 | 9 | from .animatediff_mm import mm_animatediff as motion_module 10 | from .animatediff_i2ibatch import animatediff_i2ibatch 11 | from .animatediff_lcm import AnimateDiffLCM 12 | 13 | 14 | class ToolButton(gr.Button, gr.components.FormComponent): 15 | """Small button with single emoji as text, fits inside gradio forms""" 16 | 17 | def __init__(self, **kwargs): 18 | super().__init__(variant="tool", **kwargs) 19 | 20 | 21 | def get_block_name(self): 22 | return "button" 23 | 24 | 25 | class AnimateDiffProcess: 26 | 27 | def __init__( 28 | self, 29 | model="mm_sd_v15_v2.ckpt", 30 | enable=False, 31 | video_length=0, 32 | fps=8, 33 | loop_number=0, 34 | closed_loop='R-P', 35 | batch_size=16, 36 | stride=1, 37 | overlap=-1, 38 | format=["GIF", "PNG"], 39 | interp='Off', 40 | interp_x=10, 41 | video_source=None, 42 | video_path='', 43 | latent_power=1, 44 | latent_scale=32, 45 | last_frame=None, 46 | latent_power_last=1, 47 | latent_scale_last=32, 48 | request_id = '', 49 | ): 50 | self.model = model 51 | self.enable = enable 52 | self.video_length = video_length 53 | self.fps = fps 54 | self.loop_number = loop_number 55 | self.closed_loop = closed_loop 56 | self.batch_size = batch_size 57 | self.stride = stride 58 | self.overlap = overlap 59 | self.format = format 60 | self.interp = interp 61 | self.interp_x = interp_x 62 | self.video_source = video_source 63 | self.video_path = video_path 64 | self.latent_power = latent_power 65 | self.latent_scale = latent_scale 66 | self.last_frame = last_frame 67 | self.latent_power_last = latent_power_last 68 | self.latent_scale_last = latent_scale_last 69 | self.request_id = request_id 70 | 71 | 72 | def get_list(self, is_img2img: bool): 73 | list_var = list(vars(self).values())[:-1] 74 | if is_img2img: 75 | animatediff_i2ibatch.hack() 76 | else: 77 | list_var = list_var[:-5] 78 | return list_var 79 | 80 | 81 | def get_dict(self, is_img2img: bool): 82 | infotext = { 83 | "enable": self.enable, 84 | "model": self.model, 85 | "video_length": self.video_length, 86 | "fps": self.fps, 87 | "loop_number": self.loop_number, 88 | "closed_loop": self.closed_loop, 89 | "batch_size": self.batch_size, 90 | "stride": self.stride, 91 | "overlap": self.overlap, 92 | "interp": self.interp, 93 | "interp_x": self.interp_x, 94 | } 95 | if self.request_id: 96 | infotext['request_id'] = self.request_id 97 | if motion_module.mm is not None and motion_module.mm.mm_hash is not None: 98 | infotext['mm_hash'] = motion_module.mm.mm_hash[:8] 99 | if is_img2img: 100 | infotext.update({ 101 | "latent_power": self.latent_power, 102 | "latent_scale": self.latent_scale, 103 | "latent_power_last": self.latent_power_last, 104 | "latent_scale_last": self.latent_scale_last, 105 | }) 106 | infotext_str = ', '.join(f"{k}: {v}" for k, v in infotext.items()) 107 | return infotext_str 108 | 109 | 110 | def _check(self): 111 | assert ( 112 | self.video_length >= 0 and self.fps > 0 113 | ), "Video length and FPS should be positive." 114 | assert not set(["GIF", "MP4", "PNG", "WEBP", "WEBM"]).isdisjoint( 115 | self.format 116 | ), "At least one saving format should be selected." 117 | 118 | 119 | def set_p(self, p: StableDiffusionProcessing): 120 | self._check() 121 | if self.video_length < self.batch_size: 122 | p.batch_size = self.batch_size 123 | else: 124 | p.batch_size = self.video_length 125 | if self.video_length == 0: 126 | self.video_length = p.batch_size 127 | self.video_default = True 128 | else: 129 | self.video_default = False 130 | if self.overlap == -1: 131 | self.overlap = self.batch_size // 4 132 | if "PNG" not in self.format or shared.opts.data.get("animatediff_save_to_custom", False): 133 | p.do_not_save_samples = True 134 | 135 | 136 | class AnimateDiffUiGroup: 137 | txt2img_submit_button = None 138 | img2img_submit_button = None 139 | 140 | def __init__(self): 141 | self.params = AnimateDiffProcess() 142 | 143 | 144 | def render(self, is_img2img: bool, model_dir: str): 145 | if not os.path.isdir(model_dir): 146 | os.mkdir(model_dir) 147 | elemid_prefix = "img2img-ad-" if is_img2img else "txt2img-ad-" 148 | model_list = [f for f in os.listdir(model_dir) if f != ".gitkeep"] 149 | with gr.Accordion("AnimateDiff", open=False): 150 | gr.Markdown(value="Please click [this link](https://github.com/continue-revolution/sd-webui-animatediff#webui-parameters) to read the documentation of each parameter.") 151 | with gr.Row(): 152 | 153 | def refresh_models(*inputs): 154 | new_model_list = [ 155 | f for f in os.listdir(model_dir) if f != ".gitkeep" 156 | ] 157 | dd = inputs[0] 158 | if dd in new_model_list: 159 | selected = dd 160 | elif len(new_model_list) > 0: 161 | selected = new_model_list[0] 162 | else: 163 | selected = None 164 | return gr.Dropdown.update(choices=new_model_list, value=selected) 165 | 166 | with gr.Row(): 167 | self.params.model = gr.Dropdown( 168 | choices=model_list, 169 | value=(self.params.model if self.params.model in model_list else None), 170 | label="Motion module", 171 | type="value", 172 | elem_id=f"{elemid_prefix}motion-module", 173 | ) 174 | refresh_model = ToolButton(value="\U0001f504") 175 | refresh_model.click(refresh_models, self.params.model, self.params.model) 176 | 177 | self.params.format = gr.CheckboxGroup( 178 | choices=["GIF", "MP4", "WEBP", "WEBM", "PNG", "TXT"], 179 | label="Save format", 180 | type="value", 181 | elem_id=f"{elemid_prefix}save-format", 182 | value=self.params.format, 183 | ) 184 | with gr.Row(): 185 | self.params.enable = gr.Checkbox( 186 | value=self.params.enable, label="Enable AnimateDiff", 187 | elem_id=f"{elemid_prefix}enable" 188 | ) 189 | self.params.video_length = gr.Number( 190 | minimum=0, 191 | value=self.params.video_length, 192 | label="Number of frames", 193 | precision=0, 194 | elem_id=f"{elemid_prefix}video-length", 195 | ) 196 | self.params.fps = gr.Number( 197 | value=self.params.fps, label="FPS", precision=0, 198 | elem_id=f"{elemid_prefix}fps" 199 | ) 200 | self.params.loop_number = gr.Number( 201 | minimum=0, 202 | value=self.params.loop_number, 203 | label="Display loop number", 204 | precision=0, 205 | elem_id=f"{elemid_prefix}loop-number", 206 | ) 207 | with gr.Row(): 208 | self.params.closed_loop = gr.Radio( 209 | choices=["N", "R-P", "R+P", "A"], 210 | value=self.params.closed_loop, 211 | label="Closed loop", 212 | elem_id=f"{elemid_prefix}closed-loop", 213 | ) 214 | self.params.batch_size = gr.Slider( 215 | minimum=1, 216 | maximum=32, 217 | value=self.params.batch_size, 218 | label="Context batch size", 219 | step=1, 220 | precision=0, 221 | elem_id=f"{elemid_prefix}batch-size", 222 | ) 223 | self.params.stride = gr.Number( 224 | minimum=1, 225 | value=self.params.stride, 226 | label="Stride", 227 | precision=0, 228 | elem_id=f"{elemid_prefix}stride", 229 | ) 230 | self.params.overlap = gr.Number( 231 | minimum=-1, 232 | value=self.params.overlap, 233 | label="Overlap", 234 | precision=0, 235 | elem_id=f"{elemid_prefix}overlap", 236 | ) 237 | with gr.Row(): 238 | self.params.interp = gr.Radio( 239 | choices=["Off", "FILM"], 240 | label="Frame Interpolation", 241 | elem_id=f"{elemid_prefix}interp-choice", 242 | value=self.params.interp 243 | ) 244 | self.params.interp_x = gr.Number( 245 | value=self.params.interp_x, label="Interp X", precision=0, 246 | elem_id=f"{elemid_prefix}interp-x" 247 | ) 248 | self.params.video_source = gr.Video( 249 | value=self.params.video_source, 250 | label="Video source", 251 | ) 252 | def update_fps(video_source): 253 | if video_source is not None and video_source != '': 254 | cap = cv2.VideoCapture(video_source) 255 | fps = int(cap.get(cv2.CAP_PROP_FPS)) 256 | cap.release() 257 | return fps 258 | else: 259 | return int(self.params.fps.value) 260 | self.params.video_source.change(update_fps, inputs=self.params.video_source, outputs=self.params.fps) 261 | def update_frames(video_source): 262 | if video_source is not None and video_source != '': 263 | cap = cv2.VideoCapture(video_source) 264 | frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 265 | cap.release() 266 | return frames 267 | else: 268 | return int(self.params.video_length.value) 269 | self.params.video_source.change(update_frames, inputs=self.params.video_source, outputs=self.params.video_length) 270 | self.params.video_path = gr.Textbox( 271 | value=self.params.video_path, 272 | label="Video path", 273 | elem_id=f"{elemid_prefix}video-path" 274 | ) 275 | if is_img2img: 276 | with gr.Row(): 277 | self.params.latent_power = gr.Slider( 278 | minimum=0.1, 279 | maximum=10, 280 | value=self.params.latent_power, 281 | step=0.1, 282 | label="Latent power", 283 | elem_id=f"{elemid_prefix}latent-power", 284 | ) 285 | self.params.latent_scale = gr.Slider( 286 | minimum=1, 287 | maximum=128, 288 | value=self.params.latent_scale, 289 | label="Latent scale", 290 | elem_id=f"{elemid_prefix}latent-scale" 291 | ) 292 | self.params.latent_power_last = gr.Slider( 293 | minimum=0.1, 294 | maximum=10, 295 | value=self.params.latent_power_last, 296 | step=0.1, 297 | label="Optional latent power for last frame", 298 | elem_id=f"{elemid_prefix}latent-power-last", 299 | ) 300 | self.params.latent_scale_last = gr.Slider( 301 | minimum=1, 302 | maximum=128, 303 | value=self.params.latent_scale_last, 304 | label="Optional latent scale for last frame", 305 | elem_id=f"{elemid_prefix}latent-scale-last" 306 | ) 307 | self.params.last_frame = gr.Image( 308 | label="Optional last frame. Leave it blank if you do not need one.", 309 | type="pil", 310 | ) 311 | with gr.Row(): 312 | unload = gr.Button(value="Move motion module to CPU (default if lowvram)") 313 | remove = gr.Button(value="Remove motion module from any memory") 314 | unload.click(fn=motion_module.unload) 315 | remove.click(fn=motion_module.remove) 316 | return self.register_unit(is_img2img) 317 | 318 | 319 | def register_unit(self, is_img2img: bool): 320 | unit = gr.State(value=AnimateDiffProcess) 321 | ( 322 | AnimateDiffUiGroup.img2img_submit_button 323 | if is_img2img 324 | else AnimateDiffUiGroup.txt2img_submit_button 325 | ).click( 326 | fn=AnimateDiffProcess, 327 | inputs=self.params.get_list(is_img2img), 328 | outputs=unit, 329 | queue=False, 330 | ) 331 | return unit 332 | 333 | 334 | @staticmethod 335 | def on_after_component(component, **_kwargs): 336 | elem_id = getattr(component, "elem_id", None) 337 | 338 | if elem_id == "txt2img_generate": 339 | AnimateDiffUiGroup.txt2img_submit_button = component 340 | return 341 | 342 | if elem_id == "img2img_generate": 343 | AnimateDiffUiGroup.img2img_submit_button = component 344 | return 345 | 346 | 347 | @staticmethod 348 | def on_before_ui(): 349 | AnimateDiffLCM.hack_kdiff_ui() 350 | -------------------------------------------------------------------------------- /scripts/easyphoto_utils/loractl_utils.py: -------------------------------------------------------------------------------- 1 | """Borrowed from https://github.com/cheald/sd-webui-loractl. 2 | """ 3 | import io 4 | import os 5 | import re 6 | import sys 7 | 8 | import gradio as gr 9 | import matplotlib 10 | import numpy as np 11 | import pandas as pd 12 | from modules import extra_networks, shared, script_callbacks 13 | from modules.processing import StableDiffusionProcessing 14 | import modules.scripts as scripts 15 | from PIL import Image 16 | from scripts.easyphoto_config import extensions_builtin_dir, extensions_dir 17 | 18 | 19 | # TODO: refactor the plugin dependency. 20 | lora_extensions_path = os.path.join(extensions_dir, "Lora") 21 | lora_extensions_builtin_path = os.path.join(extensions_builtin_dir, "Lora") 22 | 23 | if os.path.exists(lora_extensions_path): 24 | lora_path = lora_extensions_path 25 | elif os.path.exists(lora_extensions_builtin_path): 26 | lora_path = lora_extensions_builtin_path 27 | else: 28 | raise ImportError("Lora extension is not found.") 29 | sys.path.insert(0, lora_path) 30 | import extra_networks_lora 31 | import network 32 | import networks 33 | 34 | sys.path.remove(lora_path) 35 | 36 | 37 | def check_loractl_conflict(): 38 | loractl_extensions_path = os.path.join(extensions_dir, "sd-webui-loractl") 39 | if os.path.exists(loractl_extensions_path): 40 | disabled_extensions = shared.opts.data.get("disabled_extensions", []) 41 | if "sd-webui-loractl" not in disabled_extensions: 42 | return True 43 | return False 44 | 45 | 46 | # Borrowed from https://github.com/cheald/sd-webui-loractl/blob/master/loractl/lib/lora_ctl_network.py. 47 | lora_weights = {} 48 | 49 | 50 | def reset_weights(): 51 | global lora_weights 52 | lora_weights.clear() 53 | 54 | 55 | class LoraCtlNetwork(extra_networks_lora.ExtraNetworkLora): 56 | # Hijack the params parser and feed it dummy weights instead so it doesn't choke trying to 57 | # parse our extended syntax 58 | def activate(self, p, params_list): 59 | if not is_active(): 60 | return super().activate(p, params_list) 61 | 62 | for params in params_list: 63 | assert params.items 64 | name = params.positional[0] 65 | if lora_weights.get(name, None) is None: 66 | lora_weights[name] = params_to_weights(params) 67 | # The hardcoded 1 weight is fine here, since our actual patch looks up the weights from 68 | # our lora_weights dict 69 | params.positional = [name, 1] 70 | params.named = {} 71 | return super().activate(p, params_list) 72 | 73 | 74 | # Borrowed from https://github.com/cheald/sd-webui-loractl/blob/master/loractl/lib/utils.py. 75 | # Given a string like x@y,z@a, returns [[x, z], [y, a]] sorted for consumption by np.interp. 76 | def sorted_positions(raw_steps): 77 | steps = [[float(s.strip()) for s in re.split("[@~]", x)] for x in re.split("[,;]", str(raw_steps))] 78 | # If we just got a single number, just return it 79 | if len(steps[0]) == 1: 80 | return steps[0][0] 81 | 82 | # Add implicit 1s to any steps which don't have a weight 83 | steps = [[s[0], s[1] if len(s) == 2 else 1] for s in steps] 84 | 85 | # Sort by index 86 | steps.sort(key=lambda k: k[1]) 87 | 88 | steps = [list(v) for v in zip(*steps)] 89 | return steps 90 | 91 | 92 | def calculate_weight(m, step, max_steps, step_offset=2): 93 | if isinstance(m, list): 94 | # normalize the step to 0~1 95 | if m[1][-1] <= 1.0: 96 | if max_steps > 0: 97 | step = (step) / (max_steps - step_offset) 98 | else: 99 | step = 1.0 100 | else: 101 | step = step 102 | # get value from interp between m[1]~m[0] 103 | v = np.interp(step, m[1], m[0]) 104 | return v 105 | else: 106 | return m 107 | 108 | 109 | def params_to_weights(params): 110 | weights = {"unet": None, "te": 1.0, "hrunet": None, "hrte": None} 111 | 112 | if len(params.positional) > 1: 113 | weights["te"] = sorted_positions(params.positional[1]) 114 | 115 | if len(params.positional) > 2: 116 | weights["unet"] = sorted_positions(params.positional[2]) 117 | 118 | if params.named.get("te"): 119 | weights["te"] = sorted_positions(params.named.get("te")) 120 | 121 | if params.named.get("unet"): 122 | weights["unet"] = sorted_positions(params.named.get("unet")) 123 | 124 | if params.named.get("hr"): 125 | weights["hrunet"] = sorted_positions(params.named.get("hr")) 126 | weights["hrte"] = sorted_positions(params.named.get("hr")) 127 | 128 | if params.named.get("hrunet"): 129 | weights["hrunet"] = sorted_positions(params.named.get("hrunet")) 130 | 131 | if params.named.get("hrte"): 132 | weights["hrte"] = sorted_positions(params.named.get("hrte")) 133 | 134 | # If unet ended up unset, then use the te value 135 | weights["unet"] = weights["unet"] if weights["unet"] is not None else weights["te"] 136 | # If hrunet ended up unset, use unet value 137 | weights["hrunet"] = weights["hrunet"] if weights["hrunet"] is not None else weights["unet"] 138 | # If hrte ended up unset, use te value 139 | weights["hrte"] = weights["hrte"] if weights["hrte"] is not None else weights["te"] 140 | 141 | return weights 142 | 143 | 144 | hires = False 145 | loractl_active = True 146 | 147 | 148 | def is_hires(): 149 | return hires 150 | 151 | 152 | def set_hires(value): 153 | global hires 154 | hires = value 155 | 156 | 157 | def is_active(): 158 | global loractl_active 159 | return loractl_active 160 | 161 | 162 | def set_active(value): 163 | global loractl_active 164 | loractl_active = value 165 | 166 | 167 | # Borrowed from https://github.com/cheald/sd-webui-loractl/blob/master/loractl/lib/network_patch.py. 168 | # Patch network.Network so it reapplies properly for dynamic weights 169 | # By default, network application is cached, with (name, te, unet, dim) as a key 170 | # By replacing the bare properties with getters, we can ensure that we cause SD 171 | # to reapply the network each time we change its weights, while still taking advantage 172 | # of caching when weights are not updated. 173 | 174 | 175 | def get_weight(m): 176 | return calculate_weight(m, shared.state.sampling_step, shared.state.sampling_steps, step_offset=2) 177 | 178 | 179 | def get_dynamic_te(self): 180 | if self.name in lora_weights: 181 | key = "te" if not is_hires() else "hrte" 182 | w = lora_weights[self.name] 183 | return get_weight(w.get(key, self._te_multiplier)) 184 | 185 | return get_weight(self._te_multiplier) 186 | 187 | 188 | def get_dynamic_unet(self): 189 | if self.name in lora_weights: 190 | key = "unet" if not is_hires() else "hrunet" 191 | w = lora_weights[self.name] 192 | return get_weight(w.get(key, self._unet_multiplier)) 193 | 194 | return get_weight(self._unet_multiplier) 195 | 196 | 197 | def set_dynamic_te(self, value): 198 | self._te_multiplier = value 199 | 200 | 201 | def set_dynamic_unet(self, value): 202 | self._unet_multiplier = value 203 | 204 | 205 | def apply(): 206 | if getattr(network.Network, "te_multiplier", None) is None: 207 | network.Network.te_multiplier = property(get_dynamic_te, set_dynamic_te) 208 | network.Network.unet_multiplier = property(get_dynamic_unet, set_dynamic_unet) 209 | 210 | 211 | # Borrowed from https://github.com/cheald/sd-webui-loractl/blob/master/scripts/loractl.py. 212 | class LoraCtlScript(scripts.Script): 213 | def __init__(self): 214 | self.original_network = None 215 | super().__init__() 216 | 217 | def title(self): 218 | return "Dynamic Lora Weights (EasyPhoto built-in)" 219 | 220 | def show(self, is_img2img): 221 | return scripts.AlwaysVisible 222 | 223 | def ui(self, is_img2img): 224 | with gr.Group(): 225 | with gr.Accordion("Dynamic Lora Weights (EasyPhoto builtin)", open=False): 226 | opt_enable = gr.Checkbox(value=True, label="Enable Dynamic Lora Weights") 227 | opt_plot_lora_weight = gr.Checkbox(value=False, label="Plot the LoRA weight in all steps") 228 | return [opt_enable, opt_plot_lora_weight] 229 | 230 | def process(self, p: StableDiffusionProcessing, opt_enable=True, opt_plot_lora_weight=False, **kwargs): 231 | if opt_enable and type(extra_networks.extra_network_registry["lora"]) != LoraCtlNetwork: 232 | self.original_network = extra_networks.extra_network_registry["lora"] 233 | network = LoraCtlNetwork() 234 | extra_networks.register_extra_network(network) 235 | extra_networks.register_extra_network_alias(network, "loractl") 236 | # elif not opt_enable and type(extra_networks.extra_network_registry["lora"]) != LoraCtlNetwork.__bases__[0]: 237 | # extra_networks.register_extra_network(self.original_network) 238 | # self.original_network = None 239 | 240 | apply() 241 | set_hires(False) 242 | set_active(opt_enable) 243 | reset_weights() 244 | reset_plot() 245 | 246 | def before_hr(self, p, *args): 247 | set_hires(True) 248 | 249 | def postprocess(self, p, processed, opt_enable=True, opt_plot_lora_weight=False, **kwargs): 250 | if opt_plot_lora_weight and opt_enable: 251 | processed.images.extend([make_plot()]) 252 | 253 | 254 | # Borrowed from https://github.com/cheald/sd-webui-loractl/blob/master/scripts/loractl.py. 255 | log_weights = [] 256 | log_names = [] 257 | last_plotted_step = -1 258 | 259 | 260 | # Copied from composable_lora 261 | def plot_lora_weight(lora_weights, lora_names): 262 | data = pd.DataFrame(lora_weights, columns=lora_names) 263 | ax = data.plot() 264 | ax.set_xlabel("Steps") 265 | ax.set_ylabel("LoRA weight") 266 | ax.set_title("LoRA weight in all steps") 267 | ax.legend(loc=0) 268 | result_image = fig2img(ax) 269 | matplotlib.pyplot.close(ax.figure) 270 | del ax 271 | return result_image 272 | 273 | 274 | # Copied from composable_lora 275 | def fig2img(fig): 276 | buf = io.BytesIO() 277 | fig.figure.savefig(buf) 278 | buf.seek(0) 279 | img = Image.open(buf) 280 | return img 281 | 282 | 283 | def reset_plot(): 284 | global last_plotted_step 285 | log_weights.clear() 286 | log_names.clear() 287 | 288 | 289 | def make_plot(): 290 | return plot_lora_weight(log_weights, log_names) 291 | 292 | 293 | # On each step, capture our lora weights for plotting 294 | def on_step(params): 295 | global last_plotted_step 296 | if last_plotted_step == params.sampling_step and len(log_weights) > 0: 297 | log_weights.pop() 298 | last_plotted_step = params.sampling_step 299 | if len(log_names) == 0: 300 | for net in networks.loaded_networks: 301 | log_names.append(net.name + "_te") 302 | log_names.append(net.name + "_unet") 303 | frame = [] 304 | for net in networks.loaded_networks: 305 | frame.append(net.te_multiplier) 306 | frame.append(net.unet_multiplier) 307 | log_weights.append(frame) 308 | 309 | 310 | script_callbacks.on_cfg_after_cfg(on_step) 311 | -------------------------------------------------------------------------------- /scripts/train_kohya/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/schedulers/scheduling_ddim.py 2 | # with the following modifications: 3 | # - It computes and returns the log prob of `prev_sample` given the UNet prediction. 4 | # - Instead of `variance_noise`, it takes `prev_sample` as an optional argument. If `prev_sample` is provided, 5 | # it uses it to compute the log prob. 6 | # - Timesteps can be a batched torch.Tensor. 7 | # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py. 8 | 9 | import math 10 | from packaging import version 11 | from typing import Optional, Tuple, Union 12 | 13 | import diffusers 14 | import torch 15 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler, DDIMSchedulerOutput 16 | 17 | # See https://github.com/huggingface/diffusers/issues/5025 for details. 18 | if version.parse(diffusers.__version__) > version.parse("0.20.2"): 19 | from diffusers.utils.torch_utils import randn_tensor 20 | else: 21 | from diffusers.utils import randn_tensor 22 | 23 | 24 | def _left_broadcast(t, shape): 25 | assert t.ndim <= len(shape) 26 | return t.reshape(t.shape + (1,) * (len(shape) - t.ndim)).broadcast_to(shape) 27 | 28 | 29 | def _get_variance(self, timestep, prev_timestep): 30 | alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device) 31 | alpha_prod_t_prev = torch.where( 32 | prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod 33 | ).to(timestep.device) 34 | beta_prod_t = 1 - alpha_prod_t 35 | beta_prod_t_prev = 1 - alpha_prod_t_prev 36 | 37 | variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) 38 | 39 | return variance 40 | 41 | 42 | def ddim_step_with_logprob( 43 | self: DDIMScheduler, 44 | model_output: torch.FloatTensor, 45 | timestep: int, 46 | sample: torch.FloatTensor, 47 | eta: float = 0.0, 48 | use_clipped_model_output: bool = False, 49 | generator=None, 50 | prev_sample: Optional[torch.FloatTensor] = None, 51 | ) -> Union[DDIMSchedulerOutput, Tuple]: 52 | """ 53 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion 54 | process from the learned model outputs (most often the predicted noise). 55 | 56 | Args: 57 | model_output (`torch.FloatTensor`): direct output from learned diffusion model. 58 | timestep (`int`): current discrete timestep in the diffusion chain. 59 | sample (`torch.FloatTensor`): 60 | current instance of sample being created by diffusion process. 61 | eta (`float`): weight of noise for added noise in diffusion step. 62 | use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped 63 | predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when 64 | `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would 65 | coincide with the one provided as input and `use_clipped_model_output` will have not effect. 66 | generator: random number generator. 67 | variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we 68 | can directly provide the noise for the variance itself. This is useful for methods such as 69 | CycleDiffusion. (https://arxiv.org/abs/2210.05559) 70 | return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class 71 | 72 | Returns: 73 | [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: 74 | [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When 75 | returning a tuple, the first element is the sample tensor. 76 | 77 | """ 78 | assert isinstance(self, DDIMScheduler) 79 | if self.num_inference_steps is None: 80 | raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler") 81 | 82 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf 83 | # Ideally, read DDIM paper in-detail understanding 84 | 85 | # Notation ( -> 86 | # - pred_noise_t -> e_theta(x_t, t) 87 | # - pred_original_sample -> f_theta(x_t, t) or x_0 88 | # - std_dev_t -> sigma_t 89 | # - eta -> η 90 | # - pred_sample_direction -> "direction pointing to x_t" 91 | # - pred_prev_sample -> "x_t-1" 92 | 93 | # 1. get previous step value (=t-1) 94 | prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps 95 | # to prevent OOB on gather 96 | prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1) 97 | 98 | # 2. compute alphas, betas 99 | alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu()) 100 | alpha_prod_t_prev = torch.where(prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod) 101 | alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device) 102 | alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(sample.device) 103 | 104 | beta_prod_t = 1 - alpha_prod_t 105 | 106 | # 3. compute predicted original sample from predicted noise also called 107 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 108 | if self.config.prediction_type == "epsilon": 109 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) 110 | pred_epsilon = model_output 111 | elif self.config.prediction_type == "sample": 112 | pred_original_sample = model_output 113 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) 114 | elif self.config.prediction_type == "v_prediction": 115 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output 116 | pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample 117 | else: 118 | raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`") 119 | 120 | # 4. Clip or threshold "predicted x_0" 121 | if self.config.thresholding: 122 | pred_original_sample = self._threshold_sample(pred_original_sample) 123 | elif self.config.clip_sample: 124 | pred_original_sample = pred_original_sample.clamp(-self.config.clip_sample_range, self.config.clip_sample_range) 125 | 126 | # 5. compute variance: "sigma_t(η)" -> see formula (16) 127 | # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) 128 | variance = _get_variance(self, timestep, prev_timestep) 129 | std_dev_t = eta * variance ** (0.5) 130 | std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device) 131 | 132 | if use_clipped_model_output: 133 | # the pred_epsilon is always re-derived from the clipped x_0 in Glide 134 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) 135 | 136 | # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 137 | pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon 138 | 139 | # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf 140 | prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction 141 | 142 | if prev_sample is not None and generator is not None: 143 | raise ValueError( 144 | "Cannot pass both generator and prev_sample. Please make sure that either `generator` or" " `prev_sample` stays `None`." 145 | ) 146 | 147 | if prev_sample is None: 148 | variance_noise = randn_tensor(model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype) 149 | prev_sample = prev_sample_mean + std_dev_t * variance_noise 150 | 151 | # log prob of prev_sample given prev_sample_mean and std_dev_t 152 | log_prob = ( 153 | -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)) 154 | - torch.log(std_dev_t) 155 | - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) 156 | ) 157 | # mean along all but batch dimension 158 | log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) 159 | 160 | return prev_sample.type(sample.dtype), log_prob 161 | -------------------------------------------------------------------------------- /scripts/train_kohya/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py 2 | # with the following modifications: 3 | # - It uses the patched version of `ddim_step_with_logprob` from `ddim_with_logprob.py`. As such, it only supports the 4 | # `ddim` scheduler. 5 | # - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step. 6 | # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py. 7 | 8 | from typing import Any, Callable, Dict, List, Optional, Union 9 | 10 | import torch 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline, rescale_noise_cfg 12 | 13 | from .ddim_with_logprob import ddim_step_with_logprob 14 | 15 | 16 | @torch.no_grad() 17 | def pipeline_with_logprob( 18 | self: StableDiffusionPipeline, 19 | prompt: Union[str, List[str]] = None, 20 | height: Optional[int] = None, 21 | width: Optional[int] = None, 22 | num_inference_steps: int = 50, 23 | guidance_scale: float = 7.5, 24 | negative_prompt: Optional[Union[str, List[str]]] = None, 25 | num_images_per_prompt: Optional[int] = 1, 26 | eta: float = 0.0, 27 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 28 | latents: Optional[torch.FloatTensor] = None, 29 | prompt_embeds: Optional[torch.FloatTensor] = None, 30 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 31 | output_type: Optional[str] = "pil", 32 | return_dict: bool = True, 33 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 34 | callback_steps: int = 1, 35 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 36 | guidance_rescale: float = 0.0, 37 | ): 38 | r""" 39 | Function invoked when calling the pipeline for generation. 40 | 41 | Args: 42 | prompt (`str` or `List[str]`, *optional*): 43 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 44 | instead. 45 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 46 | The height in pixels of the generated image. 47 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 48 | The width in pixels of the generated image. 49 | num_inference_steps (`int`, *optional*, defaults to 50): 50 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 51 | expense of slower inference. 52 | guidance_scale (`float`, *optional*, defaults to 7.5): 53 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 54 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 55 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 56 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 57 | usually at the expense of lower image quality. 58 | negative_prompt (`str` or `List[str]`, *optional*): 59 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 60 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 61 | less than `1`). 62 | num_images_per_prompt (`int`, *optional*, defaults to 1): 63 | The number of images to generate per prompt. 64 | eta (`float`, *optional*, defaults to 0.0): 65 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 66 | [`schedulers.DDIMScheduler`], will be ignored for others. 67 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 68 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 69 | to make generation deterministic. 70 | latents (`torch.FloatTensor`, *optional*): 71 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 72 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 73 | tensor will ge generated by sampling using the supplied random `generator`. 74 | prompt_embeds (`torch.FloatTensor`, *optional*): 75 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 76 | provided, text embeddings will be generated from `prompt` input argument. 77 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 78 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 79 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 80 | argument. 81 | output_type (`str`, *optional*, defaults to `"pil"`): 82 | The output format of the generate image. Choose between 83 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 84 | return_dict (`bool`, *optional*, defaults to `True`): 85 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a 86 | plain tuple. 87 | callback (`Callable`, *optional*): 88 | A function that will be called every `callback_steps` steps during inference. The function will be 89 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 90 | callback_steps (`int`, *optional*, defaults to 1): 91 | The frequency at which the `callback` function will be called. If not specified, the callback will be 92 | called at every step. 93 | cross_attention_kwargs (`dict`, *optional*): 94 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 95 | `self.processor` in 96 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). 97 | guidance_rescale (`float`, *optional*, defaults to 0.7): 98 | Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are 99 | Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of 100 | [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). 101 | Guidance rescale factor should fix overexposure when using zero terminal SNR. 102 | 103 | Examples: 104 | 105 | Returns: 106 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: 107 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. 108 | When returning a tuple, the first element is a list with the generated images, and the second element is a 109 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" 110 | (nsfw) content, according to the `safety_checker`. 111 | """ 112 | # 0. Default height and width to unet 113 | height = height or self.unet.config.sample_size * self.vae_scale_factor 114 | width = width or self.unet.config.sample_size * self.vae_scale_factor 115 | 116 | # 1. Check inputs. Raise error if not correct 117 | self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) 118 | 119 | # 2. Define call parameters 120 | if prompt is not None and isinstance(prompt, str): 121 | batch_size = 1 122 | elif prompt is not None and isinstance(prompt, list): 123 | batch_size = len(prompt) 124 | else: 125 | batch_size = prompt_embeds.shape[0] 126 | 127 | device = self._execution_device 128 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 129 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 130 | # corresponds to doing no classifier free guidance. 131 | do_classifier_free_guidance = guidance_scale > 1.0 132 | 133 | # 3. Encode input prompt 134 | text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None 135 | prompt_embeds = self._encode_prompt( 136 | prompt, 137 | device, 138 | num_images_per_prompt, 139 | do_classifier_free_guidance, 140 | negative_prompt, 141 | prompt_embeds=prompt_embeds, 142 | negative_prompt_embeds=negative_prompt_embeds, 143 | lora_scale=text_encoder_lora_scale, 144 | ) 145 | 146 | # 4. Prepare timesteps 147 | self.scheduler.set_timesteps(num_inference_steps, device=device) 148 | timesteps = self.scheduler.timesteps 149 | 150 | # 5. Prepare latent variables 151 | num_channels_latents = self.unet.config.in_channels 152 | latents = self.prepare_latents( 153 | batch_size * num_images_per_prompt, 154 | num_channels_latents, 155 | height, 156 | width, 157 | prompt_embeds.dtype, 158 | device, 159 | generator, 160 | latents, 161 | ) 162 | 163 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 164 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 165 | 166 | # 7. Denoising loop 167 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order 168 | all_latents = [latents] 169 | all_log_probs = [] 170 | with self.progress_bar(total=num_inference_steps) as progress_bar: 171 | for i, t in enumerate(timesteps): 172 | # expand the latents if we are doing classifier free guidance 173 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 174 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 175 | 176 | # predict the noise residual 177 | noise_pred = self.unet( 178 | latent_model_input, 179 | t, 180 | encoder_hidden_states=prompt_embeds, 181 | cross_attention_kwargs=cross_attention_kwargs, 182 | return_dict=False, 183 | )[0] 184 | 185 | # perform guidance 186 | if do_classifier_free_guidance: 187 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 188 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 189 | 190 | if do_classifier_free_guidance and guidance_rescale > 0.0: 191 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 192 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) 193 | 194 | # compute the previous noisy sample x_t -> x_t-1 195 | latents, log_prob = ddim_step_with_logprob(self.scheduler, noise_pred, t, latents, **extra_step_kwargs) 196 | 197 | all_latents.append(latents) 198 | all_log_probs.append(log_prob) 199 | 200 | # call the callback, if provided 201 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 202 | progress_bar.update() 203 | if callback is not None and i % callback_steps == 0: 204 | callback(i, t, latents) 205 | 206 | if not output_type == "latent": 207 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] 208 | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) 209 | else: 210 | image = latents 211 | has_nsfw_concept = None 212 | 213 | if has_nsfw_concept is None: 214 | do_denormalize = [True] * image.shape[0] 215 | else: 216 | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] 217 | 218 | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) 219 | 220 | # Offload last model to CPU 221 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 222 | self.final_offload_hook.offload() 223 | 224 | return image, has_nsfw_concept, all_latents, all_log_probs 225 | -------------------------------------------------------------------------------- /scripts/train_kohya/ddpo_pytorch/prompts.py: -------------------------------------------------------------------------------- 1 | """This file defines customize prompt funtions. The prompt function which takes no arguments 2 | generates a random prompt from the given prompt distribution each time it is called. 3 | """ 4 | 5 | 6 | def easyphoto() -> str: 7 | return "easyphoto_face, easyphoto, 1person" 8 | -------------------------------------------------------------------------------- /scripts/train_kohya/ddpo_pytorch/rewards.py: -------------------------------------------------------------------------------- 1 | """This file defines customize reward funtions. The reward function which takes in a batch of images 2 | and corresponding prompts returns a batch of rewards each time it is called. 3 | """ 4 | 5 | from pathlib import Path 6 | from typing import Callable, List, Tuple, Union 7 | 8 | import numpy as np 9 | import torch 10 | from modelscope.outputs import OutputKeys 11 | from modelscope.pipelines import pipeline 12 | from modelscope.utils.constant import Tasks 13 | from PIL import Image 14 | 15 | 16 | def _convert_images(images: Union[List[Image.Image], np.array, torch.Tensor]) -> List[Image.Image]: 17 | if isinstance(images, List) and isinstance(images[0], Image.Image): 18 | return images 19 | if isinstance(images, torch.Tensor): 20 | images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() 21 | images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC 22 | images = [Image.fromarray(image) for image in images] 23 | 24 | return images 25 | 26 | 27 | def faceid_SCRFD(target_image_dir: str) -> Callable: 28 | """The closure returns the Face ID reward function given a user ID. It uses SCRFD to detect the face and then 29 | use CurricularFace to extract the face feature. SCRFD requires the extra package mmcv-full. 30 | 31 | Args: 32 | target_image_dir (str): The directory of processed face image files (.jpg) given a user ID. 33 | """ 34 | face_recognition = pipeline(Tasks.face_recognition, model="damo/cv_ir101_facerecognition_cfglint") 35 | target_image_files = Path(target_image_dir).glob("*.jpg") 36 | target_images = [Image.open(f).convert("RGB") for f in target_image_files] 37 | target_embs = [face_recognition(t)[OutputKeys.IMG_EMBEDDING][0] for t in target_images] # (M, 512) 38 | target_mean_emb = np.mean(np.vstack(target_embs), axis=0) # (512,) 39 | 40 | # Redundant parameters are for backward compatibility. 41 | def __call__(src_images: Union[List[Image.Image], np.array, torch.Tensor], prompts: List[str]) -> Tuple[np.array, dict]: 42 | src_images = _convert_images(src_images) 43 | src_embs = [] # (N, 512) 44 | for s in src_images: 45 | try: 46 | emb = face_recognition(s)[OutputKeys.IMG_EMBEDDING][0] 47 | except Exception as e: # TypeError 48 | print("Catch Exception in the reward function faceid_retina: {}".format(e)) 49 | emb = np.array([0] * 512) 50 | print("No face is detected or the size of the detected face size is not enough. Set the embedding to zero.") 51 | finally: 52 | src_embs.append(emb) 53 | faceid_list = np.dot(src_embs, target_mean_emb) 54 | 55 | return faceid_list 56 | 57 | return __call__ 58 | 59 | 60 | def faceid_retina(target_image_dir: str) -> Callable: 61 | """The closure returns the Face ID reward function given a user ID. It uses RetinaFace to detect the face and then 62 | use CurricularFace to extract the face feature. As the detection capability of RetinaFace is weaker than SCRFD, 63 | many generated side faces cannot be detected. 64 | 65 | Args: 66 | target_image_dir (str): The directory of processed face image files (.jpg) given a user ID. 67 | """ 68 | # The retinaface detection is built into the face recognition pipeline. 69 | face_recognition = pipeline("face_recognition", model="bubbliiiing/cv_retinafce_recognition", model_revision="v1.0.3") 70 | target_image_files = Path(target_image_dir).glob("*.jpg") 71 | target_images = [Image.open(f).convert("RGB") for f in target_image_files] 72 | 73 | target_embs = [face_recognition(dict(user=f))[OutputKeys.IMG_EMBEDDING] for f in target_images] # (M, 512) 74 | target_mean_emb = np.mean(np.vstack(target_embs), axis=0) # (512,) 75 | 76 | # Redundant parameters are for backward compatibility. 77 | def __call__(src_images: Union[List[Image.Image], np.array, torch.Tensor], prompts: List[str]) -> Tuple[np.array, dict]: 78 | src_images = _convert_images(src_images) 79 | src_embs = [] # (N, 512) 80 | for s in src_images: 81 | try: 82 | emb = face_recognition(dict(user=s))[OutputKeys.IMG_EMBEDDING][0] 83 | except Exception as e: # cv2.error; TypeError. 84 | print("Catch Exception in the reward function faceid_retina: {}".format(e)) 85 | emb = np.array([0] * 512) 86 | print("No face is detected or the size of the detected face size is not enough. Set the embedding to zero.") 87 | finally: 88 | src_embs.append(emb) 89 | faceid_list = np.dot(src_embs, target_mean_emb) 90 | 91 | return faceid_list 92 | 93 | return __call__ 94 | -------------------------------------------------------------------------------- /scripts/train_kohya/ddpo_pytorch/stat_tracking.py: -------------------------------------------------------------------------------- 1 | """Borrowed from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/stat_tracking.py. 2 | """ 3 | 4 | from collections import deque 5 | from typing import List 6 | 7 | import numpy as np 8 | 9 | 10 | class PerPromptStatTracker: 11 | """Track the mean and std of reward on a per-prompt basis and use that to compute advantages. 12 | 13 | Args: 14 | buffer_size (int): The number of reward values to store in the buffer for each prompt. 15 | The buffer persists across epochs. 16 | min_count (int): The minimum number of reward values to store in the buffer before using the 17 | per-prompt mean and std. If the buffer contains fewer than `min_count` values, the mean and 18 | std of the entire batch will be used instead. 19 | """ 20 | 21 | def __init__(self, buffer_size: int, min_count: int): 22 | self.buffer_size = buffer_size 23 | self.min_count = min_count 24 | self.stats = {} 25 | 26 | def update(self, prompts: List[str], rewards: List[float]): 27 | prompts = np.array(prompts) 28 | rewards = np.array(rewards) 29 | unique = np.unique(prompts) 30 | advantages = np.empty_like(rewards) 31 | for prompt in unique: 32 | prompt_rewards = rewards[prompts == prompt] 33 | if prompt not in self.stats: 34 | self.stats[prompt] = deque(maxlen=self.buffer_size) 35 | self.stats[prompt].extend(prompt_rewards) 36 | 37 | if len(self.stats[prompt]) < self.min_count: 38 | mean = np.mean(rewards) 39 | std = np.std(rewards) + 1e-6 40 | else: 41 | mean = np.mean(self.stats[prompt]) 42 | std = np.std(self.stats[prompt]) + 1e-6 43 | advantages[prompts == prompt] = (prompt_rewards - mean) / std 44 | 45 | return advantages 46 | 47 | def get_stats(self): 48 | return {k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items()} 49 | -------------------------------------------------------------------------------- /scripts/train_kohya/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/scripts/train_kohya/utils/__init__.py -------------------------------------------------------------------------------- /scripts/train_kohya/utils/gpu_info.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import platform 3 | import time 4 | from datetime import datetime 5 | from multiprocessing import Process, Value 6 | from os import makedirs, path 7 | 8 | import matplotlib.pyplot as plt 9 | import matplotlib.ticker as ticker 10 | 11 | if platform.system() != "Windows": 12 | try: 13 | from nvitop import Device 14 | except Exception: 15 | Device = None 16 | 17 | # Constants 18 | BYTES_PER_GB = 1024 * 1024 * 1024 19 | 20 | 21 | def bytes_to_gb(bytes_value: int) -> float: 22 | """Convert bytes to gigabytes.""" 23 | return bytes_value / BYTES_PER_GB 24 | 25 | 26 | def log_device_info(device, prefix: str, csvwriter, display_log): 27 | """ 28 | Logs device information. 29 | 30 | Parameters: 31 | device: The device object with GPU information. 32 | prefix: The prefix string for log identification. 33 | csvwriter: The CSV writer object for writing logs to a file. 34 | display_log: print out on shell 35 | """ 36 | total_memory_gb = float(device.memory_total_human()[:-3]) 37 | used_memory_bytes = device.memory_used() 38 | gpu_utilization = device.gpu_utilization() 39 | 40 | if display_log: 41 | print(f"Device: {device.name}") 42 | print(f" - Used memory : {bytes_to_gb(used_memory_bytes):.2f} GB") 43 | print(f" - Used memory% : {bytes_to_gb(used_memory_bytes)/total_memory_gb * 100:.2f}%") 44 | print(f" - GPU utilization: {gpu_utilization}%") 45 | print("-" * 40) 46 | 47 | current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") 48 | (bytes_to_gb(used_memory_bytes) / total_memory_gb) * 100 49 | 50 | csvwriter.writerow([current_time, bytes_to_gb(used_memory_bytes), gpu_utilization]) 51 | 52 | 53 | def monitor_and_plot(prefix="result/tmp", display_log=False, stop_flag: Value = None): 54 | """Monitor and plot GPU usage. 55 | Args: 56 | prefix: The prefix of the output file. 57 | stop_flag: A multiprocessing.Value to indicate if monitoring should stop. 58 | """ 59 | devices = Device.all() 60 | initial_pids = set() 61 | monitored_pids = set() 62 | 63 | with open(f"{prefix}.csv", "w", newline="") as csvfile: 64 | csvwriter = csv.writer(csvfile) 65 | csvwriter.writerow(["Time", "Used Memory%", "GPU Utilization"]) 66 | 67 | try: 68 | while True: 69 | if stop_flag and stop_flag.value: 70 | break 71 | 72 | for device in devices: 73 | current_pids = set(device.processes().keys()) 74 | if not initial_pids: 75 | initial_pids = current_pids 76 | 77 | new_pids = current_pids - initial_pids 78 | if new_pids: 79 | monitored_pids.update(new_pids) 80 | 81 | for pid in monitored_pids.copy(): 82 | if pid not in current_pids: 83 | monitored_pids.remove(pid) 84 | if not monitored_pids: 85 | raise StopIteration 86 | log_device_info(device, prefix, csvwriter, display_log) 87 | time.sleep(1) 88 | except StopIteration: 89 | pass 90 | 91 | plot_data(prefix) 92 | return 93 | 94 | 95 | def plot_data(prefix): 96 | """Plot the data from the CSV file. 97 | Args: 98 | prefix: The prefix of the CSV file. 99 | """ 100 | data = list(csv.reader(open(f"{prefix}.csv"))) 101 | if len(data) < 2: 102 | print("Insufficient data for plotting.") 103 | return 104 | 105 | time_stamps, used_memory, gpu_utilization = zip(*data[1:]) 106 | used_memory = [float(x) for x in used_memory] 107 | gpu_utilization = [float(x) for x in gpu_utilization] 108 | if len(used_memory) >= 10: 109 | tick_spacing = len(used_memory) // 10 110 | else: 111 | tick_spacing = 1 112 | 113 | try: 114 | plot_graph( 115 | time_stamps, 116 | used_memory, 117 | "Used Memory (GB)", 118 | "Time", 119 | "Used Memory (GB)", 120 | "Used Memory Over Time", 121 | tick_spacing, 122 | f"{prefix}_memory.png", 123 | ) 124 | except Exception as e: 125 | message = f"plot_graph of Memory error, error info:{str(e)}" 126 | print(message) 127 | 128 | try: 129 | plot_graph( 130 | time_stamps, 131 | gpu_utilization, 132 | "GPU Utilization (%)", 133 | "Time", 134 | "GPU Utilization (%)", 135 | "GPU Utilization Over Time", 136 | tick_spacing, 137 | f"{prefix}_utilization.png", 138 | ) 139 | except Exception as e: 140 | message = f"plot_graph of Utilization error, error info:{str(e)}" 141 | print(message) 142 | 143 | 144 | def plot_graph(x, y, label, xlabel, ylabel, title, tick_spacing, filename): 145 | """Generate and save a plot. 146 | Args: 147 | x: X-axis data. 148 | y: Y-axis data. 149 | label: The label for the plot. 150 | xlabel: Label for X-axis. 151 | ylabel: Label for Y-axis. 152 | title: The title of the plot. 153 | tick_spacing: Interval for tick marks on the x-axis. 154 | filename: The filename to save the plot. 155 | """ 156 | plt.figure(figsize=(10, 6)) 157 | plt.plot(x, y, label=label) 158 | plt.xlabel(xlabel) 159 | plt.ylabel(ylabel) 160 | plt.title(title) 161 | plt.legend() 162 | ax = plt.gca() 163 | ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing)) 164 | plt.xticks(rotation=45) 165 | plt.savefig(filename) 166 | 167 | 168 | def gpu_monitor_decorator(prefix="result/gpu_info", display_log=False): 169 | def actual_decorator(func): 170 | def wrapper(*args, **kwargs): 171 | if platform.system() != "Windows" and Device is not None: 172 | timestamp = datetime.now().strftime("%Y%m%d%H%M%S") 173 | dynamic_prefix = f"{prefix}/{func.__name__}_{timestamp}" 174 | 175 | directory = path.dirname(dynamic_prefix) 176 | if not path.exists(directory): 177 | try: 178 | makedirs(directory) 179 | except Exception as e: 180 | comment = f"GPU Info record need a result/gpu_info dir in your SDWebUI, now failed with {str(e)}" 181 | print(comment) 182 | dynamic_prefix = f"{func.__name__}_{timestamp}" 183 | 184 | stop_flag = Value("b", False) 185 | 186 | monitor_proc = Process(target=monitor_and_plot, args=(dynamic_prefix, display_log, stop_flag)) 187 | monitor_proc.start() 188 | 189 | try: 190 | result = func(*args, **kwargs) 191 | finally: 192 | stop_flag.value = True 193 | monitor_proc.join() 194 | else: 195 | result = func(*args, **kwargs) 196 | return result 197 | 198 | return wrapper 199 | 200 | return actual_decorator 201 | 202 | 203 | if __name__ == "__main__": 204 | pass 205 | 206 | # Display how to define a GPU infer function and wrap with gpu_monitor_decorator 207 | @gpu_monitor_decorator() 208 | def execute_process(repeat=5): 209 | from modelscope.pipelines import pipeline 210 | from modelscope.utils.constant import Tasks 211 | 212 | retina_face_detection = pipeline(Tasks.face_detection, "damo/cv_resnet50_face-detection_retinaface") 213 | img_path = "https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/retina_face_detection.jpg" 214 | 215 | for i in range(repeat): 216 | retina_face_detection([img_path] * 10) 217 | return 218 | 219 | if 1: 220 | execute_process(repeat=5) 221 | --------------------------------------------------------------------------------