├── Easy_Wav2Lip_v8.3.ipynb ├── GUI.py ├── README.md ├── audio.py ├── checkpoints ├── README.md └── mobilenet.pth ├── config.ini ├── degradations.py ├── easy_functions.py ├── enhance.py ├── hparams.py ├── inference.py ├── install.py ├── models ├── __init__.py ├── conv.py ├── syncnet.py └── wav2lip.py ├── requirements.txt ├── run.py ├── run_loop.bat └── run_loop.sh /GUI.py: -------------------------------------------------------------------------------- 1 | import tkinter as tk 2 | from tkinter import filedialog, ttk 3 | import configparser 4 | import os 5 | 6 | try: 7 | with open('installed.txt', 'r') as file: 8 | version = file.read() 9 | except FileNotFoundError: 10 | print("Easy-Wav2Lip does not appear to have installed correctly.") 11 | print("Please try to install it again.") 12 | print("https://github.com/anothermartz/Easy-Wav2Lip/issues") 13 | input() 14 | exit() 15 | 16 | print("opening GUI") 17 | 18 | runfile = 'run.txt' 19 | if os.path.exists(runfile): 20 | os.remove(runfile) 21 | 22 | import webbrowser 23 | 24 | def open_github_link(event): 25 | webbrowser.open("https://github.com/anothermartz/Easy-Wav2Lip?tab=readme-ov-file#advanced-tweaking") 26 | 27 | def read_config(): 28 | # Read the config.ini file 29 | config = configparser.ConfigParser() 30 | config.read("config.ini") 31 | return config 32 | 33 | def save_config(config): 34 | # Save the updated config back to config.ini 35 | with open("config.ini", "w") as config_file: 36 | config.write(config_file) 37 | 38 | def open_video_file(): 39 | file_path = filedialog.askopenfilename(title="Select a video file", filetypes=[("All files", "*.*")]) 40 | if file_path: 41 | video_file_var.set(file_path) 42 | 43 | def open_vocal_file(): 44 | file_path = filedialog.askopenfilename(title="Select a vocal file", filetypes=[("All files", "*.*")]) 45 | if file_path: 46 | vocal_file_var.set(file_path) 47 | 48 | # feathering 49 | def validate_frame_preview(P): 50 | if P == "": 51 | return True # Allow empty input 52 | try: 53 | num = float(P) 54 | if (num.is_integer()): 55 | return True 56 | except ValueError: 57 | pass 58 | return False 59 | 60 | def start_easy_wav2lip(): 61 | # Start Easy-Wav2Lip processing 62 | print("Saving config") 63 | config["OPTIONS"]["video_file"] = str(video_file_var.get()) 64 | config["OPTIONS"]["vocal_file"] = str(vocal_file_var.get()) 65 | config["OPTIONS"]["quality"] = str(quality_var.get()) 66 | config["OPTIONS"]["output_height"] = str(output_height_combobox.get()) 67 | config["OPTIONS"]["wav2lip_version"] = str(wav2lip_version_var.get()) 68 | config["OPTIONS"]["use_previous_tracking_data"] = str(use_previous_tracking_data_var.get()) 69 | config["OPTIONS"]["nosmooth"] = str(nosmooth_var.get()) 70 | config["OPTIONS"]["preview_window"] = str(preview_window_var.get()) 71 | config["PADDING"]["u"] = str(padding_vars["u"].get()) 72 | config["PADDING"]["d"] = str(padding_vars["d"].get()) 73 | config["PADDING"]["l"] = str(padding_vars["l"].get()) 74 | config["PADDING"]["r"] = str(padding_vars["r"].get()) 75 | config["MASK"]["size"] = str(size_var.get()) 76 | config["MASK"]["feathering"] = str(feathering_var.get()) 77 | config["MASK"]["mouth_tracking"] = str(mouth_tracking_var.get()) 78 | config["MASK"]["debug_mask"] = str(debug_mask_var.get()) 79 | config["OTHER"]["batch_process"] = str(batch_process_var.get()) 80 | config["OTHER"]["output_suffix"] = str(output_suffix_var.get()) 81 | config["OTHER"]["include_settings_in_suffix"] = str(include_settings_in_suffix_var.get()) 82 | config["OTHER"]["preview_settings"] = str(preview_settings_var.get()) 83 | config["OTHER"]["frame_to_preview"] = str(frame_to_preview_var.get()) 84 | save_config(config) # Save the updated config 85 | with open("run.txt", "w") as f: 86 | f.write("run") 87 | exit() 88 | # Add your logic here 89 | 90 | root = tk.Tk() 91 | root.title("Easy-Wav2Lip GUI") 92 | root.geometry("800x720") 93 | root.configure(bg="lightblue") 94 | 95 | # Read the existing config.ini 96 | config = read_config() 97 | 98 | row=0 99 | tk.Label(root, text=version, bg="lightblue").grid(row=row, column=0, sticky="w") 100 | # Create a label for video file 101 | row+=1 102 | video_label = tk.Label(root, text="Video File Path:", bg="lightblue") 103 | video_label.grid(row=row, column=0, sticky="e") 104 | 105 | # Entry widget for video file path 106 | video_file_var = tk.StringVar() 107 | video_entry = tk.Entry(root, textvariable=video_file_var, width=80) 108 | video_entry.grid(row=row, column=1, sticky="w") 109 | 110 | # Create a button to open the file dialog 111 | select_button = tk.Button(root, text="...", command=open_video_file) 112 | select_button.grid(row=row, column=1, sticky="w", padx=490) 113 | 114 | # Set the default value based on the existing config 115 | video_file_var.set(config["OPTIONS"].get("video_file", "")) 116 | 117 | row+=1 118 | tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w") 119 | 120 | # String input for vocal_file 121 | row+=1 122 | 123 | # Create a label for the input box 124 | vocal_file_label = tk.Label(root, text="Vocal File Path:", bg="lightblue") 125 | vocal_file_label.grid(row=row, column=0, sticky="e") 126 | 127 | # Create an input box for the vocal file path 128 | vocal_file_var = tk.StringVar() 129 | vocal_file_entry = tk.Entry(root, textvariable=vocal_file_var, width=80) 130 | vocal_file_entry.grid(row=row, column=1, sticky="w") 131 | 132 | # Create a button to open the file dialog 133 | select_button = tk.Button(root, text="...", command=open_vocal_file) 134 | select_button.grid(row=row, column=1, sticky="w", padx=490) 135 | 136 | # Set the initial value from the 'config' dictionary (if available) 137 | vocal_file_var.set(config["OPTIONS"].get("vocal_file", "")) 138 | 139 | row+=1 140 | tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w") 141 | 142 | # Dropdown box for quality options 143 | row+=1 144 | quality_label = tk.Label(root, text="Select Quality:", bg="lightblue") 145 | quality_label.grid(row=row, column=0, sticky="e") 146 | quality_options = ["Fast", "Improved", "Enhanced"] 147 | quality_var = tk.StringVar() 148 | quality_var.set(config["OPTIONS"].get("quality", "Improved")) 149 | quality_dropdown = tk.OptionMenu(root, quality_var, *quality_options) 150 | quality_dropdown.grid(row=row, column=1, sticky="w") 151 | 152 | row+=1 153 | tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w") 154 | 155 | # Output height 156 | row+=1 157 | output_height_label = tk.Label(root, text="Output height:", bg="lightblue") 158 | output_height_label.grid(row=row, column=0, sticky="e") 159 | output_height_options = ["half resolution", "full resolution"] 160 | output_height_combobox = ttk.Combobox(root, values=output_height_options) 161 | output_height_combobox.set(config["OPTIONS"].get("output_height", "full resolution")) # Set default value 162 | output_height_combobox.grid(row=row, column=1, sticky="w") 163 | 164 | row+=1 165 | tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w") 166 | 167 | # Dropdown box for wav2lip version options 168 | row+=1 169 | wav2lip_version_label = tk.Label(root, text="Select Wav2Lip version:", bg="lightblue") 170 | wav2lip_version_label.grid(row=row, column=0, sticky="e") 171 | wav2lip_version_options = ["Wav2Lip", "Wav2Lip_GAN"] 172 | wav2lip_version_var = tk.StringVar() 173 | wav2lip_version_var.set(config["OPTIONS"].get("wav2lip_version", "Wav2Lip")) 174 | wav2lip_version_dropdown = tk.OptionMenu(root, wav2lip_version_var, *wav2lip_version_options) 175 | wav2lip_version_dropdown.grid(row=row, column=1, sticky="w") 176 | 177 | row+=1 178 | tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w") 179 | # output_suffix 180 | row+=1 181 | output_suffix_label = tk.Label(root, text="Output File Suffix:", bg="lightblue") 182 | output_suffix_label.grid(row=row, column=0, sticky="e") 183 | output_suffix_var = tk.StringVar() 184 | output_suffix_var.set(config["OTHER"].get("output_suffix", "_Easy-Wav2Lip")) 185 | output_suffix_entry = output_suffix_entry = tk.Entry(root, textvariable=output_suffix_var, width=20) 186 | output_suffix_entry.grid(row=row, column=1, sticky="w") 187 | 188 | include_settings_in_suffix_var = tk.BooleanVar() 189 | include_settings_in_suffix_var.set(config["OTHER"].get("include_settings_in_suffix", True)) # Set default value 190 | include_settings_in_suffix_checkbox = tk.Checkbutton(root, text="Add Settings to Suffix", variable=include_settings_in_suffix_var, bg="lightblue") 191 | include_settings_in_suffix_checkbox.grid(row=row, column=1, sticky="w", padx=130) 192 | 193 | # batch_process 194 | row+=1 195 | tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w") 196 | row+=1 197 | batch_process_label = tk.Label(root, text="Batch Process:", bg="lightblue") 198 | batch_process_label.grid(row=row, column=0, sticky="e") 199 | batch_process_var = tk.BooleanVar() 200 | batch_process_var.set(config["OTHER"].get("batch_process", True)) # Set default value 201 | batch_process_checkbox = tk.Checkbutton(root, text="", variable=batch_process_var, bg="lightblue") 202 | batch_process_checkbox.grid(row=row, column=1, sticky="w") 203 | 204 | # Dropdown box for preview window options 205 | row+=1 206 | preview_window_label = tk.Label(root, text="Preview Window:", bg="lightblue") 207 | preview_window_label.grid(row=row, column=0, sticky="e") 208 | preview_window_options = ["Face", "Full", "Both", "None"] 209 | preview_window_var = tk.StringVar() 210 | preview_window_var.set(config["OPTIONS"].get("preview_window", "Face")) 211 | preview_window_dropdown = tk.OptionMenu(root, preview_window_var, *preview_window_options) 212 | preview_window_dropdown.grid(row=row, column=1, sticky="w") 213 | 214 | row+=1 215 | tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w") 216 | 217 | # Button to start Easy-Wav2Lip 218 | row+=1 219 | start_button = tk.Button(root, text="Start Easy-Wav2Lip", command=start_easy_wav2lip, bg="#5af269", font=("Arial", 16)) 220 | start_button.grid(row=row, column=0, sticky="w", padx=290, columnspan=2) 221 | 222 | row+=1 223 | tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w") 224 | tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w") 225 | 226 | row+=1 227 | tk.Label(root, text="Advanced Tweaking:", bg="lightblue", font=("Arial", 16)).grid(row=row, column=0, sticky="w") 228 | row+=1 229 | # Create a label with a custom cursor 230 | link = tk.Label(root, text="(Click here to see readme)", bg="lightblue", fg="blue", font=("Arial", 10), cursor="hand2") 231 | link.grid(row=row, column=0) 232 | 233 | # Bind the click event to the label 234 | link.bind("", open_github_link) 235 | 236 | # Process one frame only 237 | preview_settings_var = tk.BooleanVar() 238 | preview_settings_var.set(config["OTHER"].get("preview_settings", True)) # Set default value 239 | preview_settings_checkbox = tk.Checkbutton(root, text="Process one frame only - Frame to process:", variable=preview_settings_var, bg="lightblue") 240 | preview_settings_checkbox.grid(row=row, column=1, sticky="w") 241 | 242 | frame_to_preview_var = tk.StringVar() 243 | frame_to_preview_entry = tk.Entry(root, textvariable=frame_to_preview_var, validate="key", width=3, validatecommand=(root.register(validate_frame_preview), "%P")) 244 | frame_to_preview_entry.grid(row=row, column=1, sticky="w", padx=255) 245 | frame_to_preview_var.set(config["OTHER"].get("frame_to_preview", "100")) 246 | 247 | # Checkbox for nosmooth option 248 | row+=1 249 | nosmooth_var = tk.BooleanVar() 250 | nosmooth_var.set(config["OPTIONS"].get("nosmooth", True)) # Set default value 251 | nosmooth_checkbox = tk.Checkbutton(root, text="nosmooth - unticking will smooth face detection between 5 frames", variable=nosmooth_var, bg="lightblue") 252 | nosmooth_checkbox.grid(row=row, column=1, sticky="w") 253 | 254 | # Checkbox for use_previous_tracking_data option 255 | row+=1 256 | use_previous_tracking_data_var = tk.BooleanVar() 257 | use_previous_tracking_data_var.set(config["OPTIONS"].get("use_previous_tracking_data", True)) # Set default value 258 | use_previous_tracking_data_checkbox = tk.Checkbutton(root, text="Keep previous face tracking data if using same video", variable=use_previous_tracking_data_var, bg="lightblue") 259 | use_previous_tracking_data_checkbox.grid(row=row, column=1, sticky="w") 260 | 261 | # padding 262 | row+=1 263 | tk.Label(root, text="Padding:", bg="lightblue", font=("Arial", 12)).grid(row=row, column=1, sticky="sw", pady=10) 264 | row+=1 265 | tk.Label(root, text="(Up, Down, Left, Right)", bg="lightblue").grid(row=row, column=1, rowspan=4, sticky="w", padx=100) 266 | padding_vars = {} 267 | 268 | # Create a list of padding labels and their corresponding keys 269 | padding_labels = [("U:", "u"), ("D:", "d"), ("L:", "l"), ("R:", "r")] 270 | 271 | # Validation function to allow only integers 272 | def validate_integer(P): 273 | if P == "" or P == "-" or P.lstrip("-").isdigit(): 274 | return True 275 | return False 276 | 277 | # Create the padding labels and entry widgets using a loop 278 | for label_text, key in padding_labels: 279 | label = tk.Label(root, text=label_text, bg="lightblue") 280 | label.grid(row=row, column=1, sticky="w", padx=50) 281 | 282 | # Create a StringVar for the current key 283 | padding_var = tk.StringVar() 284 | 285 | # Set validation to allow positive and negative integers 286 | entry = tk.Entry(root, textvariable=padding_var, width=3, validate="key", validatecommand=(root.register(validate_integer), "%P")) 287 | entry.grid(row=row, column=1, sticky="w", padx=70) 288 | 289 | # Set the default value from the 'config' dictionary 290 | padding_var.set(config["PADDING"].get(key, "")) 291 | 292 | # Store the StringVar in the dictionary 293 | padding_vars[key] = padding_var 294 | 295 | # Increment the row 296 | row += 1 297 | 298 | 299 | tk.Label(root, text="", bg="lightblue").grid(row=row, column=0, sticky="w") 300 | row+=1 301 | # mask size 302 | def validate_custom_number(P): 303 | if P == "": 304 | return True # Allow empty input 305 | try: 306 | num = float(P) 307 | if 0 <= num <= 6 and (num.is_integer() or (num * 10) % 1 == 0): 308 | return True 309 | except ValueError: 310 | pass 311 | return False 312 | 313 | row+=1 314 | tk.Label(root, text="Mask settings:", bg="lightblue", font=("Arial", 12)).grid(row=row, column=1, sticky="sw") 315 | row+=1 316 | size_label = tk.Label(root, text="Mask size:", bg="lightblue", padx=50) 317 | size_label.grid(row=row, column=1, sticky="w") 318 | size_var = tk.StringVar() 319 | size_entry = tk.Entry(root, textvariable=size_var, validate="key", width=3, validatecommand=(root.register(validate_custom_number), "%P")) 320 | size_entry.grid(row=row, column=1, sticky="w", padx=120) 321 | size_var.set(config["MASK"].get("size", "2.5")) 322 | 323 | # feathering 324 | def validate_feather(P): 325 | if P == "": 326 | return True # Allow empty input 327 | try: 328 | num = float(P) 329 | if 0 <= num <= 3 and (num.is_integer()): 330 | return True 331 | except ValueError: 332 | pass 333 | return False 334 | 335 | row+=1 336 | feathering_label = tk.Label(root, text="Feathering:", bg="lightblue", padx=50) 337 | feathering_label.grid(row=row, column=1, sticky="w") 338 | feathering_var = tk.StringVar() 339 | feathering_entry = tk.Entry(root, textvariable=feathering_var, validate="key", width=3, validatecommand=(root.register(validate_feather), "%P")) 340 | feathering_entry.grid(row=row, column=1, sticky="w", padx=120) 341 | feathering_var.set(config["MASK"].get("feathering", "2.5")) 342 | 343 | # mouth_tracking 344 | row+=1 345 | mouth_tracking_var = tk.BooleanVar() 346 | mouth_tracking_var.set(config["MASK"].get("mouth_tracking", True)) # Set default value 347 | mouth_tracking_checkbox = tk.Checkbutton(root, text="track mouth for mask on every frame", variable=mouth_tracking_var, bg="lightblue", padx=50) 348 | mouth_tracking_checkbox.grid(row=row, column=1, sticky="w") 349 | 350 | # debug_mask 351 | row+=1 352 | debug_mask_var = tk.BooleanVar() 353 | debug_mask_var.set(config["MASK"].get("debug_mask", True)) # Set default value 354 | debug_mask_checkbox = tk.Checkbutton(root, text="highlight mask for debugging", variable=debug_mask_var, bg="lightblue", padx=50) 355 | debug_mask_checkbox.grid(row=row, column=1, sticky="w") 356 | 357 | # Increase spacing between all rows (uniformly) 358 | for row in range(row): 359 | root.rowconfigure(row, weight=1) 360 | 361 | 362 | root.mainloop() 363 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Contents: 2 | 1. [Introduction](https://github.com/anothermartz/Easy-Wav2Lip?tab=readme-ov-file#easy-wav2lip-improves-wav2lip-video-lipsyncing-making-it) 3 | 2. [Google Colab version (free cloud computing in-browser)](https://github.com/anothermartz/Easy-Wav2Lip?tab=readme-ov-file#google-colab) 4 | 3. [Local Installation](https://github.com/anothermartz/Easy-Wav2Lip?tab=readme-ov-file#local-installation) 5 | 4. [Support](https://github.com/anothermartz/Easy-Wav2Lip?tab=readme-ov-file#support) 6 | 5. [Best Practices](https://github.com/anothermartz/Easy-Wav2Lip?tab=readme-ov-file#best-practices) 7 | # Easy-Wav2Lip improves Wav2Lip video lipsyncing making it: 8 | 9 | ## Easier: 10 | * Simple setup and execution - locally and via colab. 11 | * no messing around manually downloading and installing prerequesits 12 | * Google Colab has only 2 cells to execute 13 | * Windows users only need one file to install, update and run. 14 | * Well documented options below. 15 | * No more wondering what anything does! 16 | 17 | ## Faster: 18 | For my 9 second 720p 60fps test clip via Colab T4: 19 | | Original Wav2Lip | Easy-Wav2Lip | 20 | |:-------|:-----| 21 | | Execution time: 6m 53s | Execution time: 56s | 22 | 23 | That's not a typo! My clip goes from almost 7 minutes to under 1 minute! 24 | 25 | The tracking data is saved between generations of the same video, saving even more time: 26 | | Easy-Wav2Lip on the same video again | 27 | |:-----| 28 | | Execution time: 25s | 29 | 30 | ## Better looking: 31 | 32 | Easy-Wav2Lip fixes visual bugs on the lips: 33 | 34 | [![Comparison gif](https://github.com/anothermartz/Easy-Wav2Lip/releases/download/Prerequesits/wav2lipcomparison.gif)](https://github.com/anothermartz/Easy-Wav2Lip/releases/download/Prerequesits/wav2lipcomparison.gif) 35 | 36 | 3 Options for Quality: 37 | * Fast: Wav2Lip 38 | * Improved: Wav2Lip with a feathered mask around the mouth to restore the original resolution for the rest of the face 39 | * Enhanced: Wav2Lip + mask + GFPGAN upscaling done on the face 40 | 41 | [![Comparison gif](https://github.com/anothermartz/Easy-Wav2Lip/releases/download/Prerequesits/JPComparison.gif)](https://github.com/anothermartz/Easy-Wav2Lip/releases/download/Prerequesits/JPComparison.gif) 42 | 43 | 44 | # Installation: 45 | 46 | ### For the easiest and most compatible way to use this tool, use the Google Colab version: 47 | 48 | ### Google Colab: 49 | [https://colab.research.google.com/github/anothermartz/Easy-Wav2Lip/blob/v8.3/Easy_Wav2Lip_v8.3.ipynb](https://colab.research.google.com/github/anothermartz/Easy-Wav2Lip/blob/v8.3/Easy_Wav2Lip_v8.3.ipynb) 50 | 51 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/anothermartz/Easy-Wav2Lip/blob/v8.3/Easy_Wav2Lip_v8.3.ipynb) 52 | 53 | ## Local Installation: 54 | Requirements: 55 | Nvidia card that supports cuda 12.2 56 | Or 57 | MacOS device that supports mps via Apple silicon or AMD GPU 58 | 59 | 60 | ### Automatic installation for Windows 64-bit and x86 processor: 61 | 1. Download [Easy-Wav2Lip.bat](https://github.com/anothermartz/Easy-Wav2Lip/blob/Installers/Easy-Wav2Lip.bat) 62 | 2. Place it in a folder on your PC (EG: in Documents) 63 | 3. Run it and follow the instructions. It will make a folder called Easy-Wav2Lip within whatever folder you run it from. 64 | 4. Run this file whenever you want to use Easy-Wav2Lip 65 | 66 | This should handle the installation of all required components. 67 | 68 | ### Manual installation: 69 | 1. Make sure the following are installed and can be accessed via your terminal: 70 | * Python 3.10 (I have only tested [3.10.11](https://www.python.org/ftp/python/3.10.11/) - other versions may not work!) 71 | * [Git](https://git-scm.com/) 72 | * Windows & Linux: Cuda (Just having the latest Nvidia drivers will do this, I have only tested 12.2) 73 | 74 | 2. Run the following in your terminal once you've navigated to the folder you want to install Easy-Wav2Lip: 75 | 76 | ### Windows manual installation: 77 | Sets up a venv, installs ffmpeg to it and then installs Easy-Wav2Lip: 78 | 1. Open cmd and navigate to the folder you want to install EasyWav2Lip using cd 79 | EG: 80 | `cd Documents` 81 | 82 | 2. Copy and paste the following code into your cmd window: 83 | Note: 2 folders will be made in this location: Easy-Wav2Lip and Easy-Wav2Lip-venv (an isolated python install) 84 | ``` 85 | py -3.10 -m venv Easy-Wav2Lip-venv 86 | Easy-Wav2Lip-venv\Scripts\activate 87 | python -m pip install --upgrade pip 88 | python -m pip install requests 89 | set url=https://github.com/BtbN/FFmpeg-Builds/releases/download/latest/ffmpeg-master-latest-win64-gpl.zip 90 | python -c "import requests; r = requests.get('%url%', stream=True); open('ffmpeg.zip', 'wb').write(r.content)" 91 | powershell -Command "Expand-Archive -Path .\\ffmpeg.zip -DestinationPath .\\" 92 | xcopy /e /i /y "ffmpeg-master-latest-win64-gpl\bin\*" .\Easy-Wav2Lip-venv\Scripts 93 | del ffmpeg.zip 94 | rmdir /s /q ffmpeg-master-latest-win64-gpl 95 | git clone https://github.com/anothermartz/Easy-Wav2Lip.git 96 | cd Easy-Wav2Lip 97 | pip install -r requirements.txt 98 | python install.py 99 | ``` 100 | Now to run Easy-Wav2Lip:
101 | 3. Close and reopen cmd then cd to the same directory as in Step 1.
102 | 4. Paste the following code: 103 | ``` 104 | Easy-Wav2Lip-venv\Scripts\activate 105 | cd Easy-Wav2Lip 106 | call run_loop.bat 107 | ``` 108 | See [Usage](https://github.com/anothermartz/Easy-Wav2Lip?tab=readme-ov-file#usage) for further instructions. 109 | 110 | ### MacOS and Linux installation (untested): 111 | Sets up a venv, installs ffmpeg to it and then installs Easy-Wav2Lip: 112 | 1. Open Terminal and navigate to the folder you want to insteall Easy0Wav2Kip using cd 113 | EG: 114 | `cd ~/Documents` 115 | 116 | 2. Copy and paste the following code into your terminal window: 117 | Note: 2 folders will be made in this location: Easy-Wav2Lip and Easy-Wav2Lip-venv (an isolated python install) 118 | ``` 119 | python3.10 -m venv Easy-Wav2Lip-venv 120 | source EW2Lvenv/bin/activate 121 | python -m pip install --upgrade pip 122 | python -m pip install requests 123 | for file in ffmpeg ffprobe ffplay; do 124 | curl -O "https://evermeet.cx/ffmpeg/${file}-6.1.1.zip" 125 | unzip "${file}-6.1.1.zip" 126 | done 127 | mv -f ffmpeg ffprobe ffplay /Easy-Wav2Lip-venv/bin/ 128 | rm -f ffmpeg-6.1.1.zip ffprobe-6.1.1.zip ffplay-6.1.1.zip 129 | source EW2Lvenv/bin/activate 130 | git clone https://github.com/anothermartz/Easy-Wav2Lip.git 131 | cd Easy-Wav2Lip 132 | pip install -r requirements.txt 133 | python install.py 134 | ``` 135 | Now to run Easy-Wav2Lip:
136 | 3. Close and reopen terminal then cd to the same directory as in Step 1.
137 | 4. Paste the following code: 138 | ``` 139 | source Easy-Wav2Lip-venv/bin/activate 140 | cd Easy-Wav2Lip 141 | ./run_loop.sh 142 | ``` 143 | 144 | ## Usage: 145 | * Once everything is installed, a file called config.ini should pop up. 146 | * Add the path(s) to your video and audio files here and configure the settings to your liking.

147 | **Pro Tip:** 148 | * On Windows Hold shift when right clicking on the file you want to use, then press "a" or click "copy as path" and that'll get the path that you can paste as video_file or vocal_file. 149 | * MacOS: Right-click on the file, hold Option (Alt) key and select “Copy [filename] as Pathname” from the context menu. 150 | 151 | * Save config.ini and close it, this will start the Wav2Lip process and your file will be saved in the same directory as your video_path file. 152 | * config.ini will open again and you can change inputs and settings. 153 | * See [Best Practices](https://github.com/anothermartz/Easy-Wav2Lip?tab=readme-ov-file#best-practices) below for tips on how to get started. 154 | * See [Advanced Tweaking](https://github.com/anothermartz/Easy-Wav2Lip?tab=readme-ov-file#advanced-tweaking) below for the explanation of the settings not already explained in config.ini 155 | 156 | # Credits: 157 | * [The Original Wav2Lip](https://github.com/Rudrabha/Wav2Lip) of course. 158 | * The huge speed increase and improved base quality comes from [cog-Wav2Lip](https://github.com/devxpy/cog-Wav2Lip). 159 | * Code to upscale with [GFPGAN](https://github.com/TencentARC/GFPGAN) mainly came from [wav2lip-hq-updated-ESRGAN](https://github.com/GucciFlipFlops1917/wav2lip-hq-updated-ESRGAN). 160 | * I couldn't have done this without AI assistance; Before making this I had very minimal python experience! LLM of choice: **Bing Chat** (now called 'Copilot'). 161 | * Thanks to [JustinJohn](https://github.com/justinjohn0306) for making the [Wav2Lip_simplified](https://colab.research.google.com/github/justinjohn0306/Wav2Lip/blob/master/Wav2Lip_simplified_v5.ipynb) colabs which inspired me to make my own, even simpler version. 162 | 163 | ## Support 164 | If you're having issues running this, please look through the [issues tab](https://github.com/anothermartz/Easy-Wav2Lip/issues) to see if someone has written about it. If not, make a new thread but make sure you include the following:
165 |
**If colab:** 166 | - Easy-Wav2Lip colab version number 167 | - Info about the files used. 168 | 169 |
**If local install:** 170 | - EasyWav2Lip.bat or manual install 171 | - Operating system (windows 11, linux etc.) 172 | - GPU model 173 | - GPU driver version 174 | - Python version 175 | - Info about the files used and if other files work 176 | 177 | Without this info, I'll just ask for it anyway and so a response about the issue itself will take longer. 178 | 179 | Chances are that if any of those are different from [the requirements](https://github.com/anothermartz/Easy-Wav2Lip?tab=readme-ov-file#local-installation) then that's the reason it's not working and you may just have to use the colab version if not already. 180 | 181 | For general chit chat about this and any other lipsync talk, I'll be in this discord:
182 | Invite link: https://discord.gg/FNZR9ETwKY
183 | Wav2Lip channel: https://discord.com/channels/667279414681272320/1076077584330280991 184 | 185 | # Best practices: 186 | * The best results come from lining up the speech to the actions and expressions of the speaker before you send it through wav2lip! 187 | 188 | Video files: 189 | * Must have a face in all frames or Wav2Lip will fail 190 | * Crop or mask out faces you don't want to lipsync or it'll choose randomly. 191 | * Use h264 .mp4 - other file types may be supported but this is what it outputs as 192 | * Images are currently untested. 193 | * Use a small file in every way (try <720p, <30 seconds, 30fps etc. - Bigger files may work but are usually the reason it fails) 194 | * For your first try, use a really tiny clip just to get used to the process, only once you're familiar should you try bigger files to see if they work. 195 | 196 | Audio files: 197 | * Save as .wav and the same length as your input video. 198 | * NOTE: I've noticed that about 80ms gets cut from the processed video/audio and I'm not sure how to fix this, so make sure you have a little extra than what you actually need! 199 | * You can just encode it into your video file and leave vocal_path blank, but this will add a couple of seconds to the processing time as it splits the audio from the video 200 | * OR 201 | * Select your audio file separately 202 | * I'm not certain what filetypes are supported, at least .wav and .mp3 work. 203 | 204 | # Advanced Tweaking: 205 | ## wav2lip_version: 206 | | Option | Pros | Cons | 207 | |:-------|:-----|:-----| 208 | | Wav2Lip | + More accurate lipsync
+ Attempts to keep the mouth closed when there is no sound | - Sometimes produces missing teeth (uncommon) | 209 | | Wav2Lip_GAN | + Looks nicer
+ Keeps the original expressions of the speaker more | - Not as good at masking the original lip movements, especially when there is no sound | 210 | 211 | I suggest trying Wav2Lip first and switching to the GAN version if you experience an effect where it looks like the speaker has big gaps in their teeth. 212 | 213 | ### nosmooth: 214 | * When enabled, wav2lip will crop the face on each frame independently. 215 | * Good for fast movements or cuts in the video. 216 | * May cause twitching if the face is on a weird angle. 217 | 218 | * When disabled, wav2lip will blend the detected position of the face between 5 frames. 219 | * Good for slow movements, especially for faces on an unusual angle. 220 | * Mouth can be offset when the face moves within the frame quickly, looks horrible between cuts. 221 | 222 | ## Padding: 223 | This option controls how many pixels are added or removed from the face crop in each direction. 224 | 225 | | Value | Example | Effect | 226 | |:------|:--------|:-------| 227 | | U | U = -5 | Removes 5 pixels from the top of the face | 228 | | D | D = 10 | Adds 10 pixels to the bottom of the face | 229 | | L | L = 0 | No change to the left of the face | 230 | | R | R = 15 | Adds 15 pixels to the right of the face | 231 | 232 | Padding can help remove hard lines at the chin or other edges of the face, but too much or too little padding can change the size or position of the mouth. It's common practice to add 10 pixels to the bottom, but you should experiment with different values to find the best balance for your clip. 233 | 234 | ## Mask: 235 | This option controls how the processed face is blended with the original face. This has no effect on the "Fast" quality option. 236 | 237 | * **size** will increase the size of the area that the mask covers. 238 | * **feathering** determines the amount of blending between the centre of the mask and the edges. 239 | * **mouth_tracking** will update the position of the mask to where the mouth is on every frame (slower) 240 | * * Note: The mouth position is already well approximated due to the frame being cropped to the face, enable this only if you find a video where the mask doesn't appear to follow the mouth. 241 | * **debug_mask** will make the background grayscale and the mask in colour so that you can easily see where the mask is in the frame. 242 | 243 | # Other options: 244 | 245 | # Batch processing: 246 | This option allows you to process multiple video and/or audio files automatically. 247 | * Name your files with a number at the end, eg. Video1.mp4, Video2.mp4, etc. and put them all in the same folder. 248 | * Files will be processed in numerical order starting from the one you select. For example, if you select Video3.mp4, it will process Video3.mp4, Video4.mp4, and so on. 249 | * If you select numbered video files and a non-numbered audio file, it will process each video with the same audio file. Useful for making different images/videos say the same line. 250 | * Likewise, if you select a non-numbered video file and numbered audio files, it will use the same video for each audio file. Useful for making the same image/video say different things. 251 | 252 | ### output_suffix: 253 | This adds a suffix to your output files so that they don't overwite your originals. 254 | 255 | ### include_settings_in_suffix: 256 | Adds what settings were used - good for comparing different settings as you will know what you used for each render. 257 | Will add: Qualty_resolution_nosmooth_pads-UDLR 258 | EG: _Enhanced_720_nosmooth1_pads-U15D10L-15R30 259 | pads_UDLR will not be included if they are set to 0. 260 | resolution will not be included if it output_height is set to full resolution 261 | 262 | ### preview_input 263 | Displays the input video/audio before processing so you can check to make sure you chose the correct file(s). It may only work with .mp4, I just know it didn't work on an .avi I tried. 264 | Disabling this will save a few seconds of processing time for each video. 265 | 266 | ### preview_settings 267 | This will render only 1 frame of your video and display it at full size, this is so you can tweak the settings without having to render the entire video each time. 268 | frame_to_preview is for selecting a particular frame you want to check out - may not be completely accurate to the actual frame. 269 | -------------------------------------------------------------------------------- /audio.py: -------------------------------------------------------------------------------- 1 | import librosa 2 | import librosa.filters 3 | import numpy as np 4 | 5 | # import tensorflow as tf 6 | from scipy import signal 7 | from scipy.io import wavfile 8 | from hparams import hparams as hp 9 | 10 | 11 | def load_wav(path, sr): 12 | return librosa.core.load(path, sr=sr)[0] 13 | 14 | 15 | def save_wav(wav, path, sr): 16 | wav *= 32767 / max(0.01, np.max(np.abs(wav))) 17 | # proposed by @dsmiller 18 | wavfile.write(path, sr, wav.astype(np.int16)) 19 | 20 | 21 | def save_wavenet_wav(wav, path, sr): 22 | librosa.output.write_wav(path, wav, sr=sr) 23 | 24 | 25 | def preemphasis(wav, k, preemphasize=True): 26 | if preemphasize: 27 | return signal.lfilter([1, -k], [1], wav) 28 | return wav 29 | 30 | 31 | def inv_preemphasis(wav, k, inv_preemphasize=True): 32 | if inv_preemphasize: 33 | return signal.lfilter([1], [1, -k], wav) 34 | return wav 35 | 36 | 37 | def get_hop_size(): 38 | hop_size = hp.hop_size 39 | if hop_size is None: 40 | assert hp.frame_shift_ms is not None 41 | hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate) 42 | return hop_size 43 | 44 | 45 | def linearspectrogram(wav): 46 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 47 | S = _amp_to_db(np.abs(D)) - hp.ref_level_db 48 | 49 | if hp.signal_normalization: 50 | return _normalize(S) 51 | return S 52 | 53 | 54 | def melspectrogram(wav): 55 | D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize)) 56 | S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db 57 | 58 | if hp.signal_normalization: 59 | return _normalize(S) 60 | return S 61 | 62 | 63 | def _lws_processor(): 64 | import lws 65 | 66 | return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech") 67 | 68 | 69 | def _stft(y): 70 | if hp.use_lws: 71 | return _lws_processor(hp).stft(y).T 72 | else: 73 | return librosa.stft( 74 | y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size 75 | ) 76 | 77 | 78 | ########################################################## 79 | # Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!) 80 | def num_frames(length, fsize, fshift): 81 | """Compute number of time frames of spectrogram""" 82 | pad = fsize - fshift 83 | if length % fshift == 0: 84 | M = (length + pad * 2 - fsize) // fshift + 1 85 | else: 86 | M = (length + pad * 2 - fsize) // fshift + 2 87 | return M 88 | 89 | 90 | def pad_lr(x, fsize, fshift): 91 | """Compute left and right padding""" 92 | M = num_frames(len(x), fsize, fshift) 93 | pad = fsize - fshift 94 | T = len(x) + 2 * pad 95 | r = (M - 1) * fshift + fsize - T 96 | return pad, pad + r 97 | 98 | 99 | ########################################################## 100 | # Librosa correct padding 101 | def librosa_pad_lr(x, fsize, fshift): 102 | return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0] 103 | 104 | 105 | # Conversions 106 | _mel_basis = None 107 | 108 | 109 | def _linear_to_mel(spectogram): 110 | global _mel_basis 111 | if _mel_basis is None: 112 | _mel_basis = _build_mel_basis() 113 | return np.dot(_mel_basis, spectogram) 114 | 115 | 116 | def _build_mel_basis(): 117 | assert hp.fmax <= hp.sample_rate // 2 118 | return librosa.filters.mel( 119 | sr=hp.sample_rate, 120 | n_fft=hp.n_fft, 121 | n_mels=hp.num_mels, 122 | fmin=hp.fmin, 123 | fmax=hp.fmax, 124 | ) 125 | 126 | 127 | def _amp_to_db(x): 128 | min_level = np.exp(hp.min_level_db / 20 * np.log(10)) 129 | return 20 * np.log10(np.maximum(min_level, x)) 130 | 131 | 132 | def _db_to_amp(x): 133 | return np.power(10.0, (x) * 0.05) 134 | 135 | 136 | def _normalize(S): 137 | if hp.allow_clipping_in_normalization: 138 | if hp.symmetric_mels: 139 | return np.clip( 140 | (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) 141 | - hp.max_abs_value, 142 | -hp.max_abs_value, 143 | hp.max_abs_value, 144 | ) 145 | else: 146 | return np.clip( 147 | hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 148 | 0, 149 | hp.max_abs_value, 150 | ) 151 | 152 | assert S.max() <= 0 and S.min() - hp.min_level_db >= 0 153 | if hp.symmetric_mels: 154 | return (2 * hp.max_abs_value) * ( 155 | (S - hp.min_level_db) / (-hp.min_level_db) 156 | ) - hp.max_abs_value 157 | else: 158 | return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)) 159 | 160 | 161 | def _denormalize(D): 162 | if hp.allow_clipping_in_normalization: 163 | if hp.symmetric_mels: 164 | return ( 165 | (np.clip(D, -hp.max_abs_value, hp.max_abs_value) + hp.max_abs_value) 166 | * -hp.min_level_db 167 | / (2 * hp.max_abs_value) 168 | ) + hp.min_level_db 169 | else: 170 | return ( 171 | np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value 172 | ) + hp.min_level_db 173 | 174 | if hp.symmetric_mels: 175 | return ( 176 | (D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value) 177 | ) + hp.min_level_db 178 | else: 179 | return (D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db 180 | -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /checkpoints/mobilenet.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/anothermartz/Easy-Wav2Lip/99928e95687dcf1e31e2a0ac8df3ed64a03bc5ee/checkpoints/mobilenet.pth -------------------------------------------------------------------------------- /config.ini: -------------------------------------------------------------------------------- 1 | # if this file opened after running Easy-Wav2Lip.bat, 2 | # please configure accordingly, save and then close this file, 3 | # processing will then begin automatically. 4 | 5 | [OPTIONS] 6 | 7 | video_file = 8 | vocal_file = 9 | 10 | quality = Improved 11 | # Options: 12 | ; Fast: Wav2Lip only 13 | ; Improved: Wav2Lip with a feathered mask around the mouth to remove the square around the face 14 | ; Enhanced: Wav2Lip + mask + GFPGAN upscaling done on the face 15 | ; Experimental: Test version of applying gfpgan - see release notes 16 | 17 | output_height = full resolution 18 | 19 | # Options: 20 | ; full resolution 21 | ; half resolution 22 | ; video height in pixels eg: 480 23 | 24 | wav2lip_version = Wav2Lip 25 | # Wav2Lip or Wav2Lip_GAN 26 | 27 | # Please consult the readme for this and the rest of the options: 28 | ; https://github.com/anothermartz/Easy-Wav2Lip#advanced-tweaking 29 | 30 | use_previous_tracking_data = True 31 | 32 | nosmooth = True 33 | 34 | preview_window = Full 35 | 36 | [PADDING] 37 | u = 0 38 | d = 0 39 | l = 0 40 | r = 0 41 | 42 | [MASK] 43 | size = 2.5 44 | feathering = 2 45 | mouth_tracking = False 46 | debug_mask = False 47 | 48 | [OTHER] 49 | batch_process = False 50 | output_suffix = _Easy-Wav2Lip 51 | include_settings_in_suffix = False 52 | preview_settings = False 53 | frame_to_preview = 100 54 | 55 | -------------------------------------------------------------------------------- /degradations.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import random 5 | import torch 6 | from scipy import special 7 | from scipy.stats import multivariate_normal 8 | from torchvision.transforms.functional import rgb_to_grayscale 9 | 10 | # -------------------------------------------------------------------- # 11 | # --------------------------- blur kernels --------------------------- # 12 | # -------------------------------------------------------------------- # 13 | 14 | 15 | # --------------------------- util functions --------------------------- # 16 | def sigma_matrix2(sig_x, sig_y, theta): 17 | """Calculate the rotated sigma matrix (two dimensional matrix). 18 | 19 | Args: 20 | sig_x (float): 21 | sig_y (float): 22 | theta (float): Radian measurement. 23 | 24 | Returns: 25 | ndarray: Rotated sigma matrix. 26 | """ 27 | d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]]) 28 | u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) 29 | return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T)) 30 | 31 | 32 | def mesh_grid(kernel_size): 33 | """Generate the mesh grid, centering at zero. 34 | 35 | Args: 36 | kernel_size (int): 37 | 38 | Returns: 39 | xy (ndarray): with the shape (kernel_size, kernel_size, 2) 40 | xx (ndarray): with the shape (kernel_size, kernel_size) 41 | yy (ndarray): with the shape (kernel_size, kernel_size) 42 | """ 43 | ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.) 44 | xx, yy = np.meshgrid(ax, ax) 45 | xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size, 46 | 1))).reshape(kernel_size, kernel_size, 2) 47 | return xy, xx, yy 48 | 49 | 50 | def pdf2(sigma_matrix, grid): 51 | """Calculate PDF of the bivariate Gaussian distribution. 52 | 53 | Args: 54 | sigma_matrix (ndarray): with the shape (2, 2) 55 | grid (ndarray): generated by :func:`mesh_grid`, 56 | with the shape (K, K, 2), K is the kernel size. 57 | 58 | Returns: 59 | kernel (ndarrray): un-normalized kernel. 60 | """ 61 | inverse_sigma = np.linalg.inv(sigma_matrix) 62 | kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2)) 63 | return kernel 64 | 65 | 66 | def cdf2(d_matrix, grid): 67 | """Calculate the CDF of the standard bivariate Gaussian distribution. 68 | Used in skewed Gaussian distribution. 69 | 70 | Args: 71 | d_matrix (ndarrasy): skew matrix. 72 | grid (ndarray): generated by :func:`mesh_grid`, 73 | with the shape (K, K, 2), K is the kernel size. 74 | 75 | Returns: 76 | cdf (ndarray): skewed cdf. 77 | """ 78 | rv = multivariate_normal([0, 0], [[1, 0], [0, 1]]) 79 | grid = np.dot(grid, d_matrix) 80 | cdf = rv.cdf(grid) 81 | return cdf 82 | 83 | 84 | def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True): 85 | """Generate a bivariate isotropic or anisotropic Gaussian kernel. 86 | 87 | In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. 88 | 89 | Args: 90 | kernel_size (int): 91 | sig_x (float): 92 | sig_y (float): 93 | theta (float): Radian measurement. 94 | grid (ndarray, optional): generated by :func:`mesh_grid`, 95 | with the shape (K, K, 2), K is the kernel size. Default: None 96 | isotropic (bool): 97 | 98 | Returns: 99 | kernel (ndarray): normalized kernel. 100 | """ 101 | if grid is None: 102 | grid, _, _ = mesh_grid(kernel_size) 103 | if isotropic: 104 | sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) 105 | else: 106 | sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) 107 | kernel = pdf2(sigma_matrix, grid) 108 | kernel = kernel / np.sum(kernel) 109 | return kernel 110 | 111 | 112 | def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True): 113 | """Generate a bivariate generalized Gaussian kernel. 114 | 115 | ``Paper: Parameter Estimation For Multivariate Generalized Gaussian Distributions`` 116 | 117 | In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. 118 | 119 | Args: 120 | kernel_size (int): 121 | sig_x (float): 122 | sig_y (float): 123 | theta (float): Radian measurement. 124 | beta (float): shape parameter, beta = 1 is the normal distribution. 125 | grid (ndarray, optional): generated by :func:`mesh_grid`, 126 | with the shape (K, K, 2), K is the kernel size. Default: None 127 | 128 | Returns: 129 | kernel (ndarray): normalized kernel. 130 | """ 131 | if grid is None: 132 | grid, _, _ = mesh_grid(kernel_size) 133 | if isotropic: 134 | sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) 135 | else: 136 | sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) 137 | inverse_sigma = np.linalg.inv(sigma_matrix) 138 | kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta)) 139 | kernel = kernel / np.sum(kernel) 140 | return kernel 141 | 142 | 143 | def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True): 144 | """Generate a plateau-like anisotropic kernel. 145 | 146 | 1 / (1+x^(beta)) 147 | 148 | Reference: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution 149 | 150 | In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. 151 | 152 | Args: 153 | kernel_size (int): 154 | sig_x (float): 155 | sig_y (float): 156 | theta (float): Radian measurement. 157 | beta (float): shape parameter, beta = 1 is the normal distribution. 158 | grid (ndarray, optional): generated by :func:`mesh_grid`, 159 | with the shape (K, K, 2), K is the kernel size. Default: None 160 | 161 | Returns: 162 | kernel (ndarray): normalized kernel. 163 | """ 164 | if grid is None: 165 | grid, _, _ = mesh_grid(kernel_size) 166 | if isotropic: 167 | sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) 168 | else: 169 | sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) 170 | inverse_sigma = np.linalg.inv(sigma_matrix) 171 | kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1) 172 | kernel = kernel / np.sum(kernel) 173 | return kernel 174 | 175 | 176 | def random_bivariate_Gaussian(kernel_size, 177 | sigma_x_range, 178 | sigma_y_range, 179 | rotation_range, 180 | noise_range=None, 181 | isotropic=True): 182 | """Randomly generate bivariate isotropic or anisotropic Gaussian kernels. 183 | 184 | In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. 185 | 186 | Args: 187 | kernel_size (int): 188 | sigma_x_range (tuple): [0.6, 5] 189 | sigma_y_range (tuple): [0.6, 5] 190 | rotation range (tuple): [-math.pi, math.pi] 191 | noise_range(tuple, optional): multiplicative kernel noise, 192 | [0.75, 1.25]. Default: None 193 | 194 | Returns: 195 | kernel (ndarray): 196 | """ 197 | assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' 198 | assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' 199 | sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) 200 | if isotropic is False: 201 | assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' 202 | assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' 203 | sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) 204 | rotation = np.random.uniform(rotation_range[0], rotation_range[1]) 205 | else: 206 | sigma_y = sigma_x 207 | rotation = 0 208 | 209 | kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic) 210 | 211 | # add multiplicative noise 212 | if noise_range is not None: 213 | assert noise_range[0] < noise_range[1], 'Wrong noise range.' 214 | noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) 215 | kernel = kernel * noise 216 | kernel = kernel / np.sum(kernel) 217 | return kernel 218 | 219 | 220 | def random_bivariate_generalized_Gaussian(kernel_size, 221 | sigma_x_range, 222 | sigma_y_range, 223 | rotation_range, 224 | beta_range, 225 | noise_range=None, 226 | isotropic=True): 227 | """Randomly generate bivariate generalized Gaussian kernels. 228 | 229 | In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. 230 | 231 | Args: 232 | kernel_size (int): 233 | sigma_x_range (tuple): [0.6, 5] 234 | sigma_y_range (tuple): [0.6, 5] 235 | rotation range (tuple): [-math.pi, math.pi] 236 | beta_range (tuple): [0.5, 8] 237 | noise_range(tuple, optional): multiplicative kernel noise, 238 | [0.75, 1.25]. Default: None 239 | 240 | Returns: 241 | kernel (ndarray): 242 | """ 243 | assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' 244 | assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' 245 | sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) 246 | if isotropic is False: 247 | assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' 248 | assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' 249 | sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) 250 | rotation = np.random.uniform(rotation_range[0], rotation_range[1]) 251 | else: 252 | sigma_y = sigma_x 253 | rotation = 0 254 | 255 | # assume beta_range[0] < 1 < beta_range[1] 256 | if np.random.uniform() < 0.5: 257 | beta = np.random.uniform(beta_range[0], 1) 258 | else: 259 | beta = np.random.uniform(1, beta_range[1]) 260 | 261 | kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic) 262 | 263 | # add multiplicative noise 264 | if noise_range is not None: 265 | assert noise_range[0] < noise_range[1], 'Wrong noise range.' 266 | noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) 267 | kernel = kernel * noise 268 | kernel = kernel / np.sum(kernel) 269 | return kernel 270 | 271 | 272 | def random_bivariate_plateau(kernel_size, 273 | sigma_x_range, 274 | sigma_y_range, 275 | rotation_range, 276 | beta_range, 277 | noise_range=None, 278 | isotropic=True): 279 | """Randomly generate bivariate plateau kernels. 280 | 281 | In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. 282 | 283 | Args: 284 | kernel_size (int): 285 | sigma_x_range (tuple): [0.6, 5] 286 | sigma_y_range (tuple): [0.6, 5] 287 | rotation range (tuple): [-math.pi/2, math.pi/2] 288 | beta_range (tuple): [1, 4] 289 | noise_range(tuple, optional): multiplicative kernel noise, 290 | [0.75, 1.25]. Default: None 291 | 292 | Returns: 293 | kernel (ndarray): 294 | """ 295 | assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' 296 | assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' 297 | sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) 298 | if isotropic is False: 299 | assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' 300 | assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' 301 | sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) 302 | rotation = np.random.uniform(rotation_range[0], rotation_range[1]) 303 | else: 304 | sigma_y = sigma_x 305 | rotation = 0 306 | 307 | # TODO: this may be not proper 308 | if np.random.uniform() < 0.5: 309 | beta = np.random.uniform(beta_range[0], 1) 310 | else: 311 | beta = np.random.uniform(1, beta_range[1]) 312 | 313 | kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic) 314 | # add multiplicative noise 315 | if noise_range is not None: 316 | assert noise_range[0] < noise_range[1], 'Wrong noise range.' 317 | noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) 318 | kernel = kernel * noise 319 | kernel = kernel / np.sum(kernel) 320 | 321 | return kernel 322 | 323 | 324 | def random_mixed_kernels(kernel_list, 325 | kernel_prob, 326 | kernel_size=21, 327 | sigma_x_range=(0.6, 5), 328 | sigma_y_range=(0.6, 5), 329 | rotation_range=(-math.pi, math.pi), 330 | betag_range=(0.5, 8), 331 | betap_range=(0.5, 8), 332 | noise_range=None): 333 | """Randomly generate mixed kernels. 334 | 335 | Args: 336 | kernel_list (tuple): a list name of kernel types, 337 | support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', 338 | 'plateau_aniso'] 339 | kernel_prob (tuple): corresponding kernel probability for each 340 | kernel type 341 | kernel_size (int): 342 | sigma_x_range (tuple): [0.6, 5] 343 | sigma_y_range (tuple): [0.6, 5] 344 | rotation range (tuple): [-math.pi, math.pi] 345 | beta_range (tuple): [0.5, 8] 346 | noise_range(tuple, optional): multiplicative kernel noise, 347 | [0.75, 1.25]. Default: None 348 | 349 | Returns: 350 | kernel (ndarray): 351 | """ 352 | kernel_type = random.choices(kernel_list, kernel_prob)[0] 353 | if kernel_type == 'iso': 354 | kernel = random_bivariate_Gaussian( 355 | kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True) 356 | elif kernel_type == 'aniso': 357 | kernel = random_bivariate_Gaussian( 358 | kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False) 359 | elif kernel_type == 'generalized_iso': 360 | kernel = random_bivariate_generalized_Gaussian( 361 | kernel_size, 362 | sigma_x_range, 363 | sigma_y_range, 364 | rotation_range, 365 | betag_range, 366 | noise_range=noise_range, 367 | isotropic=True) 368 | elif kernel_type == 'generalized_aniso': 369 | kernel = random_bivariate_generalized_Gaussian( 370 | kernel_size, 371 | sigma_x_range, 372 | sigma_y_range, 373 | rotation_range, 374 | betag_range, 375 | noise_range=noise_range, 376 | isotropic=False) 377 | elif kernel_type == 'plateau_iso': 378 | kernel = random_bivariate_plateau( 379 | kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True) 380 | elif kernel_type == 'plateau_aniso': 381 | kernel = random_bivariate_plateau( 382 | kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False) 383 | return kernel 384 | 385 | 386 | np.seterr(divide='ignore', invalid='ignore') 387 | 388 | 389 | def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0): 390 | """2D sinc filter 391 | 392 | Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter 393 | 394 | Args: 395 | cutoff (float): cutoff frequency in radians (pi is max) 396 | kernel_size (int): horizontal and vertical size, must be odd. 397 | pad_to (int): pad kernel size to desired size, must be odd or zero. 398 | """ 399 | assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' 400 | kernel = np.fromfunction( 401 | lambda x, y: cutoff * special.j1(cutoff * np.sqrt( 402 | (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt( 403 | (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size]) 404 | kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi) 405 | kernel = kernel / np.sum(kernel) 406 | if pad_to > kernel_size: 407 | pad_size = (pad_to - kernel_size) // 2 408 | kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) 409 | return kernel 410 | 411 | 412 | # ------------------------------------------------------------- # 413 | # --------------------------- noise --------------------------- # 414 | # ------------------------------------------------------------- # 415 | 416 | # ----------------------- Gaussian Noise ----------------------- # 417 | 418 | 419 | def generate_gaussian_noise(img, sigma=10, gray_noise=False): 420 | """Generate Gaussian noise. 421 | 422 | Args: 423 | img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. 424 | sigma (float): Noise scale (measured in range 255). Default: 10. 425 | 426 | Returns: 427 | (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], 428 | float32. 429 | """ 430 | if gray_noise: 431 | noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255. 432 | noise = np.expand_dims(noise, axis=2).repeat(3, axis=2) 433 | else: 434 | noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255. 435 | return noise 436 | 437 | 438 | def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False): 439 | """Add Gaussian noise. 440 | 441 | Args: 442 | img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. 443 | sigma (float): Noise scale (measured in range 255). Default: 10. 444 | 445 | Returns: 446 | (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], 447 | float32. 448 | """ 449 | noise = generate_gaussian_noise(img, sigma, gray_noise) 450 | out = img + noise 451 | if clip and rounds: 452 | out = np.clip((out * 255.0).round(), 0, 255) / 255. 453 | elif clip: 454 | out = np.clip(out, 0, 1) 455 | elif rounds: 456 | out = (out * 255.0).round() / 255. 457 | return out 458 | 459 | 460 | def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0): 461 | """Add Gaussian noise (PyTorch version). 462 | 463 | Args: 464 | img (Tensor): Shape (b, c, h, w), range[0, 1], float32. 465 | scale (float | Tensor): Noise scale. Default: 1.0. 466 | 467 | Returns: 468 | (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], 469 | float32. 470 | """ 471 | b, _, h, w = img.size() 472 | if not isinstance(sigma, (float, int)): 473 | sigma = sigma.view(img.size(0), 1, 1, 1) 474 | if isinstance(gray_noise, (float, int)): 475 | cal_gray_noise = gray_noise > 0 476 | else: 477 | gray_noise = gray_noise.view(b, 1, 1, 1) 478 | cal_gray_noise = torch.sum(gray_noise) > 0 479 | 480 | if cal_gray_noise: 481 | noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255. 482 | noise_gray = noise_gray.view(b, 1, h, w) 483 | 484 | # always calculate color noise 485 | noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255. 486 | 487 | if cal_gray_noise: 488 | noise = noise * (1 - gray_noise) + noise_gray * gray_noise 489 | return noise 490 | 491 | 492 | def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False): 493 | """Add Gaussian noise (PyTorch version). 494 | 495 | Args: 496 | img (Tensor): Shape (b, c, h, w), range[0, 1], float32. 497 | scale (float | Tensor): Noise scale. Default: 1.0. 498 | 499 | Returns: 500 | (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], 501 | float32. 502 | """ 503 | noise = generate_gaussian_noise_pt(img, sigma, gray_noise) 504 | out = img + noise 505 | if clip and rounds: 506 | out = torch.clamp((out * 255.0).round(), 0, 255) / 255. 507 | elif clip: 508 | out = torch.clamp(out, 0, 1) 509 | elif rounds: 510 | out = (out * 255.0).round() / 255. 511 | return out 512 | 513 | 514 | # ----------------------- Random Gaussian Noise ----------------------- # 515 | def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0): 516 | sigma = np.random.uniform(sigma_range[0], sigma_range[1]) 517 | if np.random.uniform() < gray_prob: 518 | gray_noise = True 519 | else: 520 | gray_noise = False 521 | return generate_gaussian_noise(img, sigma, gray_noise) 522 | 523 | 524 | def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): 525 | noise = random_generate_gaussian_noise(img, sigma_range, gray_prob) 526 | out = img + noise 527 | if clip and rounds: 528 | out = np.clip((out * 255.0).round(), 0, 255) / 255. 529 | elif clip: 530 | out = np.clip(out, 0, 1) 531 | elif rounds: 532 | out = (out * 255.0).round() / 255. 533 | return out 534 | 535 | 536 | def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0): 537 | sigma = torch.rand( 538 | img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0] 539 | gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device) 540 | gray_noise = (gray_noise < gray_prob).float() 541 | return generate_gaussian_noise_pt(img, sigma, gray_noise) 542 | 543 | 544 | def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): 545 | noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob) 546 | out = img + noise 547 | if clip and rounds: 548 | out = torch.clamp((out * 255.0).round(), 0, 255) / 255. 549 | elif clip: 550 | out = torch.clamp(out, 0, 1) 551 | elif rounds: 552 | out = (out * 255.0).round() / 255. 553 | return out 554 | 555 | 556 | # ----------------------- Poisson (Shot) Noise ----------------------- # 557 | 558 | 559 | def generate_poisson_noise(img, scale=1.0, gray_noise=False): 560 | """Generate poisson noise. 561 | 562 | Reference: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219 563 | 564 | Args: 565 | img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. 566 | scale (float): Noise scale. Default: 1.0. 567 | gray_noise (bool): Whether generate gray noise. Default: False. 568 | 569 | Returns: 570 | (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], 571 | float32. 572 | """ 573 | if gray_noise: 574 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 575 | # round and clip image for counting vals correctly 576 | img = np.clip((img * 255.0).round(), 0, 255) / 255. 577 | vals = len(np.unique(img)) 578 | vals = 2**np.ceil(np.log2(vals)) 579 | out = np.float32(np.random.poisson(img * vals) / float(vals)) 580 | noise = out - img 581 | if gray_noise: 582 | noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2) 583 | return noise * scale 584 | 585 | 586 | def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False): 587 | """Add poisson noise. 588 | 589 | Args: 590 | img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. 591 | scale (float): Noise scale. Default: 1.0. 592 | gray_noise (bool): Whether generate gray noise. Default: False. 593 | 594 | Returns: 595 | (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], 596 | float32. 597 | """ 598 | noise = generate_poisson_noise(img, scale, gray_noise) 599 | out = img + noise 600 | if clip and rounds: 601 | out = np.clip((out * 255.0).round(), 0, 255) / 255. 602 | elif clip: 603 | out = np.clip(out, 0, 1) 604 | elif rounds: 605 | out = (out * 255.0).round() / 255. 606 | return out 607 | 608 | 609 | def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0): 610 | """Generate a batch of poisson noise (PyTorch version) 611 | 612 | Args: 613 | img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32. 614 | scale (float | Tensor): Noise scale. Number or Tensor with shape (b). 615 | Default: 1.0. 616 | gray_noise (float | Tensor): 0-1 number or Tensor with shape (b). 617 | 0 for False, 1 for True. Default: 0. 618 | 619 | Returns: 620 | (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], 621 | float32. 622 | """ 623 | b, _, h, w = img.size() 624 | if isinstance(gray_noise, (float, int)): 625 | cal_gray_noise = gray_noise > 0 626 | else: 627 | gray_noise = gray_noise.view(b, 1, 1, 1) 628 | cal_gray_noise = torch.sum(gray_noise) > 0 629 | if cal_gray_noise: 630 | img_gray = rgb_to_grayscale(img, num_output_channels=1) 631 | # round and clip image for counting vals correctly 632 | img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255. 633 | # use for-loop to get the unique values for each sample 634 | vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)] 635 | vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list] 636 | vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1) 637 | out = torch.poisson(img_gray * vals) / vals 638 | noise_gray = out - img_gray 639 | noise_gray = noise_gray.expand(b, 3, h, w) 640 | 641 | # always calculate color noise 642 | # round and clip image for counting vals correctly 643 | img = torch.clamp((img * 255.0).round(), 0, 255) / 255. 644 | # use for-loop to get the unique values for each sample 645 | vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)] 646 | vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list] 647 | vals = img.new_tensor(vals_list).view(b, 1, 1, 1) 648 | out = torch.poisson(img * vals) / vals 649 | noise = out - img 650 | if cal_gray_noise: 651 | noise = noise * (1 - gray_noise) + noise_gray * gray_noise 652 | if not isinstance(scale, (float, int)): 653 | scale = scale.view(b, 1, 1, 1) 654 | return noise * scale 655 | 656 | 657 | def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0): 658 | """Add poisson noise to a batch of images (PyTorch version). 659 | 660 | Args: 661 | img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32. 662 | scale (float | Tensor): Noise scale. Number or Tensor with shape (b). 663 | Default: 1.0. 664 | gray_noise (float | Tensor): 0-1 number or Tensor with shape (b). 665 | 0 for False, 1 for True. Default: 0. 666 | 667 | Returns: 668 | (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], 669 | float32. 670 | """ 671 | noise = generate_poisson_noise_pt(img, scale, gray_noise) 672 | out = img + noise 673 | if clip and rounds: 674 | out = torch.clamp((out * 255.0).round(), 0, 255) / 255. 675 | elif clip: 676 | out = torch.clamp(out, 0, 1) 677 | elif rounds: 678 | out = (out * 255.0).round() / 255. 679 | return out 680 | 681 | 682 | # ----------------------- Random Poisson (Shot) Noise ----------------------- # 683 | 684 | 685 | def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0): 686 | scale = np.random.uniform(scale_range[0], scale_range[1]) 687 | if np.random.uniform() < gray_prob: 688 | gray_noise = True 689 | else: 690 | gray_noise = False 691 | return generate_poisson_noise(img, scale, gray_noise) 692 | 693 | 694 | def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): 695 | noise = random_generate_poisson_noise(img, scale_range, gray_prob) 696 | out = img + noise 697 | if clip and rounds: 698 | out = np.clip((out * 255.0).round(), 0, 255) / 255. 699 | elif clip: 700 | out = np.clip(out, 0, 1) 701 | elif rounds: 702 | out = (out * 255.0).round() / 255. 703 | return out 704 | 705 | 706 | def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0): 707 | scale = torch.rand( 708 | img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0] 709 | gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device) 710 | gray_noise = (gray_noise < gray_prob).float() 711 | return generate_poisson_noise_pt(img, scale, gray_noise) 712 | 713 | 714 | def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): 715 | noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob) 716 | out = img + noise 717 | if clip and rounds: 718 | out = torch.clamp((out * 255.0).round(), 0, 255) / 255. 719 | elif clip: 720 | out = torch.clamp(out, 0, 1) 721 | elif rounds: 722 | out = (out * 255.0).round() / 255. 723 | return out 724 | 725 | 726 | # ------------------------------------------------------------------------ # 727 | # --------------------------- JPEG compression --------------------------- # 728 | # ------------------------------------------------------------------------ # 729 | 730 | 731 | def add_jpg_compression(img, quality=90): 732 | """Add JPG compression artifacts. 733 | 734 | Args: 735 | img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. 736 | quality (float): JPG compression quality. 0 for lowest quality, 100 for 737 | best quality. Default: 90. 738 | 739 | Returns: 740 | (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], 741 | float32. 742 | """ 743 | img = np.clip(img, 0, 1) 744 | encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] 745 | _, encimg = cv2.imencode('.jpg', img * 255., encode_param) 746 | img = np.float32(cv2.imdecode(encimg, 1)) / 255. 747 | return img 748 | 749 | 750 | def random_add_jpg_compression(img, quality_range=(90, 100)): 751 | """Randomly add JPG compression artifacts. 752 | 753 | Args: 754 | img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. 755 | quality_range (tuple[float] | list[float]): JPG compression quality 756 | range. 0 for lowest quality, 100 for best quality. 757 | Default: (90, 100). 758 | 759 | Returns: 760 | (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], 761 | float32. 762 | """ 763 | quality = np.random.uniform(quality_range[0], quality_range[1]) 764 | return add_jpg_compression(img, quality) 765 | -------------------------------------------------------------------------------- /easy_functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import subprocess 3 | import json 4 | import os 5 | import dlib 6 | import gdown 7 | import pickle 8 | import re 9 | from models import Wav2Lip 10 | from base64 import b64encode 11 | from urllib.parse import urlparse 12 | from torch.hub import download_url_to_file, get_dir 13 | from IPython.display import HTML, display 14 | 15 | device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' 16 | 17 | 18 | def get_video_details(filename): 19 | cmd = [ 20 | "ffprobe", 21 | "-v", 22 | "error", 23 | "-show_format", 24 | "-show_streams", 25 | "-of", 26 | "json", 27 | filename, 28 | ] 29 | result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) 30 | info = json.loads(result.stdout) 31 | 32 | # Get video stream 33 | video_stream = next( 34 | stream for stream in info["streams"] if stream["codec_type"] == "video" 35 | ) 36 | 37 | # Get resolution 38 | width = int(video_stream["width"]) 39 | height = int(video_stream["height"]) 40 | resolution = width * height 41 | 42 | # Get fps 43 | fps = eval(video_stream["avg_frame_rate"]) 44 | 45 | # Get length 46 | length = float(info["format"]["duration"]) 47 | 48 | return width, height, fps, length 49 | 50 | 51 | def show_video(file_path): 52 | """Function to display video in Colab""" 53 | mp4 = open(file_path, "rb").read() 54 | data_url = "data:video/mp4;base64," + b64encode(mp4).decode() 55 | width, _, _, _ = get_video_details(file_path) 56 | display( 57 | HTML( 58 | """ 59 | 62 | """ 63 | % (min(width, 1280), data_url) 64 | ) 65 | ) 66 | 67 | 68 | def format_time(seconds): 69 | hours = int(seconds // 3600) 70 | minutes = int((seconds % 3600) // 60) 71 | seconds = int(seconds % 60) 72 | 73 | if hours > 0: 74 | return f"{hours}h {minutes}m {seconds}s" 75 | elif minutes > 0: 76 | return f"{minutes}m {seconds}s" 77 | else: 78 | return f"{seconds}s" 79 | 80 | 81 | def _load(checkpoint_path): 82 | if device != "cpu": 83 | checkpoint = torch.load(checkpoint_path) 84 | else: 85 | checkpoint = torch.load( 86 | checkpoint_path, map_location=lambda storage, loc: storage 87 | ) 88 | return checkpoint 89 | 90 | 91 | def load_model(path): 92 | # If results file exists, load it and return 93 | working_directory = os.getcwd() 94 | folder, filename_with_extension = os.path.split(path) 95 | filename, file_type = os.path.splitext(filename_with_extension) 96 | results_file = os.path.join(folder, filename + ".pk1") 97 | if os.path.exists(results_file): 98 | with open(results_file, "rb") as f: 99 | return pickle.load(f) 100 | model = Wav2Lip() 101 | print("Loading {}".format(path)) 102 | checkpoint = _load(path) 103 | s = checkpoint["state_dict"] 104 | new_s = {} 105 | for k, v in s.items(): 106 | new_s[k.replace("module.", "")] = v 107 | model.load_state_dict(new_s) 108 | 109 | model = model.to(device) 110 | # Save results to file 111 | with open(results_file, "wb") as f: 112 | pickle.dump(model.eval(), f) 113 | # os.remove(path) 114 | return model.eval() 115 | 116 | 117 | def get_input_length(filename): 118 | result = subprocess.run( 119 | [ 120 | "ffprobe", 121 | "-v", 122 | "error", 123 | "-show_entries", 124 | "format=duration", 125 | "-of", 126 | "default=noprint_wrappers=1:nokey=1", 127 | filename, 128 | ], 129 | stdout=subprocess.PIPE, 130 | stderr=subprocess.STDOUT, 131 | ) 132 | return float(result.stdout) 133 | 134 | 135 | def is_url(string): 136 | url_regex = re.compile(r"^(https?|ftp)://[^\s/$.?#].[^\s]*$") 137 | return bool(url_regex.match(string)) 138 | 139 | 140 | def load_predictor(): 141 | checkpoint = os.path.join( 142 | "checkpoints", "shape_predictor_68_face_landmarks_GTX.dat" 143 | ) 144 | predictor = dlib.shape_predictor(checkpoint) 145 | mouth_detector = dlib.get_frontal_face_detector() 146 | 147 | # Serialize the variables 148 | with open(os.path.join("checkpoints", "predictor.pkl"), "wb") as f: 149 | pickle.dump(predictor, f) 150 | 151 | with open(os.path.join("checkpoints", "mouth_detector.pkl"), "wb") as f: 152 | pickle.dump(mouth_detector, f) 153 | 154 | # delete the .dat file as it is no longer needed 155 | # os.remove(output) 156 | 157 | 158 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 159 | """Load file form http url, will download models if necessary. 160 | 161 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 162 | 163 | Args: 164 | url (str): URL to be downloaded. 165 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 166 | Default: None. 167 | progress (bool): Whether to show the download progress. Default: True. 168 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 169 | 170 | Returns: 171 | str: The path to the downloaded file. 172 | """ 173 | if model_dir is None: # use the pytorch hub_dir 174 | hub_dir = get_dir() 175 | model_dir = os.path.join(hub_dir, "checkpoints") 176 | 177 | os.makedirs(model_dir, exist_ok=True) 178 | 179 | parts = urlparse(url) 180 | filename = os.path.basename(parts.path) 181 | if file_name is not None: 182 | filename = file_name 183 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 184 | if not os.path.exists(cached_file): 185 | print(f'Downloading: "{url}" to {cached_file}\n') 186 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 187 | return cached_file 188 | 189 | 190 | def g_colab(): 191 | try: 192 | import google.colab 193 | 194 | return True 195 | except ImportError: 196 | return False 197 | -------------------------------------------------------------------------------- /enhance.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from gfpgan import GFPGANer 3 | 4 | warnings.filterwarnings("ignore") 5 | 6 | 7 | def load_sr(): 8 | run_params = GFPGANer( 9 | model_path="checkpoints/GFPGANv1.4.pth", 10 | upscale=1, 11 | arch="clean", 12 | channel_multiplier=2, 13 | bg_upsampler=None, 14 | ) 15 | return run_params 16 | 17 | 18 | def upscale(image, properties): 19 | _, _, output = properties.enhance( 20 | image, has_aligned=False, only_center_face=False, paste_back=True 21 | ) 22 | return output 23 | -------------------------------------------------------------------------------- /hparams.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | import os 3 | 4 | 5 | def get_image_list(data_root, split): 6 | filelist = [] 7 | 8 | with open("filelists/{}.txt".format(split)) as f: 9 | for line in f: 10 | line = line.strip() 11 | if " " in line: 12 | line = line.split()[0] 13 | filelist.append(os.path.join(data_root, line)) 14 | 15 | return filelist 16 | 17 | 18 | class HParams: 19 | def __init__(self, **kwargs): 20 | self.data = {} 21 | 22 | for key, value in kwargs.items(): 23 | self.data[key] = value 24 | 25 | def __getattr__(self, key): 26 | if key not in self.data: 27 | raise AttributeError("'HParams' object has no attribute %s" % key) 28 | return self.data[key] 29 | 30 | def set_hparam(self, key, value): 31 | self.data[key] = value 32 | 33 | 34 | # Default hyperparameters 35 | hparams = HParams( 36 | num_mels=80, # Number of mel-spectrogram channels and local conditioning dimensionality 37 | # network 38 | rescale=True, # Whether to rescale audio prior to preprocessing 39 | rescaling_max=0.9, # Rescaling value 40 | # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction 41 | # It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder 42 | # Does not work if n_ffit is not multiple of hop_size!! 43 | use_lws=False, 44 | n_fft=800, # Extra window size is filled with 0 paddings to match this parameter 45 | hop_size=200, # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate) 46 | win_size=800, # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) 47 | sample_rate=16000, # 16000Hz (corresponding to librispeech) (sox --i ) 48 | frame_shift_ms=None, # Can replace hop_size parameter. (Recommended: 12.5) 49 | # Mel and Linear spectrograms normalization/scaling and clipping 50 | signal_normalization=True, 51 | # Whether to normalize mel spectrograms to some predefined range (following below parameters) 52 | allow_clipping_in_normalization=True, # Only relevant if mel_normalization = True 53 | symmetric_mels=True, 54 | # Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, 55 | # faster and cleaner convergence) 56 | max_abs_value=4.0, 57 | # max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not 58 | # be too big to avoid gradient explosion, 59 | # not too small for fast convergence) 60 | # Contribution by @begeekmyfriend 61 | # Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude 62 | # levels. Also allows for better G&L phase reconstruction) 63 | preemphasize=True, # whether to apply filter 64 | preemphasis=0.97, # filter coefficient. 65 | # Limits 66 | min_level_db=-100, 67 | ref_level_db=20, 68 | fmin=55, 69 | # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To 70 | # test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) 71 | fmax=7600, # To be increased/reduced depending on data. 72 | ###################### Our training parameters ################################# 73 | img_size=96, 74 | fps=25, 75 | batch_size=16, 76 | initial_learning_rate=1e-4, 77 | nepochs=200000000000000000, ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs 78 | num_workers=16, 79 | checkpoint_interval=3000, 80 | eval_interval=3000, 81 | save_optimizer_state=True, 82 | syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. 83 | syncnet_batch_size=64, 84 | syncnet_lr=1e-4, 85 | syncnet_eval_interval=10000, 86 | syncnet_checkpoint_interval=10000, 87 | disc_wt=0.07, 88 | disc_initial_learning_rate=1e-4, 89 | ) 90 | 91 | 92 | def hparams_debug_string(): 93 | values = hparams.values() 94 | hp = [ 95 | " %s: %s" % (name, values[name]) 96 | for name in sorted(values) 97 | if name != "sentences" 98 | ] 99 | return "Hyperparameters:\n" + "\n".join(hp) 100 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | print("\rloading torch ", end="") 2 | import torch 3 | 4 | print("\rloading numpy ", end="") 5 | import numpy as np 6 | 7 | print("\rloading Image ", end="") 8 | from PIL import Image 9 | 10 | print("\rloading argparse ", end="") 11 | import argparse 12 | 13 | print("\rloading configparser", end="") 14 | import configparser 15 | 16 | print("\rloading math ", end="") 17 | import math 18 | 19 | print("\rloading os ", end="") 20 | import os 21 | 22 | print("\rloading subprocess ", end="") 23 | import subprocess 24 | 25 | print("\rloading pickle ", end="") 26 | import pickle 27 | 28 | print("\rloading cv2 ", end="") 29 | import cv2 30 | 31 | print("\rloading audio ", end="") 32 | import audio 33 | 34 | print("\rloading RetinaFace ", end="") 35 | from batch_face import RetinaFace 36 | 37 | print("\rloading re ", end="") 38 | import re 39 | 40 | print("\rloading partial ", end="") 41 | from functools import partial 42 | 43 | print("\rloading tqdm ", end="") 44 | from tqdm import tqdm 45 | 46 | print("\rloading warnings ", end="") 47 | import warnings 48 | 49 | warnings.filterwarnings( 50 | "ignore", category=UserWarning, module="torchvision.transforms.functional_tensor" 51 | ) 52 | print("\rloading upscale ", end="") 53 | from enhance import upscale 54 | 55 | print("\rloading load_sr ", end="") 56 | from enhance import load_sr 57 | 58 | print("\rloading load_model ", end="") 59 | from easy_functions import load_model, g_colab 60 | 61 | print("\rimports loaded! ") 62 | 63 | device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' 64 | gpu_id = 0 if torch.cuda.is_available() else -1 65 | 66 | if device == 'cpu': 67 | print('Warning: No GPU detected so inference will be done on the CPU which is VERY SLOW!') 68 | parser = argparse.ArgumentParser( 69 | description="Inference code to lip-sync videos in the wild using Wav2Lip models" 70 | ) 71 | 72 | parser.add_argument( 73 | "--checkpoint_path", 74 | type=str, 75 | help="Name of saved checkpoint to load weights from", 76 | required=True, 77 | ) 78 | 79 | parser.add_argument( 80 | "--segmentation_path", 81 | type=str, 82 | default="checkpoints/face_segmentation.pth", 83 | help="Name of saved checkpoint of segmentation network", 84 | required=False, 85 | ) 86 | 87 | parser.add_argument( 88 | "--face", 89 | type=str, 90 | help="Filepath of video/image that contains faces to use", 91 | required=True, 92 | ) 93 | parser.add_argument( 94 | "--audio", 95 | type=str, 96 | help="Filepath of video/audio file to use as raw audio source", 97 | required=True, 98 | ) 99 | parser.add_argument( 100 | "--outfile", 101 | type=str, 102 | help="Video path to save result. See default for an e.g.", 103 | default="results/result_voice.mp4", 104 | ) 105 | 106 | parser.add_argument( 107 | "--static", 108 | type=bool, 109 | help="If True, then use only first video frame for inference", 110 | default=False, 111 | ) 112 | parser.add_argument( 113 | "--fps", 114 | type=float, 115 | help="Can be specified only if input is a static image (default: 25)", 116 | default=25.0, 117 | required=False, 118 | ) 119 | 120 | parser.add_argument( 121 | "--pads", 122 | nargs="+", 123 | type=int, 124 | default=[0, 10, 0, 0], 125 | help="Padding (top, bottom, left, right). Please adjust to include chin at least", 126 | ) 127 | 128 | parser.add_argument( 129 | "--wav2lip_batch_size", type=int, help="Batch size for Wav2Lip model(s)", default=1 130 | ) 131 | 132 | parser.add_argument( 133 | "--out_height", 134 | default=480, 135 | type=int, 136 | help="Output video height. Best results are obtained at 480 or 720", 137 | ) 138 | 139 | parser.add_argument( 140 | "--crop", 141 | nargs="+", 142 | type=int, 143 | default=[0, -1, 0, -1], 144 | help="Crop video to a smaller region (top, bottom, left, right). Applied after resize_factor and rotate arg. " 145 | "Useful if multiple face present. -1 implies the value will be auto-inferred based on height, width", 146 | ) 147 | 148 | parser.add_argument( 149 | "--box", 150 | nargs="+", 151 | type=int, 152 | default=[-1, -1, -1, -1], 153 | help="Specify a constant bounding box for the face. Use only as a last resort if the face is not detected." 154 | "Also, might work only if the face is not moving around much. Syntax: (top, bottom, left, right).", 155 | ) 156 | 157 | parser.add_argument( 158 | "--rotate", 159 | default=False, 160 | action="store_true", 161 | help="Sometimes videos taken from a phone can be flipped 90deg. If true, will flip video right by 90deg." 162 | "Use if you get a flipped result, despite feeding a normal looking video", 163 | ) 164 | 165 | parser.add_argument( 166 | "--nosmooth", 167 | type=str, 168 | default=False, 169 | help="Prevent smoothing face detections over a short temporal window", 170 | ) 171 | 172 | parser.add_argument( 173 | "--no_seg", 174 | default=False, 175 | action="store_true", 176 | help="Prevent using face segmentation", 177 | ) 178 | 179 | parser.add_argument( 180 | "--no_sr", default=False, action="store_true", help="Prevent using super resolution" 181 | ) 182 | 183 | parser.add_argument( 184 | "--sr_model", 185 | type=str, 186 | default="gfpgan", 187 | help="Name of upscaler - gfpgan or RestoreFormer", 188 | required=False, 189 | ) 190 | 191 | parser.add_argument( 192 | "--fullres", 193 | default=3, 194 | type=int, 195 | help="used only to determine if full res is used so that no resizing needs to be done if so", 196 | ) 197 | 198 | parser.add_argument( 199 | "--debug_mask", 200 | type=str, 201 | default=False, 202 | help="Makes background grayscale to see the mask better", 203 | ) 204 | 205 | parser.add_argument( 206 | "--preview_settings", type=str, default=False, help="Processes only one frame" 207 | ) 208 | 209 | parser.add_argument( 210 | "--mouth_tracking", 211 | type=str, 212 | default=False, 213 | help="Tracks the mouth in every frame for the mask", 214 | ) 215 | 216 | parser.add_argument( 217 | "--mask_dilation", 218 | default=150, 219 | type=float, 220 | help="size of mask around mouth", 221 | required=False, 222 | ) 223 | 224 | parser.add_argument( 225 | "--mask_feathering", 226 | default=151, 227 | type=int, 228 | help="amount of feathering of mask around mouth", 229 | required=False, 230 | ) 231 | 232 | parser.add_argument( 233 | "--quality", 234 | type=str, 235 | help="Choose between Fast, Improved and Enhanced", 236 | default="Fast", 237 | ) 238 | 239 | with open(os.path.join("checkpoints", "predictor.pkl"), "rb") as f: 240 | predictor = pickle.load(f) 241 | 242 | with open(os.path.join("checkpoints", "mouth_detector.pkl"), "rb") as f: 243 | mouth_detector = pickle.load(f) 244 | 245 | # creating variables to prevent failing when a face isn't detected 246 | kernel = last_mask = x = y = w = h = None 247 | 248 | g_colab = g_colab() 249 | 250 | if not g_colab: 251 | # Load the config file 252 | config = configparser.ConfigParser() 253 | config.read('config.ini') 254 | 255 | # Get the value of the "preview_window" variable 256 | preview_window = config.get('OPTIONS', 'preview_window') 257 | 258 | all_mouth_landmarks = [] 259 | 260 | model = detector = detector_model = None 261 | 262 | def do_load(checkpoint_path): 263 | global model, detector, detector_model 264 | model = load_model(checkpoint_path) 265 | detector = RetinaFace( 266 | gpu_id=gpu_id, model_path="checkpoints/mobilenet.pth", network="mobilenet" 267 | ) 268 | detector_model = detector.model 269 | 270 | def face_rect(images): 271 | face_batch_size = 8 272 | num_batches = math.ceil(len(images) / face_batch_size) 273 | prev_ret = None 274 | for i in range(num_batches): 275 | batch = images[i * face_batch_size : (i + 1) * face_batch_size] 276 | all_faces = detector(batch) # return faces list of all images 277 | for faces in all_faces: 278 | if faces: 279 | box, landmarks, score = faces[0] 280 | prev_ret = tuple(map(int, box)) 281 | yield prev_ret 282 | 283 | def create_tracked_mask(img, original_img): 284 | global kernel, last_mask, x, y, w, h # Add last_mask to global variables 285 | 286 | # Convert color space from BGR to RGB if necessary 287 | cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) 288 | cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB, original_img) 289 | 290 | # Detect face 291 | faces = mouth_detector(img) 292 | if len(faces) == 0: 293 | if last_mask is not None: 294 | last_mask = cv2.resize(last_mask, (img.shape[1], img.shape[0])) 295 | mask = last_mask # use the last successful mask 296 | else: 297 | cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) 298 | return img, None 299 | else: 300 | face = faces[0] 301 | shape = predictor(img, face) 302 | 303 | # Get points for mouth 304 | mouth_points = np.array( 305 | [[shape.part(i).x, shape.part(i).y] for i in range(48, 68)] 306 | ) 307 | 308 | # Calculate bounding box dimensions 309 | x, y, w, h = cv2.boundingRect(mouth_points) 310 | 311 | # Set kernel size as a fraction of bounding box size 312 | kernel_size = int(max(w, h) * args.mask_dilation) 313 | # if kernel_size % 2 == 0: # Ensure kernel size is odd 314 | # kernel_size += 1 315 | 316 | # Create kernel 317 | kernel = np.ones((kernel_size, kernel_size), np.uint8) 318 | 319 | # Create binary mask for mouth 320 | mask = np.zeros(img.shape[:2], dtype=np.uint8) 321 | cv2.fillConvexPoly(mask, mouth_points, 255) 322 | 323 | last_mask = mask # Update last_mask with the new mask 324 | 325 | # Dilate the mask 326 | dilated_mask = cv2.dilate(mask, kernel) 327 | 328 | # Calculate distance transform of dilated mask 329 | dist_transform = cv2.distanceTransform(dilated_mask, cv2.DIST_L2, 5) 330 | 331 | # Normalize distance transform 332 | cv2.normalize(dist_transform, dist_transform, 0, 255, cv2.NORM_MINMAX) 333 | 334 | # Convert normalized distance transform to binary mask and convert it to uint8 335 | _, masked_diff = cv2.threshold(dist_transform, 50, 255, cv2.THRESH_BINARY) 336 | masked_diff = masked_diff.astype(np.uint8) 337 | 338 | # make sure blur is an odd number 339 | blur = args.mask_feathering 340 | if blur % 2 == 0: 341 | blur += 1 342 | # Set blur size as a fraction of bounding box size 343 | blur = int(max(w, h) * blur) # 10% of bounding box size 344 | if blur % 2 == 0: # Ensure blur size is odd 345 | blur += 1 346 | masked_diff = cv2.GaussianBlur(masked_diff, (blur, blur), 0) 347 | 348 | # Convert numpy arrays to PIL Images 349 | input1 = Image.fromarray(img) 350 | input2 = Image.fromarray(original_img) 351 | 352 | # Convert mask to single channel where pixel values are from the alpha channel of the current mask 353 | mask = Image.fromarray(masked_diff) 354 | 355 | # Ensure images are the same size 356 | assert input1.size == input2.size == mask.size 357 | 358 | # Paste input1 onto input2 using the mask 359 | input2.paste(input1, (0, 0), mask) 360 | 361 | # Convert the final PIL Image back to a numpy array 362 | input2 = np.array(input2) 363 | 364 | # input2 = cv2.cvtColor(input2, cv2.COLOR_BGR2RGB) 365 | cv2.cvtColor(input2, cv2.COLOR_BGR2RGB, input2) 366 | 367 | return input2, mask 368 | 369 | 370 | def create_mask(img, original_img): 371 | global kernel, last_mask, x, y, w, h # Add last_mask to global variables 372 | 373 | # Convert color space from BGR to RGB if necessary 374 | cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) 375 | cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB, original_img) 376 | 377 | if last_mask is not None: 378 | last_mask = np.array(last_mask) # Convert PIL Image to numpy array 379 | last_mask = cv2.resize(last_mask, (img.shape[1], img.shape[0])) 380 | mask = last_mask # use the last successful mask 381 | mask = Image.fromarray(mask) 382 | 383 | else: 384 | # Detect face 385 | faces = mouth_detector(img) 386 | if len(faces) == 0: 387 | cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) 388 | return img, None 389 | else: 390 | face = faces[0] 391 | shape = predictor(img, face) 392 | 393 | # Get points for mouth 394 | mouth_points = np.array( 395 | [[shape.part(i).x, shape.part(i).y] for i in range(48, 68)] 396 | ) 397 | 398 | # Calculate bounding box dimensions 399 | x, y, w, h = cv2.boundingRect(mouth_points) 400 | 401 | # Set kernel size as a fraction of bounding box size 402 | kernel_size = int(max(w, h) * args.mask_dilation) 403 | # if kernel_size % 2 == 0: # Ensure kernel size is odd 404 | # kernel_size += 1 405 | 406 | # Create kernel 407 | kernel = np.ones((kernel_size, kernel_size), np.uint8) 408 | 409 | # Create binary mask for mouth 410 | mask = np.zeros(img.shape[:2], dtype=np.uint8) 411 | cv2.fillConvexPoly(mask, mouth_points, 255) 412 | 413 | # Dilate the mask 414 | dilated_mask = cv2.dilate(mask, kernel) 415 | 416 | # Calculate distance transform of dilated mask 417 | dist_transform = cv2.distanceTransform(dilated_mask, cv2.DIST_L2, 5) 418 | 419 | # Normalize distance transform 420 | cv2.normalize(dist_transform, dist_transform, 0, 255, cv2.NORM_MINMAX) 421 | 422 | # Convert normalized distance transform to binary mask and convert it to uint8 423 | _, masked_diff = cv2.threshold(dist_transform, 50, 255, cv2.THRESH_BINARY) 424 | masked_diff = masked_diff.astype(np.uint8) 425 | 426 | if not args.mask_feathering == 0: 427 | blur = args.mask_feathering 428 | # Set blur size as a fraction of bounding box size 429 | blur = int(max(w, h) * blur) # 10% of bounding box size 430 | if blur % 2 == 0: # Ensure blur size is odd 431 | blur += 1 432 | masked_diff = cv2.GaussianBlur(masked_diff, (blur, blur), 0) 433 | 434 | # Convert mask to single channel where pixel values are from the alpha channel of the current mask 435 | mask = Image.fromarray(masked_diff) 436 | 437 | last_mask = mask # Update last_mask with the final mask after dilation and feathering 438 | 439 | # Convert numpy arrays to PIL Images 440 | input1 = Image.fromarray(img) 441 | input2 = Image.fromarray(original_img) 442 | 443 | # Resize mask to match image size 444 | # mask = Image.fromarray(mask) 445 | mask = mask.resize(input1.size) 446 | 447 | # Ensure images are the same size 448 | assert input1.size == input2.size == mask.size 449 | 450 | # Paste input1 onto input2 using the mask 451 | input2.paste(input1, (0, 0), mask) 452 | 453 | # Convert the final PIL Image back to a numpy array 454 | input2 = np.array(input2) 455 | 456 | # input2 = cv2.cvtColor(input2, cv2.COLOR_BGR2RGB) 457 | cv2.cvtColor(input2, cv2.COLOR_BGR2RGB, input2) 458 | 459 | return input2, mask 460 | 461 | 462 | def get_smoothened_boxes(boxes, T): 463 | for i in range(len(boxes)): 464 | if i + T > len(boxes): 465 | window = boxes[len(boxes) - T :] 466 | else: 467 | window = boxes[i : i + T] 468 | boxes[i] = np.mean(window, axis=0) 469 | return boxes 470 | 471 | def face_detect(images, results_file="last_detected_face.pkl"): 472 | # If results file exists, load it and return 473 | if os.path.exists(results_file): 474 | print("Using face detection data from last input") 475 | with open(results_file, "rb") as f: 476 | return pickle.load(f) 477 | 478 | results = [] 479 | pady1, pady2, padx1, padx2 = args.pads 480 | 481 | tqdm_partial = partial(tqdm, position=0, leave=True) 482 | for image, (rect) in tqdm_partial( 483 | zip(images, face_rect(images)), 484 | total=len(images), 485 | desc="detecting face in every frame", 486 | ncols=100, 487 | ): 488 | if rect is None: 489 | cv2.imwrite( 490 | "temp/faulty_frame.jpg", image 491 | ) # check this frame where the face was not detected. 492 | raise ValueError( 493 | "Face not detected! Ensure the video contains a face in all the frames." 494 | ) 495 | 496 | y1 = max(0, rect[1] - pady1) 497 | y2 = min(image.shape[0], rect[3] + pady2) 498 | x1 = max(0, rect[0] - padx1) 499 | x2 = min(image.shape[1], rect[2] + padx2) 500 | 501 | results.append([x1, y1, x2, y2]) 502 | 503 | 504 | boxes = np.array(results) 505 | if str(args.nosmooth) == "False": 506 | boxes = get_smoothened_boxes(boxes, T=5) 507 | results = [ 508 | [image[y1:y2, x1:x2], (y1, y2, x1, x2)] 509 | for image, (x1, y1, x2, y2) in zip(images, boxes) 510 | ] 511 | 512 | # Save results to file 513 | with open(results_file, "wb") as f: 514 | pickle.dump(results, f) 515 | 516 | return results 517 | 518 | 519 | def datagen(frames, mels): 520 | img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] 521 | print("\r" + " " * 100, end="\r") 522 | if args.box[0] == -1: 523 | if not args.static: 524 | face_det_results = face_detect(frames) # BGR2RGB for CNN face detection 525 | else: 526 | face_det_results = face_detect([frames[0]]) 527 | else: 528 | print("Using the specified bounding box instead of face detection...") 529 | y1, y2, x1, x2 = args.box 530 | face_det_results = [[f[y1:y2, x1:x2], (y1, y2, x1, x2)] for f in frames] 531 | 532 | for i, m in enumerate(mels): 533 | idx = 0 if args.static else i % len(frames) 534 | frame_to_save = frames[idx].copy() 535 | face, coords = face_det_results[idx].copy() 536 | 537 | face = cv2.resize(face, (args.img_size, args.img_size)) 538 | 539 | img_batch.append(face) 540 | mel_batch.append(m) 541 | frame_batch.append(frame_to_save) 542 | coords_batch.append(coords) 543 | 544 | if len(img_batch) >= args.wav2lip_batch_size: 545 | img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) 546 | 547 | img_masked = img_batch.copy() 548 | img_masked[:, args.img_size // 2 :] = 0 549 | 550 | img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0 551 | mel_batch = np.reshape( 552 | mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1] 553 | ) 554 | 555 | yield img_batch, mel_batch, frame_batch, coords_batch 556 | img_batch, mel_batch, frame_batch, coords_batch = [], [], [], [] 557 | 558 | if len(img_batch) > 0: 559 | img_batch, mel_batch = np.asarray(img_batch), np.asarray(mel_batch) 560 | 561 | img_masked = img_batch.copy() 562 | img_masked[:, args.img_size // 2 :] = 0 563 | 564 | img_batch = np.concatenate((img_masked, img_batch), axis=3) / 255.0 565 | mel_batch = np.reshape( 566 | mel_batch, [len(mel_batch), mel_batch.shape[1], mel_batch.shape[2], 1] 567 | ) 568 | 569 | yield img_batch, mel_batch, frame_batch, coords_batch 570 | 571 | 572 | mel_step_size = 16 573 | 574 | def _load(checkpoint_path): 575 | if device != "cpu": 576 | checkpoint = torch.load(checkpoint_path) 577 | else: 578 | checkpoint = torch.load( 579 | checkpoint_path, map_location=lambda storage, loc: storage 580 | ) 581 | return checkpoint 582 | 583 | 584 | def main(): 585 | args.img_size = 96 586 | frame_number = 11 587 | 588 | if os.path.isfile(args.face) and args.face.split(".")[1] in ["jpg", "png", "jpeg"]: 589 | args.static = True 590 | 591 | if not os.path.isfile(args.face): 592 | raise ValueError("--face argument must be a valid path to video/image file") 593 | 594 | elif args.face.split(".")[1] in ["jpg", "png", "jpeg"]: 595 | full_frames = [cv2.imread(args.face)] 596 | fps = args.fps 597 | 598 | else: 599 | if args.fullres != 1: 600 | print("Resizing video...") 601 | video_stream = cv2.VideoCapture(args.face) 602 | fps = video_stream.get(cv2.CAP_PROP_FPS) 603 | 604 | full_frames = [] 605 | while 1: 606 | still_reading, frame = video_stream.read() 607 | if not still_reading: 608 | video_stream.release() 609 | break 610 | 611 | if args.fullres != 1: 612 | aspect_ratio = frame.shape[1] / frame.shape[0] 613 | frame = cv2.resize( 614 | frame, (int(args.out_height * aspect_ratio), args.out_height) 615 | ) 616 | 617 | if args.rotate: 618 | frame = cv2.rotate(frame, cv2.cv2.ROTATE_90_CLOCKWISE) 619 | 620 | y1, y2, x1, x2 = args.crop 621 | if x2 == -1: 622 | x2 = frame.shape[1] 623 | if y2 == -1: 624 | y2 = frame.shape[0] 625 | 626 | frame = frame[y1:y2, x1:x2] 627 | 628 | full_frames.append(frame) 629 | 630 | if not args.audio.endswith(".wav"): 631 | print("Converting audio to .wav") 632 | subprocess.check_call( 633 | [ 634 | "ffmpeg", 635 | "-y", 636 | "-loglevel", 637 | "error", 638 | "-i", 639 | args.audio, 640 | "temp/temp.wav", 641 | ] 642 | ) 643 | args.audio = "temp/temp.wav" 644 | 645 | print("analysing audio...") 646 | wav = audio.load_wav(args.audio, 16000) 647 | mel = audio.melspectrogram(wav) 648 | 649 | if np.isnan(mel.reshape(-1)).sum() > 0: 650 | raise ValueError( 651 | "Mel contains nan! Using a TTS voice? Add a small epsilon noise to the wav file and try again" 652 | ) 653 | 654 | mel_chunks = [] 655 | 656 | mel_idx_multiplier = 80.0 / fps 657 | i = 0 658 | while 1: 659 | start_idx = int(i * mel_idx_multiplier) 660 | if start_idx + mel_step_size > len(mel[0]): 661 | mel_chunks.append(mel[:, len(mel[0]) - mel_step_size :]) 662 | break 663 | mel_chunks.append(mel[:, start_idx : start_idx + mel_step_size]) 664 | i += 1 665 | 666 | full_frames = full_frames[: len(mel_chunks)] 667 | if str(args.preview_settings) == "True": 668 | full_frames = [full_frames[0]] 669 | mel_chunks = [mel_chunks[0]] 670 | print(str(len(full_frames)) + " frames to process") 671 | batch_size = args.wav2lip_batch_size 672 | if str(args.preview_settings) == "True": 673 | gen = datagen(full_frames, mel_chunks) 674 | else: 675 | gen = datagen(full_frames.copy(), mel_chunks) 676 | 677 | for i, (img_batch, mel_batch, frames, coords) in enumerate( 678 | tqdm( 679 | gen, 680 | total=int(np.ceil(float(len(mel_chunks)) / batch_size)), 681 | desc="Processing Wav2Lip", 682 | ncols=100, 683 | ) 684 | ): 685 | if i == 0: 686 | if not args.quality == "Fast": 687 | print( 688 | f"mask size: {args.mask_dilation}, feathering: {args.mask_feathering}" 689 | ) 690 | if not args.quality == "Improved": 691 | print("Loading", args.sr_model) 692 | run_params = load_sr() 693 | 694 | print("Starting...") 695 | frame_h, frame_w = full_frames[0].shape[:-1] 696 | fourcc = cv2.VideoWriter_fourcc(*"mp4v") 697 | out = cv2.VideoWriter("temp/result.mp4", fourcc, fps, (frame_w, frame_h)) 698 | 699 | img_batch = torch.FloatTensor(np.transpose(img_batch, (0, 3, 1, 2))).to(device) 700 | mel_batch = torch.FloatTensor(np.transpose(mel_batch, (0, 3, 1, 2))).to(device) 701 | 702 | with torch.no_grad(): 703 | pred = model(mel_batch, img_batch) 704 | 705 | pred = pred.cpu().numpy().transpose(0, 2, 3, 1) * 255.0 706 | 707 | for p, f, c in zip(pred, frames, coords): 708 | # cv2.imwrite('temp/f.jpg', f) 709 | 710 | y1, y2, x1, x2 = c 711 | 712 | if ( 713 | str(args.debug_mask) == "True" 714 | ): # makes the background black & white so you can see the mask better 715 | f = cv2.cvtColor(f, cv2.COLOR_BGR2GRAY) 716 | f = cv2.cvtColor(f, cv2.COLOR_GRAY2BGR) 717 | 718 | p = cv2.resize(p.astype(np.uint8), (x2 - x1, y2 - y1)) 719 | cf = f[y1:y2, x1:x2] 720 | 721 | if args.quality == "Enhanced": 722 | p = upscale(p, run_params) 723 | 724 | if args.quality in ["Enhanced", "Improved"]: 725 | if str(args.mouth_tracking) == "True": 726 | p, last_mask = create_tracked_mask(p, cf) 727 | else: 728 | p, last_mask = create_mask(p, cf) 729 | 730 | f[y1:y2, x1:x2] = p 731 | 732 | if not g_colab: 733 | # Display the frame 734 | if preview_window == "Face": 735 | cv2.imshow("face preview - press Q to abort", p) 736 | elif preview_window == "Full": 737 | cv2.imshow("full preview - press Q to abort", f) 738 | elif preview_window == "Both": 739 | cv2.imshow("face preview - press Q to abort", p) 740 | cv2.imshow("full preview - press Q to abort", f) 741 | 742 | key = cv2.waitKey(1) & 0xFF 743 | if key == ord('q'): 744 | exit() # Exit the loop when 'Q' is pressed 745 | 746 | if str(args.preview_settings) == "True": 747 | cv2.imwrite("temp/preview.jpg", f) 748 | if not g_colab: 749 | cv2.imshow("preview - press Q to close", f) 750 | if cv2.waitKey(-1) & 0xFF == ord('q'): 751 | exit() # Exit the loop when 'Q' is pressed 752 | 753 | else: 754 | out.write(f) 755 | 756 | # Close the window(s) when done 757 | cv2.destroyAllWindows() 758 | 759 | out.release() 760 | 761 | if str(args.preview_settings) == "False": 762 | print("converting to final video") 763 | 764 | subprocess.check_call([ 765 | "ffmpeg", 766 | "-y", 767 | "-loglevel", 768 | "error", 769 | "-i", 770 | "temp/result.mp4", 771 | "-i", 772 | args.audio, 773 | "-c:v", 774 | "libx264", 775 | args.outfile 776 | ]) 777 | 778 | if __name__ == "__main__": 779 | args = parser.parse_args() 780 | do_load(args.checkpoint_path) 781 | main() 782 | -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | version = 'v8.3' 2 | 3 | import os 4 | import re 5 | import argparse 6 | import shutil 7 | import subprocess 8 | from IPython.display import clear_output 9 | 10 | from easy_functions import (format_time, 11 | load_file_from_url, 12 | load_model, 13 | load_predictor) 14 | # Get the location of the basicsr package 15 | import os 16 | import shutil 17 | import subprocess 18 | import warnings 19 | 20 | warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.transforms.functional_tensor") 21 | 22 | # Get the location of the basicsr package 23 | def get_basicsr_location(): 24 | result = subprocess.run(['pip', 'show', 'basicsr'], capture_output=True, text=True) 25 | for line in result.stdout.split('\n'): 26 | if 'Location: ' in line: 27 | return line.split('Location: ')[1] 28 | return None 29 | 30 | # Move and replace a file to the basicsr location 31 | def move_and_replace_file_to_basicsr(file_name): 32 | basicsr_location = get_basicsr_location() 33 | if basicsr_location: 34 | destination = os.path.join(basicsr_location, file_name) 35 | # Move and replace the file 36 | shutil.copyfile(file_name, destination) 37 | print(f'File replaced at {destination}') 38 | else: 39 | print('Could not find basicsr location.') 40 | 41 | # Example usage 42 | file_to_replace = 'degradations.py' # Replace with your file name 43 | move_and_replace_file_to_basicsr(file_to_replace) 44 | 45 | 46 | from enhance import load_sr 47 | 48 | working_directory = os.getcwd() 49 | 50 | # download and initialize both wav2lip models 51 | print("downloading wav2lip essentials") 52 | load_file_from_url( 53 | url="https://github.com/anothermartz/Easy-Wav2Lip/releases/download/Prerequesits/Wav2Lip_GAN.pth", 54 | model_dir="checkpoints", 55 | progress=True, 56 | file_name="Wav2Lip_GAN.pth", 57 | ) 58 | model = load_model(os.path.join(working_directory, "checkpoints", "Wav2Lip_GAN.pth")) 59 | print("wav2lip_gan loaded") 60 | load_file_from_url( 61 | url="https://github.com/anothermartz/Easy-Wav2Lip/releases/download/Prerequesits/Wav2Lip.pth", 62 | model_dir="checkpoints", 63 | progress=True, 64 | file_name="Wav2Lip.pth", 65 | ) 66 | model = load_model(os.path.join(working_directory, "checkpoints", "Wav2Lip.pth")) 67 | print("wav2lip loaded") 68 | 69 | # download gfpgan files 70 | print("downloading gfpgan essentials") 71 | load_file_from_url( 72 | url="https://github.com/anothermartz/Easy-Wav2Lip/releases/download/Prerequesits/GFPGANv1.4.pth", 73 | model_dir="checkpoints", 74 | progress=True, 75 | file_name="GFPGANv1.4.pth", 76 | ) 77 | load_sr() 78 | 79 | # load face detectors 80 | print("initializing face detectors") 81 | load_file_from_url( 82 | url="https://github.com/anothermartz/Easy-Wav2Lip/releases/download/Prerequesits/shape_predictor_68_face_landmarks_GTX.dat", 83 | model_dir="checkpoints", 84 | progress=True, 85 | file_name="shape_predictor_68_face_landmarks_GTX.dat", 86 | ) 87 | 88 | load_predictor() 89 | 90 | # write a file to signify setup is done 91 | with open("installed.txt", "w") as f: 92 | f.write(version) 93 | print("Installation complete!") 94 | print( 95 | "If you just updated from v8 - make sure to download the updated Easy-Wav2Lip.bat too!" 96 | ) 97 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .wav2lip import Wav2Lip, Wav2Lip_disc_qual 2 | from .syncnet import SyncNet_color -------------------------------------------------------------------------------- /models/conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | class Conv2d(nn.Module): 6 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): 7 | super().__init__(*args, **kwargs) 8 | self.conv_block = nn.Sequential( 9 | nn.Conv2d(cin, cout, kernel_size, stride, padding), 10 | nn.BatchNorm2d(cout) 11 | ) 12 | self.act = nn.ReLU() 13 | self.residual = residual 14 | 15 | def forward(self, x): 16 | out = self.conv_block(x) 17 | if self.residual: 18 | out += x 19 | return self.act(out) 20 | 21 | class nonorm_Conv2d(nn.Module): 22 | def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): 23 | super().__init__(*args, **kwargs) 24 | self.conv_block = nn.Sequential( 25 | nn.Conv2d(cin, cout, kernel_size, stride, padding), 26 | ) 27 | self.act = nn.LeakyReLU(0.01, inplace=True) 28 | 29 | def forward(self, x): 30 | out = self.conv_block(x) 31 | return self.act(out) 32 | 33 | class Conv2dTranspose(nn.Module): 34 | def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs): 35 | super().__init__(*args, **kwargs) 36 | self.conv_block = nn.Sequential( 37 | nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding), 38 | nn.BatchNorm2d(cout) 39 | ) 40 | self.act = nn.ReLU() 41 | 42 | def forward(self, x): 43 | out = self.conv_block(x) 44 | return self.act(out) 45 | -------------------------------------------------------------------------------- /models/syncnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | 5 | from .conv import Conv2d 6 | 7 | class SyncNet_color(nn.Module): 8 | def __init__(self): 9 | super(SyncNet_color, self).__init__() 10 | 11 | self.face_encoder = nn.Sequential( 12 | Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3), 13 | 14 | Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1), 15 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 16 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 17 | 18 | Conv2d(64, 128, kernel_size=3, stride=2, padding=1), 19 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 20 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 21 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 22 | 23 | Conv2d(128, 256, kernel_size=3, stride=2, padding=1), 24 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 25 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 26 | 27 | Conv2d(256, 512, kernel_size=3, stride=2, padding=1), 28 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), 29 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), 30 | 31 | Conv2d(512, 512, kernel_size=3, stride=2, padding=1), 32 | Conv2d(512, 512, kernel_size=3, stride=1, padding=0), 33 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) 34 | 35 | self.audio_encoder = nn.Sequential( 36 | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), 37 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 38 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 39 | 40 | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), 41 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 42 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 43 | 44 | Conv2d(64, 128, kernel_size=3, stride=3, padding=1), 45 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 46 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 47 | 48 | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), 49 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 50 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 51 | 52 | Conv2d(256, 512, kernel_size=3, stride=1, padding=0), 53 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) 54 | 55 | def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T) 56 | face_embedding = self.face_encoder(face_sequences) 57 | audio_embedding = self.audio_encoder(audio_sequences) 58 | 59 | audio_embedding = audio_embedding.view(audio_embedding.size(0), -1) 60 | face_embedding = face_embedding.view(face_embedding.size(0), -1) 61 | 62 | audio_embedding = F.normalize(audio_embedding, p=2, dim=1) 63 | face_embedding = F.normalize(face_embedding, p=2, dim=1) 64 | 65 | 66 | return audio_embedding, face_embedding 67 | -------------------------------------------------------------------------------- /models/wav2lip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import math 5 | 6 | from .conv import Conv2dTranspose, Conv2d, nonorm_Conv2d 7 | 8 | class Wav2Lip(nn.Module): 9 | def __init__(self): 10 | super(Wav2Lip, self).__init__() 11 | 12 | self.face_encoder_blocks = nn.ModuleList([ 13 | nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=1, padding=3)), # 96,96 14 | 15 | nn.Sequential(Conv2d(16, 32, kernel_size=3, stride=2, padding=1), # 48,48 16 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 17 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True)), 18 | 19 | nn.Sequential(Conv2d(32, 64, kernel_size=3, stride=2, padding=1), # 24,24 20 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 21 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 22 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True)), 23 | 24 | nn.Sequential(Conv2d(64, 128, kernel_size=3, stride=2, padding=1), # 12,12 25 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 26 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True)), 27 | 28 | nn.Sequential(Conv2d(128, 256, kernel_size=3, stride=2, padding=1), # 6,6 29 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 30 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True)), 31 | 32 | nn.Sequential(Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 3,3 33 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), 34 | 35 | nn.Sequential(Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 36 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),]) 37 | 38 | self.audio_encoder = nn.Sequential( 39 | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), 40 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 41 | Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), 42 | 43 | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), 44 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 45 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 46 | 47 | Conv2d(64, 128, kernel_size=3, stride=3, padding=1), 48 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 49 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 50 | 51 | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), 52 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 53 | 54 | Conv2d(256, 512, kernel_size=3, stride=1, padding=0), 55 | Conv2d(512, 512, kernel_size=1, stride=1, padding=0),) 56 | 57 | self.face_decoder_blocks = nn.ModuleList([ 58 | nn.Sequential(Conv2d(512, 512, kernel_size=1, stride=1, padding=0),), 59 | 60 | nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=1, padding=0), # 3,3 61 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), 62 | 63 | nn.Sequential(Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1), 64 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True), 65 | Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),), # 6, 6 66 | 67 | nn.Sequential(Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1), 68 | Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True), 69 | Conv2d(384, 384, kernel_size=3, stride=1, padding=1, residual=True),), # 12, 12 70 | 71 | nn.Sequential(Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), 72 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), 73 | Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),), # 24, 24 74 | 75 | nn.Sequential(Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1), 76 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), 77 | Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),), # 48, 48 78 | 79 | nn.Sequential(Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1), 80 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), 81 | Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),),]) # 96,96 82 | 83 | self.output_block = nn.Sequential(Conv2d(80, 32, kernel_size=3, stride=1, padding=1), 84 | nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), 85 | nn.Sigmoid()) 86 | 87 | def forward(self, audio_sequences, face_sequences): 88 | # audio_sequences = (B, T, 1, 80, 16) 89 | B = audio_sequences.size(0) 90 | 91 | input_dim_size = len(face_sequences.size()) 92 | if input_dim_size > 4: 93 | audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) 94 | face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) 95 | 96 | audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1 97 | 98 | feats = [] 99 | x = face_sequences 100 | for f in self.face_encoder_blocks: 101 | x = f(x) 102 | feats.append(x) 103 | 104 | x = audio_embedding 105 | for f in self.face_decoder_blocks: 106 | x = f(x) 107 | try: 108 | x = torch.cat((x, feats[-1]), dim=1) 109 | except Exception as e: 110 | print(x.size()) 111 | print(feats[-1].size()) 112 | raise e 113 | 114 | feats.pop() 115 | 116 | x = self.output_block(x) 117 | 118 | if input_dim_size > 4: 119 | x = torch.split(x, B, dim=0) # [(B, C, H, W)] 120 | outputs = torch.stack(x, dim=2) # (B, C, T, H, W) 121 | 122 | else: 123 | outputs = x 124 | 125 | return outputs 126 | 127 | class Wav2Lip_disc_qual(nn.Module): 128 | def __init__(self): 129 | super(Wav2Lip_disc_qual, self).__init__() 130 | 131 | self.face_encoder_blocks = nn.ModuleList([ 132 | nn.Sequential(nonorm_Conv2d(3, 32, kernel_size=7, stride=1, padding=3)), # 48,96 133 | 134 | nn.Sequential(nonorm_Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=2), # 48,48 135 | nonorm_Conv2d(64, 64, kernel_size=5, stride=1, padding=2)), 136 | 137 | nn.Sequential(nonorm_Conv2d(64, 128, kernel_size=5, stride=2, padding=2), # 24,24 138 | nonorm_Conv2d(128, 128, kernel_size=5, stride=1, padding=2)), 139 | 140 | nn.Sequential(nonorm_Conv2d(128, 256, kernel_size=5, stride=2, padding=2), # 12,12 141 | nonorm_Conv2d(256, 256, kernel_size=5, stride=1, padding=2)), 142 | 143 | nn.Sequential(nonorm_Conv2d(256, 512, kernel_size=3, stride=2, padding=1), # 6,6 144 | nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1)), 145 | 146 | nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=2, padding=1), # 3,3 147 | nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=1),), 148 | 149 | nn.Sequential(nonorm_Conv2d(512, 512, kernel_size=3, stride=1, padding=0), # 1, 1 150 | nonorm_Conv2d(512, 512, kernel_size=1, stride=1, padding=0)),]) 151 | 152 | self.binary_pred = nn.Sequential(nn.Conv2d(512, 1, kernel_size=1, stride=1, padding=0), nn.Sigmoid()) 153 | self.label_noise = .0 154 | 155 | def get_lower_half(self, face_sequences): 156 | return face_sequences[:, :, face_sequences.size(2)//2:] 157 | 158 | def to_2d(self, face_sequences): 159 | B = face_sequences.size(0) 160 | face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) 161 | return face_sequences 162 | 163 | def perceptual_forward(self, false_face_sequences): 164 | false_face_sequences = self.to_2d(false_face_sequences) 165 | false_face_sequences = self.get_lower_half(false_face_sequences) 166 | 167 | false_feats = false_face_sequences 168 | for f in self.face_encoder_blocks: 169 | false_feats = f(false_feats) 170 | 171 | false_pred_loss = F.binary_cross_entropy(self.binary_pred(false_feats).view(len(false_feats), -1), 172 | torch.ones((len(false_feats), 1)).cuda()) 173 | 174 | return false_pred_loss 175 | 176 | def forward(self, face_sequences): 177 | face_sequences = self.to_2d(face_sequences) 178 | face_sequences = self.get_lower_half(face_sequences) 179 | 180 | x = face_sequences 181 | for f in self.face_encoder_blocks: 182 | x = f(x) 183 | 184 | return self.binary_pred(x).view(len(x), -1) 185 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | basicsr==1.4.2 2 | batch-face==1.4.0 3 | dlib==19.24.2 4 | facexlib==0.3.0 5 | gdown==4.7.1 6 | gfpgan==1.3.8 7 | imageio-ffmpeg==0.4.9 8 | importlib-metadata==6.8.0 9 | ipython==8.16.1 10 | librosa==0.10.1 11 | moviepy==1.0.3 12 | numpy==1.26.1 13 | opencv-python==4.8.1.78 14 | scipy==1.11.3 15 | --extra-index-url https://download.pytorch.org/whl/cu121 16 | torch==2.1.0 17 | torchaudio==2.1.0 18 | torchvision==0.16.0 19 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import re 4 | import argparse 5 | from easy_functions import (format_time, 6 | get_input_length, 7 | get_video_details, 8 | show_video, 9 | g_colab) 10 | import contextlib 11 | import shutil 12 | import subprocess 13 | import time 14 | from IPython.display import Audio, Image, clear_output, display 15 | from moviepy.video.io.ffmpeg_tools import ffmpeg_extract_subclip 16 | import configparser 17 | 18 | parser = argparse.ArgumentParser(description='Easy-Wav2Lip main run file') 19 | 20 | parser.add_argument('-video_file', type=str, 21 | help='Input video file path', required=False, default=False) 22 | parser.add_argument('-vocal_file', type=str, 23 | help='Input audio file path', required=False, default=False) 24 | parser.add_argument('-output_file', type=str, 25 | help='Output video file path', required=False, default=False) 26 | args = parser.parse_args() 27 | 28 | # retrieve variables from config.ini 29 | config = configparser.ConfigParser() 30 | 31 | config.read('config.ini') 32 | if args.video_file: 33 | video_file = args.video_file 34 | else: 35 | video_file = config['OPTIONS']['video_file'] 36 | 37 | if args.vocal_file: 38 | vocal_file = args.vocal_file 39 | else: 40 | vocal_file = config['OPTIONS']['vocal_file'] 41 | quality = config['OPTIONS']['quality'] 42 | output_height = config['OPTIONS']['output_height'] 43 | wav2lip_version = config['OPTIONS']['wav2lip_version'] 44 | use_previous_tracking_data = config['OPTIONS']['use_previous_tracking_data'] 45 | nosmooth = config.getboolean('OPTIONS', 'nosmooth') 46 | U = config.getint('PADDING', 'U') 47 | D = config.getint('PADDING', 'D') 48 | L = config.getint('PADDING', 'L') 49 | R = config.getint('PADDING', 'R') 50 | size = config.getfloat('MASK', 'size') 51 | feathering = config.getint('MASK', 'feathering') 52 | mouth_tracking = config.getboolean('MASK', 'mouth_tracking') 53 | debug_mask = config.getboolean('MASK', 'debug_mask') 54 | batch_process = config.getboolean('OTHER', 'batch_process') 55 | output_suffix = config['OTHER']['output_suffix'] 56 | include_settings_in_suffix = config.getboolean('OTHER', 'include_settings_in_suffix') 57 | 58 | if g_colab(): 59 | preview_input = config.getboolean("OTHER", "preview_input") 60 | else: 61 | preview_input = False 62 | preview_settings = config.getboolean("OTHER", "preview_settings") 63 | frame_to_preview = config.getint("OTHER", "frame_to_preview") 64 | 65 | working_directory = os.getcwd() 66 | 67 | 68 | start_time = time.time() 69 | 70 | video_file = video_file.strip('"') 71 | vocal_file = vocal_file.strip('"') 72 | 73 | # check video_file exists 74 | if video_file == "": 75 | sys.exit(f"video_file cannot be blank") 76 | 77 | if os.path.isdir(video_file): 78 | sys.exit(f"{video_file} is a directory, you need to point to a file") 79 | 80 | if not os.path.exists(video_file): 81 | sys.exit(f"Could not find file: {video_file}") 82 | 83 | if wav2lip_version == "Wav2Lip_GAN": 84 | checkpoint_path = os.path.join(working_directory, "checkpoints", "Wav2Lip_GAN.pth") 85 | else: 86 | checkpoint_path = os.path.join(working_directory, "checkpoints", "Wav2Lip.pth") 87 | 88 | if feathering == 3: 89 | feathering = 5 90 | if feathering == 2: 91 | feathering = 3 92 | 93 | resolution_scale = 1 94 | res_custom = False 95 | if output_height == "half resolution": 96 | resolution_scale = 2 97 | elif output_height == "full resolution": 98 | resolution_scale = 1 99 | else: 100 | res_custom = True 101 | resolution_scale = 3 102 | 103 | in_width, in_height, in_fps, in_length = get_video_details(video_file) 104 | out_height = round(in_height / resolution_scale) 105 | 106 | if res_custom: 107 | out_height = int(output_height) 108 | fps_for_static_image = 30 109 | 110 | 111 | if output_suffix == "" and not include_settings_in_suffix: 112 | sys.exit( 113 | "Current suffix settings will overwrite your input video! Please add a suffix or tick include_settings_in_suffix" 114 | ) 115 | 116 | frame_to_preview = max(frame_to_preview - 1, 0) 117 | 118 | if include_settings_in_suffix: 119 | if wav2lip_version == "Wav2Lip_GAN": 120 | output_suffix = f"{output_suffix}_GAN" 121 | output_suffix = f"{output_suffix}_{quality}" 122 | if output_height != "full resolution": 123 | output_suffix = f"{output_suffix}_{out_height}" 124 | if nosmooth: 125 | output_suffix = f"{output_suffix}_nosmooth1" 126 | else: 127 | output_suffix = f"{output_suffix}_nosmooth0" 128 | if U != 0 or D != 0 or L != 0 or R != 0: 129 | output_suffix = f"{output_suffix}_pads-" 130 | if U != 0: 131 | output_suffix = f"{output_suffix}U{U}" 132 | if D != 0: 133 | output_suffix = f"{output_suffix}D{D}" 134 | if L != 0: 135 | output_suffix = f"{output_suffix}L{L}" 136 | if R != 0: 137 | output_suffix = f"{output_suffix}R{R}" 138 | if quality != "fast": 139 | output_suffix = f"{output_suffix}_mask-S{size}F{feathering}" 140 | if mouth_tracking: 141 | output_suffix = f"{output_suffix}_mt" 142 | if debug_mask: 143 | output_suffix = f"{output_suffix}_debug" 144 | if preview_settings: 145 | output_suffix = f"{output_suffix}_preview" 146 | 147 | 148 | rescaleFactor = str(round(1 // resolution_scale)) 149 | pad_up = str(round(U * resolution_scale)) 150 | pad_down = str(round(D * resolution_scale)) 151 | pad_left = str(round(L * resolution_scale)) 152 | pad_right = str(round(R * resolution_scale)) 153 | ################################################################################ 154 | 155 | 156 | ######################### reconstruct input paths ############################## 157 | # Extract each part of the path 158 | folder, filename_with_extension = os.path.split(video_file) 159 | filename, file_type = os.path.splitext(filename_with_extension) 160 | 161 | # Extract filenumber if it exists 162 | filenumber_match = re.search(r"\d+$", filename) 163 | if filenumber_match: # if there is a filenumber - extract it 164 | filenumber = str(filenumber_match.group()) 165 | filenamenonumber = re.sub(r"\d+$", "", filename) 166 | else: # if there is no filenumber - make it blank 167 | filenumber = "" 168 | filenamenonumber = filename 169 | 170 | # if vocal_file is blank - use the video as audio 171 | if vocal_file == "": 172 | vocal_file = video_file 173 | # if not, check that the vocal_file file exists 174 | else: 175 | if not os.path.exists(vocal_file): 176 | sys.exit(f"Could not find file: {vocal_file}") 177 | if os.path.isdir(vocal_file): 178 | sys.exit(f"{vocal_file} is a directory, you need to point to a file") 179 | 180 | # Extract each part of the path 181 | audio_folder, audio_filename_with_extension = os.path.split(vocal_file) 182 | audio_filename, audio_file_type = os.path.splitext(audio_filename_with_extension) 183 | 184 | # Extract filenumber if it exists 185 | audio_filenumber_match = re.search(r"\d+$", audio_filename) 186 | if audio_filenumber_match: # if there is a filenumber - extract it 187 | audio_filenumber = str(audio_filenumber_match.group()) 188 | audio_filenamenonumber = re.sub(r"\d+$", "", audio_filename) 189 | else: # if there is no filenumber - make it blank 190 | audio_filenumber = "" 191 | audio_filenamenonumber = audio_filename 192 | ################################################################################ 193 | 194 | # set process_failed to False so that it may be set to True if one or more processings fail 195 | process_failed = False 196 | 197 | 198 | temp_output = os.path.join(working_directory, "temp", "output.mp4") 199 | temp_folder = os.path.join(working_directory, "temp") 200 | 201 | last_input_video = None 202 | last_input_audio = None 203 | 204 | # --------------------------Batch processing loop-------------------------------! 205 | while True: 206 | 207 | # construct input_video 208 | input_video = os.path.join(folder, filenamenonumber + str(filenumber) + file_type) 209 | input_videofile = os.path.basename(input_video) 210 | 211 | # construct input_audio 212 | input_audio = os.path.join( 213 | audio_folder, audio_filenamenonumber + str(audio_filenumber) + audio_file_type 214 | ) 215 | input_audiofile = os.path.basename(input_audio) 216 | 217 | # see if filenames are different: 218 | if filenamenonumber + str(filenumber) != audio_filenamenonumber + str( 219 | audio_filenumber 220 | ): 221 | output_filename = ( 222 | filenamenonumber 223 | + str(filenumber) 224 | + "_" 225 | + audio_filenamenonumber 226 | + str(audio_filenumber) 227 | ) 228 | else: 229 | output_filename = filenamenonumber + str(filenumber) 230 | 231 | # construct output_video 232 | output_video = os.path.join(folder, output_filename + output_suffix + ".mp4") 233 | output_video = os.path.normpath(output_video) 234 | output_videofile = os.path.basename(output_video) 235 | 236 | # remove last outputs 237 | if os.path.exists("temp"): 238 | shutil.rmtree("temp") 239 | os.makedirs("temp", exist_ok=True) 240 | 241 | # preview inputs (if enabled) 242 | if preview_input: 243 | print("input video:") 244 | show_video(input_video) 245 | if vocal_file != "": 246 | print("input audio:") 247 | display(Audio(input_audio)) 248 | else: 249 | print("using", input_videofile, "for audio") 250 | print("You may want to check now that they're the correct files!") 251 | 252 | last_input_video = input_video 253 | last_input_audio = input_audio 254 | shutil.copy(input_video, temp_folder) 255 | shutil.copy(input_audio, temp_folder) 256 | 257 | # rename temp file to include padding or else changing padding does nothing 258 | temp_input_video = os.path.join(temp_folder, input_videofile) 259 | renamed_temp_input_video = os.path.join( 260 | temp_folder, str(U) + str(D) + str(L) + str(R) + input_videofile 261 | ) 262 | shutil.copy(temp_input_video, renamed_temp_input_video) 263 | temp_input_video = renamed_temp_input_video 264 | temp_input_videofile = os.path.basename(renamed_temp_input_video) 265 | temp_input_audio = os.path.join(temp_folder, input_audiofile) 266 | 267 | # trim video if it's longer than the audio 268 | video_length = get_input_length(temp_input_video) 269 | audio_length = get_input_length(temp_input_audio) 270 | 271 | if preview_settings: 272 | batch_process = False 273 | 274 | preview_length_seconds = 1 275 | converted_preview_frame = frame_to_preview / in_fps 276 | preview_start_time = min( 277 | converted_preview_frame, video_length - preview_length_seconds 278 | ) 279 | 280 | preview_video_path = os.path.join( 281 | temp_folder, 282 | "preview_" 283 | + str(preview_start_time) 284 | + "_" 285 | + str(U) 286 | + str(D) 287 | + str(L) 288 | + str(R) 289 | + input_videofile, 290 | ) 291 | preview_audio_path = os.path.join(temp_folder, "preview_" + input_audiofile) 292 | 293 | subprocess.call( 294 | [ 295 | "ffmpeg", 296 | "-loglevel", 297 | "error", 298 | "-i", 299 | temp_input_video, 300 | "-ss", 301 | str(preview_start_time), 302 | "-to", 303 | str(preview_start_time + preview_length_seconds), 304 | "-c", 305 | "copy", 306 | preview_video_path, 307 | ] 308 | ) 309 | subprocess.call( 310 | [ 311 | "ffmpeg", 312 | "-loglevel", 313 | "error", 314 | "-i", 315 | temp_input_audio, 316 | "-ss", 317 | str(preview_start_time), 318 | "-to", 319 | str(preview_start_time + 1), 320 | "-c", 321 | "copy", 322 | preview_audio_path, 323 | ] 324 | ) 325 | temp_input_video = preview_video_path 326 | temp_input_audio = preview_audio_path 327 | 328 | if video_length > audio_length: 329 | trimmed_video_path = os.path.join( 330 | temp_folder, "trimmed_" + temp_input_videofile 331 | ) 332 | with open(os.devnull, "w") as devnull: 333 | with contextlib.redirect_stdout(devnull), contextlib.redirect_stderr( 334 | devnull 335 | ): 336 | ffmpeg_extract_subclip( 337 | temp_input_video, 0, audio_length, targetname=trimmed_video_path 338 | ) 339 | temp_input_video = trimmed_video_path 340 | # check if face detection has already happened on this clip 341 | last_detected_face = os.path.join(working_directory, "last_detected_face.pkl") 342 | if os.path.isfile("last_file.txt"): 343 | with open("last_file.txt", "r") as file: 344 | last_file = file.readline() 345 | if last_file != temp_input_video or use_previous_tracking_data == "False": 346 | if os.path.isfile(last_detected_face): 347 | os.remove(last_detected_face) 348 | 349 | # ----------------------------Process the inputs!-----------------------------! 350 | print( 351 | f"Processing{' preview of' if preview_settings else ''} " 352 | f"{input_videofile} using {input_audiofile} for audio" 353 | ) 354 | 355 | # execute Wav2Lip & upscaler 356 | 357 | cmd = [ 358 | sys.executable, 359 | "inference.py", 360 | "--face", 361 | temp_input_video, 362 | "--audio", 363 | temp_input_audio, 364 | "--outfile", 365 | temp_output, 366 | "--pads", 367 | str(pad_up), 368 | str(pad_down), 369 | str(pad_left), 370 | str(pad_right), 371 | "--checkpoint_path", 372 | checkpoint_path, 373 | "--out_height", 374 | str(out_height), 375 | "--fullres", 376 | str(resolution_scale), 377 | "--quality", 378 | quality, 379 | "--mask_dilation", 380 | str(size), 381 | "--mask_feathering", 382 | str(feathering), 383 | "--nosmooth", 384 | str(nosmooth), 385 | "--debug_mask", 386 | str(debug_mask), 387 | "--preview_settings", 388 | str(preview_settings), 389 | "--mouth_tracking", 390 | str(mouth_tracking), 391 | ] 392 | 393 | # Run the command 394 | subprocess.run(cmd) 395 | 396 | if preview_settings: 397 | if os.path.isfile(os.path.join(temp_folder, "preview.jpg")): 398 | print(f"preview successful! Check out temp/preview.jpg") 399 | with open("last_file.txt", "w") as f: 400 | f.write(temp_input_video) 401 | # end processing timer and format the time it took 402 | end_time = time.time() 403 | elapsed_time = end_time - start_time 404 | formatted_setup_time = format_time(elapsed_time) 405 | print(f"Execution time: {formatted_setup_time}") 406 | break 407 | 408 | else: 409 | print(f"Processing failed! :( see line above 👆") 410 | print("Consider searching the issues tab on the github:") 411 | print("https://github.com/anothermartz/Easy-Wav2Lip/issues") 412 | exit() 413 | 414 | # rename temp file and move to correct directory 415 | if os.path.isfile(temp_output): 416 | if os.path.isfile(output_video): 417 | os.remove(output_video) 418 | shutil.copy(temp_output, output_video) 419 | # show output video 420 | with open("last_file.txt", "w") as f: 421 | f.write(temp_input_video) 422 | print(f"{output_filename} successfully lip synced! It will be found here:") 423 | print(output_video) 424 | 425 | # end processing timer and format the time it took 426 | end_time = time.time() 427 | elapsed_time = end_time - start_time 428 | formatted_setup_time = format_time(elapsed_time) 429 | print(f"Execution time: {formatted_setup_time}") 430 | 431 | else: 432 | print(f"Processing failed! :( see line above 👆") 433 | print("Consider searching the issues tab on the github:") 434 | print("https://github.com/anothermartz/Easy-Wav2Lip/issues") 435 | process_failed = True 436 | 437 | if batch_process == False: 438 | if process_failed: 439 | exit() 440 | else: 441 | break 442 | 443 | elif filenumber == "" and audio_filenumber == "": 444 | print("Files not set for batch processing") 445 | break 446 | 447 | # -----------------------------Batch Processing!------------------------------! 448 | if filenumber != "": # if video has a filenumber 449 | match = re.search(r"\d+", filenumber) 450 | # add 1 to video filenumber 451 | filenumber = ( 452 | f"{filenumber[:match.start()]}{int(match.group())+1:0{len(match.group())}d}" 453 | ) 454 | 455 | if audio_filenumber != "": # if audio has a filenumber 456 | match = re.search(r"\d+", audio_filenumber) 457 | # add 1 to audio filenumber 458 | audio_filenumber = f"{audio_filenumber[:match.start()]}{int(match.group())+1:0{len(match.group())}d}" 459 | 460 | # construct input_video 461 | input_video = os.path.join(folder, filenamenonumber + str(filenumber) + file_type) 462 | input_videofile = os.path.basename(input_video) 463 | # construct input_audio 464 | input_audio = os.path.join( 465 | audio_folder, audio_filenamenonumber + str(audio_filenumber) + audio_file_type 466 | ) 467 | input_audiofile = os.path.basename(input_audio) 468 | 469 | # now check which input files exist and what to do for each scenario 470 | 471 | # both +1 files exist - continue processing 472 | if os.path.exists(input_video) and os.path.exists(input_audio): 473 | continue 474 | 475 | # video +1 only - continue with last audio file 476 | if os.path.exists(input_video) and input_video != last_input_video: 477 | if audio_filenumber != "": # if audio has a filenumber 478 | match = re.search(r"\d+", audio_filenumber) 479 | # take 1 from audio filenumber 480 | audio_filenumber = f"{audio_filenumber[:match.start()]}{int(match.group())-1:0{len(match.group())}d}" 481 | continue 482 | 483 | # audio +1 only - continue with last video file 484 | if os.path.exists(input_audio) and input_audio != last_input_audio: 485 | if filenumber != "": # if video has a filenumber 486 | match = re.search(r"\d+", filenumber) 487 | # take 1 from video filenumber 488 | filenumber = f"{filenumber[:match.start()]}{int(match.group())-1:0{len(match.group())}d}" 489 | continue 490 | 491 | # neither +1 files exist or current files already processed - finish processing 492 | print("Finished all sequentially numbered files") 493 | if process_failed: 494 | sys.exit("Processing failed on at least one video") 495 | else: 496 | break 497 | -------------------------------------------------------------------------------- /run_loop.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | :run_loop 3 | call GUI.py 4 | 5 | if exist "run.txt" ( 6 | echo starting Easy-Wav2Lip... 7 | python run.py 8 | goto run_loop 9 | ) 10 | -------------------------------------------------------------------------------- /run_loop.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | while true; do 4 | python GUI.py 5 | 6 | if [ -f "run.txt" ]; then 7 | echo "Starting Easy-Wav2Lip..." 8 | python run.py 9 | else 10 | break # Exit the loop when "run.txt" does not exist 11 | fi 12 | done 13 | --------------------------------------------------------------------------------