├── __init__.py
├── assets
└── demo.PNG
├── export_trt.py
├── readme.md
├── requirements.txt
└── trt_utilities.py
/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 | import folder_paths
3 | import numpy as np
4 | import torch.nn.functional as F
5 | import torch
6 | from comfy.utils import ProgressBar
7 | import cv2
8 | from .trt_utilities import Engine
9 | from torchvision.transforms.functional import normalize
10 |
11 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
12 | """Convert torch Tensors into image numpy arrays.
13 |
14 | After clamping to [min, max], values will be normalized to [0, 1].
15 |
16 | Args:
17 | tensor (Tensor or list[Tensor]): Accept shapes:
18 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
19 | 2) 3D Tensor of shape (3/1 x H x W);
20 | 3) 2D Tensor of shape (H x W).
21 | Tensor channel should be in RGB order.
22 | rgb2bgr (bool): Whether to change rgb to bgr.
23 | out_type (numpy type): output types. If ``np.uint8``, transform outputs
24 | to uint8 type with range [0, 255]; otherwise, float type with
25 | range [0, 1]. Default: ``np.uint8``.
26 | min_max (tuple[int]): min and max values for clamp.
27 |
28 | Returns:
29 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
30 | shape (H x W). The channel order is BGR.
31 | """
32 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
33 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
34 |
35 | if torch.is_tensor(tensor):
36 | tensor = [tensor]
37 | result = []
38 | for _tensor in tensor:
39 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
40 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
41 |
42 | n_dim = _tensor.dim()
43 | if n_dim == 4:
44 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
45 | img_np = img_np.transpose(1, 2, 0)
46 | if rgb2bgr:
47 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
48 | elif n_dim == 3:
49 | img_np = _tensor.numpy()
50 | img_np = img_np.transpose(1, 2, 0)
51 | if img_np.shape[2] == 1: # gray image
52 | img_np = np.squeeze(img_np, axis=2)
53 | else:
54 | if rgb2bgr:
55 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
56 | elif n_dim == 2:
57 | img_np = _tensor.numpy()
58 | else:
59 | raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}')
60 | if out_type == np.uint8:
61 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
62 | img_np = (img_np * 255.0).round()
63 | img_np = img_np.astype(out_type)
64 | result.append(img_np)
65 | if len(result) == 1:
66 | result = result[0]
67 | return result
68 |
69 |
70 | ENGINE_DIR = os.path.join(folder_paths.models_dir,"tensorrt", "facerestore")
71 |
72 | class FaceRestoreTensorrt:
73 | @classmethod
74 | def INPUT_TYPES(s):
75 | return {
76 | "required": {
77 | "images": ("IMAGE",),
78 | "engine": (os.listdir(ENGINE_DIR),),
79 | }
80 | }
81 | RETURN_NAMES = ("IMAGE",)
82 | RETURN_TYPES = ("IMAGE",)
83 | FUNCTION = "main"
84 | CATEGORY = "tensorrt"
85 |
86 | def main(self, images, engine):
87 |
88 | # setup tensorrt engine
89 | if (not hasattr(self, 'engine') or self.engine_label != engine):
90 | self.engine = Engine(os.path.join(ENGINE_DIR,engine))
91 | self.engine.load()
92 | self.engine.activate()
93 | self.engine.allocate_buffers()
94 | self.engine_label = engine
95 |
96 | cudaStream = torch.cuda.current_stream().cuda_stream
97 | pbar = ProgressBar(images.shape[0])
98 | images = images.permute(0, 3, 1, 2)
99 | images_resized = F.interpolate(images, size=(512,512), mode='bilinear', align_corners=False)
100 | images_list = list(torch.split(images_resized, split_size_or_sections=1))
101 |
102 | output_frames = []
103 |
104 | for img in images_list:
105 | normalize(img, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
106 | result = self.engine.infer({"input": img},cudaStream)
107 | output = result['output']
108 |
109 | output = tensor2img(output, rgb2bgr=True, min_max=(-1, 1))
110 | output = output.astype('uint8')
111 | output = cv2.resize(output, (images.shape[3], images.shape[2]))
112 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
113 |
114 | output_frames.append(output)
115 | pbar.update(1)
116 |
117 |
118 | output_frames = np.array(output_frames).astype(np.float32) / 255.0
119 | return (torch.from_numpy(output_frames),)
120 |
121 | NODE_CLASS_MAPPINGS = {
122 | "FaceRestoreTensorrt" : FaceRestoreTensorrt,
123 | }
124 |
125 | NODE_DISPLAY_NAME_MAPPINGS = {
126 | "FaceRestoreTensorrt" : "Face Restore Tensorrt",
127 | }
128 |
129 | __all__ = ['NODE_CLASS_MAPPINGS', 'NODE_DISPLAY_NAME_MAPPINGS']
--------------------------------------------------------------------------------
/assets/demo.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/yuvraj108c/ComfyUI-Facerestore-Tensorrt/86eba4377a2c1dccacbf05b13ae93a764d9c0520/assets/demo.PNG
--------------------------------------------------------------------------------
/export_trt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import time
3 | from trt_utilities import Engine
4 |
5 |
6 | def export_trt(trt_path: str, onnx_path: str, use_fp16: bool):
7 | engine = Engine(trt_path)
8 |
9 | torch.cuda.empty_cache()
10 |
11 | s = time.time()
12 | ret = engine.build(
13 | onnx_path,
14 | use_fp16,
15 | enable_preview=True,
16 | )
17 | e = time.time()
18 | print(f"Time taken to build: {(e-s)} seconds")
19 |
20 | return ret
21 |
22 |
23 | export_trt(trt_path="./codeformer.engine",
24 | onnx_path="./codeformer.onnx", use_fp16=True)
25 |
26 | export_trt(trt_path="./gfqgan.engine",
27 | onnx_path="./gfqgan.onnx", use_fp16=False)
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # ComfyUI Facerestore TensorRT
4 |
5 | [](https://www.python.org/downloads/release/python-31012/)
6 | [](https://developer.nvidia.com/cuda-downloads)
7 | [](https://developer.nvidia.com/tensorrt)
8 | [](https://creativecommons.org/licenses/by-nc-sa/4.0/deed.en)
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 | This project provides an experimental Tensorrt implementation for ultra fast face restoration inside ComfyUI.
18 |
19 | Note: This project doesn't do pre/post processing. It only works on cropped faces for now.
20 |
21 |
22 |
23 | If you like the project, please give sa star! ⭐
24 |
25 | ---
26 |
27 | ## ⏱️ Performance
28 |
29 | _Note: The following results were benchmarked ComfyUI, using 100 similar frames_
30 |
31 | | Device | MODEL | PRECISION| FPS |
32 | |---------|--------|---|---|
33 | | RTX 3090 | Codeformer | FP16| 15.6|
34 | | RTX 3090 | Gfqgan | FP32| 13.1|
35 |
36 | ## 🚀 Installation
37 |
38 | Navigate to the ComfyUI `/custom_nodes` directory
39 |
40 | ```bash
41 | git clone https://github.com/yuvraj108c/ComfyUI-Facerestore-Tensorrt
42 | cd ./ComfyUI-Facerestore-Tensorrt
43 | pip install -r requirements.txt
44 | ```
45 |
46 | ## 🛠️ Building Tensorrt Engine
47 |
48 | 1. Download one of the following onnx models:
49 | - [gfqgan.onnx](https://huggingface.co/yuvraj108c/facerestore-onnx/resolve/main/gfqgan.onnx)
50 | - [codeformer.onnx](https://huggingface.co/yuvraj108c/facerestore-onnx/resolve/main/codeformer.onnx)
51 | 2. Build tensorrt engines for these models by running:
52 |
53 | - `python export_trt.py`
54 |
55 | 3. Place the exported engines inside ComfyUI `/models/tensorrt/facerestore` directory
56 |
57 | ## ☀️ Usage
58 |
59 | - Insert node by `Right Click -> tensorrt -> Face Restore Tensorrt`
60 |
61 | ## 🤖 Environment tested
62 |
63 | - Ubuntu 22.04 LTS, Cuda 12.4, Tensorrt 10.4.0, Python 3.10, RTX 3090 GPU
64 | - Windows (Not tested, but should work)
65 |
66 | ## 👏 Credits
67 |
68 | - https://github.com/bychen7/Face-Restoration-TensorRT
69 | - https://github.com/yuvraj108c/Codeformer-Tensorrt
70 |
71 | ## License
72 |
73 | [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)](https://creativecommons.org/licenses/by-nc-sa/4.0/)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorrt==10.4.0
2 | polygraphy
3 | colored
--------------------------------------------------------------------------------
/trt_utilities.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.cuda import nvtx
3 | from collections import OrderedDict
4 | import numpy as np
5 | from polygraphy.backend.common import bytes_from_path
6 | from polygraphy import util
7 | from polygraphy.backend.trt import ModifyNetworkOutputs, Profile
8 | from polygraphy.backend.trt import (
9 | engine_from_bytes,
10 | engine_from_network,
11 | network_from_onnx_path,
12 | save_engine,
13 | )
14 | from polygraphy.logger import G_LOGGER
15 | import tensorrt as trt
16 | from logging import error, warning
17 | from tqdm import tqdm
18 | import copy
19 |
20 | TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
21 | G_LOGGER.module_severity = G_LOGGER.ERROR
22 |
23 | # Map of numpy dtype -> torch dtype
24 | numpy_to_torch_dtype_dict = {
25 | np.uint8: torch.uint8,
26 | np.int8: torch.int8,
27 | np.int16: torch.int16,
28 | np.int32: torch.int32,
29 | np.int64: torch.int64,
30 | np.float16: torch.float16,
31 | np.float32: torch.float32,
32 | np.float64: torch.float64,
33 | np.complex64: torch.complex64,
34 | np.complex128: torch.complex128,
35 | }
36 | if np.version.full_version >= "1.24.0":
37 | numpy_to_torch_dtype_dict[np.bool_] = torch.bool
38 | else:
39 | numpy_to_torch_dtype_dict[np.bool] = torch.bool
40 |
41 | # Map of torch dtype -> numpy dtype
42 | torch_to_numpy_dtype_dict = {
43 | value: key for (key, value) in numpy_to_torch_dtype_dict.items()
44 | }
45 |
46 |
47 | class TQDMProgressMonitor(trt.IProgressMonitor):
48 | def __init__(self):
49 | trt.IProgressMonitor.__init__(self)
50 | self._active_phases = {}
51 | self._step_result = True
52 | self.max_indent = 5
53 |
54 | def phase_start(self, phase_name, parent_phase, num_steps):
55 | leave = False
56 | try:
57 | if parent_phase is not None:
58 | nbIndents = (
59 | self._active_phases.get(parent_phase, {}).get(
60 | "nbIndents", self.max_indent
61 | )
62 | + 1
63 | )
64 | if nbIndents >= self.max_indent:
65 | return
66 | else:
67 | nbIndents = 0
68 | leave = True
69 | self._active_phases[phase_name] = {
70 | "tq": tqdm(
71 | total=num_steps, desc=phase_name, leave=leave, position=nbIndents
72 | ),
73 | "nbIndents": nbIndents,
74 | "parent_phase": parent_phase,
75 | }
76 | except KeyboardInterrupt:
77 | # The phase_start callback cannot directly cancel the build, so request the cancellation from within step_complete.
78 | _step_result = False
79 |
80 | def phase_finish(self, phase_name):
81 | try:
82 | if phase_name in self._active_phases.keys():
83 | self._active_phases[phase_name]["tq"].update(
84 | self._active_phases[phase_name]["tq"].total
85 | - self._active_phases[phase_name]["tq"].n
86 | )
87 |
88 | parent_phase = self._active_phases[phase_name].get(
89 | "parent_phase", None)
90 | while parent_phase is not None:
91 | self._active_phases[parent_phase]["tq"].refresh()
92 | parent_phase = self._active_phases[parent_phase].get(
93 | "parent_phase", None
94 | )
95 | if (
96 | self._active_phases[phase_name]["parent_phase"]
97 | in self._active_phases.keys()
98 | ):
99 | self._active_phases[
100 | self._active_phases[phase_name]["parent_phase"]
101 | ]["tq"].refresh()
102 | del self._active_phases[phase_name]
103 | pass
104 | except KeyboardInterrupt:
105 | _step_result = False
106 |
107 | def step_complete(self, phase_name, step):
108 | try:
109 | if phase_name in self._active_phases.keys():
110 | self._active_phases[phase_name]["tq"].update(
111 | step - self._active_phases[phase_name]["tq"].n
112 | )
113 | return self._step_result
114 | except KeyboardInterrupt:
115 | # There is no need to propagate this exception to TensorRT. We can simply cancel the build.
116 | return False
117 |
118 |
119 | class Engine:
120 | def __init__(
121 | self,
122 | engine_path,
123 | ):
124 | self.engine_path = engine_path
125 | self.engine = None
126 | self.context = None
127 | self.buffers = OrderedDict()
128 | self.tensors = OrderedDict()
129 | self.cuda_graph_instance = None # cuda graph
130 |
131 | def __del__(self):
132 | del self.engine
133 | del self.context
134 | del self.buffers
135 | del self.tensors
136 |
137 | def reset(self, engine_path=None):
138 | del self.engine
139 | del self.context
140 | del self.buffers
141 | del self.tensors
142 | self.engine_path = engine_path
143 |
144 | self.buffers = OrderedDict()
145 | self.tensors = OrderedDict()
146 | self.inputs = {}
147 | self.outputs = {}
148 |
149 | def build(
150 | self,
151 | onnx_path,
152 | fp16,
153 | input_profile=None,
154 | enable_refit=False,
155 | enable_preview=False,
156 | enable_all_tactics=False,
157 | timing_cache=None,
158 | update_output_names=None,
159 | ):
160 | print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
161 | p = [Profile()]
162 | if input_profile:
163 | p = [Profile() for i in range(len(input_profile))]
164 | for _p, i_profile in zip(p, input_profile):
165 | for name, dims in i_profile.items():
166 | assert len(dims) == 3
167 | _p.add(name, min=dims[0], opt=dims[1], max=dims[2])
168 |
169 | config_kwargs = {}
170 | if not enable_all_tactics:
171 | config_kwargs["tactic_sources"] = []
172 |
173 | network = network_from_onnx_path(
174 | onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]
175 | )
176 | if update_output_names:
177 | print(f"Updating network outputs to {update_output_names}")
178 | network = ModifyNetworkOutputs(network, update_output_names)
179 |
180 | builder = network[0]
181 | config = builder.create_builder_config()
182 | config.progress_monitor = TQDMProgressMonitor()
183 |
184 | config.set_flag(trt.BuilderFlag.FP16) if fp16 else None
185 | config.set_flag(trt.BuilderFlag.REFIT) if enable_refit else None
186 |
187 | profiles = copy.deepcopy(p)
188 | for profile in profiles:
189 | # Last profile is used for set_calibration_profile.
190 | calib_profile = profile.fill_defaults(network[1]).to_trt(
191 | builder, network[1]
192 | )
193 | config.add_optimization_profile(calib_profile)
194 |
195 | try:
196 | engine = engine_from_network(
197 | network,
198 | config,
199 | )
200 | except Exception as e:
201 | error(f"Failed to build engine: {e}")
202 | return 1
203 | try:
204 | save_engine(engine, path=self.engine_path)
205 | except Exception as e:
206 | error(f"Failed to save engine: {e}")
207 | return 1
208 | return 0
209 |
210 | def load(self):
211 | print(f"Loading TensorRT engine: {self.engine_path}")
212 | self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
213 |
214 | def activate(self, reuse_device_memory=None):
215 | if reuse_device_memory:
216 | self.context = self.engine.create_execution_context_without_device_memory()
217 | # self.context.device_memory = reuse_device_memory
218 | else:
219 | self.context = self.engine.create_execution_context()
220 |
221 | def allocate_buffers(self, shape_dict=None, device="cuda"):
222 | nvtx.range_push("allocate_buffers")
223 | for idx in range(self.engine.num_io_tensors):
224 | name = self.engine.get_tensor_name(idx)
225 | binding = self.engine[idx]
226 | if shape_dict and binding in shape_dict:
227 | shape = shape_dict[binding]["shape"]
228 | else:
229 | shape = self.context.get_tensor_shape(name)
230 |
231 | dtype = trt.nptype(self.engine.get_tensor_dtype(name))
232 | if self.engine.get_tensor_mode(name) == trt.TensorIOMode.INPUT:
233 | self.context.set_input_shape(name, shape)
234 | tensor = torch.empty(
235 | tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]
236 | ).to(device=device)
237 | self.tensors[binding] = tensor
238 | nvtx.range_pop()
239 |
240 | def infer(self, feed_dict, stream, use_cuda_graph=False):
241 | nvtx.range_push("set_tensors")
242 | for name, buf in feed_dict.items():
243 | self.tensors[name].copy_(buf)
244 |
245 | for name, tensor in self.tensors.items():
246 | self.context.set_tensor_address(name, tensor.data_ptr())
247 | nvtx.range_pop()
248 | nvtx.range_push("execute")
249 | noerror = self.context.execute_async_v3(stream)
250 | if not noerror:
251 | raise ValueError("ERROR: inference failed.")
252 | nvtx.range_pop()
253 | return self.tensors
254 |
255 | def __str__(self):
256 | out = ""
257 | for opt_profile in range(self.engine.num_optimization_profiles):
258 | for binding_idx in range(self.engine.num_bindings):
259 | name = self.engine.get_binding_name(binding_idx)
260 | shape = self.engine.get_profile_shape(opt_profile, name)
261 | out += f"\t{name} = {shape}\n"
262 | return out
--------------------------------------------------------------------------------