├── .circleci
└── config.yml
├── .devcontainer
├── devcontainer.json
└── local.example.env
├── .gitignore
├── .vscode
├── settings.json
└── tasks.json
├── CHANGELOG.md
├── CONTRIBUTING.md
├── Dockerfile
├── LICENSE
├── README.md
├── __init__.py
├── api
├── app.py
├── convert_to_diffusers.py
├── device.py
├── download.py
├── download_checkpoint.py
├── extras
│ ├── __init__.py
│ └── upsample
│ │ ├── __init__.py
│ │ ├── models.py
│ │ └── upsample.py
├── getPipeline.py
├── getScheduler.py
├── lib
│ ├── __init__.py
│ ├── prompts.py
│ ├── textual_inversions.py
│ ├── textual_inversions_test.py
│ └── vars.py
├── loadModel.py
├── precision.py
├── send.py
├── server.py
├── status.py
├── tests.py
├── train_dreambooth.py
└── utils
│ ├── __init__.py
│ └── storage
│ ├── BaseStorage.py
│ ├── BaseStorage_test.py
│ ├── HTTPStorage.py
│ ├── S3Storage.py
│ ├── S3Storage_test.py
│ ├── __init__.py
│ └── __init__test.py
├── build
├── docs
├── internal_safetensor_cache_flow.md
└── storage.md
├── install.sh
├── package.json
├── prime.sh
├── release.config.js
├── requirements.txt
├── run.sh
├── run_integration_tests_on_lambda.sh
├── scripts
├── devContainerPostCreate.sh
├── devContainerServer.sh
├── patchmatch-setup.sh
├── permutations.yaml
└── permute.sh
├── test.py
├── tests
├── __init__.py
├── fixtures
│ ├── dreambooth
│ │ ├── alvan-nee-9M0tSjb-cpA-unsplash.jpeg
│ │ ├── alvan-nee-Id1DBHv4fbg-unsplash.jpeg
│ │ ├── alvan-nee-bQaAJCbNq3g-unsplash.jpeg
│ │ ├── alvan-nee-brFsZ7qszSY-unsplash.jpeg
│ │ └── alvan-nee-eoqnr8ikwFE-unsplash.jpeg
│ ├── girl_with_pearl_earing_outpainting_in.png
│ ├── overture-creations-5sI6fQgYIuo.png
│ ├── overture-creations-5sI6fQgYIuo_mask.png
│ └── sketch-mountains-input.jpg
└── integration
│ ├── __init__.py
│ ├── conftest.py
│ ├── lib.py
│ ├── requirements.txt
│ ├── test_attn_procs.py
│ ├── test_build_download.py
│ ├── test_cloud_cache.py
│ ├── test_dreambooth.py
│ ├── test_general.py
│ ├── test_loras.py
│ └── test_memory.py
├── touch
├── update.sh
└── yarn.lock
/.circleci/config.yml:
--------------------------------------------------------------------------------
1 | version: 2.1
2 |
3 | jobs:
4 | build:
5 | docker:
6 | - image: cimg/python:3.9-node
7 | resource_class: medium
8 |
9 | # would have been nice, but not for $2,000/month!
10 | # machine:
11 | # image: ubuntu-2004-cuda-11.4:202110-01
12 | # resource_class: gpu.nvidia.small
13 |
14 | steps:
15 | - checkout
16 |
17 | - setup_remote_docker:
18 | docker_layer_caching: true
19 |
20 | - run: docker build -t gadicc/diffusers-api .
21 |
22 | # unit tests
23 | # - run: docker run gadicc/diffusers-api conda run --no-capture -n xformers pytest --cov=. --cov-report=xml --ignore=diffusers
24 | - run: docker run gadicc/diffusers-api pytest --cov=. --cov-report=xml --ignore=diffusers --ignore=Real-ESRGAN
25 |
26 | - run: echo $DOCKER_PASSWORD | docker login --username $DOCKER_USERNAME --password-stdin
27 |
28 | # push for non-semver branches (e.g. dev, feature branches)
29 | # - run:
30 | # name: Push to hub on branches not handled by semantic-release
31 | # command: |
32 | # SEMVER_BRANCHES=$(cat release.config.js | sed 's/module.exports = //' | sed 's/\/\/.*//' | jq .branches[])
33 | #
34 | # if [[ ${SEMVER_BRANCHES[@]} =~ "$CIRCLE_BRANCH" ]] ; then
35 | # echo "Skipping because '\$CIRCLE_BRANCH' == '$CIRCLE_BRANCH'"
36 | # echo "Semantic-release will handle the publishing"
37 | # else
38 | # echo "docker push gadicc/diffusers-api:$CIRCLE_BRANCH"
39 | # docker build -t gadicc/diffusers-api:$CIRCLE_BRANCH .
40 | # docker push gadicc/diffusers-api:$CIRCLE_BRANCH
41 | # echo "Skipping integration tests"
42 | # circleci-agent step halt
43 | # fi
44 |
45 | # needed for later "apt install" steps
46 | - run: sudo apt-get update
47 |
48 | ## TODO. The below was a great first step, but in future, let's build
49 | # the container on the host, run docker remotely on lambda, and
50 | # publish the same built image if tests pass.
51 |
52 | # TODO, only run on main channel for releases (with sem-rel too)
53 | # integration tests
54 | - run: sudo apt install -yqq rsync pv
55 | - run: ./run_integration_tests_on_lambda.sh
56 |
57 | - run:
58 | name: Push to hub on branches not handled by semantic-release
59 | command: |
60 | SEMVER_BRANCHES=$(cat release.config.js | sed 's/module.exports = //' | sed 's/\/\/.*//' | jq .branches[])
61 |
62 | if [[ ${SEMVER_BRANCHES[@]} =~ "$CIRCLE_BRANCH" ]] ; then
63 | echo "Skipping because '\$CIRCLE_BRANCH' == '$CIRCLE_BRANCH'"
64 | echo "Semantic-release will handle the publishing"
65 | else
66 | echo "docker push gadicc/diffusers-api:$CIRCLE_BRANCH"
67 | docker build -t gadicc/diffusers-api:$CIRCLE_BRANCH .
68 | docker push gadicc/diffusers-api:$CIRCLE_BRANCH
69 | # echo "Skipping integration tests"
70 | # circleci-agent step halt
71 | fi
72 |
73 | # deploy the image
74 | # - run: docker push company/app:$CIRCLE_BRANCH
75 | # https://github.com/semantic-release-plus/semantic-release-plus/tree/master/packages/plugins/docker
76 | - run:
77 | name: release
78 | command: |
79 | sudo apt-get install yarn
80 | yarn install
81 | yarn run semantic-release-plus
--------------------------------------------------------------------------------
/.devcontainer/devcontainer.json:
--------------------------------------------------------------------------------
1 | // For format details, see https://aka.ms/devcontainer.json. For config options, see the
2 | // README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-dockerfile
3 | {
4 | "name": "Existing Dockerfile",
5 | "build": {
6 | // Sets the run context to one level up instead of the .devcontainer folder.
7 | "context": "..",
8 | // Update the 'dockerFile' property if you aren't using the standard 'Dockerfile' filename.
9 | "dockerfile": "../Dockerfile"
10 | },
11 |
12 | // Features to add to the dev container. More info: https://containers.dev/features.
13 | "features": {
14 | "ghcr.io/devcontainers/features/python:1": {
15 | // "version": "3.10"
16 | }
17 | },
18 |
19 | // Use 'forwardPorts' to make a list of ports inside the container available locally.
20 | "forwardPorts": [8000],
21 |
22 | // Uncomment the next line to run commands after the container is created.
23 | "postCreateCommand": "scripts/devContainerPostCreate.sh",
24 |
25 | "customizations": {
26 | "vscode": {
27 | "extensions": [
28 | "ryanluker.vscode-coverage-gutters",
29 | "fsevenm.run-it-on",
30 | "ms-python.black-formatter",
31 | ],
32 | "settings": {
33 | "python.pythonPath": "/opt/conda/bin/python"
34 | }
35 | }
36 | },
37 |
38 | // Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root.
39 | // "remoteUser": "devcontainer"
40 |
41 | "mounts": [
42 | "source=${localEnv:HOME}/root-cache,target=/root/.cache,type=bind,consistency=cached"
43 | ],
44 |
45 | "runArgs": [
46 | "--gpus",
47 | "all",
48 | "--env-file",
49 | ".devcontainer/local.env"
50 | ]
51 | }
52 |
--------------------------------------------------------------------------------
/.devcontainer/local.example.env:
--------------------------------------------------------------------------------
1 | # Useful environment variables:
2 |
3 | # AWS or S3-compatible storage credentials and buckets
4 | AWS_ACCESS_KEY_ID=
5 | AWS_SECRET_ACCESS_KEY=
6 | AWS_DEFAULT_REGION=
7 | AWS_S3_DEFAULT_BUCKET=
8 | # Only fill this in if your (non-AWS) provider has told you what to put here
9 | AWS_S3_ENDPOINT_URL=
10 |
11 | # To use a proxy, e.g.
12 | # https://github.com/kiri-art/docker-diffusers-api/blob/dev/CONTRIBUTING.md#local-https-caching-proxy
13 | # DDA_http_proxy=http://172.17.0.1:3128
14 | # DDA_https_proxy=http://172.17.0.1:3128
15 |
16 | # HuggingFace credentials
17 | HF_AUTH_TOKEN=
18 | HF_USERNAME=
19 |
20 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | /lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 | permutations
131 | tests/output
132 | node_modules
133 | .devcontainer/local.env
134 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.testing.pytestArgs": [
3 | "--cov=.",
4 | "--cov-report=xml",
5 | "--ignore=test.py",
6 | "--ignore=tests/integration",
7 | "--ignore=diffusers",
8 | // "unit_tests.py"
9 | // "."
10 | ],
11 | "python.testing.unittestEnabled": false,
12 | "python.testing.pytestEnabled": true,
13 | // "python.defaultInterpreterPath": "/opt/conda/envs/xformers/bin/python",
14 | "python.defaultInterpreterPath": "/opt/conda/bin/python",
15 | "runItOn": {
16 | "commands": [
17 | {
18 | "match": "\\.py$",
19 | "isAsync": true,
20 | "isShellCommand": false,
21 | "cmd": "testing.runAll"
22 | },
23 | ],
24 | },
25 | "[python]": {
26 | "editor.defaultFormatter": "ms-python.black-formatter"
27 | },
28 | "python.formatting.provider": "none"
29 | }
30 |
--------------------------------------------------------------------------------
/.vscode/tasks.json:
--------------------------------------------------------------------------------
1 | {
2 | // See https://go.microsoft.com/fwlink/?LinkId=733558
3 | // for the documentation about the tasks.json format
4 | "version": "2.0.0",
5 | "tasks": [
6 | {
7 | "label": "Watching Server",
8 | "type": "shell",
9 | "command": "scripts/devContainerServer.sh"
10 | }
11 | ]
12 | }
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # CONTRIBUTING
2 |
3 | *Tips for development*
4 |
5 | 1. [General Hints](#general)
6 | 1. [Development / Editor Setup](#editors)
7 | 1. [Visual Studio Code (vscode)](#vscode)
8 | 1. [Testing](#testing)
9 | 1. [Using Buildkit](#buildkit)
10 | 1. [Local HTTP(S) Caching Proxy](#caching)
11 | 1. [Local S3 Server](#local-s3-server)
12 | 1. [Stop on Suspend](#stop-on-suspend)
13 |
14 |
15 | ## General
16 |
17 | 1. Run docker with `-it` to make it easier to stop container with `Ctrl-C`.
18 | 1. If you get a `CUDA initialization: CUDA unknown error` after suspend,
19 | just stop the container, `rmmod nvidia_uvm`, and restart.
20 |
21 |
22 | ## Editors
23 |
24 |
25 | ### Visual Studio Code (recommended, WIP)
26 |
27 | *We're still writing this guide, let us know of any needed improvements*
28 |
29 | This repo includes VSCode settings that allow for a) editing inside a docker container, b) tests and coverage (on save)
30 |
31 | 1. Install from https://code.visualstudio.com/
32 | 1. Install [Remote - Containers](https://marketplace.visualstudio.com/items?itemName=ms-vscode-remote.remote-containers) extension.
33 | 1. Open your docker-diffusers-api folder, you'll get a popup in the bottom right that a dev container environment was detected, click "reload in container"
34 | 1. Look for the "( ) Watch" on status bar and click it so it changes to "( ) XX Coverage"
35 |
36 | **Live Development**
37 |
38 | 1. **Run Task** (either Ctrl-Shift-P and "Run Task", or in Terminals, the Plus ("+") DROPDOWN selector and choose, "Run Task..." at the bottom)
39 | 1. Choose **Watching Server**. Port 8000 will be forwarded. The server will be reloaded
40 | on every file safe (make sure to give it enough time to fully load before sending another
41 | request, otherwise that request will hang).
42 |
43 |
44 | ## Testing
45 |
46 | 1. **Unit testing**: exists but is sorely lacking for now. If you use the
47 | recommended editor setup above, it's probably working already. However:
48 |
49 | 1. **Integation / E2E**: cover most features used in production.
50 | `pytest -s tests/integration`.
51 | The `-s` is optional but streams stdout so you can follow along.
52 | Add also `-k test_name` to test a specific test. E2E tests are LONG but you can
53 | greatly reduce subsequent run time by following the steps below for a
54 | [Local HTTP(S) Caching Proxy](#caching) and [Local S3 Server](#local-s3-server).
55 |
56 | Docker-Diffusers-API follows Semantic Versioning. We follow the
57 | [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/)
58 | standard.
59 |
60 | * On a commit to `dev`, if all CI tests pass, a new release is made to `:dev` tag.
61 | * On a commit to `main`, if all CI tests pass, a new release with appropriate
62 | major / minor / patch is made, based on appropriate tags in the commit history.
63 |
64 |
65 | ## Using BuildKit
66 |
67 | Buildkit is a docker extension that can really improve build speeds through
68 | caching and parallelization. You can enable and tweak it by adding:
69 |
70 | `DOCKER_BUILDKIT=1 BUILDKIT_PROGRESS=plain`
71 |
72 | vars before `docker build` (the `PROGRESS` var shows much more detailed
73 | build logs, which can be useful, but are much more verbose). This is
74 | already all setup in the the [build](./build) script.
75 |
76 |
77 | ## Local HTTP(S) Caching Proxy
78 |
79 | If you're only editing e.g. `app.py`, there's no need to worry about caching
80 | and the docker layers work amazingly. But, if you're constantly changing
81 | installed packages (apt, `requirements.txt`), `download.py`, etc, it's VERY
82 | helpful to have a local cache:
83 |
84 | ```bash
85 | # See all options at https://hub.docker.com/r/gadicc/squid-ssl-zero
86 | $ docker run -d -p 3128:3128 -p 3129:80 \
87 | --name squid --restart=always \
88 | -v /usr/local/squid:/usr/local/squid \
89 | gadicc/squid-ssl-zero
90 | ```
91 |
92 | and then set the docker build args `proxy=1`, and `http_proxy` / `https_proxy`
93 | with their respective values.
94 | This is already all set up in the [build](./build) script.
95 |
96 | **You probably want to fine-tune /usr/local/squid/etc/squid.conf**.
97 |
98 | It will be created after you first run `gadicc/squid-ssl-zero`. You can then
99 | stop the container (`docker ps`, `docker stop container_id`), edit the file,
100 | and re-start (`docker start container_id`). For now, try something like:
101 |
102 | ```conf
103 | cache_dir ufs /usr/local/squid/cache 50000 16 256 # 50GB
104 | maximum_object_size 20 GB
105 | refresh_pattern . 52034400 50% 52034400 store-stale override-expire ignore-no-cache ignore-no-store ignore-private
106 | ```
107 |
108 | but ideally we can as a community create some rules that don't so
109 | aggressively catch every single request.
110 |
111 |
112 | ## Local S3 server
113 |
114 | If you're doing development around the S3 handling, it can be very useful to
115 | have a local S3 server, especially due to the large size of models. You
116 | can set one up like this:
117 |
118 | ```bash
119 | $ docker run -p 9000:9000 -p 9001:9001 \
120 | -v /usr/local/minio:/data quay.io/minio/minio \
121 | server /data --console-address ":9001"
122 | ```
123 |
124 | Now point a web browser to http://localhost:9001/, login with the default
125 | root credentials `minioadmin:minioadmin` and create a bucket and credentials
126 | for testing. More info at https://hub.docker.com/r/minio/minio/.
127 |
128 | Typical policy:
129 |
130 | ```json
131 | {
132 | "Version": "2012-10-17",
133 | "Statement": [
134 | {
135 | "Sid": "VisualEditor0",
136 | "Effect": "Allow",
137 | "Action": [
138 | "s3:PutObject",
139 | "s3:GetObject"
140 | ],
141 | "Resource": "arn:aws:s3:::BUCKET_NAME/*"
142 | }
143 | ]
144 | }
145 | ```
146 |
147 | Then set the **build-arg** `AWS_S3_ENDPOINT_URL="http://172.17.0.1:9000"`
148 | or as appropriate if you've changed the default docker network.
149 |
150 |
151 | ## Stop on Suspend
152 |
153 | Maybe it's just me, but frequently I'll have issues when suspending with
154 | the container running (I guess its a CUDA issue), either a freeze on resume,
155 | or a stuck-forever defunct process. I found it useful to automatically stop
156 | the container / process on suspend.
157 |
158 | I'm running ArchLinux and set up a `systemd` suspend hook as described
159 | [here](https://wiki.archlinux.org/title/Power_management#Sleep_hooks), to
160 | call a script, which contains:
161 |
162 | ```bash
163 | # Stop a matching docker container
164 | PID=`docker ps -qf ancestor=gadicc/diffusers-api`
165 | if [ ! -z $PID ] ; then
166 | echo "Stopping diffusers-api pid $PID"
167 | docker stop $PID
168 | fi
169 |
170 | # For a VSCode devcontainer, just kill the watchmedo process.
171 | PID=`docker ps -qf volume=/home/dragon/root-cache`
172 | if [ ! -z $PID ] ; then
173 | echo "Stopping watchmedo in container $PID"
174 | docker exec $PID /bin/bash -c 'kill `pidof -sx watchmedo`'
175 | fi
176 | ```
177 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | ARG FROM_IMAGE="pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime"
2 | # ARG FROM_IMAGE="gadicc/diffusers-api-base:python3.9-pytorch1.12.1-cuda11.6-xformers"
3 | # You only need the -banana variant if you need banana's optimization
4 | # i.e. not relevant if you're using RUNTIME_DOWNLOADS
5 | # ARG FROM_IMAGE="gadicc/python3.9-pytorch1.12.1-cuda11.6-xformers-banana"
6 | FROM ${FROM_IMAGE} as base
7 | ENV FROM_IMAGE=${FROM_IMAGE}
8 |
9 | # Note, docker uses HTTP_PROXY and HTTPS_PROXY (uppercase)
10 | # We purposefully want those managed independently, as we want docker
11 | # to manage its own cache. This is just for pip, models, etc.
12 | ARG http_proxy
13 | ARG https_proxy
14 | RUN if [ -n "$http_proxy" ] ; then \
15 | echo quit \
16 | | openssl s_client -proxy $(echo ${https_proxy} | cut -b 8-) -servername google.com -connect google.com:443 -showcerts \
17 | | sed 'H;1h;$!d;x; s/^.*\(-----BEGIN CERTIFICATE-----.*-----END CERTIFICATE-----\)\n---\nServer certificate.*$/\1/' \
18 | > /usr/local/share/ca-certificates/squid-self-signed.crt ; \
19 | update-ca-certificates ; \
20 | fi
21 | ARG REQUESTS_CA_BUNDLE=${http_proxy:+/usr/local/share/ca-certificates/squid-self-signed.crt}
22 |
23 | ARG DEBIAN_FRONTEND=noninteractive
24 |
25 | ARG TZ=UTC
26 | RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
27 |
28 | RUN apt-get update
29 | RUN apt-get install -yq apt-utils
30 | RUN apt-get install -yqq git zstd wget curl
31 |
32 | FROM base AS patchmatch
33 | ARG USE_PATCHMATCH=0
34 | WORKDIR /tmp
35 | COPY scripts/patchmatch-setup.sh .
36 | RUN sh patchmatch-setup.sh
37 |
38 | FROM base as output
39 | RUN mkdir /api
40 | WORKDIR /api
41 |
42 | # we use latest pip in base image
43 | # RUN pip3 install --upgrade pip
44 |
45 | ADD requirements.txt requirements.txt
46 | RUN pip install -r requirements.txt
47 |
48 | # [Import] Add missing settings / Correct some dummy imports (#5036) - 2023-09-14
49 | ARG DIFFUSERS_VERSION="3aa641289c995b3a0ce4ea895a76eb1128eff30c"
50 | ENV DIFFUSERS_VERSION=${DIFFUSERS_VERSION}
51 |
52 | RUN git clone https://github.com/huggingface/diffusers && cd diffusers && git checkout ${DIFFUSERS_VERSION}
53 | WORKDIR /api
54 | RUN pip install -e diffusers
55 |
56 | # Set to true to NOT download model at build time, rather at init / usage.
57 | ARG RUNTIME_DOWNLOADS=1
58 | ENV RUNTIME_DOWNLOADS=${RUNTIME_DOWNLOADS}
59 |
60 | # TODO, to dda-bananana
61 | # ARG PIPELINE="StableDiffusionInpaintPipeline"
62 | ARG PIPELINE="ALL"
63 | ENV PIPELINE=${PIPELINE}
64 |
65 | # Deps for RUNNING (not building) earlier options
66 | ARG USE_PATCHMATCH=0
67 | RUN if [ "$USE_PATCHMATCH" = "1" ] ; then apt-get install -yqq python3-opencv ; fi
68 | COPY --from=patchmatch /tmp/PyPatchMatch PyPatchMatch
69 |
70 | # TODO, just include by default, and handle all deps in OUR requirements.txt
71 | ARG USE_DREAMBOOTH=1
72 | ENV USE_DREAMBOOTH=${USE_DREAMBOOTH}
73 |
74 | RUN if [ "$USE_DREAMBOOTH" = "1" ] ; then \
75 | # By specifying the same torch version as conda, it won't download again.
76 | # Without this, it will upgrade torch, break xformers, make bigger image.
77 | # bitsandbytes==0.40.0.post4 had failed cuda detection on dreambooth test.
78 | pip install -r diffusers/examples/dreambooth/requirements.txt ; \
79 | fi
80 | RUN if [ "$USE_DREAMBOOTH" = "1" ] ; then apt-get install -yqq git-lfs ; fi
81 |
82 | ARG USE_REALESRGAN=1
83 | RUN if [ "$USE_REALESRGAN" = "1" ] ; then apt-get install -yqq libgl1-mesa-glx libglib2.0-0 ; fi
84 | RUN if [ "$USE_REALESRGAN" = "1" ] ; then git clone https://github.com/xinntao/Real-ESRGAN.git ; fi
85 | # RUN if [ "$USE_REALESRGAN" = "1" ] ; then pip install numba==0.57.1 chardet ; fi
86 | RUN if [ "$USE_REALESRGAN" = "1" ] ; then pip install basicsr==1.4.2 facexlib==0.2.5 gfpgan==1.3.8 ; fi
87 | RUN if [ "$USE_REALESRGAN" = "1" ] ; then cd Real-ESRGAN && python3 setup.py develop ; fi
88 |
89 | COPY api/ .
90 | EXPOSE 8000
91 |
92 | ARG SAFETENSORS_FAST_GPU=1
93 | ENV SAFETENSORS_FAST_GPU=${SAFETENSORS_FAST_GPU}
94 |
95 | CMD python3 -u server.py
96 |
97 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Banana, Gadi Cohen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # docker-diffusers-api ("banana-sd-base")
2 |
3 | Diffusers / Stable Diffusion in docker with a REST API, supporting various models, pipelines & schedulers. Used by [kiri.art](https://kiri.art/), perfect for local, server & serverless.
4 |
5 | [](https://hub.docker.com/r/gadicc/diffusers-api/tags) [](https://circleci.com/gh/kiri-art/docker-diffusers-api?branch=split) [](https://github.com/semantic-release/semantic-release) [](./LICENSE) [](https://vscode.dev/redirect?url=vscode://ms-vscode-remote.remote-containers/cloneInVolume?url=https://github.com/kiri-art/docker-diffusers-api)
6 |
7 | Copyright (c) Gadi Cohen, 2022. MIT Licensed.
8 | Please give credit and link back to this repo if you use it in a public project.
9 |
10 | ## Features
11 |
12 | * Models: stable-diffusion, waifu-diffusion, and easy to add others (e.g. jp-sd)
13 | * Pipelines: txt2img, img2img and inpainting in a single container
14 | ([all diffusers official and community pipelines](https://forums.kiri.art/t/all-your-pipelines-are-belong-to-us/83) are wrapped, but untested)
15 | * All model inputs supported, including setting nsfw filter per request
16 | * *Permute* base config to multiple forks based on yaml config with vars
17 | * Optionally send signed event logs / performance data to a REST endpoint / webhook.
18 | * Can automatically download a checkpoint file and convert to diffusers.
19 | * S3 support, dreambooth training.
20 |
21 | Note: This image was created for [kiri.art](https://kiri.art/).
22 | Everything is open source but there may be certain request / response
23 | assumptions. If anything is unclear, please open an issue.
24 |
25 | ## Important Notices
26 |
27 | * [Official `docker-diffusers-api` Forum](https://forums.kiri.art/c/docker-diffusers-api/16):
28 | help, updates, discussion.
29 | * Subscribe ("watch") these forum topics for:
30 | * [notable **`main`** branch updates](https://forums.kiri.art/t/official-releases-main-branch/35)
31 | * [notable **`dev`** branch updates](https://forums.kiri.art/t/development-releases-dev-branch/53)
32 | * Always [check the CHANGELOG](./CHANGELOG.md) for important updates when upgrading.
33 |
34 | **Official help in our dedicated forum https://forums.kiri.art/c/docker-diffusers-api/16.**
35 |
36 | **This README refers to the in-development `dev` branch** and may
37 | reference features and fixes not yet in the published releases.
38 |
39 | **`v1` has not yet been officially released yet** but has been
40 | running well in production on kiri.art for almost a month. We'd
41 | be grateful for any feedback from early adopters to help make
42 | this official. For more details, see [Upgrading from v0 to
43 | v1](https://forums.kiri.art/t/wip-upgrading-from-v0-to-v1/116).
44 | Previous releases available on the `dev-v0-final` and
45 | `main-v0-final` branches.
46 |
47 | **Currently only NVIDIA / CUDA devices are supported**. Tracking
48 | Apple / M1 support in issue
49 | [#20](https://github.com/kiri-art/docker-diffusers-api/issues/20).
50 |
51 | ## Installation & Setup:
52 |
53 | Setup varies depending on your use case.
54 |
55 | 1. **To run locally or on a *server*, with runtime downloads:**
56 |
57 | `docker run --gpus all -p 8000:8000 -e HF_AUTH_TOKEN=$HF_AUTH_TOKEN gadicc/diffusers-api`.
58 |
59 | See the [guides for various cloud providers](https://forums.kiri.art/t/running-on-other-cloud-providers/89/7).
60 |
61 | 1. **To run *serverless*, include the model at build time:**
62 |
63 | 1. [docker-diffusers-api-build-download](https://github.com/kiri-art/docker-diffusers-api-build-download) (
64 | [banana](https://forums.kiri.art/t/run-diffusers-api-on-banana-dev/103), others)
65 | 1. [docker-diffusers-api-runpod](https://github.com/kiri-art/docker-diffusers-api-runpod),
66 | see the [guide](https://forums.kiri.art/t/run-diffusers-api-on-runpod-io/102)
67 |
68 | 1. **Building from source**.
69 |
70 | 1. Fork / clone this repo.
71 | 1. `docker build -t gadicc/diffusers-api .`
72 | 1. See [CONTRIBUTING.md](./CONTRIBUTING.md) for more helpful hints.
73 |
74 | *Other configurations are possible but these are the most common cases*
75 |
76 | Everything is set via docker build-args or environment variables.
77 |
78 | ## Usage:
79 |
80 | See also [Testing](#testing) below.
81 |
82 | The container expects an `HTTP POST` request to `/`, with a JSON body resembling the following:
83 |
84 | ```json
85 | {
86 | "modelInputs": {
87 | "prompt": "Super dog",
88 | "num_inference_steps": 50,
89 | "guidance_scale": 7.5,
90 | "width": 512,
91 | "height": 512,
92 | "seed": 3239022079
93 | },
94 | "callInputs": {
95 | // You can leave these out to use the default
96 | "MODEL_ID": "runwayml/stable-diffusion-v1-5",
97 | "PIPELINE": "StableDiffusionPipeline",
98 | "SCHEDULER": "LMSDiscreteScheduler",
99 | "safety_checker": true,
100 | },
101 | }
102 | ```
103 |
104 | It's important to remember that `docker-diffusers-api` is primarily a wrapper
105 | around HuggingFace's
106 | [diffusers](https://huggingface.co/docs/diffusers/index) library.
107 | **Basic familiarity with `diffusers` is indespensible for a good experience
108 | with `docker-diffusers-api`.** Explaining some of the options above:
109 |
110 | * **modelInputs** - for the most part - are passed directly to the selected
111 | diffusers pipeline unchanged. So, for the default `StableDiffusionPipeline`,
112 | you can see all options in the relevant pipeline docs for its
113 | [`__call__`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img#diffusers.StableDiffusionPipeline.__call__) method. The main exceptions are:
114 |
115 | * Only valid JSON values can be given (strings, numbers, etc)
116 | * **seed**, a number, is transformed into a `generator`.
117 | * **images** are converted to / from base64 encoded strings.
118 |
119 | * **callInputs** affect which model, pipeline, scheduler and other lower
120 | level options are used to construct the final pipeline. Notably:
121 |
122 | * **`SCHEDULER`**: any scheduler included in diffusers should work out
123 | the box, provided it can loaded with its default config and without
124 | requiring any other explicit arguments at init time. In any event,
125 | the following schedulers are the most common and most well tested:
126 | `DPMSolverMultistepScheduler` (fast! only needs 20 steps!),
127 | `LMSDiscreteScheduler`, `DDIMScheduler`, `PNDMScheduler`,
128 | `EulerAncestralDiscreteScheduler`, `EulerDiscreteScheduler`.
129 |
130 | * **`PIPELINE`**: the most common are
131 | [`StableDiffusionPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/text2img),
132 | [`StableDiffusionImg2ImgPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/img2img),
133 | [`StableDiffusionInpaintPipeline`](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/inpaint), and the community
134 | [`lpw_stable_diffusion`](https://forums.kiri.art/t/lpw-stable-diffusion-pipeline-longer-prompts-prompt-weights/82)
135 | which allows for long prompts (more than 77 tokens) and prompt weights
136 | (things like `((big eyes))`, `(red hair:1.2)`, etc), and accepts a
137 | `custom_pipeline_method` callInput with values `text2img` ("text", not "txt"),
138 | `img2img` and `inpaint`. See these links for all the possible `modelInputs`'s
139 | that can be passed to the pipeline's `__call__` method.
140 |
141 | * **`MODEL_URL`** (optional) can be used to retrieve the model from
142 | locations other than HuggingFace, e.g. an `HTTP` server, S3-compatible
143 | storage, etc. For more info, see the
144 | [storage docs](https://github.com/kiri-art/docker-diffusers-api/blob/dev/docs/storage.md)
145 | and
146 | [this post](https://forums.kiri.art/t/safetensors-our-own-optimization-faster-model-init/98)
147 | for info on how to use and store optimized models from your own cloud.
148 |
149 |
150 | ## Examples and testing
151 |
152 | There are also very basic examples in [test.py](./test.py), which you can view
153 | and call `python test.py` if the container is already running on port 8000.
154 | You can also specify a specific test, change some options, and run against a
155 | deployed banana image:
156 |
157 | ```bash
158 | $ python test.py
159 | Usage: python3 test.py [--banana] [--xmfe=1/0] [--scheduler=SomeScheduler] [all / test1] [test2] [etc]
160 |
161 | # Run against http://localhost:8000/ (Nvidia Quadro RTX 5000)
162 | $ python test.py txt2img
163 | Running test: txt2img
164 | Request took 5.9s (init: 3.2s, inference: 5.9s)
165 | Saved /home/dragon/www/banana/banana-sd-base/tests/output/txt2img.png
166 |
167 | # Run against deployed banana image (Nvidia A100)
168 | $ export BANANA_API_KEY=XXX
169 | $ BANANA_MODEL_KEY=XXX python3 test.py --banana txt2img
170 | Running test: txt2img
171 | Request took 19.4s (init: 2.5s, inference: 3.5s)
172 | Saved /home/dragon/www/banana/banana-sd-base/tests/output/txt2img.png
173 |
174 | # Note that 2nd runs are much faster (ignore init, that isn't run again)
175 | Request took 3.0s (init: 2.4s, inference: 2.1s)
176 | ```
177 |
178 | The best example of course is https://kiri.art/ and it's
179 | [source code](https://github.com/kiri-art/stable-diffusion-react-nextjs-mui-pwa).
180 |
181 | ## Help on [Official Forums](https://forums.kiri.art/c/docker-diffusers-api/16).
182 |
183 | ## Adding other Models
184 |
185 | You have two options.
186 |
187 | 1. For a diffusers model, simply set `MODEL_ID` build-var / call-arg to the name
188 | of the model hosted on HuggingFace, and it will be downloaded automatically at
189 | build time.
190 |
191 | 1. For a non-diffusers model, simply set the `CHECKPOINT_URL` build-var / call-arg
192 | to the URL of a `.ckpt` file, which will be downloaded and converted to the diffusers
193 | format automatically at build time. `CHECKPOINT_CONFIG_URL` can also be set.
194 |
195 | ## Troubleshooting
196 |
197 | * **403 Client Error: Forbidden for url**
198 |
199 | Make sure you've accepted the license on the model card of the HuggingFace model
200 | specified in `MODEL_ID`, and that you correctly passed `HF_AUTH_TOKEN` to the
201 | container.
202 |
203 | ## Event logs / web hooks / performance data
204 |
205 | Set `SEND_URL` (and optionally `SIGN_KEY`) environment variable(s) to send
206 | event and timing data on `init`, `inference` and other start and end events.
207 | This can either be used to log performance data, or for webhooks on event
208 | start / finish.
209 |
210 | The timing data is now returned in the response payload too, like this:
211 | `{ $timings: { init: timeInMs, inference: timeInMs } }`, with any other
212 | events (such a `training`, `upload`, etc).
213 |
214 | You can go to https://webhook.site/ and use the provided "unique URL"
215 | as your `SEND_URL` to see how it works, if you don't have your own
216 | REST endpoint (yet).
217 |
218 | If `SIGN_KEY` is used, you can verify the signature like this (TypeScript):
219 |
220 | ```ts
221 | import crypto from "crypto";
222 |
223 | async function handler(req: NextApiRequest, res: NextApiResponse) {
224 | const data = req.body;
225 |
226 | const containerSig = data.sig as string;
227 | delete data.sig;
228 |
229 | const ourSig = crypto
230 | .createHash("md5")
231 | .update(JSON.stringify(data) + process.env.SIGN_KEY)
232 | .digest("hex");
233 |
234 | const signatureIsValid = containerSig === ourSig;
235 | }
236 | ```
237 |
238 | If you send a callInput called `startRequestId`, it will get sent
239 | back as part of the send payload in most cases.
240 |
241 | You can also set callInputs `SEND_URL` and `SIGN_KEY` to
242 | set or override these values on a per-request basis.
243 |
244 | ## Acknowledgements
245 |
246 | * The container image is originally based on
247 | https://github.com/bananaml/serverless-template-stable-diffusion.
248 |
249 | * [CompVis](https://github.com/CompVis),
250 | [Stability AI](https://stability.ai/),
251 | [LAION](https://laion.ai/)
252 | and [RunwayML](https://runwayml.com/)
253 | for their incredible time, work and efforts in creating Stable Diffusion,
254 | and no less so, their decision to release it publicly with an open source
255 | license.
256 |
257 | * [HuggingFace](https://huggingface.co/) - for their passion and inspiration
258 | for making machine learning more accessibe to developers, and in particular,
259 | their [Diffusers](https://github.com/huggingface/diffusers) library.
260 |
--------------------------------------------------------------------------------
/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/__init__.py
--------------------------------------------------------------------------------
/api/app.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from sched import scheduler
3 | import torch
4 |
5 | from torch import autocast
6 | from diffusers import __version__
7 | import base64
8 | from io import BytesIO
9 | import PIL
10 | import json
11 | from loadModel import loadModel
12 | from send import send, getTimings, clearSession
13 | from status import status
14 | import os
15 | import numpy as np
16 | import skimage
17 | import skimage.measure
18 | from getScheduler import getScheduler, SCHEDULERS
19 | from getPipeline import (
20 | getPipelineClass,
21 | getPipelineForModel,
22 | listAvailablePipelines,
23 | clearPipelines,
24 | )
25 | import re
26 | import requests
27 | from download import download_model, normalize_model_id
28 | import traceback
29 | from precision import MODEL_REVISION, MODEL_PRECISION
30 | from device import device, device_id, device_name
31 | from utils import Storage
32 | from hashlib import sha256
33 | from threading import Timer
34 | import extras
35 | import jxlpy
36 | from jxlpy import JXLImagePlugin
37 |
38 |
39 | from diffusers import (
40 | StableDiffusionXLPipeline,
41 | StableDiffusionXLImg2ImgPipeline,
42 | StableDiffusionXLInpaintPipeline,
43 | pipelines as diffusers_pipelines,
44 | AutoencoderTiny,
45 | AutoencoderKL,
46 | )
47 |
48 | from lib.textual_inversions import handle_textual_inversions
49 | from lib.prompts import prepare_prompts
50 | from lib.vars import (
51 | RUNTIME_DOWNLOADS,
52 | USE_DREAMBOOTH,
53 | MODEL_ID,
54 | PIPELINE,
55 | HF_AUTH_TOKEN,
56 | HOME,
57 | MODELS_DIR,
58 | )
59 |
60 | if USE_DREAMBOOTH:
61 | from train_dreambooth import TrainDreamBooth
62 | print(os.environ.get("USE_PATCHMATCH"))
63 | if os.environ.get("USE_PATCHMATCH") == "1":
64 | from PyPatchMatch import patch_match
65 |
66 | torch.set_grad_enabled(False)
67 | always_normalize_model_id = None
68 |
69 | tiny_vae = None
70 |
71 |
72 | # still working on this, not in use yet.
73 | def tinyVae(origVae: AutoencoderKL):
74 | global tiny_vae
75 | if not tiny_vae:
76 | tiny_vae = AutoencoderTiny.from_pretrained(
77 | "madebyollin/taesd",
78 | torch_dtype=torch.float16,
79 | in_channels=origVae.config.in_channels,
80 | out_channels=origVae.config.out_channels,
81 | act_fn=origVae.config.act_fn,
82 | latent_channels=origVae.config.latent_channels,
83 | scaling_factor=origVae.config.scaling_factor,
84 | force_upcast=origVae.config.force_upcast,
85 | )
86 | tiny_vae.to("cuda")
87 |
88 | return tiny_vae
89 |
90 |
91 | # Init is ran on server startup
92 | # Load your model to GPU as a global variable here using the variable name "model"
93 | def init():
94 | global model # needed for bananna optimizations
95 | global always_normalize_model_id
96 |
97 | asyncio.run(
98 | send(
99 | "init",
100 | "start",
101 | {
102 | "device": device_name,
103 | "hostname": os.getenv("HOSTNAME"),
104 | "model_id": MODEL_ID,
105 | "diffusers": __version__,
106 | },
107 | )
108 | )
109 |
110 | if MODEL_ID == "ALL" or RUNTIME_DOWNLOADS:
111 | global last_model_id
112 | last_model_id = None
113 |
114 | if not RUNTIME_DOWNLOADS:
115 | normalized_model_id = normalize_model_id(MODEL_ID, MODEL_REVISION)
116 | model_dir = os.path.join(MODELS_DIR, normalized_model_id)
117 | if os.path.isdir(model_dir):
118 | always_normalize_model_id = model_dir
119 | else:
120 | normalized_model_id = MODEL_ID
121 |
122 | model = loadModel(
123 | model_id=always_normalize_model_id or MODEL_ID,
124 | load=True,
125 | precision=MODEL_PRECISION,
126 | revision=MODEL_REVISION,
127 | )
128 | else:
129 | model = None
130 |
131 | asyncio.run(send("init", "done"))
132 |
133 |
134 | def decodeBase64Image(imageStr: str, name: str) -> PIL.Image:
135 | image = PIL.Image.open(BytesIO(base64.decodebytes(bytes(imageStr, "utf-8"))))
136 | print(f'Decoded image "{name}": {image.format} {image.width}x{image.height}')
137 | return image
138 |
139 |
140 | def getFromUrl(url: str, name: str) -> PIL.Image:
141 | response = requests.get(url)
142 | image = PIL.Image.open(BytesIO(response.content))
143 | print(f'Decoded image "{name}": {image.format} {image.width}x{image.height}')
144 | return image
145 |
146 |
147 | def truncateInputs(inputs: dict):
148 | clone = inputs.copy()
149 | if "modelInputs" in clone:
150 | modelInputs = clone["modelInputs"] = clone["modelInputs"].copy()
151 | for item in ["init_image", "mask_image", "image", "input_image"]:
152 | if item in modelInputs:
153 | modelInputs[item] = modelInputs[item][0:6] + "..."
154 | if "instance_images" in modelInputs:
155 | modelInputs["instance_images"] = list(
156 | map(lambda str: str[0:6] + "...", modelInputs["instance_images"])
157 | )
158 | return clone
159 |
160 |
161 | # last_xformers_memory_efficient_attention = {}
162 | last_attn_procs = None
163 | last_lora_weights = None
164 | cross_attention_kwargs = None
165 |
166 |
167 | # Inference is ran for every server call
168 | # Reference your preloaded global model variable here.
169 | async def inference(all_inputs: dict, response) -> dict:
170 | global model
171 | global pipelines
172 | global last_model_id
173 | global schedulers
174 | # global last_xformers_memory_efficient_attention
175 | global always_normalize_model_id
176 | global last_attn_procs
177 | global last_lora_weights
178 | global cross_attention_kwargs
179 |
180 | clearSession()
181 |
182 | print(json.dumps(truncateInputs(all_inputs), indent=2))
183 | model_inputs = all_inputs.get("modelInputs", None)
184 | call_inputs = all_inputs.get("callInputs", None)
185 | result = {"$meta": {}}
186 |
187 | send_opts = {}
188 | if call_inputs.get("SEND_URL", None):
189 | send_opts.update({"SEND_URL": call_inputs.get("SEND_URL")})
190 | if call_inputs.get("SIGN_KEY", None):
191 | send_opts.update({"SIGN_KEY": call_inputs.get("SIGN_KEY")})
192 | if response:
193 | send_opts.update({"response": response})
194 |
195 | async def sendStatusAsync():
196 | await response.send(json.dumps(status.get()) + "\n")
197 |
198 | def sendStatus():
199 | try:
200 | asyncio.run(sendStatusAsync())
201 | Timer(1.0, sendStatus).start()
202 | except:
203 | pass
204 |
205 | Timer(1.0, sendStatus).start()
206 |
207 | if model_inputs == None or call_inputs == None:
208 | return {
209 | "$error": {
210 | "code": "INVALID_INPUTS",
211 | "message": "Expecting on object like { modelInputs: {}, callInputs: {} } but got "
212 | + json.dumps(all_inputs),
213 | }
214 | }
215 |
216 | startRequestId = call_inputs.get("startRequestId", None)
217 |
218 | use_extra = call_inputs.get("use_extra", None)
219 | if use_extra:
220 | extra = getattr(extras, use_extra, None)
221 | if not extra:
222 | return {
223 | "$error": {
224 | "code": "NO_SUCH_EXTRA",
225 | "message": 'Requested "'
226 | + use_extra
227 | + '", available: "'
228 | + '", "'.join(extras.keys())
229 | + '"',
230 | }
231 | }
232 | return await extra(
233 | model_inputs,
234 | call_inputs,
235 | send_opts=send_opts,
236 | startRequestId=startRequestId,
237 | )
238 |
239 | model_id = call_inputs.get("MODEL_ID", None)
240 | if not model_id:
241 | if not MODEL_ID:
242 | return {
243 | "$error": {
244 | "code": "NO_MODEL_ID",
245 | "message": "No callInputs.MODEL_ID specified, nor was MODEL_ID env var set.",
246 | }
247 | }
248 | model_id = MODEL_ID
249 | result["$meta"].update({"MODEL_ID": MODEL_ID})
250 | normalized_model_id = model_id
251 |
252 | if RUNTIME_DOWNLOADS:
253 | hf_model_id = call_inputs.get("HF_MODEL_ID", None)
254 | model_revision = call_inputs.get("MODEL_REVISION", None)
255 | model_precision = call_inputs.get("MODEL_PRECISION", None)
256 | checkpoint_url = call_inputs.get("CHECKPOINT_URL", None)
257 | checkpoint_config_url = call_inputs.get("CHECKPOINT_CONFIG_URL", None)
258 | normalized_model_id = normalize_model_id(model_id, model_revision)
259 | model_dir = os.path.join(MODELS_DIR, normalized_model_id)
260 | pipeline_name = call_inputs.get("PIPELINE", None)
261 | if pipeline_name:
262 | pipeline_class = getPipelineClass(pipeline_name)
263 | if last_model_id != normalized_model_id:
264 | # if not downloaded_models.get(normalized_model_id, None):
265 | if not os.path.isdir(model_dir):
266 | model_url = call_inputs.get("MODEL_URL", None)
267 | if not model_url:
268 | # return {
269 | # "$error": {
270 | # "code": "NO_MODEL_URL",
271 | # "message": "Currently RUNTIME_DOWNOADS requires a MODEL_URL callInput",
272 | # }
273 | # }
274 | normalized_model_id = hf_model_id or model_id
275 | await download_model(
276 | model_id=model_id,
277 | model_url=model_url,
278 | model_revision=model_revision,
279 | checkpoint_url=checkpoint_url,
280 | checkpoint_config_url=checkpoint_config_url,
281 | hf_model_id=hf_model_id,
282 | model_precision=model_precision,
283 | send_opts=send_opts,
284 | pipeline_class=pipeline_class if pipeline_name else None,
285 | )
286 | # downloaded_models.update({normalized_model_id: True})
287 | clearPipelines()
288 | cross_attention_kwargs = None
289 | if model:
290 | model.to("cpu") # Necessary to avoid a memory leak
291 | await send(
292 | "loadModel", "start", {"startRequestId": startRequestId}, send_opts
293 | )
294 | model = await asyncio.to_thread(
295 | loadModel,
296 | model_id=normalized_model_id,
297 | load=True,
298 | precision=model_precision,
299 | revision=model_revision,
300 | send_opts=send_opts,
301 | pipeline_class=pipeline_class if pipeline_name else None,
302 | )
303 | await send(
304 | "loadModel", "done", {"startRequestId": startRequestId}, send_opts
305 | )
306 | last_model_id = normalized_model_id
307 | last_attn_procs = None
308 | last_lora_weights = None
309 | else:
310 | if always_normalize_model_id:
311 | normalized_model_id = always_normalize_model_id
312 | print(
313 | {
314 | "always_normalize_model_id": always_normalize_model_id,
315 | "normalized_model_id": normalized_model_id,
316 | }
317 | )
318 |
319 | if MODEL_ID == "ALL":
320 | if last_model_id != normalized_model_id:
321 | clearPipelines()
322 | cross_attention_kwargs = None
323 | model = loadModel(normalized_model_id, send_opts=send_opts)
324 | last_model_id = normalized_model_id
325 | else:
326 | if model_id != MODEL_ID and not RUNTIME_DOWNLOADS:
327 | return {
328 | "$error": {
329 | "code": "MODEL_MISMATCH",
330 | "message": f'Model "{model_id}" not available on this container which hosts "{MODEL_ID}"',
331 | "requested": model_id,
332 | "available": MODEL_ID,
333 | }
334 | }
335 |
336 | if PIPELINE == "ALL":
337 | pipeline_name = call_inputs.get("PIPELINE", None)
338 | if not pipeline_name:
339 | pipeline_name = "AutoPipelineForText2Image"
340 | result["$meta"].update({"PIPELINE": pipeline_name})
341 |
342 | pipeline = getPipelineForModel(
343 | pipeline_name,
344 | model,
345 | normalized_model_id,
346 | model_revision=model_revision if RUNTIME_DOWNLOADS else MODEL_REVISION,
347 | model_precision=model_precision if RUNTIME_DOWNLOADS else MODEL_PRECISION,
348 | )
349 | if not pipeline:
350 | return {
351 | "$error": {
352 | "code": "NO_SUCH_PIPELINE",
353 | "message": f'"{pipeline_name}" is not an official nor community Diffusers pipelines',
354 | "requested": pipeline_name,
355 | "available": listAvailablePipelines(),
356 | }
357 | }
358 | else:
359 | pipeline = model
360 |
361 | scheduler_name = call_inputs.get("SCHEDULER", None)
362 | if not scheduler_name:
363 | scheduler_name = "DPMSolverMultistepScheduler"
364 | result["$meta"].update({"SCHEDULER": scheduler_name})
365 |
366 | pipeline.scheduler = getScheduler(normalized_model_id, scheduler_name)
367 | if pipeline.scheduler == None:
368 | return {
369 | "$error": {
370 | "code": "INVALID_SCHEDULER",
371 | "message": "",
372 | "requeted": call_inputs.get("SCHEDULER", None),
373 | "available": ", ".join(SCHEDULERS),
374 | }
375 | }
376 |
377 | safety_checker = call_inputs.get("safety_checker", True)
378 | pipeline.safety_checker = (
379 | model.safety_checker
380 | if safety_checker and hasattr(model, "safety_checker")
381 | else None
382 | )
383 | is_url = call_inputs.get("is_url", False)
384 | image_decoder = getFromUrl if is_url else decodeBase64Image
385 |
386 | textual_inversions = call_inputs.get("textual_inversions", [])
387 | await handle_textual_inversions(textual_inversions, model, status=status)
388 |
389 | # Better to use new lora_weights in next section
390 | attn_procs = call_inputs.get("attn_procs", None)
391 | if attn_procs is not last_attn_procs:
392 | if attn_procs:
393 | raise Exception(
394 | "[REMOVED] Using `attn_procs` for LoRAs is no longer supported. "
395 | + "Please use `lora_weights` instead."
396 | )
397 | last_attn_procs = attn_procs
398 | # if attn_procs:
399 | # storage = Storage(attn_procs, no_raise=True)
400 | # if storage:
401 | # hash = sha256(attn_procs.encode("utf-8")).hexdigest()
402 | # attn_procs_from_safetensors = call_inputs.get(
403 | # "attn_procs_from_safetensors", None
404 | # )
405 | # fname = storage.url.split("/").pop()
406 | # if attn_procs_from_safetensors and not re.match(
407 | # r".safetensors", attn_procs
408 | # ):
409 | # fname += ".safetensors"
410 | # if True:
411 | # # TODO, way to specify explicit name
412 | # path = os.path.join(
413 | # MODELS_DIR, "attn_proc--url_" + hash[:7] + "--" + fname
414 | # )
415 | # attn_procs = path
416 | # if not os.path.exists(path):
417 | # storage.download_and_extract(path)
418 | # print("Load attn_procs " + attn_procs)
419 | # # Workaround https://github.com/huggingface/diffusers/pull/2448#issuecomment-1453938119
420 | # if storage and not re.search(r".safetensors", attn_procs):
421 | # attn_procs = torch.load(attn_procs, map_location="cpu")
422 | # pipeline.unet.load_attn_procs(attn_procs)
423 | # else:
424 | # print("Clearing attn procs")
425 | # pipeline.unet.set_attn_processor(CrossAttnProcessor())
426 |
427 | # Currently we only support a single string, but we should allow
428 | # and array too in anticipation of multi-LoRA support in diffusers
429 | # tracked at https://github.com/huggingface/diffusers/issues/2613.
430 | lora_weights = call_inputs.get("lora_weights", None)
431 | lora_weights_joined = json.dumps(lora_weights)
432 | if last_lora_weights != lora_weights_joined:
433 | if last_lora_weights != None and last_lora_weights != "[]":
434 | print("Unloading previous LoRA weights")
435 | pipeline.unload_lora_weights()
436 |
437 | last_lora_weights = lora_weights_joined
438 | cross_attention_kwargs = {}
439 |
440 | if type(lora_weights) is not list:
441 | lora_weights = [lora_weights] if lora_weights else []
442 |
443 | if len(lora_weights) > 0:
444 | for weights in lora_weights:
445 | storage = Storage(weights, no_raise=True, status=status)
446 | if storage:
447 | storage_query_fname = storage.query.get("fname")
448 | storage_query_scale = (
449 | float(storage.query.get("scale")[0])
450 | if storage.query.get("scale")
451 | else 1
452 | )
453 | cross_attention_kwargs.update({"scale": storage_query_scale})
454 | # https://github.com/damian0815/compel/issues/42#issuecomment-1656989385
455 | pipeline._lora_scale = storage_query_scale
456 | if storage_query_fname:
457 | fname = storage_query_fname[0]
458 | else:
459 | hash = sha256(weights.encode("utf-8")).hexdigest()
460 | fname = "url_" + hash[:7] + "--" + storage.url.split("/").pop()
461 | cache_fname = "lora_weights--" + fname
462 | path = os.path.join(MODELS_DIR, cache_fname)
463 | if not os.path.exists(path):
464 | await asyncio.to_thread(storage.download_file, path)
465 | print("Load lora_weights `" + weights + "` from `" + path + "`")
466 | pipeline.load_lora_weights(
467 | MODELS_DIR, weight_name=cache_fname, local_files_only=True
468 | )
469 | else:
470 | print("Loading from huggingface not supported yet: " + weights)
471 | # maybe something like sayakpaul/civitai-light-shadow-lora#lora=l_a_s.s9s?
472 | # lora_model_id = "sayakpaul/civitai-light-shadow-lora"
473 | # lora_filename = "light_and_shadow.safetensors"
474 | # pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
475 | else:
476 | print("No changes to LoRAs since last call")
477 |
478 | # TODO, generalize
479 | mi_cross_attention_kwargs = model_inputs.get("cross_attention_kwargs", None)
480 | if mi_cross_attention_kwargs:
481 | model_inputs.pop("cross_attention_kwargs")
482 | if isinstance(mi_cross_attention_kwargs, str):
483 | if not cross_attention_kwargs:
484 | cross_attention_kwargs = {}
485 | cross_attention_kwargs.update(json.loads(mi_cross_attention_kwargs))
486 | elif type(mi_cross_attention_kwargs) == dict:
487 | if not cross_attention_kwargs:
488 | cross_attention_kwargs = {}
489 | cross_attention_kwargs.update(mi_cross_attention_kwargs)
490 | else:
491 | return {
492 | "$error": {
493 | "code": "INVALID_CROSS_ATTENTION_KWARGS",
494 | "message": "`cross_attention_kwargs` should be a dict or json string",
495 | }
496 | }
497 |
498 | print({"cross_attention_kwargs": cross_attention_kwargs})
499 | if cross_attention_kwargs:
500 | model_inputs.update({"cross_attention_kwargs": cross_attention_kwargs})
501 |
502 | # Parse out your arguments
503 | # prompt = model_inputs.get("prompt", None)
504 | # if prompt == None:
505 | # return {"message": "No prompt provided"}
506 | #
507 | # height = model_inputs.get("height", 512)
508 | # width = model_inputs.get("width", 512)
509 | # num_inference_steps = model_inputs.get("num_inference_steps", 50)
510 | # guidance_scale = model_inputs.get("guidance_scale", 7.5)
511 | # seed = model_inputs.get("seed", None)
512 | # strength = model_inputs.get("strength", 0.75)
513 |
514 | if "init_image" in model_inputs:
515 | model_inputs["init_image"] = image_decoder(
516 | model_inputs.get("init_image"), "init_image"
517 | )
518 |
519 | if "image" in model_inputs:
520 | model_inputs["image"] = image_decoder(model_inputs.get("image"), "image")
521 |
522 | if "mask_image" in model_inputs:
523 | model_inputs["mask_image"] = image_decoder(
524 | model_inputs.get("mask_image"), "mask_image"
525 | )
526 |
527 | if "instance_images" in model_inputs:
528 | model_inputs["instance_images"] = list(
529 | map(
530 | lambda str: image_decoder(str, "instance_image"),
531 | model_inputs["instance_images"],
532 | )
533 | )
534 |
535 | await send("inference", "start", {"startRequestId": startRequestId}, send_opts)
536 |
537 | # Run patchmatch for inpainting
538 | if call_inputs.get("FILL_MODE", None) == "patchmatch":
539 | sel_buffer = np.array(model_inputs.get("init_image"))
540 | img = sel_buffer[:, :, 0:3]
541 | mask = sel_buffer[:, :, -1]
542 | img = patch_match.inpaint(img, mask=255 - mask, patch_size=3)
543 | model_inputs["init_image"] = PIL.Image.fromarray(img)
544 | mask = 255 - mask
545 | mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
546 | mask = mask.repeat(8, axis=0).repeat(8, axis=1)
547 | model_inputs["mask_image"] = PIL.Image.fromarray(mask)
548 |
549 | # Turning on takes 3ms and turning off 1ms... don't worry, I've got your back :)
550 | # x_m_e_a = call_inputs.get("xformers_memory_efficient_attention", True)
551 | # last_x_m_e_a = last_xformers_memory_efficient_attention.get(pipeline, None)
552 | # if x_m_e_a != last_x_m_e_a:
553 | # if x_m_e_a == True:
554 | # print("pipeline.enable_xformers_memory_efficient_attention()")
555 | # pipeline.enable_xformers_memory_efficient_attention() # default on
556 | # elif x_m_e_a == False:
557 | # print("pipeline.disable_xformers_memory_efficient_attention()")
558 | # pipeline.disable_xformers_memory_efficient_attention()
559 | # else:
560 | # return {
561 | # "$error": {
562 | # "code": "INVALID_XFORMERS_MEMORY_EFFICIENT_ATTENTION_VALUE",
563 | # "message": f"x_m_e_a expects True or False, not: {x_m_e_a}",
564 | # "requested": x_m_e_a,
565 | # "available": [True, False],
566 | # }
567 | # }
568 | # last_xformers_memory_efficient_attention.update({pipeline: x_m_e_a})
569 |
570 | # Run the model
571 | # with autocast(device_id):
572 | # image = pipeline(**model_inputs).images[0]
573 |
574 | if call_inputs.get("train", None) == "dreambooth":
575 | if not USE_DREAMBOOTH:
576 | return {
577 | "$error": {
578 | "code": "TRAIN_DREAMBOOTH_NOT_AVAILABLE",
579 | "message": 'Called with callInput { train: "dreambooth" } but built with USE_DREAMBOOTH=0',
580 | }
581 | }
582 |
583 | if RUNTIME_DOWNLOADS:
584 | if os.path.isdir(model_dir):
585 | normalized_model_id = model_dir
586 |
587 | torch.set_grad_enabled(True)
588 | result = result | await asyncio.to_thread(
589 | TrainDreamBooth,
590 | normalized_model_id,
591 | pipeline,
592 | model_inputs,
593 | call_inputs,
594 | send_opts=send_opts,
595 | )
596 | torch.set_grad_enabled(False)
597 | await send("inference", "done", {"startRequestId": startRequestId}, send_opts)
598 | result.update({"$timings": getTimings()})
599 | return result
600 |
601 | # Do this after dreambooth as dreambooth accepts a seed int directly.
602 | seed = model_inputs.get("seed", None)
603 | if seed == None:
604 | generator = torch.Generator(device=device)
605 | generator.seed()
606 | else:
607 | generator = torch.Generator(device=device).manual_seed(seed)
608 | del model_inputs["seed"]
609 |
610 | model_inputs.update({"generator": generator})
611 |
612 | callback = None
613 | if model_inputs.get("callback_steps", None):
614 |
615 | def callback(step: int, timestep: int, latents: torch.FloatTensor):
616 | asyncio.run(
617 | send(
618 | "inference",
619 | "progress",
620 | {"startRequestId": startRequestId, "step": step},
621 | send_opts,
622 | )
623 | )
624 |
625 | else:
626 | vae = pipeline.vae
627 | # vae = tinyVae(vae)
628 | scaling_factor = vae.config.scaling_factor
629 | image_processor = pipeline.image_processor
630 |
631 | def callback(step: int, timestep: int, latents: torch.FloatTensor):
632 | status.update(
633 | "inference", step / model_inputs.get("num_inference_steps", 50)
634 | )
635 |
636 | # with torch.no_grad():
637 | # image = vae.decode(latents / scaling_factor, return_dict=False)[0]
638 | # image = image_processor.postprocess(image, output_type="pil")[0]
639 | # image.save(f"step_{step}_img0.png")
640 |
641 | is_sdxl = (
642 | isinstance(model, StableDiffusionXLPipeline)
643 | or isinstance(model, StableDiffusionXLImg2ImgPipeline)
644 | or isinstance(model, StableDiffusionXLInpaintPipeline)
645 | )
646 |
647 | with torch.inference_mode():
648 | custom_pipeline_method = call_inputs.get("custom_pipeline_method", None)
649 | print(
650 | {
651 | "callback": callback,
652 | "**model_inputs": model_inputs,
653 | },
654 | )
655 |
656 | if call_inputs.get("compel_prompts", False):
657 | prepare_prompts(pipeline, model_inputs, is_sdxl)
658 |
659 | try:
660 | async_pipeline = asyncio.to_thread(
661 | getattr(pipeline, custom_pipeline_method)
662 | if custom_pipeline_method
663 | else pipeline,
664 | callback=callback,
665 | **model_inputs,
666 | )
667 | # if call_inputs.get("PIPELINE") != "StableDiffusionPipeline":
668 | # # autocast im2img and inpaint which are broken in 0.4.0, 0.4.1
669 | # # still broken in 0.5.1
670 | # with autocast(device_id):
671 | # images = (await async_pipeline).images
672 | # else:
673 | pipeResult = await async_pipeline
674 | images = pipeResult.images
675 |
676 | except Exception as err:
677 | return {
678 | "$error": {
679 | "code": "PIPELINE_ERROR",
680 | "name": type(err).__name__,
681 | "message": str(err),
682 | "stack": traceback.format_exc(),
683 | }
684 | }
685 |
686 | images_base64 = []
687 | image_format = call_inputs.get("image_format", "PNG")
688 | image_opts = (
689 | {"lossless": True} if image_format == "PNG" or image_format == "WEBP" else {}
690 | )
691 | for image in images:
692 | buffered = BytesIO()
693 | image.save(buffered, format=image_format, **image_opts)
694 | images_base64.append(base64.b64encode(buffered.getvalue()).decode("utf-8"))
695 |
696 | await send("inference", "done", {"startRequestId": startRequestId}, send_opts)
697 |
698 | # Return the results as a dictionary
699 | if len(images_base64) > 1:
700 | result = result | {"images_base64": images_base64}
701 | else:
702 | result = result | {"image_base64": images_base64[0]}
703 |
704 | nsfw_content_detected = pipeResult.get("nsfw_content_detected", None)
705 | if nsfw_content_detected:
706 | result = result | {"nsfw_content_detected": nsfw_content_detected}
707 |
708 | # TODO, move and generalize in device.py
709 | mem_usage = 0
710 | if torch.cuda.is_available():
711 | mem_usage = torch.cuda.memory_allocated() / torch.cuda.max_memory_allocated()
712 |
713 | result = result | {"$timings": getTimings(), "$mem_usage": mem_usage}
714 |
715 | return result
716 |
--------------------------------------------------------------------------------
/api/convert_to_diffusers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import requests
3 | import subprocess
4 | import torch
5 | import json
6 | from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
7 | download_from_original_stable_diffusion_ckpt,
8 | )
9 | from diffusers.pipelines.stable_diffusion import (
10 | StableDiffusionInpaintPipeline,
11 | )
12 | from utils import Storage
13 | from device import device_id
14 |
15 | MODEL_ID = os.environ.get("MODEL_ID", None)
16 | CHECKPOINT_DIR = "/root/.cache/checkpoints"
17 | CHECKPOINT_URL = os.environ.get("CHECKPOINT_URL", None)
18 | CHECKPOINT_CONFIG_URL = os.environ.get("CHECKPOINT_CONFIG_URL", None)
19 | CHECKPOINT_ARGS = os.environ.get("CHECKPOINT_ARGS", None)
20 | # _CONVERT_SPECIAL = os.environ.get("_CONVERT_SPECIAL", None)
21 |
22 |
23 | def main(
24 | model_id: str,
25 | checkpoint_url: str,
26 | checkpoint_config_url: str,
27 | checkpoint_args: dict = {},
28 | path=None,
29 | ):
30 | if not path:
31 | fname = checkpoint_url.split("/").pop()
32 | path = os.path.join(CHECKPOINT_DIR, fname)
33 |
34 | if checkpoint_config_url and checkpoint_config_url != "":
35 | storage = Storage(checkpoint_config_url)
36 | configPath = CHECKPOINT_DIR + "/" + path + "_config.yaml"
37 | print(f"Downloading {checkpoint_config_url} to {configPath}...")
38 | storage.download_file(configPath)
39 |
40 | # specialSrc = "https://raw.githubusercontent.com/hafriedlander/diffusers/stable_diffusion_2/scripts/convert_original_stable_diffusion_to_diffusers.py"
41 | # specialPath = CHECKPOINT_DIR + "/" + "convert_special.py"
42 | # if _CONVERT_SPECIAL:
43 | # storage = Storage(specialSrc)
44 | # print(f"Downloading {specialSrc} to {specialPath}")
45 | # storage.download_file(specialPath)
46 |
47 | # scriptPath = (
48 | # # specialPath
49 | # # if _CONVERT_SPECIAL
50 | # # else
51 | # "./diffusers/scripts/convert_original_stable_diffusion_to_diffusers.py"
52 | # )
53 |
54 | print("Converting " + path + " to diffusers model " + model_id + "...", flush=True)
55 |
56 | # These are now in main requirements.txt.
57 | # subprocess.run(
58 | # ["pip", "install", "omegaconf", "pytorch_lightning", "tensorboard"], check=True
59 | # )
60 | # Diffusers now uses requests instead, yay!
61 | # subprocess.run(["apt-get", "install", "-y", "wget"], check=True)
62 |
63 | # We can now specify this ourselves and don't need to modify the script.
64 | # if device_id == "cpu":
65 | # subprocess.run(
66 | # [
67 | # "sed",
68 | # "-i",
69 | # # Force loading into CPU
70 | # "s/torch.load(args.checkpoint_path)/torch.load(args.checkpoint_path, map_location=torch.device('cpu'))/",
71 | # scriptPath,
72 | # ]
73 | # )
74 | # # Nice to check but also there seems to be a race condition here which
75 | # # needs further investigation. Python docs are clear that subprocess.run()
76 | # # will "Wait for command to complete, then return a CompletedProcess instance."
77 | # # But it really seems as though without the grep in the middle, the script is
78 | # # run before sed completes, or maybe there's some FS level caching gotchas.
79 | # subprocess.run(
80 | # [
81 | # "grep",
82 | # "torch.load",
83 | # scriptPath,
84 | # ],
85 | # check=True,
86 | # )
87 |
88 | # args = [
89 | # "python3",
90 | # scriptPath,
91 | # "--extract_ema",
92 | # "--checkpoint_path",
93 | # fname,
94 | # "--dump_path",
95 | # model_id,
96 | # ]
97 |
98 | # if checkpoint_config_url:
99 | # args.append("--original_config_file")
100 | # args.append(configPath)
101 |
102 | # subprocess.run(
103 | # args,
104 | # check=True,
105 | # )
106 |
107 | # Oh yay! Diffusers abstracted this now, so much easier to use.
108 | # But less tested. Changed on 2023-02-18. TODO, remove commented
109 | # out code above once this has more usage.
110 |
111 | # diffusers defaults
112 | args = {
113 | "scheduler_type": "pndm",
114 | }
115 |
116 | # our defaults
117 | args.update(
118 | {
119 | "checkpoint_path_or_dict": path,
120 | "original_config_file": configPath if checkpoint_config_url else None,
121 | "device": device_id,
122 | "extract_ema": True,
123 | "from_safetensors": "safetensor" in path.lower(),
124 | }
125 | )
126 |
127 | if "inpaint" in path or "Inpaint" in path:
128 | args.update({"pipeline_class": StableDiffusionInpaintPipeline})
129 |
130 | # user overrides
131 | args.update(checkpoint_args)
132 |
133 | pipe = download_from_original_stable_diffusion_ckpt(**args)
134 | pipe.save_pretrained(model_id, safe_serialization=True)
135 |
136 |
137 | if __name__ == "__main__":
138 | # response = requests.get(
139 | # "https://github.com/huggingface/diffusers/raw/main/scripts/convert_original_stable_diffusion_to_diffusers.py"
140 | # )
141 | # open("convert_original_stable_diffusion_to_diffusers.py", "wb").write(
142 | # response.content
143 | # )
144 |
145 | if CHECKPOINT_URL and CHECKPOINT_URL != "":
146 | checkpoint_args = json.loads(CHECKPOINT_ARGS) if CHECKPOINT_ARGS else {}
147 | main(
148 | MODEL_ID,
149 | CHECKPOINT_URL,
150 | CHECKPOINT_CONFIG_URL,
151 | checkpoint_args=checkpoint_args,
152 | )
153 |
--------------------------------------------------------------------------------
/api/device.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | if torch.cuda.is_available():
4 | print("[device] CUDA (Nvidia) detected")
5 | device_id = "cuda"
6 | device_name = torch.cuda.get_device_name()
7 | elif torch.backends.mps.is_available():
8 | print("[device] MPS (MacOS Metal, Apple M1, etc) detected")
9 | device_id = "mps"
10 | device_name = "MPS"
11 | else:
12 | print("[device] CPU only - no GPU detected")
13 | device_id = "cpu"
14 | device_name = "CPU only"
15 |
16 | if not torch.backends.cuda.is_built():
17 | print(
18 | "CUDA not available because the current PyTorch install was not "
19 | "built with CUDA enabled."
20 | )
21 | if torch.backends.mps.is_built():
22 | print(
23 | "MPS not available because the current MacOS version is not 12.3+ "
24 | "and/or you do not have an MPS-enabled device on this machine."
25 | )
26 | else:
27 | print(
28 | "MPS not available because the current PyTorch install was not "
29 | "built with MPS enabled."
30 | )
31 |
32 | device = torch.device(device_id)
33 |
--------------------------------------------------------------------------------
/api/download.py:
--------------------------------------------------------------------------------
1 | # In this file, we define download_model
2 | # It runs during container build time to get model weights built into the container
3 |
4 | import os
5 | from loadModel import loadModel, MODEL_IDS
6 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
7 | from transformers import CLIPTextModel, CLIPTokenizer
8 | from utils import Storage
9 | import subprocess
10 | from pathlib import Path
11 | import shutil
12 | from convert_to_diffusers import main as convert_to_diffusers
13 | from download_checkpoint import main as download_checkpoint
14 | from status import status
15 | import asyncio
16 |
17 | USE_DREAMBOOTH = os.environ.get("USE_DREAMBOOTH")
18 | HF_AUTH_TOKEN = os.environ.get("HF_AUTH_TOKEN")
19 | RUNTIME_DOWNLOADS = os.environ.get("RUNTIME_DOWNLOADS")
20 |
21 | HOME = os.path.expanduser("~")
22 | MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api")
23 | Path(MODELS_DIR).mkdir(parents=True, exist_ok=True)
24 |
25 |
26 | # i.e. don't run during build
27 | async def send(type: str, status: str, payload: dict = {}, send_opts: dict = {}):
28 | if RUNTIME_DOWNLOADS:
29 | from send import send as _send
30 |
31 | await _send(type, status, payload, send_opts)
32 |
33 |
34 | def normalize_model_id(model_id: str, model_revision):
35 | normalized_model_id = "models--" + model_id.replace("/", "--")
36 | if model_revision:
37 | normalized_model_id += "--" + model_revision
38 | return normalized_model_id
39 |
40 |
41 | async def download_model(
42 | model_url=None,
43 | model_id=None,
44 | model_revision=None,
45 | checkpoint_url=None,
46 | checkpoint_config_url=None,
47 | hf_model_id=None,
48 | model_precision=None,
49 | send_opts={},
50 | pipeline_class=None,
51 | ):
52 | print(
53 | "download_model",
54 | {
55 | "model_url": model_url,
56 | "model_id": model_id,
57 | "model_revision": model_revision,
58 | "hf_model_id": hf_model_id,
59 | "checkpoint_url": checkpoint_url,
60 | "checkpoint_config_url": checkpoint_config_url,
61 | },
62 | )
63 | hf_model_id = hf_model_id or model_id
64 | normalized_model_id = model_id
65 |
66 | # if model_url != "": # throws an error, useful to debug stdout/stderr order
67 | if model_url:
68 | normalized_model_id = normalize_model_id(model_id, model_revision)
69 | print({"normalized_model_id": normalized_model_id})
70 | filename = model_url.split("/").pop()
71 | if not filename:
72 | filename = normalized_model_id + ".tar.zst"
73 | model_file = os.path.join(MODELS_DIR, filename)
74 | storage = Storage(
75 | model_url, default_path=normalized_model_id + ".tar.zst", status=status
76 | )
77 | exists = storage.file_exists()
78 | if exists:
79 | model_dir = os.path.join(MODELS_DIR, normalized_model_id)
80 | print("model_dir", model_dir)
81 | await asyncio.to_thread(storage.download_and_extract, model_file, model_dir)
82 | else:
83 | if checkpoint_url:
84 | path = download_checkpoint(checkpoint_url)
85 | convert_to_diffusers(
86 | model_id=model_id,
87 | checkpoint_url=checkpoint_url,
88 | checkpoint_config_url=checkpoint_config_url,
89 | path=path,
90 | )
91 | else:
92 | print("Does not exist, let's try find it on huggingface")
93 | print(
94 | {
95 | "model_precision": model_precision,
96 | "model_revision": model_revision,
97 | }
98 | )
99 | # This would be quicker to just model.to(device) afterwards, but
100 | # this conveniently logs all the timings (and doesn't happen often)
101 | print("download")
102 | await send("download", "start", {}, send_opts)
103 | model = loadModel(
104 | hf_model_id,
105 | False,
106 | precision=model_precision,
107 | revision=model_revision,
108 | pipeline_class=pipeline_class,
109 | ) # download
110 | await send("download", "done", {}, send_opts)
111 |
112 | print("load")
113 | model = loadModel(
114 | hf_model_id,
115 | True,
116 | precision=model_precision,
117 | revision=model_revision,
118 | pipeline_class=pipeline_class,
119 | ) # load
120 | # dir = "models--" + model_id.replace("/", "--") + "--dda"
121 | dir = os.path.join(MODELS_DIR, normalized_model_id)
122 | model.save_pretrained(dir, safe_serialization=True)
123 |
124 | # This is all duped from train_dreambooth, need to refactor TODO XXX
125 | await send("compress", "start", {}, send_opts)
126 | subprocess.run(
127 | f"tar cvf - -C {dir} . | zstd -o {model_file}",
128 | shell=True,
129 | check=True, # TODO, rather don't raise and return an error in JSON
130 | )
131 |
132 | await send("compress", "done", {}, send_opts)
133 | subprocess.run(["ls", "-l", model_file])
134 |
135 | await send("upload", "start", {}, send_opts)
136 | upload_result = storage.upload_file(model_file, filename)
137 | await send("upload", "done", {}, send_opts)
138 | print(upload_result)
139 | os.remove(model_file)
140 |
141 | # leave model dir for future loads... make configurable?
142 | # shutil.rmtree(dir)
143 |
144 | # TODO, swap directories, inside HF's cache structure.
145 |
146 | else:
147 | if checkpoint_url:
148 | path = download_checkpoint(checkpoint_url)
149 | convert_to_diffusers(
150 | model_id=model_id,
151 | checkpoint_url=checkpoint_url,
152 | checkpoint_config_url=checkpoint_config_url,
153 | path=path,
154 | )
155 | else:
156 | # do a dry run of loading the huggingface model, which will download weights at build time
157 | loadModel(
158 | model_id=hf_model_id,
159 | load=False,
160 | precision=model_precision,
161 | revision=model_revision,
162 | pipeline_class=pipeline_class,
163 | )
164 |
165 | # if USE_DREAMBOOTH:
166 | # Actually we can re-use these from the above loaded model
167 | # Will remove this soon if no more surprises
168 | # for subfolder, model in [
169 | # ["tokenizer", CLIPTokenizer],
170 | # ["text_encoder", CLIPTextModel],
171 | # ["vae", AutoencoderKL],
172 | # ["unet", UNet2DConditionModel],
173 | # ["scheduler", DDPMScheduler]
174 | # ]:
175 | # print(subfolder, model)
176 | # model.from_pretrained(
177 | # MODEL_ID,
178 | # subfolder=subfolder,
179 | # revision=revision,
180 | # use_auth_token=HF_AUTH_TOKEN,
181 | # )
182 |
183 |
184 | if __name__ == "__main__":
185 | asyncio.run(
186 | download_model(
187 | model_url=os.environ.get("MODEL_URL"),
188 | model_id=os.environ.get("MODEL_ID"),
189 | hf_model_id=os.environ.get("HF_MODEL_ID"),
190 | model_revision=os.environ.get("MODEL_REVISION"),
191 | model_precision=os.environ.get("MODEL_PRECISION"),
192 | checkpoint_url=os.environ.get("CHECKPOINT_URL"),
193 | checkpoint_config_url=os.environ.get("CHECKPOINT_CONFIG_URL"),
194 | )
195 | )
196 |
--------------------------------------------------------------------------------
/api/download_checkpoint.py:
--------------------------------------------------------------------------------
1 | import os
2 | from utils import Storage
3 |
4 | CHECKPOINT_URL = os.environ.get("CHECKPOINT_URL", None)
5 | CHECKPOINT_DIR = "/root/.cache/checkpoints"
6 |
7 |
8 | def main(checkpoint_url: str):
9 | if not os.path.isdir(CHECKPOINT_DIR):
10 | os.makedirs(CHECKPOINT_DIR)
11 |
12 | storage = Storage(checkpoint_url)
13 | storage_query_fname = storage.query.get("fname")
14 | if storage_query_fname:
15 | fname = storage_query_fname[0]
16 | else:
17 | fname = checkpoint_url.split("/").pop()
18 | path = os.path.join(CHECKPOINT_DIR, fname)
19 |
20 | if not os.path.isfile(path):
21 | storage.download_file(path)
22 |
23 | return path
24 |
25 |
26 | if __name__ == "__main__":
27 | if CHECKPOINT_URL:
28 | main(CHECKPOINT_URL)
29 |
--------------------------------------------------------------------------------
/api/extras/__init__.py:
--------------------------------------------------------------------------------
1 | from .upsample import upsample
2 |
--------------------------------------------------------------------------------
/api/extras/upsample/__init__.py:
--------------------------------------------------------------------------------
1 | from .upsample import upsample
2 |
--------------------------------------------------------------------------------
/api/extras/upsample/models.py:
--------------------------------------------------------------------------------
1 | upsamplers = {
2 | "RealESRGAN_x4plus": {
3 | "name": "General - RealESRGANplus",
4 | "weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
5 | "filename": "RealESRGAN_x4plus.pth",
6 | "net": "RRDBNet",
7 | "initArgs": {
8 | "num_in_ch": 3,
9 | "num_out_ch": 3,
10 | "num_feat": 64,
11 | "num_block": 23,
12 | "num_grow_ch": 32,
13 | "scale": 4,
14 | },
15 | "netscale": 4,
16 | },
17 | # "RealESRNet_x4plus": {
18 | # "name": "",
19 | # "weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.1/RealESRNet_x4plus.pth",
20 | # "path": "weights/RealESRNet_x4plus.pth",
21 | # },
22 | "RealESRGAN_x4plus_anime_6B": {
23 | "name": "Anime - anime6B",
24 | "weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth",
25 | "filename": "RealESRGAN_x4plus_anime_6B.pth",
26 | "net": "RRDBNet",
27 | "initArgs": {
28 | "num_in_ch": 3,
29 | "num_out_ch": 3,
30 | "num_feat": 64,
31 | "num_block": 6,
32 | "num_grow_ch": 32,
33 | "scale": 4,
34 | },
35 | "netscale": 4,
36 | },
37 | # "RealESRGAN_x2plus": {
38 | # "name": "",
39 | # "weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth",
40 | # "path": "weights/RealESRGAN_x2plus.pth",
41 | # },
42 | # "realesr-animevideov3": {
43 | # "name": "AnimeVideo - v3",
44 | # "weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-animevideov3.pth",
45 | # "path": "weights/realesr-animevideov3.pth",
46 | # },
47 | "realesr-general-x4v3": {
48 | "name": "General - v3",
49 | # [, "weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-wdn-x4v3.pth" ],
50 | "weights": "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth",
51 | "filename": "realesr-general-x4v3.pth",
52 | "net": "SRVGGNetCompact",
53 | "initArgs": {
54 | "num_in_ch": 3,
55 | "num_out_ch": 3,
56 | "num_feat": 64,
57 | "num_conv": 32,
58 | "upscale": 4,
59 | "act_type": "prelu",
60 | },
61 | "netscale": 4,
62 | },
63 | }
64 |
65 | face_enhancers = {
66 | "GFPGAN": {
67 | "name": "GFPGAN",
68 | "weights": "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth",
69 | "filename": "GFPGANv1.4.pth",
70 | },
71 | }
72 |
73 | models_by_type = {
74 | "upsamplers": upsamplers,
75 | "face_enhancers": face_enhancers,
76 | }
77 |
--------------------------------------------------------------------------------
/api/extras/upsample/upsample.py:
--------------------------------------------------------------------------------
1 | import os
2 | import asyncio
3 | from pathlib import Path
4 |
5 | import base64
6 | from io import BytesIO
7 | import PIL
8 | import json
9 | import cv2
10 | import numpy as np
11 | import torch
12 | import torchvision
13 |
14 | from basicsr.archs.rrdbnet_arch import RRDBNet
15 | from realesrgan import RealESRGANer
16 | from realesrgan.archs.srvgg_arch import SRVGGNetCompact
17 | from gfpgan import GFPGANer
18 |
19 | from .models import models_by_type, upsamplers, face_enhancers
20 | from status import status
21 | from utils import Storage
22 | from send import send
23 |
24 | print(
25 | {
26 | "torch.__version__": torch.__version__,
27 | "torchvision.__version__": torchvision.__version__,
28 | }
29 | )
30 |
31 | HOME = os.path.expanduser("~")
32 | CACHE_DIR = os.path.join(HOME, ".cache", "diffusers-api", "upsample")
33 |
34 |
35 | def cache_path(filename):
36 | return os.path.join(CACHE_DIR, filename)
37 |
38 |
39 | async def assert_model_exists(src, filename, send_opts, opts={}):
40 | dest = cache_path(filename) if not opts.get("absolutePath", None) else filename
41 | if not os.path.exists(dest):
42 | await send("download", "start", {}, send_opts)
43 | storage = Storage(src, status=status)
44 | # await storage.download_file(dest)
45 | await asyncio.to_thread(storage.download_file, dest)
46 | await send("download", "done", {}, send_opts)
47 |
48 |
49 | async def download_models(send_opts={}):
50 | Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
51 |
52 | for type in models_by_type:
53 | models = models_by_type[type]
54 | for model_key in models:
55 | model = models[model_key]
56 | await assert_model_exists(model["weights"], model["filename"], send_opts)
57 |
58 | Path("gfpgan/weights").mkdir(parents=True, exist_ok=True)
59 |
60 | await assert_model_exists(
61 | "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth",
62 | "detection_Resnet50_Final.pth",
63 | send_opts,
64 | )
65 | await assert_model_exists(
66 | "https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth",
67 | "parsing_parsenet.pth",
68 | send_opts,
69 | )
70 |
71 | # hardcoded paths in xinntao/facexlib
72 | filenames = ["detection_Resnet50_Final.pth", "parsing_parsenet.pth"]
73 | for file in filenames:
74 | if not os.path.exists(f"gfpgan/weights/{file}"):
75 | os.symlink(cache_path(file), f"gfpgan/weights/{file}")
76 |
77 |
78 | nets = {
79 | "RRDBNet": RRDBNet,
80 | "SRVGGNetCompact": SRVGGNetCompact,
81 | }
82 |
83 | models = {}
84 |
85 |
86 | async def upsample(model_inputs, call_inputs, send_opts={}, startRequestId=None):
87 | global models
88 |
89 | # TODO, only download relevant models for this request
90 | await download_models()
91 |
92 | model_id = call_inputs.get("MODEL_ID", None)
93 |
94 | if not model_id:
95 | return {
96 | "$error": {
97 | "code": "MISSING_MODEL_ID",
98 | "message": "call_inputs.MODEL_ID is required, but not given.",
99 | }
100 | }
101 |
102 | model = models.get(model_id, None)
103 | if not model:
104 | model = models_by_type["upsamplers"].get(model_id, None)
105 | if not model:
106 | return {
107 | "$error": {
108 | "code": "MISSING_MODEL",
109 | "message": f'Model "{model_id}" not available on this container.',
110 | "requested": model_id,
111 | "available": '"' + '", "'.join(models.keys()) + '"',
112 | }
113 | }
114 | else:
115 | modelModel = nets[model["net"]](**model["initArgs"])
116 | await send(
117 | "loadModel",
118 | "start",
119 | {"startRequestId": startRequestId},
120 | send_opts,
121 | )
122 | upsampler = RealESRGANer(
123 | scale=model["netscale"],
124 | model_path=cache_path(model["filename"]),
125 | dni_weight=None,
126 | model=modelModel,
127 | tile=0,
128 | tile_pad=10,
129 | pre_pad=0,
130 | half=True,
131 | )
132 | await send(
133 | "loadModel",
134 | "done",
135 | {"startRequestId": startRequestId},
136 | send_opts,
137 | )
138 | model.update({"model": modelModel, "upsampler": upsampler})
139 | models.update({model_id: model})
140 |
141 | upsampler = model["upsampler"]
142 |
143 | input_image = model_inputs.get("input_image", None)
144 | if not input_image:
145 | return {
146 | "$error": {
147 | "code": "NO_INPUT_IMAGE",
148 | "message": "Missing required parameter `input_image`",
149 | }
150 | }
151 |
152 | if model_id == "realesr-general-x4v3":
153 | denoise_strength = model_inputs.get("denoise_strength", 1)
154 | if denoise_strength != 1:
155 | # wdn_model_path = model_path.replace('realesr-general-x4v3', 'realesr-general-wdn-x4v3')
156 | # model_path = [model_path, wdn_model_path]
157 | # upsampler = models["realesr-general-x4v3-denoise"]
158 | # upsampler.dni_weight = dni_weight
159 | dni_weight = [denoise_strength, 1 - denoise_strength]
160 | return "TODO: denoise_strength"
161 |
162 | face_enhance = model_inputs.get("face_enhance", False)
163 | if face_enhance:
164 | face_enhancer = models.get("GFPGAN", None)
165 | if not face_enhancer:
166 | await send(
167 | "loadModel",
168 | "start",
169 | {"startRequestId": startRequestId},
170 | send_opts,
171 | )
172 | print("1) " + cache_path(face_enhancers["GFPGAN"]["filename"]))
173 | face_enhancer = GFPGANer(
174 | model_path=cache_path(face_enhancers["GFPGAN"]["filename"]),
175 | upscale=4, # args.outscale,
176 | arch="clean",
177 | channel_multiplier=2,
178 | bg_upsampler=upsampler,
179 | )
180 | await send(
181 | "loadModel",
182 | "done",
183 | {"startRequestId": startRequestId},
184 | send_opts,
185 | )
186 | models.update({"GFPGAN": face_enhancer})
187 |
188 | if face_enhance: # Use GFPGAN for face enhancement
189 | face_enhancer.bg_upsampler = upsampler
190 |
191 | # image = decodeBase64Image(model_inputs.get("input_image"))
192 | image_str = base64.b64decode(model_inputs["input_image"])
193 | image_np = np.frombuffer(image_str, dtype=np.uint8)
194 | # bytes = BytesIO(base64.decodebytes(bytes(model_inputs["input_image"], "utf-8")))
195 | img = cv2.imdecode(image_np, cv2.IMREAD_UNCHANGED)
196 |
197 | await send("inference", "start", {"startRequestId": startRequestId}, send_opts)
198 |
199 | # Run the model
200 | # with autocast("cuda"):
201 | # image = pipeline(**model_inputs).images[0]
202 | if face_enhance:
203 | _, _, output = face_enhancer.enhance(
204 | img, has_aligned=False, only_center_face=False, paste_back=True
205 | )
206 | else:
207 | output, _rgb = upsampler.enhance(img, outscale=4) # TODO outscale param
208 |
209 | image_base64 = base64.b64encode(cv2.imencode(".jpg", output)[1]).decode()
210 |
211 | await send("inference", "done", {"startRequestId": startRequestId}, send_opts)
212 |
213 | # Return the results as a dictionary
214 | return {"$meta": {}, "image_base64": image_base64}
215 |
--------------------------------------------------------------------------------
/api/getPipeline.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os, fnmatch
3 | from diffusers import (
4 | DiffusionPipeline,
5 | pipelines as diffusers_pipelines,
6 | )
7 | from precision import torch_dtype_from_precision
8 |
9 | HOME = os.path.expanduser("~")
10 | MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api")
11 | _pipelines = {}
12 | _availableCommunityPipelines = None
13 |
14 |
15 | def listAvailablePipelines():
16 | return (
17 | list(
18 | filter(
19 | lambda key: key.endswith("Pipeline"),
20 | list(diffusers_pipelines.__dict__.keys()),
21 | )
22 | )
23 | + availableCommunityPipelines()
24 | )
25 |
26 |
27 | def availableCommunityPipelines():
28 | global _availableCommunityPipelines
29 | if not _availableCommunityPipelines:
30 | _availableCommunityPipelines = list(
31 | map(
32 | lambda s: s[0:-3],
33 | fnmatch.filter(os.listdir("diffusers/examples/community"), "*.py"),
34 | )
35 | )
36 |
37 | return _availableCommunityPipelines
38 |
39 |
40 | def clearPipelines():
41 | """
42 | Clears the pipeline cache. Important to call this when changing the
43 | loaded model, as pipelines include references to the model and would
44 | therefore prevent memory being reclaimed after unloading the previous
45 | model.
46 | """
47 | global _pipelines
48 | _pipelines = {}
49 |
50 |
51 | def getPipelineClass(pipeline_name: str):
52 | if hasattr(diffusers_pipelines, pipeline_name):
53 | return getattr(diffusers_pipelines, pipeline_name)
54 | elif pipeline_name in availableCommunityPipelines():
55 | return DiffusionPipeline
56 |
57 |
58 | def getPipelineForModel(
59 | pipeline_name: str, model, model_id, model_revision, model_precision
60 | ):
61 | """
62 | Inits a new pipeline, re-using components from a previously loaded
63 | model. The pipeline is cached and future calls with the same
64 | arguments will return the previously initted instance. Be sure
65 | to call `clearPipelines()` if loading a new model, to allow the
66 | previous model to be garbage collected.
67 | """
68 | pipeline = _pipelines.get(pipeline_name)
69 | if pipeline:
70 | return pipeline
71 |
72 | start = time.time()
73 |
74 | if hasattr(diffusers_pipelines, pipeline_name):
75 | pipeline_class = getattr(diffusers_pipelines, pipeline_name)
76 | if hasattr(pipeline_class, "from_pipe"):
77 | pipeline = pipeline_class.from_pipe(model)
78 | elif hasattr(model, "components"):
79 | pipeline = pipeline_class(**model.components)
80 | else:
81 | pipeline = getattr(diffusers_pipelines, pipeline_name)(
82 | vae=model.vae,
83 | text_encoder=model.text_encoder,
84 | tokenizer=model.tokenizer,
85 | unet=model.unet,
86 | scheduler=model.scheduler,
87 | safety_checker=model.safety_checker,
88 | feature_extractor=model.feature_extractor,
89 | )
90 |
91 | elif pipeline_name in availableCommunityPipelines():
92 | model_dir = os.path.join(MODELS_DIR, model_id)
93 | if not os.path.isdir(model_dir):
94 | model_dir = None
95 |
96 | pipeline = DiffusionPipeline.from_pretrained(
97 | model_dir or model_id,
98 | revision=model_revision,
99 | torch_dtype=torch_dtype_from_precision(model_precision),
100 | custom_pipeline="./diffusers/examples/community/" + pipeline_name + ".py",
101 | local_files_only=True,
102 | **model.components,
103 | )
104 |
105 | if pipeline:
106 | _pipelines.update({pipeline_name: pipeline})
107 | diff = round((time.time() - start) * 1000)
108 | print(f"Initialized {pipeline_name} for {model_id} in {diff}ms")
109 | return pipeline
110 |
--------------------------------------------------------------------------------
/api/getScheduler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import time
4 | from diffusers import schedulers as _schedulers
5 |
6 | HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
7 | HOME = os.path.expanduser("~")
8 | MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api")
9 |
10 | SCHEDULERS = [
11 | "DPMSolverMultistepScheduler",
12 | "LMSDiscreteScheduler",
13 | "DDIMScheduler",
14 | "PNDMScheduler",
15 | "EulerAncestralDiscreteScheduler",
16 | "EulerDiscreteScheduler",
17 | ]
18 |
19 | DEFAULT_SCHEDULER = os.getenv("DEFAULT_SCHEDULER", SCHEDULERS[0])
20 |
21 |
22 | """
23 | # This was a nice idea but until we have default init vars for all schedulers
24 | # via from_pretrained(), it's a no go. In any case, loading a scheduler takes time
25 | # so better to init as needed and cache.
26 | isScheduler = re.compile(r".+Scheduler$")
27 | for key, val in _schedulers.__dict__.items():
28 | if isScheduler.match(key):
29 | schedulers.update(
30 | {
31 | key: val.from_pretrained(
32 | MODEL_ID, subfolder="scheduler", use_auth_token=HF_AUTH_TOKEN
33 | )
34 | }
35 | )
36 | """
37 |
38 |
39 | def initScheduler(MODEL_ID: str, scheduler_id: str, download=False):
40 | print(f"Initializing {scheduler_id} for {MODEL_ID}...")
41 | start = time.time()
42 | scheduler = getattr(_schedulers, scheduler_id)
43 | if scheduler == None:
44 | return None
45 |
46 | model_dir = os.path.join(MODELS_DIR, MODEL_ID)
47 | if not os.path.isdir(model_dir):
48 | model_dir = None
49 |
50 | inittedScheduler = scheduler.from_pretrained(
51 | model_dir or MODEL_ID,
52 | subfolder="scheduler",
53 | use_auth_token=HF_AUTH_TOKEN,
54 | local_files_only=not download,
55 | )
56 | diff = round((time.time() - start) * 1000)
57 | print(f"Initialized {scheduler_id} for {MODEL_ID} in {diff}ms")
58 |
59 | return inittedScheduler
60 |
61 |
62 | schedulers = {}
63 |
64 |
65 | def getScheduler(MODEL_ID: str, scheduler_id: str, download=False):
66 | schedulersByModel = schedulers.get(MODEL_ID, None)
67 | if schedulersByModel == None:
68 | schedulersByModel = {}
69 | schedulers.update({MODEL_ID: schedulersByModel})
70 |
71 | # Check for use of old names
72 | deprecated_map = {
73 | "LMS": "LMSDiscreteScheduler",
74 | "DDIM": "DDIMScheduler",
75 | "PNDM": "PNDMScheduler",
76 | }
77 | scheduler_renamed = deprecated_map.get(scheduler_id, None)
78 | if scheduler_renamed != None:
79 | print(
80 | f'[Deprecation Warning]: Scheduler "{scheduler_id}" is now '
81 | f'called "{scheduler_id}". Please rename as this will '
82 | f"stop working in a future release."
83 | )
84 | scheduler_id = scheduler_renamed
85 |
86 | scheduler = schedulersByModel.get(scheduler_id, None)
87 | if scheduler == None:
88 | scheduler = initScheduler(MODEL_ID, scheduler_id, download)
89 | schedulersByModel.update({scheduler_id: scheduler})
90 |
91 | return scheduler
92 |
--------------------------------------------------------------------------------
/api/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/api/lib/__init__.py
--------------------------------------------------------------------------------
/api/lib/prompts.py:
--------------------------------------------------------------------------------
1 | from compel import Compel, DiffusersTextualInversionManager, ReturnedEmbeddingsType
2 |
3 |
4 | def prepare_prompts(pipeline, model_inputs, is_sdxl):
5 | textual_inversion_manager = DiffusersTextualInversionManager(pipeline)
6 | if is_sdxl:
7 | compel = Compel(
8 | tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2],
9 | text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
10 | # diffusers has no ti in sdxl yet
11 | # https://github.com/huggingface/diffusers/issues/4376#issuecomment-1659016141
12 | # textual_inversion_manager=textual_inversion_manager,
13 | truncate_long_prompts=False,
14 | returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
15 | requires_pooled=[False, True],
16 | )
17 | conditioning, pooled = compel(model_inputs.get("prompt"))
18 | negative_conditioning, negative_pooled = compel(
19 | model_inputs.get("negative_prompt")
20 | )
21 | [
22 | conditioning,
23 | negative_conditioning,
24 | ] = compel.pad_conditioning_tensors_to_same_length(
25 | [conditioning, negative_conditioning]
26 | )
27 | model_inputs.update(
28 | {
29 | "prompt": None,
30 | "negative_prompt": None,
31 | "prompt_embeds": conditioning,
32 | "negative_prompt_embeds": negative_conditioning,
33 | "pooled_prompt_embeds": pooled,
34 | "negative_pooled_prompt_embeds": negative_pooled,
35 | }
36 | )
37 |
38 | else:
39 | compel = Compel(
40 | tokenizer=pipeline.tokenizer,
41 | text_encoder=pipeline.text_encoder,
42 | textual_inversion_manager=textual_inversion_manager,
43 | truncate_long_prompts=False,
44 | )
45 | conditioning = compel(model_inputs.get("prompt"))
46 | negative_conditioning = compel(model_inputs.get("negative_prompt"))
47 | [
48 | conditioning,
49 | negative_conditioning,
50 | ] = compel.pad_conditioning_tensors_to_same_length(
51 | [conditioning, negative_conditioning]
52 | )
53 | model_inputs.update(
54 | {
55 | "prompt": None,
56 | "negative_prompt": None,
57 | "prompt_embeds": conditioning,
58 | "negative_prompt_embeds": negative_conditioning,
59 | }
60 | )
61 |
--------------------------------------------------------------------------------
/api/lib/textual_inversions.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | import os
4 | import asyncio
5 | from utils import Storage
6 | from .vars import MODELS_DIR
7 |
8 | last_textual_inversions = None
9 | last_textual_inversion_model = None
10 | loaded_textual_inversion_tokens = []
11 |
12 | tokenRe = re.compile(
13 | r"[#&]{1}fname=(?P[^\.]+)\.(?:pt|safetensors)(&token=(?P[^&]+))?$"
14 | )
15 |
16 |
17 | def strMap(str: str):
18 | match = re.search(tokenRe, str)
19 | # print(match)
20 | if match:
21 | return match.group("token") or match.group("fname")
22 |
23 |
24 | def extract_tokens_from_list(textual_inversions: list):
25 | return list(map(strMap, textual_inversions))
26 |
27 |
28 | async def handle_textual_inversions(textual_inversions: list, model, status):
29 | global last_textual_inversions
30 | global last_textual_inversion_model
31 | global loaded_textual_inversion_tokens
32 |
33 | textual_inversions_str = json.dumps(textual_inversions)
34 | if (
35 | textual_inversions_str != last_textual_inversions
36 | or model is not last_textual_inversion_model
37 | ):
38 | if model is not last_textual_inversion_model:
39 | loaded_textual_inversion_tokens = []
40 | last_textual_inversion_model = model
41 | # print({"textual_inversions": textual_inversions})
42 | # tokens_to_load = extract_tokens_from_list(textual_inversions)
43 | # print({"tokens_loaded": loaded_textual_inversion_tokens})
44 | # print({"tokens_to_load": tokens_to_load})
45 | #
46 | # for token in loaded_textual_inversion_tokens:
47 | # if token not in tokens_to_load:
48 | # print("[TextualInversion] Removing uneeded token: " + token)
49 | # del pipeline.tokenizer.get_vocab()[token]
50 | # # del pipeline.text_encoder.get_input_embeddings().weight.data[token]
51 | # pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer))
52 | #
53 | # loaded_textual_inversion_tokens = tokens_to_load
54 |
55 | last_textual_inversions = textual_inversions_str
56 | for textual_inversion in textual_inversions:
57 | storage = Storage(textual_inversion, no_raise=True, status=status)
58 | if storage:
59 | storage_query_fname = storage.query.get("fname")
60 | if storage_query_fname:
61 | fname = storage_query_fname[0]
62 | else:
63 | fname = textual_inversion.split("/").pop()
64 | path = os.path.join(MODELS_DIR, "textual_inversion--" + fname)
65 | if not os.path.exists(path):
66 | await asyncio.to_thread(storage.download_file, path)
67 | print("Load textual inversion " + path)
68 | token = storage.query.get("token", None)
69 | if token not in loaded_textual_inversion_tokens:
70 | model.load_textual_inversion(
71 | path, token=token, local_files_only=True
72 | )
73 | loaded_textual_inversion_tokens.append(token)
74 | else:
75 | print("Load textual inversion " + textual_inversion)
76 | model.load_textual_inversion(textual_inversion)
77 | else:
78 | print("No changes to textual inversions since last call")
79 |
--------------------------------------------------------------------------------
/api/lib/textual_inversions_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from .textual_inversions import extract_tokens_from_list
3 |
4 |
5 | class TextualInversionsTest(unittest.TestCase):
6 | def test_extract_tokens_query_fname(self):
7 | tis = ["https://civitai.com/api/download/models/106132#fname=4nj0lie.pt"]
8 | tokens = extract_tokens_from_list(tis)
9 | self.assertEqual(tokens[0], "4nj0lie")
10 |
11 | def test_extract_tokens_query_token(self):
12 | tis = [
13 | "https://civitai.com/api/download/models/106132#fname=4nj0lie.pt&token=4nj0lie"
14 | ]
15 | tokens = extract_tokens_from_list(tis)
16 | self.assertEqual(tokens[0], "4nj0lie")
17 |
--------------------------------------------------------------------------------
/api/lib/vars.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | RUNTIME_DOWNLOADS = os.getenv("RUNTIME_DOWNLOADS") == "1"
4 | USE_DREAMBOOTH = os.getenv("USE_DREAMBOOTH") == "1"
5 | MODEL_ID = os.environ.get("MODEL_ID")
6 | PIPELINE = os.environ.get("PIPELINE")
7 | HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
8 | HOME = os.path.expanduser("~")
9 | MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api")
10 |
--------------------------------------------------------------------------------
/api/loadModel.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | from diffusers import pipelines as _pipelines, AutoPipelineForText2Image
4 | from getScheduler import getScheduler, DEFAULT_SCHEDULER
5 | from precision import torch_dtype_from_precision
6 | from device import device
7 | import time
8 |
9 | HF_AUTH_TOKEN = os.getenv("HF_AUTH_TOKEN")
10 | PIPELINE = os.getenv("PIPELINE")
11 | USE_DREAMBOOTH = True if os.getenv("USE_DREAMBOOTH") == "1" else False
12 | HOME = os.path.expanduser("~")
13 | MODELS_DIR = os.path.join(HOME, ".cache", "diffusers-api")
14 |
15 |
16 | MODEL_IDS = [
17 | "CompVis/stable-diffusion-v1-4",
18 | "hakurei/waifu-diffusion",
19 | # "hakurei/waifu-diffusion-v1-3", - not as diffusers yet
20 | "runwayml/stable-diffusion-inpainting",
21 | "runwayml/stable-diffusion-v1-5",
22 | "stabilityai/stable-diffusion-2"
23 | "stabilityai/stable-diffusion-2-base"
24 | "stabilityai/stable-diffusion-2-inpainting",
25 | ]
26 |
27 |
28 | def loadModel(
29 | model_id: str,
30 | load=True,
31 | precision=None,
32 | revision=None,
33 | send_opts={},
34 | pipeline_class=None,
35 | ):
36 | torch_dtype = torch_dtype_from_precision(precision)
37 | if revision == "":
38 | revision = None
39 |
40 | print(
41 | "loadModel",
42 | {
43 | "model_id": model_id,
44 | "load": load,
45 | "precision": precision,
46 | "revision": revision,
47 | "pipeline_class": pipeline_class,
48 | },
49 | )
50 |
51 | if not pipeline_class:
52 | pipeline_class = AutoPipelineForText2Image
53 |
54 | pipeline = pipeline_class if PIPELINE == "ALL" else getattr(_pipelines, PIPELINE)
55 | print("pipeline", pipeline_class)
56 |
57 | print(
58 | ("Loading" if load else "Downloading")
59 | + " model: "
60 | + model_id
61 | + (f" ({revision})" if revision else "")
62 | )
63 |
64 | scheduler = getScheduler(model_id, DEFAULT_SCHEDULER, not load)
65 |
66 | model_dir = os.path.join(MODELS_DIR, model_id)
67 | if not os.path.isdir(model_dir):
68 | model_dir = None
69 |
70 | from_pretrained = time.time()
71 | model = pipeline.from_pretrained(
72 | model_dir or model_id,
73 | revision=revision,
74 | torch_dtype=torch_dtype,
75 | use_auth_token=HF_AUTH_TOKEN,
76 | scheduler=scheduler,
77 | local_files_only=load,
78 | # Work around https://github.com/huggingface/diffusers/issues/1246
79 | # low_cpu_mem_usage=False if USE_DREAMBOOTH else True,
80 | )
81 | from_pretrained = round((time.time() - from_pretrained) * 1000)
82 |
83 | if load:
84 | to_gpu = time.time()
85 | model.to(device)
86 | to_gpu = round((time.time() - to_gpu) * 1000)
87 | print(f"Loaded from disk in {from_pretrained} ms, to gpu in {to_gpu} ms")
88 | else:
89 | print(f"Downloaded in {from_pretrained} ms")
90 |
91 | return model if load else None
92 |
--------------------------------------------------------------------------------
/api/precision.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | DEPRECATED_PRECISION = os.getenv("PRECISION")
5 | MODEL_PRECISION = os.getenv("MODEL_PRECISION") or DEPRECATED_PRECISION
6 | MODEL_REVISION = os.getenv("MODEL_REVISION")
7 |
8 | if DEPRECATED_PRECISION:
9 | print("Warning: PRECISION variable been deprecated and renamed MODEL_PRECISION")
10 | print("Your setup still works but in a future release, this will throw an error")
11 |
12 | if MODEL_PRECISION and not MODEL_REVISION:
13 | print("Warning: we no longer default to MODEL_REVISION=MODEL_PRECISION, please")
14 | print(f'explicitly set MODEL_REVISION="{MODEL_PRECISION}" if that\'s what you')
15 | print("want.")
16 |
17 |
18 | def revision_from_precision(precision=MODEL_PRECISION):
19 | # return precision if precision else None
20 | raise Exception("revision_from_precision no longer supported")
21 |
22 |
23 | def torch_dtype_from_precision(precision=MODEL_PRECISION):
24 | if precision == "fp16":
25 | return torch.float16
26 | return None
27 |
28 |
29 | def torch_dtype_from_precision(precision=MODEL_PRECISION):
30 | if precision == "fp16":
31 | return torch.float16
32 | return None
33 |
--------------------------------------------------------------------------------
/api/send.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import datetime
4 | import time
5 | import requests
6 | import hashlib
7 | from requests_futures.sessions import FuturesSession
8 | from status import status as statusInstance
9 |
10 | print()
11 | environ = os.environ.copy()
12 | for key in ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "HF_AUTH_TOKEN"]:
13 | if environ.get(key, None):
14 | environ[key] = "XXX"
15 | print(environ)
16 | print()
17 |
18 |
19 | def get_now():
20 | return round(time.time() * 1000)
21 |
22 |
23 | SEND_URL = os.getenv("SEND_URL")
24 | if SEND_URL == "":
25 | SEND_URL = None
26 |
27 | SIGN_KEY = os.getenv("SIGN_KEY", "")
28 | if SIGN_KEY == "":
29 | SIGN_KEY = None
30 |
31 | futureSession = FuturesSession()
32 |
33 | container_id = os.getenv("CONTAINER_ID")
34 | if not container_id:
35 | with open("/proc/self/mountinfo") as file:
36 | line = file.readline().strip()
37 | while line:
38 | if "/containers/" in line:
39 | container_id = line.split("/containers/")[
40 | -1
41 | ] # Take only text to the right
42 | container_id = container_id.split("/")[0] # Take only text to the left
43 | break
44 | line = file.readline().strip()
45 |
46 |
47 | init_used = False
48 |
49 |
50 | def clearSession(force=False):
51 | global session
52 | global init_used
53 |
54 | if init_used or force:
55 | session = {"_ctime": get_now()}
56 | else:
57 | init_used = True
58 |
59 |
60 | def getTimings():
61 | timings = {}
62 | for key in session.keys():
63 | if key == "_ctime":
64 | continue
65 | start = session[key].get("start", None)
66 | done = session[key].get("done", None)
67 | if start and done:
68 | timings.update({key: session[key]["done"] - session[key]["start"]})
69 | else:
70 | timings.update({key: -1})
71 | return timings
72 |
73 |
74 | async def send(type: str, status: str, payload: dict = {}, opts: dict = {}):
75 | now = get_now()
76 | send_url = opts.get("SEND_URL", SEND_URL)
77 | sign_key = opts.get("SIGN_KEY", SIGN_KEY)
78 |
79 | if status == "start":
80 | session.update({type: {"start": now, "last_time": now}})
81 | elif status == "done":
82 | session[type].update({"done": now, "diff": now - session[type]["start"]})
83 | else:
84 | session[type]["last_time"] = now
85 |
86 | data = {
87 | "type": type,
88 | "status": status,
89 | "container_id": container_id,
90 | "time": now,
91 | "t": now - session["_ctime"],
92 | "tsl": now - session[type]["last_time"],
93 | "payload": payload,
94 | }
95 |
96 | if status == "start":
97 | statusInstance.update(type, 0.0)
98 | elif status == "done":
99 | statusInstance.update(type, 1.0)
100 |
101 | if send_url and sign_key:
102 | input = json.dumps(data, separators=(",", ":")) + sign_key
103 | sig = hashlib.md5(input.encode("utf-8")).hexdigest()
104 | data["sig"] = sig
105 |
106 | print(datetime.datetime.now(), data)
107 |
108 | if send_url:
109 | futureSession.post(send_url, json=data)
110 |
111 | response = opts.get("response")
112 | if response:
113 | print("streaming above")
114 | await response.send(json.dumps(data) + "\n")
115 |
116 | # try:
117 | # requests.post(send_url, json=data) # , timeout=0.0000000001)
118 | # except requests.exceptions.ReadTimeout:
119 | # except requests.exceptions.RequestException as error:
120 | # print(error)
121 | # pass
122 |
123 |
124 | clearSession(True)
125 |
--------------------------------------------------------------------------------
/api/server.py:
--------------------------------------------------------------------------------
1 | # Do not edit if deploying to Banana Serverless
2 | # This file is boilerplate for the http server, and follows a strict interface.
3 |
4 | # Instead, edit the init() and inference() functions in app.py
5 |
6 | from sanic import Sanic, response
7 | from sanic_ext import Extend
8 | import subprocess
9 | import app as user_src
10 | import traceback
11 | import os
12 | import json
13 |
14 | # We do the model load-to-GPU step on server startup
15 | # so the model object is available globally for reuse
16 | user_src.init()
17 |
18 | # Create the http server app
19 | server = Sanic("my_app")
20 | server.config.CORS_ORIGINS = os.getenv("CORS_ORIGINS") or "*"
21 | server.config.RESPONSE_TIMEOUT = 60 * 60 # 1 hour (training can be long)
22 | Extend(server)
23 |
24 |
25 | # Healthchecks verify that the environment is correct on Banana Serverless
26 | @server.route("/healthcheck", methods=["GET"])
27 | def healthcheck(request):
28 | # dependency free way to check if GPU is visible
29 | gpu = False
30 | out = subprocess.run("nvidia-smi", shell=True)
31 | if out.returncode == 0: # success state on shell command
32 | gpu = True
33 |
34 | return response.json({"state": "healthy", "gpu": gpu})
35 |
36 |
37 | # Inference POST handler at '/' is called for every http call from Banana
38 | @server.route("/", methods=["POST"])
39 | async def inference(request):
40 | try:
41 | all_inputs = response.json.loads(request.json)
42 | except:
43 | all_inputs = request.json
44 |
45 | call_inputs = all_inputs.get("callInputs", None)
46 | stream_events = call_inputs and call_inputs.get("streamEvents", 0) != 0
47 |
48 | streaming_response = None
49 | if stream_events:
50 | streaming_response = await request.respond(content_type="application/x-ndjson")
51 |
52 | try:
53 | output = await user_src.inference(all_inputs, streaming_response)
54 | except Exception as err:
55 | print(err)
56 | output = {
57 | "$error": {
58 | "code": "APP_INFERENCE_ERROR",
59 | "name": type(err).__name__,
60 | "message": str(err),
61 | "stack": traceback.format_exc(),
62 | }
63 | }
64 |
65 | if stream_events:
66 | await streaming_response.send(json.dumps(output) + "\n")
67 | else:
68 | return response.json(output)
69 |
70 |
71 | if __name__ == "__main__":
72 | server.run(host="0.0.0.0", port="8000", workers=1)
73 |
--------------------------------------------------------------------------------
/api/status.py:
--------------------------------------------------------------------------------
1 | class Status:
2 | def __init__(self):
3 | self.type = "init"
4 | self.progress = 0.0
5 |
6 | def update(self, type, progress):
7 | self.type = type
8 | self.progress = progress
9 |
10 | def get(self):
11 | return {"type": self.type, "progress": self.progress}
12 |
13 |
14 | status = Status()
15 |
--------------------------------------------------------------------------------
/api/tests.py:
--------------------------------------------------------------------------------
1 | from test import runTest
2 |
3 |
4 | def test_memory_free_on_swap_model():
5 | """
6 | Make sure memory is freed when swapping models at runtime.
7 | """
8 | result = runTest(
9 | "txt2img",
10 | {},
11 | {
12 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base",
13 | "MODEL_PRECISION": "", # full precision
14 | "MODEL_URL": "s3://",
15 | },
16 | {"num_inference_steps": 1},
17 | )
18 | mem_usage = list()
19 | mem_usage.append(result["$mem_usage"])
20 | result = runTest(
21 | "txt2img",
22 | {},
23 | {
24 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base",
25 | "MODEL_PRECISION": "fp16", # half precision
26 | "MODEL_URL": "s3://",
27 | },
28 | {"num_inference_steps": 1},
29 | )
30 | mem_usage.append(result["$mem_usage"])
31 |
32 | print({"mem_usage": mem_usage})
33 | # Assert that less memory used when unloading fp32 model and
34 | # loading the fp16 variant in its place
35 | assert mem_usage[1] < mem_usage[0]
36 |
--------------------------------------------------------------------------------
/api/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .storage import Storage
2 |
--------------------------------------------------------------------------------
/api/utils/storage/BaseStorage.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import subprocess
4 | from abc import ABC, abstractmethod
5 | import xtarfile as tarfile
6 |
7 |
8 | class BaseArchive(ABC):
9 | def __init__(self, path, status=None):
10 | self.path = path
11 | self.status = status
12 |
13 | def updateStatus(self, type, progress):
14 | if self.status:
15 | self.status.update(type, progress)
16 |
17 | def extract(self):
18 | print("TODO")
19 |
20 | def splitext(self):
21 | base, ext = os.path.splitext(self.path)
22 | base, subext = os.path.splitext(base)
23 | return base, ext, subext
24 |
25 |
26 | class TarArchive(BaseArchive):
27 | @staticmethod
28 | def test(path):
29 | return re.search(r"\.tar", path)
30 |
31 | def extract(self, dir, dry_run=False):
32 | self.updateStatus("extract", 0)
33 | if not dir:
34 | base, ext, subext = self.splitext()
35 | parent_dir = os.path.dirname(self.path)
36 | dir = os.path.join(parent_dir, base)
37 |
38 | if not dry_run:
39 | os.mkdir(dir)
40 |
41 | def track_progress(tar):
42 | i = 0
43 | members = tar.getmembers()
44 | for member in members:
45 | i += 1
46 | self.updateStatus("extract", i / len(members))
47 | yield member
48 |
49 | print("Extracting to " + dir)
50 | with tarfile.open(self.path, "r") as tar:
51 | tar.extractall(path=dir, members=track_progress(tar))
52 | tar.close()
53 | subprocess.run(["ls", "-l", dir])
54 | os.remove(self.path)
55 |
56 | self.updateStatus("extract", 1)
57 | return dir # , base, ext, subext
58 |
59 |
60 | archiveClasses = [TarArchive]
61 |
62 |
63 | def Archive(path, **kwargs):
64 | for ArchiveClass in archiveClasses:
65 | if ArchiveClass.test(path):
66 | return ArchiveClass(path, **kwargs)
67 |
68 |
69 | class BaseStorage(ABC):
70 | @staticmethod
71 | @abstractmethod
72 | def test(url):
73 | return re.search(r"^https?://", url)
74 |
75 | def __init__(self, url, **kwargs):
76 | self.url = url
77 | self.status = kwargs.get("status", None)
78 | self.query = {}
79 |
80 | def updateStatus(self, type, progress):
81 | if self.status:
82 | self.status.update(type, progress)
83 |
84 | def splitext(self):
85 | base, ext = os.path.splitext(self.url)
86 | base, subext = os.path.splitext(base)
87 | return base, ext, subext
88 |
89 | def get_filename(self):
90 | return self.url.split("/").pop()
91 |
92 | @abstractmethod
93 | def download_file(self, dest):
94 | """Download the file to `dest`"""
95 | pass
96 |
97 | def download_and_extract(self, fname, dir=None, dry_run=False):
98 | """
99 | Downloads the file, and if it's an archive, extract it too. Returns
100 | the filename if not, or directory name (fname without extension) if
101 | it was.
102 | """
103 | if not fname:
104 | fname = self.get_filename()
105 |
106 | archive = Archive(fname, status=self.status)
107 | if archive:
108 | # TODO, streaming pipeline
109 | self.download_file(fname)
110 | return archive.extract(dir)
111 | else:
112 | self.download_file(fname)
113 | return fname
114 |
--------------------------------------------------------------------------------
/api/utils/storage/BaseStorage_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from . import Storage, S3Storage, HTTPStorage
3 |
4 |
5 | class BaseStorageTest(unittest.TestCase):
6 | def test_get_filename(self):
7 | storage = Storage("http://host.com/dir/file.tar.zst")
8 | self.assertEqual(storage.get_filename(), "file.tar.zst")
9 |
10 | class Download_and_extract(unittest.TestCase):
11 | def test_file_only(self):
12 | storage = Storage("http://host.com/dir/file.bin")
13 | result = storage.download_and_extract(dry_run=True)
14 | self.assertEqual(result, "file.bin")
15 |
16 | def test_file_archive(self):
17 | storage = Storage("http://host.com/dir/file.tar.zst")
18 | result, base, ext, subext = storage.download_and_extract(dry_run=True)
19 | self.assertEqual(result, "file")
20 | self.assertEqual(base, "file")
21 | self.assertEqual(ext, "tar")
22 | self.assertEqual(subext, "zst")
23 |
--------------------------------------------------------------------------------
/api/utils/storage/HTTPStorage.py:
--------------------------------------------------------------------------------
1 | import re
2 | import os
3 | import time
4 | import requests
5 | from tqdm import tqdm
6 | from .BaseStorage import BaseStorage
7 | import urllib.parse
8 |
9 |
10 | def get_now():
11 | return round(time.time() * 1000)
12 |
13 |
14 | class HTTPStorage(BaseStorage):
15 | @staticmethod
16 | def test(url):
17 | return re.search(r"^https?://", url)
18 |
19 | def __init__(self, url, **kwargs):
20 | super().__init__(url, **kwargs)
21 | parts = self.url.split("#", 1)
22 | self.url = parts[0]
23 | if len(parts) > 1:
24 | self.query = urllib.parse.parse_qs(parts[1])
25 |
26 | def upload_file(self, source, dest):
27 | raise RuntimeError("HTTP PUT not implemented yet")
28 |
29 | def download_file(self, fname):
30 | print(f"Downloading {self.url} to {fname}...")
31 | resp = requests.get(self.url, stream=True)
32 | total = int(resp.headers.get("content-length", 0))
33 | content_disposition = resp.headers.get("content-disposition")
34 | if content_disposition:
35 | filename_search = re.search('filename="(.+)"', content_disposition)
36 | if filename_search:
37 | self.filename = filename_search.group(1)
38 | else:
39 | print("Warning: content-disposition header is not found in the response.")
40 | # Can also replace 'file' with a io.BytesIO object
41 | with open(fname, "wb") as file, tqdm(
42 | desc="Downloading",
43 | total=total,
44 | unit="iB",
45 | unit_scale=True,
46 | unit_divisor=1024,
47 | ) as bar:
48 | total_written = 0
49 | for data in resp.iter_content(chunk_size=1024):
50 | size = file.write(data)
51 | bar.update(size)
52 | total_written += size
53 | self.updateStatus("download", total_written / total)
54 |
--------------------------------------------------------------------------------
/api/utils/storage/S3Storage.py:
--------------------------------------------------------------------------------
1 | import boto3
2 | import botocore
3 | import re
4 | import os
5 | import time
6 | from tqdm import tqdm
7 | from botocore.client import Config
8 | from .BaseStorage import BaseStorage
9 |
10 | AWS_S3_ENDPOINT_URL = os.environ.get("AWS_S3_ENDPOINT_URL", None)
11 | AWS_S3_DEFAULT_BUCKET = os.environ.get("AWS_S3_DEFAULT_BUCKET", None)
12 | if AWS_S3_ENDPOINT_URL == "":
13 | AWS_S3_ENDPOINT_URL = None
14 | if AWS_S3_DEFAULT_BUCKET == "":
15 | AWS_S3_DEFAULT_BUCKET = None
16 |
17 |
18 | def get_now():
19 | return round(time.time() * 1000)
20 |
21 |
22 | class S3Storage(BaseStorage):
23 | def test(url):
24 | return re.search(r"^(https?\+)?s3://", url)
25 |
26 | def __init__(self, url, **kwargs):
27 | super().__init__(url, **kwargs)
28 |
29 | if url.startswith("s3://"):
30 | url = "https://" + url[5:]
31 | elif url.startswith("http+s3://"):
32 | url = "http" + url[7:]
33 | elif url.startswith("https+s3://"):
34 | url = "https" + url[8:]
35 |
36 | s3_dest = re.match(
37 | r"^(?Phttps?://[^/]*)(/(?P[^/]+))?(/(?P.*))?$",
38 | url,
39 | ).groupdict()
40 |
41 | if not s3_dest["endpoint"] or s3_dest["endpoint"].endswith("//"):
42 | s3_dest["endpoint"] = AWS_S3_ENDPOINT_URL
43 | if not s3_dest["bucket"]:
44 | s3_dest["bucket"] = AWS_S3_DEFAULT_BUCKET
45 | if not s3_dest["path"] or s3_dest["path"] == "":
46 | s3_dest["path"] = kwargs.get("default_path", "")
47 |
48 | self.endpoint_url = s3_dest["endpoint"]
49 | self.bucket_name = s3_dest["bucket"]
50 | self.path = s3_dest["path"]
51 |
52 | self._s3resource = None
53 | self._s3client = None
54 | self._bucket = None
55 | print("self.endpoint_url", self.endpoint_url)
56 |
57 | def s3resource(self):
58 | if self._s3resource:
59 | return self._s3resource
60 |
61 | self._s3 = boto3.resource(
62 | "s3",
63 | endpoint_url=self.endpoint_url,
64 | config=Config(signature_version="s3v4"),
65 | )
66 | return self._s3
67 |
68 | def s3client(self):
69 | if self._s3client:
70 | return self._s3client
71 |
72 | self._s3client = boto3.client(
73 | "s3",
74 | endpoint_url=self.endpoint_url,
75 | config=Config(signature_version="s3v4"),
76 | )
77 | return self._s3client
78 |
79 | def bucket(self):
80 | if self._bucket:
81 | return self._bucket
82 |
83 | self._bucket = self.s3resource().Bucket(self.bucket_name)
84 | return self._bucket
85 |
86 | def upload_file(self, source, dest):
87 | if not dest:
88 | dest = self.path
89 |
90 | upload_start = get_now()
91 | file_size = os.stat(source).st_size
92 | with tqdm(total=file_size, unit="B", unit_scale=True, desc="Uploading") as bar:
93 | total_transferred = 0
94 |
95 | def callback(bytes_transferred):
96 | nonlocal total_transferred
97 | bar.update(bytes_transferred),
98 | total_transferred += bytes_transferred
99 | self.updateStatus("upload", total_transferred / file_size)
100 |
101 | result = self.bucket().upload_file(
102 | Filename=source, Key=dest, Callback=callback
103 | )
104 | print(result)
105 | upload_total = get_now() - upload_start
106 |
107 | return {"$time": upload_total}
108 |
109 | def download_file(self, dest):
110 | if not dest:
111 | dest = self.path.split("/").pop()
112 | print(f"Downloading {self.url} to {dest}...")
113 | object = self.s3resource().Object(self.bucket_name, self.path)
114 | object.load()
115 |
116 | with tqdm(
117 | total=object.content_length, unit="B", unit_scale=True, desc="Downloading"
118 | ) as bar:
119 | total_transferred = 0
120 |
121 | def callback(bytes_transferred):
122 | nonlocal total_transferred
123 | bar.update(bytes_transferred),
124 | total_transferred += bytes_transferred
125 | self.updateStatus("download", total_transferred / object.content_length)
126 |
127 | object.download_file(Filename=dest, Callback=callback)
128 |
129 | def file_exists(self):
130 | # res = self.s3client().list_objects_v2(
131 | # Bucket=self.bucket_name, Prefix=self.path, MaxKeys=1
132 | # )
133 | # return "Contents" in res
134 | object = self.s3resource().Object(self.bucket_name, self.path)
135 | try:
136 | object.load()
137 | except botocore.exceptions.ClientError as error:
138 | if error.response["Error"]["Code"] == "404":
139 | return False
140 | else:
141 | raise
142 | return True
143 |
--------------------------------------------------------------------------------
/api/utils/storage/S3Storage_test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import os
3 | from .S3Storage import S3Storage, AWS_S3_ENDPOINT_URL, AWS_S3_DEFAULT_BUCKET
4 |
5 |
6 | class S3StorageTest(unittest.TestCase):
7 | def test_endpoint_only_s3(self):
8 | storage = S3Storage("s3://hostname:9000")
9 | self.assertEqual(storage.endpoint_url, "https://hostname:9000")
10 | self.assertEqual(storage.bucket_name, AWS_S3_DEFAULT_BUCKET)
11 | self.assertEqual(storage.path, "")
12 |
13 | def test_endpoint_only_http_s3(self):
14 | storage = S3Storage("http+s3://hostname:9000")
15 | self.assertEqual(storage.endpoint_url, "http://hostname:9000")
16 | self.assertEqual(storage.bucket_name, AWS_S3_DEFAULT_BUCKET)
17 | self.assertEqual(storage.path, "")
18 |
19 | def test_endpoint_only_https_s3(self):
20 | storage = S3Storage("https+s3://hostname:9000")
21 | self.assertEqual(storage.endpoint_url, "https://hostname:9000")
22 | self.assertEqual(storage.bucket_name, AWS_S3_DEFAULT_BUCKET)
23 | self.assertEqual(storage.path, "")
24 |
25 | def test_bucket_only(self):
26 | storage = S3Storage("s3:///bucket")
27 | self.assertEqual(storage.endpoint_url, AWS_S3_ENDPOINT_URL)
28 | self.assertEqual(storage.bucket_name, "bucket")
29 | self.assertEqual(storage.path, "")
30 |
31 | def test_url_with_bucket_and_file_only(self):
32 | storage = S3Storage("s3:///bucket/file")
33 | self.assertEqual(storage.endpoint_url, AWS_S3_ENDPOINT_URL)
34 | self.assertEqual(storage.bucket_name, "bucket")
35 | self.assertEqual(storage.path, "file")
36 |
37 | def test_full_url_with_subdirectory(self):
38 | storage = S3Storage("s3://host/bucket/path/file")
39 | self.assertEqual(storage.endpoint_url, "https://host")
40 | self.assertEqual(storage.bucket_name, "bucket")
41 | self.assertEqual(storage.path, "path/file")
42 |
43 |
44 | if __name__ == "__main__":
45 | unittest.main()
46 |
--------------------------------------------------------------------------------
/api/utils/storage/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | from .S3Storage import S3Storage
4 | from .HTTPStorage import HTTPStorage
5 |
6 | classes = [S3Storage, HTTPStorage]
7 |
8 |
9 | def Storage(url, no_raise=False, **kwargs):
10 | for StorageClass in classes:
11 | if StorageClass.test(url):
12 | return StorageClass(url, **kwargs)
13 |
14 | if no_raise:
15 | return None
16 | else:
17 | raise RuntimeError("No storage handler for: " + url)
18 |
--------------------------------------------------------------------------------
/api/utils/storage/__init__test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | from . import Storage, S3Storage, HTTPStorage
3 |
4 |
5 | class StorageTest(unittest.TestCase):
6 | def test_url_s3(self):
7 | storage = Storage("s3://hostname:9000")
8 | self.assertTrue(isinstance(storage, S3Storage))
9 |
10 | def test_url_http(self):
11 | storage = Storage("http://hostname:9000")
12 | self.assertTrue(isinstance(storage, HTTPStorage))
13 |
14 | def test_no_match_raise(self):
15 | with self.assertRaises(RuntimeError):
16 | storage = Storage("not_a_url")
17 |
18 | def test_no_match_no_raise(self):
19 | storage = Storage("not_a_url", no_raise=True)
20 | self.assertIsNone(storage)
21 |
--------------------------------------------------------------------------------
/build:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # This is my common way of building, but you can build however you like.
4 | # Note if you using a proxy, you need to have it running first.
5 |
6 | DOCKER_BUILDKIT=1 BUILDKIT_PROGRESS=plain \
7 | docker build \
8 | -t gadicc/diffusers-api \
9 | -t gadicc/diffusers-api:test \
10 | --build-arg http_proxy="http://172.17.0.1:3128" \
11 | --build-arg https_proxy="http://172.17.0.1:3128" \
12 | "$@" .
13 |
--------------------------------------------------------------------------------
/docs/internal_safetensor_cache_flow.md:
--------------------------------------------------------------------------------
1 | internal document to gather my thoughts
2 |
3 | RUNTIME_DOWNLOADS=1 (must be build arg)
4 | IMAGE_CLOUD_CACHE="s3://" (can be env arg)
5 | CREATE_MISSING=1
6 |
7 | e.g. stabilityai/stable-diffusion-2-1-base
8 |
9 | 1. Try download from IMAGE_CLOUD_CACHE
10 | 1. If found, use.
11 | 2. If not found:
12 | 1. Download from HuggingFace
13 | 2. In a subprocess:
14 | 1. Save with safetesors to tmp directory
15 | 2. Upload to IMAGE_CLOUD_CACHE
16 | 3. Delete original model dir, mv tmp to model dir (for next load)
17 | 1. Run inference with HF model.
18 |
19 | FileNotFoundError: [Errno 2] No such file or directory: '/root/.cache/huggingface/diffusers/models--stabilityai--stable-diffusion-2-1-base/refs/main'
20 |
21 |
22 | NVIDIA RTX Quadro 5000
23 |
24 | NO SAFETENSORS
25 | Downloaded in 462557 ms
26 | Loading model: stabilityai/stable-diffusion-2-1 (fp32)
27 | Loaded from disk in 3113 ms, to gpu in 1644 ms
28 |
29 | SAFETENSORS_FAST_GPU=0
30 | Loaded from disk in 2741 ms, to gpu in 557 ms
31 |
32 | SAFETENSORS_FAST_GPU=1
33 | Loaded from disk in 1153 ms, to gpu in 1495 ms
34 |
35 |
36 |
37 | NVIDIA RTX Quadro 5000 (fp16)
38 |
39 | NO SAFETENSORS
40 | Downloaded in 462557 ms
41 | Loading model: stabilityai/stable-diffusion-2-1-base (fp16)
42 | Loaded from disk in 2043 ms, to gpu in 1539 ms
43 |
44 | SAFETENSORS_FAST_GPU=0
45 |
46 |
47 | SAFETENSORS_FAST_GPU=1
48 | Loaded from disk in 1134 ms, to gpu in 1184 ms
49 |
--------------------------------------------------------------------------------
/docs/storage.md:
--------------------------------------------------------------------------------
1 | # Storage
2 |
3 | Most URLs passed at build args or call args support special URLs, both to
4 | store and retrieve files.
5 |
6 | **The Storage API is new and may change without notice, please keep a
7 | careful look in the CHANGELOG when upgrading**.
8 |
9 | * [AWS S3](#s3)
10 |
11 |
12 | ## S3
13 |
14 | ### Build Args
15 |
16 | Set the following **build-args**, as appropriate (through the Banana dashboard,
17 | by modifying the appropriate lines in the `Dockerfile`, or by specifying, e.g.
18 | `--build-arg AWS_ACCESS_KEY="XXX"` etc.)
19 |
20 | ```Dockerfile
21 | ARG AWS_ACCESS_KEY_ID="XXX"
22 | ARG AWS_SECRET_ACCESS_KEY="XXX"
23 | ARG AWS_DEFAULT_REGION="us-west-1" # best for banana
24 | # Optional. ONLY SET THIS IF YOU KNOW YOU NEED TO.
25 | # Usually only if you're using non-Amazon S3-compatible storage.
26 | # If you need this, your provider will tell you exactly what
27 | # to put here. Otherwise leave it blank to automatically use
28 | # the correct Amazon S3 endpoint.
29 | ARG AWS_S3_ENDPOINT_URL
30 | ```
31 |
32 | ### Usage
33 |
34 | In any URL where Storage is supported (e.g. dreambooth `dest_url`):
35 |
36 | * `s3://endpoint/bucket/path/to/file`
37 | * `s3:///bucket/file` (uses the default endpoint)
38 | * `s3:///bucket` (for `dest_url`, filename will match your output model)
39 | * `http+s3://...` (force http instead of https)
--------------------------------------------------------------------------------
/install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # This entire file is no longer used but kept around for reference.
4 |
5 | if [ "$FLASH_ATTENTION" == "1" ]; then
6 |
7 | echo "Building with flash attention"
8 | git clone https://github.com/HazyResearch/flash-attention.git
9 | cd flash-attention
10 | git checkout cutlass
11 | git submodule init
12 | git submodule update
13 | python setup.py install
14 |
15 | cd ..
16 | git clone https://github.com/HazyResearch/diffusers.git
17 | pip install -e diffusers
18 |
19 | else
20 |
21 | echo "Building without flash attention"
22 | git clone https://github.com/huggingface/diffusers
23 | cd diffusers
24 | git checkout v0.9.0
25 | # 2022-11-21 [Community Pipelines] K-Diffusion Pipeline
26 | # git checkout 182eb959e5efc8c77fa31394ca55376331c0ed25
27 | # 2022-11-24 v_prediction (for SD 2.0)
28 | # git checkout 30f6f4410487b6c1cf5be2da6c7e8fc844fb9a44
29 | cd ..
30 | pip install -e diffusers
31 |
32 | fi
33 |
34 |
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "docker-diffusers-api",
3 | "version": "0.0.1",
4 | "main": "index.js",
5 | "repository": "https://github.com/kiri-art/docker-diffusers-api.git",
6 | "author": "Gadi Cohen ",
7 | "license": "MIT",
8 | "private": true,
9 | "devDependencies": {
10 | "@semantic-release-plus/docker": "^3.1.2",
11 | "@semantic-release/changelog": "^6.0.2",
12 | "@semantic-release/git": "^10.0.1",
13 | "semantic-release": "^19.0.5",
14 | "semantic-release-plus": "^20.0.0"
15 | }
16 | }
17 |
--------------------------------------------------------------------------------
/prime.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | # need to fix this.
4 | #download_model {'model_url': 's3://', 'model_id': 'Linaqruf/anything-v3.0', 'model_revision': 'fp16', 'hf_model_id': None}
5 | # {'normalized_model_id': 'models--Linaqruf--anything-v3.0--fp16'}
6 | #self.endpoint_url https://6fb830ebb3c8fed82a52524211d9c54e.r2.cloudflarestorage.com/diffusers
7 | #Downloading s3:// to /root/.cache/diffusers-api/models--Linaqruf--anything-v3.0--fp16.tar.zst...
8 |
9 |
10 | MODELS=(
11 | # ID,precision,revision
12 | # "prompthero/openjourney-v2"
13 | # "wd-1-4-anime_e1,,,hakurei/waifu-diffusion"
14 | # "Linaqruf/anything-v3.0,fp16,diffusers"
15 | # "Linaqruf/anything-v3.0,fp16,fp16"
16 | # "stabilityai/stable-diffusion-2-1,fp16,fp16"
17 | # "stabilityai/stable-diffusion-2-1-base,fp16,fp16"
18 | # "stabilityai/stable-diffusion-2,fp16,fp16"
19 | # "stabilityai/stable-diffusion-2-base,fp16,fp16"
20 | # "CompVis/stable-diffusion-v1-4,fp16,fp16"
21 | # "runwayml/stable-diffusion-v1-5,fp16,fp16"
22 | # "runwayml/stable-diffusion-inpainting,fp16,fp16"
23 | # "hakurei/waifu-diffusion,fp16,fp16"
24 | # "hakurei/waifu-diffusion-v1-3,fp16,fp16" # from checkpoint
25 | # "rinna/japanese-stable-diffusion"
26 | # "OrangeMix/AbyssOrangeMix2,fp16"
27 | # "OrangeMix/ElyOrangeMix,fp16"
28 | # "OrangeMix/EerieOrangeMix,fp16"
29 | # "OrangeMix/BloodOrangeMix,fp16"
30 | "hakurei/wd-1-5-illusion-beta3,fp16,fp16"
31 | "hakurei/wd-1-5-ink-beta3,fp16,fp16"
32 | "hakurei/wd-1-5-mofu-beta3,fp16,fp16"
33 | "hakurei/wd-1-5-radiance-beta3,fp16,fp16",
34 | )
35 |
36 | for MODEL_STR in ${MODELS[@]}; do
37 | IFS="," read -ra DATA <<<$MODEL_STR
38 | MODEL_ID=${DATA[0]}
39 | MODEL_PRECISION=${DATA[1]}
40 | MODEL_REVISION=${DATA[2]}
41 | HF_MODEL_ID=${DATA[3]}
42 | python test.py txt2img \
43 | --call-arg MODEL_ID="$MODEL_ID" \
44 | --call-arg HF_MODEL_ID="$HF_MODEL_ID" \
45 | --call-arg MODEL_PRECISION="$MODEL_PRECISION" \
46 | --call-arg MODEL_REVISION="$MODEL_REVISION" \
47 | --call-arg MODEL_URL="s3://" \
48 | --model-arg num_inference_steps=1
49 | done
50 |
--------------------------------------------------------------------------------
/release.config.js:
--------------------------------------------------------------------------------
1 | // https://semantic-release.gitbook.io/semantic-release/support/faq#can-i-use-semantic-release-to-publish-non-javascript-packages
2 | module.exports = {
3 | "branches": ["main"],
4 | "plugins": [
5 | "@semantic-release/commit-analyzer",
6 | "@semantic-release/release-notes-generator",
7 | [
8 | "@semantic-release/changelog",
9 | {
10 | "changelogFile": "CHANGELOG.md"
11 | }
12 | ],
13 | [
14 | "@semantic-release/git",
15 | {
16 | "assets": ["CHANGELOG.md"]
17 | }
18 | ],
19 | "@semantic-release/github",
20 | ["@semantic-release-plus/docker", {
21 | "name": "gadicc/diffusers-api"
22 | }]
23 | ]
24 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # we pin sanic==22.6.2 for compatibility with banana
2 | sanic==22.6.2
3 | sanic-ext==22.6.2
4 | # earlier sanics don't pin but require websockets<11.0
5 | websockets<11.0
6 |
7 | # now manually git cloned in a later step
8 | # diffusers==0.4.1
9 | # git+https://github.com/huggingface/diffusers@v0.5.1
10 |
11 | transformers==4.33.1 # was 4.30.2 until 2023-09-08
12 | scipy==1.11.2 # was 1.10.0 until 2023-09-08
13 | requests_futures==1.0.0
14 | numpy==1.25.1 # was 1.24.1 until 2023-09-08
15 | scikit-image==0.21.0 # was 0.19.3 until 2023-09-08
16 | accelerate==0.22.0 # was 0.20.3 until 2023-09-08
17 | triton==2.1.0 # was 2.0.0.post1 until 2023-09-08
18 | ftfy==6.1.1
19 | spacy==3.6.1 # was 3.5.0 until 2023-09-08
20 | k-diffusion==0.0.16 # was 0.0.15 until 2023-09-08
21 | safetensors==0.3.3 # was 0.3.1 until 2023-09-08
22 |
23 | torch==2.0.1 # was 1.12.1 until 2023-07-19
24 | torchvision==0.15.2
25 | pytorch_lightning==2.0.8 # was 1.9.2 until 2023-09-08
26 |
27 | boto3==1.28.43 # was 1.26.57 until 2023-09-08
28 | botocore==1.31.43 # was 1.29.57 until 2023-09-08
29 |
30 | pytest==7.4.2 # was 7.2.1 until 2023-09-08
31 | pytest-cov==4.1.0 # was 4.0.0 until 2023-09-08
32 |
33 | datasets==2.14.5 # was 2.8.0 until 2023-09-08
34 | omegaconf==2.3.0
35 | tensorboard==2.14.0 # was 2.12.0 until 2023-09-08
36 |
37 | xtarfile[zstd]==0.1.0
38 |
39 | bitsandbytes==0.41.1 # was 0.40.2 until 2023-09-08
40 |
41 | invisible-watermark==0.2.0 # released 2023-07-06
42 | compel==2.0.2 # was 2.0.1 until 2023-09-08
43 | jxlpy==0.9.2 # added 2023-09-11
44 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | docker run -it --rm \
4 | --gpus all \
5 | -p 8000:8000 \
6 | -e http_proxy="http://172.17.0.1:3128" \
7 | -e https_proxy="http://172.17.0.1:3128" \
8 | -e REQUESTS_CA_BUNDLE="/usr/local/share/ca-certificates/squid-self-signed.crt" \
9 | -e HF_AUTH_TOKEN="$HF_AUTH_TOKEN" \
10 | -e AWS_ACCESS_KEY_ID="$AWS_ACCESS_KEY_ID" \
11 | -e AWS_SECRET_ACCESS_KEY="$AWS_SECRET_ACCESS_KEY" \
12 | -e AWS_DEFAULT_REGION="$AWS_DEFAULT_REGION" \
13 | -e AWS_S3_ENDPOINT_URL="$AWS_S3_ENDPOINT_URL" \
14 | -e AWS_S3_DEFAULT_BUCKET="$AWS_S3_DEFAULT_BUCKET" \
15 | -v ~/root-cache:/root/.cache \
16 | "$@" gadicc/diffusers-api
17 |
--------------------------------------------------------------------------------
/run_integration_tests_on_lambda.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | PAYLOAD_FILE="/tmp/request.json"
4 |
5 | if [ -z "$LAMBDA_API_KEY" ]; then
6 | echo "No LAMBDA_API_KEY set"
7 | exit 1
8 | fi
9 |
10 | SSH_KEY_FILE="$HOME/.ssh/diffusers-api-test.pem"
11 | if [ ! -f "$SSH_KEY_FILE" ]; then
12 | curl -L $DDA_TEST_PEM > $SSH_KEY_FILE
13 | chmod 600 $SSH_KEY_FILE
14 | fi
15 |
16 | #curl -u $LAMBDA_API_KEY: https://cloud.lambdalabs.com/api/v1/instances
17 |
18 | # TODO, find an available instance
19 | # https://cloud.lambdalabs.com/api/v1/instance-types
20 |
21 | lambda_run() {
22 | # $1 = lambda instance-operation
23 | if [ -z "$2" ] ; then
24 | RESULT=$(
25 | curl -su ${LAMBDA_API_KEY}: \
26 | https://cloud.lambdalabs.com/api/v1/$1 \
27 | -H "Content-Type: application/json"
28 | )
29 | else
30 | RESULT=$(
31 | curl -su ${LAMBDA_API_KEY}: \
32 | https://cloud.lambdalabs.com/api/v1/$1 \
33 | -d @$2 -H "Content-Type: application/json"
34 | )
35 | fi
36 |
37 | if [ $? -eq 1 ]; then
38 | echo "curl failed"
39 | exit 1
40 | fi
41 |
42 | if [ "$RESULT" != "" ]; then
43 | echo $RESULT | jq -e .error >& /dev/null
44 | if [ $? -eq 0 ]; then
45 | echo "lambda error"
46 | echo $RESULT
47 | exit 1
48 | fi
49 | fi
50 | }
51 |
52 | instance_create() {
53 | echo -n "Creating instance..."
54 | local RESULT=""
55 | cat > $PAYLOAD_FILE << __END__
56 | {
57 | "region_name": "us-west-1",
58 | "instance_type_name": "gpu_1x_a10",
59 | "ssh_key_names": [
60 | "diffusers-api-test"
61 | ],
62 | "file_system_names": [],
63 | "quantity": 1
64 | }
65 | __END__
66 |
67 | lambda_run "instance-operations/launch" $PAYLOAD_FILE
68 | # echo $RESULT
69 | INSTANCE_ID=$(echo $RESULT | jq -re '.data.instance_ids[0]')
70 | echo "$INSTANCE_ID"
71 | if [ $? -eq 1 ]; then
72 | echo "jq failed"
73 | exit 1
74 | fi
75 | }
76 |
77 | instance_terminate() {
78 | # $1 = INSTANCE_ID
79 | echo "Terminating instance $1"
80 | cat > $PAYLOAD_FILE << __END__
81 | {
82 | "instance_ids": [
83 | "$1"
84 | ]
85 | }
86 | __END__
87 | lambda_run "instance-operations/terminate" $PAYLOAD_FILE
88 | echo $RESULT
89 | }
90 |
91 | declare -A IPS
92 | instance_wait() {
93 | INSTANCE_ID="$1"
94 | echo -n "Waiting for $INSTANCE_ID"
95 | STATUS=""
96 | LAST_STATUS=""
97 | while [ "$STATUS" != "active" ] ; do
98 | echo -n "."
99 | lambda_run "instances/$INSTANCE_ID"
100 | STATUS=$(echo $RESULT | jq -r '.data.status')
101 | if [ "$STATUS" != "$LAST_STATUS" ]; then
102 | # echo $RESULT
103 | # echo STATUS $STATUS
104 | LAST_STATUS=$STATUS
105 | fi
106 | sleep 1
107 | done
108 | echo
109 |
110 | IP=$(echo $RESULT | jq -r '.data.ip')
111 | echo STATUS $STATUS
112 | echo IP $IP
113 | IPS["$INSTANCE_ID"]=$IP
114 | }
115 |
116 | instance_run_script() {
117 | INSTANCE_ID="$1"
118 | SCRIPT="$2"
119 | DIRECTORY="${3:-'.'}"
120 | IP=${IPS["$INSTANCE_ID"]}
121 |
122 | echo "instance_run_script $1 $2 $3"
123 | ssh -i $SSH_KEY_FILE ubuntu@$IP "cd $DIRECTORY && bash -s" < $SCRIPT
124 | return $?
125 | }
126 |
127 | instance_run_command() {
128 | INSTANCE_ID="$1"
129 | CMD="$2"
130 | DIRECTORY="${3:-'.'}"
131 | IP=${IPS["$INSTANCE_ID"]}
132 |
133 | echo "instance_run_command $1 $2"
134 | ssh -i $SSH_KEY_FILE -o StrictHostKeyChecking=accept-new ubuntu@$IP "cd $DIRECTORY && $CMD"
135 | return $?
136 | }
137 |
138 | instance_rsync() {
139 | INSTANCE_ID="$1"
140 | SOURCE="$2"
141 | DEST="$3"
142 | IP=${IPS["$INSTANCE_ID"]}
143 |
144 | echo "instance_rsync $1 $2 $3"
145 | rsync -avzPe "ssh -i $SSH_KEY_FILE -o StrictHostKeyChecking=accept-new" --filter=':- .gitignore' --exclude=".git" $SOURCE ubuntu@$IP:$DEST
146 | return $?
147 | }
148 |
149 | # Image Method 3, preparation (TODO, arg to specify which method)
150 | docker build -t gadicc/diffusers-api:test .
151 | docker push gadicc/diffusers-api:test
152 |
153 | instance_create
154 | # INSTANCE_ID="913e06f669bf4e799c6223801eb82f40"
155 |
156 | instance_wait $INSTANCE_ID
157 |
158 | commands() {
159 | instance_run_command $INSTANCE_ID "echo 'export HF_AUTH_TOKEN=\"$HF_AUTH_TOKEN\"' >> ~/.bashrc"
160 |
161 | # Whether to build or just for test scripts, lets transfer this checkout.
162 | instance_rsync $INSTANCE_ID . docker-diffusers-api
163 |
164 | instance_run_command $INSTANCE_ID "sudo apt-get update"
165 | if [ $? -eq 1 ]; then return 1 ; fi
166 | instance_run_command $INSTANCE_ID "sudo apt install -yqq python3.9"
167 | if [ $? -eq 1 ]; then return 1 ; fi
168 | instance_run_command $INSTANCE_ID "python3.9 -m pip install -r docker-diffusers-api/tests/integration/requirements.txt"
169 | if [ $? -eq 1 ]; then return 1 ; fi
170 | instance_run_command $INSTANCE_ID "sudo usermod -aG docker ubuntu"
171 | if [ $? -eq 1 ]; then return 1 ; fi
172 |
173 | # Image Method 1: Transfer entire image
174 | # This turned out to be way too slow, quicker to rebuild on lambda
175 | # Longer term, I guess we need our own container registry.
176 | # echo "Saving and transferring docker image to Lambda..."
177 | # IP=${IPS["$INSTANCE_ID"]}
178 | # docker save gadicc/diffusers-api:latest \
179 | # | xz \
180 | # | pv \
181 | # | ssh -i $SSH_KEY_FILE ubuntu@$IP docker load
182 | # if [ $? -eq 1 ]; then return 1 ; fi
183 |
184 | # Image Method 2: Build on LambdaLabs
185 | #if [ $? -eq 1 ]; then return 1 ; fi
186 | #instance_run_command $INSTANCE_ID "docker build -t gadicc/diffusers-api ." docker-diffusers-api
187 |
188 | # Image Method 3: Just upload new layers; Lambda has fast downloads from registry
189 | # At start of script we have docker build/push. Now let's pull:
190 | instance_run_command $INSTANCE_ID "docker pull gadicc/diffusers-api:test"
191 |
192 | # instance_run_script $INSTANCE_ID run_integration_tests.sh docker-diffusers-api
193 | instance_run_command $INSTANCE_ID "export HF_AUTH_TOKEN=\"$HF_AUTH_TOKEN\" && python3.9 -m pytest -s tests/integration" docker-diffusers-api
194 | }
195 |
196 | commands
197 | RETURN_VALUE=$?
198 |
199 | instance_terminate $INSTANCE_ID
200 |
201 | exit $RETURN_VALUE
--------------------------------------------------------------------------------
/scripts/devContainerPostCreate.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # devcontainer.json postCreateCommand
4 |
5 | echo
6 | echo Initialize conda bindings for bash
7 | conda init bash
8 |
9 | echo Activating
10 | source /opt/conda/bin/activate base
11 |
12 | echo Installing dev dependencies
13 | pip install watchdog
14 |
--------------------------------------------------------------------------------
/scripts/devContainerServer.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source /opt/conda/bin/activate base
4 |
5 | ln -sf /api/diffusers .
6 |
7 | watchmedo auto-restart --recursive -d api python api/server.py
--------------------------------------------------------------------------------
/scripts/patchmatch-setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | if [ "$USE_PATCHMATCH" != "1" ]; then
4 | echo "Skipping PyPatchMatch install because USE_PATCHMATCH=$USE_PATCHMATCH"
5 | mkdir PyPatchMatch
6 | touch PyPatchMatch/patch_match.py
7 | exit
8 | fi
9 |
10 | echo "Installing PyPatchMatch because USE_PATCHMATCH=$USE_PATCHMATCH"
11 | apt-get install -yqq libopencv-dev python3-opencv > /dev/null
12 | git clone https://github.com/lkwq007/PyPatchMatch
13 | cd PyPatchMatch
14 | git checkout 0ae9b8bbdc83f84214405376f13a2056568897fb
15 | sed -i '0,/if os.name!="nt":/s//if False:/' patch_match.py
16 | make
17 |
--------------------------------------------------------------------------------
/scripts/permutations.yaml:
--------------------------------------------------------------------------------
1 | list:
2 |
3 | - name: sd-v1-5
4 | HF_AUTH_TOKEN: $HF_AUTH_TOKEN
5 | MODEL_ID: runwayml/stable-diffusion-v1-5
6 | PIPELINE: ALL
7 |
8 | - name: sd-v1-4
9 | HF_AUTH_TOKEN: $HF_AUTH_TOKEN
10 | MODEL_ID: CompVis/stable-diffusion-v1-4
11 | PIPELINE: ALL
12 |
13 | - name: sd-inpaint
14 | HF_AUTH_TOKEN: $HF_AUTH_TOKEN
15 | MODEL_ID: runwayml/stable-diffusion-inpainting
16 | PIPELINE: StableDiffusionInpaintPipeline
17 |
18 | - name: sd-waifu
19 | HF_AUTH_TOKEN: $HF_AUTH_TOKEN
20 | MODEL_ID: hakurei/waifu-diffusion
21 | PIPELINE: ALL
22 |
23 | - name: sd-waifu-v1-3
24 | HF_AUTH_TOKEN: $HF_AUTH_TOKEN
25 | MODEL_ID: hakurei/waifu-diffusion-v1-3
26 | CHECKPOINT_URL: https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float16.ckpt
27 | PIPELINE: ALL
28 |
29 | - name: sd-jp
30 | HF_AUTH_TOKEN: $HF_AUTH_TOKEN
31 | MODEL_ID: rinna/japanese-stable-diffusion
32 | PIPELINE: ALL
33 |
--------------------------------------------------------------------------------
/scripts/permute.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Run this in banana-sd-base's PARENT directory.
4 | # Modify the below first per your preferences
5 |
6 | # Requires `yq` from https://github.com/mikefarah/yq
7 | # Note, there are two yqs. In Archlinux the package is "go-yq".
8 |
9 | if [ -z "$1" ]; then
10 | echo "Using 'scripts/permutations.yaml' as default INFILE"
11 | echo "You can also run: permutate.sh MY_INFILE"
12 | INFILE='scripts/permutations.yaml'
13 | else
14 | INFILE=$1
15 | fi
16 |
17 | if [ -z "$TARGET_REPO_BASE" ]; then
18 | TARGET_REPO_BASE="git@github.com:kiri-art"
19 | echo 'No TARGET_REPO_BASE found, using "$TARGET_REPO_BASE"'
20 | fi
21 |
22 | permutations=$(yq e -o=j -I=0 '.list[]' $INFILE)
23 |
24 | # Needed for ! expansion in cp command further down.
25 | shopt -s extglob
26 | # Include dot files in expansion for .git .gitignore
27 | shopt -s dotglob
28 |
29 | COUNTER=0
30 | declare -A vars
31 |
32 | mkdir -p permutations
33 |
34 | while IFS="=" read permutation; do
35 | # e.g. Permutation #1: banana-sd-txt2img
36 | NAME=$(echo "$permutation" | yq e '.name')
37 | COUNTER=$[$COUNTER + 1]
38 | echo
39 | echo "Permutation #$COUNTER: $NAME"
40 |
41 | while IFS="=" read -r key value
42 | do
43 | if [ "$key" != "name" ]; then
44 | if [ "${value:0:1}" == "$" ]; then
45 | # For e.g. "$HF_AUTH_TOKEN", expand from environment
46 | value="${value:1}"
47 | vars[$key]=${!value}
48 | else
49 | vars[$key]=$value;
50 | fi
51 | fi
52 | done < <(echo $permutation | yq e 'to_entries | .[] | (.key + "=" + .value)')
53 |
54 | if [ -d "permutations/$NAME" ]; then
55 | echo "./permutations/$NAME already exists, skipping..."
56 | echo "Run 'rm -rf permutations/$NAME' first to remake this permutation"
57 | echo "In a later release, we'll merge updates in this case."
58 | continue
59 | fi
60 |
61 | # echo "mkdir permutations/$NAME"
62 | mkdir permutations/$NAME
63 | # echo 'cp -a ./!(permutations|scripts|root-cache) permutations/$NAME'
64 | cp -a ./!(permutations|scripts|root-cache) permutations/$NAME
65 | # echo cd permutations/$NAME
66 | cd permutations/$NAME
67 |
68 | echo "Substituting variables in Dockerfile"
69 | for key in "${!vars[@]}"; do
70 | value="${vars[$key]}"
71 | sed -i "s@^ARG $key.*\$@ARG $key=\"$value\"@" Dockerfile
72 | done
73 |
74 | diffusers=${vars[diffusers]}
75 | if [ "$diffusers" ]; then
76 | echo "Replacing diffusers with $diffusers"
77 | echo "!!! NOT DONE YET !!!"
78 | fi
79 |
80 | mkdir root-cache
81 | touch root-cache/non-empty-directory
82 | git add root-cache
83 |
84 | git remote rm origin
85 | git remote add upstream git@github.com:kiri-art/docker-diffusers-api.git
86 | git remote add origin $TARGET_REPO_BASE/$NAME.git
87 |
88 | echo git commit -a -m "$NAME permutation variables"
89 | git commit -a -m "$NAME permutation variables"
90 |
91 | # echo "cd ../.."
92 | cd ../..
93 | echo
94 | done < 60000
221 | else f"{item[1]/1000:.1f}s"
222 | if item[1] > 1000
223 | else str(item[1]) + "ms",
224 | ),
225 | timings.items(),
226 | )
227 | )
228 | ).replace('"', "")[1:-1]
229 | print(f"Request took {finish:.1f}s ({timings_str})")
230 | else:
231 | print(f"Request took {finish:.1f}s")
232 |
233 | if (
234 | result.get("images_base64", None) == None
235 | and result.get("image_base64", None) == None
236 | ):
237 | error = result.get("$error", None)
238 | if error:
239 | code = error.get("code", None)
240 | name = error.get("name", None)
241 | message = error.get("message", None)
242 | stack = error.get("stack", None)
243 | if code and name and message and stack:
244 | print()
245 | title = f"Exception {code} on container:"
246 | print(title)
247 | print("-" * len(title))
248 | # print(f'{name}("{message}")') - stack includes it.
249 | print(stack)
250 | return
251 |
252 | print(json.dumps(result, indent=4))
253 | print()
254 | return result
255 |
256 | images_base64 = result.get("images_base64", None)
257 | if images_base64:
258 | for idx, image_byte_string in enumerate(images_base64):
259 | images_base64[idx] = decode_and_save(image_byte_string, f"{name}_{idx}")
260 | else:
261 | result["image_base64"] = decode_and_save(result["image_base64"], name)
262 |
263 | print()
264 | print(json.dumps(result, indent=4))
265 | print()
266 | return result
267 |
268 |
269 | test(
270 | "txt2img",
271 | {
272 | "modelInputs": {
273 | "prompt": "realistic field of grass",
274 | "num_inference_steps": 20,
275 | },
276 | "callInputs": {
277 | # "MODEL_ID": "", # (default)
278 | # "PIPELINE": "StableDiffusionPipeline", # (default)
279 | # "SCHEDULER": "DPMSolverMultistepScheduler", # (default)
280 | # "xformers_memory_efficient_attention": False, # (default)
281 | },
282 | },
283 | )
284 |
285 | # multiple images
286 | test(
287 | "txt2img-multiple",
288 | {
289 | "modelInputs": {
290 | "prompt": "realistic field of grass",
291 | "num_images_per_prompt": 2,
292 | }
293 | },
294 | )
295 |
296 |
297 | test(
298 | "img2img",
299 | {
300 | "modelInputs": {
301 | "prompt": "A fantasy landscape, trending on artstation",
302 | "image": b64encode_file("sketch-mountains-input.jpg"),
303 | },
304 | "callInputs": {
305 | "PIPELINE": "StableDiffusionImg2ImgPipeline",
306 | },
307 | },
308 | )
309 |
310 | test(
311 | "inpaint-v1-4",
312 | {
313 | "modelInputs": {
314 | "prompt": "a cat sitting on a bench",
315 | "image": b64encode_file("overture-creations-5sI6fQgYIuo.png"),
316 | "mask_image": b64encode_file("overture-creations-5sI6fQgYIuo_mask.png"),
317 | },
318 | "callInputs": {
319 | "MODEL_ID": "CompVis/stable-diffusion-v1-4",
320 | "PIPELINE": "StableDiffusionInpaintPipelineLegacy",
321 | "SCHEDULER": "DDIMScheduler", # Note, as of diffusers 0.3.0, no LMS yet
322 | },
323 | },
324 | )
325 |
326 | test(
327 | "inpaint-sd",
328 | {
329 | "modelInputs": {
330 | "prompt": "a cat sitting on a bench",
331 | "image": b64encode_file("overture-creations-5sI6fQgYIuo.png"),
332 | "mask_image": b64encode_file("overture-creations-5sI6fQgYIuo_mask.png"),
333 | },
334 | "callInputs": {
335 | "MODEL_ID": "runwayml/stable-diffusion-inpainting",
336 | "PIPELINE": "StableDiffusionInpaintPipeline",
337 | "SCHEDULER": "DDIMScheduler", # Note, as of diffusers 0.3.0, no LMS yet
338 | },
339 | },
340 | )
341 |
342 | test(
343 | "checkpoint",
344 | {
345 | "modelInputs": {
346 | "prompt": "1girl",
347 | },
348 | "callInputs": {
349 | "MODEL_ID": "hakurei/waifu-diffusion-v1-3",
350 | "MODEL_URL": "s3://",
351 | "CHECKPOINT_URL": "http://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float16.ckpt",
352 | },
353 | },
354 | )
355 |
356 | if os.getenv("USE_PATCHMATCH"):
357 | test(
358 | "outpaint",
359 | {
360 | "modelInputs": {
361 | "prompt": "girl with a pearl earing standing in a big room",
362 | "image": b64encode_file("girl_with_pearl_earing_outpainting_in.png"),
363 | },
364 | "callInputs": {
365 | "MODEL_ID": "CompVis/stable-diffusion-v1-4",
366 | "PIPELINE": "StableDiffusionInpaintPipelineLegacy",
367 | "SCHEDULER": "DDIMScheduler", # Note, as of diffusers 0.3.0, no LMS yet
368 | "FILL_MODE": "patchmatch",
369 | },
370 | },
371 | )
372 |
373 | # Actually we just want this to be a non-default test?
374 | if True or os.getenv("USE_DREAMBOOTH"):
375 | test(
376 | "dreambooth",
377 | # If you're calling from the command line, don't forget to a
378 | # specify a destination if you want your fine-tuned model to
379 | # be uploaded somewhere at the end.
380 | {
381 | "modelInputs": {
382 | "instance_prompt": "a photo of sks dog",
383 | "instance_images": list(
384 | map(
385 | b64encode_file,
386 | list(Path("tests/fixtures/dreambooth").iterdir()),
387 | )
388 | ),
389 | # Option 1: upload to HuggingFace (see notes below)
390 | # Make sure your HF API token has read/write access.
391 | # "hub_model_id": "huggingFaceUsername/targetModelName",
392 | # "push_to_hub": True,
393 | },
394 | "callInputs": {
395 | "train": "dreambooth",
396 | # Option 2: store on S3. Note the **s3:///* (x3). See notes below.
397 | # "dest_url": "s3:///bucket/filename.tar.zst".
398 | },
399 | },
400 | )
401 |
402 |
403 | def main(tests_to_run, args, extraCallInputs, extraModelInputs):
404 | invalid_tests = []
405 | for test in tests_to_run:
406 | if all_tests.get(test, None) == None:
407 | invalid_tests.append(test)
408 |
409 | if len(invalid_tests) > 0:
410 | print("No such tests: " + ", ".join(invalid_tests))
411 | exit(1)
412 |
413 | for test in tests_to_run:
414 | runTest(test, args, extraCallInputs, extraModelInputs)
415 |
416 |
417 | if __name__ == "__main__":
418 | parser = argparse.ArgumentParser()
419 | parser.add_argument("--banana", required=False, action="store_true")
420 | parser.add_argument("--runpod", required=False, action="store_true")
421 | parser.add_argument(
422 | "--xmfe",
423 | required=False,
424 | default=None,
425 | type=lambda x: bool(distutils.util.strtobool(x)),
426 | )
427 | parser.add_argument("--scheduler", required=False, type=str)
428 | parser.add_argument("--call-arg", action="append", type=str)
429 | parser.add_argument("--model-arg", action="append", type=str)
430 |
431 | args, tests_to_run = parser.parse_known_args()
432 |
433 | call_inputs = {}
434 | model_inputs = {}
435 |
436 | if args.call_arg:
437 | for arg in args.call_arg:
438 | name, value = arg.split("=", 1)
439 | if value.lower() == "true":
440 | value = True
441 | elif value.lower() == "false":
442 | value = False
443 | elif value.isdigit():
444 | value = int(value)
445 | elif value.replace(".", "", 1).isdigit():
446 | value = float(value)
447 | call_inputs.update({name: value})
448 |
449 | if args.model_arg:
450 | for arg in args.model_arg:
451 | name, value = arg.split("=", 1)
452 | if value.lower() == "true":
453 | value = True
454 | elif value.lower() == "false":
455 | value = False
456 | elif value.isdigit():
457 | value = int(value)
458 | elif value.replace(".", "", 1).isdigit():
459 | value = float(value)
460 | model_inputs.update({name: value})
461 |
462 | if args.xmfe != None:
463 | call_inputs.update({"xformers_memory_efficient_attention": args.xmfe})
464 | if args.scheduler:
465 | call_inputs.update({"SCHEDULER": args.scheduler})
466 |
467 | if len(tests_to_run) < 1:
468 | print(
469 | "Usage: python3 test.py [--banana] [--xmfe=1/0] [--scheduler=SomeScheduler] [all / test1] [test2] [etc]"
470 | )
471 | sys.exit()
472 | elif len(tests_to_run) == 1 and (
473 | tests_to_run[0] == "ALL" or tests_to_run[0] == "all"
474 | ):
475 | tests_to_run = list(all_tests.keys())
476 |
477 | main(
478 | tests_to_run,
479 | vars(args),
480 | extraCallInputs=call_inputs,
481 | extraModelInputs=model_inputs,
482 | )
483 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/__init__.py
--------------------------------------------------------------------------------
/tests/fixtures/dreambooth/alvan-nee-9M0tSjb-cpA-unsplash.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/fixtures/dreambooth/alvan-nee-9M0tSjb-cpA-unsplash.jpeg
--------------------------------------------------------------------------------
/tests/fixtures/dreambooth/alvan-nee-Id1DBHv4fbg-unsplash.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/fixtures/dreambooth/alvan-nee-Id1DBHv4fbg-unsplash.jpeg
--------------------------------------------------------------------------------
/tests/fixtures/dreambooth/alvan-nee-bQaAJCbNq3g-unsplash.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/fixtures/dreambooth/alvan-nee-bQaAJCbNq3g-unsplash.jpeg
--------------------------------------------------------------------------------
/tests/fixtures/dreambooth/alvan-nee-brFsZ7qszSY-unsplash.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/fixtures/dreambooth/alvan-nee-brFsZ7qszSY-unsplash.jpeg
--------------------------------------------------------------------------------
/tests/fixtures/dreambooth/alvan-nee-eoqnr8ikwFE-unsplash.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/fixtures/dreambooth/alvan-nee-eoqnr8ikwFE-unsplash.jpeg
--------------------------------------------------------------------------------
/tests/fixtures/girl_with_pearl_earing_outpainting_in.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/fixtures/girl_with_pearl_earing_outpainting_in.png
--------------------------------------------------------------------------------
/tests/fixtures/overture-creations-5sI6fQgYIuo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/fixtures/overture-creations-5sI6fQgYIuo.png
--------------------------------------------------------------------------------
/tests/fixtures/overture-creations-5sI6fQgYIuo_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/fixtures/overture-creations-5sI6fQgYIuo_mask.png
--------------------------------------------------------------------------------
/tests/fixtures/sketch-mountains-input.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/fixtures/sketch-mountains-input.jpg
--------------------------------------------------------------------------------
/tests/integration/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/tests/integration/__init__.py
--------------------------------------------------------------------------------
/tests/integration/conftest.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import os
3 | from .lib import startContainer, get_free_port, DOCKER_GW_IP
4 |
5 |
6 | @pytest.fixture(autouse=True, scope="session")
7 | def my_fixture():
8 | # setup_stuff
9 | print("session start")
10 |
11 | # newCache = not os.getenv("DDA_https_proxy")
12 | newCache = False
13 |
14 | if newCache:
15 | squid_port = get_free_port()
16 | http_port = get_free_port()
17 | container, stop = startContainer(
18 | "gadicc/squid-ssl-zero",
19 | ports={3128: squid_port, 3129: http_port},
20 | )
21 | os.environ["DDA_http_proxy"] = f"http://{DOCKER_GW_IP}:{squid_port}"
22 | os.environ["DDA_https_proxy"] = os.environ["DDA_http_proxy"]
23 | # TODO, code in getDDA to download cert
24 |
25 | yield
26 | # teardown_stuff
27 | print("session end")
28 | if newCache:
29 | stop()
30 |
--------------------------------------------------------------------------------
/tests/integration/lib.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import docker
3 | import atexit
4 | import time
5 | import boto3
6 | import os
7 | import requests
8 | import socket
9 | import asyncio
10 | import sys
11 | import subprocess
12 | import selectors
13 | from threading import Thread
14 | from argparse import Namespace
15 |
16 | AWS_S3_DEFAULT_BUCKET = os.environ.get("AWS_S3_DEFAULT_BUCKET", "test")
17 | DOCKER_GW_IP = "172.17.0.1" # will override below if found
18 |
19 | myContainers = list()
20 | dockerClient = docker.DockerClient(
21 | base_url="unix://var/run/docker.sock", version="auto"
22 | )
23 | for network in dockerClient.networks.list():
24 | if network.attrs["Scope"] == "local" and network.attrs["Driver"] == "bridge":
25 | DOCKER_GW_IP = network.attrs["IPAM"]["Config"][0]["Gateway"]
26 | break
27 |
28 | # # https://stackoverflow.com/a/53255955/1839099
29 | # def fire_and_forget(f):
30 | # def wrapped(*args, **kwargs):
31 | # return asyncio.get_event_loop().run_in_executor(None, f, *args, *kwargs)
32 | # return wrapped
33 | #
34 | # @fire_and_forget
35 | # def log_streamer(container):
36 | # for line in container.logs(stream=True):
37 | # print(line.decode(), end="")
38 |
39 |
40 | def log_streamer(container, name=None):
41 | """
42 | Streams logs to stdout/stderr.
43 | Order is not guaranteed (have tried 3 different methods)
44 | """
45 | # Method 1: pipe streams directly -- even this doesn't guarantee order
46 | # Method 2: threads + readline
47 | # Method 3: selectors + read1
48 | method = 1
49 |
50 | if method == 1:
51 | kwargs = {
52 | "stdout": sys.stdout,
53 | "stderr": sys.stderr,
54 | }
55 | elif method == 2:
56 | kwargs = {
57 | "stdout": subprocess.PIPE,
58 | "stderr": subprocess.PIPE,
59 | "bufsize": 1,
60 | "universal_newlines": True,
61 | }
62 | elif method == 3:
63 | kwargs = {
64 | "stdout": subprocess.PIPE,
65 | "stderr": subprocess.PIPE,
66 | "bufsize": 1,
67 | }
68 |
69 | prefix = f"[{name or container.id[:7]}] "
70 | print(prefix + "== Streaming logs (stdout/stderr order not guaranteed): ==")
71 |
72 | sp = subprocess.Popen(["docker", "logs", "-f", container.id], **kwargs)
73 |
74 | if method == 2:
75 |
76 | def reader(pipe):
77 | while True:
78 | read = pipe.readline()
79 | if read == "" and sp.poll() is not None:
80 | break
81 | print(prefix + read, end="")
82 | sys.stdout.flush()
83 | sys.stderr.flush()
84 |
85 | Thread(target=reader, args=[sp.stdout]).start()
86 | Thread(target=reader, args=[sp.stderr]).start()
87 |
88 | elif method == 3:
89 | selector = selectors.DefaultSelector()
90 | selector.register(sp.stdout, selectors.EVENT_READ)
91 | selector.register(sp.stderr, selectors.EVENT_READ)
92 | loop = True
93 |
94 | while loop:
95 | for key, _ in selector.select():
96 | data = key.fileobj.read1().decode()
97 | if not data:
98 | loop = False
99 | break
100 | line = prefix + str(data).rstrip().replace("\n", "\n" + prefix)
101 | if key.fileobj is sp.stdout:
102 | print(line)
103 | sys.stdout.flush()
104 | else:
105 | print(line, file=sys.stderr)
106 | sys.stderr.flush()
107 |
108 |
109 | def get_free_port():
110 | s = socket.socket()
111 | s.bind(("", 0))
112 | port = s.getsockname()[1]
113 | s.close()
114 | return port
115 |
116 |
117 | def startContainer(image, command=None, stream_logs=False, onstop=None, **kwargs):
118 | global myContainers
119 |
120 | container = dockerClient.containers.run(
121 | image,
122 | command,
123 | # auto_remove=True,
124 | detach=True,
125 | **kwargs,
126 | )
127 |
128 | if stream_logs:
129 | log_streamer(container)
130 |
131 | myContainers.append(container)
132 |
133 | def stop():
134 | print("stop", container.id)
135 | container.stop()
136 | container.remove()
137 | myContainers.remove(container)
138 | if onstop:
139 | onstop()
140 |
141 | while container.status != "running" and container.status != "exited":
142 | time.sleep(1)
143 | try:
144 | container.reload()
145 | except Exception as error:
146 | print(container.logs())
147 | raise error
148 | print(container.status)
149 |
150 | # if (container.status == "exited"):
151 | # print(container.logs())
152 | # raise Exception("unexpected exit")
153 |
154 | print("returned", container)
155 | return container, stop
156 |
157 |
158 | _minioCache = {}
159 |
160 |
161 | def getMinio(id="disposable"):
162 | cached = _minioCache.get(id, None)
163 | if cached:
164 | return Namespace(**cached)
165 |
166 | if id == "global":
167 | endpoint_url = os.getenv("AWS_S3_ENDPOINT_URL")
168 | if endpoint_url:
169 | print("Reusing existing global minio")
170 | aws_access_key_id = os.getenv("AWS_ACCESS_KEY_ID")
171 | aws_secret_access_key = os.getenv("AWS_SECRET_ACCESS_KEY")
172 | aws_s3_default_bucket = AWS_S3_DEFAULT_BUCKET
173 | s3 = boto3.client(
174 | "s3",
175 | endpoint_url=endpoint_url,
176 | config=boto3.session.Config(signature_version="s3v4"),
177 | aws_access_key_id=aws_access_key_id,
178 | aws_secret_access_key=aws_secret_access_key,
179 | aws_session_token=None,
180 | # verify=False,
181 | )
182 | result = {
183 | # don't link to actual container, and don't rm it at end
184 | "container": "global",
185 | "stop": lambda: print(),
186 | # "port": port,
187 | "endpoint_url": endpoint_url,
188 | "aws_access_key_id": aws_access_key_id,
189 | "aws_secret_access_key": aws_secret_access_key,
190 | "aws_s3_default_bucket": aws_s3_default_bucket,
191 | "s3": s3,
192 | }
193 | _minioCache.update({id: result})
194 | return Namespace(**result)
195 | else:
196 | print("Creating new global minio")
197 |
198 | port = get_free_port()
199 |
200 | def onstop():
201 | del _minioCache[id]
202 |
203 | container, stop = startContainer(
204 | "minio/minio",
205 | "server /data --console-address :9001",
206 | ports={9000: port},
207 | onstop=onstop,
208 | )
209 |
210 | endpoint_url = f"http://{DOCKER_GW_IP}:{port}"
211 |
212 | while True:
213 | time.sleep(1)
214 | response = None
215 | try:
216 | print(endpoint_url + "/minio/health/live")
217 | response = requests.get(endpoint_url + "/minio/health/live")
218 | except Exception as error:
219 | print(error)
220 |
221 | if response and response.status_code == 200:
222 | break
223 |
224 | aws_access_key_id = "minioadmin"
225 | aws_secret_access_key = "minioadmin"
226 | aws_s3_default_bucket = AWS_S3_DEFAULT_BUCKET
227 | s3 = boto3.client(
228 | "s3",
229 | endpoint_url=endpoint_url,
230 | config=boto3.session.Config(signature_version="s3v4"),
231 | aws_access_key_id=aws_access_key_id,
232 | aws_secret_access_key=aws_secret_access_key,
233 | aws_session_token=None,
234 | # verify=False,
235 | )
236 |
237 | s3.create_bucket(Bucket=AWS_S3_DEFAULT_BUCKET)
238 |
239 | result = {
240 | "container": container,
241 | "stop": stop,
242 | "port": port,
243 | "endpoint_url": endpoint_url,
244 | "aws_access_key_id": aws_access_key_id,
245 | "aws_secret_access_key": aws_secret_access_key,
246 | "aws_s3_default_bucket": aws_s3_default_bucket,
247 | "s3": s3,
248 | }
249 | _minioCache.update({id: result})
250 | return Namespace(**result)
251 |
252 |
253 | _ddaCache = None
254 |
255 |
256 | def getDDA(
257 | minio=None,
258 | command=None,
259 | environment={},
260 | stream_logs=False,
261 | wait=True,
262 | root_cache=True,
263 | **kwargs,
264 | ):
265 | global _ddaCache
266 | if _ddaCache:
267 | print("return _ddaCache")
268 | return Namespace(**_ddaCache)
269 | else:
270 | print("create new _dda")
271 |
272 | port = get_free_port()
273 |
274 | environment.update(
275 | {
276 | "HF_AUTH_TOKEN": os.getenv("HF_AUTH_TOKEN"),
277 | "http_proxy": os.getenv("DDA_http_proxy"),
278 | "https_proxy": os.getenv("DDA_https_proxy"),
279 | "REQUESTS_CA_BUNDLE": os.getenv("DDA_http_proxy")
280 | and "/usr/local/share/ca-certificates/squid-self-signed.crt",
281 | }
282 | )
283 |
284 | if minio:
285 | environment.update(
286 | {
287 | "AWS_ACCESS_KEY_ID": minio.aws_access_key_id,
288 | "AWS_SECRET_ACCESS_KEY": minio.aws_secret_access_key,
289 | "AWS_DEFAULT_REGION": "",
290 | "AWS_S3_DEFAULT_BUCKET": minio.aws_s3_default_bucket,
291 | "AWS_S3_ENDPOINT_URL": minio.endpoint_url,
292 | }
293 | )
294 |
295 | def onstop():
296 | global _ddaCache
297 | _ddaCache = None
298 |
299 | HOME = os.getenv("HOME")
300 |
301 | container, stop = startContainer(
302 | "gadicc/diffusers-api:test",
303 | command,
304 | stream_logs=stream_logs,
305 | ports={8000: port},
306 | device_requests=[docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])],
307 | environment=environment,
308 | volumes=root_cache and [f"{HOME}/root-cache:/root/.cache"],
309 | onstop=onstop,
310 | **kwargs,
311 | )
312 |
313 | url = f"http://{DOCKER_GW_IP}:{port}/"
314 |
315 | while wait:
316 | time.sleep(1)
317 | container.reload()
318 | if container.status == "exited":
319 | if not stream_logs:
320 | print("--- EARLY EXIT ---")
321 | print(container.logs().decode())
322 | print("--- EARLY EXIT ---")
323 | raise Exception("Early exit before successful healthcheck")
324 |
325 | response = None
326 | try:
327 | # print(url + "healthcheck")
328 | response = requests.get(url + "healthcheck")
329 | except Exception as error:
330 | # print(error)
331 | continue
332 |
333 | if response:
334 | if response.status_code == 200:
335 | result = response.json()
336 | if result["state"] == "healthy" and result["gpu"] == True:
337 | print("Ready")
338 | break
339 | else:
340 | print(response)
341 | print(response.text)
342 | else:
343 | raise Exception("Unexpected status code from dda/healthcheck")
344 |
345 | data = {
346 | "container": container,
347 | "stop": stop,
348 | "minio": minio,
349 | "port": port,
350 | "url": url,
351 | }
352 |
353 | _ddaCache = data
354 | return Namespace(**data)
355 |
356 |
357 | def cleanup():
358 | print("cleanup")
359 | for container in myContainers:
360 | print("Stopping")
361 | print(container)
362 | container.stop()
363 | print("removing")
364 | container.remove()
365 |
366 |
367 | atexit.register(cleanup)
368 |
--------------------------------------------------------------------------------
/tests/integration/requirements.txt:
--------------------------------------------------------------------------------
1 | pytest==7.2.0
2 | docker==6.0.1
3 | boto3==1.26.44
4 | Pillow==9.4.0
5 | # work around breaking changes in urllib3 2.0
6 | # until https://github.com/docker/docker-py/pull/3114/files lands
7 | urllib3<2
8 |
--------------------------------------------------------------------------------
/tests/integration/test_attn_procs.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | from .lib import getMinio, getDDA
4 | from test import runTest
5 |
6 | if False:
7 |
8 | class TestAttnProcs:
9 | def setup_class(self):
10 | print("setup_class")
11 | # self.minio = minio = getMinio("global")
12 |
13 | self.dda = dda = getDDA(
14 | # minio=minio
15 | stream_logs=True,
16 | )
17 | print(dda)
18 |
19 | self.TEST_ARGS = {"test_url": dda.url}
20 |
21 | def teardown_class(self):
22 | print("teardown_class")
23 | # self.minio.stop() - leave global up
24 | self.dda.stop()
25 |
26 | def test_lora_hf_download(self):
27 | """
28 | Download user/repo from HuggingFace.
29 | """
30 | # fp32 model is obviously bigger
31 | result = runTest(
32 | "txt2img",
33 | self.TEST_ARGS,
34 | {
35 | "MODEL_ID": "runwayml/stable-diffusion-v1-5",
36 | "MODEL_REVISION": "fp16",
37 | "MODEL_PRECISION": "fp16",
38 | "attn_procs": "patrickvonplaten/lora_dreambooth_dog_example",
39 | },
40 | {
41 | "num_inference_steps": 1,
42 | "prompt": "A picture of a sks dog in a bucket",
43 | "seed": 1,
44 | "cross_attention_kwargs": {"scale": 0.5},
45 | },
46 | )
47 |
48 | assert result["image_base64"]
49 |
50 | def test_lora_http_download_pytorch_bin(self):
51 | """
52 | Download pytroch_lora_weights.bin directly.
53 | """
54 | result = runTest(
55 | "txt2img",
56 | self.TEST_ARGS,
57 | {
58 | "MODEL_ID": "runwayml/stable-diffusion-v1-5",
59 | "MODEL_REVISION": "fp16",
60 | "MODEL_PRECISION": "fp16",
61 | "attn_procs": "https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/resolve/main/pytorch_lora_weights.bin",
62 | },
63 | {
64 | "num_inference_steps": 1,
65 | "prompt": "A picture of a sks dog in a bucket",
66 | "seed": 1,
67 | "cross_attention_kwargs": {"scale": 0.5},
68 | },
69 | )
70 |
71 | assert result["image_base64"]
72 |
73 | if False:
74 | # These formats are not supported by diffusers yet :(
75 | def test_lora_http_download_civitai_safetensors(self):
76 | result = runTest(
77 | "txt2img",
78 | self.TEST_ARGS,
79 | {
80 | "MODEL_ID": "runwayml/stable-diffusion-v1-5",
81 | "MODEL_REVISION": "fp16",
82 | "MODEL_PRECISION": "fp16",
83 | "attn_procs": "https://civitai.com/api/download/models/11523",
84 | "attn_procs_from_safetensors": True,
85 | },
86 | {
87 | "num_inference_steps": 1,
88 | "prompt": "A picture of a sks dog in a bucket",
89 | "seed": 1,
90 | },
91 | )
92 |
93 | assert result["image_base64"]
94 |
--------------------------------------------------------------------------------
/tests/integration/test_build_download.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from .lib import getMinio, getDDA
3 | from test import runTest
4 |
5 |
6 | def test_cloudcache_build_download():
7 | """
8 | Download a model from cloud-cache at build time (no HuggingFace)
9 | """
10 | minio = getMinio()
11 | print(minio)
12 | environment = {
13 | "RUNTIME_DOWNLOADS": 0,
14 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base",
15 | "MODEL_PRECISION": "fp16",
16 | "MODEL_REVISION": "fp16",
17 | "MODEL_URL": "s3://", # <--
18 | }
19 | # conda = "conda run --no-capture-output -n xformers"
20 | conda = ""
21 | dda = getDDA(
22 | minio=minio,
23 | stream_logs=True,
24 | environment=environment,
25 | root_cache=False,
26 | command=[
27 | "sh",
28 | "-c",
29 | f"{conda} python3 -u download.py && ls -l && {conda} python3 -u server.py",
30 | ],
31 | )
32 | print(dda)
33 | assert dda.container.status == "running"
34 |
35 | ## bucket.objects.all().delete()
36 | result = runTest(
37 | "txt2img",
38 | {"test_url": dda.url},
39 | {
40 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base",
41 | },
42 | {"num_inference_steps": 1},
43 | )
44 |
45 | dda.stop()
46 | minio.stop()
47 | assert result["image_base64"]
48 | print("test successs\n\n")
49 |
50 |
51 | def test_huggingface_build_download():
52 | """
53 | Download a model from HuggingFace at build time (no cloud-cache)
54 | NOTE / TODO: Good starting point, but this still runs with gpu and
55 | uploads if missing.
56 | """
57 | environment = {
58 | "RUNTIME_DOWNLOADS": 0,
59 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base",
60 | "MODEL_PRECISION": "fp16",
61 | "MODEL_REVISION": "fp16",
62 | }
63 | # conda = "conda run --no-capture-output -n xformers"
64 | conda = ""
65 | dda = getDDA(
66 | stream_logs=True,
67 | environment=environment,
68 | root_cache=False,
69 | command=[
70 | "sh",
71 | "-c",
72 | f"{conda} python3 -u download.py && ls -l && {conda} python3 -u server.py",
73 | ],
74 | )
75 | print(dda)
76 | assert dda.container.status == "running"
77 |
78 | ## bucket.objects.all().delete()
79 | result = runTest(
80 | "txt2img",
81 | {"test_url": dda.url},
82 | {
83 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base",
84 | # "MODEL_ID": "hf-internal-testing/tiny-stable-diffusion-pipe",
85 | "MODEL_PRECISION": "fp16",
86 | "MODEL_REVISION": "fp16",
87 | "MODEL_URL": "", # <-- no model_url, i.e. no cloud cache
88 | },
89 | {"num_inference_steps": 1},
90 | )
91 | dda.stop()
92 |
93 | assert result["image_base64"]
94 | print("test successs\n\n")
95 |
96 |
97 | def test_checkpoint_url_build_download():
98 | """
99 | Download and convert a .ckpt at build time. No cloud-cache.
100 | """
101 | environment = {
102 | "RUNTIME_DOWNLOADS": 0,
103 | "MODEL_ID": "hakurei/waifu-diffusion-v1-3",
104 | "MODEL_PRECISION": "fp16",
105 | "MODEL_REVISION": "fp16",
106 | "CHECKPOINT_URL": "https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/wd-v1-3-float16.ckpt",
107 | }
108 | # conda = "conda run --no-capture-output -n xformers"
109 | conda = ""
110 | dda = getDDA(
111 | stream_logs=True,
112 | environment=environment,
113 | root_cache=False,
114 | command=[
115 | "sh",
116 | "-c",
117 | f"{conda} python3 -u download.py && ls -l && {conda} python3 -u server.py",
118 | ],
119 | )
120 | print(dda)
121 | assert dda.container.status == "running"
122 |
123 | ## bucket.objects.all().delete()
124 | result = runTest(
125 | "txt2img",
126 | {"test_url": dda.url},
127 | {
128 | "MODEL_ID": "hakurei/waifu-diffusion-v1-3",
129 | "MODEL_PRECISION": "fp16",
130 | "MODEL_URL": "", # <-- no model_url, i.e. no cloud cache
131 | },
132 | {"num_inference_steps": 1},
133 | )
134 | dda.stop()
135 |
136 | assert result["image_base64"]
137 | print("test successs\n\n")
138 |
--------------------------------------------------------------------------------
/tests/integration/test_cloud_cache.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from .lib import getMinio, getDDA
3 | from test import runTest
4 |
5 |
6 | def test_cloud_cache_create_and_upload():
7 | """
8 | Check if model exists in cloud cache bucket download otherwise, save
9 | with safetensors, and upload model.tar.zst to bucket
10 | """
11 | minio = getMinio()
12 | print(minio)
13 | dda = getDDA(minio=minio, stream_logs=True, root_cache=False)
14 | print(dda)
15 |
16 | ## bucket.objects.all().delete()
17 | result = runTest(
18 | "txt2img",
19 | {"test_url": dda.url},
20 | {
21 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base",
22 | # "MODEL_ID": "hf-internal-testing/tiny-stable-diffusion-pipe",
23 | "MODEL_PRECISION": "fp16",
24 | "MODEL_REVISION": "fp16",
25 | "MODEL_URL": "s3://",
26 | },
27 | {"num_inference_steps": 1},
28 | )
29 |
30 | dda.stop()
31 | minio.stop()
32 | timings = result["$timings"]
33 | assert timings["download"] > 0
34 | assert timings["upload"] > 0
35 |
--------------------------------------------------------------------------------
/tests/integration/test_dreambooth.py:
--------------------------------------------------------------------------------
1 | import os
2 | from .lib import getMinio, getDDA
3 | from test import runTest
4 |
5 | HF_USERNAME = os.getenv("HF_USERNAME", "gadicc")
6 |
7 |
8 | class TestDreamBoothS3:
9 | """
10 | Train/Infer via S3 model save.
11 | """
12 |
13 | def setup_class(self):
14 | print("setup_class")
15 | self.minio = getMinio("global")
16 |
17 | def teardown_class(self):
18 | print("teardown_class")
19 | # self.minio.stop() # leave global up.
20 |
21 | def test_training_s3(self):
22 | dda = getDDA(
23 | minio=self.minio,
24 | stream_logs=True,
25 | )
26 | print(dda)
27 |
28 | result = runTest(
29 | "dreambooth",
30 | {"test_url": dda.url},
31 | {
32 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base",
33 | "MODEL_REVISION": "",
34 | "MODEL_PRECISION": "",
35 | "MODEL_URL": "s3://",
36 | "train": "dreambooth",
37 | "dest_url": f"s3:///{self.minio.aws_s3_default_bucket}/model.tar.zst",
38 | },
39 | {"max_train_steps": 1},
40 | )
41 |
42 | dda.stop()
43 | timings = result["$timings"]
44 | assert timings["training"] > 0
45 | assert timings["upload"] > 0
46 |
47 | # dependent on above, TODO, mark as such.
48 | def test_s3_download_and_inference(self):
49 | dda = getDDA(
50 | minio=self.minio,
51 | stream_logs=True,
52 | root_cache=False,
53 | )
54 | print(dda)
55 |
56 | result = runTest(
57 | "txt2img",
58 | {"test_url": dda.url},
59 | {
60 | "MODEL_ID": "model",
61 | "MODEL_PRECISION": "fp16",
62 | "MODEL_URL": f"s3:///{self.minio.aws_s3_default_bucket}/model.tar.zst",
63 | },
64 | {"num_inference_steps": 1},
65 | )
66 |
67 | dda.stop()
68 | assert result["image_base64"]
69 |
70 |
71 | if os.getenv("TEST_DREAMBOOTH_HF", None):
72 |
73 | class TestDreamBoothHF:
74 | def test_training_hf(self):
75 | dda = getDDA(
76 | stream_logs=True,
77 | )
78 | print(dda)
79 |
80 | result = runTest(
81 | "dreambooth",
82 | {"test_url": dda.url},
83 | {
84 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base",
85 | "MODEL_REVISION": "",
86 | "MODEL_PRECISION": "",
87 | "MODEL_URL": "s3://",
88 | "train": "dreambooth",
89 | },
90 | {
91 | "hub_model_id": f"{HF_USERNAME}/dreambooth_test",
92 | "push_to_hub": True,
93 | "max_train_steps": 1,
94 | },
95 | )
96 |
97 | dda.stop()
98 | timings = result["$timings"]
99 | assert timings["training"] > 0
100 | assert timings["upload"] > 0
101 |
102 | # dependent on above, TODO, mark as such.
103 | def test_hf_download_and_inference(self):
104 | dda = getDDA(
105 | stream_logs=True,
106 | root_cache=False,
107 | )
108 | print(dda)
109 |
110 | result = runTest(
111 | "txt2img",
112 | {"test_url": dda.url},
113 | {
114 | "MODEL_ID": f"{HF_USERNAME}/dreambooth_test",
115 | "MODEL_PRECISION": "fp16",
116 | },
117 | {"num_inference_steps": 1},
118 | )
119 |
120 | dda.stop()
121 | assert result["image_base64"]
122 |
123 | else:
124 |
125 | print(
126 | "Skipping dreambooth HuggingFace upload/download tests by default\n"
127 | "as they can be flaky. To run, set env var TEST_DREAMBOOTH_HF=1"
128 | )
129 |
--------------------------------------------------------------------------------
/tests/integration/test_general.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | from .lib import getMinio, getDDA
4 | from test import runTest
5 |
6 |
7 | class TestGeneralClass:
8 | """
9 | Typical usage tests, that assume model is already available locally.
10 | txt2img, img2img, inpaint.
11 | """
12 |
13 | CALL_ARGS = {
14 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base",
15 | "MODEL_PRECISION": "fp16",
16 | "MODEL_REVISION": "fp16",
17 | "MODEL_URL": "s3://",
18 | }
19 |
20 | MODEL_ARGS = {"num_inference_steps": 2}
21 |
22 | def setup_class(self):
23 | print("setup_class")
24 | self.minio = minio = getMinio("global")
25 |
26 | self.dda = dda = getDDA(
27 | minio=minio
28 | # stream_logs=True,
29 | )
30 | print(dda)
31 |
32 | self.TEST_ARGS = {"test_url": dda.url}
33 |
34 | def teardown_class(self):
35 | print("teardown_class")
36 | # self.minio.stop() - leave global up
37 | self.dda.stop()
38 |
39 | def test_txt2img(self):
40 | result = runTest("txt2img", self.TEST_ARGS, self.CALL_ARGS, self.MODEL_ARGS)
41 | assert result["image_base64"]
42 |
43 | def test_img2img(self):
44 | result = runTest("img2img", self.TEST_ARGS, self.CALL_ARGS, self.MODEL_ARGS)
45 | assert result["image_base64"]
46 |
47 | # def test_inpaint(self):
48 | # """
49 | # This is actually calling inpaint with SDv2.1, not the inpainting model,
50 | # so I guess we're testing inpaint-legacy.
51 | # """
52 | # result = runTest("inpaint", self.TEST_ARGS, self.CALL_ARGS, self.MODEL_ARGS)
53 | # assert result["image_base64"]
54 |
--------------------------------------------------------------------------------
/tests/integration/test_loras.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | from .lib import getMinio, getDDA
4 | from test import runTest
5 |
6 |
7 | class TestLoRAs:
8 | def setup_class(self):
9 | print("setup_class")
10 | # self.minio = minio = getMinio("global")
11 |
12 | self.dda = dda = getDDA(
13 | # minio=minio
14 | stream_logs=True,
15 | )
16 | print(dda)
17 |
18 | self.TEST_ARGS = {"test_url": dda.url}
19 |
20 | def teardown_class(self):
21 | print("teardown_class")
22 | # self.minio.stop() - leave global up
23 | self.dda.stop()
24 |
25 | if False:
26 |
27 | def test_lora_hf_download(self):
28 | """
29 | Download user/repo from HuggingFace.
30 | """
31 | # fp32 model is obviously bigger
32 | result = runTest(
33 | "txt2img",
34 | self.TEST_ARGS,
35 | {
36 | "MODEL_ID": "runwayml/stable-diffusion-v1-5",
37 | "MODEL_REVISION": "fp16",
38 | "MODEL_PRECISION": "fp16",
39 | "attn_procs": "patrickvonplaten/lora_dreambooth_dog_example",
40 | },
41 | {
42 | "num_inference_steps": 1,
43 | "prompt": "A picture of a sks dog in a bucket",
44 | "seed": 1,
45 | "cross_attention_kwargs": {"scale": 0.5},
46 | },
47 | )
48 |
49 | assert result["image_base64"]
50 |
51 | if False:
52 |
53 | def test_lora_http_download_pytorch_bin(self):
54 | """
55 | Download pytroch_lora_weights.bin directly.
56 | """
57 | result = runTest(
58 | "txt2img",
59 | self.TEST_ARGS,
60 | {
61 | "MODEL_ID": "runwayml/stable-diffusion-v1-5",
62 | "MODEL_REVISION": "fp16",
63 | "MODEL_PRECISION": "fp16",
64 | "attn_procs": "https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example/resolve/main/pytorch_lora_weights.bin",
65 | },
66 | {
67 | "num_inference_steps": 1,
68 | "prompt": "A picture of a sks dog in a bucket",
69 | "seed": 1,
70 | "cross_attention_kwargs": {"scale": 0.5},
71 | },
72 | )
73 |
74 | assert result["image_base64"]
75 |
76 | # These formats are not supported by diffusers yet :(
77 | def test_lora_http_download_civitai_safetensors(self):
78 | quickTest = True
79 |
80 | callInputs = {
81 | "MODEL_ID": "NED-v1-22",
82 | # https://civitai.com/models/10028/neverending-dream-ned?modelVersionId=64094
83 | "CHECKPOINT_URL": "https://civitai.com/api/download/models/64094#fname=neverendingDreamNED_v122BakedVae.safetensors",
84 | "MODEL_PRECISION": "fp16",
85 | # https://civitai.com/models/5373/makima-chainsaw-man-lora
86 | "lora_weights": "https://civitai.com/api/download/models/6244#fname=makima_offset.safetensors",
87 | "safety_checker": False,
88 | "PIPELINE": "lpw_stable_diffusion",
89 | }
90 | modelInputs = {
91 | # https://civitai.com/images/709482
92 | "num_inference_steps": 30,
93 | "prompt": "masterpiece, (photorealistic:1.4), best quality, beautiful lighting, (ulzzang-6500:0.5), makima \(chainsaw man\), (red hair)+(long braided hair)+(bangs), yellow eyes, golden eyes, ((ringed eyes)), (white shirt), (necktie), RAW photo, 8k uhd, film grain",
94 | "negative_prompt": "(painting by bad-artist-anime:0.9), (painting by bad-artist:0.9), watermark, text, error, blurry, jpeg artifacts, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, artist name, (worst quality, low quality:1.4), bad anatomy",
95 | "width": 864,
96 | "height": 1304,
97 | "seed": 2281759351,
98 | "guidance_scale": 9,
99 | }
100 |
101 | if quickTest:
102 | callInputs.update(
103 | {
104 | # i.e. use a model we already have
105 | "MODEL_ID": "runwayml/stable-diffusion-v1-5",
106 | "MODEL_REVISION": "fp16",
107 | "CHECKPOINT_URL": None,
108 | }
109 | )
110 | modelInputs.update(
111 | {
112 | "num_inference_steps": 1,
113 | "width": 512,
114 | "height": 512,
115 | }
116 | )
117 | result = runTest("txt2img", self.TEST_ARGS, callInputs, modelInputs)
118 |
119 | assert result["image_base64"]
120 |
--------------------------------------------------------------------------------
/tests/integration/test_memory.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import os
3 | from .lib import getMinio, getDDA
4 | from test import runTest
5 |
6 |
7 | def test_memory():
8 | """
9 | Make sure when switching models we release VRAM afterwards.
10 | """
11 | minio = getMinio("global")
12 | dda = getDDA(
13 | minio=minio,
14 | stream_logs=True,
15 | )
16 | print(dda)
17 |
18 | TEST_ARGS = {"test_url": dda.url}
19 | MODEL_ARGS = {"num_inference_steps": 1}
20 |
21 | mem_usage = list()
22 |
23 | # fp32 model is obviously bigger
24 | result = runTest(
25 | "txt2img",
26 | TEST_ARGS,
27 | {
28 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base",
29 | "MODEL_REVISION": "", # <--
30 | "MODEL_PRECISION": "", # <--
31 | "MODEL_URL": "s3://",
32 | },
33 | MODEL_ARGS,
34 | )
35 | mem_usage.append(result["$mem_usage"])
36 |
37 | # fp32 model is obviously smaller
38 | result = runTest(
39 | "txt2img",
40 | TEST_ARGS,
41 | {
42 | "MODEL_ID": "stabilityai/stable-diffusion-2-1-base",
43 | "MODEL_REVISION": "fp16", # <--
44 | "MODEL_PRECISION": "fp16", # <--
45 | "MODEL_URL": "s3://",
46 | },
47 | MODEL_ARGS,
48 | )
49 | mem_usage.append(result["$mem_usage"])
50 |
51 | print({"mem_usage": mem_usage})
52 | assert mem_usage[1] < mem_usage[0]
53 |
54 | dda.stop()
55 |
--------------------------------------------------------------------------------
/touch:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kiri-art/docker-diffusers-api/5521b2e6d63ef7afa9c3f7f3ef6cf520c1b54d7b/touch
--------------------------------------------------------------------------------
/update.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 |
3 | rsync -avzPe "ssh -p $1" api/ $2:/api/
4 |
--------------------------------------------------------------------------------