├── __init__.py ├── assets └── demo.PNG ├── export_trt.py ├── readme.md ├── requirements.txt └── trt_utilities.py /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import folder_paths 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import torch 6 | from comfy.utils import ProgressBar 7 | import cv2 8 | from .trt_utilities import Engine 9 | from torchvision.transforms.functional import normalize 10 | 11 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 12 | """Convert torch Tensors into image numpy arrays. 13 | 14 | After clamping to [min, max], values will be normalized to [0, 1]. 15 | 16 | Args: 17 | tensor (Tensor or list[Tensor]): Accept shapes: 18 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 19 | 2) 3D Tensor of shape (3/1 x H x W); 20 | 3) 2D Tensor of shape (H x W). 21 | Tensor channel should be in RGB order. 22 | rgb2bgr (bool): Whether to change rgb to bgr. 23 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 24 | to uint8 type with range [0, 255]; otherwise, float type with 25 | range [0, 1]. Default: ``np.uint8``. 26 | min_max (tuple[int]): min and max values for clamp. 27 | 28 | Returns: 29 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 30 | shape (H x W). The channel order is BGR. 31 | """ 32 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 33 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') 34 | 35 | if torch.is_tensor(tensor): 36 | tensor = [tensor] 37 | result = [] 38 | for _tensor in tensor: 39 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 40 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 41 | 42 | n_dim = _tensor.dim() 43 | if n_dim == 4: 44 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() 45 | img_np = img_np.transpose(1, 2, 0) 46 | if rgb2bgr: 47 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 48 | elif n_dim == 3: 49 | img_np = _tensor.numpy() 50 | img_np = img_np.transpose(1, 2, 0) 51 | if img_np.shape[2] == 1: # gray image 52 | img_np = np.squeeze(img_np, axis=2) 53 | else: 54 | if rgb2bgr: 55 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 56 | elif n_dim == 2: 57 | img_np = _tensor.numpy() 58 | else: 59 | raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}') 60 | if out_type == np.uint8: 61 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 62 | img_np = (img_np * 255.0).round() 63 | img_np = img_np.astype(out_type) 64 | result.append(img_np) 65 | if len(result) == 1: 66 | result = result[0] 67 | return result 68 | 69 | 70 | ENGINE_DIR = os.path.join(folder_paths.models_dir,"tensorrt", "facerestore") 71 | 72 | class FaceRestoreTensorrt: 73 | @classmethod 74 | def INPUT_TYPES(s): 75 | return { 76 | "required": { 77 | "images": ("IMAGE",), 78 | "engine": (os.listdir(ENGINE_DIR),), 79 | } 80 | } 81 | RETURN_NAMES = ("IMAGE",) 82 | RETURN_TYPES = ("IMAGE",) 83 | FUNCTION = "main" 84 | CATEGORY = "tensorrt" 85 | 86 | def main(self, images, engine): 87 | 88 | # setup tensorrt engine 89 | if (not hasattr(self, 'engine') or self.engine_label != engine): 90 | self.engine = Engine(os.path.join(ENGINE_DIR,engine)) 91 | self.engine.load() 92 | self.engine.activate() 93 | self.engine.allocate_buffers() 94 | self.engine_label = engine 95 | 96 | cudaStream = torch.cuda.current_stream().cuda_stream 97 | pbar = ProgressBar(images.shape[0]) 98 | images = images.permute(0, 3, 1, 2) 99 | images_resized = F.interpolate(images, size=(512,512), mode='bilinear', align_corners=False) 100 | images_list = list(torch.split(images_resized, split_size_or_sections=1)) 101 | 102 | output_frames = [] 103 | 104 | for img in images_list: 105 | normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) 106 | result = self.engine.infer({"input": img},cudaStream) 107 | output = result['output'] 108 | 109 | output = tensor2img(output, rgb2bgr=True, min_max=(-1, 1)) 110 | output = output.astype('uint8') 111 | output = cv2.resize(output, (images.shape[3], images.shape[2])) 112 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 113 | 114 | output_frames.append(output) 115 | pbar.update(1) 116 | 117 | 118 | output_frames = np.array(output_frames).astype(np.float32) / 255.0 119 | return (torch.from_numpy(output_frames),) 120 | 121 | NODE_CLASS_MAPPINGS = { 122 | "FaceRestoreTensorrt" : FaceRestoreTensorrt, 123 | } 124 | 125 | NODE_DISPLAY_NAME_MAPPINGS = { 126 | "FaceRestoreTensorrt" : "Face Restore Tensorrt", 127 | } 128 | 129 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS'] -------------------------------------------------------------------------------- /assets/demo.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuvraj108c/ComfyUI-Facerestore-Tensorrt/86eba4377a2c1dccacbf05b13ae93a764d9c0520/assets/demo.PNG -------------------------------------------------------------------------------- /export_trt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import time 3 | from trt_utilities import Engine 4 | 5 | 6 | def export_trt(trt_path: str, onnx_path: str, use_fp16: bool): 7 | engine = Engine(trt_path) 8 | 9 | torch.cuda.empty_cache() 10 | 11 | s = time.time() 12 | ret = engine.build( 13 | onnx_path, 14 | use_fp16, 15 | enable_preview=True, 16 | ) 17 | e = time.time() 18 | print(f"Time taken to build: {(e-s)} seconds") 19 | 20 | return ret 21 | 22 | 23 | export_trt(trt_path="./codeformer.engine", 24 | onnx_path="./codeformer.onnx", use_fp16=True) 25 | 26 | export_trt(trt_path="./gfqgan.engine", 27 | onnx_path="./gfqgan.onnx", use_fp16=False) -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # ComfyUI Facerestore TensorRT 4 | 5 | [![python](https://img.shields.io/badge/python-3.10.12-green)](https://www.python.org/downloads/release/python-31012/) 6 | [![cuda](https://img.shields.io/badge/cuda-12.4-green)](https://developer.nvidia.com/cuda-downloads) 7 | [![trt](https://img.shields.io/badge/TRT-10.4-green)](https://developer.nvidia.com/tensorrt) 8 | [![by-nc-sa/4.0](https://img.shields.io/badge/license-CC--BY--NC--SA--4.0-lightgrey)](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en) 9 | 10 |
11 | 12 |

13 | 14 | 15 |

16 | 17 | This project provides an experimental Tensorrt implementation for ultra fast face restoration inside ComfyUI. 18 | 19 | Note: This project doesn't do pre/post processing. It only works on cropped faces for now. 20 | 21 | 22 | 23 | If you like the project, please give sa star! ⭐ 24 | 25 | --- 26 | 27 | ## ⏱️ Performance 28 | 29 | _Note: The following results were benchmarked ComfyUI, using 100 similar frames_ 30 | 31 | | Device | MODEL | PRECISION| FPS | 32 | |---------|--------|---|---| 33 | | RTX 3090 | Codeformer | FP16| 15.6| 34 | | RTX 3090 | Gfqgan | FP32| 13.1| 35 | 36 | ## 🚀 Installation 37 | 38 | Navigate to the ComfyUI `/custom_nodes` directory 39 | 40 | ```bash 41 | git clone https://github.com/yuvraj108c/ComfyUI-Facerestore-Tensorrt 42 | cd ./ComfyUI-Facerestore-Tensorrt 43 | pip install -r requirements.txt 44 | ``` 45 | 46 | ## 🛠️ Building Tensorrt Engine 47 | 48 | 1. Download one of the following onnx models: 49 | - [gfqgan.onnx](https://huggingface.co/yuvraj108c/facerestore-onnx/resolve/main/gfqgan.onnx) 50 | - [codeformer.onnx](https://huggingface.co/yuvraj108c/facerestore-onnx/resolve/main/codeformer.onnx) 51 | 2. Build tensorrt engines for these models by running: 52 | 53 | - `python export_trt.py` 54 | 55 | 3. Place the exported engines inside ComfyUI `/models/tensorrt/facerestore` directory 56 | 57 | ## ☀️ Usage 58 | 59 | - Insert node by `Right Click -> tensorrt -> Face Restore Tensorrt` 60 | 61 | ## 🤖 Environment tested 62 | 63 | - Ubuntu 22.04 LTS, Cuda 12.4, Tensorrt 10.4.0, Python 3.10, RTX 3090 GPU 64 | - Windows (Not tested, but should work) 65 | 66 | ## 👏 Credits 67 | 68 | - https://github.com/bychen7/Face-Restoration-TensorRT 69 | - https://github.com/yuvraj108c/Codeformer-Tensorrt 70 | 71 | ## License 72 | 73 | [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorrt==10.4.0 2 | polygraphy 3 | colored -------------------------------------------------------------------------------- /trt_utilities.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.cuda import nvtx 3 | from collections import OrderedDict 4 | import numpy as np 5 | from polygraphy.backend.common import bytes_from_path 6 | from polygraphy import util 7 | from polygraphy.backend.trt import ModifyNetworkOutputs, Profile 8 | from polygraphy.backend.trt import ( 9 | engine_from_bytes, 10 | engine_from_network, 11 | network_from_onnx_path, 12 | save_engine, 13 | ) 14 | from polygraphy.logger import G_LOGGER 15 | import tensorrt as trt 16 | from logging import error, warning 17 | from tqdm import tqdm 18 | import copy 19 | 20 | TRT_LOGGER = trt.Logger(trt.Logger.ERROR) 21 | G_LOGGER.module_severity = G_LOGGER.ERROR 22 | 23 | # Map of numpy dtype -> torch dtype 24 | numpy_to_torch_dtype_dict = { 25 | np.uint8: torch.uint8, 26 | np.int8: torch.int8, 27 | np.int16: torch.int16, 28 | np.int32: torch.int32, 29 | np.int64: torch.int64, 30 | np.float16: torch.float16, 31 | np.float32: torch.float32, 32 | np.float64: torch.float64, 33 | np.complex64: torch.complex64, 34 | np.complex128: torch.complex128, 35 | } 36 | if np.version.full_version >= "1.24.0": 37 | numpy_to_torch_dtype_dict[np.bool_] = torch.bool 38 | else: 39 | numpy_to_torch_dtype_dict[np.bool] = torch.bool 40 | 41 | # Map of torch dtype -> numpy dtype 42 | torch_to_numpy_dtype_dict = { 43 | value: key for (key, value) in numpy_to_torch_dtype_dict.items() 44 | } 45 | 46 | 47 | class TQDMProgressMonitor(trt.IProgressMonitor): 48 | def __init__(self): 49 | trt.IProgressMonitor.__init__(self) 50 | self._active_phases = {} 51 | self._step_result = True 52 | self.max_indent = 5 53 | 54 | def phase_start(self, phase_name, parent_phase, num_steps): 55 | leave = False 56 | try: 57 | if parent_phase is not None: 58 | nbIndents = ( 59 | self._active_phases.get(parent_phase, {}).get( 60 | "nbIndents", self.max_indent 61 | ) 62 | + 1 63 | ) 64 | if nbIndents >= self.max_indent: 65 | return 66 | else: 67 | nbIndents = 0 68 | leave = True 69 | self._active_phases[phase_name] = { 70 | "tq": tqdm( 71 | total=num_steps, desc=phase_name, leave=leave, position=nbIndents 72 | ), 73 | "nbIndents": nbIndents, 74 | "parent_phase": parent_phase, 75 | } 76 | except KeyboardInterrupt: 77 | # The phase_start callback cannot directly cancel the build, so request the cancellation from within step_complete. 78 | _step_result = False 79 | 80 | def phase_finish(self, phase_name): 81 | try: 82 | if phase_name in self._active_phases.keys(): 83 | self._active_phases[phase_name]["tq"].update( 84 | self._active_phases[phase_name]["tq"].total 85 | - self._active_phases[phase_name]["tq"].n 86 | ) 87 | 88 | parent_phase = self._active_phases[phase_name].get( 89 | "parent_phase", None) 90 | while parent_phase is not None: 91 | self._active_phases[parent_phase]["tq"].refresh() 92 | parent_phase = self._active_phases[parent_phase].get( 93 | "parent_phase", None 94 | ) 95 | if ( 96 | self._active_phases[phase_name]["parent_phase"] 97 | in self._active_phases.keys() 98 | ): 99 | self._active_phases[ 100 | self._active_phases[phase_name]["parent_phase"] 101 | ]["tq"].refresh() 102 | del self._active_phases[phase_name] 103 | pass 104 | except KeyboardInterrupt: 105 | _step_result = False 106 | 107 | def step_complete(self, phase_name, step): 108 | try: 109 | if phase_name in self._active_phases.keys(): 110 | self._active_phases[phase_name]["tq"].update( 111 | step - self._active_phases[phase_name]["tq"].n 112 | ) 113 | return self._step_result 114 | except KeyboardInterrupt: 115 | # There is no need to propagate this exception to TensorRT. We can simply cancel the build. 116 | return False 117 | 118 | 119 | class Engine: 120 | def __init__( 121 | self, 122 | engine_path, 123 | ): 124 | self.engine_path = engine_path 125 | self.engine = None 126 | self.context = None 127 | self.buffers = OrderedDict() 128 | self.tensors = OrderedDict() 129 | self.cuda_graph_instance = None # cuda graph 130 | 131 | def __del__(self): 132 | del self.engine 133 | del self.context 134 | del self.buffers 135 | del self.tensors 136 | 137 | def reset(self, engine_path=None): 138 | del self.engine 139 | del self.context 140 | del self.buffers 141 | del self.tensors 142 | self.engine_path = engine_path 143 | 144 | self.buffers = OrderedDict() 145 | self.tensors = OrderedDict() 146 | self.inputs = {} 147 | self.outputs = {} 148 | 149 | def build( 150 | self, 151 | onnx_path, 152 | fp16, 153 | input_profile=None, 154 | enable_refit=False, 155 | enable_preview=False, 156 | enable_all_tactics=False, 157 | timing_cache=None, 158 | update_output_names=None, 159 | ): 160 | print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") 161 | p = [Profile()] 162 | if input_profile: 163 | p = [Profile() for i in range(len(input_profile))] 164 | for _p, i_profile in zip(p, input_profile): 165 | for name, dims in i_profile.items(): 166 | assert len(dims) == 3 167 | _p.add(name, min=dims[0], opt=dims[1], max=dims[2]) 168 | 169 | config_kwargs = {} 170 | if not enable_all_tactics: 171 | config_kwargs["tactic_sources"] = [] 172 | 173 | network = network_from_onnx_path( 174 | onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM] 175 | ) 176 | if update_output_names: 177 | print(f"Updating network outputs to {update_output_names}") 178 | network = ModifyNetworkOutputs(network, update_output_names) 179 | 180 | builder = network[0] 181 | config = builder.create_builder_config() 182 | config.progress_monitor = TQDMProgressMonitor() 183 | 184 | config.set_flag(trt.BuilderFlag.FP16) if fp16 else None 185 | config.set_flag(trt.BuilderFlag.REFIT) if enable_refit else None 186 | 187 | profiles = copy.deepcopy(p) 188 | for profile in profiles: 189 | # Last profile is used for set_calibration_profile. 190 | calib_profile = profile.fill_defaults(network[1]).to_trt( 191 | builder, network[1] 192 | ) 193 | config.add_optimization_profile(calib_profile) 194 | 195 | try: 196 | engine = engine_from_network( 197 | network, 198 | config, 199 | ) 200 | except Exception as e: 201 | error(f"Failed to build engine: {e}") 202 | return 1 203 | try: 204 | save_engine(engine, path=self.engine_path) 205 | except Exception as e: 206 | error(f"Failed to save engine: {e}") 207 | return 1 208 | return 0 209 | 210 | def load(self): 211 | print(f"Loading TensorRT engine: {self.engine_path}") 212 | self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) 213 | 214 | def activate(self, reuse_device_memory=None): 215 | if reuse_device_memory: 216 | self.context = self.engine.create_execution_context_without_device_memory() 217 | # self.context.device_memory = reuse_device_memory 218 | else: 219 | self.context = self.engine.create_execution_context() 220 | 221 | def allocate_buffers(self, shape_dict=None, device="cuda"): 222 | nvtx.range_push("allocate_buffers") 223 | for idx in range(self.engine.num_io_tensors): 224 | name = self.engine.get_tensor_name(idx) 225 | binding = self.engine[idx] 226 | if shape_dict and binding in shape_dict: 227 | shape = shape_dict[binding]["shape"] 228 | else: 229 | shape = self.context.get_tensor_shape(name) 230 | 231 | dtype = trt.nptype(self.engine.get_tensor_dtype(name)) 232 | if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT: 233 | self.context.set_input_shape(name, shape) 234 | tensor = torch.empty( 235 | tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype] 236 | ).to(device=device) 237 | self.tensors[binding] = tensor 238 | nvtx.range_pop() 239 | 240 | def infer(self, feed_dict, stream, use_cuda_graph=False): 241 | nvtx.range_push("set_tensors") 242 | for name, buf in feed_dict.items(): 243 | self.tensors[name].copy_(buf) 244 | 245 | for name, tensor in self.tensors.items(): 246 | self.context.set_tensor_address(name, tensor.data_ptr()) 247 | nvtx.range_pop() 248 | nvtx.range_push("execute") 249 | noerror = self.context.execute_async_v3(stream) 250 | if not noerror: 251 | raise ValueError("ERROR: inference failed.") 252 | nvtx.range_pop() 253 | return self.tensors 254 | 255 | def __str__(self): 256 | out = "" 257 | for opt_profile in range(self.engine.num_optimization_profiles): 258 | for binding_idx in range(self.engine.num_bindings): 259 | name = self.engine.get_binding_name(binding_idx) 260 | shape = self.engine.get_profile_shape(opt_profile, name) 261 | out += f"\t{name} = {shape}\n" 262 | return out --------------------------------------------------------------------------------