├── .dockerignore
├── .github
└── workflows
│ └── python-app.yml
├── .gitignore
├── .pre-commit-config.yaml
├── Dockerfile
├── LICENSE
├── README.md
├── ShortsMaker
├── __init__.py
├── ask_llm.py
├── generate_image.py
├── moviepy_create_video.py
├── shorts_maker.py
└── utils
│ ├── __init__.py
│ ├── audio_transcript.py
│ ├── colors_dict.py
│ ├── download_youtube_music.py
│ ├── download_youtube_video.py
│ ├── get_tts.py
│ ├── logging_config.py
│ ├── notify_discord.py
│ └── retry.py
├── assets
├── credits
│ ├── credits.mp4
│ └── credits_mask.mp4
└── fonts
│ ├── Monaco.ttf
│ └── Roboto.ttf
├── example.py
├── example.setup.yml
├── pyproject.toml
├── requirements.txt
├── tests
├── __init__.py
├── ask_llm_tests
│ ├── test_ask_llm.py
│ └── test_ollama_service.py
├── conftest.py
├── data
│ ├── setup.yml
│ ├── test.txt
│ ├── test.wav
│ └── transcript.json
├── generate_image_tests
│ └── test_generate_image.py
├── moviepy_create_video_tests
│ └── test_moviepy_create_video.py
├── shorts_maker_tests
│ ├── test_abbreviation_replacer.py
│ ├── test_fix_text.py
│ ├── test_has_alpha_and_digit.py
│ ├── test_shorts_maker.py
│ └── test_split_alpha_and_digit.py
└── utils_tests
│ ├── test_audio_transcript.py
│ ├── test_colors_dict.py
│ ├── test_download_youtube_music.py
│ ├── test_download_youtube_video.py
│ ├── test_get_tts.py
│ ├── test_logging_config.py
│ ├── test_notify_discord.py
│ └── test_retry.py
└── uv.lock
/.dockerignore:
--------------------------------------------------------------------------------
1 | .venv
2 | old_scripts
3 |
4 | # remove any cache
5 | cache
6 | assets
7 |
8 | # ignore cache files
9 | *.pyc
10 | *.pyo
11 |
12 | # ignore logs
13 | *.log
14 |
15 | # ignore pytest cache
16 | .pytest_cache
17 |
18 | # IDE caches
19 | .vscode
20 | .idea
21 |
22 | # ignore ipynb files and checkpoints, which may be used for testing
23 | *.ipynb
24 | .ipynb_checkpoints
25 | .ruff_cache
26 |
--------------------------------------------------------------------------------
/.github/workflows/python-app.yml:
--------------------------------------------------------------------------------
1 | name: ShortsMaker
2 |
3 | on:
4 | push:
5 | branches: [ "main" ]
6 |
7 | jobs:
8 | test:
9 | permissions:
10 | contents: read
11 | pull-requests: write
12 | runs-on: ubuntu-latest
13 | env:
14 | DISCORD_WEBHOOK_URL: "https://discord.com/api/webhooks/xxx"
15 | steps:
16 | - name: Setup Java
17 | uses: actions/setup-java@v4
18 | with:
19 | distribution: temurin
20 | java-version: 21
21 | - name: Setup FFmpeg
22 | uses: federicocarboni/setup-ffmpeg@v3.1
23 | - name: Checkout the repository
24 | uses: actions/checkout@main
25 | - name: Install the latest version of uv and set the python version
26 | uses: astral-sh/setup-uv@v5
27 | with:
28 | version: "latest"
29 | enable-cache: true
30 | pyproject-file: "pyproject.toml"
31 | - name: Install dependencies
32 | run: uv sync --frozen --extra cpu
33 | - name: Run tests
34 | run: uv run --frozen pytest
35 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .venv
2 | old_scripts
3 |
4 | # remove any media files
5 | assets
6 | cache
7 |
8 | # config files
9 | setup.yml
10 |
11 | # ignore cache files
12 | *.pyc
13 | *.pyo
14 |
15 | # ignore logs
16 | *.log
17 |
18 | # ignore pytest cache
19 | .pytest_cache
20 |
21 | # IDE caches
22 | .vscode
23 | .idea
24 |
25 | # Ignore any media files
26 | *.mp4
27 | *.avi
28 | *.wav
29 | *.mp3
30 |
31 | # ignore ipynb files and checkpoints, which may be used for testing
32 | *.ipynb
33 | .ipynb_checkpoints
34 | .ruff_cache
35 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | repos:
4 | - repo: https://github.com/astral-sh/ruff-pre-commit
5 | # Ruff version.
6 | rev: v0.9.4
7 | hooks:
8 | # Run the linter.
9 | - id: ruff
10 | # Run the formatter.
11 | - id: ruff-format
12 | - repo: https://github.com/pre-commit/pre-commit-hooks
13 | rev: v5.0.0
14 | hooks:
15 | - id: trailing-whitespace
16 | - id: end-of-file-fixer
17 | - id: check-yaml
18 | - id: debug-statements
19 | language_version: python3
20 | - id: detect-private-key
21 | - repo: https://github.com/Lucas-C/pre-commit-hooks
22 | rev: v1.5.5
23 | hooks:
24 | - id: forbid-crlf
25 | - id: remove-crlf
26 | - id: forbid-tabs
27 | - id: remove-tabs
28 | - repo: https://github.com/astral-sh/uv-pre-commit
29 | rev: 0.5.29
30 | hooks:
31 | - id: uv-lock
32 | - id: uv-export
33 | exclude: .venv
34 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | # Use a Python 3.12 base image
2 | FROM python:3.12-slim-bookworm
3 | COPY --from=ghcr.io/astral-sh/uv:0.6.6 /uv /bin/
4 |
5 | # Set environment variables
6 | # Set this appropriately or leave empty if not using Discord
7 | ENV DISCORD_WEBHOOK_URL="https://discord.com/api/webhooks/xxxxxx"
8 |
9 | # Install ffmpeg and Java
10 | # Install dependencies for moviepy
11 | RUN apt-get update -y && \
12 | apt-get install -y --no-install-recommends ffmpeg openjdk-17-jre locales && \
13 | apt-get clean && \
14 | rm -rf /var/lib/apt/lists/* && \
15 | locale-gen C.UTF-8 && \
16 | /usr/sbin/update-locale LANG=C.UTF-8
17 |
18 | ENV LC_ALL=C.UTF-8
19 |
20 | # Set working directory
21 | ADD . /shorts_maker
22 |
23 | RUN cd /shorts_maker && \
24 | uv sync --frozen
25 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # ShortsMaker
2 |
3 | [](https://github.com/rajathjn/shorts_maker/actions/workflows/python-app.yml)
4 | [](https://github.com/rajathjn/shorts_maker/actions/workflows/github-code-scanning/codeql)
5 |
6 | ShortsMaker is a Python package designed to facilitate the creation of engaging short videos or social media clips. It leverages a variety of external services and libraries to streamline the process of generating, processing, and uploading short content.
7 |
8 | ## Support Me
9 | Like what I do, Please consider supporting me.
10 |
11 |
12 |
13 | ## Table of Contents
14 |
15 | - [ShortsMaker](#shortsmaker)
16 | - [Support Me](#support-me)
17 | - [Table of Contents](#table-of-contents)
18 | - [Features](#features)
19 | - [Requirements](#requirements)
20 | - [Usage Via Docker](#usage-via-docker)
21 | - [Installation](#installation)
22 | - [External Dependencies](#external-dependencies)
23 | - [Environment Variables](#environment-variables)
24 | - [Usage](#usage)
25 | - [Example Video](#example-video)
26 | - [TODO](#todo)
27 | - [Development](#development)
28 | - [Contributing](#contributing)
29 | - [License](#license)
30 |
31 | ## Features
32 |
33 | - **Automated Content Creation:** Easily generate engaging short videos.
34 | - **External Service Integration:** Seamlessly integrates with services like Discord for notifications.
35 | - **GPU-Accelerated Processing:** Optional GPU support for faster processing using whisperx.
36 | - **Modular Design:** Built with extensibility in mind.
37 | - **In Development:** AskLLM AI agent now fully integrated for generating metadata and creative insights.
38 | - **In Development:** GenerateImage class enhanced for text-to-image generation using flux. May be resource intensive.
39 |
40 | ## Requirements
41 |
42 | - **Python:** 3.12.8
43 | - **Package Manager:** [`uv`](https://docs.astral.sh/uv/) is used for package management. ( It's amazing! try it out. )
44 | - **Operating System:** Windows, Mac, or Linux (ensure external dependencies are installed for your platform)
45 |
46 | ## Usage Via Docker
47 |
48 | To use ShortsMaker via Docker, follow these steps:
49 |
50 | 1. **Build the Docker Image:**
51 |
52 | Build the Docker image using the provided Dockerfile.
53 |
54 | ```bash
55 | docker build -t shorts_maker -f Dockerfile .
56 | ```
57 |
58 | 2. **Run the Docker Container:**
59 |
60 | For the first time, run the container with the necessary mounts, container name, and working directory set.
61 |
62 | ```bash
63 | docker run --name shorts_maker_container -v $pwd/assets:/shorts_maker/assets -w /shorts_maker -it shorts_maker bash
64 | ```
65 |
66 | 3. **Start the Docker Container:**
67 |
68 | If the container was previously stopped, you can start it again using:
69 |
70 | ```bash
71 | docker start shorts_maker_container
72 | ```
73 |
74 | 4. **Access the Docker Container:**
75 |
76 | Execute a bash shell inside the running container.
77 |
78 | ```bash
79 | docker exec -it shorts_maker_container bash
80 | ```
81 |
82 | 5. **Run Examples and Tests:**
83 |
84 | Once you are in the bash shell of the container, you can run the example script or tests using `uv`.
85 |
86 | To run the example script:
87 |
88 | ```bash
89 | uv run example.py
90 | ```
91 |
92 | To run tests:
93 |
94 | ```bash
95 | uv run pytest
96 | ```
97 |
98 | **Note:** If you plan to use `ask_llm` or `generate_image`, it is not recommended to use the Docker image due to the high resource requirements of these features. Instead, run ShortsMaker directly on your host machine.
99 |
100 | ## Installation
101 |
102 | 1. **Clone the Repository:**
103 |
104 | ```bash
105 | git clone https://github.com/rajathjn/shorts_maker
106 | cd shorts_maker
107 | ```
108 |
109 | 2. **Install the Package Using uv:**
110 |
111 | Note: Before starting the installation process. Ensure a python3.12 virtual environment is set up.
112 |
113 | ```bash
114 | uv venv -p 3.12 .venv
115 |
116 | or
117 |
118 | python -m venv .venv
119 | ```
120 |
121 | Package Installation.
122 |
123 | ```bash
124 | uv pip install -r pyproject.toml
125 |
126 | or
127 |
128 | uv sync
129 | uv sync --extra cpu # for cpu
130 | uv sync --extra cu124 # for cuda 12.4 versions
131 | ```
132 |
133 | 4. **Install Any Additional Python Dependencies:**
134 |
135 | If not automatically managed by uv, you may install them using pip ( In most cases you do not need to use the below. ):
136 |
137 | ```bash
138 | pip install -r requirements.txt
139 | ```
140 |
141 | ## External Dependencies
142 |
143 | ShortsMaker relies on several external non-Python components. Please ensure the following are installed/configured on your system:
144 |
145 | - **Discord Notifications:**
146 | - You must set your Discord webhook URL (`DISCORD_WEBHOOK_URL`) as an environment variable.
147 | - Refer to the [Discord documentation](https://discord.com/developers/docs/resources/webhook#create-webhook) for creating a webhook.
148 | - If you don't want to use Discord notifications, you can set `DISCORD_WEBHOOK_URL` to `None` or do something like
149 |
150 | ```python
151 | import os
152 | os.environ["DISCORD_WEBHOOK_URL"] = "None"
153 | ```
154 |
155 | - **Ollama:**
156 | - The external tool Ollama must be installed on your system. Refer to the [Ollama documentation](https://ollama.com/) for installation details.
157 |
158 | - **WhisperX (GPU Acceleration):**
159 | - For GPU execution, ensure that the NVIDIA libraries are installed on your system:
160 | - **cuBLAS:** Version 11.x
161 | - **cuDNN:** Version 8.x
162 | - These libraries are required for optimal performance when using whisperx for processing.
163 |
164 | ## Environment Variables
165 |
166 | Before running ShortsMaker, make sure you set the necessary environment variables:
167 |
168 | - **DISCORD_WEBHOOK_URL:**
169 | This token is required for sending notifications through Discord.
170 | Example (Windows Command Prompt):
171 |
172 | ```powershell
173 | set DISCORD_WEBHOOK_URL=your_discord_webhook_url_here
174 | ```
175 |
176 | Example (Linux/macOS):
177 |
178 | ```bash
179 | export DISCORD_WEBHOOK_URL=your_discord_webhook_url_here
180 | ```
181 |
182 | From Python:
183 |
184 | ```python
185 | import os
186 | os.environ["DISCORD_WEBHOOK_URL"] = "your_discord_webhook_url_here"
187 | ```
188 |
189 | ## Usage
190 |
191 | Ensure you have a `setup.yml` configuration file in the `shorts_maker` directory. Use the [example-setup.yml](example.setup.yml) as a reference.
192 |
193 | Below is a basic example to get you started with ShortsMaker:
194 |
195 | You can also refer to the same [here](example.py)
196 |
197 | ```python
198 | from pathlib import Path
199 |
200 | import yaml
201 |
202 | from ShortsMaker import MoviepyCreateVideo, ShortsMaker
203 |
204 | setup_file = "setup.yml"
205 |
206 | with open(setup_file) as f:
207 | cfg = yaml.safe_load(f)
208 |
209 | get_post = ShortsMaker(setup_file)
210 |
211 | # You can either provide an URL for the reddit post
212 | get_post.get_reddit_post(
213 | url="https://www.reddit.com/r/Python/comments/1j36d7a/i_got_tired_of_ai_shorts_scams_so_i_built_my_own/"
214 | )
215 | # Or just run the method to get a random post from the subreddit defined in setup.yml
216 | # get_post.get_reddit_post()
217 |
218 | with open(Path(cfg["cache_dir"]) / cfg["reddit_post_getter"]["record_file_txt"]) as f:
219 | script = f.read()
220 |
221 | get_post.generate_audio(
222 | source_txt=script,
223 | output_audio=f"{cfg['cache_dir']}/{cfg['audio']['output_audio_file']}",
224 | output_script_file=f"{cfg['cache_dir']}/{cfg['audio']['output_script_file']}",
225 | )
226 |
227 | get_post.generate_audio_transcript(
228 | source_audio_file=f"{cfg['cache_dir']}/{cfg['audio']['output_audio_file']}",
229 | source_text_file=f"{cfg['cache_dir']}/{cfg['audio']['output_script_file']}",
230 | )
231 |
232 | get_post.quit()
233 |
234 | create_video = MoviepyCreateVideo(
235 | config_file=setup_file,
236 | speed_factor=1.0,
237 | )
238 |
239 | create_video(output_path="assets/output.mp4")
240 |
241 | create_video.quit()
242 |
243 | # Do not run the below when you are using shorts_maker within a container.
244 |
245 | # ask_llm = AskLLM(config_file=setup_file)
246 | # result = ask_llm.invoke(script)
247 | # print(result["parsed"].title)
248 | # print(result["parsed"].description)
249 | # print(result["parsed"].tags)
250 | # print(result["parsed"].thumbnail_description)
251 | # ask_llm.quit_llm()
252 |
253 | # You can use, AskLLM to generate a text prompt for the image generation as well
254 | # image_description = ask_llm.invoke_image_describer(script = script, input_text = "A wild scenario")
255 | # print(image_description)
256 | # print(image_description["parsed"].description)
257 |
258 | # Generate image uses a lot of resources so beware
259 | # generate_image = GenerateImage(config_file=setup_file)
260 | # generate_image.use_huggingface_flux_schnell(image_description["parsed"].description, "output.png")
261 | # generate_image.quit()
262 |
263 | ```
264 |
265 | ## Example Video
266 |
267 | Generated from this post [here](https://www.reddit.com/r/selfhosted/comments/r2a6og/comment/hm5xoas/?utm_source=share&utm_medium=web3x&utm_name=web3xcss&utm_term=1&utm_content=share_button)
268 |
269 | https://github.com/user-attachments/assets/6aad212a-bfd5-4161-a2bc-67d24a8de37f
270 |
271 | ## TODO
272 | - [ ] Explain working and usage in blog.
273 | - [x] Dockerize the project, To avoid the complex set up process.
274 | - [x] Add option to fetch post from submission URLs.
275 | - [x] Add an example video to the README.
276 |
277 | ## Development
278 |
279 | If you want to contribute to the project, please follow these steps:
280 |
281 | 1. **Set up the development environment:**
282 | - Ensure you have Python 3.12.8 and uv installed.
283 | - Clone the repository and install the development dependencies.
284 |
285 | 2. **Run the Tests:**
286 | - Tests are located in the `tests/` directory.
287 | - Run tests using:
288 |
289 | ```bash
290 | uv run --frozen pytest
291 | ```
292 |
293 | ## Contributing
294 |
295 | If you want to contribute to the project, please follow these steps:
296 |
297 | Follow everything in the [Development](#development) section and then:
298 |
299 | **Submit a Pull Request:**
300 | - Fork the repository.
301 | - Create a new branch for your feature or bugfix.
302 | - Commit your changes and push the branch to your fork.
303 | - Open a pull request with a detailed description of your changes.
304 |
305 |
306 | ## License
307 |
308 | This project is licensed under the GNU General Public License v3.0 License. See the [LICENSE](LICENSE) file for details.
309 |
--------------------------------------------------------------------------------
/ShortsMaker/__init__.py:
--------------------------------------------------------------------------------
1 | from .ask_llm import AskLLM, OllamaServiceManager
2 | from .generate_image import GenerateImage
3 | from .moviepy_create_video import MoviepyCreateVideo, VideoConfig
4 | from .shorts_maker import (
5 | ShortsMaker,
6 | abbreviation_replacer,
7 | has_alpha_and_digit,
8 | split_alpha_and_digit,
9 | )
10 | from .utils.audio_transcript import (
11 | align_transcript_with_script,
12 | generate_audio_transcription,
13 | )
14 | from .utils.colors_dict import COLORS_DICT
15 | from .utils.download_youtube_music import download_youtube_music, sanitize_filename
16 | from .utils.download_youtube_video import download_youtube_video
17 | from .utils.get_tts import VOICES, tts
18 | from .utils.logging_config import configure_logging, get_logger
19 | from .utils.notify_discord import notify_discord
20 | from .utils.retry import retry
21 |
22 | __all__ = [
23 | GenerateImage,
24 | ShortsMaker,
25 | AskLLM,
26 | OllamaServiceManager,
27 | MoviepyCreateVideo,
28 | VideoConfig,
29 | configure_logging,
30 | get_logger,
31 | download_youtube_music,
32 | download_youtube_video,
33 | abbreviation_replacer,
34 | COLORS_DICT,
35 | has_alpha_and_digit,
36 | split_alpha_and_digit,
37 | align_transcript_with_script,
38 | generate_audio_transcription,
39 | notify_discord,
40 | retry,
41 | sanitize_filename,
42 | VOICES,
43 | tts,
44 | ]
45 |
--------------------------------------------------------------------------------
/ShortsMaker/ask_llm.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import platform
3 | import subprocess
4 | import time
5 | from pathlib import Path
6 |
7 | import ollama
8 | import psutil
9 | import yaml
10 | from langchain_core.messages import HumanMessage, SystemMessage
11 | from langchain_core.prompts import ChatPromptTemplate
12 | from langchain_ollama import ChatOllama
13 | from pydantic import BaseModel, Field
14 |
15 | from .utils import get_logger
16 |
17 |
18 | class OllamaServiceManager:
19 | """
20 | Manages the Ollama service, including starting, stopping, and checking its status.
21 |
22 | Attributes:
23 | system (str): The operating system the service is running on (e.g., "windows", "linux").
24 | process (subprocess.Popen | None): The process object for the running Ollama service.
25 | ollama (module): The ollama module.
26 | logger (logging.Logger): The logger instance for the class.
27 | """
28 |
29 | def __init__(
30 | self,
31 | logger: logging.Logger | None = None,
32 | ):
33 | """
34 | Initializes the OllamaServiceManager.
35 |
36 | Args:
37 | logger (logging.Logger | None, optional): An optional logger instance. If not provided, a default logger is created. Defaults to None.
38 | """
39 | self.system = platform.system().lower()
40 | self.process = None
41 | self.ollama = ollama
42 |
43 | self.logger = logger if logger else logging.getLogger(__name__)
44 | self.logger.name = "OllamaServiceManager"
45 | self.logger.info(f"Ollama service ollama_service_manager initialized on {self.system}")
46 |
47 | def start_service(self) -> bool:
48 | """
49 | Starts the Ollama service.
50 |
51 | Returns:
52 | bool: True if the service started successfully, False otherwise.
53 |
54 | Raises:
55 | Exception: If there is an error starting the service.
56 | """
57 | self.logger.info("Starting Ollama service")
58 | try:
59 | if self.system == "windows":
60 | ollama_execution_command = ["ollama app.exe", "serve"]
61 | else:
62 | ollama_execution_command = ["ollama", "serve"]
63 | self.process = subprocess.Popen(
64 | ollama_execution_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
65 | )
66 |
67 | # Wait a moment for the service to start
68 | time.sleep(2)
69 |
70 | # Check if the service started successfully
71 | if self.process.poll() is None:
72 | self.logger.info("Ollama service started successfully")
73 | return True
74 | except Exception as e:
75 | self.logger.error(f"Error starting Ollama service: {str(e)}")
76 | raise e
77 | return False
78 |
79 | def stop_service(self) -> bool:
80 | """
81 | Stops the Ollama service.
82 |
83 | Returns:
84 | bool: True if the service stopped successfully, False otherwise.
85 | """
86 | try:
87 | if self.process:
88 | if self.system == "windows":
89 | # For Windows
90 | subprocess.run(
91 | ["taskkill", "/F", "/IM", "ollama app.exe"], capture_output=True, text=True
92 | )
93 | subprocess.run(
94 | ["taskkill", "/F", "/IM", "ollama.exe"], capture_output=True, text=True
95 | )
96 | else:
97 | # For Linux/MacOS
98 | self.process.terminate()
99 | self.process.wait(timeout=5)
100 | del self.process
101 | self.process = None
102 | self.logger.info("Ollama service stopped successfully")
103 | return True
104 | else:
105 | self.logger.warning("Ollama service either started by user or has already stopped")
106 | return False
107 |
108 | except Exception as e:
109 | print(f"Error stopping Ollama service: {str(e)}")
110 | return False
111 |
112 | @staticmethod
113 | def is_ollama_running():
114 | """
115 | Checks if any Ollama process is currently running.
116 |
117 | Returns:
118 | bool: True if an Ollama process is running, False otherwise.
119 | """
120 | for proc in psutil.process_iter(["name", "cmdline"]):
121 | try:
122 | if "ollama" in proc.info["name"].lower():
123 | return True
124 | except (psutil.NoSuchProcess, psutil.AccessDenied):
125 | continue
126 | return False
127 |
128 | def is_service_running(self) -> bool:
129 | """
130 | Checks if the Ollama service managed by this instance is running.
131 |
132 | Returns:
133 | bool: True if the service is running, False otherwise.
134 | """
135 | if self.process and self.process.poll() is None:
136 | return True
137 | if self.is_ollama_running():
138 | return True
139 | return False
140 |
141 | def get_running_models(self):
142 | """
143 | Gets a list of the currently running models on Ollama.
144 |
145 | Returns:
146 | list: The models running in ollama service.
147 | """
148 | return self.ollama.ps()
149 |
150 | def stop_running_model(self, model_name: str):
151 | """
152 | Stops a specific model that is running in Ollama.
153 |
154 | Args:
155 | model_name (str): The name of the model to stop.
156 |
157 | Returns:
158 | bool: True if the model was stopped successfully, False otherwise.
159 | """
160 | try:
161 | stop_attempt = subprocess.check_output(
162 | ["ollama", "stop", model_name], stderr=subprocess.STDOUT, text=True
163 | )
164 | self.logger.info(
165 | f"Ollama service with {model_name} stopped successfully: {stop_attempt}"
166 | )
167 | return True
168 | except Exception as e:
169 | self.logger.warning(f"Failed to stop {model_name}")
170 | self.logger.warning(
171 | "Either the model was already stopped, or the Ollama service is already stopped"
172 | )
173 | self.logger.error(f"Error stopping Ollama service: {str(e)}")
174 | return False
175 |
176 | def get_llm_model(self, model_name: str):
177 | """
178 | Downloads a specific LLM model using Ollama.
179 |
180 | Args:
181 | model_name (str): The name of the model to download.
182 | """
183 | return self.ollama.pull(model_name)
184 |
185 | def get_list_of_downloaded_files(self) -> list[str]:
186 | """
187 | Retrieves a list of models that have been downloaded by Ollama.
188 |
189 | Returns:
190 | list[str]: A list of the names of the downloaded models.
191 | """
192 | model_list = []
193 | try:
194 | model_list = list(self.ollama.list()) # list[tuple[model, list[models]]]
195 | model_list = [i.model for i in model_list[0][1]]
196 | except Exception as e:
197 | self.logger.error(f"Error getting list of downloaded files: {str(e)}")
198 | return model_list
199 |
200 |
201 | class AskLLM:
202 | """
203 | A class to interact with a Large Language Model (LLM) using Ollama.
204 |
205 | This class handles loading, querying, and managing the LLM, including
206 | starting and stopping the Ollama service if necessary.
207 |
208 | Attributes:
209 | setup_cfg (Path): The path to the configuration file.
210 | cfg (dict): The configuration loaded from the config file.
211 | logger (logging.Logger): The logger instance for the class.
212 | self_started_ollama (bool): Indicates if the instance started the ollama service.
213 | ollama_service_manager (OllamaServiceManager): The manager for the ollama service.
214 | model_name (str): The name of the LLM model to use.
215 | model_temperature (float): The temperature parameter for the LLM.
216 | llm (ChatOllama): The ChatOllama instance.
217 | structured_llm (None | ChatOllama): The structured llm model used for invoke.
218 | """
219 |
220 | def __init__(
221 | self,
222 | config_file: Path | str,
223 | model_name: str = "llama3.1:latest",
224 | temperature: float = 0,
225 | ) -> None:
226 | """
227 | Initializes the AskLLM instance.
228 |
229 | Args:
230 | config_file (Path | str): The path to the configuration file.
231 | model_name (str, optional): The name of the LLM model to use. Defaults to "llama3.1:latest".
232 | temperature (float, optional): The temperature parameter for the LLM. Defaults to 0.
233 |
234 | Raises:
235 | FileNotFoundError: If the configuration file does not exist.
236 | ValueError: If the configuration file is not a YAML file.
237 | """
238 | # if config_file is str convert it to a Pathlike
239 | self.setup_cfg = Path(config_file) if isinstance(config_file, str) else config_file
240 |
241 | if not self.setup_cfg.exists():
242 | raise FileNotFoundError(f"File {str(self.setup_cfg)} does not exist")
243 |
244 | if self.setup_cfg.suffix != ".yml":
245 | raise ValueError(f"File {str(self.setup_cfg)} is not a yaml file")
246 |
247 | # load the yml file
248 | with open(self.setup_cfg) as f:
249 | self.cfg = yaml.safe_load(f)
250 |
251 | self.logger = get_logger(__name__)
252 |
253 | self.self_started_ollama: bool = False
254 | self.ollama_service_manager = OllamaServiceManager(logger=self.logger)
255 | self.model_name = model_name
256 | self.model_temperature = temperature
257 | self.llm: ChatOllama = self._load_llm_model(self.model_name, self.model_temperature)
258 | self.structured_llm = None
259 |
260 | def _load_llm_model(self, model_name: str, temperature: float) -> ChatOllama:
261 | """
262 | Loads the specified LLM model.
263 |
264 | Starts the Ollama service if it's not already running, and downloads the model if it's not already downloaded.
265 |
266 | Args:
267 | model_name (str): The name of the model to load.
268 | temperature (float): The temperature parameter for the LLM.
269 |
270 | Returns:
271 | ChatOllama: The loaded ChatOllama instance.
272 | """
273 | if not self.ollama_service_manager.is_service_running():
274 | self.logger.warning("Ollama service is not running. Attempting to start it.")
275 | self.ollama_service_manager.start_service()
276 | self.self_started_ollama = True
277 | self.logger.warning(f"Self started ollama service: {self.self_started_ollama}")
278 | self.logger.info("Ollama service found")
279 |
280 | if model_name not in self.ollama_service_manager.get_list_of_downloaded_files():
281 | self.logger.info(f"Downloading model {model_name}")
282 | self.ollama_service_manager.get_llm_model(model_name)
283 | self.logger.info(f"Model {model_name} downloaded")
284 | else:
285 | self.logger.info(f"Model {model_name} already downloaded")
286 |
287 | return ChatOllama(model=model_name, temperature=temperature)
288 |
289 | def invoke(self, input_text: str) -> dict | BaseModel:
290 | """
291 | Invokes the LLM with the given input text and returns a structured output.
292 |
293 | This method uses a predefined prompt to query the LLM and expects a response that conforms to the YoutubeDetails schema.
294 |
295 | Args:
296 | input_text (str): The input text to send to the LLM.
297 |
298 | Returns:
299 | dict | BaseModel: A dictionary or a BaseModel instance containing the LLM's response.
300 | """
301 | prompt = ChatPromptTemplate(
302 | messages=[
303 | SystemMessage(
304 | "You are a Youtubers digital assistant. Please provide creative, engaging, clickbait, key word rich and accurate information to the user."
305 | ),
306 | SystemMessage(
307 | "The Youtuber which runs an AI automated Youtube channel. The entire process involves me finding a script, making a video about it, and then using an AI image creator to make a thumbnail for the video."
308 | ),
309 | SystemMessage(
310 | "Be short and concise. Be articulate, no need to be verbose and justify your answer."
311 | ),
312 | HumanMessage(f"Script:\n{input_text}"),
313 | ],
314 | )
315 | self.structured_llm = self.llm.with_structured_output(YoutubeDetails, include_raw=True)
316 | return self.structured_llm.invoke(prompt.messages)
317 |
318 | def invoke_image_describer(self, script: str, input_text: str) -> dict | BaseModel:
319 | """
320 | Invokes the LLM to generate an image description based on the given script and input text.
321 |
322 | This method uses a predefined prompt to query the LLM and expects a response that conforms to the ImageDescriber schema.
323 |
324 | Args:
325 | script (str): The script to base the image description on.
326 | input_text (str): Additional text to guide the image description.
327 |
328 | Returns:
329 | dict | BaseModel: A dictionary or a BaseModel instance containing the LLM's image description.
330 | """
331 | prompt = ChatPromptTemplate(
332 | messages=[
333 | SystemMessage(
334 | "You are an AI image prompt generator, who specializes in image description. Helping users to create AI image prompts."
335 | ),
336 | SystemMessage(
337 | "The user provides the complete and the text to generate the prompt for. You should provide a detailed and creative description of an image. Note: Avoid mentioning names or text titles to be in the description. The more detailed and imaginative your description, the more interesting the resulting image will be."
338 | ),
339 | SystemMessage("Keep the description with 500 characters or less."),
340 | HumanMessage(f"Script:\n{script}"),
341 | HumanMessage(f"Text:\n{input_text}"),
342 | ]
343 | )
344 | self.structured_llm = self.llm.with_structured_output(ImageDescriber, include_raw=True)
345 | return self.structured_llm.invoke(prompt.messages)
346 |
347 | def quit_llm(self):
348 | """
349 | Shuts down the LLM and the Ollama service if it was started by this instance.
350 |
351 | This method stops any running models, and if the Ollama service was started by this
352 | instance, it stops the service as well. Finally, it deletes the instance variables to
353 | clean up.
354 | """
355 | self.ollama_service_manager.stop_running_model(self.model_name)
356 | if self.self_started_ollama:
357 | self.ollama_service_manager.stop_service()
358 | # Delete all instance variables
359 | for attr in list(self.__dict__.keys()):
360 | try:
361 | self.logger.debug(f"Deleting {attr}")
362 | if attr == "logger":
363 | continue
364 | delattr(self, attr)
365 | except Exception as e:
366 | self.logger.error(f"Error deleting {attr}: {e}")
367 | return
368 |
369 |
370 | # Pydantic YoutubeDetails
371 | class YoutubeDetails(BaseModel):
372 | """Details of the YouTube video."""
373 |
374 | title: str = Field(
375 | description="A fun and engaging title for the Youtube video. Has to be related to the reddit post and is not more than 100 characters."
376 | )
377 | description: str = Field(
378 | description="Description of the Youtube video. It should be a simple summary of the video."
379 | )
380 | tags: list[str] = Field(
381 | description="Tags of the Youtube video. tags are single words with no space in them."
382 | )
383 | thumbnail_description: str = Field(
384 | description="Thumbnail description of the Youtube video. provide detailed and creative descriptions that will inspire unique and interesting images from the AI. Keep in mind that the AI is capable of understanding a wide range of language and can interpret abstract concepts, so feel free to be as imaginative and descriptive as possible. For example, you could describe a scene from a futuristic city, or a surreal landscape filled with strange creatures. The more detailed and imaginative your description, the more interesting the resulting image will be."
385 | )
386 |
387 |
388 | # Image Describer
389 | class ImageDescriber(BaseModel):
390 | """Given text, Provides a detailed and creative description for an image."""
391 |
392 | description: str = Field(
393 | description="Provide a detailed and creative description that will inspire unique and interesting images from the AI. Keep in mind that the AI is capable of understanding a wide range of language and can interpret abstract concepts, so feel free to be as imaginative and descriptive as possible. For example, you could describe the scene in a pictorial way adding more details or elaborating the scenario. The more detailed and imaginative your description, the more interesting the resulting image will be."
394 | )
395 |
--------------------------------------------------------------------------------
/ShortsMaker/generate_image.py:
--------------------------------------------------------------------------------
1 | # https://huggingface.co/docs/diffusers/main/en/index
2 | import os
3 | from pathlib import Path
4 | from time import sleep
5 |
6 | import torch
7 | import yaml
8 | from diffusers import AutoencoderKL, FluxPipeline
9 | from transformers import CLIPTextModel, T5EncoderModel
10 |
11 | from .utils import get_logger
12 |
13 | MODEL_UNLOAD_DELAY = 5
14 |
15 |
16 | class GenerateImage:
17 | """
18 | A class for generating images using different Flux models from Hugging Face.
19 |
20 | This class provides methods to load and use various Flux models, including
21 | FLUX.1-dev, FLUX.1-schnell, and a custom Pixel Wave model. It handles model
22 | loading, image generation, and resource cleanup.
23 | """
24 |
25 | def __init__(self, config_file: Path | str) -> None:
26 | """
27 | Initializes the GenerateImage class.
28 |
29 | Args:
30 | config_file (Path | str): Path to the YAML configuration file containing settings
31 | such as the Hugging Face access token.
32 |
33 | Raises:
34 | FileNotFoundError: If the specified configuration file does not exist.
35 | ValueError: If the configuration file is not a YAML file.
36 | """
37 | # if config_file is str convert it to a Pathlike
38 | self.setup_cfg = Path(config_file) if isinstance(config_file, str) else config_file
39 |
40 | if not self.setup_cfg.exists():
41 | raise FileNotFoundError(f"File {str(self.setup_cfg)} does not exist")
42 |
43 | if self.setup_cfg.suffix != ".yml":
44 | raise ValueError(f"File {str(self.setup_cfg)} is not a yaml file")
45 |
46 | # load the yml file
47 | with open(self.setup_cfg) as f:
48 | self.cfg = yaml.safe_load(f)
49 |
50 | self.logger = get_logger(__name__)
51 |
52 | if "hugging_face_access_token" not in self.cfg:
53 | self.logger.warning(
54 | "Please add your huggingface access token to use Flux.1-Dev.\nDefaulting to use Flux.1-Schnell"
55 | )
56 |
57 | self.pipe: FluxPipeline | None = None
58 |
59 | def _load_model(self, model_id: str) -> bool:
60 | """
61 | Loads a Flux model from Hugging Face.
62 |
63 | Args:
64 | model_id (str): The ID of the Flux model to load from Hugging Face.
65 |
66 | Returns:
67 | bool: True if the model was loaded successfully.
68 |
69 | Raises:
70 | RuntimeError: If there is an error loading the Flux model.
71 | """
72 | try:
73 | self.pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
74 | self.logger.info(f"Loading Flux model from {model_id}")
75 | # to run on low vram GPUs (i.e. between 4 and 32 GB VRAM)
76 | # Choose ONE of the following:
77 | # pipe.enable_model_cpu_offload() # Best for low-VRAM GPUs
78 | self.pipe.enable_sequential_cpu_offload() # Alternative for moderate VRAM GPUs
79 | self.pipe.vae.enable_slicing() # Reduces memory usage for decoding
80 | self.pipe.vae.enable_tiling() # Further optimizes VAE computation
81 |
82 | # casting here instead of in the pipeline constructor because doing so in the constructor loads all models into CPU memory at once
83 | self.pipe.to(torch.float16)
84 |
85 | self.logger.info("Flux model loaded")
86 | return True
87 | except Exception as e:
88 | self.logger.error(e)
89 | raise RuntimeError("Error in loading the Flux model")
90 |
91 | # @retry(max_retries=MAX_RETRIES, delay=DELAY, notify=NOTIFY)
92 | def use_huggingface_flux_dev(
93 | self,
94 | prompt: str,
95 | output_path: str,
96 | negative_prompt: str = "",
97 | model_id: str = "black-forest-labs/FLUX.1-dev",
98 | steps: int = 20,
99 | seed: int = 0,
100 | height: int = 1024,
101 | width: int = 1024,
102 | guidance_scale: float = 3.5,
103 | ) -> bool:
104 | """
105 | Generates an image using the FLUX.1-dev model from Hugging Face.
106 |
107 | Args:
108 | prompt (str): The text prompt to guide image generation.
109 | output_path (str): The path to save the generated image.
110 | negative_prompt (str): The text prompt to guide what the model should avoid generating. Defaults to "".
111 | model_id (str): The ID of the FLUX.1-dev model on Hugging Face. Defaults to "black-forest-labs/FLUX.1-dev".
112 | steps (int): The number of inference steps. Defaults to 20.
113 | seed (int): The random seed for image generation. Defaults to 0.
114 | height (int): The height of the output image. Defaults to 1024.
115 | width (int): The width of the output image. Defaults to 1024.
116 | guidance_scale (float): The guidance scale for image generation. Defaults to 3.5.
117 |
118 | Returns:
119 | bool: True if the image was generated and saved successfully.
120 | """
121 | self.logger.info("This image generator uses the Flux Dev model.")
122 | # Add access token to environment variable
123 | if "hugging_face_access_token" in self.cfg and os.environ.get("HF_TOKEN") is None:
124 | self.logger.info("Setting HF_TOKEN environment variable")
125 | os.environ["HF_TOKEN"] = self.cfg["hugging_face_access_token"]
126 |
127 | self._load_model(model_id)
128 |
129 | self.logger.info("Generating image")
130 | image = self.pipe(
131 | prompt,
132 | negative_prompt,
133 | guidance_scale=guidance_scale,
134 | output_type="pil",
135 | num_inference_steps=steps,
136 | max_sequence_length=512,
137 | height=height,
138 | width=width,
139 | generator=torch.Generator("cpu").manual_seed(seed),
140 | ).images[0]
141 | image.save(output_path)
142 | self.logger.info(f"Image saved to {output_path}")
143 |
144 | del self.pipe
145 | self.pipe = None
146 | if torch.cuda.is_available():
147 | torch.cuda.empty_cache()
148 | self.logger.info("Wait for 5 seconds, So that the GPU memory can be freed")
149 | sleep(MODEL_UNLOAD_DELAY)
150 | return True
151 |
152 | # @retry(max_retries=MAX_RETRIES, delay=DELAY, notify=NOTIFY)
153 | def use_huggingface_flux_schnell(
154 | self,
155 | prompt: str,
156 | output_path: str,
157 | negative_prompt: str = "",
158 | model_id: str = "black-forest-labs/FLUX.1-schnell",
159 | steps: int = 4,
160 | seed: int = 0,
161 | height: int = 1024,
162 | width: int = 1024,
163 | guidance_scale: float = 0.0,
164 | ) -> bool:
165 | """
166 | Generates an image using the FLUX.1-schnell model from Hugging Face.
167 |
168 | Args:
169 | prompt (str): The text prompt to guide image generation.
170 | output_path (str): The path to save the generated image.
171 | negative_prompt (str): The text prompt to guide what the model should avoid generating. Defaults to "".
172 | model_id (str): The ID of the FLUX.1-schnell model on Hugging Face. Defaults to "black-forest-labs/FLUX.1-schnell".
173 | steps (int): The number of inference steps. Defaults to 4.
174 | seed (int): The random seed for image generation. Defaults to 0.
175 | height (int): The height of the output image. Defaults to 1024.
176 | width (int): The width of the output image. Defaults to 1024.
177 | guidance_scale (float): The guidance scale for image generation. Defaults to 0.0.
178 |
179 | Returns:
180 | bool: True if the image was generated and saved successfully.
181 | """
182 | self.logger.info("This image generator uses the Flux Schnell model.")
183 |
184 | self._load_model(model_id)
185 |
186 | self.logger.info("Generating image")
187 | image = self.pipe(
188 | prompt,
189 | negative_prompt,
190 | guidance_scale=guidance_scale,
191 | output_type="pil",
192 | num_inference_steps=steps,
193 | max_sequence_length=256,
194 | height=height,
195 | width=width,
196 | generator=torch.Generator("cpu").manual_seed(seed),
197 | ).images[0]
198 | image.save(output_path)
199 | self.logger.info(f"Image saved to {output_path}")
200 |
201 | del self.pipe
202 | self.pipe = None
203 | if torch.cuda.is_available():
204 | torch.cuda.empty_cache()
205 | self.logger.info("Wait for 5 seconds, So that the GPU memory can be freed")
206 | sleep(MODEL_UNLOAD_DELAY)
207 | return True
208 |
209 | def use_flux_pixel_wave(
210 | self,
211 | prompt: str,
212 | output_path: str,
213 | model_id: str = "https://huggingface.co/mikeyandfriends/PixelWave_FLUX.1-schnell_03/blob/main/pixelwave_flux1_schnell_fp8_03.safetensors",
214 | steps: int = 4,
215 | seed: int = 595570113709576,
216 | height: int = 1024,
217 | width: int = 1024,
218 | guidance_scale: float = 3.5,
219 | ) -> bool:
220 | """
221 | Generates an image using the custom Flux Pixel Wave model.
222 |
223 | Args:
224 | prompt (str): The text prompt to guide image generation.
225 | output_path (str): The path to save the generated image.
226 | model_id (str): The URL or path to the Pixel Wave model file. Defaults to "https://huggingface.co/mikeyandfriends/PixelWave_FLUX.1-schnell_03/blob/main/pixelwave_flux1_schnell_fp8_03.safetensors".
227 | steps (int): The number of inference steps. Defaults to 4.
228 | seed (int): The random seed for image generation. Defaults to 595570113709576.
229 | height (int): The height of the output image. Defaults to 1024.
230 | width (int): The width of the output image. Defaults to 1024.
231 | guidance_scale (float): The guidance scale for image generation. Defaults to 3.5.
232 |
233 | Returns:
234 | bool: True if the image was generated and saved successfully.
235 | """
236 | self.logger.info("This image generator uses the Flux Pixel Wave model.")
237 |
238 | text_encoder = CLIPTextModel.from_pretrained(
239 | "black-forest-labs/FLUX.1-schnell", subfolder="text_encoder"
240 | )
241 | text_encoder_2 = T5EncoderModel.from_pretrained(
242 | "black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2"
243 | )
244 | vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="vae")
245 |
246 | self.pipe = FluxPipeline.from_single_file(
247 | model_id,
248 | use_safetensors=True,
249 | torch_dtype=torch.bfloat16,
250 | # Load additional not included in safetensor
251 | text_encoder=text_encoder,
252 | text_encoder_2=text_encoder_2,
253 | vae=vae,
254 | )
255 |
256 | # to run on low vram GPUs (i.e. between 4 and 32 GB VRAM)
257 | # Choose ONE of the following:
258 | # pipe.enable_model_cpu_offload() # Best for low-VRAM GPUs
259 | self.pipe.enable_sequential_cpu_offload() # Alternative for moderate VRAM GPUs
260 | self.pipe.vae.enable_slicing() # Reduces memory usage for decoding
261 | self.pipe.vae.enable_tiling() # Further optimizes VAE computation
262 |
263 | # casting here instead of in the pipeline constructor because doing so in the constructor loads all models into CPU memory at once
264 | self.pipe.to(torch.float16)
265 |
266 | self.logger.info("Generating image")
267 | image = self.pipe(
268 | prompt,
269 | guidance_scale=guidance_scale,
270 | output_type="pil",
271 | num_inference_steps=steps,
272 | max_sequence_length=256,
273 | height=height,
274 | width=width,
275 | generator=torch.Generator("cpu").manual_seed(seed),
276 | ).images[0]
277 | image.save(output_path)
278 | self.logger.info(f"Image saved to {output_path}")
279 |
280 | del self.pipe
281 | del text_encoder
282 | del text_encoder_2
283 | del vae
284 |
285 | self.pipe = None
286 | if torch.cuda.is_available():
287 | torch.cuda.empty_cache()
288 | self.logger.info("Wait for 5 seconds, So that the GPU memory can be freed")
289 | sleep(MODEL_UNLOAD_DELAY)
290 | return True
291 |
292 | def quit(self) -> None:
293 | """
294 | Cleans up resources and exits the image generator.
295 |
296 | This method clears the CUDA cache (if available) and attempts to
297 | delete all instance variables to free up memory.
298 | """
299 | self.logger.info("Quitting the image generator")
300 | if torch.cuda.is_available():
301 | torch.cuda.empty_cache()
302 | # Delete all instance variables
303 | for attr in list(self.__dict__.keys()):
304 | try:
305 | self.logger.debug(f"Deleting {attr}")
306 | if attr == "logger":
307 | continue
308 | delattr(self, attr)
309 | except Exception as e:
310 | self.logger.error(f"Error deleting {attr}: {e}")
311 | return None
312 |
--------------------------------------------------------------------------------
/ShortsMaker/moviepy_create_video.py:
--------------------------------------------------------------------------------
1 | import random
2 | import secrets
3 | from dataclasses import dataclass
4 | from pathlib import Path
5 | from typing import Any
6 |
7 | import yaml
8 | from moviepy import (
9 | AudioFileClip,
10 | CompositeAudioClip,
11 | CompositeVideoClip,
12 | TextClip,
13 | VideoClip,
14 | VideoFileClip,
15 | afx,
16 | vfx,
17 | )
18 |
19 | from .utils import (
20 | COLORS_DICT,
21 | download_youtube_music,
22 | download_youtube_video,
23 | get_logger,
24 | )
25 |
26 | random.seed(secrets.randbelow(1000000))
27 |
28 |
29 | @dataclass
30 | class VideoConfig:
31 | cache_dir: Path
32 | assets_dir: Path
33 | audio_config: dict
34 | video_config: dict
35 | logging_config: dict
36 |
37 |
38 | class MoviepyCreateVideo:
39 | """
40 | Class for creating videos from media components using MoviePy.
41 |
42 | This class facilitates the creation of videos by integrating various media
43 | components such as background videos, audio tracks, music, fonts, and credits.
44 | It provides settings for fading effects, delays, and logging during video
45 | generation. It ensures proper initialization and handling of required
46 | directories, configurations, and external dependencies like FFmpeg.
47 | The class allows flexibility in providing media paths or configuring the
48 | video creation process dynamically.
49 |
50 | Attributes:
51 | DEFAULT_FADE_TIME (int): Default duration for fade effects applied to the video.
52 | DEFAULT_DELAY (int): Default delay applied between video transitions or sections.
53 | DEFAULT_SPEED_FACTOR (float): Default speed factor applied to the video.)
54 | REQUIRED_DIRECTORIES (list): List of essential directories required for using the class.
55 | PUNCTUATION_MARKS (list): List of punctuation marks used for processing transcripts or text inputs.
56 | """
57 |
58 | DEFAULT_FADE_TIME = 2
59 | DEFAULT_DELAY = 1
60 | DEFAULT_SPEED_FACTOR = 1
61 | REQUIRED_DIRECTORIES = ["video_dir", "music_dir", "fonts_dir", "credits_dir"]
62 | PUNCTUATION_MARKS = [".", ";", ":", "!", "?", ","]
63 |
64 | def __init__(
65 | self,
66 | config_file: Path | str,
67 | bg_video_path: Path | str = None,
68 | add_credits: bool = True,
69 | credits_path: Path | str = None,
70 | audio_path: Path | str = None,
71 | music_path: Path | str = None,
72 | transcript_path: Path | str = None,
73 | font_path: Path | str = None,
74 | fade_time: int = DEFAULT_FADE_TIME,
75 | delay: int = DEFAULT_DELAY,
76 | speed_factor: float = DEFAULT_SPEED_FACTOR,
77 | ) -> None:
78 | self.fade_time = fade_time
79 | self.delay = delay
80 | self.speed_factor = speed_factor
81 |
82 | # Initialize configuration
83 | self.config = self._load_configuration(config_file)
84 | self.logger = get_logger(__name__)
85 | self._verify_ffmpeg()
86 |
87 | # Initialize directories
88 | self._setup_directories()
89 |
90 | # Initialize media components
91 | self.audio_clip = self._initialize_audio(audio_path)
92 | self.audio_clip_bitrate = self.audio_clip.reader.bitrate
93 | self.logger.info(f"Audio Duration: {self.audio_clip.duration:.2f}s")
94 |
95 | self.audio_transcript = self._load_transcript(transcript_path)
96 | self.audio_transcript = self.preprocess_audio_transcript()
97 | self.word_transcript, self.sentences_transcript = (
98 | self.process_audio_transcript_to_word_and_sentences_transcript()
99 | )
100 |
101 | self.bg_video = self._initialize_background_video(bg_video_path)
102 | self.bg_video = self.prepare_background_video()
103 | self.bg_video_bitrate = self.bg_video.reader.bitrate
104 |
105 | self.music_clip = self._initialize_music(music_path)
106 | self.music_clip_bitrate = self.music_clip.reader.bitrate
107 |
108 | self.font_path = self._initialize_font(font_path)
109 |
110 | self.add_credits: bool = add_credits
111 | self.credits_video = self._initialize_credits(credits_path) if add_credits else None
112 |
113 | # Initialize color
114 | self.color = self._select_random_color()
115 | self.logger.info(f"Using color {self.color}")
116 |
117 | @staticmethod
118 | def _load_configuration(config_file: Path | str) -> VideoConfig:
119 | """
120 | Loads and validates a YAML configuration file, and then parses its content
121 | into a `VideoConfig` object. The method checks for the existence of the file
122 | and ensures that it has the correct `.yml` extension. If the validation fails,
123 | it raises a `ValueError`. Otherwise, the YAML content is loaded and used to
124 | instantiate a `VideoConfig` with the appropriate properties.
125 |
126 | Args:
127 | config_file: The path to the configuration file provided as a `Path` object
128 | or a `str`. The file must exist and have a `.yml` extension.
129 |
130 | Returns:
131 | A `VideoConfig` object created using the data parsed from the provided
132 | configuration file.
133 |
134 | Raises:
135 | ValueError: If the configuration file does not exist or does not have a
136 | `.yml` extension.
137 | """
138 | config_path = Path(config_file)
139 | if not config_path.exists() or config_path.suffix != ".yml":
140 | raise ValueError(f"Invalid configuration file: {config_path}")
141 |
142 | with open(config_path) as f:
143 | cfg = yaml.safe_load(f)
144 |
145 | return VideoConfig(
146 | cache_dir=Path(cfg["cache_dir"]),
147 | assets_dir=Path(cfg["assets_dir"]),
148 | audio_config=cfg.get("audio", {}),
149 | video_config=cfg.get("video", {}),
150 | logging_config=cfg.get("logging", {}),
151 | )
152 |
153 | def _verify_ffmpeg(self) -> None:
154 | """
155 | Verifies the availability of the FFmpeg utility on the system.
156 |
157 | This method checks if FFmpeg is installed and available in the system's PATH
158 | environment variable. It does so by attempting to execute the FFmpeg command
159 | with the `-version` argument, which displays FFmpeg's version details. If
160 | FFmpeg is not installed or cannot be accessed, an error is logged, and the
161 | exception is re-raised for further handling.
162 |
163 | Raises:
164 | Exception: If FFmpeg is not installed or available in the PATH. The specific
165 | exception raised during the failure will also be re-raised.
166 | """
167 | try:
168 | import subprocess
169 |
170 | subprocess.run(["ffmpeg", "-version"], check=True)
171 | except Exception as e:
172 | self.logger.error("ffmpeg is not installed or not available in path")
173 | raise e
174 |
175 | def _setup_directories(self) -> None:
176 | """
177 | Sets up the necessary directories for storing assets used in the application.
178 |
179 | This method configures various directories such as those for background
180 | videos, background music, fonts, and credits. Additionally, it ensures
181 | that the directories required for background videos and background music
182 | are created if they do not already exist.
183 |
184 | Attributes:
185 | video_dir: Path to the directory where background videos are stored.
186 | music_dir: Path to the directory where background music is stored.
187 | fonts_dir: Path to the directory where fonts are stored.
188 | credits_dir: Path to the directory where credits data is stored.
189 |
190 | Raises:
191 | FileNotFoundError: If the base assets directory does not exist or is
192 | inaccessible during operation.
193 | """
194 | self.video_dir = self.config.assets_dir / "background_videos"
195 | self.music_dir = self.config.assets_dir / "background_music"
196 | self.fonts_dir = self.config.assets_dir / "fonts"
197 | self.credits_dir = self.config.assets_dir / "credits"
198 |
199 | for directory in [self.video_dir, self.music_dir]:
200 | directory.mkdir(parents=True, exist_ok=True)
201 |
202 | def _initialize_audio(self, audio_path: Path | str = None) -> AudioFileClip:
203 | """
204 | Initializes the audio by loading an audio file from the specified path or a default
205 | directory set in the configuration. If no audio path is provided, the method will
206 | attempt to find an audio file in the configured cache directory.
207 |
208 | Args:
209 | audio_path (Path | str, optional): The path to the audio file. If None, the method
210 | will use the `cache_dir` and `output_audio_file` settings from the
211 | configuration to locate the audio file.
212 |
213 | Returns:
214 | AudioFileClip: An instance of AudioFileClip initialized with the located audio file.
215 |
216 | Raises:
217 | ValueError: If no `audio_path` is provided and the `audio_config` in the
218 | configuration does not specify an `output_audio_file` key.
219 | """
220 | if audio_path is None:
221 | self.logger.info(
222 | f"No audio path provided. Using the audio directory {self.config.cache_dir} to find an audio file."
223 | )
224 | if "output_audio_file" not in self.config.audio_config:
225 | raise ValueError("Missing 'output_audio_file' in 'audio' section in setup.yml")
226 | audio_path = self.config.cache_dir / self.config.audio_config["output_audio_file"]
227 | self.logger.info(f"Using audio file {audio_path}")
228 | return AudioFileClip(audio_path)
229 |
230 | def _initialize_background_video(self, bg_video_path: Path | str = None) -> VideoFileClip:
231 | """
232 | Initializes the background video for further video processing.
233 |
234 | If a `bg_video_path` is not explicitly provided, the method will attempt to use the URLs specified
235 | in the video configuration to download a random background video. The downloaded video path will
236 | then be used to initialize the `VideoFileClip` instance. If the video configuration does not
237 | contain `background_videos_urls`, a ValueError is raised.
238 |
239 | Args:
240 | bg_video_path (Path | str, optional): The path to the background video file. If None, a
241 | video is chosen or downloaded based on the configuration.
242 |
243 | Raises:
244 | ValueError: If `background_videos_urls` is missing in the video configuration.
245 |
246 | Returns:
247 | VideoFileClip: The background video instance without audio.
248 | """
249 | if bg_video_path is None:
250 | self.logger.info(
251 | "No background video path provided. Using the background_videos_urls to download a background video."
252 | )
253 | if "background_videos_urls" not in self.config.video_config:
254 | raise ValueError("Missing 'background_videos_urls' in 'video' section in setup.yml")
255 | # choose a random url
256 | bg_video_url = random.choice(self.config.video_config["background_videos_urls"])
257 | self.logger.info(f"Using bg_video_url {bg_video_url}")
258 | bg_videos_path = download_youtube_video(
259 | bg_video_url,
260 | self.video_dir,
261 | )
262 | bg_video_path = random.choice(bg_videos_path).absolute()
263 | self.logger.info(f"Using bg_video_path {bg_video_path}")
264 | return VideoFileClip(bg_video_path, audio=False)
265 |
266 | @staticmethod
267 | def _select_random_color() -> tuple[int, int, int, int]:
268 | """
269 | Selects a random color from a pre-defined dictionary of colors.
270 |
271 | This method accesses a dictionary containing color codes, selects a random
272 | key from the dictionary, and retrieves the corresponding color value. It is
273 | used to dynamically choose colors for various operations requiring random
274 | color assignments.
275 |
276 | Returns:
277 | tuple[int, int, int, int]: A tuple representing the RGBA color values,
278 | where each value corresponds to red, green, blue, and alpha channels.
279 |
280 | Raises:
281 | KeyError: If the dictionary is empty or an invalid key is accessed.
282 | """
283 | return COLORS_DICT[random.choice(list(COLORS_DICT.keys()))]
284 |
285 | def _load_transcript(self, transcript_path: Path | str) -> list[dict[str, Any]]:
286 | """
287 | Loads and parses a transcript file in YAML format. The function determines the path of the
288 | transcript file, either based on the provided argument or through the configuration setup.
289 | It ensures the path validity and raises appropriate errors if the file cannot be found or
290 | if key configuration fields are missing. If the transcript file exists, it reads and
291 | deserializes its content into a Python list of dictionaries.
292 |
293 | Args:
294 | transcript_path (Path | str): Path to the transcript file. If None, this argument will
295 | default to a path derived from the configuration settings.
296 |
297 | Raises:
298 | ValueError: If `transcript_path` is None and configuration settings do not specify
299 | 'transcript_json' in the 'audio' section of `setup.yml`.
300 | ValueError: If the resolved `transcript_path` does not exist or the file cannot be found.
301 |
302 | Returns:
303 | list[dict[str, Any]]: The content of the transcript file parsed as a list of dictionaries.
304 | """
305 | if transcript_path is None:
306 | self.logger.info(
307 | f"No transcript path provided. Using the audio directory {self.config.cache_dir} to find a transcript file."
308 | )
309 | if "transcript_json" not in self.config.audio_config:
310 | raise ValueError("Missing 'transcript_json' in 'audio' section in setup.yml")
311 | transcript_path = self.config.cache_dir / self.config.audio_config["transcript_json"]
312 | self.logger.info(f"Using transcript file {transcript_path}")
313 | path = Path(transcript_path)
314 | if not path.exists():
315 | raise ValueError(f"Transcript file not found: {path}")
316 |
317 | with open(transcript_path) as audio_transcript_file:
318 | return yaml.safe_load(audio_transcript_file)
319 |
320 | def _initialize_music(self, music_path: Path | str) -> AudioFileClip:
321 | """
322 | Initializes and loads a music file, handling cases where the music path is not provided
323 | by downloading music from a specified source.
324 |
325 | If a music path is not given, a random URL is selected from the list of background music
326 | URLs defined in the video configuration, and the corresponding music file is downloaded
327 | and set as the music path. Ensures the music file exists before returning it as an
328 | AudioFileClip object.
329 |
330 | Args:
331 | music_path (Path | str): The path to the music file to be loaded. If None, the method
332 | will attempt to download and use a file from a configured URL.
333 |
334 | Returns:
335 | AudioFileClip: An AudioFileClip object representing the loaded music file.
336 |
337 | Raises:
338 | ValueError: If the `background_music_urls` is missing in the video configuration when
339 | no `music_path` is provided, or if the specified music file does not exist.
340 | """
341 | if music_path is None:
342 | self.logger.info(
343 | "No music path provided. Using the background_music_urls to download a background music."
344 | )
345 | if "background_music_urls" not in self.config.video_config:
346 | raise ValueError("Missing 'background_music_urls' in 'video' section in setup.yml")
347 | # choose a random url
348 | music_url = random.choice(self.config.video_config["background_music_urls"])
349 | self.logger.info(f"Using music_url {music_url}")
350 | musics_path = download_youtube_music(
351 | music_url,
352 | self.music_dir,
353 | )
354 | music_path = random.choice(musics_path).absolute()
355 | self.logger.info(f"Using music_path {music_path}")
356 | path = Path(music_path)
357 | if not path.exists():
358 | raise ValueError(f"Music file not found: {path}")
359 |
360 | return AudioFileClip(path)
361 |
362 | def _initialize_font(self, font_path: Path | str) -> str:
363 | """
364 | Initializes and selects a font file for use, either provided explicitly or chosen
365 | from a predefined directory of font files.
366 |
367 | If the input `font_path` is not provided, a random `.ttf` file is selected from
368 | the directory. If the directory does not contain any `.ttf` files, the function raises
369 | an exception. If `font_path` is supplied but points to a non-existent file, an exception
370 | is also raised.
371 |
372 | Args:
373 | font_path (Path | str): The path to the font file to be initialized. If None,
374 | a font is randomly selected from the predefined font directory.
375 |
376 | Raises:
377 | ValueError: Raised when no font files exist in the predefined font directory,
378 | or when the given `font_path` does not exist.
379 |
380 | Returns:
381 | str: The absolute path of the selected font file, ensuring any valid input font
382 | file or selected file from the directory is returned in string format.
383 | """
384 | if font_path is None:
385 | self.logger.info(
386 | f"No font path provided. Using the fonts directory {self.fonts_dir} to find a font."
387 | )
388 | font_files = list(self.fonts_dir.glob("*.ttf"))
389 | if not font_files:
390 | raise ValueError(f"No font files found in {self.fonts_dir}")
391 | return random.choice(font_files).absolute()
392 | path = Path(font_path)
393 | if not path.exists():
394 | raise ValueError(f"Font file not found: {path}")
395 | return str(path)
396 |
397 | def _initialize_credits(self, credits_path: Path | str) -> VideoFileClip:
398 | """
399 | Initializes the credits video by either using a provided path or searching a default credits directory.
400 | Raises an error if no valid credits file is found or if the provided path does not exist.
401 | Additionally, applies a mask to the credits video if a corresponding mask file exists.
402 |
403 | Args:
404 | credits_path (Path | str): The path to the credits video file or directory. If None,
405 | the function will attempt to locate a credits file in the default credits directory.
406 |
407 | Returns:
408 | VideoFileClip: The VideoFileClip object of the credits video with the applied mask.
409 |
410 | Raises:
411 | FileNotFoundError: If no credits files are found in the default credits directory.
412 | ValueError: If the given credits_path does not exist.
413 | """
414 | if credits_path is None:
415 | self.logger.info(
416 | f"No credits path provided. Using the credits directory {self.credits_dir} to find a credits file."
417 | )
418 | credit_videos = list(self.credits_dir.glob("*.mp4"))
419 | if not credit_videos:
420 | raise FileNotFoundError("No credits files found in the credits directory")
421 | credits_path = self.credits_dir
422 | self.logger.info(f"Using credits file {credits_path}")
423 |
424 | path = Path(credits_path)
425 | if not path.exists():
426 | raise ValueError(f"Credits file not found: {path}")
427 |
428 | self.credit_video_mask: VideoFileClip = VideoFileClip(
429 | path / "credits_mask.mp4", audio=False
430 | ).to_mask()
431 |
432 | return VideoFileClip(path / "credits.mp4", audio=False).with_mask(self.credit_video_mask)
433 |
434 | def preprocess_audio_transcript(self) -> list[dict[str, Any]]:
435 | """
436 | Preprocesses the audio transcript to adjust segment boundaries.
437 |
438 | This method modifies the transcript by updating the "end" time of each audio
439 | segment to match the "start" time of the subsequent segment. This ensures that
440 | audio segments are properly aligned in cases where boundaries are not explicitly
441 | defined. The original transcript is updated in-place and returned.
442 |
443 | Returns:
444 | list[dict[str, Any]]: The updated list of audio transcript segments,
445 | each represented as a dictionary containing at least "start" and "end" keys.
446 |
447 | """
448 | for i in range(len(self.audio_transcript) - 1):
449 | self.audio_transcript[i]["end"] = self.audio_transcript[i + 1]["start"]
450 | return self.audio_transcript
451 |
452 | def process_audio_transcript_to_word_and_sentences_transcript(
453 | self,
454 | ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
455 | """
456 | Processes audio transcript data into individual word-level and sentence-level
457 | transcripts. The method parses the `audio_transcript` to create two lists: a
458 | word transcript and a sentences transcript. Each word and sentence transcript
459 | contains details such as text, start time, and end time. Sentences are delimited
460 | based on specific punctuation marks defined in `PUNCTUATION_MARKS`. If an error
461 | occurs during word processing, it logs the error using the `logger` attribute.
462 |
463 | Returns:
464 | tuple[list[dict[str, Any]], list[dict[str, Any]]]: A tuple containing two
465 | lists. The first list is the word transcript, where each item contains
466 | information about individual words and their start and end times. The second
467 | list is the sentences transcript, where each item contains information about
468 | sentences, including the sentence text and its start and end times.
469 | """
470 | word_transcript = []
471 | sentences_transcript = []
472 |
473 | sentence_start = 0
474 | sentence_end = 0
475 |
476 | word_start = 0
477 | word_end = 0
478 |
479 | sentence = ""
480 |
481 | for count, transcript in enumerate(self.audio_transcript):
482 | word = transcript["word"].strip()
483 | sentence += word + " "
484 |
485 | word_end = transcript["end"]
486 | word_transcript.append(
487 | {
488 | "word": sentence,
489 | "start": word_start,
490 | "end": word_end,
491 | }
492 | )
493 | word_start = word_end
494 |
495 | try:
496 | if word[-1] in self.PUNCTUATION_MARKS and word != "...":
497 | sentence_end = word_end
498 | sentences_transcript.append(
499 | {
500 | "sentence": sentence,
501 | "start": sentence_start,
502 | "end": sentence_end,
503 | }
504 | )
505 | sentence_start = sentence_end
506 | sentence = ""
507 | except Exception as e:
508 | self.logger.error(f"Error processing word '{word}': {e}")
509 |
510 | # Add final sentence if any remains
511 | if sentence != "":
512 | sentences_transcript.append(
513 | {
514 | "sentence": sentence,
515 | "start": sentence_start,
516 | "end": sentence_end,
517 | }
518 | )
519 |
520 | return word_transcript, sentences_transcript
521 |
522 | def prepare_background_video(self) -> VideoFileClip:
523 | """
524 | Prepares and processes a background video clip for use in a composition.
525 |
526 | This function modifies the background video clip by randomly selecting a segment,
527 | cropping it to a specific aspect ratio, applying crossfade effects, and logging relevant
528 | information about the original and processed video dimensions.
529 |
530 | Returns:
531 | VideoFileClip: The processed and modified video clip ready for use in further processing.
532 |
533 | Args:
534 | self: The instance of the class containing this method.
535 |
536 | Raises:
537 | AttributeError: If attributes such as `self.bg_video`, `self.audio_clip`, or `self.logger`
538 | are not properly defined in the class.
539 | ValueError: If the random start time exceeds the duration of the background video.
540 | """
541 | width, height = self.bg_video.size
542 |
543 | self.logger.info(
544 | f"Original video - Width: {width}, Height: {height}, FPS: {self.bg_video.fps}"
545 | )
546 |
547 | # Select random segment of appropriate length
548 | random_start = random.uniform(20, self.bg_video.duration - self.audio_clip.duration - 20)
549 | random_end = random_start + self.audio_clip.duration
550 |
551 | self.logger.info(f"Using video segment from {random_start:.2f}s to {random_end:.2f}s")
552 |
553 | # Crop and apply effects
554 | self.bg_video: VideoFileClip = self.bg_video.subclipped(
555 | start_time=random_start - self.delay, end_time=random_end + self.delay
556 | )
557 | self.bg_video: VideoFileClip = self.bg_video.cropped(
558 | x_center=width / 2, width=int(height * 9 / 16) & -2
559 | )
560 | self.bg_video: VideoFileClip = self.bg_video.with_effects(
561 | [vfx.FadeIn(self.fade_time), vfx.FadeOut(self.fade_time)]
562 | )
563 |
564 | new_width, new_height = self.bg_video.size
565 | self.logger.info(
566 | f"Processed video - Width: {new_width}, Height: {new_height}, FPS: {self.bg_video.fps}"
567 | )
568 | self.logger.info(f"Video Duration: {self.bg_video.duration:.2f}s")
569 |
570 | return self.bg_video
571 |
572 | def create_text_clips(self) -> list[TextClip]:
573 | """
574 | Creates a list of text clips for use in video editing.
575 |
576 | This method generates text clips based on a word transcript where each word is associated with its start and end
577 | time. These text clips are created with specific stylistic properties such as font size, color, background color,
578 | alignment, and other visual attributes. The generated clips are then appended to the object's `text_clips` list
579 | and returned.
580 |
581 | Returns:
582 | list[TextClip]: A list of TextClip objects representing the visualized words with specified styling and timing.
583 | """
584 | for word in self.word_transcript:
585 | clip = (
586 | TextClip(
587 | font=self.font_path,
588 | text=word["word"],
589 | font_size=int(0.06 * self.bg_video.size[0]),
590 | size=(int(0.8 * self.bg_video.size[0]), int(0.8 * self.bg_video.size[0])),
591 | color=self.color,
592 | bg_color=(0, 0, 0, 100),
593 | text_align="center",
594 | method="caption",
595 | stroke_color="black",
596 | stroke_width=1,
597 | transparent=True,
598 | )
599 | .with_start(word["start"] + self.delay)
600 | .with_end(word["end"] + self.delay)
601 | .with_position(("center", "center"))
602 | )
603 |
604 | self.text_clips.append(clip)
605 |
606 | return self.text_clips
607 |
608 | def prepare_audio(self) -> CompositeAudioClip:
609 | """
610 | Processes and combines audio clips to prepare the final audio track.
611 |
612 | This method applies a series of effects to a music clip, including looping to match
613 | the duration of a background video, adjusting the volume, and adding fade-in and fade-out
614 | effects. Additionally, it modifies the start time of another audio clip to introduce
615 | a delay. Once processed, it combines the music clip and the audio clip into a single
616 | composite audio track.
617 |
618 | Returns:
619 | CompositeAudioClip: A combined audio clip that includes the processed music and
620 | audio components.
621 | """
622 | # Process music clip
623 | self.music_clip = self.music_clip.with_effects(
624 | [afx.AudioLoop(duration=self.bg_video.duration)]
625 | )
626 | self.music_clip = self.music_clip.with_effects([afx.MultiplyVolume(factor=0.05)])
627 | self.music_clip = self.music_clip.with_effects(
628 | [afx.AudioFadeIn(self.fade_time), afx.AudioFadeOut(self.fade_time)]
629 | )
630 |
631 | self.audio_clip = self.audio_clip.with_start(self.delay)
632 |
633 | # Combine audio clips
634 | return CompositeAudioClip([self.music_clip, self.audio_clip])
635 |
636 | def __call__(
637 | self,
638 | output_path: str = "output.mp4",
639 | codec: str = "mpeg4",
640 | preset: str = "medium",
641 | threads: int = 8,
642 | ) -> bool:
643 | """
644 | Executes the video processing pipeline by assembling all video and audio elements,
645 | creating a composite video, and writing the final output to a file. The method allows
646 | customization of output file properties such as codec, preset, and threading.
647 |
648 | Args:
649 | output_path (str): The path to save the output video file. Defaults to "output.mp4".
650 | codec (str): The video codec to use for encoding the output video. Defaults to "mpeg4".
651 | preset (str): The compression preset to optimize encoding speed vs quality. Defaults to "medium".
652 | threads (int): The number of threads to use during encoding. Defaults to 8.
653 |
654 | Returns:
655 | bool: Returns True upon successful creation and saving of the video.
656 | """
657 | # Create video clips
658 | self.video_clips: list[VideoFileClip | TextClip] = [self.bg_video]
659 | self.text_clips: list[TextClip] = []
660 | self.text_clips = self.create_text_clips()
661 | self.video_clips.extend(self.text_clips)
662 |
663 | # Add the credits clip to the end of the video
664 | if self.add_credits:
665 | self.credits_video: VideoClip = self.credits_video.resized(
666 | width=int(0.8 * self.bg_video.size[0])
667 | )
668 | self.video_clips.append(
669 | self.credits_video.with_start(
670 | self.bg_video.duration - self.credits_video.duration
671 | ).with_position(("center", "bottom"))
672 | )
673 |
674 | # Combine video clips
675 | output_video = CompositeVideoClip(self.video_clips)
676 |
677 | # Prepare final audio
678 | output_audio = self.prepare_audio()
679 |
680 | # Create final video
681 | final_video: CompositeVideoClip = output_video.with_audio(output_audio)
682 |
683 | # Write output file
684 | final_video.write_videofile(
685 | output_path,
686 | codec=codec,
687 | bitrate=f"{self.bg_video_bitrate}k",
688 | fps=self.bg_video.fps,
689 | audio_bitrate=f"{max(self.music_clip_bitrate, self.audio_clip_bitrate)}k",
690 | preset=preset,
691 | threads=threads,
692 | )
693 | self.logger.info(f"Video successfully created at {output_path}")
694 |
695 | if self.speed_factor != 1:
696 | final_video.close()
697 | speedup_output_path = f"{output_path.split('.')[0]}_speed.mp4"
698 | self.logger.info(f"Speeding up video to {self.speed_factor}x")
699 | self.logger.info(f"Speed up video at {speedup_output_path}")
700 | import subprocess
701 |
702 | result = subprocess.run(
703 | [
704 | "ffmpeg",
705 | "-i",
706 | f"{output_path}",
707 | "-filter_complex",
708 | f"[0:v]setpts=(1/{self.speed_factor})*PTS[v];[0:a]atempo={self.speed_factor}[a]",
709 | "-map",
710 | "[v]",
711 | "-map",
712 | "[a]",
713 | f"{speedup_output_path}",
714 | "-y",
715 | ],
716 | capture_output=True,
717 | text=True,
718 | )
719 | if result.returncode == 0:
720 | self.logger.info("Speed adjustment completed successfully.")
721 | else:
722 | self.logger.error(f"Speed adjustment failed: {result.stderr}")
723 |
724 | output_video.close()
725 | output_audio.close()
726 | final_video.close()
727 |
728 | return True
729 |
730 | def quit(self) -> None:
731 | """
732 | Closes and cleans up resources used by the instance.
733 |
734 | This method ensures that all open files, clips, or other resources associated
735 | with the instance are properly closed and the corresponding instance variables
736 | are deleted. It handles exceptions gracefully, logging any issues encountered
737 | during the cleanup process.
738 |
739 | Raises:
740 | Exception: If an error occurs while closing a resource or deleting
741 | an attribute, it will log the error details.
742 | """
743 | try:
744 | # Close any open files or clips
745 | if hasattr(self, "audio_clip") and self.audio_clip:
746 | self.audio_clip.close()
747 | if hasattr(self, "bg_video") and self.bg_video:
748 | self.bg_video.close()
749 | if hasattr(self, "music_clip") and self.music_clip:
750 | self.music_clip.close()
751 | if hasattr(self, "credits_video") and self.credits_video:
752 | self.credits_video.close()
753 | if hasattr(self, "credit_video_mask") and self.credit_video_mask:
754 | self.credit_video_mask.close()
755 | except Exception as e:
756 | self.logger.error(f"Error closing resources: {e}")
757 |
758 | # Delete all instance variables
759 | for attr in list(self.__dict__.keys()):
760 | try:
761 | self.logger.debug(f"Deleting {attr}")
762 | if attr == "logger":
763 | continue
764 | delattr(self, attr)
765 | except Exception as e:
766 | self.logger.error(f"Error deleting {attr}: {e}")
767 |
768 | self.logger.debug("Resources successfully cleaned up.")
769 | return
770 |
--------------------------------------------------------------------------------
/ShortsMaker/shorts_maker.py:
--------------------------------------------------------------------------------
1 | import json
2 | import secrets
3 | from collections.abc import Generator
4 | from pathlib import Path
5 | from pprint import pformat
6 | from typing import Any
7 |
8 | import ftfy
9 | import language_tool_python
10 | import praw
11 | import yaml
12 | from praw.models import Submission, Subreddit
13 | from unidecode import unidecode
14 |
15 | from .utils import VOICES, generate_audio_transcription, get_logger, retry, tts
16 |
17 | # needed for retry decorator
18 | MAX_RETRIES: int = 1
19 | DELAY: int = 0
20 | NOTIFY: bool = False
21 |
22 | # Constants
23 | PUNCTUATIONS = [".", ";", ":", "!", "?", '"']
24 | ESCAPE_CHARACTERS = ["\n", "\t", "\r", " "]
25 | # abbreviation, replacement, padding
26 | ABBREVIATION_TUPLES = [
27 | ("\n", " ", ""),
28 | ("\t", " ", ""),
29 | ("\r", " ", ""),
30 | ("(", "", ""),
31 | (")", "", ""),
32 | ("AITA", "Am I the asshole ", " "),
33 | ("WIBTA", "Would I be the asshole ", " "),
34 | ("NTA", "Not the asshole ", " "),
35 | ("YTA", "You're the Asshole", ""),
36 | ("YWBTA", "You Would Be the Asshole", ""),
37 | ("YWNBTA", "You Would Not be the Asshole", ""),
38 | ("ESH", "Everyone Sucks here", ""),
39 | ("NAH", "No Assholes here", ""),
40 | ("INFO", "Not Enough Info", ""),
41 | ("FIL", "father in law ", " "),
42 | ("BIL", "brother in law ", " "),
43 | ("MIL", "mother in law ", " "),
44 | ("SIL", "sister in law ", " "),
45 | (" BF ", " boyfriend ", ""),
46 | (" GF ", " girlfriend ", ""),
47 | (" bf ", " boyfriend ", ""),
48 | (" gf ", " girlfriend ", ""),
49 | (" ", " ", ""),
50 | ]
51 |
52 |
53 | def abbreviation_replacer(text: str, abbreviation: str, replacement: str, padding: str = "") -> str:
54 | """
55 | Replaces all occurrences of an abbreviation within a given text with a specified replacement.
56 |
57 | This function allows replacing abbreviations with a replacement string while considering optional padding
58 | around the abbreviation. Padding ensures the abbreviation is correctly replaced regardless of its position
59 | in the text or the surrounding characters.
60 |
61 | Args:
62 | text (str): The text where the abbreviation occurrences should be replaced.
63 | abbreviation (str): The abbreviation to be replaced in the text.
64 | replacement (str): The string to replace the abbreviation with.
65 | padding (str, optional): Additional characters surrounding the abbreviation, making the
66 | replacement match more specific. Default is an empty string.
67 |
68 | Returns:
69 | str: The text with all occurrences of the abbreviation replaced by the replacement string.
70 | """
71 | text = text.replace(abbreviation + padding, replacement)
72 | text = text.replace(padding + abbreviation, replacement)
73 | return text
74 |
75 |
76 | def has_alpha_and_digit(word: str) -> bool:
77 | """
78 | Determines if a string contains both alphabetic and numeric characters.
79 |
80 | This function checks whether the given string contains at least one alphabetic
81 | character and at least one numeric character. It utilizes Python's string methods
82 | to identify the required character types.
83 |
84 | Args:
85 | word: The string to check for the presence of alphabetic and numeric
86 | characters.
87 |
88 | Returns:
89 | bool: True if the string contains at least one alphabetic character and one
90 | numeric character, otherwise False.
91 | """
92 | return any(character.isalpha() for character in word) and any(
93 | character.isdigit() for character in word
94 | )
95 |
96 |
97 | def split_alpha_and_digit(word):
98 | """
99 | Splits a given string into separate segments of alphabetic and numeric sequences.
100 |
101 | This function processes each character in the input string and divides it into
102 | distinct groups of alphabetic sequences and numeric sequences. A space is added
103 | between these groups whenever a transition occurs between alphabetic and numeric
104 | characters, or vice versa. Non-alphanumeric characters are included as is without
105 | causing a split.
106 |
107 | Args:
108 | word (str): The input string to be split into alphabetic and numeric
109 | segments.
110 |
111 | Returns:
112 | str: A string where alphabetic and numeric segments from the input are
113 | separated by a space while retaining other characters.
114 | """
115 | res = ""
116 | alpha = False
117 | digit = False
118 | for character in word:
119 | if character.isalpha():
120 | alpha = True
121 | if digit:
122 | res += " "
123 | digit = False
124 | res += character
125 | elif character.isdigit():
126 | digit = True
127 | if alpha:
128 | res += " "
129 | alpha = False
130 | res += character
131 | else:
132 | res += character
133 | return res
134 |
135 |
136 | class ShortsMaker:
137 | """
138 | Represents a utility class to facilitate the creation of video shorts from
139 | text and audio assets. The class manages configuration, logging, processing
140 | of text for audio generation, reddit post retrieval, and asset handling among
141 | other operations.
142 |
143 | The `ShortsMaker` class is designed to be highly extensible and configurable
144 | via YAML configuration files. It includes robust error handling for invalid
145 | or missing configuration files and directories. The functionality also integrates
146 | with external tools such as Reddit API and grammar correction tools to streamline
147 | the process of creating video shorts.
148 |
149 | Attributes:
150 | VALID_CONFIG_EXTENSION (str): Expected file extension for configuration files.
151 | setup_cfg (Path): Path of the validated configuration file.
152 | cfg (dict): Parsed configuration from the loaded configuration file.
153 | cache_dir (Path): Directory for storing temporary files and intermediate data.
154 | logging_cfg (dict): Configuration for setting up logging.
155 | logger (Logger): Logger instance used for logging events and errors.
156 | retry_cfg (dict): Configuration parameters for retry logic, including maximum
157 | retries and delay between retries.
158 | word_transcript (str | None): Transcript represented as individual words.
159 | line_transcript (str | None): Transcript represented as individual lines.
160 | transcript (str | None): Full transcript derived from the Reddit post or input text.
161 | audio_cfg (dict | None): Configuration details specific to audio processing.
162 | reddit_post (dict | None): Details related to the Reddit post being processed.
163 | reddit_cfg (dict | None): Configuration details specific to Reddit API integration.
164 | """
165 |
166 | VALID_CONFIG_EXTENSION = ".yml"
167 |
168 | def __init__(self, config_file: Path | str) -> None:
169 | self.setup_cfg = self._validate_config_path(config_file)
170 | self.cfg = self._load_config()
171 | self.logger = get_logger(__name__)
172 | self.cache_dir = self._setup_cache_directory()
173 | self.retry_cfg = self._setup_retry_config()
174 |
175 | # Initialize other instance variables
176 | self.word_transcript: str | None = None
177 | self.line_transcript: str | None = None
178 | self.transcript: str | None = None
179 | self.audio_cfg: dict | None = None
180 | self.reddit_post: dict | None = None
181 | self.reddit_cfg: dict | None = None
182 |
183 | def _validate_config_path(self, config_file: Path | str) -> Path:
184 | """
185 | Validates the given configuration file path to ensure it exists and has the correct format.
186 |
187 | This method checks whether the provided file path points to an actual file and
188 | whether its extension matches the expected configuration file format. If any
189 | of these conditions are not met, appropriate exceptions are raised.
190 |
191 | Args:
192 | config_file: A file path string or a Path object representing the
193 | configuration file to be validated.
194 |
195 | Returns:
196 | The validated configuration file path as a Path object.
197 |
198 | Raises:
199 | FileNotFoundError: If the configuration file does not exist.
200 | ValueError: If the configuration file format is invalid.
201 | """
202 | config_path = Path(config_file) if isinstance(config_file, str) else config_file
203 | if not config_path.exists():
204 | raise FileNotFoundError(f"Configuration file not found: {config_path}")
205 | if config_path.suffix != self.VALID_CONFIG_EXTENSION:
206 | raise ValueError(
207 | f"Invalid configuration file format. Expected {self.VALID_CONFIG_EXTENSION}"
208 | )
209 | return config_path
210 |
211 | def _load_config(self) -> dict[str, Any]:
212 | """
213 | Loads and parses configuration data from a YAML file.
214 |
215 | This method attempts to open and parse a YAML configuration file specified
216 | by the `setup_cfg` attribute of the class. If the file does not contain valid
217 | YAML or cannot be read, it raises an exception with an appropriate error message.
218 |
219 | Returns:
220 | Dict[str, Any]: A dictionary representation of the loaded YAML configuration.
221 |
222 | Raises:
223 | ValueError: If the YAML file contains invalid content or cannot be parsed.
224 | """
225 | try:
226 | with open(self.setup_cfg) as f:
227 | return yaml.safe_load(f)
228 | except yaml.YAMLError as e:
229 | raise ValueError(f"Invalid YAML configuration: {e}")
230 |
231 | def _setup_cache_directory(self) -> Path:
232 | """
233 | Sets up the cache directory based on the configuration and ensures its existence.
234 |
235 | This method retrieves the cache directory path from the configuration, creates the
236 | directory (including any required parent directories), and returns the Path object
237 | representing the cache directory. If the directory already exists, it will not attempt
238 | to create it again.
239 |
240 | Returns:
241 | Path: A Path object representing the cache directory.
242 | """
243 | if "cache_dir" not in self.cfg:
244 | self.logger.info("Cache directory not specified, creating it.")
245 | self.cfg["cache_dir"] = Path.cwd()
246 | cache_dir = Path(self.cfg["cache_dir"])
247 | cache_dir.mkdir(parents=True, exist_ok=True)
248 | return cache_dir
249 |
250 | def _setup_retry_config(self) -> dict[str, Any]:
251 | """
252 | Configures the retry mechanism based on the settings provided in the
253 | configuration. If retry is disabled, it updates the retry settings to default
254 | values. Updates global constants to reflect the current retry configuration.
255 |
256 | Returns:
257 | Dict[str, Any]: The updated retry configuration dictionary.
258 | """
259 | retry_config = dict()
260 | if "retry" not in self.cfg:
261 | retry_config = {"max_retries": 1, "delay": 0, "notify": False}
262 | if "retry" in self.cfg:
263 | retry_config.update(self.cfg["retry"])
264 |
265 | global MAX_RETRIES, DELAY, NOTIFY
266 | MAX_RETRIES = retry_config["max_retries"]
267 | DELAY = retry_config["delay"]
268 | NOTIFY = retry_config["notify"]
269 |
270 | return retry_config
271 |
272 | def get_submission_from_subreddit(
273 | self, reddit: praw.Reddit, subreddit_name: str
274 | ) -> Generator[Submission]:
275 | """
276 | Retrieves a unique Reddit submission from a specified subreddit.
277 |
278 | Args:
279 | reddit (praw.Reddit): An instance of the Reddit API client.
280 | subreddit_name (str): The name of the subreddit to fetch submissions from.
281 | submission_category (str): The category of submissions to filter by (e.g., "hot", "new") TODO.
282 |
283 | Returns:
284 | Submission: A unique Reddit submission object.
285 | """
286 | subreddit: Subreddit = reddit.subreddit(subreddit_name)
287 | self.logger.info(f"Subreddit title: {subreddit.title}")
288 | self.logger.info(f"Subreddit display name: {subreddit.display_name}")
289 | yield from subreddit.hot()
290 |
291 | def is_unique_submission(self, submission: Submission) -> bool:
292 | """
293 | Checks if the given Reddit submission is unique based on its ID.
294 |
295 | Args:
296 | submission (Submission): The Reddit submission to check.
297 |
298 | Returns:
299 | bool: True if the submission is unique, False otherwise.
300 | """
301 | submission_dirs = self.cache_dir / "reddit_submissions"
302 | submission_dirs.mkdir(parents=True, exist_ok=True)
303 | self.logger.debug("Checking if submission is unique")
304 | self.logger.debug(f"Submission ID: {submission.id}")
305 | if any(f"{submission.name}.json" == file.name for file in submission_dirs.iterdir()):
306 | self.logger.info(f"Submission {submission.name} - '{submission.title}' already exists")
307 | return False
308 | else:
309 | with open(submission_dirs / f"{submission.name}.json", "w") as record_file:
310 | # Object of type Reddit is not JSON serializable, hence need to use vars
311 | json.dump(
312 | {key: str(value) for key, value in vars(submission).items()},
313 | record_file,
314 | indent=4,
315 | skipkeys=True,
316 | sort_keys=True,
317 | )
318 | self.logger.debug("Unique submission found")
319 | self.logger.info(f"Submission saved to {submission_dirs / f'{submission.name}.json'}")
320 | return True
321 |
322 | @retry(max_retries=MAX_RETRIES, delay=DELAY, notify=NOTIFY)
323 | def get_reddit_post(self, url: str | None = None) -> str:
324 | """
325 | Retrieves a random top Reddit post from a specified subreddit, saves the post details
326 | to both a JSON file and a text file, and returns the text content of the post.
327 |
328 | Args:
329 | url (str | None): The URL of the Reddit post to retrieve. If None, a random top post is retrieved.
330 |
331 | Returns:
332 | str: The text content of the retrieved Reddit post.
333 |
334 | Raises:
335 | ValueError: If any value processing errors occur.
336 | IOError: If file handling (reading/writing) fails.
337 | praw.exceptions.PRAWException: If PRAW encounters an API or authentication issue.
338 | """
339 | self.reddit_cfg = self.cfg["reddit_praw"]
340 | self.reddit_post = self.cfg["reddit_post_getter"]
341 |
342 | self.logger.info("Getting Reddit post")
343 | reddit: praw.Reddit = praw.Reddit(
344 | client_id=self.reddit_cfg["client_id"],
345 | client_secret=self.reddit_cfg["client_secret"],
346 | user_agent=self.reddit_cfg["user_agent"],
347 | # username=self.reddit_cfg["username"],
348 | # password=self.reddit_cfg["password"]
349 | )
350 | self.logger.info(f"Is reddit readonly: {reddit.read_only}")
351 |
352 | if url:
353 | submission = reddit.submission(url=url)
354 | else:
355 | for submission_found in self.get_submission_from_subreddit(
356 | reddit, self.reddit_post["subreddit_name"]
357 | ):
358 | if self.is_unique_submission(submission_found):
359 | submission = submission_found
360 | break
361 |
362 | self.logger.info(f"Submission Url: {submission.url}")
363 | self.logger.info(f"Submission title: {submission.title}")
364 |
365 | data = dict()
366 | for key, value in vars(submission).items():
367 | data[key] = str(value)
368 |
369 | # Save the submission to a json file
370 | with open(self.cache_dir / self.reddit_post["record_file_json"], "w") as record_file:
371 | # noinspection PyTypeChecker
372 | json.dump(data, record_file, indent=4, skipkeys=True, sort_keys=True)
373 | self.logger.info(
374 | f"Submission saved to {self.cache_dir / self.reddit_post['record_file_json']}"
375 | )
376 |
377 | # Save the submission to a text file
378 | with open(self.cache_dir / self.reddit_post["record_file_txt"], "w") as text_file:
379 | text_file.write(unidecode(ftfy.fix_text(submission.title)) + "." + "\n")
380 | text_file.write(unidecode(ftfy.fix_text(submission.selftext)) + "\n")
381 | self.logger.info(
382 | f"Submission text saved to {self.cache_dir / self.reddit_post['record_file_txt']}"
383 | )
384 |
385 | # return the generated file contents
386 | with open(self.cache_dir / self.reddit_post["record_file_txt"]) as result_file:
387 | result_string = result_file.read()
388 | return result_string
389 |
390 | @retry(max_retries=MAX_RETRIES, delay=DELAY, notify=NOTIFY)
391 | def fix_text(self, source_txt: str, debug: bool = True) -> str:
392 | """
393 | Fixes and corrects grammatical and textual issues in the provided text input using language processing tools.
394 | The method processes the input text by fixing encoding issues, normalizing it, splitting it into sentences,
395 | and then correcting the grammar of each individual sentence. An optional debug mode saves the processed text
396 | to a debug file for inspection.
397 |
398 | Args:
399 | source_txt: The text to be processed and corrected.
400 | debug: If True, saves the corrected text to a debug file for further analysis.
401 |
402 | Returns:
403 | str: The corrected and formatted text.
404 |
405 | Raises:
406 | Exception: Raised if errors occur during text correction within individual sentences.
407 | """
408 | self.logger.info("Setting up language tool text fixer")
409 | grammar_fixer = language_tool_python.LanguageTool("en-US")
410 |
411 | source_txt = ftfy.fix_text(source_txt)
412 | source_txt = unidecode(source_txt)
413 | for escape_char in ESCAPE_CHARACTERS:
414 | source_txt = source_txt.replace(escape_char, " ")
415 |
416 | sentences = []
417 | res = []
418 |
419 | for word in source_txt.split(" "):
420 | if word == "":
421 | continue
422 | if word[0] in PUNCTUATIONS:
423 | sentences.append(" ".join(res))
424 | res = []
425 | res.append(word)
426 | if word[-1] in PUNCTUATIONS:
427 | sentences.append(" ".join(res))
428 | res = []
429 |
430 | self.logger.info(
431 | f"Split text into sentences and fixed text. Found {len(sentences)} sentences"
432 | )
433 |
434 | corrected_sentences = []
435 | for sentence in sentences:
436 | try:
437 | corrected_sentences.append(grammar_fixer.correct(sentence))
438 | except Exception as e:
439 | self.logger.error(f"Error: {e}")
440 | corrected_sentences.append(sentence)
441 |
442 | grammar_fixer.close()
443 | result_string = " ".join(corrected_sentences)
444 |
445 | if debug:
446 | with open(self.cache_dir / "fix_text_debug.txt", "w") as text_file:
447 | text_file.write(result_string)
448 | self.logger.info(f"Debug text saved to {self.cache_dir / 'fix_text_debug.txt'}")
449 |
450 | return result_string
451 |
452 | @retry(max_retries=MAX_RETRIES, delay=DELAY, notify=NOTIFY)
453 | def generate_audio(
454 | self,
455 | source_txt: str,
456 | output_audio: str | None = None,
457 | output_script_file: str | None = None,
458 | seed: str | None = None,
459 | ) -> bool:
460 | """
461 | Generates audio from a given textual input. The function processes the input text,
462 | performs text transformations (e.g., replacing abbreviations and splitting alphanumeric
463 | combinations), and uses a synthesized voice to create an audio file. It also writes the
464 | processed script to a text file. Speaker selection is either randomized or based on the
465 | provided seed.
466 |
467 | Args:
468 | source_txt (str): The input text to be converted into audio.
469 | output_audio (str | None): The path to save the generated audio. If not provided, a
470 | default path is generated based on the configuration or cache directory.
471 | output_script_file (str | None): The file path to save the processed text script. If
472 | not provided, a default path is generated based on the configuration or cache
473 | directory.
474 | seed (str | None): An optional seed to determine the choice of speaker. If not
475 | provided, the function randomly selects a speaker. Refer to VOICES for available
476 | speakers.
477 |
478 | Returns:
479 | bool: Returns True if the audio generation is successful; False otherwise.
480 |
481 | Raises:
482 | Exception: If an error occurs during text-to-speech processing.
483 | """
484 | self.audio_cfg = self.cfg["audio"]
485 | if output_audio is None:
486 | self.logger.info("No output audio file specified. Generating output audio file")
487 | if "output_audio_file" in self.audio_cfg:
488 | output_audio = self.cache_dir / self.audio_cfg["output_audio_file"]
489 | else:
490 | output_audio = self.cache_dir / "output.wav"
491 |
492 | if output_script_file is None:
493 | self.logger.info("No output script file specified. Generating output script file")
494 | if "output_script_file" in self.audio_cfg:
495 | output_script_file = self.cache_dir / self.audio_cfg["output_script_file"]
496 | else:
497 | output_script_file = self.cache_dir / "generated_audio_script.txt"
498 |
499 | self.logger.info("Generating audio from text")
500 | for abbreviation, replacement, padding in ABBREVIATION_TUPLES:
501 | source_txt = abbreviation_replacer(source_txt, abbreviation, replacement, padding)
502 | source_txt = source_txt.strip()
503 |
504 | for s in source_txt.split(" "):
505 | if has_alpha_and_digit(s):
506 | source_txt = source_txt.replace(s, split_alpha_and_digit(s))
507 |
508 | with open(output_script_file, "w") as text_file:
509 | text_file.write(source_txt)
510 | self.logger.info(f"Text saved to {output_script_file}")
511 |
512 | if seed is None:
513 | speaker = secrets.choice(VOICES)
514 | else:
515 | speaker = seed
516 |
517 | self.logger.info(f"Generating audio with speaker: {speaker}")
518 |
519 | try:
520 | tts(source_txt, speaker, output_audio)
521 | self.logger.info(
522 | f"Successfully generated audio.\nSpeaker: {speaker}\nOutput path: {output_audio}"
523 | )
524 | except Exception as e:
525 | self.logger.error(f"Error: {e}")
526 | self.logger.error("Failed to generate audio with tiktokvoice")
527 | return False
528 |
529 | return True
530 |
531 | @retry(max_retries=MAX_RETRIES, delay=DELAY, notify=NOTIFY)
532 | def generate_audio_transcript(
533 | self,
534 | source_audio_file: Path | str,
535 | source_text_file: Path | str,
536 | output_transcript_file: str | None = None,
537 | debug: bool = True,
538 | ) -> list[dict[str, str | float]]:
539 | """
540 | Generates an audio transcript by processing a source audio file and its corresponding text
541 | file, using predefined configurations such as model, device, and batch size. Saves the
542 | resulting transcript into a specified output file or a default cache location. Additionally,
543 | provides an option to enable debug logging.
544 |
545 | Args:
546 | source_audio_file (Path): The source audio file to be transcribed.
547 | source_text_file (Path): The text file containing the corresponding script.
548 | output_transcript_file (str | None): The file where the resulting transcript will be saved.
549 | Defaults to a predefined location if not specified.
550 | debug (bool): Whether to enable debug logging of the processed transcript.
551 |
552 | Returns:
553 | list[dict[str, str | float]]: A list of word-level transcription data, where each entry
554 | contains word-related information such as timestamps and confidence scores.
555 | """
556 | self.audio_cfg = self.cfg["audio"]
557 | self.logger.info("Generating audio transcript")
558 |
559 | # read the script
560 | with open(source_text_file) as text_file:
561 | source_text = text_file.read()
562 |
563 | self.word_transcript = generate_audio_transcription(
564 | audio_file=str(source_audio_file),
565 | script=source_text,
566 | device=self.audio_cfg["device"],
567 | model=self.audio_cfg["model"],
568 | batch_size=self.audio_cfg["batch_size"],
569 | compute_type=self.audio_cfg["compute_type"],
570 | )
571 | self.word_transcript = self._filter_word_transcript(self.word_transcript)
572 |
573 | if output_transcript_file is None:
574 | output_transcript_file = self.cache_dir / self.audio_cfg["transcript_json"]
575 |
576 | self.logger.info(f"Saving transcript to {output_transcript_file}")
577 |
578 | with open(output_transcript_file, "w") as transcript_file:
579 | # noinspection PyTypeChecker
580 | json.dump(
581 | self.word_transcript,
582 | transcript_file,
583 | indent=4,
584 | skipkeys=True,
585 | sort_keys=True,
586 | )
587 |
588 | if debug:
589 | self.logger.info(pformat(self.word_transcript))
590 |
591 | return self.word_transcript
592 |
593 | def _filter_word_transcript(
594 | self, transcript: list[dict[str, str | float]]
595 | ) -> list[dict[str, str | float]]:
596 | # filter entries which have a start time of 0 and end time of greater than 5s
597 | return [
598 | entry
599 | for entry in transcript
600 | if entry["start"] > 0 and (entry["end"] - entry["start"]) < 5
601 | ]
602 |
603 | def quit(self) -> None:
604 | """
605 | Closes and cleans up resources used in the class instance.
606 |
607 | This method ensures that all resources, tools, and variables used within the
608 | class instance are properly closed or removed to prevent memory leaks or issues
609 | when the instance is no longer in use. It includes closing language tools, if
610 | utilized, and deleting all instance variables except the logger.
611 |
612 | Raises:
613 | Exception: If there is an issue closing the grammar fixer or deleting
614 | instance variables. Specific details are logged.
615 |
616 | Returns:
617 | None: This method does not return any value.
618 | """
619 | self.logger.debug("Closing and cleaning up resources.")
620 | # Close the language tool if it was used
621 | if hasattr(self, "grammar_fixer") and self.grammar_fixer:
622 | try:
623 | self.grammar_fixer.close()
624 | except Exception as e:
625 | self.logger.error(f"Error closing grammar fixer: {e}")
626 |
627 | # Delete all instance variables
628 | for attr in list(self.__dict__.keys()):
629 | try:
630 | self.logger.debug(f"Deleting {attr}")
631 | if attr == "logger":
632 | continue
633 | delattr(self, attr)
634 | except Exception as e:
635 | self.logger.warning(f"Error deleting {attr}: {e}")
636 |
637 | self.logger.debug("All objects in the class have been deleted.")
638 | return
639 |
--------------------------------------------------------------------------------
/ShortsMaker/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .audio_transcript import align_transcript_with_script, generate_audio_transcription
2 | from .colors_dict import COLORS_DICT
3 | from .download_youtube_music import download_youtube_music, sanitize_filename
4 | from .download_youtube_video import download_youtube_video
5 | from .get_tts import VOICES, tts
6 | from .logging_config import configure_logging, get_logger
7 | from .notify_discord import notify_discord
8 | from .retry import retry
9 |
10 | __all__ = [
11 | align_transcript_with_script,
12 | configure_logging,
13 | download_youtube_music,
14 | download_youtube_video,
15 | generate_audio_transcription,
16 | get_logger,
17 | notify_discord,
18 | retry,
19 | sanitize_filename,
20 | tts,
21 | COLORS_DICT,
22 | VOICES,
23 | ]
24 |
25 | # Configure logging to their preferences
26 | configure_logging()
27 |
--------------------------------------------------------------------------------
/ShortsMaker/utils/audio_transcript.py:
--------------------------------------------------------------------------------
1 | import gc
2 | from pprint import pformat
3 |
4 | import torch
5 | import whisperx
6 | from rapidfuzz import process
7 |
8 | from .logging_config import get_logger
9 |
10 | logger = get_logger(__name__)
11 |
12 |
13 | def align_transcript_with_script(transcript: list[dict], script_string: str) -> list[dict]:
14 | """
15 | Aligns the transcript entries with corresponding segments of a script string by
16 | comparing text similarities and finding the best matches. This process adjusts
17 | each transcript entry to align as closely as possible with the correct script
18 | segment while maintaining the temporal information from the transcript.
19 |
20 | Args:
21 | transcript (list[dict]): A list of dictionaries, each containing a segment
22 | of the transcript with keys "text" (text of the segment), "start"
23 | (start time), and "end" (end time).
24 | script_string (str): The entire script as a single string to which the
25 | transcript is aligned.
26 |
27 | Returns:
28 | list[dict]: A list containing the transcript with updated "text" fields
29 | that are aligned to the most similar segments of the script. The
30 | "start" and "end" fields remain unchanged from the input.
31 |
32 | Raises:
33 | ValueError: If either the transcript or script_string is empty. Applies
34 | to cases where alignment cannot be performed.
35 |
36 | """
37 | temp_transcript = []
38 | window_sizes = [i for i in range(6)]
39 | script_words = script_string.split()
40 |
41 | for entry in transcript:
42 | possible_windows = []
43 | length_of_entry_text = len(entry["text"].split())
44 |
45 | # Generate script windows for all specified window sizes
46 | for window_size in window_sizes:
47 | possible_windows.extend([" ".join(script_words[: length_of_entry_text + window_size])])
48 | possible_windows.extend([" ".join(script_words[: length_of_entry_text - window_size])])
49 |
50 | # Find the best match among all possible windows
51 | # print(f"Entry text: {entry['text']}\n"
52 | # f"Possible windows: {possible_windows}"
53 | # "\n\n\n"
54 | # )
55 | best_match, score, _ = process.extractOne(entry["text"], possible_windows)
56 |
57 | if best_match:
58 | script_words = script_words[len(best_match.split()) :]
59 |
60 | # print(
61 | # f"Best match: {best_match}, Score: {score} "
62 | # f"Script words remaining: {len(script_words)}"
63 | # f"Script words: {script_words} \n\n"
64 | # )
65 |
66 | # Add the match or original text to the new transcript
67 | temp_transcript.append(
68 | {
69 | "text": best_match if best_match else entry["text"],
70 | "start": entry["start"],
71 | "end": entry["end"],
72 | }
73 | )
74 | return temp_transcript
75 |
76 |
77 | def generate_audio_transcription(
78 | audio_file: str,
79 | script: str,
80 | device="cuda",
81 | batch_size=16,
82 | compute_type="float16",
83 | model="large-v2",
84 | ) -> list[dict[str, str | float]]:
85 | """
86 | Generates a transcription of an audio file by performing speech-to-text transcription and aligning the
87 | transcription with a given script. It utilizes whisper models for transcription and alignment to improve
88 | accuracy.
89 |
90 | This function processes the audio in batches, aligns the transcriptions with the provided script for better
91 | accuracy, and cleans GPU memory usage during its workflow. It outputs a list of word-level transcriptions
92 | with start and end times for enhanced downstream processing.
93 |
94 | Args:
95 | audio_file (str): The path to the audio file that needs to be transcribed.
96 | script (str): The text script used for alignment with the transcribed segments.
97 | device (str): The device to be used for computation, default is 'cuda'.
98 | batch_size (int): The batch size to use during transcription, default is 16.
99 | compute_type (str): The precision type to be used for the model, default is "float16".
100 | model (str): The Whisper model variant to use, default is "large-v2". Options include "medium",
101 | "large-v2", and "large-v3".
102 |
103 | Returns:
104 | list[dict[str, str | float]]: A list of dictionaries, where each dictionary represents a word in
105 | the transcription with the word text, start time, and end time.
106 |
107 | Raises:
108 | Could include potential runtime or memory-related errors specific to the underlying
109 | libraries or resource management.
110 | """
111 | # 1. Transcribe with original whisper (batched)
112 | # options for models medium, large-v2, large-v3
113 | model = whisperx.load_model(model, device, compute_type=compute_type)
114 |
115 | audio = whisperx.load_audio(audio_file)
116 | result = model.transcribe(audio, batch_size=batch_size, language="en")
117 | logger.debug(f"Before Alignment:\n {pformat(result['segments'])}") # before alignment
118 |
119 | new_aligned_transcript = align_transcript_with_script(result["segments"], script)
120 |
121 | # delete model if low on GPU resources
122 | gc.collect()
123 | if torch.cuda.is_available():
124 | torch.cuda.empty_cache()
125 | del model
126 |
127 | # 2. Align whisper output
128 | model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
129 | result = whisperx.align(
130 | new_aligned_transcript,
131 | model_a,
132 | metadata,
133 | audio,
134 | device,
135 | return_char_alignments=False,
136 | )
137 |
138 | logger.debug(
139 | f"After Alignment:\n {pformat(result['segments'])}"
140 | ) # before alignment # after alignment
141 |
142 | word_transcript = []
143 | for segments in result["segments"]:
144 | for index, word in enumerate(segments["words"]):
145 | if "start" not in word:
146 | word["start"] = segments["words"][index - 1]["end"] if index > 0 else 0
147 | word["end"] = (
148 | segments["words"][index + 1]["start"]
149 | if index < len(segments["words"]) - 1
150 | else segments["words"][-1]["start"]
151 | )
152 |
153 | word_transcript.append(
154 | {"word": word["word"], "start": word["start"], "end": word["end"]}
155 | )
156 |
157 | logger.debug(f"Transcript:\n {pformat(word_transcript)}") # before alignment
158 |
159 | gc.collect()
160 | if torch.cuda.is_available():
161 | torch.cuda.empty_cache()
162 | del model_a
163 |
164 | return word_transcript
165 |
--------------------------------------------------------------------------------
/ShortsMaker/utils/colors_dict.py:
--------------------------------------------------------------------------------
1 | # A list of colors and their equivalent RGBA values
2 | COLORS_DICT: dict[str, tuple[int, int, int, int]] = {
3 | "aquamarine": (127, 255, 212, 255),
4 | "aquamarine2": (118, 238, 198, 255),
5 | "azure1": (240, 255, 255, 255),
6 | "blue": (0, 0, 255, 255),
7 | "CadetBlue": (152, 245, 255, 255),
8 | "CadetBlue2": (142, 229, 238, 255),
9 | "CadetBlue3": (122, 197, 205, 255),
10 | "coral": (255, 127, 80, 255),
11 | "coral1": (255, 114, 86, 255),
12 | "coral2": (238, 106, 80, 255),
13 | "CornflowerBlue": (100, 149, 237, 255),
14 | "cornsilk1": (255, 248, 220, 255),
15 | "cyan": (0, 255, 255, 255),
16 | "cyan2": (0, 238, 238, 255),
17 | "cyan3": (0, 205, 205, 255),
18 | "DarkGoldenrod": (255, 185, 15, 255),
19 | "DarkGoldenrod2": (238, 173, 14, 255),
20 | "DarkOliveGreen": (202, 255, 112, 255),
21 | "DarkOliveGreen2": (188, 238, 104, 255),
22 | "DarkOliveGreen3": (162, 205, 90, 255),
23 | "DarkSalmon": (233, 150, 122, 255),
24 | "DarkSeaGreen": (143, 188, 143, 255),
25 | "DarkSeaGreen1": (193, 255, 193, 255),
26 | "DarkSeaGreen2": (180, 238, 180, 255),
27 | "DarkSeaGreen3": (155, 205, 155, 255),
28 | "DarkSlateGray": (151, 255, 255, 255),
29 | "DarkSlateGray2": (141, 238, 238, 255),
30 | "DarkSlateGray3": (121, 205, 205, 255),
31 | "DarkTurquoise": (0, 206, 209, 255),
32 | "DeepPink": (255, 20, 147, 255),
33 | "DeepSkyBlue": (0, 191, 255, 255),
34 | "DeepSkyBlue2": (0, 178, 238, 255),
35 | "DodgerBlue": (30, 144, 255, 255),
36 | "FloralWhite": (255, 250, 240, 255),
37 | "gold": (255, 215, 0, 255),
38 | "gold2": (238, 201, 0, 255),
39 | "goldenrod": (255, 193, 37, 255),
40 | "goldenrod2": (238, 180, 34, 255),
41 | "goldenrod3": (205, 155, 29, 255),
42 | "HotPink": (255, 105, 180, 255),
43 | "HotPink1": (255, 110, 180, 255),
44 | "HotPink2": (238, 106, 167, 255),
45 | "IndianRed": (255, 106, 106, 255),
46 | "IndianRed2": (238, 99, 99, 255),
47 | "khaki": (240, 230, 140, 255),
48 | "khaki1": (255, 246, 143, 255),
49 | "khaki2": (238, 230, 133, 255),
50 | "khaki3": (205, 198, 115, 255),
51 | "lavender": (230, 230, 250, 255),
52 | "LawnGreen": (124, 252, 0, 255),
53 | "LemonChiffon": (255, 250, 205, 255),
54 | "LemonChiffon2": (238, 233, 191, 255),
55 | "LemonChiffon3": (205, 201, 165, 255),
56 | "LightBlue": (173, 216, 230, 255),
57 | "LightBlue1": (191, 239, 255, 255),
58 | "LightCyan": (224, 255, 255, 255),
59 | "LightCyan2": (209, 238, 238, 255),
60 | "LightCyan3": (180, 205, 205, 255),
61 | "LightGoldenrod": (255, 236, 139, 255),
62 | "LightGoldenrod2": (238, 220, 130, 255),
63 | "LightGoldenrod3": (205, 190, 112, 255),
64 | "LightGoldenrodYellow": (250, 250, 210, 255),
65 | "LightGreen": (144, 238, 144, 255),
66 | "LightPink": (255, 182, 193, 255),
67 | "LightPink1": (255, 174, 185, 255),
68 | "LightSalmon": (255, 160, 122, 255),
69 | "LightSeaGreen": (32, 178, 170, 255),
70 | "LightSkyBlue": (135, 206, 250, 255),
71 | "LightSkyBlue1": (176, 226, 255, 255),
72 | "LightSkyBlue2": (164, 211, 238, 255),
73 | "LightSkyBlue3": (141, 182, 205, 255),
74 | "LightSlateBlue": (132, 112, 255, 255),
75 | "LightYellow": (255, 255, 224, 255),
76 | "lime": (0, 255, 0, 255),
77 | "LimeGreen": (50, 205, 50, 255),
78 | "linen": (250, 240, 230, 255),
79 | "magenta": (255, 0, 255, 255),
80 | "maroon1": (255, 52, 179, 255),
81 | "MediumAquamarine": (102, 205, 170, 255),
82 | "MediumGoldenRod": (209, 193, 102, 255),
83 | "MediumOrchid": (224, 102, 255, 255),
84 | "MediumOrchid2": (209, 95, 238, 255),
85 | "MediumSeaGreen": (60, 179, 113, 255),
86 | "MediumSpringGreen": (0, 250, 154, 255),
87 | "MediumTurquoise": (72, 209, 204, 255),
88 | "MintCream": (245, 255, 250, 255),
89 | "MistyRose": (255, 228, 225, 255),
90 | "MistyRose2": (238, 213, 210, 255),
91 | "moccasin": (255, 228, 181, 255),
92 | "NavajoWhite": (255, 222, 173, 255),
93 | "OliveDrab": (192, 255, 62, 255),
94 | "OliveDrab2": (179, 238, 58, 255),
95 | "OliveDrab3": (154, 205, 50, 255),
96 | "orange": (255, 165, 0, 255),
97 | "orchid": (218, 112, 214, 255),
98 | "orchid1": (255, 131, 250, 255),
99 | "orchid2": (238, 122, 233, 255),
100 | "PaleGoldenrod": (238, 232, 170, 255),
101 | "PaleGreen": (152, 251, 152, 255),
102 | "PaleGreen1": (154, 255, 154, 255),
103 | "PaleGreen2": (144, 238, 144, 255),
104 | "PaleGreen3": (124, 205, 124, 255),
105 | "PaleTurquoise": (175, 238, 238, 255),
106 | "PaleTurquoise1": (187, 255, 255, 255),
107 | "PapayaWhip": (255, 239, 213, 255),
108 | "PeachPuff": (255, 218, 185, 255),
109 | "PeachPuff2": (238, 203, 173, 255),
110 | "peru": (205, 133, 63, 255),
111 | "pink": (255, 192, 203, 255),
112 | "pink1": (255, 181, 197, 255),
113 | "pink2": (238, 169, 184, 255),
114 | "plum": (221, 160, 221, 255),
115 | "plum1": (255, 187, 255, 255),
116 | "plum2": (238, 174, 238, 255),
117 | "PowderBlue": (176, 224, 230, 255),
118 | "red": (255, 0, 0, 255),
119 | "RosyBrown1": (255, 193, 193, 255),
120 | "salmon1": (255, 140, 105, 255),
121 | "SandyBrown": (244, 164, 96, 255),
122 | "SeaGreen": (84, 255, 159, 255),
123 | "SeaGreen2": (78, 238, 148, 255),
124 | "SeaGreen3": (67, 205, 128, 255),
125 | "seashell": (255, 245, 238, 255),
126 | "seashell2": (238, 229, 222, 255),
127 | "sienna": (255, 130, 71, 255),
128 | "sienna2": (238, 121, 66, 255),
129 | "SkyBlue": (135, 206, 235, 255),
130 | "SkyBlue1": (135, 206, 255, 255),
131 | "SkyBlue2": (126, 192, 238, 255),
132 | "SkyBlue3": (108, 166, 205, 255),
133 | "SlateBlue": (131, 111, 255, 255),
134 | "SlateBlue2": (122, 103, 238, 255),
135 | "SlateGray1": (198, 226, 255, 255),
136 | "snow": (255, 250, 250, 255),
137 | "snow2": (238, 233, 233, 255),
138 | "SpringGreen": (0, 255, 127, 255),
139 | "SpringGreen2": (0, 238, 118, 255),
140 | "SteelBlue": (99, 184, 255, 255),
141 | "SteelBlue2": (92, 172, 238, 255),
142 | "tan": (210, 180, 140, 255),
143 | "tan1": (255, 165, 79, 255),
144 | "tan2": (238, 154, 73, 255),
145 | "thistle": (255, 225, 255, 255),
146 | "thistle2": (238, 210, 238, 255),
147 | "tomato": (255, 99, 71, 255),
148 | "turquoise": (64, 224, 208, 255),
149 | "turquoise1": (0, 245, 255, 255),
150 | "violet": (238, 130, 238, 255),
151 | "VioletRed": (255, 62, 150, 255),
152 | "VioletRed2": (238, 58, 140, 255),
153 | "white": (255, 255, 255, 255),
154 | "WhiteSmoke": (245, 245, 245, 255),
155 | "yellow": (255, 255, 0, 255),
156 | "yellow2": (238, 238, 0, 255),
157 | }
158 |
--------------------------------------------------------------------------------
/ShortsMaker/utils/download_youtube_music.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import yt_dlp
4 |
5 | from .logging_config import get_logger
6 |
7 | logger = get_logger(__name__)
8 |
9 |
10 | def sanitize_filename(source_filename: str) -> str:
11 | """
12 | Sanitizes a given filename by removing leading and trailing spaces, replacing spaces with underscores,
13 | and replacing invalid characters with underscores.
14 |
15 | Args:
16 | source_filename (str): The original filename to be sanitized.
17 |
18 | Returns:
19 | str: The sanitized filename.
20 | """
21 | sanitized_filename = source_filename
22 | sanitized_filename = sanitized_filename.strip()
23 | sanitized_filename = sanitized_filename.strip(" .")
24 | sanitized_filename = sanitized_filename.replace(" ", "_")
25 | invalid_chars = '<>:"/\\|?*'
26 | sanitized_filename = "".join("_" if c in invalid_chars else c for c in sanitized_filename)
27 | return sanitized_filename
28 |
29 |
30 | def download_youtube_music(music_url: str, music_dir: Path, force: bool = False) -> list[Path]:
31 | """
32 | Downloads music from a YouTube URL provided, saving it to a specified directory. The method supports
33 | downloading either the full audio or splitting into chapters, if chapters are available in the video
34 | metadata. Optionally, existing files can be overwritten if the `force` flag is set.
35 |
36 | Args:
37 | music_url (str): The YouTube URL of the music video to download.
38 | music_dir (Path): The directory where the downloaded audio files will be saved.
39 | force (bool): Specifies whether existing files should be overwritten. Defaults to False.
40 |
41 | Returns:
42 | list[Path]: A list of paths to the downloaded audio files.
43 | """
44 | ydl_opts = {}
45 |
46 | with yt_dlp.YoutubeDL(ydl_opts) as ydl:
47 | extracted_info = ydl.extract_info(music_url, download=False)
48 | info_dict = ydl.sanitize_info(extracted_info)
49 |
50 | logger.info(f"Music title: {info_dict['title']}")
51 |
52 | # Handle case with no chapters
53 | if not info_dict["chapters"]:
54 | logger.info("No chapters found. Downloading full audio...")
55 | sanitized_filename = sanitize_filename(info_dict["title"])
56 |
57 | ydl_opts = {
58 | "format": "bestaudio",
59 | "outtmpl": str(music_dir / f"{sanitized_filename}.%(ext)s"),
60 | "postprocessors": [
61 | {
62 | "key": "FFmpegExtractAudio",
63 | "preferredcodec": "wav",
64 | "preferredquality": "0",
65 | }
66 | ],
67 | "restrictfilenames": True,
68 | }
69 |
70 | output_path = music_dir / f"{sanitized_filename}.wav"
71 | logger.info(f"Output path: {output_path.absolute()}")
72 | if (not output_path.exists() and not force) or force:
73 | with yt_dlp.YoutubeDL(ydl_opts) as ydl_audio:
74 | ydl_audio.download([music_url])
75 | logger.info("Full audio downloaded successfully!")
76 | return [output_path]
77 |
78 | # Handle case with chapters
79 | for chapter in info_dict["chapters"]:
80 | logger.info(f"Found chapter: {chapter['title']}")
81 | sanitized_filename = sanitize_filename(chapter["title"])
82 |
83 | ydl_opts = {
84 | "format": "bestaudio",
85 | "outtmpl": str(music_dir / f"{sanitized_filename}.%(ext)s"),
86 | "download_ranges": lambda chapter_range, *args: [
87 | {
88 | "start_time": chapter["start_time"],
89 | "end_time": chapter["end_time"],
90 | "title": chapter["title"],
91 | }
92 | ],
93 | "force_keyframes_at_cuts": True,
94 | "postprocessors": [
95 | {
96 | "key": "FFmpegExtractAudio",
97 | "preferredcodec": "wav",
98 | "preferredquality": "0",
99 | }
100 | ],
101 | "restrictfilenames": True,
102 | }
103 |
104 | output_path = music_dir / f"{sanitized_filename}.wav"
105 | logger.info(f"Output path: {output_path.absolute()}")
106 | if (not output_path.exists() and not force) or force:
107 | with yt_dlp.YoutubeDL(ydl_opts) as ydl_chapter_audio:
108 | ydl_chapter_audio.download([music_url])
109 | print(f"Chapter downloaded: {chapter['title']}")
110 |
111 | # Return path to first music file found
112 | music_files = list(music_dir.glob("*.wav"))
113 | return music_files
114 |
--------------------------------------------------------------------------------
/ShortsMaker/utils/download_youtube_video.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import yt_dlp
4 |
5 | from .logging_config import get_logger
6 |
7 | logger = get_logger(__name__)
8 |
9 |
10 | def download_youtube_video(video_url: str, video_dir: Path, force: bool = False) -> list[Path]:
11 | """
12 | Downloads a YouTube video given its URL and stores it in the specified directory. The
13 | video is downloaded in the best available MP4 format, and its filename is sanitized to
14 | remove invalid characters. Provides an option to force download the video even if the
15 | target file already exists.
16 |
17 | Args:
18 | video_url: The URL of the video to be downloaded from YouTube.
19 | video_dir: The directory where the video will be saved.
20 | force: If True, forces the download even if the video file already exists. Defaults
21 | to False.
22 |
23 | Returns:
24 | list[Path]: A list containing the Path objects of the '.mp4' files in the specified
25 | directory after the download process.
26 | """
27 | ydl_opts = {
28 | "format": "bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best",
29 | "merge_output_format": "mp4",
30 | "outtmpl": str(video_dir / "%(title)s.%(ext)s"),
31 | "restrictfilenames": True,
32 | }
33 |
34 | with yt_dlp.YoutubeDL(ydl_opts) as ydl:
35 | info = ydl.extract_info(video_url, download=False)
36 | info_dict = ydl.sanitize_info(info)
37 |
38 | logger.info(f"Video title: {info_dict['title']}")
39 | sanitized_filename = ydl.prepare_filename(info)
40 | logger.info(f"Sanitized filename will be: {sanitized_filename}")
41 |
42 | output_path = video_dir / sanitized_filename
43 | if (not output_path.exists() and not force) or force:
44 | ydl.download([video_url])
45 | logger.info("Video downloaded successfully!")
46 |
47 | bg_files = list(video_dir.glob("*.mp4"))
48 | return bg_files
49 |
--------------------------------------------------------------------------------
/ShortsMaker/utils/get_tts.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import io
3 | import textwrap
4 | from pprint import pformat
5 | from threading import Thread
6 |
7 | import requests
8 | from pydub import AudioSegment
9 |
10 | from .logging_config import get_logger
11 | from .retry import retry
12 |
13 | logger = get_logger(__name__)
14 |
15 | # define the endpoint data with URLs and corresponding response keys
16 | ENDPOINT_DATA = [
17 | {
18 | "url": "https://tiktok-tts.weilnet.workers.dev/api/generation",
19 | "response": "data",
20 | },
21 | {"url": "https://countik.com/api/text/speech", "response": "v_data"},
22 | {"url": "https://gesserit.co/api/tiktok-tts", "response": "base64"},
23 | ]
24 |
25 | # define available voices for text-to-speech conversion
26 | VOICES = [
27 | "en_us_001", # English US - Female (Int. 1)
28 | "en_us_002", # English US - Female (Int. 2)
29 | "en_au_002", # English AU - Male
30 | "en_uk_001", # English UK - Male 1
31 | "en_uk_003", # English UK - Male 2
32 | "en_us_006", # English US - Male 1
33 | "en_us_010", # English US - Male 4
34 | "en_female_emotional", # peaceful
35 | ]
36 |
37 |
38 | # define the text-to-speech function
39 | @retry(max_retries=3, delay=5)
40 | def tts(text: str, voice: str, output_filename: str = "output.mp3") -> None:
41 | """
42 | Converts text to speech using specified voice and saves to output file.
43 |
44 | Args:
45 | text (str): Input text to convert
46 | voice (str): Voice ID to use
47 | output_filename (str): Output audio file path
48 |
49 | Raises:
50 | ValueError: If voice is invalid or text is empty
51 | """
52 | _validate_inputs(text, voice)
53 | chunks = _split_text(text)
54 | _log_chunks(text, chunks)
55 |
56 | global ENDPOINT_DATA
57 |
58 | for endpoint in ENDPOINT_DATA:
59 | audio_data = [""] * len(chunks)
60 | audio_data = _process_chunks(chunks, endpoint, voice, audio_data)
61 | if audio_data is not None:
62 | _save_audio(audio_data, output_filename)
63 | break
64 |
65 |
66 | def _validate_inputs(text: str, voice: str) -> None:
67 | if voice not in VOICES:
68 | raise ValueError("voice must be valid")
69 | if not text:
70 | raise ValueError("text must not be 'None'")
71 |
72 |
73 | def _log_chunks(text: str, chunks: list[str]) -> None:
74 | logger.info(f"text: {text}")
75 | logger.info(f"Split text into {len(chunks)} chunks")
76 | for chunk in chunks:
77 | logger.info(f"Chunk: {chunk}")
78 |
79 |
80 | def _process_chunks(
81 | chunks: list[str], endpoint: dict, voice: str, audio_data: list[str]
82 | ) -> list[str] | None:
83 | valid = True
84 |
85 | def generate_audio_chunk(index: int, chunk: str) -> None:
86 | nonlocal valid
87 | if not valid:
88 | return
89 |
90 | try:
91 | logger.info(f"Using endpoint: {endpoint['url']}")
92 | response = requests.post(
93 | endpoint["url"],
94 | json={"text": chunk, "voice": voice},
95 | headers={
96 | "User-Agent": "com.zhiliaoapp.musically/2022600030 (Linux; U; Android 7.1.2; es_ES; SM-G988N; Build/NRD90M;tt-ok/3.12.13.1)",
97 | },
98 | )
99 | if response.status_code == 200:
100 | audio_data[index] = response.json()[endpoint["response"]]
101 | logger.info(
102 | f"Chunk {index} processed successfully with endpoint: {endpoint['url']}"
103 | )
104 | else:
105 | logger.warning(
106 | f"Endpoint failed with status {response.status_code}: {endpoint['url']}"
107 | )
108 | valid = False
109 | except requests.exceptions.JSONDecodeError as e:
110 | logger.warning(f"JSONDecodeError for endpoint {endpoint['url']}: {e}")
111 | logger.error(f"RequestException for endpoint {endpoint['url']}: {e}")
112 | valid = False
113 | except requests.RequestException as e:
114 | logger.warning(f"Response from endpoint {endpoint['url']}:\n{pformat(response.json)}")
115 | logger.error(f"RequestException for endpoint {endpoint['url']}: {e}")
116 | valid = False
117 |
118 | threads = [
119 | Thread(target=generate_audio_chunk, args=(i, chunk)) for i, chunk in enumerate(chunks)
120 | ]
121 | for thread in threads:
122 | thread.start()
123 | for thread in threads:
124 | thread.join()
125 |
126 | return audio_data if valid else None
127 |
128 |
129 | def _save_audio(audio_data: list[str], output_filename: str) -> None:
130 | audio_bytes = b"".join(base64.b64decode(chunk) for chunk in audio_data)
131 | audio_segment: AudioSegment = AudioSegment.from_file(io.BytesIO(audio_bytes))
132 | audio_segment.export(output_filename, format="wav")
133 |
134 |
135 | def _split_text(text: str, chunk_size: int = 250) -> list[str]:
136 | """
137 | Splits a given text into smaller chunks of a specified size without breaking
138 | words or splitting on hyphens.
139 |
140 | The function wraps the input text into smaller substrings, ensuring the
141 | integrity of the text by preventing cutoff mid-word or mid-hyphen. Each chunk
142 | is at most of the specified chunk size.
143 |
144 | Args:
145 | text (str): The input text to be split into smaller chunks.
146 |
147 | Returns:
148 | list[str]: A list of text chunks where each chunk is at most the
149 | specified size while preserving word integrity.
150 | """
151 |
152 | text_list = textwrap.wrap(
153 | text, width=chunk_size, break_long_words=False, break_on_hyphens=False
154 | )
155 |
156 | return text_list
157 |
--------------------------------------------------------------------------------
/ShortsMaker/utils/logging_config.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from pathlib import Path
4 |
5 | from colorlog import ColoredFormatter
6 |
7 | # prevent huggingface symlink warnings
8 | os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "true"
9 |
10 | # Global configuration
11 | LOG_FILE: Path | str = "ShortsMaker.log"
12 | LOG_LEVEL: str = "DEBUG"
13 | LOGGING_ENABLED: bool = True
14 | INITIALIZED: bool = False
15 |
16 | # Cache of configured loggers
17 | LOGGERS: dict[str, logging.Logger] = {}
18 |
19 |
20 | def get_logger(name: str = __name__) -> logging.Logger:
21 | """
22 | Get a logger with the specified name, typically __name__ from the module.
23 | If the logging system has not been initialized, it will use default settings.
24 |
25 | Args:
26 | name (str): The logger name, typically passed as __name__ from the module.
27 | This ensures proper hierarchical naming of loggers.
28 |
29 | Returns:
30 | logging.Logger: A configured logger instance.
31 | """
32 | global LOG_FILE, LOG_LEVEL, LOGGING_ENABLED, INITIALIZED, LOGGERS
33 |
34 | # Initialize logging system with defaults if not already done
35 | if not INITIALIZED:
36 | configure_logging()
37 |
38 | # Return existing logger if already configured
39 | if name in LOGGERS:
40 | return LOGGERS[name]
41 |
42 | # Create a new logger
43 | logger = logging.getLogger(name)
44 |
45 | # Don't add handlers if this is a child logger
46 | # Parent loggers will handle it through hierarchy
47 | if not logger.handlers:
48 | # Create Console
49 | console_handler = logging.StreamHandler()
50 | color_formatter = ColoredFormatter(
51 | "{log_color}{asctime} - {name} - {funcName} - {levelname} - {message}",
52 | style="{",
53 | datefmt="%Y-%m-%d %H:%M:%S",
54 | log_colors={
55 | "DEBUG": "cyan",
56 | "INFO": "green",
57 | "WARNING": "yellow",
58 | "ERROR": "red",
59 | "CRITICAL": "bold_red",
60 | },
61 | reset=True,
62 | )
63 | console_handler.setFormatter(color_formatter)
64 |
65 | # Create File Handler
66 | file_handler = logging.FileHandler(LOG_FILE, mode="a", encoding="utf-8")
67 | formatter = logging.Formatter(
68 | "{asctime} - {name} - {funcName} - {levelname} - {message}",
69 | style="{",
70 | datefmt="%Y-%m-%d %H:%M:%S",
71 | )
72 | file_handler.setFormatter(formatter)
73 |
74 | # Add handlers to logger
75 | logger.addHandler(console_handler)
76 | logger.addHandler(file_handler)
77 |
78 | # Set logging level based on enable flag
79 | if LOGGING_ENABLED:
80 | logger.setLevel(LOG_LEVEL)
81 | else:
82 | logger.setLevel(logging.CRITICAL)
83 |
84 | # Prevent propagation to avoid duplicate logs
85 | logger.propagate = False
86 |
87 | # Store the configured logger
88 | LOGGERS[name] = logger
89 |
90 | return logger
91 |
92 |
93 | def configure_logging(
94 | log_file: Path | str = LOG_FILE, level: str | int = LOG_LEVEL, enable: bool = LOGGING_ENABLED
95 | ) -> None:
96 | """
97 | Configure the global logging settings.
98 | This function can be called by users to customize logging behavior.
99 |
100 | Args:
101 | log_file (str | Path): Path to the log file.
102 | level (str | int): Logging level ('DEBUG', 'INFO', etc.).
103 | enable (bool): Whether to enable logging.
104 | """
105 | global LOG_FILE, LOG_LEVEL, LOGGING_ENABLED, INITIALIZED, LOGGERS
106 |
107 | # Update configuration with provided values
108 | LOG_FILE = Path(log_file) if isinstance(log_file, str) else log_file
109 | LOG_LEVEL = level
110 | LOGGING_ENABLED = enable
111 |
112 | # Create log directory if it doesn't exist
113 | LOG_FILE.parent.mkdir(parents=True, exist_ok=True)
114 |
115 | # Update all existing loggers with new settings
116 | for logger_name, logger in LOGGERS.items():
117 | # Update log level
118 | if LOGGING_ENABLED:
119 | logger.setLevel(LOG_LEVEL)
120 | else:
121 | logger.setLevel(logging.CRITICAL)
122 |
123 | INITIALIZED = True
124 |
--------------------------------------------------------------------------------
/ShortsMaker/utils/notify_discord.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import textwrap
4 | from random import choice, randint
5 |
6 | import requests
7 | from bs4 import BeautifulSoup
8 | from discord_webhook import DiscordEmbed, DiscordWebhook
9 | from requests import Response
10 |
11 | if not os.environ.get("DISCORD_WEBHOOK_URL"):
12 | raise ValueError("DISCORD_WEBHOOK_URL not set, Please set it in your environment variables.")
13 |
14 | DISCORD_URL = os.environ.get("DISCORD_WEBHOOK_URL")
15 |
16 |
17 | def get_arthas():
18 | """
19 | Fetches a random Arthas image URL from Bing image search results.
20 |
21 | This function sends a search request to Bing images for the keyword 'arthas'
22 | and retrieves a specific page of the search results. It extracts image URLs
23 | from the returned HTML content and returns one randomly selected image URL.
24 |
25 | Returns:
26 | str: A randomly selected URL of an Arthas image.
27 |
28 | Raises:
29 | requests.exceptions.RequestException: If the HTTP request fails or encounters an issue.
30 | """
31 | url = f"https://www.bing.com/images/search?q=arthas&first={randint(1, 10)}"
32 | response = requests.get(
33 | url,
34 | headers={
35 | "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) \
36 | Chrome/50.0.2661.102 Safari/537.36"
37 | },
38 | )
39 | soup = BeautifulSoup(response.content, "lxml")
40 | divs = soup.find_all("div", class_="imgpt")
41 | imgs = []
42 | for div in divs:
43 | img = div.find("a")["m"]
44 | img = img.split("murl")[1].split('"')[2]
45 | imgs.append(img)
46 | return choice(imgs)
47 |
48 |
49 | def get_meme():
50 | """
51 | Fetches a meme image URL from the meme-api.com API.
52 |
53 | This function uses the "gimme" endpoint of the meme-api.com to fetch a
54 | meme image URL. The API allows for specifying a subreddit and quantity
55 | of meme images. The default behavior is to fetch a random meme from the
56 | API. The response is parsed and the URL of the image is extracted and
57 | returned. The function applies a User-Agent header as part of the
58 | request to prevent potential issues with the API.
59 |
60 | Returns:
61 | str: The URL of the meme image.
62 |
63 | Raises:
64 | JSONDecodeError: If the response content could not be decoded into JSON.
65 | KeyError: If the expected "url" key is not found in the JSON response.
66 | RequestException: If there is an issue with the HTTP request.
67 |
68 | """
69 | # Uses https://meme-api.com/gimme/
70 | # Can use custom subreddit and return multiple images
71 | # Endpoint: /gimme/{subreddit}/{count}
72 | # Example: https://meme-api.com/gimme/wholesomememes/2
73 | # Returns:
74 | # Image url
75 | # Looks like the below endpoint is not working anymore
76 | # url = "https://meme-api.com/gimme"
77 |
78 | url = "https://memeapi.zachl.tech/pic/json"
79 |
80 | response = requests.get(
81 | url,
82 | headers={
83 | "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_11_5) AppleWebKit/537.36 (KHTML, like Gecko) \
84 | Chrome/50.0.2661.102 Safari/537.36"
85 | },
86 | )
87 |
88 | soup = BeautifulSoup(response.content, "html.parser")
89 | return json.loads(soup.text)["MemeURL"]
90 |
91 |
92 | def notify_discord(message) -> Response:
93 | """
94 | Sends a notification message to a Discord webhook, splitting messages longer than the character limit,
95 | and embedding additional information such as title, description, and images.
96 |
97 | Args:
98 | message (str): The message content to be sent to the Discord webhook. If the message exceeds 4000
99 | characters, it will be split into smaller parts.
100 |
101 | Returns:
102 | Response: The response object resulting from the webhook execution, which contains information
103 | such as status code and response text.
104 | """
105 | messages = textwrap.wrap(message, 4000)
106 | response = None
107 |
108 | for message in messages:
109 | webhook = DiscordWebhook(url=DISCORD_URL, rate_limit_retry=True)
110 |
111 | embed = DiscordEmbed()
112 | embed.set_title(":warning:Error found while running the Automation!:warning:")
113 | embed.set_description(f"{message}")
114 | embed.set_image(url=get_meme())
115 |
116 | try:
117 | embed.set_thumbnail(url=get_arthas())
118 | except Exception as e:
119 | print(f"Error fetching arthas: {e}")
120 |
121 | embed.set_color("ff0000")
122 | embed.set_timestamp()
123 | webhook.add_embed(embed)
124 |
125 | response = webhook.execute()
126 | print(response.status_code)
127 | print(response.text)
128 | return response
129 |
--------------------------------------------------------------------------------
/ShortsMaker/utils/retry.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import time
3 |
4 | from .logging_config import get_logger
5 | from .notify_discord import notify_discord
6 |
7 | logger = get_logger(__name__)
8 |
9 |
10 | def retry(max_retries: int, delay: int, notify: bool = False):
11 | """
12 | A retry decorator function that allows retrying a function based on the specified
13 | number of retries, delay between retries, and an option to send a notification upon
14 | failure. It logs all execution details, including successful executions, exceptions,
15 | and retry attempts.
16 |
17 | Args:
18 | max_retries (int): The maximum number of times the function should be retried
19 | in case of an exception.
20 | delay (int): The time in seconds to wait before retrying the function after
21 | a failure.
22 | notify (bool): Whether to send a notification if the function fails after
23 | reaching the maximum number of retries. Default is False.
24 |
25 | Returns:
26 | Callable: A decorator function that applies the retry logic to the decorated
27 | function.
28 |
29 | Raises:
30 | Exception: If all retries are exhausted and the function still fails, the
31 | exception from the last attempt will be raised.
32 |
33 | Example:
34 | @retry(max_retries=3, delay=2)
35 | def my_function():
36 | # Function implementation
37 | pass
38 | """
39 |
40 | def decorator(func):
41 | @functools.wraps(func)
42 | def wrapper(*args, **kwargs):
43 | start_time = time.perf_counter()
44 | logger.info(f"Using retry decorator with {max_retries} max_retries and {delay}s delay")
45 | logger.info(f"Begin function {func.__name__}")
46 | err = "Before running"
47 | for attempt in range(max_retries):
48 | try:
49 | value = func(*args, **kwargs)
50 | logger.info(f"Returned: {value}")
51 | logger.info(
52 | f"Completed function {func.__name__} in {round(time.perf_counter() - start_time, 2)}s after {attempt + 1} max_retries"
53 | )
54 | return value
55 | except Exception as e:
56 | logger.exception(f"Exception: {e}")
57 | logger.warning(f"Retrying function {func.__name__} after {delay}s")
58 | err = str(e)
59 | time.sleep(delay)
60 | if notify:
61 | notify_discord(
62 | f"{func.__name__} Failed after {max_retries} max_retries.\nException: {err}"
63 | )
64 | logger.exception(f"Failed after {max_retries} max_retries")
65 |
66 | return wrapper
67 |
68 | return decorator
69 |
--------------------------------------------------------------------------------
/assets/credits/credits.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rajathjn/shorts_maker/cb3e1823e3f271db08a969b7b7a8a5083d09df8d/assets/credits/credits.mp4
--------------------------------------------------------------------------------
/assets/credits/credits_mask.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rajathjn/shorts_maker/cb3e1823e3f271db08a969b7b7a8a5083d09df8d/assets/credits/credits_mask.mp4
--------------------------------------------------------------------------------
/assets/fonts/Monaco.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rajathjn/shorts_maker/cb3e1823e3f271db08a969b7b7a8a5083d09df8d/assets/fonts/Monaco.ttf
--------------------------------------------------------------------------------
/assets/fonts/Roboto.ttf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rajathjn/shorts_maker/cb3e1823e3f271db08a969b7b7a8a5083d09df8d/assets/fonts/Roboto.ttf
--------------------------------------------------------------------------------
/example.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 |
3 | import yaml
4 |
5 | from ShortsMaker import MoviepyCreateVideo, ShortsMaker
6 |
7 | setup_file = "setup.yml"
8 |
9 | with open(setup_file) as f:
10 | cfg = yaml.safe_load(f)
11 |
12 | get_post = ShortsMaker(setup_file)
13 |
14 | # You can either provide an URL for the reddit post
15 | get_post.get_reddit_post(
16 | url="https://www.reddit.com/r/Python/comments/1j36d7a/i_got_tired_of_ai_shorts_scams_so_i_built_my_own/"
17 | )
18 | # Or just run the method to get a random post from the subreddit defined in setup.yml
19 | # get_post.get_reddit_post()
20 |
21 | with open(Path(cfg["cache_dir"]) / cfg["reddit_post_getter"]["record_file_txt"]) as f:
22 | script = f.read()
23 |
24 | get_post.generate_audio(
25 | source_txt=script,
26 | output_audio=f"{cfg['cache_dir']}/{cfg['audio']['output_audio_file']}",
27 | output_script_file=f"{cfg['cache_dir']}/{cfg['audio']['output_script_file']}",
28 | )
29 |
30 | get_post.generate_audio_transcript(
31 | source_audio_file=f"{cfg['cache_dir']}/{cfg['audio']['output_audio_file']}",
32 | source_text_file=f"{cfg['cache_dir']}/{cfg['audio']['output_script_file']}",
33 | )
34 |
35 | get_post.quit()
36 |
37 | create_video = MoviepyCreateVideo(
38 | config_file=setup_file,
39 | # speed_factor=1.25, # Set the speed factor for the video
40 | )
41 |
42 | create_video(output_path="assets/output.mp4")
43 |
44 | create_video.quit()
45 |
46 | # Do not run the below when you are using shorts_maker within a container.
47 |
48 | # ask_llm = AskLLM(config_file=setup_file)
49 | # result = ask_llm.invoke(script)
50 | # print(result["parsed"].title)
51 | # print(result["parsed"].description)
52 | # print(result["parsed"].tags)
53 | # print(result["parsed"].thumbnail_description)
54 | # ask_llm.quit_llm()
55 |
56 | # You can use, AskLLM to generate a text prompt for the image generation as well
57 | # image_description = ask_llm.invoke_image_describer(script = script, input_text = "A wild scenario")
58 | # print(image_description)
59 | # print(image_description["parsed"].description)
60 |
61 | # Generate image uses a lot of resources so beware
62 | # generate_image = GenerateImage(config_file=setup_file)
63 | # generate_image.use_huggingface_flux_schnell(image_description["parsed"].description, "output.png")
64 | # generate_image.quit()
65 |
--------------------------------------------------------------------------------
/example.setup.yml:
--------------------------------------------------------------------------------
1 | ---
2 | # Used only for fetching models for Image Generation
3 | hugging_face_access_token: "your_hugging_face_access_token_here"
4 |
5 | # A cache dir for storing results
6 | cache_dir: "cache"
7 |
8 | # A dir in which your background_videos, background_music, credits and fonts folders are located
9 | assets_dir: "assets"
10 |
11 | # Best to leave this as is
12 | retry:
13 | max_retries: 3
14 | delay: 5
15 | notify: False
16 |
17 | # Refer to https://business.reddithelp.com/s/article/Create-a-Reddit-Application
18 | # for more information on how to create a Reddit application
19 | # and get your client_id and client_secret
20 | reddit_praw:
21 | client_id: "your_reddit_client_id"
22 | client_secret: "your_reddit_client_secret"
23 | user_agent: "your_user_agent_here"
24 |
25 | # Replace with your own subreddit name
26 | reddit_post_getter:
27 | subreddit_name: "your_subreddit_name"
28 | record_file_json: "post.json"
29 | record_file_txt: "post.txt"
30 |
31 | # If you are not using your cuda device, set device to "cpu"
32 | # Refer https://github.com/m-bain/whisperX for more information
33 | audio:
34 | output_script_file: "generated_audio_script.txt"
35 | output_audio_file: "output.wav"
36 | transcript_json: "transcript.json"
37 | device: "cpu" # or "cuda"
38 | model: "large-v2" # or "medium"
39 | batch_size: 16 # or 32
40 | compute_type: "int8" # or "float16"
41 |
42 | # Replace with the video URLs and music URLs you want to use
43 | # Only YouTube URLs are supported
44 | # Note: If you want to avoid setting this,
45 | # Pass the path to the respective objects, when initializing the MoviepyCreateVideo class
46 | video:
47 | background_videos_urls:
48 | # https://www.youtube.com/watch?v=n_Dv4JMiwK8
49 | - "https://www.youtube.com/watch?v=example_video_id"
50 | background_music_urls:
51 | # https://www.youtube.com/watch?v=G8a45UZJGh4&t=1s
52 | - "https://www.youtube.com/watch?v=example_music_id"
53 | font_dir: "fonts"
54 | credits_dir: "credits"
55 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "ShortsMaker"
3 | version = "0.2.0"
4 | description = "A python project to make and upload Youtube shorts video"
5 | authors = [
6 | {name = "rajathjn",email = "rajathjnx@gmail.com"}
7 | ]
8 | license = "AGPL-3.0-or-later"
9 | readme = "README.md"
10 | requires-python = ">=3.12,<3.13"
11 | classifiers = [
12 | "Development Status :: 4 - Beta",
13 | "Intended Audience :: Developers",
14 | "Programming Language :: Python :: 3.12",
15 | ]
16 | dependencies = [
17 | "accelerate>=1.3.0",
18 | "beautifulsoup4>=4.13.3",
19 | "colorlog>=6.9.0",
20 | "diffusers>=0.32.2",
21 | "discord-webhook>=1.3.1",
22 | "ftfy>=6.3.1",
23 | "h11>=0.16.0",
24 | "jinja2>=3.1.6",
25 | "langchain-ollama>=0.3.2",
26 | "language-tool-python>=2.8.2",
27 | "lxml>=5.3.1",
28 | "moviepy>=2.1.2",
29 | "ollama>=0.4.7",
30 | "praw>=7.8.1",
31 | "psutil>=6.1.1",
32 | "pydub>=0.25.1",
33 | "pyyaml>=6.0.2",
34 | "rapidfuzz>=3.12.1",
35 | "requests>=2.32.3",
36 | "setuptools>=75.8.0",
37 | "transformers>=4.48.2",
38 | "unidecode>=1.3.8",
39 | "wheel>=0.45.1",
40 | "whisperx>=3.3.1",
41 | "yt-dlp>=2025.3.31",
42 | ]
43 |
44 | [dependency-groups]
45 | dev = [
46 | "pre-commit>=4.1.0",
47 | "pytest>=8.3.4",
48 | "pytest-cov>=6.0.0",
49 | "pytest-mock>=3.14.0",
50 | "requests-mock>=1.12.1",
51 | "ruff>=0.9.5",
52 | ]
53 |
54 | [project.optional-dependencies]
55 | cpu = [
56 | "torch>=2.7.0",
57 | "torchaudio>=2.7.0",
58 | "torchvision>=0.22.0",
59 | ]
60 | cu128 = [
61 | "torch>=2.7.0",
62 | "torchaudio>=2.7.0",
63 | "torchvision>=0.22.0",
64 | ]
65 |
66 | [tool.uv]
67 | conflicts = [
68 | [
69 | { extra = "cpu" },
70 | { extra = "cu128" },
71 | ],
72 | ]
73 |
74 | [tool.uv.sources]
75 | torch = [
76 | { index = "pytorch-cpu", extra = "cpu" },
77 | { index = "pytorch-cu128", extra = "cu128" },
78 | ]
79 | torchvision = [
80 | { index = "pytorch-cpu", extra = "cpu" },
81 | { index = "pytorch-cu128", extra = "cu128" },
82 | ]
83 | torchaudio = [
84 | { index = "pytorch-cpu", extra = "cpu" },
85 | { index = "pytorch-cu128", extra = "cu128" },
86 | ]
87 |
88 | [[tool.uv.index]]
89 | name = "pytorch-cpu"
90 | url = "https://download.pytorch.org/whl/cpu"
91 | explicit = true
92 |
93 | [[tool.uv.index]]
94 | name = "pytorch-cu128"
95 | url = "https://download.pytorch.org/whl/cu128"
96 | explicit = true
97 |
98 | [tool.pytest.ini_options]
99 | testpaths = [
100 | "tests",
101 | ]
102 |
103 | [tool.ruff]
104 | # Set the maximum line length to 79.
105 | exclude = [
106 | ".git",
107 | ".ipynb_checkpoints",
108 | ".mypy_cache",
109 | ".nox",
110 | ".pyenv",
111 | ".pytest_cache",
112 | ".pytype",
113 | ".ruff_cache",
114 | ".tox",
115 | "_build",
116 | "build",
117 | "venv",
118 | ".pytest_cache",
119 | ".venv",
120 | ".vscode",
121 | ".idea"
122 | ]
123 | line-length = 100
124 | indent-width = 4
125 | fix = true
126 |
127 | [tool.ruff.lint]
128 | extend-select = [
129 | "UP", # pyupgrade
130 | "I", # isort
131 | ]
132 | fixable = ["ALL"]
133 |
134 | [tool.ruff.format]
135 | quote-style = "double"
136 | indent-style = "space"
137 | skip-magic-trailing-comma = false
138 | line-ending = "auto"
139 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rajathjn/shorts_maker/cb3e1823e3f271db08a969b7b7a8a5083d09df8d/tests/__init__.py
--------------------------------------------------------------------------------
/tests/ask_llm_tests/test_ask_llm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tempfile
3 | from pathlib import Path
4 | from unittest.mock import MagicMock, patch
5 |
6 | import pytest
7 |
8 | from ShortsMaker.ask_llm import AskLLM, OllamaServiceManager, YoutubeDetails
9 |
10 |
11 | @pytest.fixture
12 | def mock_ollama_service_manager():
13 | """
14 | Fixture to provide a mocked instance of OllamaServiceManager.
15 |
16 | Returns:
17 | MagicMock: A mocked instance of OllamaServiceManager with predefined return values for start_service and stop_service methods.
18 | """
19 | ollama_service_manager = MagicMock(OllamaServiceManager)
20 | ollama_service_manager.start_service.return_value = True
21 | ollama_service_manager.stop_service.return_value = True
22 | return ollama_service_manager
23 |
24 |
25 | @patch("ShortsMaker.ask_llm.AskLLM._load_llm_model")
26 | def test_initialization_with_valid_config(
27 | mock_load_llm_model, setup_file, mock_ollama_service_manager
28 | ):
29 | mock_load_llm_model.return_value = None
30 | ask_llm = AskLLM(config_file=setup_file, model_name="test_model")
31 | assert ask_llm.model_name == "test_model"
32 |
33 |
34 | def test_initialization_with_invalid_config_path():
35 | invalid_config_path = Path("temp.yml")
36 | with pytest.raises(FileNotFoundError):
37 | AskLLM(config_file=invalid_config_path)
38 |
39 |
40 | def test_initialization_with_non_yaml_file(setup_file, tmp_path):
41 | with pytest.raises(ValueError):
42 | temp_path = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
43 | temp_path = Path(temp_path.name)
44 | try:
45 | AskLLM(config_file=temp_path)
46 | finally:
47 | temp_path.unlink()
48 |
49 |
50 | @patch("ShortsMaker.ask_llm.AskLLM._load_llm_model")
51 | def test_llm_model_loading(mock_load_llm_model, mock_ollama_service_manager, setup_file):
52 | AskLLM(config_file=setup_file, model_name="test_model")
53 | mock_load_llm_model.assert_called_once_with("test_model", 0)
54 |
55 |
56 | @patch("ShortsMaker.ask_llm.AskLLM._load_llm_model")
57 | def test_invoke_creates_chat_prompt(mock_load_llm_model, setup_file, mock_ollama_service_manager):
58 | ask_llm = AskLLM(config_file=setup_file)
59 | ask_llm.ollama_service_manager = mock_ollama_service_manager
60 | ask_llm.llm = MagicMock()
61 | ask_llm.llm.with_structured_output.return_value = ask_llm.llm
62 | ask_llm.llm.invoke.return_value = {"title": "Test Title"}
63 |
64 | input_text = "Test script input."
65 | response = ask_llm.invoke(input_text=input_text)
66 |
67 | assert "title" in response
68 | ask_llm.llm.with_structured_output.assert_called_once_with(YoutubeDetails, include_raw=True)
69 | ask_llm.llm.invoke.assert_called_once()
70 |
71 |
72 | @patch("ShortsMaker.ask_llm.AskLLM._load_llm_model")
73 | @patch("ShortsMaker.ask_llm.OllamaServiceManager.stop_service")
74 | @patch("ShortsMaker.ask_llm.subprocess.check_output")
75 | @patch("ShortsMaker.ask_llm.subprocess.run")
76 | def test_quit_llm_with_self_started_service(
77 | mock_load_llm_model,
78 | mock_stop_service,
79 | mock_check_output,
80 | mock_run,
81 | setup_file,
82 | mock_ollama_service_manager,
83 | ):
84 | mock_load_llm_model.return_value = None
85 |
86 | ask_llm = AskLLM(config_file=setup_file, model_name="test_model")
87 | ask_llm.self_started_ollama = True
88 |
89 | result = ask_llm.quit_llm()
90 | assert result is None
91 | mock_stop_service.assert_called_once()
92 |
93 |
94 | @pytest.mark.skipif("RUNALL" not in os.environ, reason="takes too long")
95 | def test_ask_llm_working(setup_file):
96 | script = "A video about a cat. Doing stunts like running around, flying, and jumping."
97 | ask_llm = AskLLM(config_file=setup_file)
98 | result = ask_llm.invoke(input_text=script)
99 | ask_llm.quit_llm()
100 | assert result["parsed"].title == "Feline Frenzy: Cat Stunt Master!"
101 | assert (
102 | result["parsed"].description
103 | == "Get ready for the most epic feline feats you've ever seen! Watch as our fearless feline friend runs, jumps, and even flies through a series of death-defying stunts."
104 | )
105 | assert result["parsed"].tags == ["cat", "stunts", "flying", "jumping"]
106 | assert (
107 | result["parsed"].thumbnail_description
108 | == "A cat in mid-air, performing a daring stunt with its paws outstretched, surrounded by a blurred cityscape with bright lights and colors."
109 | )
110 | assert result["parsing_error"] is None
111 | assert result is not None
112 |
--------------------------------------------------------------------------------
/tests/ask_llm_tests/test_ollama_service.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import MagicMock, patch
2 |
3 | import pytest
4 |
5 | from ShortsMaker.ask_llm import OllamaServiceManager
6 |
7 |
8 | @pytest.fixture
9 | def ollama_service_manager():
10 | """Fixture to provide an instance of OllamaServiceManager with a mock logger.
11 |
12 | Args:
13 | mock_logger: A mock logger for the OllamaServiceManager.
14 |
15 | Returns:
16 | An instance of OllamaServiceManager.
17 | """
18 | return OllamaServiceManager()
19 |
20 |
21 | @patch("ShortsMaker.ask_llm.subprocess.Popen")
22 | def test_start_service(mock_popen, ollama_service_manager):
23 | process_mock = MagicMock()
24 | process_mock.poll.return_value = None
25 | mock_popen.return_value = process_mock
26 |
27 | result = ollama_service_manager.start_service()
28 | assert result is True
29 | mock_popen.assert_called_once()
30 |
31 |
32 | @patch("ShortsMaker.ask_llm.subprocess.run")
33 | def test_stop_service_on_windows(mock_subprocess_run, ollama_service_manager):
34 | ollama_service_manager.process = MagicMock()
35 | ollama_service_manager.system = "windows"
36 | result = ollama_service_manager.stop_service()
37 | assert result is True
38 | mock_subprocess_run.assert_called()
39 |
40 |
41 | @patch("ShortsMaker.ask_llm.psutil.process_iter")
42 | def test_is_ollama_running(mock_process_iter, ollama_service_manager):
43 | mock_process_iter.return_value = [
44 | MagicMock(info={"name": "Ollama"}),
45 | ]
46 | result = ollama_service_manager.is_ollama_running()
47 | assert result is True
48 |
49 |
50 | def test_is_service_running_with_active_process(ollama_service_manager):
51 | ollama_service_manager.process = MagicMock()
52 | ollama_service_manager.process.poll.return_value = None
53 |
54 | result = ollama_service_manager.is_service_running()
55 | assert result is True
56 |
57 |
58 | @patch("ShortsMaker.ask_llm.subprocess.check_output")
59 | def test_stop_running_model(mock_check_output, ollama_service_manager):
60 | model_name = "test_model"
61 | mock_check_output.return_value = "Stopped"
62 |
63 | result = ollama_service_manager.stop_running_model(model_name)
64 | assert result is True
65 | mock_check_output.assert_called_with(["ollama", "stop", model_name], stderr=-2, text=True)
66 |
67 |
68 | @patch("ShortsMaker.ask_llm.ollama.ps")
69 | def test_get_running_models(mock_ollama_ps, ollama_service_manager):
70 | mock_ollama_ps.return_value = ["model1", "model2"]
71 |
72 | result = ollama_service_manager.get_running_models()
73 | assert result == ["model1", "model2"]
74 |
75 |
76 | @patch("ShortsMaker.ask_llm.ollama.pull")
77 | def test_get_llm_model(mock_ollama_pull, ollama_service_manager):
78 | model_name = "test_model"
79 | mock_ollama_pull.return_value = "model_data"
80 |
81 | result = ollama_service_manager.get_llm_model(model_name)
82 | assert result == "model_data"
83 | mock_ollama_pull.assert_called_with(model_name)
84 |
85 |
86 | @patch("ShortsMaker.ask_llm.ollama.list")
87 | def test_get_list_of_downloaded_files(mock_ollama_list, ollama_service_manager):
88 | mock_ollama_list.return_value = [("models", [MagicMock(), MagicMock(), MagicMock()])]
89 | mock_ollama_list.return_value[0][1][0].model = "submodel1"
90 | mock_ollama_list.return_value[0][1][1].model = "submodel2"
91 | mock_ollama_list.return_value[0][1][2].model = "submodel3"
92 |
93 | result = ollama_service_manager.get_list_of_downloaded_files()
94 | assert result == ["submodel1", "submodel2", "submodel3"]
95 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pathlib import Path
3 |
4 | import pytest
5 |
6 | from ShortsMaker import ShortsMaker
7 |
8 |
9 | @pytest.fixture
10 | def setup_file():
11 | return Path(__file__).parent / "data" / "setup.yml"
12 |
13 |
14 | @pytest.fixture
15 | def shorts_maker(setup_file):
16 | return ShortsMaker(setup_file)
17 |
18 |
19 | @pytest.fixture
20 | def mock_logger():
21 | return logging.getLogger("test_logger")
22 |
--------------------------------------------------------------------------------
/tests/data/setup.yml:
--------------------------------------------------------------------------------
1 | ---
2 | hugging_face_access_token: "random_token"
3 |
4 | cache_dir: "cache"
5 | assets_dir: "assets"
6 |
7 | retry:
8 | max_retries: 3
9 | delay: 5
10 | notify: False
11 |
12 | reddit_praw:
13 | client_id: blah
14 | client_secret: blah
15 | user_agent: blah
16 |
17 | reddit_post_getter:
18 | subreddit_name: RandomSubreddit
19 | record_file_json: post.json
20 | record_file_txt: post.txt
21 |
22 | audio:
23 | output_script_file: generated_audio_script.txt
24 | output_audio_file: output.wav
25 | transcript_json: transcript.json
26 | device: cuda
27 | model: large-v2
28 | batch_size: 16
29 | compute_type: float16
30 |
31 | # Random videos
32 | video:
33 | background_videos_urls:
34 | - https://www.youtube.com/watch?v=n_Dv4JMiwK8
35 | background_music_urls:
36 | - https://www.youtube.com/watch?v=G8a45UZJGh4&t=1s
37 | font_dir: fonts
38 | credits_dir: credits
39 |
--------------------------------------------------------------------------------
/tests/data/test.txt:
--------------------------------------------------------------------------------
1 | This is a test to generate audio.
2 |
--------------------------------------------------------------------------------
/tests/data/test.wav:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rajathjn/shorts_maker/cb3e1823e3f271db08a969b7b7a8a5083d09df8d/tests/data/test.wav
--------------------------------------------------------------------------------
/tests/data/transcript.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "end": 0.352,
4 | "start": 0.151,
5 | "word": "This"
6 | },
7 | {
8 | "end": 0.552,
9 | "start": 0.472,
10 | "word": "is"
11 | },
12 | {
13 | "end": 0.633,
14 | "start": 0.613,
15 | "word": "a"
16 | },
17 | {
18 | "end": 0.974,
19 | "start": 0.693,
20 | "word": "test"
21 | },
22 | {
23 | "end": 1.114,
24 | "start": 1.014,
25 | "word": "to"
26 | },
27 | {
28 | "end": 1.636,
29 | "start": 1.154,
30 | "word": "generate"
31 | },
32 | {
33 | "end": 2.017,
34 | "start": 1.756,
35 | "word": "audio."
36 | }
37 | ]
38 |
--------------------------------------------------------------------------------
/tests/generate_image_tests/test_generate_image.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 | from pathlib import Path
3 | from unittest.mock import MagicMock, patch
4 |
5 | import pytest
6 |
7 | from ShortsMaker import GenerateImage
8 |
9 |
10 | @pytest.fixture
11 | def generate_image(setup_file):
12 | """Fixture to initialize and return a GenerateImage instance.
13 |
14 | Args:
15 | setup_file: The configuration file for GenerateImage.
16 |
17 | Returns:
18 | GenerateImage: An instance of the GenerateImage class.
19 | """
20 | return GenerateImage(setup_file)
21 |
22 |
23 | def test_initialization_with_non_existent_file():
24 | with pytest.raises(FileNotFoundError):
25 | GenerateImage(config_file=Path("non_existent_file.yml"))
26 |
27 |
28 | def test_initialization_with_invalid_file_format():
29 | with pytest.raises(ValueError):
30 | temp_path = tempfile.NamedTemporaryFile(suffix=".txt", delete=False)
31 | temp_path = Path(temp_path.name)
32 | try:
33 | GenerateImage(config_file=temp_path)
34 | finally:
35 | temp_path.unlink()
36 |
37 |
38 | def test_load_model_failure(generate_image):
39 | with pytest.raises(RuntimeError):
40 | generate_image._load_model("model_id")
41 |
42 |
43 | @patch("ShortsMaker.generate_image.GenerateImage._load_model")
44 | def test_use_huggingface_flux_dev_success(mock_load_model, generate_image):
45 | mock_load_model.return_value = True
46 | generate_image.pipe = MagicMock()
47 | output = generate_image.use_huggingface_flux_dev(
48 | prompt="Random stuff", output_path="random_path.png"
49 | )
50 | assert output is not None
51 |
52 |
53 | @patch("ShortsMaker.generate_image.GenerateImage._load_model")
54 | def test_use_huggingface_flux_schnell_success(mock_load_model, generate_image):
55 | mock_load_model.return_value = True
56 | generate_image.pipe = MagicMock()
57 | output = generate_image.use_huggingface_flux_schnell(
58 | prompt="Random stuff", output_path="random_path.png"
59 | )
60 | assert output is not None
61 |
--------------------------------------------------------------------------------
/tests/shorts_maker_tests/test_abbreviation_replacer.py:
--------------------------------------------------------------------------------
1 | from ShortsMaker import abbreviation_replacer
2 |
3 |
4 | def test_abbreviation_replacer_basic_replacement():
5 | text = "This is an example ABB."
6 | abbreviation = "ABB"
7 | replacement = "abbreviation"
8 | result = abbreviation_replacer(text, abbreviation, replacement)
9 | assert result == "This is an example abbreviation."
10 |
11 |
12 | def test_abbreviation_replacer_replacement_with_padding():
13 | text = "This is an example ABB."
14 | abbreviation = "ABB"
15 | replacement = "abbreviation"
16 | padding = "."
17 | result = abbreviation_replacer(text, abbreviation, replacement, padding=padding)
18 | assert result == "This is an example abbreviation"
19 |
20 |
21 | def test_abbreviation_replacer_multiple_occurrences():
22 | text = "ABB is an ABBbreviation ABB."
23 | abbreviation = "ABB"
24 | replacement = "abbreviation"
25 | result = abbreviation_replacer(text, abbreviation, replacement)
26 | assert result == "abbreviation is an abbreviationbreviation abbreviation."
27 |
28 |
29 | def test_abbreviation_replacer_no_match():
30 | text = "No match here."
31 | abbreviation = "XYZ"
32 | replacement = "something"
33 | result = abbreviation_replacer(text, abbreviation, replacement)
34 | assert result == text
35 |
36 |
37 | def test_abbreviation_replacer_empty_string():
38 | text = ""
39 | abbreviation = "ABB"
40 | replacement = "abbreviation"
41 | result = abbreviation_replacer(text, abbreviation, replacement)
42 | assert result == ""
43 |
--------------------------------------------------------------------------------
/tests/shorts_maker_tests/test_fix_text.py:
--------------------------------------------------------------------------------
1 | def test_fix_text_basic(shorts_maker):
2 | source_txt = "This is a te st sentence."
3 | expected_output = "This is a test sentence."
4 | assert shorts_maker.fix_text(source_txt) == expected_output
5 |
6 |
7 | def test_fix_text_with_escape_characters(shorts_maker):
8 | source_txt = "This is a\t test sentence.\nThis is another test sentence.\r"
9 | expected_output = "This is a test sentence. This is another test sentence."
10 | assert shorts_maker.fix_text(source_txt) == expected_output
11 |
12 |
13 | def test_fix_text_with_punctuations(shorts_maker):
14 | source_txt = "Helllo!! How are you? I'm fine."
15 | expected_output = "Hello!! How are you? I'm fine."
16 | assert shorts_maker.fix_text(source_txt) == expected_output
17 |
18 |
19 | def test_fix_text_with_unicode(shorts_maker):
20 | source_txt = "Café is a Frnch word."
21 | expected_output = "Café is a French word."
22 | assert shorts_maker.fix_text(source_txt) == expected_output
23 |
24 |
25 | def test_fix_text_with_multiple_spaces(shorts_maker):
26 | source_txt = "This is a test."
27 | expected_output = "This is a test."
28 | assert shorts_maker.fix_text(source_txt) == expected_output
29 |
--------------------------------------------------------------------------------
/tests/shorts_maker_tests/test_has_alpha_and_digit.py:
--------------------------------------------------------------------------------
1 | from ShortsMaker import has_alpha_and_digit
2 |
3 |
4 | def test_has_alpha_and_digit_with_alphanumeric_input():
5 | assert has_alpha_and_digit("a1") is True
6 |
7 |
8 | def test_has_alpha_and_digit_with_only_alpha_input():
9 | assert has_alpha_and_digit("abc") is False
10 |
11 |
12 | def test_has_alpha_and_digit_with_only_digit_input():
13 | assert has_alpha_and_digit("1234") is False
14 |
15 |
16 | def test_has_alpha_and_digit_with_empty_string():
17 | assert has_alpha_and_digit("") is False
18 |
19 |
20 | def test_has_alpha_and_digit_with_special_characters():
21 | assert has_alpha_and_digit("a@1") is True
22 |
23 |
24 | def test_has_alpha_and_digit_with_whitespace():
25 | assert has_alpha_and_digit("a1 ") is True
26 |
27 |
28 | def test_has_alpha_and_digit_with_uppercase_letters():
29 | assert has_alpha_and_digit("A1") is True
30 |
--------------------------------------------------------------------------------
/tests/shorts_maker_tests/test_shorts_maker.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from unittest.mock import MagicMock, patch
4 |
5 | import pytest
6 | import yaml
7 |
8 | from ShortsMaker import ShortsMaker
9 |
10 |
11 | def test_validate_config_path_valid(tmp_path):
12 | config_path = tmp_path / "config.yml"
13 | config_path.touch()
14 |
15 | with pytest.raises(TypeError):
16 | maker = ShortsMaker(config_path)
17 | assert maker.setup_cfg == config_path
18 |
19 |
20 | def test_validate_config_path_invalid_extension(tmp_path):
21 | config_path = tmp_path / "config.txt"
22 | config_path.touch()
23 |
24 | with pytest.raises(ValueError):
25 | ShortsMaker(config_path)
26 |
27 |
28 | def test_validate_config_path_not_found():
29 | with pytest.raises(FileNotFoundError):
30 | ShortsMaker(Path("/nonexistent/config.yml"))
31 |
32 |
33 | @patch("ShortsMaker.ShortsMaker.is_unique_submission")
34 | @patch("praw.Reddit")
35 | def test_get_reddit_post(mock_reddit, mock_is_unique_submission, shorts_maker):
36 | mock_submission = MagicMock()
37 | mock_submission.title = "Test Title"
38 | mock_submission.selftext = "Test Content"
39 | mock_submission.name = "t3_123abc"
40 | mock_submission.id = "123abc"
41 | mock_submission.url = "https://reddit.com/r/test/123abc"
42 |
43 | mock_subreddit = MagicMock()
44 | mock_subreddit.hot.return_value = [mock_submission]
45 | mock_subreddit.title = "Test Subreddit"
46 | mock_subreddit.display_name = "test"
47 |
48 | mock_is_unique_submission.return_value = True
49 | mock_reddit.return_value.subreddit.return_value = mock_subreddit
50 |
51 | result = shorts_maker.get_reddit_post()
52 | assert "Test Title" in result
53 | assert "Test Content" in result
54 |
55 |
56 | @patch("praw.Reddit")
57 | def test_get_reddit_post_with_url(mock_reddit, shorts_maker):
58 | # Mock submission data
59 | mock_submission = MagicMock()
60 | mock_submission.title = "Test Title from URL"
61 | mock_submission.selftext = "Test Content from URL"
62 | mock_submission.name = "t3_url_submission"
63 | mock_submission.id = "url_submission"
64 | mock_submission.url = "https://www.reddit.com/r/random_subreddit/test_title_from_url/"
65 |
66 | # Mock Reddit API response
67 | mock_reddit.return_value.submission.return_value = mock_submission
68 |
69 | # Test with URL
70 | test_url = "https://www.reddit.com/r/random_subreddit/test_title_from_url/"
71 | result = shorts_maker.get_reddit_post(url=test_url)
72 |
73 | # Assertions
74 | assert "Test Title from URL" in result
75 | assert "Test Content from URL" in result
76 | mock_reddit.return_value.submission.assert_called_once_with(url=test_url)
77 |
78 |
79 | @patch("ShortsMaker.shorts_maker.tts")
80 | def test_generate_audio_success(mock_tts, shorts_maker, tmp_path):
81 | # Test successful audio generation
82 | source_text = "Test text for audio generation"
83 | output_audio = tmp_path / "test_output.wav"
84 | output_script = tmp_path / "test_script.txt"
85 | seed = "en_us_001"
86 |
87 | mock_tts.return_value = None
88 |
89 | result = shorts_maker.generate_audio(
90 | source_text, output_audio=output_audio, output_script_file=output_script, seed=seed
91 | )
92 |
93 | assert result is True
94 | mock_tts.assert_called_once()
95 |
96 |
97 | @patch("ShortsMaker.shorts_maker.tts")
98 | def test_generate_audio_failure(mock_tts, shorts_maker):
99 | # Test failed audio generation
100 | source_text = "Test text for audio generation"
101 | mock_tts.side_effect = Exception("TTS Failed")
102 |
103 | result = shorts_maker.generate_audio(source_text)
104 |
105 | assert result is False
106 | mock_tts.assert_called_once()
107 |
108 |
109 | @patch("secrets.choice")
110 | def test_generate_audio_random_speaker(mock_choice, shorts_maker):
111 | # Test random speaker selection when no seed provided
112 | mock_choice.return_value = "en_us_001"
113 | source_text = "Test text"
114 |
115 | with patch("ShortsMaker.shorts_maker.tts"):
116 | shorts_maker.generate_audio(source_text)
117 | mock_choice.assert_called_once()
118 |
119 |
120 | def test_generate_audio_text_processing(shorts_maker):
121 | # Test text processing functionality
122 | source_text = "Test123 text AITA for YTA"
123 | output_script = shorts_maker.cache_dir / "test_script.txt"
124 |
125 | with patch("ShortsMaker.shorts_maker.tts"):
126 | shorts_maker.generate_audio(source_text, output_script_file=output_script)
127 |
128 | with open(output_script) as f:
129 | processed_text = f.read()
130 |
131 | # Verify text processing
132 | assert "Test 123" in processed_text
133 | assert "Am I the asshole" in processed_text
134 | assert "You're the Asshole" in processed_text
135 |
136 |
137 | @patch("ShortsMaker.shorts_maker.generate_audio_transcription")
138 | def test_generate_audio_transcript(mock_transcription, shorts_maker, tmp_path):
139 | # Setup test files
140 | source_audio = tmp_path / "test.wav"
141 | source_audio.touch()
142 | source_text = tmp_path / "test.txt"
143 | source_text.write_text("Test transcript text")
144 | output_file = tmp_path / "transcript.json"
145 |
146 | # Setup mock transcription response
147 | mock_transcript = [
148 | {"word": "Test", "start": 0.1, "end": 0.3},
149 | {"word": "transcript", "start": 0.4, "end": 0.6},
150 | {"word": "text", "start": 0.7, "end": 0.9},
151 | ]
152 | mock_transcription.return_value = mock_transcript
153 |
154 | # Call function
155 | result = shorts_maker.generate_audio_transcript(
156 | source_audio, source_text, output_transcript_file=str(output_file), debug=False
157 | )
158 |
159 | # Verify mock was called with correct args
160 | mock_transcription.assert_called_once_with(
161 | audio_file=str(source_audio),
162 | script="Test transcript text",
163 | device=shorts_maker.audio_cfg["device"],
164 | model=shorts_maker.audio_cfg["model"],
165 | batch_size=shorts_maker.audio_cfg["batch_size"],
166 | compute_type=shorts_maker.audio_cfg["compute_type"],
167 | )
168 |
169 | # Verify result contains filtered transcript
170 | assert result == mock_transcript
171 |
172 | # Verify transcript was saved to file
173 | with open(output_file) as f:
174 | saved_transcript = yaml.safe_load(f)
175 | assert saved_transcript == mock_transcript
176 |
177 |
178 | @patch("ShortsMaker.shorts_maker.generate_audio_transcription")
179 | def test_generate_audio_transcript_default_output(mock_generate_audio_transcription, shorts_maker):
180 | # Test default output file name generation
181 | source_audio = Path(__file__).parent.parent / "data" / "test.wav"
182 | source_text = Path(__file__).parent.parent / "data" / "test.txt"
183 | with open(Path(__file__).parent.parent / "data" / "transcript.json") as f:
184 | expected_transcript = yaml.safe_load(f)
185 |
186 | mock_generate_audio_transcription.return_value = expected_transcript
187 | result = shorts_maker.generate_audio_transcript(source_audio, source_text)
188 |
189 | expected_output = shorts_maker.cache_dir / shorts_maker.audio_cfg["transcript_json"]
190 | assert expected_output.exists()
191 | mock_generate_audio_transcription.assert_called_once()
192 | assert result == expected_transcript
193 |
194 |
195 | @pytest.mark.skipif("RUNALL" not in os.environ, reason="takes too long")
196 | def test_generate_audio_transcript_with_whisperx(shorts_maker):
197 | # Test default output file name generation
198 | source_audio = Path(__file__).parent.parent / "data" / "test.wav"
199 | source_text = Path(__file__).parent.parent / "data" / "test.txt"
200 | with open(Path(__file__).parent.parent / "data" / "transcript.json") as f:
201 | expected_transcript = yaml.safe_load(f)
202 |
203 | result = shorts_maker.generate_audio_transcript(source_audio, source_text)
204 |
205 | expected_output = shorts_maker.cache_dir / shorts_maker.audio_cfg["transcript_json"]
206 | assert expected_output.exists()
207 | assert result == expected_transcript
208 |
209 |
210 | def test_filter_word_transcript(shorts_maker):
211 | test_transcript = [
212 | {"word": "valid", "start": 0.1, "end": 0.3},
213 | {"word": "invalid1", "start": 0, "end": 0.5},
214 | {"word": "invalid2", "start": 0.1, "end": 5.5},
215 | {"word": "valid2", "start": 1.0, "end": 1.2},
216 | ]
217 |
218 | filtered = shorts_maker._filter_word_transcript(test_transcript)
219 |
220 | assert len(filtered) == 2
221 | assert filtered[0]["word"] == "valid"
222 | assert filtered[1]["word"] == "valid2"
223 |
--------------------------------------------------------------------------------
/tests/shorts_maker_tests/test_split_alpha_and_digit.py:
--------------------------------------------------------------------------------
1 | from ShortsMaker import split_alpha_and_digit
2 |
3 |
4 | def test_split_alpha_and_digit_with_alphanumeric_word():
5 | result = split_alpha_and_digit("abc123")
6 | assert result == "abc 123"
7 |
8 |
9 | def test_split_alpha_and_digit_with_digits_and_letters_interleaved():
10 | result = split_alpha_and_digit("a1b2c3")
11 | assert result == "a 1 b 2 c 3"
12 |
13 |
14 | def test_split_alpha_and_digit_with_only_letters():
15 | result = split_alpha_and_digit("abcdef")
16 | assert result == "abcdef"
17 |
18 |
19 | def test_split_alpha_and_digit_with_only_digits():
20 | result = split_alpha_and_digit("123456")
21 | assert result == "123456"
22 |
23 |
24 | def test_split_alpha_and_digit_with_empty_string():
25 | result = split_alpha_and_digit("")
26 | assert result == ""
27 |
28 |
29 | def test_split_alpha_and_digit_with_special_characters():
30 | result = split_alpha_and_digit("a1!b2@")
31 | assert result == "a 1! b 2@"
32 |
33 |
34 | def test_split_alpha_and_digit_with_spaces_and_tabs():
35 | result = split_alpha_and_digit("a1 b2\tc3")
36 | assert result == "a 1 b 2\t c 3"
37 |
38 |
39 | def test_split_alpha_and_digit_with_uppercase_and_numbers():
40 | result = split_alpha_and_digit("ABC123")
41 | assert result == "ABC 123"
42 |
--------------------------------------------------------------------------------
/tests/utils_tests/test_audio_transcript.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import MagicMock, patch
2 |
3 | import pytest
4 |
5 | from ShortsMaker.utils.audio_transcript import (
6 | align_transcript_with_script,
7 | generate_audio_transcription,
8 | )
9 |
10 |
11 | def test_align_transcript_with_script_basic():
12 | transcript = [
13 | {"text": "hello world", "start": 0.0, "end": 1.0},
14 | {"text": "how are you", "start": 1.0, "end": 2.0},
15 | ]
16 | script = "hello world how are you today"
17 |
18 | result = align_transcript_with_script(transcript, script)
19 |
20 | assert len(result) == 2
21 | assert result[0]["text"] == "hello world"
22 | assert result[1]["text"] == "how are you"
23 | assert result[0]["start"] == 0.0
24 | assert result[0]["end"] == 1.0
25 | assert result[1]["start"] == 1.0
26 | assert result[1]["end"] == 2.0
27 |
28 |
29 | def test_align_transcript_with_script_partial_match():
30 | transcript = [
31 | {"text": "helo wrld", "start": 0.0, "end": 1.0}, # Misspelled
32 | {"text": "how r u", "start": 1.0, "end": 2.0}, # Text speak
33 | ]
34 | script = "hello world how are you"
35 |
36 | result = align_transcript_with_script(transcript, script)
37 |
38 | assert len(result) == 2
39 | assert result[0]["text"] == "hello world"
40 | assert result[1]["text"] == "how"
41 |
42 |
43 | def test_align_transcript_empty_inputs():
44 | transcript = []
45 | script = "hello world"
46 |
47 | result = align_transcript_with_script(transcript, script)
48 | assert result == []
49 |
50 | transcript = [{"text": "hello", "start": 0.0, "end": 1.0}]
51 | script = ""
52 |
53 | result = align_transcript_with_script(transcript, script)
54 | assert len(result) == 1
55 | assert result[0]["text"] == "hello"
56 |
57 |
58 | def test_align_transcript_with_longer_script():
59 | transcript = [
60 | {"text": "this is", "start": 0.0, "end": 1.0},
61 | {"text": "a test", "start": 1.0, "end": 2.0},
62 | ]
63 | script = "this is a test with extra words at the end"
64 |
65 | result = align_transcript_with_script(transcript, script)
66 |
67 | assert len(result) == 2
68 | assert result[0]["text"] == "this is"
69 | assert result[1]["text"] == "a test"
70 | assert all(["start" in entry and "end" in entry for entry in result])
71 |
72 |
73 | def test_align_transcript_timing_preserved():
74 | transcript = [
75 | {"text": "first segment", "start": 1.5, "end": 2.5},
76 | {"text": "second segment", "start": 2.5, "end": 3.5},
77 | ]
78 | script = "first segment second segment"
79 |
80 | result = align_transcript_with_script(transcript, script)
81 |
82 | assert len(result) == 2
83 | assert result[0]["start"] == 1.5
84 | assert result[0]["end"] == 2.5
85 | assert result[1]["start"] == 2.5
86 | assert result[1]["end"] == 3.5
87 |
88 |
89 | @pytest.fixture
90 | def mock_whisperx():
91 | with patch("ShortsMaker.utils.audio_transcript.whisperx") as mock_wx:
92 | # Mock the load_model and model.transcribe
93 | mock_model = MagicMock()
94 | mock_model.transcribe.return_value = {
95 | "segments": [{"text": "hello world", "start": 0.0, "end": 1.0}],
96 | "language": "en",
97 | }
98 | mock_wx.load_model.return_value = mock_model
99 |
100 | # Mock load_audio
101 | mock_wx.load_audio.return_value = "audio_data"
102 |
103 | # Mock alignment model and align
104 | mock_align_model = MagicMock()
105 | mock_metadata = MagicMock()
106 | mock_wx.load_align_model.return_value = (mock_align_model, mock_metadata)
107 | mock_wx.align.return_value = {
108 | "segments": [
109 | {
110 | "words": [
111 | {"word": "hello", "start": 0.0, "end": 0.5},
112 | {"word": "world", "start": 0.5, "end": 1.0},
113 | ]
114 | }
115 | ]
116 | }
117 | yield mock_wx
118 |
119 |
120 | def test_generate_audio_transcription_basic(mock_whisperx):
121 | # Test basic functionality
122 | result = generate_audio_transcription(
123 | audio_file="test.wav", script="hello world", device="cpu", batch_size=8
124 | )
125 |
126 | assert len(result) == 2
127 | assert result[0]["word"] == "hello"
128 | assert result[0]["start"] == 0.0
129 | assert result[0]["end"] == 0.5
130 | assert result[1]["word"] == "world"
131 | assert result[1]["start"] == 0.5
132 | assert result[1]["end"] == 1.0
133 |
134 |
135 | def test_generate_audio_transcription_model_cleanup(mock_whisperx):
136 | with patch("ShortsMaker.utils.audio_transcript.gc") as mock_gc:
137 | with patch("ShortsMaker.utils.audio_transcript.torch") as mock_torch:
138 | mock_torch.cuda.is_available.return_value = True
139 |
140 | generate_audio_transcription(audio_file="test.wav", script="hello world")
141 |
142 | # Verify cleanup calls
143 | assert mock_gc.collect.call_count == 2
144 | assert mock_torch.cuda.empty_cache.call_count == 2
145 |
146 |
147 | def test_generate_audio_transcription_missing_timestamps(mock_whisperx):
148 | # Mock align to return words without timestamps
149 | mock_whisperx.align.return_value = {
150 | "segments": [{"words": [{"word": "hello"}, {"word": "world", "start": 0.5, "end": 1.0}]}]
151 | }
152 |
153 | result = generate_audio_transcription(audio_file="test.wav", script="hello world")
154 |
155 | assert len(result) == 2
156 | assert result[0]["word"] == "hello"
157 | assert "start" in result[0]
158 | assert "end" in result[0]
159 | assert result[1]["word"] == "world"
160 | assert result[1]["start"] == 0.5
161 | assert result[1]["end"] == 1.0
162 |
163 |
164 | def test_generate_audio_transcription_parameters(mock_whisperx):
165 | # Test if parameters are correctly passed
166 | generate_audio_transcription(
167 | audio_file="test.wav",
168 | script="test script",
169 | device="test_device",
170 | batch_size=32,
171 | compute_type="float32",
172 | model="medium",
173 | )
174 |
175 | mock_whisperx.load_model.assert_called_with("medium", "test_device", compute_type="float32")
176 |
177 | mock_whisperx.load_align_model.assert_called_with(language_code="en", device="test_device")
178 |
--------------------------------------------------------------------------------
/tests/utils_tests/test_colors_dict.py:
--------------------------------------------------------------------------------
1 | from ShortsMaker import COLORS_DICT
2 |
3 |
4 | def test_colors_dict_structure():
5 | # Test that COLORS_DICT is a dictionary
6 | assert isinstance(COLORS_DICT, dict)
7 |
8 | # Test that all values are RGBA tuples
9 | for color_value in COLORS_DICT.values():
10 | assert isinstance(color_value, tuple)
11 | assert len(color_value) == 4
12 | for component in color_value:
13 | assert isinstance(component, int)
14 | assert 0 <= component <= 255
15 |
16 |
17 | def test_common_colors_present():
18 | # Test that some common colors are present
19 | assert "white" in COLORS_DICT
20 | assert "black" not in COLORS_DICT
21 | assert "red" in COLORS_DICT
22 | assert "blue" in COLORS_DICT
23 | assert "yellow" in COLORS_DICT
24 |
25 |
26 | def test_color_values():
27 | # Test specific color values
28 | assert COLORS_DICT["white"] == (255, 255, 255, 255)
29 | assert COLORS_DICT["yellow"] == (255, 255, 0, 255)
30 | assert COLORS_DICT["cyan"] == (0, 255, 255, 255)
31 | assert COLORS_DICT["magenta"] == (255, 0, 255, 255)
32 |
--------------------------------------------------------------------------------
/tests/utils_tests/test_download_youtube_music.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from unittest.mock import MagicMock, patch
3 |
4 | import pytest
5 |
6 | from ShortsMaker.utils.download_youtube_music import download_youtube_music, sanitize_filename
7 |
8 |
9 | @pytest.fixture
10 | def mock_music_dir(tmp_path):
11 | return tmp_path / "music"
12 |
13 |
14 | @pytest.fixture
15 | def mock_ydl_no_chapters():
16 | mock = MagicMock()
17 | mock.extract_info.return_value = {"title": "test_song"}
18 | mock.sanitize_info.return_value = {"title": "test_song", "chapters": None}
19 | return mock
20 |
21 |
22 | @pytest.fixture
23 | def mock_ydl_with_chapters():
24 | mock = MagicMock()
25 | mock.extract_info.return_value = {"title": "test_song_chapters"}
26 | mock.sanitize_info.return_value = {
27 | "title": "test_song_chapters",
28 | "chapters": [
29 | {"title": "Chapter 1", "start_time": 0, "end_time": 60},
30 | {"title": "Chapter 2", "start_time": 60, "end_time": 120},
31 | ],
32 | }
33 | return mock
34 |
35 |
36 | @pytest.mark.parametrize("force", [True, False])
37 | def test_download_without_chapters(mock_music_dir, mock_ydl_no_chapters, force):
38 | with patch("yt_dlp.YoutubeDL") as mock_ydl_class:
39 | mock_ydl_class.return_value.__enter__.return_value = mock_ydl_no_chapters
40 |
41 | result = download_youtube_music("https://youtube.com/test", mock_music_dir, force=force)
42 |
43 | assert isinstance(result, list)
44 | assert isinstance(result[0], Path)
45 | assert result[0].name == "test_song.wav"
46 | mock_ydl_no_chapters.download.assert_called_once()
47 |
48 |
49 | def test_download_with_chapters(mock_music_dir, mock_ydl_with_chapters):
50 | with patch("yt_dlp.YoutubeDL") as mock_ydl_class:
51 | mock_ydl_class.return_value.__enter__.return_value = mock_ydl_with_chapters
52 |
53 | result = download_youtube_music("https://youtube.com/test", mock_music_dir)
54 |
55 | assert isinstance(result, list)
56 | assert len(result) >= 0 # Since files are only checked at end
57 | mock_ydl_with_chapters.download.call_count == 2
58 |
59 |
60 | def test_download_with_existing_files(mock_music_dir, mock_ydl_no_chapters):
61 | # Create existing file
62 | mock_music_dir.mkdir(parents=True)
63 | existing_file = mock_music_dir / "test_song.wav"
64 | existing_file.touch()
65 |
66 | with patch("yt_dlp.YoutubeDL") as mock_ydl_class:
67 | mock_ydl_class.return_value.__enter__.return_value = mock_ydl_no_chapters
68 |
69 | # Should not download when force=False
70 | result = download_youtube_music("https://youtube.com/test", mock_music_dir, force=False)
71 |
72 | assert isinstance(result, list)
73 | assert isinstance(result[0], Path)
74 | assert not mock_ydl_no_chapters.download.called
75 |
76 | # Should download when force=True
77 | result = download_youtube_music("https://youtube.com/test", mock_music_dir, force=True)
78 |
79 | assert isinstance(result, list)
80 | assert isinstance(result[0], Path)
81 | mock_ydl_no_chapters.download.assert_called_once()
82 |
83 |
84 | @pytest.mark.parametrize(
85 | "filename, expected_filenames",
86 | [
87 | ("normal filename", "normal_filename"),
88 | (" spaces ", "spaces"),
89 | ("file.name.", "file.name"),
90 | ('invalid<>:"/\\|?*chars', "invalid_________chars"),
91 | ("mixed case FiLe", "mixed_case_FiLe"),
92 | (" leading.trailing. ", "leading.trailing"),
93 | ("file with spaces", "file_with_spaces"),
94 | ("file*with?invalid:chars", "file_with_invalid_chars"),
95 | ],
96 | )
97 | def test_sanitize_filename(filename, expected_filenames):
98 | assert sanitize_filename(filename) == expected_filenames
99 |
--------------------------------------------------------------------------------
/tests/utils_tests/test_download_youtube_video.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from unittest.mock import Mock, patch
3 |
4 | import pytest
5 |
6 | from ShortsMaker.utils.download_youtube_video import download_youtube_video
7 |
8 |
9 | @pytest.fixture
10 | def mock_ydl():
11 | with patch("yt_dlp.YoutubeDL") as mock:
12 | mock_instance = Mock()
13 | mock.return_value.__enter__.return_value = mock_instance
14 | mock_instance.extract_info.return_value = {"title": "test_video"}
15 | mock_instance.sanitize_info.return_value = {"title": "test_video"}
16 | mock_instance.prepare_filename.return_value = "test_video.mp4"
17 | yield mock_instance
18 |
19 |
20 | @pytest.fixture
21 | def tmp_path_with_video(tmp_path):
22 | video_file = tmp_path / "test_video.mp4"
23 | video_file.touch()
24 | return tmp_path
25 |
26 |
27 | def test_download_video_success(mock_ydl, tmp_path):
28 | url = "https://www.youtube.com/watch?v=test123"
29 |
30 | result = download_youtube_video(url, tmp_path)
31 |
32 | mock_ydl.extract_info.assert_called_once_with(url, download=False)
33 | mock_ydl.download.assert_called_once_with([url])
34 | assert isinstance(result, list)
35 | assert all(isinstance(p, Path) for p in result)
36 |
37 |
38 | def test_download_video_existing_file(mock_ydl, tmp_path_with_video):
39 | url = "https://www.youtube.com/watch?v=test123"
40 |
41 | result = download_youtube_video(url, tmp_path_with_video)
42 |
43 | mock_ydl.extract_info.assert_called_once_with(url, download=False)
44 | mock_ydl.download.assert_not_called()
45 | assert isinstance(result, list)
46 | assert len(result) == 1
47 | assert all(isinstance(p, Path) for p in result)
48 |
49 |
50 | def test_download_video_force(mock_ydl, tmp_path_with_video):
51 | url = "https://www.youtube.com/watch?v=test123"
52 |
53 | result = download_youtube_video(url, tmp_path_with_video, force=True)
54 |
55 | mock_ydl.extract_info.assert_called_once_with(url, download=False)
56 | mock_ydl.download.assert_called_once_with([url])
57 | assert isinstance(result, list)
58 | assert all(isinstance(p, Path) for p in result)
59 |
60 |
61 | def test_download_video_no_files(mock_ydl, tmp_path):
62 | url = "https://www.youtube.com/watch?v=test123"
63 | mock_ydl.prepare_filename.return_value = str(tmp_path / "nonexistent.mp4")
64 |
65 | result = download_youtube_video(url, tmp_path)
66 |
67 | assert isinstance(result, list)
68 | assert len(result) == 0
69 |
--------------------------------------------------------------------------------
/tests/utils_tests/test_get_tts.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import Mock, patch
2 |
3 | import pytest
4 |
5 | from ShortsMaker.utils.get_tts import (
6 | ENDPOINT_DATA,
7 | VOICES,
8 | _process_chunks,
9 | _split_text,
10 | _validate_inputs,
11 | tts,
12 | )
13 |
14 |
15 | @pytest.fixture
16 | def mock_audio_segment():
17 | with patch("pydub.AudioSegment.from_file") as mock:
18 | yield mock
19 |
20 |
21 | def test_validate_inputs_valid():
22 | _validate_inputs("test text", VOICES[0])
23 | # Should not raise any exception
24 |
25 |
26 | def test_validate_inputs_invalid_voice():
27 | with pytest.raises(ValueError, match="voice must be valid"):
28 | _validate_inputs("test text", "invalid_voice")
29 |
30 |
31 | def test_validate_inputs_empty_text():
32 | with pytest.raises(ValueError, match="text must not be 'None'"):
33 | _validate_inputs("", VOICES[0])
34 |
35 |
36 | def test_split_text():
37 | text = "This is a test text that needs to be split into multiple chunks"
38 | chunks = _split_text(text, chunk_size=20)
39 | assert len(chunks) > 1
40 | assert all(len(chunk) <= 20 for chunk in chunks)
41 |
42 |
43 | @patch("requests.post")
44 | def test_process_chunks_success(mock_post):
45 | mock_response = Mock()
46 | mock_response.status_code = 200
47 | mock_response.json.return_value = {"data": "fake_base64_data"}
48 | mock_post.return_value = mock_response
49 |
50 | chunks = ["test chunk"]
51 | endpoint = {"url": "test_url", "response": "data"}
52 | voice = VOICES[0]
53 | audio_data = [""]
54 |
55 | result = _process_chunks(chunks, endpoint, voice, audio_data)
56 | assert result == ["fake_base64_data"]
57 |
58 |
59 | @patch("requests.post")
60 | def test_process_chunks_failure(mock_post):
61 | mock_response = Mock()
62 | mock_response.status_code = 404
63 | mock_post.return_value = mock_response
64 |
65 | chunks = ["test chunk"]
66 | endpoint = {"url": "test_url", "response": "data"}
67 | voice = VOICES[0]
68 | audio_data = [""]
69 |
70 | result = _process_chunks(chunks, endpoint, voice, audio_data)
71 | assert result is None
72 |
73 |
74 | @patch("pydub.AudioSegment.from_file")
75 | @patch("requests.post")
76 | def test_tts_integration(mock_post, mock_audio):
77 | mock_response = Mock()
78 | mock_response.status_code = 200
79 | mock_response.json.return_value = {"data": "ZmFrZV9iYXNlNjRfZGF0YQ=="}
80 | mock_post.return_value = mock_response
81 |
82 | mock_audio_obj = Mock()
83 | mock_audio_obj.export = Mock()
84 | mock_audio.return_value = mock_audio_obj
85 |
86 | tts("test text", VOICES[0], "test_output.wav")
87 | assert mock_audio_obj.export.called
88 |
89 |
90 | @patch("requests.post")
91 | @patch("ShortsMaker.utils.get_tts._save_audio")
92 | def test_tts_with_failing_and_successful_endpoints(mock_save_audio, mock_post):
93 | # Mock responses for endpoints
94 | def mock_post_side_effect(url, json, headers) -> Mock:
95 | if url == ENDPOINT_DATA[0]["url"]:
96 | # Simulate failure for the first endpoint
97 | response = Mock()
98 | response.status_code = 500
99 | return response
100 | elif url == ENDPOINT_DATA[1]["url"]:
101 | # Simulate success for the second endpoint
102 | response = Mock()
103 | response.status_code = 200
104 | response.json.return_value = {ENDPOINT_DATA[1]["response"]: "mock_audio_data"}
105 | return response
106 |
107 | mock_post.side_effect = mock_post_side_effect
108 |
109 | mock_save_audio.return_value = None
110 |
111 | # Input data
112 | text = "This is a test."
113 | voice = "en_us_001"
114 | output_filename = "test_output.mp3"
115 |
116 | # Call the tts function
117 | tts(text, voice, output_filename)
118 |
119 | # Assertions
120 | assert mock_post.call_count == 2
121 | mock_post.assert_any_call(
122 | ENDPOINT_DATA[0]["url"],
123 | json={"text": "This is a test.", "voice": voice},
124 | headers={
125 | "User-Agent": "com.zhiliaoapp.musically/2022600030 (Linux; U; Android 7.1.2; es_ES; SM-G988N; Build/NRD90M;tt-ok/3.12.13.1)",
126 | },
127 | )
128 | mock_post.assert_any_call(
129 | ENDPOINT_DATA[1]["url"],
130 | json={"text": "This is a test.", "voice": voice},
131 | headers={
132 | "User-Agent": "com.zhiliaoapp.musically/2022600030 (Linux; U; Android 7.1.2; es_ES; SM-G988N; Build/NRD90M;tt-ok/3.12.13.1)",
133 | },
134 | )
135 | mock_save_audio.assert_called_once_with(["mock_audio_data"], output_filename)
136 |
--------------------------------------------------------------------------------
/tests/utils_tests/test_logging_config.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from unittest.mock import patch
3 |
4 | import pytest
5 |
6 | import ShortsMaker
7 | import ShortsMaker.utils
8 | from ShortsMaker.utils.logging_config import (
9 | LOGGERS,
10 | configure_logging,
11 | get_logger,
12 | )
13 |
14 |
15 | @pytest.fixture
16 | def reset_logging_state():
17 | """Reset global logging state between tests"""
18 | global INITIALIZED, LOGGERS
19 | INITIALIZED = False
20 | LOGGERS.clear()
21 | yield
22 | INITIALIZED = False
23 | LOGGERS.clear()
24 |
25 |
26 | def test_get_logger_creates_new_logger(reset_logging_state):
27 | logger = get_logger("test_logger")
28 | assert isinstance(logger, logging.Logger)
29 | assert logger.name == "test_logger"
30 | assert len(logger.handlers) == 2
31 | assert not logger.propagate
32 |
33 |
34 | def test_get_logger_returns_cached_logger(reset_logging_state):
35 | logger1 = get_logger("test_logger")
36 | logger2 = get_logger("test_logger")
37 | assert logger1 is logger2
38 |
39 |
40 | def test_configure_logging_updates_settings(reset_logging_state, tmp_path):
41 | test_log_file = tmp_path / "test.log"
42 | test_level = "INFO"
43 | configure_logging(log_file=test_log_file, level=test_level, enable=True)
44 | assert ShortsMaker.utils.logging_config.LOG_FILE == test_log_file
45 | assert ShortsMaker.utils.logging_config.LOG_LEVEL == test_level
46 | assert ShortsMaker.utils.logging_config.LOGGING_ENABLED is True
47 | assert ShortsMaker.utils.logging_config.INITIALIZED is True
48 |
49 |
50 | def test_configure_logging_updates_existing_loggers(reset_logging_state):
51 | logger = get_logger("test_logger")
52 |
53 | configure_logging(level="INFO", enable=True)
54 | assert logger.level == logging.INFO
55 |
56 | configure_logging(enable=False)
57 | assert logger.level == logging.CRITICAL
58 |
59 |
60 | @patch("pathlib.Path.mkdir")
61 | def test_configure_logging_creates_directory(mock_mkdir):
62 | configure_logging()
63 | mock_mkdir.assert_called_once_with(parents=True, exist_ok=True)
64 |
65 |
66 | def test_disabled_logging_sets_critical_level(reset_logging_state):
67 | configure_logging(enable=False)
68 | logger = get_logger("test_logger")
69 | assert logger.level == logging.CRITICAL
70 |
--------------------------------------------------------------------------------
/tests/utils_tests/test_notify_discord.py:
--------------------------------------------------------------------------------
1 | from unittest.mock import MagicMock, patch
2 |
3 | import pytest
4 | from requests import Response
5 |
6 | from ShortsMaker.utils.notify_discord import get_arthas, get_meme, notify_discord
7 |
8 |
9 | @pytest.fixture
10 | def mock_response():
11 | response = MagicMock(spec=Response)
12 | response.status_code = 200
13 | response.text = "Success"
14 | return response
15 |
16 |
17 | @pytest.fixture
18 | def mock_webhook():
19 | with patch("ShortsMaker.utils.notify_discord.DiscordWebhook") as mock:
20 | webhook = MagicMock()
21 | webhook.execute.return_value = MagicMock(spec=Response)
22 | mock.return_value = webhook
23 | yield mock
24 |
25 |
26 | @pytest.fixture
27 | def mock_get_meme():
28 | with patch("ShortsMaker.utils.notify_discord.get_meme") as mock:
29 | mock.return_value = "http://fake-meme.com/image.jpg"
30 | yield mock
31 |
32 |
33 | @pytest.fixture
34 | def mock_get_arthas():
35 | with patch("ShortsMaker.utils.notify_discord.get_arthas") as mock:
36 | mock.return_value = "http://fake-arthas.com/image.jpg"
37 | yield mock
38 |
39 |
40 | def test_get_arthas(requests_mock):
41 | mock_html = """
42 |