├── LICENSE ├── README.md ├── benchmark.py ├── beta ├── README.md ├── better.py ├── change_model_batch.py ├── faceswap_experiment - Copy.py ├── faceswap_experiment.py ├── globalsx.py ├── prune.py ├── swapperfp16 - Copy.py ├── swapperfp16.py └── utilities.py ├── chain_img_processor ├── __init__.py ├── ffmpeg_writer.py ├── image.py └── video.py ├── colab_example.ipynb ├── face.jpg ├── gfpgan ├── __init__.py ├── archs │ ├── __init__.py │ ├── arcface_arch.py │ ├── gfpgan_bilinear_arch.py │ ├── gfpganv1_arch.py │ ├── gfpganv1_clean_arch.py │ ├── restoreformer_arch.py │ ├── stylegan2_bilinear_arch.py │ └── stylegan2_clean_arch.py ├── data │ ├── __init__.py │ └── ffhq_degradation_dataset.py ├── models │ ├── __init__.py │ └── gfpgan_model.py ├── train.py ├── utils.py └── weights │ └── README.md ├── globalsz.py ├── install_directml_windows.cmd ├── install_linux.cmd ├── install_linux.sh ├── install_mac.sh ├── install_termux.sh ├── install_windows.cmd ├── jaa.py ├── ll.pkl ├── main.py ├── plugins ├── __pycache__ │ ├── codeformer_app_cv2.cpython-310.pyc │ └── codeformer_face_helper_cv2.cpython-310.pyc ├── codeformer_app_cv2.py ├── codeformer_face_helper_cv2.py ├── core.py ├── core_video.py ├── plugin_blur.py ├── plugin_codeformer.py ├── plugin_resize_cv2.py └── plugin_to_grayscale.py ├── realesrgan ├── __init__.py ├── archs │ ├── __init__.py │ ├── discriminator_arch.py │ └── srvgg_arch.py ├── data │ ├── __init__.py │ ├── realesrgan_dataset.py │ └── realesrgan_paired_dataset.py ├── models │ ├── __init__.py │ ├── realesrgan_model.py │ └── realesrnet_model.py ├── train.py └── utils.py ├── requirements.txt ├── start_mac.sh ├── start_venv_linux.sh ├── start_venv_windows.cmd ├── swapperfp16.py ├── utils.py ├── x.pkl └── zlibwapi.dll /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Richard Erkhov 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FastFaceSwap 2 | just a little project for fast face swapping using one picture. Now supports multigpu! (Almost, check the ending of readme) 3 | ## join my discord server https://discord.gg/hzrJBGPpgN 4 | 5 | # If you want to use CLIP, occluder etc, please use banana branch (it's in development, has some bugs) 6 | 7 | ## requirements: 8 | -python 3.10 9 | 10 | -cuda 11.7 with cudnn 11 | 12 | -everything else is installed automatically! 13 | 14 | # how to install 15 | 16 | ## for windows: I prefer using cmd, not powershell. Sometimes it bugs out, so please, use cmd 17 | 18 | - clone the repo 19 | 20 | if you are on windows with nvidia gpu: 21 | 22 | - install_windows.cmd 23 | 24 | if you are on windows with amd gpu: 25 | 26 | - install_directml_windows.cmd 27 | 28 | if you are on linux with nvidia gpu: 29 | 30 | - install_linux.sh 31 | 32 | if you are on mac, Im not sure if it's going to work properly, but: 33 | 34 | - install_mac.sh 35 | 36 | if you are on a phone (android), and want to run it for some reason, you are not really supported, but I managed it to run in termux with ubuntu 22.04 (installed inside termux) 37 | 38 | - install_termux.sh 39 | 40 | Note for linux guys with permission denied: 41 | `chmod +x install_linux.sh` (or whatever your installer is) 42 | 43 | then to start the environment: 44 | 45 | - start_venv_windows.cmd (if you are on windows) 46 | 47 | - start_venv_linux.sh (if you are on linux or android) 48 | 49 | - start_mac.sh (if you are on mac) 50 | 51 | Note for linux guys with permission denied: same as installer, `chmod +x start_venv_linux.sh` (or whatever you need to run) 52 | 53 | 54 | # how to run 55 | 56 | ### Here is a [colab example](https://colab.research.google.com/github/RichardErkhov/FastFaceSwap/blob/main/colab_example.ipynb) 57 | 58 | ```python main.py``` by default starts face swapping from camera. 59 | 60 | flags implemented: 61 | 62 | -f, --face: uses the face from the following image. Default: face.jpg 63 | 64 | -t, --target: replaces the face in the following video/image. Use 0 for camera. Default: 0 65 | 66 | -o, --output: path to output the video. Default: video.mp4 67 | 68 | -cam-fix, --camera-fix: fix for logitech cameras that start for 40 seconds in default mode. 69 | 70 | -res, --resolution: camera resolution, given in format WxH (ex 1920x1080). Is set for camera mode only. Default: 1920x1080 71 | 72 | --threads: amount of threads to run the program in. Default: 2 73 | 74 | --image: if the target is image, you have to use this flag. 75 | 76 | --cli: run in cli mode, turns off preview and now accepts the switch of face enhancer from the command line 77 | 78 | --face-enhancer: argument works with cli, face enhancer. In gui mode, you need to choose from gui. Available options: 79 | 80 | 1) none (default) 81 | 2) gfpgan 82 | 3) ffe (fast face enhancer) 83 | 4) codeformer 84 | 5) gfpgan_onnx 85 | 6) real_esrgan 86 | 87 | --no-face-swapper: disables face swapper 88 | 89 | --experimental: experimental mode to try to optimize the perfomance of reading of frames, sometimes is faster, but requires additional modules 90 | 91 | --no-cuda: no cuda should be used (might break sometimes) 92 | 93 | --lowmem, --low-memory: attempt to make code available for people with low VRAM, might result in lower quality 94 | 95 | --batch: enables batch mode, after it provide a suffix, for example --batch="_test.mp4" will result in output %target%_test.mp4 96 | 97 | --extract-output-frames: extract frames from output video. After argument write the path to folder. 98 | 99 | --codeformer-fidelity: argument works with cli, sets up codeformer's fidelity 100 | 101 | --blend: argument works with cli, blending amount from 0.0 to 1.0 102 | 103 | --codeformer-skip_if_no_face: argument works with cli, Skip codeformer if no face found 104 | 105 | --codeformer-face-upscale: argument works with cli, Upscale the face using codeformer 106 | 107 | --codeformer-background-enhance:argument works with cli, Enhance the background using codeformer 108 | 109 | --codeformer-upscale: argument works with cli, the amount of upscale to apply to the frame using codeformer 110 | 111 | --select-face: change the face you want, not all faces. After the argument add the path to the image with face from the video. (Just open video in video player, screenshot the frame and save it to file. Put this filename after --select-face argument) 112 | 113 | --optimization: choose the mode of the model: fp32 (default), fp16 (smaller, might be faster), int8 (doesnt work properly on old gpus, I dont know about new once, please test. On old gpus it uses cpu) 114 | 115 | --fast-load: try to load as fast as possible, might break something sometimes 116 | 117 | --bbox-adjust: adjustments to do for the box around the face: x1,y1 coords of left top corner and x2,y2 are bottom right. Give in the form x1xy1xx2xy2 (default: 50x50x50x50). Just try to play to understand 118 | 119 | -vcam, --virtual-camera: allows to use OBS virtual camera as output source. Please install obs to make sure it works 120 | 121 | example: 122 | ``` python main.py -f test.jpg -t "C:/Users/user/Desktop/video.mp4" -o output/test.mp4 --threads 12 ``` 123 | 124 | 125 | fast enhancer is still in development, color correction is needed! Sorry for inconvenience, still training the model. 126 | 127 | # ABOUT MULTIGPU MODE 128 | 129 | To choose the gpu you want to run on: in globalsz.py, on the line with `select_gpu = None` you can make it `select_gpu = [0, 1]` or something similar (these numbers are id of gpus, starting from 0). 130 | 131 | to use all gpus, `select_gpu = None` 132 | 133 | Multigpu mode for now only supports just face swapping, **without the enhancer**!!! So if you want enhancer to work, for now select only one gpu. 134 | 135 | # please read at least TLDR 136 | 137 | TL;DR. This tool was created just to make fun to remake memes, put yourself in the movies and other fun things. Some people on the other hand are doing some nasty things using this software, which is not intended way to use this software. Please be a good person, and don’t do harm to other people. Do not hold my liable for anything. 138 | This tool is provided for experimental and creative purposes only. It allows users to generate and manipulate multimedia content using deep learning technology. Users are cautioned that the tool's output, particularly deepfake content, can have ethical and legal implications. 139 | TL;DR ended ==== 140 | 141 | 142 | Educational and Ethical Use: Users are encouraged to use this tool in a responsible and ethical manner. It should primarily serve educational and artistic purposes, avoiding any malicious or misleading activities that could harm individuals or deceive the public. 143 | 144 | Informed Consent: If the tool is used to create content involving real individuals, ensure that you have obtained explicit and informed consent from those individuals to use their likeness. Using someone's image without permission can infringe upon their privacy and rights. 145 | 146 | Transparency: If you decide to share or publish content created with this tool, it is important to clearly indicate that the content is generated using deep learning technology. Transparency helps prevent misunderstandings and misinformation. 147 | 148 | Legal Considerations: Users are responsible for complying with all applicable laws and regulations related to content creation and sharing. Unauthorized use of copyrighted materials, defamation, and invasion of privacy could lead to legal consequences. 149 | 150 | Social Responsibility: Please consider the potential social impact of the content you create. Misuse of this tool could contribute to the spread of misinformation, deepening distrust, and undermining the credibility of authentic media. 151 | 152 | No Warranty: This tool is provided "as is," without any warranties or guarantees of any kind, either expressed or implied. The developers of this tool are not liable for any direct, indirect, incidental, special, or consequential damages arising from the use of the tool. 153 | 154 | Feedback and Improvement: We encourage users to provide feedback on their experiences with the tool. Your insights can contribute to refining the technology and addressing potential concerns. 155 | 156 | By using this tool, you acknowledge that you have read and understood this disclaimer. You agree to use the tool responsibly and in accordance with all applicable laws and ethical standards. The developers of this tool retain the right to modify, suspend, or terminate access to the tool at their discretion. 157 | 158 | -------------------------------------------------------------------------------- /benchmark.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | import subprocess 3 | import time 4 | from tabulate import tabulate 5 | 6 | 7 | def benchmark(command: str) -> Tuple[str, float]: 8 | start_time = time.time() 9 | process = subprocess.run(command, shell=True, text=True) 10 | end_time = time.time() 11 | process_time = end_time - start_time 12 | return process.stdout, process_time 13 | 14 | 15 | if __name__ == '__main__': 16 | commands = [ 17 | #'python main.py -f benchmark/source.jpg -t benchmark/target-240p.mp4 -o benchmark1.mp4 --cli --threads 12 --optimization fp16 --lowmem --fast-load', 18 | #'python main.py -f benchmark/source.jpg -t benchmark/target-360p.mp4 -o benchmark2.mp4 --cli --threads 12 --optimization fp32 --lowmem --fast-load', 19 | #'python main.py -f benchmark/source.jpg -t benchmark/target-540p.mp4 -o benchmark3.mp4 --cli --threads 12 --optimization fp32 --lowmem --fast-load', 20 | #'python main.py -f benchmark/source.jpg -t benchmark/target-720p.mp4 -o benchmark4.mp4 --cli --threads 12 --optimization fp32 --lowmem --fast-load', 21 | 'python main.py -f benchmark/source.jpg -t benchmark/target-1080p.mp4 -o benchmark5.mp4 --cli --threads 12 --optimization fp16 --lowmem --fast-load', 22 | 'python main.py -f benchmark/source.jpg -t benchmark/target-1080p.mp4 -o benchmark5.mp4 --cli --threads 12 --optimization fp16 --lowmem --fast-load', 23 | 'python main.py -f benchmark/source.jpg -t benchmark/target-1080p.mp4 -o benchmark5.mp4 --cli --threads 12 --optimization fp16 --lowmem --fast-load', 24 | #'python main.py -f benchmark/source.jpg -t benchmark/target-1440p.mp4 -o benchmark6.mp4 --cli --threads 12 --optimization fp32 --lowmem --fast-load', 25 | #'python main.py -f benchmark/source.jpg -t benchmark/target-2160p.mp4 -o benchmark7.mp4 --cli --threads 12 --optimization fp32 --lowmem --fast-load' 26 | ] 27 | 28 | results = [] 29 | for command in commands: 30 | output, execution_time = benchmark(command) 31 | results.append([command, f'{execution_time:.2f} seconds']) 32 | 33 | print(tabulate(results, headers=['Command', 'Execution Time'])) 34 | -------------------------------------------------------------------------------- /beta/README.md: -------------------------------------------------------------------------------- 1 | # This is beta version, if you use don't say "omg, it's not working!". It's not suppose to, it's developing 2 | # To run: copy to main folder and 'python better.py' 3 | # pip install libraries that cause import error 4 | 5 | don't expect any guides for this version, but you can ask questions and give me your suggestions in discord 6 | -------------------------------------------------------------------------------- /beta/change_model_batch.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | import os 3 | import struct 4 | 5 | from argparse import ArgumentParser 6 | 7 | 8 | def rebatch(infile, outfile, batch_size): 9 | model = onnx.load(infile) 10 | graph = model.graph 11 | 12 | # Change batch size in input, output and value_info 13 | for tensor in list(graph.input) + list(graph.value_info) + list(graph.output): 14 | try: 15 | tensor.type.tensor_type.shape.dim[0].dim_param = batch_size 16 | except Exception as e: 17 | print(e) 18 | 19 | # Set dynamic batch size in reshapes (-1) 20 | for node in graph.node: 21 | if node.op_type != 'Reshape': 22 | continue 23 | for init in graph.initializer: 24 | # node.input[1] is expected to be a reshape 25 | if init.name != node.input[1]: 26 | continue 27 | # Shape is stored as a list of ints 28 | if len(init.int64_data) > 0: 29 | # This overwrites bias nodes' reshape shape but should be fine 30 | init.int64_data[0] = -1 31 | # Shape is stored as bytes 32 | elif len(init.raw_data) > 0: 33 | shape = bytearray(init.raw_data) 34 | struct.pack_into('q', shape, 0, -1) 35 | init.raw_data = bytes(shape) 36 | 37 | onnx.save(model, outfile) 38 | 39 | if __name__ == '__main__': 40 | parser = ArgumentParser('Replace batch size with \'N\'') 41 | parser.add_argument('infile') 42 | parser.add_argument('outfile') 43 | args = parser.parse_args() 44 | 45 | rebatch(args.infile, args.outfile, 'N') -------------------------------------------------------------------------------- /beta/faceswap_experiment - Copy.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import onnxruntime as rt 3 | import insightface 4 | import tqdm 5 | import threading 6 | from threading import Thread 7 | import os 8 | import numpy as np 9 | from numpy.linalg import norm as l2norm 10 | import torch 11 | from swapperfp16 import get_model 12 | threads_per_gpu = 12 13 | if torch.cuda.is_available(): 14 | num_devices = torch.cuda.device_count() 15 | print(f"Number of available CUDA devices: {num_devices}") 16 | 17 | for i in range(num_devices): 18 | device_name = torch.cuda.get_device_name(i) 19 | print(f"Device {i}: {device_name}") 20 | 21 | os.environ['OMP_NUM_THREADS'] = '1' 22 | class ThreadWithReturnValue(Thread): 23 | def __init__(self, group=None, target=None, name=None, 24 | args=(), kwargs={}, Verbose=None): 25 | Thread.__init__(self, group, target, name, args, kwargs) 26 | self._return = None 27 | def run(self): 28 | if self._target is not None: 29 | self._return = self._target(*self._args, 30 | **self._kwargs) 31 | def join(self, *args): 32 | Thread.join(self, *args) 33 | return self._return 34 | cap = cv2.VideoCapture("banana.mp4") 35 | swappers = [] 36 | analysers = [] 37 | #providers = rt.get_available_providers() 38 | for idx in range(num_devices): 39 | providers = [ 40 | ('CUDAExecutionProvider', { 41 | 'device_id': idx, 42 | 'gpu_mem_limit': 8 * 1024 * 1024 * 1024, 43 | 'gpu_external_alloc': 0, 44 | 'gpu_external_free': 0, 45 | 'gpu_external_empty_cache': 1, 46 | 'cudnn_conv_algo_search': 'EXHAUSTIVE', 47 | 'cudnn_conv1d_pad_to_nc1d': 1, 48 | 'arena_extend_strategy': 'kNextPowerOfTwo', 49 | 'do_copy_in_default_stream': 1, 50 | 'enable_cuda_graph': 0, 51 | 'cudnn_conv_use_max_workspace': 1, 52 | 'tunable_op_enable': 1, 53 | 'enable_skip_layer_norm_strict_mode': 1, 54 | 'tunable_op_tuning_enable': 1 55 | }), 56 | 'CPUExecutionProvider', 57 | ] 58 | sess_options = rt.SessionOptions() 59 | sess_options.intra_op_num_threads = 1 60 | sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL#rt.GraphOptimizationLevel.ORT_DISABLE_ALL 61 | sess_options.execution_mode = rt.ExecutionMode.ORT_SEQUENTIAL 62 | sess_options.execution_order = rt.ExecutionOrder.PRIORITY_BASED 63 | #swappers.append(insightface.model_zoo.get_model("inswapper_128.onnx", session_options=sess_options, providers=providers)) 64 | swappers.append(get_model("inswapper_128.fp16.onnx", session_options=sess_options, providers=providers)) 65 | analysers.append(insightface.app.FaceAnalysis(name='buffalo_l', providers=providers, session_options=sess_options)) 66 | analysers[idx].prepare(ctx_id=0, det_size=(256, 256)) 67 | '''#providers = rt.get_available_providers() 68 | providers = [ 69 | ('CUDAExecutionProvider', { 70 | 'device_id': idx, 71 | }), 72 | #'CPUExecutionProvider', 73 | ]''' 74 | 75 | input_face = cv2.imread("face.jpg") 76 | source_face = sorted(analysers[0].get(input_face), key=lambda x: x.bbox[0])[0] 77 | def process(frame, sw): 78 | faces = analysers[sw].get(frame) 79 | bboxes = [] 80 | for face in faces: 81 | bboxes.append(face.bbox) 82 | frame = swappers[sw].get(frame, face, source_face, paste_back=True) 83 | return frame 84 | temp = [] 85 | pbar = tqdm.tqdm() 86 | current_frame = 0 87 | while True: 88 | ret, frame = cap.read() 89 | if len(temp) < threads_per_gpu*len(swappers) and ret: 90 | t = ThreadWithReturnValue(target=process, args=(frame,current_frame%len(swappers))) 91 | t.start() 92 | temp.append(t) 93 | continue 94 | if len(temp) >= threads_per_gpu*len(swappers) or not ret: 95 | frame = temp.pop(0).join() 96 | #cv2.imshow('Camera', frame) 97 | pbar.update(1) 98 | current_frame += 1 99 | if not ret and len(temp) == 0: 100 | break 101 | if cv2.waitKey(1) & 0xFF == ord('q'): 102 | break 103 | 104 | cap.release() 105 | cv2.destroyAllWindows() -------------------------------------------------------------------------------- /beta/faceswap_experiment.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import onnxruntime as rt 3 | import insightface 4 | import tqdm 5 | import threading 6 | from threading import Thread 7 | import os 8 | import numpy as np 9 | from numpy.linalg import norm as l2norm 10 | import torch 11 | from swapperfp16 import get_model 12 | threads_per_gpu = 1 13 | if torch.cuda.is_available(): 14 | num_devices = torch.cuda.device_count() 15 | print(f"Number of available CUDA devices: {num_devices}") 16 | 17 | for i in range(num_devices): 18 | device_name = torch.cuda.get_device_name(i) 19 | print(f"Device {i}: {device_name}") 20 | 21 | os.environ['OMP_NUM_THREADS'] = '1' 22 | class ThreadWithReturnValue(Thread): 23 | def __init__(self, group=None, target=None, name=None, 24 | args=(), kwargs={}, Verbose=None): 25 | Thread.__init__(self, group, target, name, args, kwargs) 26 | self._return = None 27 | def run(self): 28 | if self._target is not None: 29 | self._return = self._target(*self._args, 30 | **self._kwargs) 31 | def join(self, *args): 32 | Thread.join(self, *args) 33 | return self._return 34 | cap = cv2.VideoCapture("banana.mp4") 35 | swappers = [] 36 | analysers = [] 37 | #providers = rt.get_available_providers() 38 | for idx in range(1): 39 | providers = [ 40 | ('CUDAExecutionProvider', { 41 | 'device_id': idx, 42 | 'gpu_mem_limit': 8 * 1024 * 1024 * 1024, 43 | 'gpu_external_alloc': 0, 44 | 'gpu_external_free': 0, 45 | 'gpu_external_empty_cache': 1, 46 | 'cudnn_conv_algo_search': 'EXHAUSTIVE', 47 | 'cudnn_conv1d_pad_to_nc1d': 1, 48 | 'arena_extend_strategy': 'kNextPowerOfTwo', 49 | 'do_copy_in_default_stream': 1, 50 | 'enable_cuda_graph': 0, 51 | 'cudnn_conv_use_max_workspace': 1, 52 | 'tunable_op_enable': 1, 53 | 'enable_skip_layer_norm_strict_mode': 1, 54 | 'tunable_op_tuning_enable': 1 55 | }), 56 | 'CPUExecutionProvider', 57 | ] 58 | sess_options = rt.SessionOptions() 59 | sess_options.intra_op_num_threads = 1 60 | sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL#rt.GraphOptimizationLevel.ORT_DISABLE_ALL 61 | sess_options.execution_mode = rt.ExecutionMode.ORT_SEQUENTIAL 62 | sess_options.execution_order = rt.ExecutionOrder.PRIORITY_BASED 63 | #swappers.append(insightface.model_zoo.get_model("inswapper_128.onnx", session_options=sess_options, providers=providers)) 64 | swappers.append(get_model("inswapper_128.fp16_batch.onnx", session_options=sess_options, providers=providers)) 65 | analysers.append(insightface.app.FaceAnalysis(name='buffalo_l', providers=providers, session_options=sess_options)) 66 | analysers[idx].prepare(ctx_id=0, det_size=(256, 256)) 67 | '''#providers = rt.get_available_providers() 68 | providers = [ 69 | ('CUDAExecutionProvider', { 70 | 'device_id': idx, 71 | }), 72 | #'CPUExecutionProvider', 73 | ]''' 74 | 75 | input_face = cv2.imread("face.jpg") 76 | source_face = sorted(analysers[0].get(input_face), key=lambda x: x.bbox[0])[0] 77 | '''def process(frame, sw): 78 | faces = analysers[sw].get(frame) 79 | bboxes = [] 80 | for face in faces: 81 | bboxes.append(face.bbox) 82 | frame = swappers[sw].get(frame, face, source_face, paste_back=True) 83 | return frame''' 84 | def process(frames, sw): 85 | # Initialize a list to hold processed frames 86 | processed_frames = [] 87 | 88 | # Initialize lists to hold frames with detected faces and their corresponding faces 89 | frames_with_faces = [] 90 | all_faces = [] 91 | 92 | # Detect faces in all frames 93 | for frame in frames: 94 | faces = analysers[sw].get(frame) 95 | if faces: # If faces are detected in the frame 96 | frames_with_faces.append(frame) 97 | all_faces.append(faces) 98 | else: # If no faces are detected, add the original frame to processed_frames 99 | processed_frames.append(frame) 100 | 101 | if frames_with_faces: # If there are frames with detected faces 102 | # Flatten the list of faces 103 | all_faces = [face for faces in all_faces for face in faces] 104 | 105 | # Create a list of source faces with the same length as all_faces 106 | all_source_faces = [source_face] * len(all_faces) 107 | 108 | # Call the swapper on all frames with faces and all_faces 109 | swapped_frames = swappers[sw].get(frames_with_faces, all_faces, all_source_faces, paste_back=True) 110 | 111 | # Append swapped frames to the processed_frames list 112 | processed_frames.extend(swapped_frames) 113 | 114 | return processed_frames 115 | temp = [] 116 | pbar = tqdm.tqdm() 117 | current_frame = 0 118 | while True: 119 | frames = [] 120 | for _ in range(4): # read 4 frames 121 | ret, frame = cap.read() 122 | if ret: 123 | frames.append(frame) 124 | if frames and len(temp) < threads_per_gpu*len(swappers): 125 | t = ThreadWithReturnValue(target=process, args=(frames, current_frame%len(swappers))) 126 | t.start() 127 | temp.append(t) 128 | if len(temp) >= threads_per_gpu*len(swappers) or not frames: 129 | processed_frames = temp.pop(0).join() 130 | for i, frame in enumerate(processed_frames): 131 | cv2.imshow(f'Camera {i}', frame) 132 | pbar.update(4) 133 | current_frame += 4 134 | if not frames and len(temp) == 0: 135 | break 136 | if cv2.waitKey(1) & 0xFF == ord('q'): 137 | break 138 | 139 | cap.release() 140 | cv2.destroyAllWindows() 141 | -------------------------------------------------------------------------------- /beta/globalsx.py: -------------------------------------------------------------------------------- 1 | #just a file with globals so I can make everything better 2 | import os 3 | os.environ['OMP_NUM_THREADS'] = '1' #sometimes speeds up things 4 | args = {} 5 | swapper = None 6 | swapper_enabled = True 7 | #[face_embeddings, chosen_face] 8 | to_swap = [] 9 | source_face = None #if all_faces is not enabled, it wouldn't be used 10 | current_video = 0 11 | this_frame = None 12 | frame_move = 0 13 | render_queue = [] -------------------------------------------------------------------------------- /beta/prune.py: -------------------------------------------------------------------------------- 1 | import onnx 2 | from onnxruntime.quantization import quantize_dynamic, QuantType 3 | 4 | model_fp32 = 'inswapper_128.onnx' 5 | model_quant = 'inswapper_128.quant.onnx' 6 | #quantized_model = quantize_dynamic(model_fp32, model_quant, weight_type=QuantType.QUInt8) 7 | import onnx 8 | import onnxruntime as ort 9 | from onnxruntime.quantization import quantize_dynamic 10 | 11 | # Load the original ONNX model 12 | original_model_path = 'inswapper_128.onnx' 13 | quantized_model_path = "inswapper_128.quant.onnx" 14 | 15 | # Load the model with ONNX 16 | model = onnx.load(original_model_path) 17 | 18 | # Get the last initializer 19 | last_initializer = model.graph.initializer[-1] 20 | 21 | # Remove the last initializer from the graph 22 | model.graph.initializer.pop() 23 | 24 | # Save the modified model to a temporary file 25 | temp_model_path = "temp_model.onnx" 26 | onnx.save(model, temp_model_path) 27 | 28 | # Quantize the model 29 | quantize_dynamic(temp_model_path, quantized_model_path, per_channel=True, weight_type=QuantType.QUInt8) 30 | 31 | # Add the last initializer back to the quantized model 32 | quantized_model = onnx.load(quantized_model_path) 33 | quantized_model.graph.initializer.append(last_initializer) 34 | 35 | # Save the final quantized model 36 | onnx.save(quantized_model, quantized_model_path) 37 | 38 | # Optionally, you can remove the temporary file 39 | import os 40 | os.remove(temp_model_path) -------------------------------------------------------------------------------- /beta/swapperfp16 - Copy.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import onnxruntime 4 | import cv2 5 | import onnx 6 | from onnx import numpy_helper 7 | from insightface.utils import face_align 8 | from numpy.linalg import norm as l2norm 9 | 10 | 11 | 12 | class INSwapper(): 13 | def __init__(self, model_file=None, session=None): 14 | self.model_file = model_file 15 | self.session = session 16 | model = onnx.load(self.model_file) 17 | graph = model.graph 18 | self.emap = numpy_helper.to_array(graph.initializer[-1]) 19 | self.input_mean = 0.0 20 | self.input_std = 255.0 21 | #print('input mean and std:', model_file, self.input_mean, self.input_std) 22 | if self.session is None: 23 | self.session = onnxruntime.InferenceSession(self.model_file, None) 24 | inputs = self.session.get_inputs() 25 | self.input_names = [] 26 | for inp in inputs: 27 | self.input_names.append(inp.name) 28 | outputs = self.session.get_outputs() 29 | output_names = [] 30 | for out in outputs: 31 | output_names.append(out.name) 32 | self.output_names = output_names 33 | assert len(self.output_names)==1 34 | output_shape = outputs[0].shape 35 | input_cfg = inputs[0] 36 | input_shape = input_cfg.shape 37 | self.input_shape = input_shape 38 | print('inswapper-shape:', self.input_shape) 39 | self.input_size = tuple(input_shape[2:4][::-1]) 40 | 41 | def forward(self, img, latent): 42 | img = (img - self.input_mean) / self.input_std 43 | pred = self.session.run(self.output_names, {self.input_names[0]: img, self.input_names[1]: latent})[0] 44 | return pred 45 | 46 | def get(self, img, target_face, source_face, paste_back=True): 47 | aimg, M = face_align.norm_crop2(img, target_face.kps, self.input_size[0]) 48 | blob = cv2.dnn.blobFromImage(aimg, 1.0 / self.input_std, self.input_size, 49 | (self.input_mean, self.input_mean, self.input_mean), swapRB=True) 50 | s_e = source_face.normed_embedding 51 | n_e = s_e / l2norm(s_e) 52 | latent = n_e.reshape((1,-1)) 53 | 54 | latent = np.dot(latent, self.emap) 55 | latent /= np.linalg.norm(latent) 56 | pred = self.session.run(self.output_names, {self.input_names[0]: blob, self.input_names[1]: latent})[0] 57 | #print(latent.shape, latent.dtype, pred.shape) 58 | img_fake = pred.transpose((0,2,3,1))[0] 59 | bgr_fake = np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1] 60 | if not paste_back: 61 | return bgr_fake, M 62 | else: 63 | target_img = img 64 | fake_diff = bgr_fake.astype(np.float32) - aimg.astype(np.float32) 65 | fake_diff = np.abs(fake_diff).mean(axis=2) 66 | fake_diff[:2,:] = 0 67 | fake_diff[-2:,:] = 0 68 | fake_diff[:,:2] = 0 69 | fake_diff[:,-2:] = 0 70 | IM = cv2.invertAffineTransform(M) 71 | img_mask = np.full((aimg.shape[0],aimg.shape[1]), 255, dtype=np.float32) 72 | bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) 73 | img_mask = cv2.warpAffine(img_mask, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) 74 | fake_diff = cv2.warpAffine(fake_diff, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) 75 | img_mask[img_mask>20] = 255 76 | fthresh = 10 77 | fake_diff[fake_diff=fthresh] = 255 79 | #img_mask = img_white 80 | mask_h_inds, mask_w_inds = np.where(img_mask==255) 81 | mask_h = np.max(mask_h_inds) - np.min(mask_h_inds) 82 | mask_w = np.max(mask_w_inds) - np.min(mask_w_inds) 83 | mask_size = int(np.sqrt(mask_h*mask_w)) 84 | k = max(mask_size//10, 10) 85 | #k = max(mask_size//20, 6) 86 | #k = 6 87 | kernel = np.ones((k,k),np.uint8) 88 | img_mask = cv2.erode(img_mask,kernel,iterations = 1) 89 | kernel = np.ones((2,2),np.uint8) 90 | fake_diff = cv2.dilate(fake_diff,kernel,iterations = 1) 91 | k = max(mask_size//20, 5) 92 | #k = 3 93 | #k = 3 94 | kernel_size = (k, k) 95 | blur_size = tuple(2*i+1 for i in kernel_size) 96 | img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) 97 | k = 5 98 | kernel_size = (k, k) 99 | blur_size = tuple(2*i+1 for i in kernel_size) 100 | fake_diff = cv2.GaussianBlur(fake_diff, blur_size, 0) 101 | img_mask /= 255 102 | fake_diff /= 255 103 | #img_mask = fake_diff 104 | img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) 105 | fake_merged = img_mask * bgr_fake + (1-img_mask) * target_img.astype(np.float32) 106 | fake_merged = fake_merged.astype(np.uint8) 107 | return fake_merged 108 | class PickableInferenceSession(onnxruntime.InferenceSession): 109 | # This is a wrapper to make the current InferenceSession class pickable. 110 | def __init__(self, model_path, **kwargs): 111 | super().__init__(model_path, **kwargs) 112 | self.model_path = model_path 113 | 114 | def __getstate__(self): 115 | return {'model_path': self.model_path} 116 | 117 | def __setstate__(self, values): 118 | model_path = values['model_path'] 119 | self.__init__(model_path) 120 | 121 | class ModelRouter: 122 | def __init__(self, onnx_file): 123 | self.onnx_file = onnx_file 124 | 125 | def get_model(self, **kwargs): 126 | session = PickableInferenceSession(self.onnx_file, **kwargs) 127 | print(f'Applied providers: {session._providers}, with options: {session._provider_options}') 128 | inputs = session.get_inputs() 129 | input_cfg = inputs[0] 130 | input_shape = input_cfg.shape 131 | outputs = session.get_outputs() 132 | return INSwapper(model_file=self.onnx_file, session=session) 133 | 134 | def get_default_providers(): 135 | return ['CUDAExecutionProvider', 'CPUExecutionProvider'] 136 | 137 | def get_default_provider_options(): 138 | return None 139 | def get_model(name, **kwargs): 140 | router = ModelRouter(name) 141 | providers = kwargs.get('providers', get_default_providers()) 142 | provider_options = kwargs.get('provider_options', get_default_provider_options()) 143 | #session_options = kwargs.get('session_options', None) 144 | model = router.get_model(providers=providers, provider_options=provider_options)#, session_options = session_options) 145 | return model -------------------------------------------------------------------------------- /beta/swapperfp16.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import onnxruntime 4 | import cv2 5 | import onnx 6 | from onnx import numpy_helper 7 | from insightface.utils import face_align 8 | from numpy.linalg import norm as l2norm 9 | 10 | 11 | 12 | class INSwapper(): 13 | def __init__(self, model_file=None, session=None): 14 | self.model_file = model_file 15 | self.session = session 16 | model = onnx.load(self.model_file) 17 | graph = model.graph 18 | self.emap = numpy_helper.to_array(graph.initializer[-1]) 19 | self.input_mean = 0.0 20 | self.input_std = 255.0 21 | #print('input mean and std:', model_file, self.input_mean, self.input_std) 22 | if self.session is None: 23 | self.session = onnxruntime.InferenceSession(self.model_file, None) 24 | inputs = self.session.get_inputs() 25 | self.input_names = [] 26 | for inp in inputs: 27 | self.input_names.append(inp.name) 28 | outputs = self.session.get_outputs() 29 | output_names = [] 30 | for out in outputs: 31 | output_names.append(out.name) 32 | self.output_names = output_names 33 | assert len(self.output_names)==1 34 | output_shape = outputs[0].shape 35 | input_cfg = inputs[0] 36 | input_shape = input_cfg.shape 37 | self.input_shape = input_shape 38 | print('inswapper-shape:', self.input_shape) 39 | self.input_size = tuple(input_shape[2:4][::-1]) 40 | 41 | def forward(self, img, latent): 42 | img = (img - self.input_mean) / self.input_std 43 | pred = self.session.run(self.output_names, {self.input_names[0]: img, self.input_names[1]: latent})[0] 44 | return pred 45 | 46 | def get(self, imgs, target_faces, source_faces, paste_back=True): 47 | assert len(imgs) == len(target_faces) == len(source_faces), "The number of images, target faces, and source faces must be the same." 48 | # Initialize lists to hold results 49 | bgr_fakes = [] 50 | fake_mergeds = [] 51 | 52 | blobs = [] 53 | latents = [] 54 | for img, target_face, source_face in zip(imgs, target_faces, source_faces): 55 | aimg, M = face_align.norm_crop2(img, target_face.kps, self.input_size[0]) 56 | blob = cv2.dnn.blobFromImage(aimg, 1.0 / self.input_std, self.input_size, 57 | (self.input_mean, self.input_mean, self.input_mean), swapRB=True) 58 | s_e = source_face.normed_embedding 59 | n_e = s_e / l2norm(s_e) 60 | latent = n_e.reshape((1,-1)) 61 | 62 | latent = np.dot(latent, self.emap) 63 | latent /= np.linalg.norm(latent) 64 | 65 | blobs.append(blob) 66 | latents.append(latent) 67 | blobs = np.concatenate(blobs, axis=0) 68 | latents = np.concatenate(latents, axis=0) 69 | preds = self.session.run(self.output_names, {self.input_names[0]: blobs, self.input_names[1]: latents})[0] 70 | #print(latent.shape, latent.dtype, pred.shape) 71 | for i in range(len(imgs)): 72 | img_fake = preds[i].transpose((1,2,0)) 73 | bgr_fake = np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1] 74 | bgr_fakes.append(bgr_fake) 75 | if not paste_back: 76 | continue 77 | #return bgr_fake, M 78 | else: 79 | target_img = img 80 | fake_diff = bgr_fake.astype(np.float32) - aimg.astype(np.float32) 81 | fake_diff = np.abs(fake_diff).mean(axis=2) 82 | fake_diff[:2,:] = 0 83 | fake_diff[-2:,:] = 0 84 | fake_diff[:,:2] = 0 85 | fake_diff[:,-2:] = 0 86 | IM = cv2.invertAffineTransform(M) 87 | img_mask = np.full((aimg.shape[0],aimg.shape[1]), 255, dtype=np.float32) 88 | bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) 89 | img_mask = cv2.warpAffine(img_mask, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) 90 | fake_diff = cv2.warpAffine(fake_diff, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) 91 | img_mask[img_mask>20] = 255 92 | fthresh = 10 93 | fake_diff[fake_diff=fthresh] = 255 95 | #img_mask = img_white 96 | mask_h_inds, mask_w_inds = np.where(img_mask==255) 97 | mask_h = np.max(mask_h_inds) - np.min(mask_h_inds) 98 | mask_w = np.max(mask_w_inds) - np.min(mask_w_inds) 99 | mask_size = int(np.sqrt(mask_h*mask_w)) 100 | k = max(mask_size//10, 10) 101 | #k = max(mask_size//20, 6) 102 | #k = 6 103 | kernel = np.ones((k,k),np.uint8) 104 | img_mask = cv2.erode(img_mask,kernel,iterations = 1) 105 | kernel = np.ones((2,2),np.uint8) 106 | fake_diff = cv2.dilate(fake_diff,kernel,iterations = 1) 107 | k = max(mask_size//20, 5) 108 | #k = 3 109 | #k = 3 110 | kernel_size = (k, k) 111 | blur_size = tuple(2*i+1 for i in kernel_size) 112 | img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) 113 | k = 5 114 | kernel_size = (k, k) 115 | blur_size = tuple(2*i+1 for i in kernel_size) 116 | fake_diff = cv2.GaussianBlur(fake_diff, blur_size, 0) 117 | img_mask /= 255 118 | fake_diff /= 255 119 | #img_mask = fake_diff 120 | img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) 121 | fake_merged = img_mask * bgr_fake + (1-img_mask) * imgs[i].astype(np.float32) 122 | fake_merged = fake_merged.astype(np.uint8) 123 | fake_mergeds.append(fake_merged) 124 | #return fake_merged 125 | if paste_back: 126 | return bgr_fakes, fake_mergeds 127 | else: 128 | return bgr_fakes 129 | class PickableInferenceSession(onnxruntime.InferenceSession): 130 | # This is a wrapper to make the current InferenceSession class pickable. 131 | def __init__(self, model_path, **kwargs): 132 | super().__init__(model_path, **kwargs) 133 | self.model_path = model_path 134 | 135 | def __getstate__(self): 136 | return {'model_path': self.model_path} 137 | 138 | def __setstate__(self, values): 139 | model_path = values['model_path'] 140 | self.__init__(model_path) 141 | 142 | class ModelRouter: 143 | def __init__(self, onnx_file): 144 | self.onnx_file = onnx_file 145 | 146 | def get_model(self, **kwargs): 147 | session = PickableInferenceSession(self.onnx_file, **kwargs) 148 | print(f'Applied providers: {session._providers}, with options: {session._provider_options}') 149 | inputs = session.get_inputs() 150 | input_cfg = inputs[0] 151 | input_shape = input_cfg.shape 152 | outputs = session.get_outputs() 153 | return INSwapper(model_file=self.onnx_file, session=session) 154 | 155 | def get_default_providers(): 156 | return ['CUDAExecutionProvider', 'CPUExecutionProvider'] 157 | 158 | def get_default_provider_options(): 159 | return None 160 | def get_model(name, **kwargs): 161 | router = ModelRouter(name) 162 | providers = kwargs.get('providers', get_default_providers()) 163 | provider_options = kwargs.get('provider_options', get_default_provider_options()) 164 | #session_options = kwargs.get('session_options', None) 165 | model = router.get_model(providers=providers, provider_options=provider_options)#, session_options = session_options) 166 | return model -------------------------------------------------------------------------------- /beta/utilities.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from threading import Thread 3 | import onnxruntime as rt 4 | import insightface 5 | import cv2 6 | import numpy as np 7 | import threading 8 | import queue 9 | import time 10 | from tkinter import messagebox 11 | from PIL import Image 12 | from numpy import asarray 13 | import os 14 | from scipy.spatial import distance 15 | import psutil 16 | import globalsx 17 | from types import NoneType 18 | def is_video_file(filename): 19 | video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.webm'] # Add more extensions as needed 20 | _, ext = os.path.splitext(filename) 21 | return ext.lower() in video_extensions 22 | def is_picture_file(filename): 23 | image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.svg', '.tiff', '.webp'] 24 | _, ext = os.path.splitext(filename) 25 | return ext.lower() in image_extensions 26 | def get_system_usage(): 27 | # Get RAM usage in GB 28 | ram_usage = round(psutil.virtual_memory().used / 1024**3, 1) 29 | 30 | # Get total RAM in GB 31 | total_ram = round(psutil.virtual_memory().total / 1024**3, 1) 32 | 33 | # Get CPU usage in percentage 34 | cpu_usage = round(psutil.cpu_percent(), 0) 35 | return ram_usage, total_ram, cpu_usage 36 | def extract_frames_from_video(target_video, output_folder): 37 | target_video_name = os.path.basename(target_video) 38 | ffmpeg_cmd = [ 39 | 'ffmpeg', 40 | '-i', target_video, 41 | f'{output_folder}/{target_video_name}/frame_%05d.png' 42 | ] 43 | subprocess.run(ffmpeg_cmd, check=True) 44 | def add_audio_from_video(video_path, audio_video_path, output_path): 45 | ffmpeg_cmd = [ 46 | 'ffmpeg', 47 | '-i', video_path, 48 | '-i', audio_video_path, 49 | '-c:v', 'copy', 50 | '-map', '0:v:0', 51 | '-map', '1:a:0', 52 | '-shortest', 53 | output_path 54 | ] 55 | subprocess.run(ffmpeg_cmd, check=True) 56 | def merge_face(temp_frame, original, alpha): 57 | temp_frame = Image.blend(Image.fromarray(original), Image.fromarray(temp_frame), alpha) 58 | return asarray(temp_frame) 59 | class GFPGAN_onnxruntime: 60 | def __init__(self, model_path, use_gpu = False): 61 | sess_options = rt.SessionOptions() 62 | sess_options.intra_op_num_threads = 8 63 | providers = rt.get_available_providers() 64 | self.ort_session = rt.InferenceSession(model_path, providers=providers, session_options=sess_options) 65 | self.net_input_name = self.ort_session.get_inputs()[0].name 66 | _,self.net_input_channels,self.net_input_height,self.net_input_width = self.ort_session.get_inputs()[0].shape 67 | self.net_output_count = len(self.ort_session.get_outputs()) 68 | self.face_size = 512 69 | self.face_template = np.array([[192, 240], [319, 240], [257, 371]]) * (self.face_size / 512.0) 70 | self.upscale_factor = 2 71 | self.affine = False 72 | self.affine_matrix = None 73 | def pre_process(self, img): 74 | img = cv2.resize(img, (self.face_size, self.face_size)) 75 | img = img / 255.0 76 | img = img.astype('float32') 77 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 78 | img[:,:,0] = (img[:,:,0]-0.5)/0.5 79 | img[:,:,1] = (img[:,:,1]-0.5)/0.5 80 | img[:,:,2] = (img[:,:,2]-0.5)/0.5 81 | img = np.float32(img[np.newaxis,:,:,:]) 82 | img = img.transpose(0, 3, 1, 2) 83 | return img 84 | def post_process(self, output, height, width): 85 | output = output.clip(-1,1) 86 | output = (output + 1) / 2 87 | output = output.transpose(1, 2, 0) 88 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 89 | output = (output * 255.0).round() 90 | if self.affine: 91 | inverse_affine = cv2.invertAffineTransform(self.affine_matrix) 92 | inverse_affine *= self.upscale_factor 93 | if self.upscale_factor > 1: 94 | extra_offset = 0.5 * self.upscale_factor 95 | else: 96 | extra_offset = 0 97 | inverse_affine[:, 2] += extra_offset 98 | inv_restored = cv2.warpAffine(output, inverse_affine, (width, height)) 99 | mask = np.ones((self.face_size, self.face_size), dtype=np.float32) 100 | inv_mask = cv2.warpAffine(mask, inverse_affine, (width, height)) 101 | inv_mask_erosion = cv2.erode( 102 | inv_mask, np.ones((int(2 * self.upscale_factor), int(2 * self.upscale_factor)), np.uint8)) 103 | pasted_face = inv_mask_erosion[:, :, None] * inv_restored 104 | total_face_area = np.sum(inv_mask_erosion) 105 | # compute the fusion edge based on the area of face 106 | w_edge = int(total_face_area**0.5) // 20 107 | erosion_radius = w_edge * 2 108 | inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8)) 109 | blur_size = w_edge * 2 110 | inv_soft_mask = cv2.GaussianBlur(inv_mask_center, (blur_size + 1, blur_size + 1), 0) 111 | inv_soft_mask = inv_soft_mask[:, :, None] 112 | output = pasted_face 113 | else: 114 | inv_soft_mask = np.ones((height, width, 1), dtype=np.float32) 115 | output = cv2.resize(output, (width, height)) 116 | return output, inv_soft_mask 117 | 118 | def forward(self, img): 119 | height, width = img.shape[0], img.shape[1] 120 | img = self.pre_process(img) 121 | ort_inputs = {self.ort_session.get_inputs()[0].name: img} 122 | ort_outs = self.ort_session.run(None, ort_inputs) 123 | output = ort_outs[0][0] 124 | output, inv_soft_mask = self.post_process(output, height, width) 125 | output = output.astype(np.uint8) 126 | return output, inv_soft_mask 127 | def prepare(): 128 | import tensorflow as tf 129 | def mish_activation(x): 130 | return x * tf.keras.activations.tanh(tf.keras.activations.softplus(x)) 131 | class Mish(tf.keras.layers.Layer): 132 | def __init__(self, **kwargs): 133 | super(Mish, self).__init__() 134 | def call(self, inputs): 135 | return mish_activation(inputs) 136 | tf.keras.utils.get_custom_objects().update({'Mish': Mish}) 137 | def add_audio_from_video(video_path, audio_video_path, output_path): 138 | ffmpeg_cmd = [ 139 | 'ffmpeg', 140 | '-i', video_path, 141 | '-i', audio_video_path, 142 | '-c:v', 'copy', 143 | '-map', '0:v:0', 144 | '-map', '1:a:0', 145 | '-shortest', 146 | output_path 147 | ] 148 | subprocess.run(ffmpeg_cmd, check=True) 149 | def get_nth_frame(cap, number): 150 | cap.set(cv2.CAP_PROP_POS_FRAMES, number) 151 | ret, frame = cap.read() 152 | if ret: 153 | return frame 154 | return None 155 | class ThreadWithReturnValue(Thread): 156 | def __init__(self, group=None, target=None, name=None, 157 | args=(), kwargs={}, Verbose=None): 158 | Thread.__init__(self, group, target, name, args, kwargs) 159 | self._return = None 160 | def run(self): 161 | if self._target is not None: 162 | self._return = self._target(*self._args, **self._kwargs) 163 | def join(self, *args): 164 | Thread.join(self, *args) 165 | return self._return 166 | 167 | class VideoCaptureThread: 168 | def __init__(self, video_path, buffer_size): 169 | self.video_path = video_path 170 | self.buffer_size = buffer_size 171 | self.frame_queue = queue.Queue(maxsize=buffer_size) 172 | self.condition = threading.Condition() 173 | self.thread = None 174 | self.frame_counter = 0 175 | self.start_time = 0 176 | self.end_time = 0 177 | self.width = None 178 | self.height = None 179 | self.fps = None 180 | cap = cv2.VideoCapture(self.video_path) 181 | self.fps = cap.get(cv2.CAP_PROP_FPS) 182 | self.width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 183 | self.height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 184 | cap.release() 185 | self.start() 186 | def start(self): 187 | self.thread = threading.Thread(target=self._capture_frames) 188 | self.thread.start() 189 | 190 | def stop(self): 191 | with self.condition: 192 | self.condition.notify_all() # Unblock all threads 193 | if self.thread: 194 | self.thread.join() 195 | 196 | def _capture_frames(self): 197 | cap = cv2.VideoCapture(self.video_path) 198 | fourcc = cv2.VideoWriter_fourcc(*'H265') 199 | cap.set(cv2.CAP_PROP_FOURCC, fourcc) 200 | try: 201 | self.start_time = time.time() 202 | while True: 203 | ret, frame = cap.read() 204 | if not ret: 205 | break 206 | 207 | with self.condition: 208 | # Wait until there is space in the buffer 209 | while self.frame_queue.full(): 210 | self.condition.wait() 211 | 212 | self.frame_queue.put(frame) 213 | self.frame_counter += 1 214 | self.condition.notify_all() # Notify all threads 215 | finally: 216 | self.end_time = time.time() 217 | cap.release() 218 | 219 | with self.condition: 220 | self.frame_queue.put(None) 221 | self.condition.notify_all() # Notify all threads 222 | 223 | def read(self): 224 | with self.condition: 225 | # Wait until there is a frame available in the buffer 226 | while self.frame_queue.empty(): 227 | self.condition.wait() 228 | 229 | frame = self.frame_queue.get() 230 | self.condition.notify_all() # Notify all threads 231 | return frame 232 | 233 | 234 | def prepare_models(det_size='auto', prepare='all', custom_return=False): 235 | providers = rt.get_available_providers() 236 | sess_options = rt.SessionOptions() 237 | sess_options.intra_op_num_threads = 8 238 | sess_options2 = rt.SessionOptions() 239 | sess_options2.graph_optimization_level = rt.GraphOptimizationLevel.ORT_DISABLE_ALL #Varying with all the options 240 | sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_DISABLE_ALL #Varying with all the options 241 | if custom_return: 242 | face_analyzer = insightface.app.FaceAnalysis(name='buffalo_l', providers=providers, session_options=sess_options2) 243 | face_analyzer.prepare(ctx_id=0, det_size=det_size) 244 | return face_analyzer 245 | if not globalsx.args['no_faceswap'] and (prepare == 'all' or prepare == 'face_swapper'): 246 | globalsx.face_swapper = insightface.model_zoo.get_model("inswapper_128.onnx", session_options=sess_options, providers=providers) 247 | else: 248 | globalsx.face_swapper = None 249 | if globalsx.args['lowmem']: 250 | if prepare == 'all' or prepare == 'analyser': 251 | globalsx.face_analyser = insightface.app.FaceAnalysis(name='buffalo_l', providers=providers, session_options=sess_options2) 252 | if det_size == 'auto' and (prepare == 'all' or prepare == 'analyser'): 253 | globalsx.face_analyser.prepare(ctx_id=0, det_size=(256, 256)) 254 | else: 255 | globalsx.face_analyser = insightface.app.FaceAnalysis(name='buffalo_l', providers=providers) 256 | if det_size == 'auto' and (prepare == 'all' or prepare == 'analyser'): 257 | globalsx.face_analyser.prepare(ctx_id=0, det_size=(640, 640)) 258 | if det_size != 'auto' and (prepare == 'all' or prepare == 'analyser'): 259 | globalsx.face_analyzer.prepare(ctx_id=0, det_size=det_size) 260 | 261 | #face_analyser.models.pop("landmark_3d_68") 262 | #face_analyser.models.pop("landmark_2d_106") 263 | #face_analyser.models.pop("genderage") 264 | #return face_swapper, face_analyser 265 | 266 | def upscale_image(image, generator ): 267 | image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) 268 | image = cv2.resize(image, (256, 256)) 269 | image = (image / 255.0) #- 1 270 | image = np.expand_dims(image, axis=0).astype(np.float32) 271 | #output = generator.run(None, {'input': image}) 272 | output = generator(image)#.predict(image, verbose=0) 273 | return cv2.cvtColor((np.squeeze(output, axis=0) * 255.0), cv2.COLOR_BGR2RGB) #np.squeeze(output, axis=0)*255 274 | def show_error(message='', title='Error'): 275 | messagebox.showerror(title, message) 276 | 277 | def show_errorx(): 278 | messagebox.showerror("Error", "Preview mode does not work with camera, so please use normal mode") 279 | def show_warning(): 280 | messagebox.showwarning("Warning", "Camera is not properly working with experimental mode, sorry") 281 | 282 | def compute_cosine_distance(emb1, emb2, allowed_distance): 283 | d = distance.cosine(emb1, emb2) 284 | check = False 285 | if d < allowed_distance: 286 | check = True 287 | return d, check 288 | 289 | def open_cap(file_path=0): 290 | cap = cv2.VideoCapture(file_path) 291 | fourcc = cv2.VideoWriter_fourcc(*'H265') 292 | cap.set(cv2.CAP_PROP_FOURCC, fourcc) 293 | return cap 294 | 295 | def make_swap(frame, face, source_face, paste_back=True): #TODO make fp 16 support 296 | return face_swapper.get(frame, face, source_face, paste_back=paste_back) 297 | def swap_frame(frame): 298 | all_faces = globalsx.face_analyser.get(frame) 299 | if not isinstance(globalsx.swapper, NoneType) and globalsx.swapper_enabled: 300 | if globalsx.swap_all_faces: 301 | for face in faces: 302 | frame = make_swap(frame, face, globalsx.source_face) 303 | else: 304 | for face in all_faces: 305 | for check_face_list in globalsx.to_swap: 306 | emb, chosen_face = check_face_list 307 | a = emb.normed_embedding 308 | b = face.normed_embedding 309 | _, allow = compute_cosine_distance(a,b , 0.75) 310 | if not allow: 311 | continue 312 | frame = make_swap(frame, face, chosen_face) 313 | return frame -------------------------------------------------------------------------------- /chain_img_processor/__init__.py: -------------------------------------------------------------------------------- 1 | from .image import ChainImgProcessor, ChainImgPlugin, get_single_image_processor, version 2 | from .video import ChainVideoProcessor, get_single_video_processor 3 | from .ffmpeg_writer import FFMPEG_VideoWriter -------------------------------------------------------------------------------- /chain_img_processor/ffmpeg_writer.py: -------------------------------------------------------------------------------- 1 | """ 2 | FFMPEG_Writer - write set of frames to video file 3 | 4 | original from 5 | https://github.com/Zulko/moviepy/blob/master/moviepy/video/io/ffmpeg_writer.py 6 | 7 | removed unnecessary dependencies 8 | 9 | The MIT License (MIT) 10 | 11 | Copyright (c) 2015 Zulko 12 | Copyright (c) 2023 Janvarev Vladislav 13 | """ 14 | 15 | import os 16 | import subprocess as sp 17 | 18 | PIPE = -1 19 | STDOUT = -2 20 | DEVNULL = -3 21 | 22 | FFMPEG_BINARY = "ffmpeg" 23 | 24 | class FFMPEG_VideoWriter: 25 | """ A class for FFMPEG-based video writing. 26 | 27 | A class to write videos using ffmpeg. ffmpeg will write in a large 28 | choice of formats. 29 | 30 | Parameters 31 | ----------- 32 | 33 | filename 34 | Any filename like 'video.mp4' etc. but if you want to avoid 35 | complications it is recommended to use the generic extension 36 | '.avi' for all your videos. 37 | 38 | size 39 | Size (width,height) of the output video in pixels. 40 | 41 | fps 42 | Frames per second in the output video file. 43 | 44 | codec 45 | FFMPEG codec. It seems that in terms of quality the hierarchy is 46 | 'rawvideo' = 'png' > 'mpeg4' > 'libx264' 47 | 'png' manages the same lossless quality as 'rawvideo' but yields 48 | smaller files. Type ``ffmpeg -codecs`` in a terminal to get a list 49 | of accepted codecs. 50 | 51 | Note for default 'libx264': by default the pixel format yuv420p 52 | is used. If the video dimensions are not both even (e.g. 720x405) 53 | another pixel format is used, and this can cause problem in some 54 | video readers. 55 | 56 | audiofile 57 | Optional: The name of an audio file that will be incorporated 58 | to the video. 59 | 60 | preset 61 | Sets the time that FFMPEG will take to compress the video. The slower, 62 | the better the compression rate. Possibilities are: ultrafast,superfast, 63 | veryfast, faster, fast, medium (default), slow, slower, veryslow, 64 | placebo. 65 | 66 | bitrate 67 | Only relevant for codecs which accept a bitrate. "5000k" offers 68 | nice results in general. 69 | 70 | """ 71 | 72 | def __init__(self, filename, size, fps, codec="libx265", crf=14, audiofile=None, 73 | preset="medium", bitrate=None, 74 | logfile=None, threads=None, ffmpeg_params=None): 75 | 76 | if logfile is None: 77 | logfile = sp.PIPE 78 | 79 | self.filename = filename 80 | self.codec = codec 81 | self.ext = self.filename.split(".")[-1] 82 | 83 | # order is important 84 | cmd = [ 85 | FFMPEG_BINARY, 86 | '-y', 87 | '-loglevel', 'error' if logfile == sp.PIPE else 'info', 88 | '-f', 'rawvideo', 89 | '-vcodec', 'rawvideo', 90 | '-s', '%dx%d' % (size[0], size[1]), 91 | #'-pix_fmt', 'rgba' if withmask else 'rgb24', 92 | '-pix_fmt', 'bgr24', 93 | '-r', str(fps), 94 | '-an', '-i', '-' 95 | ] 96 | if audiofile is not None: 97 | cmd.extend([ 98 | '-i', audiofile, 99 | '-acodec', 'copy' 100 | ]) 101 | cmd.extend([ 102 | '-vcodec', codec, 103 | '-crf', str(crf) 104 | #'-preset', preset, 105 | ]) 106 | if ffmpeg_params is not None: 107 | cmd.extend(ffmpeg_params) 108 | if bitrate is not None: 109 | cmd.extend([ 110 | '-b', bitrate 111 | ]) 112 | 113 | if threads is not None: 114 | cmd.extend(["-threads", str(threads)]) 115 | 116 | # if ((codec == 'libx264') and 117 | # (size[0] % 2 == 0) and 118 | # (size[1] % 2 == 0)): 119 | cmd.extend([ 120 | '-pix_fmt', 'yuv420p', 121 | 122 | ]) 123 | cmd.extend([ 124 | filename 125 | ]) 126 | 127 | popen_params = {"stdout": DEVNULL, 128 | "stderr": logfile, 129 | "stdin": sp.PIPE} 130 | 131 | # This was added so that no extra unwanted window opens on windows 132 | # when the child process is created 133 | if os.name == "nt": 134 | popen_params["creationflags"] = 0x08000000 # CREATE_NO_WINDOW 135 | 136 | self.proc = sp.Popen(cmd, **popen_params) 137 | 138 | 139 | def write_frame(self, img_array): 140 | """ Writes one frame in the file.""" 141 | try: 142 | #if PY3: 143 | self.proc.stdin.write(img_array.tobytes()) 144 | # else: 145 | # self.proc.stdin.write(img_array.tostring()) 146 | except IOError as err: 147 | _, ffmpeg_error = self.proc.communicate() 148 | error = (str(err) + ("\n\nMoviePy error: FFMPEG encountered " 149 | "the following error while writing file %s:" 150 | "\n\n %s" % (self.filename, str(ffmpeg_error)))) 151 | 152 | if b"Unknown encoder" in ffmpeg_error: 153 | 154 | error = error+("\n\nThe video export " 155 | "failed because FFMPEG didn't find the specified " 156 | "codec for video encoding (%s). Please install " 157 | "this codec or change the codec when calling " 158 | "write_videofile. For instance:\n" 159 | " >>> clip.write_videofile('myvid.webm', codec='libvpx')")%(self.codec) 160 | 161 | elif b"incorrect codec parameters ?" in ffmpeg_error: 162 | 163 | error = error+("\n\nThe video export " 164 | "failed, possibly because the codec specified for " 165 | "the video (%s) is not compatible with the given " 166 | "extension (%s). Please specify a valid 'codec' " 167 | "argument in write_videofile. This would be 'libx264' " 168 | "or 'mpeg4' for mp4, 'libtheora' for ogv, 'libvpx for webm. " 169 | "Another possible reason is that the audio codec was not " 170 | "compatible with the video codec. For instance the video " 171 | "extensions 'ogv' and 'webm' only allow 'libvorbis' (default) as a" 172 | "video codec." 173 | )%(self.codec, self.ext) 174 | 175 | elif b"encoder setup failed" in ffmpeg_error: 176 | 177 | error = error+("\n\nThe video export " 178 | "failed, possibly because the bitrate you specified " 179 | "was too high or too low for the video codec.") 180 | 181 | elif b"Invalid encoder type" in ffmpeg_error: 182 | 183 | error = error + ("\n\nThe video export failed because the codec " 184 | "or file extension you provided is not a video") 185 | 186 | 187 | raise IOError(error) 188 | 189 | def close(self): 190 | if self.proc: 191 | self.proc.stdin.close() 192 | if self.proc.stderr is not None: 193 | self.proc.stderr.close() 194 | self.proc.wait() 195 | 196 | self.proc = None 197 | 198 | # Support the Context Manager protocol, to ensure that resources are cleaned up. 199 | 200 | def __enter__(self): 201 | return self 202 | 203 | def __exit__(self, exc_type, exc_value, traceback): 204 | self.close() 205 | 206 | 207 | 208 | def ffmpeg_write_image(filename, image, logfile=False): 209 | """ Writes an image (HxWx3 or HxWx4 numpy array) to a file, using 210 | ffmpeg. """ 211 | 212 | if image.dtype != 'uint8': 213 | image = image.astype("uint8") 214 | 215 | cmd = [ FFMPEG_BINARY, '-y', 216 | '-s', "%dx%d"%(image.shape[:2][::-1]), 217 | "-f", 'rawvideo', 218 | '-pix_fmt', "rgba" if (image.shape[2] == 4) else "rgb24", 219 | '-i','-', filename] 220 | 221 | if logfile: 222 | log_file = open(filename + ".log", 'w+') 223 | else: 224 | log_file = sp.PIPE 225 | 226 | popen_params = {"stdout": DEVNULL, 227 | "stderr": log_file, 228 | "stdin": sp.PIPE} 229 | 230 | if os.name == "nt": 231 | popen_params["creationflags"] = 0x08000000 232 | 233 | proc = sp.Popen(cmd, **popen_params) 234 | out, err = proc.communicate(image.tostring()) 235 | 236 | if proc.returncode: 237 | err = "\n".join(["[MoviePy] Running : %s\n" % cmd, 238 | "WARNING: this command returned an error:", 239 | err.decode('utf8')]) 240 | raise IOError(err) 241 | 242 | del proc 243 | -------------------------------------------------------------------------------- /chain_img_processor/image.py: -------------------------------------------------------------------------------- 1 | from jaa import JaaCore 2 | 3 | from termcolor import colored, cprint 4 | 5 | from typing import Any 6 | 7 | version = "4.0.0" 8 | 9 | 10 | class ChainImgProcessor(JaaCore): 11 | def __init__(self): 12 | JaaCore.__init__(self) 13 | 14 | self.processors:dict = { 15 | } 16 | 17 | self.processors_objects:dict[str,list[ChainImgPlugin]] = {} 18 | 19 | self.default_chain = "" 20 | self.init_on_start = "" 21 | 22 | self.inited_processors = [] 23 | 24 | self.is_demo_row_render = False 25 | 26 | def process_plugin_manifest(self, modname, manifest): 27 | # adding processors from plugin manifest 28 | if "img_processor" in manifest: # process commands 29 | for cmd in manifest["img_processor"].keys(): 30 | self.processors[cmd] = manifest["img_processor"][cmd] 31 | 32 | return manifest 33 | 34 | def init_with_plugins(self): 35 | self.init_plugins(["core"]) 36 | #self.init_plugins() 37 | self.display_init_info() 38 | 39 | #self.init_translator_engine(self.default_translator) 40 | init_on_start_arr = self.init_on_start.split(",") 41 | for proc_id in init_on_start_arr: 42 | self.init_processor(proc_id) 43 | 44 | def run_chain(self, img, params:dict[str,Any] = None, chain:str = None, thread_index:int = 0): 45 | if chain is None: 46 | chain = self.default_chain 47 | if params is None: 48 | params = {} 49 | params["_thread_index"] = thread_index 50 | 51 | chain_ar = chain.split(",") 52 | # init all not inited processors first 53 | for proc_id in chain_ar: 54 | if proc_id != "": 55 | if not proc_id in self.inited_processors: 56 | self.init_processor(proc_id) 57 | 58 | 59 | 60 | # run processing 61 | if self.is_demo_row_render: 62 | import cv2 63 | import numpy as np 64 | height, width, channels = img.shape 65 | img_blank = np.zeros((height+30, width*(1+len(chain_ar)), 3), dtype=np.uint8) 66 | img_blank.fill(255) 67 | 68 | y = 30 69 | x = 0 70 | img_blank[y:y + height, x:x + width] = img 71 | 72 | # Set the font scale and thickness 73 | font_scale = 1 74 | thickness = 2 75 | 76 | # Set the font face to a monospace font 77 | font_face = cv2.FONT_HERSHEY_SIMPLEX 78 | 79 | cv2.putText(img_blank, "original", (x+4, y-7), font_face, font_scale, (0, 0, 0), thickness) 80 | 81 | 82 | i = 0 83 | for proc_id in chain_ar: 84 | i += 1 85 | if proc_id != "": 86 | #img = self.processors[proc_id][1](self, img, params) # params can be modified inside 87 | y = 30 88 | img = self.processors_objects[proc_id][thread_index].process(img,params) 89 | if self.is_demo_row_render: 90 | x = width*i 91 | img_blank[y:y + height, x:x + width] = img 92 | cv2.putText(img_blank, proc_id, (x + 4, y - 7), font_face, font_scale, (0, 0, 0), thickness) 93 | 94 | if self.is_demo_row_render: 95 | return img_blank, params 96 | 97 | return img, params 98 | 99 | # ---------------- init translation stuff ---------------- 100 | def fill_processors_for_thread_chains(self, threads:int = 1, chain:str = None): 101 | if chain is None: 102 | chain = self.default_chain 103 | 104 | chain_ar = chain.split(",") 105 | # init all not inited processors first 106 | for processor_id in chain_ar: 107 | if processor_id != "": 108 | if self.processors_objects.get(processor_id) is None: 109 | self.processors_objects[processor_id] = [] 110 | while len(self.processors_objects[processor_id]) < threads: 111 | self.add_processor_to_list(processor_id) 112 | 113 | def add_processor_to_list(self, processor_id: str): 114 | obj = self.processors[processor_id](self) 115 | obj.init_plugin() 116 | if self.processors_objects.get(processor_id) is None: 117 | self.processors_objects[processor_id] = [] 118 | self.processors_objects[processor_id].append(obj) 119 | def init_processor(self, processor_id: str): 120 | if processor_id == "": # blank line case 121 | return 122 | 123 | if processor_id in self.inited_processors: 124 | # already inited 125 | return 126 | 127 | try: 128 | self.print_blue("TRY: init processor plugin '{0}'...".format(processor_id)) 129 | #self.processors[processor_id][0](self) 130 | self.add_processor_to_list(processor_id) 131 | self.inited_processors.append(processor_id) 132 | self.print_blue("SUCCESS: '{0}' inited!".format(processor_id)) 133 | 134 | except Exception as e: 135 | self.print_error("Error init processor plugin {0}...".format(processor_id), e) 136 | 137 | # ------------ formatting stuff ------------------- 138 | def display_init_info(self): 139 | cprint("ChainImgProcessor v{0}:".format(version), "blue", end=' ') 140 | self.format_print_key_list("processors:", self.processors.keys()) 141 | 142 | def format_print_key_list(self, key:str, value:list): 143 | print(colored(key+": ", "blue")+", ".join(value)) 144 | 145 | def print_error(self,err_txt,e:Exception = None): 146 | cprint(err_txt,"red") 147 | # if e != None: 148 | # cprint(e,"red") 149 | import traceback 150 | traceback.print_exc() 151 | 152 | def print_red(self,txt): 153 | cprint(txt,"red") 154 | 155 | def print_blue(self, txt): 156 | cprint(txt, "blue") 157 | 158 | class ChainImgPlugin: 159 | def __init__(self, core: ChainImgProcessor): 160 | self.core = core 161 | 162 | def init_plugin(self): # here you can init something. Called once 163 | pass 164 | def process(self, img, params:dict): # process img. Called multiple 165 | return img 166 | 167 | _img_processor:ChainImgProcessor = None 168 | def get_single_image_processor() -> ChainImgProcessor: 169 | global _img_processor 170 | if _img_processor is None: 171 | _img_processor = ChainImgProcessor() 172 | _img_processor.init_with_plugins() 173 | return _img_processor -------------------------------------------------------------------------------- /chain_img_processor/video.py: -------------------------------------------------------------------------------- 1 | from threading import Thread 2 | 3 | from chain_img_processor import ChainImgProcessor 4 | 5 | 6 | #version = "1.0.0" 7 | 8 | 9 | 10 | class ThreadWithReturnValue(Thread): 11 | 12 | def __init__(self, group=None, target=None, name=None, 13 | args=(), kwargs={}, Verbose=None): 14 | Thread.__init__(self, group, target, name, args, kwargs) 15 | self._return = None 16 | 17 | def run(self): 18 | if self._target is not None: 19 | self._return = self._target(*self._args, 20 | **self._kwargs) 21 | 22 | def join(self, *args): 23 | Thread.join(self, *args) 24 | return self._return 25 | 26 | 27 | # in beta 28 | class ChainVideoProcessor(ChainImgProcessor): 29 | def __init__(self): 30 | ChainImgProcessor.__init__(self) 31 | 32 | self.video_save_codec = "libx264" 33 | self.video_save_crf = 14 34 | 35 | def init_with_plugins(self): 36 | self.init_plugins(["core","core_video"]) 37 | self.display_init_info() 38 | 39 | init_on_start_arr = self.init_on_start.split(",") 40 | for proc_id in init_on_start_arr: 41 | self.init_processor(proc_id) 42 | 43 | def run_video_chain(self, source_video, target_video, fps, threads:int = 1, chain = None, params_frame_gen_func = None, video_audio = None): 44 | import cv2 45 | from tqdm import tqdm 46 | from chain_img_processor.ffmpeg_writer import FFMPEG_VideoWriter # ffmpeg install needed 47 | 48 | cap = cv2.VideoCapture(source_video) 49 | # width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) 50 | # height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) 51 | frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) 52 | 53 | # first frame do manually - because upscale may happen, we need to estimate width/height 54 | ret, frame = cap.read() 55 | if params_frame_gen_func is not None: 56 | params = params_frame_gen_func(self, frame) 57 | else: 58 | params = {} 59 | frame_processed, params = self.run_chain(frame,params,chain) 60 | height, width, channels = frame_processed.shape 61 | 62 | self.fill_processors_for_thread_chains(threads,chain) 63 | #print(self.processors_objects) 64 | #import threading 65 | #locks:list[threading.Lock] = [] 66 | locks: list[bool] = [] 67 | for i in range(threads): 68 | #locks.append(threading.Lock()) 69 | locks.append(False) 70 | 71 | temp = [] 72 | with FFMPEG_VideoWriter(target_video, (width, height), fps, codec=self.video_save_codec, crf=self.video_save_crf, audiofile=video_audio) as output_video_ff: 73 | with tqdm(total=frame_count, desc='Processing', unit="frame", dynamic_ncols=True, 74 | bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}{postfix}]') as progress: 75 | 76 | # do first frame 77 | output_video_ff.write_frame(frame_processed) 78 | progress.update(1) # 79 | cnt_frames = 0 80 | 81 | # do rest frames 82 | while True: 83 | # getting frame 84 | ret, frame = cap.read() 85 | 86 | if not ret: 87 | break 88 | cnt_frames+=1 89 | thread_ind = cnt_frames % threads 90 | # we are having an array of length %gpu_threads%, running in parallel 91 | # so if array is equal or longer than gpu threads, waiting 92 | #while len(temp) >= threads: 93 | while locks[thread_ind]: 94 | #print('WAIT', thread_ind) 95 | # we are order dependent, so we are forced to wait for first element to finish. When finished removing thread from the list 96 | frame_processed, params = temp.pop(0).join() 97 | locks[params["_thread_index"]] = False 98 | #print('OFF',cnt_frames,locks[params["_thread_index"]],locks) 99 | # writing into output 100 | output_video_ff.write_frame(frame_processed) 101 | # updating the status 102 | progress.update(1) 103 | 104 | # calc params for frame 105 | if params_frame_gen_func is not None: 106 | params = params_frame_gen_func(self,frame) 107 | else: 108 | params = {} 109 | 110 | # adding new frame to the list and starting it 111 | locks[thread_ind] = True 112 | #print('ON', cnt_frames, thread_ind, locks) 113 | temp.append( 114 | ThreadWithReturnValue(target=self.run_chain, args=(frame, params, chain, thread_ind))) 115 | temp[-1].start() 116 | 117 | while len(temp) > 0: 118 | # we are order dependent, so we are forced to wait for first element to finish. When finished removing thread from the list 119 | frame_processed, params = temp.pop(0).join() 120 | locks[params["_thread_index"]] = False 121 | # writing into output 122 | output_video_ff.write_frame(frame_processed) 123 | 124 | progress.update(1) 125 | 126 | #print("FINAL", locks) 127 | 128 | _video_processor:ChainVideoProcessor = None 129 | def get_single_video_processor() -> ChainVideoProcessor: 130 | global _video_processor 131 | if _video_processor is None: 132 | _video_processor = ChainVideoProcessor() 133 | _video_processor.init_with_plugins() 134 | return _video_processor 135 | -------------------------------------------------------------------------------- /colab_example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "provenance": [], 7 | "gpuType": "T4" 8 | }, 9 | "kernelspec": { 10 | "name": "python3", 11 | "display_name": "Python 3" 12 | }, 13 | "language_info": { 14 | "name": "python" 15 | }, 16 | "accelerator": "GPU" 17 | }, 18 | "cells": [ 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "metadata": { 23 | "id": "dhOwLvtgDjNq" 24 | }, 25 | "outputs": [], 26 | "source": [ 27 | "#@title **SETUP THE PROGRAM**. IF YOU HAVE ANY ERRORS DURING PIP INSTALL IGNORE THEN AND CONTINUE. IF LAST CELL RUNS IT RUNS, IF NOT, SEND MESSAGE TO DISCORD.\n", 28 | "!git clone https://github.com/RichardErkhov/FastFaceSwap\n", 29 | "%cd FastFaceSwap\n", 30 | "!wget https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/inswapper_128.onnx\n", 31 | "!wget https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/GFPGANv1.4.pth\n", 32 | "!wget https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/complex_256_v7_stage3_12999.h5\n", 33 | "!wget https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/GFPGANv1.4.onnx\n", 34 | "!pip install -q torch==2.1.0 torchvision torchaudio\n", 35 | "!pip install -q onnxruntime-gpu\n", 36 | "!pip install -q -r requirements.txt\n", 37 | "!pip install -q tensorflow-gpu==2.10.1\n", 38 | "!pip install -q protobuf==3.20.2\n", 39 | "!pip uninstall opencv-python opencv-headless-python opencv-contrib-python -q -y\n", 40 | "!pip install -q opencv-python" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "source": [ 46 | "%cd /content/FastFaceSwap\n", 47 | "#@title **RUN THE PROGRAM**\n", 48 | "#@markdown path to the face image\n", 49 | "face_image = \"seed0000.png\" #@param {type:\"string\"}\n", 50 | "#@markdown path to the video where to change the face\n", 51 | "target = \"seed0001.png\" #@param {type:\"string\"}\n", 52 | "output = \"output.png\" #@param {type:\"string\"}\n", 53 | "#@markdown the amount of the threads that program runs in\n", 54 | "threads = 16 #@param {type:\"integer\"}\n", 55 | "#@markdown if the input is image\n", 56 | "image = True #@param {type:\"boolean\"}\n", 57 | "face_enhancer = 'gfpgan' #@param [\"none\", \"gfpgan\", \"ffe\", \"codeformer\", \"gpfgan_onnx\",\"real_esrgan\"]\n", 58 | "#@markdown FFE face enhancer is still in progress, it might drastically change the color of the face!\n", 59 | "if image:\n", 60 | " !python main.py -f $face_image -t $target -o $output --threads $threads --image --cli --face-enhancer $face_enhancer\n", 61 | "else:\n", 62 | " !python main.py -f $face_image -t $target -o $output --threads $threads --cli --face-enhancer $face_enhancer" 63 | ], 64 | "metadata": { 65 | "id": "ZGgC07ndD3BX" 66 | }, 67 | "execution_count": null, 68 | "outputs": [] 69 | } 70 | ] 71 | } -------------------------------------------------------------------------------- /face.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RichardErkhov/FastFaceSwap/be10e8d0110678b955bd235a06be0f934f567890/face.jpg -------------------------------------------------------------------------------- /gfpgan/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .archs import * 3 | from .data import * 4 | from .models import * 5 | from .utils import * 6 | 7 | # from .version import * 8 | -------------------------------------------------------------------------------- /gfpgan/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import arch modules for registry 6 | # scan all the files that end with '_arch.py' under the archs folder 7 | arch_folder = osp.dirname(osp.abspath(__file__)) 8 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 9 | # import all the arch modules 10 | _arch_modules = [importlib.import_module(f'gfpgan.archs.{file_name}') for file_name in arch_filenames] 11 | -------------------------------------------------------------------------------- /gfpgan/archs/arcface_arch.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from basicsr.utils.registry import ARCH_REGISTRY 3 | 4 | 5 | def conv3x3(inplanes, outplanes, stride=1): 6 | """A simple wrapper for 3x3 convolution with padding. 7 | 8 | Args: 9 | inplanes (int): Channel number of inputs. 10 | outplanes (int): Channel number of outputs. 11 | stride (int): Stride in convolution. Default: 1. 12 | """ 13 | return nn.Conv2d(inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | """Basic residual block used in the ResNetArcFace architecture. 18 | 19 | Args: 20 | inplanes (int): Channel number of inputs. 21 | planes (int): Channel number of outputs. 22 | stride (int): Stride in convolution. Default: 1. 23 | downsample (nn.Module): The downsample module. Default: None. 24 | """ 25 | expansion = 1 # output channel expansion ratio 26 | 27 | def __init__(self, inplanes, planes, stride=1, downsample=None): 28 | super(BasicBlock, self).__init__() 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = nn.BatchNorm2d(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = nn.BatchNorm2d(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | residual = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | residual = self.downsample(x) 49 | 50 | out += residual 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class IRBlock(nn.Module): 57 | """Improved residual block (IR Block) used in the ResNetArcFace architecture. 58 | 59 | Args: 60 | inplanes (int): Channel number of inputs. 61 | planes (int): Channel number of outputs. 62 | stride (int): Stride in convolution. Default: 1. 63 | downsample (nn.Module): The downsample module. Default: None. 64 | use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. 65 | """ 66 | expansion = 1 # output channel expansion ratio 67 | 68 | def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True): 69 | super(IRBlock, self).__init__() 70 | self.bn0 = nn.BatchNorm2d(inplanes) 71 | self.conv1 = conv3x3(inplanes, inplanes) 72 | self.bn1 = nn.BatchNorm2d(inplanes) 73 | self.prelu = nn.PReLU() 74 | self.conv2 = conv3x3(inplanes, planes, stride) 75 | self.bn2 = nn.BatchNorm2d(planes) 76 | self.downsample = downsample 77 | self.stride = stride 78 | self.use_se = use_se 79 | if self.use_se: 80 | self.se = SEBlock(planes) 81 | 82 | def forward(self, x): 83 | residual = x 84 | out = self.bn0(x) 85 | out = self.conv1(out) 86 | out = self.bn1(out) 87 | out = self.prelu(out) 88 | 89 | out = self.conv2(out) 90 | out = self.bn2(out) 91 | if self.use_se: 92 | out = self.se(out) 93 | 94 | if self.downsample is not None: 95 | residual = self.downsample(x) 96 | 97 | out += residual 98 | out = self.prelu(out) 99 | 100 | return out 101 | 102 | 103 | class Bottleneck(nn.Module): 104 | """Bottleneck block used in the ResNetArcFace architecture. 105 | 106 | Args: 107 | inplanes (int): Channel number of inputs. 108 | planes (int): Channel number of outputs. 109 | stride (int): Stride in convolution. Default: 1. 110 | downsample (nn.Module): The downsample module. Default: None. 111 | """ 112 | expansion = 4 # output channel expansion ratio 113 | 114 | def __init__(self, inplanes, planes, stride=1, downsample=None): 115 | super(Bottleneck, self).__init__() 116 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 117 | self.bn1 = nn.BatchNorm2d(planes) 118 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 119 | self.bn2 = nn.BatchNorm2d(planes) 120 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 121 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 122 | self.relu = nn.ReLU(inplace=True) 123 | self.downsample = downsample 124 | self.stride = stride 125 | 126 | def forward(self, x): 127 | residual = x 128 | 129 | out = self.conv1(x) 130 | out = self.bn1(out) 131 | out = self.relu(out) 132 | 133 | out = self.conv2(out) 134 | out = self.bn2(out) 135 | out = self.relu(out) 136 | 137 | out = self.conv3(out) 138 | out = self.bn3(out) 139 | 140 | if self.downsample is not None: 141 | residual = self.downsample(x) 142 | 143 | out += residual 144 | out = self.relu(out) 145 | 146 | return out 147 | 148 | 149 | class SEBlock(nn.Module): 150 | """The squeeze-and-excitation block (SEBlock) used in the IRBlock. 151 | 152 | Args: 153 | channel (int): Channel number of inputs. 154 | reduction (int): Channel reduction ration. Default: 16. 155 | """ 156 | 157 | def __init__(self, channel, reduction=16): 158 | super(SEBlock, self).__init__() 159 | self.avg_pool = nn.AdaptiveAvgPool2d(1) # pool to 1x1 without spatial information 160 | self.fc = nn.Sequential( 161 | nn.Linear(channel, channel // reduction), nn.PReLU(), nn.Linear(channel // reduction, channel), 162 | nn.Sigmoid()) 163 | 164 | def forward(self, x): 165 | b, c, _, _ = x.size() 166 | y = self.avg_pool(x).view(b, c) 167 | y = self.fc(y).view(b, c, 1, 1) 168 | return x * y 169 | 170 | 171 | @ARCH_REGISTRY.register() 172 | class ResNetArcFace(nn.Module): 173 | """ArcFace with ResNet architectures. 174 | 175 | Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition. 176 | 177 | Args: 178 | block (str): Block used in the ArcFace architecture. 179 | layers (tuple(int)): Block numbers in each layer. 180 | use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True. 181 | """ 182 | 183 | def __init__(self, block, layers, use_se=True): 184 | if block == 'IRBlock': 185 | block = IRBlock 186 | self.inplanes = 64 187 | self.use_se = use_se 188 | super(ResNetArcFace, self).__init__() 189 | 190 | self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False) 191 | self.bn1 = nn.BatchNorm2d(64) 192 | self.prelu = nn.PReLU() 193 | self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 194 | self.layer1 = self._make_layer(block, 64, layers[0]) 195 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 196 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 197 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 198 | self.bn4 = nn.BatchNorm2d(512) 199 | self.dropout = nn.Dropout() 200 | self.fc5 = nn.Linear(512 * 8 * 8, 512) 201 | self.bn5 = nn.BatchNorm1d(512) 202 | 203 | # initialization 204 | for m in self.modules(): 205 | if isinstance(m, nn.Conv2d): 206 | nn.init.xavier_normal_(m.weight) 207 | elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): 208 | nn.init.constant_(m.weight, 1) 209 | nn.init.constant_(m.bias, 0) 210 | elif isinstance(m, nn.Linear): 211 | nn.init.xavier_normal_(m.weight) 212 | nn.init.constant_(m.bias, 0) 213 | 214 | def _make_layer(self, block, planes, num_blocks, stride=1): 215 | downsample = None 216 | if stride != 1 or self.inplanes != planes * block.expansion: 217 | downsample = nn.Sequential( 218 | nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False), 219 | nn.BatchNorm2d(planes * block.expansion), 220 | ) 221 | layers = [] 222 | layers.append(block(self.inplanes, planes, stride, downsample, use_se=self.use_se)) 223 | self.inplanes = planes 224 | for _ in range(1, num_blocks): 225 | layers.append(block(self.inplanes, planes, use_se=self.use_se)) 226 | 227 | return nn.Sequential(*layers) 228 | 229 | def forward(self, x): 230 | x = self.conv1(x) 231 | x = self.bn1(x) 232 | x = self.prelu(x) 233 | x = self.maxpool(x) 234 | 235 | x = self.layer1(x) 236 | x = self.layer2(x) 237 | x = self.layer3(x) 238 | x = self.layer4(x) 239 | x = self.bn4(x) 240 | x = self.dropout(x) 241 | x = x.view(x.size(0), -1) 242 | x = self.fc5(x) 243 | x = self.bn5(x) 244 | 245 | return x 246 | -------------------------------------------------------------------------------- /gfpgan/archs/gfpgan_bilinear_arch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | from torch import nn 6 | 7 | from .gfpganv1_arch import ResUpBlock 8 | from .stylegan2_bilinear_arch import (ConvLayer, EqualConv2d, EqualLinear, ResBlock, ScaledLeakyReLU, 9 | StyleGAN2GeneratorBilinear) 10 | 11 | 12 | class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear): 13 | """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform). 14 | 15 | It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for 16 | deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT. 17 | 18 | Args: 19 | out_size (int): The spatial size of outputs. 20 | num_style_feat (int): Channel number of style features. Default: 512. 21 | num_mlp (int): Layer number of MLP style layers. Default: 8. 22 | channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. 23 | lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. 24 | narrow (float): The narrow ratio for channels. Default: 1. 25 | sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. 26 | """ 27 | 28 | def __init__(self, 29 | out_size, 30 | num_style_feat=512, 31 | num_mlp=8, 32 | channel_multiplier=2, 33 | lr_mlp=0.01, 34 | narrow=1, 35 | sft_half=False): 36 | super(StyleGAN2GeneratorBilinearSFT, self).__init__( 37 | out_size, 38 | num_style_feat=num_style_feat, 39 | num_mlp=num_mlp, 40 | channel_multiplier=channel_multiplier, 41 | lr_mlp=lr_mlp, 42 | narrow=narrow) 43 | self.sft_half = sft_half 44 | 45 | def forward(self, 46 | styles, 47 | conditions, 48 | input_is_latent=False, 49 | noise=None, 50 | randomize_noise=True, 51 | truncation=1, 52 | truncation_latent=None, 53 | inject_index=None, 54 | return_latents=False): 55 | """Forward function for StyleGAN2GeneratorBilinearSFT. 56 | 57 | Args: 58 | styles (list[Tensor]): Sample codes of styles. 59 | conditions (list[Tensor]): SFT conditions to generators. 60 | input_is_latent (bool): Whether input is latent style. Default: False. 61 | noise (Tensor | None): Input noise or None. Default: None. 62 | randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. 63 | truncation (float): The truncation ratio. Default: 1. 64 | truncation_latent (Tensor | None): The truncation latent tensor. Default: None. 65 | inject_index (int | None): The injection index for mixing noise. Default: None. 66 | return_latents (bool): Whether to return style latents. Default: False. 67 | """ 68 | # style codes -> latents with Style MLP layer 69 | if not input_is_latent: 70 | styles = [self.style_mlp(s) for s in styles] 71 | # noises 72 | if noise is None: 73 | if randomize_noise: 74 | noise = [None] * self.num_layers # for each style conv layer 75 | else: # use the stored noise 76 | noise = [getattr(self.noises, f'noise{i}') for i in range(self.num_layers)] 77 | # style truncation 78 | if truncation < 1: 79 | style_truncation = [] 80 | for style in styles: 81 | style_truncation.append(truncation_latent + truncation * (style - truncation_latent)) 82 | styles = style_truncation 83 | # get style latents with injection 84 | if len(styles) == 1: 85 | inject_index = self.num_latent 86 | 87 | if styles[0].ndim < 3: 88 | # repeat latent code for all the layers 89 | latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 90 | else: # used for encoder with different latent code for each layer 91 | latent = styles[0] 92 | elif len(styles) == 2: # mixing noises 93 | if inject_index is None: 94 | inject_index = random.randint(1, self.num_latent - 1) 95 | latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1) 96 | latent2 = styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1) 97 | latent = torch.cat([latent1, latent2], 1) 98 | 99 | # main generation 100 | out = self.constant_input(latent.shape[0]) 101 | out = self.style_conv1(out, latent[:, 0], noise=noise[0]) 102 | skip = self.to_rgb1(out, latent[:, 1]) 103 | 104 | i = 1 105 | for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2], 106 | noise[2::2], self.to_rgbs): 107 | out = conv1(out, latent[:, i], noise=noise1) 108 | 109 | # the conditions may have fewer levels 110 | if i < len(conditions): 111 | # SFT part to combine the conditions 112 | if self.sft_half: # only apply SFT to half of the channels 113 | out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1) 114 | out_sft = out_sft * conditions[i - 1] + conditions[i] 115 | out = torch.cat([out_same, out_sft], dim=1) 116 | else: # apply SFT to all the channels 117 | out = out * conditions[i - 1] + conditions[i] 118 | 119 | out = conv2(out, latent[:, i + 1], noise=noise2) 120 | skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space 121 | i += 2 122 | 123 | image = skip 124 | 125 | if return_latents: 126 | return image, latent 127 | else: 128 | return image, None 129 | 130 | 131 | @ARCH_REGISTRY.register() 132 | class GFPGANBilinear(nn.Module): 133 | """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT. 134 | 135 | It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for 136 | deployment. It can be easily converted to the clean version: GFPGANv1Clean. 137 | 138 | 139 | Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior. 140 | 141 | Args: 142 | out_size (int): The spatial size of outputs. 143 | num_style_feat (int): Channel number of style features. Default: 512. 144 | channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. 145 | decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None. 146 | fix_decoder (bool): Whether to fix the decoder. Default: True. 147 | 148 | num_mlp (int): Layer number of MLP style layers. Default: 8. 149 | lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01. 150 | input_is_latent (bool): Whether input is latent style. Default: False. 151 | different_w (bool): Whether to use different latent w for different layers. Default: False. 152 | narrow (float): The narrow ratio for channels. Default: 1. 153 | sft_half (bool): Whether to apply SFT on half of the input channels. Default: False. 154 | """ 155 | 156 | def __init__( 157 | self, 158 | out_size, 159 | num_style_feat=512, 160 | channel_multiplier=1, 161 | decoder_load_path=None, 162 | fix_decoder=True, 163 | # for stylegan decoder 164 | num_mlp=8, 165 | lr_mlp=0.01, 166 | input_is_latent=False, 167 | different_w=False, 168 | narrow=1, 169 | sft_half=False): 170 | 171 | super(GFPGANBilinear, self).__init__() 172 | self.input_is_latent = input_is_latent 173 | self.different_w = different_w 174 | self.num_style_feat = num_style_feat 175 | 176 | unet_narrow = narrow * 0.5 # by default, use a half of input channels 177 | channels = { 178 | '4': int(512 * unet_narrow), 179 | '8': int(512 * unet_narrow), 180 | '16': int(512 * unet_narrow), 181 | '32': int(512 * unet_narrow), 182 | '64': int(256 * channel_multiplier * unet_narrow), 183 | '128': int(128 * channel_multiplier * unet_narrow), 184 | '256': int(64 * channel_multiplier * unet_narrow), 185 | '512': int(32 * channel_multiplier * unet_narrow), 186 | '1024': int(16 * channel_multiplier * unet_narrow) 187 | } 188 | 189 | self.log_size = int(math.log(out_size, 2)) 190 | first_out_size = 2**(int(math.log(out_size, 2))) 191 | 192 | self.conv_body_first = ConvLayer(3, channels[f'{first_out_size}'], 1, bias=True, activate=True) 193 | 194 | # downsample 195 | in_channels = channels[f'{first_out_size}'] 196 | self.conv_body_down = nn.ModuleList() 197 | for i in range(self.log_size, 2, -1): 198 | out_channels = channels[f'{2**(i - 1)}'] 199 | self.conv_body_down.append(ResBlock(in_channels, out_channels)) 200 | in_channels = out_channels 201 | 202 | self.final_conv = ConvLayer(in_channels, channels['4'], 3, bias=True, activate=True) 203 | 204 | # upsample 205 | in_channels = channels['4'] 206 | self.conv_body_up = nn.ModuleList() 207 | for i in range(3, self.log_size + 1): 208 | out_channels = channels[f'{2**i}'] 209 | self.conv_body_up.append(ResUpBlock(in_channels, out_channels)) 210 | in_channels = out_channels 211 | 212 | # to RGB 213 | self.toRGB = nn.ModuleList() 214 | for i in range(3, self.log_size + 1): 215 | self.toRGB.append(EqualConv2d(channels[f'{2**i}'], 3, 1, stride=1, padding=0, bias=True, bias_init_val=0)) 216 | 217 | if different_w: 218 | linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat 219 | else: 220 | linear_out_channel = num_style_feat 221 | 222 | self.final_linear = EqualLinear( 223 | channels['4'] * 4 * 4, linear_out_channel, bias=True, bias_init_val=0, lr_mul=1, activation=None) 224 | 225 | # the decoder: stylegan2 generator with SFT modulations 226 | self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT( 227 | out_size=out_size, 228 | num_style_feat=num_style_feat, 229 | num_mlp=num_mlp, 230 | channel_multiplier=channel_multiplier, 231 | lr_mlp=lr_mlp, 232 | narrow=narrow, 233 | sft_half=sft_half) 234 | 235 | # load pre-trained stylegan2 model if necessary 236 | if decoder_load_path: 237 | self.stylegan_decoder.load_state_dict( 238 | torch.load(decoder_load_path, map_location=lambda storage, loc: storage)['params_ema']) 239 | # fix decoder without updating params 240 | if fix_decoder: 241 | for _, param in self.stylegan_decoder.named_parameters(): 242 | param.requires_grad = False 243 | 244 | # for SFT modulations (scale and shift) 245 | self.condition_scale = nn.ModuleList() 246 | self.condition_shift = nn.ModuleList() 247 | for i in range(3, self.log_size + 1): 248 | out_channels = channels[f'{2**i}'] 249 | if sft_half: 250 | sft_out_channels = out_channels 251 | else: 252 | sft_out_channels = out_channels * 2 253 | self.condition_scale.append( 254 | nn.Sequential( 255 | EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), 256 | ScaledLeakyReLU(0.2), 257 | EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=1))) 258 | self.condition_shift.append( 259 | nn.Sequential( 260 | EqualConv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0), 261 | ScaledLeakyReLU(0.2), 262 | EqualConv2d(out_channels, sft_out_channels, 3, stride=1, padding=1, bias=True, bias_init_val=0))) 263 | 264 | def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True): 265 | """Forward function for GFPGANBilinear. 266 | 267 | Args: 268 | x (Tensor): Input images. 269 | return_latents (bool): Whether to return style latents. Default: False. 270 | return_rgb (bool): Whether return intermediate rgb images. Default: True. 271 | randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True. 272 | """ 273 | conditions = [] 274 | unet_skips = [] 275 | out_rgbs = [] 276 | 277 | # encoder 278 | feat = self.conv_body_first(x) 279 | for i in range(self.log_size - 2): 280 | feat = self.conv_body_down[i](feat) 281 | unet_skips.insert(0, feat) 282 | 283 | feat = self.final_conv(feat) 284 | 285 | # style code 286 | style_code = self.final_linear(feat.view(feat.size(0), -1)) 287 | if self.different_w: 288 | style_code = style_code.view(style_code.size(0), -1, self.num_style_feat) 289 | 290 | # decode 291 | for i in range(self.log_size - 2): 292 | # add unet skip 293 | feat = feat + unet_skips[i] 294 | # ResUpLayer 295 | feat = self.conv_body_up[i](feat) 296 | # generate scale and shift for SFT layers 297 | scale = self.condition_scale[i](feat) 298 | conditions.append(scale.clone()) 299 | shift = self.condition_shift[i](feat) 300 | conditions.append(shift.clone()) 301 | # generate rgb images 302 | if return_rgb: 303 | out_rgbs.append(self.toRGB[i](feat)) 304 | 305 | # decoder 306 | image, _ = self.stylegan_decoder([style_code], 307 | conditions, 308 | return_latents=return_latents, 309 | input_is_latent=self.input_is_latent, 310 | randomize_noise=randomize_noise) 311 | 312 | return image, out_rgbs 313 | -------------------------------------------------------------------------------- /gfpgan/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import dataset modules for registry 6 | # scan all the files that end with '_dataset.py' under the data folder 7 | data_folder = osp.dirname(osp.abspath(__file__)) 8 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 9 | # import all the dataset modules 10 | _dataset_modules = [importlib.import_module(f'gfpgan.data.{file_name}') for file_name in dataset_filenames] 11 | -------------------------------------------------------------------------------- /gfpgan/data/ffhq_degradation_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os.path as osp 5 | import torch 6 | import torch.utils.data as data 7 | from basicsr.data import degradations as degradations 8 | from basicsr.data.data_util import paths_from_folder 9 | from basicsr.data.transforms import augment 10 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 11 | from basicsr.utils.registry import DATASET_REGISTRY 12 | from torchvision.transforms.functional import (adjust_brightness, adjust_contrast, adjust_hue, adjust_saturation, 13 | normalize) 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class FFHQDegradationDataset(data.Dataset): 18 | """FFHQ dataset for GFPGAN. 19 | 20 | It reads high resolution images, and then generate low-quality (LQ) images on-the-fly. 21 | 22 | Args: 23 | opt (dict): Config for train datasets. It contains the following keys: 24 | dataroot_gt (str): Data root path for gt. 25 | io_backend (dict): IO backend type and other kwarg. 26 | mean (list | tuple): Image mean. 27 | std (list | tuple): Image std. 28 | use_hflip (bool): Whether to horizontally flip. 29 | Please see more options in the codes. 30 | """ 31 | 32 | def __init__(self, opt): 33 | super(FFHQDegradationDataset, self).__init__() 34 | self.opt = opt 35 | # file client (io backend) 36 | self.file_client = None 37 | self.io_backend_opt = opt['io_backend'] 38 | 39 | self.gt_folder = opt['dataroot_gt'] 40 | self.mean = opt['mean'] 41 | self.std = opt['std'] 42 | self.out_size = opt['out_size'] 43 | 44 | self.crop_components = opt.get('crop_components', False) # facial components 45 | self.eye_enlarge_ratio = opt.get('eye_enlarge_ratio', 1) # whether enlarge eye regions 46 | 47 | if self.crop_components: 48 | # load component list from a pre-process pth files 49 | self.components_list = torch.load(opt.get('component_path')) 50 | 51 | # file client (lmdb io backend) 52 | if self.io_backend_opt['type'] == 'lmdb': 53 | self.io_backend_opt['db_paths'] = self.gt_folder 54 | if not self.gt_folder.endswith('.lmdb'): 55 | raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") 56 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 57 | self.paths = [line.split('.')[0] for line in fin] 58 | else: 59 | # disk backend: scan file list from a folder 60 | self.paths = paths_from_folder(self.gt_folder) 61 | 62 | # degradation configurations 63 | self.blur_kernel_size = opt['blur_kernel_size'] 64 | self.kernel_list = opt['kernel_list'] 65 | self.kernel_prob = opt['kernel_prob'] 66 | self.blur_sigma = opt['blur_sigma'] 67 | self.downsample_range = opt['downsample_range'] 68 | self.noise_range = opt['noise_range'] 69 | self.jpeg_range = opt['jpeg_range'] 70 | 71 | # color jitter 72 | self.color_jitter_prob = opt.get('color_jitter_prob') 73 | self.color_jitter_pt_prob = opt.get('color_jitter_pt_prob') 74 | self.color_jitter_shift = opt.get('color_jitter_shift', 20) 75 | # to gray 76 | self.gray_prob = opt.get('gray_prob') 77 | 78 | logger = get_root_logger() 79 | logger.info(f'Blur: blur_kernel_size {self.blur_kernel_size}, sigma: [{", ".join(map(str, self.blur_sigma))}]') 80 | logger.info(f'Downsample: downsample_range [{", ".join(map(str, self.downsample_range))}]') 81 | logger.info(f'Noise: [{", ".join(map(str, self.noise_range))}]') 82 | logger.info(f'JPEG compression: [{", ".join(map(str, self.jpeg_range))}]') 83 | 84 | if self.color_jitter_prob is not None: 85 | logger.info(f'Use random color jitter. Prob: {self.color_jitter_prob}, shift: {self.color_jitter_shift}') 86 | if self.gray_prob is not None: 87 | logger.info(f'Use random gray. Prob: {self.gray_prob}') 88 | self.color_jitter_shift /= 255. 89 | 90 | @staticmethod 91 | def color_jitter(img, shift): 92 | """jitter color: randomly jitter the RGB values, in numpy formats""" 93 | jitter_val = np.random.uniform(-shift, shift, 3).astype(np.float32) 94 | img = img + jitter_val 95 | img = np.clip(img, 0, 1) 96 | return img 97 | 98 | @staticmethod 99 | def color_jitter_pt(img, brightness, contrast, saturation, hue): 100 | """jitter color: randomly jitter the brightness, contrast, saturation, and hue, in torch Tensor formats""" 101 | fn_idx = torch.randperm(4) 102 | for fn_id in fn_idx: 103 | if fn_id == 0 and brightness is not None: 104 | brightness_factor = torch.tensor(1.0).uniform_(brightness[0], brightness[1]).item() 105 | img = adjust_brightness(img, brightness_factor) 106 | 107 | if fn_id == 1 and contrast is not None: 108 | contrast_factor = torch.tensor(1.0).uniform_(contrast[0], contrast[1]).item() 109 | img = adjust_contrast(img, contrast_factor) 110 | 111 | if fn_id == 2 and saturation is not None: 112 | saturation_factor = torch.tensor(1.0).uniform_(saturation[0], saturation[1]).item() 113 | img = adjust_saturation(img, saturation_factor) 114 | 115 | if fn_id == 3 and hue is not None: 116 | hue_factor = torch.tensor(1.0).uniform_(hue[0], hue[1]).item() 117 | img = adjust_hue(img, hue_factor) 118 | return img 119 | 120 | def get_component_coordinates(self, index, status): 121 | """Get facial component (left_eye, right_eye, mouth) coordinates from a pre-loaded pth file""" 122 | components_bbox = self.components_list[f'{index:08d}'] 123 | if status[0]: # hflip 124 | # exchange right and left eye 125 | tmp = components_bbox['left_eye'] 126 | components_bbox['left_eye'] = components_bbox['right_eye'] 127 | components_bbox['right_eye'] = tmp 128 | # modify the width coordinate 129 | components_bbox['left_eye'][0] = self.out_size - components_bbox['left_eye'][0] 130 | components_bbox['right_eye'][0] = self.out_size - components_bbox['right_eye'][0] 131 | components_bbox['mouth'][0] = self.out_size - components_bbox['mouth'][0] 132 | 133 | # get coordinates 134 | locations = [] 135 | for part in ['left_eye', 'right_eye', 'mouth']: 136 | mean = components_bbox[part][0:2] 137 | half_len = components_bbox[part][2] 138 | if 'eye' in part: 139 | half_len *= self.eye_enlarge_ratio 140 | loc = np.hstack((mean - half_len + 1, mean + half_len)) 141 | loc = torch.from_numpy(loc).float() 142 | locations.append(loc) 143 | return locations 144 | 145 | def __getitem__(self, index): 146 | if self.file_client is None: 147 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 148 | 149 | # load gt image 150 | # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. 151 | gt_path = self.paths[index] 152 | img_bytes = self.file_client.get(gt_path) 153 | img_gt = imfrombytes(img_bytes, float32=True) 154 | 155 | # random horizontal flip 156 | img_gt, status = augment(img_gt, hflip=self.opt['use_hflip'], rotation=False, return_status=True) 157 | h, w, _ = img_gt.shape 158 | 159 | # get facial component coordinates 160 | if self.crop_components: 161 | locations = self.get_component_coordinates(index, status) 162 | loc_left_eye, loc_right_eye, loc_mouth = locations 163 | 164 | # ------------------------ generate lq image ------------------------ # 165 | # blur 166 | kernel = degradations.random_mixed_kernels( 167 | self.kernel_list, 168 | self.kernel_prob, 169 | self.blur_kernel_size, 170 | self.blur_sigma, 171 | self.blur_sigma, [-math.pi, math.pi], 172 | noise_range=None) 173 | img_lq = cv2.filter2D(img_gt, -1, kernel) 174 | # downsample 175 | scale = np.random.uniform(self.downsample_range[0], self.downsample_range[1]) 176 | img_lq = cv2.resize(img_lq, (int(w // scale), int(h // scale)), interpolation=cv2.INTER_LINEAR) 177 | # noise 178 | if self.noise_range is not None: 179 | img_lq = degradations.random_add_gaussian_noise(img_lq, self.noise_range) 180 | # jpeg compression 181 | if self.jpeg_range is not None: 182 | img_lq = degradations.random_add_jpg_compression(img_lq, self.jpeg_range) 183 | 184 | # resize to original size 185 | img_lq = cv2.resize(img_lq, (w, h), interpolation=cv2.INTER_LINEAR) 186 | 187 | # random color jitter (only for lq) 188 | if self.color_jitter_prob is not None and (np.random.uniform() < self.color_jitter_prob): 189 | img_lq = self.color_jitter(img_lq, self.color_jitter_shift) 190 | # random to gray (only for lq) 191 | if self.gray_prob and np.random.uniform() < self.gray_prob: 192 | img_lq = cv2.cvtColor(img_lq, cv2.COLOR_BGR2GRAY) 193 | img_lq = np.tile(img_lq[:, :, None], [1, 1, 3]) 194 | if self.opt.get('gt_gray'): # whether convert GT to gray images 195 | img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2GRAY) 196 | img_gt = np.tile(img_gt[:, :, None], [1, 1, 3]) # repeat the color channels 197 | 198 | # BGR to RGB, HWC to CHW, numpy to tensor 199 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 200 | 201 | # random color jitter (pytorch version) (only for lq) 202 | if self.color_jitter_pt_prob is not None and (np.random.uniform() < self.color_jitter_pt_prob): 203 | brightness = self.opt.get('brightness', (0.5, 1.5)) 204 | contrast = self.opt.get('contrast', (0.5, 1.5)) 205 | saturation = self.opt.get('saturation', (0, 1.5)) 206 | hue = self.opt.get('hue', (-0.1, 0.1)) 207 | img_lq = self.color_jitter_pt(img_lq, brightness, contrast, saturation, hue) 208 | 209 | # round and clip 210 | img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255. 211 | 212 | # normalize 213 | normalize(img_gt, self.mean, self.std, inplace=True) 214 | normalize(img_lq, self.mean, self.std, inplace=True) 215 | 216 | if self.crop_components: 217 | return_dict = { 218 | 'lq': img_lq, 219 | 'gt': img_gt, 220 | 'gt_path': gt_path, 221 | 'loc_left_eye': loc_left_eye, 222 | 'loc_right_eye': loc_right_eye, 223 | 'loc_mouth': loc_mouth 224 | } 225 | return return_dict 226 | else: 227 | return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path} 228 | 229 | def __len__(self): 230 | return len(self.paths) 231 | -------------------------------------------------------------------------------- /gfpgan/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import model modules for registry 6 | # scan all the files that end with '_model.py' under the model folder 7 | model_folder = osp.dirname(osp.abspath(__file__)) 8 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 9 | # import all the model modules 10 | _model_modules = [importlib.import_module(f'gfpgan.models.{file_name}') for file_name in model_filenames] 11 | -------------------------------------------------------------------------------- /gfpgan/train.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import os.path as osp 3 | from basicsr.train import train_pipeline 4 | 5 | import gfpgan.archs 6 | import gfpgan.data 7 | import gfpgan.models 8 | 9 | if __name__ == '__main__': 10 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 11 | train_pipeline(root_path) 12 | -------------------------------------------------------------------------------- /gfpgan/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import torch 4 | from basicsr.utils import img2tensor, tensor2img 5 | from basicsr.utils.download_util import load_file_from_url 6 | from facexlib.utils.face_restoration_helper import FaceRestoreHelper 7 | from torchvision.transforms.functional import normalize 8 | 9 | from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear 10 | from gfpgan.archs.gfpganv1_arch import GFPGANv1 11 | from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean 12 | 13 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 14 | 15 | 16 | class GFPGANer(): 17 | """Helper for restoration with GFPGAN. 18 | 19 | It will detect and crop faces, and then resize the faces to 512x512. 20 | GFPGAN is used to restored the resized faces. 21 | The background is upsampled with the bg_upsampler. 22 | Finally, the faces will be pasted back to the upsample background image. 23 | 24 | Args: 25 | model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically). 26 | upscale (float): The upscale of the final output. Default: 2. 27 | arch (str): The GFPGAN architecture. Option: clean | original. Default: clean. 28 | channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2. 29 | bg_upsampler (nn.Module): The upsampler for the background. Default: None. 30 | """ 31 | 32 | def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None): 33 | self.upscale = upscale 34 | self.bg_upsampler = bg_upsampler 35 | 36 | # initialize model 37 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device 38 | # initialize the GFP-GAN 39 | if arch == 'clean': 40 | self.gfpgan = GFPGANv1Clean( 41 | out_size=512, 42 | num_style_feat=512, 43 | channel_multiplier=channel_multiplier, 44 | decoder_load_path=None, 45 | fix_decoder=False, 46 | num_mlp=8, 47 | input_is_latent=True, 48 | different_w=True, 49 | narrow=1, 50 | sft_half=True) 51 | elif arch == 'bilinear': 52 | self.gfpgan = GFPGANBilinear( 53 | out_size=512, 54 | num_style_feat=512, 55 | channel_multiplier=channel_multiplier, 56 | decoder_load_path=None, 57 | fix_decoder=False, 58 | num_mlp=8, 59 | input_is_latent=True, 60 | different_w=True, 61 | narrow=1, 62 | sft_half=True) 63 | elif arch == 'original': 64 | self.gfpgan = GFPGANv1( 65 | out_size=512, 66 | num_style_feat=512, 67 | channel_multiplier=channel_multiplier, 68 | decoder_load_path=None, 69 | fix_decoder=True, 70 | num_mlp=8, 71 | input_is_latent=True, 72 | different_w=True, 73 | narrow=1, 74 | sft_half=True) 75 | elif arch == 'RestoreFormer': 76 | from gfpgan.archs.restoreformer_arch import RestoreFormer 77 | self.gfpgan = RestoreFormer() 78 | # initialize face helper 79 | self.face_helper = FaceRestoreHelper( 80 | upscale, 81 | face_size=512, 82 | crop_ratio=(1, 1), 83 | det_model='retinaface_resnet50', 84 | save_ext='png', 85 | use_parse=True, 86 | device=self.device, 87 | model_rootpath='gfpgan/weights') 88 | 89 | if model_path.startswith('https://'): 90 | model_path = load_file_from_url( 91 | url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None) 92 | loadnet = torch.load(model_path) 93 | if 'params_ema' in loadnet: 94 | keyname = 'params_ema' 95 | else: 96 | keyname = 'params' 97 | self.gfpgan.load_state_dict(loadnet[keyname], strict=True) 98 | self.gfpgan.eval() 99 | self.gfpgan = self.gfpgan.to(self.device) 100 | 101 | @torch.no_grad() 102 | def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5): 103 | self.face_helper.clean_all() 104 | 105 | if has_aligned: # the inputs are already aligned 106 | img = cv2.resize(img, (512, 512)) 107 | self.face_helper.cropped_faces = [img] 108 | else: 109 | self.face_helper.read_image(img) 110 | # get face landmarks for each face 111 | self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5) 112 | # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels 113 | # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations. 114 | # align and warp each face 115 | self.face_helper.align_warp_face() 116 | 117 | # face restoration 118 | for cropped_face in self.face_helper.cropped_faces: 119 | # prepare data 120 | cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) 121 | normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) 122 | cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device) 123 | 124 | try: 125 | output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0] 126 | # convert to image 127 | restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)) 128 | except RuntimeError as error: 129 | print(f'\tFailed inference for GFPGAN: {error}.') 130 | restored_face = cropped_face 131 | 132 | restored_face = restored_face.astype('uint8') 133 | self.face_helper.add_restored_face(restored_face) 134 | 135 | if not has_aligned and paste_back: 136 | # upsample the background 137 | if self.bg_upsampler is not None: 138 | # Now only support RealESRGAN for upsampling background 139 | bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0] 140 | else: 141 | bg_img = None 142 | 143 | self.face_helper.get_inverse_affine(None) 144 | # paste each restored face to the input image 145 | restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img) 146 | return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img 147 | else: 148 | return self.face_helper.cropped_faces, self.face_helper.restored_faces, None 149 | -------------------------------------------------------------------------------- /gfpgan/weights/README.md: -------------------------------------------------------------------------------- 1 | # Weights 2 | 3 | Put the downloaded weights to this folder. 4 | -------------------------------------------------------------------------------- /globalsz.py: -------------------------------------------------------------------------------- 1 | import threading 2 | # config 3 | select_face_swapper_gpu = [0] #None for all gpus. Or make it a list, like [0, 1, 2] to select gpus to use 4 | select_gfpgan_gpu = 0 #supports only 1 gpu for now 5 | select_realesrgan_gpu = 0 # supports only 1 gpu for now 6 | realeasrgan_model = "RealESRGAN_x2plus" 7 | realesrgan_fp16 = False 8 | realesrgan_outscale = 2 9 | 10 | 11 | # used by the program 12 | lowmem = True 13 | generator = None 14 | restorer = None 15 | gfpgan_onnx_model = None 16 | realeasrgan_enhancer = None 17 | THREAD_SEMAPHORE = threading.Semaphore() 18 | realesrgan_lock = threading.Semaphore() 19 | cuda = True # just debugging I think 20 | source_face = None -------------------------------------------------------------------------------- /install_directml_windows.cmd: -------------------------------------------------------------------------------- 1 | python -m pip install virtualenv 2 | python -m virtualenv venv 3 | call venv\scripts\activate.bat 4 | curl --location --remote-header-name --remote-name https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/inswapper_128.onnx 5 | curl --location --remote-header-name --remote-name https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/GFPGANv1.4.pth 6 | curl --location --remote-header-name --remote-name https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/complex_256_v7_stage3_12999.h5 7 | curl --location --remote-header-name --remote-name https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/GFPGANv1.4.onnx 8 | pip install torch-directml 9 | pip install -r requirements.txt 10 | pip install tensorflow-directml-plugin 11 | pip uninstall onnxruntime-gpu onnxruntime -q -y 12 | pip install onnxruntime-directml 13 | pip install protobuf==3.20.2 14 | pip uninstall opencv-python opencv-headless-python opencv-contrib-python -q -y 15 | pip install opencv-python 16 | -------------------------------------------------------------------------------- /install_linux.cmd: -------------------------------------------------------------------------------- 1 | python -m pip install virtualenv 2 | python -m virtualenv venv 3 | call venv/bin/activate.bat 4 | wget https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/inswapper_128.onnx 5 | wget https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/GFPGANv1.4.pth 6 | wget https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/complex_256_v7_stage3_12999.h5 7 | wget https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/GFPGANv1.4.onnx 8 | pip install torch torchvision torchaudio 9 | pip install onnxruntime-gpu 10 | pip install -r requirements.txt 11 | pip install tensorflow-gpu==2.10.1 12 | pip install protobuf==3.20.2 13 | pip uninstall opencv-python opencv-headless-python opencv-contrib-python -q -y 14 | pip install opencv-python 15 | -------------------------------------------------------------------------------- /install_linux.sh: -------------------------------------------------------------------------------- 1 | python -m pip install virtualenv 2 | python -m virtualenv venv 3 | source venv/bin/activate 4 | wget https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/inswapper_128.onnx 5 | wget https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/GFPGANv1.4.pth 6 | wget https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/complex_256_v7_stage3_12999.h5 7 | wget https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/GFPGANv1.4.onnx 8 | pip install torch torchvision torchaudio 9 | pip install -r requirements.txt 10 | pip install tensorflow-gpu==2.10.1 11 | pip install protobuf==3.20.2 12 | pip uninstall opencv-python opencv-headless-python opencv-contrib-python -q -y 13 | pip install opencv-python -------------------------------------------------------------------------------- /install_mac.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python -m pip install virtualenv 3 | python -m virtualenv venv 4 | source venv/bin/activate 5 | curl -LO https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/inswapper_128.onnx 6 | curl -LO https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/GFPGANv1.4.pth 7 | curl -LO https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/complex_256_v7_stage3_12999.h5 8 | curl -LO https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/GFPGANv1.4.onnx 9 | pip install torch torchvision torchaudio 10 | pip install onnxruntime-silicon 11 | pip install -r requirements.txt 12 | pip install tensorflow-gpu==2.12 13 | pip install protobuf==3.20.2 14 | pip uninstall opencv-python opencv-headless-python opencv-contrib-python -q -y 15 | pip install opencv-python 16 | -------------------------------------------------------------------------------- /install_termux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | python -m pip install virtualenv 3 | python -m virtualenv venv 4 | source venv/bin/activate 5 | wget https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/inswapper_128.onnx 6 | wget https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/GFPGANv1.4.pth 7 | wget https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/complex_256_v7_stage3_12999.h5 8 | wget https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/GFPGANv1.4.onnx 9 | pip install torch torchvision torchaudio 10 | pip install onnxruntime 11 | pip install -r requirements.txt 12 | pip install tensorflow==2.10.1 13 | pip install protobuf==3.20.2 14 | pip uninstall opencv-python opencv-headless-python opencv-contrib-python -q -y 15 | pip install opencv-python 16 | -------------------------------------------------------------------------------- /install_windows.cmd: -------------------------------------------------------------------------------- 1 | python -m pip install virtualenv 2 | python -m virtualenv venv 3 | call venv\scripts\activate.bat 4 | curl --location --remote-header-name --remote-name https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/inswapper_128.onnx 5 | curl --location --remote-header-name --remote-name https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/GFPGANv1.4.pth 6 | curl --location --remote-header-name --remote-name https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/complex_256_v7_stage3_12999.h5 7 | curl --location --remote-header-name --remote-name https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/GFPGANv1.4.onnx 8 | pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 9 | pip install onnxruntime-gpu 10 | pip install -r requirements.txt 11 | pip install tensorflow-gpu==2.10.1 12 | pip install protobuf==3.20.2 13 | 14 | pip uninstall opencv-python opencv-headless-python opencv-contrib-python -q -y 15 | pip install opencv-python 16 | -------------------------------------------------------------------------------- /ll.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RichardErkhov/FastFaceSwap/be10e8d0110678b955bd235a06be0f934f567890/ll.pkl -------------------------------------------------------------------------------- /plugins/__pycache__/codeformer_app_cv2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RichardErkhov/FastFaceSwap/be10e8d0110678b955bd235a06be0f934f567890/plugins/__pycache__/codeformer_app_cv2.cpython-310.pyc -------------------------------------------------------------------------------- /plugins/__pycache__/codeformer_face_helper_cv2.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RichardErkhov/FastFaceSwap/be10e8d0110678b955bd235a06be0f934f567890/plugins/__pycache__/codeformer_face_helper_cv2.cpython-310.pyc -------------------------------------------------------------------------------- /plugins/codeformer_app_cv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified version from codeformer-pip project 3 | 4 | S-Lab License 1.0 5 | 6 | Copyright 2022 S-Lab 7 | 8 | https://github.com/kadirnar/codeformer-pip/blob/main/LICENSE 9 | """ 10 | 11 | import os 12 | 13 | import cv2 14 | import torch 15 | from codeformer.facelib.detection import init_detection_model 16 | from codeformer.facelib.parsing import init_parsing_model 17 | from torchvision.transforms.functional import normalize 18 | 19 | from codeformer.basicsr.archs.rrdbnet_arch import RRDBNet 20 | from codeformer.basicsr.utils import img2tensor, imwrite, tensor2img 21 | from codeformer.basicsr.utils.download_util import load_file_from_url 22 | from codeformer.basicsr.utils.realesrgan_utils import RealESRGANer 23 | from codeformer.basicsr.utils.registry import ARCH_REGISTRY 24 | from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper 25 | from codeformer.facelib.utils.misc import is_gray 26 | import threading 27 | 28 | from plugins.codeformer_face_helper_cv2 import FaceRestoreHelperOptimized 29 | 30 | THREAD_LOCK_FACE_HELPER = threading.Lock() 31 | THREAD_LOCK_FACE_HELPER_CREATE = threading.Lock() 32 | THREAD_LOCK_FACE_HELPER_PROCERSSING = threading.Lock() 33 | THREAD_LOCK_CODEFORMER_NET = threading.Lock() 34 | THREAD_LOCK_CODEFORMER_NET_CREATE = threading.Lock() 35 | THREAD_LOCK_BGUPSAMPLER = threading.Lock() 36 | 37 | pretrain_model_url = { 38 | "codeformer": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth", 39 | "detection": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/detection_Resnet50_Final.pth", 40 | "parsing": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth", 41 | "realesrgan": "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/RealESRGAN_x2plus.pth", 42 | } 43 | 44 | # download weights 45 | if not os.path.exists("CodeFormer/weights/CodeFormer/codeformer.pth"): 46 | load_file_from_url( 47 | url=pretrain_model_url["codeformer"], model_dir="CodeFormer/weights/CodeFormer", progress=True, file_name=None 48 | ) 49 | if not os.path.exists("CodeFormer/weights/facelib/detection_Resnet50_Final.pth"): 50 | load_file_from_url( 51 | url=pretrain_model_url["detection"], model_dir="CodeFormer/weights/facelib", progress=True, file_name=None 52 | ) 53 | if not os.path.exists("CodeFormer/weights/facelib/parsing_parsenet.pth"): 54 | load_file_from_url( 55 | url=pretrain_model_url["parsing"], model_dir="CodeFormer/weights/facelib", progress=True, file_name=None 56 | ) 57 | if not os.path.exists("CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth"): 58 | load_file_from_url( 59 | url=pretrain_model_url["realesrgan"], model_dir="CodeFormer/weights/realesrgan", progress=True, file_name=None 60 | ) 61 | 62 | 63 | def imread(img_path): 64 | img = cv2.imread(img_path) 65 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 66 | return img 67 | 68 | 69 | # set enhancer with RealESRGAN 70 | def set_realesrgan(): 71 | half = True if torch.cuda.is_available() else False 72 | model = RRDBNet( 73 | num_in_ch=3, 74 | num_out_ch=3, 75 | num_feat=64, 76 | num_block=23, 77 | num_grow_ch=32, 78 | scale=2, 79 | ) 80 | upsampler = RealESRGANer( 81 | scale=2, 82 | model_path="CodeFormer/weights/realesrgan/RealESRGAN_x2plus.pth", 83 | model=model, 84 | tile=400, 85 | tile_pad=40, 86 | pre_pad=0, 87 | half=half, 88 | ) 89 | return upsampler 90 | 91 | 92 | upsampler = set_realesrgan() 93 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 94 | 95 | codeformers_cache = [] 96 | 97 | def get_codeformer(): 98 | if len(codeformers_cache) > 0: 99 | with THREAD_LOCK_CODEFORMER_NET: 100 | if len(codeformers_cache) > 0: 101 | return codeformers_cache.pop() 102 | 103 | with THREAD_LOCK_CODEFORMER_NET_CREATE: 104 | codeformer_net = ARCH_REGISTRY.get("CodeFormer")( 105 | dim_embd=512, 106 | codebook_size=1024, 107 | n_head=8, 108 | n_layers=9, 109 | connect_list=["32", "64", "128", "256"], 110 | ).to(device) 111 | ckpt_path = "CodeFormer/weights/CodeFormer/codeformer.pth" 112 | checkpoint = torch.load(ckpt_path)["params_ema"] 113 | codeformer_net.load_state_dict(checkpoint) 114 | codeformer_net.eval() 115 | return codeformer_net 116 | 117 | 118 | 119 | def release_codeformer(codeformer): 120 | with THREAD_LOCK_CODEFORMER_NET: 121 | codeformers_cache.append(codeformer) 122 | 123 | #os.makedirs("output", exist_ok=True) 124 | 125 | # ------- face restore thread cache ---------- 126 | 127 | face_restore_helper_cache = [] 128 | 129 | detection_model = "retinaface_resnet50" 130 | 131 | inited_face_restore_helper_nn = False 132 | 133 | import time 134 | 135 | def get_face_restore_helper(upscale): 136 | global inited_face_restore_helper_nn 137 | with THREAD_LOCK_FACE_HELPER: 138 | face_helper = FaceRestoreHelperOptimized( 139 | upscale, 140 | face_size=512, 141 | crop_ratio=(1, 1), 142 | det_model=detection_model, 143 | save_ext="png", 144 | use_parse=True, 145 | device=device, 146 | ) 147 | #return face_helper 148 | 149 | if inited_face_restore_helper_nn: 150 | while len(face_restore_helper_cache) == 0: 151 | time.sleep(0.05) 152 | face_detector, face_parse = face_restore_helper_cache.pop() 153 | face_helper.face_detector = face_detector 154 | face_helper.face_parse = face_parse 155 | return face_helper 156 | else: 157 | inited_face_restore_helper_nn = True 158 | face_helper.face_detector = init_detection_model(detection_model, half=False, device=face_helper.device) 159 | face_helper.face_parse = init_parsing_model(model_name="parsenet", device=face_helper.device) 160 | return face_helper 161 | 162 | def get_face_restore_helper2(upscale): # still not work well!!! 163 | face_helper = FaceRestoreHelperOptimized( 164 | upscale, 165 | face_size=512, 166 | crop_ratio=(1, 1), 167 | det_model=detection_model, 168 | save_ext="png", 169 | use_parse=True, 170 | device=device, 171 | ) 172 | #return face_helper 173 | 174 | if len(face_restore_helper_cache) > 0: 175 | with THREAD_LOCK_FACE_HELPER: 176 | if len(face_restore_helper_cache) > 0: 177 | face_detector, face_parse = face_restore_helper_cache.pop() 178 | face_helper.face_detector = face_detector 179 | face_helper.face_parse = face_parse 180 | return face_helper 181 | 182 | with THREAD_LOCK_FACE_HELPER_CREATE: 183 | face_helper.face_detector = init_detection_model(detection_model, half=False, device=face_helper.device) 184 | face_helper.face_parse = init_parsing_model(model_name="parsenet", device=face_helper.device) 185 | return face_helper 186 | 187 | def release_face_restore_helper(face_helper): 188 | #return 189 | #with THREAD_LOCK_FACE_HELPER: 190 | face_restore_helper_cache.append((face_helper.face_detector, face_helper.face_parse)) 191 | #pass 192 | 193 | def inference_app(image, background_enhance, face_upsample, upscale, codeformer_fidelity, skip_if_no_face = False): 194 | # take the default setting for the demo 195 | has_aligned = False 196 | only_center_face = False 197 | draw_box = False 198 | 199 | #print("Inp:", image, background_enhance, face_upsample, upscale, codeformer_fidelity) 200 | if isinstance(image, str): 201 | img = cv2.imread(str(image), cv2.IMREAD_COLOR) 202 | else: 203 | img = image 204 | #print("\timage size:", img.shape) 205 | 206 | upscale = int(upscale) # convert type to int 207 | if upscale > 4: # avoid memory exceeded due to too large upscale 208 | upscale = 4 209 | if upscale > 2 and max(img.shape[:2]) > 1000: # avoid memory exceeded due to too large img resolution 210 | upscale = 2 211 | if max(img.shape[:2]) > 1500: # avoid memory exceeded due to too large img resolution 212 | upscale = 1 213 | background_enhance = False 214 | #face_upsample = False 215 | 216 | face_helper = get_face_restore_helper(upscale) 217 | 218 | bg_upsampler = upsampler if background_enhance else None 219 | face_upsampler = upsampler if face_upsample else None 220 | 221 | if has_aligned: 222 | # the input faces are already cropped and aligned 223 | img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_LINEAR) 224 | face_helper.is_gray = is_gray(img, threshold=5) 225 | if face_helper.is_gray: 226 | print("\tgrayscale input: True") 227 | face_helper.cropped_faces = [img] 228 | else: 229 | with THREAD_LOCK_FACE_HELPER_PROCERSSING: 230 | face_helper.read_image(img) 231 | # get face landmarks for each face 232 | 233 | num_det_faces = face_helper.get_face_landmarks_5( 234 | only_center_face=only_center_face, resize=640, eye_dist_threshold=5 235 | ) 236 | #print(f"\tdetect {num_det_faces} faces") 237 | 238 | if num_det_faces == 0 and skip_if_no_face: 239 | release_face_restore_helper(face_helper) 240 | return img 241 | 242 | # align and warp each face 243 | face_helper.align_warp_face() 244 | 245 | 246 | 247 | # face restoration for each cropped face 248 | for idx, cropped_face in enumerate(face_helper.cropped_faces): 249 | # prepare data 250 | cropped_face_t = img2tensor(cropped_face / 255.0, bgr2rgb=True, float32=True) 251 | normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) 252 | cropped_face_t = cropped_face_t.unsqueeze(0).to(device) 253 | 254 | #with THREAD_LOCK_CODEFORMER_NET: 255 | codeformer_net = get_codeformer() 256 | try: 257 | with torch.no_grad(): 258 | output = codeformer_net(cropped_face_t, w=codeformer_fidelity, adain=True)[0] 259 | restored_face = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) 260 | del output 261 | #torch.cuda.empty_cache() 262 | except RuntimeError as error: 263 | print(f"Failed inference for CodeFormer: {error}") 264 | restored_face = tensor2img(cropped_face_t, rgb2bgr=True, min_max=(-1, 1)) 265 | release_codeformer(codeformer_net) 266 | 267 | restored_face = restored_face.astype("uint8") 268 | face_helper.add_restored_face(restored_face) 269 | 270 | # paste_back 271 | if not has_aligned: 272 | # upsample the background 273 | if bg_upsampler is not None: 274 | with THREAD_LOCK_BGUPSAMPLER: 275 | # Now only support RealESRGAN for upsampling background 276 | bg_img = bg_upsampler.enhance(img, outscale=upscale)[0] 277 | else: 278 | bg_img = None 279 | face_helper.get_inverse_affine(None) 280 | # paste each restored face to the input image 281 | if face_upsample and face_upsampler is not None: 282 | restored_img = face_helper.paste_faces_to_input_image( 283 | upsample_img=bg_img, 284 | draw_box=draw_box, 285 | face_upsampler=face_upsampler, 286 | ) 287 | else: 288 | restored_img = face_helper.paste_faces_to_input_image(upsample_img=bg_img, draw_box=draw_box) 289 | 290 | release_face_restore_helper(face_helper) 291 | # save restored img 292 | if isinstance(image, str): 293 | save_path = f"output/out.png" 294 | imwrite(restored_img, str(save_path)) 295 | return save_path 296 | else: 297 | return restored_img 298 | -------------------------------------------------------------------------------- /plugins/codeformer_face_helper_cv2.py: -------------------------------------------------------------------------------- 1 | from codeformer.facelib.utils.face_restoration_helper import FaceRestoreHelper 2 | 3 | import numpy as np 4 | from codeformer.basicsr.utils.misc import get_device 5 | 6 | class FaceRestoreHelperOptimized(FaceRestoreHelper): 7 | def __init__( 8 | self, 9 | upscale_factor, 10 | face_size=512, 11 | crop_ratio=(1, 1), 12 | det_model="retinaface_resnet50", 13 | save_ext="png", 14 | template_3points=False, 15 | pad_blur=False, 16 | use_parse=False, 17 | device=None, 18 | ): 19 | self.template_3points = template_3points # improve robustness 20 | self.upscale_factor = int(upscale_factor) 21 | # the cropped face ratio based on the square face 22 | self.crop_ratio = crop_ratio # (h, w) 23 | assert self.crop_ratio[0] >= 1 and self.crop_ratio[1] >= 1, "crop ration only supports >=1" 24 | self.face_size = (int(face_size * self.crop_ratio[1]), int(face_size * self.crop_ratio[0])) 25 | self.det_model = det_model 26 | 27 | if self.det_model == "dlib": 28 | # standard 5 landmarks for FFHQ faces with 1024 x 1024 29 | self.face_template = np.array( 30 | [ 31 | [686.77227723, 488.62376238], 32 | [586.77227723, 493.59405941], 33 | [337.91089109, 488.38613861], 34 | [437.95049505, 493.51485149], 35 | [513.58415842, 678.5049505], 36 | ] 37 | ) 38 | self.face_template = self.face_template / (1024 // face_size) 39 | elif self.template_3points: 40 | self.face_template = np.array([[192, 240], [319, 240], [257, 371]]) 41 | else: 42 | # standard 5 landmarks for FFHQ faces with 512 x 512 43 | # facexlib 44 | self.face_template = np.array( 45 | [ 46 | [192.98138, 239.94708], 47 | [318.90277, 240.1936], 48 | [256.63416, 314.01935], 49 | [201.26117, 371.41043], 50 | [313.08905, 371.15118], 51 | ] 52 | ) 53 | 54 | # dlib: left_eye: 36:41 right_eye: 42:47 nose: 30,32,33,34 left mouth corner: 48 right mouth corner: 54 55 | # self.face_template = np.array([[193.65928, 242.98541], [318.32558, 243.06108], [255.67984, 328.82894], 56 | # [198.22603, 372.82502], [313.91018, 372.75659]]) 57 | 58 | self.face_template = self.face_template * (face_size / 512.0) 59 | if self.crop_ratio[0] > 1: 60 | self.face_template[:, 1] += face_size * (self.crop_ratio[0] - 1) / 2 61 | if self.crop_ratio[1] > 1: 62 | self.face_template[:, 0] += face_size * (self.crop_ratio[1] - 1) / 2 63 | self.save_ext = save_ext 64 | self.pad_blur = pad_blur 65 | if self.pad_blur is True: 66 | self.template_3points = False 67 | 68 | self.all_landmarks_5 = [] 69 | self.det_faces = [] 70 | self.affine_matrices = [] 71 | self.inverse_affine_matrices = [] 72 | self.cropped_faces = [] 73 | self.restored_faces = [] 74 | self.pad_input_imgs = [] 75 | 76 | if device is None: 77 | # self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 78 | self.device = get_device() 79 | else: 80 | self.device = device 81 | 82 | # init face detection model 83 | # if self.det_model == "dlib": 84 | # self.face_detector, self.shape_predictor_5 = self.init_dlib( 85 | # dlib_model_url["face_detector"], dlib_model_url["shape_predictor_5"] 86 | # ) 87 | # else: 88 | # self.face_detector = init_detection_model(det_model, half=False, device=self.device) 89 | 90 | # init face parsing model 91 | self.use_parse = use_parse 92 | #self.face_parse = init_parsing_model(model_name="parsenet", device=self.device) 93 | 94 | # MUST set face_detector and face_parse!!! -------------------------------------------------------------------------------- /plugins/core.py: -------------------------------------------------------------------------------- 1 | # Core plugin 2 | # author: Vladislav Janvarev 3 | 4 | from chain_img_processor import ChainImgProcessor 5 | 6 | # start function 7 | def start(core:ChainImgProcessor): 8 | manifest = { 9 | "name": "Core plugin", 10 | "version": "2.0", 11 | 12 | "default_options": { 13 | "default_chain": "blur,to_grayscale", # default chain to run 14 | "init_on_start": "blur,to_grayscale", # init these processors on start 15 | "is_demo_row_render": False, 16 | }, 17 | 18 | } 19 | return manifest 20 | 21 | def start_with_options(core:ChainImgProcessor, manifest:dict): 22 | #print(manifest["options"]) 23 | options = manifest["options"] 24 | 25 | core.default_chain = options["default_chain"] 26 | core.init_on_start = options["init_on_start"] 27 | 28 | core.is_demo_row_render= options["is_demo_row_render"] 29 | 30 | return manifest 31 | -------------------------------------------------------------------------------- /plugins/core_video.py: -------------------------------------------------------------------------------- 1 | # Core plugin 2 | # author: Vladislav Janvarev 3 | 4 | from chain_img_processor import ChainImgProcessor, ChainVideoProcessor 5 | 6 | # start function 7 | def start(core:ChainImgProcessor): 8 | manifest = { 9 | "name": "Core video plugin", 10 | "version": "2.0", 11 | 12 | "default_options": { 13 | "video_save_codec": "libx264", # default codec to save 14 | "video_save_crf": 14, # default crf to save 15 | }, 16 | 17 | } 18 | return manifest 19 | 20 | def start_with_options(core:ChainVideoProcessor, manifest:dict): 21 | #print(manifest["options"]) 22 | options = manifest["options"] 23 | 24 | core.video_save_codec = options["video_save_codec"] 25 | core.video_save_crf = options["video_save_crf"] 26 | 27 | return manifest 28 | -------------------------------------------------------------------------------- /plugins/plugin_blur.py: -------------------------------------------------------------------------------- 1 | # Blur example filter 2 | # author: Vladislav Janvarev 3 | 4 | from chain_img_processor import ChainImgProcessor, ChainImgPlugin 5 | import os 6 | 7 | modname = os.path.basename(__file__)[:-3] # calculating modname 8 | 9 | # start function 10 | def start(core:ChainImgProcessor): 11 | manifest = { # plugin settings 12 | "name": "Blur filter", # name 13 | "version": "1.0", # version 14 | 15 | "default_options": { 16 | "power": 30, # 17 | }, 18 | 19 | "img_processor": { 20 | "blur": PluginBlur 21 | } 22 | } 23 | return manifest 24 | 25 | def start_with_options(core:ChainImgProcessor, manifest:dict): 26 | pass 27 | 28 | class PluginBlur(ChainImgPlugin): 29 | def init_plugin(self): 30 | pass 31 | def process(self, img, params:dict): 32 | # params can be used to transfer some img info to next processors 33 | import cv2 34 | options = self.core.plugin_options(modname) 35 | 36 | ksize = (int(options["power"]), int(options["power"])) 37 | 38 | # Using cv2.blur() method 39 | image = cv2.blur(img, ksize) 40 | 41 | return image 42 | -------------------------------------------------------------------------------- /plugins/plugin_codeformer.py: -------------------------------------------------------------------------------- 1 | # Codeformer enchance plugin 2 | # author: Vladislav Janvarev 3 | 4 | from chain_img_processor import ChainImgProcessor, ChainImgPlugin 5 | import os 6 | 7 | modname = os.path.basename(__file__)[:-3] # calculating modname 8 | 9 | # start function 10 | def start(core:ChainImgProcessor): 11 | manifest = { # plugin settings 12 | "name": "Codeformer", # name 13 | "version": "3.0", # version 14 | 15 | "default_options": { 16 | "background_enhance": True, # 17 | "face_upsample": True, # 18 | "upscale": 2, # 19 | "codeformer_fidelity": 0.8, 20 | "skip_if_no_face":False, 21 | 22 | }, 23 | 24 | "img_processor": { 25 | "codeformer": PluginCodeformer # 1 function - init, 2 - process 26 | } 27 | } 28 | return manifest 29 | 30 | def start_with_options(core:ChainImgProcessor, manifest:dict): 31 | pass 32 | 33 | class PluginCodeformer(ChainImgPlugin): 34 | def init_plugin(self): 35 | import plugins.codeformer_app_cv2 36 | pass 37 | 38 | def process(self, img, params:dict): 39 | # params can be used to transfer some img info to next processors 40 | from plugins.codeformer_app_cv2 import inference_app 41 | options = self.core.plugin_options(modname) 42 | 43 | image = inference_app(img, options.get("background_enhance"), options.get("face_upsample"), 44 | options.get("upscale"), options.get("codeformer_fidelity"), 45 | options.get("skip_if_no_face")) 46 | 47 | return image 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /plugins/plugin_resize_cv2.py: -------------------------------------------------------------------------------- 1 | # Resize example filter 2 | # author: Vladislav Janvarev 3 | 4 | from chain_img_processor import ChainImgProcessor, ChainImgPlugin 5 | import os 6 | 7 | modname = os.path.basename(__file__)[:-3] # calculating modname 8 | 9 | # start function 10 | def start(core:ChainImgProcessor): 11 | manifest = { # plugin settings 12 | "name": "Resize filter", # name 13 | "version": "1.0", # version 14 | 15 | "default_options": { 16 | "scale": 2.0, # 17 | }, 18 | 19 | "img_processor": { 20 | "resize_cv2": PluginResizeCv2 # 1 function - init, 2 - process 21 | } 22 | } 23 | return manifest 24 | 25 | def start_with_options(core:ChainImgProcessor, manifest:dict): 26 | pass 27 | 28 | class PluginResizeCv2(ChainImgPlugin): 29 | def init_plugin(self): 30 | pass 31 | 32 | def process(self, img, params: dict): 33 | # params can be used to transfer some img info to next processors 34 | import cv2 35 | options = self.core.plugin_options(modname) 36 | 37 | scale = options["scale"] 38 | # cv.INTER_CUBIC 39 | 40 | image = cv2.resize(img, None, fx=scale, fy=scale) 41 | 42 | return image 43 | -------------------------------------------------------------------------------- /plugins/plugin_to_grayscale.py: -------------------------------------------------------------------------------- 1 | # To grayscale example filter 2 | # author: Vladislav Janvarev 3 | 4 | from chain_img_processor import ChainImgProcessor, ChainImgPlugin 5 | import os 6 | 7 | modname = os.path.basename(__file__)[:-3] # calculating modname 8 | 9 | # start function 10 | def start(core:ChainImgProcessor): 11 | manifest = { # plugin settings 12 | "name": "Gray scale filter", # name 13 | "version": "1.0", # version 14 | 15 | "img_processor": { 16 | "to_grayscale": PluginGrayscale # 1 function - init, 2 - process 17 | } 18 | } 19 | return manifest 20 | 21 | 22 | class PluginGrayscale(ChainImgPlugin): 23 | def init_plugin(self): 24 | pass 25 | 26 | def process(self, img, params: dict): 27 | # params can be used to transfer some img info to next processors 28 | import cv2 29 | image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 30 | 31 | # Duplicate the grayscale channel to all three color channels 32 | image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) 33 | 34 | return image 35 | 36 | 37 | -------------------------------------------------------------------------------- /realesrgan/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | from .archs import * 3 | from .data import * 4 | from .models import * 5 | from .utils import * 6 | #from .version import * 7 | -------------------------------------------------------------------------------- /realesrgan/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import arch modules for registry 6 | # scan all the files that end with '_arch.py' under the archs folder 7 | arch_folder = osp.dirname(osp.abspath(__file__)) 8 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 9 | # import all the arch modules 10 | _arch_modules = [importlib.import_module(f'realesrgan.archs.{file_name}') for file_name in arch_filenames] 11 | -------------------------------------------------------------------------------- /realesrgan/archs/discriminator_arch.py: -------------------------------------------------------------------------------- 1 | from basicsr.utils.registry import ARCH_REGISTRY 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | from torch.nn.utils import spectral_norm 5 | 6 | 7 | @ARCH_REGISTRY.register() 8 | class UNetDiscriminatorSN(nn.Module): 9 | """Defines a U-Net discriminator with spectral normalization (SN) 10 | 11 | It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. 12 | 13 | Arg: 14 | num_in_ch (int): Channel number of inputs. Default: 3. 15 | num_feat (int): Channel number of base intermediate features. Default: 64. 16 | skip_connection (bool): Whether to use skip connections between U-Net. Default: True. 17 | """ 18 | 19 | def __init__(self, num_in_ch, num_feat=64, skip_connection=True): 20 | super(UNetDiscriminatorSN, self).__init__() 21 | self.skip_connection = skip_connection 22 | norm = spectral_norm 23 | # the first convolution 24 | self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) 25 | # downsample 26 | self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) 27 | self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) 28 | self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) 29 | # upsample 30 | self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) 31 | self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) 32 | self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) 33 | # extra convolutions 34 | self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) 35 | self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) 36 | self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) 37 | 38 | def forward(self, x): 39 | # downsample 40 | x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) 41 | x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) 42 | x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) 43 | x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) 44 | 45 | # upsample 46 | x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) 47 | x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) 48 | 49 | if self.skip_connection: 50 | x4 = x4 + x2 51 | x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) 52 | x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) 53 | 54 | if self.skip_connection: 55 | x5 = x5 + x1 56 | x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) 57 | x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) 58 | 59 | if self.skip_connection: 60 | x6 = x6 + x0 61 | 62 | # extra convolutions 63 | out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) 64 | out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) 65 | out = self.conv9(out) 66 | 67 | return out 68 | -------------------------------------------------------------------------------- /realesrgan/archs/srvgg_arch.py: -------------------------------------------------------------------------------- 1 | from basicsr.utils.registry import ARCH_REGISTRY 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | 6 | @ARCH_REGISTRY.register() 7 | class SRVGGNetCompact(nn.Module): 8 | """A compact VGG-style network structure for super-resolution. 9 | 10 | It is a compact network structure, which performs upsampling in the last layer and no convolution is 11 | conducted on the HR feature space. 12 | 13 | Args: 14 | num_in_ch (int): Channel number of inputs. Default: 3. 15 | num_out_ch (int): Channel number of outputs. Default: 3. 16 | num_feat (int): Channel number of intermediate features. Default: 64. 17 | num_conv (int): Number of convolution layers in the body network. Default: 16. 18 | upscale (int): Upsampling factor. Default: 4. 19 | act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu. 20 | """ 21 | 22 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu'): 23 | super(SRVGGNetCompact, self).__init__() 24 | self.num_in_ch = num_in_ch 25 | self.num_out_ch = num_out_ch 26 | self.num_feat = num_feat 27 | self.num_conv = num_conv 28 | self.upscale = upscale 29 | self.act_type = act_type 30 | 31 | self.body = nn.ModuleList() 32 | # the first conv 33 | self.body.append(nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)) 34 | # the first activation 35 | if act_type == 'relu': 36 | activation = nn.ReLU(inplace=True) 37 | elif act_type == 'prelu': 38 | activation = nn.PReLU(num_parameters=num_feat) 39 | elif act_type == 'leakyrelu': 40 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 41 | self.body.append(activation) 42 | 43 | # the body structure 44 | for _ in range(num_conv): 45 | self.body.append(nn.Conv2d(num_feat, num_feat, 3, 1, 1)) 46 | # activation 47 | if act_type == 'relu': 48 | activation = nn.ReLU(inplace=True) 49 | elif act_type == 'prelu': 50 | activation = nn.PReLU(num_parameters=num_feat) 51 | elif act_type == 'leakyrelu': 52 | activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 53 | self.body.append(activation) 54 | 55 | # the last conv 56 | self.body.append(nn.Conv2d(num_feat, num_out_ch * upscale * upscale, 3, 1, 1)) 57 | # upsample 58 | self.upsampler = nn.PixelShuffle(upscale) 59 | 60 | def forward(self, x): 61 | out = x 62 | for i in range(0, len(self.body)): 63 | out = self.body[i](out) 64 | 65 | out = self.upsampler(out) 66 | # add the nearest upsampled image, so that the network learns the residual 67 | base = F.interpolate(x, scale_factor=self.upscale, mode='nearest') 68 | out += base 69 | return out 70 | -------------------------------------------------------------------------------- /realesrgan/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import dataset modules for registry 6 | # scan all the files that end with '_dataset.py' under the data folder 7 | data_folder = osp.dirname(osp.abspath(__file__)) 8 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 9 | # import all the dataset modules 10 | _dataset_modules = [importlib.import_module(f'realesrgan.data.{file_name}') for file_name in dataset_filenames] 11 | -------------------------------------------------------------------------------- /realesrgan/data/realesrgan_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import os.path as osp 6 | import random 7 | import time 8 | import torch 9 | from basicsr.data.degradations import circular_lowpass_kernel, random_mixed_kernels 10 | from basicsr.data.transforms import augment 11 | from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor 12 | from basicsr.utils.registry import DATASET_REGISTRY 13 | from torch.utils import data as data 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class RealESRGANDataset(data.Dataset): 18 | """Dataset used for Real-ESRGAN model: 19 | Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. 20 | 21 | It loads gt (Ground-Truth) images, and augments them. 22 | It also generates blur kernels and sinc kernels for generating low-quality images. 23 | Note that the low-quality images are processed in tensors on GPUS for faster processing. 24 | 25 | Args: 26 | opt (dict): Config for train datasets. It contains the following keys: 27 | dataroot_gt (str): Data root path for gt. 28 | meta_info (str): Path for meta information file. 29 | io_backend (dict): IO backend type and other kwarg. 30 | use_hflip (bool): Use horizontal flips. 31 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). 32 | Please see more options in the codes. 33 | """ 34 | 35 | def __init__(self, opt): 36 | super(RealESRGANDataset, self).__init__() 37 | self.opt = opt 38 | self.file_client = None 39 | self.io_backend_opt = opt['io_backend'] 40 | self.gt_folder = opt['dataroot_gt'] 41 | 42 | # file client (lmdb io backend) 43 | if self.io_backend_opt['type'] == 'lmdb': 44 | self.io_backend_opt['db_paths'] = [self.gt_folder] 45 | self.io_backend_opt['client_keys'] = ['gt'] 46 | if not self.gt_folder.endswith('.lmdb'): 47 | raise ValueError(f"'dataroot_gt' should end with '.lmdb', but received {self.gt_folder}") 48 | with open(osp.join(self.gt_folder, 'meta_info.txt')) as fin: 49 | self.paths = [line.split('.')[0] for line in fin] 50 | else: 51 | # disk backend with meta_info 52 | # Each line in the meta_info describes the relative path to an image 53 | with open(self.opt['meta_info']) as fin: 54 | paths = [line.strip().split(' ')[0] for line in fin] 55 | self.paths = [os.path.join(self.gt_folder, v) for v in paths] 56 | 57 | # blur settings for the first degradation 58 | self.blur_kernel_size = opt['blur_kernel_size'] 59 | self.kernel_list = opt['kernel_list'] 60 | self.kernel_prob = opt['kernel_prob'] # a list for each kernel probability 61 | self.blur_sigma = opt['blur_sigma'] 62 | self.betag_range = opt['betag_range'] # betag used in generalized Gaussian blur kernels 63 | self.betap_range = opt['betap_range'] # betap used in plateau blur kernels 64 | self.sinc_prob = opt['sinc_prob'] # the probability for sinc filters 65 | 66 | # blur settings for the second degradation 67 | self.blur_kernel_size2 = opt['blur_kernel_size2'] 68 | self.kernel_list2 = opt['kernel_list2'] 69 | self.kernel_prob2 = opt['kernel_prob2'] 70 | self.blur_sigma2 = opt['blur_sigma2'] 71 | self.betag_range2 = opt['betag_range2'] 72 | self.betap_range2 = opt['betap_range2'] 73 | self.sinc_prob2 = opt['sinc_prob2'] 74 | 75 | # a final sinc filter 76 | self.final_sinc_prob = opt['final_sinc_prob'] 77 | 78 | self.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 79 | # TODO: kernel range is now hard-coded, should be in the configure file 80 | self.pulse_tensor = torch.zeros(21, 21).float() # convolving with pulse tensor brings no blurry effect 81 | self.pulse_tensor[10, 10] = 1 82 | 83 | def __getitem__(self, index): 84 | if self.file_client is None: 85 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 86 | 87 | # -------------------------------- Load gt images -------------------------------- # 88 | # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. 89 | gt_path = self.paths[index] 90 | # avoid errors caused by high latency in reading files 91 | retry = 3 92 | while retry > 0: 93 | try: 94 | img_bytes = self.file_client.get(gt_path, 'gt') 95 | except (IOError, OSError) as e: 96 | logger = get_root_logger() 97 | logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}') 98 | # change another file to read 99 | index = random.randint(0, self.__len__()) 100 | gt_path = self.paths[index] 101 | time.sleep(1) # sleep 1s for occasional server congestion 102 | else: 103 | break 104 | finally: 105 | retry -= 1 106 | img_gt = imfrombytes(img_bytes, float32=True) 107 | 108 | # -------------------- Do augmentation for training: flip, rotation -------------------- # 109 | img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot']) 110 | 111 | # crop or pad to 400 112 | # TODO: 400 is hard-coded. You may change it accordingly 113 | h, w = img_gt.shape[0:2] 114 | crop_pad_size = 400 115 | # pad 116 | if h < crop_pad_size or w < crop_pad_size: 117 | pad_h = max(0, crop_pad_size - h) 118 | pad_w = max(0, crop_pad_size - w) 119 | img_gt = cv2.copyMakeBorder(img_gt, 0, pad_h, 0, pad_w, cv2.BORDER_REFLECT_101) 120 | # crop 121 | if img_gt.shape[0] > crop_pad_size or img_gt.shape[1] > crop_pad_size: 122 | h, w = img_gt.shape[0:2] 123 | # randomly choose top and left coordinates 124 | top = random.randint(0, h - crop_pad_size) 125 | left = random.randint(0, w - crop_pad_size) 126 | img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...] 127 | 128 | # ------------------------ Generate kernels (used in the first degradation) ------------------------ # 129 | kernel_size = random.choice(self.kernel_range) 130 | if np.random.uniform() < self.opt['sinc_prob']: 131 | # this sinc filter setting is for kernels ranging from [7, 21] 132 | if kernel_size < 13: 133 | omega_c = np.random.uniform(np.pi / 3, np.pi) 134 | else: 135 | omega_c = np.random.uniform(np.pi / 5, np.pi) 136 | kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) 137 | else: 138 | kernel = random_mixed_kernels( 139 | self.kernel_list, 140 | self.kernel_prob, 141 | kernel_size, 142 | self.blur_sigma, 143 | self.blur_sigma, [-math.pi, math.pi], 144 | self.betag_range, 145 | self.betap_range, 146 | noise_range=None) 147 | # pad kernel 148 | pad_size = (21 - kernel_size) // 2 149 | kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) 150 | 151 | # ------------------------ Generate kernels (used in the second degradation) ------------------------ # 152 | kernel_size = random.choice(self.kernel_range) 153 | if np.random.uniform() < self.opt['sinc_prob2']: 154 | if kernel_size < 13: 155 | omega_c = np.random.uniform(np.pi / 3, np.pi) 156 | else: 157 | omega_c = np.random.uniform(np.pi / 5, np.pi) 158 | kernel2 = circular_lowpass_kernel(omega_c, kernel_size, pad_to=False) 159 | else: 160 | kernel2 = random_mixed_kernels( 161 | self.kernel_list2, 162 | self.kernel_prob2, 163 | kernel_size, 164 | self.blur_sigma2, 165 | self.blur_sigma2, [-math.pi, math.pi], 166 | self.betag_range2, 167 | self.betap_range2, 168 | noise_range=None) 169 | 170 | # pad kernel 171 | pad_size = (21 - kernel_size) // 2 172 | kernel2 = np.pad(kernel2, ((pad_size, pad_size), (pad_size, pad_size))) 173 | 174 | # ------------------------------------- the final sinc kernel ------------------------------------- # 175 | if np.random.uniform() < self.opt['final_sinc_prob']: 176 | kernel_size = random.choice(self.kernel_range) 177 | omega_c = np.random.uniform(np.pi / 3, np.pi) 178 | sinc_kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) 179 | sinc_kernel = torch.FloatTensor(sinc_kernel) 180 | else: 181 | sinc_kernel = self.pulse_tensor 182 | 183 | # BGR to RGB, HWC to CHW, numpy to tensor 184 | img_gt = img2tensor([img_gt], bgr2rgb=True, float32=True)[0] 185 | kernel = torch.FloatTensor(kernel) 186 | kernel2 = torch.FloatTensor(kernel2) 187 | 188 | return_d = {'gt': img_gt, 'kernel1': kernel, 'kernel2': kernel2, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path} 189 | return return_d 190 | 191 | def __len__(self): 192 | return len(self.paths) 193 | -------------------------------------------------------------------------------- /realesrgan/data/realesrgan_paired_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb 3 | from basicsr.data.transforms import augment, paired_random_crop 4 | from basicsr.utils import FileClient, imfrombytes, img2tensor 5 | from basicsr.utils.registry import DATASET_REGISTRY 6 | from torch.utils import data as data 7 | from torchvision.transforms.functional import normalize 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class RealESRGANPairedDataset(data.Dataset): 12 | """Paired image dataset for image restoration. 13 | 14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. 15 | 16 | There are three modes: 17 | 1. 'lmdb': Use lmdb files. 18 | If opt['io_backend'] == lmdb. 19 | 2. 'meta_info': Use meta information file to generate paths. 20 | If opt['io_backend'] != lmdb and opt['meta_info'] is not None. 21 | 3. 'folder': Scan folders to generate paths. 22 | The rest. 23 | 24 | Args: 25 | opt (dict): Config for train datasets. It contains the following keys: 26 | dataroot_gt (str): Data root path for gt. 27 | dataroot_lq (str): Data root path for lq. 28 | meta_info (str): Path for meta information file. 29 | io_backend (dict): IO backend type and other kwarg. 30 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. 31 | Default: '{}'. 32 | gt_size (int): Cropped patched size for gt patches. 33 | use_hflip (bool): Use horizontal flips. 34 | use_rot (bool): Use rotation (use vertical flip and transposing h 35 | and w for implementation). 36 | 37 | scale (bool): Scale, which will be added automatically. 38 | phase (str): 'train' or 'val'. 39 | """ 40 | 41 | def __init__(self, opt): 42 | super(RealESRGANPairedDataset, self).__init__() 43 | self.opt = opt 44 | self.file_client = None 45 | self.io_backend_opt = opt['io_backend'] 46 | # mean and std for normalizing the input images 47 | self.mean = opt['mean'] if 'mean' in opt else None 48 | self.std = opt['std'] if 'std' in opt else None 49 | 50 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 51 | self.filename_tmpl = opt['filename_tmpl'] if 'filename_tmpl' in opt else '{}' 52 | 53 | # file client (lmdb io backend) 54 | if self.io_backend_opt['type'] == 'lmdb': 55 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 56 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 57 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 58 | elif 'meta_info' in self.opt and self.opt['meta_info'] is not None: 59 | # disk backend with meta_info 60 | # Each line in the meta_info describes the relative path to an image 61 | with open(self.opt['meta_info']) as fin: 62 | paths = [line.strip() for line in fin] 63 | self.paths = [] 64 | for path in paths: 65 | gt_path, lq_path = path.split(', ') 66 | gt_path = os.path.join(self.gt_folder, gt_path) 67 | lq_path = os.path.join(self.lq_folder, lq_path) 68 | self.paths.append(dict([('gt_path', gt_path), ('lq_path', lq_path)])) 69 | else: 70 | # disk backend 71 | # it will scan the whole folder to get meta info 72 | # it will be time-consuming for folders with too many files. It is recommended using an extra meta txt file 73 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 74 | 75 | def __getitem__(self, index): 76 | if self.file_client is None: 77 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 78 | 79 | scale = self.opt['scale'] 80 | 81 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 82 | # image range: [0, 1], float32. 83 | gt_path = self.paths[index]['gt_path'] 84 | img_bytes = self.file_client.get(gt_path, 'gt') 85 | img_gt = imfrombytes(img_bytes, float32=True) 86 | lq_path = self.paths[index]['lq_path'] 87 | img_bytes = self.file_client.get(lq_path, 'lq') 88 | img_lq = imfrombytes(img_bytes, float32=True) 89 | 90 | # augmentation for training 91 | if self.opt['phase'] == 'train': 92 | gt_size = self.opt['gt_size'] 93 | # random crop 94 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 95 | # flip, rotation 96 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 97 | 98 | # BGR to RGB, HWC to CHW, numpy to tensor 99 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 100 | # normalize 101 | if self.mean is not None or self.std is not None: 102 | normalize(img_lq, self.mean, self.std, inplace=True) 103 | normalize(img_gt, self.mean, self.std, inplace=True) 104 | 105 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} 106 | 107 | def __len__(self): 108 | return len(self.paths) 109 | -------------------------------------------------------------------------------- /realesrgan/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from basicsr.utils import scandir 3 | from os import path as osp 4 | 5 | # automatically scan and import model modules for registry 6 | # scan all the files that end with '_model.py' under the model folder 7 | model_folder = osp.dirname(osp.abspath(__file__)) 8 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 9 | # import all the model modules 10 | _model_modules = [importlib.import_module(f'realesrgan.models.{file_name}') for file_name in model_filenames] 11 | -------------------------------------------------------------------------------- /realesrgan/models/realesrgan_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt 5 | from basicsr.data.transforms import paired_random_crop 6 | from basicsr.models.srgan_model import SRGANModel 7 | from basicsr.utils import DiffJPEG, USMSharp 8 | from basicsr.utils.img_process_util import filter2D 9 | from basicsr.utils.registry import MODEL_REGISTRY 10 | from collections import OrderedDict 11 | from torch.nn import functional as F 12 | 13 | 14 | @MODEL_REGISTRY.register() 15 | class RealESRGANModel(SRGANModel): 16 | """RealESRGAN Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. 17 | 18 | It mainly performs: 19 | 1. randomly synthesize LQ images in GPU tensors 20 | 2. optimize the networks with GAN training. 21 | """ 22 | 23 | def __init__(self, opt): 24 | super(RealESRGANModel, self).__init__(opt) 25 | self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts 26 | self.usm_sharpener = USMSharp().cuda() # do usm sharpening 27 | self.queue_size = opt.get('queue_size', 180) 28 | 29 | @torch.no_grad() 30 | def _dequeue_and_enqueue(self): 31 | """It is the training pair pool for increasing the diversity in a batch. 32 | 33 | Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a 34 | batch could not have different resize scaling factors. Therefore, we employ this training pair pool 35 | to increase the degradation diversity in a batch. 36 | """ 37 | # initialize 38 | b, c, h, w = self.lq.size() 39 | if not hasattr(self, 'queue_lr'): 40 | assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' 41 | self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() 42 | _, c, h, w = self.gt.size() 43 | self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() 44 | self.queue_ptr = 0 45 | if self.queue_ptr == self.queue_size: # the pool is full 46 | # do dequeue and enqueue 47 | # shuffle 48 | idx = torch.randperm(self.queue_size) 49 | self.queue_lr = self.queue_lr[idx] 50 | self.queue_gt = self.queue_gt[idx] 51 | # get first b samples 52 | lq_dequeue = self.queue_lr[0:b, :, :, :].clone() 53 | gt_dequeue = self.queue_gt[0:b, :, :, :].clone() 54 | # update the queue 55 | self.queue_lr[0:b, :, :, :] = self.lq.clone() 56 | self.queue_gt[0:b, :, :, :] = self.gt.clone() 57 | 58 | self.lq = lq_dequeue 59 | self.gt = gt_dequeue 60 | else: 61 | # only do enqueue 62 | self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() 63 | self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() 64 | self.queue_ptr = self.queue_ptr + b 65 | 66 | @torch.no_grad() 67 | def feed_data(self, data): 68 | """Accept data from dataloader, and then add two-order degradations to obtain LQ images. 69 | """ 70 | if self.is_train and self.opt.get('high_order_degradation', True): 71 | # training data synthesis 72 | self.gt = data['gt'].to(self.device) 73 | self.gt_usm = self.usm_sharpener(self.gt) 74 | 75 | self.kernel1 = data['kernel1'].to(self.device) 76 | self.kernel2 = data['kernel2'].to(self.device) 77 | self.sinc_kernel = data['sinc_kernel'].to(self.device) 78 | 79 | ori_h, ori_w = self.gt.size()[2:4] 80 | 81 | # ----------------------- The first degradation process ----------------------- # 82 | # blur 83 | out = filter2D(self.gt_usm, self.kernel1) 84 | # random resize 85 | updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] 86 | if updown_type == 'up': 87 | scale = np.random.uniform(1, self.opt['resize_range'][1]) 88 | elif updown_type == 'down': 89 | scale = np.random.uniform(self.opt['resize_range'][0], 1) 90 | else: 91 | scale = 1 92 | mode = random.choice(['area', 'bilinear', 'bicubic']) 93 | out = F.interpolate(out, scale_factor=scale, mode=mode) 94 | # add noise 95 | gray_noise_prob = self.opt['gray_noise_prob'] 96 | if np.random.uniform() < self.opt['gaussian_noise_prob']: 97 | out = random_add_gaussian_noise_pt( 98 | out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) 99 | else: 100 | out = random_add_poisson_noise_pt( 101 | out, 102 | scale_range=self.opt['poisson_scale_range'], 103 | gray_prob=gray_noise_prob, 104 | clip=True, 105 | rounds=False) 106 | # JPEG compression 107 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) 108 | out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts 109 | out = self.jpeger(out, quality=jpeg_p) 110 | 111 | # ----------------------- The second degradation process ----------------------- # 112 | # blur 113 | if np.random.uniform() < self.opt['second_blur_prob']: 114 | out = filter2D(out, self.kernel2) 115 | # random resize 116 | updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] 117 | if updown_type == 'up': 118 | scale = np.random.uniform(1, self.opt['resize_range2'][1]) 119 | elif updown_type == 'down': 120 | scale = np.random.uniform(self.opt['resize_range2'][0], 1) 121 | else: 122 | scale = 1 123 | mode = random.choice(['area', 'bilinear', 'bicubic']) 124 | out = F.interpolate( 125 | out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) 126 | # add noise 127 | gray_noise_prob = self.opt['gray_noise_prob2'] 128 | if np.random.uniform() < self.opt['gaussian_noise_prob2']: 129 | out = random_add_gaussian_noise_pt( 130 | out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) 131 | else: 132 | out = random_add_poisson_noise_pt( 133 | out, 134 | scale_range=self.opt['poisson_scale_range2'], 135 | gray_prob=gray_noise_prob, 136 | clip=True, 137 | rounds=False) 138 | 139 | # JPEG compression + the final sinc filter 140 | # We also need to resize images to desired sizes. We group [resize back + sinc filter] together 141 | # as one operation. 142 | # We consider two orders: 143 | # 1. [resize back + sinc filter] + JPEG compression 144 | # 2. JPEG compression + [resize back + sinc filter] 145 | # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. 146 | if np.random.uniform() < 0.5: 147 | # resize back + the final sinc filter 148 | mode = random.choice(['area', 'bilinear', 'bicubic']) 149 | out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) 150 | out = filter2D(out, self.sinc_kernel) 151 | # JPEG compression 152 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) 153 | out = torch.clamp(out, 0, 1) 154 | out = self.jpeger(out, quality=jpeg_p) 155 | else: 156 | # JPEG compression 157 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) 158 | out = torch.clamp(out, 0, 1) 159 | out = self.jpeger(out, quality=jpeg_p) 160 | # resize back + the final sinc filter 161 | mode = random.choice(['area', 'bilinear', 'bicubic']) 162 | out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) 163 | out = filter2D(out, self.sinc_kernel) 164 | 165 | # clamp and round 166 | self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. 167 | 168 | # random crop 169 | gt_size = self.opt['gt_size'] 170 | (self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size, 171 | self.opt['scale']) 172 | 173 | # training pair pool 174 | self._dequeue_and_enqueue() 175 | # sharpen self.gt again, as we have changed the self.gt with self._dequeue_and_enqueue 176 | self.gt_usm = self.usm_sharpener(self.gt) 177 | self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract 178 | else: 179 | # for paired training or validation 180 | self.lq = data['lq'].to(self.device) 181 | if 'gt' in data: 182 | self.gt = data['gt'].to(self.device) 183 | self.gt_usm = self.usm_sharpener(self.gt) 184 | 185 | def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): 186 | # do not use the synthetic process during validation 187 | self.is_train = False 188 | super(RealESRGANModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) 189 | self.is_train = True 190 | 191 | def optimize_parameters(self, current_iter): 192 | # usm sharpening 193 | l1_gt = self.gt_usm 194 | percep_gt = self.gt_usm 195 | gan_gt = self.gt_usm 196 | if self.opt['l1_gt_usm'] is False: 197 | l1_gt = self.gt 198 | if self.opt['percep_gt_usm'] is False: 199 | percep_gt = self.gt 200 | if self.opt['gan_gt_usm'] is False: 201 | gan_gt = self.gt 202 | 203 | # optimize net_g 204 | for p in self.net_d.parameters(): 205 | p.requires_grad = False 206 | 207 | self.optimizer_g.zero_grad() 208 | self.output = self.net_g(self.lq) 209 | 210 | l_g_total = 0 211 | loss_dict = OrderedDict() 212 | if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters): 213 | # pixel loss 214 | if self.cri_pix: 215 | l_g_pix = self.cri_pix(self.output, l1_gt) 216 | l_g_total += l_g_pix 217 | loss_dict['l_g_pix'] = l_g_pix 218 | # perceptual loss 219 | if self.cri_perceptual: 220 | l_g_percep, l_g_style = self.cri_perceptual(self.output, percep_gt) 221 | if l_g_percep is not None: 222 | l_g_total += l_g_percep 223 | loss_dict['l_g_percep'] = l_g_percep 224 | if l_g_style is not None: 225 | l_g_total += l_g_style 226 | loss_dict['l_g_style'] = l_g_style 227 | # gan loss 228 | fake_g_pred = self.net_d(self.output) 229 | l_g_gan = self.cri_gan(fake_g_pred, True, is_disc=False) 230 | l_g_total += l_g_gan 231 | loss_dict['l_g_gan'] = l_g_gan 232 | 233 | l_g_total.backward() 234 | self.optimizer_g.step() 235 | 236 | # optimize net_d 237 | for p in self.net_d.parameters(): 238 | p.requires_grad = True 239 | 240 | self.optimizer_d.zero_grad() 241 | # real 242 | real_d_pred = self.net_d(gan_gt) 243 | l_d_real = self.cri_gan(real_d_pred, True, is_disc=True) 244 | loss_dict['l_d_real'] = l_d_real 245 | loss_dict['out_d_real'] = torch.mean(real_d_pred.detach()) 246 | l_d_real.backward() 247 | # fake 248 | fake_d_pred = self.net_d(self.output.detach().clone()) # clone for pt1.9 249 | l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True) 250 | loss_dict['l_d_fake'] = l_d_fake 251 | loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach()) 252 | l_d_fake.backward() 253 | self.optimizer_d.step() 254 | 255 | if self.ema_decay > 0: 256 | self.model_ema(decay=self.ema_decay) 257 | 258 | self.log_dict = self.reduce_loss_dict(loss_dict) 259 | -------------------------------------------------------------------------------- /realesrgan/models/realesrnet_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import torch 4 | from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt 5 | from basicsr.data.transforms import paired_random_crop 6 | from basicsr.models.sr_model import SRModel 7 | from basicsr.utils import DiffJPEG, USMSharp 8 | from basicsr.utils.img_process_util import filter2D 9 | from basicsr.utils.registry import MODEL_REGISTRY 10 | from torch.nn import functional as F 11 | 12 | 13 | @MODEL_REGISTRY.register() 14 | class RealESRNetModel(SRModel): 15 | """RealESRNet Model for Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. 16 | 17 | It is trained without GAN losses. 18 | It mainly performs: 19 | 1. randomly synthesize LQ images in GPU tensors 20 | 2. optimize the networks with GAN training. 21 | """ 22 | 23 | def __init__(self, opt): 24 | super(RealESRNetModel, self).__init__(opt) 25 | self.jpeger = DiffJPEG(differentiable=False).cuda() # simulate JPEG compression artifacts 26 | self.usm_sharpener = USMSharp().cuda() # do usm sharpening 27 | self.queue_size = opt.get('queue_size', 180) 28 | 29 | @torch.no_grad() 30 | def _dequeue_and_enqueue(self): 31 | """It is the training pair pool for increasing the diversity in a batch. 32 | 33 | Batch processing limits the diversity of synthetic degradations in a batch. For example, samples in a 34 | batch could not have different resize scaling factors. Therefore, we employ this training pair pool 35 | to increase the degradation diversity in a batch. 36 | """ 37 | # initialize 38 | b, c, h, w = self.lq.size() 39 | if not hasattr(self, 'queue_lr'): 40 | assert self.queue_size % b == 0, f'queue size {self.queue_size} should be divisible by batch size {b}' 41 | self.queue_lr = torch.zeros(self.queue_size, c, h, w).cuda() 42 | _, c, h, w = self.gt.size() 43 | self.queue_gt = torch.zeros(self.queue_size, c, h, w).cuda() 44 | self.queue_ptr = 0 45 | if self.queue_ptr == self.queue_size: # the pool is full 46 | # do dequeue and enqueue 47 | # shuffle 48 | idx = torch.randperm(self.queue_size) 49 | self.queue_lr = self.queue_lr[idx] 50 | self.queue_gt = self.queue_gt[idx] 51 | # get first b samples 52 | lq_dequeue = self.queue_lr[0:b, :, :, :].clone() 53 | gt_dequeue = self.queue_gt[0:b, :, :, :].clone() 54 | # update the queue 55 | self.queue_lr[0:b, :, :, :] = self.lq.clone() 56 | self.queue_gt[0:b, :, :, :] = self.gt.clone() 57 | 58 | self.lq = lq_dequeue 59 | self.gt = gt_dequeue 60 | else: 61 | # only do enqueue 62 | self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone() 63 | self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone() 64 | self.queue_ptr = self.queue_ptr + b 65 | 66 | @torch.no_grad() 67 | def feed_data(self, data): 68 | """Accept data from dataloader, and then add two-order degradations to obtain LQ images. 69 | """ 70 | if self.is_train and self.opt.get('high_order_degradation', True): 71 | # training data synthesis 72 | self.gt = data['gt'].to(self.device) 73 | # USM sharpen the GT images 74 | if self.opt['gt_usm'] is True: 75 | self.gt = self.usm_sharpener(self.gt) 76 | 77 | self.kernel1 = data['kernel1'].to(self.device) 78 | self.kernel2 = data['kernel2'].to(self.device) 79 | self.sinc_kernel = data['sinc_kernel'].to(self.device) 80 | 81 | ori_h, ori_w = self.gt.size()[2:4] 82 | 83 | # ----------------------- The first degradation process ----------------------- # 84 | # blur 85 | out = filter2D(self.gt, self.kernel1) 86 | # random resize 87 | updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob'])[0] 88 | if updown_type == 'up': 89 | scale = np.random.uniform(1, self.opt['resize_range'][1]) 90 | elif updown_type == 'down': 91 | scale = np.random.uniform(self.opt['resize_range'][0], 1) 92 | else: 93 | scale = 1 94 | mode = random.choice(['area', 'bilinear', 'bicubic']) 95 | out = F.interpolate(out, scale_factor=scale, mode=mode) 96 | # add noise 97 | gray_noise_prob = self.opt['gray_noise_prob'] 98 | if np.random.uniform() < self.opt['gaussian_noise_prob']: 99 | out = random_add_gaussian_noise_pt( 100 | out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob) 101 | else: 102 | out = random_add_poisson_noise_pt( 103 | out, 104 | scale_range=self.opt['poisson_scale_range'], 105 | gray_prob=gray_noise_prob, 106 | clip=True, 107 | rounds=False) 108 | # JPEG compression 109 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range']) 110 | out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts 111 | out = self.jpeger(out, quality=jpeg_p) 112 | 113 | # ----------------------- The second degradation process ----------------------- # 114 | # blur 115 | if np.random.uniform() < self.opt['second_blur_prob']: 116 | out = filter2D(out, self.kernel2) 117 | # random resize 118 | updown_type = random.choices(['up', 'down', 'keep'], self.opt['resize_prob2'])[0] 119 | if updown_type == 'up': 120 | scale = np.random.uniform(1, self.opt['resize_range2'][1]) 121 | elif updown_type == 'down': 122 | scale = np.random.uniform(self.opt['resize_range2'][0], 1) 123 | else: 124 | scale = 1 125 | mode = random.choice(['area', 'bilinear', 'bicubic']) 126 | out = F.interpolate( 127 | out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode) 128 | # add noise 129 | gray_noise_prob = self.opt['gray_noise_prob2'] 130 | if np.random.uniform() < self.opt['gaussian_noise_prob2']: 131 | out = random_add_gaussian_noise_pt( 132 | out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob) 133 | else: 134 | out = random_add_poisson_noise_pt( 135 | out, 136 | scale_range=self.opt['poisson_scale_range2'], 137 | gray_prob=gray_noise_prob, 138 | clip=True, 139 | rounds=False) 140 | 141 | # JPEG compression + the final sinc filter 142 | # We also need to resize images to desired sizes. We group [resize back + sinc filter] together 143 | # as one operation. 144 | # We consider two orders: 145 | # 1. [resize back + sinc filter] + JPEG compression 146 | # 2. JPEG compression + [resize back + sinc filter] 147 | # Empirically, we find other combinations (sinc + JPEG + Resize) will introduce twisted lines. 148 | if np.random.uniform() < 0.5: 149 | # resize back + the final sinc filter 150 | mode = random.choice(['area', 'bilinear', 'bicubic']) 151 | out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) 152 | out = filter2D(out, self.sinc_kernel) 153 | # JPEG compression 154 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) 155 | out = torch.clamp(out, 0, 1) 156 | out = self.jpeger(out, quality=jpeg_p) 157 | else: 158 | # JPEG compression 159 | jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range2']) 160 | out = torch.clamp(out, 0, 1) 161 | out = self.jpeger(out, quality=jpeg_p) 162 | # resize back + the final sinc filter 163 | mode = random.choice(['area', 'bilinear', 'bicubic']) 164 | out = F.interpolate(out, size=(ori_h // self.opt['scale'], ori_w // self.opt['scale']), mode=mode) 165 | out = filter2D(out, self.sinc_kernel) 166 | 167 | # clamp and round 168 | self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255. 169 | 170 | # random crop 171 | gt_size = self.opt['gt_size'] 172 | self.gt, self.lq = paired_random_crop(self.gt, self.lq, gt_size, self.opt['scale']) 173 | 174 | # training pair pool 175 | self._dequeue_and_enqueue() 176 | self.lq = self.lq.contiguous() # for the warning: grad and param do not obey the gradient layout contract 177 | else: 178 | # for paired training or validation 179 | self.lq = data['lq'].to(self.device) 180 | if 'gt' in data: 181 | self.gt = data['gt'].to(self.device) 182 | self.gt_usm = self.usm_sharpener(self.gt) 183 | 184 | def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): 185 | # do not use the synthetic process during validation 186 | self.is_train = False 187 | super(RealESRNetModel, self).nondist_validation(dataloader, current_iter, tb_logger, save_img) 188 | self.is_train = True 189 | -------------------------------------------------------------------------------- /realesrgan/train.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import os.path as osp 3 | from basicsr.train import train_pipeline 4 | 5 | import realesrgan.archs 6 | import realesrgan.data 7 | import realesrgan.models 8 | 9 | if __name__ == '__main__': 10 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 11 | train_pipeline(root_path) 12 | -------------------------------------------------------------------------------- /realesrgan/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import queue 6 | import threading 7 | import torch 8 | from basicsr.utils.download_util import load_file_from_url 9 | from torch.nn import functional as F 10 | 11 | ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | 13 | 14 | class RealESRGANer(): 15 | """A helper class for upsampling images with RealESRGAN. 16 | 17 | Args: 18 | scale (int): Upsampling scale factor used in the networks. It is usually 2 or 4. 19 | model_path (str): The path to the pretrained model. It can be urls (will first download it automatically). 20 | model (nn.Module): The defined network. Default: None. 21 | tile (int): As too large images result in the out of GPU memory issue, so this tile option will first crop 22 | input images into tiles, and then process each of them. Finally, they will be merged into one image. 23 | 0 denotes for do not use tile. Default: 0. 24 | tile_pad (int): The pad size for each tile, to remove border artifacts. Default: 10. 25 | pre_pad (int): Pad the input images to avoid border artifacts. Default: 10. 26 | half (float): Whether to use half precision during inference. Default: False. 27 | """ 28 | 29 | def __init__(self, 30 | scale, 31 | model_path, 32 | dni_weight=None, 33 | model=None, 34 | tile=0, 35 | tile_pad=10, 36 | pre_pad=10, 37 | half=False, 38 | device=None, 39 | gpu_id=None): 40 | self.scale = scale 41 | self.tile_size = tile 42 | self.tile_pad = tile_pad 43 | self.pre_pad = pre_pad 44 | self.mod_scale = None 45 | self.half = half 46 | 47 | # initialize model 48 | if gpu_id: 49 | self.device = torch.device( 50 | f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu') if device is None else device 51 | else: 52 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device 53 | 54 | if isinstance(model_path, list): 55 | # dni 56 | assert len(model_path) == len(dni_weight), 'model_path and dni_weight should have the save length.' 57 | loadnet = self.dni(model_path[0], model_path[1], dni_weight) 58 | else: 59 | # if the model_path starts with https, it will first download models to the folder: weights 60 | if model_path.startswith('https://'): 61 | model_path = load_file_from_url( 62 | url=model_path, model_dir=os.path.join(ROOT_DIR, 'weights'), progress=True, file_name=None) 63 | loadnet = torch.load(model_path, map_location=torch.device('cpu')) 64 | 65 | # prefer to use params_ema 66 | if 'params_ema' in loadnet: 67 | keyname = 'params_ema' 68 | else: 69 | keyname = 'params' 70 | model.load_state_dict(loadnet[keyname], strict=True) 71 | 72 | model.eval() 73 | self.model = model.to(self.device) 74 | if self.half: 75 | self.model = self.model.half() 76 | 77 | def dni(self, net_a, net_b, dni_weight, key='params', loc='cpu'): 78 | """Deep network interpolation. 79 | 80 | ``Paper: Deep Network Interpolation for Continuous Imagery Effect Transition`` 81 | """ 82 | net_a = torch.load(net_a, map_location=torch.device(loc)) 83 | net_b = torch.load(net_b, map_location=torch.device(loc)) 84 | for k, v_a in net_a[key].items(): 85 | net_a[key][k] = dni_weight[0] * v_a + dni_weight[1] * net_b[key][k] 86 | return net_a 87 | 88 | def pre_process(self, img): 89 | """Pre-process, such as pre-pad and mod pad, so that the images can be divisible 90 | """ 91 | img = torch.from_numpy(np.transpose(img, (2, 0, 1))).float() 92 | self.img = img.unsqueeze(0).to(self.device) 93 | if self.half: 94 | self.img = self.img.half() 95 | 96 | # pre_pad 97 | if self.pre_pad != 0: 98 | self.img = F.pad(self.img, (0, self.pre_pad, 0, self.pre_pad), 'reflect') 99 | # mod pad for divisible borders 100 | if self.scale == 2: 101 | self.mod_scale = 2 102 | elif self.scale == 1: 103 | self.mod_scale = 4 104 | if self.mod_scale is not None: 105 | self.mod_pad_h, self.mod_pad_w = 0, 0 106 | _, _, h, w = self.img.size() 107 | if (h % self.mod_scale != 0): 108 | self.mod_pad_h = (self.mod_scale - h % self.mod_scale) 109 | if (w % self.mod_scale != 0): 110 | self.mod_pad_w = (self.mod_scale - w % self.mod_scale) 111 | self.img = F.pad(self.img, (0, self.mod_pad_w, 0, self.mod_pad_h), 'reflect') 112 | 113 | def process(self): 114 | # model inference 115 | self.output = self.model(self.img) 116 | 117 | def tile_process(self): 118 | """It will first crop input images to tiles, and then process each tile. 119 | Finally, all the processed tiles are merged into one images. 120 | 121 | Modified from: https://github.com/ata4/esrgan-launcher 122 | """ 123 | batch, channel, height, width = self.img.shape 124 | output_height = height * self.scale 125 | output_width = width * self.scale 126 | output_shape = (batch, channel, output_height, output_width) 127 | 128 | # start with black image 129 | self.output = self.img.new_zeros(output_shape) 130 | tiles_x = math.ceil(width / self.tile_size) 131 | tiles_y = math.ceil(height / self.tile_size) 132 | 133 | # loop over all tiles 134 | for y in range(tiles_y): 135 | for x in range(tiles_x): 136 | # extract tile from input image 137 | ofs_x = x * self.tile_size 138 | ofs_y = y * self.tile_size 139 | # input tile area on total image 140 | input_start_x = ofs_x 141 | input_end_x = min(ofs_x + self.tile_size, width) 142 | input_start_y = ofs_y 143 | input_end_y = min(ofs_y + self.tile_size, height) 144 | 145 | # input tile area on total image with padding 146 | input_start_x_pad = max(input_start_x - self.tile_pad, 0) 147 | input_end_x_pad = min(input_end_x + self.tile_pad, width) 148 | input_start_y_pad = max(input_start_y - self.tile_pad, 0) 149 | input_end_y_pad = min(input_end_y + self.tile_pad, height) 150 | 151 | # input tile dimensions 152 | input_tile_width = input_end_x - input_start_x 153 | input_tile_height = input_end_y - input_start_y 154 | tile_idx = y * tiles_x + x + 1 155 | input_tile = self.img[:, :, input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad] 156 | 157 | # upscale tile 158 | try: 159 | with torch.no_grad(): 160 | output_tile = self.model(input_tile) 161 | except RuntimeError as error: 162 | print('Error', error) 163 | print(f'\tTile {tile_idx}/{tiles_x * tiles_y}') 164 | 165 | # output tile area on total image 166 | output_start_x = input_start_x * self.scale 167 | output_end_x = input_end_x * self.scale 168 | output_start_y = input_start_y * self.scale 169 | output_end_y = input_end_y * self.scale 170 | 171 | # output tile area without padding 172 | output_start_x_tile = (input_start_x - input_start_x_pad) * self.scale 173 | output_end_x_tile = output_start_x_tile + input_tile_width * self.scale 174 | output_start_y_tile = (input_start_y - input_start_y_pad) * self.scale 175 | output_end_y_tile = output_start_y_tile + input_tile_height * self.scale 176 | 177 | # put tile into output image 178 | self.output[:, :, output_start_y:output_end_y, 179 | output_start_x:output_end_x] = output_tile[:, :, output_start_y_tile:output_end_y_tile, 180 | output_start_x_tile:output_end_x_tile] 181 | 182 | def post_process(self): 183 | # remove extra pad 184 | if self.mod_scale is not None: 185 | _, _, h, w = self.output.size() 186 | self.output = self.output[:, :, 0:h - self.mod_pad_h * self.scale, 0:w - self.mod_pad_w * self.scale] 187 | # remove prepad 188 | if self.pre_pad != 0: 189 | _, _, h, w = self.output.size() 190 | self.output = self.output[:, :, 0:h - self.pre_pad * self.scale, 0:w - self.pre_pad * self.scale] 191 | return self.output 192 | 193 | @torch.no_grad() 194 | def enhance(self, img, outscale=None, alpha_upsampler='realesrgan'): 195 | h_input, w_input = img.shape[0:2] 196 | # img: numpy 197 | img = img.astype(np.float32) 198 | if np.max(img) > 256: # 16-bit image 199 | max_range = 65535 200 | print('\tInput is a 16-bit image') 201 | else: 202 | max_range = 255 203 | img = img / max_range 204 | if len(img.shape) == 2: # gray image 205 | img_mode = 'L' 206 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) 207 | elif img.shape[2] == 4: # RGBA image with alpha channel 208 | img_mode = 'RGBA' 209 | alpha = img[:, :, 3] 210 | img = img[:, :, 0:3] 211 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 212 | if alpha_upsampler == 'realesrgan': 213 | alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2RGB) 214 | else: 215 | img_mode = 'RGB' 216 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 217 | 218 | # ------------------- process image (without the alpha channel) ------------------- # 219 | self.pre_process(img) 220 | if self.tile_size > 0: 221 | self.tile_process() 222 | else: 223 | self.process() 224 | output_img = self.post_process() 225 | output_img = output_img.data.squeeze().float().cpu().clamp_(0, 1).numpy() 226 | output_img = np.transpose(output_img[[2, 1, 0], :, :], (1, 2, 0)) 227 | if img_mode == 'L': 228 | output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2GRAY) 229 | 230 | # ------------------- process the alpha channel if necessary ------------------- # 231 | if img_mode == 'RGBA': 232 | if alpha_upsampler == 'realesrgan': 233 | self.pre_process(alpha) 234 | if self.tile_size > 0: 235 | self.tile_process() 236 | else: 237 | self.process() 238 | output_alpha = self.post_process() 239 | output_alpha = output_alpha.data.squeeze().float().cpu().clamp_(0, 1).numpy() 240 | output_alpha = np.transpose(output_alpha[[2, 1, 0], :, :], (1, 2, 0)) 241 | output_alpha = cv2.cvtColor(output_alpha, cv2.COLOR_BGR2GRAY) 242 | else: # use the cv2 resize for alpha channel 243 | h, w = alpha.shape[0:2] 244 | output_alpha = cv2.resize(alpha, (w * self.scale, h * self.scale), interpolation=cv2.INTER_LINEAR) 245 | 246 | # merge the alpha channel 247 | output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2BGRA) 248 | output_img[:, :, 3] = output_alpha 249 | 250 | # ------------------------------ return ------------------------------ # 251 | if max_range == 65535: # 16-bit image 252 | output = (output_img * 65535.0).round().astype(np.uint16) 253 | else: 254 | output = (output_img * 255.0).round().astype(np.uint8) 255 | 256 | if outscale is not None and outscale != float(self.scale): 257 | output = cv2.resize( 258 | output, ( 259 | int(w_input * outscale), 260 | int(h_input * outscale), 261 | ), interpolation=cv2.INTER_LANCZOS4) 262 | 263 | return output, img_mode 264 | 265 | 266 | class PrefetchReader(threading.Thread): 267 | """Prefetch images. 268 | 269 | Args: 270 | img_list (list[str]): A image list of image paths to be read. 271 | num_prefetch_queue (int): Number of prefetch queue. 272 | """ 273 | 274 | def __init__(self, img_list, num_prefetch_queue): 275 | super().__init__() 276 | self.que = queue.Queue(num_prefetch_queue) 277 | self.img_list = img_list 278 | 279 | def run(self): 280 | for img_path in self.img_list: 281 | img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) 282 | self.que.put(img) 283 | 284 | self.que.put(None) 285 | 286 | def __next__(self): 287 | next_item = self.que.get() 288 | if next_item is None: 289 | raise StopIteration 290 | return next_item 291 | 292 | def __iter__(self): 293 | return self 294 | 295 | 296 | class IOConsumer(threading.Thread): 297 | 298 | def __init__(self, opt, que, qid): 299 | super().__init__() 300 | self._queue = que 301 | self.qid = qid 302 | self.opt = opt 303 | 304 | def run(self): 305 | while True: 306 | msg = self._queue.get() 307 | if isinstance(msg, str) and msg == 'quit': 308 | break 309 | 310 | output = msg['output'] 311 | save_path = msg['save_path'] 312 | cv2.imwrite(save_path, output) 313 | print(f'IO worker {self.qid} is done.') 314 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | insightface 3 | argparse 4 | basicsr>=1.4.2 5 | facexlib>=0.2.5 6 | lmdb 7 | pyyaml 8 | scipy 9 | tb-nightly 10 | tqdm 11 | yapf 12 | pynvml 13 | codeformer-pip==0.0.4 14 | psutil -------------------------------------------------------------------------------- /start_mac.sh: -------------------------------------------------------------------------------- 1 | source venv/bin/activate -------------------------------------------------------------------------------- /start_venv_linux.sh: -------------------------------------------------------------------------------- 1 | source venv/bin/activate -------------------------------------------------------------------------------- /start_venv_windows.cmd: -------------------------------------------------------------------------------- 1 | start venv\scripts\activate.bat -------------------------------------------------------------------------------- /swapperfp16.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | import onnxruntime 4 | import cv2 5 | import onnx 6 | from onnx import numpy_helper 7 | from insightface.utils import face_align 8 | from numpy.linalg import norm as l2norm 9 | import tqdm 10 | import requests 11 | import os 12 | 13 | 14 | class INSwapper(): 15 | def __init__(self, model_file=None, session=None): 16 | self.model_file = model_file 17 | self.session = session 18 | model = onnx.load(self.model_file) 19 | graph = model.graph 20 | self.emap = numpy_helper.to_array(graph.initializer[-1]) 21 | self.input_mean = 0.0 22 | self.input_std = 255.0 23 | #print('input mean and std:', model_file, self.input_mean, self.input_std) 24 | if self.session is None: 25 | self.session = onnxruntime.InferenceSession(self.model_file, None) 26 | inputs = self.session.get_inputs() 27 | self.input_names = [] 28 | for inp in inputs: 29 | self.input_names.append(inp.name) 30 | outputs = self.session.get_outputs() 31 | output_names = [] 32 | for out in outputs: 33 | output_names.append(out.name) 34 | self.output_names = output_names 35 | assert len(self.output_names)==1 36 | output_shape = outputs[0].shape 37 | input_cfg = inputs[0] 38 | input_shape = input_cfg.shape 39 | self.input_shape = input_shape 40 | print('inswapper-shape:', self.input_shape) 41 | self.input_size = tuple(input_shape[2:4][::-1]) 42 | def forward(self, img, latent): 43 | img = (img - self.input_mean) / self.input_std 44 | pred = self.session.run(self.output_names, {self.input_names[0]: img, self.input_names[1]: latent})[0] 45 | return pred 46 | def get(self, img, target_face, source_face, paste_back=True): 47 | aimg, M = face_align.norm_crop2(img, target_face.kps, self.input_size[0]) 48 | blob = cv2.dnn.blobFromImage(aimg, 1.0 / self.input_std, self.input_size, 49 | (self.input_mean, self.input_mean, self.input_mean), swapRB=True) 50 | s_e = source_face.normed_embedding 51 | n_e = s_e / l2norm(s_e) 52 | latent = n_e.reshape((1,-1)) 53 | 54 | latent = np.dot(latent, self.emap) 55 | latent /= np.linalg.norm(latent) 56 | pred = self.session.run(self.output_names, {self.input_names[0]: blob, self.input_names[1]: latent})[0] 57 | #print(latent.shape, latent.dtype, pred.shape) 58 | img_fake = pred.transpose((0,2,3,1))[0] 59 | #print("Minimum value:", np.min(img_fake)) 60 | #print("Maximum value:", np.max(img_fake)) 61 | bgr_fake = np.clip(255 * img_fake, 0, 255).astype(np.uint8)[:,:,::-1] 62 | if not paste_back: 63 | return bgr_fake, M 64 | else: 65 | target_img = img 66 | fake_diff = bgr_fake.astype(np.float32) - aimg.astype(np.float32) 67 | fake_diff = np.abs(fake_diff).mean(axis=2) 68 | fake_diff[:2,:] = 0 69 | fake_diff[-2:,:] = 0 70 | fake_diff[:,:2] = 0 71 | fake_diff[:,-2:] = 0 72 | IM = cv2.invertAffineTransform(M) 73 | img_mask = np.full((aimg.shape[0],aimg.shape[1]), 255, dtype=np.float32) 74 | bgr_fake = cv2.warpAffine(bgr_fake, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) 75 | img_mask = cv2.warpAffine(img_mask, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) 76 | fake_diff = cv2.warpAffine(fake_diff, IM, (target_img.shape[1], target_img.shape[0]), borderValue=0.0) 77 | img_mask[img_mask>20] = 255 78 | fthresh = 10 79 | fake_diff[fake_diff=fthresh] = 255 81 | #img_mask = img_white 82 | mask_h_inds, mask_w_inds = np.where(img_mask==255) 83 | mask_h = np.max(mask_h_inds) - np.min(mask_h_inds) 84 | mask_w = np.max(mask_w_inds) - np.min(mask_w_inds) 85 | mask_size = int(np.sqrt(mask_h*mask_w)) 86 | k = max(mask_size//10, 10) 87 | #k = max(mask_size//20, 6) 88 | #k = 6 89 | kernel = np.ones((k,k),np.uint8) 90 | img_mask = cv2.erode(img_mask,kernel,iterations = 1) 91 | kernel = np.ones((2,2),np.uint8) 92 | fake_diff = cv2.dilate(fake_diff,kernel,iterations = 1) 93 | k = max(mask_size//20, 5) 94 | #k = 3 95 | #k = 3 96 | kernel_size = (k, k) 97 | blur_size = tuple(2*i+1 for i in kernel_size) 98 | img_mask = cv2.GaussianBlur(img_mask, blur_size, 0) 99 | k = 5 100 | kernel_size = (k, k) 101 | blur_size = tuple(2*i+1 for i in kernel_size) 102 | fake_diff = cv2.GaussianBlur(fake_diff, blur_size, 0) 103 | img_mask /= 255 104 | fake_diff /= 255 105 | #img_mask = fake_diff 106 | img_mask = np.reshape(img_mask, [img_mask.shape[0],img_mask.shape[1],1]) 107 | fake_merged = img_mask * bgr_fake + (1-img_mask) * target_img.astype(np.float32) 108 | fake_merged = fake_merged.astype(np.uint8) 109 | return fake_merged 110 | class PickableInferenceSession(onnxruntime.InferenceSession): 111 | # This is a wrapper to make the current InferenceSession class pickable. 112 | def __init__(self, model_path, **kwargs): 113 | super().__init__(model_path, **kwargs) 114 | self.model_path = model_path 115 | 116 | def __getstate__(self): 117 | return {'model_path': self.model_path} 118 | 119 | def __setstate__(self, values): 120 | model_path = values['model_path'] 121 | self.__init__(model_path) 122 | 123 | class ModelRouter: 124 | def __init__(self, onnx_file): 125 | self.onnx_file = onnx_file 126 | 127 | def get_model(self, **kwargs): 128 | session = PickableInferenceSession(self.onnx_file, **kwargs) 129 | print(f'Applied providers: {session._providers}, with options: {session._provider_options}') 130 | inputs = session.get_inputs() 131 | input_cfg = inputs[0] 132 | input_shape = input_cfg.shape 133 | outputs = session.get_outputs() 134 | return INSwapper(model_file=self.onnx_file, session=session) 135 | 136 | def get_default_providers(): 137 | return ['CUDAExecutionProvider', 'CPUExecutionProvider'] 138 | 139 | def get_default_provider_options(): 140 | return None 141 | 142 | def download(link, filename): 143 | response = requests.get(link, stream=True) 144 | total_size = int(response.headers.get('content-length', 0)) 145 | block_size = 1024*16 # 1 KB 146 | progress_bar = tqdm.tqdm(total=total_size, unit='B', unit_scale=True) 147 | 148 | with open(filename, 'wb') as file: 149 | for data in response.iter_content(block_size): 150 | progress_bar.update(len(data)) 151 | file.write(data) 152 | 153 | progress_bar.close() 154 | 155 | def check_or_download(filename): 156 | exists = os.path.exists(filename) 157 | if not exists: 158 | download(f"https://github.com/RichardErkhov/FastFaceSwap/releases/download/model/{filename}", filename) 159 | def get_model(name, **kwargs): 160 | check_or_download(name) 161 | router = ModelRouter(name) 162 | providers = kwargs.get('providers', get_default_providers()) 163 | provider_options = kwargs.get('provider_options', get_default_provider_options()) 164 | #session_options = kwargs.get('session_options', None) 165 | model = router.get_model(providers=providers, provider_options=provider_options)#, session_options = session_options) 166 | return model -------------------------------------------------------------------------------- /x.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RichardErkhov/FastFaceSwap/be10e8d0110678b955bd235a06be0f934f567890/x.pkl -------------------------------------------------------------------------------- /zlibwapi.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RichardErkhov/FastFaceSwap/be10e8d0110678b955bd235a06be0f934f567890/zlibwapi.dll --------------------------------------------------------------------------------