├── .gitignore ├── LICENSE ├── README.md ├── mindspore ├── README.md ├── build_mindir.py ├── convert_weight.py ├── dataset │ ├── triplet.py │ └── video.py ├── environment-full.yml ├── environment.yml ├── inference.py ├── model │ ├── __init__.py │ ├── deform_conv.py │ ├── model.py │ ├── model_sep.py │ └── train.py ├── train.py └── util │ ├── converter.py │ ├── normalize.py │ └── rmse.py ├── tensorrt-conda └── recipe │ ├── LICENSE │ ├── bld.bat │ ├── build.sh │ ├── meta.yaml │ └── run_test.py ├── tensorrt ├── .clang-format ├── .gitignore ├── CMakeLists.txt ├── README.md ├── app │ ├── inference_y4m.cpp │ └── optimizer.cpp ├── cmake │ └── Modules │ │ ├── FindTensorRT.cmake │ │ └── FindVapourSynth.cmake ├── demo.vpy ├── include │ ├── config.h │ ├── debug │ │ └── reveal.h │ ├── helper.h │ ├── inference.h │ ├── logging.h │ ├── md_view.h │ ├── optimize.h │ ├── reformat.h │ └── utils.h ├── layers │ ├── impl │ │ └── dcn_layer.cu │ ├── include │ │ ├── internal │ │ │ ├── config.h │ │ │ ├── dcn_layer.h │ │ │ └── dcn_layer_impl.h │ │ └── layers.h │ ├── src │ │ ├── dcn_layer.cpp │ │ └── layers.cpp │ └── test │ │ └── dcn_layer_test.cpp ├── model_src │ └── .gitkeep ├── models │ └── .gitkeep └── src │ ├── inference.cpp │ ├── optimize.cpp │ ├── reformat.cu │ └── vs-plugin.cpp └── torch ├── README.md ├── cycmunet ├── model.py └── run.py ├── cycmunet_export_onnx.py ├── cycmunet_test.py ├── cycmunet_train.py ├── dataset ├── __init__.py ├── sequence.py ├── util.py └── video.py ├── environment.yml └── model ├── __init__.py ├── cycmunet ├── __init__.py ├── module.py └── part.py ├── part.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | checkpoints -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 Yuan Tong 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Implementation for CycMuNet+ 2 | 3 | This repo contains the implementation of the CVPR 2022 paper 4 | [*CycMuNet+: Cycle-Projected Mutual Learning for Spatial-Temporal Video Super-Resolution*](https://openaccess.thecvf.com/content/CVPR2022/html/Hu_Spatial-Temporal_Space_Hand-in-Hand_Spatial-Temporal_Video_Super-Resolution_via_Cycle-Projected_Mutual_Learning_CVPR_2022_paper.html) 5 | in multiple platforms. 6 | See README in each folder for detail information. 7 | 8 | [torch](https://github.com/tongyuantongyu/cycmunet/tree/main/torch) contains PyTorch implementation of train and test code. 9 | 10 | [mindspore](https://github.com/tongyuantongyu/cycmunet/tree/main/mindspore) contains MindSpore implementation of train and test code 11 | capable of running on Huawei Ascend platform. 12 | 13 | [tensorrt](https://github.com/tongyuantongyu/cycmunet/tree/main/tensorrt) contains TensorRT implementation of inference code, 14 | as well as a VapourSynth plugin ready for use. 15 | -------------------------------------------------------------------------------- /mindspore/README.md: -------------------------------------------------------------------------------- 1 | # MindSpore implementation of CycMuNet+ 2 | 3 | This is the MindSpore implementation of CycMuNet+ capable of running on CPU and 4 | Huawei Ascend AI Processor. 5 | 6 | ## Installation 7 | 8 | We recommend using a prebuilt image provided by Huawei with necessary Ascend 9 | and MindSpore dependencies installed. Run following command to install extra 10 | necessary dependencies: 11 | 12 | ```bash 13 | conda env update -f environment.yml 14 | ``` 15 | 16 | `environment-full.yml` contains a full list of libraries in the environment. 17 | Note that you may not be able to recreate the environment using this file. 18 | Instead see [https://mindspore.cn/install](https://mindspore.cn/install) for 19 | instructions to manually install MindSpore if you prefer. 20 | 21 | ## Contents 22 | 23 | ### `train.py` 24 | 25 | Train the network. 26 | 27 | ### `inference.py` 28 | 29 | Do network inference. 30 | 31 | ### `convert_weight.py` 32 | 33 | Import checkpoints from PyTorch to MindSpore. 34 | 35 | ### `build_mindir.py` 36 | 37 | Build MindIR file. 38 | -------------------------------------------------------------------------------- /mindspore/build_mindir.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import time 3 | 4 | import numpy as np 5 | import mindspore as ms 6 | 7 | from model import CycMuNet 8 | 9 | 10 | def export_rgb(checkpoint, size): 11 | dummyArg = namedtuple('dummyArg', ( 12 | 'nf', 'groups', 'upscale_factor', 'format', 'layers', 'cycle_count', 'batch_mode', 'all_frames', 13 | 'stop_at_conf')) 14 | 15 | args = dummyArg(nf=64, groups=8, upscale_factor=4, format='rgb', layers=3, cycle_count=5, batch_mode='sequence', 16 | all_frames=False, stop_at_conf=False) 17 | 18 | ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend") 19 | 20 | print('Init done') 21 | build_start = time.time() 22 | model = CycMuNet(args) 23 | ms.load_checkpoint(checkpoint, model) 24 | 25 | inp = ms.Tensor(np.ones((2, 3, *size), dtype=np.float32)) 26 | inp = inp.astype(ms.float16) 27 | model = model.to_float(ms.float16) 28 | model.compile(inp) 29 | print(f'Load done in {time.time() - build_start}s') 30 | # verify 31 | model(inp) 32 | 33 | ms.export(model, inp, file_name=model, file_format='MINDIR') 34 | print('Export done') 35 | 36 | 37 | def export_yuv(checkpoint, size): 38 | dummyArg = namedtuple('dummyArg', ( 39 | 'nf', 'groups', 'upscale_factor', 'format', 'layers', 'cycle_count', 'batch_mode', 'all_frames', 40 | 'stop_at_conf')) 41 | 42 | args = dummyArg(nf=64, groups=8, upscale_factor=2, format='yuv420', layers=4, cycle_count=3, batch_mode='sequence', 43 | all_frames=False, stop_at_conf=False) 44 | 45 | ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend") 46 | 47 | print('Init done') 48 | build_start = time.time() 49 | model = CycMuNet(args) 50 | ms.load_checkpoint(checkpoint, model) 51 | 52 | inp_y = ms.Tensor(np.zeros((2, 1, *size), dtype=np.float16)) 53 | inp_uv = ms.Tensor(np.zeros((2, 2, size[0] // 2, size[1] // 2), dtype=np.float16)) 54 | model = model.to_float(ms.float16) 55 | model.compile(inp_y, inp_uv) 56 | print(f'Load done in {time.time() - build_start}s') 57 | # verify 58 | model(inp_y, inp_uv) 59 | 60 | ms.export(model, inp_y, inp_uv, file_name=model, file_format='MINDIR') 61 | print('Export done') 62 | 63 | 64 | if __name__ == '__main__': 65 | # export_rgb('model-files/2x_rgb_base.ckpt', 'model-files/cycmunet_2x_rgb', (64, 64)) 66 | export_yuv('model-files/2x_yuv420_cycle3_layer4.ckpt', 'model-files/cycmunet_2x_yuv420_cycle3_layer4', (1920, 1088)) 67 | -------------------------------------------------------------------------------- /mindspore/convert_weight.py: -------------------------------------------------------------------------------- 1 | import re 2 | from collections import namedtuple 3 | 4 | import mindspore as ms 5 | from mindspore import ops 6 | import torch 7 | 8 | import model 9 | 10 | 11 | def transform_dcnpack(weights): 12 | result = { 13 | 'dcn_weight': weights['dcn.weight'], 14 | 'dcn_bias': weights['dcn.bias'], 15 | 'conv_mask.weight': weights['conv_mask.weight'], 16 | 'conv_mask.bias': weights['conv_mask.bias'], 17 | } 18 | 19 | w = weights['conv_offset.weight'].reshape(72, 2, 64, 3, 3) 20 | b = weights['conv_offset.bias'].reshape(72, 2) 21 | w = w[:, ::-1, ...].transpose(1, 0, 2, 3, 4).reshape(144, 64, 3, 3) 22 | b = b[:, ::-1, ...].transpose(1, 0).reshape(144) 23 | 24 | result['conv_offset.weight'] = w 25 | result['conv_offset.bias'] = b 26 | return result 27 | 28 | 29 | if __name__ == '__main__': 30 | torch_source = 'checkpoints/2x_cycle3_yuv420_sparsity_epoch_20.pth' 31 | ms_normal = 'checkpoints/2x_yuv420_cycle3_layer4.ckpt' 32 | ms_sep = 'checkpoints/2x_sep_yuv420_cycle3_layer4.ckpt' 33 | 34 | dummyArg = namedtuple('dummyArg', ( 35 | 'nf', 'groups', 'upscale_factor', 'format', 'layers', 'cycle_count', 'batch_mode', 'all_frames', 36 | 'stop_at_conf')) 37 | 38 | size = (64, 64) 39 | args = dummyArg(nf=64, groups=8, upscale_factor=2, format='yuv420', layers=4, cycle_count=3, batch_mode='sequence', 40 | all_frames=False, stop_at_conf=False) 41 | 42 | print('Init done') 43 | 44 | rewrite_names = { 45 | ".Pro_align.conv1x1.": 3, 46 | ".Pro_align.conv1_3x3.": 2, 47 | ".offset_conv1.": 2, 48 | ".offset_conv2.": 2, 49 | ".fea_conv.": 2, 50 | "ff.fusion.": 2, 51 | "mu.conv.": 2 * args.cycle_count + 1 52 | } 53 | 54 | rewrite_names_re = { 55 | r"(merge1?\.(\d+)\.)": lambda match: int(match[2]) + 1, 56 | } 57 | 58 | # normal model 59 | model_normal = model.CycMuNet(args) 60 | 61 | source = torch.load(torch_source, map_location=torch.device('cpu')) 62 | source = {k: v for k, v in source.items() if '__weight_mma_mask' not in k} 63 | template = model_normal.parameters_dict() 64 | 65 | dest = dict() 66 | pending_dcn = dict() 67 | for k, v in source.items(): 68 | if '.dcnpack.' in k: 69 | module, name = k.split('.dcnpack.') 70 | if module in pending_dcn: 71 | pending_dcn[module][name] = v.numpy() 72 | else: 73 | pending_dcn[module] = {name: v.numpy()} 74 | continue 75 | 76 | for name in rewrite_names: 77 | k = k.replace(name, name + 'conv.') 78 | for re_name in rewrite_names_re: 79 | k = re.sub(re_name, "\\1conv.", k) 80 | if k in template: 81 | dest[k] = ms.Parameter(v.numpy()) 82 | else: 83 | print(f"Unknown parameter {k} ignored.") 84 | 85 | for m, ws in pending_dcn.items(): 86 | for name, w in transform_dcnpack(ws).items(): 87 | dest[f'{m}.dcnpack.{name}'] = ms.Parameter(w) 88 | 89 | print(ms.load_param_into_net(model_normal, dest, strict_load=True)) 90 | ms.save_checkpoint(model_normal, ms_normal) 91 | 92 | print('Done normal model') 93 | 94 | # sep model: concat + conv is separated to multiple conv + add, to reduce memory footprint 95 | model_separate = model.cycmunet_sep(args) 96 | template = model_separate.parameters_dict() 97 | 98 | dest = dict() 99 | pending_dcn = dict() 100 | 101 | def filter_catconv(k, tensor): 102 | for name, n in rewrite_names.items(): 103 | if name in k: 104 | t = ms.Parameter(tensor.numpy()) 105 | if k.endswith('.weight'): 106 | dest.update({k.replace(name, f'{name}convs.{i}.'): ms.Parameter(v) for i, v in 107 | enumerate(ops.split(t, axis=1, output_num=n))}) 108 | elif k.endswith('.bias'): 109 | dest[k.replace(name, name + 'convs.0.')] = t 110 | return True 111 | for name, get_n in rewrite_names_re.items(): 112 | search_result = re.search(name, k) 113 | if not search_result: 114 | continue 115 | n = get_n(search_result) 116 | t = ms.Parameter(tensor.numpy()) 117 | if k.endswith('.weight'): 118 | dest.update({re.sub(name, f'\\1convs.{i}.', k): ms.Parameter(v) for i, v in 119 | enumerate(ops.split(t, axis=1, output_num=n))}) 120 | elif k.endswith('.bias'): 121 | dest[re.sub(name, f'\\1convs.0.', k)] = t 122 | return True 123 | return False 124 | 125 | for k, v in source.items(): 126 | if '.dcnpack.' in k: 127 | module, name = k.split('.dcnpack.') 128 | if module in pending_dcn: 129 | pending_dcn[module][name] = v.numpy() 130 | else: 131 | pending_dcn[module] = {name: v.numpy()} 132 | continue 133 | 134 | if filter_catconv(k, v): 135 | continue 136 | if k in template: 137 | dest[k] = ms.Parameter(v.numpy()) 138 | else: 139 | print(f"Unknown parameter {k} ignored.") 140 | 141 | for m, ws in pending_dcn.items(): 142 | for name, w in transform_dcnpack(ws).items(): 143 | dest[f'{m}.dcnpack.{name}'] = ms.Parameter(w) 144 | 145 | print(ms.load_param_into_net(model_separate, dest, strict_load=True)) 146 | ms.save_checkpoint(model_separate, ms_sep) 147 | 148 | print('Done separate model') 149 | -------------------------------------------------------------------------------- /mindspore/dataset/triplet.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import random 3 | from typing import List 4 | 5 | import cv2 6 | import mindspore as ms 7 | from mindspore import dataset 8 | from mindspore.dataset import vision, transforms 9 | import numpy as np 10 | from PIL import Image, ImageFilter 11 | 12 | 13 | class ImageTripletGenerator: 14 | def __init__(self, index_file, patch_size, scale_factor, augment, seed=0): 15 | self.dataset_base = pathlib.Path(index_file).parent 16 | self.triplets = [i for i in open(index_file, 'r', encoding='utf-8').read().split('\n') 17 | if i if not i.startswith('#')] 18 | self.patch_size = patch_size 19 | self.scale_factor = scale_factor 20 | self.augment = augment 21 | self.rand = random.Random(seed) 22 | 23 | def _load_triplet(self, path): 24 | path = self.dataset_base / "sequences" / path 25 | images = [Image.open(path / f"im{i + 1}.png") for i in range(3)] 26 | if not (images[0].size == images[1].size and images[0].size == images[2].size): 27 | raise ValueError("triplet has different dimensions") 28 | return images 29 | 30 | def _prepare_images(self, images: List[Image.Image]): 31 | w, h = images[0].size 32 | f = self.scale_factor 33 | s = self.patch_size * f 34 | dh, dw = self.rand.randrange(0, h - s, 2) * f, self.rand.randrange(0, w - s, 2) * f 35 | images = [i.crop((dw, dh, dw + s, dh + s)) for i in images] 36 | return images 37 | 38 | trans_groups = { 39 | 'none': [None], 40 | 'rotate': [None, Image.ROTATE_90, Image.ROTATE_180, Image.ROTATE_270], 41 | 'mirror': [None, Image.FLIP_LEFT_RIGHT], 42 | 'flip': [None, Image.FLIP_LEFT_RIGHT, Image.FLIP_TOP_BOTTOM, Image.ROTATE_180], 43 | 'all': [None] + [e.value for e in Image.Transpose], 44 | } 45 | 46 | trans_names = [e.name for e in Image.Transpose] 47 | 48 | def _augment_images(self, images: List[Image.Image], trans_mode='all'): 49 | trans_action = 'none' 50 | trans_op = self.rand.choice(self.trans_groups[trans_mode]) 51 | if trans_op is not None: 52 | images = [i.transpose(trans_op) for i in images] 53 | trans_action = self.trans_names[trans_op] 54 | return images, trans_action 55 | 56 | scale_filters = [Image.BILINEAR, Image.BICUBIC, Image.LANCZOS] 57 | 58 | def _scale_images(self, images: List[Image.Image]): 59 | f = self.scale_factor 60 | return [i.resize((i.width // f, i.height // f), self.rand.choice(self.scale_filters)) for i in images] 61 | 62 | def _degrade_images(self, images: List[Image.Image]): 63 | degrade_action = None 64 | decision = self.rand.randrange(4) 65 | if decision == 1: 66 | degrade_action = 'box' 67 | arr = [np.array(Image.blend(j, j.copy().filter(ImageFilter.BoxBlur(1)), 0.5)) for j in images] 68 | elif decision == 2: 69 | degrade_action = 'gaussian' 70 | radius = self.rand.random() * 2 71 | arr = [np.array(j.filter(ImageFilter.GaussianBlur(radius))) for j in images] 72 | elif decision == 3: 73 | degrade_action = 'halo' 74 | radius = 1 + self.rand.random() * 2 75 | modulation = 0.1 + radius * 0.3 76 | contour = [np.array(i.copy().filter(ImageFilter.CONTOUR).filter(ImageFilter.GaussianBlur(radius))) 77 | for i in images] 78 | arr = [cv2.addWeighted(np.array(i), 1, j, modulation, 0) for i, j in zip(images, contour)] 79 | else: 80 | arr = [np.array(i) for i in images] 81 | 82 | return arr, degrade_action 83 | 84 | def __len__(self): 85 | return len(self.triplets) 86 | 87 | def __getitem__(self, idx): 88 | triplet = self._load_triplet(self.triplets[idx]) 89 | triplet = self._prepare_images(triplet) # crop to requested size 90 | original, _ = self._augment_images(triplet) # flip and rotates 91 | lf1 = original[1] 92 | lf1 = np.array(lf1.resize((lf1.width // self.scale_factor, lf1.height // self.scale_factor), Image.LANCZOS)) 93 | degraded, _ = self._degrade_images(self._scale_images([original[0], original[2]])) 94 | degraded.insert(1, lf1) 95 | return (*original, *degraded) 96 | 97 | 98 | def ImageTripletDataset(index_file, patch_size, scale_factor, augment, normalizer): 99 | ds = dataset.GeneratorDataset( 100 | ImageTripletGenerator(index_file, patch_size, scale_factor, augment), 101 | column_names=[f'{s}{i}' for s in ('h', 'l') for i in range(3)], 102 | ) 103 | 104 | mean, std = normalizer.rgb_dist() 105 | 106 | for col in [f'{s}{i}' for s in ('h', 'l') for i in range(3)]: 107 | ds = ds.map([ 108 | transforms.TypeCast(ms.float32), 109 | vision.Rescale(1.0 / 255.0, 0), 110 | vision.Normalize(mean, std, is_hwc=True), 111 | vision.HWC2CHW() 112 | ], col) 113 | 114 | return ds 115 | -------------------------------------------------------------------------------- /mindspore/dataset/video.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import collections 3 | import functools 4 | import itertools 5 | import pathlib 6 | import random 7 | import time 8 | from typing import List, Tuple 9 | 10 | import av 11 | import av.logging 12 | import cv2 13 | import numpy as np 14 | import mindspore as ms 15 | from mindspore import dataset 16 | from mindspore.dataset import vision, transforms 17 | 18 | 19 | av.logging.set_level(av.logging.FATAL) 20 | 21 | perf_debug = False 22 | 23 | 24 | class Video: 25 | def __init__(self, file, kf): 26 | self.container = av.open(file) 27 | self.stream = self.container.streams.video[0] 28 | self.stream.thread_type = "AUTO" 29 | self.at = 0 30 | self.kf = kf 31 | 32 | def get_frames(self, pts, n=1): 33 | frames = [] 34 | if bisect.bisect_left(self.kf, pts) != bisect.bisect_left(self.kf, self.at) or pts <= self.at: 35 | self.container.seek(pts, stream=self.stream) 36 | found = False 37 | first = True 38 | if perf_debug: 39 | print(f'Seek {pts} done at {time.perf_counter()}') 40 | for frame in self.container.decode(video=0): 41 | if first: 42 | if perf_debug: 43 | print(f'Search {pts} from {frame.pts} at {time.perf_counter()}') 44 | first = False 45 | if not found and frame.pts != pts: 46 | continue 47 | found = True 48 | if perf_debug: 49 | print(f'Found {frame.pts} at {time.perf_counter()}') 50 | self.at = frame.pts 51 | yuv = frame.to_ndarray() 52 | h, w = frame.height, frame.width 53 | y, uv = yuv[:h, :].reshape(1, h, w), yuv[h:, :].reshape(2, h // 2, w // 2) 54 | frames.append((y, uv)) 55 | if len(frames) == n: 56 | return frames 57 | raise ValueError("unexpected end") 58 | 59 | def __del__(self): 60 | self.container.close() 61 | 62 | 63 | video_info = collections.namedtuple('video_info', [ 64 | 'org', 65 | 'deg', 66 | 'frames', 67 | 'pts_org', 68 | 'pts_deg', 69 | 'key_org', 70 | 'key_deg' 71 | ]) 72 | 73 | 74 | def flatten_once(it): 75 | return itertools.chain.from_iterable(it) 76 | 77 | 78 | class VideoFrameGenerator: 79 | def __init__(self, index_file, patch_size, scale_factor, augment, seed=0): 80 | self.dataset_base = pathlib.PurePath(index_file).parent 81 | index_lines = [i for i in open(index_file, 'r', encoding='utf-8').read().split('\n') 82 | if i if not i.startswith('#')] 83 | files = [tuple(i.split(',')) for i in index_lines] 84 | self.files = [] 85 | self.indexes = [] 86 | for org, deg, frames, pts_org, pts_deg, key_org, key_deg in files: 87 | info = video_info( 88 | org, 89 | deg, 90 | int(frames), 91 | tuple(int(i) for i in pts_org.split(' ')), 92 | tuple(int(i) for i in pts_deg.split(' ')), 93 | tuple(int(i) for i in key_org.split(' ')), 94 | tuple(int(i) for i in key_deg.split(' ')), 95 | ) 96 | self.files.append(info) 97 | self.indexes.append(info.frames) 98 | self.indexes = list(itertools.accumulate(self.indexes)) 99 | self.patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size 100 | self.scale_factor = scale_factor 101 | self.augment = augment 102 | self.rand = random.Random(seed) 103 | 104 | @functools.lru_cache(2) 105 | def get_video(self, v_idx): 106 | info = self.files[v_idx] 107 | return Video(str(self.dataset_base / info.org), info.key_org), \ 108 | Video(str(self.dataset_base / info.deg), info.key_deg), info.pts_org, info.pts_deg 109 | 110 | def _augment_frame(self, org: List[Tuple[np.ndarray]], deg: List[Tuple[np.ndarray]]): 111 | if self.rand.random() > 0.5: 112 | org = [(y[..., ::-1].copy(), uv[..., ::-1].copy()) for y, uv in org] 113 | deg = [(y[..., ::-1].copy(), uv[..., ::-1].copy()) for y, uv in deg] 114 | return org, deg 115 | 116 | def _prepare_frame(self, org: List[Tuple[np.ndarray]], deg: List[Tuple[np.ndarray]]): 117 | _, h, w = deg[0][0].shape 118 | sw, sh = self.patch_size 119 | sh_uv, sw_uv = sh // 2, sw // 2 120 | dh, dw = self.rand.randrange(0, h - sh, 2), self.rand.randrange(0, w - sw, 2) 121 | dh_uv, dw_uv = dh // 2, dw // 2 122 | deg = [(y[:, dh:dh+sh, dw:dw+sw], uv[:, dh_uv:dh_uv+sh_uv, dw_uv:dw_uv+sw_uv]) for y, uv in deg] 123 | f = self.scale_factor 124 | size, size_uv = (sw, sh), (sw_uv, sh_uv) 125 | sh, sw, sh_uv, sw_uv = sh * f, sw * f, sh_uv * f, sw_uv * f 126 | dh, dw, dh_uv, dw_uv = dh * f, dw * f, dh_uv * f, dw_uv * f 127 | org = [(y[:, dh:dh+sh, dw:dw+sw], uv[:, dh_uv:dh_uv+sh_uv, dw_uv:dw_uv+sw_uv]) for y, uv in org] 128 | 129 | deg1_y = cv2.resize(org[1][0][0], size, interpolation=cv2.INTER_LANCZOS4) 130 | deg1_u = cv2.resize(org[1][1][0], size_uv, interpolation=cv2.INTER_LANCZOS4) 131 | deg1_v = cv2.resize(org[1][1][1], size_uv, interpolation=cv2.INTER_LANCZOS4) 132 | deg.insert(1, (deg1_y.reshape((1, *size[::-1])), np.stack((deg1_u, deg1_v)).reshape((2, *size_uv[::-1])))) 133 | return org, deg 134 | 135 | def __len__(self): 136 | return self.indexes[-1] 137 | 138 | def __getitem__(self, idx): 139 | start = time.perf_counter() 140 | v_idx = bisect.bisect_right(self.indexes, idx) 141 | f_idx = idx if v_idx == 0 else idx - self.indexes[v_idx - 1] 142 | org, deg, pts_org, pts_deg = self.get_video(v_idx) 143 | org_frames = org.get_frames(pts_org[f_idx], 3) 144 | deg_frames = deg.get_frames(pts_deg[f_idx], 3) 145 | deg_frames.pop(1) 146 | org_frames, deg_frames = self._prepare_frame(org_frames, deg_frames) 147 | if self.augment: 148 | org_frames, deg_frames = self._augment_frame(org_frames, deg_frames) 149 | ret = (*flatten_once(org_frames), *flatten_once(deg_frames)) 150 | if perf_debug: 151 | print(f'Prepared data {idx}, in {time.perf_counter() - start}') 152 | return ret 153 | 154 | 155 | def VideoFrameDataset(index_file, patch_size, scale_factor, augment, normalizer): 156 | ds = dataset.GeneratorDataset( 157 | VideoFrameGenerator(index_file, patch_size, scale_factor, augment), 158 | column_names=[f'{s}{i}_{p}' for s in ('h', 'l') for i in range(3) for p in ('y', 'uv')], 159 | shuffle=False, 160 | python_multiprocessing=True, 161 | num_parallel_workers=3, 162 | ) 163 | 164 | mean, std = normalizer.yuv_dist() 165 | 166 | for col in [f'{s}{i}' for s in ('h', 'l') for i in range(3)]: 167 | ds = ds.map([ 168 | transforms.TypeCast(ms.float32), 169 | vision.Rescale(1.0 / 255.0, 0), 170 | vision.Normalize(mean[:1], std[:1], is_hwc=False) 171 | ], col + '_y') 172 | ds = ds.map([ 173 | transforms.TypeCast(ms.float32), 174 | vision.Rescale(1.0 / 255.0, 0), 175 | vision.Normalize(mean[1:], std[1:], is_hwc=False) 176 | ], col + '_uv') 177 | 178 | return ds 179 | -------------------------------------------------------------------------------- /mindspore/environment-full.yml: -------------------------------------------------------------------------------- 1 | name: MindSpore 2 | channels: 3 | - conda-forge 4 | dependencies: 5 | - _openmp_mutex=4.5=2_gnu 6 | - alsa-lib=1.2.8=h4e544f5_0 7 | - aom=3.5.0=headf329_0 8 | - attr=2.5.1=h4e544f5_1 9 | - av=10.0.0=py37hc9521df_1 10 | - bzip2=1.0.8=hf897c2e_4 11 | - c-ares=1.18.1=hf897c2e_0 12 | - ca-certificates=2022.12.7=h4fd8a4c_0 13 | - cairo=1.16.0=hd19fb6e_1014 14 | - colorama=0.4.6=pyhd8ed1ab_0 15 | - dbus=1.13.6=h12b9eeb_3 16 | - expat=2.5.0=ha18d298_0 17 | - ffmpeg=5.1.2=gpl_h8bd3c30_106 18 | - fftw=3.3.10=nompi_ha1d0423_106 19 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 20 | - font-ttf-inconsolata=3.000=h77eed37_0 21 | - font-ttf-source-code-pro=2.038=h77eed37_0 22 | - font-ttf-ubuntu=0.83=hab24e00_0 23 | - fontconfig=2.14.2=ha9a116f_0 24 | - fonts-conda-ecosystem=1=0 25 | - fonts-conda-forge=1=0 26 | - freeglut=3.2.2=h01db608_1 27 | - freetype=2.12.1=hbbbf32d_1 28 | - gettext=0.21.1=ha18d298_0 29 | - glib=2.74.1=h7866ba4_1 30 | - glib-tools=2.74.1=h7866ba4_1 31 | - gmp=6.2.1=h7fd3ca4_0 32 | - gnutls=3.7.8=h5e100cc_0 33 | - graphite2=1.3.13=h7fd3ca4_1001 34 | - gst-plugins-base=1.21.3=h8a62080_1 35 | - gstreamer=1.21.3=h1f26242_1 36 | - gstreamer-orc=0.4.33=h4e544f5_0 37 | - harfbuzz=5.3.0=h6f3452c_0 38 | - hdf5=1.12.2=nompi_h3900512_101 39 | - icu=70.1=ha18d298_0 40 | - jack=1.9.22=hf8b18a5_0 41 | - jasper=2.0.33=h31812aa_1 42 | - jpeg=9e=h2a766a3_3 43 | - keyutils=1.6.1=h4e544f5_0 44 | - krb5=1.20.1=h113d92e_0 45 | - lame=3.100=h4e544f5_1003 46 | - lcms2=2.14=h5246980_0 47 | - ld_impl_linux-aarch64=2.36.1=h02ad14f_2 48 | - lerc=4.0.0=h4de3ea5_0 49 | - libaec=1.0.6=hd600fc2_1 50 | - libblas=3.9.0=16_linuxaarch64_openblas 51 | - libcap=2.66=hbb70f59_0 52 | - libcblas=3.9.0=16_linuxaarch64_openblas 53 | - libclang=15.0.7=default_h3fabc39_1 54 | - libclang13=15.0.7=default_haf32b04_1 55 | - libcups=2.3.3=h4303303_3 56 | - libcurl=7.88.1=h6ad7c7a_0 57 | - libdb=6.2.32=h01db608_0 58 | - libdeflate=1.14=h4e544f5_0 59 | - libedit=3.1.20191231=he28a2e2_2 60 | - libev=4.33=h516909a_1 61 | - libevent=2.1.10=h4f30969_4 62 | - libffi=3.4.2=h3557bc0_5 63 | - libflac=1.4.2=h4de3ea5_0 64 | - libgcc-ng=12.1.0=h3242a24_16 65 | - libgcrypt=1.10.1=h4e544f5_0 66 | - libgfortran-ng=12.2.0=he9431aa_19 67 | - libgfortran5=12.2.0=hf695500_19 68 | - libglib=2.74.1=h01e6fbd_1 69 | - libglu=9.0.0=he1b5a44_1001 70 | - libgomp=12.1.0=h3242a24_16 71 | - libgpg-error=1.46=haae8ae4_0 72 | - libiconv=1.17=h9cdd2b7_0 73 | - libidn2=2.3.4=h4e544f5_0 74 | - liblapack=3.9.0=16_linuxaarch64_openblas 75 | - liblapacke=3.9.0=16_linuxaarch64_openblas 76 | - libllvm15=15.0.7=h87099f9_0 77 | - libnghttp2=1.51.0=h674c3cc_0 78 | - libnsl=2.0.0=hf897c2e_0 79 | - libogg=1.3.4=h3557bc0_1 80 | - libopenblas=0.3.21=pthreads_h6cb6f83_3 81 | - libopencv=4.6.0=py37h50d0c34_5 82 | - libopus=1.3.1=hf897c2e_1 83 | - libpng=1.6.39=hf9034f9_0 84 | - libpq=15.2=hd21d9e6_0 85 | - libprotobuf=3.21.12=h7fe2111_0 86 | - libsndfile=1.2.0=h693ebdd_0 87 | - libsqlite=3.40.0=hf9034f9_0 88 | - libssh2=1.10.0=he5a64b1_3 89 | - libstdcxx-ng=12.1.0=hd01590b_16 90 | - libsystemd0=252=hc6ee767_0 91 | - libtasn1=4.19.0=h4e544f5_0 92 | - libtiff=4.4.0=hfcd36d1_5 93 | - libtool=2.4.7=h4de3ea5_0 94 | - libudev1=253=hb4cce97_0 95 | - libunistring=0.9.10=hf897c2e_0 96 | - libuuid=2.32.1=hf897c2e_1000 97 | - libvorbis=1.3.7=h01db608_0 98 | - libvpx=1.11.0=h01db608_3 99 | - libwebp-base=1.2.4=h4e544f5_0 100 | - libxcb=1.13=h3557bc0_1004 101 | - libxkbcommon=1.5.0=h4f22d97_0 102 | - libxml2=2.10.3=h249b6dd_0 103 | - libzlib=1.2.13=h4e544f5_4 104 | - lz4-c=1.9.4=hd600fc2_0 105 | - mpg123=1.31.2=hd600fc2_0 106 | - mysql-common=8.0.32=hed3ad84_0 107 | - mysql-libs=8.0.32=h42d7160_0 108 | - nano=7.2=h9b990bf_0 109 | - ncurses=6.3=headf329_1 110 | - nettle=3.8.1=hcc5b78b_1 111 | - nspr=4.35=h4de3ea5_0 112 | - nss=3.88=hf608148_0 113 | - opencv=4.6.0=py37hd9ded2f_5 114 | - openh264=2.3.1=hd600fc2_2 115 | - openjpeg=2.5.0=h9b6de37_1 116 | - openssl=3.0.8=hb4cce97_0 117 | - p11-kit=0.24.1=h9f2702f_0 118 | - pcre2=10.40=he7b27c6_0 119 | - pillow=9.2.0=py37h5e1b7c7_2 120 | - pixman=0.40.0=hb9de7d4_0 121 | - pthread-stubs=0.4=hb9de7d4_1001 122 | - pulseaudio=16.1=h7a898ed_1 123 | - py-opencv=4.6.0=py37ha4e61c6_5 124 | - python=3.7.10=h47f6e27_104_cpython 125 | - python_abi=3.7=2_cp37m 126 | - qt-main=5.15.6=h9e5b47b_5 127 | - readline=8.1.2=h38e3740_0 128 | - setuptools=63.1.0=py37hd9ded2f_0 129 | - sqlite=3.39.0=hc74f5b8_0 130 | - svt-av1=1.4.1=hd600fc2_0 131 | - tk=8.6.12=hd8af866_0 132 | - wheel=0.37.1=pyhd8ed1ab_0 133 | - x264=1!164.3095=h4e544f5_2 134 | - x265=3.5=hdd96247_3 135 | - xcb-util=0.4.0=h4e544f5_0 136 | - xcb-util-image=0.4.0=h4e544f5_0 137 | - xcb-util-keysyms=0.4.0=h4e544f5_0 138 | - xcb-util-renderutil=0.3.9=h4e544f5_0 139 | - xcb-util-wm=0.4.1=h4e544f5_0 140 | - xorg-fixesproto=5.0=h3557bc0_1002 141 | - xorg-inputproto=2.3.2=h3557bc0_1002 142 | - xorg-kbproto=1.0.7=h3557bc0_1002 143 | - xorg-libice=1.0.10=h3557bc0_0 144 | - xorg-libsm=1.2.3=h965e137_1000 145 | - xorg-libx11=1.7.2=h3557bc0_0 146 | - xorg-libxau=1.0.9=h3557bc0_0 147 | - xorg-libxdmcp=1.1.3=h3557bc0_0 148 | - xorg-libxext=1.3.4=h2a766a3_2 149 | - xorg-libxfixes=5.0.3=h3557bc0_1004 150 | - xorg-libxi=1.7.10=h3557bc0_0 151 | - xorg-libxrender=0.9.10=h3557bc0_1003 152 | - xorg-renderproto=0.11.1=h3557bc0_1002 153 | - xorg-xextproto=7.3.0=h2a766a3_1003 154 | - xorg-xproto=7.0.31=h3557bc0_1007 155 | - xz=5.2.6=h9cdd2b7_0 156 | - zlib=1.2.13=h4e544f5_4 157 | - zstd=1.5.2=h44f6412_6 158 | - pip: 159 | - absl-py==0.13.0 160 | - addict==2.4.0 161 | - albumentations==0.4.5 162 | - asgiref==3.5.2 163 | - astor==0.8.1 164 | - asttokens==2.0.8 165 | - astunparse==1.6.3 166 | - attrs==19.3.0 167 | - backcall==0.2.0 168 | - boto3==1.12.22 169 | - botocore==1.15.49 170 | - certifi==2022.6.15 171 | - cffi==1.14.0 172 | - chardet==3.0.4 173 | - charset-normalizer==2.0.12 174 | - click==8.1.3 175 | - cloudpickle==1.3.0 176 | - cycler==0.11.0 177 | - cython==0.29.14 178 | - dask==2.18.1 179 | - decorator==4.4.1 180 | - django==3.2.15 181 | - docutils==0.15.2 182 | - easydict==1.9 183 | - entrypoints==0.4 184 | - esdk-obs-python==3.20.1 185 | - et-xmlfile==1.1.0 186 | - flask==2.0.1 187 | - fonttools==4.37.1 188 | - future==0.18.2.post20200723173923 189 | - gast==0.2.2 190 | - google-pasta==0.2.0 191 | - grpcio==1.48.1 192 | - grpcio-tools==1.26.0 193 | - gunicorn==20.1.0 194 | - h5py==2.10.0 195 | - huaweicloud-sdk-python-modelarts-dataset==0.1.5 196 | - idna==2.10 197 | - image==1.5.28 198 | - imageio==2.9.0 199 | - imgaug==0.2.6 200 | - importlib-metadata==4.12.0 201 | - ipykernel==5.3.4 202 | - ipython==7.34.0 203 | - ipython-genutils==0.2.0 204 | - itsdangerous==2.1.2 205 | - jdcal==1.4.1 206 | - jedi==0.18.1 207 | - jinja2==3.0.1 208 | - jmespath==0.10.0 209 | - joblib==1.1.0 210 | - jupyter-client==7.3.4 211 | - jupyter-core==4.11.1 212 | - keras==2.3.1 213 | - keras-applications==1.0.8 214 | - keras-preprocessing==1.1.2 215 | - kfac==0.2.0 216 | - kiwisolver==1.1.0 217 | - lazy-import==0.2.2 218 | - llvmlite==0.31.0 219 | - lxml==4.4.2 220 | - markdown==3.4.1 221 | - markupsafe==2.1.1 222 | - marshmallow==3.17.1 223 | - matplotlib==3.5.1 224 | - matplotlib-inline==0.1.3 225 | - mindarmour==1.7.0 226 | - mindinsight==1.7.0 227 | - mindspore-ascend==1.9.0 228 | - mmcv==0.2.14 229 | - modelarts-mindspore-model-server==1.0.4 230 | - moxing-framework==2.0.1.rc0.ffd1c0c8 231 | - mpmath==1.2.1 232 | - nest-asyncio==1.5.5 233 | - networkx==2.6.3 234 | - numba==0.47.0 235 | - numexpr==2.7.1 236 | - numpy==1.21.2 237 | - opencv-python==4.2.0.34 238 | - openpyxl==3.0.3 239 | - opt-einsum==3.3.0 240 | - packaging==21.3 241 | - pandas==1.1.3 242 | - parso==0.8.3 243 | - pathlib2==2.3.7.post1 244 | - pexpect==4.8.0 245 | - pickleshare==0.7.5 246 | - pip==22.1.2 247 | - prometheus-client==0.8.0 248 | - prompt-toolkit==3.0.30 249 | - protobuf==4.21.5 250 | - psutil==5.7.0 251 | - ptyprocess==0.7.0 252 | - pycocotools==2.0.0 253 | - pycparser==2.21 254 | - pygments==2.12.0 255 | - pyparsing==3.0.9 256 | - python-dateutil==2.8.2 257 | - pytz==2022.2.1 258 | - pywavelets==1.1.1 259 | - pyyaml==5.3.1 260 | - pyzmq==23.2.0 261 | - requests==2.27.1 262 | - s3transfer==0.3.7 263 | - scikit-image==0.17.2 264 | - scikit-learn==0.24.0 265 | - scipy==1.5.4 266 | - shapely==1.8.4 267 | - six==1.16.0 268 | - sqlparse==0.4.2 269 | - sympy==1.4 270 | - tables==3.6.1 271 | - tensorboard==1.15.0 272 | - tensorflow==1.15.0 273 | - tensorflow-estimator==1.15.1 274 | - tensorflow-probability==0.10.1 275 | - termcolor==1.1.0 276 | - terminaltables==3.1.0 277 | - threadpoolctl==3.1.0 278 | - tifffile==2021.11.2 279 | - toml==0.10.1 280 | - tornado==6.2 281 | - tqdm==4.46.1 282 | - traitlets==5.3.0 283 | - treelib==1.6.1 284 | - typing-extensions==4.3.0 285 | - umap-learn-modified==0.3.8 286 | - urllib3==1.26.12 287 | - wcwidth==0.2.5 288 | - werkzeug==2.2.2 289 | - wrapt==1.14.1 290 | - xlsxwriter==3.0.3 291 | - xmltodict==0.12.0 292 | - yapf==0.32.0 293 | - zipp==3.8.1 294 | prefix: /home/ma-user/anaconda3/envs/MindSpore 295 | -------------------------------------------------------------------------------- /mindspore/environment.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - conda-forge 3 | dependencies: 4 | - av 5 | - opencv 6 | - tqdm 7 | -------------------------------------------------------------------------------- /mindspore/inference.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import namedtuple 3 | 4 | import numpy as np 5 | import mindspore as ms 6 | 7 | from model import CycMuNet 8 | from util.normalize import Normalizer 9 | from dataset.video import VideoFrameDataset 10 | 11 | dummyArg = namedtuple('dummyArg', ( 12 | 'nf', 'groups', 'upscale_factor', 'format', 'layers', 'cycle_count', 'batch_mode', 'all_frames', 13 | 'stop_at_conf')) 14 | 15 | size = 128 16 | args = dummyArg(nf=64, groups=8, upscale_factor=2, format='yuv420', layers=4, cycle_count=3, batch_mode='batch', 17 | all_frames=True, stop_at_conf=False) 18 | 19 | ds_path = r"/home/ma-user/work/cctv-scaled/" 20 | # ds_path = r"./test-files/" 21 | 22 | if __name__ == '__main__': 23 | ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend") 24 | # ms.set_context(mode=ms.PYNATIVE_MODE, device_target="Ascend") 25 | # ms.set_context(mode=ms.GRAPH_MODE, device_target="CPU") 26 | # ms.set_context(mode=ms.PYNATIVE_MODE, device_target="CPU") 27 | 28 | print('Init done') 29 | model = CycMuNet(args) 30 | ms.load_checkpoint("model-files/2x_yuv420_cycle3_layer4.ckpt", model) 31 | 32 | # model = model.to_float(ms.float16) 33 | print('Load done') 34 | 35 | nm = Normalizer() 36 | ds_test = VideoFrameDataset(ds_path + "index-test.txt", size, args.upscale_factor, True, nm) 37 | ds_test = ds_test.batch(1) 38 | 39 | start = time.time() 40 | for n, data in enumerate(ds_test.create_tuple_iterator()): 41 | start = time.time() 42 | # data = [t.astype(ms.float16) for t in data] 43 | inputs = [data[6:8], data[10:12]] 44 | model(*inputs) 45 | print(f"#{n:0>3} inference in {time.time() - start}s") 46 | -------------------------------------------------------------------------------- /mindspore/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import CycMuNet 2 | from .model_sep import CycMuNet as CycMuNetSep 3 | from .train import TrainModel 4 | -------------------------------------------------------------------------------- /mindspore/model/deform_conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import mindspore as ms 4 | from mindspore import nn, ops 5 | from mindspore.common import initializer as init 6 | from mindspore._checkparam import twice 7 | 8 | 9 | def _Conv2d(in_channels, 10 | out_channels, 11 | kernel_size, 12 | stride=1, 13 | padding=0, 14 | dilation=1): 15 | return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, 16 | pad_mode='pad', padding=padding, dilation=dilation, has_bias=True) 17 | 18 | 19 | class DCN_sep(nn.Cell): 20 | def __init__(self, 21 | in_channels, 22 | in_channels_features, 23 | out_channels, 24 | kernel_size, 25 | stride=1, 26 | padding=0, 27 | dilation=1, 28 | groups=1, 29 | deformable_groups=1, 30 | bias=True, 31 | mask=True): 32 | super(DCN_sep, self).__init__() 33 | 34 | kernel_size_ = twice(kernel_size) 35 | 36 | self.dcn_weight = ms.Parameter( 37 | ms.Tensor( 38 | shape=(out_channels, in_channels // groups, kernel_size_[0], kernel_size_[1]), 39 | dtype=ms.float32, 40 | init=init.HeUniform(negative_slope=math.sqrt(5)) 41 | ) 42 | ) 43 | 44 | fan_in = in_channels // groups * kernel_size_[0] * kernel_size_[1] 45 | bound = 1 / math.sqrt(fan_in) 46 | 47 | self.dcn_bias = ms.Parameter( 48 | ms.Tensor( 49 | shape=(out_channels,), 50 | dtype=ms.float32, 51 | init=init.Uniform(bound) 52 | ) 53 | ) if bias else None 54 | 55 | self.dcn_kwargs = { 56 | 'kernel_size': kernel_size_, 57 | 'strides': (1, 1, *twice(stride)), 58 | 'padding': (padding,) * 4 if isinstance(padding, int) else padding, 59 | 'dilations': (1, 1, *twice(dilation)), 60 | 'groups': groups, 61 | 'deformable_groups': deformable_groups, 62 | 'modulated': mask 63 | } 64 | 65 | offset_channels = deformable_groups * kernel_size_[0] * kernel_size_[1] 66 | 67 | self.conv_offset = _Conv2d(in_channels_features, offset_channels * 2, kernel_size=kernel_size, 68 | stride=stride, padding=padding, dilation=dilation) 69 | if mask: 70 | self.conv_mask = _Conv2d(in_channels_features, offset_channels, kernel_size=kernel_size, 71 | stride=stride, padding=padding, dilation=dilation) 72 | else: 73 | raise NotImplementedError() 74 | self.relu = nn.ReLU() 75 | 76 | def construct(self, input, feature): 77 | offset = self.conv_offset(feature) 78 | mask = ops.sigmoid(self.conv_mask(feature)) 79 | offsets = ops.concat([offset, mask], axis=1) 80 | 81 | return ops.deformable_conv2d(input, self.dcn_weight, offsets, bias=self.dcn_bias, **self.dcn_kwargs) 82 | 83 | 84 | # Same as DCN_sep but compatible with Ascend. 85 | # Can be removed once deformable_groups can be values other than 1 on Ascend. 86 | class DCN_sep_compat(nn.Cell): 87 | def __init__(self, 88 | in_channels, 89 | in_channels_features, 90 | out_channels, 91 | kernel_size, 92 | stride=1, 93 | padding=0, 94 | dilation=1, 95 | groups=1, 96 | deformable_groups=1, 97 | bias=True, 98 | mask=True): 99 | super(DCN_sep_compat, self).__init__() 100 | if deformable_groups == 1: 101 | raise ValueError("Use DCN_sep") 102 | 103 | if groups != 1: 104 | raise NotImplementedError() 105 | 106 | self.separated = groups != 1 107 | 108 | kernel_size_ = twice(kernel_size) 109 | 110 | self.deformable_groups = deformable_groups 111 | self.dcn_weight = ms.Parameter( 112 | ms.Tensor( 113 | shape=(out_channels, in_channels // groups, kernel_size_[0], kernel_size_[1]), 114 | dtype=ms.float32, 115 | init=init.HeUniform(negative_slope=math.sqrt(5)) 116 | ) 117 | ) 118 | 119 | fan_in = in_channels // groups * kernel_size_[0] * kernel_size_[1] 120 | bound = 1 / math.sqrt(fan_in) 121 | 122 | self.dcn_bias = ms.Parameter( 123 | ms.Tensor( 124 | shape=(out_channels,), 125 | dtype=ms.float32, 126 | init=init.Uniform(bound) 127 | ) 128 | ) if bias else None 129 | 130 | self.dcn_kwargs = { 131 | 'kernel_size': kernel_size_, 132 | 'strides': (1, 1, *twice(stride)), 133 | 'padding': (padding,) * 4 if isinstance(padding, int) else padding, 134 | 'dilations': (1, 1, *twice(dilation)), 135 | 'groups': 1, 136 | 'deformable_groups': 1, 137 | 'modulated': mask is not None 138 | } 139 | 140 | offset_channels = deformable_groups * kernel_size_[0] * kernel_size_[1] 141 | 142 | self.conv_offset = _Conv2d(in_channels_features, offset_channels * 2, kernel_size=kernel_size, 143 | stride=stride, padding=padding, dilation=dilation) 144 | if mask: 145 | self.conv_mask = _Conv2d(in_channels_features, offset_channels, kernel_size=kernel_size, 146 | stride=stride, padding=padding, dilation=dilation) 147 | else: 148 | raise NotImplementedError() 149 | self.relu = nn.ReLU() 150 | 151 | def construct(self, input, feature): 152 | offset = self.conv_offset(feature) 153 | mask = ops.sigmoid(self.conv_mask(feature)) 154 | offset_y, offset_x = ops.split(offset, axis=1, output_num=2) 155 | 156 | inputs = ops.split(input, axis=1, output_num=self.deformable_groups) 157 | dcn_weights = ops.split(self.dcn_weight, axis=1, output_num=self.deformable_groups) 158 | offset_ys = ops.split(offset_y, axis=1, output_num=self.deformable_groups) 159 | offset_xs = ops.split(offset_y, axis=1, output_num=self.deformable_groups) 160 | masks = ops.split(mask, axis=1, output_num=self.deformable_groups) 161 | 162 | output = None 163 | for i in range(self.deformable_groups): 164 | offsets = ops.concat([offset_ys[i], offset_xs[i], masks[i]], axis=1) 165 | if output is None: 166 | output = ops.deformable_conv2d(inputs[i], dcn_weights[i], offsets, bias=self.dcn_bias, 167 | **self.dcn_kwargs) 168 | else: 169 | output += ops.deformable_conv2d(inputs[i], dcn_weights[i], offsets, bias=None, **self.dcn_kwargs) 170 | 171 | return output -------------------------------------------------------------------------------- /mindspore/model/train.py: -------------------------------------------------------------------------------- 1 | from mindspore import nn 2 | 3 | 4 | class TrainModel(nn.Cell): 5 | def __init__(self, network, loss): 6 | super().__init__() 7 | self.net = network 8 | self.loss = loss 9 | 10 | def loss_frame(self, expected, actual): 11 | return self.loss(expected[0], actual[0]) + self.loss(expected[1], actual[1]) 12 | 13 | def construct(self, *data): 14 | inputs = [data[6:8], data[10:12]] 15 | expected = [data[0:2], data[2:4], data[4:6], data[8:10]] 16 | actual = self.net(*inputs) 17 | loss = self.loss_frame(expected[0], actual[0]) * 0.5 + \ 18 | self.loss_frame(expected[1], actual[1]) + \ 19 | self.loss_frame(expected[2], actual[2]) * 0.5 + \ 20 | self.loss_frame(expected[3], actual[3]) 21 | return loss 22 | -------------------------------------------------------------------------------- /mindspore/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from collections import namedtuple 4 | import argparse 5 | 6 | import numpy as np 7 | import tqdm 8 | import mindspore as ms 9 | from mindspore import nn, ops 10 | 11 | from model import CycMuNet, TrainModel 12 | from util.rmse import RMSELoss 13 | from util.normalize import Normalizer 14 | from dataset.video import VideoFrameDataset 15 | 16 | print("Initialized.") 17 | 18 | dummyArg = namedtuple('dummyArg', ( 19 | 'nf', 'groups', 'upscale_factor', 'format', 'layers', 'cycle_count', 'batch_mode', 'all_frames', 'stop_at_conf')) 20 | 21 | size = 128 22 | args = dummyArg(nf=64, groups=8, upscale_factor=2, format='yuv420', layers=4, cycle_count=3, batch_mode='batch', 23 | all_frames=True, stop_at_conf=False) 24 | 25 | epochs = 1 26 | batch_size = 1 27 | learning_rate = 0.001 28 | save_prefix = 'monitor' 29 | 30 | ds_path = r"/home/ma-user/work/cctv-scaled/" 31 | pretrained = "/home/ma-user/work/cycmunet-ms/model-files/2x_yuv420_cycle3_layer4.ckpt" 32 | save_path = "/home/ma-user/work/cycmunet-ms/checkpoints/" 33 | 34 | # ds_path = r"./test-files/index-train.txt" 35 | # pretrained = "./model-files/2x_yuv420_cycle3_layer4.ckpt" 36 | # save_path = "./checkpoints/" 37 | 38 | # ds_path = r"D:\Python\cycmunet-ms\test-files/" 39 | # pretrained = "" 40 | 41 | # parser = argparse.ArgumentParser( 42 | # prog='CycMuNet+ MindSpore Training') 43 | # 44 | # parser.add_argument('--dataset') 45 | # parser.add_argument('--pretrained') 46 | # parser.add_argument('--save') 47 | # 48 | # cmd_args = parser.parse_args() 49 | # 50 | # ds_path = cmd_args.dataset 51 | # pretrained = cmd_args.pretrained 52 | # save_path = cmd_args.save 53 | 54 | 55 | save_prefix = f'{save_prefix}_{args.upscale_factor}x_l{args.layers}_c{args.cycle_count}' 56 | 57 | network = CycMuNet(args) 58 | if pretrained: 59 | ms.load_checkpoint(pretrained, network) 60 | 61 | 62 | ms.set_context(mode=ms.GRAPH_MODE, device_target="Ascend") 63 | 64 | inp = (ms.Tensor(np.zeros((batch_size, 1, size, size), dtype=np.float32)), 65 | ms.Tensor(np.zeros((batch_size, 2, size // 2, size // 2), dtype=np.float32))) 66 | network.compile(inp, inp) 67 | inp = None 68 | 69 | 70 | loss_fn = RMSELoss() 71 | 72 | nm = Normalizer() 73 | ds_train = VideoFrameDataset(ds_path + "index-train.txt", size, args.upscale_factor, True, nm) 74 | ds_test = VideoFrameDataset(ds_path + "index-test.txt", size, args.upscale_factor, True, nm) 75 | 76 | ds_train = ds_train.batch(batch_size) 77 | ds_test = ds_test.batch(batch_size) 78 | 79 | scheduler = nn.CosineDecayLR(min_lr=1e-7, max_lr=learning_rate, decay_steps=640000) 80 | optimizer = nn.AdaMax(network.trainable_params(), learning_rate=learning_rate) 81 | 82 | 83 | model = TrainModel(network, loss_fn) 84 | model = ms.Model(model, optimizer=optimizer, eval_network=model, boost_level="O1") 85 | 86 | 87 | def save_model(epoch): 88 | if epoch == -1: 89 | name = "snapshot" 90 | else: 91 | name = f"epoch_{epoch}" 92 | if not os.path.exists(save_path): 93 | os.makedirs(save_path) 94 | output_path = save_path + f"{save_prefix}_{name}.ckpt" 95 | ms.save_checkpoint(network, str(output_path)) 96 | print(f"Checkpoint saved to {output_path}") 97 | 98 | 99 | print("Start train.") 100 | 101 | profiler = ms.Profiler(output_path='./profiler_data') 102 | 103 | for t in range(1, epochs + 1): 104 | try: 105 | print(f"Epoch {t}\n-------------------------------") 106 | model.train(t, ds_train, dataset_sink_mode=True) 107 | save_model(t) 108 | except KeyboardInterrupt: 109 | save_model(-1) 110 | 111 | profiler.analyse() 112 | 113 | print("Done.") 114 | -------------------------------------------------------------------------------- /mindspore/util/converter.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tongyuantongyu/cycmunet/dfb04885a9c2b79e0f7ee1b24fdb8de648e9321c/mindspore/util/converter.py -------------------------------------------------------------------------------- /mindspore/util/normalize.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | class Normalizer: 5 | sqrt1_2 = 1 / math.sqrt(2) 6 | 7 | def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), kr=0.2126, kb=0.0722, depth=8): 8 | self.mean = mean 9 | self.std = std 10 | self.krgb = (kr, 1 - kr - kb, kb) 11 | self.depth = depth 12 | self.uv_bias = (1 << (depth - 1)) / ((1 << depth) - 1) 13 | 14 | @staticmethod 15 | def _inv(mean, std): 16 | inv_std = tuple(1 / i for i in std) 17 | inv_mean = tuple(-j * i for i, j in zip(inv_std, mean)) 18 | return inv_mean, inv_std 19 | 20 | def rgb_dist(self): 21 | return self.mean, self.std 22 | 23 | def _yuv_dist(self): 24 | rm, gm, bm = self.mean 25 | rs, gs, bs = self.std 26 | kr, kg, kb = self.krgb 27 | 28 | ym = rm * kr + gm * kg + bm * kb 29 | ys = math.sqrt((rs * kr) ** 2 + (gs * kg) ** 2 + (bs * kb) ** 2) 30 | um = (bm - ym) / (1 - kb) / 2 + self.uv_bias 31 | us = math.sqrt(bs * bs + ys * ys) / (1 - kb) / 2 32 | vm = (rm - ym) / (1 - kr) / 2 + self.uv_bias 33 | vs = math.sqrt(rs * rs + ys * ys) / (1 - kr) / 2 34 | return [ym, um, vm], [ys, us, vs] 35 | 36 | def yuv_dist(self, mode='yuv420'): 37 | mean, std = self._yuv_dist() 38 | if mode == 'yuv422': 39 | std[1], std[2] = std[1] * self.sqrt1_2, std[2] * self.sqrt1_2 40 | elif mode == 'yuv420': 41 | std[1], std[2] = std[1] * 0.5, std[2] * 0.5 42 | return mean, std 43 | -------------------------------------------------------------------------------- /mindspore/util/rmse.py: -------------------------------------------------------------------------------- 1 | from mindspore import nn 2 | from mindspore.ops import functional as F 3 | 4 | 5 | class RMSELoss(nn.LossBase): 6 | def __init__(self, epsilon=0.001): 7 | super(RMSELoss, self).__init__() 8 | self.epsilon = epsilon * epsilon 9 | self.MSELoss = nn.MSELoss() 10 | 11 | def construct(self, logits, label): 12 | rmse_loss = F.sqrt(self.MSELoss(logits, label) + self.epsilon) 13 | return rmse_loss 14 | -------------------------------------------------------------------------------- /tensorrt-conda/recipe/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2023 Yuan Tong 2 | 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 5 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 6 | 7 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 8 | 9 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /tensorrt-conda/recipe/bld.bat: -------------------------------------------------------------------------------- 1 | mkdir build 2 | cd build 3 | 4 | cmake -G Ninja -DVAPOURSYNTH_PLUGIN=ON -DCMAKE_BUILD_TYPE=Release .. 5 | if %ERRORLEVEL% neq 0 exit 1 6 | ninja vs-cycmunet 7 | if %ERRORLEVEL% neq 0 exit 1 8 | 9 | mkdir %PREFIX%\vapoursynth64\plugins 10 | copy vs-cycmunet.dll %PREFIX%\vapoursynth64\plugins\vs-cycmunet.dll 11 | -------------------------------------------------------------------------------- /tensorrt-conda/recipe/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -ex 3 | 4 | mkdir build 5 | cd build 6 | cmake -G Ninja -DVAPOURSYNTH_PLUGIN=ON -DCMAKE_BUILD_TYPE=Release .. 7 | ninja vs-cycmunet 8 | 9 | mkdir -p $PREFIX/lib/vapoursynth 10 | cp libvs-cycmunet.so $PREFIX/lib/vapoursynth/libvs-cycmunet.so 11 | -------------------------------------------------------------------------------- /tensorrt-conda/recipe/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set name = "VapourSynth-CycMuNet" %} 2 | {% set version = "0.0.2" %} 3 | 4 | package: 5 | name: {{ name|lower }} 6 | version: {{ version }} 7 | 8 | source: 9 | - path: ../../tensorrt 10 | 11 | build: 12 | number: 0 13 | 14 | requirements: 15 | build: 16 | - {{ compiler('c') }} 17 | - {{ compiler('cxx') }} 18 | - cmake 19 | - ninja 20 | - cuda-nvcc 21 | - cuda-cudart-dev 22 | - libcublas-dev 23 | - cuda-cudart-static # [linux] 24 | - sysroot_{{ target_platform }} >=2.17 # [linux] 25 | host: 26 | - cuda-cudart-dev 27 | - libnvinfer-dev 28 | - vapoursynth 29 | run: 30 | - cuda-cudart-dev # [win] 31 | - cuda-cudart # [linux] 32 | - libcublas-dev #[win] 33 | - libcublas # [linux] 34 | - libnvinfer 35 | - vapoursynth 36 | run_constrained: 37 | # Only GLIBC_2.17 or older symbols present 38 | - __glibc >=2.17 # [linux] 39 | 40 | about: 41 | home: https://github.com/tongyuantongyu/cycmunet 42 | license: BSD-2-Clause 43 | license_family: BSD 44 | license_file: LICENSE 45 | summary: VapourSynth plugin of CycMuNet+ 46 | description: | 47 | VapourSynth plugin of CycMuNet+, a Spatio-Temporal Super Resolution Neural Network. 48 | dev_url: https://github.com/tongyuantongyu/cycmunet 49 | doc_url: https://github.com/tongyuantongyu/cycmunet/blob/main/tensorrt/README.md 50 | 51 | extra: 52 | recipe-maintainers: 53 | - tongyuantongyu 54 | -------------------------------------------------------------------------------- /tensorrt-conda/recipe/run_test.py: -------------------------------------------------------------------------------- 1 | from vapoursynth import core 2 | 3 | print(core.cycmunet.CycMuNetVersion()) 4 | -------------------------------------------------------------------------------- /tensorrt/.clang-format: -------------------------------------------------------------------------------- 1 | # Generated from CLion C/C++ Code Style settings 2 | BasedOnStyle: LLVM 3 | AccessModifierOffset: -1 4 | AlignAfterOpenBracket: Align 5 | AlignConsecutiveAssignments: None 6 | AlignOperands: Align 7 | AllowAllArgumentsOnNextLine: false 8 | AllowAllConstructorInitializersOnNextLine: false 9 | AllowAllParametersOfDeclarationOnNextLine: false 10 | AllowShortBlocksOnASingleLine: Empty 11 | AllowShortCaseLabelsOnASingleLine: true 12 | AllowShortFunctionsOnASingleLine: Inline 13 | AllowShortIfStatementsOnASingleLine: Never 14 | AllowShortLambdasOnASingleLine: All 15 | AllowShortLoopsOnASingleLine: false 16 | AlwaysBreakAfterReturnType: None 17 | AlwaysBreakTemplateDeclarations: Yes 18 | BreakBeforeBraces: Custom 19 | BraceWrapping: 20 | AfterCaseLabel: false 21 | AfterClass: false 22 | AfterControlStatement: Never 23 | AfterEnum: false 24 | AfterFunction: false 25 | AfterNamespace: false 26 | AfterUnion: false 27 | BeforeCatch: false 28 | BeforeElse: true 29 | IndentBraces: false 30 | SplitEmptyFunction: false 31 | SplitEmptyRecord: true 32 | BreakBeforeBinaryOperators: None 33 | BreakBeforeTernaryOperators: true 34 | BreakConstructorInitializers: BeforeColon 35 | BreakInheritanceList: BeforeColon 36 | ColumnLimit: 120 37 | CompactNamespaces: false 38 | ContinuationIndentWidth: 4 39 | IndentCaseLabels: true 40 | IndentPPDirectives: None 41 | IndentWidth: 2 42 | KeepEmptyLinesAtTheStartOfBlocks: true 43 | MaxEmptyLinesToKeep: 1 44 | NamespaceIndentation: None 45 | ObjCSpaceAfterProperty: false 46 | ObjCSpaceBeforeProtocolList: false 47 | PointerAlignment: Right 48 | ReflowComments: false 49 | SpaceAfterCStyleCast: true 50 | SpaceAfterLogicalNot: false 51 | SpaceAfterTemplateKeyword: false 52 | SpaceBeforeAssignmentOperators: true 53 | SpaceBeforeCpp11BracedList: true 54 | SpaceBeforeCtorInitializerColon: true 55 | SpaceBeforeInheritanceColon: true 56 | SpaceBeforeParens: ControlStatements 57 | SpaceBeforeRangeBasedForLoopColon: false 58 | SpaceInEmptyParentheses: false 59 | SpacesBeforeTrailingComments: 0 60 | SpacesInAngles: false 61 | SpacesInCStyleCastParentheses: false 62 | SpacesInContainerLiterals: false 63 | SpacesInParentheses: false 64 | SpacesInSquareBrackets: false 65 | TabWidth: 4 66 | UseTab: Never 67 | -------------------------------------------------------------------------------- /tensorrt/.gitignore: -------------------------------------------------------------------------------- 1 | ### C++ template 2 | # Prerequisites 3 | *.d 4 | 5 | # Compiled Object files 6 | *.slo 7 | *.lo 8 | *.o 9 | *.obj 10 | 11 | # Precompiled Headers 12 | *.gch 13 | *.pch 14 | 15 | # Compiled Dynamic libraries 16 | *.so 17 | *.dylib 18 | *.dll 19 | 20 | # Fortran module files 21 | *.mod 22 | *.smod 23 | 24 | # Compiled Static libraries 25 | *.lai 26 | *.la 27 | *.a 28 | *.lib 29 | 30 | # Executables 31 | *.exe 32 | *.out 33 | *.app 34 | 35 | # Debug files 36 | *.dSYM/ 37 | *.su 38 | *.idb 39 | *.pdb 40 | 41 | ### CUDA 42 | *.i 43 | *.ii 44 | *.gpu 45 | *.ptx 46 | *.cubin 47 | *.fatbin 48 | 49 | ### JetBrains 50 | .idea 51 | 52 | ### CMake 53 | CMakeLists.txt.user 54 | CMakeCache.txt 55 | CMakeFiles 56 | CMakeScripts 57 | Testing 58 | Makefile 59 | cmake_install.cmake 60 | install_manifest.txt 61 | compile_commands.json 62 | CTestTestfile.cmake 63 | _deps 64 | CMakeSettings.json 65 | cmake-build-*/ 66 | 67 | ### Project 68 | _* 69 | *.onnx 70 | *.pb 71 | *.engine 72 | *.cache 73 | *.y4m 74 | *.ffindex -------------------------------------------------------------------------------- /tensorrt/CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.18) 2 | list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/Modules") 3 | 4 | set(CMAKE_CXX_STANDARD 20) 5 | set(CMAKE_CUDA_STANDARD 20) 6 | 7 | if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) 8 | if(CMAKE_SYSTEM_PROCESSOR STREQUAL aarch64) 9 | set(CMAKE_CUDA_ARCHITECTURES 53 62 72 87) 10 | else() 11 | set(CMAKE_CUDA_ARCHITECTURES 61 70 75 80 86 89 90) 12 | endif() 13 | endif() 14 | 15 | project(Cycmunet-TRT LANGUAGES CXX CUDA) 16 | 17 | enable_testing() 18 | 19 | option(CUDA_DEVICE_DEBUG "Enable device debug" OFF) 20 | if(CUDA_DEVICE_DEBUG) 21 | set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G") 22 | endif() 23 | 24 | if(MSVC) 25 | add_compile_options("$<$:-Xcompiler=/source-charset:utf-8 /execution-charset:us-ascii /wd4996>") 26 | add_compile_options($<$:/source-charset:utf-8$/execution-charset:us-ascii$/wd4996>) 27 | endif() 28 | 29 | find_package(CUDAToolkit 12.0 REQUIRED COMPONENTS cublas) 30 | find_package(TensorRT 8.6 REQUIRED COMPONENTS OnnxParser) 31 | 32 | option(BUILD_TESTS "Build unit test" OFF) 33 | if(BUILD_TESTS) 34 | find_package(GTest 1 REQUIRED) 35 | endif() 36 | 37 | option(VAPOURSYNTH_PLUGIN "Enable Vapoursynth plugin" OFF) 38 | if(VAPOURSYNTH_PLUGIN) 39 | find_package(VapourSynth 4 REQUIRED) 40 | endif() 41 | 42 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) 43 | 44 | set(COMMON_HEADERS include/helper.h include/md_view.h include/utils.h include/logging.h) 45 | if(CMAKE_BUILD_TYPE MATCHES DEBUG) 46 | set(COMMON_HEADERS ${COMMON_HEADERS} include/debug/reveal.h) 47 | include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include/debug) 48 | endif() 49 | 50 | add_library(_dcn_layer_impl OBJECT 51 | ${COMMON_HEADERS} 52 | layers/include/internal/config.h 53 | 54 | layers/include/internal/dcn_layer_impl.h 55 | layers/impl/dcn_layer.cu) 56 | target_include_directories(_dcn_layer_impl PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/layers/include/internal) 57 | target_link_libraries(_dcn_layer_impl CUDA::cublas) 58 | 59 | if(BUILD_TESTS) 60 | add_executable(dcn_layer_test layers/test/dcn_layer_test.cpp) 61 | target_include_directories(dcn_layer_test PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/layers/include/internal) 62 | target_link_libraries(dcn_layer_test _dcn_layer_impl CUDA::cudart GTest::gtest_main) 63 | add_test(NAME dcn_layer_test COMMAND dcn_layer_test WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/testdata) 64 | endif() 65 | 66 | option(AUTO_REGISTER_PLUGIN "Automatically register plugin on load of shared library" OFF) 67 | if(AUTO_REGISTER_PLUGIN) 68 | add_library(trt_layers_plugin SHARED 69 | ${COMMON_HEADERS} 70 | layers/include/internal/config.h 71 | 72 | layers/include/internal/dcn_layer.h 73 | layers/include/internal/dcn_layer_impl.h 74 | layers/src/dcn_layer.cpp 75 | 76 | layers/src/layers.cpp) 77 | target_compile_definitions(trt_layers_plugin PRIVATE AUTO_REGISTER_PLUGIN) 78 | else() 79 | add_library(trt_layers_plugin STATIC 80 | ${COMMON_HEADERS} 81 | layers/include/internal/config.h 82 | 83 | layers/include/internal/dcn_layer.h 84 | layers/include/internal/dcn_layer_impl.h 85 | layers/src/dcn_layer.cpp 86 | 87 | layers/src/layers.cpp) 88 | endif() 89 | 90 | target_compile_definitions(trt_layers_plugin PRIVATE BUILDING_PLUGIN) 91 | target_include_directories(trt_layers_plugin PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/layers/include/internal) 92 | target_include_directories(trt_layers_plugin PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/layers/include) 93 | target_link_libraries(trt_layers_plugin PUBLIC _dcn_layer_impl TensorRT::NvInfer) 94 | set_target_properties(trt_layers_plugin PROPERTIES 95 | POSITION_INDEPENDENT_CODE ON) 96 | 97 | add_library(cycmunet STATIC include/config.h include/optimize.h include/inference.h src/optimize.cpp src/inference.cpp) 98 | target_link_libraries(cycmunet PUBLIC trt_layers_plugin TensorRT::OnnxParser TensorRT::NvInfer CUDA::cudart) 99 | set_target_properties(cycmunet PROPERTIES 100 | POSITION_INDEPENDENT_CODE ON) 101 | 102 | add_executable(optimizer app/optimizer.cpp) 103 | target_link_libraries(optimizer PRIVATE cycmunet) 104 | 105 | add_library(reformat_cuda OBJECT include/reformat.h src/reformat.cu) 106 | 107 | add_executable(cycmunet_y4m app/inference_y4m.cpp) 108 | target_link_libraries(cycmunet_y4m PRIVATE cycmunet reformat_cuda) 109 | 110 | if(VAPOURSYNTH_PLUGIN) 111 | add_library(vs-cycmunet SHARED src/vs-plugin.cpp) 112 | target_link_libraries(vs-cycmunet PRIVATE cycmunet reformat_cuda VapourSynth) 113 | set_target_properties(vs-cycmunet PROPERTIES 114 | C_VISIBILITY_PRESET hidden 115 | CXX_VISIBILITY_PRESET hidden 116 | CUDA_VISIBILITY_PRESET hidden) 117 | if(MSVC) 118 | target_link_libraries(vs-cycmunet PUBLIC delayimp) 119 | target_link_options(vs-cycmunet PUBLIC 120 | /DELAYLOAD:cudart64_12.dll 121 | /DELAYLOAD:cublas64_12.dll 122 | /DELAYLOAD:nvinfer.dll 123 | /DELAYLOAD:nvonnxparser.dll) 124 | endif() 125 | endif() 126 | -------------------------------------------------------------------------------- /tensorrt/README.md: -------------------------------------------------------------------------------- 1 | # TensorRT implementation of CycMuNet+ 2 | 3 | This is the TensorRT implementation of CycMuNet+ capable of running on NVIDIA GPU. 4 | 5 | ## Installation 6 | 7 | We provide precompiled binary of VapourSynth plugin for Windows and 8 | Linux x64 platforms on Anaconda, and recommend installing via conda: 9 | 10 | ```bash 11 | conda create -n cycmunet -c conda-forge -c nvidia -c tongyuantongyu vapoursynth-cycmunet 12 | ``` 13 | 14 | We recommend enable CUDA lazy loading by running the following command: 15 | 16 | ```bash 17 | conda env config vars set CUDA_MODULE_LOADING=LAZY -n cycmunet 18 | ``` 19 | 20 | ## Build 21 | 22 | Building requires CUDAToolkit and TensorRT installed. As usual, you can set 23 | `CUDAToolkit_ROOT` and `TensorRT_ROOT` if they are not installed in default location. 24 | 25 | To build VapourSynth plugin you need to provide VapourSynth headers as well. 26 | 27 | ```bash 28 | mkdir build && cd build 29 | cmake -G Ninja -DVAPOURSYNTH_PLUGIN=ON .. 30 | ninja 31 | ``` 32 | 33 | ## Usage 34 | 35 | ### `cycmunet_y4m` 36 | 37 | Standalone inference runner accepts y4m input and produces y4m output. 38 | 39 | Read input from stdin, and write output to stdout: 40 | 41 | ```bash 42 | cycmunet_y4m 43 | ``` 44 | 45 | Read input from stdin, and write output to `output.y4m`: 46 | 47 | ```bash 48 | cycmunet_y4m output.y4m 49 | ``` 50 | 51 | Read input from `input.y4m`, and write output to `output.y4m`: 52 | 53 | ```bash 54 | cycmunet_y4m input.y4m output.y4m 55 | ``` 56 | 57 | ### `vs-cycmunet` 58 | 59 | VapourSynth plugin. Supports RGB and YUV inputs. 60 | See [VapourSynth documentation](http://vapoursynth.com/doc) for the usage of 61 | VapourSynth. 62 | 63 | ```python 64 | def core.cycmunet.CycMuNet(clip: vs.VideoNode, 65 | scale: float, 66 | batch_size: int, 67 | batch_size_fusion: int, 68 | use_fp16: bool, 69 | extraction_layers: int, 70 | norm_mean: List[float, float, float], 71 | norm_std: List[float, float, float], 72 | raw_norm: bool, 73 | model: str, 74 | model_path: str, 75 | low_mem: bool): 76 | """ 77 | Run CycMuNet+ Spatio-Temporal Super Resolution on the input clip. 78 | 79 | :param clip: Input clip 80 | :param scale: Spatial scale ratio 81 | :param batch_size: Batch size for feature extract phase. Default: 1 82 | :param batch_size_fusion: Batch size for feature fusion phase. Default: batch_size 83 | :param use_fp16: Use FP16 (half precision) data format during inference. 84 | Half memory consumption, and runs faster. Requires at least Volta architecture. Default: False 85 | :param extraction_layers: Model extraction layers number. Default 4 86 | :param norm_mean: Mean value of each channel for normalization. Default: [0.485, 0.456, 0.406] 87 | :param norm_std: Standard derivation of each channel for normalization. Default: [0.229, 0.224, 0.225] 88 | :param raw_norm: Whether passed norm_mean and norm_std are values for input. If False, then mean and std are 89 | values of RGB in range [0, 1], and will be internally converted to values for input based on video properties. 90 | Default: False 91 | :param model: Model name used for inference. Default: "." 92 | :param model_path: Model storage location. Default is in "dev.tyty.aim.cycmunet" folder next to plugin file. 93 | :param low_mem: Enable tweaks to reduce memory consumption. Default: False 94 | :return: Output clip 95 | """ 96 | pass 97 | ``` 98 | 99 | See demo.vpy for a basic example of how to use the plugin in VapourSynth script. 100 | 101 | The following command can be used to get output video: 102 | ```bash 103 | vspipe demo.vpy -c y4m - | ffmpeg -i - -c:v libx264 output.mp4 104 | ``` 105 | 106 | The plugin reads `_SceneChangePrev` property on frame to handle scene change. 107 | You can use other plugins to add scene change information. -------------------------------------------------------------------------------- /tensorrt/app/optimizer.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by TYTY on 2023-01-13 013. 3 | // 4 | 5 | #include "layers.h" 6 | #include "optimize.h" 7 | #include "logging.h" 8 | 9 | static Logger gLogger(Logger::Severity::kINFO); 10 | 11 | int main() { 12 | UDOLayers::registerPlugins(); 13 | // OptimizationConfig config{256, 256, 1, 1, 64, 2, IOFormat::YUV420, 4, false}; 14 | OptimizationConfig config{1920, 1088, 1, 1, 64, 2, IOFormat::YUV420, 4, true}; 15 | // OptimizationConfig config{768, 544, {1, 3, 3}, {1, 3, 3}, 64, 2, IOFormat::YUV420, 4, true}; 16 | 17 | OptimizationContext ctx(config, gLogger, "models"); 18 | 19 | ctx.optimize("y4m_any"); 20 | } 21 | -------------------------------------------------------------------------------- /tensorrt/cmake/Modules/FindTensorRT.cmake: -------------------------------------------------------------------------------- 1 | find_path(TensorRT_INCLUDE_DIR 2 | NAMES NvInfer.h) 3 | 4 | find_library(TensorRT_LIBRARY 5 | NAMES nvinfer) 6 | 7 | if (TensorRT_LIBRARY) 8 | set(TensorRT_LIBRARIES 9 | ${TensorRT_LIBRARIES} 10 | ${TensorRT_LIBRARY}) 11 | endif (TensorRT_LIBRARY) 12 | 13 | function(_tensorrt_get_version) 14 | unset(TensorRT_VERSION_STRING PARENT_SCOPE) 15 | set(_hdr_file "${TensorRT_INCLUDE_DIR}/NvInferVersion.h") 16 | 17 | if (NOT EXISTS "${_hdr_file}") 18 | return() 19 | endif () 20 | 21 | file(STRINGS "${_hdr_file}" VERSION_STRINGS REGEX "#define NV_TENSORRT_.*") 22 | 23 | foreach(TYPE MAJOR MINOR PATCH BUILD) 24 | string(REGEX MATCH "NV_TENSORRT_${TYPE} [0-9]" TRT_TYPE_STRING ${VERSION_STRINGS}) 25 | string(REGEX MATCH "[0-9]" TensorRT_VERSION_${TYPE} ${TRT_TYPE_STRING}) 26 | endforeach(TYPE) 27 | 28 | set(TensorRT_VERSION_MAJOR ${TensorRT_VERSION_MAJOR} PARENT_SCOPE) 29 | set(TensorRT_VERSION_MINOR ${TensorRT_VERSION_MINOR} PARENT_SCOPE) 30 | set(TensorRT_VERSION_PATCH ${TensorRT_VERSION_PATCH} PARENT_SCOPE) 31 | set(TensorRT_VERSION_BUILD ${TensorRT_VERSION_BUILD} PARENT_SCOPE) 32 | 33 | set(TensorRT_VERSION_STRING "${TensorRT_VERSION_MAJOR}.${TensorRT_VERSION_MINOR}.${TensorRT_VERSION_PATCH}.${TensorRT_VERSION_BUILD}" PARENT_SCOPE) 34 | endfunction(_tensorrt_get_version) 35 | 36 | _tensorrt_get_version() 37 | 38 | if(TensorRT_FIND_COMPONENTS) 39 | list(REMOVE_ITEM TensorRT_FIND_COMPONENTS "nvinfer") 40 | 41 | if ("OnnxParser" IN_LIST TensorRT_FIND_COMPONENTS) 42 | find_path(TensorRT_OnnxParser_INCLUDE_DIR 43 | NAMES NvOnnxParser.h) 44 | 45 | find_library(TensorRT_OnnxParser_LIBRARY 46 | NAMES nvonnxparser) 47 | if (TensorRT_OnnxParser_LIBRARY AND TensorRT_LIBRARIES) 48 | set(TensorRT_LIBRARIES 49 | ${TensorRT_LIBRARIES} 50 | ${TensorRT_OnnxParser_LIBRARY}) 51 | set(TensorRT_OnnxParser_FOUND TRUE) 52 | endif () 53 | endif() 54 | 55 | if ("Plugin" IN_LIST TensorRT_FIND_COMPONENTS) 56 | find_path(TensorRT_Plugin_INCLUDE_DIR 57 | NAMES NvInferPlugin.h) 58 | 59 | find_library(TensorRT_Plugin_LIBRARY 60 | NAMES nvinfer_plugin) 61 | 62 | if (TensorRT_Plugin_LIBRARY AND TensorRT_LIBRARIES) 63 | set(TensorRT_LIBRARIES 64 | ${TensorRT_LIBRARIES} 65 | ${TensorRT_Plugin_LIBRARY}) 66 | set(TensorRT_Plugin_FOUND TRUE) 67 | endif () 68 | endif() 69 | endif() 70 | 71 | include(FindPackageHandleStandardArgs) 72 | find_package_handle_standard_args(TensorRT 73 | FOUND_VAR TensorRT_FOUND 74 | REQUIRED_VARS TensorRT_LIBRARY TensorRT_LIBRARIES TensorRT_INCLUDE_DIR 75 | VERSION_VAR TensorRT_VERSION_STRING 76 | HANDLE_COMPONENTS) 77 | 78 | add_library(TensorRT::NvInfer UNKNOWN IMPORTED) 79 | target_include_directories(TensorRT::NvInfer SYSTEM INTERFACE "${TensorRT_INCLUDE_DIR}") 80 | set_property(TARGET TensorRT::NvInfer PROPERTY IMPORTED_LOCATION "${TensorRT_LIBRARY}") 81 | 82 | if ("OnnxParser" IN_LIST TensorRT_FIND_COMPONENTS) 83 | add_library(TensorRT::OnnxParser UNKNOWN IMPORTED) 84 | target_include_directories(TensorRT::OnnxParser SYSTEM INTERFACE "${TensorRT_OnnxParser_INCLUDE_DIR}") 85 | target_link_libraries(TensorRT::OnnxParser INTERFACE TensorRT::NvInfer) 86 | set_property(TARGET TensorRT::OnnxParser PROPERTY IMPORTED_LOCATION "${TensorRT_OnnxParser_LIBRARY}") 87 | endif() 88 | 89 | if ("Plugin" IN_LIST TensorRT_FIND_COMPONENTS) 90 | add_library(TensorRT::Plugin UNKNOWN IMPORTED) 91 | target_include_directories(TensorRT::Plugin SYSTEM INTERFACE "${TensorRT_Plugin_INCLUDE_DIR}") 92 | target_link_libraries(TensorRT::Plugin INTERFACE TensorRT::NvInfer) 93 | set_property(TARGET TensorRT::Plugin PROPERTY IMPORTED_LOCATION "${TensorRT_Plugin_LIBRARY}") 94 | endif() 95 | 96 | mark_as_advanced(TensorRT_INCLUDE_DIR TensorRT_LIBRARY TensorRT_LIBRARIES) -------------------------------------------------------------------------------- /tensorrt/cmake/Modules/FindVapourSynth.cmake: -------------------------------------------------------------------------------- 1 | find_path(VapourSynth_INCLUDE_DIR 2 | NAMES VapourSynth4.h 3 | PATH_SUFFIXES vapoursynth) 4 | 5 | 6 | function(_vapoursynth_get_version) 7 | unset(VapourSynth_VERSION_STRING PARENT_SCOPE) 8 | set(_hdr_file "${VapourSynth_INCLUDE_DIR}/VapourSynth4.h") 9 | 10 | if(NOT EXISTS "${_hdr_file}") 11 | return() 12 | endif() 13 | 14 | file(STRINGS "${_hdr_file}" VERSION_STRINGS REGEX "#define VAPOURSYNTH_API_.*") 15 | 16 | foreach(TYPE MAJOR MINOR) 17 | string(REGEX MATCH "VAPOURSYNTH_API_${TYPE} [0-9]" TRT_TYPE_STRING ${VERSION_STRINGS}) 18 | string(REGEX MATCH "[0-9]" VapourSynth_VERSION_${TYPE} ${TRT_TYPE_STRING}) 19 | endforeach(TYPE) 20 | 21 | set(VapourSynth_VERSION_MAJOR ${VapourSynth_VERSION_MAJOR} PARENT_SCOPE) 22 | set(VapourSynth_VERSION_MINOR ${VapourSynth_VERSION_MINOR} PARENT_SCOPE) 23 | 24 | set(VapourSynth_VERSION_STRING "${VapourSynth_VERSION_MAJOR}.${VapourSynth_VERSION_MINOR}" PARENT_SCOPE) 25 | endfunction(_vapoursynth_get_version) 26 | 27 | _vapoursynth_get_version() 28 | 29 | include(FindPackageHandleStandardArgs) 30 | find_package_handle_standard_args(VapourSynth 31 | FOUND_VAR VapourSynth_FOUND 32 | REQUIRED_VARS VapourSynth_INCLUDE_DIR 33 | VERSION_VAR VapourSynth_VERSION_STRING 34 | HANDLE_COMPONENTS) 35 | 36 | add_library(VapourSynth INTERFACE IMPORTED) 37 | target_include_directories(VapourSynth SYSTEM INTERFACE "${VapourSynth_INCLUDE_DIR}") 38 | 39 | mark_as_advanced(VapourSynth_INCLUDE_DIR) 40 | -------------------------------------------------------------------------------- /tensorrt/demo.vpy: -------------------------------------------------------------------------------- 1 | import vapoursynth as vs 2 | 3 | core = vs.core 4 | 5 | # Load example.mp4 as input video 6 | # The ffms2 plugin can be installed by running 7 | # $ conda install -c conda-forge -c tongyuantongyu vapoursynth-ffms2 8 | clip = core.ffms2.Source(source="example.mp4") 9 | 10 | clip = core.cycmunet.CycMuNet(clip, 11 | scale_factor=2, 12 | batch_size=1, 13 | batch_size_fusion=1, 14 | use_fp16=True, 15 | low_mem=True, 16 | model_path=r"./models", 17 | model="2x_vimeo" 18 | ) 19 | 20 | clip.set_output() 21 | 22 | # - `scale_factor_h`: float. The scale factor of height of the network in use. 23 | # This is default tobe the same as `scale_factor`. (*) 24 | # - `batch_size_extract`: int. The batch size of Extract model. Default automatically 25 | # selected depending on other parameters. 26 | # - `batch_size_fusion`: int. The batch size of Fusion model. Default to 1. 27 | # - `input_count`: int. The number of input frames network needed. Default to 1. (*) 28 | # - `feature_count`: int. The "feature" (`C` channel) size. Default to 64. (*) 29 | # - `extraction_layers`: int. The number of layers Extract model outputs. Default to 1. (*) 30 | # - `interpolation`: bool. If the network is doing frame interpolation 31 | # (i.e. output clip) will have double framerate Default to False. 32 | # - `extra_frame`: bool. If network need 1 more input frame than consumed. 33 | # Default to False. (*) 34 | # - `double_frame`: bool. If network outputs 2 times of frames than input. 35 | # Default to False. (*) 36 | # - `use_fp16`: bool. Use half precision during inference. On supported GPUs 37 | # (starting from Volta), this is usually ~2x faster and consumes half 38 | # amount of GPU memory, but may cause numeric instability for some 39 | # networks. 40 | # - `low_mem`: bool. Tweak TensorRT configurations to reduce memory usage. 41 | # May cause performance degradation and effectiveness varies depending on 42 | # actual model. Default to False. 43 | # - `norm_mean` and `norm_std`: float[3]. Normalization mean and std applied 44 | # to inputs and output. The interpretation of these values depending on the 45 | # following option. Defaults to [0.485, 0.456, 0.406] and [0.229, 0.224, 0.225] (*) 46 | # - `raw_norm`; bool. If True, `norm_mean` and `norm_std` are applied directly 47 | # to the input and output frame pixel value of each channel. 48 | # If False, `norm_mean` and `norm_std` are values of RGB channels from 49 | # 0-1 range. The actual value used for normalization is inferred automatically 50 | # from colorspace information of input clip. Default to False. (*) 51 | # - `model`: str. The name of the model to be used. Default to ".". (*) 52 | # - `model_path`: str. The path that stores the model files. 53 | # Default to `dev.tyty.aim.nnvisr` folder under the folder of plugin 54 | # DLL. 55 | 56 | 57 | -------------------------------------------------------------------------------- /tensorrt/include/config.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "utils.h" 6 | 7 | struct optimization_axis { 8 | optimization_axis(int32_t min, int32_t opt, int32_t max) : min(min), opt(opt), max(max) {} 9 | optimization_axis(int32_t same) : min(same), opt(same), max(same) {} 10 | optimization_axis() : min(0), opt(0), max(0) {} 11 | int32_t min, opt, max; 12 | }; 13 | 14 | enum IOFormat { 15 | RGB, 16 | YUV420, 17 | }; 18 | 19 | struct OptimizationConfig { 20 | optimization_axis input_width; 21 | optimization_axis input_height; 22 | 23 | optimization_axis batch_extract; 24 | optimization_axis batch_fusion; 25 | 26 | int32_t input_count; 27 | int32_t feature_count; 28 | int32_t extraction_layers; 29 | bool extra_frame; 30 | bool double_frame; 31 | bool interpolation; 32 | 33 | float scale_factor_w; 34 | float scale_factor_h; 35 | IOFormat format; 36 | 37 | bool use_fp16; 38 | bool low_mem; 39 | }; 40 | 41 | struct InferenceConfig { 42 | int32_t input_width; 43 | int32_t input_height; 44 | 45 | int32_t batch_extract; 46 | int32_t batch_fusion; 47 | 48 | int32_t input_count; 49 | int32_t feature_count; 50 | int32_t extraction_layers; 51 | bool extra_frame; 52 | bool double_frame; 53 | bool interpolation; 54 | 55 | float scale_factor_w; 56 | float scale_factor_h; 57 | IOFormat format; 58 | 59 | bool use_fp16; 60 | bool low_mem; 61 | }; 62 | -------------------------------------------------------------------------------- /tensorrt/include/debug/reveal.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "cuda_runtime_api.h" 6 | #include "md_view.h" 7 | 8 | template 9 | void dump_value(md_view t, std::wstring name) { 10 | auto size = t.size(); 11 | auto host_pointer = std::make_unique(size); 12 | cudaMemcpy(host_pointer.get(), t.data, size * sizeof(T), cudaMemcpyDeviceToHost); 13 | auto h = host_pointer.get(); 14 | 15 | // Now feel free to examine host_pointer (or through h) 16 | std::ofstream p(name, std::ios::binary); 17 | p.write((const char*)(h), size * sizeof(T)); 18 | p.close(); 19 | } 20 | 21 | template 22 | void dump_value(const T* t, size_t size, std::wstring name) { 23 | auto host_pointer = std::make_unique(size); 24 | cudaMemcpy(host_pointer.get(), t, size * sizeof(T), cudaMemcpyDeviceToHost); 25 | auto h = host_pointer.get(); 26 | 27 | // Now feel free to examine host_pointer (or through h) 28 | std::ofstream p(name, std::ios::binary); 29 | p.write((const char*)(h), size * sizeof(T)); 30 | p.close(); 31 | } 32 | 33 | template 34 | void debug_me_show_memory(md_view t) { 35 | using _T = typename std::decay::type; 36 | auto size = t.size(); 37 | auto host_pointer = std::make_unique<_T[]>(size); 38 | cudaMemcpy(host_pointer.get(), t.data, size * sizeof(_T), cudaMemcpyDeviceToHost); 39 | auto h = host_pointer.get(); 40 | 41 | // Now feel free to examine host_pointer (or through h) 42 | return; 43 | } 44 | -------------------------------------------------------------------------------- /tensorrt/include/helper.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #if defined(_WIN32) 4 | #if defined(BUILDING_PLUGIN) 5 | #define PLUGIN_EXPORT __declspec(dllexport) 6 | #else 7 | #define PLUGIN_EXPORT __declspec(dllimport) 8 | #endif 9 | #elif defined(__GNUC__) && __GNUC__ >= 4 10 | #if defined(BUILDING_PLUGIN) 11 | #define PLUGIN_EXPORT __attribute__((visibility("default"))) 12 | #else 13 | #define PLUGIN_EXPORT 14 | #endif 15 | #else 16 | #define PLUGIN_EXPORT 17 | #endif 18 | 19 | #if defined(_MSC_VER) 20 | #define PLUGIN_UNREACHABLE __assume(false) 21 | #elif defined(__GNUC__) 22 | #define PLUGIN_UNREACHABLE __builtin_unreachable() 23 | #else 24 | #define PLUGIN_UNREACHABLE void(0) 25 | #endif 26 | 27 | 28 | #if defined(__CUDACC__) 29 | #define util_attrs __host__ __device__ inline 30 | #else 31 | #define util_attrs inline 32 | #endif 33 | -------------------------------------------------------------------------------- /tensorrt/include/inference.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | #include 7 | #include 8 | 9 | #include "NvInferRuntime.h" 10 | 11 | #include "config.h" 12 | #include "md_view.h" 13 | #include "utils.h" 14 | 15 | template 16 | struct ModelStuff { 17 | T feature_extract; 18 | T feature_fusion; 19 | }; 20 | 21 | class InferenceSession; 22 | 23 | class InferenceContext { 24 | nvinfer1::ILogger &logger; 25 | nvinfer1::IRuntime *runtime; 26 | std::filesystem::path path_prefix; 27 | ModelStuff engine; 28 | 29 | friend class InferenceSession; 30 | 31 | public: 32 | InferenceConfig config; 33 | InferenceContext(InferenceConfig config, nvinfer1::ILogger &logger, std::filesystem::path path_prefix); 34 | bool has_file(); 35 | bool load_engine(); 36 | 37 | bool good() { return runtime != nullptr && engine.feature_extract != nullptr && engine.feature_fusion != nullptr; } 38 | }; 39 | 40 | class InferenceSession { 41 | InferenceContext ctx; 42 | 43 | ModelStuff context; 44 | std::vector cudaBuffers; 45 | void *executionMemory; 46 | ModelStuff last_batch, last_offset_in, last_offset_out; 47 | bool good_; 48 | 49 | void trace(const std::string &info) { 50 | // no fold 51 | // ctx.logger.log(nvinfer1::ILogger::Severity::kINFO, ("Infer Trace: " + info).c_str()); 52 | } 53 | 54 | public: 55 | cudaStream_t stream; 56 | 57 | md_view input, input_uv; 58 | md_view output, output_uv; 59 | std::vector> features; 60 | 61 | explicit InferenceSession(InferenceContext &ctx); 62 | ~InferenceSession(); 63 | 64 | bool good() const { return good_; } 65 | 66 | void extractBatch(int32_t offset_in, int32_t offset_out, int32_t batch); 67 | void fusionBatch(int32_t batch); 68 | void fusionGroupedOffset(int32_t group_idx); 69 | void fusionCustomOffset(const std::vector &indexes); 70 | 71 | void duplicateExtractOutput(int32_t from, int32_t to); 72 | 73 | int32_t internalFeatureIndex(int32_t idx); 74 | shape_t<3> outputIndex(offset_t idx); 75 | 76 | bool extract(); 77 | bool fusion(); 78 | }; 79 | -------------------------------------------------------------------------------- /tensorrt/include/optimize.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by TYTY on 2023-01-13 013. 3 | // 4 | 5 | #pragma once 6 | 7 | #include "NvInfer.h" 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include "config.h" 16 | 17 | class OptimizationContext { 18 | OptimizationConfig config; 19 | nvinfer1::ILogger &logger; 20 | std::filesystem::path path_prefix; 21 | 22 | nvinfer1::IBuilder *builder; 23 | nvinfer1::ITimingCache *cache; 24 | 25 | cudaDeviceProp prop; 26 | size_t total_memory; 27 | 28 | [[nodiscard]] nvinfer1::IBuilderConfig *prepareConfig() const; 29 | [[nodiscard]] nvinfer1::INetworkDefinition *createNetwork() const; 30 | int buildFeatureExtract(std::vector input, const std::filesystem::path& output); 31 | int buildFeatureFusion(std::vector input, const std::filesystem::path& output); 32 | 33 | public: 34 | OptimizationContext(OptimizationConfig config, nvinfer1::ILogger &logger, std::filesystem::path path_prefix); 35 | int optimize(const std::filesystem::path &folder); 36 | ~OptimizationContext(); 37 | }; -------------------------------------------------------------------------------- /tensorrt/include/reformat.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by TYTY on 2023-01-04 004. 3 | // 4 | 5 | #ifndef CYCMUNET_TRT_INCLUDE_REFORMAT_H_ 6 | #define CYCMUNET_TRT_INCLUDE_REFORMAT_H_ 7 | 8 | #include "md_view.h" 9 | 10 | template 11 | void import_pixel(md_view dst, md_view src, float a, float b, cudaStream_t stream); 12 | 13 | template 14 | void export_pixel(md_view dst, md_view src, float a, float b, float min, float max, 15 | cudaStream_t stream); 16 | 17 | #endif//CYCMUNET_TRT_INCLUDE_REFORMAT_H_ 18 | -------------------------------------------------------------------------------- /tensorrt/include/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "helper.h" 5 | 6 | template 7 | struct hw { 8 | T h; 9 | T w; 10 | 11 | template 12 | constexpr util_attrs operator hw() const noexcept { 13 | return {static_cast(h), static_cast(w)}; 14 | } 15 | 16 | constexpr util_attrs hw operator+(const hw& o) const noexcept { 17 | return {h + o.h, w + o.w}; 18 | } 19 | 20 | constexpr util_attrs hw operator+(const T& o) const noexcept { 21 | return {h + o, w + o}; 22 | } 23 | 24 | constexpr util_attrs hw& operator+=(const hw& o) noexcept { 25 | *this = {h + o.h, w + o.w}; 26 | return *this; 27 | } 28 | 29 | constexpr util_attrs hw& operator+=(const T& o) noexcept { 30 | *this = {h + o, w + o}; 31 | return *this; 32 | } 33 | 34 | constexpr util_attrs hw operator-(const hw& o) const noexcept { 35 | return {h - o.h, w - o.w}; 36 | } 37 | 38 | constexpr util_attrs hw operator-(const T& o) const noexcept { 39 | return {h - o, w - o}; 40 | } 41 | 42 | constexpr util_attrs hw operator-() const noexcept { 43 | return {-h, -w}; 44 | } 45 | }; 46 | -------------------------------------------------------------------------------- /tensorrt/layers/impl/dcn_layer.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "dcn_layer_impl.h" 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | //#include "reveal.h" 10 | 11 | #ifdef __LP64__ 12 | template<> 13 | template<> 14 | util_attrs hw::operator hw() const noexcept { 15 | return {static_cast(static_cast(h)), 16 | static_cast(static_cast(w))}; 17 | } 18 | #endif 19 | 20 | constexpr std::size_t threadCount = 1024; 21 | constexpr std::size_t threadCountIm2Col = 512; 22 | 23 | struct im2col_parameters { 24 | hw<> stride; 25 | hw<> padding; 26 | hw<> dilation; 27 | int32_t channel_per_deformable_group; 28 | }; 29 | 30 | template 31 | struct bias_activation_parameters { 32 | // see NvInfer.h ActivationType 33 | int32_t activation_type; 34 | F alpha, beta; 35 | }; 36 | 37 | half __device__ floor(const half f) { 38 | return hfloor(f); 39 | } 40 | 41 | half __device__ ceil(const half f) { 42 | return hceil(f); 43 | } 44 | 45 | // Gather data from coordination. Use bilinear interpolation. 46 | template 47 | static inline F __device__ DCNGather(const md_view &input, 48 | offset_t n, 49 | offset_t c, 50 | hw pos, 51 | hw offset) { 52 | const F z {}; 53 | F result {}; 54 | 55 | const auto [h, w] = input.shape.template slice<2, 2>(); 56 | const hw ol_float {floor(offset.h), floor(offset.w)}; 57 | const hw oh_float {ceil(offset.h), ceil(offset.w)}; 58 | 59 | hw pl = ol_float, ph = oh_float; 60 | pl += pos; 61 | ph += pos; 62 | 63 | if (ph.h < 0 || ph.w < 0 || pl.h >= h || pl.w >= w) { 64 | return result; 65 | } 66 | 67 | // w(eight) of data at l(ow)/h(igh) pos 68 | const hw wh = offset - ol_float; 69 | const hw wl = -wh + 1; 70 | 71 | // should we read data at l(ow)/h(igh) h(eight)/w(idth) pos 72 | const bool lh = wl.h != z && pl.h >= 0; 73 | const bool lw = wl.w != z && pl.w >= 0; 74 | const bool hh = wh.h != z && ph.h < h; 75 | const bool hw = wh.w != z && ph.w < w; 76 | if (lh && lw) { 77 | result += input.at(n, c, pl.h, pl.w) * wl.h * wl.w; 78 | } 79 | if (lh && hw) { 80 | result += input.at(n, c, pl.h, ph.w) * wl.h * wh.w; 81 | } 82 | if (hh && lw) { 83 | result += input.at(n, c, ph.h, pl.w) * wh.h * wl.w; 84 | } 85 | if (hh && hw) { 86 | result += input.at(n, c, ph.h, ph.w) * wh.h * wh.w; 87 | } 88 | 89 | return result; 90 | } 91 | 92 | // Gather data from input into matrix form. 93 | template 94 | static void __global__ DCNIm2colKernel(md_view input, md_view offset, md_view mask, 95 | md_view col, im2col_parameters p, offset_t count) { 96 | offset_t idx = threadIdx.x + blockDim.x * blockIdx.x; 97 | if (idx >= count) { return; } 98 | 99 | const auto [n, c, h, w] = col.shape.template gather<0, 1, 4, 5>().indexes(idx); 100 | // kernel h, w 101 | const auto [kh, kw] = col.shape.template slice<2, 2>(); 102 | // index of deformable group 103 | const auto g = c / p.channel_per_deformable_group; 104 | // input h, w base offset 105 | const auto hin = h * p.stride.h - p.padding.h; 106 | const auto win = w * p.stride.w - p.padding.w; 107 | 108 | #pragma unroll K 109 | for (uint32_t i = 0; i < uint32_t(kh); ++i) { 110 | #pragma unroll K 111 | for (uint32_t j = 0; j < uint32_t(kw); ++j) { 112 | F data = DCNGather(input, n, c, {hin + i * p.dilation.h, win + j * p.dilation.w}, 113 | {offset.at(n, g, i, j, 0, h, w), offset.at(n, g, i, j, 1, h, w)}); 114 | col.at(n, c, i, j, h, w) = data * mask.at(n, g, i, j, h, w); 115 | } 116 | } 117 | } 118 | 119 | // Broadcast bias to output result. 120 | template 121 | static void __global__ BiasBroadcastKernel(md_view output, 122 | md_view bias, 123 | offset_t count) { 124 | offset_t idx = threadIdx.x + blockDim.x * blockIdx.x; 125 | if (idx >= count) { 126 | return; 127 | } 128 | 129 | const auto [n, c, h, w] = output.shape.indexes(idx); 130 | output.at(n, c, h, w) = bias.at(c); 131 | } 132 | 133 | // Add bias to output result. 134 | template 135 | static void __global__ BiasActivationKernel(md_view output, 136 | md_view bias, 137 | bias_activation_parameters p, 138 | offset_t count) { 139 | offset_t idx = threadIdx.x + blockDim.x * blockIdx.x; 140 | if (idx >= count) { 141 | return; 142 | } 143 | 144 | const auto [n, c, h, w] = output.shape.indexes(idx); 145 | F v = output.at(n, c, h, w) + bias.at(c); 146 | 147 | // support other when there's need 148 | assert(p.activation_type == 3); 149 | 150 | // leaky_relu 151 | if (v < F {}) { 152 | v *= p.alpha; 153 | } 154 | 155 | output.at(n, c, h, w) = v; 156 | } 157 | 158 | template 159 | struct cuda_type_trait {}; 160 | 161 | template<> 162 | struct cuda_type_trait { 163 | constexpr static cudaDataType_t data_type = CUDA_R_32F; 164 | constexpr static cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; 165 | }; 166 | 167 | template<> 168 | struct cuda_type_trait { 169 | constexpr static cudaDataType_t data_type = CUDA_R_16F; 170 | constexpr static cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; 171 | }; 172 | 173 | const static float fOne = 1; 174 | const static float fZero = 0; 175 | 176 | template 177 | void compute(DCNLayerInput inputs, DCNLayerOutput outputs, DCNLayerConfig config, DCNLayerExtra extra, 178 | cudaStream_t stream) { 179 | assert(cudaGetLastError() == cudaSuccess); 180 | std::size_t count = inputs.im2col_buffer.shape.template gather<0, 1, 4, 5>().count(); 181 | auto blocks = (count + threadCountIm2Col - 1) / threadCountIm2Col; 182 | 183 | im2col_parameters im2col_p{ 184 | config.stride, config.padding, config.dilation, 185 | int32_t((inputs.input.shape[1] + config.deformable_groups - 1) / config.deformable_groups)}; 186 | 187 | if (inputs.weight.shape[2] == 3 && inputs.weight.shape[3] == 3) { 188 | DCNIm2colKernel<<>>(inputs.input, inputs.offset, inputs.mask, 189 | inputs.im2col_buffer, im2col_p, count); 190 | } 191 | else if (inputs.weight.shape[2] == 1 && inputs.weight.shape[3] == 1) { 192 | DCNIm2colKernel<<>>(inputs.input, inputs.offset, inputs.mask, 193 | inputs.im2col_buffer, im2col_p, count); 194 | } 195 | else { 196 | DCNIm2colKernel<<>>(inputs.input, inputs.offset, inputs.mask, 197 | inputs.im2col_buffer, im2col_p, count); 198 | } 199 | 200 | assert(cudaGetLastError() == cudaSuccess); 201 | 202 | const offset_t m = outputs.output.shape.template slice<2, 2>().count(); 203 | const offset_t n = outputs.output.shape[1]; 204 | const offset_t k = inputs.im2col_buffer.shape.template slice<1, 3>().count(); 205 | 206 | const float *alpha = &fOne; 207 | const float *beta = &fZero; 208 | 209 | count = outputs.output.size(); 210 | blocks = (count + threadCount - 1) / threadCount; 211 | 212 | // If activation not needed, broadcast bias and let gemm do the add for us. 213 | if (config.activation_type == -1) { 214 | BiasBroadcastKernel<<>>(outputs.output, inputs.bias, count); 215 | beta = &fOne; 216 | assert(cudaGetLastError() == cudaSuccess); 217 | } 218 | 219 | const auto cublasResult = cublasGemmStridedBatchedEx_64( 220 | static_cast(extra.cublasHandle), CUBLAS_OP_N, CUBLAS_OP_N, m, n, k, alpha, 221 | inputs.im2col_buffer.data, cuda_type_trait::data_type, m, 222 | inputs.im2col_buffer.shape.template slice<1, 5>().count(), inputs.weight.data, cuda_type_trait::data_type, k, 223 | 0, beta, outputs.output.data, cuda_type_trait::data_type, m, 224 | outputs.output.shape.template slice<1, 3>().count(), outputs.output.shape[0], cuda_type_trait::compute_type, 225 | CUBLAS_GEMM_DEFAULT); 226 | 227 | assert(cublasResult == CUBLAS_STATUS_SUCCESS); 228 | 229 | // Fuse bias and activation, if there are. 230 | if (config.activation_type != -1) { 231 | bias_activation_parameters ba_p {config.activation_type, F {config.alpha}, F {config.beta}}; 232 | BiasActivationKernel<<>>(outputs.output, inputs.bias, ba_p, count); 233 | 234 | assert(cudaGetLastError() == cudaSuccess); 235 | } 236 | } 237 | 238 | // Explicit template instantiation. Keep these after template definition. 239 | template void compute(DCNLayerInput inputs, 240 | DCNLayerOutput outputs, 241 | DCNLayerConfig config, 242 | DCNLayerExtra extra, 243 | cudaStream_t stream); 244 | template void compute(DCNLayerInput inputs, 245 | DCNLayerOutput outputs, 246 | DCNLayerConfig config, 247 | DCNLayerExtra extra, 248 | cudaStream_t stream); -------------------------------------------------------------------------------- /tensorrt/layers/include/internal/config.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "md_view.h" 4 | #include "utils.h" 5 | 6 | struct DCNLayerInternal { 7 | int32_t data_type; 8 | 9 | // N, Cin, Hin, Win 10 | shape_t<4> input; 11 | // N, deformable_groups, Hker, Wker, 2, Hout, Wout 12 | shape_t<7> offset; 13 | // N, deformable_groups, Hker, Wker, Hout, Wout 14 | shape_t<6> mask; 15 | // Cout, Cin, Hker, Wker 16 | shape_t<4> weight; 17 | // Cout 18 | shape_t<1> bias; 19 | 20 | // N, Cout, Hout, Wout 21 | shape_t<4> output; 22 | 23 | // N, Cin, Hker, Wker, Hout, Wout 24 | shape_t<6> im2col_buffer; 25 | }; 26 | 27 | struct DCNLayerConfig { 28 | hw<> stride; 29 | hw<> padding; 30 | hw<> dilation; 31 | int32_t deformable_groups; 32 | int32_t activation_type; 33 | float alpha, beta; 34 | }; 35 | 36 | template 37 | struct DCNLayerInput { 38 | md_view input; 39 | // offset_h, offset_w 40 | md_view offset; 41 | md_view mask; 42 | md_view weight; 43 | md_view bias; 44 | 45 | md_view im2col_buffer; 46 | }; 47 | 48 | template 49 | struct DCNLayerOutput { 50 | md_view output; 51 | }; 52 | 53 | struct DCNLayerExtra { 54 | void* cublasHandle; 55 | int blasMode; 56 | }; 57 | -------------------------------------------------------------------------------- /tensorrt/layers/include/internal/dcn_layer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include "NvInfer.h" 6 | #include "config.h" 7 | 8 | #define PLUGIN_NAME "DeformConv2d" 9 | 10 | namespace nvinfer1 { 11 | class DCNLayerPlugin : public IPluginV2DynamicExt { 12 | public: 13 | int getNbOutputs() const noexcept override { 14 | return 1; 15 | } 16 | 17 | DimsExprs getOutputDimensions(int outputIndex, const DimsExprs *inputs, 18 | int nbInputs, IExprBuilder &exprBuilder) noexcept override; 19 | 20 | bool supportsFormatCombination(int pos, const nvinfer1::PluginTensorDesc *inOut, int nbInputs, int nbOutputs) noexcept override; 21 | 22 | // boilerplate 23 | explicit DCNLayerPlugin(const DCNLayerConfig *config) noexcept; 24 | DCNLayerPlugin(const DCNLayerPlugin &plugin) noexcept; 25 | DCNLayerPlugin(const void *data, size_t length) noexcept; 26 | ~DCNLayerPlugin() noexcept override = default; 27 | 28 | int initialize() noexcept override { 29 | return 0; 30 | }; 31 | 32 | void terminate() noexcept override {}; 33 | 34 | size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, int nbInputs, 35 | const nvinfer1::PluginTensorDesc *outputs, int nbOutputs) const noexcept override; 36 | 37 | int enqueue(const PluginTensorDesc *inputDesc, const PluginTensorDesc *outputDesc, 38 | const void *const *inputs, void *const *outputs, void *workspace, 39 | cudaStream_t stream) noexcept override; 40 | 41 | size_t getSerializationSize() const noexcept override; 42 | 43 | void serialize(void *buffer) const noexcept override; 44 | 45 | const char *getPluginType() const noexcept override { 46 | return PLUGIN_NAME; 47 | }; 48 | 49 | const char *getPluginVersion() const noexcept override { 50 | return "1"; 51 | }; 52 | 53 | void destroy() noexcept override {}; 54 | 55 | IPluginV2DynamicExt *clone() const noexcept override; 56 | 57 | void setPluginNamespace(const char *pluginNamespace) noexcept override { 58 | this->mPluginNamespace = pluginNamespace; 59 | }; 60 | 61 | const char *getPluginNamespace() const noexcept override { 62 | return this->mPluginNamespace.c_str(); 63 | }; 64 | 65 | DataType getOutputDataType(int index, const nvinfer1::DataType *inputTypes, int nbInputs) const noexcept override { 66 | return inputTypes[0]; 67 | }; 68 | 69 | void attachToContext( 70 | cudnnContext *cudnnContext, cublasContext *cublasContext, IGpuAllocator *gpuAllocator) noexcept override { 71 | cublas_context = cublasContext; 72 | }; 73 | 74 | void detachFromContext() noexcept override {}; 75 | 76 | void configurePlugin( 77 | const DynamicPluginTensorDesc* in, int nbInputs, 78 | const DynamicPluginTensorDesc* out, int nbOutputs) noexcept override; 79 | 80 | private: 81 | template 82 | std::pair, DCNLayerOutput> makeView(const void *const *inputs, 83 | void *const *outputs, 84 | void *workspace); 85 | std::string mPluginNamespace; 86 | DCNLayerConfig config{}; 87 | DCNLayerInternal internal{}; 88 | 89 | cublasContext * cublas_context; 90 | }; 91 | 92 | class DCNLayerPluginCreator : public IPluginCreator { 93 | public: 94 | // boilerplate 95 | DCNLayerPluginCreator() noexcept = default; 96 | ~DCNLayerPluginCreator() noexcept override = default; 97 | 98 | const char *getPluginName() const noexcept override { 99 | return PLUGIN_NAME; 100 | }; 101 | 102 | const char *getPluginVersion() const noexcept override { 103 | return "1"; 104 | }; 105 | 106 | const PluginFieldCollection *getFieldNames() noexcept override; 107 | 108 | IPluginV2DynamicExt *createPlugin(const char *name, const PluginFieldCollection *fc) noexcept override; 109 | 110 | IPluginV2DynamicExt *deserializePlugin(const char *name, const void *serialData, size_t serialLength) noexcept override; 111 | 112 | void setPluginNamespace(const char *libNamespace) noexcept override { 113 | mNamespace = libNamespace; 114 | } 115 | 116 | const char *getPluginNamespace() const noexcept override { 117 | return mNamespace.c_str(); 118 | } 119 | 120 | private: 121 | std::string mNamespace; 122 | 123 | const static PluginField mPluginAttributes[]; 124 | const static PluginFieldCollection mPFC; 125 | 126 | }; 127 | 128 | }; 129 | 130 | #undef PLUGIN_NAME 131 | -------------------------------------------------------------------------------- /tensorrt/layers/include/internal/dcn_layer_impl.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include "config.h" 5 | 6 | template 7 | void compute(DCNLayerInput inputs, 8 | DCNLayerOutput outputs, 9 | DCNLayerConfig config, 10 | DCNLayerExtra extra, 11 | cudaStream_t stream); 12 | 13 | -------------------------------------------------------------------------------- /tensorrt/layers/include/layers.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "helper.h" 4 | 5 | namespace UDOLayers { 6 | 7 | bool registerDCNLayerPlugin(); 8 | 9 | bool registerPlugins(); 10 | 11 | } -------------------------------------------------------------------------------- /tensorrt/layers/src/dcn_layer.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include "helper.h" 6 | #include "config.h" 7 | #include "dcn_layer.h" 8 | #include "dcn_layer_impl.h" 9 | 10 | namespace nvinfer1 { 11 | 12 | constexpr static struct { 13 | uint16_t hdr; 14 | int padding_test; 15 | float float_test; 16 | } signature = {0xfeff, 0x1, 1.0f}; 17 | 18 | DCNLayerPlugin::DCNLayerPlugin(const DCNLayerConfig *config) noexcept { 19 | this->config = *config; 20 | } 21 | 22 | DCNLayerPlugin::DCNLayerPlugin(const DCNLayerPlugin &plugin) noexcept { 23 | this->config = plugin.config; 24 | this->mPluginNamespace = plugin.mPluginNamespace; 25 | } 26 | 27 | DCNLayerPlugin::DCNLayerPlugin(const void *data, size_t length) noexcept { 28 | assert(length == sizeof(config)); 29 | std::memcpy(&config, (const char *) (data), sizeof(config)); 30 | } 31 | 32 | size_t DCNLayerPlugin::getSerializationSize() const noexcept { 33 | return sizeof(signature) + sizeof(config); 34 | } 35 | 36 | void DCNLayerPlugin::serialize(void *buffer) const noexcept { 37 | std::memcpy(buffer, &signature, sizeof(signature)); 38 | std::memcpy((char *) (buffer) + sizeof(signature), &config, sizeof(config)); 39 | } 40 | 41 | IPluginV2DynamicExt *DCNLayerPlugin::clone() const noexcept { 42 | auto p = new DCNLayerPlugin(&config); 43 | p->setPluginNamespace(this->getPluginNamespace()); 44 | return p; 45 | } 46 | 47 | enum { 48 | Input, 49 | Offset, 50 | Mask, 51 | Weight, 52 | Bias, 53 | nInput 54 | }; 55 | 56 | void DCNLayerPlugin::configurePlugin( 57 | const DynamicPluginTensorDesc *in, int nbInputs, 58 | const DynamicPluginTensorDesc *out, int nbOutputs) noexcept { 59 | assert(nbInputs == nInput); 60 | assert(nbOutputs == 1); 61 | 62 | shape_t<4> offset, mask; 63 | 64 | internal.data_type = int32_t(in->desc.type); 65 | internal.input.gather_from(in[Input].desc.dims.d, 0, 1, 2, 3); 66 | offset.gather_from(in[Offset].desc.dims.d, 0, 1, 2, 3); 67 | mask.gather_from(in[Mask].desc.dims.d, 0, 1, 2, 3); 68 | internal.weight.gather_from(in[Weight].desc.dims.d, 0, 1, 2, 3); 69 | internal.bias[0] = in[Bias].desc.dims.d[0]; 70 | 71 | if (internal.input[0] != -1 && 72 | internal.input[1] != -1 && 73 | internal.input[2] != -1 && 74 | internal.input[3] != -1) { 75 | // check and calc only when we know exact dimension. 76 | 77 | const auto [n, cin, h, w] = internal.input; 78 | 79 | // weight: input channel match 80 | assert(cin == internal.weight[1]); 81 | 82 | // offset: stack at dim 1 of offset_h, offset_w 83 | assert(n == offset[0]); 84 | const auto [cout, _, kh, kw] = internal.weight; 85 | const offset_t deformable_channels = kh * kw * config.deformable_groups; 86 | assert(deformable_channels * 2 == offset[1]); 87 | const offset_t oh = (h + 2 * config.padding.h - config.dilation.h * (kh - 1) - 1) / config.stride.h + 1; 88 | const offset_t ow = (w + 2 * config.padding.w - config.dilation.w * (kw - 1) - 1) / config.stride.w + 1; 89 | assert(oh == offset[2]); 90 | assert(ow == offset[3]); 91 | internal.offset = {n, config.deformable_groups, kh, kw, 2, offset[2], offset[3]}; 92 | 93 | // mask 94 | assert(n == mask[0]); 95 | assert(deformable_channels == mask[1]); 96 | assert(oh == mask[2]); 97 | assert(ow == mask[3]); 98 | internal.mask = {n, config.deformable_groups, kh, kw, mask[2], mask[3]}; 99 | 100 | // bias: output channel match 101 | assert(internal.weight[0] == internal.bias[0]); 102 | 103 | internal.output = {n, cout, offset[2], offset[3]}; 104 | internal.im2col_buffer = {n, cin, kh, kw, offset[2], offset[3]}; 105 | } 106 | 107 | } 108 | 109 | DimsExprs DCNLayerPlugin::getOutputDimensions(int outputIndex, const DimsExprs *inputs, 110 | int nbInputs, IExprBuilder &exprBuilder) noexcept { 111 | 112 | switch (outputIndex) { 113 | case 0: { 114 | const auto n = inputs[Input].d[0]; 115 | const auto h = inputs[Input].d[2]; 116 | const auto w = inputs[Input].d[3]; 117 | const auto c = inputs[Weight].d[0]; 118 | const auto kh = inputs[Weight].d[2]; 119 | const auto kw = inputs[Weight].d[3]; 120 | 121 | using op = DimensionOperation; 122 | 123 | // (h - config.dilation.h * (kh - 1) + 2 * config.padding.h - 1) / config.stride.h + 1 124 | auto oh = exprBuilder.operation(op::kSUB, *kh, *exprBuilder.constant(1)); 125 | oh = exprBuilder.operation(op::kPROD, *oh, *exprBuilder.constant(config.dilation.h)); 126 | oh = exprBuilder.operation(op::kSUB, *h, *oh); 127 | oh = exprBuilder.operation(op::kSUM, *oh, *exprBuilder.constant(2 * config.padding.h - 1)); 128 | oh = exprBuilder.operation(op::kFLOOR_DIV, *oh, *exprBuilder.constant(config.stride.h)); 129 | oh = exprBuilder.operation(op::kSUM, *oh, *exprBuilder.constant(1)); 130 | 131 | auto ow = exprBuilder.operation(op::kSUB, *kw, *exprBuilder.constant(1)); 132 | ow = exprBuilder.operation(op::kPROD, *ow, *exprBuilder.constant(config.dilation.w)); 133 | ow = exprBuilder.operation(op::kSUB, *w, *ow); 134 | ow = exprBuilder.operation(op::kSUM, *ow, *exprBuilder.constant(2 * config.padding.w - 1)); 135 | ow = exprBuilder.operation(op::kFLOOR_DIV, *ow, *exprBuilder.constant(config.stride.w)); 136 | ow = exprBuilder.operation(op::kSUM, *ow, *exprBuilder.constant(1)); 137 | return DimsExprs{4, { 138 | n, // batch size 139 | c, // out channel 140 | oh, // out height 141 | ow // out width 142 | }}; 143 | } 144 | 145 | default: 146 | // output count exceed. 147 | assert(false); 148 | PLUGIN_UNREACHABLE; 149 | } 150 | }; 151 | 152 | bool DCNLayerPlugin::supportsFormatCombination(int pos, 153 | const PluginTensorDesc *inOut, 154 | int nbInputs, 155 | int nbOutputs) noexcept { 156 | if (inOut[pos].format != TensorFormat::kLINEAR) { 157 | return false; 158 | } 159 | 160 | switch (pos) { 161 | case Input:return inOut[Input].type == DataType::kFLOAT || inOut[Input].type == DataType::kHALF; 162 | 163 | case Offset: 164 | case Mask: 165 | case Weight: 166 | case Bias: 167 | case Bias + 1: 168 | return inOut[pos].type == inOut[Input].type; 169 | 170 | default: 171 | // inOut count exceed. 172 | assert(false); 173 | PLUGIN_UNREACHABLE; 174 | } 175 | } 176 | 177 | size_t DCNLayerPlugin::getWorkspaceSize(const nvinfer1::PluginTensorDesc *inputs, 178 | int nbInputs, 179 | const nvinfer1::PluginTensorDesc *outputs, 180 | int nbOutputs) const noexcept { 181 | shape_t <6> im2col { 182 | inputs[Input].dims.d[0], 183 | inputs[Input].dims.d[1], 184 | inputs[Weight].dims.d[2], 185 | inputs[Weight].dims.d[3], 186 | inputs[Offset].dims.d[2], 187 | inputs[Offset].dims.d[3] 188 | }; 189 | return im2col.count() * (inputs->type == DataType::kHALF ? 2 : 4); 190 | } 191 | 192 | template 193 | std::pair, DCNLayerOutput> DCNLayerPlugin::makeView(const void *const *inputs, 194 | void *const *outputs, 195 | void *workspace) { 196 | return { 197 | { 198 | { 199 | static_cast(inputs[Input]), 200 | internal.input 201 | }, 202 | { 203 | static_cast(inputs[Offset]), 204 | internal.offset 205 | }, 206 | { 207 | static_cast(inputs[Mask]), 208 | internal.mask 209 | }, 210 | { 211 | static_cast(inputs[Weight]), 212 | internal.weight 213 | }, 214 | { 215 | static_cast(inputs[Bias]), 216 | internal.bias 217 | }, 218 | { 219 | static_cast(workspace), 220 | internal.im2col_buffer 221 | }, 222 | }, 223 | 224 | { 225 | { 226 | static_cast(outputs[0]), 227 | internal.output 228 | }, 229 | } 230 | }; 231 | } 232 | 233 | int DCNLayerPlugin::enqueue(const PluginTensorDesc *inputDesc, 234 | const PluginTensorDesc *outputDesc, 235 | const void *const *inputs, 236 | void *const *outputs, 237 | void *workspace, 238 | cudaStream_t stream) noexcept { 239 | 240 | DCNLayerExtra extra {cublas_context, 0}; 241 | 242 | switch ((DataType) internal.data_type) { 243 | case DataType::kFLOAT: { 244 | auto [in, out] = makeView(inputs, outputs, workspace); 245 | compute(in, out, config, extra, stream); 246 | return 0; 247 | } 248 | 249 | case DataType::kHALF: { 250 | auto [in, out] = makeView(inputs, outputs, workspace); 251 | compute(in, out, config, extra, stream); 252 | return 0; 253 | } 254 | 255 | default:return 1; 256 | } 257 | } 258 | 259 | const PluginField DCNLayerPluginCreator::mPluginAttributes[]{ 260 | {"stride", nullptr, PluginFieldType::kINT32, 2}, 261 | {"padding", nullptr, PluginFieldType::kINT32, 2}, 262 | {"dilation", nullptr, PluginFieldType::kINT32, 2}, 263 | {"deformable_groups", nullptr, PluginFieldType::kINT32, 1}, 264 | {"activation_type", nullptr, PluginFieldType::kINT32, 1}, 265 | {"alpha", nullptr, PluginFieldType::kFLOAT32, 1}, 266 | {"beta", nullptr, PluginFieldType::kFLOAT32, 1}, 267 | }; 268 | 269 | IPluginV2DynamicExt *DCNLayerPluginCreator::createPlugin(const char *name, 270 | const PluginFieldCollection *fc) noexcept { 271 | if (fc->nbFields != mPFC.nbFields) { 272 | return nullptr; 273 | } 274 | 275 | DCNLayerConfig config{ 276 | {1, 1}, 277 | {1, 1}, 278 | {1, 1}, 279 | 1, 280 | 3, 281 | 0.1, 282 | 0 283 | }; 284 | 285 | for (int32_t idx = 0; idx < fc->nbFields; ++idx) { 286 | const auto &field = fc->fields[idx]; 287 | 288 | if (std::strcmp(field.name, "stride") == 0) { 289 | std::memcpy(&config.stride, field.data, sizeof(config.stride)); 290 | } 291 | else if (std::strcmp(field.name, "padding") == 0) { 292 | std::memcpy(&config.padding, field.data, sizeof(config.padding)); 293 | } 294 | else if (std::strcmp(field.name, "dilation") == 0) { 295 | std::memcpy(&config.dilation, field.data, sizeof(config.dilation)); 296 | } 297 | else if (std::strcmp(field.name, "deformable_groups") == 0) { 298 | std::memcpy(&config.deformable_groups, field.data, sizeof(config.deformable_groups)); 299 | } 300 | else if (std::strcmp(field.name, "activation_type") == 0) { 301 | std::memcpy(&config.activation_type, field.data, sizeof(config.activation_type)); 302 | } 303 | else if (std::strcmp(field.name, "alpha") == 0) { 304 | std::memcpy(&config.alpha, field.data, sizeof(config.alpha)); 305 | } 306 | else if (std::strcmp(field.name, "beta") == 0) { 307 | std::memcpy(&config.beta, field.data, sizeof(config.beta)); 308 | } 309 | } 310 | 311 | auto p = new DCNLayerPlugin(&config); 312 | p->setPluginNamespace(mNamespace.c_str()); 313 | return p; 314 | } 315 | 316 | const PluginFieldCollection DCNLayerPluginCreator::mPFC{ 317 | sizeof(DCNLayerPluginCreator::mPluginAttributes) / sizeof(PluginField), 318 | DCNLayerPluginCreator::mPluginAttributes, 319 | }; 320 | 321 | const PluginFieldCollection *DCNLayerPluginCreator::getFieldNames() noexcept { 322 | return &nvinfer1::DCNLayerPluginCreator::mPFC; 323 | } 324 | 325 | IPluginV2DynamicExt *DCNLayerPluginCreator::deserializePlugin(const char *name, 326 | const void *serialData, 327 | size_t serialLength) noexcept { 328 | if (serialLength != sizeof(signature) + sizeof(DCNLayerConfig)) { 329 | return nullptr; 330 | } 331 | 332 | if (std::memcmp(&signature, serialData, sizeof(signature)) != 0) { 333 | return nullptr; 334 | } 335 | 336 | auto p = new DCNLayerPlugin((const char *) (serialData) + sizeof(signature), sizeof(DCNLayerConfig)); 337 | p->setPluginNamespace(mNamespace.c_str()); 338 | return p; 339 | } 340 | 341 | } 342 | -------------------------------------------------------------------------------- /tensorrt/layers/src/layers.cpp: -------------------------------------------------------------------------------- 1 | #include "helper.h" 2 | #include "dcn_layer.h" 3 | 4 | namespace UDOLayers { 5 | 6 | template 7 | struct PluginRegistrar { 8 | #ifdef AUTO_REGISTER_PLUGIN 9 | PluginRegistrar() noexcept 10 | { 11 | getPluginRegistry()->registerCreator(instance, ""); 12 | } 13 | #endif 14 | 15 | T instance{}; 16 | }; 17 | 18 | namespace { 19 | PluginRegistrar _mDCNLayerPluginCreator{}; 20 | } 21 | 22 | bool registerDCNLayerPlugin() { 23 | return getPluginRegistry()->registerCreator(_mDCNLayerPluginCreator.instance, ""); 24 | } 25 | 26 | bool registerPlugins() { 27 | return registerDCNLayerPlugin(); 28 | } 29 | 30 | } 31 | -------------------------------------------------------------------------------- /tensorrt/layers/test/dcn_layer_test.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by TYTY on 2022-10-04 004. 3 | // 4 | 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #include "cublas_v2.h" 11 | #include "cuda_fp16.h" 12 | #include "cuda_runtime_api.h" 13 | 14 | #include "dcn_layer_impl.h" 15 | #include "md_view.h" 16 | 17 | #include "gtest/gtest.h" 18 | 19 | #define CUDA_CHECK(status) ASSERT_EQ(status, cudaSuccess) 20 | #define COND_CHECK(cond, message) ASSERT_TRUE(cond) 21 | 22 | // small enough to not cause pixel drift 23 | constexpr double Epsilon = 0.5 / 255; 24 | 25 | // increase allowed epsilon for fp16. 26 | constexpr double EpsilonHalf = 0.025; 27 | 28 | typedef std::chrono::duration> millisecond; 29 | 30 | template 31 | void loadFile(const std::string &path, md_view &data) { 32 | std::ifstream file(path, std::ios::binary); 33 | COND_CHECK(file.good(), "can't open input file."); 34 | 35 | auto size = data.size() * sizeof(T); 36 | file.read((char *) data.data, size); 37 | file.close(); 38 | } 39 | 40 | template 41 | void loadFileNv(const std::string &path, md_view &data) { 42 | std::ifstream file(path, std::ios::binary); 43 | COND_CHECK(file.good(), "can't open input file."); 44 | 45 | auto size = data.size() * sizeof(T); 46 | auto tmp = std::make_unique(size); 47 | file.read(tmp.get(), size); 48 | file.close(); 49 | 50 | CUDA_CHECK(cudaMemcpy((void *) data.data, tmp.get(), size, cudaMemcpyHostToDevice)); 51 | } 52 | 53 | template 54 | void loadFileNvF2H(const std::string &path, md_view &data) { 55 | std::ifstream file(path, std::ios::binary); 56 | COND_CHECK(file.good(), "can't open input file."); 57 | 58 | auto size = data.size() * sizeof(float); 59 | auto tmp = std::make_unique(size); 60 | file.read(tmp.get(), size); 61 | file.close(); 62 | 63 | auto tmpF = (float *)tmp.get(); 64 | auto tmpH = ((half *)tmp.get()) + data.size(); 65 | for (offset_t i = 0; i < data.size(); ++i) { 66 | offset_t idx = data.size() - i - 1; 67 | tmpH[idx] = tmpF[idx]; 68 | } 69 | 70 | CUDA_CHECK(cudaMemcpy((void *) data.data, tmpH, size / 2, cudaMemcpyHostToDevice)); 71 | } 72 | 73 | template 74 | void ComputeFloat(DCNLayerInput &inputs, 75 | DCNLayerOutput &outputs, 76 | DCNLayerConfig &config, 77 | DCNLayerExtra &extra, 78 | cudaStream_t stream, 79 | int repeat) { 80 | for (int i = 0; i < repeat; ++i) { 81 | auto all_begin = std::chrono::steady_clock::now(); 82 | millisecond elapsed; 83 | 84 | compute(inputs, outputs, config, extra, stream); 85 | CUDA_CHECK(cudaStreamSynchronize(stream)); 86 | 87 | elapsed = std::chrono::steady_clock::now() - all_begin; 88 | std::cerr << "Inference done after " << elapsed.count() << "ms." << std::endl; 89 | } 90 | } 91 | 92 | void RunTest(const std::string &file_prefix, DCNLayerInput &inputs, DCNLayerOutput &outputs, DCNLayerConfig &config, int repeat = 1) { 93 | CUDA_CHECK(cudaMalloc((void **) &inputs.input.data, inputs.input.size() * sizeof(float))); 94 | CUDA_CHECK(cudaMalloc((void **) &inputs.offset.data, inputs.offset.size() * sizeof(float))); 95 | CUDA_CHECK(cudaMalloc((void **) &inputs.mask.data, inputs.mask.size() * sizeof(float))); 96 | CUDA_CHECK(cudaMalloc((void **) &inputs.weight.data, inputs.weight.size() * sizeof(float))); 97 | CUDA_CHECK(cudaMalloc((void **) &inputs.bias.data, inputs.bias.size() * sizeof(float))); 98 | CUDA_CHECK(cudaMalloc((void **) &inputs.im2col_buffer.data, inputs.im2col_buffer.size() * sizeof(float))); 99 | 100 | CUDA_CHECK(cudaMalloc((void **) &outputs.output.data, outputs.output.size() * sizeof(float))); 101 | 102 | loadFileNv(file_prefix + "input.bin", inputs.input); 103 | loadFileNv(file_prefix + "offset.bin", inputs.offset); 104 | loadFileNv(file_prefix + "mask.bin", inputs.mask); 105 | loadFileNv(file_prefix + "weight.bin", inputs.weight); 106 | loadFileNv(file_prefix + "bias.bin", inputs.bias); 107 | 108 | cudaStream_t stream; 109 | CUDA_CHECK(cudaStreamCreate(&stream)); 110 | 111 | cublasHandle_t cublas; 112 | 113 | DCNLayerExtra extra{}; 114 | ASSERT_EQ(cublasCreate_v2(&cublas), CUBLAS_STATUS_SUCCESS); 115 | ASSERT_EQ(cublasSetStream_v2(cublas, stream), CUBLAS_STATUS_SUCCESS); 116 | extra.cublasHandle = cublas; 117 | extra.blasMode = 0; 118 | 119 | ComputeFloat(inputs, outputs, config, extra, stream, repeat); 120 | 121 | CUDA_CHECK(cudaGetLastError()); 122 | CUDA_CHECK(cudaStreamDestroy(stream)); 123 | 124 | auto oshape = outputs.output.shape; 125 | auto output_ref_storage = std::make_unique(oshape.count()); 126 | auto output_cpu_storage = std::make_unique(oshape.count()); 127 | md_view output_ref{output_ref_storage.get(), oshape}; 128 | md_view output_cpu{output_cpu_storage.get(), oshape}; 129 | 130 | loadFile(file_prefix + "output.bin", output_ref); 131 | CUDA_CHECK(cudaMemcpy((void *) output_cpu.data, outputs.output.data, oshape.count() * sizeof(float), cudaMemcpyDeviceToHost)); 132 | 133 | float max = 0; 134 | double total = 0; 135 | 136 | for (offset_t n = 0; n < oshape[0]; ++n) { 137 | for (offset_t c = 0; c < oshape[1]; ++c) { 138 | for (offset_t h = 0; h < oshape[2]; ++h) { 139 | for (offset_t w = 0; w < oshape[3]; ++w) { 140 | EXPECT_NEAR(output_ref.at(n, c, h, w), output_cpu.at(n, c, h, w), Epsilon) << "The coordinate is [" << n << "," << c << "," << h << "," << w << "]"; 141 | float diff = std::abs(output_ref.at(n, c, h, w) - output_cpu.at(n, c, h, w)); 142 | total += diff; 143 | max = diff > max ? diff : max; 144 | } 145 | } 146 | } 147 | } 148 | 149 | std::cerr << "Diff: max " << max << ", avg " << total / double(oshape.count()) << std::endl; 150 | 151 | ASSERT_EQ(cublasDestroy_v2(cublas), CUBLAS_STATUS_SUCCESS); 152 | 153 | CUDA_CHECK(cudaFree((void *) inputs.input.data)); 154 | CUDA_CHECK(cudaFree((void *) inputs.offset.data)); 155 | CUDA_CHECK(cudaFree((void *) inputs.mask.data)); 156 | CUDA_CHECK(cudaFree((void *) inputs.weight.data)); 157 | CUDA_CHECK(cudaFree((void *) inputs.bias.data)); 158 | CUDA_CHECK(cudaFree((void *) inputs.im2col_buffer.data)); 159 | 160 | CUDA_CHECK(cudaFree((void *) outputs.output.data)); 161 | } 162 | 163 | void RunTest(const std::string &file_prefix, DCNLayerInput &inputs, DCNLayerOutput &outputs, DCNLayerConfig &config, int repeat = 1) { 164 | CUDA_CHECK(cudaMalloc((void **) &inputs.input.data, inputs.input.size() * sizeof(half))); 165 | CUDA_CHECK(cudaMalloc((void **) &inputs.offset.data, inputs.offset.size() * sizeof(half))); 166 | CUDA_CHECK(cudaMalloc((void **) &inputs.mask.data, inputs.mask.size() * sizeof(half))); 167 | CUDA_CHECK(cudaMalloc((void **) &inputs.weight.data, inputs.weight.size() * sizeof(half))); 168 | CUDA_CHECK(cudaMalloc((void **) &inputs.bias.data, inputs.bias.size() * sizeof(half))); 169 | CUDA_CHECK(cudaMalloc((void **) &inputs.im2col_buffer.data, inputs.im2col_buffer.size() * sizeof(half))); 170 | 171 | CUDA_CHECK(cudaMalloc((void **) &outputs.output.data, outputs.output.size() * sizeof(half))); 172 | 173 | loadFileNvF2H(file_prefix + "input.bin", inputs.input); 174 | loadFileNvF2H(file_prefix + "offset.bin", inputs.offset); 175 | loadFileNvF2H(file_prefix + "mask.bin", inputs.mask); 176 | loadFileNvF2H(file_prefix + "weight.bin", inputs.weight); 177 | loadFileNvF2H(file_prefix + "bias.bin", inputs.bias); 178 | 179 | cudaStream_t stream; 180 | CUDA_CHECK(cudaStreamCreate(&stream)); 181 | 182 | cublasHandle_t cublas; 183 | 184 | DCNLayerExtra extra{}; 185 | ASSERT_EQ(cublasCreate_v2(&cublas), CUBLAS_STATUS_SUCCESS); 186 | ASSERT_EQ(cublasSetStream_v2(cublas, stream), CUBLAS_STATUS_SUCCESS); 187 | extra.cublasHandle = cublas; 188 | extra.blasMode = 0; 189 | 190 | ComputeFloat(inputs, outputs, config, extra, stream, repeat); 191 | 192 | CUDA_CHECK(cudaGetLastError()); 193 | CUDA_CHECK(cudaStreamDestroy(stream)); 194 | 195 | auto oshape = outputs.output.shape; 196 | auto output_ref_storage = std::make_unique(oshape.count()); 197 | auto output_cpu_storage = std::make_unique(oshape.count()); 198 | md_view output_ref{output_ref_storage.get(), oshape}; 199 | md_view output_cpu{output_cpu_storage.get(), oshape}; 200 | 201 | loadFile(file_prefix + "output.bin", output_ref); 202 | CUDA_CHECK(cudaMemcpy((void *) output_cpu.data, outputs.output.data, oshape.count() * sizeof(half), cudaMemcpyDeviceToHost)); 203 | 204 | float max = 0; 205 | double total = 0; 206 | 207 | for (offset_t n = 0; n < oshape[0]; ++n) { 208 | for (offset_t c = 0; c < oshape[1]; ++c) { 209 | for (offset_t h = 0; h < oshape[2]; ++h) { 210 | for (offset_t w = 0; w < oshape[3]; ++w) { 211 | EXPECT_NEAR(output_ref.at(n, c, h, w), output_cpu.at(n, c, h, w), EpsilonHalf) << "The coordinate is [" << n << "," << c << "," << h << "," << w << "]"; 212 | float diff = std::abs(output_ref.at(n, c, h, w) - output_cpu.at(n, c, h, w)); 213 | total += diff; 214 | max = diff > max ? diff : max; 215 | } 216 | } 217 | } 218 | } 219 | 220 | std::cerr << "Diff: max " << max << ", avg " << total / double(oshape.count()) << std::endl; 221 | 222 | ASSERT_EQ(cublasDestroy_v2(cublas), CUBLAS_STATUS_SUCCESS); 223 | 224 | CUDA_CHECK(cudaFree((void *) inputs.input.data)); 225 | CUDA_CHECK(cudaFree((void *) inputs.offset.data)); 226 | CUDA_CHECK(cudaFree((void *) inputs.mask.data)); 227 | CUDA_CHECK(cudaFree((void *) inputs.weight.data)); 228 | CUDA_CHECK(cudaFree((void *) inputs.bias.data)); 229 | CUDA_CHECK(cudaFree((void *) inputs.im2col_buffer.data)); 230 | 231 | CUDA_CHECK(cudaFree((void *) outputs.output.data)); 232 | } 233 | 234 | TEST(DCNLayerTest, SmallInput) { 235 | DCNLayerInput input{ 236 | {nullptr, {1, 1, 5, 5}}, 237 | {nullptr, {1, 1, 3, 3, 2, 5, 5}}, 238 | {nullptr, {1, 1, 3, 3, 5, 5}}, 239 | {nullptr, {1, 1, 3, 3}}, 240 | {nullptr, {1}}, 241 | 242 | {nullptr, {1, 1, 3, 3, 5, 5}}}; 243 | 244 | DCNLayerOutput output{ 245 | {nullptr, {1, 1, 5, 5}}}; 246 | 247 | DCNLayerConfig config{ 248 | {1, 1}, 249 | {1, 1}, 250 | {1, 1}, 251 | 1, 252 | 3, 253 | 0.1, 254 | 0}; 255 | 256 | RunTest("./small/", input, output, config); 257 | } 258 | 259 | TEST(DCNLayerTest, RealCase) { 260 | DCNLayerInput input{ 261 | {nullptr, {2, 64, 180, 320}}, 262 | {nullptr, {2, 8, 3, 3, 2, 180, 320}}, 263 | {nullptr, {2, 8, 3, 3, 180, 320}}, 264 | {nullptr, {64, 64, 3, 3}}, 265 | {nullptr, {64}}, 266 | 267 | {nullptr, {2, 64, 3, 3, 180, 320}}}; 268 | 269 | DCNLayerOutput output{ 270 | {nullptr, {2, 64, 180, 320}}}; 271 | 272 | DCNLayerConfig config{ 273 | {1, 1}, 274 | {1, 1}, 275 | {1, 1}, 276 | 8, 277 | -1, 278 | 0, 279 | 0}; 280 | 281 | for (int i = 0; i < 1000; ++i) { 282 | RunTest("./real/", input, output, config, 100); 283 | } 284 | 285 | } 286 | 287 | TEST(DCNLayerTest, RealCaseHalf) { 288 | DCNLayerInput input{ 289 | {nullptr, {2, 64, 180, 320}}, 290 | {nullptr, {2, 8, 3, 3, 2, 180, 320}}, 291 | {nullptr, {2, 8, 3, 3, 180, 320}}, 292 | {nullptr, {64, 64, 3, 3}}, 293 | {nullptr, {64}}, 294 | 295 | {nullptr, {2, 64, 3, 3, 180, 320}}}; 296 | 297 | DCNLayerOutput output{ 298 | {nullptr, {2, 64, 180, 320}}}; 299 | 300 | DCNLayerConfig config{ 301 | {1, 1}, 302 | {1, 1}, 303 | {1, 1}, 304 | 8, 305 | -1, 306 | 0, 307 | 0}; 308 | 309 | RunTest("./real/", input, output, config, 20); 310 | } 311 | -------------------------------------------------------------------------------- /tensorrt/model_src/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tongyuantongyu/cycmunet/dfb04885a9c2b79e0f7ee1b24fdb8de648e9321c/tensorrt/model_src/.gitkeep -------------------------------------------------------------------------------- /tensorrt/models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tongyuantongyu/cycmunet/dfb04885a9c2b79e0f7ee1b24fdb8de648e9321c/tensorrt/models/.gitkeep -------------------------------------------------------------------------------- /tensorrt/src/optimize.cpp: -------------------------------------------------------------------------------- 1 | // 2 | // Created by TYTY on 2021-12-23 023. 3 | // 4 | 5 | #include "optimize.h" 6 | 7 | #include "NvOnnxParser.h" 8 | 9 | #define COND_CHECK_EMPTY(cond, message) \ 10 | do { \ 11 | if (!(cond)) { \ 12 | std::stringstream s; \ 13 | s << "Check failed " __FILE__ ":" << __LINE__ << ": " #cond ", " << message; \ 14 | logger.log(nvinfer1::ILogger::Severity::kERROR, s.str().c_str()); \ 15 | return -1; \ 16 | } \ 17 | } while (0) 18 | 19 | static void inspectNetwork(nvinfer1::INetworkDefinition *network) { 20 | auto count = network->getNbLayers(); 21 | for (int i = 0; i < count; ++i) { 22 | auto layer = network->getLayer(i); 23 | auto i_count = layer->getNbInputs(); 24 | std::cerr << "#" << i << ": " << layer->getName() << ", " << int32_t(layer->getType()) << ": "; 25 | 26 | std::cerr << "from {"; 27 | for (int j = 0; j < i_count; ++j) { 28 | std::string name = layer->getInput(j)->getName(); 29 | if (name.size() > 15) { 30 | name = std::to_string(atoi(name.c_str() + 15)); 31 | } 32 | std::cerr << name; 33 | auto size = layer->getInput(j)->getDimensions(); 34 | std::cerr << "("; 35 | for (int k = 0; k < size.nbDims; ++k) { 36 | std::cerr << size.d[k] << ","; 37 | } 38 | std::cerr << "\x08), "; 39 | } 40 | std::cerr << "\x08\x08}\n"; 41 | } 42 | } 43 | 44 | // feature extract: 45 | // 0 2 4 6 46 | // | 47 | // 0 2 4 6 (3x) 48 | 49 | // feature fusion 50 | // 0 2 4 (3x) 51 | // 2 4 6 (3x) 52 | // | 53 | // 1 3 5 54 | 55 | // main 56 | // 0 2 4 57 | // 1 3 5 58 | // 2 4 6 59 | // | 60 | // HR of input 61 | 62 | static constexpr size_t ceil_half(size_t size) { 63 | return (size + 1) / 2; 64 | } 65 | 66 | static std::string model_name_suffix(const OptimizationConfig &config) { 67 | std::stringstream ss; 68 | ss << "_n" << config.input_count; 69 | ss << "_" << config.scale_factor_w << "x" << config.scale_factor_h << "_l" << config.extraction_layers; 70 | if (config.format == IOFormat::YUV420) { 71 | ss << "_yuv1-1"; 72 | } 73 | ss << ".onnx"; 74 | return ss.str(); 75 | } 76 | 77 | static std::string fe_engine_name(const OptimizationConfig &config) { 78 | std::stringstream ss; 79 | ss << "fe_"; 80 | ss << config.input_width.opt << 'x' << config.input_height.opt << '_' << config.scale_factor_w << "x" 81 | << config.scale_factor_h << "_b" << config.batch_extract.opt << "_l" << config.extraction_layers; 82 | if (config.format == IOFormat::YUV420) { 83 | ss << "_yuv1-1"; 84 | } 85 | if (config.use_fp16) { 86 | ss << "_fp16"; 87 | } 88 | if (config.low_mem) { 89 | ss << "_lm"; 90 | } 91 | ss << ".engine"; 92 | return ss.str(); 93 | } 94 | 95 | static std::string ff_engine_name(const OptimizationConfig &config) { 96 | std::stringstream ss; 97 | ss << "ff_"; 98 | ss << "n" << config.input_count; 99 | if (config.double_frame) { 100 | ss << "a"; 101 | } 102 | if (config.extra_frame) { 103 | ss << "+"; 104 | } 105 | ss << "_"; 106 | ss << config.input_width.opt << 'x' << config.input_height.opt << '_' << config.scale_factor_w << "x" 107 | << config.scale_factor_h << "_b" << config.batch_fusion.opt << "_l" << config.extraction_layers; 108 | if (config.format == IOFormat::YUV420) { 109 | ss << "_yuv1-1"; 110 | } 111 | if (config.use_fp16) { 112 | ss << "_fp16"; 113 | } 114 | if (config.low_mem) { 115 | ss << "_lm"; 116 | } 117 | ss << ".engine"; 118 | return ss.str(); 119 | } 120 | 121 | nvinfer1::IBuilderConfig *OptimizationContext::prepareConfig() const { 122 | auto conf = builder->createBuilderConfig(); 123 | if (config.use_fp16) { 124 | conf->setFlag(nvinfer1::BuilderFlag::kFP16); 125 | } 126 | conf->setFlag(nvinfer1::BuilderFlag::kTF32); 127 | conf->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS); 128 | conf->setFlag(nvinfer1::BuilderFlag::kPREFER_PRECISION_CONSTRAINTS); 129 | // /usr/src/tensorrt/bin/trtexec --verbose --noDataTransfers --useCudaGraph --separateProfileRun --useSpinWait --nvtxMode=verbose --loadEngine=./mutual_cycle.engine --exportTimes=./mutual_cycle.timing.json --exportProfile=./mutual_cycle.profile.json --exportLayerInfo=./mutual_cycle.graph.json --timingCacheFile=./timing.cache --best --avgRuns=1000 "--shapes=lf0:1x64x180x270,lf1:1x64x180x270,lf2:1x64x180x270" 130 | conf->setProfilingVerbosity(nvinfer1::ProfilingVerbosity::kDETAILED); 131 | conf->setTacticSources(conf->getTacticSources() & ~nvinfer1::TacticSources(1u << int32_t(nvinfer1::TacticSource::kCUDNN))); 132 | if (config.low_mem) { 133 | conf->setTacticSources(conf->getTacticSources() & ~nvinfer1::TacticSources(1u << int32_t(nvinfer1::TacticSource::kEDGE_MASK_CONVOLUTIONS))); 134 | } 135 | 136 | if (cache != nullptr) { 137 | conf->setTimingCache(*cache, false); 138 | } 139 | 140 | return conf; 141 | } 142 | 143 | nvinfer1::INetworkDefinition *OptimizationContext::createNetwork() const { 144 | return builder->createNetworkV2(1u << uint32_t(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH)); 145 | } 146 | 147 | OptimizationContext::OptimizationContext(OptimizationConfig config, nvinfer1::ILogger &logger, 148 | std::filesystem::path path_prefix_) 149 | : config(config), logger(logger), path_prefix(std::move(path_prefix_)), 150 | builder(nvinfer1::createInferBuilder(logger)), cache(nullptr), prop {}, total_memory {} { 151 | auto conf = builder->createBuilderConfig(); 152 | cudaMemGetInfo(nullptr, &total_memory); 153 | cudaGetDeviceProperties(&prop, 0); 154 | logger.log(nvinfer1::ILogger::Severity::kINFO, 155 | ("Device has " + std::to_string(total_memory) + " bytes memory.").c_str()); 156 | 157 | if (builder->platformHasFastFp16() && !config.use_fp16) { 158 | // CUDA Architecture 6.1 (Pascal, GTX10xx series) does not have really useful FP16. 159 | if (prop.major != 6 || prop.minor != 1) { 160 | logger.log(nvinfer1::ILogger::Severity::kWARNING, "Fast FP16 is available but not enabled."); 161 | } 162 | } 163 | 164 | auto cache_file = path_prefix / "timing.cache"; 165 | std::ifstream input(cache_file, std::ios::binary | std::ios::in); 166 | if (input.is_open()) { 167 | auto size = std::filesystem::file_size(cache_file); 168 | auto *values = new char[size]; 169 | input.read(values, size); 170 | cache = conf->createTimingCache(values, size); 171 | delete[] values; 172 | input.close(); 173 | } 174 | if (cache == nullptr) { 175 | cache = conf->createTimingCache(nullptr, 0); 176 | } 177 | } 178 | 179 | OptimizationContext::~OptimizationContext() { 180 | if (cache != nullptr) { 181 | std::ofstream output(path_prefix / "timing.cache", std::ios::binary | std::ios::out); 182 | auto memory = cache->serialize(); 183 | output.write(static_cast(memory->data()), memory->size()); 184 | output.close(); 185 | } 186 | } 187 | 188 | int OptimizationContext::optimize(const std::filesystem::path &folder) { 189 | auto fe_target = path_prefix / "engines" / folder / fe_engine_name(config); 190 | if (!exists(fe_target)) { 191 | auto fe_source_file = path_prefix / "models" / folder / ("fe" + model_name_suffix(config)); 192 | std::ifstream input_fe(fe_source_file, std::ios::binary | std::ios::in); 193 | COND_CHECK_EMPTY(input_fe.is_open(), "Source model file not exist:" << fe_source_file); 194 | std::vector fe_source(std::filesystem::file_size(fe_source_file)); 195 | input_fe.read((char *) (fe_source.data()), fe_source.size()); 196 | auto ret = buildFeatureExtract(std::move(fe_source), fe_target); 197 | if (ret != 0) { 198 | return ret; 199 | } 200 | } 201 | 202 | auto ff_target = path_prefix / "engines" / folder / ff_engine_name(config); 203 | if (!exists(ff_target)) { 204 | auto ff_source_file = path_prefix / "models" / folder / ("ff" + model_name_suffix(config)); 205 | std::ifstream input_ff(ff_source_file, std::ios::binary | std::ios::in); 206 | COND_CHECK_EMPTY(input_ff.is_open(), "Source model file not exist:" << ff_source_file); 207 | std::vector ff_source(std::filesystem::file_size(ff_source_file)); 208 | input_ff.read((char *) (ff_source.data()), ff_source.size()); 209 | auto ret = buildFeatureFusion(std::move(ff_source), ff_target); 210 | if (ret != 0) { 211 | return ret; 212 | } 213 | } 214 | 215 | return 0; 216 | } 217 | 218 | int OptimizationContext::buildFeatureExtract(std::vector input, const std::filesystem::path &output) { 219 | auto network = createNetwork(); 220 | auto profile = builder->createOptimizationProfile(); 221 | auto parser = nvonnxparser::createParser(*network, logger); 222 | COND_CHECK_EMPTY(parser->parse(input.data(), input.size()), "Failed parse source model."); 223 | input.clear(); 224 | 225 | auto ioDataType = config.use_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT; 226 | 227 | for (int32_t i = 0; i < config.input_count; ++i) { 228 | if (config.format == IOFormat::RGB) { 229 | profile->setDimensions( 230 | "rgb", nvinfer1::OptProfileSelector::kMIN, 231 | nvinfer1::Dims4 {config.batch_extract.min, 3, config.input_height.min, config.input_width.min}); 232 | profile->setDimensions( 233 | "rgb", nvinfer1::OptProfileSelector::kOPT, 234 | nvinfer1::Dims4 {config.batch_extract.opt, 3, config.input_height.opt, config.input_width.opt}); 235 | profile->setDimensions( 236 | "rgb", nvinfer1::OptProfileSelector::kMAX, 237 | nvinfer1::Dims4 {config.batch_extract.max, 3, config.input_height.max, config.input_width.max}); 238 | } 239 | else { 240 | profile->setDimensions( 241 | "y", nvinfer1::OptProfileSelector::kMIN, 242 | nvinfer1::Dims4 {config.batch_extract.min, 1, config.input_height.min, config.input_width.min}); 243 | profile->setDimensions( 244 | "y", nvinfer1::OptProfileSelector::kOPT, 245 | nvinfer1::Dims4 {config.batch_extract.opt, 1, config.input_height.opt, config.input_width.opt}); 246 | profile->setDimensions( 247 | "y", nvinfer1::OptProfileSelector::kMAX, 248 | nvinfer1::Dims4 {config.batch_extract.max, 1, config.input_height.max, config.input_width.max}); 249 | 250 | profile->setDimensions( 251 | "uv", nvinfer1::OptProfileSelector::kMIN, 252 | nvinfer1::Dims4 {config.batch_extract.min, 2, config.input_height.min / 2, config.input_width.min / 2}); 253 | profile->setDimensions( 254 | "uv", nvinfer1::OptProfileSelector::kOPT, 255 | nvinfer1::Dims4 {config.batch_extract.opt, 2, config.input_height.opt / 2, config.input_width.opt / 2}); 256 | profile->setDimensions( 257 | "uv", nvinfer1::OptProfileSelector::kMAX, 258 | nvinfer1::Dims4 {config.batch_extract.max, 2, config.input_height.max / 2, config.input_width.max / 2}); 259 | } 260 | } 261 | 262 | for (int i = 0; i < network->getNbInputs(); ++i) { 263 | network->getInput(i)->setType(ioDataType); 264 | } 265 | 266 | for (int i = 0; i < network->getNbOutputs(); ++i) { 267 | network->getOutput(i)->setType(ioDataType); 268 | } 269 | logger.log(nvinfer1::ILogger::Severity::kINFO, "Done define feature extract net."); 270 | 271 | auto optimize_config = prepareConfig(); 272 | // value from experience 273 | // optimize_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, total_memory / 24); 274 | optimize_config->addOptimizationProfile(profile); 275 | auto modelStream = builder->buildSerializedNetwork(*network, *optimize_config); 276 | COND_CHECK_EMPTY(modelStream != nullptr, "Failed build feature extract net."); 277 | logger.log(nvinfer1::ILogger::Severity::kINFO, "Done build feature extract net."); 278 | 279 | auto parent = output; 280 | std::filesystem::create_directories(parent.remove_filename()); 281 | std::ofstream p(output, std::ios::binary); 282 | COND_CHECK_EMPTY(p.is_open(), "Unable to open engine file for output."); 283 | p.write(static_cast(modelStream->data()), modelStream->size()); 284 | p.close(); 285 | logger.log(nvinfer1::ILogger::Severity::kINFO, "Done save feature extract net."); 286 | 287 | return 0; 288 | } 289 | 290 | int OptimizationContext::buildFeatureFusion(std::vector input, const std::filesystem::path &output) { 291 | auto network = createNetwork(); 292 | auto profile = builder->createOptimizationProfile(); 293 | auto parser = nvonnxparser::createParser(*network, logger); 294 | COND_CHECK_EMPTY(parser->parse(input.data(), input.size()), "Failed parse source model."); 295 | input.clear(); 296 | 297 | auto ioDataType = config.use_fp16 ? nvinfer1::DataType::kHALF : nvinfer1::DataType::kFLOAT; 298 | auto layer_height = config.input_height; 299 | auto layer_width = config.input_width; 300 | 301 | for (int i = 0; i < config.extraction_layers; ++i) { 302 | for (int j = 0; j < config.input_count; ++j) { 303 | auto name = "f" + std::to_string((config.interpolation ? 2 : 1) * j) + "l" + std::to_string(i); 304 | profile->setDimensions( 305 | name.c_str(), nvinfer1::OptProfileSelector::kMIN, 306 | nvinfer1::Dims4 {config.batch_fusion.min, config.feature_count, layer_height.min, layer_width.min}); 307 | profile->setDimensions( 308 | name.c_str(), nvinfer1::OptProfileSelector::kOPT, 309 | nvinfer1::Dims4 {config.batch_fusion.opt, config.feature_count, layer_height.opt, layer_width.opt}); 310 | profile->setDimensions( 311 | name.c_str(), nvinfer1::OptProfileSelector::kMAX, 312 | nvinfer1::Dims4 {config.batch_fusion.max, config.feature_count, layer_height.max, layer_width.max}); 313 | } 314 | layer_height.min = ceil_half(layer_height.min); 315 | layer_height.opt = ceil_half(layer_height.opt); 316 | layer_height.max = ceil_half(layer_height.max); 317 | layer_width.min = ceil_half(layer_width.min); 318 | layer_width.opt = ceil_half(layer_width.opt); 319 | layer_width.max = ceil_half(layer_width.max); 320 | } 321 | 322 | for (int i = 0; i < network->getNbInputs(); ++i) { 323 | network->getInput(i)->setType(ioDataType); 324 | } 325 | 326 | for (int i = 0; i < network->getNbOutputs(); ++i) { 327 | network->getOutput(i)->setType(ioDataType); 328 | } 329 | logger.log(nvinfer1::ILogger::Severity::kINFO, "Done define feature fusion net."); 330 | 331 | auto optimize_config = prepareConfig(); 332 | // value from experience 333 | // optimize_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, total_memory / 2); 334 | optimize_config->addOptimizationProfile(profile); 335 | auto modelStream = builder->buildSerializedNetwork(*network, *optimize_config); 336 | COND_CHECK_EMPTY(modelStream != nullptr, "Failed build feature fusion net."); 337 | logger.log(nvinfer1::ILogger::Severity::kINFO, "Done build feature fusion net."); 338 | 339 | auto parent = output; 340 | std::filesystem::create_directories(parent.remove_filename()); 341 | std::ofstream p(output, std::ios::binary); 342 | COND_CHECK_EMPTY(p.is_open(), "Unable to open engine file for output."); 343 | p.write(static_cast(modelStream->data()), modelStream->size()); 344 | p.close(); 345 | logger.log(nvinfer1::ILogger::Severity::kINFO, "Done save feature fusion net."); 346 | 347 | return 0; 348 | } 349 | -------------------------------------------------------------------------------- /tensorrt/src/reformat.cu: -------------------------------------------------------------------------------- 1 | #include "reformat.h" 2 | #include 3 | #include 4 | 5 | half __device__ round(half f) { 6 | const half v0_5 = float(0.5); 7 | return hfloor(f + v0_5); 8 | } 9 | 10 | template 11 | static void __global__ fma_from(md_view dst, md_view src, F a, F b) { 12 | uint32_t dst_x = threadIdx.x + blockDim.x * blockIdx.x; 13 | uint32_t dst_y = threadIdx.y + blockDim.y * blockIdx.y; 14 | 15 | auto [dst_h, dst_w] = dst.shape; 16 | if (dst_x >= dst_w || dst_y >= dst_h) { 17 | return; 18 | } 19 | 20 | auto [src_h, src_w] = src.shape; 21 | uint32_t src_x = dst_x >= src_w ? src_w - 1 : dst_x; 22 | uint32_t src_y = dst_y >= src_h ? src_h - 1 : dst_y; 23 | 24 | F value = static_cast(src.at(src_y, src_x)); 25 | value = a * value + b; 26 | dst.at(dst_y, dst_x) = value; 27 | } 28 | 29 | template 30 | static void __global__ fma_to(md_view dst, md_view src, F a, F b, F min, F max) { 31 | uint32_t dst_x = threadIdx.x + blockDim.x * blockIdx.x; 32 | uint32_t dst_y = threadIdx.y + blockDim.y * blockIdx.y; 33 | 34 | auto [dst_h, dst_w] = dst.shape; 35 | if (dst_x >= dst_w || dst_y >= dst_h) { 36 | return; 37 | } 38 | 39 | F value = static_cast(src.at(dst_y, dst_x)); 40 | value = a * value + b; 41 | if constexpr (std::is_integral_v) { 42 | value = round(value); 43 | } 44 | 45 | if (value < min) { 46 | value = min; 47 | } 48 | else if (value > max) { 49 | value = max; 50 | } 51 | 52 | if constexpr (std::is_integral_v && sizeof(U) == 1) { 53 | dst.at(dst_y, dst_x) = static_cast(static_cast(value)); 54 | } else { 55 | dst.at(dst_y, dst_x) = static_cast(value); 56 | } 57 | } 58 | 59 | template 60 | void import_pixel(md_view dst, md_view src, float a, float b, cudaStream_t stream) { 61 | dim3 dimBlock(32, 32); 62 | dim3 dimGrid; 63 | auto [dst_h, dst_w] = dst.shape; 64 | dimGrid.x = (dst_w + 31) / 32; 65 | dimGrid.y = (dst_h + 31) / 32; 66 | 67 | fma_from<<>>(dst, src, F(a), F(b)); 68 | } 69 | 70 | template void import_pixel(md_view dst, md_view src, float a, float b, 71 | cudaStream_t stream); 72 | template void import_pixel(md_view dst, md_view src, float a, float b, 73 | cudaStream_t stream); 74 | template void import_pixel(md_view dst, md_view src, float a, float b, 75 | cudaStream_t stream); 76 | template void import_pixel(md_view dst, md_view src, float a, float b, 77 | cudaStream_t stream); 78 | template void import_pixel(md_view dst, md_view src, float a, float b, 79 | cudaStream_t stream); 80 | template void import_pixel(md_view dst, md_view src, float a, float b, 81 | cudaStream_t stream); 82 | template void import_pixel(md_view dst, md_view src, float a, float b, 83 | cudaStream_t stream); 84 | template void import_pixel(md_view dst, md_view src, float a, float b, 85 | cudaStream_t stream); 86 | 87 | template 88 | void export_pixel(md_view dst, md_view src, float a, float b, float min, float max, cudaStream_t stream) { 89 | dim3 dimBlock(32, 32); 90 | dim3 dimGrid; 91 | auto [dst_h, dst_w] = dst.shape; 92 | dimGrid.x = (dst_w + 31) / 32; 93 | dimGrid.y = (dst_h + 31) / 32; 94 | 95 | fma_to<<>>(dst, src, F(a), F(b), F(min), F(max)); 96 | } 97 | 98 | template void export_pixel(md_view dst, md_view src, float a, float b, float min, 99 | float max, cudaStream_t stream); 100 | template void export_pixel(md_view dst, md_view src, float a, float b, float min, 101 | float max, cudaStream_t stream); 102 | template void export_pixel(md_view dst, md_view src, float a, float b, float min, 103 | float max, cudaStream_t stream); 104 | template void export_pixel(md_view dst, md_view src, float a, float b, float min, 105 | float max, cudaStream_t stream); 106 | template void export_pixel(md_view dst, md_view src, float a, float b, float min, 107 | float max, cudaStream_t stream); 108 | template void export_pixel(md_view dst, md_view src, float a, float b, float min, 109 | float max, cudaStream_t stream); 110 | template void export_pixel(md_view dst, md_view src, float a, float b, float min, 111 | float max, cudaStream_t stream); 112 | template void export_pixel(md_view dst, md_view src, float a, float b, float min, 113 | float max, cudaStream_t stream); -------------------------------------------------------------------------------- /torch/README.md: -------------------------------------------------------------------------------- 1 | # PyTorch implementation of CycMuNet+ 2 | 3 | This is the PyTorch implementation of CycMuNet+. 4 | 5 | ## Installation 6 | 7 | ```bash 8 | conda create -f environment.yml 9 | ``` 10 | 11 | ## Contents 12 | 13 | ### `train.py` 14 | 15 | Train the network. 16 | 17 | ### `test.py` 18 | 19 | Test the network accuracy. 20 | 21 | ### `export_onnx.py` 22 | 23 | Export trained network to ONNX format. 24 | -------------------------------------------------------------------------------- /torch/cycmunet/model.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | model_arg = namedtuple('model_arg', ('nf', # number of feature channel 4 | 'groups', # number of deformable convolution group 5 | 'upscale_factor', # model upscale factor 6 | 'format', # model I/O format (rgb, yuv420) 7 | 'layers', # feature fusion pyramid layers 8 | 'cycle_count' # mutual cycle count 9 | )) 10 | -------------------------------------------------------------------------------- /torch/cycmunet/run.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | _share_args = ('size', # input size 4 | 'dataset_type', # type of dataset 5 | 'dataset_indexes', # index files for dataset 6 | 'preview_interval', # interval to save network output for previewing 7 | 'batch_size', # process batch size 8 | 'seed', # seed for random number generators 9 | ) 10 | 11 | _train_args = ('lr', # init learning rate 12 | 'pretrained', # pretrained checkpoint 13 | 'start_epoch', # start epoch index 14 | 'end_epoch', # end epoch index (exclusive) 15 | 'sparsity', # train network with sparsity 16 | 'autocast', # train with auto mixed precision 17 | 'loss_type', # loss type for optimization 18 | 'save_path', # checkpoint save path 19 | 'save_prefix', # prefix of checkpoint file name 20 | ) 21 | 22 | _test_args = ('checkpoints', # checkpoint to test 23 | 'fp16', # use fp16 to run network forward 24 | ) 25 | 26 | train_arg = namedtuple('train_arg', (*_share_args, *_train_args)) 27 | test_arg = namedtuple('test_arg', (*_share_args, *_test_args)) 28 | -------------------------------------------------------------------------------- /torch/cycmunet_export_onnx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | 4 | import torch 5 | from torch.onnx import symbolic_helper 6 | import onnx 7 | import onnx.shape_inference 8 | import onnxsim 9 | import onnx_graphsurgeon as gs 10 | 11 | from model import CycMuNet, use_fold_catconv 12 | from cycmunet.model import model_arg 13 | 14 | model_args = model_arg(nf=64, 15 | groups=8, 16 | upscale_factor=2, 17 | format='yuv420', 18 | layers=4, 19 | cycle_count=3 20 | ) 21 | 22 | checkpoint_file = 'checkpoints/triplet_s2_2x_l4_c3_snapshot.pth' 23 | output_path = 'onnx/triplet_new' 24 | 25 | 26 | size = 1 << model_args.layers 27 | size_in = (size, size) 28 | size_out = tuple(i * model_args.upscale_factor for i in size_in) 29 | output_path = pathlib.Path(output_path) 30 | config_string = f"_{model_args.upscale_factor}x_l{model_args.layers}" 31 | if model_args.format == 'yuv420': 32 | size_uv_in = tuple(i // 2 for i in size_in) 33 | config_string += '_yuv1-1' 34 | 35 | fe_onnx = str(output_path / f'fe{config_string}.onnx') 36 | ff_onnx = str(output_path / f'ff{config_string}.onnx') 37 | os.makedirs(output_path, exist_ok=True) 38 | 39 | 40 | # Placeholder to export DeformConv 41 | @symbolic_helper.parse_args("v", "v", "v", "v", "v", "i", "i", "i", "i", "i", "i", "i", "i", "b") 42 | def symbolic_deform_conv2d_forward(g, 43 | input, 44 | weight, 45 | offset, 46 | mask, 47 | bias, 48 | stride_h, 49 | stride_w, 50 | pad_h, 51 | pad_w, 52 | dil_h, 53 | dil_w, 54 | n_weight_grps, 55 | n_offset_grps, 56 | use_mask): 57 | if n_weight_grps != 1 or not use_mask: 58 | raise NotImplementedError() 59 | return g.op("custom::DeformConv2d", input, offset, mask, weight, bias, stride_i=[stride_h, stride_w], 60 | padding_i=[pad_h, pad_w], dilation_i=[dil_h, dil_w], deformable_groups_i=n_offset_grps) 61 | 62 | 63 | # Register custom symbolic function 64 | torch.onnx.register_custom_op_symbolic("torchvision::deform_conv2d", symbolic_deform_conv2d_forward, 13) 65 | 66 | 67 | def clean_fp16_subnormal(t: torch.Tensor): 68 | threshold = 0.00006103515625 69 | mask = torch.logical_and(t > -threshold, t < threshold) 70 | t[mask] = 0 71 | return t 72 | 73 | 74 | def as_module(func): 75 | class mod(torch.nn.Module): 76 | def __init__(self): 77 | super().__init__() 78 | 79 | def forward(self, *x): 80 | return func(*x) 81 | 82 | return mod().eval() 83 | 84 | 85 | def simplify(name): 86 | model, other = onnxsim.simplify(name) 87 | graph = gs.import_onnx(model) 88 | graph.fold_constants().cleanup() 89 | 90 | for n in graph.nodes: 91 | if n.op == 'DeformConv2d': 92 | if n.outputs[0].outputs[0].op == 'LeakyRelu': 93 | lrelu = n.outputs[0].outputs[0] 94 | n.attrs['activation_type'] = 3 95 | n.attrs['alpha'] = lrelu.attrs['alpha'] 96 | n.attrs['beta'] = 0.0 97 | n.outputs = lrelu.outputs 98 | lrelu.inputs = [] 99 | lrelu.outputs = [] 100 | else: 101 | n.attrs['activation_type'] = -1 102 | n.attrs['alpha'] = 0.0 103 | n.attrs['beta'] = 0.0 104 | 105 | graph.cleanup().toposort() 106 | 107 | model = gs.export_onnx(graph) 108 | onnx.save_model(model, name) 109 | print(f'Simplify {name} done') 110 | 111 | 112 | use_fold_catconv() 113 | model = CycMuNet(model_args) 114 | state_dict = torch.load(checkpoint_file, map_location='cpu') 115 | state_dict = {k: clean_fp16_subnormal(v) for k, v in state_dict.items() if '__weight_mma_mask' not in k} 116 | model.load_state_dict(state_dict) 117 | model = model.eval() 118 | for v in model.parameters(recurse=True): 119 | v.requires_grad = False 120 | 121 | 122 | if __name__ == '__main__': 123 | with torch.no_grad(): 124 | print("Exporting fe...") 125 | 126 | if model_args.format == 'rgb': 127 | fe_i = torch.zeros((2, 3, *size)) 128 | dynamic_axes = { 129 | "x": { 130 | 0: "batch_size", 131 | 2: "input_height", 132 | 3: "input_width" 133 | }, 134 | } 135 | elif model_args.format == 'yuv420': 136 | fe_i = tuple([torch.zeros((2, 1, *size_in)), torch.zeros((2, 2, *size_uv_in))]) 137 | dynamic_axes = { 138 | "y": { 139 | 0: "batch_size", 140 | 2: "input_height", 141 | 3: "input_width" 142 | }, 143 | "uv": { 144 | 0: "batch_size", 145 | 2: "input_height_uv", 146 | 3: "input_width_uv" 147 | }, 148 | } 149 | else: 150 | raise NotImplementedError() 151 | input_names = list(dynamic_axes.keys()) 152 | output_names = [f'l{i}' for i in range(model_args.layers)[::-1]] 153 | dynamic_axes.update({f'l{i}': { 154 | 0: "batch_size", 155 | 2: f"feature_height_{i}", 156 | 3: f"feature_width_{i}" 157 | } for i in range(model_args.layers)[::-1]}) 158 | 159 | @as_module 160 | def fe(*x_or_y_uv: torch.Tensor): 161 | if model_args.format == 'rgb': 162 | return model.head_fe(x_or_y_uv[0]) 163 | else: 164 | return model.head_fe(x_or_y_uv) 165 | 166 | 167 | torch.onnx.export(fe, fe_i, fe_onnx, opset_version=13, 168 | export_params=True, 169 | input_names=input_names, output_names=output_names, 170 | dynamic_axes=dynamic_axes) 171 | 172 | print("Exporting ff...") 173 | 174 | ff_i = [] 175 | input_axes = dict() 176 | cur_size = size_in 177 | for i in range(model_args.layers)[::-1]: 178 | ff_i.insert(0, torch.zeros(1, model_args.nf, *cur_size)) 179 | cur_size = tuple((i + 1) // 2 for i in cur_size) 180 | 181 | for i in range(model_args.layers): 182 | axes = { 183 | 0: "batch_size", 184 | 2: f"feature_height_{i}", 185 | 3: f"feature_width_{i}" 186 | } 187 | input_axes[f'f0l{i}'] = axes 188 | input_axes[f'f2l{i}'] = axes 189 | input_names = [f'f0l{i}' for i in range(model_args.layers)[::-1]] + \ 190 | [f'f2l{i}' for i in range(model_args.layers)[::-1]] 191 | output_names = ['f1'] 192 | dynamic_axes = dict(input_axes) 193 | dynamic_axes[f'f1'] = { 194 | 0: "batch_size", 195 | 2: f"feature_height_{model_args.layers - 1}", 196 | 3: f"feature_width_{model_args.layers - 1}" 197 | } 198 | 199 | if model_args.format == 'rgb': 200 | output_axes = { 201 | "h0": { 202 | 0: "batch_size", 203 | 2: "output_height", 204 | 3: "output_width" 205 | }, 206 | "h1": { 207 | 0: "batch_size", 208 | 2: "output_height", 209 | 3: "output_width" 210 | }, 211 | } 212 | elif model_args.format == 'yuv420': 213 | output_axes = { 214 | "h0_y": { 215 | 0: "batch_size", 216 | 2: "output_height", 217 | 3: "output_width" 218 | }, 219 | "h0_uv": { 220 | 0: "batch_size", 221 | 2: "output_height_uv", 222 | 3: "output_width_uv" 223 | }, 224 | "h1_y": { 225 | 0: "batch_size", 226 | 2: "output_height", 227 | 3: "output_width" 228 | }, 229 | "h1_uv": { 230 | 0: "batch_size", 231 | 2: "output_height_uv", 232 | 3: "output_width_uv" 233 | }, 234 | } 235 | else: 236 | raise NotImplementedError() 237 | output_names = list(output_axes.keys()) 238 | dynamic_axes = dict(input_axes) 239 | dynamic_axes.update(output_axes) 240 | 241 | @as_module 242 | def ff(input1, input2): 243 | fea = [input1[-1], model.ff(input1, input2)[0], input2[-1]] 244 | outs = model.mu_fr_tail(fea, all_frames=False) 245 | return outs 246 | 247 | 248 | torch.onnx.export(ff, (ff_i, ff_i), 249 | str(output_path / f'ff{config_string}.onnx'), opset_version=13, 250 | export_params=True, 251 | input_names=input_names, output_names=output_names, 252 | dynamic_axes=dynamic_axes) 253 | 254 | simplify(fe_onnx) 255 | simplify(ff_onnx) 256 | -------------------------------------------------------------------------------- /torch/cycmunet_test.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | import os 4 | import time 5 | 6 | import tqdm 7 | import torch 8 | import torch.nn.functional as F 9 | from torch.utils.data import DataLoader 10 | import torch.backends.cuda 11 | import torch.backends.cudnn 12 | import torchvision.utils 13 | from pytorch_msssim import SSIM 14 | 15 | from model import CycMuNet 16 | from model.util import converter, normalizer 17 | import dataset 18 | from cycmunet.model import model_arg 19 | from cycmunet.run import test_arg 20 | 21 | model_args = model_arg(nf=64, 22 | groups=8, 23 | upscale_factor=2, 24 | format='yuv420', 25 | layers=4, 26 | cycle_count=3 27 | ) 28 | 29 | test_args = test_arg( 30 | size=(256, 256), 31 | checkpoint='checkpoints/monitor-ugly-sparsity_2x_l4_c3_epoch_2.pth', 32 | dataset_indexes=[ 33 | "/root/videos/cctv-scaled/index-test-good.txt", 34 | "/root/videos/cctv-scaled/index-test-ugly.txt", 35 | "/root/videos/cctv-scaled/index-test-smooth.txt", 36 | "/root/videos/cctv-scaled/index-test-sharp.txt", 37 | ], 38 | preview_interval=100, 39 | seed=0, 40 | batch_size=4, 41 | fp16=True, 42 | ) 43 | 44 | torch.backends.cudnn.benchmark = True 45 | torch.backends.cudnn.allow_tf32 = True 46 | torch.backends.cuda.matmul.allow_tf32 = True 47 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 48 | torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False 49 | 50 | force_data_dtype = torch.float16 if test_args.fp16 else None 51 | 52 | # -------------------------------------- 53 | # Start of code 54 | 55 | preview_interval = 100 \ 56 | if (len(test_args.dataset_indexes) == 1 or math.gcd(100, len(test_args.dataset_indexes)) == 1) \ 57 | else 101 58 | 59 | nrow = 1 if test_args.size[0] * 9 > test_args.size[1] * 16 else 3 60 | 61 | torch.manual_seed(test_args.seed) 62 | torch.cuda.manual_seed(test_args.seed) 63 | 64 | formatter = logging.Formatter('%(asctime)s %(levelname)s [%(name)s]: %(message)s') 65 | 66 | ch = logging.StreamHandler() 67 | ch.setFormatter(formatter) 68 | ch.setLevel(logging.DEBUG) 69 | 70 | logger = logging.getLogger('test_progress') 71 | logger.addHandler(ch) 72 | logger.setLevel(logging.DEBUG) 73 | 74 | logger_init = logging.getLogger('initialization') 75 | logger_init.addHandler(ch) 76 | logger_init.setLevel(logging.DEBUG) 77 | 78 | cvt = converter() 79 | norm = normalizer() 80 | 81 | dataset_types = { 82 | 'triplet': dataset.ImageSequenceDataset, 83 | 'video': dataset.VideoFrameDataset 84 | } 85 | Dataset = dataset_types[test_args.dataset_type] 86 | if len(test_args.dataset_indexes) == 1: 87 | ds_test = Dataset(test_args.dataset_indexes[0], 88 | test_args.size, 89 | model_args.upscale_factor, 90 | augment=True, 91 | seed=test_args.seed) 92 | else: 93 | ds_test = dataset.InterleavedDataset(*[ 94 | Dataset(dataset_index, 95 | test_args.size, 96 | model_args.upscale_factor, 97 | augment=True, 98 | seed=test_args.seed + i) 99 | for i, dataset_index in enumerate(test_args.dataset_indexes)]) 100 | ds_test = DataLoader(ds_test, 101 | num_workers=1, 102 | batch_size=test_args.batch_size, 103 | shuffle=Dataset.want_shuffle, # Video dataset friendly 104 | drop_last=True) 105 | 106 | model = CycMuNet(model_args) 107 | model.eval() 108 | num_params = 0 109 | for param in model.parameters(): 110 | num_params += param.numel() 111 | logger_init.info(f"Model has {num_params} parameters.") 112 | 113 | if not os.path.exists(test_args.checkpoint): 114 | logger_init.error(f"Checkpoint weight {test_args.checkpoint} not exist.") 115 | exit(1) 116 | state_dict = torch.load(test_args.checkpoint, map_location=lambda storage, loc: storage) 117 | load_result = model.load_state_dict(state_dict, strict=False) 118 | if load_result.unexpected_keys: 119 | logger_init.warning(f"Unknown parameters ignored: {load_result.unexpected_keys}") 120 | if load_result.missing_keys: 121 | logger_init.warning(f"Missing parameters not initialized: {load_result.missing_keys}") 122 | logger_init.info("Checkpoint loaded.") 123 | 124 | model = model.cuda() 125 | if force_data_dtype: 126 | model = model.to(force_data_dtype) 127 | 128 | epsilon = (1 / 255) ** 2 129 | 130 | 131 | def rmse(a, b): 132 | return torch.mean(torch.sqrt((a - b) ** 2 + epsilon)) 133 | 134 | 135 | ssim_module = SSIM(data_range=1.0, nonnegative_ssim=True).cuda() 136 | 137 | 138 | def ssim(a, b): 139 | return 1 - ssim_module(a, b) 140 | 141 | 142 | def recursive_cuda(li, force_data_dtype): 143 | if isinstance(li, (list, tuple)): 144 | return tuple(recursive_cuda(i, force_data_dtype) for i in li) 145 | else: 146 | if force_data_dtype is not None: 147 | return li.cuda().to(force_data_dtype) 148 | else: 149 | return li.cuda() 150 | 151 | 152 | if __name__ == '__main__': 153 | with torch.no_grad(): 154 | total_loss = [0.0] * 4 155 | total_iter = len(ds_test) 156 | with tqdm.tqdm(total=total_iter, desc=f"Test") as progress: 157 | for it, data in enumerate(ds_test): 158 | (hf0, hf1, hf2), (lf0, lf1, lf2) = recursive_cuda(data, force_data_dtype) 159 | if Dataset.pix_type == 'yuv': 160 | target = [cvt.yuv2rgb(*inp) for inp in (hf0, hf1, hf2, lf1)] 161 | else: 162 | target = [hf0, hf1, hf2, lf1] 163 | 164 | if it % preview_interval == 0: 165 | if Dataset.pix_type == 'yuv': 166 | org = [F.interpolate(cvt.yuv2rgb(y[0:1], uv[0:1]), 167 | scale_factor=(model_args.upscale_factor, model_args.upscale_factor), 168 | mode='nearest').detach().float().cpu() 169 | for y, uv in (lf0, lf1, lf2)] 170 | else: 171 | org = [F.interpolate(lf[0:1], 172 | scale_factor=(model_args.upscale_factor, model_args.upscale_factor), 173 | mode='nearest').detach().float().cpu() 174 | for lf in (lf0, lf1, lf2)] 175 | 176 | if Dataset.pix_type == 'rgb': 177 | lf0, lf2 = cvt.rgb2yuv(lf0), cvt.rgb2yuv(lf2) 178 | 179 | t0 = time.perf_counter() 180 | lf0, lf2 = norm.normalize_yuv_420(*lf0), norm.normalize_yuv_420(*lf2) 181 | outs = model(lf0, lf2, batch_mode='batch') 182 | 183 | t1 = time.perf_counter() 184 | t_forward = t1 - t0 185 | actual = [cvt.yuv2rgb(*norm.denormalize_yuv_420(*out)).float() for out in outs] 186 | 187 | if it % preview_interval == 0: 188 | out = [i[0:1].detach().float().cpu() for i in actual[:3]] 189 | ref = [i[0:1].detach().float().cpu() for i in target[:3]] 190 | 191 | for idx, ts in enumerate(zip(org, out, ref)): 192 | torchvision.utils.save_image(torch.concat(ts), f"./result/out{idx}.png", 193 | value_range=(0, 1), nrow=nrow, padding=0) 194 | 195 | rmse_loss = [rmse(a, t).item() for a, t in zip(actual, target)] 196 | ssim_loss = [ssim(a, t).item() for a, t in zip(actual, target)] 197 | 198 | t2 = time.perf_counter() 199 | t_loss = t2 - t1 200 | 201 | rmse_h = sum(rmse_loss[:3]) / 3 202 | rmse_l = rmse_loss[3] 203 | ssim_h = sum(ssim_loss[:3]) / 3 204 | ssim_l = ssim_loss[3] 205 | 206 | total_loss[0] += rmse_h 207 | total_loss[1] += rmse_l 208 | total_loss[2] += ssim_h 209 | total_loss[3] += ssim_l 210 | 211 | progress.set_postfix(ordered_dict={ 212 | "rmse_h": f"{rmse_h:.4f}", 213 | "rmse_l": f"{rmse_l:.4f}", 214 | "ssim_h": f"{ssim_h:.4f}", 215 | "ssim_l": f"{ssim_l:.4f}", 216 | "f": f"{t_forward:.4f}s", 217 | "l": f"{t_loss:.4f}s", 218 | }) 219 | progress.update() 220 | 221 | logger.info(f"Test Complete: " 222 | f"RMSE HQ: {total_loss[0] / total_iter:.4f} " 223 | f"RMSE LQ: {total_loss[1] / total_iter:.4f} " 224 | f"SSIM HQ: {total_loss[2] / total_iter:.4f} " 225 | f"SSIM LQ: {total_loss[3] / total_iter:.4f}") 226 | -------------------------------------------------------------------------------- /torch/cycmunet_train.py: -------------------------------------------------------------------------------- 1 | import math 2 | import logging 3 | import os 4 | import pathlib 5 | import time 6 | 7 | import tqdm 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | import torch.optim as optim 12 | import torch.backends.cuda 13 | import torch.backends.cudnn 14 | import torchvision.utils 15 | from pytorch_msssim import SSIM 16 | 17 | from model import CycMuNet 18 | from model.util import converter, normalizer 19 | import dataset 20 | from cycmunet.model import model_arg 21 | from cycmunet.run import train_arg 22 | 23 | # ------------------------------------------ 24 | # Configs 25 | 26 | model_args = model_arg(nf=64, 27 | groups=8, 28 | upscale_factor=2, 29 | format='yuv420', 30 | layers=4, 31 | cycle_count=3 32 | ) 33 | 34 | train_args = train_arg( 35 | size=(128, 128), 36 | pretrained="/root/cycmunet-new/checkpoints/monitor-ugly_2x_l4_c3_epoch_19.pth", 37 | # dataset_type="video", 38 | # dataset_indexes=[ 39 | # "/root/videos/cctv-scaled/index-train-good.txt", 40 | # "/root/videos/cctv-scaled/index-train-ugly.txt", 41 | # "/root/videos/cctv-scaled/index-train-smooth.txt", 42 | # "/root/videos/cctv-scaled/index-train-sharp.txt", 43 | # ], 44 | dataset_type="triplet", 45 | dataset_indexes=[ 46 | "/root/dataset/vimeo_triplet/tri_trainlist.txt" 47 | ], 48 | preview_interval=100, 49 | seed=0, 50 | lr=0.001, 51 | start_epoch=1, 52 | end_epoch=11, 53 | sparsity=True, 54 | batch_size=2, 55 | autocast=False, 56 | loss_type='rmse', 57 | save_path='checkpoints', 58 | save_prefix='triplet', 59 | ) 60 | 61 | torch.backends.cudnn.benchmark = True 62 | torch.backends.cudnn.allow_tf32 = True 63 | torch.backends.cuda.matmul.allow_tf32 = True 64 | torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False 65 | torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = False 66 | 67 | # -------------------------------------- 68 | # Start of code 69 | 70 | preview_interval = 100 \ 71 | if (len(train_args.dataset_indexes) == 1 or math.gcd(100, len(train_args.dataset_indexes)) == 1) \ 72 | else 101 73 | 74 | save_prefix = f'{train_args.save_prefix}_{model_args.upscale_factor}x_l{model_args.layers}_c{model_args.cycle_count}' 75 | 76 | save_path = pathlib.Path(train_args.save_path) 77 | 78 | nrow = 1 if train_args.size[0] * 9 > train_args.size[1] * 16 else 3 79 | 80 | torch.manual_seed(train_args.seed) 81 | torch.cuda.manual_seed(train_args.seed) 82 | 83 | formatter = logging.Formatter('%(asctime)s %(levelname)s [%(name)s]: %(message)s') 84 | 85 | ch = logging.StreamHandler() 86 | ch.setFormatter(formatter) 87 | ch.setLevel(logging.DEBUG) 88 | 89 | logger = logging.getLogger('train_progress') 90 | logger.addHandler(ch) 91 | logger.setLevel(logging.DEBUG) 92 | 93 | logger_init = logging.getLogger('initialization') 94 | logger_init.addHandler(ch) 95 | logger_init.setLevel(logging.DEBUG) 96 | 97 | cvt = converter() 98 | norm = normalizer() 99 | 100 | dataset_types = { 101 | 'triplet': dataset.ImageSequenceDataset, 102 | 'video': dataset.VideoFrameDataset 103 | } 104 | Dataset = dataset_types[train_args.dataset_type] 105 | if len(train_args.dataset_indexes) == 1: 106 | ds_train = Dataset(train_args.dataset_indexes[0], 107 | train_args.size, 108 | model_args.upscale_factor, 109 | augment=True, 110 | seed=train_args.seed) 111 | else: 112 | ds_train = dataset.InterleavedDataset(*[ 113 | Dataset(dataset_index, 114 | train_args.size, 115 | model_args.upscale_factor, 116 | augment=True, 117 | seed=train_args.seed + i) 118 | for i, dataset_index in enumerate(train_args.dataset_indexes)]) 119 | ds_train = DataLoader(ds_train, 120 | num_workers=1, 121 | batch_size=train_args.batch_size, 122 | shuffle=Dataset.want_shuffle, # Video dataset friendly 123 | drop_last=True) 124 | 125 | model = CycMuNet(model_args) 126 | model.train() 127 | model_updated = False 128 | num_params = 0 129 | for param in model.parameters(): 130 | num_params += param.numel() 131 | logger_init.info(f"Model has {num_params} parameters.") 132 | 133 | if train_args.pretrained: 134 | if not os.path.exists(train_args.pretrained): 135 | logger_init.warning(f"Pretrained weight {train_args.pretrained} not exist.") 136 | state_dict = torch.load(train_args.pretrained, map_location=lambda storage, loc: storage) 137 | load_result = model.load_state_dict(state_dict, strict=False) 138 | if load_result.unexpected_keys: 139 | logger_init.warning(f"Unknown parameters ignored: {load_result.unexpected_keys}") 140 | if load_result.missing_keys: 141 | logger_init.warning(f"Missing parameters not initialized: {load_result.missing_keys}") 142 | logger_init.info("Pretrained weights loaded.") 143 | 144 | model = model.cuda() 145 | optimizer = optim.Adamax(model.parameters(), lr=train_args.lr, betas=(0.9, 0.999), eps=1e-8) 146 | # Or, train only some parts 147 | # optimizer = optim.Adamax(itertools.chain( 148 | # model.head.parameters(), 149 | # model.fe.parameters(), 150 | # model.fr.parameters(), 151 | # model.tail.parameters() 152 | # ), lr=args.lr, betas=(0.9, 0.999), eps=1e-8) 153 | scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=40000, eta_min=1e-7) 154 | 155 | num_params_train = 0 156 | for group in optimizer.param_groups: 157 | for params in group.get('params', []): 158 | num_params_train += params.numel() 159 | logger_init.info(f"Model has {num_params} parameters to train.") 160 | 161 | if train_args.sparsity: 162 | from apex.contrib.sparsity import ASP 163 | 164 | target_layers = [] 165 | target_layers.extend('mu.' + name for name, _ in model.mu.named_modules()) 166 | target_layers.extend('fr.' + name for name, _ in model.fr.named_modules()) 167 | 168 | ASP.init_model_for_pruning(model, 169 | mask_calculator="m4n2_1d", 170 | allowed_layer_names=target_layers, 171 | verbosity=2, 172 | whitelist=[torch.nn.Linear, torch.nn.Conv2d], 173 | allow_recompute_mask=False, 174 | allow_permutation=False, 175 | ) 176 | ASP.init_optimizer_for_pruning(optimizer) 177 | 178 | # import torch.fx 179 | # original_symbolic_trace = torch.fx.symbolic_trace 180 | # torch.fx.symbolic_trace = functools.partial(original_symbolic_trace, concrete_args={ 181 | # 'batch_mode': '_no_use_sparsity_pseudo', 182 | # 'stop_at_conf': False, 183 | # 'all_frames': True 184 | # }) 185 | 186 | ASP.compute_sparse_masks() 187 | # torch.fx.symbolic_trace = original_symbolic_trace 188 | logger.info('Training with sparsity.') 189 | 190 | 191 | epsilon = (1 / 255) ** 2 192 | 193 | 194 | def rmse(a, b): 195 | return torch.mean(torch.sqrt((a - b) ** 2 + epsilon)) 196 | 197 | 198 | ssim_module = SSIM(data_range=1.0, nonnegative_ssim=True).cuda() 199 | 200 | 201 | def ssim(a, b): 202 | return 1 - ssim_module(a, b) 203 | 204 | 205 | def recursive_cuda(li, force_data_dtype): 206 | if isinstance(li, (list, tuple)): 207 | return tuple(recursive_cuda(i, force_data_dtype) for i in li) 208 | else: 209 | if force_data_dtype is not None: 210 | return li.cuda().to(force_data_dtype) 211 | else: 212 | return li.cuda() 213 | 214 | 215 | def train(epoch): 216 | epoch_loss = 0 217 | total_iter = len(ds_train) 218 | loss_coeff = [1, 0.5, 1, 0.5] 219 | with tqdm.tqdm(total=total_iter, desc=f"Epoch {epoch}") as progress: 220 | for it, data in enumerate(ds_train): 221 | optimizer.zero_grad() 222 | 223 | def compute_loss(force_data_dtype=None): 224 | (hf0, hf1, hf2), (lf0, lf1, lf2) = recursive_cuda(data, force_data_dtype) 225 | if Dataset.pix_type == 'yuv': 226 | target = [cvt.yuv2rgb(*inp) for inp in (hf0, hf1, hf2, lf1)] 227 | else: 228 | target = [hf0, hf1, hf2, lf1] 229 | 230 | if it % preview_interval == 0: 231 | if Dataset.pix_type == 'yuv': 232 | org = [F.interpolate(cvt.yuv2rgb(y[0:1], uv[0:1]), 233 | scale_factor=(model_args.upscale_factor, model_args.upscale_factor), 234 | mode='nearest').detach().float().cpu() 235 | for y, uv in (lf0, lf1, lf2)] 236 | else: 237 | org = [F.interpolate(lf[0:1], 238 | scale_factor=(model_args.upscale_factor, model_args.upscale_factor), 239 | mode='nearest').detach().float().cpu() 240 | for lf in (lf0, lf1, lf2)] 241 | 242 | if Dataset.pix_type == 'rgb': 243 | lf0, lf2 = cvt.rgb2yuv(lf0), cvt.rgb2yuv(lf2) 244 | 245 | t0 = time.perf_counter() 246 | lf0, lf2 = norm.normalize_yuv_420(*lf0), norm.normalize_yuv_420(*lf2) 247 | outs = model(lf0, lf2, batch_mode='batch') 248 | 249 | t1 = time.perf_counter() 250 | actual = [cvt.yuv2rgb(*norm.denormalize_yuv_420(*out)) for out in outs] 251 | 252 | if train_args.loss_type == 'rmse': 253 | loss = [rmse(a, t) * c for a, t, c in zip(actual, target, loss_coeff)] 254 | elif train_args.loss_type == 'ssim': 255 | loss = [ssim(a, t) * c for a, t, c in zip(actual, target, loss_coeff)] 256 | else: 257 | raise ValueError("Unknown loss type: " + train_args.loss_type) 258 | 259 | assert not any(torch.any(torch.isnan(i)).item() for i in loss) 260 | 261 | t2 = time.perf_counter() 262 | 263 | if it % preview_interval == 0: 264 | out = [i[0:1].detach().float().cpu() for i in actual[:3]] 265 | ref = [i[0:1].detach().float().cpu() for i in target[:3]] 266 | 267 | for idx, ts in enumerate(zip(org, out, ref)): 268 | torchvision.utils.save_image(torch.concat(ts), f"./result/out{idx}.png", 269 | value_range=(0, 1), nrow=nrow, padding=0) 270 | 271 | return loss, t1 - t0, t2 - t1 272 | 273 | if train_args.autocast: 274 | with torch.autocast(device_type='cuda', dtype=torch.float16): 275 | loss, t_forward, t_loss = compute_loss(torch.float16) 276 | else: 277 | loss, t_forward, t_loss = compute_loss() 278 | 279 | total_loss = sum(loss) 280 | epoch_loss += total_loss.item() 281 | 282 | t3 = time.perf_counter() 283 | total_loss.backward() 284 | optimizer.step() 285 | scheduler.step() 286 | t_backward = time.perf_counter() - t3 287 | 288 | global model_updated 289 | model_updated = True 290 | 291 | progress.set_postfix(ordered_dict={ 292 | "loss": f"{total_loss.item():.4f}", 293 | "lr": f"{optimizer.param_groups[0]['lr']:.6e}", 294 | "f": f"{t_forward:.4f}s", 295 | "l": f"{t_loss:.4f}s", 296 | "b": f"{t_backward:.4f}s", 297 | }) 298 | progress.update() 299 | 300 | logger.info(f"Epoch {epoch} Complete: Avg. Loss: {epoch_loss / total_iter:.4f}") 301 | 302 | 303 | def save_model(epoch): 304 | if epoch == -1: 305 | name = "snapshot" 306 | else: 307 | name = f"epoch_{epoch}" 308 | if not os.path.exists(save_path): 309 | os.makedirs(save_path) 310 | output_path = save_path / f"{save_prefix}_{name}.pth" 311 | torch.save(model.state_dict(), output_path) 312 | logger.info(f"Checkpoint saved to {output_path}") 313 | 314 | 315 | if __name__ == '__main__': 316 | try: 317 | for epoch in range(train_args.start_epoch, train_args.end_epoch): 318 | # with torch.autograd.detect_anomaly(): 319 | # train(epoch) 320 | train(epoch) 321 | save_model(epoch) 322 | except KeyboardInterrupt: 323 | if model_updated: 324 | save_model(-1) 325 | -------------------------------------------------------------------------------- /torch/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .sequence import ImageSequenceDataset 2 | from .video import VideoFrameDataset 3 | from .util import InterleavedDataset 4 | -------------------------------------------------------------------------------- /torch/dataset/sequence.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import itertools 3 | import pathlib 4 | import random 5 | from typing import List 6 | 7 | import torch.utils.data as data 8 | import numpy as np 9 | import torchvision.transforms 10 | from PIL import Image, ImageFilter 11 | 12 | 13 | class ImageSequenceDataset(data.Dataset): 14 | want_shuffle = True 15 | pix_type = 'rgb' 16 | 17 | def __init__(self, index_file, patch_size, scale_factor, augment, seed=0): 18 | self.dataset_base = pathlib.Path(index_file).parent 19 | self.sequences = [i for i in open(index_file, 'r', encoding='utf-8').read().split('\n') 20 | if i if not i.startswith('#')] 21 | self.patch_size = patch_size 22 | self.scale_factor = scale_factor 23 | self.augment = augment 24 | self.rand = random.Random(seed) 25 | self.transform = torchvision.transforms.ToTensor() 26 | 27 | def _load_sequence(self, path): 28 | path = self.dataset_base / "sequences" / path 29 | files = glob.glob("*.png", root_dir=path) 30 | assert len(files) > 1 31 | images = [Image.open(file) for file in files] 32 | if not all(i.size != images[0].size for i in images[1:]): 33 | raise ValueError("sequence has different dimensions") 34 | return images 35 | 36 | def _prepare_images(self, images: List[Image.Image]): 37 | w, h = images[0].size 38 | f = self.scale_factor 39 | sw, sh = self.patch_size 40 | sw, sh = sw * f, sh * f 41 | assert h >= sh and w >= sw 42 | dh, dw = self.rand.randint(0, h - sh), self.rand.randint(0, w - sw) 43 | images = [i.crop((dw, dh, dw + sw, dh + sh)) for i in images] 44 | return images 45 | 46 | trans_groups = { 47 | 'none': [None], 48 | 'rotate': [None, Image.ROTATE_90, Image.ROTATE_180, Image.ROTATE_270], 49 | 'mirror': [None, Image.FLIP_LEFT_RIGHT], 50 | 'flip': [None, Image.FLIP_LEFT_RIGHT, Image.FLIP_TOP_BOTTOM, Image.ROTATE_180], 51 | 'all': [None] + [e.value for e in Image.Transpose], 52 | } 53 | 54 | trans_names = [e.name for e in Image.Transpose] 55 | 56 | def _augment_images(self, images: List[Image.Image], trans_mode='all'): 57 | trans_action = 'none' 58 | trans_op = self.rand.choice(self.trans_groups[trans_mode]) 59 | if trans_op is not None: 60 | images = [i.transpose(trans_op) for i in images] 61 | trans_action = self.trans_names[trans_op] 62 | return images, trans_action 63 | 64 | scale_filters = [Image.BILINEAR, Image.BICUBIC, Image.LANCZOS] 65 | 66 | def _scale_images(self, images: List[Image.Image]): 67 | f = self.scale_factor 68 | return [i.resize((i.width // f, i.height // f), self.rand.choice(self.scale_filters)) for i in images] 69 | 70 | def _degrade_images(self, images: List[Image.Image]): 71 | degrade_action = None 72 | decision = self.rand.randrange(4) 73 | if decision == 1: 74 | degrade_action = 'box' 75 | percent = 0.5 + 0.5 * self.rand.random() 76 | images = [Image.blend(j, j.copy().filter(ImageFilter.BoxBlur(1)), percent) for j in images] 77 | elif decision == 2: 78 | degrade_action = 'gaussian' 79 | radius = self.rand.random() 80 | images = [j.filter(ImageFilter.GaussianBlur(radius)) for j in images] 81 | elif decision == 3: 82 | degrade_action = 'halo' 83 | percent = 0.5 + 0.5 * self.rand.random() 84 | images = [Image.blend(i, 85 | i.resize((i.width // 2, i.height // 2), resample=Image.LANCZOS) 86 | .resize(i.size, resample=Image.BILINEAR), percent) 87 | for i in images] 88 | 89 | return images, degrade_action 90 | 91 | def __len__(self): 92 | return len(self.sequences) 93 | 94 | def __getitem__(self, idx): 95 | sequence = self._load_sequence(self.sequences[idx]) 96 | sequence = self._prepare_images(sequence) # crop to requested size 97 | original, _ = self._augment_images(sequence) # flip and rotates 98 | lfs_pred = [np.array(lf.resize((lf.width // self.scale_factor, lf.height // self.scale_factor), Image.LANCZOS)) 99 | for lf in original[1::2]] 100 | lfs_deg = self._scale_images(original[::2]) 101 | # lfs_deg, _ = self._degrade_images(lfs_deg) 102 | degraded = [i for i in itertools.zip_longest(lfs_deg, lfs_pred) if i is not None] 103 | original = [self.transform(i) for i in original] 104 | degraded = [self.transform(i) for i in degraded] 105 | return original, degraded 106 | -------------------------------------------------------------------------------- /torch/dataset/util.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | 3 | import torch.utils.data as data 4 | 5 | 6 | class InterleavedDataset(data.Dataset): 7 | def __init__(self, *datasets: data.Dataset): 8 | self.datasets = datasets 9 | if not all(hasattr(i, '__len__') for i in datasets): 10 | raise AttributeError('need datasets with known length') 11 | sizes = [len(i) for i in datasets] 12 | self.total = sum(sizes) 13 | self.sizes = [0] + sorted(set(sizes)) 14 | self.index = { 15 | 0: datasets 16 | } 17 | total, last_n = 0, len(datasets) 18 | for last_size, size in zip(self.sizes, self.sizes[1:]): 19 | total += (size - last_size) * last_n 20 | this_datasets = [ds for ds in datasets if len(ds) > size] 21 | self.index[total] = this_datasets 22 | last_n = len(this_datasets) 23 | self.index.popitem() 24 | self.index_keys = list(self.index.keys()) 25 | 26 | def __len__(self): 27 | return self.total 28 | 29 | def __getitem__(self, idx): 30 | stage = bisect.bisect_right(self.index_keys, idx) - 1 31 | offset = self.sizes[stage] 32 | begin = self.index_keys[stage] 33 | idx -= begin 34 | datasets = self.index[begin] 35 | n, i = idx % len(datasets), idx // len(datasets) 36 | return datasets[n][offset + i] 37 | -------------------------------------------------------------------------------- /torch/dataset/video.py: -------------------------------------------------------------------------------- 1 | import bisect 2 | import collections 3 | import functools 4 | import itertools 5 | import pathlib 6 | import random 7 | from typing import List, Tuple 8 | 9 | import av 10 | import av.logging 11 | import numpy as np 12 | import cv2 13 | import torch 14 | import torch.utils.data as data 15 | 16 | 17 | av.logging.set_level(av.logging.FATAL) 18 | 19 | 20 | class Video: 21 | def __init__(self, file, kf): 22 | self.container = av.open(file) 23 | self.stream = self.container.streams.video[0] 24 | self.stream.thread_type = "AUTO" 25 | self.at = 0 26 | self.kf = kf 27 | 28 | def get_frames(self, pts, n=1): 29 | frames = [] 30 | if bisect.bisect_left(self.kf, pts) != bisect.bisect_left(self.kf, self.at) or pts <= self.at: 31 | self.container.seek(pts, stream=self.stream) 32 | found = False 33 | for frame in self.container.decode(video=0): 34 | if not found and frame.pts != pts: 35 | continue 36 | found = True 37 | self.at = frame.pts 38 | yuv = frame.to_ndarray() 39 | h, w = frame.height, frame.width 40 | y, uv = yuv[:h, :].reshape(1, h, w), yuv[h:, :].reshape(2, h // 2, w // 2) 41 | frames.append((y, uv)) 42 | if len(frames) == n: 43 | return frames 44 | raise ValueError("unexpected end") 45 | 46 | def __del__(self): 47 | self.container.close() 48 | 49 | 50 | video_info = collections.namedtuple('video_info', [ 51 | 'org', 52 | 'deg', 53 | 'frames', 54 | 'pts_org', 55 | 'pts_deg', 56 | 'key_org', 57 | 'key_deg' 58 | ]) 59 | 60 | 61 | class VideoFrameDataset(data.Dataset): 62 | want_shuffle = False 63 | pix_type = 'yuv' 64 | 65 | def __init__(self, index_file, patch_size, scale_factor, augment, seed=0): 66 | self.dataset_base = pathlib.PurePath(index_file).parent 67 | index_lines = [i for i in open(index_file, 'r', encoding='utf-8').read().split('\n') 68 | if i if not i.startswith('#')] 69 | files = [tuple(i.split(',')) for i in index_lines] 70 | self.files = [] 71 | self.indexes = [] 72 | for org, deg, frames, pts_org, pts_deg, key_org, key_deg in files: 73 | info = video_info( 74 | org, 75 | deg, 76 | int(frames), 77 | tuple(int(i) for i in pts_org.split(' ')), 78 | tuple(int(i) for i in pts_deg.split(' ')), 79 | tuple(int(i) for i in key_org.split(' ')), 80 | tuple(int(i) for i in key_deg.split(' ')), 81 | ) 82 | self.files.append(info) 83 | self.indexes.append(info.frames) 84 | self.indexes = list(itertools.accumulate(self.indexes)) 85 | self.patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size 86 | self.scale_factor = scale_factor 87 | self.augment = augment 88 | self.rand = random.Random(seed) 89 | 90 | @staticmethod 91 | def transform(yuv): 92 | return tuple(torch.from_numpy(i).contiguous().to(dtype=torch.float32).div(255) for i in yuv) 93 | 94 | @functools.lru_cache(2) 95 | def get_video(self, v_idx): 96 | info = self.files[v_idx] 97 | return Video(str(self.dataset_base / info.org), info.key_org), \ 98 | Video(str(self.dataset_base / info.deg), info.key_deg), info.pts_org, info.pts_deg 99 | 100 | def _augment_frame(self, org: List[Tuple[np.ndarray]], deg: List[Tuple[np.ndarray]]): 101 | if self.rand.random() > 0.5: 102 | org = [(y[..., ::-1].copy(), uv[..., ::-1].copy()) for y, uv in org] 103 | deg = [(y[..., ::-1].copy(), uv[..., ::-1].copy()) for y, uv in deg] 104 | return org, deg 105 | 106 | def _prepare_frame(self, org: List[Tuple[np.ndarray]], deg: List[Tuple[np.ndarray]]): 107 | _, h, w = deg[0][0].shape 108 | sw, sh = self.patch_size 109 | sh_uv, sw_uv = sh // 2, sw // 2 110 | assert h >= sh and w >= sw 111 | dh, dw = self.rand.randrange(0, h - sh + 2, 2), self.rand.randrange(0, w - sw + 2, 2) 112 | dh_uv, dw_uv = dh // 2, dw // 2 113 | deg = [(y[:, dh:dh+sh, dw:dw+sw], uv[:, dh_uv:dh_uv+sh_uv, dw_uv:dw_uv+sw_uv]) for y, uv in deg] 114 | f = self.scale_factor 115 | size, size_uv = (sw, sh), (sw_uv, sh_uv) 116 | sh, sw, sh_uv, sw_uv = sh * f, sw * f, sh_uv * f, sw_uv * f 117 | dh, dw, dh_uv, dw_uv = dh * f, dw * f, dh_uv * f, dw_uv * f 118 | org = [(y[:, dh:dh+sh, dw:dw+sw], uv[:, dh_uv:dh_uv+sh_uv, dw_uv:dw_uv+sw_uv]) for y, uv in org] 119 | 120 | deg1_y = cv2.resize(org[1][0][0], size, interpolation=cv2.INTER_LANCZOS4) 121 | deg1_u = cv2.resize(org[1][1][0], size_uv, interpolation=cv2.INTER_LANCZOS4) 122 | deg1_v = cv2.resize(org[1][1][1], size_uv, interpolation=cv2.INTER_LANCZOS4) 123 | deg.insert(1, (deg1_y.reshape((1, *size[::-1])), np.stack((deg1_u, deg1_v)).reshape((2, *size_uv[::-1])))) 124 | return org, deg 125 | 126 | def __len__(self): 127 | return self.indexes[-1] 128 | 129 | def __getitem__(self, idx): 130 | v_idx = bisect.bisect_right(self.indexes, idx) 131 | f_idx = idx if v_idx == 0 else idx - self.indexes[v_idx - 1] 132 | org, deg, pts_org, pts_deg = self.get_video(v_idx) 133 | org_frames = org.get_frames(pts_org[f_idx], 3) 134 | deg_frames = deg.get_frames(pts_deg[f_idx], 3) 135 | deg_frames.pop(1) 136 | org_frames, deg_frames = self._prepare_frame(org_frames, deg_frames) 137 | if self.augment: 138 | org_frames, deg_frames = self._augment_frame(org_frames, deg_frames) 139 | org_frames = [self.transform(i) for i in org_frames] 140 | deg_frames = [self.transform(i) for i in deg_frames] 141 | return org_frames, deg_frames 142 | -------------------------------------------------------------------------------- /torch/environment.yml: -------------------------------------------------------------------------------- 1 | name: CycMuNet 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - fastai 6 | - nvidia 7 | dependencies: 8 | - pytorch 9 | - torchvision 10 | - pytorch-cuda=11.8 11 | - ipython 12 | - opencv-python-headless 13 | - av 14 | - tqdm 15 | - tabulate 16 | - onnx 17 | - pip: 18 | - onnx-simplifier==0.4.17 19 | - pytorch-msssim==0.2.1 20 | -------------------------------------------------------------------------------- /torch/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .util import use_fold_catconv 2 | 3 | from .cycmunet import CycMuNet 4 | -------------------------------------------------------------------------------- /torch/model/cycmunet/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from ..util import RGBOrYUV 7 | 8 | from .module import head, feature_extract, feature_fusion, mutual_cycle, feature_recon, tail 9 | 10 | 11 | class CycMuNet(nn.Module): 12 | def __init__(self, args): 13 | super(CycMuNet, self).__init__() 14 | self.args = args 15 | self.factor = (self.args.upscale_factor, self.args.upscale_factor) 16 | self.upsample_mode = 'bilinear' 17 | 18 | self.head = head(args) 19 | self.fe = feature_extract(args) 20 | self.ff = feature_fusion(args) 21 | self.mu = mutual_cycle(args) 22 | self.fr = feature_recon(args) 23 | self.tail = tail(args) 24 | 25 | def merge_hf(self, lf, hf): 26 | return F.interpolate(lf, scale_factor=self.factor, mode='bilinear', align_corners=False) + hf 27 | 28 | def head_fe(self, x_or_yuv: RGBOrYUV): 29 | x = self.head(x_or_yuv) 30 | return self.fe(x) 31 | 32 | def mu_fr_tail(self, lf, all_frames): 33 | mu_out = self.mu(*lf, all_frames=all_frames) 34 | if all_frames: 35 | *hf, lf1 = mu_out 36 | lf1 = self.fr(lf1) 37 | hf = tuple(self.merge_hf(l, self.fr(h)) for l, h in zip(lf, hf)) 38 | outs = tuple(self.tail(i) for i in (*hf, lf1)) 39 | else: 40 | outs = tuple(self.tail(self.merge_hf(l, self.fr(h))) for l, h in zip(lf, mu_out)) 41 | return outs 42 | 43 | def forward_batch(self, lf0: RGBOrYUV, lf2: RGBOrYUV, all_frames=True, stop_at_conf=False): 44 | lf0s, lf2s = self.head_fe(lf0), self.head_fe(lf2) 45 | lf1, _ = self.ff(lf0s, lf2s) 46 | if stop_at_conf: # TODO detect frame difference and exit if too big 47 | return 48 | lf = (lf0s[-1], lf1, lf2s[-1]) 49 | return self.mu_fr_tail(lf, all_frames) 50 | 51 | def forward_sequence(self, x_or_yuv: RGBOrYUV, all_frames=False): 52 | ls = self.head_fe(x_or_yuv) 53 | n = ls[0].shape[0] 54 | lf1, _ = self.ff([layer[:n - 1] for layer in ls], [layer[1:] for layer in ls]) 55 | lf = (ls[-1][:n - 1], lf1, ls[-1][1:]) 56 | return self.mu_fr_tail(lf, all_frames) 57 | 58 | # This is for symbolic tracing for sparsity 59 | def pseudo_forward_sparsity(self, lf0, lf1, lf2): 60 | hf0, *_ = self.mu(lf0, lf1, lf2, all_frames=True) 61 | return self.fr(hf0) 62 | 63 | def forward(self, lf0: RGBOrYUV, lf2: Union[RGBOrYUV, None] = None, sparsity_ex=None, /, batch_mode='batch', 64 | **kwargs): 65 | if batch_mode == '_no_use_sparsity_pseudo': 66 | return self.pseudo_forward_sparsity(lf0, lf2, sparsity_ex) 67 | if batch_mode == 'batch': 68 | outs = self.forward_batch(lf0, lf2, **kwargs) 69 | elif batch_mode == 'sequence': 70 | outs = self.forward_sequence(lf0, **kwargs) 71 | else: 72 | raise ValueError(f"Invalid batch_mode: {batch_mode}") 73 | return tuple(outs) 74 | 75 | 76 | __all__ = ['CycMuNet'] 77 | -------------------------------------------------------------------------------- /torch/model/cycmunet/module.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, List 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from ..util import cat_conv 7 | from ..part import ResidualBlock_noBN, PCDLayer 8 | 9 | from .part import Down_projection, Up_projection 10 | 11 | """CycMuNet model partitions""" 12 | 13 | 14 | class head(nn.Module): 15 | def __init__(self, args): 16 | super(head, self).__init__() 17 | self.args = args 18 | self.nf = self.args.nf 19 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 20 | 21 | match self.args.format: 22 | case 'rgb': 23 | self.conv_first = nn.Conv2d(3, self.nf, 3, 1, 1) 24 | self.forward = self.forward_rgb 25 | case 'yuv444': 26 | self.conv_first = nn.Conv2d(3, self.nf, 3, 1, 1) 27 | self.forward = self.forward_yuv444 28 | case 'yuv422': 29 | self.conv_first_y = nn.Conv2d(1, self.nf, 3, 1, 1) 30 | self.conv_up = nn.ConvTranspose2d(2, self.nf, (1, 3), (1, 2), (0, 1), (0, 1)) 31 | self.forward = self.forward_yuv42x 32 | case 'yuv420': 33 | self.conv_first_y = nn.Conv2d(1, self.nf, 3, 1, 1) 34 | self.conv_up = nn.ConvTranspose2d(2, self.nf, 3, 2, 1, 1) 35 | self.forward = self.forward_yuv42x 36 | case unk: 37 | raise ValueError(f'unknown input pixel format: {unk}') 38 | 39 | def forward_rgb(self, x: torch.Tensor): 40 | x = self.lrelu(self.conv_first(x)) 41 | return x 42 | 43 | def forward_yuv444(self, yuv: Tuple[torch.Tensor, torch.Tensor]): 44 | x = torch.cat(yuv, dim=1) 45 | x = self.lrelu(self.conv_first(x)) 46 | return x 47 | 48 | def forward_yuv42x(self, yuv: Tuple[torch.Tensor, torch.Tensor]): 49 | y, uv = yuv 50 | y = self.conv_first_y(y) 51 | uv = self.conv_up(uv) 52 | x = self.lrelu(y + uv) 53 | return x 54 | 55 | 56 | class feature_extract(nn.Module): 57 | def __init__(self, args): 58 | super(feature_extract, self).__init__() 59 | self.args = args 60 | self.nf = self.args.nf 61 | self.groups = self.args.groups 62 | self.layers = self.args.layers 63 | self.front_RBs = 5 64 | 65 | self.feature_extraction = nn.Sequential(*(ResidualBlock_noBN(nf=self.nf) for _ in range(self.front_RBs))) 66 | self.fea_conv1s = nn.ModuleList(nn.Conv2d(self.nf, self.nf, 3, 2, 1, bias=True) for _ in range(self.layers - 1)) 67 | self.fea_conv2s = nn.ModuleList(nn.Conv2d(self.nf, self.nf, 3, 1, 1, bias=True) for _ in range(self.layers - 1)) 68 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 69 | 70 | def forward(self, x: torch.Tensor): 71 | features: List[torch.Tensor] = [self.feature_extraction(x)] 72 | for i in range(self.layers - 1): 73 | feature = features[-1] 74 | _, _, h, w = feature.shape 75 | h = torch.div(h + 1, 2, rounding_mode="trunc") * 2 - h 76 | w = torch.div(w + 1, 2, rounding_mode="trunc") * 2 - w 77 | feature = F.pad(feature, (0, w, 0, h), mode="replicate") 78 | feature = self.lrelu(self.fea_conv1s[i](feature)) 79 | feature = self.lrelu(self.fea_conv2s[i](feature)) 80 | features.append(feature) 81 | return tuple(features[::-1]) # lowest dimension layer at first 82 | 83 | 84 | class feature_fusion(nn.Module): 85 | def __init__(self, args): 86 | super(feature_fusion, self).__init__() 87 | self.args = args 88 | self.nf = self.args.nf 89 | self.groups = self.args.groups 90 | self.layers = self.args.layers 91 | 92 | # from small to big. 93 | self.modules12 = nn.ModuleList(PCDLayer(args, i == 0) for i in range(self.layers)) 94 | self.modules21 = nn.ModuleList(PCDLayer(args, i == 0) for i in range(self.layers)) 95 | 96 | self.fusion = nn.Conv2d(2 * self.nf, self.nf, 1, 1) 97 | 98 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 99 | 100 | @staticmethod 101 | def fuse_features(modules, f1, f2): 102 | offset, feature = None, None 103 | for idx, sources in enumerate(zip(f1, f2)): 104 | offset, feature = modules[idx](sources, offset, feature) 105 | return feature 106 | 107 | def forward(self, f1, f2): 108 | feature1 = self.fuse_features(self.modules12, f1, f2) 109 | feature2 = self.fuse_features(self.modules21, f2, f1) 110 | fused_feature = cat_conv(self.fusion, (feature1, feature2)) 111 | return fused_feature, None 112 | 113 | 114 | class mutual_cycle(nn.Module): 115 | def __init__(self, args): 116 | super(mutual_cycle, self).__init__() 117 | self.args = args 118 | self.nf = self.args.nf 119 | self.cycle_count = self.args.cycle_count 120 | 121 | self.merge = nn.ModuleList(nn.Conv2d(64 * (i + 1), 64, 1, 1, 0) for i in range(self.cycle_count)) 122 | self.merge1 = nn.ModuleList(nn.Conv2d(64 * (i + 1), 64, 1, 1, 0) for i in range(self.cycle_count)) 123 | 124 | self.down = nn.ModuleList(Down_projection(args) for _ in range(self.cycle_count)) 125 | self.up = nn.ModuleList(Up_projection(args) for _ in range(self.cycle_count + 1)) 126 | 127 | self.conv = nn.Conv2d(self.nf * (2 * self.cycle_count + 1), self.nf, 1, 1, 0) 128 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 129 | 130 | def forward(self, lf0, lf1, lf2, all_frames=False): 131 | assert self.cycle_count > 0 132 | 133 | l_out, h_out = [(lf0, lf1, lf2)], [] 134 | for j in range(self.cycle_count): 135 | l_feats = tuple(self.lrelu(cat_conv(self.merge[j], frame_outs)) for frame_outs in zip(*l_out)) 136 | h_feat = self.up[j](*l_feats) 137 | h_out.append(h_feat) 138 | h_feats = tuple(self.lrelu(cat_conv(self.merge1[j], frame_outs)) for frame_outs in zip(*h_out)) 139 | l_feat = self.down[j](*h_feats) 140 | l_out.append(l_feat) 141 | 142 | lf_out, hf_out = [l_out[-1]], [] 143 | for j in range(self.cycle_count): 144 | l_feats = tuple(self.lrelu(cat_conv(self.merge[j], frame_outs)) for frame_outs in zip(*lf_out)) 145 | h_feat = self.up[j](*l_feats) 146 | hf_out.append(h_feat) 147 | l_feat = self.down[j](*h_feat) 148 | lf_out.append(l_feat) 149 | hf_out.append(self.up[self.cycle_count](*l_feats)) 150 | 151 | if all_frames: 152 | h_outs = zip(*h_out, *hf_out) # packed 3 frames 153 | _, l1_out, _ = zip(*l_out, *lf_out[1:]) 154 | 155 | h_outs = tuple(self.lrelu(cat_conv(self.conv, h_frame)) for h_frame in h_outs) 156 | l1_out = self.lrelu(cat_conv(self.conv, l1_out)) 157 | return *h_outs, l1_out 158 | else: 159 | h1_out, h2_out, _ = zip(*h_out, *hf_out) 160 | h1_out = self.lrelu(cat_conv(self.conv, h1_out)) 161 | h2_out = self.lrelu(cat_conv(self.conv, h2_out)) 162 | 163 | return h1_out, h2_out 164 | 165 | 166 | class feature_recon(nn.Module): 167 | def __init__(self, args): 168 | super(feature_recon, self).__init__() 169 | self.args = args 170 | self.nf = self.args.nf 171 | self.back_RBs = 40 172 | self.factor = (self.args.upscale_factor, self.args.upscale_factor) 173 | 174 | self.recon_trunk = nn.Sequential(*(ResidualBlock_noBN(nf=self.nf) for _ in range(self.back_RBs))) 175 | 176 | def forward(self, x): 177 | out = self.recon_trunk(x) 178 | return out 179 | 180 | 181 | class tail(nn.Module): 182 | def __init__(self, args): 183 | super(tail, self).__init__() 184 | self.args = args 185 | self.nf = self.args.nf 186 | match self.args.format: 187 | case 'rgb': 188 | self.conv_last2 = nn.Conv2d(self.nf, 3, 3, 1, 1) 189 | self.forward = self.forward_rgb 190 | case 'yuv444': 191 | self.conv_last2 = nn.Conv2d(self.nf, 3, 3, 1, 1) 192 | self.forward = self.forward_yuv444 193 | case 'yuv422': 194 | self.conv_last_y = nn.Conv2d(self.nf, 1, 3, 1, 1) 195 | self.conv_last_uv = nn.Conv2d(self.nf, 2, (1, 3), (1, 2), (0, 1)) 196 | self.forward = self.forward_yuv42x 197 | case 'yuv420': 198 | self.conv_last_y = nn.Conv2d(self.nf, 1, 3, 1, 1) 199 | self.conv_last_uv = nn.Conv2d(self.nf, 2, 3, 2, 1) 200 | self.forward = self.forward_yuv42x 201 | case unk: 202 | raise ValueError(f'unknown input pixel format: {unk}') 203 | 204 | def forward_rgb(self, x): 205 | out = self.conv_last2(x) 206 | return out, 207 | 208 | def forward_yuv444(self, x): 209 | out = self.conv_last2(x) 210 | return out[:, :1, ...], out[:, 1:, ...] 211 | 212 | def forward_yuv42x(self, x): 213 | y = self.conv_last_y(x) 214 | uv = self.conv_last_uv(x) 215 | return y, uv 216 | -------------------------------------------------------------------------------- /torch/model/cycmunet/part.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from ..util import cat_conv 5 | 6 | """CycMuNet Private Network Build Block""" 7 | 8 | 9 | class Pro_align(nn.Module): 10 | def __init__(self, args): 11 | super(Pro_align, self).__init__() 12 | self.args = args 13 | self.nf = self.args.nf 14 | self.conv1x1 = nn.Conv2d(self.nf * 3, self.nf, 1, 1, 0) 15 | self.conv3x3 = nn.Conv2d(self.nf, self.nf, 3, 1, 1) 16 | self.conv1_3x3 = nn.Conv2d(self.nf * 2, self.nf, 3, 1, 1) 17 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 18 | 19 | def forward(self, l1, l2, l3): 20 | r1 = self.lrelu(self.conv3x3(l1)) 21 | r2 = self.lrelu(self.conv3x3(l2)) 22 | r3 = self.lrelu(self.conv3x3(l3)) 23 | fuse = self.lrelu(cat_conv(self.conv1x1, [r1, r2, r3])) 24 | r1 = self.lrelu(cat_conv(self.conv1_3x3, [r1, fuse])) 25 | r2 = self.lrelu(cat_conv(self.conv1_3x3, [r2, fuse])) 26 | r3 = self.lrelu(cat_conv(self.conv1_3x3, [r3, fuse])) 27 | return l1 + r1, l2 + r2, l3 + r3 28 | 29 | 30 | class SR(nn.Module): 31 | def __init__(self, args): 32 | super(SR, self).__init__() 33 | self.args = args 34 | self.nf = self.args.nf 35 | self.factor = (self.args.upscale_factor, self.args.upscale_factor) 36 | self.Pro_align = Pro_align(args) 37 | self.conv1x1 = nn.Conv2d(self.nf, self.nf, 1, 1, 0) 38 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 39 | 40 | def upsample(self, x): 41 | x = F.interpolate(x, scale_factor=self.factor, mode='bilinear', align_corners=False) 42 | return self.lrelu(self.conv1x1(x)) 43 | 44 | def forward(self, l1, l2, l3): 45 | l1, l2, l3 = self.Pro_align(l1, l2, l3) 46 | return tuple(self.upsample(i) for i in (l1, l2, l3)) 47 | 48 | 49 | class DR(nn.Module): 50 | def __init__(self, args): 51 | super(DR, self).__init__() 52 | self.args = args 53 | self.nf = self.args.nf 54 | self.factor = (1 / self.args.upscale_factor, 1 / self.args.upscale_factor) 55 | self.Pro_align = Pro_align(args) 56 | self.conv = nn.Conv2d(self.nf, self.nf, 1, 1, 0) 57 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 58 | 59 | def downsample(self, x): 60 | x = F.interpolate(x, scale_factor=self.factor, mode='bilinear', align_corners=False) 61 | return self.lrelu(self.conv(x)) 62 | 63 | def forward(self, l1, l2, l3): 64 | l1 = self.downsample(l1) 65 | l2 = self.downsample(l2) 66 | l3 = self.downsample(l3) 67 | return self.Pro_align(l1, l2, l3) 68 | 69 | 70 | class Up_projection(nn.Module): 71 | def __init__(self, args): 72 | super(Up_projection, self).__init__() 73 | self.args = args 74 | self.SR = SR(args) 75 | self.DR = DR(args) 76 | self.SR1 = SR(args) 77 | 78 | def forward(self, l1, l2, l3): 79 | h1, h2, h3 = self.SR(l1, l2, l3) 80 | d1, d2, d3 = self.DR(h1, h2, h3) 81 | r1, r2, r3 = d1 - l1, d2 - l2, d3 - l3 82 | s1, s2, s3 = self.SR1(r1, r2, r3) 83 | return h1 + s1, h2 + s3, h3 + s3 84 | 85 | 86 | class Down_projection(nn.Module): 87 | def __init__(self, args): 88 | super(Down_projection, self).__init__() 89 | self.args = args 90 | self.SR = SR(args) 91 | self.DR = DR(args) 92 | self.DR1 = DR(args) 93 | 94 | def forward(self, h1, h2, h3): 95 | l1, l2, l3 = self.DR(h1, h2, h3) 96 | s1, s2, s3 = self.SR(l1, l2, l3) 97 | r1, r2, r3 = s1 - h1, s2 - h2, s3 - h3 98 | d1, d2, d3 = self.DR1(r1, r2, r3) 99 | return l1 + d1, l2 + d2, l3 + d3 100 | -------------------------------------------------------------------------------- /torch/model/part.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | import torch.nn.functional as F 7 | from torch.nn.common_types import _size_2_t 8 | from torch.nn.modules.utils import _pair 9 | import torchvision.ops 10 | 11 | from .util import cat_conv 12 | 13 | 14 | class ResidualBlock_noBN(nn.Module): 15 | '''Residual block w/o BN 16 | ---Conv-ReLU-Conv-+- 17 | |________________| 18 | ''' 19 | 20 | def __init__(self, nf=64): 21 | super(ResidualBlock_noBN, self).__init__() 22 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 23 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 24 | 25 | # initialization 26 | self.init_weights(self.conv1) 27 | self.init_weights(self.conv2) 28 | 29 | @staticmethod 30 | def init_weights(conv): 31 | init.kaiming_normal_(conv.weight, a=0, mode='fan_in') 32 | conv.weight.data *= 0.1 # for residual block 33 | if conv.bias is not None: 34 | conv.bias.data.zero_() 35 | 36 | def forward(self, x): 37 | identity = x 38 | out = F.relu(self.conv1(x), inplace=True) 39 | out = self.conv2(out) 40 | return identity + out 41 | 42 | 43 | class DCN_sep(nn.Module): 44 | def __init__(self, 45 | in_channels: int, 46 | in_channels_features: int, 47 | out_channels: int, 48 | kernel_size: _size_2_t, 49 | stride: _size_2_t = 1, 50 | padding: _size_2_t = 0, 51 | dilation: _size_2_t = 1, 52 | groups: int = 1, 53 | deformable_groups: int = 1, 54 | bias: bool = True, 55 | mask: bool = True): 56 | super(DCN_sep, self).__init__() 57 | 58 | self.dcn = torchvision.ops.DeformConv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, 59 | groups, bias) 60 | 61 | kernel_size_ = _pair(kernel_size) 62 | offset_channels = deformable_groups * kernel_size_[0] * kernel_size_[1] 63 | 64 | self.conv_offset = nn.Conv2d(in_channels_features, offset_channels * 2, kernel_size=kernel_size, 65 | stride=stride, padding=padding, dilation=dilation, bias=True) 66 | self.conv_mask = nn.Conv2d(in_channels_features, offset_channels, kernel_size=kernel_size, 67 | stride=stride, padding=padding, dilation=dilation, bias=True) if mask else None 68 | self.relu = nn.ReLU(inplace=True) 69 | 70 | def forward(self, input: torch.Tensor, feature: torch.Tensor): 71 | offset = self.conv_offset(feature) 72 | mask = torch.sigmoid(self.conv_mask(feature)) if self.conv_mask else None 73 | 74 | return self.dcn(input, offset, mask) 75 | 76 | 77 | class PCDLayer(nn.Module): 78 | """ Alignment module using Pyramid, Cascading and Deformable convolution""" 79 | def __init__(self, args, first_layer: bool): 80 | super(PCDLayer, self).__init__() 81 | self.args = args 82 | self.nf = self.args.nf 83 | self.groups = self.args.groups 84 | 85 | self.offset_conv1 = nn.Conv2d(2 * self.nf, self.nf, 3, 1, 1) 86 | self.offset_conv3 = nn.Conv2d(self.nf, self.nf, 3, 1, 1) 87 | self.dcnpack = DCN_sep(self.nf, self.nf, self.nf, 3, stride=1, padding=1, dilation=1, 88 | deformable_groups=self.groups) 89 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 90 | 91 | if not first_layer: 92 | self.offset_conv2 = nn.Conv2d(2 * self.nf, self.nf, 3, 1, 1) 93 | self.fea_conv = nn.Conv2d(2 * self.nf, self.nf, 3, 1, 1) 94 | 95 | def forward(self, current_sources: Tuple[torch.Tensor, torch.Tensor], 96 | last_offset: torch.Tensor, last_feature: torch.Tensor): 97 | offset = self.lrelu(cat_conv(self.offset_conv1, current_sources)) 98 | if last_offset is not None: 99 | last_offset = F.interpolate(last_offset, scale_factor=2, mode='bilinear', align_corners=False) 100 | _, _, h, w = offset.shape 101 | last_offset = last_offset[..., :h, :w] 102 | offset = self.lrelu(cat_conv(self.offset_conv2, (offset, last_offset * 2))) 103 | offset = self.lrelu(self.offset_conv3(offset)) 104 | feature = self.dcnpack(current_sources[0], offset) 105 | if last_feature is not None: 106 | last_feature = F.interpolate(last_feature, scale_factor=2, mode='bilinear', align_corners=False) 107 | _, _, h, w = feature.shape 108 | last_feature = last_feature[..., :h, :w] 109 | feature = cat_conv(self.fea_conv, (feature, last_feature)) 110 | feature = self.lrelu(feature) 111 | return offset, feature 112 | -------------------------------------------------------------------------------- /torch/model/util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import itertools 3 | import math 4 | from typing import Tuple, Union 5 | 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torchvision 11 | 12 | RGBOrYUV = Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] 13 | 14 | _use_fold_catconv = False 15 | 16 | 17 | def use_fold_catconv(value=True): 18 | global _use_fold_catconv 19 | _use_fold_catconv = value 20 | 21 | 22 | def cat_simp(ts, *args, **kwargs): 23 | """auto eliminate cat if there's only one input""" 24 | if len(ts) == 1: 25 | return ts[0] 26 | return torch.cat(ts, *args, **kwargs) 27 | 28 | 29 | def cat_conv(conv: nn.Conv2d, tensors, scale=None): 30 | """separate cat+conv into multiple conv to reduce memory footprint""" 31 | if _use_fold_catconv: 32 | w = conv.weight.detach() 33 | b = conv.bias.detach() 34 | if scale is not None: 35 | w *= scale 36 | b *= scale 37 | output = None 38 | channels = [0] 39 | channels.extend(itertools.accumulate(int(tensor.shape[1]) for tensor in tensors)) 40 | for ti, cb, ce in zip(tensors, channels, channels[1:]): 41 | c = ti.shape[1] 42 | convi = nn.Conv2d(c, conv.out_channels, conv.kernel_size, conv.stride, conv.padding, 43 | conv.dilation, bias=output is None).eval() 44 | convi.weight = nn.Parameter(w[:, cb:ce, :, :], requires_grad=False) 45 | if output is None: 46 | convi.bias = nn.Parameter(b, requires_grad=False) 47 | outputi = convi(ti) 48 | output = outputi if output is None else output + outputi 49 | return output 50 | else: 51 | return conv(torch.cat(tensors, dim=1)) 52 | 53 | 54 | # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 55 | class normalizer: 56 | nm = torchvision.transforms.Normalize 57 | sqrt2 = math.sqrt(2) 58 | 59 | def __init__(self, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), kr=0.2126, kb=0.0722, depth=8): 60 | self.mean = mean 61 | self.std = std 62 | self.krgb = (kr, 1 - kr - kb, kb) 63 | self.depth = depth 64 | self.uv_bias = (1 << (depth - 1)) / ((1 << depth) - 1) 65 | 66 | @staticmethod 67 | def _inv(mean, std): 68 | inv_std = tuple(1 / i for i in std) 69 | inv_mean = tuple(-j * i for i, j in zip(inv_std, mean)) 70 | return inv_mean, inv_std 71 | 72 | def _yuv_dist(self): 73 | rm, gm, bm = self.mean 74 | rs, gs, bs = self.std 75 | kr, kg, kb = self.krgb 76 | 77 | ym = rm * kr + gm * kg + bm * kb 78 | ys = math.sqrt((rs * kr) ** 2 + (gs * kg) ** 2 + (bs * kb) ** 2) 79 | um = (bm - ym) / (1 - kb) / 2 + self.uv_bias 80 | us = math.sqrt(bs * bs + ys * ys) / (1 - kb) / 2 81 | vm = (rm - ym) / (1 - kr) / 2 + self.uv_bias 82 | vs = math.sqrt(rs * rs + ys * ys) / (1 - kr) / 2 83 | return [ym, um, vm], [ys, us, vs] 84 | 85 | def normalize_rgb(self, rgb: torch.Tensor): 86 | return self.nm(self.mean, self.std)(rgb) 87 | 88 | def denormalize_rgb(self, rgb: torch.Tensor): 89 | return self.nm(*self._inv(self.mean, self.std))(rgb) 90 | 91 | def normalize_yuv_444(self, yuv: torch.Tensor): 92 | return self.nm(*self._yuv_dist())(yuv) 93 | 94 | def denormalize_yuv_444(self, yuv: torch.Tensor): 95 | return self.nm(*self._inv(*self._yuv_dist()))(yuv) 96 | 97 | def _normalize_yuv_42x(self, y: torch.Tensor, uv: torch.Tensor, scale): 98 | mean, std = self._yuv_dist() 99 | std[1], std[2] = std[1] * scale, std[2] * scale 100 | y = self.nm(mean[0], std[0])(y) 101 | uv = self.nm(mean[1:], std[1:])(uv) 102 | return y, uv 103 | 104 | def _denormalize_yuv_42x(self, y: torch.Tensor, uv: torch.Tensor, scale): 105 | mean, std = self._yuv_dist() 106 | std[1], std[2] = std[1] * scale, std[2] * scale 107 | mean, std = self._inv(mean, std) 108 | y = self.nm(mean[0], std[0])(y) 109 | uv = self.nm(mean[1:], std[1:])(uv) 110 | return y, uv 111 | 112 | def normalize_yuv_422(self, y: torch.Tensor, uv: torch.Tensor): 113 | return self._normalize_yuv_42x(y, uv, 1 / self.sqrt2) 114 | 115 | def denormalize_yuv_422(self, y: torch.Tensor, uv: torch.Tensor): 116 | return self._denormalize_yuv_42x(y, uv, 1 / self.sqrt2) 117 | 118 | def normalize_yuv_420(self, y: torch.Tensor, uv: torch.Tensor): 119 | return self._normalize_yuv_42x(y, uv, 1 / 2) 120 | 121 | def denormalize_yuv_420(self, y: torch.Tensor, uv: torch.Tensor): 122 | return self._denormalize_yuv_42x(y, uv, 1 / 2) 123 | 124 | 125 | class converter: 126 | def __init__(self, kr=0.2126, kb=0.0722, depth=8, format='yuv420', upsample_mode='bilinear'): 127 | self.krgb = (kr, 1 - kr - kb, kb) 128 | self.depth = depth 129 | self.uv_bias = (1 << (depth - 1)) / ((1 << depth) - 1) 130 | match format: 131 | case 'yuv444': 132 | self.downsample = lambda x: x 133 | self.upsample = lambda x: x 134 | case 'yuv422': 135 | self.downsample = functools.partial(F.interpolate, scale_factor=(1, 1 / 2), mode='bilinear', 136 | align_corners=False) 137 | self.upsample = functools.partial(F.interpolate, scale_factor=(1, 2), mode=upsample_mode, 138 | align_corners=False) 139 | case 'yuv420': 140 | self.downsample = functools.partial(F.interpolate, scale_factor=(1 / 2, 1 / 2), mode='bilinear', 141 | align_corners=False) 142 | self.upsample = functools.partial(F.interpolate, scale_factor=(2, 2), mode=upsample_mode, 143 | align_corners=False) 144 | 145 | def rgb2yuv(self, x: torch.Tensor): 146 | kr, kg, kb = self.krgb 147 | 148 | r, g, b = torch.chunk(x, 3, 1) 149 | y = kr * r + kg * g + kb * b 150 | u = (y - b) / (kb - 1) / 2 + self.uv_bias 151 | v = (y - r) / (kr - 1) / 2 + self.uv_bias 152 | uv = torch.cat((u, v), dim=1) 153 | return y, self.downsample(uv) 154 | 155 | def yuv2rgb(self, y: torch.Tensor, uv: torch.Tensor): 156 | kr, kg, kb = self.krgb 157 | 158 | uv = self.upsample(uv - self.uv_bias) 159 | u, v = torch.chunk(uv, 2, 1) 160 | r = y + 2 * (1 - kr) * v 161 | b = y + 2 * (1 - kb) * u 162 | g = y - 2 * (1 - kr) * kr * v - 2 * (1 - kb) * kb * u 163 | return torch.cat((r, g, b), dim=1) 164 | --------------------------------------------------------------------------------