├── .eslintignore ├── .eslintrc.js ├── .git-blame-ignore-revs ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── config.yml │ └── feature_request.yml ├── pull_request_template.md └── workflows │ ├── on_pull_request.yaml │ ├── run_tests.yaml │ └── warns_merge_master.yml ├── .gitignore ├── .pylintrc ├── CHANGELOG.md ├── CITATION.cff ├── CODEOWNERS ├── LICENSE.txt ├── README.md ├── _typos.toml ├── configs ├── alt-diffusion-inference.yaml ├── alt-diffusion-m18-inference.yaml ├── instruct-pix2pix.yaml ├── sd3-inference.yaml ├── sd_xl_inpaint.yaml ├── v1-inference.yaml └── v1-inpainting-inference.yaml ├── embeddings └── Place Textual Inversion embeddings here.txt ├── environment-wsl2.yaml ├── extensions-builtin ├── LDSR │ ├── ldsr_model_arch.py │ ├── preload.py │ ├── scripts │ │ └── ldsr_model.py │ ├── sd_hijack_autoencoder.py │ ├── sd_hijack_ddpm_v1.py │ └── vqvae_quantize.py ├── Lora │ ├── extra_networks_lora.py │ ├── lora.py │ ├── lora_logger.py │ ├── lora_patches.py │ ├── lyco_helpers.py │ ├── network.py │ ├── network_full.py │ ├── network_glora.py │ ├── network_hada.py │ ├── network_ia3.py │ ├── network_lokr.py │ ├── network_lora.py │ ├── network_norm.py │ ├── network_oft.py │ ├── networks.py │ ├── preload.py │ ├── scripts │ │ └── lora_script.py │ ├── ui_edit_user_metadata.py │ └── ui_extra_networks_lora.py ├── ScuNET │ ├── preload.py │ └── scripts │ │ └── scunet_model.py ├── SwinIR │ ├── preload.py │ └── scripts │ │ └── swinir_model.py ├── canvas-zoom-and-pan │ ├── javascript │ │ └── zoom.js │ ├── scripts │ │ └── hotkey_config.py │ └── style.css ├── extra-options-section │ └── scripts │ │ └── extra_options_section.py ├── hypertile │ ├── hypertile.py │ └── scripts │ │ └── hypertile_script.py ├── mobile │ └── javascript │ │ └── mobile.js ├── postprocessing-for-training │ └── scripts │ │ ├── postprocessing_autosized_crop.py │ │ ├── postprocessing_caption.py │ │ ├── postprocessing_create_flipped_copies.py │ │ ├── postprocessing_focal_crop.py │ │ └── postprocessing_split_oversized.py ├── prompt-bracket-checker │ └── javascript │ │ └── prompt-bracket-checker.js └── soft-inpainting │ └── scripts │ └── soft_inpainting.py ├── extensions └── put extensions here.txt ├── html ├── card-no-preview.png ├── extra-networks-card.html ├── extra-networks-copy-path-button.html ├── extra-networks-edit-item-button.html ├── extra-networks-metadata-button.html ├── extra-networks-no-cards.html ├── extra-networks-pane-dirs.html ├── extra-networks-pane-tree.html ├── extra-networks-pane.html ├── extra-networks-tree-button.html ├── footer.html └── licenses.html ├── javascript ├── aspectRatioOverlay.js ├── contextMenus.js ├── dragdrop.js ├── edit-attention.js ├── edit-order.js ├── extensions.js ├── extraNetworks.js ├── generationParams.js ├── hints.js ├── hires_fix.js ├── imageMaskFix.js ├── imageviewer.js ├── imageviewerGamepad.js ├── inputAccordion.js ├── localStorage.js ├── localization.js ├── notification.js ├── profilerVisualization.js ├── progressbar.js ├── resizeHandle.js ├── settings.js ├── textualInversion.js ├── token-counters.js ├── ui.js └── ui_settings_hints.js ├── launch.py ├── localizations └── Put localization files here.txt ├── models ├── Stable-diffusion │ └── Put Stable Diffusion checkpoints here.txt ├── VAE-approx │ └── model.pt ├── VAE │ └── Put VAE here.txt ├── deepbooru │ └── Put your deepbooru release project folder here.txt └── karlo │ └── ViT-L-14_stats.th ├── modules ├── Roboto-Regular.ttf ├── api │ ├── api.py │ └── models.py ├── cache.py ├── call_queue.py ├── cmd_args.py ├── codeformer_model.py ├── config_states.py ├── dat_model.py ├── deepbooru.py ├── deepbooru_model.py ├── devices.py ├── errors.py ├── esrgan_model.py ├── extensions.py ├── extra_networks.py ├── extra_networks_hypernet.py ├── extras.py ├── face_restoration.py ├── face_restoration_utils.py ├── fifo_lock.py ├── gfpgan_model.py ├── gitpython_hack.py ├── gradio_extensons.py ├── hashes.py ├── hat_model.py ├── hypernetworks │ ├── hypernetwork.py │ └── ui.py ├── images.py ├── img2img.py ├── import_hook.py ├── infotext_utils.py ├── infotext_versions.py ├── initialize.py ├── initialize_util.py ├── interrogate.py ├── launch_utils.py ├── localization.py ├── logging_config.py ├── lowvram.py ├── mac_specific.py ├── masking.py ├── memmon.py ├── modelloader.py ├── models │ ├── diffusion │ │ ├── ddpm_edit.py │ │ └── uni_pc │ │ │ ├── __init__.py │ │ │ ├── sampler.py │ │ │ └── uni_pc.py │ └── sd3 │ │ ├── mmdit.py │ │ ├── other_impls.py │ │ ├── sd3_cond.py │ │ ├── sd3_impls.py │ │ └── sd3_model.py ├── ngrok.py ├── npu_specific.py ├── options.py ├── patches.py ├── paths.py ├── paths_internal.py ├── postprocessing.py ├── processing.py ├── processing_scripts │ ├── comments.py │ ├── refiner.py │ ├── sampler.py │ └── seed.py ├── profiling.py ├── progress.py ├── prompt_parser.py ├── realesrgan_model.py ├── restart.py ├── rng.py ├── rng_philox.py ├── safe.py ├── script_callbacks.py ├── script_loading.py ├── scripts.py ├── scripts_auto_postprocessing.py ├── scripts_postprocessing.py ├── sd_disable_initialization.py ├── sd_emphasis.py ├── sd_hijack.py ├── sd_hijack_checkpoint.py ├── sd_hijack_clip.py ├── sd_hijack_clip_old.py ├── sd_hijack_ip2p.py ├── sd_hijack_open_clip.py ├── sd_hijack_optimizations.py ├── sd_hijack_unet.py ├── sd_hijack_utils.py ├── sd_hijack_xlmr.py ├── sd_models.py ├── sd_models_config.py ├── sd_models_types.py ├── sd_models_xl.py ├── sd_samplers.py ├── sd_samplers_cfg_denoiser.py ├── sd_samplers_common.py ├── sd_samplers_compvis.py ├── sd_samplers_extra.py ├── sd_samplers_kdiffusion.py ├── sd_samplers_lcm.py ├── sd_samplers_timesteps.py ├── sd_samplers_timesteps_impl.py ├── sd_schedulers.py ├── sd_unet.py ├── sd_vae.py ├── sd_vae_approx.py ├── sd_vae_taesd.py ├── shared.py ├── shared_cmd_options.py ├── shared_gradio_themes.py ├── shared_init.py ├── shared_items.py ├── shared_options.py ├── shared_state.py ├── shared_total_tqdm.py ├── styles.py ├── sub_quadratic_attention.py ├── sysinfo.py ├── textual_inversion │ ├── autocrop.py │ ├── dataset.py │ ├── image_embedding.py │ ├── learn_schedule.py │ ├── saving_settings.py │ ├── test_embedding.png │ ├── textual_inversion.py │ └── ui.py ├── timer.py ├── torch_utils.py ├── txt2img.py ├── ui.py ├── ui_checkpoint_merger.py ├── ui_common.py ├── ui_components.py ├── ui_extensions.py ├── ui_extra_networks.py ├── ui_extra_networks_checkpoints.py ├── ui_extra_networks_checkpoints_user_metadata.py ├── ui_extra_networks_hypernets.py ├── ui_extra_networks_textual_inversion.py ├── ui_extra_networks_user_metadata.py ├── ui_gradio_extensions.py ├── ui_loadsave.py ├── ui_postprocessing.py ├── ui_prompt_styles.py ├── ui_settings.py ├── ui_tempdir.py ├── ui_toprow.py ├── upscaler.py ├── upscaler_utils.py ├── util.py ├── xlmr.py ├── xlmr_m18.py └── xpu_specific.py ├── package.json ├── pyproject.toml ├── requirements-test.txt ├── requirements.txt ├── requirements_npu.txt ├── requirements_versions.txt ├── screenshot.png ├── script.js ├── scripts ├── custom_code.py ├── img2imgalt.py ├── loopback.py ├── outpainting_mk_2.py ├── poor_mans_outpainting.py ├── postprocessing_codeformer.py ├── postprocessing_gfpgan.py ├── postprocessing_upscale.py ├── prompt_matrix.py ├── prompts_from_file.py ├── sd_upscale.py └── xyz_grid.py ├── style.css ├── test ├── __init__.py ├── conftest.py ├── test_extras.py ├── test_face_restorers.py ├── test_files │ ├── empty.pt │ ├── img2img_basic.png │ ├── mask_basic.png │ └── two-faces.jpg ├── test_img2img.py ├── test_outputs │ └── .gitkeep ├── test_torch_utils.py ├── test_txt2img.py └── test_utils.py ├── textual_inversion_templates ├── hypernetwork.txt ├── none.txt ├── style.txt ├── style_filewords.txt ├── subject.txt └── subject_filewords.txt ├── webui-macos-env.sh ├── webui-user.bat ├── webui-user.sh ├── webui.bat ├── webui.py └── webui.sh /.eslintignore: -------------------------------------------------------------------------------- 1 | extensions 2 | extensions-disabled 3 | repositories 4 | venv -------------------------------------------------------------------------------- /.git-blame-ignore-revs: -------------------------------------------------------------------------------- 1 | # Apply ESlint 2 | 9c54b78d9dde5601e916f308d9a9d6953ec39430 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: WebUI Community Support 4 | url: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions 5 | about: Please ask and answer questions here. 6 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: Feature request 2 | description: Suggest an idea for this project 3 | title: "[Feature Request]: " 4 | labels: ["enhancement"] 5 | 6 | body: 7 | - type: checkboxes 8 | attributes: 9 | label: Is there an existing issue for this? 10 | description: Please search to see if an issue already exists for the feature you want, and that it's not implemented in a recent build/commit. 11 | options: 12 | - label: I have searched the existing issues and checked the recent builds/commits 13 | required: true 14 | - type: markdown 15 | attributes: 16 | value: | 17 | *Please fill this form with as much information as possible, provide screenshots and/or illustrations of the feature if possible* 18 | - type: textarea 19 | id: feature 20 | attributes: 21 | label: What would your feature do ? 22 | description: Tell us about your feature in a very clear and simple way, and what problem it would solve 23 | validations: 24 | required: true 25 | - type: textarea 26 | id: workflow 27 | attributes: 28 | label: Proposed workflow 29 | description: Please provide us with step by step information on how you'd like the feature to be accessed and used 30 | value: | 31 | 1. Go to .... 32 | 2. Press .... 33 | 3. ... 34 | validations: 35 | required: true 36 | - type: textarea 37 | id: misc 38 | attributes: 39 | label: Additional information 40 | description: Add any other context or screenshots about the feature request here. 41 | -------------------------------------------------------------------------------- /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | * a simple description of what you're trying to accomplish 4 | * a summary of changes in code 5 | * which issues it fixes, if any 6 | 7 | ## Screenshots/videos: 8 | 9 | 10 | ## Checklist: 11 | 12 | - [ ] I have read [contributing wiki page](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing) 13 | - [ ] I have performed a self-review of my own code 14 | - [ ] My code follows the [style guidelines](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Contributing#code-style) 15 | - [ ] My code passes [tests](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Tests) 16 | -------------------------------------------------------------------------------- /.github/workflows/on_pull_request.yaml: -------------------------------------------------------------------------------- 1 | name: Linter 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | lint-python: 9 | name: ruff 10 | runs-on: ubuntu-latest 11 | if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name 12 | steps: 13 | - name: Checkout Code 14 | uses: actions/checkout@v4 15 | - uses: actions/setup-python@v5 16 | with: 17 | python-version: 3.11 18 | # NB: there's no cache: pip here since we're not installing anything 19 | # from the requirements.txt file(s) in the repository; it's faster 20 | # not to have GHA download an (at the time of writing) 4 GB cache 21 | # of PyTorch and other dependencies. 22 | - name: Install Ruff 23 | run: pip install ruff==0.3.3 24 | - name: Run Ruff 25 | run: ruff . 26 | lint-js: 27 | name: eslint 28 | runs-on: ubuntu-latest 29 | if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name 30 | steps: 31 | - name: Checkout Code 32 | uses: actions/checkout@v4 33 | - name: Install Node.js 34 | uses: actions/setup-node@v4 35 | with: 36 | node-version: 18 37 | - run: npm i --ci 38 | - run: npm run lint 39 | -------------------------------------------------------------------------------- /.github/workflows/run_tests.yaml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | test: 9 | name: tests on CPU with empty model 10 | runs-on: ubuntu-latest 11 | if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name 12 | steps: 13 | - name: Checkout Code 14 | uses: actions/checkout@v4 15 | - name: Set up Python 3.10 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: 3.10.6 19 | cache: pip 20 | cache-dependency-path: | 21 | **/requirements*txt 22 | launch.py 23 | - name: Cache models 24 | id: cache-models 25 | uses: actions/cache@v4 26 | with: 27 | path: models 28 | key: "2023-12-30" 29 | - name: Install test dependencies 30 | run: pip install wait-for-it -r requirements-test.txt 31 | env: 32 | PIP_DISABLE_PIP_VERSION_CHECK: "1" 33 | PIP_PROGRESS_BAR: "off" 34 | - name: Setup environment 35 | run: python launch.py --skip-torch-cuda-test --exit 36 | env: 37 | PIP_DISABLE_PIP_VERSION_CHECK: "1" 38 | PIP_PROGRESS_BAR: "off" 39 | TORCH_INDEX_URL: https://download.pytorch.org/whl/cpu 40 | WEBUI_LAUNCH_LIVE_OUTPUT: "1" 41 | PYTHONUNBUFFERED: "1" 42 | - name: Print installed packages 43 | run: pip freeze 44 | - name: Start test server 45 | run: > 46 | python -m coverage run 47 | --data-file=.coverage.server 48 | launch.py 49 | --skip-prepare-environment 50 | --skip-torch-cuda-test 51 | --test-server 52 | --do-not-download-clip 53 | --no-half 54 | --disable-opt-split-attention 55 | --use-cpu all 56 | --api-server-stop 57 | 2>&1 | tee output.txt & 58 | - name: Run tests 59 | run: | 60 | wait-for-it --service 127.0.0.1:7860 -t 20 61 | python -m pytest -vv --junitxml=test/results.xml --cov . --cov-report=xml --verify-base-url test 62 | - name: Kill test server 63 | if: always() 64 | run: curl -vv -XPOST http://127.0.0.1:7860/sdapi/v1/server-stop && sleep 10 65 | - name: Show coverage 66 | run: | 67 | python -m coverage combine .coverage* 68 | python -m coverage report -i 69 | python -m coverage html -i 70 | - name: Upload main app output 71 | uses: actions/upload-artifact@v4 72 | if: always() 73 | with: 74 | name: output 75 | path: output.txt 76 | - name: Upload coverage HTML 77 | uses: actions/upload-artifact@v4 78 | if: always() 79 | with: 80 | name: htmlcov 81 | path: htmlcov 82 | -------------------------------------------------------------------------------- /.github/workflows/warns_merge_master.yml: -------------------------------------------------------------------------------- 1 | name: Pull requests can't target master branch 2 | 3 | "on": 4 | pull_request: 5 | types: 6 | - opened 7 | - synchronize 8 | - reopened 9 | branches: 10 | - master 11 | 12 | jobs: 13 | check: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Warning marge into master 17 | run: | 18 | echo -e "::warning::This pull request directly merge into \"master\" branch, normally development happens on \"dev\" branch." 19 | exit 1 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.ckpt 3 | *.safetensors 4 | *.pth 5 | .DS_Store 6 | /ESRGAN/* 7 | /SwinIR/* 8 | /repositories 9 | /venv 10 | /tmp 11 | /model.ckpt 12 | /models/**/* 13 | /GFPGANv1.3.pth 14 | /gfpgan/weights/*.pth 15 | /ui-config.json 16 | /outputs 17 | /config.json 18 | /log 19 | /webui.settings.bat 20 | /embeddings 21 | /styles.csv 22 | /params.txt 23 | /styles.csv.bak 24 | /webui-user.bat 25 | /webui-user.sh 26 | /interrogate 27 | /user.css 28 | /.idea 29 | notification.mp3 30 | /SwinIR 31 | /textual_inversion 32 | .vscode 33 | /extensions 34 | /test/stdout.txt 35 | /test/stderr.txt 36 | /cache.json* 37 | /config_states/ 38 | /node_modules 39 | /package-lock.json 40 | /.coverage* 41 | /test/test_outputs 42 | /cache 43 | trace.json 44 | /sysinfo-????-??-??-??-??.json 45 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | # See https://pylint.pycqa.org/en/latest/user_guide/messages/message_control.html 2 | [MESSAGES CONTROL] 3 | disable=C,R,W,E,I 4 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - given-names: AUTOMATIC1111 5 | title: "Stable Diffusion Web UI" 6 | date-released: 2022-08-22 7 | url: "https://github.com/AUTOMATIC1111/stable-diffusion-webui" 8 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | * @AUTOMATIC1111 2 | 3 | # if you were managing a localization and were removed from this file, this is because 4 | # the intended way to do localizations now is via extensions. See: 5 | # https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Developing-extensions 6 | # Make a repo with your localization and since you are still listed as a collaborator 7 | # you can add it to the wiki page yourself. This change is because some people complained 8 | # the git commit log is cluttered with things unrelated to almost everyone and 9 | # because I believe this is the best overall for the project to handle localizations almost 10 | # entirely without my oversight. 11 | 12 | 13 | -------------------------------------------------------------------------------- /_typos.toml: -------------------------------------------------------------------------------- 1 | [default.extend-words] 2 | # Part of "RGBa" (Pillow's pre-multiplied alpha RGB mode) 3 | Ba = "Ba" 4 | # HSA is something AMD uses for their GPUs 5 | HSA = "HSA" 6 | -------------------------------------------------------------------------------- /configs/alt-diffusion-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: False 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: modules.xlmr.BertSeriesModelWithTransformation 71 | params: 72 | name: "XLMR-Large" -------------------------------------------------------------------------------- /configs/alt-diffusion-m18-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_head_channels: 64 40 | use_spatial_transformer: True 41 | use_linear_in_transformer: True 42 | transformer_depth: 1 43 | context_dim: 1024 44 | use_checkpoint: False 45 | legacy: False 46 | 47 | first_stage_config: 48 | target: ldm.models.autoencoder.AutoencoderKL 49 | params: 50 | embed_dim: 4 51 | monitor: val/rec_loss 52 | ddconfig: 53 | double_z: true 54 | z_channels: 4 55 | resolution: 256 56 | in_channels: 3 57 | out_ch: 3 58 | ch: 128 59 | ch_mult: 60 | - 1 61 | - 2 62 | - 4 63 | - 4 64 | num_res_blocks: 2 65 | attn_resolutions: [] 66 | dropout: 0.0 67 | lossconfig: 68 | target: torch.nn.Identity 69 | 70 | cond_stage_config: 71 | target: modules.xlmr_m18.BertSeriesModelWithTransformation 72 | params: 73 | name: "XLMR-Large" 74 | -------------------------------------------------------------------------------- /configs/instruct-pix2pix.yaml: -------------------------------------------------------------------------------- 1 | # File modified by authors of InstructPix2Pix from original (https://github.com/CompVis/stable-diffusion). 2 | # See more details in LICENSE. 3 | 4 | model: 5 | base_learning_rate: 1.0e-04 6 | target: modules.models.diffusion.ddpm_edit.LatentDiffusion 7 | params: 8 | linear_start: 0.00085 9 | linear_end: 0.0120 10 | num_timesteps_cond: 1 11 | log_every_t: 200 12 | timesteps: 1000 13 | first_stage_key: edited 14 | cond_stage_key: edit 15 | # image_size: 64 16 | # image_size: 32 17 | image_size: 16 18 | channels: 4 19 | cond_stage_trainable: false # Note: different from the one we trained before 20 | conditioning_key: hybrid 21 | monitor: val/loss_simple_ema 22 | scale_factor: 0.18215 23 | use_ema: false 24 | 25 | scheduler_config: # 10000 warmup steps 26 | target: ldm.lr_scheduler.LambdaLinearScheduler 27 | params: 28 | warm_up_steps: [ 0 ] 29 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 30 | f_start: [ 1.e-6 ] 31 | f_max: [ 1. ] 32 | f_min: [ 1. ] 33 | 34 | unet_config: 35 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 36 | params: 37 | image_size: 32 # unused 38 | in_channels: 8 39 | out_channels: 4 40 | model_channels: 320 41 | attention_resolutions: [ 4, 2, 1 ] 42 | num_res_blocks: 2 43 | channel_mult: [ 1, 2, 4, 4 ] 44 | num_heads: 8 45 | use_spatial_transformer: True 46 | transformer_depth: 1 47 | context_dim: 768 48 | use_checkpoint: False 49 | legacy: False 50 | 51 | first_stage_config: 52 | target: ldm.models.autoencoder.AutoencoderKL 53 | params: 54 | embed_dim: 4 55 | monitor: val/rec_loss 56 | ddconfig: 57 | double_z: true 58 | z_channels: 4 59 | resolution: 256 60 | in_channels: 3 61 | out_ch: 3 62 | ch: 128 63 | ch_mult: 64 | - 1 65 | - 2 66 | - 4 67 | - 4 68 | num_res_blocks: 2 69 | attn_resolutions: [] 70 | dropout: 0.0 71 | lossconfig: 72 | target: torch.nn.Identity 73 | 74 | cond_stage_config: 75 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 76 | 77 | data: 78 | target: main.DataModuleFromConfig 79 | params: 80 | batch_size: 128 81 | num_workers: 1 82 | wrap: false 83 | validation: 84 | target: edit_dataset.EditDataset 85 | params: 86 | path: data/clip-filtered-dataset 87 | cache_dir: data/ 88 | cache_name: data_10k 89 | split: val 90 | min_text_sim: 0.2 91 | min_image_sim: 0.75 92 | min_direction_sim: 0.2 93 | max_samples_per_prompt: 1 94 | min_resize_res: 512 95 | max_resize_res: 512 96 | crop_res: 512 97 | output_as_edit: False 98 | real_input: True 99 | -------------------------------------------------------------------------------- /configs/sd3-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | target: modules.models.sd3.sd3_model.SD3Inferencer 3 | params: 4 | shift: 3 5 | state_dict: null 6 | -------------------------------------------------------------------------------- /configs/v1-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: crossattn 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | use_ema: False 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 10000 ] 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 4 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: False 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /configs/v1-inpainting-inference.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 7.5e-05 3 | target: ldm.models.diffusion.ddpm.LatentInpaintDiffusion 4 | params: 5 | linear_start: 0.00085 6 | linear_end: 0.0120 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: "jpg" 11 | cond_stage_key: "txt" 12 | image_size: 64 13 | channels: 4 14 | cond_stage_trainable: false # Note: different from the one we trained before 15 | conditioning_key: hybrid # important 16 | monitor: val/loss_simple_ema 17 | scale_factor: 0.18215 18 | finetune_keys: null 19 | 20 | scheduler_config: # 10000 warmup steps 21 | target: ldm.lr_scheduler.LambdaLinearScheduler 22 | params: 23 | warm_up_steps: [ 2500 ] # NOTE for resuming. use 10000 if starting from scratch 24 | cycle_lengths: [ 10000000000000 ] # incredibly large number to prevent corner cases 25 | f_start: [ 1.e-6 ] 26 | f_max: [ 1. ] 27 | f_min: [ 1. ] 28 | 29 | unet_config: 30 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 31 | params: 32 | image_size: 32 # unused 33 | in_channels: 9 # 4 data + 4 downscaled image + 1 mask 34 | out_channels: 4 35 | model_channels: 320 36 | attention_resolutions: [ 4, 2, 1 ] 37 | num_res_blocks: 2 38 | channel_mult: [ 1, 2, 4, 4 ] 39 | num_heads: 8 40 | use_spatial_transformer: True 41 | transformer_depth: 1 42 | context_dim: 768 43 | use_checkpoint: False 44 | legacy: False 45 | 46 | first_stage_config: 47 | target: ldm.models.autoencoder.AutoencoderKL 48 | params: 49 | embed_dim: 4 50 | monitor: val/rec_loss 51 | ddconfig: 52 | double_z: true 53 | z_channels: 4 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | - 4 63 | num_res_blocks: 2 64 | attn_resolutions: [] 65 | dropout: 0.0 66 | lossconfig: 67 | target: torch.nn.Identity 68 | 69 | cond_stage_config: 70 | target: ldm.modules.encoders.modules.FrozenCLIPEmbedder 71 | -------------------------------------------------------------------------------- /embeddings/Place Textual Inversion embeddings here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/embeddings/Place Textual Inversion embeddings here.txt -------------------------------------------------------------------------------- /environment-wsl2.yaml: -------------------------------------------------------------------------------- 1 | name: automatic 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.10 7 | - pip=23.0 8 | - cudatoolkit=11.8 9 | - pytorch=2.0 10 | - torchvision=0.15 11 | - numpy=1.23 12 | -------------------------------------------------------------------------------- /extensions-builtin/LDSR/preload.py: -------------------------------------------------------------------------------- 1 | import os 2 | from modules import paths 3 | 4 | 5 | def preload(parser): 6 | parser.add_argument("--ldsr-models-path", type=str, help="Path to directory with LDSR model file(s).", default=os.path.join(paths.models_path, 'LDSR')) 7 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/extra_networks_lora.py: -------------------------------------------------------------------------------- 1 | from modules import extra_networks, shared 2 | import networks 3 | 4 | 5 | class ExtraNetworkLora(extra_networks.ExtraNetwork): 6 | def __init__(self): 7 | super().__init__('lora') 8 | 9 | self.errors = {} 10 | """mapping of network names to the number of errors the network had during operation""" 11 | 12 | remove_symbols = str.maketrans('', '', ":,") 13 | 14 | def activate(self, p, params_list): 15 | additional = shared.opts.sd_lora 16 | 17 | self.errors.clear() 18 | 19 | if additional != "None" and additional in networks.available_networks and not any(x for x in params_list if x.items[0] == additional): 20 | p.all_prompts = [x + f"" for x in p.all_prompts] 21 | params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) 22 | 23 | names = [] 24 | te_multipliers = [] 25 | unet_multipliers = [] 26 | dyn_dims = [] 27 | for params in params_list: 28 | assert params.items 29 | 30 | names.append(params.positional[0]) 31 | 32 | te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0 33 | te_multiplier = float(params.named.get("te", te_multiplier)) 34 | 35 | unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else te_multiplier 36 | unet_multiplier = float(params.named.get("unet", unet_multiplier)) 37 | 38 | dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None 39 | dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim 40 | 41 | te_multipliers.append(te_multiplier) 42 | unet_multipliers.append(unet_multiplier) 43 | dyn_dims.append(dyn_dim) 44 | 45 | networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims) 46 | 47 | if shared.opts.lora_add_hashes_to_infotext: 48 | if not getattr(p, "is_hr_pass", False) or not hasattr(p, "lora_hashes"): 49 | p.lora_hashes = {} 50 | 51 | for item in networks.loaded_networks: 52 | if item.network_on_disk.shorthash and item.mentioned_name: 53 | p.lora_hashes[item.mentioned_name.translate(self.remove_symbols)] = item.network_on_disk.shorthash 54 | 55 | if p.lora_hashes: 56 | p.extra_generation_params["Lora hashes"] = ', '.join(f'{k}: {v}' for k, v in p.lora_hashes.items()) 57 | 58 | def deactivate(self, p): 59 | if self.errors: 60 | p.comment("Networks with errors: " + ", ".join(f"{k} ({v})" for k, v in self.errors.items())) 61 | 62 | self.errors.clear() 63 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/lora.py: -------------------------------------------------------------------------------- 1 | import networks 2 | 3 | list_available_loras = networks.list_available_networks 4 | 5 | available_loras = networks.available_networks 6 | available_lora_aliases = networks.available_network_aliases 7 | available_lora_hash_lookup = networks.available_network_hash_lookup 8 | forbidden_lora_aliases = networks.forbidden_network_aliases 9 | loaded_loras = networks.loaded_networks 10 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/lora_logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import copy 3 | import logging 4 | 5 | 6 | class ColoredFormatter(logging.Formatter): 7 | COLORS = { 8 | "DEBUG": "\033[0;36m", # CYAN 9 | "INFO": "\033[0;32m", # GREEN 10 | "WARNING": "\033[0;33m", # YELLOW 11 | "ERROR": "\033[0;31m", # RED 12 | "CRITICAL": "\033[0;37;41m", # WHITE ON RED 13 | "RESET": "\033[0m", # RESET COLOR 14 | } 15 | 16 | def format(self, record): 17 | colored_record = copy.copy(record) 18 | levelname = colored_record.levelname 19 | seq = self.COLORS.get(levelname, self.COLORS["RESET"]) 20 | colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" 21 | return super().format(colored_record) 22 | 23 | 24 | logger = logging.getLogger("lora") 25 | logger.propagate = False 26 | 27 | 28 | if not logger.handlers: 29 | handler = logging.StreamHandler(sys.stdout) 30 | handler.setFormatter( 31 | ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s") 32 | ) 33 | logger.addHandler(handler) 34 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/lora_patches.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import networks 4 | from modules import patches 5 | 6 | 7 | class LoraPatches: 8 | def __init__(self): 9 | self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward) 10 | self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict) 11 | self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward) 12 | self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict) 13 | self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward) 14 | self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict) 15 | self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward) 16 | self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict) 17 | self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward) 18 | self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict) 19 | 20 | def undo(self): 21 | self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward') 22 | self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict') 23 | self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward') 24 | self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict') 25 | self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward') 26 | self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict') 27 | self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward') 28 | self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict') 29 | self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward') 30 | self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict') 31 | 32 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/lyco_helpers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def make_weight_cp(t, wa, wb): 5 | temp = torch.einsum('i j k l, j r -> i r k l', t, wb) 6 | return torch.einsum('i j k l, i r -> r j k l', temp, wa) 7 | 8 | 9 | def rebuild_conventional(up, down, shape, dyn_dim=None): 10 | up = up.reshape(up.size(0), -1) 11 | down = down.reshape(down.size(0), -1) 12 | if dyn_dim is not None: 13 | up = up[:, :dyn_dim] 14 | down = down[:dyn_dim, :] 15 | return (up @ down).reshape(shape) 16 | 17 | 18 | def rebuild_cp_decomposition(up, down, mid): 19 | up = up.reshape(up.size(0), -1) 20 | down = down.reshape(down.size(0), -1) 21 | return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) 22 | 23 | 24 | # copied from https://github.com/KohakuBlueleaf/LyCORIS/blob/dev/lycoris/modules/lokr.py 25 | def factorization(dimension: int, factor:int=-1) -> tuple[int, int]: 26 | ''' 27 | return a tuple of two value of input dimension decomposed by the number closest to factor 28 | second value is higher or equal than first value. 29 | 30 | In LoRA with Kroneckor Product, first value is a value for weight scale. 31 | secon value is a value for weight. 32 | 33 | Because of non-commutative property, A⊗B ≠ B⊗A. Meaning of two matrices is slightly different. 34 | 35 | examples) 36 | factor 37 | -1 2 4 8 16 ... 38 | 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 127 -> 1, 127 39 | 128 -> 8, 16 128 -> 2, 64 128 -> 4, 32 128 -> 8, 16 128 -> 8, 16 40 | 250 -> 10, 25 250 -> 2, 125 250 -> 2, 125 250 -> 5, 50 250 -> 10, 25 41 | 360 -> 8, 45 360 -> 2, 180 360 -> 4, 90 360 -> 8, 45 360 -> 12, 30 42 | 512 -> 16, 32 512 -> 2, 256 512 -> 4, 128 512 -> 8, 64 512 -> 16, 32 43 | 1024 -> 32, 32 1024 -> 2, 512 1024 -> 4, 256 1024 -> 8, 128 1024 -> 16, 64 44 | ''' 45 | 46 | if factor > 0 and (dimension % factor) == 0: 47 | m = factor 48 | n = dimension // factor 49 | if m > n: 50 | n, m = m, n 51 | return m, n 52 | if factor < 0: 53 | factor = dimension 54 | m, n = 1, dimension 55 | length = m + n 56 | while m length or new_m>factor: 62 | break 63 | else: 64 | m, n = new_m, new_n 65 | if m > n: 66 | n, m = m, n 67 | return m, n 68 | 69 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/network_full.py: -------------------------------------------------------------------------------- 1 | import network 2 | 3 | 4 | class ModuleTypeFull(network.ModuleType): 5 | def create_module(self, net: network.Network, weights: network.NetworkWeights): 6 | if all(x in weights.w for x in ["diff"]): 7 | return NetworkModuleFull(net, weights) 8 | 9 | return None 10 | 11 | 12 | class NetworkModuleFull(network.NetworkModule): 13 | def __init__(self, net: network.Network, weights: network.NetworkWeights): 14 | super().__init__(net, weights) 15 | 16 | self.weight = weights.w.get("diff") 17 | self.ex_bias = weights.w.get("diff_b") 18 | 19 | def calc_updown(self, orig_weight): 20 | output_shape = self.weight.shape 21 | updown = self.weight.to(orig_weight.device) 22 | if self.ex_bias is not None: 23 | ex_bias = self.ex_bias.to(orig_weight.device) 24 | else: 25 | ex_bias = None 26 | 27 | return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) 28 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/network_glora.py: -------------------------------------------------------------------------------- 1 | 2 | import network 3 | 4 | class ModuleTypeGLora(network.ModuleType): 5 | def create_module(self, net: network.Network, weights: network.NetworkWeights): 6 | if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]): 7 | return NetworkModuleGLora(net, weights) 8 | 9 | return None 10 | 11 | # adapted from https://github.com/KohakuBlueleaf/LyCORIS 12 | class NetworkModuleGLora(network.NetworkModule): 13 | def __init__(self, net: network.Network, weights: network.NetworkWeights): 14 | super().__init__(net, weights) 15 | 16 | if hasattr(self.sd_module, 'weight'): 17 | self.shape = self.sd_module.weight.shape 18 | 19 | self.w1a = weights.w["a1.weight"] 20 | self.w1b = weights.w["b1.weight"] 21 | self.w2a = weights.w["a2.weight"] 22 | self.w2b = weights.w["b2.weight"] 23 | 24 | def calc_updown(self, orig_weight): 25 | w1a = self.w1a.to(orig_weight.device) 26 | w1b = self.w1b.to(orig_weight.device) 27 | w2a = self.w2a.to(orig_weight.device) 28 | w2b = self.w2b.to(orig_weight.device) 29 | 30 | output_shape = [w1a.size(0), w1b.size(1)] 31 | updown = ((w2b @ w1b) + ((orig_weight.to(dtype = w1a.dtype) @ w2a) @ w1a)) 32 | 33 | return self.finalize_updown(updown, orig_weight, output_shape) 34 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/network_hada.py: -------------------------------------------------------------------------------- 1 | import lyco_helpers 2 | import network 3 | 4 | 5 | class ModuleTypeHada(network.ModuleType): 6 | def create_module(self, net: network.Network, weights: network.NetworkWeights): 7 | if all(x in weights.w for x in ["hada_w1_a", "hada_w1_b", "hada_w2_a", "hada_w2_b"]): 8 | return NetworkModuleHada(net, weights) 9 | 10 | return None 11 | 12 | 13 | class NetworkModuleHada(network.NetworkModule): 14 | def __init__(self, net: network.Network, weights: network.NetworkWeights): 15 | super().__init__(net, weights) 16 | 17 | if hasattr(self.sd_module, 'weight'): 18 | self.shape = self.sd_module.weight.shape 19 | 20 | self.w1a = weights.w["hada_w1_a"] 21 | self.w1b = weights.w["hada_w1_b"] 22 | self.dim = self.w1b.shape[0] 23 | self.w2a = weights.w["hada_w2_a"] 24 | self.w2b = weights.w["hada_w2_b"] 25 | 26 | self.t1 = weights.w.get("hada_t1") 27 | self.t2 = weights.w.get("hada_t2") 28 | 29 | def calc_updown(self, orig_weight): 30 | w1a = self.w1a.to(orig_weight.device) 31 | w1b = self.w1b.to(orig_weight.device) 32 | w2a = self.w2a.to(orig_weight.device) 33 | w2b = self.w2b.to(orig_weight.device) 34 | 35 | output_shape = [w1a.size(0), w1b.size(1)] 36 | 37 | if self.t1 is not None: 38 | output_shape = [w1a.size(1), w1b.size(1)] 39 | t1 = self.t1.to(orig_weight.device) 40 | updown1 = lyco_helpers.make_weight_cp(t1, w1a, w1b) 41 | output_shape += t1.shape[2:] 42 | else: 43 | if len(w1b.shape) == 4: 44 | output_shape += w1b.shape[2:] 45 | updown1 = lyco_helpers.rebuild_conventional(w1a, w1b, output_shape) 46 | 47 | if self.t2 is not None: 48 | t2 = self.t2.to(orig_weight.device) 49 | updown2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) 50 | else: 51 | updown2 = lyco_helpers.rebuild_conventional(w2a, w2b, output_shape) 52 | 53 | updown = updown1 * updown2 54 | 55 | return self.finalize_updown(updown, orig_weight, output_shape) 56 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/network_ia3.py: -------------------------------------------------------------------------------- 1 | import network 2 | 3 | 4 | class ModuleTypeIa3(network.ModuleType): 5 | def create_module(self, net: network.Network, weights: network.NetworkWeights): 6 | if all(x in weights.w for x in ["weight"]): 7 | return NetworkModuleIa3(net, weights) 8 | 9 | return None 10 | 11 | 12 | class NetworkModuleIa3(network.NetworkModule): 13 | def __init__(self, net: network.Network, weights: network.NetworkWeights): 14 | super().__init__(net, weights) 15 | 16 | self.w = weights.w["weight"] 17 | self.on_input = weights.w["on_input"].item() 18 | 19 | def calc_updown(self, orig_weight): 20 | w = self.w.to(orig_weight.device) 21 | 22 | output_shape = [w.size(0), orig_weight.size(1)] 23 | if self.on_input: 24 | output_shape.reverse() 25 | else: 26 | w = w.reshape(-1, 1) 27 | 28 | updown = orig_weight * w 29 | 30 | return self.finalize_updown(updown, orig_weight, output_shape) 31 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/network_lokr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import lyco_helpers 4 | import network 5 | 6 | 7 | class ModuleTypeLokr(network.ModuleType): 8 | def create_module(self, net: network.Network, weights: network.NetworkWeights): 9 | has_1 = "lokr_w1" in weights.w or ("lokr_w1_a" in weights.w and "lokr_w1_b" in weights.w) 10 | has_2 = "lokr_w2" in weights.w or ("lokr_w2_a" in weights.w and "lokr_w2_b" in weights.w) 11 | if has_1 and has_2: 12 | return NetworkModuleLokr(net, weights) 13 | 14 | return None 15 | 16 | 17 | def make_kron(orig_shape, w1, w2): 18 | if len(w2.shape) == 4: 19 | w1 = w1.unsqueeze(2).unsqueeze(2) 20 | w2 = w2.contiguous() 21 | return torch.kron(w1, w2).reshape(orig_shape) 22 | 23 | 24 | class NetworkModuleLokr(network.NetworkModule): 25 | def __init__(self, net: network.Network, weights: network.NetworkWeights): 26 | super().__init__(net, weights) 27 | 28 | self.w1 = weights.w.get("lokr_w1") 29 | self.w1a = weights.w.get("lokr_w1_a") 30 | self.w1b = weights.w.get("lokr_w1_b") 31 | self.dim = self.w1b.shape[0] if self.w1b is not None else self.dim 32 | self.w2 = weights.w.get("lokr_w2") 33 | self.w2a = weights.w.get("lokr_w2_a") 34 | self.w2b = weights.w.get("lokr_w2_b") 35 | self.dim = self.w2b.shape[0] if self.w2b is not None else self.dim 36 | self.t2 = weights.w.get("lokr_t2") 37 | 38 | def calc_updown(self, orig_weight): 39 | if self.w1 is not None: 40 | w1 = self.w1.to(orig_weight.device) 41 | else: 42 | w1a = self.w1a.to(orig_weight.device) 43 | w1b = self.w1b.to(orig_weight.device) 44 | w1 = w1a @ w1b 45 | 46 | if self.w2 is not None: 47 | w2 = self.w2.to(orig_weight.device) 48 | elif self.t2 is None: 49 | w2a = self.w2a.to(orig_weight.device) 50 | w2b = self.w2b.to(orig_weight.device) 51 | w2 = w2a @ w2b 52 | else: 53 | t2 = self.t2.to(orig_weight.device) 54 | w2a = self.w2a.to(orig_weight.device) 55 | w2b = self.w2b.to(orig_weight.device) 56 | w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) 57 | 58 | output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] 59 | if len(orig_weight.shape) == 4: 60 | output_shape = orig_weight.shape 61 | 62 | updown = make_kron(output_shape, w1, w2) 63 | 64 | return self.finalize_updown(updown, orig_weight, output_shape) 65 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/network_norm.py: -------------------------------------------------------------------------------- 1 | import network 2 | 3 | 4 | class ModuleTypeNorm(network.ModuleType): 5 | def create_module(self, net: network.Network, weights: network.NetworkWeights): 6 | if all(x in weights.w for x in ["w_norm", "b_norm"]): 7 | return NetworkModuleNorm(net, weights) 8 | 9 | return None 10 | 11 | 12 | class NetworkModuleNorm(network.NetworkModule): 13 | def __init__(self, net: network.Network, weights: network.NetworkWeights): 14 | super().__init__(net, weights) 15 | 16 | self.w_norm = weights.w.get("w_norm") 17 | self.b_norm = weights.w.get("b_norm") 18 | 19 | def calc_updown(self, orig_weight): 20 | output_shape = self.w_norm.shape 21 | updown = self.w_norm.to(orig_weight.device) 22 | 23 | if self.b_norm is not None: 24 | ex_bias = self.b_norm.to(orig_weight.device) 25 | else: 26 | ex_bias = None 27 | 28 | return self.finalize_updown(updown, orig_weight, output_shape, ex_bias) 29 | -------------------------------------------------------------------------------- /extensions-builtin/Lora/preload.py: -------------------------------------------------------------------------------- 1 | import os 2 | from modules import paths 3 | from modules.paths_internal import normalized_filepath 4 | 5 | 6 | def preload(parser): 7 | parser.add_argument("--lora-dir", type=normalized_filepath, help="Path to directory with Lora networks.", default=os.path.join(paths.models_path, 'Lora')) 8 | parser.add_argument("--lyco-dir-backcompat", type=normalized_filepath, help="Path to directory with LyCORIS networks (for backawards compatibility; can also use --lyco-dir).", default=os.path.join(paths.models_path, 'LyCORIS')) 9 | -------------------------------------------------------------------------------- /extensions-builtin/ScuNET/preload.py: -------------------------------------------------------------------------------- 1 | import os 2 | from modules import paths 3 | 4 | 5 | def preload(parser): 6 | parser.add_argument("--scunet-models-path", type=str, help="Path to directory with ScuNET model file(s).", default=os.path.join(paths.models_path, 'ScuNET')) 7 | -------------------------------------------------------------------------------- /extensions-builtin/SwinIR/preload.py: -------------------------------------------------------------------------------- 1 | import os 2 | from modules import paths 3 | 4 | 5 | def preload(parser): 6 | parser.add_argument("--swinir-models-path", type=str, help="Path to directory with SwinIR model file(s).", default=os.path.join(paths.models_path, 'SwinIR')) 7 | -------------------------------------------------------------------------------- /extensions-builtin/canvas-zoom-and-pan/scripts/hotkey_config.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from modules import shared 3 | 4 | shared.options_templates.update(shared.options_section(('canvas_hotkey', "Canvas Hotkeys"), { 5 | "canvas_hotkey_zoom": shared.OptionInfo("Alt", "Zoom canvas", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), 6 | "canvas_hotkey_adjust": shared.OptionInfo("Ctrl", "Adjust brush size", gr.Radio, {"choices": ["Shift","Ctrl", "Alt"]}).info("If you choose 'Shift' you cannot scroll horizontally, 'Alt' can cause a little trouble in firefox"), 7 | "canvas_hotkey_shrink_brush": shared.OptionInfo("Q", "Shrink the brush size"), 8 | "canvas_hotkey_grow_brush": shared.OptionInfo("W", "Enlarge the brush size"), 9 | "canvas_hotkey_move": shared.OptionInfo("F", "Moving the canvas").info("To work correctly in firefox, turn off 'Automatically search the page text when typing' in the browser settings"), 10 | "canvas_hotkey_fullscreen": shared.OptionInfo("S", "Fullscreen Mode, maximizes the picture so that it fits into the screen and stretches it to its full width "), 11 | "canvas_hotkey_reset": shared.OptionInfo("R", "Reset zoom and canvas position"), 12 | "canvas_hotkey_overlap": shared.OptionInfo("O", "Toggle overlap").info("Technical button, needed for testing"), 13 | "canvas_show_tooltip": shared.OptionInfo(True, "Enable tooltip on the canvas"), 14 | "canvas_auto_expand": shared.OptionInfo(True, "Automatically expands an image that does not fit completely in the canvas area, similar to manually pressing the S and R buttons"), 15 | "canvas_blur_prompt": shared.OptionInfo(False, "Take the focus off the prompt when working with a canvas"), 16 | "canvas_disabled_functions": shared.OptionInfo(["Overlap"], "Disable function that you don't use", gr.CheckboxGroup, {"choices": ["Zoom","Adjust brush size","Hotkey enlarge brush","Hotkey shrink brush","Moving canvas","Fullscreen","Reset Zoom","Overlap"]}), 17 | })) 18 | -------------------------------------------------------------------------------- /extensions-builtin/canvas-zoom-and-pan/style.css: -------------------------------------------------------------------------------- 1 | .canvas-tooltip-info { 2 | position: absolute; 3 | top: 10px; 4 | left: 10px; 5 | cursor: help; 6 | background-color: rgba(0, 0, 0, 0.3); 7 | width: 20px; 8 | height: 20px; 9 | border-radius: 50%; 10 | display: flex; 11 | align-items: center; 12 | justify-content: center; 13 | flex-direction: column; 14 | 15 | z-index: 100; 16 | } 17 | 18 | .canvas-tooltip-info::after { 19 | content: ''; 20 | display: block; 21 | width: 2px; 22 | height: 7px; 23 | background-color: white; 24 | margin-top: 2px; 25 | } 26 | 27 | .canvas-tooltip-info::before { 28 | content: ''; 29 | display: block; 30 | width: 2px; 31 | height: 2px; 32 | background-color: white; 33 | } 34 | 35 | .canvas-tooltip-content { 36 | display: none; 37 | background-color: #f9f9f9; 38 | color: #333; 39 | border: 1px solid #ddd; 40 | padding: 15px; 41 | position: absolute; 42 | top: 40px; 43 | left: 10px; 44 | width: 250px; 45 | font-size: 16px; 46 | opacity: 0; 47 | border-radius: 8px; 48 | box-shadow: 0px 8px 16px 0px rgba(0,0,0,0.2); 49 | 50 | z-index: 100; 51 | } 52 | 53 | .canvas-tooltip:hover .canvas-tooltip-content { 54 | display: block; 55 | animation: fadeIn 0.5s; 56 | opacity: 1; 57 | } 58 | 59 | @keyframes fadeIn { 60 | from {opacity: 0;} 61 | to {opacity: 1;} 62 | } 63 | 64 | .styler { 65 | overflow:inherit !important; 66 | } -------------------------------------------------------------------------------- /extensions-builtin/mobile/javascript/mobile.js: -------------------------------------------------------------------------------- 1 | var isSetupForMobile = false; 2 | 3 | function isMobile() { 4 | for (var tab of ["txt2img", "img2img"]) { 5 | var imageTab = gradioApp().getElementById(tab + '_results'); 6 | if (imageTab && imageTab.offsetParent && imageTab.offsetLeft == 0) { 7 | return true; 8 | } 9 | } 10 | 11 | return false; 12 | } 13 | 14 | function reportWindowSize() { 15 | if (gradioApp().querySelector('.toprow-compact-tools')) return; // not applicable for compact prompt layout 16 | 17 | var currentlyMobile = isMobile(); 18 | if (currentlyMobile == isSetupForMobile) return; 19 | isSetupForMobile = currentlyMobile; 20 | 21 | for (var tab of ["txt2img", "img2img"]) { 22 | var button = gradioApp().getElementById(tab + '_generate_box'); 23 | var target = gradioApp().getElementById(currentlyMobile ? tab + '_results' : tab + '_actions_column'); 24 | target.insertBefore(button, target.firstElementChild); 25 | 26 | gradioApp().getElementById(tab + '_results').classList.toggle('mobile', currentlyMobile); 27 | } 28 | } 29 | 30 | window.addEventListener("resize", reportWindowSize); 31 | 32 | onUiLoaded(function() { 33 | reportWindowSize(); 34 | }); 35 | -------------------------------------------------------------------------------- /extensions-builtin/postprocessing-for-training/scripts/postprocessing_caption.py: -------------------------------------------------------------------------------- 1 | from modules import scripts_postprocessing, ui_components, deepbooru, shared 2 | import gradio as gr 3 | 4 | 5 | class ScriptPostprocessingCeption(scripts_postprocessing.ScriptPostprocessing): 6 | name = "Caption" 7 | order = 4040 8 | 9 | def ui(self): 10 | with ui_components.InputAccordion(False, label="Caption") as enable: 11 | option = gr.CheckboxGroup(value=["Deepbooru"], choices=["Deepbooru", "BLIP"], show_label=False) 12 | 13 | return { 14 | "enable": enable, 15 | "option": option, 16 | } 17 | 18 | def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option): 19 | if not enable: 20 | return 21 | 22 | captions = [pp.caption] 23 | 24 | if "Deepbooru" in option: 25 | captions.append(deepbooru.model.tag(pp.image)) 26 | 27 | if "BLIP" in option: 28 | captions.append(shared.interrogator.interrogate(pp.image.convert("RGB"))) 29 | 30 | pp.caption = ", ".join([x for x in captions if x]) 31 | -------------------------------------------------------------------------------- /extensions-builtin/postprocessing-for-training/scripts/postprocessing_create_flipped_copies.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageOps, Image 2 | 3 | from modules import scripts_postprocessing, ui_components 4 | import gradio as gr 5 | 6 | 7 | class ScriptPostprocessingCreateFlippedCopies(scripts_postprocessing.ScriptPostprocessing): 8 | name = "Create flipped copies" 9 | order = 4030 10 | 11 | def ui(self): 12 | with ui_components.InputAccordion(False, label="Create flipped copies") as enable: 13 | with gr.Row(): 14 | option = gr.CheckboxGroup(value=["Horizontal"], choices=["Horizontal", "Vertical", "Both"], show_label=False) 15 | 16 | return { 17 | "enable": enable, 18 | "option": option, 19 | } 20 | 21 | def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, option): 22 | if not enable: 23 | return 24 | 25 | if "Horizontal" in option: 26 | pp.extra_images.append(ImageOps.mirror(pp.image)) 27 | 28 | if "Vertical" in option: 29 | pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM)) 30 | 31 | if "Both" in option: 32 | pp.extra_images.append(pp.image.transpose(Image.Transpose.FLIP_TOP_BOTTOM).transpose(Image.Transpose.FLIP_LEFT_RIGHT)) 33 | -------------------------------------------------------------------------------- /extensions-builtin/postprocessing-for-training/scripts/postprocessing_focal_crop.py: -------------------------------------------------------------------------------- 1 | 2 | from modules import scripts_postprocessing, ui_components, errors 3 | import gradio as gr 4 | 5 | from modules.textual_inversion import autocrop 6 | 7 | 8 | class ScriptPostprocessingFocalCrop(scripts_postprocessing.ScriptPostprocessing): 9 | name = "Auto focal point crop" 10 | order = 4010 11 | 12 | def ui(self): 13 | with ui_components.InputAccordion(False, label="Auto focal point crop") as enable: 14 | face_weight = gr.Slider(label='Focal point face weight', value=0.9, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_face_weight") 15 | entropy_weight = gr.Slider(label='Focal point entropy weight', value=0.15, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_entropy_weight") 16 | edges_weight = gr.Slider(label='Focal point edges weight', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_focal_crop_edges_weight") 17 | debug = gr.Checkbox(label='Create debug image', elem_id="train_process_focal_crop_debug") 18 | 19 | return { 20 | "enable": enable, 21 | "face_weight": face_weight, 22 | "entropy_weight": entropy_weight, 23 | "edges_weight": edges_weight, 24 | "debug": debug, 25 | } 26 | 27 | def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, face_weight, entropy_weight, edges_weight, debug): 28 | if not enable: 29 | return 30 | 31 | if not pp.shared.target_width or not pp.shared.target_height: 32 | return 33 | 34 | dnn_model_path = None 35 | try: 36 | dnn_model_path = autocrop.download_and_cache_models() 37 | except Exception: 38 | errors.report("Unable to load face detection model for auto crop selection. Falling back to lower quality haar method.", exc_info=True) 39 | 40 | autocrop_settings = autocrop.Settings( 41 | crop_width=pp.shared.target_width, 42 | crop_height=pp.shared.target_height, 43 | face_points_weight=face_weight, 44 | entropy_points_weight=entropy_weight, 45 | corner_points_weight=edges_weight, 46 | annotate_image=debug, 47 | dnn_model_path=dnn_model_path, 48 | ) 49 | 50 | result, *others = autocrop.crop_image(pp.image, autocrop_settings) 51 | 52 | pp.image = result 53 | pp.extra_images = [pp.create_copy(x, nametags=["focal-crop-debug"], disable_processing=True) for x in others] 54 | 55 | -------------------------------------------------------------------------------- /extensions-builtin/postprocessing-for-training/scripts/postprocessing_split_oversized.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from modules import scripts_postprocessing, ui_components 4 | import gradio as gr 5 | 6 | 7 | def split_pic(image, inverse_xy, width, height, overlap_ratio): 8 | if inverse_xy: 9 | from_w, from_h = image.height, image.width 10 | to_w, to_h = height, width 11 | else: 12 | from_w, from_h = image.width, image.height 13 | to_w, to_h = width, height 14 | h = from_h * to_w // from_w 15 | if inverse_xy: 16 | image = image.resize((h, to_w)) 17 | else: 18 | image = image.resize((to_w, h)) 19 | 20 | split_count = math.ceil((h - to_h * overlap_ratio) / (to_h * (1.0 - overlap_ratio))) 21 | y_step = (h - to_h) / (split_count - 1) 22 | for i in range(split_count): 23 | y = int(y_step * i) 24 | if inverse_xy: 25 | splitted = image.crop((y, 0, y + to_h, to_w)) 26 | else: 27 | splitted = image.crop((0, y, to_w, y + to_h)) 28 | yield splitted 29 | 30 | 31 | class ScriptPostprocessingSplitOversized(scripts_postprocessing.ScriptPostprocessing): 32 | name = "Split oversized images" 33 | order = 4000 34 | 35 | def ui(self): 36 | with ui_components.InputAccordion(False, label="Split oversized images") as enable: 37 | with gr.Row(): 38 | split_threshold = gr.Slider(label='Threshold', value=0.5, minimum=0.0, maximum=1.0, step=0.05, elem_id="postprocess_split_threshold") 39 | overlap_ratio = gr.Slider(label='Overlap ratio', value=0.2, minimum=0.0, maximum=0.9, step=0.05, elem_id="postprocess_overlap_ratio") 40 | 41 | return { 42 | "enable": enable, 43 | "split_threshold": split_threshold, 44 | "overlap_ratio": overlap_ratio, 45 | } 46 | 47 | def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, split_threshold, overlap_ratio): 48 | if not enable: 49 | return 50 | 51 | width = pp.shared.target_width 52 | height = pp.shared.target_height 53 | 54 | if not width or not height: 55 | return 56 | 57 | if pp.image.height > pp.image.width: 58 | ratio = (pp.image.width * height) / (pp.image.height * width) 59 | inverse_xy = False 60 | else: 61 | ratio = (pp.image.height * width) / (pp.image.width * height) 62 | inverse_xy = True 63 | 64 | if ratio >= 1.0 or ratio > split_threshold: 65 | return 66 | 67 | result, *others = split_pic(pp.image, inverse_xy, width, height, overlap_ratio) 68 | 69 | pp.image = result 70 | pp.extra_images = [pp.create_copy(x) for x in others] 71 | 72 | -------------------------------------------------------------------------------- /extensions-builtin/prompt-bracket-checker/javascript/prompt-bracket-checker.js: -------------------------------------------------------------------------------- 1 | // Stable Diffusion WebUI - Bracket checker 2 | // By Hingashi no Florin/Bwin4L & @akx 3 | // Counts open and closed brackets (round, square, curly) in the prompt and negative prompt text boxes in the txt2img and img2img tabs. 4 | // If there's a mismatch, the keyword counter turns red and if you hover on it, a tooltip tells you what's wrong. 5 | 6 | function checkBrackets(textArea, counterElt) { 7 | var counts = {}; 8 | (textArea.value.match(/[(){}[\]]/g) || []).forEach(bracket => { 9 | counts[bracket] = (counts[bracket] || 0) + 1; 10 | }); 11 | var errors = []; 12 | 13 | function checkPair(open, close, kind) { 14 | if (counts[open] !== counts[close]) { 15 | errors.push( 16 | `${open}...${close} - Detected ${counts[open] || 0} opening and ${counts[close] || 0} closing ${kind}.` 17 | ); 18 | } 19 | } 20 | 21 | checkPair('(', ')', 'round brackets'); 22 | checkPair('[', ']', 'square brackets'); 23 | checkPair('{', '}', 'curly brackets'); 24 | counterElt.title = errors.join('\n'); 25 | counterElt.classList.toggle('error', errors.length !== 0); 26 | } 27 | 28 | function setupBracketChecking(id_prompt, id_counter) { 29 | var textarea = gradioApp().querySelector("#" + id_prompt + " > label > textarea"); 30 | var counter = gradioApp().getElementById(id_counter); 31 | 32 | if (textarea && counter) { 33 | textarea.addEventListener("input", () => checkBrackets(textarea, counter)); 34 | } 35 | } 36 | 37 | onUiLoaded(function() { 38 | setupBracketChecking('txt2img_prompt', 'txt2img_token_counter'); 39 | setupBracketChecking('txt2img_neg_prompt', 'txt2img_negative_token_counter'); 40 | setupBracketChecking('img2img_prompt', 'img2img_token_counter'); 41 | setupBracketChecking('img2img_neg_prompt', 'img2img_negative_token_counter'); 42 | }); 43 | -------------------------------------------------------------------------------- /extensions/put extensions here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/extensions/put extensions here.txt -------------------------------------------------------------------------------- /html/card-no-preview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/html/card-no-preview.png -------------------------------------------------------------------------------- /html/extra-networks-card.html: -------------------------------------------------------------------------------- 1 |
2 | {background_image} 3 |
{copy_path_button}{metadata_button}{edit_button}
4 |
5 |
{search_terms}
6 | {name} 7 | {description} 8 |
9 |
10 | -------------------------------------------------------------------------------- /html/extra-networks-copy-path-button.html: -------------------------------------------------------------------------------- 1 |
5 |
-------------------------------------------------------------------------------- /html/extra-networks-edit-item-button.html: -------------------------------------------------------------------------------- 1 |
4 |
-------------------------------------------------------------------------------- /html/extra-networks-metadata-button.html: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /html/extra-networks-no-cards.html: -------------------------------------------------------------------------------- 1 |
2 |

Nothing here. Add some content to the following directories:

3 | 4 |
    5 | {dirs} 6 |
7 |
8 | 9 | -------------------------------------------------------------------------------- /html/extra-networks-pane-dirs.html: -------------------------------------------------------------------------------- 1 |
2 |
3 | {dirs_html} 4 |
5 |
6 | {items_html} 7 |
8 |
9 | -------------------------------------------------------------------------------- /html/extra-networks-pane-tree.html: -------------------------------------------------------------------------------- 1 |
2 |
3 | {tree_html} 4 |
5 |
6 | {items_html} 7 |
8 |
-------------------------------------------------------------------------------- /html/extra-networks-tree-button.html: -------------------------------------------------------------------------------- 1 | 2 |
8 | 9 | {action_list_item_action_leading} 10 | 11 | 12 | {action_list_item_visual_leading} 13 | 14 | 15 | {action_list_item_label} 16 | 17 | 18 | {action_list_item_visual_trailing} 19 | 20 | 21 | {action_list_item_action_trailing} 22 | 23 |
-------------------------------------------------------------------------------- /html/footer.html: -------------------------------------------------------------------------------- 1 |
2 | API 3 |  •  4 | Github 5 |  •  6 | Gradio 7 |  •  8 | Startup profile 9 |  •  10 | Reload UI 11 |
12 |
13 |
14 | {versions} 15 |
16 | -------------------------------------------------------------------------------- /javascript/edit-order.js: -------------------------------------------------------------------------------- 1 | /* alt+left/right moves text in prompt */ 2 | 3 | function keyupEditOrder(event) { 4 | if (!opts.keyedit_move) return; 5 | 6 | let target = event.originalTarget || event.composedPath()[0]; 7 | if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return; 8 | if (!event.altKey) return; 9 | 10 | let isLeft = event.key == "ArrowLeft"; 11 | let isRight = event.key == "ArrowRight"; 12 | if (!isLeft && !isRight) return; 13 | event.preventDefault(); 14 | 15 | let selectionStart = target.selectionStart; 16 | let selectionEnd = target.selectionEnd; 17 | let text = target.value; 18 | let items = text.split(","); 19 | let indexStart = (text.slice(0, selectionStart).match(/,/g) || []).length; 20 | let indexEnd = (text.slice(0, selectionEnd).match(/,/g) || []).length; 21 | let range = indexEnd - indexStart + 1; 22 | 23 | if (isLeft && indexStart > 0) { 24 | items.splice(indexStart - 1, 0, ...items.splice(indexStart, range)); 25 | target.value = items.join(); 26 | target.selectionStart = items.slice(0, indexStart - 1).join().length + (indexStart == 1 ? 0 : 1); 27 | target.selectionEnd = items.slice(0, indexEnd).join().length; 28 | } else if (isRight && indexEnd < items.length - 1) { 29 | items.splice(indexStart + 1, 0, ...items.splice(indexStart, range)); 30 | target.value = items.join(); 31 | target.selectionStart = items.slice(0, indexStart + 1).join().length + 1; 32 | target.selectionEnd = items.slice(0, indexEnd + 2).join().length; 33 | } 34 | 35 | event.preventDefault(); 36 | updateInput(target); 37 | } 38 | 39 | addEventListener('keydown', (event) => { 40 | keyupEditOrder(event); 41 | }); 42 | -------------------------------------------------------------------------------- /javascript/generationParams.js: -------------------------------------------------------------------------------- 1 | // attaches listeners to the txt2img and img2img galleries to update displayed generation param text when the image changes 2 | 3 | let txt2img_gallery, img2img_gallery, modal = undefined; 4 | onAfterUiUpdate(function() { 5 | if (!txt2img_gallery) { 6 | txt2img_gallery = attachGalleryListeners("txt2img"); 7 | } 8 | if (!img2img_gallery) { 9 | img2img_gallery = attachGalleryListeners("img2img"); 10 | } 11 | if (!modal) { 12 | modal = gradioApp().getElementById('lightboxModal'); 13 | modalObserver.observe(modal, {attributes: true, attributeFilter: ['style']}); 14 | } 15 | }); 16 | 17 | let modalObserver = new MutationObserver(function(mutations) { 18 | mutations.forEach(function(mutationRecord) { 19 | let selectedTab = gradioApp().querySelector('#tabs div button.selected')?.innerText; 20 | if (mutationRecord.target.style.display === 'none' && (selectedTab === 'txt2img' || selectedTab === 'img2img')) { 21 | gradioApp().getElementById(selectedTab + "_generation_info_button")?.click(); 22 | } 23 | }); 24 | }); 25 | 26 | function attachGalleryListeners(tab_name) { 27 | var gallery = gradioApp().querySelector('#' + tab_name + '_gallery'); 28 | gallery?.addEventListener('click', () => gradioApp().getElementById(tab_name + "_generation_info_button").click()); 29 | gallery?.addEventListener('keydown', (e) => { 30 | if (e.keyCode == 37 || e.keyCode == 39) { // left or right arrow 31 | gradioApp().getElementById(tab_name + "_generation_info_button").click(); 32 | } 33 | }); 34 | return gallery; 35 | } 36 | -------------------------------------------------------------------------------- /javascript/hires_fix.js: -------------------------------------------------------------------------------- 1 | 2 | function onCalcResolutionHires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y) { 3 | function setInactive(elem, inactive) { 4 | elem.classList.toggle('inactive', !!inactive); 5 | } 6 | 7 | var hrUpscaleBy = gradioApp().getElementById('txt2img_hr_scale'); 8 | var hrResizeX = gradioApp().getElementById('txt2img_hr_resize_x'); 9 | var hrResizeY = gradioApp().getElementById('txt2img_hr_resize_y'); 10 | 11 | gradioApp().getElementById('txt2img_hires_fix_row2').style.display = opts.use_old_hires_fix_width_height ? "none" : ""; 12 | 13 | setInactive(hrUpscaleBy, opts.use_old_hires_fix_width_height || hr_resize_x > 0 || hr_resize_y > 0); 14 | setInactive(hrResizeX, opts.use_old_hires_fix_width_height || hr_resize_x == 0); 15 | setInactive(hrResizeY, opts.use_old_hires_fix_width_height || hr_resize_y == 0); 16 | 17 | return [enable, width, height, hr_scale, hr_resize_x, hr_resize_y]; 18 | } 19 | -------------------------------------------------------------------------------- /javascript/imageMaskFix.js: -------------------------------------------------------------------------------- 1 | /** 2 | * temporary fix for https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/668 3 | * @see https://github.com/gradio-app/gradio/issues/1721 4 | */ 5 | function imageMaskResize() { 6 | const canvases = gradioApp().querySelectorAll('#img2maskimg .touch-none canvas'); 7 | if (!canvases.length) { 8 | window.removeEventListener('resize', imageMaskResize); 9 | return; 10 | } 11 | 12 | const wrapper = canvases[0].closest('.touch-none'); 13 | const previewImage = wrapper.previousElementSibling; 14 | 15 | if (!previewImage.complete) { 16 | previewImage.addEventListener('load', imageMaskResize); 17 | return; 18 | } 19 | 20 | const w = previewImage.width; 21 | const h = previewImage.height; 22 | const nw = previewImage.naturalWidth; 23 | const nh = previewImage.naturalHeight; 24 | const portrait = nh > nw; 25 | 26 | const wW = Math.min(w, portrait ? h / nh * nw : w / nw * nw); 27 | const wH = Math.min(h, portrait ? h / nh * nh : w / nw * nh); 28 | 29 | wrapper.style.width = `${wW}px`; 30 | wrapper.style.height = `${wH}px`; 31 | wrapper.style.left = `0px`; 32 | wrapper.style.top = `0px`; 33 | 34 | canvases.forEach(c => { 35 | c.style.width = c.style.height = ''; 36 | c.style.maxWidth = '100%'; 37 | c.style.maxHeight = '100%'; 38 | c.style.objectFit = 'contain'; 39 | }); 40 | } 41 | 42 | onAfterUiUpdate(imageMaskResize); 43 | window.addEventListener('resize', imageMaskResize); 44 | -------------------------------------------------------------------------------- /javascript/imageviewerGamepad.js: -------------------------------------------------------------------------------- 1 | let gamepads = []; 2 | 3 | window.addEventListener('gamepadconnected', (e) => { 4 | const index = e.gamepad.index; 5 | let isWaiting = false; 6 | gamepads[index] = setInterval(async() => { 7 | if (!opts.js_modal_lightbox_gamepad || isWaiting) return; 8 | const gamepad = navigator.getGamepads()[index]; 9 | const xValue = gamepad.axes[0]; 10 | if (xValue <= -0.3) { 11 | modalPrevImage(e); 12 | isWaiting = true; 13 | } else if (xValue >= 0.3) { 14 | modalNextImage(e); 15 | isWaiting = true; 16 | } 17 | if (isWaiting) { 18 | await sleepUntil(() => { 19 | const xValue = navigator.getGamepads()[index].axes[0]; 20 | if (xValue < 0.3 && xValue > -0.3) { 21 | return true; 22 | } 23 | }, opts.js_modal_lightbox_gamepad_repeat); 24 | isWaiting = false; 25 | } 26 | }, 10); 27 | }); 28 | 29 | window.addEventListener('gamepaddisconnected', (e) => { 30 | clearInterval(gamepads[e.gamepad.index]); 31 | }); 32 | 33 | /* 34 | Primarily for vr controller type pointer devices. 35 | I use the wheel event because there's currently no way to do it properly with web xr. 36 | */ 37 | let isScrolling = false; 38 | window.addEventListener('wheel', (e) => { 39 | if (!opts.js_modal_lightbox_gamepad || isScrolling) return; 40 | isScrolling = true; 41 | 42 | if (e.deltaX <= -0.6) { 43 | modalPrevImage(e); 44 | } else if (e.deltaX >= 0.6) { 45 | modalNextImage(e); 46 | } 47 | 48 | setTimeout(() => { 49 | isScrolling = false; 50 | }, opts.js_modal_lightbox_gamepad_repeat); 51 | }); 52 | 53 | function sleepUntil(f, timeout) { 54 | return new Promise((resolve) => { 55 | const timeStart = new Date(); 56 | const wait = setInterval(function() { 57 | if (f() || new Date() - timeStart > timeout) { 58 | clearInterval(wait); 59 | resolve(); 60 | } 61 | }, 20); 62 | }); 63 | } 64 | -------------------------------------------------------------------------------- /javascript/inputAccordion.js: -------------------------------------------------------------------------------- 1 | function inputAccordionChecked(id, checked) { 2 | var accordion = gradioApp().getElementById(id); 3 | accordion.visibleCheckbox.checked = checked; 4 | accordion.onVisibleCheckboxChange(); 5 | } 6 | 7 | function setupAccordion(accordion) { 8 | var labelWrap = accordion.querySelector('.label-wrap'); 9 | var gradioCheckbox = gradioApp().querySelector('#' + accordion.id + "-checkbox input"); 10 | var extra = gradioApp().querySelector('#' + accordion.id + "-extra"); 11 | var span = labelWrap.querySelector('span'); 12 | var linked = true; 13 | 14 | var isOpen = function() { 15 | return labelWrap.classList.contains('open'); 16 | }; 17 | 18 | var observerAccordionOpen = new MutationObserver(function(mutations) { 19 | mutations.forEach(function(mutationRecord) { 20 | accordion.classList.toggle('input-accordion-open', isOpen()); 21 | 22 | if (linked) { 23 | accordion.visibleCheckbox.checked = isOpen(); 24 | accordion.onVisibleCheckboxChange(); 25 | } 26 | }); 27 | }); 28 | observerAccordionOpen.observe(labelWrap, {attributes: true, attributeFilter: ['class']}); 29 | 30 | if (extra) { 31 | labelWrap.insertBefore(extra, labelWrap.lastElementChild); 32 | } 33 | 34 | accordion.onChecked = function(checked) { 35 | if (isOpen() != checked) { 36 | labelWrap.click(); 37 | } 38 | }; 39 | 40 | var visibleCheckbox = document.createElement('INPUT'); 41 | visibleCheckbox.type = 'checkbox'; 42 | visibleCheckbox.checked = isOpen(); 43 | visibleCheckbox.id = accordion.id + "-visible-checkbox"; 44 | visibleCheckbox.className = gradioCheckbox.className + " input-accordion-checkbox"; 45 | span.insertBefore(visibleCheckbox, span.firstChild); 46 | 47 | accordion.visibleCheckbox = visibleCheckbox; 48 | accordion.onVisibleCheckboxChange = function() { 49 | if (linked && isOpen() != visibleCheckbox.checked) { 50 | labelWrap.click(); 51 | } 52 | 53 | gradioCheckbox.checked = visibleCheckbox.checked; 54 | updateInput(gradioCheckbox); 55 | }; 56 | 57 | visibleCheckbox.addEventListener('click', function(event) { 58 | linked = false; 59 | event.stopPropagation(); 60 | }); 61 | visibleCheckbox.addEventListener('input', accordion.onVisibleCheckboxChange); 62 | } 63 | 64 | onUiLoaded(function() { 65 | for (var accordion of gradioApp().querySelectorAll('.input-accordion')) { 66 | setupAccordion(accordion); 67 | } 68 | }); 69 | -------------------------------------------------------------------------------- /javascript/localStorage.js: -------------------------------------------------------------------------------- 1 | 2 | function localSet(k, v) { 3 | try { 4 | localStorage.setItem(k, v); 5 | } catch (e) { 6 | console.warn(`Failed to save ${k} to localStorage: ${e}`); 7 | } 8 | } 9 | 10 | function localGet(k, def) { 11 | try { 12 | return localStorage.getItem(k); 13 | } catch (e) { 14 | console.warn(`Failed to load ${k} from localStorage: ${e}`); 15 | } 16 | 17 | return def; 18 | } 19 | 20 | function localRemove(k) { 21 | try { 22 | return localStorage.removeItem(k); 23 | } catch (e) { 24 | console.warn(`Failed to remove ${k} from localStorage: ${e}`); 25 | } 26 | } 27 | -------------------------------------------------------------------------------- /javascript/notification.js: -------------------------------------------------------------------------------- 1 | // Monitors the gallery and sends a browser notification when the leading image is new. 2 | 3 | let lastHeadImg = null; 4 | 5 | let notificationButton = null; 6 | 7 | onAfterUiUpdate(function() { 8 | if (notificationButton == null) { 9 | notificationButton = gradioApp().getElementById('request_notifications'); 10 | 11 | if (notificationButton != null) { 12 | notificationButton.addEventListener('click', () => { 13 | void Notification.requestPermission(); 14 | }, true); 15 | } 16 | } 17 | 18 | const galleryPreviews = gradioApp().querySelectorAll('div[id^="tab_"] div[id$="_results"] .thumbnail-item > img'); 19 | 20 | if (galleryPreviews == null) return; 21 | 22 | const headImg = galleryPreviews[0]?.src; 23 | 24 | if (headImg == null || headImg == lastHeadImg) return; 25 | 26 | lastHeadImg = headImg; 27 | 28 | // play notification sound if available 29 | const notificationAudio = gradioApp().querySelector('#audio_notification audio'); 30 | if (notificationAudio) { 31 | notificationAudio.volume = opts.notification_volume / 100.0 || 1.0; 32 | notificationAudio.play(); 33 | } 34 | 35 | if (document.hasFocus()) return; 36 | 37 | // Multiple copies of the images are in the DOM when one is selected. Dedup with a Set to get the real number generated. 38 | const imgs = new Set(Array.from(galleryPreviews).map(img => img.src)); 39 | 40 | const notification = new Notification( 41 | 'Stable Diffusion', 42 | { 43 | body: `Generated ${imgs.size > 1 ? imgs.size - opts.return_grid : 1} image${imgs.size > 1 ? 's' : ''}`, 44 | icon: headImg, 45 | image: headImg, 46 | } 47 | ); 48 | 49 | notification.onclick = function(_) { 50 | parent.focus(); 51 | this.close(); 52 | }; 53 | }); 54 | -------------------------------------------------------------------------------- /javascript/settings.js: -------------------------------------------------------------------------------- 1 | let settingsExcludeTabsFromShowAll = { 2 | settings_tab_defaults: 1, 3 | settings_tab_sysinfo: 1, 4 | settings_tab_actions: 1, 5 | settings_tab_licenses: 1, 6 | }; 7 | 8 | function settingsShowAllTabs() { 9 | gradioApp().querySelectorAll('#settings > div').forEach(function(elem) { 10 | if (settingsExcludeTabsFromShowAll[elem.id]) return; 11 | 12 | elem.style.display = "block"; 13 | }); 14 | } 15 | 16 | function settingsShowOneTab() { 17 | gradioApp().querySelector('#settings_show_one_page').click(); 18 | } 19 | 20 | onUiLoaded(function() { 21 | var edit = gradioApp().querySelector('#settings_search'); 22 | var editTextarea = gradioApp().querySelector('#settings_search > label > input'); 23 | var buttonShowAllPages = gradioApp().getElementById('settings_show_all_pages'); 24 | var settings_tabs = gradioApp().querySelector('#settings div'); 25 | 26 | onEdit('settingsSearch', editTextarea, 250, function() { 27 | var searchText = (editTextarea.value || "").trim().toLowerCase(); 28 | 29 | gradioApp().querySelectorAll('#settings > div[id^=settings_] div[id^=column_settings_] > *').forEach(function(elem) { 30 | var visible = elem.textContent.trim().toLowerCase().indexOf(searchText) != -1; 31 | elem.style.display = visible ? "" : "none"; 32 | }); 33 | 34 | if (searchText != "") { 35 | settingsShowAllTabs(); 36 | } else { 37 | settingsShowOneTab(); 38 | } 39 | }); 40 | 41 | settings_tabs.insertBefore(edit, settings_tabs.firstChild); 42 | settings_tabs.appendChild(buttonShowAllPages); 43 | 44 | 45 | buttonShowAllPages.addEventListener("click", settingsShowAllTabs); 46 | }); 47 | 48 | 49 | onOptionsChanged(function() { 50 | if (gradioApp().querySelector('#settings .settings-category')) return; 51 | 52 | var sectionMap = {}; 53 | gradioApp().querySelectorAll('#settings > div > button').forEach(function(x) { 54 | sectionMap[x.textContent.trim()] = x; 55 | }); 56 | 57 | opts._categories.forEach(function(x) { 58 | var section = localization[x[0]] ?? x[0]; 59 | var category = localization[x[1]] ?? x[1]; 60 | 61 | var span = document.createElement('SPAN'); 62 | span.textContent = category; 63 | span.className = 'settings-category'; 64 | 65 | var sectionElem = sectionMap[section]; 66 | if (!sectionElem) return; 67 | 68 | sectionElem.parentElement.insertBefore(span, sectionElem); 69 | }); 70 | }); 71 | 72 | -------------------------------------------------------------------------------- /javascript/textualInversion.js: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | function start_training_textual_inversion() { 5 | gradioApp().querySelector('#ti_error').innerHTML = ''; 6 | 7 | var id = randomId(); 8 | requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function() {}, function(progress) { 9 | gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo; 10 | }); 11 | 12 | var res = Array.from(arguments); 13 | 14 | res[0] = id; 15 | 16 | return res; 17 | } 18 | -------------------------------------------------------------------------------- /javascript/token-counters.js: -------------------------------------------------------------------------------- 1 | let promptTokenCountUpdateFunctions = {}; 2 | 3 | function update_txt2img_tokens(...args) { 4 | // Called from Gradio 5 | update_token_counter("txt2img_token_button"); 6 | update_token_counter("txt2img_negative_token_button"); 7 | if (args.length == 2) { 8 | return args[0]; 9 | } 10 | return args; 11 | } 12 | 13 | function update_img2img_tokens(...args) { 14 | // Called from Gradio 15 | update_token_counter("img2img_token_button"); 16 | update_token_counter("img2img_negative_token_button"); 17 | if (args.length == 2) { 18 | return args[0]; 19 | } 20 | return args; 21 | } 22 | 23 | function update_token_counter(button_id) { 24 | promptTokenCountUpdateFunctions[button_id]?.(); 25 | } 26 | 27 | 28 | function recalculatePromptTokens(name) { 29 | promptTokenCountUpdateFunctions[name]?.(); 30 | } 31 | 32 | function recalculate_prompts_txt2img() { 33 | // Called from Gradio 34 | recalculatePromptTokens('txt2img_prompt'); 35 | recalculatePromptTokens('txt2img_neg_prompt'); 36 | return Array.from(arguments); 37 | } 38 | 39 | function recalculate_prompts_img2img() { 40 | // Called from Gradio 41 | recalculatePromptTokens('img2img_prompt'); 42 | recalculatePromptTokens('img2img_neg_prompt'); 43 | return Array.from(arguments); 44 | } 45 | 46 | function setupTokenCounting(id, id_counter, id_button) { 47 | var prompt = gradioApp().getElementById(id); 48 | var counter = gradioApp().getElementById(id_counter); 49 | var textarea = gradioApp().querySelector(`#${id} > label > textarea`); 50 | 51 | if (counter.parentElement == prompt.parentElement) { 52 | return; 53 | } 54 | 55 | prompt.parentElement.insertBefore(counter, prompt); 56 | prompt.parentElement.style.position = "relative"; 57 | 58 | var func = onEdit(id, textarea, 800, function() { 59 | if (counter.classList.contains("token-counter-visible")) { 60 | gradioApp().getElementById(id_button)?.click(); 61 | } 62 | }); 63 | promptTokenCountUpdateFunctions[id] = func; 64 | promptTokenCountUpdateFunctions[id_button] = func; 65 | } 66 | 67 | function toggleTokenCountingVisibility(id, id_counter, id_button) { 68 | var counter = gradioApp().getElementById(id_counter); 69 | 70 | counter.style.display = opts.disable_token_counters ? "none" : "block"; 71 | counter.classList.toggle("token-counter-visible", !opts.disable_token_counters); 72 | } 73 | 74 | function runCodeForTokenCounters(fun) { 75 | fun('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button'); 76 | fun('txt2img_neg_prompt', 'txt2img_negative_token_counter', 'txt2img_negative_token_button'); 77 | fun('img2img_prompt', 'img2img_token_counter', 'img2img_token_button'); 78 | fun('img2img_neg_prompt', 'img2img_negative_token_counter', 'img2img_negative_token_button'); 79 | } 80 | 81 | onUiLoaded(function() { 82 | runCodeForTokenCounters(setupTokenCounting); 83 | }); 84 | 85 | onOptionsChanged(function() { 86 | runCodeForTokenCounters(toggleTokenCountingVisibility); 87 | }); 88 | -------------------------------------------------------------------------------- /javascript/ui_settings_hints.js: -------------------------------------------------------------------------------- 1 | // various hints and extra info for the settings tab 2 | 3 | var settingsHintsSetup = false; 4 | 5 | onOptionsChanged(function() { 6 | if (settingsHintsSetup) return; 7 | settingsHintsSetup = true; 8 | 9 | gradioApp().querySelectorAll('#settings [id^=setting_]').forEach(function(div) { 10 | var name = div.id.substr(8); 11 | var commentBefore = opts._comments_before[name]; 12 | var commentAfter = opts._comments_after[name]; 13 | 14 | if (!commentBefore && !commentAfter) return; 15 | 16 | var span = null; 17 | if (div.classList.contains('gradio-checkbox')) span = div.querySelector('label span'); 18 | else if (div.classList.contains('gradio-checkboxgroup')) span = div.querySelector('span').firstChild; 19 | else if (div.classList.contains('gradio-radio')) span = div.querySelector('span').firstChild; 20 | else span = div.querySelector('label span').firstChild; 21 | 22 | if (!span) return; 23 | 24 | if (commentBefore) { 25 | var comment = document.createElement('DIV'); 26 | comment.className = 'settings-comment'; 27 | comment.innerHTML = commentBefore; 28 | span.parentElement.insertBefore(document.createTextNode('\xa0'), span); 29 | span.parentElement.insertBefore(comment, span); 30 | span.parentElement.insertBefore(document.createTextNode('\xa0'), span); 31 | } 32 | if (commentAfter) { 33 | comment = document.createElement('DIV'); 34 | comment.className = 'settings-comment'; 35 | comment.innerHTML = commentAfter; 36 | span.parentElement.insertBefore(comment, span.nextSibling); 37 | span.parentElement.insertBefore(document.createTextNode('\xa0'), span.nextSibling); 38 | } 39 | }); 40 | }); 41 | 42 | function settingsHintsShowQuicksettings() { 43 | requestGet("./internal/quicksettings-hint", {}, function(data) { 44 | var table = document.createElement('table'); 45 | table.className = 'popup-table'; 46 | 47 | data.forEach(function(obj) { 48 | var tr = document.createElement('tr'); 49 | var td = document.createElement('td'); 50 | td.textContent = obj.name; 51 | tr.appendChild(td); 52 | 53 | td = document.createElement('td'); 54 | td.textContent = obj.label; 55 | tr.appendChild(td); 56 | 57 | table.appendChild(tr); 58 | }); 59 | 60 | popup(table); 61 | }); 62 | } 63 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | from modules import launch_utils 2 | 3 | args = launch_utils.args 4 | python = launch_utils.python 5 | git = launch_utils.git 6 | index_url = launch_utils.index_url 7 | dir_repos = launch_utils.dir_repos 8 | 9 | commit_hash = launch_utils.commit_hash 10 | git_tag = launch_utils.git_tag 11 | 12 | run = launch_utils.run 13 | is_installed = launch_utils.is_installed 14 | repo_dir = launch_utils.repo_dir 15 | 16 | run_pip = launch_utils.run_pip 17 | check_run_python = launch_utils.check_run_python 18 | git_clone = launch_utils.git_clone 19 | git_pull_recursive = launch_utils.git_pull_recursive 20 | list_extensions = launch_utils.list_extensions 21 | run_extension_installer = launch_utils.run_extension_installer 22 | prepare_environment = launch_utils.prepare_environment 23 | configure_for_tests = launch_utils.configure_for_tests 24 | start = launch_utils.start 25 | 26 | 27 | def main(): 28 | if args.dump_sysinfo: 29 | filename = launch_utils.dump_sysinfo() 30 | 31 | print(f"Sysinfo saved as {filename}. Exiting...") 32 | 33 | exit(0) 34 | 35 | launch_utils.startup_timer.record("initial startup") 36 | 37 | with launch_utils.startup_timer.subcategory("prepare environment"): 38 | if not args.skip_prepare_environment: 39 | prepare_environment() 40 | 41 | if args.test_server: 42 | configure_for_tests() 43 | 44 | start() 45 | 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /localizations/Put localization files here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/localizations/Put localization files here.txt -------------------------------------------------------------------------------- /models/Stable-diffusion/Put Stable Diffusion checkpoints here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/models/Stable-diffusion/Put Stable Diffusion checkpoints here.txt -------------------------------------------------------------------------------- /models/VAE-approx/model.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/models/VAE-approx/model.pt -------------------------------------------------------------------------------- /models/VAE/Put VAE here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/models/VAE/Put VAE here.txt -------------------------------------------------------------------------------- /models/deepbooru/Put your deepbooru release project folder here.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/models/deepbooru/Put your deepbooru release project folder here.txt -------------------------------------------------------------------------------- /models/karlo/ViT-L-14_stats.th: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/models/karlo/ViT-L-14_stats.th -------------------------------------------------------------------------------- /modules/Roboto-Regular.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/modules/Roboto-Regular.ttf -------------------------------------------------------------------------------- /modules/codeformer_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | 5 | import torch 6 | 7 | from modules import ( 8 | devices, 9 | errors, 10 | face_restoration, 11 | face_restoration_utils, 12 | modelloader, 13 | shared, 14 | ) 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' 19 | model_download_name = 'codeformer-v0.1.0.pth' 20 | 21 | # used by e.g. postprocessing_codeformer.py 22 | codeformer: face_restoration.FaceRestoration | None = None 23 | 24 | 25 | class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration): 26 | def name(self): 27 | return "CodeFormer" 28 | 29 | def load_net(self) -> torch.Module: 30 | for model_path in modelloader.load_models( 31 | model_path=self.model_path, 32 | model_url=model_url, 33 | command_path=self.model_path, 34 | download_name=model_download_name, 35 | ext_filter=['.pth'], 36 | ): 37 | return modelloader.load_spandrel_model( 38 | model_path, 39 | device=devices.device_codeformer, 40 | expected_architecture='CodeFormer', 41 | ).model 42 | raise ValueError("No codeformer model found") 43 | 44 | def get_device(self): 45 | return devices.device_codeformer 46 | 47 | def restore(self, np_image, w: float | None = None): 48 | if w is None: 49 | w = getattr(shared.opts, "code_former_weight", 0.5) 50 | 51 | def restore_face(cropped_face_t): 52 | assert self.net is not None 53 | return self.net(cropped_face_t, weight=w, adain=True)[0] 54 | 55 | return self.restore_with_helper(np_image, restore_face) 56 | 57 | 58 | def setup_model(dirname: str) -> None: 59 | global codeformer 60 | try: 61 | codeformer = FaceRestorerCodeFormer(dirname) 62 | shared.face_restorers.append(codeformer) 63 | except Exception: 64 | errors.report("Error setting up CodeFormer", exc_info=True) 65 | -------------------------------------------------------------------------------- /modules/dat_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from modules import modelloader, errors 4 | from modules.shared import cmd_opts, opts 5 | from modules.upscaler import Upscaler, UpscalerData 6 | from modules.upscaler_utils import upscale_with_model 7 | 8 | 9 | class UpscalerDAT(Upscaler): 10 | def __init__(self, user_path): 11 | self.name = "DAT" 12 | self.user_path = user_path 13 | self.scalers = [] 14 | super().__init__() 15 | 16 | for file in self.find_models(ext_filter=[".pt", ".pth"]): 17 | name = modelloader.friendly_name(file) 18 | scaler_data = UpscalerData(name, file, upscaler=self, scale=None) 19 | self.scalers.append(scaler_data) 20 | 21 | for model in get_dat_models(self): 22 | if model.name in opts.dat_enabled_models: 23 | self.scalers.append(model) 24 | 25 | def do_upscale(self, img, path): 26 | try: 27 | info = self.load_model(path) 28 | except Exception: 29 | errors.report(f"Unable to load DAT model {path}", exc_info=True) 30 | return img 31 | 32 | model_descriptor = modelloader.load_spandrel_model( 33 | info.local_data_path, 34 | device=self.device, 35 | prefer_half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling), 36 | expected_architecture="DAT", 37 | ) 38 | return upscale_with_model( 39 | model_descriptor, 40 | img, 41 | tile_size=opts.DAT_tile, 42 | tile_overlap=opts.DAT_tile_overlap, 43 | ) 44 | 45 | def load_model(self, path): 46 | for scaler in self.scalers: 47 | if scaler.data_path == path: 48 | if scaler.local_data_path.startswith("http"): 49 | scaler.local_data_path = modelloader.load_file_from_url( 50 | scaler.data_path, 51 | model_dir=self.model_download_path, 52 | ) 53 | if not os.path.exists(scaler.local_data_path): 54 | raise FileNotFoundError(f"DAT data missing: {scaler.local_data_path}") 55 | return scaler 56 | raise ValueError(f"Unable to find model info: {path}") 57 | 58 | 59 | def get_dat_models(scaler): 60 | return [ 61 | UpscalerData( 62 | name="DAT x2", 63 | path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x2.pth", 64 | scale=2, 65 | upscaler=scaler, 66 | ), 67 | UpscalerData( 68 | name="DAT x3", 69 | path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x3.pth", 70 | scale=3, 71 | upscaler=scaler, 72 | ), 73 | UpscalerData( 74 | name="DAT x4", 75 | path="https://github.com/n0kovo/dat_upscaler_models/raw/main/DAT/DAT_x4.pth", 76 | scale=4, 77 | upscaler=scaler, 78 | ), 79 | ] 80 | -------------------------------------------------------------------------------- /modules/esrgan_model.py: -------------------------------------------------------------------------------- 1 | from modules import modelloader, devices, errors 2 | from modules.shared import opts 3 | from modules.upscaler import Upscaler, UpscalerData 4 | from modules.upscaler_utils import upscale_with_model 5 | 6 | 7 | class UpscalerESRGAN(Upscaler): 8 | def __init__(self, dirname): 9 | self.name = "ESRGAN" 10 | self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/ESRGAN.pth" 11 | self.model_name = "ESRGAN_4x" 12 | self.scalers = [] 13 | self.user_path = dirname 14 | super().__init__() 15 | model_paths = self.find_models(ext_filter=[".pt", ".pth"]) 16 | scalers = [] 17 | if len(model_paths) == 0: 18 | scaler_data = UpscalerData(self.model_name, self.model_url, self, 4) 19 | scalers.append(scaler_data) 20 | for file in model_paths: 21 | if file.startswith("http"): 22 | name = self.model_name 23 | else: 24 | name = modelloader.friendly_name(file) 25 | 26 | scaler_data = UpscalerData(name, file, self, 4) 27 | self.scalers.append(scaler_data) 28 | 29 | def do_upscale(self, img, selected_model): 30 | try: 31 | model = self.load_model(selected_model) 32 | except Exception: 33 | errors.report(f"Unable to load ESRGAN model {selected_model}", exc_info=True) 34 | return img 35 | model.to(devices.device_esrgan) 36 | return esrgan_upscale(model, img) 37 | 38 | def load_model(self, path: str): 39 | if path.startswith("http"): 40 | # TODO: this doesn't use `path` at all? 41 | filename = modelloader.load_file_from_url( 42 | url=self.model_url, 43 | model_dir=self.model_download_path, 44 | file_name=f"{self.model_name}.pth", 45 | ) 46 | else: 47 | filename = path 48 | 49 | return modelloader.load_spandrel_model( 50 | filename, 51 | device=('cpu' if devices.device_esrgan.type == 'mps' else None), 52 | expected_architecture='ESRGAN', 53 | ) 54 | 55 | 56 | def esrgan_upscale(model, img): 57 | return upscale_with_model( 58 | model, 59 | img, 60 | tile_size=opts.ESRGAN_tile, 61 | tile_overlap=opts.ESRGAN_tile_overlap, 62 | ) 63 | -------------------------------------------------------------------------------- /modules/extra_networks_hypernet.py: -------------------------------------------------------------------------------- 1 | from modules import extra_networks, shared 2 | from modules.hypernetworks import hypernetwork 3 | 4 | 5 | class ExtraNetworkHypernet(extra_networks.ExtraNetwork): 6 | def __init__(self): 7 | super().__init__('hypernet') 8 | 9 | def activate(self, p, params_list): 10 | additional = shared.opts.sd_hypernetwork 11 | 12 | if additional != "None" and additional in shared.hypernetworks and not any(x for x in params_list if x.items[0] == additional): 13 | hypernet_prompt_text = f"" 14 | p.all_prompts = [f"{prompt}{hypernet_prompt_text}" for prompt in p.all_prompts] 15 | params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) 16 | 17 | names = [] 18 | multipliers = [] 19 | for params in params_list: 20 | assert params.items 21 | 22 | names.append(params.items[0]) 23 | multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) 24 | 25 | hypernetwork.load_hypernetworks(names, multipliers) 26 | 27 | def deactivate(self, p): 28 | pass 29 | -------------------------------------------------------------------------------- /modules/face_restoration.py: -------------------------------------------------------------------------------- 1 | from modules import shared 2 | 3 | 4 | class FaceRestoration: 5 | def name(self): 6 | return "None" 7 | 8 | def restore(self, np_image): 9 | return np_image 10 | 11 | 12 | def restore_faces(np_image): 13 | face_restorers = [x for x in shared.face_restorers if x.name() == shared.opts.face_restoration_model or shared.opts.face_restoration_model is None] 14 | if len(face_restorers) == 0: 15 | return np_image 16 | 17 | face_restorer = face_restorers[0] 18 | 19 | return face_restorer.restore(np_image) 20 | -------------------------------------------------------------------------------- /modules/fifo_lock.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import collections 3 | 4 | 5 | # reference: https://gist.github.com/vitaliyp/6d54dd76ca2c3cdfc1149d33007dc34a 6 | class FIFOLock(object): 7 | def __init__(self): 8 | self._lock = threading.Lock() 9 | self._inner_lock = threading.Lock() 10 | self._pending_threads = collections.deque() 11 | 12 | def acquire(self, blocking=True): 13 | with self._inner_lock: 14 | lock_acquired = self._lock.acquire(False) 15 | if lock_acquired: 16 | return True 17 | elif not blocking: 18 | return False 19 | 20 | release_event = threading.Event() 21 | self._pending_threads.append(release_event) 22 | 23 | release_event.wait() 24 | return self._lock.acquire() 25 | 26 | def release(self): 27 | with self._inner_lock: 28 | if self._pending_threads: 29 | release_event = self._pending_threads.popleft() 30 | release_event.set() 31 | 32 | self._lock.release() 33 | 34 | __enter__ = acquire 35 | 36 | def __exit__(self, t, v, tb): 37 | self.release() 38 | -------------------------------------------------------------------------------- /modules/gfpgan_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | import os 5 | 6 | import torch 7 | 8 | from modules import ( 9 | devices, 10 | errors, 11 | face_restoration, 12 | face_restoration_utils, 13 | modelloader, 14 | shared, 15 | ) 16 | 17 | logger = logging.getLogger(__name__) 18 | model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" 19 | model_download_name = "GFPGANv1.4.pth" 20 | gfpgan_face_restorer: face_restoration.FaceRestoration | None = None 21 | 22 | 23 | class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration): 24 | def name(self): 25 | return "GFPGAN" 26 | 27 | def get_device(self): 28 | return devices.device_gfpgan 29 | 30 | def load_net(self) -> torch.Module: 31 | for model_path in modelloader.load_models( 32 | model_path=self.model_path, 33 | model_url=model_url, 34 | command_path=self.model_path, 35 | download_name=model_download_name, 36 | ext_filter=['.pth'], 37 | ): 38 | if 'GFPGAN' in os.path.basename(model_path): 39 | return modelloader.load_spandrel_model( 40 | model_path, 41 | device=self.get_device(), 42 | expected_architecture='GFPGAN', 43 | ).model 44 | raise ValueError("No GFPGAN model found") 45 | 46 | def restore(self, np_image): 47 | def restore_face(cropped_face_t): 48 | assert self.net is not None 49 | return self.net(cropped_face_t, return_rgb=False)[0] 50 | 51 | return self.restore_with_helper(np_image, restore_face) 52 | 53 | 54 | def gfpgan_fix_faces(np_image): 55 | if gfpgan_face_restorer: 56 | return gfpgan_face_restorer.restore(np_image) 57 | logger.warning("GFPGAN face restorer not set up") 58 | return np_image 59 | 60 | 61 | def setup_model(dirname: str) -> None: 62 | global gfpgan_face_restorer 63 | 64 | try: 65 | face_restoration_utils.patch_facexlib(dirname) 66 | gfpgan_face_restorer = FaceRestorerGFPGAN(model_path=dirname) 67 | shared.face_restorers.append(gfpgan_face_restorer) 68 | except Exception: 69 | errors.report("Error setting up GFPGAN", exc_info=True) 70 | -------------------------------------------------------------------------------- /modules/gitpython_hack.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import io 4 | import subprocess 5 | 6 | import git 7 | 8 | 9 | class Git(git.Git): 10 | """ 11 | Git subclassed to never use persistent processes. 12 | """ 13 | 14 | def _get_persistent_cmd(self, attr_name, cmd_name, *args, **kwargs): 15 | raise NotImplementedError(f"Refusing to use persistent process: {attr_name} ({cmd_name} {args} {kwargs})") 16 | 17 | def get_object_header(self, ref: str | bytes) -> tuple[str, str, int]: 18 | ret = subprocess.check_output( 19 | [self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch-check"], 20 | input=self._prepare_ref(ref), 21 | cwd=self._working_dir, 22 | timeout=2, 23 | ) 24 | return self._parse_object_header(ret) 25 | 26 | def stream_object_data(self, ref: str) -> tuple[str, str, int, Git.CatFileContentStream]: 27 | # Not really streaming, per se; this buffers the entire object in memory. 28 | # Shouldn't be a problem for our use case, since we're only using this for 29 | # object headers (commit objects). 30 | ret = subprocess.check_output( 31 | [self.GIT_PYTHON_GIT_EXECUTABLE, "cat-file", "--batch"], 32 | input=self._prepare_ref(ref), 33 | cwd=self._working_dir, 34 | timeout=30, 35 | ) 36 | bio = io.BytesIO(ret) 37 | hexsha, typename, size = self._parse_object_header(bio.readline()) 38 | return (hexsha, typename, size, self.CatFileContentStream(size, bio)) 39 | 40 | 41 | class Repo(git.Repo): 42 | GitCommandWrapperType = Git 43 | -------------------------------------------------------------------------------- /modules/gradio_extensons.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from modules import scripts, ui_tempdir, patches 4 | 5 | 6 | def add_classes_to_gradio_component(comp): 7 | """ 8 | this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others 9 | """ 10 | 11 | comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])] 12 | 13 | if getattr(comp, 'multiselect', False): 14 | comp.elem_classes.append('multiselect') 15 | 16 | 17 | def IOComponent_init(self, *args, **kwargs): 18 | self.webui_tooltip = kwargs.pop('tooltip', None) 19 | 20 | if scripts.scripts_current is not None: 21 | scripts.scripts_current.before_component(self, **kwargs) 22 | 23 | scripts.script_callbacks.before_component_callback(self, **kwargs) 24 | 25 | res = original_IOComponent_init(self, *args, **kwargs) 26 | 27 | add_classes_to_gradio_component(self) 28 | 29 | scripts.script_callbacks.after_component_callback(self, **kwargs) 30 | 31 | if scripts.scripts_current is not None: 32 | scripts.scripts_current.after_component(self, **kwargs) 33 | 34 | return res 35 | 36 | 37 | def Block_get_config(self): 38 | config = original_Block_get_config(self) 39 | 40 | webui_tooltip = getattr(self, 'webui_tooltip', None) 41 | if webui_tooltip: 42 | config["webui_tooltip"] = webui_tooltip 43 | 44 | config.pop('example_inputs', None) 45 | 46 | return config 47 | 48 | 49 | def BlockContext_init(self, *args, **kwargs): 50 | if scripts.scripts_current is not None: 51 | scripts.scripts_current.before_component(self, **kwargs) 52 | 53 | scripts.script_callbacks.before_component_callback(self, **kwargs) 54 | 55 | res = original_BlockContext_init(self, *args, **kwargs) 56 | 57 | add_classes_to_gradio_component(self) 58 | 59 | scripts.script_callbacks.after_component_callback(self, **kwargs) 60 | 61 | if scripts.scripts_current is not None: 62 | scripts.scripts_current.after_component(self, **kwargs) 63 | 64 | return res 65 | 66 | 67 | def Blocks_get_config_file(self, *args, **kwargs): 68 | config = original_Blocks_get_config_file(self, *args, **kwargs) 69 | 70 | for comp_config in config["components"]: 71 | if "example_inputs" in comp_config: 72 | comp_config["example_inputs"] = {"serialized": []} 73 | 74 | return config 75 | 76 | 77 | original_IOComponent_init = patches.patch(__name__, obj=gr.components.IOComponent, field="__init__", replacement=IOComponent_init) 78 | original_Block_get_config = patches.patch(__name__, obj=gr.blocks.Block, field="get_config", replacement=Block_get_config) 79 | original_BlockContext_init = patches.patch(__name__, obj=gr.blocks.BlockContext, field="__init__", replacement=BlockContext_init) 80 | original_Blocks_get_config_file = patches.patch(__name__, obj=gr.blocks.Blocks, field="get_config_file", replacement=Blocks_get_config_file) 81 | 82 | 83 | ui_tempdir.install_ui_tempdir_override() 84 | -------------------------------------------------------------------------------- /modules/hashes.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os.path 3 | 4 | from modules import shared 5 | import modules.cache 6 | 7 | dump_cache = modules.cache.dump_cache 8 | cache = modules.cache.cache 9 | 10 | 11 | def calculate_sha256(filename): 12 | hash_sha256 = hashlib.sha256() 13 | blksize = 1024 * 1024 14 | 15 | with open(filename, "rb") as f: 16 | for chunk in iter(lambda: f.read(blksize), b""): 17 | hash_sha256.update(chunk) 18 | 19 | return hash_sha256.hexdigest() 20 | 21 | 22 | def sha256_from_cache(filename, title, use_addnet_hash=False): 23 | hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes") 24 | try: 25 | ondisk_mtime = os.path.getmtime(filename) 26 | except FileNotFoundError: 27 | return None 28 | 29 | if title not in hashes: 30 | return None 31 | 32 | cached_sha256 = hashes[title].get("sha256", None) 33 | cached_mtime = hashes[title].get("mtime", 0) 34 | 35 | if ondisk_mtime > cached_mtime or cached_sha256 is None: 36 | return None 37 | 38 | return cached_sha256 39 | 40 | 41 | def sha256(filename, title, use_addnet_hash=False): 42 | hashes = cache("hashes-addnet") if use_addnet_hash else cache("hashes") 43 | 44 | sha256_value = sha256_from_cache(filename, title, use_addnet_hash) 45 | if sha256_value is not None: 46 | return sha256_value 47 | 48 | if shared.cmd_opts.no_hashing: 49 | return None 50 | 51 | print(f"Calculating sha256 for {filename}: ", end='') 52 | if use_addnet_hash: 53 | with open(filename, "rb") as file: 54 | sha256_value = addnet_hash_safetensors(file) 55 | else: 56 | sha256_value = calculate_sha256(filename) 57 | print(f"{sha256_value}") 58 | 59 | hashes[title] = { 60 | "mtime": os.path.getmtime(filename), 61 | "sha256": sha256_value, 62 | } 63 | 64 | dump_cache() 65 | 66 | return sha256_value 67 | 68 | 69 | def addnet_hash_safetensors(b): 70 | """kohya-ss hash for safetensors from https://github.com/kohya-ss/sd-scripts/blob/main/library/train_util.py""" 71 | hash_sha256 = hashlib.sha256() 72 | blksize = 1024 * 1024 73 | 74 | b.seek(0) 75 | header = b.read(8) 76 | n = int.from_bytes(header, "little") 77 | 78 | offset = n + 8 79 | b.seek(offset) 80 | for chunk in iter(lambda: b.read(blksize), b""): 81 | hash_sha256.update(chunk) 82 | 83 | return hash_sha256.hexdigest() 84 | 85 | -------------------------------------------------------------------------------- /modules/hat_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from modules import modelloader, devices 5 | from modules.shared import opts 6 | from modules.upscaler import Upscaler, UpscalerData 7 | from modules.upscaler_utils import upscale_with_model 8 | 9 | 10 | class UpscalerHAT(Upscaler): 11 | def __init__(self, dirname): 12 | self.name = "HAT" 13 | self.scalers = [] 14 | self.user_path = dirname 15 | super().__init__() 16 | for file in self.find_models(ext_filter=[".pt", ".pth"]): 17 | name = modelloader.friendly_name(file) 18 | scale = 4 # TODO: scale might not be 4, but we can't know without loading the model 19 | scaler_data = UpscalerData(name, file, upscaler=self, scale=scale) 20 | self.scalers.append(scaler_data) 21 | 22 | def do_upscale(self, img, selected_model): 23 | try: 24 | model = self.load_model(selected_model) 25 | except Exception as e: 26 | print(f"Unable to load HAT model {selected_model}: {e}", file=sys.stderr) 27 | return img 28 | model.to(devices.device_esrgan) # TODO: should probably be device_hat 29 | return upscale_with_model( 30 | model, 31 | img, 32 | tile_size=opts.ESRGAN_tile, # TODO: should probably be HAT_tile 33 | tile_overlap=opts.ESRGAN_tile_overlap, # TODO: should probably be HAT_tile_overlap 34 | ) 35 | 36 | def load_model(self, path: str): 37 | if not os.path.isfile(path): 38 | raise FileNotFoundError(f"Model file {path} not found") 39 | return modelloader.load_spandrel_model( 40 | path, 41 | device=devices.device_esrgan, # TODO: should probably be device_hat 42 | expected_architecture='HAT', 43 | ) 44 | -------------------------------------------------------------------------------- /modules/hypernetworks/ui.py: -------------------------------------------------------------------------------- 1 | import html 2 | 3 | import gradio as gr 4 | import modules.hypernetworks.hypernetwork 5 | from modules import devices, sd_hijack, shared 6 | 7 | not_available = ["hardswish", "multiheadattention"] 8 | keys = [x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict if x not in not_available] 9 | 10 | 11 | def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): 12 | filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure) 13 | 14 | return gr.Dropdown.update(choices=sorted(shared.hypernetworks)), f"Created: {filename}", "" 15 | 16 | 17 | def train_hypernetwork(*args): 18 | shared.loaded_hypernetworks = [] 19 | 20 | assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' 21 | 22 | try: 23 | sd_hijack.undo_optimizations() 24 | 25 | hypernetwork, filename = modules.hypernetworks.hypernetwork.train_hypernetwork(*args) 26 | 27 | res = f""" 28 | Training {'interrupted' if shared.state.interrupted else 'finished'} at {hypernetwork.step} steps. 29 | Hypernetwork saved to {html.escape(filename)} 30 | """ 31 | return res, "" 32 | except Exception: 33 | raise 34 | finally: 35 | shared.sd_model.cond_stage_model.to(devices.device) 36 | shared.sd_model.first_stage_model.to(devices.device) 37 | sd_hijack.apply_optimizations() 38 | 39 | -------------------------------------------------------------------------------- /modules/import_hook.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | # this will break any attempt to import xformers which will prevent stability diffusion repo from trying to use it 4 | if "--xformers" not in "".join(sys.argv): 5 | sys.modules["xformers"] = None 6 | 7 | # Hack to fix a changed import in torchvision 0.17+, which otherwise breaks 8 | # basicsr; see https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/13985 9 | try: 10 | import torchvision.transforms.functional_tensor # noqa: F401 11 | except ImportError: 12 | try: 13 | import torchvision.transforms.functional as functional 14 | sys.modules["torchvision.transforms.functional_tensor"] = functional 15 | except ImportError: 16 | pass # shrug... 17 | -------------------------------------------------------------------------------- /modules/infotext_versions.py: -------------------------------------------------------------------------------- 1 | from modules import shared 2 | from packaging import version 3 | import re 4 | 5 | 6 | v160 = version.parse("1.6.0") 7 | v170_tsnr = version.parse("v1.7.0-225") 8 | v180 = version.parse("1.8.0") 9 | v180_hr_styles = version.parse("1.8.0-139") 10 | 11 | 12 | def parse_version(text): 13 | if text is None: 14 | return None 15 | 16 | m = re.match(r'([^-]+-[^-]+)-.*', text) 17 | if m: 18 | text = m.group(1) 19 | 20 | try: 21 | return version.parse(text) 22 | except Exception: 23 | return None 24 | 25 | 26 | def backcompat(d): 27 | """Checks infotext Version field, and enables backwards compatibility options according to it.""" 28 | 29 | if not shared.opts.auto_backcompat: 30 | return 31 | 32 | ver = parse_version(d.get("Version")) 33 | if ver is None: 34 | return 35 | 36 | if ver < v160 and '[' in d.get('Prompt', ''): 37 | d["Old prompt editing timelines"] = True 38 | 39 | if ver < v160 and d.get('Sampler', '') in ('DDIM', 'PLMS'): 40 | d["Pad conds v0"] = True 41 | 42 | if ver < v170_tsnr: 43 | d["Downcast alphas_cumprod"] = True 44 | 45 | if ver < v180 and d.get('Refiner'): 46 | d["Refiner switch by sampling steps"] = True 47 | -------------------------------------------------------------------------------- /modules/localization.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | from modules import errors, scripts 5 | 6 | localizations = {} 7 | 8 | 9 | def list_localizations(dirname): 10 | localizations.clear() 11 | 12 | for file in os.listdir(dirname): 13 | fn, ext = os.path.splitext(file) 14 | if ext.lower() != ".json": 15 | continue 16 | 17 | localizations[fn] = [os.path.join(dirname, file)] 18 | 19 | for file in scripts.list_scripts("localizations", ".json"): 20 | fn, ext = os.path.splitext(file.filename) 21 | if fn not in localizations: 22 | localizations[fn] = [] 23 | localizations[fn].append(file.path) 24 | 25 | 26 | def localization_js(current_localization_name: str) -> str: 27 | fns = localizations.get(current_localization_name, None) 28 | data = {} 29 | if fns is not None: 30 | for fn in fns: 31 | try: 32 | with open(fn, "r", encoding="utf8") as file: 33 | data.update(json.load(file)) 34 | except Exception: 35 | errors.report(f"Error loading localization from {fn}", exc_info=True) 36 | 37 | return f"window.localization = {json.dumps(data)}" 38 | -------------------------------------------------------------------------------- /modules/logging_config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | try: 5 | from tqdm import tqdm 6 | 7 | 8 | class TqdmLoggingHandler(logging.Handler): 9 | def __init__(self, fallback_handler: logging.Handler): 10 | super().__init__() 11 | self.fallback_handler = fallback_handler 12 | 13 | def emit(self, record): 14 | try: 15 | # If there are active tqdm progress bars, 16 | # attempt to not interfere with them. 17 | if tqdm._instances: 18 | tqdm.write(self.format(record)) 19 | else: 20 | self.fallback_handler.emit(record) 21 | except Exception: 22 | self.fallback_handler.emit(record) 23 | 24 | except ImportError: 25 | TqdmLoggingHandler = None 26 | 27 | 28 | def setup_logging(loglevel): 29 | if loglevel is None: 30 | loglevel = os.environ.get("SD_WEBUI_LOG_LEVEL") 31 | 32 | if not loglevel: 33 | return 34 | 35 | if logging.root.handlers: 36 | # Already configured, do not interfere 37 | return 38 | 39 | formatter = logging.Formatter( 40 | '%(asctime)s %(levelname)s [%(name)s] %(message)s', 41 | '%Y-%m-%d %H:%M:%S', 42 | ) 43 | 44 | if os.environ.get("SD_WEBUI_RICH_LOG"): 45 | from rich.logging import RichHandler 46 | handler = RichHandler() 47 | else: 48 | handler = logging.StreamHandler() 49 | handler.setFormatter(formatter) 50 | 51 | if TqdmLoggingHandler: 52 | handler = TqdmLoggingHandler(handler) 53 | 54 | handler.setFormatter(formatter) 55 | 56 | log_level = getattr(logging, loglevel.upper(), None) or logging.INFO 57 | logging.root.setLevel(log_level) 58 | logging.root.addHandler(handler) 59 | -------------------------------------------------------------------------------- /modules/memmon.py: -------------------------------------------------------------------------------- 1 | import threading 2 | import time 3 | from collections import defaultdict 4 | 5 | import torch 6 | 7 | 8 | class MemUsageMonitor(threading.Thread): 9 | run_flag = None 10 | device = None 11 | disabled = False 12 | opts = None 13 | data = None 14 | 15 | def __init__(self, name, device, opts): 16 | threading.Thread.__init__(self) 17 | self.name = name 18 | self.device = device 19 | self.opts = opts 20 | 21 | self.daemon = True 22 | self.run_flag = threading.Event() 23 | self.data = defaultdict(int) 24 | 25 | try: 26 | self.cuda_mem_get_info() 27 | torch.cuda.memory_stats(self.device) 28 | except Exception as e: # AMD or whatever 29 | print(f"Warning: caught exception '{e}', memory monitor disabled") 30 | self.disabled = True 31 | 32 | def cuda_mem_get_info(self): 33 | index = self.device.index if self.device.index is not None else torch.cuda.current_device() 34 | return torch.cuda.mem_get_info(index) 35 | 36 | def run(self): 37 | if self.disabled: 38 | return 39 | 40 | while True: 41 | self.run_flag.wait() 42 | 43 | torch.cuda.reset_peak_memory_stats() 44 | self.data.clear() 45 | 46 | if self.opts.memmon_poll_rate <= 0: 47 | self.run_flag.clear() 48 | continue 49 | 50 | self.data["min_free"] = self.cuda_mem_get_info()[0] 51 | 52 | while self.run_flag.is_set(): 53 | free, total = self.cuda_mem_get_info() 54 | self.data["min_free"] = min(self.data["min_free"], free) 55 | 56 | time.sleep(1 / self.opts.memmon_poll_rate) 57 | 58 | def dump_debug(self): 59 | print(self, 'recorded data:') 60 | for k, v in self.read().items(): 61 | print(k, -(v // -(1024 ** 2))) 62 | 63 | print(self, 'raw torch memory stats:') 64 | tm = torch.cuda.memory_stats(self.device) 65 | for k, v in tm.items(): 66 | if 'bytes' not in k: 67 | continue 68 | print('\t' if 'peak' in k else '', k, -(v // -(1024 ** 2))) 69 | 70 | print(torch.cuda.memory_summary()) 71 | 72 | def monitor(self): 73 | self.run_flag.set() 74 | 75 | def read(self): 76 | if not self.disabled: 77 | free, total = self.cuda_mem_get_info() 78 | self.data["free"] = free 79 | self.data["total"] = total 80 | 81 | torch_stats = torch.cuda.memory_stats(self.device) 82 | self.data["active"] = torch_stats["active.all.current"] 83 | self.data["active_peak"] = torch_stats["active_bytes.all.peak"] 84 | self.data["reserved"] = torch_stats["reserved_bytes.all.current"] 85 | self.data["reserved_peak"] = torch_stats["reserved_bytes.all.peak"] 86 | self.data["system_peak"] = total - self.data["min_free"] 87 | 88 | return self.data 89 | 90 | def stop(self): 91 | self.run_flag.clear() 92 | return self.read() 93 | -------------------------------------------------------------------------------- /modules/models/diffusion/uni_pc/__init__.py: -------------------------------------------------------------------------------- 1 | from .sampler import UniPCSampler # noqa: F401 2 | -------------------------------------------------------------------------------- /modules/ngrok.py: -------------------------------------------------------------------------------- 1 | import ngrok 2 | 3 | # Connect to ngrok for ingress 4 | def connect(token, port, options): 5 | account = None 6 | if token is None: 7 | token = 'None' 8 | else: 9 | if ':' in token: 10 | # token = authtoken:username:password 11 | token, username, password = token.split(':', 2) 12 | account = f"{username}:{password}" 13 | 14 | # For all options see: https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py 15 | if not options.get('authtoken_from_env'): 16 | options['authtoken'] = token 17 | if account: 18 | options['basic_auth'] = account 19 | if not options.get('session_metadata'): 20 | options['session_metadata'] = 'stable-diffusion-webui' 21 | 22 | 23 | try: 24 | public_url = ngrok.connect(f"127.0.0.1:{port}", **options).url() 25 | except Exception as e: 26 | print(f'Invalid ngrok authtoken? ngrok connection aborted due to: {e}\n' 27 | f'Your token: {token}, get the right one on https://dashboard.ngrok.com/get-started/your-authtoken') 28 | else: 29 | print(f'ngrok connected to localhost:{port}! URL: {public_url}\n' 30 | 'You can use this link after the launch is complete.') 31 | -------------------------------------------------------------------------------- /modules/npu_specific.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch 3 | 4 | from modules import shared 5 | 6 | 7 | def check_for_npu(): 8 | if importlib.util.find_spec("torch_npu") is None: 9 | return False 10 | import torch_npu 11 | 12 | try: 13 | # Will raise a RuntimeError if no NPU is found 14 | _ = torch_npu.npu.device_count() 15 | return torch.npu.is_available() 16 | except RuntimeError: 17 | return False 18 | 19 | 20 | def get_npu_device_string(): 21 | if shared.cmd_opts.device_id is not None: 22 | return f"npu:{shared.cmd_opts.device_id}" 23 | return "npu:0" 24 | 25 | 26 | def torch_npu_gc(): 27 | with torch.npu.device(get_npu_device_string()): 28 | torch.npu.empty_cache() 29 | 30 | 31 | has_npu = check_for_npu() 32 | -------------------------------------------------------------------------------- /modules/patches.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | 4 | def patch(key, obj, field, replacement): 5 | """Replaces a function in a module or a class. 6 | 7 | Also stores the original function in this module, possible to be retrieved via original(key, obj, field). 8 | If the function is already replaced by this caller (key), an exception is raised -- use undo() before that. 9 | 10 | Arguments: 11 | key: identifying information for who is doing the replacement. You can use __name__. 12 | obj: the module or the class 13 | field: name of the function as a string 14 | replacement: the new function 15 | 16 | Returns: 17 | the original function 18 | """ 19 | 20 | patch_key = (obj, field) 21 | if patch_key in originals[key]: 22 | raise RuntimeError(f"patch for {field} is already applied") 23 | 24 | original_func = getattr(obj, field) 25 | originals[key][patch_key] = original_func 26 | 27 | setattr(obj, field, replacement) 28 | 29 | return original_func 30 | 31 | 32 | def undo(key, obj, field): 33 | """Undoes the peplacement by the patch(). 34 | 35 | If the function is not replaced, raises an exception. 36 | 37 | Arguments: 38 | key: identifying information for who is doing the replacement. You can use __name__. 39 | obj: the module or the class 40 | field: name of the function as a string 41 | 42 | Returns: 43 | Always None 44 | """ 45 | 46 | patch_key = (obj, field) 47 | 48 | if patch_key not in originals[key]: 49 | raise RuntimeError(f"there is no patch for {field} to undo") 50 | 51 | original_func = originals[key].pop(patch_key) 52 | setattr(obj, field, original_func) 53 | 54 | return None 55 | 56 | 57 | def original(key, obj, field): 58 | """Returns the original function for the patch created by the patch() function""" 59 | patch_key = (obj, field) 60 | 61 | return originals[key].get(patch_key, None) 62 | 63 | 64 | originals = defaultdict(dict) 65 | -------------------------------------------------------------------------------- /modules/paths.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir, cwd # noqa: F401 4 | 5 | import modules.safe # noqa: F401 6 | 7 | 8 | def mute_sdxl_imports(): 9 | """create fake modules that SDXL wants to import but doesn't actually use for our purposes""" 10 | 11 | class Dummy: 12 | pass 13 | 14 | module = Dummy() 15 | module.LPIPS = None 16 | sys.modules['taming.modules.losses.lpips'] = module 17 | 18 | module = Dummy() 19 | module.StableDataModuleFromConfig = None 20 | sys.modules['sgm.data'] = module 21 | 22 | 23 | # data_path = cmd_opts_pre.data 24 | sys.path.insert(0, script_path) 25 | 26 | # search for directory of stable diffusion in following places 27 | sd_path = None 28 | possible_sd_paths = [os.path.join(script_path, 'repositories/stable-diffusion-stability-ai'), '.', os.path.dirname(script_path)] 29 | for possible_sd_path in possible_sd_paths: 30 | if os.path.exists(os.path.join(possible_sd_path, 'ldm/models/diffusion/ddpm.py')): 31 | sd_path = os.path.abspath(possible_sd_path) 32 | break 33 | 34 | assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possible_sd_paths}" 35 | 36 | mute_sdxl_imports() 37 | 38 | path_dirs = [ 39 | (sd_path, 'ldm', 'Stable Diffusion', []), 40 | (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), 41 | (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), 42 | (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), 43 | ] 44 | 45 | paths = {} 46 | 47 | for d, must_exist, what, options in path_dirs: 48 | must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist)) 49 | if not os.path.exists(must_exist_path): 50 | print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr) 51 | else: 52 | d = os.path.abspath(d) 53 | if "atstart" in options: 54 | sys.path.insert(0, d) 55 | elif "sgm" in options: 56 | # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we 57 | # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir. 58 | 59 | sys.path.insert(0, d) 60 | import sgm # noqa: F401 61 | sys.path.pop(0) 62 | else: 63 | sys.path.append(d) 64 | paths[what] = d 65 | -------------------------------------------------------------------------------- /modules/paths_internal.py: -------------------------------------------------------------------------------- 1 | """this module defines internal paths used by program and is safe to import before dependencies are installed in launch.py""" 2 | 3 | import argparse 4 | import os 5 | import sys 6 | import shlex 7 | from pathlib import Path 8 | 9 | 10 | normalized_filepath = lambda filepath: str(Path(filepath).absolute()) 11 | 12 | commandline_args = os.environ.get('COMMANDLINE_ARGS', "") 13 | sys.argv += shlex.split(commandline_args) 14 | 15 | cwd = os.getcwd() 16 | modules_path = os.path.dirname(os.path.realpath(__file__)) 17 | script_path = os.path.dirname(modules_path) 18 | 19 | sd_configs_path = os.path.join(script_path, "configs") 20 | sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml") 21 | sd_model_file = os.path.join(script_path, 'model.ckpt') 22 | default_sd_model_file = sd_model_file 23 | 24 | # Parse the --data-dir flag first so we can use it as a base for our other argument default values 25 | parser_pre = argparse.ArgumentParser(add_help=False) 26 | parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(modules_path), help="base path where all user data is stored", ) 27 | parser_pre.add_argument("--models-dir", type=str, default=None, help="base path where models are stored; overrides --data-dir", ) 28 | cmd_opts_pre = parser_pre.parse_known_args()[0] 29 | 30 | data_path = cmd_opts_pre.data_dir 31 | 32 | models_path = cmd_opts_pre.models_dir if cmd_opts_pre.models_dir else os.path.join(data_path, "models") 33 | extensions_dir = os.path.join(data_path, "extensions") 34 | extensions_builtin_dir = os.path.join(script_path, "extensions-builtin") 35 | config_states_dir = os.path.join(script_path, "config_states") 36 | default_output_dir = os.path.join(data_path, "outputs") 37 | 38 | roboto_ttf_file = os.path.join(modules_path, 'Roboto-Regular.ttf') 39 | -------------------------------------------------------------------------------- /modules/processing_scripts/comments.py: -------------------------------------------------------------------------------- 1 | from modules import scripts, shared, script_callbacks 2 | import re 3 | 4 | 5 | def strip_comments(text): 6 | text = re.sub('(^|\n)#[^\n]*(\n|$)', '\n', text) # while line comment 7 | text = re.sub('#[^\n]*(\n|$)', '\n', text) # in the middle of the line comment 8 | 9 | return text 10 | 11 | 12 | class ScriptStripComments(scripts.Script): 13 | def title(self): 14 | return "Comments" 15 | 16 | def show(self, is_img2img): 17 | return scripts.AlwaysVisible 18 | 19 | def process(self, p, *args): 20 | if not shared.opts.enable_prompt_comments: 21 | return 22 | 23 | p.all_prompts = [strip_comments(x) for x in p.all_prompts] 24 | p.all_negative_prompts = [strip_comments(x) for x in p.all_negative_prompts] 25 | 26 | p.main_prompt = strip_comments(p.main_prompt) 27 | p.main_negative_prompt = strip_comments(p.main_negative_prompt) 28 | 29 | if getattr(p, 'enable_hr', False): 30 | p.all_hr_prompts = [strip_comments(x) for x in p.all_hr_prompts] 31 | p.all_hr_negative_prompts = [strip_comments(x) for x in p.all_hr_negative_prompts] 32 | 33 | p.hr_prompt = strip_comments(p.hr_prompt) 34 | p.hr_negative_prompt = strip_comments(p.hr_negative_prompt) 35 | 36 | 37 | def before_token_counter(params: script_callbacks.BeforeTokenCounterParams): 38 | if not shared.opts.enable_prompt_comments: 39 | return 40 | 41 | params.prompt = strip_comments(params.prompt) 42 | 43 | 44 | script_callbacks.on_before_token_counter(before_token_counter) 45 | 46 | 47 | shared.options_templates.update(shared.options_section(('sd', "Stable Diffusion", "sd"), { 48 | "enable_prompt_comments": shared.OptionInfo(True, "Enable comments").info("Use # anywhere in the prompt to hide the text between # and the end of the line from the generation."), 49 | })) 50 | -------------------------------------------------------------------------------- /modules/processing_scripts/refiner.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from modules import scripts, sd_models 4 | from modules.infotext_utils import PasteField 5 | from modules.ui_common import create_refresh_button 6 | from modules.ui_components import InputAccordion 7 | 8 | 9 | class ScriptRefiner(scripts.ScriptBuiltinUI): 10 | section = "accordions" 11 | create_group = False 12 | 13 | def __init__(self): 14 | pass 15 | 16 | def title(self): 17 | return "Refiner" 18 | 19 | def show(self, is_img2img): 20 | return scripts.AlwaysVisible 21 | 22 | def ui(self, is_img2img): 23 | with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner: 24 | with gr.Row(): 25 | refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation") 26 | create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh")) 27 | 28 | refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation") 29 | 30 | def lookup_checkpoint(title): 31 | info = sd_models.get_closet_checkpoint_match(title) 32 | return None if info is None else info.title 33 | 34 | self.infotext_fields = [ 35 | PasteField(enable_refiner, lambda d: 'Refiner' in d), 36 | PasteField(refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner')), api="refiner_checkpoint"), 37 | PasteField(refiner_switch_at, 'Refiner switch at', api="refiner_switch_at"), 38 | ] 39 | 40 | return enable_refiner, refiner_checkpoint, refiner_switch_at 41 | 42 | def setup(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at): 43 | # the actual implementation is in sd_samplers_common.py, apply_refiner 44 | 45 | if not enable_refiner or refiner_checkpoint in (None, "", "None"): 46 | p.refiner_checkpoint = None 47 | p.refiner_switch_at = None 48 | else: 49 | p.refiner_checkpoint = refiner_checkpoint 50 | p.refiner_switch_at = refiner_switch_at 51 | -------------------------------------------------------------------------------- /modules/processing_scripts/sampler.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from modules import scripts, sd_samplers, sd_schedulers, shared 4 | from modules.infotext_utils import PasteField 5 | from modules.ui_components import FormRow, FormGroup 6 | 7 | 8 | class ScriptSampler(scripts.ScriptBuiltinUI): 9 | section = "sampler" 10 | 11 | def __init__(self): 12 | self.steps = None 13 | self.sampler_name = None 14 | self.scheduler = None 15 | 16 | def title(self): 17 | return "Sampler" 18 | 19 | def ui(self, is_img2img): 20 | sampler_names = [x.name for x in sd_samplers.visible_samplers()] 21 | scheduler_names = [x.label for x in sd_schedulers.schedulers] 22 | 23 | if shared.opts.samplers_in_dropdown: 24 | with FormRow(elem_id=f"sampler_selection_{self.tabname}"): 25 | self.sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{self.tabname}_sampling", choices=sampler_names, value=sampler_names[0]) 26 | self.scheduler = gr.Dropdown(label='Schedule type', elem_id=f"{self.tabname}_scheduler", choices=scheduler_names, value=scheduler_names[0]) 27 | self.steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{self.tabname}_steps", label="Sampling steps", value=20) 28 | else: 29 | with FormGroup(elem_id=f"sampler_selection_{self.tabname}"): 30 | self.steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{self.tabname}_steps", label="Sampling steps", value=20) 31 | self.sampler_name = gr.Radio(label='Sampling method', elem_id=f"{self.tabname}_sampling", choices=sampler_names, value=sampler_names[0]) 32 | self.scheduler = gr.Dropdown(label='Schedule type', elem_id=f"{self.tabname}_scheduler", choices=scheduler_names, value=scheduler_names[0]) 33 | 34 | self.infotext_fields = [ 35 | PasteField(self.steps, "Steps", api="steps"), 36 | PasteField(self.sampler_name, sd_samplers.get_sampler_from_infotext, api="sampler_name"), 37 | PasteField(self.scheduler, sd_samplers.get_scheduler_from_infotext, api="scheduler"), 38 | ] 39 | 40 | return self.steps, self.sampler_name, self.scheduler 41 | 42 | def setup(self, p, steps, sampler_name, scheduler): 43 | p.steps = steps 44 | p.sampler_name = sampler_name 45 | p.scheduler = scheduler 46 | -------------------------------------------------------------------------------- /modules/profiling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from modules import shared, ui_gradio_extensions 4 | 5 | 6 | class Profiler: 7 | def __init__(self): 8 | if not shared.opts.profiling_enable: 9 | self.profiler = None 10 | return 11 | 12 | activities = [] 13 | if "CPU" in shared.opts.profiling_activities: 14 | activities.append(torch.profiler.ProfilerActivity.CPU) 15 | if "CUDA" in shared.opts.profiling_activities: 16 | activities.append(torch.profiler.ProfilerActivity.CUDA) 17 | 18 | if not activities: 19 | self.profiler = None 20 | return 21 | 22 | self.profiler = torch.profiler.profile( 23 | activities=activities, 24 | record_shapes=shared.opts.profiling_record_shapes, 25 | profile_memory=shared.opts.profiling_profile_memory, 26 | with_stack=shared.opts.profiling_with_stack 27 | ) 28 | 29 | def __enter__(self): 30 | if self.profiler: 31 | self.profiler.__enter__() 32 | 33 | return self 34 | 35 | def __exit__(self, exc_type, exc, exc_tb): 36 | if self.profiler: 37 | shared.state.textinfo = "Finishing profile..." 38 | 39 | self.profiler.__exit__(exc_type, exc, exc_tb) 40 | 41 | self.profiler.export_chrome_trace(shared.opts.profiling_filename) 42 | 43 | 44 | def webpath(): 45 | return ui_gradio_extensions.webpath(shared.opts.profiling_filename) 46 | 47 | -------------------------------------------------------------------------------- /modules/restart.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | from modules.paths_internal import script_path 5 | 6 | 7 | def is_restartable() -> bool: 8 | """ 9 | Return True if the webui is restartable (i.e. there is something watching to restart it with) 10 | """ 11 | return bool(os.environ.get('SD_WEBUI_RESTART')) 12 | 13 | 14 | def restart_program() -> None: 15 | """creates file tmp/restart and immediately stops the process, which webui.bat/webui.sh interpret as a command to start webui again""" 16 | 17 | tmpdir = Path(script_path) / "tmp" 18 | tmpdir.mkdir(parents=True, exist_ok=True) 19 | (tmpdir / "restart").touch() 20 | 21 | stop_program() 22 | 23 | 24 | def stop_program() -> None: 25 | os._exit(0) 26 | -------------------------------------------------------------------------------- /modules/script_loading.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib.util 3 | 4 | from modules import errors 5 | 6 | 7 | loaded_scripts = {} 8 | 9 | 10 | def load_module(path): 11 | module_spec = importlib.util.spec_from_file_location(os.path.basename(path), path) 12 | module = importlib.util.module_from_spec(module_spec) 13 | module_spec.loader.exec_module(module) 14 | 15 | loaded_scripts[path] = module 16 | return module 17 | 18 | 19 | def preload_extensions(extensions_dir, parser, extension_list=None): 20 | if not os.path.isdir(extensions_dir): 21 | return 22 | 23 | extensions = extension_list if extension_list is not None else os.listdir(extensions_dir) 24 | for dirname in sorted(extensions): 25 | preload_script = os.path.join(extensions_dir, dirname, "preload.py") 26 | if not os.path.isfile(preload_script): 27 | continue 28 | 29 | try: 30 | module = load_module(preload_script) 31 | if hasattr(module, 'preload'): 32 | module.preload(parser) 33 | 34 | except Exception: 35 | errors.report(f"Error running preload() for {preload_script}", exc_info=True) 36 | -------------------------------------------------------------------------------- /modules/scripts_auto_postprocessing.py: -------------------------------------------------------------------------------- 1 | from modules import scripts, scripts_postprocessing, shared 2 | 3 | 4 | class ScriptPostprocessingForMainUI(scripts.Script): 5 | def __init__(self, script_postproc): 6 | self.script: scripts_postprocessing.ScriptPostprocessing = script_postproc 7 | self.postprocessing_controls = None 8 | 9 | def title(self): 10 | return self.script.name 11 | 12 | def show(self, is_img2img): 13 | return scripts.AlwaysVisible 14 | 15 | def ui(self, is_img2img): 16 | self.postprocessing_controls = self.script.ui() 17 | return self.postprocessing_controls.values() 18 | 19 | def postprocess_image(self, p, script_pp, *args): 20 | args_dict = dict(zip(self.postprocessing_controls, args)) 21 | 22 | pp = scripts_postprocessing.PostprocessedImage(script_pp.image) 23 | pp.info = {} 24 | self.script.process(pp, **args_dict) 25 | p.extra_generation_params.update(pp.info) 26 | script_pp.image = pp.image 27 | 28 | 29 | def create_auto_preprocessing_script_data(): 30 | from modules import scripts 31 | 32 | res = [] 33 | 34 | for name in shared.opts.postprocessing_enable_in_main_ui: 35 | script = next(iter([x for x in scripts.postprocessing_scripts_data if x.script_class.name == name]), None) 36 | if script is None: 37 | continue 38 | 39 | constructor = lambda s=script: ScriptPostprocessingForMainUI(s.script_class()) 40 | res.append(scripts.ScriptClassData(script_class=constructor, path=script.path, basedir=script.basedir, module=script.module)) 41 | 42 | return res 43 | -------------------------------------------------------------------------------- /modules/sd_emphasis.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | import torch 3 | 4 | 5 | class Emphasis: 6 | """Emphasis class decides how to death with (emphasized:1.1) text in prompts""" 7 | 8 | name: str = "Base" 9 | description: str = "" 10 | 11 | tokens: list[list[int]] 12 | """tokens from the chunk of the prompt""" 13 | 14 | multipliers: torch.Tensor 15 | """tensor with multipliers, once for each token""" 16 | 17 | z: torch.Tensor 18 | """output of cond transformers network (CLIP)""" 19 | 20 | def after_transformers(self): 21 | """Called after cond transformers network has processed the chunk of the prompt; this function should modify self.z to apply the emphasis""" 22 | 23 | pass 24 | 25 | 26 | class EmphasisNone(Emphasis): 27 | name = "None" 28 | description = "disable the mechanism entirely and treat (:.1.1) as literal characters" 29 | 30 | 31 | class EmphasisIgnore(Emphasis): 32 | name = "Ignore" 33 | description = "treat all empasised words as if they have no emphasis" 34 | 35 | 36 | class EmphasisOriginal(Emphasis): 37 | name = "Original" 38 | description = "the original emphasis implementation" 39 | 40 | def after_transformers(self): 41 | original_mean = self.z.mean() 42 | self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape) 43 | 44 | # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise 45 | new_mean = self.z.mean() 46 | self.z = self.z * (original_mean / new_mean) 47 | 48 | 49 | class EmphasisOriginalNoNorm(EmphasisOriginal): 50 | name = "No norm" 51 | description = "same as original, but without normalization (seems to work better for SDXL)" 52 | 53 | def after_transformers(self): 54 | self.z = self.z * self.multipliers.reshape(self.multipliers.shape + (1,)).expand(self.z.shape) 55 | 56 | 57 | def get_current_option(emphasis_option_name): 58 | return next(iter([x for x in options if x.name == emphasis_option_name]), EmphasisOriginal) 59 | 60 | 61 | def get_options_descriptions(): 62 | return ", ".join(f"{x.name}: {x.description}" for x in options) 63 | 64 | 65 | options = [ 66 | EmphasisNone, 67 | EmphasisIgnore, 68 | EmphasisOriginal, 69 | EmphasisOriginalNoNorm, 70 | ] 71 | -------------------------------------------------------------------------------- /modules/sd_hijack_checkpoint.py: -------------------------------------------------------------------------------- 1 | from torch.utils.checkpoint import checkpoint 2 | 3 | import ldm.modules.attention 4 | import ldm.modules.diffusionmodules.openaimodel 5 | 6 | 7 | def BasicTransformerBlock_forward(self, x, context=None): 8 | return checkpoint(self._forward, x, context) 9 | 10 | 11 | def AttentionBlock_forward(self, x): 12 | return checkpoint(self._forward, x) 13 | 14 | 15 | def ResBlock_forward(self, x, emb): 16 | return checkpoint(self._forward, x, emb) 17 | 18 | 19 | stored = [] 20 | 21 | 22 | def add(): 23 | if len(stored) != 0: 24 | return 25 | 26 | stored.extend([ 27 | ldm.modules.attention.BasicTransformerBlock.forward, 28 | ldm.modules.diffusionmodules.openaimodel.ResBlock.forward, 29 | ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward 30 | ]) 31 | 32 | ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward 33 | ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward 34 | ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward 35 | 36 | 37 | def remove(): 38 | if len(stored) == 0: 39 | return 40 | 41 | ldm.modules.attention.BasicTransformerBlock.forward = stored[0] 42 | ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1] 43 | ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2] 44 | 45 | stored.clear() 46 | 47 | -------------------------------------------------------------------------------- /modules/sd_hijack_ip2p.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | 3 | 4 | def should_hijack_ip2p(checkpoint_info): 5 | from modules import sd_models_config 6 | 7 | ckpt_basename = os.path.basename(checkpoint_info.filename).lower() 8 | cfg_basename = os.path.basename(sd_models_config.find_checkpoint_config_near_filename(checkpoint_info)).lower() 9 | 10 | return "pix2pix" in ckpt_basename and "pix2pix" not in cfg_basename 11 | -------------------------------------------------------------------------------- /modules/sd_hijack_open_clip.py: -------------------------------------------------------------------------------- 1 | import open_clip.tokenizer 2 | import torch 3 | 4 | from modules import sd_hijack_clip, devices 5 | from modules.shared import opts 6 | 7 | tokenizer = open_clip.tokenizer._tokenizer 8 | 9 | 10 | class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): 11 | def __init__(self, wrapped, hijack): 12 | super().__init__(wrapped, hijack) 13 | 14 | self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] 15 | self.id_start = tokenizer.encoder[""] 16 | self.id_end = tokenizer.encoder[""] 17 | self.id_pad = 0 18 | 19 | def tokenize(self, texts): 20 | assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' 21 | 22 | tokenized = [tokenizer.encode(text) for text in texts] 23 | 24 | return tokenized 25 | 26 | def encode_with_transformers(self, tokens): 27 | # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers 28 | z = self.wrapped.encode_with_transformer(tokens) 29 | 30 | return z 31 | 32 | def encode_embedding_init_text(self, init_text, nvpt): 33 | ids = tokenizer.encode(init_text) 34 | ids = torch.asarray([ids], device=devices.device, dtype=torch.int) 35 | embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) 36 | 37 | return embedded 38 | 39 | 40 | class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): 41 | def __init__(self, wrapped, hijack): 42 | super().__init__(wrapped, hijack) 43 | 44 | self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] 45 | self.id_start = tokenizer.encoder[""] 46 | self.id_end = tokenizer.encoder[""] 47 | self.id_pad = 0 48 | 49 | def tokenize(self, texts): 50 | assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' 51 | 52 | tokenized = [tokenizer.encode(text) for text in texts] 53 | 54 | return tokenized 55 | 56 | def encode_with_transformers(self, tokens): 57 | d = self.wrapped.encode_with_transformer(tokens) 58 | z = d[self.wrapped.layer] 59 | 60 | pooled = d.get("pooled") 61 | if pooled is not None: 62 | z.pooled = pooled 63 | 64 | return z 65 | 66 | def encode_embedding_init_text(self, init_text, nvpt): 67 | ids = tokenizer.encode(init_text) 68 | ids = torch.asarray([ids], device=devices.device, dtype=torch.int) 69 | embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0) 70 | 71 | return embedded 72 | -------------------------------------------------------------------------------- /modules/sd_hijack_utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | always_true_func = lambda *args, **kwargs: True 5 | 6 | 7 | class CondFunc: 8 | def __new__(cls, orig_func, sub_func, cond_func=always_true_func): 9 | self = super(CondFunc, cls).__new__(cls) 10 | if isinstance(orig_func, str): 11 | func_path = orig_func.split('.') 12 | for i in range(len(func_path)-1, -1, -1): 13 | try: 14 | resolved_obj = importlib.import_module('.'.join(func_path[:i])) 15 | break 16 | except ImportError: 17 | pass 18 | try: 19 | for attr_name in func_path[i:-1]: 20 | resolved_obj = getattr(resolved_obj, attr_name) 21 | orig_func = getattr(resolved_obj, func_path[-1]) 22 | setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) 23 | except AttributeError: 24 | print(f"Warning: Failed to resolve {orig_func} for CondFunc hijack") 25 | pass 26 | self.__init__(orig_func, sub_func, cond_func) 27 | return lambda *args, **kwargs: self(*args, **kwargs) 28 | def __init__(self, orig_func, sub_func, cond_func): 29 | self.__orig_func = orig_func 30 | self.__sub_func = sub_func 31 | self.__cond_func = cond_func 32 | def __call__(self, *args, **kwargs): 33 | if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): 34 | return self.__sub_func(self.__orig_func, *args, **kwargs) 35 | else: 36 | return self.__orig_func(*args, **kwargs) 37 | -------------------------------------------------------------------------------- /modules/sd_hijack_xlmr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from modules import sd_hijack_clip, devices 4 | 5 | 6 | class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords): 7 | def __init__(self, wrapped, hijack): 8 | super().__init__(wrapped, hijack) 9 | 10 | self.id_start = wrapped.config.bos_token_id 11 | self.id_end = wrapped.config.eos_token_id 12 | self.id_pad = wrapped.config.pad_token_id 13 | 14 | self.comma_token = self.tokenizer.get_vocab().get(',', None) # alt diffusion doesn't have bits for comma 15 | 16 | def encode_with_transformers(self, tokens): 17 | # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a 18 | # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer 19 | # layer to work with - you have to use the last 20 | 21 | attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64) 22 | features = self.wrapped(input_ids=tokens, attention_mask=attention_mask) 23 | z = features['projection_state'] 24 | 25 | return z 26 | 27 | def encode_embedding_init_text(self, init_text, nvpt): 28 | embedding_layer = self.wrapped.roberta.embeddings 29 | ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] 30 | embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0) 31 | 32 | return embedded 33 | -------------------------------------------------------------------------------- /modules/sd_models_types.py: -------------------------------------------------------------------------------- 1 | from ldm.models.diffusion.ddpm import LatentDiffusion 2 | from typing import TYPE_CHECKING 3 | 4 | 5 | if TYPE_CHECKING: 6 | from modules.sd_models import CheckpointInfo 7 | 8 | 9 | class WebuiSdModel(LatentDiffusion): 10 | """This class is not actually instantinated, but its fields are created and fieeld by webui""" 11 | 12 | lowvram: bool 13 | """True if lowvram/medvram optimizations are enabled -- see modules.lowvram for more info""" 14 | 15 | sd_model_hash: str 16 | """short hash, 10 first characters of SHA1 hash of the model file; may be None if --no-hashing flag is used""" 17 | 18 | sd_model_checkpoint: str 19 | """path to the file on disk that model weights were obtained from""" 20 | 21 | sd_checkpoint_info: 'CheckpointInfo' 22 | """structure with additional information about the file with model's weights""" 23 | 24 | is_sdxl: bool 25 | """True if the model's architecture is SDXL or SSD""" 26 | 27 | is_ssd: bool 28 | """True if the model is SSD""" 29 | 30 | is_sd2: bool 31 | """True if the model's architecture is SD 2.x""" 32 | 33 | is_sd1: bool 34 | """True if the model's architecture is SD 1.x""" 35 | 36 | is_sd3: bool 37 | """True if the model's architecture is SD 3""" 38 | 39 | latent_channels: int 40 | """number of layer in latent image representation; will be 16 in SD3 and 4 in other version""" 41 | -------------------------------------------------------------------------------- /modules/sd_samplers_compvis.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/modules/sd_samplers_compvis.py -------------------------------------------------------------------------------- /modules/sd_unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn 2 | 3 | from modules import script_callbacks, shared, devices 4 | 5 | unet_options = [] 6 | current_unet_option = None 7 | current_unet = None 8 | original_forward = None # not used, only left temporarily for compatibility 9 | 10 | def list_unets(): 11 | new_unets = script_callbacks.list_unets_callback() 12 | 13 | unet_options.clear() 14 | unet_options.extend(new_unets) 15 | 16 | 17 | def get_unet_option(option=None): 18 | option = option or shared.opts.sd_unet 19 | 20 | if option == "None": 21 | return None 22 | 23 | if option == "Automatic": 24 | name = shared.sd_model.sd_checkpoint_info.model_name 25 | 26 | options = [x for x in unet_options if x.model_name == name] 27 | 28 | option = options[0].label if options else "None" 29 | 30 | return next(iter([x for x in unet_options if x.label == option]), None) 31 | 32 | 33 | def apply_unet(option=None): 34 | global current_unet_option 35 | global current_unet 36 | 37 | new_option = get_unet_option(option) 38 | if new_option == current_unet_option: 39 | return 40 | 41 | if current_unet is not None: 42 | print(f"Dectivating unet: {current_unet.option.label}") 43 | current_unet.deactivate() 44 | 45 | current_unet_option = new_option 46 | if current_unet_option is None: 47 | current_unet = None 48 | 49 | if not shared.sd_model.lowvram: 50 | shared.sd_model.model.diffusion_model.to(devices.device) 51 | 52 | return 53 | 54 | shared.sd_model.model.diffusion_model.to(devices.cpu) 55 | devices.torch_gc() 56 | 57 | current_unet = current_unet_option.create_unet() 58 | current_unet.option = current_unet_option 59 | print(f"Activating unet: {current_unet.option.label}") 60 | current_unet.activate() 61 | 62 | 63 | class SdUnetOption: 64 | model_name = None 65 | """name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this""" 66 | 67 | label = None 68 | """name of the unet in UI""" 69 | 70 | def create_unet(self): 71 | """returns SdUnet object to be used as a Unet instead of built-in unet when making pictures""" 72 | raise NotImplementedError() 73 | 74 | 75 | class SdUnet(torch.nn.Module): 76 | def forward(self, x, timesteps, context, *args, **kwargs): 77 | raise NotImplementedError() 78 | 79 | def activate(self): 80 | pass 81 | 82 | def deactivate(self): 83 | pass 84 | 85 | 86 | def create_unet_forward(original_forward): 87 | def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs): 88 | if current_unet is not None: 89 | return current_unet.forward(x, timesteps, context, *args, **kwargs) 90 | 91 | return original_forward(self, x, timesteps, context, *args, **kwargs) 92 | 93 | return UNetModel_forward 94 | 95 | -------------------------------------------------------------------------------- /modules/shared.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import gradio as gr 5 | 6 | from modules import shared_cmd_options, shared_gradio_themes, options, shared_items, sd_models_types 7 | from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 8 | from modules import util 9 | from typing import TYPE_CHECKING 10 | 11 | if TYPE_CHECKING: 12 | from modules import shared_state, styles, interrogate, shared_total_tqdm, memmon 13 | 14 | cmd_opts = shared_cmd_options.cmd_opts 15 | parser = shared_cmd_options.parser 16 | 17 | batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond 18 | parallel_processing_allowed = True 19 | styles_filename = cmd_opts.styles_file = cmd_opts.styles_file if len(cmd_opts.styles_file) > 0 else [os.path.join(data_path, 'styles.csv')] 20 | config_filename = cmd_opts.ui_settings_file 21 | hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config} 22 | 23 | demo: gr.Blocks = None 24 | 25 | device: str = None 26 | 27 | weight_load_location: str = None 28 | 29 | xformers_available = False 30 | 31 | hypernetworks = {} 32 | 33 | loaded_hypernetworks = [] 34 | 35 | state: 'shared_state.State' = None 36 | 37 | prompt_styles: 'styles.StyleDatabase' = None 38 | 39 | interrogator: 'interrogate.InterrogateModels' = None 40 | 41 | face_restorers = [] 42 | 43 | options_templates: dict = None 44 | opts: options.Options = None 45 | restricted_opts: set[str] = None 46 | 47 | sd_model: sd_models_types.WebuiSdModel = None 48 | 49 | settings_components: dict = None 50 | """assigned from ui.py, a mapping on setting names to gradio components responsible for those settings""" 51 | 52 | tab_names = [] 53 | 54 | latent_upscale_default_mode = "Latent" 55 | latent_upscale_modes = { 56 | "Latent": {"mode": "bilinear", "antialias": False}, 57 | "Latent (antialiased)": {"mode": "bilinear", "antialias": True}, 58 | "Latent (bicubic)": {"mode": "bicubic", "antialias": False}, 59 | "Latent (bicubic antialiased)": {"mode": "bicubic", "antialias": True}, 60 | "Latent (nearest)": {"mode": "nearest", "antialias": False}, 61 | "Latent (nearest-exact)": {"mode": "nearest-exact", "antialias": False}, 62 | } 63 | 64 | sd_upscalers = [] 65 | 66 | clip_model = None 67 | 68 | progress_print_out = sys.stdout 69 | 70 | gradio_theme = gr.themes.Base() 71 | 72 | total_tqdm: 'shared_total_tqdm.TotalTQDM' = None 73 | 74 | mem_mon: 'memmon.MemUsageMonitor' = None 75 | 76 | options_section = options.options_section 77 | OptionInfo = options.OptionInfo 78 | OptionHTML = options.OptionHTML 79 | 80 | natural_sort_key = util.natural_sort_key 81 | listfiles = util.listfiles 82 | html_path = util.html_path 83 | html = util.html 84 | walk_files = util.walk_files 85 | ldm_print = util.ldm_print 86 | 87 | reload_gradio_theme = shared_gradio_themes.reload_gradio_theme 88 | 89 | list_checkpoint_tiles = shared_items.list_checkpoint_tiles 90 | refresh_checkpoints = shared_items.refresh_checkpoints 91 | list_samplers = shared_items.list_samplers 92 | reload_hypernetworks = shared_items.reload_hypernetworks 93 | 94 | hf_endpoint = os.getenv('HF_ENDPOINT', 'https://huggingface.co') 95 | -------------------------------------------------------------------------------- /modules/shared_cmd_options.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import launch 4 | from modules import cmd_args, script_loading 5 | from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 6 | 7 | parser = cmd_args.parser 8 | 9 | script_loading.preload_extensions(extensions_dir, parser, extension_list=launch.list_extensions(launch.args.ui_settings_file)) 10 | script_loading.preload_extensions(extensions_builtin_dir, parser) 11 | 12 | if os.environ.get('IGNORE_CMD_ARGS_ERRORS', None) is None: 13 | cmd_opts = parser.parse_args() 14 | else: 15 | cmd_opts, _ = parser.parse_known_args() 16 | 17 | cmd_opts.webui_is_non_local = any([cmd_opts.share, cmd_opts.listen, cmd_opts.ngrok, cmd_opts.server_name]) 18 | cmd_opts.disable_extension_access = cmd_opts.webui_is_non_local and not cmd_opts.enable_insecure_extension_access 19 | -------------------------------------------------------------------------------- /modules/shared_init.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from modules import shared 6 | from modules.shared import cmd_opts 7 | 8 | 9 | def initialize(): 10 | """Initializes fields inside the shared module in a controlled manner. 11 | 12 | Should be called early because some other modules you can import mingt need these fields to be already set. 13 | """ 14 | 15 | os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) 16 | 17 | from modules import options, shared_options 18 | shared.options_templates = shared_options.options_templates 19 | shared.opts = options.Options(shared_options.options_templates, shared_options.restricted_opts) 20 | shared.restricted_opts = shared_options.restricted_opts 21 | try: 22 | shared.opts.load(shared.config_filename) 23 | except FileNotFoundError: 24 | pass 25 | 26 | from modules import devices 27 | devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \ 28 | (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'esrgan', 'codeformer']) 29 | 30 | devices.dtype = torch.float32 if cmd_opts.no_half else torch.float16 31 | devices.dtype_vae = torch.float32 if cmd_opts.no_half or cmd_opts.no_half_vae else torch.float16 32 | devices.dtype_inference = torch.float32 if cmd_opts.precision == 'full' else devices.dtype 33 | 34 | if cmd_opts.precision == "half": 35 | msg = "--no-half and --no-half-vae conflict with --precision half" 36 | assert devices.dtype == torch.float16, msg 37 | assert devices.dtype_vae == torch.float16, msg 38 | assert devices.dtype_inference == torch.float16, msg 39 | devices.force_fp16 = True 40 | devices.force_model_fp16() 41 | 42 | shared.device = devices.device 43 | shared.weight_load_location = None if cmd_opts.lowram else "cpu" 44 | 45 | from modules import shared_state 46 | shared.state = shared_state.State() 47 | 48 | from modules import styles 49 | shared.prompt_styles = styles.StyleDatabase(shared.styles_filename) 50 | 51 | from modules import interrogate 52 | shared.interrogator = interrogate.InterrogateModels("interrogate") 53 | 54 | from modules import shared_total_tqdm 55 | shared.total_tqdm = shared_total_tqdm.TotalTQDM() 56 | 57 | from modules import memmon, devices 58 | shared.mem_mon = memmon.MemUsageMonitor("MemMon", devices.device, shared.opts) 59 | shared.mem_mon.start() 60 | 61 | -------------------------------------------------------------------------------- /modules/shared_total_tqdm.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | 3 | from modules import shared 4 | 5 | 6 | class TotalTQDM: 7 | def __init__(self): 8 | self._tqdm = None 9 | 10 | def reset(self): 11 | self._tqdm = tqdm.tqdm( 12 | desc="Total progress", 13 | total=shared.state.job_count * shared.state.sampling_steps, 14 | position=1, 15 | file=shared.progress_print_out 16 | ) 17 | 18 | def update(self): 19 | if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars: 20 | return 21 | if self._tqdm is None: 22 | self.reset() 23 | self._tqdm.update() 24 | 25 | def updateTotal(self, new_total): 26 | if not shared.opts.multiple_tqdm or shared.cmd_opts.disable_console_progressbars: 27 | return 28 | if self._tqdm is None: 29 | self.reset() 30 | self._tqdm.total = new_total 31 | 32 | def clear(self): 33 | if self._tqdm is not None: 34 | self._tqdm.refresh() 35 | self._tqdm.close() 36 | self._tqdm = None 37 | 38 | -------------------------------------------------------------------------------- /modules/textual_inversion/learn_schedule.py: -------------------------------------------------------------------------------- 1 | import tqdm 2 | 3 | 4 | class LearnScheduleIterator: 5 | def __init__(self, learn_rate, max_steps, cur_step=0): 6 | """ 7 | specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000 8 | """ 9 | 10 | pairs = learn_rate.split(',') 11 | self.rates = [] 12 | self.it = 0 13 | self.maxit = 0 14 | try: 15 | for pair in pairs: 16 | if not pair.strip(): 17 | continue 18 | tmp = pair.split(':') 19 | if len(tmp) == 2: 20 | step = int(tmp[1]) 21 | if step > cur_step: 22 | self.rates.append((float(tmp[0]), min(step, max_steps))) 23 | self.maxit += 1 24 | if step > max_steps: 25 | return 26 | elif step == -1: 27 | self.rates.append((float(tmp[0]), max_steps)) 28 | self.maxit += 1 29 | return 30 | else: 31 | self.rates.append((float(tmp[0]), max_steps)) 32 | self.maxit += 1 33 | return 34 | assert self.rates 35 | except (ValueError, AssertionError) as e: 36 | raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.') from e 37 | 38 | 39 | def __iter__(self): 40 | return self 41 | 42 | def __next__(self): 43 | if self.it < self.maxit: 44 | self.it += 1 45 | return self.rates[self.it - 1] 46 | else: 47 | raise StopIteration 48 | 49 | 50 | class LearnRateScheduler: 51 | def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True): 52 | self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step) 53 | (self.learn_rate, self.end_step) = next(self.schedules) 54 | self.verbose = verbose 55 | 56 | if self.verbose: 57 | print(f'Training at rate of {self.learn_rate} until step {self.end_step}') 58 | 59 | self.finished = False 60 | 61 | def step(self, step_number): 62 | if step_number < self.end_step: 63 | return False 64 | 65 | try: 66 | (self.learn_rate, self.end_step) = next(self.schedules) 67 | except StopIteration: 68 | self.finished = True 69 | return False 70 | return True 71 | 72 | def apply(self, optimizer, step_number): 73 | if not self.step(step_number): 74 | return 75 | 76 | if self.verbose: 77 | tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}') 78 | 79 | for pg in optimizer.param_groups: 80 | pg['lr'] = self.learn_rate 81 | 82 | -------------------------------------------------------------------------------- /modules/textual_inversion/saving_settings.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import json 3 | import os 4 | 5 | saved_params_shared = { 6 | "batch_size", 7 | "clip_grad_mode", 8 | "clip_grad_value", 9 | "create_image_every", 10 | "data_root", 11 | "gradient_step", 12 | "initial_step", 13 | "latent_sampling_method", 14 | "learn_rate", 15 | "log_directory", 16 | "model_hash", 17 | "model_name", 18 | "num_of_dataset_images", 19 | "steps", 20 | "template_file", 21 | "training_height", 22 | "training_width", 23 | } 24 | saved_params_ti = { 25 | "embedding_name", 26 | "num_vectors_per_token", 27 | "save_embedding_every", 28 | "save_image_with_stored_embedding", 29 | } 30 | saved_params_hypernet = { 31 | "activation_func", 32 | "add_layer_norm", 33 | "hypernetwork_name", 34 | "layer_structure", 35 | "save_hypernetwork_every", 36 | "use_dropout", 37 | "weight_init", 38 | } 39 | saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet 40 | saved_params_previews = { 41 | "preview_cfg_scale", 42 | "preview_height", 43 | "preview_negative_prompt", 44 | "preview_prompt", 45 | "preview_sampler_index", 46 | "preview_seed", 47 | "preview_steps", 48 | "preview_width", 49 | } 50 | 51 | 52 | def save_settings_to_file(log_directory, all_params): 53 | now = datetime.datetime.now() 54 | params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")} 55 | 56 | keys = saved_params_all 57 | if all_params.get('preview_from_txt2img'): 58 | keys = keys | saved_params_previews 59 | 60 | params.update({k: v for k, v in all_params.items() if k in keys}) 61 | 62 | filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json' 63 | with open(os.path.join(log_directory, filename), "w") as file: 64 | json.dump(params, file, indent=4) 65 | -------------------------------------------------------------------------------- /modules/textual_inversion/test_embedding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/modules/textual_inversion/test_embedding.png -------------------------------------------------------------------------------- /modules/textual_inversion/ui.py: -------------------------------------------------------------------------------- 1 | import html 2 | 3 | import gradio as gr 4 | 5 | import modules.textual_inversion.textual_inversion 6 | from modules import sd_hijack, shared 7 | 8 | 9 | def create_embedding(name, initialization_text, nvpt, overwrite_old): 10 | filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, overwrite_old, init_text=initialization_text) 11 | 12 | sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() 13 | 14 | return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", "" 15 | 16 | 17 | def train_embedding(*args): 18 | 19 | assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible' 20 | 21 | apply_optimizations = shared.opts.training_xattention_optimizations 22 | try: 23 | if not apply_optimizations: 24 | sd_hijack.undo_optimizations() 25 | 26 | embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args) 27 | 28 | res = f""" 29 | Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps. 30 | Embedding saved to {html.escape(filename)} 31 | """ 32 | return res, "" 33 | except Exception: 34 | raise 35 | finally: 36 | if not apply_optimizations: 37 | sd_hijack.apply_optimizations() 38 | 39 | -------------------------------------------------------------------------------- /modules/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | 4 | 5 | class TimerSubcategory: 6 | def __init__(self, timer, category): 7 | self.timer = timer 8 | self.category = category 9 | self.start = None 10 | self.original_base_category = timer.base_category 11 | 12 | def __enter__(self): 13 | self.start = time.time() 14 | self.timer.base_category = self.original_base_category + self.category + "/" 15 | self.timer.subcategory_level += 1 16 | 17 | if self.timer.print_log: 18 | print(f"{' ' * self.timer.subcategory_level}{self.category}:") 19 | 20 | def __exit__(self, exc_type, exc_val, exc_tb): 21 | elapsed_for_subcategroy = time.time() - self.start 22 | self.timer.base_category = self.original_base_category 23 | self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy) 24 | self.timer.subcategory_level -= 1 25 | self.timer.record(self.category, disable_log=True) 26 | 27 | 28 | class Timer: 29 | def __init__(self, print_log=False): 30 | self.start = time.time() 31 | self.records = {} 32 | self.total = 0 33 | self.base_category = '' 34 | self.print_log = print_log 35 | self.subcategory_level = 0 36 | 37 | def elapsed(self): 38 | end = time.time() 39 | res = end - self.start 40 | self.start = end 41 | return res 42 | 43 | def add_time_to_record(self, category, amount): 44 | if category not in self.records: 45 | self.records[category] = 0 46 | 47 | self.records[category] += amount 48 | 49 | def record(self, category, extra_time=0, disable_log=False): 50 | e = self.elapsed() 51 | 52 | self.add_time_to_record(self.base_category + category, e + extra_time) 53 | 54 | self.total += e + extra_time 55 | 56 | if self.print_log and not disable_log: 57 | print(f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s") 58 | 59 | def subcategory(self, name): 60 | self.elapsed() 61 | 62 | subcat = TimerSubcategory(self, name) 63 | return subcat 64 | 65 | def summary(self): 66 | res = f"{self.total:.1f}s" 67 | 68 | additions = [(category, time_taken) for category, time_taken in self.records.items() if time_taken >= 0.1 and '/' not in category] 69 | if not additions: 70 | return res 71 | 72 | res += " (" 73 | res += ", ".join([f"{category}: {time_taken:.1f}s" for category, time_taken in additions]) 74 | res += ")" 75 | 76 | return res 77 | 78 | def dump(self): 79 | return {'total': self.total, 'records': self.records} 80 | 81 | def reset(self): 82 | self.__init__() 83 | 84 | 85 | parser = argparse.ArgumentParser(add_help=False) 86 | parser.add_argument("--log-startup", action='store_true', help="print a detailed log of what's happening at startup") 87 | args = parser.parse_known_args()[0] 88 | 89 | startup_timer = Timer(print_log=args.log_startup) 90 | 91 | startup_record = None 92 | -------------------------------------------------------------------------------- /modules/torch_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import torch.nn 4 | import torch 5 | 6 | 7 | def get_param(model) -> torch.nn.Parameter: 8 | """ 9 | Find the first parameter in a model or module. 10 | """ 11 | if hasattr(model, "model") and hasattr(model.model, "parameters"): 12 | # Unpeel a model descriptor to get at the actual Torch module. 13 | model = model.model 14 | 15 | for param in model.parameters(): 16 | return param 17 | 18 | raise ValueError(f"No parameters found in model {model!r}") 19 | 20 | 21 | def float64(t: torch.Tensor): 22 | """return torch.float64 if device is not mps or xpu, else return torch.float32""" 23 | if t.device.type in ['mps', 'xpu']: 24 | return torch.float32 25 | return torch.float64 26 | -------------------------------------------------------------------------------- /modules/ui_extra_networks_checkpoints.py: -------------------------------------------------------------------------------- 1 | import html 2 | import os 3 | 4 | from modules import shared, ui_extra_networks, sd_models 5 | from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor 6 | 7 | 8 | class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): 9 | def __init__(self): 10 | super().__init__('Checkpoints') 11 | 12 | self.allow_prompt = False 13 | 14 | def refresh(self): 15 | shared.refresh_checkpoints() 16 | 17 | def create_item(self, name, index=None, enable_filter=True): 18 | checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name) 19 | if checkpoint is None: 20 | return 21 | 22 | path, ext = os.path.splitext(checkpoint.filename) 23 | search_terms = [self.search_terms_from_path(checkpoint.filename)] 24 | if checkpoint.sha256: 25 | search_terms.append(checkpoint.sha256) 26 | return { 27 | "name": checkpoint.name_for_extra, 28 | "filename": checkpoint.filename, 29 | "shorthash": checkpoint.shorthash, 30 | "preview": self.find_preview(path), 31 | "description": self.find_description(path), 32 | "search_terms": search_terms, 33 | "onclick": html.escape(f"return selectCheckpoint({ui_extra_networks.quote_js(name)})"), 34 | "local_preview": f"{path}.{shared.opts.samples_format}", 35 | "metadata": checkpoint.metadata, 36 | "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, 37 | } 38 | 39 | def list_items(self): 40 | # instantiate a list to protect against concurrent modification 41 | names = list(sd_models.checkpoints_list) 42 | for index, name in enumerate(names): 43 | item = self.create_item(name, index) 44 | if item is not None: 45 | yield item 46 | 47 | def allowed_directories_for_previews(self): 48 | return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] 49 | 50 | def create_user_metadata_editor(self, ui, tabname): 51 | return CheckpointUserMetadataEditor(ui, tabname, self) 52 | -------------------------------------------------------------------------------- /modules/ui_extra_networks_checkpoints_user_metadata.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | 3 | from modules import ui_extra_networks_user_metadata, sd_vae, shared 4 | from modules.ui_common import create_refresh_button 5 | 6 | 7 | class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor): 8 | def __init__(self, ui, tabname, page): 9 | super().__init__(ui, tabname, page) 10 | 11 | self.select_vae = None 12 | 13 | def save_user_metadata(self, name, desc, notes, vae): 14 | user_metadata = self.get_user_metadata(name) 15 | user_metadata["description"] = desc 16 | user_metadata["notes"] = notes 17 | user_metadata["vae"] = vae 18 | 19 | self.write_user_metadata(name, user_metadata) 20 | 21 | def update_vae(self, name): 22 | if name == shared.sd_model.sd_checkpoint_info.name_for_extra: 23 | sd_vae.reload_vae_weights() 24 | 25 | def put_values_into_components(self, name): 26 | user_metadata = self.get_user_metadata(name) 27 | values = super().put_values_into_components(name) 28 | 29 | return [ 30 | *values[0:5], 31 | user_metadata.get('vae', ''), 32 | ] 33 | 34 | def create_editor(self): 35 | self.create_default_editor_elems() 36 | 37 | with gr.Row(): 38 | self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae") 39 | create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae") 40 | 41 | self.edit_notes = gr.TextArea(label='Notes', lines=4) 42 | 43 | self.create_default_buttons() 44 | 45 | viewed_components = [ 46 | self.edit_name, 47 | self.edit_description, 48 | self.html_filedata, 49 | self.html_preview, 50 | self.edit_notes, 51 | self.select_vae, 52 | ] 53 | 54 | self.button_edit\ 55 | .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\ 56 | .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) 57 | 58 | edited_components = [ 59 | self.edit_description, 60 | self.edit_notes, 61 | self.select_vae, 62 | ] 63 | 64 | self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components) 65 | self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input]) 66 | 67 | -------------------------------------------------------------------------------- /modules/ui_extra_networks_hypernets.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from modules import shared, ui_extra_networks 4 | from modules.ui_extra_networks import quote_js 5 | from modules.hashes import sha256_from_cache 6 | 7 | 8 | class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): 9 | def __init__(self): 10 | super().__init__('Hypernetworks') 11 | 12 | def refresh(self): 13 | shared.reload_hypernetworks() 14 | 15 | def create_item(self, name, index=None, enable_filter=True): 16 | full_path = shared.hypernetworks.get(name) 17 | if full_path is None: 18 | return 19 | 20 | path, ext = os.path.splitext(full_path) 21 | sha256 = sha256_from_cache(full_path, f'hypernet/{name}') 22 | shorthash = sha256[0:10] if sha256 else None 23 | search_terms = [self.search_terms_from_path(path)] 24 | if sha256: 25 | search_terms.append(sha256) 26 | return { 27 | "name": name, 28 | "filename": full_path, 29 | "shorthash": shorthash, 30 | "preview": self.find_preview(path), 31 | "description": self.find_description(path), 32 | "search_terms": search_terms, 33 | "prompt": quote_js(f""), 34 | "local_preview": f"{path}.preview.{shared.opts.samples_format}", 35 | "sort_keys": {'default': index, **self.get_sort_keys(path + ext)}, 36 | } 37 | 38 | def list_items(self): 39 | # instantiate a list to protect against concurrent modification 40 | names = list(shared.hypernetworks) 41 | for index, name in enumerate(names): 42 | item = self.create_item(name, index) 43 | if item is not None: 44 | yield item 45 | 46 | def allowed_directories_for_previews(self): 47 | return [shared.cmd_opts.hypernetwork_dir] 48 | 49 | -------------------------------------------------------------------------------- /modules/ui_extra_networks_textual_inversion.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from modules import ui_extra_networks, sd_hijack, shared 4 | from modules.ui_extra_networks import quote_js 5 | 6 | 7 | class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): 8 | def __init__(self): 9 | super().__init__('Textual Inversion') 10 | self.allow_negative_prompt = True 11 | 12 | def refresh(self): 13 | sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) 14 | 15 | def create_item(self, name, index=None, enable_filter=True): 16 | embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name) 17 | if embedding is None: 18 | return 19 | 20 | path, ext = os.path.splitext(embedding.filename) 21 | search_terms = [self.search_terms_from_path(embedding.filename)] 22 | if embedding.hash: 23 | search_terms.append(embedding.hash) 24 | return { 25 | "name": name, 26 | "filename": embedding.filename, 27 | "shorthash": embedding.shorthash, 28 | "preview": self.find_preview(path), 29 | "description": self.find_description(path), 30 | "search_terms": search_terms, 31 | "prompt": quote_js(embedding.name), 32 | "local_preview": f"{path}.preview.{shared.opts.samples_format}", 33 | "sort_keys": {'default': index, **self.get_sort_keys(embedding.filename)}, 34 | } 35 | 36 | def list_items(self): 37 | # instantiate a list to protect against concurrent modification 38 | names = list(sd_hijack.model_hijack.embedding_db.word_embeddings) 39 | for index, name in enumerate(names): 40 | item = self.create_item(name, index) 41 | if item is not None: 42 | yield item 43 | 44 | def allowed_directories_for_previews(self): 45 | return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) 46 | -------------------------------------------------------------------------------- /modules/ui_gradio_extensions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gradio as gr 3 | 4 | from modules import localization, shared, scripts, util 5 | from modules.paths import script_path, data_path 6 | 7 | 8 | def webpath(fn): 9 | return f'file={util.truncate_path(fn)}?{os.path.getmtime(fn)}' 10 | 11 | 12 | def javascript_html(): 13 | # Ensure localization is in `window` before scripts 14 | head = f'\n' 15 | 16 | script_js = os.path.join(script_path, "script.js") 17 | head += f'\n' 18 | 19 | for script in scripts.list_scripts("javascript", ".js"): 20 | head += f'\n' 21 | 22 | for script in scripts.list_scripts("javascript", ".mjs"): 23 | head += f'\n' 24 | 25 | if shared.cmd_opts.theme: 26 | head += f'\n' 27 | 28 | return head 29 | 30 | 31 | def css_html(): 32 | head = "" 33 | 34 | def stylesheet(fn): 35 | return f'' 36 | 37 | for cssfile in scripts.list_files_with_name("style.css"): 38 | head += stylesheet(cssfile) 39 | 40 | user_css = os.path.join(data_path, "user.css") 41 | if os.path.exists(user_css): 42 | head += stylesheet(user_css) 43 | 44 | from modules.shared_gradio_themes import resolve_var 45 | light = resolve_var('background_fill_primary') 46 | dark = resolve_var('background_fill_primary_dark') 47 | head += f'' 48 | 49 | return head 50 | 51 | 52 | def reload_javascript(): 53 | js = javascript_html() 54 | css = css_html() 55 | 56 | def template_response(*args, **kwargs): 57 | res = shared.GradioTemplateResponseOriginal(*args, **kwargs) 58 | res.body = res.body.replace(b'', f'{js}'.encode("utf8")) 59 | res.body = res.body.replace(b'', f'{css}'.encode("utf8")) 60 | res.init_headers() 61 | return res 62 | 63 | gr.routes.templates.TemplateResponse = template_response 64 | 65 | 66 | if not hasattr(shared, 'GradioTemplateResponseOriginal'): 67 | shared.GradioTemplateResponseOriginal = gr.routes.templates.TemplateResponse 68 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "stable-diffusion-webui", 3 | "version": "0.0.0", 4 | "devDependencies": { 5 | "eslint": "^8.40.0" 6 | }, 7 | "scripts": { 8 | "lint": "eslint .", 9 | "fix": "eslint --fix ." 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | 3 | target-version = "py39" 4 | 5 | [tool.ruff.lint] 6 | 7 | extend-select = [ 8 | "B", 9 | "C", 10 | "I", 11 | "W", 12 | ] 13 | 14 | exclude = [ 15 | "extensions", 16 | "extensions-disabled", 17 | ] 18 | 19 | ignore = [ 20 | "E501", # Line too long 21 | "E721", # Do not compare types, use `isinstance` 22 | "E731", # Do not assign a `lambda` expression, use a `def` 23 | 24 | "I001", # Import block is un-sorted or un-formatted 25 | "C901", # Function is too complex 26 | "C408", # Rewrite as a literal 27 | "W605", # invalid escape sequence, messes with some docstrings 28 | ] 29 | 30 | [tool.ruff.lint.per-file-ignores] 31 | "webui.py" = ["E402"] # Module level import not at top of file 32 | 33 | [tool.ruff.lint.flake8-bugbear] 34 | # Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`. 35 | extend-immutable-calls = ["fastapi.Depends", "fastapi.security.HTTPBasic"] 36 | 37 | [tool.pytest.ini_options] 38 | base_url = "http://127.0.0.1:7860" 39 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | pytest-base-url~=2.0 2 | pytest-cov~=4.0 3 | pytest~=7.3 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | GitPython 2 | Pillow 3 | accelerate 4 | 5 | blendmodes 6 | clean-fid 7 | diskcache 8 | einops 9 | facexlib 10 | fastapi>=0.90.1 11 | gradio==3.41.2 12 | inflection 13 | jsonmerge 14 | kornia 15 | lark 16 | numpy 17 | omegaconf 18 | open-clip-torch 19 | 20 | piexif 21 | protobuf==3.20.0 22 | psutil 23 | pytorch_lightning 24 | requests 25 | resize-right 26 | 27 | safetensors 28 | scikit-image>=0.19 29 | tomesd 30 | torch 31 | torchdiffeq 32 | torchsde 33 | transformers==4.30.2 34 | pillow-avif-plugin==1.4.3 -------------------------------------------------------------------------------- /requirements_npu.txt: -------------------------------------------------------------------------------- 1 | cloudpickle 2 | decorator 3 | synr==0.5.0 4 | tornado 5 | -------------------------------------------------------------------------------- /requirements_versions.txt: -------------------------------------------------------------------------------- 1 | setuptools==69.5.1 # temp fix for compatibility with some old packages 2 | GitPython==3.1.32 3 | Pillow==9.5.0 4 | accelerate==0.21.0 5 | blendmodes==2022 6 | clean-fid==0.1.35 7 | diskcache==5.6.3 8 | einops==0.4.1 9 | facexlib==0.3.0 10 | fastapi==0.94.0 11 | gradio==3.41.2 12 | httpcore==0.15 13 | inflection==0.5.1 14 | jsonmerge==1.8.0 15 | kornia==0.6.7 16 | lark==1.1.2 17 | numpy==1.26.2 18 | omegaconf==2.2.3 19 | open-clip-torch==2.20.0 20 | piexif==1.1.3 21 | protobuf==3.20.0 22 | psutil==5.9.5 23 | pytorch_lightning==1.9.4 24 | resize-right==0.0.2 25 | safetensors==0.4.2 26 | scikit-image==0.21.0 27 | spandrel==0.3.4 28 | spandrel-extra-arches==0.1.1 29 | tomesd==0.1.3 30 | torch 31 | torchdiffeq==0.2.3 32 | torchsde==0.2.6 33 | transformers==4.30.2 34 | httpx==0.24.1 35 | pillow-avif-plugin==1.4.3 36 | -------------------------------------------------------------------------------- /screenshot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/screenshot.png -------------------------------------------------------------------------------- /scripts/custom_code.py: -------------------------------------------------------------------------------- 1 | import modules.scripts as scripts 2 | import gradio as gr 3 | import ast 4 | import copy 5 | 6 | from modules.processing import Processed 7 | from modules.shared import cmd_opts 8 | 9 | 10 | def convertExpr2Expression(expr): 11 | expr.lineno = 0 12 | expr.col_offset = 0 13 | result = ast.Expression(expr.value, lineno=0, col_offset = 0) 14 | 15 | return result 16 | 17 | 18 | def exec_with_return(code, module): 19 | """ 20 | like exec() but can return values 21 | https://stackoverflow.com/a/52361938/5862977 22 | """ 23 | code_ast = ast.parse(code) 24 | 25 | init_ast = copy.deepcopy(code_ast) 26 | init_ast.body = code_ast.body[:-1] 27 | 28 | last_ast = copy.deepcopy(code_ast) 29 | last_ast.body = code_ast.body[-1:] 30 | 31 | exec(compile(init_ast, "", "exec"), module.__dict__) 32 | if type(last_ast.body[0]) == ast.Expr: 33 | return eval(compile(convertExpr2Expression(last_ast.body[0]), "", "eval"), module.__dict__) 34 | else: 35 | exec(compile(last_ast, "", "exec"), module.__dict__) 36 | 37 | 38 | class Script(scripts.Script): 39 | 40 | def title(self): 41 | return "Custom code" 42 | 43 | def show(self, is_img2img): 44 | return cmd_opts.allow_code 45 | 46 | def ui(self, is_img2img): 47 | example = """from modules.processing import process_images 48 | 49 | p.width = 768 50 | p.height = 768 51 | p.batch_size = 2 52 | p.steps = 10 53 | 54 | return process_images(p) 55 | """ 56 | 57 | 58 | code = gr.Code(value=example, language="python", label="Python code", elem_id=self.elem_id("code")) 59 | indent_level = gr.Number(label='Indent level', value=2, precision=0, elem_id=self.elem_id("indent_level")) 60 | 61 | return [code, indent_level] 62 | 63 | def run(self, p, code, indent_level): 64 | assert cmd_opts.allow_code, '--allow-code option must be enabled' 65 | 66 | display_result_data = [[], -1, ""] 67 | 68 | def display(imgs, s=display_result_data[1], i=display_result_data[2]): 69 | display_result_data[0] = imgs 70 | display_result_data[1] = s 71 | display_result_data[2] = i 72 | 73 | from types import ModuleType 74 | module = ModuleType("testmodule") 75 | module.__dict__.update(globals()) 76 | module.p = p 77 | module.display = display 78 | 79 | indent = " " * indent_level 80 | indented = code.replace('\n', f"\n{indent}") 81 | body = f"""def __webuitemp__(): 82 | {indent}{indented} 83 | __webuitemp__()""" 84 | 85 | result = exec_with_return(body, module) 86 | 87 | if isinstance(result, Processed): 88 | return result 89 | 90 | return Processed(p, *display_result_data) 91 | -------------------------------------------------------------------------------- /scripts/postprocessing_codeformer.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | 4 | from modules import scripts_postprocessing, codeformer_model, ui_components 5 | import gradio as gr 6 | 7 | 8 | class ScriptPostprocessingCodeFormer(scripts_postprocessing.ScriptPostprocessing): 9 | name = "CodeFormer" 10 | order = 3000 11 | 12 | def ui(self): 13 | with ui_components.InputAccordion(False, label="CodeFormer") as enable: 14 | with gr.Row(): 15 | codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Visibility", value=1.0, elem_id="extras_codeformer_visibility") 16 | codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Weight (0 = maximum effect, 1 = minimum effect)", value=0, elem_id="extras_codeformer_weight") 17 | 18 | return { 19 | "enable": enable, 20 | "codeformer_visibility": codeformer_visibility, 21 | "codeformer_weight": codeformer_weight, 22 | } 23 | 24 | def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, codeformer_visibility, codeformer_weight): 25 | if codeformer_visibility == 0 or not enable: 26 | return 27 | 28 | restored_img = codeformer_model.codeformer.restore(np.array(pp.image.convert("RGB"), dtype=np.uint8), w=codeformer_weight) 29 | res = Image.fromarray(restored_img) 30 | 31 | if codeformer_visibility < 1.0: 32 | res = Image.blend(pp.image, res, codeformer_visibility) 33 | 34 | pp.image = res 35 | pp.info["CodeFormer visibility"] = round(codeformer_visibility, 3) 36 | pp.info["CodeFormer weight"] = round(codeformer_weight, 3) 37 | -------------------------------------------------------------------------------- /scripts/postprocessing_gfpgan.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | 4 | from modules import scripts_postprocessing, gfpgan_model, ui_components 5 | import gradio as gr 6 | 7 | 8 | class ScriptPostprocessingGfpGan(scripts_postprocessing.ScriptPostprocessing): 9 | name = "GFPGAN" 10 | order = 2000 11 | 12 | def ui(self): 13 | with ui_components.InputAccordion(False, label="GFPGAN") as enable: 14 | gfpgan_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="Visibility", value=1.0, elem_id="extras_gfpgan_visibility") 15 | 16 | return { 17 | "enable": enable, 18 | "gfpgan_visibility": gfpgan_visibility, 19 | } 20 | 21 | def process(self, pp: scripts_postprocessing.PostprocessedImage, enable, gfpgan_visibility): 22 | if gfpgan_visibility == 0 or not enable: 23 | return 24 | 25 | restored_img = gfpgan_model.gfpgan_fix_faces(np.array(pp.image.convert("RGB"), dtype=np.uint8)) 26 | res = Image.fromarray(restored_img) 27 | 28 | if gfpgan_visibility < 1.0: 29 | res = Image.blend(pp.image, res, gfpgan_visibility) 30 | 31 | pp.image = res 32 | pp.info["GFPGAN visibility"] = round(gfpgan_visibility, 3) 33 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/test/__init__.py -------------------------------------------------------------------------------- /test/conftest.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import os 3 | 4 | import pytest 5 | 6 | test_files_path = os.path.dirname(__file__) + "/test_files" 7 | test_outputs_path = os.path.dirname(__file__) + "/test_outputs" 8 | 9 | 10 | def pytest_configure(config): 11 | # We don't want to fail on Py.test command line arguments being 12 | # parsed by webui: 13 | os.environ.setdefault("IGNORE_CMD_ARGS_ERRORS", "1") 14 | 15 | 16 | def file_to_base64(filename): 17 | with open(filename, "rb") as file: 18 | data = file.read() 19 | 20 | base64_str = str(base64.b64encode(data), "utf-8") 21 | return "data:image/png;base64," + base64_str 22 | 23 | 24 | @pytest.fixture(scope="session") # session so we don't read this over and over 25 | def img2img_basic_image_base64() -> str: 26 | return file_to_base64(os.path.join(test_files_path, "img2img_basic.png")) 27 | 28 | 29 | @pytest.fixture(scope="session") # session so we don't read this over and over 30 | def mask_basic_image_base64() -> str: 31 | return file_to_base64(os.path.join(test_files_path, "mask_basic.png")) 32 | 33 | 34 | @pytest.fixture(scope="session") 35 | def initialize() -> None: 36 | import webui # noqa: F401 37 | -------------------------------------------------------------------------------- /test/test_extras.py: -------------------------------------------------------------------------------- 1 | import requests 2 | 3 | 4 | def test_simple_upscaling_performed(base_url, img2img_basic_image_base64): 5 | payload = { 6 | "resize_mode": 0, 7 | "show_extras_results": True, 8 | "gfpgan_visibility": 0, 9 | "codeformer_visibility": 0, 10 | "codeformer_weight": 0, 11 | "upscaling_resize": 2, 12 | "upscaling_resize_w": 128, 13 | "upscaling_resize_h": 128, 14 | "upscaling_crop": True, 15 | "upscaler_1": "Lanczos", 16 | "upscaler_2": "None", 17 | "extras_upscaler_2_visibility": 0, 18 | "image": img2img_basic_image_base64, 19 | } 20 | assert requests.post(f"{base_url}/sdapi/v1/extra-single-image", json=payload).status_code == 200 21 | 22 | 23 | def test_png_info_performed(base_url, img2img_basic_image_base64): 24 | payload = { 25 | "image": img2img_basic_image_base64, 26 | } 27 | assert requests.post(f"{base_url}/sdapi/v1/extra-single-image", json=payload).status_code == 200 28 | 29 | 30 | def test_interrogate_performed(base_url, img2img_basic_image_base64): 31 | payload = { 32 | "image": img2img_basic_image_base64, 33 | "model": "clip", 34 | } 35 | assert requests.post(f"{base_url}/sdapi/v1/extra-single-image", json=payload).status_code == 200 36 | -------------------------------------------------------------------------------- /test/test_face_restorers.py: -------------------------------------------------------------------------------- 1 | import os 2 | from test.conftest import test_files_path, test_outputs_path 3 | 4 | import numpy as np 5 | import pytest 6 | from PIL import Image 7 | 8 | 9 | @pytest.mark.usefixtures("initialize") 10 | @pytest.mark.parametrize("restorer_name", ["gfpgan", "codeformer"]) 11 | def test_face_restorers(restorer_name): 12 | from modules import shared 13 | 14 | if restorer_name == "gfpgan": 15 | from modules import gfpgan_model 16 | gfpgan_model.setup_model(shared.cmd_opts.gfpgan_models_path) 17 | restorer = gfpgan_model.gfpgan_fix_faces 18 | elif restorer_name == "codeformer": 19 | from modules import codeformer_model 20 | codeformer_model.setup_model(shared.cmd_opts.codeformer_models_path) 21 | restorer = codeformer_model.codeformer.restore 22 | else: 23 | raise NotImplementedError("...") 24 | img = Image.open(os.path.join(test_files_path, "two-faces.jpg")) 25 | np_img = np.array(img, dtype=np.uint8) 26 | fixed_image = restorer(np_img) 27 | assert fixed_image.shape == np_img.shape 28 | assert not np.allclose(fixed_image, np_img) # should have visibly changed 29 | Image.fromarray(fixed_image).save(os.path.join(test_outputs_path, f"{restorer_name}.png")) 30 | -------------------------------------------------------------------------------- /test/test_files/empty.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/test/test_files/empty.pt -------------------------------------------------------------------------------- /test/test_files/img2img_basic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/test/test_files/img2img_basic.png -------------------------------------------------------------------------------- /test/test_files/mask_basic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/test/test_files/mask_basic.png -------------------------------------------------------------------------------- /test/test_files/two-faces.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/test/test_files/two-faces.jpg -------------------------------------------------------------------------------- /test/test_img2img.py: -------------------------------------------------------------------------------- 1 | 2 | import pytest 3 | import requests 4 | 5 | 6 | @pytest.fixture() 7 | def url_img2img(base_url): 8 | return f"{base_url}/sdapi/v1/img2img" 9 | 10 | 11 | @pytest.fixture() 12 | def simple_img2img_request(img2img_basic_image_base64): 13 | return { 14 | "batch_size": 1, 15 | "cfg_scale": 7, 16 | "denoising_strength": 0.75, 17 | "eta": 0, 18 | "height": 64, 19 | "include_init_images": False, 20 | "init_images": [img2img_basic_image_base64], 21 | "inpaint_full_res": False, 22 | "inpaint_full_res_padding": 0, 23 | "inpainting_fill": 0, 24 | "inpainting_mask_invert": False, 25 | "mask": None, 26 | "mask_blur": 4, 27 | "n_iter": 1, 28 | "negative_prompt": "", 29 | "override_settings": {}, 30 | "prompt": "example prompt", 31 | "resize_mode": 0, 32 | "restore_faces": False, 33 | "s_churn": 0, 34 | "s_noise": 1, 35 | "s_tmax": 0, 36 | "s_tmin": 0, 37 | "sampler_index": "Euler a", 38 | "seed": -1, 39 | "seed_resize_from_h": -1, 40 | "seed_resize_from_w": -1, 41 | "steps": 3, 42 | "styles": [], 43 | "subseed": -1, 44 | "subseed_strength": 0, 45 | "tiling": False, 46 | "width": 64, 47 | } 48 | 49 | 50 | def test_img2img_simple_performed(url_img2img, simple_img2img_request): 51 | assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200 52 | 53 | 54 | def test_inpainting_masked_performed(url_img2img, simple_img2img_request, mask_basic_image_base64): 55 | simple_img2img_request["mask"] = mask_basic_image_base64 56 | assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200 57 | 58 | 59 | def test_inpainting_with_inverted_masked_performed(url_img2img, simple_img2img_request, mask_basic_image_base64): 60 | simple_img2img_request["mask"] = mask_basic_image_base64 61 | simple_img2img_request["inpainting_mask_invert"] = True 62 | assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200 63 | 64 | 65 | def test_img2img_sd_upscale_performed(url_img2img, simple_img2img_request): 66 | simple_img2img_request["script_name"] = "sd upscale" 67 | simple_img2img_request["script_args"] = ["", 8, "Lanczos", 2.0] 68 | assert requests.post(url_img2img, json=simple_img2img_request).status_code == 200 69 | -------------------------------------------------------------------------------- /test/test_outputs/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AUTOMATIC1111/stable-diffusion-webui/82a973c04367123ae98bd9abdf80d9eda9b910e2/test/test_outputs/.gitkeep -------------------------------------------------------------------------------- /test/test_torch_utils.py: -------------------------------------------------------------------------------- 1 | import types 2 | 3 | import pytest 4 | import torch 5 | 6 | from modules import torch_utils 7 | 8 | 9 | @pytest.mark.parametrize("wrapped", [True, False]) 10 | def test_get_param(wrapped): 11 | mod = torch.nn.Linear(1, 1) 12 | cpu = torch.device("cpu") 13 | mod.to(dtype=torch.float16, device=cpu) 14 | if wrapped: 15 | # more or less how spandrel wraps a thing 16 | mod = types.SimpleNamespace(model=mod) 17 | p = torch_utils.get_param(mod) 18 | assert p.dtype == torch.float16 19 | assert p.device == cpu 20 | -------------------------------------------------------------------------------- /test/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import requests 3 | 4 | 5 | def test_options_write(base_url): 6 | url_options = f"{base_url}/sdapi/v1/options" 7 | response = requests.get(url_options) 8 | assert response.status_code == 200 9 | 10 | pre_value = response.json()["send_seed"] 11 | 12 | assert requests.post(url_options, json={'send_seed': (not pre_value)}).status_code == 200 13 | 14 | response = requests.get(url_options) 15 | assert response.status_code == 200 16 | assert response.json()['send_seed'] == (not pre_value) 17 | 18 | requests.post(url_options, json={"send_seed": pre_value}) 19 | 20 | 21 | @pytest.mark.parametrize("url", [ 22 | "sdapi/v1/cmd-flags", 23 | "sdapi/v1/samplers", 24 | "sdapi/v1/upscalers", 25 | "sdapi/v1/sd-models", 26 | "sdapi/v1/hypernetworks", 27 | "sdapi/v1/face-restorers", 28 | "sdapi/v1/realesrgan-models", 29 | "sdapi/v1/prompt-styles", 30 | "sdapi/v1/embeddings", 31 | ]) 32 | def test_get_api_url(base_url, url): 33 | assert requests.get(f"{base_url}/{url}").status_code == 200 34 | -------------------------------------------------------------------------------- /textual_inversion_templates/hypernetwork.txt: -------------------------------------------------------------------------------- 1 | a photo of a [filewords] 2 | a rendering of a [filewords] 3 | a cropped photo of the [filewords] 4 | the photo of a [filewords] 5 | a photo of a clean [filewords] 6 | a photo of a dirty [filewords] 7 | a dark photo of the [filewords] 8 | a photo of my [filewords] 9 | a photo of the cool [filewords] 10 | a close-up photo of a [filewords] 11 | a bright photo of the [filewords] 12 | a cropped photo of a [filewords] 13 | a photo of the [filewords] 14 | a good photo of the [filewords] 15 | a photo of one [filewords] 16 | a close-up photo of the [filewords] 17 | a rendition of the [filewords] 18 | a photo of the clean [filewords] 19 | a rendition of a [filewords] 20 | a photo of a nice [filewords] 21 | a good photo of a [filewords] 22 | a photo of the nice [filewords] 23 | a photo of the small [filewords] 24 | a photo of the weird [filewords] 25 | a photo of the large [filewords] 26 | a photo of a cool [filewords] 27 | a photo of a small [filewords] 28 | -------------------------------------------------------------------------------- /textual_inversion_templates/none.txt: -------------------------------------------------------------------------------- 1 | picture 2 | -------------------------------------------------------------------------------- /textual_inversion_templates/style.txt: -------------------------------------------------------------------------------- 1 | a painting, art by [name] 2 | a rendering, art by [name] 3 | a cropped painting, art by [name] 4 | the painting, art by [name] 5 | a clean painting, art by [name] 6 | a dirty painting, art by [name] 7 | a dark painting, art by [name] 8 | a picture, art by [name] 9 | a cool painting, art by [name] 10 | a close-up painting, art by [name] 11 | a bright painting, art by [name] 12 | a cropped painting, art by [name] 13 | a good painting, art by [name] 14 | a close-up painting, art by [name] 15 | a rendition, art by [name] 16 | a nice painting, art by [name] 17 | a small painting, art by [name] 18 | a weird painting, art by [name] 19 | a large painting, art by [name] 20 | -------------------------------------------------------------------------------- /textual_inversion_templates/style_filewords.txt: -------------------------------------------------------------------------------- 1 | a painting of [filewords], art by [name] 2 | a rendering of [filewords], art by [name] 3 | a cropped painting of [filewords], art by [name] 4 | the painting of [filewords], art by [name] 5 | a clean painting of [filewords], art by [name] 6 | a dirty painting of [filewords], art by [name] 7 | a dark painting of [filewords], art by [name] 8 | a picture of [filewords], art by [name] 9 | a cool painting of [filewords], art by [name] 10 | a close-up painting of [filewords], art by [name] 11 | a bright painting of [filewords], art by [name] 12 | a cropped painting of [filewords], art by [name] 13 | a good painting of [filewords], art by [name] 14 | a close-up painting of [filewords], art by [name] 15 | a rendition of [filewords], art by [name] 16 | a nice painting of [filewords], art by [name] 17 | a small painting of [filewords], art by [name] 18 | a weird painting of [filewords], art by [name] 19 | a large painting of [filewords], art by [name] 20 | -------------------------------------------------------------------------------- /textual_inversion_templates/subject.txt: -------------------------------------------------------------------------------- 1 | a photo of a [name] 2 | a rendering of a [name] 3 | a cropped photo of the [name] 4 | the photo of a [name] 5 | a photo of a clean [name] 6 | a photo of a dirty [name] 7 | a dark photo of the [name] 8 | a photo of my [name] 9 | a photo of the cool [name] 10 | a close-up photo of a [name] 11 | a bright photo of the [name] 12 | a cropped photo of a [name] 13 | a photo of the [name] 14 | a good photo of the [name] 15 | a photo of one [name] 16 | a close-up photo of the [name] 17 | a rendition of the [name] 18 | a photo of the clean [name] 19 | a rendition of a [name] 20 | a photo of a nice [name] 21 | a good photo of a [name] 22 | a photo of the nice [name] 23 | a photo of the small [name] 24 | a photo of the weird [name] 25 | a photo of the large [name] 26 | a photo of a cool [name] 27 | a photo of a small [name] 28 | -------------------------------------------------------------------------------- /textual_inversion_templates/subject_filewords.txt: -------------------------------------------------------------------------------- 1 | a photo of a [name], [filewords] 2 | a rendering of a [name], [filewords] 3 | a cropped photo of the [name], [filewords] 4 | the photo of a [name], [filewords] 5 | a photo of a clean [name], [filewords] 6 | a photo of a dirty [name], [filewords] 7 | a dark photo of the [name], [filewords] 8 | a photo of my [name], [filewords] 9 | a photo of the cool [name], [filewords] 10 | a close-up photo of a [name], [filewords] 11 | a bright photo of the [name], [filewords] 12 | a cropped photo of a [name], [filewords] 13 | a photo of the [name], [filewords] 14 | a good photo of the [name], [filewords] 15 | a photo of one [name], [filewords] 16 | a close-up photo of the [name], [filewords] 17 | a rendition of the [name], [filewords] 18 | a photo of the clean [name], [filewords] 19 | a rendition of a [name], [filewords] 20 | a photo of a nice [name], [filewords] 21 | a good photo of a [name], [filewords] 22 | a photo of the nice [name], [filewords] 23 | a photo of the small [name], [filewords] 24 | a photo of the weird [name], [filewords] 25 | a photo of the large [name], [filewords] 26 | a photo of a cool [name], [filewords] 27 | a photo of a small [name], [filewords] 28 | -------------------------------------------------------------------------------- /webui-macos-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #################################################################### 3 | # macOS defaults # 4 | # Please modify webui-user.sh to change these instead of this file # 5 | #################################################################### 6 | 7 | export install_dir="$HOME" 8 | export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" 9 | export PYTORCH_ENABLE_MPS_FALLBACK=1 10 | 11 | if [[ "$(sysctl -n machdep.cpu.brand_string)" =~ ^.*"Intel".*$ ]]; then 12 | export TORCH_COMMAND="pip install torch==2.1.2 torchvision==0.16.2" 13 | else 14 | export TORCH_COMMAND="pip install torch==2.3.1 torchvision==0.18.1" 15 | fi 16 | 17 | #################################################################### 18 | -------------------------------------------------------------------------------- /webui-user.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | set PYTHON= 4 | set GIT= 5 | set VENV_DIR= 6 | set COMMANDLINE_ARGS= 7 | 8 | call webui.bat 9 | -------------------------------------------------------------------------------- /webui-user.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | ######################################################### 3 | # Uncomment and change the variables below to your need:# 4 | ######################################################### 5 | 6 | # Install directory without trailing slash 7 | #install_dir="/home/$(whoami)" 8 | 9 | # Name of the subdirectory 10 | #clone_dir="stable-diffusion-webui" 11 | 12 | # Commandline arguments for webui.py, for example: export COMMANDLINE_ARGS="--medvram --opt-split-attention" 13 | #export COMMANDLINE_ARGS="" 14 | 15 | # python3 executable 16 | #python_cmd="python3" 17 | 18 | # git executable 19 | #export GIT="git" 20 | 21 | # python3 venv without trailing slash (defaults to ${install_dir}/${clone_dir}/venv) 22 | #venv_dir="venv" 23 | 24 | # script to launch to start the app 25 | #export LAUNCH_SCRIPT="launch.py" 26 | 27 | # install command for torch 28 | #export TORCH_COMMAND="pip install torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113" 29 | 30 | # Requirements file to use for stable-diffusion-webui 31 | #export REQS_FILE="requirements_versions.txt" 32 | 33 | # Fixed git repos 34 | #export K_DIFFUSION_PACKAGE="" 35 | #export GFPGAN_PACKAGE="" 36 | 37 | # Fixed git commits 38 | #export STABLE_DIFFUSION_COMMIT_HASH="" 39 | #export CODEFORMER_COMMIT_HASH="" 40 | #export BLIP_COMMIT_HASH="" 41 | 42 | # Uncomment to enable accelerated launch 43 | #export ACCELERATE="True" 44 | 45 | # Uncomment to disable TCMalloc 46 | #export NO_TCMALLOC="True" 47 | 48 | ########################################### 49 | -------------------------------------------------------------------------------- /webui.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | if exist webui.settings.bat ( 4 | call webui.settings.bat 5 | ) 6 | 7 | if not defined PYTHON (set PYTHON=python) 8 | if defined GIT (set "GIT_PYTHON_GIT_EXECUTABLE=%GIT%") 9 | if not defined VENV_DIR (set "VENV_DIR=%~dp0%venv") 10 | 11 | set SD_WEBUI_RESTART=tmp/restart 12 | set ERROR_REPORTING=FALSE 13 | 14 | mkdir tmp 2>NUL 15 | 16 | %PYTHON% -c "" >tmp/stdout.txt 2>tmp/stderr.txt 17 | if %ERRORLEVEL% == 0 goto :check_pip 18 | echo Couldn't launch python 19 | goto :show_stdout_stderr 20 | 21 | :check_pip 22 | %PYTHON% -mpip --help >tmp/stdout.txt 2>tmp/stderr.txt 23 | if %ERRORLEVEL% == 0 goto :start_venv 24 | if "%PIP_INSTALLER_LOCATION%" == "" goto :show_stdout_stderr 25 | %PYTHON% "%PIP_INSTALLER_LOCATION%" >tmp/stdout.txt 2>tmp/stderr.txt 26 | if %ERRORLEVEL% == 0 goto :start_venv 27 | echo Couldn't install pip 28 | goto :show_stdout_stderr 29 | 30 | :start_venv 31 | if ["%VENV_DIR%"] == ["-"] goto :skip_venv 32 | if ["%SKIP_VENV%"] == ["1"] goto :skip_venv 33 | 34 | dir "%VENV_DIR%\Scripts\Python.exe" >tmp/stdout.txt 2>tmp/stderr.txt 35 | if %ERRORLEVEL% == 0 goto :activate_venv 36 | 37 | for /f "delims=" %%i in ('CALL %PYTHON% -c "import sys; print(sys.executable)"') do set PYTHON_FULLNAME="%%i" 38 | echo Creating venv in directory %VENV_DIR% using python %PYTHON_FULLNAME% 39 | %PYTHON_FULLNAME% -m venv "%VENV_DIR%" >tmp/stdout.txt 2>tmp/stderr.txt 40 | if %ERRORLEVEL% == 0 goto :upgrade_pip 41 | echo Unable to create venv in directory "%VENV_DIR%" 42 | goto :show_stdout_stderr 43 | 44 | :upgrade_pip 45 | "%VENV_DIR%\Scripts\Python.exe" -m pip install --upgrade pip 46 | if %ERRORLEVEL% == 0 goto :activate_venv 47 | echo Warning: Failed to upgrade PIP version 48 | 49 | :activate_venv 50 | set PYTHON="%VENV_DIR%\Scripts\Python.exe" 51 | call "%VENV_DIR%\Scripts\activate.bat" 52 | echo venv %PYTHON% 53 | 54 | :skip_venv 55 | if [%ACCELERATE%] == ["True"] goto :accelerate 56 | goto :launch 57 | 58 | :accelerate 59 | echo Checking for accelerate 60 | set ACCELERATE="%VENV_DIR%\Scripts\accelerate.exe" 61 | if EXIST %ACCELERATE% goto :accelerate_launch 62 | 63 | :launch 64 | %PYTHON% launch.py %* 65 | if EXIST tmp/restart goto :skip_venv 66 | pause 67 | exit /b 68 | 69 | :accelerate_launch 70 | echo Accelerating 71 | %ACCELERATE% launch --num_cpu_threads_per_process=6 launch.py 72 | if EXIST tmp/restart goto :skip_venv 73 | pause 74 | exit /b 75 | 76 | :show_stdout_stderr 77 | 78 | echo. 79 | echo exit code: %errorlevel% 80 | 81 | for /f %%i in ("tmp\stdout.txt") do set size=%%~zi 82 | if %size% equ 0 goto :show_stderr 83 | echo. 84 | echo stdout: 85 | type tmp\stdout.txt 86 | 87 | :show_stderr 88 | for /f %%i in ("tmp\stderr.txt") do set size=%%~zi 89 | if %size% equ 0 goto :show_stderr 90 | echo. 91 | echo stderr: 92 | type tmp\stderr.txt 93 | 94 | :endofscript 95 | 96 | echo. 97 | echo Launch unsuccessful. Exiting. 98 | pause 99 | --------------------------------------------------------------------------------