├── __init__.py ├── shared ├── __init__.py ├── logging │ ├── constants.py │ └── logging.py ├── utils.py └── file_upload │ └── s3.py ├── utils ├── __init__.py ├── refresh_target.py ├── state_refresh.py ├── enum.py ├── cookie_manager │ └── cookie.py ├── local_storage │ ├── url_storage.py │ └── local_storage.py ├── ml_processor │ ├── ml_interface.py │ ├── comfy_workflows │ │ ├── llama2_workflow_api.json │ │ ├── flux_schnell_workflow_api.json │ │ ├── llama_workflow_api.json │ │ ├── sdxl_img2img_workflow_api.json │ │ ├── creative_image_gen.json │ │ ├── ipadapter_composition_workflow_api.json │ │ ├── sd3_workflow_api.json │ │ ├── ipadapter_plus_api.json │ │ ├── dynamicrafter_api.json │ │ ├── sdxl_controlnet_workflow_api.json │ │ ├── sdxl_workflow_api.json │ │ ├── sdxl_openpose_workflow_api.json │ │ └── ipadapter_face_api.json │ ├── motion_module.py │ ├── sai │ │ ├── sai.py │ │ └── utils.py │ └── gpu │ │ ├── gpu.py │ │ └── utils.py ├── encryption.py ├── third_party_auth │ └── google │ │ └── google_auth.py ├── media_processor │ └── video.py ├── cache │ └── cache.py └── common_decorators.py ├── backend ├── __init__.py ├── migrations │ ├── __init__.py │ ├── 0017_credits_used_added.py │ ├── 0002_temp_file_list_added.py │ ├── 0005_model_type_added.py │ ├── 0010_project_metadata_added.py │ ├── 0004_credits_added_in_user.py │ ├── 0003_custom_trained_model_check_added.py │ ├── 0009_log_status_added.py │ ├── 0008_interpolated_clip_list_added.py │ ├── 0006_inference_time_converted_to_float.py │ ├── 0015_log_tag_added.py │ ├── 0016_db_col_size_update.py │ ├── 0007_log_mapped_to_file.py │ ├── 0013_filter_keys_added.py │ ├── 0011_lock_added.py │ └── 0014_file_link_added.py ├── constants.py ├── apps.py └── serializers │ └── dto.py ├── scripts ├── app_version.txt ├── app_settings.toml ├── windows_setup_online.bat ├── windows_setup.bat ├── linux_setup_online.sh ├── linux_setup.sh ├── entrypoint.sh ├── entrypoint.bat └── config.toml ├── .streamlit ├── credentials.toml └── config.toml ├── .env.sample ├── sample_assets ├── sample_images │ └── init_frames │ │ ├── 1.jpg │ │ ├── 2.jpg │ │ └── 3.jpg └── example_generations │ ├── guy-result.mp4 │ ├── lady-result.mp4 │ └── world-result.mp4 ├── .dockerignore ├── pyproject.toml ├── .vscode └── settings.json ├── Dockerfile ├── ui_components ├── widgets │ ├── base_theme.py │ ├── common_element.py │ ├── attach_audio_element.py │ ├── display_element.py │ ├── video_cropping_element.py │ ├── frame_switch_btn.py │ ├── download_file_progress_bar.py │ ├── image_zoom_widgets.py │ ├── image_carousal.py │ ├── add_key_frame_element.py │ └── frame_selector.py ├── components │ ├── shortlist_page.py │ ├── adjust_shot_page.py │ ├── inspiraton_engine_page.py │ ├── user_login_page.py │ ├── timeline_view_page.py │ ├── project_settings_page.py │ ├── query_logger_page.py │ └── new_project_page.py ├── methods │ ├── data_logger.py │ └── training_methods.py └── constants.py ├── manage.py ├── .gitignore ├── requirements.txt ├── django_settings.py ├── .aws └── task-definition.json ├── .github └── workflows │ └── deploy.yml ├── LICENSE.txt ├── auto_refresh.py └── app.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /shared/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /backend/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /backend/migrations/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scripts/app_version.txt: -------------------------------------------------------------------------------- 1 | 0.9.24 -------------------------------------------------------------------------------- /utils/refresh_target.py: -------------------------------------------------------------------------------- 1 | SAVE_STATE = 256 2 | -------------------------------------------------------------------------------- /.streamlit/credentials.toml: -------------------------------------------------------------------------------- 1 | [general] 2 | email="xyz@gmil.com" -------------------------------------------------------------------------------- /scripts/app_settings.toml: -------------------------------------------------------------------------------- 1 | automatic_update = false 2 | gpu_inference = true 3 | runner_process_port = 12345 -------------------------------------------------------------------------------- /utils/state_refresh.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | 4 | def refresh_app(*args): 5 | st.rerun() 6 | -------------------------------------------------------------------------------- /.env.sample: -------------------------------------------------------------------------------- 1 | SERVER=development 2 | SERVER_URL=https://api.banodoco.ai 3 | OFFLINE_MODE=True 4 | HOSTED_BACKGROUND_RUNNER_MODE=False -------------------------------------------------------------------------------- /sample_assets/sample_images/init_frames/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/banodoco/Dough/HEAD/sample_assets/sample_images/init_frames/1.jpg -------------------------------------------------------------------------------- /sample_assets/sample_images/init_frames/2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/banodoco/Dough/HEAD/sample_assets/sample_images/init_frames/2.jpg -------------------------------------------------------------------------------- /sample_assets/sample_images/init_frames/3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/banodoco/Dough/HEAD/sample_assets/sample_images/init_frames/3.jpg -------------------------------------------------------------------------------- /.dockerignore: -------------------------------------------------------------------------------- 1 | # system 2 | __pycache__/ 3 | /.vscode 4 | 5 | # user 6 | /venv 7 | /videos 8 | 9 | banodoco_local.db 10 | doc.py 11 | .env -------------------------------------------------------------------------------- /.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [theme] 2 | base="dark" 3 | primaryColor="#9b9bb1" 4 | backgroundColor="#2a3446" 5 | secondaryBackgroundColor="#26282f" 6 | -------------------------------------------------------------------------------- /sample_assets/example_generations/guy-result.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/banodoco/Dough/HEAD/sample_assets/example_generations/guy-result.mp4 -------------------------------------------------------------------------------- /sample_assets/example_generations/lady-result.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/banodoco/Dough/HEAD/sample_assets/example_generations/lady-result.mp4 -------------------------------------------------------------------------------- /backend/constants.py: -------------------------------------------------------------------------------- 1 | from utils.enum import ExtendedEnum 2 | 3 | 4 | class UserType(ExtendedEnum): 5 | USER = "user" 6 | ADMIN = "admin" 7 | -------------------------------------------------------------------------------- /sample_assets/example_generations/world-result.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/banodoco/Dough/HEAD/sample_assets/example_generations/world-result.mp4 -------------------------------------------------------------------------------- /backend/apps.py: -------------------------------------------------------------------------------- 1 | from django.apps import AppConfig 2 | 3 | 4 | class BackendConfig(AppConfig): 5 | name = "backend" 6 | verbose_name = "Local backend" 7 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "Dough" 3 | description = "Open source tool for steering AI animations with precision" 4 | readme = "README.md" 5 | 6 | [tool.black] 7 | line-length = 110 8 | target-version = ['py310'] 9 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "workbench.colorTheme": "Monokai", 3 | "[python]": { 4 | "editor.defaultFormatter": "ms-python.black-formatter", 5 | "editor.formatOnSave": true, 6 | "editor.formatOnSaveMode": "file", 7 | }, 8 | "terminal.integrated.scrollback": 1000000, 9 | } -------------------------------------------------------------------------------- /utils/enum.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | 4 | class ExtendedEnum(Enum): 5 | 6 | @classmethod 7 | def value_list(cls): 8 | return list(map(lambda c: c.value, cls)) 9 | 10 | @classmethod 11 | def has_value(cls, value): 12 | return value in cls._value2member_map_ 13 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM python:3.10.2 2 | 3 | RUN mkdir banodoco 4 | 5 | WORKDIR /banodoco 6 | 7 | COPY requirements.txt . 8 | 9 | RUN pip3 install --upgrade pip 10 | RUN pip3 install -r requirements.txt 11 | RUN apt-get update && apt-get install -y ffmpeg 12 | RUN echo "SERVER=production" > .env 13 | 14 | COPY . . 15 | 16 | EXPOSE 5500 17 | 18 | CMD ["sh", "entrypoint.sh"] -------------------------------------------------------------------------------- /ui_components/widgets/base_theme.py: -------------------------------------------------------------------------------- 1 | import time 2 | import streamlit as st 3 | from utils.state_refresh import refresh_app 4 | 5 | 6 | class BaseTheme: 7 | @staticmethod 8 | def success_msg(msg): 9 | st.success(msg) 10 | time.sleep(0.5) 11 | refresh_app() 12 | 13 | @staticmethod 14 | def error_msg(msg): 15 | st.error(msg) 16 | time.sleep(0.5) 17 | refresh_app() 18 | -------------------------------------------------------------------------------- /backend/migrations/0017_credits_used_added.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 4.2.1 on 2024-09-17 11:34 2 | 3 | from django.db import migrations, models 4 | 5 | 6 | class Migration(migrations.Migration): 7 | 8 | dependencies = [ 9 | ("backend", "0016_db_col_size_update"), 10 | ] 11 | 12 | operations = [ 13 | migrations.AddField( 14 | model_name="inferencelog", 15 | name="credits_used", 16 | field=models.FloatField(default=0), 17 | ), 18 | ] 19 | -------------------------------------------------------------------------------- /backend/migrations/0002_temp_file_list_added.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 4.2.1 on 2023-07-26 07:18 2 | 3 | from django.db import migrations, models 4 | 5 | 6 | class Migration(migrations.Migration): 7 | 8 | dependencies = [ 9 | ("backend", "0001_initial_setup"), 10 | ] 11 | 12 | operations = [ 13 | migrations.AddField( 14 | model_name="project", 15 | name="temp_file_list", 16 | field=models.TextField(default=None, null=True), 17 | ), 18 | ] 19 | -------------------------------------------------------------------------------- /backend/migrations/0005_model_type_added.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 4.2.1 on 2023-08-15 07:46 2 | 3 | from django.db import migrations, models 4 | 5 | 6 | class Migration(migrations.Migration): 7 | 8 | dependencies = [ 9 | ("backend", "0004_credits_added_in_user"), 10 | ] 11 | 12 | operations = [ 13 | migrations.AddField( 14 | model_name="aimodel", 15 | name="model_type", 16 | field=models.TextField(blank=True, default=""), 17 | ), 18 | ] 19 | -------------------------------------------------------------------------------- /backend/migrations/0010_project_metadata_added.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 4.2.1 on 2023-10-14 01:55 2 | 3 | from django.db import migrations, models 4 | 5 | 6 | class Migration(migrations.Migration): 7 | 8 | dependencies = [ 9 | ("backend", "0009_log_status_added"), 10 | ] 11 | 12 | operations = [ 13 | migrations.AddField( 14 | model_name="project", 15 | name="meta_data", 16 | field=models.TextField(default=None, null=True), 17 | ), 18 | ] 19 | -------------------------------------------------------------------------------- /backend/migrations/0004_credits_added_in_user.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 4.2.1 on 2023-08-02 06:43 2 | 3 | from django.db import migrations, models 4 | 5 | 6 | class Migration(migrations.Migration): 7 | 8 | dependencies = [ 9 | ("backend", "0003_custom_trained_model_check_added"), 10 | ] 11 | 12 | operations = [ 13 | migrations.AddField( 14 | model_name="user", 15 | name="total_credits", 16 | field=models.FloatField(default=0), 17 | ), 18 | ] 19 | -------------------------------------------------------------------------------- /backend/migrations/0003_custom_trained_model_check_added.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 4.2.1 on 2023-07-31 05:47 2 | 3 | from django.db import migrations, models 4 | 5 | 6 | class Migration(migrations.Migration): 7 | 8 | dependencies = [ 9 | ("backend", "0002_temp_file_list_added"), 10 | ] 11 | 12 | operations = [ 13 | migrations.AddField( 14 | model_name="aimodel", 15 | name="custom_trained", 16 | field=models.BooleanField(default=False), 17 | ), 18 | ] 19 | -------------------------------------------------------------------------------- /backend/migrations/0009_log_status_added.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 4.2.1 on 2023-10-09 14:17 2 | 3 | from django.db import migrations, models 4 | 5 | 6 | class Migration(migrations.Migration): 7 | 8 | dependencies = [ 9 | ("backend", "0008_interpolated_clip_list_added"), 10 | ] 11 | 12 | operations = [ 13 | migrations.AddField( 14 | model_name="inferencelog", 15 | name="status", 16 | field=models.CharField(default="", max_length=255), 17 | ), 18 | ] 19 | -------------------------------------------------------------------------------- /shared/logging/constants.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from utils.enum import ExtendedEnum 3 | 4 | 5 | @dataclass 6 | class LoggingPayload: 7 | message: str 8 | data: dict = field(default_factory=dict) 9 | 10 | 11 | class LoggingType(ExtendedEnum): 12 | INFO = "info" 13 | INFERENCE_CALL = "inference_call" 14 | INFERENCE_RESULT = "inference_result" 15 | ERROR = "error" 16 | DEBUG = "debug" 17 | 18 | 19 | class LoggingMode(ExtendedEnum): 20 | OFFLINE = "offline" 21 | ONLINE = "online" 22 | -------------------------------------------------------------------------------- /manage.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | if __name__ == "__main__": 4 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_settings") 5 | 6 | try: 7 | from django.core.management import execute_from_command_line 8 | except ImportError as exc: 9 | raise ImportError( 10 | "Couldn't import Django. Are you sure it's installed and " 11 | "available on your PYTHONPATH environment variable? Did you " 12 | "forget to activate a virtual environment?" 13 | ) from exc 14 | 15 | execute_from_command_line(sys.argv) 16 | -------------------------------------------------------------------------------- /scripts/windows_setup_online.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | set "folderName=Dough" 3 | if not exist "%folderName%\" ( 4 | if /i not "%CD%"=="%~dp0%folderName%\" ( 5 | git clone --depth 1 -b main https://github.com/banodoco/Dough.git 6 | cd Dough 7 | python -m venv dough-env 8 | call dough-env\Scripts\activate.bat 9 | python.exe -m pip install --upgrade pip 10 | pip install -r requirements.txt 11 | pip install websocket 12 | call dough-env\Scripts\deactivate.bat 13 | copy .env.sample .env 14 | cd .. 15 | pause 16 | ) 17 | ) -------------------------------------------------------------------------------- /backend/migrations/0008_interpolated_clip_list_added.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 4.2.1 on 2023-09-16 03:09 2 | 3 | from django.db import migrations, models 4 | 5 | 6 | class Migration(migrations.Migration): 7 | 8 | dependencies = [ 9 | ("backend", "0007_log_mapped_to_file"), 10 | ] 11 | 12 | operations = [ 13 | migrations.RemoveField( 14 | model_name="timing", 15 | name="interpolated_clip", 16 | ), 17 | migrations.AddField( 18 | model_name="timing", 19 | name="interpolated_clip_list", 20 | field=models.TextField(default=None, null=True), 21 | ), 22 | ] 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | comfyui*.log 6 | 7 | venv 8 | dough-env 9 | .vscode/launch.json 10 | 11 | # system 12 | .DS_Store 13 | 14 | # user files 15 | app_settings.csv 16 | videos 17 | videos/* 18 | temp_dir/ 19 | /temp/ 20 | training_data 21 | inference_log/* 22 | *.db 23 | test.py 24 | doc.py 25 | .env 26 | data.json 27 | comfy_runner/ 28 | ComfyUI/ 29 | output/ 30 | images.zip 31 | upload_data.json 32 | 33 | # generated file TODO: move inside videos 34 | depth.png 35 | masked_image.png 36 | concatenated.mp4 37 | temp.png 38 | scripts/app_checkpoint.json 39 | current.md 40 | refresh_lock.json 41 | -------------------------------------------------------------------------------- /ui_components/widgets/common_element.py: -------------------------------------------------------------------------------- 1 | from utils.data_repo.data_repo import DataRepo 2 | import streamlit as st 3 | import time 4 | from utils.state_refresh import refresh_app 5 | 6 | 7 | def duplicate_shot_button(shot_uuid, position="shot_view"): 8 | data_repo = DataRepo() 9 | shot = data_repo.get_shot_from_uuid(shot_uuid) 10 | if st.button( 11 | "Duplicate shot", 12 | key=f"duplicate_btn_{shot.uuid}_{position}", 13 | help="This will duplicate this shot.", 14 | use_container_width=True, 15 | ): 16 | data_repo.duplicate_shot(shot.uuid) 17 | st.success("Shot duplicated successfully") 18 | time.sleep(0.3) 19 | refresh_app() 20 | -------------------------------------------------------------------------------- /backend/migrations/0006_inference_time_converted_to_float.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 4.2.1 on 2023-09-13 09:23 2 | 3 | from django.db import migrations, models 4 | 5 | 6 | class Migration(migrations.Migration): 7 | 8 | dependencies = [ 9 | ("backend", "0005_model_type_added"), 10 | ] 11 | 12 | operations = [ 13 | migrations.AlterField( 14 | model_name="inferencelog", 15 | name="total_inference_time", 16 | field=models.FloatField(default=0), 17 | ), 18 | migrations.AlterField( 19 | model_name="timing", 20 | name="strength", 21 | field=models.FloatField(default=1), 22 | ), 23 | ] 24 | -------------------------------------------------------------------------------- /backend/migrations/0015_log_tag_added.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 4.2.1 on 2024-06-16 16:08 2 | 3 | from django.db import migrations, models 4 | 5 | 6 | class Migration(migrations.Migration): 7 | 8 | dependencies = [ 9 | ("backend", "0014_file_link_added"), 10 | ] 11 | 12 | operations = [ 13 | migrations.AddField( 14 | model_name="inferencelog", 15 | name="generation_source", 16 | field=models.CharField(blank=True, default="", max_length=255), 17 | ), 18 | migrations.AddField( 19 | model_name="inferencelog", 20 | name="generation_tag", 21 | field=models.CharField(blank=True, default="", max_length=255), 22 | ), 23 | ] 24 | -------------------------------------------------------------------------------- /backend/migrations/0016_db_col_size_update.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 4.2.1 on 2024-09-14 09:31 2 | 3 | from django.db import migrations, models 4 | 5 | 6 | class Migration(migrations.Migration): 7 | 8 | dependencies = [ 9 | ("backend", "0015_log_tag_added"), 10 | ] 11 | 12 | operations = [ 13 | migrations.AlterField( 14 | model_name="appsetting", 15 | name="aws_access_key", 16 | field=models.CharField(blank=True, default="", max_length=1024), 17 | ), 18 | migrations.AlterField( 19 | model_name="appsetting", 20 | name="aws_secret_access_key", 21 | field=models.CharField(blank=True, default="", max_length=1024), 22 | ), 23 | ] 24 | -------------------------------------------------------------------------------- /backend/migrations/0007_log_mapped_to_file.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 4.2.1 on 2023-09-15 04:30 2 | 3 | from django.db import migrations, models 4 | import django.db.models.deletion 5 | 6 | 7 | class Migration(migrations.Migration): 8 | 9 | dependencies = [ 10 | ("backend", "0006_inference_time_converted_to_float"), 11 | ] 12 | 13 | operations = [ 14 | migrations.AddField( 15 | model_name="internalfileobject", 16 | name="inference_log", 17 | field=models.ForeignKey( 18 | default=None, 19 | null=True, 20 | on_delete=django.db.models.deletion.SET_NULL, 21 | to="backend.inferencelog", 22 | ), 23 | ), 24 | ] 25 | -------------------------------------------------------------------------------- /backend/migrations/0013_filter_keys_added.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 4.2.1 on 2024-02-18 07:55 2 | 3 | from django.db import migrations, models 4 | 5 | 6 | class Migration(migrations.Migration): 7 | 8 | dependencies = [ 9 | ("backend", "0012_shot_added_and_redundant_fields_removed"), 10 | ] 11 | 12 | operations = [ 13 | migrations.AddField( 14 | model_name="inferencelog", 15 | name="model_name", 16 | field=models.CharField(blank=True, default="", max_length=512), 17 | ), 18 | migrations.AddField( 19 | model_name="internalfileobject", 20 | name="shot_uuid", 21 | field=models.CharField(blank=True, default="", max_length=255), 22 | ), 23 | ] 24 | -------------------------------------------------------------------------------- /utils/cookie_manager/cookie.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import streamlit as st 3 | import extra_streamlit_components as stx 4 | 5 | from utils.constants import AUTH_TOKEN 6 | 7 | # NOTE: code not working properly. check again after patch from the streamlit team 8 | # @st.cache(allow_output_mutation=True) 9 | # def get_manager(): 10 | # return stx.CookieManager() 11 | 12 | 13 | # def get_cookie(key): 14 | # cookie_manager = get_manager() 15 | # return cookie_manager.get(cookie=key) 16 | 17 | # def set_cookie(key, value): 18 | # cookie_manager = get_manager() 19 | # expiration_time = datetime.datetime.now() + datetime.timedelta(days=1) 20 | # cookie_manager.set(key, value, expires_at=expiration_time) 21 | 22 | # def delete_cookie(key): 23 | # cookie_manager = get_manager() 24 | # cookie = get_cookie(key) 25 | # if cookie: 26 | # cookie_manager.delete(key) 27 | -------------------------------------------------------------------------------- /utils/local_storage/url_storage.py: -------------------------------------------------------------------------------- 1 | # no persistent storage present for streamlit at the moment, so storing things in the url. will update this asap 2 | import streamlit as st 3 | 4 | 5 | def get_url_param(key): 6 | params = st.experimental_get_query_params() 7 | val = params.get(key) 8 | if isinstance(val, list): 9 | res = val[0] 10 | else: 11 | res = val 12 | 13 | if not res and (key in st.session_state and st.session_state[key]): 14 | set_url_param(key, st.session_state[key]) 15 | return st.session_state[key] 16 | return res 17 | 18 | 19 | def set_url_param(key, value): 20 | st.session_state[key] = value 21 | st.experimental_set_query_params(**{key: [value]}) 22 | 23 | 24 | def delete_url_param(key): 25 | print("deleting key: ", key) 26 | if key in st.session_state: 27 | del st.session_state[key] 28 | 29 | st.experimental_set_query_params(**{key: None}) 30 | -------------------------------------------------------------------------------- /scripts/windows_setup.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | set "folderName=Dough" 3 | if not exist "%folderName%\" ( 4 | if /i not "%CD%"=="%~dp0%folderName%\" ( 5 | git clone --depth 1 -b main https://github.com/banodoco/Dough.git 6 | cd Dough 7 | git clone --depth 1 -b main https://github.com/piyushK52/comfy_runner.git 8 | git clone https://github.com/comfyanonymous/ComfyUI.git 9 | python -m venv dough-env 10 | call dough-env\Scripts\activate.bat 11 | python.exe -m pip install --upgrade pip 12 | pip install -r requirements.txt 13 | pip install websocket 14 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 15 | pip install -r comfy_runner\requirements.txt 16 | pip install -r ComfyUI\requirements.txt 17 | call dough-env\Scripts\deactivate.bat 18 | copy .env.sample .env 19 | cd .. 20 | pause 21 | ) 22 | ) -------------------------------------------------------------------------------- /utils/ml_processor/ml_interface.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | from shared.constants import GPU_INFERENCE_ENABLED_KEY, ConfigManager 4 | 5 | config_manager = ConfigManager() 6 | gpu_enabled = config_manager.get(GPU_INFERENCE_ENABLED_KEY, False) 7 | 8 | def get_ml_client(): 9 | from utils.ml_processor.sai.api import APIProcessor 10 | from utils.ml_processor.gpu.gpu import GPUProcessor 11 | 12 | return APIProcessor() if not gpu_enabled else GPUProcessor() 13 | 14 | 15 | class MachineLearningProcessor(ABC): 16 | def __init__(self): 17 | pass 18 | 19 | def predict_model_output_standardized(self, *args, **kwargs): 20 | pass 21 | 22 | def predict_model_output(self, *args, **kwargs): 23 | pass 24 | 25 | def upload_training_data(self, *args, **kwargs): 26 | pass 27 | 28 | # NOTE: implementation not neccessary as this functionality is removed from the app 29 | def dreambooth_training(self, *args, **kwargs): 30 | pass 31 | -------------------------------------------------------------------------------- /scripts/linux_setup_online.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Store the current directory path 4 | current_dir="$(pwd)" 5 | 6 | # Define the project directory path 7 | project_dir="$current_dir/Dough" 8 | 9 | # Check if the "Dough" directory doesn't exist and we're not already inside it 10 | if [ ! -d "$project_dir" ] && [ "$(basename "$current_dir")" != "Dough" ]; then 11 | # Clone the git repo 12 | git clone --depth 1 -b main https://github.com/banodoco/Dough.git "$project_dir" 13 | cd "$project_dir" 14 | 15 | # Create virtual environment 16 | python3.10 -m venv "dough-env" 17 | 18 | # Install system dependencies 19 | if command -v sudo &> /dev/null; then 20 | sudo apt-get update && sudo apt-get install -y libpq-dev python3.10-dev 21 | else 22 | apt-get update && apt-get install -y libpq-dev python3.10-dev 23 | fi 24 | 25 | echo $(pwd) 26 | . ./dough-env/bin/activate && pip install -r "requirements.txt" 27 | 28 | # Copy the environment file 29 | cp "$project_dir/.env.sample" "$project_dir/.env" 30 | fi 31 | -------------------------------------------------------------------------------- /ui_components/components/shortlist_page.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from ui_components.components.explorer_page import gallery_image_view 3 | from utils.data_repo.data_repo import DataRepo 4 | from utils import st_memory 5 | 6 | 7 | def shortlist_page(project_uuid): 8 | 9 | st.markdown(f"#### :green[{st.session_state['main_view_type']}] > :red[{st.session_state['page']}]") 10 | 11 | data_repo = DataRepo() 12 | project_setting = data_repo.get_project_setting(project_uuid) 13 | # columnn_selecter() 14 | # k1,k2 = st.columns([5,1]) 15 | # shortlist_page_number = k1.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages), horizontal=True, key="shortlist_gallery") 16 | # with k2: 17 | # open_detailed_view_for_all = st_memory.toggle("Open prompt details for all:", key='shortlist_gallery_toggle',value=False) 18 | st.markdown("***") 19 | gallery_image_view( 20 | project_uuid, 21 | True, 22 | view=["view_inference_details", "add_to_any_shot", "add_and_remove_from_shortlist"], 23 | ) 24 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | streamlit==1.27.0 2 | streamlit-image-comparison==0.0.3 3 | opencv-python-headless 4 | opencv-python==4.8.0.74 5 | sahi 6 | Pillow==9.4.0 7 | moviepy==1.0.3 8 | pandas==2.2.2 9 | replicate==0.26.1 10 | requests==2.28.2 11 | boto3==1.26.54 12 | ffmpeg-python==0.2.0 13 | streamlit-drawable-canvas==0.9.0 14 | numpy==1.24.4 15 | mesa 16 | streamlit-javascript==0.1.5 17 | streamlit-cropper==0.2.1 18 | pydub==0.25.1 19 | requests-toolbelt==1.0.0 20 | streamlit-option-menu==0.3.6 21 | colorlog==6.7.0 22 | Django==4.2.1 23 | djangorestframework==3.14.0 24 | python-dotenv==0.19.2 25 | scikit-image==0.21.0 26 | sentry-sdk==1.29.2 27 | st-clickable-images==0.0.3 28 | htbuilder==0.6.1 29 | streamlit-extras==0.2.7 30 | cryptography==41.0.1 31 | watchdog==3.0.0 32 | httpx-oauth==0.13.0 33 | extra-streamlit-components==0.1.56 34 | wrapt==1.15.0 35 | pydantic==1.10.9 36 | streamlit-server-state==0.17.1 37 | setproctitle==1.3.3 38 | gitdb==4.0.11 39 | websockets 40 | psutil==5.9.8 41 | black==24.4.2 42 | onnxruntime==1.18.1 43 | opencv-contrib-python==4.10.0.84 44 | portalocker==2.10.1 45 | Flask==3.0.3 46 | PyJWT==2.9.0 -------------------------------------------------------------------------------- /backend/migrations/0011_lock_added.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 4.2.1 on 2023-10-17 05:06 2 | 3 | from django.db import migrations, models 4 | import uuid 5 | 6 | 7 | class Migration(migrations.Migration): 8 | 9 | dependencies = [ 10 | ("backend", "0010_project_metadata_added"), 11 | ] 12 | 13 | operations = [ 14 | migrations.CreateModel( 15 | name="Lock", 16 | fields=[ 17 | ( 18 | "id", 19 | models.BigAutoField( 20 | auto_created=True, primary_key=True, serialize=False, verbose_name="ID" 21 | ), 22 | ), 23 | ("uuid", models.UUIDField(default=uuid.uuid4)), 24 | ("created_on", models.DateTimeField(auto_now_add=True)), 25 | ("updated_on", models.DateTimeField(auto_now=True)), 26 | ("is_disabled", models.BooleanField(default=False)), 27 | ("row_key", models.CharField(max_length=255, unique=True)), 28 | ], 29 | options={ 30 | "db_table": "lock", 31 | }, 32 | ), 33 | ] 34 | -------------------------------------------------------------------------------- /utils/local_storage/local_storage.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | 4 | 5 | MOTION_LORA_DB = "data.json" 6 | 7 | 8 | def is_file_present(filename): 9 | script_directory = os.path.dirname(os.path.abspath(__file__)) 10 | file_path = os.path.join(script_directory, filename) 11 | return os.path.isfile(file_path) 12 | 13 | 14 | def write_to_motion_lora_local_db(update_data): 15 | data_store = MOTION_LORA_DB 16 | 17 | data = {} 18 | if os.path.exists(data_store): 19 | try: 20 | with open(data_store, "r", encoding="utf-8") as file: 21 | data = json.loads(file.read()) 22 | except Exception as e: 23 | pass 24 | 25 | for key, value in update_data.items(): 26 | data[key] = value 27 | 28 | data = json.dumps(data, indent=4) 29 | with open(data_store, "w", encoding="utf-8") as file: 30 | file.write(data) 31 | 32 | 33 | def read_from_motion_lora_local_db(key=None): 34 | data_store = MOTION_LORA_DB 35 | 36 | data = {} 37 | if os.path.exists(data_store): 38 | with open(data_store, "r", encoding="utf-8") as file: 39 | data = json.loads(file.read()) 40 | 41 | return data[key] if key in data else data 42 | -------------------------------------------------------------------------------- /utils/ml_processor/comfy_workflows/llama2_workflow_api.json: -------------------------------------------------------------------------------- 1 | { 2 | "1": { 3 | "inputs": { 4 | "Model": "llama-2-7b.Q5_0.gguf", 5 | "n_ctx": 0 6 | }, 7 | "class_type": "Load LLM Model Basic", 8 | "_meta": { 9 | "title": "Load LLM Model Basic" 10 | } 11 | }, 12 | "14": { 13 | "inputs": { 14 | "text": [ 15 | "15", 16 | 0 17 | ] 18 | }, 19 | "class_type": "ShowText|pysssss", 20 | "_meta": { 21 | "title": "Show Text 🐍" 22 | } 23 | }, 24 | "15": { 25 | "inputs": { 26 | "prompt": "write a poem on finding your way in about 100 words", 27 | "suffix": "", 28 | "max_response_tokens": 500, 29 | "temperature": 0.8, 30 | "top_p": 0.95, 31 | "min_p": 0.05, 32 | "typical_p": 1, 33 | "echo": false, 34 | "frequency_penalty": 0, 35 | "presence_penalty": 0, 36 | "repeat_penalty": 1.1, 37 | "top_k": 40, 38 | "seed": 273, 39 | "tfs_z": 1, 40 | "mirostat_mode": 0, 41 | "mirostat_tau": 5, 42 | "mirostat_eta": 0.1, 43 | "LLM": [ 44 | "1", 45 | 0 46 | ] 47 | }, 48 | "class_type": "Call LLM Advanced", 49 | "_meta": { 50 | "title": "Call LLM Advanced" 51 | } 52 | } 53 | } -------------------------------------------------------------------------------- /ui_components/widgets/attach_audio_element.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from ui_components.methods.common_methods import save_audio_file 3 | from ui_components.models import InternalProjectObject, InternalSettingObject 4 | from utils.data_repo.data_repo import DataRepo 5 | from utils.state_refresh import refresh_app 6 | 7 | 8 | def attach_audio_element(project_uuid, expanded): 9 | data_repo = DataRepo() 10 | project_setting: InternalSettingObject = data_repo.get_project_setting(project_uuid) 11 | 12 | with st.expander("🔊 Audio", expanded=expanded): 13 | 14 | uploaded_file = st.file_uploader( 15 | "Attach audio", type=["mp3"], help="This will attach this audio when you render a video" 16 | ) 17 | if st.button("Upload and attach new audio"): 18 | if uploaded_file: 19 | save_audio_file(uploaded_file, project_uuid) 20 | refresh_app() 21 | else: 22 | st.warning("No file selected") 23 | 24 | if project_setting.audio: 25 | # TODO: store "extracted_audio.mp3" in a constant 26 | if project_setting.audio.name == "extracted_audio.mp3": 27 | st.info("You have attached the audio from the video you uploaded.") 28 | 29 | if project_setting.audio.location: 30 | st.audio(project_setting.audio.location) 31 | -------------------------------------------------------------------------------- /scripts/linux_setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Store the current directory path 4 | current_dir="$(pwd)" 5 | 6 | # Define the project directory path 7 | project_dir="$current_dir/Dough" 8 | 9 | # Check if the "Dough" directory doesn't exist and we're not already inside it 10 | if [ ! -d "$project_dir" ] && [ "$(basename "$current_dir")" != "Dough" ]; then 11 | # Clone the git repo 12 | git clone --depth 1 -b main https://github.com/banodoco/Dough.git "$project_dir" 13 | cd "$project_dir" 14 | git clone --depth 1 -b main https://github.com/piyushK52/comfy_runner.git 15 | git clone https://github.com/comfyanonymous/ComfyUI.git 16 | 17 | # Create virtual environment 18 | python3 -m venv "dough-env" 19 | 20 | # Install system dependencies 21 | if command -v sudo &> /dev/null; then 22 | sudo apt-get update && sudo apt-get install -y libpq-dev python3.10-dev 23 | else 24 | apt-get update && apt-get install -y libpq-dev python3.10-dev 25 | fi 26 | 27 | # Install Python dependencies 28 | echo $(pwd) 29 | . ./dough-env/bin/activate && pip install -r "requirements.txt" 30 | . ./dough-env/bin/activate && pip install -r "comfy_runner/requirements.txt" 31 | . ./dough-env/bin/activate && pip install -r "ComfyUI/requirements.txt" 32 | 33 | # Copy the environment file 34 | cp "$project_dir/.env.sample" "$project_dir/.env" 35 | fi 36 | -------------------------------------------------------------------------------- /utils/ml_processor/motion_module.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class MotionModuleCheckpoint: 6 | name: str 7 | 8 | 9 | # make sure to have unique names (streamlit limitation) 10 | class AnimateDiffCheckpoint: 11 | mm_v15 = MotionModuleCheckpoint(name="mm_sd_v15.ckpt") 12 | mm_v14 = MotionModuleCheckpoint(name="mm_sd_v14.ckpt") 13 | 14 | @staticmethod 15 | def get_name_list(): 16 | checkpoint_names = [ 17 | getattr(AnimateDiffCheckpoint, attr).name 18 | for attr in dir(AnimateDiffCheckpoint) 19 | if not callable(getattr(AnimateDiffCheckpoint, attr)) 20 | and not attr.startswith("__") 21 | and isinstance(getattr(AnimateDiffCheckpoint, attr), MotionModuleCheckpoint) 22 | ] 23 | return checkpoint_names 24 | 25 | @staticmethod 26 | def get_model_from_name(name): 27 | checkpoint_list = [ 28 | getattr(AnimateDiffCheckpoint, attr) 29 | for attr in dir(AnimateDiffCheckpoint) 30 | if not callable(getattr(AnimateDiffCheckpoint, attr)) 31 | and not attr.startswith("__") 32 | and isinstance(getattr(AnimateDiffCheckpoint, attr), MotionModuleCheckpoint) 33 | ] 34 | 35 | for ckpt in checkpoint_list: 36 | if ckpt.name == name: 37 | return ckpt 38 | 39 | return None 40 | -------------------------------------------------------------------------------- /ui_components/widgets/display_element.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | import streamlit as st 3 | from shared.constants import SERVER, ServerType 4 | from ui_components.methods.file_methods import get_file_size 5 | from ui_components.models import InternalFileObject 6 | from utils.local_storage.local_storage import read_from_motion_lora_local_db 7 | 8 | MAX_LOADING_FILE_SIZE = 10 9 | 10 | 11 | def individual_video_display_element(file: Union[InternalFileObject, str], dont_bypass_file_size_check=True): 12 | file_location = file.location if file and not isinstance(file, str) and file.location else file 13 | show_video_file = ( 14 | SERVER == ServerType.DEVELOPMENT.value 15 | or not dont_bypass_file_size_check 16 | or get_file_size(file_location) < MAX_LOADING_FILE_SIZE 17 | ) 18 | if file_location: 19 | ( 20 | st.video(file_location, format="mp4", start_time=0) 21 | if show_video_file 22 | else st.info("Video file too large to display") 23 | ) 24 | else: 25 | st.error("No video present") 26 | 27 | 28 | def display_motion_lora(motion_lora, lora_file_dict={}): 29 | filename_video_dict = read_from_motion_lora_local_db() 30 | 31 | if motion_lora and motion_lora in filename_video_dict and filename_video_dict[motion_lora]: 32 | st.image(filename_video_dict[motion_lora]) 33 | elif motion_lora in lora_file_dict: 34 | loras = [ele.split("/")[-1] for ele in lora_file_dict.keys()] 35 | try: 36 | idx = loras.index(motion_lora) 37 | if lora_file_dict[list(lora_file_dict.keys())[idx]]: 38 | st.image(lora_file_dict[list(lora_file_dict.keys())[idx]]) 39 | except ValueError: 40 | st.write("") 41 | -------------------------------------------------------------------------------- /ui_components/components/adjust_shot_page.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from ui_components.widgets.shot_view import shot_keyframe_element 3 | from ui_components.components.explorer_page import gallery_image_view 4 | from ui_components.components.explorer_page import generate_images_element 5 | from ui_components.components.frame_styling_page import frame_styling_page 6 | from ui_components.widgets.frame_selector import frame_selector_widget 7 | from utils import st_memory 8 | from ui_components.widgets.sidebar_logger import sidebar_logger 9 | from utils.data_repo.data_repo import DataRepo 10 | 11 | 12 | def adjust_shot_page(shot_uuid: str, h2): 13 | with st.sidebar: 14 | frame_selection = frame_selector_widget(show_frame_selector=True) 15 | 16 | data_repo = DataRepo() 17 | shot = data_repo.get_shot_from_uuid(shot_uuid) 18 | 19 | if frame_selection == "": 20 | with st.sidebar: 21 | st.write("") 22 | 23 | with st.expander("🔍 Generation log", expanded=True): 24 | # if st_memory.toggle("Open", value=True, key="generaton_log_toggle"): 25 | sidebar_logger(st.session_state["shot_uuid"]) 26 | 27 | st.markdown( 28 | f"#### :green[{st.session_state['main_view_type']}] > :red[{st.session_state['page']}] > :blue[{shot.name}]" 29 | ) 30 | st.markdown("***") 31 | 32 | column1, column2 = st.columns([2, 1.35]) 33 | with column1: 34 | st.markdown(f"### 🎬 '{shot.name}' frames") 35 | st.write("##### _\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_") 36 | items_per_row = st_memory.slider("Items per row:", 1, 10, 6, key="items_per_row") 37 | 38 | shot_keyframe_element(st.session_state["shot_uuid"], items_per_row, column2, position="Individual") 39 | 40 | else: 41 | frame_styling_page(st.session_state["shot_uuid"]) 42 | -------------------------------------------------------------------------------- /utils/encryption.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import hashlib 3 | import os 4 | from cryptography.fernet import Fernet 5 | from shared.constants import ENCRYPTION_KEY 6 | 7 | 8 | class Encryptor: 9 | def __init__(self): 10 | self.cipher = Fernet(ENCRYPTION_KEY) 11 | 12 | def encrypt(self, data: str): 13 | data_bytes = data.encode("utf-8") 14 | encrypted_data = self.cipher.encrypt(data_bytes) 15 | return encrypted_data 16 | 17 | def decrypt(self, data: str): 18 | data_bytes = data[2:-1].encode() if data.startswith("b'") else data.encode() 19 | decrypted_data = self.cipher.decrypt(data_bytes) 20 | return decrypted_data.decode("utf-8") 21 | 22 | def encrypt_json(self, data: str): 23 | encrypted_data = self.encrypt(data) 24 | encrypted_data_str = base64.b64encode(encrypted_data).decode("utf-8") 25 | return encrypted_data_str 26 | 27 | def decrypt_json(self, data: str): 28 | encrypted_data_bytes = base64.b64decode(data) 29 | decrypted_data = self.decrypt(encrypted_data_bytes.decode("utf-8")) 30 | return decrypted_data 31 | 32 | 33 | def validate_file_hash(file_path, expected_hash_list): 34 | if not os.path.exists(file_path): 35 | return False 36 | 37 | hash_md5 = hashlib.md5() 38 | with open(file_path, "rb") as f: 39 | for chunk in iter(lambda: f.read(4096), b""): 40 | hash_md5.update(chunk) 41 | calculated_hash = hash_md5.hexdigest() 42 | 43 | return len(expected_hash_list) and str(calculated_hash) in expected_hash_list 44 | 45 | 46 | def generate_file_hash(file_path): 47 | if not os.path.exists(file_path): 48 | return None 49 | 50 | hash_md5 = hashlib.md5() 51 | with open(file_path, "rb") as f: 52 | for chunk in iter(lambda: f.read(4096), b""): 53 | hash_md5.update(chunk) 54 | 55 | return hash_md5.hexdigest() 56 | -------------------------------------------------------------------------------- /utils/third_party_auth/google/google_auth.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from dotenv import load_dotenv 3 | import requests 4 | 5 | from shared.constants import SERVER_URL 6 | from shared.logging.logging import AppLogger 7 | 8 | load_dotenv() 9 | 10 | logger = AppLogger() 11 | 12 | 13 | def get_auth_provider(): 14 | return GoogleAuthProvider() # TODO: add test provider for development 15 | 16 | 17 | class AuthProvider(ABC): 18 | def get_auth_url(self, *args, **kwargs): 19 | pass 20 | 21 | def verify_auth_details(self, *args, **kwargs): 22 | pass 23 | 24 | 25 | # TODO: make this implementation 'proper' 26 | class GoogleAuthProvider(AuthProvider): 27 | def __init__(self): 28 | self.url = f"{SERVER_URL}/v1/authentication/google" 29 | 30 | def get_auth_url(self, redirect_uri): 31 | params = {"redirect_uri": redirect_uri} 32 | 33 | response = requests.get(self.url, params=params) 34 | 35 | if response.status_code == 200: 36 | data = response.json() 37 | auth_url = data["payload"]["data"]["url"] 38 | return f""" Google login -> """ 39 | else: 40 | print(f"Error: {response.status_code} - {response.text}") 41 | 42 | return None 43 | 44 | def verify_auth_details(self, auth_details=None): 45 | response = requests.post(self.url, json=auth_details, headers={"Content-Type": "application/json"}) 46 | 47 | if response.status_code == 200: 48 | data = response.json() 49 | if not data["status"]: 50 | return None, None, None 51 | 52 | user = {"name": data["payload"]["user"]["name"], "email": data["payload"]["user"]["email"]} 53 | return data["payload"]["token"], data["payload"]["refresh_token"], user 54 | else: 55 | logger.error("auth verification failed:", response.text) 56 | 57 | return None, None, None 58 | -------------------------------------------------------------------------------- /ui_components/widgets/video_cropping_element.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import cv2 3 | from PIL import Image 4 | import numpy as np 5 | import tempfile 6 | from moviepy.editor import VideoFileClip 7 | 8 | 9 | def video_cropping_element(shot_uuid): 10 | st.title("Video Cropper") 11 | 12 | video_file = st.file_uploader("Upload a video", type=["mp4", "mov", "avi", "mkv"]) 13 | video_url = st.text_input("...or enter a video URL") 14 | 15 | if video_file or video_url: 16 | tfile = tempfile.NamedTemporaryFile(delete=False) 17 | if video_file: 18 | tfile.write(video_file.read()) 19 | video_path = tfile.name 20 | else: 21 | video_path = video_url 22 | 23 | cap = cv2.VideoCapture(video_path) 24 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 25 | fps = cap.get(cv2.CAP_PROP_FPS) 26 | duration = total_frames / fps 27 | cap.release() 28 | 29 | start_time = st.slider("Start Time", 0.0, float(duration), 0.0, 0.1) 30 | end_time = st.slider("End Time", 0.0, float(duration), float(duration), 0.1) 31 | 32 | starting1, starting2 = st.columns(2) 33 | with starting1: 34 | starting_frame_number = int(start_time * fps) 35 | display_frame(video_path, starting_frame_number) 36 | with starting2: 37 | ending_frame_number = int(end_time * fps) 38 | display_frame(video_path, ending_frame_number) 39 | 40 | if st.button("Save New Video"): 41 | with st.spinner("Processing..."): 42 | clip = VideoFileClip(video_path).subclip(start_time, end_time) 43 | output_file = video_path.split(".")[0] + "_cropped.mp4" 44 | clip.write_videofile(output_file) 45 | st.success("Saved as {}".format(output_file)) 46 | 47 | 48 | def display_frame(video_path, frame_number): 49 | cap = cv2.VideoCapture(video_path) 50 | cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) 51 | ret, frame = cap.read() 52 | if ret: 53 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 54 | st.image(frame) 55 | cap.release() 56 | -------------------------------------------------------------------------------- /backend/migrations/0014_file_link_added.py: -------------------------------------------------------------------------------- 1 | # Generated by Django 4.2.1 on 2024-06-05 14:48 2 | 3 | from django.db import migrations, models 4 | import django.db.models.deletion 5 | import uuid 6 | 7 | 8 | class Migration(migrations.Migration): 9 | 10 | dependencies = [ 11 | ("backend", "0013_filter_keys_added"), 12 | ] 13 | 14 | operations = [ 15 | migrations.CreateModel( 16 | name="FileRelationship", 17 | fields=[ 18 | ( 19 | "id", 20 | models.BigAutoField( 21 | auto_created=True, primary_key=True, serialize=False, verbose_name="ID" 22 | ), 23 | ), 24 | ("uuid", models.UUIDField(default=uuid.uuid4)), 25 | ("created_on", models.DateTimeField(auto_now_add=True)), 26 | ("updated_on", models.DateTimeField(auto_now=True)), 27 | ("is_disabled", models.BooleanField(default=False)), 28 | ("transformation_type", models.TextField(default="")), 29 | ("child_entity_type", models.CharField(default="file", max_length=255)), 30 | ("parent_entity_type", models.CharField(default="file", max_length=255)), 31 | ( 32 | "child_entity", 33 | models.ForeignKey( 34 | default=None, 35 | null=True, 36 | on_delete=django.db.models.deletion.CASCADE, 37 | related_name="child_entity", 38 | to="backend.internalfileobject", 39 | ), 40 | ), 41 | ( 42 | "parent_entity", 43 | models.ForeignKey( 44 | default=None, 45 | null=True, 46 | on_delete=django.db.models.deletion.CASCADE, 47 | related_name="parent_entity", 48 | to="backend.internalfileobject", 49 | ), 50 | ), 51 | ], 52 | options={ 53 | "db_table": "file_relationship", 54 | }, 55 | ), 56 | ] 57 | -------------------------------------------------------------------------------- /utils/ml_processor/sai/sai.py: -------------------------------------------------------------------------------- 1 | import time 2 | from shared.constants import InferenceParamType 3 | from ui_components.methods.data_logger import log_model_inference 4 | from utils.constants import MLQueryObject 5 | from utils.ml_processor.constants import MLModel 6 | from utils.ml_processor.ml_interface import MachineLearningProcessor 7 | from utils.ml_processor.sai.utils import get_model_params_from_query_obj, predict_sai_output 8 | 9 | 10 | # rn only used for sd3 and doesn't have all methods of MachineLearningProcessor 11 | class StabilityProcessor(MachineLearningProcessor): 12 | def __init__(self): 13 | pass 14 | 15 | def predict_model_output_standardized( 16 | self, model: MLModel, query_obj: MLQueryObject, queue_inference=False, backlog=False 17 | ): 18 | params = get_model_params_from_query_obj(model, query_obj) 19 | 20 | if params: 21 | new_params = {} 22 | new_params[InferenceParamType.QUERY_DICT.value] = params 23 | new_params[InferenceParamType.SAI_INFERENCE.value] = params 24 | return ( 25 | self.predict_model_output(model, **new_params) 26 | if not queue_inference 27 | else self.queue_prediction(model, **new_params) 28 | ) # add backlog later 29 | 30 | def predict_model_output(self, model: MLModel, **kwargs): 31 | queue_inference = kwargs.get("queue_inference", False) 32 | if queue_inference: 33 | del kwargs["queue_inference"] 34 | return self.queue_prediction(model, **kwargs) 35 | 36 | start_time = time.time() 37 | output = predict_sai_output(kwargs.get(InferenceParamType.SAI_INFERENCE.value, None)) 38 | end_time = time.time() 39 | 40 | if "model" in kwargs: 41 | kwargs["inf_model"] = kwargs["model"] 42 | del kwargs["model"] 43 | 44 | log = log_model_inference(model, end_time - start_time, **kwargs) 45 | # TODO: update usage credits in the api mode 46 | # self.update_usage_credits(end_time - start_time) 47 | 48 | return output, log 49 | 50 | def queue_prediction(self, model, **kwargs): 51 | log = log_model_inference(model, None, **kwargs) 52 | return None, log 53 | -------------------------------------------------------------------------------- /django_settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import sys 4 | from django.db.backends.signals import connection_created 5 | 6 | sys.path.append("../") 7 | 8 | 9 | from dotenv import load_dotenv 10 | 11 | from shared.constants import HOSTED_BACKGROUND_RUNNER_MODE, LOCAL_DATABASE_NAME, SERVER, ServerType 12 | 13 | 14 | load_dotenv() 15 | 16 | if SERVER == ServerType.DEVELOPMENT.value: 17 | DB_LOCATION = LOCAL_DATABASE_NAME 18 | else: 19 | DB_LOCATION = "" 20 | 21 | BASE_DIR = Path(__file__).resolve().parent.parent 22 | 23 | def set_sqlite_timeout(sender, connection, **kwargs): 24 | if connection.vendor == 'sqlite': 25 | cursor = connection.cursor() 26 | cursor.execute('PRAGMA busy_timeout = 30000;') # 30 seconds 27 | 28 | if HOSTED_BACKGROUND_RUNNER_MODE in [False, "False"]: 29 | DATABASES = { 30 | "default": { 31 | "ENGINE": "django.db.backends.sqlite3", # sqlite by default works with serializable isolation level 32 | "NAME": DB_LOCATION, 33 | } 34 | } 35 | connection_created.connect(set_sqlite_timeout) 36 | else: 37 | import boto3 38 | 39 | ssm = boto3.client("ssm", region_name="ap-south-1") 40 | DB_NAME = ssm.get_parameter(Name="/backend/banodoco/db/name")["Parameter"]["Value"] 41 | DB_USER = ssm.get_parameter(Name="/backend/banodoco/db/user")["Parameter"]["Value"] 42 | DB_PASS = ssm.get_parameter(Name="/backend/banodoco/db/password")["Parameter"]["Value"] 43 | DB_HOST = ssm.get_parameter(Name="/backend/banodoco/db/host")["Parameter"]["Value"] 44 | DB_PORT = ssm.get_parameter(Name="/backend/banodoco/db/port")["Parameter"]["Value"] 45 | 46 | DATABASES = { 47 | "default": { 48 | "ENGINE": "django.db.backends.postgresql", 49 | "NAME": DB_NAME, 50 | "USER": DB_USER, 51 | "PASSWORD": DB_PASS, 52 | "HOST": DB_HOST, 53 | "PORT": DB_PORT, 54 | }, 55 | 'OPTIONS': { 56 | 'isolation_level': 'repeatable_read', # TODO: test this isolation_level if deployed on prod 57 | }, 58 | } 59 | 60 | INSTALLED_APPS = ("backend",) 61 | 62 | DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField" 63 | 64 | SECRET_KEY = "4e&6aw+(5&cg^_!05r(&7_#dghg_pdgopq(yk)xa^bog7j)^*j" 65 | -------------------------------------------------------------------------------- /ui_components/components/inspiraton_engine_page.py: -------------------------------------------------------------------------------- 1 | from shared.constants import COMFY_BASE_PATH 2 | import streamlit as st 3 | from ui_components.constants import CreativeProcessType 4 | from ui_components.widgets.inspiration_engine import inspiration_engine_element 5 | from ui_components.widgets.timeline_view import timeline_view 6 | from ui_components.components.explorer_page import gallery_image_view 7 | from utils import st_memory 8 | from utils.data_repo.data_repo import DataRepo 9 | 10 | from ui_components.widgets.sidebar_logger import sidebar_logger 11 | from ui_components.components.explorer_page import generate_images_element 12 | 13 | 14 | def inspiration_engine_page(shot_uuid: str, h2): 15 | data_repo = DataRepo() 16 | shot = data_repo.get_shot_from_uuid(shot_uuid) 17 | if not shot: 18 | st.error("Shot not found") 19 | else: 20 | project_uuid = shot.project.uuid 21 | project = data_repo.get_project_from_uuid(project_uuid) 22 | 23 | with st.sidebar: 24 | views = CreativeProcessType.value_list() 25 | 26 | if "view" not in st.session_state: 27 | st.session_state["view"] = views[0] 28 | 29 | st.write("") 30 | 31 | with st.expander("🔍 Generation log", expanded=True): 32 | sidebar_logger(st.session_state["shot_uuid"]) 33 | 34 | 35 | 36 | st.write("") 37 | with st.expander("🪄 Shots", expanded=True): 38 | timeline_view(shot_uuid, "🪄 Shots", view="sidebar") 39 | 40 | st.markdown(f"#### :green[{st.session_state['main_view_type']}] > :red[{st.session_state['page']}]") 41 | st.markdown("***") 42 | 43 | st.markdown("### ✨ Generate images") 44 | st.write("##### _\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_") 45 | 46 | inspiration_engine_element( 47 | position="explorer", project_uuid=project_uuid, timing_uuid=None, shot_uuid=None 48 | ) 49 | 50 | st.markdown("***") 51 | gallery_image_view( 52 | project_uuid, 53 | False, 54 | view=[ 55 | "add_and_remove_from_shortlist", 56 | "view_inference_details", 57 | "shot_chooser", 58 | "add_to_any_shot", 59 | ], 60 | ) 61 | -------------------------------------------------------------------------------- /ui_components/components/user_login_page.py: -------------------------------------------------------------------------------- 1 | import time 2 | import streamlit as st 3 | 4 | from utils.data_repo.api_repo import APIRepo 5 | from utils.local_storage.url_storage import set_url_param 6 | from utils.state_refresh import refresh_app 7 | from utils.third_party_auth.google.google_auth import get_auth_provider 8 | from streamlit.web.server.server import Server 9 | 10 | 11 | def user_login_ui(): 12 | params = st.experimental_get_query_params() 13 | api_repo = APIRepo() 14 | 15 | # http://localhost:5500/?code=4%2F0AQlEd8xV0xpyTCHnHJH8zopNtZ033s7m419wdSXLT7-fYK5KSk5PqDYR0bdM0F8UXzjJMQ&scope=email+profile+openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.profile+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email&authuser=0&prompt=consent 16 | # params found: {'code': ['4/0AQlEd8xy2YawBOvBcWiKBpwqtRtEE5i5hbpfx'], 'scope': ['email profile openid], 'authuser': ['0'], 'prompt': ['consent']} 17 | if params and "code" in params and not st.session_state.get("retry_login"): 18 | if st.session_state.get("retry_login"): 19 | st.session_state["retry_login"] = False 20 | refresh_app() 21 | 22 | st.markdown("#### Logging you in, please wait...") 23 | # st.write(params['code']) 24 | data = {"id_token": params["code"][0]} 25 | auth_token, refresh_token, user = api_repo.auth_provider.verify_auth_details(data) 26 | if auth_token and refresh_token: 27 | st.success("Successfully logged In, settings things up...") 28 | api_repo.set_auth_token(auth_token, refresh_token, user) 29 | refresh_app() 30 | else: 31 | st.error("Unable to login..") 32 | if st.button("Retry Login", key="retry_login_btn"): 33 | st.session_state["retry_login"] = True 34 | refresh_app() 35 | else: 36 | st.session_state["retry_login"] = False 37 | st.markdown("# :green[D]:red[o]:blue[u]:orange[g]:green[h] :red[□] :blue[□] :orange[□]") 38 | st.markdown("#### Login with Google to proceed") 39 | 40 | auth_url = api_repo.auth_provider.get_auth_url(redirect_uri="http://localhost:5500") 41 | if auth_url: 42 | st.markdown(auth_url, unsafe_allow_html=True) 43 | else: 44 | time.sleep(0.1) 45 | st.warning("Unable to generate login link, please contact support") 46 | -------------------------------------------------------------------------------- /shared/logging/logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import colorlog 3 | from shared.constants import SERVER, ServerType 4 | 5 | from shared.logging.constants import LoggingMode, LoggingPayload, LoggingType 6 | 7 | 8 | class AppLogger(logging.Logger): 9 | def __init__(self, name="app_logger", log_file=None, log_level=logging.DEBUG): 10 | super().__init__(name, log_level) 11 | self.log_file = log_file 12 | 13 | self._configure_logging() 14 | 15 | def _configure_logging(self): 16 | log_formatter = colorlog.ColoredFormatter( 17 | "%(log_color)s%(levelname)s:%(name)s:%(message)s", 18 | log_colors={ 19 | "DEBUG": "cyan", 20 | "INFO": "green", 21 | "WARNING": "yellow", 22 | "ERROR": "red", 23 | "CRITICAL": "red,bg_white", 24 | }, 25 | reset=True, 26 | secondary_log_colors={}, 27 | style="%", 28 | ) 29 | if self.log_file: 30 | file_handler = logging.FileHandler(self.log_file) 31 | file_handler.setFormatter(log_formatter) 32 | self.addHandler(file_handler) 33 | 34 | console_handler = logging.StreamHandler() 35 | console_handler.setFormatter(log_formatter) 36 | self.addHandler(console_handler) 37 | 38 | # setting logging mode 39 | if SERVER != ServerType.PRODUCTION.value: 40 | self.logging_mode = LoggingMode.OFFLINE.value 41 | else: 42 | self.logging_mode = LoggingMode.ONLINE.value 43 | 44 | # def log(self, log_type: LoggingType, log_payload: LoggingPayload): 45 | # if log_type == LoggingType.DEBUG: 46 | # self.debug(log_payload.message) 47 | # elif log_type == LoggingType.INFO: 48 | # self.info(log_payload.message) 49 | # elif log_type == LoggingType.ERROR: 50 | # self.error(log_payload.message) 51 | # elif log_type in [LoggingType.INFERENCE_CALL, LoggingType.INFERENCE_RESULT]: 52 | # self.info(log_payload.message) 53 | 54 | def log(self, log_type: LoggingType, log_message, log_data=None): 55 | if log_type == LoggingType.DEBUG: 56 | self.debug(log_message) 57 | elif log_type == LoggingType.INFO: 58 | self.info(log_message) 59 | elif log_type == LoggingType.ERROR: 60 | self.error(log_message) 61 | elif log_type in [LoggingType.INFERENCE_CALL, LoggingType.INFERENCE_RESULT]: 62 | self.info(log_message) 63 | 64 | 65 | app_logger = AppLogger() 66 | -------------------------------------------------------------------------------- /utils/ml_processor/sai/utils.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | import json 3 | import os 4 | import random 5 | import uuid 6 | 7 | import requests 8 | from shared.constants import InferenceParamType 9 | from utils.constants import MLQueryObject 10 | from utils.encryption import Encryptor 11 | from utils.ml_processor.constants import ML_MODEL, MLModel 12 | 13 | 14 | def predict_sai_output(data): 15 | if not data: 16 | return None 17 | 18 | # TODO: decouple encryptor from this function 19 | encryptor = Encryptor() 20 | sai_key = encryptor.decrypt_json(data["data"]["data"]["stability_key"]) 21 | input_params = deepcopy(data) 22 | del input_params["data"] 23 | 24 | response = requests.post( 25 | f"https://api.stability.ai/v2beta/stable-image/generate/sd3", 26 | headers={"authorization": f"Bearer {sai_key}", "accept": "image/*"}, 27 | files={"none": ""}, 28 | data=input_params, 29 | ) 30 | 31 | if response.status_code == 200: 32 | unique_filename = os.path.join("output", str(uuid.uuid4()) + ".png") 33 | if not os.path.exists("output"): 34 | os.makedirs("output") 35 | with open(unique_filename, "wb") as file: 36 | file.write(response.content) 37 | 38 | return unique_filename 39 | else: 40 | raise Exception(str(response.json())) 41 | 42 | 43 | def get_closest_aspect_ratio(width, height): 44 | aspect_ratios = ["16:9", "1:1", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21"] 45 | ratio = width / height 46 | closest_ratio = None 47 | min_difference = float("inf") 48 | 49 | for aspect_ratio in aspect_ratios: 50 | aspect_width, aspect_height = aspect_ratio.split(":") 51 | aspect_ratio_value = int(aspect_width) / int(aspect_height) 52 | 53 | difference = abs(ratio - aspect_ratio_value) 54 | if difference < min_difference: 55 | min_difference = difference 56 | closest_ratio = aspect_ratio 57 | 58 | return closest_ratio 59 | 60 | 61 | def get_model_params_from_query_obj(model: MLModel, query_obj: MLQueryObject): 62 | if model == ML_MODEL.sd3: 63 | return { 64 | "mode": "text-to-image", 65 | "prompt": query_obj.prompt, 66 | "negative_prompt": query_obj.negative_prompt, 67 | "model": "sd3", 68 | "seed": random_seed(), 69 | "output_format": "png", 70 | "data": query_obj.data, 71 | "aspect_ratio": get_closest_aspect_ratio(query_obj.width, query_obj.height), 72 | } 73 | else: 74 | return None 75 | 76 | 77 | def random_seed(): 78 | return random.randint(10**6, 10**8 - 1) 79 | -------------------------------------------------------------------------------- /ui_components/components/timeline_view_page.py: -------------------------------------------------------------------------------- 1 | from shared.constants import COMFY_BASE_PATH 2 | import streamlit as st 3 | from ui_components.constants import CreativeProcessType 4 | from ui_components.widgets.timeline_view import timeline_view 5 | 6 | from utils import st_memory 7 | from utils.data_repo.data_repo import DataRepo 8 | from ui_components.widgets.sidebar_logger import sidebar_logger 9 | from ui_components.components.explorer_page import gallery_image_view 10 | 11 | 12 | def timeline_view_page(shot_uuid: str, h2): 13 | data_repo = DataRepo() 14 | shot = data_repo.get_shot_from_uuid(shot_uuid) 15 | if not shot: 16 | st.error("Shot not found") 17 | else: 18 | project_uuid = shot.project.uuid 19 | project = data_repo.get_project_from_uuid(project_uuid) 20 | 21 | with st.sidebar: 22 | views = CreativeProcessType.value_list() 23 | 24 | if "view" not in st.session_state: 25 | st.session_state["view"] = views[0] 26 | 27 | st.write("") 28 | 29 | with st.expander("🔍 Generation log", expanded=True): 30 | # if st_memory.toggle("Open", value=True, key="generaton_log_toggle"): 31 | sidebar_logger(st.session_state["shot_uuid"]) 32 | 33 | st.write("") 34 | 35 | with st.expander("📋 Shortlist", expanded=True): 36 | if st_memory.toggle("Open", value=True, key="explorer_shortlist_toggle"): 37 | gallery_image_view( 38 | shot.project.uuid, 39 | shortlist=True, 40 | view=["add_and_remove_from_shortlist", "add_to_any_shot"], 41 | ) 42 | 43 | st.markdown(f"#### :green[{st.session_state['main_view_type']}] > :red[{st.session_state['page']}]") 44 | st.markdown("***") 45 | slider1, slider2, slider3 = st.columns([2, 1, 1]) 46 | with slider1: 47 | st.markdown(f"### 🪄 '{project.name}' shots") 48 | st.write("##### _\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_") 49 | 50 | # start_time = time.time() 51 | timeline_view(st.session_state["shot_uuid"], st.session_state["view"], view="main") 52 | 53 | # end_time = time.time() 54 | # print("///////////////// timeline laoded in: ", end_time - start_time) 55 | # generate_images_element( 56 | # position="explorer", project_uuid=project_uuid, timing_uuid=None, shot_uuid=None 57 | # ) 58 | 59 | # end_time = time.time() 60 | # print("///////////////// generate img laoded in: ", end_time - start_time) 61 | 62 | # end_time = time.time() 63 | # print("///////////////// gallery laoded in: ", end_time - start_time) 64 | -------------------------------------------------------------------------------- /utils/ml_processor/comfy_workflows/flux_schnell_workflow_api.json: -------------------------------------------------------------------------------- 1 | { 2 | "6": { 3 | "inputs": { 4 | "text": "a dark fantasy castle", 5 | "clip": [ 6 | "30", 7 | 1 8 | ] 9 | }, 10 | "class_type": "CLIPTextEncode", 11 | "_meta": { 12 | "title": "CLIP Text Encode (Positive Prompt)" 13 | } 14 | }, 15 | "8": { 16 | "inputs": { 17 | "samples": [ 18 | "31", 19 | 0 20 | ], 21 | "vae": [ 22 | "30", 23 | 2 24 | ] 25 | }, 26 | "class_type": "VAEDecode", 27 | "_meta": { 28 | "title": "VAE Decode" 29 | } 30 | }, 31 | "9": { 32 | "inputs": { 33 | "filename_prefix": "ComfyUI", 34 | "images": [ 35 | "8", 36 | 0 37 | ] 38 | }, 39 | "class_type": "SaveImage", 40 | "_meta": { 41 | "title": "Save Image" 42 | } 43 | }, 44 | "27": { 45 | "inputs": { 46 | "width": 1024, 47 | "height": 1024, 48 | "batch_size": 1 49 | }, 50 | "class_type": "EmptySD3LatentImage", 51 | "_meta": { 52 | "title": "EmptySD3LatentImage" 53 | } 54 | }, 55 | "30": { 56 | "inputs": { 57 | "ckpt_name": "flux1-schnell-fp8.safetensors" 58 | }, 59 | "class_type": "CheckpointLoaderSimple", 60 | "_meta": { 61 | "title": "Load Checkpoint" 62 | } 63 | }, 64 | "31": { 65 | "inputs": { 66 | "seed": 1063867382637414, 67 | "steps": 4, 68 | "cfg": 1, 69 | "sampler_name": "euler", 70 | "scheduler": "simple", 71 | "denoise": 1, 72 | "model": [ 73 | "30", 74 | 0 75 | ], 76 | "positive": [ 77 | "6", 78 | 0 79 | ], 80 | "negative": [ 81 | "33", 82 | 0 83 | ], 84 | "latent_image": [ 85 | "27", 86 | 0 87 | ] 88 | }, 89 | "class_type": "KSampler", 90 | "_meta": { 91 | "title": "KSampler" 92 | } 93 | }, 94 | "33": { 95 | "inputs": { 96 | "text": "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature", 97 | "clip": [ 98 | "30", 99 | 1 100 | ] 101 | }, 102 | "class_type": "CLIPTextEncode", 103 | "_meta": { 104 | "title": "CLIP Text Encode (Negative Prompt)" 105 | } 106 | } 107 | } -------------------------------------------------------------------------------- /utils/ml_processor/comfy_workflows/llama_workflow_api.json: -------------------------------------------------------------------------------- 1 | { 2 | "15": { 3 | "inputs": { 4 | "ckpt_name": "Meta-Llama-3-8B-Instruct-Q5_K_M.gguf", 5 | "max_ctx": 2048, 6 | "gpu_layers": 100, 7 | "n_threads": 2 8 | }, 9 | "class_type": "LLMLoader", 10 | "_meta": { 11 | "title": "LLMLoader" 12 | } 13 | }, 14 | "24": { 15 | "inputs": { 16 | "string": "Looking at the examples 1 and 2 below generate the sub-prompts for the FINAL EXAMPLE containing a concise story:\n\nEXAMPLE 1:\nOverall prompt:Story about Leonard Cohen's big day at the beach.\nNumber of items: 12 \nSub-prompts:Leonard Cohen looking lying in bed|Leonard Cohen brushing teeth|Leonard Cohen smiling happily|Leonard Cohen driving in car, morning|Leonard Cohen going for a swim at beach|Leonard Cohen sitting on beach towel|Leonard Cohen building sandcastle|Leonard Cohen eating sandwich|Leonard Cohen walking along beach|Leonard Cohen getting out of water at seaside|Leonard Cohen driving home, dark outside|Leonard Cohen lying in bed smiling|close of of Leonard Cohen asleep in bed---\n\nEXAMPLE 2:\nOverall prompt: Visualizing the first day of spring \nNumber of items: 24 \nSub-prompts:Frost melting off grass|Sun rising over dewy meadow|Sparrows chirping in tree|Puddles drying up|Bees flying out of hive|Flowers blooming in garden bed|Robin landing on branch|Steam rising from cup of coffee|Morning light creeping through curtains|Wind rustling through leaves|Crocus bulbs pushing through soil|Buds swelling on branches|Birds singing in chorus|Sun shining through rain-soaked pavement|Droplets clinging to spider's web|Green shoots bursting forth from roots|Garden hose dripping water|Fog burning off lake|Light filtering through stained glass window|Hummingbird sipping nectar from flower|Warm breeze rustling through wheat field|Birch trees donning new green coat|Solar eclipse casting shadow on path|Birds returning to their nests---\n\nFINAL EXAMPLE:\nOverall prompt:a magical dark castle \nNumber of items: 5 \nSub-prompts:" 17 | }, 18 | "class_type": "String Literal", 19 | "_meta": { 20 | "title": "String Literal" 21 | } 22 | }, 23 | "26": { 24 | "inputs": { 25 | "text": [ 26 | "30", 27 | 0 28 | ] 29 | }, 30 | "class_type": "ShowText|pysssss", 31 | "_meta": { 32 | "title": "✴️ U-NAI Get Text" 33 | } 34 | }, 35 | "30": { 36 | "inputs": { 37 | "prompt": [ 38 | "24", 39 | 0 40 | ], 41 | "temperature": 0.15, 42 | "attribute_name": "Sub-prompts", 43 | "attribute_type": "str", 44 | "attribute_description": "generate Sub-prompts:", 45 | "categories": "", 46 | "model": [ 47 | "15", 48 | 0 49 | ] 50 | }, 51 | "class_type": "StructuredOutput", 52 | "_meta": { 53 | "title": "Structured Output" 54 | } 55 | } 56 | } -------------------------------------------------------------------------------- /utils/ml_processor/comfy_workflows/sdxl_img2img_workflow_api.json: -------------------------------------------------------------------------------- 1 | { 2 | "1": { 3 | "inputs": { 4 | "ckpt_name": "sd_xl_base_1.0.safetensors" 5 | }, 6 | "class_type": "CheckpointLoaderSimple", 7 | "_meta": { 8 | "title": "Load Checkpoint" 9 | } 10 | }, 11 | "31": { 12 | "inputs": { 13 | "filename_prefix": "ComfyUI", 14 | "images": [ 15 | "44:0", 16 | 0 17 | ] 18 | }, 19 | "class_type": "SaveImage", 20 | "_meta": { 21 | "title": "Save Image" 22 | } 23 | }, 24 | "37:0": { 25 | "inputs": { 26 | "image": "cca52d68-91aa-4db8-b724-2370d03ff987.png", 27 | "upload": "image" 28 | }, 29 | "class_type": "LoadImage", 30 | "_meta": { 31 | "title": "Load Image" 32 | } 33 | }, 34 | "37:1": { 35 | "inputs": { 36 | "pixels": [ 37 | "37:0", 38 | 0 39 | ], 40 | "vae": [ 41 | "1", 42 | 2 43 | ] 44 | }, 45 | "class_type": "VAEEncode", 46 | "_meta": { 47 | "title": "VAE Encode" 48 | } 49 | }, 50 | "42:0": { 51 | "inputs": { 52 | "text": "pic of a king", 53 | "clip": [ 54 | "1", 55 | 1 56 | ] 57 | }, 58 | "class_type": "CLIPTextEncode", 59 | "_meta": { 60 | "title": "Positive prompt" 61 | } 62 | }, 63 | "42:1": { 64 | "inputs": { 65 | "text": "", 66 | "clip": [ 67 | "1", 68 | 1 69 | ] 70 | }, 71 | "class_type": "CLIPTextEncode", 72 | "_meta": { 73 | "title": "Negative prompt (not used)" 74 | } 75 | }, 76 | "42:2": { 77 | "inputs": { 78 | "seed": 89273174590337, 79 | "steps": 20, 80 | "cfg": 7, 81 | "sampler_name": "euler", 82 | "scheduler": "normal", 83 | "denoise": 0.6, 84 | "model": [ 85 | "1", 86 | 0 87 | ], 88 | "positive": [ 89 | "42:0", 90 | 0 91 | ], 92 | "negative": [ 93 | "42:1", 94 | 0 95 | ], 96 | "latent_image": [ 97 | "37:1", 98 | 0 99 | ] 100 | }, 101 | "class_type": "KSampler", 102 | "_meta": { 103 | "title": "KSampler" 104 | } 105 | }, 106 | "44:0": { 107 | "inputs": { 108 | "samples": [ 109 | "42:2", 110 | 0 111 | ], 112 | "vae": [ 113 | "1", 114 | 2 115 | ] 116 | }, 117 | "class_type": "VAEDecode", 118 | "_meta": { 119 | "title": "VAE Decode" 120 | } 121 | }, 122 | "44:1": { 123 | "inputs": { 124 | "images": [ 125 | "44:0", 126 | 0 127 | ] 128 | }, 129 | "class_type": "PreviewImage", 130 | "_meta": { 131 | "title": "Preview Image" 132 | } 133 | } 134 | } -------------------------------------------------------------------------------- /ui_components/methods/data_logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import streamlit as st 3 | import time 4 | from shared.constants import InferenceParamType, InferenceStatus 5 | from shared.logging.constants import LoggingPayload, LoggingType 6 | from utils.common_utils import get_current_user_uuid 7 | from utils.data_repo.data_repo import DataRepo 8 | 9 | from utils.ml_processor.constants import ML_MODEL, MLModel 10 | 11 | 12 | def log_model_inference(model: MLModel, time_taken, **kwargs): 13 | kwargs_dict = dict(kwargs) 14 | 15 | # removing object like bufferedreader, image_obj .. 16 | for key, value in dict(kwargs_dict).items(): 17 | if not isinstance(value, (int, str, list, dict)): 18 | del kwargs_dict[key] 19 | 20 | data_str = json.dumps(kwargs_dict) 21 | origin_data = kwargs_dict.get(InferenceParamType.ORIGIN_DATA.value, {}) 22 | time_taken = round(time_taken, 2) if time_taken else 0 23 | 24 | # system_logger = AppLogger() 25 | # logging_payload = LoggingPayload(message="logging inference data", data=data) 26 | 27 | # # logging in console 28 | # system_logger.log(LoggingType.INFERENCE_CALL, logging_payload) 29 | 30 | # storing the log in db 31 | data_repo = DataRepo() 32 | user_id = get_current_user_uuid() 33 | ai_model = data_repo.get_ai_model_from_name(model.name, user_id) 34 | 35 | # TODO: fix this - we were initially storing all the models and their versions in the database but later moved on from it 36 | # so earlier models are found when fetching ai_model but for the new models we are adding this hack for adding dummy model_id 37 | # hackish sol for insuring that inpainting logs don't have an empty model field 38 | if ai_model is None and model.name in [ 39 | ML_MODEL.sdxl_inpainting.name, 40 | ML_MODEL.ad_interpolation.name, 41 | ML_MODEL.sd3.name, 42 | ]: 43 | ai_model = data_repo.get_ai_model_from_name(ML_MODEL.sdxl.name, user_id) 44 | 45 | log_data = { 46 | "project_id": st.session_state["project_uuid"], 47 | "model_id": ai_model.uuid if ai_model else None, 48 | "input_params": data_str, 49 | "output_details": json.dumps({"model_name": model.display_name(), "version": model.version}), 50 | "total_inference_time": time_taken, 51 | "status": ( 52 | InferenceStatus.COMPLETED.value 53 | if time_taken 54 | else ( 55 | InferenceStatus.BACKLOG.value 56 | if "backlog" in kwargs and kwargs["backlog"] 57 | else InferenceStatus.QUEUED.value 58 | ) 59 | ), 60 | "model_name": model.display_name(), 61 | "generation_source": origin_data.get("inference_type", ""), 62 | "generation_tag": origin_data.get("inference_tag", "") 63 | } 64 | 65 | log = data_repo.create_inference_log(**log_data) 66 | return log 67 | -------------------------------------------------------------------------------- /.aws/task-definition.json: -------------------------------------------------------------------------------- 1 | { 2 | "taskDefinitionArn": "arn:aws:ecs:ap-south-1:861629679241:task-definition/backend-banodoco-frontend-task:14", 3 | "containerDefinitions": [ 4 | { 5 | "name": "backend-banodoco-frontend", 6 | "image": "861629679241.dkr.ecr.ap-south-1.amazonaws.com/banodoco-frontend:latest", 7 | "cpu": 512, 8 | "memory": 2048, 9 | "portMappings": [ 10 | { 11 | "containerPort": 5500, 12 | "hostPort": 5500, 13 | "protocol": "tcp" 14 | } 15 | ], 16 | "essential": true, 17 | "environment": [ 18 | { 19 | "name": "SERVER_URL", 20 | "value": "banodoco-backend.mvp.internal" 21 | }, 22 | { 23 | "name": "TEST_2", 24 | "value": "test" 25 | } 26 | ], 27 | "mountPoints": [], 28 | "volumesFrom": [], 29 | "logConfiguration": { 30 | "logDriver": "awslogs", 31 | "options": { 32 | "awslogs-group": "/backend/banodoco-frontend", 33 | "awslogs-region": "ap-south-1", 34 | "awslogs-stream-prefix": "backend/banodoco-frontend" 35 | } 36 | } 37 | } 38 | ], 39 | "family": "backend-banodoco-frontend-task", 40 | "taskRoleArn": "arn:aws:iam::861629679241:role/ecs-task-role", 41 | "executionRoleArn": "arn:aws:iam::861629679241:role/ecs-task-execution-role", 42 | "networkMode": "awsvpc", 43 | "revision": 14, 44 | "volumes": [], 45 | "status": "ACTIVE", 46 | "requiresAttributes": [ 47 | { 48 | "name": "com.amazonaws.ecs.capability.logging-driver.awslogs" 49 | }, 50 | { 51 | "name": "ecs.capability.execution-role-awslogs" 52 | }, 53 | { 54 | "name": "com.amazonaws.ecs.capability.ecr-auth" 55 | }, 56 | { 57 | "name": "com.amazonaws.ecs.capability.docker-remote-api.1.19" 58 | }, 59 | { 60 | "name": "com.amazonaws.ecs.capability.task-iam-role" 61 | }, 62 | { 63 | "name": "ecs.capability.execution-role-ecr-pull" 64 | }, 65 | { 66 | "name": "com.amazonaws.ecs.capability.docker-remote-api.1.18" 67 | }, 68 | { 69 | "name": "ecs.capability.task-eni" 70 | } 71 | ], 72 | "placementConstraints": [], 73 | "compatibilities": [ 74 | "EC2", 75 | "FARGATE" 76 | ], 77 | "requiresCompatibilities": [ 78 | "FARGATE" 79 | ], 80 | "cpu": "512", 81 | "memory": "2048", 82 | "registeredAt": "2023-09-10T10:50:27.792Z", 83 | "registeredBy": "arn:aws:iam::861629679241:user/tf-admin", 84 | "tags": [] 85 | } -------------------------------------------------------------------------------- /utils/ml_processor/comfy_workflows/creative_image_gen.json: -------------------------------------------------------------------------------- 1 | { 2 | "3": { 3 | "inputs": { 4 | "seed": [ 5 | "28", 6 | 0 7 | ], 8 | "steps": 20, 9 | "cfg": 7, 10 | "sampler_name": "dpmpp_2m", 11 | "scheduler": "karras", 12 | "denoise": 1, 13 | "model": [ 14 | "24", 15 | 0 16 | ], 17 | "positive": [ 18 | "6", 19 | 0 20 | ], 21 | "negative": [ 22 | "7", 23 | 0 24 | ], 25 | "latent_image": [ 26 | "5", 27 | 0 28 | ] 29 | }, 30 | "class_type": "KSampler", 31 | "_meta": { 32 | "title": "KSampler" 33 | } 34 | }, 35 | "4": { 36 | "inputs": { 37 | "ckpt_name": "sd_xl_base_1.0.safetensors" 38 | }, 39 | "class_type": "CheckpointLoaderSimple", 40 | "_meta": { 41 | "title": "Load Checkpoint" 42 | } 43 | }, 44 | "5": { 45 | "inputs": { 46 | "width": 1024, 47 | "height": 1024, 48 | "batch_size": 1 49 | }, 50 | "class_type": "EmptyLatentImage", 51 | "_meta": { 52 | "title": "Empty Latent Image" 53 | } 54 | }, 55 | "6": { 56 | "inputs": { 57 | "text": "beautiful ocean scene", 58 | "clip": [ 59 | "4", 60 | 1 61 | ] 62 | }, 63 | "class_type": "CLIPTextEncode", 64 | "_meta": { 65 | "title": "CLIP Text Encode (Prompt)" 66 | } 67 | }, 68 | "7": { 69 | "inputs": { 70 | "text": "horror", 71 | "clip": [ 72 | "4", 73 | 1 74 | ] 75 | }, 76 | "class_type": "CLIPTextEncode", 77 | "_meta": { 78 | "title": "CLIP Text Encode (Prompt)" 79 | } 80 | }, 81 | "10": { 82 | "inputs": { 83 | "vae_name": "sdxl_vae.safetensors" 84 | }, 85 | "class_type": "VAELoader", 86 | "_meta": { 87 | "title": "Load VAE" 88 | } 89 | }, 90 | "11": { 91 | "inputs": { 92 | "preset": "PLUS (high strength)", 93 | "model": [ 94 | "4", 95 | 0 96 | ] 97 | }, 98 | "class_type": "IPAdapterUnifiedLoader", 99 | "_meta": { 100 | "title": "IPAdapter Unified Loader" 101 | } 102 | }, 103 | "16": { 104 | "inputs": { 105 | "samples": [ 106 | "3", 107 | 0 108 | ], 109 | "vae": [ 110 | "10", 111 | 0 112 | ] 113 | }, 114 | "class_type": "VAEDecode", 115 | "_meta": { 116 | "title": "VAE Decode" 117 | } 118 | }, 119 | "27": { 120 | "inputs": { 121 | "filename_prefix": "IPAdapter", 122 | "images": [ 123 | "16", 124 | 0 125 | ] 126 | }, 127 | "class_type": "SaveImage", 128 | "_meta": { 129 | "title": "Save Image" 130 | } 131 | }, 132 | "28": { 133 | "inputs": { 134 | "seed": 832322928972889 135 | }, 136 | "class_type": "ttN seed", 137 | "_meta": { 138 | "title": "seed" 139 | } 140 | } 141 | } -------------------------------------------------------------------------------- /ui_components/widgets/frame_switch_btn.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import streamlit as st 3 | 4 | from ui_components.models import InternalFrameTimingObject 5 | from utils.data_repo.data_repo import DataRepo 6 | from utils.state_refresh import refresh_app 7 | 8 | 9 | def back_and_forward_buttons(): 10 | data_repo = DataRepo() 11 | timing: InternalFrameTimingObject = data_repo.get_timing_from_uuid(st.session_state["current_frame_uuid"]) 12 | timing_list: List[InternalFrameTimingObject] = data_repo.get_timing_list_from_shot(timing.shot.uuid) 13 | 14 | smallbutton0, smallbutton1, smallbutton2, smallbutton3, smallbutton4 = st.columns([2, 2, 2, 2, 2]) 15 | 16 | display_idx = st.session_state["current_frame_index"] 17 | with smallbutton0: 18 | if display_idx > 2: 19 | if st.button(f"{display_idx-2} ⏮️", key=f"Previous Previous Image for {display_idx}"): 20 | st.session_state["current_frame_index"] = st.session_state["current_frame_index"] - 2 21 | st.session_state["prev_frame_index"] = st.session_state["current_frame_index"] 22 | st.session_state["current_frame_uuid"] = timing_list[ 23 | st.session_state["current_frame_index"] - 1 24 | ].uuid 25 | refresh_app() 26 | with smallbutton1: 27 | # if it's not the first image 28 | if display_idx != 1: 29 | if st.button(f"{display_idx-1} ⏪", key=f"Previous Image for {display_idx}"): 30 | st.session_state["current_frame_index"] = st.session_state["current_frame_index"] - 1 31 | st.session_state["prev_frame_index"] = st.session_state["current_frame_index"] 32 | st.session_state["current_frame_uuid"] = timing_list[ 33 | st.session_state["current_frame_index"] - 1 34 | ].uuid 35 | refresh_app() 36 | 37 | with smallbutton2: 38 | st.button(f"{display_idx} 📍", disabled=True) 39 | with smallbutton3: 40 | # if it's not the last image 41 | if display_idx != len(timing_list): 42 | if st.button(f"{display_idx+1} ⏩", key=f"Next Image for {display_idx}"): 43 | st.session_state["current_frame_index"] = st.session_state["current_frame_index"] + 1 44 | st.session_state["prev_frame_index"] = st.session_state["current_frame_index"] 45 | st.session_state["current_frame_uuid"] = timing_list[ 46 | st.session_state["current_frame_index"] - 1 47 | ].uuid 48 | refresh_app() 49 | with smallbutton4: 50 | if display_idx <= len(timing_list) - 2: 51 | if st.button(f"{display_idx+2} ⏭️", key=f"Next Next Image for {display_idx}"): 52 | st.session_state["current_frame_index"] = st.session_state["current_frame_index"] + 2 53 | st.session_state["prev_frame_index"] = st.session_state["current_frame_index"] 54 | st.session_state["current_frame_uuid"] = timing_list[ 55 | st.session_state["current_frame_index"] - 1 56 | ].uuid 57 | refresh_app() 58 | -------------------------------------------------------------------------------- /utils/media_processor/video.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from moviepy.editor import VideoFileClip, vfx 4 | 5 | 6 | class VideoProcessor: 7 | @staticmethod 8 | def update_video_speed(video_location, desired_duration): 9 | clip = VideoFileClip(video_location) 10 | 11 | return VideoProcessor.update_clip_speed(clip, desired_duration) 12 | 13 | @staticmethod 14 | def update_video_bytes_speed(video_bytes, desired_duration): 15 | # Use a context manager for the temporary file to ensure it's deleted when done 16 | with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4", mode="wb") as temp_file: 17 | temp_file.write(video_bytes) 18 | temp_file_path = temp_file.name # Store the file name to delete later 19 | 20 | # Process the video file 21 | with VideoFileClip(temp_file_path) as clip: 22 | result = VideoProcessor.update_clip_speed(clip, desired_duration) 23 | 24 | # After processing and closing the clip, it's safe to delete the source temp file 25 | os.remove(temp_file_path) 26 | 27 | return result 28 | 29 | @staticmethod 30 | def update_clip_speed(clip: VideoFileClip, desired_duration): 31 | # Use a context manager to ensure temporary file for output is deleted when done 32 | with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4", mode="wb") as temp_output_file: 33 | temp_output_path = temp_output_file.name # Store the file name for later use 34 | 35 | # if animation_style == AnimationStyleType.DIRECT_MORPHING.value: 36 | # clip = clip.set_fps(120) 37 | 38 | # # Calculate the number of frames to keep 39 | # input_duration = clip.duration 40 | # total_frames = len(list(clip.iter_frames())) 41 | # target_frames = int(total_frames * (desired_duration / input_duration)) 42 | 43 | # # Determine which frames to keep 44 | # keep_every_n_frames = total_frames / target_frames 45 | # frames_to_keep = [int(i * keep_every_n_frames) 46 | # for i in range(target_frames)] 47 | 48 | # # Create a new video clip with the selected frames 49 | # output_clip = concatenate_videoclips( 50 | # [clip.subclip(i/clip.fps, (i+1)/clip.fps) for i in frames_to_keep]) 51 | 52 | # output_clip.write_videofile(filename=temp_output_file.name, codec="libx265") 53 | 54 | # Apply desired video speed change and write to the temporary output file 55 | input_video_duration = clip.duration 56 | desired_speed_change = float(input_video_duration) / float(desired_duration) 57 | print("Desired Speed Change: " + str(desired_speed_change)) 58 | output_clip = clip.fx(vfx.speedx, desired_speed_change) 59 | output_clip.write_videofile(filename=temp_output_path, codec="libx264", preset="fast") 60 | 61 | # Read the processed video bytes from the temporary output file 62 | with open(temp_output_path, "rb") as f: 63 | video_bytes = f.read() 64 | 65 | # Now it's safe to delete the output temp file since its content is already read 66 | os.remove(temp_output_path) 67 | 68 | return video_bytes 69 | -------------------------------------------------------------------------------- /utils/ml_processor/gpu/gpu.py: -------------------------------------------------------------------------------- 1 | import json 2 | from shared.constants import InferenceParamType 3 | from ui_components.methods.data_logger import log_model_inference 4 | from utils.constants import MLQueryObject 5 | from utils.data_repo.data_repo import DataRepo 6 | from utils.ml_processor.comfy_data_transform import get_file_path_list, get_model_workflow_from_query 7 | from utils.ml_processor.constants import MLModel 8 | from utils.ml_processor.gpu.utils import predict_gpu_output, setup_comfy_runner 9 | from utils.ml_processor.ml_interface import MachineLearningProcessor 10 | import time 11 | 12 | 13 | class GPUProcessor(MachineLearningProcessor): 14 | def __init__(self): 15 | setup_comfy_runner() 16 | data_repo = DataRepo() 17 | self.app_settings = data_repo.get_app_secrets_from_user_uuid() 18 | super().__init__() 19 | 20 | def predict_model_output_standardized( 21 | self, 22 | model: MLModel, 23 | query_obj: MLQueryObject, 24 | queue_inference=False, 25 | backlog=False, 26 | ): 27 | ( 28 | workflow_type, 29 | workflow_json, 30 | output_node_ids, 31 | extra_model_list, 32 | ignore_list, 33 | ) = get_model_workflow_from_query(model, query_obj) 34 | 35 | file_path_list = get_file_path_list(model, query_obj) 36 | 37 | # this is the format that is expected by comfy_runner 38 | data = { 39 | "workflow_input": workflow_json, 40 | "file_path_list": file_path_list, 41 | "output_node_ids": output_node_ids, 42 | "extra_model_list": extra_model_list, 43 | "ignore_model_list": ignore_list, 44 | } 45 | 46 | params = { 47 | "prompt": query_obj.prompt, # hackish sol 48 | InferenceParamType.QUERY_DICT.value: query_obj.to_json(), 49 | InferenceParamType.GPU_INFERENCE.value: json.dumps(data), 50 | InferenceParamType.FILE_RELATION_DATA.value: query_obj.relation_data, 51 | } 52 | return ( 53 | self.predict_model_output(model, **params) 54 | if not queue_inference 55 | else self.queue_prediction(model, **params, backlog=backlog) 56 | ) 57 | 58 | def predict_model_output(self, replicate_model: MLModel, **kwargs): 59 | queue_inference = kwargs.get("queue_inference", False) 60 | if queue_inference: 61 | return self.queue_prediction(replicate_model, **kwargs) 62 | 63 | data = kwargs.get(InferenceParamType.GPU_INFERENCE.value, None) 64 | data = json.loads(data) 65 | start_time = time.time() 66 | output = predict_gpu_output(data["workflow_input"], data["file_path_list"], data["output_node_ids"]) 67 | end_time = time.time() 68 | 69 | log = log_model_inference(replicate_model, end_time - start_time, **kwargs) 70 | return output, log 71 | 72 | def queue_prediction(self, replicate_model, **kwargs): 73 | log = log_model_inference(replicate_model, None, **kwargs) 74 | return None, log 75 | 76 | def upload_training_data(self, zip_file_name, delete_after_upload=False): 77 | # TODO: fix for online hosting 78 | # return the local file path as it is 79 | return zip_file_name 80 | -------------------------------------------------------------------------------- /.github/workflows/deploy.yml: -------------------------------------------------------------------------------- 1 | name: Deploy to ECR 2 | 3 | on: 4 | 5 | push: 6 | branches: [ piyush-dev ] 7 | 8 | jobs: 9 | 10 | build: 11 | name: Build Image 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | 16 | - name: Check out code 17 | uses: actions/checkout@v2 18 | 19 | - name: Configure AWS credentials 20 | uses: aws-actions/configure-aws-credentials@v1 21 | with: 22 | aws-access-key-id: ${{ secrets.AWS_ECR_ACCESS_KEY }} 23 | aws-secret-access-key: ${{ secrets.AWS_ECR_SECRET_KEY }} 24 | aws-region: ap-south-1 25 | 26 | - name: Login to Amazon ECR 27 | id: login-ecr 28 | uses: aws-actions/amazon-ecr-login@v1 29 | 30 | - name: Build, tag, and push image to Amazon ECR 31 | env: 32 | ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }} 33 | ECR_REPOSITORY: banodoco-frontend 34 | IMAGE_TAG: latest 35 | id: build-image 36 | run: | 37 | docker build -t $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG . 38 | docker push $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG 39 | echo "image=$ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG" >> $GITHUB_OUTPUT 40 | 41 | - name: Fill in the new image ID in the Amazon ECS task definition 42 | env: 43 | ECS_TASK_DEFINITION: .aws/task-definition.json 44 | CONTAINER_NAME: backend-banodoco-frontend 45 | id: task-def 46 | uses: aws-actions/amazon-ecs-render-task-definition@c804dfbdd57f713b6c079302a4c01db7017a36fc 47 | with: 48 | task-definition: ${{ env.ECS_TASK_DEFINITION }} 49 | container-name: ${{ env.CONTAINER_NAME }} 50 | image: ${{ steps.build-image.outputs.image }} 51 | 52 | - name: Deploy Amazon ECS task definition 53 | env: 54 | ECS_SERVICE: backend-banodoco-frontend-service 55 | ECS_CLUSTER: backend-banodoco-frontend-cluster 56 | uses: aws-actions/amazon-ecs-deploy-task-definition@df9643053eda01f169e64a0e60233aacca83799a 57 | with: 58 | task-definition: ${{ steps.task-def.outputs.task-definition }} 59 | service: ${{ env.ECS_SERVICE }} 60 | cluster: ${{ env.ECS_CLUSTER }} 61 | wait-for-service-stability: true 62 | 63 | update-runner: 64 | name: Update Background Runner 65 | runs-on: self-hosted 66 | 67 | steps: 68 | - name: Checkout code 69 | uses: actions/checkout@v2 70 | 71 | - name: Create and activate virtual environment 72 | run: | 73 | python3 -m venv venv 74 | source venv/bin/activate 75 | 76 | - name: Install dependencies 77 | run: | 78 | python3 -m pip install --upgrade pip 79 | source venv/bin/activate && pip install -r requirements.txt 80 | 81 | - name: Create .env file 82 | run: | 83 | touch .env 84 | echo "SERVER=production" > .env 85 | echo "SERVER_URL=https://api.banodoco.ai" >> .env 86 | echo "HOSTED_BACKGROUND_RUNNER_MODE=True" >> .env 87 | echo "admin_email=${{ secrets.ADMIN_EMAIL }}" >> .env 88 | echo "admin_password=${{ secrets.ADMIN_PASSWORD }}" >> .env 89 | echo "ENCRYPTION_KEY=${{ secrets.FERNET_ENCRYPTION_KEY }}" >> .env 90 | 91 | - name: Restart runner 92 | run: | 93 | if pkill -0 -f "banodoco_runner"; then 94 | pkill -f "banodoco_runner" 95 | fi 96 | . venv/bin/activate && nohup python banodoco_runner.py > script_output.log 2>&1 & -------------------------------------------------------------------------------- /shared/utils.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import time 3 | import jwt 4 | import requests 5 | from shared.constants import SERVER_URL, InternalResponse 6 | import urllib.parse 7 | 8 | 9 | def execute_shell_command(command: str): 10 | import subprocess 11 | 12 | result = subprocess.run(command, shell=True, capture_output=True, text=True) 13 | # print("Error:\n", result.stderr) 14 | return InternalResponse(result.stdout, "success", result.returncode == 0) 15 | 16 | 17 | def is_online_file_path(file_path): 18 | parsed = urllib.parse.urlparse(file_path) 19 | return parsed.scheme in ("http", "https", "ftp") 20 | 21 | 22 | def is_url_valid(url): 23 | try: 24 | response = requests.head(url, allow_redirects=True) 25 | final_response = response.history[-1] if response.history else response 26 | 27 | return final_response.status_code in [200, 201, 307] # TODO: handle all possible status codes 28 | except Exception as e: 29 | return False 30 | 31 | 32 | def get_file_type(url): 33 | try: 34 | response = requests.head(url) 35 | content_type = response.headers.get("Content-Type") 36 | 37 | if content_type and "image" in content_type: 38 | return "image" 39 | elif content_type and "video" in content_type: 40 | return "video" 41 | else: 42 | return "unknown" 43 | except Exception as e: 44 | print("Error:", e) 45 | return "unknown" 46 | 47 | 48 | def generate_fresh_token(refresh_token): 49 | if not refresh_token: 50 | return None, None 51 | 52 | url = f"{SERVER_URL}/v1/authentication/refresh" 53 | 54 | payload = {} 55 | headers = {"Authorization": f"Bearer {refresh_token}"} 56 | 57 | response = requests.request("GET", url, headers=headers, data=payload) 58 | if response.status_code == 200: 59 | data = response.json() 60 | return data["payload"]["token"], data["payload"]["refresh_token"] 61 | 62 | return None, None 63 | 64 | 65 | def validate_token_through_db(token, refresh_token): 66 | url = f"{SERVER_URL}/v1/user/op" 67 | 68 | payload = {} 69 | headers = {"Authorization": f"Bearer {token}"} 70 | 71 | response = requests.request("GET", url, headers=headers, data=payload) 72 | if response.status_code == 200: 73 | data = response.json() 74 | return token, refresh_token 75 | else: 76 | return generate_fresh_token(refresh_token) 77 | 78 | 79 | def validate_token( 80 | token, 81 | refresh_token, 82 | validate_through_db=False, 83 | ): 84 | # returns a fresh token if the old one has expired 85 | # returns None if the token has expired or can't be renewed 86 | if not token: 87 | return None, None 88 | 89 | try: 90 | decoded_token = jwt.decode(token, options={"verify_signature": False}) 91 | exp = decoded_token.get("exp") 92 | 93 | if exp is None: 94 | return token, refresh_token 95 | 96 | now = time.time() 97 | if exp > now: 98 | if not validate_through_db: 99 | return token, refresh_token 100 | else: 101 | return validate_token_through_db(token, refresh_token) 102 | else: 103 | return generate_fresh_token(refresh_token) 104 | 105 | except jwt.ExpiredSignatureError: 106 | print("expired token, trying to refresh...") 107 | return generate_fresh_token(refresh_token) 108 | except Exception as e: 109 | print("error validating the jwt: ", str(e)) 110 | return None, None 111 | -------------------------------------------------------------------------------- /scripts/entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | COMMAND="streamlit run app.py --runner.fastReruns false --server.runOnSave true --server.port 5500" 4 | 5 | compare_versions() { 6 | IFS='.' read -r -a ver1 << EOF 7 | $(echo "$1") 8 | EOF 9 | IFS='.' read -r -a ver2 << EOF 10 | $(echo "$2") 11 | EOF 12 | 13 | # Fill empty fields with zeros 14 | i=${#ver1[@]} 15 | while [ $i -lt ${#ver2[@]} ]; do 16 | ver1[i]=0 17 | i=$((i + 1)) 18 | done 19 | i=0 20 | while [ $i -lt ${#ver2[@]} ]; do 21 | if [ -z "${ver1[i]}" ]; then 22 | ver1[i]=0 23 | fi 24 | i=$((i + 1)) 25 | done 26 | 27 | # Compare major and minor versions 28 | i=0 29 | while [ $i -lt $((${#ver1[@]} - 1)) ]; do 30 | if [ $((10#${ver1[i]})) -gt $((10#${ver2[i]})) ]; then 31 | echo "1" 32 | return 33 | elif [ $((10#${ver1[i]})) -lt $((10#${ver2[i]})) ]; then 34 | echo "-1" 35 | return 36 | fi 37 | i=$((i + 1)) 38 | done 39 | 40 | # Compare patch versions 41 | if [ $((10#${ver1[2]})) -gt $((10#${ver2[2]})) ]; then 42 | echo "1" 43 | return 44 | elif [ $((10#${ver1[2]})) -lt $((10#${ver2[2]})) ]; then 45 | echo "-1" 46 | return 47 | fi 48 | 49 | echo "0" 50 | } 51 | 52 | update_app() { 53 | CURRENT_VERSION=$(curl -s "https://raw.githubusercontent.com/banodoco/Dough/feature/final/scripts/app_version.txt") 54 | 55 | ERR_MSG="Unable to fetch the current version from the remote repository." 56 | # file not present 57 | if ! echo "$CURRENT_VERSION" | grep -q '^[0-9]\+\.[0-9]\+\.[0-9]\+$'; then 58 | echo $ERR_MSG 59 | return 60 | fi 61 | # file is empty 62 | if [ -z "$CURRENT_VERSION" ]; then 63 | echo $ERR_MSG 64 | return 65 | fi 66 | 67 | CURRENT_DIR=$(pwd) 68 | LOCAL_VERSION=$(cat ${CURRENT_DIR}/scripts/app_version.txt) 69 | echo "local version $LOCAL_VERSION" 70 | echo "current version $CURRENT_VERSION" 71 | VERSION_DIFF=$(compare_versions "$LOCAL_VERSION" "$CURRENT_VERSION") 72 | if [ "$VERSION_DIFF" == "-1" ]; then 73 | echo "A newer version ($CURRENT_VERSION) is available. Updating..." 74 | 75 | git stash 76 | # Step 1: Pull from the current branch 77 | git pull origin "$(git rev-parse --abbrev-ref HEAD)" 78 | 79 | # Step 2: Check if the comfy_runner folder is present 80 | if [ -d "${CURRENT_DIR}/comfy_runner" ]; then 81 | # Step 3a: If comfy_runner is present, pull from the feature/package branch 82 | # echo "comfy_runner folder found. Pulling from feature/package branch." 83 | cd comfy_runner && git pull origin main 84 | cd .. 85 | else 86 | # Step 3b: If comfy_runner is not present, clone the repository 87 | echo "comfy_runner folder not found. Cloning repository." 88 | REPO_URL="https://github.com/piyushK52/comfy_runner.git" 89 | git clone "$REPO_URL" "${CURRENT_DIR}/comfy_runner" 90 | fi 91 | 92 | echo "$CURRENT_VERSION" > ${CURRENT_DIR}/scripts/app_version.txt 93 | else 94 | echo "You have the latest version ($LOCAL_VERSION)." 95 | fi 96 | } 97 | 98 | while [ "$#" -gt 0 ]; do 99 | case $1 in 100 | --update) 101 | update_app 102 | ;; 103 | *) 104 | echo "Invalid option: $1" >&2 105 | exit 1 106 | ;; 107 | esac 108 | shift 109 | done 110 | 111 | # Execute the base command 112 | eval $COMMAND -------------------------------------------------------------------------------- /scripts/entrypoint.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | set COMMAND=streamlit run app.py --runner.fastReruns false --server.runOnSave true --server.port 5500 4 | goto :loop 5 | 6 | :compare_versions 7 | setlocal EnableDelayedExpansion 8 | set "ver1=%~1" 9 | set "ver2=%~2" 10 | 11 | set "ver1=!ver1:.= !" 12 | set "ver2=!ver2:.= !" 13 | set "arr1=" 14 | set "arr2=" 15 | 16 | set "index=0" 17 | for %%c in (%ver1%) do ( 18 | set "arr1[!index!]=%%c" 19 | set /a "index+=1" 20 | ) 21 | 22 | set "index=0" 23 | for %%c in (%ver2%) do ( 24 | set "arr2[!index!]=%%c" 25 | set /a "index+=1" 26 | ) 27 | 28 | set CURRENT_DIR=%cd% 29 | 30 | for /L %%i in (1,1,3) do ( 31 | set "v1=!arr1[%%i]!" 32 | set "v2=!arr2[%%i]!" 33 | if "!v1!" equ "" set "v1=0" 34 | if "!v2!" equ "" set "v2=0" 35 | if !v1! gtr !v2! ( 36 | echo You have the latest version 37 | endlocal & exit /b 38 | ) else if !v1! lss !v2! ( 39 | echo A newer version is available. Updating... 40 | git stash 41 | rem Step 1: Pull from the current branch 42 | git pull origin "!git rev-parse --abbrev-ref HEAD!" 43 | 44 | rem Step 2: Check if the comfy_runner folder is present 45 | if exist "!CURRENT_DIR!\comfy_runner" ( 46 | rem Step 3a: If comfy_runner is present, pull from the feature/package branch 47 | rem echo comfy_runner folder found. Pulling from feature/package branch. 48 | cd comfy_runner 49 | rem git pull origin feature/package 50 | cd "!CURRENT_DIR!" 51 | ) else ( 52 | rem Step 3b: If comfy_runner is not present, clone the repository 53 | echo comfy_runner folder not found. Cloning repository. 54 | set REPO_URL=https://github.com/piyushK52/comfy_runner.git 55 | git clone "!REPO_URL!" "!CURRENT_DIR!\comfy_runner" 56 | ) 57 | 58 | echo !CURRENT_VERSION! > "!CURRENT_DIR!\scripts\app_version.txt" 59 | endlocal & exit /b 60 | ) 61 | ) 62 | echo You have the latest version 63 | endlocal & exit /b 64 | 65 | :update_app 66 | setlocal EnableDelayedExpansion 67 | set CURRENT_VERSION= 68 | for /f "delims=" %%i in ('curl -s "https://raw.githubusercontent.com/banodoco/Dough/feature/final/scripts/app_version.txt"') do set CURRENT_VERSION=%%i 69 | 70 | echo %CURRENT_VERSION% | findstr /r "^[^a-zA-Z]*$" >nul 71 | if %errorlevel% neq 0 ( 72 | echo Invalid version format: %CURRENT_VERSION%. Expected format: X.X.X (e.g., 1.2.3^) 73 | endlocal & exit /b 74 | ) 75 | 76 | if not "%CURRENT_VERSION%" == "" ( 77 | echo %CURRENT_VERSION% 78 | ) else ( 79 | set ERR_MSG=Unable to fetch the current version from the remote repository. 80 | echo %ERR_MSG% 81 | exit /b 82 | ) 83 | 84 | set CURRENT_DIR=%cd% 85 | set LOCAL_VERSION= 86 | for /f "delims=" %%i in ('type "%CURRENT_DIR%\scripts\app_version.txt"') do set LOCAL_VERSION=%%i 87 | 88 | echo Local version %LOCAL_VERSION% 89 | echo Current version %CURRENT_VERSION% 90 | 91 | call :compare_versions "%LOCAL_VERSION%" "%CURRENT_VERSION%" 92 | endlocal & exit /b 93 | 94 | :loop 95 | if "%~1" == "--update" ( 96 | call :update_app 97 | %COMMAND% 98 | exit /b 99 | ) 100 | if not "%~1" == "" ( 101 | echo Invalid option: %1 >&2 102 | exit /b 103 | ) 104 | shift 105 | if not "%~1" == "" goto :loop 106 | 107 | %COMMAND% -------------------------------------------------------------------------------- /utils/cache/cache.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | 3 | from utils.enum import ExtendedEnum 4 | 5 | 6 | class CacheKey(ExtendedEnum): 7 | TIMING_DETAILS = "timing_details" 8 | APP_SETTING = "app_setting" 9 | APP_SECRET = "app_secret" 10 | PROJECT_SETTING = "project_setting" 11 | AI_MODEL = "ai_model" 12 | LOGGED_USER = "logged_user" 13 | FILE = "file" 14 | SHOT = "shot" 15 | # temp items (only cached for speed boost) 16 | LOG = "log" 17 | LOG_PAGES = "log_pages" 18 | PROJECT = "project" 19 | USER = "user" 20 | 21 | 22 | class StCache: 23 | @staticmethod 24 | def get(uuid, data_type): 25 | uuid = str(uuid) 26 | if data_type in st.session_state: 27 | for ele in st.session_state[data_type]: 28 | ele_uuid = ele["uuid"] if type(ele) is dict else str(ele.uuid) 29 | if ele_uuid == uuid: 30 | return ele 31 | 32 | return None 33 | 34 | @staticmethod 35 | def update(data, data_type) -> bool: 36 | object_found = False 37 | uuid = data["uuid"] if type(data) is dict else data.uuid 38 | uuid = str(uuid) 39 | 40 | if data_type in st.session_state: 41 | object_list = st.session_state[data_type] 42 | for idx, ele in enumerate(object_list): 43 | ele_uuid = ele["uuid"] if type(ele) is dict else str(ele.uuid) 44 | if ele_uuid == uuid: 45 | object_list[idx] = data 46 | object_found = True 47 | 48 | st.session_state[data_type] = object_list 49 | 50 | return object_found 51 | 52 | @staticmethod 53 | def add(data, data_type) -> bool: 54 | uuid = data["uuid"] if type(data) is dict else data.uuid 55 | uuid = str(uuid) 56 | obj = StCache.get(uuid, data_type) 57 | if obj: 58 | StCache.update(data, data_type) 59 | else: 60 | if data_type in st.session_state: 61 | object_list = st.session_state[data_type] 62 | else: 63 | object_list = [] 64 | 65 | object_list.append(data) 66 | 67 | st.session_state[data_type] = object_list 68 | 69 | @staticmethod 70 | def delete(uuid, data_type) -> bool: 71 | object_found = False 72 | uuid = str(uuid) 73 | if data_type in st.session_state: 74 | object_list = st.session_state[data_type] 75 | for ele in object_list: 76 | ele_uuid = ele["uuid"] if type(ele) is dict else str(ele.uuid) 77 | if ele_uuid == uuid: 78 | object_list.remove(ele) 79 | object_found = True 80 | break 81 | 82 | st.session_state[data_type] = object_list 83 | 84 | return object_found 85 | 86 | @staticmethod 87 | def delete_all(data_type) -> bool: 88 | if data_type in st.session_state: 89 | del st.session_state[data_type] 90 | return True 91 | 92 | return False 93 | 94 | @staticmethod 95 | def add_all(data_list, data_type) -> bool: 96 | for data in data_list: 97 | StCache.add(data, data_type) 98 | 99 | return True 100 | 101 | @staticmethod 102 | def get_all(data_type): 103 | if data_type in st.session_state: 104 | return st.session_state[data_type] 105 | 106 | return [] 107 | 108 | # deletes all cached objects of every data type 109 | @staticmethod 110 | def clear_entire_cache() -> bool: 111 | for c in CacheKey.value_list(): 112 | StCache.delete_all(c) 113 | 114 | return True 115 | -------------------------------------------------------------------------------- /ui_components/components/project_settings_page.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import streamlit as st 3 | import os 4 | import time 5 | from ui_components.widgets.attach_audio_element import attach_audio_element 6 | from PIL import Image 7 | 8 | from utils.common_utils import get_current_user_uuid 9 | from utils.data_repo.data_repo import DataRepo 10 | from utils.state_refresh import refresh_app 11 | 12 | 13 | def project_settings_page(project_uuid): 14 | data_repo = DataRepo() 15 | st.markdown("#### Project Settings") 16 | st.markdown("***") 17 | 18 | with st.expander("📋 Project name", expanded=True): 19 | project = data_repo.get_project_from_uuid(project_uuid) 20 | new_name = st.text_input("Enter new name:", project.name) 21 | if st.button("Save", key="project_name"): 22 | data_repo.update_project(uuid=project_uuid, name=new_name) 23 | refresh_app() 24 | project_settings = data_repo.get_project_setting(project_uuid) 25 | 26 | frame_sizes = ["512x512", "768x512", "512x768", "512x896", "896x512", "512x1024", "1024x512"] 27 | current_size = f"{project_settings.width}x{project_settings.height}" 28 | current_index = frame_sizes.index(current_size) if current_size in frame_sizes else 0 29 | 30 | with st.expander("🖼️ Frame Size", expanded=True): 31 | 32 | v1, v2, v3 = st.columns([4, 4, 2]) 33 | with v1: 34 | st.write("Current Size = ", project_settings.width, "x", project_settings.height) 35 | 36 | custom_frame_size = st.checkbox("Enter custom frame size", value=False) 37 | err = False 38 | if not custom_frame_size: 39 | frame_size = st.radio( 40 | "Select frame size:", 41 | options=frame_sizes, 42 | index=current_index, 43 | key="frame_size", 44 | horizontal=True, 45 | ) 46 | width, height = map(int, frame_size.split("x")) 47 | else: 48 | st.info( 49 | "This is an experimental feature. There might be some issues - particularly with image generation." 50 | ) 51 | width = st.text_input("Width", value=512) 52 | height = st.text_input("Height", value=512) 53 | try: 54 | width, height = int(width), int(height) 55 | err = False 56 | except Exception as e: 57 | st.error("Please input integer values") 58 | err = True 59 | 60 | if not err: 61 | img = Image.new("RGB", (width, height), color=(73, 109, 137)) 62 | st.image(img, width=70) 63 | 64 | if st.button("Save"): 65 | st.success("Frame size updated successfully") 66 | time.sleep(0.3) 67 | data_repo.update_project_setting(project_uuid, width=width) 68 | data_repo.update_project_setting(project_uuid, height=height) 69 | refresh_app() 70 | 71 | st.write("") 72 | st.write("") 73 | st.write("") 74 | delete_proj = st.checkbox("I confirm to delete this project entirely", value=False) 75 | if st.button("Delete Project", disabled=(not delete_proj)): 76 | project_list = data_repo.get_all_project_list(user_id=get_current_user_uuid()) 77 | if project_list and len(project_list) > 1: 78 | data_repo.update_project(uuid=project_uuid, is_disabled=True) 79 | st.success("Project deleted successfully") 80 | st.session_state["index_of_project_name"] = 0 81 | else: 82 | st.error("You can't delete the only available project") 83 | 84 | time.sleep(0.7) 85 | refresh_app() 86 | -------------------------------------------------------------------------------- /utils/ml_processor/comfy_workflows/ipadapter_composition_workflow_api.json: -------------------------------------------------------------------------------- 1 | { 2 | "1": { 3 | "inputs": { 4 | "ckpt_name": "Realistic_Vision_V5.1.safetensors" 5 | }, 6 | "class_type": "CheckpointLoaderSimple", 7 | "_meta": { 8 | "title": "Load Checkpoint" 9 | } 10 | }, 11 | "2": { 12 | "inputs": { 13 | "vae_name": "vae-ft-mse-840000-ema-pruned.safetensors" 14 | }, 15 | "class_type": "VAELoader", 16 | "_meta": { 17 | "title": "Load VAE" 18 | } 19 | }, 20 | "3": { 21 | "inputs": { 22 | "ipadapter_file": "ip_plus_composition_sd15.safetensors" 23 | }, 24 | "class_type": "IPAdapterModelLoader", 25 | "_meta": { 26 | "title": "Load IPAdapter Model" 27 | } 28 | }, 29 | "4": { 30 | "inputs": { 31 | "clip_name": "SD1.5/pytorch_model.bin" 32 | }, 33 | "class_type": "CLIPVisionLoader", 34 | "_meta": { 35 | "title": "Load CLIP Vision" 36 | } 37 | }, 38 | "6": { 39 | "inputs": { 40 | "image": "Hulk_Hogan.jpg", 41 | "upload": "image" 42 | }, 43 | "class_type": "LoadImage", 44 | "_meta": { 45 | "title": "Load Image" 46 | } 47 | }, 48 | "7": { 49 | "inputs": { 50 | "text": "hulk hogan", 51 | "clip": [ 52 | "1", 53 | 1 54 | ] 55 | }, 56 | "class_type": "CLIPTextEncode", 57 | "_meta": { 58 | "title": "CLIP Text Encode (Prompt)" 59 | } 60 | }, 61 | "8": { 62 | "inputs": { 63 | "text": "blurry, photo, malformed", 64 | "clip": [ 65 | "1", 66 | 1 67 | ] 68 | }, 69 | "class_type": "CLIPTextEncode", 70 | "_meta": { 71 | "title": "CLIP Text Encode (Prompt)" 72 | } 73 | }, 74 | "9": { 75 | "inputs": { 76 | "seed": 16, 77 | "steps": 30, 78 | "cfg": 5, 79 | "sampler_name": "dpmpp_2m_sde", 80 | "scheduler": "exponential", 81 | "denoise": 1, 82 | "model": [ 83 | "28", 84 | 0 85 | ], 86 | "positive": [ 87 | "7", 88 | 0 89 | ], 90 | "negative": [ 91 | "8", 92 | 0 93 | ], 94 | "latent_image": [ 95 | "10", 96 | 0 97 | ] 98 | }, 99 | "class_type": "KSampler", 100 | "_meta": { 101 | "title": "KSampler" 102 | } 103 | }, 104 | "10": { 105 | "inputs": { 106 | "width": 512, 107 | "height": 512, 108 | "batch_size": 1 109 | }, 110 | "class_type": "EmptyLatentImage", 111 | "_meta": { 112 | "title": "Empty Latent Image" 113 | } 114 | }, 115 | "11": { 116 | "inputs": { 117 | "samples": [ 118 | "9", 119 | 0 120 | ], 121 | "vae": [ 122 | "2", 123 | 0 124 | ] 125 | }, 126 | "class_type": "VAEDecode", 127 | "_meta": { 128 | "title": "VAE Decode" 129 | } 130 | }, 131 | "27": { 132 | "inputs": { 133 | "filename_prefix": "ComfyUI", 134 | "images": [ 135 | "11", 136 | 0 137 | ] 138 | }, 139 | "class_type": "SaveImage", 140 | "_meta": { 141 | "title": "Save Image" 142 | } 143 | }, 144 | "28": { 145 | "inputs": { 146 | "weight": 1, 147 | "weight_type": "linear", 148 | "combine_embeds": "concat", 149 | "embeds_scaling": "V only", 150 | "start_at": 0, 151 | "end_at": 1, 152 | "ipadapter": [ 153 | "3", 154 | 0 155 | ], 156 | "clip_vision": [ 157 | "4", 158 | 0 159 | ], 160 | "image": [ 161 | "6", 162 | 0 163 | ], 164 | "model": [ 165 | "1", 166 | 0 167 | ] 168 | }, 169 | "class_type": "IPAdapterAdvanced", 170 | "_meta": { 171 | "title": "IPAdapter Advanced" 172 | } 173 | } 174 | } -------------------------------------------------------------------------------- /scripts/config.toml: -------------------------------------------------------------------------------- 1 | [file_hash] 2 | "sd_xl_base_1.0.safetensors" = { "location" = "models/checkpoints/", "hash" = ["cf9d29192ef7433096a69fe5e32efba1"]} 3 | "sd_xl_refiner_1.0.safetensors" = { "location" = "models/checkpoints/", "hash" = ["c0df90f318abf30fcd2cd233ea638251"]} 4 | "Realistic_Vision_V5.1.safetensors" = { "location" = "models/checkpoints/", "hash" = ["03c000bbbbfede673b334897648431ba"]} 5 | "dreamshaper_8.safetensors" = { "location" = "models/checkpoints/", "hash" = ["1bfe9fce2c2c1498bcb6fdac2e44aab5"]} 6 | "sd_xl_refiner_1.0_0.9vae.safetensors" = { "location" = "models/checkpoints/", "hash" = ["74d31c96097471d5d3bd93a915ad0b30"]} 7 | "pytorch_model.bin" = { "location" = "", "hash" = ["0a0900180245762f122a681839391a9b", "ac277e29ea6d70799072d1ed52ac0f79"]} 8 | "v3_sd15_sparsectrl_rgb.ckpt" = { "location" = "models/controlnet/SD1.5/animatediff/", "hash" = ["b7d6b72ee2ba2866d167277598f27340"]} 9 | "inpainting_diffusion_pytorch_model.fp16.safetensors" = { "location" = "models/unet/inpainting_diffusion_pytorch_model.fp16.safetensors/", "hash" = ["9c688d78a75c1d3a0b33afffdd7e72f2"]} 10 | "vae-ft-mse-840000-ema-pruned.safetensors" = { "location" = "models/vae/", "hash" = ["418949762c3f321f2927e590e255f63c"]} 11 | "ip_plus_composition_sd15.safetensors" = { "location" = "models/ipadapter/", "hash" = ["029d345e257b72bc7be2fdf0a2295eba"]} 12 | "ip-adapter_sdxl.safetensors" = { "location" = "models/ipadapter/", "hash" = ["67b7f27bbf03d6b6921e520779abc212"]} 13 | "ip-adapter-plus_sd15.bin" = { "location" = "models/ipadapter/", "hash" = ["a50be6ac20883c4969c7bb60d5f2a46b"]} 14 | "v3_sd15_mm.ckpt" = { "location" = "models/animatediff_models/", "hash" = ["ac855686ee49fb5436c881c50a9b59ca"]} 15 | "AnimateLCM_sd15_t2v.ckpt" = { "location" = "models/animatediff_models/", "hash" = ["b5426e509e70b1b1b13ca03f877bac2a"]} 16 | "AnimateLCM_sd15_t2v_lora.safetensors" = { "location" = "models/loras/", "hash" = ["fb6fa0bf09bcfbe3d002ce0bfdfa7770"]} 17 | 18 | 19 | [node_version] 20 | "ComfyUI-Manager" = { "url" = "https://github.com/ltdrdata/ComfyUI-Manager", "commit_hash" = "7e777c5460754f4ef7a6f1ad9ef9356b89f66339"} 21 | "ComfyUI_IPAdapter_plus" = { "url" = "https://github.com/cubiq/ComfyUI_IPAdapter_plus", "commit_hash" = "b188a6cb39b512a9c6da7235b880af42c78ccd0d"} 22 | "comfyui-various" = { "url" = "https://github.com/jamesWalker55/comfyui-various", "commit_hash" = "cc66b62c0861314a4952eb96a6ae330f180bf6a1"} 23 | "comfy_PoP" = { "url" = "https://github.com/picturesonpictures/comfy_PoP", "commit_hash" = "db66c9777c6ab45f9c7056f1d8f365b0d4f2b339"} 24 | "ComfyUI-Frame-Interpolation" = { "url" = "https://github.com/Fannovel16/ComfyUI-Frame-Interpolation", "commit_hash" = "483dfe64465369e077d351ed2f1acbf7dc046864"} 25 | "efficiency-nodes-comfyui" = { "url" = "https://github.com/jags111/efficiency-nodes-comfyui", "commit_hash" = "3ead4afd120833f3bffdefeca0d6545df8051798"} 26 | "ComfyUI_FizzNodes" = { "url" = "https://github.com/FizzleDorf/ComfyUI_FizzNodes", "commit_hash" = "974b41cfdfde4f84d97712234d6e502f7da831fa"} 27 | "ComfyUI-Advanced-ControlNet" = { "url" = "https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet", "commit_hash" = "74d0c56ab3ba69663281390cc1b2072107939f96"} 28 | "ComfyUI-AnimateDiff-Evolved" = { "url" = "https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved", "commit_hash" = "83fe8d40638e3491a31fa2865107bdc14a308a35"} 29 | "ComfyUI-VideoHelperSuite" = { "url" = "https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite", "commit_hash" = "0376e577442c236fbba6ef410a4e5ec64aed5017"} 30 | "ComfyUI_essentials" = { "url" = "https://github.com/cubiq/ComfyUI_essentials", "commit_hash" = "99aad72c84e1dac2f924941ecb12c93007512a8c"} 31 | "steerable-motion" = { "url" = "https://github.com/banodoco/steerable-motion", "commit_hash" = "b95e9eed09741bd3f0fc3933ab45d09f6f055aae"} 32 | "ComfyUI_tinyterraNodes" = { "url" = "https://github.com/TinyTerra/ComfyUI_tinyterraNodes", "commit_hash" = "52711d57e97c0255e6c9d627cfc0dfec4ebacf22"} 33 | 34 | 35 | [comfy] 36 | commit_hash = "56e8f5e4fd0a048811095f44d2147bce48b02457" 37 | 38 | 39 | [strict_pkg_versions] 40 | "numpy" = "1.24.4" -------------------------------------------------------------------------------- /ui_components/components/query_logger_page.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import requests 4 | import streamlit as st 5 | from ui_components.constants import DefaultTimingStyleParams 6 | from utils.common_utils import get_current_user 7 | from shared.constants import ServerType, ConfigManager, GPU_INFERENCE_ENABLED_KEY 8 | from utils.data_repo.api_repo import APIRepo 9 | from utils.data_repo.data_repo import DataRepo 10 | from utils.state_refresh import refresh_app 11 | 12 | 13 | def query_logger_page(): 14 | st.markdown("##### Inference log") 15 | 16 | data_repo = DataRepo() 17 | api_repo = APIRepo() 18 | config_manager = ConfigManager() 19 | gpu_enabled = config_manager.get(GPU_INFERENCE_ENABLED_KEY, False) 20 | 21 | if not gpu_enabled: 22 | credits_remaining = api_repo.get_user_credits() 23 | 24 | c01, c02, _ = st.columns([1, 1, 2]) 25 | 26 | credits_remaining = round(credits_remaining, 3) 27 | with c01: 28 | st.write(f"### Credit Balance: {credits_remaining}") 29 | 30 | with c02: 31 | if st.button("Refresh Credits"): 32 | st.session_state["user_credit_data"] = None 33 | refresh_app() 34 | 35 | c1, c2, _ = st.columns([1, 1, 3]) 36 | with c1: 37 | credits_to_buy = st.number_input( 38 | label="Credits to Buy (10 credits = $1)", 39 | key="credit_btn", 40 | min_value=50, 41 | step=20, 42 | ) 43 | 44 | with c2: 45 | st.write("") 46 | st.write("") 47 | if st.button("Generate payment link"): 48 | payment_link = api_repo.generate_payment_link(int(credits_to_buy // 10)) 49 | if payment_link: 50 | st.write("Please click on the link below to make the payment") 51 | st.write(payment_link) 52 | else: 53 | st.write("error occured during payment link generation, pls try again") 54 | 55 | b1, b2 = st.columns([1, 0.2]) 56 | 57 | total_log_table_pages = ( 58 | st.session_state["total_log_table_pages"] 59 | if "total_log_table_pages" in st.session_state 60 | else DefaultTimingStyleParams.total_log_table_pages 61 | ) 62 | list_of_pages = [i for i in range(1, total_log_table_pages + 1)] 63 | page_number = b1.radio( 64 | "Select page:", options=list_of_pages, key="inference_log_page_number", index=0, horizontal=True 65 | ) 66 | # page_number = b1.number_input('Page number', min_value=1, max_value=total_log_table_pages, value=1, step=1) 67 | inference_log_list, total_page_count = data_repo.get_all_inference_log_list( 68 | page=page_number, data_per_page=100 69 | ) 70 | 71 | if total_log_table_pages != total_page_count: 72 | st.session_state["total_log_table_pages"] = total_page_count 73 | refresh_app() 74 | 75 | data = { 76 | "Project": [], 77 | "Prompt": [], 78 | "Model": [], 79 | "Inference time (sec)": [], 80 | "Credits": [], 81 | "Status": [], 82 | } 83 | 84 | # if SERVER != ServerType.DEVELOPMENT.value: 85 | # data[] = [] 86 | 87 | if inference_log_list and len(inference_log_list): 88 | for log in inference_log_list: 89 | data["Project"].append(log.project.name) 90 | prompt = json.loads(log.input_params).get("prompt", "") if log.input_params else "" 91 | data["Prompt"].append(prompt) 92 | model_name = log.model_name 93 | data["Model"].append(model_name) 94 | data["Inference time (sec)"].append(round(log.total_inference_time, 3)) 95 | data["Credits"].append(round(log.credits_used, 3)) 96 | data["Status"].append(log.status) 97 | 98 | st.table(data=data) 99 | st.markdown("***") 100 | else: 101 | st.info("No logs present") 102 | -------------------------------------------------------------------------------- /utils/ml_processor/comfy_workflows/sd3_workflow_api.json: -------------------------------------------------------------------------------- 1 | { 2 | "6": { 3 | "inputs": { 4 | "text": "a ninja in the night, red atmosphere with a giant red moon behind him", 5 | "clip": [ 6 | "252", 7 | 1 8 | ] 9 | }, 10 | "class_type": "CLIPTextEncode", 11 | "_meta": { 12 | "title": "CLIP Text Encode (Prompt)" 13 | } 14 | }, 15 | "13": { 16 | "inputs": { 17 | "shift": 3, 18 | "model": [ 19 | "252", 20 | 0 21 | ] 22 | }, 23 | "class_type": "ModelSamplingSD3", 24 | "_meta": { 25 | "title": "ModelSamplingSD3" 26 | } 27 | }, 28 | "67": { 29 | "inputs": { 30 | "conditioning": [ 31 | "71", 32 | 0 33 | ] 34 | }, 35 | "class_type": "ConditioningZeroOut", 36 | "_meta": { 37 | "title": "ConditioningZeroOut" 38 | } 39 | }, 40 | "68": { 41 | "inputs": { 42 | "start": 0.1, 43 | "end": 1, 44 | "conditioning": [ 45 | "67", 46 | 0 47 | ] 48 | }, 49 | "class_type": "ConditioningSetTimestepRange", 50 | "_meta": { 51 | "title": "ConditioningSetTimestepRange" 52 | } 53 | }, 54 | "69": { 55 | "inputs": { 56 | "conditioning_1": [ 57 | "68", 58 | 0 59 | ], 60 | "conditioning_2": [ 61 | "70", 62 | 0 63 | ] 64 | }, 65 | "class_type": "ConditioningCombine", 66 | "_meta": { 67 | "title": "Conditioning (Combine)" 68 | } 69 | }, 70 | "70": { 71 | "inputs": { 72 | "start": 0, 73 | "end": 0.1, 74 | "conditioning": [ 75 | "71", 76 | 0 77 | ] 78 | }, 79 | "class_type": "ConditioningSetTimestepRange", 80 | "_meta": { 81 | "title": "ConditioningSetTimestepRange" 82 | } 83 | }, 84 | "71": { 85 | "inputs": { 86 | "text": "bad quality, poor quality, doll, disfigured, jpg, toy, bad anatomy, missing limbs, missing fingers, 3d, cgi", 87 | "clip": [ 88 | "252", 89 | 1 90 | ] 91 | }, 92 | "class_type": "CLIPTextEncode", 93 | "_meta": { 94 | "title": "CLIP Text Encode (Negative Prompt)" 95 | } 96 | }, 97 | "135": { 98 | "inputs": { 99 | "width": 1024, 100 | "height": 768, 101 | "batch_size": 1 102 | }, 103 | "class_type": "EmptySD3LatentImage", 104 | "_meta": { 105 | "title": "EmptySD3LatentImage" 106 | } 107 | }, 108 | "231": { 109 | "inputs": { 110 | "samples": [ 111 | "271", 112 | 0 113 | ], 114 | "vae": [ 115 | "252", 116 | 2 117 | ] 118 | }, 119 | "class_type": "VAEDecode", 120 | "_meta": { 121 | "title": "VAE Decode" 122 | } 123 | }, 124 | "233": { 125 | "inputs": { 126 | "filename_prefix": "SD3", 127 | "images": [ 128 | "231", 129 | 0 130 | ] 131 | }, 132 | "class_type": "SaveImage", 133 | "_meta": { 134 | "title": "Save Image" 135 | } 136 | }, 137 | "252": { 138 | "inputs": { 139 | "ckpt_name": "sd3_medium_incl_clips_t5xxlfp16.safetensors" 140 | }, 141 | "class_type": "CheckpointLoaderSimple", 142 | "_meta": { 143 | "title": "Load Checkpoint" 144 | } 145 | }, 146 | "271": { 147 | "inputs": { 148 | "seed": 945512652412924, 149 | "steps": 28, 150 | "cfg": 4.5, 151 | "sampler_name": "dpmpp_2m", 152 | "scheduler": "sgm_uniform", 153 | "denoise": 1, 154 | "model": [ 155 | "13", 156 | 0 157 | ], 158 | "positive": [ 159 | "6", 160 | 0 161 | ], 162 | "negative": [ 163 | "69", 164 | 0 165 | ], 166 | "latent_image": [ 167 | "135", 168 | 0 169 | ] 170 | }, 171 | "class_type": "KSampler", 172 | "_meta": { 173 | "title": "KSampler" 174 | } 175 | } 176 | } -------------------------------------------------------------------------------- /utils/ml_processor/comfy_workflows/ipadapter_plus_api.json: -------------------------------------------------------------------------------- 1 | { 2 | "3": { 3 | "inputs": { 4 | "seed": 641608455784125, 5 | "steps": 24, 6 | "cfg": 9.25, 7 | "sampler_name": "ddim", 8 | "scheduler": "normal", 9 | "denoise": 1, 10 | "model": [ 11 | "27", 12 | 0 13 | ], 14 | "positive": [ 15 | "6", 16 | 0 17 | ], 18 | "negative": [ 19 | "7", 20 | 0 21 | ], 22 | "latent_image": [ 23 | "5", 24 | 0 25 | ] 26 | }, 27 | "class_type": "KSampler", 28 | "_meta": { 29 | "title": "KSampler" 30 | } 31 | }, 32 | "4": { 33 | "inputs": { 34 | "ckpt_name": "sd_xl_base_1.0.safetensors" 35 | }, 36 | "class_type": "CheckpointLoaderSimple", 37 | "_meta": { 38 | "title": "Load Checkpoint" 39 | } 40 | }, 41 | "5": { 42 | "inputs": { 43 | "width": 1024, 44 | "height": 1024, 45 | "batch_size": 1 46 | }, 47 | "class_type": "EmptyLatentImage", 48 | "_meta": { 49 | "title": "Empty Latent Image" 50 | } 51 | }, 52 | "6": { 53 | "inputs": { 54 | "text": "", 55 | "clip": [ 56 | "4", 57 | 1 58 | ] 59 | }, 60 | "class_type": "CLIPTextEncode", 61 | "_meta": { 62 | "title": "CLIP Text Encode (Prompt)" 63 | } 64 | }, 65 | "7": { 66 | "inputs": { 67 | "text": "", 68 | "clip": [ 69 | "4", 70 | 1 71 | ] 72 | }, 73 | "class_type": "CLIPTextEncode", 74 | "_meta": { 75 | "title": "CLIP Text Encode (Prompt)" 76 | } 77 | }, 78 | "8": { 79 | "inputs": { 80 | "samples": [ 81 | "3", 82 | 0 83 | ], 84 | "vae": [ 85 | "4", 86 | 2 87 | ] 88 | }, 89 | "class_type": "VAEDecode", 90 | "_meta": { 91 | "title": "VAE Decode" 92 | } 93 | }, 94 | "23": { 95 | "inputs": { 96 | "clip_name": "SDXL/pytorch_model.bin" 97 | }, 98 | "class_type": "CLIPVisionLoader", 99 | "_meta": { 100 | "title": "Load CLIP Vision" 101 | } 102 | }, 103 | "26": { 104 | "inputs": { 105 | "ipadapter_file": "ip-adapter_sdxl.safetensors" 106 | }, 107 | "class_type": "IPAdapterModelLoader", 108 | "_meta": { 109 | "title": "Load IPAdapter Model" 110 | } 111 | }, 112 | "27": { 113 | "inputs": { 114 | "weight": 1, 115 | "weight_type": "linear", 116 | "combine_embeds": "concat", 117 | "embeds_scaling": "V only", 118 | "start_at": 0, 119 | "end_at": 1, 120 | "ipadapter": [ 121 | "26", 122 | 0 123 | ], 124 | "clip_vision": [ 125 | "23", 126 | 0 127 | ], 128 | "image": [ 129 | "39", 130 | 0 131 | ], 132 | "model": [ 133 | "4", 134 | 0 135 | ] 136 | }, 137 | "class_type": "IPAdapterAdvanced", 138 | "_meta": { 139 | "title": "IPAdapter Advanced" 140 | } 141 | }, 142 | "28": { 143 | "inputs": { 144 | "image": "714d97a3fe2dcf645f1b500d523d4d3c848acc65bfda3602c56305fc.jpg", 145 | "upload": "image" 146 | }, 147 | "class_type": "LoadImage", 148 | "_meta": { 149 | "title": "Load Image" 150 | } 151 | }, 152 | "29": { 153 | "inputs": { 154 | "filename_prefix": "ComfyUI", 155 | "images": [ 156 | "8", 157 | 0 158 | ] 159 | }, 160 | "class_type": "SaveImage", 161 | "_meta": { 162 | "title": "Save Image" 163 | } 164 | }, 165 | "39": { 166 | "inputs": { 167 | "interpolation": "LANCZOS", 168 | "crop_position": "top", 169 | "sharpening": 0, 170 | "image": [ 171 | "28", 172 | 0 173 | ] 174 | }, 175 | "class_type": "PrepImageForClipVision", 176 | "_meta": { 177 | "title": "Prepare Image For Clip Vision" 178 | } 179 | } 180 | } -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Open Source Native License (OSNL) 2 | Version 0.1 - March 1, 2024 3 | 4 | Preamble 5 | 6 | The Open Source Native License (OSNL) is designed to ensure that software remains free and open, fostering innovation and knowledge sharing within the community. It grants individuals, researchers, and commercial entities who open source their primary business assets, the freedom to use the software in any manner they choose. 7 | 8 | This distinctive approach aims to balance the benefits of open-source development with the realities of commercial enterprise. It ensures that software remains a shared, community-driven resource while enabling businesses to thrive in an open-source ecosystem. Additional licenses are available for non-open source commercial entities. 9 | 10 | 1. Definitions 11 | 12 | - "This License" refers to Version 1.0 of the Open Source Native License. 13 | - "The Program" refers to the software distributed under this License. 14 | - "You" refers to the individual or entity utilizing or contributing to the Program. 15 | - "Primary Business Assets" are the core resources, capabilities, and technology that constitute the main value proposition and operational basis of your business. 16 | 17 | 2. Grant of License 18 | 19 | Subject to the terms and conditions of this License, you are hereby granted a free, perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable license to use, reproduce, modify, distribute, and sublicense the Program, provided you comply with the following condition: 20 | 21 | - Individual or researcher: You are granted the rights to use, modify, distribute, and contribute to the Program for any purpose, including educational, research, and personal projects, without the necessity to make your personal projects open source, provided these activities do not constitute a commercial enterprise. For any use that transitions to commercial purposes, the conditions applicable to commercial entities as outlined in this License will then apply. 22 | - Commercial Entity who meets open source condition: Your primary business assets, including all core technologies, software, and platforms, must be available under an OSI-approved open source license or OSNL. This condition does not apply to ancillary or peripheral services not constituting primary business assets. 23 | 24 | 2.1 Commercial Use by Non-Open Source Businesses 25 | 26 | Non-open source businesses that wish to utilize the Program or its derivatives as a component of their products or services are required to obtain an additional license. These entities must proactively contact Banodoco to request such a license. Banodoco reserves the right, at its own discretion, to grant or deny this additional license. Until an additional license is granted by Banodoco, non-open source businesses are not authorized to exercise any rights provided under this License regarding the use of the Program or its derivatives. 27 | 28 | 3. Redistribution 29 | 30 | You may reproduce and distribute copies of the Program or derivative works thereof in any medium, with or without modifications, provided that you meet the following conditions: 31 | 32 | - You must give any recipients of the Program a copy of this License. 33 | - You must ensure that any modified files carry prominent notices stating that you changed the files. 34 | - You must disclose the source of the Program, and if you distribute any portion of it in a compiled or object code form, you must also provide the full source code under this License. 35 | - Any distribution of the Program or derivative works must comply with the Primary Business Open Source Condition. 36 | 37 | 4. Disclaimer of Warranty 38 | 39 | THE PROGRAM IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY ARISING FROM THE USE OF THE PROGRAM. 40 | 41 | 5. General 42 | 43 | This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Program. -------------------------------------------------------------------------------- /ui_components/components/new_project_page.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from banodoco_settings import create_new_project 3 | from ui_components.methods.common_methods import ( 4 | save_audio_file, 5 | create_frame_inside_shot, 6 | save_and_promote_image, 7 | ) 8 | from utils.state_refresh import refresh_app 9 | from utils.common_utils import get_current_user_uuid, reset_project_state 10 | from utils.data_repo.data_repo import DataRepo 11 | import time 12 | from PIL import Image 13 | 14 | import utils.local_storage.local_storage as local_storage 15 | 16 | 17 | def new_project_page(): 18 | 19 | # Initialize data repository 20 | data_repo = DataRepo() 21 | 22 | # title 23 | st.markdown("#### New Project") 24 | st.markdown("***") 25 | # Define multicolumn layout 26 | project_column, _ = st.columns([1, 3]) 27 | 28 | # Prompt user for project naming within project_column 29 | with project_column: 30 | new_project_name = st.text_input("Project name:", value="") 31 | 32 | # Prompt user for video dimension specifications 33 | v1, v2, v3 = st.columns([6, 3, 9]) 34 | 35 | frame_sizes = ["512x512", "768x512", "512x768", "512x896", "896x512", "512x1024", "1024x512"] 36 | with v1: 37 | 38 | frame_size = st.radio("Select frame size:", options=frame_sizes, key="frame_size", horizontal=True) 39 | width, height = map(int, frame_size.split("x")) 40 | 41 | with v2: 42 | # is width or height > 767 43 | if width > 769 or height > 769: 44 | st.warning("There may be issues with very wide or high frames.") 45 | img = Image.new("RGB", (width, height), color=(73, 109, 137)) 46 | st.image(img, use_column_width=True) 47 | # st.info("Uploaded images will be resized to the selected dimensions.") 48 | 49 | # with v1: 50 | # audio = st.radio("Audio:", ["No audio", "Attach new audio"], key="audio", horizontal=True) 51 | 52 | # # Display audio upload option if user selects "Attach new audio" 53 | # if audio == "Attach new audio": 54 | # uploaded_audio = st.file_uploader("Choose an audio file:") 55 | # else: 56 | # uploaded_audio = None 57 | 58 | st.write("") 59 | 60 | if st.button("Create New Project"): 61 | # Add checks for project name existence and format 62 | if not new_project_name: 63 | st.error("Please enter a project name.") 64 | else: 65 | current_user = data_repo.get_first_active_user() 66 | new_project, shot = create_new_project(current_user, new_project_name, width, height) 67 | new_timing = create_frame_inside_shot(shot.uuid, 0) 68 | # remvoing the initial frame which moved to the 1st position 69 | # (since creating new project also creates a frame) 70 | shot = data_repo.get_shot_from_number(new_project.uuid, 1) 71 | initial_frame = data_repo.get_timing_from_frame_number(shot.uuid, 0) 72 | data_repo.delete_timing_from_uuid(initial_frame.uuid) 73 | 74 | # if uploaded_audio: 75 | # try: 76 | # if save_audio_file(uploaded_audio, new_project.uuid): 77 | # st.success("Audio file saved and attached successfully.") 78 | # else: 79 | # st.error("Failed to save and attach the audio file.") 80 | # except Exception as e: 81 | # st.error(f"Failed to save the uploaded audio due to {str(e)}") 82 | 83 | reset_project_state() 84 | 85 | st.session_state["project_uuid"] = new_project.uuid 86 | project_list = data_repo.get_all_project_list(user_id=get_current_user_uuid()) 87 | st.session_state["index_of_project_name"] = len(project_list) - 1 88 | st.session_state["main_view_type"] = "Creative Process" 89 | st.session_state["app_settings"] = 0 90 | st.success("Project created successfully!") 91 | time.sleep(1) 92 | refresh_app() 93 | 94 | st.markdown("***") 95 | -------------------------------------------------------------------------------- /utils/ml_processor/comfy_workflows/dynamicrafter_api.json: -------------------------------------------------------------------------------- 1 | { 2 | "2": { 3 | "inputs": { 4 | "frame_rate": 12, 5 | "loop_count": 0, 6 | "filename_prefix": "AnimateDiff", 7 | "format": "video/h264-mp4", 8 | "pix_fmt": "yuv420p", 9 | "crf": 19, 10 | "save_metadata": true, 11 | "pingpong": false, 12 | "save_output": true, 13 | "images": [ 14 | "34", 15 | 0 16 | ] 17 | }, 18 | "class_type": "VHS_VideoCombine", 19 | "_meta": { 20 | "title": "Video Combine 🎥🅥🅗🅢" 21 | } 22 | }, 23 | "11": { 24 | "inputs": { 25 | "ckpt_name": "dynamicrafter_512_interp_v1.ckpt", 26 | "dtype": "auto", 27 | "fp8_unet": false 28 | }, 29 | "class_type": "DynamiCrafterModelLoader", 30 | "_meta": { 31 | "title": "DynamiCrafterModelLoader" 32 | } 33 | }, 34 | "12": { 35 | "inputs": { 36 | "steps": 50, 37 | "cfg": 5, 38 | "eta": 1, 39 | "frames": 16, 40 | "prompt": "dolly zoom out", 41 | "seed": 262623773159722, 42 | "fs": 10, 43 | "keep_model_loaded": true, 44 | "vae_dtype": "auto", 45 | "cut_near_keyframes": 0, 46 | "model": [ 47 | "11", 48 | 0 49 | ], 50 | "images": [ 51 | "15", 52 | 0 53 | ] 54 | }, 55 | "class_type": "DynamiCrafterBatchInterpolation", 56 | "_meta": { 57 | "title": "DynamiCrafterBatchInterpolation" 58 | } 59 | }, 60 | "15": { 61 | "inputs": { 62 | "image1": [ 63 | "37", 64 | 0 65 | ], 66 | "image2": [ 67 | "38", 68 | 0 69 | ] 70 | }, 71 | "class_type": "ImageBatch", 72 | "_meta": { 73 | "title": "Batch Images" 74 | } 75 | }, 76 | "16": { 77 | "inputs": { 78 | "image": "ea47a572b4e5b52ea7da22384232381b3e62048fa715f042b38b4da9 (1) (2).jpg", 79 | "upload": "image" 80 | }, 81 | "class_type": "LoadImage", 82 | "_meta": { 83 | "title": "Load Image" 84 | } 85 | }, 86 | "17": { 87 | "inputs": { 88 | "image": "2193d9ded46130b41d09133b4b1d2502f0eaa19ea1762252c6581e86 (1) (1).jpg", 89 | "upload": "image" 90 | }, 91 | "class_type": "LoadImage", 92 | "_meta": { 93 | "title": "Load Image" 94 | } 95 | }, 96 | "34": { 97 | "inputs": { 98 | "ckpt_name": "film_net_fp32.pt", 99 | "clear_cache_after_n_frames": 10, 100 | "multiplier": 3, 101 | "frames": [ 102 | "12", 103 | 0 104 | ] 105 | }, 106 | "class_type": "FILM VFI", 107 | "_meta": { 108 | "title": "FILM VFI" 109 | } 110 | }, 111 | "35": { 112 | "inputs": { 113 | "frame_rate": 8, 114 | "loop_count": 0, 115 | "filename_prefix": "AnimateDiff", 116 | "format": "image/gif", 117 | "pingpong": false, 118 | "save_output": true, 119 | "images": [ 120 | "12", 121 | 0 122 | ] 123 | }, 124 | "class_type": "VHS_VideoCombine", 125 | "_meta": { 126 | "title": "Video Combine 🎥🅥🅗🅢" 127 | } 128 | }, 129 | "37": { 130 | "inputs": { 131 | "mode": "rescale", 132 | "supersample": "true", 133 | "resampling": "lanczos", 134 | "rescale_factor": 0.7000000000000001, 135 | "resize_width": 1024, 136 | "resize_height": 1536, 137 | "image": [ 138 | "16", 139 | 0 140 | ] 141 | }, 142 | "class_type": "Image Resize", 143 | "_meta": { 144 | "title": "Image Resize" 145 | } 146 | }, 147 | "38": { 148 | "inputs": { 149 | "mode": "rescale", 150 | "supersample": "true", 151 | "resampling": "lanczos", 152 | "rescale_factor": 0.7000000000000001, 153 | "resize_width": 1024, 154 | "resize_height": 1536, 155 | "image": [ 156 | "17", 157 | 0 158 | ] 159 | }, 160 | "class_type": "Image Resize", 161 | "_meta": { 162 | "title": "Image Resize" 163 | } 164 | } 165 | } -------------------------------------------------------------------------------- /utils/ml_processor/comfy_workflows/sdxl_controlnet_workflow_api.json: -------------------------------------------------------------------------------- 1 | { 2 | "3": { 3 | "inputs": { 4 | "seed": 741148140596738, 5 | "steps": 20, 6 | "cfg": 8, 7 | "sampler_name": "euler", 8 | "scheduler": "normal", 9 | "denoise": 1, 10 | "model": [ 11 | "4", 12 | 0 13 | ], 14 | "positive": [ 15 | "14", 16 | 0 17 | ], 18 | "negative": [ 19 | "7", 20 | 0 21 | ], 22 | "latent_image": [ 23 | "5", 24 | 0 25 | ] 26 | }, 27 | "class_type": "KSampler", 28 | "_meta": { 29 | "title": "KSampler" 30 | } 31 | }, 32 | "4": { 33 | "inputs": { 34 | "ckpt_name": "sd_xl_base_1.0.safetensors" 35 | }, 36 | "class_type": "CheckpointLoaderSimple", 37 | "_meta": { 38 | "title": "Load Checkpoint" 39 | } 40 | }, 41 | "5": { 42 | "inputs": { 43 | "width": 1024, 44 | "height": 1024, 45 | "batch_size": 1 46 | }, 47 | "class_type": "EmptyLatentImage", 48 | "_meta": { 49 | "title": "Empty Latent Image" 50 | } 51 | }, 52 | "6": { 53 | "inputs": { 54 | "text": "a person standing in an open field", 55 | "clip": [ 56 | "4", 57 | 1 58 | ] 59 | }, 60 | "class_type": "CLIPTextEncode", 61 | "_meta": { 62 | "title": "CLIP Text Encode (Prompt)" 63 | } 64 | }, 65 | "7": { 66 | "inputs": { 67 | "text": "text, watermark", 68 | "clip": [ 69 | "4", 70 | 1 71 | ] 72 | }, 73 | "class_type": "CLIPTextEncode", 74 | "_meta": { 75 | "title": "CLIP Text Encode (Prompt)" 76 | } 77 | }, 78 | "8": { 79 | "inputs": { 80 | "samples": [ 81 | "3", 82 | 0 83 | ], 84 | "vae": [ 85 | "4", 86 | 2 87 | ] 88 | }, 89 | "class_type": "VAEDecode", 90 | "_meta": { 91 | "title": "VAE Decode" 92 | } 93 | }, 94 | "9": { 95 | "inputs": { 96 | "filename_prefix": "ComfyUI", 97 | "images": [ 98 | "8", 99 | 0 100 | ] 101 | }, 102 | "class_type": "SaveImage", 103 | "_meta": { 104 | "title": "Save Image" 105 | } 106 | }, 107 | "12": { 108 | "inputs": { 109 | "low_threshold": 0.19999999999999984, 110 | "high_threshold": 0.7, 111 | "image": [ 112 | "17", 113 | 0 114 | ] 115 | }, 116 | "class_type": "Canny", 117 | "_meta": { 118 | "title": "Canny" 119 | } 120 | }, 121 | "13": { 122 | "inputs": { 123 | "image": "boy_sunshine.png", 124 | "upload": "image" 125 | }, 126 | "class_type": "LoadImage", 127 | "_meta": { 128 | "title": "Load Image" 129 | } 130 | }, 131 | "14": { 132 | "inputs": { 133 | "strength": 1, 134 | "conditioning": [ 135 | "6", 136 | 0 137 | ], 138 | "control_net": [ 139 | "22", 140 | 0 141 | ], 142 | "image": [ 143 | "12", 144 | 0 145 | ] 146 | }, 147 | "class_type": "ControlNetApply", 148 | "_meta": { 149 | "title": "Apply ControlNet" 150 | } 151 | }, 152 | "17": { 153 | "inputs": { 154 | "upscale_method": "nearest-exact", 155 | "width": 1024, 156 | "height": 1024, 157 | "crop": "center", 158 | "image": [ 159 | "13", 160 | 0 161 | ] 162 | }, 163 | "class_type": "ImageScale", 164 | "_meta": { 165 | "title": "Upscale Image" 166 | } 167 | }, 168 | "22": { 169 | "inputs": { 170 | "control_net_name": "canny_diffusion_pytorch_model.safetensors" 171 | }, 172 | "class_type": "ControlNetLoader", 173 | "_meta": { 174 | "title": "Load ControlNet Model" 175 | } 176 | }, 177 | "25": { 178 | "inputs": { 179 | "images": [ 180 | "12", 181 | 0 182 | ] 183 | }, 184 | "class_type": "PreviewImage", 185 | "_meta": { 186 | "title": "Preview Image" 187 | } 188 | } 189 | } -------------------------------------------------------------------------------- /utils/ml_processor/comfy_workflows/sdxl_workflow_api.json: -------------------------------------------------------------------------------- 1 | { 2 | "4": { 3 | "inputs": { 4 | "ckpt_name": "sd_xl_base_1.0.safetensors" 5 | }, 6 | "class_type": "CheckpointLoaderSimple", 7 | "_meta": { 8 | "title": "Load Checkpoint - BASE" 9 | } 10 | }, 11 | "5": { 12 | "inputs": { 13 | "width": 1024, 14 | "height": 1024, 15 | "batch_size": 1 16 | }, 17 | "class_type": "EmptyLatentImage", 18 | "_meta": { 19 | "title": "Empty Latent Image" 20 | } 21 | }, 22 | "6": { 23 | "inputs": { 24 | "text": "evening sunset scenery blue sky nature, glass bottle with a galaxy in it", 25 | "clip": [ 26 | "4", 27 | 1 28 | ] 29 | }, 30 | "class_type": "CLIPTextEncode", 31 | "_meta": { 32 | "title": "CLIP Text Encode (Prompt)" 33 | } 34 | }, 35 | "7": { 36 | "inputs": { 37 | "text": "text, watermark", 38 | "clip": [ 39 | "4", 40 | 1 41 | ] 42 | }, 43 | "class_type": "CLIPTextEncode", 44 | "_meta": { 45 | "title": "CLIP Text Encode (Prompt)" 46 | } 47 | }, 48 | "10": { 49 | "inputs": { 50 | "add_noise": "enable", 51 | "noise_seed": 721897303308196, 52 | "steps": 25, 53 | "cfg": 8, 54 | "sampler_name": "euler", 55 | "scheduler": "normal", 56 | "start_at_step": 0, 57 | "end_at_step": 20, 58 | "return_with_leftover_noise": "enable", 59 | "model": [ 60 | "4", 61 | 0 62 | ], 63 | "positive": [ 64 | "6", 65 | 0 66 | ], 67 | "negative": [ 68 | "7", 69 | 0 70 | ], 71 | "latent_image": [ 72 | "5", 73 | 0 74 | ] 75 | }, 76 | "class_type": "KSamplerAdvanced", 77 | "_meta": { 78 | "title": "KSampler (Advanced) - BASE" 79 | } 80 | }, 81 | "11": { 82 | "inputs": { 83 | "add_noise": "disable", 84 | "noise_seed": 0, 85 | "steps": 25, 86 | "cfg": 8, 87 | "sampler_name": "euler", 88 | "scheduler": "normal", 89 | "start_at_step": 20, 90 | "end_at_step": 10000, 91 | "return_with_leftover_noise": "disable", 92 | "model": [ 93 | "12", 94 | 0 95 | ], 96 | "positive": [ 97 | "15", 98 | 0 99 | ], 100 | "negative": [ 101 | "16", 102 | 0 103 | ], 104 | "latent_image": [ 105 | "10", 106 | 0 107 | ] 108 | }, 109 | "class_type": "KSamplerAdvanced", 110 | "_meta": { 111 | "title": "KSampler (Advanced) - REFINER" 112 | } 113 | }, 114 | "12": { 115 | "inputs": { 116 | "ckpt_name": "sd_xl_refiner_1.0.safetensors" 117 | }, 118 | "class_type": "CheckpointLoaderSimple", 119 | "_meta": { 120 | "title": "Load Checkpoint - REFINER" 121 | } 122 | }, 123 | "15": { 124 | "inputs": { 125 | "text": "evening sunset scenery blue sky nature, glass bottle with a galaxy in it", 126 | "clip": [ 127 | "12", 128 | 1 129 | ] 130 | }, 131 | "class_type": "CLIPTextEncode", 132 | "_meta": { 133 | "title": "CLIP Text Encode (Prompt)" 134 | } 135 | }, 136 | "16": { 137 | "inputs": { 138 | "text": "text, watermark", 139 | "clip": [ 140 | "12", 141 | 1 142 | ] 143 | }, 144 | "class_type": "CLIPTextEncode", 145 | "_meta": { 146 | "title": "CLIP Text Encode (Prompt)" 147 | } 148 | }, 149 | "17": { 150 | "inputs": { 151 | "samples": [ 152 | "11", 153 | 0 154 | ], 155 | "vae": [ 156 | "12", 157 | 2 158 | ] 159 | }, 160 | "class_type": "VAEDecode", 161 | "_meta": { 162 | "title": "VAE Decode" 163 | } 164 | }, 165 | "19": { 166 | "inputs": { 167 | "filename_prefix": "ComfyUI", 168 | "images": [ 169 | "17", 170 | 0 171 | ] 172 | }, 173 | "class_type": "SaveImage", 174 | "_meta": { 175 | "title": "Save Image" 176 | } 177 | } 178 | } -------------------------------------------------------------------------------- /ui_components/widgets/download_file_progress_bar.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import string 4 | import tarfile 5 | import time 6 | import zipfile 7 | import requests 8 | import streamlit as st 9 | from utils.common_decorators import with_refresh_lock 10 | from utils.state_refresh import refresh_app 11 | 12 | 13 | @with_refresh_lock 14 | def download_file_widget(url, filename, dest): 15 | save_directory = dest 16 | zip_filename = filename 17 | filepath = os.path.join(save_directory, zip_filename) 18 | 19 | # ------- deleting partial downloads 20 | if st.session_state.get("delete_partial_download", None): 21 | fp = st.session_state["delete_partial_download"] 22 | st.session_state["delete_partial_download"] = None 23 | if os.path.exists(fp): 24 | os.remove(fp) 25 | st.info("Partial downloads deleted") 26 | time.sleep(0.3) 27 | refresh_app() 28 | 29 | # checking if the file already exists 30 | if os.path.exists(os.path.join(dest, filename)): 31 | st.warning("File already present") 32 | time.sleep(1) 33 | refresh_app() 34 | 35 | # setting this file for deletion, incase it's not downloaded properly 36 | # if it is downloaded properly then it will be removed from here (all these steps because of streamlit!) 37 | st.session_state["delete_partial_download"] = filepath 38 | 39 | with st.spinner("Downloading model..."): 40 | download_bar = st.progress(0, text="") 41 | os.makedirs(save_directory, exist_ok=True) # Create the directory if it doesn't exist 42 | 43 | # Download the model and save it to the directory 44 | response = requests.get(url, stream=True) 45 | cancel_download = False 46 | 47 | if st.button("Cancel"): 48 | st.session_state["delete_partial_download"] = filepath 49 | 50 | if response.status_code == 200: 51 | total_size = int(response.headers.get("content-length", 0)) 52 | total_size_mb = total_size / (1024 * 1024) 53 | 54 | start_time = time.time() 55 | 56 | with open(filepath, "wb") as f: 57 | received_bytes = 0 58 | for data in response.iter_content(chunk_size=1048576): 59 | if cancel_download: 60 | raise Exception("download cancelled") 61 | 62 | f.write(data) 63 | received_bytes += len(data) 64 | progress = received_bytes / total_size 65 | received_mb = received_bytes / (1024 * 1024) 66 | 67 | elapsed_time = time.time() - start_time 68 | download_speed = received_bytes / elapsed_time / (1024 * 1024) 69 | 70 | download_bar.progress( 71 | progress, 72 | text=f"Downloaded: {received_mb:.2f} MB / {total_size_mb:.2f} MB | Speed: {download_speed:.2f} MB/sec", 73 | ) 74 | 75 | st.success(f"Downloaded {filename} and saved to {save_directory}") 76 | time.sleep(1) 77 | download_bar.empty() 78 | 79 | if url.endswith(".zip") or url.endswith(".tar"): 80 | st.success("Extracting the zip file. Please wait...") 81 | new_filepath = filepath.replace(zip_filename, "") 82 | if url.endswith(".zip"): 83 | with zipfile.ZipFile(f"{filepath}", "r") as zip_ref: 84 | zip_ref.extractall(new_filepath) 85 | else: 86 | with tarfile.open(f"{filepath}", "r") as tar_ref: 87 | tar_ref.extractall(new_filepath) 88 | 89 | os.remove(filepath) 90 | else: 91 | os.rename(filepath, filepath.replace(zip_filename, filename)) 92 | print("removing ---------") 93 | 94 | st.session_state["delete_partial_download"] = None 95 | else: 96 | st.error("Unable to access model url") 97 | time.sleep(1) 98 | 99 | refresh_app() 100 | -------------------------------------------------------------------------------- /utils/ml_processor/comfy_workflows/sdxl_openpose_workflow_api.json: -------------------------------------------------------------------------------- 1 | { 2 | "3": { 3 | "inputs": { 4 | "seed": 253663277835217, 5 | "steps": 20, 6 | "cfg": 8, 7 | "sampler_name": "euler", 8 | "scheduler": "normal", 9 | "denoise": 1, 10 | "model": [ 11 | "4", 12 | 0 13 | ], 14 | "positive": [ 15 | "16", 16 | 0 17 | ], 18 | "negative": [ 19 | "7", 20 | 0 21 | ], 22 | "latent_image": [ 23 | "5", 24 | 0 25 | ] 26 | }, 27 | "class_type": "KSampler", 28 | "_meta": { 29 | "title": "KSampler" 30 | } 31 | }, 32 | "4": { 33 | "inputs": { 34 | "ckpt_name": "sd_xl_base_1.0_0.9vae.safetensors" 35 | }, 36 | "class_type": "CheckpointLoaderSimple", 37 | "_meta": { 38 | "title": "Load Checkpoint" 39 | } 40 | }, 41 | "5": { 42 | "inputs": { 43 | "width": 624, 44 | "height": 624, 45 | "batch_size": 1 46 | }, 47 | "class_type": "EmptyLatentImage", 48 | "_meta": { 49 | "title": "Empty Latent Image" 50 | } 51 | }, 52 | "6": { 53 | "inputs": { 54 | "text": "a ballerina, romantic sunset, 4k photo", 55 | "clip": [ 56 | "4", 57 | 1 58 | ] 59 | }, 60 | "class_type": "CLIPTextEncode", 61 | "_meta": { 62 | "title": "CLIP Text Encode (Prompt)" 63 | } 64 | }, 65 | "7": { 66 | "inputs": { 67 | "text": "text, watermark", 68 | "clip": [ 69 | "4", 70 | 1 71 | ] 72 | }, 73 | "class_type": "CLIPTextEncode", 74 | "_meta": { 75 | "title": "CLIP Text Encode (Prompt)" 76 | } 77 | }, 78 | "8": { 79 | "inputs": { 80 | "samples": [ 81 | "3", 82 | 0 83 | ], 84 | "vae": [ 85 | "4", 86 | 2 87 | ] 88 | }, 89 | "class_type": "VAEDecode", 90 | "_meta": { 91 | "title": "VAE Decode" 92 | } 93 | }, 94 | "9": { 95 | "inputs": { 96 | "filename_prefix": "ComfyUI", 97 | "images": [ 98 | "8", 99 | 0 100 | ] 101 | }, 102 | "class_type": "SaveImage", 103 | "_meta": { 104 | "title": "Save Image" 105 | } 106 | }, 107 | "10": { 108 | "inputs": { 109 | "detect_hand": "enable", 110 | "detect_body": "enable", 111 | "detect_face": "enable", 112 | "resolution": "v1.1", 113 | "image": [ 114 | "11", 115 | 0 116 | ] 117 | }, 118 | "class_type": "OpenposePreprocessor", 119 | "_meta": { 120 | "title": "OpenPose Pose" 121 | } 122 | }, 123 | "11": { 124 | "inputs": { 125 | "upscale_method": "nearest-exact", 126 | "width": 623, 127 | "height": 623, 128 | "crop": "disabled", 129 | "image": [ 130 | "12", 131 | 0 132 | ] 133 | }, 134 | "class_type": "ImageScale", 135 | "_meta": { 136 | "title": "Upscale Image" 137 | } 138 | }, 139 | "12": { 140 | "inputs": { 141 | "image": "boy_sunshine.png", 142 | "upload": "image" 143 | }, 144 | "class_type": "LoadImage", 145 | "_meta": { 146 | "title": "Load Image" 147 | } 148 | }, 149 | "13": { 150 | "inputs": { 151 | "images": [ 152 | "10", 153 | 0 154 | ] 155 | }, 156 | "class_type": "PreviewImage", 157 | "_meta": { 158 | "title": "Preview Image" 159 | } 160 | }, 161 | "14": { 162 | "inputs": { 163 | "control_net_name": "OpenPoseXL2.safetensors" 164 | }, 165 | "class_type": "ControlNetLoader", 166 | "_meta": { 167 | "title": "Load ControlNet Model" 168 | } 169 | }, 170 | "16": { 171 | "inputs": { 172 | "strength": 1, 173 | "conditioning": [ 174 | "6", 175 | 0 176 | ], 177 | "control_net": [ 178 | "14", 179 | 0 180 | ], 181 | "image": [ 182 | "10", 183 | 0 184 | ] 185 | }, 186 | "class_type": "ControlNetApply", 187 | "_meta": { 188 | "title": "Apply ControlNet" 189 | } 190 | } 191 | } -------------------------------------------------------------------------------- /auto_refresh.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import signal 4 | import sys 5 | import portalocker 6 | import json 7 | from flask import Flask, jsonify, request 8 | import logging 9 | import threading 10 | import time 11 | import random 12 | from queue import Queue 13 | 14 | import requests 15 | 16 | from utils.constants import REFRESH_LOCK_FILE, REFRESH_PROCESS_PORT, REFRESH_TARGET_FILE 17 | 18 | app = Flask(__name__) 19 | 20 | 21 | target_file = REFRESH_TARGET_FILE 22 | lock_file = REFRESH_LOCK_FILE 23 | refresh_queue = Queue() 24 | refresh_thread = None 25 | 26 | last_refreshed_on = 0 27 | REFRESH_BUFFER_TIME = 10 # seconds before making consecutive refreshes 28 | TERMINATE_SCRIPT = False 29 | 30 | 31 | def handle_termination(signal, frame): 32 | print("Received termination signal - auto refresh. Cleaning up...") 33 | global TERMINATE_SCRIPT 34 | TERMINATE_SCRIPT = True 35 | os._exit(1) 36 | 37 | 38 | if platform.system() == "Windows": 39 | signal.signal(signal.SIGINT, handle_termination) 40 | 41 | signal.signal(signal.SIGTERM, handle_termination) 42 | 43 | 44 | def check_lock(): 45 | if not os.path.exists(lock_file): 46 | return False 47 | 48 | try: 49 | with portalocker.Lock(lock_file, "r", timeout=0.1) as lock_file_handle: 50 | data = json.load(lock_file_handle) 51 | return data["status"] == "locked" 52 | except (portalocker.LockException, FileNotFoundError, json.JSONDecodeError): 53 | return False 54 | 55 | 56 | def refresh(): 57 | while check_lock(): 58 | # print("process locked.. sleeping") 59 | time.sleep(2) 60 | 61 | global last_refreshed_on 62 | while int(time.time()) - last_refreshed_on < REFRESH_BUFFER_TIME: 63 | # print(f"waiting {REFRESH_BUFFER_TIME} secs before the next refresh") 64 | time.sleep(2) 65 | 66 | # print("Refreshing...") 67 | last_refreshed_on = int(time.time()) 68 | with portalocker.Lock(target_file, "w") as f: 69 | f.write(f"SAVE_STATE = {random.randint(1, 1000)}") 70 | return True 71 | 72 | 73 | def refresh_worker(): 74 | while not TERMINATE_SCRIPT: 75 | refresh_queue.get() 76 | refresh() 77 | 78 | 79 | @app.route("/refresh", methods=["POST"]) 80 | def trigger_refresh(): 81 | if refresh_queue.empty(): 82 | refresh_queue.put(True) 83 | return jsonify({"success": True, "message": "Refresh request queued"}) 84 | else: 85 | return jsonify({"success": False, "message": "Refresh already queued"}) 86 | 87 | 88 | @app.route("/health", methods=["GET"]) 89 | def health_check(): 90 | return jsonify({"status": "healthy"}), 200 91 | 92 | 93 | def run_flask(): 94 | # disabling flask's default logger 95 | log = logging.getLogger("werkzeug") 96 | log.setLevel(logging.ERROR) 97 | 98 | # disabling the output stream 99 | cli = sys.modules["flask.cli"] 100 | cli.show_server_banner = lambda *x: None 101 | 102 | app.run(host="0.0.0.0", port=REFRESH_PROCESS_PORT) 103 | 104 | 105 | def main(): 106 | # flask server 107 | flask_thread = threading.Thread(target=run_flask) 108 | flask_thread.start() 109 | 110 | # refresh worker 111 | refresh_thread = threading.Thread(target=refresh_worker) 112 | refresh_thread.daemon = True 113 | refresh_thread.start() 114 | 115 | try: 116 | while not TERMINATE_SCRIPT: 117 | time.sleep(1) 118 | except KeyboardInterrupt: 119 | print("Keyboard interrupt received. Shutting down...") 120 | handle_termination(signal.SIGINT, None) 121 | 122 | print("Waiting for threads to finish...") 123 | # running the main thread as long as the flask server is active 124 | flask_thread.join(timeout=5) 125 | refresh_thread.join(timeout=5) 126 | 127 | if not (flask_thread.is_alive() and refresh_thread.is_alive()): 128 | print("Auto refresh terminated.") 129 | else: 130 | print("threads didn't shutdown") 131 | print("refresh thread active: ", refresh_thread.is_alive()) 132 | print("flask thread active: ", flask_thread.is_alive()) 133 | 134 | 135 | if __name__ == "__main__": 136 | main() 137 | -------------------------------------------------------------------------------- /ui_components/constants.py: -------------------------------------------------------------------------------- 1 | import os 2 | from shared.constants import COMFY_BASE_PATH, AnimationStyleType, AnimationToolType 3 | from utils.constants import ImageStage 4 | from utils.enum import ExtendedEnum 5 | 6 | 7 | class WorkflowStageType(ExtendedEnum): 8 | SOURCE = "source" 9 | STYLED = "styled" 10 | 11 | 12 | class VideoQuality(ExtendedEnum): 13 | HIGH = "High-Quality" 14 | LOW = "Low" 15 | 16 | 17 | class CreativeProcessType(ExtendedEnum): 18 | STYLING = "Key Frames" 19 | MOTION = "Shots" 20 | 21 | 22 | class ShotMetaData(ExtendedEnum): 23 | MOTION_DATA = "motion_data" # {"timing_data": [...], "main_setting_data": {}} 24 | DYNAMICRAFTER_DATA = "dynamicrafter_data" 25 | 26 | 27 | class GalleryImageViewType(ExtendedEnum): 28 | EXPLORER_ONLY = "explorer" 29 | SHOT_ONLY = "shot" 30 | ANY = "any" 31 | 32 | 33 | class DefaultTimingStyleParams: 34 | prompt = "" 35 | negative_prompt = "bad image, worst quality" 36 | strength = 1 37 | guidance_scale = 7.5 38 | seed = 0 39 | num_inference_steps = 25 40 | low_threshold = 100 41 | high_threshold = 200 42 | adapter_type = None 43 | interpolation_steps = 3 44 | transformation_stage = ImageStage.SOURCE_IMAGE.value 45 | custom_model_id_list = [] 46 | animation_tool = AnimationToolType.G_FILM.value 47 | animation_style = AnimationStyleType.CREATIVE_INTERPOLATION.value 48 | model = None 49 | total_log_table_pages = 1 50 | 51 | 52 | class DefaultProjectSettingParams: 53 | batch_prompt = "" 54 | batch_negative_prompt = "bad image, worst quality" 55 | batch_strength = 1 56 | batch_guidance_scale = 0.5 57 | batch_seed = 0 58 | batch_num_inference_steps = 25 59 | batch_low_threshold = 100 60 | batch_high_threshold = 200 61 | batch_adapter_type = None 62 | batch_interpolation_steps = 3 63 | batch_transformation_stage = ImageStage.SOURCE_IMAGE.value 64 | batch_custom_model_id_list = [] 65 | batch_animation_tool = AnimationToolType.G_FILM.value 66 | batch_animation_style = AnimationStyleType.CREATIVE_INTERPOLATION.value 67 | batch_model = None 68 | total_log_pages = 1 69 | total_gallery_pages = 1 70 | total_shortlist_gallery_pages = 1 71 | max_frames_per_shot = 30 72 | 73 | 74 | DEFAULT_SHOT_MOTION_VALUES = { 75 | "strength_of_frame": 0.85, 76 | "distance_to_next_frame": 3.0, 77 | "speed_of_transition": 0.5, 78 | "freedom_between_frames": 0.85, 79 | "individual_prompt": "", 80 | "individual_negative_prompt": "", 81 | "motion_during_frame": 1.25, 82 | } 83 | 84 | # TODO: make proper paths for every file 85 | CROPPED_IMG_LOCAL_PATH = "videos/temp/cropped.png" 86 | 87 | MASK_IMG_LOCAL_PATH = "videos/temp/mask.png" 88 | TEMP_MASK_FILE = "temp_mask_file" 89 | 90 | SECOND_MASK_FILE_PATH = "videos/temp/second_mask.png" 91 | SECOND_MASK_FILE = "second_mask_file" 92 | 93 | AUDIO_FILE_PATH = "videos/temp/audio.mp3" 94 | AUDIO_FILE = "audio_file" 95 | 96 | checkpoints_dir = os.path.join(COMFY_BASE_PATH, "models", "checkpoints") 97 | SD_MODEL_DICT = { 98 | "realisticVisionV60B1_v51VAE.safetensors": { 99 | "url": "https://civitai.com/api/download/models/130072", 100 | "filename": "realisticVisionV60B1_v51VAE.safetensors", 101 | "dest": checkpoints_dir, 102 | }, 103 | "anything_v50.safetensors": { 104 | "url": "https://civitai.com/api/download/models/30163", 105 | "filename": "anything_v50.safetensors", 106 | "dest": checkpoints_dir, 107 | }, 108 | "dreamshaper_8.safetensors": { 109 | "url": "https://civitai.com/api/download/models/128713", 110 | "filename": "dreamshaper_8.safetensors", 111 | "dest": checkpoints_dir, 112 | }, 113 | "epicrealism_pureEvolutionV5.safetensors": { 114 | "url": "https://civitai.com/api/download/models/134065", 115 | "filename": "epicrealism_pureEvolutionV5.safetensors", 116 | "dest": checkpoints_dir, 117 | }, 118 | "majicmixRealistic_v6.safetensors": { 119 | "url": "https://civitai.com/api/download/models/94640", 120 | "filename": "majicmixRealistic_v6.safetensors", 121 | "dest": checkpoints_dir, 122 | }, 123 | } 124 | -------------------------------------------------------------------------------- /utils/ml_processor/comfy_workflows/ipadapter_face_api.json: -------------------------------------------------------------------------------- 1 | { 2 | "3": { 3 | "inputs": { 4 | "seed": 862782529735965, 5 | "steps": 24, 6 | "cfg": 9.25, 7 | "sampler_name": "ddim", 8 | "scheduler": "normal", 9 | "denoise": 1, 10 | "model": [ 11 | "36", 12 | 0 13 | ], 14 | "positive": [ 15 | "6", 16 | 0 17 | ], 18 | "negative": [ 19 | "7", 20 | 0 21 | ], 22 | "latent_image": [ 23 | "5", 24 | 0 25 | ] 26 | }, 27 | "class_type": "KSampler", 28 | "_meta": { 29 | "title": "KSampler" 30 | } 31 | }, 32 | "4": { 33 | "inputs": { 34 | "ckpt_name": "sd_xl_base_1.0.safetensors" 35 | }, 36 | "class_type": "CheckpointLoaderSimple", 37 | "_meta": { 38 | "title": "Load Checkpoint" 39 | } 40 | }, 41 | "5": { 42 | "inputs": { 43 | "width": 1024, 44 | "height": 1024, 45 | "batch_size": 1 46 | }, 47 | "class_type": "EmptyLatentImage", 48 | "_meta": { 49 | "title": "Empty Latent Image" 50 | } 51 | }, 52 | "6": { 53 | "inputs": { 54 | "text": "", 55 | "clip": [ 56 | "4", 57 | 1 58 | ] 59 | }, 60 | "class_type": "CLIPTextEncode", 61 | "_meta": { 62 | "title": "CLIP Text Encode (Prompt)" 63 | } 64 | }, 65 | "7": { 66 | "inputs": { 67 | "text": "", 68 | "clip": [ 69 | "4", 70 | 1 71 | ] 72 | }, 73 | "class_type": "CLIPTextEncode", 74 | "_meta": { 75 | "title": "CLIP Text Encode (Prompt)" 76 | } 77 | }, 78 | "8": { 79 | "inputs": { 80 | "samples": [ 81 | "3", 82 | 0 83 | ], 84 | "vae": [ 85 | "4", 86 | 2 87 | ] 88 | }, 89 | "class_type": "VAEDecode", 90 | "_meta": { 91 | "title": "VAE Decode" 92 | } 93 | }, 94 | "21": { 95 | "inputs": { 96 | "ipadapter_file": "ip-adapter_sdxl.safetensors" 97 | }, 98 | "class_type": "IPAdapterModelLoader", 99 | "_meta": { 100 | "title": "Load IPAdapter Model" 101 | } 102 | }, 103 | "24": { 104 | "inputs": { 105 | "image": "rWA-3_T7_400x400.jpg", 106 | "upload": "image" 107 | }, 108 | "class_type": "LoadImage", 109 | "_meta": { 110 | "title": "Load Image" 111 | } 112 | }, 113 | "29": { 114 | "inputs": { 115 | "filename_prefix": "ComfyUI", 116 | "images": [ 117 | "8", 118 | 0 119 | ] 120 | }, 121 | "class_type": "SaveImage", 122 | "_meta": { 123 | "title": "Save Image" 124 | } 125 | }, 126 | "36": { 127 | "inputs": { 128 | "weight": 0.75, 129 | "noise": 0.3, 130 | "weight_faceidv2": 0.75, 131 | "weight_type": "linear", 132 | "combine_embeds": "concat", 133 | "embeds_scaling": "V only", 134 | "start_at": 0, 135 | "end_at": 1, 136 | "ipadapter": [ 137 | "21", 138 | 0 139 | ], 140 | "clip_vision": [ 141 | "41", 142 | 0 143 | ], 144 | "insightface": [ 145 | "37", 146 | 0 147 | ], 148 | "image": [ 149 | "40", 150 | 0 151 | ], 152 | "model": [ 153 | "4", 154 | 0 155 | ] 156 | }, 157 | "class_type": "IPAdapterFaceID", 158 | "_meta": { 159 | "title": "IPAdapter FaceID" 160 | } 161 | }, 162 | "37": { 163 | "inputs": { 164 | "provider": "CUDA" 165 | }, 166 | "class_type": "IPAdapterInsightFaceLoader", 167 | "_meta": { 168 | "title": "IPAdapter InsightFace Loader" 169 | } 170 | }, 171 | "40": { 172 | "inputs": { 173 | "interpolation": "LANCZOS", 174 | "crop_position": "top", 175 | "sharpening": 0, 176 | "image": [ 177 | "24", 178 | 0 179 | ] 180 | }, 181 | "class_type": "PrepImageForClipVision", 182 | "_meta": { 183 | "title": "Prep Image For ClipVision" 184 | } 185 | }, 186 | "41": { 187 | "inputs": { 188 | "clip_name": "SDXL/pytorch_model.bin" 189 | }, 190 | "class_type": "CLIPVisionLoader", 191 | "_meta": { 192 | "title": "Load CLIP Vision" 193 | } 194 | } 195 | } -------------------------------------------------------------------------------- /utils/common_decorators.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | import json 3 | import os 4 | import time 5 | import portalocker 6 | import streamlit as st 7 | from streamlit import runtime 8 | 9 | 10 | def count_calls(cls): 11 | class Wrapper(cls): 12 | def __init__(self, *args, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | self.call_counts = {} 15 | self.total_count = 0 16 | 17 | def __getattribute__(self, name): 18 | attr = super().__getattribute__(name) 19 | if callable(attr) and name not in ["__getattribute__", "call_counts", "total_count"]: 20 | if name not in self.call_counts: 21 | self.call_counts[name] = 0 22 | 23 | def wrapped_method(*args, **kwargs): 24 | self.call_counts[name] += 1 25 | self.total_count += 1 26 | return attr(*args, **kwargs) 27 | 28 | return wrapped_method 29 | 30 | return attr 31 | 32 | return Wrapper 33 | 34 | 35 | def log_time(func): 36 | def wrapper(*args, **kwargs): 37 | start_time = time.time() 38 | result = func(*args, **kwargs) 39 | end_time = time.time() 40 | execution_time = end_time - start_time 41 | print( 42 | f"{args[1] if args and len(args) >= 2 else kwargs['url']} took {execution_time:.4f} seconds to execute." 43 | ) 44 | return result 45 | 46 | return wrapper 47 | 48 | 49 | def measure_execution_time(cls): 50 | class WrapperClass: 51 | def __init__(self, *args, **kwargs): 52 | self.wrapped_instance = cls(*args, **kwargs) 53 | 54 | def __getattr__(self, name): 55 | attr = getattr(self.wrapped_instance, name) 56 | if callable(attr): 57 | return self.measure_method_execution(attr) 58 | return attr 59 | 60 | def measure_method_execution(self, method): 61 | def wrapper(*args, **kwargs): 62 | start_time = time.time() 63 | result = method(*args, **kwargs) 64 | end_time = time.time() 65 | execution_time = end_time - start_time 66 | print(f"Execution time of {method.__name__}: {execution_time} seconds") 67 | return result 68 | 69 | return wrapper 70 | 71 | return WrapperClass 72 | 73 | 74 | def session_state_attributes(default_value_cls): 75 | def decorator(cls): 76 | original_getattr = cls.__getattribute__ 77 | original_setattr = cls.__setattr__ 78 | 79 | def custom_attr(self, attr): 80 | if hasattr(default_value_cls, attr): 81 | key = f"{self.uuid}_{attr}" 82 | if not (key in st.session_state and st.session_state[key]): 83 | st.session_state[key] = getattr(default_value_cls, attr) 84 | 85 | return st.session_state[key] if runtime.exists() else getattr(default_value_cls, attr) 86 | else: 87 | return original_getattr(self, attr) 88 | 89 | def custom_setattr(self, attr, value): 90 | if hasattr(default_value_cls, attr): 91 | key = f"{self.uuid}_{attr}" 92 | st.session_state[key] = value 93 | else: 94 | original_setattr(self, attr, value) 95 | 96 | cls.__getattribute__ = custom_attr 97 | cls.__setattr__ = custom_setattr 98 | return cls 99 | 100 | return decorator 101 | 102 | 103 | def with_refresh_lock(func): 104 | @wraps(func) 105 | def wrapper(*args, **kwargs): 106 | update_refresh_lock(True) 107 | try: 108 | return func(*args, **kwargs) 109 | except Exception as e: 110 | print("Error occured while processing ", str(e)) 111 | finally: 112 | update_refresh_lock(False) 113 | 114 | return wrapper 115 | 116 | 117 | def update_refresh_lock(status=False): 118 | from utils.constants import REFRESH_LOCK_FILE 119 | 120 | status = "locked" if status else "unlocked" 121 | lock_file = REFRESH_LOCK_FILE 122 | with portalocker.Lock(lock_file, "w") as lock_file_handle: 123 | json.dump( 124 | {"status": status, "last_action_time": time.time(), "process_id": os.getpid()}, 125 | lock_file_handle, 126 | ) 127 | -------------------------------------------------------------------------------- /ui_components/methods/training_methods.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import List 3 | from shared.constants import AIModelCategory 4 | from utils.common_utils import get_current_user_uuid 5 | from utils.data_repo.data_repo import DataRepo 6 | from utils.ml_processor.ml_interface import get_ml_client 7 | from utils.ml_processor.constants import ML_MODEL 8 | 9 | # NOTE: code not in use 10 | # NOTE: making an exception for this function, passing just the image urls instead of image files 11 | # def train_model(images_list, instance_prompt, class_prompt, max_train_steps, 12 | # model_name, type_of_model, type_of_task, resolution, controller_type, model_type_list): 13 | # # prepare and upload the training data (images.zip) 14 | # ml_client = get_ml_client() 15 | # try: 16 | # training_file_url = ml_client.upload_training_data(images_list) 17 | # except Exception as e: 18 | # raise e 19 | 20 | # # training the model 21 | # model_name = model_name.replace(" ", "-").lower() 22 | # if type_of_model == "Dreambooth": 23 | # return train_dreambooth_model(instance_prompt, class_prompt, training_file_url, 24 | # max_train_steps, model_name, images_list, controller_type, model_type_list) 25 | # elif type_of_model == "LoRA": 26 | # return train_lora_model(training_file_url, type_of_task, resolution, model_name, images_list, model_type_list) 27 | 28 | 29 | # NOTE: code not in use 30 | # INFO: images_list passed here are converted to internal files after they are used for training 31 | # def train_dreambooth_model(instance_prompt, class_prompt, training_file_url, max_train_steps, model_name, images_list: List[str], controller_type, model_type_list): 32 | # from ui_components.methods.common_methods import convert_image_list_to_file_list 33 | 34 | # ml_client = get_ml_client() 35 | # app_setting = DataRepo().get_app_setting_from_uuid() 36 | 37 | # response = ml_client.dreambooth_training( 38 | # training_file_url, instance_prompt, class_prompt, max_train_steps, model_name, controller_type, len(images_list), app_setting.replicate_username) 39 | # training_status = response["status"] 40 | 41 | # model_id = response["id"] 42 | # if training_status == "queued": 43 | # file_list = convert_image_list_to_file_list(images_list) 44 | # file_uuid_list = [file.uuid for file in file_list] 45 | # file_uuid_list = json.dumps(file_uuid_list) 46 | 47 | # model_data = { 48 | # "name": model_name, 49 | # "user_id": get_current_user_uuid(), 50 | # "replicate_model_id": model_id, 51 | # "replicate_url": response["model"], 52 | # "diffusers_url": "", 53 | # "category": AIModelCategory.DREAMBOOTH.value, 54 | # "training_image_list": file_uuid_list, 55 | # "keyword": instance_prompt, 56 | # "custom_trained": True, 57 | # "model_type": model_type_list 58 | # } 59 | 60 | # data_repo = DataRepo() 61 | # data_repo.create_ai_model(**model_data) 62 | 63 | # return "Success - Training Started. Please wait 10-15 minutes for the model to be trained." 64 | # else: 65 | # return "Failed" 66 | 67 | # NOTE: code not in use 68 | # INFO: images_list passed here are converted to internal files after they are used for training 69 | # def train_lora_model(training_file_url, type_of_task, resolution, model_name, images_list, model_type_list): 70 | # from ui_components.methods.common_methods import convert_image_list_to_file_list 71 | 72 | # data_repo = DataRepo() 73 | # ml_client = get_ml_client() 74 | # output = ml_client.predict_model_output(ML_MODEL.clones_lora_training, instance_data=training_file_url, 75 | # task=type_of_task, resolution=int(resolution)) 76 | 77 | # file_list = convert_image_list_to_file_list(images_list) 78 | # file_uuid_list = [file.uuid for file in file_list] 79 | # file_uuid_list = json.dumps(file_uuid_list) 80 | # model_data = { 81 | # "name": model_name, 82 | # "user_id": get_current_user_uuid(), 83 | # "replicate_url": output, 84 | # "diffusers_url": "", 85 | # "category": AIModelCategory.LORA.value, 86 | # "training_image_list": file_uuid_list, 87 | # "custom_trained": True, 88 | # "model_type": model_type_list 89 | # } 90 | 91 | # data_repo.create_ai_model(**model_data) 92 | # return f"Successfully trained - the model '{model_name}' is now available for use!" 93 | -------------------------------------------------------------------------------- /shared/file_upload/s3.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import mimetypes 3 | from urllib.parse import urlparse 4 | import boto3 5 | import uuid 6 | import os 7 | import shutil 8 | 9 | import requests 10 | from shared.constants import AWS_ACCESS_KEY, AWS_S3_BUCKET, AWS_S3_REGION, AWS_SECRET_KEY 11 | from shared.logging.logging import AppLogger 12 | from shared.logging.constants import LoggingPayload, LoggingType 13 | from ui_components.methods.file_methods import convert_file_to_base64 14 | 15 | logger = AppLogger() 16 | 17 | # TODO: fix proper paths for file uploads 18 | 19 | 20 | def upload_file(file_location, aws_access_key, aws_secret_key, bucket=AWS_S3_BUCKET): 21 | url = None 22 | ext = os.path.splitext(file_location)[1] 23 | unique_file_name = str(uuid.uuid4()) + ext 24 | try: 25 | s3_file = f"input_images/{unique_file_name}" 26 | s3 = boto3.client("s3", aws_access_key_id=aws_access_key, aws_secret_access_key=aws_secret_key) 27 | s3.upload_file(file_location, bucket, s3_file) 28 | s3.put_object_acl(ACL="public-read", Bucket=bucket, Key=s3_file) 29 | url = f"https://s3.amazonaws.com/{bucket}/{s3_file}" 30 | except Exception as e: 31 | logger.log(LoggingType.ERROR, "unable to upload to s3") 32 | 33 | return url 34 | 35 | 36 | def upload_file_from_obj(file, file_extension, bucket=AWS_S3_BUCKET): 37 | aws_access_key, aws_secret_key = AWS_ACCESS_KEY, AWS_SECRET_KEY 38 | folder = "test/" 39 | unique_tag = str(uuid.uuid4()) 40 | filename = unique_tag + file_extension 41 | file.seek(0) 42 | 43 | # Upload the file 44 | content_type = ( 45 | "application/octet-stream" if file_extension not in [".png", ".jpg"] else "image/png" 46 | ) # hackish sol, will fix later 47 | data = {"Body": file, "Bucket": bucket, "Key": folder + filename, "ACL": "public-read"} 48 | if content_type: 49 | data["ContentType"] = content_type 50 | 51 | s3_client = boto3.client("s3", aws_access_key_id=aws_access_key, aws_secret_access_key=aws_secret_key) 52 | resp = s3_client.put_object(**data) 53 | object_url = "https://s3-{0}.amazonaws.com/{1}/{2}".format(AWS_S3_REGION, bucket, folder + filename) 54 | return object_url 55 | 56 | 57 | def upload_file_from_bytes(file_bytes, aws_access_key, aws_secret_key, key=None, bucket=AWS_S3_BUCKET): 58 | if not key: 59 | key = "test/" + str(uuid.uuid4()) + ".png" 60 | 61 | content_type = "image/png" 62 | data = {"Body": file_bytes, "Bucket": bucket, "Key": key, "ACL": "public-read"} 63 | if content_type: 64 | data["ContentType"] = content_type 65 | 66 | s3_client = boto3.client("s3", aws_access_key_id=aws_access_key, aws_secret_access_key=aws_secret_key) 67 | resp = s3_client.put_object(**data) 68 | object_url = "https://s3-{0}.amazonaws.com/{1}/{2}".format(AWS_S3_REGION, bucket, key) 69 | return object_url 70 | 71 | 72 | # TODO: fix the structuring of s3 for different users and different files 73 | def generate_s3_url( 74 | image_url, aws_access_key, aws_secret_key, bucket=AWS_S3_BUCKET, file_ext="png", folder="posts/" 75 | ): 76 | if object_name is None: 77 | object_name = str(uuid.uuid4()) + "." + file_ext 78 | 79 | response = requests.get(image_url) 80 | if response.status_code != 200: 81 | raise Exception("Failed to download the image from the given URL") 82 | 83 | file = response.content 84 | 85 | content_type = mimetypes.guess_type(object_name)[0] 86 | data = {"Body": file, "Bucket": bucket, "Key": folder + object_name, "ACL": "public-read"} 87 | if content_type: 88 | data["ContentType"] = content_type 89 | else: 90 | data["ContentType"] = "image/png" 91 | 92 | s3_client = boto3.client( 93 | service_name="s3", 94 | region_name=AWS_S3_REGION, 95 | aws_access_key_id=aws_access_key, 96 | aws_secret_access_key=aws_secret_key, 97 | ) 98 | resp = s3_client.put_object(**data) 99 | 100 | extension = os.path.splitext(object_name)[1] 101 | disposition = f'inline; filename="{object_name}"' 102 | if extension: 103 | disposition += f'; filename="{object_name}"' 104 | resp["ResponseMetadata"]["HTTPHeaders"]["Content-Disposition"] = disposition 105 | 106 | object_url = "https://s3-{0}.amazonaws.com/{1}/{2}".format( 107 | AWS_S3_REGION, AWS_S3_BUCKET, folder + object_name 108 | ) 109 | return object_url 110 | 111 | 112 | def is_s3_image_url(url): 113 | parsed_url = urlparse(url) 114 | netloc = parsed_url.netloc.lower() 115 | 116 | if netloc.endswith(".amazonaws.com"): 117 | subdomain = netloc[: -len(".amazonaws.com")].split("-") 118 | if len(subdomain) > 1 and subdomain[0] == "s3": 119 | return True 120 | 121 | return False 122 | -------------------------------------------------------------------------------- /utils/ml_processor/gpu/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import sys 4 | import subprocess 5 | import time 6 | from git import Repo 7 | from shared.constants import COMFY_BASE_PATH 8 | from shared.logging.constants import LoggingType 9 | from shared.logging.logging import app_logger 10 | from utils.common_utils import get_toml_config 11 | from utils.constants import TomlConfig 12 | 13 | 14 | COMFY_RUNNER_PATH = "./comfy_runner" 15 | 16 | 17 | def predict_gpu_output( 18 | workflow: str, 19 | file_path_list=[], 20 | output_node=None, 21 | extra_model_list=[], 22 | ignore_model_list=[], 23 | log_tag=None, 24 | ) -> str: 25 | # spec = importlib.util.spec_from_file_location('my_module', f'{COMFY_RUNNER_PATH}/inf.py') 26 | # comfy_runner = importlib.util.module_from_spec(spec) 27 | # spec.loader.exec_module(comfy_runner) 28 | 29 | # hackish sol.. waiting for comfy repo to be cloned 30 | while not is_comfy_runner_present(): 31 | time.sleep(2) 32 | 33 | sys.path.append(str(os.getcwd()) + COMFY_RUNNER_PATH[1:]) 34 | from comfy_runner.inf import ComfyRunner 35 | 36 | comfy_commit_hash = get_toml_config(TomlConfig.COMFY_VERSION.value)["commit_hash"] 37 | node_commit_dict = get_toml_config(TomlConfig.NODE_VERSION.value) 38 | pkg_versions = get_toml_config(TomlConfig.PKG_VERSIONS.value) 39 | extra_node_urls = [] 40 | for k, v in node_commit_dict.items(): 41 | v["title"] = k 42 | extra_node_urls.append(v) 43 | 44 | comfy_runner = ComfyRunner() 45 | output = comfy_runner.predict( 46 | workflow_input=workflow, 47 | file_path_list=file_path_list, 48 | stop_server_after_completion=False, 49 | output_node_ids=output_node, 50 | extra_models_list=extra_model_list, 51 | ignore_model_list=ignore_model_list, 52 | client_id=log_tag, 53 | extra_node_urls=extra_node_urls, 54 | comfy_commit_hash=comfy_commit_hash, 55 | strict_dep_list=pkg_versions 56 | ) 57 | 58 | return output["file_paths"] # ignoring text output for now {"file_paths": [], "text_content": []} 59 | 60 | 61 | def is_comfy_runner_present(): 62 | return os.path.exists(COMFY_RUNNER_PATH) # hackish sol, will fix later 63 | 64 | 65 | # TODO: convert comfy_runner into a package for easy import 66 | def setup_comfy_runner(): 67 | if is_comfy_runner_present(): 68 | update_comfy_runner_env() 69 | return 70 | 71 | app_logger.log(LoggingType.INFO, "cloning comfy runner") 72 | comfy_repo_url = "https://github.com/piyushK52/comfy-runner" 73 | Repo.clone_from(comfy_repo_url, COMFY_RUNNER_PATH[2:], single_branch=True, branch="main") 74 | 75 | # installing dependencies 76 | subprocess.run(["pip", "install", "-r", COMFY_RUNNER_PATH + "/requirements.txt"], check=True) 77 | update_comfy_runner_env() 78 | 79 | 80 | def find_comfy_runner(): 81 | # just keep going up the directory tree, till we find comfy_runner 82 | current_path = os.path.dirname(os.path.abspath(__file__)) 83 | 84 | while True: 85 | if os.path.exists(os.path.join(current_path, '.git')): 86 | comfy_runner_path = os.path.join(current_path, 'comfy_runner') 87 | if os.path.exists(comfy_runner_path): 88 | return comfy_runner_path 89 | else: 90 | return None # comfy_runner not found in the project root 91 | 92 | parent_path = os.path.dirname(current_path) 93 | if parent_path == current_path: 94 | return None 95 | 96 | current_path = parent_path 97 | 98 | def update_comfy_runner_env(): 99 | comfy_base_path = os.getenv("COMFY_MODELS_BASE_PATH", "ComfyUI") 100 | comfy_runner_path = find_comfy_runner() 101 | if not comfy_runner_path: 102 | print("comfy_runner not present") 103 | return 104 | 105 | if comfy_base_path != "ComfyUI": 106 | env_file_path = os.path.join(comfy_runner_path, ".env") 107 | try: 108 | os.makedirs(os.path.dirname(env_file_path), exist_ok=True) 109 | 110 | with open(env_file_path, "w", encoding="utf-8") as f: 111 | f.write(f"COMFY_RUNNER_MODELS_BASE_PATH={comfy_base_path}") 112 | 113 | with open(env_file_path, "r", encoding="utf-8") as f: 114 | written_content = f.read() 115 | 116 | if written_content != f"COMFY_RUNNER_MODELS_BASE_PATH={comfy_base_path}": 117 | print(f"File was written, but content doesn't match. Expected: {comfy_base_path}, Got: {written_content}") 118 | 119 | except IOError as e: 120 | print(f"IOError occurred while writing to {env_file_path}: {e}") 121 | except Exception as e: 122 | print(f"An unexpected error occurred: {e}") 123 | 124 | else: 125 | with open("comfy_runner/.env", "w", encoding="utf-8") as f: 126 | f.write("") 127 | -------------------------------------------------------------------------------- /ui_components/widgets/image_zoom_widgets.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | import streamlit as st 3 | from backend.models import InternalFileObject 4 | from shared.constants import InternalFileType 5 | from ui_components.constants import WorkflowStageType 6 | from ui_components.methods.common_methods import add_image_variant, promote_image_variant 7 | from ui_components.methods.file_methods import save_or_host_file 8 | 9 | from utils.data_repo.data_repo import DataRepo 10 | 11 | 12 | def zoom_inputs(position="in-frame", horizontal=False, shot_uuid=None): 13 | if horizontal: 14 | col1, col2 = st.columns(2) 15 | col3, col4 = st.columns(2) 16 | col5, col6 = st.columns(2) 17 | else: 18 | col1 = col2 = col3 = col4 = col5 = col6 = st 19 | 20 | if "zoom_level_input_default" not in st.session_state: 21 | st.session_state["zoom_level_input_default"] = 100 22 | st.session_state["rotation_angle_input_default"] = 0 23 | st.session_state["x_shift_default"] = 0 24 | st.session_state["y_shift_default"] = 0 25 | st.session_state["flip_vertically_default"] = False 26 | st.session_state["flip_horizontally_default"] = False 27 | 28 | col1.number_input( 29 | "Zoom in/out:", 30 | min_value=10, 31 | max_value=1000, 32 | step=5, 33 | key=f"zoom_level_input", 34 | value=st.session_state["zoom_level_input_default"], 35 | ) 36 | 37 | col2.number_input( 38 | "Rotate:", 39 | min_value=-360, 40 | max_value=360, 41 | step=5, 42 | key="rotation_angle_input", 43 | value=st.session_state["rotation_angle_input_default"], 44 | ) 45 | # st.session_state['rotation_angle_input'] = 0 46 | 47 | col3.number_input( 48 | "Shift left/right:", 49 | min_value=-1000, 50 | max_value=1000, 51 | step=5, 52 | key=f"x_shift", 53 | value=st.session_state["x_shift_default"], 54 | ) 55 | 56 | col4.number_input( 57 | "Shift down/up:", 58 | min_value=-1000, 59 | max_value=1000, 60 | step=5, 61 | key=f"y_shift", 62 | value=st.session_state["y_shift_default"], 63 | ) 64 | 65 | col5.checkbox( 66 | "Flip vertically ↕️", key=f"flip_vertically", value=str(st.session_state["flip_vertically_default"]) 67 | ) 68 | 69 | col6.checkbox( 70 | "Flip horizontally ↔️", 71 | key=f"flip_horizontally", 72 | value=str(st.session_state["flip_horizontally_default"]), 73 | ) 74 | 75 | 76 | def save_zoomed_image(image, timing_uuid, stage, promote=False): 77 | data_repo = DataRepo() 78 | timing = data_repo.get_timing_from_uuid(timing_uuid) 79 | project_uuid = timing.shot.project.uuid 80 | 81 | file_name = str(uuid.uuid4()) + ".png" 82 | 83 | if stage == WorkflowStageType.SOURCE.value: 84 | save_location = f"videos/{project_uuid}/assets/frames/modified/{file_name}" 85 | hosted_url = save_or_host_file(image, save_location) 86 | file_data = {"name": file_name, "type": InternalFileType.IMAGE.value, "project_id": project_uuid} 87 | 88 | if hosted_url: 89 | file_data.update({"hosted_url": hosted_url}) 90 | else: 91 | file_data.update({"local_path": save_location}) 92 | 93 | source_image: InternalFileObject = data_repo.create_file(**file_data) 94 | data_repo.update_specific_timing( 95 | st.session_state["current_frame_uuid"], source_image_id=source_image.uuid, update_in_place=True 96 | ) 97 | elif stage == WorkflowStageType.STYLED.value: 98 | save_location = f"videos/{project_uuid}/assets/frames/modified/{file_name}" 99 | hosted_url = save_or_host_file(image, save_location) 100 | file_data = {"name": file_name, "type": InternalFileType.IMAGE.value, "project_id": project_uuid} 101 | 102 | if hosted_url: 103 | file_data.update({"hosted_url": hosted_url}) 104 | else: 105 | file_data.update({"local_path": save_location}) 106 | 107 | styled_image: InternalFileObject = data_repo.create_file(**file_data) 108 | number_of_image_variants = add_image_variant(styled_image.uuid, timing_uuid) 109 | if promote: 110 | promote_image_variant(timing_uuid, number_of_image_variants - 1) 111 | 112 | 113 | def reset_zoom_element(): 114 | st.session_state["zoom_level_input_default"] = 100 115 | st.session_state["zoom_level_input"] = 100 116 | st.session_state["rotation_angle_input_default"] = 0 117 | st.session_state["rotation_angle_input"] = 0 118 | st.session_state["x_shift_default"] = 0 119 | st.session_state["x_shift"] = 0 120 | st.session_state["y_shift_default"] = 0 121 | st.session_state["y_shift"] = 0 122 | st.session_state["flip_vertically_default"] = False 123 | st.session_state["flip_vertically"] = False 124 | st.session_state["flip_horizontally_default"] = False 125 | st.session_state["flip_horizontally"] = False 126 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import time 2 | import streamlit as st 3 | from moviepy.editor import * 4 | import subprocess 5 | import os 6 | import django 7 | import sentry_sdk 8 | 9 | from utils.data_repo.api_repo import APIRepo 10 | 11 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_settings") 12 | django.setup() 13 | st.session_state["django_init"] = True 14 | 15 | from shared.constants import ( 16 | GPU_INFERENCE_ENABLED_KEY, 17 | HOSTED_BACKGROUND_RUNNER_MODE, 18 | OFFLINE_MODE, 19 | SERVER, 20 | ServerType, 21 | ConfigManager, 22 | ) 23 | 24 | from shared.logging.logging import AppLogger 25 | from ui_components.components.user_login_page import user_login_ui 26 | from ui_components.models import InternalUserObject 27 | from utils.app_update_utils import apply_updates, check_and_pull_changes, load_save_checkpoint 28 | from utils.common_decorators import update_refresh_lock 29 | from utils.common_utils import is_process_active, refresh_process_active 30 | from utils.state_refresh import refresh_app 31 | 32 | from utils.constants import ( 33 | REFRESH_PROCESS_PORT, 34 | RUNNER_PROCESS_NAME, 35 | ) 36 | from streamlit_server_state import server_state_lock 37 | from utils.refresh_target import SAVE_STATE 38 | from banodoco_settings import project_init 39 | from utils.data_repo.data_repo import DataRepo 40 | 41 | 42 | config_manager = ConfigManager() 43 | RUNNER_PROCESS_PORT = config_manager.get("runner_process_port") 44 | 45 | 46 | def start_runner(): 47 | if SERVER != ServerType.DEVELOPMENT.value and HOSTED_BACKGROUND_RUNNER_MODE in [False, "False"]: 48 | return 49 | 50 | with server_state_lock["runner"]: 51 | app_logger = AppLogger() 52 | 53 | if not is_process_active(RUNNER_PROCESS_NAME, RUNNER_PROCESS_PORT): 54 | app_logger.info("Starting runner") 55 | python_executable = sys.executable 56 | _ = subprocess.Popen([python_executable, "banodoco_runner.py"]) 57 | max_retries = 6 58 | while not is_process_active(RUNNER_PROCESS_NAME, RUNNER_PROCESS_PORT) and max_retries: 59 | time.sleep(0.1) 60 | max_retries -= 1 61 | 62 | # refreshing the app if the runner port has changed 63 | old_port = RUNNER_PROCESS_PORT 64 | new_port = config_manager.get("runner_process_port", fresh_pull=True) 65 | if old_port != new_port: 66 | st.rerun() 67 | else: 68 | # app_logger.debug("Runner already running") 69 | pass 70 | 71 | 72 | def start_project_refresh(): 73 | if SERVER != ServerType.DEVELOPMENT.value: 74 | return 75 | 76 | with server_state_lock["refresh_app"]: 77 | app_logger = AppLogger() 78 | 79 | if not refresh_process_active(REFRESH_PROCESS_PORT): 80 | python_executable = sys.executable 81 | _ = subprocess.Popen([python_executable, "auto_refresh.py"]) 82 | max_retries = 6 83 | while not refresh_process_active(REFRESH_PROCESS_PORT) and max_retries: 84 | time.sleep(1) 85 | max_retries -= 1 86 | app_logger.info("Auto refresh enabled") 87 | else: 88 | # app_logger.debug("refresh process already running") 89 | pass 90 | 91 | 92 | def main(): 93 | st.set_page_config(page_title="Dough", layout="wide", page_icon="🎨") 94 | st.markdown( 95 | r""" 96 | 101 | """, 102 | unsafe_allow_html=True, 103 | ) 104 | update_refresh_lock(False) 105 | 106 | # if it's the first time, 107 | if "first_load" not in st.session_state: 108 | if not is_process_active(RUNNER_PROCESS_NAME, RUNNER_PROCESS_PORT): 109 | if not load_save_checkpoint(): 110 | check_and_pull_changes() # enabling auto updates only for local version 111 | else: 112 | apply_updates() 113 | refresh_app() 114 | st.session_state["first_load"] = True 115 | 116 | start_runner() 117 | start_project_refresh() 118 | project_init() 119 | 120 | from ui_components.setup import setup_app_ui 121 | from ui_components.components.welcome_page import welcome_page 122 | 123 | data_repo = DataRepo() 124 | api_repo = APIRepo() 125 | config_manager = ConfigManager() 126 | gpu_enabled = config_manager.get(GPU_INFERENCE_ENABLED_KEY, False) 127 | 128 | app_setting = data_repo.get_app_setting_from_uuid() 129 | if app_setting.welcome_state == 2: 130 | # api/online inference mode 131 | if not gpu_enabled: 132 | if not api_repo.is_user_logged_in(): 133 | # user not logged in 134 | user_login_ui() 135 | else: 136 | setup_app_ui() 137 | else: 138 | # gpu/offline inference mode 139 | setup_app_ui() 140 | else: 141 | welcome_page() 142 | 143 | st.session_state["maintain_state"] = False 144 | 145 | 146 | if __name__ == "__main__": 147 | try: 148 | main() 149 | except Exception as e: 150 | sentry_sdk.capture_exception(e) 151 | raise e 152 | -------------------------------------------------------------------------------- /ui_components/widgets/image_carousal.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from st_clickable_images import clickable_images 3 | 4 | from ui_components.constants import WorkflowStageType 5 | from ui_components.models import InternalShotObject 6 | from utils.data_repo.data_repo import DataRepo 7 | from utils.state_refresh import refresh_app 8 | 9 | 10 | def display_image(timing_uuid, stage=None, clickable=False): 11 | data_repo = DataRepo() 12 | timing = data_repo.get_timing_from_uuid(timing_uuid) 13 | timing_idx = timing.aux_frame_index 14 | 15 | # if it's less than 0 or greater than the number in timing_details, show nothing 16 | if not timing: 17 | st.write("no images") 18 | 19 | else: 20 | if stage == WorkflowStageType.STYLED.value: 21 | image = timing.primary_image_location 22 | elif stage == WorkflowStageType.SOURCE.value: 23 | image = timing.source_image.location if timing.source_image else "" 24 | 25 | if image != "": 26 | if clickable is True: 27 | if "counter" not in st.session_state: 28 | st.session_state["counter"] = 0 29 | 30 | import base64 31 | 32 | if image.startswith("http"): 33 | st.write("") 34 | else: 35 | with open(image, "rb") as image: 36 | st.write("") 37 | encoded = base64.b64encode(image.read()).decode() 38 | image = f"data:image/jpeg;base64,{encoded}" 39 | 40 | st.session_state[f"{timing_idx}_{stage}_clicked"] = clickable_images( 41 | [image], 42 | div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"}, 43 | img_style={"max-width": "100%", "height": "auto", "cursor": "pointer"}, 44 | key=f"{timing_idx}_{stage}_image_{st.session_state['counter']}", 45 | ) 46 | 47 | if st.session_state[f"{timing_idx}_{stage}_clicked"] == 0: 48 | timing_details = data_repo.get_timing_list_from_shot(timing.shot.uuid) 49 | st.session_state["current_frame_uuid"] = timing_details[timing_idx].uuid 50 | st.session_state["current_frame_index"] = timing_idx + 1 51 | st.session_state["prev_frame_index"] = timing_idx + 1 52 | # st.session_state['frame_styling_view_type_index'] = 0 53 | st.session_state["frame_styling_view_type"] = "Individual" 54 | st.session_state["counter"] += 1 55 | refresh_app() 56 | 57 | elif clickable is False: 58 | st.image(image, use_column_width=True) 59 | else: 60 | st.error(f"No {stage} image found for #{timing_idx + 1}") 61 | 62 | 63 | def carousal_of_images_element(shot_uuid, stage=WorkflowStageType.STYLED.value): 64 | data_repo = DataRepo() 65 | shot: InternalShotObject = data_repo.get_shot_from_uuid(shot_uuid) 66 | timing_list = shot.timing_list 67 | 68 | header1, header2, header3, header4, header5 = st.columns([1, 1, 1, 1, 1]) 69 | 70 | current_frame_uuid = st.session_state["current_frame_uuid"] 71 | current_timing = data_repo.get_timing_from_uuid(current_frame_uuid) 72 | 73 | with header1: 74 | if current_timing.aux_frame_index - 2 >= 0: 75 | prev_2_timing = data_repo.get_timing_from_frame_number( 76 | shot_uuid, current_timing.aux_frame_index - 2 77 | ) 78 | 79 | if prev_2_timing: 80 | display_image(prev_2_timing.uuid, stage=stage, clickable=True) 81 | st.info(f"#{prev_2_timing.aux_frame_index + 1}") 82 | 83 | with header2: 84 | if current_timing.aux_frame_index - 1 >= 0: 85 | prev_timing = data_repo.get_timing_from_frame_number( 86 | shot_uuid, current_timing.aux_frame_index - 1 87 | ) 88 | if prev_timing: 89 | display_image(prev_timing.uuid, stage=stage, clickable=True) 90 | st.info(f"#{prev_timing.aux_frame_index + 1}") 91 | 92 | with header3: 93 | 94 | timing = data_repo.get_timing_from_uuid(current_frame_uuid) 95 | display_image(timing.uuid, stage=stage, clickable=True) 96 | st.success(f"#{current_timing.aux_frame_index + 1}") 97 | with header4: 98 | if current_timing.aux_frame_index + 1 <= len(timing_list): 99 | next_timing = data_repo.get_timing_from_frame_number( 100 | shot_uuid, current_timing.aux_frame_index + 1 101 | ) 102 | if next_timing: 103 | display_image(next_timing.uuid, stage=stage, clickable=True) 104 | st.info(f"#{next_timing.aux_frame_index + 1}") 105 | 106 | with header5: 107 | if current_timing.aux_frame_index + 2 <= len(timing_list): 108 | next_2_timing = data_repo.get_timing_from_frame_number( 109 | shot_uuid, current_timing.aux_frame_index + 2 110 | ) 111 | if next_2_timing: 112 | display_image(next_2_timing.uuid, stage=stage, clickable=True) 113 | st.info(f"#{next_2_timing.aux_frame_index + 1}") 114 | -------------------------------------------------------------------------------- /ui_components/widgets/add_key_frame_element.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Union 3 | import streamlit as st 4 | from shared.constants import AnimationStyleType 5 | from ui_components.models import InternalFileObject, InternalFrameTimingObject 6 | from utils.common_decorators import update_refresh_lock 7 | from utils.state_refresh import refresh_app 8 | from utils.data_repo.data_repo import DataRepo 9 | from ui_components.methods.file_methods import generate_pil_image, save_or_host_file 10 | from ui_components.methods.common_methods import add_image_variant, save_new_image 11 | from PIL import Image 12 | 13 | 14 | def add_key_frame_section(shot_uuid): 15 | data_repo = DataRepo() 16 | shot = data_repo.get_shot_from_uuid(shot_uuid) 17 | selected_image_location = "" 18 | 19 | uploaded_images = st.file_uploader( 20 | "Upload images:", 21 | type=["png", "jpg", "jpeg", "webp"], 22 | key=f"uploaded_image_{shot_uuid}", 23 | help="You can upload multiple images", 24 | accept_multiple_files=True, 25 | ) 26 | 27 | if st.button( 28 | f"Add key frame(s)", 29 | use_container_width=True, 30 | key=f"add_key_frame_btn_{shot_uuid}", 31 | type="primary", 32 | ): 33 | update_refresh_lock(True) 34 | if uploaded_images: 35 | progress_bar = st.progress(0) 36 | # Remove sorting to maintain upload order 37 | for i, uploaded_image in enumerate(uploaded_images): 38 | image = Image.open(uploaded_image) 39 | file_location = f"videos/{shot.uuid}/assets/frames/base/{uploaded_image.name}" 40 | selected_image_location = save_or_host_file(image, file_location) 41 | selected_image_location = selected_image_location or file_location 42 | add_key_frame(selected_image_location, shot_uuid, refresh_state=False) 43 | progress_bar.progress((i + 1) / len(uploaded_images)) 44 | else: 45 | st.error("Please generate new images or upload them") 46 | time.sleep(0.7) 47 | update_refresh_lock(False) 48 | refresh_app() 49 | 50 | 51 | def display_selected_key_frame(selected_image_location, apply_zoom_effects): 52 | selected_image = None 53 | if selected_image_location: 54 | # if apply_zoom_effects == "Yes": 55 | # image_preview = generate_pil_image(selected_image_location) 56 | # selected_image = apply_image_transformations(image_preview, st.session_state['zoom_level_input'], st.session_state['rotation_angle_input'], st.session_state['x_shift'], st.session_state['y_shift'], st.session_state['flip_vertically'], st.session_state['flip_horizontally']) 57 | 58 | selected_image = generate_pil_image(selected_image_location) 59 | st.info("Starting Image:") 60 | st.image(selected_image) 61 | else: 62 | st.error("No Starting Image Found") 63 | 64 | return selected_image 65 | 66 | 67 | def add_key_frame_element(shot_uuid): 68 | add1, add2 = st.columns(2) 69 | with add1: 70 | selected_image_location, inherit_styling_settings = add_key_frame_section(shot_uuid) 71 | with add2: 72 | selected_image = display_selected_key_frame(selected_image_location, False) 73 | 74 | return selected_image, inherit_styling_settings 75 | 76 | 77 | def add_key_frame( 78 | selected_image: Union[Image.Image, InternalFileObject], 79 | shot_uuid, 80 | target_frame_position=None, 81 | refresh_state=True, 82 | update_cur_frame_idx=True, 83 | ): 84 | """ 85 | either a pil image or a internalfileobject can be passed to this method, for adding it inside a shot 86 | """ 87 | data_repo = DataRepo() 88 | timing_list = data_repo.get_timing_list_from_shot(shot_uuid) 89 | 90 | # creating frame inside the shot at target_frame_position 91 | len_shot_timing_list = len(timing_list) if len(timing_list) > 0 else 0 92 | target_frame_position = len_shot_timing_list if target_frame_position is None else target_frame_position 93 | target_aux_frame_index = min(len(timing_list), target_frame_position) 94 | 95 | if isinstance(selected_image, InternalFileObject): 96 | saved_image = selected_image 97 | else: 98 | shot = data_repo.get_shot_from_uuid(shot_uuid) 99 | saved_image = save_new_image(selected_image, shot.project.uuid) 100 | 101 | timing_data = { 102 | "shot_id": shot_uuid, 103 | "animation_style": AnimationStyleType.CREATIVE_INTERPOLATION.value, 104 | "aux_frame_index": target_aux_frame_index, 105 | "source_image_id": saved_image.uuid, 106 | "primary_image_id": saved_image.uuid, 107 | } 108 | new_timing: InternalFrameTimingObject = data_repo.create_timing(**timing_data) 109 | 110 | if update_cur_frame_idx: 111 | timing_list = data_repo.get_timing_list_from_shot(shot_uuid) 112 | # this part of code updates current_frame_index when a new keyframe is added 113 | if len(timing_list) <= 1: 114 | st.session_state["current_frame_index"] = 1 115 | st.session_state["current_frame_uuid"] = timing_list[0].uuid 116 | else: 117 | st.session_state["prev_frame_index"] = min(len(timing_list), target_aux_frame_index + 1) 118 | st.session_state["current_frame_index"] = min(len(timing_list), target_aux_frame_index + 1) 119 | st.session_state["current_frame_uuid"] = timing_list[ 120 | st.session_state["current_frame_index"] - 1 121 | ].uuid 122 | 123 | print( 124 | f"Updated session state: current_frame_index: {st.session_state['current_frame_index']}, current_frame_uuid: {st.session_state['current_frame_uuid']}" 125 | ) 126 | 127 | if refresh_state: 128 | refresh_app() 129 | 130 | return new_timing 131 | -------------------------------------------------------------------------------- /ui_components/widgets/frame_selector.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import streamlit as st 3 | from utils.data_repo.data_repo import DataRepo 4 | from utils import st_memory 5 | from ui_components.methods.common_methods import add_new_shot 6 | from utils.state_refresh import refresh_app 7 | 8 | 9 | def frame_selector_widget(show_frame_selector=True): 10 | data_repo = DataRepo() 11 | timing_list = data_repo.get_timing_list_from_shot(st.session_state["shot_uuid"]) 12 | shot = data_repo.get_shot_from_uuid(st.session_state["shot_uuid"]) 13 | shot_list = data_repo.get_shot_list(shot.project.uuid) 14 | len_timing_list = len(timing_list) if len(timing_list) > 0 else 1.0 15 | project_uuid = shot.project.uuid 16 | 17 | if "prev_shot_index" not in st.session_state: 18 | st.session_state["prev_shot_index"] = shot.shot_idx 19 | if "shot_name" not in st.session_state: 20 | st.session_state["shot_name"] = shot.name 21 | shot1, shot2 = st.columns([1, 1]) 22 | with shot1: 23 | shot_names = [s.name for s in shot_list] 24 | shot_names.append("**Create New Shot**") 25 | current_shot_name = st.selectbox( 26 | "Shot:", shot_names, key="current_shot_sidebar_selector", index=shot_names.index(shot.name) 27 | ) 28 | if current_shot_name != "**Create New Shot**": 29 | if current_shot_name != st.session_state["shot_name"]: 30 | st.session_state["shot_name"] = current_shot_name 31 | refresh_app() 32 | 33 | if current_shot_name == "**Create New Shot**": 34 | new_shot_name = st.text_input( 35 | "New shot name:", max_chars=40, key=f"shot_name_sidebar_{st.session_state['shot_name']}" 36 | ) 37 | if st.button("Create new shot", key=f"create_new_shot_{st.session_state['shot_name']}"): 38 | new_shot = add_new_shot(project_uuid, name=new_shot_name) 39 | st.session_state["shot_name"] = new_shot_name 40 | st.session_state["shot_uuid"] = new_shot.uuid 41 | refresh_app() 42 | 43 | # find shot index based on shot name 44 | st.session_state["current_shot_index"] = shot_names.index(st.session_state["shot_name"]) + 1 45 | 46 | if st.session_state["shot_name"] != shot.name: 47 | st.session_state["shot_uuid"] = shot_list[shot_names.index(st.session_state["shot_name"])].uuid 48 | refresh_app() 49 | 50 | if not ("current_shot_index" in st.session_state and st.session_state["current_shot_index"]): 51 | st.session_state["current_shot_index"] = shot_names.index(st.session_state["shot_name"]) + 1 52 | update_current_shot_index(st.session_state["current_shot_index"]) 53 | 54 | if st.session_state["page"] == "Key Frames": 55 | if st.session_state["current_frame_index"] > len_timing_list: 56 | update_current_frame_index(len_timing_list) 57 | 58 | elif st.session_state["page"] == "Shots": 59 | if st.session_state["current_shot_index"] > len(shot_list): 60 | update_current_shot_index(len(shot_list)) 61 | 62 | if show_frame_selector: 63 | if len(timing_list): 64 | if "prev_frame_index" not in st.session_state or st.session_state["prev_frame_index"] > len( 65 | timing_list 66 | ): 67 | st.session_state["prev_frame_index"] = 1 68 | 69 | # Create a list of frames with a blank value as the first item 70 | frame_list = [""] + [f"{i+1}" for i in range(len(timing_list))] 71 | with shot2: 72 | frame_selection = st_memory.selectbox( 73 | "Frame:", frame_list, key="current_frame_sidebar_selector" 74 | ) 75 | 76 | # only trigger the frame number extraction and current frame index update if a non-empty value is selected 77 | if frame_selection != "": 78 | if st.button("Jump to shot view", use_container_width=True): 79 | st.session_state["current_frame_sidebar_selector"] = 0 80 | refresh_app() 81 | 82 | st.session_state["current_frame_index"] = int(frame_selection.split(" ")[-1]) 83 | update_current_frame_index(st.session_state["current_frame_index"]) 84 | else: 85 | frame_selection = "" 86 | with shot2: 87 | st.write("") 88 | st.error("No frames present") 89 | 90 | return frame_selection 91 | 92 | 93 | def update_current_frame_index(index): 94 | data_repo = DataRepo() 95 | timing_list = data_repo.get_timing_list_from_shot(st.session_state["shot_uuid"]) 96 | st.session_state["current_frame_uuid"] = timing_list[index - 1].uuid 97 | if st.session_state["prev_frame_index"] != index or st.session_state["current_frame_index"] != index: 98 | st.session_state["prev_frame_index"] = index 99 | st.session_state["current_frame_index"] = index 100 | st.session_state["current_frame_uuid"] = timing_list[index - 1].uuid 101 | st.session_state["reset_canvas"] = True 102 | st.session_state["frame_styling_view_type_index"] = 0 103 | st.session_state["frame_styling_view_type"] = "Generate View" 104 | 105 | refresh_app() 106 | 107 | 108 | def update_current_shot_index(index): 109 | data_repo = DataRepo() 110 | shot_list = data_repo.get_shot_list(st.session_state["project_uuid"]) 111 | st.session_state["shot_uuid"] = shot_list[index - 1].uuid 112 | if st.session_state["prev_shot_index"] != index or st.session_state["current_shot_index"] != index: 113 | st.session_state["current_shot_index"] = index 114 | st.session_state["prev_shot_index"] = index 115 | st.session_state["shot_uuid"] = shot_list[index - 1].uuid 116 | st.session_state["reset_canvas"] = True 117 | st.session_state["frame_styling_view_type_index"] = 0 118 | st.session_state["frame_styling_view_type"] = "Individual View" 119 | 120 | refresh_app() 121 | -------------------------------------------------------------------------------- /backend/serializers/dto.py: -------------------------------------------------------------------------------- 1 | import json 2 | from rest_framework import serializers 3 | 4 | from backend.models import ( 5 | AIModel, 6 | AppSetting, 7 | BackupTiming, 8 | InferenceLog, 9 | InternalFileObject, 10 | Project, 11 | Setting, 12 | Shot, 13 | Timing, 14 | User, 15 | ) 16 | 17 | 18 | class UserDto(serializers.ModelSerializer): 19 | class Meta: 20 | model = User 21 | fields = ("uuid", "name", "email", "type", "total_credits") 22 | 23 | 24 | class ProjectDto(serializers.ModelSerializer): 25 | user_uuid = serializers.SerializerMethodField() 26 | 27 | class Meta: 28 | model = Project 29 | fields = ("uuid", "name", "user_uuid", "created_on", "temp_file_list", "meta_data") 30 | 31 | def get_user_uuid(self, obj): 32 | return obj.user.uuid 33 | 34 | 35 | class AIModelDto(serializers.ModelSerializer): 36 | user_uuid = serializers.SerializerMethodField() 37 | 38 | class Meta: 39 | model = AIModel 40 | fields = ( 41 | "uuid", 42 | "name", 43 | "user_uuid", 44 | "custom_trained", 45 | "version", 46 | "replicate_model_id", 47 | "replicate_url", 48 | "diffusers_url", 49 | "category", 50 | "training_image_list", 51 | "keyword", 52 | "created_on", 53 | ) 54 | 55 | def get_user_uuid(self, obj): 56 | return obj.user.uuid 57 | 58 | 59 | class InferenceLogDto(serializers.ModelSerializer): 60 | project = ProjectDto() 61 | model = AIModelDto() 62 | 63 | class Meta: 64 | model = InferenceLog 65 | fields = ( 66 | "uuid", 67 | "project", 68 | "model", 69 | "input_params", 70 | "output_details", 71 | "total_inference_time", 72 | "credits_used", 73 | "created_on", 74 | "updated_on", 75 | "status", 76 | "model_name", 77 | "generation_source", 78 | "generation_tag", 79 | ) 80 | 81 | 82 | class InternalFileDto(serializers.ModelSerializer): 83 | project = ProjectDto() # TODO: pass this as context to speed up the api 84 | inference_log = InferenceLogDto() 85 | 86 | class Meta: 87 | model = InternalFileObject 88 | fields = ( 89 | "uuid", 90 | "name", 91 | "local_path", 92 | "type", 93 | "hosted_url", 94 | "created_on", 95 | "inference_log", 96 | "project", 97 | "tag", 98 | "shot_uuid", 99 | ) 100 | 101 | 102 | class BasicShotDto(serializers.ModelSerializer): 103 | project = ProjectDto() 104 | 105 | class Meta: 106 | model = Shot 107 | fields = ( 108 | "uuid", 109 | "name", 110 | "project", 111 | "desc", 112 | "shot_idx", 113 | "project", 114 | "duration", 115 | "meta_data", 116 | ) 117 | 118 | 119 | class TimingDto(serializers.ModelSerializer): 120 | model = AIModelDto() 121 | source_image = InternalFileDto() 122 | mask = InternalFileDto() 123 | canny_image = InternalFileDto() 124 | primary_image = InternalFileDto() 125 | shot = BasicShotDto() 126 | 127 | class Meta: 128 | model = Timing 129 | fields = ( 130 | "uuid", 131 | "model", 132 | "source_image", 133 | "mask", 134 | "canny_image", 135 | "primary_image", 136 | "alternative_images", 137 | "notes", 138 | "aux_frame_index", 139 | "created_on", 140 | "shot", 141 | ) 142 | 143 | 144 | class AppSettingDto(serializers.ModelSerializer): 145 | user = UserDto() 146 | 147 | class Meta: 148 | model = AppSetting 149 | fields = ("uuid", "user", "previous_project", "replicate_username", "welcome_state", "created_on") 150 | 151 | 152 | class SettingDto(serializers.ModelSerializer): 153 | project = ProjectDto() 154 | default_model = AIModelDto() 155 | audio = InternalFileDto() 156 | 157 | class Meta: 158 | model = Setting 159 | fields = ( 160 | "uuid", 161 | "project", 162 | "default_model", 163 | "audio", 164 | "input_type", 165 | "width", 166 | "height", 167 | "created_on", 168 | ) 169 | 170 | 171 | class BackupDto(serializers.ModelSerializer): 172 | project = ProjectDto() 173 | 174 | class Meta: 175 | model = BackupTiming 176 | fields = ("name", "project", "note", "data_dump", "created_on") 177 | 178 | 179 | class BackupListDto(serializers.ModelSerializer): 180 | project = ProjectDto() 181 | 182 | class Meta: 183 | model = BackupTiming 184 | fields = ("uuid", "project", "name", "note", "created_on") 185 | 186 | 187 | class ShotDto(serializers.ModelSerializer): 188 | timing_list = serializers.SerializerMethodField() 189 | interpolated_clip_list = serializers.SerializerMethodField() 190 | main_clip = InternalFileDto() 191 | project = ProjectDto() 192 | 193 | class Meta: 194 | model = Shot 195 | fields = ( 196 | "uuid", 197 | "name", 198 | "desc", 199 | "shot_idx", 200 | "project", 201 | "duration", 202 | "meta_data", 203 | "timing_list", 204 | "interpolated_clip_list", 205 | "main_clip", 206 | ) 207 | 208 | def get_timing_list(self, obj): 209 | timing_list = self.context.get("timing_list", []) 210 | timing_list = [ 211 | TimingDto(timing).data for timing in timing_list if str(timing.shot.uuid) == str(obj.uuid) 212 | ] 213 | timing_list.sort(key=lambda x: x["aux_frame_index"]) 214 | return timing_list 215 | 216 | def get_interpolated_clip_list(self, obj): 217 | id_list = json.loads(obj.interpolated_clip_list) if obj.interpolated_clip_list else [] 218 | file_list = InternalFileObject.objects.filter(uuid__in=id_list, is_disabled=False).all() 219 | return [InternalFileDto(file).data for file in file_list] 220 | --------------------------------------------------------------------------------