├── .circleci └── config.yml ├── .devcontainer ├── devcontainer.json └── local.example.env ├── .gitignore ├── .vscode ├── settings.json └── tasks.json ├── CHANGELOG.md ├── CONTRIBUTING.md ├── Dockerfile ├── LICENSE ├── README.md ├── __init__.py ├── api ├── app.py ├── convert_to_diffusers.py ├── device.py ├── download.py ├── download_checkpoint.py ├── extras │ ├── __init__.py │ └── upsample │ │ ├── __init__.py │ │ ├── models.py │ │ └── upsample.py ├── getPipeline.py ├── getScheduler.py ├── lib │ ├── __init__.py │ ├── prompts.py │ ├── textual_inversions.py │ ├── textual_inversions_test.py │ └── vars.py ├── loadModel.py ├── precision.py ├── send.py ├── server.py ├── status.py ├── tests.py ├── train_dreambooth.py └── utils │ ├── __init__.py │ └── storage │ ├── BaseStorage.py │ ├── BaseStorage_test.py │ ├── HTTPStorage.py │ ├── S3Storage.py │ ├── S3Storage_test.py │ ├── __init__.py │ └── __init__test.py ├── build ├── docs ├── internal_safetensor_cache_flow.md └── storage.md ├── install.sh ├── package.json ├── prime.sh ├── release.config.js ├── requirements.txt ├── run.sh ├── run_integration_tests_on_lambda.sh ├── scripts ├── devContainerPostCreate.sh ├── devContainerServer.sh ├── patchmatch-setup.sh ├── permutations.yaml └── permute.sh ├── test.py ├── tests ├── __init__.py ├── fixtures │ ├── dreambooth │ │ ├── alvan-nee-9M0tSjb-cpA-unsplash.jpeg │ │ ├── alvan-nee-Id1DBHv4fbg-unsplash.jpeg │ │ ├── alvan-nee-bQaAJCbNq3g-unsplash.jpeg │ │ ├── alvan-nee-brFsZ7qszSY-unsplash.jpeg │ │ └── alvan-nee-eoqnr8ikwFE-unsplash.jpeg │ ├── girl_with_pearl_earing_outpainting_in.png │ ├── overture-creations-5sI6fQgYIuo.png │ ├── overture-creations-5sI6fQgYIuo_mask.png │ └── sketch-mountains-input.jpg └── integration │ ├── __init__.py │ ├── conftest.py │ ├── lib.py │ ├── requirements.txt │ ├── test_attn_procs.py │ ├── test_build_download.py │ ├── test_cloud_cache.py │ ├── test_dreambooth.py │ ├── test_general.py │ ├── test_loras.py │ └── test_memory.py ├── touch ├── update.sh └── yarn.lock /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | 3 | jobs: 4 | build: 5 | docker: 6 | - image: cimg/python:3.9-node 7 | resource_class: medium 8 | 9 | # would have been nice, but not for $2,000/month! 10 | # machine: 11 | # image: ubuntu-2004-cuda-11.4:202110-01 12 | # resource_class: gpu.nvidia.small 13 | 14 | steps: 15 | - checkout 16 | 17 | - setup_remote_docker: 18 | docker_layer_caching: true 19 | 20 | - run: docker build -t gadicc/diffusers-api . 21 | 22 | # unit tests 23 | # - run: docker run gadicc/diffusers-api conda run --no-capture -n xformers pytest --cov=. --cov-report=xml --ignore=diffusers 24 | - run: docker run gadicc/diffusers-api pytest --cov=. --cov-report=xml --ignore=diffusers --ignore=Real-ESRGAN 25 | 26 | - run: echo $DOCKER_PASSWORD | docker login --username $DOCKER_USERNAME --password-stdin 27 | 28 | # push for non-semver branches (e.g. dev, feature branches) 29 | # - run: 30 | # name: Push to hub on branches not handled by semantic-release 31 | # command: | 32 | # SEMVER_BRANCHES=$(cat release.config.js | sed 's/module.exports = //' | sed 's/\/\/.*//' | jq .branches[]) 33 | # 34 | # if [[ ${SEMVER_BRANCHES[@]} =~ "$CIRCLE_BRANCH" ]] ; then 35 | # echo "Skipping because '\$CIRCLE_BRANCH' == '$CIRCLE_BRANCH'" 36 | # echo "Semantic-release will handle the publishing" 37 | # else 38 | # echo "docker push gadicc/diffusers-api:$CIRCLE_BRANCH" 39 | # docker build -t gadicc/diffusers-api:$CIRCLE_BRANCH . 40 | # docker push gadicc/diffusers-api:$CIRCLE_BRANCH 41 | # echo "Skipping integration tests" 42 | # circleci-agent step halt 43 | # fi 44 | 45 | # needed for later "apt install" steps 46 | - run: sudo apt-get update 47 | 48 | ## TODO. The below was a great first step, but in future, let's build 49 | # the container on the host, run docker remotely on lambda, and 50 | # publish the same built image if tests pass. 51 | 52 | # TODO, only run on main channel for releases (with sem-rel too) 53 | # integration tests 54 | - run: sudo apt install -yqq rsync pv 55 | - run: ./run_integration_tests_on_lambda.sh 56 | 57 | - run: 58 | name: Push to hub on branches not handled by semantic-release 59 | command: | 60 | SEMVER_BRANCHES=$(cat release.config.js | sed 's/module.exports = //' | sed 's/\/\/.*//' | jq .branches[]) 61 | 62 | if [[ ${SEMVER_BRANCHES[@]} =~ "$CIRCLE_BRANCH" ]] ; then 63 | echo "Skipping because '\$CIRCLE_BRANCH' == '$CIRCLE_BRANCH'" 64 | echo "Semantic-release will handle the publishing" 65 | else 66 | echo "docker push gadicc/diffusers-api:$CIRCLE_BRANCH" 67 | docker build -t gadicc/diffusers-api:$CIRCLE_BRANCH . 68 | docker push gadicc/diffusers-api:$CIRCLE_BRANCH 69 | # echo "Skipping integration tests" 70 | # circleci-agent step halt 71 | fi 72 | 73 | # deploy the image 74 | # - run: docker push company/app:$CIRCLE_BRANCH 75 | # https://github.com/semantic-release-plus/semantic-release-plus/tree/master/packages/plugins/docker 76 | - run: 77 | name: release 78 | command: | 79 | sudo apt-get install yarn 80 | yarn install 81 | yarn run semantic-release-plus -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the 2 | // README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-dockerfile 3 | { 4 | "name": "Existing Dockerfile", 5 | "build": { 6 | // Sets the run context to one level up instead of the .devcontainer folder. 7 | "context": "..", 8 | // Update the 'dockerFile' property if you aren't using the standard 'Dockerfile' filename. 9 | "dockerfile": "../Dockerfile" 10 | }, 11 | 12 | // Features to add to the dev container. More info: https://containers.dev/features. 13 | "features": { 14 | "ghcr.io/devcontainers/features/python:1": { 15 | // "version": "3.10" 16 | } 17 | }, 18 | 19 | // Use 'forwardPorts' to make a list of ports inside the container available locally. 20 | "forwardPorts": [8000], 21 | 22 | // Uncomment the next line to run commands after the container is created. 23 | "postCreateCommand": "scripts/devContainerPostCreate.sh", 24 | 25 | "customizations": { 26 | "vscode": { 27 | "extensions": [ 28 | "ryanluker.vscode-coverage-gutters", 29 | "fsevenm.run-it-on", 30 | "ms-python.black-formatter", 31 | ], 32 | "settings": { 33 | "python.pythonPath": "/opt/conda/bin/python" 34 | } 35 | } 36 | }, 37 | 38 | // Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root. 39 | // "remoteUser": "devcontainer" 40 | 41 | "mounts": [ 42 | "source=${localEnv:HOME}/root-cache,target=/root/.cache,type=bind,consistency=cached" 43 | ], 44 | 45 | "runArgs": [ 46 | "--gpus", 47 | "all", 48 | "--env-file", 49 | ".devcontainer/local.env" 50 | ] 51 | } 52 | -------------------------------------------------------------------------------- /.devcontainer/local.example.env: -------------------------------------------------------------------------------- 1 | # Useful environment variables: 2 | 3 | # AWS or S3-compatible storage credentials and buckets 4 | AWS_ACCESS_KEY_ID= 5 | AWS_SECRET_ACCESS_KEY= 6 | AWS_DEFAULT_REGION= 7 | AWS_S3_DEFAULT_BUCKET= 8 | # Only fill this in if your (non-AWS) provider has told you what to put here 9 | AWS_S3_ENDPOINT_URL= 10 | 11 | # To use a proxy, e.g. 12 | # https://github.com/kiri-art/docker-diffusers-api/blob/dev/CONTRIBUTING.md#local-https-caching-proxy 13 | # DDA_http_proxy=http://172.17.0.1:3128 14 | # DDA_https_proxy=http://172.17.0.1:3128 15 | 16 | # HuggingFace credentials 17 | HF_AUTH_TOKEN= 18 | HF_USERNAME= 19 | 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | /lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | permutations 131 | tests/output 132 | node_modules 133 | .devcontainer/local.env 134 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.testing.pytestArgs": [ 3 | "--cov=.", 4 | "--cov-report=xml", 5 | "--ignore=test.py", 6 | "--ignore=tests/integration", 7 | "--ignore=diffusers", 8 | // "unit_tests.py" 9 | // "." 10 | ], 11 | "python.testing.unittestEnabled": false, 12 | "python.testing.pytestEnabled": true, 13 | // "python.defaultInterpreterPath": "/opt/conda/envs/xformers/bin/python", 14 | "python.defaultInterpreterPath": "/opt/conda/bin/python", 15 | "runItOn": { 16 | "commands": [ 17 | { 18 | "match": "\\.py$", 19 | "isAsync": true, 20 | "isShellCommand": false, 21 | "cmd": "testing.runAll" 22 | }, 23 | ], 24 | }, 25 | "[python]": { 26 | "editor.defaultFormatter": "ms-python.black-formatter" 27 | }, 28 | "python.formatting.provider": "none" 29 | } 30 | -------------------------------------------------------------------------------- /.vscode/tasks.json: -------------------------------------------------------------------------------- 1 | { 2 | // See https://go.microsoft.com/fwlink/?LinkId=733558 3 | // for the documentation about the tasks.json format 4 | "version": "2.0.0", 5 | "tasks": [ 6 | { 7 | "label": "Watching Server", 8 | "type": "shell", 9 | "command": "scripts/devContainerServer.sh" 10 | } 11 | ] 12 | } -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # CONTRIBUTING 2 | 3 | *Tips for development* 4 | 5 | 1. [General Hints](#general) 6 | 1. [Development / Editor Setup](#editors) 7 | 1. [Visual Studio Code (vscode)](#vscode) 8 | 1. [Testing](#testing) 9 | 1. [Using Buildkit](#buildkit) 10 | 1. [Local HTTP(S) Caching Proxy](#caching) 11 | 1. [Local S3 Server](#local-s3-server) 12 | 1. [Stop on Suspend](#stop-on-suspend) 13 | 14 | 15 | ## General 16 | 17 | 1. Run docker with `-it` to make it easier to stop container with `Ctrl-C`. 18 | 1. If you get a `CUDA initialization: CUDA unknown error` after suspend, 19 | just stop the container, `rmmod nvidia_uvm`, and restart. 20 | 21 | 22 | ## Editors 23 | 24 | 25 | ### Visual Studio Code (recommended, WIP) 26 | 27 | *We're still writing this guide, let us know of any needed improvements* 28 | 29 | This repo includes VSCode settings that allow for a) editing inside a docker container, b) tests and coverage (on save) 30 | 31 | 1. Install from https://code.visualstudio.com/ 32 | 1. Install [Remote - Containers](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) extension. 33 | 1. Open your docker-diffusers-api folder, you'll get a popup in the bottom right that a dev container environment was detected, click "reload in container" 34 | 1. Look for the "( ) Watch" on status bar and click it so it changes to "( ) XX Coverage" 35 | 36 | **Live Development** 37 | 38 | 1. **Run Task** (either Ctrl-Shift-P and "Run Task", or in Terminals, the Plus ("+") DROPDOWN selector and choose, "Run Task..." at the bottom) 39 | 1. Choose **Watching Server**. Port 8000 will be forwarded. The server will be reloaded 40 | on every file safe (make sure to give it enough time to fully load before sending another 41 | request, otherwise that request will hang). 42 | 43 | 44 | ## Testing 45 | 46 | 1. **Unit testing**: exists but is sorely lacking for now. If you use the 47 | recommended editor setup above, it's probably working already. However: 48 | 49 | 1. **Integation / E2E**: cover most features used in production. 50 | `pytest -s tests/integration`. 51 | The `-s` is optional but streams stdout so you can follow along. 52 | Add also `-k test_name` to test a specific test. E2E tests are LONG but you can 53 | greatly reduce subsequent run time by following the steps below for a 54 | [Local HTTP(S) Caching Proxy](#caching) and [Local S3 Server](#local-s3-server). 55 | 56 | Docker-Diffusers-API follows Semantic Versioning. We follow the 57 | [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/) 58 | standard. 59 | 60 | * On a commit to `dev`, if all CI tests pass, a new release is made to `:dev` tag. 61 | * On a commit to `main`, if all CI tests pass, a new release with appropriate 62 | major / minor / patch is made, based on appropriate tags in the commit history. 63 | 64 | 65 | ## Using BuildKit 66 | 67 | Buildkit is a docker extension that can really improve build speeds through 68 | caching and parallelization. You can enable and tweak it by adding: 69 | 70 | `DOCKER_BUILDKIT=1 BUILDKIT_PROGRESS=plain` 71 | 72 | vars before `docker build` (the `PROGRESS` var shows much more detailed 73 | build logs, which can be useful, but are much more verbose). This is 74 | already all setup in the the [build](./build) script. 75 | 76 | 77 | ## Local HTTP(S) Caching Proxy 78 | 79 | If you're only editing e.g. `app.py`, there's no need to worry about caching 80 | and the docker layers work amazingly. But, if you're constantly changing 81 | installed packages (apt, `requirements.txt`), `download.py`, etc, it's VERY 82 | helpful to have a local cache: 83 | 84 | ```bash 85 | # See all options at https://hub.docker.com/r/gadicc/squid-ssl-zero 86 | $ docker run -d -p 3128:3128 -p 3129:80 \ 87 | --name squid --restart=always \ 88 | -v /usr/local/squid:/usr/local/squid \ 89 | gadicc/squid-ssl-zero 90 | ``` 91 | 92 | and then set the docker build args `proxy=1`, and `http_proxy` / `https_proxy` 93 | with their respective values. 94 | This is already all set up in the [build](./build) script. 95 | 96 | **You probably want to fine-tune /usr/local/squid/etc/squid.conf**. 97 | 98 | It will be created after you first run `gadicc/squid-ssl-zero`. You can then 99 | stop the container (`docker ps`, `docker stop container_id`), edit the file, 100 | and re-start (`docker start container_id`). For now, try something like: 101 | 102 | ```conf 103 | cache_dir ufs /usr/local/squid/cache 50000 16 256 # 50GB 104 | maximum_object_size 20 GB 105 | refresh_pattern . 52034400 50% 52034400 store-stale override-expire ignore-no-cache ignore-no-store ignore-private 106 | ``` 107 | 108 | but ideally we can as a community create some rules that don't so 109 | aggressively catch every single request. 110 | 111 | 112 | ## Local S3 server 113 | 114 | If you're doing development around the S3 handling, it can be very useful to 115 | have a local S3 server, especially due to the large size of models. You 116 | can set one up like this: 117 | 118 | ```bash 119 | $ docker run -p 9000:9000 -p 9001:9001 \ 120 | -v /usr/local/minio:/data quay.io/minio/minio \ 121 | server /data --console-address ":9001" 122 | ``` 123 | 124 | Now point a web browser to http://localhost:9001/, login with the default 125 | root credentials `minioadmin:minioadmin` and create a bucket and credentials 126 | for testing. More info at https://hub.docker.com/r/minio/minio/. 127 | 128 | Typical policy: 129 | 130 | ```json 131 | { 132 | "Version": "2012-10-17", 133 | "Statement": [ 134 | { 135 | "Sid": "VisualEditor0", 136 | "Effect": "Allow", 137 | "Action": [ 138 | "s3:PutObject", 139 | "s3:GetObject" 140 | ], 141 | "Resource": "arn:aws:s3:::BUCKET_NAME/*" 142 | } 143 | ] 144 | } 145 | ``` 146 | 147 | Then set the **build-arg** `AWS_S3_ENDPOINT_URL="http://172.17.0.1:9000"` 148 | or as appropriate if you've changed the default docker network. 149 | 150 | 151 | ## Stop on Suspend 152 | 153 | Maybe it's just me, but frequently I'll have issues when suspending with 154 | the container running (I guess its a CUDA issue), either a freeze on resume, 155 | or a stuck-forever defunct process. I found it useful to automatically stop 156 | the container / process on suspend. 157 | 158 | I'm running ArchLinux and set up a `systemd` suspend hook as described 159 | [here](https://wiki.archlinux.org/title/Power_management#Sleep_hooks), to 160 | call a script, which contains: 161 | 162 | ```bash 163 | # Stop a matching docker container 164 | PID=`docker ps -qf ancestor=gadicc/diffusers-api` 165 | if [ ! -z $PID ] ; then 166 | echo "Stopping diffusers-api pid $PID" 167 | docker stop $PID 168 | fi 169 | 170 | # For a VSCode devcontainer, just kill the watchmedo process. 171 | PID=`docker ps -qf volume=/home/dragon/root-cache` 172 | if [ ! -z $PID ] ; then 173 | echo "Stopping watchmedo in container $PID" 174 | docker exec $PID /bin/bash -c 'kill `pidof -sx watchmedo`' 175 | fi 176 | ``` 177 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG FROM_IMAGE="pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime" 2 | # ARG FROM_IMAGE="gadicc/diffusers-api-base:python3.9-pytorch1.12.1-cuda11.6-xformers" 3 | # You only need the -banana variant if you need banana's optimization 4 | # i.e. not relevant if you're using RUNTIME_DOWNLOADS 5 | # ARG FROM_IMAGE="gadicc/python3.9-pytorch1.12.1-cuda11.6-xformers-banana" 6 | FROM ${FROM_IMAGE} as base 7 | ENV FROM_IMAGE=${FROM_IMAGE} 8 | 9 | # Note, docker uses HTTP_PROXY and HTTPS_PROXY (uppercase) 10 | # We purposefully want those managed independently, as we want docker 11 | # to manage its own cache. This is just for pip, models, etc. 12 | ARG http_proxy 13 | ARG https_proxy 14 | RUN if [ -n "$http_proxy" ] ; then \ 15 | echo quit \ 16 | | openssl s_client -proxy $(echo ${https_proxy} | cut -b 8-) -servername google.com -connect google.com:443 -showcerts \ 17 | | sed 'H;1h;$!d;x; s/^.*\(-----BEGIN CERTIFICATE-----.*-----END CERTIFICATE-----\)\n---\nServer certificate.*$/\1/' \ 18 | > /usr/local/share/ca-certificates/squid-self-signed.crt ; \ 19 | update-ca-certificates ; \ 20 | fi 21 | ARG REQUESTS_CA_BUNDLE=${http_proxy:+/usr/local/share/ca-certificates/squid-self-signed.crt} 22 | 23 | ARG DEBIAN_FRONTEND=noninteractive 24 | 25 | ARG TZ=UTC 26 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone 27 | 28 | RUN apt-get update 29 | RUN apt-get install -yq apt-utils 30 | RUN apt-get install -yqq git zstd wget curl 31 | 32 | FROM base AS patchmatch 33 | ARG USE_PATCHMATCH=0 34 | WORKDIR /tmp 35 | COPY scripts/patchmatch-setup.sh . 36 | RUN sh patchmatch-setup.sh 37 | 38 | FROM base as output 39 | RUN mkdir /api 40 | WORKDIR /api 41 | 42 | # we use latest pip in base image 43 | # RUN pip3 install --upgrade pip 44 | 45 | ADD requirements.txt requirements.txt 46 | RUN pip install -r requirements.txt 47 | 48 | # [Import] Add missing settings / Correct some dummy imports (#5036) - 2023-09-14 49 | ARG DIFFUSERS_VERSION="3aa641289c995b3a0ce4ea895a76eb1128eff30c" 50 | ENV DIFFUSERS_VERSION=${DIFFUSERS_VERSION} 51 | 52 | RUN git clone https://github.com/huggingface/diffusers && cd diffusers && git checkout ${DIFFUSERS_VERSION} 53 | WORKDIR /api 54 | RUN pip install -e diffusers 55 | 56 | # Set to true to NOT download model at build time, rather at init / usage. 57 | ARG RUNTIME_DOWNLOADS=1 58 | ENV RUNTIME_DOWNLOADS=${RUNTIME_DOWNLOADS} 59 | 60 | # TODO, to dda-bananana 61 | # ARG PIPELINE="StableDiffusionInpaintPipeline" 62 | ARG PIPELINE="ALL" 63 | ENV PIPELINE=${PIPELINE} 64 | 65 | # Deps for RUNNING (not building) earlier options 66 | ARG USE_PATCHMATCH=0 67 | RUN if [ "$USE_PATCHMATCH" = "1" ] ; then apt-get install -yqq python3-opencv ; fi 68 | COPY --from=patchmatch /tmp/PyPatchMatch PyPatchMatch 69 | 70 | # TODO, just include by default, and handle all deps in OUR requirements.txt 71 | ARG USE_DREAMBOOTH=1 72 | ENV USE_DREAMBOOTH=${USE_DREAMBOOTH} 73 | 74 | RUN if [ "$USE_DREAMBOOTH" = "1" ] ; then \ 75 | # By specifying the same torch version as conda, it won't download again. 76 | # Without this, it will upgrade torch, break xformers, make bigger image. 77 | # bitsandbytes==0.40.0.post4 had failed cuda detection on dreambooth test. 78 | pip install -r diffusers/examples/dreambooth/requirements.txt ; \ 79 | fi 80 | RUN if [ "$USE_DREAMBOOTH" = "1" ] ; then apt-get install -yqq git-lfs ; fi 81 | 82 | ARG USE_REALESRGAN=1 83 | RUN if [ "$USE_REALESRGAN" = "1" ] ; then apt-get install -yqq libgl1-mesa-glx libglib2.0-0 ; fi 84 | RUN if [ "$USE_REALESRGAN" = "1" ] ; then git clone https://github.com/xinntao/Real-ESRGAN.git ; fi 85 | # RUN if [ "$USE_REALESRGAN" = "1" ] ; then pip install numba==0.57.1 chardet ; fi 86 | RUN if [ "$USE_REALESRGAN" = "1" ] ; then pip install basicsr==1.4.2 facexlib==0.2.5 gfpgan==1.3.8 ; fi 87 | RUN if [ "$USE_REALESRGAN" = "1" ] ; then cd Real-ESRGAN && python3 setup.py develop ; fi 88 | 89 | COPY api/ . 90 | EXPOSE 8000 91 | 92 | ARG SAFETENSORS_FAST_GPU=1 93 | ENV SAFETENSORS_FAST_GPU=${SAFETENSORS_FAST_GPU} 94 | 95 | CMD python3 -u server.py 96 | 97 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Banana, Gadi Cohen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # docker-diffusers-api ("banana-sd-base") 2 | 3 | Diffusers / Stable Diffusion in docker with a REST API, supporting various models, pipelines & schedulers. Used by [kiri.art](https://kiri.art/), perfect for local, server & serverless. 4 | 5 | [![Docker](https://img.shields.io/docker/v/gadicc/diffusers-api?sort=semver)](https://hub.docker.com/r/gadicc/diffusers-api/tags) [![CircleCI](https://img.shields.io/circleci/build/github/kiri-art/docker-diffusers-api/split)](https://circleci.com/gh/kiri-art/docker-diffusers-api?branch=split) [![semantic-release](https://img.shields.io/badge/%20%20%F0%9F%93%A6%F0%9F%9A%80-semantic--release-e10079.svg)](https://github.com/semantic-release/semantic-release) [![MIT License](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE) [![Open in Dev Containers](https://img.shields.io/static/v1?label=Dev%20Containers&message=Open&color=blue&logo=visualstudiocode)](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/kiri-art/docker-diffusers-api) 6 | 7 | Copyright (c) Gadi Cohen, 2022. MIT Licensed. 8 | Please give credit and link back to this repo if you use it in a public project. 9 | 10 | ## Features 11 | 12 | * Models: stable-diffusion, waifu-diffusion, and easy to add others (e.g. jp-sd) 13 | * Pipelines: txt2img, img2img and inpainting in a single container 14 | ([all diffusers official and community pipelines](https://forums.kiri.art/t/all-your-pipelines-are-belong-to-us/83) are wrapped, but untested) 15 | * All model inputs supported, including setting nsfw filter per request 16 | * *Permute* base config to multiple forks based on yaml config with vars 17 | * Optionally send signed event logs / performance data to a REST endpoint / webhook. 18 | * Can automatically download a checkpoint file and convert to diffusers. 19 | * S3 support, dreambooth training. 20 | 21 | Note: This image was created for [kiri.art](https://kiri.art/). 22 | Everything is open source but there may be certain request / response 23 | assumptions. If anything is unclear, please open an issue. 24 | 25 | ## Important Notices 26 | 27 | * [Official `docker-diffusers-api` Forum](https://forums.kiri.art/c/docker-diffusers-api/16): 28 | help, updates, discussion. 29 | * Subscribe ("watch") these forum topics for: 30 | * [notable **`main`** branch updates](https://forums.kiri.art/t/official-releases-main-branch/35) 31 | * [notable **`dev`** branch updates](https://forums.kiri.art/t/development-releases-dev-branch/53) 32 | * Always [check the CHANGELOG](./CHANGELOG.md) for important updates when upgrading. 33 | 34 | **Official help in our dedicated forum https://forums.kiri.art/c/docker-diffusers-api/16.** 35 | 36 | **This README refers to the in-development `dev` branch** and may 37 | reference features and fixes not yet in the published releases. 38 | 39 | **`v1` has not yet been officially released yet** but has been 40 | running well in production on kiri.art for almost a month. We'd 41 | be grateful for any feedback from early adopters to help make 42 | this official. For more details, see [Upgrading from v0 to 43 | v1](https://forums.kiri.art/t/wip-upgrading-from-v0-to-v1/116). 44 | Previous releases available on the `dev-v0-final` and 45 | `main-v0-final` branches. 46 | 47 | **Currently only NVIDIA / CUDA devices are supported**. Tracking 48 | Apple / M1 support in issue 49 | [#20](https://github.com/kiri-art/docker-diffusers-api/issues/20). 50 | 51 | ## Installation & Setup: 52 | 53 | Setup varies depending on your use case. 54 | 55 | 1. **To run locally or on a *server*, with runtime downloads:** 56 | 57 | `docker run --gpus all -p 8000:8000 -e HF_AUTH_TOKEN=$HF_AUTH_TOKEN gadicc/diffusers-api`. 58 | 59 | See the [guides for various cloud providers](https://forums.kiri.art/t/running-on-other-cloud-providers/89/7). 60 | 61 | 1. **To run *serverless*, include the model at build time:** 62 | 63 | 1. [docker-diffusers-api-build-download](https://github.com/kiri-art/docker-diffusers-api-build-download) ( 64 | [banana](https://forums.kiri.art/t/run-diffusers-api-on-banana-dev/103), others) 65 | 1. [docker-diffusers-api-runpod](https://github.com/kiri-art/docker-diffusers-api-runpod), 66 | see the [guide](https://forums.kiri.art/t/run-diffusers-api-on-runpod-io/102) 67 | 68 | 1. **Building from source**. 69 | 70 | 1. Fork / clone this repo. 71 | 1. `docker build -t gadicc/diffusers-api .` 72 | 1. See [CONTRIBUTING.md](./CONTRIBUTING.md) for more helpful hints. 73 | 74 | *Other configurations are possible but these are the most common cases* 75 | 76 | Everything is set via docker build-args or environment variables. 77 | 78 | ## Usage: 79 | 80 | See also [Testing](#testing) below. 81 | 82 | The container expects an `HTTP POST` request to `/`, with a JSON body resembling the following: 83 | 84 | ```json 85 | { 86 | "modelInputs": { 87 | "prompt": "Super dog", 88 | "num_inference_steps": 50, 89 | "guidance_scale": 7.5, 90 | "width": 512, 91 | "height": 512, 92 | "seed": 3239022079 93 | }, 94 | "callInputs": { 95 | // You can leave these out to use the default 96 | "MODEL_ID": "runwayml/stable-diffusion-v1-5", 97 | "PIPELINE": "StableDiffusionPipeline", 98 | "SCHEDULER": "LMSDiscreteScheduler", 99 | "safety_checker": true, 100 | }, 101 | } 102 | ``` 103 | 104 | It's important to remember that `docker-diffusers-api` is primarily a wrapper 105 | around HuggingFace's 106 | [diffusers](https://huggingface.co/docs/diffusers/index) library. 107 | **Basic familiarity with `diffusers` is indespensible for a good experience 108 | with `docker-diffusers-api`.** Explaining some of the options above: 109 | 110 | * **modelInputs** - for the most part - are passed directly to the selected 111 | diffusers pipeline unchanged. So, for the default `StableDiffusionPipeline`, 112 | you can see all options in the relevant pipeline docs for its 113 | [`__call__`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__) method. The main exceptions are: 114 | 115 | * Only valid JSON values can be given (strings, numbers, etc) 116 | * **seed**, a number, is transformed into a `generator`. 117 | * **images** are converted to / from base64 encoded strings. 118 | 119 | * **callInputs** affect which model, pipeline, scheduler and other lower 120 | level options are used to construct the final pipeline. Notably: 121 | 122 | * **`SCHEDULER`**: any scheduler included in diffusers should work out 123 | the box, provided it can loaded with its default config and without 124 | requiring any other explicit arguments at init time. In any event, 125 | the following schedulers are the most common and most well tested: 126 | `DPMSolverMultistepScheduler` (fast! only needs 20 steps!), 127 | `LMSDiscreteScheduler`, `DDIMScheduler`, `PNDMScheduler`, 128 | `EulerAncestralDiscreteScheduler`, `EulerDiscreteScheduler`. 129 | 130 | * **`PIPELINE`**: the most common are 131 | [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img), 132 | [`StableDiffusionImg2ImgPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/img2img), 133 | [`StableDiffusionInpaintPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/inpaint), and the community 134 | [`lpw_stable_diffusion`](https://forums.kiri.art/t/lpw-stable-diffusion-pipeline-longer-prompts-prompt-weights/82) 135 | which allows for long prompts (more than 77 tokens) and prompt weights 136 | (things like `((big eyes))`, `(red hair:1.2)`, etc), and accepts a 137 | `custom_pipeline_method` callInput with values `text2img` ("text", not "txt"), 138 | `img2img` and `inpaint`. See these links for all the possible `modelInputs`'s 139 | that can be passed to the pipeline's `__call__` method. 140 | 141 | * **`MODEL_URL`** (optional) can be used to retrieve the model from 142 | locations other than HuggingFace, e.g. an `HTTP` server, S3-compatible 143 | storage, etc. For more info, see the 144 | [storage docs](https://github.com/kiri-art/docker-diffusers-api/blob/dev/docs/storage.md) 145 | and 146 | [this post](https://forums.kiri.art/t/safetensors-our-own-optimization-faster-model-init/98) 147 | for info on how to use and store optimized models from your own cloud. 148 | 149 | 150 | ## Examples and testing 151 | 152 | There are also very basic examples in [test.py](./test.py), which you can view 153 | and call `python test.py` if the container is already running on port 8000. 154 | You can also specify a specific test, change some options, and run against a 155 | deployed banana image: 156 | 157 | ```bash 158 | $ python test.py 159 | Usage: python3 test.py [--banana] [--xmfe=1/0] [--scheduler=SomeScheduler] [all / test1] [test2] [etc] 160 | 161 | # Run against http://localhost:8000/ (Nvidia Quadro RTX 5000) 162 | $ python test.py txt2img 163 | Running test: txt2img 164 | Request took 5.9s (init: 3.2s, inference: 5.9s) 165 | Saved /home/dragon/www/banana/banana-sd-base/tests/output/txt2img.png 166 | 167 | # Run against deployed banana image (Nvidia A100) 168 | $ export BANANA_API_KEY=XXX 169 | $ BANANA_MODEL_KEY=XXX python3 test.py --banana txt2img 170 | Running test: txt2img 171 | Request took 19.4s (init: 2.5s, inference: 3.5s) 172 | Saved /home/dragon/www/banana/banana-sd-base/tests/output/txt2img.png 173 | 174 | # Note that 2nd runs are much faster (ignore init, that isn't run again) 175 | Request took 3.0s (init: 2.4s, inference: 2.1s) 176 | ``` 177 | 178 | The best example of course is https://kiri.art/ and it's 179 | [source code](https://github.com/kiri-art/stable-diffusion-react-nextjs-mui-pwa). 180 | 181 | ## Help on [Official Forums](https://forums.kiri.art/c/docker-diffusers-api/16). 182 | 183 | ## Adding other Models 184 | 185 | You have two options. 186 | 187 | 1. For a diffusers model, simply set `MODEL_ID` build-var / call-arg to the name 188 | of the model hosted on HuggingFace, and it will be downloaded automatically at 189 | build time. 190 | 191 | 1. For a non-diffusers model, simply set the `CHECKPOINT_URL` build-var / call-arg 192 | to the URL of a `.ckpt` file, which will be downloaded and converted to the diffusers 193 | format automatically at build time. `CHECKPOINT_CONFIG_URL` can also be set. 194 | 195 | ## Troubleshooting 196 | 197 | * **403 Client Error: Forbidden for url** 198 | 199 | Make sure you've accepted the license on the model card of the HuggingFace model 200 | specified in `MODEL_ID`, and that you correctly passed `HF_AUTH_TOKEN` to the 201 | container. 202 | 203 | ## Event logs / web hooks / performance data 204 | 205 | Set `SEND_URL` (and optionally `SIGN_KEY`) environment variable(s) to send 206 | event and timing data on `init`, `inference` and other start and end events. 207 | This can either be used to log performance data, or for webhooks on event 208 | start / finish. 209 | 210 | The timing data is now returned in the response payload too, like this: 211 | `{ $timings: { init: timeInMs, inference: timeInMs } }`, with any other 212 | events (such a `training`, `upload`, etc). 213 | 214 | You can go to https://webhook.site/ and use the provided "unique URL" 215 | as your `SEND_URL` to see how it works, if you don't have your own 216 | REST endpoint (yet). 217 | 218 | If `SIGN_KEY` is used, you can verify the signature like this (TypeScript): 219 | 220 | ```ts 221 | import crypto from "crypto"; 222 | 223 | async function handler(req: NextApiRequest, res: NextApiResponse) { 224 | const data = req.body; 225 | 226 | const containerSig = data.sig as string; 227 | delete data.sig; 228 | 229 | const ourSig = crypto 230 | .createHash("md5") 231 | .update(JSON.stringify(data) + process.env.SIGN_KEY) 232 | .digest("hex"); 233 | 234 | const signatureIsValid = containerSig === ourSig; 235 | } 236 | ``` 237 | 238 | If you send a callInput called `startRequestId`, it will get sent 239 | back as part of the send payload in most cases. 240 | 241 | You can also set callInputs `SEND_URL` and `SIGN_KEY` to 242 | set or override these values on a per-request basis. 243 | 244 | ## Acknowledgements 245 | 246 | * The container image is originally based on 247 | https://github.com/bananaml/serverless-template-stable-diffusion. 248 | 249 | * [CompVis](https://github.com/CompVis), 250 | [Stability AI](https://stability.ai/), 251 | [LAION](https://laion.ai/) 252 | and [RunwayML](https://runwayml.com/) 253 | for their incredible time, work and efforts in creating Stable Diffusion, 254 | and no less so, their decision to release it publicly with an open source 255 | license. 256 | 257 | * [HuggingFace](https://huggingface.co/) - for their passion and inspiration 258 | for making machine learning more accessibe to developers, and in particular, 259 | their [Diffusers](https://github.com/huggingface/diffusers) library. 260 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/__init__.py -------------------------------------------------------------------------------- /api/app.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | from sched import scheduler 3 | import torch 4 | 5 | from torch import autocast 6 | from diffusers import __version__ 7 | import base64 8 | from io import BytesIO 9 | import PIL 10 | import json 11 | from loadModel import loadModel 12 | from send import send, getTimings, clearSession 13 | from status import status 14 | import os 15 | import numpy as np 16 | import skimage 17 | import skimage.measure 18 | from getScheduler import getScheduler, SCHEDULERS 19 | from getPipeline import ( 20 | getPipelineClass, 21 | getPipelineForModel, 22 | listAvailablePipelines, 23 | clearPipelines, 24 | ) 25 | import re 26 | import requests 27 | from download import download_model, normalize_model_id 28 | import traceback 29 | from precision import MODEL_REVISION, MODEL_PRECISION 30 | from device import device, device_id, device_name 31 | from utils import Storage 32 | from hashlib import sha256 33 | from threading import Timer 34 | import extras 35 | import jxlpy 36 | from jxlpy import JXLImagePlugin 37 | 38 | 39 | from diffusers import ( 40 | StableDiffusionXLPipeline, 41 | StableDiffusionXLImg2ImgPipeline, 42 | StableDiffusionXLInpaintPipeline, 43 | pipelines as diffusers_pipelines, 44 | AutoencoderTiny, 45 | AutoencoderKL, 46 | ) 47 | 48 | from lib.textual_inversions import handle_textual_inversions 49 | from lib.prompts import prepare_prompts 50 | from lib.vars import ( 51 | RUNTIME_DOWNLOADS, 52 | USE_DREAMBOOTH, 53 | MODEL_ID, 54 | PIPELINE, 55 | HF_AUTH_TOKEN, 56 | HOME, 57 | MODELS_DIR, 58 | ) 59 | 60 | if USE_DREAMBOOTH: 61 | from train_dreambooth import TrainDreamBooth 62 | print(os.environ.get("USE_PATCHMATCH")) 63 | if os.environ.get("USE_PATCHMATCH") == "1": 64 | from PyPatchMatch import patch_match 65 | 66 | torch.set_grad_enabled(False) 67 | always_normalize_model_id = None 68 | 69 | tiny_vae = None 70 | 71 | 72 | # still working on this, not in use yet. 73 | def tinyVae(origVae: AutoencoderKL): 74 | global tiny_vae 75 | if not tiny_vae: 76 | tiny_vae = AutoencoderTiny.from_pretrained( 77 | "madebyollin/taesd", 78 | torch_dtype=torch.float16, 79 | in_channels=origVae.config.in_channels, 80 | out_channels=origVae.config.out_channels, 81 | act_fn=origVae.config.act_fn, 82 | latent_channels=origVae.config.latent_channels, 83 | scaling_factor=origVae.config.scaling_factor, 84 | force_upcast=origVae.config.force_upcast, 85 | ) 86 | tiny_vae.to("cuda") 87 | 88 | return tiny_vae 89 | 90 | 91 | # Init is ran on server startup 92 | # Load your model to GPU as a global variable here using the variable name "model" 93 | def init(): 94 | global model # needed for bananna optimizations 95 | global always_normalize_model_id 96 | 97 | asyncio.run( 98 | send( 99 | "init", 100 | "start", 101 | { 102 | "device": device_name, 103 | "hostname": os.getenv("HOSTNAME"), 104 | "model_id": MODEL_ID, 105 | "diffusers": __version__, 106 | }, 107 | ) 108 | ) 109 | 110 | if MODEL_ID == "ALL" or RUNTIME_DOWNLOADS: 111 | global last_model_id 112 | last_model_id = None 113 | 114 | if not RUNTIME_DOWNLOADS: 115 | normalized_model_id = normalize_model_id(MODEL_ID, MODEL_REVISION) 116 | model_dir = os.path.join(MODELS_DIR, normalized_model_id) 117 | if os.path.isdir(model_dir): 118 | always_normalize_model_id = model_dir 119 | else: 120 | normalized_model_id = MODEL_ID 121 | 122 | model = loadModel( 123 | model_id=always_normalize_model_id or MODEL_ID, 124 | load=True, 125 | precision=MODEL_PRECISION, 126 | revision=MODEL_REVISION, 127 | ) 128 | else: 129 | model = None 130 | 131 | asyncio.run(send("init", "done")) 132 | 133 | 134 | def decodeBase64Image(imageStr: str, name: str) -> PIL.Image: 135 | image = PIL.Image.open(BytesIO(base64.decodebytes(bytes(imageStr, "utf-8")))) 136 | print(f'Decoded image "{name}": {image.format} {image.width}x{image.height}') 137 | return image 138 | 139 | 140 | def getFromUrl(url: str, name: str) -> PIL.Image: 141 | response = requests.get(url) 142 | image = PIL.Image.open(BytesIO(response.content)) 143 | print(f'Decoded image "{name}": {image.format} {image.width}x{image.height}') 144 | return image 145 | 146 | 147 | def truncateInputs(inputs: dict): 148 | clone = inputs.copy() 149 | if "modelInputs" in clone: 150 | modelInputs = clone["modelInputs"] = clone["modelInputs"].copy() 151 | for item in ["init_image", "mask_image", "image", "input_image"]: 152 | if item in modelInputs: 153 | modelInputs[item] = modelInputs[item][0:6] + "..." 154 | if "instance_images" in modelInputs: 155 | modelInputs["instance_images"] = list( 156 | map(lambda str: str[0:6] + "...", modelInputs["instance_images"]) 157 | ) 158 | return clone 159 | 160 | 161 | # last_xformers_memory_efficient_attention = {} 162 | last_attn_procs = None 163 | last_lora_weights = None 164 | cross_attention_kwargs = None 165 | 166 | 167 | # Inference is ran for every server call 168 | # Reference your preloaded global model variable here. 169 | async def inference(all_inputs: dict, response) -> dict: 170 | global model 171 | global pipelines 172 | global last_model_id 173 | global schedulers 174 | # global last_xformers_memory_efficient_attention 175 | global always_normalize_model_id 176 | global last_attn_procs 177 | global last_lora_weights 178 | global cross_attention_kwargs 179 | 180 | clearSession() 181 | 182 | print(json.dumps(truncateInputs(all_inputs), indent=2)) 183 | model_inputs = all_inputs.get("modelInputs", None) 184 | call_inputs = all_inputs.get("callInputs", None) 185 | result = {"$meta": {}} 186 | 187 | send_opts = {} 188 | if call_inputs.get("SEND_URL", None): 189 | send_opts.update({"SEND_URL": call_inputs.get("SEND_URL")}) 190 | if call_inputs.get("SIGN_KEY", None): 191 | send_opts.update({"SIGN_KEY": call_inputs.get("SIGN_KEY")}) 192 | if response: 193 | send_opts.update({"response": response}) 194 | 195 | async def sendStatusAsync(): 196 | await response.send(json.dumps(status.get()) + "\n") 197 | 198 | def sendStatus(): 199 | try: 200 | asyncio.run(sendStatusAsync()) 201 | Timer(1.0, sendStatus).start() 202 | except: 203 | pass 204 | 205 | Timer(1.0, sendStatus).start() 206 | 207 | if model_inputs == None or call_inputs == None: 208 | return { 209 | "$error": { 210 | "code": "INVALID_INPUTS", 211 | "message": "Expecting on object like { modelInputs: {}, callInputs: {} } but got " 212 | + json.dumps(all_inputs), 213 | } 214 | } 215 | 216 | startRequestId = call_inputs.get("startRequestId", None) 217 | 218 | use_extra = call_inputs.get("use_extra", None) 219 | if use_extra: 220 | extra = getattr(extras, use_extra, None) 221 | if not extra: 222 | return { 223 | "$error": { 224 | "code": "NO_SUCH_EXTRA", 225 | "message": 'Requested "' 226 | + use_extra 227 | + '", available: "' 228 | + '", "'.join(extras.keys()) 229 | + '"', 230 | } 231 | } 232 | return await extra( 233 | model_inputs, 234 | call_inputs, 235 | send_opts=send_opts, 236 | startRequestId=startRequestId, 237 | ) 238 | 239 | model_id = call_inputs.get("MODEL_ID", None) 240 | if not model_id: 241 | if not MODEL_ID: 242 | return { 243 | "$error": { 244 | "code": "NO_MODEL_ID", 245 | "message": "No callInputs.MODEL_ID specified, nor was MODEL_ID env var set.", 246 | } 247 | } 248 | model_id = MODEL_ID 249 | result["$meta"].update({"MODEL_ID": MODEL_ID}) 250 | normalized_model_id = model_id 251 | 252 | if RUNTIME_DOWNLOADS: 253 | hf_model_id = call_inputs.get("HF_MODEL_ID", None) 254 | model_revision = call_inputs.get("MODEL_REVISION", None) 255 | model_precision = call_inputs.get("MODEL_PRECISION", None) 256 | checkpoint_url = call_inputs.get("CHECKPOINT_URL", None) 257 | checkpoint_config_url = call_inputs.get("CHECKPOINT_CONFIG_URL", None) 258 | normalized_model_id = normalize_model_id(model_id, model_revision) 259 | model_dir = os.path.join(MODELS_DIR, normalized_model_id) 260 | pipeline_name = call_inputs.get("PIPELINE", None) 261 | if pipeline_name: 262 | pipeline_class = getPipelineClass(pipeline_name) 263 | if last_model_id != normalized_model_id: 264 | # if not downloaded_models.get(normalized_model_id, None): 265 | if not os.path.isdir(model_dir): 266 | model_url = call_inputs.get("MODEL_URL", None) 267 | if not model_url: 268 | # return { 269 | # "$error": { 270 | # "code": "NO_MODEL_URL", 271 | # "message": "Currently RUNTIME_DOWNOADS requires a MODEL_URL callInput", 272 | # } 273 | # } 274 | normalized_model_id = hf_model_id or model_id 275 | await download_model( 276 | model_id=model_id, 277 | model_url=model_url, 278 | model_revision=model_revision, 279 | checkpoint_url=checkpoint_url, 280 | checkpoint_config_url=checkpoint_config_url, 281 | hf_model_id=hf_model_id, 282 | model_precision=model_precision, 283 | send_opts=send_opts, 284 | pipeline_class=pipeline_class if pipeline_name else None, 285 | ) 286 | # downloaded_models.update({normalized_model_id: True}) 287 | clearPipelines() 288 | cross_attention_kwargs = None 289 | if model: 290 | model.to("cpu") # Necessary to avoid a memory leak 291 | await send( 292 | "loadModel", "start", {"startRequestId": startRequestId}, send_opts 293 | ) 294 | model = await asyncio.to_thread( 295 | loadModel, 296 | model_id=normalized_model_id, 297 | load=True, 298 | precision=model_precision, 299 | revision=model_revision, 300 | send_opts=send_opts, 301 | pipeline_class=pipeline_class if pipeline_name else None, 302 | ) 303 | await send( 304 | "loadModel", "done", {"startRequestId": startRequestId}, send_opts 305 | ) 306 | last_model_id = normalized_model_id 307 | last_attn_procs = None 308 | last_lora_weights = None 309 | else: 310 | if always_normalize_model_id: 311 | normalized_model_id = always_normalize_model_id 312 | print( 313 | { 314 | "always_normalize_model_id": always_normalize_model_id, 315 | "normalized_model_id": normalized_model_id, 316 | } 317 | ) 318 | 319 | if MODEL_ID == "ALL": 320 | if last_model_id != normalized_model_id: 321 | clearPipelines() 322 | cross_attention_kwargs = None 323 | model = loadModel(normalized_model_id, send_opts=send_opts) 324 | last_model_id = normalized_model_id 325 | else: 326 | if model_id != MODEL_ID and not RUNTIME_DOWNLOADS: 327 | return { 328 | "$error": { 329 | "code": "MODEL_MISMATCH", 330 | "message": f'Model "{model_id}" not available on this container which hosts "{MODEL_ID}"', 331 | "requested": model_id, 332 | "available": MODEL_ID, 333 | } 334 | } 335 | 336 | if PIPELINE == "ALL": 337 | pipeline_name = call_inputs.get("PIPELINE", None) 338 | if not pipeline_name: 339 | pipeline_name = "AutoPipelineForText2Image" 340 | result["$meta"].update({"PIPELINE": pipeline_name}) 341 | 342 | pipeline = getPipelineForModel( 343 | pipeline_name, 344 | model, 345 | normalized_model_id, 346 | model_revision=model_revision if RUNTIME_DOWNLOADS else MODEL_REVISION, 347 | model_precision=model_precision if RUNTIME_DOWNLOADS else MODEL_PRECISION, 348 | ) 349 | if not pipeline: 350 | return { 351 | "$error": { 352 | "code": "NO_SUCH_PIPELINE", 353 | "message": f'"{pipeline_name}" is not an official nor community Diffusers pipelines', 354 | "requested": pipeline_name, 355 | "available": listAvailablePipelines(), 356 | } 357 | } 358 | else: 359 | pipeline = model 360 | 361 | scheduler_name = call_inputs.get("SCHEDULER", None) 362 | if not scheduler_name: 363 | scheduler_name = "DPMSolverMultistepScheduler" 364 | result["$meta"].update({"SCHEDULER": scheduler_name}) 365 | 366 | pipeline.scheduler = getScheduler(normalized_model_id, scheduler_name) 367 | if pipeline.scheduler == None: 368 | return { 369 | "$error": { 370 | "code": "INVALID_SCHEDULER", 371 | "message": "", 372 | "requeted": call_inputs.get("SCHEDULER", None), 373 | "available": ", ".join(SCHEDULERS), 374 | } 375 | } 376 | 377 | safety_checker = call_inputs.get("safety_checker", True) 378 | pipeline.safety_checker = ( 379 | model.safety_checker 380 | if safety_checker and hasattr(model, "safety_checker") 381 | else None 382 | ) 383 | is_url = call_inputs.get("is_url", False) 384 | image_decoder = getFromUrl if is_url else decodeBase64Image 385 | 386 | textual_inversions = call_inputs.get("textual_inversions", []) 387 | await handle_textual_inversions(textual_inversions, model, status=status) 388 | 389 | # Better to use new lora_weights in next section 390 | attn_procs = call_inputs.get("attn_procs", None) 391 | if attn_procs is not last_attn_procs: 392 | if attn_procs: 393 | raise Exception( 394 | "[REMOVED] Using `attn_procs` for LoRAs is no longer supported. " 395 | + "Please use `lora_weights` instead." 396 | ) 397 | last_attn_procs = attn_procs 398 | # if attn_procs: 399 | # storage = Storage(attn_procs, no_raise=True) 400 | # if storage: 401 | # hash = sha256(attn_procs.encode("utf-8")).hexdigest() 402 | # attn_procs_from_safetensors = call_inputs.get( 403 | # "attn_procs_from_safetensors", None 404 | # ) 405 | # fname = storage.url.split("/").pop() 406 | # if attn_procs_from_safetensors and not re.match( 407 | # r".safetensors", attn_procs 408 | # ): 409 | # fname += ".safetensors" 410 | # if True: 411 | # # TODO, way to specify explicit name 412 | # path = os.path.join( 413 | # MODELS_DIR, "attn_proc--url_" + hash[:7] + "--" + fname 414 | # ) 415 | # attn_procs = path 416 | # if not os.path.exists(path): 417 | # storage.download_and_extract(path) 418 | # print("Load attn_procs " + attn_procs) 419 | # # Workaround https://github.com/huggingface/diffusers/pull/2448#issuecomment-1453938119 420 | # if storage and not re.search(r".safetensors", attn_procs): 421 | # attn_procs = torch.load(attn_procs, map_location="cpu") 422 | # pipeline.unet.load_attn_procs(attn_procs) 423 | # else: 424 | # print("Clearing attn procs") 425 | # pipeline.unet.set_attn_processor(CrossAttnProcessor()) 426 | 427 | # Currently we only support a single string, but we should allow 428 | # and array too in anticipation of multi-LoRA support in diffusers 429 | # tracked at https://github.com/huggingface/diffusers/issues/2613. 430 | lora_weights = call_inputs.get("lora_weights", None) 431 | lora_weights_joined = json.dumps(lora_weights) 432 | if last_lora_weights != lora_weights_joined: 433 | if last_lora_weights != None and last_lora_weights != "[]": 434 | print("Unloading previous LoRA weights") 435 | pipeline.unload_lora_weights() 436 | 437 | last_lora_weights = lora_weights_joined 438 | cross_attention_kwargs = {} 439 | 440 | if type(lora_weights) is not list: 441 | lora_weights = [lora_weights] if lora_weights else [] 442 | 443 | if len(lora_weights) > 0: 444 | for weights in lora_weights: 445 | storage = Storage(weights, no_raise=True, status=status) 446 | if storage: 447 | storage_query_fname = storage.query.get("fname") 448 | storage_query_scale = ( 449 | float(storage.query.get("scale")[0]) 450 | if storage.query.get("scale") 451 | else 1 452 | ) 453 | cross_attention_kwargs.update({"scale": storage_query_scale}) 454 | # https://github.com/damian0815/compel/issues/42#issuecomment-1656989385 455 | pipeline._lora_scale = storage_query_scale 456 | if storage_query_fname: 457 | fname = storage_query_fname[0] 458 | else: 459 | hash = sha256(weights.encode("utf-8")).hexdigest() 460 | fname = "url_" + hash[:7] + "--" + storage.url.split("/").pop() 461 | cache_fname = "lora_weights--" + fname 462 | path = os.path.join(MODELS_DIR, cache_fname) 463 | if not os.path.exists(path): 464 | await asyncio.to_thread(storage.download_file, path) 465 | print("Load lora_weights `" + weights + "` from `" + path + "`") 466 | pipeline.load_lora_weights( 467 | MODELS_DIR, weight_name=cache_fname, local_files_only=True 468 | ) 469 | else: 470 | print("Loading from huggingface not supported yet: " + weights) 471 | # maybe something like sayakpaul/civitai-light-shadow-lora#lora=l_a_s.s9s? 472 | # lora_model_id = "sayakpaul/civitai-light-shadow-lora" 473 | # lora_filename = "light_and_shadow.safetensors" 474 | # pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename) 475 | else: 476 | print("No changes to LoRAs since last call") 477 | 478 | # TODO, generalize 479 | mi_cross_attention_kwargs = model_inputs.get("cross_attention_kwargs", None) 480 | if mi_cross_attention_kwargs: 481 | model_inputs.pop("cross_attention_kwargs") 482 | if isinstance(mi_cross_attention_kwargs, str): 483 | if not cross_attention_kwargs: 484 | cross_attention_kwargs = {} 485 | cross_attention_kwargs.update(json.loads(mi_cross_attention_kwargs)) 486 | elif type(mi_cross_attention_kwargs) == dict: 487 | if not cross_attention_kwargs: 488 | cross_attention_kwargs = {} 489 | cross_attention_kwargs.update(mi_cross_attention_kwargs) 490 | else: 491 | return { 492 | "$error": { 493 | "code": "INVALID_CROSS_ATTENTION_KWARGS", 494 | "message": "`cross_attention_kwargs` should be a dict or json string", 495 | } 496 | } 497 | 498 | print({"cross_attention_kwargs": cross_attention_kwargs}) 499 | if cross_attention_kwargs: 500 | model_inputs.update({"cross_attention_kwargs": cross_attention_kwargs}) 501 | 502 | # Parse out your arguments 503 | # prompt = model_inputs.get("prompt", None) 504 | # if prompt == None: 505 | # return {"message": "No prompt provided"} 506 | # 507 | # height = model_inputs.get("height", 512) 508 | # width = model_inputs.get("width", 512) 509 | # num_inference_steps = model_inputs.get("num_inference_steps", 50) 510 | # guidance_scale = model_inputs.get("guidance_scale", 7.5) 511 | # seed = model_inputs.get("seed", None) 512 | # strength = model_inputs.get("strength", 0.75) 513 | 514 | if "init_image" in model_inputs: 515 | model_inputs["init_image"] = image_decoder( 516 | model_inputs.get("init_image"), "init_image" 517 | ) 518 | 519 | if "image" in model_inputs: 520 | model_inputs["image"] = image_decoder(model_inputs.get("image"), "image") 521 | 522 | if "mask_image" in model_inputs: 523 | model_inputs["mask_image"] = image_decoder( 524 | model_inputs.get("mask_image"), "mask_image" 525 | ) 526 | 527 | if "instance_images" in model_inputs: 528 | model_inputs["instance_images"] = list( 529 | map( 530 | lambda str: image_decoder(str, "instance_image"), 531 | model_inputs["instance_images"], 532 | ) 533 | ) 534 | 535 | await send("inference", "start", {"startRequestId": startRequestId}, send_opts) 536 | 537 | # Run patchmatch for inpainting 538 | if call_inputs.get("FILL_MODE", None) == "patchmatch": 539 | sel_buffer = np.array(model_inputs.get("init_image")) 540 | img = sel_buffer[:, :, 0:3] 541 | mask = sel_buffer[:, :, -1] 542 | img = patch_match.inpaint(img, mask=255 - mask, patch_size=3) 543 | model_inputs["init_image"] = PIL.Image.fromarray(img) 544 | mask = 255 - mask 545 | mask = skimage.measure.block_reduce(mask, (8, 8), np.max) 546 | mask = mask.repeat(8, axis=0).repeat(8, axis=1) 547 | model_inputs["mask_image"] = PIL.Image.fromarray(mask) 548 | 549 | # Turning on takes 3ms and turning off 1ms... don't worry, I've got your back :) 550 | # x_m_e_a = call_inputs.get("xformers_memory_efficient_attention", True) 551 | # last_x_m_e_a = last_xformers_memory_efficient_attention.get(pipeline, None) 552 | # if x_m_e_a != last_x_m_e_a: 553 | # if x_m_e_a == True: 554 | # print("pipeline.enable_xformers_memory_efficient_attention()") 555 | # pipeline.enable_xformers_memory_efficient_attention() # default on 556 | # elif x_m_e_a == False: 557 | # print("pipeline.disable_xformers_memory_efficient_attention()") 558 | # pipeline.disable_xformers_memory_efficient_attention() 559 | # else: 560 | # return { 561 | # "$error": { 562 | # "code": "INVALID_XFORMERS_MEMORY_EFFICIENT_ATTENTION_VALUE", 563 | # "message": f"x_m_e_a expects True or False, not: {x_m_e_a}", 564 | # "requested": x_m_e_a, 565 | # "available": [True, False], 566 | # } 567 | # } 568 | # last_xformers_memory_efficient_attention.update({pipeline: x_m_e_a}) 569 | 570 | # Run the model 571 | # with autocast(device_id): 572 | # image = pipeline(**model_inputs).images[0] 573 | 574 | if call_inputs.get("train", None) == "dreambooth": 575 | if not USE_DREAMBOOTH: 576 | return { 577 | "$error": { 578 | "code": "TRAIN_DREAMBOOTH_NOT_AVAILABLE", 579 | "message": 'Called with callInput { train: "dreambooth" } but built with USE_DREAMBOOTH=0', 580 | } 581 | } 582 | 583 | if RUNTIME_DOWNLOADS: 584 | if os.path.isdir(model_dir): 585 | normalized_model_id = model_dir 586 | 587 | torch.set_grad_enabled(True) 588 | result = result | await asyncio.to_thread( 589 | TrainDreamBooth, 590 | normalized_model_id, 591 | pipeline, 592 | model_inputs, 593 | call_inputs, 594 | send_opts=send_opts, 595 | ) 596 | torch.set_grad_enabled(False) 597 | await send("inference", "done", {"startRequestId": startRequestId}, send_opts) 598 | result.update({"$timings": getTimings()}) 599 | return result 600 | 601 | # Do this after dreambooth as dreambooth accepts a seed int directly. 602 | seed = model_inputs.get("seed", None) 603 | if seed == None: 604 | generator = torch.Generator(device=device) 605 | generator.seed() 606 | else: 607 | generator = torch.Generator(device=device).manual_seed(seed) 608 | del model_inputs["seed"] 609 | 610 | model_inputs.update({"generator": generator}) 611 | 612 | callback = None 613 | if model_inputs.get("callback_steps", None): 614 | 615 | def callback(step: int, timestep: int, latents: torch.FloatTensor): 616 | asyncio.run( 617 | send( 618 | "inference", 619 | "progress", 620 | {"startRequestId": startRequestId, "step": step}, 621 | send_opts, 622 | ) 623 | ) 624 | 625 | else: 626 | vae = pipeline.vae 627 | # vae = tinyVae(vae) 628 | scaling_factor = vae.config.scaling_factor 629 | image_processor = pipeline.image_processor 630 | 631 | def callback(step: int, timestep: int, latents: torch.FloatTensor): 632 | status.update( 633 | "inference", step / model_inputs.get("num_inference_steps", 50) 634 | ) 635 | 636 | # with torch.no_grad(): 637 | # image = vae.decode(latents / scaling_factor, return_dict=False)[0] 638 | # image = image_processor.postprocess(image, output_type="pil")[0] 639 | # image.save(f"step_{step}_img0.png") 640 | 641 | is_sdxl = ( 642 | isinstance(model, StableDiffusionXLPipeline) 643 | or isinstance(model, StableDiffusionXLImg2ImgPipeline) 644 | or isinstance(model, StableDiffusionXLInpaintPipeline) 645 | ) 646 | 647 | with torch.inference_mode(): 648 | custom_pipeline_method = call_inputs.get("custom_pipeline_method", None) 649 | print( 650 | { 651 | "callback": callback, 652 | "**model_inputs": model_inputs, 653 | }, 654 | ) 655 | 656 | if call_inputs.get("compel_prompts", False): 657 | prepare_prompts(pipeline, model_inputs, is_sdxl) 658 | 659 | try: 660 | async_pipeline = asyncio.to_thread( 661 | getattr(pipeline, custom_pipeline_method) 662 | if custom_pipeline_method 663 | else pipeline, 664 | callback=callback, 665 | **model_inputs, 666 | ) 667 | # if call_inputs.get("PIPELINE") != "StableDiffusionPipeline": 668 | # # autocast im2img and inpaint which are broken in 0.4.0, 0.4.1 669 | # # still broken in 0.5.1 670 | # with autocast(device_id): 671 | # images = (await async_pipeline).images 672 | # else: 673 | pipeResult = await async_pipeline 674 | images = pipeResult.images 675 | 676 | except Exception as err: 677 | return { 678 | "$error": { 679 | "code": "PIPELINE_ERROR", 680 | "name": type(err).__name__, 681 | "message": str(err), 682 | "stack": traceback.format_exc(), 683 | } 684 | } 685 | 686 | images_base64 = [] 687 | image_format = call_inputs.get("image_format", "PNG") 688 | image_opts = ( 689 | {"lossless": True} if image_format == "PNG" or image_format == "WEBP" else {} 690 | ) 691 | for image in images: 692 | buffered = BytesIO() 693 | image.save(buffered, format=image_format, **image_opts) 694 | images_base64.append(base64.b64encode(buffered.getvalue()).decode("utf-8")) 695 | 696 | await send("inference", "done", {"startRequestId": startRequestId}, send_opts) 697 | 698 | # Return the results as a dictionary 699 | if len(images_base64) > 1: 700 | result = result | {"images_base64": images_base64} 701 | else: 702 | result = result | {"image_base64": images_base64[0]} 703 | 704 | nsfw_content_detected = pipeResult.get("nsfw_content_detected", None) 705 | if nsfw_content_detected: 706 | result = result | {"nsfw_content_detected": nsfw_content_detected} 707 | 708 | # TODO, move and generalize in device.py 709 | mem_usage = 0 710 | if torch.cuda.is_available(): 711 | mem_usage = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated() 712 | 713 | result = result | {"$timings": getTimings(), "$mem_usage": mem_usage} 714 | 715 | return result 716 | -------------------------------------------------------------------------------- /api/convert_to_diffusers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import subprocess 4 | import torch 5 | import json 6 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( 7 | download_from_original_stable_diffusion_ckpt, 8 | ) 9 | from diffusers.pipelines.stable_diffusion import ( 10 | StableDiffusionInpaintPipeline, 11 | ) 12 | from utils import Storage 13 | from device import device_id 14 | 15 | MODEL_ID = os.environ.get("MODEL_ID", None) 16 | CHECKPOINT_DIR = "/root/.cache/checkpoints" 17 | CHECKPOINT_URL = os.environ.get("CHECKPOINT_URL", None) 18 | CHECKPOINT_CONFIG_URL = os.environ.get("CHECKPOINT_CONFIG_URL", None) 19 | CHECKPOINT_ARGS = os.environ.get("CHECKPOINT_ARGS", None) 20 | # _CONVERT_SPECIAL = os.environ.get("_CONVERT_SPECIAL", None) 21 | 22 | 23 | def main( 24 | model_id: str, 25 | checkpoint_url: str, 26 | checkpoint_config_url: str, 27 | checkpoint_args: dict = {}, 28 | path=None, 29 | ): 30 | if not path: 31 | fname = checkpoint_url.split("/").pop() 32 | path = os.path.join(CHECKPOINT_DIR, fname) 33 | 34 | if checkpoint_config_url and checkpoint_config_url != "": 35 | storage = Storage(checkpoint_config_url) 36 | configPath = CHECKPOINT_DIR + "/" + path + "_config.yaml" 37 | print(f"Downloading {checkpoint_config_url} to {configPath}...") 38 | storage.download_file(configPath) 39 | 40 | # specialSrc = "https://raw.githubusercontent.com/hafriedlander/diffusers/stable_diffusion_2/scripts/convert_original_stable_diffusion_to_diffusers.py" 41 | # specialPath = CHECKPOINT_DIR + "/" + "convert_special.py" 42 | # if _CONVERT_SPECIAL: 43 | # storage = Storage(specialSrc) 44 | # print(f"Downloading {specialSrc} to {specialPath}") 45 | # storage.download_file(specialPath) 46 | 47 | # scriptPath = ( 48 | # # specialPath 49 | # # if _CONVERT_SPECIAL 50 | # # else 51 | # "./diffusers/scripts/convert_original_stable_diffusion_to_diffusers.py" 52 | # ) 53 | 54 | print("Converting " + path + " to diffusers model " + model_id + "...", flush=True) 55 | 56 | # These are now in main requirements.txt. 57 | # subprocess.run( 58 | # ["pip", "install", "omegaconf", "pytorch_lightning", "tensorboard"], check=True 59 | # ) 60 | # Diffusers now uses requests instead, yay! 61 | # subprocess.run(["apt-get", "install", "-y", "wget"], check=True) 62 | 63 | # We can now specify this ourselves and don't need to modify the script. 64 | # if device_id == "cpu": 65 | # subprocess.run( 66 | # [ 67 | # "sed", 68 | # "-i", 69 | # # Force loading into CPU 70 | # "s/torch.load(args.checkpoint_path)/torch.load(args.checkpoint_path, map_location=torch.device('cpu'))/", 71 | # scriptPath, 72 | # ] 73 | # ) 74 | # # Nice to check but also there seems to be a race condition here which 75 | # # needs further investigation. Python docs are clear that subprocess.run() 76 | # # will "Wait for command to complete, then return a CompletedProcess instance." 77 | # # But it really seems as though without the grep in the middle, the script is 78 | # # run before sed completes, or maybe there's some FS level caching gotchas. 79 | # subprocess.run( 80 | # [ 81 | # "grep", 82 | # "torch.load", 83 | # scriptPath, 84 | # ], 85 | # check=True, 86 | # ) 87 | 88 | # args = [ 89 | # "python3", 90 | # scriptPath, 91 | # "--extract_ema", 92 | # "--checkpoint_path", 93 | # fname, 94 | # "--dump_path", 95 | # model_id, 96 | # ] 97 | 98 | # if checkpoint_config_url: 99 | # args.append("--original_config_file") 100 | # args.append(configPath) 101 | 102 | # subprocess.run( 103 | # args, 104 | # check=True, 105 | # ) 106 | 107 | # Oh yay! Diffusers abstracted this now, so much easier to use. 108 | # But less tested. Changed on 2023-02-18. TODO, remove commented 109 | # out code above once this has more usage. 110 | 111 | # diffusers defaults 112 | args = { 113 | "scheduler_type": "pndm", 114 | } 115 | 116 | # our defaults 117 | args.update( 118 | { 119 | "checkpoint_path_or_dict": path, 120 | "original_config_file": configPath if checkpoint_config_url else None, 121 | "device": device_id, 122 | "extract_ema": True, 123 | "from_safetensors": "safetensor" in path.lower(), 124 | } 125 | ) 126 | 127 | if "inpaint" in path or "Inpaint" in path: 128 | args.update({"pipeline_class": StableDiffusionInpaintPipeline}) 129 | 130 | # user overrides 131 | args.update(checkpoint_args) 132 | 133 | pipe = download_from_original_stable_diffusion_ckpt(**args) 134 | pipe.save_pretrained(model_id, safe_serialization=True) 135 | 136 | 137 | if __name__ == "__main__": 138 | # response = requests.get( 139 | # "https://github.com/huggingface/diffusers/raw/main/scripts/convert_original_stable_diffusion_to_diffusers.py" 140 | # ) 141 | # open("convert_original_stable_diffusion_to_diffusers.py", "wb").write( 142 | # response.content 143 | # ) 144 | 145 | if CHECKPOINT_URL and CHECKPOINT_URL != "": 146 | checkpoint_args = json.loads(CHECKPOINT_ARGS) if CHECKPOINT_ARGS else {} 147 | main( 148 | MODEL_ID, 149 | CHECKPOINT_URL, 150 | CHECKPOINT_CONFIG_URL, 151 | checkpoint_args=checkpoint_args, 152 | ) 153 | -------------------------------------------------------------------------------- /api/device.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | if torch.cuda.is_available(): 4 | print("[device] CUDA (Nvidia) detected") 5 | device_id = "cuda" 6 | device_name = torch.cuda.get_device_name() 7 | elif torch.backends.mps.is_available(): 8 | print("[device] MPS (MacOS Metal, Apple M1, etc) detected") 9 | device_id = "mps" 10 | device_name = "MPS" 11 | else: 12 | print("[device] CPU only - no GPU detected") 13 | device_id = "cpu" 14 | device_name = "CPU only" 15 | 16 | if not torch.backends.cuda.is_built(): 17 | print( 18 | "CUDA not available because the current PyTorch install was not " 19 | "built with CUDA enabled." 20 | ) 21 | if torch.backends.mps.is_built(): 22 | print( 23 | "MPS not available because the current MacOS version is not 12.3+ " 24 | "and/or you do not have an MPS-enabled device on this machine." 25 | ) 26 | else: 27 | print( 28 | "MPS not available because the current PyTorch install was not " 29 | "built with MPS enabled." 30 | ) 31 | 32 | device = torch.device(device_id) 33 | -------------------------------------------------------------------------------- /api/download.py: -------------------------------------------------------------------------------- 1 | # In this file, we define download_model 2 | # It runs during container build time to get model weights built into the container 3 | 4 | import os 5 | from loadModel import loadModel, MODEL_IDS 6 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler 7 | from transformers import CLIPTextModel, CLIPTokenizer 8 | from utils import Storage 9 | import subprocess 10 | from pathlib import Path 11 | import shutil 12 | from convert_to_diffusers import main as convert_to_diffusers 13 | from download_checkpoint import main as download_checkpoint 14 | from status import status 15 | import asyncio 16 | 17 | USE_DREAMBOOTH = os.environ.get("USE_DREAMBOOTH") 18 | HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN") 19 | RUNTIME_DOWNLOADS = os.environ.get("RUNTIME_DOWNLOADS") 20 | 21 | HOME = os.path.expanduser("~") 22 | MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api") 23 | Path(MODELS_DIR).mkdir(parents=True, exist_ok=True) 24 | 25 | 26 | # i.e. don't run during build 27 | async def send(type: str, status: str, payload: dict = {}, send_opts: dict = {}): 28 | if RUNTIME_DOWNLOADS: 29 | from send import send as _send 30 | 31 | await _send(type, status, payload, send_opts) 32 | 33 | 34 | def normalize_model_id(model_id: str, model_revision): 35 | normalized_model_id = "models--" + model_id.replace("/", "--") 36 | if model_revision: 37 | normalized_model_id += "--" + model_revision 38 | return normalized_model_id 39 | 40 | 41 | async def download_model( 42 | model_url=None, 43 | model_id=None, 44 | model_revision=None, 45 | checkpoint_url=None, 46 | checkpoint_config_url=None, 47 | hf_model_id=None, 48 | model_precision=None, 49 | send_opts={}, 50 | pipeline_class=None, 51 | ): 52 | print( 53 | "download_model", 54 | { 55 | "model_url": model_url, 56 | "model_id": model_id, 57 | "model_revision": model_revision, 58 | "hf_model_id": hf_model_id, 59 | "checkpoint_url": checkpoint_url, 60 | "checkpoint_config_url": checkpoint_config_url, 61 | }, 62 | ) 63 | hf_model_id = hf_model_id or model_id 64 | normalized_model_id = model_id 65 | 66 | # if model_url != "": # throws an error, useful to debug stdout/stderr order 67 | if model_url: 68 | normalized_model_id = normalize_model_id(model_id, model_revision) 69 | print({"normalized_model_id": normalized_model_id}) 70 | filename = model_url.split("/").pop() 71 | if not filename: 72 | filename = normalized_model_id + ".tar.zst" 73 | model_file = os.path.join(MODELS_DIR, filename) 74 | storage = Storage( 75 | model_url, default_path=normalized_model_id + ".tar.zst", status=status 76 | ) 77 | exists = storage.file_exists() 78 | if exists: 79 | model_dir = os.path.join(MODELS_DIR, normalized_model_id) 80 | print("model_dir", model_dir) 81 | await asyncio.to_thread(storage.download_and_extract, model_file, model_dir) 82 | else: 83 | if checkpoint_url: 84 | path = download_checkpoint(checkpoint_url) 85 | convert_to_diffusers( 86 | model_id=model_id, 87 | checkpoint_url=checkpoint_url, 88 | checkpoint_config_url=checkpoint_config_url, 89 | path=path, 90 | ) 91 | else: 92 | print("Does not exist, let's try find it on huggingface") 93 | print( 94 | { 95 | "model_precision": model_precision, 96 | "model_revision": model_revision, 97 | } 98 | ) 99 | # This would be quicker to just model.to(device) afterwards, but 100 | # this conveniently logs all the timings (and doesn't happen often) 101 | print("download") 102 | await send("download", "start", {}, send_opts) 103 | model = loadModel( 104 | hf_model_id, 105 | False, 106 | precision=model_precision, 107 | revision=model_revision, 108 | pipeline_class=pipeline_class, 109 | ) # download 110 | await send("download", "done", {}, send_opts) 111 | 112 | print("load") 113 | model = loadModel( 114 | hf_model_id, 115 | True, 116 | precision=model_precision, 117 | revision=model_revision, 118 | pipeline_class=pipeline_class, 119 | ) # load 120 | # dir = "models--" + model_id.replace("/", "--") + "--dda" 121 | dir = os.path.join(MODELS_DIR, normalized_model_id) 122 | model.save_pretrained(dir, safe_serialization=True) 123 | 124 | # This is all duped from train_dreambooth, need to refactor TODO XXX 125 | await send("compress", "start", {}, send_opts) 126 | subprocess.run( 127 | f"tar cvf - -C {dir} . | zstd -o {model_file}", 128 | shell=True, 129 | check=True, # TODO, rather don't raise and return an error in JSON 130 | ) 131 | 132 | await send("compress", "done", {}, send_opts) 133 | subprocess.run(["ls", "-l", model_file]) 134 | 135 | await send("upload", "start", {}, send_opts) 136 | upload_result = storage.upload_file(model_file, filename) 137 | await send("upload", "done", {}, send_opts) 138 | print(upload_result) 139 | os.remove(model_file) 140 | 141 | # leave model dir for future loads... make configurable? 142 | # shutil.rmtree(dir) 143 | 144 | # TODO, swap directories, inside HF's cache structure. 145 | 146 | else: 147 | if checkpoint_url: 148 | path = download_checkpoint(checkpoint_url) 149 | convert_to_diffusers( 150 | model_id=model_id, 151 | checkpoint_url=checkpoint_url, 152 | checkpoint_config_url=checkpoint_config_url, 153 | path=path, 154 | ) 155 | else: 156 | # do a dry run of loading the huggingface model, which will download weights at build time 157 | loadModel( 158 | model_id=hf_model_id, 159 | load=False, 160 | precision=model_precision, 161 | revision=model_revision, 162 | pipeline_class=pipeline_class, 163 | ) 164 | 165 | # if USE_DREAMBOOTH: 166 | # Actually we can re-use these from the above loaded model 167 | # Will remove this soon if no more surprises 168 | # for subfolder, model in [ 169 | # ["tokenizer", CLIPTokenizer], 170 | # ["text_encoder", CLIPTextModel], 171 | # ["vae", AutoencoderKL], 172 | # ["unet", UNet2DConditionModel], 173 | # ["scheduler", DDPMScheduler] 174 | # ]: 175 | # print(subfolder, model) 176 | # model.from_pretrained( 177 | # MODEL_ID, 178 | # subfolder=subfolder, 179 | # revision=revision, 180 | # use_auth_token=HF_AUTH_TOKEN, 181 | # ) 182 | 183 | 184 | if __name__ == "__main__": 185 | asyncio.run( 186 | download_model( 187 | model_url=os.environ.get("MODEL_URL"), 188 | model_id=os.environ.get("MODEL_ID"), 189 | hf_model_id=os.environ.get("HF_MODEL_ID"), 190 | model_revision=os.environ.get("MODEL_REVISION"), 191 | model_precision=os.environ.get("MODEL_PRECISION"), 192 | checkpoint_url=os.environ.get("CHECKPOINT_URL"), 193 | checkpoint_config_url=os.environ.get("CHECKPOINT_CONFIG_URL"), 194 | ) 195 | ) 196 | -------------------------------------------------------------------------------- /api/download_checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | from utils import Storage 3 | 4 | CHECKPOINT_URL = os.environ.get("CHECKPOINT_URL", None) 5 | CHECKPOINT_DIR = "/root/.cache/checkpoints" 6 | 7 | 8 | def main(checkpoint_url: str): 9 | if not os.path.isdir(CHECKPOINT_DIR): 10 | os.makedirs(CHECKPOINT_DIR) 11 | 12 | storage = Storage(checkpoint_url) 13 | storage_query_fname = storage.query.get("fname") 14 | if storage_query_fname: 15 | fname = storage_query_fname[0] 16 | else: 17 | fname = checkpoint_url.split("/").pop() 18 | path = os.path.join(CHECKPOINT_DIR, fname) 19 | 20 | if not os.path.isfile(path): 21 | storage.download_file(path) 22 | 23 | return path 24 | 25 | 26 | if __name__ == "__main__": 27 | if CHECKPOINT_URL: 28 | main(CHECKPOINT_URL) 29 | -------------------------------------------------------------------------------- /api/extras/__init__.py: -------------------------------------------------------------------------------- 1 | from .upsample import upsample 2 | -------------------------------------------------------------------------------- /api/extras/upsample/__init__.py: -------------------------------------------------------------------------------- 1 | from .upsample import upsample 2 | -------------------------------------------------------------------------------- /api/extras/upsample/models.py: -------------------------------------------------------------------------------- 1 | upsamplers = { 2 | "RealESRGAN_x4plus": { 3 | "name": "General - RealESRGANplus", 4 | "weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", 5 | "filename": "RealESRGAN_x4plus.pth", 6 | "net": "RRDBNet", 7 | "initArgs": { 8 | "num_in_ch": 3, 9 | "num_out_ch": 3, 10 | "num_feat": 64, 11 | "num_block": 23, 12 | "num_grow_ch": 32, 13 | "scale": 4, 14 | }, 15 | "netscale": 4, 16 | }, 17 | # "RealESRNet_x4plus": { 18 | # "name": "", 19 | # "weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth", 20 | # "path": "weights/RealESRNet_x4plus.pth", 21 | # }, 22 | "RealESRGAN_x4plus_anime_6B": { 23 | "name": "Anime - anime6B", 24 | "weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth", 25 | "filename": "RealESRGAN_x4plus_anime_6B.pth", 26 | "net": "RRDBNet", 27 | "initArgs": { 28 | "num_in_ch": 3, 29 | "num_out_ch": 3, 30 | "num_feat": 64, 31 | "num_block": 6, 32 | "num_grow_ch": 32, 33 | "scale": 4, 34 | }, 35 | "netscale": 4, 36 | }, 37 | # "RealESRGAN_x2plus": { 38 | # "name": "", 39 | # "weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth", 40 | # "path": "weights/RealESRGAN_x2plus.pth", 41 | # }, 42 | # "realesr-animevideov3": { 43 | # "name": "AnimeVideo - v3", 44 | # "weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth", 45 | # "path": "weights/realesr-animevideov3.pth", 46 | # }, 47 | "realesr-general-x4v3": { 48 | "name": "General - v3", 49 | # [, "weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth" ], 50 | "weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth", 51 | "filename": "realesr-general-x4v3.pth", 52 | "net": "SRVGGNetCompact", 53 | "initArgs": { 54 | "num_in_ch": 3, 55 | "num_out_ch": 3, 56 | "num_feat": 64, 57 | "num_conv": 32, 58 | "upscale": 4, 59 | "act_type": "prelu", 60 | }, 61 | "netscale": 4, 62 | }, 63 | } 64 | 65 | face_enhancers = { 66 | "GFPGAN": { 67 | "name": "GFPGAN", 68 | "weights": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth", 69 | "filename": "GFPGANv1.4.pth", 70 | }, 71 | } 72 | 73 | models_by_type = { 74 | "upsamplers": upsamplers, 75 | "face_enhancers": face_enhancers, 76 | } 77 | -------------------------------------------------------------------------------- /api/extras/upsample/upsample.py: -------------------------------------------------------------------------------- 1 | import os 2 | import asyncio 3 | from pathlib import Path 4 | 5 | import base64 6 | from io import BytesIO 7 | import PIL 8 | import json 9 | import cv2 10 | import numpy as np 11 | import torch 12 | import torchvision 13 | 14 | from basicsr.archs.rrdbnet_arch import RRDBNet 15 | from realesrgan import RealESRGANer 16 | from realesrgan.archs.srvgg_arch import SRVGGNetCompact 17 | from gfpgan import GFPGANer 18 | 19 | from .models import models_by_type, upsamplers, face_enhancers 20 | from status import status 21 | from utils import Storage 22 | from send import send 23 | 24 | print( 25 | { 26 | "torch.__version__": torch.__version__, 27 | "torchvision.__version__": torchvision.__version__, 28 | } 29 | ) 30 | 31 | HOME = os.path.expanduser("~") 32 | CACHE_DIR = os.path.join(HOME, ".cache", "diffusers-api", "upsample") 33 | 34 | 35 | def cache_path(filename): 36 | return os.path.join(CACHE_DIR, filename) 37 | 38 | 39 | async def assert_model_exists(src, filename, send_opts, opts={}): 40 | dest = cache_path(filename) if not opts.get("absolutePath", None) else filename 41 | if not os.path.exists(dest): 42 | await send("download", "start", {}, send_opts) 43 | storage = Storage(src, status=status) 44 | # await storage.download_file(dest) 45 | await asyncio.to_thread(storage.download_file, dest) 46 | await send("download", "done", {}, send_opts) 47 | 48 | 49 | async def download_models(send_opts={}): 50 | Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) 51 | 52 | for type in models_by_type: 53 | models = models_by_type[type] 54 | for model_key in models: 55 | model = models[model_key] 56 | await assert_model_exists(model["weights"], model["filename"], send_opts) 57 | 58 | Path("gfpgan/weights").mkdir(parents=True, exist_ok=True) 59 | 60 | await assert_model_exists( 61 | "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth", 62 | "detection_Resnet50_Final.pth", 63 | send_opts, 64 | ) 65 | await assert_model_exists( 66 | "https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth", 67 | "parsing_parsenet.pth", 68 | send_opts, 69 | ) 70 | 71 | # hardcoded paths in xinntao/facexlib 72 | filenames = ["detection_Resnet50_Final.pth", "parsing_parsenet.pth"] 73 | for file in filenames: 74 | if not os.path.exists(f"gfpgan/weights/{file}"): 75 | os.symlink(cache_path(file), f"gfpgan/weights/{file}") 76 | 77 | 78 | nets = { 79 | "RRDBNet": RRDBNet, 80 | "SRVGGNetCompact": SRVGGNetCompact, 81 | } 82 | 83 | models = {} 84 | 85 | 86 | async def upsample(model_inputs, call_inputs, send_opts={}, startRequestId=None): 87 | global models 88 | 89 | # TODO, only download relevant models for this request 90 | await download_models() 91 | 92 | model_id = call_inputs.get("MODEL_ID", None) 93 | 94 | if not model_id: 95 | return { 96 | "$error": { 97 | "code": "MISSING_MODEL_ID", 98 | "message": "call_inputs.MODEL_ID is required, but not given.", 99 | } 100 | } 101 | 102 | model = models.get(model_id, None) 103 | if not model: 104 | model = models_by_type["upsamplers"].get(model_id, None) 105 | if not model: 106 | return { 107 | "$error": { 108 | "code": "MISSING_MODEL", 109 | "message": f'Model "{model_id}" not available on this container.', 110 | "requested": model_id, 111 | "available": '"' + '", "'.join(models.keys()) + '"', 112 | } 113 | } 114 | else: 115 | modelModel = nets[model["net"]](**model["initArgs"]) 116 | await send( 117 | "loadModel", 118 | "start", 119 | {"startRequestId": startRequestId}, 120 | send_opts, 121 | ) 122 | upsampler = RealESRGANer( 123 | scale=model["netscale"], 124 | model_path=cache_path(model["filename"]), 125 | dni_weight=None, 126 | model=modelModel, 127 | tile=0, 128 | tile_pad=10, 129 | pre_pad=0, 130 | half=True, 131 | ) 132 | await send( 133 | "loadModel", 134 | "done", 135 | {"startRequestId": startRequestId}, 136 | send_opts, 137 | ) 138 | model.update({"model": modelModel, "upsampler": upsampler}) 139 | models.update({model_id: model}) 140 | 141 | upsampler = model["upsampler"] 142 | 143 | input_image = model_inputs.get("input_image", None) 144 | if not input_image: 145 | return { 146 | "$error": { 147 | "code": "NO_INPUT_IMAGE", 148 | "message": "Missing required parameter `input_image`", 149 | } 150 | } 151 | 152 | if model_id == "realesr-general-x4v3": 153 | denoise_strength = model_inputs.get("denoise_strength", 1) 154 | if denoise_strength != 1: 155 | # wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3') 156 | # model_path = [model_path, wdn_model_path] 157 | # upsampler = models["realesr-general-x4v3-denoise"] 158 | # upsampler.dni_weight = dni_weight 159 | dni_weight = [denoise_strength, 1 - denoise_strength] 160 | return "TODO: denoise_strength" 161 | 162 | face_enhance = model_inputs.get("face_enhance", False) 163 | if face_enhance: 164 | face_enhancer = models.get("GFPGAN", None) 165 | if not face_enhancer: 166 | await send( 167 | "loadModel", 168 | "start", 169 | {"startRequestId": startRequestId}, 170 | send_opts, 171 | ) 172 | print("1) " + cache_path(face_enhancers["GFPGAN"]["filename"])) 173 | face_enhancer = GFPGANer( 174 | model_path=cache_path(face_enhancers["GFPGAN"]["filename"]), 175 | upscale=4, # args.outscale, 176 | arch="clean", 177 | channel_multiplier=2, 178 | bg_upsampler=upsampler, 179 | ) 180 | await send( 181 | "loadModel", 182 | "done", 183 | {"startRequestId": startRequestId}, 184 | send_opts, 185 | ) 186 | models.update({"GFPGAN": face_enhancer}) 187 | 188 | if face_enhance: # Use GFPGAN for face enhancement 189 | face_enhancer.bg_upsampler = upsampler 190 | 191 | # image = decodeBase64Image(model_inputs.get("input_image")) 192 | image_str = base64.b64decode(model_inputs["input_image"]) 193 | image_np = np.frombuffer(image_str, dtype=np.uint8) 194 | # bytes = BytesIO(base64.decodebytes(bytes(model_inputs["input_image"], "utf-8"))) 195 | img = cv2.imdecode(image_np, cv2.IMREAD_UNCHANGED) 196 | 197 | await send("inference", "start", {"startRequestId": startRequestId}, send_opts) 198 | 199 | # Run the model 200 | # with autocast("cuda"): 201 | # image = pipeline(**model_inputs).images[0] 202 | if face_enhance: 203 | _, _, output = face_enhancer.enhance( 204 | img, has_aligned=False, only_center_face=False, paste_back=True 205 | ) 206 | else: 207 | output, _rgb = upsampler.enhance(img, outscale=4) # TODO outscale param 208 | 209 | image_base64 = base64.b64encode(cv2.imencode(".jpg", output)[1]).decode() 210 | 211 | await send("inference", "done", {"startRequestId": startRequestId}, send_opts) 212 | 213 | # Return the results as a dictionary 214 | return {"$meta": {}, "image_base64": image_base64} 215 | -------------------------------------------------------------------------------- /api/getPipeline.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os, fnmatch 3 | from diffusers import ( 4 | DiffusionPipeline, 5 | pipelines as diffusers_pipelines, 6 | ) 7 | from precision import torch_dtype_from_precision 8 | 9 | HOME = os.path.expanduser("~") 10 | MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api") 11 | _pipelines = {} 12 | _availableCommunityPipelines = None 13 | 14 | 15 | def listAvailablePipelines(): 16 | return ( 17 | list( 18 | filter( 19 | lambda key: key.endswith("Pipeline"), 20 | list(diffusers_pipelines.__dict__.keys()), 21 | ) 22 | ) 23 | + availableCommunityPipelines() 24 | ) 25 | 26 | 27 | def availableCommunityPipelines(): 28 | global _availableCommunityPipelines 29 | if not _availableCommunityPipelines: 30 | _availableCommunityPipelines = list( 31 | map( 32 | lambda s: s[0:-3], 33 | fnmatch.filter(os.listdir("diffusers/examples/community"), "*.py"), 34 | ) 35 | ) 36 | 37 | return _availableCommunityPipelines 38 | 39 | 40 | def clearPipelines(): 41 | """ 42 | Clears the pipeline cache. Important to call this when changing the 43 | loaded model, as pipelines include references to the model and would 44 | therefore prevent memory being reclaimed after unloading the previous 45 | model. 46 | """ 47 | global _pipelines 48 | _pipelines = {} 49 | 50 | 51 | def getPipelineClass(pipeline_name: str): 52 | if hasattr(diffusers_pipelines, pipeline_name): 53 | return getattr(diffusers_pipelines, pipeline_name) 54 | elif pipeline_name in availableCommunityPipelines(): 55 | return DiffusionPipeline 56 | 57 | 58 | def getPipelineForModel( 59 | pipeline_name: str, model, model_id, model_revision, model_precision 60 | ): 61 | """ 62 | Inits a new pipeline, re-using components from a previously loaded 63 | model. The pipeline is cached and future calls with the same 64 | arguments will return the previously initted instance. Be sure 65 | to call `clearPipelines()` if loading a new model, to allow the 66 | previous model to be garbage collected. 67 | """ 68 | pipeline = _pipelines.get(pipeline_name) 69 | if pipeline: 70 | return pipeline 71 | 72 | start = time.time() 73 | 74 | if hasattr(diffusers_pipelines, pipeline_name): 75 | pipeline_class = getattr(diffusers_pipelines, pipeline_name) 76 | if hasattr(pipeline_class, "from_pipe"): 77 | pipeline = pipeline_class.from_pipe(model) 78 | elif hasattr(model, "components"): 79 | pipeline = pipeline_class(**model.components) 80 | else: 81 | pipeline = getattr(diffusers_pipelines, pipeline_name)( 82 | vae=model.vae, 83 | text_encoder=model.text_encoder, 84 | tokenizer=model.tokenizer, 85 | unet=model.unet, 86 | scheduler=model.scheduler, 87 | safety_checker=model.safety_checker, 88 | feature_extractor=model.feature_extractor, 89 | ) 90 | 91 | elif pipeline_name in availableCommunityPipelines(): 92 | model_dir = os.path.join(MODELS_DIR, model_id) 93 | if not os.path.isdir(model_dir): 94 | model_dir = None 95 | 96 | pipeline = DiffusionPipeline.from_pretrained( 97 | model_dir or model_id, 98 | revision=model_revision, 99 | torch_dtype=torch_dtype_from_precision(model_precision), 100 | custom_pipeline="./diffusers/examples/community/" + pipeline_name + ".py", 101 | local_files_only=True, 102 | **model.components, 103 | ) 104 | 105 | if pipeline: 106 | _pipelines.update({pipeline_name: pipeline}) 107 | diff = round((time.time() - start) * 1000) 108 | print(f"Initialized {pipeline_name} for {model_id} in {diff}ms") 109 | return pipeline 110 | -------------------------------------------------------------------------------- /api/getScheduler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import time 4 | from diffusers import schedulers as _schedulers 5 | 6 | HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN") 7 | HOME = os.path.expanduser("~") 8 | MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api") 9 | 10 | SCHEDULERS = [ 11 | "DPMSolverMultistepScheduler", 12 | "LMSDiscreteScheduler", 13 | "DDIMScheduler", 14 | "PNDMScheduler", 15 | "EulerAncestralDiscreteScheduler", 16 | "EulerDiscreteScheduler", 17 | ] 18 | 19 | DEFAULT_SCHEDULER = os.getenv("DEFAULT_SCHEDULER", SCHEDULERS[0]) 20 | 21 | 22 | """ 23 | # This was a nice idea but until we have default init vars for all schedulers 24 | # via from_pretrained(), it's a no go. In any case, loading a scheduler takes time 25 | # so better to init as needed and cache. 26 | isScheduler = re.compile(r".+Scheduler$") 27 | for key, val in _schedulers.__dict__.items(): 28 | if isScheduler.match(key): 29 | schedulers.update( 30 | { 31 | key: val.from_pretrained( 32 | MODEL_ID, subfolder="scheduler", use_auth_token=HF_AUTH_TOKEN 33 | ) 34 | } 35 | ) 36 | """ 37 | 38 | 39 | def initScheduler(MODEL_ID: str, scheduler_id: str, download=False): 40 | print(f"Initializing {scheduler_id} for {MODEL_ID}...") 41 | start = time.time() 42 | scheduler = getattr(_schedulers, scheduler_id) 43 | if scheduler == None: 44 | return None 45 | 46 | model_dir = os.path.join(MODELS_DIR, MODEL_ID) 47 | if not os.path.isdir(model_dir): 48 | model_dir = None 49 | 50 | inittedScheduler = scheduler.from_pretrained( 51 | model_dir or MODEL_ID, 52 | subfolder="scheduler", 53 | use_auth_token=HF_AUTH_TOKEN, 54 | local_files_only=not download, 55 | ) 56 | diff = round((time.time() - start) * 1000) 57 | print(f"Initialized {scheduler_id} for {MODEL_ID} in {diff}ms") 58 | 59 | return inittedScheduler 60 | 61 | 62 | schedulers = {} 63 | 64 | 65 | def getScheduler(MODEL_ID: str, scheduler_id: str, download=False): 66 | schedulersByModel = schedulers.get(MODEL_ID, None) 67 | if schedulersByModel == None: 68 | schedulersByModel = {} 69 | schedulers.update({MODEL_ID: schedulersByModel}) 70 | 71 | # Check for use of old names 72 | deprecated_map = { 73 | "LMS": "LMSDiscreteScheduler", 74 | "DDIM": "DDIMScheduler", 75 | "PNDM": "PNDMScheduler", 76 | } 77 | scheduler_renamed = deprecated_map.get(scheduler_id, None) 78 | if scheduler_renamed != None: 79 | print( 80 | f'[Deprecation Warning]: Scheduler "{scheduler_id}" is now ' 81 | f'called "{scheduler_id}". Please rename as this will ' 82 | f"stop working in a future release." 83 | ) 84 | scheduler_id = scheduler_renamed 85 | 86 | scheduler = schedulersByModel.get(scheduler_id, None) 87 | if scheduler == None: 88 | scheduler = initScheduler(MODEL_ID, scheduler_id, download) 89 | schedulersByModel.update({scheduler_id: scheduler}) 90 | 91 | return scheduler 92 | -------------------------------------------------------------------------------- /api/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/api/lib/__init__.py -------------------------------------------------------------------------------- /api/lib/prompts.py: -------------------------------------------------------------------------------- 1 | from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType 2 | 3 | 4 | def prepare_prompts(pipeline, model_inputs, is_sdxl): 5 | textual_inversion_manager = DiffusersTextualInversionManager(pipeline) 6 | if is_sdxl: 7 | compel = Compel( 8 | tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2], 9 | text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2], 10 | # diffusers has no ti in sdxl yet 11 | # https://github.com/huggingface/diffusers/issues/4376#issuecomment-1659016141 12 | # textual_inversion_manager=textual_inversion_manager, 13 | truncate_long_prompts=False, 14 | returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, 15 | requires_pooled=[False, True], 16 | ) 17 | conditioning, pooled = compel(model_inputs.get("prompt")) 18 | negative_conditioning, negative_pooled = compel( 19 | model_inputs.get("negative_prompt") 20 | ) 21 | [ 22 | conditioning, 23 | negative_conditioning, 24 | ] = compel.pad_conditioning_tensors_to_same_length( 25 | [conditioning, negative_conditioning] 26 | ) 27 | model_inputs.update( 28 | { 29 | "prompt": None, 30 | "negative_prompt": None, 31 | "prompt_embeds": conditioning, 32 | "negative_prompt_embeds": negative_conditioning, 33 | "pooled_prompt_embeds": pooled, 34 | "negative_pooled_prompt_embeds": negative_pooled, 35 | } 36 | ) 37 | 38 | else: 39 | compel = Compel( 40 | tokenizer=pipeline.tokenizer, 41 | text_encoder=pipeline.text_encoder, 42 | textual_inversion_manager=textual_inversion_manager, 43 | truncate_long_prompts=False, 44 | ) 45 | conditioning = compel(model_inputs.get("prompt")) 46 | negative_conditioning = compel(model_inputs.get("negative_prompt")) 47 | [ 48 | conditioning, 49 | negative_conditioning, 50 | ] = compel.pad_conditioning_tensors_to_same_length( 51 | [conditioning, negative_conditioning] 52 | ) 53 | model_inputs.update( 54 | { 55 | "prompt": None, 56 | "negative_prompt": None, 57 | "prompt_embeds": conditioning, 58 | "negative_prompt_embeds": negative_conditioning, 59 | } 60 | ) 61 | -------------------------------------------------------------------------------- /api/lib/textual_inversions.py: -------------------------------------------------------------------------------- 1 | import json 2 | import re 3 | import os 4 | import asyncio 5 | from utils import Storage 6 | from .vars import MODELS_DIR 7 | 8 | last_textual_inversions = None 9 | last_textual_inversion_model = None 10 | loaded_textual_inversion_tokens = [] 11 | 12 | tokenRe = re.compile( 13 | r"[#&]{1}fname=(?P[^\.]+)\.(?:pt|safetensors)(&token=(?P[^&]+))?$" 14 | ) 15 | 16 | 17 | def strMap(str: str): 18 | match = re.search(tokenRe, str) 19 | # print(match) 20 | if match: 21 | return match.group("token") or match.group("fname") 22 | 23 | 24 | def extract_tokens_from_list(textual_inversions: list): 25 | return list(map(strMap, textual_inversions)) 26 | 27 | 28 | async def handle_textual_inversions(textual_inversions: list, model, status): 29 | global last_textual_inversions 30 | global last_textual_inversion_model 31 | global loaded_textual_inversion_tokens 32 | 33 | textual_inversions_str = json.dumps(textual_inversions) 34 | if ( 35 | textual_inversions_str != last_textual_inversions 36 | or model is not last_textual_inversion_model 37 | ): 38 | if model is not last_textual_inversion_model: 39 | loaded_textual_inversion_tokens = [] 40 | last_textual_inversion_model = model 41 | # print({"textual_inversions": textual_inversions}) 42 | # tokens_to_load = extract_tokens_from_list(textual_inversions) 43 | # print({"tokens_loaded": loaded_textual_inversion_tokens}) 44 | # print({"tokens_to_load": tokens_to_load}) 45 | # 46 | # for token in loaded_textual_inversion_tokens: 47 | # if token not in tokens_to_load: 48 | # print("[TextualInversion] Removing uneeded token: " + token) 49 | # del pipeline.tokenizer.get_vocab()[token] 50 | # # del pipeline.text_encoder.get_input_embeddings().weight.data[token] 51 | # pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer)) 52 | # 53 | # loaded_textual_inversion_tokens = tokens_to_load 54 | 55 | last_textual_inversions = textual_inversions_str 56 | for textual_inversion in textual_inversions: 57 | storage = Storage(textual_inversion, no_raise=True, status=status) 58 | if storage: 59 | storage_query_fname = storage.query.get("fname") 60 | if storage_query_fname: 61 | fname = storage_query_fname[0] 62 | else: 63 | fname = textual_inversion.split("/").pop() 64 | path = os.path.join(MODELS_DIR, "textual_inversion--" + fname) 65 | if not os.path.exists(path): 66 | await asyncio.to_thread(storage.download_file, path) 67 | print("Load textual inversion " + path) 68 | token = storage.query.get("token", None) 69 | if token not in loaded_textual_inversion_tokens: 70 | model.load_textual_inversion( 71 | path, token=token, local_files_only=True 72 | ) 73 | loaded_textual_inversion_tokens.append(token) 74 | else: 75 | print("Load textual inversion " + textual_inversion) 76 | model.load_textual_inversion(textual_inversion) 77 | else: 78 | print("No changes to textual inversions since last call") 79 | -------------------------------------------------------------------------------- /api/lib/textual_inversions_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from .textual_inversions import extract_tokens_from_list 3 | 4 | 5 | class TextualInversionsTest(unittest.TestCase): 6 | def test_extract_tokens_query_fname(self): 7 | tis = ["https://civitai.com/api/download/models/106132#fname=4nj0lie.pt"] 8 | tokens = extract_tokens_from_list(tis) 9 | self.assertEqual(tokens[0], "4nj0lie") 10 | 11 | def test_extract_tokens_query_token(self): 12 | tis = [ 13 | "https://civitai.com/api/download/models/106132#fname=4nj0lie.pt&token=4nj0lie" 14 | ] 15 | tokens = extract_tokens_from_list(tis) 16 | self.assertEqual(tokens[0], "4nj0lie") 17 | -------------------------------------------------------------------------------- /api/lib/vars.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | RUNTIME_DOWNLOADS = os.getenv("RUNTIME_DOWNLOADS") == "1" 4 | USE_DREAMBOOTH = os.getenv("USE_DREAMBOOTH") == "1" 5 | MODEL_ID = os.environ.get("MODEL_ID") 6 | PIPELINE = os.environ.get("PIPELINE") 7 | HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN") 8 | HOME = os.path.expanduser("~") 9 | MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api") 10 | -------------------------------------------------------------------------------- /api/loadModel.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from diffusers import pipelines as _pipelines, AutoPipelineForText2Image 4 | from getScheduler import getScheduler, DEFAULT_SCHEDULER 5 | from precision import torch_dtype_from_precision 6 | from device import device 7 | import time 8 | 9 | HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN") 10 | PIPELINE = os.getenv("PIPELINE") 11 | USE_DREAMBOOTH = True if os.getenv("USE_DREAMBOOTH") == "1" else False 12 | HOME = os.path.expanduser("~") 13 | MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api") 14 | 15 | 16 | MODEL_IDS = [ 17 | "CompVis/stable-diffusion-v1-4", 18 | "hakurei/waifu-diffusion", 19 | # "hakurei/waifu-diffusion-v1-3", - not as diffusers yet 20 | "runwayml/stable-diffusion-inpainting", 21 | "runwayml/stable-diffusion-v1-5", 22 | "stabilityai/stable-diffusion-2" 23 | "stabilityai/stable-diffusion-2-base" 24 | "stabilityai/stable-diffusion-2-inpainting", 25 | ] 26 | 27 | 28 | def loadModel( 29 | model_id: str, 30 | load=True, 31 | precision=None, 32 | revision=None, 33 | send_opts={}, 34 | pipeline_class=None, 35 | ): 36 | torch_dtype = torch_dtype_from_precision(precision) 37 | if revision == "": 38 | revision = None 39 | 40 | print( 41 | "loadModel", 42 | { 43 | "model_id": model_id, 44 | "load": load, 45 | "precision": precision, 46 | "revision": revision, 47 | "pipeline_class": pipeline_class, 48 | }, 49 | ) 50 | 51 | if not pipeline_class: 52 | pipeline_class = AutoPipelineForText2Image 53 | 54 | pipeline = pipeline_class if PIPELINE == "ALL" else getattr(_pipelines, PIPELINE) 55 | print("pipeline", pipeline_class) 56 | 57 | print( 58 | ("Loading" if load else "Downloading") 59 | + " model: " 60 | + model_id 61 | + (f" ({revision})" if revision else "") 62 | ) 63 | 64 | scheduler = getScheduler(model_id, DEFAULT_SCHEDULER, not load) 65 | 66 | model_dir = os.path.join(MODELS_DIR, model_id) 67 | if not os.path.isdir(model_dir): 68 | model_dir = None 69 | 70 | from_pretrained = time.time() 71 | model = pipeline.from_pretrained( 72 | model_dir or model_id, 73 | revision=revision, 74 | torch_dtype=torch_dtype, 75 | use_auth_token=HF_AUTH_TOKEN, 76 | scheduler=scheduler, 77 | local_files_only=load, 78 | # Work around https://github.com/huggingface/diffusers/issues/1246 79 | # low_cpu_mem_usage=False if USE_DREAMBOOTH else True, 80 | ) 81 | from_pretrained = round((time.time() - from_pretrained) * 1000) 82 | 83 | if load: 84 | to_gpu = time.time() 85 | model.to(device) 86 | to_gpu = round((time.time() - to_gpu) * 1000) 87 | print(f"Loaded from disk in {from_pretrained} ms, to gpu in {to_gpu} ms") 88 | else: 89 | print(f"Downloaded in {from_pretrained} ms") 90 | 91 | return model if load else None 92 | -------------------------------------------------------------------------------- /api/precision.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | DEPRECATED_PRECISION = os.getenv("PRECISION") 5 | MODEL_PRECISION = os.getenv("MODEL_PRECISION") or DEPRECATED_PRECISION 6 | MODEL_REVISION = os.getenv("MODEL_REVISION") 7 | 8 | if DEPRECATED_PRECISION: 9 | print("Warning: PRECISION variable been deprecated and renamed MODEL_PRECISION") 10 | print("Your setup still works but in a future release, this will throw an error") 11 | 12 | if MODEL_PRECISION and not MODEL_REVISION: 13 | print("Warning: we no longer default to MODEL_REVISION=MODEL_PRECISION, please") 14 | print(f'explicitly set MODEL_REVISION="{MODEL_PRECISION}" if that\'s what you') 15 | print("want.") 16 | 17 | 18 | def revision_from_precision(precision=MODEL_PRECISION): 19 | # return precision if precision else None 20 | raise Exception("revision_from_precision no longer supported") 21 | 22 | 23 | def torch_dtype_from_precision(precision=MODEL_PRECISION): 24 | if precision == "fp16": 25 | return torch.float16 26 | return None 27 | 28 | 29 | def torch_dtype_from_precision(precision=MODEL_PRECISION): 30 | if precision == "fp16": 31 | return torch.float16 32 | return None 33 | -------------------------------------------------------------------------------- /api/send.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import datetime 4 | import time 5 | import requests 6 | import hashlib 7 | from requests_futures.sessions import FuturesSession 8 | from status import status as statusInstance 9 | 10 | print() 11 | environ = os.environ.copy() 12 | for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "HF_AUTH_TOKEN"]: 13 | if environ.get(key, None): 14 | environ[key] = "XXX" 15 | print(environ) 16 | print() 17 | 18 | 19 | def get_now(): 20 | return round(time.time() * 1000) 21 | 22 | 23 | SEND_URL = os.getenv("SEND_URL") 24 | if SEND_URL == "": 25 | SEND_URL = None 26 | 27 | SIGN_KEY = os.getenv("SIGN_KEY", "") 28 | if SIGN_KEY == "": 29 | SIGN_KEY = None 30 | 31 | futureSession = FuturesSession() 32 | 33 | container_id = os.getenv("CONTAINER_ID") 34 | if not container_id: 35 | with open("/proc/self/mountinfo") as file: 36 | line = file.readline().strip() 37 | while line: 38 | if "/containers/" in line: 39 | container_id = line.split("/containers/")[ 40 | -1 41 | ] # Take only text to the right 42 | container_id = container_id.split("/")[0] # Take only text to the left 43 | break 44 | line = file.readline().strip() 45 | 46 | 47 | init_used = False 48 | 49 | 50 | def clearSession(force=False): 51 | global session 52 | global init_used 53 | 54 | if init_used or force: 55 | session = {"_ctime": get_now()} 56 | else: 57 | init_used = True 58 | 59 | 60 | def getTimings(): 61 | timings = {} 62 | for key in session.keys(): 63 | if key == "_ctime": 64 | continue 65 | start = session[key].get("start", None) 66 | done = session[key].get("done", None) 67 | if start and done: 68 | timings.update({key: session[key]["done"] - session[key]["start"]}) 69 | else: 70 | timings.update({key: -1}) 71 | return timings 72 | 73 | 74 | async def send(type: str, status: str, payload: dict = {}, opts: dict = {}): 75 | now = get_now() 76 | send_url = opts.get("SEND_URL", SEND_URL) 77 | sign_key = opts.get("SIGN_KEY", SIGN_KEY) 78 | 79 | if status == "start": 80 | session.update({type: {"start": now, "last_time": now}}) 81 | elif status == "done": 82 | session[type].update({"done": now, "diff": now - session[type]["start"]}) 83 | else: 84 | session[type]["last_time"] = now 85 | 86 | data = { 87 | "type": type, 88 | "status": status, 89 | "container_id": container_id, 90 | "time": now, 91 | "t": now - session["_ctime"], 92 | "tsl": now - session[type]["last_time"], 93 | "payload": payload, 94 | } 95 | 96 | if status == "start": 97 | statusInstance.update(type, 0.0) 98 | elif status == "done": 99 | statusInstance.update(type, 1.0) 100 | 101 | if send_url and sign_key: 102 | input = json.dumps(data, separators=(",", ":")) + sign_key 103 | sig = hashlib.md5(input.encode("utf-8")).hexdigest() 104 | data["sig"] = sig 105 | 106 | print(datetime.datetime.now(), data) 107 | 108 | if send_url: 109 | futureSession.post(send_url, json=data) 110 | 111 | response = opts.get("response") 112 | if response: 113 | print("streaming above") 114 | await response.send(json.dumps(data) + "\n") 115 | 116 | # try: 117 | # requests.post(send_url, json=data) # , timeout=0.0000000001) 118 | # except requests.exceptions.ReadTimeout: 119 | # except requests.exceptions.RequestException as error: 120 | # print(error) 121 | # pass 122 | 123 | 124 | clearSession(True) 125 | -------------------------------------------------------------------------------- /api/server.py: -------------------------------------------------------------------------------- 1 | # Do not edit if deploying to Banana Serverless 2 | # This file is boilerplate for the http server, and follows a strict interface. 3 | 4 | # Instead, edit the init() and inference() functions in app.py 5 | 6 | from sanic import Sanic, response 7 | from sanic_ext import Extend 8 | import subprocess 9 | import app as user_src 10 | import traceback 11 | import os 12 | import json 13 | 14 | # We do the model load-to-GPU step on server startup 15 | # so the model object is available globally for reuse 16 | user_src.init() 17 | 18 | # Create the http server app 19 | server = Sanic("my_app") 20 | server.config.CORS_ORIGINS = os.getenv("CORS_ORIGINS") or "*" 21 | server.config.RESPONSE_TIMEOUT = 60 * 60 # 1 hour (training can be long) 22 | Extend(server) 23 | 24 | 25 | # Healthchecks verify that the environment is correct on Banana Serverless 26 | @server.route("/healthcheck", methods=["GET"]) 27 | def healthcheck(request): 28 | # dependency free way to check if GPU is visible 29 | gpu = False 30 | out = subprocess.run("nvidia-smi", shell=True) 31 | if out.returncode == 0: # success state on shell command 32 | gpu = True 33 | 34 | return response.json({"state": "healthy", "gpu": gpu}) 35 | 36 | 37 | # Inference POST handler at '/' is called for every http call from Banana 38 | @server.route("/", methods=["POST"]) 39 | async def inference(request): 40 | try: 41 | all_inputs = response.json.loads(request.json) 42 | except: 43 | all_inputs = request.json 44 | 45 | call_inputs = all_inputs.get("callInputs", None) 46 | stream_events = call_inputs and call_inputs.get("streamEvents", 0) != 0 47 | 48 | streaming_response = None 49 | if stream_events: 50 | streaming_response = await request.respond(content_type="application/x-ndjson") 51 | 52 | try: 53 | output = await user_src.inference(all_inputs, streaming_response) 54 | except Exception as err: 55 | print(err) 56 | output = { 57 | "$error": { 58 | "code": "APP_INFERENCE_ERROR", 59 | "name": type(err).__name__, 60 | "message": str(err), 61 | "stack": traceback.format_exc(), 62 | } 63 | } 64 | 65 | if stream_events: 66 | await streaming_response.send(json.dumps(output) + "\n") 67 | else: 68 | return response.json(output) 69 | 70 | 71 | if __name__ == "__main__": 72 | server.run(host="0.0.0.0", port="8000", workers=1) 73 | -------------------------------------------------------------------------------- /api/status.py: -------------------------------------------------------------------------------- 1 | class Status: 2 | def __init__(self): 3 | self.type = "init" 4 | self.progress = 0.0 5 | 6 | def update(self, type, progress): 7 | self.type = type 8 | self.progress = progress 9 | 10 | def get(self): 11 | return {"type": self.type, "progress": self.progress} 12 | 13 | 14 | status = Status() 15 | -------------------------------------------------------------------------------- /api/tests.py: -------------------------------------------------------------------------------- 1 | from test import runTest 2 | 3 | 4 | def test_memory_free_on_swap_model(): 5 | """ 6 | Make sure memory is freed when swapping models at runtime. 7 | """ 8 | result = runTest( 9 | "txt2img", 10 | {}, 11 | { 12 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base", 13 | "MODEL_PRECISION": "", # full precision 14 | "MODEL_URL": "s3://", 15 | }, 16 | {"num_inference_steps": 1}, 17 | ) 18 | mem_usage = list() 19 | mem_usage.append(result["$mem_usage"]) 20 | result = runTest( 21 | "txt2img", 22 | {}, 23 | { 24 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base", 25 | "MODEL_PRECISION": "fp16", # half precision 26 | "MODEL_URL": "s3://", 27 | }, 28 | {"num_inference_steps": 1}, 29 | ) 30 | mem_usage.append(result["$mem_usage"]) 31 | 32 | print({"mem_usage": mem_usage}) 33 | # Assert that less memory used when unloading fp32 model and 34 | # loading the fp16 variant in its place 35 | assert mem_usage[1] < mem_usage[0] 36 | -------------------------------------------------------------------------------- /api/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .storage import Storage 2 | -------------------------------------------------------------------------------- /api/utils/storage/BaseStorage.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import subprocess 4 | from abc import ABC, abstractmethod 5 | import xtarfile as tarfile 6 | 7 | 8 | class BaseArchive(ABC): 9 | def __init__(self, path, status=None): 10 | self.path = path 11 | self.status = status 12 | 13 | def updateStatus(self, type, progress): 14 | if self.status: 15 | self.status.update(type, progress) 16 | 17 | def extract(self): 18 | print("TODO") 19 | 20 | def splitext(self): 21 | base, ext = os.path.splitext(self.path) 22 | base, subext = os.path.splitext(base) 23 | return base, ext, subext 24 | 25 | 26 | class TarArchive(BaseArchive): 27 | @staticmethod 28 | def test(path): 29 | return re.search(r"\.tar", path) 30 | 31 | def extract(self, dir, dry_run=False): 32 | self.updateStatus("extract", 0) 33 | if not dir: 34 | base, ext, subext = self.splitext() 35 | parent_dir = os.path.dirname(self.path) 36 | dir = os.path.join(parent_dir, base) 37 | 38 | if not dry_run: 39 | os.mkdir(dir) 40 | 41 | def track_progress(tar): 42 | i = 0 43 | members = tar.getmembers() 44 | for member in members: 45 | i += 1 46 | self.updateStatus("extract", i / len(members)) 47 | yield member 48 | 49 | print("Extracting to " + dir) 50 | with tarfile.open(self.path, "r") as tar: 51 | tar.extractall(path=dir, members=track_progress(tar)) 52 | tar.close() 53 | subprocess.run(["ls", "-l", dir]) 54 | os.remove(self.path) 55 | 56 | self.updateStatus("extract", 1) 57 | return dir # , base, ext, subext 58 | 59 | 60 | archiveClasses = [TarArchive] 61 | 62 | 63 | def Archive(path, **kwargs): 64 | for ArchiveClass in archiveClasses: 65 | if ArchiveClass.test(path): 66 | return ArchiveClass(path, **kwargs) 67 | 68 | 69 | class BaseStorage(ABC): 70 | @staticmethod 71 | @abstractmethod 72 | def test(url): 73 | return re.search(r"^https?://", url) 74 | 75 | def __init__(self, url, **kwargs): 76 | self.url = url 77 | self.status = kwargs.get("status", None) 78 | self.query = {} 79 | 80 | def updateStatus(self, type, progress): 81 | if self.status: 82 | self.status.update(type, progress) 83 | 84 | def splitext(self): 85 | base, ext = os.path.splitext(self.url) 86 | base, subext = os.path.splitext(base) 87 | return base, ext, subext 88 | 89 | def get_filename(self): 90 | return self.url.split("/").pop() 91 | 92 | @abstractmethod 93 | def download_file(self, dest): 94 | """Download the file to `dest`""" 95 | pass 96 | 97 | def download_and_extract(self, fname, dir=None, dry_run=False): 98 | """ 99 | Downloads the file, and if it's an archive, extract it too. Returns 100 | the filename if not, or directory name (fname without extension) if 101 | it was. 102 | """ 103 | if not fname: 104 | fname = self.get_filename() 105 | 106 | archive = Archive(fname, status=self.status) 107 | if archive: 108 | # TODO, streaming pipeline 109 | self.download_file(fname) 110 | return archive.extract(dir) 111 | else: 112 | self.download_file(fname) 113 | return fname 114 | -------------------------------------------------------------------------------- /api/utils/storage/BaseStorage_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from . import Storage, S3Storage, HTTPStorage 3 | 4 | 5 | class BaseStorageTest(unittest.TestCase): 6 | def test_get_filename(self): 7 | storage = Storage("http://host.com/dir/file.tar.zst") 8 | self.assertEqual(storage.get_filename(), "file.tar.zst") 9 | 10 | class Download_and_extract(unittest.TestCase): 11 | def test_file_only(self): 12 | storage = Storage("http://host.com/dir/file.bin") 13 | result = storage.download_and_extract(dry_run=True) 14 | self.assertEqual(result, "file.bin") 15 | 16 | def test_file_archive(self): 17 | storage = Storage("http://host.com/dir/file.tar.zst") 18 | result, base, ext, subext = storage.download_and_extract(dry_run=True) 19 | self.assertEqual(result, "file") 20 | self.assertEqual(base, "file") 21 | self.assertEqual(ext, "tar") 22 | self.assertEqual(subext, "zst") 23 | -------------------------------------------------------------------------------- /api/utils/storage/HTTPStorage.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import time 4 | import requests 5 | from tqdm import tqdm 6 | from .BaseStorage import BaseStorage 7 | import urllib.parse 8 | 9 | 10 | def get_now(): 11 | return round(time.time() * 1000) 12 | 13 | 14 | class HTTPStorage(BaseStorage): 15 | @staticmethod 16 | def test(url): 17 | return re.search(r"^https?://", url) 18 | 19 | def __init__(self, url, **kwargs): 20 | super().__init__(url, **kwargs) 21 | parts = self.url.split("#", 1) 22 | self.url = parts[0] 23 | if len(parts) > 1: 24 | self.query = urllib.parse.parse_qs(parts[1]) 25 | 26 | def upload_file(self, source, dest): 27 | raise RuntimeError("HTTP PUT not implemented yet") 28 | 29 | def download_file(self, fname): 30 | print(f"Downloading {self.url} to {fname}...") 31 | resp = requests.get(self.url, stream=True) 32 | total = int(resp.headers.get("content-length", 0)) 33 | content_disposition = resp.headers.get("content-disposition") 34 | if content_disposition: 35 | filename_search = re.search('filename="(.+)"', content_disposition) 36 | if filename_search: 37 | self.filename = filename_search.group(1) 38 | else: 39 | print("Warning: content-disposition header is not found in the response.") 40 | # Can also replace 'file' with a io.BytesIO object 41 | with open(fname, "wb") as file, tqdm( 42 | desc="Downloading", 43 | total=total, 44 | unit="iB", 45 | unit_scale=True, 46 | unit_divisor=1024, 47 | ) as bar: 48 | total_written = 0 49 | for data in resp.iter_content(chunk_size=1024): 50 | size = file.write(data) 51 | bar.update(size) 52 | total_written += size 53 | self.updateStatus("download", total_written / total) 54 | -------------------------------------------------------------------------------- /api/utils/storage/S3Storage.py: -------------------------------------------------------------------------------- 1 | import boto3 2 | import botocore 3 | import re 4 | import os 5 | import time 6 | from tqdm import tqdm 7 | from botocore.client import Config 8 | from .BaseStorage import BaseStorage 9 | 10 | AWS_S3_ENDPOINT_URL = os.environ.get("AWS_S3_ENDPOINT_URL", None) 11 | AWS_S3_DEFAULT_BUCKET = os.environ.get("AWS_S3_DEFAULT_BUCKET", None) 12 | if AWS_S3_ENDPOINT_URL == "": 13 | AWS_S3_ENDPOINT_URL = None 14 | if AWS_S3_DEFAULT_BUCKET == "": 15 | AWS_S3_DEFAULT_BUCKET = None 16 | 17 | 18 | def get_now(): 19 | return round(time.time() * 1000) 20 | 21 | 22 | class S3Storage(BaseStorage): 23 | def test(url): 24 | return re.search(r"^(https?\+)?s3://", url) 25 | 26 | def __init__(self, url, **kwargs): 27 | super().__init__(url, **kwargs) 28 | 29 | if url.startswith("s3://"): 30 | url = "https://" + url[5:] 31 | elif url.startswith("http+s3://"): 32 | url = "http" + url[7:] 33 | elif url.startswith("https+s3://"): 34 | url = "https" + url[8:] 35 | 36 | s3_dest = re.match( 37 | r"^(?Phttps?://[^/]*)(/(?P[^/]+))?(/(?P.*))?$", 38 | url, 39 | ).groupdict() 40 | 41 | if not s3_dest["endpoint"] or s3_dest["endpoint"].endswith("//"): 42 | s3_dest["endpoint"] = AWS_S3_ENDPOINT_URL 43 | if not s3_dest["bucket"]: 44 | s3_dest["bucket"] = AWS_S3_DEFAULT_BUCKET 45 | if not s3_dest["path"] or s3_dest["path"] == "": 46 | s3_dest["path"] = kwargs.get("default_path", "") 47 | 48 | self.endpoint_url = s3_dest["endpoint"] 49 | self.bucket_name = s3_dest["bucket"] 50 | self.path = s3_dest["path"] 51 | 52 | self._s3resource = None 53 | self._s3client = None 54 | self._bucket = None 55 | print("self.endpoint_url", self.endpoint_url) 56 | 57 | def s3resource(self): 58 | if self._s3resource: 59 | return self._s3resource 60 | 61 | self._s3 = boto3.resource( 62 | "s3", 63 | endpoint_url=self.endpoint_url, 64 | config=Config(signature_version="s3v4"), 65 | ) 66 | return self._s3 67 | 68 | def s3client(self): 69 | if self._s3client: 70 | return self._s3client 71 | 72 | self._s3client = boto3.client( 73 | "s3", 74 | endpoint_url=self.endpoint_url, 75 | config=Config(signature_version="s3v4"), 76 | ) 77 | return self._s3client 78 | 79 | def bucket(self): 80 | if self._bucket: 81 | return self._bucket 82 | 83 | self._bucket = self.s3resource().Bucket(self.bucket_name) 84 | return self._bucket 85 | 86 | def upload_file(self, source, dest): 87 | if not dest: 88 | dest = self.path 89 | 90 | upload_start = get_now() 91 | file_size = os.stat(source).st_size 92 | with tqdm(total=file_size, unit="B", unit_scale=True, desc="Uploading") as bar: 93 | total_transferred = 0 94 | 95 | def callback(bytes_transferred): 96 | nonlocal total_transferred 97 | bar.update(bytes_transferred), 98 | total_transferred += bytes_transferred 99 | self.updateStatus("upload", total_transferred / file_size) 100 | 101 | result = self.bucket().upload_file( 102 | Filename=source, Key=dest, Callback=callback 103 | ) 104 | print(result) 105 | upload_total = get_now() - upload_start 106 | 107 | return {"$time": upload_total} 108 | 109 | def download_file(self, dest): 110 | if not dest: 111 | dest = self.path.split("/").pop() 112 | print(f"Downloading {self.url} to {dest}...") 113 | object = self.s3resource().Object(self.bucket_name, self.path) 114 | object.load() 115 | 116 | with tqdm( 117 | total=object.content_length, unit="B", unit_scale=True, desc="Downloading" 118 | ) as bar: 119 | total_transferred = 0 120 | 121 | def callback(bytes_transferred): 122 | nonlocal total_transferred 123 | bar.update(bytes_transferred), 124 | total_transferred += bytes_transferred 125 | self.updateStatus("download", total_transferred / object.content_length) 126 | 127 | object.download_file(Filename=dest, Callback=callback) 128 | 129 | def file_exists(self): 130 | # res = self.s3client().list_objects_v2( 131 | # Bucket=self.bucket_name, Prefix=self.path, MaxKeys=1 132 | # ) 133 | # return "Contents" in res 134 | object = self.s3resource().Object(self.bucket_name, self.path) 135 | try: 136 | object.load() 137 | except botocore.exceptions.ClientError as error: 138 | if error.response["Error"]["Code"] == "404": 139 | return False 140 | else: 141 | raise 142 | return True 143 | -------------------------------------------------------------------------------- /api/utils/storage/S3Storage_test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | from .S3Storage import S3Storage, AWS_S3_ENDPOINT_URL, AWS_S3_DEFAULT_BUCKET 4 | 5 | 6 | class S3StorageTest(unittest.TestCase): 7 | def test_endpoint_only_s3(self): 8 | storage = S3Storage("s3://hostname:9000") 9 | self.assertEqual(storage.endpoint_url, "https://hostname:9000") 10 | self.assertEqual(storage.bucket_name, AWS_S3_DEFAULT_BUCKET) 11 | self.assertEqual(storage.path, "") 12 | 13 | def test_endpoint_only_http_s3(self): 14 | storage = S3Storage("http+s3://hostname:9000") 15 | self.assertEqual(storage.endpoint_url, "http://hostname:9000") 16 | self.assertEqual(storage.bucket_name, AWS_S3_DEFAULT_BUCKET) 17 | self.assertEqual(storage.path, "") 18 | 19 | def test_endpoint_only_https_s3(self): 20 | storage = S3Storage("https+s3://hostname:9000") 21 | self.assertEqual(storage.endpoint_url, "https://hostname:9000") 22 | self.assertEqual(storage.bucket_name, AWS_S3_DEFAULT_BUCKET) 23 | self.assertEqual(storage.path, "") 24 | 25 | def test_bucket_only(self): 26 | storage = S3Storage("s3:///bucket") 27 | self.assertEqual(storage.endpoint_url, AWS_S3_ENDPOINT_URL) 28 | self.assertEqual(storage.bucket_name, "bucket") 29 | self.assertEqual(storage.path, "") 30 | 31 | def test_url_with_bucket_and_file_only(self): 32 | storage = S3Storage("s3:///bucket/file") 33 | self.assertEqual(storage.endpoint_url, AWS_S3_ENDPOINT_URL) 34 | self.assertEqual(storage.bucket_name, "bucket") 35 | self.assertEqual(storage.path, "file") 36 | 37 | def test_full_url_with_subdirectory(self): 38 | storage = S3Storage("s3://host/bucket/path/file") 39 | self.assertEqual(storage.endpoint_url, "https://host") 40 | self.assertEqual(storage.bucket_name, "bucket") 41 | self.assertEqual(storage.path, "path/file") 42 | 43 | 44 | if __name__ == "__main__": 45 | unittest.main() 46 | -------------------------------------------------------------------------------- /api/utils/storage/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from .S3Storage import S3Storage 4 | from .HTTPStorage import HTTPStorage 5 | 6 | classes = [S3Storage, HTTPStorage] 7 | 8 | 9 | def Storage(url, no_raise=False, **kwargs): 10 | for StorageClass in classes: 11 | if StorageClass.test(url): 12 | return StorageClass(url, **kwargs) 13 | 14 | if no_raise: 15 | return None 16 | else: 17 | raise RuntimeError("No storage handler for: " + url) 18 | -------------------------------------------------------------------------------- /api/utils/storage/__init__test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from . import Storage, S3Storage, HTTPStorage 3 | 4 | 5 | class StorageTest(unittest.TestCase): 6 | def test_url_s3(self): 7 | storage = Storage("s3://hostname:9000") 8 | self.assertTrue(isinstance(storage, S3Storage)) 9 | 10 | def test_url_http(self): 11 | storage = Storage("http://hostname:9000") 12 | self.assertTrue(isinstance(storage, HTTPStorage)) 13 | 14 | def test_no_match_raise(self): 15 | with self.assertRaises(RuntimeError): 16 | storage = Storage("not_a_url") 17 | 18 | def test_no_match_no_raise(self): 19 | storage = Storage("not_a_url", no_raise=True) 20 | self.assertIsNone(storage) 21 | -------------------------------------------------------------------------------- /build: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This is my common way of building, but you can build however you like. 4 | # Note if you using a proxy, you need to have it running first. 5 | 6 | DOCKER_BUILDKIT=1 BUILDKIT_PROGRESS=plain \ 7 | docker build \ 8 | -t gadicc/diffusers-api \ 9 | -t gadicc/diffusers-api:test \ 10 | --build-arg http_proxy="http://172.17.0.1:3128" \ 11 | --build-arg https_proxy="http://172.17.0.1:3128" \ 12 | "$@" . 13 | -------------------------------------------------------------------------------- /docs/internal_safetensor_cache_flow.md: -------------------------------------------------------------------------------- 1 | internal document to gather my thoughts 2 | 3 | RUNTIME_DOWNLOADS=1 (must be build arg) 4 | IMAGE_CLOUD_CACHE="s3://" (can be env arg) 5 | CREATE_MISSING=1 6 | 7 | e.g. stabilityai/stable-diffusion-2-1-base 8 | 9 | 1. Try download from IMAGE_CLOUD_CACHE 10 | 1. If found, use. 11 | 2. If not found: 12 | 1. Download from HuggingFace 13 | 2. In a subprocess: 14 | 1. Save with safetesors to tmp directory 15 | 2. Upload to IMAGE_CLOUD_CACHE 16 | 3. Delete original model dir, mv tmp to model dir (for next load) 17 | 1. Run inference with HF model. 18 | 19 | FileNotFoundError: [Errno 2] No such file or directory: '/root/.cache/huggingface/diffusers/models--stabilityai--stable-diffusion-2-1-base/refs/main' 20 | 21 | 22 | NVIDIA RTX Quadro 5000 23 | 24 | NO SAFETENSORS 25 | Downloaded in 462557 ms 26 | Loading model: stabilityai/stable-diffusion-2-1 (fp32) 27 | Loaded from disk in 3113 ms, to gpu in 1644 ms 28 | 29 | SAFETENSORS_FAST_GPU=0 30 | Loaded from disk in 2741 ms, to gpu in 557 ms 31 | 32 | SAFETENSORS_FAST_GPU=1 33 | Loaded from disk in 1153 ms, to gpu in 1495 ms 34 | 35 | 36 | 37 | NVIDIA RTX Quadro 5000 (fp16) 38 | 39 | NO SAFETENSORS 40 | Downloaded in 462557 ms 41 | Loading model: stabilityai/stable-diffusion-2-1-base (fp16) 42 | Loaded from disk in 2043 ms, to gpu in 1539 ms 43 | 44 | SAFETENSORS_FAST_GPU=0 45 | 46 | 47 | SAFETENSORS_FAST_GPU=1 48 | Loaded from disk in 1134 ms, to gpu in 1184 ms 49 | -------------------------------------------------------------------------------- /docs/storage.md: -------------------------------------------------------------------------------- 1 | # Storage 2 | 3 | Most URLs passed at build args or call args support special URLs, both to 4 | store and retrieve files. 5 | 6 | **The Storage API is new and may change without notice, please keep a 7 | careful look in the CHANGELOG when upgrading**. 8 | 9 | * [AWS S3](#s3) 10 | 11 | 12 | ## S3 13 | 14 | ### Build Args 15 | 16 | Set the following **build-args**, as appropriate (through the Banana dashboard, 17 | by modifying the appropriate lines in the `Dockerfile`, or by specifying, e.g. 18 | `--build-arg AWS_ACCESS_KEY="XXX"` etc.) 19 | 20 | ```Dockerfile 21 | ARG AWS_ACCESS_KEY_ID="XXX" 22 | ARG AWS_SECRET_ACCESS_KEY="XXX" 23 | ARG AWS_DEFAULT_REGION="us-west-1" # best for banana 24 | # Optional. ONLY SET THIS IF YOU KNOW YOU NEED TO. 25 | # Usually only if you're using non-Amazon S3-compatible storage. 26 | # If you need this, your provider will tell you exactly what 27 | # to put here. Otherwise leave it blank to automatically use 28 | # the correct Amazon S3 endpoint. 29 | ARG AWS_S3_ENDPOINT_URL 30 | ``` 31 | 32 | ### Usage 33 | 34 | In any URL where Storage is supported (e.g. dreambooth `dest_url`): 35 | 36 | * `s3://endpoint/bucket/path/to/file` 37 | * `s3:///bucket/file` (uses the default endpoint) 38 | * `s3:///bucket` (for `dest_url`, filename will match your output model) 39 | * `http+s3://...` (force http instead of https) -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # This entire file is no longer used but kept around for reference. 4 | 5 | if [ "$FLASH_ATTENTION" == "1" ]; then 6 | 7 | echo "Building with flash attention" 8 | git clone https://github.com/HazyResearch/flash-attention.git 9 | cd flash-attention 10 | git checkout cutlass 11 | git submodule init 12 | git submodule update 13 | python setup.py install 14 | 15 | cd .. 16 | git clone https://github.com/HazyResearch/diffusers.git 17 | pip install -e diffusers 18 | 19 | else 20 | 21 | echo "Building without flash attention" 22 | git clone https://github.com/huggingface/diffusers 23 | cd diffusers 24 | git checkout v0.9.0 25 | # 2022-11-21 [Community Pipelines] K-Diffusion Pipeline 26 | # git checkout 182eb959e5efc8c77fa31394ca55376331c0ed25 27 | # 2022-11-24 v_prediction (for SD 2.0) 28 | # git checkout 30f6f4410487b6c1cf5be2da6c7e8fc844fb9a44 29 | cd .. 30 | pip install -e diffusers 31 | 32 | fi 33 | 34 | -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "docker-diffusers-api", 3 | "version": "0.0.1", 4 | "main": "index.js", 5 | "repository": "https://github.com/kiri-art/docker-diffusers-api.git", 6 | "author": "Gadi Cohen ", 7 | "license": "MIT", 8 | "private": true, 9 | "devDependencies": { 10 | "@semantic-release-plus/docker": "^3.1.2", 11 | "@semantic-release/changelog": "^6.0.2", 12 | "@semantic-release/git": "^10.0.1", 13 | "semantic-release": "^19.0.5", 14 | "semantic-release-plus": "^20.0.0" 15 | } 16 | } 17 | -------------------------------------------------------------------------------- /prime.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | # need to fix this. 4 | #download_model {'model_url': 's3://', 'model_id': 'Linaqruf/anything-v3.0', 'model_revision': 'fp16', 'hf_model_id': None} 5 | # {'normalized_model_id': 'models--Linaqruf--anything-v3.0--fp16'} 6 | #self.endpoint_url https://6fb830ebb3c8fed82a52524211d9c54e.r2.cloudflarestorage.com/diffusers 7 | #Downloading s3:// to /root/.cache/diffusers-api/models--Linaqruf--anything-v3.0--fp16.tar.zst... 8 | 9 | 10 | MODELS=( 11 | # ID,precision,revision 12 | # "prompthero/openjourney-v2" 13 | # "wd-1-4-anime_e1,,,hakurei/waifu-diffusion" 14 | # "Linaqruf/anything-v3.0,fp16,diffusers" 15 | # "Linaqruf/anything-v3.0,fp16,fp16" 16 | # "stabilityai/stable-diffusion-2-1,fp16,fp16" 17 | # "stabilityai/stable-diffusion-2-1-base,fp16,fp16" 18 | # "stabilityai/stable-diffusion-2,fp16,fp16" 19 | # "stabilityai/stable-diffusion-2-base,fp16,fp16" 20 | # "CompVis/stable-diffusion-v1-4,fp16,fp16" 21 | # "runwayml/stable-diffusion-v1-5,fp16,fp16" 22 | # "runwayml/stable-diffusion-inpainting,fp16,fp16" 23 | # "hakurei/waifu-diffusion,fp16,fp16" 24 | # "hakurei/waifu-diffusion-v1-3,fp16,fp16" # from checkpoint 25 | # "rinna/japanese-stable-diffusion" 26 | # "OrangeMix/AbyssOrangeMix2,fp16" 27 | # "OrangeMix/ElyOrangeMix,fp16" 28 | # "OrangeMix/EerieOrangeMix,fp16" 29 | # "OrangeMix/BloodOrangeMix,fp16" 30 | "hakurei/wd-1-5-illusion-beta3,fp16,fp16" 31 | "hakurei/wd-1-5-ink-beta3,fp16,fp16" 32 | "hakurei/wd-1-5-mofu-beta3,fp16,fp16" 33 | "hakurei/wd-1-5-radiance-beta3,fp16,fp16", 34 | ) 35 | 36 | for MODEL_STR in ${MODELS[@]}; do 37 | IFS="," read -ra DATA <<<$MODEL_STR 38 | MODEL_ID=${DATA[0]} 39 | MODEL_PRECISION=${DATA[1]} 40 | MODEL_REVISION=${DATA[2]} 41 | HF_MODEL_ID=${DATA[3]} 42 | python test.py txt2img \ 43 | --call-arg MODEL_ID="$MODEL_ID" \ 44 | --call-arg HF_MODEL_ID="$HF_MODEL_ID" \ 45 | --call-arg MODEL_PRECISION="$MODEL_PRECISION" \ 46 | --call-arg MODEL_REVISION="$MODEL_REVISION" \ 47 | --call-arg MODEL_URL="s3://" \ 48 | --model-arg num_inference_steps=1 49 | done 50 | -------------------------------------------------------------------------------- /release.config.js: -------------------------------------------------------------------------------- 1 | // https://semantic-release.gitbook.io/semantic-release/support/faq#can-i-use-semantic-release-to-publish-non-javascript-packages 2 | module.exports = { 3 | "branches": ["main"], 4 | "plugins": [ 5 | "@semantic-release/commit-analyzer", 6 | "@semantic-release/release-notes-generator", 7 | [ 8 | "@semantic-release/changelog", 9 | { 10 | "changelogFile": "CHANGELOG.md" 11 | } 12 | ], 13 | [ 14 | "@semantic-release/git", 15 | { 16 | "assets": ["CHANGELOG.md"] 17 | } 18 | ], 19 | "@semantic-release/github", 20 | ["@semantic-release-plus/docker", { 21 | "name": "gadicc/diffusers-api" 22 | }] 23 | ] 24 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # we pin sanic==22.6.2 for compatibility with banana 2 | sanic==22.6.2 3 | sanic-ext==22.6.2 4 | # earlier sanics don't pin but require websockets<11.0 5 | websockets<11.0 6 | 7 | # now manually git cloned in a later step 8 | # diffusers==0.4.1 9 | # git+https://github.com/huggingface/diffusers@v0.5.1 10 | 11 | transformers==4.33.1 # was 4.30.2 until 2023-09-08 12 | scipy==1.11.2 # was 1.10.0 until 2023-09-08 13 | requests_futures==1.0.0 14 | numpy==1.25.1 # was 1.24.1 until 2023-09-08 15 | scikit-image==0.21.0 # was 0.19.3 until 2023-09-08 16 | accelerate==0.22.0 # was 0.20.3 until 2023-09-08 17 | triton==2.1.0 # was 2.0.0.post1 until 2023-09-08 18 | ftfy==6.1.1 19 | spacy==3.6.1 # was 3.5.0 until 2023-09-08 20 | k-diffusion==0.0.16 # was 0.0.15 until 2023-09-08 21 | safetensors==0.3.3 # was 0.3.1 until 2023-09-08 22 | 23 | torch==2.0.1 # was 1.12.1 until 2023-07-19 24 | torchvision==0.15.2 25 | pytorch_lightning==2.0.8 # was 1.9.2 until 2023-09-08 26 | 27 | boto3==1.28.43 # was 1.26.57 until 2023-09-08 28 | botocore==1.31.43 # was 1.29.57 until 2023-09-08 29 | 30 | pytest==7.4.2 # was 7.2.1 until 2023-09-08 31 | pytest-cov==4.1.0 # was 4.0.0 until 2023-09-08 32 | 33 | datasets==2.14.5 # was 2.8.0 until 2023-09-08 34 | omegaconf==2.3.0 35 | tensorboard==2.14.0 # was 2.12.0 until 2023-09-08 36 | 37 | xtarfile[zstd]==0.1.0 38 | 39 | bitsandbytes==0.41.1 # was 0.40.2 until 2023-09-08 40 | 41 | invisible-watermark==0.2.0 # released 2023-07-06 42 | compel==2.0.2 # was 2.0.1 until 2023-09-08 43 | jxlpy==0.9.2 # added 2023-09-11 44 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | docker run -it --rm \ 4 | --gpus all \ 5 | -p 8000:8000 \ 6 | -e http_proxy="http://172.17.0.1:3128" \ 7 | -e https_proxy="http://172.17.0.1:3128" \ 8 | -e REQUESTS_CA_BUNDLE="/usr/local/share/ca-certificates/squid-self-signed.crt" \ 9 | -e HF_AUTH_TOKEN="$HF_AUTH_TOKEN" \ 10 | -e AWS_ACCESS_KEY_ID="$AWS_ACCESS_KEY_ID" \ 11 | -e AWS_SECRET_ACCESS_KEY="$AWS_SECRET_ACCESS_KEY" \ 12 | -e AWS_DEFAULT_REGION="$AWS_DEFAULT_REGION" \ 13 | -e AWS_S3_ENDPOINT_URL="$AWS_S3_ENDPOINT_URL" \ 14 | -e AWS_S3_DEFAULT_BUCKET="$AWS_S3_DEFAULT_BUCKET" \ 15 | -v ~/root-cache:/root/.cache \ 16 | "$@" gadicc/diffusers-api 17 | -------------------------------------------------------------------------------- /run_integration_tests_on_lambda.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | PAYLOAD_FILE="/tmp/request.json" 4 | 5 | if [ -z "$LAMBDA_API_KEY" ]; then 6 | echo "No LAMBDA_API_KEY set" 7 | exit 1 8 | fi 9 | 10 | SSH_KEY_FILE="$HOME/.ssh/diffusers-api-test.pem" 11 | if [ ! -f "$SSH_KEY_FILE" ]; then 12 | curl -L $DDA_TEST_PEM > $SSH_KEY_FILE 13 | chmod 600 $SSH_KEY_FILE 14 | fi 15 | 16 | #curl -u $LAMBDA_API_KEY: https://cloud.lambdalabs.com/api/v1/instances 17 | 18 | # TODO, find an available instance 19 | # https://cloud.lambdalabs.com/api/v1/instance-types 20 | 21 | lambda_run() { 22 | # $1 = lambda instance-operation 23 | if [ -z "$2" ] ; then 24 | RESULT=$( 25 | curl -su ${LAMBDA_API_KEY}: \ 26 | https://cloud.lambdalabs.com/api/v1/$1 \ 27 | -H "Content-Type: application/json" 28 | ) 29 | else 30 | RESULT=$( 31 | curl -su ${LAMBDA_API_KEY}: \ 32 | https://cloud.lambdalabs.com/api/v1/$1 \ 33 | -d @$2 -H "Content-Type: application/json" 34 | ) 35 | fi 36 | 37 | if [ $? -eq 1 ]; then 38 | echo "curl failed" 39 | exit 1 40 | fi 41 | 42 | if [ "$RESULT" != "" ]; then 43 | echo $RESULT | jq -e .error >& /dev/null 44 | if [ $? -eq 0 ]; then 45 | echo "lambda error" 46 | echo $RESULT 47 | exit 1 48 | fi 49 | fi 50 | } 51 | 52 | instance_create() { 53 | echo -n "Creating instance..." 54 | local RESULT="" 55 | cat > $PAYLOAD_FILE << __END__ 56 | { 57 | "region_name": "us-west-1", 58 | "instance_type_name": "gpu_1x_a10", 59 | "ssh_key_names": [ 60 | "diffusers-api-test" 61 | ], 62 | "file_system_names": [], 63 | "quantity": 1 64 | } 65 | __END__ 66 | 67 | lambda_run "instance-operations/launch" $PAYLOAD_FILE 68 | # echo $RESULT 69 | INSTANCE_ID=$(echo $RESULT | jq -re '.data.instance_ids[0]') 70 | echo "$INSTANCE_ID" 71 | if [ $? -eq 1 ]; then 72 | echo "jq failed" 73 | exit 1 74 | fi 75 | } 76 | 77 | instance_terminate() { 78 | # $1 = INSTANCE_ID 79 | echo "Terminating instance $1" 80 | cat > $PAYLOAD_FILE << __END__ 81 | { 82 | "instance_ids": [ 83 | "$1" 84 | ] 85 | } 86 | __END__ 87 | lambda_run "instance-operations/terminate" $PAYLOAD_FILE 88 | echo $RESULT 89 | } 90 | 91 | declare -A IPS 92 | instance_wait() { 93 | INSTANCE_ID="$1" 94 | echo -n "Waiting for $INSTANCE_ID" 95 | STATUS="" 96 | LAST_STATUS="" 97 | while [ "$STATUS" != "active" ] ; do 98 | echo -n "." 99 | lambda_run "instances/$INSTANCE_ID" 100 | STATUS=$(echo $RESULT | jq -r '.data.status') 101 | if [ "$STATUS" != "$LAST_STATUS" ]; then 102 | # echo $RESULT 103 | # echo STATUS $STATUS 104 | LAST_STATUS=$STATUS 105 | fi 106 | sleep 1 107 | done 108 | echo 109 | 110 | IP=$(echo $RESULT | jq -r '.data.ip') 111 | echo STATUS $STATUS 112 | echo IP $IP 113 | IPS["$INSTANCE_ID"]=$IP 114 | } 115 | 116 | instance_run_script() { 117 | INSTANCE_ID="$1" 118 | SCRIPT="$2" 119 | DIRECTORY="${3:-'.'}" 120 | IP=${IPS["$INSTANCE_ID"]} 121 | 122 | echo "instance_run_script $1 $2 $3" 123 | ssh -i $SSH_KEY_FILE ubuntu@$IP "cd $DIRECTORY && bash -s" < $SCRIPT 124 | return $? 125 | } 126 | 127 | instance_run_command() { 128 | INSTANCE_ID="$1" 129 | CMD="$2" 130 | DIRECTORY="${3:-'.'}" 131 | IP=${IPS["$INSTANCE_ID"]} 132 | 133 | echo "instance_run_command $1 $2" 134 | ssh -i $SSH_KEY_FILE -o StrictHostKeyChecking=accept-new ubuntu@$IP "cd $DIRECTORY && $CMD" 135 | return $? 136 | } 137 | 138 | instance_rsync() { 139 | INSTANCE_ID="$1" 140 | SOURCE="$2" 141 | DEST="$3" 142 | IP=${IPS["$INSTANCE_ID"]} 143 | 144 | echo "instance_rsync $1 $2 $3" 145 | rsync -avzPe "ssh -i $SSH_KEY_FILE -o StrictHostKeyChecking=accept-new" --filter=':- .gitignore' --exclude=".git" $SOURCE ubuntu@$IP:$DEST 146 | return $? 147 | } 148 | 149 | # Image Method 3, preparation (TODO, arg to specify which method) 150 | docker build -t gadicc/diffusers-api:test . 151 | docker push gadicc/diffusers-api:test 152 | 153 | instance_create 154 | # INSTANCE_ID="913e06f669bf4e799c6223801eb82f40" 155 | 156 | instance_wait $INSTANCE_ID 157 | 158 | commands() { 159 | instance_run_command $INSTANCE_ID "echo 'export HF_AUTH_TOKEN=\"$HF_AUTH_TOKEN\"' >> ~/.bashrc" 160 | 161 | # Whether to build or just for test scripts, lets transfer this checkout. 162 | instance_rsync $INSTANCE_ID . docker-diffusers-api 163 | 164 | instance_run_command $INSTANCE_ID "sudo apt-get update" 165 | if [ $? -eq 1 ]; then return 1 ; fi 166 | instance_run_command $INSTANCE_ID "sudo apt install -yqq python3.9" 167 | if [ $? -eq 1 ]; then return 1 ; fi 168 | instance_run_command $INSTANCE_ID "python3.9 -m pip install -r docker-diffusers-api/tests/integration/requirements.txt" 169 | if [ $? -eq 1 ]; then return 1 ; fi 170 | instance_run_command $INSTANCE_ID "sudo usermod -aG docker ubuntu" 171 | if [ $? -eq 1 ]; then return 1 ; fi 172 | 173 | # Image Method 1: Transfer entire image 174 | # This turned out to be way too slow, quicker to rebuild on lambda 175 | # Longer term, I guess we need our own container registry. 176 | # echo "Saving and transferring docker image to Lambda..." 177 | # IP=${IPS["$INSTANCE_ID"]} 178 | # docker save gadicc/diffusers-api:latest \ 179 | # | xz \ 180 | # | pv \ 181 | # | ssh -i $SSH_KEY_FILE ubuntu@$IP docker load 182 | # if [ $? -eq 1 ]; then return 1 ; fi 183 | 184 | # Image Method 2: Build on LambdaLabs 185 | #if [ $? -eq 1 ]; then return 1 ; fi 186 | #instance_run_command $INSTANCE_ID "docker build -t gadicc/diffusers-api ." docker-diffusers-api 187 | 188 | # Image Method 3: Just upload new layers; Lambda has fast downloads from registry 189 | # At start of script we have docker build/push. Now let's pull: 190 | instance_run_command $INSTANCE_ID "docker pull gadicc/diffusers-api:test" 191 | 192 | # instance_run_script $INSTANCE_ID run_integration_tests.sh docker-diffusers-api 193 | instance_run_command $INSTANCE_ID "export HF_AUTH_TOKEN=\"$HF_AUTH_TOKEN\" && python3.9 -m pytest -s tests/integration" docker-diffusers-api 194 | } 195 | 196 | commands 197 | RETURN_VALUE=$? 198 | 199 | instance_terminate $INSTANCE_ID 200 | 201 | exit $RETURN_VALUE -------------------------------------------------------------------------------- /scripts/devContainerPostCreate.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # devcontainer.json postCreateCommand 4 | 5 | echo 6 | echo Initialize conda bindings for bash 7 | conda init bash 8 | 9 | echo Activating 10 | source /opt/conda/bin/activate base 11 | 12 | echo Installing dev dependencies 13 | pip install watchdog 14 | -------------------------------------------------------------------------------- /scripts/devContainerServer.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source /opt/conda/bin/activate base 4 | 5 | ln -sf /api/diffusers . 6 | 7 | watchmedo auto-restart --recursive -d api python api/server.py -------------------------------------------------------------------------------- /scripts/patchmatch-setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | if [ "$USE_PATCHMATCH" != "1" ]; then 4 | echo "Skipping PyPatchMatch install because USE_PATCHMATCH=$USE_PATCHMATCH" 5 | mkdir PyPatchMatch 6 | touch PyPatchMatch/patch_match.py 7 | exit 8 | fi 9 | 10 | echo "Installing PyPatchMatch because USE_PATCHMATCH=$USE_PATCHMATCH" 11 | apt-get install -yqq libopencv-dev python3-opencv > /dev/null 12 | git clone https://github.com/lkwq007/PyPatchMatch 13 | cd PyPatchMatch 14 | git checkout 0ae9b8bbdc83f84214405376f13a2056568897fb 15 | sed -i '0,/if os.name!="nt":/s//if False:/' patch_match.py 16 | make 17 | -------------------------------------------------------------------------------- /scripts/permutations.yaml: -------------------------------------------------------------------------------- 1 | list: 2 | 3 | - name: sd-v1-5 4 | HF_AUTH_TOKEN: $HF_AUTH_TOKEN 5 | MODEL_ID: runwayml/stable-diffusion-v1-5 6 | PIPELINE: ALL 7 | 8 | - name: sd-v1-4 9 | HF_AUTH_TOKEN: $HF_AUTH_TOKEN 10 | MODEL_ID: CompVis/stable-diffusion-v1-4 11 | PIPELINE: ALL 12 | 13 | - name: sd-inpaint 14 | HF_AUTH_TOKEN: $HF_AUTH_TOKEN 15 | MODEL_ID: runwayml/stable-diffusion-inpainting 16 | PIPELINE: StableDiffusionInpaintPipeline 17 | 18 | - name: sd-waifu 19 | HF_AUTH_TOKEN: $HF_AUTH_TOKEN 20 | MODEL_ID: hakurei/waifu-diffusion 21 | PIPELINE: ALL 22 | 23 | - name: sd-waifu-v1-3 24 | HF_AUTH_TOKEN: $HF_AUTH_TOKEN 25 | MODEL_ID: hakurei/waifu-diffusion-v1-3 26 | CHECKPOINT_URL: https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float16.ckpt 27 | PIPELINE: ALL 28 | 29 | - name: sd-jp 30 | HF_AUTH_TOKEN: $HF_AUTH_TOKEN 31 | MODEL_ID: rinna/japanese-stable-diffusion 32 | PIPELINE: ALL 33 | -------------------------------------------------------------------------------- /scripts/permute.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Run this in banana-sd-base's PARENT directory. 4 | # Modify the below first per your preferences 5 | 6 | # Requires `yq` from https://github.com/mikefarah/yq 7 | # Note, there are two yqs. In Archlinux the package is "go-yq". 8 | 9 | if [ -z "$1" ]; then 10 | echo "Using 'scripts/permutations.yaml' as default INFILE" 11 | echo "You can also run: permutate.sh MY_INFILE" 12 | INFILE='scripts/permutations.yaml' 13 | else 14 | INFILE=$1 15 | fi 16 | 17 | if [ -z "$TARGET_REPO_BASE" ]; then 18 | TARGET_REPO_BASE="git@github.com:kiri-art" 19 | echo 'No TARGET_REPO_BASE found, using "$TARGET_REPO_BASE"' 20 | fi 21 | 22 | permutations=$(yq e -o=j -I=0 '.list[]' $INFILE) 23 | 24 | # Needed for ! expansion in cp command further down. 25 | shopt -s extglob 26 | # Include dot files in expansion for .git .gitignore 27 | shopt -s dotglob 28 | 29 | COUNTER=0 30 | declare -A vars 31 | 32 | mkdir -p permutations 33 | 34 | while IFS="=" read permutation; do 35 | # e.g. Permutation #1: banana-sd-txt2img 36 | NAME=$(echo "$permutation" | yq e '.name') 37 | COUNTER=$[$COUNTER + 1] 38 | echo 39 | echo "Permutation #$COUNTER: $NAME" 40 | 41 | while IFS="=" read -r key value 42 | do 43 | if [ "$key" != "name" ]; then 44 | if [ "${value:0:1}" == "$" ]; then 45 | # For e.g. "$HF_AUTH_TOKEN", expand from environment 46 | value="${value:1}" 47 | vars[$key]=${!value} 48 | else 49 | vars[$key]=$value; 50 | fi 51 | fi 52 | done < <(echo $permutation | yq e 'to_entries | .[] | (.key + "=" + .value)') 53 | 54 | if [ -d "permutations/$NAME" ]; then 55 | echo "./permutations/$NAME already exists, skipping..." 56 | echo "Run 'rm -rf permutations/$NAME' first to remake this permutation" 57 | echo "In a later release, we'll merge updates in this case." 58 | continue 59 | fi 60 | 61 | # echo "mkdir permutations/$NAME" 62 | mkdir permutations/$NAME 63 | # echo 'cp -a ./!(permutations|scripts|root-cache) permutations/$NAME' 64 | cp -a ./!(permutations|scripts|root-cache) permutations/$NAME 65 | # echo cd permutations/$NAME 66 | cd permutations/$NAME 67 | 68 | echo "Substituting variables in Dockerfile" 69 | for key in "${!vars[@]}"; do 70 | value="${vars[$key]}" 71 | sed -i "s@^ARG $key.*\$@ARG $key=\"$value\"@" Dockerfile 72 | done 73 | 74 | diffusers=${vars[diffusers]} 75 | if [ "$diffusers" ]; then 76 | echo "Replacing diffusers with $diffusers" 77 | echo "!!! NOT DONE YET !!!" 78 | fi 79 | 80 | mkdir root-cache 81 | touch root-cache/non-empty-directory 82 | git add root-cache 83 | 84 | git remote rm origin 85 | git remote add upstream git@github.com:kiri-art/docker-diffusers-api.git 86 | git remote add origin $TARGET_REPO_BASE/$NAME.git 87 | 88 | echo git commit -a -m "$NAME permutation variables" 89 | git commit -a -m "$NAME permutation variables" 90 | 91 | # echo "cd ../.." 92 | cd ../.. 93 | echo 94 | done < 60000 221 | else f"{item[1]/1000:.1f}s" 222 | if item[1] > 1000 223 | else str(item[1]) + "ms", 224 | ), 225 | timings.items(), 226 | ) 227 | ) 228 | ).replace('"', "")[1:-1] 229 | print(f"Request took {finish:.1f}s ({timings_str})") 230 | else: 231 | print(f"Request took {finish:.1f}s") 232 | 233 | if ( 234 | result.get("images_base64", None) == None 235 | and result.get("image_base64", None) == None 236 | ): 237 | error = result.get("$error", None) 238 | if error: 239 | code = error.get("code", None) 240 | name = error.get("name", None) 241 | message = error.get("message", None) 242 | stack = error.get("stack", None) 243 | if code and name and message and stack: 244 | print() 245 | title = f"Exception {code} on container:" 246 | print(title) 247 | print("-" * len(title)) 248 | # print(f'{name}("{message}")') - stack includes it. 249 | print(stack) 250 | return 251 | 252 | print(json.dumps(result, indent=4)) 253 | print() 254 | return result 255 | 256 | images_base64 = result.get("images_base64", None) 257 | if images_base64: 258 | for idx, image_byte_string in enumerate(images_base64): 259 | images_base64[idx] = decode_and_save(image_byte_string, f"{name}_{idx}") 260 | else: 261 | result["image_base64"] = decode_and_save(result["image_base64"], name) 262 | 263 | print() 264 | print(json.dumps(result, indent=4)) 265 | print() 266 | return result 267 | 268 | 269 | test( 270 | "txt2img", 271 | { 272 | "modelInputs": { 273 | "prompt": "realistic field of grass", 274 | "num_inference_steps": 20, 275 | }, 276 | "callInputs": { 277 | # "MODEL_ID": "", # (default) 278 | # "PIPELINE": "StableDiffusionPipeline", # (default) 279 | # "SCHEDULER": "DPMSolverMultistepScheduler", # (default) 280 | # "xformers_memory_efficient_attention": False, # (default) 281 | }, 282 | }, 283 | ) 284 | 285 | # multiple images 286 | test( 287 | "txt2img-multiple", 288 | { 289 | "modelInputs": { 290 | "prompt": "realistic field of grass", 291 | "num_images_per_prompt": 2, 292 | } 293 | }, 294 | ) 295 | 296 | 297 | test( 298 | "img2img", 299 | { 300 | "modelInputs": { 301 | "prompt": "A fantasy landscape, trending on artstation", 302 | "image": b64encode_file("sketch-mountains-input.jpg"), 303 | }, 304 | "callInputs": { 305 | "PIPELINE": "StableDiffusionImg2ImgPipeline", 306 | }, 307 | }, 308 | ) 309 | 310 | test( 311 | "inpaint-v1-4", 312 | { 313 | "modelInputs": { 314 | "prompt": "a cat sitting on a bench", 315 | "image": b64encode_file("overture-creations-5sI6fQgYIuo.png"), 316 | "mask_image": b64encode_file("overture-creations-5sI6fQgYIuo_mask.png"), 317 | }, 318 | "callInputs": { 319 | "MODEL_ID": "CompVis/stable-diffusion-v1-4", 320 | "PIPELINE": "StableDiffusionInpaintPipelineLegacy", 321 | "SCHEDULER": "DDIMScheduler", # Note, as of diffusers 0.3.0, no LMS yet 322 | }, 323 | }, 324 | ) 325 | 326 | test( 327 | "inpaint-sd", 328 | { 329 | "modelInputs": { 330 | "prompt": "a cat sitting on a bench", 331 | "image": b64encode_file("overture-creations-5sI6fQgYIuo.png"), 332 | "mask_image": b64encode_file("overture-creations-5sI6fQgYIuo_mask.png"), 333 | }, 334 | "callInputs": { 335 | "MODEL_ID": "runwayml/stable-diffusion-inpainting", 336 | "PIPELINE": "StableDiffusionInpaintPipeline", 337 | "SCHEDULER": "DDIMScheduler", # Note, as of diffusers 0.3.0, no LMS yet 338 | }, 339 | }, 340 | ) 341 | 342 | test( 343 | "checkpoint", 344 | { 345 | "modelInputs": { 346 | "prompt": "1girl", 347 | }, 348 | "callInputs": { 349 | "MODEL_ID": "hakurei/waifu-diffusion-v1-3", 350 | "MODEL_URL": "s3://", 351 | "CHECKPOINT_URL": "http://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float16.ckpt", 352 | }, 353 | }, 354 | ) 355 | 356 | if os.getenv("USE_PATCHMATCH"): 357 | test( 358 | "outpaint", 359 | { 360 | "modelInputs": { 361 | "prompt": "girl with a pearl earing standing in a big room", 362 | "image": b64encode_file("girl_with_pearl_earing_outpainting_in.png"), 363 | }, 364 | "callInputs": { 365 | "MODEL_ID": "CompVis/stable-diffusion-v1-4", 366 | "PIPELINE": "StableDiffusionInpaintPipelineLegacy", 367 | "SCHEDULER": "DDIMScheduler", # Note, as of diffusers 0.3.0, no LMS yet 368 | "FILL_MODE": "patchmatch", 369 | }, 370 | }, 371 | ) 372 | 373 | # Actually we just want this to be a non-default test? 374 | if True or os.getenv("USE_DREAMBOOTH"): 375 | test( 376 | "dreambooth", 377 | # If you're calling from the command line, don't forget to a 378 | # specify a destination if you want your fine-tuned model to 379 | # be uploaded somewhere at the end. 380 | { 381 | "modelInputs": { 382 | "instance_prompt": "a photo of sks dog", 383 | "instance_images": list( 384 | map( 385 | b64encode_file, 386 | list(Path("tests/fixtures/dreambooth").iterdir()), 387 | ) 388 | ), 389 | # Option 1: upload to HuggingFace (see notes below) 390 | # Make sure your HF API token has read/write access. 391 | # "hub_model_id": "huggingFaceUsername/targetModelName", 392 | # "push_to_hub": True, 393 | }, 394 | "callInputs": { 395 | "train": "dreambooth", 396 | # Option 2: store on S3. Note the **s3:///* (x3). See notes below. 397 | # "dest_url": "s3:///bucket/filename.tar.zst". 398 | }, 399 | }, 400 | ) 401 | 402 | 403 | def main(tests_to_run, args, extraCallInputs, extraModelInputs): 404 | invalid_tests = [] 405 | for test in tests_to_run: 406 | if all_tests.get(test, None) == None: 407 | invalid_tests.append(test) 408 | 409 | if len(invalid_tests) > 0: 410 | print("No such tests: " + ", ".join(invalid_tests)) 411 | exit(1) 412 | 413 | for test in tests_to_run: 414 | runTest(test, args, extraCallInputs, extraModelInputs) 415 | 416 | 417 | if __name__ == "__main__": 418 | parser = argparse.ArgumentParser() 419 | parser.add_argument("--banana", required=False, action="store_true") 420 | parser.add_argument("--runpod", required=False, action="store_true") 421 | parser.add_argument( 422 | "--xmfe", 423 | required=False, 424 | default=None, 425 | type=lambda x: bool(distutils.util.strtobool(x)), 426 | ) 427 | parser.add_argument("--scheduler", required=False, type=str) 428 | parser.add_argument("--call-arg", action="append", type=str) 429 | parser.add_argument("--model-arg", action="append", type=str) 430 | 431 | args, tests_to_run = parser.parse_known_args() 432 | 433 | call_inputs = {} 434 | model_inputs = {} 435 | 436 | if args.call_arg: 437 | for arg in args.call_arg: 438 | name, value = arg.split("=", 1) 439 | if value.lower() == "true": 440 | value = True 441 | elif value.lower() == "false": 442 | value = False 443 | elif value.isdigit(): 444 | value = int(value) 445 | elif value.replace(".", "", 1).isdigit(): 446 | value = float(value) 447 | call_inputs.update({name: value}) 448 | 449 | if args.model_arg: 450 | for arg in args.model_arg: 451 | name, value = arg.split("=", 1) 452 | if value.lower() == "true": 453 | value = True 454 | elif value.lower() == "false": 455 | value = False 456 | elif value.isdigit(): 457 | value = int(value) 458 | elif value.replace(".", "", 1).isdigit(): 459 | value = float(value) 460 | model_inputs.update({name: value}) 461 | 462 | if args.xmfe != None: 463 | call_inputs.update({"xformers_memory_efficient_attention": args.xmfe}) 464 | if args.scheduler: 465 | call_inputs.update({"SCHEDULER": args.scheduler}) 466 | 467 | if len(tests_to_run) < 1: 468 | print( 469 | "Usage: python3 test.py [--banana] [--xmfe=1/0] [--scheduler=SomeScheduler] [all / test1] [test2] [etc]" 470 | ) 471 | sys.exit() 472 | elif len(tests_to_run) == 1 and ( 473 | tests_to_run[0] == "ALL" or tests_to_run[0] == "all" 474 | ): 475 | tests_to_run = list(all_tests.keys()) 476 | 477 | main( 478 | tests_to_run, 479 | vars(args), 480 | extraCallInputs=call_inputs, 481 | extraModelInputs=model_inputs, 482 | ) 483 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/__init__.py -------------------------------------------------------------------------------- /tests/fixtures/dreambooth/alvan-nee-9M0tSjb-cpA-unsplash.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/fixtures/dreambooth/alvan-nee-9M0tSjb-cpA-unsplash.jpeg -------------------------------------------------------------------------------- /tests/fixtures/dreambooth/alvan-nee-Id1DBHv4fbg-unsplash.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/fixtures/dreambooth/alvan-nee-Id1DBHv4fbg-unsplash.jpeg -------------------------------------------------------------------------------- /tests/fixtures/dreambooth/alvan-nee-bQaAJCbNq3g-unsplash.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/fixtures/dreambooth/alvan-nee-bQaAJCbNq3g-unsplash.jpeg -------------------------------------------------------------------------------- /tests/fixtures/dreambooth/alvan-nee-brFsZ7qszSY-unsplash.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/fixtures/dreambooth/alvan-nee-brFsZ7qszSY-unsplash.jpeg -------------------------------------------------------------------------------- /tests/fixtures/dreambooth/alvan-nee-eoqnr8ikwFE-unsplash.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/fixtures/dreambooth/alvan-nee-eoqnr8ikwFE-unsplash.jpeg -------------------------------------------------------------------------------- /tests/fixtures/girl_with_pearl_earing_outpainting_in.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/fixtures/girl_with_pearl_earing_outpainting_in.png -------------------------------------------------------------------------------- /tests/fixtures/overture-creations-5sI6fQgYIuo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/fixtures/overture-creations-5sI6fQgYIuo.png -------------------------------------------------------------------------------- /tests/fixtures/overture-creations-5sI6fQgYIuo_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/fixtures/overture-creations-5sI6fQgYIuo_mask.png -------------------------------------------------------------------------------- /tests/fixtures/sketch-mountains-input.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/fixtures/sketch-mountains-input.jpg -------------------------------------------------------------------------------- /tests/integration/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/integration/__init__.py -------------------------------------------------------------------------------- /tests/integration/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import os 3 | from .lib import startContainer, get_free_port, DOCKER_GW_IP 4 | 5 | 6 | @pytest.fixture(autouse=True, scope="session") 7 | def my_fixture(): 8 | # setup_stuff 9 | print("session start") 10 | 11 | # newCache = not os.getenv("DDA_https_proxy") 12 | newCache = False 13 | 14 | if newCache: 15 | squid_port = get_free_port() 16 | http_port = get_free_port() 17 | container, stop = startContainer( 18 | "gadicc/squid-ssl-zero", 19 | ports={3128: squid_port, 3129: http_port}, 20 | ) 21 | os.environ["DDA_http_proxy"] = f"http://{DOCKER_GW_IP}:{squid_port}" 22 | os.environ["DDA_https_proxy"] = os.environ["DDA_http_proxy"] 23 | # TODO, code in getDDA to download cert 24 | 25 | yield 26 | # teardown_stuff 27 | print("session end") 28 | if newCache: 29 | stop() 30 | -------------------------------------------------------------------------------- /tests/integration/lib.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import docker 3 | import atexit 4 | import time 5 | import boto3 6 | import os 7 | import requests 8 | import socket 9 | import asyncio 10 | import sys 11 | import subprocess 12 | import selectors 13 | from threading import Thread 14 | from argparse import Namespace 15 | 16 | AWS_S3_DEFAULT_BUCKET = os.environ.get("AWS_S3_DEFAULT_BUCKET", "test") 17 | DOCKER_GW_IP = "172.17.0.1" # will override below if found 18 | 19 | myContainers = list() 20 | dockerClient = docker.DockerClient( 21 | base_url="unix://var/run/docker.sock", version="auto" 22 | ) 23 | for network in dockerClient.networks.list(): 24 | if network.attrs["Scope"] == "local" and network.attrs["Driver"] == "bridge": 25 | DOCKER_GW_IP = network.attrs["IPAM"]["Config"][0]["Gateway"] 26 | break 27 | 28 | # # https://stackoverflow.com/a/53255955/1839099 29 | # def fire_and_forget(f): 30 | # def wrapped(*args, **kwargs): 31 | # return asyncio.get_event_loop().run_in_executor(None, f, *args, *kwargs) 32 | # return wrapped 33 | # 34 | # @fire_and_forget 35 | # def log_streamer(container): 36 | # for line in container.logs(stream=True): 37 | # print(line.decode(), end="") 38 | 39 | 40 | def log_streamer(container, name=None): 41 | """ 42 | Streams logs to stdout/stderr. 43 | Order is not guaranteed (have tried 3 different methods) 44 | """ 45 | # Method 1: pipe streams directly -- even this doesn't guarantee order 46 | # Method 2: threads + readline 47 | # Method 3: selectors + read1 48 | method = 1 49 | 50 | if method == 1: 51 | kwargs = { 52 | "stdout": sys.stdout, 53 | "stderr": sys.stderr, 54 | } 55 | elif method == 2: 56 | kwargs = { 57 | "stdout": subprocess.PIPE, 58 | "stderr": subprocess.PIPE, 59 | "bufsize": 1, 60 | "universal_newlines": True, 61 | } 62 | elif method == 3: 63 | kwargs = { 64 | "stdout": subprocess.PIPE, 65 | "stderr": subprocess.PIPE, 66 | "bufsize": 1, 67 | } 68 | 69 | prefix = f"[{name or container.id[:7]}] " 70 | print(prefix + "== Streaming logs (stdout/stderr order not guaranteed): ==") 71 | 72 | sp = subprocess.Popen(["docker", "logs", "-f", container.id], **kwargs) 73 | 74 | if method == 2: 75 | 76 | def reader(pipe): 77 | while True: 78 | read = pipe.readline() 79 | if read == "" and sp.poll() is not None: 80 | break 81 | print(prefix + read, end="") 82 | sys.stdout.flush() 83 | sys.stderr.flush() 84 | 85 | Thread(target=reader, args=[sp.stdout]).start() 86 | Thread(target=reader, args=[sp.stderr]).start() 87 | 88 | elif method == 3: 89 | selector = selectors.DefaultSelector() 90 | selector.register(sp.stdout, selectors.EVENT_READ) 91 | selector.register(sp.stderr, selectors.EVENT_READ) 92 | loop = True 93 | 94 | while loop: 95 | for key, _ in selector.select(): 96 | data = key.fileobj.read1().decode() 97 | if not data: 98 | loop = False 99 | break 100 | line = prefix + str(data).rstrip().replace("\n", "\n" + prefix) 101 | if key.fileobj is sp.stdout: 102 | print(line) 103 | sys.stdout.flush() 104 | else: 105 | print(line, file=sys.stderr) 106 | sys.stderr.flush() 107 | 108 | 109 | def get_free_port(): 110 | s = socket.socket() 111 | s.bind(("", 0)) 112 | port = s.getsockname()[1] 113 | s.close() 114 | return port 115 | 116 | 117 | def startContainer(image, command=None, stream_logs=False, onstop=None, **kwargs): 118 | global myContainers 119 | 120 | container = dockerClient.containers.run( 121 | image, 122 | command, 123 | # auto_remove=True, 124 | detach=True, 125 | **kwargs, 126 | ) 127 | 128 | if stream_logs: 129 | log_streamer(container) 130 | 131 | myContainers.append(container) 132 | 133 | def stop(): 134 | print("stop", container.id) 135 | container.stop() 136 | container.remove() 137 | myContainers.remove(container) 138 | if onstop: 139 | onstop() 140 | 141 | while container.status != "running" and container.status != "exited": 142 | time.sleep(1) 143 | try: 144 | container.reload() 145 | except Exception as error: 146 | print(container.logs()) 147 | raise error 148 | print(container.status) 149 | 150 | # if (container.status == "exited"): 151 | # print(container.logs()) 152 | # raise Exception("unexpected exit") 153 | 154 | print("returned", container) 155 | return container, stop 156 | 157 | 158 | _minioCache = {} 159 | 160 | 161 | def getMinio(id="disposable"): 162 | cached = _minioCache.get(id, None) 163 | if cached: 164 | return Namespace(**cached) 165 | 166 | if id == "global": 167 | endpoint_url = os.getenv("AWS_S3_ENDPOINT_URL") 168 | if endpoint_url: 169 | print("Reusing existing global minio") 170 | aws_access_key_id = os.getenv("AWS_ACCESS_KEY_ID") 171 | aws_secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY") 172 | aws_s3_default_bucket = AWS_S3_DEFAULT_BUCKET 173 | s3 = boto3.client( 174 | "s3", 175 | endpoint_url=endpoint_url, 176 | config=boto3.session.Config(signature_version="s3v4"), 177 | aws_access_key_id=aws_access_key_id, 178 | aws_secret_access_key=aws_secret_access_key, 179 | aws_session_token=None, 180 | # verify=False, 181 | ) 182 | result = { 183 | # don't link to actual container, and don't rm it at end 184 | "container": "global", 185 | "stop": lambda: print(), 186 | # "port": port, 187 | "endpoint_url": endpoint_url, 188 | "aws_access_key_id": aws_access_key_id, 189 | "aws_secret_access_key": aws_secret_access_key, 190 | "aws_s3_default_bucket": aws_s3_default_bucket, 191 | "s3": s3, 192 | } 193 | _minioCache.update({id: result}) 194 | return Namespace(**result) 195 | else: 196 | print("Creating new global minio") 197 | 198 | port = get_free_port() 199 | 200 | def onstop(): 201 | del _minioCache[id] 202 | 203 | container, stop = startContainer( 204 | "minio/minio", 205 | "server /data --console-address :9001", 206 | ports={9000: port}, 207 | onstop=onstop, 208 | ) 209 | 210 | endpoint_url = f"http://{DOCKER_GW_IP}:{port}" 211 | 212 | while True: 213 | time.sleep(1) 214 | response = None 215 | try: 216 | print(endpoint_url + "/minio/health/live") 217 | response = requests.get(endpoint_url + "/minio/health/live") 218 | except Exception as error: 219 | print(error) 220 | 221 | if response and response.status_code == 200: 222 | break 223 | 224 | aws_access_key_id = "minioadmin" 225 | aws_secret_access_key = "minioadmin" 226 | aws_s3_default_bucket = AWS_S3_DEFAULT_BUCKET 227 | s3 = boto3.client( 228 | "s3", 229 | endpoint_url=endpoint_url, 230 | config=boto3.session.Config(signature_version="s3v4"), 231 | aws_access_key_id=aws_access_key_id, 232 | aws_secret_access_key=aws_secret_access_key, 233 | aws_session_token=None, 234 | # verify=False, 235 | ) 236 | 237 | s3.create_bucket(Bucket=AWS_S3_DEFAULT_BUCKET) 238 | 239 | result = { 240 | "container": container, 241 | "stop": stop, 242 | "port": port, 243 | "endpoint_url": endpoint_url, 244 | "aws_access_key_id": aws_access_key_id, 245 | "aws_secret_access_key": aws_secret_access_key, 246 | "aws_s3_default_bucket": aws_s3_default_bucket, 247 | "s3": s3, 248 | } 249 | _minioCache.update({id: result}) 250 | return Namespace(**result) 251 | 252 | 253 | _ddaCache = None 254 | 255 | 256 | def getDDA( 257 | minio=None, 258 | command=None, 259 | environment={}, 260 | stream_logs=False, 261 | wait=True, 262 | root_cache=True, 263 | **kwargs, 264 | ): 265 | global _ddaCache 266 | if _ddaCache: 267 | print("return _ddaCache") 268 | return Namespace(**_ddaCache) 269 | else: 270 | print("create new _dda") 271 | 272 | port = get_free_port() 273 | 274 | environment.update( 275 | { 276 | "HF_AUTH_TOKEN": os.getenv("HF_AUTH_TOKEN"), 277 | "http_proxy": os.getenv("DDA_http_proxy"), 278 | "https_proxy": os.getenv("DDA_https_proxy"), 279 | "REQUESTS_CA_BUNDLE": os.getenv("DDA_http_proxy") 280 | and "/usr/local/share/ca-certificates/squid-self-signed.crt", 281 | } 282 | ) 283 | 284 | if minio: 285 | environment.update( 286 | { 287 | "AWS_ACCESS_KEY_ID": minio.aws_access_key_id, 288 | "AWS_SECRET_ACCESS_KEY": minio.aws_secret_access_key, 289 | "AWS_DEFAULT_REGION": "", 290 | "AWS_S3_DEFAULT_BUCKET": minio.aws_s3_default_bucket, 291 | "AWS_S3_ENDPOINT_URL": minio.endpoint_url, 292 | } 293 | ) 294 | 295 | def onstop(): 296 | global _ddaCache 297 | _ddaCache = None 298 | 299 | HOME = os.getenv("HOME") 300 | 301 | container, stop = startContainer( 302 | "gadicc/diffusers-api:test", 303 | command, 304 | stream_logs=stream_logs, 305 | ports={8000: port}, 306 | device_requests=[docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])], 307 | environment=environment, 308 | volumes=root_cache and [f"{HOME}/root-cache:/root/.cache"], 309 | onstop=onstop, 310 | **kwargs, 311 | ) 312 | 313 | url = f"http://{DOCKER_GW_IP}:{port}/" 314 | 315 | while wait: 316 | time.sleep(1) 317 | container.reload() 318 | if container.status == "exited": 319 | if not stream_logs: 320 | print("--- EARLY EXIT ---") 321 | print(container.logs().decode()) 322 | print("--- EARLY EXIT ---") 323 | raise Exception("Early exit before successful healthcheck") 324 | 325 | response = None 326 | try: 327 | # print(url + "healthcheck") 328 | response = requests.get(url + "healthcheck") 329 | except Exception as error: 330 | # print(error) 331 | continue 332 | 333 | if response: 334 | if response.status_code == 200: 335 | result = response.json() 336 | if result["state"] == "healthy" and result["gpu"] == True: 337 | print("Ready") 338 | break 339 | else: 340 | print(response) 341 | print(response.text) 342 | else: 343 | raise Exception("Unexpected status code from dda/healthcheck") 344 | 345 | data = { 346 | "container": container, 347 | "stop": stop, 348 | "minio": minio, 349 | "port": port, 350 | "url": url, 351 | } 352 | 353 | _ddaCache = data 354 | return Namespace(**data) 355 | 356 | 357 | def cleanup(): 358 | print("cleanup") 359 | for container in myContainers: 360 | print("Stopping") 361 | print(container) 362 | container.stop() 363 | print("removing") 364 | container.remove() 365 | 366 | 367 | atexit.register(cleanup) 368 | -------------------------------------------------------------------------------- /tests/integration/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest==7.2.0 2 | docker==6.0.1 3 | boto3==1.26.44 4 | Pillow==9.4.0 5 | # work around breaking changes in urllib3 2.0 6 | # until https://github.com/docker/docker-py/pull/3114/files lands 7 | urllib3<2 8 | -------------------------------------------------------------------------------- /tests/integration/test_attn_procs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from .lib import getMinio, getDDA 4 | from test import runTest 5 | 6 | if False: 7 | 8 | class TestAttnProcs: 9 | def setup_class(self): 10 | print("setup_class") 11 | # self.minio = minio = getMinio("global") 12 | 13 | self.dda = dda = getDDA( 14 | # minio=minio 15 | stream_logs=True, 16 | ) 17 | print(dda) 18 | 19 | self.TEST_ARGS = {"test_url": dda.url} 20 | 21 | def teardown_class(self): 22 | print("teardown_class") 23 | # self.minio.stop() - leave global up 24 | self.dda.stop() 25 | 26 | def test_lora_hf_download(self): 27 | """ 28 | Download user/repo from HuggingFace. 29 | """ 30 | # fp32 model is obviously bigger 31 | result = runTest( 32 | "txt2img", 33 | self.TEST_ARGS, 34 | { 35 | "MODEL_ID": "runwayml/stable-diffusion-v1-5", 36 | "MODEL_REVISION": "fp16", 37 | "MODEL_PRECISION": "fp16", 38 | "attn_procs": "patrickvonplaten/lora_dreambooth_dog_example", 39 | }, 40 | { 41 | "num_inference_steps": 1, 42 | "prompt": "A picture of a sks dog in a bucket", 43 | "seed": 1, 44 | "cross_attention_kwargs": {"scale": 0.5}, 45 | }, 46 | ) 47 | 48 | assert result["image_base64"] 49 | 50 | def test_lora_http_download_pytorch_bin(self): 51 | """ 52 | Download pytroch_lora_weights.bin directly. 53 | """ 54 | result = runTest( 55 | "txt2img", 56 | self.TEST_ARGS, 57 | { 58 | "MODEL_ID": "runwayml/stable-diffusion-v1-5", 59 | "MODEL_REVISION": "fp16", 60 | "MODEL_PRECISION": "fp16", 61 | "attn_procs": "https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/resolve/main/pytorch_lora_weights.bin", 62 | }, 63 | { 64 | "num_inference_steps": 1, 65 | "prompt": "A picture of a sks dog in a bucket", 66 | "seed": 1, 67 | "cross_attention_kwargs": {"scale": 0.5}, 68 | }, 69 | ) 70 | 71 | assert result["image_base64"] 72 | 73 | if False: 74 | # These formats are not supported by diffusers yet :( 75 | def test_lora_http_download_civitai_safetensors(self): 76 | result = runTest( 77 | "txt2img", 78 | self.TEST_ARGS, 79 | { 80 | "MODEL_ID": "runwayml/stable-diffusion-v1-5", 81 | "MODEL_REVISION": "fp16", 82 | "MODEL_PRECISION": "fp16", 83 | "attn_procs": "https://civitai.com/api/download/models/11523", 84 | "attn_procs_from_safetensors": True, 85 | }, 86 | { 87 | "num_inference_steps": 1, 88 | "prompt": "A picture of a sks dog in a bucket", 89 | "seed": 1, 90 | }, 91 | ) 92 | 93 | assert result["image_base64"] 94 | -------------------------------------------------------------------------------- /tests/integration/test_build_download.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from .lib import getMinio, getDDA 3 | from test import runTest 4 | 5 | 6 | def test_cloudcache_build_download(): 7 | """ 8 | Download a model from cloud-cache at build time (no HuggingFace) 9 | """ 10 | minio = getMinio() 11 | print(minio) 12 | environment = { 13 | "RUNTIME_DOWNLOADS": 0, 14 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base", 15 | "MODEL_PRECISION": "fp16", 16 | "MODEL_REVISION": "fp16", 17 | "MODEL_URL": "s3://", # <-- 18 | } 19 | # conda = "conda run --no-capture-output -n xformers" 20 | conda = "" 21 | dda = getDDA( 22 | minio=minio, 23 | stream_logs=True, 24 | environment=environment, 25 | root_cache=False, 26 | command=[ 27 | "sh", 28 | "-c", 29 | f"{conda} python3 -u download.py && ls -l && {conda} python3 -u server.py", 30 | ], 31 | ) 32 | print(dda) 33 | assert dda.container.status == "running" 34 | 35 | ## bucket.objects.all().delete() 36 | result = runTest( 37 | "txt2img", 38 | {"test_url": dda.url}, 39 | { 40 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base", 41 | }, 42 | {"num_inference_steps": 1}, 43 | ) 44 | 45 | dda.stop() 46 | minio.stop() 47 | assert result["image_base64"] 48 | print("test successs\n\n") 49 | 50 | 51 | def test_huggingface_build_download(): 52 | """ 53 | Download a model from HuggingFace at build time (no cloud-cache) 54 | NOTE / TODO: Good starting point, but this still runs with gpu and 55 | uploads if missing. 56 | """ 57 | environment = { 58 | "RUNTIME_DOWNLOADS": 0, 59 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base", 60 | "MODEL_PRECISION": "fp16", 61 | "MODEL_REVISION": "fp16", 62 | } 63 | # conda = "conda run --no-capture-output -n xformers" 64 | conda = "" 65 | dda = getDDA( 66 | stream_logs=True, 67 | environment=environment, 68 | root_cache=False, 69 | command=[ 70 | "sh", 71 | "-c", 72 | f"{conda} python3 -u download.py && ls -l && {conda} python3 -u server.py", 73 | ], 74 | ) 75 | print(dda) 76 | assert dda.container.status == "running" 77 | 78 | ## bucket.objects.all().delete() 79 | result = runTest( 80 | "txt2img", 81 | {"test_url": dda.url}, 82 | { 83 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base", 84 | # "MODEL_ID": "hf-internal-testing/tiny-stable-diffusion-pipe", 85 | "MODEL_PRECISION": "fp16", 86 | "MODEL_REVISION": "fp16", 87 | "MODEL_URL": "", # <-- no model_url, i.e. no cloud cache 88 | }, 89 | {"num_inference_steps": 1}, 90 | ) 91 | dda.stop() 92 | 93 | assert result["image_base64"] 94 | print("test successs\n\n") 95 | 96 | 97 | def test_checkpoint_url_build_download(): 98 | """ 99 | Download and convert a .ckpt at build time. No cloud-cache. 100 | """ 101 | environment = { 102 | "RUNTIME_DOWNLOADS": 0, 103 | "MODEL_ID": "hakurei/waifu-diffusion-v1-3", 104 | "MODEL_PRECISION": "fp16", 105 | "MODEL_REVISION": "fp16", 106 | "CHECKPOINT_URL": "https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float16.ckpt", 107 | } 108 | # conda = "conda run --no-capture-output -n xformers" 109 | conda = "" 110 | dda = getDDA( 111 | stream_logs=True, 112 | environment=environment, 113 | root_cache=False, 114 | command=[ 115 | "sh", 116 | "-c", 117 | f"{conda} python3 -u download.py && ls -l && {conda} python3 -u server.py", 118 | ], 119 | ) 120 | print(dda) 121 | assert dda.container.status == "running" 122 | 123 | ## bucket.objects.all().delete() 124 | result = runTest( 125 | "txt2img", 126 | {"test_url": dda.url}, 127 | { 128 | "MODEL_ID": "hakurei/waifu-diffusion-v1-3", 129 | "MODEL_PRECISION": "fp16", 130 | "MODEL_URL": "", # <-- no model_url, i.e. no cloud cache 131 | }, 132 | {"num_inference_steps": 1}, 133 | ) 134 | dda.stop() 135 | 136 | assert result["image_base64"] 137 | print("test successs\n\n") 138 | -------------------------------------------------------------------------------- /tests/integration/test_cloud_cache.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from .lib import getMinio, getDDA 3 | from test import runTest 4 | 5 | 6 | def test_cloud_cache_create_and_upload(): 7 | """ 8 | Check if model exists in cloud cache bucket download otherwise, save 9 | with safetensors, and upload model.tar.zst to bucket 10 | """ 11 | minio = getMinio() 12 | print(minio) 13 | dda = getDDA(minio=minio, stream_logs=True, root_cache=False) 14 | print(dda) 15 | 16 | ## bucket.objects.all().delete() 17 | result = runTest( 18 | "txt2img", 19 | {"test_url": dda.url}, 20 | { 21 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base", 22 | # "MODEL_ID": "hf-internal-testing/tiny-stable-diffusion-pipe", 23 | "MODEL_PRECISION": "fp16", 24 | "MODEL_REVISION": "fp16", 25 | "MODEL_URL": "s3://", 26 | }, 27 | {"num_inference_steps": 1}, 28 | ) 29 | 30 | dda.stop() 31 | minio.stop() 32 | timings = result["$timings"] 33 | assert timings["download"] > 0 34 | assert timings["upload"] > 0 35 | -------------------------------------------------------------------------------- /tests/integration/test_dreambooth.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .lib import getMinio, getDDA 3 | from test import runTest 4 | 5 | HF_USERNAME = os.getenv("HF_USERNAME", "gadicc") 6 | 7 | 8 | class TestDreamBoothS3: 9 | """ 10 | Train/Infer via S3 model save. 11 | """ 12 | 13 | def setup_class(self): 14 | print("setup_class") 15 | self.minio = getMinio("global") 16 | 17 | def teardown_class(self): 18 | print("teardown_class") 19 | # self.minio.stop() # leave global up. 20 | 21 | def test_training_s3(self): 22 | dda = getDDA( 23 | minio=self.minio, 24 | stream_logs=True, 25 | ) 26 | print(dda) 27 | 28 | result = runTest( 29 | "dreambooth", 30 | {"test_url": dda.url}, 31 | { 32 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base", 33 | "MODEL_REVISION": "", 34 | "MODEL_PRECISION": "", 35 | "MODEL_URL": "s3://", 36 | "train": "dreambooth", 37 | "dest_url": f"s3:///{self.minio.aws_s3_default_bucket}/model.tar.zst", 38 | }, 39 | {"max_train_steps": 1}, 40 | ) 41 | 42 | dda.stop() 43 | timings = result["$timings"] 44 | assert timings["training"] > 0 45 | assert timings["upload"] > 0 46 | 47 | # dependent on above, TODO, mark as such. 48 | def test_s3_download_and_inference(self): 49 | dda = getDDA( 50 | minio=self.minio, 51 | stream_logs=True, 52 | root_cache=False, 53 | ) 54 | print(dda) 55 | 56 | result = runTest( 57 | "txt2img", 58 | {"test_url": dda.url}, 59 | { 60 | "MODEL_ID": "model", 61 | "MODEL_PRECISION": "fp16", 62 | "MODEL_URL": f"s3:///{self.minio.aws_s3_default_bucket}/model.tar.zst", 63 | }, 64 | {"num_inference_steps": 1}, 65 | ) 66 | 67 | dda.stop() 68 | assert result["image_base64"] 69 | 70 | 71 | if os.getenv("TEST_DREAMBOOTH_HF", None): 72 | 73 | class TestDreamBoothHF: 74 | def test_training_hf(self): 75 | dda = getDDA( 76 | stream_logs=True, 77 | ) 78 | print(dda) 79 | 80 | result = runTest( 81 | "dreambooth", 82 | {"test_url": dda.url}, 83 | { 84 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base", 85 | "MODEL_REVISION": "", 86 | "MODEL_PRECISION": "", 87 | "MODEL_URL": "s3://", 88 | "train": "dreambooth", 89 | }, 90 | { 91 | "hub_model_id": f"{HF_USERNAME}/dreambooth_test", 92 | "push_to_hub": True, 93 | "max_train_steps": 1, 94 | }, 95 | ) 96 | 97 | dda.stop() 98 | timings = result["$timings"] 99 | assert timings["training"] > 0 100 | assert timings["upload"] > 0 101 | 102 | # dependent on above, TODO, mark as such. 103 | def test_hf_download_and_inference(self): 104 | dda = getDDA( 105 | stream_logs=True, 106 | root_cache=False, 107 | ) 108 | print(dda) 109 | 110 | result = runTest( 111 | "txt2img", 112 | {"test_url": dda.url}, 113 | { 114 | "MODEL_ID": f"{HF_USERNAME}/dreambooth_test", 115 | "MODEL_PRECISION": "fp16", 116 | }, 117 | {"num_inference_steps": 1}, 118 | ) 119 | 120 | dda.stop() 121 | assert result["image_base64"] 122 | 123 | else: 124 | 125 | print( 126 | "Skipping dreambooth HuggingFace upload/download tests by default\n" 127 | "as they can be flaky. To run, set env var TEST_DREAMBOOTH_HF=1" 128 | ) 129 | -------------------------------------------------------------------------------- /tests/integration/test_general.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from .lib import getMinio, getDDA 4 | from test import runTest 5 | 6 | 7 | class TestGeneralClass: 8 | """ 9 | Typical usage tests, that assume model is already available locally. 10 | txt2img, img2img, inpaint. 11 | """ 12 | 13 | CALL_ARGS = { 14 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base", 15 | "MODEL_PRECISION": "fp16", 16 | "MODEL_REVISION": "fp16", 17 | "MODEL_URL": "s3://", 18 | } 19 | 20 | MODEL_ARGS = {"num_inference_steps": 2} 21 | 22 | def setup_class(self): 23 | print("setup_class") 24 | self.minio = minio = getMinio("global") 25 | 26 | self.dda = dda = getDDA( 27 | minio=minio 28 | # stream_logs=True, 29 | ) 30 | print(dda) 31 | 32 | self.TEST_ARGS = {"test_url": dda.url} 33 | 34 | def teardown_class(self): 35 | print("teardown_class") 36 | # self.minio.stop() - leave global up 37 | self.dda.stop() 38 | 39 | def test_txt2img(self): 40 | result = runTest("txt2img", self.TEST_ARGS, self.CALL_ARGS, self.MODEL_ARGS) 41 | assert result["image_base64"] 42 | 43 | def test_img2img(self): 44 | result = runTest("img2img", self.TEST_ARGS, self.CALL_ARGS, self.MODEL_ARGS) 45 | assert result["image_base64"] 46 | 47 | # def test_inpaint(self): 48 | # """ 49 | # This is actually calling inpaint with SDv2.1, not the inpainting model, 50 | # so I guess we're testing inpaint-legacy. 51 | # """ 52 | # result = runTest("inpaint", self.TEST_ARGS, self.CALL_ARGS, self.MODEL_ARGS) 53 | # assert result["image_base64"] 54 | -------------------------------------------------------------------------------- /tests/integration/test_loras.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from .lib import getMinio, getDDA 4 | from test import runTest 5 | 6 | 7 | class TestLoRAs: 8 | def setup_class(self): 9 | print("setup_class") 10 | # self.minio = minio = getMinio("global") 11 | 12 | self.dda = dda = getDDA( 13 | # minio=minio 14 | stream_logs=True, 15 | ) 16 | print(dda) 17 | 18 | self.TEST_ARGS = {"test_url": dda.url} 19 | 20 | def teardown_class(self): 21 | print("teardown_class") 22 | # self.minio.stop() - leave global up 23 | self.dda.stop() 24 | 25 | if False: 26 | 27 | def test_lora_hf_download(self): 28 | """ 29 | Download user/repo from HuggingFace. 30 | """ 31 | # fp32 model is obviously bigger 32 | result = runTest( 33 | "txt2img", 34 | self.TEST_ARGS, 35 | { 36 | "MODEL_ID": "runwayml/stable-diffusion-v1-5", 37 | "MODEL_REVISION": "fp16", 38 | "MODEL_PRECISION": "fp16", 39 | "attn_procs": "patrickvonplaten/lora_dreambooth_dog_example", 40 | }, 41 | { 42 | "num_inference_steps": 1, 43 | "prompt": "A picture of a sks dog in a bucket", 44 | "seed": 1, 45 | "cross_attention_kwargs": {"scale": 0.5}, 46 | }, 47 | ) 48 | 49 | assert result["image_base64"] 50 | 51 | if False: 52 | 53 | def test_lora_http_download_pytorch_bin(self): 54 | """ 55 | Download pytroch_lora_weights.bin directly. 56 | """ 57 | result = runTest( 58 | "txt2img", 59 | self.TEST_ARGS, 60 | { 61 | "MODEL_ID": "runwayml/stable-diffusion-v1-5", 62 | "MODEL_REVISION": "fp16", 63 | "MODEL_PRECISION": "fp16", 64 | "attn_procs": "https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/resolve/main/pytorch_lora_weights.bin", 65 | }, 66 | { 67 | "num_inference_steps": 1, 68 | "prompt": "A picture of a sks dog in a bucket", 69 | "seed": 1, 70 | "cross_attention_kwargs": {"scale": 0.5}, 71 | }, 72 | ) 73 | 74 | assert result["image_base64"] 75 | 76 | # These formats are not supported by diffusers yet :( 77 | def test_lora_http_download_civitai_safetensors(self): 78 | quickTest = True 79 | 80 | callInputs = { 81 | "MODEL_ID": "NED-v1-22", 82 | # https://civitai.com/models/10028/neverending-dream-ned?modelVersionId=64094 83 | "CHECKPOINT_URL": "https://civitai.com/api/download/models/64094#fname=neverendingDreamNED_v122BakedVae.safetensors", 84 | "MODEL_PRECISION": "fp16", 85 | # https://civitai.com/models/5373/makima-chainsaw-man-lora 86 | "lora_weights": "https://civitai.com/api/download/models/6244#fname=makima_offset.safetensors", 87 | "safety_checker": False, 88 | "PIPELINE": "lpw_stable_diffusion", 89 | } 90 | modelInputs = { 91 | # https://civitai.com/images/709482 92 | "num_inference_steps": 30, 93 | "prompt": "masterpiece, (photorealistic:1.4), best quality, beautiful lighting, (ulzzang-6500:0.5), makima \(chainsaw man\), (red hair)+(long braided hair)+(bangs), yellow eyes, golden eyes, ((ringed eyes)), (white shirt), (necktie), RAW photo, 8k uhd, film grain", 94 | "negative_prompt": "(painting by bad-artist-anime:0.9), (painting by bad-artist:0.9), watermark, text, error, blurry, jpeg artifacts, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, artist name, (worst quality, low quality:1.4), bad anatomy", 95 | "width": 864, 96 | "height": 1304, 97 | "seed": 2281759351, 98 | "guidance_scale": 9, 99 | } 100 | 101 | if quickTest: 102 | callInputs.update( 103 | { 104 | # i.e. use a model we already have 105 | "MODEL_ID": "runwayml/stable-diffusion-v1-5", 106 | "MODEL_REVISION": "fp16", 107 | "CHECKPOINT_URL": None, 108 | } 109 | ) 110 | modelInputs.update( 111 | { 112 | "num_inference_steps": 1, 113 | "width": 512, 114 | "height": 512, 115 | } 116 | ) 117 | result = runTest("txt2img", self.TEST_ARGS, callInputs, modelInputs) 118 | 119 | assert result["image_base64"] 120 | -------------------------------------------------------------------------------- /tests/integration/test_memory.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | from .lib import getMinio, getDDA 4 | from test import runTest 5 | 6 | 7 | def test_memory(): 8 | """ 9 | Make sure when switching models we release VRAM afterwards. 10 | """ 11 | minio = getMinio("global") 12 | dda = getDDA( 13 | minio=minio, 14 | stream_logs=True, 15 | ) 16 | print(dda) 17 | 18 | TEST_ARGS = {"test_url": dda.url} 19 | MODEL_ARGS = {"num_inference_steps": 1} 20 | 21 | mem_usage = list() 22 | 23 | # fp32 model is obviously bigger 24 | result = runTest( 25 | "txt2img", 26 | TEST_ARGS, 27 | { 28 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base", 29 | "MODEL_REVISION": "", # <-- 30 | "MODEL_PRECISION": "", # <-- 31 | "MODEL_URL": "s3://", 32 | }, 33 | MODEL_ARGS, 34 | ) 35 | mem_usage.append(result["$mem_usage"]) 36 | 37 | # fp32 model is obviously smaller 38 | result = runTest( 39 | "txt2img", 40 | TEST_ARGS, 41 | { 42 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base", 43 | "MODEL_REVISION": "fp16", # <-- 44 | "MODEL_PRECISION": "fp16", # <-- 45 | "MODEL_URL": "s3://", 46 | }, 47 | MODEL_ARGS, 48 | ) 49 | mem_usage.append(result["$mem_usage"]) 50 | 51 | print({"mem_usage": mem_usage}) 52 | assert mem_usage[1] < mem_usage[0] 53 | 54 | dda.stop() 55 | -------------------------------------------------------------------------------- /touch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/touch -------------------------------------------------------------------------------- /update.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | 3 | rsync -avzPe "ssh -p $1" api/ $2:/api/ 4 | --------------------------------------------------------------------------------