├── .gitignore ├── README.md ├── requirements.txt └── tune.py /.gitignore: -------------------------------------------------------------------------------- 1 | /.DS_Store 2 | /downloaded_videos 3 | /downloaded_audio 4 | /extracted_frames 5 | /extracted_frames.zip 6 | /venv 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YouTune 2 | 3 | ![DALL·E 2023-10-26 19 39 10 - Photo of a sophisticated logo for 'YouTune' The design emphasizes the theme of tuning videos It incorporates a stylized film reel that intertwines w](https://github.com/cbh123/youtune/assets/14149230/808411aa-cecc-4735-a2dd-18344317601a) 4 | 5 | [YouTune Video Walkthrough](https://www.loom.com/share/193fa040b8074f44bb5ddabd4dd42b01?sid=4b09aa1b-5cd6-4e4f-a538-d3d62cb1bdc0) 6 | 7 | YouTune makes it really easy to: 8 | 9 | - fine-tune SDXL on images from YouTube videos 10 | - fine-tune MusicGen on audio from YouTube videos 11 | 12 | ## Fine-tuning SDXL 13 | 14 | Just give it a URL and a model name on Replicate, and it’ll download the video, take screenshots of every 50 frames, remove near duplicates and very light/dark images, and create a training for you. 15 | 16 | ```bash 17 | python tune.py 18 | ``` 19 | 20 | ## Fine-tuning MusicGen 21 | 22 | With `--audio`, it’ll download just the audio, convert it to mp3 and create a training for you. 23 | 24 | ```bash 25 | python tune.py --audio 26 | ``` 27 | 28 | ## Setup 29 | 30 | Clone this repo, and setup and activate a virtualenv: 31 | 32 | ```bash 33 | python3 -m pip install virtualenv 34 | python3 -m virtualenv venv 35 | source venv/bin/activate 36 | ``` 37 | 38 | Then, install the dependencies: 39 | `pip install -r requirements.txt` 40 | 41 | Make a [Replicate](https://replicate.com) account and set your token: 42 | 43 | `export REPLICATE_API_TOKEN=` 44 | 45 | ## Run it! 46 | 47 | ```bash 48 | python tune.py 49 | ``` 50 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | annotated-types==0.6.0 2 | anyio==4.0.0 3 | certifi==2023.7.22 4 | h11==0.14.0 5 | httpcore==0.18.0 6 | httpx==0.25.0 7 | idna==3.4 8 | ImageHash==4.3.1 9 | moviepy==1.0.3 10 | numpy==1.26.1 11 | opencv-python==4.8.1.78 12 | packaging==23.2 13 | Pillow==10.1.0 14 | pydantic==2.4.2 15 | pydantic_core==2.10.1 16 | pytube==15.0.0 17 | PyWavelets==1.4.1 18 | replicate==0.15.4 19 | scipy==1.11.3 20 | sniffio==1.3.0 21 | typing_extensions==4.8.0 22 | -------------------------------------------------------------------------------- /tune.py: -------------------------------------------------------------------------------- 1 | from pytube import YouTube 2 | from moviepy.editor import AudioFileClip 3 | import cv2 4 | import os 5 | from PIL import Image 6 | import imagehash 7 | import argparse 8 | import numpy as np 9 | import platform 10 | import subprocess 11 | import zipfile 12 | import webbrowser 13 | import os 14 | import re 15 | from numpy.fft import fftshift, fft2 16 | 17 | 18 | def download_youtube_video(url, save_path="", audio_only=False): 19 | print(f"Downloading {'audio' if audio_only else 'video'} from {url} ...") 20 | yt = YouTube(url) 21 | if audio_only: 22 | stream = yt.streams.get_audio_only() 23 | else: 24 | stream = ( 25 | yt.streams.filter(progressive=True, file_extension="mp4") 26 | .order_by("resolution") 27 | .desc() 28 | .first() 29 | ) 30 | if not save_path: 31 | save_path = stream.default_filename 32 | 33 | stream.download(output_path=save_path) 34 | print(f"{'Audio' if audio_only else 'Video'} downloaded: {stream.default_filename}") 35 | 36 | return os.path.join(save_path, stream.default_filename) 37 | 38 | 39 | def is_mostly_black_or_white(image, threshold, white_threshold=225): 40 | """ 41 | Check if the given image is mostly black or white. 42 | :param image: Image to be checked. 43 | :param threshold: Threshold below which the image is considered 'black'. 44 | :param white_threshold: Threshold above which the image is considered 'white'. 45 | :return: True if the image is mostly black or white; False otherwise. 46 | """ 47 | # Convert image to grayscale 48 | gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 49 | 50 | # Calculate the average brightness of the image 51 | average_brightness = cv2.mean(gray_image)[0] 52 | 53 | return average_brightness < threshold or average_brightness > white_threshold 54 | 55 | 56 | def extract_frames( 57 | video_path, 58 | frame_interval=50, 59 | save_path="", 60 | black_white_threshold=10, 61 | hash_func=imagehash.average_hash, 62 | hash_size=8, 63 | hash_diff_threshold=10, 64 | remove_blur=True, 65 | motion_blur_threshold=-0.02, 66 | ): 67 | if save_path and not os.path.exists(save_path): 68 | os.makedirs(save_path) 69 | 70 | cap = cv2.VideoCapture(video_path) 71 | 72 | if not cap.isOpened(): 73 | print("Error: Could not open video.") 74 | return 75 | 76 | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 77 | fps = cap.get(cv2.CAP_PROP_FPS) 78 | print(f"Total frames: {total_frames}, FPS: {fps}") 79 | 80 | frame_index = 0 81 | saved_frame_count = 0 82 | previous_hash = None 83 | 84 | while True: 85 | ret, frame = cap.read() 86 | if ret: 87 | if frame_index % frame_interval == 0: 88 | pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) 89 | 90 | is_black_or_white = is_mostly_black_or_white( 91 | frame, threshold=black_white_threshold 92 | ) 93 | current_hash = hash_func(pil_image, hash_size=hash_size) 94 | 95 | is_blurry, blur_score = detect_motion_blur(frame, motion_blur_threshold) 96 | 97 | should_save = ( 98 | True # Initialize the flag assuming the frame will be saved 99 | ) 100 | 101 | if previous_hash is not None: 102 | hash_diff = current_hash - previous_hash 103 | if hash_diff < hash_diff_threshold: 104 | should_save = False # If frames are similar, do not save 105 | print( 106 | f"Skipping frame {frame_index} due to perceived duplication (hash diff: {hash_diff})" 107 | ) 108 | 109 | if is_black_or_white: 110 | should_save = ( 111 | False # If frame is mostly black or white, do not save 112 | ) 113 | print( 114 | f"Skipping frame {frame_index} because it's mostly black or white" 115 | ) 116 | 117 | if remove_blur and is_blurry: 118 | should_save = False 119 | print( 120 | f"Skipping frame {frame_index} because it's blurry (blur score: {blur_score}))" 121 | ) 122 | 123 | if should_save: 124 | save_filename = os.path.join(save_path, f"frame_{frame_index}.jpg") 125 | cv2.imwrite(save_filename, frame) 126 | saved_frame_count += 1 127 | 128 | previous_hash = current_hash 129 | 130 | frame_index += 1 131 | else: 132 | break 133 | 134 | cap.release() 135 | print( 136 | f"Done extracting frames. {saved_frame_count} images are saved in '{save_path}'." 137 | ) 138 | 139 | 140 | def open_file_explorer(path): 141 | """ 142 | Open the file explorer at the specified path. 143 | """ 144 | if platform.system() == "Windows": 145 | os.startfile(path) 146 | elif platform.system() == "Darwin": # macOS 147 | subprocess.Popen(["open", path]) 148 | else: # linux 149 | subprocess.Popen(["xdg-open", path]) 150 | 151 | 152 | def user_image_confirmation(path): 153 | """ 154 | Prompt the user to check the images before proceeding. 155 | """ 156 | confirmation = input("Press Enter to open finder to check the images.") 157 | open_file_explorer( 158 | path 159 | ) # This will open the file explorer so you can check the images 160 | 161 | confirmation = input( 162 | "Have you checked the images and do you want to proceed with posting a training? (y/n): " 163 | ) 164 | return confirmation.lower() == "y" 165 | 166 | 167 | def user_audio_confirmation(path): 168 | """ 169 | Prompt the user to check the audio before proceeding. 170 | """ 171 | confirmation = input("Press Enter to check the audio file") 172 | open_file_explorer(path) 173 | 174 | confirmation = input( 175 | "Have you checked the audio and do you want to proceed with posting a training? (y/n): " 176 | ) 177 | return confirmation.lower() == "y" 178 | 179 | 180 | def user_model(): 181 | """ 182 | Prompt the user to input the fine tune model name 183 | """ 184 | does_model_exist = input("Have you already created the model on Replicate? (y/n): ") 185 | 186 | if does_model_exist.lower() == "y": 187 | model_name = input("Please input the model name (owner/model_name): ") 188 | return model_name 189 | else: 190 | name = input( 191 | "What do you want to call the model? Pick a short and memorable name. Use lowercase characters and dashes. (eg: sdxl-barbie, musicgen-ye): " 192 | ) 193 | webbrowser.open(f"https://replicate.com/create?name={name}") 194 | input( 195 | "Once you have created the model (click Create on the webpage that just opened), press Enter to continue." 196 | ) 197 | owner = input("What is your Replicate username? ") 198 | return f"{owner}/{name}" 199 | 200 | 201 | def zip_directory(folder_path, zip_path): 202 | """ 203 | Compress a directory (with all files in it) into a zip file. 204 | 205 | :param folder_path: Path of the folder you want to compress. 206 | :param zip_path: Destination file path, including the filename of the new zip file. 207 | """ 208 | print(f"Zipping {folder_path} to {zip_path} ...") 209 | with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf: 210 | for root, _, files in os.walk(folder_path): 211 | for file in files: 212 | file_path = os.path.join(root, file) 213 | zipf.write(file_path, os.path.relpath(file_path, folder_path)) 214 | 215 | return zip_path 216 | 217 | 218 | def create_sdxl_training(model, save_dir, caption_prefix="in the style of TOK"): 219 | try: 220 | # Please make sure that 'replicate' is installed and available in your system's PATH. 221 | # The command assumes that "nightmare.zip" is correctly placed and accessible. 222 | command = [ 223 | "replicate", 224 | "train", 225 | "stability-ai/sdxl", 226 | "--destination", 227 | model, 228 | "--web", 229 | f"input_images=@{save_dir}", 230 | f"caption_prefix={caption_prefix}", 231 | ] 232 | subprocess.run(command, check=True) 233 | except subprocess.CalledProcessError as e: 234 | print("Error executing the command:", str(e)) 235 | except FileNotFoundError: 236 | print("Error: 'replicate' command not found. Is it installed correctly?") 237 | except Exception as e: 238 | print(f"An error occurred: {str(e)}") 239 | 240 | 241 | def create_musicgen_training(model, save_dir, audio_description, drop_vocals=False): 242 | try: 243 | # Please make sure that 'replicate' is installed and available in your system's PATH. 244 | # The command assumes that "your-audio.mp3" is correctly placed and accessible. 245 | command = [ 246 | "replicate", 247 | "train", 248 | "sakemin/musicgen-fine-tuner", 249 | "--destination", 250 | model, 251 | "--web", 252 | "model=medium", 253 | f"drop_vocals={drop_vocals}", 254 | f"one_same_description={audio_description}", 255 | f"dataset_path=@{save_dir}", 256 | ] 257 | subprocess.run(command, check=True) 258 | except subprocess.CalledProcessError as e: 259 | print("Error executing the command:", str(e)) 260 | except FileNotFoundError: 261 | print("Error: 'replicate' command not found. Is it installed correctly?") 262 | except Exception as e: 263 | print(f"An error occurred: {str(e)}") 264 | 265 | 266 | def is_replicate_api_token_set(): 267 | return "REPLICATE_API_TOKEN" in os.environ 268 | 269 | 270 | def is_replicate_cli_installed(): 271 | try: 272 | subprocess.check_output(["replicate", "--version"]) 273 | return True 274 | except subprocess.CalledProcessError: 275 | return False 276 | except FileNotFoundError: 277 | return False 278 | 279 | 280 | def slugify(title): 281 | """ 282 | Slugify a YouTube title. 283 | 284 | :param title: The title to slugify. 285 | :return: The slugified title. 286 | """ 287 | return re.sub(r"\W+", "-", title).lower() 288 | 289 | 290 | def detect_motion_blur(image, motion_blur_threshold): 291 | # Load the image 292 | image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 293 | 294 | # Scale the image to the range [0, 1] 295 | image = image / 255.0 296 | 297 | # Apply DCT 298 | dct = cv2.dct(np.float32(image)) 299 | 300 | # Compute the average of the DCT coefficients in the high-frequency region 301 | avg = np.mean(dct[1 : int(dct.shape[0] / 2), 1 : int(dct.shape[1] / 2)]) 302 | 303 | avg = avg * 10000 if avg else 0 304 | 305 | # If the average is below a certain threshold, the image is likely blurred 306 | # After some trial and error, around -0.02 seems to be a good threshold 307 | if avg < motion_blur_threshold: 308 | return (True, avg) 309 | else: 310 | return (False, avg) 311 | 312 | 313 | def convert_mp4_to_mp3(mp4_file_path): 314 | print(f"Converting {mp4_file_path} to MP3 ...") 315 | audio_clip = AudioFileClip(mp4_file_path) 316 | mp3_file_path = mp4_file_path.replace(".mp4", ".mp3") 317 | audio_clip.write_audiofile(mp3_file_path, codec="mp3") 318 | audio_clip.close() 319 | os.remove(mp4_file_path) 320 | return mp3_file_path 321 | 322 | 323 | def process_audio(audio_file_path): 324 | audio_file_path = convert_mp4_to_mp3(audio_file_path) 325 | 326 | if user_audio_confirmation(audio_file_path): 327 | # If the user confirms, proceed with the posting function 328 | model = user_model() 329 | 330 | print( 331 | "Please describe the audio, use 2 to 3 comma separated keywords. This could be a band name, genre or something unique. You’ll use this when prompting your fine-tune." 332 | ) 333 | audio_description = input("Describe the audio: ") 334 | 335 | # ask user if they want to drop vocals 336 | print("MusicGen does not train well with audio that has vocals") 337 | drop_vocals_input = input( 338 | "Do you want to automatically drop vocals from your audio? (y/n): " 339 | ) 340 | drop_vocals = True if drop_vocals_input.lower() == "y" else False 341 | 342 | create_musicgen_training(model, audio_file_path, audio_description, drop_vocals) 343 | else: 344 | print("Operation cancelled by the user.") 345 | 346 | 347 | def process_video(video_file_path, interval, caption_prefix): 348 | # slugify the video title 349 | video_name = video_file_path.split("/")[-1] 350 | output_directory = f"./extracted_frames/{slugify(video_name)}" 351 | 352 | extract_frames(video_file_path, frame_interval=interval, save_path=output_directory) 353 | 354 | # After extracting and saving images, ask the user to confirm 355 | if user_image_confirmation(output_directory): 356 | # If the user confirms, proceed with the posting function 357 | model = user_model() 358 | 359 | # Compress the directory with the images 360 | zip_path = zip_directory(output_directory, output_directory + ".zip") 361 | 362 | create_sdxl_training(model, zip_path, caption_prefix=caption_prefix) 363 | else: 364 | print("Operation cancelled by the user.") 365 | 366 | 367 | def main(): 368 | parser = argparse.ArgumentParser( 369 | description="Download a video from YouTube and extract frames or audio" 370 | ) 371 | parser.add_argument("url", help="URL of the YouTube video") 372 | parser.add_argument( 373 | "--interval", help="Interval between frames", default=50, type=int 374 | ) 375 | parser.add_argument( 376 | "--caption_prefix", 377 | help="automatically add this to the start of each caption", 378 | default="in the style of TOK", 379 | type=str, 380 | ) 381 | parser.add_argument("--audio", help="Download audio only", action="store_true") 382 | parser.add_argument( 383 | "--remove_blur", help="remove blurry frames", default=True, action="store_true" 384 | ) 385 | args = parser.parse_args() 386 | 387 | if not is_replicate_cli_installed(): 388 | input( 389 | "🚫 Replicate CLI is not installed. Please install it before proceeding. Link: https://github.com/replicate/cli. Press any key to open the webpage." 390 | ) 391 | webbrowser.open(f"https://github.com/replicate/cli") 392 | else: 393 | print("✅ Replicate CLI is installed. Proceeding...") 394 | 395 | if not is_replicate_api_token_set(): 396 | print( 397 | "🚫 REPLICATE_API_TOKEN is not set. Please set it with `export REPLICATE_API_TOKEN=`, then try again." 398 | ) 399 | return 400 | else: 401 | print("✅ REPLICATE_API_TOKEN is set. Proceeding...") 402 | 403 | if args.audio: 404 | print("🎵 Audio training mode. Proceeding...") 405 | else: 406 | print("🎥 Video training mode. Proceeding...") 407 | 408 | video_url = args.url 409 | interval = args.interval 410 | caption_prefix = args.caption_prefix 411 | 412 | # Directory where you want to save the downloaded video 413 | download_directory = "./downloaded_audio" if args.audio else "./downloaded_videos" 414 | 415 | if video_url.startswith("http"): 416 | video_file_path = download_youtube_video( 417 | video_url, save_path=download_directory, audio_only=args.audio 418 | ) 419 | else: 420 | video_file_path = video_url 421 | 422 | if args.audio: 423 | process_audio(video_file_path) 424 | else: 425 | process_video(video_file_path, interval, caption_prefix) 426 | 427 | 428 | if __name__ == "__main__": 429 | main() 430 | --------------------------------------------------------------------------------