├── .gitattributes ├── .gitignore ├── .project ├── configpacks └── spec.yaml ├── LICENSE.txt ├── README.md ├── apt.txt ├── base-command.sh ├── code ├── .gitkeep ├── FineTuning-SDXL.ipynb ├── chatui │ ├── __init__.py │ ├── __main__.py │ ├── api.py │ ├── assets │ │ ├── __init__.py │ │ ├── kaizen-theme.css │ │ └── kaizen-theme.json │ ├── chat_client.py │ ├── configuration.py │ ├── configuration_wizard.py │ ├── pages │ │ ├── __init__.py │ │ └── converse.py │ ├── static │ │ ├── converse.html │ │ ├── favicon.ico │ │ ├── index.html │ │ ├── kb.html │ │ └── next.svg │ └── utils │ │ ├── __init__.py │ │ └── logger.py └── output.log ├── data ├── .gitkeep ├── charles-3 │ └── .gitkeep ├── my-data │ └── .gitkeep ├── scratch │ └── .gitkeep └── toy-jensen │ ├── tj1.png │ ├── tj2.png │ ├── tj3.png │ ├── tj4.png │ ├── tj5.png │ ├── tj6.png │ ├── tj7.png │ └── tj8.png ├── models └── .gitkeep ├── postBuild.bash ├── preBuild.bash ├── requirements.txt ├── third-party └── LICENSE.sdxl └── variables.env /.gitattributes: -------------------------------------------------------------------------------- 1 | models/** filter=lfs diff=lfs merge=lfs -text 2 | data/** filter=lfs diff=lfs merge=lfs -text 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore generated or temporary files managed by the Workbench 2 | .project/* 3 | !.project/spec.yaml 4 | !.project/configpacks 5 | 6 | # General ignores 7 | .DS_Store 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | .ipynb_checkpoints 12 | *:Zone.Identifier 13 | 14 | # Workbench Project Layout 15 | data/scratch/* 16 | !data/scratch/.gitkeep 17 | 18 | # Byte-compiled / optimized / DLL files 19 | 20 | # Temp directories, notebooks created by jupyterlab 21 | .Trash-*/ 22 | .jupyter/ 23 | .local/ 24 | 25 | # Python distribution / packaging 26 | .Python 27 | build/ 28 | develop-eggs/ 29 | dist/ 30 | downloads/ 31 | eggs/ 32 | .eggs/ 33 | lib/ 34 | lib64/ 35 | parts/ 36 | sdist/ 37 | var/ 38 | wheels/ 39 | share/python-wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | MANIFEST 44 | 45 | # Unit test / coverage reports 46 | htmlcov/ 47 | .tox/ 48 | .nox/ 49 | .coverage 50 | .coverage.* 51 | .cache 52 | nosetests.xml 53 | coverage.xml 54 | *.cover 55 | *.py,cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | cover/ 59 | code/output.log 60 | data/generated_images/ -------------------------------------------------------------------------------- /.project/configpacks: -------------------------------------------------------------------------------- 1 | *defaults.ContainerUser 2 | *bash.PreBuild 3 | *defaults.CA 4 | *cuda.CUDA 5 | *defaults.EnvVars 6 | *defaults.Readme 7 | *defaults.Entrypoint 8 | *apt.PackageManager 9 | *bash.PreLanguage 10 | *python.PipPackageManager 11 | *bash.PostBuild 12 | *jupyterlab.JupyterLab -------------------------------------------------------------------------------- /.project/spec.yaml: -------------------------------------------------------------------------------- 1 | specVersion: v2 2 | specMinorVersion: 2 3 | meta: 4 | name: sdxl-customization 5 | image: project-sdxl-customization 6 | description: Image generation customization 7 | labels: [] 8 | createdOn: "2023-08-07T02:40:39Z" 9 | defaultBranch: main 10 | layout: 11 | - path: code/ 12 | type: code 13 | storage: git 14 | - path: models/ 15 | type: models 16 | storage: gitlfs 17 | - path: data/ 18 | type: data 19 | storage: gitlfs 20 | - path: data/scratch/ 21 | type: data 22 | storage: gitignore 23 | environment: 24 | base: 25 | registry: nvcr.io 26 | image: nvidia/pytorch:23.09-py3 27 | build_timestamp: "20230322105648" 28 | name: Pytorch 23.09 Base Container 29 | supported_architectures: 30 | - amd64 31 | cuda_version: "12.2" 32 | description: A base for fine-tuning SDXL with Dreambooth 33 | entrypoint_script: "" 34 | labels: 35 | - ubuntu 36 | - python3 37 | - jupyterlab 38 | - cuda12.2 39 | - pytorch2.1 40 | apps: 41 | - name: inference 42 | type: custom 43 | class: webapp 44 | start_command: cd /project/code/ && PROXY_PREFIX=$PROXY_PREFIX python3 -m chatui 45 | health_check_command: curl -f "http://localhost:8080/" 46 | stop_command: pkill -f "^python3 -m chatui" 47 | user_msg: "" 48 | logfile_path: "" 49 | timeout_seconds: 120 50 | icon_url: "" 51 | webapp_options: 52 | autolaunch: true 53 | port: "8080" 54 | proxy: 55 | trim_prefix: true 56 | url: http://localhost:8080/ 57 | programming_languages: 58 | - python3 59 | icon_url: "" 60 | image_version: 1.0.0 61 | os: linux 62 | os_distro: ubuntu 63 | os_distro_release: "22.04" 64 | schema_version: v2 65 | user_info: 66 | uid: "" 67 | gid: "" 68 | username: "" 69 | package_managers: 70 | - name: pip 71 | binary_path: /opt/conda/bin/pip 72 | installed_packages: [] 73 | - name: conda 74 | binary_path: /opt/conda/bin/conda 75 | installed_packages: 76 | - jupyterlab=3.4.4 77 | - tensorboard 78 | - nodejs 79 | - notebook 80 | - python=3.10.9 81 | - pytorch=2.1.0 82 | - torchvision 83 | - pytorch-cuda=12.1 84 | - name: apt 85 | binary_path: /usr/bin/apt 86 | installed_packages: 87 | - build-essential 88 | - ca-certificates 89 | - curl 90 | - locales 91 | - git 92 | - git-lfs 93 | - openssl=1.1.1f-1ubuntu2.17 94 | - libssl1.1=1.1.1f-1ubuntu2.17 95 | - vim 96 | package_manager_environment: 97 | name: "" 98 | target: "" 99 | compose_file_path: "" 100 | execution: 101 | apps: 102 | - name: jupyterlab 103 | type: jupyterlab 104 | class: webapp 105 | start_command: jupyter lab --allow-root --port 8888 --ip 0.0.0.0 --no-browser --NotebookApp.base_url=\$PROXY_PREFIX --NotebookApp.default_url=/lab --notebook-dir=/project/ 106 | health_check_command: '[ \$(echo url=\$(jupyter lab list 2>&1 | head -n 2 | tail -n 1 | cut -f1 -d'''' '''' | grep -v ''''Currently'''' | sed ''''s@/?@/lab?@g'''') | curl -o /dev/null -s -w ''''%{http_code}'''' --config -) == ''''200'''' ]' 107 | stop_command: jupyter lab stop 8888 108 | user_msg: "" 109 | logfile_path: "" 110 | timeout_seconds: 60 111 | icon_url: "" 112 | webapp_options: 113 | autolaunch: true 114 | port: "8888" 115 | proxy: 116 | trim_prefix: false 117 | url_command: jupyter lab list 2>&1 | head -n 2 | tail -n 1 | cut -f1 -d' ' | grep -v 'Currently' 118 | resources: 119 | gpu: 120 | requested: 1 121 | sharedMemoryMB: 1024 122 | secrets: [] 123 | mounts: 124 | - type: project 125 | target: /project/ 126 | description: Project directory 127 | options: rw 128 | - type: volume 129 | target: /mnt/cache/ 130 | description: Huggingface cache root 131 | options: "" 132 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "{}" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 NVIDIA Corporation 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NVIDIA AI Workbench: Introduction 2 | This is an [NVIDIA AI Workbench](https://www.nvidia.com/en-us/deep-learning-ai/solutions/data-science/workbench/) example Project that demonstrates how to customize a Stable Diffusion XL (SDXL) model. This project takes the latest SDXL model and familiarizes it with Toy Jensen via finetuning on a few pictures, thereby teaching it to generate new images which include him when it didn't recognize him previously. Next, we will also enable the user to bring their own custom image data to fine-tune the model on. Users who have [installed AI Workbench](https://www.nvidia.com/en-us/deep-learning-ai/solutions/data-science/workbench/) can get up and running with this project in minutes. 3 | 4 | Have questions? Please direct any issues, fixes, suggestions, and discussion on this project to the DevZone Members Only Forum thread [here](https://forums.developer.nvidia.com/t/support-workbench-example-project-sdxl-customization/278374/1). 5 | 6 | ## Project Description 7 | For a short demo and project overview, check out this [video](https://www.youtube.com/watch?v=ntMRzPzSvM4)! 8 | 9 | Over the past few years Generative AI models have popped up everywhere - from creating realistic responses to complex questions, to generating images and music to impress art critics around the globe. In this project we use the Hugging Face Stable Diffusion XL (SDXL) model to create images from text prompts. You'll see how to import the SDXL model and use it to generate an image. 10 | 11 | From there, you'll see how you can fine-tune the model using DreamBooth. We'll use a small number of photos of Toy Jensen to fine-tune the model. This will allow us to generate new images that include Toy Jensen when the model didn't previously recognize him. After that, you'll have the chance to fine-tune the model on your own images. Perhaps you want to create an image of you at the bottom of the ocean, or in outer space? By the end of this notebook you will be able to! 12 | 13 | ## System Requirements: 14 | * Operating System: Ubuntu 22.04 15 | * CPU requirements: None, tested with Intel® Xeon® Platinum 8380 CPU @ 2.30GHz 16 | * GPU requirements: Any NVIDIA training GPU, tested with 1x NVIDIA A100-80GB 17 | * NVIDIA driver requirements: Latest driver version 18 | * Storage requirements: 40GB 19 | 20 | # Quickstart 21 | If you have NVIDIA AI Workbench already installed, you can use this Project in AI Workbench on your choice of machine by: 22 | 1. Forking this Project to your own GitHub namespace and copying the link 23 | 24 | ``` 25 | https://github.com/[your_namespace]/ 26 | ``` 27 | 28 | 2. Opening a shell and activating the Context you want to clone into by 29 | 30 | ``` 31 | $ nvwb list contexts 32 | 33 | $ nvwb activate 34 | ``` 35 | 36 | 3. Cloning this Project onto your desired machine by running 37 | 38 | ``` 39 | $ nvwb clone project 40 | ``` 41 | 42 | 4. Opening the Project by 43 | 44 | ``` 45 | $ nvwb list projects 46 | 47 | $ nvwb open 48 | ``` 49 | 50 | 5. Starting JupyterLab by 51 | 52 | ``` 53 | $ nvwb start jupyterlab 54 | ``` 55 | 56 | 6. Navigate to the code directory of the project. Then, open the notebook titled ```FineTuning-SDXL.ipynb``` and get started. Happy coding! 57 | 58 | --- 59 | **Tip:** Use ```nvwb help``` to see a full list of NVIDIA AI Workbench commands. 60 | 61 | --- 62 | 63 | ## Tested On 64 | This notebook has been tested with an NVIDIA RTX A6000 GPU and the following version of NVIDIA AI Workbench: ```nvwb 0.21.3 (internal; linux; amd64; go1.21.3; Tue Mar 5 03:55:43 UTC 2024)``` 65 | 66 | ## License 67 | This NVIDIA AI Workbench example project is under the [Apache 2.0 License](https://github.com/nv-edwli/sdxl-customization/blob/main/LICENSE.txt) 68 | 69 | This project will utilize additional third-party open source software projects. Review the license terms of these open source projects before use. Third party components used as part of this project are subject to their separate legal notices or terms that accompany the components. You are responsible for confirming compliance with third-party component license terms and requirements. 70 | -------------------------------------------------------------------------------- /apt.txt: -------------------------------------------------------------------------------- 1 | # apt packages to install should be listed one per line 2 | -------------------------------------------------------------------------------- /base-command.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # *** AUTO-GENERATED FILE - DO NOT EDIT **** 4 | # This script is required when running projects in NGC Base Command from Workbench 5 | # It is safe to delete if you do not plan to run this project in NGC Base Command 6 | 7 | if [ -z "$NVWB_BASE_COMMAND" ] 8 | then 9 | echo "base-command.sh should only be used while executing in Base Command." 10 | exit 1 11 | fi 12 | 13 | # Wipe workspace contents prior to clone 14 | echo "Erasing project workspace mounted at /project" 15 | cd /project 16 | rm -rf ..?* .[!.]* * 17 | ls -la /project 18 | 19 | # Clone the repo into the workspace 20 | echo "Cloning https://${NVWB_PROJECT_URL} into /project" 21 | git clone https://${NVWB_GIT_USERNAME}:${NVWB_GIT_PASSWORD}@${NVWB_PROJECT_URL} /project 22 | 23 | # Run the command 24 | echo "running $@" 25 | cd /project$NVWB_SCRIPT_DIR 26 | $@ 27 | -------------------------------------------------------------------------------- /code/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/workbench-example-sdxl-customization/3a13b42bb9b9dd27322fc0c57dabcaa73c126451/code/.gitkeep -------------------------------------------------------------------------------- /code/FineTuning-SDXL.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "e09c673d", 6 | "metadata": {}, 7 | "source": [ 8 | "# Fine-Tuning StableDiffusion XL with DreamBooth" 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "19aac13c", 14 | "metadata": {}, 15 | "source": [ 16 | "Over the past few years Generative AI models have popped up everywhere - from creating realistic responses to complex questions, to generating images and music to impress art critics around the globe. In this notebook we use the Hugging Face [Stable Diffusion XL (SDXL)](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) model to create images from text prompts. You'll see how to import the SDXL model and use it to generate an image. \n", 17 | "\n", 18 | "From there, you'll see how you can fine-tune the model using [DreamBooth](https://huggingface.co/docs/diffusers/training/dreambooth), a method for easily fine-tuning a text-to-image model. We'll use a small number of photos of [Toy Jensen](https://blogs.nvidia.com/blog/2022/12/22/toy-jensen-jingle-bells/) in this notebook to fine-tune SDXL. This will allow us to generate new images that include Toy Jensen! \n", 19 | "\n", 20 | "After that, you'll have the chance to fine-tune the model on your own images. Perhaps you want to create an image of you at the bottom of the ocean, or in outer space? By the end of this notebook you will be able to! \n", 21 | "\n", 22 | "**IMPORTANT:** This project will utilize additional third-party open source software. Review the license terms of these open source projects before use. Third party components used as part of this project are subject to their separate legal notices or terms that accompany the components. You are responsible for confirming compliance with third-party component license terms and requirements." 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "id": "d0ee02dd", 28 | "metadata": {}, 29 | "source": [ 30 | "### Stable Diffusion XL Model\n", 31 | "\n", 32 | "First, we import the classes and libraries we need to run the notebook." 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "9ff19a5c", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "import torch\n", 43 | "from diffusers import StableDiffusionXLPipeline, DiffusionPipeline" 44 | ] 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "id": "33488fe1", 49 | "metadata": {}, 50 | "source": [ 51 | "Next, from the Hugging Face `diffusers` library, we create a `StableDiffusionXLPipeline` object from the SDXL base model. " 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "2db6042e", 58 | "metadata": { 59 | "scrolled": true 60 | }, 61 | "outputs": [], 62 | "source": [ 63 | "pipe = StableDiffusionXLPipeline.from_pretrained(\n", 64 | " \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\", use_safetensors=True\n", 65 | ")\n", 66 | "pipe.to(\"cuda\")" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "id": "ac0e1a6e", 72 | "metadata": {}, 73 | "source": [ 74 | "Let's use the SDXL model to generate an image. " 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "4655c4f7", 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "prompt = \"toy jensen in space\"\n", 85 | "image = pipe(prompt=prompt).images[0]\n", 86 | "\n", 87 | "image" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "id": "72f0dad2", 93 | "metadata": {}, 94 | "source": [ 95 | "Hmmm, looks like the Hugging Face SDXL model doesn't know about Toy Jensen! Imagine that! \n", 96 | "\n", 97 | "✅ Try using the SDXL model to generate some other images by editing the text in the first line of the cell above. \n" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "id": "2d5795a7", 103 | "metadata": {}, 104 | "source": [ 105 | "## Fine-Tuning the model with DreamBooth\n", 106 | "\n", 107 | "Fine-Tuning is used to train an existing Machine Learning Model, given new information. In our case, we want to teach the SDXL model about Toy Jensen. This will allow us to create the perfect image of Toy Jensen in Space!\n", 108 | "\n", 109 | "[DreamBooth](https://arxiv.org/abs/2208.12242) provides a way to fine-tune a text-to-image model using only a few images. Let's use this to tune our SDXL Model so that it knows about Toy Jensen!\n", 110 | "\n", 111 | "We have 8 photos of Toy Jensen in our dataset - let's take a look at one of them." 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "execution_count": null, 117 | "id": "c048e134", 118 | "metadata": {}, 119 | "outputs": [], 120 | "source": [ 121 | "from IPython.display import Image\n", 122 | "\n", 123 | "display(Image(filename='../data/toy-jensen/tj1.png'))" 124 | ] 125 | }, 126 | { 127 | "cell_type": "markdown", 128 | "id": "1611c557", 129 | "metadata": {}, 130 | "source": [ 131 | "Now we can use Hugging Face and DreamBooth to fine-tune this model. To do this we create a config, then specify some flags like an instance prompt, a resolution and a number of training steps for the fine-tuning algorithm to run. " 132 | ] 133 | }, 134 | { 135 | "cell_type": "code", 136 | "execution_count": null, 137 | "id": "3f7d9ac2", 138 | "metadata": {}, 139 | "outputs": [], 140 | "source": [ 141 | "from accelerate.utils import write_basic_config\n", 142 | "write_basic_config()" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": null, 148 | "id": "782e0ee5", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "!accelerate launch /workspace/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py \\\n", 153 | " --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \\\n", 154 | " --instance_data_dir=/project/data/toy-jensen \\\n", 155 | " --output_dir=/project/models/tuned-toy-jensen \\\n", 156 | " --mixed_precision=\"fp16\" \\\n", 157 | " --instance_prompt=\"a photo of toy jensen\" \\\n", 158 | " --resolution=1024 \\\n", 159 | " --train_batch_size=1 \\\n", 160 | " --gradient_accumulation_steps=4 \\\n", 161 | " --learning_rate=1e-4 \\\n", 162 | " --lr_scheduler=\"constant\" \\\n", 163 | " --lr_warmup_steps=0 \\\n", 164 | " --max_train_steps=100 \\\n", 165 | " --seed=\"0\" " 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "id": "83aa7268", 171 | "metadata": {}, 172 | "source": [ 173 | "Now that the model is fine-tuned, let's tell our notebook where to find it." 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "id": "21f899ea", 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "base_model_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\n", 184 | "pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)\n", 185 | "pipe = pipe.to(\"cuda\")\n", 186 | "pipe.load_lora_weights(\"/project/models/tuned-toy-jensen\")" 187 | ] 188 | }, 189 | { 190 | "cell_type": "markdown", 191 | "id": "08b5ee85", 192 | "metadata": {}, 193 | "source": [ 194 | "Finally, we can use our fine-tuned model to create an image with Toy Jensen in it. Let's give it a go! " 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "id": "00d3ac0c", 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "image = pipe(\"A picture of toy jensen in space\", num_inference_steps=75).images[0]\n", 205 | "\n", 206 | "image" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "id": "3fb09409", 212 | "metadata": {}, 213 | "source": [ 214 | "Wow - look at him go! " 215 | ] 216 | }, 217 | { 218 | "cell_type": "markdown", 219 | "id": "b39851c0", 220 | "metadata": {}, 221 | "source": [ 222 | "### Trying out some more examples\n", 223 | "\n", 224 | "\n", 225 | "The SDXL model we are using was trained on historical data, and knows about everything from celebrities to famous buildings. However, it was trained on data up to a fixed point in time and isn't up to date with things and people who have become famous in the last few months.\n", 226 | "\n", 227 | "For example, King Charles III became king of the United Kingdom in September 2022. Let's ask our SDXL Model for an image of King Charles in Space:" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "id": "a692804d", 234 | "metadata": {}, 235 | "outputs": [], 236 | "source": [ 237 | "prompt = \"King Charles in space\"\n", 238 | "image = pipe(prompt=prompt).images[0]\n", 239 | "\n", 240 | "image" 241 | ] 242 | }, 243 | { 244 | "cell_type": "markdown", 245 | "id": "1be0a65a", 246 | "metadata": {}, 247 | "source": [ 248 | "Did it give you an image of a King Charles spaniel? Or maybe King Charles II? That's not what we were hoping for! \n", 249 | "\n", 250 | "1. Let's gather some (10ish) images of King Charles III from your favourite search engine. Copy those images into the `data/charles-3/` folder. You can download then to your machine and move them to this folder. \n", 251 | "\n", 252 | " **Reminder:** Third party components used as part of this project are subject to their separate legal notices or terms that accompany the components; you are responsible for reviewing and confirming compliance with third-party component license terms and requirements.\n", 253 | "2. Run the code below to fine-tune the model on your images of King Charles. " 254 | ] 255 | }, 256 | { 257 | "cell_type": "code", 258 | "execution_count": null, 259 | "id": "8bf45d79", 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "# Remove the .gitkeep file in the 'charles-3' folder.\n", 264 | "!rm ../data/charles-3/.gitkeep\n", 265 | "!rm -rf ../data/charles-3/.ipynb_checkpoints" 266 | ] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": null, 271 | "id": "deb5c7cb", 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "!accelerate launch /workspace/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py \\\n", 276 | " --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \\\n", 277 | " --instance_data_dir=/project/data/charles-3 \\\n", 278 | " --output_dir=/project/models/tuned-charles-3 \\\n", 279 | " --mixed_precision=\"fp16\" \\\n", 280 | " --instance_prompt=\"a photo of King Charles\" \\\n", 281 | " --resolution=1024 \\\n", 282 | " --train_batch_size=1 \\\n", 283 | " --gradient_accumulation_steps=4 \\\n", 284 | " --learning_rate=2e-4 \\\n", 285 | " --lr_scheduler=\"constant\" \\\n", 286 | " --lr_warmup_steps=0 \\\n", 287 | " --max_train_steps=150 \\\n", 288 | " --seed=\"0\" " 289 | ] 290 | }, 291 | { 292 | "cell_type": "markdown", 293 | "id": "70e68be4", 294 | "metadata": {}, 295 | "source": [ 296 | "Now we load the model and use it to generate an image of King Charles. " 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "id": "da82c1b6", 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "base_model_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\n", 307 | "pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)\n", 308 | "pipe = pipe.to(\"cuda\")\n", 309 | "pipe.load_lora_weights(\"/project/models/tuned-charles-3\")" 310 | ] 311 | }, 312 | { 313 | "cell_type": "code", 314 | "execution_count": null, 315 | "id": "fdd059b2", 316 | "metadata": {}, 317 | "outputs": [], 318 | "source": [ 319 | "image = pipe(\"A picture of King Charles in space\", num_inference_steps=75).images[0]\n", 320 | "\n", 321 | "image" 322 | ] 323 | }, 324 | { 325 | "cell_type": "markdown", 326 | "id": "9333d426", 327 | "metadata": {}, 328 | "source": [ 329 | "How is the model performing? Do you need to train it on a few more images? If so, add some more images to the folder then run the cells above to retrain. \n", 330 | "\n", 331 | "Now, the model knows what King Charles III looks like and is able to generate realistic images." 332 | ] 333 | }, 334 | { 335 | "cell_type": "markdown", 336 | "id": "c0436ad5", 337 | "metadata": {}, 338 | "source": [ 339 | "\n", 340 | "## Fine-tuning the Model on your own data\n", 341 | "\n", 342 | "✅ Why not try out training the SDXL model on your own set of images? Follow the steps below to get set up to train your own model. \n", 343 | "\n", 344 | "**Reminder:** Third party components used as part of this project are subject to their separate legal notices or terms that accompany the components; you are responsible for reviewing and confirming compliance with third-party component license terms and requirements.\n", 345 | "\n", 346 | "\n", 347 | "1. You'll need to find around 10 different pictures of your chosen item. Why not find some of your pet or your car? \n", 348 | "\n", 349 | "2. Save those images into the `data/my-data` folder we have created for you, similarly to as you have done with the input images of King Charles III.\n", 350 | "\n", 351 | "3. Edit the 'instance_prompt' line the code below so that it reflects your item. For example, you could change it to \n", 352 | "```--instance_prompt=\"a photo of my cat alice\"```\n", 353 | "\n", 354 | "4. Once you've updated the prompt, run the cells below to train the model on your data. \n" 355 | ] 356 | }, 357 | { 358 | "cell_type": "code", 359 | "execution_count": null, 360 | "id": "751b5fab", 361 | "metadata": {}, 362 | "outputs": [], 363 | "source": [ 364 | "# Remove the .gitkeep file in the 'my-data' folder.\n", 365 | "!rm ../data/my-data/.gitkeep" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": null, 371 | "id": "1d4daa75", 372 | "metadata": {}, 373 | "outputs": [], 374 | "source": [ 375 | "!accelerate launch /workspace/diffusers/examples/dreambooth/train_dreambooth_lora_sdxl.py \\\n", 376 | " --pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \\\n", 377 | " --instance_data_dir=/project/data/my-data \\\n", 378 | " --output_dir=/project/models/tuned-my-data \\\n", 379 | " --mixed_precision=\"fp16\" \\\n", 380 | " --instance_prompt=\"a photo of [CHANGE THIS]\" \\\n", 381 | " --resolution=1024 \\\n", 382 | " --train_batch_size=1 \\\n", 383 | " --gradient_accumulation_steps=4 \\\n", 384 | " --learning_rate=1e-4 \\\n", 385 | " --lr_scheduler=\"constant\" \\\n", 386 | " --lr_warmup_steps=0 \\\n", 387 | " --max_train_steps=100 " 388 | ] 389 | }, 390 | { 391 | "cell_type": "markdown", 392 | "id": "7e3c2e7f", 393 | "metadata": {}, 394 | "source": [ 395 | "Now that your model has been trained we can load it:" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": null, 401 | "id": "96f2f540", 402 | "metadata": {}, 403 | "outputs": [], 404 | "source": [ 405 | "base_model_id = \"stabilityai/stable-diffusion-xl-base-1.0\"\n", 406 | "pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)\n", 407 | "pipe = pipe.to(\"cuda\")\n", 408 | "pipe.load_lora_weights(\"/project/models/tuned-my-data\")" 409 | ] 410 | }, 411 | { 412 | "cell_type": "markdown", 413 | "id": "64d1ac18", 414 | "metadata": {}, 415 | "source": [ 416 | "And finally, use the code below to generate images. Change the prompt to something which includes your item. For example:\n", 417 | "\n", 418 | "`image = pipe(\"A picture of my cat alice in space)`. " 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": null, 424 | "id": "fcb50913", 425 | "metadata": {}, 426 | "outputs": [], 427 | "source": [ 428 | "image = pipe(\"A picture of [CHANGE THIS] in space\", num_inference_steps=75).images[0]\n", 429 | "\n", 430 | "image" 431 | ] 432 | }, 433 | { 434 | "cell_type": "code", 435 | "execution_count": null, 436 | "id": "e672d403", 437 | "metadata": {}, 438 | "outputs": [], 439 | "source": [] 440 | } 441 | ], 442 | "metadata": { 443 | "kernelspec": { 444 | "display_name": "Python 3 (ipykernel)", 445 | "language": "python", 446 | "name": "python3" 447 | }, 448 | "language_info": { 449 | "codemirror_mode": { 450 | "name": "ipython", 451 | "version": 3 452 | }, 453 | "file_extension": ".py", 454 | "mimetype": "text/x-python", 455 | "name": "python", 456 | "nbconvert_exporter": "python", 457 | "pygments_lexer": "ipython3", 458 | "version": "3.10.12" 459 | } 460 | }, 461 | "nbformat": 4, 462 | "nbformat_minor": 5 463 | } 464 | -------------------------------------------------------------------------------- /code/chatui/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Document Retrieval Service. 17 | 18 | Handle document ingestion and retrieval from a VectorDB. 19 | """ 20 | 21 | import logging 22 | import os 23 | import sys 24 | import typing 25 | 26 | if typing.TYPE_CHECKING: 27 | from chatui.api import APIServer 28 | 29 | 30 | _LOG_FMT = f"[{os.getpid()}] %(asctime)15s [%(levelname)7s] - %(name)s - %(message)s" 31 | _LOG_DATE_FMT = "%b %d %H:%M:%S" 32 | _LOGGER = logging.getLogger(__name__) 33 | 34 | 35 | def bootstrap_logging(verbosity: int = 0) -> None: 36 | """Configure Python's logger according to the given verbosity level. 37 | 38 | :param verbosity: The desired verbosity level. Must be one of 0, 1, or 2. 39 | :type verbosity: typing.Literal[0, 1, 2] 40 | """ 41 | # determine log level 42 | verbosity = min(2, max(0, verbosity)) # limit verbosity to 0-2 43 | log_level = [logging.WARN, logging.INFO, logging.DEBUG][verbosity] 44 | 45 | # configure python's logger 46 | logging.basicConfig(filename='chatui.log', filemode='w',format=_LOG_FMT, datefmt=_LOG_DATE_FMT, level=log_level) 47 | # update existing loggers 48 | _LOGGER.setLevel(logging.DEBUG) 49 | for logger in [ 50 | __name__, 51 | "uvicorn", 52 | "uvicorn.access", 53 | "uvicorn.error", 54 | ]: 55 | for handler in logging.getLogger(logger).handlers: 56 | handler.setFormatter(logging.Formatter(fmt=_LOG_FMT, datefmt=_LOG_DATE_FMT)) 57 | -------------------------------------------------------------------------------- /code/chatui/__main__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Entrypoint for the Conversation GUI. 17 | 18 | The functions in this module are responsible for bootstrapping then executing the Conversation GUI server. 19 | """ 20 | 21 | import argparse 22 | import os 23 | import sys 24 | 25 | import uvicorn 26 | 27 | 28 | def parse_args() -> argparse.Namespace: 29 | """Parse command-line arguments for the program. 30 | 31 | :returns: A namespace containing the parsed arguments. 32 | :rtype: argparse.Namespace 33 | """ 34 | parser = argparse.ArgumentParser(description="Document Retrieval Service") 35 | 36 | parser.add_argument( 37 | "--help-config", 38 | action="store_true", 39 | default=False, 40 | help="show the configuration help text", 41 | ) 42 | 43 | parser.add_argument( 44 | "-c", 45 | "--config", 46 | metavar="CONFIGURATION_FILE", 47 | default="/dev/null", 48 | help="path to the configuration file (json or yaml)", 49 | ) 50 | parser.add_argument( 51 | "-v", 52 | "--verbose", 53 | action="count", 54 | default=1, 55 | help="increase output verbosity", 56 | ) 57 | parser.add_argument( 58 | "-q", 59 | "--quiet", 60 | action="count", 61 | default=0, 62 | help="decrease output verbosity", 63 | ) 64 | 65 | parser.add_argument( 66 | "--host", 67 | metavar="HOSTNAME", 68 | type=str, 69 | default="0.0.0.0", # nosec # this is intentional 70 | help="Bind socket to this host.", 71 | ) 72 | parser.add_argument( 73 | "--port", 74 | metavar="PORT_NUM", 75 | type=int, 76 | default=8080, 77 | help="Bind socket to this port.", 78 | ) 79 | parser.add_argument( 80 | "--workers", 81 | metavar="NUM_WORKERS", 82 | type=int, 83 | default=1, 84 | help="Number of worker processes.", 85 | ) 86 | parser.add_argument( 87 | "--ssl-keyfile", metavar="SSL_KEY", type=str, default=None, help="SSL key file" 88 | ) 89 | parser.add_argument( 90 | "--ssl-certfile", 91 | metavar="SSL_CERT", 92 | type=str, 93 | default=None, 94 | help="SSL certificate file", 95 | ) 96 | 97 | cliargs = parser.parse_args() 98 | if cliargs.help_config: 99 | # pylint: disable=import-outside-toplevel; this is intentional to allow for the environment to be configured 100 | # before any of the application libraries are loaded. 101 | from chatui.configuration import AppConfig 102 | 103 | sys.stdout.write("\nconfiguration file format:\n") 104 | AppConfig.print_help(sys.stdout.write) 105 | sys.exit(0) 106 | 107 | return cliargs 108 | 109 | 110 | if __name__ == "__main__": 111 | args = parse_args() 112 | os.environ["APP_VERBOSITY"] = f"{args.verbose - args.quiet}" 113 | os.environ["APP_CONFIG_FILE"] = args.config 114 | 115 | from chatui import api, chat_client, configuration, pages 116 | 117 | # load config 118 | config_file = os.environ.get("APP_CONFIG_FILE", "/dev/null") 119 | config = configuration.AppConfig.from_file(config_file) 120 | if not config: 121 | sys.exit(1) 122 | 123 | # connect to other services 124 | api_url = f"{config.server_url}:{config.server_port}" 125 | print(api_url) 126 | client = chat_client.ChatClient( 127 | api_url, config.model_name 128 | ) 129 | proxy_prefix = os.environ.get("PROXY_PREFIX") 130 | blocks = pages.converse.build_page(client) 131 | blocks.queue(max_size=10) 132 | blocks.launch(server_name="0.0.0.0", server_port=8080, root_path=proxy_prefix) -------------------------------------------------------------------------------- /code/chatui/api.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """This module contains the Server that will host the chatui and API.""" 17 | import os 18 | 19 | import gradio as gr 20 | from fastapi import FastAPI 21 | from fastapi.responses import FileResponse 22 | from fastapi.staticfiles import StaticFiles 23 | from chatui.chat_client import ChatClient 24 | 25 | from chatui import pages 26 | 27 | STATIC_DIR = os.path.join(os.path.abspath(os.path.dirname(__file__)), "static") 28 | 29 | 30 | class APIServer(FastAPI): 31 | """A class that hosts the service api. 32 | 33 | :cvar title: The title of the server. 34 | :type title: str 35 | :cvar desc: A description of the server. 36 | :type desc: str 37 | """ 38 | 39 | title = "Chat" 40 | desc = "This service provides a sample conversation chatui flow." 41 | 42 | def __init__(self, client: ChatClient) -> None: 43 | """Initialize the API server.""" 44 | self._client = client 45 | super().__init__(title=self.title, description=self.desc) 46 | 47 | def configure_routes(self) -> None: 48 | """Configure the routes in the API Server.""" 49 | _ = gr.mount_gradio_app( 50 | self, 51 | blocks=pages.converse.build_page(self._client), 52 | path=f"/content{pages.converse.PATH}", 53 | ) 54 | _ = gr.mount_gradio_app( 55 | self, 56 | blocks=pages.kb.build_page(self._client), 57 | path=f"/content{pages.kb.PATH}", 58 | ) 59 | 60 | @self.get("/") 61 | async def root_redirect() -> FileResponse: 62 | return FileResponse(os.path.join(STATIC_DIR, "converse.html")) 63 | 64 | @self.get("/converse") 65 | async def converse_redirect() -> FileResponse: 66 | return FileResponse(os.path.join(STATIC_DIR, "converse.html")) 67 | 68 | @self.get("/kb") 69 | async def kb_redirect() -> FileResponse: 70 | return FileResponse(os.path.join(STATIC_DIR, "kb.html")) 71 | 72 | self.mount("/", StaticFiles(directory=STATIC_DIR, html=True)) 73 | -------------------------------------------------------------------------------- /code/chatui/assets/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """This module contains theming assets.""" 17 | import os.path 18 | from typing import Tuple 19 | 20 | import gradio as gr 21 | 22 | _ASSET_DIR = os.path.dirname(__file__) 23 | 24 | 25 | def load_theme(name: str) -> Tuple[gr.Theme, str]: 26 | """Load a pre-defined chatui theme. 27 | 28 | :param name: The name of the theme to load. 29 | :type name: str 30 | :returns: A tuple containing the Gradio theme and custom CSS. 31 | :rtype: Tuple[gr.Theme, str] 32 | """ 33 | theme_json_path = os.path.join(_ASSET_DIR, f"{name}-theme.json") 34 | theme_css_path = os.path.join(_ASSET_DIR, f"{name}-theme.css") 35 | return ( 36 | gr.themes.Default().load(theme_json_path), 37 | open(theme_css_path, encoding="UTF-8").read(), 38 | ) 39 | -------------------------------------------------------------------------------- /code/chatui/assets/kaizen-theme.css: -------------------------------------------------------------------------------- 1 | .tabitem { 2 | background-color: var(--block-background-fill); 3 | } 4 | 5 | .gradio-container { 6 | /* This needs to be !important, otherwise the breakpoint override the container being full width */ 7 | max-width: 100% !important; 8 | padding: 10px !important; 9 | } 10 | 11 | footer { 12 | visibility: hidden; 13 | } 14 | -------------------------------------------------------------------------------- /code/chatui/assets/kaizen-theme.json: -------------------------------------------------------------------------------- 1 | { 2 | "theme": { 3 | "_font": [ 4 | { 5 | "__gradio_font__": true, 6 | "name": "NVIDIA Sans", 7 | "class": "font" 8 | }, 9 | { 10 | "__gradio_font__": true, 11 | "name": "ui-sans-serif", 12 | "class": "font" 13 | }, 14 | { 15 | "__gradio_font__": true, 16 | "name": "system-ui", 17 | "class": "font" 18 | }, 19 | { 20 | "__gradio_font__": true, 21 | "name": "sans-serif", 22 | "class": "font" 23 | } 24 | ], 25 | "_font_mono": [ 26 | { 27 | "__gradio_font__": true, 28 | "name": "JetBrains Mono", 29 | "class": "google" 30 | }, 31 | { 32 | "__gradio_font__": true, 33 | "name": "ui-monospace", 34 | "class": "font" 35 | }, 36 | { 37 | "__gradio_font__": true, 38 | "name": "Consolas", 39 | "class": "font" 40 | }, 41 | { 42 | "__gradio_font__": true, 43 | "name": "monospace", 44 | "class": "font" 45 | } 46 | ], 47 | "_stylesheets": [ 48 | "https://fonts.googleapis.com/css2?family=JetBrains+Mono&family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap", 49 | "https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_Lt.woff2", 50 | "https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_LtIt.woff2", 51 | "https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_Rg.woff2", 52 | "https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_It.woff2", 53 | "https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_Md.woff2", 54 | "https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_MdIt.woff2", 55 | "https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_Bd.woff2", 56 | "https://brand-assets.cne.ngc.nvidia.com/assets/fonts/nvidia-sans/1.0.0/NVIDIASans_BdIt.woff2" 57 | ], 58 | "background_fill_primary": "#ffffff", 59 | "background_fill_primary_dark": "#292929", 60 | "background_fill_secondary": "*neutral_50", 61 | "background_fill_secondary_dark": "*neutral_900", 62 | "block_background_fill": "#ffffff", 63 | "block_background_fill_dark": "#292929", 64 | "block_border_color": "#d8d8d8", 65 | "block_border_color_dark": "*border_color_primary", 66 | "block_border_width": "1px", 67 | "block_info_text_color": "*body_text_color_subdued", 68 | "block_info_text_color_dark": "*body_text_color_subdued", 69 | "block_info_text_size": "*text_sm", 70 | "block_info_text_weight": "400", 71 | "block_label_background_fill": "#e4fabe", 72 | "block_label_background_fill_dark": "#e4fabe", 73 | "block_label_border_color": "#e4fabe", 74 | "block_label_border_color_dark": "#e4fabe", 75 | "block_label_border_width": "1px", 76 | "block_label_margin": "0", 77 | "block_label_padding": "*spacing_sm *spacing_lg", 78 | "block_label_radius": "calc(*radius_lg - 1px) 0 calc(*radius_lg - 1px) 0", 79 | "block_label_right_radius": "0 calc(*radius_lg - 1px) 0 calc(*radius_lg - 1px)", 80 | "block_label_shadow": "*block_shadow", 81 | "block_label_text_color": "#4d6721", 82 | "block_label_text_color_dark": "#4d6721", 83 | "block_label_text_size": "*text_sm", 84 | "block_label_text_weight": "400", 85 | "block_padding": "*spacing_xl calc(*spacing_xl + 2px)", 86 | "block_radius": "*radius_lg", 87 | "block_shadow": "*shadow_drop", 88 | "block_title_background_fill": "none", 89 | "block_title_border_color": "none", 90 | "block_title_border_width": "0px", 91 | "block_title_padding": "0", 92 | "block_title_radius": "none", 93 | "block_title_text_color": "*neutral_500", 94 | "block_title_text_color_dark": "*neutral_200", 95 | "block_title_text_size": "*text_md", 96 | "block_title_text_weight": "500", 97 | "body_background_fill": "#f2f2f2", 98 | "body_background_fill_dark": "#202020", 99 | "body_text_color": "#202020", 100 | "body_text_color_dark": "#f2f2f2", 101 | "body_text_color_subdued": "*neutral_400", 102 | "body_text_color_subdued_dark": "*neutral_400", 103 | "body_text_size": "*text_md", 104 | "body_text_weight": "400", 105 | "border_color_accent": "*primary_300", 106 | "border_color_accent_dark": "*neutral_600", 107 | "border_color_primary": "#d8d8d8", 108 | "border_color_primary_dark": "#343434", 109 | "button_border_width": "1px", 110 | "button_border_width_dark": "1px", 111 | "button_cancel_background_fill": "#dc3528", 112 | "button_cancel_background_fill_dark": "#dc3528", 113 | "button_cancel_background_fill_hover": "#b6251b", 114 | "button_cancel_background_fill_hover_dark": "#b6251b", 115 | "button_cancel_border_color": "#dc3528", 116 | "button_cancel_border_color_dark": "#dc3528", 117 | "button_cancel_border_color_hover": "#b6251b", 118 | "button_cancel_border_color_hover_dark": "#b6251b", 119 | "button_cancel_text_color": "#ffffff", 120 | "button_cancel_text_color_dark": "#ffffff", 121 | "button_cancel_text_color_hover": "#ffffff", 122 | "button_cancel_text_color_hover_dark": "#ffffff", 123 | "button_large_padding": "*spacing_lg calc(2 * *spacing_lg)", 124 | "button_large_radius": "*radius_lg", 125 | "button_large_text_size": "*text_lg", 126 | "button_large_text_weight": "500", 127 | "button_primary_background_fill": "#76b900", 128 | "button_primary_background_fill_dark": "#76b900", 129 | "button_primary_background_fill_hover": "#659f00", 130 | "button_primary_background_fill_hover_dark": "#659f00", 131 | "button_primary_border_color": "#76b900", 132 | "button_primary_border_color_dark": "#76b900", 133 | "button_primary_border_color_hover": "#659f00", 134 | "button_primary_border_color_hover_dark": "#659f00", 135 | "button_primary_text_color": "#202020", 136 | "button_primary_text_color_dark": "#202020", 137 | "button_primary_text_color_hover": "#202020", 138 | "button_primary_text_color_hover_dark": "#202020", 139 | "button_secondary_background_fill": "#ffffff", 140 | "button_secondary_background_fill_dark": "#292929", 141 | "button_secondary_background_fill_hover": "#e2e2e2", 142 | "button_secondary_background_fill_hover_dark": "#202020", 143 | "button_secondary_border_color": "#5e5e5e", 144 | "button_secondary_border_color_dark": "#c6c6c6", 145 | "button_secondary_border_color_hover": "#5e5e5e", 146 | "button_secondary_border_color_hover_dark": "#c6c6c6", 147 | "button_secondary_text_color": "#5e5e5e", 148 | "button_secondary_text_color_dark": "#e2e2e2", 149 | "button_secondary_text_color_hover": "#343434", 150 | "button_secondary_text_color_hover_dark": "#ffffff", 151 | "button_shadow": "*shadow_drop", 152 | "button_shadow_active": "*shadow_inset", 153 | "button_shadow_hover": "*shadow_drop_lg", 154 | "button_small_padding": "*spacing_sm calc(2 * *spacing_sm)", 155 | "button_small_radius": "*radius_lg", 156 | "button_small_text_size": "*text_md", 157 | "button_small_text_weight": "400", 158 | "button_transition": "none", 159 | "chatbot_code_background_color": "*neutral_100", 160 | "chatbot_code_background_color_dark": "*neutral_800", 161 | "checkbox_background_color": "*background_fill_primary", 162 | "checkbox_background_color_dark": "*neutral_800", 163 | "checkbox_background_color_focus": "*checkbox_background_color", 164 | "checkbox_background_color_focus_dark": "*checkbox_background_color", 165 | "checkbox_background_color_hover": "*checkbox_background_color", 166 | "checkbox_background_color_hover_dark": "*checkbox_background_color", 167 | "checkbox_background_color_selected": "#659f00", 168 | "checkbox_background_color_selected_dark": "#659f00", 169 | "checkbox_border_color": "*neutral_300", 170 | "checkbox_border_color_dark": "*neutral_700", 171 | "checkbox_border_color_focus": "*secondary_500", 172 | "checkbox_border_color_focus_dark": "*secondary_500", 173 | "checkbox_border_color_hover": "*neutral_300", 174 | "checkbox_border_color_hover_dark": "*neutral_600", 175 | "checkbox_border_color_selected": "#659f00", 176 | "checkbox_border_color_selected_dark": "#659f00", 177 | "checkbox_border_radius": "*radius_sm", 178 | "checkbox_border_width": "2px", 179 | "checkbox_border_width_dark": "*input_border_width", 180 | "checkbox_check": "url(\"data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3cpath d='M12.207 4.793a1 1 0 010 1.414l-5 5a1 1 0 01-1.414 0l-2-2a1 1 0 011.414-1.414L6.5 9.086l4.293-4.293a1 1 0 011.414 0z'/%3e%3c/svg%3e\")", 181 | "checkbox_label_background_fill": "#ffffff", 182 | "checkbox_label_background_fill_dark": "#292929", 183 | "checkbox_label_background_fill_hover": "#ffffff", 184 | "checkbox_label_background_fill_hover_dark": "#292929", 185 | "checkbox_label_background_fill_selected": "*checkbox_label_background_fill", 186 | "checkbox_label_background_fill_selected_dark": "*checkbox_label_background_fill", 187 | "checkbox_label_border_color": "#ffffff", 188 | "checkbox_label_border_color_dark": "#292929", 189 | "checkbox_label_border_color_hover": "*checkbox_label_border_color", 190 | "checkbox_label_border_color_hover_dark": "*checkbox_label_border_color", 191 | "checkbox_label_border_width": "0", 192 | "checkbox_label_border_width_dark": "*input_border_width", 193 | "checkbox_label_gap": "16px", 194 | "checkbox_label_padding": "", 195 | "checkbox_label_shadow": "none", 196 | "checkbox_label_text_color": "*body_text_color", 197 | "checkbox_label_text_color_dark": "*body_text_color", 198 | "checkbox_label_text_color_selected": "*checkbox_label_text_color", 199 | "checkbox_label_text_color_selected_dark": "*checkbox_label_text_color", 200 | "checkbox_label_text_size": "*text_md", 201 | "checkbox_label_text_weight": "400", 202 | "checkbox_shadow": "*input_shadow", 203 | "color_accent": "*primary_500", 204 | "color_accent_soft": "*primary_50", 205 | "color_accent_soft_dark": "*neutral_700", 206 | "container_radius": "*radius_lg", 207 | "embed_radius": "*radius_lg", 208 | "error_background_fill": "#fef2f2", 209 | "error_background_fill_dark": "*neutral_900", 210 | "error_border_color": "#fee2e2", 211 | "error_border_color_dark": "#ef4444", 212 | "error_border_width": "1px", 213 | "error_icon_color": "#b91c1c", 214 | "error_icon_color_dark": "#ef4444", 215 | "error_text_color": "#b91c1c", 216 | "error_text_color_dark": "#fef2f2", 217 | "font": "'NVIDIA Sans', 'ui-sans-serif', 'system-ui', sans-serif", 218 | "font_mono": "'JetBrains Mono', 'ui-monospace', 'Consolas', monospace", 219 | "form_gap_width": "1px", 220 | "input_background_fill": "white", 221 | "input_background_fill_dark": "*neutral_800", 222 | "input_background_fill_focus": "*secondary_500", 223 | "input_background_fill_focus_dark": "*secondary_600", 224 | "input_background_fill_hover": "*input_background_fill", 225 | "input_background_fill_hover_dark": "*input_background_fill", 226 | "input_border_color": "#d8d8d8", 227 | "input_border_color_dark": "#343434", 228 | "input_border_color_focus": "*secondary_300", 229 | "input_border_color_focus_dark": "*neutral_700", 230 | "input_border_color_hover": "*input_border_color", 231 | "input_border_color_hover_dark": "*input_border_color", 232 | "input_border_width": "2px", 233 | "input_padding": "*spacing_xl", 234 | "input_placeholder_color": "*neutral_400", 235 | "input_placeholder_color_dark": "*neutral_500", 236 | "input_radius": "*radius_lg", 237 | "input_shadow": "0 0 0 *shadow_spread transparent, *shadow_inset", 238 | "input_shadow_focus": "0 0 0 *shadow_spread *secondary_50, *shadow_inset", 239 | "input_shadow_focus_dark": "0 0 0 *shadow_spread *neutral_700, *shadow_inset", 240 | "input_text_size": "*text_md", 241 | "input_text_weight": "400", 242 | "layout_gap": "*spacing_xxl", 243 | "link_text_color": "*secondary_600", 244 | "link_text_color_active": "*secondary_600", 245 | "link_text_color_active_dark": "*secondary_500", 246 | "link_text_color_dark": "*secondary_500", 247 | "link_text_color_hover": "*secondary_700", 248 | "link_text_color_hover_dark": "*secondary_400", 249 | "link_text_color_visited": "*secondary_500", 250 | "link_text_color_visited_dark": "*secondary_600", 251 | "loader_color": "*color_accent", 252 | "name": "default", 253 | "neutral_100": "#e2e2e2", 254 | "neutral_200": "#d8d8d8", 255 | "neutral_300": "#c6c6c6", 256 | "neutral_400": "#8f8f8f", 257 | "neutral_50": "#f2f2f2", 258 | "neutral_500": "#767676", 259 | "neutral_600": "#5e5e5e", 260 | "neutral_700": "#343434", 261 | "neutral_800": "#292929", 262 | "neutral_900": "#202020", 263 | "neutral_950": "#121212", 264 | "panel_background_fill": "*background_fill_secondary", 265 | "panel_background_fill_dark": "*background_fill_secondary", 266 | "panel_border_color": "*border_color_primary", 267 | "panel_border_color_dark": "*border_color_primary", 268 | "panel_border_width": "0", 269 | "primary_100": "#caf087", 270 | "primary_200": "#b6e95d", 271 | "primary_300": "#9fd73d", 272 | "primary_400": "#76b900", 273 | "primary_50": "#e4fabe", 274 | "primary_500": "#659f00", 275 | "primary_600": "#538300", 276 | "primary_700": "#4d6721", 277 | "primary_800": "#253a00", 278 | "primary_900": "#1d2e00", 279 | "primary_950": "#172400", 280 | "prose_header_text_weight": "600", 281 | "prose_text_size": "*text_md", 282 | "prose_text_weight": "400", 283 | "radio_circle": "url(\"data:image/svg+xml,%3csvg viewBox='0 0 16 16' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='8' cy='8' r='3'/%3e%3c/svg%3e\")", 284 | "radius_lg": "0px", 285 | "radius_md": "0px", 286 | "radius_sm": "0px", 287 | "radius_xl": "0px", 288 | "radius_xs": "0px", 289 | "radius_xxl": "0px", 290 | "radius_xxs": "0px", 291 | "secondary_100": "#cde6fa", 292 | "secondary_200": "#badef8", 293 | "secondary_300": "#9accf2", 294 | "secondary_400": "#3a96d9", 295 | "secondary_50": "#e9f4fb", 296 | "secondary_500": "#2378ca", 297 | "secondary_600": "#2a63ba", 298 | "secondary_700": "#013076", 299 | "secondary_800": "#00265e", 300 | "secondary_900": "#001e4b", 301 | "secondary_950": "#00112c", 302 | "section_header_text_size": "*text_md", 303 | "section_header_text_weight": "500", 304 | "shadow_drop": "none", 305 | "shadow_drop_lg": "0 1px 3px 0 rgb(0 0 0 / 0.1), 0 1px 2px -1px rgb(0 0 0 / 0.1)", 306 | "shadow_inset": "rgba(0,0,0,0.05) 0px 2px 4px 0px inset", 307 | "shadow_spread": "3px", 308 | "shadow_spread_dark": "1px", 309 | "slider_color": "#9fd73d", 310 | "spacing_lg": "8px", 311 | "spacing_md": "6px", 312 | "spacing_sm": "4px", 313 | "spacing_xl": "10px", 314 | "spacing_xs": "2px", 315 | "spacing_xxl": "16px", 316 | "spacing_xxs": "1px", 317 | "stat_background_fill": "linear-gradient(to right, *primary_400, *primary_200)", 318 | "stat_background_fill_dark": "linear-gradient(to right, *primary_400, *primary_600)", 319 | "table_border_color": "*neutral_300", 320 | "table_border_color_dark": "*neutral_700", 321 | "table_even_background_fill": "white", 322 | "table_even_background_fill_dark": "*neutral_950", 323 | "table_odd_background_fill": "*neutral_50", 324 | "table_odd_background_fill_dark": "*neutral_900", 325 | "table_radius": "*radius_lg", 326 | "table_row_focus": "*color_accent_soft", 327 | "table_row_focus_dark": "*color_accent_soft", 328 | "text_lg": "16px", 329 | "text_md": "14px", 330 | "text_sm": "12px", 331 | "text_xl": "22px", 332 | "text_xs": "10px", 333 | "text_xxl": "26px", 334 | "text_xxs": "9px" 335 | } 336 | } 337 | -------------------------------------------------------------------------------- /code/chatui/chat_client.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """The API client for the langchain-esque service.""" 17 | import logging 18 | import mimetypes 19 | import typing 20 | 21 | import requests 22 | 23 | _LOGGER = logging.getLogger(__name__) 24 | 25 | 26 | class ChatClient: 27 | """A client for connecting the the lanchain-esque service.""" 28 | 29 | def __init__(self, server_url: str, model_name: str) -> None: 30 | """Initialize the client.""" 31 | self.server_url = server_url 32 | self._model_name = model_name 33 | self.default_model = "local" 34 | 35 | @property 36 | def model_name(self) -> str: 37 | """Return the friendly model name.""" 38 | return self._model_name 39 | 40 | def search( 41 | self, prompt: str 42 | ) -> typing.List[typing.Dict[str, typing.Union[str, float]]]: 43 | """Search for relevant documents and return json data.""" 44 | data = {"content": prompt, "num_docs": 4} 45 | headers = {"accept": "application/json", "Content-Type": "application/json"} 46 | url = f"{self.server_url}/documentSearch" 47 | _LOGGER.debug( 48 | "looking up documents - %s", str({"server_url": url, "post_data": data}) 49 | ) 50 | 51 | with requests.post(url, headers=headers, json=data, timeout=30) as req: 52 | response = req.json() 53 | return typing.cast( 54 | typing.List[typing.Dict[str, typing.Union[str, float]]], response 55 | ) 56 | 57 | def predict( 58 | self, 59 | query: str, 60 | mode: str, 61 | local_model_id: str, 62 | nvcf_model_id: str, 63 | nim_model_ip: str, 64 | nim_model_port: str, 65 | nim_model_id: str, 66 | temp_slider: float, 67 | top_p_slider: float, 68 | freq_pen_slider: float, 69 | pres_pen_slider: float, 70 | use_knowledge_base: bool, 71 | num_tokens: int 72 | ) -> typing.Generator[str, None, None]: 73 | """Make a model prediction.""" 74 | data = { 75 | "question": query, 76 | "context": "", 77 | "use_knowledge_base": use_knowledge_base, 78 | "num_tokens": num_tokens, 79 | "inference_mode": mode, 80 | "local_model_id": local_model_id, 81 | "nvcf_model_id": nvcf_model_id, 82 | "nim_model_ip": nim_model_ip, 83 | "nim_model_port": nim_model_port, 84 | "nim_model_id": nim_model_id, 85 | "temp": temp_slider, 86 | "top_p": top_p_slider, 87 | "freq_pen": freq_pen_slider, 88 | "pres_pen": pres_pen_slider, 89 | } 90 | url = f"{self.server_url}/generate" 91 | _LOGGER.info( 92 | "making inference request - %s", str({"server_url": url, "post_data": data}) 93 | ) 94 | msg = str({"server_url": url, "post_data": data}) 95 | print(f"making inference request - {msg}") 96 | 97 | with requests.post(url, stream=True, json=data, timeout=10) as req: 98 | for chunk in req.iter_content(16): 99 | yield chunk.decode("UTF-8") 100 | 101 | def upload_documents(self, file_paths: typing.List[str]) -> None: 102 | """Upload documents to the kb.""" 103 | url = f"{self.server_url}/uploadDocument" 104 | headers = { 105 | "accept": "application/json", 106 | } 107 | 108 | for fpath in file_paths: 109 | mime_type, _ = mimetypes.guess_type(fpath) 110 | # pylint: disable-next=consider-using-with # with pattern is not intuitive here 111 | files = {"file": (fpath, open(fpath, "rb"), mime_type)} 112 | 113 | _LOGGER.debug( 114 | "uploading file - %s", 115 | str({"server_url": url, "file": fpath}), 116 | ) 117 | 118 | _ = requests.post( 119 | url, headers=headers, files=files, verify=False, timeout=120 # type: ignore [arg-type] 120 | ) # nosec # verify=false is intentional for now 121 | -------------------------------------------------------------------------------- /code/chatui/configuration.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """The definition of the application configuration.""" 17 | from chatui.configuration_wizard import ConfigWizard, configclass, configfield 18 | 19 | 20 | @configclass 21 | class AppConfig(ConfigWizard): 22 | """Configuration class for the application. 23 | 24 | :cvar triton: The configuration of the chat server. 25 | :type triton: ChatConfig 26 | :cvar model: The configuration of the model 27 | :type triton: ModelConfig 28 | """ 29 | 30 | server_url: str = configfield( 31 | "serverUrl", 32 | default="http://localhost", 33 | help_txt="The location of the chat API server.", 34 | ) 35 | server_port: str = configfield( 36 | "serverPort", 37 | default="8000", 38 | help_txt="The port on which the chat server is listening for HTTP requests.", 39 | ) 40 | server_prefix: str = configfield( 41 | "serverPrefix", 42 | default="/projects/retrieval-augmented-generation/applications/rag-api/", 43 | help_txt="The prefix on which the server is running.", 44 | ) 45 | model_name: str = configfield( 46 | "modelName", 47 | default="local", 48 | help_txt="The name of the hosted LLM model.", 49 | ) 50 | -------------------------------------------------------------------------------- /code/chatui/configuration_wizard.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """A module containing utilities for defining application configuration. 17 | 18 | This module provides a configuration wizard class that can read configuration data from YAML, JSON, and environment 19 | variables. The configuration wizard is based heavily off of the JSON and YAML wizards from the `dataclass-wizard` 20 | Python package. That package is in-turn based heavily off of the built-in `dataclass` module. 21 | 22 | This module adds Environment Variable parsing to config file reading. 23 | """ 24 | # pylint: disable=too-many-lines; this file is meant to be portable between projects so everything is put into one file 25 | 26 | import json 27 | import logging 28 | import os 29 | from dataclasses import _MISSING_TYPE, dataclass 30 | from typing import Any, Callable, Dict, List, Optional, TextIO, Tuple, Union 31 | 32 | import yaml 33 | from dataclass_wizard import ( 34 | JSONWizard, 35 | LoadMeta, 36 | YAMLWizard, 37 | errors, 38 | fromdict, 39 | json_field, 40 | ) 41 | from dataclass_wizard.models import JSONField 42 | from dataclass_wizard.utils.string_conv import to_camel_case 43 | 44 | configclass = dataclass(frozen=True) 45 | ENV_BASE = "APP" 46 | _LOGGER = logging.getLogger(__name__) 47 | 48 | 49 | def configfield( 50 | name: str, *, env: bool = True, help_txt: str = "", **kwargs: Any 51 | ) -> JSONField: 52 | """Create a data class field with the specified name in JSON format. 53 | 54 | :param name: The name of the field. 55 | :type name: str 56 | :param env: Whether this field should be configurable from an environment variable. 57 | :type env: bool 58 | :param help_txt: The description of this field that is used in help docs. 59 | :type help_txt: str 60 | :param **kwargs: Optional keyword arguments to customize the JSON field. More information here: 61 | https://dataclass-wizard.readthedocs.io/en/latest/dataclass_wizard.html#dataclass_wizard.json_field 62 | :type **kwargs: Any 63 | :returns: A JSONField instance with the specified name and optional parameters. 64 | :rtype: JSONField 65 | 66 | :raises TypeError: If the provided name is not a string. 67 | """ 68 | # sanitize specified name 69 | if not isinstance(name, str): 70 | raise TypeError("Provided name must be a string.") 71 | json_name = to_camel_case(name) 72 | 73 | # update metadata 74 | meta = kwargs.get("metadata", {}) 75 | meta["env"] = env 76 | meta["help"] = help_txt 77 | kwargs["metadata"] = meta 78 | 79 | # create the data class field 80 | field = json_field(json_name, **kwargs) 81 | return field 82 | 83 | 84 | class _Color: 85 | """A collection of colors used when writing output to the shell.""" 86 | 87 | # pylint: disable=too-few-public-methods; this class does not require methods. 88 | 89 | PURPLE = "\033[95m" 90 | BLUE = "\033[94m" 91 | GREEN = "\033[92m" 92 | YELLOW = "\033[93m" 93 | RED = "\033[91m" 94 | BOLD = "\033[1m" 95 | UNDERLINE = "\033[4m" 96 | END = "\033[0m" 97 | 98 | 99 | class ConfigWizard(JSONWizard, YAMLWizard): # type: ignore[misc] # dataclass-wizard doesn't provide stubs 100 | """A configuration wizard class that can read configuration data from YAML, JSON, and environment variables.""" 101 | 102 | # pylint: disable=arguments-differ,arguments-renamed; this class intentionally reduces arguments for some methods. 103 | 104 | @classmethod 105 | def print_help( 106 | cls, 107 | help_printer: Callable[[str], Any], 108 | *, 109 | env_parent: Optional[str] = None, 110 | json_parent: Optional[Tuple[str, ...]] = None, 111 | ) -> None: 112 | """Print the help documentation for the application configuration with the provided `write` function. 113 | 114 | :param help_printer: The `write` function that will be used to output the data. 115 | :param help_printer: Callable[[str], None] 116 | :param env_parent: The name of the parent environment variable. Leave blank, used for recursion. 117 | :type env_parent: Optional[str] 118 | :param json_parent: The name of the parent JSON key. Leave blank, used for recursion. 119 | :type json_parent: Optional[Tuple[str, ...]] 120 | :returns: A list of tuples with one item per configuration value. Each item will have the environment variable 121 | and a tuple to the path in configuration. 122 | :rtype: List[Tuple[str, Tuple[str, ...]]] 123 | """ 124 | if not env_parent: 125 | env_parent = "" 126 | help_printer("---\n") 127 | if not json_parent: 128 | json_parent = () 129 | 130 | for ( 131 | _, 132 | val, 133 | ) in ( 134 | cls.__dataclass_fields__.items() # pylint: disable=no-member; false positive 135 | ): # pylint: disable=no-member; member is added by dataclass. 136 | jsonname = val.json.keys[0] 137 | envname = jsonname.upper() 138 | full_envname = f"{ENV_BASE}{env_parent}_{envname}" 139 | is_embedded_config = hasattr(val.type, "envvars") 140 | 141 | # print the help data 142 | indent = len(json_parent) * 2 143 | if is_embedded_config: 144 | default = "" 145 | elif not isinstance(val.default_factory, _MISSING_TYPE): 146 | default = val.default_factory() 147 | elif isinstance(val.default, _MISSING_TYPE): 148 | default = "NO-DEFAULT-VALUE" 149 | else: 150 | default = val.default 151 | help_printer( 152 | f"{_Color.BOLD}{' ' * indent}{jsonname}:{_Color.END} {default}\n" 153 | ) 154 | 155 | # print comments 156 | if is_embedded_config: 157 | indent += 2 158 | if val.metadata.get("help"): 159 | help_printer(f"{' ' * indent}# {val.metadata['help']}\n") 160 | if not is_embedded_config: 161 | typestr = getattr(val.type, "__name__", None) or str(val.type).replace( 162 | "typing.", "" 163 | ) 164 | help_printer(f"{' ' * indent}# Type: {typestr}\n") 165 | if val.metadata.get("env", True): 166 | help_printer(f"{' ' * indent}# ENV Variable: {full_envname}\n") 167 | # if not is_embedded_config: 168 | help_printer("\n") 169 | 170 | if is_embedded_config: 171 | new_env_parent = f"{env_parent}_{envname}" 172 | new_json_parent = json_parent + (jsonname,) 173 | val.type.print_help( 174 | help_printer, env_parent=new_env_parent, json_parent=new_json_parent 175 | ) 176 | 177 | help_printer("\n") 178 | 179 | @classmethod 180 | def envvars( 181 | cls, 182 | env_parent: Optional[str] = None, 183 | json_parent: Optional[Tuple[str, ...]] = None, 184 | ) -> List[Tuple[str, Tuple[str, ...], type]]: 185 | """Calculate valid environment variables and their config structure location. 186 | 187 | :param env_parent: The name of the parent environment variable. 188 | :type env_parent: Optional[str] 189 | :param json_parent: The name of the parent JSON key. 190 | :type json_parent: Optional[Tuple[str, ...]] 191 | :returns: A list of tuples with one item per configuration value. Each item will have the environment variable, 192 | a tuple to the path in configuration, and they type of the value. 193 | :rtype: List[Tuple[str, Tuple[str, ...], type]] 194 | """ 195 | if not env_parent: 196 | env_parent = "" 197 | if not json_parent: 198 | json_parent = () 199 | output = [] 200 | 201 | for ( 202 | _, 203 | val, 204 | ) in ( 205 | cls.__dataclass_fields__.items() # pylint: disable=no-member; false positive 206 | ): # pylint: disable=no-member; member is added by dataclass. 207 | jsonname = val.json.keys[0] 208 | envname = jsonname.upper() 209 | full_envname = f"{ENV_BASE}{env_parent}_{envname}" 210 | is_embedded_config = hasattr(val.type, "envvars") 211 | 212 | # add entry to output list 213 | if is_embedded_config: 214 | new_env_parent = f"{env_parent}_{envname}" 215 | new_json_parent = json_parent + (jsonname,) 216 | output += val.type.envvars( 217 | env_parent=new_env_parent, json_parent=new_json_parent 218 | ) 219 | elif val.metadata.get("env", True): 220 | output += [(full_envname, json_parent + (jsonname,), val.type)] 221 | 222 | return output 223 | 224 | @classmethod 225 | def from_dict(cls, data: Dict[str, Any]) -> "ConfigWizard": 226 | """Create a ConfigWizard instance from a dictionary. 227 | 228 | :param data: The dictionary containing the configuration data. 229 | :type data: Dict[str, Any] 230 | :returns: A ConfigWizard instance created from the input dictionary. 231 | :rtype: ConfigWizard 232 | 233 | :raises RuntimeError: If the configuration data is not a dictionary. 234 | """ 235 | # sanitize data 236 | if not data: 237 | data = {} 238 | if not isinstance(data, dict): 239 | raise RuntimeError("Configuration data is not a dictionary.") 240 | 241 | # parse env variables 242 | for envvar in cls.envvars(): 243 | var_name, conf_path, var_type = envvar 244 | var_value = os.environ.get(var_name) 245 | if var_value: 246 | var_value = try_json_load(var_value) 247 | update_dict(data, conf_path, var_value) 248 | _LOGGER.debug( 249 | "Found EnvVar Config - %s:%s = %s", 250 | var_name, 251 | str(var_type), 252 | repr(var_value), 253 | ) 254 | 255 | LoadMeta(key_transform="CAMEL").bind_to(cls) 256 | return fromdict(cls, data) # type: ignore[no-any-return] # dataclass-wizard doesn't provide stubs 257 | 258 | @classmethod 259 | def from_file(cls, filepath: str) -> Optional["ConfigWizard"]: 260 | """Load the application configuration from the specified file. 261 | 262 | The file must be either in JSON or YAML format. 263 | 264 | :returns: The fully processed configuration file contents. If the file was unreadable, None will be returned. 265 | :rtype: Optional["ConfigWizard"] 266 | """ 267 | # open the file 268 | try: 269 | # pylint: disable-next=consider-using-with; using a with would make exception handling even more ugly 270 | file = open(filepath, encoding="utf-8") 271 | except FileNotFoundError: 272 | _LOGGER.error("The configuration file cannot be found.") 273 | file = None 274 | except PermissionError: 275 | _LOGGER.error( 276 | "Permission denied when trying to read the configuration file." 277 | ) 278 | file = None 279 | if not file: 280 | return None 281 | 282 | # read the file 283 | try: 284 | data = read_json_or_yaml(file) 285 | except ValueError as err: 286 | _LOGGER.error( 287 | "Configuration file must be valid JSON or YAML. The following errors occured:\n%s", 288 | str(err), 289 | ) 290 | data = None 291 | config = None 292 | finally: 293 | file.close() 294 | 295 | # parse the file 296 | if data: 297 | try: 298 | config = cls.from_dict(data) 299 | except errors.MissingFields as err: 300 | _LOGGER.error( 301 | "Configuration is missing required fields: \n%s", str(err) 302 | ) 303 | config = None 304 | except errors.ParseError as err: 305 | _LOGGER.error("Invalid configuration value provided:\n%s", str(err)) 306 | config = None 307 | else: 308 | config = cls.from_dict({}) 309 | 310 | return config 311 | 312 | 313 | def read_json_or_yaml(stream: TextIO) -> Dict[str, Any]: 314 | """Read a file without knowing if it is JSON or YAML formatted. 315 | 316 | The file will first be assumed to be JSON formatted. If this fails, an attempt to parse the file with the YAML 317 | parser will be made. If both of these fail, an exception will be raised that contains the exception strings returned 318 | by both the parsers. 319 | 320 | :param stream: An IO stream that allows seeking. 321 | :type stream: typing.TextIO 322 | :returns: The parsed file contents. 323 | :rtype: typing.Dict[str, typing.Any]: 324 | :raises ValueError: If the IO stream is not seekable or if the file doesn't appear to be JSON or YAML formatted. 325 | """ 326 | exceptions: Dict[str, Union[None, ValueError, yaml.error.YAMLError]] = { 327 | "JSON": None, 328 | "YAML": None, 329 | } 330 | data: Dict[str, Any] 331 | 332 | # ensure we can rewind the file 333 | if not stream.seekable(): 334 | raise ValueError("The provided stream must be seekable.") 335 | 336 | # attempt to read json 337 | try: 338 | data = json.loads(stream.read()) 339 | except ValueError as err: 340 | exceptions["JSON"] = err 341 | else: 342 | return data 343 | finally: 344 | stream.seek(0) 345 | 346 | # attempt to read yaml 347 | try: 348 | data = yaml.safe_load(stream.read()) 349 | except (yaml.error.YAMLError, ValueError) as err: 350 | exceptions["YAML"] = err 351 | else: 352 | return data 353 | 354 | # neither json nor yaml 355 | err_msg = "\n\n".join( 356 | [key + " Parser Errors:\n" + str(val) for key, val in exceptions.items()] 357 | ) 358 | raise ValueError(err_msg) 359 | 360 | 361 | def try_json_load(value: str) -> Any: 362 | """Try parsing the value as JSON and silently ignore errors. 363 | 364 | :param value: The value on which a JSON load should be attempted. 365 | :type value: str 366 | :returns: Either the parsed JSON or the provided value. 367 | :rtype: typing.Any 368 | """ 369 | try: 370 | return json.loads(value) 371 | except json.JSONDecodeError: 372 | return value 373 | 374 | 375 | def update_dict( 376 | data: Dict[str, Any], 377 | path: Tuple[str, ...], 378 | value: Any, 379 | overwrite: bool = False, 380 | ) -> None: 381 | """Update a dictionary with a new value at a given path. 382 | 383 | :param data: The dictionary to be updated. 384 | :type data: Dict[str, Any] 385 | :param path: The path to the key that should be updated. 386 | :type path: Tuple[str, ...] 387 | :param value: The new value to be set at the specified path. 388 | :type value: Any 389 | :param overwrite: If True, overwrite the existing value. Otherwise, don't update if the key already exists. 390 | :type overwrite: bool 391 | :returns: None 392 | """ 393 | end = len(path) 394 | target = data 395 | for idx, key in enumerate(path, 1): 396 | # on the last field in path, update the dict if necessary 397 | if idx == end: 398 | if overwrite or not target.get(key): 399 | target[key] = value 400 | return 401 | 402 | # verify the next hop exists 403 | if not target.get(key): 404 | target[key] = {} 405 | 406 | # if the next hop is not a dict, exit 407 | if not isinstance(target.get(key), dict): 408 | return 409 | 410 | # get next hop 411 | target = target.get(key) # type: ignore[assignment] # type has already been enforced. 412 | -------------------------------------------------------------------------------- /code/chatui/pages/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """This module contains definitions for all the chatui pages.""" 17 | from chatui.pages import converse 18 | 19 | __all__ = ["converse"] 20 | -------------------------------------------------------------------------------- /code/chatui/pages/converse.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | ### This module contains the chatui gui for having a conversation. ### 17 | 18 | import functools 19 | from typing import Any, Dict, List, Tuple, Union 20 | 21 | import gradio as gr 22 | import shutil 23 | import os 24 | import subprocess 25 | import time 26 | import sys 27 | import torch 28 | import logging 29 | import gc 30 | 31 | from chatui import assets, chat_client 32 | from chatui.utils import logger 33 | from diffusers import StableDiffusionXLPipeline, DiffusionPipeline 34 | 35 | PATH = "/" 36 | TITLE = "SDXL Image Generation" 37 | OUTPUT_TOKENS = 250 38 | MAX_DOCS = 5 39 | GENERATED_IMG_DIR = "/project/data/generated_images" 40 | BASE_MODEL = "stabilityai/stable-diffusion-xl-base-1.0" 41 | 42 | ### Load in CSS here for components that need custom styling. ### 43 | 44 | _LOCAL_CSS = """ 45 | #contextbox { 46 | overflow-y: scroll !important; 47 | max-height: 400px; 48 | } 49 | 50 | #params .tabs { 51 | display: flex; 52 | flex-direction: column; 53 | flex-grow: 1; 54 | } 55 | #params .tabitem[style="display: block;"] { 56 | flex-grow: 1; 57 | display: flex !important; 58 | } 59 | #params .gap { 60 | flex-grow: 1; 61 | } 62 | #params .form { 63 | flex-grow: 1 !important; 64 | } 65 | #params .form > :last-child{ 66 | flex-grow: 1; 67 | } 68 | #accordion { 69 | } 70 | #rag-inputs .svelte-1gfkn6j { 71 | color: #76b900; 72 | } 73 | #rag-inputs .svelte-s1r2yt { 74 | color: #76b900; 75 | } 76 | """ 77 | 78 | INSTRUCTIONS = """ 79 |
Welcome to the SDXL Image Generation app! To get started with model inference, 80 | 81 |   1. Select a model from the dropdown 82 | 83 |   2. Input an image generation prompt into the textbox and press ENTER 84 | 85 |   3. Generated images are auto-saved to the project under ``data/generated_images`` 86 | 87 | \n\n Important: Make sure you disconnect any kernels that may still be utilizing the GPU! 88 | """ 89 | 90 | sys.stdout = logger.Logger("/project/code/output.log") 91 | 92 | print("--- MODELS: Loading Model " + BASE_MODEL + " ---") 93 | pipe = StableDiffusionXLPipeline.from_pretrained( 94 | BASE_MODEL, torch_dtype=torch.float16, variant="fp16", use_safetensors=True 95 | ) 96 | print("--- MODELS: Configuring Pipe ---") 97 | pipe.to("cuda") 98 | print("--- MODELS: Model is ready for inference ---") 99 | 100 | def build_page(client: chat_client.ChatClient) -> gr.Blocks: 101 | """ 102 | Build the gradio page to be mounted in the frame. 103 | 104 | Parameters: 105 | client (chat_client.ChatClient): The chat client running the application. 106 | 107 | Returns: 108 | page (gr.Blocks): A Gradio page. 109 | """ 110 | kui_theme, kui_styles = assets.load_theme("kaizen") 111 | 112 | # Get a list of models 113 | entries = os.listdir("/project/models") 114 | models = [entry for entry in entries if os.path.isdir(os.path.join("/project/models", entry)) and entry[0] != '.'] 115 | models.insert(0, BASE_MODEL) 116 | 117 | # Prep base model 118 | logging.basicConfig(level=logging.INFO) 119 | global pipe 120 | 121 | with gr.Blocks(title=TITLE, theme=kui_theme, css=kui_styles + _LOCAL_CSS) as page: 122 | gr.Markdown(f"# {TITLE}") 123 | 124 | """ Keep state of which model pipe to use. """ 125 | 126 | current_pipe = gr.State({"pipe": pipe}) 127 | 128 | """ Build the Chat Application. """ 129 | 130 | with gr.Row(equal_height=True): 131 | 132 | # Left Column will display the chatbot 133 | with gr.Column(scale=15, min_width=350): 134 | 135 | # Main chatbot panel. 136 | with gr.Row(equal_height=True): 137 | with gr.Column(min_width=350): 138 | chatbot = gr.Chatbot(show_label=False, height=575) 139 | 140 | # Message box for user input 141 | with gr.Row(equal_height=True): 142 | with gr.Column(scale=3, min_width=450): 143 | msg = gr.Textbox( 144 | show_label=False, 145 | placeholder="Enter your image prompt and press ENTER", 146 | container=False, 147 | interactive=True, 148 | ) 149 | 150 | with gr.Column(scale=1, min_width=150): 151 | clear = gr.ClearButton([msg, chatbot], value="Clear history") 152 | 153 | # Hidden column to be rendered when the user collapses all settings. 154 | with gr.Column(scale=1, min_width=100, visible=False) as hidden_settings_column: 155 | show_settings = gr.Button(value="< Expand", size="sm") 156 | 157 | # Right column to display all relevant settings 158 | with gr.Column(scale=10, min_width=350) as settings_column: 159 | with gr.Tabs(selected=0) as settings_tabs: 160 | with gr.TabItem("Settings", id=0) as model_settings: 161 | 162 | gr.Markdown(INSTRUCTIONS) 163 | 164 | model = gr.Dropdown(models, 165 | label="Select a model", 166 | elem_id="rag-inputs", 167 | value=BASE_MODEL) 168 | 169 | logs = gr.Textbox(label="Console", 170 | elem_id="rag-inputs", 171 | lines=12, 172 | max_lines=12, 173 | interactive=False) 174 | 175 | with gr.TabItem("Hide All Settings", id=1) as hide_all_settings: 176 | gr.Markdown("") 177 | 178 | def _toggle_hide_all_settings(): 179 | print("--- SETTINGS: Hiding Settings ---") 180 | return { 181 | settings_column: gr.update(visible=False), 182 | hidden_settings_column: gr.update(visible=True), 183 | } 184 | 185 | def _toggle_show_all_settings(): 186 | print("--- SETTINGS: Expanding Settings ---") 187 | return { 188 | settings_column: gr.update(visible=True), 189 | settings_tabs: gr.update(selected=0), 190 | hidden_settings_column: gr.update(visible=False), 191 | } 192 | 193 | hide_all_settings.select(_toggle_hide_all_settings, None, [settings_column, hidden_settings_column]) 194 | show_settings.click(_toggle_show_all_settings, None, [settings_column, settings_tabs, hidden_settings_column]) 195 | 196 | def clear_imgs(): 197 | print("--- IMAGES: Clearing Images... ---") 198 | for file in os.listdir(GENERATED_IMG_DIR): 199 | if not file.endswith(".png"): 200 | continue 201 | print("--- SETTINGS: Deleting '" + file + "' ---") 202 | os.remove(os.path.join(GENERATED_IMG_DIR, file)) 203 | 204 | clear.click(clear_imgs, [], []) 205 | 206 | def load_model(model: str): 207 | pipe = None 208 | gc.collect() 209 | torch.cuda.empty_cache() 210 | if model == BASE_MODEL: 211 | print("--- MODELS: Loading Model " + model + " ---") 212 | pipe = StableDiffusionXLPipeline.from_pretrained( 213 | BASE_MODEL, torch_dtype=torch.float16, variant="fp16", use_safetensors=True 214 | ) 215 | print("--- MODELS: Configuring Pipe ---") 216 | pipe.to("cuda") 217 | else: 218 | print("--- MODELS: Loading Model: " + model + " ---") 219 | pipe = StableDiffusionXLPipeline.from_pretrained( 220 | BASE_MODEL, torch_dtype=torch.float16, variant="fp16", use_safetensors=True 221 | ) 222 | print("--- MODELS: Configuring Pipe ---") 223 | pipe.to("cuda") 224 | print("--- MODELS: Loading LoRA Weights for: " + model + " ---") 225 | pipe.load_lora_weights("/project/models/" + model) 226 | print("--- MODELS: Model is ready for inference ---") 227 | return { 228 | current_pipe: {"pipe": pipe}, 229 | msg: gr.update(visible=True), 230 | } 231 | 232 | model.change(load_model, [model], [current_pipe, msg]) 233 | 234 | page.load(logger.read_logs, None, logs, every=1) 235 | 236 | """ This helper function builds out the submission function call when a user submits a query. """ 237 | 238 | _my_build_stream = functools.partial(_stream_predict, client) 239 | msg.submit( 240 | _my_build_stream, [msg, 241 | chatbot, 242 | current_pipe], [msg, chatbot] 243 | ) 244 | 245 | page.queue() 246 | return page 247 | 248 | def create_img_dir(): 249 | if not os.path.exists(GENERATED_IMG_DIR): 250 | print("--- IMAGES: Creating Image Directory ---") 251 | os.makedirs(GENERATED_IMG_DIR) 252 | 253 | def get_image_count(): 254 | """Count all .png files in the given directory.""" 255 | count = 0 256 | for filename in os.listdir(GENERATED_IMG_DIR): 257 | if filename.endswith('.png'): 258 | count += 1 259 | return count 260 | 261 | def gen_new_img_name(): 262 | print("--- IMAGES: Created Image 'generated_image-" + str(get_image_count()) + ".png' ---") 263 | return GENERATED_IMG_DIR + "/generated_image-" + str(get_image_count()) + ".png" 264 | 265 | def generate_image(pipe, prompt): 266 | create_img_dir() 267 | print("--- IMAGES: Image is generating... ---") 268 | image = pipe(prompt=prompt).images[0] 269 | name = gen_new_img_name() 270 | image.save(name) 271 | return name 272 | 273 | """ This helper function executes and generates a response to the user query. """ 274 | 275 | def _stream_predict( 276 | client: chat_client.ChatClient, 277 | prompt: str, 278 | chat_history: List[Tuple[str, str]], 279 | pipe, 280 | ) -> Any: 281 | 282 | try: 283 | yield "", chat_history + [[prompt, "Generating your image..."]] 284 | filename = generate_image(pipe["pipe"], prompt) 285 | yield "", chat_history + [[prompt, (filename,)]] 286 | except Exception as e: 287 | yield "", chat_history + [[prompt, "*** ERR: Unable to process query. ***\n\nException: " + str(e)]] 288 | -------------------------------------------------------------------------------- /code/chatui/static/converse.html: -------------------------------------------------------------------------------- 1 |
-------------------------------------------------------------------------------- /code/chatui/static/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/workbench-example-sdxl-customization/3a13b42bb9b9dd27322fc0c57dabcaa73c126451/code/chatui/static/favicon.ico -------------------------------------------------------------------------------- /code/chatui/static/index.html: -------------------------------------------------------------------------------- 1 |
-------------------------------------------------------------------------------- /code/chatui/static/kb.html: -------------------------------------------------------------------------------- 1 |
-------------------------------------------------------------------------------- /code/chatui/static/next.svg: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /code/chatui/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | -------------------------------------------------------------------------------- /code/chatui/utils/logger.py: -------------------------------------------------------------------------------- 1 | # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # SPDX-License-Identifier: Apache-2.0 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Logger class for capturing stdout 17 | import sys 18 | 19 | class Logger: 20 | def __init__(self, filename): 21 | self.terminal = sys.stdout 22 | self.log = open(filename, "w") 23 | 24 | def write(self, message): 25 | self.terminal.write(message) 26 | self.log.write(message) 27 | 28 | def flush(self): 29 | self.terminal.flush() 30 | self.log.flush() 31 | 32 | def isatty(self): 33 | return False 34 | 35 | def read_logs(): 36 | sys.stdout.flush() 37 | with open("/project/code/output.log", "r") as f: 38 | return f.read() -------------------------------------------------------------------------------- /code/output.log: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/workbench-example-sdxl-customization/3a13b42bb9b9dd27322fc0c57dabcaa73c126451/code/output.log -------------------------------------------------------------------------------- /data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/workbench-example-sdxl-customization/3a13b42bb9b9dd27322fc0c57dabcaa73c126451/data/.gitkeep -------------------------------------------------------------------------------- /data/charles-3/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/workbench-example-sdxl-customization/3a13b42bb9b9dd27322fc0c57dabcaa73c126451/data/charles-3/.gitkeep -------------------------------------------------------------------------------- /data/my-data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/workbench-example-sdxl-customization/3a13b42bb9b9dd27322fc0c57dabcaa73c126451/data/my-data/.gitkeep -------------------------------------------------------------------------------- /data/scratch/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/workbench-example-sdxl-customization/3a13b42bb9b9dd27322fc0c57dabcaa73c126451/data/scratch/.gitkeep -------------------------------------------------------------------------------- /data/toy-jensen/tj1.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1885ff2f6007dea62b82460fbd659ec1a9a57d69b760553abf5733224dd880ed 3 | size 2967492 4 | -------------------------------------------------------------------------------- /data/toy-jensen/tj2.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:595556fb6e0a2ba44576cda03675a58e27f5b399779d8dc3c14755833f77d180 3 | size 1135508 4 | -------------------------------------------------------------------------------- /data/toy-jensen/tj3.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:794e290c5557ab60d7ed6756e5ffc84c966c06def7af10b7020d185cc7f78c40 3 | size 212813 4 | -------------------------------------------------------------------------------- /data/toy-jensen/tj4.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f3e3189c7f94b6ea193501a6188cc65577b5b2ac9a6b97a96981b72306b301fd 3 | size 967498 4 | -------------------------------------------------------------------------------- /data/toy-jensen/tj5.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:1f9271cecd0f74cda99514e49e83995eb60a1fa621f799649aab6aeb1bec6e6c 3 | size 613242 4 | -------------------------------------------------------------------------------- /data/toy-jensen/tj6.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:e3d3eb9fcdfe976ca6e38323125bb3ff19ade07bd698e38418dc142fe1672347 3 | size 1471810 4 | -------------------------------------------------------------------------------- /data/toy-jensen/tj7.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:8995e52d2520b29e5336bc985876ea5f29e19182cee0ccf94cfe558214a287e6 3 | size 2037986 4 | -------------------------------------------------------------------------------- /data/toy-jensen/tj8.png: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f1925c59a9c8438508e21f5b0ce54206962ed1dacc58ee67c32b10f3b236fe20 3 | size 1668641 4 | -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVIDIA/workbench-example-sdxl-customization/3a13b42bb9b9dd27322fc0c57dabcaa73c126451/models/.gitkeep -------------------------------------------------------------------------------- /postBuild.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This file contains bash commands that will be executed at the end of the container build process, 3 | # after all system packages and programming language specific package have been installed. 4 | # 5 | # Note: This file may be removed if you don't need to use it 6 | 7 | sudo mkdir -p /mnt/cache/ 8 | sudo chown $NVWB_UID:$NVWB_GID /mnt/cache/ -------------------------------------------------------------------------------- /preBuild.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This file contains bash commands that will be executed at the beginning of the container build process, 3 | # before any system packages or programming language specific package have been installed. 4 | # 5 | # Note: This file may be removed if you don't need to use it 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==1.4.0 2 | accelerate==1.7.0 3 | aiofiles==23.2.1 4 | altair==5.5.0 5 | annotated-types==0.7.0 6 | anyio==4.9.0 7 | argon2-cffi==23.1.0 8 | argon2-cffi-bindings==21.2.0 9 | arrow==1.3.0 10 | asttokens==2.4.0 11 | astunparse==1.6.3 12 | async-lru==2.0.5 13 | attrs==23.1.0 14 | audioread==3.0.0 15 | babel==2.17.0 16 | backcall==0.2.0 17 | beautifulsoup4==4.12.2 18 | bleach==6.0.0 19 | blis==0.7.10 20 | cachetools==5.3.1 21 | catalogue==2.0.9 22 | certifi==2023.7.22 23 | cffi==1.15.1 24 | charset-normalizer==3.2.0 25 | cmake==3.27.4.1 26 | comm==0.1.4 27 | confection==0.1.3 28 | contourpy==1.1.0 29 | cycler==0.11.0 30 | cymem==2.0.7 31 | Cython==3.0.2 32 | dataclass-wizard==0.35.0 33 | debugpy==1.8.0 34 | decorator==5.1.1 35 | defusedxml==0.7.1 36 | diffusers==0.32.2 37 | dm-tree==0.1.8 38 | einops==0.6.1 39 | exceptiongroup==1.3.0 40 | execnet==2.0.2 41 | executing==1.2.0 42 | expecttest==0.1.3 43 | fastapi==0.112.2 44 | fastjsonschema==2.18.0 45 | ffmpy==0.6.0 46 | filelock==3.12.4 47 | flash-attn==2.0.4 48 | fonttools==4.42.1 49 | fqdn==1.5.1 50 | gast==0.5.4 51 | gradio==4.35.0 52 | gradio_client==1.0.1 53 | grpcio==1.58.0 54 | h11==0.16.0 55 | hf-xet==1.1.3 56 | httpcore==1.0.9 57 | httpx==0.28.1 58 | huggingface-hub==0.32.4 59 | hypothesis==5.35.1 60 | idna==3.10 61 | importlib_resources==6.5.2 62 | iniconfig==2.0.0 63 | inquirerpy==0.3.4 64 | intel-openmp==2021.4.0 65 | ipykernel==6.25.2 66 | ipython==8.15.0 67 | ipython-genutils==0.2.0 68 | ipywidgets==8.1.7 69 | isoduration==20.11.0 70 | jedi==0.19.0 71 | Jinja2==3.1.2 72 | joblib==1.3.2 73 | json5==0.9.14 74 | jsonpointer==3.0.0 75 | jsonschema==4.19.0 76 | jsonschema-specifications==2023.7.1 77 | jupyterlab>3.0 78 | kiwisolver==1.4.5 79 | langcodes==3.3.0 80 | librosa==0.9.2 81 | llvmlite==0.40.1 82 | Markdown==3.4.4 83 | markdown-it-py==3.0.0 84 | MarkupSafe==2.1.3 85 | matplotlib==3.7.3 86 | matplotlib-inline==0.1.6 87 | mdit-py-plugins==0.4.0 88 | mdurl==0.1.2 89 | mistune==3.0.1 90 | mkl==2021.1.1 91 | mkl-devel==2021.1.1 92 | mkl-include==2021.1.1 93 | mock==5.1.0 94 | mpmath==1.3.0 95 | murmurhash==1.0.9 96 | narwhals==1.42.0 97 | nbclient==0.8.0 98 | nbconvert==7.8.0 99 | nbformat==5.9.2 100 | nest-asyncio==1.5.7 101 | networkx==2.6.3 102 | ninja==1.11.1 103 | notebook==6.4.10 104 | notebook_shim==0.2.4 105 | numpy==1.22.2 106 | nvidia-cublas-cu12==12.4.5.8 107 | nvidia-cuda-cupti-cu12==12.4.127 108 | nvidia-cuda-nvrtc-cu12==12.4.127 109 | nvidia-cuda-runtime-cu12==12.4.127 110 | nvidia-cudnn-cu12==9.1.0.70 111 | nvidia-cufft-cu12==11.2.1.3 112 | nvidia-curand-cu12==10.3.5.147 113 | nvidia-cusolver-cu12==11.6.1.9 114 | nvidia-cusparse-cu12==12.3.1.170 115 | nvidia-cusparselt-cu12==0.6.2 116 | nvidia-dali-cuda120==1.29.0 117 | nvidia-nccl-cu12==2.21.5 118 | nvidia-nvjitlink-cu12==12.4.127 119 | nvidia-nvtx-cu12==12.4.127 120 | nvidia-pyindex==1.0.9 121 | oauthlib==3.2.2 122 | orjson==3.10.18 123 | overrides==7.7.0 124 | packaging==23.1 125 | pandocfilters==1.5.0 126 | parso==0.8.3 127 | peft==0.15.2 128 | pexpect==4.8.0 129 | pfzy==0.3.4 130 | pickleshare==0.7.5 131 | platformdirs==3.10.0 132 | pluggy==1.3.0 133 | polygraphy==0.49.0 134 | pooch==1.7.0 135 | preshed==3.0.8 136 | prettytable==3.9.0 137 | prometheus-client==0.17.1 138 | prompt-toolkit==3.0.39 139 | protobuf==4.24.3 140 | ptyprocess==0.7.0 141 | pure-eval==0.2.2 142 | pyasn1==0.5.0 143 | pyasn1-modules==0.3.0 144 | pybind11==2.11.1 145 | pycparser==2.21 146 | pydantic==2.8.2 147 | pydantic_core==2.20.1 148 | pydub==0.25.1 149 | Pygments==2.16.1 150 | pyparsing==3.1.1 151 | pytest==7.4.2 152 | pytest-flakefinder==1.1.0 153 | pytest-rerunfailures==12.0 154 | pytest-shard==0.1.2 155 | pytest-xdist==3.3.1 156 | python-dateutil==2.8.2 157 | python-hostlist==1.23.0 158 | python-json-logger==3.3.0 159 | python-multipart==0.0.20 160 | pytorch-quantization==2.1.2 161 | PyYAML==6.0.1 162 | pyzmq==25.1.1 163 | referencing==0.30.2 164 | regex==2023.8.8 165 | requests==2.31.0 166 | requests-oauthlib==1.3.1 167 | resampy==0.4.2 168 | rfc3339-validator==0.1.4 169 | rfc3986-validator==0.1.1 170 | rich==14.0.0 171 | rpds-py==0.10.3 172 | rsa==4.9 173 | ruff==0.11.13 174 | safetensors==0.5.3 175 | semantic-version==2.10.0 176 | Send2Trash==1.8.2 177 | shellingham==1.5.4 178 | six==1.16.0 179 | smart-open==6.4.0 180 | sniffio==1.3.1 181 | sortedcontainers==2.4.0 182 | soundfile==0.12.1 183 | soupsieve==2.5 184 | sphinx-glpi-theme==0.3 185 | srsly==2.4.7 186 | stack-data==0.6.2 187 | starlette==0.38.6 188 | sympy==1.13.1 189 | tabulate==0.9.0 190 | tbb==2021.10.0 191 | tensorboard==2.9.0 192 | tensorboard-data-server==0.6.1 193 | tensorboard-plugin-wit==1.8.1 194 | terminado==0.17.1 195 | thinc==8.1.12 196 | threadpoolctl==3.2.0 197 | tinycss2==1.2.1 198 | tokenizers==0.21.1 199 | toml==0.10.2 200 | tomli==2.0.1 201 | tomlkit==0.12.0 202 | torch==2.6.0 203 | torchvision==0.21.0 204 | tornado==6.3.3 205 | tqdm==4.66.1 206 | traitlets==5.9.0 207 | transformers==4.48.2 208 | triton==3.2.0 209 | types-dataclasses==0.6.6 210 | typing-inspection==0.4.1 211 | typing_extensions==4.14.0 212 | tzdata==2025.2 213 | uri-template==1.3.0 214 | urllib3==2.4.0 215 | uvicorn==0.34.3 216 | wasabi==1.1.2 217 | wcwidth==0.2.6 218 | webcolors==24.11.1 219 | webencodings==0.5.1 220 | websocket-client==1.8.0 221 | websockets==11.0.3 222 | Werkzeug==2.3.7 223 | widgetsnbextension==4.0.14 224 | xdoctest==1.0.2 225 | -------------------------------------------------------------------------------- /third-party/LICENSE.sdxl: -------------------------------------------------------------------------------- 1 | Copyright (c) 2023 Stability AI CreativeML Open RAIL++-M License dated July 26, 2023 2 | 3 | Section I: PREAMBLE Multimodal generative models are being widely adopted and used, and 4 | have the potential to transform the way artists, among other individuals, conceive and 5 | benefit from AI or ML technologies as a tool for content creation. Notwithstanding the 6 | current and potential benefits that these artifacts can bring to society at large, there 7 | are also concerns about potential misuses of them, either due to their technical 8 | limitations or ethical considerations. In short, this license strives for both the open 9 | and responsible downstream use of the accompanying model. When it comes to the open 10 | character, we took inspiration from open source permissive licenses regarding the grant 11 | of IP rights. Referring to the downstream responsible use, we added use-based 12 | restrictions not permitting the use of the model in very specific scenarios, in order 13 | for the licensor to be able to enforce the license in case potential misuses of the 14 | Model may occur. At the same time, we strive to promote open and responsible research on 15 | generative models for art and content generation. Even though downstream derivative 16 | versions of the model could be released under different licensing terms, the latter will 17 | always have to include - at minimum - the same use-based restrictions as the ones in the 18 | original license (this license). We believe in the intersection between open and 19 | responsible AI development; thus, this agreement aims to strike a balance between both 20 | in order to enable responsible open-science in the field of AI. This CreativeML Open 21 | RAIL++-M License governs the use of the model (and its derivatives) and is informed by 22 | the model card associated with the model. NOW THEREFORE, You and Licensor agree as 23 | follows: Definitions "License" means the terms and conditions for use, reproduction, and 24 | Distribution as defined in this document. "Data" means a collection of information 25 | and/or content extracted from the dataset used with the Model, including to train, 26 | pretrain, or otherwise evaluate the Model. The Data is not licensed under this License. 27 | "Output" means the results of operating a Model as embodied in informational content 28 | resulting therefrom. "Model" means any accompanying machine-learning based assemblies 29 | (including checkpoints), consisting of learnt weights, parameters (including optimizer 30 | states), corresponding to the model architecture as embodied in the Complementary 31 | Material, that have been trained or tuned, in whole or in part on the Data, using the 32 | Complementary Material. "Derivatives of the Model" means all modifications to the Model, 33 | works based on the Model, or any other model which is created or initialized by transfer 34 | of patterns of the weights, parameters, activations or output of the Model, to the other 35 | model, in order to cause the other model to perform similarly to the Model, including - 36 | but not limited to - distillation methods entailing the use of intermediate data 37 | representations or methods based on the generation of synthetic data by the Model for 38 | training the other model. "Complementary Material" means the accompanying source code 39 | and scripts used to define, run, load, benchmark or evaluate the Model, and used to 40 | prepare data for training or evaluation, if any. This includes any accompanying 41 | documentation, tutorials, examples, etc, if any. "Distribution" means any transmission, 42 | reproduction, publication or other sharing of the Model or Derivatives of the Model to a 43 | third party, including providing the Model as a hosted service made available by 44 | electronic or other remote means - e.g. API-based or web access. "Licensor" means the 45 | copyright owner or entity authorized by the copyright owner that is granting the 46 | License, including the persons or entities that may have rights in the Model and/or 47 | distributing the Model. "You" (or "Your") means an individual or Legal Entity exercising 48 | permissions granted by this License and/or making use of the Model for whichever purpose 49 | and in any field of use, including usage of the Model in an end-use application - e.g. 50 | chatbot, translator, image generator. "Third Parties" means individuals or legal 51 | entities that are not under common control with Licensor or You. "Contribution" means 52 | any work of authorship, including the original version of the Model and any 53 | modifications or additions to that Model or Derivatives of the Model thereof, that is 54 | intentionally submitted to Licensor for inclusion in the Model by the copyright owner or 55 | by an individual or Legal Entity authorized to submit on behalf of the copyright owner. 56 | For the purposes of this definition, "submitted" means any form of electronic, verbal, 57 | or written communication sent to the Licensor or its representatives, including but not 58 | limited to communication on electronic mailing lists, source code control systems, and 59 | issue tracking systems that are managed by, or on behalf of, the Licensor for the 60 | purpose of discussing and improving the Model, but excluding communication that is 61 | conspicuously marked or otherwise designated in writing by the copyright owner as "Not a 62 | Contribution." "Contributor" means Licensor and any individual or Legal Entity on behalf 63 | of whom a Contribution has been received by Licensor and subsequently incorporated 64 | within the Model. 65 | 66 | Section II: INTELLECTUAL PROPERTY RIGHTS Both copyright and patent grants apply to the 67 | Model, Derivatives of the Model and Complementary Material. The Model and Derivatives of 68 | the Model are subject to additional terms as described in 69 | 70 | Section III. Grant of Copyright License. Subject to the terms and conditions of this 71 | License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, 72 | no-charge, royalty-free, irrevocable copyright license to reproduce, prepare, publicly 73 | display, publicly perform, sublicense, and distribute the Complementary Material, the 74 | Model, and Derivatives of the Model. Grant of Patent License. Subject to the terms and 75 | conditions of this License and where and as applicable, each Contributor hereby grants 76 | to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this paragraph) patent license to make, have made, use, offer to 78 | sell, sell, import, and otherwise transfer the Model and the Complementary Material, 79 | where such license applies only to those patent claims licensable by such Contributor 80 | that are necessarily infringed by their Contribution(s) alone or by combination of their 81 | Contribution(s) with the Model to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a cross-claim or counterclaim 83 | in a lawsuit) alleging that the Model and/or Complementary Material or a Contribution 84 | incorporated within the Model and/or Complementary Material constitutes direct or 85 | contributory patent infringement, then any patent licenses granted to You under this 86 | License for the Model and/or Work shall terminate as of the date such litigation is 87 | asserted or filed. Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION 88 | Distribution and Redistribution. You may host for Third Party remote access purposes 89 | (e.g. software-as-a-service), reproduce and distribute copies of the Model or 90 | Derivatives of the Model thereof in any medium, with or without modifications, provided 91 | that You meet the following conditions: Use-based restrictions as referenced in 92 | paragraph 5 MUST be included as an enforceable provision by You in any type of legal 93 | agreement (e.g. a license) governing the use and/or distribution of the Model or 94 | Derivatives of the Model, and You shall give notice to subsequent users You Distribute 95 | to, that the Model or Derivatives of the Model are subject to paragraph 5. This 96 | provision does not apply to the use of Complementary Material. You must give any Third 97 | Party recipients of the Model or Derivatives of the Model a copy of this License; You 98 | must cause any modified files to carry prominent notices stating that You changed the 99 | files; You must retain all copyright, patent, trademark, and attribution notices 100 | excluding those notices that do not pertain to any part of the Model, Derivatives of the 101 | Model. You may add Your own copyright statement to Your modifications and may provide 102 | additional or different license terms and conditions - respecting paragraph 4.a. - for 103 | use, reproduction, or Distribution of Your modifications, or for any such Derivatives of 104 | the Model as a whole, provided Your use, reproduction, and Distribution of the Model 105 | otherwise complies with the conditions stated in this License. Use-based restrictions. 106 | The restrictions set forth in Attachment A are considered Use-based restrictions. 107 | Therefore You cannot use the Model and the Derivatives of the Model for the specified 108 | restricted uses. You may use the Model subject to this License, including only for 109 | lawful purposes and in accordance with the License. Use may include creating any content 110 | with, finetuning, updating, running, training, evaluating and/or reparametrizing the 111 | Model. You shall require all of Your users who use the Model or a Derivative of the 112 | Model to comply with the terms of this paragraph (paragraph 5). The Output You Generate. 113 | Except as set forth herein, Licensor claims no rights in the Output You generate using 114 | the Model. You are accountable for the Output you generate and its subsequent uses. No 115 | use of the output can contravene any provision as stated in the License. 116 | 117 | Section IV: OTHER PROVISIONS Updates and Runtime Restrictions. To the maximum extent 118 | permitted by law, Licensor reserves the right to restrict (remotely or otherwise) usage 119 | of the Model in violation of this License. Trademarks and related. Nothing in this 120 | License permits You to make use of Licensors’ trademarks, trade names, logos or to 121 | otherwise suggest endorsement or misrepresent the relationship between the parties; and 122 | any rights not expressly granted herein are reserved by the Licensors. Disclaimer of 123 | Warranty. Unless required by applicable law or agreed to in writing, Licensor provides 124 | the Model and the Complementary Material (and each Contributor provides its 125 | Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either 126 | express or implied, including, without limitation, any warranties or conditions of 127 | TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are 128 | solely responsible for determining the appropriateness of using or redistributing the 129 | Model, Derivatives of the Model, and the Complementary Material and assume any risks 130 | associated with Your exercise of permissions under this License. Limitation of 131 | Liability. In no event and under no legal theory, whether in tort (including 132 | negligence), contract, or otherwise, unless required by applicable law (such as 133 | deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be 134 | liable to You for damages, including any direct, indirect, special, incidental, or 135 | consequential damages of any character arising as a result of this License or out of the 136 | use or inability to use the Model and the Complementary Material (including but not 137 | limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, 138 | or any and all other commercial damages or losses), even if such Contributor has been 139 | advised of the possibility of such damages. Accepting Warranty or Additional Liability. 140 | While redistributing the Model, Derivatives of the Model and the Complementary Material 141 | thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, 142 | indemnity, or other liability obligations and/or rights consistent with this License. 143 | However, in accepting such obligations, You may act only on Your own behalf and on Your 144 | sole responsibility, not on behalf of any other Contributor, and only if You agree to 145 | indemnify, defend, and hold each Contributor harmless for any liability incurred by, or 146 | claims asserted against, such Contributor by reason of your accepting any such warranty 147 | or additional liability. If any provision of this License is held to be invalid, illegal 148 | or unenforceable, the remaining provisions shall be unaffected thereby and remain valid 149 | as if such provision had not been set forth herein. 150 | 151 | END OF TERMS AND CONDITIONS 152 | 153 | Attachment A Use Restrictions 154 | You agree not to use the Model or Derivatives of the Model: 155 | In any way that violates any applicable national, federal, state, local or 156 | international law or regulation; For the purpose of exploiting, harming or attempting to 157 | exploit or harm minors in any way; To generate or disseminate verifiably false 158 | information and/or content with the purpose of harming others; To generate or 159 | disseminate personal identifiable information that can be used to harm an individual; To 160 | defame, disparage or otherwise harass others; For fully automated decision making that 161 | adversely impacts an individual’s legal rights or otherwise creates or modifies a 162 | binding, enforceable obligation; For any use intended to or which has the effect of 163 | discriminating against or harming individuals or groups based on online or offline 164 | social behavior or known or predicted personal or personality characteristics; To 165 | exploit any of the vulnerabilities of a specific group of persons based on their age, 166 | social, physical or mental characteristics, in order to materially distort the behavior 167 | of a person pertaining to that group in a manner that causes or is likely to cause that 168 | person or another person physical or psychological harm; For any use intended to or 169 | which has the effect of discriminating against individuals or groups based on legally 170 | protected characteristics or categories; To provide medical advice and medical results 171 | interpretation; To generate or disseminate information for the purpose to be used for 172 | administration of justice, law enforcement, immigration or asylum processes, such as 173 | predicting an individual will commit fraud/crime commitment (e.g. by text profiling, 174 | drawing causal relationships between assertions made in documents, indiscriminate and 175 | arbitrarily-targeted use). 176 | -------------------------------------------------------------------------------- /variables.env: -------------------------------------------------------------------------------- 1 | # Set environment variables in the format KEY=VALUE, 1 per line 2 | # This file will be sourced inside the project container when started. 3 | # NOTE: If you change this file while the project is running, you must restart the project container for changes to take effect. 4 | 5 | HUGGINGFACE_HUB_CACHE=/mnt/cache 6 | 7 | APP_TRITONHOST="localhost" 8 | APP_TRITONPORT="8001" 9 | APP_MODELNAME="models" 10 | APP_MODELVERSION="1" 11 | --------------------------------------------------------------------------------