├── scripts ├── __init__.py ├── wandb │ ├── latest-run │ ├── debug.log │ ├── debug-internal.log │ └── run-20240429_145519-r0eclldx │ │ ├── files │ │ ├── wandb-summary.json │ │ ├── requirements.txt │ │ ├── output.log │ │ ├── config.yaml │ │ ├── wandb-metadata.json │ │ └── diff.patch │ │ ├── run-r0eclldx.wandb │ │ └── logs │ │ └── debug.log ├── output.mp4 ├── __pycache__ │ ├── config.cpython-310.pyc │ ├── __init__.cpython-310.pyc │ ├── api_utils.cpython-310.pyc │ ├── outpainting.cpython-310.pyc │ ├── s3_manager.cpython-310.pyc │ ├── flux_inference.cpython-310.pyc │ ├── image_to_video.cpython-310.pyc │ ├── controlnet_union.cpython-310.pyc │ └── pipeline_fill_sd_xl.cpython-310.pyc ├── load_pipeline.py ├── sdxl_lora_inference.py ├── s3_manager.py ├── config.py ├── inpainting_pipeline.py ├── products10k_captions.py ├── image_to_video.py ├── flux_inference.py ├── api_utils.py └── outpainting.py ├── picpilot.egg-info ├── dependency_links.txt ├── top_level.txt ├── PKG-INFO └── SOURCES.txt ├── sample_data ├── mask.png ├── image.jpg ├── example2.jpg ├── example3.jpg ├── example4.jpg ├── product_img.jpg └── example5.jpg ├── config_template.env ├── .dockerignore ├── test-scripts ├── generated_video.mp4 ├── test_outpainting.py ├── test.py ├── test_image2video.py └── test_sdxl.py ├── api ├── __pycache__ │ ├── flux_serve.cpython-310.pyc │ ├── sdxl_serve.cpython-310.pyc │ ├── image2video_serve.cpython-310.pyc │ └── outpainting_serve.cpython-310.pyc ├── requirements.txt ├── client.py ├── picpilot.py ├── image2video_serve.py ├── sdxl_serve.py ├── flux_serve.py └── outpainting_serve.py ├── __pycache__ └── config_settings.cpython-310.pyc ├── configs ├── __pycache__ │ └── tti_settings.cpython-310.pyc └── tti_settings.py ├── setup.py ├── .gitmodules ├── .gitignore ├── .gitattributes ├── iac ├── terraform.tfstate.backup ├── project.tf ├── .terraform.lock.hcl └── terraform.tfstate ├── .vscode └── settings.json ├── config_settings.py ├── run.sh ├── .github └── workflows │ ├── docker-ci.yml │ └── docker-ci-serverless.yml ├── serverless ├── inpainting │ ├── Dockerfile │ └── run_inpainting.py ├── outpainting │ ├── Dockerfile │ └── run_outpainting.py ├── image-to-video │ ├── Dockerfile │ └── run_image-to-video.py └── text-to-image │ ├── Dockerfile │ └── run_text-to-image.py ├── LICENSE ├── DockerFileFolder └── Dockerfile ├── README.md └── Readme.svg /scripts/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /picpilot.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /picpilot.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | scripts 2 | -------------------------------------------------------------------------------- /scripts/wandb/latest-run: -------------------------------------------------------------------------------- 1 | run-20240518_183749-6g8f40nj -------------------------------------------------------------------------------- /scripts/wandb/debug.log: -------------------------------------------------------------------------------- 1 | run-20240518_183749-6g8f40nj/logs/debug.log -------------------------------------------------------------------------------- /scripts/wandb/debug-internal.log: -------------------------------------------------------------------------------- 1 | run-20240518_183749-6g8f40nj/logs/debug-internal.log -------------------------------------------------------------------------------- /sample_data/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/sample_data/mask.png -------------------------------------------------------------------------------- /scripts/output.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/scripts/output.mp4 -------------------------------------------------------------------------------- /scripts/wandb/run-20240429_145519-r0eclldx/files/wandb-summary.json: -------------------------------------------------------------------------------- 1 | {"_wandb": {"runtime": 3}} -------------------------------------------------------------------------------- /sample_data/image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/sample_data/image.jpg -------------------------------------------------------------------------------- /config_template.env: -------------------------------------------------------------------------------- 1 | AWS_ACCESS_KEY_ID 2 | AWS_SECRET_ACCESS_KEY 3 | AWS_REGION 4 | AWS_BUCKET_NAME 5 | -------------------------------------------------------------------------------- /sample_data/example2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/sample_data/example2.jpg -------------------------------------------------------------------------------- /sample_data/example3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/sample_data/example3.jpg -------------------------------------------------------------------------------- /sample_data/example4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/sample_data/example4.jpg -------------------------------------------------------------------------------- /sample_data/product_img.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/sample_data/product_img.jpg -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | iac 2 | .venv 3 | sample_data 4 | run.sh 5 | ui 6 | outputs 7 | .vscode 8 | config_template 9 | -------------------------------------------------------------------------------- /picpilot.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 2.1 2 | Name: picpilot 3 | Version: 0.1.1 4 | License-File: LICENSE 5 | -------------------------------------------------------------------------------- /test-scripts/generated_video.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/test-scripts/generated_video.mp4 -------------------------------------------------------------------------------- /api/__pycache__/flux_serve.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/api/__pycache__/flux_serve.cpython-310.pyc -------------------------------------------------------------------------------- /api/__pycache__/sdxl_serve.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/api/__pycache__/sdxl_serve.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/scripts/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /__pycache__/config_settings.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/__pycache__/config_settings.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/scripts/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/api_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/scripts/__pycache__/api_utils.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/outpainting.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/scripts/__pycache__/outpainting.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/s3_manager.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/scripts/__pycache__/s3_manager.cpython-310.pyc -------------------------------------------------------------------------------- /api/__pycache__/image2video_serve.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/api/__pycache__/image2video_serve.cpython-310.pyc -------------------------------------------------------------------------------- /api/__pycache__/outpainting_serve.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/api/__pycache__/outpainting_serve.cpython-310.pyc -------------------------------------------------------------------------------- /configs/__pycache__/tti_settings.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/configs/__pycache__/tti_settings.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/flux_inference.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/scripts/__pycache__/flux_inference.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/image_to_video.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/scripts/__pycache__/image_to_video.cpython-310.pyc -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup(name='picpilot', 4 | version='0.1.1', 5 | packages=find_packages(), 6 | ) -------------------------------------------------------------------------------- /scripts/__pycache__/controlnet_union.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/scripts/__pycache__/controlnet_union.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/pipeline_fill_sd_xl.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/scripts/__pycache__/pipeline_fill_sd_xl.cpython-310.pyc -------------------------------------------------------------------------------- /sample_data/example5.jpg: -------------------------------------------------------------------------------- 1 | version https://git-lfs.github.com/spec/v1 2 | oid sha256:f6205e9183e7730f1e2e42f20fa0203ac91154a508b259130a7eca9b97a13296 3 | size 30073 4 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "/Users/vikram/Projects/PicPilot/Diffree"] 2 | path = /Users/vikram/Projects/PicPilot/Diffree 3 | url = https://github.com/OpenGVLab/Diffree.git 4 | -------------------------------------------------------------------------------- /scripts/wandb/run-20240429_145519-r0eclldx/run-r0eclldx.wandb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VikramxD/PicPilot/HEAD/scripts/wandb/run-20240429_145519-r0eclldx/run-r0eclldx.wandb -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | /scripts/wandb 3 | variables.tf 4 | .terraform 5 | config.env 6 | /scripts/yolov8s* 7 | /scripts/*jpg 8 | /scripts/outputs 9 | scripts/wandb 10 | /scripts/outputs 11 | /scripts/outputs 12 | bfg.jar 13 | 14 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | scripts/output filter=lfs diff=lfs merge=lfs -text 2 | *.png filter=lfs diff=lfs merge=lfs -text 3 | *.pt filter=lfs diff=lfs merge=lfs -text 4 | *.psd filter=lfs diff=lfs merge=lfs -text 5 | *.jpg filter=lfs diff=lfs merge=lfs -text 6 | -------------------------------------------------------------------------------- /iac/terraform.tfstate.backup: -------------------------------------------------------------------------------- 1 | { 2 | "version": 4, 3 | "terraform_version": "1.8.2", 4 | "serial": 14, 5 | "lineage": "7f24f129-9566-aba8-5a57-55c73ab4a868", 6 | "outputs": {}, 7 | "resources": [], 8 | "check_results": null 9 | } 10 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.REPL.enableREPLSmartSend": false, 3 | "circleci.persistedProjectSelection": [ 4 | "circleci/BQyaZDyai3ejufsymJgXtx/G7GFeA574Ga7r2VtdHXBbu" 5 | ], 6 | "python.analysis.extraPaths": [ 7 | "./scripts" 8 | ] 9 | } -------------------------------------------------------------------------------- /config_settings.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings 2 | class Settings(BaseSettings): 3 | LOGFIRE_TOKEN:str = '' 4 | AWS_ACCESS_KEY_ID: str = '' 5 | AWS_SECRET_ACCESS_KEY: str = '' 6 | AWS_REGION: str = "ap-south-1" 7 | AWS_BUCKET_NAME: str="diffusion-model-bucket" 8 | 9 | 10 | 11 | settings = Settings() 12 | 13 | -------------------------------------------------------------------------------- /api/requirements.txt: -------------------------------------------------------------------------------- 1 | sentencepiece 2 | git+https://github.com/huggingface/diffusers 3 | lightning 4 | Pillow 5 | pydantic 6 | torch>=1.4 7 | utils 8 | uvicorn 9 | boto3 10 | ultralytics 11 | git+https://github.com/huggingface/transformers 12 | accelerate 13 | peft 14 | pydantic-settings 15 | controlnet-aux 16 | rembg 17 | mediapipe 18 | git+https://github.com/bhimrazy/LitServe.git@feat/multi-endpoints 19 | torchao 20 | runpod -------------------------------------------------------------------------------- /picpilot.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | README.md 3 | setup.py 4 | picpilot.egg-info/PKG-INFO 5 | picpilot.egg-info/SOURCES.txt 6 | picpilot.egg-info/dependency_links.txt 7 | picpilot.egg-info/top_level.txt 8 | scripts/__init__.py 9 | scripts/api_utils.py 10 | scripts/config.py 11 | scripts/controlnet_union.py 12 | scripts/flux_inference.py 13 | scripts/image_to_video.py 14 | scripts/inpainting_pipeline.py 15 | scripts/load_pipeline.py 16 | scripts/outpainting.py 17 | scripts/pipeline_fill_sd_xl.py 18 | scripts/products10k_captions.py 19 | scripts/s3_manager.py 20 | scripts/sdxl_lora_inference.py 21 | scripts/sdxl_lora_tuner.py -------------------------------------------------------------------------------- /api/client.py: -------------------------------------------------------------------------------- 1 | # Copyright The Lightning AI team. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import requests 15 | response = requests.post("http://127.0.0.1:8000/predict", json={"input": 4.0}) 16 | print(f"Status: {response.status_code}\nResponse:\n {response.text}") 17 | -------------------------------------------------------------------------------- /configs/tti_settings.py: -------------------------------------------------------------------------------- 1 | from pydantic_settings import BaseSettings 2 | 3 | class TTI_SETTINGS(BaseSettings): 4 | 5 | MODEL_NAME:str="stabilityai/stable-diffusion-xl-base-1.0" 6 | ADAPTER_NAME:str = "VikramSingh178/sdxl-lora-finetune-product-caption" 7 | ENABLE_COMPILE: bool = False 8 | DEVICE: str = "cuda" 9 | TRITON_MODEL_NAME: str = "PICPILOT_PRODUCTION_SERVER" 10 | MAX_BATCH_SIZE: int = 32 11 | MAX_QUEUE_DELAY_MICROSECONDS: int = 100 12 | TORCH_INDUCTOR_CONFIG: dict = { 13 | "conv_1x1_as_mm": True, 14 | "coordinate_descent_tuning": True, 15 | "epilogue_fusion": False, 16 | "coordinate_descent_check_all_directions": True, 17 | "force_fuse_int_mm_with_mul": True, 18 | "use_mixed_mm": True 19 | } 20 | LOG_FORMAT: str = "%(asctime)s - %(levelname)s - %(name)s: %(message)s" 21 | LOG_LEVEL: str = "INFO" 22 | 23 | tti_settings = TTI_SETTINGS() -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Install necessary packages 4 | apt-get install -y python3-venv python3-pip 5 | 6 | 7 | # Remove the old virtual environment if it exists and create a new one 8 | if [ -d ".venv" ]; then 9 | rm -rf .venv 10 | fi 11 | 12 | python3 -m venv .venv 13 | 14 | # Activate the virtual environment 15 | if [ -f ".venv/bin/activate" ]; then 16 | source .venv/bin/activate 17 | else 18 | echo "Failed to create the virtual environment. Please check for errors." 19 | exit 1 20 | fi 21 | 22 | # Install the uv package within the virtual environment 23 | pip install uv 24 | 25 | # Change directory to api and install the required packages 26 | cd api || { echo "API directory not found"; exit 1; } 27 | 28 | 29 | # Ensure the uv command is available 30 | if ! command -v uv &> /dev/null; then 31 | echo "uv command not found. Ensure it is installed and accessible." 32 | exit 1 33 | fi 34 | 35 | uv pip install -r requirements.txt 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /.github/workflows/docker-ci.yml: -------------------------------------------------------------------------------- 1 | name: Docker Build and Push 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | pull_request: 7 | branches: [ "main" ] 8 | 9 | jobs: 10 | build-and-push: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - name: Checkout code 15 | uses: actions/checkout@v3 16 | 17 | - name: Set up Docker Buildx 18 | uses: docker/setup-buildx-action@v2 19 | 20 | - name: Login to Docker Hub 21 | uses: docker/login-action@v2 22 | with: 23 | username: ${{secrets.DOCKER_USERNAME }} 24 | password: ${{ secrets.DOCKER_PASSWORD }} 25 | 26 | - name: Build and push Docker image 27 | id: docker_build 28 | uses: docker/build-push-action@v4 29 | with: 30 | context: . 31 | file: ./DockerFileFolder/Dockerfile 32 | push: true 33 | tags: ${{secrets.DOCKER_USERNAME }}/picpilot:latest 34 | 35 | - name: Image digest 36 | run: echo ${{ steps.docker_build.outputs.digest }} 37 | -------------------------------------------------------------------------------- /serverless/inpainting/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04 2 | 3 | ENV PYTHONDONTWRITEBYTECODE=1 \ 4 | PYTHONUNBUFFERED=1 \ 5 | PYTHONPATH=/app 6 | 7 | WORKDIR /app 8 | 9 | RUN apt-get update && apt-get install -y --no-install-recommends \ 10 | python3.11 \ 11 | python3-pip \ 12 | python3.11-dev \ 13 | python3.11-venv \ 14 | ffmpeg \ 15 | libsm6 \ 16 | libxext6 \ 17 | libgl1-mesa-glx \ 18 | git \ 19 | build-essential \ 20 | && rm -rf /var/lib/apt/lists/* 21 | 22 | RUN useradd -m -u 1000 user && \ 23 | chown -R user:user /app 24 | 25 | USER user 26 | 27 | RUN python3.11 -m venv /app/venv 28 | ENV PATH="/app/venv/bin:$PATH" 29 | 30 | COPY --chown=user:user . . 31 | 32 | RUN pip install --no-cache-dir -U pip setuptools wheel && \ 33 | pip install --no-cache-dir -r api/requirements.txt && \ 34 | pip install --no-cache-dir runpod && \ 35 | pip install --no-cache-dir -e . 36 | 37 | CMD ["python", "serverless/inpainting/run_inpainting.py"] -------------------------------------------------------------------------------- /serverless/outpainting/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04 2 | 3 | ENV PYTHONDONTWRITEBYTECODE=1 \ 4 | PYTHONUNBUFFERED=1 \ 5 | PYTHONPATH=/app 6 | 7 | WORKDIR /app 8 | 9 | RUN apt-get update && apt-get install -y --no-install-recommends \ 10 | python3.11 \ 11 | python3-pip \ 12 | python3.11-dev \ 13 | python3.11-venv \ 14 | ffmpeg \ 15 | libsm6 \ 16 | libxext6 \ 17 | libgl1-mesa-glx \ 18 | git \ 19 | build-essential \ 20 | && rm -rf /var/lib/apt/lists/* 21 | 22 | RUN useradd -m -u 1000 user && \ 23 | chown -R user:user /app 24 | 25 | USER user 26 | 27 | RUN python3.11 -m venv /app/venv 28 | ENV PATH="/app/venv/bin:$PATH" 29 | 30 | COPY --chown=user:user . . 31 | 32 | RUN pip install --no-cache-dir -U pip setuptools wheel && \ 33 | pip install --no-cache-dir -r api/requirements.txt && \ 34 | pip install --no-cache-dir runpod && \ 35 | pip install --no-cache-dir -e . 36 | 37 | CMD ["python", "serverless/outpainting/run_outpainting.py"] -------------------------------------------------------------------------------- /serverless/image-to-video/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04 2 | 3 | ENV PYTHONDONTWRITEBYTECODE=1 \ 4 | PYTHONUNBUFFERED=1 \ 5 | PYTHONPATH=/app 6 | 7 | WORKDIR /app 8 | 9 | RUN apt-get update && apt-get install -y --no-install-recommends \ 10 | python3.11 \ 11 | python3-pip \ 12 | python3.11-dev \ 13 | python3.11-venv \ 14 | ffmpeg \ 15 | libsm6 \ 16 | libxext6 \ 17 | libgl1-mesa-glx \ 18 | git \ 19 | build-essential \ 20 | && rm -rf /var/lib/apt/lists/* 21 | 22 | RUN useradd -m -u 1000 user && \ 23 | chown -R user:user /app 24 | 25 | USER user 26 | 27 | RUN python3.11 -m venv /app/venv 28 | ENV PATH="/app/venv/bin:$PATH" 29 | 30 | COPY --chown=user:user . . 31 | 32 | RUN pip install --no-cache-dir -U pip setuptools wheel && \ 33 | pip install --no-cache-dir -r api/requirements.txt && \ 34 | pip install --no-cache-dir runpod && \ 35 | pip install --no-cache-dir -e . 36 | 37 | CMD ["python", "serverless/image-to-video/run_image-to-video.py"] -------------------------------------------------------------------------------- /serverless/text-to-image/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04 2 | 3 | ENV PYTHONDONTWRITEBYTECODE=1 \ 4 | PYTHONUNBUFFERED=1 \ 5 | PYTHONPATH=/app 6 | 7 | WORKDIR /app 8 | 9 | RUN apt-get update && apt-get install -y --no-install-recommends \ 10 | python3.11 \ 11 | python3-pip \ 12 | python3.11-dev \ 13 | python3.11-venv \ 14 | ffmpeg \ 15 | libsm6 \ 16 | libxext6 \ 17 | libgl1-mesa-glx \ 18 | git \ 19 | build-essential \ 20 | && rm -rf /var/lib/apt/lists/* 21 | 22 | RUN useradd -m -u 1000 user && \ 23 | chown -R user:user /app 24 | 25 | USER user 26 | 27 | RUN python3.11 -m venv /app/venv 28 | ENV PATH="/app/venv/bin:$PATH" 29 | 30 | COPY --chown=user:user . . 31 | 32 | RUN pip install --no-cache-dir -U pip setuptools wheel && \ 33 | pip install --no-cache-dir -r api/requirements.txt && \ 34 | pip install --no-cache-dir runpod && \ 35 | pip install --no-cache-dir -e . 36 | 37 | CMD ["python", "serverless/text-to-image/run_text-to-image.py"] -------------------------------------------------------------------------------- /scripts/load_pipeline.py: -------------------------------------------------------------------------------- 1 | from config import MODEL_NAME,ADAPTER_NAME 2 | import torch 3 | from diffusers import DiffusionPipeline 4 | from wandb.integration.diffusers import autolog 5 | from scripts.config import PROJECT_NAME 6 | autolog(init=dict(project=PROJECT_NAME)) 7 | 8 | 9 | 10 | 11 | 12 | 13 | def load_pipeline(model_name, adapter_name): 14 | pipe = DiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16).to( 15 | "cuda" 16 | ) 17 | pipe.load_lora_weights(adapter_name) 18 | pipe.unet.to(memory_format=torch.channels_last) 19 | pipe.vae.to(memory_format=torch.channels_last) 20 | pipe.unet = torch.compile(pipe.unet, mode="max-autotune", fullgraph=True) 21 | pipe.vae.decode = torch.compile( 22 | pipe.vae.decode, mode="max-autotune", fullgraph=True 23 | ) 24 | pipe.fuse_qkv_projections() 25 | 26 | return pipe 27 | 28 | loaded_pipeline = load_pipeline(MODEL_NAME, ADAPTER_NAME) 29 | images = loaded_pipeline('toaster', num_inference_steps=30).images[0] 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Vikramjeet Singh 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /iac/project.tf: -------------------------------------------------------------------------------- 1 | provider "aws" { 2 | region = "ap-south-1" 3 | access_key = var.aws_access_key 4 | secret_key = var.aws_secret_key 5 | } 6 | 7 | 8 | 9 | 10 | 11 | resource "aws_s3_bucket" "diffusion_model_bucket" { 12 | bucket = "diffusion-model-bucket" 13 | tags = { 14 | Name = "Diffusion Model Bucket" 15 | Task = "SDXL LORA" 16 | Product = "Product Diffusion API" 17 | } 18 | 19 | } 20 | 21 | resource "aws_s3_bucket_ownership_controls" "s3_bucket_acl_ownership" { 22 | bucket = aws_s3_bucket.diffusion_model_bucket.id 23 | rule { 24 | object_ownership = "ObjectWriter" 25 | } 26 | 27 | } 28 | 29 | resource "aws_s3_bucket_public_access_block" "s3_bucket_public_access_block" { 30 | bucket = aws_s3_bucket.diffusion_model_bucket.id 31 | block_public_acls = false 32 | block_public_policy = false 33 | ignore_public_acls = true 34 | restrict_public_buckets = true 35 | } 36 | 37 | 38 | resource "aws_s3_bucket_acl" "acl_access" { 39 | depends_on = [ 40 | aws_s3_bucket_ownership_controls.s3_bucket_acl_ownership, 41 | aws_s3_bucket_public_access_block.s3_bucket_public_access_block, 42 | ] 43 | 44 | bucket = aws_s3_bucket.diffusion_model_bucket.id 45 | acl = "public-read" 46 | } 47 | 48 | 49 | -------------------------------------------------------------------------------- /api/picpilot.py: -------------------------------------------------------------------------------- 1 | from litserve.server import LitServer,run_all 2 | from flux_serve import FluxInpaintingAPI 3 | from sdxl_serve import SDXLLoraAPI 4 | from configs.tti_settings import tti_settings 5 | from outpainting_serve import OutpaintingAPI 6 | from image2video_serve import ImageToVideoAPI 7 | from starlette.middleware.cors import CORSMiddleware 8 | 9 | cors_middleware = (CORSMiddleware, {"allow_origins": ["*"], "allow_credentials": True, "allow_methods": ["*"], "allow_headers": ["*"],}) 10 | 11 | 12 | 13 | flux_server = LitServer(FluxInpaintingAPI(), api_path='/api/v2/painting/flux', accelerator="auto",devices='auto', max_batch_size=4, batch_timeout=0.1,middlewares=[cors_middleware]) 14 | sdxl_server = LitServer(SDXLLoraAPI(), api_path='/api/v2/generate/sdxl', accelerator="auto",devices='auto', max_batch_size=tti_settings.MAX_BATCH_SIZE, batch_timeout=tti_settings.MAX_QUEUE_DELAY_MICROSECONDS / 1e6, middlewares=[cors_middleware]) 15 | outpainting_server = LitServer(OutpaintingAPI(), api_path='/api/v2/painting/sdxl_outpainting', accelerator='auto',devices='auto', max_batch_size=4, batch_timeout=0.1,middlewares=[cors_middleware]) 16 | image2video_server = LitServer(ImageToVideoAPI(), api_path='/api/v2/image2video/cogvideox', accelerator='auto',devices='auto', max_batch_size=1, batch_timeout=0.1,middlewares=[cors_middleware]) 17 | 18 | 19 | if __name__ == '__main__': 20 | run_all([flux_server,sdxl_server,outpainting_server,image2video_server], port=8000) 21 | -------------------------------------------------------------------------------- /iac/.terraform.lock.hcl: -------------------------------------------------------------------------------- 1 | # This file is maintained automatically by "terraform init". 2 | # Manual edits may be lost in future updates. 3 | 4 | provider "registry.terraform.io/hashicorp/aws" { 5 | version = "5.48.0" 6 | hashes = [ 7 | "h1:rMyeKizkPgNuYQ1UQpWGDvGdJQs5vDPDlYtS4jVxxcI=", 8 | "zh:0876d94be46be905d1f6c149461979cd6e9bec80d5ffad43fd6267fe7c3a924d", 9 | "zh:3a853f887e6f61c2ba383c46e71bcec97ecd31d25a78dab08958f43bbbaecb86", 10 | "zh:43235595e26dd131f00704b5b64a65c4e7c4984a559b30d4272170e1b78e99b7", 11 | "zh:6866f7535ec2ef8fe6ed16eecee2e31418a2bd86cec73e1d18e47bd3bb87f68e", 12 | "zh:756a4ed97f30ea6e8871c16446b24ce55601143a715e067b7f9ebdae8349da34", 13 | "zh:793e8414962934be9805186874f207ca1dc8d162b6665e4938893ad827a545c6", 14 | "zh:79b2f886507f21ff1b752ff140ed95ed551f389abf0c8177c7b5f5bbbd95da8e", 15 | "zh:8653b1bc6f7e62404e02f940d962d0c2ba0c4dd4c28bd595945454cf348c2697", 16 | "zh:9b12af85486a96aedd8d7984b0ff811a4b42e3d88dad1a3fb4c0b580d04fa425", 17 | "zh:9bd612f013c075685c129e5d0ec9243572cd51359599b7218459babe0e9b6ac7", 18 | "zh:aabafe758ee8392f56d2c894017203de4dae38c1e3e0d274c54e194b9b8fccff", 19 | "zh:aded7d6034115ca512b79ca17da65ebb8906e3b8def78dcbc0640142b0c05ca0", 20 | "zh:ce383ee19b37666aba60db6b01cbe7a1fcbb40c6dd54d0cb36b2ba114ee5ae62", 21 | "zh:ceaf1b998b9ced4b63f35da386358e4c7ad6def582438987c91bceffefb9e258", 22 | "zh:d79225a9ae6a7391c33aa2d794bf9b167db66398c4f054f94d557615b051a40d", 23 | ] 24 | } 25 | -------------------------------------------------------------------------------- /DockerFileFolder/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.2.2-cudnn8-runtime-ubuntu22.04 2 | 3 | # Set environment variables 4 | ENV PYTHONDONTWRITEBYTECODE=1 \ 5 | PYTHONUNBUFFERED=1 \ 6 | PATH="/home/user/.local/bin:$PATH" \ 7 | PYTHONPATH=/app 8 | 9 | # Set working directory in the container 10 | WORKDIR /app 11 | 12 | # Install system dependencies, Python, and build tools 13 | RUN apt-get update && apt-get install -y --no-install-recommends \ 14 | python3.11 \ 15 | python3-pip \ 16 | python3.11-venv \ 17 | python3.11-dev \ 18 | ffmpeg \ 19 | libsm6 \ 20 | libxext6 \ 21 | libgl1-mesa-glx \ 22 | git \ 23 | build-essential \ 24 | && rm -rf /var/lib/apt/lists/* 25 | 26 | # Create a non-root user to run the application 27 | RUN useradd -m -u 1000 user 28 | 29 | # Create a virtual environment and give ownership to the non-root user 30 | RUN python3.11 -m venv /home/user/venv && chown -R user:user /home/user/venv 31 | ENV PATH="/home/user/venv/bin:$PATH" 32 | 33 | # Switch to the non-root user 34 | USER user 35 | 36 | # Copy the entire application code 37 | COPY --chown=user:user . . 38 | 39 | # Install Python dependencies 40 | RUN pip install --no-cache-dir -U pip setuptools wheel && \ 41 | pip install --no-cache-dir -r api/requirements.txt 42 | 43 | # Install the application in editable mode using setup.py 44 | RUN pip install --no-cache-dir -e . 45 | 46 | # Set working directory for the application 47 | WORKDIR /app/api 48 | 49 | # Command to run the application 50 | CMD ["python3", "picpilot.py"] -------------------------------------------------------------------------------- /.github/workflows/docker-ci-serverless.yml: -------------------------------------------------------------------------------- 1 | name: Docker Build and Push Serverless 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | paths: 7 | - 'serverless/**' 8 | - 'api/**' 9 | - 'scripts/**' 10 | - '.github/workflows/**' 11 | pull_request: 12 | branches: [ "main" ] 13 | 14 | jobs: 15 | build-and-push: 16 | runs-on: ubuntu-latest 17 | strategy: 18 | matrix: 19 | service: 20 | - image-to-video 21 | - text-to-image 22 | - inpainting 23 | - outpainting 24 | 25 | steps: 26 | - name: Checkout code 27 | uses: actions/checkout@v3 28 | 29 | - name: Set up Docker Buildx 30 | uses: docker/setup-buildx-action@v2 31 | 32 | - name: Login to Docker Hub 33 | uses: docker/login-action@v2 34 | with: 35 | username: ${{secrets.DOCKER_USERNAME }} 36 | password: ${{ secrets.DOCKER_PASSWORD }} 37 | 38 | - name: Build and push Docker image 39 | id: docker_build 40 | uses: docker/build-push-action@v4 41 | with: 42 | context: . 43 | file: serverless/${{ matrix.service }}/Dockerfile 44 | push: true 45 | tags: ${{secrets.DOCKER_USERNAME }}/picpilot_${{ matrix.service }}:latest 46 | 47 | - name: Image digest 48 | run: echo ${{ steps.docker_build.outputs.digest }} 49 | 50 | notify: 51 | needs: build-and-push 52 | runs-on: ubuntu-latest 53 | steps: 54 | - name: Check build status 55 | run: | 56 | if [ "${{ needs.build-and-push.result }}" == "success" ]; then 57 | echo "All images built and pushed successfully" 58 | else 59 | echo "Some builds failed, check the logs" 60 | exit 1 61 | fi -------------------------------------------------------------------------------- /scripts/sdxl_lora_inference.py: -------------------------------------------------------------------------------- 1 | from wandb.integration.diffusers import autolog 2 | from diffusers import DiffusionPipeline 3 | import torch 4 | from config import PROJECT_NAME 5 | autolog(init=dict(project=PROJECT_NAME)) 6 | 7 | class SDXLLoraInference: 8 | """ 9 | Class for running inference using the SDXL-LoRA model to generate stunning product photographs. 10 | 11 | Args: 12 | num_inference_steps (int): The number of inference steps to perform. 13 | guidance_scale (float): The scale factor for guidance during inference. 14 | """ 15 | def __init__(self, num_inference_steps: int, guidance_scale: float) -> None: 16 | self.model_path = "VikramSingh178/sdxl-lora-finetune-product-caption" 17 | self.pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16) 18 | self.pipe.to("cuda") 19 | self.pipe.load_lora_weights(self.model_path) 20 | self.num_inference_steps = num_inference_steps 21 | self.guidance_scale = guidance_scale 22 | 23 | def run_inference(self, prompt): 24 | """ 25 | Runs inference using the SDXL-LoRA model to generate a stunning product photograph. 26 | 27 | Args: 28 | prompt: The input prompt for generating the product photograph. 29 | 30 | Returns: 31 | images: The generated product photograph(s). 32 | """ 33 | 34 | prompt = prompt 35 | images = self.pipe(prompt, num_inference_steps=self.num_inference_steps, guidance_scale=self.guidance_scale).images 36 | return images 37 | 38 | inference = SDXLLoraInference(num_inference_steps=100, guidance_scale=2.5) 39 | inference.run_inference(prompt= "A stunning 4k Shot of a Balenciaga X Anime Hoodie with a person wearing it in a party" ) 40 | -------------------------------------------------------------------------------- /scripts/s3_manager.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | import os 4 | import boto3 5 | from botocore.config import Config 6 | import random 7 | import string 8 | from config_settings import settings 9 | 10 | 11 | 12 | 13 | class S3ManagerService: 14 | def __init__(self): 15 | self.s3 = boto3.client( 16 | "s3", 17 | config=Config(signature_version="s3v4"), 18 | aws_access_key_id=settings.AWS_ACCESS_KEY_ID, 19 | aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY, 20 | region_name=settings.AWS_REGION, 21 | ) 22 | 23 | def generate_signed_url(self, file_name: str, exp: int = 43200) -> str: 24 | try: 25 | url = self.s3.generate_presigned_url( 26 | ClientMethod='get_object', 27 | Params={ 28 | 'Bucket': settings.AWS_BUCKET_NAME, 29 | 'Key': file_name 30 | }, 31 | ExpiresIn=exp, 32 | HttpMethod='GET' 33 | ) 34 | return url 35 | except Exception as e: 36 | print(f"Error generating presigned URL: {e}") 37 | return None 38 | 39 | def generate_unique_file_name(self, file_name: str) -> str: 40 | random_string = "".join( 41 | random.choices(string.ascii_uppercase + string.digits, k=10) 42 | ) 43 | file_extension = "png" 44 | file_real_name = file_name.split(".")[0] 45 | return f"{file_real_name}-{random_string}.{file_extension}" 46 | 47 | def upload_file(self, file, file_name) -> str: 48 | self.s3.upload_fileobj(file, settings.AWS_BUCKET_NAME, file_name) 49 | return file_name 50 | 51 | def upload_base64_file(self, base64_file: str, file_name: str) -> str: 52 | return self.upload_file(io.BytesIO(base64.b64decode(base64_file)), file_name) 53 | 54 | def get_object(self, file_name: str, bucket: str): 55 | try: 56 | return self.s3.get_object(Bucket=bucket, Key=file_name) 57 | except self.s3.exceptions.NoSuchKey: 58 | print(f"The file {file_name} does not exist in the bucket {bucket}.") 59 | return None 60 | except Exception as e: 61 | print(f"An error occurred: {e}") 62 | return None 63 | -------------------------------------------------------------------------------- /scripts/wandb/run-20240429_145519-r0eclldx/files/requirements.txt: -------------------------------------------------------------------------------- 1 | GitPython==3.1.43 2 | Jinja2==3.1.3 3 | Markdown==3.6 4 | MarkupSafe==2.1.5 5 | PyYAML==6.0.1 6 | Pygments==2.17.2 7 | Werkzeug==3.0.2 8 | absl-py==2.1.0 9 | accelerate==0.29.3 10 | aiohttp==3.9.5 11 | aiosignal==1.3.1 12 | annotated-types==0.6.0 13 | anyio==4.3.0 14 | appdirs==1.4.4 15 | async-timeout==4.0.3 16 | attrs==23.2.0 17 | bitsandbytes==0.43.1 18 | certifi==2024.2.2 19 | charset-normalizer==3.3.2 20 | click==8.1.7 21 | contourpy==1.2.1 22 | cycler==0.12.1 23 | datasets==2.19.0 24 | diffusers==0.27.2 25 | dill==0.3.8 26 | docker-pycreds==0.4.0 27 | exceptiongroup==1.2.1 28 | fastapi==0.110.2 29 | filelock==3.13.4 30 | fonttools==4.51.0 31 | frozenlist==1.4.1 32 | fsspec==2024.3.1 33 | ftfy==6.2.0 34 | gitdb==4.0.11 35 | grpcio==1.62.2 36 | h11==0.14.0 37 | huggingface-hub==0.22.2 38 | idna==3.7 39 | importlib_metadata==7.1.0 40 | kiwisolver==1.4.5 41 | lightning-utilities==0.11.2 42 | lightning==2.2.3 43 | markdown-it-py==3.0.0 44 | matplotlib==3.8.4 45 | mdurl==0.1.2 46 | mpmath==1.3.0 47 | multidict==6.0.5 48 | multiprocess==0.70.16 49 | networkx==3.3 50 | numpy==1.26.4 51 | nvidia-cublas-cu12==12.1.3.1 52 | nvidia-cuda-cupti-cu12==12.1.105 53 | nvidia-cuda-nvrtc-cu12==12.1.105 54 | nvidia-cuda-runtime-cu12==12.1.105 55 | nvidia-cudnn-cu12==8.9.2.26 56 | nvidia-cufft-cu12==11.0.2.54 57 | nvidia-curand-cu12==10.3.2.106 58 | nvidia-cusolver-cu12==11.4.5.107 59 | nvidia-cusparse-cu12==12.1.0.106 60 | nvidia-nccl-cu12==2.20.5 61 | nvidia-nvjitlink-cu12==12.4.127 62 | nvidia-nvtx-cu12==12.1.105 63 | opencv-python-headless==4.9.0.80 64 | opencv-python==4.9.0.80 65 | packaging==24.0 66 | pandas==2.2.2 67 | peft==0.10.0 68 | pillow==10.3.0 69 | pip==24.0 70 | protobuf==4.25.3 71 | psutil==5.9.8 72 | py-cpuinfo==9.0.0 73 | pyarrow-hotfix==0.6 74 | pyarrow==16.0.0 75 | pydantic==2.7.1 76 | pydantic_core==2.18.2 77 | pyparsing==3.1.2 78 | python-dateutil==2.9.0.post0 79 | pytorch-lightning==2.2.3 80 | pytz==2024.1 81 | regex==2024.4.16 82 | requests==2.31.0 83 | rich==13.7.1 84 | safetensors==0.4.3 85 | scipy==1.13.0 86 | seaborn==0.13.2 87 | sentry-sdk==2.0.1 88 | setproctitle==1.3.3 89 | setuptools==69.5.1 90 | six==1.16.0 91 | smmap==5.0.1 92 | sniffio==1.3.1 93 | starlette==0.37.2 94 | sympy==1.12 95 | tensorboard-data-server==0.7.2 96 | tensorboard==2.16.2 97 | thop==0.1.1-2209072238 98 | tokenizers==0.19.1 99 | torch==2.3.0 100 | torchmetrics==1.3.2 101 | torchvision==0.18.0 102 | tqdm==4.66.2 103 | transformers==4.40.1 104 | triton==2.3.0 105 | typing_extensions==4.11.0 106 | tzdata==2024.1 107 | ultralytics==8.2.4 108 | urllib3==2.2.1 109 | uvicorn==0.29.0 110 | wandb==0.16.6 111 | wcwidth==0.2.13 112 | xxhash==3.4.1 113 | yarl==1.9.4 114 | zipp==3.18.1 -------------------------------------------------------------------------------- /scripts/config.py: -------------------------------------------------------------------------------- 1 | 2 | MODEL_NAME:str="stabilityai/stable-diffusion-xl-base-1.0" 3 | ADAPTER_NAME:str = "VikramSingh178/sdxl-lora-finetune-product-caption" 4 | ADAPTER_NAME_2:str = "VikramSingh178/Products10k-SDXL-Lora" 5 | VAE_NAME:str= "madebyollin/sdxl-vae-fp16-fix" 6 | DATASET_NAME:str = "hahminlew/kream-product-blip-captions" 7 | PROJECT_NAME:str = "Product Photography" 8 | PRODUCTS_10k_DATASET:str = "VikramSingh178/Products-10k-BLIP-captions" 9 | CAPTIONING_MODEL_NAME:str = "Salesforce/blip-image-captioning-base" 10 | SEGMENTATION_MODEL_NAME:str = "facebook/sam-vit-large" 11 | DETECTION_MODEL_NAME:str = "yolov8l" 12 | ENABLE_COMPILE:bool = False 13 | INPAINTING_MODEL_NAME:str = 'kandinsky-community/kandinsky-2-2-decoder-inpaint' 14 | 15 | 16 | 17 | 18 | class Config: 19 | def __init__(self): 20 | self.pretrained_model_name_or_path = MODEL_NAME 21 | self.pretrained_vae_model_name_or_path = VAE_NAME 22 | self.revision = None 23 | self.variant = None 24 | self.dataset_name = PRODUCTS_10k_DATASET 25 | self.dataset_config_name = None 26 | self.train_data_dir = None 27 | self.image_column = 'image' 28 | self.caption_column = 'text' 29 | self.validation_prompt = None 30 | self.num_validation_images = 4 31 | self.validation_epochs = 1 32 | self.max_train_samples = 7 33 | self.output_dir = "output" 34 | self.cache_dir = None 35 | self.seed = 42 36 | self.resolution = 512 37 | self.center_crop = True 38 | self.random_flip = True 39 | self.train_text_encoder = False 40 | self.train_batch_size = 64 41 | self.num_train_epochs = 400 42 | self.max_train_steps = None 43 | self.checkpointing_steps = 500 44 | self.checkpoints_total_limit = None 45 | self.resume_from_checkpoint = None 46 | self.gradient_accumulation_steps = 1 47 | self.gradient_checkpointing = False 48 | self.learning_rate = 1e-4 49 | self.scale_lr = False 50 | self.lr_scheduler = "constant" 51 | self.lr_warmup_steps = 500 52 | self.snr_gamma = None 53 | self.allow_tf32 = True 54 | self.dataloader_num_workers = 0 55 | self.use_8bit_adam = True 56 | self.adam_beta1 = 0.9 57 | self.adam_beta2 = 0.999 58 | self.adam_weight_decay = 1e-2 59 | self.adam_epsilon = 1e-08 60 | self.max_grad_norm = 1.0 61 | self.push_to_hub = True 62 | self.hub_token = None 63 | self.prediction_type = None 64 | self.hub_model_id = None 65 | self.logging_dir = "logs" 66 | self.report_to = "wandb" 67 | self.mixed_precision = 'fp16' 68 | self.local_rank = -1 69 | self.enable_xformers_memory_efficient_attention = False 70 | self.noise_offset = 0 71 | self.rank = 4 72 | self.debug_loss = False 73 | 74 | 75 | -------------------------------------------------------------------------------- /test-scripts/test_outpainting.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import base64 3 | import json 4 | from PIL import Image 5 | import io 6 | 7 | def encode_image_to_base64(image_path): 8 | """ 9 | Encode an image file to base64 string. 10 | 11 | Args: 12 | image_path (str): Path to the image file. 13 | 14 | Returns: 15 | str: Base64 encoded string of the image. 16 | """ 17 | with open(image_path, "rb") as image_file: 18 | return base64.b64encode(image_file.read()).decode('utf-8') 19 | 20 | def test_outpainting_api(server_url, input_image_path, prompt): 21 | """ 22 | Test the Outpainting API by sending a request and processing the response. 23 | 24 | Args: 25 | server_url (str): URL of the Outpainting API server. 26 | input_image_path (str): Path to the input image file. 27 | prompt (str): The prompt for outpainting. 28 | 29 | Returns: 30 | None 31 | """ 32 | # Prepare the request payload 33 | payload = { 34 | "image": encode_image_to_base64(input_image_path), 35 | "width": 2560, 36 | "height": 1440, 37 | "overlap_percentage": 10, 38 | "num_inference_steps": 8, 39 | "resize_option": "Full", 40 | "custom_resize_percentage": 100, 41 | "prompt_input": prompt, 42 | "alignment": "Middle", 43 | "overlap_left": True, 44 | "overlap_right": True, 45 | "overlap_top": True, 46 | "overlap_bottom": True 47 | } 48 | 49 | # Send POST request to the server 50 | try: 51 | response = requests.post(server_url, json=payload) 52 | response.raise_for_status() # Raise an exception for bad status codes 53 | except requests.exceptions.RequestException as e: 54 | print(f"Error sending request: {e}") 55 | return 56 | 57 | # Process the response 58 | try: 59 | result = response.json() 60 | print(f"Response received:") 61 | print(f"Completion time: {result['completion_time']} seconds") 62 | print(f"Prompt ratio: {result['prompt_ratio']}") 63 | print(f"Image resolution: {result['image_resolution']}") 64 | 65 | # Decode and save the result image 66 | image_data = base64.b64decode(result['result']) 67 | result_image = Image.open(io.BytesIO(image_data)) 68 | result_image.save("outpainting_result.png") 69 | print("Result image saved as 'outpainting_result.png'") 70 | 71 | except (json.JSONDecodeError, KeyError) as e: 72 | print(f"Error processing response: {e}") 73 | print(f"Response content: {response.text}") 74 | 75 | if __name__ == "__main__": 76 | SERVER_URL = "http://localhost:8000/predict" # Adjust this to your server's address 77 | INPUT_IMAGE_PATH = "/root/PicPilot/sample_data/example3.jpg" # Replace with your input image path 78 | PROMPT = "A beautiful landscape " # Replace with your desired prompt 79 | 80 | test_outpainting_api(SERVER_URL, INPUT_IMAGE_PATH, PROMPT) -------------------------------------------------------------------------------- /test-scripts/test.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import base64 3 | import json 4 | from PIL import Image 5 | import io 6 | 7 | def encode_image_to_base64(image_path): 8 | """ 9 | Encode an image file to base64 string. 10 | 11 | Args: 12 | image_path (str): Path to the image file. 13 | 14 | Returns: 15 | str: Base64 encoded string of the image. 16 | """ 17 | with open(image_path, "rb") as image_file: 18 | return base64.b64encode(image_file.read()).decode('utf-8') 19 | 20 | def test_flux_inpainting_api(server_url, input_image_path, mask_image_path, prompt): 21 | """ 22 | Test the Flux Inpainting API by sending a request and processing the response. 23 | 24 | Args: 25 | server_url (str): URL of the Flux Inpainting API server. 26 | input_image_path (str): Path to the input image file. 27 | mask_image_path (str): Path to the mask image file. 28 | prompt (str): The prompt for inpainting. 29 | 30 | Returns: 31 | None 32 | """ 33 | # Prepare the request payload 34 | payload = { 35 | "prompt": prompt, 36 | "strength": 0.8, 37 | "seed": 42, 38 | "num_inference_steps": 50, 39 | "input_image": encode_image_to_base64(input_image_path), 40 | "mask_image": encode_image_to_base64(mask_image_path) 41 | } 42 | 43 | # Send POST request to the server 44 | try: 45 | response = requests.post(server_url, json=payload) 46 | response.raise_for_status() # Raise an exception for bad status codes 47 | except requests.exceptions.RequestException as e: 48 | print(f"Error sending request: {e}") 49 | return 50 | 51 | # Process the response 52 | try: 53 | result = response.json() 54 | print(f"Response received:") 55 | print(f"Result URL: {result['result_url']}") 56 | print(f"Prompt: {result['prompt']}") 57 | print(f"Seed: {result['seed']}") 58 | print(f"Time taken: {result['time_taken']} seconds") 59 | 60 | # Download and save the result image 61 | image_response = requests.get(result['result_url']) 62 | image_response.raise_for_status() 63 | result_image = Image.open(io.BytesIO(image_response.content)) 64 | result_image.save("inpainting_result.png") 65 | print("Result image saved as 'inpainting_result.png'") 66 | 67 | except (json.JSONDecodeError, KeyError) as e: 68 | print(f"Error processing response: {e}") 69 | print(f"Response content: {response.text}") 70 | 71 | if __name__ == "__main__": 72 | SERVER_URL = "http://localhost:8000/api/v2/inpainting/flux" # Adjust this to your server's address 73 | INPUT_IMAGE_PATH = "/root/PicPilot/sample_data/image.jpg" # Replace with your input image path 74 | MASK_IMAGE_PATH = "/root/PicPilot/sample_data/mask.png" # Replace with your mask image path 75 | PROMPT = "Signora Cooker" # Replace with your desired prompt 76 | 77 | test_flux_inpainting_api(SERVER_URL, INPUT_IMAGE_PATH, MASK_IMAGE_PATH, PROMPT) -------------------------------------------------------------------------------- /scripts/inpainting_pipeline.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import AutoPipelineForInpainting,DiffusionPipeline 3 | from diffusers.utils import load_image 4 | from scripts.api_utils import accelerator, ImageAugmentation 5 | import hydra 6 | from omegaconf import DictConfig 7 | from PIL import Image 8 | 9 | def load_pipeline(model_name: str, device, enable_compile: bool = True): 10 | pipeline = AutoPipelineForInpainting.from_pretrained(model_name, torch_dtype=torch.float16) 11 | if enable_compile: 12 | pipeline.unet.to(memory_format=torch.channels_last) 13 | pipeline.unet = torch.compile(pipeline.unet, mode='reduce-overhead',fullgraph=True) 14 | pipeline.to(device) 15 | return pipeline 16 | 17 | 18 | class AutoPaintingPipeline: 19 | def __init__(self, pipeline, image: Image, mask_image: Image, target_width: int, target_height: int): 20 | self.pipeline = pipeline 21 | self.image = image 22 | self.mask_image = mask_image 23 | self.target_width = target_width 24 | self.target_height = target_height 25 | 26 | def run_inference(self, prompt: str, negative_prompt: str, num_inference_steps: int, strength: float, guidance_scale: float,num_images): 27 | output = self.pipeline( 28 | prompt=prompt, 29 | negative_prompt=negative_prompt, 30 | image=self.image, 31 | mask_image=self.mask_image, 32 | num_inference_steps=num_inference_steps, 33 | strength=strength, 34 | guidance_scale=guidance_scale, 35 | num_images_per_prompt = num_images, 36 | height=self.target_height, 37 | width=self.target_width 38 | 39 | ).images[0] 40 | return output 41 | 42 | @hydra.main(version_base=None, config_path="../configs", config_name="inpainting") 43 | def inference(cfg: DictConfig): 44 | # Load the pipeline once and cache it 45 | pipeline = load_pipeline(cfg.model, accelerator(), True) 46 | 47 | # Image augmentation and preparation 48 | augmenter = ImageAugmentation(target_width=cfg.target_width, target_height=cfg.target_height) 49 | image_path = "../sample_data/example3.jpg" 50 | image = Image.open(image_path) 51 | extended_image = augmenter.extend_image(image) 52 | mask_image = augmenter.generate_mask_from_bbox(extended_image, cfg.segmentation_model, cfg.detection_model) 53 | mask_image = augmenter.invert_mask(mask_image) 54 | 55 | 56 | painting_pipeline = AutoPaintingPipeline( 57 | pipeline=pipeline, 58 | image=extended_image, 59 | mask_image=mask_image, 60 | target_height=cfg.target_height, 61 | target_width=cfg.target_width 62 | ) 63 | 64 | # Run inference 65 | output = painting_pipeline.run_inference( 66 | prompt=cfg.prompt, 67 | negative_prompt=cfg.negative_prompt, 68 | num_inference_steps=cfg.num_inference_steps, 69 | strength=cfg.strength, 70 | guidance_scale=cfg.guidance_scale 71 | ) 72 | 73 | # Save output and mask images 74 | output.save(f'{cfg.output_path}/output.jpg') 75 | mask_image.save(f'{cfg.output_path}/mask.jpg') 76 | 77 | if __name__ == "__main__": 78 | inference() 79 | 80 | -------------------------------------------------------------------------------- /scripts/wandb/run-20240429_145519-r0eclldx/files/output.log: -------------------------------------------------------------------------------- 1 | 04/29/2024 14:55:21 - INFO - __main__ - ***** Running training ***** 2 | 04/29/2024 14:55:21 - INFO - __main__ - Num examples = 14904 3 | 04/29/2024 14:55:21 - INFO - __main__ - Num Epochs = 200 4 | 04/29/2024 14:55:21 - INFO - __main__ - Instantaneous batch size per device = 16 5 | 04/29/2024 14:55:21 - INFO - __main__ - Total train batch size (w. parallel, distributed & accumulation) = 16 6 | 04/29/2024 14:55:21 - INFO - __main__ - Gradient Accumulation steps = 1 7 | 04/29/2024 14:55:21 - INFO - __main__ - Total optimization steps = 186400 8 | Steps: 0%| | 0/186400 [00:00 10 | main() 11 | File "/home/product_diffusion_api/scripts/sdxl_lora_tuner.py", line 814, in main 12 | model_input = vae.encode(pixel_values).latent_dist.sample() 13 | File "/home/product_diffusion_api/.venv/lib/python3.10/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper 14 | return method(self, *args, **kwargs) 15 | File "/home/product_diffusion_api/.venv/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 260, in encode 16 | h = self.encoder(x) 17 | File "/home/product_diffusion_api/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl 18 | return self._call_impl(*args, **kwargs) 19 | File "/home/product_diffusion_api/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl 20 | return forward_call(*args, **kwargs) 21 | File "/home/product_diffusion_api/.venv/lib/python3.10/site-packages/diffusers/models/autoencoders/vae.py", line 172, in forward 22 | sample = down_block(sample) 23 | File "/home/product_diffusion_api/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl 24 | return self._call_impl(*args, **kwargs) 25 | File "/home/product_diffusion_api/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl 26 | return forward_call(*args, **kwargs) 27 | File "/home/product_diffusion_api/.venv/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_blocks.py", line 1465, in forward 28 | hidden_states = resnet(hidden_states, temb=None) 29 | File "/home/product_diffusion_api/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl 30 | return self._call_impl(*args, **kwargs) 31 | File "/home/product_diffusion_api/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl 32 | return forward_call(*args, **kwargs) 33 | File "/home/product_diffusion_api/.venv/lib/python3.10/site-packages/diffusers/models/resnet.py", line 332, in forward 34 | hidden_states = self.norm1(hidden_states) 35 | File "/home/product_diffusion_api/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl 36 | return self._call_impl(*args, **kwargs) 37 | File "/home/product_diffusion_api/.venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl 38 | return forward_call(*args, **kwargs) 39 | File "/home/product_diffusion_api/.venv/lib/python3.10/site-packages/torch/nn/modules/normalization.py", line 287, in forward 40 | return F.group_norm( 41 | File "/home/product_diffusion_api/.venv/lib/python3.10/site-packages/torch/nn/functional.py", line 2588, in group_norm 42 | return torch.group_norm(input, num_groups, weight, bias, eps, torch.backends.cudnn.enabled) 43 | torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 8.00 GiB. GPU -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ![Picpilot Amplify Dark](https://github.com/user-attachments/assets/4c125f98-329e-4661-bc74-15a5b23899e1) 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | [![GitHub Stars](https://img.shields.io/github/stars/VikramxD/Picpilot?style=social)](https://github.com/VikramxD/picpilot/stargazers) 13 | [![License](https://img.shields.io/badge/license-MIT-blue.svg)](https://github.com/VikramxD/Picpilot/blob/main/LICENSE) 14 | 15 | 16 | > PicPilot is a scalable solution that leverages state-of-the-art Text to Image Models to extend and enhance images and create product photography in seconds for your brand. Whether you're working with existing product images or need to generate new visuals, PicPilot prepares your visual content to be stunning, professional, and ready for marketing applications. 17 | 18 | 19 | 20 | ## Features 21 | ✅ Flux Inpainting for detailed image editing 22 | ✅ SDXL (Stable Diffusion XL) with LoRA for high-quality image generation 23 | ✅ SDXL Outpainting for extending images seamlessly 24 | ✅ Image to Video Generation using CogvideoX 25 | ✅ Batch processing support with configurable batch sizes and timeouts 26 | 27 | ### Why PicPilot? 28 | Creating professional product photography and visual narratives can be time-consuming and expensive. PicPilot aims to revolutionize this process by offering an AI-powered platform where you can enhance existing images, generate new ones, or even convert images to videos, creating stunning visuals for your brand in seconds. 29 | 30 | ## Installation 31 | 32 | ```bash 33 | git clone https://github.com/VikramxD/Picpilot 34 | cd Picpilot 35 | ``` 36 | 37 | 38 | 39 | Install Dependencies: 40 | 41 | ```bash 42 | ./run.sh 43 | ``` 44 | 45 | ### 🛳️ Docker 46 | 47 | To use PicPilot with Docker, execute the following commands: 48 | 49 | ```bash 50 | docker pull vikram1202/picpilot:latest 51 | docker run --gpus all -p 8000:8000 vikram1202/picpilot:latest 52 | ``` 53 | 54 | Alternatively, if you prefer to build the Docker image locally: 55 | 56 | ```bash 57 | docker build -t picpilot . 58 | docker run --gpus all -p 8000:8000 picpilot 59 | ``` 60 | 61 | ## Usage 62 | 63 | Run the Server: 64 | 65 | ```bash 66 | cd api 67 | python picpilot.py 68 | ``` 69 | 70 | This will start the server on port 8000 with all available API endpoints. 71 | 72 | ## API Endpoints 73 | 74 | PicPilot offers the following API endpoints: 75 | 76 | | Endpoint | Path | Purpose | Max Batch Size | Batch Timeout | 77 | |----------|------|---------|----------------|---------------| 78 | | Flux Inpainting | `/api/v2/painting/flux` | Detailed image editing and inpainting | 4 | 0.1 seconds | 79 | | SDXL Generation | `/api/v2/generate/sdxl` | High-quality image generation using SDXL with LoRA | Configured in `tti_settings` | Configured in `tti_settings` | 80 | | SDXL Outpainting | `/api/v2/painting/sdxl_outpainting` | Extending images seamlessly | 4 | 0.1 seconds | 81 | | Image to Video | `/api/v2/image2video/cogvideox` | Converting images to videos | 1 | 0.1 seconds | 82 | 83 | ## Next Features 84 | - Support for Image Editing in FLUX Models 85 | - Support for Custom Flux LORA'S 86 | - Support for CogvideoX finetuning 87 | 88 | ## Limitations 89 | - Requires Powerful GPU's to Run for optimal performance Especially the FLUX Models 90 | - Processing time may vary depending on the complexity of the task and input size 91 | - Image to video conversion is limited to one image at a time 92 | 93 | ## License 94 | PicPilot is licensed under the MIT license. See `LICENSE` for more information. 95 | 96 | ## Acknowledgements 97 | 98 | This project utilizes several open-source models and libraries. We express our gratitude to the authors and contributors of: 99 | 100 | - Diffusers 101 | - LitServe 102 | - Transformers 103 | 104 | --- 105 | 106 | -------------------------------------------------------------------------------- /scripts/wandb/run-20240429_145519-r0eclldx/files/config.yaml: -------------------------------------------------------------------------------- 1 | wandb_version: 1 2 | 3 | _wandb: 4 | desc: null 5 | value: 6 | code_path: code/scripts/sdxl_lora_tuner.py 7 | python_version: 3.10.12 8 | cli_version: 0.16.6 9 | framework: huggingface 10 | huggingface_version: 4.40.1 11 | is_jupyter_run: false 12 | is_kaggle_kernel: false 13 | start_time: 1714402519.0 14 | t: 15 | 1: 16 | - 1 17 | - 11 18 | - 41 19 | - 49 20 | - 51 21 | - 55 22 | - 71 23 | - 83 24 | - 98 25 | 2: 26 | - 1 27 | - 11 28 | - 41 29 | - 49 30 | - 51 31 | - 55 32 | - 71 33 | - 83 34 | - 98 35 | 3: 36 | - 23 37 | 4: 3.10.12 38 | 5: 0.16.6 39 | 6: 4.40.1 40 | 8: 41 | - 5 42 | 13: linux-x86_64 43 | pretrained_model_name_or_path: 44 | desc: null 45 | value: stabilityai/stable-diffusion-xl-base-1.0 46 | pretrained_vae_model_name_or_path: 47 | desc: null 48 | value: madebyollin/sdxl-vae-fp16-fix 49 | revision: 50 | desc: null 51 | value: null 52 | variant: 53 | desc: null 54 | value: null 55 | dataset_name: 56 | desc: null 57 | value: hahminlew/kream-product-blip-captions 58 | dataset_config_name: 59 | desc: null 60 | value: null 61 | train_data_dir: 62 | desc: null 63 | value: null 64 | image_column: 65 | desc: null 66 | value: image 67 | caption_column: 68 | desc: null 69 | value: text 70 | validation_prompt: 71 | desc: null 72 | value: null 73 | num_validation_images: 74 | desc: null 75 | value: 4 76 | validation_epochs: 77 | desc: null 78 | value: 1 79 | max_train_samples: 80 | desc: null 81 | value: null 82 | output_dir: 83 | desc: null 84 | value: output 85 | cache_dir: 86 | desc: null 87 | value: null 88 | seed: 89 | desc: null 90 | value: null 91 | resolution: 92 | desc: null 93 | value: 1024 94 | center_crop: 95 | desc: null 96 | value: false 97 | random_flip: 98 | desc: null 99 | value: false 100 | train_text_encoder: 101 | desc: null 102 | value: false 103 | train_batch_size: 104 | desc: null 105 | value: 16 106 | num_train_epochs: 107 | desc: null 108 | value: 200 109 | max_train_steps: 110 | desc: null 111 | value: 186400 112 | checkpointing_steps: 113 | desc: null 114 | value: 500 115 | checkpoints_total_limit: 116 | desc: null 117 | value: null 118 | resume_from_checkpoint: 119 | desc: null 120 | value: null 121 | gradient_accumulation_steps: 122 | desc: null 123 | value: 1 124 | gradient_checkpointing: 125 | desc: null 126 | value: false 127 | learning_rate: 128 | desc: null 129 | value: 0.0001 130 | scale_lr: 131 | desc: null 132 | value: false 133 | lr_scheduler: 134 | desc: null 135 | value: constant 136 | lr_warmup_steps: 137 | desc: null 138 | value: 500 139 | snr_gamma: 140 | desc: null 141 | value: null 142 | allow_tf32: 143 | desc: null 144 | value: false 145 | dataloader_num_workers: 146 | desc: null 147 | value: 0 148 | use_8bit_adam: 149 | desc: null 150 | value: true 151 | adam_beta1: 152 | desc: null 153 | value: 0.9 154 | adam_beta2: 155 | desc: null 156 | value: 0.999 157 | adam_weight_decay: 158 | desc: null 159 | value: 0.01 160 | adam_epsilon: 161 | desc: null 162 | value: 1.0e-08 163 | max_grad_norm: 164 | desc: null 165 | value: 1.0 166 | push_to_hub: 167 | desc: null 168 | value: false 169 | hub_token: 170 | desc: null 171 | value: null 172 | prediction_type: 173 | desc: null 174 | value: null 175 | hub_model_id: 176 | desc: null 177 | value: null 178 | logging_dir: 179 | desc: null 180 | value: logs 181 | report_to: 182 | desc: null 183 | value: wandb 184 | mixed_precision: 185 | desc: null 186 | value: null 187 | local_rank: 188 | desc: null 189 | value: -1 190 | enable_xformers_memory_efficient_attention: 191 | desc: null 192 | value: false 193 | noise_offset: 194 | desc: null 195 | value: 0 196 | rank: 197 | desc: null 198 | value: 4 199 | debug_loss: 200 | desc: null 201 | value: false 202 | -------------------------------------------------------------------------------- /test-scripts/test_image2video.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import base64 3 | import json 4 | from PIL import Image 5 | import io 6 | import os 7 | 8 | def encode_image_to_base64(image_path): 9 | """ 10 | Encode an image file to base64 string. 11 | 12 | Args: 13 | image_path (str): Path to the image file. 14 | 15 | Returns: 16 | str: Base64 encoded string of the image. 17 | """ 18 | try: 19 | with open(image_path, "rb") as image_file: 20 | return base64.b64encode(image_file.read()).decode('utf-8') 21 | except Exception as e: 22 | print(f"Error encoding image: {e}") 23 | return None 24 | 25 | def validate_image(base64_string): 26 | """ 27 | Validate the base64 encoded image by attempting to open it with PIL. 28 | 29 | Args: 30 | base64_string (str): Base64 encoded image string. 31 | 32 | Returns: 33 | bool: True if valid, False otherwise. 34 | """ 35 | try: 36 | image_data = base64.b64decode(base64_string) 37 | Image.open(io.BytesIO(image_data)) 38 | return True 39 | except Exception as e: 40 | print(f"Error validating image: {e}") 41 | return False 42 | 43 | def test_image_to_video_api(server_url, input_image_path, prompt): 44 | """ 45 | Test the Image-to-Video API by sending a request and processing the response. 46 | 47 | Args: 48 | server_url (str): URL of the Image-to-Video API server. 49 | input_image_path (str): Path to the input image file. 50 | prompt (str): The prompt for video generation. 51 | 52 | Returns: 53 | None 54 | """ 55 | # Encode and validate the image 56 | base64_image = encode_image_to_base64(input_image_path) 57 | if not base64_image: 58 | print("Failed to encode image.") 59 | return 60 | 61 | if not validate_image(base64_image): 62 | print("Encoded image is not valid.") 63 | return 64 | 65 | # Prepare the request payload 66 | payload = { 67 | "image": base64_image, 68 | "prompt": prompt, 69 | "num_frames": 49, 70 | "num_inference_steps": 20, 71 | "guidance_scale": 6.0, 72 | "height": 480, 73 | "width": 720, 74 | "use_dynamic_cfg": True, 75 | "fps": 10 76 | } 77 | 78 | # Send POST request to the server 79 | try: 80 | response = requests.post(server_url, json=payload) 81 | response.raise_for_status() # Raise an exception for bad status codes 82 | except requests.exceptions.RequestException as e: 83 | print(f"Error sending request: {e}") 84 | print(f"Response status code: {response.status_code}") 85 | print(f"Response content: {response.text}") 86 | return 87 | 88 | # Process the response 89 | try: 90 | result = response.json() 91 | print(f"Response received:") 92 | print(f"Completion time: {result['completion_time']} seconds") 93 | print(f"Video resolution: {result['video_resolution']}") 94 | print(f"FPS: {result['fps']}") 95 | 96 | # Save the result video 97 | video_url = result['result']['url'] 98 | video_response = requests.get(video_url) 99 | video_response.raise_for_status() 100 | 101 | output_path = "generated_video.mp4" 102 | with open(output_path, "wb") as video_file: 103 | video_file.write(video_response.content) 104 | 105 | print(f"Result video saved as '{output_path}'") 106 | 107 | except (json.JSONDecodeError, KeyError, requests.exceptions.RequestException) as e: 108 | print(f"Error processing response: {e}") 109 | print(f"Response content: {response.text}") 110 | 111 | if __name__ == "__main__": 112 | SERVER_URL = "http://localhost:8000/predict" # Adjust this to your server's address 113 | INPUT_IMAGE_PATH = "/root/PicPilot/sample_data/product_img.jpg" # Replace with your input image path 114 | PROMPT = "A product shot of a Nike Shoe in a studio High quality, ultrarealistic detail and breath-taking movie-like camera shot" # Replace with your desired prompt 115 | 116 | test_image_to_video_api(SERVER_URL, INPUT_IMAGE_PATH, PROMPT) -------------------------------------------------------------------------------- /scripts/products10k_captions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from datasets import load_dataset, Dataset 3 | from transformers import BlipProcessor, BlipForConditionalGeneration 4 | from tqdm import tqdm 5 | from config import PRODUCTS_10k_DATASET, CAPTIONING_MODEL_NAME 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | 9 | class ImageCaptioner: 10 | """ 11 | A class for generating captions for images using a pre-trained model. 12 | 13 | Args: 14 | dataset (str): The path to the dataset. 15 | processor (str): The pre-trained processor model to use for image processing. 16 | model (str): The pre-trained model to use for caption generation. 17 | prompt (str): The conditioning prompt to use for caption generation. 18 | 19 | Attributes: 20 | dataset: The loaded dataset. 21 | processor: The pre-trained processor model. 22 | model: The pre-trained caption generation model. 23 | prompt: The conditioning prompt for generating captions. 24 | 25 | Methods: 26 | process_dataset: Preprocesses the dataset. 27 | generate_caption: Generates a caption for a single image. 28 | generate_captions: Generates captions for all images in the dataset. 29 | """ 30 | 31 | def __init__(self, dataset: str, processor: str, model: str, prompt: str = "Product photo of"): 32 | self.dataset = load_dataset(dataset, split="test") 33 | self.dataset = self.dataset.select(range(10000)) # For demonstration purposes 34 | self.processor = BlipProcessor.from_pretrained(processor) 35 | self.model = BlipForConditionalGeneration.from_pretrained(model).to(device) 36 | self.prompt = prompt 37 | 38 | def process_dataset(self): 39 | """ 40 | Preprocesses the dataset by renaming the image column and removing unwanted columns. 41 | 42 | Returns: 43 | The preprocessed dataset. 44 | """ 45 | # Check if 'image' column exists, otherwise use 'pixel_values' if it exists 46 | image_column = "image" if "image" in self.dataset.column_names else "pixel_values" 47 | self.dataset = self.dataset.rename_column(image_column, "image") 48 | 49 | if "label" in self.dataset.column_names: 50 | self.dataset = self.dataset.remove_columns(["label"]) 51 | 52 | # Add an empty 'text' column for captions if it doesn't exist 53 | if "text" not in self.dataset.column_names: 54 | new_column = [""] * len(self.dataset) 55 | self.dataset = self.dataset.add_column("text", new_column) 56 | 57 | return self.dataset 58 | 59 | def generate_caption(self, example): 60 | """ 61 | Generates a caption for a single image. 62 | 63 | Args: 64 | example (dict): A dictionary containing the image data. 65 | 66 | Returns: 67 | dict: The dictionary with the generated caption. 68 | """ 69 | image = example["image"].convert("RGB") 70 | inputs = self.processor(images=image, return_tensors="pt").to(device) 71 | prompt_inputs = self.processor(text=[self.prompt], return_tensors="pt").to(device) 72 | outputs = self.model.generate(**inputs, **prompt_inputs) 73 | blip_caption = self.processor.decode(outputs[0], skip_special_tokens=True) 74 | example["text"] = blip_caption 75 | return example 76 | 77 | def generate_captions(self): 78 | """ 79 | Generates captions for all images in the dataset. 80 | 81 | Returns: 82 | Dataset: The dataset with generated captions. 83 | """ 84 | self.dataset = self.process_dataset() 85 | self.dataset = self.dataset.map(self.generate_caption, batched=False) 86 | return self.dataset 87 | 88 | # Initialize ImageCaptioner 89 | ic = ImageCaptioner( 90 | dataset=PRODUCTS_10k_DATASET, 91 | processor=CAPTIONING_MODEL_NAME, 92 | model=CAPTIONING_MODEL_NAME, 93 | prompt='Commercial photography of' 94 | ) 95 | 96 | # Generate captions for the dataset 97 | products10k_dataset = ic.generate_captions() 98 | 99 | # Save the dataset to the hub 100 | products10k_dataset.push_to_hub("VikramSingh178/Products-10k-BLIP-captions") 101 | -------------------------------------------------------------------------------- /scripts/image_to_video.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides a class for generating videos from images using the CogVideoX model. 3 | """ 4 | 5 | import torch 6 | from diffusers import CogVideoXImageToVideoPipeline 7 | from diffusers.utils import load_image 8 | 9 | 10 | class ImageToVideoPipeline: 11 | """ 12 | A class to generate videos from images using the CogVideoX model. 13 | 14 | This class encapsulates the functionality of the CogVideoXImageToVideoPipeline, 15 | providing methods to generate video frames from an input image and save them as a video file. 16 | 17 | Attributes: 18 | pipe (CogVideoXImageToVideoPipeline): The underlying CogVideoX pipeline. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | model_path: str = "THUDM/CogVideoX-5b-I2V", 24 | device: str = "cuda:2", 25 | torch_dtype: torch.dtype = torch.bfloat16 26 | ): 27 | """ 28 | Initialize the ImageToVideoPipeline. 29 | 30 | Args: 31 | model_path (str): Path to the pretrained CogVideoX model. 32 | device (str): The device to run the model on (e.g., "cuda:2", "cpu"). 33 | torch_dtype (torch.dtype): The torch data type to use for computations. 34 | """ 35 | self.pipe = CogVideoXImageToVideoPipeline.from_pretrained( 36 | model_path, 37 | torch_dtype=torch_dtype 38 | ) 39 | self.pipe.to(device) 40 | 41 | def generate( 42 | self, 43 | prompt: str, 44 | image: str | torch.Tensor, 45 | negative_prompt: str | None = None, 46 | num_frames: int = 49, 47 | num_inference_steps: int = 50, 48 | guidance_scale: float = 6.0, 49 | use_dynamic_cfg: bool = True, 50 | height: int = 480, 51 | width: int = 720, 52 | num_videos_per_prompt: int = 1 53 | ) -> list: 54 | """ 55 | Generate video frames from an input image. 56 | 57 | Args: 58 | prompt (str): The text prompt to guide the video generation. 59 | image (str | torch.Tensor): The input image path or tensor. 60 | negative_prompt (str | None): The negative prompt to guide the generation. 61 | num_frames (int): The number of frames to generate. 62 | num_inference_steps (int): The number of denoising steps. 63 | guidance_scale (float): The scale for classifier-free guidance. 64 | use_dynamic_cfg (bool): Whether to use dynamic CFG. 65 | height (int): The height of the output video frames. 66 | width (int): The width of the output video frames. 67 | num_videos_per_prompt (int): The number of videos to generate per prompt. 68 | 69 | Returns: 70 | list: A list of generated video frames. 71 | """ 72 | if isinstance(image, str): 73 | image = load_image(image) 74 | 75 | result = self.pipe( 76 | image=image, 77 | prompt=prompt, 78 | negative_prompt=negative_prompt, 79 | num_frames=num_frames, 80 | num_inference_steps=num_inference_steps, 81 | guidance_scale=guidance_scale, 82 | use_dynamic_cfg=use_dynamic_cfg, 83 | height=height, 84 | width=width, 85 | num_videos_per_prompt=num_videos_per_prompt 86 | ) 87 | return result.frames[0] 88 | 89 | 90 | if __name__ == "__main__": 91 | # Initialize the pipeline 92 | pipeline = ImageToVideoPipeline(device="cuda:2") 93 | prompt = ("An astronaut hatching from an egg, on the surface of the moon the darkness and depth of space realised in the background. ,High quality, ultrarealistic detail and breath-taking movie-like camera shot.") 94 | image_url = ( 95 | "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg" 96 | ) 97 | 98 | # Generate and save the video 99 | pipeline.generate( 100 | prompt=prompt, 101 | image=image_url, 102 | output_file="custom_output.mp4", 103 | num_frames=60, 104 | num_inference_steps=75, 105 | guidance_scale=7.5, 106 | use_dynamic_cfg=True, 107 | height=640, 108 | width=960, 109 | fps=30 110 | ) -------------------------------------------------------------------------------- /test-scripts/test_sdxl.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import json 3 | import base64 4 | from PIL import Image 5 | import io 6 | 7 | class SDXLLoraClient: 8 | """ 9 | Client for interacting with the SDXL LoRA server. 10 | """ 11 | 12 | def __init__(self, base_url: str): 13 | """ 14 | Initialize the client with the server's base URL. 15 | 16 | Args: 17 | base_url (str): The base URL of the SDXL LoRA server. 18 | """ 19 | self.base_url = base_url 20 | 21 | def generate_image(self, prompt: str, negative_prompt: str = "", num_images: int = 1, 22 | num_inference_steps: int = 50, guidance_scale: float = 7.5, 23 | mode: str = "b64_json") -> list: 24 | """ 25 | Send a request to the server to generate images. 26 | 27 | Args: 28 | prompt (str): The prompt for image generation. 29 | negative_prompt (str, optional): The negative prompt. Defaults to "". 30 | num_images (int, optional): Number of images to generate. Defaults to 1. 31 | num_inference_steps (int, optional): Number of inference steps. Defaults to 50. 32 | guidance_scale (float, optional): Guidance scale. Defaults to 7.5. 33 | mode (str, optional): Response mode ('b64_json' or 's3_json'). Defaults to "b64_json". 34 | 35 | Returns: 36 | list: A list of generated images (as PIL Image objects) or S3 URLs. 37 | """ 38 | payload = { 39 | "prompt": prompt, 40 | "negative_prompt": negative_prompt, 41 | "num_images": num_images, 42 | "num_inference_steps": num_inference_steps, 43 | "guidance_scale": guidance_scale, 44 | "mode": mode 45 | } 46 | 47 | response = requests.post(f"{self.base_url}", json=payload) 48 | response.raise_for_status() 49 | 50 | result = response.json() 51 | print(f"Server response: {result}") # Debug print 52 | 53 | if isinstance(result, str): 54 | print(f"Unexpected string result: {result}") 55 | return [result] 56 | 57 | if mode == "b64_json": 58 | if isinstance(result, list): 59 | return [Image.open(io.BytesIO(base64.b64decode(img["base64"]))) for img in result] 60 | elif isinstance(result, dict) and "base64" in result: 61 | return [Image.open(io.BytesIO(base64.b64decode(result["base64"])))] 62 | else: 63 | raise ValueError(f"Unexpected result format for b64_json mode: {result}") 64 | elif mode == "s3_json": 65 | if isinstance(result, list): 66 | return [img["url"] for img in result if "url" in img] 67 | elif isinstance(result, dict) and "url" in result: 68 | return [result["url"]] 69 | else: 70 | raise ValueError(f"Unexpected result format for s3_json mode: {result}") 71 | else: 72 | raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.") 73 | 74 | def main(): 75 | """ 76 | Main function to demonstrate the usage of the SDXLLoraClient. 77 | """ 78 | client = SDXLLoraClient("https://tqfdmm6n7udye4-8000.proxy.runpod.net/api/v2/generate/sdxl") 79 | 80 | # Test case 1: Generate a single image 81 | print("Generating a single image...") 82 | images = client.generate_image( 83 | prompt="A serene landscape with mountains and a lake", 84 | negative_prompt='Low resolution , Poor Resolution', 85 | mode="s3_json" 86 | ) 87 | 88 | # Test case 2: Generate multiple images 89 | print("\nGenerating multiple images...") 90 | images = client.generate_image( 91 | prompt="A futuristic cityscape at night", 92 | num_images=3, 93 | num_inference_steps=30, 94 | guidance_scale=8.0, 95 | mode="s3_json" 96 | ) 97 | for i, img in enumerate(images): 98 | if isinstance(img, Image.Image): 99 | img.save(f"test_image_2_{i+1}.png") 100 | print(f"Image saved as test_image_2_{i+1}.png") 101 | else: 102 | print(f"Unexpected result for image {i+1}: {img}") 103 | 104 | # Test case 3: Generate image with S3 storage 105 | print("\nGenerating image with S3 storage...") 106 | urls = client.generate_image( 107 | prompt="An abstract painting with vibrant colors", 108 | mode="s3_json" 109 | ) 110 | print(f"S3 URLS for the generated image: {urls}") 111 | 112 | if __name__ == "__main__": 113 | main() -------------------------------------------------------------------------------- /api/image2video_serve.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import tempfile 4 | from typing import Dict, Any, Tuple 5 | from PIL import Image 6 | import base64 7 | from pydantic import BaseModel, Field 8 | import time 9 | from diffusers.utils import export_to_video 10 | from litserve import LitAPI, LitServer 11 | from scripts.api_utils import mp4_to_s3_json 12 | from scripts.image_to_video import ImageToVideoPipeline 13 | 14 | class ImageToVideoRequest(BaseModel): 15 | """ 16 | Pydantic model representing a request for image-to-video generation. 17 | """ 18 | image: str = Field(..., description="Base64 encoded input image") 19 | prompt: str = Field(..., description="Text prompt for video generation") 20 | num_frames: int = Field(49, description="Number of frames to generate") 21 | num_inference_steps: int = Field(50, description="Number of inference steps") 22 | guidance_scale: float = Field(6.0, description="Guidance scale") 23 | height: int = Field(480, description="Height of the output video") 24 | width: int = Field(720, description="Width of the output video") 25 | use_dynamic_cfg: bool = Field(True, description="Use dynamic CFG") 26 | fps: int = Field(30, description="Frames per second for the output video") 27 | 28 | class ImageToVideoAPI(LitAPI): 29 | """ 30 | LitAPI implementation for Image-to-Video model serving. 31 | """ 32 | 33 | def setup(self, device: str) -> None: 34 | """ 35 | Set up the Image-to-Video pipeline and associated resources. 36 | """ 37 | self.device = device 38 | self.pipeline = ImageToVideoPipeline(device=device) 39 | 40 | def decode_request(self, request: Dict[str, Any]) -> Dict[str, Any]: 41 | """ 42 | Decode the incoming request and prepare inputs for the model. 43 | """ 44 | try: 45 | video_request = ImageToVideoRequest(**request) 46 | image_data = base64.b64decode(video_request.image) 47 | image = Image.open(io.BytesIO(image_data)).convert("RGB") 48 | 49 | return { 50 | 'image': image, 51 | 'params': video_request.model_dump() 52 | } 53 | except Exception as e: 54 | raise ValueError(f"Invalid request: {str(e)}") 55 | 56 | def predict(self, inputs: Dict[str, Any]) -> Tuple[list, float, int]: 57 | """ 58 | Run predictions on the input. 59 | """ 60 | image = inputs['image'] 61 | params = inputs['params'] 62 | 63 | start_time = time.time() 64 | 65 | result = self.pipeline.generate( 66 | prompt=params['prompt'], 67 | image=image, 68 | num_frames=params['num_frames'], 69 | num_inference_steps=params['num_inference_steps'], 70 | guidance_scale=params['guidance_scale'], 71 | height=params['height'], 72 | width=params['width'], 73 | use_dynamic_cfg=params['use_dynamic_cfg'] 74 | ) 75 | 76 | if isinstance(result, tuple): 77 | frames = result[0] 78 | elif hasattr(result, 'frames'): 79 | frames = result.frames 80 | else: 81 | frames = result 82 | 83 | completion_time = time.time() - start_time 84 | return frames, completion_time, params['fps'] 85 | 86 | def encode_response(self, output: Tuple[list, float, int]) -> Dict[str, Any]: 87 | """ 88 | Encode the model output and additional information into a response payload. 89 | """ 90 | frames, completion_time, fps = output 91 | try: 92 | # Create a temporary directory to store the video file 93 | with tempfile.TemporaryDirectory() as temp_dir: 94 | temp_video_path = os.path.join(temp_dir, "generated_video.mp4") 95 | 96 | # Export the video to the temporary file 97 | export_to_video(frames, temp_video_path, fps=fps) 98 | 99 | # Read the video file and upload to S3 100 | with open(temp_video_path, "rb") as video_file: 101 | s3_response = mp4_to_s3_json(video_file, "generated_video.mp4") 102 | 103 | return { 104 | "result": s3_response, 105 | "completion_time": round(completion_time, 2), 106 | "video_resolution": f"{frames[0].width}x{frames[0].height}", 107 | "fps": fps 108 | } 109 | except Exception as e: 110 | # Log the error for debugging 111 | print(f"Error in encode_response: {str(e)}") 112 | raise 113 | 114 | if __name__ == "__main__": 115 | api = ImageToVideoAPI() 116 | server = LitServer(api, accelerator="cuda", max_batch_size=1) 117 | server.run(port=8000) -------------------------------------------------------------------------------- /scripts/wandb/run-20240429_145519-r0eclldx/logs/debug.log: -------------------------------------------------------------------------------- 1 | 2024-04-29 14:55:19,803 INFO MainThread:4560 [wandb_setup.py:_flush():76] Current SDK version is 0.16.6 2 | 2024-04-29 14:55:19,803 INFO MainThread:4560 [wandb_setup.py:_flush():76] Configure stats pid to 4560 3 | 2024-04-29 14:55:19,803 INFO MainThread:4560 [wandb_setup.py:_flush():76] Loading settings from /home/.config/wandb/settings 4 | 2024-04-29 14:55:19,804 INFO MainThread:4560 [wandb_setup.py:_flush():76] Loading settings from /home/product_diffusion_api/scripts/wandb/settings 5 | 2024-04-29 14:55:19,804 INFO MainThread:4560 [wandb_setup.py:_flush():76] Loading settings from environment variables: {} 6 | 2024-04-29 14:55:19,804 INFO MainThread:4560 [wandb_setup.py:_flush():76] Applying setup settings: {'_disable_service': False} 7 | 2024-04-29 14:55:19,804 INFO MainThread:4560 [wandb_setup.py:_flush():76] Inferring run settings from compute environment: {'program_relpath': 'scripts/sdxl_lora_tuner.py', 'program_abspath': '/home/product_diffusion_api/scripts/sdxl_lora_tuner.py', 'program': '/home/product_diffusion_api/scripts/sdxl_lora_tuner.py'} 8 | 2024-04-29 14:55:19,804 INFO MainThread:4560 [wandb_setup.py:_flush():76] Applying login settings: {} 9 | 2024-04-29 14:55:19,804 INFO MainThread:4560 [wandb_init.py:_log_setup():521] Logging user logs to /home/product_diffusion_api/scripts/wandb/run-20240429_145519-r0eclldx/logs/debug.log 10 | 2024-04-29 14:55:19,804 INFO MainThread:4560 [wandb_init.py:_log_setup():522] Logging internal logs to /home/product_diffusion_api/scripts/wandb/run-20240429_145519-r0eclldx/logs/debug-internal.log 11 | 2024-04-29 14:55:19,804 INFO MainThread:4560 [wandb_init.py:init():561] calling init triggers 12 | 2024-04-29 14:55:19,804 INFO MainThread:4560 [wandb_init.py:init():568] wandb.init called with sweep_config: {} 13 | config: {} 14 | 2024-04-29 14:55:19,804 INFO MainThread:4560 [wandb_init.py:init():611] starting backend 15 | 2024-04-29 14:55:19,804 INFO MainThread:4560 [wandb_init.py:init():615] setting up manager 16 | 2024-04-29 14:55:19,806 INFO MainThread:4560 [backend.py:_multiprocessing_setup():105] multiprocessing start_methods=fork,spawn,forkserver, using: spawn 17 | 2024-04-29 14:55:19,811 INFO MainThread:4560 [wandb_init.py:init():623] backend started and connected 18 | 2024-04-29 14:55:19,814 INFO MainThread:4560 [wandb_init.py:init():715] updated telemetry 19 | 2024-04-29 14:55:19,821 INFO MainThread:4560 [wandb_init.py:init():748] communicating run to backend with 90.0 second timeout 20 | 2024-04-29 14:55:20,569 INFO MainThread:4560 [wandb_run.py:_on_init():2357] communicating current version 21 | 2024-04-29 14:55:20,817 INFO MainThread:4560 [wandb_run.py:_on_init():2366] got version response 22 | 2024-04-29 14:55:20,818 INFO MainThread:4560 [wandb_init.py:init():799] starting run threads in backend 23 | 2024-04-29 14:55:21,118 INFO MainThread:4560 [wandb_run.py:_console_start():2335] atexit reg 24 | 2024-04-29 14:55:21,118 INFO MainThread:4560 [wandb_run.py:_redirect():2190] redirect: wrap_raw 25 | 2024-04-29 14:55:21,118 INFO MainThread:4560 [wandb_run.py:_redirect():2255] Wrapping output streams. 26 | 2024-04-29 14:55:21,118 INFO MainThread:4560 [wandb_run.py:_redirect():2280] Redirects installed. 27 | 2024-04-29 14:55:21,119 INFO MainThread:4560 [wandb_init.py:init():842] run started, returning control to user process 28 | 2024-04-29 14:55:21,120 INFO MainThread:4560 [wandb_run.py:_config_callback():1347] config_cb None None {'pretrained_model_name_or_path': 'stabilityai/stable-diffusion-xl-base-1.0', 'pretrained_vae_model_name_or_path': 'madebyollin/sdxl-vae-fp16-fix', 'revision': None, 'variant': None, 'dataset_name': 'hahminlew/kream-product-blip-captions', 'dataset_config_name': None, 'train_data_dir': None, 'image_column': 'image', 'caption_column': 'text', 'validation_prompt': None, 'num_validation_images': 4, 'validation_epochs': 1, 'max_train_samples': None, 'output_dir': 'output', 'cache_dir': None, 'seed': None, 'resolution': 1024, 'center_crop': False, 'random_flip': False, 'train_text_encoder': False, 'train_batch_size': 16, 'num_train_epochs': 200, 'max_train_steps': 186400, 'checkpointing_steps': 500, 'checkpoints_total_limit': None, 'resume_from_checkpoint': None, 'gradient_accumulation_steps': 1, 'gradient_checkpointing': False, 'learning_rate': 0.0001, 'scale_lr': False, 'lr_scheduler': 'constant', 'lr_warmup_steps': 500, 'snr_gamma': None, 'allow_tf32': False, 'dataloader_num_workers': 0, 'use_8bit_adam': True, 'adam_beta1': 0.9, 'adam_beta2': 0.999, 'adam_weight_decay': 0.01, 'adam_epsilon': 1e-08, 'max_grad_norm': 1.0, 'push_to_hub': False, 'hub_token': None, 'prediction_type': None, 'hub_model_id': None, 'logging_dir': 'logs', 'report_to': 'wandb', 'mixed_precision': None, 'local_rank': -1, 'enable_xformers_memory_efficient_attention': False, 'noise_offset': 0, 'rank': 4, 'debug_loss': False} 29 | 2024-04-29 14:55:32,175 WARNING MsgRouterThr:4560 [router.py:message_loop():77] message_loop has been closed 30 | -------------------------------------------------------------------------------- /api/sdxl_serve.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers import DiffusionPipeline 3 | import litserve as ls 4 | from typing import Dict, Any, List 5 | from PIL import Image 6 | from config_settings import settings 7 | from configs.tti_settings import tti_settings 8 | from scripts.api_utils import pil_to_b64_json, pil_to_s3_json 9 | 10 | DEVICE = 'cuda:1' 11 | 12 | class SDXLLoraAPI(ls.LitAPI): 13 | """ 14 | LitAPI implementation for serving SDXL (Stable Diffusion XL) model with LoRA. 15 | 16 | This class defines the API for the SDXL model with LoRA, including methods for 17 | setup, request decoding, batching, prediction, and response encoding. 18 | """ 19 | 20 | def setup(self, device: str) -> None: 21 | """ 22 | Set up the SDXL pipeline with LoRA and optimize it for inference. 23 | 24 | Args: 25 | device (str): The device to run the model on (e.g., 'cuda:1'). 26 | """ 27 | self.device = device 28 | self.sdxl_pipeline = DiffusionPipeline.from_pretrained( 29 | tti_settings.MODEL_NAME, 30 | torch_dtype=torch.bfloat16 31 | ).to(self.device) 32 | self.sdxl_pipeline.load_lora_weights(tti_settings.ADAPTER_NAME) 33 | self.sdxl_pipeline.fuse_lora() 34 | self.sdxl_pipeline.unet.to(memory_format=torch.channels_last) 35 | if tti_settings.ENABLE_COMPILE: 36 | self.sdxl_pipeline.unet = torch.compile(self.sdxl_pipeline.unet, mode="max-autotune") 37 | self.sdxl_pipeline.vae.decode = torch.compile(self.sdxl_pipeline.vae.decode, mode="max-autotune") 38 | self.sdxl_pipeline.fuse_qkv_projections() 39 | 40 | def decode_request(self, request: Dict[str, Any]) -> Dict[str, Any]: 41 | """ 42 | Decode the incoming request and prepare inputs for the model. 43 | 44 | Args: 45 | request (Dict[str, Any]): The raw request data. 46 | 47 | Returns: 48 | Dict[str, Any]: The decoded request with processed inputs. 49 | """ 50 | return { 51 | "prompt": request["prompt"], 52 | "negative_prompt": request.get("negative_prompt", ""), 53 | "num_images": request.get("num_images", 1), 54 | "num_inference_steps": request.get("num_inference_steps", 50), 55 | "guidance_scale": request.get("guidance_scale", 7.5), 56 | "mode": request.get("mode", "s3_json") 57 | } 58 | 59 | def batch(self, inputs: List[Dict[str, Any]]) -> Dict[str, List[Any]]: 60 | """ 61 | Batch multiple inputs together for efficient processing. 62 | 63 | Args: 64 | inputs (List[Dict[str, Any]]): A list of individual inputs. 65 | 66 | Returns: 67 | Dict[str, List[Any]]: A dictionary of batched inputs. 68 | """ 69 | return { 70 | "prompt": [input["prompt"] for input in inputs], 71 | "negative_prompt": [input["negative_prompt"] for input in inputs], 72 | "num_images": [input["num_images"] for input in inputs], 73 | "num_inference_steps": [input["num_inference_steps"] for input in inputs], 74 | "guidance_scale": [input["guidance_scale"] for input in inputs], 75 | "mode": [input["mode"] for input in inputs] 76 | } 77 | 78 | def predict(self, inputs: Dict[str, List[Any]]) -> List[Dict[str, Any]]: 79 | """ 80 | Run predictions on the batched inputs. 81 | 82 | Args: 83 | inputs (Dict[str, List[Any]]): Batched inputs for the model. 84 | 85 | Returns: 86 | List[Dict[str, Any]]: A list of dictionaries containing generated images and their modes. 87 | """ 88 | total_images = sum(inputs["num_images"]) 89 | images = self.sdxl_pipeline( 90 | prompt=inputs["prompt"], 91 | negative_prompt=inputs["negative_prompt"], 92 | num_images_per_prompt=1, # Generate one image per prompt 93 | num_inference_steps=inputs["num_inference_steps"][0], # Use the first value 94 | guidance_scale=inputs["guidance_scale"][0], # Use the first value 95 | ).images 96 | 97 | # Repeat images based on num_images and pair with modes 98 | results = [] 99 | for img, num, mode in zip(images, inputs["num_images"], inputs["mode"]): 100 | results.extend([{"image": img, "mode": mode} for _ in range(num)]) 101 | 102 | return results[:total_images] 103 | 104 | def unbatch(self, outputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 105 | """ 106 | Unbatch the outputs from the predict method. 107 | 108 | Args: 109 | outputs (List[Dict[str, Any]]): The batched outputs from predict. 110 | 111 | Returns: 112 | List[Dict[str, Any]]: The unbatched list of outputs. 113 | """ 114 | return outputs 115 | 116 | def encode_response(self, output: Dict[str, Any]) -> Dict[str, Any]: 117 | """ 118 | Encode the model output into a response payload. 119 | 120 | Args: 121 | output (Dict[str, Any]): The generated image and its mode. 122 | 123 | Returns: 124 | Dict[str, Any]: The encoded response with either S3 URL or base64 encoded image. 125 | """ 126 | mode = output["mode"] 127 | image = output["image"] 128 | if mode == "s3_json": 129 | return pil_to_s3_json(image, "sdxl_image") 130 | elif mode == "b64_json": 131 | return pil_to_b64_json(image) 132 | else: 133 | raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.") 134 | 135 | if __name__ == "__main__": 136 | api = SDXLLoraAPI() 137 | #server = ls.LitServer( 138 | # api, 139 | # accelerator="auto", 140 | # max_batch_size=tti_settings.MAX_BATCH_SIZE, 141 | # batch_timeout=tti_settings.MAX_QUEUE_DELAY_MICROSECONDS / 1e6, 142 | # ) 143 | #server.run(port=8000) -------------------------------------------------------------------------------- /scripts/flux_inference.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Tuple 3 | from functools import lru_cache 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | from diffusers import FluxInpaintPipeline, FluxTransformer2DModel 8 | from torchao.quantization.quant_api import quantize_, int8_weight_only 9 | 10 | class FluxInpaintingInference: 11 | """ 12 | A class to perform image inpainting using the FLUX model with int8 quantization for efficient inference. 13 | 14 | Attributes: 15 | MAX_SEED (int): The maximum value for a random seed. 16 | DEVICE (str): The device to run the model on ('cuda' or 'cpu'). 17 | IMAGE_SIZE (int): The maximum size for the input image dimensions. 18 | """ 19 | 20 | MAX_SEED = np.iinfo(np.int32).max 21 | IMAGE_SIZE = 1024 22 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 23 | 24 | _pipeline = None 25 | 26 | @classmethod 27 | @lru_cache(maxsize=1) 28 | def get_pipeline(cls, model_name: str, torch_dtype): 29 | """ 30 | Loads and caches the FluxInpaintPipeline with int8 quantization. 31 | 32 | Args: 33 | model_name (str): The name of the model to be loaded from Hugging Face Hub. 34 | torch_dtype: The data type to be used by PyTorch. 35 | 36 | Returns: 37 | FluxInpaintPipeline: The loaded and optimized pipeline. 38 | """ 39 | if cls._pipeline is None: 40 | # Load the transformer with int8 quantization 41 | transformer = FluxTransformer2DModel.from_pretrained( 42 | model_name, 43 | subfolder="transformer", 44 | torch_dtype=torch_dtype 45 | ) 46 | quantize_(transformer, int8_weight_only()) 47 | 48 | # Load the rest of the pipeline 49 | cls._pipeline = FluxInpaintPipeline.from_pretrained( 50 | model_name, 51 | transformer=transformer, 52 | torch_dtype=torch_dtype 53 | ) 54 | 55 | # Additional optimizations 56 | cls._pipeline.vae.enable_slicing() 57 | cls._pipeline.vae.enable_tiling() 58 | cls._pipeline.to(cls.DEVICE) 59 | 60 | return cls._pipeline 61 | 62 | def __init__( 63 | self, 64 | model_name: str = "black-forest-labs/FLUX.1-schnell", 65 | torch_dtype=torch.bfloat16, 66 | ): 67 | """ 68 | Initializes the FluxInpaintingInference class with a specified model and optimizations. 69 | 70 | Args: 71 | model_name (str): The name of the model to be loaded from Hugging Face Hub. 72 | torch_dtype: The data type to be used by PyTorch (e.g., torch.bfloat16). 73 | """ 74 | self.pipeline = self.get_pipeline(model_name, torch_dtype) 75 | 76 | @staticmethod 77 | def calculate_new_dimensions( 78 | original_dimensions: Tuple[int, int], max_dimension: int = 1024 79 | ) -> Tuple[int, int]: 80 | """ 81 | Calculates new image dimensions while maintaining aspect ratio and ensuring divisibility by 32. 82 | 83 | Args: 84 | original_dimensions (Tuple[int, int]): The original width and height of the image. 85 | max_dimension (int): The maximum dimension size. 86 | 87 | Returns: 88 | Tuple[int, int]: The new width and height. 89 | """ 90 | width, height = original_dimensions 91 | 92 | # Calculate scaling factor 93 | scaling_factor = min(max_dimension / width, max_dimension / height, 1.0) 94 | 95 | # Calculate new dimensions and make them divisible by 32 96 | new_width = int((width * scaling_factor) // 32 * 32) 97 | new_height = int((height * scaling_factor) // 32 * 32) 98 | 99 | # Ensure minimum size of 32x32 100 | new_width = max(32, new_width) 101 | new_height = max(32, new_height) 102 | 103 | return new_width, new_height 104 | 105 | def generate_inpainting( 106 | self, 107 | input_image: Image.Image, 108 | mask_image: Image.Image, 109 | prompt: str, 110 | seed: int = None, 111 | randomize_seed: bool = False, 112 | strength: float = 0.8, 113 | num_inference_steps: int = 50, 114 | guidance_scale: float = 0.0, 115 | max_sequence_length: int = 256, 116 | ) -> Image.Image: 117 | """ 118 | Generates an inpainted image based on the provided inputs. 119 | 120 | Args: 121 | input_image (Image.Image): The original image to be inpainted. 122 | mask_image (Image.Image): The mask indicating areas to be inpainted (white areas are inpainted). 123 | prompt (str): Text prompt guiding the inpainting. 124 | seed (int, optional): Seed for random number generation. Defaults to None. 125 | randomize_seed (bool, optional): Whether to randomize the seed. Defaults to False. 126 | strength (float, optional): Strength of the inpainting effect (0.0 to 1.0). Defaults to 0.8. 127 | num_inference_steps (int, optional): Number of denoising steps. Defaults to 50. 128 | guidance_scale (float, optional): Scale for classifier-free guidance. Defaults to 0.0. 129 | max_sequence_length (int, optional): Maximum sequence length for the transformer. Defaults to 256. 130 | 131 | Returns: 132 | Image.Image: The resulting inpainted image. 133 | """ 134 | if randomize_seed or seed is None: 135 | seed = random.randint(0, self.MAX_SEED) 136 | 137 | generator = torch.Generator(device=self.DEVICE).manual_seed(seed) 138 | 139 | # Resize images 140 | new_width, new_height = self.calculate_new_dimensions(input_image.size) 141 | input_image = input_image.resize((new_width, new_height), Image.LANCZOS) 142 | mask_image = mask_image.resize((new_width, new_height), Image.LANCZOS) 143 | 144 | # Run inference 145 | result = self.pipeline( 146 | prompt=prompt, 147 | image=input_image, 148 | mask_image=mask_image, 149 | strength=strength, 150 | num_inference_steps=num_inference_steps, 151 | generator=generator, 152 | guidance_scale=guidance_scale, 153 | max_sequence_length=max_sequence_length, 154 | ).images[0] 155 | 156 | return result -------------------------------------------------------------------------------- /iac/terraform.tfstate: -------------------------------------------------------------------------------- 1 | { 2 | "version": 4, 3 | "terraform_version": "1.8.2", 4 | "serial": 19, 5 | "lineage": "7f24f129-9566-aba8-5a57-55c73ab4a868", 6 | "outputs": {}, 7 | "resources": [ 8 | { 9 | "mode": "managed", 10 | "type": "aws_s3_bucket", 11 | "name": "diffusion_model_bucket", 12 | "provider": "provider[\"registry.terraform.io/hashicorp/aws\"]", 13 | "instances": [ 14 | { 15 | "schema_version": 0, 16 | "attributes": { 17 | "acceleration_status": "", 18 | "acl": null, 19 | "arn": "arn:aws:s3:::diffusion-model-bucket", 20 | "bucket": "diffusion-model-bucket", 21 | "bucket_domain_name": "diffusion-model-bucket.s3.amazonaws.com", 22 | "bucket_prefix": "", 23 | "bucket_regional_domain_name": "diffusion-model-bucket.s3.ap-south-1.amazonaws.com", 24 | "cors_rule": [], 25 | "force_destroy": false, 26 | "grant": [ 27 | { 28 | "id": "ad7ce3402e9e4f4834521dfae4941257bbcec051b293bd2322daed9a2632f03a", 29 | "permissions": [ 30 | "FULL_CONTROL" 31 | ], 32 | "type": "CanonicalUser", 33 | "uri": "" 34 | } 35 | ], 36 | "hosted_zone_id": "Z11RGJOFQNVJUP", 37 | "id": "diffusion-model-bucket", 38 | "lifecycle_rule": [], 39 | "logging": [], 40 | "object_lock_configuration": [], 41 | "object_lock_enabled": false, 42 | "policy": "", 43 | "region": "ap-south-1", 44 | "replication_configuration": [], 45 | "request_payer": "BucketOwner", 46 | "server_side_encryption_configuration": [ 47 | { 48 | "rule": [ 49 | { 50 | "apply_server_side_encryption_by_default": [ 51 | { 52 | "kms_master_key_id": "", 53 | "sse_algorithm": "AES256" 54 | } 55 | ], 56 | "bucket_key_enabled": false 57 | } 58 | ] 59 | } 60 | ], 61 | "tags": { 62 | "Name": "Diffusion Model Bucket", 63 | "Product": "Product Diffusion API", 64 | "Task": "SDXL LORA" 65 | }, 66 | "tags_all": { 67 | "Name": "Diffusion Model Bucket", 68 | "Product": "Product Diffusion API", 69 | "Task": "SDXL LORA" 70 | }, 71 | "timeouts": null, 72 | "versioning": [ 73 | { 74 | "enabled": false, 75 | "mfa_delete": false 76 | } 77 | ], 78 | "website": [], 79 | "website_domain": null, 80 | "website_endpoint": null 81 | }, 82 | "sensitive_attributes": [], 83 | "private": "eyJlMmJmYjczMC1lY2FhLTExZTYtOGY4OC0zNDM2M2JjN2M0YzAiOnsiY3JlYXRlIjoxMjAwMDAwMDAwMDAwLCJkZWxldGUiOjM2MDAwMDAwMDAwMDAsInJlYWQiOjEyMDAwMDAwMDAwMDAsInVwZGF0ZSI6MTIwMDAwMDAwMDAwMH19" 84 | } 85 | ] 86 | }, 87 | { 88 | "mode": "managed", 89 | "type": "aws_s3_bucket_acl", 90 | "name": "acl_access", 91 | "provider": "provider[\"registry.terraform.io/hashicorp/aws\"]", 92 | "instances": [ 93 | { 94 | "schema_version": 0, 95 | "attributes": { 96 | "access_control_policy": [ 97 | { 98 | "grant": [ 99 | { 100 | "grantee": [ 101 | { 102 | "display_name": "", 103 | "email_address": "", 104 | "id": "ad7ce3402e9e4f4834521dfae4941257bbcec051b293bd2322daed9a2632f03a", 105 | "type": "CanonicalUser", 106 | "uri": "" 107 | } 108 | ], 109 | "permission": "FULL_CONTROL" 110 | } 111 | ], 112 | "owner": [ 113 | { 114 | "display_name": "", 115 | "id": "ad7ce3402e9e4f4834521dfae4941257bbcec051b293bd2322daed9a2632f03a" 116 | } 117 | ] 118 | } 119 | ], 120 | "acl": "public-read", 121 | "bucket": "diffusion-model-bucket", 122 | "expected_bucket_owner": "", 123 | "id": "diffusion-model-bucket,public-read" 124 | }, 125 | "sensitive_attributes": [], 126 | "private": "bnVsbA==", 127 | "dependencies": [ 128 | "aws_s3_bucket.diffusion_model_bucket", 129 | "aws_s3_bucket_ownership_controls.s3_bucket_acl_ownership", 130 | "aws_s3_bucket_public_access_block.s3_bucket_public_access_block" 131 | ] 132 | } 133 | ] 134 | }, 135 | { 136 | "mode": "managed", 137 | "type": "aws_s3_bucket_ownership_controls", 138 | "name": "s3_bucket_acl_ownership", 139 | "provider": "provider[\"registry.terraform.io/hashicorp/aws\"]", 140 | "instances": [ 141 | { 142 | "schema_version": 0, 143 | "attributes": { 144 | "bucket": "diffusion-model-bucket", 145 | "id": "diffusion-model-bucket", 146 | "rule": [ 147 | { 148 | "object_ownership": "ObjectWriter" 149 | } 150 | ] 151 | }, 152 | "sensitive_attributes": [], 153 | "private": "bnVsbA==", 154 | "dependencies": [ 155 | "aws_s3_bucket.diffusion_model_bucket" 156 | ] 157 | } 158 | ] 159 | }, 160 | { 161 | "mode": "managed", 162 | "type": "aws_s3_bucket_public_access_block", 163 | "name": "s3_bucket_public_access_block", 164 | "provider": "provider[\"registry.terraform.io/hashicorp/aws\"]", 165 | "instances": [ 166 | { 167 | "schema_version": 0, 168 | "attributes": { 169 | "block_public_acls": false, 170 | "block_public_policy": false, 171 | "bucket": "diffusion-model-bucket", 172 | "id": "diffusion-model-bucket", 173 | "ignore_public_acls": true, 174 | "restrict_public_buckets": true 175 | }, 176 | "sensitive_attributes": [], 177 | "private": "bnVsbA==", 178 | "dependencies": [ 179 | "aws_s3_bucket.diffusion_model_bucket" 180 | ] 181 | } 182 | ] 183 | } 184 | ], 185 | "check_results": null 186 | } 187 | -------------------------------------------------------------------------------- /serverless/text-to-image/run_text-to-image.py: -------------------------------------------------------------------------------- 1 | import runpod 2 | import torch 3 | import asyncio 4 | import logging 5 | from typing import Dict, Any, List, AsyncGenerator 6 | from PIL import Image 7 | from diffusers import DiffusionPipeline 8 | from config_settings import settings 9 | from configs.tti_settings import tti_settings 10 | from scripts.api_utils import pil_to_b64_json, pil_to_s3_json 11 | 12 | # Set up logging 13 | logging.basicConfig(level=logging.INFO) 14 | logger = logging.getLogger(__name__) 15 | 16 | # Global pipeline instance 17 | global_pipeline = None 18 | device = "cuda" if torch.cuda.is_available() else "cpu" 19 | 20 | async def initialize_pipeline(): 21 | """ 22 | Initialize and optimize the SDXL pipeline with LoRA. 23 | """ 24 | global global_pipeline 25 | 26 | if global_pipeline is None: 27 | logger.info("Initializing SDXL pipeline...") 28 | 29 | # Run model loading in thread pool 30 | global_pipeline = await asyncio.to_thread( 31 | DiffusionPipeline.from_pretrained, 32 | tti_settings.MODEL_NAME, 33 | torch_dtype=torch.bfloat16 34 | ) 35 | global_pipeline.to(device) 36 | 37 | logger.info("Loading LoRA weights...") 38 | await asyncio.to_thread(global_pipeline.load_lora_weights, tti_settings.ADAPTER_NAME) 39 | await asyncio.to_thread(global_pipeline.fuse_lora) 40 | 41 | logger.info("Optimizing pipeline...") 42 | global_pipeline.unet.to(memory_format=torch.channels_last) 43 | if tti_settings.ENABLE_COMPILE: 44 | global_pipeline.unet = await asyncio.to_thread( 45 | torch.compile, 46 | global_pipeline.unet, 47 | mode="max-autotune" 48 | ) 49 | global_pipeline.vae.decode = await asyncio.to_thread( 50 | torch.compile, 51 | global_pipeline.vae.decode, 52 | mode="max-autotune" 53 | ) 54 | await asyncio.to_thread(global_pipeline.fuse_qkv_projections) 55 | logger.info("Pipeline initialization complete") 56 | 57 | return global_pipeline 58 | 59 | def decode_request(request: Dict[str, Any]) -> Dict[str, Any]: 60 | """Decode and validate the incoming request.""" 61 | return { 62 | "prompt": request["prompt"], 63 | "negative_prompt": request.get("negative_prompt", ""), 64 | "num_images": request.get("num_images", 1), 65 | "num_inference_steps": request.get("num_inference_steps", 50), 66 | "guidance_scale": request.get("guidance_scale", 7.5), 67 | "mode": request.get("mode", "s3_json") 68 | } 69 | 70 | async def generate_images(params: Dict[str, Any], pipeline: DiffusionPipeline) -> List[Dict[str, Any]]: 71 | """Generate images using the SDXL pipeline asynchronously.""" 72 | images = await asyncio.to_thread( 73 | pipeline, 74 | prompt=params["prompt"], 75 | negative_prompt=params["negative_prompt"], 76 | num_images_per_prompt=params["num_images"], 77 | num_inference_steps=params["num_inference_steps"], 78 | guidance_scale=params["guidance_scale"], 79 | ) 80 | 81 | return [{"image": img, "mode": params["mode"]} for img in images.images] 82 | 83 | async def encode_response(output: Dict[str, Any]) -> Dict[str, Any]: 84 | """Encode the generated image asynchronously.""" 85 | mode = output["mode"] 86 | image = output["image"] 87 | 88 | if mode == "s3_json": 89 | return await asyncio.to_thread(pil_to_s3_json, image, "sdxl_image") 90 | elif mode == "b64_json": 91 | return await asyncio.to_thread(pil_to_b64_json, image) 92 | else: 93 | raise ValueError("Invalid mode. Supported modes are 'b64_json' and 's3_json'.") 94 | 95 | async def async_generator_handler(job: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]: 96 | """ 97 | Async generator handler for RunPod with progress updates. 98 | """ 99 | try: 100 | # Initial status 101 | yield {"status": "starting", "message": "Initializing image generation process"} 102 | 103 | # Initialize pipeline 104 | pipeline = await initialize_pipeline() 105 | yield {"status": "processing", "message": "Pipeline loaded successfully"} 106 | 107 | # Decode request 108 | try: 109 | params = decode_request(job['input']) 110 | yield { 111 | "status": "processing", 112 | "message": "Request decoded successfully", 113 | "params": { 114 | "prompt": params["prompt"], 115 | "num_images": params["num_images"], 116 | "steps": params["num_inference_steps"] 117 | } 118 | } 119 | except Exception as e: 120 | logger.error(f"Request decode error: {e}") 121 | yield {"status": "error", "message": f"Error decoding request: {str(e)}"} 122 | return 123 | 124 | # Generate images 125 | try: 126 | yield {"status": "processing", "message": "Generating images"} 127 | outputs = await generate_images(params, pipeline) 128 | yield {"status": "processing", "message": f"Generated {len(outputs)} images successfully"} 129 | except Exception as e: 130 | logger.error(f"Generation error: {e}") 131 | yield {"status": "error", "message": f"Error generating images: {str(e)}"} 132 | return 133 | 134 | # Encode responses 135 | try: 136 | yield {"status": "processing", "message": "Encoding and uploading images"} 137 | results = [] 138 | for idx, output in enumerate(outputs, 1): 139 | result = await encode_response(output) 140 | results.append(result) 141 | yield { 142 | "status": "processing", 143 | "message": f"Processed image {idx}/{len(outputs)}" 144 | } 145 | except Exception as e: 146 | logger.error(f"Encoding error: {e}") 147 | yield {"status": "error", "message": f"Error encoding images: {str(e)}"} 148 | return 149 | 150 | # Final response 151 | final_response = results[0] if len(results) == 1 else {"results": results} 152 | yield { 153 | "status": "completed", 154 | "output": final_response 155 | } 156 | 157 | except Exception as e: 158 | logger.error(f"Unexpected error: {e}") 159 | yield { 160 | "status": "error", 161 | "message": f"Unexpected error: {str(e)}" 162 | } 163 | 164 | # Initialize pipeline at startup 165 | logger.info("Initializing service...") 166 | asyncio.get_event_loop().run_until_complete(initialize_pipeline()) 167 | logger.info("Service initialization complete") 168 | 169 | if __name__ == "__main__": 170 | runpod.serverless.start({ 171 | "handler": async_generator_handler, 172 | "return_aggregate_stream": True 173 | }) -------------------------------------------------------------------------------- /serverless/outpainting/run_outpainting.py: -------------------------------------------------------------------------------- 1 | import io 2 | import base64 3 | import time 4 | import asyncio 5 | from typing import Dict, Any, Tuple, AsyncGenerator 6 | from PIL import Image 7 | from pydantic import BaseModel, Field 8 | from scripts.outpainting import Outpainter 9 | from scripts.api_utils import pil_to_s3_json 10 | 11 | class OutpaintingRequest(BaseModel): 12 | """ 13 | Pydantic model representing a request for outpainting inference. 14 | 15 | This model defines the structure and validation rules for incoming API requests. 16 | All fields are required unless otherwise specified. 17 | """ 18 | image: str = Field(..., description="Base64 encoded input image") 19 | width: int = Field(1024, description="Target width") 20 | height: int = Field(1024, description="Target height") 21 | overlap_percentage: int = Field(10, description="Mask overlap percentage") 22 | num_inference_steps: int = Field(8, description="Number of inference steps") 23 | resize_option: str = Field("Full", description="Resize option") 24 | custom_resize_percentage: int = Field(100, description="Custom resize percentage") 25 | prompt_input: str = Field("", description="Prompt for generation") 26 | alignment: str = Field("Middle", description="Image alignment") 27 | overlap_left: bool = Field(True, description="Apply overlap on left side") 28 | overlap_right: bool = Field(True, description="Apply overlap on right side") 29 | overlap_top: bool = Field(True, description="Apply overlap on top side") 30 | overlap_bottom: bool = Field(True, description="Apply overlap on bottom side") 31 | 32 | class OutpaintingService: 33 | """ 34 | Service class for handling outpainting operations. 35 | Based on LitAPI implementation but adapted for RunPod. 36 | """ 37 | 38 | def __init__(self, device: str = "cuda"): 39 | """Initialize the outpainting service.""" 40 | self.device = device 41 | self.outpainter = Outpainter() 42 | 43 | async def decode_request(self, request: Dict[str, Any]) -> Dict[str, Any]: 44 | """ 45 | Decode the incoming request and prepare inputs for the model. 46 | 47 | Args: 48 | request: The raw request data. 49 | 50 | Returns: 51 | Dict containing decoded image and request parameters. 52 | 53 | Raises: 54 | ValueError: If request is invalid or cannot be processed. 55 | """ 56 | try: 57 | outpainting_request = OutpaintingRequest(**request) 58 | # Run decode in thread pool 59 | image_data = await asyncio.to_thread( 60 | base64.b64decode, outpainting_request.image 61 | ) 62 | image = await asyncio.to_thread( 63 | lambda: Image.open(io.BytesIO(image_data)).convert("RGBA") 64 | ) 65 | 66 | return { 67 | 'image': image, 68 | 'params': outpainting_request.model_dump() 69 | } 70 | except Exception as e: 71 | raise ValueError(f"Invalid request: {str(e)}") 72 | 73 | async def predict(self, inputs: Dict[str, Any]) -> Tuple[Image.Image, float]: 74 | """ 75 | Run predictions on the input. 76 | 77 | Args: 78 | inputs: Dict containing image and outpainting parameters. 79 | 80 | Returns: 81 | Tuple containing the resulting image and completion time. 82 | """ 83 | image = inputs['image'] 84 | params = inputs['params'] 85 | 86 | start_time = time.time() 87 | 88 | # Run outpainting in thread pool 89 | result = await asyncio.to_thread( 90 | self.outpainter.outpaint, 91 | image, 92 | params['width'], 93 | params['height'], 94 | params['overlap_percentage'], 95 | params['num_inference_steps'], 96 | params['resize_option'], 97 | params['custom_resize_percentage'], 98 | params['prompt_input'], 99 | params['alignment'], 100 | params['overlap_left'], 101 | params['overlap_right'], 102 | params['overlap_top'], 103 | params['overlap_bottom'] 104 | ) 105 | 106 | completion_time = time.time() - start_time 107 | return result, completion_time 108 | 109 | async def encode_response(self, output: Tuple[Image.Image, float]) -> Dict[str, Any]: 110 | """ 111 | Encode the model output into a response payload. 112 | 113 | Args: 114 | output: Tuple containing outpainted image and completion time. 115 | 116 | Returns: 117 | Dict containing S3 URL and metadata. 118 | """ 119 | image, completion_time = output 120 | # Run S3 upload in thread pool 121 | img_str = await asyncio.to_thread(pil_to_s3_json, image, "outpainting_image") 122 | 123 | return { 124 | "result": img_str, 125 | "completion_time": round(completion_time, 2), 126 | "image_resolution": f"{image.width}x{image.height}" 127 | } 128 | 129 | async def async_generator_handler(job: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]: 130 | """ 131 | Async generator handler for RunPod with progress updates. 132 | """ 133 | try: 134 | # Create service instance 135 | service = OutpaintingService(device="cuda") 136 | yield {"status": "starting", "message": "Service initialized"} 137 | 138 | # Decode request 139 | try: 140 | inputs = await service.decode_request(job['input']) 141 | yield { 142 | "status": "processing", 143 | "message": "Request decoded successfully", 144 | "input_resolution": f"{inputs['image'].width}x{inputs['image'].height}" 145 | } 146 | except Exception as e: 147 | yield {"status": "error", "message": f"Error decoding request: {str(e)}"} 148 | return 149 | 150 | # Generate prediction 151 | try: 152 | yield {"status": "processing", "message": "Starting outpainting"} 153 | result = await service.predict(inputs) 154 | yield { 155 | "status": "processing", 156 | "message": "Outpainting completed", 157 | "completion_time": f"{result[1]:.2f}s" 158 | } 159 | except Exception as e: 160 | yield {"status": "error", "message": f"Error during outpainting: {str(e)}"} 161 | return 162 | 163 | # Encode response 164 | try: 165 | yield {"status": "processing", "message": "Encoding result"} 166 | response = await service.encode_response(result) 167 | yield {"status": "processing", "message": "Result encoded successfully"} 168 | except Exception as e: 169 | yield {"status": "error", "message": f"Error encoding result: {str(e)}"} 170 | return 171 | 172 | # Final response 173 | yield { 174 | "status": "completed", 175 | "output": response 176 | } 177 | 178 | except Exception as e: 179 | yield { 180 | "status": "error", 181 | "message": f"Unexpected error: {str(e)}" 182 | } 183 | 184 | if __name__ == "__main__": 185 | import runpod 186 | runpod.serverless.start({ 187 | "handler": async_generator_handler, 188 | "return_aggregate_stream": True 189 | 190 | }) -------------------------------------------------------------------------------- /serverless/image-to-video/run_image-to-video.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import runpod 4 | import tempfile 5 | import time 6 | import asyncio 7 | from typing import Dict, Any, List, Union, Tuple, AsyncGenerator 8 | from PIL import Image 9 | import base64 10 | from pydantic import BaseModel, Field 11 | from diffusers.utils import export_to_video 12 | from scripts.api_utils import mp4_to_s3_json 13 | from scripts.image_to_video import ImageToVideoPipeline 14 | 15 | # Global pipeline instance 16 | global_pipeline = None 17 | 18 | class ImageToVideoRequest(BaseModel): 19 | """ 20 | Pydantic model representing a request for image-to-video generation. 21 | """ 22 | image: str = Field(..., description="Base64 encoded input image") 23 | prompt: str = Field(..., description="Text prompt for video generation") 24 | num_frames: int = Field(49, description="Number of frames to generate") 25 | num_inference_steps: int = Field(50, description="Number of inference steps") 26 | guidance_scale: float = Field(6.0, description="Guidance scale") 27 | height: int = Field(480, description="Height of the output video") 28 | width: int = Field(720, description="Width of the output video") 29 | use_dynamic_cfg: bool = Field(True, description="Use dynamic CFG") 30 | fps: int = Field(30, description="Frames per second for the output video") 31 | 32 | async def initialize_pipeline(): 33 | """Initialize the pipeline if not already loaded""" 34 | global global_pipeline 35 | if global_pipeline is None: 36 | print("Initializing Image to Video pipeline...") 37 | global_pipeline = ImageToVideoPipeline(device="cuda") 38 | print("Pipeline initialized successfully") 39 | return global_pipeline 40 | 41 | async def decode_request(request: Dict[str, Any]) -> Dict[str, Any]: 42 | """ 43 | Decode and validate the incoming video generation request asynchronously. 44 | """ 45 | try: 46 | video_request = ImageToVideoRequest(**request) 47 | # Run decode in thread pool 48 | image_data = await asyncio.to_thread(base64.b64decode, video_request.image) 49 | image = await asyncio.to_thread( 50 | lambda: Image.open(io.BytesIO(image_data)).convert("RGB") 51 | ) 52 | 53 | return { 54 | 'image': image, 55 | 'params': video_request.model_dump() 56 | } 57 | except Exception as e: 58 | raise ValueError(f"Invalid request: {str(e)}") 59 | 60 | async def generate_frames(inputs: Dict[str, Any], pipeline: ImageToVideoPipeline) -> Tuple[List[Image.Image], float]: 61 | """ 62 | Generate video frames using the pipeline asynchronously. 63 | """ 64 | start_time = time.time() 65 | 66 | # Run generation in thread pool 67 | frames = await asyncio.to_thread( 68 | pipeline.generate, 69 | prompt=inputs['params']['prompt'], 70 | image=inputs['image'], 71 | num_frames=inputs['params']['num_frames'], 72 | num_inference_steps=inputs['params']['num_inference_steps'], 73 | guidance_scale=inputs['params']['guidance_scale'], 74 | height=inputs['params']['height'], 75 | width=inputs['params']['width'], 76 | use_dynamic_cfg=inputs['params']['use_dynamic_cfg'] 77 | ) 78 | 79 | if isinstance(frames, tuple): 80 | frames = frames[0] 81 | elif hasattr(frames, 'frames'): 82 | frames = frames.frames[0] 83 | 84 | completion_time = time.time() - start_time 85 | return frames, completion_time 86 | 87 | async def create_video_response(frames: List[Image.Image], completion_time: float, fps: int) -> Dict[str, Any]: 88 | """ 89 | Create video file and generate response with S3 URL asynchronously. 90 | """ 91 | def create_video(): 92 | with tempfile.TemporaryDirectory() as temp_dir: 93 | temp_video_path = os.path.join(temp_dir, "generated_video.mp4") 94 | export_to_video(frames, temp_video_path, fps=fps) 95 | 96 | with open(temp_video_path, "rb") as video_file: 97 | return mp4_to_s3_json( 98 | video_file, 99 | f"generated_video_{int(time.time())}.mp4" 100 | ) 101 | 102 | # Run video creation and upload in thread pool 103 | s3_response = await asyncio.to_thread(create_video) 104 | 105 | return { 106 | "result": s3_response, 107 | "completion_time": round(completion_time, 2), 108 | "video_resolution": f"{frames[0].width}x{frames[0].height}", 109 | "fps": fps 110 | } 111 | 112 | async def async_generator_handler(job: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]: 113 | """ 114 | Async generator handler for RunPod with progress updates. 115 | """ 116 | try: 117 | # Initial status 118 | yield {"status": "starting", "message": "Initializing video generation process"} 119 | 120 | # Initialize pipeline 121 | pipeline = await initialize_pipeline() 122 | yield {"status": "processing", "message": "Pipeline loaded successfully"} 123 | 124 | # Decode request 125 | try: 126 | inputs = await decode_request(job['input']) 127 | yield {"status": "processing", "message": "Request decoded successfully"} 128 | except Exception as e: 129 | yield {"status": "error", "message": f"Error decoding request: {str(e)}"} 130 | return 131 | 132 | # Generate frames with progress updates 133 | try: 134 | yield {"status": "processing", "message": "Generating video frames"} 135 | frames, completion_time = await generate_frames(inputs, pipeline) 136 | yield {"status": "processing", "message": f"Generated {len(frames)} frames successfully"} 137 | except Exception as e: 138 | yield {"status": "error", "message": f"Error generating frames: {str(e)}"} 139 | return 140 | 141 | # Create and upload video 142 | try: 143 | yield {"status": "processing", "message": "Creating and uploading video"} 144 | response = await create_video_response( 145 | frames, 146 | completion_time, 147 | inputs['params']['fps'] 148 | ) 149 | yield {"status": "processing", "message": "Video uploaded successfully"} 150 | except Exception as e: 151 | yield {"status": "error", "message": f"Error creating video: {str(e)}"} 152 | return 153 | 154 | # Final response 155 | yield { 156 | "status": "completed", 157 | "output": response 158 | } 159 | 160 | except Exception as e: 161 | yield { 162 | "status": "error", 163 | "message": f"Unexpected error: {str(e)}" 164 | } 165 | 166 | def calculate_progress(current_frame: int, total_frames: int) -> dict: 167 | """Calculate progress percentage and create status update.""" 168 | progress = (current_frame / total_frames) * 100 169 | return { 170 | "status": "processing", 171 | "progress": round(progress, 2), 172 | "message": f"Generating frame {current_frame}/{total_frames}" 173 | } 174 | 175 | # Initialize the pipeline when the service starts 176 | print("Initializing service...") 177 | asyncio.get_event_loop().run_until_complete(initialize_pipeline()) 178 | print("Service initialization complete") 179 | 180 | if __name__ == "__main__": 181 | runpod.serverless.start({ 182 | "handler": async_generator_handler, 183 | "return_aggregate_stream": True 184 | }) -------------------------------------------------------------------------------- /scripts/api_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module provides utilities for device management, memory handling, and file operations. 3 | It includes functions for accelerator selection, memory clearing, and image/video processing. 4 | """ 5 | 6 | import os 7 | import platform 8 | import subprocess 9 | import sys 10 | from functools import lru_cache 11 | from typing import List, Optional, Union 12 | import gc 13 | import io 14 | from io import BytesIO 15 | import base64 16 | import uuid 17 | import torch 18 | from PIL import Image 19 | from scripts.s3_manager import S3ManagerService 20 | 21 | # Device Management 22 | class DeviceManager: 23 | """ 24 | Manages device selection for accelerated computing. 25 | 26 | This class handles the selection of appropriate computing devices (CPU, CUDA, MPS) 27 | based on system availability and user preferences. 28 | 29 | Attributes: 30 | _accelerator (str): The type of accelerator being used. 31 | _devices (Union[List[int], int]): The devices available for use. 32 | """ 33 | 34 | def __init__(self, accelerator: str = "auto", devices: Union[List[int], int, str] = "auto"): 35 | """ 36 | Initialize the DeviceManager. 37 | 38 | Args: 39 | accelerator (str): The type of accelerator to use. Defaults to "auto". 40 | devices (Union[List[int], int, str]): The devices to use. Defaults to "auto". 41 | """ 42 | self._accelerator = self._sanitize_accelerator(accelerator) 43 | self._devices = self._setup_devices(devices) 44 | 45 | @property 46 | def accelerator(self): 47 | """Get the current accelerator type.""" 48 | return self._accelerator 49 | 50 | @property 51 | def devices(self): 52 | """Get the current devices in use.""" 53 | return self._devices 54 | 55 | @staticmethod 56 | def _sanitize_accelerator(accelerator: Optional[str]): 57 | """Sanitize the accelerator input.""" 58 | if isinstance(accelerator, str): 59 | accelerator = accelerator.lower() 60 | if accelerator not in ["auto", "cpu", "mps", "cuda", "gpu", None]: 61 | raise ValueError("accelerator must be one of 'auto', 'cpu', 'mps', 'cuda', or 'gpu'") 62 | return "auto" if accelerator is None else accelerator 63 | 64 | def _setup_devices(self, devices: Union[List[int], int, str]): 65 | """Set up the devices based on input and availability.""" 66 | if devices == "auto": 67 | return self._auto_device_count() 68 | elif isinstance(devices, int): 69 | return min(devices, self._auto_device_count()) 70 | elif isinstance(devices, list): 71 | return [dev for dev in devices if dev < self._auto_device_count()] 72 | else: 73 | raise ValueError("devices must be 'auto', an integer, or a list of integers") 74 | 75 | def _auto_device_count(self) -> int: 76 | """Automatically determine the number of available devices.""" 77 | if self._accelerator == "cuda": 78 | return check_cuda_with_nvidia_smi() 79 | elif self._accelerator == "mps": 80 | return 1 81 | elif self._accelerator == "cpu": 82 | return os.cpu_count() or 1 83 | else: 84 | return 1 85 | 86 | def _choose_auto_accelerator(self): 87 | """Choose the best available accelerator automatically.""" 88 | gpu_backend = self._choose_gpu_accelerator_backend() 89 | return gpu_backend if gpu_backend else "cpu" 90 | 91 | @staticmethod 92 | def _choose_gpu_accelerator_backend(): 93 | """Choose the appropriate GPU backend if available.""" 94 | if check_cuda_with_nvidia_smi() > 0: 95 | return "cuda" 96 | if torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64"): 97 | return "mps" 98 | return None 99 | 100 | @lru_cache(maxsize=1) 101 | def check_cuda_with_nvidia_smi() -> int: 102 | """ 103 | Check CUDA availability using nvidia-smi. 104 | 105 | Returns: 106 | int: The number of available CUDA devices. 107 | """ 108 | try: 109 | nvidia_smi_output = subprocess.check_output(["nvidia-smi", "-L"]).decode("utf-8").strip() 110 | devices = [el for el in nvidia_smi_output.split("\n") if el.startswith("GPU")] 111 | devices = [el.split(":")[0].split()[1] for el in devices] 112 | visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES") 113 | if visible_devices: 114 | devices = [el for el in devices if el in visible_devices.split(",")] 115 | return len(devices) 116 | except (subprocess.CalledProcessError, FileNotFoundError): 117 | return 0 118 | 119 | def accelerator(devices: Union[List[int], int, str] = "auto") -> tuple: 120 | """ 121 | Determine the device accelerator to use based on availability. 122 | 123 | Args: 124 | devices (Union[List[int], int, str]): Specifies the devices to use. 125 | 126 | Returns: 127 | tuple: A tuple containing the accelerator type and available devices. 128 | """ 129 | device_manager = DeviceManager(accelerator="auto", devices=devices) 130 | return device_manager.accelerator, device_manager.devices 131 | 132 | # Memory Management 133 | def clear_memory(): 134 | """ 135 | Clear memory by collecting garbage and emptying the CUDA cache. 136 | """ 137 | gc.collect() 138 | torch.cuda.empty_cache() 139 | 140 | # File Operations 141 | def pil_to_b64_json(image: Image.Image) -> dict: 142 | """ 143 | Convert a PIL image to a base64-encoded JSON object. 144 | 145 | Args: 146 | image (PIL.Image.Image): The PIL image object to be converted. 147 | 148 | Returns: 149 | dict: A dictionary containing the image ID and the base64-encoded image. 150 | """ 151 | image_id = str(uuid.uuid4()) 152 | buffered = BytesIO() 153 | image.save(buffered, format="PNG") 154 | b64_image = base64.b64encode(buffered.getvalue()).decode("utf-8") 155 | return {"image_id": image_id, "b64_image": b64_image} 156 | 157 | def pil_to_s3_json(image: Image.Image, file_name: str) -> dict: 158 | """ 159 | Upload a PIL image to Amazon S3 and return a JSON object with the image ID and signed URL. 160 | 161 | Args: 162 | image (PIL.Image.Image): The PIL image to be uploaded. 163 | file_name (str): The name of the file. 164 | 165 | Returns: 166 | dict: A JSON object containing the image ID and the signed URL. 167 | """ 168 | image_id = str(uuid.uuid4()) 169 | s3_uploader = S3ManagerService() 170 | image_bytes = io.BytesIO() 171 | image.save(image_bytes, format="PNG") 172 | image_bytes.seek(0) 173 | 174 | unique_file_name = s3_uploader.generate_unique_file_name(file_name) 175 | s3_uploader.upload_file(image_bytes, unique_file_name) 176 | signed_url = s3_uploader.generate_signed_url(unique_file_name, exp=43200) # 12 hours 177 | return {"image_id": image_id, "url": signed_url} 178 | 179 | def mp4_to_s3_json(video_bytes: io.BytesIO, file_name: str) -> dict: 180 | """ 181 | Upload an MP4 video to Amazon S3 and return a JSON object with the video ID and signed URL. 182 | 183 | Args: 184 | video_bytes (io.BytesIO): The video data as bytes. 185 | file_name (str): The name of the file. 186 | 187 | Returns: 188 | dict: A JSON object containing the video ID and the signed URL. 189 | """ 190 | video_id = str(uuid.uuid4()) 191 | s3_uploader = S3ManagerService() 192 | 193 | unique_file_name = s3_uploader.generate_unique_file_name(file_name) 194 | s3_uploader.upload_file(video_bytes, unique_file_name) 195 | signed_url = s3_uploader.generate_signed_url(unique_file_name, exp=43200) # 12 hours 196 | return {"video_id": video_id, "url": signed_url} 197 | 198 | if __name__ == "__main__": 199 | acc, devs = accelerator() 200 | print(f"Selected accelerator: {acc}") 201 | print(f"Available devices: {devs}") -------------------------------------------------------------------------------- /api/flux_serve.py: -------------------------------------------------------------------------------- 1 | import io 2 | import json 3 | import base64 4 | import time 5 | import logging 6 | from typing import Dict, Any, List 7 | from pydantic import BaseModel, Field 8 | from PIL import Image 9 | from litserve import LitAPI, LitServer 10 | from scripts.s3_manager import S3ManagerService 11 | from config_settings import settings 12 | from scripts.flux_inference import FluxInpaintingInference 13 | 14 | logging.basicConfig(level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | s3_manager = S3ManagerService() 18 | 19 | class InpaintingRequest(BaseModel): 20 | """ 21 | Model representing an inpainting request. 22 | 23 | Attributes: 24 | prompt (str): The prompt for inpainting. 25 | strength (float): Strength of inpainting effect, between 0.0 and 1.0. 26 | seed (int): Random seed for reproducibility. 27 | num_inference_steps (int): Number of inference steps, between 1 and 1000. 28 | input_image (str): Base64 encoded input image. 29 | mask_image (str): Base64 encoded mask image. 30 | """ 31 | prompt: str = Field(..., description="The prompt for inpainting") 32 | strength: float = Field(0.8, ge=0.0, le=1.0, description="Strength of inpainting effect") 33 | seed: int = Field(42, description="Random seed for reproducibility") 34 | num_inference_steps: int = Field(50, ge=1, le=1000, description="Number of inference steps") 35 | input_image: str = Field(..., description="Base64 encoded input image") 36 | mask_image: str = Field(..., description="Base64 encoded mask image") 37 | 38 | class FluxInpaintingAPI(LitAPI): 39 | """ 40 | API for Flux Inpainting using LitServer. 41 | 42 | This class implements the LitAPI interface to provide inpainting functionality 43 | using the Flux Inpainting model. It handles request decoding, batching, 44 | prediction, and response encoding. 45 | """ 46 | 47 | def setup(self, device: str) -> None: 48 | """ 49 | Initialize the Flux Inpainting model. 50 | 51 | Args: 52 | device (str): The device to run the model on (e.g., 'cpu', 'cuda'). 53 | """ 54 | self.flux_inpainter = FluxInpaintingInference() 55 | self.device = device 56 | 57 | def decode_request(self, request: Dict[str, Any]) -> Dict[str, Any]: 58 | """ 59 | Decode the incoming request into a format suitable for processing. 60 | 61 | Args: 62 | request (Dict[str, Any]): The raw incoming request data. 63 | 64 | Returns: 65 | Dict[str, Any]: A dictionary containing the decoded request data. 66 | 67 | Raises: 68 | Exception: If there's an error in decoding the request. 69 | """ 70 | try: 71 | inpainting_request = InpaintingRequest(**request) 72 | 73 | input_image = Image.open(io.BytesIO(base64.b64decode(inpainting_request.input_image))) 74 | mask_image = Image.open(io.BytesIO(base64.b64decode(inpainting_request.mask_image))) 75 | 76 | return { 77 | "prompt": inpainting_request.prompt, 78 | "input_image": input_image, 79 | "mask_image": mask_image, 80 | "strength": inpainting_request.strength, 81 | "seed": inpainting_request.seed, 82 | "num_inference_steps": inpainting_request.num_inference_steps 83 | } 84 | except Exception as e: 85 | logger.error(f"Error in decode_request: {e}") 86 | raise 87 | 88 | def batch(self, inputs: List[Dict[str, Any]]) -> Dict[str, List[Any]]: 89 | """ 90 | Prepare a batch of inputs for processing. 91 | 92 | Args: 93 | inputs (List[Dict[str, Any]]): A list of individual input dictionaries. 94 | 95 | Returns: 96 | Dict[str, List[Any]]: A dictionary containing batched inputs. 97 | """ 98 | return { 99 | "prompt": [input["prompt"] for input in inputs], 100 | "input_image": [input["input_image"] for input in inputs], 101 | "mask_image": [input["mask_image"] for input in inputs], 102 | "strength": [input["strength"] for input in inputs], 103 | "seed": [input["seed"] for input in inputs], 104 | "num_inference_steps": [input["num_inference_steps"] for input in inputs] 105 | } 106 | 107 | def predict(self, inputs: Dict[str, List[Any]]) -> List[Dict[str, Any]]: 108 | """ 109 | Process a batch of inputs and return the results. 110 | 111 | Args: 112 | inputs (Dict[str, List[Any]]): A dictionary containing batched inputs. 113 | 114 | Returns: 115 | List[Dict[str, Any]]: A list of dictionaries containing the prediction results. 116 | """ 117 | results = [] 118 | for i in range(len(inputs["prompt"])): 119 | start_time = time.time() 120 | try: 121 | result_image = self.flux_inpainter.generate_inpainting( 122 | input_image=inputs["input_image"][i], 123 | mask_image=inputs["mask_image"][i], 124 | prompt=inputs["prompt"][i], 125 | seed=inputs["seed"][i], 126 | strength=inputs["strength"][i], 127 | num_inference_steps=inputs["num_inference_steps"][i] 128 | ) 129 | end_time = time.time() 130 | results.append({ 131 | "image": result_image, 132 | "prompt": inputs["prompt"][i], 133 | "seed": inputs["seed"][i], 134 | "time_taken": end_time - start_time 135 | }) 136 | except Exception as e: 137 | logger.error(f"Error in predict for item {i}: {e}") 138 | results.append(None) 139 | return results 140 | 141 | def unbatch(self, outputs: List[Dict[str, Any]]) -> List[Dict[str, Any]]: 142 | """ 143 | Convert batched outputs back to individual results. 144 | 145 | Args: 146 | outputs (List[Dict[str, Any]]): A list of output dictionaries from the predict method. 147 | 148 | Returns: 149 | List[Dict[str, Any]]: The same list of output dictionaries. 150 | """ 151 | return outputs 152 | 153 | def encode_response(self, output: Dict[str, Any]) -> Dict[str, Any]: 154 | """ 155 | Encode the output image and prepare the response. 156 | 157 | Args: 158 | output (Dict[str, Any]): A dictionary containing the prediction output. 159 | 160 | Returns: 161 | Dict[str, Any]: A dictionary containing the encoded response with the result URL, 162 | prompt, seed, and time taken. 163 | 164 | Raises: 165 | Exception: If there's an error in encoding the response. 166 | """ 167 | if output is None: 168 | return {"error": "Failed to generate image"} 169 | 170 | try: 171 | result_image = output["image"] 172 | buffered = io.BytesIO() 173 | result_image.save(buffered, format="PNG") 174 | 175 | unique_filename = s3_manager.generate_unique_file_name("result.png") 176 | s3_manager.upload_file(io.BytesIO(buffered.getvalue()), unique_filename) 177 | signed_url = s3_manager.generate_signed_url(unique_filename, exp=43200) 178 | 179 | return { 180 | "result_url": signed_url, 181 | "prompt": output["prompt"], 182 | "seed": output["seed"], 183 | "time_taken": output["time_taken"] 184 | } 185 | except Exception as e: 186 | logger.error(f"Error in encode_response: {e}") 187 | return {"error": str(e)} 188 | 189 | if __name__ == "__main__": 190 | api = FluxInpaintingAPI() 191 | server = LitServer( 192 | api, 193 | api_path='/api/v2/inpainting/flux', 194 | accelerator="auto", 195 | max_batch_size=4, 196 | batch_timeout=0.1 197 | ) 198 | server.run(port=8000) -------------------------------------------------------------------------------- /api/outpainting_serve.py: -------------------------------------------------------------------------------- 1 | from litserve import LitAPI, LitServer 2 | from typing import Dict, Any, Tuple 3 | from PIL import Image 4 | import io 5 | import base64 6 | from pydantic import BaseModel, Field 7 | import time 8 | from scripts.outpainting import Outpainter 9 | from scripts.api_utils import pil_to_s3_json 10 | 11 | class OutpaintingRequest(BaseModel): 12 | """ 13 | Pydantic model representing a request for outpainting inference. 14 | 15 | This model defines the structure and validation rules for incoming API requests. 16 | All fields are required unless otherwise specified. 17 | 18 | Attributes: 19 | image (str): Base64 encoded input image. 20 | width (int): Target width for the outpainted image. 21 | height (int): Target height for the outpainted image. 22 | overlap_percentage (int): Percentage of overlap for the mask. 23 | num_inference_steps (int): Number of inference steps for the diffusion process. 24 | resize_option (str): Option for resizing the input image ("Full", "50%", "33%", "25%", or "Custom"). 25 | custom_resize_percentage (int): Custom resize percentage when resize_option is "Custom". 26 | prompt_input (str): Text prompt to guide the outpainting process. 27 | alignment (str): Alignment of the original image within the new canvas. 28 | overlap_left (bool): Whether to apply overlap on the left side. 29 | overlap_right (bool): Whether to apply overlap on the right side. 30 | overlap_top (bool): Whether to apply overlap on the top side. 31 | overlap_bottom (bool): Whether to apply overlap on the bottom side. 32 | """ 33 | 34 | image: str = Field(..., description="Base64 encoded input image") 35 | width: int = Field(1024, description="Target width") 36 | height: int = Field(1024, description="Target height") 37 | overlap_percentage: int = Field(10, description="Mask overlap percentage") 38 | num_inference_steps: int = Field(8, description="Number of inference steps") 39 | resize_option: str = Field("Full", description="Resize option") 40 | custom_resize_percentage: int = Field(100, description="Custom resize percentage") 41 | prompt_input: str = Field("", description="Prompt for generation") 42 | alignment: str = Field("Middle", description="Image alignment") 43 | overlap_left: bool = Field(True, description="Apply overlap on left side") 44 | overlap_right: bool = Field(True, description="Apply overlap on right side") 45 | overlap_top: bool = Field(True, description="Apply overlap on top side") 46 | overlap_bottom: bool = Field(True, description="Apply overlap on bottom side") 47 | 48 | class OutpaintingAPI(LitAPI): 49 | """ 50 | LitAPI implementation for Outpainting model serving. 51 | 52 | This class defines the API for the Outpainting model, including methods for 53 | request decoding, prediction, and response encoding. It uses the Outpainter 54 | class to perform the actual outpainting operations. 55 | 56 | Attributes: 57 | outpainter (Outpainter): An instance of the Outpainter class for performing outpainting. 58 | 59 | Methods: 60 | setup: Initialize the Outpainter and set up any necessary resources. 61 | decode_request: Decode and validate incoming API requests. 62 | predict: Perform the outpainting operation on the input image. 63 | encode_response: Encode the outpainted image and additional information as a response. 64 | """ 65 | 66 | def setup(self, device: str) -> None: 67 | """ 68 | Set up the Outpainting model and associated resources. 69 | 70 | This method is called once when the API is initialized. It creates an instance 71 | of the Outpainter class and performs any necessary setup. 72 | 73 | Args: 74 | device (str): The device to run the model on (e.g., 'cpu', 'cuda'). 75 | 76 | Returns: 77 | None 78 | """ 79 | self.device = device 80 | self.outpainter = Outpainter() 81 | 82 | def decode_request(self, request: Dict[str, Any]) -> Dict[str, Any]: 83 | """ 84 | Decode the incoming request and prepare inputs for the model. 85 | 86 | This method validates the incoming request against the OutpaintingRequest model, 87 | decodes the base64 encoded image, and prepares the inputs for the outpainting process. 88 | 89 | Args: 90 | request (Dict[str, Any]): The raw request data. 91 | 92 | Returns: 93 | Dict[str, Any]: A dictionary containing the decoded image and request parameters. 94 | 95 | Raises: 96 | ValueError: If the request is invalid or cannot be processed. 97 | """ 98 | try: 99 | outpainting_request = OutpaintingRequest(**request) 100 | image_data = base64.b64decode(outpainting_request.image) 101 | image = Image.open(io.BytesIO(image_data)).convert("RGBA") 102 | 103 | return { 104 | 'image': image, 105 | 'params': outpainting_request.dict() 106 | } 107 | except Exception as e: 108 | raise ValueError(f"Invalid request: {str(e)}") 109 | 110 | def predict(self, inputs: Dict[str, Any]) -> Tuple[Image.Image, float, float]: 111 | """ 112 | Run predictions on the input. 113 | 114 | This method performs the outpainting operation using the Outpainter instance. 115 | It takes the decoded inputs from decode_request and passes them to the outpainter. 116 | It also measures the completion time and calculates the prompt ratio. 117 | 118 | Args: 119 | inputs (Dict[str, Any]): A dictionary containing the image and outpainting parameters. 120 | 121 | Returns: 122 | Tuple[Image.Image, float, float]: A tuple containing: 123 | - The resulting outpainted image 124 | - The completion time in seconds 125 | - The prompt ratio (ratio of prompt tokens to total tokens) 126 | """ 127 | image = inputs['image'] 128 | params = inputs['params'] 129 | 130 | start_time = time.time() 131 | 132 | result = self.outpainter.outpaint( 133 | image, 134 | params['width'], 135 | params['height'], 136 | params['overlap_percentage'], 137 | params['num_inference_steps'], 138 | params['resize_option'], 139 | params['custom_resize_percentage'], 140 | params['prompt_input'], 141 | params['alignment'], 142 | params['overlap_left'], 143 | params['overlap_right'], 144 | params['overlap_top'], 145 | params['overlap_bottom'] 146 | ) 147 | 148 | completion_time = time.time() - start_time 149 | return result, completion_time 150 | 151 | def encode_response(self, output: Tuple[Image.Image, float, float]) -> Dict[str, Any]: 152 | """ 153 | Encode the model output and additional information into a response payload. 154 | 155 | This method takes the outpainted image, completion time, and prompt ratio, 156 | encodes the image as a base64 string, and prepares the final API response 157 | with additional information. 158 | 159 | Args: 160 | output (Tuple[Image.Image, float, float]): A tuple containing: 161 | - The outpainted image produced by the predict method 162 | - The completion time in seconds 163 | - The prompt ratio 164 | 165 | Returns: 166 | Dict[str, Any]: A dictionary containing the base64 encoded image string, 167 | completion time, prompt ratio, and image resolution. 168 | """ 169 | image, completion_time = output 170 | img_str = pil_to_s3_json(image,"outpainting_image") 171 | 172 | return { 173 | "result": img_str, 174 | "completion_time": round(completion_time, 2), 175 | "image_resolution": f"{image.width}x{image.height}" 176 | } 177 | 178 | 179 | 180 | 181 | if __name__ == "__main__": 182 | api = OutpaintingAPI() 183 | server = LitServer(api, accelerator="cuda", max_batch_size=1) 184 | server.run(port=8000) -------------------------------------------------------------------------------- /serverless/inpainting/run_inpainting.py: -------------------------------------------------------------------------------- 1 | import io 2 | import base64 3 | import time 4 | import logging 5 | import asyncio 6 | from typing import Dict, Any, Tuple, Optional, AsyncGenerator 7 | from pydantic import BaseModel, Field 8 | from PIL import Image 9 | from scripts.s3_manager import S3ManagerService 10 | from scripts.flux_inference import FluxInpaintingInference 11 | from config_settings import settings 12 | import runpod 13 | 14 | logging.basicConfig(level=logging.INFO) 15 | logger = logging.getLogger(__name__) 16 | 17 | class InpaintingRequest(BaseModel): 18 | """ 19 | Model representing an inpainting request. 20 | """ 21 | prompt: str = Field(..., description="The prompt for inpainting") 22 | strength: float = Field(0.8, ge=0.0, le=1.0, description="Strength of inpainting effect") 23 | seed: int = Field(42, description="Random seed for reproducibility") 24 | num_inference_steps: int = Field(50, ge=1, le=1000, description="Number of inference steps") 25 | input_image: str = Field(..., description="Base64 encoded input image") 26 | mask_image: str = Field(..., description="Base64 encoded mask image") 27 | 28 | # Global instances 29 | global_inpainter = None 30 | global_s3_manager = None 31 | 32 | async def initialize_services(): 33 | """Initialize global services if not already initialized""" 34 | global global_inpainter, global_s3_manager 35 | 36 | if global_inpainter is None: 37 | logger.info("Initializing Flux Inpainting model...") 38 | global_inpainter = FluxInpaintingInference() 39 | logger.info("Flux Inpainting model initialized successfully") 40 | 41 | if global_s3_manager is None: 42 | logger.info("Initializing S3 manager...") 43 | global_s3_manager = S3ManagerService() 44 | logger.info("S3 manager initialized successfully") 45 | 46 | return global_inpainter, global_s3_manager 47 | 48 | async def decode_request(request: Dict[str, Any]) -> Dict[str, Any]: 49 | """ 50 | Decode and validate the incoming inpainting request asynchronously. 51 | """ 52 | try: 53 | logger.info("Decoding inpainting request") 54 | inpainting_request = InpaintingRequest(**request) 55 | 56 | # Run image decoding in thread pool 57 | input_image_data = await asyncio.to_thread( 58 | base64.b64decode, inpainting_request.input_image 59 | ) 60 | mask_image_data = await asyncio.to_thread( 61 | base64.b64decode, inpainting_request.mask_image 62 | ) 63 | 64 | input_image = await asyncio.to_thread( 65 | lambda: Image.open(io.BytesIO(input_image_data)) 66 | ) 67 | mask_image = await asyncio.to_thread( 68 | lambda: Image.open(io.BytesIO(mask_image_data)) 69 | ) 70 | 71 | logger.info("Request decoded successfully") 72 | return { 73 | "prompt": inpainting_request.prompt, 74 | "input_image": input_image, 75 | "mask_image": mask_image, 76 | "strength": inpainting_request.strength, 77 | "seed": inpainting_request.seed, 78 | "num_inference_steps": inpainting_request.num_inference_steps 79 | } 80 | except Exception as e: 81 | logger.error(f"Error in decode_request: {e}") 82 | raise 83 | 84 | async def generate_inpainting(inputs: Dict[str, Any], inpainter: FluxInpaintingInference) -> Tuple[Optional[Image.Image], Dict[str, Any]]: 85 | """ 86 | Perform inpainting operation using the Flux model asynchronously. 87 | """ 88 | start_time = time.time() 89 | 90 | # Run inpainting in thread pool 91 | result_image = await asyncio.to_thread( 92 | inpainter.generate_inpainting, 93 | input_image=inputs["input_image"], 94 | mask_image=inputs["mask_image"], 95 | prompt=inputs["prompt"], 96 | seed=inputs["seed"], 97 | strength=inputs["strength"], 98 | num_inference_steps=inputs["num_inference_steps"] 99 | ) 100 | 101 | output = { 102 | "image": result_image, 103 | "prompt": inputs["prompt"], 104 | "seed": inputs["seed"], 105 | "time_taken": time.time() - start_time 106 | } 107 | 108 | return result_image, output 109 | 110 | async def upload_result(image: Image.Image, metadata: Dict[str, Any], s3_manager: S3ManagerService) -> Dict[str, Any]: 111 | """ 112 | Upload the generated image to S3 and prepare the response asynchronously. 113 | """ 114 | try: 115 | # Prepare image buffer 116 | buffered = io.BytesIO() 117 | await asyncio.to_thread(image.save, buffered, format="PNG") 118 | buffered.seek(0) 119 | 120 | # Generate unique filename and upload 121 | unique_filename = await asyncio.to_thread( 122 | s3_manager.generate_unique_file_name, "result.png" 123 | ) 124 | await asyncio.to_thread( 125 | s3_manager.upload_file, 126 | io.BytesIO(buffered.getvalue()), 127 | unique_filename 128 | ) 129 | 130 | # Generate signed URL 131 | signed_url = await asyncio.to_thread( 132 | s3_manager.generate_signed_url, 133 | unique_filename, 134 | exp=43200 135 | ) 136 | 137 | return { 138 | "result_url": signed_url, 139 | "prompt": metadata["prompt"], 140 | "seed": metadata["seed"], 141 | "time_taken": metadata["time_taken"] 142 | } 143 | except Exception as e: 144 | logger.error(f"Error in upload_result: {e}") 145 | raise 146 | 147 | async def async_generator_handler(job: Dict[str, Any]) -> AsyncGenerator[Dict[str, Any], None]: 148 | """ 149 | Async generator handler for RunPod with progress updates. 150 | """ 151 | try: 152 | # Initial status 153 | yield {"status": "starting", "message": "Initializing inpainting process"} 154 | 155 | # Initialize services 156 | inpainter, s3_manager = await initialize_services() 157 | yield {"status": "processing", "message": "Services initialized successfully"} 158 | 159 | # Decode request 160 | try: 161 | inputs = await decode_request(job['input']) 162 | yield {"status": "processing", "message": "Request decoded successfully"} 163 | except Exception as e: 164 | logger.error(f"Request decode error: {e}") 165 | yield {"status": "error", "message": f"Error decoding request: {str(e)}"} 166 | return 167 | 168 | # Generate inpainting 169 | try: 170 | yield {"status": "processing", "message": "Starting inpainting generation"} 171 | result_image, metadata = await generate_inpainting(inputs, inpainter) 172 | 173 | if result_image is None: 174 | yield {"status": "error", "message": "Failed to generate image"} 175 | return 176 | 177 | yield { 178 | "status": "processing", 179 | "message": "Inpainting generated successfully", 180 | "completion": f"{metadata['time_taken']:.2f}s" 181 | } 182 | except Exception as e: 183 | logger.error(f"Inpainting error: {e}") 184 | yield {"status": "error", "message": f"Error during inpainting: {str(e)}"} 185 | return 186 | 187 | # Upload result 188 | try: 189 | yield {"status": "processing", "message": "Uploading result"} 190 | response = await upload_result(result_image, metadata, s3_manager) 191 | yield {"status": "processing", "message": "Result uploaded successfully"} 192 | except Exception as e: 193 | logger.error(f"Upload error: {e}") 194 | yield {"status": "error", "message": f"Error uploading result: {str(e)}"} 195 | return 196 | 197 | # Final response 198 | yield { 199 | "status": "completed", 200 | "output": response 201 | } 202 | 203 | except Exception as e: 204 | logger.error(f"Unexpected error: {e}") 205 | yield { 206 | "status": "error", 207 | "message": f"Unexpected error: {str(e)}" 208 | } 209 | 210 | # Initialize services when the service starts 211 | print("Initializing service...") 212 | asyncio.get_event_loop().run_until_complete(initialize_services()) 213 | print("Service initialization complete") 214 | 215 | if __name__ == "__main__": 216 | runpod.serverless.start({ 217 | "handler": async_generator_handler, 218 | "return_aggregate_stream": True 219 | }) -------------------------------------------------------------------------------- /scripts/wandb/run-20240429_145519-r0eclldx/files/wandb-metadata.json: -------------------------------------------------------------------------------- 1 | { 2 | "os": "Linux-6.5.0-21-generic-x86_64-with-glibc2.35", 3 | "python": "3.10.12", 4 | "heartbeatAt": "2024-04-29T14:55:20.877614", 5 | "startedAt": "2024-04-29T14:55:19.801537", 6 | "docker": null, 7 | "cuda": null, 8 | "args": [], 9 | "state": "running", 10 | "program": "/home/product_diffusion_api/scripts/sdxl_lora_tuner.py", 11 | "codePathLocal": "sdxl_lora_tuner.py", 12 | "codePath": "scripts/sdxl_lora_tuner.py", 13 | "git": { 14 | "remote": "https://github.com/VikramxD/product_diffusion_api.git", 15 | "commit": "e33275e65d0cd88e0c809d4b5d54039ac777c99d" 16 | }, 17 | "email": "singh.vikram.1782000@gmail.com", 18 | "root": "/home/product_diffusion_api", 19 | "host": "78f51589a5b5", 20 | "username": "root", 21 | "executable": "/home/product_diffusion_api/.venv/bin/python3", 22 | "cpu_count": 32, 23 | "cpu_count_logical": 64, 24 | "cpu_freq": { 25 | "current": 1344.2777499999997, 26 | "min": 800.0, 27 | "max": 3200.0 28 | }, 29 | "cpu_freq_per_core": [ 30 | { 31 | "current": 2901.3, 32 | "min": 800.0, 33 | "max": 3200.0 34 | }, 35 | { 36 | "current": 2926.233, 37 | "min": 800.0, 38 | "max": 3200.0 39 | }, 40 | { 41 | "current": 800.0, 42 | "min": 800.0, 43 | "max": 3200.0 44 | }, 45 | { 46 | "current": 800.0, 47 | "min": 800.0, 48 | "max": 3200.0 49 | }, 50 | { 51 | "current": 800.0, 52 | "min": 800.0, 53 | "max": 3200.0 54 | }, 55 | { 56 | "current": 800.0, 57 | "min": 800.0, 58 | "max": 3200.0 59 | }, 60 | { 61 | "current": 800.0, 62 | "min": 800.0, 63 | "max": 3200.0 64 | }, 65 | { 66 | "current": 2100.324, 67 | "min": 800.0, 68 | "max": 3200.0 69 | }, 70 | { 71 | "current": 2100.469, 72 | "min": 800.0, 73 | "max": 3200.0 74 | }, 75 | { 76 | "current": 2199.995, 77 | "min": 800.0, 78 | "max": 3200.0 79 | }, 80 | { 81 | "current": 2100.338, 82 | "min": 800.0, 83 | "max": 3200.0 84 | }, 85 | { 86 | "current": 1244.531, 87 | "min": 800.0, 88 | "max": 3200.0 89 | }, 90 | { 91 | "current": 1364.104, 92 | "min": 800.0, 93 | "max": 3200.0 94 | }, 95 | { 96 | "current": 1516.545, 97 | "min": 800.0, 98 | "max": 3200.0 99 | }, 100 | { 101 | "current": 800.0, 102 | "min": 800.0, 103 | "max": 3200.0 104 | }, 105 | { 106 | "current": 800.0, 107 | "min": 800.0, 108 | "max": 3200.0 109 | }, 110 | { 111 | "current": 800.045, 112 | "min": 800.0, 113 | "max": 3200.0 114 | }, 115 | { 116 | "current": 1012.875, 117 | "min": 800.0, 118 | "max": 3200.0 119 | }, 120 | { 121 | "current": 800.0, 122 | "min": 800.0, 123 | "max": 3200.0 124 | }, 125 | { 126 | "current": 800.0, 127 | "min": 800.0, 128 | "max": 3200.0 129 | }, 130 | { 131 | "current": 800.0, 132 | "min": 800.0, 133 | "max": 3200.0 134 | }, 135 | { 136 | "current": 800.0, 137 | "min": 800.0, 138 | "max": 3200.0 139 | }, 140 | { 141 | "current": 800.0, 142 | "min": 800.0, 143 | "max": 3200.0 144 | }, 145 | { 146 | "current": 800.0, 147 | "min": 800.0, 148 | "max": 3200.0 149 | }, 150 | { 151 | "current": 3200.0, 152 | "min": 800.0, 153 | "max": 3200.0 154 | }, 155 | { 156 | "current": 800.0, 157 | "min": 800.0, 158 | "max": 3200.0 159 | }, 160 | { 161 | "current": 800.0, 162 | "min": 800.0, 163 | "max": 3200.0 164 | }, 165 | { 166 | "current": 3200.0, 167 | "min": 800.0, 168 | "max": 3200.0 169 | }, 170 | { 171 | "current": 2933.878, 172 | "min": 800.0, 173 | "max": 3200.0 174 | }, 175 | { 176 | "current": 2966.521, 177 | "min": 800.0, 178 | "max": 3200.0 179 | }, 180 | { 181 | "current": 2966.524, 182 | "min": 800.0, 183 | "max": 3200.0 184 | }, 185 | { 186 | "current": 2966.525, 187 | "min": 800.0, 188 | "max": 3200.0 189 | }, 190 | { 191 | "current": 2966.202, 192 | "min": 800.0, 193 | "max": 3200.0 194 | }, 195 | { 196 | "current": 2966.202, 197 | "min": 800.0, 198 | "max": 3200.0 199 | }, 200 | { 201 | "current": 2966.212, 202 | "min": 800.0, 203 | "max": 3200.0 204 | }, 205 | { 206 | "current": 800.026, 207 | "min": 800.0, 208 | "max": 3200.0 209 | }, 210 | { 211 | "current": 1583.189, 212 | "min": 800.0, 213 | "max": 3200.0 214 | }, 215 | { 216 | "current": 800.0, 217 | "min": 800.0, 218 | "max": 3200.0 219 | }, 220 | { 221 | "current": 800.0, 222 | "min": 800.0, 223 | "max": 3200.0 224 | }, 225 | { 226 | "current": 1431.135, 227 | "min": 800.0, 228 | "max": 3200.0 229 | }, 230 | { 231 | "current": 800.0, 232 | "min": 800.0, 233 | "max": 3200.0 234 | }, 235 | { 236 | "current": 800.0, 237 | "min": 800.0, 238 | "max": 3200.0 239 | }, 240 | { 241 | "current": 800.0, 242 | "min": 800.0, 243 | "max": 3200.0 244 | }, 245 | { 246 | "current": 800.0, 247 | "min": 800.0, 248 | "max": 3200.0 249 | }, 250 | { 251 | "current": 800.0, 252 | "min": 800.0, 253 | "max": 3200.0 254 | }, 255 | { 256 | "current": 800.0, 257 | "min": 800.0, 258 | "max": 3200.0 259 | }, 260 | { 261 | "current": 800.0, 262 | "min": 800.0, 263 | "max": 3200.0 264 | }, 265 | { 266 | "current": 800.0, 267 | "min": 800.0, 268 | "max": 3200.0 269 | }, 270 | { 271 | "current": 800.0, 272 | "min": 800.0, 273 | "max": 3200.0 274 | }, 275 | { 276 | "current": 976.385, 277 | "min": 800.0, 278 | "max": 3200.0 279 | }, 280 | { 281 | "current": 800.0, 282 | "min": 800.0, 283 | "max": 3200.0 284 | }, 285 | { 286 | "current": 800.0, 287 | "min": 800.0, 288 | "max": 3200.0 289 | }, 290 | { 291 | "current": 800.0, 292 | "min": 800.0, 293 | "max": 3200.0 294 | }, 295 | { 296 | "current": 800.062, 297 | "min": 800.0, 298 | "max": 3200.0 299 | }, 300 | { 301 | "current": 800.0, 302 | "min": 800.0, 303 | "max": 3200.0 304 | }, 305 | { 306 | "current": 800.0, 307 | "min": 800.0, 308 | "max": 3200.0 309 | }, 310 | { 311 | "current": 800.0, 312 | "min": 800.0, 313 | "max": 3200.0 314 | }, 315 | { 316 | "current": 800.0, 317 | "min": 800.0, 318 | "max": 3200.0 319 | }, 320 | { 321 | "current": 800.015, 322 | "min": 800.0, 323 | "max": 3200.0 324 | }, 325 | { 326 | "current": 800.0, 327 | "min": 800.0, 328 | "max": 3200.0 329 | }, 330 | { 331 | "current": 2892.103, 332 | "min": 800.0, 333 | "max": 3200.0 334 | }, 335 | { 336 | "current": 2902.559, 337 | "min": 800.0, 338 | "max": 3200.0 339 | }, 340 | { 341 | "current": 800.0, 342 | "min": 800.0, 343 | "max": 3200.0 344 | }, 345 | { 346 | "current": 2949.129, 347 | "min": 800.0, 348 | "max": 3200.0 349 | } 350 | ], 351 | "disk": { 352 | "/": { 353 | "total": 1876.216781616211, 354 | "used": 704.5339584350586 355 | } 356 | }, 357 | "gpu": "NVIDIA RTX A5000", 358 | "gpu_count": 1, 359 | "gpu_devices": [ 360 | { 361 | "name": "NVIDIA RTX A5000", 362 | "memory_total": 25757220864 363 | } 364 | ], 365 | "memory": { 366 | "total": 251.53280639648438 367 | } 368 | } 369 | -------------------------------------------------------------------------------- /scripts/outpainting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image, ImageDraw 3 | import numpy as np 4 | from diffusers import AutoencoderKL, TCDScheduler 5 | from diffusers.models.model_loading_utils import load_state_dict 6 | from huggingface_hub import hf_hub_download 7 | from scripts.controlnet_union import ControlNetModel_Union 8 | from scripts.pipeline_fill_sd_xl import StableDiffusionXLFillPipeline 9 | 10 | 11 | 12 | 13 | class Outpainter: 14 | """ 15 | A class for performing outpainting operations using Stable Diffusion XL. 16 | 17 | This class handles the setup and execution of outpainting tasks, including 18 | model initialization, image preparation, and the actual outpainting process. 19 | """ 20 | 21 | def __init__(self): 22 | """Initialize the Outpainter by setting up the required models.""" 23 | self.setup_model() 24 | 25 | def setup_model(self): 26 | """ 27 | Set up and configure the SDXL model with ControlNet and VAE components. 28 | 29 | Downloads necessary model files, initializes components, and configures 30 | the pipeline for inference. 31 | """ 32 | config_file = hf_hub_download( 33 | "xinsir/controlnet-union-sdxl-1.0", 34 | filename="config_promax.json", 35 | ) 36 | config = ControlNetModel_Union.load_config(config_file) 37 | controlnet_model = ControlNetModel_Union.from_config(config) 38 | model_file = hf_hub_download( 39 | "xinsir/controlnet-union-sdxl-1.0", 40 | filename="diffusion_pytorch_model_promax.safetensors", 41 | ) 42 | state_dict = load_state_dict(model_file) 43 | model, _, _, _, _ = ControlNetModel_Union._load_pretrained_model( 44 | controlnet_model, state_dict, model_file, "xinsir/controlnet-union-sdxl-1.0" 45 | ) 46 | model.to(device="cuda", dtype=torch.float16) 47 | 48 | vae = AutoencoderKL.from_pretrained( 49 | "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16 50 | ).to("cuda") 51 | 52 | self.pipe = StableDiffusionXLFillPipeline.from_pretrained( 53 | "SG161222/RealVisXL_V5.0_Lightning", 54 | torch_dtype=torch.float16, 55 | vae=vae, 56 | controlnet=model, 57 | variant="fp16", 58 | ).to("cuda") 59 | 60 | self.pipe.scheduler = TCDScheduler.from_config(self.pipe.scheduler.config) 61 | 62 | def calculate_margins(self, target_size: tuple, new_width: int, new_height: int, alignment: str) -> tuple: 63 | """ 64 | Calculate image margins based on alignment and dimensions. 65 | 66 | Args: 67 | target_size: Tuple of (width, height) for the target canvas size 68 | new_width: Width of the resized image 69 | new_height: Height of the resized image 70 | alignment: Position alignment ("Middle", "Left", "Right", "Top", "Bottom") 71 | 72 | Returns: 73 | tuple: (margin_x, margin_y) coordinates for image placement 74 | """ 75 | if alignment == "Middle": 76 | margin_x = (target_size[0] - new_width) // 2 77 | margin_y = (target_size[1] - new_height) // 2 78 | elif alignment == "Left": 79 | margin_x = 0 80 | margin_y = (target_size[1] - new_height) // 2 81 | elif alignment == "Right": 82 | margin_x = target_size[0] - new_width 83 | margin_y = (target_size[1] - new_height) // 2 84 | elif alignment == "Top": 85 | margin_x = (target_size[0] - new_width) // 2 86 | margin_y = 0 87 | elif alignment == "Bottom": 88 | margin_x = (target_size[0] - new_width) // 2 89 | margin_y = target_size[1] - new_height 90 | else: 91 | margin_x = (target_size[0] - new_width) // 2 92 | margin_y = (target_size[1] - new_height) // 2 93 | 94 | margin_x = max(0, min(margin_x, target_size[0] - new_width)) 95 | margin_y = max(0, min(margin_y, target_size[1] - new_height)) 96 | 97 | return margin_x, margin_y 98 | 99 | def prepare_image_and_mask(self, image: Image, width: int, height: int, 100 | overlap_percentage: int, resize_option: str, 101 | custom_resize_percentage: int, alignment: str, 102 | overlap_left: bool, overlap_right: bool, 103 | overlap_top: bool, overlap_bottom: bool) -> tuple: 104 | """ 105 | Prepare the input image and generate a mask for outpainting. 106 | 107 | Args: 108 | image: Input PIL Image 109 | width: Target width for output 110 | height: Target height for output 111 | overlap_percentage: Percentage of overlap for mask 112 | resize_option: Image resize option ("Full", "50%", "33%", "25%", "Custom") 113 | custom_resize_percentage: Custom resize percentage if resize_option is "Custom" 114 | alignment: Image alignment in the canvas 115 | overlap_left: Apply overlap on left side 116 | overlap_right: Apply overlap on right side 117 | overlap_top: Apply overlap on top side 118 | overlap_bottom: Apply overlap on bottom side 119 | 120 | Returns: 121 | tuple: (background_image, mask_image) prepared for outpainting 122 | """ 123 | target_size = (width, height) 124 | scale_factor = min(target_size[0] / image.width, target_size[1] / image.height) 125 | new_width = int(image.width * scale_factor) 126 | new_height = int(image.height * scale_factor) 127 | 128 | source = image.resize((new_width, new_height), Image.LANCZOS) 129 | 130 | resize_percentage = { 131 | "Full": 100, 132 | "50%": 50, 133 | "33%": 33, 134 | "25%": 25 135 | }.get(resize_option, custom_resize_percentage) 136 | 137 | resize_factor = resize_percentage / 100 138 | new_width = max(int(source.width * resize_factor), 64) 139 | new_height = max(int(source.height * resize_factor), 64) 140 | 141 | source = source.resize((new_width, new_height), Image.LANCZOS) 142 | 143 | overlap_x = max(int(new_width * (overlap_percentage / 100)), 1) 144 | overlap_y = max(int(new_height * (overlap_percentage / 100)), 1) 145 | 146 | margin_x, margin_y = self.calculate_margins(target_size, new_width, new_height, alignment) 147 | 148 | background = Image.new('RGB', target_size, (255, 255, 255)) 149 | background.paste(source, (margin_x, margin_y)) 150 | 151 | mask = Image.new('L', target_size, 255) 152 | white_gaps_patch = 2 153 | 154 | left_overlap = margin_x + (overlap_x if overlap_left else white_gaps_patch) 155 | right_overlap = margin_x + new_width - (overlap_x if overlap_right else white_gaps_patch) 156 | top_overlap = margin_y + (overlap_y if overlap_top else white_gaps_patch) 157 | bottom_overlap = margin_y + new_height - (overlap_y if overlap_bottom else white_gaps_patch) 158 | 159 | if alignment == "Left": 160 | left_overlap = margin_x + (overlap_x if overlap_left else 0) 161 | elif alignment == "Right": 162 | right_overlap = margin_x + new_width - (overlap_x if overlap_right else 0) 163 | elif alignment == "Top": 164 | top_overlap = margin_y + (overlap_y if overlap_top else 0) 165 | elif alignment == "Bottom": 166 | bottom_overlap = margin_y + new_height - (overlap_y if overlap_bottom else 0) 167 | 168 | mask_draw = ImageDraw.Draw(mask) 169 | mask_draw.rectangle([left_overlap, top_overlap, right_overlap, bottom_overlap], fill=0) 170 | 171 | return background, mask 172 | 173 | def outpaint(self, image: Image, width: int, height: int, 174 | overlap_percentage: int, num_inference_steps: int, 175 | resize_option: str, custom_resize_percentage: int, 176 | prompt_input: str, alignment: str, 177 | overlap_left: bool, overlap_right: bool, 178 | overlap_top: bool, overlap_bottom: bool) -> Image: 179 | """ 180 | Perform outpainting on the input image. 181 | 182 | Args: 183 | image: Input PIL Image to outpaint 184 | width: Target width for output 185 | height: Target height for output 186 | overlap_percentage: Percentage of overlap for mask 187 | num_inference_steps: Number of denoising steps 188 | resize_option: Image resize option 189 | custom_resize_percentage: Custom resize percentage 190 | prompt_input: Text prompt for generation 191 | alignment: Image alignment in canvas 192 | overlap_left: Apply overlap on left 193 | overlap_right: Apply overlap on right 194 | overlap_top: Apply overlap on top 195 | overlap_bottom: Apply overlap on bottom 196 | 197 | Returns: 198 | PIL.Image: Outpainted image 199 | """ 200 | background, mask = self.prepare_image_and_mask( 201 | image, width, height, overlap_percentage, resize_option, 202 | custom_resize_percentage, alignment, overlap_left, overlap_right, 203 | overlap_top, overlap_bottom 204 | ) 205 | 206 | cnet_image = background.copy() 207 | cnet_image.paste(0, (0, 0), mask) 208 | 209 | final_prompt = f"{prompt_input}, high quality, 4k" 210 | 211 | ( 212 | prompt_embeds, 213 | negative_prompt_embeds, 214 | pooled_prompt_embeds, 215 | negative_pooled_prompt_embeds, 216 | ) = self.pipe.encode_prompt(final_prompt, "cuda", True) 217 | 218 | generator = self.pipe( 219 | prompt_embeds=prompt_embeds, 220 | negative_prompt_embeds=negative_prompt_embeds, 221 | pooled_prompt_embeds=pooled_prompt_embeds, 222 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 223 | image=cnet_image, 224 | num_inference_steps=num_inference_steps 225 | ) 226 | 227 | for output in generator: 228 | final_image = output 229 | 230 | final_image = final_image.convert("RGBA") 231 | cnet_image.paste(final_image, (0, 0), mask) 232 | 233 | return cnet_image -------------------------------------------------------------------------------- /Readme.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 171 |
172 |

🚀 PicPilot

173 | 174 |
175 | GitHub Stars 176 | MIT License 177 |
178 | 179 |

PicPilot is a scalable solution that leverages state-of-the-art Text to Image Models to extend and enhance images and create product photography in seconds for your brand. Whether you're working with existing product images or need to generate new visuals, PicPilot prepares your visual content to be stunning, professional, and ready for marketing applications.

180 | 181 |

Features

182 |
183 |
184 |

Flux Inpainting

185 |

Detailed image editing for precise enhancements.

186 |
187 |
188 |

SDXL Generation

189 |

High-quality image generation using SDXL with LoRA.

190 |
191 |
192 |

SDXL Outpainting

193 |

Extend images seamlessly for creative compositions.

194 |
195 |
196 |

Image to Video

197 |

Convert static images into dynamic videos with CogvideoX.

198 |
199 |
200 |

Batch Processing

201 |

Handle multiple images with configurable batch sizes and timeouts.

202 |
203 |
204 |

Scalable Solution

205 |

Built to handle your growing visual content needs.

206 |
207 |
208 | 209 |

Why PicPilot?

210 |

Creating professional product photography and visual narratives can be time-consuming and expensive. PicPilot aims to revolutionize this process by offering an AI-powered platform where you can enhance existing images, generate new ones, or even convert images to videos, creating stunning visuals for your brand in seconds.

211 | 212 |

Installation

213 |
git clone https://github.com/VikramxD/Picpilot
214 | cd Picpilot
215 | ./run.sh
216 | 217 |

Docker Support

218 |

Using Pre-built Docker Image

219 |
docker pull vikram1202/picpilot:latest
220 | docker run --gpus all -p 8000:8000 vikram1202/picpilot:latest
221 | 222 |

Building Docker Image Locally

223 |
docker build -t picpilot .
224 | docker run --gpus all -p 8000:8000 picpilot
225 | 226 |

Usage

227 |

Run the Server:

228 |
cd api
229 | python picpilot.py
230 |

This will start the server on port 8000 with all available API endpoints.

231 | 232 |

API Endpoints

233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 |
EndpointPathPurposeMax Batch SizeBatch Timeout
Flux Inpainting/api/v2/painting/fluxDetailed image editing and inpainting40.1 seconds
SDXL Generation/api/v2/generate/sdxlHigh-quality image generation using SDXL with LoRAConfigured in tti_settingsConfigured in tti_settings
SDXL Outpainting/api/v2/painting/sdxl_outpaintingExtending images seamlessly40.1 seconds
Image to Video/api/v2/image2video/cogvideoxConverting images to videos10.1 seconds
270 | 271 |

Next Features

272 |
    273 |
  • Support for Image Editing in FLUX Models
  • 274 |
  • Support for Custom Flux LORA'S
  • 275 |
  • Support for CogvideoX finetuning
  • 276 |
277 | 278 |

Limitations

279 |
    280 |
  • Requires Powerful GPU's to Run for optimal performance Especially the FLUX Models
  • 281 |
  • Processing time may vary depending on the complexity of the task and input size
  • 282 |
  • Image to video conversion is limited to one image at a time
  • 283 |
284 | 285 |

License

286 |

PicPilot is licensed under the MIT license. See LICENSE for more information.

287 | 288 |

Acknowledgements

289 |

This project utilizes several open-source models and libraries. We express our gratitude to the authors and contributors of:

290 |
    291 |
  • Diffusers
  • 292 |
  • LitServe
  • 293 |
  • Transformers
  • 294 |
295 |
296 |
297 |
298 |
299 | -------------------------------------------------------------------------------- /scripts/wandb/run-20240429_145519-r0eclldx/files/diff.patch: -------------------------------------------------------------------------------- 1 | diff --git a/.gitignore b/.gitignore 2 | index 5bbee1b..1d17dae 100644 3 | --- a/.gitignore 4 | +++ b/.gitignore 5 | @@ -1,5 +1 @@ 6 | .venv 7 | -data 8 | -scripts/wandb 9 | -models 10 | -scripts/yolov8* 11 | diff --git a/requirements.txt b/requirements.txt 12 | index d1c8048..85f0bbc 100644 13 | --- a/requirements.txt 14 | +++ b/requirements.txt 15 | @@ -9,7 +9,13 @@ numpy 16 | rich 17 | tqdm 18 | transformers 19 | -opencv-python-headless 20 | fastapi 21 | uvicorn 22 | matplotlib 23 | +accelerate 24 | +torchvision 25 | +ftfy 26 | +tensorboard 27 | +Jinja2 28 | +datasets 29 | +peft 30 | diff --git a/scripts/clear_memory.py b/scripts/clear_memory.py 31 | deleted file mode 100644 32 | index 7b6010e..0000000 33 | --- a/scripts/clear_memory.py 34 | +++ /dev/null 35 | @@ -1,18 +0,0 @@ 36 | -import gc 37 | -import torch 38 | -from logger import rich_logger as l 39 | - 40 | -def clear_memory(): 41 | - """ 42 | - Clears the memory by collecting garbage and emptying the CUDA cache. 43 | - 44 | - This function is useful when dealing with memory-intensive operations in Python, especially when using libraries like PyTorch. 45 | - 46 | - Note: 47 | - This function requires the `gc` and `torch` modules to be imported. 48 | - 49 | - """ 50 | - gc.collect() 51 | - torch.cuda.empty_cache() 52 | - l.info("Memory Cleared") 53 | - 54 | \ No newline at end of file 55 | diff --git a/scripts/config.py b/scripts/config.py 56 | index b620197..10947d3 100644 57 | --- a/scripts/config.py 58 | +++ b/scripts/config.py 59 | @@ -1,13 +1,60 @@ 60 | -LOGS_DIR = '../logs' 61 | -DATA_DIR = '../data' 62 | -Project_Name = 'product_placement_api' 63 | -entity = 'vikramxd' 64 | -image_dir = '../sample_data' 65 | -mask_dir = '../masks' 66 | -segmentation_model = 'facebook/sam-vit-large' 67 | -detection_model = 'yolov8l' 68 | -kandinsky_model_name = 'kandinsky-community/kandinsky-2-2-decoder-inpaint' 69 | -video_model_name = 'stabilityai/stable-video-diffusion-img2vid-xt' 70 | -target_width = 2560 71 | -target_height = 1440 72 | -roi_scale = 0.6 73 | +MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0" 74 | +VAE_NAME= "madebyollin/sdxl-vae-fp16-fix" 75 | +DATASET_NAME= "hahminlew/kream-product-blip-captions" 76 | +PROJECT_NAME = "Product Photography" 77 | + 78 | +class Config: 79 | + def __init__(self): 80 | + self.pretrained_model_name_or_path = MODEL_NAME 81 | + self.pretrained_vae_model_name_or_path = VAE_NAME 82 | + self.revision = None 83 | + self.variant = None 84 | + self.dataset_name = DATASET_NAME 85 | + self.dataset_config_name = None 86 | + self.train_data_dir = None 87 | + self.image_column = 'image' 88 | + self.caption_column = 'text' 89 | + self.validation_prompt = None 90 | + self.num_validation_images = 4 91 | + self.validation_epochs = 1 92 | + self.max_train_samples = None 93 | + self.output_dir = "output" 94 | + self.cache_dir = None 95 | + self.seed = None 96 | + self.resolution = 1024 97 | + self.center_crop = False 98 | + self.random_flip = False 99 | + self.train_text_encoder = False 100 | + self.train_batch_size = 16 101 | + self.num_train_epochs = 200 102 | + self.max_train_steps = None 103 | + self.checkpointing_steps = 500 104 | + self.checkpoints_total_limit = None 105 | + self.resume_from_checkpoint = None 106 | + self.gradient_accumulation_steps = 1 107 | + self.gradient_checkpointing = False 108 | + self.learning_rate = 1e-4 109 | + self.scale_lr = False 110 | + self.lr_scheduler = "constant" 111 | + self.lr_warmup_steps = 500 112 | + self.snr_gamma = None 113 | + self.allow_tf32 = False 114 | + self.dataloader_num_workers = 0 115 | + self.use_8bit_adam = True 116 | + self.adam_beta1 = 0.9 117 | + self.adam_beta2 = 0.999 118 | + self.adam_weight_decay = 1e-2 119 | + self.adam_epsilon = 1e-08 120 | + self.max_grad_norm = 1.0 121 | + self.push_to_hub = False 122 | + self.hub_token = None 123 | + self.prediction_type = None 124 | + self.hub_model_id = None 125 | + self.logging_dir = "logs" 126 | + self.report_to = "wandb" 127 | + self.mixed_precision = None 128 | + self.local_rank = -1 129 | + self.enable_xformers_memory_efficient_attention = False 130 | + self.noise_offset = 0 131 | + self.rank = 4 132 | + self.debug_loss = False 133 | diff --git a/scripts/endpoint.py b/scripts/endpoint.py 134 | deleted file mode 100644 135 | index cbb9ebe..0000000 136 | --- a/scripts/endpoint.py 137 | +++ /dev/null 138 | @@ -1,65 +0,0 @@ 139 | -from fastapi import FastAPI,HTTPException 140 | -from fastapi.responses import FileResponse 141 | -from fastapi.middleware.cors import CORSMiddleware 142 | -from models import kandinsky_inpainting_inference 143 | -from segment_everything import extend_image, generate_mask_from_bbox, invert_mask 144 | -from video_pipeline import fetch_video_pipeline 145 | -from diffusers.utils import load_image 146 | -from logger import rich_logger as l 147 | -from fastapi import UploadFile, File 148 | -from config import segmentation_model, detection_model,target_height, target_width, roi_scale 149 | -from PIL import Image 150 | -import io 151 | -import tempfile 152 | - 153 | - 154 | - 155 | - 156 | - 157 | - 158 | -app = FastAPI(title="Product Diffusion API", 159 | - description="API for Product Diffusion", 160 | - version="0.1.0", 161 | - openapi_url="/api/v1/openapi.json") 162 | - 163 | - 164 | -app.add_middleware( 165 | - CORSMiddleware, 166 | - allow_origins=["*"], 167 | - allow_methods=["*"], 168 | - allow_headers=["*"], 169 | - allow_credentials=True 170 | - 171 | -) 172 | - 173 | -@app.post("/api/v1/image_outpainting") 174 | -async def image_outpainting(image: UploadFile, prompt: str, negative_prompt: str,num_inference_steps:int=30): 175 | - """ 176 | - Perform Outpainting on an image. 177 | - 178 | - Args: 179 | - image (UploadFile): The input image file. 180 | - prompt (str): The prompt for the outpainting. 181 | - negative_prompt (str): The negative prompt for the outpainting. 182 | - 183 | - Returns: 184 | - JSONResponse: The output image path. 185 | - """ 186 | - image_data = await image.read() 187 | - image = Image.open(io.BytesIO(image_data)) 188 | - image = load_image(image) 189 | - image = extend_image(image, target_width=target_width, target_height=target_height, roi_scale=roi_scale) 190 | - mask_image = generate_mask_from_bbox(image, segmentation_model, detection_model) 191 | - mask_image = Image.fromarray(mask_image) 192 | - mask_image = invert_mask(mask_image) 193 | - output_image = kandinsky_inpainting_inference(prompt, negative_prompt, image, mask_image,num_inference_steps=num_inference_steps) 194 | - with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as temp_file: 195 | - output_image.save(temp_file, format='JPEG') 196 | - temp_file_path = temp_file.name 197 | - return FileResponse(temp_file_path, media_type='image/jpeg', filename='output_image.jpg') 198 | - 199 | - 200 | - 201 | - 202 | - 203 | - 204 | \ No newline at end of file 205 | diff --git a/scripts/logger.py b/scripts/logger.py 206 | index 2e0f42f..c493b93 100644 207 | --- a/scripts/logger.py 208 | +++ b/scripts/logger.py 209 | @@ -25,5 +25,4 @@ for level in log_levels: 210 | file_handler = RotatingFileHandler(log_file, maxBytes=10 * 1024 * 1024, backupCount=5) 211 | file_handler.setLevel(level) 212 | file_handler.setFormatter(logging.Formatter('%(asctime)s [%(levelname)s] %(module)s - %(message)s')) 213 | - rich_logger.addHandler(file_handler) 214 | - 215 | + rich_logger.addHandler(file_handler) 216 | \ No newline at end of file 217 | diff --git a/scripts/models.py b/scripts/models.py 218 | deleted file mode 100644 219 | index 2ca9eea..0000000 220 | --- a/scripts/models.py 221 | +++ /dev/null 222 | @@ -1,82 +0,0 @@ 223 | -from logger import rich_logger as l 224 | -from wandb.integration.diffusers import autolog 225 | -from config import Project_Name 226 | -from clear_memory import clear_memory 227 | -import numpy as np 228 | -import torch 229 | -from diffusers.utils import load_image 230 | -from pipeline import fetch_kandinsky_pipeline 231 | -from config import controlnet_adapter_model_name,controlnet_base_model_name,kandinsky_model_name 232 | -from diffusers import StableDiffusionInpaintPipeline, DPMSolverMultistepScheduler 233 | -from video_pipeline import fetch_video_pipeline 234 | -from config import video_model_name 235 | - 236 | - 237 | - 238 | - 239 | - 240 | - 241 | - 242 | - 243 | - 244 | - 245 | - 246 | - 247 | - 248 | - 249 | - 250 | - 251 | -def kandinsky_inpainting_inference(prompt, negative_prompt, image, mask_image,num_inference_steps=800,strength=1.0,guidance_scale = 7.8): 252 | - """ 253 | - Perform Kandinsky inpainting inference on the given image. 254 | - 255 | - Args: 256 | - prompt (str): The prompt for the inpainting process. 257 | - negative_prompt (str): The negative prompt for the inpainting process. 258 | - image (PIL.Image.Image): The input image to be inpainted. 259 | - mask_image (PIL.Image.Image): The mask image indicating the areas to be inpainted. 260 | - 261 | - Returns: 262 | - PIL.Image.Image: The output inpainted image. 263 | - """ 264 | - clear_memory() 265 | - l.info("Kandinsky Inpainting Inference ->") 266 | - pipe = fetch_kandinsky_pipeline(controlnet_adapter_model_name, controlnet_base_model_name,kandinsky_model_name, image) 267 | - output_image = pipe(prompt=prompt,negative_prompt=negative_prompt,image=image,mask_image=mask_image,num_inference_steps=num_inference_steps,strength=strength,guidance_scale = guidance_scale,height = 1472, width = 2560).images[0] 268 | - return output_image 269 | - 270 | - 271 | - 272 | - 273 | - 274 | - 275 | - 276 | - 277 | - 278 | -def image_to_video_pipeline(image, video_model_name, decode_chunk_size, motion_bucket_id, generator=torch.manual_seed(42)): 279 | - """ 280 | - Converts an image to a video using a specified video model. 281 | - 282 | - Args: 283 | - image (Image): The input image to convert to video. 284 | - video_model_name (str): The name of the video model to use. 285 | - decode_chunk_size (int): The size of the chunks to decode. 286 | - motion_bucket_id (str): The ID of the motion bucket. 287 | - generator (torch.Generator, optional): The random number generator. Defaults to torch.manual_seed(42). 288 | - 289 | - Returns: 290 | - list: The frames of the generated video. 291 | - """ 292 | - clear_memory() 293 | - l.info("Stable Video Diffusion Image 2 Video pipeline Inference ->") 294 | - pipe = fetch_video_pipeline(video_model_name) 295 | - frames = pipe(image=image, decode_chunk_size=decode_chunk_size, motion_bucket_id=motion_bucket_id, generator=generator).frames[0] 296 | - return frames 297 | - 298 | - 299 | - 300 | - 301 | - 302 | - 303 | - 304 | - 305 | diff --git a/scripts/pipeline.py b/scripts/pipeline.py 306 | deleted file mode 100644 307 | index af0e6bf..0000000 308 | --- a/scripts/pipeline.py 309 | +++ /dev/null 310 | @@ -1,100 +0,0 @@ 311 | -from diffusers import ControlNetModel,StableDiffusionControlNetInpaintPipeline,AutoPipelineForInpainting 312 | -import torch 313 | - 314 | - 315 | - 316 | - 317 | - 318 | - 319 | - 320 | -class PipelineFetcher: 321 | - """ 322 | - A class that fetches different pipelines for image processing. 323 | - 324 | - Args: 325 | - controlnet_adapter_model_name (str): The name of the controlnet adapter model. 326 | - controlnet_base_model_name (str): The name of the controlnet base model. 327 | - kandinsky_model_name (str): The name of the Kandinsky model. 328 | - image (str): The image to be processed. 329 | - 330 | - """ 331 | - 332 | - def __init__(self, controlnet_adapter_model_name, controlnet_base_model_name, kandinsky_model_name, image: str): 333 | - self.controlnet_adapter_model_name = controlnet_adapter_model_name 334 | - self.controlnet_base_model_name = controlnet_base_model_name 335 | - self.kandinsky_model_name = kandinsky_model_name 336 | - self.image = image 337 | - 338 | - def ControlNetInpaintPipeline(self): 339 | - """ 340 | - Fetches the ControlNet inpainting pipeline. 341 | - 342 | - Returns: 343 | - pipe (StableDiffusionControlNetInpaintPipeline): The ControlNet inpainting pipeline. 344 | - 345 | - """ 346 | - controlnet = ControlNetModel.from_pretrained(self.controlnet_adapter_model_name, torch_dtype=torch.float16) 347 | - pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( 348 | - self.controlnet_base_model_name, controlnet=controlnet, torch_dtype=torch.float16 349 | - ) 350 | - pipe.to('cuda') 351 | - 352 | - return pipe 353 | - 354 | - def KandinskyPipeline(self): 355 | - """ 356 | - Fetches the Kandinsky pipeline. 357 | - 358 | - Returns: 359 | - pipe (AutoPipelineForInpainting): The Kandinsky pipeline. 360 | - 361 | - """ 362 | - pipe = AutoPipelineForInpainting.from_pretrained(self.kandinsky_model_name, torch_dtype=torch.float16) 363 | - pipe = pipe.to('cuda') 364 | - pipe.unet = torch.compile(pipe.unet) 365 | - 366 | - return pipe 367 | - 368 | - 369 | - 370 | - 371 | - 372 | - 373 | -def fetch_control_pipeline(controlnet_adapter_model_name, controlnet_base_model_name, kandinsky_model_name, image): 374 | - """ 375 | - Fetches the control pipeline for image processing. 376 | - 377 | - Args: 378 | - controlnet_adapter_model_name (str): The name of the controlnet adapter model. 379 | - controlnet_base_model_name (str): The name of the controlnet base model. 380 | - kandinsky_model_name (str): The name of the Kandinsky model. 381 | - image: The input image for processing. 382 | - 383 | - Returns: 384 | - pipe: The control pipeline for image processing. 385 | - """ 386 | - pipe_fetcher = PipelineFetcher(controlnet_adapter_model_name, controlnet_base_model_name, kandinsky_model_name, image) 387 | - pipe = pipe_fetcher.ControlNetInpaintPipeline() 388 | - return pipe 389 | - 390 | - 391 | -def fetch_kandinsky_pipeline(controlnet_adapter_model_name, controlnet_base_model_name, kandinsky_model_name, image): 392 | - """ 393 | - Fetches the Kandinsky pipeline. 394 | - 395 | - Args: 396 | - controlnet_adapter_model_name (str): The name of the controlnet adapter model. 397 | - controlnet_base_model_name (str): The name of the controlnet base model. 398 | - kandinsky_model_name (str): The name of the Kandinsky model. 399 | - image: The input image. 400 | - 401 | - Returns: 402 | - pipe: The Kandinsky pipeline. 403 | - """ 404 | - pipe_fetcher = PipelineFetcher(controlnet_adapter_model_name, controlnet_base_model_name, kandinsky_model_name, image) 405 | - pipe = pipe_fetcher.KandinskyPipeline() 406 | - pipe = pipe.to('cuda') 407 | - 408 | - return pipe 409 | - 410 | - 411 | diff --git a/scripts/run.py b/scripts/run.py 412 | deleted file mode 100644 413 | index cccc06a..0000000 414 | --- a/scripts/run.py 415 | +++ /dev/null 416 | @@ -1,39 +0,0 @@ 417 | -import argparse 418 | -import os 419 | -from segment_everything import generate_mask_from_bbox, extend_image, invert_mask 420 | -from models import kandinsky_inpainting_inference, load_image 421 | -from PIL import Image 422 | -from config import segmentation_model, detection_model,target_height, target_width, roi_scale 423 | - 424 | -def main(args): 425 | - """ 426 | - Main function that performs the product diffusion process. 427 | - 428 | - Args: 429 | - args (Namespace): Command-line arguments. 430 | - 431 | - Returns: 432 | - None 433 | - """ 434 | - os.makedirs(args.output_dir, exist_ok=True) 435 | - os.makedirs(args.mask_dir, exist_ok=True) 436 | - output_image_path = os.path.join(args.output_dir, f'{args.uid}_output.jpg') 437 | - image = load_image(args.image_path) 438 | - extended_image = extend_image(image, target_width=target_width, target_height=target_height, roi_scale=roi_scale) 439 | - mask = generate_mask_from_bbox(extended_image, segmentation_model, detection_model) 440 | - mask_image = Image.fromarray(mask) 441 | - inverted_mask = invert_mask(mask_image) 442 | - #inverted_mask = Image.fromarray(inverted_mask) 443 | - output_image = kandinsky_inpainting_inference(args.prompt, args.negative_prompt, extended_image, inverted_mask) 444 | - output_image.save(output_image_path) 445 | - 446 | -if __name__ == "__main__": 447 | - parser = argparse.ArgumentParser(description='Perform Outpainting on an image.') 448 | - parser.add_argument('--image_path', type=str, required=True, help='Path to the input image.') 449 | - parser.add_argument('--prompt', type=str, required=True, help='Prompt for the Kandinsky inpainting.') 450 | - parser.add_argument('--negative_prompt', type=str, required=True, help='Negative prompt for the Kandinsky inpainting.') 451 | - parser.add_argument('--output_dir', type=str, required=True, help='Directory to save the output image.') 452 | - parser.add_argument('--mask_dir', type=str, required=True, help='Directory to save the mask image.') 453 | - parser.add_argument('--uid', type=str, required=True, help='Unique identifier for the image and mask.') 454 | - args = parser.parse_args() 455 | - main(args) 456 | \ No newline at end of file 457 | diff --git a/scripts/segment_everything.py b/scripts/segment_everything.py 458 | deleted file mode 100644 459 | index c2e9532..0000000 460 | --- a/scripts/segment_everything.py 461 | +++ /dev/null 462 | @@ -1,125 +0,0 @@ 463 | -from ultralytics import YOLO 464 | -from transformers import SamModel, SamProcessor 465 | -import torch 466 | -from diffusers.utils import load_image 467 | -from PIL import Image, ImageOps 468 | -import numpy as np 469 | -import torch 470 | -from diffusers import StableVideoDiffusionPipeline 471 | - 472 | - 473 | - 474 | - 475 | - 476 | - 477 | - 478 | - 479 | - 480 | -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 481 | - 482 | - 483 | - 484 | - 485 | - 486 | - 487 | - 488 | - 489 | - 490 | -def extend_image(image, target_width, target_height, roi_scale=0.5): 491 | - """ 492 | - Extends an image to fit within the specified target dimensions while maintaining the aspect ratio. 493 | - 494 | - Args: 495 | - image (PIL.Image.Image): The image to be extended. 496 | - target_width (int): The desired width of the extended image. 497 | - target_height (int): The desired height of the extended image. 498 | - roi_scale (float, optional): The scale factor applied to the resized image. Defaults to 0.5. 499 | - 500 | - Returns: 501 | - PIL.Image.Image: The extended image. 502 | - """ 503 | - original_image = image 504 | - original_width, original_height = original_image.size 505 | - scale = min(target_width / original_width, target_height / original_height) 506 | - new_width = int(original_width * scale * roi_scale) 507 | - new_height = int(original_height * scale * roi_scale) 508 | - original_image_resized = original_image.resize((new_width, new_height)) 509 | - extended_image = Image.new("RGB", (target_width, target_height), "white") 510 | - paste_x = (target_width - new_width) // 2 511 | - paste_y = (target_height - new_height) // 2 512 | - extended_image.paste(original_image_resized, (paste_x, paste_y)) 513 | - return extended_image 514 | - 515 | - 516 | - 517 | - 518 | - 519 | -def generate_mask_from_bbox(image: Image, segmentation_model: str ,detection_model) -> Image: 520 | - """ 521 | - Generates a mask from the bounding box of an image using YOLO and SAM-ViT models. 522 | - 523 | - Args: 524 | - image_path (str): The path to the input image. 525 | - 526 | - Returns: 527 | - numpy.ndarray: The generated mask as a NumPy array. 528 | - """ 529 | - 530 | - yolo = YOLO(detection_model) 531 | - processor = SamProcessor.from_pretrained(segmentation_model) 532 | - model = SamModel.from_pretrained(segmentation_model).to(device) 533 | - results = yolo(image) 534 | - bboxes = results[0].boxes.xyxy.tolist() 535 | - input_boxes = [[[bboxes[0]]]] 536 | - inputs = processor(load_image(image), input_boxes=input_boxes, return_tensors="pt").to("cuda") 537 | - with torch.no_grad(): 538 | - outputs = model(**inputs) 539 | - mask = processor.image_processor.post_process_masks( 540 | - outputs.pred_masks.cpu(), 541 | - inputs["original_sizes"].cpu(), 542 | - inputs["reshaped_input_sizes"].cpu() 543 | - )[0][0][0].numpy() 544 | - return mask 545 | - 546 | - 547 | - 548 | - 549 | - 550 | - 551 | -def invert_mask(mask_image: Image) -> np.ndarray: 552 | - """Method to invert mask 553 | - Args: 554 | - mask_image (np.ndarray): input mask image 555 | - Returns: 556 | - np.ndarray: inverted mask image 557 | - """ 558 | - inverted_mask_image = ImageOps.invert(mask_image) 559 | - return inverted_mask_image 560 | - 561 | - 562 | - 563 | - 564 | - 565 | - 566 | - 567 | - 568 | -def fetch_video_pipeline(video_model_name): 569 | - """ 570 | - Fetches the video pipeline for image processing. 571 | - 572 | - Args: 573 | - video_model_name (str): The name of the video model. 574 | - 575 | - Returns: 576 | - pipe (StableVideoDiffusionPipeline): The video pipeline. 577 | - 578 | - """ 579 | - pipe = StableVideoDiffusionPipeline.from_pretrained( 580 | - video_model_name, torch_dtype=torch.float16, 581 | - ) 582 | - pipe = pipe.to('cuda') 583 | - pipe.unet= torch.compile(pipe.unet) 584 | - 585 | - 586 | - return pipe 587 | - 588 | diff --git a/scripts/video_pipeline.py b/scripts/video_pipeline.py 589 | deleted file mode 100644 590 | index e69de29..0000000 591 | --------------------------------------------------------------------------------