├── .github
└── ISSUE_TEMPLATE
│ ├── bug_report.yml
│ └── config.yml
├── .gitignore
├── .pre-commit-config.yaml
├── COVENANT.md
├── COVENANT_zh-CN.md
├── LICENSE
├── README.md
├── README_zh-CN.md
├── api_test
├── README.md
├── double_blind
│ ├── app.py
│ ├── default_template.json
│ ├── format_data2json.py
│ └── style.css
├── post_infer.py
├── post_train.py
└── post_video_infer.py
├── images
├── controlnet_num.jpg
├── ding_erweima.jpg
├── double_blindui.jpg
├── dsw.png
├── erweima.jpg
├── infer_ui.jpg
├── install.jpg
├── multi_people_1.jpg
├── multi_people_2.jpg
├── no_found_image.jpg
├── overview.jpg
├── results_1.jpg
├── results_2.jpg
├── results_3.jpg
├── scene_lora
│ ├── Christmas_1.jpg
│ ├── Cyberpunk_1.jpg
│ ├── FairMaidenStyle_1.jpg
│ ├── Gentleman_1.jpg
│ ├── GuoFeng_1.jpg
│ ├── GuoFeng_2.jpg
│ ├── GuoFeng_3.jpg
│ ├── GuoFeng_4.jpg
│ ├── Minimalism_1.jpg
│ ├── NaturalWind_1.jpg
│ ├── Princess_1.jpg
│ ├── Princess_2.jpg
│ ├── Princess_3.jpg
│ ├── SchoolUniform_1.jpg
│ └── SchoolUniform_2.jpg
├── single_people.jpg
├── train_1.jpg
├── train_2.jpg
├── train_3.jpg
├── train_detail.jpg
├── train_detail1.jpg
├── train_ui.jpg
├── tryon
│ ├── cloth
│ │ ├── demo_black_200.jpg
│ │ ├── demo_dress_200.jpg
│ │ ├── demo_purple_200.jpg
│ │ ├── demo_short_200.jpg
│ │ └── demo_white_200.jpg
│ └── template
│ │ ├── boy.jpg
│ │ ├── dress.jpg
│ │ ├── girl.jpg
│ │ └── short.jpg
└── wechat.jpg
├── install.py
├── javascript
└── ui.js
├── models
├── infer_templates
│ ├── 1.jpg
│ ├── 3.jpg
│ ├── 4.jpg
│ ├── 5.jpg
│ ├── 6.jpg
│ ├── 7.jpg
│ ├── 8.jpg
│ ├── 9.jpg
│ └── Put templates here
├── stable-diffusion-v1-5
│ ├── model_index.json
│ ├── scheduler
│ │ └── scheduler_config.json
│ └── tokenizer
│ │ ├── merges.txt
│ │ ├── special_tokens_map.json
│ │ ├── tokenizer_config.json
│ │ └── vocab.json
├── stable-diffusion-xl
│ ├── madebyollin_sdxl_vae_fp16_fix
│ │ └── config.json
│ ├── models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k
│ │ ├── refs
│ │ │ └── main
│ │ └── snapshots
│ │ │ └── 8c7a3583335de4dba1b07182dbf81c75137ce67b
│ │ │ ├── config.json
│ │ │ ├── merges.txt
│ │ │ ├── open_clip_config.json
│ │ │ ├── preprocessor_config.json
│ │ │ ├── pytorch_model.bin.index.json
│ │ │ ├── special_tokens_map.json
│ │ │ ├── tokenizer.json
│ │ │ ├── tokenizer_config.json
│ │ │ └── vocab.json
│ ├── models--openai--clip-vit-large-patch14
│ │ ├── refs
│ │ │ └── main
│ │ └── snapshots
│ │ │ └── 32bd64288804d66eefd0ccbe215aa642df71cc41
│ │ │ ├── config.json
│ │ │ ├── merges.txt
│ │ │ ├── preprocessor_config.json
│ │ │ ├── special_tokens_map.json
│ │ │ ├── tokenizer.json
│ │ │ ├── tokenizer_config.json
│ │ │ └── vocab.json
│ └── stabilityai_stable_diffusion_xl_base_1.0
│ │ └── scheduler
│ │ └── scheduler_config.json
└── training_templates
│ ├── 1.jpg
│ ├── 2.jpg
│ ├── 3.jpg
│ └── 4.jpg
└── scripts
├── api.py
├── easyphoto_config.py
├── easyphoto_infer.py
├── easyphoto_train.py
├── easyphoto_tryon_infer.py
├── easyphoto_ui.py
├── easyphoto_utils
├── __init__.py
├── animatediff
│ ├── README.MD
│ ├── animatediff_cn.py
│ ├── animatediff_i2ibatch.py
│ ├── animatediff_infotext.py
│ ├── animatediff_infv2v.py
│ ├── animatediff_latent.py
│ ├── animatediff_lcm.py
│ ├── animatediff_logger.py
│ ├── animatediff_lora.py
│ ├── animatediff_mm.py
│ ├── animatediff_output.py
│ ├── animatediff_prompt.py
│ ├── animatediff_ui.py
│ └── motion_module.py
├── animatediff_utils.py
├── common_utils.py
├── face_process_utils.py
├── fire_utils.py
├── loractl_utils.py
├── psgan_utils.py
└── tryon_utils.py
├── preprocess.py
├── sdwebui.py
└── train_kohya
├── ddpo_pytorch
├── diffusers_patch
│ ├── ddim_with_logprob.py
│ └── pipeline_with_logprob.py
├── prompts.py
├── rewards.py
└── stat_tracking.py
├── train_ddpo.py
├── train_lora.py
├── train_lora_sd_XL.py
└── utils
├── __init__.py
├── gpu_info.py
├── lora_utils.py
├── lora_utils_diffusers.py
├── model_utils.py
├── original_unet.py
└── original_unet_sd_XL.py
/.github/ISSUE_TEMPLATE/bug_report.yml:
--------------------------------------------------------------------------------
1 | # Borrowed from sd-webui-controlnet.
2 | name: Bug Report
3 | description: Create a report
4 | title: "[Bug]: "
5 | labels: ["bug-report"]
6 |
7 | body:
8 | - type: checkboxes
9 | attributes:
10 | label: Is there an existing issue for this?
11 | description: Please search to see if an issue already exists for the bug you encountered, and that it hasn't been fixed in a recent build/commit.
12 | options:
13 | - label: I have searched the existing issues and checked the recent builds/commits of both this extension and the webui
14 | required: true
15 | - type: checkboxes
16 | attributes:
17 | label: Is EasyPhoto the latest version?
18 | description: Please check for updates in the extensions tab and reproduce the bug with the latest version.
19 | options:
20 | - label: I have updated EasyPhoto to the latest version and the bug still exists.
21 | required: true
22 | - type: markdown
23 | attributes:
24 | value: |
25 | *Please fill this form with as much information as possible, don't forget to fill "What OS..." and "What browsers" and *provide screenshots if possible**
26 | - type: textarea
27 | id: what-did
28 | attributes:
29 | label: What happened?
30 | description: Tell us what happened in a very clear and simple way
31 | validations:
32 | required: true
33 | - type: textarea
34 | id: steps
35 | attributes:
36 | label: Steps to reproduce the problem
37 | description: Please provide us with precise step by step information on how to reproduce the bug
38 | value: |
39 | 1. Go to ....
40 | 2. Press ....
41 | 3. ...
42 | validations:
43 | required: true
44 | - type: textarea
45 | id: what-should
46 | attributes:
47 | label: What should have happened?
48 | description: Tell what you think the normal behavior should be
49 | validations:
50 | required: true
51 | - type: textarea
52 | id: commits
53 | attributes:
54 | label: Commit where the problem happens
55 | description: Which commit of the extension are you running on? Please include the commit of both the extension and the webui (Do not write *Latest version/repo/commit*, as this means nothing and will have changed by the time we read your issue. Rather, copy the **Commit** link at the bottom of the UI, or from the cmd/terminal if you can't launch it.)
56 | value: |
57 | webui:
58 | EastPhoto:
59 | validations:
60 | required: true
61 | - type: dropdown
62 | id: browsers
63 | attributes:
64 | label: What browsers do you use to access the UI ?
65 | multiple: true
66 | options:
67 | - Mozilla Firefox
68 | - Google Chrome
69 | - Brave
70 | - Apple Safari
71 | - Microsoft Edge
72 | - type: textarea
73 | id: cmdargs
74 | attributes:
75 | label: Command Line Arguments
76 | description: Are you using any launching parameters/command line arguments (modified webui-user .bat/.sh) ? If yes, please write them below. Write "No" otherwise.
77 | render: Shell
78 | validations:
79 | required: true
80 | - type: textarea
81 | id: extensions
82 | attributes:
83 | label: List of enabled extensions
84 | description: Please provide a full list of enabled extensions or screenshots of your "Extensions" tab.
85 | validations:
86 | required: true
87 | - type: textarea
88 | id: logs
89 | attributes:
90 | label: Console logs
91 | description: Please provide full cmd/terminal logs from the moment you started UI to the end of it, after your bug happened. If it's very long, provide a link to pastebin or similar service.
92 | render: Shell
93 | validations:
94 | required: true
95 | - type: textarea
96 | id: misc
97 | attributes:
98 | label: Additional information
99 | description: Please provide us with any relevant additional info or context.
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/config.yml:
--------------------------------------------------------------------------------
1 | blank_issues_enabled: true
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 | # weight
162 | *.pth
163 | *.onnx
164 | *.safetensors
165 | *.ckpt
166 | *.bin
167 | *.pkl
168 | *.jpg
169 | *.png
170 |
171 | models/stable-diffusion-xl/version.txt
172 | models/pose_templates
173 | scripts/thirdparty
174 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/psf/black
3 | rev: 22.3.0
4 | hooks:
5 | - id: black
6 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py
7 | args: ["--line-length=140"]
8 | - repo: https://github.com/PyCQA/flake8
9 | rev: 3.9.2
10 | hooks:
11 | - id: flake8
12 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py
13 | args: ["--max-line-length=140", "--ignore=E303,E731,W191,W504,E402,E203,F541,W605,W503,E501,E712, F401", "--exclude=__init__.py"]
14 | - repo: https://github.com/myint/autoflake
15 | rev: v1.4
16 | hooks:
17 | - id: autoflake
18 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py
19 | args:
20 | [
21 | "--recursive",
22 | "--in-place",
23 | "--remove-unused-variable",
24 | "--ignore-init-module-imports",
25 | "--exclude=__init__.py"
26 | ]
27 | - repo: https://github.com/pre-commit/pre-commit-hooks
28 | rev: v4.4.0
29 | hooks:
30 | - id: check-ast
31 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py
32 | - id: check-byte-order-marker
33 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py
34 | - id: check-case-conflict
35 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py
36 | - id: check-docstring-first
37 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py
38 | - id: check-executables-have-shebangs
39 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py
40 | - id: check-json
41 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py
42 | - id: check-yaml
43 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py
44 | - id: debug-statements
45 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py
46 | - id: detect-private-key
47 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py
48 | - id: end-of-file-fixer
49 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py
50 | - id: trailing-whitespace
51 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py
52 | - id: mixed-line-ending
53 | exclude: models/|scripts/easyphoto_utils/animatediff/|scripts/easyphoto_utils/animatediff_utils.py
54 |
--------------------------------------------------------------------------------
/COVENANT.md:
--------------------------------------------------------------------------------
1 | # EasyPhoto Developer Covenant
2 | Disclaimer: This covenant serves as a set of recommended guidelines.
3 |
4 | ## Overview:
5 | EasyPhoto is an open-source software built on the SDWebUI plugin ecosystem, focusing on leveraging AIGC technology to create true-to-life, aesthetic, and beautiful AI portraits ("true/like/beautiful"). We are committed to expanding the application scope of this technology, lowering the entry barrier, and facilitating use for a wide audience.
6 |
7 | ## Covenant Purpose:
8 | Although technology is inherently neutral in value, considering that EasyPhoto already has the capability to produce highly realistic images, particularly facial images, we strongly suggest that everyone involved in development and usage adhere to the following guidelines.
9 |
10 | ## Behavioral Guidelines:
11 | - Comply with laws and regulations of relevant jurisdictions: It is prohibited to use this technology for any unlawful, criminal, or activities against public morals and decency.
12 | - Content Restrictions: It is prohibited to produce or disseminate any images that may involve political figures, pornography, violence, or other activities contrary to regional regulations.
13 |
14 | ## Ongoing Updates:
15 | This covenant will be updated periodically to adapt to technological and social advancements. We encourage community members to follow these guidelines in daily interactions and usage. Non-compliance will result in appropriate community management actions.
16 |
17 | Thank you for your cooperation and support. Together, let's ensure that EasyPhoto remains a responsible and sustainably-developed open-source software.
18 |
--------------------------------------------------------------------------------
/COVENANT_zh-CN.md:
--------------------------------------------------------------------------------
1 | # EasyPhoto 开发者公约
2 | !声明:本公约仅为推荐性准则
3 |
4 | ## 概述:
5 | EasyPhoto 是一个基于SDWebUI插件生态构建的开源软件,专注于利用AIGC技术实现真/像/美的AI-写真。我们致力于拓展该技术的应用范围,降低使用门槛,并为广大用户提供便利。
6 |
7 | ## 公约宗旨:
8 | 尽管技术本身并无价值倾向,但考虑到EasyPhoto目前已具备生成逼真图像(特别是人脸图像)的能力,我们强烈建议所有参与开发和使用的人员遵循以下准则。
9 |
10 | ## 行为准则:
11 | - 遵循相关地区的法律和法规:不得利用本技术从事任何违法、犯罪或有悖于社会公序良俗的活动。
12 | - 内容限制:禁止生成或传播任何可能涉及政治人物、色情、暴力或其他违反相关地区规定的图像。
13 |
14 | ## 持续更新:
15 | 本公约将不定期进行更新以适应技术和社会发展。我们鼓励社群成员在日常交流和使用中遵循这些准则。未遵守本公约的行为将在社群管理中受到相应限制。
16 | 感谢您的配合与支持。我们共同努力,以确保EasyPhoto成为一个负责任和可持续发展的开源软件。
17 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | the copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by the Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributors that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, the Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assuming any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README_zh-CN.md:
--------------------------------------------------------------------------------
1 | # EasyPhoto | 您的智能 AI 照片生成器。
2 | 🦜 EasyPhoto是一款Webui UI插件,用于生成AI肖像画,该代码可用于训练与您相关的数字分身。
3 |
4 | 🦜 🦜 Welcome!
5 |
6 | [](https://huggingface.co/spaces/alibaba-pai/easyphoto)
7 |
8 | [English](./README.md) | 简体中文
9 |
10 | # 目录
11 | - [简介](#简介)
12 | - [TODO List](#todo-list)
13 | - [快速启动](#快速启动)
14 | - [1. 云使用: AliyunDSW/AutoDL/Docker](#1-云使用-aliyundswautodldocker)
15 | - [2. 本地安装: 环境检查/下载/安装](#2-本地安装-环境检查下载安装)
16 | - [如何使用](#如何使用)
17 | - [1. 模型训练](#1-模型训练)
18 | - [2. 人物生成](#2-人物生成)
19 | - [API测试](./api_test/README.md)
20 | - [算法详细信息](#算法详细信息)
21 | - [1. 架构概述](#1-架构概述)
22 | - [2. 训练细节](#2-训练细节)
23 | - [3. 推理细节](#3-推理细节)
24 | - [参考文献](#参考文献)
25 | - [相关项目](#相关项目)
26 | - [许可证](#许可证)
27 | - [联系我们](#联系我们)
28 |
29 | # 简介
30 | EasyPhoto是一款Webui UI插件,用于生成AI肖像画,该代码可用于训练与您相关的数字分身。建议使用 5 到 20 张肖像图片进行训练,最好是半身照片且不要佩戴眼镜(少量可以接受)。训练完成后,我们可以在推理部分生成图像。我们支持使用预设模板图片与上传自己的图片进行推理。
31 |
32 | 请阅读我们的开发者公约,共建美好社区 [covenant](./COVENANT.md) | [简体中文](./COVENANT_zh-CN.md)
33 |
34 | 如果您在训练中遇到一些问题,请参考 [VQA](https://github.com/aigc-apps/sd-webui-EasyPhoto/wiki)。
35 |
36 | 我们现在支持从不同平台快速启动,请参阅 [快速启动](#快速启动)。
37 |
38 | 新特性:
39 | - 支持基于LCM-Lora的采样加速,现在您只需要进行12个steps(vs 50steps)来生成图像和视频, 并支持了场景化(风格化) Lora的训练和大量内置的模型。[🔥 🔥 🔥 🔥 2023.12.09]
40 | - 支持基于Concepts-Sliders的属性编辑和虚拟试穿,请参考[sliders-wiki](https://github.com/aigc-apps/sd-webui-EasyPhoto/wiki/Attribute-Edit) , [tryon-wiki](https://github.com/aigc-apps/sd-webui-EasyPhoto/wiki/TryOn)获取更多详细信息。[🔥 🔥 🔥 🔥 2023.12.08]
41 | - 感谢[揽睿星舟](https://www.lanrui-ai.com/) 提供了内置EasyPhoto的SDWebUI官方镜像,并承诺每两周更新一次。亲自测试,可以在2分钟内拉起资源,并在5分钟内完成启动。[🔥 🔥 🔥 🔥 2023.11.20]
42 | - ComfyUI 支持 [repo](https://github.com/THtianhao/ComfyUI-Portrait-Maker), 感谢[THtianhao](https://github.com/THtianhao)的精彩工作![🔥 🔥 🔥 2023.10.17]
43 | - EasyPhoto 论文地址 [arxiv](https://arxiv.org/abs/2310.04672)[🔥 🔥 🔥 2023.10.10]
44 | - 支持使用SDXL模型和一定的选项直接生成高清大图,不再需要上传模板,需要16GB显存。具体细节可以前往[这里](https://zhuanlan.zhihu.com/p/658940203)[🔥 🔥 🔥 2023.09.26]
45 | - 我们同样支持[Diffusers版本](https://github.com/aigc-apps/EasyPhoto/)。 [🔥 2023.09.25]
46 | - **支持对背景进行微调,并计算生成的图像与用户之间的相似度得分。** [🔥🔥 2023.09.15]
47 | - **支持不同预测基础模型。** [🔥🔥 2023.09.08]
48 | - **支持多人生成!添加缓存选项以优化推理速度。在UI上添加日志刷新。** [🔥🔥 2023.09.06]
49 | - 创建代码!现在支持 Windows 和 Linux。[🔥 2023.09.02]
50 |
51 | 这些是我们的生成结果:
52 | 
53 | 
54 | 
55 |
56 | 我们的ui界面如下:
57 | **训练部分:**
58 | 
59 | **预测部分:**
60 | 
61 |
62 | # TODO List
63 | - 支持中文界面。
64 | - 支持模板背景部分变化。
65 | - 支持高分辨率。
66 |
67 | # 快速启动
68 | ### 1. 云使用: AliyunDSW/AutoDL/揽睿星舟/Docker
69 | #### a. 通过阿里云 DSW
70 | DSW 有免费 GPU 时间,用户可申请一次,申请后3个月内有效。
71 |
72 | 阿里云在[Freetier](https://free.aliyun.com/?product=9602825&crowd=enterprise&spm=5176.28055625.J_5831864660.1.e939154aRgha4e&scm=20140722.M_9974135.P_110.MO_1806-ID_9974135-MID_9974135-CID_30683-ST_8512-V_1)提供免费GPU时间,获取并在阿里云PAI-DSW中使用,3分钟内即可启动EasyPhoto
73 |
74 | [](https://gallery.pai-ml.com/#/preview/deepLearning/cv/stable_diffusion_easyphoto)
75 |
76 | #### b. 通过揽睿星舟/AutoDL
77 | ##### 揽睿星舟
78 | 揽睿星舟官方全插件版本内置EasyPhoto,并承诺每两周测试与更新,亲测可用,5分钟内拉起,感谢他们的支持和对社区做出的贡献。
79 |
80 | ##### AutoDL
81 | 如果您正在使用 AutoDL,您可以使用我们提供的镜像快速启动 Stable DIffusion webui。
82 |
83 | 您可以在社区镜像中填写以下信息来选择所需的镜像。
84 | ```
85 | aigc-apps/sd-webui-EasyPhoto/sd-webui-EasyPhoto
86 | ```
87 | #### c. 通过docker
88 | 使用docker的情况下,请保证机器中已经正确安装显卡驱动与CUDA环境,然后以此执行以下命令:
89 | ```
90 | # 拉取镜像
91 | docker pull registry.cn-beijing.aliyuncs.com/mybigpai/sd-webui-easyphoto:0.0.3
92 |
93 | # 进入镜像
94 | docker run -it -p 7860:7860 --network host --gpus all registry.cn-beijing.aliyuncs.com/mybigpai/sd-webui-easyphoto:0.0.3
95 |
96 | # 启动webui
97 | python3 launch.py --port 7860
98 | ```
99 |
100 | ### 2. 本地安装: 环境检查/下载/安装
101 | #### a. 环境检查
102 | 我们已验证EasyPhoto可在以下环境中执行:
103 | 如果你遇到内存使用过高而导致WebUI进程自动被kill掉,请参考[ISSUE21](https://github.com/aigc-apps/sd-webui-EasyPhoto/issues/21),设置一些参数,例如num_threads=0,如果你也发现了其他解决的好办法,请及时联系我们。
104 |
105 | Windows 10 的详细信息:
106 | - 操作系统: Windows10
107 | - python: python 3.10
108 | - pytorch: torch2.0.1
109 | - tensorflow-cpu: 2.13.0
110 | - CUDA: 11.7
111 | - CUDNN: 8+
112 | - GPU: Nvidia-3060 12G
113 |
114 | Linux 的详细信息:
115 | - 操作系统 Ubuntu 20.04, CentOS
116 | - python: python3.10 & python3.11
117 | - pytorch: torch2.0.1
118 | - tensorflow-cpu: 2.13.0
119 | - CUDA: 11.7
120 | - CUDNN: 8+
121 | - GPU: Nvidia-A10 24G & Nvidia-V100 16G & Nvidia-A100 40G
122 |
123 | 我们需要大约 60GB 的可用磁盘空间(用于保存权重和数据集),请检查!
124 |
125 | #### b. 相关资料库和权重下载
126 | ##### i. Controlnet
127 | 我们需要使用 Controlnet 进行推理。相关软件源是[Mikubill/sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet)。在使用 EasyPhoto 之前,您需要安装这个软件源。
128 |
129 | 此外,我们至少需要三个 Controlnets 用于推理。因此,您需要设置 **Multi ControlNet: Max models amount (requires restart)**。
130 | 
131 |
132 | ##### ii. 其他依赖关系。
133 | 我们与现有的 stable-diffusion-webui 环境相互兼容,启动 stable-diffusion-webui 时会安装相关软件源。
134 |
135 | 我们所需的权重会在第一次开始训练时自动下载。
136 |
137 | #### c. 插件安装
138 | 现在我们支持从 git 安装 EasyPhoto。我们的仓库网址是 https://github.com/aigc-apps/sd-webui-EasyPhoto。
139 |
140 | 今后,我们将支持从 **Available** 安装 EasyPhoto。
141 |
142 | 
143 |
144 | # 如何使用
145 | ### 1. 模型训练
146 | EasyPhoto训练界面如下:
147 | - 左边是训练图像。只需点击上传照片即可上传图片,点击清除照片即可删除上传的图片;
148 | - 右边是训练参数,不能为第一次训练进行调整。
149 |
150 | 点击上传照片后,我们可以开始上传图像**这里最好上传5到20张图像,包括不同的角度和光照**。最好有一些不包括眼镜的图像。如果所有图片都包含眼镜眼镜,则生成的结果可以容易地生成眼镜。
151 | 
152 |
153 | 然后我们点击下面的“开始培训”,此时,我们需要填写上面的用户ID,例如用户名,才能开始培训。
154 | 
155 |
156 | 模型开始训练后,webui会自动刷新训练日志。如果没有刷新,请单击“Refresh Log”按钮。
157 | 
158 |
159 | 如果要设置参数,每个参数的解析如下:
160 | | 参数名 | 含义 |
161 | |--|--|
162 | | resolution | 训练时喂入网络的图片大小,默认值为512 |
163 | | validation & save steps| 验证图片与保存中间权重的steps数,默认值为100,代表每100步验证一次图片并保存权重 |
164 | | max train steps | 最大训练步数,默认值为800 |
165 | | max steps per photos | 每张图片的最大训练次数,默认为200 |
166 | | train batch size | 训练的批次大小,默认值为1 |
167 | | gradient accumulationsteps | 是否进行梯度累计,默认值为4,结合train batch size来看,每个Step相当于喂入四张图片 |
168 | | dataloader num workers | 数据加载的works数量,windows下不生效,因为设置了会报错,Linux正常设置 |
169 | | learning rate | 训练Lora的学习率,默认为1e-4 |
170 | | rank Lora | 权重的特征长度,默认为128 |
171 | | network alpha | Lora训练的正则化参数,一般为rank的二分之一,默认为64 |
172 |
173 | ### 2. 人物生成
174 | #### a. 单人模版
175 | - 步骤1:点击刷新按钮,查询训练后的用户ID对应的模型。
176 | - 步骤2:选择用户ID。
177 | - 步骤3:选择需要生成的模板。
178 | - 步骤4:单击“生成”按钮生成结果。
179 |
180 | 
181 |
182 | #### b. 多人模板
183 | - 步骤1:转到EasyPhoto的设置页面,设置num_of_Faceid大于1。
184 | - 步骤2:应用设置。
185 | - 步骤3:重新启动webui的ui界面。
186 | - 步骤4:返回EasyPhoto并上传多人模板。
187 | - 步骤5:选择两个人的用户ID。
188 | - 步骤6:单击“生成”按钮。执行图像生成。
189 |
190 | 
191 | 
192 |
193 | # 算法详细信息
194 | - 英文论文[arxiv](https://arxiv.org/abs/2310.04672)
195 | - 中文博客[这里](https://blog.csdn.net/weixin_44791964/article/details/132922309)
196 |
197 | ### 1. 架构概述
198 |
199 | 
200 |
201 | 在人工智能肖像领域,我们希望模型生成的图像逼真且与用户相似,而传统方法会引入不真实的光照(如人脸融合或roop)。为了解决这种不真实的问题,我们引入了稳定扩散模型的图像到图像功能。生成完美的个人肖像需要考虑所需的生成场景和用户的数字分身。我们使用一个预先准备好的模板作为所需的生成场景,并使用一个在线训练的人脸 LoRA 模型作为用户的数字分身,这是一种流行的稳定扩散微调模型。我们使用少量用户图像来训练用户的稳定数字分身,并在推理过程中根据人脸 LoRA 模型和预期生成场景生成个人肖像图像。
202 |
203 |
204 | ### 2. 训练细节
205 |
206 | 
207 |
208 | 首先,我们对输入的用户图像进行人脸检测,确定人脸位置后,按照一定比例截取输入图像。然后,我们使用显著性检测模型和皮肤美化模型获得干净的人脸训练图像,该图像基本上只包含人脸。然后,我们为每张图像贴上一个固定标签。这里不需要使用标签器,而且效果很好。最后,我们对稳定扩散模型进行微调,得到用户的数字分身。
209 |
210 | 在训练过程中,我们会利用模板图像进行实时验证,在训练结束后,我们会计算验证图像与用户图像之间的人脸 ID 差距,从而实现 Lora 融合,确保我们的 Lora 是用户的完美数字分
211 | 身。
212 |
213 | 此外,我们将选择验证中与用户最相似的图像作为 face_id 图像,用于推理。
214 |
215 | ### 3. 推理细节
216 | #### a. 第一次扩散:
217 | 首先,我们将对接收到的模板图像进行人脸检测,以确定为实现稳定扩散而需要涂抹的遮罩。然后,我们将使用模板图像与最佳用户图像进行人脸融合。人脸融合完成后,我们将使用上述遮罩对融合后的人脸图像进行内绘(fusion_image)。此外,我们还将通过仿射变换(replace_image)把训练中获得的最佳 face_id 图像贴到模板图像上。然后,我们将对其应用 Controlnets,在融合图像中使用带有颜色的 canny 提取特征,在替换图像中使用 openpose 提取特征,以确保图像的相似性和稳定性。然后,我们将使用稳定扩散(Stable Diffusion)结合用户的数字分割进行生成。
218 |
219 | #### b. 第二次扩散:
220 | 在得到第一次扩散的结果后,我们将把该结果与最佳用户图像进行人脸融合,然后再次使用稳定扩散与用户的数字分身进行生成。第二次生成将使用更高的分辨率。
221 |
222 | # 特别感谢
223 | 特别感谢DevelopmentZheng, qiuyanxin, rainlee, jhuang1207, bubbliiiing, wuziheng, yjjinjie, hkunzhe, yunkchen同学们的代码贡献(此排名不分先后)。
224 |
225 | # 参考文献
226 | - insightface:https://github.com/deepinsight/insightface
227 | - cv_resnet50_face:https://www.modelscope.cn/models/damo/cv_resnet50_face-detection_retinaface/summary
228 | - cv_u2net_salient:https://www.modelscope.cn/models/damo/cv_u2net_salient-detection/summary
229 | - cv_unet_skin_retouching_torch:https://www.modelscope.cn/models/damo/cv_unet_skin_retouching_torch/summary
230 | - cv_unet-image-face-fusion:https://www.modelscope.cn/models/damo/cv_unet-image-face-fusion_damo/summary
231 | - kohya:https://github.com/bmaltais/kohya_ss
232 | - controlnet-webui:https://github.com/Mikubill/sd-webui-controlnet
233 |
234 | # 相关项目
235 | 我们还列出了一些很棒的开源项目以及任何你可能会感兴趣的扩展项目:
236 | - [ModelScope](https://github.com/modelscope/modelscope).
237 | - [FaceChain](https://github.com/modelscope/facechain).
238 | - [sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet).
239 | - [sd-webui-roop](https://github.com/s0md3v/sd-webui-roop).
240 | - [roop](https://github.com/s0md3v/roop).
241 | - [sd-webui-deforum](https://github.com/deforum-art/sd-webui-deforum).
242 | - [sd-webui-additional-networks](https://github.com/kohya-ss/sd-webui-additional-networks).
243 | - [a1111-sd-webui-tagcomplete](https://github.com/DominikDoom/a1111-sd-webui-tagcomplete).
244 | - [sd-webui-segment-anything](https://github.com/continue-revolution/sd-webui-segment-anything).
245 | - [sd-webui-tunnels](https://github.com/Bing-su/sd-webui-tunnels).
246 | - [sd-webui-mov2mov](https://github.com/Scholar01/sd-webui-mov2mov).
247 |
248 | # 许可证
249 | 本项目采用 [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE).
250 |
251 | # 联系我们
252 | 1. 使用[钉钉](https://www.dingtalk.com/)搜索2群54095000124或扫描下列二维码加入群聊
253 | 2. 由于 微信群 已经满了,需要扫描右边的图片先添加这个同学为好友,然后再加入 微信群 。
254 |
255 |
256 |
257 |
258 |
--------------------------------------------------------------------------------
/api_test/README.md:
--------------------------------------------------------------------------------
1 | # Test & Profile-23/10/09
2 | ## 环境准备
3 | - 请搭建一个待测试的SDWebUI环境
4 | - 保证可以访问到上述SDWebUI,准备python环境,满足依赖 base64/json/numpy/cv2
5 |
6 | ## 训练/推理测试代码
7 | - **post_train.py** 支持公网URL图片/本地图片读取
8 | - **post_infer.py** 支持公网URL图片/本地图片读取
9 | 代码提供了默认URL,可修改,本地图片通过命令行参数输入。
10 |
11 | ## 双盲测试
12 | 基于上述的推理代码,我们可以实现预定模板和预定人物的Lora的批量测试图片生成,形成某个版本的记录。并基于此对两个版本的代码的生成结果进行双盲测试,下面,我们简单的使用一个例子进行双盲测试。打开后的UI 如下图
13 |
14 | 
15 |
16 | #### 步骤
17 | - 按照环境准备,在环境准备相关的user_id模型
18 | - 准备预设模板和上面user_id对应的真人图片
19 | - templates
20 | - 1.jpg
21 | - 2.jpg
22 | - ref_image
23 | - id1.jpg
24 | - id2.jpg
25 | - 运行批量推理代码: version1, version2 需要分别推理一次。
26 | ```python
27 | python3 post_infer.py --template_dir templates --output_path test_data/version1 --user_ids your_id
28 | ```
29 | - 运行数据整理代码
30 | ```python
31 | python3 ./double_blind/format_data2json.py --ref_images ref_image --version1_dir test_data/version1 --version2_dir test_data/version2 --output_json test_v1_v2.json
32 | ```
33 | - 运行./double_blind/app.py 获取如下双盲测试页面, 请注意,gradio 存在很多bug,我们稳定可以运行的版本是gradio==3.48.0, 否则会出现gradio, BarPlot的奇怪问题。
34 | ```python
35 | python3 ./double_blind/app.py --data-path test_v1_v2.json --result-path ./result.jsonl
36 | ```
37 | 运行上述代码后,会得到一个上述页面。如果在域名指定的机器,则可分享相关测试域名(待补充),然后获得 version1和version2的 WinningRate,作为PR的记录。
38 |
--------------------------------------------------------------------------------
/api_test/double_blind/app.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | import random
5 |
6 | import gradio as gr
7 | import pandas as pd
8 |
9 |
10 | def read_json(file_path: str):
11 | return json.load(open(file_path))
12 |
13 |
14 | def read_jsonl(file_path: str):
15 | data = []
16 | with open(file_path, "r") as file:
17 | for line in file:
18 | data.append(json.loads(line))
19 | return data
20 |
21 |
22 | def write_jsonl(data: any, file_path: str):
23 | if not os.path.exists(file_path):
24 | with open(file_path, "w") as f:
25 | pass
26 | with open(file_path, "a") as f:
27 | json.dump(data, f, ensure_ascii=False)
28 | f.write("\n")
29 |
30 |
31 | def save_result(id, submit_cnt, ids, ids_list, id2data, results, *eval_results):
32 |
33 | if not all(eval_results):
34 | gr.Warning("请完整填写所有问题的答案。\nPlease complete the answers to all questions.")
35 | return next_item(id) + (submit_cnt,)
36 |
37 | if id is None:
38 | gr.Info("感谢您参与EasyPhoto的评测,本次评测已全部完成~🥰\nThank you for participating in the EasyPhoto review, this review is complete ~🥰")
39 | return None, [], None, None, draw_results(), submit_cnt
40 |
41 | if id in ids:
42 | ids.remove(id)
43 | item = id2data[id]
44 | result = {"id": id, "questions": template["questions"], "answers": []}
45 | for r in eval_results:
46 | if r == "持平/Tie":
47 | result["answers"].append("tie")
48 | elif r == "左边/Left":
49 | result["answers"].append("method1" if item["left"] == "img1" else "method2")
50 | elif r == "右边/Right":
51 | result["answers"].append("method2" if item["left"] == "img1" else "method1")
52 |
53 | results.append(result)
54 | write_jsonl(result, args.result_path)
55 |
56 | return next_item(ids, ids_list, id2data, results) + (submit_cnt + 1,)
57 |
58 |
59 | def next_item(ids, ids_list, id2data, results):
60 |
61 | if len(ids) <= 0:
62 | gr.Info("感谢您参与EasyPhoto的评测,本次评测已全部完成~🥰\nThank you for participating in the EasyPhoto review, this review is complete ~🥰")
63 | return None, [], None, None, draw_results(results, ids_list), ids, ids_list, id2data, results
64 |
65 | id = random.choice(list(ids))
66 |
67 | if random.random() < 0.5:
68 | id2data[id]["left"] = "img1"
69 | left_img = id2data[id]["img1"]
70 | right_img = id2data[id]["img2"]
71 | else:
72 | id2data[id]["left"] = "img2"
73 | left_img = id2data[id]["img2"]
74 | right_img = id2data[id]["img1"]
75 |
76 | item = id2data[id]
77 |
78 | return (
79 | item["id"],
80 | [(x, "") for x in item["reference_imgs"]],
81 | left_img,
82 | right_img,
83 | draw_results(results, ids_list),
84 | ids,
85 | ids_list,
86 | id2data,
87 | results,
88 | )
89 |
90 |
91 | def draw_results(results, ids_list):
92 |
93 | if len(results) < len(ids_list):
94 | return None
95 | else:
96 |
97 | questions = template["questions"]
98 | num_questions = len(questions)
99 |
100 | method1_win = [0] * num_questions
101 | tie = [0] * num_questions
102 | method2_win = [0] * num_questions
103 |
104 | for item in results:
105 | assert len(item["answers"]) == num_questions
106 | for i in range(num_questions):
107 | if item["answers"][i] == "method1":
108 | method1_win[i] += 1
109 | elif item["answers"][i] == "tie":
110 | tie[i] += 1
111 | elif item["answers"][i] == "method2":
112 | method2_win[i] += 1
113 | else:
114 | raise NotImplementedError()
115 | results_for_drawing = {}
116 |
117 | method1_win += [sum(method1_win) / len(method1_win)]
118 | tie += [sum(tie) / len(tie)]
119 | method2_win += [sum(method2_win) / len(method2_win)]
120 |
121 | results_for_drawing["Questions"] = (questions + ["Average"]) * 3
122 | results_for_drawing["Win Rate"] = (
123 | [x / len(results) * 100 for x in method1_win]
124 | + [x / len(results) * 100 for x in tie]
125 | + [x / len(results) * 100 for x in method2_win]
126 | )
127 |
128 | results_for_drawing["Winner"] = (
129 | [data[0]["method1"]] * (num_questions + 1) + ["Tie"] * (num_questions + 1) + [data[0]["method2"]] * (num_questions + 1)
130 | )
131 | results_for_drawing = pd.DataFrame(results_for_drawing)
132 |
133 | return gr.BarPlot(
134 | results_for_drawing,
135 | x="Questions",
136 | y="Win Rate",
137 | color="Winner",
138 | title="Human Evaluation Result",
139 | vertical=False,
140 | width=450,
141 | height=300,
142 | )
143 |
144 |
145 | def init_start(ids, ids_list, id2data, results):
146 | random_elements = random.sample(data, len(data) // 2)
147 | id2data = {}
148 | for item in random_elements:
149 | id2data[item["id"]] = item
150 | ids = set(id2data.keys())
151 | ids_list = set(id2data.keys())
152 | results = []
153 | return next_item(ids, ids_list, id2data, results)
154 |
155 |
156 | parser = argparse.ArgumentParser()
157 | parser.add_argument("--template-file", default="default_template.json")
158 | parser.add_argument("--data-path", default="data/makeup_transfer/data.json")
159 | parser.add_argument("--result-path", default="data/makeup_transfer/result.jsonl")
160 | parser.add_argument("--port", type=int, default=80)
161 |
162 | args = parser.parse_args()
163 | # global data
164 | if not os.path.exists(args.template_file):
165 | args.template_file = "./double_blind/default_template.json"
166 | template = read_json(args.template_file)
167 | data = read_json(args.data_path)
168 |
169 | with gr.Blocks(title="EasyPhoto双盲评测", css="style.css") as app:
170 |
171 | id = gr.State()
172 | id2data = gr.State({})
173 | ids = gr.State()
174 | ids_list = gr.State()
175 | results = gr.State([])
176 |
177 | with gr.Column(visible=True, elem_id="start"):
178 | gr.Markdown("### 欢迎您参与EasyPhoto的本次评测。")
179 | gr.Markdown("### Welcome to this review of EasyPhoto.")
180 | with gr.Row():
181 | start_btn = gr.Button("开始 / Start")
182 |
183 | with gr.Column(visible=False, elem_id="main"):
184 | submit_cnt = gr.State(value=1)
185 |
186 | with gr.Row():
187 | with gr.Column(scale=3):
188 | reference_imgs = gr.Gallery(
189 | [], columns=3, rows=1, label="人物参考图片", show_label=True, elem_id="reference-imgs", visible=template["show_references"]
190 | )
191 | with gr.Column(scale=1):
192 | pass
193 |
194 | gr.Markdown("### 根据下面的图片和上面的参考图片(如果有),回答下面的问题。")
195 | with gr.Row():
196 | with gr.Column(scale=3):
197 | with gr.Row():
198 | left_img = gr.Image(show_label=False)
199 | right_img = gr.Image(show_label=False)
200 | with gr.Column(scale=1):
201 | pass
202 |
203 | eval_results = []
204 | for question in template["questions"]:
205 | eval_results.append(gr.Radio(["左边/Left", "持平/Tie", "右边/Right"], label=question, elem_classes="question"))
206 |
207 | submit = gr.Button("提交 / Submit")
208 | next_btn = gr.Button("换一个 / Change Another")
209 |
210 | with gr.Accordion("查看结果/View Results", open=False):
211 | with gr.Row():
212 | with gr.Column(scale=1):
213 | plot = gr.BarPlot()
214 | with gr.Column(scale=1):
215 | pass
216 |
217 | start_btn.click(
218 | init_start,
219 | inputs=[ids, ids_list, id2data, results],
220 | outputs=[id, reference_imgs, left_img, right_img, plot, ids, ids_list, id2data, results],
221 | ).then(
222 | fn=None,
223 | _js="\
224 | () => {\
225 | document.querySelector('#start').style.display = 'none';\
226 | document.querySelector('#main').style.display = 'flex';\
227 | }\
228 | ",
229 | inputs=None,
230 | outputs=[],
231 | )
232 |
233 | submit.click(
234 | save_result,
235 | inputs=[id, submit_cnt, ids, ids_list, id2data, results] + eval_results,
236 | outputs=[id, reference_imgs, left_img, right_img, plot, ids, ids_list, id2data, results, submit_cnt],
237 | )
238 | next_btn.click(
239 | next_item,
240 | inputs=[ids, ids_list, id2data, results],
241 | outputs=[id, reference_imgs, left_img, right_img, plot, ids, ids_list, id2data, results],
242 | )
243 |
244 | if __name__ == "__main__":
245 |
246 | # 最高并发15
247 | app.queue(concurrency_count=15).launch(server_name="0.0.0.0", server_port=args.port, show_api=False)
248 |
--------------------------------------------------------------------------------
/api_test/double_blind/default_template.json:
--------------------------------------------------------------------------------
1 | {
2 | "show_references": true,
3 | "num_reference_imgs": 2,
4 | "max_once": 15,
5 | "questions": [
6 | "更美观?/Beauty?",
7 | "更像本人?/Similar?",
8 | "更真实?/Reality?"
9 | ]
10 | }
11 |
--------------------------------------------------------------------------------
/api_test/double_blind/format_data2json.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import os
4 | from glob import glob
5 |
6 |
7 | def format_ref_images(file_paths):
8 | file_names = [os.path.basename(path).split(".")[0] for path in file_paths]
9 | result_dict = {k: [v] for k, v in zip(file_names, file_paths)}
10 |
11 | return result_dict
12 |
13 |
14 | def remove_last_underscore(input_string):
15 | parts = input_string.split("_")
16 |
17 | if len(parts) > 1:
18 | parts = parts[:-1]
19 |
20 | result = "_".join(parts)
21 |
22 | return result
23 |
24 |
25 | def match_prefix(file_name, image_formats):
26 | for img_prefix in image_formats:
27 | if os.path.exists(file_name):
28 | return file_name
29 | else:
30 | current_prefix = file_name.split("/")[-1].split(".")[-1]
31 | file_name = file_name.replace(current_prefix, img_prefix)
32 |
33 | print("Warning: No match file in compared version!")
34 | return file_name
35 |
36 |
37 | def find_value_for_key(file_name, dictionary):
38 | parts = file_name.split("/")
39 | last_part = parts[-1].split(".")[0]
40 |
41 | user_id = remove_last_underscore(last_part)
42 |
43 | if user_id in dictionary.keys():
44 | return dictionary[user_id]
45 | else:
46 | return None
47 |
48 |
49 | if __name__ == "__main__":
50 |
51 | parser = argparse.ArgumentParser(description="Description of your script")
52 |
53 | parser.add_argument("--ref_images", type=str, default="", help="Path to the user_id reference directory")
54 | parser.add_argument("--version1_dir", type=str, help="Path to version1 output result")
55 | parser.add_argument("--version2_dir", type=str, help="Path to version2 output result")
56 | parser.add_argument("--output_json", type=str, help="Path to output_datajson")
57 |
58 | args = parser.parse_args()
59 |
60 | image_formats = ["*.jpg", "*.jpeg", "*.png", "*.webp"]
61 | ref_images = []
62 | for image_format in image_formats:
63 | ref_images.extend(glob(os.path.join(args.ref_images, image_format)))
64 |
65 | if len(ref_images) == 0:
66 | print(f"Your reference dir contains no reference images. Set --ref_images to your user_id reference directory")
67 | else:
68 | print(f"reference images contains : {ref_images}")
69 |
70 | ref_dicts = format_ref_images(ref_images)
71 | # print(ref_dicts)
72 |
73 | result_data = []
74 | abs_path = True
75 |
76 | version1_dir = args.version1_dir
77 | version2_dir = args.version2_dir
78 | method_a = version1_dir.strip().split("/")[-1]
79 | method_b = version2_dir.strip().split("/")[-1]
80 |
81 | image_formats = ["jpg", "jpeg", "png", "webp"]
82 |
83 | for root, dirs, files in os.walk(version1_dir):
84 | for filename in files:
85 | if filename.split(".")[-1] in image_formats:
86 | file_path = os.path.join(root, filename)
87 | file_path2 = os.path.join(version2_dir, filename)
88 | file_path2 = match_prefix(file_path2, image_formats)
89 |
90 | reference = find_value_for_key(file_path, ref_dicts)
91 |
92 | if reference:
93 | if abs_path:
94 | file_path = os.path.abspath(file_path)
95 | file_path2 = os.path.abspath(file_path2)
96 | reference = [os.path.abspath(t) for t in reference]
97 |
98 | if os.path.exists(file_path2) and reference is not None:
99 | data_item = {
100 | "id": len(result_data),
101 | "method1": method_a,
102 | "img1": file_path,
103 | "method2": method_b,
104 | "img2": file_path2,
105 | "reference_imgs": reference,
106 | }
107 |
108 | result_data.append(data_item)
109 | else:
110 | pass
111 | else:
112 | user_id = file_path.split("/")[-1].split("_")[0]
113 | print(
114 | f"No matching user_id for {file_path}. Aborting! \
115 | Please rename the ref image as {user_id}.jpg"
116 | )
117 |
118 | output_json = args.output_json
119 | with open(output_json, "w") as json_file:
120 | json.dump(result_data, json_file, indent=4)
121 |
122 | print(f"Generated JSON file: {output_json}")
123 |
--------------------------------------------------------------------------------
/api_test/double_blind/style.css:
--------------------------------------------------------------------------------
1 | .contain {
2 | display: flex;
3 | }
4 |
5 | #component-0 {
6 | flex-grow: 1;
7 | }
8 |
9 | #start {
10 | margin-top: -20vh;
11 | text-align: center;
12 | justify-content: center;
13 | }
14 |
15 | #start h3 {
16 | margin: 0;
17 | }
18 |
19 | #start #component-5 {
20 | display: block;
21 | }
22 |
23 | /* #start button {
24 | width: 40%;
25 | min-width: 280px;
26 | max-width: 500px;
27 | } */
28 |
29 | #reference-imgs .grid-wrap {
30 | min-height: 0 !important;
31 | max-height: 60vh;
32 | }
33 |
34 | .question span {
35 | font-size: 1.1rem;
36 | }
37 |
38 | .question input {
39 | width: 1.1rem !important;
40 | height: 1.1rem !important;
41 | }
42 |
--------------------------------------------------------------------------------
/api_test/post_infer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import base64
3 | import json
4 | import os
5 | import time
6 | from datetime import datetime
7 | from glob import glob
8 | from io import BytesIO
9 |
10 | import cv2
11 | import numpy as np
12 | import requests
13 | from tqdm import tqdm
14 |
15 |
16 | def decode_image_from_base64jpeg(base64_image):
17 | image_bytes = base64.b64decode(base64_image)
18 | np_arr = np.frombuffer(image_bytes, np.uint8)
19 | image = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
20 | return image
21 |
22 |
23 | def post(encoded_image, user_id=None, url="http://0.0.0.0:7860"):
24 | if user_id is None:
25 | user_id = "test"
26 | datas = json.dumps(
27 | {
28 | "user_ids": [user_id],
29 | "sd_model_checkpoint": "Chilloutmix-Ni-pruned-fp16-fix.safetensors",
30 | "init_image": encoded_image,
31 | "first_diffusion_steps": 50,
32 | "first_denoising_strength": 0.45,
33 | "second_diffusion_steps": 20,
34 | "second_denoising_strength": 0.35,
35 | "seed": 12345,
36 | "crop_face_preprocess": True,
37 | "before_face_fusion_ratio": 0.5,
38 | "after_face_fusion_ratio": 0.5,
39 | "apply_face_fusion_before": True,
40 | "apply_face_fusion_after": True,
41 | "color_shift_middle": True,
42 | "color_shift_last": True,
43 | "super_resolution": True,
44 | "super_resolution_method": "gpen",
45 | "skin_retouching_bool": False,
46 | "background_restore": False,
47 | "background_restore_denoising_strength": 0.35,
48 | "makeup_transfer": False,
49 | "makeup_transfer_ratio": 0.50,
50 | "face_shape_match": False,
51 | "tabs": 1,
52 | "ipa_control": False,
53 | "ipa_weight": 0.50,
54 | "ipa_image": None,
55 | "ref_mode_choose": "Infer with Pretrained Lora",
56 | "ipa_only_weight": 0.60,
57 | "ipa_only_image": None,
58 | "lcm_accelerate": False,
59 | }
60 | )
61 | r = requests.post(f"{url}/easyphoto/easyphoto_infer_forward", data=datas, timeout=1500)
62 | data = r.content.decode("utf-8")
63 | return data
64 |
65 |
66 | if __name__ == "__main__":
67 | """
68 | There are two ways to test:
69 | The first: make sure the directory is full of readable images
70 | The second: public link of readable picture
71 | """
72 | parser = argparse.ArgumentParser(description="Description of your script")
73 |
74 | parser.add_argument("--template_dir", type=str, default="", help="Path to the template directory")
75 | parser.add_argument("--output_path", type=str, default="./", help="Path to the output directory")
76 | parser.add_argument("--user_ids", type=str, default="test", help="Test user ids, split with space")
77 |
78 | args = parser.parse_args()
79 |
80 | template_dir = args.template_dir
81 | output_path = args.output_path
82 | user_ids = args.user_ids.split(" ")
83 |
84 | if output_path != "./":
85 | os.makedirs(output_path, exist_ok=True)
86 |
87 | # initiate time
88 | now_date = datetime.now()
89 | time_start = time.time()
90 |
91 | # -------------------test infer------------------- #
92 | # When there is no parameter input.
93 | if template_dir == "":
94 | encoded_image = "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/webui/api/template1.jpeg"
95 | encoded_image = requests.get(encoded_image)
96 | encoded_image = base64.b64encode(BytesIO(encoded_image.content).read()).decode("utf-8")
97 |
98 | for user_id in tqdm(user_ids):
99 | outputs = post(encoded_image, user_id)
100 | outputs = json.loads(outputs)
101 | image = decode_image_from_base64jpeg(outputs["outputs"][0])
102 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_tmp.jpg")
103 | cv2.imwrite(toutput_path, image)
104 | print(outputs["message"])
105 |
106 | # When selecting a local file as a parameter input.
107 | else:
108 | image_formats = ["*.jpg", "*.jpeg", "*.png", "*.webp"]
109 | img_list = []
110 | for image_format in image_formats:
111 | img_list.extend(glob(os.path.join(template_dir, image_format)))
112 |
113 | if len(img_list) == 0:
114 | print(f" Input template dir {template_dir} contains no images")
115 | else:
116 | print(f" Total {len(img_list)} templates to test for {len(user_ids)} ID")
117 |
118 | # please set your test user ids in args
119 | for user_id in tqdm(user_ids):
120 | for img_path in tqdm(img_list):
121 | print(f" Call generate for ID ({user_id}) and Template ({img_path})")
122 |
123 | with open(img_path, "rb") as f:
124 | encoded_image = base64.b64encode(f.read()).decode("utf-8")
125 | outputs = post(encoded_image, user_id=user_id)
126 | outputs = json.loads(outputs)
127 |
128 | if len(outputs["outputs"]):
129 | image = decode_image_from_base64jpeg(outputs["outputs"][0])
130 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_" + os.path.basename(img_path))
131 | print(output_path)
132 | cv2.imwrite(toutput_path, image)
133 | else:
134 | print("Error!", outputs["message"])
135 | print(outputs["message"])
136 |
137 | # End of record time
138 | # The calculated time difference is the execution time of the program, expressed in seconds / s
139 | time_end = time.time()
140 | time_sum = time_end - time_start
141 | print("# --------------------------------------------------------- #")
142 | print(f"# Total expenditure: {time_sum}s")
143 | print("# --------------------------------------------------------- #")
144 |
--------------------------------------------------------------------------------
/api_test/post_train.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import json
3 | import os
4 | import sys
5 | import time
6 | from glob import glob
7 | from io import BytesIO
8 |
9 | import requests
10 |
11 |
12 | def post_train(encoded_images, url="http://0.0.0.0:7860"):
13 | datas = json.dumps(
14 | {
15 | "user_id": "test", # A custom ID that identifies the trained face model
16 | "sd_model_checkpoint": "Chilloutmix-Ni-pruned-fp16-fix.safetensors",
17 | "train_mode_choose": "Train Human Lora",
18 | "resolution": 512,
19 | "val_and_checkpointing_steps": 100,
20 | "max_train_steps": 800, # Training steps
21 | "steps_per_photos": 200,
22 | "train_batch_size": 1,
23 | "gradient_accumulation_steps": 4,
24 | "dataloader_num_workers": 16,
25 | "learning_rate": 1e-4,
26 | "rank": 128,
27 | "network_alpha": 64,
28 | "instance_images": encoded_images,
29 | "skin_retouching_bool": False,
30 | }
31 | )
32 | r = requests.post(f"{url}/easyphoto/easyphoto_train_forward", data=datas, timeout=1500)
33 | data = r.content.decode("utf-8")
34 | return data
35 |
36 |
37 | if __name__ == "__main__":
38 | """
39 | There are two ways to test:
40 | The first: make sure the directory is full of readable images
41 | The second: public link of readable picture
42 | """
43 | # initiate time
44 | time_start = time.time()
45 |
46 | # -------------------training procedure------------------- #
47 | # When there is no parameter input.
48 | if len(sys.argv) == 1:
49 | img_list = [
50 | "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/webui/api/t1.jpg",
51 | "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/webui/api/t2.jpg",
52 | "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/webui/api/t3.jpg",
53 | "https://pai-aigc-photog.oss-cn-hangzhou.aliyuncs.com/webui/api/t4.jpg",
54 | ]
55 | encoded_images = []
56 | for idx, img_path in enumerate(img_list):
57 | encoded_image = requests.get(img_path)
58 | encoded_image = base64.b64encode(BytesIO(encoded_image.content).read()).decode("utf-8")
59 | encoded_images.append(encoded_image)
60 |
61 | outputs = post_train(encoded_images)
62 | outputs = json.loads(outputs)
63 | print(outputs["message"])
64 |
65 | # When selecting a folder as a parameter input.
66 | elif len(sys.argv) == 2:
67 | img_list = glob(os.path.join(sys.argv[1], "*"))
68 | encoded_images = []
69 | for idx, img_path in enumerate(img_list):
70 | with open(img_path, "rb") as f:
71 | encoded_image = base64.b64encode(f.read()).decode("utf-8")
72 | encoded_images.append(encoded_image)
73 | outputs = post_train(encoded_images)
74 | outputs = json.loads(outputs)
75 | print(outputs["message"])
76 |
77 | else:
78 | print("other modes except url and local read are not supported")
79 |
80 | # End of record time
81 | # The calculated time difference is the execution time of the program, expressed in minute / m
82 | time_end = time.time()
83 | time_sum = (time_end - time_start) // 60
84 |
85 | print("# --------------------------------------------------------- #")
86 | print(f"# Total expenditure:{time_sum} minutes ")
87 | print("# --------------------------------------------------------- #")
88 |
--------------------------------------------------------------------------------
/api_test/post_video_infer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import base64
3 | import json
4 | import os
5 | import time
6 | from datetime import datetime
7 |
8 | import requests
9 | from tqdm import tqdm
10 |
11 |
12 | # Function to encode a video file to Base64
13 | def encode_video_to_base64(video_file_path):
14 | with open(video_file_path, "rb") as video_file:
15 | # Read the video file as binary data
16 | video_data = video_file.read()
17 | # Encode the data to Base64
18 | video_base64 = base64.b64encode(video_data)
19 | return video_base64
20 |
21 |
22 | # Function to decode Base64 encoded data and save it as a video file
23 | def decode_base64_to_video(encoded_video, output_file_path):
24 | with open(output_file_path, "wb") as output_file:
25 | # Decode the Base64 encoded data
26 | video_data = base64.b64decode(encoded_video)
27 | # Write the decoded binary data to the file
28 | output_file.write(video_data)
29 |
30 |
31 | def post(t2v_input_prompt="", init_image=None, last_image=None, init_video=None, user_id=None, tabs=0, url="http://0.0.0.0:7860"):
32 | if user_id is None:
33 | user_id = "test"
34 | datas = json.dumps(
35 | {
36 | "user_ids": [user_id],
37 | "sd_model_checkpoint": "Chilloutmix-Ni-pruned-fp16-fix.safetensors",
38 | "sd_model_checkpoint_for_animatediff_text2video": "majicmixRealistic_v7.safetensors",
39 | "sd_model_checkpoint_for_animatediff_image2video": "majicmixRealistic_v7.safetensors",
40 | "t2v_input_prompt": t2v_input_prompt,
41 | "t2v_input_width": 512,
42 | "t2v_input_height": 768,
43 | "scene_id": "none",
44 | "upload_control_video": False,
45 | "upload_control_video_type": "openpose",
46 | "openpose_video": None,
47 | "init_image": init_image,
48 | "init_image_prompt": "",
49 | "last_image": last_image,
50 | "init_video": init_video,
51 | "additional_prompt": "masterpiece, beauty",
52 | "max_frames": 16,
53 | "max_fps": 8,
54 | "save_as": "gif",
55 | "first_diffusion_steps": 50,
56 | "first_denoising_strength": 0.45,
57 | "seed": -1,
58 | "crop_face_preprocess": True,
59 | "before_face_fusion_ratio": 0.5,
60 | "after_face_fusion_ratio": 0.5,
61 | "apply_face_fusion_before": True,
62 | "apply_face_fusion_after": True,
63 | "color_shift_middle": True,
64 | "super_resolution": True,
65 | "super_resolution_method": "gpen",
66 | "skin_retouching_bool": False,
67 | "makeup_transfer": False,
68 | "makeup_transfer_ratio": 0.50,
69 | "face_shape_match": False,
70 | "video_interpolation": False,
71 | "video_interpolation_ext": 1,
72 | "tabs": tabs,
73 | "ipa_control": False,
74 | "ipa_weight": 0.50,
75 | "ipa_image": None,
76 | "lcm_accelerate": False,
77 | }
78 | )
79 | r = requests.post(f"{url}/easyphoto/easyphoto_video_infer_forward", data=datas, timeout=1500)
80 | data = r.content.decode("utf-8")
81 | return data
82 |
83 |
84 | if __name__ == "__main__":
85 | """
86 | There are two ways to test:
87 | The first: make sure the directory is full of readable images
88 | The second: public link of readable picture
89 | """
90 | parser = argparse.ArgumentParser(description="Description of your script")
91 |
92 | parser.add_argument(
93 | "--t2v_input_prompt",
94 | type=str,
95 | default="1girl, (white hair, long hair), blue eyes, hair ornament, blue dress, standing, looking at viewer, shy, upper-body",
96 | help="Prompt for t2v",
97 | )
98 | parser.add_argument("--init_image_path", type=str, default="", help="Path to the init image path")
99 | parser.add_argument("--last_image_path", type=str, default="", help="Path to the last image path")
100 | parser.add_argument("--video_path", type=str, default="", help="Path to the video path")
101 | parser.add_argument("--output_path", type=str, default="./", help="Path to the output directory")
102 | parser.add_argument("--user_ids", type=str, default="test", help="Test user ids, split with space")
103 |
104 | args = parser.parse_args()
105 |
106 | t2v_input_prompt = args.t2v_input_prompt
107 | init_image_path = args.init_image_path
108 | last_image_path = args.last_image_path
109 | video_path = args.video_path
110 | output_path = args.output_path
111 | user_ids = args.user_ids.split(" ")
112 |
113 | if output_path != "./":
114 | os.makedirs(output_path, exist_ok=True)
115 |
116 | # initiate time
117 | now_date = datetime.now()
118 | time_start = time.time()
119 |
120 | # -------------------test infer------------------- #
121 | # When there is no parameter input.
122 | if init_image_path == "" and last_image_path == "" and video_path == "":
123 | for user_id in tqdm(user_ids):
124 | outputs = post(t2v_input_prompt, None, None, None, user_id, tabs=0)
125 | outputs = json.loads(outputs)
126 | print(outputs["message"])
127 | if outputs["output_video"] is not None:
128 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_tmp.mp4")
129 | decode_base64_to_video(outputs["output_video"], toutput_path)
130 | elif outputs["output_gif"] is not None:
131 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_tmp.gif")
132 | decode_base64_to_video(outputs["output_gif"], toutput_path)
133 |
134 | elif init_image_path != "" and last_image_path == "" and video_path == "":
135 | with open(init_image_path, "rb") as f:
136 | init_image = base64.b64encode(f.read()).decode("utf-8")
137 |
138 | for user_id in tqdm(user_ids):
139 | outputs = post(t2v_input_prompt, init_image, None, None, user_id, tabs=1)
140 | outputs = json.loads(outputs)
141 | print(outputs["message"])
142 | if outputs["output_video"] is not None:
143 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_tmp.mp4")
144 | decode_base64_to_video(outputs["output_video"], toutput_path)
145 | elif outputs["output_gif"] is not None:
146 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_tmp.gif")
147 | decode_base64_to_video(outputs["output_gif"], toutput_path)
148 |
149 | elif init_image_path != "" and last_image_path != "" and video_path == "":
150 | with open(init_image_path, "rb") as f:
151 | init_image = base64.b64encode(f.read()).decode("utf-8")
152 | with open(last_image_path, "rb") as f:
153 | last_image = base64.b64encode(f.read()).decode("utf-8")
154 |
155 | for user_id in tqdm(user_ids):
156 | outputs = post(t2v_input_prompt, init_image, last_image, None, user_id, tabs=1)
157 | outputs = json.loads(outputs)
158 | print(outputs["message"])
159 | if outputs["output_video"] is not None:
160 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_tmp.mp4")
161 | decode_base64_to_video(outputs["output_video"], toutput_path)
162 | elif outputs["output_gif"] is not None:
163 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_tmp.gif")
164 | decode_base64_to_video(outputs["output_gif"], toutput_path)
165 |
166 | elif init_image_path == "" and last_image_path == "" and video_path != "":
167 | with open(video_path, "rb") as f:
168 | init_video = base64.b64encode(f.read()).decode("utf-8")
169 |
170 | for user_id in tqdm(user_ids):
171 | outputs = post(t2v_input_prompt, None, None, init_video, user_id, tabs=2)
172 | outputs = json.loads(outputs)
173 | print(outputs["message"])
174 | if outputs["output_video"] is not None:
175 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_tmp.mp4")
176 | decode_base64_to_video(outputs["output_video"], toutput_path)
177 | elif outputs["output_gif"] is not None:
178 | toutput_path = os.path.join(os.path.join(output_path), f"{user_id}_tmp.gif")
179 | decode_base64_to_video(outputs["output_gif"], toutput_path)
180 |
181 | # End of record time
182 | # The calculated time difference is the execution time of the program, expressed in seconds / s
183 | time_end = time.time()
184 | time_sum = time_end - time_start
185 | print("# --------------------------------------------------------- #")
186 | print(f"# Total expenditure: {time_sum}s")
187 | print("# --------------------------------------------------------- #")
188 |
--------------------------------------------------------------------------------
/images/controlnet_num.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/controlnet_num.jpg
--------------------------------------------------------------------------------
/images/ding_erweima.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/ding_erweima.jpg
--------------------------------------------------------------------------------
/images/double_blindui.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/double_blindui.jpg
--------------------------------------------------------------------------------
/images/dsw.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/dsw.png
--------------------------------------------------------------------------------
/images/erweima.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/erweima.jpg
--------------------------------------------------------------------------------
/images/infer_ui.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/infer_ui.jpg
--------------------------------------------------------------------------------
/images/install.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/install.jpg
--------------------------------------------------------------------------------
/images/multi_people_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/multi_people_1.jpg
--------------------------------------------------------------------------------
/images/multi_people_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/multi_people_2.jpg
--------------------------------------------------------------------------------
/images/no_found_image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/no_found_image.jpg
--------------------------------------------------------------------------------
/images/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/overview.jpg
--------------------------------------------------------------------------------
/images/results_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/results_1.jpg
--------------------------------------------------------------------------------
/images/results_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/results_2.jpg
--------------------------------------------------------------------------------
/images/results_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/results_3.jpg
--------------------------------------------------------------------------------
/images/scene_lora/Christmas_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/Christmas_1.jpg
--------------------------------------------------------------------------------
/images/scene_lora/Cyberpunk_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/Cyberpunk_1.jpg
--------------------------------------------------------------------------------
/images/scene_lora/FairMaidenStyle_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/FairMaidenStyle_1.jpg
--------------------------------------------------------------------------------
/images/scene_lora/Gentleman_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/Gentleman_1.jpg
--------------------------------------------------------------------------------
/images/scene_lora/GuoFeng_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/GuoFeng_1.jpg
--------------------------------------------------------------------------------
/images/scene_lora/GuoFeng_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/GuoFeng_2.jpg
--------------------------------------------------------------------------------
/images/scene_lora/GuoFeng_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/GuoFeng_3.jpg
--------------------------------------------------------------------------------
/images/scene_lora/GuoFeng_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/GuoFeng_4.jpg
--------------------------------------------------------------------------------
/images/scene_lora/Minimalism_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/Minimalism_1.jpg
--------------------------------------------------------------------------------
/images/scene_lora/NaturalWind_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/NaturalWind_1.jpg
--------------------------------------------------------------------------------
/images/scene_lora/Princess_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/Princess_1.jpg
--------------------------------------------------------------------------------
/images/scene_lora/Princess_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/Princess_2.jpg
--------------------------------------------------------------------------------
/images/scene_lora/Princess_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/Princess_3.jpg
--------------------------------------------------------------------------------
/images/scene_lora/SchoolUniform_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/SchoolUniform_1.jpg
--------------------------------------------------------------------------------
/images/scene_lora/SchoolUniform_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/scene_lora/SchoolUniform_2.jpg
--------------------------------------------------------------------------------
/images/single_people.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/single_people.jpg
--------------------------------------------------------------------------------
/images/train_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/train_1.jpg
--------------------------------------------------------------------------------
/images/train_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/train_2.jpg
--------------------------------------------------------------------------------
/images/train_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/train_3.jpg
--------------------------------------------------------------------------------
/images/train_detail.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/train_detail.jpg
--------------------------------------------------------------------------------
/images/train_detail1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/train_detail1.jpg
--------------------------------------------------------------------------------
/images/train_ui.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/train_ui.jpg
--------------------------------------------------------------------------------
/images/tryon/cloth/demo_black_200.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/tryon/cloth/demo_black_200.jpg
--------------------------------------------------------------------------------
/images/tryon/cloth/demo_dress_200.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/tryon/cloth/demo_dress_200.jpg
--------------------------------------------------------------------------------
/images/tryon/cloth/demo_purple_200.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/tryon/cloth/demo_purple_200.jpg
--------------------------------------------------------------------------------
/images/tryon/cloth/demo_short_200.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/tryon/cloth/demo_short_200.jpg
--------------------------------------------------------------------------------
/images/tryon/cloth/demo_white_200.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/tryon/cloth/demo_white_200.jpg
--------------------------------------------------------------------------------
/images/tryon/template/boy.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/tryon/template/boy.jpg
--------------------------------------------------------------------------------
/images/tryon/template/dress.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/tryon/template/dress.jpg
--------------------------------------------------------------------------------
/images/tryon/template/girl.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/tryon/template/girl.jpg
--------------------------------------------------------------------------------
/images/tryon/template/short.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/tryon/template/short.jpg
--------------------------------------------------------------------------------
/images/wechat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/images/wechat.jpg
--------------------------------------------------------------------------------
/install.py:
--------------------------------------------------------------------------------
1 | # Package check util
2 | # Modified from https://github.com/Bing-su/adetailer/blob/main/install.py
3 | import importlib.util
4 | import platform
5 | from importlib.metadata import version
6 |
7 | import launch
8 | from packaging.version import parse
9 |
10 |
11 | def is_installed(package: str):
12 | min_version = "0.0.0"
13 | max_version = "99999999.99999999.99999999"
14 | pkg_name = package
15 | version_check = True
16 | if "==" in package:
17 | pkg_name, _version = package.split("==")
18 | min_version = max_version = _version
19 | elif "<=" in package:
20 | pkg_name, _version = package.split("<=")
21 | max_version = _version
22 | elif ">=" in package:
23 | pkg_name, _version = package.split(">=")
24 | min_version = _version
25 | else:
26 | version_check = False
27 | package = pkg_name
28 | try:
29 | spec = importlib.util.find_spec(package)
30 | except ModuleNotFoundError:
31 | message = f"is_installed check for {str(package)} failed as error ModuleNotFoundError"
32 | print(message)
33 | return False
34 | if spec is None:
35 | message = f"is_installed check for {str(package)} failed as 'spec is None'"
36 | print(message)
37 | return False
38 | if not version_check:
39 | return True
40 | if package == "google.protobuf":
41 | package = "protobuf"
42 | try:
43 | pkg_version = version(package)
44 | return parse(min_version) <= parse(pkg_version) <= parse(max_version)
45 | except Exception as e:
46 | message = f"is_installed check for {str(package)} failed as error {str(e)}"
47 | print(message)
48 | return False
49 |
50 |
51 | # End of Package check util
52 |
53 | if not is_installed("cv2"):
54 | print("Installing requirements for easyphoto-webui")
55 | launch.run_pip("install opencv-python", "requirements for opencv")
56 |
57 | if not is_installed("tensorflow-cpu"):
58 | print("Installing requirements for easyphoto-webui")
59 | launch.run_pip("install tensorflow-cpu", "requirements for tensorflow")
60 |
61 | if not is_installed("onnx"):
62 | print("Installing requirements for easyphoto-webui")
63 | launch.run_pip("install onnx", "requirements for onnx")
64 |
65 | if not is_installed("onnxruntime"):
66 | print("Installing requirements for easyphoto-webui")
67 | launch.run_pip("install onnxruntime", "requirements for onnxruntime")
68 |
69 | if not is_installed("modelscope==1.9.3"):
70 | print("Installing requirements for easyphoto-webui")
71 | launch.run_pip("install modelscope==1.9.3", "requirements for modelscope")
72 |
73 | if not is_installed("einops"):
74 | print("Installing requirements for easyphoto-webui")
75 | launch.run_pip("install einops", "requirements for diffusers")
76 |
77 | if not is_installed("imageio>=2.29.0"):
78 | print("Installing requirements for easyphoto-webui")
79 | # The '>' will be interpreted as redirection (in linux) since SD WebUI uses `shell=True` in `subprocess.run`.
80 | launch.run_pip("install \"imageio>=2.29.0\"", "requirements for imageio")
81 |
82 | if not is_installed("av"):
83 | print("Installing requirements for easyphoto-webui")
84 | launch.run_pip("install \"imageio[pyav]\"", "requirements for av")
85 |
86 | # Temporarily pin fsspec==2023.9.2. See https://github.com/huggingface/datasets/issues/6330 for details.
87 | if not is_installed("fsspec==2023.9.2"):
88 | print("Installing requirements for easyphoto-webui")
89 | launch.run_pip("install fsspec==2023.9.2", "requirements for fsspec")
90 |
91 | # `StableDiffusionXLPipeline` in diffusers requires the invisible-watermark library.
92 | if not launch.is_installed("invisible-watermark"):
93 | print("Installing requirements for easyphoto-webui")
94 | launch.run_pip("install invisible-watermark", "requirements for invisible-watermark")
95 |
96 | # Tryon requires the shapely and segment-anything library.
97 | if not launch.is_installed("shapely"):
98 | print("Installing requirements for easyphoto-webui")
99 | launch.run_pip("install shapely", "requirements for shapely")
100 |
101 | if not launch.is_installed("segment_anything"):
102 | try:
103 | launch.run_pip("install segment-anything", "requirements for segment_anything")
104 | except Exception:
105 | print("Can't install segment-anything. Please follow the readme to install manually")
106 |
107 | if not is_installed("diffusers>=0.18.2"):
108 | print("Installing requirements for easyphoto-webui")
109 | try:
110 | launch.run_pip("install diffusers==0.23.0", "requirements for diffusers")
111 | except Exception as e:
112 | print(f"Can't install the diffusers==0.23.0. Error info {e}")
113 | launch.run_pip("install diffusers==0.18.2", "requirements for diffusers")
114 |
115 | if platform.system() != "Windows":
116 | if not is_installed("nvitop"):
117 | print("Installing requirements for easyphoto-webui")
118 | launch.run_pip("install nvitop==1.3.0", "requirements for tensorflow")
119 |
--------------------------------------------------------------------------------
/javascript/ui.js:
--------------------------------------------------------------------------------
1 | function ask_for_style_name(sd_model_checkpoint, dummy_component, _, train_mode_choose, resolution, val_and_checkpointing_steps, max_train_steps, steps_per_photos, train_batch_size, gradient_accumulation_steps, dataloader_num_workers, learning_rate, rank, network_alpha, validation, instance_images, enable_rl, max_rl_time, timestep_fraction, skin_retouching, training_prefix_prompt, crop_ratio) {
2 | var name_ = prompt('User id:');
3 | return [sd_model_checkpoint, dummy_component, name_, train_mode_choose, resolution, val_and_checkpointing_steps, max_train_steps, steps_per_photos, train_batch_size, gradient_accumulation_steps, dataloader_num_workers, learning_rate, rank, network_alpha, validation, instance_images, enable_rl, max_rl_time, timestep_fraction, skin_retouching, training_prefix_prompt, crop_ratio];
4 | }
5 |
6 | function switch_to_ep_photoinfer_upload() {
7 | gradioApp().getElementById('mode_easyphoto').querySelectorAll('button')[1].click();
8 | gradioApp().getElementById('mode_easyphoto_photo_inference').querySelectorAll('button')[0].click();
9 |
10 | return Array.from(arguments);
11 | }
12 |
13 | function switch_to_ep_tryon() {
14 | gradioApp().getElementById('mode_easyphoto').querySelectorAll('button')[3].click();
15 |
16 | return Array.from(arguments);
17 | }
18 |
--------------------------------------------------------------------------------
/models/infer_templates/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/infer_templates/1.jpg
--------------------------------------------------------------------------------
/models/infer_templates/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/infer_templates/3.jpg
--------------------------------------------------------------------------------
/models/infer_templates/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/infer_templates/4.jpg
--------------------------------------------------------------------------------
/models/infer_templates/5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/infer_templates/5.jpg
--------------------------------------------------------------------------------
/models/infer_templates/6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/infer_templates/6.jpg
--------------------------------------------------------------------------------
/models/infer_templates/7.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/infer_templates/7.jpg
--------------------------------------------------------------------------------
/models/infer_templates/8.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/infer_templates/8.jpg
--------------------------------------------------------------------------------
/models/infer_templates/9.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/infer_templates/9.jpg
--------------------------------------------------------------------------------
/models/infer_templates/Put templates here:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/infer_templates/Put templates here
--------------------------------------------------------------------------------
/models/stable-diffusion-v1-5/model_index.json:
--------------------------------------------------------------------------------
1 | {
2 | "_class_name": "StableDiffusionPipeline",
3 | "_diffusers_version": "0.6.0",
4 | "feature_extractor": [
5 | "transformers",
6 | "CLIPImageProcessor"
7 | ],
8 | "safety_checker": [
9 | "stable_diffusion",
10 | "StableDiffusionSafetyChecker"
11 | ],
12 | "scheduler": [
13 | "diffusers",
14 | "PNDMScheduler"
15 | ],
16 | "text_encoder": [
17 | "transformers",
18 | "CLIPTextModel"
19 | ],
20 | "tokenizer": [
21 | "transformers",
22 | "CLIPTokenizer"
23 | ],
24 | "unet": [
25 | "diffusers",
26 | "UNet2DConditionModel"
27 | ],
28 | "vae": [
29 | "diffusers",
30 | "AutoencoderKL"
31 | ]
32 | }
33 |
--------------------------------------------------------------------------------
/models/stable-diffusion-v1-5/scheduler/scheduler_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_class_name": "PNDMScheduler",
3 | "_diffusers_version": "0.6.0",
4 | "beta_end": 0.012,
5 | "beta_schedule": "scaled_linear",
6 | "beta_start": 0.00085,
7 | "num_train_timesteps": 1000,
8 | "set_alpha_to_one": false,
9 | "skip_prk_steps": true,
10 | "steps_offset": 1,
11 | "trained_betas": null,
12 | "clip_sample": false
13 | }
14 |
--------------------------------------------------------------------------------
/models/stable-diffusion-v1-5/tokenizer/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {
2 | "bos_token": {
3 | "content": "<|startoftext|>",
4 | "lstrip": false,
5 | "normalized": true,
6 | "rstrip": false,
7 | "single_word": false
8 | },
9 | "eos_token": {
10 | "content": "<|endoftext|>",
11 | "lstrip": false,
12 | "normalized": true,
13 | "rstrip": false,
14 | "single_word": false
15 | },
16 | "pad_token": "<|endoftext|>",
17 | "unk_token": {
18 | "content": "<|endoftext|>",
19 | "lstrip": false,
20 | "normalized": true,
21 | "rstrip": false,
22 | "single_word": false
23 | }
24 | }
25 |
--------------------------------------------------------------------------------
/models/stable-diffusion-v1-5/tokenizer/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "add_prefix_space": false,
3 | "bos_token": {
4 | "__type": "AddedToken",
5 | "content": "<|startoftext|>",
6 | "lstrip": false,
7 | "normalized": true,
8 | "rstrip": false,
9 | "single_word": false
10 | },
11 | "do_lower_case": true,
12 | "eos_token": {
13 | "__type": "AddedToken",
14 | "content": "<|endoftext|>",
15 | "lstrip": false,
16 | "normalized": true,
17 | "rstrip": false,
18 | "single_word": false
19 | },
20 | "errors": "replace",
21 | "model_max_length": 77,
22 | "name_or_path": "openai/clip-vit-large-patch14",
23 | "pad_token": "<|endoftext|>",
24 | "special_tokens_map_file": "./special_tokens_map.json",
25 | "tokenizer_class": "CLIPTokenizer",
26 | "unk_token": {
27 | "__type": "AddedToken",
28 | "content": "<|endoftext|>",
29 | "lstrip": false,
30 | "normalized": true,
31 | "rstrip": false,
32 | "single_word": false
33 | }
34 | }
35 |
--------------------------------------------------------------------------------
/models/stable-diffusion-xl/madebyollin_sdxl_vae_fp16_fix/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_class_name": "AutoencoderKL",
3 | "_diffusers_version": "0.18.0.dev0",
4 | "_name_or_path": ".",
5 | "act_fn": "silu",
6 | "block_out_channels": [
7 | 128,
8 | 256,
9 | 512,
10 | 512
11 | ],
12 | "down_block_types": [
13 | "DownEncoderBlock2D",
14 | "DownEncoderBlock2D",
15 | "DownEncoderBlock2D",
16 | "DownEncoderBlock2D"
17 | ],
18 | "in_channels": 3,
19 | "latent_channels": 4,
20 | "layers_per_block": 2,
21 | "norm_num_groups": 32,
22 | "out_channels": 3,
23 | "sample_size": 512,
24 | "scaling_factor": 0.13025,
25 | "up_block_types": [
26 | "UpDecoderBlock2D",
27 | "UpDecoderBlock2D",
28 | "UpDecoderBlock2D",
29 | "UpDecoderBlock2D"
30 | ],
31 | "force_upcast": false
32 | }
33 |
--------------------------------------------------------------------------------
/models/stable-diffusion-xl/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k/refs/main:
--------------------------------------------------------------------------------
1 | 8c7a3583335de4dba1b07182dbf81c75137ce67b
--------------------------------------------------------------------------------
/models/stable-diffusion-xl/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k/snapshots/8c7a3583335de4dba1b07182dbf81c75137ce67b/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_commit_hash": null,
3 | "architectures": [
4 | "CLIPModel"
5 | ],
6 | "initializer_factor": 1.0,
7 | "logit_scale_init_value": 2.6592,
8 | "model_type": "clip",
9 | "projection_dim": 1280,
10 | "text_config": {
11 | "_name_or_path": "",
12 | "add_cross_attention": false,
13 | "architectures": null,
14 | "attention_dropout": 0.0,
15 | "bad_words_ids": null,
16 | "begin_suppress_tokens": null,
17 | "bos_token_id": 0,
18 | "chunk_size_feed_forward": 0,
19 | "cross_attention_hidden_size": null,
20 | "decoder_start_token_id": null,
21 | "diversity_penalty": 0.0,
22 | "do_sample": false,
23 | "dropout": 0.0,
24 | "early_stopping": false,
25 | "encoder_no_repeat_ngram_size": 0,
26 | "eos_token_id": 2,
27 | "exponential_decay_length_penalty": null,
28 | "finetuning_task": null,
29 | "forced_bos_token_id": null,
30 | "forced_eos_token_id": null,
31 | "hidden_act": "gelu",
32 | "hidden_size": 1280,
33 | "id2label": {
34 | "0": "LABEL_0",
35 | "1": "LABEL_1"
36 | },
37 | "initializer_factor": 1.0,
38 | "initializer_range": 0.02,
39 | "intermediate_size": 5120,
40 | "is_decoder": false,
41 | "is_encoder_decoder": false,
42 | "label2id": {
43 | "LABEL_0": 0,
44 | "LABEL_1": 1
45 | },
46 | "layer_norm_eps": 1e-05,
47 | "length_penalty": 1.0,
48 | "max_length": 20,
49 | "max_position_embeddings": 77,
50 | "min_length": 0,
51 | "model_type": "clip_text_model",
52 | "no_repeat_ngram_size": 0,
53 | "num_attention_heads": 20,
54 | "num_beam_groups": 1,
55 | "num_beams": 1,
56 | "num_hidden_layers": 32,
57 | "num_return_sequences": 1,
58 | "output_attentions": false,
59 | "output_hidden_states": false,
60 | "output_scores": false,
61 | "pad_token_id": 1,
62 | "prefix": null,
63 | "problem_type": null,
64 | "pruned_heads": {},
65 | "remove_invalid_values": false,
66 | "repetition_penalty": 1.0,
67 | "return_dict": true,
68 | "return_dict_in_generate": false,
69 | "sep_token_id": null,
70 | "suppress_tokens": null,
71 | "task_specific_params": null,
72 | "temperature": 1.0,
73 | "tf_legacy_loss": false,
74 | "tie_encoder_decoder": false,
75 | "tie_word_embeddings": true,
76 | "tokenizer_class": null,
77 | "top_k": 50,
78 | "top_p": 1.0,
79 | "torch_dtype": null,
80 | "torchscript": false,
81 | "transformers_version": "4.24.0",
82 | "typical_p": 1.0,
83 | "use_bfloat16": false,
84 | "vocab_size": 49408
85 | },
86 | "text_config_dict": {
87 | "hidden_act": "gelu",
88 | "hidden_size": 1280,
89 | "intermediate_size": 5120,
90 | "num_attention_heads": 20,
91 | "num_hidden_layers": 32
92 | },
93 | "torch_dtype": "float32",
94 | "transformers_version": null,
95 | "vision_config": {
96 | "_name_or_path": "",
97 | "add_cross_attention": false,
98 | "architectures": null,
99 | "attention_dropout": 0.0,
100 | "bad_words_ids": null,
101 | "begin_suppress_tokens": null,
102 | "bos_token_id": null,
103 | "chunk_size_feed_forward": 0,
104 | "cross_attention_hidden_size": null,
105 | "decoder_start_token_id": null,
106 | "diversity_penalty": 0.0,
107 | "do_sample": false,
108 | "dropout": 0.0,
109 | "early_stopping": false,
110 | "encoder_no_repeat_ngram_size": 0,
111 | "eos_token_id": null,
112 | "exponential_decay_length_penalty": null,
113 | "finetuning_task": null,
114 | "forced_bos_token_id": null,
115 | "forced_eos_token_id": null,
116 | "hidden_act": "gelu",
117 | "hidden_size": 1664,
118 | "id2label": {
119 | "0": "LABEL_0",
120 | "1": "LABEL_1"
121 | },
122 | "image_size": 224,
123 | "initializer_factor": 1.0,
124 | "initializer_range": 0.02,
125 | "intermediate_size": 8192,
126 | "is_decoder": false,
127 | "is_encoder_decoder": false,
128 | "label2id": {
129 | "LABEL_0": 0,
130 | "LABEL_1": 1
131 | },
132 | "layer_norm_eps": 1e-05,
133 | "length_penalty": 1.0,
134 | "max_length": 20,
135 | "min_length": 0,
136 | "model_type": "clip_vision_model",
137 | "no_repeat_ngram_size": 0,
138 | "num_attention_heads": 16,
139 | "num_beam_groups": 1,
140 | "num_beams": 1,
141 | "num_channels": 3,
142 | "num_hidden_layers": 48,
143 | "num_return_sequences": 1,
144 | "output_attentions": false,
145 | "output_hidden_states": false,
146 | "output_scores": false,
147 | "pad_token_id": null,
148 | "patch_size": 14,
149 | "prefix": null,
150 | "problem_type": null,
151 | "pruned_heads": {},
152 | "remove_invalid_values": false,
153 | "repetition_penalty": 1.0,
154 | "return_dict": true,
155 | "return_dict_in_generate": false,
156 | "sep_token_id": null,
157 | "suppress_tokens": null,
158 | "task_specific_params": null,
159 | "temperature": 1.0,
160 | "tf_legacy_loss": false,
161 | "tie_encoder_decoder": false,
162 | "tie_word_embeddings": true,
163 | "tokenizer_class": null,
164 | "top_k": 50,
165 | "top_p": 1.0,
166 | "torch_dtype": null,
167 | "torchscript": false,
168 | "transformers_version": "4.24.0",
169 | "typical_p": 1.0,
170 | "use_bfloat16": false
171 | },
172 | "vision_config_dict": {
173 | "hidden_act": "gelu",
174 | "hidden_size": 1664,
175 | "intermediate_size": 8192,
176 | "num_attention_heads": 16,
177 | "num_hidden_layers": 48,
178 | "patch_size": 14
179 | }
180 | }
181 |
--------------------------------------------------------------------------------
/models/stable-diffusion-xl/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k/snapshots/8c7a3583335de4dba1b07182dbf81c75137ce67b/open_clip_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "model_cfg": {
3 | "embed_dim": 1280,
4 | "vision_cfg": {
5 | "image_size": 224,
6 | "layers": 48,
7 | "width": 1664,
8 | "head_width": 104,
9 | "mlp_ratio": 4.9231,
10 | "patch_size": 14
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 1280,
16 | "heads": 20,
17 | "layers": 32
18 | }
19 | },
20 | "preprocess_cfg": {
21 | "mean": [
22 | 0.48145466,
23 | 0.4578275,
24 | 0.40821073
25 | ],
26 | "std": [
27 | 0.26862954,
28 | 0.26130258,
29 | 0.27577711
30 | ]
31 | }
32 | }
--------------------------------------------------------------------------------
/models/stable-diffusion-xl/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k/snapshots/8c7a3583335de4dba1b07182dbf81c75137ce67b/preprocessor_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "crop_size": 224,
3 | "do_center_crop": true,
4 | "do_normalize": true,
5 | "do_resize": true,
6 | "feature_extractor_type": "CLIPFeatureExtractor",
7 | "image_mean": [
8 | 0.48145466,
9 | 0.4578275,
10 | 0.40821073
11 | ],
12 | "image_std": [
13 | 0.26862954,
14 | 0.26130258,
15 | 0.27577711
16 | ],
17 | "resample": 3,
18 | "size": 224
19 | }
20 |
--------------------------------------------------------------------------------
/models/stable-diffusion-xl/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k/snapshots/8c7a3583335de4dba1b07182dbf81c75137ce67b/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"}
--------------------------------------------------------------------------------
/models/stable-diffusion-xl/models--laion--CLIP-ViT-bigG-14-laion2B-39B-b160k/snapshots/8c7a3583335de4dba1b07182dbf81c75137ce67b/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "unk_token": {
3 | "content": "<|endoftext|>",
4 | "single_word": false,
5 | "lstrip": false,
6 | "rstrip": false,
7 | "normalized": true,
8 | "__type": "AddedToken"
9 | },
10 | "bos_token": {
11 | "content": "<|startoftext|>",
12 | "single_word": false,
13 | "lstrip": false,
14 | "rstrip": false,
15 | "normalized": true,
16 | "__type": "AddedToken"
17 | },
18 | "eos_token": {
19 | "content": "<|endoftext|>",
20 | "single_word": false,
21 | "lstrip": false,
22 | "rstrip": false,
23 | "normalized": true,
24 | "__type": "AddedToken"
25 | },
26 | "pad_token": "<|endoftext|>",
27 | "add_prefix_space": false,
28 | "errors": "replace",
29 | "do_lower_case": true,
30 | "name_or_path": "openai/clip-vit-base-patch32",
31 | "model_max_length": 77,
32 | "special_tokens_map_file": "./special_tokens_map.json",
33 | "tokenizer_class": "CLIPTokenizer"
34 | }
--------------------------------------------------------------------------------
/models/stable-diffusion-xl/models--openai--clip-vit-large-patch14/refs/main:
--------------------------------------------------------------------------------
1 | 32bd64288804d66eefd0ccbe215aa642df71cc41
--------------------------------------------------------------------------------
/models/stable-diffusion-xl/models--openai--clip-vit-large-patch14/snapshots/32bd64288804d66eefd0ccbe215aa642df71cc41/config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "clip-vit-large-patch14/",
3 | "architectures": [
4 | "CLIPModel"
5 | ],
6 | "initializer_factor": 1.0,
7 | "logit_scale_init_value": 2.6592,
8 | "model_type": "clip",
9 | "projection_dim": 768,
10 | "text_config": {
11 | "_name_or_path": "",
12 | "add_cross_attention": false,
13 | "architectures": null,
14 | "attention_dropout": 0.0,
15 | "bad_words_ids": null,
16 | "bos_token_id": 0,
17 | "chunk_size_feed_forward": 0,
18 | "cross_attention_hidden_size": null,
19 | "decoder_start_token_id": null,
20 | "diversity_penalty": 0.0,
21 | "do_sample": false,
22 | "dropout": 0.0,
23 | "early_stopping": false,
24 | "encoder_no_repeat_ngram_size": 0,
25 | "eos_token_id": 2,
26 | "finetuning_task": null,
27 | "forced_bos_token_id": null,
28 | "forced_eos_token_id": null,
29 | "hidden_act": "quick_gelu",
30 | "hidden_size": 768,
31 | "id2label": {
32 | "0": "LABEL_0",
33 | "1": "LABEL_1"
34 | },
35 | "initializer_factor": 1.0,
36 | "initializer_range": 0.02,
37 | "intermediate_size": 3072,
38 | "is_decoder": false,
39 | "is_encoder_decoder": false,
40 | "label2id": {
41 | "LABEL_0": 0,
42 | "LABEL_1": 1
43 | },
44 | "layer_norm_eps": 1e-05,
45 | "length_penalty": 1.0,
46 | "max_length": 20,
47 | "max_position_embeddings": 77,
48 | "min_length": 0,
49 | "model_type": "clip_text_model",
50 | "no_repeat_ngram_size": 0,
51 | "num_attention_heads": 12,
52 | "num_beam_groups": 1,
53 | "num_beams": 1,
54 | "num_hidden_layers": 12,
55 | "num_return_sequences": 1,
56 | "output_attentions": false,
57 | "output_hidden_states": false,
58 | "output_scores": false,
59 | "pad_token_id": 1,
60 | "prefix": null,
61 | "problem_type": null,
62 | "projection_dim" : 768,
63 | "pruned_heads": {},
64 | "remove_invalid_values": false,
65 | "repetition_penalty": 1.0,
66 | "return_dict": true,
67 | "return_dict_in_generate": false,
68 | "sep_token_id": null,
69 | "task_specific_params": null,
70 | "temperature": 1.0,
71 | "tie_encoder_decoder": false,
72 | "tie_word_embeddings": true,
73 | "tokenizer_class": null,
74 | "top_k": 50,
75 | "top_p": 1.0,
76 | "torch_dtype": null,
77 | "torchscript": false,
78 | "transformers_version": "4.16.0.dev0",
79 | "use_bfloat16": false,
80 | "vocab_size": 49408
81 | },
82 | "text_config_dict": {
83 | "hidden_size": 768,
84 | "intermediate_size": 3072,
85 | "num_attention_heads": 12,
86 | "num_hidden_layers": 12,
87 | "projection_dim": 768
88 | },
89 | "torch_dtype": "float32",
90 | "transformers_version": null,
91 | "vision_config": {
92 | "_name_or_path": "",
93 | "add_cross_attention": false,
94 | "architectures": null,
95 | "attention_dropout": 0.0,
96 | "bad_words_ids": null,
97 | "bos_token_id": null,
98 | "chunk_size_feed_forward": 0,
99 | "cross_attention_hidden_size": null,
100 | "decoder_start_token_id": null,
101 | "diversity_penalty": 0.0,
102 | "do_sample": false,
103 | "dropout": 0.0,
104 | "early_stopping": false,
105 | "encoder_no_repeat_ngram_size": 0,
106 | "eos_token_id": null,
107 | "finetuning_task": null,
108 | "forced_bos_token_id": null,
109 | "forced_eos_token_id": null,
110 | "hidden_act": "quick_gelu",
111 | "hidden_size": 1024,
112 | "id2label": {
113 | "0": "LABEL_0",
114 | "1": "LABEL_1"
115 | },
116 | "image_size": 224,
117 | "initializer_factor": 1.0,
118 | "initializer_range": 0.02,
119 | "intermediate_size": 4096,
120 | "is_decoder": false,
121 | "is_encoder_decoder": false,
122 | "label2id": {
123 | "LABEL_0": 0,
124 | "LABEL_1": 1
125 | },
126 | "layer_norm_eps": 1e-05,
127 | "length_penalty": 1.0,
128 | "max_length": 20,
129 | "min_length": 0,
130 | "model_type": "clip_vision_model",
131 | "no_repeat_ngram_size": 0,
132 | "num_attention_heads": 16,
133 | "num_beam_groups": 1,
134 | "num_beams": 1,
135 | "num_hidden_layers": 24,
136 | "num_return_sequences": 1,
137 | "output_attentions": false,
138 | "output_hidden_states": false,
139 | "output_scores": false,
140 | "pad_token_id": null,
141 | "patch_size": 14,
142 | "prefix": null,
143 | "problem_type": null,
144 | "projection_dim" : 768,
145 | "pruned_heads": {},
146 | "remove_invalid_values": false,
147 | "repetition_penalty": 1.0,
148 | "return_dict": true,
149 | "return_dict_in_generate": false,
150 | "sep_token_id": null,
151 | "task_specific_params": null,
152 | "temperature": 1.0,
153 | "tie_encoder_decoder": false,
154 | "tie_word_embeddings": true,
155 | "tokenizer_class": null,
156 | "top_k": 50,
157 | "top_p": 1.0,
158 | "torch_dtype": null,
159 | "torchscript": false,
160 | "transformers_version": "4.16.0.dev0",
161 | "use_bfloat16": false
162 | },
163 | "vision_config_dict": {
164 | "hidden_size": 1024,
165 | "intermediate_size": 4096,
166 | "num_attention_heads": 16,
167 | "num_hidden_layers": 24,
168 | "patch_size": 14,
169 | "projection_dim": 768
170 | }
171 | }
172 |
--------------------------------------------------------------------------------
/models/stable-diffusion-xl/models--openai--clip-vit-large-patch14/snapshots/32bd64288804d66eefd0ccbe215aa642df71cc41/preprocessor_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "crop_size": 224,
3 | "do_center_crop": true,
4 | "do_normalize": true,
5 | "do_resize": true,
6 | "feature_extractor_type": "CLIPFeatureExtractor",
7 | "image_mean": [
8 | 0.48145466,
9 | 0.4578275,
10 | 0.40821073
11 | ],
12 | "image_std": [
13 | 0.26862954,
14 | 0.26130258,
15 | 0.27577711
16 | ],
17 | "resample": 3,
18 | "size": 224
19 | }
20 |
--------------------------------------------------------------------------------
/models/stable-diffusion-xl/models--openai--clip-vit-large-patch14/snapshots/32bd64288804d66eefd0ccbe215aa642df71cc41/special_tokens_map.json:
--------------------------------------------------------------------------------
1 | {"bos_token": {"content": "<|startoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "eos_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "unk_token": {"content": "<|endoftext|>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": true}, "pad_token": "<|endoftext|>"}
--------------------------------------------------------------------------------
/models/stable-diffusion-xl/models--openai--clip-vit-large-patch14/snapshots/32bd64288804d66eefd0ccbe215aa642df71cc41/tokenizer_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "unk_token": {
3 | "content": "<|endoftext|>",
4 | "single_word": false,
5 | "lstrip": false,
6 | "rstrip": false,
7 | "normalized": true,
8 | "__type": "AddedToken"
9 | },
10 | "bos_token": {
11 | "content": "<|startoftext|>",
12 | "single_word": false,
13 | "lstrip": false,
14 | "rstrip": false,
15 | "normalized": true,
16 | "__type": "AddedToken"
17 | },
18 | "eos_token": {
19 | "content": "<|endoftext|>",
20 | "single_word": false,
21 | "lstrip": false,
22 | "rstrip": false,
23 | "normalized": true,
24 | "__type": "AddedToken"
25 | },
26 | "pad_token": "<|endoftext|>",
27 | "add_prefix_space": false,
28 | "errors": "replace",
29 | "do_lower_case": true,
30 | "name_or_path": "openai/clip-vit-base-patch32",
31 | "model_max_length": 77,
32 | "special_tokens_map_file": "./special_tokens_map.json",
33 | "tokenizer_class": "CLIPTokenizer"
34 | }
35 |
--------------------------------------------------------------------------------
/models/stable-diffusion-xl/stabilityai_stable_diffusion_xl_base_1.0/scheduler/scheduler_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "_class_name": "EulerDiscreteScheduler",
3 | "_diffusers_version": "0.19.0.dev0",
4 | "beta_end": 0.012,
5 | "beta_schedule": "scaled_linear",
6 | "beta_start": 0.00085,
7 | "clip_sample": false,
8 | "interpolation_type": "linear",
9 | "num_train_timesteps": 1000,
10 | "prediction_type": "epsilon",
11 | "sample_max_value": 1.0,
12 | "set_alpha_to_one": false,
13 | "skip_prk_steps": true,
14 | "steps_offset": 1,
15 | "timestep_spacing": "leading",
16 | "trained_betas": null,
17 | "use_karras_sigmas": false
18 | }
--------------------------------------------------------------------------------
/models/training_templates/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/training_templates/1.jpg
--------------------------------------------------------------------------------
/models/training_templates/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/training_templates/2.jpg
--------------------------------------------------------------------------------
/models/training_templates/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/training_templates/3.jpg
--------------------------------------------------------------------------------
/models/training_templates/4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/models/training_templates/4.jpg
--------------------------------------------------------------------------------
/scripts/easyphoto_config.py:
--------------------------------------------------------------------------------
1 | import os
2 | from modules.paths import data_path, models_path, extensions_builtin_dir, extensions_dir
3 |
4 | # save_dirs
5 | data_dir = data_path
6 | models_path = models_path
7 | extensions_builtin_dir = extensions_builtin_dir
8 | extensions_dir = extensions_dir
9 | easyphoto_models_path = os.path.abspath(os.path.dirname(__file__)).replace("scripts", "models")
10 | easyphoto_img2img_samples = os.path.join(data_dir, "outputs/img2img-images")
11 | easyphoto_txt2img_samples = os.path.join(data_dir, "outputs/txt2img-images")
12 | easyphoto_outpath_samples = os.path.join(data_dir, "outputs/easyphoto-outputs")
13 | easyphoto_video_outpath_samples = os.path.join(data_dir, "outputs/easyphoto-video-outputs")
14 | user_id_outpath_samples = os.path.join(data_dir, "outputs/easyphoto-user-id-infos")
15 | cloth_id_outpath_samples = os.path.join(data_dir, "outputs/easyphoto-cloth-id-infos")
16 | scene_id_outpath_samples = os.path.join(data_dir, "outputs/easyphoto-scene-id-infos")
17 | cache_log_file_path = os.path.join(data_dir, "outputs/easyphoto-tmp/train_kohya_log.txt")
18 |
19 | # gallery_dir
20 | tryon_preview_dir = os.path.join(os.path.abspath(os.path.dirname(__file__)).replace("scripts", "images"), "tryon")
21 | tryon_gallery_dir = os.path.join(cloth_id_outpath_samples, "gallery")
22 |
23 | # prompts
24 | validation_prompt = "easyphoto_face, easyphoto, 1person"
25 | validation_prompt_scene = "special_scene, scene"
26 | validation_tryon_prompt = "easyphoto, 1thing"
27 | DEFAULT_POSITIVE = "(cloth:1.5), (best quality), (realistic, photo-realistic:1.3), (beautiful eyes:1.3), (sparkling eyes:1.3), (beautiful mouth:1.3), finely detail, light smile, extremely detailed CG unity 8k wallpaper, huge filesize, best quality, realistic, photo-realistic, ultra high res, raw photo, put on makeup"
28 | DEFAULT_NEGATIVE = "(bags under the eyes:1.5), (bags under eyes:1.5), (earrings:1.3), (glasses:1.2), (naked:1.5), (nsfw:1.5), nude, breasts, penis, cum, (over red lips: 1.3), (bad lips: 1.3), (bad ears:1.3), (bad hair: 1.3), (bad teeth: 1.3), (worst quality:2), (low quality:2), (normal quality:2), lowres, watermark, badhand, lowres, bad anatomy, bad hands, normal quality, mural,"
29 | DEFAULT_POSITIVE_AD = "(realistic, photorealistic), (masterpiece, best quality, high quality), (delicate eyes and face), extremely detailed CG unity 8k wallpaper, best quality, realistic, photo-realistic, ultra high res, raw photo"
30 | DEFAULT_NEGATIVE_AD = "(naked:1.2), (nsfw:1.2), nipple slip, nude, breasts, (huge breasts:1.2), penis, cum, (blurry background:1.3), (depth of field:1.7), (holding:2), (worst quality:2), (normal quality:2), lowres, bad anatomy, bad hands"
31 | DEFAULT_POSITIVE_T2I = "(cloth:1.0), (best quality), (realistic, photo-realistic:1.3), film photography, minor acne, (portrait:1.1), (indirect lighting), extremely detailed CG unity 8k wallpaper, huge filesize, best quality, realistic, photo-realistic, ultra high res, raw photo, put on makeup"
32 | DEFAULT_NEGATIVE_T2I = "(nsfw:1.5), (huge breast:1.5), nude, breasts, penis, cum, bokeh, cgi, illustration, cartoon, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, ugly, deformed, blurry, Noisy, log, text (worst quality:2), (low quality:2), (normal quality:2), lowres, watermark, badhand, lowres"
33 |
34 | # scene lora
35 | DEFAULT_SCENE_LORA = [
36 | "Christmas_1",
37 | "Cyberpunk_1",
38 | "FairMaidenStyle_1",
39 | "Gentleman_1",
40 | "GuoFeng_1",
41 | "GuoFeng_2",
42 | "GuoFeng_3",
43 | "GuoFeng_4",
44 | "Minimalism_1",
45 | "NaturalWind_1",
46 | "Princess_1",
47 | "Princess_2",
48 | "Princess_3",
49 | "SchoolUniform_1",
50 | "SchoolUniform_2",
51 | ]
52 |
53 | # tryon template
54 | DEFAULT_TRYON_TEMPLATE = ["boy", "girl", "dress", "short"]
55 |
56 | # cloth lora
57 | DEFAULT_CLOTH_LORA = ["demo_black_200", "demo_white_200", "demo_purple_200", "demo_dress_200", "demo_short_200"]
58 |
59 | # sliders
60 | DEFAULT_SLIDERS = ["age_sd1_sliders", "smiling_sd1_sliders", "age_sdxl_sliders", "smiling_sdxl_sliders"]
61 |
62 | # ModelName
63 | SDXL_MODEL_NAME = "SDXL_1.0_ArienMixXL_v2.0.safetensors"
64 |
--------------------------------------------------------------------------------
/scripts/easyphoto_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .face_process_utils import (
2 | Face_Skin,
3 | alignment_photo,
4 | call_face_crop,
5 | call_face_crop_templates,
6 | color_transfer,
7 | crop_and_paste,
8 | safe_get_box_mask_keypoints_and_padding_image,
9 | )
10 | from .fire_utils import FIRE_forward
11 | from .psgan_utils import PSGAN_Inference
12 | from .tryon_utils import (
13 | align_and_overlay_images,
14 | apply_mask_to_image,
15 | compute_rotation_angle,
16 | copy_white_mask_to_template,
17 | crop_image,
18 | expand_box_by_pad,
19 | expand_roi,
20 | find_best_angle_ratio,
21 | get_background_color,
22 | mask_to_box,
23 | mask_to_polygon,
24 | merge_with_inner_canny,
25 | prepare_tryon_train_data,
26 | resize_and_stretch,
27 | resize_image_with_pad,
28 | seg_by_box,
29 | find_connected_components,
30 | )
31 |
32 | from .common_utils import (
33 | check_files_exists_and_download,
34 | check_id_valid,
35 | check_scene_valid,
36 | convert_to_video,
37 | ep_logger,
38 | get_controlnet_version,
39 | get_mov_all_images,
40 | modelscope_models_to_cpu,
41 | modelscope_models_to_gpu,
42 | switch_ms_model_cpu,
43 | unload_models,
44 | seed_everything,
45 | get_attribute_edit_ids,
46 | encode_video_to_base64,
47 | decode_base64_to_video,
48 | cleanup_decorator,
49 | auto_to_gpu_model,
50 | )
51 | from .loractl_utils import check_loractl_conflict, LoraCtlScript
52 | from .animatediff_utils import (
53 | AnimateDiffControl,
54 | AnimateDiffI2VLatent,
55 | AnimateDiffInfV2V,
56 | AnimateDiffLora,
57 | AnimateDiffMM,
58 | AnimateDiffOutput,
59 | AnimateDiffProcess,
60 | AnimateDiffPromptSchedule,
61 | AnimateDiffUiGroup,
62 | animatediff_i2ibatch,
63 | motion_module,
64 | update_infotext,
65 | video_visible,
66 | )
67 |
--------------------------------------------------------------------------------
/scripts/easyphoto_utils/animatediff/README.MD:
--------------------------------------------------------------------------------
1 | ## 如何修改animatediff
2 | 1. 将根目录下的motion_module.py拷贝进scripts文件夹。
3 | 2. 将scripts下的文件,其中每个py文件里面的绝对导入改为相对路径,例如:
4 | ```
5 | scripts.animatediff_logger => .animatediff_logger
6 | ```
7 | 即可。
8 |
9 | ## hack的代码
10 | hack代码不多,主要在animatediff_utils.py中,
11 |
12 | 1. 重写AnimateDiffControl,为了可以获取批处理的图片,从视频中。
13 | 2. 重写AnimateDiffMM,为了加载模型。
14 | 3. 重写AnimateDiffProcess与AnimateDiffI2VLatent,为了image2video的合理的特征保留。
15 | 4. 重写AnimateDiffScript,为了easyphoto更简单的调用。
--------------------------------------------------------------------------------
/scripts/easyphoto_utils/animatediff/animatediff_infotext.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from modules.paths import data_path
4 | from modules.processing import StableDiffusionProcessing, StableDiffusionProcessingImg2Img
5 |
6 | from .animatediff_ui import AnimateDiffProcess
7 | from .animatediff_logger import logger_animatediff as logger
8 |
9 |
10 | def update_infotext(p: StableDiffusionProcessing, params: AnimateDiffProcess):
11 | if p.extra_generation_params is not None:
12 | p.extra_generation_params["AnimateDiff"] = params.get_dict(isinstance(p, StableDiffusionProcessingImg2Img))
13 |
14 |
15 | def write_params_txt(info: str):
16 | with open(os.path.join(data_path, "params.txt"), "w", encoding="utf8") as file:
17 | file.write(info)
18 |
19 |
20 |
21 | def infotext_pasted(infotext, results):
22 | for k, v in results.items():
23 | if not k.startswith("AnimateDiff"):
24 | continue
25 |
26 | assert isinstance(v, str), f"Expect string but got {v}."
27 | try:
28 | for items in v.split(', '):
29 | field, value = items.split(': ')
30 | results[f"AnimateDiff {field}"] = value
31 | except Exception:
32 | logger.warn(
33 | f"Failed to parse infotext, legacy format infotext is no longer supported:\n{v}"
34 | )
35 | break
36 |
--------------------------------------------------------------------------------
/scripts/easyphoto_utils/animatediff/animatediff_latent.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from modules import images, shared
4 | from modules.devices import device, dtype_vae, torch_gc
5 | from modules.processing import StableDiffusionProcessingImg2Img
6 | from modules.sd_samplers_common import (approximation_indexes,
7 | images_tensor_to_samples)
8 |
9 | from .animatediff_logger import logger_animatediff as logger
10 | from .animatediff_ui import AnimateDiffProcess
11 |
12 |
13 | class AnimateDiffI2VLatent:
14 | def randomize(
15 | self, p: StableDiffusionProcessingImg2Img, params: AnimateDiffProcess
16 | ):
17 | # Get init_alpha
18 | reserve_scale = [
19 | 0.75 for i in range(params.video_length)
20 | ]
21 | logger.info(f"Randomizing reserve_scale according to {reserve_scale}.")
22 | reserve_scale = torch.tensor(reserve_scale, dtype=torch.float32, device=device)[
23 | :, None, None, None
24 | ]
25 | reserve_scale[reserve_scale < 0] = 0
26 |
27 | if params.last_frame is not None:
28 | # Get init_alpha
29 | init_alpha = [
30 | 1 - pow(i, params.latent_power) / params.latent_scale
31 | for i in range(params.video_length)
32 | ]
33 | logger.info(f"Randomizing init_latent according to {init_alpha}.")
34 | init_alpha = torch.tensor(init_alpha, dtype=torch.float32, device=device)[
35 | :, None, None, None
36 | ]
37 | init_alpha[init_alpha < 0] = 0
38 |
39 | last_frame = params.last_frame
40 | if type(last_frame) == str:
41 | from modules.api.api import decode_base64_to_image
42 | last_frame = decode_base64_to_image(last_frame)
43 | # Get last_alpha
44 | last_alpha = [
45 | 1 - pow(i, params.latent_power_last) / params.latent_scale_last
46 | for i in range(params.video_length)
47 | ]
48 | last_alpha.reverse()
49 | logger.info(f"Randomizing last_latent according to {last_alpha}.")
50 | last_alpha = torch.tensor(last_alpha, dtype=torch.float32, device=device)[
51 | :, None, None, None
52 | ]
53 | last_alpha[last_alpha < 0] = 0
54 |
55 | # Normalize alpha
56 | sum_alpha = init_alpha + last_alpha
57 | mask_alpha = sum_alpha > 1
58 | scaling_factor = 1 / sum_alpha[mask_alpha]
59 | init_alpha[mask_alpha] *= scaling_factor
60 | last_alpha[mask_alpha] *= scaling_factor
61 | init_alpha[0] = 1
62 | init_alpha[-1] = 0
63 | last_alpha[0] = 0
64 | last_alpha[-1] = 1
65 |
66 | # Calculate last_latent
67 | if p.resize_mode != 3:
68 | last_frame = images.resize_image(
69 | p.resize_mode, last_frame, p.width, p.height
70 | )
71 | last_frame = np.array(last_frame).astype(np.float32) / 255.0
72 | last_frame = np.moveaxis(last_frame, 2, 0)[None, ...]
73 | last_frame = torch.from_numpy(last_frame).to(device).to(dtype_vae)
74 | last_latent = images_tensor_to_samples(
75 | last_frame,
76 | approximation_indexes.get(shared.opts.sd_vae_encode_method),
77 | p.sd_model,
78 | )
79 | torch_gc()
80 | if p.resize_mode == 3:
81 | opt_f = 8
82 | last_latent = torch.nn.functional.interpolate(
83 | last_latent,
84 | size=(p.height // opt_f, p.width // opt_f),
85 | mode="bilinear",
86 | )
87 | # Modify init_latent
88 | p.init_latent = (
89 | (p.init_latent * init_alpha
90 | + last_latent * last_alpha
91 | + p.rng.next() * (1 - init_alpha - last_alpha)) * reserve_scale
92 | + p.rng.next() * (1 - reserve_scale)
93 | )
94 | else:
95 | p.init_latent = p.init_latent * reserve_scale + p.rng.next() * (1 - reserve_scale)
96 |
--------------------------------------------------------------------------------
/scripts/easyphoto_utils/animatediff/animatediff_lcm.py:
--------------------------------------------------------------------------------
1 |
2 | # TODO: remove this file when LCM is merged to A1111
3 | import torch
4 |
5 | from k_diffusion import utils, sampling
6 | from k_diffusion.external import DiscreteEpsDDPMDenoiser
7 | from k_diffusion.sampling import default_noise_sampler, trange
8 |
9 | from modules import shared, sd_samplers_cfg_denoiser, sd_samplers_kdiffusion
10 | from .animatediff_logger import logger_animatediff as logger
11 |
12 |
13 | class LCMCompVisDenoiser(DiscreteEpsDDPMDenoiser):
14 | def __init__(self, model):
15 | timesteps = 1000
16 | beta_start = 0.00085
17 | beta_end = 0.012
18 |
19 | betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=torch.float32) ** 2
20 | alphas = 1.0 - betas
21 | alphas_cumprod = torch.cumprod(alphas, dim=0)
22 |
23 | original_timesteps = 50 # LCM Original Timesteps (default=50, for current version of LCM)
24 | self.skip_steps = timesteps // original_timesteps
25 |
26 |
27 | alphas_cumprod_valid = torch.zeros((original_timesteps), dtype=torch.float32, device=model.device)
28 | for x in range(original_timesteps):
29 | alphas_cumprod_valid[original_timesteps - 1 - x] = alphas_cumprod[timesteps - 1 - x * self.skip_steps]
30 |
31 | super().__init__(model, alphas_cumprod_valid, quantize=None)
32 |
33 |
34 | def get_sigmas(self, n=None, sgm=False):
35 | if n is None:
36 | return sampling.append_zero(self.sigmas.flip(0))
37 |
38 | start = self.sigma_to_t(self.sigma_max)
39 | end = self.sigma_to_t(self.sigma_min)
40 |
41 | if sgm:
42 | t = torch.linspace(start, end, n + 1, device=shared.sd_model.device)[:-1]
43 | else:
44 | t = torch.linspace(start, end, n, device=shared.sd_model.device)
45 |
46 | return sampling.append_zero(self.t_to_sigma(t))
47 |
48 |
49 | def sigma_to_t(self, sigma, quantize=None):
50 | log_sigma = sigma.log()
51 | dists = log_sigma - self.log_sigmas[:, None]
52 | return dists.abs().argmin(dim=0).view(sigma.shape) * self.skip_steps + (self.skip_steps - 1)
53 |
54 |
55 | def t_to_sigma(self, timestep):
56 | t = torch.clamp(((timestep - (self.skip_steps - 1)) / self.skip_steps).float(), min=0, max=(len(self.sigmas) - 1))
57 | return super().t_to_sigma(t)
58 |
59 |
60 | def get_eps(self, *args, **kwargs):
61 | return self.inner_model.apply_model(*args, **kwargs)
62 |
63 |
64 | def get_scaled_out(self, sigma, output, input):
65 | sigma_data = 0.5
66 | scaled_timestep = utils.append_dims(self.sigma_to_t(sigma), output.ndim) * 10.0
67 |
68 | c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2)
69 | c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5
70 |
71 | return c_out * output + c_skip * input
72 |
73 |
74 | def forward(self, input, sigma, **kwargs):
75 | c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
76 | eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
77 | return self.get_scaled_out(sigma, input + eps * c_out, input)
78 |
79 |
80 | def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None):
81 | extra_args = {} if extra_args is None else extra_args
82 | noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
83 | s_in = x.new_ones([x.shape[0]])
84 |
85 | for i in trange(len(sigmas) - 1, disable=disable):
86 | denoised = model(x, sigmas[i] * s_in, **extra_args)
87 |
88 | if callback is not None:
89 | callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
90 |
91 | x = denoised
92 | if sigmas[i + 1] > 0:
93 | x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1])
94 | return x
95 |
96 |
97 | class CFGDenoiserLCM(sd_samplers_cfg_denoiser.CFGDenoiser):
98 | @property
99 | def inner_model(self):
100 | if self.model_wrap is None:
101 | denoiser = LCMCompVisDenoiser
102 | self.model_wrap = denoiser(shared.sd_model)
103 |
104 | return self.model_wrap
105 |
106 |
107 | class LCMSampler(sd_samplers_kdiffusion.KDiffusionSampler):
108 | def __init__(self, funcname, sd_model, options=None):
109 | super().__init__(funcname, sd_model, options)
110 | self.model_wrap_cfg = CFGDenoiserLCM(self)
111 | self.model_wrap = self.model_wrap_cfg.inner_model
112 |
113 |
114 | class AnimateDiffLCM:
115 | lcm_ui_injected = False
116 |
117 |
118 | @staticmethod
119 | def hack_kdiff_ui():
120 | if AnimateDiffLCM.lcm_ui_injected:
121 | logger.info(f"LCM UI already injected.")
122 | return
123 |
124 | logger.info(f"Injecting LCM to UI.")
125 | from modules import sd_samplers, sd_samplers_common
126 | samplers_lcm = [('LCM', sample_lcm, ['k_lcm'], {})]
127 | samplers_data_lcm = [
128 | sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: LCMSampler(funcname, model), aliases, options)
129 | for label, funcname, aliases, options in samplers_lcm
130 | ]
131 | sd_samplers.all_samplers.extend(samplers_data_lcm)
132 | sd_samplers.all_samplers_map = {x.name: x for x in sd_samplers.all_samplers}
133 | sd_samplers.set_samplers()
134 | AnimateDiffLCM.lcm_ui_injected = True
135 |
--------------------------------------------------------------------------------
/scripts/easyphoto_utils/animatediff/animatediff_logger.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | import sys
4 |
5 | from modules import shared
6 |
7 |
8 | class ColoredFormatter(logging.Formatter):
9 | COLORS = {
10 | "DEBUG": "\033[0;36m", # CYAN
11 | "INFO": "\033[0;32m", # GREEN
12 | "WARNING": "\033[0;33m", # YELLOW
13 | "ERROR": "\033[0;31m", # RED
14 | "CRITICAL": "\033[0;37;41m", # WHITE ON RED
15 | "RESET": "\033[0m", # RESET COLOR
16 | }
17 |
18 | def format(self, record):
19 | colored_record = copy.copy(record)
20 | levelname = colored_record.levelname
21 | seq = self.COLORS.get(levelname, self.COLORS["RESET"])
22 | colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}"
23 | return super().format(colored_record)
24 |
25 |
26 | # Create a new logger
27 | logger_animatediff = logging.getLogger("AnimateDiff")
28 | logger_animatediff.propagate = False
29 |
30 | # Add handler if we don't have one.
31 | if not logger_animatediff.handlers:
32 | handler = logging.StreamHandler(sys.stdout)
33 | handler.setFormatter(
34 | ColoredFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
35 | )
36 | logger_animatediff.addHandler(handler)
37 |
38 | # Configure logger
39 | loglevel_string = getattr(shared.cmd_opts, "animatediff_loglevel", "INFO")
40 | loglevel = getattr(logging, loglevel_string.upper(), None)
41 | logger_animatediff.setLevel(loglevel)
42 |
--------------------------------------------------------------------------------
/scripts/easyphoto_utils/animatediff/animatediff_lora.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import sys
4 |
5 | from modules import sd_models, shared
6 | from modules.paths import extensions_builtin_dir
7 |
8 | from .animatediff_logger import logger_animatediff as logger
9 |
10 | sys.path.append(f"{extensions_builtin_dir}/Lora")
11 |
12 | class AnimateDiffLora:
13 | original_load_network = None
14 |
15 | def __init__(self, v2: bool):
16 | self.v2 = v2
17 |
18 | def hack(self):
19 | if not self.v2:
20 | return
21 |
22 | if AnimateDiffLora.original_load_network is not None:
23 | logger.info("AnimateDiff LoRA already hacked")
24 | return
25 |
26 | logger.info("Hacking LoRA module to support motion LoRA")
27 | import network
28 | import networks
29 | AnimateDiffLora.original_load_network = networks.load_network
30 | original_load_network = AnimateDiffLora.original_load_network
31 |
32 | def mm_load_network(name, network_on_disk):
33 |
34 | def convert_mm_name_to_compvis(key):
35 | sd_module_key, _, network_part = re.split(r'(_lora\.)', key)
36 | sd_module_key = sd_module_key.replace("processor.", "").replace("to_out", "to_out.0")
37 | return sd_module_key, 'lora_' + network_part
38 |
39 | net = network.Network(name, network_on_disk)
40 | net.mtime = os.path.getmtime(network_on_disk.filename)
41 |
42 | sd = sd_models.read_state_dict(network_on_disk.filename)
43 |
44 | if 'motion_modules' in list(sd.keys())[0]:
45 | logger.info(f"Loading motion LoRA {name} from {network_on_disk.filename}")
46 | matched_networks = {}
47 |
48 | for key_network, weight in sd.items():
49 | key, network_part = convert_mm_name_to_compvis(key_network)
50 | sd_module = shared.sd_model.network_layer_mapping.get(key, None)
51 |
52 | assert sd_module is not None, f"Failed to find sd module for key {key}."
53 |
54 | if key not in matched_networks:
55 | matched_networks[key] = network.NetworkWeights(
56 | network_key=key_network, sd_key=key, w={}, sd_module=sd_module)
57 |
58 | matched_networks[key].w[network_part] = weight
59 |
60 | for key, weights in matched_networks.items():
61 | net_module = networks.module_types[0].create_module(net, weights)
62 | assert net_module is not None, "Failed to create motion module LoRA"
63 | net.modules[key] = net_module
64 |
65 | return net
66 | else:
67 | del sd
68 | return original_load_network(name, network_on_disk)
69 |
70 | networks.load_network = mm_load_network
71 |
72 |
73 | def restore(self):
74 | if not self.v2:
75 | return
76 |
77 | if AnimateDiffLora.original_load_network is None:
78 | logger.info("AnimateDiff LoRA already restored")
79 | return
80 |
81 | logger.info("Restoring hacked LoRA")
82 | import networks
83 | networks.load_network = AnimateDiffLora.original_load_network
84 | AnimateDiffLora.original_load_network = None
85 |
--------------------------------------------------------------------------------
/scripts/easyphoto_utils/animatediff/animatediff_mm.py:
--------------------------------------------------------------------------------
1 | import gc
2 | import os
3 |
4 | import torch
5 | from einops import rearrange
6 | from modules import hashes, shared, sd_models, devices
7 | from modules.devices import cpu, device, torch_gc
8 |
9 | from .motion_module import MotionWrapper, MotionModuleType
10 | from .animatediff_logger import logger_animatediff as logger
11 |
12 |
13 | class AnimateDiffMM:
14 | mm_injected = False
15 |
16 | def __init__(self):
17 | self.mm: MotionWrapper = None
18 | self.script_dir = None
19 | self.prev_alpha_cumprod = None
20 | self.gn32_original_forward = None
21 |
22 |
23 | def set_script_dir(self, script_dir):
24 | self.script_dir = script_dir
25 |
26 |
27 | def get_model_dir(self):
28 | model_dir = shared.opts.data.get("animatediff_model_path", os.path.join(self.script_dir, "model"))
29 | if not model_dir:
30 | model_dir = os.path.join(self.script_dir, "model")
31 | return model_dir
32 |
33 |
34 | def _load(self, model_name):
35 | model_path = os.path.join(self.get_model_dir(), model_name)
36 | if not os.path.isfile(model_path):
37 | raise RuntimeError("Please download models manually.")
38 | if self.mm is None or self.mm.mm_name != model_name:
39 | logger.info(f"Loading motion module {model_name} from {model_path}")
40 | model_hash = hashes.sha256(model_path, f"AnimateDiff/{model_name}")
41 | mm_state_dict = sd_models.read_state_dict(model_path)
42 | model_type = MotionModuleType.get_mm_type(mm_state_dict)
43 | logger.info(f"Guessed {model_name} architecture: {model_type}")
44 | self.mm = MotionWrapper(model_name, model_hash, model_type)
45 | missed_keys = self.mm.load_state_dict(mm_state_dict)
46 | logger.warn(f"Missing keys {missed_keys}")
47 | self.mm.to(device).eval()
48 | if not shared.cmd_opts.no_half:
49 | self.mm.half()
50 | if getattr(devices, "fp8", False):
51 | for module in self.mm.modules():
52 | if isinstance(module, (torch.nn.Conv2d, torch.nn.Linear)):
53 | module.to(torch.float8_e4m3fn)
54 |
55 |
56 | def inject(self, sd_model, model_name="mm_sd_v15.ckpt"):
57 | if AnimateDiffMM.mm_injected:
58 | logger.info("Motion module already injected. Trying to restore.")
59 | self.restore(sd_model)
60 |
61 | unet = sd_model.model.diffusion_model
62 | self._load(model_name)
63 | inject_sdxl = sd_model.is_sdxl or self.mm.is_xl
64 | sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5"
65 | assert sd_model.is_sdxl == self.mm.is_xl, f"Motion module incompatible with SD. You are using {sd_ver} with {self.mm.mm_type}."
66 |
67 | if self.mm.is_v2:
68 | logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet middle block.")
69 | unet.middle_block.insert(-1, self.mm.mid_block.motion_modules[0])
70 | elif not self.mm.is_adxl:
71 | logger.info(f"Hacking {sd_ver} GroupNorm32 forward function.")
72 | if self.mm.is_hotshot:
73 | from sgm.modules.diffusionmodules.util import GroupNorm32
74 | else:
75 | from ldm.modules.diffusionmodules.util import GroupNorm32
76 | self.gn32_original_forward = GroupNorm32.forward
77 | gn32_original_forward = self.gn32_original_forward
78 |
79 | def groupnorm32_mm_forward(self, x):
80 | x = rearrange(x, "(b f) c h w -> b c f h w", b=2)
81 | x = gn32_original_forward(self, x)
82 | x = rearrange(x, "b c f h w -> (b f) c h w", b=2)
83 | return x
84 |
85 | GroupNorm32.forward = groupnorm32_mm_forward
86 |
87 | logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet input blocks.")
88 | for mm_idx, unet_idx in enumerate([1, 2, 4, 5, 7, 8, 10, 11]):
89 | if inject_sdxl and mm_idx >= 6:
90 | break
91 | mm_idx0, mm_idx1 = mm_idx // 2, mm_idx % 2
92 | mm_inject = getattr(self.mm.down_blocks[mm_idx0], "temporal_attentions" if self.mm.is_hotshot else "motion_modules")[mm_idx1]
93 | unet.input_blocks[unet_idx].append(mm_inject)
94 |
95 | logger.info(f"Injecting motion module {model_name} into {sd_ver} UNet output blocks.")
96 | for unet_idx in range(12):
97 | if inject_sdxl and unet_idx >= 9:
98 | break
99 | mm_idx0, mm_idx1 = unet_idx // 3, unet_idx % 3
100 | mm_inject = getattr(self.mm.up_blocks[mm_idx0], "temporal_attentions" if self.mm.is_hotshot else "motion_modules")[mm_idx1]
101 | if unet_idx % 3 == 2 and unet_idx != (8 if self.mm.is_xl else 11):
102 | unet.output_blocks[unet_idx].insert(-1, mm_inject)
103 | else:
104 | unet.output_blocks[unet_idx].append(mm_inject)
105 |
106 | self._set_ddim_alpha(sd_model)
107 | self._set_layer_mapping(sd_model)
108 | AnimateDiffMM.mm_injected = True
109 | logger.info(f"Injection finished.")
110 |
111 |
112 | def restore(self, sd_model):
113 | if not AnimateDiffMM.mm_injected:
114 | logger.info("Motion module already removed.")
115 | return
116 |
117 | inject_sdxl = sd_model.is_sdxl or self.mm.is_xl
118 | sd_ver = "SDXL" if sd_model.is_sdxl else "SD1.5"
119 | self._restore_ddim_alpha(sd_model)
120 | unet = sd_model.model.diffusion_model
121 |
122 | logger.info(f"Removing motion module from {sd_ver} UNet input blocks.")
123 | for unet_idx in [1, 2, 4, 5, 7, 8, 10, 11]:
124 | if inject_sdxl and unet_idx >= 9:
125 | break
126 | unet.input_blocks[unet_idx].pop(-1)
127 |
128 | logger.info(f"Removing motion module from {sd_ver} UNet output blocks.")
129 | for unet_idx in range(12):
130 | if inject_sdxl and unet_idx >= 9:
131 | break
132 | if unet_idx % 3 == 2 and unet_idx != (8 if self.mm.is_xl else 11):
133 | unet.output_blocks[unet_idx].pop(-2)
134 | else:
135 | unet.output_blocks[unet_idx].pop(-1)
136 |
137 | if self.mm.is_v2:
138 | logger.info(f"Removing motion module from {sd_ver} UNet middle block.")
139 | unet.middle_block.pop(-2)
140 | elif not self.mm.is_adxl:
141 | logger.info(f"Restoring {sd_ver} GroupNorm32 forward function.")
142 | if self.mm.is_hotshot:
143 | from sgm.modules.diffusionmodules.util import GroupNorm32
144 | else:
145 | from ldm.modules.diffusionmodules.util import GroupNorm32
146 | GroupNorm32.forward = self.gn32_original_forward
147 | self.gn32_original_forward = None
148 |
149 | AnimateDiffMM.mm_injected = False
150 | logger.info(f"Removal finished.")
151 | if shared.cmd_opts.lowvram:
152 | self.unload()
153 |
154 |
155 | def _set_ddim_alpha(self, sd_model):
156 | logger.info(f"Setting DDIM alpha.")
157 | beta_start = 0.00085
158 | beta_end = 0.020 if self.mm.is_adxl else 0.012
159 | if self.mm.is_adxl:
160 | betas = torch.linspace(beta_start**0.5, beta_end**0.5, 1000, dtype=torch.float32, device=device) ** 2
161 | else:
162 | betas = torch.linspace(
163 | beta_start,
164 | beta_end,
165 | 1000 if sd_model.is_sdxl else sd_model.num_timesteps,
166 | dtype=torch.float32,
167 | device=device,
168 | )
169 | alphas = 1.0 - betas
170 | alphas_cumprod = torch.cumprod(alphas, dim=0)
171 | self.prev_alpha_cumprod = sd_model.alphas_cumprod
172 | sd_model.alphas_cumprod = alphas_cumprod
173 |
174 |
175 | def _set_layer_mapping(self, sd_model):
176 | if hasattr(sd_model, 'network_layer_mapping'):
177 | for name, module in self.mm.named_modules():
178 | sd_model.network_layer_mapping[name] = module
179 | module.network_layer_name = name
180 |
181 |
182 | def _restore_ddim_alpha(self, sd_model):
183 | logger.info(f"Restoring DDIM alpha.")
184 | sd_model.alphas_cumprod = self.prev_alpha_cumprod
185 | self.prev_alpha_cumprod = None
186 |
187 |
188 | def unload(self):
189 | logger.info("Moving motion module to CPU")
190 | if self.mm is not None:
191 | self.mm.to(cpu)
192 | torch_gc()
193 | gc.collect()
194 |
195 |
196 | def remove(self):
197 | logger.info("Removing motion module from any memory")
198 | del self.mm
199 | self.mm = None
200 | torch_gc()
201 | gc.collect()
202 |
203 |
204 | mm_animatediff = AnimateDiffMM()
205 |
--------------------------------------------------------------------------------
/scripts/easyphoto_utils/animatediff/animatediff_prompt.py:
--------------------------------------------------------------------------------
1 | import re
2 | import torch
3 |
4 | from modules.processing import StableDiffusionProcessing, Processed
5 |
6 | from .animatediff_logger import logger_animatediff as logger
7 | from .animatediff_infotext import write_params_txt
8 |
9 |
10 | class AnimateDiffPromptSchedule:
11 |
12 | def __init__(self):
13 | self.prompt_map = None
14 | self.original_prompt = None
15 |
16 |
17 | def save_infotext_img(self, p: StableDiffusionProcessing):
18 | if self.prompt_map is not None:
19 | p.prompts = [self.original_prompt for _ in range(p.batch_size)]
20 |
21 |
22 | def save_infotext_txt(self, res: Processed):
23 | if self.prompt_map is not None:
24 | parts = res.info.split('\nNegative prompt: ', 1)
25 | if len(parts) > 1:
26 | res.info = f"{self.original_prompt}\nNegative prompt: {parts[1]}"
27 | for i in range(len(res.infotexts)):
28 | parts = res.infotexts[i].split('\nNegative prompt: ', 1)
29 | if len(parts) > 1:
30 | res.infotexts[i] = f"{self.original_prompt}\nNegative prompt: {parts[1]}"
31 | write_params_txt(res.info)
32 |
33 |
34 | def parse_prompt(self, p: StableDiffusionProcessing):
35 | if type(p.prompt) is not str:
36 | logger.warn("prompt is not str, cannot support prompt map")
37 | return
38 |
39 | lines = p.prompt.strip().split('\n')
40 | data = {
41 | 'head_prompts': [],
42 | 'mapp_prompts': {},
43 | 'tail_prompts': []
44 | }
45 |
46 | mode = 'head'
47 | for line in lines:
48 | if mode == 'head':
49 | if re.match(r'^\d+:', line):
50 | mode = 'mapp'
51 | else:
52 | data['head_prompts'].append(line)
53 |
54 | if mode == 'mapp':
55 | match = re.match(r'^(\d+): (.+)$', line)
56 | if match:
57 | frame, prompt = match.groups()
58 | data['mapp_prompts'][int(frame)] = prompt
59 | else:
60 | mode = 'tail'
61 |
62 | if mode == 'tail':
63 | data['tail_prompts'].append(line)
64 |
65 | if data['mapp_prompts']:
66 | logger.info("You are using prompt travel.")
67 | self.prompt_map = {}
68 | prompt_list = []
69 | last_frame = 0
70 | current_prompt = ''
71 | for frame, prompt in data['mapp_prompts'].items():
72 | prompt_list += [current_prompt for _ in range(last_frame, frame)]
73 | last_frame = frame
74 | current_prompt = f"{', '.join(data['head_prompts'])}, {prompt}, {', '.join(data['tail_prompts'])}"
75 | self.prompt_map[frame] = current_prompt
76 | prompt_list += [current_prompt for _ in range(last_frame, p.batch_size)]
77 | assert len(prompt_list) == p.batch_size, f"prompt_list length {len(prompt_list)} != batch_size {p.batch_size}"
78 | self.original_prompt = p.prompt
79 | p.prompt = prompt_list * p.n_iter
80 |
81 |
82 | def single_cond(self, center_frame, video_length: int, cond: torch.Tensor, closed_loop = False):
83 | if closed_loop:
84 | key_prev = list(self.prompt_map.keys())[-1]
85 | key_next = list(self.prompt_map.keys())[0]
86 | else:
87 | key_prev = list(self.prompt_map.keys())[0]
88 | key_next = list(self.prompt_map.keys())[-1]
89 |
90 | for p in self.prompt_map.keys():
91 | if p > center_frame:
92 | key_next = p
93 | break
94 | key_prev = p
95 |
96 | dist_prev = center_frame - key_prev
97 | if dist_prev < 0:
98 | dist_prev += video_length
99 | dist_next = key_next - center_frame
100 | if dist_next < 0:
101 | dist_next += video_length
102 |
103 | if key_prev == key_next or dist_prev + dist_next == 0:
104 | return cond[key_prev] if isinstance(cond, torch.Tensor) else {k: v[key_prev] for k, v in cond.items()}
105 |
106 | rate = dist_prev / (dist_prev + dist_next)
107 | if isinstance(cond, torch.Tensor):
108 | return AnimateDiffPromptSchedule.slerp(cond[key_prev], cond[key_next], rate)
109 | else: # isinstance(cond, dict)
110 | return {
111 | k: AnimateDiffPromptSchedule.slerp(v[key_prev], v[key_next], rate)
112 | for k, v in cond.items()
113 | }
114 |
115 |
116 | def multi_cond(self, cond: torch.Tensor, closed_loop = False):
117 | if self.prompt_map is None:
118 | return cond
119 | cond_list = [] if isinstance(cond, torch.Tensor) else {k: [] for k in cond.keys()}
120 | for i in range(cond.shape[0]):
121 | single_cond = self.single_cond(i, cond.shape[0], cond, closed_loop)
122 | if isinstance(cond, torch.Tensor):
123 | cond_list.append(single_cond)
124 | else:
125 | for k, v in single_cond.items():
126 | cond_list[k].append(v)
127 | if isinstance(cond, torch.Tensor):
128 | return torch.stack(cond_list).to(cond.dtype).to(cond.device)
129 | else:
130 | return {k: torch.stack(v).to(cond[k].dtype).to(cond[k].device) for k, v in cond_list.items()}
131 |
132 |
133 | @staticmethod
134 | def slerp(
135 | v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
136 | ) -> torch.Tensor:
137 | u0 = v0 / v0.norm()
138 | u1 = v1 / v1.norm()
139 | dot = (u0 * u1).sum()
140 | if dot.abs() > DOT_THRESHOLD:
141 | return (1.0 - t) * v0 + t * v1
142 | omega = dot.acos()
143 | return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
144 |
--------------------------------------------------------------------------------
/scripts/easyphoto_utils/animatediff/animatediff_ui.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import cv2
4 | import gradio as gr
5 |
6 | from modules import shared
7 | from modules.processing import StableDiffusionProcessing
8 |
9 | from .animatediff_mm import mm_animatediff as motion_module
10 | from .animatediff_i2ibatch import animatediff_i2ibatch
11 | from .animatediff_lcm import AnimateDiffLCM
12 |
13 |
14 | class ToolButton(gr.Button, gr.components.FormComponent):
15 | """Small button with single emoji as text, fits inside gradio forms"""
16 |
17 | def __init__(self, **kwargs):
18 | super().__init__(variant="tool", **kwargs)
19 |
20 |
21 | def get_block_name(self):
22 | return "button"
23 |
24 |
25 | class AnimateDiffProcess:
26 |
27 | def __init__(
28 | self,
29 | model="mm_sd_v15_v2.ckpt",
30 | enable=False,
31 | video_length=0,
32 | fps=8,
33 | loop_number=0,
34 | closed_loop='R-P',
35 | batch_size=16,
36 | stride=1,
37 | overlap=-1,
38 | format=["GIF", "PNG"],
39 | interp='Off',
40 | interp_x=10,
41 | video_source=None,
42 | video_path='',
43 | latent_power=1,
44 | latent_scale=32,
45 | last_frame=None,
46 | latent_power_last=1,
47 | latent_scale_last=32,
48 | request_id = '',
49 | ):
50 | self.model = model
51 | self.enable = enable
52 | self.video_length = video_length
53 | self.fps = fps
54 | self.loop_number = loop_number
55 | self.closed_loop = closed_loop
56 | self.batch_size = batch_size
57 | self.stride = stride
58 | self.overlap = overlap
59 | self.format = format
60 | self.interp = interp
61 | self.interp_x = interp_x
62 | self.video_source = video_source
63 | self.video_path = video_path
64 | self.latent_power = latent_power
65 | self.latent_scale = latent_scale
66 | self.last_frame = last_frame
67 | self.latent_power_last = latent_power_last
68 | self.latent_scale_last = latent_scale_last
69 | self.request_id = request_id
70 |
71 |
72 | def get_list(self, is_img2img: bool):
73 | list_var = list(vars(self).values())[:-1]
74 | if is_img2img:
75 | animatediff_i2ibatch.hack()
76 | else:
77 | list_var = list_var[:-5]
78 | return list_var
79 |
80 |
81 | def get_dict(self, is_img2img: bool):
82 | infotext = {
83 | "enable": self.enable,
84 | "model": self.model,
85 | "video_length": self.video_length,
86 | "fps": self.fps,
87 | "loop_number": self.loop_number,
88 | "closed_loop": self.closed_loop,
89 | "batch_size": self.batch_size,
90 | "stride": self.stride,
91 | "overlap": self.overlap,
92 | "interp": self.interp,
93 | "interp_x": self.interp_x,
94 | }
95 | if self.request_id:
96 | infotext['request_id'] = self.request_id
97 | if motion_module.mm is not None and motion_module.mm.mm_hash is not None:
98 | infotext['mm_hash'] = motion_module.mm.mm_hash[:8]
99 | if is_img2img:
100 | infotext.update({
101 | "latent_power": self.latent_power,
102 | "latent_scale": self.latent_scale,
103 | "latent_power_last": self.latent_power_last,
104 | "latent_scale_last": self.latent_scale_last,
105 | })
106 | infotext_str = ', '.join(f"{k}: {v}" for k, v in infotext.items())
107 | return infotext_str
108 |
109 |
110 | def _check(self):
111 | assert (
112 | self.video_length >= 0 and self.fps > 0
113 | ), "Video length and FPS should be positive."
114 | assert not set(["GIF", "MP4", "PNG", "WEBP", "WEBM"]).isdisjoint(
115 | self.format
116 | ), "At least one saving format should be selected."
117 |
118 |
119 | def set_p(self, p: StableDiffusionProcessing):
120 | self._check()
121 | if self.video_length < self.batch_size:
122 | p.batch_size = self.batch_size
123 | else:
124 | p.batch_size = self.video_length
125 | if self.video_length == 0:
126 | self.video_length = p.batch_size
127 | self.video_default = True
128 | else:
129 | self.video_default = False
130 | if self.overlap == -1:
131 | self.overlap = self.batch_size // 4
132 | if "PNG" not in self.format or shared.opts.data.get("animatediff_save_to_custom", False):
133 | p.do_not_save_samples = True
134 |
135 |
136 | class AnimateDiffUiGroup:
137 | txt2img_submit_button = None
138 | img2img_submit_button = None
139 |
140 | def __init__(self):
141 | self.params = AnimateDiffProcess()
142 |
143 |
144 | def render(self, is_img2img: bool, model_dir: str):
145 | if not os.path.isdir(model_dir):
146 | os.mkdir(model_dir)
147 | elemid_prefix = "img2img-ad-" if is_img2img else "txt2img-ad-"
148 | model_list = [f for f in os.listdir(model_dir) if f != ".gitkeep"]
149 | with gr.Accordion("AnimateDiff", open=False):
150 | gr.Markdown(value="Please click [this link](https://github.com/continue-revolution/sd-webui-animatediff#webui-parameters) to read the documentation of each parameter.")
151 | with gr.Row():
152 |
153 | def refresh_models(*inputs):
154 | new_model_list = [
155 | f for f in os.listdir(model_dir) if f != ".gitkeep"
156 | ]
157 | dd = inputs[0]
158 | if dd in new_model_list:
159 | selected = dd
160 | elif len(new_model_list) > 0:
161 | selected = new_model_list[0]
162 | else:
163 | selected = None
164 | return gr.Dropdown.update(choices=new_model_list, value=selected)
165 |
166 | with gr.Row():
167 | self.params.model = gr.Dropdown(
168 | choices=model_list,
169 | value=(self.params.model if self.params.model in model_list else None),
170 | label="Motion module",
171 | type="value",
172 | elem_id=f"{elemid_prefix}motion-module",
173 | )
174 | refresh_model = ToolButton(value="\U0001f504")
175 | refresh_model.click(refresh_models, self.params.model, self.params.model)
176 |
177 | self.params.format = gr.CheckboxGroup(
178 | choices=["GIF", "MP4", "WEBP", "WEBM", "PNG", "TXT"],
179 | label="Save format",
180 | type="value",
181 | elem_id=f"{elemid_prefix}save-format",
182 | value=self.params.format,
183 | )
184 | with gr.Row():
185 | self.params.enable = gr.Checkbox(
186 | value=self.params.enable, label="Enable AnimateDiff",
187 | elem_id=f"{elemid_prefix}enable"
188 | )
189 | self.params.video_length = gr.Number(
190 | minimum=0,
191 | value=self.params.video_length,
192 | label="Number of frames",
193 | precision=0,
194 | elem_id=f"{elemid_prefix}video-length",
195 | )
196 | self.params.fps = gr.Number(
197 | value=self.params.fps, label="FPS", precision=0,
198 | elem_id=f"{elemid_prefix}fps"
199 | )
200 | self.params.loop_number = gr.Number(
201 | minimum=0,
202 | value=self.params.loop_number,
203 | label="Display loop number",
204 | precision=0,
205 | elem_id=f"{elemid_prefix}loop-number",
206 | )
207 | with gr.Row():
208 | self.params.closed_loop = gr.Radio(
209 | choices=["N", "R-P", "R+P", "A"],
210 | value=self.params.closed_loop,
211 | label="Closed loop",
212 | elem_id=f"{elemid_prefix}closed-loop",
213 | )
214 | self.params.batch_size = gr.Slider(
215 | minimum=1,
216 | maximum=32,
217 | value=self.params.batch_size,
218 | label="Context batch size",
219 | step=1,
220 | precision=0,
221 | elem_id=f"{elemid_prefix}batch-size",
222 | )
223 | self.params.stride = gr.Number(
224 | minimum=1,
225 | value=self.params.stride,
226 | label="Stride",
227 | precision=0,
228 | elem_id=f"{elemid_prefix}stride",
229 | )
230 | self.params.overlap = gr.Number(
231 | minimum=-1,
232 | value=self.params.overlap,
233 | label="Overlap",
234 | precision=0,
235 | elem_id=f"{elemid_prefix}overlap",
236 | )
237 | with gr.Row():
238 | self.params.interp = gr.Radio(
239 | choices=["Off", "FILM"],
240 | label="Frame Interpolation",
241 | elem_id=f"{elemid_prefix}interp-choice",
242 | value=self.params.interp
243 | )
244 | self.params.interp_x = gr.Number(
245 | value=self.params.interp_x, label="Interp X", precision=0,
246 | elem_id=f"{elemid_prefix}interp-x"
247 | )
248 | self.params.video_source = gr.Video(
249 | value=self.params.video_source,
250 | label="Video source",
251 | )
252 | def update_fps(video_source):
253 | if video_source is not None and video_source != '':
254 | cap = cv2.VideoCapture(video_source)
255 | fps = int(cap.get(cv2.CAP_PROP_FPS))
256 | cap.release()
257 | return fps
258 | else:
259 | return int(self.params.fps.value)
260 | self.params.video_source.change(update_fps, inputs=self.params.video_source, outputs=self.params.fps)
261 | def update_frames(video_source):
262 | if video_source is not None and video_source != '':
263 | cap = cv2.VideoCapture(video_source)
264 | frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
265 | cap.release()
266 | return frames
267 | else:
268 | return int(self.params.video_length.value)
269 | self.params.video_source.change(update_frames, inputs=self.params.video_source, outputs=self.params.video_length)
270 | self.params.video_path = gr.Textbox(
271 | value=self.params.video_path,
272 | label="Video path",
273 | elem_id=f"{elemid_prefix}video-path"
274 | )
275 | if is_img2img:
276 | with gr.Row():
277 | self.params.latent_power = gr.Slider(
278 | minimum=0.1,
279 | maximum=10,
280 | value=self.params.latent_power,
281 | step=0.1,
282 | label="Latent power",
283 | elem_id=f"{elemid_prefix}latent-power",
284 | )
285 | self.params.latent_scale = gr.Slider(
286 | minimum=1,
287 | maximum=128,
288 | value=self.params.latent_scale,
289 | label="Latent scale",
290 | elem_id=f"{elemid_prefix}latent-scale"
291 | )
292 | self.params.latent_power_last = gr.Slider(
293 | minimum=0.1,
294 | maximum=10,
295 | value=self.params.latent_power_last,
296 | step=0.1,
297 | label="Optional latent power for last frame",
298 | elem_id=f"{elemid_prefix}latent-power-last",
299 | )
300 | self.params.latent_scale_last = gr.Slider(
301 | minimum=1,
302 | maximum=128,
303 | value=self.params.latent_scale_last,
304 | label="Optional latent scale for last frame",
305 | elem_id=f"{elemid_prefix}latent-scale-last"
306 | )
307 | self.params.last_frame = gr.Image(
308 | label="Optional last frame. Leave it blank if you do not need one.",
309 | type="pil",
310 | )
311 | with gr.Row():
312 | unload = gr.Button(value="Move motion module to CPU (default if lowvram)")
313 | remove = gr.Button(value="Remove motion module from any memory")
314 | unload.click(fn=motion_module.unload)
315 | remove.click(fn=motion_module.remove)
316 | return self.register_unit(is_img2img)
317 |
318 |
319 | def register_unit(self, is_img2img: bool):
320 | unit = gr.State(value=AnimateDiffProcess)
321 | (
322 | AnimateDiffUiGroup.img2img_submit_button
323 | if is_img2img
324 | else AnimateDiffUiGroup.txt2img_submit_button
325 | ).click(
326 | fn=AnimateDiffProcess,
327 | inputs=self.params.get_list(is_img2img),
328 | outputs=unit,
329 | queue=False,
330 | )
331 | return unit
332 |
333 |
334 | @staticmethod
335 | def on_after_component(component, **_kwargs):
336 | elem_id = getattr(component, "elem_id", None)
337 |
338 | if elem_id == "txt2img_generate":
339 | AnimateDiffUiGroup.txt2img_submit_button = component
340 | return
341 |
342 | if elem_id == "img2img_generate":
343 | AnimateDiffUiGroup.img2img_submit_button = component
344 | return
345 |
346 |
347 | @staticmethod
348 | def on_before_ui():
349 | AnimateDiffLCM.hack_kdiff_ui()
350 |
--------------------------------------------------------------------------------
/scripts/easyphoto_utils/loractl_utils.py:
--------------------------------------------------------------------------------
1 | """Borrowed from https://github.com/cheald/sd-webui-loractl.
2 | """
3 | import io
4 | import os
5 | import re
6 | import sys
7 |
8 | import gradio as gr
9 | import matplotlib
10 | import numpy as np
11 | import pandas as pd
12 | from modules import extra_networks, shared, script_callbacks
13 | from modules.processing import StableDiffusionProcessing
14 | import modules.scripts as scripts
15 | from PIL import Image
16 | from scripts.easyphoto_config import extensions_builtin_dir, extensions_dir
17 |
18 |
19 | # TODO: refactor the plugin dependency.
20 | lora_extensions_path = os.path.join(extensions_dir, "Lora")
21 | lora_extensions_builtin_path = os.path.join(extensions_builtin_dir, "Lora")
22 |
23 | if os.path.exists(lora_extensions_path):
24 | lora_path = lora_extensions_path
25 | elif os.path.exists(lora_extensions_builtin_path):
26 | lora_path = lora_extensions_builtin_path
27 | else:
28 | raise ImportError("Lora extension is not found.")
29 | sys.path.insert(0, lora_path)
30 | import extra_networks_lora
31 | import network
32 | import networks
33 |
34 | sys.path.remove(lora_path)
35 |
36 |
37 | def check_loractl_conflict():
38 | loractl_extensions_path = os.path.join(extensions_dir, "sd-webui-loractl")
39 | if os.path.exists(loractl_extensions_path):
40 | disabled_extensions = shared.opts.data.get("disabled_extensions", [])
41 | if "sd-webui-loractl" not in disabled_extensions:
42 | return True
43 | return False
44 |
45 |
46 | # Borrowed from https://github.com/cheald/sd-webui-loractl/blob/master/loractl/lib/lora_ctl_network.py.
47 | lora_weights = {}
48 |
49 |
50 | def reset_weights():
51 | global lora_weights
52 | lora_weights.clear()
53 |
54 |
55 | class LoraCtlNetwork(extra_networks_lora.ExtraNetworkLora):
56 | # Hijack the params parser and feed it dummy weights instead so it doesn't choke trying to
57 | # parse our extended syntax
58 | def activate(self, p, params_list):
59 | if not is_active():
60 | return super().activate(p, params_list)
61 |
62 | for params in params_list:
63 | assert params.items
64 | name = params.positional[0]
65 | if lora_weights.get(name, None) is None:
66 | lora_weights[name] = params_to_weights(params)
67 | # The hardcoded 1 weight is fine here, since our actual patch looks up the weights from
68 | # our lora_weights dict
69 | params.positional = [name, 1]
70 | params.named = {}
71 | return super().activate(p, params_list)
72 |
73 |
74 | # Borrowed from https://github.com/cheald/sd-webui-loractl/blob/master/loractl/lib/utils.py.
75 | # Given a string like x@y,z@a, returns [[x, z], [y, a]] sorted for consumption by np.interp.
76 | def sorted_positions(raw_steps):
77 | steps = [[float(s.strip()) for s in re.split("[@~]", x)] for x in re.split("[,;]", str(raw_steps))]
78 | # If we just got a single number, just return it
79 | if len(steps[0]) == 1:
80 | return steps[0][0]
81 |
82 | # Add implicit 1s to any steps which don't have a weight
83 | steps = [[s[0], s[1] if len(s) == 2 else 1] for s in steps]
84 |
85 | # Sort by index
86 | steps.sort(key=lambda k: k[1])
87 |
88 | steps = [list(v) for v in zip(*steps)]
89 | return steps
90 |
91 |
92 | def calculate_weight(m, step, max_steps, step_offset=2):
93 | if isinstance(m, list):
94 | # normalize the step to 0~1
95 | if m[1][-1] <= 1.0:
96 | if max_steps > 0:
97 | step = (step) / (max_steps - step_offset)
98 | else:
99 | step = 1.0
100 | else:
101 | step = step
102 | # get value from interp between m[1]~m[0]
103 | v = np.interp(step, m[1], m[0])
104 | return v
105 | else:
106 | return m
107 |
108 |
109 | def params_to_weights(params):
110 | weights = {"unet": None, "te": 1.0, "hrunet": None, "hrte": None}
111 |
112 | if len(params.positional) > 1:
113 | weights["te"] = sorted_positions(params.positional[1])
114 |
115 | if len(params.positional) > 2:
116 | weights["unet"] = sorted_positions(params.positional[2])
117 |
118 | if params.named.get("te"):
119 | weights["te"] = sorted_positions(params.named.get("te"))
120 |
121 | if params.named.get("unet"):
122 | weights["unet"] = sorted_positions(params.named.get("unet"))
123 |
124 | if params.named.get("hr"):
125 | weights["hrunet"] = sorted_positions(params.named.get("hr"))
126 | weights["hrte"] = sorted_positions(params.named.get("hr"))
127 |
128 | if params.named.get("hrunet"):
129 | weights["hrunet"] = sorted_positions(params.named.get("hrunet"))
130 |
131 | if params.named.get("hrte"):
132 | weights["hrte"] = sorted_positions(params.named.get("hrte"))
133 |
134 | # If unet ended up unset, then use the te value
135 | weights["unet"] = weights["unet"] if weights["unet"] is not None else weights["te"]
136 | # If hrunet ended up unset, use unet value
137 | weights["hrunet"] = weights["hrunet"] if weights["hrunet"] is not None else weights["unet"]
138 | # If hrte ended up unset, use te value
139 | weights["hrte"] = weights["hrte"] if weights["hrte"] is not None else weights["te"]
140 |
141 | return weights
142 |
143 |
144 | hires = False
145 | loractl_active = True
146 |
147 |
148 | def is_hires():
149 | return hires
150 |
151 |
152 | def set_hires(value):
153 | global hires
154 | hires = value
155 |
156 |
157 | def is_active():
158 | global loractl_active
159 | return loractl_active
160 |
161 |
162 | def set_active(value):
163 | global loractl_active
164 | loractl_active = value
165 |
166 |
167 | # Borrowed from https://github.com/cheald/sd-webui-loractl/blob/master/loractl/lib/network_patch.py.
168 | # Patch network.Network so it reapplies properly for dynamic weights
169 | # By default, network application is cached, with (name, te, unet, dim) as a key
170 | # By replacing the bare properties with getters, we can ensure that we cause SD
171 | # to reapply the network each time we change its weights, while still taking advantage
172 | # of caching when weights are not updated.
173 |
174 |
175 | def get_weight(m):
176 | return calculate_weight(m, shared.state.sampling_step, shared.state.sampling_steps, step_offset=2)
177 |
178 |
179 | def get_dynamic_te(self):
180 | if self.name in lora_weights:
181 | key = "te" if not is_hires() else "hrte"
182 | w = lora_weights[self.name]
183 | return get_weight(w.get(key, self._te_multiplier))
184 |
185 | return get_weight(self._te_multiplier)
186 |
187 |
188 | def get_dynamic_unet(self):
189 | if self.name in lora_weights:
190 | key = "unet" if not is_hires() else "hrunet"
191 | w = lora_weights[self.name]
192 | return get_weight(w.get(key, self._unet_multiplier))
193 |
194 | return get_weight(self._unet_multiplier)
195 |
196 |
197 | def set_dynamic_te(self, value):
198 | self._te_multiplier = value
199 |
200 |
201 | def set_dynamic_unet(self, value):
202 | self._unet_multiplier = value
203 |
204 |
205 | def apply():
206 | if getattr(network.Network, "te_multiplier", None) is None:
207 | network.Network.te_multiplier = property(get_dynamic_te, set_dynamic_te)
208 | network.Network.unet_multiplier = property(get_dynamic_unet, set_dynamic_unet)
209 |
210 |
211 | # Borrowed from https://github.com/cheald/sd-webui-loractl/blob/master/scripts/loractl.py.
212 | class LoraCtlScript(scripts.Script):
213 | def __init__(self):
214 | self.original_network = None
215 | super().__init__()
216 |
217 | def title(self):
218 | return "Dynamic Lora Weights (EasyPhoto built-in)"
219 |
220 | def show(self, is_img2img):
221 | return scripts.AlwaysVisible
222 |
223 | def ui(self, is_img2img):
224 | with gr.Group():
225 | with gr.Accordion("Dynamic Lora Weights (EasyPhoto builtin)", open=False):
226 | opt_enable = gr.Checkbox(value=True, label="Enable Dynamic Lora Weights")
227 | opt_plot_lora_weight = gr.Checkbox(value=False, label="Plot the LoRA weight in all steps")
228 | return [opt_enable, opt_plot_lora_weight]
229 |
230 | def process(self, p: StableDiffusionProcessing, opt_enable=True, opt_plot_lora_weight=False, **kwargs):
231 | if opt_enable and type(extra_networks.extra_network_registry["lora"]) != LoraCtlNetwork:
232 | self.original_network = extra_networks.extra_network_registry["lora"]
233 | network = LoraCtlNetwork()
234 | extra_networks.register_extra_network(network)
235 | extra_networks.register_extra_network_alias(network, "loractl")
236 | # elif not opt_enable and type(extra_networks.extra_network_registry["lora"]) != LoraCtlNetwork.__bases__[0]:
237 | # extra_networks.register_extra_network(self.original_network)
238 | # self.original_network = None
239 |
240 | apply()
241 | set_hires(False)
242 | set_active(opt_enable)
243 | reset_weights()
244 | reset_plot()
245 |
246 | def before_hr(self, p, *args):
247 | set_hires(True)
248 |
249 | def postprocess(self, p, processed, opt_enable=True, opt_plot_lora_weight=False, **kwargs):
250 | if opt_plot_lora_weight and opt_enable:
251 | processed.images.extend([make_plot()])
252 |
253 |
254 | # Borrowed from https://github.com/cheald/sd-webui-loractl/blob/master/scripts/loractl.py.
255 | log_weights = []
256 | log_names = []
257 | last_plotted_step = -1
258 |
259 |
260 | # Copied from composable_lora
261 | def plot_lora_weight(lora_weights, lora_names):
262 | data = pd.DataFrame(lora_weights, columns=lora_names)
263 | ax = data.plot()
264 | ax.set_xlabel("Steps")
265 | ax.set_ylabel("LoRA weight")
266 | ax.set_title("LoRA weight in all steps")
267 | ax.legend(loc=0)
268 | result_image = fig2img(ax)
269 | matplotlib.pyplot.close(ax.figure)
270 | del ax
271 | return result_image
272 |
273 |
274 | # Copied from composable_lora
275 | def fig2img(fig):
276 | buf = io.BytesIO()
277 | fig.figure.savefig(buf)
278 | buf.seek(0)
279 | img = Image.open(buf)
280 | return img
281 |
282 |
283 | def reset_plot():
284 | global last_plotted_step
285 | log_weights.clear()
286 | log_names.clear()
287 |
288 |
289 | def make_plot():
290 | return plot_lora_weight(log_weights, log_names)
291 |
292 |
293 | # On each step, capture our lora weights for plotting
294 | def on_step(params):
295 | global last_plotted_step
296 | if last_plotted_step == params.sampling_step and len(log_weights) > 0:
297 | log_weights.pop()
298 | last_plotted_step = params.sampling_step
299 | if len(log_names) == 0:
300 | for net in networks.loaded_networks:
301 | log_names.append(net.name + "_te")
302 | log_names.append(net.name + "_unet")
303 | frame = []
304 | for net in networks.loaded_networks:
305 | frame.append(net.te_multiplier)
306 | frame.append(net.unet_multiplier)
307 | log_weights.append(frame)
308 |
309 |
310 | script_callbacks.on_cfg_after_cfg(on_step)
311 |
--------------------------------------------------------------------------------
/scripts/train_kohya/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py:
--------------------------------------------------------------------------------
1 | # Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/schedulers/scheduling_ddim.py
2 | # with the following modifications:
3 | # - It computes and returns the log prob of `prev_sample` given the UNet prediction.
4 | # - Instead of `variance_noise`, it takes `prev_sample` as an optional argument. If `prev_sample` is provided,
5 | # it uses it to compute the log prob.
6 | # - Timesteps can be a batched torch.Tensor.
7 | # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py.
8 |
9 | import math
10 | from packaging import version
11 | from typing import Optional, Tuple, Union
12 |
13 | import diffusers
14 | import torch
15 | from diffusers.schedulers.scheduling_ddim import DDIMScheduler, DDIMSchedulerOutput
16 |
17 | # See https://github.com/huggingface/diffusers/issues/5025 for details.
18 | if version.parse(diffusers.__version__) > version.parse("0.20.2"):
19 | from diffusers.utils.torch_utils import randn_tensor
20 | else:
21 | from diffusers.utils import randn_tensor
22 |
23 |
24 | def _left_broadcast(t, shape):
25 | assert t.ndim <= len(shape)
26 | return t.reshape(t.shape + (1,) * (len(shape) - t.ndim)).broadcast_to(shape)
27 |
28 |
29 | def _get_variance(self, timestep, prev_timestep):
30 | alpha_prod_t = torch.gather(self.alphas_cumprod, 0, timestep.cpu()).to(timestep.device)
31 | alpha_prod_t_prev = torch.where(
32 | prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod
33 | ).to(timestep.device)
34 | beta_prod_t = 1 - alpha_prod_t
35 | beta_prod_t_prev = 1 - alpha_prod_t_prev
36 |
37 | variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
38 |
39 | return variance
40 |
41 |
42 | def ddim_step_with_logprob(
43 | self: DDIMScheduler,
44 | model_output: torch.FloatTensor,
45 | timestep: int,
46 | sample: torch.FloatTensor,
47 | eta: float = 0.0,
48 | use_clipped_model_output: bool = False,
49 | generator=None,
50 | prev_sample: Optional[torch.FloatTensor] = None,
51 | ) -> Union[DDIMSchedulerOutput, Tuple]:
52 | """
53 | Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
54 | process from the learned model outputs (most often the predicted noise).
55 |
56 | Args:
57 | model_output (`torch.FloatTensor`): direct output from learned diffusion model.
58 | timestep (`int`): current discrete timestep in the diffusion chain.
59 | sample (`torch.FloatTensor`):
60 | current instance of sample being created by diffusion process.
61 | eta (`float`): weight of noise for added noise in diffusion step.
62 | use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
63 | predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
64 | `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
65 | coincide with the one provided as input and `use_clipped_model_output` will have not effect.
66 | generator: random number generator.
67 | variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
68 | can directly provide the noise for the variance itself. This is useful for methods such as
69 | CycleDiffusion. (https://arxiv.org/abs/2210.05559)
70 | return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
71 |
72 | Returns:
73 | [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
74 | [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
75 | returning a tuple, the first element is the sample tensor.
76 |
77 | """
78 | assert isinstance(self, DDIMScheduler)
79 | if self.num_inference_steps is None:
80 | raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler")
81 |
82 | # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
83 | # Ideally, read DDIM paper in-detail understanding
84 |
85 | # Notation ( ->
86 | # - pred_noise_t -> e_theta(x_t, t)
87 | # - pred_original_sample -> f_theta(x_t, t) or x_0
88 | # - std_dev_t -> sigma_t
89 | # - eta -> η
90 | # - pred_sample_direction -> "direction pointing to x_t"
91 | # - pred_prev_sample -> "x_t-1"
92 |
93 | # 1. get previous step value (=t-1)
94 | prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
95 | # to prevent OOB on gather
96 | prev_timestep = torch.clamp(prev_timestep, 0, self.config.num_train_timesteps - 1)
97 |
98 | # 2. compute alphas, betas
99 | alpha_prod_t = self.alphas_cumprod.gather(0, timestep.cpu())
100 | alpha_prod_t_prev = torch.where(prev_timestep.cpu() >= 0, self.alphas_cumprod.gather(0, prev_timestep.cpu()), self.final_alpha_cumprod)
101 | alpha_prod_t = _left_broadcast(alpha_prod_t, sample.shape).to(sample.device)
102 | alpha_prod_t_prev = _left_broadcast(alpha_prod_t_prev, sample.shape).to(sample.device)
103 |
104 | beta_prod_t = 1 - alpha_prod_t
105 |
106 | # 3. compute predicted original sample from predicted noise also called
107 | # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
108 | if self.config.prediction_type == "epsilon":
109 | pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
110 | pred_epsilon = model_output
111 | elif self.config.prediction_type == "sample":
112 | pred_original_sample = model_output
113 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
114 | elif self.config.prediction_type == "v_prediction":
115 | pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
116 | pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
117 | else:
118 | raise ValueError(f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" " `v_prediction`")
119 |
120 | # 4. Clip or threshold "predicted x_0"
121 | if self.config.thresholding:
122 | pred_original_sample = self._threshold_sample(pred_original_sample)
123 | elif self.config.clip_sample:
124 | pred_original_sample = pred_original_sample.clamp(-self.config.clip_sample_range, self.config.clip_sample_range)
125 |
126 | # 5. compute variance: "sigma_t(η)" -> see formula (16)
127 | # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
128 | variance = _get_variance(self, timestep, prev_timestep)
129 | std_dev_t = eta * variance ** (0.5)
130 | std_dev_t = _left_broadcast(std_dev_t, sample.shape).to(sample.device)
131 |
132 | if use_clipped_model_output:
133 | # the pred_epsilon is always re-derived from the clipped x_0 in Glide
134 | pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
135 |
136 | # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
137 | pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
138 |
139 | # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
140 | prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
141 |
142 | if prev_sample is not None and generator is not None:
143 | raise ValueError(
144 | "Cannot pass both generator and prev_sample. Please make sure that either `generator` or" " `prev_sample` stays `None`."
145 | )
146 |
147 | if prev_sample is None:
148 | variance_noise = randn_tensor(model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype)
149 | prev_sample = prev_sample_mean + std_dev_t * variance_noise
150 |
151 | # log prob of prev_sample given prev_sample_mean and std_dev_t
152 | log_prob = (
153 | -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2))
154 | - torch.log(std_dev_t)
155 | - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi)))
156 | )
157 | # mean along all but batch dimension
158 | log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
159 |
160 | return prev_sample.type(sample.dtype), log_prob
161 |
--------------------------------------------------------------------------------
/scripts/train_kohya/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py:
--------------------------------------------------------------------------------
1 | # Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py
2 | # with the following modifications:
3 | # - It uses the patched version of `ddim_step_with_logprob` from `ddim_with_logprob.py`. As such, it only supports the
4 | # `ddim` scheduler.
5 | # - It returns all the intermediate latents of the denoising process as well as the log probs of each denoising step.
6 | # Copied from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py.
7 |
8 | from typing import Any, Callable, Dict, List, Optional, Union
9 |
10 | import torch
11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline, rescale_noise_cfg
12 |
13 | from .ddim_with_logprob import ddim_step_with_logprob
14 |
15 |
16 | @torch.no_grad()
17 | def pipeline_with_logprob(
18 | self: StableDiffusionPipeline,
19 | prompt: Union[str, List[str]] = None,
20 | height: Optional[int] = None,
21 | width: Optional[int] = None,
22 | num_inference_steps: int = 50,
23 | guidance_scale: float = 7.5,
24 | negative_prompt: Optional[Union[str, List[str]]] = None,
25 | num_images_per_prompt: Optional[int] = 1,
26 | eta: float = 0.0,
27 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
28 | latents: Optional[torch.FloatTensor] = None,
29 | prompt_embeds: Optional[torch.FloatTensor] = None,
30 | negative_prompt_embeds: Optional[torch.FloatTensor] = None,
31 | output_type: Optional[str] = "pil",
32 | return_dict: bool = True,
33 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
34 | callback_steps: int = 1,
35 | cross_attention_kwargs: Optional[Dict[str, Any]] = None,
36 | guidance_rescale: float = 0.0,
37 | ):
38 | r"""
39 | Function invoked when calling the pipeline for generation.
40 |
41 | Args:
42 | prompt (`str` or `List[str]`, *optional*):
43 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
44 | instead.
45 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
46 | The height in pixels of the generated image.
47 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
48 | The width in pixels of the generated image.
49 | num_inference_steps (`int`, *optional*, defaults to 50):
50 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
51 | expense of slower inference.
52 | guidance_scale (`float`, *optional*, defaults to 7.5):
53 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
54 | `guidance_scale` is defined as `w` of equation 2. of [Imagen
55 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
56 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
57 | usually at the expense of lower image quality.
58 | negative_prompt (`str` or `List[str]`, *optional*):
59 | The prompt or prompts not to guide the image generation. If not defined, one has to pass
60 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
61 | less than `1`).
62 | num_images_per_prompt (`int`, *optional*, defaults to 1):
63 | The number of images to generate per prompt.
64 | eta (`float`, *optional*, defaults to 0.0):
65 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
66 | [`schedulers.DDIMScheduler`], will be ignored for others.
67 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
68 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
69 | to make generation deterministic.
70 | latents (`torch.FloatTensor`, *optional*):
71 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
72 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
73 | tensor will ge generated by sampling using the supplied random `generator`.
74 | prompt_embeds (`torch.FloatTensor`, *optional*):
75 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
76 | provided, text embeddings will be generated from `prompt` input argument.
77 | negative_prompt_embeds (`torch.FloatTensor`, *optional*):
78 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
79 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
80 | argument.
81 | output_type (`str`, *optional*, defaults to `"pil"`):
82 | The output format of the generate image. Choose between
83 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
84 | return_dict (`bool`, *optional*, defaults to `True`):
85 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
86 | plain tuple.
87 | callback (`Callable`, *optional*):
88 | A function that will be called every `callback_steps` steps during inference. The function will be
89 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
90 | callback_steps (`int`, *optional*, defaults to 1):
91 | The frequency at which the `callback` function will be called. If not specified, the callback will be
92 | called at every step.
93 | cross_attention_kwargs (`dict`, *optional*):
94 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
95 | `self.processor` in
96 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
97 | guidance_rescale (`float`, *optional*, defaults to 0.7):
98 | Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
99 | Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
100 | [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
101 | Guidance rescale factor should fix overexposure when using zero terminal SNR.
102 |
103 | Examples:
104 |
105 | Returns:
106 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
107 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
108 | When returning a tuple, the first element is a list with the generated images, and the second element is a
109 | list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
110 | (nsfw) content, according to the `safety_checker`.
111 | """
112 | # 0. Default height and width to unet
113 | height = height or self.unet.config.sample_size * self.vae_scale_factor
114 | width = width or self.unet.config.sample_size * self.vae_scale_factor
115 |
116 | # 1. Check inputs. Raise error if not correct
117 | self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
118 |
119 | # 2. Define call parameters
120 | if prompt is not None and isinstance(prompt, str):
121 | batch_size = 1
122 | elif prompt is not None and isinstance(prompt, list):
123 | batch_size = len(prompt)
124 | else:
125 | batch_size = prompt_embeds.shape[0]
126 |
127 | device = self._execution_device
128 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
129 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
130 | # corresponds to doing no classifier free guidance.
131 | do_classifier_free_guidance = guidance_scale > 1.0
132 |
133 | # 3. Encode input prompt
134 | text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
135 | prompt_embeds = self._encode_prompt(
136 | prompt,
137 | device,
138 | num_images_per_prompt,
139 | do_classifier_free_guidance,
140 | negative_prompt,
141 | prompt_embeds=prompt_embeds,
142 | negative_prompt_embeds=negative_prompt_embeds,
143 | lora_scale=text_encoder_lora_scale,
144 | )
145 |
146 | # 4. Prepare timesteps
147 | self.scheduler.set_timesteps(num_inference_steps, device=device)
148 | timesteps = self.scheduler.timesteps
149 |
150 | # 5. Prepare latent variables
151 | num_channels_latents = self.unet.config.in_channels
152 | latents = self.prepare_latents(
153 | batch_size * num_images_per_prompt,
154 | num_channels_latents,
155 | height,
156 | width,
157 | prompt_embeds.dtype,
158 | device,
159 | generator,
160 | latents,
161 | )
162 |
163 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
164 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
165 |
166 | # 7. Denoising loop
167 | num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
168 | all_latents = [latents]
169 | all_log_probs = []
170 | with self.progress_bar(total=num_inference_steps) as progress_bar:
171 | for i, t in enumerate(timesteps):
172 | # expand the latents if we are doing classifier free guidance
173 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
174 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
175 |
176 | # predict the noise residual
177 | noise_pred = self.unet(
178 | latent_model_input,
179 | t,
180 | encoder_hidden_states=prompt_embeds,
181 | cross_attention_kwargs=cross_attention_kwargs,
182 | return_dict=False,
183 | )[0]
184 |
185 | # perform guidance
186 | if do_classifier_free_guidance:
187 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
188 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
189 |
190 | if do_classifier_free_guidance and guidance_rescale > 0.0:
191 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
192 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
193 |
194 | # compute the previous noisy sample x_t -> x_t-1
195 | latents, log_prob = ddim_step_with_logprob(self.scheduler, noise_pred, t, latents, **extra_step_kwargs)
196 |
197 | all_latents.append(latents)
198 | all_log_probs.append(log_prob)
199 |
200 | # call the callback, if provided
201 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
202 | progress_bar.update()
203 | if callback is not None and i % callback_steps == 0:
204 | callback(i, t, latents)
205 |
206 | if not output_type == "latent":
207 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
208 | image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
209 | else:
210 | image = latents
211 | has_nsfw_concept = None
212 |
213 | if has_nsfw_concept is None:
214 | do_denormalize = [True] * image.shape[0]
215 | else:
216 | do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
217 |
218 | image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
219 |
220 | # Offload last model to CPU
221 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
222 | self.final_offload_hook.offload()
223 |
224 | return image, has_nsfw_concept, all_latents, all_log_probs
225 |
--------------------------------------------------------------------------------
/scripts/train_kohya/ddpo_pytorch/prompts.py:
--------------------------------------------------------------------------------
1 | """This file defines customize prompt funtions. The prompt function which takes no arguments
2 | generates a random prompt from the given prompt distribution each time it is called.
3 | """
4 |
5 |
6 | def easyphoto() -> str:
7 | return "easyphoto_face, easyphoto, 1person"
8 |
--------------------------------------------------------------------------------
/scripts/train_kohya/ddpo_pytorch/rewards.py:
--------------------------------------------------------------------------------
1 | """This file defines customize reward funtions. The reward function which takes in a batch of images
2 | and corresponding prompts returns a batch of rewards each time it is called.
3 | """
4 |
5 | from pathlib import Path
6 | from typing import Callable, List, Tuple, Union
7 |
8 | import numpy as np
9 | import torch
10 | from modelscope.outputs import OutputKeys
11 | from modelscope.pipelines import pipeline
12 | from modelscope.utils.constant import Tasks
13 | from PIL import Image
14 |
15 |
16 | def _convert_images(images: Union[List[Image.Image], np.array, torch.Tensor]) -> List[Image.Image]:
17 | if isinstance(images, List) and isinstance(images[0], Image.Image):
18 | return images
19 | if isinstance(images, torch.Tensor):
20 | images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
21 | images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC
22 | images = [Image.fromarray(image) for image in images]
23 |
24 | return images
25 |
26 |
27 | def faceid_SCRFD(target_image_dir: str) -> Callable:
28 | """The closure returns the Face ID reward function given a user ID. It uses SCRFD to detect the face and then
29 | use CurricularFace to extract the face feature. SCRFD requires the extra package mmcv-full.
30 |
31 | Args:
32 | target_image_dir (str): The directory of processed face image files (.jpg) given a user ID.
33 | """
34 | face_recognition = pipeline(Tasks.face_recognition, model="damo/cv_ir101_facerecognition_cfglint")
35 | target_image_files = Path(target_image_dir).glob("*.jpg")
36 | target_images = [Image.open(f).convert("RGB") for f in target_image_files]
37 | target_embs = [face_recognition(t)[OutputKeys.IMG_EMBEDDING][0] for t in target_images] # (M, 512)
38 | target_mean_emb = np.mean(np.vstack(target_embs), axis=0) # (512,)
39 |
40 | # Redundant parameters are for backward compatibility.
41 | def __call__(src_images: Union[List[Image.Image], np.array, torch.Tensor], prompts: List[str]) -> Tuple[np.array, dict]:
42 | src_images = _convert_images(src_images)
43 | src_embs = [] # (N, 512)
44 | for s in src_images:
45 | try:
46 | emb = face_recognition(s)[OutputKeys.IMG_EMBEDDING][0]
47 | except Exception as e: # TypeError
48 | print("Catch Exception in the reward function faceid_retina: {}".format(e))
49 | emb = np.array([0] * 512)
50 | print("No face is detected or the size of the detected face size is not enough. Set the embedding to zero.")
51 | finally:
52 | src_embs.append(emb)
53 | faceid_list = np.dot(src_embs, target_mean_emb)
54 |
55 | return faceid_list
56 |
57 | return __call__
58 |
59 |
60 | def faceid_retina(target_image_dir: str) -> Callable:
61 | """The closure returns the Face ID reward function given a user ID. It uses RetinaFace to detect the face and then
62 | use CurricularFace to extract the face feature. As the detection capability of RetinaFace is weaker than SCRFD,
63 | many generated side faces cannot be detected.
64 |
65 | Args:
66 | target_image_dir (str): The directory of processed face image files (.jpg) given a user ID.
67 | """
68 | # The retinaface detection is built into the face recognition pipeline.
69 | face_recognition = pipeline("face_recognition", model="bubbliiiing/cv_retinafce_recognition", model_revision="v1.0.3")
70 | target_image_files = Path(target_image_dir).glob("*.jpg")
71 | target_images = [Image.open(f).convert("RGB") for f in target_image_files]
72 |
73 | target_embs = [face_recognition(dict(user=f))[OutputKeys.IMG_EMBEDDING] for f in target_images] # (M, 512)
74 | target_mean_emb = np.mean(np.vstack(target_embs), axis=0) # (512,)
75 |
76 | # Redundant parameters are for backward compatibility.
77 | def __call__(src_images: Union[List[Image.Image], np.array, torch.Tensor], prompts: List[str]) -> Tuple[np.array, dict]:
78 | src_images = _convert_images(src_images)
79 | src_embs = [] # (N, 512)
80 | for s in src_images:
81 | try:
82 | emb = face_recognition(dict(user=s))[OutputKeys.IMG_EMBEDDING][0]
83 | except Exception as e: # cv2.error; TypeError.
84 | print("Catch Exception in the reward function faceid_retina: {}".format(e))
85 | emb = np.array([0] * 512)
86 | print("No face is detected or the size of the detected face size is not enough. Set the embedding to zero.")
87 | finally:
88 | src_embs.append(emb)
89 | faceid_list = np.dot(src_embs, target_mean_emb)
90 |
91 | return faceid_list
92 |
93 | return __call__
94 |
--------------------------------------------------------------------------------
/scripts/train_kohya/ddpo_pytorch/stat_tracking.py:
--------------------------------------------------------------------------------
1 | """Borrowed from https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/stat_tracking.py.
2 | """
3 |
4 | from collections import deque
5 | from typing import List
6 |
7 | import numpy as np
8 |
9 |
10 | class PerPromptStatTracker:
11 | """Track the mean and std of reward on a per-prompt basis and use that to compute advantages.
12 |
13 | Args:
14 | buffer_size (int): The number of reward values to store in the buffer for each prompt.
15 | The buffer persists across epochs.
16 | min_count (int): The minimum number of reward values to store in the buffer before using the
17 | per-prompt mean and std. If the buffer contains fewer than `min_count` values, the mean and
18 | std of the entire batch will be used instead.
19 | """
20 |
21 | def __init__(self, buffer_size: int, min_count: int):
22 | self.buffer_size = buffer_size
23 | self.min_count = min_count
24 | self.stats = {}
25 |
26 | def update(self, prompts: List[str], rewards: List[float]):
27 | prompts = np.array(prompts)
28 | rewards = np.array(rewards)
29 | unique = np.unique(prompts)
30 | advantages = np.empty_like(rewards)
31 | for prompt in unique:
32 | prompt_rewards = rewards[prompts == prompt]
33 | if prompt not in self.stats:
34 | self.stats[prompt] = deque(maxlen=self.buffer_size)
35 | self.stats[prompt].extend(prompt_rewards)
36 |
37 | if len(self.stats[prompt]) < self.min_count:
38 | mean = np.mean(rewards)
39 | std = np.std(rewards) + 1e-6
40 | else:
41 | mean = np.mean(self.stats[prompt])
42 | std = np.std(self.stats[prompt]) + 1e-6
43 | advantages[prompts == prompt] = (prompt_rewards - mean) / std
44 |
45 | return advantages
46 |
47 | def get_stats(self):
48 | return {k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} for k, v in self.stats.items()}
49 |
--------------------------------------------------------------------------------
/scripts/train_kohya/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aigc-apps/sd-webui-EasyPhoto/6c1f8959c4024b0f8a1187fae7c50c0c8f86a8df/scripts/train_kohya/utils/__init__.py
--------------------------------------------------------------------------------
/scripts/train_kohya/utils/gpu_info.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import platform
3 | import time
4 | from datetime import datetime
5 | from multiprocessing import Process, Value
6 | from os import makedirs, path
7 |
8 | import matplotlib.pyplot as plt
9 | import matplotlib.ticker as ticker
10 |
11 | if platform.system() != "Windows":
12 | try:
13 | from nvitop import Device
14 | except Exception:
15 | Device = None
16 |
17 | # Constants
18 | BYTES_PER_GB = 1024 * 1024 * 1024
19 |
20 |
21 | def bytes_to_gb(bytes_value: int) -> float:
22 | """Convert bytes to gigabytes."""
23 | return bytes_value / BYTES_PER_GB
24 |
25 |
26 | def log_device_info(device, prefix: str, csvwriter, display_log):
27 | """
28 | Logs device information.
29 |
30 | Parameters:
31 | device: The device object with GPU information.
32 | prefix: The prefix string for log identification.
33 | csvwriter: The CSV writer object for writing logs to a file.
34 | display_log: print out on shell
35 | """
36 | total_memory_gb = float(device.memory_total_human()[:-3])
37 | used_memory_bytes = device.memory_used()
38 | gpu_utilization = device.gpu_utilization()
39 |
40 | if display_log:
41 | print(f"Device: {device.name}")
42 | print(f" - Used memory : {bytes_to_gb(used_memory_bytes):.2f} GB")
43 | print(f" - Used memory% : {bytes_to_gb(used_memory_bytes)/total_memory_gb * 100:.2f}%")
44 | print(f" - GPU utilization: {gpu_utilization}%")
45 | print("-" * 40)
46 |
47 | current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
48 | (bytes_to_gb(used_memory_bytes) / total_memory_gb) * 100
49 |
50 | csvwriter.writerow([current_time, bytes_to_gb(used_memory_bytes), gpu_utilization])
51 |
52 |
53 | def monitor_and_plot(prefix="result/tmp", display_log=False, stop_flag: Value = None):
54 | """Monitor and plot GPU usage.
55 | Args:
56 | prefix: The prefix of the output file.
57 | stop_flag: A multiprocessing.Value to indicate if monitoring should stop.
58 | """
59 | devices = Device.all()
60 | initial_pids = set()
61 | monitored_pids = set()
62 |
63 | with open(f"{prefix}.csv", "w", newline="") as csvfile:
64 | csvwriter = csv.writer(csvfile)
65 | csvwriter.writerow(["Time", "Used Memory%", "GPU Utilization"])
66 |
67 | try:
68 | while True:
69 | if stop_flag and stop_flag.value:
70 | break
71 |
72 | for device in devices:
73 | current_pids = set(device.processes().keys())
74 | if not initial_pids:
75 | initial_pids = current_pids
76 |
77 | new_pids = current_pids - initial_pids
78 | if new_pids:
79 | monitored_pids.update(new_pids)
80 |
81 | for pid in monitored_pids.copy():
82 | if pid not in current_pids:
83 | monitored_pids.remove(pid)
84 | if not monitored_pids:
85 | raise StopIteration
86 | log_device_info(device, prefix, csvwriter, display_log)
87 | time.sleep(1)
88 | except StopIteration:
89 | pass
90 |
91 | plot_data(prefix)
92 | return
93 |
94 |
95 | def plot_data(prefix):
96 | """Plot the data from the CSV file.
97 | Args:
98 | prefix: The prefix of the CSV file.
99 | """
100 | data = list(csv.reader(open(f"{prefix}.csv")))
101 | if len(data) < 2:
102 | print("Insufficient data for plotting.")
103 | return
104 |
105 | time_stamps, used_memory, gpu_utilization = zip(*data[1:])
106 | used_memory = [float(x) for x in used_memory]
107 | gpu_utilization = [float(x) for x in gpu_utilization]
108 | if len(used_memory) >= 10:
109 | tick_spacing = len(used_memory) // 10
110 | else:
111 | tick_spacing = 1
112 |
113 | try:
114 | plot_graph(
115 | time_stamps,
116 | used_memory,
117 | "Used Memory (GB)",
118 | "Time",
119 | "Used Memory (GB)",
120 | "Used Memory Over Time",
121 | tick_spacing,
122 | f"{prefix}_memory.png",
123 | )
124 | except Exception as e:
125 | message = f"plot_graph of Memory error, error info:{str(e)}"
126 | print(message)
127 |
128 | try:
129 | plot_graph(
130 | time_stamps,
131 | gpu_utilization,
132 | "GPU Utilization (%)",
133 | "Time",
134 | "GPU Utilization (%)",
135 | "GPU Utilization Over Time",
136 | tick_spacing,
137 | f"{prefix}_utilization.png",
138 | )
139 | except Exception as e:
140 | message = f"plot_graph of Utilization error, error info:{str(e)}"
141 | print(message)
142 |
143 |
144 | def plot_graph(x, y, label, xlabel, ylabel, title, tick_spacing, filename):
145 | """Generate and save a plot.
146 | Args:
147 | x: X-axis data.
148 | y: Y-axis data.
149 | label: The label for the plot.
150 | xlabel: Label for X-axis.
151 | ylabel: Label for Y-axis.
152 | title: The title of the plot.
153 | tick_spacing: Interval for tick marks on the x-axis.
154 | filename: The filename to save the plot.
155 | """
156 | plt.figure(figsize=(10, 6))
157 | plt.plot(x, y, label=label)
158 | plt.xlabel(xlabel)
159 | plt.ylabel(ylabel)
160 | plt.title(title)
161 | plt.legend()
162 | ax = plt.gca()
163 | ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
164 | plt.xticks(rotation=45)
165 | plt.savefig(filename)
166 |
167 |
168 | def gpu_monitor_decorator(prefix="result/gpu_info", display_log=False):
169 | def actual_decorator(func):
170 | def wrapper(*args, **kwargs):
171 | if platform.system() != "Windows" and Device is not None:
172 | timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
173 | dynamic_prefix = f"{prefix}/{func.__name__}_{timestamp}"
174 |
175 | directory = path.dirname(dynamic_prefix)
176 | if not path.exists(directory):
177 | try:
178 | makedirs(directory)
179 | except Exception as e:
180 | comment = f"GPU Info record need a result/gpu_info dir in your SDWebUI, now failed with {str(e)}"
181 | print(comment)
182 | dynamic_prefix = f"{func.__name__}_{timestamp}"
183 |
184 | stop_flag = Value("b", False)
185 |
186 | monitor_proc = Process(target=monitor_and_plot, args=(dynamic_prefix, display_log, stop_flag))
187 | monitor_proc.start()
188 |
189 | try:
190 | result = func(*args, **kwargs)
191 | finally:
192 | stop_flag.value = True
193 | monitor_proc.join()
194 | else:
195 | result = func(*args, **kwargs)
196 | return result
197 |
198 | return wrapper
199 |
200 | return actual_decorator
201 |
202 |
203 | if __name__ == "__main__":
204 | pass
205 |
206 | # Display how to define a GPU infer function and wrap with gpu_monitor_decorator
207 | @gpu_monitor_decorator()
208 | def execute_process(repeat=5):
209 | from modelscope.pipelines import pipeline
210 | from modelscope.utils.constant import Tasks
211 |
212 | retina_face_detection = pipeline(Tasks.face_detection, "damo/cv_resnet50_face-detection_retinaface")
213 | img_path = "https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/retina_face_detection.jpg"
214 |
215 | for i in range(repeat):
216 | retina_face_detection([img_path] * 10)
217 | return
218 |
219 | if 1:
220 | execute_process(repeat=5)
221 |
--------------------------------------------------------------------------------