├── .gitignore ├── CODE_OF_CONDUCT.md ├── LICENSE ├── Dockerfile ├── README.md ├── CONTRIBUTING.md └── ecs_run.py /.gitignore: -------------------------------------------------------------------------------- 1 | aws 2 | awscliv2.zip -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | this software and associated documentation files (the "Software"), to deal in 5 | the Software without restriction, including without limitation the rights to 6 | use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of 7 | the Software, and to permit persons to whom the Software is furnished to do so. 8 | 9 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 10 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS 11 | FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR 12 | COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER 13 | IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN 14 | CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 15 | 16 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.0.0-runtime-ubuntu22.04 2 | WORKDIR /app 3 | 4 | RUN apt-get update && apt-get install -y wget git && apt-get clean 5 | 6 | RUN git clone https://github.com/db0/nataili.git . 7 | # Check out a specific version of the above repository 8 | RUN git checkout 6c2f1862bacf25b6bc74e95e3174ca45a116f85b 9 | RUN echo "boto3>=1.21.32">>requirements.txt 10 | 11 | # Download and install Miniconda 12 | RUN wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh 13 | RUN bash Miniconda3-latest-Linux-x86_64.sh -b -p /miniconda 14 | 15 | # Add miniconda to the PATH 16 | ENV PATH=/miniconda/bin:$PATH 17 | 18 | # Update conda and install any necessary packages 19 | RUN conda update --name base --channel defaults conda && \ 20 | conda env create -f /app/environment.yaml --force && \ 21 | conda clean -a -y 22 | 23 | # Install conda environment into container so we do not need to install every time. 24 | ENV ENV_NAME ldm 25 | 26 | COPY ecs_run.py /app/ 27 | 28 | SHELL ["conda", "run", "-n", "ldm", "/bin/bash", "-c"] 29 | 30 | ENTRYPOINT ["conda", "run", "--no-capture-output", "-n", "ldm", "python", "ecs_run.py"] 31 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## amazon-scalable-discord-diffusion 2 | 3 | This project is designed to create a container that runs text to image generation with Amazon Web Services (AWS). This project pairs with the infrastructure of [amazon-scalable-infra-discord-diffusion](https://github.com/aws-samples/amazon-scalable-infra-discord-diffusion). 4 | 5 | ## Security 6 | 7 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 8 | 9 | ## License 10 | 11 | This library is licensed under the MIT-0 License. See the LICENSE file. 12 | 13 | ## Legal disclamer 14 | 15 | This package requires and may incorporate or retrieve a number of third-party 16 | software packages (such as open source packages) at install-time or build-time 17 | or run-time ("External Dependencies"). The External Dependencies are subject to 18 | license terms that you must accept in order to use this package. If you do not 19 | accept all of the applicable license terms, you should not use this package. We 20 | recommend that you consult your company’s open source approval policy before 21 | proceeding. 22 | 23 | Provided below is a list of External Dependencies and the applicable license 24 | identification as indicated by the documentation associated with the External 25 | Dependencies as of Amazon's most recent review. 26 | 27 | THIS INFORMATION IS PROVIDED FOR CONVENIENCE ONLY. AMAZON DOES NOT PROMISE THAT 28 | THE LIST OR THE APPLICABLE TERMS AND CONDITIONS ARE COMPLETE, ACCURATE, OR 29 | UP-TO-DATE, AND AMAZON WILL HAVE NO LIABILITY FOR ANY INACCURACIES. YOU SHOULD 30 | CONSULT THE DOWNLOAD SITES FOR THE EXTERNAL DEPENDENCIES FOR THE MOST COMPLETE 31 | AND UP-TO-DATE LICENSING INFORMATION. 32 | 33 | YOUR USE OF THE EXTERNAL DEPENDENCIES IS AT YOUR SOLE RISK. IN NO EVENT WILL 34 | AMAZON BE LIABLE FOR ANY DAMAGES, INCLUDING WITHOUT LIMITATION ANY DIRECT, 35 | INDIRECT, CONSEQUENTIAL, SPECIAL, INCIDENTAL, OR PUNITIVE DAMAGES (INCLUDING 36 | FOR ANY LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, OR 37 | COMPUTER FAILURE OR MALFUNCTION) ARISING FROM OR RELATING TO THE EXTERNAL 38 | DEPENDENCIES, HOWEVER CAUSED AND REGARDLESS OF THE THEORY OF LIABILITY, EVEN 39 | IF AMAZON HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. THESE LIMITATIONS 40 | AND DISCLAIMERS APPLY EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW. 41 | 42 | https://github.com/Sygil-Dev/nataili — AGPL-3.0 43 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /ecs_run.py: -------------------------------------------------------------------------------- 1 | from nataili.model_manager import ModelManager 2 | from nataili.inference.compvis.txt2img import txt2img 3 | from nataili.util.cache import torch_gc 4 | import os 5 | from PIL import Image 6 | 7 | # Cloud Requirements 8 | import boto3 9 | import json 10 | import requests 11 | import random 12 | 13 | REGION = os.environ['REGION'] 14 | ssm = boto3.client('ssm', region_name=REGION) 15 | USER_HG = ssm.get_parameter(Name='/USER_HG')['Parameter']['Value'] 16 | PASSWORD_HG = ssm.get_parameter(Name='/PASSWORD_HG', WithDecryption=True)['Parameter']['Value'] 17 | 18 | # Create SQS client 19 | SQS = boto3.client('sqs', region_name=REGION) 20 | 21 | QUEUE_URL = os.environ['SQSQUEUEURL'] 22 | WAIT_TIME_SECONDS = 20 23 | 24 | ### SQS Functions ### 25 | def getSQSMessage(queue_url, time_wait): 26 | # Receive message from SQS queue 27 | response = SQS.receive_message( 28 | QueueUrl=queue_url, 29 | AttributeNames=[ 30 | 'SentTimestamp' 31 | ], 32 | MaxNumberOfMessages=1, 33 | MessageAttributeNames=[ 34 | 'All' 35 | ], 36 | WaitTimeSeconds=time_wait, 37 | ) 38 | 39 | try: 40 | message = response['Messages'][0] 41 | except KeyError: 42 | return None, None 43 | 44 | receipt_handle = message['ReceiptHandle'] 45 | return message, receipt_handle 46 | 47 | def deleteSQSMessage(queue_url, receipt_handle, prompt): 48 | # Delete received message from queue 49 | SQS.delete_message( 50 | QueueUrl=queue_url, 51 | ReceiptHandle=receipt_handle 52 | ) 53 | print(f'Received and deleted message: "{prompt}"') 54 | 55 | def convertMessageToDict(message): 56 | cleaned_message = {} 57 | body = json.loads(message['Body']) 58 | for item in body: 59 | # print(item) 60 | cleaned_message[item] = body[item]['StringValue'] 61 | return cleaned_message 62 | 63 | def validateRequest(r): 64 | if not r.ok: 65 | print("Failure") 66 | print(r.text) 67 | # raise Exception(r.text) 68 | else: 69 | print("Success") 70 | return 71 | 72 | ### Discord required functions ### 73 | def updateDiscordPicture(application_id, interaction_token, file_path): 74 | url = f'https://discord.com/api/v10/webhooks/{application_id}/{interaction_token}/messages/@original' 75 | files = {'stable-diffusion.png': open(file_path,'rb')} 76 | r = requests.patch(url, files=files) 77 | validateRequest(r) 78 | return 79 | 80 | def picturesToDiscord(file_path, message_dict, message_response): 81 | # Posts a follow up picture back to user on Discord 82 | 83 | # Initial Response is words. 84 | url = f"https://discord.com/api/v10/webhooks/{message_dict['applicationId']}/{message_dict['interactionToken']}/messages/@original" 85 | json_payload = { 86 | "content": f"*Completed your Sparkle!*```{message_response}```", 87 | "embeds": [], 88 | "attachments": [], 89 | "allowed_mentions": { "parse": [] }, 90 | } 91 | r = requests.patch(url, json=json_payload) 92 | validateRequest(r) 93 | 94 | # Upload a picture 95 | files = {'stable-diffusion.png': open(file_path,'rb')} 96 | r = requests.patch(url, json=json_payload, files=files) 97 | validateRequest(r) 98 | 99 | return 100 | 101 | def messageResponse(customer_data): 102 | message_response = f"\nPrompt: {customer_data['prompt']}" 103 | if 'negative_prompt' in customer_data: 104 | message_response += f"\nNegative Prompt: {customer_data['negative_prompt']}" 105 | if 'seed' in customer_data: 106 | message_response += f"\nSeed: {customer_data['seed']}" 107 | if 'steps' in customer_data: 108 | message_response += f"\nSteps: {customer_data['steps']}" 109 | if 'sampler' in customer_data: 110 | message_response += f"\nSampler: {customer_data['sampler']}" 111 | return message_response 112 | 113 | def submitInitialResponse(application_id, interaction_token, message_response): 114 | # Posts a follow up picture back to user on Discord 115 | url = f'https://discord.com/api/v10/webhooks/{application_id}/{interaction_token}/messages/@original' 116 | json_payload = { 117 | "content": f"Processing your Sparkle```{message_response}```", 118 | "embeds": [], 119 | "attachments": [], 120 | "allowed_mentions": { "parse": [] }, 121 | } 122 | r = requests.patch(url, json=json_payload, ) 123 | validateRequest(r) 124 | 125 | return 126 | 127 | def cleanupPictures(path_to_file): 128 | # Clean up file(s) created during creation. 129 | os.remove(path_to_file) 130 | return 131 | 132 | ### Stable Diffusion functions ### 133 | def image_grid(imgs, rows, cols): 134 | assert len(imgs) == rows*cols 135 | 136 | w, h = imgs[0].size 137 | grid = Image.new('RGB', size=(cols*w, rows*h)) 138 | grid_w, grid_h = grid.size 139 | 140 | for i, img in enumerate(imgs): 141 | grid.paste(img, box=(i%cols*w, i//cols*h)) 142 | return grid 143 | 144 | def runStableDiffusion(model_manager, model, user_inputs): 145 | # Run Stable Diffusion and create images in a grid. 146 | image_list = [] 147 | for my_seed in range(int(user_inputs['seed']),int(user_inputs['seed']) + 4): 148 | generator = txt2img(model_manager.loaded_models[model]["model"], model_manager.loaded_models[model]["device"], 'output_dir') 149 | generator.generate(user_inputs['prompt'], sampler_name=user_inputs['sampler'], ddim_steps=int(user_inputs['steps']), save_individual_images=False, n_iter=1, batch_size=1, seed=my_seed) 150 | image_list.append(generator.images[0]["image"]) 151 | torch_gc() 152 | return image_list 153 | 154 | def saveImage(image_list): 155 | my_grid = image_grid(image_list, 2, 2) 156 | my_grid.save('tmp.png', format="Png") 157 | return 'tmp.png' 158 | 159 | def decideInputs(user_dict): 160 | if 'seed' not in user_dict: 161 | user_dict['seed'] = random.randint(0,99999) 162 | 163 | if 'steps' not in user_dict: 164 | user_dict['steps'] = 16 165 | 166 | if 'sampler' not in user_dict: 167 | user_dict['sampler'] = 'k_euler_a' 168 | return user_dict 169 | 170 | def runMain(): 171 | # The model manager loads and unloads the  SD models and has features to download them or find their location 172 | model_manager = ModelManager(hf_auth={'username': USER_HG, 'password': PASSWORD_HG}) 173 | model_manager.init() 174 | 175 | # The model to use for the generation. 176 | model = "stable_diffusion" 177 | success = model_manager.load_model(model) 178 | # Load model or download model using hugging face credentials 179 | if success: 180 | print(f'{model} loaded') 181 | else: 182 | download_s = model_manager.download_model(model) 183 | if download_s: 184 | print(f'{model} downloaded') 185 | model_manager.load_model(model) 186 | else: 187 | print(f'{model} download error') 188 | print(f'{model} load error') 189 | 190 | queue_long_poll = WAIT_TIME_SECONDS 191 | # Get Message from Queue 192 | while True: 193 | print("Waiting for next message from Queue...") 194 | message, receipt_handle = getSQSMessage(QUEUE_URL, WAIT_TIME_SECONDS) 195 | 196 | if not message: 197 | ## Wait for new message or timeout and exit 198 | while not message: 199 | message, receipt_handle = getSQSMessage(QUEUE_URL, queue_long_poll) 200 | if message: 201 | break 202 | 203 | ## Run stable Diffusion 204 | print("Found a message! Running Stable Diffusion") 205 | message_dict = convertMessageToDict(message) 206 | message_dict = decideInputs(message_dict) 207 | message_response = messageResponse(message_dict) 208 | print(message_response) 209 | submitInitialResponse(message_dict['applicationId'], message_dict['interactionToken'], message_response) 210 | # file_path, user_seed, user_steps = runStableDiffusion(opt, message_dict, model, device, outpath, sampler) 211 | image_list = runStableDiffusion(model_manager, model, message_dict) 212 | file_path = saveImage(image_list) 213 | picturesToDiscord(file_path, message_dict, message_response) 214 | cleanupPictures(file_path) 215 | ## Delete Message 216 | deleteSQSMessage(QUEUE_URL, receipt_handle, message_dict['prompt']) 217 | 218 | image_list = runStableDiffusion(model_manager, model) 219 | my_image = saveImage(image_list) 220 | 221 | 222 | 223 | 224 | if __name__ == "__main__": 225 | runMain() 226 | --------------------------------------------------------------------------------