├── color ├── place onnx model here ├── deoldify.py └── deoldify_fp16.py ├── requirements.txt ├── render_factor.txt ├── setup.txt ├── README.md ├── image.py ├── convert_onnx_to_fp16_gui.py ├── convert_to_onnx.py ├── video.py ├── image_GUI.py └── video_GUI.py /color/place onnx model here: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | numpy 3 | tqdm 4 | onnxruntime==1.14.1 5 | 6 | -------------------------------------------------------------------------------- /render_factor.txt: -------------------------------------------------------------------------------- 1 | eg. 2 | --render_factor 8 3 | 4 | Like the original deoldify but step 2 5 | So original render_factor 16 now is render_factor 8 (default) 6 | (Range 1 to 40 original = range 1 to 20 onnx version) 7 | -------------------------------------------------------------------------------- /setup.txt: -------------------------------------------------------------------------------- 1 | clone repo 2 | 3 | conda create -n ENV_NAME python==3.10 4 | conda activate ENV_NAME 5 | cd c:\env\env_path 6 | 7 | pip install -r requirements.txt 8 | 9 | ############# GPU ############### 10 | 11 | conda install -c conda-forge cudatoolkit=11.2 cudnn=8.1.0 12 | pip uninstall onnxuntime 13 | pip install onnxruntime-gpu 14 | 15 | model download: 16 | https://drive.google.com/drive/folders/1bU9Zj7zGVEujIzvDTb1b9cyWU3s__WQR?usp=sharing 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # deoldify-onnx 2 | 3 | Updated version: option render factor added (only commandline version) 4 | 5 | New models for use with render factor: 6 | 7 | https://drive.google.com/drive/folders/1bU9Zj7zGVEujIzvDTb1b9cyWU3s__WQR?usp=sharing 8 | 9 | . 10 | 11 | Simple image and video colorization using onnx converted deoldify model. 12 | 13 | Easy to install. Can be run on CPU or nVidia GPU 14 | 15 | ffmpeg for video colorzation required. 16 | 17 | Added floating point 16 model for 100% faster inference and simple GUI version. 18 | 19 | For inference run: 20 | 21 | Image: 22 | python image.py --source_image "image.jpg" 23 | 24 | Video: 25 | python video.py --source "video.mp4" --result "video_colorized.mp4" --audio 26 | 27 | Image example: 28 | ![colorizer1](https://github.com/instant-high/deoldify-onnx/assets/77229558/171642dd-9034-4ca7-8d29-c07c6e5e9f0a) 29 | 30 | 31 | https://github.com/instant-high/deoldify-onnx/assets/77229558/3824e96d-fffc-494e-8ce1-193e6a77c8b6 32 | 33 | https://github.com/instant-high/deoldify-onnx/assets/77229558/543e1dd1-27da-4c63-95a9-9c0696adea51 34 | 35 | 36 | original deoldify: 37 | 38 | https://github.com/jantic/DeOldify 39 | 40 | -------------------------------------------------------------------------------- /image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from argparse import ArgumentParser 5 | 6 | import onnxruntime as rt 7 | rt.set_default_logger_severity(3) 8 | parser = ArgumentParser() 9 | parser.add_argument("--source_image", default='source.jpg', help="path to source image") 10 | parser.add_argument("--result_image", default='result.jpg', help="path to result image") 11 | parser.add_argument("--render_factor", type=int, default=8, help=" - ") 12 | opt = parser.parse_args() 13 | 14 | ''' 15 | The render factor determines the resolution at which the image is rendered for inference. 16 | When set at a low value, the process is faster and the colors tend to be more vibrant 17 | but the results are less stable. 18 | original torch model accepts input divisible by 16 19 | ONNX models currently accept only divisible by 32 20 | ''' 21 | 22 | # 23 | render_factor = opt.render_factor * 32 24 | # 25 | 26 | # old model - you cannot set render_factor 27 | #from color.deoldify_fp16 import DEOLDIFY 28 | #colorizer = DEOLDIFY(model_path="color/deoldify_fp16.onnx", device="cpu") 29 | #from color.deoldify import DEOLDIFY 30 | #colorizer = DEOLDIFY(model_path="color/deoldify.onnx", device="cuda") 31 | 32 | # new onnx models - render_factor - dynamic axes input: 33 | from color.deoldify_fp16 import DEOLDIFY 34 | colorizer = DEOLDIFY(model_path="color/ColorizeArtistic_dyn_fp16.onnx", device="cuda") 35 | #from color.deoldify import DEOLDIFY 36 | #colorizer = DEOLDIFY(model_path="color/ColorizeArtistic_dyn.onnx", device="cuda") 37 | 38 | image = cv2.imread(opt.source_image) 39 | 40 | colorized = colorizer.colorize(image, render_factor) 41 | 42 | cv2.imwrite(opt.result_image, colorized) 43 | cv2.imshow("Colorized image saved - press any key",colorized) 44 | cv2.waitKey() 45 | -------------------------------------------------------------------------------- /convert_onnx_to_fp16_gui.py: -------------------------------------------------------------------------------- 1 | import tkinter as tk 2 | from tkinter import filedialog 3 | import os 4 | import onnx 5 | from onnxconverter_common import float16 6 | 7 | def select_model_file(): 8 | file_path = filedialog.askopenfilename(filetypes=[("ONNX files", "*.onnx")]) 9 | if file_path: 10 | entry.delete(0, tk.END) 11 | entry.insert(0, file_path) 12 | result_label.config(text="Model loaded") 13 | 14 | def convert_to_float16(): 15 | model_path = entry.get() 16 | if model_path: 17 | model = onnx.load(model_path) 18 | result_file= os.path.splitext(model_path)[0] 19 | result_extension = os.path.splitext(model_path)[1] 20 | model_converted = result_file + "_fp16" + result_extension 21 | model_fp16 = float16.convert_float_to_float16(model, min_positive_val=1e-7, max_finite_val=1e4, keep_io_types=False,disable_shape_infer=False, op_block_list=None, node_block_list=None) 22 | onnx.save(model_fp16, model_converted) 23 | result_label.config(text="Model converted successfully!") 24 | 25 | # Create the main window 26 | root = tk.Tk() 27 | root.title("ONNX Model Converter") 28 | 29 | # Create widgets 30 | label = tk.Label(root, text="Select ONNX model file:") 31 | label.grid(row=0, column=0, padx=10, pady=5) 32 | 33 | entry = tk.Entry(root, width=50) 34 | entry.grid(row=0, column=1, padx=10, pady=5, columnspan=2) 35 | 36 | browse_button = tk.Button(root, text="Browse", command=select_model_file) 37 | browse_button.grid(row=0, column=3, padx=5, pady=5) 38 | 39 | convert_button = tk.Button(root, text="Convert to float16", command=convert_to_float16) 40 | convert_button.grid(row=1, column=1, columnspan=2, padx=10, pady=5) 41 | 42 | result_label = tk.Label(root, text="") 43 | result_label.grid(row=2, column=1, columnspan=2, padx=10, pady=5) 44 | 45 | # Start the GUI 46 | root.mainloop() 47 | -------------------------------------------------------------------------------- /color/deoldify.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import onnxruntime 4 | 5 | 6 | class DEOLDIFY: 7 | def __init__(self, model_path="deoldify.onnx", device='cpu'): 8 | session_options = onnxruntime.SessionOptions() 9 | session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 10 | providers = ["CPUExecutionProvider"] 11 | if device == 'cuda': 12 | providers = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}),"CPUExecutionProvider"] 13 | self.session = onnxruntime.InferenceSession(model_path, sess_options=session_options, providers=providers) 14 | self.resolution = self.session.get_inputs()[0].shape[-2:] 15 | 16 | 17 | def colorize(self, image, r_factor): 18 | 19 | # preprocess image: 20 | targetL = cv2.cvtColor(image,cv2.COLOR_BGR2LAB) 21 | targetL,_,_=cv2.split(image) 22 | 23 | image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 24 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) 25 | 26 | h, w, channels = image.shape 27 | 28 | image = cv2.resize(image,(r_factor, r_factor)) 29 | image = image.astype(np.float32) 30 | image = image.transpose((2, 0, 1)) 31 | image = np.expand_dims(image, axis=0).astype(np.float32) 32 | 33 | # run deoldify: 34 | colorized = self.session.run(None, {(self.session.get_inputs()[0].name):image})[0][0] 35 | 36 | # postprocess image: 37 | colorized = colorized.transpose(1,2,0) 38 | colorized = cv2.cvtColor(colorized, cv2.COLOR_BGR2RGB).astype(np.uint8) 39 | colorized = cv2.resize(colorized,(w,h)) 40 | colorized = cv2.GaussianBlur(colorized,(13,13),0) 41 | colorizedLAB = cv2.cvtColor(colorized,cv2.COLOR_BGR2LAB) 42 | L,A,B=cv2.split(colorizedLAB) 43 | colorizedLAB = cv2.resize(colorizedLAB,(w, h)) 44 | colorized = cv2.merge((targetL,A,B)) 45 | colorized = cv2.cvtColor(colorized,cv2.COLOR_LAB2BGR) 46 | 47 | return colorized 48 | 49 | -------------------------------------------------------------------------------- /color/deoldify_fp16.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import onnxruntime 4 | 5 | 6 | class DEOLDIFY: 7 | def __init__(self, model_path="deoldify.onnx", device='cpu'): 8 | session_options = onnxruntime.SessionOptions() 9 | session_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL 10 | providers = ["CPUExecutionProvider"] 11 | if device == 'cuda': 12 | providers = [("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}),"CPUExecutionProvider"] 13 | self.session = onnxruntime.InferenceSession(model_path, sess_options=session_options, providers=providers) 14 | self.resolution = self.session.get_inputs()[0].shape[-2:] 15 | 16 | 17 | def colorize(self, image, r_factor): 18 | 19 | # preprocess image: 20 | targetL = cv2.cvtColor(image,cv2.COLOR_BGR2LAB) 21 | targetL,_,_=cv2.split(image) 22 | 23 | image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 24 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) 25 | h, w, channels = image.shape 26 | 27 | image = cv2.resize(image,(r_factor, r_factor)) 28 | image = image.astype(np.float16) 29 | image = image.transpose((2, 0, 1)) 30 | image = np.expand_dims(image, axis=0).astype(np.float16) 31 | 32 | # run deoldify: 33 | colorized = self.session.run(None, {(self.session.get_inputs()[0].name):image})[0][0] 34 | 35 | # postprocess image: 36 | colorized = colorized.transpose(1,2,0).astype(np.float32) 37 | colorized = cv2.cvtColor(colorized, cv2.COLOR_BGR2RGB).astype(np.uint8) 38 | colorized = cv2.resize(colorized,(w,h)) 39 | colorized = cv2.GaussianBlur(colorized,(13,13),0) 40 | colorizedLAB = cv2.cvtColor(colorized,cv2.COLOR_BGR2LAB) 41 | L,A,B=cv2.split(colorizedLAB) 42 | colorizedLAB = cv2.resize(colorizedLAB,(w,h)) 43 | colorized = cv2.merge((targetL,A,B)) 44 | colorized = cv2.cvtColor(colorized,cv2.COLOR_LAB2BGR) 45 | 46 | return colorized 47 | 48 | -------------------------------------------------------------------------------- /convert_to_onnx.py: -------------------------------------------------------------------------------- 1 | ''' 2 | run this script in the original deoldify repo 3 | thanks to henry ruhs - face fusion for helping 4 | ''' 5 | 6 | import os 7 | import torch 8 | from deoldify.generators import gen_inference_deep 9 | from deoldify.generators import gen_inference_wide 10 | import torch.nn as nn 11 | from pathlib import Path 12 | 13 | from fastai.vision.data import normalize_funcs, imagenet_stats 14 | 15 | norm, denorm = normalize_funcs(*imagenet_stats) 16 | 17 | class ImageScaleInput(nn.Module): 18 | def __init__(self): 19 | super().__init__() 20 | 21 | def forward(self, x): 22 | out = (x.div(255.0)).type(torch.float32) 23 | out, _ = norm((out, out)) 24 | return out 25 | 26 | class ImageScaleOutput(nn.Module): 27 | def __init__(self): 28 | super().__init__() 29 | 30 | def forward(self, x): 31 | out = denorm(x) 32 | out = out.float().clamp(min=0, max=1) 33 | out = (out.mul(255.0)).type(torch.float32) 34 | return out 35 | 36 | root_folder = Path('./deoldify') 37 | 38 | # select the original model to be converted: 39 | #raw_model = gen_inference_deep(root_folder=Path('./deoldify'), weights_name='./deoldify/ColorizeArtistic_gen').model 40 | #onnx_path = 'ColorizeArtistic_dyn.onnx' 41 | 42 | #raw_model = gen_inference_wide(root_folder=Path('./deoldify'), weights_name='./deoldify/ColorizeStable_gen').model 43 | #onnx_path = 'ColorizeStable_dyn.onnx' 44 | 45 | raw_model = gen_inference_wide(root_folder=Path('./deoldify'), weights_name='./deoldify/DeoldifyVideo_gen').model 46 | onnx_path = 'DeoldifyVideo_dyn.onnx' 47 | 48 | dummy_input = torch.randn(1, 3, 256, 256) 49 | 50 | # Wenn CUDA verfügbar ist, auf CUDA umschalten 51 | dummy_input = dummy_input.to('cuda') 52 | 53 | final_pytorch_model = nn.Sequential(ImageScaleInput(), raw_model, ImageScaleOutput()) 54 | 55 | torch.onnx.export( 56 | final_pytorch_model, 57 | dummy_input, 58 | onnx_path, 59 | do_constant_folding=False, 60 | input_names=['input'], 61 | output_names=['output'], 62 | opset_version=12, 63 | dynamic_axes={'input': {2: 'height', 3: 'width'}, 'output': {2: 'height', 3: 'width'}} 64 | ) 65 | -------------------------------------------------------------------------------- /video.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import cv2 5 | import numpy as np 6 | import subprocess 7 | import platform 8 | 9 | from argparse import ArgumentParser 10 | from tqdm import tqdm 11 | 12 | import onnxruntime as rt 13 | rt.set_default_logger_severity(3) 14 | 15 | parser = ArgumentParser() 16 | parser.add_argument("--source", help="path to source video") 17 | parser.add_argument("--result", help="path to result video") 18 | parser.add_argument("--audio", default=False, action="store_true", help="Keep audio") 19 | parser.add_argument("--render_factor", type=int, default=8, help=" - ") 20 | opt = parser.parse_args() 21 | 22 | 23 | ''' 24 | The render factor determines the resolution at which the image is rendered for inference. 25 | When set at a low value, the process is faster and the colors tend to be more vibrant 26 | but the results are less stable. 27 | original torch model accepts input divisible by 16 28 | ONNX models currently accept only divisible by 32 29 | ''' 30 | 31 | # 32 | render_factor = opt.render_factor * 32 33 | # 34 | 35 | # old model - you cannot set render_factor 36 | #from color.deoldify_fp16 import DEOLDIFY 37 | #colorizer = DEOLDIFY(model_path="color/deoldify_fp16.onnx", device="cpu") 38 | #from color.deoldify import DEOLDIFY 39 | #colorizer = DEOLDIFY(model_path="color/deoldify.onnx", device="cuda") 40 | 41 | # new onnx models - render_factor - dynamic axes input: 42 | from color.deoldify_fp16 import DEOLDIFY 43 | colorizer = DEOLDIFY(model_path="color/ColorizeArtistic_dyn_fp16.onnx", device="cuda") 44 | #from color.deoldify import DEOLDIFY 45 | #colorizer = DEOLDIFY(model_path="color/ColorizeArtistic_dyn.onnx", device="cuda") 46 | 47 | 48 | video = cv2.VideoCapture(opt.source) 49 | 50 | w = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) 51 | h = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) 52 | n_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 53 | fps = video.get(cv2.CAP_PROP_FPS) 54 | 55 | if opt.audio: 56 | writer = cv2.VideoWriter('temp.mp4',cv2.VideoWriter_fourcc('m','p','4','v'), fps, (w, h)) 57 | else: 58 | writer = cv2.VideoWriter(opt.result,cv2.VideoWriter_fourcc('m','p','4','v'), fps, (w, h)) 59 | 60 | for frame_idx in tqdm(range(n_frames)): 61 | 62 | ret, frame = video.read() 63 | if not ret: 64 | break 65 | 66 | result = colorizer.colorize(frame, render_factor) 67 | 68 | writer.write(result) 69 | cv2.imshow ("Result",result) 70 | k = cv2.waitKey(1) 71 | if k == 27: 72 | writer.release() 73 | break 74 | 75 | if opt.audio: 76 | # lossless remuxing audio/video 77 | command = 'ffmpeg.exe -y -vn -i ' + '"' + opt.source + '"' + ' -an -i ' + 'temp.mp4' + ' -c:v copy -acodec libmp3lame -ac 2 -ar 44100 -ab 128000 -map 0:1 -map 1:0 -shortest ' + '"' + opt.result + '"' 78 | subprocess.call(command, shell=platform.system() != 'Windows') 79 | os.remove('temp.mp4') 80 | -------------------------------------------------------------------------------- /image_GUI.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import tkinter as tk 5 | from tkinter import filedialog, messagebox 6 | from PIL import Image, ImageTk 7 | from color.deoldify import DEOLDIFY 8 | import onnxruntime as rt 9 | 10 | rt.set_default_logger_severity(3) 11 | 12 | # old model - you cannot set render_factor 13 | #from color.deoldify_fp16 import DEOLDIFY 14 | #colorizer = DEOLDIFY(model_path="color/deoldify_fp16.onnx", device="cpu") 15 | #from color.deoldify import DEOLDIFY 16 | #colorizer = DEOLDIFY(model_path="color/deoldify.onnx", device="cuda") 17 | 18 | from color.deoldify import DEOLDIFY 19 | colorizer = DEOLDIFY(model_path="color/ColorizeArtistic_dyn.onnx", device="cuda") 20 | 21 | 22 | def resize_image(image): 23 | max_height = root.winfo_screenheight() - 100 # Adjusted for padding 24 | max_width = root.winfo_screenwidth() - 200 # Adjusted for padding 25 | height, width = image.shape[:2] 26 | if height > max_height or width > max_width: 27 | scale = min(max_height/height, max_width/width) 28 | new_width = int(width * scale) 29 | new_height = int(height * scale) 30 | return cv2.resize(image, (new_width, new_height)) 31 | return image 32 | 33 | def select_image(): 34 | file_path = filedialog.askopenfilename(filetypes=[("Image files", "*.jpg;*.jpeg;*.png;*.bmp")]) 35 | if file_path: 36 | process_image(file_path) 37 | 38 | def adjust_saturation(image, saturation_factor): 39 | # Convert BGR image to HSV 40 | hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 41 | # Split into Hue, Saturation, and Value channels 42 | h, s, v = cv2.split(hsv) 43 | # Scale the saturation channel 44 | s = np.clip(s * saturation_factor, 0, 255).astype(np.uint8) 45 | # Merge the channels back 46 | adjusted_hsv = cv2.merge([h, s, v]) 47 | # Convert back to BGR 48 | adjusted_bgr = cv2.cvtColor(adjusted_hsv, cv2.COLOR_HSV2BGR) 49 | return adjusted_bgr 50 | 51 | def process_image(file_path): 52 | render_factor = 8 53 | render_factor = render_factor * 32 54 | 55 | image = cv2.imread(file_path) 56 | if image is None: 57 | messagebox.showerror("Error", "Failed to load the image.") 58 | return 59 | 60 | # Resize the image if it's too big 61 | image = resize_image(image) 62 | 63 | colorized = colorizer.colorize(image, render_factor) 64 | 65 | #colorized = adjust_saturation(colorized, 1) # 0.1 - 2.0 66 | 67 | # Convert the OpenCV BGR image to RGB 68 | colorized_rgb = cv2.cvtColor(colorized, cv2.COLOR_BGR2RGB) 69 | 70 | # Convert the colorized image to PIL format 71 | img_colorized = Image.fromarray(colorized_rgb) 72 | 73 | # Get the screen width and height 74 | screen_width = root.winfo_screenwidth() 75 | screen_height = root.winfo_screenheight() 76 | 77 | # Resize the colorized image to fit the screen dimensions 78 | img_width, img_height = img_colorized.size 79 | aspect_ratio = img_width / img_height 80 | max_width = screen_width - 200 # Adjusted for padding 81 | max_height = screen_height - 100 # Adjusted for padding 82 | 83 | if max_width / aspect_ratio < max_height: 84 | new_width = max_width 85 | new_height = int(new_width / aspect_ratio) 86 | else: 87 | new_height = max_height 88 | new_width = int(new_height * aspect_ratio) 89 | 90 | img_colorized = img_colorized.resize((new_width, new_height)) 91 | 92 | # Create a PhotoImage object to display in the Tkinter window 93 | img_colorized_tk = ImageTk.PhotoImage(img_colorized) 94 | 95 | # Update the label with the colorized image 96 | colorized_label.configure(image=img_colorized_tk) 97 | colorized_label.image = img_colorized_tk 98 | 99 | # Set the window size to fit the image and center it on the screen 100 | root.geometry(f"{new_width}x{new_height}+{(screen_width - new_width) // 2}+{(screen_height - new_height) // 2}") 101 | 102 | # Save the colorized image 103 | result_file = os.path.splitext(file_path)[0] 104 | result_extension = os.path.splitext(file_path)[1] 105 | result_image = result_file + "_colorized" + result_extension 106 | img_colorized.save(result_image) 107 | messagebox.showinfo("Done", "File saved as " + result_image) 108 | 109 | # Ensure the window stays centered after displaying the image 110 | root.update_idletasks() 111 | 112 | 113 | 114 | def main(): 115 | global root, original_label, colorized_label 116 | root = tk.Tk() 117 | root.title("Image Colorization - deoldify.onnx") 118 | 119 | # Calculate the center position 120 | screen_width = root.winfo_screenwidth() 121 | screen_height = root.winfo_screenheight() 122 | window_width = 400 # Adjust as needed 123 | window_height = 70 # Adjust as needed 124 | x = (screen_width - window_width) // 2 125 | y = (screen_height - window_height) // 2 126 | 127 | # Set the window size and position 128 | root.geometry(f"{window_width}x{window_height}+{x}+{y}") 129 | 130 | select_button = tk.Button(root, text="Select Image", command=select_image) 131 | select_button.pack(pady=10) 132 | 133 | original_label = tk.Label(root) 134 | original_label.pack(side=tk.LEFT, padx=10) 135 | 136 | colorized_label = tk.Label(root) 137 | colorized_label.pack(side=tk.RIGHT, padx=10) 138 | 139 | root.mainloop() 140 | 141 | 142 | if __name__ == "__main__": 143 | main() 144 | -------------------------------------------------------------------------------- /video_GUI.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import subprocess 5 | import platform 6 | import tkinter as tk 7 | from tkinter import filedialog, messagebox 8 | from argparse import Namespace 9 | from tqdm import tqdm 10 | import onnxruntime as rt 11 | rt.set_default_logger_severity(3) 12 | 13 | # old model - you cannot set render_factor 14 | #from color.deoldify_fp16 import DEOLDIFY 15 | #colorizer = DEOLDIFY(model_path="color/deoldify_fp16.onnx", device="cpu") 16 | #from color.deoldify import DEOLDIFY 17 | #colorizer = DEOLDIFY(model_path="color/deoldify.onnx", device="cuda") 18 | 19 | from color.deoldify import DEOLDIFY 20 | colorizer = DEOLDIFY(model_path="color/ColorizeArtistic_dyn.onnx", device="cuda") 21 | 22 | 23 | class DEOLDIFY_GUI: 24 | def __init__(self, master): 25 | self.master = master 26 | master.title("Video Colorizer - Deoldify-ONNX") 27 | 28 | # Calculate the center position 29 | screen_width = master.winfo_screenwidth() 30 | screen_height = master.winfo_screenheight() 31 | window_width = 440 # Adjust as needed 32 | window_height = 120 # Adjust as needed 33 | x = (screen_width - window_width) // 2 34 | y = (screen_height - window_height) // 2 35 | 36 | # Set the window size and position 37 | master.geometry(f"{window_width}x{window_height}+{x}+{y}") 38 | 39 | self.source_label = tk.Label(master, text="Source Video:") 40 | self.source_label.grid(row=0, column=0, sticky='w') 41 | 42 | self.source_path = tk.StringVar() 43 | self.source_entry = tk.Entry(master, textvariable=self.source_path, width=50) 44 | self.source_entry.grid(row=0, column=1) 45 | 46 | self.source_button = tk.Button(master, text="Browse", command=self.browse_source) 47 | self.source_button.grid(row=0, column=2) 48 | 49 | self.result_label = tk.Label(master, text="Result Video:") 50 | self.result_label.grid(row=1, column=0, sticky='w') 51 | 52 | self.result_path = tk.StringVar() 53 | self.result_entry = tk.Entry(master, textvariable=self.result_path, width=50) 54 | self.result_entry.grid(row=1, column=1) 55 | 56 | self.result_button = tk.Button(master, text="Browse", command=self.browse_result) 57 | self.result_button.grid(row=1, column=2) 58 | 59 | self.audio_var = tk.BooleanVar() 60 | self.audio_checkbox = tk.Checkbutton(master, text="Keep Audio", variable=self.audio_var) 61 | self.audio_checkbox.grid(row=2, column=0, columnspan=3) 62 | 63 | self.run_button = tk.Button(master, text="Run", command=self.run_colorizer) 64 | self.run_button.grid(row=3, columnspan=3) 65 | 66 | def browse_source(self): 67 | file_path = filedialog.askopenfilename(filetypes=[("Video files", "*.mp4;*.avi;*.mkv")]) 68 | self.source_path.set(file_path) 69 | 70 | def browse_result(self): 71 | file_path = filedialog.asksaveasfilename(defaultextension=".mp4", filetypes=[("Video files", "*.mp4")]) 72 | self.result_path.set(file_path) 73 | 74 | def run_colorizer(self): 75 | source = self.source_path.get() 76 | result = self.result_path.get() 77 | if not source or not result: 78 | messagebox.showerror("Error", "Please select source and result paths.") 79 | return 80 | 81 | opt = Namespace(source=source, result=result, audio=self.audio_var.get()) 82 | self.colorize_video(opt) 83 | messagebox.showinfo("Done", f"Inference done. Output file: {result}") 84 | 85 | 86 | def colorize_video(self, opt): 87 | render_factor = 8 88 | render_factor = render_factor * 32 89 | 90 | video = cv2.VideoCapture(opt.source) 91 | w = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)) 92 | h = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT)) 93 | n_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) 94 | fps = video.get(cv2.CAP_PROP_FPS) 95 | 96 | if opt.audio: 97 | writer = cv2.VideoWriter('temp.mp4', cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), fps, (w, h)) 98 | else: 99 | writer = cv2.VideoWriter(opt.result, cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), fps, (w, h)) 100 | 101 | for frame_idx in tqdm(range(n_frames)): 102 | ret, frame = video.read() 103 | if not ret: 104 | break 105 | result = colorizer.colorize(frame, render_factor) 106 | 107 | #result = adjust_saturation(result, 1) # 0.1 - 2.0 108 | 109 | writer.write(result) 110 | 111 | cv2.imshow ("Result - press ESC to stop",result) 112 | k = cv2.waitKey(1) 113 | if k == 27: 114 | writer.release() 115 | break 116 | 117 | writer.release() 118 | video.release() 119 | cv2.destroyAllWindows() 120 | if opt.audio: 121 | # lossless remuxing audio/video 122 | command = 'ffmpeg.exe -y -vn -i ' + '"' + opt.source + '"' + ' -an -i ' + 'temp.mp4' + \ 123 | ' -c:v copy -acodec libmp3lame -ac 2 -ar 44100 -ab 128000 -map 0:1 -map 1:0 -shortest ' + '"' + opt.result + '"' 124 | subprocess.call(command, shell=platform.system() != 'Windows') 125 | os.remove('temp.mp4') 126 | 127 | def adjust_saturation(image, saturation_factor): 128 | # Convert BGR image to HSV 129 | hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) 130 | # Split into Hue, Saturation, and Value channels 131 | h, s, v = cv2.split(hsv) 132 | # Scale the saturation channel 133 | s = np.clip(s * saturation_factor, 0, 255).astype(np.uint8) 134 | # Merge the channels back 135 | adjusted_hsv = cv2.merge([h, s, v]) 136 | # Convert back to BGR 137 | adjusted_bgr = cv2.cvtColor(adjusted_hsv, cv2.COLOR_HSV2BGR) 138 | return adjusted_bgr 139 | 140 | def main(): 141 | 142 | root = tk.Tk() 143 | deoldify_gui = DEOLDIFY_GUI(root) 144 | root.mainloop() 145 | 146 | if __name__ == "__main__": 147 | main() 148 | --------------------------------------------------------------------------------