├── .dockerignore ├── .github └── workflows │ └── python-app.yml ├── .gitignore ├── .pre-commit-config.yaml ├── Dockerfile ├── LICENSE ├── README.md ├── ShortsMaker ├── __init__.py ├── ask_llm.py ├── generate_image.py ├── moviepy_create_video.py ├── shorts_maker.py └── utils │ ├── __init__.py │ ├── audio_transcript.py │ ├── colors_dict.py │ ├── download_youtube_music.py │ ├── download_youtube_video.py │ ├── get_tts.py │ ├── logging_config.py │ ├── notify_discord.py │ └── retry.py ├── assets ├── credits │ ├── credits.mp4 │ └── credits_mask.mp4 └── fonts │ ├── Monaco.ttf │ └── Roboto.ttf ├── example.py ├── example.setup.yml ├── pyproject.toml ├── requirements.txt ├── tests ├── __init__.py ├── ask_llm_tests │ ├── test_ask_llm.py │ └── test_ollama_service.py ├── conftest.py ├── data │ ├── setup.yml │ ├── test.txt │ ├── test.wav │ └── transcript.json ├── generate_image_tests │ └── test_generate_image.py ├── moviepy_create_video_tests │ └── test_moviepy_create_video.py ├── shorts_maker_tests │ ├── test_abbreviation_replacer.py │ ├── test_fix_text.py │ ├── test_has_alpha_and_digit.py │ ├── test_shorts_maker.py │ └── test_split_alpha_and_digit.py └── utils_tests │ ├── test_audio_transcript.py │ ├── test_colors_dict.py │ ├── test_download_youtube_music.py │ ├── test_download_youtube_video.py │ ├── test_get_tts.py │ ├── test_logging_config.py │ ├── test_notify_discord.py │ └── test_retry.py └── uv.lock /.dockerignore: -------------------------------------------------------------------------------- 1 | .venv 2 | old_scripts 3 | 4 | # remove any cache 5 | cache 6 | assets 7 | 8 | # ignore cache files 9 | *.pyc 10 | *.pyo 11 | 12 | # ignore logs 13 | *.log 14 | 15 | # ignore pytest cache 16 | .pytest_cache 17 | 18 | # IDE caches 19 | .vscode 20 | .idea 21 | 22 | # ignore ipynb files and checkpoints, which may be used for testing 23 | *.ipynb 24 | .ipynb_checkpoints 25 | .ruff_cache 26 | -------------------------------------------------------------------------------- /.github/workflows/python-app.yml: -------------------------------------------------------------------------------- 1 | name: ShortsMaker 2 | 3 | on: 4 | push: 5 | branches: [ "main" ] 6 | 7 | jobs: 8 | test: 9 | permissions: 10 | contents: read 11 | pull-requests: write 12 | runs-on: ubuntu-latest 13 | env: 14 | DISCORD_WEBHOOK_URL: "https://discord.com/api/webhooks/xxx" 15 | steps: 16 | - name: Setup Java 17 | uses: actions/setup-java@v4 18 | with: 19 | distribution: temurin 20 | java-version: 21 21 | - name: Setup FFmpeg 22 | uses: federicocarboni/setup-ffmpeg@v3.1 23 | - name: Checkout the repository 24 | uses: actions/checkout@main 25 | - name: Install the latest version of uv and set the python version 26 | uses: astral-sh/setup-uv@v5 27 | with: 28 | version: "latest" 29 | enable-cache: true 30 | pyproject-file: "pyproject.toml" 31 | - name: Install dependencies 32 | run: uv sync --frozen --extra cpu 33 | - name: Run tests 34 | run: uv run --frozen pytest 35 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .venv 2 | old_scripts 3 | 4 | # remove any media files 5 | assets 6 | cache 7 | 8 | # config files 9 | setup.yml 10 | 11 | # ignore cache files 12 | *.pyc 13 | *.pyo 14 | 15 | # ignore logs 16 | *.log 17 | 18 | # ignore pytest cache 19 | .pytest_cache 20 | 21 | # IDE caches 22 | .vscode 23 | .idea 24 | 25 | # Ignore any media files 26 | *.mp4 27 | *.avi 28 | *.wav 29 | *.mp3 30 | 31 | # ignore ipynb files and checkpoints, which may be used for testing 32 | *.ipynb 33 | .ipynb_checkpoints 34 | .ruff_cache 35 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/astral-sh/ruff-pre-commit 5 | # Ruff version. 6 | rev: v0.9.4 7 | hooks: 8 | # Run the linter. 9 | - id: ruff 10 | # Run the formatter. 11 | - id: ruff-format 12 | - repo: https://github.com/pre-commit/pre-commit-hooks 13 | rev: v5.0.0 14 | hooks: 15 | - id: trailing-whitespace 16 | - id: end-of-file-fixer 17 | - id: check-yaml 18 | - id: debug-statements 19 | language_version: python3 20 | - id: detect-private-key 21 | - repo: https://github.com/Lucas-C/pre-commit-hooks 22 | rev: v1.5.5 23 | hooks: 24 | - id: forbid-crlf 25 | - id: remove-crlf 26 | - id: forbid-tabs 27 | - id: remove-tabs 28 | - repo: https://github.com/astral-sh/uv-pre-commit 29 | rev: 0.5.29 30 | hooks: 31 | - id: uv-lock 32 | - id: uv-export 33 | exclude: .venv 34 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Use a Python 3.12 base image 2 | FROM python:3.12-slim-bookworm 3 | COPY --from=ghcr.io/astral-sh/uv:0.6.6 /uv /bin/ 4 | 5 | # Set environment variables 6 | # Set this appropriately or leave empty if not using Discord 7 | ENV DISCORD_WEBHOOK_URL="https://discord.com/api/webhooks/xxxxxx" 8 | 9 | # Install ffmpeg and Java 10 | # Install dependencies for moviepy 11 | RUN apt-get update -y && \ 12 | apt-get install -y --no-install-recommends ffmpeg openjdk-17-jre locales && \ 13 | apt-get clean && \ 14 | rm -rf /var/lib/apt/lists/* && \ 15 | locale-gen C.UTF-8 && \ 16 | /usr/sbin/update-locale LANG=C.UTF-8 17 | 18 | ENV LC_ALL=C.UTF-8 19 | 20 | # Set working directory 21 | ADD . /shorts_maker 22 | 23 | RUN cd /shorts_maker && \ 24 | uv sync --frozen 25 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ShortsMaker 2 | 3 | [![ShortsMaker](https://github.com/rajathjn/shorts_maker/actions/workflows/python-app.yml/badge.svg)](https://github.com/rajathjn/shorts_maker/actions/workflows/python-app.yml) 4 | [![CodeQL](https://github.com/rajathjn/shorts_maker/actions/workflows/github-code-scanning/codeql/badge.svg)](https://github.com/rajathjn/shorts_maker/actions/workflows/github-code-scanning/codeql) 5 | 6 | ShortsMaker is a Python package designed to facilitate the creation of engaging short videos or social media clips. It leverages a variety of external services and libraries to streamline the process of generating, processing, and uploading short content. 7 | 8 | ## Support Me 9 | Like what I do, Please consider supporting me. 10 | 11 | Coindrop.to me 12 | 13 | ## Table of Contents 14 | 15 | - [ShortsMaker](#shortsmaker) 16 | - [Support Me](#support-me) 17 | - [Table of Contents](#table-of-contents) 18 | - [Features](#features) 19 | - [Requirements](#requirements) 20 | - [Usage Via Docker](#usage-via-docker) 21 | - [Installation](#installation) 22 | - [External Dependencies](#external-dependencies) 23 | - [Environment Variables](#environment-variables) 24 | - [Usage](#usage) 25 | - [Example Video](#example-video) 26 | - [TODO](#todo) 27 | - [Development](#development) 28 | - [Contributing](#contributing) 29 | - [License](#license) 30 | 31 | ## Features 32 | 33 | - **Automated Content Creation:** Easily generate engaging short videos. 34 | - **External Service Integration:** Seamlessly integrates with services like Discord for notifications. 35 | - **GPU-Accelerated Processing:** Optional GPU support for faster processing using whisperx. 36 | - **Modular Design:** Built with extensibility in mind. 37 | - **In Development:** AskLLM AI agent now fully integrated for generating metadata and creative insights. 38 | - **In Development:** GenerateImage class enhanced for text-to-image generation using flux. May be resource intensive. 39 | 40 | ## Requirements 41 | 42 | - **Python:** 3.12.8 43 | - **Package Manager:** [`uv`](https://docs.astral.sh/uv/) is used for package management. ( It's amazing! try it out. ) 44 | - **Operating System:** Windows, Mac, or Linux (ensure external dependencies are installed for your platform) 45 | 46 | ## Usage Via Docker 47 | 48 | To use ShortsMaker via Docker, follow these steps: 49 | 50 | 1. **Build the Docker Image:** 51 | 52 | Build the Docker image using the provided Dockerfile. 53 | 54 | ```bash 55 | docker build -t shorts_maker -f Dockerfile . 56 | ``` 57 | 58 | 2. **Run the Docker Container:** 59 | 60 | For the first time, run the container with the necessary mounts, container name, and working directory set. 61 | 62 | ```bash 63 | docker run --name shorts_maker_container -v $pwd/assets:/shorts_maker/assets -w /shorts_maker -it shorts_maker bash 64 | ``` 65 | 66 | 3. **Start the Docker Container:** 67 | 68 | If the container was previously stopped, you can start it again using: 69 | 70 | ```bash 71 | docker start shorts_maker_container 72 | ``` 73 | 74 | 4. **Access the Docker Container:** 75 | 76 | Execute a bash shell inside the running container. 77 | 78 | ```bash 79 | docker exec -it shorts_maker_container bash 80 | ``` 81 | 82 | 5. **Run Examples and Tests:** 83 | 84 | Once you are in the bash shell of the container, you can run the example script or tests using `uv`. 85 | 86 | To run the example script: 87 | 88 | ```bash 89 | uv run example.py 90 | ``` 91 | 92 | To run tests: 93 | 94 | ```bash 95 | uv run pytest 96 | ``` 97 | 98 | **Note:** If you plan to use `ask_llm` or `generate_image`, it is not recommended to use the Docker image due to the high resource requirements of these features. Instead, run ShortsMaker directly on your host machine. 99 | 100 | ## Installation 101 | 102 | 1. **Clone the Repository:** 103 | 104 | ```bash 105 | git clone https://github.com/rajathjn/shorts_maker 106 | cd shorts_maker 107 | ``` 108 | 109 | 2. **Install the Package Using uv:** 110 | 111 | Note: Before starting the installation process. Ensure a python3.12 virtual environment is set up. 112 | 113 | ```bash 114 | uv venv -p 3.12 .venv 115 | 116 | or 117 | 118 | python -m venv .venv 119 | ``` 120 | 121 | Package Installation. 122 | 123 | ```bash 124 | uv pip install -r pyproject.toml 125 | 126 | or 127 | 128 | uv sync 129 | uv sync --extra cpu # for cpu 130 | uv sync --extra cu124 # for cuda 12.4 versions 131 | ``` 132 | 133 | 4. **Install Any Additional Python Dependencies:** 134 | 135 | If not automatically managed by uv, you may install them using pip ( In most cases you do not need to use the below. ): 136 | 137 | ```bash 138 | pip install -r requirements.txt 139 | ``` 140 | 141 | ## External Dependencies 142 | 143 | ShortsMaker relies on several external non-Python components. Please ensure the following are installed/configured on your system: 144 | 145 | - **Discord Notifications:** 146 | - You must set your Discord webhook URL (`DISCORD_WEBHOOK_URL`) as an environment variable. 147 | - Refer to the [Discord documentation](https://discord.com/developers/docs/resources/webhook#create-webhook) for creating a webhook. 148 | - If you don't want to use Discord notifications, you can set `DISCORD_WEBHOOK_URL` to `None` or do something like 149 | 150 | ```python 151 | import os 152 | os.environ["DISCORD_WEBHOOK_URL"] = "None" 153 | ``` 154 | 155 | - **Ollama:** 156 | - The external tool Ollama must be installed on your system. Refer to the [Ollama documentation](https://ollama.com/) for installation details. 157 | 158 | - **WhisperX (GPU Acceleration):** 159 | - For GPU execution, ensure that the NVIDIA libraries are installed on your system: 160 | - **cuBLAS:** Version 11.x 161 | - **cuDNN:** Version 8.x 162 | - These libraries are required for optimal performance when using whisperx for processing. 163 | 164 | ## Environment Variables 165 | 166 | Before running ShortsMaker, make sure you set the necessary environment variables: 167 | 168 | - **DISCORD_WEBHOOK_URL:** 169 | This token is required for sending notifications through Discord. 170 | Example (Windows Command Prompt): 171 | 172 | ```powershell 173 | set DISCORD_WEBHOOK_URL=your_discord_webhook_url_here 174 | ``` 175 | 176 | Example (Linux/macOS): 177 | 178 | ```bash 179 | export DISCORD_WEBHOOK_URL=your_discord_webhook_url_here 180 | ``` 181 | 182 | From Python: 183 | 184 | ```python 185 | import os 186 | os.environ["DISCORD_WEBHOOK_URL"] = "your_discord_webhook_url_here" 187 | ``` 188 | 189 | ## Usage 190 | 191 | Ensure you have a `setup.yml` configuration file in the `shorts_maker` directory. Use the [example-setup.yml](example.setup.yml) as a reference. 192 | 193 | Below is a basic example to get you started with ShortsMaker: 194 | 195 | You can also refer to the same [here](example.py) 196 | 197 | ```python 198 | from pathlib import Path 199 | 200 | import yaml 201 | 202 | from ShortsMaker import MoviepyCreateVideo, ShortsMaker 203 | 204 | setup_file = "setup.yml" 205 | 206 | with open(setup_file) as f: 207 | cfg = yaml.safe_load(f) 208 | 209 | get_post = ShortsMaker(setup_file) 210 | 211 | # You can either provide an URL for the reddit post 212 | get_post.get_reddit_post( 213 | url="https://www.reddit.com/r/Python/comments/1j36d7a/i_got_tired_of_ai_shorts_scams_so_i_built_my_own/" 214 | ) 215 | # Or just run the method to get a random post from the subreddit defined in setup.yml 216 | # get_post.get_reddit_post() 217 | 218 | with open(Path(cfg["cache_dir"]) / cfg["reddit_post_getter"]["record_file_txt"]) as f: 219 | script = f.read() 220 | 221 | get_post.generate_audio( 222 | source_txt=script, 223 | output_audio=f"{cfg['cache_dir']}/{cfg['audio']['output_audio_file']}", 224 | output_script_file=f"{cfg['cache_dir']}/{cfg['audio']['output_script_file']}", 225 | ) 226 | 227 | get_post.generate_audio_transcript( 228 | source_audio_file=f"{cfg['cache_dir']}/{cfg['audio']['output_audio_file']}", 229 | source_text_file=f"{cfg['cache_dir']}/{cfg['audio']['output_script_file']}", 230 | ) 231 | 232 | get_post.quit() 233 | 234 | create_video = MoviepyCreateVideo( 235 | config_file=setup_file, 236 | speed_factor=1.0, 237 | ) 238 | 239 | create_video(output_path="assets/output.mp4") 240 | 241 | create_video.quit() 242 | 243 | # Do not run the below when you are using shorts_maker within a container. 244 | 245 | # ask_llm = AskLLM(config_file=setup_file) 246 | # result = ask_llm.invoke(script) 247 | # print(result["parsed"].title) 248 | # print(result["parsed"].description) 249 | # print(result["parsed"].tags) 250 | # print(result["parsed"].thumbnail_description) 251 | # ask_llm.quit_llm() 252 | 253 | # You can use, AskLLM to generate a text prompt for the image generation as well 254 | # image_description = ask_llm.invoke_image_describer(script = script, input_text = "A wild scenario") 255 | # print(image_description) 256 | # print(image_description["parsed"].description) 257 | 258 | # Generate image uses a lot of resources so beware 259 | # generate_image = GenerateImage(config_file=setup_file) 260 | # generate_image.use_huggingface_flux_schnell(image_description["parsed"].description, "output.png") 261 | # generate_image.quit() 262 | 263 | ``` 264 | 265 | ## Example Video 266 | 267 | Generated from this post [here](https://www.reddit.com/r/selfhosted/comments/r2a6og/comment/hm5xoas/?utm_source=share&utm_medium=web3x&utm_name=web3xcss&utm_term=1&utm_content=share_button) 268 | 269 | https://github.com/user-attachments/assets/6aad212a-bfd5-4161-a2bc-67d24a8de37f 270 | 271 | ## TODO 272 | - [ ] Explain working and usage in blog. 273 | - [x] Dockerize the project, To avoid the complex set up process. 274 | - [x] Add option to fetch post from submission URLs. 275 | - [x] Add an example video to the README. 276 | 277 | ## Development 278 | 279 | If you want to contribute to the project, please follow these steps: 280 | 281 | 1. **Set up the development environment:** 282 | - Ensure you have Python 3.12.8 and uv installed. 283 | - Clone the repository and install the development dependencies. 284 | 285 | 2. **Run the Tests:** 286 | - Tests are located in the `tests/` directory. 287 | - Run tests using: 288 | 289 | ```bash 290 | uv run --frozen pytest 291 | ``` 292 | 293 | ## Contributing 294 | 295 | If you want to contribute to the project, please follow these steps: 296 | 297 | Follow everything in the [Development](#development) section and then: 298 | 299 | **Submit a Pull Request:** 300 | - Fork the repository. 301 | - Create a new branch for your feature or bugfix. 302 | - Commit your changes and push the branch to your fork. 303 | - Open a pull request with a detailed description of your changes. 304 | 305 | 306 | ## License 307 | 308 | This project is licensed under the GNU General Public License v3.0 License. See the [LICENSE](LICENSE) file for details. 309 | -------------------------------------------------------------------------------- /ShortsMaker/__init__.py: -------------------------------------------------------------------------------- 1 | from .ask_llm import AskLLM, OllamaServiceManager 2 | from .generate_image import GenerateImage 3 | from .moviepy_create_video import MoviepyCreateVideo, VideoConfig 4 | from .shorts_maker import ( 5 | ShortsMaker, 6 | abbreviation_replacer, 7 | has_alpha_and_digit, 8 | split_alpha_and_digit, 9 | ) 10 | from .utils.audio_transcript import ( 11 | align_transcript_with_script, 12 | generate_audio_transcription, 13 | ) 14 | from .utils.colors_dict import COLORS_DICT 15 | from .utils.download_youtube_music import download_youtube_music, sanitize_filename 16 | from .utils.download_youtube_video import download_youtube_video 17 | from .utils.get_tts import VOICES, tts 18 | from .utils.logging_config import configure_logging, get_logger 19 | from .utils.notify_discord import notify_discord 20 | from .utils.retry import retry 21 | 22 | __all__ = [ 23 | GenerateImage, 24 | ShortsMaker, 25 | AskLLM, 26 | OllamaServiceManager, 27 | MoviepyCreateVideo, 28 | VideoConfig, 29 | configure_logging, 30 | get_logger, 31 | download_youtube_music, 32 | download_youtube_video, 33 | abbreviation_replacer, 34 | COLORS_DICT, 35 | has_alpha_and_digit, 36 | split_alpha_and_digit, 37 | align_transcript_with_script, 38 | generate_audio_transcription, 39 | notify_discord, 40 | retry, 41 | sanitize_filename, 42 | VOICES, 43 | tts, 44 | ] 45 | -------------------------------------------------------------------------------- /ShortsMaker/ask_llm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import platform 3 | import subprocess 4 | import time 5 | from pathlib import Path 6 | 7 | import ollama 8 | import psutil 9 | import yaml 10 | from langchain_core.messages import HumanMessage, SystemMessage 11 | from langchain_core.prompts import ChatPromptTemplate 12 | from langchain_ollama import ChatOllama 13 | from pydantic import BaseModel, Field 14 | 15 | from .utils import get_logger 16 | 17 | 18 | class OllamaServiceManager: 19 | """ 20 | Manages the Ollama service, including starting, stopping, and checking its status. 21 | 22 | Attributes: 23 | system (str): The operating system the service is running on (e.g., "windows", "linux"). 24 | process (subprocess.Popen | None): The process object for the running Ollama service. 25 | ollama (module): The ollama module. 26 | logger (logging.Logger): The logger instance for the class. 27 | """ 28 | 29 | def __init__( 30 | self, 31 | logger: logging.Logger | None = None, 32 | ): 33 | """ 34 | Initializes the OllamaServiceManager. 35 | 36 | Args: 37 | logger (logging.Logger | None, optional): An optional logger instance. If not provided, a default logger is created. Defaults to None. 38 | """ 39 | self.system = platform.system().lower() 40 | self.process = None 41 | self.ollama = ollama 42 | 43 | self.logger = logger if logger else logging.getLogger(__name__) 44 | self.logger.name = "OllamaServiceManager" 45 | self.logger.info(f"Ollama service ollama_service_manager initialized on {self.system}") 46 | 47 | def start_service(self) -> bool: 48 | """ 49 | Starts the Ollama service. 50 | 51 | Returns: 52 | bool: True if the service started successfully, False otherwise. 53 | 54 | Raises: 55 | Exception: If there is an error starting the service. 56 | """ 57 | self.logger.info("Starting Ollama service") 58 | try: 59 | if self.system == "windows": 60 | ollama_execution_command = ["ollama app.exe", "serve"] 61 | else: 62 | ollama_execution_command = ["ollama", "serve"] 63 | self.process = subprocess.Popen( 64 | ollama_execution_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True 65 | ) 66 | 67 | # Wait a moment for the service to start 68 | time.sleep(2) 69 | 70 | # Check if the service started successfully 71 | if self.process.poll() is None: 72 | self.logger.info("Ollama service started successfully") 73 | return True 74 | except Exception as e: 75 | self.logger.error(f"Error starting Ollama service: {str(e)}") 76 | raise e 77 | return False 78 | 79 | def stop_service(self) -> bool: 80 | """ 81 | Stops the Ollama service. 82 | 83 | Returns: 84 | bool: True if the service stopped successfully, False otherwise. 85 | """ 86 | try: 87 | if self.process: 88 | if self.system == "windows": 89 | # For Windows 90 | subprocess.run( 91 | ["taskkill", "/F", "/IM", "ollama app.exe"], capture_output=True, text=True 92 | ) 93 | subprocess.run( 94 | ["taskkill", "/F", "/IM", "ollama.exe"], capture_output=True, text=True 95 | ) 96 | else: 97 | # For Linux/MacOS 98 | self.process.terminate() 99 | self.process.wait(timeout=5) 100 | del self.process 101 | self.process = None 102 | self.logger.info("Ollama service stopped successfully") 103 | return True 104 | else: 105 | self.logger.warning("Ollama service either started by user or has already stopped") 106 | return False 107 | 108 | except Exception as e: 109 | print(f"Error stopping Ollama service: {str(e)}") 110 | return False 111 | 112 | @staticmethod 113 | def is_ollama_running(): 114 | """ 115 | Checks if any Ollama process is currently running. 116 | 117 | Returns: 118 | bool: True if an Ollama process is running, False otherwise. 119 | """ 120 | for proc in psutil.process_iter(["name", "cmdline"]): 121 | try: 122 | if "ollama" in proc.info["name"].lower(): 123 | return True 124 | except (psutil.NoSuchProcess, psutil.AccessDenied): 125 | continue 126 | return False 127 | 128 | def is_service_running(self) -> bool: 129 | """ 130 | Checks if the Ollama service managed by this instance is running. 131 | 132 | Returns: 133 | bool: True if the service is running, False otherwise. 134 | """ 135 | if self.process and self.process.poll() is None: 136 | return True 137 | if self.is_ollama_running(): 138 | return True 139 | return False 140 | 141 | def get_running_models(self): 142 | """ 143 | Gets a list of the currently running models on Ollama. 144 | 145 | Returns: 146 | list: The models running in ollama service. 147 | """ 148 | return self.ollama.ps() 149 | 150 | def stop_running_model(self, model_name: str): 151 | """ 152 | Stops a specific model that is running in Ollama. 153 | 154 | Args: 155 | model_name (str): The name of the model to stop. 156 | 157 | Returns: 158 | bool: True if the model was stopped successfully, False otherwise. 159 | """ 160 | try: 161 | stop_attempt = subprocess.check_output( 162 | ["ollama", "stop", model_name], stderr=subprocess.STDOUT, text=True 163 | ) 164 | self.logger.info( 165 | f"Ollama service with {model_name} stopped successfully: {stop_attempt}" 166 | ) 167 | return True 168 | except Exception as e: 169 | self.logger.warning(f"Failed to stop {model_name}") 170 | self.logger.warning( 171 | "Either the model was already stopped, or the Ollama service is already stopped" 172 | ) 173 | self.logger.error(f"Error stopping Ollama service: {str(e)}") 174 | return False 175 | 176 | def get_llm_model(self, model_name: str): 177 | """ 178 | Downloads a specific LLM model using Ollama. 179 | 180 | Args: 181 | model_name (str): The name of the model to download. 182 | """ 183 | return self.ollama.pull(model_name) 184 | 185 | def get_list_of_downloaded_files(self) -> list[str]: 186 | """ 187 | Retrieves a list of models that have been downloaded by Ollama. 188 | 189 | Returns: 190 | list[str]: A list of the names of the downloaded models. 191 | """ 192 | model_list = [] 193 | try: 194 | model_list = list(self.ollama.list()) # list[tuple[model, list[models]]] 195 | model_list = [i.model for i in model_list[0][1]] 196 | except Exception as e: 197 | self.logger.error(f"Error getting list of downloaded files: {str(e)}") 198 | return model_list 199 | 200 | 201 | class AskLLM: 202 | """ 203 | A class to interact with a Large Language Model (LLM) using Ollama. 204 | 205 | This class handles loading, querying, and managing the LLM, including 206 | starting and stopping the Ollama service if necessary. 207 | 208 | Attributes: 209 | setup_cfg (Path): The path to the configuration file. 210 | cfg (dict): The configuration loaded from the config file. 211 | logger (logging.Logger): The logger instance for the class. 212 | self_started_ollama (bool): Indicates if the instance started the ollama service. 213 | ollama_service_manager (OllamaServiceManager): The manager for the ollama service. 214 | model_name (str): The name of the LLM model to use. 215 | model_temperature (float): The temperature parameter for the LLM. 216 | llm (ChatOllama): The ChatOllama instance. 217 | structured_llm (None | ChatOllama): The structured llm model used for invoke. 218 | """ 219 | 220 | def __init__( 221 | self, 222 | config_file: Path | str, 223 | model_name: str = "llama3.1:latest", 224 | temperature: float = 0, 225 | ) -> None: 226 | """ 227 | Initializes the AskLLM instance. 228 | 229 | Args: 230 | config_file (Path | str): The path to the configuration file. 231 | model_name (str, optional): The name of the LLM model to use. Defaults to "llama3.1:latest". 232 | temperature (float, optional): The temperature parameter for the LLM. Defaults to 0. 233 | 234 | Raises: 235 | FileNotFoundError: If the configuration file does not exist. 236 | ValueError: If the configuration file is not a YAML file. 237 | """ 238 | # if config_file is str convert it to a Pathlike 239 | self.setup_cfg = Path(config_file) if isinstance(config_file, str) else config_file 240 | 241 | if not self.setup_cfg.exists(): 242 | raise FileNotFoundError(f"File {str(self.setup_cfg)} does not exist") 243 | 244 | if self.setup_cfg.suffix != ".yml": 245 | raise ValueError(f"File {str(self.setup_cfg)} is not a yaml file") 246 | 247 | # load the yml file 248 | with open(self.setup_cfg) as f: 249 | self.cfg = yaml.safe_load(f) 250 | 251 | self.logger = get_logger(__name__) 252 | 253 | self.self_started_ollama: bool = False 254 | self.ollama_service_manager = OllamaServiceManager(logger=self.logger) 255 | self.model_name = model_name 256 | self.model_temperature = temperature 257 | self.llm: ChatOllama = self._load_llm_model(self.model_name, self.model_temperature) 258 | self.structured_llm = None 259 | 260 | def _load_llm_model(self, model_name: str, temperature: float) -> ChatOllama: 261 | """ 262 | Loads the specified LLM model. 263 | 264 | Starts the Ollama service if it's not already running, and downloads the model if it's not already downloaded. 265 | 266 | Args: 267 | model_name (str): The name of the model to load. 268 | temperature (float): The temperature parameter for the LLM. 269 | 270 | Returns: 271 | ChatOllama: The loaded ChatOllama instance. 272 | """ 273 | if not self.ollama_service_manager.is_service_running(): 274 | self.logger.warning("Ollama service is not running. Attempting to start it.") 275 | self.ollama_service_manager.start_service() 276 | self.self_started_ollama = True 277 | self.logger.warning(f"Self started ollama service: {self.self_started_ollama}") 278 | self.logger.info("Ollama service found") 279 | 280 | if model_name not in self.ollama_service_manager.get_list_of_downloaded_files(): 281 | self.logger.info(f"Downloading model {model_name}") 282 | self.ollama_service_manager.get_llm_model(model_name) 283 | self.logger.info(f"Model {model_name} downloaded") 284 | else: 285 | self.logger.info(f"Model {model_name} already downloaded") 286 | 287 | return ChatOllama(model=model_name, temperature=temperature) 288 | 289 | def invoke(self, input_text: str) -> dict | BaseModel: 290 | """ 291 | Invokes the LLM with the given input text and returns a structured output. 292 | 293 | This method uses a predefined prompt to query the LLM and expects a response that conforms to the YoutubeDetails schema. 294 | 295 | Args: 296 | input_text (str): The input text to send to the LLM. 297 | 298 | Returns: 299 | dict | BaseModel: A dictionary or a BaseModel instance containing the LLM's response. 300 | """ 301 | prompt = ChatPromptTemplate( 302 | messages=[ 303 | SystemMessage( 304 | "You are a Youtubers digital assistant. Please provide creative, engaging, clickbait, key word rich and accurate information to the user." 305 | ), 306 | SystemMessage( 307 | "The Youtuber which runs an AI automated Youtube channel. The entire process involves me finding a script, making a video about it, and then using an AI image creator to make a thumbnail for the video." 308 | ), 309 | SystemMessage( 310 | "Be short and concise. Be articulate, no need to be verbose and justify your answer." 311 | ), 312 | HumanMessage(f"Script:\n{input_text}"), 313 | ], 314 | ) 315 | self.structured_llm = self.llm.with_structured_output(YoutubeDetails, include_raw=True) 316 | return self.structured_llm.invoke(prompt.messages) 317 | 318 | def invoke_image_describer(self, script: str, input_text: str) -> dict | BaseModel: 319 | """ 320 | Invokes the LLM to generate an image description based on the given script and input text. 321 | 322 | This method uses a predefined prompt to query the LLM and expects a response that conforms to the ImageDescriber schema. 323 | 324 | Args: 325 | script (str): The script to base the image description on. 326 | input_text (str): Additional text to guide the image description. 327 | 328 | Returns: 329 | dict | BaseModel: A dictionary or a BaseModel instance containing the LLM's image description. 330 | """ 331 | prompt = ChatPromptTemplate( 332 | messages=[ 333 | SystemMessage( 334 | "You are an AI image prompt generator, who specializes in image description. Helping users to create AI image prompts." 335 | ), 336 | SystemMessage( 337 | "The user provides the complete and the text to generate the prompt for. You should provide a detailed and creative description of an image. Note: Avoid mentioning names or text titles to be in the description. The more detailed and imaginative your description, the more interesting the resulting image will be." 338 | ), 339 | SystemMessage("Keep the description with 500 characters or less."), 340 | HumanMessage(f"Script:\n{script}"), 341 | HumanMessage(f"Text:\n{input_text}"), 342 | ] 343 | ) 344 | self.structured_llm = self.llm.with_structured_output(ImageDescriber, include_raw=True) 345 | return self.structured_llm.invoke(prompt.messages) 346 | 347 | def quit_llm(self): 348 | """ 349 | Shuts down the LLM and the Ollama service if it was started by this instance. 350 | 351 | This method stops any running models, and if the Ollama service was started by this 352 | instance, it stops the service as well. Finally, it deletes the instance variables to 353 | clean up. 354 | """ 355 | self.ollama_service_manager.stop_running_model(self.model_name) 356 | if self.self_started_ollama: 357 | self.ollama_service_manager.stop_service() 358 | # Delete all instance variables 359 | for attr in list(self.__dict__.keys()): 360 | try: 361 | self.logger.debug(f"Deleting {attr}") 362 | if attr == "logger": 363 | continue 364 | delattr(self, attr) 365 | except Exception as e: 366 | self.logger.error(f"Error deleting {attr}: {e}") 367 | return 368 | 369 | 370 | # Pydantic YoutubeDetails 371 | class YoutubeDetails(BaseModel): 372 | """Details of the YouTube video.""" 373 | 374 | title: str = Field( 375 | description="A fun and engaging title for the Youtube video. Has to be related to the reddit post and is not more than 100 characters." 376 | ) 377 | description: str = Field( 378 | description="Description of the Youtube video. It should be a simple summary of the video." 379 | ) 380 | tags: list[str] = Field( 381 | description="Tags of the Youtube video. tags are single words with no space in them." 382 | ) 383 | thumbnail_description: str = Field( 384 | description="Thumbnail description of the Youtube video. provide detailed and creative descriptions that will inspire unique and interesting images from the AI. Keep in mind that the AI is capable of understanding a wide range of language and can interpret abstract concepts, so feel free to be as imaginative and descriptive as possible. For example, you could describe a scene from a futuristic city, or a surreal landscape filled with strange creatures. The more detailed and imaginative your description, the more interesting the resulting image will be." 385 | ) 386 | 387 | 388 | # Image Describer 389 | class ImageDescriber(BaseModel): 390 | """Given text, Provides a detailed and creative description for an image.""" 391 | 392 | description: str = Field( 393 | description="Provide a detailed and creative description that will inspire unique and interesting images from the AI. Keep in mind that the AI is capable of understanding a wide range of language and can interpret abstract concepts, so feel free to be as imaginative and descriptive as possible. For example, you could describe the scene in a pictorial way adding more details or elaborating the scenario. The more detailed and imaginative your description, the more interesting the resulting image will be." 394 | ) 395 | -------------------------------------------------------------------------------- /ShortsMaker/generate_image.py: -------------------------------------------------------------------------------- 1 | # https://huggingface.co/docs/diffusers/main/en/index 2 | import os 3 | from pathlib import Path 4 | from time import sleep 5 | 6 | import torch 7 | import yaml 8 | from diffusers import AutoencoderKL, FluxPipeline 9 | from transformers import CLIPTextModel, T5EncoderModel 10 | 11 | from .utils import get_logger 12 | 13 | MODEL_UNLOAD_DELAY = 5 14 | 15 | 16 | class GenerateImage: 17 | """ 18 | A class for generating images using different Flux models from Hugging Face. 19 | 20 | This class provides methods to load and use various Flux models, including 21 | FLUX.1-dev, FLUX.1-schnell, and a custom Pixel Wave model. It handles model 22 | loading, image generation, and resource cleanup. 23 | """ 24 | 25 | def __init__(self, config_file: Path | str) -> None: 26 | """ 27 | Initializes the GenerateImage class. 28 | 29 | Args: 30 | config_file (Path | str): Path to the YAML configuration file containing settings 31 | such as the Hugging Face access token. 32 | 33 | Raises: 34 | FileNotFoundError: If the specified configuration file does not exist. 35 | ValueError: If the configuration file is not a YAML file. 36 | """ 37 | # if config_file is str convert it to a Pathlike 38 | self.setup_cfg = Path(config_file) if isinstance(config_file, str) else config_file 39 | 40 | if not self.setup_cfg.exists(): 41 | raise FileNotFoundError(f"File {str(self.setup_cfg)} does not exist") 42 | 43 | if self.setup_cfg.suffix != ".yml": 44 | raise ValueError(f"File {str(self.setup_cfg)} is not a yaml file") 45 | 46 | # load the yml file 47 | with open(self.setup_cfg) as f: 48 | self.cfg = yaml.safe_load(f) 49 | 50 | self.logger = get_logger(__name__) 51 | 52 | if "hugging_face_access_token" not in self.cfg: 53 | self.logger.warning( 54 | "Please add your huggingface access token to use Flux.1-Dev.\nDefaulting to use Flux.1-Schnell" 55 | ) 56 | 57 | self.pipe: FluxPipeline | None = None 58 | 59 | def _load_model(self, model_id: str) -> bool: 60 | """ 61 | Loads a Flux model from Hugging Face. 62 | 63 | Args: 64 | model_id (str): The ID of the Flux model to load from Hugging Face. 65 | 66 | Returns: 67 | bool: True if the model was loaded successfully. 68 | 69 | Raises: 70 | RuntimeError: If there is an error loading the Flux model. 71 | """ 72 | try: 73 | self.pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) 74 | self.logger.info(f"Loading Flux model from {model_id}") 75 | # to run on low vram GPUs (i.e. between 4 and 32 GB VRAM) 76 | # Choose ONE of the following: 77 | # pipe.enable_model_cpu_offload() # Best for low-VRAM GPUs 78 | self.pipe.enable_sequential_cpu_offload() # Alternative for moderate VRAM GPUs 79 | self.pipe.vae.enable_slicing() # Reduces memory usage for decoding 80 | self.pipe.vae.enable_tiling() # Further optimizes VAE computation 81 | 82 | # casting here instead of in the pipeline constructor because doing so in the constructor loads all models into CPU memory at once 83 | self.pipe.to(torch.float16) 84 | 85 | self.logger.info("Flux model loaded") 86 | return True 87 | except Exception as e: 88 | self.logger.error(e) 89 | raise RuntimeError("Error in loading the Flux model") 90 | 91 | # @retry(max_retries=MAX_RETRIES, delay=DELAY, notify=NOTIFY) 92 | def use_huggingface_flux_dev( 93 | self, 94 | prompt: str, 95 | output_path: str, 96 | negative_prompt: str = "", 97 | model_id: str = "black-forest-labs/FLUX.1-dev", 98 | steps: int = 20, 99 | seed: int = 0, 100 | height: int = 1024, 101 | width: int = 1024, 102 | guidance_scale: float = 3.5, 103 | ) -> bool: 104 | """ 105 | Generates an image using the FLUX.1-dev model from Hugging Face. 106 | 107 | Args: 108 | prompt (str): The text prompt to guide image generation. 109 | output_path (str): The path to save the generated image. 110 | negative_prompt (str): The text prompt to guide what the model should avoid generating. Defaults to "". 111 | model_id (str): The ID of the FLUX.1-dev model on Hugging Face. Defaults to "black-forest-labs/FLUX.1-dev". 112 | steps (int): The number of inference steps. Defaults to 20. 113 | seed (int): The random seed for image generation. Defaults to 0. 114 | height (int): The height of the output image. Defaults to 1024. 115 | width (int): The width of the output image. Defaults to 1024. 116 | guidance_scale (float): The guidance scale for image generation. Defaults to 3.5. 117 | 118 | Returns: 119 | bool: True if the image was generated and saved successfully. 120 | """ 121 | self.logger.info("This image generator uses the Flux Dev model.") 122 | # Add access token to environment variable 123 | if "hugging_face_access_token" in self.cfg and os.environ.get("HF_TOKEN") is None: 124 | self.logger.info("Setting HF_TOKEN environment variable") 125 | os.environ["HF_TOKEN"] = self.cfg["hugging_face_access_token"] 126 | 127 | self._load_model(model_id) 128 | 129 | self.logger.info("Generating image") 130 | image = self.pipe( 131 | prompt, 132 | negative_prompt, 133 | guidance_scale=guidance_scale, 134 | output_type="pil", 135 | num_inference_steps=steps, 136 | max_sequence_length=512, 137 | height=height, 138 | width=width, 139 | generator=torch.Generator("cpu").manual_seed(seed), 140 | ).images[0] 141 | image.save(output_path) 142 | self.logger.info(f"Image saved to {output_path}") 143 | 144 | del self.pipe 145 | self.pipe = None 146 | if torch.cuda.is_available(): 147 | torch.cuda.empty_cache() 148 | self.logger.info("Wait for 5 seconds, So that the GPU memory can be freed") 149 | sleep(MODEL_UNLOAD_DELAY) 150 | return True 151 | 152 | # @retry(max_retries=MAX_RETRIES, delay=DELAY, notify=NOTIFY) 153 | def use_huggingface_flux_schnell( 154 | self, 155 | prompt: str, 156 | output_path: str, 157 | negative_prompt: str = "", 158 | model_id: str = "black-forest-labs/FLUX.1-schnell", 159 | steps: int = 4, 160 | seed: int = 0, 161 | height: int = 1024, 162 | width: int = 1024, 163 | guidance_scale: float = 0.0, 164 | ) -> bool: 165 | """ 166 | Generates an image using the FLUX.1-schnell model from Hugging Face. 167 | 168 | Args: 169 | prompt (str): The text prompt to guide image generation. 170 | output_path (str): The path to save the generated image. 171 | negative_prompt (str): The text prompt to guide what the model should avoid generating. Defaults to "". 172 | model_id (str): The ID of the FLUX.1-schnell model on Hugging Face. Defaults to "black-forest-labs/FLUX.1-schnell". 173 | steps (int): The number of inference steps. Defaults to 4. 174 | seed (int): The random seed for image generation. Defaults to 0. 175 | height (int): The height of the output image. Defaults to 1024. 176 | width (int): The width of the output image. Defaults to 1024. 177 | guidance_scale (float): The guidance scale for image generation. Defaults to 0.0. 178 | 179 | Returns: 180 | bool: True if the image was generated and saved successfully. 181 | """ 182 | self.logger.info("This image generator uses the Flux Schnell model.") 183 | 184 | self._load_model(model_id) 185 | 186 | self.logger.info("Generating image") 187 | image = self.pipe( 188 | prompt, 189 | negative_prompt, 190 | guidance_scale=guidance_scale, 191 | output_type="pil", 192 | num_inference_steps=steps, 193 | max_sequence_length=256, 194 | height=height, 195 | width=width, 196 | generator=torch.Generator("cpu").manual_seed(seed), 197 | ).images[0] 198 | image.save(output_path) 199 | self.logger.info(f"Image saved to {output_path}") 200 | 201 | del self.pipe 202 | self.pipe = None 203 | if torch.cuda.is_available(): 204 | torch.cuda.empty_cache() 205 | self.logger.info("Wait for 5 seconds, So that the GPU memory can be freed") 206 | sleep(MODEL_UNLOAD_DELAY) 207 | return True 208 | 209 | def use_flux_pixel_wave( 210 | self, 211 | prompt: str, 212 | output_path: str, 213 | model_id: str = "https://huggingface.co/mikeyandfriends/PixelWave_FLUX.1-schnell_03/blob/main/pixelwave_flux1_schnell_fp8_03.safetensors", 214 | steps: int = 4, 215 | seed: int = 595570113709576, 216 | height: int = 1024, 217 | width: int = 1024, 218 | guidance_scale: float = 3.5, 219 | ) -> bool: 220 | """ 221 | Generates an image using the custom Flux Pixel Wave model. 222 | 223 | Args: 224 | prompt (str): The text prompt to guide image generation. 225 | output_path (str): The path to save the generated image. 226 | model_id (str): The URL or path to the Pixel Wave model file. Defaults to "https://huggingface.co/mikeyandfriends/PixelWave_FLUX.1-schnell_03/blob/main/pixelwave_flux1_schnell_fp8_03.safetensors". 227 | steps (int): The number of inference steps. Defaults to 4. 228 | seed (int): The random seed for image generation. Defaults to 595570113709576. 229 | height (int): The height of the output image. Defaults to 1024. 230 | width (int): The width of the output image. Defaults to 1024. 231 | guidance_scale (float): The guidance scale for image generation. Defaults to 3.5. 232 | 233 | Returns: 234 | bool: True if the image was generated and saved successfully. 235 | """ 236 | self.logger.info("This image generator uses the Flux Pixel Wave model.") 237 | 238 | text_encoder = CLIPTextModel.from_pretrained( 239 | "black-forest-labs/FLUX.1-schnell", subfolder="text_encoder" 240 | ) 241 | text_encoder_2 = T5EncoderModel.from_pretrained( 242 | "black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2" 243 | ) 244 | vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="vae") 245 | 246 | self.pipe = FluxPipeline.from_single_file( 247 | model_id, 248 | use_safetensors=True, 249 | torch_dtype=torch.bfloat16, 250 | # Load additional not included in safetensor 251 | text_encoder=text_encoder, 252 | text_encoder_2=text_encoder_2, 253 | vae=vae, 254 | ) 255 | 256 | # to run on low vram GPUs (i.e. between 4 and 32 GB VRAM) 257 | # Choose ONE of the following: 258 | # pipe.enable_model_cpu_offload() # Best for low-VRAM GPUs 259 | self.pipe.enable_sequential_cpu_offload() # Alternative for moderate VRAM GPUs 260 | self.pipe.vae.enable_slicing() # Reduces memory usage for decoding 261 | self.pipe.vae.enable_tiling() # Further optimizes VAE computation 262 | 263 | # casting here instead of in the pipeline constructor because doing so in the constructor loads all models into CPU memory at once 264 | self.pipe.to(torch.float16) 265 | 266 | self.logger.info("Generating image") 267 | image = self.pipe( 268 | prompt, 269 | guidance_scale=guidance_scale, 270 | output_type="pil", 271 | num_inference_steps=steps, 272 | max_sequence_length=256, 273 | height=height, 274 | width=width, 275 | generator=torch.Generator("cpu").manual_seed(seed), 276 | ).images[0] 277 | image.save(output_path) 278 | self.logger.info(f"Image saved to {output_path}") 279 | 280 | del self.pipe 281 | del text_encoder 282 | del text_encoder_2 283 | del vae 284 | 285 | self.pipe = None 286 | if torch.cuda.is_available(): 287 | torch.cuda.empty_cache() 288 | self.logger.info("Wait for 5 seconds, So that the GPU memory can be freed") 289 | sleep(MODEL_UNLOAD_DELAY) 290 | return True 291 | 292 | def quit(self) -> None: 293 | """ 294 | Cleans up resources and exits the image generator. 295 | 296 | This method clears the CUDA cache (if available) and attempts to 297 | delete all instance variables to free up memory. 298 | """ 299 | self.logger.info("Quitting the image generator") 300 | if torch.cuda.is_available(): 301 | torch.cuda.empty_cache() 302 | # Delete all instance variables 303 | for attr in list(self.__dict__.keys()): 304 | try: 305 | self.logger.debug(f"Deleting {attr}") 306 | if attr == "logger": 307 | continue 308 | delattr(self, attr) 309 | except Exception as e: 310 | self.logger.error(f"Error deleting {attr}: {e}") 311 | return None 312 | -------------------------------------------------------------------------------- /ShortsMaker/moviepy_create_video.py: -------------------------------------------------------------------------------- 1 | import random 2 | import secrets 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from typing import Any 6 | 7 | import yaml 8 | from moviepy import ( 9 | AudioFileClip, 10 | CompositeAudioClip, 11 | CompositeVideoClip, 12 | TextClip, 13 | VideoClip, 14 | VideoFileClip, 15 | afx, 16 | vfx, 17 | ) 18 | 19 | from .utils import ( 20 | COLORS_DICT, 21 | download_youtube_music, 22 | download_youtube_video, 23 | get_logger, 24 | ) 25 | 26 | random.seed(secrets.randbelow(1000000)) 27 | 28 | 29 | @dataclass 30 | class VideoConfig: 31 | cache_dir: Path 32 | assets_dir: Path 33 | audio_config: dict 34 | video_config: dict 35 | logging_config: dict 36 | 37 | 38 | class MoviepyCreateVideo: 39 | """ 40 | Class for creating videos from media components using MoviePy. 41 | 42 | This class facilitates the creation of videos by integrating various media 43 | components such as background videos, audio tracks, music, fonts, and credits. 44 | It provides settings for fading effects, delays, and logging during video 45 | generation. It ensures proper initialization and handling of required 46 | directories, configurations, and external dependencies like FFmpeg. 47 | The class allows flexibility in providing media paths or configuring the 48 | video creation process dynamically. 49 | 50 | Attributes: 51 | DEFAULT_FADE_TIME (int): Default duration for fade effects applied to the video. 52 | DEFAULT_DELAY (int): Default delay applied between video transitions or sections. 53 | DEFAULT_SPEED_FACTOR (float): Default speed factor applied to the video.) 54 | REQUIRED_DIRECTORIES (list): List of essential directories required for using the class. 55 | PUNCTUATION_MARKS (list): List of punctuation marks used for processing transcripts or text inputs. 56 | """ 57 | 58 | DEFAULT_FADE_TIME = 2 59 | DEFAULT_DELAY = 1 60 | DEFAULT_SPEED_FACTOR = 1 61 | REQUIRED_DIRECTORIES = ["video_dir", "music_dir", "fonts_dir", "credits_dir"] 62 | PUNCTUATION_MARKS = [".", ";", ":", "!", "?", ","] 63 | 64 | def __init__( 65 | self, 66 | config_file: Path | str, 67 | bg_video_path: Path | str = None, 68 | add_credits: bool = True, 69 | credits_path: Path | str = None, 70 | audio_path: Path | str = None, 71 | music_path: Path | str = None, 72 | transcript_path: Path | str = None, 73 | font_path: Path | str = None, 74 | fade_time: int = DEFAULT_FADE_TIME, 75 | delay: int = DEFAULT_DELAY, 76 | speed_factor: float = DEFAULT_SPEED_FACTOR, 77 | ) -> None: 78 | self.fade_time = fade_time 79 | self.delay = delay 80 | self.speed_factor = speed_factor 81 | 82 | # Initialize configuration 83 | self.config = self._load_configuration(config_file) 84 | self.logger = get_logger(__name__) 85 | self._verify_ffmpeg() 86 | 87 | # Initialize directories 88 | self._setup_directories() 89 | 90 | # Initialize media components 91 | self.audio_clip = self._initialize_audio(audio_path) 92 | self.audio_clip_bitrate = self.audio_clip.reader.bitrate 93 | self.logger.info(f"Audio Duration: {self.audio_clip.duration:.2f}s") 94 | 95 | self.audio_transcript = self._load_transcript(transcript_path) 96 | self.audio_transcript = self.preprocess_audio_transcript() 97 | self.word_transcript, self.sentences_transcript = ( 98 | self.process_audio_transcript_to_word_and_sentences_transcript() 99 | ) 100 | 101 | self.bg_video = self._initialize_background_video(bg_video_path) 102 | self.bg_video = self.prepare_background_video() 103 | self.bg_video_bitrate = self.bg_video.reader.bitrate 104 | 105 | self.music_clip = self._initialize_music(music_path) 106 | self.music_clip_bitrate = self.music_clip.reader.bitrate 107 | 108 | self.font_path = self._initialize_font(font_path) 109 | 110 | self.add_credits: bool = add_credits 111 | self.credits_video = self._initialize_credits(credits_path) if add_credits else None 112 | 113 | # Initialize color 114 | self.color = self._select_random_color() 115 | self.logger.info(f"Using color {self.color}") 116 | 117 | @staticmethod 118 | def _load_configuration(config_file: Path | str) -> VideoConfig: 119 | """ 120 | Loads and validates a YAML configuration file, and then parses its content 121 | into a `VideoConfig` object. The method checks for the existence of the file 122 | and ensures that it has the correct `.yml` extension. If the validation fails, 123 | it raises a `ValueError`. Otherwise, the YAML content is loaded and used to 124 | instantiate a `VideoConfig` with the appropriate properties. 125 | 126 | Args: 127 | config_file: The path to the configuration file provided as a `Path` object 128 | or a `str`. The file must exist and have a `.yml` extension. 129 | 130 | Returns: 131 | A `VideoConfig` object created using the data parsed from the provided 132 | configuration file. 133 | 134 | Raises: 135 | ValueError: If the configuration file does not exist or does not have a 136 | `.yml` extension. 137 | """ 138 | config_path = Path(config_file) 139 | if not config_path.exists() or config_path.suffix != ".yml": 140 | raise ValueError(f"Invalid configuration file: {config_path}") 141 | 142 | with open(config_path) as f: 143 | cfg = yaml.safe_load(f) 144 | 145 | return VideoConfig( 146 | cache_dir=Path(cfg["cache_dir"]), 147 | assets_dir=Path(cfg["assets_dir"]), 148 | audio_config=cfg.get("audio", {}), 149 | video_config=cfg.get("video", {}), 150 | logging_config=cfg.get("logging", {}), 151 | ) 152 | 153 | def _verify_ffmpeg(self) -> None: 154 | """ 155 | Verifies the availability of the FFmpeg utility on the system. 156 | 157 | This method checks if FFmpeg is installed and available in the system's PATH 158 | environment variable. It does so by attempting to execute the FFmpeg command 159 | with the `-version` argument, which displays FFmpeg's version details. If 160 | FFmpeg is not installed or cannot be accessed, an error is logged, and the 161 | exception is re-raised for further handling. 162 | 163 | Raises: 164 | Exception: If FFmpeg is not installed or available in the PATH. The specific 165 | exception raised during the failure will also be re-raised. 166 | """ 167 | try: 168 | import subprocess 169 | 170 | subprocess.run(["ffmpeg", "-version"], check=True) 171 | except Exception as e: 172 | self.logger.error("ffmpeg is not installed or not available in path") 173 | raise e 174 | 175 | def _setup_directories(self) -> None: 176 | """ 177 | Sets up the necessary directories for storing assets used in the application. 178 | 179 | This method configures various directories such as those for background 180 | videos, background music, fonts, and credits. Additionally, it ensures 181 | that the directories required for background videos and background music 182 | are created if they do not already exist. 183 | 184 | Attributes: 185 | video_dir: Path to the directory where background videos are stored. 186 | music_dir: Path to the directory where background music is stored. 187 | fonts_dir: Path to the directory where fonts are stored. 188 | credits_dir: Path to the directory where credits data is stored. 189 | 190 | Raises: 191 | FileNotFoundError: If the base assets directory does not exist or is 192 | inaccessible during operation. 193 | """ 194 | self.video_dir = self.config.assets_dir / "background_videos" 195 | self.music_dir = self.config.assets_dir / "background_music" 196 | self.fonts_dir = self.config.assets_dir / "fonts" 197 | self.credits_dir = self.config.assets_dir / "credits" 198 | 199 | for directory in [self.video_dir, self.music_dir]: 200 | directory.mkdir(parents=True, exist_ok=True) 201 | 202 | def _initialize_audio(self, audio_path: Path | str = None) -> AudioFileClip: 203 | """ 204 | Initializes the audio by loading an audio file from the specified path or a default 205 | directory set in the configuration. If no audio path is provided, the method will 206 | attempt to find an audio file in the configured cache directory. 207 | 208 | Args: 209 | audio_path (Path | str, optional): The path to the audio file. If None, the method 210 | will use the `cache_dir` and `output_audio_file` settings from the 211 | configuration to locate the audio file. 212 | 213 | Returns: 214 | AudioFileClip: An instance of AudioFileClip initialized with the located audio file. 215 | 216 | Raises: 217 | ValueError: If no `audio_path` is provided and the `audio_config` in the 218 | configuration does not specify an `output_audio_file` key. 219 | """ 220 | if audio_path is None: 221 | self.logger.info( 222 | f"No audio path provided. Using the audio directory {self.config.cache_dir} to find an audio file." 223 | ) 224 | if "output_audio_file" not in self.config.audio_config: 225 | raise ValueError("Missing 'output_audio_file' in 'audio' section in setup.yml") 226 | audio_path = self.config.cache_dir / self.config.audio_config["output_audio_file"] 227 | self.logger.info(f"Using audio file {audio_path}") 228 | return AudioFileClip(audio_path) 229 | 230 | def _initialize_background_video(self, bg_video_path: Path | str = None) -> VideoFileClip: 231 | """ 232 | Initializes the background video for further video processing. 233 | 234 | If a `bg_video_path` is not explicitly provided, the method will attempt to use the URLs specified 235 | in the video configuration to download a random background video. The downloaded video path will 236 | then be used to initialize the `VideoFileClip` instance. If the video configuration does not 237 | contain `background_videos_urls`, a ValueError is raised. 238 | 239 | Args: 240 | bg_video_path (Path | str, optional): The path to the background video file. If None, a 241 | video is chosen or downloaded based on the configuration. 242 | 243 | Raises: 244 | ValueError: If `background_videos_urls` is missing in the video configuration. 245 | 246 | Returns: 247 | VideoFileClip: The background video instance without audio. 248 | """ 249 | if bg_video_path is None: 250 | self.logger.info( 251 | "No background video path provided. Using the background_videos_urls to download a background video." 252 | ) 253 | if "background_videos_urls" not in self.config.video_config: 254 | raise ValueError("Missing 'background_videos_urls' in 'video' section in setup.yml") 255 | # choose a random url 256 | bg_video_url = random.choice(self.config.video_config["background_videos_urls"]) 257 | self.logger.info(f"Using bg_video_url {bg_video_url}") 258 | bg_videos_path = download_youtube_video( 259 | bg_video_url, 260 | self.video_dir, 261 | ) 262 | bg_video_path = random.choice(bg_videos_path).absolute() 263 | self.logger.info(f"Using bg_video_path {bg_video_path}") 264 | return VideoFileClip(bg_video_path, audio=False) 265 | 266 | @staticmethod 267 | def _select_random_color() -> tuple[int, int, int, int]: 268 | """ 269 | Selects a random color from a pre-defined dictionary of colors. 270 | 271 | This method accesses a dictionary containing color codes, selects a random 272 | key from the dictionary, and retrieves the corresponding color value. It is 273 | used to dynamically choose colors for various operations requiring random 274 | color assignments. 275 | 276 | Returns: 277 | tuple[int, int, int, int]: A tuple representing the RGBA color values, 278 | where each value corresponds to red, green, blue, and alpha channels. 279 | 280 | Raises: 281 | KeyError: If the dictionary is empty or an invalid key is accessed. 282 | """ 283 | return COLORS_DICT[random.choice(list(COLORS_DICT.keys()))] 284 | 285 | def _load_transcript(self, transcript_path: Path | str) -> list[dict[str, Any]]: 286 | """ 287 | Loads and parses a transcript file in YAML format. The function determines the path of the 288 | transcript file, either based on the provided argument or through the configuration setup. 289 | It ensures the path validity and raises appropriate errors if the file cannot be found or 290 | if key configuration fields are missing. If the transcript file exists, it reads and 291 | deserializes its content into a Python list of dictionaries. 292 | 293 | Args: 294 | transcript_path (Path | str): Path to the transcript file. If None, this argument will 295 | default to a path derived from the configuration settings. 296 | 297 | Raises: 298 | ValueError: If `transcript_path` is None and configuration settings do not specify 299 | 'transcript_json' in the 'audio' section of `setup.yml`. 300 | ValueError: If the resolved `transcript_path` does not exist or the file cannot be found. 301 | 302 | Returns: 303 | list[dict[str, Any]]: The content of the transcript file parsed as a list of dictionaries. 304 | """ 305 | if transcript_path is None: 306 | self.logger.info( 307 | f"No transcript path provided. Using the audio directory {self.config.cache_dir} to find a transcript file." 308 | ) 309 | if "transcript_json" not in self.config.audio_config: 310 | raise ValueError("Missing 'transcript_json' in 'audio' section in setup.yml") 311 | transcript_path = self.config.cache_dir / self.config.audio_config["transcript_json"] 312 | self.logger.info(f"Using transcript file {transcript_path}") 313 | path = Path(transcript_path) 314 | if not path.exists(): 315 | raise ValueError(f"Transcript file not found: {path}") 316 | 317 | with open(transcript_path) as audio_transcript_file: 318 | return yaml.safe_load(audio_transcript_file) 319 | 320 | def _initialize_music(self, music_path: Path | str) -> AudioFileClip: 321 | """ 322 | Initializes and loads a music file, handling cases where the music path is not provided 323 | by downloading music from a specified source. 324 | 325 | If a music path is not given, a random URL is selected from the list of background music 326 | URLs defined in the video configuration, and the corresponding music file is downloaded 327 | and set as the music path. Ensures the music file exists before returning it as an 328 | AudioFileClip object. 329 | 330 | Args: 331 | music_path (Path | str): The path to the music file to be loaded. If None, the method 332 | will attempt to download and use a file from a configured URL. 333 | 334 | Returns: 335 | AudioFileClip: An AudioFileClip object representing the loaded music file. 336 | 337 | Raises: 338 | ValueError: If the `background_music_urls` is missing in the video configuration when 339 | no `music_path` is provided, or if the specified music file does not exist. 340 | """ 341 | if music_path is None: 342 | self.logger.info( 343 | "No music path provided. Using the background_music_urls to download a background music." 344 | ) 345 | if "background_music_urls" not in self.config.video_config: 346 | raise ValueError("Missing 'background_music_urls' in 'video' section in setup.yml") 347 | # choose a random url 348 | music_url = random.choice(self.config.video_config["background_music_urls"]) 349 | self.logger.info(f"Using music_url {music_url}") 350 | musics_path = download_youtube_music( 351 | music_url, 352 | self.music_dir, 353 | ) 354 | music_path = random.choice(musics_path).absolute() 355 | self.logger.info(f"Using music_path {music_path}") 356 | path = Path(music_path) 357 | if not path.exists(): 358 | raise ValueError(f"Music file not found: {path}") 359 | 360 | return AudioFileClip(path) 361 | 362 | def _initialize_font(self, font_path: Path | str) -> str: 363 | """ 364 | Initializes and selects a font file for use, either provided explicitly or chosen 365 | from a predefined directory of font files. 366 | 367 | If the input `font_path` is not provided, a random `.ttf` file is selected from 368 | the directory. If the directory does not contain any `.ttf` files, the function raises 369 | an exception. If `font_path` is supplied but points to a non-existent file, an exception 370 | is also raised. 371 | 372 | Args: 373 | font_path (Path | str): The path to the font file to be initialized. If None, 374 | a font is randomly selected from the predefined font directory. 375 | 376 | Raises: 377 | ValueError: Raised when no font files exist in the predefined font directory, 378 | or when the given `font_path` does not exist. 379 | 380 | Returns: 381 | str: The absolute path of the selected font file, ensuring any valid input font 382 | file or selected file from the directory is returned in string format. 383 | """ 384 | if font_path is None: 385 | self.logger.info( 386 | f"No font path provided. Using the fonts directory {self.fonts_dir} to find a font." 387 | ) 388 | font_files = list(self.fonts_dir.glob("*.ttf")) 389 | if not font_files: 390 | raise ValueError(f"No font files found in {self.fonts_dir}") 391 | return random.choice(font_files).absolute() 392 | path = Path(font_path) 393 | if not path.exists(): 394 | raise ValueError(f"Font file not found: {path}") 395 | return str(path) 396 | 397 | def _initialize_credits(self, credits_path: Path | str) -> VideoFileClip: 398 | """ 399 | Initializes the credits video by either using a provided path or searching a default credits directory. 400 | Raises an error if no valid credits file is found or if the provided path does not exist. 401 | Additionally, applies a mask to the credits video if a corresponding mask file exists. 402 | 403 | Args: 404 | credits_path (Path | str): The path to the credits video file or directory. If None, 405 | the function will attempt to locate a credits file in the default credits directory. 406 | 407 | Returns: 408 | VideoFileClip: The VideoFileClip object of the credits video with the applied mask. 409 | 410 | Raises: 411 | FileNotFoundError: If no credits files are found in the default credits directory. 412 | ValueError: If the given credits_path does not exist. 413 | """ 414 | if credits_path is None: 415 | self.logger.info( 416 | f"No credits path provided. Using the credits directory {self.credits_dir} to find a credits file." 417 | ) 418 | credit_videos = list(self.credits_dir.glob("*.mp4")) 419 | if not credit_videos: 420 | raise FileNotFoundError("No credits files found in the credits directory") 421 | credits_path = self.credits_dir 422 | self.logger.info(f"Using credits file {credits_path}") 423 | 424 | path = Path(credits_path) 425 | if not path.exists(): 426 | raise ValueError(f"Credits file not found: {path}") 427 | 428 | self.credit_video_mask: VideoFileClip = VideoFileClip( 429 | path / "credits_mask.mp4", audio=False 430 | ).to_mask() 431 | 432 | return VideoFileClip(path / "credits.mp4", audio=False).with_mask(self.credit_video_mask) 433 | 434 | def preprocess_audio_transcript(self) -> list[dict[str, Any]]: 435 | """ 436 | Preprocesses the audio transcript to adjust segment boundaries. 437 | 438 | This method modifies the transcript by updating the "end" time of each audio 439 | segment to match the "start" time of the subsequent segment. This ensures that 440 | audio segments are properly aligned in cases where boundaries are not explicitly 441 | defined. The original transcript is updated in-place and returned. 442 | 443 | Returns: 444 | list[dict[str, Any]]: The updated list of audio transcript segments, 445 | each represented as a dictionary containing at least "start" and "end" keys. 446 | 447 | """ 448 | for i in range(len(self.audio_transcript) - 1): 449 | self.audio_transcript[i]["end"] = self.audio_transcript[i + 1]["start"] 450 | return self.audio_transcript 451 | 452 | def process_audio_transcript_to_word_and_sentences_transcript( 453 | self, 454 | ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: 455 | """ 456 | Processes audio transcript data into individual word-level and sentence-level 457 | transcripts. The method parses the `audio_transcript` to create two lists: a 458 | word transcript and a sentences transcript. Each word and sentence transcript 459 | contains details such as text, start time, and end time. Sentences are delimited 460 | based on specific punctuation marks defined in `PUNCTUATION_MARKS`. If an error 461 | occurs during word processing, it logs the error using the `logger` attribute. 462 | 463 | Returns: 464 | tuple[list[dict[str, Any]], list[dict[str, Any]]]: A tuple containing two 465 | lists. The first list is the word transcript, where each item contains 466 | information about individual words and their start and end times. The second 467 | list is the sentences transcript, where each item contains information about 468 | sentences, including the sentence text and its start and end times. 469 | """ 470 | word_transcript = [] 471 | sentences_transcript = [] 472 | 473 | sentence_start = 0 474 | sentence_end = 0 475 | 476 | word_start = 0 477 | word_end = 0 478 | 479 | sentence = "" 480 | 481 | for count, transcript in enumerate(self.audio_transcript): 482 | word = transcript["word"].strip() 483 | sentence += word + " " 484 | 485 | word_end = transcript["end"] 486 | word_transcript.append( 487 | { 488 | "word": sentence, 489 | "start": word_start, 490 | "end": word_end, 491 | } 492 | ) 493 | word_start = word_end 494 | 495 | try: 496 | if word[-1] in self.PUNCTUATION_MARKS and word != "...": 497 | sentence_end = word_end 498 | sentences_transcript.append( 499 | { 500 | "sentence": sentence, 501 | "start": sentence_start, 502 | "end": sentence_end, 503 | } 504 | ) 505 | sentence_start = sentence_end 506 | sentence = "" 507 | except Exception as e: 508 | self.logger.error(f"Error processing word '{word}': {e}") 509 | 510 | # Add final sentence if any remains 511 | if sentence != "": 512 | sentences_transcript.append( 513 | { 514 | "sentence": sentence, 515 | "start": sentence_start, 516 | "end": sentence_end, 517 | } 518 | ) 519 | 520 | return word_transcript, sentences_transcript 521 | 522 | def prepare_background_video(self) -> VideoFileClip: 523 | """ 524 | Prepares and processes a background video clip for use in a composition. 525 | 526 | This function modifies the background video clip by randomly selecting a segment, 527 | cropping it to a specific aspect ratio, applying crossfade effects, and logging relevant 528 | information about the original and processed video dimensions. 529 | 530 | Returns: 531 | VideoFileClip: The processed and modified video clip ready for use in further processing. 532 | 533 | Args: 534 | self: The instance of the class containing this method. 535 | 536 | Raises: 537 | AttributeError: If attributes such as `self.bg_video`, `self.audio_clip`, or `self.logger` 538 | are not properly defined in the class. 539 | ValueError: If the random start time exceeds the duration of the background video. 540 | """ 541 | width, height = self.bg_video.size 542 | 543 | self.logger.info( 544 | f"Original video - Width: {width}, Height: {height}, FPS: {self.bg_video.fps}" 545 | ) 546 | 547 | # Select random segment of appropriate length 548 | random_start = random.uniform(20, self.bg_video.duration - self.audio_clip.duration - 20) 549 | random_end = random_start + self.audio_clip.duration 550 | 551 | self.logger.info(f"Using video segment from {random_start:.2f}s to {random_end:.2f}s") 552 | 553 | # Crop and apply effects 554 | self.bg_video: VideoFileClip = self.bg_video.subclipped( 555 | start_time=random_start - self.delay, end_time=random_end + self.delay 556 | ) 557 | self.bg_video: VideoFileClip = self.bg_video.cropped( 558 | x_center=width / 2, width=int(height * 9 / 16) & -2 559 | ) 560 | self.bg_video: VideoFileClip = self.bg_video.with_effects( 561 | [vfx.FadeIn(self.fade_time), vfx.FadeOut(self.fade_time)] 562 | ) 563 | 564 | new_width, new_height = self.bg_video.size 565 | self.logger.info( 566 | f"Processed video - Width: {new_width}, Height: {new_height}, FPS: {self.bg_video.fps}" 567 | ) 568 | self.logger.info(f"Video Duration: {self.bg_video.duration:.2f}s") 569 | 570 | return self.bg_video 571 | 572 | def create_text_clips(self) -> list[TextClip]: 573 | """ 574 | Creates a list of text clips for use in video editing. 575 | 576 | This method generates text clips based on a word transcript where each word is associated with its start and end 577 | time. These text clips are created with specific stylistic properties such as font size, color, background color, 578 | alignment, and other visual attributes. The generated clips are then appended to the object's `text_clips` list 579 | and returned. 580 | 581 | Returns: 582 | list[TextClip]: A list of TextClip objects representing the visualized words with specified styling and timing. 583 | """ 584 | for word in self.word_transcript: 585 | clip = ( 586 | TextClip( 587 | font=self.font_path, 588 | text=word["word"], 589 | font_size=int(0.06 * self.bg_video.size[0]), 590 | size=(int(0.8 * self.bg_video.size[0]), int(0.8 * self.bg_video.size[0])), 591 | color=self.color, 592 | bg_color=(0, 0, 0, 100), 593 | text_align="center", 594 | method="caption", 595 | stroke_color="black", 596 | stroke_width=1, 597 | transparent=True, 598 | ) 599 | .with_start(word["start"] + self.delay) 600 | .with_end(word["end"] + self.delay) 601 | .with_position(("center", "center")) 602 | ) 603 | 604 | self.text_clips.append(clip) 605 | 606 | return self.text_clips 607 | 608 | def prepare_audio(self) -> CompositeAudioClip: 609 | """ 610 | Processes and combines audio clips to prepare the final audio track. 611 | 612 | This method applies a series of effects to a music clip, including looping to match 613 | the duration of a background video, adjusting the volume, and adding fade-in and fade-out 614 | effects. Additionally, it modifies the start time of another audio clip to introduce 615 | a delay. Once processed, it combines the music clip and the audio clip into a single 616 | composite audio track. 617 | 618 | Returns: 619 | CompositeAudioClip: A combined audio clip that includes the processed music and 620 | audio components. 621 | """ 622 | # Process music clip 623 | self.music_clip = self.music_clip.with_effects( 624 | [afx.AudioLoop(duration=self.bg_video.duration)] 625 | ) 626 | self.music_clip = self.music_clip.with_effects([afx.MultiplyVolume(factor=0.05)]) 627 | self.music_clip = self.music_clip.with_effects( 628 | [afx.AudioFadeIn(self.fade_time), afx.AudioFadeOut(self.fade_time)] 629 | ) 630 | 631 | self.audio_clip = self.audio_clip.with_start(self.delay) 632 | 633 | # Combine audio clips 634 | return CompositeAudioClip([self.music_clip, self.audio_clip]) 635 | 636 | def __call__( 637 | self, 638 | output_path: str = "output.mp4", 639 | codec: str = "mpeg4", 640 | preset: str = "medium", 641 | threads: int = 8, 642 | ) -> bool: 643 | """ 644 | Executes the video processing pipeline by assembling all video and audio elements, 645 | creating a composite video, and writing the final output to a file. The method allows 646 | customization of output file properties such as codec, preset, and threading. 647 | 648 | Args: 649 | output_path (str): The path to save the output video file. Defaults to "output.mp4". 650 | codec (str): The video codec to use for encoding the output video. Defaults to "mpeg4". 651 | preset (str): The compression preset to optimize encoding speed vs quality. Defaults to "medium". 652 | threads (int): The number of threads to use during encoding. Defaults to 8. 653 | 654 | Returns: 655 | bool: Returns True upon successful creation and saving of the video. 656 | """ 657 | # Create video clips 658 | self.video_clips: list[VideoFileClip | TextClip] = [self.bg_video] 659 | self.text_clips: list[TextClip] = [] 660 | self.text_clips = self.create_text_clips() 661 | self.video_clips.extend(self.text_clips) 662 | 663 | # Add the credits clip to the end of the video 664 | if self.add_credits: 665 | self.credits_video: VideoClip = self.credits_video.resized( 666 | width=int(0.8 * self.bg_video.size[0]) 667 | ) 668 | self.video_clips.append( 669 | self.credits_video.with_start( 670 | self.bg_video.duration - self.credits_video.duration 671 | ).with_position(("center", "bottom")) 672 | ) 673 | 674 | # Combine video clips 675 | output_video = CompositeVideoClip(self.video_clips) 676 | 677 | # Prepare final audio 678 | output_audio = self.prepare_audio() 679 | 680 | # Create final video 681 | final_video: CompositeVideoClip = output_video.with_audio(output_audio) 682 | 683 | # Write output file 684 | final_video.write_videofile( 685 | output_path, 686 | codec=codec, 687 | bitrate=f"{self.bg_video_bitrate}k", 688 | fps=self.bg_video.fps, 689 | audio_bitrate=f"{max(self.music_clip_bitrate, self.audio_clip_bitrate)}k", 690 | preset=preset, 691 | threads=threads, 692 | ) 693 | self.logger.info(f"Video successfully created at {output_path}") 694 | 695 | if self.speed_factor != 1: 696 | final_video.close() 697 | speedup_output_path = f"{output_path.split('.')[0]}_speed.mp4" 698 | self.logger.info(f"Speeding up video to {self.speed_factor}x") 699 | self.logger.info(f"Speed up video at {speedup_output_path}") 700 | import subprocess 701 | 702 | result = subprocess.run( 703 | [ 704 | "ffmpeg", 705 | "-i", 706 | f"{output_path}", 707 | "-filter_complex", 708 | f"[0:v]setpts=(1/{self.speed_factor})*PTS[v];[0:a]atempo={self.speed_factor}[a]", 709 | "-map", 710 | "[v]", 711 | "-map", 712 | "[a]", 713 | f"{speedup_output_path}", 714 | "-y", 715 | ], 716 | capture_output=True, 717 | text=True, 718 | ) 719 | if result.returncode == 0: 720 | self.logger.info("Speed adjustment completed successfully.") 721 | else: 722 | self.logger.error(f"Speed adjustment failed: {result.stderr}") 723 | 724 | output_video.close() 725 | output_audio.close() 726 | final_video.close() 727 | 728 | return True 729 | 730 | def quit(self) -> None: 731 | """ 732 | Closes and cleans up resources used by the instance. 733 | 734 | This method ensures that all open files, clips, or other resources associated 735 | with the instance are properly closed and the corresponding instance variables 736 | are deleted. It handles exceptions gracefully, logging any issues encountered 737 | during the cleanup process. 738 | 739 | Raises: 740 | Exception: If an error occurs while closing a resource or deleting 741 | an attribute, it will log the error details. 742 | """ 743 | try: 744 | # Close any open files or clips 745 | if hasattr(self, "audio_clip") and self.audio_clip: 746 | self.audio_clip.close() 747 | if hasattr(self, "bg_video") and self.bg_video: 748 | self.bg_video.close() 749 | if hasattr(self, "music_clip") and self.music_clip: 750 | self.music_clip.close() 751 | if hasattr(self, "credits_video") and self.credits_video: 752 | self.credits_video.close() 753 | if hasattr(self, "credit_video_mask") and self.credit_video_mask: 754 | self.credit_video_mask.close() 755 | except Exception as e: 756 | self.logger.error(f"Error closing resources: {e}") 757 | 758 | # Delete all instance variables 759 | for attr in list(self.__dict__.keys()): 760 | try: 761 | self.logger.debug(f"Deleting {attr}") 762 | if attr == "logger": 763 | continue 764 | delattr(self, attr) 765 | except Exception as e: 766 | self.logger.error(f"Error deleting {attr}: {e}") 767 | 768 | self.logger.debug("Resources successfully cleaned up.") 769 | return 770 | -------------------------------------------------------------------------------- /ShortsMaker/shorts_maker.py: -------------------------------------------------------------------------------- 1 | import json 2 | import secrets 3 | from collections.abc import Generator 4 | from pathlib import Path 5 | from pprint import pformat 6 | from typing import Any 7 | 8 | import ftfy 9 | import language_tool_python 10 | import praw 11 | import yaml 12 | from praw.models import Submission, Subreddit 13 | from unidecode import unidecode 14 | 15 | from .utils import VOICES, generate_audio_transcription, get_logger, retry, tts 16 | 17 | # needed for retry decorator 18 | MAX_RETRIES: int = 1 19 | DELAY: int = 0 20 | NOTIFY: bool = False 21 | 22 | # Constants 23 | PUNCTUATIONS = [".", ";", ":", "!", "?", '"'] 24 | ESCAPE_CHARACTERS = ["\n", "\t", "\r", " "] 25 | # abbreviation, replacement, padding 26 | ABBREVIATION_TUPLES = [ 27 | ("\n", " ", ""), 28 | ("\t", " ", ""), 29 | ("\r", " ", ""), 30 | ("(", "", ""), 31 | (")", "", ""), 32 | ("AITA", "Am I the asshole ", " "), 33 | ("WIBTA", "Would I be the asshole ", " "), 34 | ("NTA", "Not the asshole ", " "), 35 | ("YTA", "You're the Asshole", ""), 36 | ("YWBTA", "You Would Be the Asshole", ""), 37 | ("YWNBTA", "You Would Not be the Asshole", ""), 38 | ("ESH", "Everyone Sucks here", ""), 39 | ("NAH", "No Assholes here", ""), 40 | ("INFO", "Not Enough Info", ""), 41 | ("FIL", "father in law ", " "), 42 | ("BIL", "brother in law ", " "), 43 | ("MIL", "mother in law ", " "), 44 | ("SIL", "sister in law ", " "), 45 | (" BF ", " boyfriend ", ""), 46 | (" GF ", " girlfriend ", ""), 47 | (" bf ", " boyfriend ", ""), 48 | (" gf ", " girlfriend ", ""), 49 | (" ", " ", ""), 50 | ] 51 | 52 | 53 | def abbreviation_replacer(text: str, abbreviation: str, replacement: str, padding: str = "") -> str: 54 | """ 55 | Replaces all occurrences of an abbreviation within a given text with a specified replacement. 56 | 57 | This function allows replacing abbreviations with a replacement string while considering optional padding 58 | around the abbreviation. Padding ensures the abbreviation is correctly replaced regardless of its position 59 | in the text or the surrounding characters. 60 | 61 | Args: 62 | text (str): The text where the abbreviation occurrences should be replaced. 63 | abbreviation (str): The abbreviation to be replaced in the text. 64 | replacement (str): The string to replace the abbreviation with. 65 | padding (str, optional): Additional characters surrounding the abbreviation, making the 66 | replacement match more specific. Default is an empty string. 67 | 68 | Returns: 69 | str: The text with all occurrences of the abbreviation replaced by the replacement string. 70 | """ 71 | text = text.replace(abbreviation + padding, replacement) 72 | text = text.replace(padding + abbreviation, replacement) 73 | return text 74 | 75 | 76 | def has_alpha_and_digit(word: str) -> bool: 77 | """ 78 | Determines if a string contains both alphabetic and numeric characters. 79 | 80 | This function checks whether the given string contains at least one alphabetic 81 | character and at least one numeric character. It utilizes Python's string methods 82 | to identify the required character types. 83 | 84 | Args: 85 | word: The string to check for the presence of alphabetic and numeric 86 | characters. 87 | 88 | Returns: 89 | bool: True if the string contains at least one alphabetic character and one 90 | numeric character, otherwise False. 91 | """ 92 | return any(character.isalpha() for character in word) and any( 93 | character.isdigit() for character in word 94 | ) 95 | 96 | 97 | def split_alpha_and_digit(word): 98 | """ 99 | Splits a given string into separate segments of alphabetic and numeric sequences. 100 | 101 | This function processes each character in the input string and divides it into 102 | distinct groups of alphabetic sequences and numeric sequences. A space is added 103 | between these groups whenever a transition occurs between alphabetic and numeric 104 | characters, or vice versa. Non-alphanumeric characters are included as is without 105 | causing a split. 106 | 107 | Args: 108 | word (str): The input string to be split into alphabetic and numeric 109 | segments. 110 | 111 | Returns: 112 | str: A string where alphabetic and numeric segments from the input are 113 | separated by a space while retaining other characters. 114 | """ 115 | res = "" 116 | alpha = False 117 | digit = False 118 | for character in word: 119 | if character.isalpha(): 120 | alpha = True 121 | if digit: 122 | res += " " 123 | digit = False 124 | res += character 125 | elif character.isdigit(): 126 | digit = True 127 | if alpha: 128 | res += " " 129 | alpha = False 130 | res += character 131 | else: 132 | res += character 133 | return res 134 | 135 | 136 | class ShortsMaker: 137 | """ 138 | Represents a utility class to facilitate the creation of video shorts from 139 | text and audio assets. The class manages configuration, logging, processing 140 | of text for audio generation, reddit post retrieval, and asset handling among 141 | other operations. 142 | 143 | The `ShortsMaker` class is designed to be highly extensible and configurable 144 | via YAML configuration files. It includes robust error handling for invalid 145 | or missing configuration files and directories. The functionality also integrates 146 | with external tools such as Reddit API and grammar correction tools to streamline 147 | the process of creating video shorts. 148 | 149 | Attributes: 150 | VALID_CONFIG_EXTENSION (str): Expected file extension for configuration files. 151 | setup_cfg (Path): Path of the validated configuration file. 152 | cfg (dict): Parsed configuration from the loaded configuration file. 153 | cache_dir (Path): Directory for storing temporary files and intermediate data. 154 | logging_cfg (dict): Configuration for setting up logging. 155 | logger (Logger): Logger instance used for logging events and errors. 156 | retry_cfg (dict): Configuration parameters for retry logic, including maximum 157 | retries and delay between retries. 158 | word_transcript (str | None): Transcript represented as individual words. 159 | line_transcript (str | None): Transcript represented as individual lines. 160 | transcript (str | None): Full transcript derived from the Reddit post or input text. 161 | audio_cfg (dict | None): Configuration details specific to audio processing. 162 | reddit_post (dict | None): Details related to the Reddit post being processed. 163 | reddit_cfg (dict | None): Configuration details specific to Reddit API integration. 164 | """ 165 | 166 | VALID_CONFIG_EXTENSION = ".yml" 167 | 168 | def __init__(self, config_file: Path | str) -> None: 169 | self.setup_cfg = self._validate_config_path(config_file) 170 | self.cfg = self._load_config() 171 | self.logger = get_logger(__name__) 172 | self.cache_dir = self._setup_cache_directory() 173 | self.retry_cfg = self._setup_retry_config() 174 | 175 | # Initialize other instance variables 176 | self.word_transcript: str | None = None 177 | self.line_transcript: str | None = None 178 | self.transcript: str | None = None 179 | self.audio_cfg: dict | None = None 180 | self.reddit_post: dict | None = None 181 | self.reddit_cfg: dict | None = None 182 | 183 | def _validate_config_path(self, config_file: Path | str) -> Path: 184 | """ 185 | Validates the given configuration file path to ensure it exists and has the correct format. 186 | 187 | This method checks whether the provided file path points to an actual file and 188 | whether its extension matches the expected configuration file format. If any 189 | of these conditions are not met, appropriate exceptions are raised. 190 | 191 | Args: 192 | config_file: A file path string or a Path object representing the 193 | configuration file to be validated. 194 | 195 | Returns: 196 | The validated configuration file path as a Path object. 197 | 198 | Raises: 199 | FileNotFoundError: If the configuration file does not exist. 200 | ValueError: If the configuration file format is invalid. 201 | """ 202 | config_path = Path(config_file) if isinstance(config_file, str) else config_file 203 | if not config_path.exists(): 204 | raise FileNotFoundError(f"Configuration file not found: {config_path}") 205 | if config_path.suffix != self.VALID_CONFIG_EXTENSION: 206 | raise ValueError( 207 | f"Invalid configuration file format. Expected {self.VALID_CONFIG_EXTENSION}" 208 | ) 209 | return config_path 210 | 211 | def _load_config(self) -> dict[str, Any]: 212 | """ 213 | Loads and parses configuration data from a YAML file. 214 | 215 | This method attempts to open and parse a YAML configuration file specified 216 | by the `setup_cfg` attribute of the class. If the file does not contain valid 217 | YAML or cannot be read, it raises an exception with an appropriate error message. 218 | 219 | Returns: 220 | Dict[str, Any]: A dictionary representation of the loaded YAML configuration. 221 | 222 | Raises: 223 | ValueError: If the YAML file contains invalid content or cannot be parsed. 224 | """ 225 | try: 226 | with open(self.setup_cfg) as f: 227 | return yaml.safe_load(f) 228 | except yaml.YAMLError as e: 229 | raise ValueError(f"Invalid YAML configuration: {e}") 230 | 231 | def _setup_cache_directory(self) -> Path: 232 | """ 233 | Sets up the cache directory based on the configuration and ensures its existence. 234 | 235 | This method retrieves the cache directory path from the configuration, creates the 236 | directory (including any required parent directories), and returns the Path object 237 | representing the cache directory. If the directory already exists, it will not attempt 238 | to create it again. 239 | 240 | Returns: 241 | Path: A Path object representing the cache directory. 242 | """ 243 | if "cache_dir" not in self.cfg: 244 | self.logger.info("Cache directory not specified, creating it.") 245 | self.cfg["cache_dir"] = Path.cwd() 246 | cache_dir = Path(self.cfg["cache_dir"]) 247 | cache_dir.mkdir(parents=True, exist_ok=True) 248 | return cache_dir 249 | 250 | def _setup_retry_config(self) -> dict[str, Any]: 251 | """ 252 | Configures the retry mechanism based on the settings provided in the 253 | configuration. If retry is disabled, it updates the retry settings to default 254 | values. Updates global constants to reflect the current retry configuration. 255 | 256 | Returns: 257 | Dict[str, Any]: The updated retry configuration dictionary. 258 | """ 259 | retry_config = dict() 260 | if "retry" not in self.cfg: 261 | retry_config = {"max_retries": 1, "delay": 0, "notify": False} 262 | if "retry" in self.cfg: 263 | retry_config.update(self.cfg["retry"]) 264 | 265 | global MAX_RETRIES, DELAY, NOTIFY 266 | MAX_RETRIES = retry_config["max_retries"] 267 | DELAY = retry_config["delay"] 268 | NOTIFY = retry_config["notify"] 269 | 270 | return retry_config 271 | 272 | def get_submission_from_subreddit( 273 | self, reddit: praw.Reddit, subreddit_name: str 274 | ) -> Generator[Submission]: 275 | """ 276 | Retrieves a unique Reddit submission from a specified subreddit. 277 | 278 | Args: 279 | reddit (praw.Reddit): An instance of the Reddit API client. 280 | subreddit_name (str): The name of the subreddit to fetch submissions from. 281 | submission_category (str): The category of submissions to filter by (e.g., "hot", "new") TODO. 282 | 283 | Returns: 284 | Submission: A unique Reddit submission object. 285 | """ 286 | subreddit: Subreddit = reddit.subreddit(subreddit_name) 287 | self.logger.info(f"Subreddit title: {subreddit.title}") 288 | self.logger.info(f"Subreddit display name: {subreddit.display_name}") 289 | yield from subreddit.hot() 290 | 291 | def is_unique_submission(self, submission: Submission) -> bool: 292 | """ 293 | Checks if the given Reddit submission is unique based on its ID. 294 | 295 | Args: 296 | submission (Submission): The Reddit submission to check. 297 | 298 | Returns: 299 | bool: True if the submission is unique, False otherwise. 300 | """ 301 | submission_dirs = self.cache_dir / "reddit_submissions" 302 | submission_dirs.mkdir(parents=True, exist_ok=True) 303 | self.logger.debug("Checking if submission is unique") 304 | self.logger.debug(f"Submission ID: {submission.id}") 305 | if any(f"{submission.name}.json" == file.name for file in submission_dirs.iterdir()): 306 | self.logger.info(f"Submission {submission.name} - '{submission.title}' already exists") 307 | return False 308 | else: 309 | with open(submission_dirs / f"{submission.name}.json", "w") as record_file: 310 | # Object of type Reddit is not JSON serializable, hence need to use vars 311 | json.dump( 312 | {key: str(value) for key, value in vars(submission).items()}, 313 | record_file, 314 | indent=4, 315 | skipkeys=True, 316 | sort_keys=True, 317 | ) 318 | self.logger.debug("Unique submission found") 319 | self.logger.info(f"Submission saved to {submission_dirs / f'{submission.name}.json'}") 320 | return True 321 | 322 | @retry(max_retries=MAX_RETRIES, delay=DELAY, notify=NOTIFY) 323 | def get_reddit_post(self, url: str | None = None) -> str: 324 | """ 325 | Retrieves a random top Reddit post from a specified subreddit, saves the post details 326 | to both a JSON file and a text file, and returns the text content of the post. 327 | 328 | Args: 329 | url (str | None): The URL of the Reddit post to retrieve. If None, a random top post is retrieved. 330 | 331 | Returns: 332 | str: The text content of the retrieved Reddit post. 333 | 334 | Raises: 335 | ValueError: If any value processing errors occur. 336 | IOError: If file handling (reading/writing) fails. 337 | praw.exceptions.PRAWException: If PRAW encounters an API or authentication issue. 338 | """ 339 | self.reddit_cfg = self.cfg["reddit_praw"] 340 | self.reddit_post = self.cfg["reddit_post_getter"] 341 | 342 | self.logger.info("Getting Reddit post") 343 | reddit: praw.Reddit = praw.Reddit( 344 | client_id=self.reddit_cfg["client_id"], 345 | client_secret=self.reddit_cfg["client_secret"], 346 | user_agent=self.reddit_cfg["user_agent"], 347 | # username=self.reddit_cfg["username"], 348 | # password=self.reddit_cfg["password"] 349 | ) 350 | self.logger.info(f"Is reddit readonly: {reddit.read_only}") 351 | 352 | if url: 353 | submission = reddit.submission(url=url) 354 | else: 355 | for submission_found in self.get_submission_from_subreddit( 356 | reddit, self.reddit_post["subreddit_name"] 357 | ): 358 | if self.is_unique_submission(submission_found): 359 | submission = submission_found 360 | break 361 | 362 | self.logger.info(f"Submission Url: {submission.url}") 363 | self.logger.info(f"Submission title: {submission.title}") 364 | 365 | data = dict() 366 | for key, value in vars(submission).items(): 367 | data[key] = str(value) 368 | 369 | # Save the submission to a json file 370 | with open(self.cache_dir / self.reddit_post["record_file_json"], "w") as record_file: 371 | # noinspection PyTypeChecker 372 | json.dump(data, record_file, indent=4, skipkeys=True, sort_keys=True) 373 | self.logger.info( 374 | f"Submission saved to {self.cache_dir / self.reddit_post['record_file_json']}" 375 | ) 376 | 377 | # Save the submission to a text file 378 | with open(self.cache_dir / self.reddit_post["record_file_txt"], "w") as text_file: 379 | text_file.write(unidecode(ftfy.fix_text(submission.title)) + "." + "\n") 380 | text_file.write(unidecode(ftfy.fix_text(submission.selftext)) + "\n") 381 | self.logger.info( 382 | f"Submission text saved to {self.cache_dir / self.reddit_post['record_file_txt']}" 383 | ) 384 | 385 | # return the generated file contents 386 | with open(self.cache_dir / self.reddit_post["record_file_txt"]) as result_file: 387 | result_string = result_file.read() 388 | return result_string 389 | 390 | @retry(max_retries=MAX_RETRIES, delay=DELAY, notify=NOTIFY) 391 | def fix_text(self, source_txt: str, debug: bool = True) -> str: 392 | """ 393 | Fixes and corrects grammatical and textual issues in the provided text input using language processing tools. 394 | The method processes the input text by fixing encoding issues, normalizing it, splitting it into sentences, 395 | and then correcting the grammar of each individual sentence. An optional debug mode saves the processed text 396 | to a debug file for inspection. 397 | 398 | Args: 399 | source_txt: The text to be processed and corrected. 400 | debug: If True, saves the corrected text to a debug file for further analysis. 401 | 402 | Returns: 403 | str: The corrected and formatted text. 404 | 405 | Raises: 406 | Exception: Raised if errors occur during text correction within individual sentences. 407 | """ 408 | self.logger.info("Setting up language tool text fixer") 409 | grammar_fixer = language_tool_python.LanguageTool("en-US") 410 | 411 | source_txt = ftfy.fix_text(source_txt) 412 | source_txt = unidecode(source_txt) 413 | for escape_char in ESCAPE_CHARACTERS: 414 | source_txt = source_txt.replace(escape_char, " ") 415 | 416 | sentences = [] 417 | res = [] 418 | 419 | for word in source_txt.split(" "): 420 | if word == "": 421 | continue 422 | if word[0] in PUNCTUATIONS: 423 | sentences.append(" ".join(res)) 424 | res = [] 425 | res.append(word) 426 | if word[-1] in PUNCTUATIONS: 427 | sentences.append(" ".join(res)) 428 | res = [] 429 | 430 | self.logger.info( 431 | f"Split text into sentences and fixed text. Found {len(sentences)} sentences" 432 | ) 433 | 434 | corrected_sentences = [] 435 | for sentence in sentences: 436 | try: 437 | corrected_sentences.append(grammar_fixer.correct(sentence)) 438 | except Exception as e: 439 | self.logger.error(f"Error: {e}") 440 | corrected_sentences.append(sentence) 441 | 442 | grammar_fixer.close() 443 | result_string = " ".join(corrected_sentences) 444 | 445 | if debug: 446 | with open(self.cache_dir / "fix_text_debug.txt", "w") as text_file: 447 | text_file.write(result_string) 448 | self.logger.info(f"Debug text saved to {self.cache_dir / 'fix_text_debug.txt'}") 449 | 450 | return result_string 451 | 452 | @retry(max_retries=MAX_RETRIES, delay=DELAY, notify=NOTIFY) 453 | def generate_audio( 454 | self, 455 | source_txt: str, 456 | output_audio: str | None = None, 457 | output_script_file: str | None = None, 458 | seed: str | None = None, 459 | ) -> bool: 460 | """ 461 | Generates audio from a given textual input. The function processes the input text, 462 | performs text transformations (e.g., replacing abbreviations and splitting alphanumeric 463 | combinations), and uses a synthesized voice to create an audio file. It also writes the 464 | processed script to a text file. Speaker selection is either randomized or based on the 465 | provided seed. 466 | 467 | Args: 468 | source_txt (str): The input text to be converted into audio. 469 | output_audio (str | None): The path to save the generated audio. If not provided, a 470 | default path is generated based on the configuration or cache directory. 471 | output_script_file (str | None): The file path to save the processed text script. If 472 | not provided, a default path is generated based on the configuration or cache 473 | directory. 474 | seed (str | None): An optional seed to determine the choice of speaker. If not 475 | provided, the function randomly selects a speaker. Refer to VOICES for available 476 | speakers. 477 | 478 | Returns: 479 | bool: Returns True if the audio generation is successful; False otherwise. 480 | 481 | Raises: 482 | Exception: If an error occurs during text-to-speech processing. 483 | """ 484 | self.audio_cfg = self.cfg["audio"] 485 | if output_audio is None: 486 | self.logger.info("No output audio file specified. Generating output audio file") 487 | if "output_audio_file" in self.audio_cfg: 488 | output_audio = self.cache_dir / self.audio_cfg["output_audio_file"] 489 | else: 490 | output_audio = self.cache_dir / "output.wav" 491 | 492 | if output_script_file is None: 493 | self.logger.info("No output script file specified. Generating output script file") 494 | if "output_script_file" in self.audio_cfg: 495 | output_script_file = self.cache_dir / self.audio_cfg["output_script_file"] 496 | else: 497 | output_script_file = self.cache_dir / "generated_audio_script.txt" 498 | 499 | self.logger.info("Generating audio from text") 500 | for abbreviation, replacement, padding in ABBREVIATION_TUPLES: 501 | source_txt = abbreviation_replacer(source_txt, abbreviation, replacement, padding) 502 | source_txt = source_txt.strip() 503 | 504 | for s in source_txt.split(" "): 505 | if has_alpha_and_digit(s): 506 | source_txt = source_txt.replace(s, split_alpha_and_digit(s)) 507 | 508 | with open(output_script_file, "w") as text_file: 509 | text_file.write(source_txt) 510 | self.logger.info(f"Text saved to {output_script_file}") 511 | 512 | if seed is None: 513 | speaker = secrets.choice(VOICES) 514 | else: 515 | speaker = seed 516 | 517 | self.logger.info(f"Generating audio with speaker: {speaker}") 518 | 519 | try: 520 | tts(source_txt, speaker, output_audio) 521 | self.logger.info( 522 | f"Successfully generated audio.\nSpeaker: {speaker}\nOutput path: {output_audio}" 523 | ) 524 | except Exception as e: 525 | self.logger.error(f"Error: {e}") 526 | self.logger.error("Failed to generate audio with tiktokvoice") 527 | return False 528 | 529 | return True 530 | 531 | @retry(max_retries=MAX_RETRIES, delay=DELAY, notify=NOTIFY) 532 | def generate_audio_transcript( 533 | self, 534 | source_audio_file: Path | str, 535 | source_text_file: Path | str, 536 | output_transcript_file: str | None = None, 537 | debug: bool = True, 538 | ) -> list[dict[str, str | float]]: 539 | """ 540 | Generates an audio transcript by processing a source audio file and its corresponding text 541 | file, using predefined configurations such as model, device, and batch size. Saves the 542 | resulting transcript into a specified output file or a default cache location. Additionally, 543 | provides an option to enable debug logging. 544 | 545 | Args: 546 | source_audio_file (Path): The source audio file to be transcribed. 547 | source_text_file (Path): The text file containing the corresponding script. 548 | output_transcript_file (str | None): The file where the resulting transcript will be saved. 549 | Defaults to a predefined location if not specified. 550 | debug (bool): Whether to enable debug logging of the processed transcript. 551 | 552 | Returns: 553 | list[dict[str, str | float]]: A list of word-level transcription data, where each entry 554 | contains word-related information such as timestamps and confidence scores. 555 | """ 556 | self.audio_cfg = self.cfg["audio"] 557 | self.logger.info("Generating audio transcript") 558 | 559 | # read the script 560 | with open(source_text_file) as text_file: 561 | source_text = text_file.read() 562 | 563 | self.word_transcript = generate_audio_transcription( 564 | audio_file=str(source_audio_file), 565 | script=source_text, 566 | device=self.audio_cfg["device"], 567 | model=self.audio_cfg["model"], 568 | batch_size=self.audio_cfg["batch_size"], 569 | compute_type=self.audio_cfg["compute_type"], 570 | ) 571 | self.word_transcript = self._filter_word_transcript(self.word_transcript) 572 | 573 | if output_transcript_file is None: 574 | output_transcript_file = self.cache_dir / self.audio_cfg["transcript_json"] 575 | 576 | self.logger.info(f"Saving transcript to {output_transcript_file}") 577 | 578 | with open(output_transcript_file, "w") as transcript_file: 579 | # noinspection PyTypeChecker 580 | json.dump( 581 | self.word_transcript, 582 | transcript_file, 583 | indent=4, 584 | skipkeys=True, 585 | sort_keys=True, 586 | ) 587 | 588 | if debug: 589 | self.logger.info(pformat(self.word_transcript)) 590 | 591 | return self.word_transcript 592 | 593 | def _filter_word_transcript( 594 | self, transcript: list[dict[str, str | float]] 595 | ) -> list[dict[str, str | float]]: 596 | # filter entries which have a start time of 0 and end time of greater than 5s 597 | return [ 598 | entry 599 | for entry in transcript 600 | if entry["start"] > 0 and (entry["end"] - entry["start"]) < 5 601 | ] 602 | 603 | def quit(self) -> None: 604 | """ 605 | Closes and cleans up resources used in the class instance. 606 | 607 | This method ensures that all resources, tools, and variables used within the 608 | class instance are properly closed or removed to prevent memory leaks or issues 609 | when the instance is no longer in use. It includes closing language tools, if 610 | utilized, and deleting all instance variables except the logger. 611 | 612 | Raises: 613 | Exception: If there is an issue closing the grammar fixer or deleting 614 | instance variables. Specific details are logged. 615 | 616 | Returns: 617 | None: This method does not return any value. 618 | """ 619 | self.logger.debug("Closing and cleaning up resources.") 620 | # Close the language tool if it was used 621 | if hasattr(self, "grammar_fixer") and self.grammar_fixer: 622 | try: 623 | self.grammar_fixer.close() 624 | except Exception as e: 625 | self.logger.error(f"Error closing grammar fixer: {e}") 626 | 627 | # Delete all instance variables 628 | for attr in list(self.__dict__.keys()): 629 | try: 630 | self.logger.debug(f"Deleting {attr}") 631 | if attr == "logger": 632 | continue 633 | delattr(self, attr) 634 | except Exception as e: 635 | self.logger.warning(f"Error deleting {attr}: {e}") 636 | 637 | self.logger.debug("All objects in the class have been deleted.") 638 | return 639 | -------------------------------------------------------------------------------- /ShortsMaker/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .audio_transcript import align_transcript_with_script, generate_audio_transcription 2 | from .colors_dict import COLORS_DICT 3 | from .download_youtube_music import download_youtube_music, sanitize_filename 4 | from .download_youtube_video import download_youtube_video 5 | from .get_tts import VOICES, tts 6 | from .logging_config import configure_logging, get_logger 7 | from .notify_discord import notify_discord 8 | from .retry import retry 9 | 10 | __all__ = [ 11 | align_transcript_with_script, 12 | configure_logging, 13 | download_youtube_music, 14 | download_youtube_video, 15 | generate_audio_transcription, 16 | get_logger, 17 | notify_discord, 18 | retry, 19 | sanitize_filename, 20 | tts, 21 | COLORS_DICT, 22 | VOICES, 23 | ] 24 | 25 | # Configure logging to their preferences 26 | configure_logging() 27 | -------------------------------------------------------------------------------- /ShortsMaker/utils/audio_transcript.py: -------------------------------------------------------------------------------- 1 | import gc 2 | from pprint import pformat 3 | 4 | import torch 5 | import whisperx 6 | from rapidfuzz import process 7 | 8 | from .logging_config import get_logger 9 | 10 | logger = get_logger(__name__) 11 | 12 | 13 | def align_transcript_with_script(transcript: list[dict], script_string: str) -> list[dict]: 14 | """ 15 | Aligns the transcript entries with corresponding segments of a script string by 16 | comparing text similarities and finding the best matches. This process adjusts 17 | each transcript entry to align as closely as possible with the correct script 18 | segment while maintaining the temporal information from the transcript. 19 | 20 | Args: 21 | transcript (list[dict]): A list of dictionaries, each containing a segment 22 | of the transcript with keys "text" (text of the segment), "start" 23 | (start time), and "end" (end time). 24 | script_string (str): The entire script as a single string to which the 25 | transcript is aligned. 26 | 27 | Returns: 28 | list[dict]: A list containing the transcript with updated "text" fields 29 | that are aligned to the most similar segments of the script. The 30 | "start" and "end" fields remain unchanged from the input. 31 | 32 | Raises: 33 | ValueError: If either the transcript or script_string is empty. Applies 34 | to cases where alignment cannot be performed. 35 | 36 | """ 37 | temp_transcript = [] 38 | window_sizes = [i for i in range(6)] 39 | script_words = script_string.split() 40 | 41 | for entry in transcript: 42 | possible_windows = [] 43 | length_of_entry_text = len(entry["text"].split()) 44 | 45 | # Generate script windows for all specified window sizes 46 | for window_size in window_sizes: 47 | possible_windows.extend([" ".join(script_words[: length_of_entry_text + window_size])]) 48 | possible_windows.extend([" ".join(script_words[: length_of_entry_text - window_size])]) 49 | 50 | # Find the best match among all possible windows 51 | # print(f"Entry text: {entry['text']}\n" 52 | # f"Possible windows: {possible_windows}" 53 | # "\n\n\n" 54 | # ) 55 | best_match, score, _ = process.extractOne(entry["text"], possible_windows) 56 | 57 | if best_match: 58 | script_words = script_words[len(best_match.split()) :] 59 | 60 | # print( 61 | # f"Best match: {best_match}, Score: {score} " 62 | # f"Script words remaining: {len(script_words)}" 63 | # f"Script words: {script_words} \n\n" 64 | # ) 65 | 66 | # Add the match or original text to the new transcript 67 | temp_transcript.append( 68 | { 69 | "text": best_match if best_match else entry["text"], 70 | "start": entry["start"], 71 | "end": entry["end"], 72 | } 73 | ) 74 | return temp_transcript 75 | 76 | 77 | def generate_audio_transcription( 78 | audio_file: str, 79 | script: str, 80 | device="cuda", 81 | batch_size=16, 82 | compute_type="float16", 83 | model="large-v2", 84 | ) -> list[dict[str, str | float]]: 85 | """ 86 | Generates a transcription of an audio file by performing speech-to-text transcription and aligning the 87 | transcription with a given script. It utilizes whisper models for transcription and alignment to improve 88 | accuracy. 89 | 90 | This function processes the audio in batches, aligns the transcriptions with the provided script for better 91 | accuracy, and cleans GPU memory usage during its workflow. It outputs a list of word-level transcriptions 92 | with start and end times for enhanced downstream processing. 93 | 94 | Args: 95 | audio_file (str): The path to the audio file that needs to be transcribed. 96 | script (str): The text script used for alignment with the transcribed segments. 97 | device (str): The device to be used for computation, default is 'cuda'. 98 | batch_size (int): The batch size to use during transcription, default is 16. 99 | compute_type (str): The precision type to be used for the model, default is "float16". 100 | model (str): The Whisper model variant to use, default is "large-v2". Options include "medium", 101 | "large-v2", and "large-v3". 102 | 103 | Returns: 104 | list[dict[str, str | float]]: A list of dictionaries, where each dictionary represents a word in 105 | the transcription with the word text, start time, and end time. 106 | 107 | Raises: 108 | Could include potential runtime or memory-related errors specific to the underlying 109 | libraries or resource management. 110 | """ 111 | # 1. Transcribe with original whisper (batched) 112 | # options for models medium, large-v2, large-v3 113 | model = whisperx.load_model(model, device, compute_type=compute_type) 114 | 115 | audio = whisperx.load_audio(audio_file) 116 | result = model.transcribe(audio, batch_size=batch_size, language="en") 117 | logger.debug(f"Before Alignment:\n {pformat(result['segments'])}") # before alignment 118 | 119 | new_aligned_transcript = align_transcript_with_script(result["segments"], script) 120 | 121 | # delete model if low on GPU resources 122 | gc.collect() 123 | if torch.cuda.is_available(): 124 | torch.cuda.empty_cache() 125 | del model 126 | 127 | # 2. Align whisper output 128 | model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device) 129 | result = whisperx.align( 130 | new_aligned_transcript, 131 | model_a, 132 | metadata, 133 | audio, 134 | device, 135 | return_char_alignments=False, 136 | ) 137 | 138 | logger.debug( 139 | f"After Alignment:\n {pformat(result['segments'])}" 140 | ) # before alignment # after alignment 141 | 142 | word_transcript = [] 143 | for segments in result["segments"]: 144 | for index, word in enumerate(segments["words"]): 145 | if "start" not in word: 146 | word["start"] = segments["words"][index - 1]["end"] if index > 0 else 0 147 | word["end"] = ( 148 | segments["words"][index + 1]["start"] 149 | if index < len(segments["words"]) - 1 150 | else segments["words"][-1]["start"] 151 | ) 152 | 153 | word_transcript.append( 154 | {"word": word["word"], "start": word["start"], "end": word["end"]} 155 | ) 156 | 157 | logger.debug(f"Transcript:\n {pformat(word_transcript)}") # before alignment 158 | 159 | gc.collect() 160 | if torch.cuda.is_available(): 161 | torch.cuda.empty_cache() 162 | del model_a 163 | 164 | return word_transcript 165 | -------------------------------------------------------------------------------- /ShortsMaker/utils/colors_dict.py: -------------------------------------------------------------------------------- 1 | # A list of colors and their equivalent RGBA values 2 | COLORS_DICT: dict[str, tuple[int, int, int, int]] = { 3 | "aquamarine": (127, 255, 212, 255), 4 | "aquamarine2": (118, 238, 198, 255), 5 | "azure1": (240, 255, 255, 255), 6 | "blue": (0, 0, 255, 255), 7 | "CadetBlue": (152, 245, 255, 255), 8 | "CadetBlue2": (142, 229, 238, 255), 9 | "CadetBlue3": (122, 197, 205, 255), 10 | "coral": (255, 127, 80, 255), 11 | "coral1": (255, 114, 86, 255), 12 | "coral2": (238, 106, 80, 255), 13 | "CornflowerBlue": (100, 149, 237, 255), 14 | "cornsilk1": (255, 248, 220, 255), 15 | "cyan": (0, 255, 255, 255), 16 | "cyan2": (0, 238, 238, 255), 17 | "cyan3": (0, 205, 205, 255), 18 | "DarkGoldenrod": (255, 185, 15, 255), 19 | "DarkGoldenrod2": (238, 173, 14, 255), 20 | "DarkOliveGreen": (202, 255, 112, 255), 21 | "DarkOliveGreen2": (188, 238, 104, 255), 22 | "DarkOliveGreen3": (162, 205, 90, 255), 23 | "DarkSalmon": (233, 150, 122, 255), 24 | "DarkSeaGreen": (143, 188, 143, 255), 25 | "DarkSeaGreen1": (193, 255, 193, 255), 26 | "DarkSeaGreen2": (180, 238, 180, 255), 27 | "DarkSeaGreen3": (155, 205, 155, 255), 28 | "DarkSlateGray": (151, 255, 255, 255), 29 | "DarkSlateGray2": (141, 238, 238, 255), 30 | "DarkSlateGray3": (121, 205, 205, 255), 31 | "DarkTurquoise": (0, 206, 209, 255), 32 | "DeepPink": (255, 20, 147, 255), 33 | "DeepSkyBlue": (0, 191, 255, 255), 34 | "DeepSkyBlue2": (0, 178, 238, 255), 35 | "DodgerBlue": (30, 144, 255, 255), 36 | "FloralWhite": (255, 250, 240, 255), 37 | "gold": (255, 215, 0, 255), 38 | "gold2": (238, 201, 0, 255), 39 | "goldenrod": (255, 193, 37, 255), 40 | "goldenrod2": (238, 180, 34, 255), 41 | "goldenrod3": (205, 155, 29, 255), 42 | "HotPink": (255, 105, 180, 255), 43 | "HotPink1": (255, 110, 180, 255), 44 | "HotPink2": (238, 106, 167, 255), 45 | "IndianRed": (255, 106, 106, 255), 46 | "IndianRed2": (238, 99, 99, 255), 47 | "khaki": (240, 230, 140, 255), 48 | "khaki1": (255, 246, 143, 255), 49 | "khaki2": (238, 230, 133, 255), 50 | "khaki3": (205, 198, 115, 255), 51 | "lavender": (230, 230, 250, 255), 52 | "LawnGreen": (124, 252, 0, 255), 53 | "LemonChiffon": (255, 250, 205, 255), 54 | "LemonChiffon2": (238, 233, 191, 255), 55 | "LemonChiffon3": (205, 201, 165, 255), 56 | "LightBlue": (173, 216, 230, 255), 57 | "LightBlue1": (191, 239, 255, 255), 58 | "LightCyan": (224, 255, 255, 255), 59 | "LightCyan2": (209, 238, 238, 255), 60 | "LightCyan3": (180, 205, 205, 255), 61 | "LightGoldenrod": (255, 236, 139, 255), 62 | "LightGoldenrod2": (238, 220, 130, 255), 63 | "LightGoldenrod3": (205, 190, 112, 255), 64 | "LightGoldenrodYellow": (250, 250, 210, 255), 65 | "LightGreen": (144, 238, 144, 255), 66 | "LightPink": (255, 182, 193, 255), 67 | "LightPink1": (255, 174, 185, 255), 68 | "LightSalmon": (255, 160, 122, 255), 69 | "LightSeaGreen": (32, 178, 170, 255), 70 | "LightSkyBlue": (135, 206, 250, 255), 71 | "LightSkyBlue1": (176, 226, 255, 255), 72 | "LightSkyBlue2": (164, 211, 238, 255), 73 | "LightSkyBlue3": (141, 182, 205, 255), 74 | "LightSlateBlue": (132, 112, 255, 255), 75 | "LightYellow": (255, 255, 224, 255), 76 | "lime": (0, 255, 0, 255), 77 | "LimeGreen": (50, 205, 50, 255), 78 | "linen": (250, 240, 230, 255), 79 | "magenta": (255, 0, 255, 255), 80 | "maroon1": (255, 52, 179, 255), 81 | "MediumAquamarine": (102, 205, 170, 255), 82 | "MediumGoldenRod": (209, 193, 102, 255), 83 | "MediumOrchid": (224, 102, 255, 255), 84 | "MediumOrchid2": (209, 95, 238, 255), 85 | "MediumSeaGreen": (60, 179, 113, 255), 86 | "MediumSpringGreen": (0, 250, 154, 255), 87 | "MediumTurquoise": (72, 209, 204, 255), 88 | "MintCream": (245, 255, 250, 255), 89 | "MistyRose": (255, 228, 225, 255), 90 | "MistyRose2": (238, 213, 210, 255), 91 | "moccasin": (255, 228, 181, 255), 92 | "NavajoWhite": (255, 222, 173, 255), 93 | "OliveDrab": (192, 255, 62, 255), 94 | "OliveDrab2": (179, 238, 58, 255), 95 | "OliveDrab3": (154, 205, 50, 255), 96 | "orange": (255, 165, 0, 255), 97 | "orchid": (218, 112, 214, 255), 98 | "orchid1": (255, 131, 250, 255), 99 | "orchid2": (238, 122, 233, 255), 100 | "PaleGoldenrod": (238, 232, 170, 255), 101 | "PaleGreen": (152, 251, 152, 255), 102 | "PaleGreen1": (154, 255, 154, 255), 103 | "PaleGreen2": (144, 238, 144, 255), 104 | "PaleGreen3": (124, 205, 124, 255), 105 | "PaleTurquoise": (175, 238, 238, 255), 106 | "PaleTurquoise1": (187, 255, 255, 255), 107 | "PapayaWhip": (255, 239, 213, 255), 108 | "PeachPuff": (255, 218, 185, 255), 109 | "PeachPuff2": (238, 203, 173, 255), 110 | "peru": (205, 133, 63, 255), 111 | "pink": (255, 192, 203, 255), 112 | "pink1": (255, 181, 197, 255), 113 | "pink2": (238, 169, 184, 255), 114 | "plum": (221, 160, 221, 255), 115 | "plum1": (255, 187, 255, 255), 116 | "plum2": (238, 174, 238, 255), 117 | "PowderBlue": (176, 224, 230, 255), 118 | "red": (255, 0, 0, 255), 119 | "RosyBrown1": (255, 193, 193, 255), 120 | "salmon1": (255, 140, 105, 255), 121 | "SandyBrown": (244, 164, 96, 255), 122 | "SeaGreen": (84, 255, 159, 255), 123 | "SeaGreen2": (78, 238, 148, 255), 124 | "SeaGreen3": (67, 205, 128, 255), 125 | "seashell": (255, 245, 238, 255), 126 | "seashell2": (238, 229, 222, 255), 127 | "sienna": (255, 130, 71, 255), 128 | "sienna2": (238, 121, 66, 255), 129 | "SkyBlue": (135, 206, 235, 255), 130 | "SkyBlue1": (135, 206, 255, 255), 131 | "SkyBlue2": (126, 192, 238, 255), 132 | "SkyBlue3": (108, 166, 205, 255), 133 | "SlateBlue": (131, 111, 255, 255), 134 | "SlateBlue2": (122, 103, 238, 255), 135 | "SlateGray1": (198, 226, 255, 255), 136 | "snow": (255, 250, 250, 255), 137 | "snow2": (238, 233, 233, 255), 138 | "SpringGreen": (0, 255, 127, 255), 139 | "SpringGreen2": (0, 238, 118, 255), 140 | "SteelBlue": (99, 184, 255, 255), 141 | "SteelBlue2": (92, 172, 238, 255), 142 | "tan": (210, 180, 140, 255), 143 | "tan1": (255, 165, 79, 255), 144 | "tan2": (238, 154, 73, 255), 145 | "thistle": (255, 225, 255, 255), 146 | "thistle2": (238, 210, 238, 255), 147 | "tomato": (255, 99, 71, 255), 148 | "turquoise": (64, 224, 208, 255), 149 | "turquoise1": (0, 245, 255, 255), 150 | "violet": (238, 130, 238, 255), 151 | "VioletRed": (255, 62, 150, 255), 152 | "VioletRed2": (238, 58, 140, 255), 153 | "white": (255, 255, 255, 255), 154 | "WhiteSmoke": (245, 245, 245, 255), 155 | "yellow": (255, 255, 0, 255), 156 | "yellow2": (238, 238, 0, 255), 157 | } 158 | -------------------------------------------------------------------------------- /ShortsMaker/utils/download_youtube_music.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import yt_dlp 4 | 5 | from .logging_config import get_logger 6 | 7 | logger = get_logger(__name__) 8 | 9 | 10 | def sanitize_filename(source_filename: str) -> str: 11 | """ 12 | Sanitizes a given filename by removing leading and trailing spaces, replacing spaces with underscores, 13 | and replacing invalid characters with underscores. 14 | 15 | Args: 16 | source_filename (str): The original filename to be sanitized. 17 | 18 | Returns: 19 | str: The sanitized filename. 20 | """ 21 | sanitized_filename = source_filename 22 | sanitized_filename = sanitized_filename.strip() 23 | sanitized_filename = sanitized_filename.strip(" .") 24 | sanitized_filename = sanitized_filename.replace(" ", "_") 25 | invalid_chars = '<>:"/\\|?*' 26 | sanitized_filename = "".join("_" if c in invalid_chars else c for c in sanitized_filename) 27 | return sanitized_filename 28 | 29 | 30 | def download_youtube_music(music_url: str, music_dir: Path, force: bool = False) -> list[Path]: 31 | """ 32 | Downloads music from a YouTube URL provided, saving it to a specified directory. The method supports 33 | downloading either the full audio or splitting into chapters, if chapters are available in the video 34 | metadata. Optionally, existing files can be overwritten if the `force` flag is set. 35 | 36 | Args: 37 | music_url (str): The YouTube URL of the music video to download. 38 | music_dir (Path): The directory where the downloaded audio files will be saved. 39 | force (bool): Specifies whether existing files should be overwritten. Defaults to False. 40 | 41 | Returns: 42 | list[Path]: A list of paths to the downloaded audio files. 43 | """ 44 | ydl_opts = {} 45 | 46 | with yt_dlp.YoutubeDL(ydl_opts) as ydl: 47 | extracted_info = ydl.extract_info(music_url, download=False) 48 | info_dict = ydl.sanitize_info(extracted_info) 49 | 50 | logger.info(f"Music title: {info_dict['title']}") 51 | 52 | # Handle case with no chapters 53 | if not info_dict["chapters"]: 54 | logger.info("No chapters found. Downloading full audio...") 55 | sanitized_filename = sanitize_filename(info_dict["title"]) 56 | 57 | ydl_opts = { 58 | "format": "bestaudio", 59 | "outtmpl": str(music_dir / f"{sanitized_filename}.%(ext)s"), 60 | "postprocessors": [ 61 | { 62 | "key": "FFmpegExtractAudio", 63 | "preferredcodec": "wav", 64 | "preferredquality": "0", 65 | } 66 | ], 67 | "restrictfilenames": True, 68 | } 69 | 70 | output_path = music_dir / f"{sanitized_filename}.wav" 71 | logger.info(f"Output path: {output_path.absolute()}") 72 | if (not output_path.exists() and not force) or force: 73 | with yt_dlp.YoutubeDL(ydl_opts) as ydl_audio: 74 | ydl_audio.download([music_url]) 75 | logger.info("Full audio downloaded successfully!") 76 | return [output_path] 77 | 78 | # Handle case with chapters 79 | for chapter in info_dict["chapters"]: 80 | logger.info(f"Found chapter: {chapter['title']}") 81 | sanitized_filename = sanitize_filename(chapter["title"]) 82 | 83 | ydl_opts = { 84 | "format": "bestaudio", 85 | "outtmpl": str(music_dir / f"{sanitized_filename}.%(ext)s"), 86 | "download_ranges": lambda chapter_range, *args: [ 87 | { 88 | "start_time": chapter["start_time"], 89 | "end_time": chapter["end_time"], 90 | "title": chapter["title"], 91 | } 92 | ], 93 | "force_keyframes_at_cuts": True, 94 | "postprocessors": [ 95 | { 96 | "key": "FFmpegExtractAudio", 97 | "preferredcodec": "wav", 98 | "preferredquality": "0", 99 | } 100 | ], 101 | "restrictfilenames": True, 102 | } 103 | 104 | output_path = music_dir / f"{sanitized_filename}.wav" 105 | logger.info(f"Output path: {output_path.absolute()}") 106 | if (not output_path.exists() and not force) or force: 107 | with yt_dlp.YoutubeDL(ydl_opts) as ydl_chapter_audio: 108 | ydl_chapter_audio.download([music_url]) 109 | print(f"Chapter downloaded: {chapter['title']}") 110 | 111 | # Return path to first music file found 112 | music_files = list(music_dir.glob("*.wav")) 113 | return music_files 114 | -------------------------------------------------------------------------------- /ShortsMaker/utils/download_youtube_video.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import yt_dlp 4 | 5 | from .logging_config import get_logger 6 | 7 | logger = get_logger(__name__) 8 | 9 | 10 | def download_youtube_video(video_url: str, video_dir: Path, force: bool = False) -> list[Path]: 11 | """ 12 | Downloads a YouTube video given its URL and stores it in the specified directory. The 13 | video is downloaded in the best available MP4 format, and its filename is sanitized to 14 | remove invalid characters. Provides an option to force download the video even if the 15 | target file already exists. 16 | 17 | Args: 18 | video_url: The URL of the video to be downloaded from YouTube. 19 | video_dir: The directory where the video will be saved. 20 | force: If True, forces the download even if the video file already exists. Defaults 21 | to False. 22 | 23 | Returns: 24 | list[Path]: A list containing the Path objects of the '.mp4' files in the specified 25 | directory after the download process. 26 | """ 27 | ydl_opts = { 28 | "format": "bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best", 29 | "merge_output_format": "mp4", 30 | "outtmpl": str(video_dir / "%(title)s.%(ext)s"), 31 | "restrictfilenames": True, 32 | } 33 | 34 | with yt_dlp.YoutubeDL(ydl_opts) as ydl: 35 | info = ydl.extract_info(video_url, download=False) 36 | info_dict = ydl.sanitize_info(info) 37 | 38 | logger.info(f"Video title: {info_dict['title']}") 39 | sanitized_filename = ydl.prepare_filename(info) 40 | logger.info(f"Sanitized filename will be: {sanitized_filename}") 41 | 42 | output_path = video_dir / sanitized_filename 43 | if (not output_path.exists() and not force) or force: 44 | ydl.download([video_url]) 45 | logger.info("Video downloaded successfully!") 46 | 47 | bg_files = list(video_dir.glob("*.mp4")) 48 | return bg_files 49 | -------------------------------------------------------------------------------- /ShortsMaker/utils/get_tts.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import io 3 | import textwrap 4 | from pprint import pformat 5 | from threading import Thread 6 | 7 | import requests 8 | from pydub import AudioSegment 9 | 10 | from .logging_config import get_logger 11 | from .retry import retry 12 | 13 | logger = get_logger(__name__) 14 | 15 | # define the endpoint data with URLs and corresponding response keys 16 | ENDPOINT_DATA = [ 17 | { 18 | "url": "https://tiktok-tts.weilnet.workers.dev/api/generation", 19 | "response": "data", 20 | }, 21 | {"url": "https://countik.com/api/text/speech", "response": "v_data"}, 22 | {"url": "https://gesserit.co/api/tiktok-tts", "response": "base64"}, 23 | ] 24 | 25 | # define available voices for text-to-speech conversion 26 | VOICES = [ 27 | "en_us_001", # English US - Female (Int. 1) 28 | "en_us_002", # English US - Female (Int. 2) 29 | "en_au_002", # English AU - Male 30 | "en_uk_001", # English UK - Male 1 31 | "en_uk_003", # English UK - Male 2 32 | "en_us_006", # English US - Male 1 33 | "en_us_010", # English US - Male 4 34 | "en_female_emotional", # peaceful 35 | ] 36 | 37 | 38 | # define the text-to-speech function 39 | @retry(max_retries=3, delay=5) 40 | def tts(text: str, voice: str, output_filename: str = "output.mp3") -> None: 41 | """ 42 | Converts text to speech using specified voice and saves to output file. 43 | 44 | Args: 45 | text (str): Input text to convert 46 | voice (str): Voice ID to use 47 | output_filename (str): Output audio file path 48 | 49 | Raises: 50 | ValueError: If voice is invalid or text is empty 51 | """ 52 | _validate_inputs(text, voice) 53 | chunks = _split_text(text) 54 | _log_chunks(text, chunks) 55 | 56 | global ENDPOINT_DATA 57 | 58 | for endpoint in ENDPOINT_DATA: 59 | audio_data = [""] * len(chunks) 60 | audio_data = _process_chunks(chunks, endpoint, voice, audio_data) 61 | if audio_data is not None: 62 | _save_audio(audio_data, output_filename) 63 | break 64 | 65 | 66 | def _validate_inputs(text: str, voice: str) -> None: 67 | if voice not in VOICES: 68 | raise ValueError("voice must be valid") 69 | if not text: 70 | raise ValueError("text must not be 'None'") 71 | 72 | 73 | def _log_chunks(text: str, chunks: list[str]) -> None: 74 | logger.info(f"text: {text}") 75 | logger.info(f"Split text into {len(chunks)} chunks") 76 | for chunk in chunks: 77 | logger.info(f"Chunk: {chunk}") 78 | 79 | 80 | def _process_chunks( 81 | chunks: list[str], endpoint: dict, voice: str, audio_data: list[str] 82 | ) -> list[str] | None: 83 | valid = True 84 | 85 | def generate_audio_chunk(index: int, chunk: str) -> None: 86 | nonlocal valid 87 | if not valid: 88 | return 89 | 90 | try: 91 | logger.info(f"Using endpoint: {endpoint['url']}") 92 | response = requests.post( 93 | endpoint["url"], 94 | json={"text": chunk, "voice": voice}, 95 | headers={ 96 | "User-Agent": "com.zhiliaoapp.musically/2022600030 (Linux; U; Android 7.1.2; es_ES; SM-G988N; Build/NRD90M;tt-ok/3.12.13.1)", 97 | }, 98 | ) 99 | if response.status_code == 200: 100 | audio_data[index] = response.json()[endpoint["response"]] 101 | logger.info( 102 | f"Chunk {index} processed successfully with endpoint: {endpoint['url']}" 103 | ) 104 | else: 105 | logger.warning( 106 | f"Endpoint failed with status {response.status_code}: {endpoint['url']}" 107 | ) 108 | valid = False 109 | except requests.exceptions.JSONDecodeError as e: 110 | logger.warning(f"JSONDecodeError for endpoint {endpoint['url']}: {e}") 111 | logger.error(f"RequestException for endpoint {endpoint['url']}: {e}") 112 | valid = False 113 | except requests.RequestException as e: 114 | logger.warning(f"Response from endpoint {endpoint['url']}:\n{pformat(response.json)}") 115 | logger.error(f"RequestException for endpoint {endpoint['url']}: {e}") 116 | valid = False 117 | 118 | threads = [ 119 | Thread(target=generate_audio_chunk, args=(i, chunk)) for i, chunk in enumerate(chunks) 120 | ] 121 | for thread in threads: 122 | thread.start() 123 | for thread in threads: 124 | thread.join() 125 | 126 | return audio_data if valid else None 127 | 128 | 129 | def _save_audio(audio_data: list[str], output_filename: str) -> None: 130 | audio_bytes = b"".join(base64.b64decode(chunk) for chunk in audio_data) 131 | audio_segment: AudioSegment = AudioSegment.from_file(io.BytesIO(audio_bytes)) 132 | audio_segment.export(output_filename, format="wav") 133 | 134 | 135 | def _split_text(text: str, chunk_size: int = 250) -> list[str]: 136 | """ 137 | Splits a given text into smaller chunks of a specified size without breaking 138 | words or splitting on hyphens. 139 | 140 | The function wraps the input text into smaller substrings, ensuring the 141 | integrity of the text by preventing cutoff mid-word or mid-hyphen. Each chunk 142 | is at most of the specified chunk size. 143 | 144 | Args: 145 | text (str): The input text to be split into smaller chunks. 146 | 147 | Returns: 148 | list[str]: A list of text chunks where each chunk is at most the 149 | specified size while preserving word integrity. 150 | """ 151 | 152 | text_list = textwrap.wrap( 153 | text, width=chunk_size, break_long_words=False, break_on_hyphens=False 154 | ) 155 | 156 | return text_list 157 | -------------------------------------------------------------------------------- /ShortsMaker/utils/logging_config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from pathlib import Path 4 | 5 | from colorlog import ColoredFormatter 6 | 7 | # prevent huggingface symlink warnings 8 | os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "true" 9 | 10 | # Global configuration 11 | LOG_FILE: Path | str = "ShortsMaker.log" 12 | LOG_LEVEL: str = "DEBUG" 13 | LOGGING_ENABLED: bool = True 14 | INITIALIZED: bool = False 15 | 16 | # Cache of configured loggers 17 | LOGGERS: dict[str, logging.Logger] = {} 18 | 19 | 20 | def get_logger(name: str = __name__) -> logging.Logger: 21 | """ 22 | Get a logger with the specified name, typically __name__ from the module. 23 | If the logging system has not been initialized, it will use default settings. 24 | 25 | Args: 26 | name (str): The logger name, typically passed as __name__ from the module. 27 | This ensures proper hierarchical naming of loggers. 28 | 29 | Returns: 30 | logging.Logger: A configured logger instance. 31 | """ 32 | global LOG_FILE, LOG_LEVEL, LOGGING_ENABLED, INITIALIZED, LOGGERS 33 | 34 | # Initialize logging system with defaults if not already done 35 | if not INITIALIZED: 36 | configure_logging() 37 | 38 | # Return existing logger if already configured 39 | if name in LOGGERS: 40 | return LOGGERS[name] 41 | 42 | # Create a new logger 43 | logger = logging.getLogger(name) 44 | 45 | # Don't add handlers if this is a child logger 46 | # Parent loggers will handle it through hierarchy 47 | if not logger.handlers: 48 | # Create Console 49 | console_handler = logging.StreamHandler() 50 | color_formatter = ColoredFormatter( 51 | "{log_color}{asctime} - {name} - {funcName} - {levelname} - {message}", 52 | style="{", 53 | datefmt="%Y-%m-%d %H:%M:%S", 54 | log_colors={ 55 | "DEBUG": "cyan", 56 | "INFO": "green", 57 | "WARNING": "yellow", 58 | "ERROR": "red", 59 | "CRITICAL": "bold_red", 60 | }, 61 | reset=True, 62 | ) 63 | console_handler.setFormatter(color_formatter) 64 | 65 | # Create File Handler 66 | file_handler = logging.FileHandler(LOG_FILE, mode="a", encoding="utf-8") 67 | formatter = logging.Formatter( 68 | "{asctime} - {name} - {funcName} - {levelname} - {message}", 69 | style="{", 70 | datefmt="%Y-%m-%d %H:%M:%S", 71 | ) 72 | file_handler.setFormatter(formatter) 73 | 74 | # Add handlers to logger 75 | logger.addHandler(console_handler) 76 | logger.addHandler(file_handler) 77 | 78 | # Set logging level based on enable flag 79 | if LOGGING_ENABLED: 80 | logger.setLevel(LOG_LEVEL) 81 | else: 82 | logger.setLevel(logging.CRITICAL) 83 | 84 | # Prevent propagation to avoid duplicate logs 85 | logger.propagate = False 86 | 87 | # Store the configured logger 88 | LOGGERS[name] = logger 89 | 90 | return logger 91 | 92 | 93 | def configure_logging( 94 | log_file: Path | str = LOG_FILE, level: str | int = LOG_LEVEL, enable: bool = LOGGING_ENABLED 95 | ) -> None: 96 | """ 97 | Configure the global logging settings. 98 | This function can be called by users to customize logging behavior. 99 | 100 | Args: 101 | log_file (str | Path): Path to the log file. 102 | level (str | int): Logging level ('DEBUG', 'INFO', etc.). 103 | enable (bool): Whether to enable logging. 104 | """ 105 | global LOG_FILE, LOG_LEVEL, LOGGING_ENABLED, INITIALIZED, LOGGERS 106 | 107 | # Update configuration with provided values 108 | LOG_FILE = Path(log_file) if isinstance(log_file, str) else log_file 109 | LOG_LEVEL = level 110 | LOGGING_ENABLED = enable 111 | 112 | # Create log directory if it doesn't exist 113 | LOG_FILE.parent.mkdir(parents=True, exist_ok=True) 114 | 115 | # Update all existing loggers with new settings 116 | for logger_name, logger in LOGGERS.items(): 117 | # Update log level 118 | if LOGGING_ENABLED: 119 | logger.setLevel(LOG_LEVEL) 120 | else: 121 | logger.setLevel(logging.CRITICAL) 122 | 123 | INITIALIZED = True 124 | -------------------------------------------------------------------------------- /ShortsMaker/utils/notify_discord.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import textwrap 4 | from random import choice, randint 5 | 6 | import requests 7 | from bs4 import BeautifulSoup 8 | from discord_webhook import DiscordEmbed, DiscordWebhook 9 | from requests import Response 10 | 11 | if not os.environ.get("DISCORD_WEBHOOK_URL"): 12 | raise ValueError("DISCORD_WEBHOOK_URL not set, Please set it in your environment variables.") 13 | 14 | DISCORD_URL = os.environ.get("DISCORD_WEBHOOK_URL") 15 | 16 | 17 | def get_arthas(): 18 | """ 19 | Fetches a random Arthas image URL from Bing image search results. 20 | 21 | This function sends a search request to Bing images for the keyword 'arthas' 22 | and retrieves a specific page of the search results. It extracts image URLs 23 | from the returned HTML content and returns one randomly selected image URL. 24 | 25 | Returns: 26 | str: A randomly selected URL of an Arthas image. 27 | 28 | Raises: 29 | requests.exceptions.RequestException: If the HTTP request fails or encounters an issue. 30 | """ 31 | url = f"https://www.bing.com/images/search?q=arthas&first={randint(1, 10)}" 32 | response = requests.get( 33 | url, 34 | headers={ 35 | "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) \ 36 | Chrome/50.0.2661.102 Safari/537.36" 37 | }, 38 | ) 39 | soup = BeautifulSoup(response.content, "lxml") 40 | divs = soup.find_all("div", class_="imgpt") 41 | imgs = [] 42 | for div in divs: 43 | img = div.find("a")["m"] 44 | img = img.split("murl")[1].split('"')[2] 45 | imgs.append(img) 46 | return choice(imgs) 47 | 48 | 49 | def get_meme(): 50 | """ 51 | Fetches a meme image URL from the meme-api.com API. 52 | 53 | This function uses the "gimme" endpoint of the meme-api.com to fetch a 54 | meme image URL. The API allows for specifying a subreddit and quantity 55 | of meme images. The default behavior is to fetch a random meme from the 56 | API. The response is parsed and the URL of the image is extracted and 57 | returned. The function applies a User-Agent header as part of the 58 | request to prevent potential issues with the API. 59 | 60 | Returns: 61 | str: The URL of the meme image. 62 | 63 | Raises: 64 | JSONDecodeError: If the response content could not be decoded into JSON. 65 | KeyError: If the expected "url" key is not found in the JSON response. 66 | RequestException: If there is an issue with the HTTP request. 67 | 68 | """ 69 | # Uses https://meme-api.com/gimme/ 70 | # Can use custom subreddit and return multiple images 71 | # Endpoint: /gimme/{subreddit}/{count} 72 | # Example: https://meme-api.com/gimme/wholesomememes/2 73 | # Returns: 74 | # Image url 75 | # Looks like the below endpoint is not working anymore 76 | # url = "https://meme-api.com/gimme" 77 | 78 | url = "https://memeapi.zachl.tech/pic/json" 79 | 80 | response = requests.get( 81 | url, 82 | headers={ 83 | "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) \ 84 | Chrome/50.0.2661.102 Safari/537.36" 85 | }, 86 | ) 87 | 88 | soup = BeautifulSoup(response.content, "html.parser") 89 | return json.loads(soup.text)["MemeURL"] 90 | 91 | 92 | def notify_discord(message) -> Response: 93 | """ 94 | Sends a notification message to a Discord webhook, splitting messages longer than the character limit, 95 | and embedding additional information such as title, description, and images. 96 | 97 | Args: 98 | message (str): The message content to be sent to the Discord webhook. If the message exceeds 4000 99 | characters, it will be split into smaller parts. 100 | 101 | Returns: 102 | Response: The response object resulting from the webhook execution, which contains information 103 | such as status code and response text. 104 | """ 105 | messages = textwrap.wrap(message, 4000) 106 | response = None 107 | 108 | for message in messages: 109 | webhook = DiscordWebhook(url=DISCORD_URL, rate_limit_retry=True) 110 | 111 | embed = DiscordEmbed() 112 | embed.set_title(":warning:Error found while running the Automation!:warning:") 113 | embed.set_description(f"{message}") 114 | embed.set_image(url=get_meme()) 115 | 116 | try: 117 | embed.set_thumbnail(url=get_arthas()) 118 | except Exception as e: 119 | print(f"Error fetching arthas: {e}") 120 | 121 | embed.set_color("ff0000") 122 | embed.set_timestamp() 123 | webhook.add_embed(embed) 124 | 125 | response = webhook.execute() 126 | print(response.status_code) 127 | print(response.text) 128 | return response 129 | -------------------------------------------------------------------------------- /ShortsMaker/utils/retry.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import time 3 | 4 | from .logging_config import get_logger 5 | from .notify_discord import notify_discord 6 | 7 | logger = get_logger(__name__) 8 | 9 | 10 | def retry(max_retries: int, delay: int, notify: bool = False): 11 | """ 12 | A retry decorator function that allows retrying a function based on the specified 13 | number of retries, delay between retries, and an option to send a notification upon 14 | failure. It logs all execution details, including successful executions, exceptions, 15 | and retry attempts. 16 | 17 | Args: 18 | max_retries (int): The maximum number of times the function should be retried 19 | in case of an exception. 20 | delay (int): The time in seconds to wait before retrying the function after 21 | a failure. 22 | notify (bool): Whether to send a notification if the function fails after 23 | reaching the maximum number of retries. Default is False. 24 | 25 | Returns: 26 | Callable: A decorator function that applies the retry logic to the decorated 27 | function. 28 | 29 | Raises: 30 | Exception: If all retries are exhausted and the function still fails, the 31 | exception from the last attempt will be raised. 32 | 33 | Example: 34 | @retry(max_retries=3, delay=2) 35 | def my_function(): 36 | # Function implementation 37 | pass 38 | """ 39 | 40 | def decorator(func): 41 | @functools.wraps(func) 42 | def wrapper(*args, **kwargs): 43 | start_time = time.perf_counter() 44 | logger.info(f"Using retry decorator with {max_retries} max_retries and {delay}s delay") 45 | logger.info(f"Begin function {func.__name__}") 46 | err = "Before running" 47 | for attempt in range(max_retries): 48 | try: 49 | value = func(*args, **kwargs) 50 | logger.info(f"Returned: {value}") 51 | logger.info( 52 | f"Completed function {func.__name__} in {round(time.perf_counter() - start_time, 2)}s after {attempt + 1} max_retries" 53 | ) 54 | return value 55 | except Exception as e: 56 | logger.exception(f"Exception: {e}") 57 | logger.warning(f"Retrying function {func.__name__} after {delay}s") 58 | err = str(e) 59 | time.sleep(delay) 60 | if notify: 61 | notify_discord( 62 | f"{func.__name__} Failed after {max_retries} max_retries.\nException: {err}" 63 | ) 64 | logger.exception(f"Failed after {max_retries} max_retries") 65 | 66 | return wrapper 67 | 68 | return decorator 69 | -------------------------------------------------------------------------------- /assets/credits/credits.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajathjn/shorts_maker/cb3e1823e3f271db08a969b7b7a8a5083d09df8d/assets/credits/credits.mp4 -------------------------------------------------------------------------------- /assets/credits/credits_mask.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajathjn/shorts_maker/cb3e1823e3f271db08a969b7b7a8a5083d09df8d/assets/credits/credits_mask.mp4 -------------------------------------------------------------------------------- /assets/fonts/Monaco.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajathjn/shorts_maker/cb3e1823e3f271db08a969b7b7a8a5083d09df8d/assets/fonts/Monaco.ttf -------------------------------------------------------------------------------- /assets/fonts/Roboto.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajathjn/shorts_maker/cb3e1823e3f271db08a969b7b7a8a5083d09df8d/assets/fonts/Roboto.ttf -------------------------------------------------------------------------------- /example.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import yaml 4 | 5 | from ShortsMaker import MoviepyCreateVideo, ShortsMaker 6 | 7 | setup_file = "setup.yml" 8 | 9 | with open(setup_file) as f: 10 | cfg = yaml.safe_load(f) 11 | 12 | get_post = ShortsMaker(setup_file) 13 | 14 | # You can either provide an URL for the reddit post 15 | get_post.get_reddit_post( 16 | url="https://www.reddit.com/r/Python/comments/1j36d7a/i_got_tired_of_ai_shorts_scams_so_i_built_my_own/" 17 | ) 18 | # Or just run the method to get a random post from the subreddit defined in setup.yml 19 | # get_post.get_reddit_post() 20 | 21 | with open(Path(cfg["cache_dir"]) / cfg["reddit_post_getter"]["record_file_txt"]) as f: 22 | script = f.read() 23 | 24 | get_post.generate_audio( 25 | source_txt=script, 26 | output_audio=f"{cfg['cache_dir']}/{cfg['audio']['output_audio_file']}", 27 | output_script_file=f"{cfg['cache_dir']}/{cfg['audio']['output_script_file']}", 28 | ) 29 | 30 | get_post.generate_audio_transcript( 31 | source_audio_file=f"{cfg['cache_dir']}/{cfg['audio']['output_audio_file']}", 32 | source_text_file=f"{cfg['cache_dir']}/{cfg['audio']['output_script_file']}", 33 | ) 34 | 35 | get_post.quit() 36 | 37 | create_video = MoviepyCreateVideo( 38 | config_file=setup_file, 39 | # speed_factor=1.25, # Set the speed factor for the video 40 | ) 41 | 42 | create_video(output_path="assets/output.mp4") 43 | 44 | create_video.quit() 45 | 46 | # Do not run the below when you are using shorts_maker within a container. 47 | 48 | # ask_llm = AskLLM(config_file=setup_file) 49 | # result = ask_llm.invoke(script) 50 | # print(result["parsed"].title) 51 | # print(result["parsed"].description) 52 | # print(result["parsed"].tags) 53 | # print(result["parsed"].thumbnail_description) 54 | # ask_llm.quit_llm() 55 | 56 | # You can use, AskLLM to generate a text prompt for the image generation as well 57 | # image_description = ask_llm.invoke_image_describer(script = script, input_text = "A wild scenario") 58 | # print(image_description) 59 | # print(image_description["parsed"].description) 60 | 61 | # Generate image uses a lot of resources so beware 62 | # generate_image = GenerateImage(config_file=setup_file) 63 | # generate_image.use_huggingface_flux_schnell(image_description["parsed"].description, "output.png") 64 | # generate_image.quit() 65 | -------------------------------------------------------------------------------- /example.setup.yml: -------------------------------------------------------------------------------- 1 | --- 2 | # Used only for fetching models for Image Generation 3 | hugging_face_access_token: "your_hugging_face_access_token_here" 4 | 5 | # A cache dir for storing results 6 | cache_dir: "cache" 7 | 8 | # A dir in which your background_videos, background_music, credits and fonts folders are located 9 | assets_dir: "assets" 10 | 11 | # Best to leave this as is 12 | retry: 13 | max_retries: 3 14 | delay: 5 15 | notify: False 16 | 17 | # Refer to https://business.reddithelp.com/s/article/Create-a-Reddit-Application 18 | # for more information on how to create a Reddit application 19 | # and get your client_id and client_secret 20 | reddit_praw: 21 | client_id: "your_reddit_client_id" 22 | client_secret: "your_reddit_client_secret" 23 | user_agent: "your_user_agent_here" 24 | 25 | # Replace with your own subreddit name 26 | reddit_post_getter: 27 | subreddit_name: "your_subreddit_name" 28 | record_file_json: "post.json" 29 | record_file_txt: "post.txt" 30 | 31 | # If you are not using your cuda device, set device to "cpu" 32 | # Refer https://github.com/m-bain/whisperX for more information 33 | audio: 34 | output_script_file: "generated_audio_script.txt" 35 | output_audio_file: "output.wav" 36 | transcript_json: "transcript.json" 37 | device: "cpu" # or "cuda" 38 | model: "large-v2" # or "medium" 39 | batch_size: 16 # or 32 40 | compute_type: "int8" # or "float16" 41 | 42 | # Replace with the video URLs and music URLs you want to use 43 | # Only YouTube URLs are supported 44 | # Note: If you want to avoid setting this, 45 | # Pass the path to the respective objects, when initializing the MoviepyCreateVideo class 46 | video: 47 | background_videos_urls: 48 | # https://www.youtube.com/watch?v=n_Dv4JMiwK8 49 | - "https://www.youtube.com/watch?v=example_video_id" 50 | background_music_urls: 51 | # https://www.youtube.com/watch?v=G8a45UZJGh4&t=1s 52 | - "https://www.youtube.com/watch?v=example_music_id" 53 | font_dir: "fonts" 54 | credits_dir: "credits" 55 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "ShortsMaker" 3 | version = "0.2.0" 4 | description = "A python project to make and upload Youtube shorts video" 5 | authors = [ 6 | {name = "rajathjn",email = "rajathjnx@gmail.com"} 7 | ] 8 | license = "AGPL-3.0-or-later" 9 | readme = "README.md" 10 | requires-python = ">=3.12,<3.13" 11 | classifiers = [ 12 | "Development Status :: 4 - Beta", 13 | "Intended Audience :: Developers", 14 | "Programming Language :: Python :: 3.12", 15 | ] 16 | dependencies = [ 17 | "accelerate>=1.3.0", 18 | "beautifulsoup4>=4.13.3", 19 | "colorlog>=6.9.0", 20 | "diffusers>=0.32.2", 21 | "discord-webhook>=1.3.1", 22 | "ftfy>=6.3.1", 23 | "h11>=0.16.0", 24 | "jinja2>=3.1.6", 25 | "langchain-ollama>=0.3.2", 26 | "language-tool-python>=2.8.2", 27 | "lxml>=5.3.1", 28 | "moviepy>=2.1.2", 29 | "ollama>=0.4.7", 30 | "praw>=7.8.1", 31 | "psutil>=6.1.1", 32 | "pydub>=0.25.1", 33 | "pyyaml>=6.0.2", 34 | "rapidfuzz>=3.12.1", 35 | "requests>=2.32.3", 36 | "setuptools>=75.8.0", 37 | "transformers>=4.48.2", 38 | "unidecode>=1.3.8", 39 | "wheel>=0.45.1", 40 | "whisperx>=3.3.1", 41 | "yt-dlp>=2025.3.31", 42 | ] 43 | 44 | [dependency-groups] 45 | dev = [ 46 | "pre-commit>=4.1.0", 47 | "pytest>=8.3.4", 48 | "pytest-cov>=6.0.0", 49 | "pytest-mock>=3.14.0", 50 | "requests-mock>=1.12.1", 51 | "ruff>=0.9.5", 52 | ] 53 | 54 | [project.optional-dependencies] 55 | cpu = [ 56 | "torch>=2.7.0", 57 | "torchaudio>=2.7.0", 58 | "torchvision>=0.22.0", 59 | ] 60 | cu128 = [ 61 | "torch>=2.7.0", 62 | "torchaudio>=2.7.0", 63 | "torchvision>=0.22.0", 64 | ] 65 | 66 | [tool.uv] 67 | conflicts = [ 68 | [ 69 | { extra = "cpu" }, 70 | { extra = "cu128" }, 71 | ], 72 | ] 73 | 74 | [tool.uv.sources] 75 | torch = [ 76 | { index = "pytorch-cpu", extra = "cpu" }, 77 | { index = "pytorch-cu128", extra = "cu128" }, 78 | ] 79 | torchvision = [ 80 | { index = "pytorch-cpu", extra = "cpu" }, 81 | { index = "pytorch-cu128", extra = "cu128" }, 82 | ] 83 | torchaudio = [ 84 | { index = "pytorch-cpu", extra = "cpu" }, 85 | { index = "pytorch-cu128", extra = "cu128" }, 86 | ] 87 | 88 | [[tool.uv.index]] 89 | name = "pytorch-cpu" 90 | url = "https://download.pytorch.org/whl/cpu" 91 | explicit = true 92 | 93 | [[tool.uv.index]] 94 | name = "pytorch-cu128" 95 | url = "https://download.pytorch.org/whl/cu128" 96 | explicit = true 97 | 98 | [tool.pytest.ini_options] 99 | testpaths = [ 100 | "tests", 101 | ] 102 | 103 | [tool.ruff] 104 | # Set the maximum line length to 79. 105 | exclude = [ 106 | ".git", 107 | ".ipynb_checkpoints", 108 | ".mypy_cache", 109 | ".nox", 110 | ".pyenv", 111 | ".pytest_cache", 112 | ".pytype", 113 | ".ruff_cache", 114 | ".tox", 115 | "_build", 116 | "build", 117 | "venv", 118 | ".pytest_cache", 119 | ".venv", 120 | ".vscode", 121 | ".idea" 122 | ] 123 | line-length = 100 124 | indent-width = 4 125 | fix = true 126 | 127 | [tool.ruff.lint] 128 | extend-select = [ 129 | "UP", # pyupgrade 130 | "I", # isort 131 | ] 132 | fixable = ["ALL"] 133 | 134 | [tool.ruff.format] 135 | quote-style = "double" 136 | indent-style = "space" 137 | skip-magic-trailing-comma = false 138 | line-ending = "auto" 139 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajathjn/shorts_maker/cb3e1823e3f271db08a969b7b7a8a5083d09df8d/tests/__init__.py -------------------------------------------------------------------------------- /tests/ask_llm_tests/test_ask_llm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | from pathlib import Path 4 | from unittest.mock import MagicMock, patch 5 | 6 | import pytest 7 | 8 | from ShortsMaker.ask_llm import AskLLM, OllamaServiceManager, YoutubeDetails 9 | 10 | 11 | @pytest.fixture 12 | def mock_ollama_service_manager(): 13 | """ 14 | Fixture to provide a mocked instance of OllamaServiceManager. 15 | 16 | Returns: 17 | MagicMock: A mocked instance of OllamaServiceManager with predefined return values for start_service and stop_service methods. 18 | """ 19 | ollama_service_manager = MagicMock(OllamaServiceManager) 20 | ollama_service_manager.start_service.return_value = True 21 | ollama_service_manager.stop_service.return_value = True 22 | return ollama_service_manager 23 | 24 | 25 | @patch("ShortsMaker.ask_llm.AskLLM._load_llm_model") 26 | def test_initialization_with_valid_config( 27 | mock_load_llm_model, setup_file, mock_ollama_service_manager 28 | ): 29 | mock_load_llm_model.return_value = None 30 | ask_llm = AskLLM(config_file=setup_file, model_name="test_model") 31 | assert ask_llm.model_name == "test_model" 32 | 33 | 34 | def test_initialization_with_invalid_config_path(): 35 | invalid_config_path = Path("temp.yml") 36 | with pytest.raises(FileNotFoundError): 37 | AskLLM(config_file=invalid_config_path) 38 | 39 | 40 | def test_initialization_with_non_yaml_file(setup_file, tmp_path): 41 | with pytest.raises(ValueError): 42 | temp_path = tempfile.NamedTemporaryFile(suffix=".txt", delete=False) 43 | temp_path = Path(temp_path.name) 44 | try: 45 | AskLLM(config_file=temp_path) 46 | finally: 47 | temp_path.unlink() 48 | 49 | 50 | @patch("ShortsMaker.ask_llm.AskLLM._load_llm_model") 51 | def test_llm_model_loading(mock_load_llm_model, mock_ollama_service_manager, setup_file): 52 | AskLLM(config_file=setup_file, model_name="test_model") 53 | mock_load_llm_model.assert_called_once_with("test_model", 0) 54 | 55 | 56 | @patch("ShortsMaker.ask_llm.AskLLM._load_llm_model") 57 | def test_invoke_creates_chat_prompt(mock_load_llm_model, setup_file, mock_ollama_service_manager): 58 | ask_llm = AskLLM(config_file=setup_file) 59 | ask_llm.ollama_service_manager = mock_ollama_service_manager 60 | ask_llm.llm = MagicMock() 61 | ask_llm.llm.with_structured_output.return_value = ask_llm.llm 62 | ask_llm.llm.invoke.return_value = {"title": "Test Title"} 63 | 64 | input_text = "Test script input." 65 | response = ask_llm.invoke(input_text=input_text) 66 | 67 | assert "title" in response 68 | ask_llm.llm.with_structured_output.assert_called_once_with(YoutubeDetails, include_raw=True) 69 | ask_llm.llm.invoke.assert_called_once() 70 | 71 | 72 | @patch("ShortsMaker.ask_llm.AskLLM._load_llm_model") 73 | @patch("ShortsMaker.ask_llm.OllamaServiceManager.stop_service") 74 | @patch("ShortsMaker.ask_llm.subprocess.check_output") 75 | @patch("ShortsMaker.ask_llm.subprocess.run") 76 | def test_quit_llm_with_self_started_service( 77 | mock_load_llm_model, 78 | mock_stop_service, 79 | mock_check_output, 80 | mock_run, 81 | setup_file, 82 | mock_ollama_service_manager, 83 | ): 84 | mock_load_llm_model.return_value = None 85 | 86 | ask_llm = AskLLM(config_file=setup_file, model_name="test_model") 87 | ask_llm.self_started_ollama = True 88 | 89 | result = ask_llm.quit_llm() 90 | assert result is None 91 | mock_stop_service.assert_called_once() 92 | 93 | 94 | @pytest.mark.skipif("RUNALL" not in os.environ, reason="takes too long") 95 | def test_ask_llm_working(setup_file): 96 | script = "A video about a cat. Doing stunts like running around, flying, and jumping." 97 | ask_llm = AskLLM(config_file=setup_file) 98 | result = ask_llm.invoke(input_text=script) 99 | ask_llm.quit_llm() 100 | assert result["parsed"].title == "Feline Frenzy: Cat Stunt Master!" 101 | assert ( 102 | result["parsed"].description 103 | == "Get ready for the most epic feline feats you've ever seen! Watch as our fearless feline friend runs, jumps, and even flies through a series of death-defying stunts." 104 | ) 105 | assert result["parsed"].tags == ["cat", "stunts", "flying", "jumping"] 106 | assert ( 107 | result["parsed"].thumbnail_description 108 | == "A cat in mid-air, performing a daring stunt with its paws outstretched, surrounded by a blurred cityscape with bright lights and colors." 109 | ) 110 | assert result["parsing_error"] is None 111 | assert result is not None 112 | -------------------------------------------------------------------------------- /tests/ask_llm_tests/test_ollama_service.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | import pytest 4 | 5 | from ShortsMaker.ask_llm import OllamaServiceManager 6 | 7 | 8 | @pytest.fixture 9 | def ollama_service_manager(): 10 | """Fixture to provide an instance of OllamaServiceManager with a mock logger. 11 | 12 | Args: 13 | mock_logger: A mock logger for the OllamaServiceManager. 14 | 15 | Returns: 16 | An instance of OllamaServiceManager. 17 | """ 18 | return OllamaServiceManager() 19 | 20 | 21 | @patch("ShortsMaker.ask_llm.subprocess.Popen") 22 | def test_start_service(mock_popen, ollama_service_manager): 23 | process_mock = MagicMock() 24 | process_mock.poll.return_value = None 25 | mock_popen.return_value = process_mock 26 | 27 | result = ollama_service_manager.start_service() 28 | assert result is True 29 | mock_popen.assert_called_once() 30 | 31 | 32 | @patch("ShortsMaker.ask_llm.subprocess.run") 33 | def test_stop_service_on_windows(mock_subprocess_run, ollama_service_manager): 34 | ollama_service_manager.process = MagicMock() 35 | ollama_service_manager.system = "windows" 36 | result = ollama_service_manager.stop_service() 37 | assert result is True 38 | mock_subprocess_run.assert_called() 39 | 40 | 41 | @patch("ShortsMaker.ask_llm.psutil.process_iter") 42 | def test_is_ollama_running(mock_process_iter, ollama_service_manager): 43 | mock_process_iter.return_value = [ 44 | MagicMock(info={"name": "Ollama"}), 45 | ] 46 | result = ollama_service_manager.is_ollama_running() 47 | assert result is True 48 | 49 | 50 | def test_is_service_running_with_active_process(ollama_service_manager): 51 | ollama_service_manager.process = MagicMock() 52 | ollama_service_manager.process.poll.return_value = None 53 | 54 | result = ollama_service_manager.is_service_running() 55 | assert result is True 56 | 57 | 58 | @patch("ShortsMaker.ask_llm.subprocess.check_output") 59 | def test_stop_running_model(mock_check_output, ollama_service_manager): 60 | model_name = "test_model" 61 | mock_check_output.return_value = "Stopped" 62 | 63 | result = ollama_service_manager.stop_running_model(model_name) 64 | assert result is True 65 | mock_check_output.assert_called_with(["ollama", "stop", model_name], stderr=-2, text=True) 66 | 67 | 68 | @patch("ShortsMaker.ask_llm.ollama.ps") 69 | def test_get_running_models(mock_ollama_ps, ollama_service_manager): 70 | mock_ollama_ps.return_value = ["model1", "model2"] 71 | 72 | result = ollama_service_manager.get_running_models() 73 | assert result == ["model1", "model2"] 74 | 75 | 76 | @patch("ShortsMaker.ask_llm.ollama.pull") 77 | def test_get_llm_model(mock_ollama_pull, ollama_service_manager): 78 | model_name = "test_model" 79 | mock_ollama_pull.return_value = "model_data" 80 | 81 | result = ollama_service_manager.get_llm_model(model_name) 82 | assert result == "model_data" 83 | mock_ollama_pull.assert_called_with(model_name) 84 | 85 | 86 | @patch("ShortsMaker.ask_llm.ollama.list") 87 | def test_get_list_of_downloaded_files(mock_ollama_list, ollama_service_manager): 88 | mock_ollama_list.return_value = [("models", [MagicMock(), MagicMock(), MagicMock()])] 89 | mock_ollama_list.return_value[0][1][0].model = "submodel1" 90 | mock_ollama_list.return_value[0][1][1].model = "submodel2" 91 | mock_ollama_list.return_value[0][1][2].model = "submodel3" 92 | 93 | result = ollama_service_manager.get_list_of_downloaded_files() 94 | assert result == ["submodel1", "submodel2", "submodel3"] 95 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | import pytest 5 | 6 | from ShortsMaker import ShortsMaker 7 | 8 | 9 | @pytest.fixture 10 | def setup_file(): 11 | return Path(__file__).parent / "data" / "setup.yml" 12 | 13 | 14 | @pytest.fixture 15 | def shorts_maker(setup_file): 16 | return ShortsMaker(setup_file) 17 | 18 | 19 | @pytest.fixture 20 | def mock_logger(): 21 | return logging.getLogger("test_logger") 22 | -------------------------------------------------------------------------------- /tests/data/setup.yml: -------------------------------------------------------------------------------- 1 | --- 2 | hugging_face_access_token: "random_token" 3 | 4 | cache_dir: "cache" 5 | assets_dir: "assets" 6 | 7 | retry: 8 | max_retries: 3 9 | delay: 5 10 | notify: False 11 | 12 | reddit_praw: 13 | client_id: blah 14 | client_secret: blah 15 | user_agent: blah 16 | 17 | reddit_post_getter: 18 | subreddit_name: RandomSubreddit 19 | record_file_json: post.json 20 | record_file_txt: post.txt 21 | 22 | audio: 23 | output_script_file: generated_audio_script.txt 24 | output_audio_file: output.wav 25 | transcript_json: transcript.json 26 | device: cuda 27 | model: large-v2 28 | batch_size: 16 29 | compute_type: float16 30 | 31 | # Random videos 32 | video: 33 | background_videos_urls: 34 | - https://www.youtube.com/watch?v=n_Dv4JMiwK8 35 | background_music_urls: 36 | - https://www.youtube.com/watch?v=G8a45UZJGh4&t=1s 37 | font_dir: fonts 38 | credits_dir: credits 39 | -------------------------------------------------------------------------------- /tests/data/test.txt: -------------------------------------------------------------------------------- 1 | This is a test to generate audio. 2 | -------------------------------------------------------------------------------- /tests/data/test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rajathjn/shorts_maker/cb3e1823e3f271db08a969b7b7a8a5083d09df8d/tests/data/test.wav -------------------------------------------------------------------------------- /tests/data/transcript.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "end": 0.352, 4 | "start": 0.151, 5 | "word": "This" 6 | }, 7 | { 8 | "end": 0.552, 9 | "start": 0.472, 10 | "word": "is" 11 | }, 12 | { 13 | "end": 0.633, 14 | "start": 0.613, 15 | "word": "a" 16 | }, 17 | { 18 | "end": 0.974, 19 | "start": 0.693, 20 | "word": "test" 21 | }, 22 | { 23 | "end": 1.114, 24 | "start": 1.014, 25 | "word": "to" 26 | }, 27 | { 28 | "end": 1.636, 29 | "start": 1.154, 30 | "word": "generate" 31 | }, 32 | { 33 | "end": 2.017, 34 | "start": 1.756, 35 | "word": "audio." 36 | } 37 | ] 38 | -------------------------------------------------------------------------------- /tests/generate_image_tests/test_generate_image.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from pathlib import Path 3 | from unittest.mock import MagicMock, patch 4 | 5 | import pytest 6 | 7 | from ShortsMaker import GenerateImage 8 | 9 | 10 | @pytest.fixture 11 | def generate_image(setup_file): 12 | """Fixture to initialize and return a GenerateImage instance. 13 | 14 | Args: 15 | setup_file: The configuration file for GenerateImage. 16 | 17 | Returns: 18 | GenerateImage: An instance of the GenerateImage class. 19 | """ 20 | return GenerateImage(setup_file) 21 | 22 | 23 | def test_initialization_with_non_existent_file(): 24 | with pytest.raises(FileNotFoundError): 25 | GenerateImage(config_file=Path("non_existent_file.yml")) 26 | 27 | 28 | def test_initialization_with_invalid_file_format(): 29 | with pytest.raises(ValueError): 30 | temp_path = tempfile.NamedTemporaryFile(suffix=".txt", delete=False) 31 | temp_path = Path(temp_path.name) 32 | try: 33 | GenerateImage(config_file=temp_path) 34 | finally: 35 | temp_path.unlink() 36 | 37 | 38 | def test_load_model_failure(generate_image): 39 | with pytest.raises(RuntimeError): 40 | generate_image._load_model("model_id") 41 | 42 | 43 | @patch("ShortsMaker.generate_image.GenerateImage._load_model") 44 | def test_use_huggingface_flux_dev_success(mock_load_model, generate_image): 45 | mock_load_model.return_value = True 46 | generate_image.pipe = MagicMock() 47 | output = generate_image.use_huggingface_flux_dev( 48 | prompt="Random stuff", output_path="random_path.png" 49 | ) 50 | assert output is not None 51 | 52 | 53 | @patch("ShortsMaker.generate_image.GenerateImage._load_model") 54 | def test_use_huggingface_flux_schnell_success(mock_load_model, generate_image): 55 | mock_load_model.return_value = True 56 | generate_image.pipe = MagicMock() 57 | output = generate_image.use_huggingface_flux_schnell( 58 | prompt="Random stuff", output_path="random_path.png" 59 | ) 60 | assert output is not None 61 | -------------------------------------------------------------------------------- /tests/shorts_maker_tests/test_abbreviation_replacer.py: -------------------------------------------------------------------------------- 1 | from ShortsMaker import abbreviation_replacer 2 | 3 | 4 | def test_abbreviation_replacer_basic_replacement(): 5 | text = "This is an example ABB." 6 | abbreviation = "ABB" 7 | replacement = "abbreviation" 8 | result = abbreviation_replacer(text, abbreviation, replacement) 9 | assert result == "This is an example abbreviation." 10 | 11 | 12 | def test_abbreviation_replacer_replacement_with_padding(): 13 | text = "This is an example ABB." 14 | abbreviation = "ABB" 15 | replacement = "abbreviation" 16 | padding = "." 17 | result = abbreviation_replacer(text, abbreviation, replacement, padding=padding) 18 | assert result == "This is an example abbreviation" 19 | 20 | 21 | def test_abbreviation_replacer_multiple_occurrences(): 22 | text = "ABB is an ABBbreviation ABB." 23 | abbreviation = "ABB" 24 | replacement = "abbreviation" 25 | result = abbreviation_replacer(text, abbreviation, replacement) 26 | assert result == "abbreviation is an abbreviationbreviation abbreviation." 27 | 28 | 29 | def test_abbreviation_replacer_no_match(): 30 | text = "No match here." 31 | abbreviation = "XYZ" 32 | replacement = "something" 33 | result = abbreviation_replacer(text, abbreviation, replacement) 34 | assert result == text 35 | 36 | 37 | def test_abbreviation_replacer_empty_string(): 38 | text = "" 39 | abbreviation = "ABB" 40 | replacement = "abbreviation" 41 | result = abbreviation_replacer(text, abbreviation, replacement) 42 | assert result == "" 43 | -------------------------------------------------------------------------------- /tests/shorts_maker_tests/test_fix_text.py: -------------------------------------------------------------------------------- 1 | def test_fix_text_basic(shorts_maker): 2 | source_txt = "This is a te st sentence." 3 | expected_output = "This is a test sentence." 4 | assert shorts_maker.fix_text(source_txt) == expected_output 5 | 6 | 7 | def test_fix_text_with_escape_characters(shorts_maker): 8 | source_txt = "This is a\t test sentence.\nThis is another test sentence.\r" 9 | expected_output = "This is a test sentence. This is another test sentence." 10 | assert shorts_maker.fix_text(source_txt) == expected_output 11 | 12 | 13 | def test_fix_text_with_punctuations(shorts_maker): 14 | source_txt = "Helllo!! How are you? I'm fine." 15 | expected_output = "Hello!! How are you? I'm fine." 16 | assert shorts_maker.fix_text(source_txt) == expected_output 17 | 18 | 19 | def test_fix_text_with_unicode(shorts_maker): 20 | source_txt = "Café is a Frnch word." 21 | expected_output = "Café is a French word." 22 | assert shorts_maker.fix_text(source_txt) == expected_output 23 | 24 | 25 | def test_fix_text_with_multiple_spaces(shorts_maker): 26 | source_txt = "This is a test." 27 | expected_output = "This is a test." 28 | assert shorts_maker.fix_text(source_txt) == expected_output 29 | -------------------------------------------------------------------------------- /tests/shorts_maker_tests/test_has_alpha_and_digit.py: -------------------------------------------------------------------------------- 1 | from ShortsMaker import has_alpha_and_digit 2 | 3 | 4 | def test_has_alpha_and_digit_with_alphanumeric_input(): 5 | assert has_alpha_and_digit("a1") is True 6 | 7 | 8 | def test_has_alpha_and_digit_with_only_alpha_input(): 9 | assert has_alpha_and_digit("abc") is False 10 | 11 | 12 | def test_has_alpha_and_digit_with_only_digit_input(): 13 | assert has_alpha_and_digit("1234") is False 14 | 15 | 16 | def test_has_alpha_and_digit_with_empty_string(): 17 | assert has_alpha_and_digit("") is False 18 | 19 | 20 | def test_has_alpha_and_digit_with_special_characters(): 21 | assert has_alpha_and_digit("a@1") is True 22 | 23 | 24 | def test_has_alpha_and_digit_with_whitespace(): 25 | assert has_alpha_and_digit("a1 ") is True 26 | 27 | 28 | def test_has_alpha_and_digit_with_uppercase_letters(): 29 | assert has_alpha_and_digit("A1") is True 30 | -------------------------------------------------------------------------------- /tests/shorts_maker_tests/test_shorts_maker.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from unittest.mock import MagicMock, patch 4 | 5 | import pytest 6 | import yaml 7 | 8 | from ShortsMaker import ShortsMaker 9 | 10 | 11 | def test_validate_config_path_valid(tmp_path): 12 | config_path = tmp_path / "config.yml" 13 | config_path.touch() 14 | 15 | with pytest.raises(TypeError): 16 | maker = ShortsMaker(config_path) 17 | assert maker.setup_cfg == config_path 18 | 19 | 20 | def test_validate_config_path_invalid_extension(tmp_path): 21 | config_path = tmp_path / "config.txt" 22 | config_path.touch() 23 | 24 | with pytest.raises(ValueError): 25 | ShortsMaker(config_path) 26 | 27 | 28 | def test_validate_config_path_not_found(): 29 | with pytest.raises(FileNotFoundError): 30 | ShortsMaker(Path("/nonexistent/config.yml")) 31 | 32 | 33 | @patch("ShortsMaker.ShortsMaker.is_unique_submission") 34 | @patch("praw.Reddit") 35 | def test_get_reddit_post(mock_reddit, mock_is_unique_submission, shorts_maker): 36 | mock_submission = MagicMock() 37 | mock_submission.title = "Test Title" 38 | mock_submission.selftext = "Test Content" 39 | mock_submission.name = "t3_123abc" 40 | mock_submission.id = "123abc" 41 | mock_submission.url = "https://reddit.com/r/test/123abc" 42 | 43 | mock_subreddit = MagicMock() 44 | mock_subreddit.hot.return_value = [mock_submission] 45 | mock_subreddit.title = "Test Subreddit" 46 | mock_subreddit.display_name = "test" 47 | 48 | mock_is_unique_submission.return_value = True 49 | mock_reddit.return_value.subreddit.return_value = mock_subreddit 50 | 51 | result = shorts_maker.get_reddit_post() 52 | assert "Test Title" in result 53 | assert "Test Content" in result 54 | 55 | 56 | @patch("praw.Reddit") 57 | def test_get_reddit_post_with_url(mock_reddit, shorts_maker): 58 | # Mock submission data 59 | mock_submission = MagicMock() 60 | mock_submission.title = "Test Title from URL" 61 | mock_submission.selftext = "Test Content from URL" 62 | mock_submission.name = "t3_url_submission" 63 | mock_submission.id = "url_submission" 64 | mock_submission.url = "https://www.reddit.com/r/random_subreddit/test_title_from_url/" 65 | 66 | # Mock Reddit API response 67 | mock_reddit.return_value.submission.return_value = mock_submission 68 | 69 | # Test with URL 70 | test_url = "https://www.reddit.com/r/random_subreddit/test_title_from_url/" 71 | result = shorts_maker.get_reddit_post(url=test_url) 72 | 73 | # Assertions 74 | assert "Test Title from URL" in result 75 | assert "Test Content from URL" in result 76 | mock_reddit.return_value.submission.assert_called_once_with(url=test_url) 77 | 78 | 79 | @patch("ShortsMaker.shorts_maker.tts") 80 | def test_generate_audio_success(mock_tts, shorts_maker, tmp_path): 81 | # Test successful audio generation 82 | source_text = "Test text for audio generation" 83 | output_audio = tmp_path / "test_output.wav" 84 | output_script = tmp_path / "test_script.txt" 85 | seed = "en_us_001" 86 | 87 | mock_tts.return_value = None 88 | 89 | result = shorts_maker.generate_audio( 90 | source_text, output_audio=output_audio, output_script_file=output_script, seed=seed 91 | ) 92 | 93 | assert result is True 94 | mock_tts.assert_called_once() 95 | 96 | 97 | @patch("ShortsMaker.shorts_maker.tts") 98 | def test_generate_audio_failure(mock_tts, shorts_maker): 99 | # Test failed audio generation 100 | source_text = "Test text for audio generation" 101 | mock_tts.side_effect = Exception("TTS Failed") 102 | 103 | result = shorts_maker.generate_audio(source_text) 104 | 105 | assert result is False 106 | mock_tts.assert_called_once() 107 | 108 | 109 | @patch("secrets.choice") 110 | def test_generate_audio_random_speaker(mock_choice, shorts_maker): 111 | # Test random speaker selection when no seed provided 112 | mock_choice.return_value = "en_us_001" 113 | source_text = "Test text" 114 | 115 | with patch("ShortsMaker.shorts_maker.tts"): 116 | shorts_maker.generate_audio(source_text) 117 | mock_choice.assert_called_once() 118 | 119 | 120 | def test_generate_audio_text_processing(shorts_maker): 121 | # Test text processing functionality 122 | source_text = "Test123 text AITA for YTA" 123 | output_script = shorts_maker.cache_dir / "test_script.txt" 124 | 125 | with patch("ShortsMaker.shorts_maker.tts"): 126 | shorts_maker.generate_audio(source_text, output_script_file=output_script) 127 | 128 | with open(output_script) as f: 129 | processed_text = f.read() 130 | 131 | # Verify text processing 132 | assert "Test 123" in processed_text 133 | assert "Am I the asshole" in processed_text 134 | assert "You're the Asshole" in processed_text 135 | 136 | 137 | @patch("ShortsMaker.shorts_maker.generate_audio_transcription") 138 | def test_generate_audio_transcript(mock_transcription, shorts_maker, tmp_path): 139 | # Setup test files 140 | source_audio = tmp_path / "test.wav" 141 | source_audio.touch() 142 | source_text = tmp_path / "test.txt" 143 | source_text.write_text("Test transcript text") 144 | output_file = tmp_path / "transcript.json" 145 | 146 | # Setup mock transcription response 147 | mock_transcript = [ 148 | {"word": "Test", "start": 0.1, "end": 0.3}, 149 | {"word": "transcript", "start": 0.4, "end": 0.6}, 150 | {"word": "text", "start": 0.7, "end": 0.9}, 151 | ] 152 | mock_transcription.return_value = mock_transcript 153 | 154 | # Call function 155 | result = shorts_maker.generate_audio_transcript( 156 | source_audio, source_text, output_transcript_file=str(output_file), debug=False 157 | ) 158 | 159 | # Verify mock was called with correct args 160 | mock_transcription.assert_called_once_with( 161 | audio_file=str(source_audio), 162 | script="Test transcript text", 163 | device=shorts_maker.audio_cfg["device"], 164 | model=shorts_maker.audio_cfg["model"], 165 | batch_size=shorts_maker.audio_cfg["batch_size"], 166 | compute_type=shorts_maker.audio_cfg["compute_type"], 167 | ) 168 | 169 | # Verify result contains filtered transcript 170 | assert result == mock_transcript 171 | 172 | # Verify transcript was saved to file 173 | with open(output_file) as f: 174 | saved_transcript = yaml.safe_load(f) 175 | assert saved_transcript == mock_transcript 176 | 177 | 178 | @patch("ShortsMaker.shorts_maker.generate_audio_transcription") 179 | def test_generate_audio_transcript_default_output(mock_generate_audio_transcription, shorts_maker): 180 | # Test default output file name generation 181 | source_audio = Path(__file__).parent.parent / "data" / "test.wav" 182 | source_text = Path(__file__).parent.parent / "data" / "test.txt" 183 | with open(Path(__file__).parent.parent / "data" / "transcript.json") as f: 184 | expected_transcript = yaml.safe_load(f) 185 | 186 | mock_generate_audio_transcription.return_value = expected_transcript 187 | result = shorts_maker.generate_audio_transcript(source_audio, source_text) 188 | 189 | expected_output = shorts_maker.cache_dir / shorts_maker.audio_cfg["transcript_json"] 190 | assert expected_output.exists() 191 | mock_generate_audio_transcription.assert_called_once() 192 | assert result == expected_transcript 193 | 194 | 195 | @pytest.mark.skipif("RUNALL" not in os.environ, reason="takes too long") 196 | def test_generate_audio_transcript_with_whisperx(shorts_maker): 197 | # Test default output file name generation 198 | source_audio = Path(__file__).parent.parent / "data" / "test.wav" 199 | source_text = Path(__file__).parent.parent / "data" / "test.txt" 200 | with open(Path(__file__).parent.parent / "data" / "transcript.json") as f: 201 | expected_transcript = yaml.safe_load(f) 202 | 203 | result = shorts_maker.generate_audio_transcript(source_audio, source_text) 204 | 205 | expected_output = shorts_maker.cache_dir / shorts_maker.audio_cfg["transcript_json"] 206 | assert expected_output.exists() 207 | assert result == expected_transcript 208 | 209 | 210 | def test_filter_word_transcript(shorts_maker): 211 | test_transcript = [ 212 | {"word": "valid", "start": 0.1, "end": 0.3}, 213 | {"word": "invalid1", "start": 0, "end": 0.5}, 214 | {"word": "invalid2", "start": 0.1, "end": 5.5}, 215 | {"word": "valid2", "start": 1.0, "end": 1.2}, 216 | ] 217 | 218 | filtered = shorts_maker._filter_word_transcript(test_transcript) 219 | 220 | assert len(filtered) == 2 221 | assert filtered[0]["word"] == "valid" 222 | assert filtered[1]["word"] == "valid2" 223 | -------------------------------------------------------------------------------- /tests/shorts_maker_tests/test_split_alpha_and_digit.py: -------------------------------------------------------------------------------- 1 | from ShortsMaker import split_alpha_and_digit 2 | 3 | 4 | def test_split_alpha_and_digit_with_alphanumeric_word(): 5 | result = split_alpha_and_digit("abc123") 6 | assert result == "abc 123" 7 | 8 | 9 | def test_split_alpha_and_digit_with_digits_and_letters_interleaved(): 10 | result = split_alpha_and_digit("a1b2c3") 11 | assert result == "a 1 b 2 c 3" 12 | 13 | 14 | def test_split_alpha_and_digit_with_only_letters(): 15 | result = split_alpha_and_digit("abcdef") 16 | assert result == "abcdef" 17 | 18 | 19 | def test_split_alpha_and_digit_with_only_digits(): 20 | result = split_alpha_and_digit("123456") 21 | assert result == "123456" 22 | 23 | 24 | def test_split_alpha_and_digit_with_empty_string(): 25 | result = split_alpha_and_digit("") 26 | assert result == "" 27 | 28 | 29 | def test_split_alpha_and_digit_with_special_characters(): 30 | result = split_alpha_and_digit("a1!b2@") 31 | assert result == "a 1! b 2@" 32 | 33 | 34 | def test_split_alpha_and_digit_with_spaces_and_tabs(): 35 | result = split_alpha_and_digit("a1 b2\tc3") 36 | assert result == "a 1 b 2\t c 3" 37 | 38 | 39 | def test_split_alpha_and_digit_with_uppercase_and_numbers(): 40 | result = split_alpha_and_digit("ABC123") 41 | assert result == "ABC 123" 42 | -------------------------------------------------------------------------------- /tests/utils_tests/test_audio_transcript.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | import pytest 4 | 5 | from ShortsMaker.utils.audio_transcript import ( 6 | align_transcript_with_script, 7 | generate_audio_transcription, 8 | ) 9 | 10 | 11 | def test_align_transcript_with_script_basic(): 12 | transcript = [ 13 | {"text": "hello world", "start": 0.0, "end": 1.0}, 14 | {"text": "how are you", "start": 1.0, "end": 2.0}, 15 | ] 16 | script = "hello world how are you today" 17 | 18 | result = align_transcript_with_script(transcript, script) 19 | 20 | assert len(result) == 2 21 | assert result[0]["text"] == "hello world" 22 | assert result[1]["text"] == "how are you" 23 | assert result[0]["start"] == 0.0 24 | assert result[0]["end"] == 1.0 25 | assert result[1]["start"] == 1.0 26 | assert result[1]["end"] == 2.0 27 | 28 | 29 | def test_align_transcript_with_script_partial_match(): 30 | transcript = [ 31 | {"text": "helo wrld", "start": 0.0, "end": 1.0}, # Misspelled 32 | {"text": "how r u", "start": 1.0, "end": 2.0}, # Text speak 33 | ] 34 | script = "hello world how are you" 35 | 36 | result = align_transcript_with_script(transcript, script) 37 | 38 | assert len(result) == 2 39 | assert result[0]["text"] == "hello world" 40 | assert result[1]["text"] == "how" 41 | 42 | 43 | def test_align_transcript_empty_inputs(): 44 | transcript = [] 45 | script = "hello world" 46 | 47 | result = align_transcript_with_script(transcript, script) 48 | assert result == [] 49 | 50 | transcript = [{"text": "hello", "start": 0.0, "end": 1.0}] 51 | script = "" 52 | 53 | result = align_transcript_with_script(transcript, script) 54 | assert len(result) == 1 55 | assert result[0]["text"] == "hello" 56 | 57 | 58 | def test_align_transcript_with_longer_script(): 59 | transcript = [ 60 | {"text": "this is", "start": 0.0, "end": 1.0}, 61 | {"text": "a test", "start": 1.0, "end": 2.0}, 62 | ] 63 | script = "this is a test with extra words at the end" 64 | 65 | result = align_transcript_with_script(transcript, script) 66 | 67 | assert len(result) == 2 68 | assert result[0]["text"] == "this is" 69 | assert result[1]["text"] == "a test" 70 | assert all(["start" in entry and "end" in entry for entry in result]) 71 | 72 | 73 | def test_align_transcript_timing_preserved(): 74 | transcript = [ 75 | {"text": "first segment", "start": 1.5, "end": 2.5}, 76 | {"text": "second segment", "start": 2.5, "end": 3.5}, 77 | ] 78 | script = "first segment second segment" 79 | 80 | result = align_transcript_with_script(transcript, script) 81 | 82 | assert len(result) == 2 83 | assert result[0]["start"] == 1.5 84 | assert result[0]["end"] == 2.5 85 | assert result[1]["start"] == 2.5 86 | assert result[1]["end"] == 3.5 87 | 88 | 89 | @pytest.fixture 90 | def mock_whisperx(): 91 | with patch("ShortsMaker.utils.audio_transcript.whisperx") as mock_wx: 92 | # Mock the load_model and model.transcribe 93 | mock_model = MagicMock() 94 | mock_model.transcribe.return_value = { 95 | "segments": [{"text": "hello world", "start": 0.0, "end": 1.0}], 96 | "language": "en", 97 | } 98 | mock_wx.load_model.return_value = mock_model 99 | 100 | # Mock load_audio 101 | mock_wx.load_audio.return_value = "audio_data" 102 | 103 | # Mock alignment model and align 104 | mock_align_model = MagicMock() 105 | mock_metadata = MagicMock() 106 | mock_wx.load_align_model.return_value = (mock_align_model, mock_metadata) 107 | mock_wx.align.return_value = { 108 | "segments": [ 109 | { 110 | "words": [ 111 | {"word": "hello", "start": 0.0, "end": 0.5}, 112 | {"word": "world", "start": 0.5, "end": 1.0}, 113 | ] 114 | } 115 | ] 116 | } 117 | yield mock_wx 118 | 119 | 120 | def test_generate_audio_transcription_basic(mock_whisperx): 121 | # Test basic functionality 122 | result = generate_audio_transcription( 123 | audio_file="test.wav", script="hello world", device="cpu", batch_size=8 124 | ) 125 | 126 | assert len(result) == 2 127 | assert result[0]["word"] == "hello" 128 | assert result[0]["start"] == 0.0 129 | assert result[0]["end"] == 0.5 130 | assert result[1]["word"] == "world" 131 | assert result[1]["start"] == 0.5 132 | assert result[1]["end"] == 1.0 133 | 134 | 135 | def test_generate_audio_transcription_model_cleanup(mock_whisperx): 136 | with patch("ShortsMaker.utils.audio_transcript.gc") as mock_gc: 137 | with patch("ShortsMaker.utils.audio_transcript.torch") as mock_torch: 138 | mock_torch.cuda.is_available.return_value = True 139 | 140 | generate_audio_transcription(audio_file="test.wav", script="hello world") 141 | 142 | # Verify cleanup calls 143 | assert mock_gc.collect.call_count == 2 144 | assert mock_torch.cuda.empty_cache.call_count == 2 145 | 146 | 147 | def test_generate_audio_transcription_missing_timestamps(mock_whisperx): 148 | # Mock align to return words without timestamps 149 | mock_whisperx.align.return_value = { 150 | "segments": [{"words": [{"word": "hello"}, {"word": "world", "start": 0.5, "end": 1.0}]}] 151 | } 152 | 153 | result = generate_audio_transcription(audio_file="test.wav", script="hello world") 154 | 155 | assert len(result) == 2 156 | assert result[0]["word"] == "hello" 157 | assert "start" in result[0] 158 | assert "end" in result[0] 159 | assert result[1]["word"] == "world" 160 | assert result[1]["start"] == 0.5 161 | assert result[1]["end"] == 1.0 162 | 163 | 164 | def test_generate_audio_transcription_parameters(mock_whisperx): 165 | # Test if parameters are correctly passed 166 | generate_audio_transcription( 167 | audio_file="test.wav", 168 | script="test script", 169 | device="test_device", 170 | batch_size=32, 171 | compute_type="float32", 172 | model="medium", 173 | ) 174 | 175 | mock_whisperx.load_model.assert_called_with("medium", "test_device", compute_type="float32") 176 | 177 | mock_whisperx.load_align_model.assert_called_with(language_code="en", device="test_device") 178 | -------------------------------------------------------------------------------- /tests/utils_tests/test_colors_dict.py: -------------------------------------------------------------------------------- 1 | from ShortsMaker import COLORS_DICT 2 | 3 | 4 | def test_colors_dict_structure(): 5 | # Test that COLORS_DICT is a dictionary 6 | assert isinstance(COLORS_DICT, dict) 7 | 8 | # Test that all values are RGBA tuples 9 | for color_value in COLORS_DICT.values(): 10 | assert isinstance(color_value, tuple) 11 | assert len(color_value) == 4 12 | for component in color_value: 13 | assert isinstance(component, int) 14 | assert 0 <= component <= 255 15 | 16 | 17 | def test_common_colors_present(): 18 | # Test that some common colors are present 19 | assert "white" in COLORS_DICT 20 | assert "black" not in COLORS_DICT 21 | assert "red" in COLORS_DICT 22 | assert "blue" in COLORS_DICT 23 | assert "yellow" in COLORS_DICT 24 | 25 | 26 | def test_color_values(): 27 | # Test specific color values 28 | assert COLORS_DICT["white"] == (255, 255, 255, 255) 29 | assert COLORS_DICT["yellow"] == (255, 255, 0, 255) 30 | assert COLORS_DICT["cyan"] == (0, 255, 255, 255) 31 | assert COLORS_DICT["magenta"] == (255, 0, 255, 255) 32 | -------------------------------------------------------------------------------- /tests/utils_tests/test_download_youtube_music.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from unittest.mock import MagicMock, patch 3 | 4 | import pytest 5 | 6 | from ShortsMaker.utils.download_youtube_music import download_youtube_music, sanitize_filename 7 | 8 | 9 | @pytest.fixture 10 | def mock_music_dir(tmp_path): 11 | return tmp_path / "music" 12 | 13 | 14 | @pytest.fixture 15 | def mock_ydl_no_chapters(): 16 | mock = MagicMock() 17 | mock.extract_info.return_value = {"title": "test_song"} 18 | mock.sanitize_info.return_value = {"title": "test_song", "chapters": None} 19 | return mock 20 | 21 | 22 | @pytest.fixture 23 | def mock_ydl_with_chapters(): 24 | mock = MagicMock() 25 | mock.extract_info.return_value = {"title": "test_song_chapters"} 26 | mock.sanitize_info.return_value = { 27 | "title": "test_song_chapters", 28 | "chapters": [ 29 | {"title": "Chapter 1", "start_time": 0, "end_time": 60}, 30 | {"title": "Chapter 2", "start_time": 60, "end_time": 120}, 31 | ], 32 | } 33 | return mock 34 | 35 | 36 | @pytest.mark.parametrize("force", [True, False]) 37 | def test_download_without_chapters(mock_music_dir, mock_ydl_no_chapters, force): 38 | with patch("yt_dlp.YoutubeDL") as mock_ydl_class: 39 | mock_ydl_class.return_value.__enter__.return_value = mock_ydl_no_chapters 40 | 41 | result = download_youtube_music("https://youtube.com/test", mock_music_dir, force=force) 42 | 43 | assert isinstance(result, list) 44 | assert isinstance(result[0], Path) 45 | assert result[0].name == "test_song.wav" 46 | mock_ydl_no_chapters.download.assert_called_once() 47 | 48 | 49 | def test_download_with_chapters(mock_music_dir, mock_ydl_with_chapters): 50 | with patch("yt_dlp.YoutubeDL") as mock_ydl_class: 51 | mock_ydl_class.return_value.__enter__.return_value = mock_ydl_with_chapters 52 | 53 | result = download_youtube_music("https://youtube.com/test", mock_music_dir) 54 | 55 | assert isinstance(result, list) 56 | assert len(result) >= 0 # Since files are only checked at end 57 | mock_ydl_with_chapters.download.call_count == 2 58 | 59 | 60 | def test_download_with_existing_files(mock_music_dir, mock_ydl_no_chapters): 61 | # Create existing file 62 | mock_music_dir.mkdir(parents=True) 63 | existing_file = mock_music_dir / "test_song.wav" 64 | existing_file.touch() 65 | 66 | with patch("yt_dlp.YoutubeDL") as mock_ydl_class: 67 | mock_ydl_class.return_value.__enter__.return_value = mock_ydl_no_chapters 68 | 69 | # Should not download when force=False 70 | result = download_youtube_music("https://youtube.com/test", mock_music_dir, force=False) 71 | 72 | assert isinstance(result, list) 73 | assert isinstance(result[0], Path) 74 | assert not mock_ydl_no_chapters.download.called 75 | 76 | # Should download when force=True 77 | result = download_youtube_music("https://youtube.com/test", mock_music_dir, force=True) 78 | 79 | assert isinstance(result, list) 80 | assert isinstance(result[0], Path) 81 | mock_ydl_no_chapters.download.assert_called_once() 82 | 83 | 84 | @pytest.mark.parametrize( 85 | "filename, expected_filenames", 86 | [ 87 | ("normal filename", "normal_filename"), 88 | (" spaces ", "spaces"), 89 | ("file.name.", "file.name"), 90 | ('invalid<>:"/\\|?*chars', "invalid_________chars"), 91 | ("mixed case FiLe", "mixed_case_FiLe"), 92 | (" leading.trailing. ", "leading.trailing"), 93 | ("file with spaces", "file_with_spaces"), 94 | ("file*with?invalid:chars", "file_with_invalid_chars"), 95 | ], 96 | ) 97 | def test_sanitize_filename(filename, expected_filenames): 98 | assert sanitize_filename(filename) == expected_filenames 99 | -------------------------------------------------------------------------------- /tests/utils_tests/test_download_youtube_video.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from unittest.mock import Mock, patch 3 | 4 | import pytest 5 | 6 | from ShortsMaker.utils.download_youtube_video import download_youtube_video 7 | 8 | 9 | @pytest.fixture 10 | def mock_ydl(): 11 | with patch("yt_dlp.YoutubeDL") as mock: 12 | mock_instance = Mock() 13 | mock.return_value.__enter__.return_value = mock_instance 14 | mock_instance.extract_info.return_value = {"title": "test_video"} 15 | mock_instance.sanitize_info.return_value = {"title": "test_video"} 16 | mock_instance.prepare_filename.return_value = "test_video.mp4" 17 | yield mock_instance 18 | 19 | 20 | @pytest.fixture 21 | def tmp_path_with_video(tmp_path): 22 | video_file = tmp_path / "test_video.mp4" 23 | video_file.touch() 24 | return tmp_path 25 | 26 | 27 | def test_download_video_success(mock_ydl, tmp_path): 28 | url = "https://www.youtube.com/watch?v=test123" 29 | 30 | result = download_youtube_video(url, tmp_path) 31 | 32 | mock_ydl.extract_info.assert_called_once_with(url, download=False) 33 | mock_ydl.download.assert_called_once_with([url]) 34 | assert isinstance(result, list) 35 | assert all(isinstance(p, Path) for p in result) 36 | 37 | 38 | def test_download_video_existing_file(mock_ydl, tmp_path_with_video): 39 | url = "https://www.youtube.com/watch?v=test123" 40 | 41 | result = download_youtube_video(url, tmp_path_with_video) 42 | 43 | mock_ydl.extract_info.assert_called_once_with(url, download=False) 44 | mock_ydl.download.assert_not_called() 45 | assert isinstance(result, list) 46 | assert len(result) == 1 47 | assert all(isinstance(p, Path) for p in result) 48 | 49 | 50 | def test_download_video_force(mock_ydl, tmp_path_with_video): 51 | url = "https://www.youtube.com/watch?v=test123" 52 | 53 | result = download_youtube_video(url, tmp_path_with_video, force=True) 54 | 55 | mock_ydl.extract_info.assert_called_once_with(url, download=False) 56 | mock_ydl.download.assert_called_once_with([url]) 57 | assert isinstance(result, list) 58 | assert all(isinstance(p, Path) for p in result) 59 | 60 | 61 | def test_download_video_no_files(mock_ydl, tmp_path): 62 | url = "https://www.youtube.com/watch?v=test123" 63 | mock_ydl.prepare_filename.return_value = str(tmp_path / "nonexistent.mp4") 64 | 65 | result = download_youtube_video(url, tmp_path) 66 | 67 | assert isinstance(result, list) 68 | assert len(result) == 0 69 | -------------------------------------------------------------------------------- /tests/utils_tests/test_get_tts.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock, patch 2 | 3 | import pytest 4 | 5 | from ShortsMaker.utils.get_tts import ( 6 | ENDPOINT_DATA, 7 | VOICES, 8 | _process_chunks, 9 | _split_text, 10 | _validate_inputs, 11 | tts, 12 | ) 13 | 14 | 15 | @pytest.fixture 16 | def mock_audio_segment(): 17 | with patch("pydub.AudioSegment.from_file") as mock: 18 | yield mock 19 | 20 | 21 | def test_validate_inputs_valid(): 22 | _validate_inputs("test text", VOICES[0]) 23 | # Should not raise any exception 24 | 25 | 26 | def test_validate_inputs_invalid_voice(): 27 | with pytest.raises(ValueError, match="voice must be valid"): 28 | _validate_inputs("test text", "invalid_voice") 29 | 30 | 31 | def test_validate_inputs_empty_text(): 32 | with pytest.raises(ValueError, match="text must not be 'None'"): 33 | _validate_inputs("", VOICES[0]) 34 | 35 | 36 | def test_split_text(): 37 | text = "This is a test text that needs to be split into multiple chunks" 38 | chunks = _split_text(text, chunk_size=20) 39 | assert len(chunks) > 1 40 | assert all(len(chunk) <= 20 for chunk in chunks) 41 | 42 | 43 | @patch("requests.post") 44 | def test_process_chunks_success(mock_post): 45 | mock_response = Mock() 46 | mock_response.status_code = 200 47 | mock_response.json.return_value = {"data": "fake_base64_data"} 48 | mock_post.return_value = mock_response 49 | 50 | chunks = ["test chunk"] 51 | endpoint = {"url": "test_url", "response": "data"} 52 | voice = VOICES[0] 53 | audio_data = [""] 54 | 55 | result = _process_chunks(chunks, endpoint, voice, audio_data) 56 | assert result == ["fake_base64_data"] 57 | 58 | 59 | @patch("requests.post") 60 | def test_process_chunks_failure(mock_post): 61 | mock_response = Mock() 62 | mock_response.status_code = 404 63 | mock_post.return_value = mock_response 64 | 65 | chunks = ["test chunk"] 66 | endpoint = {"url": "test_url", "response": "data"} 67 | voice = VOICES[0] 68 | audio_data = [""] 69 | 70 | result = _process_chunks(chunks, endpoint, voice, audio_data) 71 | assert result is None 72 | 73 | 74 | @patch("pydub.AudioSegment.from_file") 75 | @patch("requests.post") 76 | def test_tts_integration(mock_post, mock_audio): 77 | mock_response = Mock() 78 | mock_response.status_code = 200 79 | mock_response.json.return_value = {"data": "ZmFrZV9iYXNlNjRfZGF0YQ=="} 80 | mock_post.return_value = mock_response 81 | 82 | mock_audio_obj = Mock() 83 | mock_audio_obj.export = Mock() 84 | mock_audio.return_value = mock_audio_obj 85 | 86 | tts("test text", VOICES[0], "test_output.wav") 87 | assert mock_audio_obj.export.called 88 | 89 | 90 | @patch("requests.post") 91 | @patch("ShortsMaker.utils.get_tts._save_audio") 92 | def test_tts_with_failing_and_successful_endpoints(mock_save_audio, mock_post): 93 | # Mock responses for endpoints 94 | def mock_post_side_effect(url, json, headers) -> Mock: 95 | if url == ENDPOINT_DATA[0]["url"]: 96 | # Simulate failure for the first endpoint 97 | response = Mock() 98 | response.status_code = 500 99 | return response 100 | elif url == ENDPOINT_DATA[1]["url"]: 101 | # Simulate success for the second endpoint 102 | response = Mock() 103 | response.status_code = 200 104 | response.json.return_value = {ENDPOINT_DATA[1]["response"]: "mock_audio_data"} 105 | return response 106 | 107 | mock_post.side_effect = mock_post_side_effect 108 | 109 | mock_save_audio.return_value = None 110 | 111 | # Input data 112 | text = "This is a test." 113 | voice = "en_us_001" 114 | output_filename = "test_output.mp3" 115 | 116 | # Call the tts function 117 | tts(text, voice, output_filename) 118 | 119 | # Assertions 120 | assert mock_post.call_count == 2 121 | mock_post.assert_any_call( 122 | ENDPOINT_DATA[0]["url"], 123 | json={"text": "This is a test.", "voice": voice}, 124 | headers={ 125 | "User-Agent": "com.zhiliaoapp.musically/2022600030 (Linux; U; Android 7.1.2; es_ES; SM-G988N; Build/NRD90M;tt-ok/3.12.13.1)", 126 | }, 127 | ) 128 | mock_post.assert_any_call( 129 | ENDPOINT_DATA[1]["url"], 130 | json={"text": "This is a test.", "voice": voice}, 131 | headers={ 132 | "User-Agent": "com.zhiliaoapp.musically/2022600030 (Linux; U; Android 7.1.2; es_ES; SM-G988N; Build/NRD90M;tt-ok/3.12.13.1)", 133 | }, 134 | ) 135 | mock_save_audio.assert_called_once_with(["mock_audio_data"], output_filename) 136 | -------------------------------------------------------------------------------- /tests/utils_tests/test_logging_config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from unittest.mock import patch 3 | 4 | import pytest 5 | 6 | import ShortsMaker 7 | import ShortsMaker.utils 8 | from ShortsMaker.utils.logging_config import ( 9 | LOGGERS, 10 | configure_logging, 11 | get_logger, 12 | ) 13 | 14 | 15 | @pytest.fixture 16 | def reset_logging_state(): 17 | """Reset global logging state between tests""" 18 | global INITIALIZED, LOGGERS 19 | INITIALIZED = False 20 | LOGGERS.clear() 21 | yield 22 | INITIALIZED = False 23 | LOGGERS.clear() 24 | 25 | 26 | def test_get_logger_creates_new_logger(reset_logging_state): 27 | logger = get_logger("test_logger") 28 | assert isinstance(logger, logging.Logger) 29 | assert logger.name == "test_logger" 30 | assert len(logger.handlers) == 2 31 | assert not logger.propagate 32 | 33 | 34 | def test_get_logger_returns_cached_logger(reset_logging_state): 35 | logger1 = get_logger("test_logger") 36 | logger2 = get_logger("test_logger") 37 | assert logger1 is logger2 38 | 39 | 40 | def test_configure_logging_updates_settings(reset_logging_state, tmp_path): 41 | test_log_file = tmp_path / "test.log" 42 | test_level = "INFO" 43 | configure_logging(log_file=test_log_file, level=test_level, enable=True) 44 | assert ShortsMaker.utils.logging_config.LOG_FILE == test_log_file 45 | assert ShortsMaker.utils.logging_config.LOG_LEVEL == test_level 46 | assert ShortsMaker.utils.logging_config.LOGGING_ENABLED is True 47 | assert ShortsMaker.utils.logging_config.INITIALIZED is True 48 | 49 | 50 | def test_configure_logging_updates_existing_loggers(reset_logging_state): 51 | logger = get_logger("test_logger") 52 | 53 | configure_logging(level="INFO", enable=True) 54 | assert logger.level == logging.INFO 55 | 56 | configure_logging(enable=False) 57 | assert logger.level == logging.CRITICAL 58 | 59 | 60 | @patch("pathlib.Path.mkdir") 61 | def test_configure_logging_creates_directory(mock_mkdir): 62 | configure_logging() 63 | mock_mkdir.assert_called_once_with(parents=True, exist_ok=True) 64 | 65 | 66 | def test_disabled_logging_sets_critical_level(reset_logging_state): 67 | configure_logging(enable=False) 68 | logger = get_logger("test_logger") 69 | assert logger.level == logging.CRITICAL 70 | -------------------------------------------------------------------------------- /tests/utils_tests/test_notify_discord.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import MagicMock, patch 2 | 3 | import pytest 4 | from requests import Response 5 | 6 | from ShortsMaker.utils.notify_discord import get_arthas, get_meme, notify_discord 7 | 8 | 9 | @pytest.fixture 10 | def mock_response(): 11 | response = MagicMock(spec=Response) 12 | response.status_code = 200 13 | response.text = "Success" 14 | return response 15 | 16 | 17 | @pytest.fixture 18 | def mock_webhook(): 19 | with patch("ShortsMaker.utils.notify_discord.DiscordWebhook") as mock: 20 | webhook = MagicMock() 21 | webhook.execute.return_value = MagicMock(spec=Response) 22 | mock.return_value = webhook 23 | yield mock 24 | 25 | 26 | @pytest.fixture 27 | def mock_get_meme(): 28 | with patch("ShortsMaker.utils.notify_discord.get_meme") as mock: 29 | mock.return_value = "http://fake-meme.com/image.jpg" 30 | yield mock 31 | 32 | 33 | @pytest.fixture 34 | def mock_get_arthas(): 35 | with patch("ShortsMaker.utils.notify_discord.get_arthas") as mock: 36 | mock.return_value = "http://fake-arthas.com/image.jpg" 37 | yield mock 38 | 39 | 40 | def test_get_arthas(requests_mock): 41 | mock_html = """ 42 |
Test
43 | """ 44 | requests_mock.get("https://www.bing.com/images/search", text=mock_html) 45 | result = get_arthas() 46 | assert isinstance(result, str) 47 | assert result == "test_image.jpg" 48 | 49 | 50 | def test_get_meme(requests_mock): 51 | mock_response = {"MemeURL": "http://test-meme.com/image.jpg"} 52 | requests_mock.get("https://memeapi.zachl.tech/pic/json", json=mock_response) 53 | result = get_meme() 54 | assert result == "http://test-meme.com/image.jpg" 55 | 56 | 57 | def test_notify_discord_short_message(mock_webhook): 58 | message = "Test message" 59 | mock_webhook.return_value.execute.return_value = MagicMock(status_code=200) 60 | result = notify_discord(message) 61 | mock_webhook.assert_called_once() 62 | assert result is not None 63 | 64 | 65 | def test_notify_discord_long_message(mock_webhook, mock_get_meme, mock_get_arthas, mock_response): 66 | message = "x" * 5000 # Message longer than 4000 chars 67 | mock_webhook.return_value.execute.return_value = MagicMock(status_code=200) 68 | result = notify_discord(message) 69 | assert mock_webhook.call_count > 1 70 | assert result is not None 71 | 72 | 73 | @pytest.mark.parametrize( 74 | "status_code,expected_text", 75 | [ 76 | (200, "Success"), 77 | (400, "Bad Request"), 78 | ], 79 | ) 80 | def test_notify_discord_response( 81 | mock_webhook, mock_get_meme, mock_get_arthas, status_code, expected_text 82 | ): 83 | webhook = mock_webhook.return_value 84 | response = MagicMock(spec=Response) 85 | response.status_code = status_code 86 | response.text = expected_text 87 | webhook.execute.return_value = response 88 | 89 | result = notify_discord("Test message") 90 | assert result.status_code == status_code 91 | assert result.text == expected_text 92 | -------------------------------------------------------------------------------- /tests/utils_tests/test_retry.py: -------------------------------------------------------------------------------- 1 | from unittest.mock import Mock, patch 2 | 3 | from ShortsMaker.utils.retry import retry 4 | 5 | 6 | def test_retry_preserves_function_docstring(): 7 | @retry(max_retries=3, delay=0) 8 | def mock_func(): 9 | """This is a test function.""" 10 | pass 11 | 12 | assert mock_func.__doc__ == "This is a test function." 13 | 14 | 15 | @patch("ShortsMaker.utils.retry.logger") 16 | def test_retry_logs_function_name(mock_logger): 17 | # Test that function name is correctly logged 18 | mock_func = Mock(return_value="success") 19 | mock_func.__name__ = "test_func" 20 | decorated = retry(max_retries=1, delay=0)(mock_func) 21 | 22 | decorated() 23 | 24 | mock_logger.info.assert_any_call("Begin function test_func") 25 | 26 | 27 | @patch("ShortsMaker.utils.retry.logger") 28 | def test_retry_logs_return_value(mock_logger): 29 | # Test that return value is correctly logged 30 | mock_func = Mock(return_value="test_value") 31 | mock_func.__name__ = "test_func" 32 | decorated = retry(max_retries=1, delay=0)(mock_func) 33 | 34 | decorated() 35 | 36 | mock_logger.info.assert_any_call("Returned: test_value") 37 | 38 | 39 | @patch("ShortsMaker.utils.retry.logger") 40 | def test_retry_logs_execution_time(mock_logger): 41 | # Test that execution time is logged 42 | mock_func = Mock(return_value="success") 43 | mock_func.__name__ = "test_func" 44 | decorated = retry(max_retries=1, delay=0)(mock_func) 45 | 46 | decorated() 47 | 48 | # Check that completion log contains function name and execution time 49 | completion_log_call = mock_logger.info.call_args_list[-1] 50 | assert "Completed function test_func in" in str(completion_log_call) 51 | assert "s after 1 max_retries" in str(completion_log_call) 52 | 53 | 54 | @patch("ShortsMaker.utils.retry.logger") 55 | def test_retry_logs_max_retries_and_delay(mock_logger): 56 | # Test that retry parameters are logged 57 | mock_func = Mock(return_value="success") 58 | mock_func.__name__ = "test_func" 59 | decorated = retry(max_retries=5, delay=10)(mock_func) 60 | 61 | decorated() 62 | 63 | mock_logger.info.assert_any_call("Using retry decorator with 5 max_retries and 10s delay") 64 | 65 | 66 | @patch("ShortsMaker.utils.retry.logger") 67 | def test_retry_logs_attempts(mock_logger): 68 | mock_func = Mock(side_effect=[Exception("error"), "success"]) 69 | mock_func.__name__ = "test_func" 70 | decorated = retry(max_retries=2, delay=0)(mock_func) 71 | 72 | decorated() 73 | 74 | assert mock_logger.info.call_count >= 2 75 | assert mock_logger.warning.call_count == 1 76 | assert mock_logger.exception.call_count == 1 77 | 78 | 79 | def test_retry_successful_second_attempt(): 80 | mock_func = Mock(side_effect=[Exception("error"), "success"]) 81 | mock_func.__name__ = "test_func" 82 | decorated = retry(max_retries=3, delay=0)(mock_func) 83 | 84 | result = decorated() 85 | 86 | assert result == "success" 87 | assert mock_func.call_count == 2 88 | 89 | 90 | def test_retry_all_attempts_failed(): 91 | mock_func = Mock(side_effect=Exception("error")) 92 | mock_func.__name__ = "test_func" 93 | decorated = retry(max_retries=3, delay=0)(mock_func) 94 | 95 | result = decorated() 96 | 97 | assert result is None 98 | assert mock_func.call_count == 3 99 | 100 | 101 | @patch("ShortsMaker.utils.retry.notify_discord") 102 | def test_retry_with_notify(mock_notify): 103 | mock_func = Mock(side_effect=Exception("error")) 104 | mock_func.__name__ = "test_func" 105 | decorated = retry(max_retries=2, delay=0, notify=True)(mock_func) 106 | 107 | decorated() 108 | 109 | mock_notify.assert_called_once_with("test_func Failed after 2 max_retries.\nException: error") 110 | 111 | 112 | @patch("time.sleep") 113 | def test_retry_respects_delay(mock_sleep): 114 | mock_func = Mock(side_effect=[Exception("error"), "success"]) 115 | mock_func.__name__ = "test_func" 116 | delay = 5 117 | decorated = retry(max_retries=3, delay=delay)(mock_func) 118 | 119 | decorated() 120 | 121 | mock_sleep.assert_called_once_with(delay) 122 | 123 | 124 | def test_retry_preserves_function_args(): 125 | mock_func = Mock(return_value="success") 126 | mock_func.__name__ = "test_func" 127 | decorated = retry(max_retries=3, delay=0)(mock_func) 128 | 129 | decorated(1, key="value") 130 | 131 | mock_func.assert_called_once_with(1, key="value") 132 | --------------------------------------------------------------------------------