├── __init__.py
├── shared
├── __init__.py
├── logging
│ ├── constants.py
│ └── logging.py
├── utils.py
└── file_upload
│ └── s3.py
├── utils
├── __init__.py
├── refresh_target.py
├── state_refresh.py
├── enum.py
├── cookie_manager
│ └── cookie.py
├── local_storage
│ ├── url_storage.py
│ └── local_storage.py
├── ml_processor
│ ├── ml_interface.py
│ ├── comfy_workflows
│ │ ├── llama2_workflow_api.json
│ │ ├── flux_schnell_workflow_api.json
│ │ ├── llama_workflow_api.json
│ │ ├── sdxl_img2img_workflow_api.json
│ │ ├── creative_image_gen.json
│ │ ├── ipadapter_composition_workflow_api.json
│ │ ├── sd3_workflow_api.json
│ │ ├── ipadapter_plus_api.json
│ │ ├── dynamicrafter_api.json
│ │ ├── sdxl_controlnet_workflow_api.json
│ │ ├── sdxl_workflow_api.json
│ │ ├── sdxl_openpose_workflow_api.json
│ │ └── ipadapter_face_api.json
│ ├── motion_module.py
│ ├── sai
│ │ ├── sai.py
│ │ └── utils.py
│ └── gpu
│ │ ├── gpu.py
│ │ └── utils.py
├── encryption.py
├── third_party_auth
│ └── google
│ │ └── google_auth.py
├── media_processor
│ └── video.py
├── cache
│ └── cache.py
└── common_decorators.py
├── backend
├── __init__.py
├── migrations
│ ├── __init__.py
│ ├── 0017_credits_used_added.py
│ ├── 0002_temp_file_list_added.py
│ ├── 0005_model_type_added.py
│ ├── 0010_project_metadata_added.py
│ ├── 0004_credits_added_in_user.py
│ ├── 0003_custom_trained_model_check_added.py
│ ├── 0009_log_status_added.py
│ ├── 0008_interpolated_clip_list_added.py
│ ├── 0006_inference_time_converted_to_float.py
│ ├── 0015_log_tag_added.py
│ ├── 0016_db_col_size_update.py
│ ├── 0007_log_mapped_to_file.py
│ ├── 0013_filter_keys_added.py
│ ├── 0011_lock_added.py
│ └── 0014_file_link_added.py
├── constants.py
├── apps.py
└── serializers
│ └── dto.py
├── scripts
├── app_version.txt
├── app_settings.toml
├── windows_setup_online.bat
├── windows_setup.bat
├── linux_setup_online.sh
├── linux_setup.sh
├── entrypoint.sh
├── entrypoint.bat
└── config.toml
├── .streamlit
├── credentials.toml
└── config.toml
├── .env.sample
├── sample_assets
├── sample_images
│ └── init_frames
│ │ ├── 1.jpg
│ │ ├── 2.jpg
│ │ └── 3.jpg
└── example_generations
│ ├── guy-result.mp4
│ ├── lady-result.mp4
│ └── world-result.mp4
├── .dockerignore
├── pyproject.toml
├── .vscode
└── settings.json
├── Dockerfile
├── ui_components
├── widgets
│ ├── base_theme.py
│ ├── common_element.py
│ ├── attach_audio_element.py
│ ├── display_element.py
│ ├── video_cropping_element.py
│ ├── frame_switch_btn.py
│ ├── download_file_progress_bar.py
│ ├── image_zoom_widgets.py
│ ├── image_carousal.py
│ ├── add_key_frame_element.py
│ └── frame_selector.py
├── components
│ ├── shortlist_page.py
│ ├── adjust_shot_page.py
│ ├── inspiraton_engine_page.py
│ ├── user_login_page.py
│ ├── timeline_view_page.py
│ ├── project_settings_page.py
│ ├── query_logger_page.py
│ └── new_project_page.py
├── methods
│ ├── data_logger.py
│ └── training_methods.py
└── constants.py
├── manage.py
├── .gitignore
├── requirements.txt
├── django_settings.py
├── .aws
└── task-definition.json
├── .github
└── workflows
│ └── deploy.yml
├── LICENSE.txt
├── auto_refresh.py
└── app.py
/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/shared/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/backend/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/backend/migrations/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/scripts/app_version.txt:
--------------------------------------------------------------------------------
1 | 0.9.24
--------------------------------------------------------------------------------
/utils/refresh_target.py:
--------------------------------------------------------------------------------
1 | SAVE_STATE = 256
2 |
--------------------------------------------------------------------------------
/.streamlit/credentials.toml:
--------------------------------------------------------------------------------
1 | [general]
2 | email="xyz@gmil.com"
--------------------------------------------------------------------------------
/scripts/app_settings.toml:
--------------------------------------------------------------------------------
1 | automatic_update = false
2 | gpu_inference = true
3 | runner_process_port = 12345
--------------------------------------------------------------------------------
/utils/state_refresh.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 |
4 | def refresh_app(*args):
5 | st.rerun()
6 |
--------------------------------------------------------------------------------
/.env.sample:
--------------------------------------------------------------------------------
1 | SERVER=development
2 | SERVER_URL=https://api.banodoco.ai
3 | OFFLINE_MODE=True
4 | HOSTED_BACKGROUND_RUNNER_MODE=False
--------------------------------------------------------------------------------
/sample_assets/sample_images/init_frames/1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/banodoco/Dough/HEAD/sample_assets/sample_images/init_frames/1.jpg
--------------------------------------------------------------------------------
/sample_assets/sample_images/init_frames/2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/banodoco/Dough/HEAD/sample_assets/sample_images/init_frames/2.jpg
--------------------------------------------------------------------------------
/sample_assets/sample_images/init_frames/3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/banodoco/Dough/HEAD/sample_assets/sample_images/init_frames/3.jpg
--------------------------------------------------------------------------------
/.dockerignore:
--------------------------------------------------------------------------------
1 | # system
2 | __pycache__/
3 | /.vscode
4 |
5 | # user
6 | /venv
7 | /videos
8 |
9 | banodoco_local.db
10 | doc.py
11 | .env
--------------------------------------------------------------------------------
/.streamlit/config.toml:
--------------------------------------------------------------------------------
1 | [theme]
2 | base="dark"
3 | primaryColor="#9b9bb1"
4 | backgroundColor="#2a3446"
5 | secondaryBackgroundColor="#26282f"
6 |
--------------------------------------------------------------------------------
/sample_assets/example_generations/guy-result.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/banodoco/Dough/HEAD/sample_assets/example_generations/guy-result.mp4
--------------------------------------------------------------------------------
/sample_assets/example_generations/lady-result.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/banodoco/Dough/HEAD/sample_assets/example_generations/lady-result.mp4
--------------------------------------------------------------------------------
/backend/constants.py:
--------------------------------------------------------------------------------
1 | from utils.enum import ExtendedEnum
2 |
3 |
4 | class UserType(ExtendedEnum):
5 | USER = "user"
6 | ADMIN = "admin"
7 |
--------------------------------------------------------------------------------
/sample_assets/example_generations/world-result.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/banodoco/Dough/HEAD/sample_assets/example_generations/world-result.mp4
--------------------------------------------------------------------------------
/backend/apps.py:
--------------------------------------------------------------------------------
1 | from django.apps import AppConfig
2 |
3 |
4 | class BackendConfig(AppConfig):
5 | name = "backend"
6 | verbose_name = "Local backend"
7 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "Dough"
3 | description = "Open source tool for steering AI animations with precision"
4 | readme = "README.md"
5 |
6 | [tool.black]
7 | line-length = 110
8 | target-version = ['py310']
9 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "workbench.colorTheme": "Monokai",
3 | "[python]": {
4 | "editor.defaultFormatter": "ms-python.black-formatter",
5 | "editor.formatOnSave": true,
6 | "editor.formatOnSaveMode": "file",
7 | },
8 | "terminal.integrated.scrollback": 1000000,
9 | }
--------------------------------------------------------------------------------
/utils/enum.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 |
4 | class ExtendedEnum(Enum):
5 |
6 | @classmethod
7 | def value_list(cls):
8 | return list(map(lambda c: c.value, cls))
9 |
10 | @classmethod
11 | def has_value(cls, value):
12 | return value in cls._value2member_map_
13 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM python:3.10.2
2 |
3 | RUN mkdir banodoco
4 |
5 | WORKDIR /banodoco
6 |
7 | COPY requirements.txt .
8 |
9 | RUN pip3 install --upgrade pip
10 | RUN pip3 install -r requirements.txt
11 | RUN apt-get update && apt-get install -y ffmpeg
12 | RUN echo "SERVER=production" > .env
13 |
14 | COPY . .
15 |
16 | EXPOSE 5500
17 |
18 | CMD ["sh", "entrypoint.sh"]
--------------------------------------------------------------------------------
/ui_components/widgets/base_theme.py:
--------------------------------------------------------------------------------
1 | import time
2 | import streamlit as st
3 | from utils.state_refresh import refresh_app
4 |
5 |
6 | class BaseTheme:
7 | @staticmethod
8 | def success_msg(msg):
9 | st.success(msg)
10 | time.sleep(0.5)
11 | refresh_app()
12 |
13 | @staticmethod
14 | def error_msg(msg):
15 | st.error(msg)
16 | time.sleep(0.5)
17 | refresh_app()
18 |
--------------------------------------------------------------------------------
/backend/migrations/0017_credits_used_added.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 4.2.1 on 2024-09-17 11:34
2 |
3 | from django.db import migrations, models
4 |
5 |
6 | class Migration(migrations.Migration):
7 |
8 | dependencies = [
9 | ("backend", "0016_db_col_size_update"),
10 | ]
11 |
12 | operations = [
13 | migrations.AddField(
14 | model_name="inferencelog",
15 | name="credits_used",
16 | field=models.FloatField(default=0),
17 | ),
18 | ]
19 |
--------------------------------------------------------------------------------
/backend/migrations/0002_temp_file_list_added.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 4.2.1 on 2023-07-26 07:18
2 |
3 | from django.db import migrations, models
4 |
5 |
6 | class Migration(migrations.Migration):
7 |
8 | dependencies = [
9 | ("backend", "0001_initial_setup"),
10 | ]
11 |
12 | operations = [
13 | migrations.AddField(
14 | model_name="project",
15 | name="temp_file_list",
16 | field=models.TextField(default=None, null=True),
17 | ),
18 | ]
19 |
--------------------------------------------------------------------------------
/backend/migrations/0005_model_type_added.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 4.2.1 on 2023-08-15 07:46
2 |
3 | from django.db import migrations, models
4 |
5 |
6 | class Migration(migrations.Migration):
7 |
8 | dependencies = [
9 | ("backend", "0004_credits_added_in_user"),
10 | ]
11 |
12 | operations = [
13 | migrations.AddField(
14 | model_name="aimodel",
15 | name="model_type",
16 | field=models.TextField(blank=True, default=""),
17 | ),
18 | ]
19 |
--------------------------------------------------------------------------------
/backend/migrations/0010_project_metadata_added.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 4.2.1 on 2023-10-14 01:55
2 |
3 | from django.db import migrations, models
4 |
5 |
6 | class Migration(migrations.Migration):
7 |
8 | dependencies = [
9 | ("backend", "0009_log_status_added"),
10 | ]
11 |
12 | operations = [
13 | migrations.AddField(
14 | model_name="project",
15 | name="meta_data",
16 | field=models.TextField(default=None, null=True),
17 | ),
18 | ]
19 |
--------------------------------------------------------------------------------
/backend/migrations/0004_credits_added_in_user.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 4.2.1 on 2023-08-02 06:43
2 |
3 | from django.db import migrations, models
4 |
5 |
6 | class Migration(migrations.Migration):
7 |
8 | dependencies = [
9 | ("backend", "0003_custom_trained_model_check_added"),
10 | ]
11 |
12 | operations = [
13 | migrations.AddField(
14 | model_name="user",
15 | name="total_credits",
16 | field=models.FloatField(default=0),
17 | ),
18 | ]
19 |
--------------------------------------------------------------------------------
/backend/migrations/0003_custom_trained_model_check_added.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 4.2.1 on 2023-07-31 05:47
2 |
3 | from django.db import migrations, models
4 |
5 |
6 | class Migration(migrations.Migration):
7 |
8 | dependencies = [
9 | ("backend", "0002_temp_file_list_added"),
10 | ]
11 |
12 | operations = [
13 | migrations.AddField(
14 | model_name="aimodel",
15 | name="custom_trained",
16 | field=models.BooleanField(default=False),
17 | ),
18 | ]
19 |
--------------------------------------------------------------------------------
/backend/migrations/0009_log_status_added.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 4.2.1 on 2023-10-09 14:17
2 |
3 | from django.db import migrations, models
4 |
5 |
6 | class Migration(migrations.Migration):
7 |
8 | dependencies = [
9 | ("backend", "0008_interpolated_clip_list_added"),
10 | ]
11 |
12 | operations = [
13 | migrations.AddField(
14 | model_name="inferencelog",
15 | name="status",
16 | field=models.CharField(default="", max_length=255),
17 | ),
18 | ]
19 |
--------------------------------------------------------------------------------
/shared/logging/constants.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from utils.enum import ExtendedEnum
3 |
4 |
5 | @dataclass
6 | class LoggingPayload:
7 | message: str
8 | data: dict = field(default_factory=dict)
9 |
10 |
11 | class LoggingType(ExtendedEnum):
12 | INFO = "info"
13 | INFERENCE_CALL = "inference_call"
14 | INFERENCE_RESULT = "inference_result"
15 | ERROR = "error"
16 | DEBUG = "debug"
17 |
18 |
19 | class LoggingMode(ExtendedEnum):
20 | OFFLINE = "offline"
21 | ONLINE = "online"
22 |
--------------------------------------------------------------------------------
/manage.py:
--------------------------------------------------------------------------------
1 | import os, sys
2 |
3 | if __name__ == "__main__":
4 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_settings")
5 |
6 | try:
7 | from django.core.management import execute_from_command_line
8 | except ImportError as exc:
9 | raise ImportError(
10 | "Couldn't import Django. Are you sure it's installed and "
11 | "available on your PYTHONPATH environment variable? Did you "
12 | "forget to activate a virtual environment?"
13 | ) from exc
14 |
15 | execute_from_command_line(sys.argv)
16 |
--------------------------------------------------------------------------------
/scripts/windows_setup_online.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | set "folderName=Dough"
3 | if not exist "%folderName%\" (
4 | if /i not "%CD%"=="%~dp0%folderName%\" (
5 | git clone --depth 1 -b main https://github.com/banodoco/Dough.git
6 | cd Dough
7 | python -m venv dough-env
8 | call dough-env\Scripts\activate.bat
9 | python.exe -m pip install --upgrade pip
10 | pip install -r requirements.txt
11 | pip install websocket
12 | call dough-env\Scripts\deactivate.bat
13 | copy .env.sample .env
14 | cd ..
15 | pause
16 | )
17 | )
--------------------------------------------------------------------------------
/backend/migrations/0008_interpolated_clip_list_added.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 4.2.1 on 2023-09-16 03:09
2 |
3 | from django.db import migrations, models
4 |
5 |
6 | class Migration(migrations.Migration):
7 |
8 | dependencies = [
9 | ("backend", "0007_log_mapped_to_file"),
10 | ]
11 |
12 | operations = [
13 | migrations.RemoveField(
14 | model_name="timing",
15 | name="interpolated_clip",
16 | ),
17 | migrations.AddField(
18 | model_name="timing",
19 | name="interpolated_clip_list",
20 | field=models.TextField(default=None, null=True),
21 | ),
22 | ]
23 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | comfyui*.log
6 |
7 | venv
8 | dough-env
9 | .vscode/launch.json
10 |
11 | # system
12 | .DS_Store
13 |
14 | # user files
15 | app_settings.csv
16 | videos
17 | videos/*
18 | temp_dir/
19 | /temp/
20 | training_data
21 | inference_log/*
22 | *.db
23 | test.py
24 | doc.py
25 | .env
26 | data.json
27 | comfy_runner/
28 | ComfyUI/
29 | output/
30 | images.zip
31 | upload_data.json
32 |
33 | # generated file TODO: move inside videos
34 | depth.png
35 | masked_image.png
36 | concatenated.mp4
37 | temp.png
38 | scripts/app_checkpoint.json
39 | current.md
40 | refresh_lock.json
41 |
--------------------------------------------------------------------------------
/ui_components/widgets/common_element.py:
--------------------------------------------------------------------------------
1 | from utils.data_repo.data_repo import DataRepo
2 | import streamlit as st
3 | import time
4 | from utils.state_refresh import refresh_app
5 |
6 |
7 | def duplicate_shot_button(shot_uuid, position="shot_view"):
8 | data_repo = DataRepo()
9 | shot = data_repo.get_shot_from_uuid(shot_uuid)
10 | if st.button(
11 | "Duplicate shot",
12 | key=f"duplicate_btn_{shot.uuid}_{position}",
13 | help="This will duplicate this shot.",
14 | use_container_width=True,
15 | ):
16 | data_repo.duplicate_shot(shot.uuid)
17 | st.success("Shot duplicated successfully")
18 | time.sleep(0.3)
19 | refresh_app()
20 |
--------------------------------------------------------------------------------
/backend/migrations/0006_inference_time_converted_to_float.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 4.2.1 on 2023-09-13 09:23
2 |
3 | from django.db import migrations, models
4 |
5 |
6 | class Migration(migrations.Migration):
7 |
8 | dependencies = [
9 | ("backend", "0005_model_type_added"),
10 | ]
11 |
12 | operations = [
13 | migrations.AlterField(
14 | model_name="inferencelog",
15 | name="total_inference_time",
16 | field=models.FloatField(default=0),
17 | ),
18 | migrations.AlterField(
19 | model_name="timing",
20 | name="strength",
21 | field=models.FloatField(default=1),
22 | ),
23 | ]
24 |
--------------------------------------------------------------------------------
/backend/migrations/0015_log_tag_added.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 4.2.1 on 2024-06-16 16:08
2 |
3 | from django.db import migrations, models
4 |
5 |
6 | class Migration(migrations.Migration):
7 |
8 | dependencies = [
9 | ("backend", "0014_file_link_added"),
10 | ]
11 |
12 | operations = [
13 | migrations.AddField(
14 | model_name="inferencelog",
15 | name="generation_source",
16 | field=models.CharField(blank=True, default="", max_length=255),
17 | ),
18 | migrations.AddField(
19 | model_name="inferencelog",
20 | name="generation_tag",
21 | field=models.CharField(blank=True, default="", max_length=255),
22 | ),
23 | ]
24 |
--------------------------------------------------------------------------------
/backend/migrations/0016_db_col_size_update.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 4.2.1 on 2024-09-14 09:31
2 |
3 | from django.db import migrations, models
4 |
5 |
6 | class Migration(migrations.Migration):
7 |
8 | dependencies = [
9 | ("backend", "0015_log_tag_added"),
10 | ]
11 |
12 | operations = [
13 | migrations.AlterField(
14 | model_name="appsetting",
15 | name="aws_access_key",
16 | field=models.CharField(blank=True, default="", max_length=1024),
17 | ),
18 | migrations.AlterField(
19 | model_name="appsetting",
20 | name="aws_secret_access_key",
21 | field=models.CharField(blank=True, default="", max_length=1024),
22 | ),
23 | ]
24 |
--------------------------------------------------------------------------------
/backend/migrations/0007_log_mapped_to_file.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 4.2.1 on 2023-09-15 04:30
2 |
3 | from django.db import migrations, models
4 | import django.db.models.deletion
5 |
6 |
7 | class Migration(migrations.Migration):
8 |
9 | dependencies = [
10 | ("backend", "0006_inference_time_converted_to_float"),
11 | ]
12 |
13 | operations = [
14 | migrations.AddField(
15 | model_name="internalfileobject",
16 | name="inference_log",
17 | field=models.ForeignKey(
18 | default=None,
19 | null=True,
20 | on_delete=django.db.models.deletion.SET_NULL,
21 | to="backend.inferencelog",
22 | ),
23 | ),
24 | ]
25 |
--------------------------------------------------------------------------------
/backend/migrations/0013_filter_keys_added.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 4.2.1 on 2024-02-18 07:55
2 |
3 | from django.db import migrations, models
4 |
5 |
6 | class Migration(migrations.Migration):
7 |
8 | dependencies = [
9 | ("backend", "0012_shot_added_and_redundant_fields_removed"),
10 | ]
11 |
12 | operations = [
13 | migrations.AddField(
14 | model_name="inferencelog",
15 | name="model_name",
16 | field=models.CharField(blank=True, default="", max_length=512),
17 | ),
18 | migrations.AddField(
19 | model_name="internalfileobject",
20 | name="shot_uuid",
21 | field=models.CharField(blank=True, default="", max_length=255),
22 | ),
23 | ]
24 |
--------------------------------------------------------------------------------
/utils/cookie_manager/cookie.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import streamlit as st
3 | import extra_streamlit_components as stx
4 |
5 | from utils.constants import AUTH_TOKEN
6 |
7 | # NOTE: code not working properly. check again after patch from the streamlit team
8 | # @st.cache(allow_output_mutation=True)
9 | # def get_manager():
10 | # return stx.CookieManager()
11 |
12 |
13 | # def get_cookie(key):
14 | # cookie_manager = get_manager()
15 | # return cookie_manager.get(cookie=key)
16 |
17 | # def set_cookie(key, value):
18 | # cookie_manager = get_manager()
19 | # expiration_time = datetime.datetime.now() + datetime.timedelta(days=1)
20 | # cookie_manager.set(key, value, expires_at=expiration_time)
21 |
22 | # def delete_cookie(key):
23 | # cookie_manager = get_manager()
24 | # cookie = get_cookie(key)
25 | # if cookie:
26 | # cookie_manager.delete(key)
27 |
--------------------------------------------------------------------------------
/utils/local_storage/url_storage.py:
--------------------------------------------------------------------------------
1 | # no persistent storage present for streamlit at the moment, so storing things in the url. will update this asap
2 | import streamlit as st
3 |
4 |
5 | def get_url_param(key):
6 | params = st.experimental_get_query_params()
7 | val = params.get(key)
8 | if isinstance(val, list):
9 | res = val[0]
10 | else:
11 | res = val
12 |
13 | if not res and (key in st.session_state and st.session_state[key]):
14 | set_url_param(key, st.session_state[key])
15 | return st.session_state[key]
16 | return res
17 |
18 |
19 | def set_url_param(key, value):
20 | st.session_state[key] = value
21 | st.experimental_set_query_params(**{key: [value]})
22 |
23 |
24 | def delete_url_param(key):
25 | print("deleting key: ", key)
26 | if key in st.session_state:
27 | del st.session_state[key]
28 |
29 | st.experimental_set_query_params(**{key: None})
30 |
--------------------------------------------------------------------------------
/scripts/windows_setup.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 | set "folderName=Dough"
3 | if not exist "%folderName%\" (
4 | if /i not "%CD%"=="%~dp0%folderName%\" (
5 | git clone --depth 1 -b main https://github.com/banodoco/Dough.git
6 | cd Dough
7 | git clone --depth 1 -b main https://github.com/piyushK52/comfy_runner.git
8 | git clone https://github.com/comfyanonymous/ComfyUI.git
9 | python -m venv dough-env
10 | call dough-env\Scripts\activate.bat
11 | python.exe -m pip install --upgrade pip
12 | pip install -r requirements.txt
13 | pip install websocket
14 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
15 | pip install -r comfy_runner\requirements.txt
16 | pip install -r ComfyUI\requirements.txt
17 | call dough-env\Scripts\deactivate.bat
18 | copy .env.sample .env
19 | cd ..
20 | pause
21 | )
22 | )
--------------------------------------------------------------------------------
/utils/ml_processor/ml_interface.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 |
3 | from shared.constants import GPU_INFERENCE_ENABLED_KEY, ConfigManager
4 |
5 | config_manager = ConfigManager()
6 | gpu_enabled = config_manager.get(GPU_INFERENCE_ENABLED_KEY, False)
7 |
8 | def get_ml_client():
9 | from utils.ml_processor.sai.api import APIProcessor
10 | from utils.ml_processor.gpu.gpu import GPUProcessor
11 |
12 | return APIProcessor() if not gpu_enabled else GPUProcessor()
13 |
14 |
15 | class MachineLearningProcessor(ABC):
16 | def __init__(self):
17 | pass
18 |
19 | def predict_model_output_standardized(self, *args, **kwargs):
20 | pass
21 |
22 | def predict_model_output(self, *args, **kwargs):
23 | pass
24 |
25 | def upload_training_data(self, *args, **kwargs):
26 | pass
27 |
28 | # NOTE: implementation not neccessary as this functionality is removed from the app
29 | def dreambooth_training(self, *args, **kwargs):
30 | pass
31 |
--------------------------------------------------------------------------------
/scripts/linux_setup_online.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Store the current directory path
4 | current_dir="$(pwd)"
5 |
6 | # Define the project directory path
7 | project_dir="$current_dir/Dough"
8 |
9 | # Check if the "Dough" directory doesn't exist and we're not already inside it
10 | if [ ! -d "$project_dir" ] && [ "$(basename "$current_dir")" != "Dough" ]; then
11 | # Clone the git repo
12 | git clone --depth 1 -b main https://github.com/banodoco/Dough.git "$project_dir"
13 | cd "$project_dir"
14 |
15 | # Create virtual environment
16 | python3.10 -m venv "dough-env"
17 |
18 | # Install system dependencies
19 | if command -v sudo &> /dev/null; then
20 | sudo apt-get update && sudo apt-get install -y libpq-dev python3.10-dev
21 | else
22 | apt-get update && apt-get install -y libpq-dev python3.10-dev
23 | fi
24 |
25 | echo $(pwd)
26 | . ./dough-env/bin/activate && pip install -r "requirements.txt"
27 |
28 | # Copy the environment file
29 | cp "$project_dir/.env.sample" "$project_dir/.env"
30 | fi
31 |
--------------------------------------------------------------------------------
/ui_components/components/shortlist_page.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from ui_components.components.explorer_page import gallery_image_view
3 | from utils.data_repo.data_repo import DataRepo
4 | from utils import st_memory
5 |
6 |
7 | def shortlist_page(project_uuid):
8 |
9 | st.markdown(f"#### :green[{st.session_state['main_view_type']}] > :red[{st.session_state['page']}]")
10 |
11 | data_repo = DataRepo()
12 | project_setting = data_repo.get_project_setting(project_uuid)
13 | # columnn_selecter()
14 | # k1,k2 = st.columns([5,1])
15 | # shortlist_page_number = k1.radio("Select page", options=range(1, project_setting.total_shortlist_gallery_pages), horizontal=True, key="shortlist_gallery")
16 | # with k2:
17 | # open_detailed_view_for_all = st_memory.toggle("Open prompt details for all:", key='shortlist_gallery_toggle',value=False)
18 | st.markdown("***")
19 | gallery_image_view(
20 | project_uuid,
21 | True,
22 | view=["view_inference_details", "add_to_any_shot", "add_and_remove_from_shortlist"],
23 | )
24 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | streamlit==1.27.0
2 | streamlit-image-comparison==0.0.3
3 | opencv-python-headless
4 | opencv-python==4.8.0.74
5 | sahi
6 | Pillow==9.4.0
7 | moviepy==1.0.3
8 | pandas==2.2.2
9 | replicate==0.26.1
10 | requests==2.28.2
11 | boto3==1.26.54
12 | ffmpeg-python==0.2.0
13 | streamlit-drawable-canvas==0.9.0
14 | numpy==1.24.4
15 | mesa
16 | streamlit-javascript==0.1.5
17 | streamlit-cropper==0.2.1
18 | pydub==0.25.1
19 | requests-toolbelt==1.0.0
20 | streamlit-option-menu==0.3.6
21 | colorlog==6.7.0
22 | Django==4.2.1
23 | djangorestframework==3.14.0
24 | python-dotenv==0.19.2
25 | scikit-image==0.21.0
26 | sentry-sdk==1.29.2
27 | st-clickable-images==0.0.3
28 | htbuilder==0.6.1
29 | streamlit-extras==0.2.7
30 | cryptography==41.0.1
31 | watchdog==3.0.0
32 | httpx-oauth==0.13.0
33 | extra-streamlit-components==0.1.56
34 | wrapt==1.15.0
35 | pydantic==1.10.9
36 | streamlit-server-state==0.17.1
37 | setproctitle==1.3.3
38 | gitdb==4.0.11
39 | websockets
40 | psutil==5.9.8
41 | black==24.4.2
42 | onnxruntime==1.18.1
43 | opencv-contrib-python==4.10.0.84
44 | portalocker==2.10.1
45 | Flask==3.0.3
46 | PyJWT==2.9.0
--------------------------------------------------------------------------------
/backend/migrations/0011_lock_added.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 4.2.1 on 2023-10-17 05:06
2 |
3 | from django.db import migrations, models
4 | import uuid
5 |
6 |
7 | class Migration(migrations.Migration):
8 |
9 | dependencies = [
10 | ("backend", "0010_project_metadata_added"),
11 | ]
12 |
13 | operations = [
14 | migrations.CreateModel(
15 | name="Lock",
16 | fields=[
17 | (
18 | "id",
19 | models.BigAutoField(
20 | auto_created=True, primary_key=True, serialize=False, verbose_name="ID"
21 | ),
22 | ),
23 | ("uuid", models.UUIDField(default=uuid.uuid4)),
24 | ("created_on", models.DateTimeField(auto_now_add=True)),
25 | ("updated_on", models.DateTimeField(auto_now=True)),
26 | ("is_disabled", models.BooleanField(default=False)),
27 | ("row_key", models.CharField(max_length=255, unique=True)),
28 | ],
29 | options={
30 | "db_table": "lock",
31 | },
32 | ),
33 | ]
34 |
--------------------------------------------------------------------------------
/utils/local_storage/local_storage.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 |
4 |
5 | MOTION_LORA_DB = "data.json"
6 |
7 |
8 | def is_file_present(filename):
9 | script_directory = os.path.dirname(os.path.abspath(__file__))
10 | file_path = os.path.join(script_directory, filename)
11 | return os.path.isfile(file_path)
12 |
13 |
14 | def write_to_motion_lora_local_db(update_data):
15 | data_store = MOTION_LORA_DB
16 |
17 | data = {}
18 | if os.path.exists(data_store):
19 | try:
20 | with open(data_store, "r", encoding="utf-8") as file:
21 | data = json.loads(file.read())
22 | except Exception as e:
23 | pass
24 |
25 | for key, value in update_data.items():
26 | data[key] = value
27 |
28 | data = json.dumps(data, indent=4)
29 | with open(data_store, "w", encoding="utf-8") as file:
30 | file.write(data)
31 |
32 |
33 | def read_from_motion_lora_local_db(key=None):
34 | data_store = MOTION_LORA_DB
35 |
36 | data = {}
37 | if os.path.exists(data_store):
38 | with open(data_store, "r", encoding="utf-8") as file:
39 | data = json.loads(file.read())
40 |
41 | return data[key] if key in data else data
42 |
--------------------------------------------------------------------------------
/utils/ml_processor/comfy_workflows/llama2_workflow_api.json:
--------------------------------------------------------------------------------
1 | {
2 | "1": {
3 | "inputs": {
4 | "Model": "llama-2-7b.Q5_0.gguf",
5 | "n_ctx": 0
6 | },
7 | "class_type": "Load LLM Model Basic",
8 | "_meta": {
9 | "title": "Load LLM Model Basic"
10 | }
11 | },
12 | "14": {
13 | "inputs": {
14 | "text": [
15 | "15",
16 | 0
17 | ]
18 | },
19 | "class_type": "ShowText|pysssss",
20 | "_meta": {
21 | "title": "Show Text 🐍"
22 | }
23 | },
24 | "15": {
25 | "inputs": {
26 | "prompt": "write a poem on finding your way in about 100 words",
27 | "suffix": "",
28 | "max_response_tokens": 500,
29 | "temperature": 0.8,
30 | "top_p": 0.95,
31 | "min_p": 0.05,
32 | "typical_p": 1,
33 | "echo": false,
34 | "frequency_penalty": 0,
35 | "presence_penalty": 0,
36 | "repeat_penalty": 1.1,
37 | "top_k": 40,
38 | "seed": 273,
39 | "tfs_z": 1,
40 | "mirostat_mode": 0,
41 | "mirostat_tau": 5,
42 | "mirostat_eta": 0.1,
43 | "LLM": [
44 | "1",
45 | 0
46 | ]
47 | },
48 | "class_type": "Call LLM Advanced",
49 | "_meta": {
50 | "title": "Call LLM Advanced"
51 | }
52 | }
53 | }
--------------------------------------------------------------------------------
/ui_components/widgets/attach_audio_element.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from ui_components.methods.common_methods import save_audio_file
3 | from ui_components.models import InternalProjectObject, InternalSettingObject
4 | from utils.data_repo.data_repo import DataRepo
5 | from utils.state_refresh import refresh_app
6 |
7 |
8 | def attach_audio_element(project_uuid, expanded):
9 | data_repo = DataRepo()
10 | project_setting: InternalSettingObject = data_repo.get_project_setting(project_uuid)
11 |
12 | with st.expander("🔊 Audio", expanded=expanded):
13 |
14 | uploaded_file = st.file_uploader(
15 | "Attach audio", type=["mp3"], help="This will attach this audio when you render a video"
16 | )
17 | if st.button("Upload and attach new audio"):
18 | if uploaded_file:
19 | save_audio_file(uploaded_file, project_uuid)
20 | refresh_app()
21 | else:
22 | st.warning("No file selected")
23 |
24 | if project_setting.audio:
25 | # TODO: store "extracted_audio.mp3" in a constant
26 | if project_setting.audio.name == "extracted_audio.mp3":
27 | st.info("You have attached the audio from the video you uploaded.")
28 |
29 | if project_setting.audio.location:
30 | st.audio(project_setting.audio.location)
31 |
--------------------------------------------------------------------------------
/scripts/linux_setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Store the current directory path
4 | current_dir="$(pwd)"
5 |
6 | # Define the project directory path
7 | project_dir="$current_dir/Dough"
8 |
9 | # Check if the "Dough" directory doesn't exist and we're not already inside it
10 | if [ ! -d "$project_dir" ] && [ "$(basename "$current_dir")" != "Dough" ]; then
11 | # Clone the git repo
12 | git clone --depth 1 -b main https://github.com/banodoco/Dough.git "$project_dir"
13 | cd "$project_dir"
14 | git clone --depth 1 -b main https://github.com/piyushK52/comfy_runner.git
15 | git clone https://github.com/comfyanonymous/ComfyUI.git
16 |
17 | # Create virtual environment
18 | python3 -m venv "dough-env"
19 |
20 | # Install system dependencies
21 | if command -v sudo &> /dev/null; then
22 | sudo apt-get update && sudo apt-get install -y libpq-dev python3.10-dev
23 | else
24 | apt-get update && apt-get install -y libpq-dev python3.10-dev
25 | fi
26 |
27 | # Install Python dependencies
28 | echo $(pwd)
29 | . ./dough-env/bin/activate && pip install -r "requirements.txt"
30 | . ./dough-env/bin/activate && pip install -r "comfy_runner/requirements.txt"
31 | . ./dough-env/bin/activate && pip install -r "ComfyUI/requirements.txt"
32 |
33 | # Copy the environment file
34 | cp "$project_dir/.env.sample" "$project_dir/.env"
35 | fi
36 |
--------------------------------------------------------------------------------
/utils/ml_processor/motion_module.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 |
3 |
4 | @dataclass
5 | class MotionModuleCheckpoint:
6 | name: str
7 |
8 |
9 | # make sure to have unique names (streamlit limitation)
10 | class AnimateDiffCheckpoint:
11 | mm_v15 = MotionModuleCheckpoint(name="mm_sd_v15.ckpt")
12 | mm_v14 = MotionModuleCheckpoint(name="mm_sd_v14.ckpt")
13 |
14 | @staticmethod
15 | def get_name_list():
16 | checkpoint_names = [
17 | getattr(AnimateDiffCheckpoint, attr).name
18 | for attr in dir(AnimateDiffCheckpoint)
19 | if not callable(getattr(AnimateDiffCheckpoint, attr))
20 | and not attr.startswith("__")
21 | and isinstance(getattr(AnimateDiffCheckpoint, attr), MotionModuleCheckpoint)
22 | ]
23 | return checkpoint_names
24 |
25 | @staticmethod
26 | def get_model_from_name(name):
27 | checkpoint_list = [
28 | getattr(AnimateDiffCheckpoint, attr)
29 | for attr in dir(AnimateDiffCheckpoint)
30 | if not callable(getattr(AnimateDiffCheckpoint, attr))
31 | and not attr.startswith("__")
32 | and isinstance(getattr(AnimateDiffCheckpoint, attr), MotionModuleCheckpoint)
33 | ]
34 |
35 | for ckpt in checkpoint_list:
36 | if ckpt.name == name:
37 | return ckpt
38 |
39 | return None
40 |
--------------------------------------------------------------------------------
/ui_components/widgets/display_element.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | import streamlit as st
3 | from shared.constants import SERVER, ServerType
4 | from ui_components.methods.file_methods import get_file_size
5 | from ui_components.models import InternalFileObject
6 | from utils.local_storage.local_storage import read_from_motion_lora_local_db
7 |
8 | MAX_LOADING_FILE_SIZE = 10
9 |
10 |
11 | def individual_video_display_element(file: Union[InternalFileObject, str], dont_bypass_file_size_check=True):
12 | file_location = file.location if file and not isinstance(file, str) and file.location else file
13 | show_video_file = (
14 | SERVER == ServerType.DEVELOPMENT.value
15 | or not dont_bypass_file_size_check
16 | or get_file_size(file_location) < MAX_LOADING_FILE_SIZE
17 | )
18 | if file_location:
19 | (
20 | st.video(file_location, format="mp4", start_time=0)
21 | if show_video_file
22 | else st.info("Video file too large to display")
23 | )
24 | else:
25 | st.error("No video present")
26 |
27 |
28 | def display_motion_lora(motion_lora, lora_file_dict={}):
29 | filename_video_dict = read_from_motion_lora_local_db()
30 |
31 | if motion_lora and motion_lora in filename_video_dict and filename_video_dict[motion_lora]:
32 | st.image(filename_video_dict[motion_lora])
33 | elif motion_lora in lora_file_dict:
34 | loras = [ele.split("/")[-1] for ele in lora_file_dict.keys()]
35 | try:
36 | idx = loras.index(motion_lora)
37 | if lora_file_dict[list(lora_file_dict.keys())[idx]]:
38 | st.image(lora_file_dict[list(lora_file_dict.keys())[idx]])
39 | except ValueError:
40 | st.write("")
41 |
--------------------------------------------------------------------------------
/ui_components/components/adjust_shot_page.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from ui_components.widgets.shot_view import shot_keyframe_element
3 | from ui_components.components.explorer_page import gallery_image_view
4 | from ui_components.components.explorer_page import generate_images_element
5 | from ui_components.components.frame_styling_page import frame_styling_page
6 | from ui_components.widgets.frame_selector import frame_selector_widget
7 | from utils import st_memory
8 | from ui_components.widgets.sidebar_logger import sidebar_logger
9 | from utils.data_repo.data_repo import DataRepo
10 |
11 |
12 | def adjust_shot_page(shot_uuid: str, h2):
13 | with st.sidebar:
14 | frame_selection = frame_selector_widget(show_frame_selector=True)
15 |
16 | data_repo = DataRepo()
17 | shot = data_repo.get_shot_from_uuid(shot_uuid)
18 |
19 | if frame_selection == "":
20 | with st.sidebar:
21 | st.write("")
22 |
23 | with st.expander("🔍 Generation log", expanded=True):
24 | # if st_memory.toggle("Open", value=True, key="generaton_log_toggle"):
25 | sidebar_logger(st.session_state["shot_uuid"])
26 |
27 | st.markdown(
28 | f"#### :green[{st.session_state['main_view_type']}] > :red[{st.session_state['page']}] > :blue[{shot.name}]"
29 | )
30 | st.markdown("***")
31 |
32 | column1, column2 = st.columns([2, 1.35])
33 | with column1:
34 | st.markdown(f"### 🎬 '{shot.name}' frames")
35 | st.write("##### _\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_")
36 | items_per_row = st_memory.slider("Items per row:", 1, 10, 6, key="items_per_row")
37 |
38 | shot_keyframe_element(st.session_state["shot_uuid"], items_per_row, column2, position="Individual")
39 |
40 | else:
41 | frame_styling_page(st.session_state["shot_uuid"])
42 |
--------------------------------------------------------------------------------
/utils/encryption.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import hashlib
3 | import os
4 | from cryptography.fernet import Fernet
5 | from shared.constants import ENCRYPTION_KEY
6 |
7 |
8 | class Encryptor:
9 | def __init__(self):
10 | self.cipher = Fernet(ENCRYPTION_KEY)
11 |
12 | def encrypt(self, data: str):
13 | data_bytes = data.encode("utf-8")
14 | encrypted_data = self.cipher.encrypt(data_bytes)
15 | return encrypted_data
16 |
17 | def decrypt(self, data: str):
18 | data_bytes = data[2:-1].encode() if data.startswith("b'") else data.encode()
19 | decrypted_data = self.cipher.decrypt(data_bytes)
20 | return decrypted_data.decode("utf-8")
21 |
22 | def encrypt_json(self, data: str):
23 | encrypted_data = self.encrypt(data)
24 | encrypted_data_str = base64.b64encode(encrypted_data).decode("utf-8")
25 | return encrypted_data_str
26 |
27 | def decrypt_json(self, data: str):
28 | encrypted_data_bytes = base64.b64decode(data)
29 | decrypted_data = self.decrypt(encrypted_data_bytes.decode("utf-8"))
30 | return decrypted_data
31 |
32 |
33 | def validate_file_hash(file_path, expected_hash_list):
34 | if not os.path.exists(file_path):
35 | return False
36 |
37 | hash_md5 = hashlib.md5()
38 | with open(file_path, "rb") as f:
39 | for chunk in iter(lambda: f.read(4096), b""):
40 | hash_md5.update(chunk)
41 | calculated_hash = hash_md5.hexdigest()
42 |
43 | return len(expected_hash_list) and str(calculated_hash) in expected_hash_list
44 |
45 |
46 | def generate_file_hash(file_path):
47 | if not os.path.exists(file_path):
48 | return None
49 |
50 | hash_md5 = hashlib.md5()
51 | with open(file_path, "rb") as f:
52 | for chunk in iter(lambda: f.read(4096), b""):
53 | hash_md5.update(chunk)
54 |
55 | return hash_md5.hexdigest()
56 |
--------------------------------------------------------------------------------
/utils/third_party_auth/google/google_auth.py:
--------------------------------------------------------------------------------
1 | from abc import ABC
2 | from dotenv import load_dotenv
3 | import requests
4 |
5 | from shared.constants import SERVER_URL
6 | from shared.logging.logging import AppLogger
7 |
8 | load_dotenv()
9 |
10 | logger = AppLogger()
11 |
12 |
13 | def get_auth_provider():
14 | return GoogleAuthProvider() # TODO: add test provider for development
15 |
16 |
17 | class AuthProvider(ABC):
18 | def get_auth_url(self, *args, **kwargs):
19 | pass
20 |
21 | def verify_auth_details(self, *args, **kwargs):
22 | pass
23 |
24 |
25 | # TODO: make this implementation 'proper'
26 | class GoogleAuthProvider(AuthProvider):
27 | def __init__(self):
28 | self.url = f"{SERVER_URL}/v1/authentication/google"
29 |
30 | def get_auth_url(self, redirect_uri):
31 | params = {"redirect_uri": redirect_uri}
32 |
33 | response = requests.get(self.url, params=params)
34 |
35 | if response.status_code == 200:
36 | data = response.json()
37 | auth_url = data["payload"]["data"]["url"]
38 | return f""" Google login -> """
39 | else:
40 | print(f"Error: {response.status_code} - {response.text}")
41 |
42 | return None
43 |
44 | def verify_auth_details(self, auth_details=None):
45 | response = requests.post(self.url, json=auth_details, headers={"Content-Type": "application/json"})
46 |
47 | if response.status_code == 200:
48 | data = response.json()
49 | if not data["status"]:
50 | return None, None, None
51 |
52 | user = {"name": data["payload"]["user"]["name"], "email": data["payload"]["user"]["email"]}
53 | return data["payload"]["token"], data["payload"]["refresh_token"], user
54 | else:
55 | logger.error("auth verification failed:", response.text)
56 |
57 | return None, None, None
58 |
--------------------------------------------------------------------------------
/ui_components/widgets/video_cropping_element.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import cv2
3 | from PIL import Image
4 | import numpy as np
5 | import tempfile
6 | from moviepy.editor import VideoFileClip
7 |
8 |
9 | def video_cropping_element(shot_uuid):
10 | st.title("Video Cropper")
11 |
12 | video_file = st.file_uploader("Upload a video", type=["mp4", "mov", "avi", "mkv"])
13 | video_url = st.text_input("...or enter a video URL")
14 |
15 | if video_file or video_url:
16 | tfile = tempfile.NamedTemporaryFile(delete=False)
17 | if video_file:
18 | tfile.write(video_file.read())
19 | video_path = tfile.name
20 | else:
21 | video_path = video_url
22 |
23 | cap = cv2.VideoCapture(video_path)
24 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
25 | fps = cap.get(cv2.CAP_PROP_FPS)
26 | duration = total_frames / fps
27 | cap.release()
28 |
29 | start_time = st.slider("Start Time", 0.0, float(duration), 0.0, 0.1)
30 | end_time = st.slider("End Time", 0.0, float(duration), float(duration), 0.1)
31 |
32 | starting1, starting2 = st.columns(2)
33 | with starting1:
34 | starting_frame_number = int(start_time * fps)
35 | display_frame(video_path, starting_frame_number)
36 | with starting2:
37 | ending_frame_number = int(end_time * fps)
38 | display_frame(video_path, ending_frame_number)
39 |
40 | if st.button("Save New Video"):
41 | with st.spinner("Processing..."):
42 | clip = VideoFileClip(video_path).subclip(start_time, end_time)
43 | output_file = video_path.split(".")[0] + "_cropped.mp4"
44 | clip.write_videofile(output_file)
45 | st.success("Saved as {}".format(output_file))
46 |
47 |
48 | def display_frame(video_path, frame_number):
49 | cap = cv2.VideoCapture(video_path)
50 | cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
51 | ret, frame = cap.read()
52 | if ret:
53 | frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
54 | st.image(frame)
55 | cap.release()
56 |
--------------------------------------------------------------------------------
/backend/migrations/0014_file_link_added.py:
--------------------------------------------------------------------------------
1 | # Generated by Django 4.2.1 on 2024-06-05 14:48
2 |
3 | from django.db import migrations, models
4 | import django.db.models.deletion
5 | import uuid
6 |
7 |
8 | class Migration(migrations.Migration):
9 |
10 | dependencies = [
11 | ("backend", "0013_filter_keys_added"),
12 | ]
13 |
14 | operations = [
15 | migrations.CreateModel(
16 | name="FileRelationship",
17 | fields=[
18 | (
19 | "id",
20 | models.BigAutoField(
21 | auto_created=True, primary_key=True, serialize=False, verbose_name="ID"
22 | ),
23 | ),
24 | ("uuid", models.UUIDField(default=uuid.uuid4)),
25 | ("created_on", models.DateTimeField(auto_now_add=True)),
26 | ("updated_on", models.DateTimeField(auto_now=True)),
27 | ("is_disabled", models.BooleanField(default=False)),
28 | ("transformation_type", models.TextField(default="")),
29 | ("child_entity_type", models.CharField(default="file", max_length=255)),
30 | ("parent_entity_type", models.CharField(default="file", max_length=255)),
31 | (
32 | "child_entity",
33 | models.ForeignKey(
34 | default=None,
35 | null=True,
36 | on_delete=django.db.models.deletion.CASCADE,
37 | related_name="child_entity",
38 | to="backend.internalfileobject",
39 | ),
40 | ),
41 | (
42 | "parent_entity",
43 | models.ForeignKey(
44 | default=None,
45 | null=True,
46 | on_delete=django.db.models.deletion.CASCADE,
47 | related_name="parent_entity",
48 | to="backend.internalfileobject",
49 | ),
50 | ),
51 | ],
52 | options={
53 | "db_table": "file_relationship",
54 | },
55 | ),
56 | ]
57 |
--------------------------------------------------------------------------------
/utils/ml_processor/sai/sai.py:
--------------------------------------------------------------------------------
1 | import time
2 | from shared.constants import InferenceParamType
3 | from ui_components.methods.data_logger import log_model_inference
4 | from utils.constants import MLQueryObject
5 | from utils.ml_processor.constants import MLModel
6 | from utils.ml_processor.ml_interface import MachineLearningProcessor
7 | from utils.ml_processor.sai.utils import get_model_params_from_query_obj, predict_sai_output
8 |
9 |
10 | # rn only used for sd3 and doesn't have all methods of MachineLearningProcessor
11 | class StabilityProcessor(MachineLearningProcessor):
12 | def __init__(self):
13 | pass
14 |
15 | def predict_model_output_standardized(
16 | self, model: MLModel, query_obj: MLQueryObject, queue_inference=False, backlog=False
17 | ):
18 | params = get_model_params_from_query_obj(model, query_obj)
19 |
20 | if params:
21 | new_params = {}
22 | new_params[InferenceParamType.QUERY_DICT.value] = params
23 | new_params[InferenceParamType.SAI_INFERENCE.value] = params
24 | return (
25 | self.predict_model_output(model, **new_params)
26 | if not queue_inference
27 | else self.queue_prediction(model, **new_params)
28 | ) # add backlog later
29 |
30 | def predict_model_output(self, model: MLModel, **kwargs):
31 | queue_inference = kwargs.get("queue_inference", False)
32 | if queue_inference:
33 | del kwargs["queue_inference"]
34 | return self.queue_prediction(model, **kwargs)
35 |
36 | start_time = time.time()
37 | output = predict_sai_output(kwargs.get(InferenceParamType.SAI_INFERENCE.value, None))
38 | end_time = time.time()
39 |
40 | if "model" in kwargs:
41 | kwargs["inf_model"] = kwargs["model"]
42 | del kwargs["model"]
43 |
44 | log = log_model_inference(model, end_time - start_time, **kwargs)
45 | # TODO: update usage credits in the api mode
46 | # self.update_usage_credits(end_time - start_time)
47 |
48 | return output, log
49 |
50 | def queue_prediction(self, model, **kwargs):
51 | log = log_model_inference(model, None, **kwargs)
52 | return None, log
53 |
--------------------------------------------------------------------------------
/django_settings.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | import sys
4 | from django.db.backends.signals import connection_created
5 |
6 | sys.path.append("../")
7 |
8 |
9 | from dotenv import load_dotenv
10 |
11 | from shared.constants import HOSTED_BACKGROUND_RUNNER_MODE, LOCAL_DATABASE_NAME, SERVER, ServerType
12 |
13 |
14 | load_dotenv()
15 |
16 | if SERVER == ServerType.DEVELOPMENT.value:
17 | DB_LOCATION = LOCAL_DATABASE_NAME
18 | else:
19 | DB_LOCATION = ""
20 |
21 | BASE_DIR = Path(__file__).resolve().parent.parent
22 |
23 | def set_sqlite_timeout(sender, connection, **kwargs):
24 | if connection.vendor == 'sqlite':
25 | cursor = connection.cursor()
26 | cursor.execute('PRAGMA busy_timeout = 30000;') # 30 seconds
27 |
28 | if HOSTED_BACKGROUND_RUNNER_MODE in [False, "False"]:
29 | DATABASES = {
30 | "default": {
31 | "ENGINE": "django.db.backends.sqlite3", # sqlite by default works with serializable isolation level
32 | "NAME": DB_LOCATION,
33 | }
34 | }
35 | connection_created.connect(set_sqlite_timeout)
36 | else:
37 | import boto3
38 |
39 | ssm = boto3.client("ssm", region_name="ap-south-1")
40 | DB_NAME = ssm.get_parameter(Name="/backend/banodoco/db/name")["Parameter"]["Value"]
41 | DB_USER = ssm.get_parameter(Name="/backend/banodoco/db/user")["Parameter"]["Value"]
42 | DB_PASS = ssm.get_parameter(Name="/backend/banodoco/db/password")["Parameter"]["Value"]
43 | DB_HOST = ssm.get_parameter(Name="/backend/banodoco/db/host")["Parameter"]["Value"]
44 | DB_PORT = ssm.get_parameter(Name="/backend/banodoco/db/port")["Parameter"]["Value"]
45 |
46 | DATABASES = {
47 | "default": {
48 | "ENGINE": "django.db.backends.postgresql",
49 | "NAME": DB_NAME,
50 | "USER": DB_USER,
51 | "PASSWORD": DB_PASS,
52 | "HOST": DB_HOST,
53 | "PORT": DB_PORT,
54 | },
55 | 'OPTIONS': {
56 | 'isolation_level': 'repeatable_read', # TODO: test this isolation_level if deployed on prod
57 | },
58 | }
59 |
60 | INSTALLED_APPS = ("backend",)
61 |
62 | DEFAULT_AUTO_FIELD = "django.db.models.BigAutoField"
63 |
64 | SECRET_KEY = "4e&6aw+(5&cg^_!05r(&7_#dghg_pdgopq(yk)xa^bog7j)^*j"
65 |
--------------------------------------------------------------------------------
/ui_components/components/inspiraton_engine_page.py:
--------------------------------------------------------------------------------
1 | from shared.constants import COMFY_BASE_PATH
2 | import streamlit as st
3 | from ui_components.constants import CreativeProcessType
4 | from ui_components.widgets.inspiration_engine import inspiration_engine_element
5 | from ui_components.widgets.timeline_view import timeline_view
6 | from ui_components.components.explorer_page import gallery_image_view
7 | from utils import st_memory
8 | from utils.data_repo.data_repo import DataRepo
9 |
10 | from ui_components.widgets.sidebar_logger import sidebar_logger
11 | from ui_components.components.explorer_page import generate_images_element
12 |
13 |
14 | def inspiration_engine_page(shot_uuid: str, h2):
15 | data_repo = DataRepo()
16 | shot = data_repo.get_shot_from_uuid(shot_uuid)
17 | if not shot:
18 | st.error("Shot not found")
19 | else:
20 | project_uuid = shot.project.uuid
21 | project = data_repo.get_project_from_uuid(project_uuid)
22 |
23 | with st.sidebar:
24 | views = CreativeProcessType.value_list()
25 |
26 | if "view" not in st.session_state:
27 | st.session_state["view"] = views[0]
28 |
29 | st.write("")
30 |
31 | with st.expander("🔍 Generation log", expanded=True):
32 | sidebar_logger(st.session_state["shot_uuid"])
33 |
34 |
35 |
36 | st.write("")
37 | with st.expander("🪄 Shots", expanded=True):
38 | timeline_view(shot_uuid, "🪄 Shots", view="sidebar")
39 |
40 | st.markdown(f"#### :green[{st.session_state['main_view_type']}] > :red[{st.session_state['page']}]")
41 | st.markdown("***")
42 |
43 | st.markdown("### ✨ Generate images")
44 | st.write("##### _\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_")
45 |
46 | inspiration_engine_element(
47 | position="explorer", project_uuid=project_uuid, timing_uuid=None, shot_uuid=None
48 | )
49 |
50 | st.markdown("***")
51 | gallery_image_view(
52 | project_uuid,
53 | False,
54 | view=[
55 | "add_and_remove_from_shortlist",
56 | "view_inference_details",
57 | "shot_chooser",
58 | "add_to_any_shot",
59 | ],
60 | )
61 |
--------------------------------------------------------------------------------
/ui_components/components/user_login_page.py:
--------------------------------------------------------------------------------
1 | import time
2 | import streamlit as st
3 |
4 | from utils.data_repo.api_repo import APIRepo
5 | from utils.local_storage.url_storage import set_url_param
6 | from utils.state_refresh import refresh_app
7 | from utils.third_party_auth.google.google_auth import get_auth_provider
8 | from streamlit.web.server.server import Server
9 |
10 |
11 | def user_login_ui():
12 | params = st.experimental_get_query_params()
13 | api_repo = APIRepo()
14 |
15 | # http://localhost:5500/?code=4%2F0AQlEd8xV0xpyTCHnHJH8zopNtZ033s7m419wdSXLT7-fYK5KSk5PqDYR0bdM0F8UXzjJMQ&scope=email+profile+openid+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.profile+https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fuserinfo.email&authuser=0&prompt=consent
16 | # params found: {'code': ['4/0AQlEd8xy2YawBOvBcWiKBpwqtRtEE5i5hbpfx'], 'scope': ['email profile openid], 'authuser': ['0'], 'prompt': ['consent']}
17 | if params and "code" in params and not st.session_state.get("retry_login"):
18 | if st.session_state.get("retry_login"):
19 | st.session_state["retry_login"] = False
20 | refresh_app()
21 |
22 | st.markdown("#### Logging you in, please wait...")
23 | # st.write(params['code'])
24 | data = {"id_token": params["code"][0]}
25 | auth_token, refresh_token, user = api_repo.auth_provider.verify_auth_details(data)
26 | if auth_token and refresh_token:
27 | st.success("Successfully logged In, settings things up...")
28 | api_repo.set_auth_token(auth_token, refresh_token, user)
29 | refresh_app()
30 | else:
31 | st.error("Unable to login..")
32 | if st.button("Retry Login", key="retry_login_btn"):
33 | st.session_state["retry_login"] = True
34 | refresh_app()
35 | else:
36 | st.session_state["retry_login"] = False
37 | st.markdown("# :green[D]:red[o]:blue[u]:orange[g]:green[h] :red[□] :blue[□] :orange[□]")
38 | st.markdown("#### Login with Google to proceed")
39 |
40 | auth_url = api_repo.auth_provider.get_auth_url(redirect_uri="http://localhost:5500")
41 | if auth_url:
42 | st.markdown(auth_url, unsafe_allow_html=True)
43 | else:
44 | time.sleep(0.1)
45 | st.warning("Unable to generate login link, please contact support")
46 |
--------------------------------------------------------------------------------
/shared/logging/logging.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import colorlog
3 | from shared.constants import SERVER, ServerType
4 |
5 | from shared.logging.constants import LoggingMode, LoggingPayload, LoggingType
6 |
7 |
8 | class AppLogger(logging.Logger):
9 | def __init__(self, name="app_logger", log_file=None, log_level=logging.DEBUG):
10 | super().__init__(name, log_level)
11 | self.log_file = log_file
12 |
13 | self._configure_logging()
14 |
15 | def _configure_logging(self):
16 | log_formatter = colorlog.ColoredFormatter(
17 | "%(log_color)s%(levelname)s:%(name)s:%(message)s",
18 | log_colors={
19 | "DEBUG": "cyan",
20 | "INFO": "green",
21 | "WARNING": "yellow",
22 | "ERROR": "red",
23 | "CRITICAL": "red,bg_white",
24 | },
25 | reset=True,
26 | secondary_log_colors={},
27 | style="%",
28 | )
29 | if self.log_file:
30 | file_handler = logging.FileHandler(self.log_file)
31 | file_handler.setFormatter(log_formatter)
32 | self.addHandler(file_handler)
33 |
34 | console_handler = logging.StreamHandler()
35 | console_handler.setFormatter(log_formatter)
36 | self.addHandler(console_handler)
37 |
38 | # setting logging mode
39 | if SERVER != ServerType.PRODUCTION.value:
40 | self.logging_mode = LoggingMode.OFFLINE.value
41 | else:
42 | self.logging_mode = LoggingMode.ONLINE.value
43 |
44 | # def log(self, log_type: LoggingType, log_payload: LoggingPayload):
45 | # if log_type == LoggingType.DEBUG:
46 | # self.debug(log_payload.message)
47 | # elif log_type == LoggingType.INFO:
48 | # self.info(log_payload.message)
49 | # elif log_type == LoggingType.ERROR:
50 | # self.error(log_payload.message)
51 | # elif log_type in [LoggingType.INFERENCE_CALL, LoggingType.INFERENCE_RESULT]:
52 | # self.info(log_payload.message)
53 |
54 | def log(self, log_type: LoggingType, log_message, log_data=None):
55 | if log_type == LoggingType.DEBUG:
56 | self.debug(log_message)
57 | elif log_type == LoggingType.INFO:
58 | self.info(log_message)
59 | elif log_type == LoggingType.ERROR:
60 | self.error(log_message)
61 | elif log_type in [LoggingType.INFERENCE_CALL, LoggingType.INFERENCE_RESULT]:
62 | self.info(log_message)
63 |
64 |
65 | app_logger = AppLogger()
66 |
--------------------------------------------------------------------------------
/utils/ml_processor/sai/utils.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | import json
3 | import os
4 | import random
5 | import uuid
6 |
7 | import requests
8 | from shared.constants import InferenceParamType
9 | from utils.constants import MLQueryObject
10 | from utils.encryption import Encryptor
11 | from utils.ml_processor.constants import ML_MODEL, MLModel
12 |
13 |
14 | def predict_sai_output(data):
15 | if not data:
16 | return None
17 |
18 | # TODO: decouple encryptor from this function
19 | encryptor = Encryptor()
20 | sai_key = encryptor.decrypt_json(data["data"]["data"]["stability_key"])
21 | input_params = deepcopy(data)
22 | del input_params["data"]
23 |
24 | response = requests.post(
25 | f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
26 | headers={"authorization": f"Bearer {sai_key}", "accept": "image/*"},
27 | files={"none": ""},
28 | data=input_params,
29 | )
30 |
31 | if response.status_code == 200:
32 | unique_filename = os.path.join("output", str(uuid.uuid4()) + ".png")
33 | if not os.path.exists("output"):
34 | os.makedirs("output")
35 | with open(unique_filename, "wb") as file:
36 | file.write(response.content)
37 |
38 | return unique_filename
39 | else:
40 | raise Exception(str(response.json()))
41 |
42 |
43 | def get_closest_aspect_ratio(width, height):
44 | aspect_ratios = ["16:9", "1:1", "21:9", "2:3", "3:2", "4:5", "5:4", "9:16", "9:21"]
45 | ratio = width / height
46 | closest_ratio = None
47 | min_difference = float("inf")
48 |
49 | for aspect_ratio in aspect_ratios:
50 | aspect_width, aspect_height = aspect_ratio.split(":")
51 | aspect_ratio_value = int(aspect_width) / int(aspect_height)
52 |
53 | difference = abs(ratio - aspect_ratio_value)
54 | if difference < min_difference:
55 | min_difference = difference
56 | closest_ratio = aspect_ratio
57 |
58 | return closest_ratio
59 |
60 |
61 | def get_model_params_from_query_obj(model: MLModel, query_obj: MLQueryObject):
62 | if model == ML_MODEL.sd3:
63 | return {
64 | "mode": "text-to-image",
65 | "prompt": query_obj.prompt,
66 | "negative_prompt": query_obj.negative_prompt,
67 | "model": "sd3",
68 | "seed": random_seed(),
69 | "output_format": "png",
70 | "data": query_obj.data,
71 | "aspect_ratio": get_closest_aspect_ratio(query_obj.width, query_obj.height),
72 | }
73 | else:
74 | return None
75 |
76 |
77 | def random_seed():
78 | return random.randint(10**6, 10**8 - 1)
79 |
--------------------------------------------------------------------------------
/ui_components/components/timeline_view_page.py:
--------------------------------------------------------------------------------
1 | from shared.constants import COMFY_BASE_PATH
2 | import streamlit as st
3 | from ui_components.constants import CreativeProcessType
4 | from ui_components.widgets.timeline_view import timeline_view
5 |
6 | from utils import st_memory
7 | from utils.data_repo.data_repo import DataRepo
8 | from ui_components.widgets.sidebar_logger import sidebar_logger
9 | from ui_components.components.explorer_page import gallery_image_view
10 |
11 |
12 | def timeline_view_page(shot_uuid: str, h2):
13 | data_repo = DataRepo()
14 | shot = data_repo.get_shot_from_uuid(shot_uuid)
15 | if not shot:
16 | st.error("Shot not found")
17 | else:
18 | project_uuid = shot.project.uuid
19 | project = data_repo.get_project_from_uuid(project_uuid)
20 |
21 | with st.sidebar:
22 | views = CreativeProcessType.value_list()
23 |
24 | if "view" not in st.session_state:
25 | st.session_state["view"] = views[0]
26 |
27 | st.write("")
28 |
29 | with st.expander("🔍 Generation log", expanded=True):
30 | # if st_memory.toggle("Open", value=True, key="generaton_log_toggle"):
31 | sidebar_logger(st.session_state["shot_uuid"])
32 |
33 | st.write("")
34 |
35 | with st.expander("📋 Shortlist", expanded=True):
36 | if st_memory.toggle("Open", value=True, key="explorer_shortlist_toggle"):
37 | gallery_image_view(
38 | shot.project.uuid,
39 | shortlist=True,
40 | view=["add_and_remove_from_shortlist", "add_to_any_shot"],
41 | )
42 |
43 | st.markdown(f"#### :green[{st.session_state['main_view_type']}] > :red[{st.session_state['page']}]")
44 | st.markdown("***")
45 | slider1, slider2, slider3 = st.columns([2, 1, 1])
46 | with slider1:
47 | st.markdown(f"### 🪄 '{project.name}' shots")
48 | st.write("##### _\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_\_")
49 |
50 | # start_time = time.time()
51 | timeline_view(st.session_state["shot_uuid"], st.session_state["view"], view="main")
52 |
53 | # end_time = time.time()
54 | # print("///////////////// timeline laoded in: ", end_time - start_time)
55 | # generate_images_element(
56 | # position="explorer", project_uuid=project_uuid, timing_uuid=None, shot_uuid=None
57 | # )
58 |
59 | # end_time = time.time()
60 | # print("///////////////// generate img laoded in: ", end_time - start_time)
61 |
62 | # end_time = time.time()
63 | # print("///////////////// gallery laoded in: ", end_time - start_time)
64 |
--------------------------------------------------------------------------------
/utils/ml_processor/comfy_workflows/flux_schnell_workflow_api.json:
--------------------------------------------------------------------------------
1 | {
2 | "6": {
3 | "inputs": {
4 | "text": "a dark fantasy castle",
5 | "clip": [
6 | "30",
7 | 1
8 | ]
9 | },
10 | "class_type": "CLIPTextEncode",
11 | "_meta": {
12 | "title": "CLIP Text Encode (Positive Prompt)"
13 | }
14 | },
15 | "8": {
16 | "inputs": {
17 | "samples": [
18 | "31",
19 | 0
20 | ],
21 | "vae": [
22 | "30",
23 | 2
24 | ]
25 | },
26 | "class_type": "VAEDecode",
27 | "_meta": {
28 | "title": "VAE Decode"
29 | }
30 | },
31 | "9": {
32 | "inputs": {
33 | "filename_prefix": "ComfyUI",
34 | "images": [
35 | "8",
36 | 0
37 | ]
38 | },
39 | "class_type": "SaveImage",
40 | "_meta": {
41 | "title": "Save Image"
42 | }
43 | },
44 | "27": {
45 | "inputs": {
46 | "width": 1024,
47 | "height": 1024,
48 | "batch_size": 1
49 | },
50 | "class_type": "EmptySD3LatentImage",
51 | "_meta": {
52 | "title": "EmptySD3LatentImage"
53 | }
54 | },
55 | "30": {
56 | "inputs": {
57 | "ckpt_name": "flux1-schnell-fp8.safetensors"
58 | },
59 | "class_type": "CheckpointLoaderSimple",
60 | "_meta": {
61 | "title": "Load Checkpoint"
62 | }
63 | },
64 | "31": {
65 | "inputs": {
66 | "seed": 1063867382637414,
67 | "steps": 4,
68 | "cfg": 1,
69 | "sampler_name": "euler",
70 | "scheduler": "simple",
71 | "denoise": 1,
72 | "model": [
73 | "30",
74 | 0
75 | ],
76 | "positive": [
77 | "6",
78 | 0
79 | ],
80 | "negative": [
81 | "33",
82 | 0
83 | ],
84 | "latent_image": [
85 | "27",
86 | 0
87 | ]
88 | },
89 | "class_type": "KSampler",
90 | "_meta": {
91 | "title": "KSampler"
92 | }
93 | },
94 | "33": {
95 | "inputs": {
96 | "text": "lowres, text, error, cropped, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck, username, watermark, signature",
97 | "clip": [
98 | "30",
99 | 1
100 | ]
101 | },
102 | "class_type": "CLIPTextEncode",
103 | "_meta": {
104 | "title": "CLIP Text Encode (Negative Prompt)"
105 | }
106 | }
107 | }
--------------------------------------------------------------------------------
/utils/ml_processor/comfy_workflows/llama_workflow_api.json:
--------------------------------------------------------------------------------
1 | {
2 | "15": {
3 | "inputs": {
4 | "ckpt_name": "Meta-Llama-3-8B-Instruct-Q5_K_M.gguf",
5 | "max_ctx": 2048,
6 | "gpu_layers": 100,
7 | "n_threads": 2
8 | },
9 | "class_type": "LLMLoader",
10 | "_meta": {
11 | "title": "LLMLoader"
12 | }
13 | },
14 | "24": {
15 | "inputs": {
16 | "string": "Looking at the examples 1 and 2 below generate the sub-prompts for the FINAL EXAMPLE containing a concise story:\n\nEXAMPLE 1:\nOverall prompt:Story about Leonard Cohen's big day at the beach.\nNumber of items: 12 \nSub-prompts:Leonard Cohen looking lying in bed|Leonard Cohen brushing teeth|Leonard Cohen smiling happily|Leonard Cohen driving in car, morning|Leonard Cohen going for a swim at beach|Leonard Cohen sitting on beach towel|Leonard Cohen building sandcastle|Leonard Cohen eating sandwich|Leonard Cohen walking along beach|Leonard Cohen getting out of water at seaside|Leonard Cohen driving home, dark outside|Leonard Cohen lying in bed smiling|close of of Leonard Cohen asleep in bed---\n\nEXAMPLE 2:\nOverall prompt: Visualizing the first day of spring \nNumber of items: 24 \nSub-prompts:Frost melting off grass|Sun rising over dewy meadow|Sparrows chirping in tree|Puddles drying up|Bees flying out of hive|Flowers blooming in garden bed|Robin landing on branch|Steam rising from cup of coffee|Morning light creeping through curtains|Wind rustling through leaves|Crocus bulbs pushing through soil|Buds swelling on branches|Birds singing in chorus|Sun shining through rain-soaked pavement|Droplets clinging to spider's web|Green shoots bursting forth from roots|Garden hose dripping water|Fog burning off lake|Light filtering through stained glass window|Hummingbird sipping nectar from flower|Warm breeze rustling through wheat field|Birch trees donning new green coat|Solar eclipse casting shadow on path|Birds returning to their nests---\n\nFINAL EXAMPLE:\nOverall prompt:a magical dark castle \nNumber of items: 5 \nSub-prompts:"
17 | },
18 | "class_type": "String Literal",
19 | "_meta": {
20 | "title": "String Literal"
21 | }
22 | },
23 | "26": {
24 | "inputs": {
25 | "text": [
26 | "30",
27 | 0
28 | ]
29 | },
30 | "class_type": "ShowText|pysssss",
31 | "_meta": {
32 | "title": "✴️ U-NAI Get Text"
33 | }
34 | },
35 | "30": {
36 | "inputs": {
37 | "prompt": [
38 | "24",
39 | 0
40 | ],
41 | "temperature": 0.15,
42 | "attribute_name": "Sub-prompts",
43 | "attribute_type": "str",
44 | "attribute_description": "generate Sub-prompts:",
45 | "categories": "",
46 | "model": [
47 | "15",
48 | 0
49 | ]
50 | },
51 | "class_type": "StructuredOutput",
52 | "_meta": {
53 | "title": "Structured Output"
54 | }
55 | }
56 | }
--------------------------------------------------------------------------------
/utils/ml_processor/comfy_workflows/sdxl_img2img_workflow_api.json:
--------------------------------------------------------------------------------
1 | {
2 | "1": {
3 | "inputs": {
4 | "ckpt_name": "sd_xl_base_1.0.safetensors"
5 | },
6 | "class_type": "CheckpointLoaderSimple",
7 | "_meta": {
8 | "title": "Load Checkpoint"
9 | }
10 | },
11 | "31": {
12 | "inputs": {
13 | "filename_prefix": "ComfyUI",
14 | "images": [
15 | "44:0",
16 | 0
17 | ]
18 | },
19 | "class_type": "SaveImage",
20 | "_meta": {
21 | "title": "Save Image"
22 | }
23 | },
24 | "37:0": {
25 | "inputs": {
26 | "image": "cca52d68-91aa-4db8-b724-2370d03ff987.png",
27 | "upload": "image"
28 | },
29 | "class_type": "LoadImage",
30 | "_meta": {
31 | "title": "Load Image"
32 | }
33 | },
34 | "37:1": {
35 | "inputs": {
36 | "pixels": [
37 | "37:0",
38 | 0
39 | ],
40 | "vae": [
41 | "1",
42 | 2
43 | ]
44 | },
45 | "class_type": "VAEEncode",
46 | "_meta": {
47 | "title": "VAE Encode"
48 | }
49 | },
50 | "42:0": {
51 | "inputs": {
52 | "text": "pic of a king",
53 | "clip": [
54 | "1",
55 | 1
56 | ]
57 | },
58 | "class_type": "CLIPTextEncode",
59 | "_meta": {
60 | "title": "Positive prompt"
61 | }
62 | },
63 | "42:1": {
64 | "inputs": {
65 | "text": "",
66 | "clip": [
67 | "1",
68 | 1
69 | ]
70 | },
71 | "class_type": "CLIPTextEncode",
72 | "_meta": {
73 | "title": "Negative prompt (not used)"
74 | }
75 | },
76 | "42:2": {
77 | "inputs": {
78 | "seed": 89273174590337,
79 | "steps": 20,
80 | "cfg": 7,
81 | "sampler_name": "euler",
82 | "scheduler": "normal",
83 | "denoise": 0.6,
84 | "model": [
85 | "1",
86 | 0
87 | ],
88 | "positive": [
89 | "42:0",
90 | 0
91 | ],
92 | "negative": [
93 | "42:1",
94 | 0
95 | ],
96 | "latent_image": [
97 | "37:1",
98 | 0
99 | ]
100 | },
101 | "class_type": "KSampler",
102 | "_meta": {
103 | "title": "KSampler"
104 | }
105 | },
106 | "44:0": {
107 | "inputs": {
108 | "samples": [
109 | "42:2",
110 | 0
111 | ],
112 | "vae": [
113 | "1",
114 | 2
115 | ]
116 | },
117 | "class_type": "VAEDecode",
118 | "_meta": {
119 | "title": "VAE Decode"
120 | }
121 | },
122 | "44:1": {
123 | "inputs": {
124 | "images": [
125 | "44:0",
126 | 0
127 | ]
128 | },
129 | "class_type": "PreviewImage",
130 | "_meta": {
131 | "title": "Preview Image"
132 | }
133 | }
134 | }
--------------------------------------------------------------------------------
/ui_components/methods/data_logger.py:
--------------------------------------------------------------------------------
1 | import json
2 | import streamlit as st
3 | import time
4 | from shared.constants import InferenceParamType, InferenceStatus
5 | from shared.logging.constants import LoggingPayload, LoggingType
6 | from utils.common_utils import get_current_user_uuid
7 | from utils.data_repo.data_repo import DataRepo
8 |
9 | from utils.ml_processor.constants import ML_MODEL, MLModel
10 |
11 |
12 | def log_model_inference(model: MLModel, time_taken, **kwargs):
13 | kwargs_dict = dict(kwargs)
14 |
15 | # removing object like bufferedreader, image_obj ..
16 | for key, value in dict(kwargs_dict).items():
17 | if not isinstance(value, (int, str, list, dict)):
18 | del kwargs_dict[key]
19 |
20 | data_str = json.dumps(kwargs_dict)
21 | origin_data = kwargs_dict.get(InferenceParamType.ORIGIN_DATA.value, {})
22 | time_taken = round(time_taken, 2) if time_taken else 0
23 |
24 | # system_logger = AppLogger()
25 | # logging_payload = LoggingPayload(message="logging inference data", data=data)
26 |
27 | # # logging in console
28 | # system_logger.log(LoggingType.INFERENCE_CALL, logging_payload)
29 |
30 | # storing the log in db
31 | data_repo = DataRepo()
32 | user_id = get_current_user_uuid()
33 | ai_model = data_repo.get_ai_model_from_name(model.name, user_id)
34 |
35 | # TODO: fix this - we were initially storing all the models and their versions in the database but later moved on from it
36 | # so earlier models are found when fetching ai_model but for the new models we are adding this hack for adding dummy model_id
37 | # hackish sol for insuring that inpainting logs don't have an empty model field
38 | if ai_model is None and model.name in [
39 | ML_MODEL.sdxl_inpainting.name,
40 | ML_MODEL.ad_interpolation.name,
41 | ML_MODEL.sd3.name,
42 | ]:
43 | ai_model = data_repo.get_ai_model_from_name(ML_MODEL.sdxl.name, user_id)
44 |
45 | log_data = {
46 | "project_id": st.session_state["project_uuid"],
47 | "model_id": ai_model.uuid if ai_model else None,
48 | "input_params": data_str,
49 | "output_details": json.dumps({"model_name": model.display_name(), "version": model.version}),
50 | "total_inference_time": time_taken,
51 | "status": (
52 | InferenceStatus.COMPLETED.value
53 | if time_taken
54 | else (
55 | InferenceStatus.BACKLOG.value
56 | if "backlog" in kwargs and kwargs["backlog"]
57 | else InferenceStatus.QUEUED.value
58 | )
59 | ),
60 | "model_name": model.display_name(),
61 | "generation_source": origin_data.get("inference_type", ""),
62 | "generation_tag": origin_data.get("inference_tag", "")
63 | }
64 |
65 | log = data_repo.create_inference_log(**log_data)
66 | return log
67 |
--------------------------------------------------------------------------------
/.aws/task-definition.json:
--------------------------------------------------------------------------------
1 | {
2 | "taskDefinitionArn": "arn:aws:ecs:ap-south-1:861629679241:task-definition/backend-banodoco-frontend-task:14",
3 | "containerDefinitions": [
4 | {
5 | "name": "backend-banodoco-frontend",
6 | "image": "861629679241.dkr.ecr.ap-south-1.amazonaws.com/banodoco-frontend:latest",
7 | "cpu": 512,
8 | "memory": 2048,
9 | "portMappings": [
10 | {
11 | "containerPort": 5500,
12 | "hostPort": 5500,
13 | "protocol": "tcp"
14 | }
15 | ],
16 | "essential": true,
17 | "environment": [
18 | {
19 | "name": "SERVER_URL",
20 | "value": "banodoco-backend.mvp.internal"
21 | },
22 | {
23 | "name": "TEST_2",
24 | "value": "test"
25 | }
26 | ],
27 | "mountPoints": [],
28 | "volumesFrom": [],
29 | "logConfiguration": {
30 | "logDriver": "awslogs",
31 | "options": {
32 | "awslogs-group": "/backend/banodoco-frontend",
33 | "awslogs-region": "ap-south-1",
34 | "awslogs-stream-prefix": "backend/banodoco-frontend"
35 | }
36 | }
37 | }
38 | ],
39 | "family": "backend-banodoco-frontend-task",
40 | "taskRoleArn": "arn:aws:iam::861629679241:role/ecs-task-role",
41 | "executionRoleArn": "arn:aws:iam::861629679241:role/ecs-task-execution-role",
42 | "networkMode": "awsvpc",
43 | "revision": 14,
44 | "volumes": [],
45 | "status": "ACTIVE",
46 | "requiresAttributes": [
47 | {
48 | "name": "com.amazonaws.ecs.capability.logging-driver.awslogs"
49 | },
50 | {
51 | "name": "ecs.capability.execution-role-awslogs"
52 | },
53 | {
54 | "name": "com.amazonaws.ecs.capability.ecr-auth"
55 | },
56 | {
57 | "name": "com.amazonaws.ecs.capability.docker-remote-api.1.19"
58 | },
59 | {
60 | "name": "com.amazonaws.ecs.capability.task-iam-role"
61 | },
62 | {
63 | "name": "ecs.capability.execution-role-ecr-pull"
64 | },
65 | {
66 | "name": "com.amazonaws.ecs.capability.docker-remote-api.1.18"
67 | },
68 | {
69 | "name": "ecs.capability.task-eni"
70 | }
71 | ],
72 | "placementConstraints": [],
73 | "compatibilities": [
74 | "EC2",
75 | "FARGATE"
76 | ],
77 | "requiresCompatibilities": [
78 | "FARGATE"
79 | ],
80 | "cpu": "512",
81 | "memory": "2048",
82 | "registeredAt": "2023-09-10T10:50:27.792Z",
83 | "registeredBy": "arn:aws:iam::861629679241:user/tf-admin",
84 | "tags": []
85 | }
--------------------------------------------------------------------------------
/utils/ml_processor/comfy_workflows/creative_image_gen.json:
--------------------------------------------------------------------------------
1 | {
2 | "3": {
3 | "inputs": {
4 | "seed": [
5 | "28",
6 | 0
7 | ],
8 | "steps": 20,
9 | "cfg": 7,
10 | "sampler_name": "dpmpp_2m",
11 | "scheduler": "karras",
12 | "denoise": 1,
13 | "model": [
14 | "24",
15 | 0
16 | ],
17 | "positive": [
18 | "6",
19 | 0
20 | ],
21 | "negative": [
22 | "7",
23 | 0
24 | ],
25 | "latent_image": [
26 | "5",
27 | 0
28 | ]
29 | },
30 | "class_type": "KSampler",
31 | "_meta": {
32 | "title": "KSampler"
33 | }
34 | },
35 | "4": {
36 | "inputs": {
37 | "ckpt_name": "sd_xl_base_1.0.safetensors"
38 | },
39 | "class_type": "CheckpointLoaderSimple",
40 | "_meta": {
41 | "title": "Load Checkpoint"
42 | }
43 | },
44 | "5": {
45 | "inputs": {
46 | "width": 1024,
47 | "height": 1024,
48 | "batch_size": 1
49 | },
50 | "class_type": "EmptyLatentImage",
51 | "_meta": {
52 | "title": "Empty Latent Image"
53 | }
54 | },
55 | "6": {
56 | "inputs": {
57 | "text": "beautiful ocean scene",
58 | "clip": [
59 | "4",
60 | 1
61 | ]
62 | },
63 | "class_type": "CLIPTextEncode",
64 | "_meta": {
65 | "title": "CLIP Text Encode (Prompt)"
66 | }
67 | },
68 | "7": {
69 | "inputs": {
70 | "text": "horror",
71 | "clip": [
72 | "4",
73 | 1
74 | ]
75 | },
76 | "class_type": "CLIPTextEncode",
77 | "_meta": {
78 | "title": "CLIP Text Encode (Prompt)"
79 | }
80 | },
81 | "10": {
82 | "inputs": {
83 | "vae_name": "sdxl_vae.safetensors"
84 | },
85 | "class_type": "VAELoader",
86 | "_meta": {
87 | "title": "Load VAE"
88 | }
89 | },
90 | "11": {
91 | "inputs": {
92 | "preset": "PLUS (high strength)",
93 | "model": [
94 | "4",
95 | 0
96 | ]
97 | },
98 | "class_type": "IPAdapterUnifiedLoader",
99 | "_meta": {
100 | "title": "IPAdapter Unified Loader"
101 | }
102 | },
103 | "16": {
104 | "inputs": {
105 | "samples": [
106 | "3",
107 | 0
108 | ],
109 | "vae": [
110 | "10",
111 | 0
112 | ]
113 | },
114 | "class_type": "VAEDecode",
115 | "_meta": {
116 | "title": "VAE Decode"
117 | }
118 | },
119 | "27": {
120 | "inputs": {
121 | "filename_prefix": "IPAdapter",
122 | "images": [
123 | "16",
124 | 0
125 | ]
126 | },
127 | "class_type": "SaveImage",
128 | "_meta": {
129 | "title": "Save Image"
130 | }
131 | },
132 | "28": {
133 | "inputs": {
134 | "seed": 832322928972889
135 | },
136 | "class_type": "ttN seed",
137 | "_meta": {
138 | "title": "seed"
139 | }
140 | }
141 | }
--------------------------------------------------------------------------------
/ui_components/widgets/frame_switch_btn.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import streamlit as st
3 |
4 | from ui_components.models import InternalFrameTimingObject
5 | from utils.data_repo.data_repo import DataRepo
6 | from utils.state_refresh import refresh_app
7 |
8 |
9 | def back_and_forward_buttons():
10 | data_repo = DataRepo()
11 | timing: InternalFrameTimingObject = data_repo.get_timing_from_uuid(st.session_state["current_frame_uuid"])
12 | timing_list: List[InternalFrameTimingObject] = data_repo.get_timing_list_from_shot(timing.shot.uuid)
13 |
14 | smallbutton0, smallbutton1, smallbutton2, smallbutton3, smallbutton4 = st.columns([2, 2, 2, 2, 2])
15 |
16 | display_idx = st.session_state["current_frame_index"]
17 | with smallbutton0:
18 | if display_idx > 2:
19 | if st.button(f"{display_idx-2} ⏮️", key=f"Previous Previous Image for {display_idx}"):
20 | st.session_state["current_frame_index"] = st.session_state["current_frame_index"] - 2
21 | st.session_state["prev_frame_index"] = st.session_state["current_frame_index"]
22 | st.session_state["current_frame_uuid"] = timing_list[
23 | st.session_state["current_frame_index"] - 1
24 | ].uuid
25 | refresh_app()
26 | with smallbutton1:
27 | # if it's not the first image
28 | if display_idx != 1:
29 | if st.button(f"{display_idx-1} ⏪", key=f"Previous Image for {display_idx}"):
30 | st.session_state["current_frame_index"] = st.session_state["current_frame_index"] - 1
31 | st.session_state["prev_frame_index"] = st.session_state["current_frame_index"]
32 | st.session_state["current_frame_uuid"] = timing_list[
33 | st.session_state["current_frame_index"] - 1
34 | ].uuid
35 | refresh_app()
36 |
37 | with smallbutton2:
38 | st.button(f"{display_idx} 📍", disabled=True)
39 | with smallbutton3:
40 | # if it's not the last image
41 | if display_idx != len(timing_list):
42 | if st.button(f"{display_idx+1} ⏩", key=f"Next Image for {display_idx}"):
43 | st.session_state["current_frame_index"] = st.session_state["current_frame_index"] + 1
44 | st.session_state["prev_frame_index"] = st.session_state["current_frame_index"]
45 | st.session_state["current_frame_uuid"] = timing_list[
46 | st.session_state["current_frame_index"] - 1
47 | ].uuid
48 | refresh_app()
49 | with smallbutton4:
50 | if display_idx <= len(timing_list) - 2:
51 | if st.button(f"{display_idx+2} ⏭️", key=f"Next Next Image for {display_idx}"):
52 | st.session_state["current_frame_index"] = st.session_state["current_frame_index"] + 2
53 | st.session_state["prev_frame_index"] = st.session_state["current_frame_index"]
54 | st.session_state["current_frame_uuid"] = timing_list[
55 | st.session_state["current_frame_index"] - 1
56 | ].uuid
57 | refresh_app()
58 |
--------------------------------------------------------------------------------
/utils/media_processor/video.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | from moviepy.editor import VideoFileClip, vfx
4 |
5 |
6 | class VideoProcessor:
7 | @staticmethod
8 | def update_video_speed(video_location, desired_duration):
9 | clip = VideoFileClip(video_location)
10 |
11 | return VideoProcessor.update_clip_speed(clip, desired_duration)
12 |
13 | @staticmethod
14 | def update_video_bytes_speed(video_bytes, desired_duration):
15 | # Use a context manager for the temporary file to ensure it's deleted when done
16 | with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4", mode="wb") as temp_file:
17 | temp_file.write(video_bytes)
18 | temp_file_path = temp_file.name # Store the file name to delete later
19 |
20 | # Process the video file
21 | with VideoFileClip(temp_file_path) as clip:
22 | result = VideoProcessor.update_clip_speed(clip, desired_duration)
23 |
24 | # After processing and closing the clip, it's safe to delete the source temp file
25 | os.remove(temp_file_path)
26 |
27 | return result
28 |
29 | @staticmethod
30 | def update_clip_speed(clip: VideoFileClip, desired_duration):
31 | # Use a context manager to ensure temporary file for output is deleted when done
32 | with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4", mode="wb") as temp_output_file:
33 | temp_output_path = temp_output_file.name # Store the file name for later use
34 |
35 | # if animation_style == AnimationStyleType.DIRECT_MORPHING.value:
36 | # clip = clip.set_fps(120)
37 |
38 | # # Calculate the number of frames to keep
39 | # input_duration = clip.duration
40 | # total_frames = len(list(clip.iter_frames()))
41 | # target_frames = int(total_frames * (desired_duration / input_duration))
42 |
43 | # # Determine which frames to keep
44 | # keep_every_n_frames = total_frames / target_frames
45 | # frames_to_keep = [int(i * keep_every_n_frames)
46 | # for i in range(target_frames)]
47 |
48 | # # Create a new video clip with the selected frames
49 | # output_clip = concatenate_videoclips(
50 | # [clip.subclip(i/clip.fps, (i+1)/clip.fps) for i in frames_to_keep])
51 |
52 | # output_clip.write_videofile(filename=temp_output_file.name, codec="libx265")
53 |
54 | # Apply desired video speed change and write to the temporary output file
55 | input_video_duration = clip.duration
56 | desired_speed_change = float(input_video_duration) / float(desired_duration)
57 | print("Desired Speed Change: " + str(desired_speed_change))
58 | output_clip = clip.fx(vfx.speedx, desired_speed_change)
59 | output_clip.write_videofile(filename=temp_output_path, codec="libx264", preset="fast")
60 |
61 | # Read the processed video bytes from the temporary output file
62 | with open(temp_output_path, "rb") as f:
63 | video_bytes = f.read()
64 |
65 | # Now it's safe to delete the output temp file since its content is already read
66 | os.remove(temp_output_path)
67 |
68 | return video_bytes
69 |
--------------------------------------------------------------------------------
/utils/ml_processor/gpu/gpu.py:
--------------------------------------------------------------------------------
1 | import json
2 | from shared.constants import InferenceParamType
3 | from ui_components.methods.data_logger import log_model_inference
4 | from utils.constants import MLQueryObject
5 | from utils.data_repo.data_repo import DataRepo
6 | from utils.ml_processor.comfy_data_transform import get_file_path_list, get_model_workflow_from_query
7 | from utils.ml_processor.constants import MLModel
8 | from utils.ml_processor.gpu.utils import predict_gpu_output, setup_comfy_runner
9 | from utils.ml_processor.ml_interface import MachineLearningProcessor
10 | import time
11 |
12 |
13 | class GPUProcessor(MachineLearningProcessor):
14 | def __init__(self):
15 | setup_comfy_runner()
16 | data_repo = DataRepo()
17 | self.app_settings = data_repo.get_app_secrets_from_user_uuid()
18 | super().__init__()
19 |
20 | def predict_model_output_standardized(
21 | self,
22 | model: MLModel,
23 | query_obj: MLQueryObject,
24 | queue_inference=False,
25 | backlog=False,
26 | ):
27 | (
28 | workflow_type,
29 | workflow_json,
30 | output_node_ids,
31 | extra_model_list,
32 | ignore_list,
33 | ) = get_model_workflow_from_query(model, query_obj)
34 |
35 | file_path_list = get_file_path_list(model, query_obj)
36 |
37 | # this is the format that is expected by comfy_runner
38 | data = {
39 | "workflow_input": workflow_json,
40 | "file_path_list": file_path_list,
41 | "output_node_ids": output_node_ids,
42 | "extra_model_list": extra_model_list,
43 | "ignore_model_list": ignore_list,
44 | }
45 |
46 | params = {
47 | "prompt": query_obj.prompt, # hackish sol
48 | InferenceParamType.QUERY_DICT.value: query_obj.to_json(),
49 | InferenceParamType.GPU_INFERENCE.value: json.dumps(data),
50 | InferenceParamType.FILE_RELATION_DATA.value: query_obj.relation_data,
51 | }
52 | return (
53 | self.predict_model_output(model, **params)
54 | if not queue_inference
55 | else self.queue_prediction(model, **params, backlog=backlog)
56 | )
57 |
58 | def predict_model_output(self, replicate_model: MLModel, **kwargs):
59 | queue_inference = kwargs.get("queue_inference", False)
60 | if queue_inference:
61 | return self.queue_prediction(replicate_model, **kwargs)
62 |
63 | data = kwargs.get(InferenceParamType.GPU_INFERENCE.value, None)
64 | data = json.loads(data)
65 | start_time = time.time()
66 | output = predict_gpu_output(data["workflow_input"], data["file_path_list"], data["output_node_ids"])
67 | end_time = time.time()
68 |
69 | log = log_model_inference(replicate_model, end_time - start_time, **kwargs)
70 | return output, log
71 |
72 | def queue_prediction(self, replicate_model, **kwargs):
73 | log = log_model_inference(replicate_model, None, **kwargs)
74 | return None, log
75 |
76 | def upload_training_data(self, zip_file_name, delete_after_upload=False):
77 | # TODO: fix for online hosting
78 | # return the local file path as it is
79 | return zip_file_name
80 |
--------------------------------------------------------------------------------
/.github/workflows/deploy.yml:
--------------------------------------------------------------------------------
1 | name: Deploy to ECR
2 |
3 | on:
4 |
5 | push:
6 | branches: [ piyush-dev ]
7 |
8 | jobs:
9 |
10 | build:
11 | name: Build Image
12 | runs-on: ubuntu-latest
13 |
14 | steps:
15 |
16 | - name: Check out code
17 | uses: actions/checkout@v2
18 |
19 | - name: Configure AWS credentials
20 | uses: aws-actions/configure-aws-credentials@v1
21 | with:
22 | aws-access-key-id: ${{ secrets.AWS_ECR_ACCESS_KEY }}
23 | aws-secret-access-key: ${{ secrets.AWS_ECR_SECRET_KEY }}
24 | aws-region: ap-south-1
25 |
26 | - name: Login to Amazon ECR
27 | id: login-ecr
28 | uses: aws-actions/amazon-ecr-login@v1
29 |
30 | - name: Build, tag, and push image to Amazon ECR
31 | env:
32 | ECR_REGISTRY: ${{ steps.login-ecr.outputs.registry }}
33 | ECR_REPOSITORY: banodoco-frontend
34 | IMAGE_TAG: latest
35 | id: build-image
36 | run: |
37 | docker build -t $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG .
38 | docker push $ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG
39 | echo "image=$ECR_REGISTRY/$ECR_REPOSITORY:$IMAGE_TAG" >> $GITHUB_OUTPUT
40 |
41 | - name: Fill in the new image ID in the Amazon ECS task definition
42 | env:
43 | ECS_TASK_DEFINITION: .aws/task-definition.json
44 | CONTAINER_NAME: backend-banodoco-frontend
45 | id: task-def
46 | uses: aws-actions/amazon-ecs-render-task-definition@c804dfbdd57f713b6c079302a4c01db7017a36fc
47 | with:
48 | task-definition: ${{ env.ECS_TASK_DEFINITION }}
49 | container-name: ${{ env.CONTAINER_NAME }}
50 | image: ${{ steps.build-image.outputs.image }}
51 |
52 | - name: Deploy Amazon ECS task definition
53 | env:
54 | ECS_SERVICE: backend-banodoco-frontend-service
55 | ECS_CLUSTER: backend-banodoco-frontend-cluster
56 | uses: aws-actions/amazon-ecs-deploy-task-definition@df9643053eda01f169e64a0e60233aacca83799a
57 | with:
58 | task-definition: ${{ steps.task-def.outputs.task-definition }}
59 | service: ${{ env.ECS_SERVICE }}
60 | cluster: ${{ env.ECS_CLUSTER }}
61 | wait-for-service-stability: true
62 |
63 | update-runner:
64 | name: Update Background Runner
65 | runs-on: self-hosted
66 |
67 | steps:
68 | - name: Checkout code
69 | uses: actions/checkout@v2
70 |
71 | - name: Create and activate virtual environment
72 | run: |
73 | python3 -m venv venv
74 | source venv/bin/activate
75 |
76 | - name: Install dependencies
77 | run: |
78 | python3 -m pip install --upgrade pip
79 | source venv/bin/activate && pip install -r requirements.txt
80 |
81 | - name: Create .env file
82 | run: |
83 | touch .env
84 | echo "SERVER=production" > .env
85 | echo "SERVER_URL=https://api.banodoco.ai" >> .env
86 | echo "HOSTED_BACKGROUND_RUNNER_MODE=True" >> .env
87 | echo "admin_email=${{ secrets.ADMIN_EMAIL }}" >> .env
88 | echo "admin_password=${{ secrets.ADMIN_PASSWORD }}" >> .env
89 | echo "ENCRYPTION_KEY=${{ secrets.FERNET_ENCRYPTION_KEY }}" >> .env
90 |
91 | - name: Restart runner
92 | run: |
93 | if pkill -0 -f "banodoco_runner"; then
94 | pkill -f "banodoco_runner"
95 | fi
96 | . venv/bin/activate && nohup python banodoco_runner.py > script_output.log 2>&1 &
--------------------------------------------------------------------------------
/shared/utils.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import time
3 | import jwt
4 | import requests
5 | from shared.constants import SERVER_URL, InternalResponse
6 | import urllib.parse
7 |
8 |
9 | def execute_shell_command(command: str):
10 | import subprocess
11 |
12 | result = subprocess.run(command, shell=True, capture_output=True, text=True)
13 | # print("Error:\n", result.stderr)
14 | return InternalResponse(result.stdout, "success", result.returncode == 0)
15 |
16 |
17 | def is_online_file_path(file_path):
18 | parsed = urllib.parse.urlparse(file_path)
19 | return parsed.scheme in ("http", "https", "ftp")
20 |
21 |
22 | def is_url_valid(url):
23 | try:
24 | response = requests.head(url, allow_redirects=True)
25 | final_response = response.history[-1] if response.history else response
26 |
27 | return final_response.status_code in [200, 201, 307] # TODO: handle all possible status codes
28 | except Exception as e:
29 | return False
30 |
31 |
32 | def get_file_type(url):
33 | try:
34 | response = requests.head(url)
35 | content_type = response.headers.get("Content-Type")
36 |
37 | if content_type and "image" in content_type:
38 | return "image"
39 | elif content_type and "video" in content_type:
40 | return "video"
41 | else:
42 | return "unknown"
43 | except Exception as e:
44 | print("Error:", e)
45 | return "unknown"
46 |
47 |
48 | def generate_fresh_token(refresh_token):
49 | if not refresh_token:
50 | return None, None
51 |
52 | url = f"{SERVER_URL}/v1/authentication/refresh"
53 |
54 | payload = {}
55 | headers = {"Authorization": f"Bearer {refresh_token}"}
56 |
57 | response = requests.request("GET", url, headers=headers, data=payload)
58 | if response.status_code == 200:
59 | data = response.json()
60 | return data["payload"]["token"], data["payload"]["refresh_token"]
61 |
62 | return None, None
63 |
64 |
65 | def validate_token_through_db(token, refresh_token):
66 | url = f"{SERVER_URL}/v1/user/op"
67 |
68 | payload = {}
69 | headers = {"Authorization": f"Bearer {token}"}
70 |
71 | response = requests.request("GET", url, headers=headers, data=payload)
72 | if response.status_code == 200:
73 | data = response.json()
74 | return token, refresh_token
75 | else:
76 | return generate_fresh_token(refresh_token)
77 |
78 |
79 | def validate_token(
80 | token,
81 | refresh_token,
82 | validate_through_db=False,
83 | ):
84 | # returns a fresh token if the old one has expired
85 | # returns None if the token has expired or can't be renewed
86 | if not token:
87 | return None, None
88 |
89 | try:
90 | decoded_token = jwt.decode(token, options={"verify_signature": False})
91 | exp = decoded_token.get("exp")
92 |
93 | if exp is None:
94 | return token, refresh_token
95 |
96 | now = time.time()
97 | if exp > now:
98 | if not validate_through_db:
99 | return token, refresh_token
100 | else:
101 | return validate_token_through_db(token, refresh_token)
102 | else:
103 | return generate_fresh_token(refresh_token)
104 |
105 | except jwt.ExpiredSignatureError:
106 | print("expired token, trying to refresh...")
107 | return generate_fresh_token(refresh_token)
108 | except Exception as e:
109 | print("error validating the jwt: ", str(e))
110 | return None, None
111 |
--------------------------------------------------------------------------------
/scripts/entrypoint.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | COMMAND="streamlit run app.py --runner.fastReruns false --server.runOnSave true --server.port 5500"
4 |
5 | compare_versions() {
6 | IFS='.' read -r -a ver1 << EOF
7 | $(echo "$1")
8 | EOF
9 | IFS='.' read -r -a ver2 << EOF
10 | $(echo "$2")
11 | EOF
12 |
13 | # Fill empty fields with zeros
14 | i=${#ver1[@]}
15 | while [ $i -lt ${#ver2[@]} ]; do
16 | ver1[i]=0
17 | i=$((i + 1))
18 | done
19 | i=0
20 | while [ $i -lt ${#ver2[@]} ]; do
21 | if [ -z "${ver1[i]}" ]; then
22 | ver1[i]=0
23 | fi
24 | i=$((i + 1))
25 | done
26 |
27 | # Compare major and minor versions
28 | i=0
29 | while [ $i -lt $((${#ver1[@]} - 1)) ]; do
30 | if [ $((10#${ver1[i]})) -gt $((10#${ver2[i]})) ]; then
31 | echo "1"
32 | return
33 | elif [ $((10#${ver1[i]})) -lt $((10#${ver2[i]})) ]; then
34 | echo "-1"
35 | return
36 | fi
37 | i=$((i + 1))
38 | done
39 |
40 | # Compare patch versions
41 | if [ $((10#${ver1[2]})) -gt $((10#${ver2[2]})) ]; then
42 | echo "1"
43 | return
44 | elif [ $((10#${ver1[2]})) -lt $((10#${ver2[2]})) ]; then
45 | echo "-1"
46 | return
47 | fi
48 |
49 | echo "0"
50 | }
51 |
52 | update_app() {
53 | CURRENT_VERSION=$(curl -s "https://raw.githubusercontent.com/banodoco/Dough/feature/final/scripts/app_version.txt")
54 |
55 | ERR_MSG="Unable to fetch the current version from the remote repository."
56 | # file not present
57 | if ! echo "$CURRENT_VERSION" | grep -q '^[0-9]\+\.[0-9]\+\.[0-9]\+$'; then
58 | echo $ERR_MSG
59 | return
60 | fi
61 | # file is empty
62 | if [ -z "$CURRENT_VERSION" ]; then
63 | echo $ERR_MSG
64 | return
65 | fi
66 |
67 | CURRENT_DIR=$(pwd)
68 | LOCAL_VERSION=$(cat ${CURRENT_DIR}/scripts/app_version.txt)
69 | echo "local version $LOCAL_VERSION"
70 | echo "current version $CURRENT_VERSION"
71 | VERSION_DIFF=$(compare_versions "$LOCAL_VERSION" "$CURRENT_VERSION")
72 | if [ "$VERSION_DIFF" == "-1" ]; then
73 | echo "A newer version ($CURRENT_VERSION) is available. Updating..."
74 |
75 | git stash
76 | # Step 1: Pull from the current branch
77 | git pull origin "$(git rev-parse --abbrev-ref HEAD)"
78 |
79 | # Step 2: Check if the comfy_runner folder is present
80 | if [ -d "${CURRENT_DIR}/comfy_runner" ]; then
81 | # Step 3a: If comfy_runner is present, pull from the feature/package branch
82 | # echo "comfy_runner folder found. Pulling from feature/package branch."
83 | cd comfy_runner && git pull origin main
84 | cd ..
85 | else
86 | # Step 3b: If comfy_runner is not present, clone the repository
87 | echo "comfy_runner folder not found. Cloning repository."
88 | REPO_URL="https://github.com/piyushK52/comfy_runner.git"
89 | git clone "$REPO_URL" "${CURRENT_DIR}/comfy_runner"
90 | fi
91 |
92 | echo "$CURRENT_VERSION" > ${CURRENT_DIR}/scripts/app_version.txt
93 | else
94 | echo "You have the latest version ($LOCAL_VERSION)."
95 | fi
96 | }
97 |
98 | while [ "$#" -gt 0 ]; do
99 | case $1 in
100 | --update)
101 | update_app
102 | ;;
103 | *)
104 | echo "Invalid option: $1" >&2
105 | exit 1
106 | ;;
107 | esac
108 | shift
109 | done
110 |
111 | # Execute the base command
112 | eval $COMMAND
--------------------------------------------------------------------------------
/scripts/entrypoint.bat:
--------------------------------------------------------------------------------
1 | @echo off
2 |
3 | set COMMAND=streamlit run app.py --runner.fastReruns false --server.runOnSave true --server.port 5500
4 | goto :loop
5 |
6 | :compare_versions
7 | setlocal EnableDelayedExpansion
8 | set "ver1=%~1"
9 | set "ver2=%~2"
10 |
11 | set "ver1=!ver1:.= !"
12 | set "ver2=!ver2:.= !"
13 | set "arr1="
14 | set "arr2="
15 |
16 | set "index=0"
17 | for %%c in (%ver1%) do (
18 | set "arr1[!index!]=%%c"
19 | set /a "index+=1"
20 | )
21 |
22 | set "index=0"
23 | for %%c in (%ver2%) do (
24 | set "arr2[!index!]=%%c"
25 | set /a "index+=1"
26 | )
27 |
28 | set CURRENT_DIR=%cd%
29 |
30 | for /L %%i in (1,1,3) do (
31 | set "v1=!arr1[%%i]!"
32 | set "v2=!arr2[%%i]!"
33 | if "!v1!" equ "" set "v1=0"
34 | if "!v2!" equ "" set "v2=0"
35 | if !v1! gtr !v2! (
36 | echo You have the latest version
37 | endlocal & exit /b
38 | ) else if !v1! lss !v2! (
39 | echo A newer version is available. Updating...
40 | git stash
41 | rem Step 1: Pull from the current branch
42 | git pull origin "!git rev-parse --abbrev-ref HEAD!"
43 |
44 | rem Step 2: Check if the comfy_runner folder is present
45 | if exist "!CURRENT_DIR!\comfy_runner" (
46 | rem Step 3a: If comfy_runner is present, pull from the feature/package branch
47 | rem echo comfy_runner folder found. Pulling from feature/package branch.
48 | cd comfy_runner
49 | rem git pull origin feature/package
50 | cd "!CURRENT_DIR!"
51 | ) else (
52 | rem Step 3b: If comfy_runner is not present, clone the repository
53 | echo comfy_runner folder not found. Cloning repository.
54 | set REPO_URL=https://github.com/piyushK52/comfy_runner.git
55 | git clone "!REPO_URL!" "!CURRENT_DIR!\comfy_runner"
56 | )
57 |
58 | echo !CURRENT_VERSION! > "!CURRENT_DIR!\scripts\app_version.txt"
59 | endlocal & exit /b
60 | )
61 | )
62 | echo You have the latest version
63 | endlocal & exit /b
64 |
65 | :update_app
66 | setlocal EnableDelayedExpansion
67 | set CURRENT_VERSION=
68 | for /f "delims=" %%i in ('curl -s "https://raw.githubusercontent.com/banodoco/Dough/feature/final/scripts/app_version.txt"') do set CURRENT_VERSION=%%i
69 |
70 | echo %CURRENT_VERSION% | findstr /r "^[^a-zA-Z]*$" >nul
71 | if %errorlevel% neq 0 (
72 | echo Invalid version format: %CURRENT_VERSION%. Expected format: X.X.X (e.g., 1.2.3^)
73 | endlocal & exit /b
74 | )
75 |
76 | if not "%CURRENT_VERSION%" == "" (
77 | echo %CURRENT_VERSION%
78 | ) else (
79 | set ERR_MSG=Unable to fetch the current version from the remote repository.
80 | echo %ERR_MSG%
81 | exit /b
82 | )
83 |
84 | set CURRENT_DIR=%cd%
85 | set LOCAL_VERSION=
86 | for /f "delims=" %%i in ('type "%CURRENT_DIR%\scripts\app_version.txt"') do set LOCAL_VERSION=%%i
87 |
88 | echo Local version %LOCAL_VERSION%
89 | echo Current version %CURRENT_VERSION%
90 |
91 | call :compare_versions "%LOCAL_VERSION%" "%CURRENT_VERSION%"
92 | endlocal & exit /b
93 |
94 | :loop
95 | if "%~1" == "--update" (
96 | call :update_app
97 | %COMMAND%
98 | exit /b
99 | )
100 | if not "%~1" == "" (
101 | echo Invalid option: %1 >&2
102 | exit /b
103 | )
104 | shift
105 | if not "%~1" == "" goto :loop
106 |
107 | %COMMAND%
--------------------------------------------------------------------------------
/utils/cache/cache.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 |
3 | from utils.enum import ExtendedEnum
4 |
5 |
6 | class CacheKey(ExtendedEnum):
7 | TIMING_DETAILS = "timing_details"
8 | APP_SETTING = "app_setting"
9 | APP_SECRET = "app_secret"
10 | PROJECT_SETTING = "project_setting"
11 | AI_MODEL = "ai_model"
12 | LOGGED_USER = "logged_user"
13 | FILE = "file"
14 | SHOT = "shot"
15 | # temp items (only cached for speed boost)
16 | LOG = "log"
17 | LOG_PAGES = "log_pages"
18 | PROJECT = "project"
19 | USER = "user"
20 |
21 |
22 | class StCache:
23 | @staticmethod
24 | def get(uuid, data_type):
25 | uuid = str(uuid)
26 | if data_type in st.session_state:
27 | for ele in st.session_state[data_type]:
28 | ele_uuid = ele["uuid"] if type(ele) is dict else str(ele.uuid)
29 | if ele_uuid == uuid:
30 | return ele
31 |
32 | return None
33 |
34 | @staticmethod
35 | def update(data, data_type) -> bool:
36 | object_found = False
37 | uuid = data["uuid"] if type(data) is dict else data.uuid
38 | uuid = str(uuid)
39 |
40 | if data_type in st.session_state:
41 | object_list = st.session_state[data_type]
42 | for idx, ele in enumerate(object_list):
43 | ele_uuid = ele["uuid"] if type(ele) is dict else str(ele.uuid)
44 | if ele_uuid == uuid:
45 | object_list[idx] = data
46 | object_found = True
47 |
48 | st.session_state[data_type] = object_list
49 |
50 | return object_found
51 |
52 | @staticmethod
53 | def add(data, data_type) -> bool:
54 | uuid = data["uuid"] if type(data) is dict else data.uuid
55 | uuid = str(uuid)
56 | obj = StCache.get(uuid, data_type)
57 | if obj:
58 | StCache.update(data, data_type)
59 | else:
60 | if data_type in st.session_state:
61 | object_list = st.session_state[data_type]
62 | else:
63 | object_list = []
64 |
65 | object_list.append(data)
66 |
67 | st.session_state[data_type] = object_list
68 |
69 | @staticmethod
70 | def delete(uuid, data_type) -> bool:
71 | object_found = False
72 | uuid = str(uuid)
73 | if data_type in st.session_state:
74 | object_list = st.session_state[data_type]
75 | for ele in object_list:
76 | ele_uuid = ele["uuid"] if type(ele) is dict else str(ele.uuid)
77 | if ele_uuid == uuid:
78 | object_list.remove(ele)
79 | object_found = True
80 | break
81 |
82 | st.session_state[data_type] = object_list
83 |
84 | return object_found
85 |
86 | @staticmethod
87 | def delete_all(data_type) -> bool:
88 | if data_type in st.session_state:
89 | del st.session_state[data_type]
90 | return True
91 |
92 | return False
93 |
94 | @staticmethod
95 | def add_all(data_list, data_type) -> bool:
96 | for data in data_list:
97 | StCache.add(data, data_type)
98 |
99 | return True
100 |
101 | @staticmethod
102 | def get_all(data_type):
103 | if data_type in st.session_state:
104 | return st.session_state[data_type]
105 |
106 | return []
107 |
108 | # deletes all cached objects of every data type
109 | @staticmethod
110 | def clear_entire_cache() -> bool:
111 | for c in CacheKey.value_list():
112 | StCache.delete_all(c)
113 |
114 | return True
115 |
--------------------------------------------------------------------------------
/ui_components/components/project_settings_page.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import streamlit as st
3 | import os
4 | import time
5 | from ui_components.widgets.attach_audio_element import attach_audio_element
6 | from PIL import Image
7 |
8 | from utils.common_utils import get_current_user_uuid
9 | from utils.data_repo.data_repo import DataRepo
10 | from utils.state_refresh import refresh_app
11 |
12 |
13 | def project_settings_page(project_uuid):
14 | data_repo = DataRepo()
15 | st.markdown("#### Project Settings")
16 | st.markdown("***")
17 |
18 | with st.expander("📋 Project name", expanded=True):
19 | project = data_repo.get_project_from_uuid(project_uuid)
20 | new_name = st.text_input("Enter new name:", project.name)
21 | if st.button("Save", key="project_name"):
22 | data_repo.update_project(uuid=project_uuid, name=new_name)
23 | refresh_app()
24 | project_settings = data_repo.get_project_setting(project_uuid)
25 |
26 | frame_sizes = ["512x512", "768x512", "512x768", "512x896", "896x512", "512x1024", "1024x512"]
27 | current_size = f"{project_settings.width}x{project_settings.height}"
28 | current_index = frame_sizes.index(current_size) if current_size in frame_sizes else 0
29 |
30 | with st.expander("🖼️ Frame Size", expanded=True):
31 |
32 | v1, v2, v3 = st.columns([4, 4, 2])
33 | with v1:
34 | st.write("Current Size = ", project_settings.width, "x", project_settings.height)
35 |
36 | custom_frame_size = st.checkbox("Enter custom frame size", value=False)
37 | err = False
38 | if not custom_frame_size:
39 | frame_size = st.radio(
40 | "Select frame size:",
41 | options=frame_sizes,
42 | index=current_index,
43 | key="frame_size",
44 | horizontal=True,
45 | )
46 | width, height = map(int, frame_size.split("x"))
47 | else:
48 | st.info(
49 | "This is an experimental feature. There might be some issues - particularly with image generation."
50 | )
51 | width = st.text_input("Width", value=512)
52 | height = st.text_input("Height", value=512)
53 | try:
54 | width, height = int(width), int(height)
55 | err = False
56 | except Exception as e:
57 | st.error("Please input integer values")
58 | err = True
59 |
60 | if not err:
61 | img = Image.new("RGB", (width, height), color=(73, 109, 137))
62 | st.image(img, width=70)
63 |
64 | if st.button("Save"):
65 | st.success("Frame size updated successfully")
66 | time.sleep(0.3)
67 | data_repo.update_project_setting(project_uuid, width=width)
68 | data_repo.update_project_setting(project_uuid, height=height)
69 | refresh_app()
70 |
71 | st.write("")
72 | st.write("")
73 | st.write("")
74 | delete_proj = st.checkbox("I confirm to delete this project entirely", value=False)
75 | if st.button("Delete Project", disabled=(not delete_proj)):
76 | project_list = data_repo.get_all_project_list(user_id=get_current_user_uuid())
77 | if project_list and len(project_list) > 1:
78 | data_repo.update_project(uuid=project_uuid, is_disabled=True)
79 | st.success("Project deleted successfully")
80 | st.session_state["index_of_project_name"] = 0
81 | else:
82 | st.error("You can't delete the only available project")
83 |
84 | time.sleep(0.7)
85 | refresh_app()
86 |
--------------------------------------------------------------------------------
/utils/ml_processor/comfy_workflows/ipadapter_composition_workflow_api.json:
--------------------------------------------------------------------------------
1 | {
2 | "1": {
3 | "inputs": {
4 | "ckpt_name": "Realistic_Vision_V5.1.safetensors"
5 | },
6 | "class_type": "CheckpointLoaderSimple",
7 | "_meta": {
8 | "title": "Load Checkpoint"
9 | }
10 | },
11 | "2": {
12 | "inputs": {
13 | "vae_name": "vae-ft-mse-840000-ema-pruned.safetensors"
14 | },
15 | "class_type": "VAELoader",
16 | "_meta": {
17 | "title": "Load VAE"
18 | }
19 | },
20 | "3": {
21 | "inputs": {
22 | "ipadapter_file": "ip_plus_composition_sd15.safetensors"
23 | },
24 | "class_type": "IPAdapterModelLoader",
25 | "_meta": {
26 | "title": "Load IPAdapter Model"
27 | }
28 | },
29 | "4": {
30 | "inputs": {
31 | "clip_name": "SD1.5/pytorch_model.bin"
32 | },
33 | "class_type": "CLIPVisionLoader",
34 | "_meta": {
35 | "title": "Load CLIP Vision"
36 | }
37 | },
38 | "6": {
39 | "inputs": {
40 | "image": "Hulk_Hogan.jpg",
41 | "upload": "image"
42 | },
43 | "class_type": "LoadImage",
44 | "_meta": {
45 | "title": "Load Image"
46 | }
47 | },
48 | "7": {
49 | "inputs": {
50 | "text": "hulk hogan",
51 | "clip": [
52 | "1",
53 | 1
54 | ]
55 | },
56 | "class_type": "CLIPTextEncode",
57 | "_meta": {
58 | "title": "CLIP Text Encode (Prompt)"
59 | }
60 | },
61 | "8": {
62 | "inputs": {
63 | "text": "blurry, photo, malformed",
64 | "clip": [
65 | "1",
66 | 1
67 | ]
68 | },
69 | "class_type": "CLIPTextEncode",
70 | "_meta": {
71 | "title": "CLIP Text Encode (Prompt)"
72 | }
73 | },
74 | "9": {
75 | "inputs": {
76 | "seed": 16,
77 | "steps": 30,
78 | "cfg": 5,
79 | "sampler_name": "dpmpp_2m_sde",
80 | "scheduler": "exponential",
81 | "denoise": 1,
82 | "model": [
83 | "28",
84 | 0
85 | ],
86 | "positive": [
87 | "7",
88 | 0
89 | ],
90 | "negative": [
91 | "8",
92 | 0
93 | ],
94 | "latent_image": [
95 | "10",
96 | 0
97 | ]
98 | },
99 | "class_type": "KSampler",
100 | "_meta": {
101 | "title": "KSampler"
102 | }
103 | },
104 | "10": {
105 | "inputs": {
106 | "width": 512,
107 | "height": 512,
108 | "batch_size": 1
109 | },
110 | "class_type": "EmptyLatentImage",
111 | "_meta": {
112 | "title": "Empty Latent Image"
113 | }
114 | },
115 | "11": {
116 | "inputs": {
117 | "samples": [
118 | "9",
119 | 0
120 | ],
121 | "vae": [
122 | "2",
123 | 0
124 | ]
125 | },
126 | "class_type": "VAEDecode",
127 | "_meta": {
128 | "title": "VAE Decode"
129 | }
130 | },
131 | "27": {
132 | "inputs": {
133 | "filename_prefix": "ComfyUI",
134 | "images": [
135 | "11",
136 | 0
137 | ]
138 | },
139 | "class_type": "SaveImage",
140 | "_meta": {
141 | "title": "Save Image"
142 | }
143 | },
144 | "28": {
145 | "inputs": {
146 | "weight": 1,
147 | "weight_type": "linear",
148 | "combine_embeds": "concat",
149 | "embeds_scaling": "V only",
150 | "start_at": 0,
151 | "end_at": 1,
152 | "ipadapter": [
153 | "3",
154 | 0
155 | ],
156 | "clip_vision": [
157 | "4",
158 | 0
159 | ],
160 | "image": [
161 | "6",
162 | 0
163 | ],
164 | "model": [
165 | "1",
166 | 0
167 | ]
168 | },
169 | "class_type": "IPAdapterAdvanced",
170 | "_meta": {
171 | "title": "IPAdapter Advanced"
172 | }
173 | }
174 | }
--------------------------------------------------------------------------------
/scripts/config.toml:
--------------------------------------------------------------------------------
1 | [file_hash]
2 | "sd_xl_base_1.0.safetensors" = { "location" = "models/checkpoints/", "hash" = ["cf9d29192ef7433096a69fe5e32efba1"]}
3 | "sd_xl_refiner_1.0.safetensors" = { "location" = "models/checkpoints/", "hash" = ["c0df90f318abf30fcd2cd233ea638251"]}
4 | "Realistic_Vision_V5.1.safetensors" = { "location" = "models/checkpoints/", "hash" = ["03c000bbbbfede673b334897648431ba"]}
5 | "dreamshaper_8.safetensors" = { "location" = "models/checkpoints/", "hash" = ["1bfe9fce2c2c1498bcb6fdac2e44aab5"]}
6 | "sd_xl_refiner_1.0_0.9vae.safetensors" = { "location" = "models/checkpoints/", "hash" = ["74d31c96097471d5d3bd93a915ad0b30"]}
7 | "pytorch_model.bin" = { "location" = "", "hash" = ["0a0900180245762f122a681839391a9b", "ac277e29ea6d70799072d1ed52ac0f79"]}
8 | "v3_sd15_sparsectrl_rgb.ckpt" = { "location" = "models/controlnet/SD1.5/animatediff/", "hash" = ["b7d6b72ee2ba2866d167277598f27340"]}
9 | "inpainting_diffusion_pytorch_model.fp16.safetensors" = { "location" = "models/unet/inpainting_diffusion_pytorch_model.fp16.safetensors/", "hash" = ["9c688d78a75c1d3a0b33afffdd7e72f2"]}
10 | "vae-ft-mse-840000-ema-pruned.safetensors" = { "location" = "models/vae/", "hash" = ["418949762c3f321f2927e590e255f63c"]}
11 | "ip_plus_composition_sd15.safetensors" = { "location" = "models/ipadapter/", "hash" = ["029d345e257b72bc7be2fdf0a2295eba"]}
12 | "ip-adapter_sdxl.safetensors" = { "location" = "models/ipadapter/", "hash" = ["67b7f27bbf03d6b6921e520779abc212"]}
13 | "ip-adapter-plus_sd15.bin" = { "location" = "models/ipadapter/", "hash" = ["a50be6ac20883c4969c7bb60d5f2a46b"]}
14 | "v3_sd15_mm.ckpt" = { "location" = "models/animatediff_models/", "hash" = ["ac855686ee49fb5436c881c50a9b59ca"]}
15 | "AnimateLCM_sd15_t2v.ckpt" = { "location" = "models/animatediff_models/", "hash" = ["b5426e509e70b1b1b13ca03f877bac2a"]}
16 | "AnimateLCM_sd15_t2v_lora.safetensors" = { "location" = "models/loras/", "hash" = ["fb6fa0bf09bcfbe3d002ce0bfdfa7770"]}
17 |
18 |
19 | [node_version]
20 | "ComfyUI-Manager" = { "url" = "https://github.com/ltdrdata/ComfyUI-Manager", "commit_hash" = "7e777c5460754f4ef7a6f1ad9ef9356b89f66339"}
21 | "ComfyUI_IPAdapter_plus" = { "url" = "https://github.com/cubiq/ComfyUI_IPAdapter_plus", "commit_hash" = "b188a6cb39b512a9c6da7235b880af42c78ccd0d"}
22 | "comfyui-various" = { "url" = "https://github.com/jamesWalker55/comfyui-various", "commit_hash" = "cc66b62c0861314a4952eb96a6ae330f180bf6a1"}
23 | "comfy_PoP" = { "url" = "https://github.com/picturesonpictures/comfy_PoP", "commit_hash" = "db66c9777c6ab45f9c7056f1d8f365b0d4f2b339"}
24 | "ComfyUI-Frame-Interpolation" = { "url" = "https://github.com/Fannovel16/ComfyUI-Frame-Interpolation", "commit_hash" = "483dfe64465369e077d351ed2f1acbf7dc046864"}
25 | "efficiency-nodes-comfyui" = { "url" = "https://github.com/jags111/efficiency-nodes-comfyui", "commit_hash" = "3ead4afd120833f3bffdefeca0d6545df8051798"}
26 | "ComfyUI_FizzNodes" = { "url" = "https://github.com/FizzleDorf/ComfyUI_FizzNodes", "commit_hash" = "974b41cfdfde4f84d97712234d6e502f7da831fa"}
27 | "ComfyUI-Advanced-ControlNet" = { "url" = "https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet", "commit_hash" = "74d0c56ab3ba69663281390cc1b2072107939f96"}
28 | "ComfyUI-AnimateDiff-Evolved" = { "url" = "https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved", "commit_hash" = "83fe8d40638e3491a31fa2865107bdc14a308a35"}
29 | "ComfyUI-VideoHelperSuite" = { "url" = "https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite", "commit_hash" = "0376e577442c236fbba6ef410a4e5ec64aed5017"}
30 | "ComfyUI_essentials" = { "url" = "https://github.com/cubiq/ComfyUI_essentials", "commit_hash" = "99aad72c84e1dac2f924941ecb12c93007512a8c"}
31 | "steerable-motion" = { "url" = "https://github.com/banodoco/steerable-motion", "commit_hash" = "b95e9eed09741bd3f0fc3933ab45d09f6f055aae"}
32 | "ComfyUI_tinyterraNodes" = { "url" = "https://github.com/TinyTerra/ComfyUI_tinyterraNodes", "commit_hash" = "52711d57e97c0255e6c9d627cfc0dfec4ebacf22"}
33 |
34 |
35 | [comfy]
36 | commit_hash = "56e8f5e4fd0a048811095f44d2147bce48b02457"
37 |
38 |
39 | [strict_pkg_versions]
40 | "numpy" = "1.24.4"
--------------------------------------------------------------------------------
/ui_components/components/query_logger_page.py:
--------------------------------------------------------------------------------
1 | import json
2 | import time
3 | import requests
4 | import streamlit as st
5 | from ui_components.constants import DefaultTimingStyleParams
6 | from utils.common_utils import get_current_user
7 | from shared.constants import ServerType, ConfigManager, GPU_INFERENCE_ENABLED_KEY
8 | from utils.data_repo.api_repo import APIRepo
9 | from utils.data_repo.data_repo import DataRepo
10 | from utils.state_refresh import refresh_app
11 |
12 |
13 | def query_logger_page():
14 | st.markdown("##### Inference log")
15 |
16 | data_repo = DataRepo()
17 | api_repo = APIRepo()
18 | config_manager = ConfigManager()
19 | gpu_enabled = config_manager.get(GPU_INFERENCE_ENABLED_KEY, False)
20 |
21 | if not gpu_enabled:
22 | credits_remaining = api_repo.get_user_credits()
23 |
24 | c01, c02, _ = st.columns([1, 1, 2])
25 |
26 | credits_remaining = round(credits_remaining, 3)
27 | with c01:
28 | st.write(f"### Credit Balance: {credits_remaining}")
29 |
30 | with c02:
31 | if st.button("Refresh Credits"):
32 | st.session_state["user_credit_data"] = None
33 | refresh_app()
34 |
35 | c1, c2, _ = st.columns([1, 1, 3])
36 | with c1:
37 | credits_to_buy = st.number_input(
38 | label="Credits to Buy (10 credits = $1)",
39 | key="credit_btn",
40 | min_value=50,
41 | step=20,
42 | )
43 |
44 | with c2:
45 | st.write("")
46 | st.write("")
47 | if st.button("Generate payment link"):
48 | payment_link = api_repo.generate_payment_link(int(credits_to_buy // 10))
49 | if payment_link:
50 | st.write("Please click on the link below to make the payment")
51 | st.write(payment_link)
52 | else:
53 | st.write("error occured during payment link generation, pls try again")
54 |
55 | b1, b2 = st.columns([1, 0.2])
56 |
57 | total_log_table_pages = (
58 | st.session_state["total_log_table_pages"]
59 | if "total_log_table_pages" in st.session_state
60 | else DefaultTimingStyleParams.total_log_table_pages
61 | )
62 | list_of_pages = [i for i in range(1, total_log_table_pages + 1)]
63 | page_number = b1.radio(
64 | "Select page:", options=list_of_pages, key="inference_log_page_number", index=0, horizontal=True
65 | )
66 | # page_number = b1.number_input('Page number', min_value=1, max_value=total_log_table_pages, value=1, step=1)
67 | inference_log_list, total_page_count = data_repo.get_all_inference_log_list(
68 | page=page_number, data_per_page=100
69 | )
70 |
71 | if total_log_table_pages != total_page_count:
72 | st.session_state["total_log_table_pages"] = total_page_count
73 | refresh_app()
74 |
75 | data = {
76 | "Project": [],
77 | "Prompt": [],
78 | "Model": [],
79 | "Inference time (sec)": [],
80 | "Credits": [],
81 | "Status": [],
82 | }
83 |
84 | # if SERVER != ServerType.DEVELOPMENT.value:
85 | # data[] = []
86 |
87 | if inference_log_list and len(inference_log_list):
88 | for log in inference_log_list:
89 | data["Project"].append(log.project.name)
90 | prompt = json.loads(log.input_params).get("prompt", "") if log.input_params else ""
91 | data["Prompt"].append(prompt)
92 | model_name = log.model_name
93 | data["Model"].append(model_name)
94 | data["Inference time (sec)"].append(round(log.total_inference_time, 3))
95 | data["Credits"].append(round(log.credits_used, 3))
96 | data["Status"].append(log.status)
97 |
98 | st.table(data=data)
99 | st.markdown("***")
100 | else:
101 | st.info("No logs present")
102 |
--------------------------------------------------------------------------------
/utils/ml_processor/comfy_workflows/sd3_workflow_api.json:
--------------------------------------------------------------------------------
1 | {
2 | "6": {
3 | "inputs": {
4 | "text": "a ninja in the night, red atmosphere with a giant red moon behind him",
5 | "clip": [
6 | "252",
7 | 1
8 | ]
9 | },
10 | "class_type": "CLIPTextEncode",
11 | "_meta": {
12 | "title": "CLIP Text Encode (Prompt)"
13 | }
14 | },
15 | "13": {
16 | "inputs": {
17 | "shift": 3,
18 | "model": [
19 | "252",
20 | 0
21 | ]
22 | },
23 | "class_type": "ModelSamplingSD3",
24 | "_meta": {
25 | "title": "ModelSamplingSD3"
26 | }
27 | },
28 | "67": {
29 | "inputs": {
30 | "conditioning": [
31 | "71",
32 | 0
33 | ]
34 | },
35 | "class_type": "ConditioningZeroOut",
36 | "_meta": {
37 | "title": "ConditioningZeroOut"
38 | }
39 | },
40 | "68": {
41 | "inputs": {
42 | "start": 0.1,
43 | "end": 1,
44 | "conditioning": [
45 | "67",
46 | 0
47 | ]
48 | },
49 | "class_type": "ConditioningSetTimestepRange",
50 | "_meta": {
51 | "title": "ConditioningSetTimestepRange"
52 | }
53 | },
54 | "69": {
55 | "inputs": {
56 | "conditioning_1": [
57 | "68",
58 | 0
59 | ],
60 | "conditioning_2": [
61 | "70",
62 | 0
63 | ]
64 | },
65 | "class_type": "ConditioningCombine",
66 | "_meta": {
67 | "title": "Conditioning (Combine)"
68 | }
69 | },
70 | "70": {
71 | "inputs": {
72 | "start": 0,
73 | "end": 0.1,
74 | "conditioning": [
75 | "71",
76 | 0
77 | ]
78 | },
79 | "class_type": "ConditioningSetTimestepRange",
80 | "_meta": {
81 | "title": "ConditioningSetTimestepRange"
82 | }
83 | },
84 | "71": {
85 | "inputs": {
86 | "text": "bad quality, poor quality, doll, disfigured, jpg, toy, bad anatomy, missing limbs, missing fingers, 3d, cgi",
87 | "clip": [
88 | "252",
89 | 1
90 | ]
91 | },
92 | "class_type": "CLIPTextEncode",
93 | "_meta": {
94 | "title": "CLIP Text Encode (Negative Prompt)"
95 | }
96 | },
97 | "135": {
98 | "inputs": {
99 | "width": 1024,
100 | "height": 768,
101 | "batch_size": 1
102 | },
103 | "class_type": "EmptySD3LatentImage",
104 | "_meta": {
105 | "title": "EmptySD3LatentImage"
106 | }
107 | },
108 | "231": {
109 | "inputs": {
110 | "samples": [
111 | "271",
112 | 0
113 | ],
114 | "vae": [
115 | "252",
116 | 2
117 | ]
118 | },
119 | "class_type": "VAEDecode",
120 | "_meta": {
121 | "title": "VAE Decode"
122 | }
123 | },
124 | "233": {
125 | "inputs": {
126 | "filename_prefix": "SD3",
127 | "images": [
128 | "231",
129 | 0
130 | ]
131 | },
132 | "class_type": "SaveImage",
133 | "_meta": {
134 | "title": "Save Image"
135 | }
136 | },
137 | "252": {
138 | "inputs": {
139 | "ckpt_name": "sd3_medium_incl_clips_t5xxlfp16.safetensors"
140 | },
141 | "class_type": "CheckpointLoaderSimple",
142 | "_meta": {
143 | "title": "Load Checkpoint"
144 | }
145 | },
146 | "271": {
147 | "inputs": {
148 | "seed": 945512652412924,
149 | "steps": 28,
150 | "cfg": 4.5,
151 | "sampler_name": "dpmpp_2m",
152 | "scheduler": "sgm_uniform",
153 | "denoise": 1,
154 | "model": [
155 | "13",
156 | 0
157 | ],
158 | "positive": [
159 | "6",
160 | 0
161 | ],
162 | "negative": [
163 | "69",
164 | 0
165 | ],
166 | "latent_image": [
167 | "135",
168 | 0
169 | ]
170 | },
171 | "class_type": "KSampler",
172 | "_meta": {
173 | "title": "KSampler"
174 | }
175 | }
176 | }
--------------------------------------------------------------------------------
/utils/ml_processor/comfy_workflows/ipadapter_plus_api.json:
--------------------------------------------------------------------------------
1 | {
2 | "3": {
3 | "inputs": {
4 | "seed": 641608455784125,
5 | "steps": 24,
6 | "cfg": 9.25,
7 | "sampler_name": "ddim",
8 | "scheduler": "normal",
9 | "denoise": 1,
10 | "model": [
11 | "27",
12 | 0
13 | ],
14 | "positive": [
15 | "6",
16 | 0
17 | ],
18 | "negative": [
19 | "7",
20 | 0
21 | ],
22 | "latent_image": [
23 | "5",
24 | 0
25 | ]
26 | },
27 | "class_type": "KSampler",
28 | "_meta": {
29 | "title": "KSampler"
30 | }
31 | },
32 | "4": {
33 | "inputs": {
34 | "ckpt_name": "sd_xl_base_1.0.safetensors"
35 | },
36 | "class_type": "CheckpointLoaderSimple",
37 | "_meta": {
38 | "title": "Load Checkpoint"
39 | }
40 | },
41 | "5": {
42 | "inputs": {
43 | "width": 1024,
44 | "height": 1024,
45 | "batch_size": 1
46 | },
47 | "class_type": "EmptyLatentImage",
48 | "_meta": {
49 | "title": "Empty Latent Image"
50 | }
51 | },
52 | "6": {
53 | "inputs": {
54 | "text": "",
55 | "clip": [
56 | "4",
57 | 1
58 | ]
59 | },
60 | "class_type": "CLIPTextEncode",
61 | "_meta": {
62 | "title": "CLIP Text Encode (Prompt)"
63 | }
64 | },
65 | "7": {
66 | "inputs": {
67 | "text": "",
68 | "clip": [
69 | "4",
70 | 1
71 | ]
72 | },
73 | "class_type": "CLIPTextEncode",
74 | "_meta": {
75 | "title": "CLIP Text Encode (Prompt)"
76 | }
77 | },
78 | "8": {
79 | "inputs": {
80 | "samples": [
81 | "3",
82 | 0
83 | ],
84 | "vae": [
85 | "4",
86 | 2
87 | ]
88 | },
89 | "class_type": "VAEDecode",
90 | "_meta": {
91 | "title": "VAE Decode"
92 | }
93 | },
94 | "23": {
95 | "inputs": {
96 | "clip_name": "SDXL/pytorch_model.bin"
97 | },
98 | "class_type": "CLIPVisionLoader",
99 | "_meta": {
100 | "title": "Load CLIP Vision"
101 | }
102 | },
103 | "26": {
104 | "inputs": {
105 | "ipadapter_file": "ip-adapter_sdxl.safetensors"
106 | },
107 | "class_type": "IPAdapterModelLoader",
108 | "_meta": {
109 | "title": "Load IPAdapter Model"
110 | }
111 | },
112 | "27": {
113 | "inputs": {
114 | "weight": 1,
115 | "weight_type": "linear",
116 | "combine_embeds": "concat",
117 | "embeds_scaling": "V only",
118 | "start_at": 0,
119 | "end_at": 1,
120 | "ipadapter": [
121 | "26",
122 | 0
123 | ],
124 | "clip_vision": [
125 | "23",
126 | 0
127 | ],
128 | "image": [
129 | "39",
130 | 0
131 | ],
132 | "model": [
133 | "4",
134 | 0
135 | ]
136 | },
137 | "class_type": "IPAdapterAdvanced",
138 | "_meta": {
139 | "title": "IPAdapter Advanced"
140 | }
141 | },
142 | "28": {
143 | "inputs": {
144 | "image": "714d97a3fe2dcf645f1b500d523d4d3c848acc65bfda3602c56305fc.jpg",
145 | "upload": "image"
146 | },
147 | "class_type": "LoadImage",
148 | "_meta": {
149 | "title": "Load Image"
150 | }
151 | },
152 | "29": {
153 | "inputs": {
154 | "filename_prefix": "ComfyUI",
155 | "images": [
156 | "8",
157 | 0
158 | ]
159 | },
160 | "class_type": "SaveImage",
161 | "_meta": {
162 | "title": "Save Image"
163 | }
164 | },
165 | "39": {
166 | "inputs": {
167 | "interpolation": "LANCZOS",
168 | "crop_position": "top",
169 | "sharpening": 0,
170 | "image": [
171 | "28",
172 | 0
173 | ]
174 | },
175 | "class_type": "PrepImageForClipVision",
176 | "_meta": {
177 | "title": "Prepare Image For Clip Vision"
178 | }
179 | }
180 | }
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Open Source Native License (OSNL)
2 | Version 0.1 - March 1, 2024
3 |
4 | Preamble
5 |
6 | The Open Source Native License (OSNL) is designed to ensure that software remains free and open, fostering innovation and knowledge sharing within the community. It grants individuals, researchers, and commercial entities who open source their primary business assets, the freedom to use the software in any manner they choose.
7 |
8 | This distinctive approach aims to balance the benefits of open-source development with the realities of commercial enterprise. It ensures that software remains a shared, community-driven resource while enabling businesses to thrive in an open-source ecosystem. Additional licenses are available for non-open source commercial entities.
9 |
10 | 1. Definitions
11 |
12 | - "This License" refers to Version 1.0 of the Open Source Native License.
13 | - "The Program" refers to the software distributed under this License.
14 | - "You" refers to the individual or entity utilizing or contributing to the Program.
15 | - "Primary Business Assets" are the core resources, capabilities, and technology that constitute the main value proposition and operational basis of your business.
16 |
17 | 2. Grant of License
18 |
19 | Subject to the terms and conditions of this License, you are hereby granted a free, perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable license to use, reproduce, modify, distribute, and sublicense the Program, provided you comply with the following condition:
20 |
21 | - Individual or researcher: You are granted the rights to use, modify, distribute, and contribute to the Program for any purpose, including educational, research, and personal projects, without the necessity to make your personal projects open source, provided these activities do not constitute a commercial enterprise. For any use that transitions to commercial purposes, the conditions applicable to commercial entities as outlined in this License will then apply.
22 | - Commercial Entity who meets open source condition: Your primary business assets, including all core technologies, software, and platforms, must be available under an OSI-approved open source license or OSNL. This condition does not apply to ancillary or peripheral services not constituting primary business assets.
23 |
24 | 2.1 Commercial Use by Non-Open Source Businesses
25 |
26 | Non-open source businesses that wish to utilize the Program or its derivatives as a component of their products or services are required to obtain an additional license. These entities must proactively contact Banodoco to request such a license. Banodoco reserves the right, at its own discretion, to grant or deny this additional license. Until an additional license is granted by Banodoco, non-open source businesses are not authorized to exercise any rights provided under this License regarding the use of the Program or its derivatives.
27 |
28 | 3. Redistribution
29 |
30 | You may reproduce and distribute copies of the Program or derivative works thereof in any medium, with or without modifications, provided that you meet the following conditions:
31 |
32 | - You must give any recipients of the Program a copy of this License.
33 | - You must ensure that any modified files carry prominent notices stating that you changed the files.
34 | - You must disclose the source of the Program, and if you distribute any portion of it in a compiled or object code form, you must also provide the full source code under this License.
35 | - Any distribution of the Program or derivative works must comply with the Primary Business Open Source Condition.
36 |
37 | 4. Disclaimer of Warranty
38 |
39 | THE PROGRAM IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY ARISING FROM THE USE OF THE PROGRAM.
40 |
41 | 5. General
42 |
43 | This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Program.
--------------------------------------------------------------------------------
/ui_components/components/new_project_page.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from banodoco_settings import create_new_project
3 | from ui_components.methods.common_methods import (
4 | save_audio_file,
5 | create_frame_inside_shot,
6 | save_and_promote_image,
7 | )
8 | from utils.state_refresh import refresh_app
9 | from utils.common_utils import get_current_user_uuid, reset_project_state
10 | from utils.data_repo.data_repo import DataRepo
11 | import time
12 | from PIL import Image
13 |
14 | import utils.local_storage.local_storage as local_storage
15 |
16 |
17 | def new_project_page():
18 |
19 | # Initialize data repository
20 | data_repo = DataRepo()
21 |
22 | # title
23 | st.markdown("#### New Project")
24 | st.markdown("***")
25 | # Define multicolumn layout
26 | project_column, _ = st.columns([1, 3])
27 |
28 | # Prompt user for project naming within project_column
29 | with project_column:
30 | new_project_name = st.text_input("Project name:", value="")
31 |
32 | # Prompt user for video dimension specifications
33 | v1, v2, v3 = st.columns([6, 3, 9])
34 |
35 | frame_sizes = ["512x512", "768x512", "512x768", "512x896", "896x512", "512x1024", "1024x512"]
36 | with v1:
37 |
38 | frame_size = st.radio("Select frame size:", options=frame_sizes, key="frame_size", horizontal=True)
39 | width, height = map(int, frame_size.split("x"))
40 |
41 | with v2:
42 | # is width or height > 767
43 | if width > 769 or height > 769:
44 | st.warning("There may be issues with very wide or high frames.")
45 | img = Image.new("RGB", (width, height), color=(73, 109, 137))
46 | st.image(img, use_column_width=True)
47 | # st.info("Uploaded images will be resized to the selected dimensions.")
48 |
49 | # with v1:
50 | # audio = st.radio("Audio:", ["No audio", "Attach new audio"], key="audio", horizontal=True)
51 |
52 | # # Display audio upload option if user selects "Attach new audio"
53 | # if audio == "Attach new audio":
54 | # uploaded_audio = st.file_uploader("Choose an audio file:")
55 | # else:
56 | # uploaded_audio = None
57 |
58 | st.write("")
59 |
60 | if st.button("Create New Project"):
61 | # Add checks for project name existence and format
62 | if not new_project_name:
63 | st.error("Please enter a project name.")
64 | else:
65 | current_user = data_repo.get_first_active_user()
66 | new_project, shot = create_new_project(current_user, new_project_name, width, height)
67 | new_timing = create_frame_inside_shot(shot.uuid, 0)
68 | # remvoing the initial frame which moved to the 1st position
69 | # (since creating new project also creates a frame)
70 | shot = data_repo.get_shot_from_number(new_project.uuid, 1)
71 | initial_frame = data_repo.get_timing_from_frame_number(shot.uuid, 0)
72 | data_repo.delete_timing_from_uuid(initial_frame.uuid)
73 |
74 | # if uploaded_audio:
75 | # try:
76 | # if save_audio_file(uploaded_audio, new_project.uuid):
77 | # st.success("Audio file saved and attached successfully.")
78 | # else:
79 | # st.error("Failed to save and attach the audio file.")
80 | # except Exception as e:
81 | # st.error(f"Failed to save the uploaded audio due to {str(e)}")
82 |
83 | reset_project_state()
84 |
85 | st.session_state["project_uuid"] = new_project.uuid
86 | project_list = data_repo.get_all_project_list(user_id=get_current_user_uuid())
87 | st.session_state["index_of_project_name"] = len(project_list) - 1
88 | st.session_state["main_view_type"] = "Creative Process"
89 | st.session_state["app_settings"] = 0
90 | st.success("Project created successfully!")
91 | time.sleep(1)
92 | refresh_app()
93 |
94 | st.markdown("***")
95 |
--------------------------------------------------------------------------------
/utils/ml_processor/comfy_workflows/dynamicrafter_api.json:
--------------------------------------------------------------------------------
1 | {
2 | "2": {
3 | "inputs": {
4 | "frame_rate": 12,
5 | "loop_count": 0,
6 | "filename_prefix": "AnimateDiff",
7 | "format": "video/h264-mp4",
8 | "pix_fmt": "yuv420p",
9 | "crf": 19,
10 | "save_metadata": true,
11 | "pingpong": false,
12 | "save_output": true,
13 | "images": [
14 | "34",
15 | 0
16 | ]
17 | },
18 | "class_type": "VHS_VideoCombine",
19 | "_meta": {
20 | "title": "Video Combine 🎥🅥🅗🅢"
21 | }
22 | },
23 | "11": {
24 | "inputs": {
25 | "ckpt_name": "dynamicrafter_512_interp_v1.ckpt",
26 | "dtype": "auto",
27 | "fp8_unet": false
28 | },
29 | "class_type": "DynamiCrafterModelLoader",
30 | "_meta": {
31 | "title": "DynamiCrafterModelLoader"
32 | }
33 | },
34 | "12": {
35 | "inputs": {
36 | "steps": 50,
37 | "cfg": 5,
38 | "eta": 1,
39 | "frames": 16,
40 | "prompt": "dolly zoom out",
41 | "seed": 262623773159722,
42 | "fs": 10,
43 | "keep_model_loaded": true,
44 | "vae_dtype": "auto",
45 | "cut_near_keyframes": 0,
46 | "model": [
47 | "11",
48 | 0
49 | ],
50 | "images": [
51 | "15",
52 | 0
53 | ]
54 | },
55 | "class_type": "DynamiCrafterBatchInterpolation",
56 | "_meta": {
57 | "title": "DynamiCrafterBatchInterpolation"
58 | }
59 | },
60 | "15": {
61 | "inputs": {
62 | "image1": [
63 | "37",
64 | 0
65 | ],
66 | "image2": [
67 | "38",
68 | 0
69 | ]
70 | },
71 | "class_type": "ImageBatch",
72 | "_meta": {
73 | "title": "Batch Images"
74 | }
75 | },
76 | "16": {
77 | "inputs": {
78 | "image": "ea47a572b4e5b52ea7da22384232381b3e62048fa715f042b38b4da9 (1) (2).jpg",
79 | "upload": "image"
80 | },
81 | "class_type": "LoadImage",
82 | "_meta": {
83 | "title": "Load Image"
84 | }
85 | },
86 | "17": {
87 | "inputs": {
88 | "image": "2193d9ded46130b41d09133b4b1d2502f0eaa19ea1762252c6581e86 (1) (1).jpg",
89 | "upload": "image"
90 | },
91 | "class_type": "LoadImage",
92 | "_meta": {
93 | "title": "Load Image"
94 | }
95 | },
96 | "34": {
97 | "inputs": {
98 | "ckpt_name": "film_net_fp32.pt",
99 | "clear_cache_after_n_frames": 10,
100 | "multiplier": 3,
101 | "frames": [
102 | "12",
103 | 0
104 | ]
105 | },
106 | "class_type": "FILM VFI",
107 | "_meta": {
108 | "title": "FILM VFI"
109 | }
110 | },
111 | "35": {
112 | "inputs": {
113 | "frame_rate": 8,
114 | "loop_count": 0,
115 | "filename_prefix": "AnimateDiff",
116 | "format": "image/gif",
117 | "pingpong": false,
118 | "save_output": true,
119 | "images": [
120 | "12",
121 | 0
122 | ]
123 | },
124 | "class_type": "VHS_VideoCombine",
125 | "_meta": {
126 | "title": "Video Combine 🎥🅥🅗🅢"
127 | }
128 | },
129 | "37": {
130 | "inputs": {
131 | "mode": "rescale",
132 | "supersample": "true",
133 | "resampling": "lanczos",
134 | "rescale_factor": 0.7000000000000001,
135 | "resize_width": 1024,
136 | "resize_height": 1536,
137 | "image": [
138 | "16",
139 | 0
140 | ]
141 | },
142 | "class_type": "Image Resize",
143 | "_meta": {
144 | "title": "Image Resize"
145 | }
146 | },
147 | "38": {
148 | "inputs": {
149 | "mode": "rescale",
150 | "supersample": "true",
151 | "resampling": "lanczos",
152 | "rescale_factor": 0.7000000000000001,
153 | "resize_width": 1024,
154 | "resize_height": 1536,
155 | "image": [
156 | "17",
157 | 0
158 | ]
159 | },
160 | "class_type": "Image Resize",
161 | "_meta": {
162 | "title": "Image Resize"
163 | }
164 | }
165 | }
--------------------------------------------------------------------------------
/utils/ml_processor/comfy_workflows/sdxl_controlnet_workflow_api.json:
--------------------------------------------------------------------------------
1 | {
2 | "3": {
3 | "inputs": {
4 | "seed": 741148140596738,
5 | "steps": 20,
6 | "cfg": 8,
7 | "sampler_name": "euler",
8 | "scheduler": "normal",
9 | "denoise": 1,
10 | "model": [
11 | "4",
12 | 0
13 | ],
14 | "positive": [
15 | "14",
16 | 0
17 | ],
18 | "negative": [
19 | "7",
20 | 0
21 | ],
22 | "latent_image": [
23 | "5",
24 | 0
25 | ]
26 | },
27 | "class_type": "KSampler",
28 | "_meta": {
29 | "title": "KSampler"
30 | }
31 | },
32 | "4": {
33 | "inputs": {
34 | "ckpt_name": "sd_xl_base_1.0.safetensors"
35 | },
36 | "class_type": "CheckpointLoaderSimple",
37 | "_meta": {
38 | "title": "Load Checkpoint"
39 | }
40 | },
41 | "5": {
42 | "inputs": {
43 | "width": 1024,
44 | "height": 1024,
45 | "batch_size": 1
46 | },
47 | "class_type": "EmptyLatentImage",
48 | "_meta": {
49 | "title": "Empty Latent Image"
50 | }
51 | },
52 | "6": {
53 | "inputs": {
54 | "text": "a person standing in an open field",
55 | "clip": [
56 | "4",
57 | 1
58 | ]
59 | },
60 | "class_type": "CLIPTextEncode",
61 | "_meta": {
62 | "title": "CLIP Text Encode (Prompt)"
63 | }
64 | },
65 | "7": {
66 | "inputs": {
67 | "text": "text, watermark",
68 | "clip": [
69 | "4",
70 | 1
71 | ]
72 | },
73 | "class_type": "CLIPTextEncode",
74 | "_meta": {
75 | "title": "CLIP Text Encode (Prompt)"
76 | }
77 | },
78 | "8": {
79 | "inputs": {
80 | "samples": [
81 | "3",
82 | 0
83 | ],
84 | "vae": [
85 | "4",
86 | 2
87 | ]
88 | },
89 | "class_type": "VAEDecode",
90 | "_meta": {
91 | "title": "VAE Decode"
92 | }
93 | },
94 | "9": {
95 | "inputs": {
96 | "filename_prefix": "ComfyUI",
97 | "images": [
98 | "8",
99 | 0
100 | ]
101 | },
102 | "class_type": "SaveImage",
103 | "_meta": {
104 | "title": "Save Image"
105 | }
106 | },
107 | "12": {
108 | "inputs": {
109 | "low_threshold": 0.19999999999999984,
110 | "high_threshold": 0.7,
111 | "image": [
112 | "17",
113 | 0
114 | ]
115 | },
116 | "class_type": "Canny",
117 | "_meta": {
118 | "title": "Canny"
119 | }
120 | },
121 | "13": {
122 | "inputs": {
123 | "image": "boy_sunshine.png",
124 | "upload": "image"
125 | },
126 | "class_type": "LoadImage",
127 | "_meta": {
128 | "title": "Load Image"
129 | }
130 | },
131 | "14": {
132 | "inputs": {
133 | "strength": 1,
134 | "conditioning": [
135 | "6",
136 | 0
137 | ],
138 | "control_net": [
139 | "22",
140 | 0
141 | ],
142 | "image": [
143 | "12",
144 | 0
145 | ]
146 | },
147 | "class_type": "ControlNetApply",
148 | "_meta": {
149 | "title": "Apply ControlNet"
150 | }
151 | },
152 | "17": {
153 | "inputs": {
154 | "upscale_method": "nearest-exact",
155 | "width": 1024,
156 | "height": 1024,
157 | "crop": "center",
158 | "image": [
159 | "13",
160 | 0
161 | ]
162 | },
163 | "class_type": "ImageScale",
164 | "_meta": {
165 | "title": "Upscale Image"
166 | }
167 | },
168 | "22": {
169 | "inputs": {
170 | "control_net_name": "canny_diffusion_pytorch_model.safetensors"
171 | },
172 | "class_type": "ControlNetLoader",
173 | "_meta": {
174 | "title": "Load ControlNet Model"
175 | }
176 | },
177 | "25": {
178 | "inputs": {
179 | "images": [
180 | "12",
181 | 0
182 | ]
183 | },
184 | "class_type": "PreviewImage",
185 | "_meta": {
186 | "title": "Preview Image"
187 | }
188 | }
189 | }
--------------------------------------------------------------------------------
/utils/ml_processor/comfy_workflows/sdxl_workflow_api.json:
--------------------------------------------------------------------------------
1 | {
2 | "4": {
3 | "inputs": {
4 | "ckpt_name": "sd_xl_base_1.0.safetensors"
5 | },
6 | "class_type": "CheckpointLoaderSimple",
7 | "_meta": {
8 | "title": "Load Checkpoint - BASE"
9 | }
10 | },
11 | "5": {
12 | "inputs": {
13 | "width": 1024,
14 | "height": 1024,
15 | "batch_size": 1
16 | },
17 | "class_type": "EmptyLatentImage",
18 | "_meta": {
19 | "title": "Empty Latent Image"
20 | }
21 | },
22 | "6": {
23 | "inputs": {
24 | "text": "evening sunset scenery blue sky nature, glass bottle with a galaxy in it",
25 | "clip": [
26 | "4",
27 | 1
28 | ]
29 | },
30 | "class_type": "CLIPTextEncode",
31 | "_meta": {
32 | "title": "CLIP Text Encode (Prompt)"
33 | }
34 | },
35 | "7": {
36 | "inputs": {
37 | "text": "text, watermark",
38 | "clip": [
39 | "4",
40 | 1
41 | ]
42 | },
43 | "class_type": "CLIPTextEncode",
44 | "_meta": {
45 | "title": "CLIP Text Encode (Prompt)"
46 | }
47 | },
48 | "10": {
49 | "inputs": {
50 | "add_noise": "enable",
51 | "noise_seed": 721897303308196,
52 | "steps": 25,
53 | "cfg": 8,
54 | "sampler_name": "euler",
55 | "scheduler": "normal",
56 | "start_at_step": 0,
57 | "end_at_step": 20,
58 | "return_with_leftover_noise": "enable",
59 | "model": [
60 | "4",
61 | 0
62 | ],
63 | "positive": [
64 | "6",
65 | 0
66 | ],
67 | "negative": [
68 | "7",
69 | 0
70 | ],
71 | "latent_image": [
72 | "5",
73 | 0
74 | ]
75 | },
76 | "class_type": "KSamplerAdvanced",
77 | "_meta": {
78 | "title": "KSampler (Advanced) - BASE"
79 | }
80 | },
81 | "11": {
82 | "inputs": {
83 | "add_noise": "disable",
84 | "noise_seed": 0,
85 | "steps": 25,
86 | "cfg": 8,
87 | "sampler_name": "euler",
88 | "scheduler": "normal",
89 | "start_at_step": 20,
90 | "end_at_step": 10000,
91 | "return_with_leftover_noise": "disable",
92 | "model": [
93 | "12",
94 | 0
95 | ],
96 | "positive": [
97 | "15",
98 | 0
99 | ],
100 | "negative": [
101 | "16",
102 | 0
103 | ],
104 | "latent_image": [
105 | "10",
106 | 0
107 | ]
108 | },
109 | "class_type": "KSamplerAdvanced",
110 | "_meta": {
111 | "title": "KSampler (Advanced) - REFINER"
112 | }
113 | },
114 | "12": {
115 | "inputs": {
116 | "ckpt_name": "sd_xl_refiner_1.0.safetensors"
117 | },
118 | "class_type": "CheckpointLoaderSimple",
119 | "_meta": {
120 | "title": "Load Checkpoint - REFINER"
121 | }
122 | },
123 | "15": {
124 | "inputs": {
125 | "text": "evening sunset scenery blue sky nature, glass bottle with a galaxy in it",
126 | "clip": [
127 | "12",
128 | 1
129 | ]
130 | },
131 | "class_type": "CLIPTextEncode",
132 | "_meta": {
133 | "title": "CLIP Text Encode (Prompt)"
134 | }
135 | },
136 | "16": {
137 | "inputs": {
138 | "text": "text, watermark",
139 | "clip": [
140 | "12",
141 | 1
142 | ]
143 | },
144 | "class_type": "CLIPTextEncode",
145 | "_meta": {
146 | "title": "CLIP Text Encode (Prompt)"
147 | }
148 | },
149 | "17": {
150 | "inputs": {
151 | "samples": [
152 | "11",
153 | 0
154 | ],
155 | "vae": [
156 | "12",
157 | 2
158 | ]
159 | },
160 | "class_type": "VAEDecode",
161 | "_meta": {
162 | "title": "VAE Decode"
163 | }
164 | },
165 | "19": {
166 | "inputs": {
167 | "filename_prefix": "ComfyUI",
168 | "images": [
169 | "17",
170 | 0
171 | ]
172 | },
173 | "class_type": "SaveImage",
174 | "_meta": {
175 | "title": "Save Image"
176 | }
177 | }
178 | }
--------------------------------------------------------------------------------
/ui_components/widgets/download_file_progress_bar.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import string
4 | import tarfile
5 | import time
6 | import zipfile
7 | import requests
8 | import streamlit as st
9 | from utils.common_decorators import with_refresh_lock
10 | from utils.state_refresh import refresh_app
11 |
12 |
13 | @with_refresh_lock
14 | def download_file_widget(url, filename, dest):
15 | save_directory = dest
16 | zip_filename = filename
17 | filepath = os.path.join(save_directory, zip_filename)
18 |
19 | # ------- deleting partial downloads
20 | if st.session_state.get("delete_partial_download", None):
21 | fp = st.session_state["delete_partial_download"]
22 | st.session_state["delete_partial_download"] = None
23 | if os.path.exists(fp):
24 | os.remove(fp)
25 | st.info("Partial downloads deleted")
26 | time.sleep(0.3)
27 | refresh_app()
28 |
29 | # checking if the file already exists
30 | if os.path.exists(os.path.join(dest, filename)):
31 | st.warning("File already present")
32 | time.sleep(1)
33 | refresh_app()
34 |
35 | # setting this file for deletion, incase it's not downloaded properly
36 | # if it is downloaded properly then it will be removed from here (all these steps because of streamlit!)
37 | st.session_state["delete_partial_download"] = filepath
38 |
39 | with st.spinner("Downloading model..."):
40 | download_bar = st.progress(0, text="")
41 | os.makedirs(save_directory, exist_ok=True) # Create the directory if it doesn't exist
42 |
43 | # Download the model and save it to the directory
44 | response = requests.get(url, stream=True)
45 | cancel_download = False
46 |
47 | if st.button("Cancel"):
48 | st.session_state["delete_partial_download"] = filepath
49 |
50 | if response.status_code == 200:
51 | total_size = int(response.headers.get("content-length", 0))
52 | total_size_mb = total_size / (1024 * 1024)
53 |
54 | start_time = time.time()
55 |
56 | with open(filepath, "wb") as f:
57 | received_bytes = 0
58 | for data in response.iter_content(chunk_size=1048576):
59 | if cancel_download:
60 | raise Exception("download cancelled")
61 |
62 | f.write(data)
63 | received_bytes += len(data)
64 | progress = received_bytes / total_size
65 | received_mb = received_bytes / (1024 * 1024)
66 |
67 | elapsed_time = time.time() - start_time
68 | download_speed = received_bytes / elapsed_time / (1024 * 1024)
69 |
70 | download_bar.progress(
71 | progress,
72 | text=f"Downloaded: {received_mb:.2f} MB / {total_size_mb:.2f} MB | Speed: {download_speed:.2f} MB/sec",
73 | )
74 |
75 | st.success(f"Downloaded {filename} and saved to {save_directory}")
76 | time.sleep(1)
77 | download_bar.empty()
78 |
79 | if url.endswith(".zip") or url.endswith(".tar"):
80 | st.success("Extracting the zip file. Please wait...")
81 | new_filepath = filepath.replace(zip_filename, "")
82 | if url.endswith(".zip"):
83 | with zipfile.ZipFile(f"{filepath}", "r") as zip_ref:
84 | zip_ref.extractall(new_filepath)
85 | else:
86 | with tarfile.open(f"{filepath}", "r") as tar_ref:
87 | tar_ref.extractall(new_filepath)
88 |
89 | os.remove(filepath)
90 | else:
91 | os.rename(filepath, filepath.replace(zip_filename, filename))
92 | print("removing ---------")
93 |
94 | st.session_state["delete_partial_download"] = None
95 | else:
96 | st.error("Unable to access model url")
97 | time.sleep(1)
98 |
99 | refresh_app()
100 |
--------------------------------------------------------------------------------
/utils/ml_processor/comfy_workflows/sdxl_openpose_workflow_api.json:
--------------------------------------------------------------------------------
1 | {
2 | "3": {
3 | "inputs": {
4 | "seed": 253663277835217,
5 | "steps": 20,
6 | "cfg": 8,
7 | "sampler_name": "euler",
8 | "scheduler": "normal",
9 | "denoise": 1,
10 | "model": [
11 | "4",
12 | 0
13 | ],
14 | "positive": [
15 | "16",
16 | 0
17 | ],
18 | "negative": [
19 | "7",
20 | 0
21 | ],
22 | "latent_image": [
23 | "5",
24 | 0
25 | ]
26 | },
27 | "class_type": "KSampler",
28 | "_meta": {
29 | "title": "KSampler"
30 | }
31 | },
32 | "4": {
33 | "inputs": {
34 | "ckpt_name": "sd_xl_base_1.0_0.9vae.safetensors"
35 | },
36 | "class_type": "CheckpointLoaderSimple",
37 | "_meta": {
38 | "title": "Load Checkpoint"
39 | }
40 | },
41 | "5": {
42 | "inputs": {
43 | "width": 624,
44 | "height": 624,
45 | "batch_size": 1
46 | },
47 | "class_type": "EmptyLatentImage",
48 | "_meta": {
49 | "title": "Empty Latent Image"
50 | }
51 | },
52 | "6": {
53 | "inputs": {
54 | "text": "a ballerina, romantic sunset, 4k photo",
55 | "clip": [
56 | "4",
57 | 1
58 | ]
59 | },
60 | "class_type": "CLIPTextEncode",
61 | "_meta": {
62 | "title": "CLIP Text Encode (Prompt)"
63 | }
64 | },
65 | "7": {
66 | "inputs": {
67 | "text": "text, watermark",
68 | "clip": [
69 | "4",
70 | 1
71 | ]
72 | },
73 | "class_type": "CLIPTextEncode",
74 | "_meta": {
75 | "title": "CLIP Text Encode (Prompt)"
76 | }
77 | },
78 | "8": {
79 | "inputs": {
80 | "samples": [
81 | "3",
82 | 0
83 | ],
84 | "vae": [
85 | "4",
86 | 2
87 | ]
88 | },
89 | "class_type": "VAEDecode",
90 | "_meta": {
91 | "title": "VAE Decode"
92 | }
93 | },
94 | "9": {
95 | "inputs": {
96 | "filename_prefix": "ComfyUI",
97 | "images": [
98 | "8",
99 | 0
100 | ]
101 | },
102 | "class_type": "SaveImage",
103 | "_meta": {
104 | "title": "Save Image"
105 | }
106 | },
107 | "10": {
108 | "inputs": {
109 | "detect_hand": "enable",
110 | "detect_body": "enable",
111 | "detect_face": "enable",
112 | "resolution": "v1.1",
113 | "image": [
114 | "11",
115 | 0
116 | ]
117 | },
118 | "class_type": "OpenposePreprocessor",
119 | "_meta": {
120 | "title": "OpenPose Pose"
121 | }
122 | },
123 | "11": {
124 | "inputs": {
125 | "upscale_method": "nearest-exact",
126 | "width": 623,
127 | "height": 623,
128 | "crop": "disabled",
129 | "image": [
130 | "12",
131 | 0
132 | ]
133 | },
134 | "class_type": "ImageScale",
135 | "_meta": {
136 | "title": "Upscale Image"
137 | }
138 | },
139 | "12": {
140 | "inputs": {
141 | "image": "boy_sunshine.png",
142 | "upload": "image"
143 | },
144 | "class_type": "LoadImage",
145 | "_meta": {
146 | "title": "Load Image"
147 | }
148 | },
149 | "13": {
150 | "inputs": {
151 | "images": [
152 | "10",
153 | 0
154 | ]
155 | },
156 | "class_type": "PreviewImage",
157 | "_meta": {
158 | "title": "Preview Image"
159 | }
160 | },
161 | "14": {
162 | "inputs": {
163 | "control_net_name": "OpenPoseXL2.safetensors"
164 | },
165 | "class_type": "ControlNetLoader",
166 | "_meta": {
167 | "title": "Load ControlNet Model"
168 | }
169 | },
170 | "16": {
171 | "inputs": {
172 | "strength": 1,
173 | "conditioning": [
174 | "6",
175 | 0
176 | ],
177 | "control_net": [
178 | "14",
179 | 0
180 | ],
181 | "image": [
182 | "10",
183 | 0
184 | ]
185 | },
186 | "class_type": "ControlNetApply",
187 | "_meta": {
188 | "title": "Apply ControlNet"
189 | }
190 | }
191 | }
--------------------------------------------------------------------------------
/auto_refresh.py:
--------------------------------------------------------------------------------
1 | import os
2 | import platform
3 | import signal
4 | import sys
5 | import portalocker
6 | import json
7 | from flask import Flask, jsonify, request
8 | import logging
9 | import threading
10 | import time
11 | import random
12 | from queue import Queue
13 |
14 | import requests
15 |
16 | from utils.constants import REFRESH_LOCK_FILE, REFRESH_PROCESS_PORT, REFRESH_TARGET_FILE
17 |
18 | app = Flask(__name__)
19 |
20 |
21 | target_file = REFRESH_TARGET_FILE
22 | lock_file = REFRESH_LOCK_FILE
23 | refresh_queue = Queue()
24 | refresh_thread = None
25 |
26 | last_refreshed_on = 0
27 | REFRESH_BUFFER_TIME = 10 # seconds before making consecutive refreshes
28 | TERMINATE_SCRIPT = False
29 |
30 |
31 | def handle_termination(signal, frame):
32 | print("Received termination signal - auto refresh. Cleaning up...")
33 | global TERMINATE_SCRIPT
34 | TERMINATE_SCRIPT = True
35 | os._exit(1)
36 |
37 |
38 | if platform.system() == "Windows":
39 | signal.signal(signal.SIGINT, handle_termination)
40 |
41 | signal.signal(signal.SIGTERM, handle_termination)
42 |
43 |
44 | def check_lock():
45 | if not os.path.exists(lock_file):
46 | return False
47 |
48 | try:
49 | with portalocker.Lock(lock_file, "r", timeout=0.1) as lock_file_handle:
50 | data = json.load(lock_file_handle)
51 | return data["status"] == "locked"
52 | except (portalocker.LockException, FileNotFoundError, json.JSONDecodeError):
53 | return False
54 |
55 |
56 | def refresh():
57 | while check_lock():
58 | # print("process locked.. sleeping")
59 | time.sleep(2)
60 |
61 | global last_refreshed_on
62 | while int(time.time()) - last_refreshed_on < REFRESH_BUFFER_TIME:
63 | # print(f"waiting {REFRESH_BUFFER_TIME} secs before the next refresh")
64 | time.sleep(2)
65 |
66 | # print("Refreshing...")
67 | last_refreshed_on = int(time.time())
68 | with portalocker.Lock(target_file, "w") as f:
69 | f.write(f"SAVE_STATE = {random.randint(1, 1000)}")
70 | return True
71 |
72 |
73 | def refresh_worker():
74 | while not TERMINATE_SCRIPT:
75 | refresh_queue.get()
76 | refresh()
77 |
78 |
79 | @app.route("/refresh", methods=["POST"])
80 | def trigger_refresh():
81 | if refresh_queue.empty():
82 | refresh_queue.put(True)
83 | return jsonify({"success": True, "message": "Refresh request queued"})
84 | else:
85 | return jsonify({"success": False, "message": "Refresh already queued"})
86 |
87 |
88 | @app.route("/health", methods=["GET"])
89 | def health_check():
90 | return jsonify({"status": "healthy"}), 200
91 |
92 |
93 | def run_flask():
94 | # disabling flask's default logger
95 | log = logging.getLogger("werkzeug")
96 | log.setLevel(logging.ERROR)
97 |
98 | # disabling the output stream
99 | cli = sys.modules["flask.cli"]
100 | cli.show_server_banner = lambda *x: None
101 |
102 | app.run(host="0.0.0.0", port=REFRESH_PROCESS_PORT)
103 |
104 |
105 | def main():
106 | # flask server
107 | flask_thread = threading.Thread(target=run_flask)
108 | flask_thread.start()
109 |
110 | # refresh worker
111 | refresh_thread = threading.Thread(target=refresh_worker)
112 | refresh_thread.daemon = True
113 | refresh_thread.start()
114 |
115 | try:
116 | while not TERMINATE_SCRIPT:
117 | time.sleep(1)
118 | except KeyboardInterrupt:
119 | print("Keyboard interrupt received. Shutting down...")
120 | handle_termination(signal.SIGINT, None)
121 |
122 | print("Waiting for threads to finish...")
123 | # running the main thread as long as the flask server is active
124 | flask_thread.join(timeout=5)
125 | refresh_thread.join(timeout=5)
126 |
127 | if not (flask_thread.is_alive() and refresh_thread.is_alive()):
128 | print("Auto refresh terminated.")
129 | else:
130 | print("threads didn't shutdown")
131 | print("refresh thread active: ", refresh_thread.is_alive())
132 | print("flask thread active: ", flask_thread.is_alive())
133 |
134 |
135 | if __name__ == "__main__":
136 | main()
137 |
--------------------------------------------------------------------------------
/ui_components/constants.py:
--------------------------------------------------------------------------------
1 | import os
2 | from shared.constants import COMFY_BASE_PATH, AnimationStyleType, AnimationToolType
3 | from utils.constants import ImageStage
4 | from utils.enum import ExtendedEnum
5 |
6 |
7 | class WorkflowStageType(ExtendedEnum):
8 | SOURCE = "source"
9 | STYLED = "styled"
10 |
11 |
12 | class VideoQuality(ExtendedEnum):
13 | HIGH = "High-Quality"
14 | LOW = "Low"
15 |
16 |
17 | class CreativeProcessType(ExtendedEnum):
18 | STYLING = "Key Frames"
19 | MOTION = "Shots"
20 |
21 |
22 | class ShotMetaData(ExtendedEnum):
23 | MOTION_DATA = "motion_data" # {"timing_data": [...], "main_setting_data": {}}
24 | DYNAMICRAFTER_DATA = "dynamicrafter_data"
25 |
26 |
27 | class GalleryImageViewType(ExtendedEnum):
28 | EXPLORER_ONLY = "explorer"
29 | SHOT_ONLY = "shot"
30 | ANY = "any"
31 |
32 |
33 | class DefaultTimingStyleParams:
34 | prompt = ""
35 | negative_prompt = "bad image, worst quality"
36 | strength = 1
37 | guidance_scale = 7.5
38 | seed = 0
39 | num_inference_steps = 25
40 | low_threshold = 100
41 | high_threshold = 200
42 | adapter_type = None
43 | interpolation_steps = 3
44 | transformation_stage = ImageStage.SOURCE_IMAGE.value
45 | custom_model_id_list = []
46 | animation_tool = AnimationToolType.G_FILM.value
47 | animation_style = AnimationStyleType.CREATIVE_INTERPOLATION.value
48 | model = None
49 | total_log_table_pages = 1
50 |
51 |
52 | class DefaultProjectSettingParams:
53 | batch_prompt = ""
54 | batch_negative_prompt = "bad image, worst quality"
55 | batch_strength = 1
56 | batch_guidance_scale = 0.5
57 | batch_seed = 0
58 | batch_num_inference_steps = 25
59 | batch_low_threshold = 100
60 | batch_high_threshold = 200
61 | batch_adapter_type = None
62 | batch_interpolation_steps = 3
63 | batch_transformation_stage = ImageStage.SOURCE_IMAGE.value
64 | batch_custom_model_id_list = []
65 | batch_animation_tool = AnimationToolType.G_FILM.value
66 | batch_animation_style = AnimationStyleType.CREATIVE_INTERPOLATION.value
67 | batch_model = None
68 | total_log_pages = 1
69 | total_gallery_pages = 1
70 | total_shortlist_gallery_pages = 1
71 | max_frames_per_shot = 30
72 |
73 |
74 | DEFAULT_SHOT_MOTION_VALUES = {
75 | "strength_of_frame": 0.85,
76 | "distance_to_next_frame": 3.0,
77 | "speed_of_transition": 0.5,
78 | "freedom_between_frames": 0.85,
79 | "individual_prompt": "",
80 | "individual_negative_prompt": "",
81 | "motion_during_frame": 1.25,
82 | }
83 |
84 | # TODO: make proper paths for every file
85 | CROPPED_IMG_LOCAL_PATH = "videos/temp/cropped.png"
86 |
87 | MASK_IMG_LOCAL_PATH = "videos/temp/mask.png"
88 | TEMP_MASK_FILE = "temp_mask_file"
89 |
90 | SECOND_MASK_FILE_PATH = "videos/temp/second_mask.png"
91 | SECOND_MASK_FILE = "second_mask_file"
92 |
93 | AUDIO_FILE_PATH = "videos/temp/audio.mp3"
94 | AUDIO_FILE = "audio_file"
95 |
96 | checkpoints_dir = os.path.join(COMFY_BASE_PATH, "models", "checkpoints")
97 | SD_MODEL_DICT = {
98 | "realisticVisionV60B1_v51VAE.safetensors": {
99 | "url": "https://civitai.com/api/download/models/130072",
100 | "filename": "realisticVisionV60B1_v51VAE.safetensors",
101 | "dest": checkpoints_dir,
102 | },
103 | "anything_v50.safetensors": {
104 | "url": "https://civitai.com/api/download/models/30163",
105 | "filename": "anything_v50.safetensors",
106 | "dest": checkpoints_dir,
107 | },
108 | "dreamshaper_8.safetensors": {
109 | "url": "https://civitai.com/api/download/models/128713",
110 | "filename": "dreamshaper_8.safetensors",
111 | "dest": checkpoints_dir,
112 | },
113 | "epicrealism_pureEvolutionV5.safetensors": {
114 | "url": "https://civitai.com/api/download/models/134065",
115 | "filename": "epicrealism_pureEvolutionV5.safetensors",
116 | "dest": checkpoints_dir,
117 | },
118 | "majicmixRealistic_v6.safetensors": {
119 | "url": "https://civitai.com/api/download/models/94640",
120 | "filename": "majicmixRealistic_v6.safetensors",
121 | "dest": checkpoints_dir,
122 | },
123 | }
124 |
--------------------------------------------------------------------------------
/utils/ml_processor/comfy_workflows/ipadapter_face_api.json:
--------------------------------------------------------------------------------
1 | {
2 | "3": {
3 | "inputs": {
4 | "seed": 862782529735965,
5 | "steps": 24,
6 | "cfg": 9.25,
7 | "sampler_name": "ddim",
8 | "scheduler": "normal",
9 | "denoise": 1,
10 | "model": [
11 | "36",
12 | 0
13 | ],
14 | "positive": [
15 | "6",
16 | 0
17 | ],
18 | "negative": [
19 | "7",
20 | 0
21 | ],
22 | "latent_image": [
23 | "5",
24 | 0
25 | ]
26 | },
27 | "class_type": "KSampler",
28 | "_meta": {
29 | "title": "KSampler"
30 | }
31 | },
32 | "4": {
33 | "inputs": {
34 | "ckpt_name": "sd_xl_base_1.0.safetensors"
35 | },
36 | "class_type": "CheckpointLoaderSimple",
37 | "_meta": {
38 | "title": "Load Checkpoint"
39 | }
40 | },
41 | "5": {
42 | "inputs": {
43 | "width": 1024,
44 | "height": 1024,
45 | "batch_size": 1
46 | },
47 | "class_type": "EmptyLatentImage",
48 | "_meta": {
49 | "title": "Empty Latent Image"
50 | }
51 | },
52 | "6": {
53 | "inputs": {
54 | "text": "",
55 | "clip": [
56 | "4",
57 | 1
58 | ]
59 | },
60 | "class_type": "CLIPTextEncode",
61 | "_meta": {
62 | "title": "CLIP Text Encode (Prompt)"
63 | }
64 | },
65 | "7": {
66 | "inputs": {
67 | "text": "",
68 | "clip": [
69 | "4",
70 | 1
71 | ]
72 | },
73 | "class_type": "CLIPTextEncode",
74 | "_meta": {
75 | "title": "CLIP Text Encode (Prompt)"
76 | }
77 | },
78 | "8": {
79 | "inputs": {
80 | "samples": [
81 | "3",
82 | 0
83 | ],
84 | "vae": [
85 | "4",
86 | 2
87 | ]
88 | },
89 | "class_type": "VAEDecode",
90 | "_meta": {
91 | "title": "VAE Decode"
92 | }
93 | },
94 | "21": {
95 | "inputs": {
96 | "ipadapter_file": "ip-adapter_sdxl.safetensors"
97 | },
98 | "class_type": "IPAdapterModelLoader",
99 | "_meta": {
100 | "title": "Load IPAdapter Model"
101 | }
102 | },
103 | "24": {
104 | "inputs": {
105 | "image": "rWA-3_T7_400x400.jpg",
106 | "upload": "image"
107 | },
108 | "class_type": "LoadImage",
109 | "_meta": {
110 | "title": "Load Image"
111 | }
112 | },
113 | "29": {
114 | "inputs": {
115 | "filename_prefix": "ComfyUI",
116 | "images": [
117 | "8",
118 | 0
119 | ]
120 | },
121 | "class_type": "SaveImage",
122 | "_meta": {
123 | "title": "Save Image"
124 | }
125 | },
126 | "36": {
127 | "inputs": {
128 | "weight": 0.75,
129 | "noise": 0.3,
130 | "weight_faceidv2": 0.75,
131 | "weight_type": "linear",
132 | "combine_embeds": "concat",
133 | "embeds_scaling": "V only",
134 | "start_at": 0,
135 | "end_at": 1,
136 | "ipadapter": [
137 | "21",
138 | 0
139 | ],
140 | "clip_vision": [
141 | "41",
142 | 0
143 | ],
144 | "insightface": [
145 | "37",
146 | 0
147 | ],
148 | "image": [
149 | "40",
150 | 0
151 | ],
152 | "model": [
153 | "4",
154 | 0
155 | ]
156 | },
157 | "class_type": "IPAdapterFaceID",
158 | "_meta": {
159 | "title": "IPAdapter FaceID"
160 | }
161 | },
162 | "37": {
163 | "inputs": {
164 | "provider": "CUDA"
165 | },
166 | "class_type": "IPAdapterInsightFaceLoader",
167 | "_meta": {
168 | "title": "IPAdapter InsightFace Loader"
169 | }
170 | },
171 | "40": {
172 | "inputs": {
173 | "interpolation": "LANCZOS",
174 | "crop_position": "top",
175 | "sharpening": 0,
176 | "image": [
177 | "24",
178 | 0
179 | ]
180 | },
181 | "class_type": "PrepImageForClipVision",
182 | "_meta": {
183 | "title": "Prep Image For ClipVision"
184 | }
185 | },
186 | "41": {
187 | "inputs": {
188 | "clip_name": "SDXL/pytorch_model.bin"
189 | },
190 | "class_type": "CLIPVisionLoader",
191 | "_meta": {
192 | "title": "Load CLIP Vision"
193 | }
194 | }
195 | }
--------------------------------------------------------------------------------
/utils/common_decorators.py:
--------------------------------------------------------------------------------
1 | from functools import wraps
2 | import json
3 | import os
4 | import time
5 | import portalocker
6 | import streamlit as st
7 | from streamlit import runtime
8 |
9 |
10 | def count_calls(cls):
11 | class Wrapper(cls):
12 | def __init__(self, *args, **kwargs):
13 | super().__init__(*args, **kwargs)
14 | self.call_counts = {}
15 | self.total_count = 0
16 |
17 | def __getattribute__(self, name):
18 | attr = super().__getattribute__(name)
19 | if callable(attr) and name not in ["__getattribute__", "call_counts", "total_count"]:
20 | if name not in self.call_counts:
21 | self.call_counts[name] = 0
22 |
23 | def wrapped_method(*args, **kwargs):
24 | self.call_counts[name] += 1
25 | self.total_count += 1
26 | return attr(*args, **kwargs)
27 |
28 | return wrapped_method
29 |
30 | return attr
31 |
32 | return Wrapper
33 |
34 |
35 | def log_time(func):
36 | def wrapper(*args, **kwargs):
37 | start_time = time.time()
38 | result = func(*args, **kwargs)
39 | end_time = time.time()
40 | execution_time = end_time - start_time
41 | print(
42 | f"{args[1] if args and len(args) >= 2 else kwargs['url']} took {execution_time:.4f} seconds to execute."
43 | )
44 | return result
45 |
46 | return wrapper
47 |
48 |
49 | def measure_execution_time(cls):
50 | class WrapperClass:
51 | def __init__(self, *args, **kwargs):
52 | self.wrapped_instance = cls(*args, **kwargs)
53 |
54 | def __getattr__(self, name):
55 | attr = getattr(self.wrapped_instance, name)
56 | if callable(attr):
57 | return self.measure_method_execution(attr)
58 | return attr
59 |
60 | def measure_method_execution(self, method):
61 | def wrapper(*args, **kwargs):
62 | start_time = time.time()
63 | result = method(*args, **kwargs)
64 | end_time = time.time()
65 | execution_time = end_time - start_time
66 | print(f"Execution time of {method.__name__}: {execution_time} seconds")
67 | return result
68 |
69 | return wrapper
70 |
71 | return WrapperClass
72 |
73 |
74 | def session_state_attributes(default_value_cls):
75 | def decorator(cls):
76 | original_getattr = cls.__getattribute__
77 | original_setattr = cls.__setattr__
78 |
79 | def custom_attr(self, attr):
80 | if hasattr(default_value_cls, attr):
81 | key = f"{self.uuid}_{attr}"
82 | if not (key in st.session_state and st.session_state[key]):
83 | st.session_state[key] = getattr(default_value_cls, attr)
84 |
85 | return st.session_state[key] if runtime.exists() else getattr(default_value_cls, attr)
86 | else:
87 | return original_getattr(self, attr)
88 |
89 | def custom_setattr(self, attr, value):
90 | if hasattr(default_value_cls, attr):
91 | key = f"{self.uuid}_{attr}"
92 | st.session_state[key] = value
93 | else:
94 | original_setattr(self, attr, value)
95 |
96 | cls.__getattribute__ = custom_attr
97 | cls.__setattr__ = custom_setattr
98 | return cls
99 |
100 | return decorator
101 |
102 |
103 | def with_refresh_lock(func):
104 | @wraps(func)
105 | def wrapper(*args, **kwargs):
106 | update_refresh_lock(True)
107 | try:
108 | return func(*args, **kwargs)
109 | except Exception as e:
110 | print("Error occured while processing ", str(e))
111 | finally:
112 | update_refresh_lock(False)
113 |
114 | return wrapper
115 |
116 |
117 | def update_refresh_lock(status=False):
118 | from utils.constants import REFRESH_LOCK_FILE
119 |
120 | status = "locked" if status else "unlocked"
121 | lock_file = REFRESH_LOCK_FILE
122 | with portalocker.Lock(lock_file, "w") as lock_file_handle:
123 | json.dump(
124 | {"status": status, "last_action_time": time.time(), "process_id": os.getpid()},
125 | lock_file_handle,
126 | )
127 |
--------------------------------------------------------------------------------
/ui_components/methods/training_methods.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import List
3 | from shared.constants import AIModelCategory
4 | from utils.common_utils import get_current_user_uuid
5 | from utils.data_repo.data_repo import DataRepo
6 | from utils.ml_processor.ml_interface import get_ml_client
7 | from utils.ml_processor.constants import ML_MODEL
8 |
9 | # NOTE: code not in use
10 | # NOTE: making an exception for this function, passing just the image urls instead of image files
11 | # def train_model(images_list, instance_prompt, class_prompt, max_train_steps,
12 | # model_name, type_of_model, type_of_task, resolution, controller_type, model_type_list):
13 | # # prepare and upload the training data (images.zip)
14 | # ml_client = get_ml_client()
15 | # try:
16 | # training_file_url = ml_client.upload_training_data(images_list)
17 | # except Exception as e:
18 | # raise e
19 |
20 | # # training the model
21 | # model_name = model_name.replace(" ", "-").lower()
22 | # if type_of_model == "Dreambooth":
23 | # return train_dreambooth_model(instance_prompt, class_prompt, training_file_url,
24 | # max_train_steps, model_name, images_list, controller_type, model_type_list)
25 | # elif type_of_model == "LoRA":
26 | # return train_lora_model(training_file_url, type_of_task, resolution, model_name, images_list, model_type_list)
27 |
28 |
29 | # NOTE: code not in use
30 | # INFO: images_list passed here are converted to internal files after they are used for training
31 | # def train_dreambooth_model(instance_prompt, class_prompt, training_file_url, max_train_steps, model_name, images_list: List[str], controller_type, model_type_list):
32 | # from ui_components.methods.common_methods import convert_image_list_to_file_list
33 |
34 | # ml_client = get_ml_client()
35 | # app_setting = DataRepo().get_app_setting_from_uuid()
36 |
37 | # response = ml_client.dreambooth_training(
38 | # training_file_url, instance_prompt, class_prompt, max_train_steps, model_name, controller_type, len(images_list), app_setting.replicate_username)
39 | # training_status = response["status"]
40 |
41 | # model_id = response["id"]
42 | # if training_status == "queued":
43 | # file_list = convert_image_list_to_file_list(images_list)
44 | # file_uuid_list = [file.uuid for file in file_list]
45 | # file_uuid_list = json.dumps(file_uuid_list)
46 |
47 | # model_data = {
48 | # "name": model_name,
49 | # "user_id": get_current_user_uuid(),
50 | # "replicate_model_id": model_id,
51 | # "replicate_url": response["model"],
52 | # "diffusers_url": "",
53 | # "category": AIModelCategory.DREAMBOOTH.value,
54 | # "training_image_list": file_uuid_list,
55 | # "keyword": instance_prompt,
56 | # "custom_trained": True,
57 | # "model_type": model_type_list
58 | # }
59 |
60 | # data_repo = DataRepo()
61 | # data_repo.create_ai_model(**model_data)
62 |
63 | # return "Success - Training Started. Please wait 10-15 minutes for the model to be trained."
64 | # else:
65 | # return "Failed"
66 |
67 | # NOTE: code not in use
68 | # INFO: images_list passed here are converted to internal files after they are used for training
69 | # def train_lora_model(training_file_url, type_of_task, resolution, model_name, images_list, model_type_list):
70 | # from ui_components.methods.common_methods import convert_image_list_to_file_list
71 |
72 | # data_repo = DataRepo()
73 | # ml_client = get_ml_client()
74 | # output = ml_client.predict_model_output(ML_MODEL.clones_lora_training, instance_data=training_file_url,
75 | # task=type_of_task, resolution=int(resolution))
76 |
77 | # file_list = convert_image_list_to_file_list(images_list)
78 | # file_uuid_list = [file.uuid for file in file_list]
79 | # file_uuid_list = json.dumps(file_uuid_list)
80 | # model_data = {
81 | # "name": model_name,
82 | # "user_id": get_current_user_uuid(),
83 | # "replicate_url": output,
84 | # "diffusers_url": "",
85 | # "category": AIModelCategory.LORA.value,
86 | # "training_image_list": file_uuid_list,
87 | # "custom_trained": True,
88 | # "model_type": model_type_list
89 | # }
90 |
91 | # data_repo.create_ai_model(**model_data)
92 | # return f"Successfully trained - the model '{model_name}' is now available for use!"
93 |
--------------------------------------------------------------------------------
/shared/file_upload/s3.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import mimetypes
3 | from urllib.parse import urlparse
4 | import boto3
5 | import uuid
6 | import os
7 | import shutil
8 |
9 | import requests
10 | from shared.constants import AWS_ACCESS_KEY, AWS_S3_BUCKET, AWS_S3_REGION, AWS_SECRET_KEY
11 | from shared.logging.logging import AppLogger
12 | from shared.logging.constants import LoggingPayload, LoggingType
13 | from ui_components.methods.file_methods import convert_file_to_base64
14 |
15 | logger = AppLogger()
16 |
17 | # TODO: fix proper paths for file uploads
18 |
19 |
20 | def upload_file(file_location, aws_access_key, aws_secret_key, bucket=AWS_S3_BUCKET):
21 | url = None
22 | ext = os.path.splitext(file_location)[1]
23 | unique_file_name = str(uuid.uuid4()) + ext
24 | try:
25 | s3_file = f"input_images/{unique_file_name}"
26 | s3 = boto3.client("s3", aws_access_key_id=aws_access_key, aws_secret_access_key=aws_secret_key)
27 | s3.upload_file(file_location, bucket, s3_file)
28 | s3.put_object_acl(ACL="public-read", Bucket=bucket, Key=s3_file)
29 | url = f"https://s3.amazonaws.com/{bucket}/{s3_file}"
30 | except Exception as e:
31 | logger.log(LoggingType.ERROR, "unable to upload to s3")
32 |
33 | return url
34 |
35 |
36 | def upload_file_from_obj(file, file_extension, bucket=AWS_S3_BUCKET):
37 | aws_access_key, aws_secret_key = AWS_ACCESS_KEY, AWS_SECRET_KEY
38 | folder = "test/"
39 | unique_tag = str(uuid.uuid4())
40 | filename = unique_tag + file_extension
41 | file.seek(0)
42 |
43 | # Upload the file
44 | content_type = (
45 | "application/octet-stream" if file_extension not in [".png", ".jpg"] else "image/png"
46 | ) # hackish sol, will fix later
47 | data = {"Body": file, "Bucket": bucket, "Key": folder + filename, "ACL": "public-read"}
48 | if content_type:
49 | data["ContentType"] = content_type
50 |
51 | s3_client = boto3.client("s3", aws_access_key_id=aws_access_key, aws_secret_access_key=aws_secret_key)
52 | resp = s3_client.put_object(**data)
53 | object_url = "https://s3-{0}.amazonaws.com/{1}/{2}".format(AWS_S3_REGION, bucket, folder + filename)
54 | return object_url
55 |
56 |
57 | def upload_file_from_bytes(file_bytes, aws_access_key, aws_secret_key, key=None, bucket=AWS_S3_BUCKET):
58 | if not key:
59 | key = "test/" + str(uuid.uuid4()) + ".png"
60 |
61 | content_type = "image/png"
62 | data = {"Body": file_bytes, "Bucket": bucket, "Key": key, "ACL": "public-read"}
63 | if content_type:
64 | data["ContentType"] = content_type
65 |
66 | s3_client = boto3.client("s3", aws_access_key_id=aws_access_key, aws_secret_access_key=aws_secret_key)
67 | resp = s3_client.put_object(**data)
68 | object_url = "https://s3-{0}.amazonaws.com/{1}/{2}".format(AWS_S3_REGION, bucket, key)
69 | return object_url
70 |
71 |
72 | # TODO: fix the structuring of s3 for different users and different files
73 | def generate_s3_url(
74 | image_url, aws_access_key, aws_secret_key, bucket=AWS_S3_BUCKET, file_ext="png", folder="posts/"
75 | ):
76 | if object_name is None:
77 | object_name = str(uuid.uuid4()) + "." + file_ext
78 |
79 | response = requests.get(image_url)
80 | if response.status_code != 200:
81 | raise Exception("Failed to download the image from the given URL")
82 |
83 | file = response.content
84 |
85 | content_type = mimetypes.guess_type(object_name)[0]
86 | data = {"Body": file, "Bucket": bucket, "Key": folder + object_name, "ACL": "public-read"}
87 | if content_type:
88 | data["ContentType"] = content_type
89 | else:
90 | data["ContentType"] = "image/png"
91 |
92 | s3_client = boto3.client(
93 | service_name="s3",
94 | region_name=AWS_S3_REGION,
95 | aws_access_key_id=aws_access_key,
96 | aws_secret_access_key=aws_secret_key,
97 | )
98 | resp = s3_client.put_object(**data)
99 |
100 | extension = os.path.splitext(object_name)[1]
101 | disposition = f'inline; filename="{object_name}"'
102 | if extension:
103 | disposition += f'; filename="{object_name}"'
104 | resp["ResponseMetadata"]["HTTPHeaders"]["Content-Disposition"] = disposition
105 |
106 | object_url = "https://s3-{0}.amazonaws.com/{1}/{2}".format(
107 | AWS_S3_REGION, AWS_S3_BUCKET, folder + object_name
108 | )
109 | return object_url
110 |
111 |
112 | def is_s3_image_url(url):
113 | parsed_url = urlparse(url)
114 | netloc = parsed_url.netloc.lower()
115 |
116 | if netloc.endswith(".amazonaws.com"):
117 | subdomain = netloc[: -len(".amazonaws.com")].split("-")
118 | if len(subdomain) > 1 and subdomain[0] == "s3":
119 | return True
120 |
121 | return False
122 |
--------------------------------------------------------------------------------
/utils/ml_processor/gpu/utils.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import os
3 | import sys
4 | import subprocess
5 | import time
6 | from git import Repo
7 | from shared.constants import COMFY_BASE_PATH
8 | from shared.logging.constants import LoggingType
9 | from shared.logging.logging import app_logger
10 | from utils.common_utils import get_toml_config
11 | from utils.constants import TomlConfig
12 |
13 |
14 | COMFY_RUNNER_PATH = "./comfy_runner"
15 |
16 |
17 | def predict_gpu_output(
18 | workflow: str,
19 | file_path_list=[],
20 | output_node=None,
21 | extra_model_list=[],
22 | ignore_model_list=[],
23 | log_tag=None,
24 | ) -> str:
25 | # spec = importlib.util.spec_from_file_location('my_module', f'{COMFY_RUNNER_PATH}/inf.py')
26 | # comfy_runner = importlib.util.module_from_spec(spec)
27 | # spec.loader.exec_module(comfy_runner)
28 |
29 | # hackish sol.. waiting for comfy repo to be cloned
30 | while not is_comfy_runner_present():
31 | time.sleep(2)
32 |
33 | sys.path.append(str(os.getcwd()) + COMFY_RUNNER_PATH[1:])
34 | from comfy_runner.inf import ComfyRunner
35 |
36 | comfy_commit_hash = get_toml_config(TomlConfig.COMFY_VERSION.value)["commit_hash"]
37 | node_commit_dict = get_toml_config(TomlConfig.NODE_VERSION.value)
38 | pkg_versions = get_toml_config(TomlConfig.PKG_VERSIONS.value)
39 | extra_node_urls = []
40 | for k, v in node_commit_dict.items():
41 | v["title"] = k
42 | extra_node_urls.append(v)
43 |
44 | comfy_runner = ComfyRunner()
45 | output = comfy_runner.predict(
46 | workflow_input=workflow,
47 | file_path_list=file_path_list,
48 | stop_server_after_completion=False,
49 | output_node_ids=output_node,
50 | extra_models_list=extra_model_list,
51 | ignore_model_list=ignore_model_list,
52 | client_id=log_tag,
53 | extra_node_urls=extra_node_urls,
54 | comfy_commit_hash=comfy_commit_hash,
55 | strict_dep_list=pkg_versions
56 | )
57 |
58 | return output["file_paths"] # ignoring text output for now {"file_paths": [], "text_content": []}
59 |
60 |
61 | def is_comfy_runner_present():
62 | return os.path.exists(COMFY_RUNNER_PATH) # hackish sol, will fix later
63 |
64 |
65 | # TODO: convert comfy_runner into a package for easy import
66 | def setup_comfy_runner():
67 | if is_comfy_runner_present():
68 | update_comfy_runner_env()
69 | return
70 |
71 | app_logger.log(LoggingType.INFO, "cloning comfy runner")
72 | comfy_repo_url = "https://github.com/piyushK52/comfy-runner"
73 | Repo.clone_from(comfy_repo_url, COMFY_RUNNER_PATH[2:], single_branch=True, branch="main")
74 |
75 | # installing dependencies
76 | subprocess.run(["pip", "install", "-r", COMFY_RUNNER_PATH + "/requirements.txt"], check=True)
77 | update_comfy_runner_env()
78 |
79 |
80 | def find_comfy_runner():
81 | # just keep going up the directory tree, till we find comfy_runner
82 | current_path = os.path.dirname(os.path.abspath(__file__))
83 |
84 | while True:
85 | if os.path.exists(os.path.join(current_path, '.git')):
86 | comfy_runner_path = os.path.join(current_path, 'comfy_runner')
87 | if os.path.exists(comfy_runner_path):
88 | return comfy_runner_path
89 | else:
90 | return None # comfy_runner not found in the project root
91 |
92 | parent_path = os.path.dirname(current_path)
93 | if parent_path == current_path:
94 | return None
95 |
96 | current_path = parent_path
97 |
98 | def update_comfy_runner_env():
99 | comfy_base_path = os.getenv("COMFY_MODELS_BASE_PATH", "ComfyUI")
100 | comfy_runner_path = find_comfy_runner()
101 | if not comfy_runner_path:
102 | print("comfy_runner not present")
103 | return
104 |
105 | if comfy_base_path != "ComfyUI":
106 | env_file_path = os.path.join(comfy_runner_path, ".env")
107 | try:
108 | os.makedirs(os.path.dirname(env_file_path), exist_ok=True)
109 |
110 | with open(env_file_path, "w", encoding="utf-8") as f:
111 | f.write(f"COMFY_RUNNER_MODELS_BASE_PATH={comfy_base_path}")
112 |
113 | with open(env_file_path, "r", encoding="utf-8") as f:
114 | written_content = f.read()
115 |
116 | if written_content != f"COMFY_RUNNER_MODELS_BASE_PATH={comfy_base_path}":
117 | print(f"File was written, but content doesn't match. Expected: {comfy_base_path}, Got: {written_content}")
118 |
119 | except IOError as e:
120 | print(f"IOError occurred while writing to {env_file_path}: {e}")
121 | except Exception as e:
122 | print(f"An unexpected error occurred: {e}")
123 |
124 | else:
125 | with open("comfy_runner/.env", "w", encoding="utf-8") as f:
126 | f.write("")
127 |
--------------------------------------------------------------------------------
/ui_components/widgets/image_zoom_widgets.py:
--------------------------------------------------------------------------------
1 | import uuid
2 | import streamlit as st
3 | from backend.models import InternalFileObject
4 | from shared.constants import InternalFileType
5 | from ui_components.constants import WorkflowStageType
6 | from ui_components.methods.common_methods import add_image_variant, promote_image_variant
7 | from ui_components.methods.file_methods import save_or_host_file
8 |
9 | from utils.data_repo.data_repo import DataRepo
10 |
11 |
12 | def zoom_inputs(position="in-frame", horizontal=False, shot_uuid=None):
13 | if horizontal:
14 | col1, col2 = st.columns(2)
15 | col3, col4 = st.columns(2)
16 | col5, col6 = st.columns(2)
17 | else:
18 | col1 = col2 = col3 = col4 = col5 = col6 = st
19 |
20 | if "zoom_level_input_default" not in st.session_state:
21 | st.session_state["zoom_level_input_default"] = 100
22 | st.session_state["rotation_angle_input_default"] = 0
23 | st.session_state["x_shift_default"] = 0
24 | st.session_state["y_shift_default"] = 0
25 | st.session_state["flip_vertically_default"] = False
26 | st.session_state["flip_horizontally_default"] = False
27 |
28 | col1.number_input(
29 | "Zoom in/out:",
30 | min_value=10,
31 | max_value=1000,
32 | step=5,
33 | key=f"zoom_level_input",
34 | value=st.session_state["zoom_level_input_default"],
35 | )
36 |
37 | col2.number_input(
38 | "Rotate:",
39 | min_value=-360,
40 | max_value=360,
41 | step=5,
42 | key="rotation_angle_input",
43 | value=st.session_state["rotation_angle_input_default"],
44 | )
45 | # st.session_state['rotation_angle_input'] = 0
46 |
47 | col3.number_input(
48 | "Shift left/right:",
49 | min_value=-1000,
50 | max_value=1000,
51 | step=5,
52 | key=f"x_shift",
53 | value=st.session_state["x_shift_default"],
54 | )
55 |
56 | col4.number_input(
57 | "Shift down/up:",
58 | min_value=-1000,
59 | max_value=1000,
60 | step=5,
61 | key=f"y_shift",
62 | value=st.session_state["y_shift_default"],
63 | )
64 |
65 | col5.checkbox(
66 | "Flip vertically ↕️", key=f"flip_vertically", value=str(st.session_state["flip_vertically_default"])
67 | )
68 |
69 | col6.checkbox(
70 | "Flip horizontally ↔️",
71 | key=f"flip_horizontally",
72 | value=str(st.session_state["flip_horizontally_default"]),
73 | )
74 |
75 |
76 | def save_zoomed_image(image, timing_uuid, stage, promote=False):
77 | data_repo = DataRepo()
78 | timing = data_repo.get_timing_from_uuid(timing_uuid)
79 | project_uuid = timing.shot.project.uuid
80 |
81 | file_name = str(uuid.uuid4()) + ".png"
82 |
83 | if stage == WorkflowStageType.SOURCE.value:
84 | save_location = f"videos/{project_uuid}/assets/frames/modified/{file_name}"
85 | hosted_url = save_or_host_file(image, save_location)
86 | file_data = {"name": file_name, "type": InternalFileType.IMAGE.value, "project_id": project_uuid}
87 |
88 | if hosted_url:
89 | file_data.update({"hosted_url": hosted_url})
90 | else:
91 | file_data.update({"local_path": save_location})
92 |
93 | source_image: InternalFileObject = data_repo.create_file(**file_data)
94 | data_repo.update_specific_timing(
95 | st.session_state["current_frame_uuid"], source_image_id=source_image.uuid, update_in_place=True
96 | )
97 | elif stage == WorkflowStageType.STYLED.value:
98 | save_location = f"videos/{project_uuid}/assets/frames/modified/{file_name}"
99 | hosted_url = save_or_host_file(image, save_location)
100 | file_data = {"name": file_name, "type": InternalFileType.IMAGE.value, "project_id": project_uuid}
101 |
102 | if hosted_url:
103 | file_data.update({"hosted_url": hosted_url})
104 | else:
105 | file_data.update({"local_path": save_location})
106 |
107 | styled_image: InternalFileObject = data_repo.create_file(**file_data)
108 | number_of_image_variants = add_image_variant(styled_image.uuid, timing_uuid)
109 | if promote:
110 | promote_image_variant(timing_uuid, number_of_image_variants - 1)
111 |
112 |
113 | def reset_zoom_element():
114 | st.session_state["zoom_level_input_default"] = 100
115 | st.session_state["zoom_level_input"] = 100
116 | st.session_state["rotation_angle_input_default"] = 0
117 | st.session_state["rotation_angle_input"] = 0
118 | st.session_state["x_shift_default"] = 0
119 | st.session_state["x_shift"] = 0
120 | st.session_state["y_shift_default"] = 0
121 | st.session_state["y_shift"] = 0
122 | st.session_state["flip_vertically_default"] = False
123 | st.session_state["flip_vertically"] = False
124 | st.session_state["flip_horizontally_default"] = False
125 | st.session_state["flip_horizontally"] = False
126 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | import time
2 | import streamlit as st
3 | from moviepy.editor import *
4 | import subprocess
5 | import os
6 | import django
7 | import sentry_sdk
8 |
9 | from utils.data_repo.api_repo import APIRepo
10 |
11 | os.environ.setdefault("DJANGO_SETTINGS_MODULE", "django_settings")
12 | django.setup()
13 | st.session_state["django_init"] = True
14 |
15 | from shared.constants import (
16 | GPU_INFERENCE_ENABLED_KEY,
17 | HOSTED_BACKGROUND_RUNNER_MODE,
18 | OFFLINE_MODE,
19 | SERVER,
20 | ServerType,
21 | ConfigManager,
22 | )
23 |
24 | from shared.logging.logging import AppLogger
25 | from ui_components.components.user_login_page import user_login_ui
26 | from ui_components.models import InternalUserObject
27 | from utils.app_update_utils import apply_updates, check_and_pull_changes, load_save_checkpoint
28 | from utils.common_decorators import update_refresh_lock
29 | from utils.common_utils import is_process_active, refresh_process_active
30 | from utils.state_refresh import refresh_app
31 |
32 | from utils.constants import (
33 | REFRESH_PROCESS_PORT,
34 | RUNNER_PROCESS_NAME,
35 | )
36 | from streamlit_server_state import server_state_lock
37 | from utils.refresh_target import SAVE_STATE
38 | from banodoco_settings import project_init
39 | from utils.data_repo.data_repo import DataRepo
40 |
41 |
42 | config_manager = ConfigManager()
43 | RUNNER_PROCESS_PORT = config_manager.get("runner_process_port")
44 |
45 |
46 | def start_runner():
47 | if SERVER != ServerType.DEVELOPMENT.value and HOSTED_BACKGROUND_RUNNER_MODE in [False, "False"]:
48 | return
49 |
50 | with server_state_lock["runner"]:
51 | app_logger = AppLogger()
52 |
53 | if not is_process_active(RUNNER_PROCESS_NAME, RUNNER_PROCESS_PORT):
54 | app_logger.info("Starting runner")
55 | python_executable = sys.executable
56 | _ = subprocess.Popen([python_executable, "banodoco_runner.py"])
57 | max_retries = 6
58 | while not is_process_active(RUNNER_PROCESS_NAME, RUNNER_PROCESS_PORT) and max_retries:
59 | time.sleep(0.1)
60 | max_retries -= 1
61 |
62 | # refreshing the app if the runner port has changed
63 | old_port = RUNNER_PROCESS_PORT
64 | new_port = config_manager.get("runner_process_port", fresh_pull=True)
65 | if old_port != new_port:
66 | st.rerun()
67 | else:
68 | # app_logger.debug("Runner already running")
69 | pass
70 |
71 |
72 | def start_project_refresh():
73 | if SERVER != ServerType.DEVELOPMENT.value:
74 | return
75 |
76 | with server_state_lock["refresh_app"]:
77 | app_logger = AppLogger()
78 |
79 | if not refresh_process_active(REFRESH_PROCESS_PORT):
80 | python_executable = sys.executable
81 | _ = subprocess.Popen([python_executable, "auto_refresh.py"])
82 | max_retries = 6
83 | while not refresh_process_active(REFRESH_PROCESS_PORT) and max_retries:
84 | time.sleep(1)
85 | max_retries -= 1
86 | app_logger.info("Auto refresh enabled")
87 | else:
88 | # app_logger.debug("refresh process already running")
89 | pass
90 |
91 |
92 | def main():
93 | st.set_page_config(page_title="Dough", layout="wide", page_icon="🎨")
94 | st.markdown(
95 | r"""
96 |
101 | """,
102 | unsafe_allow_html=True,
103 | )
104 | update_refresh_lock(False)
105 |
106 | # if it's the first time,
107 | if "first_load" not in st.session_state:
108 | if not is_process_active(RUNNER_PROCESS_NAME, RUNNER_PROCESS_PORT):
109 | if not load_save_checkpoint():
110 | check_and_pull_changes() # enabling auto updates only for local version
111 | else:
112 | apply_updates()
113 | refresh_app()
114 | st.session_state["first_load"] = True
115 |
116 | start_runner()
117 | start_project_refresh()
118 | project_init()
119 |
120 | from ui_components.setup import setup_app_ui
121 | from ui_components.components.welcome_page import welcome_page
122 |
123 | data_repo = DataRepo()
124 | api_repo = APIRepo()
125 | config_manager = ConfigManager()
126 | gpu_enabled = config_manager.get(GPU_INFERENCE_ENABLED_KEY, False)
127 |
128 | app_setting = data_repo.get_app_setting_from_uuid()
129 | if app_setting.welcome_state == 2:
130 | # api/online inference mode
131 | if not gpu_enabled:
132 | if not api_repo.is_user_logged_in():
133 | # user not logged in
134 | user_login_ui()
135 | else:
136 | setup_app_ui()
137 | else:
138 | # gpu/offline inference mode
139 | setup_app_ui()
140 | else:
141 | welcome_page()
142 |
143 | st.session_state["maintain_state"] = False
144 |
145 |
146 | if __name__ == "__main__":
147 | try:
148 | main()
149 | except Exception as e:
150 | sentry_sdk.capture_exception(e)
151 | raise e
152 |
--------------------------------------------------------------------------------
/ui_components/widgets/image_carousal.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from st_clickable_images import clickable_images
3 |
4 | from ui_components.constants import WorkflowStageType
5 | from ui_components.models import InternalShotObject
6 | from utils.data_repo.data_repo import DataRepo
7 | from utils.state_refresh import refresh_app
8 |
9 |
10 | def display_image(timing_uuid, stage=None, clickable=False):
11 | data_repo = DataRepo()
12 | timing = data_repo.get_timing_from_uuid(timing_uuid)
13 | timing_idx = timing.aux_frame_index
14 |
15 | # if it's less than 0 or greater than the number in timing_details, show nothing
16 | if not timing:
17 | st.write("no images")
18 |
19 | else:
20 | if stage == WorkflowStageType.STYLED.value:
21 | image = timing.primary_image_location
22 | elif stage == WorkflowStageType.SOURCE.value:
23 | image = timing.source_image.location if timing.source_image else ""
24 |
25 | if image != "":
26 | if clickable is True:
27 | if "counter" not in st.session_state:
28 | st.session_state["counter"] = 0
29 |
30 | import base64
31 |
32 | if image.startswith("http"):
33 | st.write("")
34 | else:
35 | with open(image, "rb") as image:
36 | st.write("")
37 | encoded = base64.b64encode(image.read()).decode()
38 | image = f"data:image/jpeg;base64,{encoded}"
39 |
40 | st.session_state[f"{timing_idx}_{stage}_clicked"] = clickable_images(
41 | [image],
42 | div_style={"display": "flex", "justify-content": "center", "flex-wrap": "wrap"},
43 | img_style={"max-width": "100%", "height": "auto", "cursor": "pointer"},
44 | key=f"{timing_idx}_{stage}_image_{st.session_state['counter']}",
45 | )
46 |
47 | if st.session_state[f"{timing_idx}_{stage}_clicked"] == 0:
48 | timing_details = data_repo.get_timing_list_from_shot(timing.shot.uuid)
49 | st.session_state["current_frame_uuid"] = timing_details[timing_idx].uuid
50 | st.session_state["current_frame_index"] = timing_idx + 1
51 | st.session_state["prev_frame_index"] = timing_idx + 1
52 | # st.session_state['frame_styling_view_type_index'] = 0
53 | st.session_state["frame_styling_view_type"] = "Individual"
54 | st.session_state["counter"] += 1
55 | refresh_app()
56 |
57 | elif clickable is False:
58 | st.image(image, use_column_width=True)
59 | else:
60 | st.error(f"No {stage} image found for #{timing_idx + 1}")
61 |
62 |
63 | def carousal_of_images_element(shot_uuid, stage=WorkflowStageType.STYLED.value):
64 | data_repo = DataRepo()
65 | shot: InternalShotObject = data_repo.get_shot_from_uuid(shot_uuid)
66 | timing_list = shot.timing_list
67 |
68 | header1, header2, header3, header4, header5 = st.columns([1, 1, 1, 1, 1])
69 |
70 | current_frame_uuid = st.session_state["current_frame_uuid"]
71 | current_timing = data_repo.get_timing_from_uuid(current_frame_uuid)
72 |
73 | with header1:
74 | if current_timing.aux_frame_index - 2 >= 0:
75 | prev_2_timing = data_repo.get_timing_from_frame_number(
76 | shot_uuid, current_timing.aux_frame_index - 2
77 | )
78 |
79 | if prev_2_timing:
80 | display_image(prev_2_timing.uuid, stage=stage, clickable=True)
81 | st.info(f"#{prev_2_timing.aux_frame_index + 1}")
82 |
83 | with header2:
84 | if current_timing.aux_frame_index - 1 >= 0:
85 | prev_timing = data_repo.get_timing_from_frame_number(
86 | shot_uuid, current_timing.aux_frame_index - 1
87 | )
88 | if prev_timing:
89 | display_image(prev_timing.uuid, stage=stage, clickable=True)
90 | st.info(f"#{prev_timing.aux_frame_index + 1}")
91 |
92 | with header3:
93 |
94 | timing = data_repo.get_timing_from_uuid(current_frame_uuid)
95 | display_image(timing.uuid, stage=stage, clickable=True)
96 | st.success(f"#{current_timing.aux_frame_index + 1}")
97 | with header4:
98 | if current_timing.aux_frame_index + 1 <= len(timing_list):
99 | next_timing = data_repo.get_timing_from_frame_number(
100 | shot_uuid, current_timing.aux_frame_index + 1
101 | )
102 | if next_timing:
103 | display_image(next_timing.uuid, stage=stage, clickable=True)
104 | st.info(f"#{next_timing.aux_frame_index + 1}")
105 |
106 | with header5:
107 | if current_timing.aux_frame_index + 2 <= len(timing_list):
108 | next_2_timing = data_repo.get_timing_from_frame_number(
109 | shot_uuid, current_timing.aux_frame_index + 2
110 | )
111 | if next_2_timing:
112 | display_image(next_2_timing.uuid, stage=stage, clickable=True)
113 | st.info(f"#{next_2_timing.aux_frame_index + 1}")
114 |
--------------------------------------------------------------------------------
/ui_components/widgets/add_key_frame_element.py:
--------------------------------------------------------------------------------
1 | import time
2 | from typing import Union
3 | import streamlit as st
4 | from shared.constants import AnimationStyleType
5 | from ui_components.models import InternalFileObject, InternalFrameTimingObject
6 | from utils.common_decorators import update_refresh_lock
7 | from utils.state_refresh import refresh_app
8 | from utils.data_repo.data_repo import DataRepo
9 | from ui_components.methods.file_methods import generate_pil_image, save_or_host_file
10 | from ui_components.methods.common_methods import add_image_variant, save_new_image
11 | from PIL import Image
12 |
13 |
14 | def add_key_frame_section(shot_uuid):
15 | data_repo = DataRepo()
16 | shot = data_repo.get_shot_from_uuid(shot_uuid)
17 | selected_image_location = ""
18 |
19 | uploaded_images = st.file_uploader(
20 | "Upload images:",
21 | type=["png", "jpg", "jpeg", "webp"],
22 | key=f"uploaded_image_{shot_uuid}",
23 | help="You can upload multiple images",
24 | accept_multiple_files=True,
25 | )
26 |
27 | if st.button(
28 | f"Add key frame(s)",
29 | use_container_width=True,
30 | key=f"add_key_frame_btn_{shot_uuid}",
31 | type="primary",
32 | ):
33 | update_refresh_lock(True)
34 | if uploaded_images:
35 | progress_bar = st.progress(0)
36 | # Remove sorting to maintain upload order
37 | for i, uploaded_image in enumerate(uploaded_images):
38 | image = Image.open(uploaded_image)
39 | file_location = f"videos/{shot.uuid}/assets/frames/base/{uploaded_image.name}"
40 | selected_image_location = save_or_host_file(image, file_location)
41 | selected_image_location = selected_image_location or file_location
42 | add_key_frame(selected_image_location, shot_uuid, refresh_state=False)
43 | progress_bar.progress((i + 1) / len(uploaded_images))
44 | else:
45 | st.error("Please generate new images or upload them")
46 | time.sleep(0.7)
47 | update_refresh_lock(False)
48 | refresh_app()
49 |
50 |
51 | def display_selected_key_frame(selected_image_location, apply_zoom_effects):
52 | selected_image = None
53 | if selected_image_location:
54 | # if apply_zoom_effects == "Yes":
55 | # image_preview = generate_pil_image(selected_image_location)
56 | # selected_image = apply_image_transformations(image_preview, st.session_state['zoom_level_input'], st.session_state['rotation_angle_input'], st.session_state['x_shift'], st.session_state['y_shift'], st.session_state['flip_vertically'], st.session_state['flip_horizontally'])
57 |
58 | selected_image = generate_pil_image(selected_image_location)
59 | st.info("Starting Image:")
60 | st.image(selected_image)
61 | else:
62 | st.error("No Starting Image Found")
63 |
64 | return selected_image
65 |
66 |
67 | def add_key_frame_element(shot_uuid):
68 | add1, add2 = st.columns(2)
69 | with add1:
70 | selected_image_location, inherit_styling_settings = add_key_frame_section(shot_uuid)
71 | with add2:
72 | selected_image = display_selected_key_frame(selected_image_location, False)
73 |
74 | return selected_image, inherit_styling_settings
75 |
76 |
77 | def add_key_frame(
78 | selected_image: Union[Image.Image, InternalFileObject],
79 | shot_uuid,
80 | target_frame_position=None,
81 | refresh_state=True,
82 | update_cur_frame_idx=True,
83 | ):
84 | """
85 | either a pil image or a internalfileobject can be passed to this method, for adding it inside a shot
86 | """
87 | data_repo = DataRepo()
88 | timing_list = data_repo.get_timing_list_from_shot(shot_uuid)
89 |
90 | # creating frame inside the shot at target_frame_position
91 | len_shot_timing_list = len(timing_list) if len(timing_list) > 0 else 0
92 | target_frame_position = len_shot_timing_list if target_frame_position is None else target_frame_position
93 | target_aux_frame_index = min(len(timing_list), target_frame_position)
94 |
95 | if isinstance(selected_image, InternalFileObject):
96 | saved_image = selected_image
97 | else:
98 | shot = data_repo.get_shot_from_uuid(shot_uuid)
99 | saved_image = save_new_image(selected_image, shot.project.uuid)
100 |
101 | timing_data = {
102 | "shot_id": shot_uuid,
103 | "animation_style": AnimationStyleType.CREATIVE_INTERPOLATION.value,
104 | "aux_frame_index": target_aux_frame_index,
105 | "source_image_id": saved_image.uuid,
106 | "primary_image_id": saved_image.uuid,
107 | }
108 | new_timing: InternalFrameTimingObject = data_repo.create_timing(**timing_data)
109 |
110 | if update_cur_frame_idx:
111 | timing_list = data_repo.get_timing_list_from_shot(shot_uuid)
112 | # this part of code updates current_frame_index when a new keyframe is added
113 | if len(timing_list) <= 1:
114 | st.session_state["current_frame_index"] = 1
115 | st.session_state["current_frame_uuid"] = timing_list[0].uuid
116 | else:
117 | st.session_state["prev_frame_index"] = min(len(timing_list), target_aux_frame_index + 1)
118 | st.session_state["current_frame_index"] = min(len(timing_list), target_aux_frame_index + 1)
119 | st.session_state["current_frame_uuid"] = timing_list[
120 | st.session_state["current_frame_index"] - 1
121 | ].uuid
122 |
123 | print(
124 | f"Updated session state: current_frame_index: {st.session_state['current_frame_index']}, current_frame_uuid: {st.session_state['current_frame_uuid']}"
125 | )
126 |
127 | if refresh_state:
128 | refresh_app()
129 |
130 | return new_timing
131 |
--------------------------------------------------------------------------------
/ui_components/widgets/frame_selector.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import streamlit as st
3 | from utils.data_repo.data_repo import DataRepo
4 | from utils import st_memory
5 | from ui_components.methods.common_methods import add_new_shot
6 | from utils.state_refresh import refresh_app
7 |
8 |
9 | def frame_selector_widget(show_frame_selector=True):
10 | data_repo = DataRepo()
11 | timing_list = data_repo.get_timing_list_from_shot(st.session_state["shot_uuid"])
12 | shot = data_repo.get_shot_from_uuid(st.session_state["shot_uuid"])
13 | shot_list = data_repo.get_shot_list(shot.project.uuid)
14 | len_timing_list = len(timing_list) if len(timing_list) > 0 else 1.0
15 | project_uuid = shot.project.uuid
16 |
17 | if "prev_shot_index" not in st.session_state:
18 | st.session_state["prev_shot_index"] = shot.shot_idx
19 | if "shot_name" not in st.session_state:
20 | st.session_state["shot_name"] = shot.name
21 | shot1, shot2 = st.columns([1, 1])
22 | with shot1:
23 | shot_names = [s.name for s in shot_list]
24 | shot_names.append("**Create New Shot**")
25 | current_shot_name = st.selectbox(
26 | "Shot:", shot_names, key="current_shot_sidebar_selector", index=shot_names.index(shot.name)
27 | )
28 | if current_shot_name != "**Create New Shot**":
29 | if current_shot_name != st.session_state["shot_name"]:
30 | st.session_state["shot_name"] = current_shot_name
31 | refresh_app()
32 |
33 | if current_shot_name == "**Create New Shot**":
34 | new_shot_name = st.text_input(
35 | "New shot name:", max_chars=40, key=f"shot_name_sidebar_{st.session_state['shot_name']}"
36 | )
37 | if st.button("Create new shot", key=f"create_new_shot_{st.session_state['shot_name']}"):
38 | new_shot = add_new_shot(project_uuid, name=new_shot_name)
39 | st.session_state["shot_name"] = new_shot_name
40 | st.session_state["shot_uuid"] = new_shot.uuid
41 | refresh_app()
42 |
43 | # find shot index based on shot name
44 | st.session_state["current_shot_index"] = shot_names.index(st.session_state["shot_name"]) + 1
45 |
46 | if st.session_state["shot_name"] != shot.name:
47 | st.session_state["shot_uuid"] = shot_list[shot_names.index(st.session_state["shot_name"])].uuid
48 | refresh_app()
49 |
50 | if not ("current_shot_index" in st.session_state and st.session_state["current_shot_index"]):
51 | st.session_state["current_shot_index"] = shot_names.index(st.session_state["shot_name"]) + 1
52 | update_current_shot_index(st.session_state["current_shot_index"])
53 |
54 | if st.session_state["page"] == "Key Frames":
55 | if st.session_state["current_frame_index"] > len_timing_list:
56 | update_current_frame_index(len_timing_list)
57 |
58 | elif st.session_state["page"] == "Shots":
59 | if st.session_state["current_shot_index"] > len(shot_list):
60 | update_current_shot_index(len(shot_list))
61 |
62 | if show_frame_selector:
63 | if len(timing_list):
64 | if "prev_frame_index" not in st.session_state or st.session_state["prev_frame_index"] > len(
65 | timing_list
66 | ):
67 | st.session_state["prev_frame_index"] = 1
68 |
69 | # Create a list of frames with a blank value as the first item
70 | frame_list = [""] + [f"{i+1}" for i in range(len(timing_list))]
71 | with shot2:
72 | frame_selection = st_memory.selectbox(
73 | "Frame:", frame_list, key="current_frame_sidebar_selector"
74 | )
75 |
76 | # only trigger the frame number extraction and current frame index update if a non-empty value is selected
77 | if frame_selection != "":
78 | if st.button("Jump to shot view", use_container_width=True):
79 | st.session_state["current_frame_sidebar_selector"] = 0
80 | refresh_app()
81 |
82 | st.session_state["current_frame_index"] = int(frame_selection.split(" ")[-1])
83 | update_current_frame_index(st.session_state["current_frame_index"])
84 | else:
85 | frame_selection = ""
86 | with shot2:
87 | st.write("")
88 | st.error("No frames present")
89 |
90 | return frame_selection
91 |
92 |
93 | def update_current_frame_index(index):
94 | data_repo = DataRepo()
95 | timing_list = data_repo.get_timing_list_from_shot(st.session_state["shot_uuid"])
96 | st.session_state["current_frame_uuid"] = timing_list[index - 1].uuid
97 | if st.session_state["prev_frame_index"] != index or st.session_state["current_frame_index"] != index:
98 | st.session_state["prev_frame_index"] = index
99 | st.session_state["current_frame_index"] = index
100 | st.session_state["current_frame_uuid"] = timing_list[index - 1].uuid
101 | st.session_state["reset_canvas"] = True
102 | st.session_state["frame_styling_view_type_index"] = 0
103 | st.session_state["frame_styling_view_type"] = "Generate View"
104 |
105 | refresh_app()
106 |
107 |
108 | def update_current_shot_index(index):
109 | data_repo = DataRepo()
110 | shot_list = data_repo.get_shot_list(st.session_state["project_uuid"])
111 | st.session_state["shot_uuid"] = shot_list[index - 1].uuid
112 | if st.session_state["prev_shot_index"] != index or st.session_state["current_shot_index"] != index:
113 | st.session_state["current_shot_index"] = index
114 | st.session_state["prev_shot_index"] = index
115 | st.session_state["shot_uuid"] = shot_list[index - 1].uuid
116 | st.session_state["reset_canvas"] = True
117 | st.session_state["frame_styling_view_type_index"] = 0
118 | st.session_state["frame_styling_view_type"] = "Individual View"
119 |
120 | refresh_app()
121 |
--------------------------------------------------------------------------------
/backend/serializers/dto.py:
--------------------------------------------------------------------------------
1 | import json
2 | from rest_framework import serializers
3 |
4 | from backend.models import (
5 | AIModel,
6 | AppSetting,
7 | BackupTiming,
8 | InferenceLog,
9 | InternalFileObject,
10 | Project,
11 | Setting,
12 | Shot,
13 | Timing,
14 | User,
15 | )
16 |
17 |
18 | class UserDto(serializers.ModelSerializer):
19 | class Meta:
20 | model = User
21 | fields = ("uuid", "name", "email", "type", "total_credits")
22 |
23 |
24 | class ProjectDto(serializers.ModelSerializer):
25 | user_uuid = serializers.SerializerMethodField()
26 |
27 | class Meta:
28 | model = Project
29 | fields = ("uuid", "name", "user_uuid", "created_on", "temp_file_list", "meta_data")
30 |
31 | def get_user_uuid(self, obj):
32 | return obj.user.uuid
33 |
34 |
35 | class AIModelDto(serializers.ModelSerializer):
36 | user_uuid = serializers.SerializerMethodField()
37 |
38 | class Meta:
39 | model = AIModel
40 | fields = (
41 | "uuid",
42 | "name",
43 | "user_uuid",
44 | "custom_trained",
45 | "version",
46 | "replicate_model_id",
47 | "replicate_url",
48 | "diffusers_url",
49 | "category",
50 | "training_image_list",
51 | "keyword",
52 | "created_on",
53 | )
54 |
55 | def get_user_uuid(self, obj):
56 | return obj.user.uuid
57 |
58 |
59 | class InferenceLogDto(serializers.ModelSerializer):
60 | project = ProjectDto()
61 | model = AIModelDto()
62 |
63 | class Meta:
64 | model = InferenceLog
65 | fields = (
66 | "uuid",
67 | "project",
68 | "model",
69 | "input_params",
70 | "output_details",
71 | "total_inference_time",
72 | "credits_used",
73 | "created_on",
74 | "updated_on",
75 | "status",
76 | "model_name",
77 | "generation_source",
78 | "generation_tag",
79 | )
80 |
81 |
82 | class InternalFileDto(serializers.ModelSerializer):
83 | project = ProjectDto() # TODO: pass this as context to speed up the api
84 | inference_log = InferenceLogDto()
85 |
86 | class Meta:
87 | model = InternalFileObject
88 | fields = (
89 | "uuid",
90 | "name",
91 | "local_path",
92 | "type",
93 | "hosted_url",
94 | "created_on",
95 | "inference_log",
96 | "project",
97 | "tag",
98 | "shot_uuid",
99 | )
100 |
101 |
102 | class BasicShotDto(serializers.ModelSerializer):
103 | project = ProjectDto()
104 |
105 | class Meta:
106 | model = Shot
107 | fields = (
108 | "uuid",
109 | "name",
110 | "project",
111 | "desc",
112 | "shot_idx",
113 | "project",
114 | "duration",
115 | "meta_data",
116 | )
117 |
118 |
119 | class TimingDto(serializers.ModelSerializer):
120 | model = AIModelDto()
121 | source_image = InternalFileDto()
122 | mask = InternalFileDto()
123 | canny_image = InternalFileDto()
124 | primary_image = InternalFileDto()
125 | shot = BasicShotDto()
126 |
127 | class Meta:
128 | model = Timing
129 | fields = (
130 | "uuid",
131 | "model",
132 | "source_image",
133 | "mask",
134 | "canny_image",
135 | "primary_image",
136 | "alternative_images",
137 | "notes",
138 | "aux_frame_index",
139 | "created_on",
140 | "shot",
141 | )
142 |
143 |
144 | class AppSettingDto(serializers.ModelSerializer):
145 | user = UserDto()
146 |
147 | class Meta:
148 | model = AppSetting
149 | fields = ("uuid", "user", "previous_project", "replicate_username", "welcome_state", "created_on")
150 |
151 |
152 | class SettingDto(serializers.ModelSerializer):
153 | project = ProjectDto()
154 | default_model = AIModelDto()
155 | audio = InternalFileDto()
156 |
157 | class Meta:
158 | model = Setting
159 | fields = (
160 | "uuid",
161 | "project",
162 | "default_model",
163 | "audio",
164 | "input_type",
165 | "width",
166 | "height",
167 | "created_on",
168 | )
169 |
170 |
171 | class BackupDto(serializers.ModelSerializer):
172 | project = ProjectDto()
173 |
174 | class Meta:
175 | model = BackupTiming
176 | fields = ("name", "project", "note", "data_dump", "created_on")
177 |
178 |
179 | class BackupListDto(serializers.ModelSerializer):
180 | project = ProjectDto()
181 |
182 | class Meta:
183 | model = BackupTiming
184 | fields = ("uuid", "project", "name", "note", "created_on")
185 |
186 |
187 | class ShotDto(serializers.ModelSerializer):
188 | timing_list = serializers.SerializerMethodField()
189 | interpolated_clip_list = serializers.SerializerMethodField()
190 | main_clip = InternalFileDto()
191 | project = ProjectDto()
192 |
193 | class Meta:
194 | model = Shot
195 | fields = (
196 | "uuid",
197 | "name",
198 | "desc",
199 | "shot_idx",
200 | "project",
201 | "duration",
202 | "meta_data",
203 | "timing_list",
204 | "interpolated_clip_list",
205 | "main_clip",
206 | )
207 |
208 | def get_timing_list(self, obj):
209 | timing_list = self.context.get("timing_list", [])
210 | timing_list = [
211 | TimingDto(timing).data for timing in timing_list if str(timing.shot.uuid) == str(obj.uuid)
212 | ]
213 | timing_list.sort(key=lambda x: x["aux_frame_index"])
214 | return timing_list
215 |
216 | def get_interpolated_clip_list(self, obj):
217 | id_list = json.loads(obj.interpolated_clip_list) if obj.interpolated_clip_list else []
218 | file_list = InternalFileObject.objects.filter(uuid__in=id_list, is_disabled=False).all()
219 | return [InternalFileDto(file).data for file in file_list]
220 |
--------------------------------------------------------------------------------