├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md ├── dependabot.yml ├── solutionid_validator.sh └── workflows │ ├── maintainer_workflows.yml │ └── pr-triage.yml ├── .gitignore ├── .gitmodules ├── .npmignore ├── CODEOWNERS ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── NOTICE ├── README.md ├── bin └── stable-difussion-on-eks.ts ├── cdk.json ├── config.schema.yaml ├── config.yaml ├── deploy ├── README.md ├── config.yaml.template ├── deploy.sh ├── install-tools.sh ├── update.sh └── upload-model.sh ├── docs ├── Gemfile ├── Gemfile.lock ├── README.md ├── api │ ├── v1alpha1.yaml │ └── v1alpha2.yaml ├── en │ ├── _config.yml │ ├── _includes │ │ └── image.html │ ├── async_img_sd_IG.md │ ├── async_img_sd_images │ │ ├── .gitkeep │ │ └── asynchronous-image-generation-with-stable-diffusion-on-aws.png │ └── async_img_sd_sidebar.yml ├── stable-diffusion-reference-architecture-updated.png └── zh │ ├── _config.yml │ ├── _includes │ └── image.html │ ├── async_img_sd_zh_IG.md │ ├── async_img_sd_zh_images │ ├── .gitkeep │ └── stable-diffusion-reference-architecture-updated.png │ └── async_img_sd_zh_sidebar.yml ├── lib ├── addons │ ├── dcgmExporter.ts │ ├── ebsThroughputTuner.ts │ ├── s3CSIDriver.ts │ └── sharedComponent.ts ├── dataPlane.ts ├── resourceProvider │ ├── s3GWEndpoint.ts │ ├── sns.ts │ └── vpc.ts ├── runtime │ └── sdRuntime.ts └── utils │ ├── namespace.ts │ └── validateConfig.ts ├── load_test └── locustfile.py ├── package-lock.json ├── package.json ├── src ├── backend │ └── queue_agent │ │ ├── .gitignore │ │ ├── Dockerfile │ │ ├── requirements.txt │ │ └── src │ │ ├── __init__.py │ │ ├── main.py │ │ ├── modules │ │ ├── __init__.py │ │ ├── http_action.py │ │ ├── misc.py │ │ ├── s3_action.py │ │ ├── sns_action.py │ │ ├── sqs_action.py │ │ └── time_utils.py │ │ └── runtimes │ │ ├── __init__.py │ │ ├── comfyui.py │ │ └── sdwebui.py ├── charts │ └── sd_on_eks │ │ ├── Chart.yaml │ │ ├── _helpers.tpl │ │ ├── templates │ │ ├── _helpers.tpl │ │ ├── aws-sqs-queue-scaledobject.yaml │ │ ├── configmap-comfyui.yml │ │ ├── configmap-queue-agent.yml │ │ ├── configmap-sdwebui.yml │ │ ├── deployment-comfyui.yaml │ │ ├── deployment-sdwebui.yaml │ │ ├── keda-trigger-auth-aws-credentials.yaml │ │ ├── nodeclass.yaml │ │ ├── nodepool.yaml │ │ ├── persistentvolume-s3.yaml │ │ └── persistentvolumeclaim.yaml │ │ └── values.yaml ├── frontend │ └── input_function │ │ ├── v1alpha1 │ │ └── app.py │ │ └── v1alpha2 │ │ └── app.py └── tools │ └── ebs_throughput_tuner │ └── app.py ├── test ├── run.sh ├── v1alpha1 │ ├── i2i.json │ ├── t2i.json │ └── t2v.json └── v1alpha2 │ ├── extra-single-image.json │ ├── i2i.json │ ├── pipeline.json │ └── t2i.json └── tsconfig.json /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG]" 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | **What happened**: 11 | 14 | 15 | **To Reproduce** 16 | Steps to reproduce the behavior: 17 | 18 | **Attach logs** 19 | 20 | **What you expected to happen**: 21 | 22 | **Anything else we need to know?**: 23 | 24 | **Environment**: 25 | 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[Feature]" 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Describe alternatives you've considered** 17 | A clear and concise description of any alternative solutions or features you've considered. 18 | 19 | **Additional context** 20 | Add any other context or screenshots about the feature request here. 21 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "npm" 4 | directory: "/" 5 | schedule: 6 | interval: "monthly" 7 | ignore: 8 | - dependency-name: "*" 9 | update-types: ["version-update:semver-patch", "version-update:semver-minor"] 10 | - dependency-name: "aws-cdk-lib" 11 | labels: 12 | - "kind/dependencies" 13 | - "component/cdk" 14 | 15 | - package-ecosystem: "pip" 16 | directory: "/src/backend/queue_agent" 17 | schedule: 18 | interval: "monthly" 19 | ignore: 20 | - dependency-name: "*" 21 | update-types: ["version-update:semver-patch", "version-update:semver-minor"] 22 | labels: 23 | - "kind/dependencies" 24 | - "component/queue-agent" 25 | -------------------------------------------------------------------------------- /.github/solutionid_validator.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #set -e 3 | 4 | echo "checking solution id $1" 5 | echo "grep -nr --exclude-dir='.github' "$1" ./.." 6 | result=$(grep -nr --exclude-dir='.github' "$1" ./..) 7 | if [ $? -eq 0 ] 8 | then 9 | echo "Solution ID $1 found\n" 10 | echo "$result" 11 | exit 0 12 | else 13 | echo "Solution ID $1 not found" 14 | exit 1 15 | fi 16 | 17 | export result 18 | -------------------------------------------------------------------------------- /.github/workflows/maintainer_workflows.yml: -------------------------------------------------------------------------------- 1 | # Workflows managed by aws-solutions-library-samples maintainers 2 | name: Maintainer Workflows 3 | on: 4 | # Triggers the workflow on push or pull request events but only for the "main" branch 5 | push: 6 | branches: [ "main" ] 7 | pull_request: 8 | branches: [ "main" ] 9 | types: [opened, reopened, edited] 10 | 11 | jobs: 12 | CheckSolutionId: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Run solutionid validator 17 | run: | 18 | chmod u+x ./.github/solutionid_validator.sh 19 | ./.github/solutionid_validator.sh ${{ vars.SOLUTIONID }} -------------------------------------------------------------------------------- /.github/workflows/pr-triage.yml: -------------------------------------------------------------------------------- 1 | name: Auto Approve document-related PR 2 | 3 | on: 4 | pull_request_target: 5 | types: [opened, synchronize] 6 | 7 | jobs: 8 | auto-approve: 9 | runs-on: ubuntu-latest 10 | permissions: 11 | pull-requests: write 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Get PR author 15 | id: get-author 16 | run: | 17 | PR_AUTHOR=$(jq -r '.pull_request.user.login' "$GITHUB_EVENT_PATH") 18 | echo "PR author: $PR_AUTHOR" 19 | echo "pr-author=$PR_AUTHOR" >> $GITHUB_OUTPUT 20 | - name: Check author permission 21 | id: author-permission 22 | uses: actions-cool/check-user-permission@v2 23 | with: 24 | username: ${{ steps.get-author.outputs.pr-author }} 25 | require: write 26 | - name: Auto-approve for break glass PR 27 | id: break-glass 28 | if: ${{ (steps.author-permission.outputs.require-result == 'true') && (startsWith(github.event.pull_request.title, '[Break Glass]')) }} 29 | uses: hmarr/auto-approve-action@v4 30 | - name: Auto-approve if author is able to write and contains only doc change 31 | id: doc-change 32 | if: steps.author-permission.outputs.require-result == 'true' 33 | uses: actions/github-script@v7 34 | with: 35 | github-token: ${{ secrets.GITHUB_TOKEN }} 36 | script: | 37 | const { data: files } = await github.rest.pulls.listFiles({ 38 | owner: context.repo.owner, 39 | repo: context.repo.repo, 40 | pull_number: context.issue.number 41 | }) 42 | 43 | const onlyDocChanges = files.every(file => file.filename.startsWith('docs/')) 44 | 45 | if (onlyDocChanges) { 46 | await github.rest.pulls.createReview({ 47 | owner: context.repo.owner, 48 | repo: context.repo.repo, 49 | pull_number: context.issue.number, 50 | event: 'APPROVE' 51 | }) 52 | console.log('Auto-approved PR by author') 53 | } else { 54 | console.log('PR does not meet auto-approval criteria') 55 | } -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### Helm ### 2 | # Chart dependencies 3 | **/charts/*.tgz 4 | 5 | ### Node ### 6 | # Logs 7 | logs 8 | *.log 9 | npm-debug.log* 10 | yarn-debug.log* 11 | yarn-error.log* 12 | lerna-debug.log* 13 | .pnpm-debug.log* 14 | 15 | # Diagnostic reports (https://nodejs.org/api/report.html) 16 | report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json 17 | 18 | # Runtime data 19 | pids 20 | *.pid 21 | *.seed 22 | *.pid.lock 23 | 24 | # Directory for instrumented libs generated by jscoverage/JSCover 25 | lib-cov 26 | 27 | # Coverage directory used by tools like istanbul 28 | coverage 29 | *.lcov 30 | 31 | # nyc test coverage 32 | .nyc_output 33 | 34 | # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) 35 | .grunt 36 | 37 | # Bower dependency directory (https://bower.io/) 38 | bower_components 39 | 40 | # node-waf configuration 41 | .lock-wscript 42 | 43 | # Compiled binary addons (https://nodejs.org/api/addons.html) 44 | build/Release 45 | 46 | # Dependency directories 47 | node_modules/ 48 | jspm_packages/ 49 | 50 | # Snowpack dependency directory (https://snowpack.dev/) 51 | web_modules/ 52 | 53 | # TypeScript cache 54 | *.tsbuildinfo 55 | 56 | # Optional npm cache directory 57 | .npm 58 | 59 | # Optional eslint cache 60 | .eslintcache 61 | 62 | # Optional stylelint cache 63 | .stylelintcache 64 | 65 | # Microbundle cache 66 | .rpt2_cache/ 67 | .rts2_cache_cjs/ 68 | .rts2_cache_es/ 69 | .rts2_cache_umd/ 70 | 71 | # Optional REPL history 72 | .node_repl_history 73 | 74 | # Output of 'npm pack' 75 | *.tgz 76 | 77 | # Yarn Integrity file 78 | .yarn-integrity 79 | 80 | # dotenv environment variable files 81 | .env 82 | .env.development.local 83 | .env.test.local 84 | .env.production.local 85 | .env.local 86 | 87 | # parcel-bundler cache (https://parceljs.org/) 88 | .cache 89 | .parcel-cache 90 | 91 | # Next.js build output 92 | .next 93 | out 94 | 95 | # Nuxt.js build / generate output 96 | .nuxt 97 | dist 98 | 99 | # Gatsby files 100 | .cache/ 101 | # Comment in the public line in if your project uses Gatsby and not Next.js 102 | # https://nextjs.org/blog/next-9-1#public-directory-support 103 | # public 104 | 105 | # vuepress build output 106 | .vuepress/dist 107 | 108 | # vuepress v2.x temp and cache directory 109 | .temp 110 | 111 | # Docusaurus cache and generated files 112 | .docusaurus 113 | 114 | # Serverless directories 115 | .serverless/ 116 | 117 | # FuseBox cache 118 | .fusebox/ 119 | 120 | # DynamoDB Local files 121 | .dynamodb/ 122 | 123 | # TernJS port file 124 | .tern-port 125 | 126 | # Stores VSCode versions used for testing VSCode extensions 127 | .vscode-test 128 | .vscode 129 | 130 | # yarn v2 131 | .yarn/cache 132 | .yarn/unplugged 133 | .yarn/build-state.yml 134 | .yarn/install-state.gz 135 | .pnp.* 136 | 137 | ### Node Patch ### 138 | # Serverless Webpack directories 139 | .webpack/ 140 | 141 | # Optional stylelint cache 142 | 143 | # SvelteKit build / generate output 144 | .svelte-kit 145 | 146 | 147 | # CDK asset staging directory 148 | .cdk.staging 149 | cdk.out 150 | cdk.context.json 151 | 152 | *.js 153 | !jest.config.js 154 | *.d.ts 155 | 156 | ### macOS ### 157 | # General 158 | .DS_Store 159 | .AppleDouble 160 | .LSOverride 161 | 162 | # Icon must end with two \r 163 | Icon 164 | 165 | 166 | # Thumbnails 167 | ._* 168 | 169 | # Files that might appear in the root of a volume 170 | .DocumentRevisions-V100 171 | .fseventsd 172 | .Spotlight-V100 173 | .TemporaryItems 174 | .Trashes 175 | .VolumeIcon.icns 176 | .com.apple.timemachine.donotpresent 177 | 178 | # Directories potentially created on remote AFP share 179 | .AppleDB 180 | .AppleDesktop 181 | Network Trash Folder 182 | Temporary Items 183 | .apdisk 184 | 185 | ### macOS Patch ### 186 | # iCloud generated files 187 | *.icloud 188 | 189 | ### JupyterNotebooks ### 190 | # gitignore template for Jupyter Notebooks 191 | # website: http://jupyter.org/ 192 | 193 | .ipynb_checkpoints 194 | */.ipynb_checkpoints/* 195 | 196 | # IPython 197 | profile_default/ 198 | ipython_config.py 199 | 200 | # mkdocs 201 | site/ 202 | 203 | ### Python ### 204 | # Byte-compiled / optimized / DLL files 205 | __pycache__/ 206 | *.py[cod] 207 | *$py.class 208 | 209 | # C extensions 210 | *.so 211 | 212 | # Distribution / packaging 213 | .Python 214 | build/ 215 | develop-eggs/ 216 | dist/ 217 | downloads/ 218 | eggs/ 219 | .eggs/ 220 | lib64/ 221 | parts/ 222 | sdist/ 223 | var/ 224 | wheels/ 225 | share/python-wheels/ 226 | *.egg-info/ 227 | .installed.cfg 228 | *.egg 229 | MANIFEST 230 | 231 | # PyInstaller 232 | # Usually these files are written by a python script from a template 233 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 234 | *.manifest 235 | *.spec 236 | 237 | # Installer logs 238 | pip-log.txt 239 | pip-delete-this-directory.txt 240 | 241 | # Unit test / coverage reports 242 | htmlcov/ 243 | .tox/ 244 | .nox/ 245 | .coverage 246 | .coverage.* 247 | .cache 248 | nosetests.xml 249 | coverage.xml 250 | *.cover 251 | *.py,cover 252 | .hypothesis/ 253 | .pytest_cache/ 254 | cover/ 255 | 256 | # Translations 257 | *.mo 258 | *.pot 259 | 260 | # Django stuff: 261 | *.log 262 | local_settings.py 263 | db.sqlite3 264 | db.sqlite3-journal 265 | 266 | # Flask stuff: 267 | instance/ 268 | .webassets-cache 269 | 270 | # Scrapy stuff: 271 | .scrapy 272 | 273 | # Sphinx documentation 274 | docs/_build/ 275 | 276 | # PyBuilder 277 | .pybuilder/ 278 | target/ 279 | 280 | # pyenv 281 | # For a library or package, you might want to ignore these files since the code is 282 | # intended to run in multiple environments; otherwise, check them in: 283 | # .python-version 284 | 285 | # pipenv 286 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 287 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 288 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 289 | # install all needed dependencies. 290 | #Pipfile.lock 291 | 292 | # poetry 293 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 294 | # This is especially recommended for binary packages to ensure reproducibility, and is more 295 | # commonly ignored for libraries. 296 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 297 | #poetry.lock 298 | 299 | # pdm 300 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 301 | #pdm.lock 302 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 303 | # in version control. 304 | # https://pdm.fming.dev/#use-with-ide 305 | .pdm.toml 306 | 307 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 308 | __pypackages__/ 309 | 310 | # Celery stuff 311 | celerybeat-schedule 312 | celerybeat.pid 313 | 314 | # SageMath parsed files 315 | *.sage.py 316 | 317 | # Environments 318 | .env 319 | .venv 320 | env/ 321 | venv/ 322 | ENV/ 323 | env.bak/ 324 | venv.bak/ 325 | 326 | # Spyder project settings 327 | .spyderproject 328 | .spyproject 329 | 330 | # Rope project settings 331 | .ropeproject 332 | 333 | # mypy 334 | .mypy_cache/ 335 | .dmypy.json 336 | dmypy.json 337 | 338 | # Pyre type checker 339 | .pyre/ 340 | 341 | # pytype static type analyzer 342 | .pytype/ 343 | 344 | # Cython debug symbols 345 | cython_debug/ 346 | 347 | # PyCharm 348 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 349 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 350 | # and can be added to the global gitignore or merged into this file. For a more nuclear 351 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 352 | #.idea/ 353 | 354 | ### Python Patch ### 355 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 356 | poetry.toml 357 | 358 | # ruff 359 | .ruff_cache/ 360 | 361 | # LSP config files 362 | pyrightconfig.json 363 | 364 | manual.md 365 | config.yaml 366 | 367 | ### Jekyll ### 368 | _site/ 369 | .sass-cache/ 370 | .jekyll-cache/ 371 | .jekyll-metadata 372 | # Ignore folders generated by Bundler 373 | .bundle/ 374 | vendor/ -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "utils/bottlerocket-images-cache"] 2 | path = utils/bottlerocket-images-cache 3 | url = https://github.com/aws-samples/bottlerocket-images-cache 4 | -------------------------------------------------------------------------------- /.npmignore: -------------------------------------------------------------------------------- 1 | *.ts 2 | !*.d.ts 3 | 4 | # CDK asset staging directory 5 | .cdk.staging 6 | cdk.out 7 | -------------------------------------------------------------------------------- /CODEOWNERS: -------------------------------------------------------------------------------- 1 | CODEOWNERS @aws-solutions-library-samples/maintainers @yubingjiaocn 2 | /.github/workflows/maintainer_workflows.yml @aws-solutions-library-samples/maintainers 3 | /.github/solutionid_validator.sh @aws-solutions-library-samples/maintainers 4 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT No Attribution 2 | 3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy of 6 | this software and associated documentation files (the "Software"), to deal in 7 | the Software without restriction, including without limitation the rights to 8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 9 | the Software, and to permit persons to whom the Software is furnished to do so. 10 | 11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 17 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | ********************** 4 | THIRD PARTY COMPONENTS 5 | ********************** 6 | 7 | This software includes third party software subject to the following licenses. 8 | 9 | KEDA 10 | Apache-2.0 license 11 | https://github.com/kedacore/keda/blob/main/LICENSE 12 | 13 | Karpenter 14 | Apache-2.0 license 15 | https://github.com/aws/karpenter/blob/main/LICENSE 16 | 17 | AWS CDK 18 | Apache-2.0 license 19 | https://github.com/aws/aws-cdk/blob/main/LICENSE 20 | 21 | Amazon EKS Blueprints for CDK 22 | Apache-2.0 license 23 | https://github.com/aws-quickstart/cdk-eks-blueprints/blob/main/LICENSE 24 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [Guidance for Asynchronous Inference with Stable Diffusion on AWS](https://aws.amazon.com/solutions/guidance/asynchronous-image-generation-with-stable-diffusion-on-aws/) 2 | 3 | Implementing a fast scaling and low cost Stable Diffusion inference solution with serverless and containers on AWS 4 | 5 | [Stable Diffusion](https://aws.amazon.com/what-is/stable-diffusion/) is a popular open source project for generating images using Generative AI. Building a scalable and cost efficient Machine Learning (ML) Inference solution is a common challenge that many AWS customers are facing. 6 | This project shows how to use serverless architcture and container services to build an end-to-end low cost, rapidly scaling asyncronous image generation architecture. This repo contains the sample code and CDK deployment scripts that will help you to deploy this solution in a few steps. 7 | 8 | ## Features 9 | 10 | - Asyncronous API and [Serverless Event-Driven Architecture](https://docs.aws.amazon.com/wellarchitected/latest/serverless-applications-lens/event-driven-architectures.html) 11 | - Image Generation with open source Stable Diffusion runtimes running on [Amazon EKS](https://aws.amazon.com/eks) 12 | - Automatic [Amazon SQS](https://aws.amazon.com/sqs/) queue length based scaling with [KEDA](https://keda.sh/) 13 | - Automatic provisioning of EC2 instances for Amazon EKS compute Nodes with [Karpenter](https://karpenter.sh/) 14 | - Scaling up new Amazon EKS nodes within 2 minutes to run Inference tasks 15 | - Saving up to 70% with [AWS GPU](https://aws.amazon.com/ec2/instance-types/g5/) spot EC2 instances 16 | 17 | ## Architecture diagram 18 | 19 | 20 |
21 | 22 |
23 | Figure 1: Asynchronous Image Generation with Stable Diffusion on AWS reference architecture 24 |
25 | 26 | ### Architecture steps 27 | 28 | 1. An application sends the prompt to [Amazon API Gateway](https://aws.amazon.com/api-gateway/) that acts as an endpoint for the overall Guidance, including authentication. [AWS Lambda](https://aws.amazon.com/lambda/) function validates the requests, publishes them to the designated [Amazon Simple Notification Service](https://aws.amazon.com/sns/) (Amazon SNS) topic, and immediately returns a response. 29 | 2. Amazon SNS publishes the message to [Amazon Simple Queue Service](https://aws.amazon.com/sqs/) (Amazon SQS) queues. Each message contains a Stable Diffusion (SD) runtime name attribute and will be delivered to the queues with matching SD runtime. 30 | 3. In the [Amazon Elastic Kubernetes Service](https://aws.amazon.com/eks/) (Amazon EKS) cluster, the previously deployed open source [Kubernetes Event Driven Auto-Scaler (KEDA)](https://keda.sh) scales up new pods to process the incoming messages from SQS model processing queues. 31 | 4. In the Amazon EKS cluster, the previously deployed open source Kubernetes auto-scaler, [Karpenter](https://karpenter.sh), launches new compute nodes based on GPU [Amazon Elastic Compute Cloud](https://aws.amazon.com/ec2/) (Amazon EC2) instances (such as g4, g5, and p4) to schedule pending pods. The instances use pre-cached SD Runtime images and are based on [Bottlerocket OS](https://aws.amazon.com/bottlerocket/) for fast boot. The instance can be launched with on-demand or [spot](https://aws.amazon.com/ec2/spot) pricing model. 32 | 5. Stable Diffusion Runtimes load ML model files from [Amazon Simple Storage Service](https://aws.amazon.com/efs/) (Amazon S3) via [Mountpoint for Amazon S3 CSI Driver](https://github.com/awslabs/mountpoint-s3-csi-driver) on runtime initialization or on demand. 33 | 6. Queue agents (software component created for this Guidance) receive messages from SQS model processing queues and convert them to inputs for SD Runtime APIs calls. 34 | 7. Queue agents call SD Runtime APIs, receive and decode responses, and save the generated images to designated Amazon S3 buckets. 35 | 8. Queue agents send notifications to the designated SNS topic from the pods, user receives notifications from SNS and can access images in S3 buckets. 36 | 37 | 38 | ### AWS services in this Guidance 39 | 40 | | **AWS service** | Description | 41 | |-----------|------------| 42 | |[Amazon Elastic Kubernetes Service - EKS](https://aws.amazon.com/eks/)|Core service - application platform host the SD containerized workloads| 43 | |[Amazon Virtual Private Cloud - VPC](https://aws.amazon.com/vpc/)| Core Service - network security layer | 44 | |[Amazon Elastic Compute Cloud - EC2](https://aws.amazon.com/ec2/)| Core Service - EC2 instance power On Demand and Spot based EKS compute node groups for running container workloads| 45 | |[Amazon Elastic Container Registry - ECR](https://aws.amazon.com/ecr/)|Core service - ECR registry is used to host the container images and Helm charts| 46 | |[Amazon Simple Storage Service S3](https://aws.amazon.com/s3/)|Core service - Object storage for model files and generated image| 47 | |[Amazon API Gateway](https://aws.amazon.com/api-gateway/)| Core service - endpoint for all user requests| 48 | |[AWS Lambda](https://aws.amazon.com/lambda/)| Core service - validates the requests, publishes them to the designated queues | 49 | |[Amazon Simple Queue Service](https://aws.amazon.com/sqs/)| Core service - provides asynchronous event handling | 50 | |[Amazon Simple Notification Service](https://aws.amazon.com/sns/)| Core service - provides model specific event processing | 51 | |[Amazon CloudWatch](https://aws.amazon.com/cloudwatch/)|Auxiliary service - provides observability for core services | 52 | |[AWS CDK](https://aws.amazon.com/cdk/) | Core service - Used for deploying and updating this solution| 53 | 54 | ### Cost 55 | 56 | You are responsible for the cost of the AWS services used while running this Guidance. As of April 2024, the cost for running this 57 | Guidance with the default settings in the US West (Oregon) is approximately for one month and generating one million images would cost approximately **$436.72** as illustrated in the two sample tables below (excluding free tiers). 58 | 59 | We recommend creating a [budget](https://docs.aws.amazon.com/cost-management/latest/userguide/budgets-create.html) through [AWS Cost Explorer](http://aws.amazon.com/aws-cost-management/aws-cost-explorer/) to help monitor and manage costs. Prices are subject to change. For full details, refer to the pricing webpage for each AWS service used in this Guidance. 60 | 61 | The main services and their pricing for usage related to the number of images are listed below (per one million images): 62 | 63 | | **AWS Service** | **Billing Dimension** | **Quantity per 1M Images** | **Unit Price \[USD\]** | **Total \[USD\]** | 64 | |-----------|------------|------------|------------|------------| 65 | | Amazon EC2 | g5.2xlarge instance, Spot instance per hour | 416.67 | \$ 0.4968 | \$ 207 | 66 | | Amazon API Gateway | Per 1M REST API requests | 1 | \$ 3.50 | \$ 3.50 | 67 | | AWS Lambda | Per GB-second | 12,50 | \$ 0.0000166667 | \$ 0.21 | 68 | | AWS Lambda | Per 1M requests | 1 | \$ 0.20 | \$ 0.20 | 69 | | Amazon SNS | Per 1M requests | 2 | \$ 0.50 | \$ 0.50 | 70 | | Amazon SNS | Data transfer per GB | 7.62\** | \$ 0.09 | \$ 0.68 | 71 | | Amazon SQS | Per 1M requests | 2 | \$ 0.40 | \$ 0.80 | 72 | | Amazon S3 | Per 1K PUT requests | 2,000 | \$ 0.005 | \$ 10.00 | 73 | | Amazon S3 | Per GB per month | 143.05\*** | \$ 0.023 | \$ 3.29 | 74 | | **Total, 1M images** |   |   |   | **\$226.18** | 75 | 76 | The fixed costs unrelated to the number of images, with the main services and their pricing listed below (per month): 77 | 78 | | **AWS Service** | Billing Dimension | Quantity per Month | Unit Price \[USD\] | Total \[USD\] 79 | |-----------|------------|------------|------------|------------| 80 | | Amazon EKS | Cluster | 1 | \$ 72.00 | \$ 72.00 | 81 | | Amazon EC2 | m5.large instance, On-Demand instance per hour | 1440 | \$ 0.0960 | \$ 138.24 | 82 | | **Total, month** |   |   |   | **\$210.24** | 83 | 84 | - \* Calculated based on an average request duration of 1.5 seconds and the average Spot instance pricing across all Availability Zones in the `us-west-2` (Oregon) Region from January 29, 2024, to April 28, 2024. 85 | - \*\* Calculated based on an average request size of 16 KB 86 | - \*\*\* Calculated based on an average image size of 150 KB, stored for 1 month. 87 | 88 | Please note that thise are estimated costs for reference only. The actual cost may vary depending on the model you use, task parameters, current Spot instance pricing, and other factors. 89 | 90 | ## Deployment Documentation 91 | 92 | Please see detailed Implementation Guides here: 93 | - [English](https://aws-solutions-library-samples.github.io/ai-ml/asynchronous-image-generation-with-stable-diffusion-on-aws.html) 94 | - [Chinese 简体中文 ](https://aws-solutions-library-samples.github.io/ai-ml/asynchronous-image-generation-with-stable-diffusion-on-aws-zh.html) 95 | 96 | ## Security 97 | 98 | When you build systems on AWS infrastructure, security responsibilities are shared between you and AWS. This [shared responsibility model](https://aws.amazon.com/compliance/shared-responsibility-model/) reduces your operational burden because AWS operates, manages, and 99 | controls the components, including host operating systems, the virtualization layer, and the physical security of the facilities in 100 | which the services operate. For more information about AWS security, visit [AWS Cloud Security](http://aws.amazon.com/security/). 101 | 102 | For potential security issue, see [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 103 | 104 | ## License 105 | 106 | This library is licensed under MIT-0 License. See the [LICENSE](LICENSE) file. 107 | -------------------------------------------------------------------------------- /bin/stable-difussion-on-eks.ts: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | import { App, Aspects } from 'aws-cdk-lib'; 3 | import DataPlaneStack from "../lib/dataPlane"; 4 | import { parse } from 'yaml' 5 | import * as fs from 'fs' 6 | import { validateConfig } from '../lib/utils/validateConfig'; 7 | 8 | const app = new App(); 9 | 10 | const env = { 11 | account: process.env.CDK_DEFAULT_ACCOUNT, 12 | region: process.env.CDK_DEFAULT_REGION, 13 | } 14 | 15 | let filename: string 16 | 17 | if ("CDK_CONFIG_PATH" in process.env) { 18 | filename = process.env.CDK_CONFIG_PATH as string 19 | } else { 20 | filename = 'config.yaml' 21 | } 22 | 23 | const file = fs.readFileSync(filename, 'utf8') 24 | const props = parse(file) 25 | 26 | if (validateConfig(props)) { 27 | const dataPlaneStack = new DataPlaneStack(app, props.stackName, props, { 28 | env: env, 29 | description: "Guidance for Asynchronous Image Generation with Stable Diffusion on AWS (SO9306)" 30 | }); 31 | } else { 32 | console.log("Deployment failed due to failed validation. Please check and try again.") 33 | } -------------------------------------------------------------------------------- /cdk.json: -------------------------------------------------------------------------------- 1 | { 2 | "app": "npx ts-node --prefer-ts-exts bin/stable-difussion-on-eks.ts", 3 | "watch": { 4 | "include": [ 5 | "**" 6 | ], 7 | "exclude": [ 8 | "README.md", 9 | "cdk*.json", 10 | "**/*.d.ts", 11 | "**/*.js", 12 | "tsconfig.json", 13 | "package*.json", 14 | "yarn.lock", 15 | "node_modules", 16 | "test" 17 | ] 18 | }, 19 | "context": { 20 | "@aws-cdk/aws-lambda:recognizeLayerVersion": true, 21 | "@aws-cdk/core:checkSecretUsage": true, 22 | "@aws-cdk/core:target-partitions": [ 23 | "aws", 24 | "aws-cn" 25 | ], 26 | "@aws-cdk-containers/ecs-service-extensions:enableDefaultLogDriver": true, 27 | "@aws-cdk/aws-ec2:uniqueImdsv2TemplateName": true, 28 | "@aws-cdk/aws-ecs:arnFormatIncludesClusterName": true, 29 | "@aws-cdk/aws-iam:minimizePolicies": true, 30 | "@aws-cdk/core:validateSnapshotRemovalPolicy": true, 31 | "@aws-cdk/aws-codepipeline:crossAccountKeyAliasStackSafeResourceName": true, 32 | "@aws-cdk/aws-s3:createDefaultLoggingPolicy": true, 33 | "@aws-cdk/aws-sns-subscriptions:restrictSqsDescryption": true, 34 | "@aws-cdk/aws-apigateway:disableCloudWatchRole": true, 35 | "@aws-cdk/core:enablePartitionLiterals": true, 36 | "@aws-cdk/aws-events:eventsTargetQueueSameAccount": true, 37 | "@aws-cdk/aws-iam:standardizedServicePrincipals": true, 38 | "@aws-cdk/aws-ecs:disableExplicitDeploymentControllerForCircuitBreaker": true, 39 | "@aws-cdk/aws-iam:importedRoleStackSafeDefaultPolicyName": true, 40 | "@aws-cdk/aws-s3:serverAccessLogsUseBucketPolicy": true, 41 | "@aws-cdk/aws-route53-patters:useCertificate": true, 42 | "@aws-cdk/customresources:installLatestAwsSdkDefault": false, 43 | "@aws-cdk/aws-rds:databaseProxyUniqueResourceName": true, 44 | "@aws-cdk/aws-codedeploy:removeAlarmsFromDeploymentGroup": true, 45 | "@aws-cdk/aws-apigateway:authorizerChangeDeploymentLogicalId": true, 46 | "@aws-cdk/aws-ec2:launchTemplateDefaultUserData": true, 47 | "@aws-cdk/aws-secretsmanager:useAttachedSecretResourcePolicyForSecretTargetAttachments": true, 48 | "@aws-cdk/aws-redshift:columnId": true, 49 | "@aws-cdk/aws-stepfunctions-tasks:enableEmrServicePolicyV2": true, 50 | "@aws-cdk/aws-ec2:restrictDefaultSecurityGroup": true, 51 | "@aws-cdk/aws-apigateway:requestValidatorUniqueId": true, 52 | "@aws-cdk/aws-kms:aliasNameRef": true, 53 | "@aws-cdk/aws-autoscaling:generateLaunchTemplateInsteadOfLaunchConfig": true, 54 | "@aws-cdk/core:includePrefixInUniqueNameGeneration": true, 55 | "@aws-cdk/aws-opensearchservice:enableOpensearchMultiAzWithStandby": true 56 | }, 57 | "requireApproval": "never", 58 | "rollback": false 59 | } 60 | -------------------------------------------------------------------------------- /config.schema.yaml: -------------------------------------------------------------------------------- 1 | stackName: str() 2 | modelBucketArn: str(starts_with="arn") 3 | APIGW: 4 | stageName: str() 5 | throttle: 6 | rateLimit: int() 7 | burstLimit: int() 8 | modelsRuntime: list(include('runtime'), min=1) 9 | --- 10 | runtime: 11 | name: str() 12 | namespace: str() 13 | modelFilename: str(required=False) 14 | dynamicModel: bool(required=False) 15 | type: enum('sdwebui', 'comfyui') 16 | extraValues: include('extraValues') 17 | --- 18 | ebs: 19 | volumeSize: str(required=False) 20 | volumeType: str(required=False) 21 | deleteOnTermination: bool(required=False) 22 | iops: int(min=3000, max=16000, required=False) 23 | throughput: int(min=125, max=1000, required=False) 24 | snapshotID: str(starts_with="snap", required=False) 25 | --- 26 | extraValues: 27 | karpenter: include('karpenter', required=False) 28 | runtime: include('runtimeValues', required=False) 29 | --- 30 | karpenter: 31 | provisioner: include('provisioner', required=False) 32 | nodeTemplate: include('nodeTemplate', required=False) 33 | --- 34 | provisioner: 35 | labels: map(str(), str(), required=False) 36 | capacityType: 37 | onDemand: bool(required=False) 38 | spot: bool(required=False) 39 | instanceType: list(str(), required=False) 40 | extraRequirements: list(any(), required=False) 41 | extraTaints: list(any(), required=False) 42 | resourceLimits: include('resources', required=False) 43 | consolidation: bool(required=False) 44 | disruption: include('disruption', required=False) 45 | --- 46 | disruption: 47 | consolidateAfter: str(required=False) 48 | expireAfter: str(required=False) 49 | --- 50 | nodeTemplate: 51 | securityGroupSelector: map(str(), str(), required=False) 52 | subnetSelector: map(str(), str(), required=False) 53 | tags: map(str(), str(), required=False) 54 | amiFamily: enum('Bottlerocket', required=False) 55 | osVolume: include('ebs', required=False) 56 | dataVolume: include('ebs', required=False) 57 | userData: str(required=False) 58 | --- 59 | runtimeValues: 60 | labels: map(str(), str(), required=False) 61 | annotations: map(str(), str(), required=False) 62 | scaling: include('scaling', required=False) 63 | inferenceApi: include('inferenceApi', required=False) 64 | queueAgent: include('inferenceApi', required=False) 65 | --- 66 | scaling: 67 | enabled: bool(required=False) 68 | queueLength: int(min=1, required=False) 69 | cooldownPeriod: int(min=1, required=False) 70 | maxReplicaCount: int(min=0, required=False) 71 | minReplicaCount: int(min=0, required=False) 72 | pollingInterval: int(required=False) 73 | scaleOnInFlight: bool(required=False) 74 | extraHPAConfig: any(required=False) 75 | --- 76 | image: 77 | repository: str() 78 | tag: str() 79 | --- 80 | resources: 81 | nvidia.com/gpu: str(required=False) 82 | cpu: str(required=False) 83 | memory: str(required=False) 84 | --- 85 | inferenceApi: 86 | image: include('image', required=False) 87 | modelMountPath: str(required=False) 88 | commandArguments: str(required=False) 89 | extraEnv: map(str(), str(), required=False) 90 | imagePullPolicy: enum('Always', 'IfNotPresent', 'Never') 91 | resources: 92 | limits: include('resources', required=False) 93 | requests: include('resources', required=False) 94 | --- 95 | queueAgent: 96 | image: include('image', required=False) 97 | extraEnv: map(str(), str()) 98 | imagePullPolicy: enum('Always', 'IfNotPresent', 'Never') 99 | resources: 100 | limits: include('resources', required=False) 101 | requests: include('resources', required=False) 102 | XRay: 103 | enabled: bool(required=False) 104 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | stackName: sdoneks 2 | modelBucketArn: arn:aws:s3:::dummy-bucket 3 | APIGW: 4 | stageName: dev 5 | throttle: 6 | rateLimit: 30 7 | burstLimit: 50 8 | modelsRuntime: 9 | - name: sdruntime 10 | namespace: default 11 | modelFilename: v1-5-pruned-emaonly.safetensors 12 | dynamicModel: false 13 | type: sdwebui 14 | extraValues: 15 | karpenter: 16 | nodeTemplate: 17 | amiFamily: Bottlerocket 18 | # dataVolume: 19 | # snapshotID: snap-0123456789 20 | provisioner: 21 | instanceType: 22 | - g6.xlarge 23 | - g6.2xlarge 24 | capacityType: 25 | onDemand: true 26 | spot: true 27 | runtime: 28 | scaling: 29 | queueLength: 10 30 | minReplicaCount: 0 31 | cooldownPeriod: 300 32 | -------------------------------------------------------------------------------- /deploy/README.md: -------------------------------------------------------------------------------- 1 | # One-key deployment script 2 | 3 | This script will work as a quick start for this solution. This script will: 4 | 5 | * Install required tools 6 | * Download Stable Diffusions 1.5 model from HuggingFace and upload to S3 bucket 7 | * Generate EBS snapshot with prefetched container images 8 | * Generate a sample config file 9 | * Deploy SD on EKS solutions 10 | 11 | ## Usage 12 | 13 | ```bash 14 | cd deploy 15 | ./deploy.sh 16 | ``` 17 | 18 | ## Test after deploy 19 | 20 | This script will generate text-to-image and image-to-image request to SD on EKS endpoints. 21 | 22 | ```bash 23 | cd ../test 24 | ./run.sh 25 | ``` -------------------------------------------------------------------------------- /deploy/config.yaml.template: -------------------------------------------------------------------------------- 1 | # Auto generated config file 2 | # Run cdk deploy to deploy SD on EKS stack 3 | 4 | stackName: ${STACK_NAME} 5 | modelBucketArn: arn:aws:s3:::${MODEL_BUCKET} 6 | APIGW: 7 | stageName: dev 8 | throttle: 9 | rateLimit: 30 10 | burstLimit: 50 11 | modelsRuntime: 12 | - name: ${RUNTIME_NAME} 13 | namespace: "default" 14 | modelFilename: "sd_xl_turbo_1.0.safetensors" 15 | dynamicModel: false 16 | type: ${RUNTIME_TYPE} 17 | extraValues: 18 | karpenter: 19 | nodeTemplate: 20 | amiFamily: Bottlerocket 21 | dataVolume: 22 | snapshotID: ${SNAPSHOT_ID} 23 | provisioner: 24 | instanceType: 25 | - "g6.2xlarge" 26 | - "g5.2xlarge" 27 | capacityType: 28 | onDemand: true 29 | spot: true 30 | runtime: 31 | scaling: 32 | queueLength: 10 33 | minReplicaCount: 0 34 | maxReplicaCount: 5 35 | cooldownPeriod: 300 -------------------------------------------------------------------------------- /deploy/deploy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | SHORTOPTS="h,n:,R:,d,b:,s:,r:,t:,T" 6 | LONGOPTS="help,stack-name:,region:,dry-run,bucket:,snapshot:,runtime-name:,runtime-type:,skip-tools" 7 | ARGS=$(getopt --options $SHORTOPTS --longoptions $LONGOPTS -- "$@" ) 8 | 9 | eval set -- "$ARGS" 10 | while true; 11 | do 12 | case $1 in 13 | -h|--help) 14 | printf "Usage: deploy.sh [options] \n" 15 | printf "Options: \n" 16 | printf " -h, --help Print this help message \n" 17 | printf " -T, --skip-tools Skip tools installation \n" 18 | printf " -n, --stack-name Name of the stack to be created (Default: sdoneks) \n" 19 | printf " -R, --region AWS region to be used \n" 20 | printf " -d, --dry-run Don't deploy the stack. You can manually deploy the stack using generated config file.\n" 21 | printf " -b, --bucket S3 bucket name to store the model. If not specified, S3 bucket will be created and SD 1.5 model will be placed. \n" 22 | printf " -s, --snapshot EBS snapshot ID with container image. If not specified, EBS snapshot will be automatically generated. \n" 23 | printf " -r, --runtime-name Runtime name. (Default: sdruntime) \n" 24 | printf " -t, --runtime-type Runtime type. Only 'sdwebui' and 'comfyui' is accepted. (Default: sdwebui) \n" 25 | exit 0 26 | ;; 27 | -n|--stack-name) 28 | STACK_NAME=$2 29 | shift 2 30 | ;; 31 | -T|--skip-tools) 32 | INSTALL_TOOLS=false 33 | shift 34 | ;; 35 | -R|--region) 36 | AWS_DEFAULT_REGION=$2 37 | shift 2 38 | ;; 39 | -d|--dry-run) 40 | DEPLOY=false 41 | shift 42 | ;; 43 | -b|--bucket) 44 | MODEL_BUCKET=$2 45 | shift 2 46 | ;; 47 | -s|--snapshot) 48 | SNAPSHOT_ID=$2 49 | shift 2 50 | ;; 51 | -r|--runtime-name) 52 | RUNTIME_NAME=$2 53 | shift 2 54 | ;; 55 | -t|--runtime-type) 56 | RUNTIME_TYPE=$2 57 | if [[ "$RUNTIME_TYPE" != "sdwebui" && "$RUNTIME_TYPE" != "comfyui" ]] ; then 58 | printf "Runtime type should be 'sdwebui' or 'comfyui'. \n" 59 | exit 1 60 | fi 61 | shift 2 62 | ;; 63 | --) 64 | shift 65 | break 66 | ;; 67 | ?) 68 | shift 69 | printf "invalid parameter" 70 | exit 1 71 | ;; 72 | esac 73 | done 74 | 75 | 76 | SCRIPTPATH=$(realpath $(dirname "$0")) 77 | export AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION:-$(aws ec2 describe-availability-zones --output text --query 'AvailabilityZones[0].[RegionName]')} 78 | declare -l STACK_NAME=${STACK_NAME:-"sdoneks"} 79 | RUNTIME_NAME=${RUNTIME_NAME:-"sdruntime"} 80 | declare -l RUNTIME_TYPE=${RUNTIME_TYPE:-"sdwebui"} 81 | INSTALL_TOOLS=${INSTALL_TOOLS:-true} 82 | DEPLOY=${DEPLOY:-true} 83 | SDWEBUI_IMAGE=public.ecr.aws/bingjiao/sd-on-eks/sdwebui:latest 84 | COMFYUI_IMAGE=public.ecr.aws/bingjiao/sd-on-eks/comfyui:latest 85 | QUEUE_AGENT_IMAGE=public.ecr.aws/bingjiao/sd-on-eks/queue-agent:latest 86 | 87 | # Step 1: Install tools 88 | 89 | printf "Step 1: Install tools... \n" 90 | if [ ${INSTALL_TOOLS} = true ] ; then 91 | "${SCRIPTPATH}"/install-tools.sh 92 | fi 93 | 94 | # Step 2: Create S3 bucket and upload model 95 | 96 | printf "Step 2: Create S3 bucket and upload SD 1.5 model... \n" 97 | if [ -z "${MODEL_BUCKET}" ] ; then 98 | MODEL_BUCKET="${STACK_NAME}"-model-bucket-$(echo ${RANDOM} | md5sum | head -c 4) 99 | aws s3 mb "s3://${MODEL_BUCKET}" --region "${AWS_DEFAULT_REGION}" 100 | "${SCRIPTPATH}"/upload-model.sh "${MODEL_BUCKET}" 101 | else 102 | printf "Existing bucket detected, skipping... \n" 103 | fi 104 | 105 | # Step 3: Create EBS Snapshot 106 | 107 | printf "Step 3: Creating EBS snapshot for faster launching...(This step will last for 15-30 min) \n" 108 | if [ -z "$SNAPSHOT_ID" ]; then 109 | cd "${SCRIPTPATH}"/.. 110 | git submodule update --init --recursive 111 | SNAPSHOT_ID=$(utils/bottlerocket-images-cache/snapshot.sh -s 100 -r "${AWS_DEFAULT_REGION}" -q ${SDWEBUI_IMAGE},${COMFYUI_IMAGE},${QUEUE_AGENT_IMAGE}) 112 | else 113 | printf "Existing snapshot ID detected, skipping... \n" 114 | fi 115 | 116 | # Step 4: Deploy 117 | 118 | printf "Step 4: Start deploy... \n" 119 | aws iam create-service-linked-role --aws-service-name spot.amazonaws.com >/dev/null 2>&1 || true 120 | cd "${SCRIPTPATH}"/.. 121 | npm install 122 | 123 | template="$(cat deploy/config.yaml.template)" 124 | eval "echo \"${template}\"" > config.yaml 125 | cdk bootstrap 126 | if [ ${DEPLOY} = true ] ; then 127 | CDK_DEFAULT_REGION=${AWS_DEFAULT_REGION} cdk deploy --no-rollback --require-approval never 128 | printf "Deploy complete. \n" 129 | else 130 | printf "Please revise config.yaml and run 'cdk deploy --no-rollback --require-approval never' to deploy. \n" 131 | fi 132 | -------------------------------------------------------------------------------- /deploy/install-tools.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | kubectl_version='1.29.0' 6 | helm_version='3.10.1' 7 | yq_version='4.30.4' 8 | s5cmd_version='2.2.2' 9 | node_version='22.12.0' 10 | cdk_version='2.162.1' 11 | 12 | download () { 13 | url=$1 14 | out_file=$2 15 | curl --location --show-error --silent --output "$out_file" "$url" 16 | } 17 | cur_dir=$(pwd) 18 | tmp_dir=$(mktemp -d) 19 | 20 | cd "$tmp_dir" 21 | 22 | if which pacapt > /dev/null 23 | then 24 | printf "\n" 25 | else 26 | sudo wget -O /usr/bin/pacapt https://github.com/icy/pacapt/raw/ng/pacapt 27 | sudo chmod 755 /usr/bin/pacapt 28 | fi 29 | 30 | sudo pacapt install --noconfirm jq git wget unzip openssl bash-completion python3 python3-pip 31 | 32 | pip3 install -q yamale 33 | 34 | # kubectl 35 | if which kubectl > /dev/null 36 | then 37 | printf "kubectl is installed, skipping...\n" 38 | else 39 | printf "Installing kubectl...\n" 40 | download "https://dl.k8s.io/release/v$kubectl_version/bin/linux/amd64/kubectl" "kubectl" 41 | chmod +x ./kubectl 42 | sudo mv ./kubectl /usr/bin 43 | fi 44 | 45 | # helm 46 | if which helm > /dev/null 47 | then 48 | printf "helm is installed, skipping...\n" 49 | else 50 | printf "Installing helm...\n" 51 | download "https://get.helm.sh/helm-v$helm_version-linux-amd64.tar.gz" "helm.tar.gz" 52 | tar zxf helm.tar.gz 53 | chmod +x linux-amd64/helm 54 | sudo mv ./linux-amd64/helm /usr/bin 55 | rm -rf linux-amd64/ helm.tar.gz 56 | fi 57 | 58 | # aws cli v2 59 | printf "Installing/Upgrading AWS CLI v2...\n" 60 | curl --location --show-error --silent "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip" 61 | unzip -o -q awscliv2.zip -d /tmp 62 | sudo /tmp/aws/install --update 63 | rm -rf /tmp/aws awscliv2.zip 64 | 65 | # yq 66 | if which yq > /dev/null 67 | then 68 | printf "yq is installed, skipping...\n" 69 | else 70 | printf "Installing yq...\n" 71 | download "https://github.com/mikefarah/yq/releases/download/v${yq_version}/yq_linux_amd64" "yq" 72 | chmod +x ./yq 73 | sudo mv ./yq /usr/bin 74 | fi 75 | 76 | # s5cmd 77 | if which s5cmd > /dev/null 78 | then 79 | printf "s5cmd is installed, skipping...\n" 80 | else 81 | printf "Installing s5cmd...\n" 82 | download "https://github.com/peak/s5cmd/releases/download/v${s5cmd_version}/s5cmd_${s5cmd_version}_Linux-64bit.tar.gz" "s5cmd.tar.gz" 83 | tar zxf s5cmd.tar.gz 84 | chmod +x ./s5cmd 85 | sudo mv ./s5cmd /usr/bin 86 | rm -rf s5cmd.tar.gz 87 | fi 88 | 89 | # Node.js 90 | if which node > /dev/null 91 | then 92 | printf "Node.js is installed, skipping...\n" 93 | else 94 | printf "Installing Node.js...\n" 95 | download "https://nodejs.org/dist/v{$node_version}/node-v{$node_version}-linux-x64.tar.xz" "node.tar.xz" 96 | sudo mkdir -p /usr/local/lib/nodejs 97 | sudo tar -xJf node.tar.xz -C /usr/local/lib/nodejs 98 | export PATH="/usr/local/lib/nodejs/node-v{$node_version}-linux-x64/bin:$PATH" 99 | printf "export PATH=\"/usr/local/lib/nodejs/node-v{$node_version}-linux-x64/bin:\$PATH\"" >> ~/.bash_profile 100 | source ~/.bash_profile 101 | fi 102 | 103 | # CDK CLI 104 | if which cdk > /dev/null 105 | then 106 | printf "CDK CLI is installed, skipping...\n" 107 | else 108 | printf "Installing AWS CDK CLI and bootstraping CDK environment...\n" 109 | sudo npm install -g aws-cdk@$cdk_version 110 | fi 111 | 112 | printf "Tools install complete. \n" 113 | 114 | cd $cur_dir 115 | rm -rf $tmp_dir 116 | -------------------------------------------------------------------------------- /deploy/update.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | printf "Updating stack using current configuration... \n" 6 | cd "${SCRIPTPATH}"/.. 7 | npm install 8 | cdk deploy --no-rollback --require-approval never 9 | printf "Deploy complete. \n" -------------------------------------------------------------------------------- /deploy/upload-model.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | MODEL_BUCKET="$1" 6 | 7 | AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION:-$(aws ec2 describe-availability-zones --output text --query 'AvailabilityZones[0].[RegionName]')} 8 | 9 | MODEL_URL="https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/sd_xl_turbo_1.0.safetensors" 10 | MODEL_NAME="sd_xl_turbo_1.0.safetensors" 11 | 12 | printf "Transport SDXL-Turbo model from hugging face to S3 bucket...\n" 13 | curl -L "$MODEL_URL" | aws s3 cp - s3://${MODEL_BUCKET}/Stable-diffusion/${MODEL_NAME} 14 | 15 | printf "Model uploaded to s3://${MODEL_BUCKET}/Stable-diffusion/${MODEL_NAME}\n" 16 | -------------------------------------------------------------------------------- /docs/Gemfile: -------------------------------------------------------------------------------- 1 | source 'https://rubygems.org' 2 | 3 | gem "jekyll", "~> 4.3.3" # installed by `gem jekyll` 4 | # gem "webrick" # required when using Ruby >= 3 and Jekyll <= 4.2.2 5 | 6 | gem "just-the-docs", "0.8.2" # pinned to the current release 7 | # gem "just-the-docs" # always download the latest release -------------------------------------------------------------------------------- /docs/Gemfile.lock: -------------------------------------------------------------------------------- 1 | GEM 2 | remote: https://rubygems.org/ 3 | specs: 4 | addressable (2.8.7) 5 | public_suffix (>= 2.0.2, < 7.0) 6 | colorator (1.1.0) 7 | concurrent-ruby (1.3.4) 8 | em-websocket (0.5.3) 9 | eventmachine (>= 0.12.9) 10 | http_parser.rb (~> 0) 11 | eventmachine (1.2.7) 12 | ffi (1.17.0) 13 | forwardable-extended (2.6.0) 14 | google-protobuf (3.25.5) 15 | google-protobuf (3.25.5-arm64-darwin) 16 | google-protobuf (3.25.5-x86_64-darwin) 17 | google-protobuf (3.25.5-x86_64-linux) 18 | http_parser.rb (0.8.0) 19 | i18n (1.14.6) 20 | concurrent-ruby (~> 1.0) 21 | jekyll (4.3.4) 22 | addressable (~> 2.4) 23 | colorator (~> 1.0) 24 | em-websocket (~> 0.5) 25 | i18n (~> 1.0) 26 | jekyll-sass-converter (>= 2.0, < 4.0) 27 | jekyll-watch (~> 2.0) 28 | kramdown (~> 2.3, >= 2.3.1) 29 | kramdown-parser-gfm (~> 1.0) 30 | liquid (~> 4.0) 31 | mercenary (>= 0.3.6, < 0.5) 32 | pathutil (~> 0.9) 33 | rouge (>= 3.0, < 5.0) 34 | safe_yaml (~> 1.0) 35 | terminal-table (>= 1.8, < 4.0) 36 | webrick (~> 1.7) 37 | jekyll-include-cache (0.2.1) 38 | jekyll (>= 3.7, < 5.0) 39 | jekyll-sass-converter (3.0.0) 40 | sass-embedded (~> 1.54) 41 | jekyll-seo-tag (2.8.0) 42 | jekyll (>= 3.8, < 5.0) 43 | jekyll-watch (2.2.1) 44 | listen (~> 3.0) 45 | just-the-docs (0.8.2) 46 | jekyll (>= 3.8.5) 47 | jekyll-include-cache 48 | jekyll-seo-tag (>= 2.0) 49 | rake (>= 12.3.1) 50 | kramdown (2.4.0) 51 | rexml 52 | kramdown-parser-gfm (1.1.0) 53 | kramdown (~> 2.0) 54 | liquid (4.0.4) 55 | listen (3.9.0) 56 | rb-fsevent (~> 0.10, >= 0.10.3) 57 | rb-inotify (~> 0.9, >= 0.9.10) 58 | mercenary (0.4.0) 59 | pathutil (0.16.2) 60 | forwardable-extended (~> 2.6) 61 | public_suffix (6.0.1) 62 | rake (13.2.1) 63 | rb-fsevent (0.11.2) 64 | rb-inotify (0.11.1) 65 | ffi (~> 1.0) 66 | rexml (3.3.9) 67 | rouge (4.4.0) 68 | safe_yaml (1.0.5) 69 | sass-embedded (1.69.5) 70 | google-protobuf (~> 3.23) 71 | rake (>= 13.0.0) 72 | terminal-table (3.0.2) 73 | unicode-display_width (>= 1.1.1, < 3) 74 | unicode-display_width (2.6.0) 75 | webrick (1.9.0) 76 | 77 | PLATFORMS 78 | arm64-darwin 79 | ruby 80 | x86_64-darwin 81 | x86_64-linux 82 | 83 | DEPENDENCIES 84 | jekyll (~> 4.3.3) 85 | just-the-docs (= 0.8.2) 86 | 87 | BUNDLED WITH 88 | 2.3.5 89 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # Docs for Guidance for Asynchronous Inference with Stable Diffusion on AWS 2 | 3 | This folder contains the documentation code for the guidance. 4 | 5 | ## Preview in your environment 6 | 7 | The doc uses [Jekyll](https://jekyllrb.com/) with [Just the Docs](https://just-the-docs.com/) theme. 8 | 9 | Assuming [Jekyll](https://jekyllrb.com/) and [Bundler](https://bundler.io/) are installed on your computer, you can preview the doc by the following: 10 | 11 | 1. Run `bundle install`. 12 | 2. Change your working directory to language variant (`zh` for Simplified Chinese, and `en` for English) 13 | 3. Run `bundle exec jekyll serve` to build and preview the doc at `localhost:4000`. -------------------------------------------------------------------------------- /docs/api/v1alpha1.yaml: -------------------------------------------------------------------------------- 1 | openapi: 3.0.3 2 | info: 3 | title: SD on EKS 4 | version: v1alpha1 5 | x-logo: 6 | url: "" 7 | servers: 8 | - url: http://localhost:8080/v1alpha1 9 | description: "" 10 | paths: 11 | /: 12 | summary: Submit task to SD on EKS runtimes 13 | post: 14 | requestBody: 15 | description: "Content of image generating task" 16 | content: 17 | application/json: 18 | schema: 19 | $ref: '#/components/schemas/Task' 20 | required: true 21 | operationId: SubmitTask 22 | responses: 23 | '200': 24 | description: Successful response, task is accepted 25 | content: 26 | application/json: 27 | schema: 28 | title: Sample 29 | type: object 30 | properties: 31 | id_task: 32 | type: string 33 | description: Task ID 34 | example: "abc" 35 | task: 36 | description: Type of task. 37 | example: "text-to-image" 38 | type: string 39 | sd_model_checkpoint: 40 | type: string 41 | description: Selected model 42 | example: "sd_xl_turbo_1.0.safetensors" 43 | output_location: 44 | type: string 45 | description: Location of output file 46 | example: "s3://sdoneksstack-outputs3bucket/abc" 47 | security: 48 | - apikey: [] 49 | externalDocs: 50 | description: Usage introduction 51 | url: >- 52 | https://aws-samples.github.io/stable-diffusion-on-eks/zh/implementation-guide/usage/ 53 | tags: [] 54 | security: [] 55 | components: 56 | schemas: 57 | Task: 58 | type: object 59 | properties: 60 | task: 61 | type: object 62 | properties: 63 | alwayson_scripts: 64 | type: object 65 | properties: 66 | task: 67 | description: Type of task. Should be "text-to-image" or "image-to-image" for SD Web UI 68 | enum: 69 | - text-to-image 70 | - image-to-image 71 | example: "text-to-image" 72 | type: string 73 | sd_model_checkpoint: 74 | description: Checkpoint used in the task 75 | example: "sd_xl_turbo_1.0.safetensors" 76 | type: string 77 | id_task: 78 | description: Task ID 79 | example: "sd_xl_turbo_1.0.safetensors" 80 | type: string 81 | uid: 82 | type: string 83 | description: User ID 84 | example: "abc" 85 | save_dir: 86 | type: string 87 | description: Directory stored in s3 bucket, should be "output" only 88 | example: output 89 | prompt: 90 | type: string 91 | example: "a dog" 92 | steps: 93 | type: integer 94 | example: 16 95 | height: 96 | type: integer 97 | example: 512 98 | width: 99 | type: integer 100 | example: 512 101 | securitySchemes: 102 | apikey: 103 | type: apiKey 104 | description: API Key generated by AWS API Gateway 105 | name: x-api-key 106 | in: header -------------------------------------------------------------------------------- /docs/api/v1alpha2.yaml: -------------------------------------------------------------------------------- 1 | openapi: 3.0.3 2 | info: 3 | title: SD on EKS 4 | version: v1alpha2 5 | x-logo: 6 | url: "" 7 | servers: 8 | - url: http://localhost:8080/ 9 | description: "" 10 | paths: 11 | /task: 12 | summary: Submit task to SD on EKS runtimes 13 | post: 14 | requestBody: 15 | description: "Content of image generating task" 16 | content: 17 | application/json: 18 | schema: 19 | $ref: '#/components/schemas/Task' 20 | required: true 21 | operationId: SubmitTask 22 | responses: 23 | '200': 24 | description: Successful response, task is accepted 25 | content: 26 | application/json: 27 | schema: 28 | title: Sample 29 | type: object 30 | properties: 31 | id: 32 | type: string 33 | description: Task ID 34 | example: "abc" 35 | runtime: 36 | type: string 37 | description: Selected runtime 38 | example: "sdruntime1" 39 | output_location: 40 | type: string 41 | description: Location of output file 42 | example: "s3://sdoneksstack-outputs3bucket/abc" 43 | security: 44 | - apikey: [] 45 | externalDocs: 46 | description: Usage introduction 47 | url: >- 48 | https://aws-samples.github.io/stable-diffusion-on-eks/zh/implementation-guide/usage/ 49 | tags: [] 50 | security: [] 51 | components: 52 | schemas: 53 | Task: 54 | type: object 55 | properties: 56 | task: 57 | type: object 58 | properties: 59 | metadata: 60 | type: object 61 | properties: 62 | id: 63 | type: string 64 | description: Task ID 65 | example: "abc" 66 | runtime: 67 | type: string 68 | description: Runtime used in this task, should match with installed runtime. 69 | example: sdruntime1 70 | tasktype: 71 | description: Type of task. Should be "text-to-image" or "image-to-image" for SD Web UI, and "pipeline" for ComfyUI 72 | enum: 73 | - text-to-image 74 | - image-to-image 75 | - pipeline 76 | example: "text-to-image" 77 | type: string 78 | prefix: 79 | type: string 80 | description: Prefix for output file store 81 | example: "/output" 82 | context: 83 | type: object 84 | description: All content in context will be brought to callback 85 | example: "" 86 | content: 87 | type: object 88 | example: "" 89 | securitySchemes: 90 | apikey: 91 | type: apiKey 92 | description: API Key generated by AWS API Gateway 93 | name: x-api-key 94 | in: header 95 | -------------------------------------------------------------------------------- /docs/en/_config.yml: -------------------------------------------------------------------------------- 1 | theme: just-the-docs 2 | 3 | callouts: 4 | highlight: 5 | color: yellow 6 | note: 7 | color: purple 8 | new: 9 | color: green 10 | warning: 11 | color: red -------------------------------------------------------------------------------- /docs/en/_includes/image.html: -------------------------------------------------------------------------------- 1 |
2 | {{ include.alt }} 3 |
-------------------------------------------------------------------------------- /docs/en/async_img_sd_images/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-solutions-library-samples/guidance-for-asynchronous-inference-with-stable-diffusion-on-aws/efc4a0aef8367cb49944b3fbca1d72e70d27d457/docs/en/async_img_sd_images/.gitkeep -------------------------------------------------------------------------------- /docs/en/async_img_sd_images/asynchronous-image-generation-with-stable-diffusion-on-aws.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-solutions-library-samples/guidance-for-asynchronous-inference-with-stable-diffusion-on-aws/efc4a0aef8367cb49944b3fbca1d72e70d27d457/docs/en/async_img_sd_images/asynchronous-image-generation-with-stable-diffusion-on-aws.png -------------------------------------------------------------------------------- /docs/en/async_img_sd_sidebar.yml: -------------------------------------------------------------------------------- 1 | entries: 2 | - title: Guidance for Asynchronous Image Generation with Stable Diffusion on AWS 3 | folders: 4 | - title: Overview 5 | url: '#overview' 6 | output: web 7 | 8 | - title: Architecture Overview 9 | url: '#architecture-overview' 10 | output: web 11 | 12 | - title: Cost 13 | url: '#cost' 14 | output: web 15 | 16 | - title: Security 17 | url: '#security' 18 | output: web 19 | 20 | - title: Deploy the Guidance 21 | url: '#deploy-the-guidance' 22 | output: web 23 | 24 | - title: Usage Guide 25 | url: '#usage-guide' 26 | output: web 27 | 28 | - title: Uninstall the guidance 29 | url: '#uninstall-the-guidance' 30 | output: web 31 | 32 | - title: Related resources 33 | url: '#related-resources' 34 | output: web 35 | 36 | - title: Contributors 37 | url: '#contributors' 38 | output: web 39 | 40 | - title: Notices 41 | url: '#notices' 42 | output: web 43 | -------------------------------------------------------------------------------- /docs/stable-diffusion-reference-architecture-updated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-solutions-library-samples/guidance-for-asynchronous-inference-with-stable-diffusion-on-aws/efc4a0aef8367cb49944b3fbca1d72e70d27d457/docs/stable-diffusion-reference-architecture-updated.png -------------------------------------------------------------------------------- /docs/zh/_config.yml: -------------------------------------------------------------------------------- 1 | theme: just-the-docs 2 | 3 | callouts: 4 | highlight: 5 | color: yellow 6 | note: 7 | color: purple 8 | new: 9 | color: green 10 | warning: 11 | color: red -------------------------------------------------------------------------------- /docs/zh/_includes/image.html: -------------------------------------------------------------------------------- 1 |
2 | {{ include.alt }} 3 |
-------------------------------------------------------------------------------- /docs/zh/async_img_sd_zh_images/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-solutions-library-samples/guidance-for-asynchronous-inference-with-stable-diffusion-on-aws/efc4a0aef8367cb49944b3fbca1d72e70d27d457/docs/zh/async_img_sd_zh_images/.gitkeep -------------------------------------------------------------------------------- /docs/zh/async_img_sd_zh_images/stable-diffusion-reference-architecture-updated.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aws-solutions-library-samples/guidance-for-asynchronous-inference-with-stable-diffusion-on-aws/efc4a0aef8367cb49944b3fbca1d72e70d27d457/docs/zh/async_img_sd_zh_images/stable-diffusion-reference-architecture-updated.png -------------------------------------------------------------------------------- /docs/zh/async_img_sd_zh_sidebar.yml: -------------------------------------------------------------------------------- 1 | entries: 2 | - title: Guidance for Asynchronous Image Generation with Stable Diffusion on AWS 3 | folders: 4 | - title: 简介 5 | url: '#简介' 6 | output: web 7 | 8 | - title: 架构概览 9 | url: '#架构概览' 10 | output: web 11 | 12 | - title: 费用预估 13 | url: '#费用预估' 14 | output: web 15 | 16 | - title: 部署解决方案 17 | url: '#部署解决方案' 18 | output: web 19 | 20 | - title: 使用指南 21 | url: '#使用指南' 22 | output: web 23 | 24 | - title: 删除解决方案 25 | url: '#删除解决方案' 26 | output: web 27 | -------------------------------------------------------------------------------- /lib/addons/dcgmExporter.ts: -------------------------------------------------------------------------------- 1 | import * as blueprints from '@aws-quickstart/eks-blueprints'; 2 | import { Construct } from "constructs"; 3 | 4 | export interface dcgmExporterAddOnProps extends blueprints.addons.HelmAddOnUserProps { 5 | } 6 | 7 | export const defaultProps: blueprints.addons.HelmAddOnProps & dcgmExporterAddOnProps = { 8 | chart: 'dcgm-exporter', 9 | name: 'dcgmExporterAddOn', 10 | namespace: 'kube-system', 11 | release: 'dcgm', 12 | version: '3.4.0', 13 | repository: 'https://nvidia.github.io/dcgm-exporter/helm-charts', 14 | values: { 15 | serviceMonitor: { 16 | enabled: false 17 | }, 18 | affinity: { 19 | nodeAffinity: { 20 | requiredDuringSchedulingIgnoredDuringExecution: { 21 | nodeSelectorTerms: [{ 22 | matchExpressions: [{ 23 | key: "karpenter.k8s.aws/instance-gpu-count", 24 | operator: "Exists" 25 | }] 26 | }] 27 | } 28 | } 29 | }, 30 | tolerations: [{ 31 | key: "nvidia.com/gpu", 32 | operator: "Exists", 33 | effect: "NoSchedule" 34 | }, { 35 | key: "runtime", 36 | operator: "Exists", 37 | effect: "NoSchedule" 38 | }] 39 | } 40 | } 41 | 42 | export class dcgmExporterAddOn extends blueprints.addons.HelmAddOn { 43 | 44 | readonly options: dcgmExporterAddOnProps; 45 | 46 | constructor(props: dcgmExporterAddOnProps) { 47 | super({ ...defaultProps, ...props }); 48 | this.options = this.props as dcgmExporterAddOnProps; 49 | } 50 | 51 | deploy(clusterInfo: blueprints.ClusterInfo): Promise { 52 | const cluster = clusterInfo.cluster; 53 | 54 | const chart = this.addHelmChart(clusterInfo, this.options.values, true, true); 55 | return Promise.resolve(chart); 56 | } 57 | } -------------------------------------------------------------------------------- /lib/addons/ebsThroughputTuner.ts: -------------------------------------------------------------------------------- 1 | import { ClusterAddOn, ClusterInfo } from '@aws-quickstart/eks-blueprints'; 2 | import { Construct } from "constructs"; 3 | import * as cdk from 'aws-cdk-lib'; 4 | import * as lambda from 'aws-cdk-lib/aws-lambda'; 5 | import * as iam from "aws-cdk-lib/aws-iam" 6 | import * as path from 'path'; 7 | import * as sfn from 'aws-cdk-lib/aws-stepfunctions'; 8 | import * as tasks from 'aws-cdk-lib/aws-stepfunctions-tasks'; 9 | import * as events from 'aws-cdk-lib/aws-events'; 10 | import * as targets from 'aws-cdk-lib/aws-events-targets'; 11 | 12 | export interface EbsThroughputTunerAddOnProps { 13 | duration: number; 14 | throughput: number; 15 | iops: number; 16 | } 17 | 18 | export class EbsThroughputTunerAddOn implements ClusterAddOn { 19 | 20 | readonly options: EbsThroughputTunerAddOnProps; 21 | 22 | constructor(props: EbsThroughputTunerAddOnProps) { 23 | this.options = props 24 | } 25 | 26 | deploy(clusterInfo: ClusterInfo): Promise { 27 | const cluster = clusterInfo.cluster; 28 | 29 | const lambdaTimeout: number = 300 30 | 31 | //EBS Throughput Modify lambda function 32 | const lambdaFunction = new lambda.Function(cluster.stack, 'EbsThroughputTunerLambda', { 33 | code: lambda.Code.fromAsset(path.join(__dirname, '../../src/tools/ebs_throughput_tuner')), 34 | handler: 'app.lambda_handler', 35 | runtime: lambda.Runtime.PYTHON_3_11, 36 | timeout: cdk.Duration.seconds(lambdaTimeout), 37 | environment: { 38 | "TARGET_EC2_TAG_KEY": "stack", 39 | "TARGET_EC2_TAG_VALUE": cdk.Aws.STACK_NAME, 40 | "THROUGHPUT_VALUE": this.options.throughput.toString(), 41 | "IOPS_VALUE": this.options.iops.toString() 42 | }, 43 | }); 44 | 45 | const functionRole = lambdaFunction.role!.addManagedPolicy( 46 | iam.ManagedPolicy.fromAwsManagedPolicyName( 47 | 'AmazonEC2FullAccess', 48 | )) 49 | 50 | // Step Functions definition 51 | const waitTask = new sfn.Wait(cluster.stack, 'Wait time', { 52 | time: sfn.WaitTime.duration(cdk.Duration.seconds(this.options.duration)), 53 | }); 54 | 55 | const triggerTask = new tasks.LambdaInvoke(cluster.stack, 'Change throughput', { 56 | lambdaFunction: lambdaFunction 57 | }).addRetry({ 58 | backoffRate: 2, 59 | maxAttempts: 3, 60 | interval: cdk.Duration.seconds(5) 61 | }) 62 | 63 | const stateDefinition = waitTask 64 | .next(triggerTask) 65 | 66 | const stateMachine = new sfn.StateMachine(cluster.stack, 'EbsThroughputTunerStateMachine', { 67 | definitionBody: sfn.DefinitionBody.fromChainable(stateDefinition), 68 | timeout: cdk.Duration.seconds(this.options.duration + lambdaTimeout + 30), 69 | }); 70 | 71 | lambdaFunction.grantInvoke(stateMachine) 72 | 73 | const rule = new events.Rule(cluster.stack, 'EbsThroughputTunerRule', { 74 | eventPattern: { 75 | detail: { 76 | 'state': events.Match.equalsIgnoreCase("running") 77 | }, 78 | detailType: events.Match.equalsIgnoreCase('EC2 Instance State-change Notification'), 79 | source: ['aws.ec2'], 80 | } 81 | }); 82 | 83 | rule.addTarget(new targets.SfnStateMachine(stateMachine)) 84 | 85 | return Promise.resolve(rule); 86 | } 87 | } -------------------------------------------------------------------------------- /lib/addons/s3CSIDriver.ts: -------------------------------------------------------------------------------- 1 | import * as blueprints from '@aws-quickstart/eks-blueprints'; 2 | import * as eks from 'aws-cdk-lib/aws-eks'; 3 | import * as iam from 'aws-cdk-lib/aws-iam'; 4 | import { Construct } from "constructs"; 5 | 6 | export interface s3CSIDriverAddOnProps extends blueprints.addons.HelmAddOnUserProps { 7 | s3BucketArn: string; 8 | } 9 | 10 | export const defaultProps: blueprints.addons.HelmAddOnProps & s3CSIDriverAddOnProps = { 11 | chart: 'aws-mountpoint-s3-csi-driver', 12 | name: 's3CSIDriverAddOn', 13 | namespace: 'kube-system', 14 | release: 's3-csi-driver-release', 15 | version: 'v1.10.0', 16 | repository: 'https://awslabs.github.io/mountpoint-s3-csi-driver', 17 | s3BucketArn: '' 18 | } 19 | 20 | export class s3CSIDriverAddOn extends blueprints.addons.HelmAddOn { 21 | 22 | readonly options: s3CSIDriverAddOnProps; 23 | 24 | constructor(props: s3CSIDriverAddOnProps) { 25 | super({ ...defaultProps, ...props }); 26 | this.options = this.props as s3CSIDriverAddOnProps; 27 | } 28 | 29 | deploy(clusterInfo: blueprints.ClusterInfo): Promise { 30 | const cluster = clusterInfo.cluster; 31 | const serviceAccount = cluster.addServiceAccount('s3-csi-driver-sa', { 32 | name: 's3-csi-driver-sa', 33 | namespace: this.options.namespace 34 | }); 35 | 36 | // new IAM policy to grand access to S3 bucket 37 | // https://github.com/awslabs/mountpoint-s3/blob/main/doc/CONFIGURATION.md#iam-permissions 38 | const s3BucketPolicy = new iam.Policy(cluster, 's3-csi-driver-policy', { 39 | statements: [ 40 | new iam.PolicyStatement({ 41 | sid: 'MountpointFullBucketAccess', 42 | actions: [ 43 | "s3:ListBucket" 44 | ], 45 | resources: [this.options.s3BucketArn] 46 | }), 47 | new iam.PolicyStatement({ 48 | sid: 'MountpointFullObjectAccess', 49 | actions: [ 50 | "s3:GetObject", 51 | "s3:PutObject", 52 | "s3:AbortMultipartUpload", 53 | "s3:DeleteObject" 54 | ], 55 | resources: [`${this.options.s3BucketArn}/*`] 56 | })] 57 | }); 58 | serviceAccount.role.attachInlinePolicy(s3BucketPolicy); 59 | 60 | const chart = this.addHelmChart(clusterInfo, { 61 | node: { 62 | serviceAccount: { 63 | create: false 64 | }, 65 | tolerateAllTaints: true 66 | } 67 | }, true, true); 68 | return Promise.resolve(chart); 69 | } 70 | } -------------------------------------------------------------------------------- /lib/addons/sharedComponent.ts: -------------------------------------------------------------------------------- 1 | import { ClusterAddOn, ClusterInfo } from '@aws-quickstart/eks-blueprints'; 2 | import { Construct } from "constructs"; 3 | import * as cdk from 'aws-cdk-lib'; 4 | import * as sns from 'aws-cdk-lib/aws-sns'; 5 | import * as s3 from 'aws-cdk-lib/aws-s3'; 6 | import * as lambda from 'aws-cdk-lib/aws-lambda'; 7 | import * as apigw from "aws-cdk-lib/aws-apigateway"; 8 | import * as xray from "aws-cdk-lib/aws-xray" 9 | import * as path from 'path'; 10 | 11 | export interface SharedComponentAddOnProps { 12 | inputSns: sns.ITopic; 13 | outputSns: sns.ITopic; 14 | outputBucket: s3.IBucket; 15 | apiGWProps?: { 16 | stageName?: string, 17 | throttle?: { 18 | rateLimit?: number, 19 | burstLimit?: number 20 | } 21 | } 22 | } 23 | 24 | export const defaultProps: SharedComponentAddOnProps = { 25 | inputSns: undefined!, 26 | outputSns: undefined!, 27 | outputBucket: undefined!, 28 | apiGWProps: { 29 | stageName: "prod", 30 | throttle: { 31 | rateLimit: 10, 32 | burstLimit: 2 33 | } 34 | } 35 | } 36 | 37 | export class SharedComponentAddOn implements ClusterAddOn { 38 | readonly options: SharedComponentAddOnProps; 39 | 40 | constructor(props: SharedComponentAddOnProps) { 41 | this.options = { ...defaultProps, ...props }; 42 | } 43 | 44 | deploy(clusterInfo: ClusterInfo): Promise { 45 | const cluster = clusterInfo.cluster; 46 | 47 | const v1Alpha1Parser = new lambda.Function(cluster.stack, 'v1Alpha1ParserFunction', { 48 | code: lambda.Code.fromAsset(path.join(__dirname, '../../src/frontend/input_function/v1alpha1')), 49 | handler: 'app.lambda_handler', 50 | runtime: lambda.Runtime.PYTHON_3_11, 51 | environment: { 52 | "SNS_TOPIC_ARN": this.options.inputSns.topicArn, 53 | "S3_OUTPUT_BUCKET": this.options.outputBucket.bucketName 54 | }, 55 | tracing: lambda.Tracing.ACTIVE 56 | }); 57 | 58 | const v1Alpha2Parser = new lambda.Function(cluster.stack, 'v1Alpha2ParserFunction', { 59 | code: lambda.Code.fromAsset(path.join(__dirname, '../../src/frontend/input_function/v1alpha2')), 60 | handler: 'app.lambda_handler', 61 | runtime: lambda.Runtime.PYTHON_3_11, 62 | environment: { 63 | "SNS_TOPIC_ARN": this.options.inputSns.topicArn, 64 | "S3_OUTPUT_BUCKET": this.options.outputBucket.bucketName 65 | }, 66 | tracing: lambda.Tracing.ACTIVE 67 | }); 68 | 69 | this.options.inputSns.grantPublish(v1Alpha1Parser); 70 | this.options.inputSns.grantPublish(v1Alpha2Parser); 71 | 72 | const api = new apigw.RestApi(cluster.stack, 'FrontAPI', { 73 | restApiName: 'FrontAPI', 74 | deploy: true, 75 | cloudWatchRole: true, 76 | endpointConfiguration: { 77 | types: [ apigw.EndpointType.REGIONAL ] 78 | }, 79 | defaultMethodOptions: { 80 | apiKeyRequired: true 81 | }, 82 | deployOptions: { 83 | stageName: this.options.apiGWProps!.stageName, 84 | tracingEnabled: true, 85 | metricsEnabled: true, 86 | } 87 | }); 88 | 89 | const v1alpha1Resource = api.root.addResource('v1alpha1'); 90 | v1alpha1Resource.addMethod('POST', new apigw.LambdaIntegration(v1Alpha1Parser), { apiKeyRequired: true }); 91 | 92 | const v1alpha2Resource = api.root.addResource('v1alpha2'); 93 | v1alpha2Resource.addMethod('POST', new apigw.LambdaIntegration(v1Alpha2Parser), { apiKeyRequired: true }); 94 | 95 | api.node.addDependency(v1Alpha1Parser); 96 | api.node.addDependency(v1Alpha2Parser); 97 | 98 | //Force override name of generated output to provide a static name 99 | const urlCfnOutput = api.node.findChild('Endpoint') as cdk.CfnOutput; 100 | urlCfnOutput.overrideLogicalId('FrontApiEndpoint') 101 | 102 | const apiKey = new apigw.ApiKey(cluster.stack, `defaultAPIKey`, { 103 | description: `Default API Key`, 104 | enabled: true 105 | }) 106 | 107 | const plan = api.addUsagePlan('UsagePlan', { 108 | name: cluster.stack.stackId + '-Default', 109 | apiStages: [{ 110 | stage: api.deploymentStage 111 | }], 112 | throttle: { 113 | rateLimit: this.options.apiGWProps!.throttle!.rateLimit, 114 | burstLimit: this.options.apiGWProps!.throttle!.burstLimit 115 | } 116 | }); 117 | 118 | plan.addApiKey(apiKey) 119 | 120 | new cdk.CfnOutput(cluster.stack, 'GetAPIKeyCommand', { 121 | value: "aws apigateway get-api-keys --query 'items[?id==`" + apiKey.keyId + "`].value' --include-values --output text", 122 | description: 'Command to get API Key' 123 | }); 124 | 125 | //Xray Access Policy 126 | new xray.CfnResourcePolicy(cluster.stack, cluster.stack.stackName+'XRayAccessPolicyForSNS', { 127 | policyName: cluster.stack.stackName+'XRayAccessPolicyForSNS', 128 | policyDocument: '{"Version":"2012-10-17","Statement":[{"Sid":"SNSAccess","Effect":"Allow","Principal":{"Service":"sns.amazonaws.com"},"Action":["xray:PutTraceSegments","xray:GetSamplingRules","xray:GetSamplingTargets"],"Resource":"*","Condition":{"StringEquals":{"aws:SourceAccount":"' + cluster.stack.account + '"},"StringLike":{"aws:SourceArn":"' + cluster.stack.formatArn({ service: "sns", resource: '*' }) + '"}}}]}' 129 | }) 130 | 131 | // Output S3 bucket ARN 132 | 133 | new cdk.CfnOutput(cluster.stack, 'OutputS3Bucket', { 134 | value: this.options.outputBucket.bucketArn, 135 | description: 'S3 bucket for generated images' 136 | }); 137 | 138 | return Promise.resolve(plan); 139 | } 140 | } 141 | -------------------------------------------------------------------------------- /lib/dataPlane.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import * as blueprints from '@aws-quickstart/eks-blueprints'; 3 | import { Construct } from "constructs"; 4 | import * as eks from "aws-cdk-lib/aws-eks"; 5 | import * as sns from 'aws-cdk-lib/aws-sns'; 6 | import * as ec2 from 'aws-cdk-lib/aws-ec2'; 7 | import * as s3 from 'aws-cdk-lib/aws-s3'; 8 | import * as iam from 'aws-cdk-lib/aws-iam'; 9 | import SDRuntimeAddon, { SDRuntimeAddOnProps } from './runtime/sdRuntime'; 10 | import { EbsThroughputTunerAddOn, EbsThroughputTunerAddOnProps } from './addons/ebsThroughputTuner' 11 | import { s3CSIDriverAddOn, s3CSIDriverAddOnProps } from './addons/s3CSIDriver' 12 | import { SharedComponentAddOn, SharedComponentAddOnProps } from './addons/sharedComponent'; 13 | import { SNSResourceProvider } from './resourceProvider/sns' 14 | import { s3GWEndpointProvider } from './resourceProvider/s3GWEndpoint' 15 | import { dcgmExporterAddOn } from './addons/dcgmExporter'; 16 | 17 | export interface dataPlaneProps { 18 | stackName: string, 19 | modelBucketArn: string; 20 | APIGW?: { 21 | stageName?: string, 22 | throttle?: { 23 | rateLimit?: number, 24 | burstLimit?: number 25 | } 26 | } 27 | modelsRuntime: { 28 | name: string, 29 | namespace: string, 30 | type: string, 31 | modelFilename?: string, 32 | dynamicModel?: boolean, 33 | chartRepository?: string, 34 | chartVersion?: string, 35 | extraValues?: {} 36 | }[] 37 | } 38 | 39 | export default class DataPlaneStack { 40 | constructor(scope: Construct, id: string, 41 | dataplaneProps: dataPlaneProps, 42 | props: cdk.StackProps) { 43 | 44 | const kedaParams: blueprints.KedaAddOnProps = { 45 | podSecurityContextFsGroup: 1001, 46 | securityContextRunAsGroup: 1001, 47 | securityContextRunAsUser: 1001, 48 | irsaRoles: ["CloudWatchFullAccess", "AmazonSQSFullAccess"] 49 | }; 50 | 51 | const cloudWatchInsightsParams: blueprints.CloudWatchInsightsAddOnProps = { 52 | configurationValues: { 53 | tolerations: [ 54 | { 55 | key: "runtime", 56 | operator: "Exists", 57 | effect: "NoSchedule" 58 | }, 59 | { 60 | key: "nvidia.com/gpu", 61 | operator: "Exists", 62 | effect: "NoSchedule" 63 | } 64 | ], 65 | containerLogs: { 66 | enabled: true, 67 | fluentBit: { 68 | config: { 69 | service: "[SERVICE]\n Flush 5\n Grace 30\n Log_Level info", 70 | extraFiles: { 71 | "application-log.conf": "[INPUT]\n Name tail\n Tag kube.*\n Path /var/log/containers/*.log\n Parser docker\n DB /var/log/flb_kube.db\n Mem_Buf_Limit 5MB\n Skip_Long_Lines On\n Refresh_Interval 10\n\n[FILTER]\n Name kubernetes\n Match kube.*\n Kube_URL https://kubernetes.default.svc:443\n Kube_CA_File /var/run/secrets/kubernetes.io/serviceaccount/ca.crt\n Kube_Token_File /var/run/secrets/kubernetes.io/serviceaccount/token\n Kube_Tag_Prefix kube.var.log.containers.\n Merge_Log On\n Merge_Log_Key log_processed\n K8S-Logging.Parser On\n K8S-Logging.Exclude On\n\n[FILTER]\n Name grep\n Match kube.*\n Exclude $kubernetes['namespace_name'] kube-system\n\n[OUTPUT]\n Name cloudwatch\n Match kube.*\n region ${AWS_REGION}\n log_group_name /aws/containerinsights/${CLUSTER_NAME}/application\n log_stream_prefix ${HOST_NAME}-\n auto_create_group true\n retention_in_days 7" 72 | } 73 | } 74 | } 75 | } 76 | } 77 | }; 78 | 79 | const SharedComponentAddOnParams: SharedComponentAddOnProps = { 80 | inputSns: blueprints.getNamedResource("inputSNSTopic"), 81 | outputSns: blueprints.getNamedResource("outputSNSTopic"), 82 | outputBucket: blueprints.getNamedResource("outputS3Bucket"), 83 | apiGWProps: dataplaneProps.APIGW 84 | }; 85 | 86 | const EbsThroughputModifyAddOnParams: EbsThroughputTunerAddOnProps = { 87 | duration: 300, 88 | throughput: 125, 89 | iops: 3000 90 | }; 91 | 92 | const s3CSIDriverAddOnParams: s3CSIDriverAddOnProps = { 93 | s3BucketArn: dataplaneProps.modelBucketArn 94 | }; 95 | 96 | const addOns: Array = [ 97 | new blueprints.addons.VpcCniAddOn(), 98 | new blueprints.addons.CoreDnsAddOn(), 99 | new blueprints.addons.KubeProxyAddOn(), 100 | new blueprints.addons.AwsLoadBalancerControllerAddOn(), 101 | new blueprints.addons.KarpenterAddOn({ interruptionHandling: true }), 102 | new blueprints.addons.KedaAddOn(kedaParams), 103 | new blueprints.addons.CloudWatchInsights(cloudWatchInsightsParams), 104 | new s3CSIDriverAddOn(s3CSIDriverAddOnParams), 105 | new SharedComponentAddOn(SharedComponentAddOnParams), 106 | new EbsThroughputTunerAddOn(EbsThroughputModifyAddOnParams), 107 | ]; 108 | 109 | // Generate SD Runtime Addon for runtime 110 | dataplaneProps.modelsRuntime.forEach((val, idx, array) => { 111 | const sdRuntimeParams: SDRuntimeAddOnProps = { 112 | modelBucketArn: dataplaneProps.modelBucketArn, 113 | outputSns: blueprints.getNamedResource("outputSNSTopic") as sns.ITopic, 114 | inputSns: blueprints.getNamedResource("inputSNSTopic") as sns.ITopic, 115 | outputBucket: blueprints.getNamedResource("outputS3Bucket") as s3.IBucket, 116 | type: val.type.toLowerCase(), 117 | chartRepository: val.chartRepository, 118 | chartVersion: val.chartVersion, 119 | extraValues: val.extraValues, 120 | targetNamespace: val.namespace, 121 | }; 122 | 123 | //Parameters for SD Web UI 124 | if (val.type.toLowerCase() == "sdwebui") { 125 | if (val.modelFilename) { 126 | sdRuntimeParams.sdModelCheckpoint = val.modelFilename 127 | } 128 | if (val.dynamicModel == true) { 129 | sdRuntimeParams.dynamicModel = true 130 | } else { 131 | sdRuntimeParams.dynamicModel = false 132 | } 133 | } 134 | 135 | if (val.type.toLowerCase() == "comfyui") {} 136 | 137 | addOns.push(new SDRuntimeAddon(sdRuntimeParams, val.name)) 138 | }); 139 | 140 | // Define initial managed node group for cluster components 141 | const MngProps: blueprints.MngClusterProviderProps = { 142 | minSize: 2, 143 | maxSize: 2, 144 | desiredSize: 2, 145 | version: eks.KubernetesVersion.V1_31, 146 | instanceTypes: [new ec2.InstanceType('m7g.large')], 147 | amiType: eks.NodegroupAmiType.AL2023_ARM_64_STANDARD, 148 | enableSsmPermissions: true, 149 | nodeGroupTags: { 150 | "Name": cdk.Aws.STACK_NAME + "-ClusterComponents", 151 | "stack": cdk.Aws.STACK_NAME 152 | } 153 | } 154 | 155 | // Deploy EKS cluster with all add-ons 156 | const blueprint = blueprints.EksBlueprint.builder() 157 | .version(eks.KubernetesVersion.V1_31) 158 | .addOns(...addOns) 159 | .resourceProvider( 160 | blueprints.GlobalResources.Vpc, 161 | new blueprints.VpcProvider()) 162 | .resourceProvider("inputSNSTopic", new SNSResourceProvider("sdNotificationLambda")) 163 | .resourceProvider("outputSNSTopic", new SNSResourceProvider("sdNotificationOutput")) 164 | .resourceProvider("outputS3Bucket", new blueprints.CreateS3BucketProvider({ 165 | id: 'outputS3Bucket' 166 | })) 167 | .resourceProvider("s3GWEndpoint", new s3GWEndpointProvider("s3GWEndpoint")) 168 | .clusterProvider(new blueprints.MngClusterProvider(MngProps)) 169 | .build(scope, id + 'Stack', props); 170 | /* 171 | // Workaround for permission denied when creating cluster 172 | const handler = blueprint.node.tryFindChild('@aws-cdk--aws-eks.KubectlProvider')! 173 | .node.tryFindChild('Handler')! as cdk.aws_lambda.Function 174 | 175 | ( 176 | blueprint.node.tryFindChild('@aws-cdk--aws-eks.KubectlProvider')! 177 | .node.tryFindChild('Provider')! 178 | .node.tryFindChild('framework-onEvent')! 179 | .node.tryFindChild('ServiceRole')! 180 | .node.tryFindChild('DefaultPolicy') as cdk.aws_iam.Policy 181 | ) 182 | .addStatements(new cdk.aws_iam.PolicyStatement({ 183 | effect: cdk.aws_iam.Effect.ALLOW, 184 | actions: ["lambda:GetFunctionConfiguration"], 185 | resources: [handler.functionArn] 186 | })) 187 | */ 188 | // Provide static output name for cluster 189 | const cluster = blueprint.getClusterInfo().cluster 190 | const clusterNameCfnOutput = cluster.node.findChild('ClusterName') as cdk.CfnOutput; 191 | clusterNameCfnOutput.overrideLogicalId('ClusterName') 192 | 193 | const configCommandCfnOutput = cluster.node.findChild('ConfigCommand') as cdk.CfnOutput; 194 | configCommandCfnOutput.overrideLogicalId('ConfigCommand') 195 | 196 | const getTokenCommandCfnOutput = cluster.node.findChild('GetTokenCommand') as cdk.CfnOutput; 197 | getTokenCommandCfnOutput.overrideLogicalId('GetTokenCommand') 198 | } 199 | } -------------------------------------------------------------------------------- /lib/resourceProvider/s3GWEndpoint.ts: -------------------------------------------------------------------------------- 1 | import * as blueprints from '@aws-quickstart/eks-blueprints'; 2 | import * as ec2 from 'aws-cdk-lib/aws-ec2'; 3 | 4 | export class s3GWEndpointProvider implements blueprints.ResourceProvider { 5 | constructor(readonly name: string) { } 6 | 7 | provide(context: blueprints.ResourceContext): ec2.IGatewayVpcEndpoint { 8 | const vpc = context.get('vpc') as ec2.IVpc 9 | const vpce = new ec2.GatewayVpcEndpoint(context.scope, this.name, { 10 | service: ec2.GatewayVpcEndpointAwsService.S3, 11 | vpc: vpc 12 | }); 13 | 14 | return vpce 15 | } 16 | } -------------------------------------------------------------------------------- /lib/resourceProvider/sns.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import * as blueprints from '@aws-quickstart/eks-blueprints'; 3 | import * as sns from 'aws-cdk-lib/aws-sns'; 4 | 5 | export class SNSResourceProvider implements blueprints.ResourceProvider { 6 | constructor(readonly topicName: string, readonly displayName?: string) { } 7 | 8 | provide(context: blueprints.ResourceContext): sns.ITopic { 9 | 10 | const cfnTopic = new cdk.aws_sns.CfnTopic(context.scope, this.topicName + 'Cfn', { 11 | displayName: this.displayName, 12 | tracingConfig: 'Active' 13 | }); 14 | 15 | new cdk.CfnOutput(context.scope, this.topicName + 'ARN', { 16 | value: cfnTopic.attrTopicArn 17 | }) 18 | 19 | return sns.Topic.fromTopicArn(context.scope, this.topicName, cfnTopic.attrTopicArn) 20 | } 21 | } -------------------------------------------------------------------------------- /lib/resourceProvider/vpc.ts: -------------------------------------------------------------------------------- 1 | import { Tags } from 'aws-cdk-lib'; 2 | import * as ec2 from 'aws-cdk-lib/aws-ec2'; 3 | import { ISubnet, PrivateSubnet } from 'aws-cdk-lib/aws-ec2'; 4 | import * as blueprints from '@aws-quickstart/eks-blueprints'; 5 | 6 | /** 7 | * Interface for Mapping for fields such as Primary CIDR, Secondary CIDR, Secondary Subnet CIDR. 8 | */ 9 | interface VpcProps { 10 | primaryCidr?: string, 11 | secondaryCidr?: string, 12 | secondarySubnetCidrs?: string[] 13 | } 14 | 15 | /** 16 | * VPC resource provider 17 | */ 18 | export class VpcProvider implements blueprints.ResourceProvider { 19 | readonly vpcId?: string; 20 | readonly primaryCidr?: string; 21 | readonly secondaryCidr?: string; 22 | readonly secondarySubnetCidrs?: string[]; 23 | 24 | constructor(vpcId?: string, private vpcProps?: VpcProps) { 25 | this.vpcId = vpcId; 26 | this.primaryCidr = vpcProps?.primaryCidr; 27 | this.secondaryCidr = vpcProps?.secondaryCidr; 28 | this.secondarySubnetCidrs = vpcProps?.secondarySubnetCidrs; 29 | } 30 | 31 | provide(context: blueprints.ResourceContext): ec2.IVpc { 32 | const id = context.scope.node.id; 33 | 34 | let vpc = getVPCFromId(context, id, this.vpcId); 35 | if (vpc == null) { 36 | // It will automatically divide the provided VPC CIDR range, and create public and private subnets per Availability Zone. 37 | // If VPC CIDR range is not provided, uses `10.0.0.0/16` as the range and creates public and private subnets per Availability Zone. 38 | // Network routing for the public subnets will be configured to allow outbound access directly via an Internet Gateway. 39 | // Network routing for the private subnets will be configured to allow outbound access via a set of resilient NAT Gateways (one per AZ). 40 | // Creates Secondary CIDR and Secondary subnets if passed. 41 | if (this.primaryCidr) { 42 | vpc = new ec2.Vpc(context.scope, id + "-vpc",{ 43 | ipAddresses: ec2.IpAddresses.cidr(this.primaryCidr) 44 | }); 45 | } 46 | else { 47 | vpc = new ec2.Vpc(context.scope, id + "-vpc"); 48 | } 49 | } 50 | 51 | 52 | if (this.secondaryCidr) { 53 | this.createSecondarySubnets(context, id, vpc); 54 | } 55 | 56 | return vpc; 57 | } 58 | 59 | protected createSecondarySubnets(context: blueprints.ResourceContext, id: string, vpc: ec2.IVpc) { 60 | const secondarySubnets: Array = []; 61 | const secondaryCidr = new ec2.CfnVPCCidrBlock(context.scope, id + "-secondaryCidr", { 62 | vpcId: vpc.vpcId, 63 | cidrBlock: this.secondaryCidr 64 | }); 65 | secondaryCidr.node.addDependency(vpc); 66 | if (this.secondarySubnetCidrs) { 67 | for (let i = 0; i < vpc.availabilityZones.length; i++) { 68 | if (this.secondarySubnetCidrs[i]) { 69 | secondarySubnets[i] = new ec2.PrivateSubnet(context.scope, id + "private-subnet-" + i, { 70 | availabilityZone: vpc.availabilityZones[i], 71 | cidrBlock: this.secondarySubnetCidrs[i], 72 | vpcId: vpc.vpcId 73 | }); 74 | secondarySubnets[i].node.addDependency(secondaryCidr); 75 | context.add("secondary-cidr-subnet-" + i, { 76 | provide(_context): ISubnet { return secondarySubnets[i]; } 77 | }); 78 | } 79 | } 80 | for (let secondarySubnet of secondarySubnets) { 81 | Tags.of(secondarySubnet).add("kubernetes.io/role/internal-elb", "1", { applyToLaunchedInstances: true }); 82 | Tags.of(secondarySubnet).add("Name", `blueprint-construct-dev-PrivateSubnet-${secondarySubnet}`, { applyToLaunchedInstances: true }); 83 | } 84 | } 85 | } 86 | } 87 | 88 | 89 | 90 | /* 91 | ** This function will give return vpc based on the ResourceContext and vpcId passed to the cluster. 92 | */ 93 | export function getVPCFromId(context: blueprints.ResourceContext, nodeId: string, vpcId?: string) { 94 | let vpc = undefined; 95 | if (vpcId) { 96 | if (vpcId === "default") { 97 | console.log(`looking up completely default VPC`); 98 | vpc = ec2.Vpc.fromLookup(context.scope, nodeId + "-vpc", { isDefault: true }); 99 | } else { 100 | console.log(`looking up non-default ${vpcId} VPC`); 101 | vpc = ec2.Vpc.fromLookup(context.scope, nodeId + "-vpc", { vpcId: vpcId }); 102 | } 103 | } 104 | return vpc; 105 | } 106 | -------------------------------------------------------------------------------- /lib/runtime/sdRuntime.ts: -------------------------------------------------------------------------------- 1 | import * as cdk from 'aws-cdk-lib'; 2 | import { Construct } from 'constructs'; 3 | import * as sns from 'aws-cdk-lib/aws-sns'; 4 | import * as blueprints from '@aws-quickstart/eks-blueprints'; 5 | import * as iam from 'aws-cdk-lib/aws-iam'; 6 | import * as sqs from 'aws-cdk-lib/aws-sqs'; 7 | import * as s3 from 'aws-cdk-lib/aws-s3'; 8 | import { aws_sns_subscriptions } from "aws-cdk-lib"; 9 | import * as lodash from "lodash"; 10 | import { createNamespace } from "../utils/namespace" 11 | 12 | export interface SDRuntimeAddOnProps extends blueprints.addons.HelmAddOnUserProps { 13 | type: string, 14 | targetNamespace?: string, 15 | modelBucketArn?: string, 16 | outputSns?: sns.ITopic, 17 | inputSns?: sns.ITopic, 18 | outputBucket?: s3.IBucket 19 | sdModelCheckpoint?: string, 20 | dynamicModel?: boolean, 21 | chartRepository?: string, 22 | chartVersion?: string, 23 | extraValues?: object 24 | } 25 | 26 | export const defaultProps: blueprints.addons.HelmAddOnProps & SDRuntimeAddOnProps = { 27 | chart: 'sd-on-eks', 28 | name: 'sdRuntimeAddOn', 29 | namespace: 'sdruntime', 30 | release: 'sdruntime', 31 | version: '1.1.3', 32 | repository: 'oci://public.ecr.aws/bingjiao/charts/sd-on-eks', 33 | values: {}, 34 | type: "sdwebui" 35 | } 36 | 37 | export default class SDRuntimeAddon extends blueprints.addons.HelmAddOn { 38 | 39 | readonly options: SDRuntimeAddOnProps; 40 | readonly id: string; 41 | 42 | constructor(props: SDRuntimeAddOnProps, id?: string) { 43 | super({ ...defaultProps, ...props }); 44 | this.options = this.props as SDRuntimeAddOnProps; 45 | if (id) { 46 | this.id = id!.toLowerCase() 47 | } else { 48 | this.id = 'sdruntime' 49 | } 50 | } 51 | 52 | @blueprints.utils.dependable(blueprints.KarpenterAddOn.name) 53 | @blueprints.utils.dependable("SharedComponentAddOn") 54 | @blueprints.utils.dependable("s3CSIDriverAddOn") 55 | 56 | deploy(clusterInfo: blueprints.ClusterInfo): Promise { 57 | const cluster = clusterInfo.cluster 58 | 59 | this.props.name = this.id + 'Addon' 60 | this.props.release = this.id 61 | 62 | if (this.options.targetNamespace) { 63 | this.props.namespace = this.options.targetNamespace.toLowerCase() 64 | } else { 65 | this.props.namespace = "default" 66 | } 67 | 68 | const ns = createNamespace(this.id+"-"+this.props.namespace+"-namespace-struct", this.props.namespace, cluster, true) 69 | 70 | const runtimeSA = cluster.addServiceAccount('runtimeSA' + this.id, { namespace: this.props.namespace }); 71 | runtimeSA.node.addDependency(ns) 72 | 73 | if (this.options.chartRepository) { 74 | this.props.repository = this.options.chartRepository 75 | } 76 | 77 | if (this.options.chartVersion) { 78 | this.props.version = this.options.chartVersion 79 | } 80 | 81 | const modelBucket = s3.Bucket.fromBucketAttributes(cluster.stack, 'ModelBucket' + this.id, { 82 | bucketArn: this.options.modelBucketArn! 83 | }); 84 | 85 | modelBucket.grantRead(runtimeSA); 86 | 87 | const inputQueue = new sqs.Queue(cluster.stack, 'InputQueue' + this.id); 88 | inputQueue.grantConsumeMessages(runtimeSA); 89 | 90 | 91 | this.options.outputBucket!.grantWrite(runtimeSA); 92 | this.options.outputBucket!.grantPutAcl(runtimeSA); 93 | this.options.outputSns!.grantPublish(runtimeSA); 94 | 95 | runtimeSA.role.addManagedPolicy( 96 | iam.ManagedPolicy.fromAwsManagedPolicyName( 97 | 'AWSXRayDaemonWriteAccess', 98 | )) 99 | 100 | const nodeRole = clusterInfo.cluster.node.findChild('karpenter-node-role') as iam.IRole 101 | 102 | var generatedValues = { 103 | global: { 104 | awsRegion: cdk.Stack.of(cluster).region, 105 | stackName: cdk.Stack.of(cluster).stackName, 106 | runtime: this.id 107 | }, 108 | runtime: { 109 | type: this.options.type, 110 | serviceAccountName: runtimeSA.serviceAccountName, 111 | queueAgent: { 112 | s3Bucket: this.options.outputBucket!.bucketName, 113 | snsTopicArn: this.options.outputSns!.topicArn, 114 | sqsQueueUrl: inputQueue.queueUrl, 115 | }, 116 | persistence: { 117 | enabled: true, 118 | storageClass: "-", 119 | s3: { 120 | enabled: true, 121 | modelBucket: modelBucket.bucketName 122 | } 123 | } 124 | }, 125 | karpenter: { 126 | nodeTemplate: { 127 | iamRole: nodeRole.roleName 128 | } 129 | } 130 | } 131 | // Temp change: set image repo to ECR Public 132 | if (this.options.type == "sdwebui") { 133 | var imagerepo: string 134 | if (!(lodash.get(this.options, "extraValues.runtime.inferenceApi.image.repository"))) { 135 | imagerepo = "public.ecr.aws/bingjiao/sd-on-eks/sdwebui" 136 | } else { 137 | imagerepo = lodash.get(this.options, "extraValues.runtime.inferenceApi.image.repository")! 138 | } 139 | var sdWebUIgeneratedValues = { 140 | runtime: { 141 | inferenceApi: { 142 | image: { 143 | repository: imagerepo 144 | }, 145 | modelFilename: this.options.sdModelCheckpoint 146 | }, 147 | queueAgent: { 148 | dynamicModel: this.options.dynamicModel 149 | } 150 | } 151 | } 152 | 153 | generatedValues = lodash.merge(generatedValues, sdWebUIgeneratedValues) 154 | } 155 | 156 | if (this.options.type == "comfyui") { 157 | var imagerepo: string 158 | if (!(lodash.get(this.options, "extraValues.runtime.inferenceApi.image.repository"))) { 159 | imagerepo = "public.ecr.aws/bingjiao/sd-on-eks/comfyui" 160 | } else { 161 | imagerepo = lodash.get(this.options, "extraValues.runtime.inferenceApi.image.repository")! 162 | } 163 | 164 | var comfyUIgeneratedValues = { 165 | runtime: { 166 | inferenceApi: { 167 | image: { 168 | repository: imagerepo 169 | }, 170 | } 171 | } 172 | } 173 | 174 | generatedValues = lodash.merge(generatedValues, comfyUIgeneratedValues) 175 | } 176 | 177 | if (this.options.type == "sdwebui" && this.options.sdModelCheckpoint) { 178 | // Legacy and new routing, use CFN as a workaround since L2 construct doesn't support OR 179 | const cfnSubscription = new sns.CfnSubscription(cluster.stack, this.id+'CfnSubscription', { 180 | protocol: 'sqs', 181 | endpoint: inputQueue.queueArn, 182 | topicArn: this.options.inputSns!.topicArn, 183 | filterPolicy: { 184 | "$or": [ 185 | { 186 | "sd_model_checkpoint": [ 187 | this.options.sdModelCheckpoint! 188 | ] 189 | }, { 190 | "runtime": [ 191 | this.id 192 | ] 193 | }] 194 | }, 195 | filterPolicyScope: "MessageAttributes" 196 | }) 197 | 198 | inputQueue.addToResourcePolicy(new iam.PolicyStatement({ 199 | effect: iam.Effect.ALLOW, 200 | principals: [new iam.ServicePrincipal('sns.amazonaws.com')], 201 | actions: ['sqs:SendMessage'], 202 | resources: [inputQueue.queueArn], 203 | conditions: { 204 | 'ArnEquals': { 205 | 'aws:SourceArn': this.options.inputSns!.topicArn 206 | } 207 | } 208 | })) 209 | 210 | 211 | /* It should like this... 212 | this.options.inputSns!.addSubscription(new aws_sns_subscriptions.SqsSubscription(inputQueue, { 213 | filterPolicy: { 214 | : sns.SubscriptionFilter.stringFilter({ 215 | allowlist: [this.options.sdModelCheckpoint!] 216 | }), 217 | runtime: sns.SubscriptionFilter.stringFilter({ 218 | allowlist: [this.id] 219 | }) 220 | } 221 | })) */ 222 | } else { 223 | // New version routing only 224 | this.options.inputSns!.addSubscription(new aws_sns_subscriptions.SqsSubscription(inputQueue, { 225 | filterPolicy: { 226 | runtime: 227 | sns.SubscriptionFilter.stringFilter({ 228 | allowlist: [this.id] 229 | }) 230 | } 231 | })) 232 | } 233 | 234 | const values = lodash.merge(this.props.values, this.options.extraValues, generatedValues) 235 | 236 | const chart = this.addHelmChart(clusterInfo, values, true); 237 | 238 | return Promise.resolve(chart); 239 | } 240 | } -------------------------------------------------------------------------------- /lib/utils/namespace.ts: -------------------------------------------------------------------------------- 1 | import { KubernetesManifest } from "aws-cdk-lib/aws-eks"; 2 | import * as eks from "aws-cdk-lib/aws-eks"; 3 | 4 | export type Values = { 5 | [key: string]: any; 6 | }; 7 | 8 | export function createNamespace(id: string, name: string, cluster: eks.ICluster, overwrite?: boolean, prune?: boolean, annotations?: Values, labels?: Values) { 9 | return new KubernetesManifest(cluster.stack, id, { 10 | cluster: cluster, 11 | manifest: [{ 12 | apiVersion: 'v1', 13 | kind: 'Namespace', 14 | metadata: { 15 | name: name, 16 | annotations, 17 | labels 18 | } 19 | }], 20 | overwrite: overwrite ?? true, 21 | prune: prune ?? true 22 | }); 23 | } -------------------------------------------------------------------------------- /lib/utils/validateConfig.ts: -------------------------------------------------------------------------------- 1 | export function validateConfig (props: any) { 2 | var checkResult = true 3 | 4 | return checkResult 5 | } -------------------------------------------------------------------------------- /load_test/locustfile.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import datetime 3 | import json 4 | import logging 5 | import random 6 | import os 7 | 8 | import boto3 9 | import gevent 10 | from botocore.exceptions import ClientError 11 | from dateutil import parser 12 | from flask import send_file 13 | from locust import HttpUser, events, run_single_user, task 14 | from locust.runners import LocalRunner, MasterRunner 15 | from locust.user.wait_time import between 16 | 17 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(process)s - %(levelname)s - %(message)s') 18 | 19 | logger = logging.getLogger("processor") 20 | logger.setLevel(logging.INFO) 21 | 22 | API_ENDPOINT=os.getenv("API_ENDPOINT") 23 | API_KEY=os.getenv("API_KEY") 24 | OUTPUT_SQS_NAME=os.getenv("OUTPUT_SQS_NAME") 25 | 26 | TEMPLATE=json.loads("""{ 27 | "task": { 28 | "metadata": { 29 | "id": "test-t2i", 30 | "runtime": "sdruntime", 31 | "tasktype": "text-to-image", 32 | "prefix": "output", 33 | "context": "" 34 | }, 35 | "content": { 36 | "alwayson_scripts": {}, 37 | "prompt": "A dog", 38 | "steps": 16, 39 | "width": 512, 40 | "height": 512 41 | } 42 | } 43 | }""") 44 | 45 | 46 | 47 | stats = {} 48 | 49 | sqs = boto3.resource('sqs') 50 | 51 | @events.init.add_listener 52 | def locust_init(environment, **kwargs): 53 | global stats 54 | stats = {} 55 | 56 | if isinstance(environment.runner, MasterRunner) or isinstance(environment.runner, LocalRunner): 57 | gevent.spawn(checker, environment) 58 | 59 | if environment.web_ui: 60 | @environment.web_ui.app.route("/result") 61 | def get_result(): 62 | task_count = len(stats) 63 | completed_count = sum((i["complete"] == True) for i in stats.values()) 64 | failed_count = sum((i["error"] == True) for i in stats.values()) 65 | failure_rate = failed_count / task_count 66 | avg_time_usage = sum(i["time_usage"] for i in stats.values()) / task_count 67 | response = { 68 | "task_count": task_count, 69 | "completed_count": completed_count, 70 | "failed_count": failed_count, 71 | "failure_rate": failure_rate, 72 | "avg_time_usage": avg_time_usage 73 | } 74 | return json.dumps(response) 75 | 76 | @environment.web_ui.app.route("/dump_failed") 77 | def dump_failed(): 78 | failed = dict(filter(lambda x: x[1]["error"], stats.items())) 79 | return json.dumps(failed) 80 | 81 | @environment.web_ui.app.route("/dump_all") 82 | def dump_all(): 83 | data_file = open('/tmp/result.csv', 'w') 84 | csv_writer = csv.writer(data_file) 85 | count = 0 86 | for item in stats.values(): 87 | if count == 0: 88 | header = item.keys() 89 | csv_writer.writerow(header) 90 | count += 1 91 | item["start_time"] = int(item["start_time"]) 92 | item["complete_time"] = int(item["complete_time"]) 93 | csv_writer.writerow(item.values()) 94 | data_file.close() 95 | return send_file('/tmp/result.csv', as_attachment=True) 96 | 97 | def receive_messages(queue, max_number, wait_time): 98 | try: 99 | messages = queue.receive_messages( 100 | MaxNumberOfMessages=max_number, 101 | WaitTimeSeconds=wait_time, 102 | AttributeNames=['All'], 103 | MessageAttributeNames=['All'] 104 | ) 105 | except Exception as error: 106 | logger.error(f"Error receiving messages: {error}") 107 | else: 108 | return messages 109 | 110 | def delete_message(message): 111 | try: 112 | message.delete() 113 | except ClientError as error: 114 | raise error 115 | 116 | def checker(environment): 117 | global stats 118 | logger.info(f'Checker launched') 119 | queue = sqs.get_queue_by_name(QueueName=OUTPUT_SQS_NAME) 120 | # while not environment.runner.state in [STATE_STOPPED]: 121 | while True: 122 | received_messages = receive_messages(queue, 10, 20) 123 | if len(received_messages) == 0: 124 | logger.debug('No message received') 125 | else: 126 | for message in received_messages: 127 | payload = json.loads(message.body) 128 | msg = json.loads(payload['Message']) 129 | succeed = msg['result'] 130 | task_id = msg['id'] 131 | logger.debug(f'Received task with {task_id}') 132 | if task_id in stats.keys(): 133 | logger.debug(f'Processing {task_id}') 134 | time = parser.parse(payload["Timestamp"]).now(datetime.timezone.utc).timestamp() 135 | time_usage = int((time - stats[task_id]['start_time']) * 1000) 136 | if succeed: 137 | stats[task_id]['complete'] = True 138 | stats[task_id]['error'] = False 139 | stats[task_id]['complete_time'] = time 140 | stats[task_id]['time_usage'] = time_usage 141 | else: 142 | stats[task_id]['complete'] = True 143 | stats[task_id]['error'] = True 144 | stats[task_id]['complete_time'] = time 145 | stats[task_id]['time_usage'] = time_usage 146 | else: 147 | logger.debug(f'Ignored {task_id}') 148 | delete_message(message) 149 | 150 | @events.test_start.add_listener 151 | def on_test_start(**kwargs): 152 | global stats 153 | stats = {} 154 | 155 | class MyUser(HttpUser): 156 | host = "http://0.0.0.0:8089" 157 | wait_time = between(0.5, 2.0) 158 | 159 | @task 160 | def txt_to_img(self): 161 | random_number = str(random.randint(1, 999999999)).zfill(9) 162 | body = TEMPLATE.copy() 163 | body["task"]["metadata"]["id"] = str(random_number) 164 | logger.debug(f'Send request with {random_number}') 165 | self.client.post(API_ENDPOINT+"v1alpha2", 166 | data = json.dumps(body), 167 | headers={"x-api-key": API_KEY, "Content-Type": "application/json"}, 168 | context={"task-id": random_number}) 169 | 170 | @events.request.add_listener 171 | def on_request(context, **kwargs): 172 | """ 173 | Event handler that get triggered on every request. 174 | """ 175 | global stats 176 | stats[context["task-id"]] = { 177 | "task-id": context["task-id"], 178 | "start_time": datetime.datetime.now(datetime.timezone.utc).timestamp(), 179 | "complete_time": 0.0, 180 | "complete": False, 181 | "error": False, 182 | "time_usage": 0 183 | } 184 | 185 | # if launched directly, e.g. "python3 debugging.py", not "locust -f debugging.py" 186 | if __name__ == "__main__": 187 | run_single_user(MyUser) -------------------------------------------------------------------------------- /package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "stable-difussion-on-eks", 3 | "version": "1.1.2", 4 | "bin": { 5 | "stable-difussion-on-eks": "bin/stable-difussion-on-eks.ts" 6 | }, 7 | "scripts": { 8 | "synth": "npx cdk synth -q", 9 | "deploy": "npx cdk deploy --no-rollback --require-approval never" 10 | }, 11 | "devDependencies": { 12 | "@types/node": "22.8.6", 13 | "aws-cdk": "2.173.4", 14 | "ts-node": "^10.9.1", 15 | "typescript": "~5.5.2" 16 | }, 17 | "dependencies": { 18 | "@aws-quickstart/eks-blueprints": "1.16.3", 19 | "@types/lodash": "^4.14.197", 20 | "aws-cdk-lib": "2.173.4", 21 | "constructs": "^10.0.0", 22 | "source-map-support": "^0.5.21", 23 | "yaml": "^2.3.2" 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /src/backend/queue_agent/.gitignore: -------------------------------------------------------------------------------- 1 | test/* -------------------------------------------------------------------------------- /src/backend/queue_agent/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM public.ecr.aws/docker/library/python:3.12-slim 2 | 3 | RUN apt-get update; apt-get install libmagic1 -y; rm -rf /var/lib/apt/lists/* 4 | 5 | RUN adduser --disabled-password --gecos '' app && \ 6 | mkdir /app || true; \ 7 | chown -R app /app 8 | WORKDIR /app 9 | USER app 10 | 11 | COPY requirements.txt . 12 | RUN pip3 install --no-cache-dir --upgrade pip && \ 13 | pip3 install --no-cache-dir -r requirements.txt 14 | 15 | COPY src/ /app/ 16 | 17 | CMD python3 -u /app/main.py 18 | -------------------------------------------------------------------------------- /src/backend/queue_agent/requirements.txt: -------------------------------------------------------------------------------- 1 | aioboto3>=12.3.0 2 | aiohttp_client_cache>=0.11.0 3 | aws_xray_sdk==2.12.0 4 | boto3==1.33.2 5 | botocore==1.33.2 6 | python_magic>=0.4.24 7 | Requests>=2.31.0 8 | requests_cache>=1.2.0 9 | websocket_client>=1.7.0 10 | -------------------------------------------------------------------------------- /src/backend/queue_agent/src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | -------------------------------------------------------------------------------- /src/backend/queue_agent/src/main.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import json 5 | import logging 6 | import os 7 | import signal 8 | import sys 9 | import uuid 10 | import time 11 | import functools 12 | 13 | import boto3 14 | from botocore.exceptions import EndpointConnectionError 15 | from aws_xray_sdk.core import patch_all, xray_recorder 16 | from aws_xray_sdk.core.models.trace_header import TraceHeader 17 | from modules import s3_action, sns_action, sqs_action 18 | from runtimes import comfyui, sdwebui 19 | 20 | # Initialize logging first so we can log X-Ray initialization attempts 21 | logging.basicConfig() 22 | logging.getLogger().setLevel(logging.ERROR) 23 | 24 | # Configure the queue-agent logger only once 25 | logger = logging.getLogger("queue-agent") 26 | logger.propagate = False 27 | logger.setLevel(os.environ.get('LOGLEVEL', 'INFO').upper()) 28 | 29 | # Remove any existing handlers to prevent duplicate logs 30 | if logger.handlers: 31 | logger.handlers.clear() 32 | 33 | # Add a single handler 34 | handler = logging.StreamHandler(sys.stdout) 35 | handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')) 36 | logger.addHandler(handler) 37 | 38 | # Check if X-Ray is manually disabled via environment variable 39 | DISABLE_XRAY = os.environ.get('DISABLE_XRAY', 'false').lower() == 'true' 40 | if DISABLE_XRAY: 41 | logger.info("X-Ray tracing manually disabled via DISABLE_XRAY environment variable") 42 | xray_enabled = False 43 | else: 44 | # Try to initialize X-Ray SDK with retries, as the daemon might be starting up 45 | MAX_XRAY_INIT_ATTEMPTS = 5 46 | XRAY_RETRY_DELAY = 3 # seconds 47 | xray_enabled = False 48 | 49 | for attempt in range(MAX_XRAY_INIT_ATTEMPTS): 50 | try: 51 | logger.info(f"Attempting to initialize X-Ray SDK (attempt {attempt+1}/{MAX_XRAY_INIT_ATTEMPTS})") 52 | patch_all() 53 | xray_enabled = True 54 | logger.info("X-Ray SDK initialized successfully") 55 | break 56 | except EndpointConnectionError: 57 | logger.warning(f"Could not connect to X-Ray daemon (attempt {attempt+1}/{MAX_XRAY_INIT_ATTEMPTS})") 58 | if attempt < MAX_XRAY_INIT_ATTEMPTS - 1: 59 | logger.info(f"Retrying in {XRAY_RETRY_DELAY} seconds...") 60 | time.sleep(XRAY_RETRY_DELAY) 61 | except Exception as e: 62 | logger.warning(f"Error initializing X-Ray: {str(e)} (attempt {attempt+1}/{MAX_XRAY_INIT_ATTEMPTS})") 63 | if attempt < MAX_XRAY_INIT_ATTEMPTS - 1: 64 | logger.info(f"Retrying in {XRAY_RETRY_DELAY} seconds...") 65 | time.sleep(XRAY_RETRY_DELAY) 66 | 67 | if not xray_enabled: 68 | logger.warning("X-Ray initialization failed after all attempts. Tracing will be disabled.") 69 | 70 | # Create a decorator for safe X-Ray instrumentation 71 | def safe_xray_capture(name): 72 | """Decorator that safely applies X-Ray instrumentation if available""" 73 | def decorator(func): 74 | @functools.wraps(func) 75 | def wrapper(*args, **kwargs): 76 | if xray_enabled: 77 | try: 78 | # Try to use X-Ray instrumentation 79 | with xray_recorder.in_segment(name): 80 | return func(*args, **kwargs) 81 | except Exception as e: 82 | logger.warning(f"X-Ray instrumentation failed for {name}: {str(e)}") 83 | # Fall back to non-instrumented execution 84 | return func(*args, **kwargs) 85 | else: 86 | # X-Ray is disabled, just call the function directly 87 | return func(*args, **kwargs) 88 | return wrapper 89 | return decorator 90 | 91 | # Get base environment variable 92 | aws_default_region = os.getenv("AWS_DEFAULT_REGION") 93 | sqs_queue_url = os.getenv("SQS_QUEUE_URL") 94 | sns_topic_arn = os.getenv("SNS_TOPIC_ARN") 95 | s3_bucket = os.getenv("S3_BUCKET") 96 | runtime_name = os.getenv("RUNTIME_NAME", "") 97 | api_base_url = "" 98 | 99 | exp_callback_when_running = os.getenv("EXP_CALLBACK_WHEN_RUNNING", "") 100 | 101 | # Check current runtime type 102 | runtime_type = os.getenv("RUNTIME_TYPE", "").lower() 103 | 104 | # Runtime type should be specified 105 | if runtime_type == "": 106 | logger.error(f'Runtime type not specified') 107 | raise RuntimeError 108 | 109 | # Init for SD Web UI 110 | if runtime_type == "sdwebui": 111 | api_base_url = os.getenv("API_BASE_URL", "http://localhost:8080/sdapi/v1/") 112 | dynamic_sd_model_str = os.getenv("DYNAMIC_SD_MODEL", "false") 113 | if dynamic_sd_model_str.lower() == "false": 114 | dynamic_sd_model = False 115 | else: 116 | dynamic_sd_model = True 117 | 118 | # Init for ComfyUI 119 | if runtime_type == "comfyui": 120 | api_base_url = os.getenv("API_BASE_URL", "localhost:8080") 121 | client_id = str(uuid.uuid4()) 122 | # Change here to ComfyUI's base URL 123 | # You can specify any required environment variable here 124 | 125 | sqsRes = boto3.resource('sqs') 126 | snsRes = boto3.resource('sns') 127 | 128 | SQS_WAIT_TIME_SECONDS = 20 129 | 130 | # For graceful shutdown 131 | shutdown = False 132 | 133 | def main(): 134 | # Initialization: 135 | # 1. Environment parameters; 136 | # 2. AWS services resources(sqs/sns/s3); 137 | # 3. SD API readiness check, current checkpoint cached; 138 | print_env() 139 | 140 | queue = sqsRes.Queue(sqs_queue_url) 141 | topic = snsRes.Topic(sns_topic_arn) 142 | 143 | if runtime_type == "sdwebui": 144 | sdwebui.check_readiness(api_base_url, dynamic_sd_model) 145 | 146 | if runtime_type == "comfyui": 147 | comfyui.check_readiness(api_base_url) 148 | 149 | # main loop 150 | # 1. Pull msg from sqs; 151 | # 2. Translate parameteres; 152 | # 3. (opt)Switch model; 153 | # 4. (opt)Prepare inputs for image downloading and encoding; 154 | # 5. Call SD API; 155 | # 6. Prepare outputs for decoding, uploading and notifying; 156 | # 7. Delete msg; 157 | while True: 158 | if shutdown: 159 | logger.info('Received SIGTERM, shutting down...') 160 | break 161 | 162 | received_messages = sqs_action.receive_messages(queue, 1, SQS_WAIT_TIME_SECONDS) 163 | 164 | for message in received_messages: 165 | # Process with X-Ray if enabled, otherwise just process the message directly 166 | if xray_enabled: 167 | try: 168 | with xray_recorder.in_segment(runtime_name+"-queue-agent") as segment: 169 | # Retrieve x-ray trace header from SQS message 170 | if "AWSTraceHeader" in message.attributes.keys(): 171 | traceHeaderStr = message.attributes['AWSTraceHeader'] 172 | sqsTraceHeader = TraceHeader.from_header_str(traceHeaderStr) 173 | # Update current segment to link with SQS 174 | segment.trace_id = sqsTraceHeader.root 175 | segment.parent_id = sqsTraceHeader.parent 176 | segment.sampled = sqsTraceHeader.sampled 177 | 178 | # Process the message within the X-Ray segment 179 | process_message(message, topic, s3_bucket, runtime_type, runtime_name, api_base_url, dynamic_sd_model if runtime_type == "sdwebui" else None) 180 | except Exception as e: 181 | logger.error(f"Error with X-Ray tracing: {str(e)}. Processing message without tracing.") 182 | process_message(message, topic, s3_bucket, runtime_type, runtime_name, api_base_url, dynamic_sd_model if runtime_type == "sdwebui" else None) 183 | else: 184 | # Process without X-Ray tracing 185 | process_message(message, topic, s3_bucket, runtime_type, runtime_name, api_base_url, dynamic_sd_model if runtime_type == "sdwebui" else None) 186 | 187 | def process_message(message, topic, s3_bucket, runtime_type, runtime_name, api_base_url, dynamic_sd_model=None): 188 | """Process a single SQS message""" 189 | # Process received message 190 | try: 191 | payload = json.loads(json.loads(message.body)['Message']) 192 | metadata = payload["metadata"] 193 | task_id = metadata["id"] 194 | 195 | logger.info(f"Received task {task_id}, processing") 196 | 197 | if "prefix" in metadata.keys(): 198 | if metadata["prefix"][-1] == '/': 199 | prefix = metadata["prefix"] + str(task_id) 200 | else: 201 | prefix = metadata["prefix"] + "/" + str(task_id) 202 | else: 203 | prefix = str(task_id) 204 | 205 | if "tasktype" in metadata.keys(): 206 | tasktype = metadata["tasktype"] 207 | 208 | if "context" in metadata.keys(): 209 | context = metadata["context"] 210 | else: 211 | context = {} 212 | 213 | body = payload["content"] 214 | logger.debug(body) 215 | except Exception as e: 216 | logger.error(f"Error parsing message: {e}, skipping") 217 | logger.debug(payload) 218 | sqs_action.delete_message(message) 219 | return 220 | 221 | if (exp_callback_when_running.lower() == "true"): 222 | sns_response = {"runtime": runtime_name, 223 | 'id': task_id, 224 | 'status': "running", 225 | 'context': context} 226 | 227 | sns_action.publish_message(topic, json.dumps(sns_response)) 228 | 229 | # Start handling message 230 | response = {} 231 | 232 | try: 233 | if runtime_type == "sdwebui": 234 | response = sdwebui.handler(api_base_url, tasktype, task_id, body, dynamic_sd_model) 235 | 236 | if runtime_type == "comfyui": 237 | response = comfyui.handler(api_base_url, task_id, body) 238 | except Exception as e: 239 | logger.error(f"Error calling handler for task {task_id}: {str(e)}") 240 | response = { 241 | "success": False, 242 | "image": [], 243 | "content": '{"code": 500, "error": "Runtime handler failed"}' 244 | } 245 | 246 | result = [] 247 | rand = str(uuid.uuid4())[0:4] 248 | 249 | if response["success"]: 250 | idx = 0 251 | if len(response["image"]) > 0: 252 | for i in response["image"]: 253 | idx += 1 254 | result.append(s3_action.upload_file(i, s3_bucket, prefix, str(task_id)+"-"+rand+"-"+str(idx))) 255 | 256 | output_url = s3_action.upload_file(response["content"], s3_bucket, prefix, str(task_id)+"-"+rand, ".out") 257 | 258 | if response["success"]: 259 | status = "completed" 260 | else: 261 | status = "failed" 262 | 263 | sns_response = {"runtime": runtime_name, 264 | 'id': task_id, 265 | 'result': response["success"], 266 | 'status': status, 267 | 'image_url': result, 268 | 'output_url': output_url, 269 | 'context': context} 270 | 271 | # Put response handler to SNS and delete message 272 | sns_action.publish_message(topic, json.dumps(sns_response)) 273 | sqs_action.delete_message(message) 274 | 275 | def print_env() -> None: 276 | logger.info(f'AWS_DEFAULT_REGION={aws_default_region}') 277 | logger.info(f'SQS_QUEUE_URL={sqs_queue_url}') 278 | logger.info(f'SNS_TOPIC_ARN={sns_topic_arn}') 279 | logger.info(f'S3_BUCKET={s3_bucket}') 280 | logger.info(f'RUNTIME_TYPE={runtime_type}') 281 | logger.info(f'RUNTIME_NAME={runtime_name}') 282 | logger.info(f'X-Ray Tracing: {"Disabled" if DISABLE_XRAY else "Enabled"}') 283 | logger.info(f'X-Ray Status: {"Active" if xray_enabled else "Inactive"}') 284 | 285 | def signalHandler(signum, frame): 286 | global shutdown 287 | shutdown = True 288 | 289 | if __name__ == '__main__': 290 | for sig in [signal.SIGINT, signal.SIGHUP, signal.SIGTERM]: 291 | signal.signal(sig, signalHandler) 292 | main() 293 | -------------------------------------------------------------------------------- /src/backend/queue_agent/src/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | -------------------------------------------------------------------------------- /src/backend/queue_agent/src/modules/http_action.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import logging 5 | 6 | import aioboto3 7 | import boto3 8 | import requests 9 | import requests_cache 10 | from aiohttp_client_cache import CacheBackend, CachedSession 11 | from requests.adapters import HTTPAdapter, Retry 12 | 13 | from . import s3_action, time_utils 14 | 15 | logger = logging.getLogger("queue-agent") 16 | 17 | 18 | ab3_session = aioboto3.Session() 19 | 20 | apiClient = requests.Session() 21 | retries = Retry( 22 | total=3, 23 | connect=100, 24 | backoff_factor=0.1, 25 | allowed_methods=["GET", "POST"]) 26 | apiClient.mount('http://', HTTPAdapter(max_retries=retries)) 27 | 28 | REQUESTS_TIMEOUT_SECONDS = 300 29 | 30 | cache = CacheBackend( 31 | cache_name='memory-cache', 32 | expire_after=600 33 | ) 34 | 35 | @time_utils.get_time 36 | def do_invocations(url: str, body:str=None) -> str: 37 | if body is None: 38 | logger.debug(f"Invoking {url}") 39 | response = apiClient.get( 40 | url=url, timeout=(1, REQUESTS_TIMEOUT_SECONDS)) 41 | else: 42 | logger.debug(f"Invoking {url} with body: {body}") 43 | response = apiClient.post( 44 | url=url, json=body, timeout=(1, REQUESTS_TIMEOUT_SECONDS)) 45 | response.raise_for_status() 46 | logger.debug(response.text) 47 | return response.json() 48 | 49 | def get(url: str) -> bytes: 50 | logger.debug(f"Downloading {url}") 51 | try: 52 | if url.lower().startswith("http://") or url.lower().startswith("https://"): 53 | with requests_cache.CachedSession('demo_cache') as session: 54 | with session.get(url) as res: 55 | res.raise_for_status() 56 | return res.content 57 | elif url.lower().startswith("s3://"): 58 | bucket_name, key = s3_action.get_bucket_and_key(url) 59 | with boto3.client('s3') as s3: 60 | obj = s3.get_object(Bucket=bucket_name, Key=key) 61 | data = obj['Body'].read() 62 | return data 63 | except Exception as e: 64 | raise e 65 | 66 | async def async_get(url: str) -> None: 67 | try: 68 | if url.lower().startswith("http://") or url.lower().startswith("https://"): 69 | async with CachedSession(cache=cache) as session: 70 | async with session.get(url) as res: 71 | res.raise_for_status() 72 | return await res.read() 73 | elif url.lower().startswith("s3://"): 74 | bucket_name, key = s3_action.get_bucket_and_key(url) 75 | async with ab3_session.resource("s3") as s3: 76 | obj = await s3.Object(bucket_name, key) 77 | res = await obj.get() 78 | return await res['Body'].read() 79 | except Exception as e: 80 | raise e -------------------------------------------------------------------------------- /src/backend/queue_agent/src/modules/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import base64 5 | import difflib 6 | 7 | 8 | def exclude_keys(dictionary, keys): 9 | key_set = set(dictionary.keys()) - set(keys) 10 | return {key: dictionary[key] for key in key_set} 11 | 12 | def str_simularity(a, b): 13 | return difflib.SequenceMatcher(None, a, b).ratio() 14 | 15 | def encode_to_base64(buffer): 16 | return str(base64.b64encode(buffer))[2:-1] -------------------------------------------------------------------------------- /src/backend/queue_agent/src/modules/s3_action.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import datetime 5 | import logging 6 | import mimetypes 7 | import uuid 8 | 9 | import aioboto3 10 | import boto3 11 | import magic 12 | 13 | logger = logging.getLogger("queue-agent") 14 | s3Res = boto3.resource('s3') 15 | 16 | ab3_session = aioboto3.Session() 17 | 18 | def upload_file(object_bytes: bytes, bucket_name: str, prefix: str, file_name: str=None, extension: str=None) -> str: 19 | if file_name is None: 20 | file_name = datetime.datetime.now().strftime(f"%Y%m%d%H%M%S-{uuid.uuid4()[0:5]}") 21 | 22 | # Auto determine file type and extension using magic 23 | if extension is None: 24 | content_type = magic.from_buffer(object_bytes, mime=True) 25 | extension = mimetypes.guess_extension(content_type, True) 26 | 27 | if extension == '.out': 28 | content_type = f'application/json' 29 | 30 | try: 31 | bucket = s3Res.Bucket(bucket_name) 32 | logger.info(f"Uploading s3://{bucket_name}/{prefix}/{file_name}{extension}") 33 | bucket.put_object(Body=object_bytes, Key=f'{prefix}/{file_name}{extension}', ContentType=content_type) 34 | return f's3://{bucket_name}/{prefix}/{file_name}{extension}' 35 | except Exception as error: 36 | logger.error('Failed to upload content to S3', exc_info=True) 37 | raise error 38 | 39 | 40 | async def async_upload(object_bytes: bytes, bucket_name: str, prefix: str, file_name: str=None, extension: str=None) -> str: 41 | if file_name is None: 42 | file_name = datetime.datetime.now().strftime(f"%Y%m%d%H%M%S-{uuid.uuid4()[0:5]}") 43 | 44 | # Auto determine file type and extension using magic 45 | if extension is None: 46 | content_type = magic.from_buffer(object_bytes, mime=True) 47 | extension = mimetypes.guess_extension(content_type, True) 48 | 49 | if extension == '.out': 50 | content_type = f'application/json' 51 | 52 | try: 53 | async with ab3_session.resource("s3") as s3: 54 | bucket = await s3.Bucket(bucket_name) 55 | await bucket.put_object(Body=object_bytes, Key=f'{prefix}/{file_name}{extension}', ContentType=content_type) 56 | return f's3://{bucket_name}/{prefix}/{file_name}{extension}' 57 | except Exception as e: 58 | raise e 59 | 60 | def get_bucket_and_key(s3uri): 61 | pos = s3uri.find('/', 5) 62 | bucket = s3uri[5: pos] 63 | key = s3uri[pos + 1:] 64 | return bucket, key 65 | 66 | def get_prefix(path): 67 | pos = path.find('/') 68 | return path[pos + 1:] -------------------------------------------------------------------------------- /src/backend/queue_agent/src/modules/sns_action.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import aioboto3 4 | from botocore.exceptions import ClientError 5 | 6 | logger = logging.getLogger("queue-agent") 7 | 8 | ab3_session = aioboto3.Session() 9 | 10 | def publish_message(topic, message: str) -> str: 11 | try: 12 | response = topic.publish(Message=message) 13 | message_id = response['MessageId'] 14 | except ClientError as error: 15 | logger.error('Failed to send message to SNS', exc_info=True) 16 | raise error 17 | else: 18 | return message_id 19 | 20 | async def async_publish_message(topic, content: str): 21 | try: 22 | async with ab3_session.resource("sns") as sns: 23 | topic = await sns.Topic(topic) 24 | response = await topic.publish(Message=content) 25 | return response['MessageId'] 26 | except Exception as e: 27 | raise e -------------------------------------------------------------------------------- /src/backend/queue_agent/src/modules/sqs_action.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import logging 5 | 6 | from botocore.exceptions import ClientError 7 | 8 | logger = logging.getLogger("queue-agent") 9 | 10 | def receive_messages(queue, max_number, wait_time): 11 | try: 12 | messages = queue.receive_messages( 13 | MaxNumberOfMessages=max_number, 14 | WaitTimeSeconds=wait_time, 15 | AttributeNames=['All'], 16 | MessageAttributeNames=['All'] 17 | ) 18 | except ClientError as error: 19 | logger.error('Failed to get message from SQS', exc_info=True) 20 | raise error 21 | else: 22 | return messages 23 | 24 | def delete_message(message): 25 | try: 26 | message.delete() 27 | except ClientError as error: 28 | logger.error('Failed to delete message from SQS', exc_info=True) 29 | raise error -------------------------------------------------------------------------------- /src/backend/queue_agent/src/modules/time_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import logging 5 | from time import perf_counter 6 | 7 | logger = logging.getLogger("queue-agent") 8 | 9 | def get_time(f): 10 | def inner(*arg, **kwarg): 11 | s_time = perf_counter() 12 | res = f(*arg, **kwarg) 13 | e_time = perf_counter() 14 | logger.info('Used: {:.4f} seconds on api: {}.'.format(e_time - s_time, arg[0])) 15 | return res 16 | return inner -------------------------------------------------------------------------------- /src/backend/queue_agent/src/runtimes/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 -------------------------------------------------------------------------------- /src/backend/queue_agent/src/runtimes/comfyui.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import json 5 | import logging 6 | import time 7 | import traceback 8 | import urllib.parse 9 | import uuid 10 | from typing import Optional, Dict, List, Any, Union 11 | 12 | import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) 13 | from modules import http_action 14 | 15 | logger = logging.getLogger("queue-agent") 16 | 17 | # Import the safe_xray_capture decorator from main module 18 | try: 19 | from src.main import safe_xray_capture, xray_enabled 20 | except ImportError: 21 | try: 22 | # Try alternative import path 23 | from ..main import safe_xray_capture, xray_enabled 24 | except ImportError: 25 | # Fallback if import fails - create a simple pass-through decorator 26 | logger.warning("Failed to import safe_xray_capture from main, using fallback") 27 | def safe_xray_capture(name): 28 | def decorator(func): 29 | return func 30 | return decorator 31 | xray_enabled = False 32 | 33 | # Constants for websocket reconnection 34 | MAX_RECONNECT_ATTEMPTS = 5 35 | RECONNECT_DELAY = 2 # seconds 36 | 37 | def singleton(cls): 38 | _instance = {} 39 | 40 | def inner(): 41 | if cls not in _instance: 42 | _instance[cls] = cls() 43 | return _instance[cls] 44 | return inner 45 | 46 | @singleton 47 | class comfyuiCaller(object): 48 | 49 | def __init__(self): 50 | self.wss = websocket.WebSocket() 51 | self.client_id = str(uuid.uuid4()) 52 | self.api_base_url = None 53 | self.connected = False 54 | 55 | def setUrl(self, api_base_url:str): 56 | self.api_base_url = api_base_url 57 | 58 | def wss_connect(self): 59 | """Connect to websocket with reconnection logic""" 60 | if self.connected: 61 | return True 62 | 63 | attempts = 0 64 | while attempts < MAX_RECONNECT_ATTEMPTS: 65 | try: 66 | logger.info(f"Connecting to websocket (attempt {attempts+1}/{MAX_RECONNECT_ATTEMPTS})") 67 | self.wss.connect("ws://{}/ws?clientId={}".format(self.api_base_url, self.client_id)) 68 | self.connected = True 69 | logger.info("Successfully connected to websocket") 70 | return True 71 | except Exception as e: 72 | attempts += 1 73 | logger.warning(f"Failed to connect to websocket: {str(e)}") 74 | if attempts < MAX_RECONNECT_ATTEMPTS: 75 | logger.info(f"Retrying in {RECONNECT_DELAY} seconds...") 76 | time.sleep(RECONNECT_DELAY) 77 | else: 78 | logger.error("Max reconnection attempts reached") 79 | raise ConnectionError(f"Failed to connect to ComfyUI websocket after {MAX_RECONNECT_ATTEMPTS} attempts") from e 80 | 81 | return False 82 | 83 | def wss_recv(self) -> Optional[str]: 84 | """Receive data from websocket with reconnection logic""" 85 | attempts = 0 86 | while attempts < MAX_RECONNECT_ATTEMPTS: 87 | try: 88 | return self.wss.recv() 89 | except websocket.WebSocketConnectionClosedException: 90 | attempts += 1 91 | logger.warning(f"WebSocket connection closed, attempting to reconnect (attempt {attempts}/{MAX_RECONNECT_ATTEMPTS})...") 92 | self.connected = False 93 | 94 | if attempts < MAX_RECONNECT_ATTEMPTS: 95 | if self.wss_connect(): 96 | logger.info("Reconnected successfully, retrying receive operation") 97 | continue 98 | else: 99 | logger.warning(f"Failed to reconnect, waiting {RECONNECT_DELAY} seconds before retry...") 100 | time.sleep(RECONNECT_DELAY) 101 | else: 102 | logger.error("Max reconnection attempts reached in wss_recv") 103 | return None 104 | except Exception as e: 105 | attempts += 1 106 | logger.error(f"Error receiving data from websocket: {str(e)}") 107 | self.connected = False 108 | 109 | if attempts < MAX_RECONNECT_ATTEMPTS: 110 | logger.info(f"Waiting {RECONNECT_DELAY} seconds before retry...") 111 | time.sleep(RECONNECT_DELAY) 112 | if self.wss_connect(): 113 | logger.info("Reconnected successfully, retrying receive operation") 114 | continue 115 | else: 116 | logger.error("Max reconnection attempts reached in wss_recv") 117 | return None 118 | 119 | return None 120 | 121 | def get_history(self, prompt_id): 122 | try: 123 | url = f"http://{self.api_base_url}/history/{prompt_id}" 124 | # Use the http_action module with built-in retry logic 125 | return http_action.do_invocations(url) 126 | except Exception as e: 127 | logger.error(f"Error in get_history: {str(e)}") 128 | return {} 129 | 130 | def queue_prompt(self, prompt): 131 | try: 132 | p = {"prompt": prompt, "client_id": self.client_id} 133 | url = f"http://{self.api_base_url}/prompt" 134 | 135 | # Use the http_action module with built-in retry logic 136 | response = http_action.do_invocations(url, p) 137 | return response 138 | except Exception as e: 139 | logger.error(f"Error in queue_prompt: {str(e)}") 140 | return None 141 | 142 | def get_image(self, filename, subfolder, folder_type): 143 | try: 144 | data = {"filename": filename, "subfolder": subfolder, "type": folder_type} 145 | url_values = urllib.parse.urlencode(data) 146 | url = f"http://{self.api_base_url}/view?{url_values}" 147 | 148 | # Use http_action.get which returns bytes directly 149 | return http_action.get(url) 150 | except Exception as e: 151 | logger.error(f"Error getting image {filename}: {str(e)}") 152 | return b'' # Return empty bytes on error 153 | 154 | def track_progress(self, prompt, prompt_id): 155 | logger.info("Task received, prompt ID:" + prompt_id) 156 | node_ids = list(prompt.keys()) 157 | finished_nodes = [] 158 | max_errors = 5 159 | error_count = 0 160 | 161 | while True: 162 | try: 163 | out = self.wss_recv() # Using our new method with reconnection logic 164 | if out is None: 165 | error_count += 1 166 | logger.warning(f"Failed to receive data from websocket (error {error_count}/{max_errors})") 167 | if error_count >= max_errors: 168 | logger.error("Too many errors receiving websocket data, aborting track_progress") 169 | return False 170 | time.sleep(1) 171 | continue 172 | 173 | error_count = 0 # Reset error count on successful receive 174 | 175 | if isinstance(out, str): 176 | try: 177 | message = json.loads(out) 178 | logger.debug(out) 179 | if message['type'] == 'progress': 180 | data = message['data'] 181 | current_step = data['value'] 182 | logger.info(f"In K-Sampler -> Step: {current_step} of: {data['max']}") 183 | if message['type'] == 'execution_cached': 184 | data = message['data'] 185 | for itm in data['nodes']: 186 | if itm not in finished_nodes: 187 | finished_nodes.append(itm) 188 | logger.info(f"Progress: {len(finished_nodes)} / {len(node_ids)} tasks done") 189 | if message['type'] == 'executing': 190 | data = message['data'] 191 | if data['node'] not in finished_nodes: 192 | finished_nodes.append(data['node']) 193 | logger.info(f"Progress: {len(finished_nodes)} / {len(node_ids)} tasks done") 194 | 195 | if data['node'] is None and data['prompt_id'] == prompt_id: 196 | return True # Execution is done successfully 197 | except json.JSONDecodeError as e: 198 | logger.warning(f"Error parsing websocket message: {str(e)}, skipping message") 199 | continue 200 | except KeyError as e: 201 | logger.warning(f"Missing key in websocket message: {str(e)}, skipping message") 202 | continue 203 | else: 204 | continue 205 | except Exception as e: 206 | error_count += 1 207 | logger.warning(f"Unexpected error in track_progress: {str(e)} (error {error_count}/{max_errors})") 208 | if error_count >= max_errors: 209 | logger.error("Too many errors in track_progress, aborting") 210 | return False 211 | time.sleep(1) 212 | 213 | return True 214 | 215 | def get_images(self, prompt): 216 | max_retries = 3 217 | retry_count = 0 218 | 219 | while retry_count < max_retries: 220 | try: 221 | output = self.queue_prompt(prompt) 222 | if output is None: 223 | raise RuntimeError("Failed to queue prompt - internal error") 224 | 225 | prompt_id = output['prompt_id'] 226 | output_images = {} 227 | 228 | self.track_progress(prompt, prompt_id) 229 | 230 | history = self.get_history(prompt_id)[prompt_id] 231 | for o in history['outputs']: 232 | for node_id in history['outputs']: 233 | node_output = history['outputs'][node_id] 234 | # image branch 235 | if 'images' in node_output: 236 | images_output = [] 237 | for image in node_output['images']: 238 | image_data = self.get_image(image['filename'], image['subfolder'], image['type']) 239 | images_output.append(image_data) 240 | output_images[node_id] = images_output 241 | # video branch 242 | if 'videos' in node_output: 243 | videos_output = [] 244 | for video in node_output['videos']: 245 | video_data = self.get_image(video['filename'], video['subfolder'], video['type']) 246 | videos_output.append(video_data) 247 | output_images[node_id] = videos_output 248 | 249 | # If we got here, everything worked 250 | return output_images 251 | 252 | except websocket.WebSocketConnectionClosedException as e: 253 | retry_count += 1 254 | logger.warning(f"WebSocket connection closed during processing (attempt {retry_count}/{max_retries})") 255 | 256 | # Try to reconnect before retrying 257 | self.connected = False 258 | if retry_count < max_retries: 259 | logger.info("Attempting to reconnect websocket...") 260 | if self.wss_connect(): 261 | logger.info("Reconnected successfully, retrying operation") 262 | time.sleep(1) # Small delay before retry 263 | else: 264 | logger.error("Failed to reconnect websocket") 265 | else: 266 | logger.error(f"Failed after {max_retries} attempts") 267 | raise RuntimeError(f"Failed to process images after {max_retries} attempts") from e 268 | 269 | except Exception as e: 270 | logger.error(f"Error processing images: {str(e)}") 271 | retry_count += 1 272 | 273 | # For non-websocket errors, we might still want to try reconnecting the websocket 274 | if not self.connected and retry_count < max_retries: 275 | logger.info("Attempting to reconnect websocket...") 276 | self.wss_connect() 277 | time.sleep(1) # Small delay before retry 278 | else: 279 | # If it's not a connection issue or we've tried enough times, re-raise 280 | if retry_count >= max_retries: 281 | raise 282 | 283 | # This should not be reached, but just in case 284 | raise RuntimeError(f"Failed to process images after {max_retries} attempts") 285 | 286 | def parse_worflow(self, prompt_data): 287 | logger.debug(prompt_data) 288 | return self.get_images(prompt_data) 289 | 290 | 291 | def check_readiness(api_base_url: str) -> bool: 292 | cf = comfyuiCaller() 293 | cf.setUrl(api_base_url) 294 | logger.info("Init health check... ") 295 | try: 296 | logger.info(f"Try to connect to ComfyUI backend {api_base_url} ... ") 297 | if cf.wss_connect(): 298 | logger.info(f"ComfyUI backend {api_base_url} connected.") 299 | return True 300 | else: 301 | logger.error(f"Failed to connect to ComfyUI backend {api_base_url}") 302 | return False 303 | except Exception as e: 304 | logger.error(f"Error during health check: {str(e)}") 305 | return False 306 | 307 | 308 | def handler(api_base_url: str, task_id: str, payload: dict) -> dict: 309 | response = { 310 | "success": False, 311 | "image": [], 312 | "content": '{"code": 500}' 313 | } 314 | 315 | try: 316 | logger.info(f"Processing pipeline task with ID: {task_id}") 317 | 318 | # Attempt to invoke the pipeline 319 | try: 320 | images = invoke_pipeline(api_base_url, payload) 321 | 322 | # Process images if available 323 | imgOutputs = post_invocations(images) 324 | logger.info(f"Received {len(imgOutputs)} images") 325 | 326 | # Set success response 327 | response["success"] = True 328 | response["image"] = imgOutputs 329 | response["content"] = '{"code": 200}' 330 | logger.info(f"End process pipeline task with ID: {task_id}") 331 | except Exception as e: 332 | logger.error(f"Error processing pipeline: {str(e)}") 333 | # Keep default failure response 334 | except Exception as e: 335 | # This is a catch-all for any unexpected errors 336 | logger.error(f"Unexpected error in handler for task ID {task_id}: {str(e)}") 337 | traceback.print_exc() 338 | 339 | return response 340 | 341 | @safe_xray_capture('comfyui-pipeline') 342 | def invoke_pipeline(api_base_url: str, body) -> str: 343 | cf = comfyuiCaller() 344 | cf.setUrl(api_base_url) 345 | 346 | # Ensure websocket connection is established before proceeding 347 | if not cf.wss_connect(): 348 | raise ConnectionError(f"Failed to establish websocket connection to {api_base_url}") 349 | 350 | return cf.parse_worflow(body) 351 | 352 | def post_invocations(image): 353 | img_bytes = [] 354 | 355 | if len(image) > 0: 356 | for node_id in image: 357 | for image_data in image[node_id]: 358 | img_bytes.append(image_data) 359 | 360 | return img_bytes -------------------------------------------------------------------------------- /src/backend/queue_agent/src/runtimes/sdwebui.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import base64 5 | import json 6 | import logging 7 | import time 8 | import traceback 9 | 10 | from requests.exceptions import ReadTimeout, HTTPError 11 | from modules import http_action, misc 12 | 13 | logger = logging.getLogger("queue-agent") 14 | 15 | ALWAYSON_SCRIPTS_EXCLUDE_KEYS = ['task', 'id_task', 'uid', 16 | 'sd_model_checkpoint', 'image_link', 'save_dir', 'sd_vae', 'override_settings'] 17 | 18 | # Import the safe_xray_capture decorator from main module 19 | try: 20 | from src.main import safe_xray_capture, xray_enabled 21 | except ImportError: 22 | try: 23 | # Try alternative import path 24 | from ..main import safe_xray_capture, xray_enabled 25 | except ImportError: 26 | # Fallback if import fails - create a simple pass-through decorator 27 | logger.warning("Failed to import safe_xray_capture from main, using fallback") 28 | def safe_xray_capture(name): 29 | def decorator(func): 30 | return func 31 | return decorator 32 | xray_enabled = False 33 | 34 | def check_readiness(api_base_url: str, dynamic_sd_model: bool) -> bool: 35 | """Check if SD Web UI is ready by invoking /option endpoint""" 36 | while True: 37 | try: 38 | logger.info('Checking service readiness...') 39 | # checking with options "sd_model_checkpoint" also for caching current model 40 | opts = invoke_get_options(api_base_url) 41 | logger.info('Service is ready.') 42 | if "sd_model_checkpoint" in opts: 43 | if opts['sd_model_checkpoint'] != None: 44 | current_model_name = opts['sd_model_checkpoint'] 45 | logger.info(f'Init model is: {current_model_name}.') 46 | else: 47 | if dynamic_sd_model: 48 | logger.info(f'Dynamic SD model is enabled, init model is not loaded.') 49 | else: 50 | logger.error(f'Init model {current_model_name} failed to load.') 51 | break 52 | except Exception as e: 53 | logger.debug(repr(e)) 54 | time.sleep(1) 55 | return True 56 | 57 | def handler(api_base_url: str, task_type: str, task_id: str, payload: dict, dynamic_sd_model: bool) -> dict: 58 | """Main handler for SD Web UI request""" 59 | response = {} 60 | try: 61 | logger.info(f"Start process {task_type} task with ID: {task_id}") 62 | match task_type: 63 | case 'text-to-image': 64 | # Compatiability for v1alpha1: Ensure there is an alwayson_scripts 65 | if 'alwayson_scripts' in payload: 66 | # Switch model if necessery 67 | if dynamic_sd_model and payload['alwayson_scripts']['sd_model_checkpoint']: 68 | new_model = payload['alwayson_scripts']['sd_model_checkpoint'] 69 | logger.info(f'Try to switching model to: {new_model}.') 70 | current_model_name = switch_model(api_base_url, new_model) 71 | if current_model_name is None: 72 | raise Exception(f'Failed to switch model to {new_model}') 73 | logger.info(f'Current model is: {current_model_name}.') 74 | else: 75 | payload.update({'alwayson_scripts': {}}) 76 | 77 | task_response = invoke_txt2img(api_base_url, payload) 78 | 79 | case 'image-to-image': 80 | # Compatiability for v1alpha1: Ensure there is an alwayson_scripts 81 | if 'alwayson_scripts' in payload: 82 | # Switch model if necessery 83 | if dynamic_sd_model and payload['alwayson_scripts']['sd_model_checkpoint']: 84 | new_model = payload['alwayson_scripts']['sd_model_checkpoint'] 85 | logger.info(f'Try to switching model to: {new_model}.') 86 | current_model_name = switch_model(api_base_url, new_model) 87 | if current_model_name is None: 88 | raise Exception(f'Failed to switch model to {new_model}') 89 | logger.info(f'Current model is: {current_model_name}.') 90 | else: 91 | payload.update({'alwayson_scripts': {}}) 92 | 93 | task_response = invoke_img2img(api_base_url, payload) 94 | case 'extra-single-image': 95 | # There is no alwayson_script in API spec 96 | task_response = invoke_extra_single_image(api_base_url, payload) 97 | case 'extra-batch-image': 98 | task_response = invoke_extra_batch_images(api_base_url, payload) 99 | case _: 100 | # Catch all 101 | logger.error(f'Unsupported task type: {task_type}, ignoring') 102 | 103 | imgOutputs = post_invocations(task_response) 104 | logger.info(f"Received {len(imgOutputs)} images") 105 | content = json.dumps(succeed(task_id, task_response)) 106 | response["success"] = True 107 | response["image"] = imgOutputs 108 | response["content"] = content 109 | logger.info(f"End process {task_type} task with ID: {task_id}") 110 | except ReadTimeout as e: 111 | invoke_interrupt(api_base_url) 112 | content = json.dumps(failed(task_id, e)) 113 | logger.error(f"{task_type} task with ID: {task_id} timeouted") 114 | traceback.print_exc() 115 | response["success"] = False 116 | response["content"] = content 117 | except Exception as e: 118 | content = json.dumps(failed(task_id, e)) 119 | logger.error(f"{task_type} task with ID: {task_id} finished with error") 120 | traceback.print_exc() 121 | response["success"] = False 122 | response["content"] = content 123 | return response 124 | 125 | @safe_xray_capture('text-to-image') 126 | def invoke_txt2img(api_base_url: str, body) -> str: 127 | # Compatiability for v1alpha1: Move override_settings from header to body 128 | override_settings = {} 129 | if 'override_settings' in body['alwayson_scripts']: 130 | override_settings.update(body['alwayson_scripts']['override_settings']) 131 | if override_settings: 132 | if 'override_settings' in body: 133 | body['override_settings'].update(override_settings) 134 | else: 135 | body.update({'override_settings': override_settings}) 136 | 137 | # Compatiability for v1alpha1: Remove header used for routing in v1alpha1 API request 138 | body.update({'alwayson_scripts': misc.exclude_keys(body['alwayson_scripts'], ALWAYSON_SCRIPTS_EXCLUDE_KEYS)}) 139 | 140 | # Process image link in elsewhere in body 141 | body = download_image(body) 142 | 143 | response = http_action.do_invocations(api_base_url+"txt2img", body) 144 | return response 145 | 146 | @safe_xray_capture('image-to-image') 147 | def invoke_img2img(api_base_url: str, body: dict) -> str: 148 | """Image-to-Image request""" 149 | # Process image link 150 | body = download_image(body) 151 | 152 | # Compatiability for v1alpha1: Move override_settings from header to body 153 | override_settings = {} 154 | if 'override_settings' in body['alwayson_scripts']: 155 | override_settings.update(body['alwayson_scripts']['override_settings']) 156 | if override_settings: 157 | if 'override_settings' in body: 158 | body['override_settings'].update(override_settings) 159 | else: 160 | body.update({'override_settings': override_settings}) 161 | 162 | # Compatiability for v1alpha2: Process image link in "Alwayson_scripts" 163 | # Plan to remove in next release 164 | if 'image_link' in body['alwayson_scripts']: 165 | body.update({"init_images": [body['alwayson_scripts']['image_link']]}) 166 | 167 | # Compatiability for v1alpha1: Remove header used for routing in v1alpha1 API request 168 | body.update({'alwayson_scripts': misc.exclude_keys(body['alwayson_scripts'], ALWAYSON_SCRIPTS_EXCLUDE_KEYS)}) 169 | 170 | response = http_action.do_invocations(api_base_url+"img2img", body) 171 | return response 172 | 173 | @safe_xray_capture('extra-single-image') 174 | def invoke_extra_single_image(api_base_url: str, body) -> str: 175 | body = download_image(body) 176 | response = http_action.do_invocations(api_base_url+"extra-single-image", body) 177 | return response 178 | 179 | @safe_xray_capture('extra-batch-images') 180 | def invoke_extra_batch_images(api_base_url: str, body) -> str: 181 | body = download_image(body) 182 | response = http_action.do_invocations(api_base_url+"extra-batch-images", body) 183 | return response 184 | 185 | def invoke_set_options(api_base_url: str, options: dict) -> str: 186 | return http_action.do_invocations(api_base_url+"options", options) 187 | 188 | def invoke_get_options(api_base_url: str) -> str: 189 | return http_action.do_invocations(api_base_url+"options") 190 | 191 | def invoke_get_model_names(api_base_url: str) -> str: 192 | return sorted([x["title"] for x in http_action.do_invocations(api_base_url+"sd-models")]) 193 | 194 | def invoke_refresh_checkpoints(api_base_url: str) -> str: 195 | return http_action.do_invocations(api_base_url+"refresh-checkpoints", {}) 196 | 197 | def invoke_unload_checkpoints(api_base_url: str) -> str: 198 | return http_action.do_invocations(api_base_url+"unload-checkpoint", {}) 199 | 200 | def invoke_interrupt(api_base_url: str) -> str: 201 | return http_action.do_invocations(api_base_url+"interrupt", {}) 202 | 203 | def switch_model(api_base_url: str, name: str) -> str: 204 | opts = invoke_get_options(api_base_url) 205 | current_model_name = opts['sd_model_checkpoint'] 206 | 207 | if current_model_name == name: 208 | logger.info(f"Model {current_model_name} is currently loaded, ignore switch.") 209 | else: 210 | # refresh then check from model list 211 | invoke_refresh_checkpoints(api_base_url) 212 | models = invoke_get_model_names(api_base_url) 213 | if name in models: 214 | if (current_model_name != None): 215 | logger.info(f"Model {current_model_name} is currently loaded, unloading... ") 216 | try: 217 | invoke_unload_checkpoints(api_base_url) 218 | except HTTPError: 219 | logger.info(f"No model is currently loaded. Loading new model... ") 220 | options = {} 221 | options["sd_model_checkpoint"] = name 222 | invoke_set_options(api_base_url, options) 223 | current_model_name = name 224 | else: 225 | logger.error(f"Model {name} not found, keeping current model.") 226 | return None 227 | 228 | return current_model_name 229 | 230 | # Customizable for success responses 231 | def succeed(task_id, response): 232 | parameters = {} 233 | if 'parameters' in response: # text-to-image and image-to-image 234 | parameters = response['parameters'] 235 | parameters['id_task'] = task_id 236 | parameters['image_seed'] = ','.join( 237 | str(x) for x in json.loads(response['info'])['all_seeds']) 238 | parameters['error_msg'] = '' 239 | elif 'html_info' in response: # extra-single-image and extra-batch-images 240 | parameters['html_info'] = response['html_info'] 241 | parameters['id_task'] = task_id 242 | parameters['error_msg'] = '' 243 | return { 244 | 'images': [''], 245 | 'parameters': parameters, 246 | 'info': '' 247 | } 248 | 249 | 250 | # Customizable for failure responses 251 | def failed(task_id, exception): 252 | parameters = {} 253 | parameters['id_task'] = task_id 254 | parameters['status'] = 0 255 | parameters['error_msg'] = repr(exception) 256 | parameters['reason'] = exception.response.json() if hasattr(exception, "response") else None 257 | return { 258 | 'images': [''], 259 | 'parameters': parameters, 260 | 'info': '' 261 | } 262 | 263 | def download_image(obj, path=""): 264 | """Search URL in object, and replace all URL with content of URL""" 265 | if isinstance(obj, dict): 266 | for key, value in obj.items(): 267 | new_path = f"{path}.{key}" if path else key 268 | obj[key] = download_image(value, new_path) 269 | elif isinstance(obj, list): 270 | for index, item in enumerate(obj): 271 | new_path = f"{path}[{index}]" 272 | obj[index] = download_image(item, new_path) 273 | elif isinstance(obj, str): 274 | if (obj.startswith('http') or obj.startswith('s3://')): 275 | logger.info(f"Found URL {obj} in {path}, replacing... ") 276 | try: 277 | image_byte = misc.encode_to_base64(http_action.get(obj)) 278 | logger.info(f"Replaced {path} with content") 279 | except Exception as e: 280 | logger.error(f"Error fetching URL: {obj}") 281 | logger.error(f"Error: {str(e)}") 282 | return image_byte 283 | return obj 284 | 285 | def post_invocations(response): 286 | img_bytes = [] 287 | 288 | if "images" in response.keys(): 289 | for i in response["images"]: 290 | img_bytes.append(base64.b64decode(i)) 291 | 292 | elif "image" in response.keys(): 293 | img_bytes.append(base64.b64decode(response["image"])) 294 | 295 | return img_bytes -------------------------------------------------------------------------------- /src/charts/sd_on_eks/Chart.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: v2 2 | name: sd-on-eks 3 | description: Stable Diffusion on EKS 4 | # A chart can be either an 'application' or a 'library' chart. 5 | # 6 | # Application charts are a collection of templates that can be packaged into versioned archives 7 | # to be deployed. 8 | # 9 | # Library charts provide useful utilities or functions for the chart developer. They're included as 10 | # a dependency of application charts to inject those utilities and functions into the rendering 11 | # pipeline. Library charts do not define any templates and therefore cannot be deployed. 12 | type: application 13 | # This is the chart version. This version number should be incremented each time you make changes 14 | # to the chart and its templates, including the app version. 15 | # Versions are expected to follow Semantic Versioning (https://semver.org/) 16 | version: 1.1.3 17 | # This is the version number of the application being deployed. This version number should be 18 | # incremented each time you make changes to the application. Versions are not expected to 19 | # follow Semantic Versioning. They should reflect the version the application is using. 20 | # It is recommended to use it with quotes. 21 | appVersion: "1.1.3" 22 | -------------------------------------------------------------------------------- /src/charts/sd_on_eks/_helpers.tpl: -------------------------------------------------------------------------------- 1 | {{/* 2 | Expand the name of the chart. 3 | */}} 4 | {{- define "sdchart.name" -}} 5 | {{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} 6 | {{- end }} 7 | 8 | {{/* 9 | Create a default fully qualified app name. 10 | We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec). 11 | If release name contains chart name it will be used as a full name. 12 | */}} 13 | {{- define "sdchart.fullname" -}} 14 | {{- if .Values.fullnameOverride }} 15 | {{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} 16 | {{- else }} 17 | {{- $name := default .Chart.Name .Values.nameOverride }} 18 | {{- if contains $name .Release.Name }} 19 | {{- .Release.Name | trunc 63 | trimSuffix "-" }} 20 | {{- else }} 21 | {{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} 22 | {{- end }} 23 | {{- end }} 24 | {{- end }} 25 | 26 | {{/* 27 | Create chart name and version as used by the chart label. 28 | */}} 29 | {{- define "sdchart.chart" -}} 30 | {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} 31 | {{- end }} 32 | 33 | {{/* 34 | Common labels 35 | */}} 36 | {{- define "sdchart.labels" -}} 37 | helm.sh/chart: {{ include "sdchart.chart" . }} 38 | {{ include "sdchart.selectorLabels" . }} 39 | {{- if .Chart.AppVersion }} 40 | app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} 41 | {{- end }} 42 | app.kubernetes.io/managed-by: {{ .Release.Service }} 43 | {{- end }} 44 | 45 | {{/* 46 | Selector labels 47 | */}} 48 | {{- define "sdchart.selectorLabels" -}} 49 | app.kubernetes.io/name: {{ include "sdchart.name" . }} 50 | app.kubernetes.io/instance: {{ .Release.Name }} 51 | {{- end }} 52 | 53 | {{/* 54 | Create the name of the service account to use 55 | */}} 56 | {{- define "sdchart.serviceAccountName" -}} 57 | {{- if .Values.serviceAccount.create }} 58 | {{- default (include "sdchart.fullname" .) .Values.serviceAccount.name }} 59 | {{- else }} 60 | {{- default "default" .Values.serviceAccount.name }} 61 | {{- end }} 62 | {{- end }} 63 | -------------------------------------------------------------------------------- /src/charts/sd_on_eks/templates/_helpers.tpl: -------------------------------------------------------------------------------- 1 | {{/* 2 | Expand the name of the chart. 3 | */}} 4 | {{- define "sdchart.name" -}} 5 | {{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }} 6 | {{- end }} 7 | 8 | {{/* 9 | Create a default fully qualified app name. 10 | We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec). 11 | If release name contains chart name it will be used as a full name. 12 | */}} 13 | {{- define "sdchart.fullname" -}} 14 | {{- if .Values.fullnameOverride }} 15 | {{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }} 16 | {{- else }} 17 | {{- $name := default .Chart.Name .Values.nameOverride }} 18 | {{- if contains $name .Release.Name }} 19 | {{- .Release.Name | trunc 63 | trimSuffix "-" }} 20 | {{- else }} 21 | {{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }} 22 | {{- end }} 23 | {{- end }} 24 | {{- end }} 25 | 26 | {{/* 27 | Create chart name and version as used by the chart label. 28 | */}} 29 | {{- define "sdchart.chart" -}} 30 | {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} 31 | {{- end }} 32 | 33 | {{/* 34 | Common labels 35 | */}} 36 | {{- define "sdchart.labels" -}} 37 | helm.sh/chart: {{ include "sdchart.chart" . }} 38 | {{ include "sdchart.selectorLabels" . }} 39 | {{- if .Chart.AppVersion }} 40 | app.kubernetes.io/version: {{ .Chart.AppVersion | quote }} 41 | {{- end }} 42 | app.kubernetes.io/managed-by: {{ .Release.Service }} 43 | {{- end }} 44 | 45 | {{/* 46 | Selector labels 47 | */}} 48 | {{- define "sdchart.selectorLabels" -}} 49 | app.kubernetes.io/name: {{ include "sdchart.name" . }} 50 | app.kubernetes.io/instance: {{ .Release.Name }} 51 | {{- end }} 52 | 53 | {{/* 54 | Create the name of the service account to use 55 | */}} 56 | {{- define "sdchart.serviceAccountName" -}} 57 | {{- if .Values.serviceAccount.create }} 58 | {{- default (include "sdchart.fullname" .) .Values.serviceAccount.name }} 59 | {{- else }} 60 | {{- default "default" .Values.serviceAccount.name }} 61 | {{- end }} 62 | {{- end }} 63 | -------------------------------------------------------------------------------- /src/charts/sd_on_eks/templates/aws-sqs-queue-scaledobject.yaml: -------------------------------------------------------------------------------- 1 | {{- if .Values.runtime.scaling.enabled }} 2 | apiVersion: keda.sh/v1alpha1 3 | kind: ScaledObject 4 | metadata: 5 | name: {{ include "sdchart.fullname" . }}-aws-sqs-queue-scaledobject 6 | namespace: {{ .Release.Namespace }} 7 | labels: 8 | {{- include "sdchart.labels" . | nindent 4 }} 9 | {{- if .Values.runtime.labels }} 10 | {{- toYaml .Values.runtime.labels | nindent 4 }} 11 | {{- end }} 12 | spec: 13 | cooldownPeriod: {{ .Values.runtime.scaling.cooldownPeriod }} 14 | maxReplicaCount: {{ .Values.runtime.scaling.maxReplicaCount }} 15 | minReplicaCount: {{ .Values.runtime.scaling.minReplicaCount }} 16 | pollingInterval: {{ .Values.runtime.scaling.pollingInterval }} 17 | scaleOnInFlight: {{ .Values.runtime.scaling.scaleOnInFlight }} 18 | {{- if .Values.runtime.scaling.extraHPAConfig }} 19 | advanced: 20 | horizontalPodAutoscalerConfig: 21 | behavior: 22 | {{- toYaml .Values.runtime.scaling.extraHPAConfig | nindent 8 }} 23 | {{- end }} 24 | scaleTargetRef: 25 | name: {{ include "sdchart.fullname" . }}-inference-api 26 | triggers: 27 | - authenticationRef: 28 | name: {{ include "sdchart.fullname" . }}-keda-trigger-auth-aws-credentials 29 | metadata: 30 | awsRegion: {{ .Values.global.awsRegion }} 31 | identityOwner: operator 32 | queueLength: {{ quote .Values.runtime.scaling.queueLength }} 33 | queueURL: {{ .Values.runtime.queueAgent.sqsQueueUrl }} 34 | type: aws-sqs-queue 35 | {{- end }} -------------------------------------------------------------------------------- /src/charts/sd_on_eks/templates/configmap-comfyui.yml: -------------------------------------------------------------------------------- 1 | {{- if (eq "comfyui" .Values.runtime.type) }} 2 | apiVersion: v1 3 | kind: ConfigMap 4 | metadata: 5 | name: {{ include "sdchart.fullname" . }}-comfyui-config 6 | namespace: {{ .Release.Namespace }} 7 | labels: 8 | {{- include "sdchart.labels" . | nindent 4 }} 9 | {{- if .Values.runtime.labels }} 10 | {{- toYaml .Values.runtime.labels | nindent 4 }} 11 | {{- end }} 12 | 13 | {{- if .Values.runtime.annotations }} 14 | annotations: 15 | {{ toYaml .Values.runtime.annotations | nindent 4 }} 16 | {{- end }} 17 | data: {} 18 | {{- end }} -------------------------------------------------------------------------------- /src/charts/sd_on_eks/templates/configmap-queue-agent.yml: -------------------------------------------------------------------------------- 1 | apiVersion: v1 2 | kind: ConfigMap 3 | metadata: 4 | name: {{ include "sdchart.fullname" . }}-queue-agent-config 5 | namespace: {{ .Release.Namespace }} 6 | labels: 7 | {{- include "sdchart.labels" . | nindent 4 }} 8 | {{- if .Values.runtime.labels }} 9 | {{- toYaml .Values.runtime.labels | nindent 4 }} 10 | {{- end }} 11 | 12 | {{- if .Values.runtime.annotations }} 13 | annotations: 14 | {{ toYaml .Values.runtime.annotations | nindent 4 }} 15 | {{- end }} 16 | data: 17 | AWS_DEFAULT_REGION: {{ quote .Values.global.awsRegion }} 18 | RUNTIME_NAME: {{ quote .Values.global.runtime }} 19 | SQS_QUEUE_URL: {{ quote .Values.runtime.queueAgent.sqsQueueUrl }} 20 | S3_BUCKET: {{ quote .Values.runtime.queueAgent.s3Bucket }} 21 | SNS_TOPIC_ARN: {{ quote .Values.runtime.queueAgent.snsTopicArn }} 22 | RUNTIME_TYPE: {{ quote .Values.runtime.type }} 23 | {{- if .Values.runtime.queueAgent.dynamicModel }} 24 | DYNAMIC_SD_MODEL: "true" 25 | {{- end }} 26 | -------------------------------------------------------------------------------- /src/charts/sd_on_eks/templates/configmap-sdwebui.yml: -------------------------------------------------------------------------------- 1 | {{- if (eq "sdwebui" .Values.runtime.type) }} 2 | apiVersion: v1 3 | kind: ConfigMap 4 | metadata: 5 | name: {{ include "sdchart.fullname" . }}-sdwebui-config 6 | namespace: {{ .Release.Namespace }} 7 | labels: 8 | {{- include "sdchart.labels" . | nindent 4 }} 9 | {{- if .Values.runtime.labels }} 10 | {{- toYaml .Values.runtime.labels | nindent 4 }} 11 | {{- end }} 12 | 13 | {{- if .Values.runtime.annotations }} 14 | annotations: 15 | {{ toYaml .Values.runtime.annotations | nindent 4 }} 16 | {{- end }} 17 | data: 18 | config.json: | 19 | {"disable_mmap_load_safetensors": true} 20 | {{- end }} -------------------------------------------------------------------------------- /src/charts/sd_on_eks/templates/deployment-comfyui.yaml: -------------------------------------------------------------------------------- 1 | {{- if (eq "comfyui" .Values.runtime.type) }} 2 | apiVersion: apps/v1 3 | kind: Deployment 4 | metadata: 5 | name: {{ include "sdchart.fullname" . }}-inference-api 6 | namespace: {{ .Release.Namespace }} 7 | labels: 8 | {{- include "sdchart.labels" . | nindent 4 }} 9 | {{- if .Values.runtime.labels }} 10 | {{- toYaml .Values.runtime.labels | nindent 4 }} 11 | {{- end }} 12 | runtime-type: comfyui 13 | 14 | {{- if .Values.runtime.annotations }} 15 | annotations: 16 | {{ toYaml .Values.runtime.annotations | nindent 4 }} 17 | {{- end }} 18 | 19 | spec: 20 | replicas: {{ .Values.runtime.replicas }} 21 | selector: 22 | matchLabels: 23 | app: inference-api 24 | {{- include "sdchart.selectorLabels" . | nindent 6 }} 25 | strategy: 26 | rollingUpdate: 27 | maxSurge: 100% 28 | maxUnavailable: 0 29 | type: RollingUpdate 30 | template: 31 | metadata: 32 | labels: 33 | app: inference-api 34 | {{- include "sdchart.selectorLabels" . | nindent 8 }} 35 | spec: 36 | containers: 37 | - name: inference-api 38 | image: {{ .Values.runtime.inferenceApi.image.repository }}:{{ .Values.runtime.inferenceApi.image.tag }} 39 | env: 40 | {{- if .Values.runtime.queueAgent.commandArguments }} 41 | - name: EXTRA_CMD_ARG 42 | value: {{ .Values.runtime.inferenceApi.commandArguments }} 43 | {{- end }} 44 | {{- if .Values.runtime.inferenceApi.extraEnv }} 45 | {{- toYaml .Values.runtime.inferenceApi.extraEnv | nindent 8 }} 46 | {{- end }} 47 | resources: 48 | {{- toYaml .Values.runtime.inferenceApi.resources | nindent 10 }} 49 | volumeMounts: 50 | - mountPath: {{ .Values.runtime.inferenceApi.modelMountPath }} 51 | name: models 52 | imagePullPolicy: {{ .Values.runtime.inferenceApi.imagePullPolicy }} 53 | startupProbe: 54 | httpGet: 55 | path: /system_stats 56 | port: 8080 57 | failureThreshold: 120 58 | periodSeconds: 1 59 | - name: queue-agent 60 | envFrom: 61 | - configMapRef: 62 | name: {{ include "sdchart.fullname" . }}-queue-agent-config 63 | env: 64 | {{- if .Values.runtime.queueAgent.extraEnv }} 65 | {{- toYaml .Values.runtime.queueAgent.extraEnv | nindent 8 }} 66 | {{- end }} 67 | {{- if .Values.runtime.queueAgent.xray.enabled }} 68 | - name: AWS_XRAY_DAEMON_ADDRESS 69 | value: localhost:2000 70 | - name: AWS_XRAY_CONTEXT_MISSING 71 | value: IGNORE_ERROR 72 | {{- else }} 73 | - name: DISABLE_XRAY 74 | value: "true" 75 | {{- end }} 76 | image: {{ .Values.runtime.queueAgent.image.repository }}:{{ .Values.runtime.queueAgent.image.tag }} 77 | imagePullPolicy: {{ .Values.runtime.queueAgent.imagePullPolicy }} 78 | volumeMounts: 79 | - mountPath: /tmp/models 80 | name: models 81 | resources: 82 | {{- toYaml .Values.runtime.queueAgent.resources | nindent 10 }} 83 | {{- if .Values.runtime.queueAgent.xray.enabled }} 84 | - name: xray-daemon 85 | image: {{ .Values.runtime.queueAgent.xray.daemon.image.repository }}:{{ .Values.runtime.queueAgent.xray.daemon.image.tag }} 86 | ports: 87 | - containerPort: 2000 88 | protocol: UDP 89 | {{- end }} 90 | serviceAccountName: {{ .Values.runtime.serviceAccountName }} 91 | terminationGracePeriodSeconds: 60 92 | tolerations: 93 | - effect: NoSchedule 94 | key: nvidia.com/gpu 95 | operator: Exists 96 | - effect: NoSchedule 97 | key: runtime 98 | value: {{ include "sdchart.fullname" . }} 99 | volumes: 100 | - name: models 101 | {{- if .Values.runtime.persistence.enabled }} 102 | persistentVolumeClaim: 103 | {{- if .Values.runtime.persistence.existingClaim }} 104 | claimName: {{ .Values.runtime.persistence.existingClaim }} 105 | {{- else }} 106 | claimName: {{ include "sdchart.fullname" . }}-model-claim 107 | {{- end }} 108 | {{- else }} 109 | emptyDir: {} 110 | {{- end }} 111 | {{- end }} -------------------------------------------------------------------------------- /src/charts/sd_on_eks/templates/deployment-sdwebui.yaml: -------------------------------------------------------------------------------- 1 | {{- if (eq "sdwebui" .Values.runtime.type) }} 2 | apiVersion: apps/v1 3 | kind: Deployment 4 | metadata: 5 | name: {{ include "sdchart.fullname" . }}-inference-api 6 | namespace: {{ .Release.Namespace }} 7 | labels: 8 | {{- include "sdchart.labels" . | nindent 4 }} 9 | {{- if .Values.runtime.labels }} 10 | {{- toYaml .Values.runtime.labels | nindent 4 }} 11 | {{- end }} 12 | runtime-type: sdwebui 13 | {{- if .Values.runtime.annotations }} 14 | annotations: 15 | {{- toYaml .Values.runtime.annotations | nindent 4 }} 16 | {{- end }} 17 | 18 | spec: 19 | replicas: {{ .Values.runtime.replicas }} 20 | selector: 21 | matchLabels: 22 | app: inference-api 23 | {{- include "sdchart.selectorLabels" . | nindent 6 }} 24 | strategy: 25 | rollingUpdate: 26 | maxSurge: 100% 27 | maxUnavailable: 0 28 | type: RollingUpdate 29 | template: 30 | metadata: 31 | labels: 32 | app: inference-api 33 | {{- include "sdchart.selectorLabels" . | nindent 8 }} 34 | spec: 35 | containers: 36 | - name: inference-api 37 | image: {{ .Values.runtime.inferenceApi.image.repository }}:{{ .Values.runtime.inferenceApi.image.tag }} 38 | env: 39 | {{- if .Values.runtime.queueAgent.dynamicModel }} 40 | - name: DYNAMIC_SD_MODEL 41 | value: "true" 42 | {{- else }} 43 | - name: SD_MODEL_CHECKPOINT 44 | value: {{ quote .Values.runtime.inferenceApi.modelFilename }} 45 | {{- end }} 46 | - name: API_ONLY 47 | value: "true" 48 | - name: CONFIG_FILE 49 | value: "/tmp/config.json" 50 | {{- if .Values.runtime.inferenceApi.commandArguments }} 51 | - name: EXTRA_CMD_ARG 52 | value: {{ .Values.runtime.inferenceApi.commandArguments }} 53 | {{- end }} 54 | {{- if .Values.runtime.inferenceApi.extraEnv }} 55 | {{- toYaml .Values.runtime.inferenceApi.extraEnv | nindent 8 }} 56 | {{- end }} 57 | resources: 58 | {{- toYaml .Values.runtime.inferenceApi.resources | nindent 10 }} 59 | volumeMounts: 60 | - mountPath: {{ .Values.runtime.inferenceApi.modelMountPath }} 61 | name: models 62 | - mountPath: "/tmp/config.json" 63 | name: config 64 | subPath: config.json 65 | imagePullPolicy: {{ .Values.runtime.inferenceApi.imagePullPolicy }} 66 | startupProbe: 67 | httpGet: 68 | path: /sdapi/v1/memory 69 | port: 8080 70 | failureThreshold: 120 71 | periodSeconds: 1 72 | - name: queue-agent 73 | envFrom: 74 | - configMapRef: 75 | name: {{ include "sdchart.fullname" . }}-queue-agent-config 76 | env: 77 | {{- if .Values.runtime.queueAgent.extraEnv }} 78 | {{- toYaml .Values.runtime.queueAgent.extraEnv | nindent 8 }} 79 | {{- end }} 80 | {{- if .Values.runtime.queueAgent.xray.enabled }} 81 | - name: AWS_XRAY_DAEMON_ADDRESS 82 | value: localhost:2000 83 | - name: AWS_XRAY_CONTEXT_MISSING 84 | value: IGNORE_ERROR 85 | {{- else }} 86 | - name: DISABLE_XRAY 87 | value: "true" 88 | {{- end }} 89 | image: {{ .Values.runtime.queueAgent.image.repository }}:{{ .Values.runtime.queueAgent.image.tag }} 90 | imagePullPolicy: {{ .Values.runtime.queueAgent.imagePullPolicy }} 91 | resources: 92 | {{- toYaml .Values.runtime.queueAgent.resources | nindent 10 }} 93 | {{- if .Values.runtime.queueAgent.xray.enabled }} 94 | - name: xray-daemon 95 | image: {{ .Values.runtime.queueAgent.xray.daemon.image.repository }}:{{ .Values.runtime.queueAgent.xray.daemon.image.tag }} 96 | ports: 97 | - containerPort: 2000 98 | protocol: UDP 99 | {{- end }} 100 | serviceAccountName: {{ .Values.runtime.serviceAccountName }} 101 | terminationGracePeriodSeconds: 60 102 | tolerations: 103 | - effect: NoSchedule 104 | key: nvidia.com/gpu 105 | operator: Exists 106 | - effect: NoSchedule 107 | key: runtime 108 | value: {{ include "sdchart.fullname" . }} 109 | volumes: 110 | - name: models 111 | {{- if .Values.runtime.persistence.enabled }} 112 | persistentVolumeClaim: 113 | {{- if .Values.runtime.persistence.existingClaim }} 114 | claimName: {{ .Values.runtime.persistence.existingClaim }} 115 | {{- else }} 116 | claimName: {{ include "sdchart.fullname" . }}-model-claim 117 | {{- end }} 118 | {{- else }} 119 | emptyDir: {} 120 | {{- end }} 121 | - name: config 122 | configMap: 123 | name: {{ include "sdchart.fullname" . }}-sdwebui-config 124 | {{- end }} -------------------------------------------------------------------------------- /src/charts/sd_on_eks/templates/keda-trigger-auth-aws-credentials.yaml: -------------------------------------------------------------------------------- 1 | {{- if .Values.runtime.scaling.enabled }} 2 | apiVersion: keda.sh/v1alpha1 3 | kind: TriggerAuthentication 4 | metadata: 5 | name: {{ include "sdchart.fullname" . }}-keda-trigger-auth-aws-credentials 6 | namespace: {{ .Release.Namespace }} 7 | labels: 8 | {{- include "sdchart.labels" . | nindent 4 }} 9 | {{- if .Values.runtime.labels }} 10 | {{- toYaml .Values.runtime.labels | nindent 4 }} 11 | {{- end }} 12 | spec: 13 | podIdentity: 14 | provider: aws-eks 15 | {{- end }} -------------------------------------------------------------------------------- /src/charts/sd_on_eks/templates/nodeclass.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: karpenter.k8s.aws/v1 2 | kind: EC2NodeClass 3 | metadata: 4 | name: {{ include "sdchart.fullname" . }}-nodeclass-gpu 5 | labels: 6 | {{- include "sdchart.labels" . | nindent 4 }} 7 | {{- if .Values.runtime.labels }} 8 | {{- toYaml .Values.runtime.labels | nindent 4 }} 9 | {{- end }} 10 | spec: 11 | amiSelectorTerms: 12 | - alias: {{ lower .Values.karpenter.nodeTemplate.amiFamily }}@latest 13 | 14 | subnetSelectorTerms: 15 | - tags: 16 | "aws-cdk:subnet-type": "Private" 17 | "aws:cloudformation:stack-name": {{ .Values.global.stackName }} 18 | 19 | securityGroupSelectorTerms: 20 | - tags: 21 | "aws:eks:cluster-name": {{ quote .Values.global.stackName }} 22 | 23 | role: {{ .Values.karpenter.nodeTemplate.iamRole }} 24 | 25 | tags: 26 | stack: {{ .Values.global.stackName }} 27 | runtime: {{ .Release.Name }} 28 | {{- if .Values.karpenter.nodeTemplate.tags }} 29 | {{- toYaml .Values.karpenter.nodeTemplate.tags | nindent 4 }} 30 | {{- end }} 31 | 32 | metadataOptions: 33 | httpEndpoint: enabled 34 | httpProtocolIPv6: disabled 35 | httpPutResponseHopLimit: 2 36 | httpTokens: optional 37 | 38 | blockDeviceMappings: 39 | - deviceName: /dev/xvda 40 | ebs: 41 | volumeSize: {{ .Values.karpenter.nodeTemplate.osVolume.volumeSize }} 42 | volumeType: {{ .Values.karpenter.nodeTemplate.osVolume.volumeType }} 43 | deleteOnTermination: {{ .Values.karpenter.nodeTemplate.osVolume.deleteOnTermination }} 44 | iops: {{ .Values.karpenter.nodeTemplate.osVolume.iops }} 45 | throughput: {{ .Values.karpenter.nodeTemplate.osVolume.throughput }} 46 | {{- if .Values.karpenter.nodeTemplate.dataVolume }} 47 | - deviceName: /dev/xvdb 48 | ebs: 49 | volumeSize: {{ .Values.karpenter.nodeTemplate.dataVolume.volumeSize }} 50 | volumeType: {{ .Values.karpenter.nodeTemplate.dataVolume.volumeType }} 51 | deleteOnTermination: {{ .Values.karpenter.nodeTemplate.dataVolume.deleteOnTermination }} 52 | iops: {{ .Values.karpenter.nodeTemplate.dataVolume.iops }} 53 | throughput: {{ .Values.karpenter.nodeTemplate.dataVolume.throughput }} 54 | {{- if .Values.karpenter.nodeTemplate.dataVolume.snapshotID }} 55 | snapshotID: {{ .Values.karpenter.nodeTemplate.dataVolume.snapshotID }} 56 | {{- end }} 57 | {{- end }} 58 | 59 | {{- if .Values.karpenter.nodeTemplate.userData }} 60 | userData: |- 61 | {{- tpl .Values.karpenter.nodeTemplate.userData . | nindent 4 }} 62 | {{- end }} -------------------------------------------------------------------------------- /src/charts/sd_on_eks/templates/nodepool.yaml: -------------------------------------------------------------------------------- 1 | apiVersion: karpenter.sh/v1 2 | kind: NodePool 3 | metadata: 4 | name: {{ include "sdchart.fullname" . }}-nodepool-gpu 5 | labels: 6 | {{- include "sdchart.labels" . | nindent 4 }} 7 | {{- if .Values.runtime.labels }} 8 | {{- toYaml .Values.runtime.labels | nindent 4 }} 9 | {{- end }} 10 | spec: 11 | template: 12 | metadata: 13 | {{- if .Values.karpenter.provisioner.labels }} 14 | labels: 15 | {{- toYaml .Values.karpenter.provisioner.labels | nindent 8 }} 16 | {{- end }} 17 | {{- if .Values.karpenter.provisioner.annotations }} 18 | annotations: 19 | {{- toYaml .Values.karpenter.provisioner.annotations | nindent 8 }} 20 | {{- end }} 21 | spec: 22 | nodeClassRef: 23 | group: karpenter.k8s.aws 24 | kind: EC2NodeClass 25 | name: {{ include "sdchart.fullname" . }}-nodeclass-gpu 26 | taints: 27 | - effect: NoSchedule 28 | key: nvidia.com/gpu 29 | - effect: NoSchedule 30 | key: runtime 31 | value: {{ include "sdchart.fullname" . }} 32 | {{- if .Values.karpenter.provisioner.extraTaints }} 33 | {{- toYaml .Values.karpenter.provisioner.extraTaints | nindent 6 }} 34 | {{- end }} 35 | {{- if .Values.karpenter.provisioner.disruption.expireAfter }} 36 | expireAfter: {{ .Values.karpenter.provisioner.disruption.expireAfter }} 37 | {{- end }} 38 | requirements: 39 | - key: karpenter.sh/capacity-type 40 | operator: In 41 | values: 42 | {{- if .Values.karpenter.provisioner.capacityType.spot }} 43 | - spot 44 | {{- end }} 45 | {{- if .Values.karpenter.provisioner.capacityType.onDemand }} 46 | - on-demand 47 | {{- end }} 48 | {{- if .Values.karpenter.provisioner.instanceType }} 49 | - key: node.kubernetes.io/instance-type 50 | operator: In 51 | values: 52 | {{- toYaml .Values.karpenter.provisioner.instanceType | nindent 8 }} 53 | {{- end }} 54 | {{- if .Values.karpenter.provisioner.extraRequirements }} 55 | {{- toYaml .Values.karpenter.provisioner.extraRequirements | nindent 6 }} 56 | {{- end }} 57 | disruption: 58 | {{- if .Values.karpenter.provisioner.consolidation }} 59 | consolidationPolicy: WhenEmptyOrUnderutilized 60 | {{- else }} 61 | consolidationPolicy: WhenEmpty 62 | {{- end }} 63 | consolidateAfter: {{ .Values.karpenter.provisioner.disruption.consolidateAfter }} 64 | {{- if .Values.karpenter.provisioner.resourceLimits }} 65 | limits: 66 | {{- toYaml .Values.karpenter.provisioner.resourceLimits | nindent 4 }} 67 | {{- end }} 68 | -------------------------------------------------------------------------------- /src/charts/sd_on_eks/templates/persistentvolume-s3.yaml: -------------------------------------------------------------------------------- 1 | {{- if and (.Values.runtime.persistence.enabled) (.Values.runtime.persistence.s3.enabled) }} 2 | apiVersion: v1 3 | kind: PersistentVolume 4 | metadata: 5 | name: {{ include "sdchart.fullname" . }}-s3-model-volume 6 | {{- if .Values.runtime.persistence.annotations }} 7 | annotations: 8 | {{ toYaml .Values.runtime.persistence.annotations | nindent 4 }} 9 | {{- end }} 10 | labels: 11 | {{- include "sdchart.labels" . | nindent 4 }} 12 | {{- if .Values.runtime.labels }} 13 | {{- toYaml .Values.runtime.labels | nindent 4 }} 14 | {{- end }} 15 | {{- if .Values.runtime.persistence.labels }} 16 | {{- toYaml .Values.runtime.persistence.labels | nindent 4 }} 17 | {{- end }} 18 | spec: 19 | capacity: 20 | storage: {{ .Values.runtime.persistence.size }} 21 | accessModes: 22 | {{- toYaml .Values.runtime.persistence.accessModes | nindent 2 }} 23 | mountOptions: 24 | - allow-delete 25 | - allow-other 26 | - file-mode=777 27 | - dir-mode=777 28 | csi: 29 | driver: s3.csi.aws.com 30 | volumeHandle: s3-csi-driver-volume 31 | volumeAttributes: 32 | bucketName: {{ .Values.runtime.persistence.s3.modelBucket }} 33 | {{- end }} -------------------------------------------------------------------------------- /src/charts/sd_on_eks/templates/persistentvolumeclaim.yaml: -------------------------------------------------------------------------------- 1 | {{- if and (.Values.runtime.persistence.enabled) (not (.Values.runtime.persistence.existingClaim)) }} 2 | apiVersion: v1 3 | kind: PersistentVolumeClaim 4 | metadata: 5 | name: {{ include "sdchart.fullname" . }}-model-claim 6 | {{- if .Values.runtime.persistence.annotations }} 7 | annotations: 8 | {{- toYaml .Values.runtime.persistence.annotations | nindent 4 }} 9 | {{- end }} 10 | labels: 11 | {{- include "sdchart.labels" . | nindent 4 }} 12 | {{- if .Values.runtime.labels }} 13 | {{- toYaml .Values.runtime.labels | nindent 4 }} 14 | {{- end }} 15 | {{- if .Values.runtime.persistence.labels }} 16 | {{- toYaml .Values.runtime.persistence.labels | nindent 4 }} 17 | {{- end }} 18 | spec: 19 | accessModes: 20 | {{- toYaml .Values.runtime.persistence.accessModes | nindent 2 }} 21 | resources: 22 | requests: 23 | storage: "{{ .Values.runtime.persistence.size }}" 24 | {{- if .Values.runtime.persistence.storageClass }} 25 | {{- if (eq "-" .Values.runtime.persistence.storageClass) }} 26 | storageClassName: "" 27 | {{- else }} 28 | storageClassName: "{{ .Values.runtime.persistence.storageClass }}" 29 | {{- end }} 30 | {{- end }} 31 | {{- if .Values.runtime.persistence.existingVolume }} 32 | volumeName: "{{ .Values.runtime.persistence.existingVolume }}" 33 | {{- end }} 34 | {{- end }} -------------------------------------------------------------------------------- /src/charts/sd_on_eks/values.yaml: -------------------------------------------------------------------------------- 1 | global: 2 | awsRegion: us-west-2 3 | stackName: "" 4 | runtime: "" 5 | 6 | karpenter: 7 | provisioner: 8 | labels: {} 9 | capacityType: 10 | onDemand: true 11 | spot: true 12 | instanceType: 13 | - "g5.xlarge" 14 | - "g5.2xlarge" 15 | extraRequirements: [] 16 | extraTaints: [] 17 | resourceLimits: 18 | nvidia.com/gpu: 100 19 | consolidation: true 20 | disruption: 21 | consolidateAfter: 30s 22 | expireAfter: Never 23 | nodeTemplate: 24 | iamRole: "" 25 | securityGroupSelector: {} 26 | subnetSelector: {} 27 | tags: {} 28 | amiFamily: Bottlerocket 29 | osVolume: 30 | volumeSize: 10Gi 31 | volumeType: gp3 32 | deleteOnTermination: true 33 | iops: 3000 34 | throughput: 125 35 | dataVolume: 36 | volumeSize: 150Gi 37 | volumeType: gp3 38 | deleteOnTermination: true 39 | iops: 4000 40 | throughput: 1000 41 | userData: "" 42 | runtime: 43 | type: "sdwebui" 44 | labels: {} 45 | annotations: {} 46 | serviceAccountName: runtime-sa 47 | replicas: 1 48 | scaling: 49 | enabled: true 50 | queueLength: 10 51 | cooldownPeriod: 60 52 | maxReplicaCount: 20 53 | minReplicaCount: 0 54 | pollingInterval: 1 55 | scaleOnInFlight: false 56 | extraHPAConfig: {} 57 | inferenceApi: 58 | image: 59 | repository: public.ecr.aws/bingjiao/sd-on-eks/sdwebui 60 | tag: latest 61 | modelFilename: "" 62 | modelMountPath: /opt/ml/code/models 63 | commandArguments: "" 64 | extraEnv: {} 65 | imagePullPolicy: IfNotPresent 66 | resources: 67 | limits: 68 | nvidia.com/gpu: "1" 69 | requests: 70 | nvidia.com/gpu: "1" 71 | cpu: 2500m 72 | memory: 6Gi 73 | queueAgent: 74 | image: 75 | repository: public.ecr.aws/bingjiao/sd-on-eks/queue-agent 76 | tag: latest 77 | extraEnv: {} 78 | dynamicModel: false 79 | imagePullPolicy: IfNotPresent 80 | s3Bucket: "" 81 | snsTopicArn: "" 82 | sqsQueueUrl: "" 83 | resources: 84 | requests: 85 | cpu: 500m 86 | memory: 512Mi 87 | xray: 88 | enabled: true 89 | daemon: 90 | image: 91 | repository: public.ecr.aws/xray/aws-xray-daemon 92 | tag: 3.3.14 93 | persistence: 94 | enabled: true 95 | existingClaim: "" 96 | existingVolume: "" 97 | labels: {} 98 | annotations: {} 99 | storageClass: "" 100 | size: 2Ti 101 | accessModes: 102 | - ReadWriteMany 103 | s3: 104 | enabled: true 105 | modelBucket: "" 106 | -------------------------------------------------------------------------------- /src/frontend/input_function/v1alpha1/app.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import os 5 | import boto3 6 | import json 7 | import traceback 8 | 9 | sns_client = boto3.client('sns') 10 | 11 | 12 | def lambda_handler(event, context): 13 | if event['httpMethod'] == 'POST': 14 | try: 15 | payload = json.loads(event['body']) 16 | val = validate(payload) 17 | if (val != "success"): 18 | return { 19 | 'statusCode': 400, 20 | 'body': f"Incorrect payload structure, {val}" 21 | } 22 | 23 | task = payload['alwayson_scripts']['task'] 24 | id_task = payload['alwayson_scripts']['id_task'] 25 | sd_model_checkpoint = payload['alwayson_scripts']['sd_model_checkpoint'] 26 | prefix = payload['alwayson_scripts']['save_dir'] 27 | s3_output_path = f"{os.environ['S3_OUTPUT_BUCKET']}/{prefix}/{id_task}" 28 | 29 | payload['s3_output_path'] = s3_output_path 30 | 31 | print(event['headers']) 32 | print(event['queryStringParameters']) 33 | 34 | msg = {"metadata": { 35 | "id": id_task, 36 | "runtime": "legacy", 37 | "tasktype": task, 38 | "prefix": prefix, 39 | "context": {} 40 | },"content": payload} 41 | 42 | sns_client.publish( 43 | TargetArn=os.environ['SNS_TOPIC_ARN'], 44 | Message=json.dumps(msg), 45 | MessageAttributes={ 46 | 'sd_model_checkpoint': { 47 | 'DataType': 'String', 48 | 'StringValue': sd_model_checkpoint 49 | } 50 | } 51 | ) 52 | 53 | return { 54 | 'statusCode': 200, 55 | 'body': json.dumps({ 56 | "id_task": id_task, 57 | "sd_model_checkpoint": sd_model_checkpoint, 58 | "output_location": f"s3://{s3_output_path}" 59 | }) 60 | } 61 | 62 | except Exception as e: 63 | traceback.print_exc() 64 | return { 65 | 'statusCode': 400, 66 | 'body': str(e) 67 | } 68 | else: 69 | return { 70 | 'statusCode': 400, 71 | 'body': "Unsupported HTTP method" 72 | } 73 | 74 | def validate(body: dict) -> str: 75 | result = "success" 76 | if 'alwayson_scripts' not in body.keys(): 77 | result = "alwayson_scripts is missing" 78 | else: 79 | if "task" not in body["alwayson_scripts"].keys(): 80 | result = "task is missing" 81 | if "sd_model_checkpoint" not in body["alwayson_scripts"].keys(): 82 | result = "sd_model_checkpoint is missing" 83 | if "id_task" not in body["alwayson_scripts"].keys(): 84 | result = "id_task is missing" 85 | return result -------------------------------------------------------------------------------- /src/frontend/input_function/v1alpha2/app.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import json 5 | import os 6 | import traceback 7 | 8 | import boto3 9 | 10 | sns_client = boto3.client('sns') 11 | 12 | def lambda_handler(event, context): 13 | if event['httpMethod'] == 'POST': 14 | try: 15 | payload = json.loads(event['body'])["task"] 16 | val = validate(payload) 17 | if val != "success": 18 | return { 19 | 'statusCode': 400, 20 | 'body': 'Incorrect payload structure, ' + val 21 | } 22 | 23 | id = payload["metadata"]["id"] 24 | runtime = payload["metadata"]["runtime"] 25 | prefix = payload["metadata"]["prefix"] 26 | s3_output_path = f"{os.environ['S3_OUTPUT_BUCKET']}/{prefix}/{id}" 27 | 28 | print(event['headers']) 29 | print(event['queryStringParameters']) 30 | 31 | sns_client.publish( 32 | TargetArn=os.environ['SNS_TOPIC_ARN'], 33 | Message=json.dumps(payload), 34 | MessageAttributes={ 35 | 'runtime': { 36 | 'DataType': 'String', 37 | 'StringValue': runtime 38 | } 39 | } 40 | ) 41 | 42 | return { 43 | 'statusCode': 200, 44 | 'body': json.dumps({ 45 | "id": id, 46 | "runtime": runtime, 47 | "output_location": f"s3://{s3_output_path}" 48 | }) 49 | } 50 | 51 | except Exception as e: 52 | traceback.print_exc() 53 | return { 54 | 'statusCode': 400, 55 | 'body': str(e) 56 | } 57 | else: 58 | return { 59 | 'statusCode': 400, 60 | 'body': "Unsupported HTTP method" 61 | } 62 | 63 | 64 | def validate(body: dict) -> str: 65 | result = "success" 66 | if "metadata" not in body.keys(): 67 | result = "metadata is missing" 68 | else: 69 | if "id" not in body["metadata"].keys(): 70 | result = "id is missing" 71 | if "runtime" not in body["metadata"].keys(): 72 | result = "runtime is missing" 73 | if "tasktype" not in body["metadata"].keys(): 74 | result = "tasktype is missing" 75 | if "content" not in body.keys(): 76 | result = "content is missing" 77 | return result -------------------------------------------------------------------------------- /src/tools/ebs_throughput_tuner/app.py: -------------------------------------------------------------------------------- 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | # SPDX-License-Identifier: MIT-0 3 | 4 | import boto3, os 5 | 6 | def get_ebs_volume_id(instance_id): 7 | ec2 = boto3.client('ec2') 8 | try: 9 | response = ec2.describe_instances(InstanceIds=[instance_id]) 10 | volume_id = response['Reservations'][0]['Instances'][0]['BlockDeviceMappings'][-1]['Ebs']['VolumeId'] 11 | return volume_id 12 | except Exception as e: 13 | print("Error:", e) 14 | return None 15 | 16 | def get_instance_tag(instance, key): 17 | for tag in instance.get('Tags', []): 18 | if tag['Key'] == key: 19 | return tag['Value'] 20 | return None 21 | 22 | def modify_ebs_throughput_and_iops(volume_id, throughput, iops): 23 | ec2 = boto3.client('ec2') 24 | try: 25 | response = ec2.modify_volume(VolumeId=volume_id, Throughput=throughput, Iops=iops) 26 | print("Successfully modified EBS throughput of " + volume_id) 27 | except Exception as e: 28 | print("Error:", e) 29 | 30 | # Entrypoint 31 | def lambda_handler(event, context): 32 | ec2 = boto3.client('ec2') 33 | instance_id = event['detail']['instance-id'] 34 | print (f"Process with instance " + instance_id) 35 | target_ec2_tag_key = os.environ['TARGET_EC2_TAG_KEY'] 36 | target_ec2_tag_value = os.environ['TARGET_EC2_TAG_VALUE'] 37 | # Get the instance tag 38 | response = ec2.describe_instances(InstanceIds=[instance_id]) 39 | instance = response['Reservations'][0]['Instances'][0] 40 | instance_tag_value = get_instance_tag(instance, target_ec2_tag_key) 41 | volume_id = get_ebs_volume_id(instance_id) 42 | if instance_tag_value: 43 | # Determine if EBS throughput needs to be modified based on instance name 44 | if target_ec2_tag_value in instance_tag_value: 45 | print (f"Got matching tag from " + instance_id + " ,processing...") 46 | throughput_value = int(os.environ['THROUGHPUT_VALUE']) # Modify throughput 47 | IOPS_value = int(os.environ['IOPS_VALUE']) # Modify IOPS 48 | modify_ebs_throughput_and_iops(volume_id, throughput_value, IOPS_value) 49 | else: 50 | print (f"Skipped " + instance_id) 51 | -------------------------------------------------------------------------------- /test/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | SCRIPTPATH=$(realpath $(dirname "$0")) 6 | STACK_NAME=${STACK_NAME:-"sdoneksStack"} 7 | AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION:-$(aws ec2 describe-availability-zones --output text --query 'AvailabilityZones[0].[RegionName]')} 8 | declare -l RUNTIME_TYPE=${RUNTIME_TYPE:-"sdwebui"} 9 | API_VERSION=${API_VERSION:-"v1alpha2"} 10 | 11 | API_ENDPOINT=$(aws cloudformation describe-stacks --stack-name ${STACK_NAME} --output text --query 'Stacks[0].Outputs[?OutputKey==`FrontApiEndpoint`].OutputValue') 12 | 13 | printf "API Endpoint is ${API_ENDPOINT}\n" 14 | 15 | API_KEY_COMMAND=$(aws cloudformation describe-stacks --stack-name ${STACK_NAME} --output text --query 'Stacks[0].Outputs[?OutputKey==`GetAPIKeyCommand`].OutputValue') 16 | 17 | API_KEY=$(echo $API_KEY_COMMAND | bash) 18 | 19 | printf "API Key is ${API_KEY}\n" 20 | 21 | if [[ ${RUNTIME_TYPE} == "sdwebui" ]] 22 | then 23 | printf "Generating test text-to-image request... \n" 24 | 25 | curl -X POST ${API_ENDPOINT}/${API_VERSION}/ \ 26 | -H "Content-Type: application/json" \ 27 | -H "x-api-key: ${API_KEY}" \ 28 | -d @${SCRIPTPATH}/${API_VERSION}/t2i.json 29 | 30 | printf "\nGenerating test image-to-image request... \n" 31 | 32 | curl -X POST ${API_ENDPOINT}/${API_VERSION}/ \ 33 | -H "Content-Type: application/json" \ 34 | -H "x-api-key: ${API_KEY}" \ 35 | -d @${SCRIPTPATH}/${API_VERSION}/i2i.json 36 | 37 | printf "\nGenerating image upscaling request... \n" 38 | 39 | curl -X POST ${API_ENDPOINT}/${API_VERSION}/ \ 40 | -H "Content-Type: application/json" \ 41 | -H "x-api-key: ${API_KEY}" \ 42 | -d @${SCRIPTPATH}/${API_VERSION}/extra-single-image.json 43 | fi 44 | 45 | if [[ ${RUNTIME_TYPE} == "comfyui" ]] 46 | then 47 | printf "Generating test pipeline request... \n" 48 | 49 | curl -X POST ${API_ENDPOINT}/${API_VERSION} \ 50 | -H "Content-Type: application/json" \ 51 | -H "x-api-key: ${API_KEY}" \ 52 | -d @${SCRIPTPATH}/${API_VERSION}/pipeline.json 53 | fi -------------------------------------------------------------------------------- /test/v1alpha1/i2i.json: -------------------------------------------------------------------------------- 1 | { 2 | "alwayson_scripts": { 3 | "task": "image-to-image", 4 | "image_link": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png", 5 | "id_task": "test-i2i", 6 | "sd_model_checkpoint": "sd_xl_turbo_1.0.safetensors", 7 | "save_dir": "outputs" 8 | }, 9 | "prompt": "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k", 10 | "steps": 16, 11 | "width": 512, 12 | "height": 512 13 | } -------------------------------------------------------------------------------- /test/v1alpha1/t2i.json: -------------------------------------------------------------------------------- 1 | { 2 | "alwayson_scripts": { 3 | "task": "text-to-image", 4 | "sd_model_checkpoint": "sd_xl_turbo_1.0.safetensors", 5 | "id_task": "test-t2i", 6 | "save_dir": "outputs" 7 | }, 8 | "prompt": "A dog", 9 | "steps": 16, 10 | "width": 512, 11 | "height": 512 12 | } -------------------------------------------------------------------------------- /test/v1alpha1/t2v.json: -------------------------------------------------------------------------------- 1 | { 2 | "alwayson_scripts": { 3 | "task": "text-to-image", 4 | "sd_model_checkpoint": "sd_xl_turbo_1.0.safetensors", 5 | "id_task": "gif1", 6 | "uid": "gif1", 7 | "save_dir": "outputs", 8 | "AnimateDiff": { 9 | "args": [ 10 | { 11 | "batch_size": 16, 12 | "closed_loop": "N", 13 | "enable": true, 14 | "format": [ 15 | "GIF" 16 | ], 17 | "fps": 8, 18 | "interp": "Off", 19 | "interp_x": 10, 20 | "last_frame": null, 21 | "latent_power": 1, 22 | "latent_scale": 32, 23 | "loop_number": 0, 24 | "model": "mm_sd15_v3.safetensors", 25 | "overlap": -1, 26 | "stride": 1, 27 | "video_length": 32 28 | } 29 | ] 30 | } 31 | }, 32 | "batch_size": 32, 33 | "cfg_scale": 7, 34 | "height": 512, 35 | "negative_prompt": "(worst quality:1.2), (low quality:1.2), (lowres:1.1), (monochrome:1.1), (greyscale), multiple views, comic, sketch, (((bad anatomy))), (((deformed))), (((disfigured))), watermark, multiple_views, mutation hands, mutation fingers, extra fingers, missing fingers, watermark, nude, nsfw", 36 | "prompt": "masterpiece, 30 year old women, cleavage, red hair, bun, ponytail, medium breast, desert, cactus vibe, sensual pose, (looking in the camera:1.2), (front view:1.2), facing the camera, close up, upper body", 37 | "steps": 20, 38 | "width": 512 39 | } -------------------------------------------------------------------------------- /test/v1alpha2/extra-single-image.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": { 3 | "metadata": { 4 | "id": "test-extra", 5 | "runtime": "sdruntime", 6 | "tasktype": "extra-single-image", 7 | "prefix": "output", 8 | "context": "" 9 | }, 10 | "content": { 11 | "resize_mode":0, 12 | "show_extras_results":false, 13 | "gfpgan_visibility":0, 14 | "codeformer_visibility":0, 15 | "codeformer_weight":0, 16 | "upscaling_resize":4, 17 | "upscaling_resize_w":512, 18 | "upscaling_resize_h":512, 19 | "upscaling_crop":false, 20 | "upscaler_1":"R-ESRGAN 4x+", 21 | "upscaler_2":"None", 22 | "extras_upscaler_2_visibility":0, 23 | "upscale_first":false, 24 | "image":"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png" 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /test/v1alpha2/i2i.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": { 3 | "metadata": { 4 | "id": "test-i2i", 5 | "runtime": "sdruntime", 6 | "tasktype": "image-to-image", 7 | "prefix": "output", 8 | "context": "" 9 | }, 10 | "content": { 11 | "alwayson_scripts": {}, 12 | "init_images": ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"], 13 | "prompt": "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k", 14 | "steps": 16, 15 | "width": 512, 16 | "height": 512 17 | } 18 | } 19 | } -------------------------------------------------------------------------------- /test/v1alpha2/pipeline.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": { 3 | "metadata": { 4 | "id": "test-pipeline", 5 | "runtime": "sdruntime", 6 | "tasktype": "pipeline", 7 | "prefix": "output", 8 | "context": "" 9 | }, 10 | "content": { 11 | "3": { 12 | "inputs": { 13 | "seed": 89465186141685, 14 | "steps": 20, 15 | "cfg": 8, 16 | "sampler_name": "euler", 17 | "scheduler": "normal", 18 | "denoise": 1, 19 | "model": [ 20 | "4", 21 | 0 22 | ], 23 | "positive": [ 24 | "6", 25 | 0 26 | ], 27 | "negative": [ 28 | "7", 29 | 0 30 | ], 31 | "latent_image": [ 32 | "5", 33 | 0 34 | ] 35 | }, 36 | "class_type": "KSampler", 37 | "_meta": { 38 | "title": "KSampler" 39 | } 40 | }, 41 | "4": { 42 | "inputs": { 43 | "ckpt_name": "sd_xl_turbo_1.0.safetensors" 44 | }, 45 | "class_type": "CheckpointLoaderSimple", 46 | "_meta": { 47 | "title": "Load Checkpoint" 48 | } 49 | }, 50 | "5": { 51 | "inputs": { 52 | "width": 512, 53 | "height": 512, 54 | "batch_size": 1 55 | }, 56 | "class_type": "EmptyLatentImage", 57 | "_meta": { 58 | "title": "Empty Latent Image" 59 | } 60 | }, 61 | "6": { 62 | "inputs": { 63 | "text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,", 64 | "clip": [ 65 | "4", 66 | 1 67 | ] 68 | }, 69 | "class_type": "CLIPTextEncode", 70 | "_meta": { 71 | "title": "CLIP Text Encode (Prompt)" 72 | } 73 | }, 74 | "7": { 75 | "inputs": { 76 | "text": "text, watermark", 77 | "clip": [ 78 | "4", 79 | 1 80 | ] 81 | }, 82 | "class_type": "CLIPTextEncode", 83 | "_meta": { 84 | "title": "CLIP Text Encode (Prompt)" 85 | } 86 | }, 87 | "8": { 88 | "inputs": { 89 | "samples": [ 90 | "3", 91 | 0 92 | ], 93 | "vae": [ 94 | "4", 95 | 2 96 | ] 97 | }, 98 | "class_type": "VAEDecode", 99 | "_meta": { 100 | "title": "VAE Decode" 101 | } 102 | }, 103 | "9": { 104 | "inputs": { 105 | "filename_prefix": "ComfyUI", 106 | "images": [ 107 | "8", 108 | 0 109 | ] 110 | }, 111 | "class_type": "SaveImage", 112 | "_meta": { 113 | "title": "Save Image" 114 | } 115 | } 116 | } 117 | } 118 | } -------------------------------------------------------------------------------- /test/v1alpha2/t2i.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": { 3 | "metadata": { 4 | "id": "test-t2i", 5 | "runtime": "sdruntime", 6 | "tasktype": "text-to-image", 7 | "prefix": "output", 8 | "context": "" 9 | }, 10 | "content": { 11 | "alwayson_scripts": {}, 12 | "prompt": "A dog", 13 | "steps": 16, 14 | "width": 512, 15 | "height": 512 16 | } 17 | } 18 | } -------------------------------------------------------------------------------- /tsconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "compilerOptions": { 3 | "target": "ES2020", 4 | "module": "commonjs", 5 | "lib": [ 6 | "es2020", 7 | "dom" 8 | ], 9 | "declaration": true, 10 | "strict": true, 11 | "noImplicitAny": true, 12 | "strictNullChecks": true, 13 | "noImplicitThis": true, 14 | "alwaysStrict": true, 15 | "noUnusedLocals": false, 16 | "noUnusedParameters": false, 17 | "noImplicitReturns": true, 18 | "noFallthroughCasesInSwitch": false, 19 | "inlineSourceMap": true, 20 | "inlineSources": true, 21 | "experimentalDecorators": true, 22 | "strictPropertyInitialization": false, 23 | "typeRoots": [ 24 | "./node_modules/@types" 25 | ] 26 | }, 27 | "exclude": [ 28 | "node_modules", 29 | "cdk.out" 30 | ] 31 | } 32 | --------------------------------------------------------------------------------