├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
├── dependabot.yml
├── solutionid_validator.sh
└── workflows
│ ├── maintainer_workflows.yml
│ └── pr-triage.yml
├── .gitignore
├── .gitmodules
├── .npmignore
├── CODEOWNERS
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── NOTICE
├── README.md
├── bin
└── stable-difussion-on-eks.ts
├── cdk.json
├── config.schema.yaml
├── config.yaml
├── deploy
├── README.md
├── config.yaml.template
├── deploy.sh
├── install-tools.sh
├── update.sh
└── upload-model.sh
├── docs
├── Gemfile
├── Gemfile.lock
├── README.md
├── api
│ ├── v1alpha1.yaml
│ └── v1alpha2.yaml
├── en
│ ├── _config.yml
│ ├── _includes
│ │ └── image.html
│ ├── async_img_sd_IG.md
│ ├── async_img_sd_images
│ │ ├── .gitkeep
│ │ └── asynchronous-image-generation-with-stable-diffusion-on-aws.png
│ └── async_img_sd_sidebar.yml
├── stable-diffusion-reference-architecture-updated.png
└── zh
│ ├── _config.yml
│ ├── _includes
│ └── image.html
│ ├── async_img_sd_zh_IG.md
│ ├── async_img_sd_zh_images
│ ├── .gitkeep
│ └── stable-diffusion-reference-architecture-updated.png
│ └── async_img_sd_zh_sidebar.yml
├── lib
├── addons
│ ├── dcgmExporter.ts
│ ├── ebsThroughputTuner.ts
│ ├── s3CSIDriver.ts
│ └── sharedComponent.ts
├── dataPlane.ts
├── resourceProvider
│ ├── s3GWEndpoint.ts
│ ├── sns.ts
│ └── vpc.ts
├── runtime
│ └── sdRuntime.ts
└── utils
│ ├── namespace.ts
│ └── validateConfig.ts
├── load_test
└── locustfile.py
├── package-lock.json
├── package.json
├── src
├── backend
│ └── queue_agent
│ │ ├── .gitignore
│ │ ├── Dockerfile
│ │ ├── requirements.txt
│ │ └── src
│ │ ├── __init__.py
│ │ ├── main.py
│ │ ├── modules
│ │ ├── __init__.py
│ │ ├── http_action.py
│ │ ├── misc.py
│ │ ├── s3_action.py
│ │ ├── sns_action.py
│ │ ├── sqs_action.py
│ │ └── time_utils.py
│ │ └── runtimes
│ │ ├── __init__.py
│ │ ├── comfyui.py
│ │ └── sdwebui.py
├── charts
│ └── sd_on_eks
│ │ ├── Chart.yaml
│ │ ├── _helpers.tpl
│ │ ├── templates
│ │ ├── _helpers.tpl
│ │ ├── aws-sqs-queue-scaledobject.yaml
│ │ ├── configmap-comfyui.yml
│ │ ├── configmap-queue-agent.yml
│ │ ├── configmap-sdwebui.yml
│ │ ├── deployment-comfyui.yaml
│ │ ├── deployment-sdwebui.yaml
│ │ ├── keda-trigger-auth-aws-credentials.yaml
│ │ ├── nodeclass.yaml
│ │ ├── nodepool.yaml
│ │ ├── persistentvolume-s3.yaml
│ │ └── persistentvolumeclaim.yaml
│ │ └── values.yaml
├── frontend
│ └── input_function
│ │ ├── v1alpha1
│ │ └── app.py
│ │ └── v1alpha2
│ │ └── app.py
└── tools
│ └── ebs_throughput_tuner
│ └── app.py
├── test
├── run.sh
├── v1alpha1
│ ├── i2i.json
│ ├── t2i.json
│ └── t2v.json
└── v1alpha2
│ ├── extra-single-image.json
│ ├── i2i.json
│ ├── pipeline.json
│ └── t2i.json
└── tsconfig.json
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: "[BUG]"
5 | labels: bug
6 | assignees: ''
7 |
8 | ---
9 |
10 | **What happened**:
11 |
14 |
15 | **To Reproduce**
16 | Steps to reproduce the behavior:
17 |
18 | **Attach logs**
19 |
20 | **What you expected to happen**:
21 |
22 | **Anything else we need to know?**:
23 |
24 | **Environment**:
25 |
28 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: "[Feature]"
5 | labels: enhancement
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Describe alternatives you've considered**
17 | A clear and concise description of any alternative solutions or features you've considered.
18 |
19 | **Additional context**
20 | Add any other context or screenshots about the feature request here.
21 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "npm"
4 | directory: "/"
5 | schedule:
6 | interval: "monthly"
7 | ignore:
8 | - dependency-name: "*"
9 | update-types: ["version-update:semver-patch", "version-update:semver-minor"]
10 | - dependency-name: "aws-cdk-lib"
11 | labels:
12 | - "kind/dependencies"
13 | - "component/cdk"
14 |
15 | - package-ecosystem: "pip"
16 | directory: "/src/backend/queue_agent"
17 | schedule:
18 | interval: "monthly"
19 | ignore:
20 | - dependency-name: "*"
21 | update-types: ["version-update:semver-patch", "version-update:semver-minor"]
22 | labels:
23 | - "kind/dependencies"
24 | - "component/queue-agent"
25 |
--------------------------------------------------------------------------------
/.github/solutionid_validator.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | #set -e
3 |
4 | echo "checking solution id $1"
5 | echo "grep -nr --exclude-dir='.github' "$1" ./.."
6 | result=$(grep -nr --exclude-dir='.github' "$1" ./..)
7 | if [ $? -eq 0 ]
8 | then
9 | echo "Solution ID $1 found\n"
10 | echo "$result"
11 | exit 0
12 | else
13 | echo "Solution ID $1 not found"
14 | exit 1
15 | fi
16 |
17 | export result
18 |
--------------------------------------------------------------------------------
/.github/workflows/maintainer_workflows.yml:
--------------------------------------------------------------------------------
1 | # Workflows managed by aws-solutions-library-samples maintainers
2 | name: Maintainer Workflows
3 | on:
4 | # Triggers the workflow on push or pull request events but only for the "main" branch
5 | push:
6 | branches: [ "main" ]
7 | pull_request:
8 | branches: [ "main" ]
9 | types: [opened, reopened, edited]
10 |
11 | jobs:
12 | CheckSolutionId:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v4
16 | - name: Run solutionid validator
17 | run: |
18 | chmod u+x ./.github/solutionid_validator.sh
19 | ./.github/solutionid_validator.sh ${{ vars.SOLUTIONID }}
--------------------------------------------------------------------------------
/.github/workflows/pr-triage.yml:
--------------------------------------------------------------------------------
1 | name: Auto Approve document-related PR
2 |
3 | on:
4 | pull_request_target:
5 | types: [opened, synchronize]
6 |
7 | jobs:
8 | auto-approve:
9 | runs-on: ubuntu-latest
10 | permissions:
11 | pull-requests: write
12 | steps:
13 | - uses: actions/checkout@v4
14 | - name: Get PR author
15 | id: get-author
16 | run: |
17 | PR_AUTHOR=$(jq -r '.pull_request.user.login' "$GITHUB_EVENT_PATH")
18 | echo "PR author: $PR_AUTHOR"
19 | echo "pr-author=$PR_AUTHOR" >> $GITHUB_OUTPUT
20 | - name: Check author permission
21 | id: author-permission
22 | uses: actions-cool/check-user-permission@v2
23 | with:
24 | username: ${{ steps.get-author.outputs.pr-author }}
25 | require: write
26 | - name: Auto-approve for break glass PR
27 | id: break-glass
28 | if: ${{ (steps.author-permission.outputs.require-result == 'true') && (startsWith(github.event.pull_request.title, '[Break Glass]')) }}
29 | uses: hmarr/auto-approve-action@v4
30 | - name: Auto-approve if author is able to write and contains only doc change
31 | id: doc-change
32 | if: steps.author-permission.outputs.require-result == 'true'
33 | uses: actions/github-script@v7
34 | with:
35 | github-token: ${{ secrets.GITHUB_TOKEN }}
36 | script: |
37 | const { data: files } = await github.rest.pulls.listFiles({
38 | owner: context.repo.owner,
39 | repo: context.repo.repo,
40 | pull_number: context.issue.number
41 | })
42 |
43 | const onlyDocChanges = files.every(file => file.filename.startsWith('docs/'))
44 |
45 | if (onlyDocChanges) {
46 | await github.rest.pulls.createReview({
47 | owner: context.repo.owner,
48 | repo: context.repo.repo,
49 | pull_number: context.issue.number,
50 | event: 'APPROVE'
51 | })
52 | console.log('Auto-approved PR by author')
53 | } else {
54 | console.log('PR does not meet auto-approval criteria')
55 | }
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | ### Helm ###
2 | # Chart dependencies
3 | **/charts/*.tgz
4 |
5 | ### Node ###
6 | # Logs
7 | logs
8 | *.log
9 | npm-debug.log*
10 | yarn-debug.log*
11 | yarn-error.log*
12 | lerna-debug.log*
13 | .pnpm-debug.log*
14 |
15 | # Diagnostic reports (https://nodejs.org/api/report.html)
16 | report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
17 |
18 | # Runtime data
19 | pids
20 | *.pid
21 | *.seed
22 | *.pid.lock
23 |
24 | # Directory for instrumented libs generated by jscoverage/JSCover
25 | lib-cov
26 |
27 | # Coverage directory used by tools like istanbul
28 | coverage
29 | *.lcov
30 |
31 | # nyc test coverage
32 | .nyc_output
33 |
34 | # Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
35 | .grunt
36 |
37 | # Bower dependency directory (https://bower.io/)
38 | bower_components
39 |
40 | # node-waf configuration
41 | .lock-wscript
42 |
43 | # Compiled binary addons (https://nodejs.org/api/addons.html)
44 | build/Release
45 |
46 | # Dependency directories
47 | node_modules/
48 | jspm_packages/
49 |
50 | # Snowpack dependency directory (https://snowpack.dev/)
51 | web_modules/
52 |
53 | # TypeScript cache
54 | *.tsbuildinfo
55 |
56 | # Optional npm cache directory
57 | .npm
58 |
59 | # Optional eslint cache
60 | .eslintcache
61 |
62 | # Optional stylelint cache
63 | .stylelintcache
64 |
65 | # Microbundle cache
66 | .rpt2_cache/
67 | .rts2_cache_cjs/
68 | .rts2_cache_es/
69 | .rts2_cache_umd/
70 |
71 | # Optional REPL history
72 | .node_repl_history
73 |
74 | # Output of 'npm pack'
75 | *.tgz
76 |
77 | # Yarn Integrity file
78 | .yarn-integrity
79 |
80 | # dotenv environment variable files
81 | .env
82 | .env.development.local
83 | .env.test.local
84 | .env.production.local
85 | .env.local
86 |
87 | # parcel-bundler cache (https://parceljs.org/)
88 | .cache
89 | .parcel-cache
90 |
91 | # Next.js build output
92 | .next
93 | out
94 |
95 | # Nuxt.js build / generate output
96 | .nuxt
97 | dist
98 |
99 | # Gatsby files
100 | .cache/
101 | # Comment in the public line in if your project uses Gatsby and not Next.js
102 | # https://nextjs.org/blog/next-9-1#public-directory-support
103 | # public
104 |
105 | # vuepress build output
106 | .vuepress/dist
107 |
108 | # vuepress v2.x temp and cache directory
109 | .temp
110 |
111 | # Docusaurus cache and generated files
112 | .docusaurus
113 |
114 | # Serverless directories
115 | .serverless/
116 |
117 | # FuseBox cache
118 | .fusebox/
119 |
120 | # DynamoDB Local files
121 | .dynamodb/
122 |
123 | # TernJS port file
124 | .tern-port
125 |
126 | # Stores VSCode versions used for testing VSCode extensions
127 | .vscode-test
128 | .vscode
129 |
130 | # yarn v2
131 | .yarn/cache
132 | .yarn/unplugged
133 | .yarn/build-state.yml
134 | .yarn/install-state.gz
135 | .pnp.*
136 |
137 | ### Node Patch ###
138 | # Serverless Webpack directories
139 | .webpack/
140 |
141 | # Optional stylelint cache
142 |
143 | # SvelteKit build / generate output
144 | .svelte-kit
145 |
146 |
147 | # CDK asset staging directory
148 | .cdk.staging
149 | cdk.out
150 | cdk.context.json
151 |
152 | *.js
153 | !jest.config.js
154 | *.d.ts
155 |
156 | ### macOS ###
157 | # General
158 | .DS_Store
159 | .AppleDouble
160 | .LSOverride
161 |
162 | # Icon must end with two \r
163 | Icon
164 |
165 |
166 | # Thumbnails
167 | ._*
168 |
169 | # Files that might appear in the root of a volume
170 | .DocumentRevisions-V100
171 | .fseventsd
172 | .Spotlight-V100
173 | .TemporaryItems
174 | .Trashes
175 | .VolumeIcon.icns
176 | .com.apple.timemachine.donotpresent
177 |
178 | # Directories potentially created on remote AFP share
179 | .AppleDB
180 | .AppleDesktop
181 | Network Trash Folder
182 | Temporary Items
183 | .apdisk
184 |
185 | ### macOS Patch ###
186 | # iCloud generated files
187 | *.icloud
188 |
189 | ### JupyterNotebooks ###
190 | # gitignore template for Jupyter Notebooks
191 | # website: http://jupyter.org/
192 |
193 | .ipynb_checkpoints
194 | */.ipynb_checkpoints/*
195 |
196 | # IPython
197 | profile_default/
198 | ipython_config.py
199 |
200 | # mkdocs
201 | site/
202 |
203 | ### Python ###
204 | # Byte-compiled / optimized / DLL files
205 | __pycache__/
206 | *.py[cod]
207 | *$py.class
208 |
209 | # C extensions
210 | *.so
211 |
212 | # Distribution / packaging
213 | .Python
214 | build/
215 | develop-eggs/
216 | dist/
217 | downloads/
218 | eggs/
219 | .eggs/
220 | lib64/
221 | parts/
222 | sdist/
223 | var/
224 | wheels/
225 | share/python-wheels/
226 | *.egg-info/
227 | .installed.cfg
228 | *.egg
229 | MANIFEST
230 |
231 | # PyInstaller
232 | # Usually these files are written by a python script from a template
233 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
234 | *.manifest
235 | *.spec
236 |
237 | # Installer logs
238 | pip-log.txt
239 | pip-delete-this-directory.txt
240 |
241 | # Unit test / coverage reports
242 | htmlcov/
243 | .tox/
244 | .nox/
245 | .coverage
246 | .coverage.*
247 | .cache
248 | nosetests.xml
249 | coverage.xml
250 | *.cover
251 | *.py,cover
252 | .hypothesis/
253 | .pytest_cache/
254 | cover/
255 |
256 | # Translations
257 | *.mo
258 | *.pot
259 |
260 | # Django stuff:
261 | *.log
262 | local_settings.py
263 | db.sqlite3
264 | db.sqlite3-journal
265 |
266 | # Flask stuff:
267 | instance/
268 | .webassets-cache
269 |
270 | # Scrapy stuff:
271 | .scrapy
272 |
273 | # Sphinx documentation
274 | docs/_build/
275 |
276 | # PyBuilder
277 | .pybuilder/
278 | target/
279 |
280 | # pyenv
281 | # For a library or package, you might want to ignore these files since the code is
282 | # intended to run in multiple environments; otherwise, check them in:
283 | # .python-version
284 |
285 | # pipenv
286 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
287 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
288 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
289 | # install all needed dependencies.
290 | #Pipfile.lock
291 |
292 | # poetry
293 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
294 | # This is especially recommended for binary packages to ensure reproducibility, and is more
295 | # commonly ignored for libraries.
296 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
297 | #poetry.lock
298 |
299 | # pdm
300 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
301 | #pdm.lock
302 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
303 | # in version control.
304 | # https://pdm.fming.dev/#use-with-ide
305 | .pdm.toml
306 |
307 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
308 | __pypackages__/
309 |
310 | # Celery stuff
311 | celerybeat-schedule
312 | celerybeat.pid
313 |
314 | # SageMath parsed files
315 | *.sage.py
316 |
317 | # Environments
318 | .env
319 | .venv
320 | env/
321 | venv/
322 | ENV/
323 | env.bak/
324 | venv.bak/
325 |
326 | # Spyder project settings
327 | .spyderproject
328 | .spyproject
329 |
330 | # Rope project settings
331 | .ropeproject
332 |
333 | # mypy
334 | .mypy_cache/
335 | .dmypy.json
336 | dmypy.json
337 |
338 | # Pyre type checker
339 | .pyre/
340 |
341 | # pytype static type analyzer
342 | .pytype/
343 |
344 | # Cython debug symbols
345 | cython_debug/
346 |
347 | # PyCharm
348 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
349 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
350 | # and can be added to the global gitignore or merged into this file. For a more nuclear
351 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
352 | #.idea/
353 |
354 | ### Python Patch ###
355 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
356 | poetry.toml
357 |
358 | # ruff
359 | .ruff_cache/
360 |
361 | # LSP config files
362 | pyrightconfig.json
363 |
364 | manual.md
365 | config.yaml
366 |
367 | ### Jekyll ###
368 | _site/
369 | .sass-cache/
370 | .jekyll-cache/
371 | .jekyll-metadata
372 | # Ignore folders generated by Bundler
373 | .bundle/
374 | vendor/
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "utils/bottlerocket-images-cache"]
2 | path = utils/bottlerocket-images-cache
3 | url = https://github.com/aws-samples/bottlerocket-images-cache
4 |
--------------------------------------------------------------------------------
/.npmignore:
--------------------------------------------------------------------------------
1 | *.ts
2 | !*.d.ts
3 |
4 | # CDK asset staging directory
5 | .cdk.staging
6 | cdk.out
7 |
--------------------------------------------------------------------------------
/CODEOWNERS:
--------------------------------------------------------------------------------
1 | CODEOWNERS @aws-solutions-library-samples/maintainers @yubingjiaocn
2 | /.github/workflows/maintainer_workflows.yml @aws-solutions-library-samples/maintainers
3 | /.github/solutionid_validator.sh @aws-solutions-library-samples/maintainers
4 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT No Attribution
2 |
3 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy of
6 | this software and associated documentation files (the "Software"), to deal in
7 | the Software without restriction, including without limitation the rights to
8 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
9 | the Software, and to permit persons to whom the Software is furnished to do so.
10 |
11 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
12 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
13 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
14 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
15 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
16 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
17 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 |
3 | **********************
4 | THIRD PARTY COMPONENTS
5 | **********************
6 |
7 | This software includes third party software subject to the following licenses.
8 |
9 | KEDA
10 | Apache-2.0 license
11 | https://github.com/kedacore/keda/blob/main/LICENSE
12 |
13 | Karpenter
14 | Apache-2.0 license
15 | https://github.com/aws/karpenter/blob/main/LICENSE
16 |
17 | AWS CDK
18 | Apache-2.0 license
19 | https://github.com/aws/aws-cdk/blob/main/LICENSE
20 |
21 | Amazon EKS Blueprints for CDK
22 | Apache-2.0 license
23 | https://github.com/aws-quickstart/cdk-eks-blueprints/blob/main/LICENSE
24 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [Guidance for Asynchronous Inference with Stable Diffusion on AWS](https://aws.amazon.com/solutions/guidance/asynchronous-image-generation-with-stable-diffusion-on-aws/)
2 |
3 | Implementing a fast scaling and low cost Stable Diffusion inference solution with serverless and containers on AWS
4 |
5 | [Stable Diffusion](https://aws.amazon.com/what-is/stable-diffusion/) is a popular open source project for generating images using Generative AI. Building a scalable and cost efficient Machine Learning (ML) Inference solution is a common challenge that many AWS customers are facing.
6 | This project shows how to use serverless architcture and container services to build an end-to-end low cost, rapidly scaling asyncronous image generation architecture. This repo contains the sample code and CDK deployment scripts that will help you to deploy this solution in a few steps.
7 |
8 | ## Features
9 |
10 | - Asyncronous API and [Serverless Event-Driven Architecture](https://docs.aws.amazon.com/wellarchitected/latest/serverless-applications-lens/event-driven-architectures.html)
11 | - Image Generation with open source Stable Diffusion runtimes running on [Amazon EKS](https://aws.amazon.com/eks)
12 | - Automatic [Amazon SQS](https://aws.amazon.com/sqs/) queue length based scaling with [KEDA](https://keda.sh/)
13 | - Automatic provisioning of EC2 instances for Amazon EKS compute Nodes with [Karpenter](https://karpenter.sh/)
14 | - Scaling up new Amazon EKS nodes within 2 minutes to run Inference tasks
15 | - Saving up to 70% with [AWS GPU](https://aws.amazon.com/ec2/instance-types/g5/) spot EC2 instances
16 |
17 | ## Architecture diagram
18 |
19 |
20 |
21 |

22 |
23 | Figure 1: Asynchronous Image Generation with Stable Diffusion on AWS reference architecture
24 |
25 |
26 | ### Architecture steps
27 |
28 | 1. An application sends the prompt to [Amazon API Gateway](https://aws.amazon.com/api-gateway/) that acts as an endpoint for the overall Guidance, including authentication. [AWS Lambda](https://aws.amazon.com/lambda/) function validates the requests, publishes them to the designated [Amazon Simple Notification Service](https://aws.amazon.com/sns/) (Amazon SNS) topic, and immediately returns a response.
29 | 2. Amazon SNS publishes the message to [Amazon Simple Queue Service](https://aws.amazon.com/sqs/) (Amazon SQS) queues. Each message contains a Stable Diffusion (SD) runtime name attribute and will be delivered to the queues with matching SD runtime.
30 | 3. In the [Amazon Elastic Kubernetes Service](https://aws.amazon.com/eks/) (Amazon EKS) cluster, the previously deployed open source [Kubernetes Event Driven Auto-Scaler (KEDA)](https://keda.sh) scales up new pods to process the incoming messages from SQS model processing queues.
31 | 4. In the Amazon EKS cluster, the previously deployed open source Kubernetes auto-scaler, [Karpenter](https://karpenter.sh), launches new compute nodes based on GPU [Amazon Elastic Compute Cloud](https://aws.amazon.com/ec2/) (Amazon EC2) instances (such as g4, g5, and p4) to schedule pending pods. The instances use pre-cached SD Runtime images and are based on [Bottlerocket OS](https://aws.amazon.com/bottlerocket/) for fast boot. The instance can be launched with on-demand or [spot](https://aws.amazon.com/ec2/spot) pricing model.
32 | 5. Stable Diffusion Runtimes load ML model files from [Amazon Simple Storage Service](https://aws.amazon.com/efs/) (Amazon S3) via [Mountpoint for Amazon S3 CSI Driver](https://github.com/awslabs/mountpoint-s3-csi-driver) on runtime initialization or on demand.
33 | 6. Queue agents (software component created for this Guidance) receive messages from SQS model processing queues and convert them to inputs for SD Runtime APIs calls.
34 | 7. Queue agents call SD Runtime APIs, receive and decode responses, and save the generated images to designated Amazon S3 buckets.
35 | 8. Queue agents send notifications to the designated SNS topic from the pods, user receives notifications from SNS and can access images in S3 buckets.
36 |
37 |
38 | ### AWS services in this Guidance
39 |
40 | | **AWS service** | Description |
41 | |-----------|------------|
42 | |[Amazon Elastic Kubernetes Service - EKS](https://aws.amazon.com/eks/)|Core service - application platform host the SD containerized workloads|
43 | |[Amazon Virtual Private Cloud - VPC](https://aws.amazon.com/vpc/)| Core Service - network security layer |
44 | |[Amazon Elastic Compute Cloud - EC2](https://aws.amazon.com/ec2/)| Core Service - EC2 instance power On Demand and Spot based EKS compute node groups for running container workloads|
45 | |[Amazon Elastic Container Registry - ECR](https://aws.amazon.com/ecr/)|Core service - ECR registry is used to host the container images and Helm charts|
46 | |[Amazon Simple Storage Service S3](https://aws.amazon.com/s3/)|Core service - Object storage for model files and generated image|
47 | |[Amazon API Gateway](https://aws.amazon.com/api-gateway/)| Core service - endpoint for all user requests|
48 | |[AWS Lambda](https://aws.amazon.com/lambda/)| Core service - validates the requests, publishes them to the designated queues |
49 | |[Amazon Simple Queue Service](https://aws.amazon.com/sqs/)| Core service - provides asynchronous event handling |
50 | |[Amazon Simple Notification Service](https://aws.amazon.com/sns/)| Core service - provides model specific event processing |
51 | |[Amazon CloudWatch](https://aws.amazon.com/cloudwatch/)|Auxiliary service - provides observability for core services |
52 | |[AWS CDK](https://aws.amazon.com/cdk/) | Core service - Used for deploying and updating this solution|
53 |
54 | ### Cost
55 |
56 | You are responsible for the cost of the AWS services used while running this Guidance. As of April 2024, the cost for running this
57 | Guidance with the default settings in the US West (Oregon) is approximately for one month and generating one million images would cost approximately **$436.72** as illustrated in the two sample tables below (excluding free tiers).
58 |
59 | We recommend creating a [budget](https://docs.aws.amazon.com/cost-management/latest/userguide/budgets-create.html) through [AWS Cost Explorer](http://aws.amazon.com/aws-cost-management/aws-cost-explorer/) to help monitor and manage costs. Prices are subject to change. For full details, refer to the pricing webpage for each AWS service used in this Guidance.
60 |
61 | The main services and their pricing for usage related to the number of images are listed below (per one million images):
62 |
63 | | **AWS Service** | **Billing Dimension** | **Quantity per 1M Images** | **Unit Price \[USD\]** | **Total \[USD\]** |
64 | |-----------|------------|------------|------------|------------|
65 | | Amazon EC2 | g5.2xlarge instance, Spot instance per hour | 416.67 | \$ 0.4968 | \$ 207 |
66 | | Amazon API Gateway | Per 1M REST API requests | 1 | \$ 3.50 | \$ 3.50 |
67 | | AWS Lambda | Per GB-second | 12,50 | \$ 0.0000166667 | \$ 0.21 |
68 | | AWS Lambda | Per 1M requests | 1 | \$ 0.20 | \$ 0.20 |
69 | | Amazon SNS | Per 1M requests | 2 | \$ 0.50 | \$ 0.50 |
70 | | Amazon SNS | Data transfer per GB | 7.62\** | \$ 0.09 | \$ 0.68 |
71 | | Amazon SQS | Per 1M requests | 2 | \$ 0.40 | \$ 0.80 |
72 | | Amazon S3 | Per 1K PUT requests | 2,000 | \$ 0.005 | \$ 10.00 |
73 | | Amazon S3 | Per GB per month | 143.05\*** | \$ 0.023 | \$ 3.29 |
74 | | **Total, 1M images** | | | | **\$226.18** |
75 |
76 | The fixed costs unrelated to the number of images, with the main services and their pricing listed below (per month):
77 |
78 | | **AWS Service** | Billing Dimension | Quantity per Month | Unit Price \[USD\] | Total \[USD\]
79 | |-----------|------------|------------|------------|------------|
80 | | Amazon EKS | Cluster | 1 | \$ 72.00 | \$ 72.00 |
81 | | Amazon EC2 | m5.large instance, On-Demand instance per hour | 1440 | \$ 0.0960 | \$ 138.24 |
82 | | **Total, month** | | | | **\$210.24** |
83 |
84 | - \* Calculated based on an average request duration of 1.5 seconds and the average Spot instance pricing across all Availability Zones in the `us-west-2` (Oregon) Region from January 29, 2024, to April 28, 2024.
85 | - \*\* Calculated based on an average request size of 16 KB
86 | - \*\*\* Calculated based on an average image size of 150 KB, stored for 1 month.
87 |
88 | Please note that thise are estimated costs for reference only. The actual cost may vary depending on the model you use, task parameters, current Spot instance pricing, and other factors.
89 |
90 | ## Deployment Documentation
91 |
92 | Please see detailed Implementation Guides here:
93 | - [English](https://aws-solutions-library-samples.github.io/ai-ml/asynchronous-image-generation-with-stable-diffusion-on-aws.html)
94 | - [Chinese 简体中文 ](https://aws-solutions-library-samples.github.io/ai-ml/asynchronous-image-generation-with-stable-diffusion-on-aws-zh.html)
95 |
96 | ## Security
97 |
98 | When you build systems on AWS infrastructure, security responsibilities are shared between you and AWS. This [shared responsibility model](https://aws.amazon.com/compliance/shared-responsibility-model/) reduces your operational burden because AWS operates, manages, and
99 | controls the components, including host operating systems, the virtualization layer, and the physical security of the facilities in
100 | which the services operate. For more information about AWS security, visit [AWS Cloud Security](http://aws.amazon.com/security/).
101 |
102 | For potential security issue, see [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
103 |
104 | ## License
105 |
106 | This library is licensed under MIT-0 License. See the [LICENSE](LICENSE) file.
107 |
--------------------------------------------------------------------------------
/bin/stable-difussion-on-eks.ts:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env node
2 | import { App, Aspects } from 'aws-cdk-lib';
3 | import DataPlaneStack from "../lib/dataPlane";
4 | import { parse } from 'yaml'
5 | import * as fs from 'fs'
6 | import { validateConfig } from '../lib/utils/validateConfig';
7 |
8 | const app = new App();
9 |
10 | const env = {
11 | account: process.env.CDK_DEFAULT_ACCOUNT,
12 | region: process.env.CDK_DEFAULT_REGION,
13 | }
14 |
15 | let filename: string
16 |
17 | if ("CDK_CONFIG_PATH" in process.env) {
18 | filename = process.env.CDK_CONFIG_PATH as string
19 | } else {
20 | filename = 'config.yaml'
21 | }
22 |
23 | const file = fs.readFileSync(filename, 'utf8')
24 | const props = parse(file)
25 |
26 | if (validateConfig(props)) {
27 | const dataPlaneStack = new DataPlaneStack(app, props.stackName, props, {
28 | env: env,
29 | description: "Guidance for Asynchronous Image Generation with Stable Diffusion on AWS (SO9306)"
30 | });
31 | } else {
32 | console.log("Deployment failed due to failed validation. Please check and try again.")
33 | }
--------------------------------------------------------------------------------
/cdk.json:
--------------------------------------------------------------------------------
1 | {
2 | "app": "npx ts-node --prefer-ts-exts bin/stable-difussion-on-eks.ts",
3 | "watch": {
4 | "include": [
5 | "**"
6 | ],
7 | "exclude": [
8 | "README.md",
9 | "cdk*.json",
10 | "**/*.d.ts",
11 | "**/*.js",
12 | "tsconfig.json",
13 | "package*.json",
14 | "yarn.lock",
15 | "node_modules",
16 | "test"
17 | ]
18 | },
19 | "context": {
20 | "@aws-cdk/aws-lambda:recognizeLayerVersion": true,
21 | "@aws-cdk/core:checkSecretUsage": true,
22 | "@aws-cdk/core:target-partitions": [
23 | "aws",
24 | "aws-cn"
25 | ],
26 | "@aws-cdk-containers/ecs-service-extensions:enableDefaultLogDriver": true,
27 | "@aws-cdk/aws-ec2:uniqueImdsv2TemplateName": true,
28 | "@aws-cdk/aws-ecs:arnFormatIncludesClusterName": true,
29 | "@aws-cdk/aws-iam:minimizePolicies": true,
30 | "@aws-cdk/core:validateSnapshotRemovalPolicy": true,
31 | "@aws-cdk/aws-codepipeline:crossAccountKeyAliasStackSafeResourceName": true,
32 | "@aws-cdk/aws-s3:createDefaultLoggingPolicy": true,
33 | "@aws-cdk/aws-sns-subscriptions:restrictSqsDescryption": true,
34 | "@aws-cdk/aws-apigateway:disableCloudWatchRole": true,
35 | "@aws-cdk/core:enablePartitionLiterals": true,
36 | "@aws-cdk/aws-events:eventsTargetQueueSameAccount": true,
37 | "@aws-cdk/aws-iam:standardizedServicePrincipals": true,
38 | "@aws-cdk/aws-ecs:disableExplicitDeploymentControllerForCircuitBreaker": true,
39 | "@aws-cdk/aws-iam:importedRoleStackSafeDefaultPolicyName": true,
40 | "@aws-cdk/aws-s3:serverAccessLogsUseBucketPolicy": true,
41 | "@aws-cdk/aws-route53-patters:useCertificate": true,
42 | "@aws-cdk/customresources:installLatestAwsSdkDefault": false,
43 | "@aws-cdk/aws-rds:databaseProxyUniqueResourceName": true,
44 | "@aws-cdk/aws-codedeploy:removeAlarmsFromDeploymentGroup": true,
45 | "@aws-cdk/aws-apigateway:authorizerChangeDeploymentLogicalId": true,
46 | "@aws-cdk/aws-ec2:launchTemplateDefaultUserData": true,
47 | "@aws-cdk/aws-secretsmanager:useAttachedSecretResourcePolicyForSecretTargetAttachments": true,
48 | "@aws-cdk/aws-redshift:columnId": true,
49 | "@aws-cdk/aws-stepfunctions-tasks:enableEmrServicePolicyV2": true,
50 | "@aws-cdk/aws-ec2:restrictDefaultSecurityGroup": true,
51 | "@aws-cdk/aws-apigateway:requestValidatorUniqueId": true,
52 | "@aws-cdk/aws-kms:aliasNameRef": true,
53 | "@aws-cdk/aws-autoscaling:generateLaunchTemplateInsteadOfLaunchConfig": true,
54 | "@aws-cdk/core:includePrefixInUniqueNameGeneration": true,
55 | "@aws-cdk/aws-opensearchservice:enableOpensearchMultiAzWithStandby": true
56 | },
57 | "requireApproval": "never",
58 | "rollback": false
59 | }
60 |
--------------------------------------------------------------------------------
/config.schema.yaml:
--------------------------------------------------------------------------------
1 | stackName: str()
2 | modelBucketArn: str(starts_with="arn")
3 | APIGW:
4 | stageName: str()
5 | throttle:
6 | rateLimit: int()
7 | burstLimit: int()
8 | modelsRuntime: list(include('runtime'), min=1)
9 | ---
10 | runtime:
11 | name: str()
12 | namespace: str()
13 | modelFilename: str(required=False)
14 | dynamicModel: bool(required=False)
15 | type: enum('sdwebui', 'comfyui')
16 | extraValues: include('extraValues')
17 | ---
18 | ebs:
19 | volumeSize: str(required=False)
20 | volumeType: str(required=False)
21 | deleteOnTermination: bool(required=False)
22 | iops: int(min=3000, max=16000, required=False)
23 | throughput: int(min=125, max=1000, required=False)
24 | snapshotID: str(starts_with="snap", required=False)
25 | ---
26 | extraValues:
27 | karpenter: include('karpenter', required=False)
28 | runtime: include('runtimeValues', required=False)
29 | ---
30 | karpenter:
31 | provisioner: include('provisioner', required=False)
32 | nodeTemplate: include('nodeTemplate', required=False)
33 | ---
34 | provisioner:
35 | labels: map(str(), str(), required=False)
36 | capacityType:
37 | onDemand: bool(required=False)
38 | spot: bool(required=False)
39 | instanceType: list(str(), required=False)
40 | extraRequirements: list(any(), required=False)
41 | extraTaints: list(any(), required=False)
42 | resourceLimits: include('resources', required=False)
43 | consolidation: bool(required=False)
44 | disruption: include('disruption', required=False)
45 | ---
46 | disruption:
47 | consolidateAfter: str(required=False)
48 | expireAfter: str(required=False)
49 | ---
50 | nodeTemplate:
51 | securityGroupSelector: map(str(), str(), required=False)
52 | subnetSelector: map(str(), str(), required=False)
53 | tags: map(str(), str(), required=False)
54 | amiFamily: enum('Bottlerocket', required=False)
55 | osVolume: include('ebs', required=False)
56 | dataVolume: include('ebs', required=False)
57 | userData: str(required=False)
58 | ---
59 | runtimeValues:
60 | labels: map(str(), str(), required=False)
61 | annotations: map(str(), str(), required=False)
62 | scaling: include('scaling', required=False)
63 | inferenceApi: include('inferenceApi', required=False)
64 | queueAgent: include('inferenceApi', required=False)
65 | ---
66 | scaling:
67 | enabled: bool(required=False)
68 | queueLength: int(min=1, required=False)
69 | cooldownPeriod: int(min=1, required=False)
70 | maxReplicaCount: int(min=0, required=False)
71 | minReplicaCount: int(min=0, required=False)
72 | pollingInterval: int(required=False)
73 | scaleOnInFlight: bool(required=False)
74 | extraHPAConfig: any(required=False)
75 | ---
76 | image:
77 | repository: str()
78 | tag: str()
79 | ---
80 | resources:
81 | nvidia.com/gpu: str(required=False)
82 | cpu: str(required=False)
83 | memory: str(required=False)
84 | ---
85 | inferenceApi:
86 | image: include('image', required=False)
87 | modelMountPath: str(required=False)
88 | commandArguments: str(required=False)
89 | extraEnv: map(str(), str(), required=False)
90 | imagePullPolicy: enum('Always', 'IfNotPresent', 'Never')
91 | resources:
92 | limits: include('resources', required=False)
93 | requests: include('resources', required=False)
94 | ---
95 | queueAgent:
96 | image: include('image', required=False)
97 | extraEnv: map(str(), str())
98 | imagePullPolicy: enum('Always', 'IfNotPresent', 'Never')
99 | resources:
100 | limits: include('resources', required=False)
101 | requests: include('resources', required=False)
102 | XRay:
103 | enabled: bool(required=False)
104 |
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | stackName: sdoneks
2 | modelBucketArn: arn:aws:s3:::dummy-bucket
3 | APIGW:
4 | stageName: dev
5 | throttle:
6 | rateLimit: 30
7 | burstLimit: 50
8 | modelsRuntime:
9 | - name: sdruntime
10 | namespace: default
11 | modelFilename: v1-5-pruned-emaonly.safetensors
12 | dynamicModel: false
13 | type: sdwebui
14 | extraValues:
15 | karpenter:
16 | nodeTemplate:
17 | amiFamily: Bottlerocket
18 | # dataVolume:
19 | # snapshotID: snap-0123456789
20 | provisioner:
21 | instanceType:
22 | - g6.xlarge
23 | - g6.2xlarge
24 | capacityType:
25 | onDemand: true
26 | spot: true
27 | runtime:
28 | scaling:
29 | queueLength: 10
30 | minReplicaCount: 0
31 | cooldownPeriod: 300
32 |
--------------------------------------------------------------------------------
/deploy/README.md:
--------------------------------------------------------------------------------
1 | # One-key deployment script
2 |
3 | This script will work as a quick start for this solution. This script will:
4 |
5 | * Install required tools
6 | * Download Stable Diffusions 1.5 model from HuggingFace and upload to S3 bucket
7 | * Generate EBS snapshot with prefetched container images
8 | * Generate a sample config file
9 | * Deploy SD on EKS solutions
10 |
11 | ## Usage
12 |
13 | ```bash
14 | cd deploy
15 | ./deploy.sh
16 | ```
17 |
18 | ## Test after deploy
19 |
20 | This script will generate text-to-image and image-to-image request to SD on EKS endpoints.
21 |
22 | ```bash
23 | cd ../test
24 | ./run.sh
25 | ```
--------------------------------------------------------------------------------
/deploy/config.yaml.template:
--------------------------------------------------------------------------------
1 | # Auto generated config file
2 | # Run cdk deploy to deploy SD on EKS stack
3 |
4 | stackName: ${STACK_NAME}
5 | modelBucketArn: arn:aws:s3:::${MODEL_BUCKET}
6 | APIGW:
7 | stageName: dev
8 | throttle:
9 | rateLimit: 30
10 | burstLimit: 50
11 | modelsRuntime:
12 | - name: ${RUNTIME_NAME}
13 | namespace: "default"
14 | modelFilename: "sd_xl_turbo_1.0.safetensors"
15 | dynamicModel: false
16 | type: ${RUNTIME_TYPE}
17 | extraValues:
18 | karpenter:
19 | nodeTemplate:
20 | amiFamily: Bottlerocket
21 | dataVolume:
22 | snapshotID: ${SNAPSHOT_ID}
23 | provisioner:
24 | instanceType:
25 | - "g6.2xlarge"
26 | - "g5.2xlarge"
27 | capacityType:
28 | onDemand: true
29 | spot: true
30 | runtime:
31 | scaling:
32 | queueLength: 10
33 | minReplicaCount: 0
34 | maxReplicaCount: 5
35 | cooldownPeriod: 300
--------------------------------------------------------------------------------
/deploy/deploy.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 |
5 | SHORTOPTS="h,n:,R:,d,b:,s:,r:,t:,T"
6 | LONGOPTS="help,stack-name:,region:,dry-run,bucket:,snapshot:,runtime-name:,runtime-type:,skip-tools"
7 | ARGS=$(getopt --options $SHORTOPTS --longoptions $LONGOPTS -- "$@" )
8 |
9 | eval set -- "$ARGS"
10 | while true;
11 | do
12 | case $1 in
13 | -h|--help)
14 | printf "Usage: deploy.sh [options] \n"
15 | printf "Options: \n"
16 | printf " -h, --help Print this help message \n"
17 | printf " -T, --skip-tools Skip tools installation \n"
18 | printf " -n, --stack-name Name of the stack to be created (Default: sdoneks) \n"
19 | printf " -R, --region AWS region to be used \n"
20 | printf " -d, --dry-run Don't deploy the stack. You can manually deploy the stack using generated config file.\n"
21 | printf " -b, --bucket S3 bucket name to store the model. If not specified, S3 bucket will be created and SD 1.5 model will be placed. \n"
22 | printf " -s, --snapshot EBS snapshot ID with container image. If not specified, EBS snapshot will be automatically generated. \n"
23 | printf " -r, --runtime-name Runtime name. (Default: sdruntime) \n"
24 | printf " -t, --runtime-type Runtime type. Only 'sdwebui' and 'comfyui' is accepted. (Default: sdwebui) \n"
25 | exit 0
26 | ;;
27 | -n|--stack-name)
28 | STACK_NAME=$2
29 | shift 2
30 | ;;
31 | -T|--skip-tools)
32 | INSTALL_TOOLS=false
33 | shift
34 | ;;
35 | -R|--region)
36 | AWS_DEFAULT_REGION=$2
37 | shift 2
38 | ;;
39 | -d|--dry-run)
40 | DEPLOY=false
41 | shift
42 | ;;
43 | -b|--bucket)
44 | MODEL_BUCKET=$2
45 | shift 2
46 | ;;
47 | -s|--snapshot)
48 | SNAPSHOT_ID=$2
49 | shift 2
50 | ;;
51 | -r|--runtime-name)
52 | RUNTIME_NAME=$2
53 | shift 2
54 | ;;
55 | -t|--runtime-type)
56 | RUNTIME_TYPE=$2
57 | if [[ "$RUNTIME_TYPE" != "sdwebui" && "$RUNTIME_TYPE" != "comfyui" ]] ; then
58 | printf "Runtime type should be 'sdwebui' or 'comfyui'. \n"
59 | exit 1
60 | fi
61 | shift 2
62 | ;;
63 | --)
64 | shift
65 | break
66 | ;;
67 | ?)
68 | shift
69 | printf "invalid parameter"
70 | exit 1
71 | ;;
72 | esac
73 | done
74 |
75 |
76 | SCRIPTPATH=$(realpath $(dirname "$0"))
77 | export AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION:-$(aws ec2 describe-availability-zones --output text --query 'AvailabilityZones[0].[RegionName]')}
78 | declare -l STACK_NAME=${STACK_NAME:-"sdoneks"}
79 | RUNTIME_NAME=${RUNTIME_NAME:-"sdruntime"}
80 | declare -l RUNTIME_TYPE=${RUNTIME_TYPE:-"sdwebui"}
81 | INSTALL_TOOLS=${INSTALL_TOOLS:-true}
82 | DEPLOY=${DEPLOY:-true}
83 | SDWEBUI_IMAGE=public.ecr.aws/bingjiao/sd-on-eks/sdwebui:latest
84 | COMFYUI_IMAGE=public.ecr.aws/bingjiao/sd-on-eks/comfyui:latest
85 | QUEUE_AGENT_IMAGE=public.ecr.aws/bingjiao/sd-on-eks/queue-agent:latest
86 |
87 | # Step 1: Install tools
88 |
89 | printf "Step 1: Install tools... \n"
90 | if [ ${INSTALL_TOOLS} = true ] ; then
91 | "${SCRIPTPATH}"/install-tools.sh
92 | fi
93 |
94 | # Step 2: Create S3 bucket and upload model
95 |
96 | printf "Step 2: Create S3 bucket and upload SD 1.5 model... \n"
97 | if [ -z "${MODEL_BUCKET}" ] ; then
98 | MODEL_BUCKET="${STACK_NAME}"-model-bucket-$(echo ${RANDOM} | md5sum | head -c 4)
99 | aws s3 mb "s3://${MODEL_BUCKET}" --region "${AWS_DEFAULT_REGION}"
100 | "${SCRIPTPATH}"/upload-model.sh "${MODEL_BUCKET}"
101 | else
102 | printf "Existing bucket detected, skipping... \n"
103 | fi
104 |
105 | # Step 3: Create EBS Snapshot
106 |
107 | printf "Step 3: Creating EBS snapshot for faster launching...(This step will last for 15-30 min) \n"
108 | if [ -z "$SNAPSHOT_ID" ]; then
109 | cd "${SCRIPTPATH}"/..
110 | git submodule update --init --recursive
111 | SNAPSHOT_ID=$(utils/bottlerocket-images-cache/snapshot.sh -s 100 -r "${AWS_DEFAULT_REGION}" -q ${SDWEBUI_IMAGE},${COMFYUI_IMAGE},${QUEUE_AGENT_IMAGE})
112 | else
113 | printf "Existing snapshot ID detected, skipping... \n"
114 | fi
115 |
116 | # Step 4: Deploy
117 |
118 | printf "Step 4: Start deploy... \n"
119 | aws iam create-service-linked-role --aws-service-name spot.amazonaws.com >/dev/null 2>&1 || true
120 | cd "${SCRIPTPATH}"/..
121 | npm install
122 |
123 | template="$(cat deploy/config.yaml.template)"
124 | eval "echo \"${template}\"" > config.yaml
125 | cdk bootstrap
126 | if [ ${DEPLOY} = true ] ; then
127 | CDK_DEFAULT_REGION=${AWS_DEFAULT_REGION} cdk deploy --no-rollback --require-approval never
128 | printf "Deploy complete. \n"
129 | else
130 | printf "Please revise config.yaml and run 'cdk deploy --no-rollback --require-approval never' to deploy. \n"
131 | fi
132 |
--------------------------------------------------------------------------------
/deploy/install-tools.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 |
5 | kubectl_version='1.29.0'
6 | helm_version='3.10.1'
7 | yq_version='4.30.4'
8 | s5cmd_version='2.2.2'
9 | node_version='22.12.0'
10 | cdk_version='2.162.1'
11 |
12 | download () {
13 | url=$1
14 | out_file=$2
15 | curl --location --show-error --silent --output "$out_file" "$url"
16 | }
17 | cur_dir=$(pwd)
18 | tmp_dir=$(mktemp -d)
19 |
20 | cd "$tmp_dir"
21 |
22 | if which pacapt > /dev/null
23 | then
24 | printf "\n"
25 | else
26 | sudo wget -O /usr/bin/pacapt https://github.com/icy/pacapt/raw/ng/pacapt
27 | sudo chmod 755 /usr/bin/pacapt
28 | fi
29 |
30 | sudo pacapt install --noconfirm jq git wget unzip openssl bash-completion python3 python3-pip
31 |
32 | pip3 install -q yamale
33 |
34 | # kubectl
35 | if which kubectl > /dev/null
36 | then
37 | printf "kubectl is installed, skipping...\n"
38 | else
39 | printf "Installing kubectl...\n"
40 | download "https://dl.k8s.io/release/v$kubectl_version/bin/linux/amd64/kubectl" "kubectl"
41 | chmod +x ./kubectl
42 | sudo mv ./kubectl /usr/bin
43 | fi
44 |
45 | # helm
46 | if which helm > /dev/null
47 | then
48 | printf "helm is installed, skipping...\n"
49 | else
50 | printf "Installing helm...\n"
51 | download "https://get.helm.sh/helm-v$helm_version-linux-amd64.tar.gz" "helm.tar.gz"
52 | tar zxf helm.tar.gz
53 | chmod +x linux-amd64/helm
54 | sudo mv ./linux-amd64/helm /usr/bin
55 | rm -rf linux-amd64/ helm.tar.gz
56 | fi
57 |
58 | # aws cli v2
59 | printf "Installing/Upgrading AWS CLI v2...\n"
60 | curl --location --show-error --silent "https://awscli.amazonaws.com/awscli-exe-linux-x86_64.zip" -o "awscliv2.zip"
61 | unzip -o -q awscliv2.zip -d /tmp
62 | sudo /tmp/aws/install --update
63 | rm -rf /tmp/aws awscliv2.zip
64 |
65 | # yq
66 | if which yq > /dev/null
67 | then
68 | printf "yq is installed, skipping...\n"
69 | else
70 | printf "Installing yq...\n"
71 | download "https://github.com/mikefarah/yq/releases/download/v${yq_version}/yq_linux_amd64" "yq"
72 | chmod +x ./yq
73 | sudo mv ./yq /usr/bin
74 | fi
75 |
76 | # s5cmd
77 | if which s5cmd > /dev/null
78 | then
79 | printf "s5cmd is installed, skipping...\n"
80 | else
81 | printf "Installing s5cmd...\n"
82 | download "https://github.com/peak/s5cmd/releases/download/v${s5cmd_version}/s5cmd_${s5cmd_version}_Linux-64bit.tar.gz" "s5cmd.tar.gz"
83 | tar zxf s5cmd.tar.gz
84 | chmod +x ./s5cmd
85 | sudo mv ./s5cmd /usr/bin
86 | rm -rf s5cmd.tar.gz
87 | fi
88 |
89 | # Node.js
90 | if which node > /dev/null
91 | then
92 | printf "Node.js is installed, skipping...\n"
93 | else
94 | printf "Installing Node.js...\n"
95 | download "https://nodejs.org/dist/v{$node_version}/node-v{$node_version}-linux-x64.tar.xz" "node.tar.xz"
96 | sudo mkdir -p /usr/local/lib/nodejs
97 | sudo tar -xJf node.tar.xz -C /usr/local/lib/nodejs
98 | export PATH="/usr/local/lib/nodejs/node-v{$node_version}-linux-x64/bin:$PATH"
99 | printf "export PATH=\"/usr/local/lib/nodejs/node-v{$node_version}-linux-x64/bin:\$PATH\"" >> ~/.bash_profile
100 | source ~/.bash_profile
101 | fi
102 |
103 | # CDK CLI
104 | if which cdk > /dev/null
105 | then
106 | printf "CDK CLI is installed, skipping...\n"
107 | else
108 | printf "Installing AWS CDK CLI and bootstraping CDK environment...\n"
109 | sudo npm install -g aws-cdk@$cdk_version
110 | fi
111 |
112 | printf "Tools install complete. \n"
113 |
114 | cd $cur_dir
115 | rm -rf $tmp_dir
116 |
--------------------------------------------------------------------------------
/deploy/update.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 |
5 | printf "Updating stack using current configuration... \n"
6 | cd "${SCRIPTPATH}"/..
7 | npm install
8 | cdk deploy --no-rollback --require-approval never
9 | printf "Deploy complete. \n"
--------------------------------------------------------------------------------
/deploy/upload-model.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 |
5 | MODEL_BUCKET="$1"
6 |
7 | AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION:-$(aws ec2 describe-availability-zones --output text --query 'AvailabilityZones[0].[RegionName]')}
8 |
9 | MODEL_URL="https://huggingface.co/stabilityai/sdxl-turbo/resolve/main/sd_xl_turbo_1.0.safetensors"
10 | MODEL_NAME="sd_xl_turbo_1.0.safetensors"
11 |
12 | printf "Transport SDXL-Turbo model from hugging face to S3 bucket...\n"
13 | curl -L "$MODEL_URL" | aws s3 cp - s3://${MODEL_BUCKET}/Stable-diffusion/${MODEL_NAME}
14 |
15 | printf "Model uploaded to s3://${MODEL_BUCKET}/Stable-diffusion/${MODEL_NAME}\n"
16 |
--------------------------------------------------------------------------------
/docs/Gemfile:
--------------------------------------------------------------------------------
1 | source 'https://rubygems.org'
2 |
3 | gem "jekyll", "~> 4.3.3" # installed by `gem jekyll`
4 | # gem "webrick" # required when using Ruby >= 3 and Jekyll <= 4.2.2
5 |
6 | gem "just-the-docs", "0.8.2" # pinned to the current release
7 | # gem "just-the-docs" # always download the latest release
--------------------------------------------------------------------------------
/docs/Gemfile.lock:
--------------------------------------------------------------------------------
1 | GEM
2 | remote: https://rubygems.org/
3 | specs:
4 | addressable (2.8.7)
5 | public_suffix (>= 2.0.2, < 7.0)
6 | colorator (1.1.0)
7 | concurrent-ruby (1.3.4)
8 | em-websocket (0.5.3)
9 | eventmachine (>= 0.12.9)
10 | http_parser.rb (~> 0)
11 | eventmachine (1.2.7)
12 | ffi (1.17.0)
13 | forwardable-extended (2.6.0)
14 | google-protobuf (3.25.5)
15 | google-protobuf (3.25.5-arm64-darwin)
16 | google-protobuf (3.25.5-x86_64-darwin)
17 | google-protobuf (3.25.5-x86_64-linux)
18 | http_parser.rb (0.8.0)
19 | i18n (1.14.6)
20 | concurrent-ruby (~> 1.0)
21 | jekyll (4.3.4)
22 | addressable (~> 2.4)
23 | colorator (~> 1.0)
24 | em-websocket (~> 0.5)
25 | i18n (~> 1.0)
26 | jekyll-sass-converter (>= 2.0, < 4.0)
27 | jekyll-watch (~> 2.0)
28 | kramdown (~> 2.3, >= 2.3.1)
29 | kramdown-parser-gfm (~> 1.0)
30 | liquid (~> 4.0)
31 | mercenary (>= 0.3.6, < 0.5)
32 | pathutil (~> 0.9)
33 | rouge (>= 3.0, < 5.0)
34 | safe_yaml (~> 1.0)
35 | terminal-table (>= 1.8, < 4.0)
36 | webrick (~> 1.7)
37 | jekyll-include-cache (0.2.1)
38 | jekyll (>= 3.7, < 5.0)
39 | jekyll-sass-converter (3.0.0)
40 | sass-embedded (~> 1.54)
41 | jekyll-seo-tag (2.8.0)
42 | jekyll (>= 3.8, < 5.0)
43 | jekyll-watch (2.2.1)
44 | listen (~> 3.0)
45 | just-the-docs (0.8.2)
46 | jekyll (>= 3.8.5)
47 | jekyll-include-cache
48 | jekyll-seo-tag (>= 2.0)
49 | rake (>= 12.3.1)
50 | kramdown (2.4.0)
51 | rexml
52 | kramdown-parser-gfm (1.1.0)
53 | kramdown (~> 2.0)
54 | liquid (4.0.4)
55 | listen (3.9.0)
56 | rb-fsevent (~> 0.10, >= 0.10.3)
57 | rb-inotify (~> 0.9, >= 0.9.10)
58 | mercenary (0.4.0)
59 | pathutil (0.16.2)
60 | forwardable-extended (~> 2.6)
61 | public_suffix (6.0.1)
62 | rake (13.2.1)
63 | rb-fsevent (0.11.2)
64 | rb-inotify (0.11.1)
65 | ffi (~> 1.0)
66 | rexml (3.3.9)
67 | rouge (4.4.0)
68 | safe_yaml (1.0.5)
69 | sass-embedded (1.69.5)
70 | google-protobuf (~> 3.23)
71 | rake (>= 13.0.0)
72 | terminal-table (3.0.2)
73 | unicode-display_width (>= 1.1.1, < 3)
74 | unicode-display_width (2.6.0)
75 | webrick (1.9.0)
76 |
77 | PLATFORMS
78 | arm64-darwin
79 | ruby
80 | x86_64-darwin
81 | x86_64-linux
82 |
83 | DEPENDENCIES
84 | jekyll (~> 4.3.3)
85 | just-the-docs (= 0.8.2)
86 |
87 | BUNDLED WITH
88 | 2.3.5
89 |
--------------------------------------------------------------------------------
/docs/README.md:
--------------------------------------------------------------------------------
1 | # Docs for Guidance for Asynchronous Inference with Stable Diffusion on AWS
2 |
3 | This folder contains the documentation code for the guidance.
4 |
5 | ## Preview in your environment
6 |
7 | The doc uses [Jekyll](https://jekyllrb.com/) with [Just the Docs](https://just-the-docs.com/) theme.
8 |
9 | Assuming [Jekyll](https://jekyllrb.com/) and [Bundler](https://bundler.io/) are installed on your computer, you can preview the doc by the following:
10 |
11 | 1. Run `bundle install`.
12 | 2. Change your working directory to language variant (`zh` for Simplified Chinese, and `en` for English)
13 | 3. Run `bundle exec jekyll serve` to build and preview the doc at `localhost:4000`.
--------------------------------------------------------------------------------
/docs/api/v1alpha1.yaml:
--------------------------------------------------------------------------------
1 | openapi: 3.0.3
2 | info:
3 | title: SD on EKS
4 | version: v1alpha1
5 | x-logo:
6 | url: ""
7 | servers:
8 | - url: http://localhost:8080/v1alpha1
9 | description: ""
10 | paths:
11 | /:
12 | summary: Submit task to SD on EKS runtimes
13 | post:
14 | requestBody:
15 | description: "Content of image generating task"
16 | content:
17 | application/json:
18 | schema:
19 | $ref: '#/components/schemas/Task'
20 | required: true
21 | operationId: SubmitTask
22 | responses:
23 | '200':
24 | description: Successful response, task is accepted
25 | content:
26 | application/json:
27 | schema:
28 | title: Sample
29 | type: object
30 | properties:
31 | id_task:
32 | type: string
33 | description: Task ID
34 | example: "abc"
35 | task:
36 | description: Type of task.
37 | example: "text-to-image"
38 | type: string
39 | sd_model_checkpoint:
40 | type: string
41 | description: Selected model
42 | example: "sd_xl_turbo_1.0.safetensors"
43 | output_location:
44 | type: string
45 | description: Location of output file
46 | example: "s3://sdoneksstack-outputs3bucket/abc"
47 | security:
48 | - apikey: []
49 | externalDocs:
50 | description: Usage introduction
51 | url: >-
52 | https://aws-samples.github.io/stable-diffusion-on-eks/zh/implementation-guide/usage/
53 | tags: []
54 | security: []
55 | components:
56 | schemas:
57 | Task:
58 | type: object
59 | properties:
60 | task:
61 | type: object
62 | properties:
63 | alwayson_scripts:
64 | type: object
65 | properties:
66 | task:
67 | description: Type of task. Should be "text-to-image" or "image-to-image" for SD Web UI
68 | enum:
69 | - text-to-image
70 | - image-to-image
71 | example: "text-to-image"
72 | type: string
73 | sd_model_checkpoint:
74 | description: Checkpoint used in the task
75 | example: "sd_xl_turbo_1.0.safetensors"
76 | type: string
77 | id_task:
78 | description: Task ID
79 | example: "sd_xl_turbo_1.0.safetensors"
80 | type: string
81 | uid:
82 | type: string
83 | description: User ID
84 | example: "abc"
85 | save_dir:
86 | type: string
87 | description: Directory stored in s3 bucket, should be "output" only
88 | example: output
89 | prompt:
90 | type: string
91 | example: "a dog"
92 | steps:
93 | type: integer
94 | example: 16
95 | height:
96 | type: integer
97 | example: 512
98 | width:
99 | type: integer
100 | example: 512
101 | securitySchemes:
102 | apikey:
103 | type: apiKey
104 | description: API Key generated by AWS API Gateway
105 | name: x-api-key
106 | in: header
--------------------------------------------------------------------------------
/docs/api/v1alpha2.yaml:
--------------------------------------------------------------------------------
1 | openapi: 3.0.3
2 | info:
3 | title: SD on EKS
4 | version: v1alpha2
5 | x-logo:
6 | url: ""
7 | servers:
8 | - url: http://localhost:8080/
9 | description: ""
10 | paths:
11 | /task:
12 | summary: Submit task to SD on EKS runtimes
13 | post:
14 | requestBody:
15 | description: "Content of image generating task"
16 | content:
17 | application/json:
18 | schema:
19 | $ref: '#/components/schemas/Task'
20 | required: true
21 | operationId: SubmitTask
22 | responses:
23 | '200':
24 | description: Successful response, task is accepted
25 | content:
26 | application/json:
27 | schema:
28 | title: Sample
29 | type: object
30 | properties:
31 | id:
32 | type: string
33 | description: Task ID
34 | example: "abc"
35 | runtime:
36 | type: string
37 | description: Selected runtime
38 | example: "sdruntime1"
39 | output_location:
40 | type: string
41 | description: Location of output file
42 | example: "s3://sdoneksstack-outputs3bucket/abc"
43 | security:
44 | - apikey: []
45 | externalDocs:
46 | description: Usage introduction
47 | url: >-
48 | https://aws-samples.github.io/stable-diffusion-on-eks/zh/implementation-guide/usage/
49 | tags: []
50 | security: []
51 | components:
52 | schemas:
53 | Task:
54 | type: object
55 | properties:
56 | task:
57 | type: object
58 | properties:
59 | metadata:
60 | type: object
61 | properties:
62 | id:
63 | type: string
64 | description: Task ID
65 | example: "abc"
66 | runtime:
67 | type: string
68 | description: Runtime used in this task, should match with installed runtime.
69 | example: sdruntime1
70 | tasktype:
71 | description: Type of task. Should be "text-to-image" or "image-to-image" for SD Web UI, and "pipeline" for ComfyUI
72 | enum:
73 | - text-to-image
74 | - image-to-image
75 | - pipeline
76 | example: "text-to-image"
77 | type: string
78 | prefix:
79 | type: string
80 | description: Prefix for output file store
81 | example: "/output"
82 | context:
83 | type: object
84 | description: All content in context will be brought to callback
85 | example: ""
86 | content:
87 | type: object
88 | example: ""
89 | securitySchemes:
90 | apikey:
91 | type: apiKey
92 | description: API Key generated by AWS API Gateway
93 | name: x-api-key
94 | in: header
95 |
--------------------------------------------------------------------------------
/docs/en/_config.yml:
--------------------------------------------------------------------------------
1 | theme: just-the-docs
2 |
3 | callouts:
4 | highlight:
5 | color: yellow
6 | note:
7 | color: purple
8 | new:
9 | color: green
10 | warning:
11 | color: red
--------------------------------------------------------------------------------
/docs/en/_includes/image.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/docs/en/async_img_sd_images/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-solutions-library-samples/guidance-for-asynchronous-inference-with-stable-diffusion-on-aws/efc4a0aef8367cb49944b3fbca1d72e70d27d457/docs/en/async_img_sd_images/.gitkeep
--------------------------------------------------------------------------------
/docs/en/async_img_sd_images/asynchronous-image-generation-with-stable-diffusion-on-aws.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-solutions-library-samples/guidance-for-asynchronous-inference-with-stable-diffusion-on-aws/efc4a0aef8367cb49944b3fbca1d72e70d27d457/docs/en/async_img_sd_images/asynchronous-image-generation-with-stable-diffusion-on-aws.png
--------------------------------------------------------------------------------
/docs/en/async_img_sd_sidebar.yml:
--------------------------------------------------------------------------------
1 | entries:
2 | - title: Guidance for Asynchronous Image Generation with Stable Diffusion on AWS
3 | folders:
4 | - title: Overview
5 | url: '#overview'
6 | output: web
7 |
8 | - title: Architecture Overview
9 | url: '#architecture-overview'
10 | output: web
11 |
12 | - title: Cost
13 | url: '#cost'
14 | output: web
15 |
16 | - title: Security
17 | url: '#security'
18 | output: web
19 |
20 | - title: Deploy the Guidance
21 | url: '#deploy-the-guidance'
22 | output: web
23 |
24 | - title: Usage Guide
25 | url: '#usage-guide'
26 | output: web
27 |
28 | - title: Uninstall the guidance
29 | url: '#uninstall-the-guidance'
30 | output: web
31 |
32 | - title: Related resources
33 | url: '#related-resources'
34 | output: web
35 |
36 | - title: Contributors
37 | url: '#contributors'
38 | output: web
39 |
40 | - title: Notices
41 | url: '#notices'
42 | output: web
43 |
--------------------------------------------------------------------------------
/docs/stable-diffusion-reference-architecture-updated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-solutions-library-samples/guidance-for-asynchronous-inference-with-stable-diffusion-on-aws/efc4a0aef8367cb49944b3fbca1d72e70d27d457/docs/stable-diffusion-reference-architecture-updated.png
--------------------------------------------------------------------------------
/docs/zh/_config.yml:
--------------------------------------------------------------------------------
1 | theme: just-the-docs
2 |
3 | callouts:
4 | highlight:
5 | color: yellow
6 | note:
7 | color: purple
8 | new:
9 | color: green
10 | warning:
11 | color: red
--------------------------------------------------------------------------------
/docs/zh/_includes/image.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/docs/zh/async_img_sd_zh_images/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-solutions-library-samples/guidance-for-asynchronous-inference-with-stable-diffusion-on-aws/efc4a0aef8367cb49944b3fbca1d72e70d27d457/docs/zh/async_img_sd_zh_images/.gitkeep
--------------------------------------------------------------------------------
/docs/zh/async_img_sd_zh_images/stable-diffusion-reference-architecture-updated.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/aws-solutions-library-samples/guidance-for-asynchronous-inference-with-stable-diffusion-on-aws/efc4a0aef8367cb49944b3fbca1d72e70d27d457/docs/zh/async_img_sd_zh_images/stable-diffusion-reference-architecture-updated.png
--------------------------------------------------------------------------------
/docs/zh/async_img_sd_zh_sidebar.yml:
--------------------------------------------------------------------------------
1 | entries:
2 | - title: Guidance for Asynchronous Image Generation with Stable Diffusion on AWS
3 | folders:
4 | - title: 简介
5 | url: '#简介'
6 | output: web
7 |
8 | - title: 架构概览
9 | url: '#架构概览'
10 | output: web
11 |
12 | - title: 费用预估
13 | url: '#费用预估'
14 | output: web
15 |
16 | - title: 部署解决方案
17 | url: '#部署解决方案'
18 | output: web
19 |
20 | - title: 使用指南
21 | url: '#使用指南'
22 | output: web
23 |
24 | - title: 删除解决方案
25 | url: '#删除解决方案'
26 | output: web
27 |
--------------------------------------------------------------------------------
/lib/addons/dcgmExporter.ts:
--------------------------------------------------------------------------------
1 | import * as blueprints from '@aws-quickstart/eks-blueprints';
2 | import { Construct } from "constructs";
3 |
4 | export interface dcgmExporterAddOnProps extends blueprints.addons.HelmAddOnUserProps {
5 | }
6 |
7 | export const defaultProps: blueprints.addons.HelmAddOnProps & dcgmExporterAddOnProps = {
8 | chart: 'dcgm-exporter',
9 | name: 'dcgmExporterAddOn',
10 | namespace: 'kube-system',
11 | release: 'dcgm',
12 | version: '3.4.0',
13 | repository: 'https://nvidia.github.io/dcgm-exporter/helm-charts',
14 | values: {
15 | serviceMonitor: {
16 | enabled: false
17 | },
18 | affinity: {
19 | nodeAffinity: {
20 | requiredDuringSchedulingIgnoredDuringExecution: {
21 | nodeSelectorTerms: [{
22 | matchExpressions: [{
23 | key: "karpenter.k8s.aws/instance-gpu-count",
24 | operator: "Exists"
25 | }]
26 | }]
27 | }
28 | }
29 | },
30 | tolerations: [{
31 | key: "nvidia.com/gpu",
32 | operator: "Exists",
33 | effect: "NoSchedule"
34 | }, {
35 | key: "runtime",
36 | operator: "Exists",
37 | effect: "NoSchedule"
38 | }]
39 | }
40 | }
41 |
42 | export class dcgmExporterAddOn extends blueprints.addons.HelmAddOn {
43 |
44 | readonly options: dcgmExporterAddOnProps;
45 |
46 | constructor(props: dcgmExporterAddOnProps) {
47 | super({ ...defaultProps, ...props });
48 | this.options = this.props as dcgmExporterAddOnProps;
49 | }
50 |
51 | deploy(clusterInfo: blueprints.ClusterInfo): Promise {
52 | const cluster = clusterInfo.cluster;
53 |
54 | const chart = this.addHelmChart(clusterInfo, this.options.values, true, true);
55 | return Promise.resolve(chart);
56 | }
57 | }
--------------------------------------------------------------------------------
/lib/addons/ebsThroughputTuner.ts:
--------------------------------------------------------------------------------
1 | import { ClusterAddOn, ClusterInfo } from '@aws-quickstart/eks-blueprints';
2 | import { Construct } from "constructs";
3 | import * as cdk from 'aws-cdk-lib';
4 | import * as lambda from 'aws-cdk-lib/aws-lambda';
5 | import * as iam from "aws-cdk-lib/aws-iam"
6 | import * as path from 'path';
7 | import * as sfn from 'aws-cdk-lib/aws-stepfunctions';
8 | import * as tasks from 'aws-cdk-lib/aws-stepfunctions-tasks';
9 | import * as events from 'aws-cdk-lib/aws-events';
10 | import * as targets from 'aws-cdk-lib/aws-events-targets';
11 |
12 | export interface EbsThroughputTunerAddOnProps {
13 | duration: number;
14 | throughput: number;
15 | iops: number;
16 | }
17 |
18 | export class EbsThroughputTunerAddOn implements ClusterAddOn {
19 |
20 | readonly options: EbsThroughputTunerAddOnProps;
21 |
22 | constructor(props: EbsThroughputTunerAddOnProps) {
23 | this.options = props
24 | }
25 |
26 | deploy(clusterInfo: ClusterInfo): Promise {
27 | const cluster = clusterInfo.cluster;
28 |
29 | const lambdaTimeout: number = 300
30 |
31 | //EBS Throughput Modify lambda function
32 | const lambdaFunction = new lambda.Function(cluster.stack, 'EbsThroughputTunerLambda', {
33 | code: lambda.Code.fromAsset(path.join(__dirname, '../../src/tools/ebs_throughput_tuner')),
34 | handler: 'app.lambda_handler',
35 | runtime: lambda.Runtime.PYTHON_3_11,
36 | timeout: cdk.Duration.seconds(lambdaTimeout),
37 | environment: {
38 | "TARGET_EC2_TAG_KEY": "stack",
39 | "TARGET_EC2_TAG_VALUE": cdk.Aws.STACK_NAME,
40 | "THROUGHPUT_VALUE": this.options.throughput.toString(),
41 | "IOPS_VALUE": this.options.iops.toString()
42 | },
43 | });
44 |
45 | const functionRole = lambdaFunction.role!.addManagedPolicy(
46 | iam.ManagedPolicy.fromAwsManagedPolicyName(
47 | 'AmazonEC2FullAccess',
48 | ))
49 |
50 | // Step Functions definition
51 | const waitTask = new sfn.Wait(cluster.stack, 'Wait time', {
52 | time: sfn.WaitTime.duration(cdk.Duration.seconds(this.options.duration)),
53 | });
54 |
55 | const triggerTask = new tasks.LambdaInvoke(cluster.stack, 'Change throughput', {
56 | lambdaFunction: lambdaFunction
57 | }).addRetry({
58 | backoffRate: 2,
59 | maxAttempts: 3,
60 | interval: cdk.Duration.seconds(5)
61 | })
62 |
63 | const stateDefinition = waitTask
64 | .next(triggerTask)
65 |
66 | const stateMachine = new sfn.StateMachine(cluster.stack, 'EbsThroughputTunerStateMachine', {
67 | definitionBody: sfn.DefinitionBody.fromChainable(stateDefinition),
68 | timeout: cdk.Duration.seconds(this.options.duration + lambdaTimeout + 30),
69 | });
70 |
71 | lambdaFunction.grantInvoke(stateMachine)
72 |
73 | const rule = new events.Rule(cluster.stack, 'EbsThroughputTunerRule', {
74 | eventPattern: {
75 | detail: {
76 | 'state': events.Match.equalsIgnoreCase("running")
77 | },
78 | detailType: events.Match.equalsIgnoreCase('EC2 Instance State-change Notification'),
79 | source: ['aws.ec2'],
80 | }
81 | });
82 |
83 | rule.addTarget(new targets.SfnStateMachine(stateMachine))
84 |
85 | return Promise.resolve(rule);
86 | }
87 | }
--------------------------------------------------------------------------------
/lib/addons/s3CSIDriver.ts:
--------------------------------------------------------------------------------
1 | import * as blueprints from '@aws-quickstart/eks-blueprints';
2 | import * as eks from 'aws-cdk-lib/aws-eks';
3 | import * as iam from 'aws-cdk-lib/aws-iam';
4 | import { Construct } from "constructs";
5 |
6 | export interface s3CSIDriverAddOnProps extends blueprints.addons.HelmAddOnUserProps {
7 | s3BucketArn: string;
8 | }
9 |
10 | export const defaultProps: blueprints.addons.HelmAddOnProps & s3CSIDriverAddOnProps = {
11 | chart: 'aws-mountpoint-s3-csi-driver',
12 | name: 's3CSIDriverAddOn',
13 | namespace: 'kube-system',
14 | release: 's3-csi-driver-release',
15 | version: 'v1.10.0',
16 | repository: 'https://awslabs.github.io/mountpoint-s3-csi-driver',
17 | s3BucketArn: ''
18 | }
19 |
20 | export class s3CSIDriverAddOn extends blueprints.addons.HelmAddOn {
21 |
22 | readonly options: s3CSIDriverAddOnProps;
23 |
24 | constructor(props: s3CSIDriverAddOnProps) {
25 | super({ ...defaultProps, ...props });
26 | this.options = this.props as s3CSIDriverAddOnProps;
27 | }
28 |
29 | deploy(clusterInfo: blueprints.ClusterInfo): Promise {
30 | const cluster = clusterInfo.cluster;
31 | const serviceAccount = cluster.addServiceAccount('s3-csi-driver-sa', {
32 | name: 's3-csi-driver-sa',
33 | namespace: this.options.namespace
34 | });
35 |
36 | // new IAM policy to grand access to S3 bucket
37 | // https://github.com/awslabs/mountpoint-s3/blob/main/doc/CONFIGURATION.md#iam-permissions
38 | const s3BucketPolicy = new iam.Policy(cluster, 's3-csi-driver-policy', {
39 | statements: [
40 | new iam.PolicyStatement({
41 | sid: 'MountpointFullBucketAccess',
42 | actions: [
43 | "s3:ListBucket"
44 | ],
45 | resources: [this.options.s3BucketArn]
46 | }),
47 | new iam.PolicyStatement({
48 | sid: 'MountpointFullObjectAccess',
49 | actions: [
50 | "s3:GetObject",
51 | "s3:PutObject",
52 | "s3:AbortMultipartUpload",
53 | "s3:DeleteObject"
54 | ],
55 | resources: [`${this.options.s3BucketArn}/*`]
56 | })]
57 | });
58 | serviceAccount.role.attachInlinePolicy(s3BucketPolicy);
59 |
60 | const chart = this.addHelmChart(clusterInfo, {
61 | node: {
62 | serviceAccount: {
63 | create: false
64 | },
65 | tolerateAllTaints: true
66 | }
67 | }, true, true);
68 | return Promise.resolve(chart);
69 | }
70 | }
--------------------------------------------------------------------------------
/lib/addons/sharedComponent.ts:
--------------------------------------------------------------------------------
1 | import { ClusterAddOn, ClusterInfo } from '@aws-quickstart/eks-blueprints';
2 | import { Construct } from "constructs";
3 | import * as cdk from 'aws-cdk-lib';
4 | import * as sns from 'aws-cdk-lib/aws-sns';
5 | import * as s3 from 'aws-cdk-lib/aws-s3';
6 | import * as lambda from 'aws-cdk-lib/aws-lambda';
7 | import * as apigw from "aws-cdk-lib/aws-apigateway";
8 | import * as xray from "aws-cdk-lib/aws-xray"
9 | import * as path from 'path';
10 |
11 | export interface SharedComponentAddOnProps {
12 | inputSns: sns.ITopic;
13 | outputSns: sns.ITopic;
14 | outputBucket: s3.IBucket;
15 | apiGWProps?: {
16 | stageName?: string,
17 | throttle?: {
18 | rateLimit?: number,
19 | burstLimit?: number
20 | }
21 | }
22 | }
23 |
24 | export const defaultProps: SharedComponentAddOnProps = {
25 | inputSns: undefined!,
26 | outputSns: undefined!,
27 | outputBucket: undefined!,
28 | apiGWProps: {
29 | stageName: "prod",
30 | throttle: {
31 | rateLimit: 10,
32 | burstLimit: 2
33 | }
34 | }
35 | }
36 |
37 | export class SharedComponentAddOn implements ClusterAddOn {
38 | readonly options: SharedComponentAddOnProps;
39 |
40 | constructor(props: SharedComponentAddOnProps) {
41 | this.options = { ...defaultProps, ...props };
42 | }
43 |
44 | deploy(clusterInfo: ClusterInfo): Promise {
45 | const cluster = clusterInfo.cluster;
46 |
47 | const v1Alpha1Parser = new lambda.Function(cluster.stack, 'v1Alpha1ParserFunction', {
48 | code: lambda.Code.fromAsset(path.join(__dirname, '../../src/frontend/input_function/v1alpha1')),
49 | handler: 'app.lambda_handler',
50 | runtime: lambda.Runtime.PYTHON_3_11,
51 | environment: {
52 | "SNS_TOPIC_ARN": this.options.inputSns.topicArn,
53 | "S3_OUTPUT_BUCKET": this.options.outputBucket.bucketName
54 | },
55 | tracing: lambda.Tracing.ACTIVE
56 | });
57 |
58 | const v1Alpha2Parser = new lambda.Function(cluster.stack, 'v1Alpha2ParserFunction', {
59 | code: lambda.Code.fromAsset(path.join(__dirname, '../../src/frontend/input_function/v1alpha2')),
60 | handler: 'app.lambda_handler',
61 | runtime: lambda.Runtime.PYTHON_3_11,
62 | environment: {
63 | "SNS_TOPIC_ARN": this.options.inputSns.topicArn,
64 | "S3_OUTPUT_BUCKET": this.options.outputBucket.bucketName
65 | },
66 | tracing: lambda.Tracing.ACTIVE
67 | });
68 |
69 | this.options.inputSns.grantPublish(v1Alpha1Parser);
70 | this.options.inputSns.grantPublish(v1Alpha2Parser);
71 |
72 | const api = new apigw.RestApi(cluster.stack, 'FrontAPI', {
73 | restApiName: 'FrontAPI',
74 | deploy: true,
75 | cloudWatchRole: true,
76 | endpointConfiguration: {
77 | types: [ apigw.EndpointType.REGIONAL ]
78 | },
79 | defaultMethodOptions: {
80 | apiKeyRequired: true
81 | },
82 | deployOptions: {
83 | stageName: this.options.apiGWProps!.stageName,
84 | tracingEnabled: true,
85 | metricsEnabled: true,
86 | }
87 | });
88 |
89 | const v1alpha1Resource = api.root.addResource('v1alpha1');
90 | v1alpha1Resource.addMethod('POST', new apigw.LambdaIntegration(v1Alpha1Parser), { apiKeyRequired: true });
91 |
92 | const v1alpha2Resource = api.root.addResource('v1alpha2');
93 | v1alpha2Resource.addMethod('POST', new apigw.LambdaIntegration(v1Alpha2Parser), { apiKeyRequired: true });
94 |
95 | api.node.addDependency(v1Alpha1Parser);
96 | api.node.addDependency(v1Alpha2Parser);
97 |
98 | //Force override name of generated output to provide a static name
99 | const urlCfnOutput = api.node.findChild('Endpoint') as cdk.CfnOutput;
100 | urlCfnOutput.overrideLogicalId('FrontApiEndpoint')
101 |
102 | const apiKey = new apigw.ApiKey(cluster.stack, `defaultAPIKey`, {
103 | description: `Default API Key`,
104 | enabled: true
105 | })
106 |
107 | const plan = api.addUsagePlan('UsagePlan', {
108 | name: cluster.stack.stackId + '-Default',
109 | apiStages: [{
110 | stage: api.deploymentStage
111 | }],
112 | throttle: {
113 | rateLimit: this.options.apiGWProps!.throttle!.rateLimit,
114 | burstLimit: this.options.apiGWProps!.throttle!.burstLimit
115 | }
116 | });
117 |
118 | plan.addApiKey(apiKey)
119 |
120 | new cdk.CfnOutput(cluster.stack, 'GetAPIKeyCommand', {
121 | value: "aws apigateway get-api-keys --query 'items[?id==`" + apiKey.keyId + "`].value' --include-values --output text",
122 | description: 'Command to get API Key'
123 | });
124 |
125 | //Xray Access Policy
126 | new xray.CfnResourcePolicy(cluster.stack, cluster.stack.stackName+'XRayAccessPolicyForSNS', {
127 | policyName: cluster.stack.stackName+'XRayAccessPolicyForSNS',
128 | policyDocument: '{"Version":"2012-10-17","Statement":[{"Sid":"SNSAccess","Effect":"Allow","Principal":{"Service":"sns.amazonaws.com"},"Action":["xray:PutTraceSegments","xray:GetSamplingRules","xray:GetSamplingTargets"],"Resource":"*","Condition":{"StringEquals":{"aws:SourceAccount":"' + cluster.stack.account + '"},"StringLike":{"aws:SourceArn":"' + cluster.stack.formatArn({ service: "sns", resource: '*' }) + '"}}}]}'
129 | })
130 |
131 | // Output S3 bucket ARN
132 |
133 | new cdk.CfnOutput(cluster.stack, 'OutputS3Bucket', {
134 | value: this.options.outputBucket.bucketArn,
135 | description: 'S3 bucket for generated images'
136 | });
137 |
138 | return Promise.resolve(plan);
139 | }
140 | }
141 |
--------------------------------------------------------------------------------
/lib/dataPlane.ts:
--------------------------------------------------------------------------------
1 | import * as cdk from 'aws-cdk-lib';
2 | import * as blueprints from '@aws-quickstart/eks-blueprints';
3 | import { Construct } from "constructs";
4 | import * as eks from "aws-cdk-lib/aws-eks";
5 | import * as sns from 'aws-cdk-lib/aws-sns';
6 | import * as ec2 from 'aws-cdk-lib/aws-ec2';
7 | import * as s3 from 'aws-cdk-lib/aws-s3';
8 | import * as iam from 'aws-cdk-lib/aws-iam';
9 | import SDRuntimeAddon, { SDRuntimeAddOnProps } from './runtime/sdRuntime';
10 | import { EbsThroughputTunerAddOn, EbsThroughputTunerAddOnProps } from './addons/ebsThroughputTuner'
11 | import { s3CSIDriverAddOn, s3CSIDriverAddOnProps } from './addons/s3CSIDriver'
12 | import { SharedComponentAddOn, SharedComponentAddOnProps } from './addons/sharedComponent';
13 | import { SNSResourceProvider } from './resourceProvider/sns'
14 | import { s3GWEndpointProvider } from './resourceProvider/s3GWEndpoint'
15 | import { dcgmExporterAddOn } from './addons/dcgmExporter';
16 |
17 | export interface dataPlaneProps {
18 | stackName: string,
19 | modelBucketArn: string;
20 | APIGW?: {
21 | stageName?: string,
22 | throttle?: {
23 | rateLimit?: number,
24 | burstLimit?: number
25 | }
26 | }
27 | modelsRuntime: {
28 | name: string,
29 | namespace: string,
30 | type: string,
31 | modelFilename?: string,
32 | dynamicModel?: boolean,
33 | chartRepository?: string,
34 | chartVersion?: string,
35 | extraValues?: {}
36 | }[]
37 | }
38 |
39 | export default class DataPlaneStack {
40 | constructor(scope: Construct, id: string,
41 | dataplaneProps: dataPlaneProps,
42 | props: cdk.StackProps) {
43 |
44 | const kedaParams: blueprints.KedaAddOnProps = {
45 | podSecurityContextFsGroup: 1001,
46 | securityContextRunAsGroup: 1001,
47 | securityContextRunAsUser: 1001,
48 | irsaRoles: ["CloudWatchFullAccess", "AmazonSQSFullAccess"]
49 | };
50 |
51 | const cloudWatchInsightsParams: blueprints.CloudWatchInsightsAddOnProps = {
52 | configurationValues: {
53 | tolerations: [
54 | {
55 | key: "runtime",
56 | operator: "Exists",
57 | effect: "NoSchedule"
58 | },
59 | {
60 | key: "nvidia.com/gpu",
61 | operator: "Exists",
62 | effect: "NoSchedule"
63 | }
64 | ],
65 | containerLogs: {
66 | enabled: true,
67 | fluentBit: {
68 | config: {
69 | service: "[SERVICE]\n Flush 5\n Grace 30\n Log_Level info",
70 | extraFiles: {
71 | "application-log.conf": "[INPUT]\n Name tail\n Tag kube.*\n Path /var/log/containers/*.log\n Parser docker\n DB /var/log/flb_kube.db\n Mem_Buf_Limit 5MB\n Skip_Long_Lines On\n Refresh_Interval 10\n\n[FILTER]\n Name kubernetes\n Match kube.*\n Kube_URL https://kubernetes.default.svc:443\n Kube_CA_File /var/run/secrets/kubernetes.io/serviceaccount/ca.crt\n Kube_Token_File /var/run/secrets/kubernetes.io/serviceaccount/token\n Kube_Tag_Prefix kube.var.log.containers.\n Merge_Log On\n Merge_Log_Key log_processed\n K8S-Logging.Parser On\n K8S-Logging.Exclude On\n\n[FILTER]\n Name grep\n Match kube.*\n Exclude $kubernetes['namespace_name'] kube-system\n\n[OUTPUT]\n Name cloudwatch\n Match kube.*\n region ${AWS_REGION}\n log_group_name /aws/containerinsights/${CLUSTER_NAME}/application\n log_stream_prefix ${HOST_NAME}-\n auto_create_group true\n retention_in_days 7"
72 | }
73 | }
74 | }
75 | }
76 | }
77 | };
78 |
79 | const SharedComponentAddOnParams: SharedComponentAddOnProps = {
80 | inputSns: blueprints.getNamedResource("inputSNSTopic"),
81 | outputSns: blueprints.getNamedResource("outputSNSTopic"),
82 | outputBucket: blueprints.getNamedResource("outputS3Bucket"),
83 | apiGWProps: dataplaneProps.APIGW
84 | };
85 |
86 | const EbsThroughputModifyAddOnParams: EbsThroughputTunerAddOnProps = {
87 | duration: 300,
88 | throughput: 125,
89 | iops: 3000
90 | };
91 |
92 | const s3CSIDriverAddOnParams: s3CSIDriverAddOnProps = {
93 | s3BucketArn: dataplaneProps.modelBucketArn
94 | };
95 |
96 | const addOns: Array = [
97 | new blueprints.addons.VpcCniAddOn(),
98 | new blueprints.addons.CoreDnsAddOn(),
99 | new blueprints.addons.KubeProxyAddOn(),
100 | new blueprints.addons.AwsLoadBalancerControllerAddOn(),
101 | new blueprints.addons.KarpenterAddOn({ interruptionHandling: true }),
102 | new blueprints.addons.KedaAddOn(kedaParams),
103 | new blueprints.addons.CloudWatchInsights(cloudWatchInsightsParams),
104 | new s3CSIDriverAddOn(s3CSIDriverAddOnParams),
105 | new SharedComponentAddOn(SharedComponentAddOnParams),
106 | new EbsThroughputTunerAddOn(EbsThroughputModifyAddOnParams),
107 | ];
108 |
109 | // Generate SD Runtime Addon for runtime
110 | dataplaneProps.modelsRuntime.forEach((val, idx, array) => {
111 | const sdRuntimeParams: SDRuntimeAddOnProps = {
112 | modelBucketArn: dataplaneProps.modelBucketArn,
113 | outputSns: blueprints.getNamedResource("outputSNSTopic") as sns.ITopic,
114 | inputSns: blueprints.getNamedResource("inputSNSTopic") as sns.ITopic,
115 | outputBucket: blueprints.getNamedResource("outputS3Bucket") as s3.IBucket,
116 | type: val.type.toLowerCase(),
117 | chartRepository: val.chartRepository,
118 | chartVersion: val.chartVersion,
119 | extraValues: val.extraValues,
120 | targetNamespace: val.namespace,
121 | };
122 |
123 | //Parameters for SD Web UI
124 | if (val.type.toLowerCase() == "sdwebui") {
125 | if (val.modelFilename) {
126 | sdRuntimeParams.sdModelCheckpoint = val.modelFilename
127 | }
128 | if (val.dynamicModel == true) {
129 | sdRuntimeParams.dynamicModel = true
130 | } else {
131 | sdRuntimeParams.dynamicModel = false
132 | }
133 | }
134 |
135 | if (val.type.toLowerCase() == "comfyui") {}
136 |
137 | addOns.push(new SDRuntimeAddon(sdRuntimeParams, val.name))
138 | });
139 |
140 | // Define initial managed node group for cluster components
141 | const MngProps: blueprints.MngClusterProviderProps = {
142 | minSize: 2,
143 | maxSize: 2,
144 | desiredSize: 2,
145 | version: eks.KubernetesVersion.V1_31,
146 | instanceTypes: [new ec2.InstanceType('m7g.large')],
147 | amiType: eks.NodegroupAmiType.AL2023_ARM_64_STANDARD,
148 | enableSsmPermissions: true,
149 | nodeGroupTags: {
150 | "Name": cdk.Aws.STACK_NAME + "-ClusterComponents",
151 | "stack": cdk.Aws.STACK_NAME
152 | }
153 | }
154 |
155 | // Deploy EKS cluster with all add-ons
156 | const blueprint = blueprints.EksBlueprint.builder()
157 | .version(eks.KubernetesVersion.V1_31)
158 | .addOns(...addOns)
159 | .resourceProvider(
160 | blueprints.GlobalResources.Vpc,
161 | new blueprints.VpcProvider())
162 | .resourceProvider("inputSNSTopic", new SNSResourceProvider("sdNotificationLambda"))
163 | .resourceProvider("outputSNSTopic", new SNSResourceProvider("sdNotificationOutput"))
164 | .resourceProvider("outputS3Bucket", new blueprints.CreateS3BucketProvider({
165 | id: 'outputS3Bucket'
166 | }))
167 | .resourceProvider("s3GWEndpoint", new s3GWEndpointProvider("s3GWEndpoint"))
168 | .clusterProvider(new blueprints.MngClusterProvider(MngProps))
169 | .build(scope, id + 'Stack', props);
170 | /*
171 | // Workaround for permission denied when creating cluster
172 | const handler = blueprint.node.tryFindChild('@aws-cdk--aws-eks.KubectlProvider')!
173 | .node.tryFindChild('Handler')! as cdk.aws_lambda.Function
174 |
175 | (
176 | blueprint.node.tryFindChild('@aws-cdk--aws-eks.KubectlProvider')!
177 | .node.tryFindChild('Provider')!
178 | .node.tryFindChild('framework-onEvent')!
179 | .node.tryFindChild('ServiceRole')!
180 | .node.tryFindChild('DefaultPolicy') as cdk.aws_iam.Policy
181 | )
182 | .addStatements(new cdk.aws_iam.PolicyStatement({
183 | effect: cdk.aws_iam.Effect.ALLOW,
184 | actions: ["lambda:GetFunctionConfiguration"],
185 | resources: [handler.functionArn]
186 | }))
187 | */
188 | // Provide static output name for cluster
189 | const cluster = blueprint.getClusterInfo().cluster
190 | const clusterNameCfnOutput = cluster.node.findChild('ClusterName') as cdk.CfnOutput;
191 | clusterNameCfnOutput.overrideLogicalId('ClusterName')
192 |
193 | const configCommandCfnOutput = cluster.node.findChild('ConfigCommand') as cdk.CfnOutput;
194 | configCommandCfnOutput.overrideLogicalId('ConfigCommand')
195 |
196 | const getTokenCommandCfnOutput = cluster.node.findChild('GetTokenCommand') as cdk.CfnOutput;
197 | getTokenCommandCfnOutput.overrideLogicalId('GetTokenCommand')
198 | }
199 | }
--------------------------------------------------------------------------------
/lib/resourceProvider/s3GWEndpoint.ts:
--------------------------------------------------------------------------------
1 | import * as blueprints from '@aws-quickstart/eks-blueprints';
2 | import * as ec2 from 'aws-cdk-lib/aws-ec2';
3 |
4 | export class s3GWEndpointProvider implements blueprints.ResourceProvider {
5 | constructor(readonly name: string) { }
6 |
7 | provide(context: blueprints.ResourceContext): ec2.IGatewayVpcEndpoint {
8 | const vpc = context.get('vpc') as ec2.IVpc
9 | const vpce = new ec2.GatewayVpcEndpoint(context.scope, this.name, {
10 | service: ec2.GatewayVpcEndpointAwsService.S3,
11 | vpc: vpc
12 | });
13 |
14 | return vpce
15 | }
16 | }
--------------------------------------------------------------------------------
/lib/resourceProvider/sns.ts:
--------------------------------------------------------------------------------
1 | import * as cdk from 'aws-cdk-lib';
2 | import * as blueprints from '@aws-quickstart/eks-blueprints';
3 | import * as sns from 'aws-cdk-lib/aws-sns';
4 |
5 | export class SNSResourceProvider implements blueprints.ResourceProvider {
6 | constructor(readonly topicName: string, readonly displayName?: string) { }
7 |
8 | provide(context: blueprints.ResourceContext): sns.ITopic {
9 |
10 | const cfnTopic = new cdk.aws_sns.CfnTopic(context.scope, this.topicName + 'Cfn', {
11 | displayName: this.displayName,
12 | tracingConfig: 'Active'
13 | });
14 |
15 | new cdk.CfnOutput(context.scope, this.topicName + 'ARN', {
16 | value: cfnTopic.attrTopicArn
17 | })
18 |
19 | return sns.Topic.fromTopicArn(context.scope, this.topicName, cfnTopic.attrTopicArn)
20 | }
21 | }
--------------------------------------------------------------------------------
/lib/resourceProvider/vpc.ts:
--------------------------------------------------------------------------------
1 | import { Tags } from 'aws-cdk-lib';
2 | import * as ec2 from 'aws-cdk-lib/aws-ec2';
3 | import { ISubnet, PrivateSubnet } from 'aws-cdk-lib/aws-ec2';
4 | import * as blueprints from '@aws-quickstart/eks-blueprints';
5 |
6 | /**
7 | * Interface for Mapping for fields such as Primary CIDR, Secondary CIDR, Secondary Subnet CIDR.
8 | */
9 | interface VpcProps {
10 | primaryCidr?: string,
11 | secondaryCidr?: string,
12 | secondarySubnetCidrs?: string[]
13 | }
14 |
15 | /**
16 | * VPC resource provider
17 | */
18 | export class VpcProvider implements blueprints.ResourceProvider {
19 | readonly vpcId?: string;
20 | readonly primaryCidr?: string;
21 | readonly secondaryCidr?: string;
22 | readonly secondarySubnetCidrs?: string[];
23 |
24 | constructor(vpcId?: string, private vpcProps?: VpcProps) {
25 | this.vpcId = vpcId;
26 | this.primaryCidr = vpcProps?.primaryCidr;
27 | this.secondaryCidr = vpcProps?.secondaryCidr;
28 | this.secondarySubnetCidrs = vpcProps?.secondarySubnetCidrs;
29 | }
30 |
31 | provide(context: blueprints.ResourceContext): ec2.IVpc {
32 | const id = context.scope.node.id;
33 |
34 | let vpc = getVPCFromId(context, id, this.vpcId);
35 | if (vpc == null) {
36 | // It will automatically divide the provided VPC CIDR range, and create public and private subnets per Availability Zone.
37 | // If VPC CIDR range is not provided, uses `10.0.0.0/16` as the range and creates public and private subnets per Availability Zone.
38 | // Network routing for the public subnets will be configured to allow outbound access directly via an Internet Gateway.
39 | // Network routing for the private subnets will be configured to allow outbound access via a set of resilient NAT Gateways (one per AZ).
40 | // Creates Secondary CIDR and Secondary subnets if passed.
41 | if (this.primaryCidr) {
42 | vpc = new ec2.Vpc(context.scope, id + "-vpc",{
43 | ipAddresses: ec2.IpAddresses.cidr(this.primaryCidr)
44 | });
45 | }
46 | else {
47 | vpc = new ec2.Vpc(context.scope, id + "-vpc");
48 | }
49 | }
50 |
51 |
52 | if (this.secondaryCidr) {
53 | this.createSecondarySubnets(context, id, vpc);
54 | }
55 |
56 | return vpc;
57 | }
58 |
59 | protected createSecondarySubnets(context: blueprints.ResourceContext, id: string, vpc: ec2.IVpc) {
60 | const secondarySubnets: Array = [];
61 | const secondaryCidr = new ec2.CfnVPCCidrBlock(context.scope, id + "-secondaryCidr", {
62 | vpcId: vpc.vpcId,
63 | cidrBlock: this.secondaryCidr
64 | });
65 | secondaryCidr.node.addDependency(vpc);
66 | if (this.secondarySubnetCidrs) {
67 | for (let i = 0; i < vpc.availabilityZones.length; i++) {
68 | if (this.secondarySubnetCidrs[i]) {
69 | secondarySubnets[i] = new ec2.PrivateSubnet(context.scope, id + "private-subnet-" + i, {
70 | availabilityZone: vpc.availabilityZones[i],
71 | cidrBlock: this.secondarySubnetCidrs[i],
72 | vpcId: vpc.vpcId
73 | });
74 | secondarySubnets[i].node.addDependency(secondaryCidr);
75 | context.add("secondary-cidr-subnet-" + i, {
76 | provide(_context): ISubnet { return secondarySubnets[i]; }
77 | });
78 | }
79 | }
80 | for (let secondarySubnet of secondarySubnets) {
81 | Tags.of(secondarySubnet).add("kubernetes.io/role/internal-elb", "1", { applyToLaunchedInstances: true });
82 | Tags.of(secondarySubnet).add("Name", `blueprint-construct-dev-PrivateSubnet-${secondarySubnet}`, { applyToLaunchedInstances: true });
83 | }
84 | }
85 | }
86 | }
87 |
88 |
89 |
90 | /*
91 | ** This function will give return vpc based on the ResourceContext and vpcId passed to the cluster.
92 | */
93 | export function getVPCFromId(context: blueprints.ResourceContext, nodeId: string, vpcId?: string) {
94 | let vpc = undefined;
95 | if (vpcId) {
96 | if (vpcId === "default") {
97 | console.log(`looking up completely default VPC`);
98 | vpc = ec2.Vpc.fromLookup(context.scope, nodeId + "-vpc", { isDefault: true });
99 | } else {
100 | console.log(`looking up non-default ${vpcId} VPC`);
101 | vpc = ec2.Vpc.fromLookup(context.scope, nodeId + "-vpc", { vpcId: vpcId });
102 | }
103 | }
104 | return vpc;
105 | }
106 |
--------------------------------------------------------------------------------
/lib/runtime/sdRuntime.ts:
--------------------------------------------------------------------------------
1 | import * as cdk from 'aws-cdk-lib';
2 | import { Construct } from 'constructs';
3 | import * as sns from 'aws-cdk-lib/aws-sns';
4 | import * as blueprints from '@aws-quickstart/eks-blueprints';
5 | import * as iam from 'aws-cdk-lib/aws-iam';
6 | import * as sqs from 'aws-cdk-lib/aws-sqs';
7 | import * as s3 from 'aws-cdk-lib/aws-s3';
8 | import { aws_sns_subscriptions } from "aws-cdk-lib";
9 | import * as lodash from "lodash";
10 | import { createNamespace } from "../utils/namespace"
11 |
12 | export interface SDRuntimeAddOnProps extends blueprints.addons.HelmAddOnUserProps {
13 | type: string,
14 | targetNamespace?: string,
15 | modelBucketArn?: string,
16 | outputSns?: sns.ITopic,
17 | inputSns?: sns.ITopic,
18 | outputBucket?: s3.IBucket
19 | sdModelCheckpoint?: string,
20 | dynamicModel?: boolean,
21 | chartRepository?: string,
22 | chartVersion?: string,
23 | extraValues?: object
24 | }
25 |
26 | export const defaultProps: blueprints.addons.HelmAddOnProps & SDRuntimeAddOnProps = {
27 | chart: 'sd-on-eks',
28 | name: 'sdRuntimeAddOn',
29 | namespace: 'sdruntime',
30 | release: 'sdruntime',
31 | version: '1.1.3',
32 | repository: 'oci://public.ecr.aws/bingjiao/charts/sd-on-eks',
33 | values: {},
34 | type: "sdwebui"
35 | }
36 |
37 | export default class SDRuntimeAddon extends blueprints.addons.HelmAddOn {
38 |
39 | readonly options: SDRuntimeAddOnProps;
40 | readonly id: string;
41 |
42 | constructor(props: SDRuntimeAddOnProps, id?: string) {
43 | super({ ...defaultProps, ...props });
44 | this.options = this.props as SDRuntimeAddOnProps;
45 | if (id) {
46 | this.id = id!.toLowerCase()
47 | } else {
48 | this.id = 'sdruntime'
49 | }
50 | }
51 |
52 | @blueprints.utils.dependable(blueprints.KarpenterAddOn.name)
53 | @blueprints.utils.dependable("SharedComponentAddOn")
54 | @blueprints.utils.dependable("s3CSIDriverAddOn")
55 |
56 | deploy(clusterInfo: blueprints.ClusterInfo): Promise {
57 | const cluster = clusterInfo.cluster
58 |
59 | this.props.name = this.id + 'Addon'
60 | this.props.release = this.id
61 |
62 | if (this.options.targetNamespace) {
63 | this.props.namespace = this.options.targetNamespace.toLowerCase()
64 | } else {
65 | this.props.namespace = "default"
66 | }
67 |
68 | const ns = createNamespace(this.id+"-"+this.props.namespace+"-namespace-struct", this.props.namespace, cluster, true)
69 |
70 | const runtimeSA = cluster.addServiceAccount('runtimeSA' + this.id, { namespace: this.props.namespace });
71 | runtimeSA.node.addDependency(ns)
72 |
73 | if (this.options.chartRepository) {
74 | this.props.repository = this.options.chartRepository
75 | }
76 |
77 | if (this.options.chartVersion) {
78 | this.props.version = this.options.chartVersion
79 | }
80 |
81 | const modelBucket = s3.Bucket.fromBucketAttributes(cluster.stack, 'ModelBucket' + this.id, {
82 | bucketArn: this.options.modelBucketArn!
83 | });
84 |
85 | modelBucket.grantRead(runtimeSA);
86 |
87 | const inputQueue = new sqs.Queue(cluster.stack, 'InputQueue' + this.id);
88 | inputQueue.grantConsumeMessages(runtimeSA);
89 |
90 |
91 | this.options.outputBucket!.grantWrite(runtimeSA);
92 | this.options.outputBucket!.grantPutAcl(runtimeSA);
93 | this.options.outputSns!.grantPublish(runtimeSA);
94 |
95 | runtimeSA.role.addManagedPolicy(
96 | iam.ManagedPolicy.fromAwsManagedPolicyName(
97 | 'AWSXRayDaemonWriteAccess',
98 | ))
99 |
100 | const nodeRole = clusterInfo.cluster.node.findChild('karpenter-node-role') as iam.IRole
101 |
102 | var generatedValues = {
103 | global: {
104 | awsRegion: cdk.Stack.of(cluster).region,
105 | stackName: cdk.Stack.of(cluster).stackName,
106 | runtime: this.id
107 | },
108 | runtime: {
109 | type: this.options.type,
110 | serviceAccountName: runtimeSA.serviceAccountName,
111 | queueAgent: {
112 | s3Bucket: this.options.outputBucket!.bucketName,
113 | snsTopicArn: this.options.outputSns!.topicArn,
114 | sqsQueueUrl: inputQueue.queueUrl,
115 | },
116 | persistence: {
117 | enabled: true,
118 | storageClass: "-",
119 | s3: {
120 | enabled: true,
121 | modelBucket: modelBucket.bucketName
122 | }
123 | }
124 | },
125 | karpenter: {
126 | nodeTemplate: {
127 | iamRole: nodeRole.roleName
128 | }
129 | }
130 | }
131 | // Temp change: set image repo to ECR Public
132 | if (this.options.type == "sdwebui") {
133 | var imagerepo: string
134 | if (!(lodash.get(this.options, "extraValues.runtime.inferenceApi.image.repository"))) {
135 | imagerepo = "public.ecr.aws/bingjiao/sd-on-eks/sdwebui"
136 | } else {
137 | imagerepo = lodash.get(this.options, "extraValues.runtime.inferenceApi.image.repository")!
138 | }
139 | var sdWebUIgeneratedValues = {
140 | runtime: {
141 | inferenceApi: {
142 | image: {
143 | repository: imagerepo
144 | },
145 | modelFilename: this.options.sdModelCheckpoint
146 | },
147 | queueAgent: {
148 | dynamicModel: this.options.dynamicModel
149 | }
150 | }
151 | }
152 |
153 | generatedValues = lodash.merge(generatedValues, sdWebUIgeneratedValues)
154 | }
155 |
156 | if (this.options.type == "comfyui") {
157 | var imagerepo: string
158 | if (!(lodash.get(this.options, "extraValues.runtime.inferenceApi.image.repository"))) {
159 | imagerepo = "public.ecr.aws/bingjiao/sd-on-eks/comfyui"
160 | } else {
161 | imagerepo = lodash.get(this.options, "extraValues.runtime.inferenceApi.image.repository")!
162 | }
163 |
164 | var comfyUIgeneratedValues = {
165 | runtime: {
166 | inferenceApi: {
167 | image: {
168 | repository: imagerepo
169 | },
170 | }
171 | }
172 | }
173 |
174 | generatedValues = lodash.merge(generatedValues, comfyUIgeneratedValues)
175 | }
176 |
177 | if (this.options.type == "sdwebui" && this.options.sdModelCheckpoint) {
178 | // Legacy and new routing, use CFN as a workaround since L2 construct doesn't support OR
179 | const cfnSubscription = new sns.CfnSubscription(cluster.stack, this.id+'CfnSubscription', {
180 | protocol: 'sqs',
181 | endpoint: inputQueue.queueArn,
182 | topicArn: this.options.inputSns!.topicArn,
183 | filterPolicy: {
184 | "$or": [
185 | {
186 | "sd_model_checkpoint": [
187 | this.options.sdModelCheckpoint!
188 | ]
189 | }, {
190 | "runtime": [
191 | this.id
192 | ]
193 | }]
194 | },
195 | filterPolicyScope: "MessageAttributes"
196 | })
197 |
198 | inputQueue.addToResourcePolicy(new iam.PolicyStatement({
199 | effect: iam.Effect.ALLOW,
200 | principals: [new iam.ServicePrincipal('sns.amazonaws.com')],
201 | actions: ['sqs:SendMessage'],
202 | resources: [inputQueue.queueArn],
203 | conditions: {
204 | 'ArnEquals': {
205 | 'aws:SourceArn': this.options.inputSns!.topicArn
206 | }
207 | }
208 | }))
209 |
210 |
211 | /* It should like this...
212 | this.options.inputSns!.addSubscription(new aws_sns_subscriptions.SqsSubscription(inputQueue, {
213 | filterPolicy: {
214 | : sns.SubscriptionFilter.stringFilter({
215 | allowlist: [this.options.sdModelCheckpoint!]
216 | }),
217 | runtime: sns.SubscriptionFilter.stringFilter({
218 | allowlist: [this.id]
219 | })
220 | }
221 | })) */
222 | } else {
223 | // New version routing only
224 | this.options.inputSns!.addSubscription(new aws_sns_subscriptions.SqsSubscription(inputQueue, {
225 | filterPolicy: {
226 | runtime:
227 | sns.SubscriptionFilter.stringFilter({
228 | allowlist: [this.id]
229 | })
230 | }
231 | }))
232 | }
233 |
234 | const values = lodash.merge(this.props.values, this.options.extraValues, generatedValues)
235 |
236 | const chart = this.addHelmChart(clusterInfo, values, true);
237 |
238 | return Promise.resolve(chart);
239 | }
240 | }
--------------------------------------------------------------------------------
/lib/utils/namespace.ts:
--------------------------------------------------------------------------------
1 | import { KubernetesManifest } from "aws-cdk-lib/aws-eks";
2 | import * as eks from "aws-cdk-lib/aws-eks";
3 |
4 | export type Values = {
5 | [key: string]: any;
6 | };
7 |
8 | export function createNamespace(id: string, name: string, cluster: eks.ICluster, overwrite?: boolean, prune?: boolean, annotations?: Values, labels?: Values) {
9 | return new KubernetesManifest(cluster.stack, id, {
10 | cluster: cluster,
11 | manifest: [{
12 | apiVersion: 'v1',
13 | kind: 'Namespace',
14 | metadata: {
15 | name: name,
16 | annotations,
17 | labels
18 | }
19 | }],
20 | overwrite: overwrite ?? true,
21 | prune: prune ?? true
22 | });
23 | }
--------------------------------------------------------------------------------
/lib/utils/validateConfig.ts:
--------------------------------------------------------------------------------
1 | export function validateConfig (props: any) {
2 | var checkResult = true
3 |
4 | return checkResult
5 | }
--------------------------------------------------------------------------------
/load_test/locustfile.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import datetime
3 | import json
4 | import logging
5 | import random
6 | import os
7 |
8 | import boto3
9 | import gevent
10 | from botocore.exceptions import ClientError
11 | from dateutil import parser
12 | from flask import send_file
13 | from locust import HttpUser, events, run_single_user, task
14 | from locust.runners import LocalRunner, MasterRunner
15 | from locust.user.wait_time import between
16 |
17 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(process)s - %(levelname)s - %(message)s')
18 |
19 | logger = logging.getLogger("processor")
20 | logger.setLevel(logging.INFO)
21 |
22 | API_ENDPOINT=os.getenv("API_ENDPOINT")
23 | API_KEY=os.getenv("API_KEY")
24 | OUTPUT_SQS_NAME=os.getenv("OUTPUT_SQS_NAME")
25 |
26 | TEMPLATE=json.loads("""{
27 | "task": {
28 | "metadata": {
29 | "id": "test-t2i",
30 | "runtime": "sdruntime",
31 | "tasktype": "text-to-image",
32 | "prefix": "output",
33 | "context": ""
34 | },
35 | "content": {
36 | "alwayson_scripts": {},
37 | "prompt": "A dog",
38 | "steps": 16,
39 | "width": 512,
40 | "height": 512
41 | }
42 | }
43 | }""")
44 |
45 |
46 |
47 | stats = {}
48 |
49 | sqs = boto3.resource('sqs')
50 |
51 | @events.init.add_listener
52 | def locust_init(environment, **kwargs):
53 | global stats
54 | stats = {}
55 |
56 | if isinstance(environment.runner, MasterRunner) or isinstance(environment.runner, LocalRunner):
57 | gevent.spawn(checker, environment)
58 |
59 | if environment.web_ui:
60 | @environment.web_ui.app.route("/result")
61 | def get_result():
62 | task_count = len(stats)
63 | completed_count = sum((i["complete"] == True) for i in stats.values())
64 | failed_count = sum((i["error"] == True) for i in stats.values())
65 | failure_rate = failed_count / task_count
66 | avg_time_usage = sum(i["time_usage"] for i in stats.values()) / task_count
67 | response = {
68 | "task_count": task_count,
69 | "completed_count": completed_count,
70 | "failed_count": failed_count,
71 | "failure_rate": failure_rate,
72 | "avg_time_usage": avg_time_usage
73 | }
74 | return json.dumps(response)
75 |
76 | @environment.web_ui.app.route("/dump_failed")
77 | def dump_failed():
78 | failed = dict(filter(lambda x: x[1]["error"], stats.items()))
79 | return json.dumps(failed)
80 |
81 | @environment.web_ui.app.route("/dump_all")
82 | def dump_all():
83 | data_file = open('/tmp/result.csv', 'w')
84 | csv_writer = csv.writer(data_file)
85 | count = 0
86 | for item in stats.values():
87 | if count == 0:
88 | header = item.keys()
89 | csv_writer.writerow(header)
90 | count += 1
91 | item["start_time"] = int(item["start_time"])
92 | item["complete_time"] = int(item["complete_time"])
93 | csv_writer.writerow(item.values())
94 | data_file.close()
95 | return send_file('/tmp/result.csv', as_attachment=True)
96 |
97 | def receive_messages(queue, max_number, wait_time):
98 | try:
99 | messages = queue.receive_messages(
100 | MaxNumberOfMessages=max_number,
101 | WaitTimeSeconds=wait_time,
102 | AttributeNames=['All'],
103 | MessageAttributeNames=['All']
104 | )
105 | except Exception as error:
106 | logger.error(f"Error receiving messages: {error}")
107 | else:
108 | return messages
109 |
110 | def delete_message(message):
111 | try:
112 | message.delete()
113 | except ClientError as error:
114 | raise error
115 |
116 | def checker(environment):
117 | global stats
118 | logger.info(f'Checker launched')
119 | queue = sqs.get_queue_by_name(QueueName=OUTPUT_SQS_NAME)
120 | # while not environment.runner.state in [STATE_STOPPED]:
121 | while True:
122 | received_messages = receive_messages(queue, 10, 20)
123 | if len(received_messages) == 0:
124 | logger.debug('No message received')
125 | else:
126 | for message in received_messages:
127 | payload = json.loads(message.body)
128 | msg = json.loads(payload['Message'])
129 | succeed = msg['result']
130 | task_id = msg['id']
131 | logger.debug(f'Received task with {task_id}')
132 | if task_id in stats.keys():
133 | logger.debug(f'Processing {task_id}')
134 | time = parser.parse(payload["Timestamp"]).now(datetime.timezone.utc).timestamp()
135 | time_usage = int((time - stats[task_id]['start_time']) * 1000)
136 | if succeed:
137 | stats[task_id]['complete'] = True
138 | stats[task_id]['error'] = False
139 | stats[task_id]['complete_time'] = time
140 | stats[task_id]['time_usage'] = time_usage
141 | else:
142 | stats[task_id]['complete'] = True
143 | stats[task_id]['error'] = True
144 | stats[task_id]['complete_time'] = time
145 | stats[task_id]['time_usage'] = time_usage
146 | else:
147 | logger.debug(f'Ignored {task_id}')
148 | delete_message(message)
149 |
150 | @events.test_start.add_listener
151 | def on_test_start(**kwargs):
152 | global stats
153 | stats = {}
154 |
155 | class MyUser(HttpUser):
156 | host = "http://0.0.0.0:8089"
157 | wait_time = between(0.5, 2.0)
158 |
159 | @task
160 | def txt_to_img(self):
161 | random_number = str(random.randint(1, 999999999)).zfill(9)
162 | body = TEMPLATE.copy()
163 | body["task"]["metadata"]["id"] = str(random_number)
164 | logger.debug(f'Send request with {random_number}')
165 | self.client.post(API_ENDPOINT+"v1alpha2",
166 | data = json.dumps(body),
167 | headers={"x-api-key": API_KEY, "Content-Type": "application/json"},
168 | context={"task-id": random_number})
169 |
170 | @events.request.add_listener
171 | def on_request(context, **kwargs):
172 | """
173 | Event handler that get triggered on every request.
174 | """
175 | global stats
176 | stats[context["task-id"]] = {
177 | "task-id": context["task-id"],
178 | "start_time": datetime.datetime.now(datetime.timezone.utc).timestamp(),
179 | "complete_time": 0.0,
180 | "complete": False,
181 | "error": False,
182 | "time_usage": 0
183 | }
184 |
185 | # if launched directly, e.g. "python3 debugging.py", not "locust -f debugging.py"
186 | if __name__ == "__main__":
187 | run_single_user(MyUser)
--------------------------------------------------------------------------------
/package.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "stable-difussion-on-eks",
3 | "version": "1.1.2",
4 | "bin": {
5 | "stable-difussion-on-eks": "bin/stable-difussion-on-eks.ts"
6 | },
7 | "scripts": {
8 | "synth": "npx cdk synth -q",
9 | "deploy": "npx cdk deploy --no-rollback --require-approval never"
10 | },
11 | "devDependencies": {
12 | "@types/node": "22.8.6",
13 | "aws-cdk": "2.173.4",
14 | "ts-node": "^10.9.1",
15 | "typescript": "~5.5.2"
16 | },
17 | "dependencies": {
18 | "@aws-quickstart/eks-blueprints": "1.16.3",
19 | "@types/lodash": "^4.14.197",
20 | "aws-cdk-lib": "2.173.4",
21 | "constructs": "^10.0.0",
22 | "source-map-support": "^0.5.21",
23 | "yaml": "^2.3.2"
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/src/backend/queue_agent/.gitignore:
--------------------------------------------------------------------------------
1 | test/*
--------------------------------------------------------------------------------
/src/backend/queue_agent/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM public.ecr.aws/docker/library/python:3.12-slim
2 |
3 | RUN apt-get update; apt-get install libmagic1 -y; rm -rf /var/lib/apt/lists/*
4 |
5 | RUN adduser --disabled-password --gecos '' app && \
6 | mkdir /app || true; \
7 | chown -R app /app
8 | WORKDIR /app
9 | USER app
10 |
11 | COPY requirements.txt .
12 | RUN pip3 install --no-cache-dir --upgrade pip && \
13 | pip3 install --no-cache-dir -r requirements.txt
14 |
15 | COPY src/ /app/
16 |
17 | CMD python3 -u /app/main.py
18 |
--------------------------------------------------------------------------------
/src/backend/queue_agent/requirements.txt:
--------------------------------------------------------------------------------
1 | aioboto3>=12.3.0
2 | aiohttp_client_cache>=0.11.0
3 | aws_xray_sdk==2.12.0
4 | boto3==1.33.2
5 | botocore==1.33.2
6 | python_magic>=0.4.24
7 | Requests>=2.31.0
8 | requests_cache>=1.2.0
9 | websocket_client>=1.7.0
10 |
--------------------------------------------------------------------------------
/src/backend/queue_agent/src/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 |
--------------------------------------------------------------------------------
/src/backend/queue_agent/src/main.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 |
4 | import json
5 | import logging
6 | import os
7 | import signal
8 | import sys
9 | import uuid
10 | import time
11 | import functools
12 |
13 | import boto3
14 | from botocore.exceptions import EndpointConnectionError
15 | from aws_xray_sdk.core import patch_all, xray_recorder
16 | from aws_xray_sdk.core.models.trace_header import TraceHeader
17 | from modules import s3_action, sns_action, sqs_action
18 | from runtimes import comfyui, sdwebui
19 |
20 | # Initialize logging first so we can log X-Ray initialization attempts
21 | logging.basicConfig()
22 | logging.getLogger().setLevel(logging.ERROR)
23 |
24 | # Configure the queue-agent logger only once
25 | logger = logging.getLogger("queue-agent")
26 | logger.propagate = False
27 | logger.setLevel(os.environ.get('LOGLEVEL', 'INFO').upper())
28 |
29 | # Remove any existing handlers to prevent duplicate logs
30 | if logger.handlers:
31 | logger.handlers.clear()
32 |
33 | # Add a single handler
34 | handler = logging.StreamHandler(sys.stdout)
35 | handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
36 | logger.addHandler(handler)
37 |
38 | # Check if X-Ray is manually disabled via environment variable
39 | DISABLE_XRAY = os.environ.get('DISABLE_XRAY', 'false').lower() == 'true'
40 | if DISABLE_XRAY:
41 | logger.info("X-Ray tracing manually disabled via DISABLE_XRAY environment variable")
42 | xray_enabled = False
43 | else:
44 | # Try to initialize X-Ray SDK with retries, as the daemon might be starting up
45 | MAX_XRAY_INIT_ATTEMPTS = 5
46 | XRAY_RETRY_DELAY = 3 # seconds
47 | xray_enabled = False
48 |
49 | for attempt in range(MAX_XRAY_INIT_ATTEMPTS):
50 | try:
51 | logger.info(f"Attempting to initialize X-Ray SDK (attempt {attempt+1}/{MAX_XRAY_INIT_ATTEMPTS})")
52 | patch_all()
53 | xray_enabled = True
54 | logger.info("X-Ray SDK initialized successfully")
55 | break
56 | except EndpointConnectionError:
57 | logger.warning(f"Could not connect to X-Ray daemon (attempt {attempt+1}/{MAX_XRAY_INIT_ATTEMPTS})")
58 | if attempt < MAX_XRAY_INIT_ATTEMPTS - 1:
59 | logger.info(f"Retrying in {XRAY_RETRY_DELAY} seconds...")
60 | time.sleep(XRAY_RETRY_DELAY)
61 | except Exception as e:
62 | logger.warning(f"Error initializing X-Ray: {str(e)} (attempt {attempt+1}/{MAX_XRAY_INIT_ATTEMPTS})")
63 | if attempt < MAX_XRAY_INIT_ATTEMPTS - 1:
64 | logger.info(f"Retrying in {XRAY_RETRY_DELAY} seconds...")
65 | time.sleep(XRAY_RETRY_DELAY)
66 |
67 | if not xray_enabled:
68 | logger.warning("X-Ray initialization failed after all attempts. Tracing will be disabled.")
69 |
70 | # Create a decorator for safe X-Ray instrumentation
71 | def safe_xray_capture(name):
72 | """Decorator that safely applies X-Ray instrumentation if available"""
73 | def decorator(func):
74 | @functools.wraps(func)
75 | def wrapper(*args, **kwargs):
76 | if xray_enabled:
77 | try:
78 | # Try to use X-Ray instrumentation
79 | with xray_recorder.in_segment(name):
80 | return func(*args, **kwargs)
81 | except Exception as e:
82 | logger.warning(f"X-Ray instrumentation failed for {name}: {str(e)}")
83 | # Fall back to non-instrumented execution
84 | return func(*args, **kwargs)
85 | else:
86 | # X-Ray is disabled, just call the function directly
87 | return func(*args, **kwargs)
88 | return wrapper
89 | return decorator
90 |
91 | # Get base environment variable
92 | aws_default_region = os.getenv("AWS_DEFAULT_REGION")
93 | sqs_queue_url = os.getenv("SQS_QUEUE_URL")
94 | sns_topic_arn = os.getenv("SNS_TOPIC_ARN")
95 | s3_bucket = os.getenv("S3_BUCKET")
96 | runtime_name = os.getenv("RUNTIME_NAME", "")
97 | api_base_url = ""
98 |
99 | exp_callback_when_running = os.getenv("EXP_CALLBACK_WHEN_RUNNING", "")
100 |
101 | # Check current runtime type
102 | runtime_type = os.getenv("RUNTIME_TYPE", "").lower()
103 |
104 | # Runtime type should be specified
105 | if runtime_type == "":
106 | logger.error(f'Runtime type not specified')
107 | raise RuntimeError
108 |
109 | # Init for SD Web UI
110 | if runtime_type == "sdwebui":
111 | api_base_url = os.getenv("API_BASE_URL", "http://localhost:8080/sdapi/v1/")
112 | dynamic_sd_model_str = os.getenv("DYNAMIC_SD_MODEL", "false")
113 | if dynamic_sd_model_str.lower() == "false":
114 | dynamic_sd_model = False
115 | else:
116 | dynamic_sd_model = True
117 |
118 | # Init for ComfyUI
119 | if runtime_type == "comfyui":
120 | api_base_url = os.getenv("API_BASE_URL", "localhost:8080")
121 | client_id = str(uuid.uuid4())
122 | # Change here to ComfyUI's base URL
123 | # You can specify any required environment variable here
124 |
125 | sqsRes = boto3.resource('sqs')
126 | snsRes = boto3.resource('sns')
127 |
128 | SQS_WAIT_TIME_SECONDS = 20
129 |
130 | # For graceful shutdown
131 | shutdown = False
132 |
133 | def main():
134 | # Initialization:
135 | # 1. Environment parameters;
136 | # 2. AWS services resources(sqs/sns/s3);
137 | # 3. SD API readiness check, current checkpoint cached;
138 | print_env()
139 |
140 | queue = sqsRes.Queue(sqs_queue_url)
141 | topic = snsRes.Topic(sns_topic_arn)
142 |
143 | if runtime_type == "sdwebui":
144 | sdwebui.check_readiness(api_base_url, dynamic_sd_model)
145 |
146 | if runtime_type == "comfyui":
147 | comfyui.check_readiness(api_base_url)
148 |
149 | # main loop
150 | # 1. Pull msg from sqs;
151 | # 2. Translate parameteres;
152 | # 3. (opt)Switch model;
153 | # 4. (opt)Prepare inputs for image downloading and encoding;
154 | # 5. Call SD API;
155 | # 6. Prepare outputs for decoding, uploading and notifying;
156 | # 7. Delete msg;
157 | while True:
158 | if shutdown:
159 | logger.info('Received SIGTERM, shutting down...')
160 | break
161 |
162 | received_messages = sqs_action.receive_messages(queue, 1, SQS_WAIT_TIME_SECONDS)
163 |
164 | for message in received_messages:
165 | # Process with X-Ray if enabled, otherwise just process the message directly
166 | if xray_enabled:
167 | try:
168 | with xray_recorder.in_segment(runtime_name+"-queue-agent") as segment:
169 | # Retrieve x-ray trace header from SQS message
170 | if "AWSTraceHeader" in message.attributes.keys():
171 | traceHeaderStr = message.attributes['AWSTraceHeader']
172 | sqsTraceHeader = TraceHeader.from_header_str(traceHeaderStr)
173 | # Update current segment to link with SQS
174 | segment.trace_id = sqsTraceHeader.root
175 | segment.parent_id = sqsTraceHeader.parent
176 | segment.sampled = sqsTraceHeader.sampled
177 |
178 | # Process the message within the X-Ray segment
179 | process_message(message, topic, s3_bucket, runtime_type, runtime_name, api_base_url, dynamic_sd_model if runtime_type == "sdwebui" else None)
180 | except Exception as e:
181 | logger.error(f"Error with X-Ray tracing: {str(e)}. Processing message without tracing.")
182 | process_message(message, topic, s3_bucket, runtime_type, runtime_name, api_base_url, dynamic_sd_model if runtime_type == "sdwebui" else None)
183 | else:
184 | # Process without X-Ray tracing
185 | process_message(message, topic, s3_bucket, runtime_type, runtime_name, api_base_url, dynamic_sd_model if runtime_type == "sdwebui" else None)
186 |
187 | def process_message(message, topic, s3_bucket, runtime_type, runtime_name, api_base_url, dynamic_sd_model=None):
188 | """Process a single SQS message"""
189 | # Process received message
190 | try:
191 | payload = json.loads(json.loads(message.body)['Message'])
192 | metadata = payload["metadata"]
193 | task_id = metadata["id"]
194 |
195 | logger.info(f"Received task {task_id}, processing")
196 |
197 | if "prefix" in metadata.keys():
198 | if metadata["prefix"][-1] == '/':
199 | prefix = metadata["prefix"] + str(task_id)
200 | else:
201 | prefix = metadata["prefix"] + "/" + str(task_id)
202 | else:
203 | prefix = str(task_id)
204 |
205 | if "tasktype" in metadata.keys():
206 | tasktype = metadata["tasktype"]
207 |
208 | if "context" in metadata.keys():
209 | context = metadata["context"]
210 | else:
211 | context = {}
212 |
213 | body = payload["content"]
214 | logger.debug(body)
215 | except Exception as e:
216 | logger.error(f"Error parsing message: {e}, skipping")
217 | logger.debug(payload)
218 | sqs_action.delete_message(message)
219 | return
220 |
221 | if (exp_callback_when_running.lower() == "true"):
222 | sns_response = {"runtime": runtime_name,
223 | 'id': task_id,
224 | 'status': "running",
225 | 'context': context}
226 |
227 | sns_action.publish_message(topic, json.dumps(sns_response))
228 |
229 | # Start handling message
230 | response = {}
231 |
232 | try:
233 | if runtime_type == "sdwebui":
234 | response = sdwebui.handler(api_base_url, tasktype, task_id, body, dynamic_sd_model)
235 |
236 | if runtime_type == "comfyui":
237 | response = comfyui.handler(api_base_url, task_id, body)
238 | except Exception as e:
239 | logger.error(f"Error calling handler for task {task_id}: {str(e)}")
240 | response = {
241 | "success": False,
242 | "image": [],
243 | "content": '{"code": 500, "error": "Runtime handler failed"}'
244 | }
245 |
246 | result = []
247 | rand = str(uuid.uuid4())[0:4]
248 |
249 | if response["success"]:
250 | idx = 0
251 | if len(response["image"]) > 0:
252 | for i in response["image"]:
253 | idx += 1
254 | result.append(s3_action.upload_file(i, s3_bucket, prefix, str(task_id)+"-"+rand+"-"+str(idx)))
255 |
256 | output_url = s3_action.upload_file(response["content"], s3_bucket, prefix, str(task_id)+"-"+rand, ".out")
257 |
258 | if response["success"]:
259 | status = "completed"
260 | else:
261 | status = "failed"
262 |
263 | sns_response = {"runtime": runtime_name,
264 | 'id': task_id,
265 | 'result': response["success"],
266 | 'status': status,
267 | 'image_url': result,
268 | 'output_url': output_url,
269 | 'context': context}
270 |
271 | # Put response handler to SNS and delete message
272 | sns_action.publish_message(topic, json.dumps(sns_response))
273 | sqs_action.delete_message(message)
274 |
275 | def print_env() -> None:
276 | logger.info(f'AWS_DEFAULT_REGION={aws_default_region}')
277 | logger.info(f'SQS_QUEUE_URL={sqs_queue_url}')
278 | logger.info(f'SNS_TOPIC_ARN={sns_topic_arn}')
279 | logger.info(f'S3_BUCKET={s3_bucket}')
280 | logger.info(f'RUNTIME_TYPE={runtime_type}')
281 | logger.info(f'RUNTIME_NAME={runtime_name}')
282 | logger.info(f'X-Ray Tracing: {"Disabled" if DISABLE_XRAY else "Enabled"}')
283 | logger.info(f'X-Ray Status: {"Active" if xray_enabled else "Inactive"}')
284 |
285 | def signalHandler(signum, frame):
286 | global shutdown
287 | shutdown = True
288 |
289 | if __name__ == '__main__':
290 | for sig in [signal.SIGINT, signal.SIGHUP, signal.SIGTERM]:
291 | signal.signal(sig, signalHandler)
292 | main()
293 |
--------------------------------------------------------------------------------
/src/backend/queue_agent/src/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 |
--------------------------------------------------------------------------------
/src/backend/queue_agent/src/modules/http_action.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 |
4 | import logging
5 |
6 | import aioboto3
7 | import boto3
8 | import requests
9 | import requests_cache
10 | from aiohttp_client_cache import CacheBackend, CachedSession
11 | from requests.adapters import HTTPAdapter, Retry
12 |
13 | from . import s3_action, time_utils
14 |
15 | logger = logging.getLogger("queue-agent")
16 |
17 |
18 | ab3_session = aioboto3.Session()
19 |
20 | apiClient = requests.Session()
21 | retries = Retry(
22 | total=3,
23 | connect=100,
24 | backoff_factor=0.1,
25 | allowed_methods=["GET", "POST"])
26 | apiClient.mount('http://', HTTPAdapter(max_retries=retries))
27 |
28 | REQUESTS_TIMEOUT_SECONDS = 300
29 |
30 | cache = CacheBackend(
31 | cache_name='memory-cache',
32 | expire_after=600
33 | )
34 |
35 | @time_utils.get_time
36 | def do_invocations(url: str, body:str=None) -> str:
37 | if body is None:
38 | logger.debug(f"Invoking {url}")
39 | response = apiClient.get(
40 | url=url, timeout=(1, REQUESTS_TIMEOUT_SECONDS))
41 | else:
42 | logger.debug(f"Invoking {url} with body: {body}")
43 | response = apiClient.post(
44 | url=url, json=body, timeout=(1, REQUESTS_TIMEOUT_SECONDS))
45 | response.raise_for_status()
46 | logger.debug(response.text)
47 | return response.json()
48 |
49 | def get(url: str) -> bytes:
50 | logger.debug(f"Downloading {url}")
51 | try:
52 | if url.lower().startswith("http://") or url.lower().startswith("https://"):
53 | with requests_cache.CachedSession('demo_cache') as session:
54 | with session.get(url) as res:
55 | res.raise_for_status()
56 | return res.content
57 | elif url.lower().startswith("s3://"):
58 | bucket_name, key = s3_action.get_bucket_and_key(url)
59 | with boto3.client('s3') as s3:
60 | obj = s3.get_object(Bucket=bucket_name, Key=key)
61 | data = obj['Body'].read()
62 | return data
63 | except Exception as e:
64 | raise e
65 |
66 | async def async_get(url: str) -> None:
67 | try:
68 | if url.lower().startswith("http://") or url.lower().startswith("https://"):
69 | async with CachedSession(cache=cache) as session:
70 | async with session.get(url) as res:
71 | res.raise_for_status()
72 | return await res.read()
73 | elif url.lower().startswith("s3://"):
74 | bucket_name, key = s3_action.get_bucket_and_key(url)
75 | async with ab3_session.resource("s3") as s3:
76 | obj = await s3.Object(bucket_name, key)
77 | res = await obj.get()
78 | return await res['Body'].read()
79 | except Exception as e:
80 | raise e
--------------------------------------------------------------------------------
/src/backend/queue_agent/src/modules/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 |
4 | import base64
5 | import difflib
6 |
7 |
8 | def exclude_keys(dictionary, keys):
9 | key_set = set(dictionary.keys()) - set(keys)
10 | return {key: dictionary[key] for key in key_set}
11 |
12 | def str_simularity(a, b):
13 | return difflib.SequenceMatcher(None, a, b).ratio()
14 |
15 | def encode_to_base64(buffer):
16 | return str(base64.b64encode(buffer))[2:-1]
--------------------------------------------------------------------------------
/src/backend/queue_agent/src/modules/s3_action.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 |
4 | import datetime
5 | import logging
6 | import mimetypes
7 | import uuid
8 |
9 | import aioboto3
10 | import boto3
11 | import magic
12 |
13 | logger = logging.getLogger("queue-agent")
14 | s3Res = boto3.resource('s3')
15 |
16 | ab3_session = aioboto3.Session()
17 |
18 | def upload_file(object_bytes: bytes, bucket_name: str, prefix: str, file_name: str=None, extension: str=None) -> str:
19 | if file_name is None:
20 | file_name = datetime.datetime.now().strftime(f"%Y%m%d%H%M%S-{uuid.uuid4()[0:5]}")
21 |
22 | # Auto determine file type and extension using magic
23 | if extension is None:
24 | content_type = magic.from_buffer(object_bytes, mime=True)
25 | extension = mimetypes.guess_extension(content_type, True)
26 |
27 | if extension == '.out':
28 | content_type = f'application/json'
29 |
30 | try:
31 | bucket = s3Res.Bucket(bucket_name)
32 | logger.info(f"Uploading s3://{bucket_name}/{prefix}/{file_name}{extension}")
33 | bucket.put_object(Body=object_bytes, Key=f'{prefix}/{file_name}{extension}', ContentType=content_type)
34 | return f's3://{bucket_name}/{prefix}/{file_name}{extension}'
35 | except Exception as error:
36 | logger.error('Failed to upload content to S3', exc_info=True)
37 | raise error
38 |
39 |
40 | async def async_upload(object_bytes: bytes, bucket_name: str, prefix: str, file_name: str=None, extension: str=None) -> str:
41 | if file_name is None:
42 | file_name = datetime.datetime.now().strftime(f"%Y%m%d%H%M%S-{uuid.uuid4()[0:5]}")
43 |
44 | # Auto determine file type and extension using magic
45 | if extension is None:
46 | content_type = magic.from_buffer(object_bytes, mime=True)
47 | extension = mimetypes.guess_extension(content_type, True)
48 |
49 | if extension == '.out':
50 | content_type = f'application/json'
51 |
52 | try:
53 | async with ab3_session.resource("s3") as s3:
54 | bucket = await s3.Bucket(bucket_name)
55 | await bucket.put_object(Body=object_bytes, Key=f'{prefix}/{file_name}{extension}', ContentType=content_type)
56 | return f's3://{bucket_name}/{prefix}/{file_name}{extension}'
57 | except Exception as e:
58 | raise e
59 |
60 | def get_bucket_and_key(s3uri):
61 | pos = s3uri.find('/', 5)
62 | bucket = s3uri[5: pos]
63 | key = s3uri[pos + 1:]
64 | return bucket, key
65 |
66 | def get_prefix(path):
67 | pos = path.find('/')
68 | return path[pos + 1:]
--------------------------------------------------------------------------------
/src/backend/queue_agent/src/modules/sns_action.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import aioboto3
4 | from botocore.exceptions import ClientError
5 |
6 | logger = logging.getLogger("queue-agent")
7 |
8 | ab3_session = aioboto3.Session()
9 |
10 | def publish_message(topic, message: str) -> str:
11 | try:
12 | response = topic.publish(Message=message)
13 | message_id = response['MessageId']
14 | except ClientError as error:
15 | logger.error('Failed to send message to SNS', exc_info=True)
16 | raise error
17 | else:
18 | return message_id
19 |
20 | async def async_publish_message(topic, content: str):
21 | try:
22 | async with ab3_session.resource("sns") as sns:
23 | topic = await sns.Topic(topic)
24 | response = await topic.publish(Message=content)
25 | return response['MessageId']
26 | except Exception as e:
27 | raise e
--------------------------------------------------------------------------------
/src/backend/queue_agent/src/modules/sqs_action.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 |
4 | import logging
5 |
6 | from botocore.exceptions import ClientError
7 |
8 | logger = logging.getLogger("queue-agent")
9 |
10 | def receive_messages(queue, max_number, wait_time):
11 | try:
12 | messages = queue.receive_messages(
13 | MaxNumberOfMessages=max_number,
14 | WaitTimeSeconds=wait_time,
15 | AttributeNames=['All'],
16 | MessageAttributeNames=['All']
17 | )
18 | except ClientError as error:
19 | logger.error('Failed to get message from SQS', exc_info=True)
20 | raise error
21 | else:
22 | return messages
23 |
24 | def delete_message(message):
25 | try:
26 | message.delete()
27 | except ClientError as error:
28 | logger.error('Failed to delete message from SQS', exc_info=True)
29 | raise error
--------------------------------------------------------------------------------
/src/backend/queue_agent/src/modules/time_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 |
4 | import logging
5 | from time import perf_counter
6 |
7 | logger = logging.getLogger("queue-agent")
8 |
9 | def get_time(f):
10 | def inner(*arg, **kwarg):
11 | s_time = perf_counter()
12 | res = f(*arg, **kwarg)
13 | e_time = perf_counter()
14 | logger.info('Used: {:.4f} seconds on api: {}.'.format(e_time - s_time, arg[0]))
15 | return res
16 | return inner
--------------------------------------------------------------------------------
/src/backend/queue_agent/src/runtimes/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
--------------------------------------------------------------------------------
/src/backend/queue_agent/src/runtimes/comfyui.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 |
4 | import json
5 | import logging
6 | import time
7 | import traceback
8 | import urllib.parse
9 | import uuid
10 | from typing import Optional, Dict, List, Any, Union
11 |
12 | import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
13 | from modules import http_action
14 |
15 | logger = logging.getLogger("queue-agent")
16 |
17 | # Import the safe_xray_capture decorator from main module
18 | try:
19 | from src.main import safe_xray_capture, xray_enabled
20 | except ImportError:
21 | try:
22 | # Try alternative import path
23 | from ..main import safe_xray_capture, xray_enabled
24 | except ImportError:
25 | # Fallback if import fails - create a simple pass-through decorator
26 | logger.warning("Failed to import safe_xray_capture from main, using fallback")
27 | def safe_xray_capture(name):
28 | def decorator(func):
29 | return func
30 | return decorator
31 | xray_enabled = False
32 |
33 | # Constants for websocket reconnection
34 | MAX_RECONNECT_ATTEMPTS = 5
35 | RECONNECT_DELAY = 2 # seconds
36 |
37 | def singleton(cls):
38 | _instance = {}
39 |
40 | def inner():
41 | if cls not in _instance:
42 | _instance[cls] = cls()
43 | return _instance[cls]
44 | return inner
45 |
46 | @singleton
47 | class comfyuiCaller(object):
48 |
49 | def __init__(self):
50 | self.wss = websocket.WebSocket()
51 | self.client_id = str(uuid.uuid4())
52 | self.api_base_url = None
53 | self.connected = False
54 |
55 | def setUrl(self, api_base_url:str):
56 | self.api_base_url = api_base_url
57 |
58 | def wss_connect(self):
59 | """Connect to websocket with reconnection logic"""
60 | if self.connected:
61 | return True
62 |
63 | attempts = 0
64 | while attempts < MAX_RECONNECT_ATTEMPTS:
65 | try:
66 | logger.info(f"Connecting to websocket (attempt {attempts+1}/{MAX_RECONNECT_ATTEMPTS})")
67 | self.wss.connect("ws://{}/ws?clientId={}".format(self.api_base_url, self.client_id))
68 | self.connected = True
69 | logger.info("Successfully connected to websocket")
70 | return True
71 | except Exception as e:
72 | attempts += 1
73 | logger.warning(f"Failed to connect to websocket: {str(e)}")
74 | if attempts < MAX_RECONNECT_ATTEMPTS:
75 | logger.info(f"Retrying in {RECONNECT_DELAY} seconds...")
76 | time.sleep(RECONNECT_DELAY)
77 | else:
78 | logger.error("Max reconnection attempts reached")
79 | raise ConnectionError(f"Failed to connect to ComfyUI websocket after {MAX_RECONNECT_ATTEMPTS} attempts") from e
80 |
81 | return False
82 |
83 | def wss_recv(self) -> Optional[str]:
84 | """Receive data from websocket with reconnection logic"""
85 | attempts = 0
86 | while attempts < MAX_RECONNECT_ATTEMPTS:
87 | try:
88 | return self.wss.recv()
89 | except websocket.WebSocketConnectionClosedException:
90 | attempts += 1
91 | logger.warning(f"WebSocket connection closed, attempting to reconnect (attempt {attempts}/{MAX_RECONNECT_ATTEMPTS})...")
92 | self.connected = False
93 |
94 | if attempts < MAX_RECONNECT_ATTEMPTS:
95 | if self.wss_connect():
96 | logger.info("Reconnected successfully, retrying receive operation")
97 | continue
98 | else:
99 | logger.warning(f"Failed to reconnect, waiting {RECONNECT_DELAY} seconds before retry...")
100 | time.sleep(RECONNECT_DELAY)
101 | else:
102 | logger.error("Max reconnection attempts reached in wss_recv")
103 | return None
104 | except Exception as e:
105 | attempts += 1
106 | logger.error(f"Error receiving data from websocket: {str(e)}")
107 | self.connected = False
108 |
109 | if attempts < MAX_RECONNECT_ATTEMPTS:
110 | logger.info(f"Waiting {RECONNECT_DELAY} seconds before retry...")
111 | time.sleep(RECONNECT_DELAY)
112 | if self.wss_connect():
113 | logger.info("Reconnected successfully, retrying receive operation")
114 | continue
115 | else:
116 | logger.error("Max reconnection attempts reached in wss_recv")
117 | return None
118 |
119 | return None
120 |
121 | def get_history(self, prompt_id):
122 | try:
123 | url = f"http://{self.api_base_url}/history/{prompt_id}"
124 | # Use the http_action module with built-in retry logic
125 | return http_action.do_invocations(url)
126 | except Exception as e:
127 | logger.error(f"Error in get_history: {str(e)}")
128 | return {}
129 |
130 | def queue_prompt(self, prompt):
131 | try:
132 | p = {"prompt": prompt, "client_id": self.client_id}
133 | url = f"http://{self.api_base_url}/prompt"
134 |
135 | # Use the http_action module with built-in retry logic
136 | response = http_action.do_invocations(url, p)
137 | return response
138 | except Exception as e:
139 | logger.error(f"Error in queue_prompt: {str(e)}")
140 | return None
141 |
142 | def get_image(self, filename, subfolder, folder_type):
143 | try:
144 | data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
145 | url_values = urllib.parse.urlencode(data)
146 | url = f"http://{self.api_base_url}/view?{url_values}"
147 |
148 | # Use http_action.get which returns bytes directly
149 | return http_action.get(url)
150 | except Exception as e:
151 | logger.error(f"Error getting image {filename}: {str(e)}")
152 | return b'' # Return empty bytes on error
153 |
154 | def track_progress(self, prompt, prompt_id):
155 | logger.info("Task received, prompt ID:" + prompt_id)
156 | node_ids = list(prompt.keys())
157 | finished_nodes = []
158 | max_errors = 5
159 | error_count = 0
160 |
161 | while True:
162 | try:
163 | out = self.wss_recv() # Using our new method with reconnection logic
164 | if out is None:
165 | error_count += 1
166 | logger.warning(f"Failed to receive data from websocket (error {error_count}/{max_errors})")
167 | if error_count >= max_errors:
168 | logger.error("Too many errors receiving websocket data, aborting track_progress")
169 | return False
170 | time.sleep(1)
171 | continue
172 |
173 | error_count = 0 # Reset error count on successful receive
174 |
175 | if isinstance(out, str):
176 | try:
177 | message = json.loads(out)
178 | logger.debug(out)
179 | if message['type'] == 'progress':
180 | data = message['data']
181 | current_step = data['value']
182 | logger.info(f"In K-Sampler -> Step: {current_step} of: {data['max']}")
183 | if message['type'] == 'execution_cached':
184 | data = message['data']
185 | for itm in data['nodes']:
186 | if itm not in finished_nodes:
187 | finished_nodes.append(itm)
188 | logger.info(f"Progress: {len(finished_nodes)} / {len(node_ids)} tasks done")
189 | if message['type'] == 'executing':
190 | data = message['data']
191 | if data['node'] not in finished_nodes:
192 | finished_nodes.append(data['node'])
193 | logger.info(f"Progress: {len(finished_nodes)} / {len(node_ids)} tasks done")
194 |
195 | if data['node'] is None and data['prompt_id'] == prompt_id:
196 | return True # Execution is done successfully
197 | except json.JSONDecodeError as e:
198 | logger.warning(f"Error parsing websocket message: {str(e)}, skipping message")
199 | continue
200 | except KeyError as e:
201 | logger.warning(f"Missing key in websocket message: {str(e)}, skipping message")
202 | continue
203 | else:
204 | continue
205 | except Exception as e:
206 | error_count += 1
207 | logger.warning(f"Unexpected error in track_progress: {str(e)} (error {error_count}/{max_errors})")
208 | if error_count >= max_errors:
209 | logger.error("Too many errors in track_progress, aborting")
210 | return False
211 | time.sleep(1)
212 |
213 | return True
214 |
215 | def get_images(self, prompt):
216 | max_retries = 3
217 | retry_count = 0
218 |
219 | while retry_count < max_retries:
220 | try:
221 | output = self.queue_prompt(prompt)
222 | if output is None:
223 | raise RuntimeError("Failed to queue prompt - internal error")
224 |
225 | prompt_id = output['prompt_id']
226 | output_images = {}
227 |
228 | self.track_progress(prompt, prompt_id)
229 |
230 | history = self.get_history(prompt_id)[prompt_id]
231 | for o in history['outputs']:
232 | for node_id in history['outputs']:
233 | node_output = history['outputs'][node_id]
234 | # image branch
235 | if 'images' in node_output:
236 | images_output = []
237 | for image in node_output['images']:
238 | image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
239 | images_output.append(image_data)
240 | output_images[node_id] = images_output
241 | # video branch
242 | if 'videos' in node_output:
243 | videos_output = []
244 | for video in node_output['videos']:
245 | video_data = self.get_image(video['filename'], video['subfolder'], video['type'])
246 | videos_output.append(video_data)
247 | output_images[node_id] = videos_output
248 |
249 | # If we got here, everything worked
250 | return output_images
251 |
252 | except websocket.WebSocketConnectionClosedException as e:
253 | retry_count += 1
254 | logger.warning(f"WebSocket connection closed during processing (attempt {retry_count}/{max_retries})")
255 |
256 | # Try to reconnect before retrying
257 | self.connected = False
258 | if retry_count < max_retries:
259 | logger.info("Attempting to reconnect websocket...")
260 | if self.wss_connect():
261 | logger.info("Reconnected successfully, retrying operation")
262 | time.sleep(1) # Small delay before retry
263 | else:
264 | logger.error("Failed to reconnect websocket")
265 | else:
266 | logger.error(f"Failed after {max_retries} attempts")
267 | raise RuntimeError(f"Failed to process images after {max_retries} attempts") from e
268 |
269 | except Exception as e:
270 | logger.error(f"Error processing images: {str(e)}")
271 | retry_count += 1
272 |
273 | # For non-websocket errors, we might still want to try reconnecting the websocket
274 | if not self.connected and retry_count < max_retries:
275 | logger.info("Attempting to reconnect websocket...")
276 | self.wss_connect()
277 | time.sleep(1) # Small delay before retry
278 | else:
279 | # If it's not a connection issue or we've tried enough times, re-raise
280 | if retry_count >= max_retries:
281 | raise
282 |
283 | # This should not be reached, but just in case
284 | raise RuntimeError(f"Failed to process images after {max_retries} attempts")
285 |
286 | def parse_worflow(self, prompt_data):
287 | logger.debug(prompt_data)
288 | return self.get_images(prompt_data)
289 |
290 |
291 | def check_readiness(api_base_url: str) -> bool:
292 | cf = comfyuiCaller()
293 | cf.setUrl(api_base_url)
294 | logger.info("Init health check... ")
295 | try:
296 | logger.info(f"Try to connect to ComfyUI backend {api_base_url} ... ")
297 | if cf.wss_connect():
298 | logger.info(f"ComfyUI backend {api_base_url} connected.")
299 | return True
300 | else:
301 | logger.error(f"Failed to connect to ComfyUI backend {api_base_url}")
302 | return False
303 | except Exception as e:
304 | logger.error(f"Error during health check: {str(e)}")
305 | return False
306 |
307 |
308 | def handler(api_base_url: str, task_id: str, payload: dict) -> dict:
309 | response = {
310 | "success": False,
311 | "image": [],
312 | "content": '{"code": 500}'
313 | }
314 |
315 | try:
316 | logger.info(f"Processing pipeline task with ID: {task_id}")
317 |
318 | # Attempt to invoke the pipeline
319 | try:
320 | images = invoke_pipeline(api_base_url, payload)
321 |
322 | # Process images if available
323 | imgOutputs = post_invocations(images)
324 | logger.info(f"Received {len(imgOutputs)} images")
325 |
326 | # Set success response
327 | response["success"] = True
328 | response["image"] = imgOutputs
329 | response["content"] = '{"code": 200}'
330 | logger.info(f"End process pipeline task with ID: {task_id}")
331 | except Exception as e:
332 | logger.error(f"Error processing pipeline: {str(e)}")
333 | # Keep default failure response
334 | except Exception as e:
335 | # This is a catch-all for any unexpected errors
336 | logger.error(f"Unexpected error in handler for task ID {task_id}: {str(e)}")
337 | traceback.print_exc()
338 |
339 | return response
340 |
341 | @safe_xray_capture('comfyui-pipeline')
342 | def invoke_pipeline(api_base_url: str, body) -> str:
343 | cf = comfyuiCaller()
344 | cf.setUrl(api_base_url)
345 |
346 | # Ensure websocket connection is established before proceeding
347 | if not cf.wss_connect():
348 | raise ConnectionError(f"Failed to establish websocket connection to {api_base_url}")
349 |
350 | return cf.parse_worflow(body)
351 |
352 | def post_invocations(image):
353 | img_bytes = []
354 |
355 | if len(image) > 0:
356 | for node_id in image:
357 | for image_data in image[node_id]:
358 | img_bytes.append(image_data)
359 |
360 | return img_bytes
--------------------------------------------------------------------------------
/src/backend/queue_agent/src/runtimes/sdwebui.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 |
4 | import base64
5 | import json
6 | import logging
7 | import time
8 | import traceback
9 |
10 | from requests.exceptions import ReadTimeout, HTTPError
11 | from modules import http_action, misc
12 |
13 | logger = logging.getLogger("queue-agent")
14 |
15 | ALWAYSON_SCRIPTS_EXCLUDE_KEYS = ['task', 'id_task', 'uid',
16 | 'sd_model_checkpoint', 'image_link', 'save_dir', 'sd_vae', 'override_settings']
17 |
18 | # Import the safe_xray_capture decorator from main module
19 | try:
20 | from src.main import safe_xray_capture, xray_enabled
21 | except ImportError:
22 | try:
23 | # Try alternative import path
24 | from ..main import safe_xray_capture, xray_enabled
25 | except ImportError:
26 | # Fallback if import fails - create a simple pass-through decorator
27 | logger.warning("Failed to import safe_xray_capture from main, using fallback")
28 | def safe_xray_capture(name):
29 | def decorator(func):
30 | return func
31 | return decorator
32 | xray_enabled = False
33 |
34 | def check_readiness(api_base_url: str, dynamic_sd_model: bool) -> bool:
35 | """Check if SD Web UI is ready by invoking /option endpoint"""
36 | while True:
37 | try:
38 | logger.info('Checking service readiness...')
39 | # checking with options "sd_model_checkpoint" also for caching current model
40 | opts = invoke_get_options(api_base_url)
41 | logger.info('Service is ready.')
42 | if "sd_model_checkpoint" in opts:
43 | if opts['sd_model_checkpoint'] != None:
44 | current_model_name = opts['sd_model_checkpoint']
45 | logger.info(f'Init model is: {current_model_name}.')
46 | else:
47 | if dynamic_sd_model:
48 | logger.info(f'Dynamic SD model is enabled, init model is not loaded.')
49 | else:
50 | logger.error(f'Init model {current_model_name} failed to load.')
51 | break
52 | except Exception as e:
53 | logger.debug(repr(e))
54 | time.sleep(1)
55 | return True
56 |
57 | def handler(api_base_url: str, task_type: str, task_id: str, payload: dict, dynamic_sd_model: bool) -> dict:
58 | """Main handler for SD Web UI request"""
59 | response = {}
60 | try:
61 | logger.info(f"Start process {task_type} task with ID: {task_id}")
62 | match task_type:
63 | case 'text-to-image':
64 | # Compatiability for v1alpha1: Ensure there is an alwayson_scripts
65 | if 'alwayson_scripts' in payload:
66 | # Switch model if necessery
67 | if dynamic_sd_model and payload['alwayson_scripts']['sd_model_checkpoint']:
68 | new_model = payload['alwayson_scripts']['sd_model_checkpoint']
69 | logger.info(f'Try to switching model to: {new_model}.')
70 | current_model_name = switch_model(api_base_url, new_model)
71 | if current_model_name is None:
72 | raise Exception(f'Failed to switch model to {new_model}')
73 | logger.info(f'Current model is: {current_model_name}.')
74 | else:
75 | payload.update({'alwayson_scripts': {}})
76 |
77 | task_response = invoke_txt2img(api_base_url, payload)
78 |
79 | case 'image-to-image':
80 | # Compatiability for v1alpha1: Ensure there is an alwayson_scripts
81 | if 'alwayson_scripts' in payload:
82 | # Switch model if necessery
83 | if dynamic_sd_model and payload['alwayson_scripts']['sd_model_checkpoint']:
84 | new_model = payload['alwayson_scripts']['sd_model_checkpoint']
85 | logger.info(f'Try to switching model to: {new_model}.')
86 | current_model_name = switch_model(api_base_url, new_model)
87 | if current_model_name is None:
88 | raise Exception(f'Failed to switch model to {new_model}')
89 | logger.info(f'Current model is: {current_model_name}.')
90 | else:
91 | payload.update({'alwayson_scripts': {}})
92 |
93 | task_response = invoke_img2img(api_base_url, payload)
94 | case 'extra-single-image':
95 | # There is no alwayson_script in API spec
96 | task_response = invoke_extra_single_image(api_base_url, payload)
97 | case 'extra-batch-image':
98 | task_response = invoke_extra_batch_images(api_base_url, payload)
99 | case _:
100 | # Catch all
101 | logger.error(f'Unsupported task type: {task_type}, ignoring')
102 |
103 | imgOutputs = post_invocations(task_response)
104 | logger.info(f"Received {len(imgOutputs)} images")
105 | content = json.dumps(succeed(task_id, task_response))
106 | response["success"] = True
107 | response["image"] = imgOutputs
108 | response["content"] = content
109 | logger.info(f"End process {task_type} task with ID: {task_id}")
110 | except ReadTimeout as e:
111 | invoke_interrupt(api_base_url)
112 | content = json.dumps(failed(task_id, e))
113 | logger.error(f"{task_type} task with ID: {task_id} timeouted")
114 | traceback.print_exc()
115 | response["success"] = False
116 | response["content"] = content
117 | except Exception as e:
118 | content = json.dumps(failed(task_id, e))
119 | logger.error(f"{task_type} task with ID: {task_id} finished with error")
120 | traceback.print_exc()
121 | response["success"] = False
122 | response["content"] = content
123 | return response
124 |
125 | @safe_xray_capture('text-to-image')
126 | def invoke_txt2img(api_base_url: str, body) -> str:
127 | # Compatiability for v1alpha1: Move override_settings from header to body
128 | override_settings = {}
129 | if 'override_settings' in body['alwayson_scripts']:
130 | override_settings.update(body['alwayson_scripts']['override_settings'])
131 | if override_settings:
132 | if 'override_settings' in body:
133 | body['override_settings'].update(override_settings)
134 | else:
135 | body.update({'override_settings': override_settings})
136 |
137 | # Compatiability for v1alpha1: Remove header used for routing in v1alpha1 API request
138 | body.update({'alwayson_scripts': misc.exclude_keys(body['alwayson_scripts'], ALWAYSON_SCRIPTS_EXCLUDE_KEYS)})
139 |
140 | # Process image link in elsewhere in body
141 | body = download_image(body)
142 |
143 | response = http_action.do_invocations(api_base_url+"txt2img", body)
144 | return response
145 |
146 | @safe_xray_capture('image-to-image')
147 | def invoke_img2img(api_base_url: str, body: dict) -> str:
148 | """Image-to-Image request"""
149 | # Process image link
150 | body = download_image(body)
151 |
152 | # Compatiability for v1alpha1: Move override_settings from header to body
153 | override_settings = {}
154 | if 'override_settings' in body['alwayson_scripts']:
155 | override_settings.update(body['alwayson_scripts']['override_settings'])
156 | if override_settings:
157 | if 'override_settings' in body:
158 | body['override_settings'].update(override_settings)
159 | else:
160 | body.update({'override_settings': override_settings})
161 |
162 | # Compatiability for v1alpha2: Process image link in "Alwayson_scripts"
163 | # Plan to remove in next release
164 | if 'image_link' in body['alwayson_scripts']:
165 | body.update({"init_images": [body['alwayson_scripts']['image_link']]})
166 |
167 | # Compatiability for v1alpha1: Remove header used for routing in v1alpha1 API request
168 | body.update({'alwayson_scripts': misc.exclude_keys(body['alwayson_scripts'], ALWAYSON_SCRIPTS_EXCLUDE_KEYS)})
169 |
170 | response = http_action.do_invocations(api_base_url+"img2img", body)
171 | return response
172 |
173 | @safe_xray_capture('extra-single-image')
174 | def invoke_extra_single_image(api_base_url: str, body) -> str:
175 | body = download_image(body)
176 | response = http_action.do_invocations(api_base_url+"extra-single-image", body)
177 | return response
178 |
179 | @safe_xray_capture('extra-batch-images')
180 | def invoke_extra_batch_images(api_base_url: str, body) -> str:
181 | body = download_image(body)
182 | response = http_action.do_invocations(api_base_url+"extra-batch-images", body)
183 | return response
184 |
185 | def invoke_set_options(api_base_url: str, options: dict) -> str:
186 | return http_action.do_invocations(api_base_url+"options", options)
187 |
188 | def invoke_get_options(api_base_url: str) -> str:
189 | return http_action.do_invocations(api_base_url+"options")
190 |
191 | def invoke_get_model_names(api_base_url: str) -> str:
192 | return sorted([x["title"] for x in http_action.do_invocations(api_base_url+"sd-models")])
193 |
194 | def invoke_refresh_checkpoints(api_base_url: str) -> str:
195 | return http_action.do_invocations(api_base_url+"refresh-checkpoints", {})
196 |
197 | def invoke_unload_checkpoints(api_base_url: str) -> str:
198 | return http_action.do_invocations(api_base_url+"unload-checkpoint", {})
199 |
200 | def invoke_interrupt(api_base_url: str) -> str:
201 | return http_action.do_invocations(api_base_url+"interrupt", {})
202 |
203 | def switch_model(api_base_url: str, name: str) -> str:
204 | opts = invoke_get_options(api_base_url)
205 | current_model_name = opts['sd_model_checkpoint']
206 |
207 | if current_model_name == name:
208 | logger.info(f"Model {current_model_name} is currently loaded, ignore switch.")
209 | else:
210 | # refresh then check from model list
211 | invoke_refresh_checkpoints(api_base_url)
212 | models = invoke_get_model_names(api_base_url)
213 | if name in models:
214 | if (current_model_name != None):
215 | logger.info(f"Model {current_model_name} is currently loaded, unloading... ")
216 | try:
217 | invoke_unload_checkpoints(api_base_url)
218 | except HTTPError:
219 | logger.info(f"No model is currently loaded. Loading new model... ")
220 | options = {}
221 | options["sd_model_checkpoint"] = name
222 | invoke_set_options(api_base_url, options)
223 | current_model_name = name
224 | else:
225 | logger.error(f"Model {name} not found, keeping current model.")
226 | return None
227 |
228 | return current_model_name
229 |
230 | # Customizable for success responses
231 | def succeed(task_id, response):
232 | parameters = {}
233 | if 'parameters' in response: # text-to-image and image-to-image
234 | parameters = response['parameters']
235 | parameters['id_task'] = task_id
236 | parameters['image_seed'] = ','.join(
237 | str(x) for x in json.loads(response['info'])['all_seeds'])
238 | parameters['error_msg'] = ''
239 | elif 'html_info' in response: # extra-single-image and extra-batch-images
240 | parameters['html_info'] = response['html_info']
241 | parameters['id_task'] = task_id
242 | parameters['error_msg'] = ''
243 | return {
244 | 'images': [''],
245 | 'parameters': parameters,
246 | 'info': ''
247 | }
248 |
249 |
250 | # Customizable for failure responses
251 | def failed(task_id, exception):
252 | parameters = {}
253 | parameters['id_task'] = task_id
254 | parameters['status'] = 0
255 | parameters['error_msg'] = repr(exception)
256 | parameters['reason'] = exception.response.json() if hasattr(exception, "response") else None
257 | return {
258 | 'images': [''],
259 | 'parameters': parameters,
260 | 'info': ''
261 | }
262 |
263 | def download_image(obj, path=""):
264 | """Search URL in object, and replace all URL with content of URL"""
265 | if isinstance(obj, dict):
266 | for key, value in obj.items():
267 | new_path = f"{path}.{key}" if path else key
268 | obj[key] = download_image(value, new_path)
269 | elif isinstance(obj, list):
270 | for index, item in enumerate(obj):
271 | new_path = f"{path}[{index}]"
272 | obj[index] = download_image(item, new_path)
273 | elif isinstance(obj, str):
274 | if (obj.startswith('http') or obj.startswith('s3://')):
275 | logger.info(f"Found URL {obj} in {path}, replacing... ")
276 | try:
277 | image_byte = misc.encode_to_base64(http_action.get(obj))
278 | logger.info(f"Replaced {path} with content")
279 | except Exception as e:
280 | logger.error(f"Error fetching URL: {obj}")
281 | logger.error(f"Error: {str(e)}")
282 | return image_byte
283 | return obj
284 |
285 | def post_invocations(response):
286 | img_bytes = []
287 |
288 | if "images" in response.keys():
289 | for i in response["images"]:
290 | img_bytes.append(base64.b64decode(i))
291 |
292 | elif "image" in response.keys():
293 | img_bytes.append(base64.b64decode(response["image"]))
294 |
295 | return img_bytes
--------------------------------------------------------------------------------
/src/charts/sd_on_eks/Chart.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: v2
2 | name: sd-on-eks
3 | description: Stable Diffusion on EKS
4 | # A chart can be either an 'application' or a 'library' chart.
5 | #
6 | # Application charts are a collection of templates that can be packaged into versioned archives
7 | # to be deployed.
8 | #
9 | # Library charts provide useful utilities or functions for the chart developer. They're included as
10 | # a dependency of application charts to inject those utilities and functions into the rendering
11 | # pipeline. Library charts do not define any templates and therefore cannot be deployed.
12 | type: application
13 | # This is the chart version. This version number should be incremented each time you make changes
14 | # to the chart and its templates, including the app version.
15 | # Versions are expected to follow Semantic Versioning (https://semver.org/)
16 | version: 1.1.3
17 | # This is the version number of the application being deployed. This version number should be
18 | # incremented each time you make changes to the application. Versions are not expected to
19 | # follow Semantic Versioning. They should reflect the version the application is using.
20 | # It is recommended to use it with quotes.
21 | appVersion: "1.1.3"
22 |
--------------------------------------------------------------------------------
/src/charts/sd_on_eks/_helpers.tpl:
--------------------------------------------------------------------------------
1 | {{/*
2 | Expand the name of the chart.
3 | */}}
4 | {{- define "sdchart.name" -}}
5 | {{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
6 | {{- end }}
7 |
8 | {{/*
9 | Create a default fully qualified app name.
10 | We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec).
11 | If release name contains chart name it will be used as a full name.
12 | */}}
13 | {{- define "sdchart.fullname" -}}
14 | {{- if .Values.fullnameOverride }}
15 | {{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
16 | {{- else }}
17 | {{- $name := default .Chart.Name .Values.nameOverride }}
18 | {{- if contains $name .Release.Name }}
19 | {{- .Release.Name | trunc 63 | trimSuffix "-" }}
20 | {{- else }}
21 | {{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
22 | {{- end }}
23 | {{- end }}
24 | {{- end }}
25 |
26 | {{/*
27 | Create chart name and version as used by the chart label.
28 | */}}
29 | {{- define "sdchart.chart" -}}
30 | {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
31 | {{- end }}
32 |
33 | {{/*
34 | Common labels
35 | */}}
36 | {{- define "sdchart.labels" -}}
37 | helm.sh/chart: {{ include "sdchart.chart" . }}
38 | {{ include "sdchart.selectorLabels" . }}
39 | {{- if .Chart.AppVersion }}
40 | app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
41 | {{- end }}
42 | app.kubernetes.io/managed-by: {{ .Release.Service }}
43 | {{- end }}
44 |
45 | {{/*
46 | Selector labels
47 | */}}
48 | {{- define "sdchart.selectorLabels" -}}
49 | app.kubernetes.io/name: {{ include "sdchart.name" . }}
50 | app.kubernetes.io/instance: {{ .Release.Name }}
51 | {{- end }}
52 |
53 | {{/*
54 | Create the name of the service account to use
55 | */}}
56 | {{- define "sdchart.serviceAccountName" -}}
57 | {{- if .Values.serviceAccount.create }}
58 | {{- default (include "sdchart.fullname" .) .Values.serviceAccount.name }}
59 | {{- else }}
60 | {{- default "default" .Values.serviceAccount.name }}
61 | {{- end }}
62 | {{- end }}
63 |
--------------------------------------------------------------------------------
/src/charts/sd_on_eks/templates/_helpers.tpl:
--------------------------------------------------------------------------------
1 | {{/*
2 | Expand the name of the chart.
3 | */}}
4 | {{- define "sdchart.name" -}}
5 | {{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
6 | {{- end }}
7 |
8 | {{/*
9 | Create a default fully qualified app name.
10 | We truncate at 63 chars because some Kubernetes name fields are limited to this (by the DNS naming spec).
11 | If release name contains chart name it will be used as a full name.
12 | */}}
13 | {{- define "sdchart.fullname" -}}
14 | {{- if .Values.fullnameOverride }}
15 | {{- .Values.fullnameOverride | trunc 63 | trimSuffix "-" }}
16 | {{- else }}
17 | {{- $name := default .Chart.Name .Values.nameOverride }}
18 | {{- if contains $name .Release.Name }}
19 | {{- .Release.Name | trunc 63 | trimSuffix "-" }}
20 | {{- else }}
21 | {{- printf "%s-%s" .Release.Name $name | trunc 63 | trimSuffix "-" }}
22 | {{- end }}
23 | {{- end }}
24 | {{- end }}
25 |
26 | {{/*
27 | Create chart name and version as used by the chart label.
28 | */}}
29 | {{- define "sdchart.chart" -}}
30 | {{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
31 | {{- end }}
32 |
33 | {{/*
34 | Common labels
35 | */}}
36 | {{- define "sdchart.labels" -}}
37 | helm.sh/chart: {{ include "sdchart.chart" . }}
38 | {{ include "sdchart.selectorLabels" . }}
39 | {{- if .Chart.AppVersion }}
40 | app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
41 | {{- end }}
42 | app.kubernetes.io/managed-by: {{ .Release.Service }}
43 | {{- end }}
44 |
45 | {{/*
46 | Selector labels
47 | */}}
48 | {{- define "sdchart.selectorLabels" -}}
49 | app.kubernetes.io/name: {{ include "sdchart.name" . }}
50 | app.kubernetes.io/instance: {{ .Release.Name }}
51 | {{- end }}
52 |
53 | {{/*
54 | Create the name of the service account to use
55 | */}}
56 | {{- define "sdchart.serviceAccountName" -}}
57 | {{- if .Values.serviceAccount.create }}
58 | {{- default (include "sdchart.fullname" .) .Values.serviceAccount.name }}
59 | {{- else }}
60 | {{- default "default" .Values.serviceAccount.name }}
61 | {{- end }}
62 | {{- end }}
63 |
--------------------------------------------------------------------------------
/src/charts/sd_on_eks/templates/aws-sqs-queue-scaledobject.yaml:
--------------------------------------------------------------------------------
1 | {{- if .Values.runtime.scaling.enabled }}
2 | apiVersion: keda.sh/v1alpha1
3 | kind: ScaledObject
4 | metadata:
5 | name: {{ include "sdchart.fullname" . }}-aws-sqs-queue-scaledobject
6 | namespace: {{ .Release.Namespace }}
7 | labels:
8 | {{- include "sdchart.labels" . | nindent 4 }}
9 | {{- if .Values.runtime.labels }}
10 | {{- toYaml .Values.runtime.labels | nindent 4 }}
11 | {{- end }}
12 | spec:
13 | cooldownPeriod: {{ .Values.runtime.scaling.cooldownPeriod }}
14 | maxReplicaCount: {{ .Values.runtime.scaling.maxReplicaCount }}
15 | minReplicaCount: {{ .Values.runtime.scaling.minReplicaCount }}
16 | pollingInterval: {{ .Values.runtime.scaling.pollingInterval }}
17 | scaleOnInFlight: {{ .Values.runtime.scaling.scaleOnInFlight }}
18 | {{- if .Values.runtime.scaling.extraHPAConfig }}
19 | advanced:
20 | horizontalPodAutoscalerConfig:
21 | behavior:
22 | {{- toYaml .Values.runtime.scaling.extraHPAConfig | nindent 8 }}
23 | {{- end }}
24 | scaleTargetRef:
25 | name: {{ include "sdchart.fullname" . }}-inference-api
26 | triggers:
27 | - authenticationRef:
28 | name: {{ include "sdchart.fullname" . }}-keda-trigger-auth-aws-credentials
29 | metadata:
30 | awsRegion: {{ .Values.global.awsRegion }}
31 | identityOwner: operator
32 | queueLength: {{ quote .Values.runtime.scaling.queueLength }}
33 | queueURL: {{ .Values.runtime.queueAgent.sqsQueueUrl }}
34 | type: aws-sqs-queue
35 | {{- end }}
--------------------------------------------------------------------------------
/src/charts/sd_on_eks/templates/configmap-comfyui.yml:
--------------------------------------------------------------------------------
1 | {{- if (eq "comfyui" .Values.runtime.type) }}
2 | apiVersion: v1
3 | kind: ConfigMap
4 | metadata:
5 | name: {{ include "sdchart.fullname" . }}-comfyui-config
6 | namespace: {{ .Release.Namespace }}
7 | labels:
8 | {{- include "sdchart.labels" . | nindent 4 }}
9 | {{- if .Values.runtime.labels }}
10 | {{- toYaml .Values.runtime.labels | nindent 4 }}
11 | {{- end }}
12 |
13 | {{- if .Values.runtime.annotations }}
14 | annotations:
15 | {{ toYaml .Values.runtime.annotations | nindent 4 }}
16 | {{- end }}
17 | data: {}
18 | {{- end }}
--------------------------------------------------------------------------------
/src/charts/sd_on_eks/templates/configmap-queue-agent.yml:
--------------------------------------------------------------------------------
1 | apiVersion: v1
2 | kind: ConfigMap
3 | metadata:
4 | name: {{ include "sdchart.fullname" . }}-queue-agent-config
5 | namespace: {{ .Release.Namespace }}
6 | labels:
7 | {{- include "sdchart.labels" . | nindent 4 }}
8 | {{- if .Values.runtime.labels }}
9 | {{- toYaml .Values.runtime.labels | nindent 4 }}
10 | {{- end }}
11 |
12 | {{- if .Values.runtime.annotations }}
13 | annotations:
14 | {{ toYaml .Values.runtime.annotations | nindent 4 }}
15 | {{- end }}
16 | data:
17 | AWS_DEFAULT_REGION: {{ quote .Values.global.awsRegion }}
18 | RUNTIME_NAME: {{ quote .Values.global.runtime }}
19 | SQS_QUEUE_URL: {{ quote .Values.runtime.queueAgent.sqsQueueUrl }}
20 | S3_BUCKET: {{ quote .Values.runtime.queueAgent.s3Bucket }}
21 | SNS_TOPIC_ARN: {{ quote .Values.runtime.queueAgent.snsTopicArn }}
22 | RUNTIME_TYPE: {{ quote .Values.runtime.type }}
23 | {{- if .Values.runtime.queueAgent.dynamicModel }}
24 | DYNAMIC_SD_MODEL: "true"
25 | {{- end }}
26 |
--------------------------------------------------------------------------------
/src/charts/sd_on_eks/templates/configmap-sdwebui.yml:
--------------------------------------------------------------------------------
1 | {{- if (eq "sdwebui" .Values.runtime.type) }}
2 | apiVersion: v1
3 | kind: ConfigMap
4 | metadata:
5 | name: {{ include "sdchart.fullname" . }}-sdwebui-config
6 | namespace: {{ .Release.Namespace }}
7 | labels:
8 | {{- include "sdchart.labels" . | nindent 4 }}
9 | {{- if .Values.runtime.labels }}
10 | {{- toYaml .Values.runtime.labels | nindent 4 }}
11 | {{- end }}
12 |
13 | {{- if .Values.runtime.annotations }}
14 | annotations:
15 | {{ toYaml .Values.runtime.annotations | nindent 4 }}
16 | {{- end }}
17 | data:
18 | config.json: |
19 | {"disable_mmap_load_safetensors": true}
20 | {{- end }}
--------------------------------------------------------------------------------
/src/charts/sd_on_eks/templates/deployment-comfyui.yaml:
--------------------------------------------------------------------------------
1 | {{- if (eq "comfyui" .Values.runtime.type) }}
2 | apiVersion: apps/v1
3 | kind: Deployment
4 | metadata:
5 | name: {{ include "sdchart.fullname" . }}-inference-api
6 | namespace: {{ .Release.Namespace }}
7 | labels:
8 | {{- include "sdchart.labels" . | nindent 4 }}
9 | {{- if .Values.runtime.labels }}
10 | {{- toYaml .Values.runtime.labels | nindent 4 }}
11 | {{- end }}
12 | runtime-type: comfyui
13 |
14 | {{- if .Values.runtime.annotations }}
15 | annotations:
16 | {{ toYaml .Values.runtime.annotations | nindent 4 }}
17 | {{- end }}
18 |
19 | spec:
20 | replicas: {{ .Values.runtime.replicas }}
21 | selector:
22 | matchLabels:
23 | app: inference-api
24 | {{- include "sdchart.selectorLabels" . | nindent 6 }}
25 | strategy:
26 | rollingUpdate:
27 | maxSurge: 100%
28 | maxUnavailable: 0
29 | type: RollingUpdate
30 | template:
31 | metadata:
32 | labels:
33 | app: inference-api
34 | {{- include "sdchart.selectorLabels" . | nindent 8 }}
35 | spec:
36 | containers:
37 | - name: inference-api
38 | image: {{ .Values.runtime.inferenceApi.image.repository }}:{{ .Values.runtime.inferenceApi.image.tag }}
39 | env:
40 | {{- if .Values.runtime.queueAgent.commandArguments }}
41 | - name: EXTRA_CMD_ARG
42 | value: {{ .Values.runtime.inferenceApi.commandArguments }}
43 | {{- end }}
44 | {{- if .Values.runtime.inferenceApi.extraEnv }}
45 | {{- toYaml .Values.runtime.inferenceApi.extraEnv | nindent 8 }}
46 | {{- end }}
47 | resources:
48 | {{- toYaml .Values.runtime.inferenceApi.resources | nindent 10 }}
49 | volumeMounts:
50 | - mountPath: {{ .Values.runtime.inferenceApi.modelMountPath }}
51 | name: models
52 | imagePullPolicy: {{ .Values.runtime.inferenceApi.imagePullPolicy }}
53 | startupProbe:
54 | httpGet:
55 | path: /system_stats
56 | port: 8080
57 | failureThreshold: 120
58 | periodSeconds: 1
59 | - name: queue-agent
60 | envFrom:
61 | - configMapRef:
62 | name: {{ include "sdchart.fullname" . }}-queue-agent-config
63 | env:
64 | {{- if .Values.runtime.queueAgent.extraEnv }}
65 | {{- toYaml .Values.runtime.queueAgent.extraEnv | nindent 8 }}
66 | {{- end }}
67 | {{- if .Values.runtime.queueAgent.xray.enabled }}
68 | - name: AWS_XRAY_DAEMON_ADDRESS
69 | value: localhost:2000
70 | - name: AWS_XRAY_CONTEXT_MISSING
71 | value: IGNORE_ERROR
72 | {{- else }}
73 | - name: DISABLE_XRAY
74 | value: "true"
75 | {{- end }}
76 | image: {{ .Values.runtime.queueAgent.image.repository }}:{{ .Values.runtime.queueAgent.image.tag }}
77 | imagePullPolicy: {{ .Values.runtime.queueAgent.imagePullPolicy }}
78 | volumeMounts:
79 | - mountPath: /tmp/models
80 | name: models
81 | resources:
82 | {{- toYaml .Values.runtime.queueAgent.resources | nindent 10 }}
83 | {{- if .Values.runtime.queueAgent.xray.enabled }}
84 | - name: xray-daemon
85 | image: {{ .Values.runtime.queueAgent.xray.daemon.image.repository }}:{{ .Values.runtime.queueAgent.xray.daemon.image.tag }}
86 | ports:
87 | - containerPort: 2000
88 | protocol: UDP
89 | {{- end }}
90 | serviceAccountName: {{ .Values.runtime.serviceAccountName }}
91 | terminationGracePeriodSeconds: 60
92 | tolerations:
93 | - effect: NoSchedule
94 | key: nvidia.com/gpu
95 | operator: Exists
96 | - effect: NoSchedule
97 | key: runtime
98 | value: {{ include "sdchart.fullname" . }}
99 | volumes:
100 | - name: models
101 | {{- if .Values.runtime.persistence.enabled }}
102 | persistentVolumeClaim:
103 | {{- if .Values.runtime.persistence.existingClaim }}
104 | claimName: {{ .Values.runtime.persistence.existingClaim }}
105 | {{- else }}
106 | claimName: {{ include "sdchart.fullname" . }}-model-claim
107 | {{- end }}
108 | {{- else }}
109 | emptyDir: {}
110 | {{- end }}
111 | {{- end }}
--------------------------------------------------------------------------------
/src/charts/sd_on_eks/templates/deployment-sdwebui.yaml:
--------------------------------------------------------------------------------
1 | {{- if (eq "sdwebui" .Values.runtime.type) }}
2 | apiVersion: apps/v1
3 | kind: Deployment
4 | metadata:
5 | name: {{ include "sdchart.fullname" . }}-inference-api
6 | namespace: {{ .Release.Namespace }}
7 | labels:
8 | {{- include "sdchart.labels" . | nindent 4 }}
9 | {{- if .Values.runtime.labels }}
10 | {{- toYaml .Values.runtime.labels | nindent 4 }}
11 | {{- end }}
12 | runtime-type: sdwebui
13 | {{- if .Values.runtime.annotations }}
14 | annotations:
15 | {{- toYaml .Values.runtime.annotations | nindent 4 }}
16 | {{- end }}
17 |
18 | spec:
19 | replicas: {{ .Values.runtime.replicas }}
20 | selector:
21 | matchLabels:
22 | app: inference-api
23 | {{- include "sdchart.selectorLabels" . | nindent 6 }}
24 | strategy:
25 | rollingUpdate:
26 | maxSurge: 100%
27 | maxUnavailable: 0
28 | type: RollingUpdate
29 | template:
30 | metadata:
31 | labels:
32 | app: inference-api
33 | {{- include "sdchart.selectorLabels" . | nindent 8 }}
34 | spec:
35 | containers:
36 | - name: inference-api
37 | image: {{ .Values.runtime.inferenceApi.image.repository }}:{{ .Values.runtime.inferenceApi.image.tag }}
38 | env:
39 | {{- if .Values.runtime.queueAgent.dynamicModel }}
40 | - name: DYNAMIC_SD_MODEL
41 | value: "true"
42 | {{- else }}
43 | - name: SD_MODEL_CHECKPOINT
44 | value: {{ quote .Values.runtime.inferenceApi.modelFilename }}
45 | {{- end }}
46 | - name: API_ONLY
47 | value: "true"
48 | - name: CONFIG_FILE
49 | value: "/tmp/config.json"
50 | {{- if .Values.runtime.inferenceApi.commandArguments }}
51 | - name: EXTRA_CMD_ARG
52 | value: {{ .Values.runtime.inferenceApi.commandArguments }}
53 | {{- end }}
54 | {{- if .Values.runtime.inferenceApi.extraEnv }}
55 | {{- toYaml .Values.runtime.inferenceApi.extraEnv | nindent 8 }}
56 | {{- end }}
57 | resources:
58 | {{- toYaml .Values.runtime.inferenceApi.resources | nindent 10 }}
59 | volumeMounts:
60 | - mountPath: {{ .Values.runtime.inferenceApi.modelMountPath }}
61 | name: models
62 | - mountPath: "/tmp/config.json"
63 | name: config
64 | subPath: config.json
65 | imagePullPolicy: {{ .Values.runtime.inferenceApi.imagePullPolicy }}
66 | startupProbe:
67 | httpGet:
68 | path: /sdapi/v1/memory
69 | port: 8080
70 | failureThreshold: 120
71 | periodSeconds: 1
72 | - name: queue-agent
73 | envFrom:
74 | - configMapRef:
75 | name: {{ include "sdchart.fullname" . }}-queue-agent-config
76 | env:
77 | {{- if .Values.runtime.queueAgent.extraEnv }}
78 | {{- toYaml .Values.runtime.queueAgent.extraEnv | nindent 8 }}
79 | {{- end }}
80 | {{- if .Values.runtime.queueAgent.xray.enabled }}
81 | - name: AWS_XRAY_DAEMON_ADDRESS
82 | value: localhost:2000
83 | - name: AWS_XRAY_CONTEXT_MISSING
84 | value: IGNORE_ERROR
85 | {{- else }}
86 | - name: DISABLE_XRAY
87 | value: "true"
88 | {{- end }}
89 | image: {{ .Values.runtime.queueAgent.image.repository }}:{{ .Values.runtime.queueAgent.image.tag }}
90 | imagePullPolicy: {{ .Values.runtime.queueAgent.imagePullPolicy }}
91 | resources:
92 | {{- toYaml .Values.runtime.queueAgent.resources | nindent 10 }}
93 | {{- if .Values.runtime.queueAgent.xray.enabled }}
94 | - name: xray-daemon
95 | image: {{ .Values.runtime.queueAgent.xray.daemon.image.repository }}:{{ .Values.runtime.queueAgent.xray.daemon.image.tag }}
96 | ports:
97 | - containerPort: 2000
98 | protocol: UDP
99 | {{- end }}
100 | serviceAccountName: {{ .Values.runtime.serviceAccountName }}
101 | terminationGracePeriodSeconds: 60
102 | tolerations:
103 | - effect: NoSchedule
104 | key: nvidia.com/gpu
105 | operator: Exists
106 | - effect: NoSchedule
107 | key: runtime
108 | value: {{ include "sdchart.fullname" . }}
109 | volumes:
110 | - name: models
111 | {{- if .Values.runtime.persistence.enabled }}
112 | persistentVolumeClaim:
113 | {{- if .Values.runtime.persistence.existingClaim }}
114 | claimName: {{ .Values.runtime.persistence.existingClaim }}
115 | {{- else }}
116 | claimName: {{ include "sdchart.fullname" . }}-model-claim
117 | {{- end }}
118 | {{- else }}
119 | emptyDir: {}
120 | {{- end }}
121 | - name: config
122 | configMap:
123 | name: {{ include "sdchart.fullname" . }}-sdwebui-config
124 | {{- end }}
--------------------------------------------------------------------------------
/src/charts/sd_on_eks/templates/keda-trigger-auth-aws-credentials.yaml:
--------------------------------------------------------------------------------
1 | {{- if .Values.runtime.scaling.enabled }}
2 | apiVersion: keda.sh/v1alpha1
3 | kind: TriggerAuthentication
4 | metadata:
5 | name: {{ include "sdchart.fullname" . }}-keda-trigger-auth-aws-credentials
6 | namespace: {{ .Release.Namespace }}
7 | labels:
8 | {{- include "sdchart.labels" . | nindent 4 }}
9 | {{- if .Values.runtime.labels }}
10 | {{- toYaml .Values.runtime.labels | nindent 4 }}
11 | {{- end }}
12 | spec:
13 | podIdentity:
14 | provider: aws-eks
15 | {{- end }}
--------------------------------------------------------------------------------
/src/charts/sd_on_eks/templates/nodeclass.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: karpenter.k8s.aws/v1
2 | kind: EC2NodeClass
3 | metadata:
4 | name: {{ include "sdchart.fullname" . }}-nodeclass-gpu
5 | labels:
6 | {{- include "sdchart.labels" . | nindent 4 }}
7 | {{- if .Values.runtime.labels }}
8 | {{- toYaml .Values.runtime.labels | nindent 4 }}
9 | {{- end }}
10 | spec:
11 | amiSelectorTerms:
12 | - alias: {{ lower .Values.karpenter.nodeTemplate.amiFamily }}@latest
13 |
14 | subnetSelectorTerms:
15 | - tags:
16 | "aws-cdk:subnet-type": "Private"
17 | "aws:cloudformation:stack-name": {{ .Values.global.stackName }}
18 |
19 | securityGroupSelectorTerms:
20 | - tags:
21 | "aws:eks:cluster-name": {{ quote .Values.global.stackName }}
22 |
23 | role: {{ .Values.karpenter.nodeTemplate.iamRole }}
24 |
25 | tags:
26 | stack: {{ .Values.global.stackName }}
27 | runtime: {{ .Release.Name }}
28 | {{- if .Values.karpenter.nodeTemplate.tags }}
29 | {{- toYaml .Values.karpenter.nodeTemplate.tags | nindent 4 }}
30 | {{- end }}
31 |
32 | metadataOptions:
33 | httpEndpoint: enabled
34 | httpProtocolIPv6: disabled
35 | httpPutResponseHopLimit: 2
36 | httpTokens: optional
37 |
38 | blockDeviceMappings:
39 | - deviceName: /dev/xvda
40 | ebs:
41 | volumeSize: {{ .Values.karpenter.nodeTemplate.osVolume.volumeSize }}
42 | volumeType: {{ .Values.karpenter.nodeTemplate.osVolume.volumeType }}
43 | deleteOnTermination: {{ .Values.karpenter.nodeTemplate.osVolume.deleteOnTermination }}
44 | iops: {{ .Values.karpenter.nodeTemplate.osVolume.iops }}
45 | throughput: {{ .Values.karpenter.nodeTemplate.osVolume.throughput }}
46 | {{- if .Values.karpenter.nodeTemplate.dataVolume }}
47 | - deviceName: /dev/xvdb
48 | ebs:
49 | volumeSize: {{ .Values.karpenter.nodeTemplate.dataVolume.volumeSize }}
50 | volumeType: {{ .Values.karpenter.nodeTemplate.dataVolume.volumeType }}
51 | deleteOnTermination: {{ .Values.karpenter.nodeTemplate.dataVolume.deleteOnTermination }}
52 | iops: {{ .Values.karpenter.nodeTemplate.dataVolume.iops }}
53 | throughput: {{ .Values.karpenter.nodeTemplate.dataVolume.throughput }}
54 | {{- if .Values.karpenter.nodeTemplate.dataVolume.snapshotID }}
55 | snapshotID: {{ .Values.karpenter.nodeTemplate.dataVolume.snapshotID }}
56 | {{- end }}
57 | {{- end }}
58 |
59 | {{- if .Values.karpenter.nodeTemplate.userData }}
60 | userData: |-
61 | {{- tpl .Values.karpenter.nodeTemplate.userData . | nindent 4 }}
62 | {{- end }}
--------------------------------------------------------------------------------
/src/charts/sd_on_eks/templates/nodepool.yaml:
--------------------------------------------------------------------------------
1 | apiVersion: karpenter.sh/v1
2 | kind: NodePool
3 | metadata:
4 | name: {{ include "sdchart.fullname" . }}-nodepool-gpu
5 | labels:
6 | {{- include "sdchart.labels" . | nindent 4 }}
7 | {{- if .Values.runtime.labels }}
8 | {{- toYaml .Values.runtime.labels | nindent 4 }}
9 | {{- end }}
10 | spec:
11 | template:
12 | metadata:
13 | {{- if .Values.karpenter.provisioner.labels }}
14 | labels:
15 | {{- toYaml .Values.karpenter.provisioner.labels | nindent 8 }}
16 | {{- end }}
17 | {{- if .Values.karpenter.provisioner.annotations }}
18 | annotations:
19 | {{- toYaml .Values.karpenter.provisioner.annotations | nindent 8 }}
20 | {{- end }}
21 | spec:
22 | nodeClassRef:
23 | group: karpenter.k8s.aws
24 | kind: EC2NodeClass
25 | name: {{ include "sdchart.fullname" . }}-nodeclass-gpu
26 | taints:
27 | - effect: NoSchedule
28 | key: nvidia.com/gpu
29 | - effect: NoSchedule
30 | key: runtime
31 | value: {{ include "sdchart.fullname" . }}
32 | {{- if .Values.karpenter.provisioner.extraTaints }}
33 | {{- toYaml .Values.karpenter.provisioner.extraTaints | nindent 6 }}
34 | {{- end }}
35 | {{- if .Values.karpenter.provisioner.disruption.expireAfter }}
36 | expireAfter: {{ .Values.karpenter.provisioner.disruption.expireAfter }}
37 | {{- end }}
38 | requirements:
39 | - key: karpenter.sh/capacity-type
40 | operator: In
41 | values:
42 | {{- if .Values.karpenter.provisioner.capacityType.spot }}
43 | - spot
44 | {{- end }}
45 | {{- if .Values.karpenter.provisioner.capacityType.onDemand }}
46 | - on-demand
47 | {{- end }}
48 | {{- if .Values.karpenter.provisioner.instanceType }}
49 | - key: node.kubernetes.io/instance-type
50 | operator: In
51 | values:
52 | {{- toYaml .Values.karpenter.provisioner.instanceType | nindent 8 }}
53 | {{- end }}
54 | {{- if .Values.karpenter.provisioner.extraRequirements }}
55 | {{- toYaml .Values.karpenter.provisioner.extraRequirements | nindent 6 }}
56 | {{- end }}
57 | disruption:
58 | {{- if .Values.karpenter.provisioner.consolidation }}
59 | consolidationPolicy: WhenEmptyOrUnderutilized
60 | {{- else }}
61 | consolidationPolicy: WhenEmpty
62 | {{- end }}
63 | consolidateAfter: {{ .Values.karpenter.provisioner.disruption.consolidateAfter }}
64 | {{- if .Values.karpenter.provisioner.resourceLimits }}
65 | limits:
66 | {{- toYaml .Values.karpenter.provisioner.resourceLimits | nindent 4 }}
67 | {{- end }}
68 |
--------------------------------------------------------------------------------
/src/charts/sd_on_eks/templates/persistentvolume-s3.yaml:
--------------------------------------------------------------------------------
1 | {{- if and (.Values.runtime.persistence.enabled) (.Values.runtime.persistence.s3.enabled) }}
2 | apiVersion: v1
3 | kind: PersistentVolume
4 | metadata:
5 | name: {{ include "sdchart.fullname" . }}-s3-model-volume
6 | {{- if .Values.runtime.persistence.annotations }}
7 | annotations:
8 | {{ toYaml .Values.runtime.persistence.annotations | nindent 4 }}
9 | {{- end }}
10 | labels:
11 | {{- include "sdchart.labels" . | nindent 4 }}
12 | {{- if .Values.runtime.labels }}
13 | {{- toYaml .Values.runtime.labels | nindent 4 }}
14 | {{- end }}
15 | {{- if .Values.runtime.persistence.labels }}
16 | {{- toYaml .Values.runtime.persistence.labels | nindent 4 }}
17 | {{- end }}
18 | spec:
19 | capacity:
20 | storage: {{ .Values.runtime.persistence.size }}
21 | accessModes:
22 | {{- toYaml .Values.runtime.persistence.accessModes | nindent 2 }}
23 | mountOptions:
24 | - allow-delete
25 | - allow-other
26 | - file-mode=777
27 | - dir-mode=777
28 | csi:
29 | driver: s3.csi.aws.com
30 | volumeHandle: s3-csi-driver-volume
31 | volumeAttributes:
32 | bucketName: {{ .Values.runtime.persistence.s3.modelBucket }}
33 | {{- end }}
--------------------------------------------------------------------------------
/src/charts/sd_on_eks/templates/persistentvolumeclaim.yaml:
--------------------------------------------------------------------------------
1 | {{- if and (.Values.runtime.persistence.enabled) (not (.Values.runtime.persistence.existingClaim)) }}
2 | apiVersion: v1
3 | kind: PersistentVolumeClaim
4 | metadata:
5 | name: {{ include "sdchart.fullname" . }}-model-claim
6 | {{- if .Values.runtime.persistence.annotations }}
7 | annotations:
8 | {{- toYaml .Values.runtime.persistence.annotations | nindent 4 }}
9 | {{- end }}
10 | labels:
11 | {{- include "sdchart.labels" . | nindent 4 }}
12 | {{- if .Values.runtime.labels }}
13 | {{- toYaml .Values.runtime.labels | nindent 4 }}
14 | {{- end }}
15 | {{- if .Values.runtime.persistence.labels }}
16 | {{- toYaml .Values.runtime.persistence.labels | nindent 4 }}
17 | {{- end }}
18 | spec:
19 | accessModes:
20 | {{- toYaml .Values.runtime.persistence.accessModes | nindent 2 }}
21 | resources:
22 | requests:
23 | storage: "{{ .Values.runtime.persistence.size }}"
24 | {{- if .Values.runtime.persistence.storageClass }}
25 | {{- if (eq "-" .Values.runtime.persistence.storageClass) }}
26 | storageClassName: ""
27 | {{- else }}
28 | storageClassName: "{{ .Values.runtime.persistence.storageClass }}"
29 | {{- end }}
30 | {{- end }}
31 | {{- if .Values.runtime.persistence.existingVolume }}
32 | volumeName: "{{ .Values.runtime.persistence.existingVolume }}"
33 | {{- end }}
34 | {{- end }}
--------------------------------------------------------------------------------
/src/charts/sd_on_eks/values.yaml:
--------------------------------------------------------------------------------
1 | global:
2 | awsRegion: us-west-2
3 | stackName: ""
4 | runtime: ""
5 |
6 | karpenter:
7 | provisioner:
8 | labels: {}
9 | capacityType:
10 | onDemand: true
11 | spot: true
12 | instanceType:
13 | - "g5.xlarge"
14 | - "g5.2xlarge"
15 | extraRequirements: []
16 | extraTaints: []
17 | resourceLimits:
18 | nvidia.com/gpu: 100
19 | consolidation: true
20 | disruption:
21 | consolidateAfter: 30s
22 | expireAfter: Never
23 | nodeTemplate:
24 | iamRole: ""
25 | securityGroupSelector: {}
26 | subnetSelector: {}
27 | tags: {}
28 | amiFamily: Bottlerocket
29 | osVolume:
30 | volumeSize: 10Gi
31 | volumeType: gp3
32 | deleteOnTermination: true
33 | iops: 3000
34 | throughput: 125
35 | dataVolume:
36 | volumeSize: 150Gi
37 | volumeType: gp3
38 | deleteOnTermination: true
39 | iops: 4000
40 | throughput: 1000
41 | userData: ""
42 | runtime:
43 | type: "sdwebui"
44 | labels: {}
45 | annotations: {}
46 | serviceAccountName: runtime-sa
47 | replicas: 1
48 | scaling:
49 | enabled: true
50 | queueLength: 10
51 | cooldownPeriod: 60
52 | maxReplicaCount: 20
53 | minReplicaCount: 0
54 | pollingInterval: 1
55 | scaleOnInFlight: false
56 | extraHPAConfig: {}
57 | inferenceApi:
58 | image:
59 | repository: public.ecr.aws/bingjiao/sd-on-eks/sdwebui
60 | tag: latest
61 | modelFilename: ""
62 | modelMountPath: /opt/ml/code/models
63 | commandArguments: ""
64 | extraEnv: {}
65 | imagePullPolicy: IfNotPresent
66 | resources:
67 | limits:
68 | nvidia.com/gpu: "1"
69 | requests:
70 | nvidia.com/gpu: "1"
71 | cpu: 2500m
72 | memory: 6Gi
73 | queueAgent:
74 | image:
75 | repository: public.ecr.aws/bingjiao/sd-on-eks/queue-agent
76 | tag: latest
77 | extraEnv: {}
78 | dynamicModel: false
79 | imagePullPolicy: IfNotPresent
80 | s3Bucket: ""
81 | snsTopicArn: ""
82 | sqsQueueUrl: ""
83 | resources:
84 | requests:
85 | cpu: 500m
86 | memory: 512Mi
87 | xray:
88 | enabled: true
89 | daemon:
90 | image:
91 | repository: public.ecr.aws/xray/aws-xray-daemon
92 | tag: 3.3.14
93 | persistence:
94 | enabled: true
95 | existingClaim: ""
96 | existingVolume: ""
97 | labels: {}
98 | annotations: {}
99 | storageClass: ""
100 | size: 2Ti
101 | accessModes:
102 | - ReadWriteMany
103 | s3:
104 | enabled: true
105 | modelBucket: ""
106 |
--------------------------------------------------------------------------------
/src/frontend/input_function/v1alpha1/app.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 |
4 | import os
5 | import boto3
6 | import json
7 | import traceback
8 |
9 | sns_client = boto3.client('sns')
10 |
11 |
12 | def lambda_handler(event, context):
13 | if event['httpMethod'] == 'POST':
14 | try:
15 | payload = json.loads(event['body'])
16 | val = validate(payload)
17 | if (val != "success"):
18 | return {
19 | 'statusCode': 400,
20 | 'body': f"Incorrect payload structure, {val}"
21 | }
22 |
23 | task = payload['alwayson_scripts']['task']
24 | id_task = payload['alwayson_scripts']['id_task']
25 | sd_model_checkpoint = payload['alwayson_scripts']['sd_model_checkpoint']
26 | prefix = payload['alwayson_scripts']['save_dir']
27 | s3_output_path = f"{os.environ['S3_OUTPUT_BUCKET']}/{prefix}/{id_task}"
28 |
29 | payload['s3_output_path'] = s3_output_path
30 |
31 | print(event['headers'])
32 | print(event['queryStringParameters'])
33 |
34 | msg = {"metadata": {
35 | "id": id_task,
36 | "runtime": "legacy",
37 | "tasktype": task,
38 | "prefix": prefix,
39 | "context": {}
40 | },"content": payload}
41 |
42 | sns_client.publish(
43 | TargetArn=os.environ['SNS_TOPIC_ARN'],
44 | Message=json.dumps(msg),
45 | MessageAttributes={
46 | 'sd_model_checkpoint': {
47 | 'DataType': 'String',
48 | 'StringValue': sd_model_checkpoint
49 | }
50 | }
51 | )
52 |
53 | return {
54 | 'statusCode': 200,
55 | 'body': json.dumps({
56 | "id_task": id_task,
57 | "sd_model_checkpoint": sd_model_checkpoint,
58 | "output_location": f"s3://{s3_output_path}"
59 | })
60 | }
61 |
62 | except Exception as e:
63 | traceback.print_exc()
64 | return {
65 | 'statusCode': 400,
66 | 'body': str(e)
67 | }
68 | else:
69 | return {
70 | 'statusCode': 400,
71 | 'body': "Unsupported HTTP method"
72 | }
73 |
74 | def validate(body: dict) -> str:
75 | result = "success"
76 | if 'alwayson_scripts' not in body.keys():
77 | result = "alwayson_scripts is missing"
78 | else:
79 | if "task" not in body["alwayson_scripts"].keys():
80 | result = "task is missing"
81 | if "sd_model_checkpoint" not in body["alwayson_scripts"].keys():
82 | result = "sd_model_checkpoint is missing"
83 | if "id_task" not in body["alwayson_scripts"].keys():
84 | result = "id_task is missing"
85 | return result
--------------------------------------------------------------------------------
/src/frontend/input_function/v1alpha2/app.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 |
4 | import json
5 | import os
6 | import traceback
7 |
8 | import boto3
9 |
10 | sns_client = boto3.client('sns')
11 |
12 | def lambda_handler(event, context):
13 | if event['httpMethod'] == 'POST':
14 | try:
15 | payload = json.loads(event['body'])["task"]
16 | val = validate(payload)
17 | if val != "success":
18 | return {
19 | 'statusCode': 400,
20 | 'body': 'Incorrect payload structure, ' + val
21 | }
22 |
23 | id = payload["metadata"]["id"]
24 | runtime = payload["metadata"]["runtime"]
25 | prefix = payload["metadata"]["prefix"]
26 | s3_output_path = f"{os.environ['S3_OUTPUT_BUCKET']}/{prefix}/{id}"
27 |
28 | print(event['headers'])
29 | print(event['queryStringParameters'])
30 |
31 | sns_client.publish(
32 | TargetArn=os.environ['SNS_TOPIC_ARN'],
33 | Message=json.dumps(payload),
34 | MessageAttributes={
35 | 'runtime': {
36 | 'DataType': 'String',
37 | 'StringValue': runtime
38 | }
39 | }
40 | )
41 |
42 | return {
43 | 'statusCode': 200,
44 | 'body': json.dumps({
45 | "id": id,
46 | "runtime": runtime,
47 | "output_location": f"s3://{s3_output_path}"
48 | })
49 | }
50 |
51 | except Exception as e:
52 | traceback.print_exc()
53 | return {
54 | 'statusCode': 400,
55 | 'body': str(e)
56 | }
57 | else:
58 | return {
59 | 'statusCode': 400,
60 | 'body': "Unsupported HTTP method"
61 | }
62 |
63 |
64 | def validate(body: dict) -> str:
65 | result = "success"
66 | if "metadata" not in body.keys():
67 | result = "metadata is missing"
68 | else:
69 | if "id" not in body["metadata"].keys():
70 | result = "id is missing"
71 | if "runtime" not in body["metadata"].keys():
72 | result = "runtime is missing"
73 | if "tasktype" not in body["metadata"].keys():
74 | result = "tasktype is missing"
75 | if "content" not in body.keys():
76 | result = "content is missing"
77 | return result
--------------------------------------------------------------------------------
/src/tools/ebs_throughput_tuner/app.py:
--------------------------------------------------------------------------------
1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 | # SPDX-License-Identifier: MIT-0
3 |
4 | import boto3, os
5 |
6 | def get_ebs_volume_id(instance_id):
7 | ec2 = boto3.client('ec2')
8 | try:
9 | response = ec2.describe_instances(InstanceIds=[instance_id])
10 | volume_id = response['Reservations'][0]['Instances'][0]['BlockDeviceMappings'][-1]['Ebs']['VolumeId']
11 | return volume_id
12 | except Exception as e:
13 | print("Error:", e)
14 | return None
15 |
16 | def get_instance_tag(instance, key):
17 | for tag in instance.get('Tags', []):
18 | if tag['Key'] == key:
19 | return tag['Value']
20 | return None
21 |
22 | def modify_ebs_throughput_and_iops(volume_id, throughput, iops):
23 | ec2 = boto3.client('ec2')
24 | try:
25 | response = ec2.modify_volume(VolumeId=volume_id, Throughput=throughput, Iops=iops)
26 | print("Successfully modified EBS throughput of " + volume_id)
27 | except Exception as e:
28 | print("Error:", e)
29 |
30 | # Entrypoint
31 | def lambda_handler(event, context):
32 | ec2 = boto3.client('ec2')
33 | instance_id = event['detail']['instance-id']
34 | print (f"Process with instance " + instance_id)
35 | target_ec2_tag_key = os.environ['TARGET_EC2_TAG_KEY']
36 | target_ec2_tag_value = os.environ['TARGET_EC2_TAG_VALUE']
37 | # Get the instance tag
38 | response = ec2.describe_instances(InstanceIds=[instance_id])
39 | instance = response['Reservations'][0]['Instances'][0]
40 | instance_tag_value = get_instance_tag(instance, target_ec2_tag_key)
41 | volume_id = get_ebs_volume_id(instance_id)
42 | if instance_tag_value:
43 | # Determine if EBS throughput needs to be modified based on instance name
44 | if target_ec2_tag_value in instance_tag_value:
45 | print (f"Got matching tag from " + instance_id + " ,processing...")
46 | throughput_value = int(os.environ['THROUGHPUT_VALUE']) # Modify throughput
47 | IOPS_value = int(os.environ['IOPS_VALUE']) # Modify IOPS
48 | modify_ebs_throughput_and_iops(volume_id, throughput_value, IOPS_value)
49 | else:
50 | print (f"Skipped " + instance_id)
51 |
--------------------------------------------------------------------------------
/test/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e
4 |
5 | SCRIPTPATH=$(realpath $(dirname "$0"))
6 | STACK_NAME=${STACK_NAME:-"sdoneksStack"}
7 | AWS_DEFAULT_REGION=${AWS_DEFAULT_REGION:-$(aws ec2 describe-availability-zones --output text --query 'AvailabilityZones[0].[RegionName]')}
8 | declare -l RUNTIME_TYPE=${RUNTIME_TYPE:-"sdwebui"}
9 | API_VERSION=${API_VERSION:-"v1alpha2"}
10 |
11 | API_ENDPOINT=$(aws cloudformation describe-stacks --stack-name ${STACK_NAME} --output text --query 'Stacks[0].Outputs[?OutputKey==`FrontApiEndpoint`].OutputValue')
12 |
13 | printf "API Endpoint is ${API_ENDPOINT}\n"
14 |
15 | API_KEY_COMMAND=$(aws cloudformation describe-stacks --stack-name ${STACK_NAME} --output text --query 'Stacks[0].Outputs[?OutputKey==`GetAPIKeyCommand`].OutputValue')
16 |
17 | API_KEY=$(echo $API_KEY_COMMAND | bash)
18 |
19 | printf "API Key is ${API_KEY}\n"
20 |
21 | if [[ ${RUNTIME_TYPE} == "sdwebui" ]]
22 | then
23 | printf "Generating test text-to-image request... \n"
24 |
25 | curl -X POST ${API_ENDPOINT}/${API_VERSION}/ \
26 | -H "Content-Type: application/json" \
27 | -H "x-api-key: ${API_KEY}" \
28 | -d @${SCRIPTPATH}/${API_VERSION}/t2i.json
29 |
30 | printf "\nGenerating test image-to-image request... \n"
31 |
32 | curl -X POST ${API_ENDPOINT}/${API_VERSION}/ \
33 | -H "Content-Type: application/json" \
34 | -H "x-api-key: ${API_KEY}" \
35 | -d @${SCRIPTPATH}/${API_VERSION}/i2i.json
36 |
37 | printf "\nGenerating image upscaling request... \n"
38 |
39 | curl -X POST ${API_ENDPOINT}/${API_VERSION}/ \
40 | -H "Content-Type: application/json" \
41 | -H "x-api-key: ${API_KEY}" \
42 | -d @${SCRIPTPATH}/${API_VERSION}/extra-single-image.json
43 | fi
44 |
45 | if [[ ${RUNTIME_TYPE} == "comfyui" ]]
46 | then
47 | printf "Generating test pipeline request... \n"
48 |
49 | curl -X POST ${API_ENDPOINT}/${API_VERSION} \
50 | -H "Content-Type: application/json" \
51 | -H "x-api-key: ${API_KEY}" \
52 | -d @${SCRIPTPATH}/${API_VERSION}/pipeline.json
53 | fi
--------------------------------------------------------------------------------
/test/v1alpha1/i2i.json:
--------------------------------------------------------------------------------
1 | {
2 | "alwayson_scripts": {
3 | "task": "image-to-image",
4 | "image_link": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png",
5 | "id_task": "test-i2i",
6 | "sd_model_checkpoint": "sd_xl_turbo_1.0.safetensors",
7 | "save_dir": "outputs"
8 | },
9 | "prompt": "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k",
10 | "steps": 16,
11 | "width": 512,
12 | "height": 512
13 | }
--------------------------------------------------------------------------------
/test/v1alpha1/t2i.json:
--------------------------------------------------------------------------------
1 | {
2 | "alwayson_scripts": {
3 | "task": "text-to-image",
4 | "sd_model_checkpoint": "sd_xl_turbo_1.0.safetensors",
5 | "id_task": "test-t2i",
6 | "save_dir": "outputs"
7 | },
8 | "prompt": "A dog",
9 | "steps": 16,
10 | "width": 512,
11 | "height": 512
12 | }
--------------------------------------------------------------------------------
/test/v1alpha1/t2v.json:
--------------------------------------------------------------------------------
1 | {
2 | "alwayson_scripts": {
3 | "task": "text-to-image",
4 | "sd_model_checkpoint": "sd_xl_turbo_1.0.safetensors",
5 | "id_task": "gif1",
6 | "uid": "gif1",
7 | "save_dir": "outputs",
8 | "AnimateDiff": {
9 | "args": [
10 | {
11 | "batch_size": 16,
12 | "closed_loop": "N",
13 | "enable": true,
14 | "format": [
15 | "GIF"
16 | ],
17 | "fps": 8,
18 | "interp": "Off",
19 | "interp_x": 10,
20 | "last_frame": null,
21 | "latent_power": 1,
22 | "latent_scale": 32,
23 | "loop_number": 0,
24 | "model": "mm_sd15_v3.safetensors",
25 | "overlap": -1,
26 | "stride": 1,
27 | "video_length": 32
28 | }
29 | ]
30 | }
31 | },
32 | "batch_size": 32,
33 | "cfg_scale": 7,
34 | "height": 512,
35 | "negative_prompt": "(worst quality:1.2), (low quality:1.2), (lowres:1.1), (monochrome:1.1), (greyscale), multiple views, comic, sketch, (((bad anatomy))), (((deformed))), (((disfigured))), watermark, multiple_views, mutation hands, mutation fingers, extra fingers, missing fingers, watermark, nude, nsfw",
36 | "prompt": "masterpiece, 30 year old women, cleavage, red hair, bun, ponytail, medium breast, desert, cactus vibe, sensual pose, (looking in the camera:1.2), (front view:1.2), facing the camera, close up, upper body",
37 | "steps": 20,
38 | "width": 512
39 | }
--------------------------------------------------------------------------------
/test/v1alpha2/extra-single-image.json:
--------------------------------------------------------------------------------
1 | {
2 | "task": {
3 | "metadata": {
4 | "id": "test-extra",
5 | "runtime": "sdruntime",
6 | "tasktype": "extra-single-image",
7 | "prefix": "output",
8 | "context": ""
9 | },
10 | "content": {
11 | "resize_mode":0,
12 | "show_extras_results":false,
13 | "gfpgan_visibility":0,
14 | "codeformer_visibility":0,
15 | "codeformer_weight":0,
16 | "upscaling_resize":4,
17 | "upscaling_resize_w":512,
18 | "upscaling_resize_h":512,
19 | "upscaling_crop":false,
20 | "upscaler_1":"R-ESRGAN 4x+",
21 | "upscaler_2":"None",
22 | "extras_upscaler_2_visibility":0,
23 | "upscale_first":false,
24 | "image":"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
25 | }
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/test/v1alpha2/i2i.json:
--------------------------------------------------------------------------------
1 | {
2 | "task": {
3 | "metadata": {
4 | "id": "test-i2i",
5 | "runtime": "sdruntime",
6 | "tasktype": "image-to-image",
7 | "prefix": "output",
8 | "context": ""
9 | },
10 | "content": {
11 | "alwayson_scripts": {},
12 | "init_images": ["https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"],
13 | "prompt": "cat wizard, gandalf, lord of the rings, detailed, fantasy, cute, adorable, Pixar, Disney, 8k",
14 | "steps": 16,
15 | "width": 512,
16 | "height": 512
17 | }
18 | }
19 | }
--------------------------------------------------------------------------------
/test/v1alpha2/pipeline.json:
--------------------------------------------------------------------------------
1 | {
2 | "task": {
3 | "metadata": {
4 | "id": "test-pipeline",
5 | "runtime": "sdruntime",
6 | "tasktype": "pipeline",
7 | "prefix": "output",
8 | "context": ""
9 | },
10 | "content": {
11 | "3": {
12 | "inputs": {
13 | "seed": 89465186141685,
14 | "steps": 20,
15 | "cfg": 8,
16 | "sampler_name": "euler",
17 | "scheduler": "normal",
18 | "denoise": 1,
19 | "model": [
20 | "4",
21 | 0
22 | ],
23 | "positive": [
24 | "6",
25 | 0
26 | ],
27 | "negative": [
28 | "7",
29 | 0
30 | ],
31 | "latent_image": [
32 | "5",
33 | 0
34 | ]
35 | },
36 | "class_type": "KSampler",
37 | "_meta": {
38 | "title": "KSampler"
39 | }
40 | },
41 | "4": {
42 | "inputs": {
43 | "ckpt_name": "sd_xl_turbo_1.0.safetensors"
44 | },
45 | "class_type": "CheckpointLoaderSimple",
46 | "_meta": {
47 | "title": "Load Checkpoint"
48 | }
49 | },
50 | "5": {
51 | "inputs": {
52 | "width": 512,
53 | "height": 512,
54 | "batch_size": 1
55 | },
56 | "class_type": "EmptyLatentImage",
57 | "_meta": {
58 | "title": "Empty Latent Image"
59 | }
60 | },
61 | "6": {
62 | "inputs": {
63 | "text": "beautiful scenery nature glass bottle landscape, , purple galaxy bottle,",
64 | "clip": [
65 | "4",
66 | 1
67 | ]
68 | },
69 | "class_type": "CLIPTextEncode",
70 | "_meta": {
71 | "title": "CLIP Text Encode (Prompt)"
72 | }
73 | },
74 | "7": {
75 | "inputs": {
76 | "text": "text, watermark",
77 | "clip": [
78 | "4",
79 | 1
80 | ]
81 | },
82 | "class_type": "CLIPTextEncode",
83 | "_meta": {
84 | "title": "CLIP Text Encode (Prompt)"
85 | }
86 | },
87 | "8": {
88 | "inputs": {
89 | "samples": [
90 | "3",
91 | 0
92 | ],
93 | "vae": [
94 | "4",
95 | 2
96 | ]
97 | },
98 | "class_type": "VAEDecode",
99 | "_meta": {
100 | "title": "VAE Decode"
101 | }
102 | },
103 | "9": {
104 | "inputs": {
105 | "filename_prefix": "ComfyUI",
106 | "images": [
107 | "8",
108 | 0
109 | ]
110 | },
111 | "class_type": "SaveImage",
112 | "_meta": {
113 | "title": "Save Image"
114 | }
115 | }
116 | }
117 | }
118 | }
--------------------------------------------------------------------------------
/test/v1alpha2/t2i.json:
--------------------------------------------------------------------------------
1 | {
2 | "task": {
3 | "metadata": {
4 | "id": "test-t2i",
5 | "runtime": "sdruntime",
6 | "tasktype": "text-to-image",
7 | "prefix": "output",
8 | "context": ""
9 | },
10 | "content": {
11 | "alwayson_scripts": {},
12 | "prompt": "A dog",
13 | "steps": 16,
14 | "width": 512,
15 | "height": 512
16 | }
17 | }
18 | }
--------------------------------------------------------------------------------
/tsconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "compilerOptions": {
3 | "target": "ES2020",
4 | "module": "commonjs",
5 | "lib": [
6 | "es2020",
7 | "dom"
8 | ],
9 | "declaration": true,
10 | "strict": true,
11 | "noImplicitAny": true,
12 | "strictNullChecks": true,
13 | "noImplicitThis": true,
14 | "alwaysStrict": true,
15 | "noUnusedLocals": false,
16 | "noUnusedParameters": false,
17 | "noImplicitReturns": true,
18 | "noFallthroughCasesInSwitch": false,
19 | "inlineSourceMap": true,
20 | "inlineSources": true,
21 | "experimentalDecorators": true,
22 | "strictPropertyInitialization": false,
23 | "typeRoots": [
24 | "./node_modules/@types"
25 | ]
26 | },
27 | "exclude": [
28 | "node_modules",
29 | "cdk.out"
30 | ]
31 | }
32 |
--------------------------------------------------------------------------------