├── README.md ├── models.yaml ├── globals.py ├── LICENSE ├── install.sh ├── invoke_ai_gui_colab.ipynb └── cross_attention_control.py /README.md: -------------------------------------------------------------------------------- 1 | # invoke-ai-gui-colab 2 | 3 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/peaashmeter/invoke-ai-gui-colab/blob/main/invoke_ai_gui_colab.ipynb) 4 | 5 | :ru: Этот код позволяет запускать веб-интерфейс [Invoke ai](https://github.com/invoke-ai/InvokeAI) в Google Colab. В процессе установки автоматически подгружается аниме-модель [Anything](https://huggingface.co/andite/anything-v4.0). 6 | 7 | :us: It's a Colab adaptation for easy & automatic setup of [Invoke ai](https://github.com/invoke-ai/InvokeAI) web gui. It uses [Anything](https://huggingface.co/andite/anything-v4.0) – a model for weebs – because I personally love it. Enjoy ^_^ 8 | -------------------------------------------------------------------------------- /models.yaml: -------------------------------------------------------------------------------- 1 | # This file describes the alternative machine learning modelsv1 2 | # available to InvokeAI script. 3 | # 4 | # To add a new model, follow the examples below. Each 5 | # model requires a model config file, a weights file, 6 | # and the width and height of the images it 7 | # was trained on. 8 | Anything_v5_vaefixed: 9 | weights: models/ldm/stable-diffusion-v1/Anything-V5.ckpt 10 | description: Cool anime model 11 | config: configs/stable-diffusion/v1-inference.yaml 12 | width: 512 13 | height: 512 14 | vae: models/ldm/stable-diffusion-v1/Anything-V5.ckpt 15 | default: true 16 | Anything_v5_default: 17 | weights: models/ldm/stable-diffusion-v1/Anything-V5.ckpt 18 | description: Cool anime model 19 | config: configs/stable-diffusion/v1-inference.yaml 20 | width: 512 21 | height: 512 22 | vae: models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt -------------------------------------------------------------------------------- /globals.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ldm.invoke.globals defines a small number of global variables that would 3 | otherwise have to be passed through long and complex call chains. 4 | 5 | It defines a Namespace object named "Globals" that contains 6 | the attributes: 7 | 8 | - root - the root directory under which "models" and "outputs" can be found 9 | - initfile - path to the initialization file 10 | - try_patchmatch - option to globally disable loading of 'patchmatch' module 11 | ''' 12 | 13 | import os 14 | import os.path as osp 15 | from argparse import Namespace 16 | 17 | Globals = Namespace() 18 | 19 | # This is usually overwritten by the command line and/or environment variables 20 | Globals.root = osp.abspath('/root/invokeai') 21 | 22 | # Where to look for the initialization file 23 | Globals.initfile = 'invokeai.init' 24 | 25 | # Awkward workaround to disable attempted loading of pypatchmatch 26 | # which is causing CI tests to error out. 27 | Globals.try_patchmatch = True 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ivan Yuriev 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # This script will install git and conda (if not found on the PATH variable) 4 | # using micromamba (an 8mb static-linked single-file binary, conda replacement). 5 | # For users who already have git and conda, this step will be skipped. 6 | 7 | # Next, it'll checkout the project's git repo, if necessary. 8 | # Finally, it'll create the conda environment and configure InvokeAI. 9 | 10 | # This enables a user to install this project without manually installing conda and git. 11 | 12 | cd "$(dirname "${BASH_SOURCE[0]}")" 13 | 14 | echo "Заходит как-то улитка в бар..." 15 | 16 | # config 17 | export MAMBA_ROOT_PREFIX="$(pwd)/installer_files/mamba" 18 | INSTALL_ENV_DIR="$(pwd)/installer_files/env" 19 | MICROMAMBA_DOWNLOAD_URL="https://micro.mamba.pm/api/micromamba/linux-64/latest" 20 | REPO_URL="https://github.com/invoke-ai/InvokeAI.git" 21 | umamba_exists="F" 22 | 23 | # figure out whether git and conda needs to be installed 24 | if [ -e "$INSTALL_ENV_DIR" ]; then export PATH="$INSTALL_ENV_DIR/bin:$PATH"; fi 25 | 26 | PACKAGES_TO_INSTALL="" 27 | if ! $(which conda) -V &>/dev/null; then PACKAGES_TO_INSTALL="$PACKAGES_TO_INSTALL conda"; fi 28 | if ! which git &>/dev/null; then PACKAGES_TO_INSTALL="$PACKAGES_TO_INSTALL git"; fi 29 | 30 | if "$MAMBA_ROOT_PREFIX/micromamba" --version &>/dev/null; then umamba_exists="T"; fi 31 | 32 | # (if necessary) install git and conda into a contained environment 33 | if [ "$PACKAGES_TO_INSTALL" != "" ]; then 34 | # download micromamba 35 | if [ "$umamba_exists" == "F" ]; then 36 | echo "Downloading micromamba from $MICROMAMBA_DOWNLOAD_URL to $MAMBA_ROOT_PREFIX/micromamba" 37 | 38 | mkdir -p "$MAMBA_ROOT_PREFIX" 39 | curl -L "$MICROMAMBA_DOWNLOAD_URL" | tar -xvjO bin/micromamba > "$MAMBA_ROOT_PREFIX/micromamba" 40 | 41 | chmod u+x "$MAMBA_ROOT_PREFIX/micromamba" 42 | 43 | # test the mamba binary 44 | echo "Micromamba version:" 45 | "$MAMBA_ROOT_PREFIX/micromamba" --version 46 | fi 47 | 48 | # create the installer env 49 | if [ ! -e "$INSTALL_ENV_DIR" ]; then 50 | "$MAMBA_ROOT_PREFIX/micromamba" create -y --prefix "$INSTALL_ENV_DIR" 51 | fi 52 | 53 | echo "Packages to install:$PACKAGES_TO_INSTALL" 54 | 55 | "$MAMBA_ROOT_PREFIX/micromamba" install -y --prefix "$INSTALL_ENV_DIR" -c conda-forge $PACKAGES_TO_INSTALL 56 | 57 | if [ ! -e "$INSTALL_ENV_DIR" ]; then 58 | echo "There was a problem while initializing micromamba. Cannot continue." 59 | exit 60 | fi 61 | fi 62 | 63 | if [ -e "$INSTALL_ENV_DIR" ]; then export PATH="$INSTALL_ENV_DIR/bin:$PATH"; fi 64 | 65 | # get the repo (and load into the current directory) 66 | 67 | git init 68 | git config --local init.defaultBranch main 69 | git remote add origin "$REPO_URL" 70 | git fetch 71 | git checkout origin/main -ft 72 | 73 | 74 | # create the environment 75 | CONDA_BASEPATH=$(conda info --base) 76 | source "$CONDA_BASEPATH/etc/profile.d/conda.sh" # otherwise conda complains about 'shell not initialized' (needed when running in a script) 77 | 78 | conda activate 79 | echo "Linux system detected. Installing CUDA and CPU support." 80 | ln -sf environments-and-requirements/environment-lin-cuda.yml environment.yml 81 | conda env update 82 | 83 | status=$? 84 | 85 | if test $status -ne 0 86 | then 87 | echo "Something went wrong while installing Python libraries and cannot continue." 88 | echo "See https://invoke-ai.github.io/InvokeAI/INSTALL_SOURCE#troubleshooting for troubleshooting" 89 | echo "tips, or visit https://invoke-ai.github.io/InvokeAI/#installation for alternative" 90 | echo "installation methods" 91 | else 92 | ln -sf ./source_installer/invoke.sh.in ./invoke.sh 93 | ln -sf ./source_installer/update.sh.in ./update.sh 94 | chmod a+rx ./source_installer/invoke.sh.in 95 | chmod a+rx ./source_installer/update.sh.in 96 | 97 | conda activate invokeai 98 | # configure 99 | echo "Calling the configure_invokeai script" 100 | python scripts/configure_invokeai.py 101 | status=$? 102 | if test $status -ne 0 103 | then 104 | echo "The configure_invoke.py script crashed or was cancelled." 105 | echo "InvokeAI is not ready to run. Try again by running" 106 | echo "update.sh in this directory." 107 | else 108 | # tell the user their next steps 109 | echo "You can now start generating images by running invoke.sh (inside this folder), using ./invoke.sh" 110 | fi 111 | fi 112 | 113 | conda activate invokeai 114 | -------------------------------------------------------------------------------- /invoke_ai_gui_colab.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "V78TAa3W-9WN" 7 | }, 8 | "source": [ 9 | "# Для того, чтобы установить Invoke Ai, запустите код в каждой из форм.\n", 10 | "Для запуска сервера потребуется аккаунт на [ngrok](https://dashboard.ngrok.com/login)\n", 11 | "\n", 12 | "# To use Invoke Ai, launch the code in every form below.\n", 13 | "You will need an account on [ngrok](https://dashboard.ngrok.com/login)\n", 14 | "\n", 15 | "**[Repo](https://github.com/peaashmeter/invoke-ai-gui-colab)**" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": { 22 | "id": "hqe0rPfs6wJ3" 23 | }, 24 | "outputs": [], 25 | "source": [ 26 | "from google.colab import drive\n", 27 | "drive.mount('/content/drive')" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": { 34 | "id": "G5LfQaWe0bHK" 35 | }, 36 | "outputs": [], 37 | "source": [ 38 | "#@title Во время этого этапа Colab крашится, это нормально\n", 39 | "!pip install -q condacolab\n", 40 | "import condacolab\n", 41 | "condacolab.install()" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": { 48 | "id": "U-4vMQGrGXx8" 49 | }, 50 | "outputs": [], 51 | "source": [ 52 | "#@title Подготовка репозитория Invoke Ai { display-mode: \"form\" }\n", 53 | "%cd /home\n", 54 | "\n", 55 | "!git clone -n https://github.com/invoke-ai/InvokeAI.git\n", 56 | "%cd InvokeAI\n", 57 | "!git checkout 69b15024a93ecfbbb555c956f745e9748c12f044\n", 58 | "\n", 59 | "#dnspython fix\n", 60 | "!echo ' - dnspython==2.2.1' >> environments-and-requirements/environment-lin-cuda.yml\n", 61 | "\n", 62 | "#huggingface fix\n", 63 | "!echo ' - huggingface-hub==0.11.1' >> environments-and-requirements/environment-lin-cuda.yml\n", 64 | "\n", 65 | "#Werkzeug fix\n", 66 | "!echo ' - Werkzeug==2.2.2' >> environments-and-requirements/environment-lin-cuda.yml\n", 67 | "\n", 68 | "#opencv fix\n", 69 | "!sed -i 's/opencv-python==4.5.5.64/opencv-python==4.8.0.74/g' environments-and-requirements/environment-lin-cuda.yml\n", 70 | "\n", 71 | "!ln -sf environments-and-requirements/environment-lin-cuda.yml environment.yml\n", 72 | "!ls -la\n", 73 | "\n", 74 | "%cd ..\n", 75 | "\n", 76 | "!git clone https://github.com/peaashmeter/invoke-ai-gui-colab.git\n", 77 | "\n", 78 | "\n", 79 | "%cd invoke-ai-gui-colab\n", 80 | "\n", 81 | "%cp cross_attention_control.py ../InvokeAI/ldm/models/diffusion/cross_attention_control.py\n", 82 | "%cp globals.py ../InvokeAI/ldm/invoke/globals.py\n", 83 | "%cd /home/invoke-ai-gui-colab" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "metadata": { 90 | "id": "0z1VLnSGQyrx" 91 | }, 92 | "outputs": [], 93 | "source": [ 94 | "#@title Установка зависимостей, в норме занимает ~10 минут { display-mode: \"form\" }\n", 95 | "%cd ../InvokeAI\n", 96 | "!pip install pyngrok --quiet\n", 97 | "!conda env update\n", 98 | "!source activate invokeai ; python scripts/configure_invokeai.py --skip-sd-weights --yes\n", 99 | "\n", 100 | "%cp ../invoke-ai-gui-colab/models.yaml /root/invokeai/configs\n", 101 | "\n" 102 | ] 103 | }, 104 | { 105 | "cell_type": "code", 106 | "execution_count": null, 107 | "metadata": { 108 | "id": "OzFl5cmndBrd" 109 | }, 110 | "outputs": [], 111 | "source": [ 112 | "#@title Установка модели Anything { display-mode: \"form\" }\n", 113 | "import os\n", 114 | "\n", 115 | "#@markdown Установка модели Anything V5.\n", 116 | "#@markdown\n", 117 | "#@markdown Anything V5 - 4 ГБ\n", 118 | "\n", 119 | "os.system('cd /')\n", 120 | "os.system('mkdir -p /root/invokeai/models/ldm/stable-diffusion-v1')\n", 121 | "\n", 122 | "if not os.path.exists('/content/drive/MyDrive/models/Anything-V5.ckpt'):\n", 123 | " print('Производится скачивание модели Anything-V5')\n", 124 | " os.system('mkdir -p /content/drive/MyDrive/models/')\n", 125 | " os.system('wget -O /content/drive/MyDrive/models/Anything-V5.ckpt https://civitai.com/api/download/models/33672')\n", 126 | "\n", 127 | "\n", 128 | "if os.path.exists('/content/drive/MyDrive/models/Anything-V5-vae.safetensors'):\n", 129 | " None\n", 130 | "else:\n", 131 | " print('Производится скачивание Anything-VAE')\n", 132 | " os.system('wget -O /content/drive/MyDrive/models/Anything-V5-vae.safetensors https://huggingface.co/stablediffusionapi/anything-v5/resolve/main/vae/diffusion_pytorch_model.safetensors')\n", 133 | "\n", 134 | "if os.path.exists('/content/drive/MyDrive/models/vae-ft-mse-840000-ema-pruned.ckpt'):\n", 135 | " None\n", 136 | "else:\n", 137 | " print('Производится скачивание SD-VAE')\n", 138 | " os.system('wget -O /content/drive/MyDrive/models/vae-ft-mse-840000-ema-pruned.ckpt https://huggingface.co/stabilityai/sd-vae-ft-mse-original/resolve/main/vae-ft-mse-840000-ema-pruned.ckpt')\n", 139 | "\n", 140 | "if not os.path.exists('/content/drive/MyDrive/models/vae-ft-mse-840000-ema-pruned.ckpt') or not os.path.exists('/content/drive/MyDrive/models/Anything-V5-vae.safetensors') or not os.path.exists('/content/drive/MyDrive/models/Anything-V5.ckpt'):\n", 141 | " print('Произошла ошибка, попробуйте еще раз\\nUnexpected error')\n", 142 | "else:\n", 143 | " print('Make sure you accept the terms at https://huggingface.co/stablediffusionapi/anything-v5 and https://huggingface.co/stabilityai/sd-vae-ft-mse-original')\n", 144 | "\n", 145 | "\n", 146 | "%mkdir -p /root/invokeai/models/ldm/stable-diffusion-v1/\n", 147 | "!ln -s /content/drive/MyDrive/models/Anything-V5.ckpt /root/invokeai/models/ldm/stable-diffusion-v1/Anything-V5.ckpt\n", 148 | "!ls -l /root/invokeai/models/ldm/stable-diffusion-v1/Anything-V5.ckpt\n", 149 | "\n", 150 | "\n", 151 | "!ln -s /content/drive/MyDrive/models/vae-ft-mse-840000-ema-pruned.ckpt /root/invokeai/models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt\n", 152 | "!ls -l /root/invokeai/models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt\n", 153 | "\n", 154 | "!ln -s /content/drive/MyDrive/models/Anything-V5-vae.safetensors /root/invokeai/models/ldm/stable-diffusion-v1/Anything-V5-vae.safetensors\n", 155 | "!ls -l /root/invokeai/models/ldm/stable-diffusion-v1/Anything-V5-vae.safetensors" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "metadata": { 162 | "id": "C826ebQkb_OQ" 163 | }, 164 | "outputs": [], 165 | "source": [ 166 | "#@title Запуск сервера { display-mode: \"both\" }\n", 167 | "#@markdown Вставьте токен с [ngrok](https://dashboard.ngrok.com/get-started/your-authtoken) в поле перед выполнением кода.\n", 168 | "\n", 169 | "ngrok_token = \"\" #@param {type:\"string\"}\n", 170 | "nsfw_checker = 0 #@param {type:\"slider\", min:0, max:1, step:1}\n", 171 | "used_vae = \"Anything-V5.vae\" #@param [\"Anything-V5.vae\", \"vae-ft-mse-840000-ema-pruned\"]\n", 172 | "\n", 173 | "import os\n", 174 | "from pyngrok import ngrok\n", 175 | "\n", 176 | "ngrok.kill()\n", 177 | "ngrok.set_auth_token(ngrok_token)\n", 178 | "public_url = ngrok.connect(9090).public_url\n", 179 | "print(f'Invoke Ai public url: {public_url}')\n", 180 | "\n", 181 | "%cd /home/InvokeAI\n", 182 | "\n", 183 | "model_name = \"Anything_v5_vaefixed\"\n", 184 | "if used_vae == \"vae-ft-mse-840000-ema-pruned\":\n", 185 | " if model_name == \"Anything_v5_vaefixed\":\n", 186 | " model_name = \"Anything_v5_default\"\n", 187 | "\n", 188 | "\n", 189 | "if nsfw_checker:\n", 190 | " !source activate invokeai ; python scripts/invoke.py --web --model $model_name\n", 191 | "else:\n", 192 | " !source activate invokeai ; python scripts/invoke.py --web --no-nsfw_checker --model $model_name" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": { 198 | "id": "ckMayfYdkOmF" 199 | }, 200 | "source": [ 201 | "# После запуска сервера нужно перейти по первой ссылке из вывода (Invoke Ai public url)\n", 202 | "Если интерфейс не прогрузился, надо перезагрузить страницу\n", 203 | "---\n", 204 | "\n" 205 | ] 206 | } 207 | ], 208 | "metadata": { 209 | "accelerator": "GPU", 210 | "colab": { 211 | "provenance": [] 212 | }, 213 | "gpuClass": "standard", 214 | "kernelspec": { 215 | "display_name": "Python 3", 216 | "language": "python", 217 | "name": "python3" 218 | }, 219 | "language_info": { 220 | "codemirror_mode": { 221 | "name": "ipython", 222 | "version": 3 223 | }, 224 | "file_extension": ".py", 225 | "mimetype": "text/x-python", 226 | "name": "python", 227 | "nbconvert_exporter": "python", 228 | "pygments_lexer": "ipython3", 229 | "version": "3.9.0 (tags/v3.9.0:9cf6752, Oct 5 2020, 15:34:40) [MSC v.1927 64 bit (AMD64)]" 230 | }, 231 | "orig_nbformat": 4, 232 | "vscode": { 233 | "interpreter": { 234 | "hash": "07f22fa44ccf32757507f6b49dc71cbc8d838934a5626a47cf7fbf3cda75c8a2" 235 | } 236 | } 237 | }, 238 | "nbformat": 4, 239 | "nbformat_minor": 0 240 | } 241 | -------------------------------------------------------------------------------- /cross_attention_control.py: -------------------------------------------------------------------------------- 1 | import enum 2 | import math 3 | from typing import Optional, Callable 4 | 5 | import psutil 6 | import torch 7 | from torch import nn 8 | 9 | # adapted from bloc97's CrossAttentionControl colab 10 | # https://github.com/bloc97/CrossAttentionControl 11 | 12 | class Arguments: 13 | def __init__(self, edited_conditioning: torch.Tensor, edit_opcodes: list[tuple], edit_options: dict): 14 | """ 15 | :param edited_conditioning: if doing cross-attention control, the edited conditioning [1 x 77 x 768] 16 | :param edit_opcodes: if doing cross-attention control, a list of difflib.SequenceMatcher-like opcodes describing how to map original conditioning tokens to edited conditioning tokens (only the 'equal' opcode is required) 17 | :param edit_options: if doing cross-attention control, per-edit options. there should be 1 item in edit_options for each item in edit_opcodes. 18 | """ 19 | # todo: rewrite this to take embedding fragments rather than a single edited_conditioning vector 20 | self.edited_conditioning = edited_conditioning 21 | self.edit_opcodes = edit_opcodes 22 | 23 | if edited_conditioning is not None: 24 | assert len(edit_opcodes) == len(edit_options), \ 25 | "there must be 1 edit_options dict for each edit_opcodes tuple" 26 | non_none_edit_options = [x for x in edit_options if x is not None] 27 | assert len(non_none_edit_options)>0, "missing edit_options" 28 | if len(non_none_edit_options)>1: 29 | print('warning: cross-attention control options are not working properly for >1 edit') 30 | self.edit_options = non_none_edit_options[0] 31 | 32 | 33 | class CrossAttentionType(enum.Enum): 34 | SELF = 1 35 | TOKENS = 2 36 | 37 | 38 | class Context: 39 | 40 | cross_attention_mask: Optional[torch.Tensor] 41 | cross_attention_index_map: Optional[torch.Tensor] 42 | 43 | class Action(enum.Enum): 44 | NONE = 0 45 | SAVE = 1, 46 | APPLY = 2 47 | 48 | def __init__(self, arguments: Arguments, step_count: int): 49 | """ 50 | :param arguments: Arguments for the cross-attention control process 51 | :param step_count: The absolute total number of steps of diffusion (for img2img this is likely larger than the number of steps that will actually run) 52 | """ 53 | self.cross_attention_mask = None 54 | self.cross_attention_index_map = None 55 | self.self_cross_attention_action = Context.Action.NONE 56 | self.tokens_cross_attention_action = Context.Action.NONE 57 | self.arguments = arguments 58 | self.step_count = step_count 59 | 60 | self.self_cross_attention_module_identifiers = [] 61 | self.tokens_cross_attention_module_identifiers = [] 62 | 63 | self.saved_cross_attention_maps = {} 64 | 65 | self.clear_requests(cleanup=True) 66 | 67 | def register_cross_attention_modules(self, model): 68 | for name,module in get_cross_attention_modules(model, CrossAttentionType.SELF): 69 | if name in self.self_cross_attention_module_identifiers: 70 | assert False, f"name {name} cannot appear more than once" 71 | self.self_cross_attention_module_identifiers.append(name) 72 | for name,module in get_cross_attention_modules(model, CrossAttentionType.TOKENS): 73 | if name in self.tokens_cross_attention_module_identifiers: 74 | assert False, f"name {name} cannot appear more than once" 75 | self.tokens_cross_attention_module_identifiers.append(name) 76 | 77 | def request_save_attention_maps(self, cross_attention_type: CrossAttentionType): 78 | if cross_attention_type == CrossAttentionType.SELF: 79 | self.self_cross_attention_action = Context.Action.SAVE 80 | else: 81 | self.tokens_cross_attention_action = Context.Action.SAVE 82 | 83 | def request_apply_saved_attention_maps(self, cross_attention_type: CrossAttentionType): 84 | if cross_attention_type == CrossAttentionType.SELF: 85 | self.self_cross_attention_action = Context.Action.APPLY 86 | else: 87 | self.tokens_cross_attention_action = Context.Action.APPLY 88 | 89 | def is_tokens_cross_attention(self, module_identifier) -> bool: 90 | return module_identifier in self.tokens_cross_attention_module_identifiers 91 | 92 | def get_should_save_maps(self, module_identifier: str) -> bool: 93 | if module_identifier in self.self_cross_attention_module_identifiers: 94 | return self.self_cross_attention_action == Context.Action.SAVE 95 | elif module_identifier in self.tokens_cross_attention_module_identifiers: 96 | return self.tokens_cross_attention_action == Context.Action.SAVE 97 | return False 98 | 99 | def get_should_apply_saved_maps(self, module_identifier: str) -> bool: 100 | if module_identifier in self.self_cross_attention_module_identifiers: 101 | return self.self_cross_attention_action == Context.Action.APPLY 102 | elif module_identifier in self.tokens_cross_attention_module_identifiers: 103 | return self.tokens_cross_attention_action == Context.Action.APPLY 104 | return False 105 | 106 | def get_active_cross_attention_control_types_for_step(self, percent_through:float=None)\ 107 | -> list[CrossAttentionType]: 108 | """ 109 | Should cross-attention control be applied on the given step? 110 | :param percent_through: How far through the step sequence are we (0.0=pure noise, 1.0=completely denoised image). Expected range 0.0..<1.0. 111 | :return: A list of attention types that cross-attention control should be performed for on the given step. May be []. 112 | """ 113 | if percent_through is None: 114 | return [CrossAttentionType.SELF, CrossAttentionType.TOKENS] 115 | 116 | opts = self.arguments.edit_options 117 | to_control = [] 118 | if opts['s_start'] <= percent_through < opts['s_end']: 119 | to_control.append(CrossAttentionType.SELF) 120 | if opts['t_start'] <= percent_through < opts['t_end']: 121 | to_control.append(CrossAttentionType.TOKENS) 122 | return to_control 123 | 124 | def save_slice(self, identifier: str, slice: torch.Tensor, dim: Optional[int], offset: int, 125 | slice_size: Optional[int]): 126 | if identifier not in self.saved_cross_attention_maps: 127 | self.saved_cross_attention_maps[identifier] = { 128 | 'dim': dim, 129 | 'slice_size': slice_size, 130 | 'slices': {offset or 0: slice} 131 | } 132 | else: 133 | self.saved_cross_attention_maps[identifier]['slices'][offset or 0] = slice 134 | 135 | def get_slice(self, identifier: str, requested_dim: Optional[int], requested_offset: int, slice_size: int): 136 | saved_attention_dict = self.saved_cross_attention_maps[identifier] 137 | if requested_dim is None: 138 | if saved_attention_dict['dim'] is not None: 139 | raise RuntimeError(f"dim mismatch: expected dim=None, have {saved_attention_dict['dim']}") 140 | return saved_attention_dict['slices'][0] 141 | 142 | if saved_attention_dict['dim'] == requested_dim: 143 | if slice_size != saved_attention_dict['slice_size']: 144 | raise RuntimeError( 145 | f"slice_size mismatch: expected slice_size={slice_size}, have {saved_attention_dict['slice_size']}") 146 | return saved_attention_dict['slices'][requested_offset] 147 | 148 | if saved_attention_dict['dim'] is None: 149 | whole_saved_attention = saved_attention_dict['slices'][0] 150 | if requested_dim == 0: 151 | return whole_saved_attention[requested_offset:requested_offset + slice_size] 152 | elif requested_dim == 1: 153 | return whole_saved_attention[:, requested_offset:requested_offset + slice_size] 154 | 155 | raise RuntimeError(f"Cannot convert dim {saved_attention_dict['dim']} to requested dim {requested_dim}") 156 | 157 | def get_slicing_strategy(self, identifier: str) -> tuple[Optional[int], Optional[int]]: 158 | saved_attention = self.saved_cross_attention_maps.get(identifier, None) 159 | if saved_attention is None: 160 | return None, None 161 | return saved_attention['dim'], saved_attention['slice_size'] 162 | 163 | def clear_requests(self, cleanup=True): 164 | self.tokens_cross_attention_action = Context.Action.NONE 165 | self.self_cross_attention_action = Context.Action.NONE 166 | if cleanup: 167 | self.saved_cross_attention_maps = {} 168 | 169 | def offload_saved_attention_slices_to_cpu(self): 170 | for key, map_dict in self.saved_cross_attention_maps.items(): 171 | for offset, slice in map_dict['slices'].items(): 172 | map_dict[offset] = slice.to('cpu') 173 | 174 | 175 | 176 | class InvokeAICrossAttentionMixin: 177 | """ 178 | Enable InvokeAI-flavoured CrossAttention calculation, which does aggressive low-memory slicing and calls 179 | through both to an attention_slice_wrangler and a slicing_strategy_getter for custom attention map wrangling 180 | and dymamic slicing strategy selection. 181 | """ 182 | def __init__(self): 183 | self.mem_total_gb = psutil.virtual_memory().total // (1 << 30) 184 | self.attention_slice_wrangler = None 185 | self.slicing_strategy_getter = None 186 | self.attention_slice_calculated_callback = None 187 | 188 | def set_attention_slice_wrangler(self, wrangler: Optional[Callable[[nn.Module, torch.Tensor, int, int, int], torch.Tensor]]): 189 | ''' 190 | Set custom attention calculator to be called when attention is calculated 191 | :param wrangler: Callback, with args (module, suggested_attention_slice, dim, offset, slice_size), 192 | which returns either the suggested_attention_slice or an adjusted equivalent. 193 | `module` is the current CrossAttention module for which the callback is being invoked. 194 | `suggested_attention_slice` is the default-calculated attention slice 195 | `dim` is -1 if the attenion map has not been sliced, or 0 or 1 for dimension-0 or dimension-1 slicing. 196 | If `dim` is >= 0, `offset` and `slice_size` specify the slice start and length. 197 | 198 | Pass None to use the default attention calculation. 199 | :return: 200 | ''' 201 | self.attention_slice_wrangler = wrangler 202 | 203 | def set_slicing_strategy_getter(self, getter: Optional[Callable[[nn.Module], tuple[int,int]]]): 204 | self.slicing_strategy_getter = getter 205 | 206 | def set_attention_slice_calculated_callback(self, callback: Optional[Callable[[torch.Tensor], None]]): 207 | self.attention_slice_calculated_callback = callback 208 | 209 | def einsum_lowest_level(self, query, key, value, dim, offset, slice_size): 210 | # calculate attention scores 211 | #attention_scores = torch.einsum('b i d, b j d -> b i j', q, k) 212 | attention_scores = torch.baddbmm( 213 | torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), 214 | query, 215 | key.transpose(-1, -2), 216 | beta=0, 217 | alpha=self.scale, 218 | ) 219 | 220 | # calculate attention slice by taking the best scores for each latent pixel 221 | default_attention_slice = attention_scores.softmax(dim=-1, dtype=attention_scores.dtype) 222 | attention_slice_wrangler = self.attention_slice_wrangler 223 | if attention_slice_wrangler is not None: 224 | attention_slice = attention_slice_wrangler(self, default_attention_slice, dim, offset, slice_size) 225 | else: 226 | attention_slice = default_attention_slice 227 | 228 | if self.attention_slice_calculated_callback is not None: 229 | self.attention_slice_calculated_callback(attention_slice, dim, offset, slice_size) 230 | 231 | hidden_states = torch.bmm(attention_slice, value) 232 | return hidden_states 233 | 234 | def einsum_op_slice_dim0(self, q, k, v, slice_size): 235 | r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) 236 | for i in range(0, q.shape[0], slice_size): 237 | end = i + slice_size 238 | r[i:end] = self.einsum_lowest_level(q[i:end], k[i:end], v[i:end], dim=0, offset=i, slice_size=slice_size) 239 | return r 240 | 241 | def einsum_op_slice_dim1(self, q, k, v, slice_size): 242 | r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) 243 | for i in range(0, q.shape[1], slice_size): 244 | end = i + slice_size 245 | r[:, i:end] = self.einsum_lowest_level(q[:, i:end], k, v, dim=1, offset=i, slice_size=slice_size) 246 | return r 247 | 248 | def einsum_op_mps_v1(self, q, k, v): 249 | if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096 250 | return self.einsum_lowest_level(q, k, v, None, None, None) 251 | else: 252 | slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1])) 253 | return self.einsum_op_slice_dim1(q, k, v, slice_size) 254 | 255 | def einsum_op_mps_v2(self, q, k, v): 256 | if self.mem_total_gb > 8 and q.shape[1] <= 4096: 257 | return self.einsum_lowest_level(q, k, v, None, None, None) 258 | else: 259 | return self.einsum_op_slice_dim0(q, k, v, 1) 260 | 261 | def einsum_op_tensor_mem(self, q, k, v, max_tensor_mb): 262 | size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) 263 | if size_mb <= max_tensor_mb: 264 | return self.einsum_lowest_level(q, k, v, None, None, None) 265 | div = 1 << int((size_mb - 1) / max_tensor_mb).bit_length() 266 | if div <= q.shape[0]: 267 | return self.einsum_op_slice_dim0(q, k, v, q.shape[0] // div) 268 | return self.einsum_op_slice_dim1(q, k, v, max(q.shape[1] // div, 1)) 269 | 270 | def einsum_op_cuda(self, q, k, v): 271 | # check if we already have a slicing strategy (this should only happen during cross-attention controlled generation) 272 | slicing_strategy_getter = self.slicing_strategy_getter 273 | if slicing_strategy_getter is not None: 274 | (dim, slice_size) = slicing_strategy_getter(self) 275 | if dim is not None: 276 | # print("using saved slicing strategy with dim", dim, "slice size", slice_size) 277 | if dim == 0: 278 | return self.einsum_op_slice_dim0(q, k, v, slice_size) 279 | elif dim == 1: 280 | return self.einsum_op_slice_dim1(q, k, v, slice_size) 281 | 282 | # fallback for when there is no saved strategy, or saved strategy does not slice 283 | mem_free_total = get_mem_free_total(q.device) 284 | # Divide factor of safety as there's copying and fragmentation 285 | return self.einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) 286 | 287 | 288 | def get_invokeai_attention_mem_efficient(self, q, k, v): 289 | if q.device.type == 'cuda': 290 | #print("in get_attention_mem_efficient with q shape", q.shape, ", k shape", k.shape, ", free memory is", get_mem_free_total(q.device)) 291 | return self.einsum_op_cuda(q, k, v) 292 | 293 | if q.device.type == 'mps' or q.device.type == 'cpu': 294 | if self.mem_total_gb >= 32: 295 | return self.einsum_op_mps_v1(q, k, v) 296 | return self.einsum_op_mps_v2(q, k, v) 297 | 298 | # Smaller slices are faster due to L2/L3/SLC caches. 299 | # Tested on i7 with 8MB L3 cache. 300 | return self.einsum_op_tensor_mem(q, k, v, 32) 301 | 302 | 303 | 304 | def remove_cross_attention_control(model): 305 | remove_attention_function(model) 306 | 307 | 308 | def setup_cross_attention_control(model, context: Context): 309 | """ 310 | Inject attention parameters and functions into the passed in model to enable cross attention editing. 311 | 312 | :param model: The unet model to inject into. 313 | :param cross_attention_control_args: Arugments passeed to the CrossAttentionControl implementations 314 | :return: None 315 | """ 316 | 317 | # adapted from init_attention_edit 318 | device = context.arguments.edited_conditioning.device 319 | 320 | # urgh. should this be hardcoded? 321 | max_length = 77 322 | # mask=1 means use base prompt attention, mask=0 means use edited prompt attention 323 | mask = torch.zeros(max_length) 324 | indices_target = torch.arange(max_length, dtype=torch.long) 325 | indices = torch.arange(max_length, dtype=torch.long) 326 | for name, a0, a1, b0, b1 in context.arguments.edit_opcodes: 327 | if b0 < max_length: 328 | if name == "equal":# or (name == "replace" and a1 - a0 == b1 - b0): 329 | # these tokens have not been edited 330 | indices[b0:b1] = indices_target[a0:a1] 331 | mask[b0:b1] = 1 332 | 333 | context.register_cross_attention_modules(model) 334 | context.cross_attention_mask = mask.to(device) 335 | context.cross_attention_index_map = indices.to(device) 336 | inject_attention_function(model, context) 337 | 338 | 339 | def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[str, InvokeAICrossAttentionMixin]]: 340 | cross_attention_class: type = InvokeAICrossAttentionMixin 341 | # cross_attention_class: type = InvokeAIDiffusersCrossAttention 342 | which_attn = "attn1" if which is CrossAttentionType.SELF else "attn2" 343 | attention_module_tuples = [(name,module) for name, module in model.named_modules() if 344 | isinstance(module, cross_attention_class) and which_attn in name] 345 | cross_attention_modules_in_model_count = len(attention_module_tuples) 346 | expected_count = 16 347 | if cross_attention_modules_in_model_count != expected_count: 348 | # non-fatal error but .swap() won't work. 349 | print(f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model " + 350 | f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed " + 351 | f"or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, " + 352 | f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows " + 353 | f"what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not " + 354 | f"work properly until it is fixed.") 355 | return attention_module_tuples 356 | 357 | 358 | def inject_attention_function(unet, context: Context): 359 | # ORIGINAL SOURCE CODE: https://github.com/huggingface/diffusers/blob/91ddd2a25b848df0fa1262d4f1cd98c7ccb87750/src/diffusers/models/attention.py#L276 360 | 361 | def attention_slice_wrangler(module, suggested_attention_slice:torch.Tensor, dim, offset, slice_size): 362 | 363 | #memory_usage = suggested_attention_slice.element_size() * suggested_attention_slice.nelement() 364 | 365 | attention_slice = suggested_attention_slice 366 | 367 | if context.get_should_save_maps(module.identifier): 368 | #print(module.identifier, "saving suggested_attention_slice of shape", 369 | # suggested_attention_slice.shape, "dim", dim, "offset", offset) 370 | slice_to_save = attention_slice.to('cpu') if dim is not None else attention_slice 371 | context.save_slice(module.identifier, slice_to_save, dim=dim, offset=offset, slice_size=slice_size) 372 | elif context.get_should_apply_saved_maps(module.identifier): 373 | #print(module.identifier, "applying saved attention slice for dim", dim, "offset", offset) 374 | saved_attention_slice = context.get_slice(module.identifier, dim, offset, slice_size) 375 | 376 | # slice may have been offloaded to CPU 377 | saved_attention_slice = saved_attention_slice.to(suggested_attention_slice.device) 378 | 379 | if context.is_tokens_cross_attention(module.identifier): 380 | index_map = context.cross_attention_index_map 381 | remapped_saved_attention_slice = torch.index_select(saved_attention_slice, -1, index_map) 382 | this_attention_slice = suggested_attention_slice 383 | 384 | mask = context.cross_attention_mask 385 | saved_mask = mask 386 | this_mask = 1 - mask 387 | attention_slice = remapped_saved_attention_slice * saved_mask + \ 388 | this_attention_slice * this_mask 389 | else: 390 | # just use everything 391 | attention_slice = saved_attention_slice 392 | 393 | return attention_slice 394 | 395 | cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF) 396 | for identifier, module in cross_attention_modules: 397 | module.identifier = identifier 398 | try: 399 | module.set_attention_slice_wrangler(attention_slice_wrangler) 400 | module.set_slicing_strategy_getter( 401 | lambda module: context.get_slicing_strategy(identifier) 402 | ) 403 | except AttributeError as e: 404 | if is_attribute_error_about(e, 'set_attention_slice_wrangler'): 405 | print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") # TODO 406 | else: 407 | raise 408 | 409 | 410 | def remove_attention_function(unet): 411 | cross_attention_modules = get_cross_attention_modules(unet, CrossAttentionType.TOKENS) + get_cross_attention_modules(unet, CrossAttentionType.SELF) 412 | for identifier, module in cross_attention_modules: 413 | try: 414 | # clear wrangler callback 415 | module.set_attention_slice_wrangler(None) 416 | module.set_slicing_strategy_getter(None) 417 | except AttributeError as e: 418 | if is_attribute_error_about(e, 'set_attention_slice_wrangler'): 419 | print(f"TODO: implement set_attention_slice_wrangler for {type(module)}") 420 | else: 421 | raise 422 | 423 | 424 | def is_attribute_error_about(error: AttributeError, attribute: str): 425 | if hasattr(error, 'name'): # Python 3.10 426 | return error.name == attribute 427 | else: # Python 3.9 428 | return attribute in str(error) 429 | 430 | 431 | 432 | def get_mem_free_total(device): 433 | #only on cuda 434 | if not torch.cuda.is_available(): 435 | return None 436 | stats = torch.cuda.memory_stats(device) 437 | mem_active = stats['active_bytes.all.current'] 438 | mem_reserved = stats['reserved_bytes.all.current'] 439 | mem_free_cuda, _ = torch.cuda.mem_get_info(device) 440 | mem_free_torch = mem_reserved - mem_active 441 | mem_free_total = mem_free_cuda + mem_free_torch 442 | return mem_free_total 443 | 444 | --------------------------------------------------------------------------------