├── __init__.py ├── .python-version ├── llm_modules ├── __init__.py └── llm_zephyr.py ├── tts_modules ├── __init__.py └── tts_coqui.py ├── t2i_modules ├── __init__.py ├── t2i_sdxl.py └── t2i_juggernaut.py ├── t2v_modules ├── __init__.py ├── t2v_zeroscope.py ├── t2v_wan.py └── t2v_ltx.py ├── i2v_modules ├── __init__.py ├── i2v_slideshow.py ├── i2v_ltx.py ├── i2v_svd.py └── i2v_wan.py ├── check_versions.py ├── .gitignore ├── pyproject.toml ├── todo.todo ├── module_discovery.py ├── config_manager.py ├── package_code.sh ├── mp3_to_wav_converter.py ├── system.py ├── utils.py ├── base_modules.py ├── __requirements.txt ├── ui_task_executor.py ├── task_executor.py ├── README.md ├── video_assembly.py ├── project_manager.py └── app.py /__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.python-version: -------------------------------------------------------------------------------- 1 | 3.12 -------------------------------------------------------------------------------- /llm_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .llm_zephyr import ZephyrLLM -------------------------------------------------------------------------------- /tts_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .tts_coqui import CoquiTTSModule -------------------------------------------------------------------------------- /t2i_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .t2i_juggernaut import JuggernautT2I 2 | from .t2i_sdxl import SdxlT2I -------------------------------------------------------------------------------- /t2v_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .t2v_zeroscope import ZeroscopeT2V 2 | from .t2v_wan import WanT2V 3 | from .t2v_ltx import LtxT2V -------------------------------------------------------------------------------- /i2v_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .i2v_ltx import LtxI2V 2 | from .i2v_svd import SvdI2V 3 | from .i2v_slideshow import SlideshowI2V 4 | from .i2v_wan import WanI2V -------------------------------------------------------------------------------- /check_versions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import streamlit 3 | import sys 4 | 5 | print(f"Python version: {sys.version}") 6 | print(f"PyTorch version: {torch.__version__}") 7 | print(f"Streamlit version: {streamlit.__version__}") -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # Environment 24 | .env 25 | .venv 26 | env/ 27 | venv/ 28 | ENV/ 29 | 30 | # IDE 31 | .idea/ 32 | .vscode/ 33 | *.swp 34 | *.swo 35 | 36 | # Project specific 37 | prompt_helpers/ 38 | instagram_content/ 39 | output/ 40 | my_reels/ 41 | *.mp4 42 | *.wav 43 | *.png 44 | 45 | project.json 46 | system.json 47 | 48 | modular_reels_output/ 49 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "influencer" 3 | version = "0.1.0" 4 | description = "Add your description here" 5 | readme = "README.md" 6 | requires-python = ">=3.12" 7 | dependencies = [ 8 | "accelerate>=1.7.0", 9 | "coqui-tts>=0.24.3", 10 | "diffusers>=0.33.1", 11 | "ftfy>=6.3.1", 12 | "gputil>=1.4.0", 13 | "hf-transfer>=0.1.9", 14 | "hf-xet>=1.1.1", 15 | "huggingface-hub[cli]>=0.31.2", 16 | "jupyter>=1.1.1", 17 | "llvmlite>=0.44.0", 18 | "moviepy>=2.1.2", 19 | "mutagen>=1.47.0", 20 | "nicegui>=2.19.0", 21 | "numpy>=1.26.4", 22 | "psutil>=7.0.0", 23 | "pydantic>=2.11.5", 24 | "pydub>=0.25.1", 25 | "sentencepiece>=0.2.0", 26 | "streamlit>=1.45.1", 27 | "torch>=2.7.1", 28 | "torchaudio>=2.7.1", 29 | "torchvision>=0.22.1", 30 | "transformers>=4.51.3", 31 | ] 32 | -------------------------------------------------------------------------------- /todo.todo: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | narration to show or not.. font and position and size selection in project 5 | try shot and script 3 times for format then fallback to fallback system. 6 | record user generation form in project status what user actually selcted to create ths video 7 | store time taken in each segment, image, video, audio, assembly and final 8 | show better project in list, status, video duration etc. play from there. 9 | save vram and system ram in system.json and use it to filter models 10 | show module name from its config to dropdown 11 | prompt finetunder button touse llm with different prompt (or a separate module ?) 12 | are we saving reference sound in project somewhere and its path .. and showing back in dashboard page? 13 | is the logic separate from UI so we can change ui part any time without chainging logic? 14 | 15 | ✓ on dashboard keep expander of characters closed and show characters name on expander title 16 | ✗ audio tts emotion parameters (not in narration) 17 | ✗ tts language selection 18 | ✓ scene delete add facility 19 | ✓ add all characters in all scene as default -------------------------------------------------------------------------------- /module_discovery.py: -------------------------------------------------------------------------------- 1 | # In module_discovery.py 2 | 3 | import os 4 | import importlib 5 | import inspect 6 | from typing import Dict, List, Any, Type 7 | # Correctly import from base_modules 8 | from base_modules import BaseLLM, BaseTTS, BaseT2I, BaseI2V, BaseT2V, ModuleCapabilities 9 | 10 | MODULE_TYPES = { 11 | "llm": {"base_class": BaseLLM, "path": "llm_modules"}, 12 | "tts": {"base_class": BaseTTS, "path": "tts_modules"}, 13 | "t2i": {"base_class": BaseT2I, "path": "t2i_modules"}, 14 | "i2v": {"base_class": BaseI2V, "path": "i2v_modules"}, 15 | "t2v": {"base_class": BaseT2V, "path": "t2v_modules"}, 16 | } 17 | 18 | def discover_modules() -> Dict[str, List[Dict[str, Any]]]: 19 | """ 20 | Scans module directories, imports classes, and gets their capabilities. 21 | """ 22 | discovered_modules = {key: [] for key in MODULE_TYPES} 23 | 24 | for module_type, info in MODULE_TYPES.items(): 25 | module_path = info["path"] 26 | base_class = info["base_class"] 27 | 28 | if not os.path.exists(module_path): 29 | continue 30 | 31 | for filename in os.listdir(module_path): 32 | if filename.endswith(".py") and not filename.startswith("__"): 33 | module_name = f"{module_path}.{filename[:-3]}" 34 | try: 35 | module = importlib.import_module(module_name) 36 | for attribute_name in dir(module): 37 | attribute = getattr(module, attribute_name) 38 | if inspect.isclass(attribute) and issubclass(attribute, base_class) and attribute is not base_class: 39 | caps = attribute.get_capabilities() 40 | discovered_modules[module_type].append({ 41 | "name": attribute.__name__, 42 | "path": f"{module_name}.{attribute.__name__}", 43 | "caps": caps, 44 | "class": attribute 45 | }) 46 | except Exception as e: 47 | print(f"Warning: Could not load module {module_name}. Error: {e}") 48 | 49 | return discovered_modules -------------------------------------------------------------------------------- /config_manager.py: -------------------------------------------------------------------------------- 1 | # In config_manager.py 2 | import os 3 | import torch 4 | import gc 5 | from pydantic import BaseModel, Field 6 | from typing import Dict, Tuple, Literal 7 | 8 | DEVICE = "cuda" if torch.cuda.is_available() else "cpu" 9 | if DEVICE == "cuda": os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" 10 | 11 | class ContentConfig(BaseModel): 12 | """Configuration for overall content generation parameters, using Pydantic.""" 13 | # --- User-defined settings from the UI --- 14 | target_video_length_hint: float = 20.0 15 | min_scenes: int = 2 16 | max_scenes: int = 5 17 | aspect_ratio_format: Literal["Portrait", "Landscape"] = "Landscape" 18 | use_svd_flow: bool = True 19 | add_narration_text_to_video: bool = True 20 | seed: int = -1 # <--- NEW: -1 means random seed 21 | 22 | # --- NEW: To be filled from UI selections --- 23 | module_selections: Dict[str, str] = Field(default_factory=dict) 24 | language: str = "en" 25 | 26 | # --- Static project-wide settings --- 27 | fps: int = 24 28 | output_dir: str = "modular_reels_output" 29 | font_for_subtitles: str = "Arial" 30 | 31 | # --- DYNAMIC settings, to be populated by the TaskExecutor --- 32 | model_max_video_shot_duration: float = 2.0 # A safe default 33 | generation_resolution: Tuple[int, int] = (1024, 1024) # A safe default 34 | 35 | @property 36 | def max_scene_narration_duration_hint(self) -> float: 37 | if self.max_scenes > 0 and self.min_scenes > 0: 38 | avg_scenes = (self.min_scenes + self.max_scenes) / 2 39 | return round(self.target_video_length_hint / avg_scenes, 1) 40 | return 6.0 41 | 42 | @property 43 | def final_output_resolution(self) -> Tuple[int, int]: 44 | if self.aspect_ratio_format == "Landscape": 45 | return (1920, 1080) 46 | return (1080, 1920) 47 | 48 | def __init__(self, **data): 49 | super().__init__(**data) 50 | os.makedirs(self.output_dir, exist_ok=True) 51 | 52 | 53 | def clear_vram_globally(*items_to_del): 54 | print(f"Attempting to clear VRAM. Received {len(items_to_del)} items to delete.") 55 | for item in items_to_del: 56 | if hasattr(item, 'to') and hasattr(item, 'dtype') and item.dtype != torch.float16: 57 | try: 58 | item.to('cpu') 59 | except Exception as e: 60 | print(f"Could not move item of type {type(item)} to CPU: {e}") 61 | del items_to_del 62 | gc.collect() 63 | if torch.cuda.is_available(): 64 | torch.cuda.empty_cache() 65 | torch.cuda.ipc_collect() 66 | print("VRAM clearing attempt finished.") -------------------------------------------------------------------------------- /package_code.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Default output file name 4 | OUTPUT_FILE="combined_code.txt" 5 | 6 | # --- Configuration: Define what to include --- 7 | # Add your source directories and specific files here. 8 | # Paths should be relative to where you run the script from. 9 | # Directories will be scanned recursively. 10 | # Use spaces to separate items. 11 | FILES_TO_INCLUDE=( 12 | "README.md" 13 | "base_modules.py" 14 | "utils.py" 15 | "module_discovery.py" 16 | "app.py" 17 | "project_manager.py" 18 | "task_executor.py" 19 | "ui_task_executor.py" 20 | "config_manager.py" 21 | "video_assembly.py" 22 | "llm_modules/" 23 | "tts_modules/" 24 | "t2i_modules/" 25 | "i2v_modules/" 26 | "t2v_modules/" 27 | ) 28 | 29 | # --- End of Configuration --- 30 | 31 | # Check if an output file name was provided as an argument 32 | if [ "$1" ]; then 33 | OUTPUT_FILE="$1" 34 | echo "Using custom output file name: $OUTPUT_FILE" 35 | fi 36 | 37 | # Clear the output file to start fresh 38 | > "$OUTPUT_FILE" 39 | echo "Cleared old content from $OUTPUT_FILE." 40 | 41 | # A function to process and append a file to the output 42 | process_file() { 43 | local file_path=$1 44 | echo "Processing: $file_path" 45 | 46 | # Write the header with the relative file path 47 | echo "==== $file_path ====" >> "$OUTPUT_FILE" 48 | 49 | # Append the content of the file 50 | cat "$file_path" >> "$OUTPUT_FILE" 51 | 52 | # Add multiple newlines at the end for better separation 53 | echo -e "\n\n\n" >> "$OUTPUT_FILE" 54 | } 55 | 56 | # Loop through the configured list of files and directories 57 | for item in "${FILES_TO_INCLUDE[@]}"; do 58 | if [ -f "$item" ]; then 59 | # If it's a single file, process it directly 60 | process_file "$item" 61 | elif [ -d "$item" ]; then 62 | # If it's a directory, find all relevant files inside it 63 | # - The `find` command is powerful. 64 | # - It searches for items of type 'f' (file). 65 | # - It ignores paths containing '__pycache__', '.git', '.vscode', etc. 66 | # - It only includes files ending in '.py' or other specified extensions. 67 | find "$item" -type f \( -name "*.py" -o -name "*.sh" \) \ 68 | -not -path "*/__pycache__/*" \ 69 | -not -path "*/.git/*" \ 70 | -not -path "*/.venv/*" \ 71 | -not -path "*/.vscode/*" \ 72 | | sort | while read -r file; do 73 | process_file "$file" 74 | done 75 | else 76 | echo "Warning: Item '$item' not found. Skipping." 77 | fi 78 | done 79 | 80 | echo "=========================================" 81 | echo "✅ All done!" 82 | echo "Combined code saved to: $OUTPUT_FILE" 83 | echo "=========================================" -------------------------------------------------------------------------------- /tts_modules/tts_coqui.py: -------------------------------------------------------------------------------- 1 | # tts_modules/tts_coqui.py 2 | import os 3 | import torch 4 | import numpy as np 5 | from typing import Tuple, Optional 6 | from TTS.api import TTS as CoquiTTS 7 | from moviepy import AudioFileClip 8 | from scipy.io import wavfile 9 | 10 | from base_modules import BaseTTS, BaseModuleConfig, ModuleCapabilities 11 | from config_manager import DEVICE, clear_vram_globally 12 | 13 | class CoquiTTSConfig(BaseModuleConfig): 14 | model_id: str = "tts_models/multilingual/multi-dataset/xtts_v2" 15 | 16 | class CoquiTTSModule(BaseTTS): 17 | Config = CoquiTTSConfig 18 | 19 | @classmethod 20 | def get_capabilities(cls) -> ModuleCapabilities: 21 | return ModuleCapabilities( 22 | title="XTTS, Multi-Language, Documentary Style", 23 | vram_gb_min=2.0, # XTTS is relatively lightweight 24 | ram_gb_min=8.0, 25 | supported_tts_languages=["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja", "hu", "ko", "hi"] 26 | ) 27 | 28 | def _load_model(self): 29 | if self.model is None: 30 | print(f"Loading TTS model: {self.config.model_id}...") 31 | self.model = CoquiTTS(model_name=self.config.model_id, progress_bar=True).to(DEVICE) 32 | print("TTS model loaded.") 33 | 34 | def clear_vram(self): 35 | print("Clearing TTS VRAM...") 36 | if self.model is not None: 37 | clear_vram_globally(self.model) 38 | self.model = None 39 | print("TTS VRAM cleared.") 40 | 41 | def generate_audio( 42 | self, text: str, output_dir: str, scene_idx: int, language: str, speaker_wav: Optional[str] = None 43 | ) -> Tuple[str, float]: 44 | self._load_model() 45 | 46 | print(f"Generating audio in {language} for scene {scene_idx}: \"{text[:50]}...\"") 47 | output_path = os.path.join(output_dir, f"scene_{scene_idx}_audio.wav") 48 | 49 | tts_kwargs = {"language": language, "file_path": output_path} 50 | 51 | if "xtts" in self.config.model_id.lower(): 52 | if speaker_wav and os.path.exists(speaker_wav): 53 | tts_kwargs["speaker_wav"] = speaker_wav 54 | else: 55 | if speaker_wav: print(f"Warning: Speaker WAV {speaker_wav} not found. XTTS using default voice.") 56 | 57 | self.model.tts_to_file(text, **tts_kwargs) 58 | 59 | duration = 0.0 60 | try: 61 | if os.path.exists(output_path) and os.path.getsize(output_path) > 0: 62 | with AudioFileClip(output_path) as audio_clip: 63 | duration = audio_clip.duration + 0.1 # Small buffer 64 | else: raise ValueError("Audio file not generated or is empty.") 65 | except Exception as e: 66 | print(f"Error getting duration for {output_path}: {e}. Creating fallback.") 67 | samplerate = 22050 68 | wavfile.write(output_path, samplerate, np.zeros(int(0.1 * samplerate), dtype=np.int16)) 69 | duration = 0.1 70 | 71 | print(f"Actual audio duration for scene {scene_idx}: {duration:.2f}s") 72 | return output_path, duration -------------------------------------------------------------------------------- /mp3_to_wav_converter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """ 3 | MP3 to WAV Converter 4 | Converts all MP3 files in the Downloads folder to WAV format. 5 | """ 6 | 7 | import os 8 | import sys 9 | from pathlib import Path 10 | from pydub import AudioSegment 11 | 12 | def convert_mp3_to_wav(downloads_folder="~/Downloads", output_folder=None): 13 | """ 14 | Convert all MP3 files in the Downloads folder to WAV format. 15 | 16 | Args: 17 | downloads_folder (str): Path to the Downloads folder 18 | output_folder (str): Path to output folder (defaults to same as input) 19 | """ 20 | # Expand the tilde to full path 21 | downloads_path = Path(downloads_folder).expanduser() 22 | 23 | if output_folder is None: 24 | output_path = downloads_path 25 | else: 26 | output_path = Path(output_folder).expanduser() 27 | 28 | # Create output directory if it doesn't exist 29 | output_path.mkdir(parents=True, exist_ok=True) 30 | 31 | # Find all MP3 files 32 | mp3_files = list(downloads_path.glob("*.mp3")) 33 | 34 | if not mp3_files: 35 | print("No MP3 files found in the Downloads folder.") 36 | return 37 | 38 | print(f"Found {len(mp3_files)} MP3 file(s) to convert:") 39 | for mp3_file in mp3_files: 40 | print(f" - {mp3_file.name}") 41 | 42 | print("\nStarting conversion...") 43 | 44 | converted_count = 0 45 | failed_count = 0 46 | 47 | for mp3_file in mp3_files: 48 | try: 49 | print(f"Converting: {mp3_file.name}") 50 | 51 | # Load the MP3 file 52 | audio = AudioSegment.from_mp3(str(mp3_file)) 53 | 54 | # Create output filename (replace .mp3 with .wav) 55 | wav_filename = mp3_file.stem + ".wav" 56 | wav_path = output_path / wav_filename 57 | 58 | # Export as WAV 59 | audio.export(str(wav_path), format="wav") 60 | 61 | print(f" ✓ Successfully converted to: {wav_filename}") 62 | converted_count += 1 63 | 64 | except Exception as e: 65 | print(f" ✗ Failed to convert {mp3_file.name}: {str(e)}") 66 | failed_count += 1 67 | 68 | print(f"\nConversion complete!") 69 | print(f"Successfully converted: {converted_count} files") 70 | if failed_count > 0: 71 | print(f"Failed conversions: {failed_count} files") 72 | 73 | def main(): 74 | """Main function to handle command line arguments.""" 75 | import argparse 76 | 77 | parser = argparse.ArgumentParser(description="Convert MP3 files to WAV format") 78 | parser.add_argument("--input", "-i", default="~/Downloads", 79 | help="Input folder containing MP3 files (default: ~/Downloads)") 80 | parser.add_argument("--output", "-o", 81 | help="Output folder for WAV files (default: same as input folder)") 82 | 83 | args = parser.parse_args() 84 | 85 | try: 86 | convert_mp3_to_wav(args.input, args.output) 87 | except KeyboardInterrupt: 88 | print("\nConversion interrupted by user.") 89 | sys.exit(1) 90 | except Exception as e: 91 | print(f"Error: {str(e)}") 92 | sys.exit(1) 93 | 94 | if __name__ == "__main__": 95 | main() 96 | -------------------------------------------------------------------------------- /system.py: -------------------------------------------------------------------------------- 1 | # In system.py 2 | import json 3 | import os 4 | from pydantic import BaseModel, Field 5 | from typing import Optional, Tuple 6 | 7 | # --- START OF MODIFICATION --- 8 | # Import necessary libraries for detection 9 | try: 10 | import psutil 11 | except ImportError: 12 | psutil = None 13 | 14 | try: 15 | import GPUtil 16 | except ImportError: 17 | GPUtil = None 18 | # --- END OF MODIFICATION --- 19 | 20 | 21 | SYSTEM_CONFIG_FILE = "system.json" 22 | 23 | class SystemConfig(BaseModel): 24 | """Stores the user's available system resources.""" 25 | vram_gb: float = Field(description="Available GPU VRAM in GB.") 26 | ram_gb: float = Field(description="Available system RAM in GB.") 27 | 28 | def save_system_config(vram_gb: float, ram_gb: float) -> None: 29 | """Saves the system resource configuration to system.json.""" 30 | config = SystemConfig(vram_gb=vram_gb, ram_gb=ram_gb) 31 | with open(SYSTEM_CONFIG_FILE, 'w') as f: 32 | f.write(config.model_dump_json(indent=4)) 33 | print(f"System configuration saved to {SYSTEM_CONFIG_FILE}") 34 | 35 | def load_system_config() -> Optional[SystemConfig]: 36 | """Loads the system resource configuration from system.json if it exists.""" 37 | if not os.path.exists(SYSTEM_CONFIG_FILE): 38 | return None 39 | try: 40 | with open(SYSTEM_CONFIG_FILE, 'r') as f: 41 | data = json.load(f) 42 | return SystemConfig(**data) 43 | except (json.JSONDecodeError, TypeError) as e: 44 | print(f"Error loading or parsing {SYSTEM_CONFIG_FILE}: {e}. Please re-enter details.") 45 | return None 46 | 47 | # --- START OF MODIFICATION --- 48 | def detect_system_specs() -> Tuple[float, float]: 49 | """ 50 | Attempts to detect available system RAM and GPU VRAM. 51 | Returns (vram_in_gb, ram_in_gb). 52 | Defaults to 8.0 for VRAM and 16.0 for RAM if detection fails. 53 | """ 54 | # Default values 55 | detected_ram_gb = 16.0 56 | detected_vram_gb = 8.0 57 | 58 | # 1. Detect System RAM 59 | if psutil: 60 | try: 61 | ram_bytes = psutil.virtual_memory().total 62 | # Round to the nearest whole number for a cleaner UI 63 | detected_ram_gb = round(ram_bytes / (1024**3)) 64 | print(f"Detected System RAM: {detected_ram_gb} GB") 65 | except Exception as e: 66 | print(f"Could not detect system RAM using psutil: {e}. Falling back to default.") 67 | else: 68 | print("psutil not installed. Cannot detect RAM. Falling back to default.") 69 | 70 | # 2. Detect GPU VRAM 71 | if GPUtil: 72 | try: 73 | gpus = GPUtil.getGPUs() 74 | if gpus: 75 | # Use the VRAM of the first detected GPU 76 | gpu = gpus[0] 77 | # VRAM is in MB, convert to GB and round to one decimal place 78 | detected_vram_gb = round(gpu.memoryTotal / 1024, 1) 79 | print(f"Detected GPU: {gpu.name} with {detected_vram_gb} GB VRAM") 80 | else: 81 | print("GPUtil found no GPUs. Falling back to default VRAM.") 82 | except Exception as e: 83 | print(f"Could not detect GPU VRAM using GPUtil: {e}. Falling back to default.") 84 | else: 85 | print("GPUtil not installed. Cannot detect VRAM. Falling back to default.") 86 | 87 | return detected_vram_gb, detected_ram_gb 88 | # --- END OF MODIFICATION --- -------------------------------------------------------------------------------- /i2v_modules/i2v_slideshow.py: -------------------------------------------------------------------------------- 1 | # In i2v_modules/i2v_slideshow.py 2 | from typing import Dict, Any, List, Optional, Union 3 | # --- THIS IS THE FIX: Importing ImageClip directly, matching the project's pattern --- 4 | from moviepy.video.VideoClip import ImageClip 5 | 6 | from base_modules import BaseI2V, BaseModuleConfig, ModuleCapabilities 7 | from config_manager import ContentConfig 8 | 9 | class SlideshowI2VConfig(BaseModuleConfig): 10 | # This module doesn't load a model, but the config is part of the contract. 11 | model_id: str = "moviepy_image_clip" 12 | 13 | class SlideshowI2V(BaseI2V): 14 | Config = SlideshowI2VConfig 15 | 16 | @classmethod 17 | def get_capabilities(cls) -> ModuleCapabilities: 18 | """ 19 | Defines the capabilities of this simple, non-AI module. 20 | It uses minimal resources and doesn't support AI-specific features. 21 | """ 22 | return ModuleCapabilities( 23 | title="Slideshow (Static Image)", 24 | vram_gb_min=0.1, # Uses virtually no VRAM 25 | ram_gb_min=1.0, # Uses very little RAM 26 | supported_formats=["Portrait", "Landscape"], 27 | supports_ip_adapter=False, # Not an AI model 28 | supports_lora=False, # Not an AI model 29 | max_subjects=0, 30 | accepts_text_prompt=False, # Ignores prompts 31 | accepts_negative_prompt=False 32 | ) 33 | 34 | def get_model_capabilities(self) -> Dict[str, Any]: 35 | """ 36 | This module has no native resolution and can handle long durations. 37 | """ 38 | return { 39 | # It can handle any resolution, as it just wraps the image. 40 | "resolutions": {"Portrait": (1080, 1920), "Landscape": (1920, 1080)}, 41 | "max_shot_duration": 60.0 # Can be very long 42 | } 43 | 44 | def _load_pipeline(self): 45 | """No pipeline to load for this module.""" 46 | print("SlideshowI2V: No pipeline to load.") 47 | pass 48 | 49 | def clear_vram(self): 50 | """No VRAM to clear for this module.""" 51 | print("SlideshowI2V: No VRAM to clear.") 52 | pass 53 | 54 | def enhance_prompt(self, prompt: str, prompt_type: str = "visual") -> str: 55 | """This module ignores prompts, so no enhancement is needed.""" 56 | return prompt 57 | 58 | def generate_video_from_image(self, image_path: str, output_video_path: str, target_duration: float, content_config: ContentConfig, visual_prompt: str, motion_prompt: Optional[str], ip_adapter_image: Optional[Union[str, List[str]]] = None) -> str: 59 | """ 60 | Creates a video by holding a static image for the target duration. 61 | """ 62 | print(f"SlideshowI2V: Creating static video for {target_duration:.2f}s from {image_path}") 63 | 64 | video_clip = None 65 | try: 66 | # Create a video clip from the static image and set its duration. 67 | video_clip = ImageClip(image_path).with_duration(target_duration) 68 | 69 | # Use the correct syntax for write_videofile, matching video_assembly.py 70 | video_clip.write_videofile( 71 | output_video_path, 72 | fps=content_config.fps, 73 | codec="libx264", 74 | audio=False, # This is a visual-only shot 75 | threads=4, 76 | preset="medium", 77 | logger=None # Suppress verbose moviepy logs 78 | ) 79 | 80 | print(f"Slideshow video shot saved to {output_video_path}") 81 | return output_video_path 82 | 83 | except Exception as e: 84 | print(f"Error creating slideshow video: {e}") 85 | return "" # Return empty string on failure 86 | finally: 87 | # Ensure the clip resources are released 88 | if video_clip: 89 | video_clip.close() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # In utils.py 2 | import datetime 3 | import json 4 | import os 5 | from PIL import Image, ImageOps 6 | from moviepy import VideoFileClip 7 | import streamlit as st # Keep st for st.error 8 | 9 | def load_and_correct_image_orientation(image_source): 10 | """ 11 | Loads an image from a source (file path or uploaded file object) 12 | and corrects its orientation based on EXIF data. 13 | """ 14 | try: 15 | image = Image.open(image_source) 16 | # The magic is in exif_transpose 17 | corrected_image = ImageOps.exif_transpose(image) 18 | return corrected_image 19 | except Exception as e: 20 | # Using st.error here is okay for a simple app, but for true separation, 21 | # you might log the error and return None, letting the caller handle the UI. 22 | # For this project, this is fine. 23 | st.error(f"Could not load or correct image: {e}") 24 | return None 25 | 26 | def list_projects(): 27 | """Lists all projects from the output directory with extended details including modules.""" 28 | projects = [] 29 | base_dir = "modular_reels_output" 30 | if not os.path.exists(base_dir): return [] 31 | for project_dir in os.listdir(base_dir): 32 | project_path = os.path.join(base_dir, project_dir) 33 | if os.path.isdir(project_path): 34 | project_file = os.path.join(project_path, "project.json") 35 | if os.path.exists(project_file): 36 | try: 37 | with open(project_file, 'r') as f: 38 | data = json.load(f) 39 | 40 | project_info = data.get('project_info', {}) 41 | # --- START OF MODIFICATION --- 42 | # Use title, but fall back to topic for old projects, then to dir name. 43 | title = project_info.get('title', project_info.get('topic', project_dir)) 44 | topic = project_info.get('topic', 'N/A') # Keep topic for potential detailed views 45 | # --- END OF MODIFICATION --- 46 | 47 | config = project_info.get('config', {}) 48 | final_video_info = data.get('final_video', {}) 49 | status = project_info.get('status', 'unknown') 50 | 51 | flow = "Image-to-Video" if config.get('use_svd_flow', True) else "Text-to-Video" 52 | 53 | final_video_path = None 54 | duration = 0.0 55 | if status == 'completed': 56 | stored_path = final_video_info.get('path') 57 | if stored_path and os.path.exists(stored_path): 58 | final_video_path = stored_path 59 | try: 60 | with VideoFileClip(final_video_path) as clip: 61 | duration = clip.duration 62 | except Exception as e: 63 | print(f"Could not read video duration for {final_video_path}: {e}") 64 | duration = 0.0 65 | 66 | modules = config.get('module_selections', {}) 67 | 68 | # --- START OF MODIFICATION --- 69 | projects.append({ 70 | 'name': project_dir, 71 | 'title': title, # Use the new title field 72 | 'topic': topic, # Keep topic field for completeness 73 | 'created_at': datetime.datetime.fromtimestamp(project_info.get('created_at', 0)), 74 | 'status': status, 75 | 'flow': flow, 76 | 'final_video_path': final_video_path, 77 | 'duration': duration, 78 | 'modules': modules, 79 | }) 80 | # --- END OF MODIFICATION --- 81 | except Exception as e: 82 | print(f"Error loading project {project_dir}: {e}") 83 | return sorted(projects, key=lambda p: p['created_at'], reverse=True) -------------------------------------------------------------------------------- /t2i_modules/t2i_sdxl.py: -------------------------------------------------------------------------------- 1 | # t2i_modules/t2i_sdxl.py 2 | import torch 3 | from typing import List, Optional, Dict, Any, Union 4 | from diffusers import StableDiffusionXLPipeline, DiffusionPipeline 5 | from diffusers.utils import load_image 6 | 7 | from base_modules import BaseT2I, BaseModuleConfig, ModuleCapabilities 8 | from config_manager import DEVICE, clear_vram_globally 9 | 10 | class SdxlT2IConfig(BaseModuleConfig): 11 | model_id: str = "stabilityai/stable-diffusion-xl-base-1.0" 12 | refiner_id: Optional[str] = "stabilityai/stable-diffusion-xl-refiner-1.0" 13 | num_inference_steps: int = 30 14 | guidance_scale: float = 7.5 15 | base_denoising_end: float = 0.8 16 | refiner_denoising_start: float = 0.8 17 | 18 | class SdxlT2I(BaseT2I): 19 | Config = SdxlT2IConfig 20 | 21 | def __init__(self, config: SdxlT2IConfig): 22 | super().__init__(config) 23 | self.refiner_pipe = None 24 | 25 | @classmethod 26 | def get_capabilities(cls) -> ModuleCapabilities: 27 | return ModuleCapabilities( 28 | title="SDXL + Refiner (High VRAM): No Subjects considered", 29 | vram_gb_min=10.0, # SDXL with refiner is heavy 30 | ram_gb_min=16.0, 31 | supported_formats=["Portrait", "Landscape"], 32 | # Even if we don't implement IP-Adapter here, we declare support 33 | # because the pipeline is capable. A more advanced version could add it. 34 | supports_ip_adapter=True, 35 | supports_lora=True, 36 | max_subjects=2, 37 | accepts_text_prompt=True, 38 | accepts_negative_prompt=True 39 | ) 40 | 41 | def get_model_capabilities(self) -> Dict[str, Any]: 42 | return { 43 | "resolutions": {"Portrait": (896, 1152), "Landscape": (1344, 768)}, 44 | "max_shot_duration": 3.0 45 | } 46 | 47 | def _load_pipeline(self): 48 | if self.pipe is None: 49 | print(f"Loading T2I pipeline (SDXL): {self.config.model_id}...") 50 | self.pipe = StableDiffusionXLPipeline.from_pretrained( 51 | self.config.model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True 52 | ).to(DEVICE) 53 | print("SDXL Base pipeline loaded.") 54 | if self.config.refiner_id: 55 | print(f"Loading T2I Refiner pipeline: {self.config.refiner_id}...") 56 | self.refiner_pipe = DiffusionPipeline.from_pretrained( 57 | self.config.refiner_id, text_encoder_2=self.pipe.text_encoder_2, 58 | vae=self.pipe.vae, torch_dtype=torch.float16, 59 | use_safetensors=True, variant="fp16" 60 | ).to(DEVICE) 61 | print("SDXL Refiner pipeline loaded.") 62 | 63 | def clear_vram(self): 64 | print("Clearing T2I (SDXL) VRAM...") 65 | models = [m for m in [self.pipe, self.refiner_pipe] if m is not None] 66 | if models: clear_vram_globally(*models) 67 | self.pipe, self.refiner_pipe = None, None 68 | print("T2I (SDXL) VRAM cleared.") 69 | 70 | # --- START OF FIX: Updated method signature and implementation --- 71 | def generate_image(self, prompt: str, negative_prompt: str, output_path: str, width: int, height: int, ip_adapter_image: Optional[Union[str, List[str]]] = None, seed: int = -1) -> str: 72 | self._load_pipeline() 73 | 74 | if ip_adapter_image: 75 | print("Warning: SDXLT2I module received IP-Adapter image but does not currently implement its use.") 76 | 77 | generator = None 78 | if seed != -1: 79 | print(f"Using fixed seed for generation: {seed}") 80 | # Ensure the generator is on the same device as the pipeline 81 | generator = torch.Generator(device=self.pipe.device).manual_seed(seed) 82 | else: 83 | print("Using random seed for generation.") 84 | 85 | kwargs = { 86 | "prompt": prompt, 87 | "negative_prompt": negative_prompt, # Now passing this argument 88 | "width": width, "height": height, 89 | "num_inference_steps": self.config.num_inference_steps, 90 | "guidance_scale": self.config.guidance_scale, 91 | "generator": generator # Now passing the generator 92 | } 93 | if self.refiner_pipe: 94 | kwargs["output_type"] = "latent" 95 | kwargs["denoising_end"] = self.config.base_denoising_end 96 | 97 | image = self.pipe(**kwargs).images[0] 98 | 99 | if self.refiner_pipe: 100 | print("Refining image...") 101 | refiner_kwargs = { 102 | "prompt": prompt, 103 | "negative_prompt": negative_prompt, 104 | "image": image, 105 | "denoising_start": self.config.refiner_denoising_start, 106 | "num_inference_steps": self.config.num_inference_steps, 107 | "generator": generator 108 | } 109 | image = self.refiner_pipe(**refiner_kwargs).images[0] 110 | 111 | image.save(output_path) 112 | print(f"Image saved to {output_path}") 113 | return output_path 114 | # --- END OF FIX --- -------------------------------------------------------------------------------- /i2v_modules/i2v_ltx.py: -------------------------------------------------------------------------------- 1 | # i2v_modules/i2v_ltx.py 2 | import torch 3 | from typing import Dict, Any, List, Optional, Union 4 | from diffusers import LTXImageToVideoPipeline 5 | from diffusers.utils import export_to_video, load_image 6 | from PIL import Image 7 | 8 | from base_modules import BaseI2V, BaseModuleConfig, ModuleCapabilities 9 | from config_manager import DEVICE, clear_vram_globally, ContentConfig 10 | 11 | class LtxI2VConfig(BaseModuleConfig): 12 | model_id: str = "Lightricks/LTX-Video" 13 | num_inference_steps: int = 50 14 | guidance_scale: float = 7.5 15 | 16 | class LtxI2V(BaseI2V): 17 | Config = LtxI2VConfig 18 | 19 | @classmethod 20 | def get_capabilities(cls) -> ModuleCapabilities: 21 | return ModuleCapabilities( 22 | title="LTX, 8bit Load, Port/LandScape, 2 Sub, Take +/- Prompts, max 4 sec", 23 | vram_gb_min=8.0, 24 | ram_gb_min=12.0, 25 | supported_formats=["Portrait", "Landscape"], 26 | supports_ip_adapter=True, 27 | supports_lora=True, # Juggernaut is a fine-tune, can easily use LoRAs 28 | max_subjects=2, # Can handle one or two IP adapter images 29 | accepts_text_prompt=True, 30 | accepts_negative_prompt=True 31 | ) 32 | 33 | 34 | def get_model_capabilities(self) -> Dict[str, Any]: 35 | return { 36 | "resolutions": {"Portrait": (480, 704), "Landscape": (704, 480)}, 37 | "max_shot_duration": 4 38 | } 39 | 40 | def enhance_prompt(self, prompt: str, prompt_type: str = "visual") -> str: 41 | # SVD doesn't use text prompts, but this shows how you could add model-specific keywords. 42 | # For example, for a different model you might do: 43 | if prompt_type == "visual": 44 | return f"{prompt}, 8k, photorealistic, cinematic lighting" 45 | return prompt # Return original for SVD 46 | 47 | def _load_pipeline(self): 48 | if self.pipe is None: 49 | print(f"Loading I2V pipeline (LTX): {self.config.model_id}...") 50 | self.pipe = LTXImageToVideoPipeline.from_pretrained(self.config.model_id, torch_dtype=torch.bfloat16) 51 | self.pipe.enable_model_cpu_offload() 52 | print("I2V (LTX) pipeline loaded.") 53 | 54 | def clear_vram(self): 55 | print("Clearing I2V (LTX) VRAM...") 56 | if self.pipe is not None: clear_vram_globally(self.pipe) 57 | self.pipe = None 58 | print("I2V (LTX) VRAM cleared.") 59 | 60 | def _resize_and_pad(self, image: Image.Image, target_width: int, target_height: int) -> Image.Image: 61 | original_aspect = image.width / image.height; target_aspect = target_width / target_height 62 | if original_aspect > target_aspect: new_width, new_height = target_width, int(target_width / original_aspect) 63 | else: new_height, new_width = target_height, int(target_height * original_aspect) 64 | resized_image = image.resize((new_width, new_height), Image.LANCZOS) 65 | background = Image.new('RGB', (target_width, target_height), (0, 0, 0)) 66 | background.paste(resized_image, ((target_width - new_width) // 2, (target_height - new_height) // 2)) 67 | return background 68 | 69 | def generate_video_from_image(self, image_path: str, output_video_path: str, target_duration: float, content_config: ContentConfig, visual_prompt: str, motion_prompt: Optional[str], ip_adapter_image: Optional[Union[str, List[str]]] = None) -> str: 70 | self._load_pipeline() 71 | 72 | input_image = load_image(image_path) 73 | target_res = self.get_model_capabilities()["resolutions"] 74 | aspect_ratio = "Landscape" if input_image.width > input_image.height else "Portrait" 75 | target_width, target_height = target_res[aspect_ratio] 76 | prepared_image = self._resize_and_pad(input_image, target_width, target_height) 77 | 78 | num_frames = max(16, int(target_duration * content_config.fps)) 79 | full_prompt = f"{visual_prompt}, {motion_prompt}" if motion_prompt else visual_prompt 80 | 81 | # --- NEW LOGIC TO HANDLE ip_adapter_image --- 82 | # While LTX doesn't have a formal IP-Adapter, we can use the character 83 | # reference to guide the style by adding it to the prompt. 84 | if ip_adapter_image: 85 | print("LTX I2V: Using character reference to guide prompt style.") 86 | # For simplicity, we add a generic phrase. A more complex system could use an image-to-text model. 87 | full_prompt = f"in the style of the reference character, {full_prompt}" 88 | 89 | print(f"LTX I2V using prompt: {full_prompt}") 90 | 91 | video = self.pipe( 92 | prompt=full_prompt, image=prepared_image, width=target_width, height=target_height, 93 | num_frames=num_frames, num_inference_steps=self.config.num_inference_steps, 94 | guidance_scale=self.config.guidance_scale, 95 | negative_prompt="worst quality, inconsistent motion, blurry" 96 | ).frames[0] 97 | 98 | export_to_video(video, output_video_path, fps=content_config.fps) 99 | print(f"LTX video shot saved to {output_video_path}") 100 | return output_video_path -------------------------------------------------------------------------------- /t2v_modules/t2v_zeroscope.py: -------------------------------------------------------------------------------- 1 | # In t2v_modules/t2v_zeroscope.py 2 | import torch 3 | from typing import Dict, Any, List, Optional, Union 4 | from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler 5 | from diffusers.utils import export_to_video 6 | 7 | from base_modules import BaseT2V, BaseModuleConfig, ModuleCapabilities 8 | from config_manager import DEVICE, clear_vram_globally 9 | 10 | class ZeroscopeT2VConfig(BaseModuleConfig): 11 | model_id: str = "cerspense/zeroscope_v2_576w" 12 | upscaler_model_id: str = "cerspense/zeroscope_v2_xl" 13 | 14 | num_inference_steps: int = 30 15 | guidance_scale: float = 9.0 16 | # --- START OF FIX: Add strength for the upscaling process --- 17 | upscaler_strength: float = 0.7 18 | # --- END OF FIX --- 19 | 20 | class ZeroscopeT2V(BaseT2V): 21 | Config = ZeroscopeT2VConfig 22 | 23 | @classmethod 24 | def get_capabilities(cls) -> ModuleCapabilities: 25 | return ModuleCapabilities( 26 | title="Zeroscope, Port/Landscape, No Subject, 2 sec", 27 | vram_gb_min=8.0, 28 | ram_gb_min=12.0, 29 | supported_formats=["Portrait", "Landscape"], 30 | supports_ip_adapter=False, # Zeroscope does not support IP-Adapter 31 | supports_lora=False, # Zeroscope does not support LoRA loading 32 | max_subjects=0, 33 | accepts_text_prompt=True, 34 | accepts_negative_prompt=True 35 | ) 36 | 37 | 38 | def __init__(self, config: ZeroscopeT2VConfig): 39 | super().__init__(config) 40 | self.upscaler_pipe = None 41 | 42 | def get_model_capabilities(self) -> Dict[str, Any]: 43 | # Zeroscope has a fixed native resolution that is then upscaled 44 | base_resolution = (576, 320) 45 | return { 46 | "resolutions": {"Portrait": base_resolution, "Landscape": base_resolution}, 47 | "max_shot_duration": 2.0 48 | } 49 | 50 | def _load_pipeline(self): 51 | if self.pipe is None: 52 | print(f"Loading T2V pipeline ({self.config.model_id})...") 53 | self.pipe = DiffusionPipeline.from_pretrained(self.config.model_id, torch_dtype=torch.float16) 54 | self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config) 55 | self.pipe.enable_model_cpu_offload() 56 | print(f"T2V ({self.config.model_id}) pipeline loaded.") 57 | 58 | if self.upscaler_pipe is None: 59 | print(f"Loading T2V Upscaler pipeline ({self.config.upscaler_model_id})...") 60 | self.upscaler_pipe = DiffusionPipeline.from_pretrained(self.config.upscaler_model_id, torch_dtype=torch.float16) 61 | self.upscaler_pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.upscaler_pipe.scheduler.config) 62 | self.upscaler_pipe.enable_model_cpu_offload() 63 | print(f"T2V Upscaler ({self.config.upscaler_model_id}) pipeline loaded.") 64 | 65 | def clear_vram(self): 66 | print(f"Clearing T2V VRAM...") 67 | models_to_clear = [m for m in [self.pipe, self.upscaler_pipe] if m is not None] 68 | if models_to_clear: clear_vram_globally(*models_to_clear) 69 | self.pipe, self.upscaler_pipe = None, None 70 | print("T2V VRAM cleared.") 71 | 72 | def generate_video_from_text( 73 | self, prompt: str, output_video_path: str, num_frames: int, fps: int, width: int, height: int, ip_adapter_image: Optional[Union[str, List[str]]] = None 74 | ) -> str: 75 | self._load_pipeline() 76 | 77 | if ip_adapter_image: 78 | print("Warning: ZeroscopeT2V module received IP-Adapter image but does not currently implement its use.") 79 | 80 | negative_prompt = "blurry, low quality, watermark, bad anatomy, text, letters, distorted" 81 | 82 | # Note: Zeroscope generates at a fixed resolution, so we use its capabilities directly 83 | model_res = self.get_model_capabilities()["resolutions"]["Landscape"] 84 | 85 | print(f"Stage 1: Generating T2V ({model_res[0]}x{model_res[1]}) for prompt: \"{prompt[:70]}...\"") 86 | 87 | video_frames_tensor = self.pipe( 88 | prompt=prompt, negative_prompt=negative_prompt, 89 | num_inference_steps=self.config.num_inference_steps, 90 | height=model_res[1], width=model_res[0], num_frames=num_frames, 91 | guidance_scale=self.config.guidance_scale, output_type="pt" 92 | ).frames 93 | 94 | print("Stage 2: Upscaling video to HD...") 95 | 96 | # --- START OF FIX --- 97 | upscaled_video_frames = self.upscaler_pipe( 98 | prompt=prompt, 99 | negative_prompt=negative_prompt, 100 | video=video_frames_tensor, # The argument is 'video', not 'image'. 101 | strength=self.config.upscaler_strength, # Add the strength parameter 102 | num_inference_steps=self.config.num_inference_steps, 103 | guidance_scale=self.config.guidance_scale, 104 | ).frames[0] 105 | # --- END OF FIX --- 106 | 107 | export_to_video(upscaled_video_frames, output_video_path, fps=fps) 108 | 109 | print(f"High-quality T2V video shot saved to {output_video_path}") 110 | return output_video_path -------------------------------------------------------------------------------- /t2v_modules/t2v_wan.py: -------------------------------------------------------------------------------- 1 | # In t2v_modules/t2v_wan.py 2 | import torch 3 | from typing import Dict, Any, List, Optional, Union 4 | 5 | # --- Important: Import the specific classes for this model --- 6 | from diffusers import WanPipeline, AutoencoderKLWan 7 | from diffusers.utils import export_to_video 8 | 9 | from base_modules import BaseT2V, BaseModuleConfig, ModuleCapabilities 10 | from config_manager import DEVICE, clear_vram_globally 11 | 12 | class WanT2VConfig(BaseModuleConfig): 13 | """Configuration for the Wan 2.1 T2V model.""" 14 | model_id: str = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" 15 | # Parameters from the model card example 16 | num_inference_steps: int = 30 17 | guidance_scale: float = 5.0 18 | 19 | class WanT2V(BaseT2V): 20 | """ 21 | Text-to-Video module using Wan 2.1 T2V 1.3B model. 22 | This model is efficient and produces high-quality video but does not support 23 | character consistency (IP-Adapter). 24 | """ 25 | Config = WanT2VConfig 26 | 27 | @classmethod 28 | def get_capabilities(cls) -> ModuleCapabilities: 29 | """Declare the capabilities of the Wan 2.1 model.""" 30 | return ModuleCapabilities( 31 | title="Wan 2.1 (1.3B, Fast, 5s Shots)", 32 | vram_gb_min=15.0, # Based on the 8.19 GB requirement from the model card 33 | ram_gb_min=12.0, 34 | supported_formats=["Portrait", "Landscape"], 35 | # This model does not support IP-Adapter, so we are honest here. 36 | supports_ip_adapter=False, 37 | supports_lora=False, # The pipeline does not have a LoRA loader 38 | max_subjects=0, 39 | accepts_text_prompt=True, 40 | accepts_negative_prompt=True 41 | ) 42 | 43 | def get_model_capabilities(self) -> Dict[str, Any]: 44 | """Return the specific resolutions and max duration for this model.""" 45 | return { 46 | # Based on the example: width=832, height=480 47 | "resolutions": {"Portrait": (480, 832), "Landscape": (832, 480)}, 48 | # Based on the example: "generate a 5-second 480P video" 49 | "max_shot_duration": 5.0 50 | } 51 | 52 | def _load_pipeline(self): 53 | """Loads the custom WanPipeline and its required VAE.""" 54 | if self.pipe is not None: 55 | return 56 | 57 | print(f"Loading T2V pipeline ({self.config.model_id})...") 58 | 59 | # This model requires loading the VAE separately first 60 | vae = AutoencoderKLWan.from_pretrained( 61 | self.config.model_id, 62 | subfolder="vae", 63 | torch_dtype=torch.float32 # VAE often works better in float32 64 | ) 65 | 66 | # Then, load the main pipeline, passing the VAE to it 67 | self.pipe = WanPipeline.from_pretrained( 68 | self.config.model_id, 69 | vae=vae, 70 | torch_dtype=torch.bfloat16 # bfloat16 is recommended in the example 71 | ) 72 | 73 | self.pipe.enable_model_cpu_offload() 74 | 75 | print(f"T2V ({self.config.model_id}) pipeline loaded to {DEVICE}.") 76 | 77 | def clear_vram(self): 78 | """Clears the VRAM used by the pipeline.""" 79 | print(f"Clearing T2V (Wan) VRAM...") 80 | if self.pipe is not None: 81 | clear_vram_globally(self.pipe) 82 | self.pipe = None 83 | print("T2V (Wan) VRAM cleared.") 84 | 85 | def generate_video_from_text( 86 | self, prompt: str, output_video_path: str, num_frames: int, fps: int, width: int, height: int, ip_adapter_image: Optional[Union[str, List[str]]] = None 87 | ) -> str: 88 | """Generates a video shot using the Wan T2V pipeline.""" 89 | self._load_pipeline() 90 | 91 | # Gracefully handle the case where character images are passed to a non-supporting model. 92 | if ip_adapter_image: 93 | print("="*50) 94 | print("WARNING: The WanT2V module does not support IP-Adapters for character consistency.") 95 | print("The provided character images will be ignored for this T2V generation.") 96 | print("="*50) 97 | 98 | # Use the detailed negative prompt from the model card for best results 99 | negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" 100 | 101 | print(f"Generating Wan T2V ({width}x{height}) for prompt: \"{prompt[:70]}...\"") 102 | 103 | video_frames = self.pipe( 104 | prompt=prompt, 105 | negative_prompt=negative_prompt, 106 | height=height, 107 | width=width, 108 | num_frames=num_frames, 109 | guidance_scale=self.config.guidance_scale, 110 | num_inference_steps=self.config.num_inference_steps 111 | ).frames[0] 112 | 113 | # The system's config determines the final FPS, not the model's example 114 | export_to_video(video_frames, output_video_path, fps=fps) 115 | 116 | print(f"Wan T2V video shot saved to {output_video_path}") 117 | return output_video_path -------------------------------------------------------------------------------- /i2v_modules/i2v_svd.py: -------------------------------------------------------------------------------- 1 | # i2v_modules/i2v_svd.py 2 | import torch 3 | from typing import Dict, Any, List, Optional, Union 4 | from diffusers import StableVideoDiffusionPipeline 5 | from diffusers.utils import load_image, export_to_video 6 | from PIL import Image 7 | 8 | from base_modules import BaseI2V, BaseModuleConfig, ModuleCapabilities 9 | from config_manager import DEVICE, clear_vram_globally, ContentConfig 10 | 11 | class SvdI2VConfig(BaseModuleConfig): 12 | model_id: str = "stabilityai/stable-video-diffusion-img2vid-xt" 13 | decode_chunk_size: int = 8 14 | motion_bucket_id: int = 127 15 | noise_aug_strength: float = 0.02 16 | model_native_frames: int = 25 17 | 18 | class SvdI2V(BaseI2V): 19 | Config = SvdI2VConfig 20 | 21 | @classmethod 22 | def get_capabilities(cls) -> ModuleCapabilities: 23 | return ModuleCapabilities( 24 | title="SVD, Float16, Port/Landscape, No Prompt just image, Max 2 Sec", 25 | vram_gb_min=8.0, 26 | ram_gb_min=12.0, 27 | supported_formats=["Portrait", "Landscape"], 28 | supports_ip_adapter=True, 29 | supports_lora=True, # Juggernaut is a fine-tune, can easily use LoRAs 30 | max_subjects=2, # Can handle one or two IP adapter images 31 | accepts_text_prompt=False, 32 | accepts_negative_prompt=True 33 | ) 34 | 35 | 36 | def get_model_capabilities(self) -> Dict[str, Any]: 37 | return { 38 | "resolutions": {"Portrait": (576, 1024), "Landscape": (1024, 576)}, 39 | "max_shot_duration": 2.0 40 | } 41 | 42 | def enhance_prompt(self, prompt: str, prompt_type: str = "visual") -> str: 43 | # SVD doesn't use text prompts, but this shows how you could add model-specific keywords. 44 | # For example, for a different model you might do: 45 | # if prompt_type == "visual": 46 | # return f"{prompt}, 8k, photorealistic, cinematic lighting" 47 | return prompt # Return original for SVD 48 | 49 | def _load_pipeline(self): 50 | if self.pipe is None: 51 | print(f"Loading I2V pipeline (SVD): {self.config.model_id}...") 52 | self.pipe = StableVideoDiffusionPipeline.from_pretrained( 53 | self.config.model_id, torch_dtype=torch.float16 54 | ) 55 | self.pipe.enable_model_cpu_offload() 56 | print("I2V (SVD) pipeline loaded.") 57 | 58 | def clear_vram(self): 59 | print("Clearing I2V (SVD) VRAM...") 60 | if self.pipe is not None: clear_vram_globally(self.pipe) 61 | self.pipe = None 62 | print("I2V (SVD) VRAM cleared.") 63 | 64 | def _resize_and_pad(self, image: Image.Image, target_width: int, target_height: int) -> Image.Image: 65 | original_aspect = image.width / image.height; target_aspect = target_width / target_height 66 | if original_aspect > target_aspect: new_width, new_height = target_width, int(target_width / original_aspect) 67 | else: new_height, new_width = target_height, int(target_height * original_aspect) 68 | resized_image = image.resize((new_width, new_height), Image.LANCZOS) 69 | background = Image.new('RGB', (target_width, target_height), (0, 0, 0)) 70 | background.paste(resized_image, ((target_width - new_width) // 2, (target_height - new_height) // 2)) 71 | return background 72 | 73 | def generate_video_from_image(self, image_path: str, output_video_path: str, target_duration: float, content_config: ContentConfig, visual_prompt: str, motion_prompt: Optional[str], ip_adapter_image: Optional[Union[str, List[str]]] = None) -> str: 74 | self._load_pipeline() 75 | 76 | if ip_adapter_image: 77 | print("Warning: SvdI2V module received IP-Adapter image but does not currently implement its use.") 78 | 79 | input_image = load_image(image_path) 80 | svd_target_res = self.get_model_capabilities()["resolutions"] 81 | aspect_ratio = "Landscape" if input_image.width > input_image.height else "Portrait" 82 | svd_target_width, svd_target_height = svd_target_res[aspect_ratio] 83 | prepared_image = self._resize_and_pad(input_image, svd_target_width, svd_target_height) 84 | 85 | calculated_fps = max(1, round(self.config.model_native_frames / target_duration)) if target_duration > 0 else 8 86 | motion_bucket_id = self.config.motion_bucket_id 87 | if motion_prompt: 88 | motion_prompt_lower = motion_prompt.lower() 89 | if any(w in motion_prompt_lower for w in ['fast', 'quick', 'rapid', 'zoom in', 'pan right']): motion_bucket_id = min(255, motion_bucket_id + 50) 90 | elif any(w in motion_prompt_lower for w in ['slow', 'gentle', 'subtle', 'still']): motion_bucket_id = max(0, motion_bucket_id - 50) 91 | print(f"Adjusted motion_bucket_id to {motion_bucket_id} based on prompt: '{motion_prompt}'") 92 | 93 | video_frames = self.pipe( 94 | image=prepared_image, height=svd_target_height, width=svd_target_width, 95 | decode_chunk_size=self.config.decode_chunk_size, num_frames=self.config.model_native_frames, 96 | motion_bucket_id=motion_bucket_id, noise_aug_strength=self.config.noise_aug_strength, 97 | ).frames[0] 98 | 99 | export_to_video(video_frames, output_video_path, fps=calculated_fps) 100 | print(f"SVD video shot saved to {output_video_path}") 101 | return output_video_path -------------------------------------------------------------------------------- /t2v_modules/t2v_ltx.py: -------------------------------------------------------------------------------- 1 | # In t2v_modules/t2v_ltx.py 2 | import torch 3 | from typing import Dict, Any, List, Optional, Union 4 | import os 5 | 6 | # --- Import the necessary pipelines and configs --- 7 | from diffusers import LTXPipeline, LTXVideoTransformer3DModel 8 | from diffusers.utils import export_to_video 9 | from transformers import T5EncoderModel, BitsAndBytesConfig as TransformersBitsAndBytesConfig 10 | from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig 11 | 12 | from base_modules import BaseT2V, BaseModuleConfig, ModuleCapabilities 13 | from config_manager import DEVICE, clear_vram_globally 14 | 15 | class LtxT2VConfig(BaseModuleConfig): 16 | model_id: str = "Lightricks/LTX-Video" 17 | use_8bit_quantization: bool = True 18 | num_inference_steps: int = 50 19 | guidance_scale: float = 7.5 20 | decode_timestep: float = 0.03 21 | decode_noise_scale: float = 0.025 22 | # No IP-Adapter configs needed as this pipeline doesn't support them 23 | 24 | class LtxT2V(BaseT2V): 25 | Config = LtxT2VConfig 26 | 27 | # No __init__ needed if we just have the default behavior 28 | 29 | @classmethod 30 | def get_capabilities(cls) -> ModuleCapabilities: 31 | """This module is for pure T2V and does NOT support IP-Adapters.""" 32 | return ModuleCapabilities( 33 | title="LTX, Port/Landscape, No Subject, 5 sec", 34 | vram_gb_min=8.0, 35 | ram_gb_min=12.0, 36 | supported_formats=["Portrait", "Landscape"], 37 | # --- THE CRITICAL CHANGE: Be honest about capabilities --- 38 | supports_ip_adapter=False, 39 | supports_lora=False, # This pipeline doesn't have a LoRA loader either 40 | max_subjects=0, 41 | accepts_text_prompt=True, 42 | accepts_negative_prompt=True 43 | ) 44 | 45 | def get_model_capabilities(self) -> Dict[str, Any]: 46 | return {"resolutions": {"Portrait": (512, 768), "Landscape": (768, 512)}, "max_shot_duration": 5.0} 47 | 48 | def _load_pipeline(self): 49 | if self.pipe is not None: return 50 | 51 | if self.config.use_8bit_quantization: 52 | print(f"Loading T2V pipeline ({self.config.model_id}) with 8-bit quantization...") 53 | text_encoder_quant_config = TransformersBitsAndBytesConfig(load_in_8bit=True) 54 | text_encoder_8bit = T5EncoderModel.from_pretrained(self.config.model_id, subfolder="text_encoder", quantization_config=text_encoder_quant_config, torch_dtype=torch.float16) 55 | transformer_quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True) 56 | transformer_8bit = LTXVideoTransformer3DModel.from_pretrained(self.config.model_id, subfolder="transformer", quantization_config=transformer_quant_config, torch_dtype=torch.float16) 57 | 58 | # Note: We are no longer passing the `image_encoder` as it was being ignored. 59 | self.pipe = LTXPipeline.from_pretrained( 60 | self.config.model_id, 61 | text_encoder=text_encoder_8bit, 62 | transformer=transformer_8bit, 63 | torch_dtype=torch.float16, 64 | device_map="balanced", 65 | ) 66 | print("Quantized T2V pipeline loaded successfully.") 67 | else: 68 | print(f"Loading T2V pipeline ({self.config.model_id}) in full precision...") 69 | self.pipe = LTXPipeline.from_pretrained( 70 | self.config.model_id, 71 | torch_dtype=torch.bfloat16 72 | ) 73 | self.pipe.enable_model_cpu_offload() 74 | 75 | self.pipe.vae.enable_tiling() 76 | print("VAE tiling enabled for memory efficiency.") 77 | 78 | def clear_vram(self): 79 | print(f"Clearing T2V (LTX) VRAM...") 80 | if self.pipe is not None: 81 | clear_vram_globally(self.pipe) 82 | self.pipe = None 83 | print("T2V (LTX) VRAM cleared.") 84 | 85 | def generate_video_from_text( 86 | self, prompt: str, output_video_path: str, num_frames: int, fps: int, width: int, height: int, ip_adapter_image: Optional[Union[str, List[str]]] = None 87 | ) -> str: 88 | self._load_pipeline() 89 | 90 | # --- THE GRACEFUL HANDLING --- 91 | # If character images are passed, inform the user they are being ignored. 92 | if ip_adapter_image: 93 | print("="*50) 94 | print("WARNING: The LtxT2V module does not support IP-Adapters for character consistency.") 95 | print("The provided character images will be ignored for this T2V generation.") 96 | print("="*50) 97 | 98 | # All IP-Adapter logic is removed. We just call the pipeline. 99 | pipeline_kwargs = {} 100 | 101 | negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted, text, watermark, bad anatomy" 102 | print(f"Generating LTX T2V ({width}x{height}) for prompt: \"{prompt[:50]}...\"") 103 | 104 | video_frames = self.pipe( 105 | prompt=prompt, 106 | negative_prompt=negative_prompt, 107 | width=width, 108 | height=height, 109 | num_frames=num_frames, 110 | num_inference_steps=self.config.num_inference_steps, 111 | guidance_scale=self.config.guidance_scale, 112 | decode_timestep=self.config.decode_timestep, 113 | decode_noise_scale=self.config.decode_noise_scale, 114 | **pipeline_kwargs 115 | ).frames[0] 116 | 117 | export_to_video(video_frames, output_video_path, fps=fps) 118 | 119 | print(f"LTX T2V video shot saved to {output_video_path}") 120 | return output_video_path -------------------------------------------------------------------------------- /i2v_modules/i2v_wan.py: -------------------------------------------------------------------------------- 1 | # In i2v_modules/i2v_wan.py 2 | import torch 3 | import numpy as np 4 | from typing import Dict, Any, List, Optional, Union 5 | from PIL import Image 6 | 7 | # Import the necessary components 8 | from diffusers import WanImageToVideoPipeline, AutoencoderKLWan 9 | from diffusers.utils import export_to_video, load_image 10 | from transformers import CLIPVisionModel, UMT5EncoderModel, T5Tokenizer, CLIPImageProcessor 11 | 12 | from base_modules import BaseI2V, BaseModuleConfig, ModuleCapabilities 13 | from config_manager import DEVICE, clear_vram_globally, ContentConfig 14 | 15 | class WanI2VConfig(BaseModuleConfig): 16 | """Configuration for the Wan 2.1 I2V model.""" 17 | model_id: str = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" 18 | 19 | num_inference_steps: int = 30 20 | guidance_scale: float = 5.0 21 | 22 | class WanI2V(BaseI2V): 23 | """ 24 | Image-to-Video module using the Wan 2.1 14B pipeline. 25 | """ 26 | Config = WanI2VConfig 27 | 28 | @classmethod 29 | def get_capabilities(cls) -> ModuleCapabilities: 30 | """Declare the capabilities of the Wan 2.1 I2V model.""" 31 | return ModuleCapabilities( 32 | title="Wan 2.1 I2V (14B)", 33 | vram_gb_min=40.0, 34 | ram_gb_min=24.0, 35 | supported_formats=["Portrait", "Landscape"], 36 | supports_ip_adapter=False, 37 | supports_lora=False, 38 | max_subjects=0, 39 | accepts_text_prompt=True, 40 | accepts_negative_prompt=True 41 | ) 42 | 43 | def get_model_capabilities(self) -> Dict[str, Any]: 44 | """Return the specific resolutions and max duration for this model.""" 45 | return { 46 | "resolutions": {"base_pixel_area": 399360}, # 480P model base area 47 | "max_shot_duration": 4.0 48 | } 49 | 50 | def _load_pipeline(self): 51 | """ 52 | Loads the WanImageToVideoPipeline following the official documentation example. 53 | """ 54 | if self.pipe is not None: return 55 | 56 | print(f"Loading I2V pipeline ({self.config.model_id})...") 57 | 58 | # 1. Load individual components with appropriate dtypes 59 | image_encoder = CLIPVisionModel.from_pretrained( 60 | self.config.model_id, 61 | subfolder="image_encoder", 62 | torch_dtype=torch.float32 63 | ) 64 | 65 | vae = AutoencoderKLWan.from_pretrained( 66 | self.config.model_id, 67 | subfolder="vae", 68 | torch_dtype=torch.float32 69 | ) 70 | 71 | # 2. Create the pipeline with the components 72 | self.pipe = WanImageToVideoPipeline.from_pretrained( 73 | self.config.model_id, 74 | vae=vae, 75 | image_encoder=image_encoder, 76 | torch_dtype=torch.bfloat16 77 | ) 78 | 79 | # 3. Enable model CPU offload for memory efficienc y 80 | self.pipe.enable_model_cpu_offload() 81 | 82 | print("I2V (Wan 14B) pipeline loaded successfully.") 83 | 84 | def clear_vram(self): 85 | """Clears the VRAM used by all loaded components.""" 86 | print(f"Clearing I2V (Wan 14B) VRAM...") 87 | if self.pipe is not None: 88 | clear_vram_globally(self.pipe) 89 | self.pipe = None 90 | print("I2V (Wan 14B) VRAM cleared.") 91 | 92 | def generate_video_from_image( 93 | self, image_path: str, output_video_path: str, target_duration: float, 94 | content_config: ContentConfig, visual_prompt: str, motion_prompt: Optional[str], 95 | ip_adapter_image: Optional[Union[str, List[str]]] = None 96 | ) -> str: 97 | """Generates a video by animating a source image using the 14B model.""" 98 | self._load_pipeline() 99 | 100 | input_image = load_image(image_path) 101 | 102 | model_caps = self.get_model_capabilities() 103 | max_area = model_caps["resolutions"]["base_pixel_area"] 104 | aspect_ratio = input_image.height / input_image.width 105 | 106 | # Calculate dimensions using the correct scale factors 107 | mod_value = self.pipe.vae_scale_factor_spatial * self.pipe.transformer.config.patch_size[1] 108 | h = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value 109 | w = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value 110 | prepared_image = input_image.resize((w, h)) 111 | 112 | num_frames = int(target_duration * content_config.fps) 113 | full_prompt = f"{visual_prompt}, {motion_prompt}" if motion_prompt else visual_prompt 114 | negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards" 115 | 116 | print(f"Generating Wan I2V ({w}x{h}) from image: {image_path}") 117 | print(f" - Prompt: \"{full_prompt[:70]}...\"") 118 | 119 | video_frames = self.pipe( 120 | image=prepared_image, 121 | prompt=full_prompt, 122 | negative_prompt=negative_prompt, 123 | height=h, 124 | width=w, 125 | num_frames=num_frames, 126 | guidance_scale=self.config.guidance_scale, 127 | num_inference_steps=self.config.num_inference_steps, 128 | ).frames[0] 129 | 130 | export_to_video(video_frames, output_video_path, fps=content_config.fps) 131 | 132 | print(f"Wan I2V 14B video shot saved to {output_video_path}") 133 | return output_video_path -------------------------------------------------------------------------------- /base_modules.py: -------------------------------------------------------------------------------- 1 | # In base_modules.py 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import List, Tuple, Dict, Any, Optional, Union, Literal 5 | from pydantic import BaseModel, Field 6 | 7 | # --- NEW: Define the ModuleCapabilities Contract --- 8 | class ModuleCapabilities(BaseModel): 9 | """A standardized spec sheet for all generation modules.""" 10 | 11 | title: str = Field(description="Title to show in dropdowns") 12 | 13 | # Resource Requirements 14 | vram_gb_min: float = Field(default=4.0, description="Minimum GPU VRAM required in GB.") 15 | ram_gb_min: float = Field(default=8.0, description="Minimum system RAM required in GB.") 16 | 17 | # Format & Control Support 18 | supported_formats: List[Literal["Portrait", "Landscape"]] = Field(default=["Portrait", "Landscape"]) 19 | supports_ip_adapter: bool = Field(default=False, description="True if the module can use IP-Adapter for subject consistency.") 20 | supports_lora: bool = Field(default=False, description="True if the module supports LoRA weights.") 21 | 22 | # Subject & Prompting 23 | max_subjects: int = Field(default=0, description="Maximum number of distinct subjects/characters the module can handle at once (e.g., via IP-Adapter).") 24 | accepts_text_prompt: bool = Field(default=True, description="True if the module uses a text prompt.") 25 | accepts_negative_prompt: bool = Field(default=True, description="True if the module uses a negative prompt.") 26 | 27 | # Type-Specific 28 | supported_tts_languages: List[str] = Field(default=[], description="List of languages supported by a TTS module (e.g., ['en', 'es']).") 29 | 30 | # Forward-declare to avoid circular imports 31 | class ContentConfig(BaseModel): pass 32 | class ProjectState(BaseModel): pass 33 | 34 | # --- Base Configuration Models --- 35 | class BaseModuleConfig(BaseModel): 36 | """Base for all module-specific configurations.""" 37 | model_id: str 38 | 39 | # --- Base Module Classes --- 40 | class BaseLLM(ABC): 41 | """Abstract Base Class for Language Model modules.""" 42 | def __init__(self, config: BaseModuleConfig): 43 | self.config = config 44 | self.model = None 45 | self.tokenizer = None 46 | 47 | # --- NEW: Enforce capabilities contract --- 48 | @classmethod 49 | @abstractmethod 50 | def get_capabilities(cls) -> ModuleCapabilities: 51 | """Returns the spec sheet for this module.""" 52 | raise NotImplementedError 53 | 54 | @abstractmethod 55 | def generate_script(self, topic: str, content_config: ContentConfig) -> Dict[str, Any]: 56 | """Generates the main script, visual prompts, hashtags, and context descriptions.""" 57 | pass 58 | 59 | @abstractmethod 60 | def generate_shot_visual_prompts(self, scene_narration: str, original_scene_prompt: str, num_shots: int, content_config: ContentConfig, main_subject: str, setting: str) -> List[Tuple[str, str]]: 61 | """Generates visual and motion prompts for each shot within a scene.""" 62 | pass 63 | 64 | @abstractmethod 65 | def clear_vram(self): 66 | """Clears the VRAM used by the model and tokenizer.""" 67 | pass 68 | 69 | class BaseTTS(ABC): 70 | """Abstract Base Class for Text-to-Speech modules.""" 71 | def __init__(self, config: BaseModuleConfig): 72 | self.config = config 73 | self.model = None 74 | 75 | # --- NEW: Enforce capabilities contract --- 76 | @classmethod 77 | @abstractmethod 78 | def get_capabilities(cls) -> ModuleCapabilities: 79 | """Returns the spec sheet for this module.""" 80 | raise NotImplementedError 81 | 82 | @abstractmethod 83 | def generate_audio(self, text: str, output_dir: str, scene_idx: int, language: str, speaker_wav: Optional[str] = None) -> Tuple[str, float]: 84 | """Generates audio from text.""" 85 | pass 86 | 87 | @abstractmethod 88 | def clear_vram(self): 89 | """Clears the VRAM used by the TTS model.""" 90 | pass 91 | 92 | class BaseVideoGen(ABC): 93 | """A common base for all video generation modules (T2I, I2V, T2V).""" 94 | def __init__(self, config: BaseModuleConfig): 95 | self.config = config 96 | self.pipe = None 97 | 98 | # --- NEW: Enforce capabilities contract --- 99 | @classmethod 100 | @abstractmethod 101 | def get_capabilities(cls) -> ModuleCapabilities: 102 | """Returns the spec sheet for this module.""" 103 | raise NotImplementedError 104 | 105 | @abstractmethod 106 | def get_model_capabilities(self) -> Dict[str, Any]: 107 | """Returns a dictionary of the model's capabilities, like resolutions.""" 108 | pass 109 | 110 | def enhance_prompt(self, prompt: str, prompt_type: str = "visual") -> str: 111 | return prompt 112 | 113 | @abstractmethod 114 | def clear_vram(self): 115 | """Clears the VRAM used by the pipeline.""" 116 | pass 117 | 118 | class BaseT2I(BaseVideoGen): 119 | """Abstract Base Class for Text-to-Image modules.""" 120 | @abstractmethod 121 | def generate_image(self, prompt: str, negative_prompt: str, output_path: str, width: int, height: int, ip_adapter_image: Optional[Union[str, List[str]]] = None, seed: int = -1) -> str: 122 | """Generates an image from a text prompt, optionally using an IP-Adapter image.""" 123 | pass 124 | 125 | class BaseI2V(BaseVideoGen): 126 | """Abstract Base Class for Image-to-Video modules.""" 127 | @abstractmethod 128 | def generate_video_from_image(self, image_path: str, output_video_path: str, target_duration: float, content_config: ContentConfig, visual_prompt: str, motion_prompt: Optional[str], ip_adapter_image: Optional[Union[str, List[str]]] = None) -> str: 129 | """Generates a video from an initial image, optionally using an IP-Adapter image for style/subject.""" 130 | pass 131 | 132 | class BaseT2V(BaseVideoGen): 133 | """Abstract Base Class for Text-to-Video modules.""" 134 | @abstractmethod 135 | def generate_video_from_text(self, prompt: str, output_video_path: str, num_frames: int, fps: int, width: int, height: int, ip_adapter_image: Optional[Union[str, List[str]]] = None) -> str: 136 | """Generates a video directly from a text prompt, optionally using an IP-Adapter image.""" 137 | pass -------------------------------------------------------------------------------- /__requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.2.2 2 | accelerate==1.6.0 3 | aiofiles==24.1.0 4 | aiohappyeyeballs==2.6.1 5 | aiohttp==3.11.18 6 | aiosignal==1.3.2 7 | albucore==0.0.24 8 | albumentations==2.0.7 9 | altair==5.5.0 10 | annotated-types==0.7.0 11 | anyascii==0.3.2 12 | anyio==4.9.0 13 | anykeystore==0.2 14 | apex==0.9.10.dev0 15 | argon2-cffi==23.1.0 16 | argon2-cffi-bindings==21.2.0 17 | arrow==1.3.0 18 | asttokens==3.0.0 19 | async-lru==2.0.5 20 | attrs==25.3.0 21 | audioread==3.0.1 22 | av==12.1.0 23 | babel==2.17.0 24 | beautifulsoup4==4.13.4 25 | bitsandbytes==0.45.5 26 | bleach==6.2.0 27 | blinker==1.9.0 28 | blis==0.7.11 29 | cachetools==5.5.2 30 | catalogue==2.0.10 31 | certifi==2025.4.26 32 | cffi==1.17.1 33 | charset-normalizer==3.4.2 34 | click==8.1.8 35 | cloudpathlib==0.21.1 36 | coloredlogs==15.0.1 37 | comm==0.2.2 38 | confection==0.1.5 39 | consisid-eva-clip==1.0.2 40 | contourpy>=1.3.0 41 | coqpit-config>=0.2.0 42 | coqui-tts>=0.26.0 43 | coqui-tts-trainer>=0.2.3 44 | cryptacular==1.6.2 45 | cycler==0.12.1 46 | cymem==2.0.11 47 | cython==3.1.0 48 | dateparser==1.1.8 49 | debugpy==1.8.14 50 | decorator==5.2.1 51 | defusedxml==0.7.1 52 | diffusers==0.33.1 53 | docopt==0.6.2 54 | easydict==1.13 55 | einops==0.8.1 56 | encodec==0.1.1 57 | executing==2.2.0 58 | facexlib==0.3.0 59 | fastapi==0.115.12 60 | fastjsonschema==2.21.1 61 | ffmpy==0.6.0 62 | filelock==3.18.0 63 | filterpy==1.4.5 64 | flatbuffers==25.2.10 65 | fonttools==4.58.0 66 | fqdn==1.5.1 67 | frozenlist==1.6.0 68 | fsspec==2025.5.1 69 | ftfy==6.3.1 70 | gitdb==4.0.12 71 | gitpython==3.1.44 72 | gradio==5.25.2 73 | gradio-client==1.8.0 74 | greenlet==3.2.2 75 | groovy==0.1.2 76 | grpcio==1.71.0 77 | gruut==2.4.0 78 | gruut-ipa==0.13.0 79 | gruut-lang-de==2.0.1 80 | gruut-lang-en==2.0.1 81 | gruut-lang-es==2.0.1 82 | gruut-lang-fr==2.0.2 83 | h11==0.16.0 84 | hf-xet==1.1.3 85 | httpcore==1.0.9 86 | httpx==0.28.1 87 | huggingface-hub==0.32.4 88 | humanfriendly==10.0 89 | hupper==1.12.1 90 | idna==3.10 91 | imageio==2.37.0 92 | imageio-ffmpeg==0.6.0 93 | importlib-metadata==8.7.0 94 | inflect==7.5.0 95 | inquirerpy==0.3.4 96 | insightface==0.7.3 97 | ipykernel==6.29.5 98 | ipython==9.2.0 99 | ipython-pygments-lexers==1.1.1 100 | ipywidgets==8.1.7 101 | isoduration==20.11.0 102 | jedi==0.19.2 103 | jinja2==3.1.6 104 | joblib==1.5.0 105 | json5==0.12.0 106 | jsonlines==1.2.0 107 | jsonpointer==3.0.0 108 | jsonschema==4.23.0 109 | jsonschema-specifications==2025.4.1 110 | jupyter==1.1.1 111 | jupyter-client==8.6.3 112 | jupyter-console==6.6.3 113 | jupyter-core==5.7.2 114 | jupyter-events==0.12.0 115 | jupyter-lsp==2.2.5 116 | jupyter-server==2.16.0 117 | jupyter-server-terminals==0.5.3 118 | jupyterlab==4.4.2 119 | jupyterlab-pygments==0.3.0 120 | jupyterlab-server==2.27.3 121 | jupyterlab-widgets==3.0.15 122 | kiwisolver==1.4.8 123 | langcodes==3.5.0 124 | language-data==1.3.0 125 | lazy-loader==0.4 126 | librosa>=0.11.0 127 | llvmlite==0.44.0 128 | marisa-trie==1.2.1 129 | markdown==3.8 130 | markdown-it-py==3.0.0 131 | markupsafe==3.0.2 132 | matplotlib==3.10.3 133 | matplotlib-inline==0.1.7 134 | mdurl==0.1.2 135 | mistune==3.1.3 136 | monotonic-alignment-search==0.1.1 137 | more-itertools==10.7.0 138 | moviepy==2.1.2 139 | mpmath==1.3.0 140 | msgpack==1.1.0 141 | multidict==6.4.3 142 | murmurhash==1.0.12 143 | narwhals==1.41.1 144 | nbclient==0.10.2 145 | nbconvert==7.16.6 146 | nbformat==5.10.4 147 | nest-asyncio==1.6.0 148 | networkx==3.5 149 | notebook==7.4.2 150 | notebook-shim==0.2.4 151 | num2words==0.5.14 152 | numba>=0.61.2 153 | numpy>=1.26.2 154 | nvidia-cublas-cu12==12.1.3.1 155 | nvidia-cuda-cupti-cu12==12.1.105 156 | nvidia-cuda-nvrtc-cu12==12.1.105 157 | nvidia-cuda-runtime-cu12==12.1.105 158 | nvidia-cudnn-cu12==9.1.0.70 159 | nvidia-cufft-cu12==11.0.2.54 160 | nvidia-cufile-cu12==1.11.1.6 161 | nvidia-curand-cu12==10.3.2.106 162 | nvidia-cusolver-cu12==11.4.5.107 163 | nvidia-cusparse-cu12==12.1.0.106 164 | nvidia-cusparselt-cu12==0.6.3 165 | nvidia-nccl-cu12==2.21.5 166 | nvidia-nvjitlink-cu12==12.6.85 167 | nvidia-nvtx-cu12==12.1.105 168 | oauthlib==3.2.2 169 | onnx==1.18.0 170 | onnxruntime-gpu==1.22.0 171 | opencv-contrib-python==4.11.0.86 172 | opencv-python==4.11.0.86 173 | opencv-python-headless==4.11.0.86 174 | orjson==3.10.18 175 | overrides==7.7.0 176 | packaging==24.2 177 | pandas==2.3.0 178 | pandocfilters==1.5.1 179 | parso==0.8.4 180 | pastedeploy==3.1.0 181 | pbkdf2==1.3 182 | peft==0.15.2 183 | pexpect==4.9.0 184 | pfzy==0.3.4 185 | pillow>=9.2.0,<11.0 186 | plaster==1.1.2 187 | plaster-pastedeploy==1.0.1 188 | platformdirs==4.3.8 189 | pooch==1.8.2 190 | preshed==3.0.9 191 | prettytable==3.16.0 192 | proglog==0.1.12 193 | prometheus-client==0.21.1 194 | prompt-toolkit==3.0.51 195 | propcache==0.3.1 196 | protobuf==6.31.0 197 | psutil==7.0.0 198 | ptyprocess==0.7.0 199 | pure-eval==0.2.3 200 | pyarrow==20.0.0 201 | pycparser==2.22 202 | pydantic==2.11.4 203 | pydantic-core==2.33.2 204 | pydeck==0.9.1 205 | pydub==0.25.1 206 | pyfacer==0.0.5 207 | pygments==2.19.1 208 | pyparsing==3.2.3 209 | pyramid==2.0.2 210 | pyramid-mailer==0.15.1 211 | pysbd==0.3.4 212 | python-crfsuite==0.9.11 213 | python-dateutil==2.9.0.post0 214 | python-dotenv==1.1.0 215 | python-json-logger==3.3.0 216 | python-multipart==0.0.20 217 | python3-openid==3.2.0 218 | pytz==2025.2 219 | pyyaml==6.0.2 220 | pyzmq==26.4.0 221 | referencing==0.36.2 222 | regex==2024.11.6 223 | repoze-sendmail==4.4.1 224 | requests==2.31.0 225 | requests-oauthlib==2.0.0 226 | rfc3339-validator==0.1.4 227 | rfc3986-validator==0.1.1 228 | rich==14.0.0 229 | rpds-py==0.25.0 230 | ruff==0.11.13 231 | safehttpx==0.1.6 232 | safetensors==0.5.3 233 | scikit-image==0.25.2 234 | scikit-learn==1.6.1 235 | scipy==1.12.0 236 | semantic-version==2.10.0 237 | send2trash==1.8.3 238 | sentencepiece==0.2.0 239 | setuptools==80.9.0 240 | shellingham==1.5.4 241 | simsimd==6.2.1 242 | six==1.17.0 243 | smart-open==7.1.0 244 | smmap==5.0.2 245 | sniffio==1.3.1 246 | soundfile==0.13.1 247 | soupsieve==2.7 248 | soxr==0.5.0.post1 249 | spacy==3.7.5 250 | spacy-legacy==3.0.12 251 | spacy-loggers==1.0.5 252 | sqlalchemy==2.0.41 253 | srsly==2.5.1 254 | stack-data==0.6.3 255 | starlette==0.46.2 256 | streamlit==1.45.0 257 | stringzilla==3.12.5 258 | sudachidict-core==20250129 259 | sudachipy==0.6.10 260 | sympy==1.13.1 261 | tenacity==9.1.2 262 | tensorboard==2.19.0 263 | tensorboard-data-server==0.7.2 264 | terminado==0.18.1 265 | thinc==8.2.5 266 | threadpoolctl==3.6.0 267 | tifffile==2025.5.10 268 | timm==1.0.15 269 | tinycss2==1.4.0 270 | tokenizers>=0.20.3 271 | toml==0.10.2 272 | tomlkit==0.13.3 273 | # torch==2.5.1+cu121 274 | # torchaudio==2.5.1+cu121 275 | torchsde==0.2.6 276 | # torchvision==0.20.1+cu121 277 | tornado==6.4.2 278 | tqdm==4.67.1 279 | traitlets==5.14.3 280 | trampoline==0.1.2 281 | transaction==5.0 282 | transformers>=4.46.2 283 | translationstring==1.4 284 | triton==3.1.0 285 | typeguard==4.4.2 286 | typer==0.15.4 287 | types-python-dateutil==2.9.0.20241206 288 | typing-extensions==4.14.0 289 | typing-inspection==0.4.0 290 | tzdata==2025.2 291 | tzlocal==5.3.1 292 | uri-template==1.3.0 293 | urllib3==2.4.0 294 | uvicorn==0.34.3 295 | validators==0.35.0 296 | velruse==1.1.1 297 | venusian==3.1.1 298 | wasabi==1.1.3 299 | watchdog==6.0.0 300 | wcwidth==0.2.13 301 | weasel==0.4.1 302 | webcolors==24.11.1 303 | webencodings==0.5.1 304 | webob==1.8.9 305 | websocket-client==1.8.0 306 | websockets==15.0.1 307 | werkzeug==3.1.3 308 | widgetsnbextension==4.0.14 309 | wrapt==1.17.2 310 | wtforms==3.2.1 311 | wtforms-recaptcha==0.3.2 312 | xformers==0.0.29.post1 313 | yarl==1.20.0 314 | zipp==3.22.0 315 | zope-deprecation==5.1 316 | zope-interface==7.2 317 | zope-sqlalchemy==3.1 318 | -------------------------------------------------------------------------------- /t2i_modules/t2i_juggernaut.py: -------------------------------------------------------------------------------- 1 | # In t2i_modules/t2i_juggernaut.py 2 | import torch 3 | from typing import List, Optional, Dict, Any, Union 4 | from diffusers import StableDiffusionXLPipeline, DiffusionPipeline 5 | from diffusers.utils import load_image 6 | from transformers import BitsAndBytesConfig 7 | from diffusers import DPMSolverMultistepScheduler as JuggernautScheduler 8 | 9 | from base_modules import BaseT2I, BaseModuleConfig, ModuleCapabilities 10 | from config_manager import DEVICE, clear_vram_globally 11 | 12 | class JuggernautT2IConfig(BaseModuleConfig): 13 | model_id: str = "RunDiffusion/Juggernaut-XL-v9" 14 | refiner_id: Optional[str] = None 15 | # --- NEW: Flag to control memory-saving quantization --- 16 | use_8bit_quantization: bool = True 17 | num_inference_steps: int = 35 18 | guidance_scale: float = 6.0 19 | ip_adapter_repo: str = "h94/IP-Adapter" 20 | ip_adapter_subfolder: str = "sdxl_models" 21 | ip_adapter_weight_name: str = "ip-adapter_sdxl.bin" 22 | 23 | 24 | class JuggernautT2I(BaseT2I): 25 | Config = JuggernautT2IConfig 26 | 27 | def __init__(self, config: JuggernautT2IConfig): 28 | super().__init__(config) 29 | self.refiner_pipe = None 30 | self._loaded_ip_adapter_count = 0 31 | 32 | @classmethod 33 | def get_capabilities(cls) -> ModuleCapabilities: 34 | return ModuleCapabilities( 35 | title="Juggernaut XL v9 (Quality), 2 Subjects considered", 36 | vram_gb_min=8.0, 37 | ram_gb_min=12.0, 38 | supported_formats=["Portrait", "Landscape"], 39 | supports_ip_adapter=True, 40 | supports_lora=True, 41 | max_subjects=2, 42 | accepts_text_prompt=True, 43 | accepts_negative_prompt=True 44 | ) 45 | 46 | def get_model_capabilities(self) -> Dict[str, Any]: 47 | return { 48 | "resolutions": {"Portrait": (832, 1216), "Landscape": (1216, 832)}, 49 | "max_shot_duration": 3.0 50 | } 51 | 52 | def enhance_prompt(self, prompt: str, prompt_type: str = "visual") -> str: 53 | quality_keywords = "cinematic photography, hyperdetailed, (skin details:1.1), 8k, professional lighting" 54 | if prompt.strip().endswith(','): 55 | return f"{prompt} {quality_keywords}" 56 | else: 57 | return f"{prompt}, {quality_keywords}" 58 | 59 | def _load_pipeline(self): 60 | if self.pipe is None: 61 | if self.config.use_8bit_quantization: 62 | print("Loading T2I pipeline (Juggernaut) with 8-bit quantization to save VRAM...") 63 | bnb_config = BitsAndBytesConfig( 64 | load_in_8bit=True, 65 | ) 66 | # --- START OF FIX: Remove device_map and use .to(DEVICE) instead --- 67 | # This prevents the accelerate hook conflict when loading IP-Adapters later. 68 | self.pipe = StableDiffusionXLPipeline.from_pretrained( 69 | self.config.model_id, 70 | quantization_config=bnb_config, 71 | torch_dtype=torch.float16, 72 | variant="fp16", 73 | use_safetensors=True, 74 | ).to(DEVICE) 75 | # --- END OF FIX --- 76 | else: 77 | print(f"Loading T2I pipeline (Juggernaut) in full precision to {DEVICE}...") 78 | self.pipe = StableDiffusionXLPipeline.from_pretrained( 79 | self.config.model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True 80 | ).to(DEVICE) 81 | 82 | self.pipe.scheduler = JuggernautScheduler.from_config(self.pipe.scheduler.config, use_karras_sigmas=True) 83 | print(f"Juggernaut pipeline configured with {JuggernautScheduler.__name__} sampler.") 84 | 85 | if self.config.refiner_id: 86 | print(f"Refiner specified but not typically used with Juggernaut, skipping load.") 87 | 88 | def clear_vram(self): 89 | print("Clearing T2I (Juggernaut) VRAM...") 90 | models = [m for m in [self.pipe, self.refiner_pipe] if m is not None] 91 | if models: clear_vram_globally(*models) 92 | self.pipe, self.refiner_pipe = None, None 93 | self._loaded_ip_adapter_count = 0 94 | print("T2I (Juggernaut) VRAM cleared.") 95 | 96 | def generate_image(self, prompt: str, negative_prompt: str, output_path: str, width: int, height: int, ip_adapter_image: Optional[Union[str, List[str]]] = None, seed: int = -1) -> str: 97 | self._load_pipeline() 98 | 99 | generator = None 100 | if seed != -1: 101 | print(f"Using fixed seed for generation: {seed}") 102 | generator = torch.Generator(device=self.pipe.device).manual_seed(seed) 103 | else: 104 | print("Using random seed for generation.") 105 | 106 | pipeline_kwargs = {"generator": generator} if generator else {} 107 | ip_images_to_load = [] 108 | 109 | if ip_adapter_image: 110 | if isinstance(ip_adapter_image, str): 111 | ip_images_to_load = [ip_adapter_image] 112 | else: 113 | ip_images_to_load = ip_adapter_image 114 | 115 | num_ip_images = len(ip_images_to_load) 116 | 117 | if num_ip_images > 0: 118 | print(f"Juggernaut T2I: Activating IP-Adapter with {num_ip_images} character image(s).") 119 | if self._loaded_ip_adapter_count != num_ip_images: 120 | print(f"Loading {num_ip_images} IP-Adapter(s) for the pipeline...") 121 | if hasattr(self.pipe, "unload_ip_adapter"): self.pipe.unload_ip_adapter() 122 | adapter_weights = [self.config.ip_adapter_weight_name] * num_ip_images 123 | self.pipe.load_ip_adapter( 124 | self.config.ip_adapter_repo, 125 | subfolder=self.config.ip_adapter_subfolder, 126 | weight_name=adapter_weights 127 | ) 128 | self._loaded_ip_adapter_count = num_ip_images 129 | print(f"Successfully loaded {self._loaded_ip_adapter_count} adapters.") 130 | 131 | scales = [0.6] * num_ip_images 132 | self.pipe.set_ip_adapter_scale(scales) 133 | ip_images = [load_image(p) for p in ip_images_to_load] 134 | pipeline_kwargs["ip_adapter_image"] = ip_images 135 | else: 136 | print("Juggernaut T2I: No IP-Adapter image provided.") 137 | if self._loaded_ip_adapter_count > 0: 138 | if hasattr(self.pipe, "unload_ip_adapter"): self.pipe.unload_ip_adapter() 139 | self._loaded_ip_adapter_count = 0 140 | 141 | enhanced_prompt = self.enhance_prompt(prompt) 142 | print(f"Juggernaut generating image with resolution: {width}x{height}") 143 | print(f" - Prompt: '{enhanced_prompt}'") 144 | print(f" - Negative: '{negative_prompt}'") 145 | 146 | image = self.pipe( 147 | prompt=enhanced_prompt, 148 | negative_prompt=negative_prompt, 149 | width=width, 150 | height=height, 151 | num_inference_steps=self.config.num_inference_steps, 152 | guidance_scale=self.config.guidance_scale, 153 | **pipeline_kwargs 154 | ).images[0] 155 | 156 | image.save(output_path) 157 | print(f"Image saved to {output_path}") 158 | return output_path -------------------------------------------------------------------------------- /ui_task_executor.py: -------------------------------------------------------------------------------- 1 | # In ui_task_executor.py 2 | 3 | import streamlit as st 4 | from task_executor import TaskExecutor 5 | from config_manager import ContentConfig 6 | import logging 7 | from typing import List, Optional, Any 8 | import os 9 | from utils import load_and_correct_image_orientation 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | class UITaskExecutor: 14 | """Handles task execution triggered from the Streamlit UI, providing user feedback.""" 15 | 16 | def __init__(self, project_manager): 17 | self.project_manager = project_manager 18 | self.task_executor: Optional[TaskExecutor] = None 19 | self._initialize_task_executor() 20 | 21 | def _initialize_task_executor(self): 22 | if not self.project_manager.state: 23 | st.error("Cannot initialize task executor: Project state not found.") 24 | return 25 | try: 26 | self.task_executor = TaskExecutor(self.project_manager) 27 | except Exception as e: 28 | logger.error(f"Failed to initialize TaskExecutor: {e}", exc_info=True) 29 | st.error(f"Configuration Error: {e}") 30 | 31 | def update_narration_text(self, scene_idx: int, text: str): 32 | self.project_manager.update_narration_part_text(scene_idx, text) 33 | 34 | def update_shot_prompts(self, scene_idx: int, shot_idx: int, visual_prompt: Optional[str] = None, motion_prompt: Optional[str] = None): 35 | self.project_manager.update_shot_content(scene_idx, shot_idx, visual_prompt, motion_prompt) 36 | 37 | def regenerate_audio(self, scene_idx: int, text: str, speaker_audio: Optional[str] = None) -> bool: 38 | if not self.task_executor: return False 39 | self.project_manager.update_narration_part_text(scene_idx, text) 40 | task_data = {"scene_idx": scene_idx, "text": text, "speaker_wav": speaker_audio if speaker_audio and os.path.exists(speaker_audio) else None} 41 | success = self.task_executor.execute_task("generate_audio", task_data) 42 | if success: st.toast(f"Audio for Scene {scene_idx + 1} generated!", icon="🔊") 43 | else: st.error(f"Failed to generate audio for Scene {scene_idx + 1}.") 44 | self.project_manager.load_project() 45 | return success 46 | 47 | def create_scene(self, scene_idx: int) -> bool: 48 | if not self.task_executor: return False 49 | success = self.task_executor.execute_task("create_scene", {"scene_idx": scene_idx}) 50 | if success: st.toast(f"Scene {scene_idx + 1} shots created!", icon="🎬") 51 | else: st.error(f"Failed to create shots for Scene {scene_idx + 1}.") 52 | self.project_manager.load_project() 53 | return success 54 | 55 | # --- NEW METHOD --- 56 | def regenerate_scene_shots(self, scene_idx: int) -> bool: 57 | """Resets a scene and triggers the 'create_scene' task to regenerate shots.""" 58 | if not self.task_executor: return False 59 | 60 | # First, reset the scene, clearing old shots and assets 61 | self.project_manager.reset_scene_for_shot_regeneration(scene_idx) 62 | st.toast(f"Cleared old shots for Scene {scene_idx + 1}. Regenerating...", icon="♻️") 63 | 64 | # Now, execute the create_scene task which will find the scene missing and create it 65 | success = self.task_executor.execute_task("create_scene", {"scene_idx": scene_idx}) 66 | 67 | if success: 68 | st.toast(f"New shots for Scene {scene_idx + 1} generated!", icon="✨") 69 | else: 70 | st.error(f"Failed to regenerate shots for Scene {scene_idx + 1}.") 71 | 72 | self.project_manager.load_project() 73 | return success 74 | 75 | def regenerate_shot_image(self, scene_idx: int, shot_idx: int) -> bool: 76 | if not self.task_executor: return False 77 | self.project_manager.update_shot_content(scene_idx, shot_idx) 78 | shot = self.project_manager.get_scene_info(scene_idx).shots[shot_idx] 79 | task_data = {"scene_idx": scene_idx, "shot_idx": shot_idx, "visual_prompt": shot.visual_prompt} 80 | success = self.task_executor.execute_task("generate_shot_image", task_data) 81 | if success: st.toast(f"Image for Shot {shot_idx + 1} generated!", icon="🖼️") 82 | else: st.error(f"Failed to generate image for Shot {shot_idx + 1}.") 83 | self.project_manager.load_project() 84 | return success 85 | 86 | def regenerate_shot_video(self, scene_idx: int, shot_idx: int) -> bool: 87 | if not self.task_executor: return False 88 | self.project_manager.update_shot_content(scene_idx, shot_idx) 89 | shot = self.project_manager.get_scene_info(scene_idx).shots[shot_idx] 90 | task_data = { 91 | "scene_idx": scene_idx, "shot_idx": shot_idx, 92 | "visual_prompt": shot.visual_prompt, 93 | "motion_prompt": shot.motion_prompt 94 | } 95 | success = self.task_executor.execute_task("generate_shot_video", task_data) 96 | if success: st.toast(f"Video for Shot {shot_idx + 1} generated!", icon="📹") 97 | else: st.error(f"Failed to generate video for Shot {shot_idx + 1}.") 98 | self.project_manager.load_project() 99 | return success 100 | 101 | def regenerate_shot_t2v(self, scene_idx: int, shot_idx: int) -> bool: 102 | if not self.task_executor: return False 103 | self.project_manager.update_shot_content(scene_idx, shot_idx) 104 | shot = self.project_manager.get_scene_info(scene_idx).shots[shot_idx] 105 | task_data = {"scene_idx": scene_idx, "shot_idx": shot_idx, "visual_prompt": shot.visual_prompt} 106 | success = self.task_executor.execute_task("generate_shot_t2v", task_data) 107 | if success: st.toast(f"T2V Shot {shot_idx + 1} generated!", icon="📹") 108 | else: st.error(f"Failed to generate T2V Shot {shot_idx + 1}.") 109 | self.project_manager.load_project() 110 | return success 111 | 112 | def assemble_final_video(self) -> bool: 113 | if not self.task_executor: return False 114 | success = self.task_executor.execute_task("assemble_final", {}) 115 | if success: st.toast("Final video assembled successfully!", icon="🏆") 116 | else: st.error("Failed to assemble final video.") 117 | self.project_manager.load_project() 118 | return success 119 | 120 | def add_character(self, name: str, image_file: "UploadedFile"): 121 | if not self.project_manager.state: return False 122 | safe_name = name.replace(" ", "_") 123 | char_dir = os.path.join(self.project_manager.output_dir, "characters", safe_name) 124 | os.makedirs(char_dir, exist_ok=True) 125 | ref_image_path = os.path.join(char_dir, "reference.png") 126 | 127 | corrected_image = load_and_correct_image_orientation(image_file) 128 | if corrected_image: 129 | corrected_image.save(ref_image_path, "PNG") 130 | char_data = {"name": name, "reference_image_path": ref_image_path} 131 | self.project_manager.add_character(char_data) 132 | st.toast(f"Character '{name}' added!", icon="👤") 133 | return True 134 | else: 135 | st.error(f"Could not process image for new character {name}. Aborting.") 136 | return False 137 | 138 | def update_character(self, old_name: str, new_name: str, new_image_file: Optional["UploadedFile"]): 139 | ref_image_path = None 140 | if new_image_file: 141 | safe_name = (new_name or old_name).replace(" ", "_") 142 | char_dir = os.path.join(self.project_manager.output_dir, "characters", safe_name) 143 | os.makedirs(char_dir, exist_ok=True) 144 | ref_image_path = os.path.join(char_dir, "reference.png") 145 | 146 | corrected_image = load_and_correct_image_orientation(new_image_file) 147 | if corrected_image: 148 | corrected_image.save(ref_image_path, "PNG") 149 | else: 150 | st.error("Failed to process the new image. Character image was not updated.") 151 | ref_image_path = None 152 | 153 | self.project_manager.update_character(old_name, new_name, ref_image_path) 154 | st.toast(f"Character '{old_name}' updated!", icon="✏️") 155 | return True 156 | 157 | def delete_character(self, name: str): 158 | self.project_manager.delete_character(name) 159 | st.toast(f"Character '{name}' deleted!", icon="🗑️") 160 | return True 161 | 162 | def update_project_config(self, key: str, value: Any): 163 | """UI wrapper to update a specific project configuration value.""" 164 | self.project_manager.update_config_value(key, value) 165 | st.toast(f"Setting '{key.replace('_', ' ').title()}' updated.") 166 | st.rerun() 167 | 168 | def update_scene_characters(self, scene_idx: int, character_names: List[str]): 169 | self.project_manager.update_scene_characters(scene_idx, character_names) 170 | st.toast(f"Characters for Scene {scene_idx+1} updated.", icon="🎬") 171 | 172 | def add_new_scene(self, scene_idx: int): 173 | """UI wrapper to add a new scene.""" 174 | self.project_manager.add_new_scene_at(scene_idx) 175 | st.toast(f"New scene added at position {scene_idx + 1}!", icon="➕") 176 | return True 177 | 178 | def remove_scene(self, scene_idx: int): 179 | """UI wrapper to remove a scene.""" 180 | self.project_manager.remove_scene_at(scene_idx) 181 | st.toast(f"Scene {scene_idx + 1} removed!", icon="🗑️") 182 | return True -------------------------------------------------------------------------------- /llm_modules/llm_zephyr.py: -------------------------------------------------------------------------------- 1 | # llm_modules/llm_zephyr.py 2 | import torch 3 | import json 4 | import re 5 | from typing import List, Optional, Tuple, Dict, Any 6 | from transformers import AutoModelForCausalLM, AutoTokenizer 7 | 8 | from base_modules import BaseLLM, BaseModuleConfig, ModuleCapabilities 9 | from config_manager import ContentConfig, DEVICE, clear_vram_globally 10 | 11 | class ZephyrLLMConfig(BaseModuleConfig): 12 | model_id: str = "HuggingFaceH4/zephyr-7b-beta" 13 | max_new_tokens_script: int = 2048 # Increased for new fields 14 | max_new_tokens_shot_prompt: int = 256 15 | temperature: float = 0.7 16 | top_k: int = 50 17 | top_p: float = 0.95 18 | 19 | class ZephyrLLM(BaseLLM): 20 | Config = ZephyrLLMConfig 21 | 22 | @classmethod 23 | def get_capabilities(cls) -> ModuleCapabilities: 24 | return ModuleCapabilities( 25 | title="Zephyr 7B", 26 | vram_gb_min=8.0, 27 | ram_gb_min=16.0, 28 | # LLM-specific capabilities are not the main focus, so we use defaults. 29 | ) 30 | 31 | def _load_model_and_tokenizer(self): 32 | if self.model is None or self.tokenizer is None: 33 | print(f"Loading LLM: {self.config.model_id}...") 34 | self.tokenizer = AutoTokenizer.from_pretrained(self.config.model_id) 35 | if self.tokenizer.pad_token is None: 36 | self.tokenizer.pad_token = self.tokenizer.eos_token 37 | 38 | try: 39 | self.model = AutoModelForCausalLM.from_pretrained( 40 | self.config.model_id, torch_dtype=torch.float16 41 | ).to(DEVICE) 42 | except Exception as e: 43 | print(f"Failed to load LLM with device_map='auto' ({e}), trying with explicit device: {DEVICE}") 44 | self.model = AutoModelForCausalLM.from_pretrained( 45 | self.config.model_id, torch_dtype=torch.float16 46 | ).to(DEVICE) 47 | print("LLM loaded.") 48 | 49 | def clear_vram(self): 50 | print("Clearing LLM VRAM...") 51 | models_to_clear = [m for m in [self.model] if m is not None] 52 | if models_to_clear: clear_vram_globally(*models_to_clear) 53 | self.model, self.tokenizer = None, None 54 | print("LLM VRAM cleared.") 55 | 56 | def _parse_llm_json_response(self, decoded_output: str, context: str = "script") -> Optional[Dict]: 57 | match = re.search(r'\{[\s\S]*\}', decoded_output) 58 | json_text = match.group(0) if match else decoded_output 59 | try: 60 | return json.loads(re.sub(r',(\s*[}\]])', r'\1', json_text)) 61 | except json.JSONDecodeError as e: 62 | print(f"Error parsing LLM JSON for {context}: {e}. Raw output:\n{decoded_output}") 63 | return None 64 | 65 | def generate_script(self, topic: str, content_config: ContentConfig) -> Dict[str, Any]: 66 | self._load_model_and_tokenizer() 67 | print(f"Generating script for topic: '{topic}' in language: {content_config.language}") 68 | 69 | # --- MODIFICATION START --- 70 | # Map language code to full name for better prompting 71 | language_map = { 72 | 'en': 'English', 'es': 'Spanish', 'fr': 'French', 73 | 'de': 'German', 'it': 'Italian', 'pt': 'Portuguese', 74 | 'pl': 'Polish', 'tr': 'Turkish', 'ru': 'Russian', 75 | 'nl': 'Dutch', 'cs': 'Czech', 'ar': 'Arabic', 76 | 'zh-cn': 'Chinese (Simplified)', 'ja': 'Japanese', 77 | 'hu': 'Hungarian', 'ko': 'Korean', 'hi': 'Hindi' 78 | } 79 | target_language = language_map.get(content_config.language, 'English') 80 | 81 | system_prompt = ( 82 | "You are a multilingual AI assistant creating content for a short video. " 83 | "You will be asked to write the narration in a specific language, but all other content (visual prompts, descriptions, hashtags) must be in English for the video generation models. " 84 | "Your response must be a single, valid JSON object with these exact keys: " 85 | "\"main_subject_description\", \"setting_description\", \"narration\", \"visuals\", \"hashtags\"." 86 | ) 87 | 88 | user_prompt = f""" 89 | **IMPORTANT INSTRUCTIONS:** 90 | 1. The **"narration"** text MUST be written in **{target_language}**. Use the native script if applicable (e.g., Devanagari for Hindi). 91 | 2. Use proper punctuation (like commas and periods) in the narration for a natural-sounding voiceover. 92 | 3. All other fields ("main_subject_description", "setting_description", "visuals", "hashtags") MUST remain in **English**. 93 | 94 | --- 95 | Create content for a short video about "{topic}". 96 | The total narration should be ~{content_config.target_video_length_hint}s, with {content_config.min_scenes} to {content_config.max_scenes} scenes. 97 | Each scene's narration should be ~{content_config.max_scene_narration_duration_hint}s. 98 | 99 | Return your response in this exact JSON format: 100 | {{ 101 | "main_subject_description": "A detailed, consistent description of the main character or subject (e.g., 'Fluffy, a chubby but cute orange tabby cat with green eyes'). MUST BE IN ENGLISH.", 102 | "setting_description": "A description of the primary environment (e.g., 'a cozy, sunlit living room with plush furniture'). MUST BE IN ENGLISH.", 103 | "narration": [ 104 | {{"scene": 1, "text": "First scene narration text, written in {target_language}.", "duration_estimate": {content_config.max_scene_narration_duration_hint}}} 105 | ], 106 | "visuals": [ 107 | {{"scene": 1, "prompt": "Detailed visual prompt for scene 1. MUST BE IN ENGLISH."}} 108 | ], 109 | "hashtags": ["relevantTag1", "relevantTag2"] 110 | }} 111 | """ 112 | # --- MODIFICATION END --- 113 | 114 | messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] 115 | 116 | for attempt in range(3): 117 | print(f"Attempt {attempt + 1} of 3 to generate valid script JSON...") 118 | 119 | tokenized_chat = self.tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(self.model.device) 120 | outputs = self.model.generate( 121 | input_ids=tokenized_chat, max_new_tokens=self.config.max_new_tokens_script, 122 | do_sample=True, top_k=self.config.top_k, top_p=self.config.top_p, 123 | temperature=self.config.temperature, pad_token_id=self.tokenizer.eos_token_id 124 | ) 125 | decoded_output = self.tokenizer.decode(outputs[0][tokenized_chat.shape[-1]:], skip_special_tokens=True) 126 | response_data = self._parse_llm_json_response(decoded_output, "script") 127 | 128 | if response_data and all(k in response_data for k in ["narration", "visuals", "main_subject_description"]): 129 | print("Successfully generated and parsed valid script JSON.") 130 | return { 131 | "main_subject_description": response_data.get("main_subject_description"), 132 | "setting_description": response_data.get("setting_description"), 133 | "narration": sorted(response_data.get("narration", []), key=lambda x: x["scene"]), 134 | "visuals": [p["prompt"] for p in sorted(response_data.get("visuals", []), key=lambda x: x["scene"])], 135 | "hashtags": response_data.get("hashtags", []) 136 | } 137 | else: 138 | print(f"Attempt {attempt + 1} failed. The response was not a valid JSON or was missing required keys.") 139 | if attempt < 2: 140 | print("Retrying...") 141 | 142 | print("LLM script generation failed after 3 attempts. Using fallback.") 143 | # Fallback remains in English as a safe default 144 | return { 145 | "main_subject_description": topic, "setting_description": "a simple background", 146 | "narration": [{"text": f"An intro to {topic}.", "duration_estimate": 5.0}], 147 | "visuals": [f"Cinematic overview of {topic}."], "hashtags": [f"#{topic.replace(' ', '')}"] 148 | } 149 | 150 | def generate_shot_visual_prompts(self, scene_narration: str, original_scene_prompt: str, num_shots: int, content_config: ContentConfig, main_subject: str, setting: str) -> List[Tuple[str, str]]: 151 | self._load_model_and_tokenizer() 152 | shot_prompts = [] 153 | 154 | # Define the prompts, which are the same for each shot generation call 155 | system_prompt = ( 156 | "You are an Movie director. Your task is to generate a 'visual_prompt' and a 'motion_prompt' for a short video shot " 157 | "The prompts MUST incorporate the provided main subject and setting. Do NOT change the subject. " 158 | "Respond in this exact JSON format: {\"visual_prompt\": \"...\", \"motion_prompt\": \"...\"}" 159 | ) 160 | 161 | for shot_idx in range(num_shots): 162 | print(f"--- Generating prompts for Shot {shot_idx + 1}/{num_shots} ---") 163 | 164 | # --- NEW: Defensive check to prevent intermittent crashes --- 165 | # This handles rare cases where the model/tokenizer might be cleared from memory 166 | # between calls within the same task execution. 167 | if self.model is None or self.tokenizer is None: 168 | print("WARNING: LLM was unloaded unexpectedly. Forcing a reload before generating shot prompt.") 169 | self._load_model_and_tokenizer() 170 | 171 | user_prompt = f""" 172 | **Main Subject (MUST BE INCLUDED):** {main_subject} 173 | **Setting (MUST BE INCLUDED):** {setting} 174 | 175 | --- 176 | **Original Scene Goal:** "{original_scene_prompt}" 177 | **This Shot's Narration:** "{scene_narration}" 178 | 179 | Based on ALL the information above, create a visual and motion prompt for shot {shot_idx + 1}/{num_shots}. 180 | The visual prompt should be a specific, detailed moment consistent with the subject and setting. 181 | try to describe the visual prompt in minimum words but in very specific details what a director would want the image to look like. 182 | Descrive character, subject and envrionment in words, only chose important words no need to make complete sentances. 183 | try to describe the visual prompt in minimum words but in very specific details what a director would want the image to look like. 184 | Descrive character, subject and envrionment in words, only chose important words no need to make complete sentances. 185 | Also descirbe camera mm, shot type, location, lighting, color, mood, etc. 186 | Do not include any other text or comments other then given json format. 187 | """ 188 | messages = [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}] 189 | 190 | visual_prompt, motion_prompt = None, None 191 | 192 | # --- MODIFICATION START: Add retry loop for each shot --- 193 | for attempt in range(3): 194 | print(f"Attempt {attempt + 1} of 3 to generate valid prompt JSON for shot {shot_idx + 1}...") 195 | 196 | tokenized_chat = self.tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(self.model.device) 197 | outputs = self.model.generate( 198 | input_ids=tokenized_chat, max_new_tokens=self.config.max_new_tokens_shot_prompt, 199 | do_sample=True, temperature=self.config.temperature, pad_token_id=self.tokenizer.eos_token_id 200 | ) 201 | decoded_output = self.tokenizer.decode(outputs[0][tokenized_chat.shape[-1]:], skip_special_tokens=True) 202 | response_data = self._parse_llm_json_response(decoded_output, f"shot {shot_idx+1} prompt") 203 | 204 | # Check for a dictionary with both required string keys 205 | if (isinstance(response_data, dict) and 206 | isinstance(response_data.get("visual_prompt"), str) and 207 | isinstance(response_data.get("motion_prompt"), str)): 208 | 209 | visual_prompt = response_data["visual_prompt"] 210 | motion_prompt = response_data["motion_prompt"] 211 | print(f"Successfully generated and parsed prompts for shot {shot_idx + 1}.") 212 | break # Exit the retry loop on success 213 | else: 214 | print(f"Attempt {attempt + 1} failed for shot {shot_idx + 1}. Invalid JSON or missing keys.") 215 | # --- MODIFICATION END --- 216 | 217 | # If after 3 attempts, we still don't have prompts, use the fallback 218 | if not visual_prompt or not motion_prompt: 219 | print(f"All attempts failed for shot {shot_idx + 1}. Using fallback prompts.") 220 | visual_prompt = f"{main_subject} in {setting}, {original_scene_prompt}" 221 | motion_prompt = "gentle camera movement" 222 | 223 | shot_prompts.append((visual_prompt, motion_prompt)) 224 | print(f" > Shot {shot_idx+1} Visual: \"{visual_prompt[:80]}...\"") 225 | print(f" > Shot {shot_idx+1} Motion: \"{motion_prompt[:80]}...\"") 226 | 227 | return shot_prompts -------------------------------------------------------------------------------- /task_executor.py: -------------------------------------------------------------------------------- 1 | # In task_executor.py 2 | import logging 3 | import math 4 | import os 5 | import random 6 | from typing import Optional, Dict 7 | import torch 8 | from importlib import import_module 9 | 10 | from project_manager import ProjectManager, STATUS_IMAGE_GENERATED, STATUS_VIDEO_GENERATED, STATUS_FAILED 11 | from config_manager import ContentConfig 12 | from video_assembly import assemble_final_reel, assemble_scene_video_from_sub_clips 13 | from base_modules import ModuleCapabilities 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | def _import_class(module_path_str: str): 18 | module_path, class_name = module_path_str.rsplit('.', 1) 19 | module = import_module(module_path) 20 | return getattr(module, class_name) 21 | 22 | class TaskExecutor: 23 | def __init__(self, project_manager: ProjectManager): 24 | self.project_manager = project_manager 25 | self.content_cfg = ContentConfig(**self.project_manager.state.project_info.config) 26 | 27 | module_selections = self.content_cfg.module_selections 28 | if not module_selections: 29 | raise ValueError("Project state is missing module selections. Cannot initialize TaskExecutor.") 30 | 31 | # --- START OF FIX: Use .get() for safe module loading to prevent crashes --- 32 | 33 | # LLM and TTS are always required 34 | LlmClass = _import_class(module_selections["llm"]) 35 | self.llm_module = LlmClass(LlmClass.Config()) 36 | 37 | TtsClass = _import_class(module_selections["tts"]) 38 | self.tts_module = TtsClass(TtsClass.Config()) 39 | 40 | # Video modules are optional depending on the flow 41 | self.t2i_module = None 42 | self.i2v_module = None 43 | self.t2v_module = None 44 | 45 | t2i_path = module_selections.get("t2i") 46 | if t2i_path: 47 | T2iClass = _import_class(t2i_path) 48 | self.t2i_module = T2iClass(T2iClass.Config()) 49 | 50 | i2v_path = module_selections.get("i2v") 51 | if i2v_path: 52 | I2vClass = _import_class(i2v_path) 53 | self.i2v_module = I2vClass(I2vClass.Config()) 54 | 55 | t2v_path = module_selections.get("t2v") 56 | if t2v_path: 57 | T2vClass = _import_class(t2v_path) 58 | self.t2v_module = T2vClass(T2vClass.Config()) 59 | 60 | # Determine capabilities based on which modules were actually loaded 61 | self.active_flow_supports_characters = False 62 | if self.content_cfg.use_svd_flow and self.t2i_module: 63 | t2i_caps = self.t2i_module.get_capabilities() 64 | self.active_flow_supports_characters = t2i_caps.supports_ip_adapter 65 | logger.info("Decisive module for character support: T2I module.") 66 | elif not self.content_cfg.use_svd_flow and self.t2v_module: 67 | t2v_caps = self.t2v_module.get_capabilities() 68 | self.active_flow_supports_characters = t2v_caps.supports_ip_adapter 69 | logger.info("Decisive module for character support: T2V module.") 70 | # --- END OF FIX --- 71 | 72 | logger.info(f"Holistic check: Active flow supports characters: {self.active_flow_supports_characters}") 73 | self._configure_from_model_capabilities() 74 | 75 | def _configure_from_model_capabilities(self): 76 | logger.info("--- TaskExecutor: Configuring run from model capabilities... ---") 77 | if self.content_cfg.use_svd_flow: 78 | if self.t2i_module and self.i2v_module: 79 | t2i_caps = self.t2i_module.get_model_capabilities() 80 | i2v_caps = self.i2v_module.get_model_capabilities() 81 | self.content_cfg.generation_resolution = t2i_caps["resolutions"].get(self.content_cfg.aspect_ratio_format) 82 | self.content_cfg.model_max_video_shot_duration = i2v_caps.get("max_shot_duration", 3.0) 83 | else: 84 | logger.warning("Warning: T2I or I2V module not loaded for I2V flow. Using default configurations.") 85 | else: # T2V Flow 86 | if self.t2v_module: 87 | t2v_caps = self.t2v_module.get_model_capabilities() 88 | self.content_cfg.generation_resolution = t2v_caps["resolutions"].get(self.content_cfg.aspect_ratio_format) 89 | self.content_cfg.model_max_video_shot_duration = t2v_caps.get("max_shot_duration", 2.0) 90 | else: 91 | logger.warning("Warning: T2V module not loaded for T2V flow. Using default configurations.") 92 | 93 | logger.info(f"Dynamically set Generation Resolution to: {self.content_cfg.generation_resolution}") 94 | logger.info(f"Dynamically set Max Shot Duration to: {self.content_cfg.model_max_video_shot_duration}s") 95 | self.project_manager.state.project_info.config = self.content_cfg.model_dump() 96 | self.project_manager._save_state() 97 | 98 | def execute_task(self, task: str, task_data: Dict) -> bool: 99 | try: 100 | # --- START OF FIX: Refresh config before every task to prevent stale state --- 101 | self.content_cfg = ContentConfig(**self.project_manager.state.project_info.config) 102 | logger.info(f"Executing task '{task}' with add_narration_text set to: {self.content_cfg.add_narration_text_to_video}") 103 | # --- END OF FIX --- 104 | 105 | task_map = { 106 | "generate_script": self._execute_generate_script, "generate_audio": self._execute_generate_audio, 107 | "create_scene": self._execute_create_scene, "generate_shot_image": self._execute_generate_shot_image, 108 | "generate_shot_video": self._execute_generate_shot_video, "generate_shot_t2v": self._execute_generate_shot_t2v, 109 | "assemble_scene": self._execute_assemble_scene, "assemble_final": self._execute_assemble_final, 110 | } 111 | if task in task_map: return task_map[task](**task_data) 112 | logger.error(f"Unknown task: {task}"); return False 113 | except Exception as e: 114 | logger.error(f"Error executing task {task}: {e}", exc_info=True); return False 115 | 116 | def _execute_generate_script(self, topic: str) -> bool: 117 | script_data = self.llm_module.generate_script(topic, self.content_cfg) 118 | self.llm_module.clear_vram() 119 | self.project_manager.update_script(script_data) 120 | return True 121 | 122 | def _execute_generate_audio(self, scene_idx: int, text: str, speaker_wav: Optional[str] = None) -> bool: 123 | path, duration = self.tts_module.generate_audio(text, self.content_cfg.output_dir, scene_idx, language=self.content_cfg.language, speaker_wav=speaker_wav) 124 | self.project_manager.update_narration_part_status(scene_idx, "generated", path, duration if duration > 0.1 else 0.0) 125 | return True 126 | 127 | def _execute_create_scene(self, scene_idx: int) -> bool: 128 | narration = self.project_manager.state.script.narration_parts[scene_idx] 129 | visual_prompt = self.project_manager.state.script.visual_prompts[scene_idx] 130 | main_subject = self.project_manager.state.script.main_subject_description 131 | setting = self.project_manager.state.script.setting_description 132 | 133 | actual_audio_duration = narration.duration 134 | max_shot_duration = self.content_cfg.model_max_video_shot_duration 135 | 136 | if actual_audio_duration <= 0 or max_shot_duration <= 0: 137 | num_shots = 1 138 | logger.warning(f"Warning: Invalid duration detected for Scene {scene_idx} (Audio: {actual_audio_duration}s, Max Shot: {max_shot_duration}s). Defaulting to 1 shot.") 139 | else: 140 | num_shots = math.ceil(actual_audio_duration / max_shot_duration) or 1 141 | 142 | logger.info("--- Calculating Shots for Scene {} ---".format(scene_idx)) 143 | logger.info(f" - Actual Audio Duration: {actual_audio_duration:.2f}s") 144 | logger.info(f" - Model's Max Shot Duration: {max_shot_duration:.2f}s") 145 | logger.info(f" - Calculated Number of Shots: {num_shots} ({actual_audio_duration:.2f}s / {max_shot_duration:.2f}s)") 146 | 147 | shot_prompts = self.llm_module.generate_shot_visual_prompts( 148 | narration.text, visual_prompt.prompt, num_shots, self.content_cfg, main_subject, setting 149 | ) 150 | self.llm_module.clear_vram() 151 | 152 | shots = [] 153 | for i, (visual, motion) in enumerate(shot_prompts): 154 | if i < num_shots - 1: 155 | duration = max_shot_duration 156 | else: 157 | duration = actual_audio_duration - (i * max_shot_duration) 158 | 159 | shots.append({"shot_idx": i, "target_duration": max(0.5, duration), "visual_prompt": visual, "motion_prompt": motion}) 160 | 161 | all_character_names = [char.name for char in self.project_manager.state.characters] 162 | logger.info(f"Creating Scene {scene_idx} and assigning default characters: {all_character_names}") 163 | self.project_manager.add_scene(scene_idx, shots, character_names=all_character_names) 164 | return True 165 | 166 | def _execute_generate_shot_image(self, scene_idx: int, shot_idx: int, visual_prompt: str, **kwargs) -> bool: 167 | if not self.t2i_module: 168 | logger.error("Attempted to generate image, but T2I module is not loaded for this workflow.") 169 | return False 170 | w, h = self.content_cfg.generation_resolution 171 | path = os.path.join(self.content_cfg.output_dir, f"scene_{scene_idx}_shot_{shot_idx}_keyframe.png") 172 | 173 | base_seed = self.content_cfg.seed 174 | shot_seed = random.randint(0, 2**32 - 1) if base_seed == -1 else base_seed + scene_idx * 100 + shot_idx 175 | 176 | negative_prompt = "worst quality, low quality, bad anatomy, text, watermark, jpeg artifacts, blurry" 177 | 178 | scene = self.project_manager.get_scene_info(scene_idx) 179 | ip_adapter_image_paths = [] 180 | if scene and scene.character_names: 181 | logger.info(f"Found characters for Scene {scene_idx}: {scene.character_names}") 182 | for name in scene.character_names: 183 | char = self.project_manager.get_character(name) 184 | if char and os.path.exists(char.reference_image_path): 185 | ip_adapter_image_paths.append(char.reference_image_path) 186 | 187 | self.t2i_module.generate_image( 188 | prompt=visual_prompt, negative_prompt=negative_prompt, output_path=path, 189 | width=w, height=h, ip_adapter_image=ip_adapter_image_paths or None, seed=shot_seed 190 | ) 191 | 192 | self.project_manager.update_shot_status(scene_idx, shot_idx, STATUS_IMAGE_GENERATED, keyframe_path=path) 193 | self.t2i_module.clear_vram() 194 | return True 195 | 196 | def _execute_generate_shot_video(self, scene_idx: int, shot_idx: int, visual_prompt: str, motion_prompt: Optional[str], **kwargs) -> bool: 197 | if not self.i2v_module: 198 | logger.error("Attempted to generate video from image, but I2V module is not loaded for this workflow.") 199 | return False 200 | shot = self.project_manager.get_scene_info(scene_idx).shots[shot_idx] 201 | if not shot.keyframe_image_path or not os.path.exists(shot.keyframe_image_path): return False 202 | 203 | enhanced_visual = self.i2v_module.enhance_prompt(visual_prompt, "visual") 204 | enhanced_motion = self.i2v_module.enhance_prompt(motion_prompt, "motion") 205 | 206 | scene = self.project_manager.get_scene_info(scene_idx) 207 | ip_adapter_image_paths = [self.project_manager.get_character(name).reference_image_path for name in scene.character_names if self.project_manager.get_character(name)] 208 | 209 | video_path = os.path.join(self.content_cfg.output_dir, f"scene_{scene_idx}_shot_{shot_idx}_svd.mp4") 210 | 211 | sub_clip_path = self.i2v_module.generate_video_from_image( 212 | image_path=shot.keyframe_image_path, output_video_path=video_path, target_duration=shot.target_duration, 213 | content_config=self.content_cfg, visual_prompt=enhanced_visual, motion_prompt=enhanced_motion, 214 | ip_adapter_image=ip_adapter_image_paths or None 215 | ) 216 | 217 | if sub_clip_path and os.path.exists(sub_clip_path): 218 | self.project_manager.update_shot_status(scene_idx, shot_idx, STATUS_VIDEO_GENERATED, video_path=sub_clip_path) 219 | return True 220 | self.project_manager.update_shot_status(scene_idx, shot_idx, STATUS_FAILED); return False 221 | 222 | def _execute_generate_shot_t2v(self, scene_idx: int, shot_idx: int, visual_prompt: str, **kwargs) -> bool: 223 | if not self.t2v_module: 224 | logger.error("Attempted to generate video from text, but T2V module is not loaded for this workflow.") 225 | return False 226 | shot = self.project_manager.get_scene_info(scene_idx).shots[shot_idx] 227 | num_frames = int(shot.target_duration * self.content_cfg.fps) 228 | w, h = self.content_cfg.generation_resolution 229 | 230 | enhanced_prompt = self.t2v_module.enhance_prompt(visual_prompt) 231 | 232 | scene = self.project_manager.get_scene_info(scene_idx) 233 | ip_adapter_image_paths = [self.project_manager.get_character(name).reference_image_path for name in scene.character_names if self.project_manager.get_character(name)] 234 | 235 | video_path = os.path.join(self.content_cfg.output_dir, f"scene_{scene_idx}_shot_{shot_idx}_t2v.mp4") 236 | 237 | sub_clip_path = self.t2v_module.generate_video_from_text( 238 | enhanced_prompt, video_path, num_frames, self.content_cfg.fps, w, h, 239 | ip_adapter_image=ip_adapter_image_paths or None 240 | ) 241 | 242 | if sub_clip_path and os.path.exists(sub_clip_path): 243 | self.project_manager.update_shot_status(scene_idx, shot_idx, STATUS_VIDEO_GENERATED, video_path=sub_clip_path) 244 | return True 245 | self.project_manager.update_shot_status(scene_idx, shot_idx, STATUS_FAILED); return False 246 | 247 | def _execute_assemble_scene(self, scene_idx: int, **kwargs) -> bool: 248 | scene = self.project_manager.get_scene_info(scene_idx) 249 | if not scene: return False 250 | video_paths = [c.video_path for c in scene.shots if c.status == STATUS_VIDEO_GENERATED] 251 | if len(video_paths) != len(scene.shots): return False 252 | 253 | narration_duration = self.project_manager.state.script.narration_parts[scene_idx].duration 254 | final_path = assemble_scene_video_from_sub_clips(video_paths, narration_duration, self.content_cfg, scene_idx) 255 | 256 | if final_path: 257 | self.project_manager.update_scene_status(scene_idx, "completed", assembled_video_path=final_path) 258 | return True 259 | self.project_manager.update_scene_status(scene_idx, "failed"); return False 260 | 261 | def _execute_assemble_final(self, **kwargs) -> bool: 262 | narration_parts = self.project_manager.state.script.narration_parts 263 | assets = [ 264 | (s.assembled_video_path, narration_parts[s.scene_idx].audio_path, {"text": narration_parts[s.scene_idx].text, "duration": narration_parts[s.scene_idx].duration}) 265 | for s in self.project_manager.state.scenes if s.status == "completed" 266 | ] 267 | if len(assets) != len(self.project_manager.state.scenes): return False 268 | 269 | topic = self.project_manager.state.project_info.topic 270 | final_path = assemble_final_reel(assets, self.content_cfg, output_filename=f"{topic.replace(' ','_')}_final.mp4") 271 | 272 | if final_path and os.path.exists(final_path): 273 | text = " ".join([a[2]["text"] for a in assets]) 274 | hashtags = self.project_manager.state.script.hashtags 275 | self.project_manager.update_final_video(final_path, "generated", text, hashtags) 276 | return True 277 | self.project_manager.update_final_video("", "pending", "", []); return False -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Modular AI Video Generation Pipeline 2 | 3 | [](https://www.python.org/) 4 | [](https://streamlit.io) 5 | [](LICENSE) 6 | 7 | ## ⚠️ Important Notes 8 | 9 | **Video Quality Issues**: If your generated videos appear scrambled or distorted, this typically means you're not using the optimal video dimensions that the selected model was trained on. Each AI model has specific resolution requirements for best results. Check the model documentation for recommended dimensions and adjust your video settings accordingly. 10 | 11 | **Contributors Welcome!** 🚀 This project is open to contributions from the community. If you're interested in helping improve this pipeline, adding new models, or fixing bugs, please feel free to submit pull requests or open issues. 12 | 13 | **New Project Announcement**: I've started working on a completely separate and different video generation project. If you're interested in learning more or collaborating, feel free to reach out to me on [LinkedIn](https://www.linkedin.com/in/gowravvishwakarma/)! 14 | 15 | --- 16 | 17 | An extensible, modular pipeline for generating short-form videos using a variety of AI models. This tool provides a powerful Streamlit-based web interface to define a video topic, select different AI models for each generation step (language, speech, image, video), and orchestrate the entire content creation process from script to final rendered video. 18 | 19 | ## 🎥 Demo Video 20 | 21 |
24 |
25 | Watch the full demo on YouTube
26 |