├── .gitignore ├── LICENSE ├── README.md ├── base ├── file_io_base.py ├── flow_base.py ├── image_base.py ├── imresize_right.py ├── kernel_base.py ├── matlab_imresize.py └── os_base.py └── utils ├── LPIPS └── models │ ├── __init__.py │ ├── base_model.py │ ├── dist_model.py │ ├── networks_basic.py │ ├── pretrained_networks.py │ └── weights │ ├── v0.0 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth │ └── v0.1 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth ├── NIQE ├── niqe.py └── niqe_image_params.mat ├── ResizeRight ├── LICENSE ├── README.md ├── interp_methods.py └── resize_right.py ├── create_gif.py ├── dnn_utils.py ├── file_regroup_utils.py ├── image_crop_combine_utils.py ├── image_edge_sharpen.py ├── image_metric_utils.py ├── image_utils.py ├── kernel_metric_utils.py ├── kernel_utils.py ├── plot_utils.py ├── string_utils.py ├── torchstat ├── .gitignore ├── .travis.yml ├── LICENSE ├── README.md ├── detail.md ├── example.py ├── requirements.txt ├── setup.cfg ├── setup.py ├── test_requirements.txt └── torchstat │ ├── __init__.py │ ├── __main__.py │ ├── compute_flops.py │ ├── compute_madd.py │ ├── compute_memory.py │ ├── model_hook.py │ ├── reporter.py │ ├── stat_tree.py │ └── statistics.py ├── video_metric_utils.py ├── video_regroup_utils.py └── video_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea/ 2 | */__pycache__ 3 | **/__pycache__ 4 | ss* 5 | draft* 6 | temp/ 7 | .DS_Store* 8 | **/.DS_Store* 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Haoran Bai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OpenUtility 2 | [![LICENSE](https://img.shields.io/badge/license-MIT-green)](https://github.com/csbhr/Python_Tools/blob/master/LICENSE) 3 | [![Python](https://img.shields.io/badge/python-3.6-blue.svg)](https://www.python.org/) 4 | 5 | There are some useful tools for low-level vision tasks. 6 | 7 | - [Calculating Metrics ( PSNR, SSIM, etc.)](#chapter-calculating-metrics) 8 | - [Image / Video Processing ( resize, crop, shift, etc.)](#chapter-image-video-processing) 9 | - [Deep Model Properties ( Params, Flops, etc. )](#chapter-model-properties) 10 | - [File Processing ( csv, etc. )](#chapter-file-processing) 11 | - [Visualize Tools ( plot, optical-flow, etc. )](#chapter-visualize-tools) 12 | 13 | 14 | ## Dependencies 15 | 16 | - Python 3 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux)) 17 | - [PyTorch](https://pytorch.org/) 18 | - numpy: `conda install numpy` 19 | - matplotlib: `conda install matplotlib` 20 | - opencv: `conda install opencv` 21 | - pandas: `conda install pandas` 22 | 23 | ## Easy to use 24 | Here are some simple demos. If you want to learn more about the usage of these tools, you can refer to the source code for more optional parameters. 25 | 26 | 27 | 28 | ### 1. Calculating Metrics ( PSNR, SSIM, etc.) 29 | 30 | ##### 1.1 PSNR/SSIM 31 | - Following the demo for batch operation: 32 | ```python 33 | # Images 34 | from utils.image_metric_utils import batch_calc_image_PSNR_SSIM 35 | root_list = [ 36 | { 37 | 'output': '/path/to/output images 1', 38 | 'gt': '/path/to/gt images 1' 39 | }, 40 | { 41 | 'output': '/path/to/output images 2', 42 | 'gt': '/path/to/gt images 2' 43 | }, 44 | ] 45 | batch_calc_image_PSNR_SSIM(root_list) 46 | 47 | # Videos 48 | from utils.video_metric_utils import batch_calc_video_PSNR_SSIM 49 | root_list = [ 50 | { 51 | 'output': '/path/to/output videos 1', 52 | 'gt': '/path/to/gt videos 1' 53 | }, 54 | { 55 | 'output': '/path/to/output videos 2', 56 | 'gt': '/path/to/gt videos 2' 57 | }, 58 | ] 59 | batch_calc_video_PSNR_SSIM(root_list) 60 | ``` 61 | 62 | ##### 1.2 LPIPS 63 | - Following the demo for batch operation: 64 | ```python 65 | # Images 66 | from utils.image_metric_utils import batch_calc_image_LPIPS 67 | root_list = [ 68 | { 69 | 'output': '/path/to/output images 1', 70 | 'gt': '/path/to/gt images 1' 71 | }, 72 | { 73 | 'output': '/path/to/output images 2', 74 | 'gt': '/path/to/gt images 2' 75 | }, 76 | ] 77 | batch_calc_image_LPIPS(root_list) 78 | 79 | # Videos 80 | from utils.video_metric_utils import batch_calc_video_LPIPS 81 | root_list = [ 82 | { 83 | 'output': '/path/to/output videos 1', 84 | 'gt': '/path/to/gt videos 1' 85 | }, 86 | { 87 | 'output': '/path/to/output videos 2', 88 | 'gt': '/path/to/gt videos 2' 89 | }, 90 | ] 91 | batch_calc_video_LPIPS(root_list) 92 | ``` 93 | 94 | ##### 1.3 NIQE 95 | - Following the demo for batch operation: 96 | ```python 97 | # Images 98 | from utils.image_metric_utils import batch_calc_image_NIQE 99 | root_list = [ 100 | { 101 | 'output': '/path/to/output images 1', 102 | 'gt': '/path/to/gt images 1' 103 | }, 104 | { 105 | 'output': '/path/to/output images 2', 106 | 'gt': '/path/to/gt images 2' 107 | }, 108 | ] 109 | batch_calc_image_NIQE(root_list) 110 | 111 | # Videos 112 | from utils.video_metric_utils import batch_calc_video_NIQE 113 | root_list = [ 114 | { 115 | 'output': '/path/to/output videos 1', 116 | 'gt': '/path/to/gt videos 1' 117 | }, 118 | { 119 | 'output': '/path/to/output videos 2', 120 | 'gt': '/path/to/gt videos 2' 121 | }, 122 | ] 123 | batch_calc_video_NIQE(root_list) 124 | ``` 125 | 126 | ##### 1.4 Kernel Gradient Similarity 127 | - Following the demo for batch operation: 128 | ```python 129 | # Images: video_type=False 130 | # Videos: video_type=True 131 | from utils.kernel_metric_utils import batch_calc_kernel_metric 132 | root_list = [ 133 | { 134 | 'output': '/path/to/output kernel 1', 135 | 'gt': '/path/to/gt kernel 1' 136 | }, 137 | { 138 | 'output': '/path/to/output kernel 2', 139 | 'gt': '/path/to/gt kernel 2' 140 | }, 141 | ] 142 | batch_calc_kernel_metric(root_list, video_type=False) 143 | ``` 144 | 145 | 146 | 147 | ### 2. Image / Video Processing ( resize, crop, shift, etc.) 148 | 149 | ##### 2.1 Image/Video Resize 150 | - We use fatheral's python implementation of matLab imresize() function [fatheral/matlab_imresize](https://github.com/fatheral/matlab_imresize). 151 | - Following the demo for batch operation: 152 | 153 | ```python 154 | # Images 155 | from utils.image_utils import matlab_resize_images 156 | 157 | ori_root = '/path/to/ori images' 158 | dest_root = '/path/to/dest images' 159 | matlab_resize_images(ori_root, dest_root, scale=2.0) 160 | 161 | # Videos 162 | from utils.video_utils import matlab_resize_videos 163 | 164 | ori_root = '/path/to/ori videos' 165 | dest_root = '/path/to/dest videos' 166 | matlab_resize_videos(ori_root, dest_root, scale=2.0) 167 | ``` 168 | - We also apply opencv for resizing. 169 | - Following the demo for batch operation: 170 | 171 | ```python 172 | from utils.image_utils import cv2_resize_images 173 | 174 | ori_root = '/path/to/ori images' 175 | dest_root = '/path/to/dest images' 176 | cv2_resize_images(ori_root, dest_root, scale=2.0) 177 | 178 | # Videos 179 | from utils.video_utils import cv2_resize_videos 180 | 181 | ori_root = '/path/to/ori videos' 182 | dest_root = '/path/to/dest videos' 183 | cv2_resize_videos(ori_root, dest_root, scale=2.0) 184 | ``` 185 | 186 | ##### 2.2 Crop and combine images 187 | - When you need to infer large image, you can crop image to many patches with padding by following the demo: 188 | ```python 189 | # Notice: 190 | # filenames should not contain the character "-" 191 | # the crop flag "x-x-x-x" will be at the end of filename when cropping 192 | # the combine operation will use the crop flag "x-x-x-x" 193 | from utils.image_crop_combine_utils import * 194 | ori_root = '/path/to/ori images' 195 | dest_root = '/path/to/dest images' 196 | batch_crop_img_with_padding(ori_root, dest_root, min_size=(800, 800), padding=100) 197 | ``` 198 | - When you finish inferring large image with cropped patches, you can combine patches to image by following the demo: 199 | ```python 200 | # Notice: 201 | # filenames should not contain the character "-" except for the crop flag 202 | # the crop flag "x-x-x-x" should be at the end of filename when combining 203 | from utils.image_crop_combine_utils import * 204 | ori_root = '/path/to/ori images' 205 | dest_root = '/path/to/dest images' 206 | batch_combine_img(ori_root, dest_root, padding=100) 207 | ``` 208 | - You can traversal crop image to many patches with same interval by following the demo: 209 | ```python 210 | from utils.image_crop_combine_utils import * 211 | ori_root = '/path/to/ori images' 212 | dest_root = '/path/to/dest images' 213 | batch_traverse_crop_img(ori_root, dest_root, dsize=(800, 800), interval=400) 214 | ``` 215 | - You can select valid patch that are not too smooth by following the demo: 216 | ```python 217 | from utils.image_crop_combine_utils import * 218 | ori_root = '/path/to/ori images' 219 | dest_root = '/path/to/dest images' 220 | batch_select_valid_patch(ori_root, dest_root) 221 | ``` 222 | 223 | ##### 2.3 Image/Video Shift 224 | - We use "Bilinear" interpolation method to shift images/videos for sub-pixels. 225 | - Following the demo for batch operation: 226 | 227 | ```python 228 | # Images 229 | from utils.image_utils import shift_images 230 | 231 | ori_root = '/path/to/ori images' 232 | dest_root = '/path/to/dest images' 233 | shift_images(ori_root, dest_root, offset_x=0.5, offset_y=0.5) 234 | 235 | # Videos 236 | from utils.video_utils import shift_videos 237 | 238 | ori_root = '/path/to/ori videos' 239 | dest_root = '/path/to/dest videos' 240 | shift_videos(ori_root, dest_root, offset_x=0.5, offset_y=0.5) 241 | ``` 242 | 243 | 244 | 245 | ### 3. Deep Model Properties ( Params, Flops, etc. ) 246 | - You can only calculate model Params by following the demo: 247 | ```python 248 | from utils.dnn_utils import cal_parmeters 249 | network = None # Please define the model 250 | cal_parmeters(network) 251 | ``` 252 | - You can also calculate more properties of model: Params, Memory, MAdd, Flops, etc. 253 | - We use Swall0w's tools [Swall0w/torchstat](https://github.com/Swall0w/torchstat). 254 | - [Swall0w/torchstat] can not use cuda, so we modified it for using cuda. If you want to calculate Flops on cuda, please using the following command to install torchstat. 255 | ```shell script 256 | cd ./utils/torchstat 257 | python3 setup.py install 258 | ``` 259 | - If you do not want to using cuda, please using the following command to install torchstat. 260 | ```shell script 261 | pip install torchstat # pytorch >= 1.0.0 262 | pip install torchstat==0.0.6 # pytorch = 0.4.1 263 | ``` 264 | - And then you can calculate properties by following the demo: 265 | ```python 266 | from torchstat import stat 267 | network = None # Please define the model 268 | input_size = (3, 80, 80) # the size of input (channel, height, width) 269 | stat(network, input_size) 270 | ``` 271 | 272 | 273 | 274 | ### 4. File Processing ( csv, etc. ) 275 | 276 | ##### 4.1 csv file 277 | - You can read a csv file by following the demo: 278 | ```python 279 | from base import file_io_base 280 | data, col_names, row_names = file_io_base.read_csv('filename.csv', col_name_ind=0, row_name_ind=0) 281 | ``` 282 | - You can write a numpy.array into a csv file by following the demo: 283 | ```python 284 | from base import file_io_base 285 | import numpy as np 286 | row_names = ['r1', 'r2', 'r3'] 287 | col_names = ['c1', 'c2', 'c3', 'c4'] 288 | data_array = np.array([[1, 2, 3, 4], 289 | [2, 3, 4, 5], 290 | [3, 4, 5, 6]]) 291 | file_io_base.write_csv('filename.csv', data_array, col_names, row_names) 292 | ``` 293 | 294 | 295 | 296 | ### 5. Visualize Tools ( plot, optical-flow, etc. ) 297 | 298 | ##### 5.1 Plot multiple curves in one figure 299 | - You plot multiple curves in one figure by following the demo: 300 | ```python 301 | import numpy as np 302 | from utils import plot_utils 303 | array_list = [ 304 | np.array([1, 4, 5, 3, 6]), 305 | np.array([2, 3, 7, 4, 5]), 306 | ] 307 | label_list = ['curve-1', 'curve-2'] 308 | plot_utils.plot_multi_curve(array_list, label_list) 309 | ``` 310 | 311 | ##### 5.2 Visualize optical flow 312 | - You can visualize optical flow by following the demo: 313 | ```python 314 | from utils import flow_utils 315 | flow = None # this is flow, shape=(h, w, 2) 316 | rgb_image = flow_utils.visual_flow(flow) 317 | ``` 318 | -------------------------------------------------------------------------------- /base/file_io_base.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | 4 | 5 | ################################################################################# 6 | #### csv file I/O #### 7 | ################################################################################# 8 | def read_csv(file_path): 9 | ''' 10 | function: 11 | read csv file into list[list[]] 12 | required params: 13 | file_path: the csv file's path 14 | return: 15 | data_list: the csv file's content 16 | ''' 17 | csv_data = pd.read_csv(file_path) 18 | data_list = csv_data.values.tolist() 19 | return data_list 20 | 21 | 22 | def write_csv(file_path, data, col_names=None, row_names=None): 23 | ''' 24 | function: 25 | write np.array into csv file 26 | required params: 27 | file_path: the csv file's path 28 | data: np.array, the ndim must be 2, the content of csv file 29 | optional params: 30 | col_names: a list, default None, if not None, the len must equal to data.shape[1] 31 | row_names: a list, default None, if not None, the len must equal to data.shape[0] 32 | ''' 33 | index = False if row_names is None else True 34 | header = False if col_names is None else True 35 | assert data.ndim == 2, 'the ndim of data must be 2!' 36 | if index: 37 | assert data.shape[0] == len(row_names), 'the number of data rows must equal to len(row_names)' 38 | if header: 39 | assert data.shape[1] == len(col_names), 'the number of data cols must equal to len(col_names)' 40 | data = pd.DataFrame(data, index=row_names, columns=col_names) 41 | data.to_csv(file_path, index=index, header=header) 42 | 43 | 44 | ################################################################################# 45 | #### excel file I/O #### 46 | ################################################################################# 47 | def read_excel(file_path, sheet_name=0): 48 | ''' 49 | function: 50 | read excel file into list[list[]] 51 | required params: 52 | file_path: the excel file's path 53 | optional params: 54 | sheet_name: int or str, the index/name of read sheet 55 | return: 56 | data_list: the excel file's content 57 | ''' 58 | excel_data = pd.read_excel(file_path, sheet_name=sheet_name) 59 | data_list = excel_data.values.tolist() 60 | return data_list 61 | -------------------------------------------------------------------------------- /base/flow_base.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | def visual_flow(flow): 6 | # 色调H:用角度度量,取值范围为0°~360°,从红色开始按逆时针方向计算,红色为0°,绿色为120°,蓝色为240° 7 | # 饱和度S:取值范围为0.0~1.0 8 | # 亮度V:取值范围为0.0(黑色)~1.0(白色) 9 | # flow shape: [h, w, 2] 10 | h, w = flow.shape[:2] 11 | hsv = np.zeros((h, w, 3), np.uint8) 12 | mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1]) 13 | hsv[..., 0] = ang * 180 / np.pi / 2 14 | hsv[..., 1] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX) 15 | # flownet是将V赋值为255, 此函数遵循flownet,饱和度S代表像素位移的大小,亮度都为最大,便于观看 16 | # 也有的光流可视化讲s赋值为255,亮度代表像素位移的大小,整个图片会很暗,很少这样用 17 | hsv[..., 2] = 255 18 | bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) 19 | return bgr 20 | -------------------------------------------------------------------------------- /base/image_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import math 5 | from torch.autograd import Variable 6 | import cv2 7 | from base.matlab_imresize import imresize 8 | 9 | 10 | ################################################################################# 11 | #### Image Operation #### 12 | ################################################################################# 13 | def matlab_imresize(img, scalar_scale=None, output_shape=None, method='bicubic'): 14 | '''same as matlab2017 imresize 15 | img: shape=[h, w, c] 16 | scalar_scale: the resize scale 17 | if None, using output_shape 18 | output_shape: the resize shape, (h, w) 19 | if scalar_scale=None, using this param 20 | method: the interpolation method 21 | optional: 'bicubic', 'bilinear' 22 | default: 'bicubic' 23 | ''' 24 | return imresize( 25 | I=img, 26 | scalar_scale=scalar_scale, 27 | output_shape=output_shape, 28 | method=method 29 | ) 30 | 31 | 32 | def image_shift(img, offset_x=0., offset_y=0.): 33 | '''shift the img by (offset_x, offset_y) on (axis-x, axis-y) 34 | img: numpy.array, shape=[h, w, c] 35 | offset_x: offset pixels on axis-x 36 | positive=left; negative=right 37 | offset_y: offset pixels on axis-y 38 | positive=up; negative=down 39 | ''' 40 | img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).float() 41 | B, C, H, W = img_tensor.size() 42 | 43 | # init flow 44 | flo = torch.ones(B, 2, H, W).type_as(img_tensor) 45 | flo[:, 0, :, :] *= offset_x 46 | flo[:, 1, :, :] *= offset_y 47 | 48 | # mesh grid 49 | xx = torch.arange(0, W).view(1, -1).repeat(H, 1) 50 | yy = torch.arange(0, H).view(-1, 1).repeat(1, W) 51 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1) 52 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1) 53 | grid = torch.cat((xx, yy), 1).float() 54 | vgrid = Variable(grid) + flo 55 | 56 | # scale grid to [-1,1] 57 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 58 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 59 | 60 | # Interpolation 61 | vgrid = vgrid.permute(0, 2, 3, 1) 62 | output_tensor = F.grid_sample(img_tensor, vgrid, padding_mode='border') 63 | 64 | output = output_tensor.round()[0].detach().numpy().transpose(1, 2, 0).astype(img.dtype) 65 | 66 | return output 67 | 68 | 69 | def rgb2ycbcr(img, range=255., only_y=True): 70 | """same as matlab rgb2ycbcr, please use bgr2ycbcr when using cv2.imread 71 | img: shape=[h, w, 3] 72 | range: the data range 73 | only_y: only return Y channel 74 | """ 75 | in_img_type = img.dtype 76 | img.astype(np.float32) 77 | range_scale = 255. / range 78 | img *= range_scale 79 | 80 | # convert 81 | if only_y: 82 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 83 | else: 84 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], 85 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] 86 | 87 | rlt /= range_scale 88 | if in_img_type == np.uint8: 89 | rlt = rlt.round() 90 | return rlt.astype(in_img_type) 91 | 92 | 93 | def bgr2ycbcr(img, range=255., only_y=True): 94 | """bgr version of rgb2ycbcr, for cv2.imread 95 | img: shape=[h, w, 3] 96 | range: the data range 97 | only_y: only return Y channel 98 | """ 99 | in_img_type = img.dtype 100 | img.astype(np.float32) 101 | range_scale = 255. / range 102 | img *= range_scale 103 | 104 | # convert 105 | if only_y: 106 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 107 | else: 108 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 109 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] 110 | 111 | rlt /= range_scale 112 | if in_img_type == np.uint8: 113 | rlt = rlt.round() 114 | return rlt.astype(in_img_type) 115 | 116 | 117 | def ycbcr2rgb(img, range=255.): 118 | """same as matlab ycbcr2rgb 119 | img: shape=[h, w, 3] 120 | range: the data range 121 | """ 122 | in_img_type = img.dtype 123 | img.astype(np.float32) 124 | range_scale = 255. / range 125 | img *= range_scale 126 | 127 | # convert 128 | rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], 129 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] 130 | 131 | rlt /= range_scale 132 | if in_img_type == np.uint8: 133 | rlt = rlt.round() 134 | return rlt.astype(in_img_type) 135 | 136 | 137 | ################################################################################# 138 | #### Image PSNR SSIM #### 139 | ################################################################################# 140 | 141 | def RMSE(img1, img2): 142 | img1 = img1.astype(np.float64) 143 | img2 = img2.astype(np.float64) 144 | mse = np.mean((img1 - img2) ** 2) 145 | if mse == 0: 146 | return float('inf') 147 | return math.sqrt(mse) 148 | 149 | 150 | def PSNR(img1, img2): 151 | ''' 152 | img1, img2: [0, 255] 153 | ''' 154 | img1 = img1.astype(np.float64) 155 | img2 = img2.astype(np.float64) 156 | mse = np.mean((img1 - img2) ** 2) 157 | if mse == 0: 158 | return float('inf') 159 | return 20 * math.log10(255.0 / math.sqrt(mse)) 160 | 161 | 162 | def SSIM(img1, img2): 163 | '''calculate SSIM 164 | the same outputs as MATLAB's 165 | img1, img2: [0, 255] 166 | ''' 167 | 168 | def ssim(img1, img2): 169 | C1 = (0.01 * 255) ** 2 170 | C2 = (0.03 * 255) ** 2 171 | 172 | img1 = img1.astype(np.float64) 173 | img2 = img2.astype(np.float64) 174 | kernel = cv2.getGaussianKernel(11, 1.5) 175 | window = np.outer(kernel, kernel.transpose()) 176 | 177 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 178 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 179 | mu1_sq = mu1 ** 2 180 | mu2_sq = mu2 ** 2 181 | mu1_mu2 = mu1 * mu2 182 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 183 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 184 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 185 | 186 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 187 | (sigma1_sq + sigma2_sq + C2)) 188 | return ssim_map.mean() 189 | 190 | if not img1.shape == img2.shape: 191 | raise ValueError('Input images must have the same dimensions.') 192 | if img1.ndim == 2: 193 | return ssim(img1, img2) 194 | elif img1.ndim == 3: 195 | if img1.shape[2] == 3: 196 | ssims = [] 197 | for i in range(3): 198 | ssims.append(ssim(img1, img2)) 199 | return np.array(ssims).mean() 200 | elif img1.shape[2] == 1: 201 | return ssim(np.squeeze(img1), np.squeeze(img2)) 202 | else: 203 | raise ValueError('Wrong input image dimensions.') 204 | 205 | 206 | def PSNR_SSIM_Shift_Best(img1, img2, window_size=5): 207 | ''' 208 | img1, img2: [0, 255] 209 | ''' 210 | img1 = img1.astype(np.float64) 211 | img2 = img2.astype(np.float64) 212 | img1 = img1[window_size:-window_size, window_size:-window_size, :] 213 | best_psnr, best_ssim = 0., 0. 214 | for i in range(-window_size, window_size + 1): 215 | for j in range(-window_size, window_size + 1): 216 | shifted_img2 = image_shift(img2, offset_x=i, offset_y=j) 217 | shifted_img2 = shifted_img2[window_size:-window_size, window_size:-window_size, :] 218 | psnr = PSNR(img1, shifted_img2) 219 | ssim = SSIM(img1, shifted_img2) 220 | best_psnr = max(best_psnr, psnr) 221 | best_ssim = max(best_ssim, ssim) 222 | return best_psnr, best_ssim 223 | 224 | 225 | ################################################################################# 226 | #### Others #### 227 | ################################################################################# 228 | 229 | 230 | def evaluate_smooth(img): 231 | x = cv2.Sobel(img, cv2.CV_16S, 1, 0) 232 | y = cv2.Sobel(img, cv2.CV_16S, 0, 1) 233 | absX = cv2.convertScaleAbs(x) 234 | absY = cv2.convertScaleAbs(y) 235 | dst = cv2.addWeighted(absX, 0.5, absY, 0.5, 0) 236 | smooth = np.mean(dst) 237 | return smooth 238 | 239 | 240 | def calc_grad_sobel(img, device='cuda'): 241 | if not isinstance(img, torch.Tensor): 242 | raise Exception("Now just support torch.Tensor. See the Type(img)={}".format(type(img))) 243 | if not img.ndimension() == 4: 244 | raise Exception("Tensor ndimension must equal to 4. See the img.ndimension={}".format(img.ndimension())) 245 | 246 | img = torch.mean(img, dim=1, keepdim=True) 247 | 248 | # img = calc_meanFilter(img, device=device) # meanFilter 249 | 250 | sobel_filter_X = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).reshape((1, 1, 3, 3)) 251 | sobel_filter_Y = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).reshape((1, 1, 3, 3)) 252 | sobel_filter_X = torch.from_numpy(sobel_filter_X).float().to(device) 253 | sobel_filter_Y = torch.from_numpy(sobel_filter_Y).float().to(device) 254 | grad_X = F.conv2d(img, sobel_filter_X, bias=None, stride=1, padding=1) 255 | grad_Y = F.conv2d(img, sobel_filter_Y, bias=None, stride=1, padding=1) 256 | grad = torch.sqrt(grad_X.pow(2) + grad_Y.pow(2)) 257 | 258 | return grad_X, grad_Y, grad 259 | 260 | 261 | def calc_meanFilter(img, kernel_size=11, n_channel=1, device='cuda'): 262 | mean_filter_X = np.ones(shape=(1, 1, kernel_size, kernel_size), dtype=np.float32) / (kernel_size * kernel_size) 263 | mean_filter_X = torch.from_numpy(mean_filter_X).float().to(device) 264 | new_img = torch.zeros_like(img) 265 | for i in range(n_channel): 266 | new_img[:, i:i + 1, :, :] = F.conv2d(img[:, i:i + 1, :, :], mean_filter_X, bias=None, 267 | stride=1, padding=kernel_size // 2) 268 | return new_img 269 | 270 | 271 | def warp_by_flow(x, flo, device='cuda'): 272 | B, C, H, W = flo.size() 273 | 274 | # mesh grid 275 | xx = torch.arange(0, W).view(1, -1).repeat(H, 1) 276 | yy = torch.arange(0, H).view(-1, 1).repeat(1, W) 277 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1) 278 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1) 279 | grid = torch.cat((xx, yy), 1).float() 280 | grid = grid.to(device) 281 | vgrid = Variable(grid) + flo 282 | 283 | # scale grid to [-1,1] 284 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 285 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 286 | 287 | vgrid = vgrid.permute(0, 2, 3, 1) 288 | output = F.grid_sample(x, vgrid, padding_mode='border') 289 | 290 | return output 291 | -------------------------------------------------------------------------------- /base/kernel_base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import cv2 4 | from scipy import io as scio 5 | from base.image_base import RMSE, PSNR, SSIM 6 | 7 | 8 | def load_mat_kernel(mat_path): 9 | data_dict = scio.loadmat(mat_path) 10 | key = [v for v in data_dict.keys() if v not in ['__header__', '__version__', '__globals__']][0] 11 | return data_dict[key] 12 | 13 | 14 | def kernel2png(kernel): 15 | # kernel: [ks, ks], float32/float64 16 | # return kernel_png: [ks, ks, 3], uint8 17 | kernel = cv2.resize(kernel, dsize=(0, 0), fx=8, fy=8, interpolation=cv2.INTER_CUBIC) 18 | kernel = np.clip(kernel, 0, np.max(kernel)) 19 | kernel = kernel / np.sum(kernel) 20 | mi = np.min(kernel) 21 | ma = np.max(kernel) 22 | kernel = (kernel - mi) / (ma - mi) 23 | kernel = np.round(np.clip(kernel * 255., 0, 255)) 24 | kernel_png = np.stack([kernel, kernel, kernel], axis=2).astype('uint8') 25 | return kernel_png 26 | 27 | 28 | def Gradient_Similarity(gradB, gradT): 29 | ''' 30 | calculate the gradient similarity using the response of convolution devided by the norm 31 | Reference: Z. Hu and M.-H. Yang. Good regions to deblur. In ECCV 2012 32 | ''' 33 | gradB = gradB.astype(np.float64) 34 | gradT = gradT.astype(np.float64) 35 | 36 | gradB_sum = math.sqrt(np.sum(np.power(gradB, 2))) 37 | gradT_sum = math.sqrt(np.sum(np.power(gradT, 2))) 38 | 39 | kb, kt = gradB.shape[0], gradT.shape[0] 40 | psize = kt // 2 41 | gradB_norm = gradB / gradB_sum 42 | gradB_pad = np.zeros(shape=[kb + 2 * psize, kb + 2 * psize], dtype=np.float64) 43 | gradB_pad[psize:-psize, psize:-psize] = gradB_norm 44 | corr = cv2.filter2D(src=gradB_pad, ddepth=-1, kernel=gradT, borderType=cv2.BORDER_CONSTANT) 45 | 46 | similarity = np.max(corr) / gradT_sum 47 | 48 | return similarity 49 | 50 | 51 | def Kernel_RMSE_PSNR_SSIM(kernel1, kernel2): 52 | h1, w1 = kernel1.shape 53 | h2, w2 = kernel2.shape 54 | 55 | mh, mw = max(h1, h2), max(w1, w2) 56 | kernel1_pad = np.zeros(shape=(mh, mw)).astype(kernel1.dtype) 57 | kernel2_pad = np.zeros(shape=(mh, mw)).astype(kernel2.dtype) 58 | kernel1_pad[(mh - h1) // 2:(mh - h1) // 2 + h1, (mw - w1) // 2:(mw - w1) // 2 + w1] = kernel1 59 | kernel2_pad[(mh - h2) // 2:(mh - h2) // 2 + h2, (mw - w2) // 2:(mw - w2) // 2 + w2] = kernel2 60 | 61 | # kernel1_pad = cv2.resize(kernel1, dsize=(w2, w2), interpolation=cv2.INTER_CUBIC) 62 | # kernel2_pad = kernel2 63 | 64 | rmse = RMSE(kernel1_pad, kernel2_pad) 65 | psnr = PSNR(kernel1_pad * 255, kernel2_pad * 255) 66 | ssim = SSIM(kernel1_pad * 255, kernel2_pad * 255) 67 | return rmse, psnr, ssim 68 | -------------------------------------------------------------------------------- /base/matlab_imresize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from math import ceil 3 | 4 | 5 | def deriveSizeFromScale(img_shape, scale): 6 | output_shape = [] 7 | for k in range(2): 8 | output_shape.append(int(ceil(scale[k] * img_shape[k]))) 9 | return output_shape 10 | 11 | 12 | def deriveScaleFromSize(img_shape_in, img_shape_out): 13 | scale = [] 14 | for k in range(2): 15 | scale.append(1.0 * img_shape_out[k] / img_shape_in[k]) 16 | return scale 17 | 18 | 19 | def triangle(x): 20 | x = np.array(x).astype(np.float64) 21 | lessthanzero = np.logical_and((x >= -1), x < 0) 22 | greaterthanzero = np.logical_and((x <= 1), x >= 0) 23 | f = np.multiply((x + 1), lessthanzero) + np.multiply((1 - x), greaterthanzero) 24 | return f 25 | 26 | 27 | def cubic(x): 28 | x = np.array(x).astype(np.float64) 29 | absx = np.absolute(x) 30 | absx2 = np.multiply(absx, absx) 31 | absx3 = np.multiply(absx2, absx) 32 | f = np.multiply(1.5 * absx3 - 2.5 * absx2 + 1, absx <= 1) + np.multiply(-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2, 33 | (1 < absx) & (absx <= 2)) 34 | return f 35 | 36 | 37 | def contributions(in_length, out_length, scale, kernel, k_width): 38 | if scale < 1: 39 | h = lambda x: scale * kernel(scale * x) 40 | kernel_width = 1.0 * k_width / scale 41 | else: 42 | h = kernel 43 | kernel_width = k_width 44 | x = np.arange(1, out_length + 1).astype(np.float64) 45 | u = x / scale + 0.5 * (1 - 1 / scale) 46 | left = np.floor(u - kernel_width / 2) 47 | P = int(ceil(kernel_width)) + 2 48 | ind = np.expand_dims(left, axis=1) + np.arange(P) - 1 # -1 because indexing from 0 49 | indices = ind.astype(np.int32) 50 | weights = h(np.expand_dims(u, axis=1) - indices - 1) # -1 because indexing from 0 51 | weights = np.divide(weights, np.expand_dims(np.sum(weights, axis=1), axis=1)) 52 | aux = np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))).astype(np.int32) 53 | indices = aux[np.mod(indices, aux.size)] 54 | ind2store = np.nonzero(np.any(weights, axis=0)) 55 | weights = weights[:, ind2store] 56 | indices = indices[:, ind2store] 57 | return weights, indices 58 | 59 | 60 | def imresizemex(inimg, weights, indices, dim): 61 | in_shape = inimg.shape 62 | w_shape = weights.shape 63 | out_shape = list(in_shape) 64 | out_shape[dim] = w_shape[0] 65 | outimg = np.zeros(out_shape) 66 | if dim == 0: 67 | for i_img in range(in_shape[1]): 68 | for i_w in range(w_shape[0]): 69 | w = weights[i_w, :] 70 | ind = indices[i_w, :] 71 | im_slice = inimg[ind, i_img].astype(np.float64) 72 | outimg[i_w, i_img] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0) 73 | elif dim == 1: 74 | for i_img in range(in_shape[0]): 75 | for i_w in range(w_shape[0]): 76 | w = weights[i_w, :] 77 | ind = indices[i_w, :] 78 | im_slice = inimg[i_img, ind].astype(np.float64) 79 | outimg[i_img, i_w] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0) 80 | if inimg.dtype == np.uint8: 81 | outimg = np.clip(outimg, 0, 255) 82 | return np.around(outimg).astype(np.uint8) 83 | else: 84 | return outimg 85 | 86 | 87 | def imresizevec(inimg, weights, indices, dim): 88 | wshape = weights.shape 89 | if dim == 0: 90 | weights = weights.reshape((wshape[0], wshape[2], 1, 1)) 91 | outimg = np.sum(weights * ((inimg[indices].squeeze(axis=1)).astype(np.float64)), axis=1) 92 | elif dim == 1: 93 | weights = weights.reshape((1, wshape[0], wshape[2], 1)) 94 | outimg = np.sum(weights * ((inimg[:, indices].squeeze(axis=2)).astype(np.float64)), axis=2) 95 | if inimg.dtype == np.uint8: 96 | outimg = np.clip(outimg, 0, 255) 97 | return np.around(outimg).astype(np.uint8) 98 | else: 99 | return outimg 100 | 101 | 102 | def resizeAlongDim(A, dim, weights, indices, mode="vec"): 103 | if mode == "org": 104 | out = imresizemex(A, weights, indices, dim) 105 | else: 106 | out = imresizevec(A, weights, indices, dim) 107 | return out 108 | 109 | 110 | def imresize(I, scalar_scale=None, method='bicubic', output_shape=None, mode="vec"): 111 | if method is 'bicubic': 112 | kernel = cubic 113 | elif method is 'bilinear': 114 | kernel = triangle 115 | else: 116 | print('Error: Unidentified method supplied') 117 | 118 | kernel_width = 4.0 119 | # Fill scale and output_size 120 | if scalar_scale is not None: 121 | scalar_scale = float(scalar_scale) 122 | scale = [scalar_scale, scalar_scale] 123 | output_size = deriveSizeFromScale(I.shape, scale) 124 | elif output_shape is not None: 125 | scale = deriveScaleFromSize(I.shape, output_shape) 126 | output_size = list(output_shape) 127 | else: 128 | print('Error: scalar_scale OR output_shape should be defined!') 129 | return 130 | scale_np = np.array(scale) 131 | order = np.argsort(scale_np) 132 | weights = [] 133 | indices = [] 134 | for k in range(2): 135 | w, ind = contributions(I.shape[k], output_size[k], scale[k], kernel, kernel_width) 136 | weights.append(w) 137 | indices.append(ind) 138 | B = np.copy(I) 139 | flag2D = False 140 | if B.ndim == 2: 141 | B = np.expand_dims(B, axis=2) 142 | flag2D = True 143 | for k in range(2): 144 | dim = order[k] 145 | B = resizeAlongDim(B, dim, weights[dim], indices[dim], mode) 146 | if flag2D: 147 | B = np.squeeze(B, axis=2) 148 | return B 149 | 150 | 151 | def convertDouble2Byte(I): 152 | B = np.clip(I, 0.0, 1.0) 153 | B = 255 * B 154 | return np.around(B).astype(np.uint8) 155 | -------------------------------------------------------------------------------- /base/os_base.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | import glob 4 | 5 | 6 | def handle_dir(dir): 7 | if not os.path.exists(dir): 8 | os.mkdir(dir) 9 | print('mkdir:', dir) 10 | 11 | 12 | def listdir(path): 13 | sys_files = ['.DS_Store'] 14 | files = os.listdir(path) 15 | for sf in sys_files: 16 | if sf in files: 17 | files.remove(sf) 18 | return files 19 | 20 | 21 | def glob_match(template): 22 | fpathes = glob.glob(template) 23 | dest_fpathes = [] 24 | for fp in fpathes: 25 | if '.DS_Store' not in fp: 26 | dest_fpathes.append(fp) 27 | return dest_fpathes 28 | 29 | 30 | def get_fname_ext(filepath): 31 | filename = os.path.basename(filepath) 32 | ext = filename.split(".")[-1] 33 | fname = filename[:-(len(ext) + 1)] 34 | return fname, ext 35 | 36 | 37 | def copy_file(src, dst): 38 | shutil.copy(src, dst) 39 | print('copy file from {} to {}'.format(os.path.basename(src), os.path.basename(dst))) 40 | 41 | 42 | def move_file(src, dst): 43 | shutil.move(src, dst) 44 | print('move file from {} to {}'.format(os.path.basename(src), os.path.basename(dst))) 45 | 46 | 47 | def rename_file(src, dst): 48 | os.rename(src, dst) 49 | print('rename file from {} to {}'.format(os.path.basename(src), os.path.basename(dst))) 50 | -------------------------------------------------------------------------------- /utils/LPIPS/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from skimage.measure import compare_ssim 8 | import torch 9 | 10 | from utils.LPIPS.models import dist_model 11 | 12 | class PerceptualLoss(torch.nn.Module): 13 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0], version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric) 14 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 15 | super(PerceptualLoss, self).__init__() 16 | print('Setting up Perceptual loss...') 17 | self.use_gpu = use_gpu 18 | self.spatial = spatial 19 | self.gpu_ids = gpu_ids 20 | self.model = dist_model.DistModel() 21 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids, version=version) 22 | print('...[%s] initialized'%self.model.name()) 23 | print('...Done') 24 | 25 | def forward(self, pred, target, normalize=False): 26 | """ 27 | Pred and target are Variables. 28 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 29 | If normalize is False, assumes the images are already between [-1,+1] 30 | 31 | Inputs pred and target are Nx3xHxW 32 | Output pytorch Variable N long 33 | """ 34 | 35 | if normalize: 36 | target = 2 * target - 1 37 | pred = 2 * pred - 1 38 | 39 | return self.model.forward(target, pred) 40 | 41 | def normalize_tensor(in_feat,eps=1e-10): 42 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 43 | return in_feat/(norm_factor+eps) 44 | 45 | def l2(p0, p1, range=255.): 46 | return .5*np.mean((p0 / range - p1 / range)**2) 47 | 48 | def psnr(p0, p1, peak=255.): 49 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 50 | 51 | def dssim(p0, p1, range=255.): 52 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 53 | 54 | def rgb2lab(in_img,mean_cent=False): 55 | from skimage import color 56 | img_lab = color.rgb2lab(in_img) 57 | if(mean_cent): 58 | img_lab[:,:,0] = img_lab[:,:,0]-50 59 | return img_lab 60 | 61 | def tensor2np(tensor_obj): 62 | # change dimension of a tensor object into a numpy array 63 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 64 | 65 | def np2tensor(np_obj): 66 | # change dimenion of np array into tensor array 67 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 68 | 69 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 70 | # image tensor to lab tensor 71 | from skimage import color 72 | 73 | img = tensor2im(image_tensor) 74 | img_lab = color.rgb2lab(img) 75 | if(mc_only): 76 | img_lab[:,:,0] = img_lab[:,:,0]-50 77 | if(to_norm and not mc_only): 78 | img_lab[:,:,0] = img_lab[:,:,0]-50 79 | img_lab = img_lab/100. 80 | 81 | return np2tensor(img_lab) 82 | 83 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 84 | from skimage import color 85 | import warnings 86 | warnings.filterwarnings("ignore") 87 | 88 | lab = tensor2np(lab_tensor)*100. 89 | lab[:,:,0] = lab[:,:,0]+50 90 | 91 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 92 | if(return_inbnd): 93 | # convert back to lab, see if we match 94 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 95 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 96 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 97 | return (im2tensor(rgb_back),mask) 98 | else: 99 | return im2tensor(rgb_back) 100 | 101 | def rgb2lab(input): 102 | from skimage import color 103 | return color.rgb2lab(input / 255.) 104 | 105 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 106 | image_numpy = image_tensor[0].cpu().float().numpy() 107 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 108 | return image_numpy.astype(imtype) 109 | 110 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 111 | return torch.Tensor((image / factor - cent) 112 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 113 | 114 | def tensor2vec(vector_tensor): 115 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 116 | 117 | def voc_ap(rec, prec, use_07_metric=False): 118 | """ ap = voc_ap(rec, prec, [use_07_metric]) 119 | Compute VOC AP given precision and recall. 120 | If use_07_metric is true, uses the 121 | VOC 07 11 point method (default:False). 122 | """ 123 | if use_07_metric: 124 | # 11 point metric 125 | ap = 0. 126 | for t in np.arange(0., 1.1, 0.1): 127 | if np.sum(rec >= t) == 0: 128 | p = 0 129 | else: 130 | p = np.max(prec[rec >= t]) 131 | ap = ap + p / 11. 132 | else: 133 | # correct AP calculation 134 | # first append sentinel values at the end 135 | mrec = np.concatenate(([0.], rec, [1.])) 136 | mpre = np.concatenate(([0.], prec, [0.])) 137 | 138 | # compute the precision envelope 139 | for i in range(mpre.size - 1, 0, -1): 140 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 141 | 142 | # to calculate area under PR curve, look for points 143 | # where X axis (recall) changes value 144 | i = np.where(mrec[1:] != mrec[:-1])[0] 145 | 146 | # and sum (\Delta recall) * prec 147 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 148 | return ap 149 | 150 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 151 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 152 | image_numpy = image_tensor[0].cpu().float().numpy() 153 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 154 | return image_numpy.astype(imtype) 155 | 156 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 157 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 158 | return torch.Tensor((image / factor - cent) 159 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 160 | -------------------------------------------------------------------------------- /utils/LPIPS/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | from pdb import set_trace as st 5 | # from IPython import embed 6 | 7 | class BaseModel(): 8 | def __init__(self): 9 | pass; 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, use_gpu=True, gpu_ids=[0]): 15 | self.use_gpu = use_gpu 16 | self.gpu_ids = gpu_ids 17 | 18 | def forward(self): 19 | pass 20 | 21 | def get_image_paths(self): 22 | pass 23 | 24 | def optimize_parameters(self): 25 | pass 26 | 27 | def get_current_visuals(self): 28 | return self.input 29 | 30 | def get_current_errors(self): 31 | return {} 32 | 33 | def save(self, label): 34 | pass 35 | 36 | # helper saving function that can be used by subclasses 37 | def save_network(self, network, path, network_label, epoch_label): 38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 39 | save_path = os.path.join(path, save_filename) 40 | torch.save(network.state_dict(), save_path) 41 | 42 | # helper loading function that can be used by subclasses 43 | def load_network(self, network, network_label, epoch_label): 44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 45 | save_path = os.path.join(self.save_dir, save_filename) 46 | print('Loading network from %s'%save_path) 47 | network.load_state_dict(torch.load(save_path)) 48 | 49 | def update_learning_rate(): 50 | pass 51 | 52 | def get_image_paths(self): 53 | return self.image_paths 54 | 55 | def save_done(self, flag=False): 56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 58 | 59 | -------------------------------------------------------------------------------- /utils/LPIPS/models/dist_model.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import numpy as np 5 | import torch 6 | import os 7 | from collections import OrderedDict 8 | from torch.autograd import Variable 9 | from .base_model import BaseModel 10 | from scipy.ndimage import zoom 11 | from tqdm import tqdm 12 | 13 | # from IPython import embed 14 | 15 | from . import networks_basic as networks 16 | from .. import models as util 17 | 18 | 19 | class DistModel(BaseModel): 20 | def name(self): 21 | return self.model_name 22 | 23 | def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None, 24 | use_gpu=True, printNet=False, spatial=False, 25 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): 26 | ''' 27 | INPUTS 28 | model - ['net-lin'] for linearly calibrated network 29 | ['net'] for off-the-shelf network 30 | ['L2'] for L2 distance in Lab colorspace 31 | ['SSIM'] for ssim in RGB colorspace 32 | net - ['squeeze','alex','vgg'] 33 | model_path - if None, will look in weights/[NET_NAME].pth 34 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM 35 | use_gpu - bool - whether or not to use a GPU 36 | printNet - bool - whether or not to print network architecture out 37 | spatial - bool - whether to output an array containing varying distances across spatial dimensions 38 | is_train - bool - [True] for training mode 39 | lr - float - initial learning rate 40 | beta1 - float - initial momentum term for adam 41 | version - 0.1 for latest, 0.0 was original (with a bug) 42 | gpu_ids - int array - [0] by default, gpus to use 43 | ''' 44 | BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) 45 | 46 | self.model = model 47 | self.net = net 48 | self.is_train = is_train 49 | self.spatial = spatial 50 | self.gpu_ids = gpu_ids 51 | self.model_name = '%s [%s]'%(model,net) 52 | 53 | if(self.model == 'net-lin'): # pretrained net + linear layer 54 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, 55 | use_dropout=True, spatial=spatial, version=version, lpips=True) 56 | kw = {} 57 | if not use_gpu: 58 | kw['map_location'] = 'cpu' 59 | if(model_path is None): 60 | import inspect 61 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net))) 62 | 63 | if(not is_train): 64 | print('Loading model from: %s'%model_path) 65 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False) 66 | 67 | elif(self.model=='net'): # pretrained network 68 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) 69 | elif(self.model in ['L2','l2']): 70 | self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing 71 | self.model_name = 'L2' 72 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']): 73 | self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace) 74 | self.model_name = 'SSIM' 75 | else: 76 | raise ValueError("Model [%s] not recognized." % self.model) 77 | 78 | self.parameters = list(self.net.parameters()) 79 | 80 | if self.is_train: # training mode 81 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) 82 | self.rankLoss = networks.BCERankingLoss() 83 | self.parameters += list(self.rankLoss.net.parameters()) 84 | self.lr = lr 85 | self.old_lr = lr 86 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) 87 | else: # test mode 88 | self.net.eval() 89 | 90 | if(use_gpu): 91 | self.net.to(gpu_ids[0]) 92 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids) 93 | if(self.is_train): 94 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 95 | 96 | if(printNet): 97 | print('---------- Networks initialized -------------') 98 | networks.print_network(self.net) 99 | print('-----------------------------------------------') 100 | 101 | def forward(self, in0, in1, retPerLayer=False): 102 | ''' Function computes the distance between image patches in0 and in1 103 | INPUTS 104 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 105 | OUTPUT 106 | computed distances between in0 and in1 107 | ''' 108 | 109 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 110 | 111 | # ***** TRAINING FUNCTIONS ***** 112 | def optimize_parameters(self): 113 | self.forward_train() 114 | self.optimizer_net.zero_grad() 115 | self.backward_train() 116 | self.optimizer_net.step() 117 | self.clamp_weights() 118 | 119 | def clamp_weights(self): 120 | for module in self.net.modules(): 121 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)): 122 | module.weight.data = torch.clamp(module.weight.data,min=0) 123 | 124 | def set_input(self, data): 125 | self.input_ref = data['ref'] 126 | self.input_p0 = data['p0'] 127 | self.input_p1 = data['p1'] 128 | self.input_judge = data['judge'] 129 | 130 | if(self.use_gpu): 131 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) 132 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) 133 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) 134 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) 135 | 136 | self.var_ref = Variable(self.input_ref,requires_grad=True) 137 | self.var_p0 = Variable(self.input_p0,requires_grad=True) 138 | self.var_p1 = Variable(self.input_p1,requires_grad=True) 139 | 140 | def forward_train(self): # run forward pass 141 | # print(self.net.module.scaling_layer.shift) 142 | # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) 143 | 144 | self.d0 = self.forward(self.var_ref, self.var_p0) 145 | self.d1 = self.forward(self.var_ref, self.var_p1) 146 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge) 147 | 148 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size()) 149 | 150 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.) 151 | 152 | return self.loss_total 153 | 154 | def backward_train(self): 155 | torch.mean(self.loss_total).backward() 156 | 157 | def compute_accuracy(self,d0,d1,judge): 158 | ''' d0, d1 are Variables, judge is a Tensor ''' 159 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr)) 202 | self.old_lr = lr 203 | 204 | def score_2afc_dataset(data_loader, func, name=''): 205 | ''' Function computes Two Alternative Forced Choice (2AFC) score using 206 | distance function 'func' in dataset 'data_loader' 207 | INPUTS 208 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside 209 | func - callable distance function - calling d=func(in0,in1) should take 2 210 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N 211 | OUTPUTS 212 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators 213 | [1] - dictionary with following elements 214 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches 215 | gts - N array in [0,1], preferred patch selected by human evaluators 216 | (closer to "0" for left patch p0, "1" for right patch p1, 217 | "0.6" means 60pct people preferred right patch, 40pct preferred left) 218 | scores - N array in [0,1], corresponding to what percentage function agreed with humans 219 | CONSTS 220 | N - number of test triplets in data_loader 221 | ''' 222 | 223 | d0s = [] 224 | d1s = [] 225 | gts = [] 226 | 227 | for data in tqdm(data_loader.load_data(), desc=name): 228 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist() 229 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist() 230 | gts+=data['judge'].cpu().numpy().flatten().tolist() 231 | 232 | d0s = np.array(d0s) 233 | d1s = np.array(d1s) 234 | gts = np.array(gts) 235 | scores = (d0s= 0] 26 | left_mean_sqrt = 0 27 | right_mean_sqrt = 0 28 | if len(left_data) > 0: 29 | left_mean_sqrt = np.sqrt(np.average(left_data)) 30 | if len(right_data) > 0: 31 | right_mean_sqrt = np.sqrt(np.average(right_data)) 32 | 33 | if right_mean_sqrt != 0: 34 | gamma_hat = left_mean_sqrt / right_mean_sqrt 35 | else: 36 | gamma_hat = np.inf 37 | # solve r-hat norm 38 | 39 | imdata2_mean = np.mean(imdata2) 40 | if imdata2_mean != 0: 41 | r_hat = (np.average(np.abs(imdata)) ** 2) / (np.average(imdata2)) 42 | else: 43 | r_hat = np.inf 44 | rhat_norm = r_hat * (((math.pow(gamma_hat, 3) + 1) * (gamma_hat + 1)) / math.pow(math.pow(gamma_hat, 2) + 1, 2)) 45 | 46 | # solve alpha by guessing values that minimize ro 47 | pos = np.argmin((prec_gammas - rhat_norm) ** 2); 48 | alpha = gamma_range[pos] 49 | 50 | gam1 = scipy.special.gamma(1.0 / alpha) 51 | gam2 = scipy.special.gamma(2.0 / alpha) 52 | gam3 = scipy.special.gamma(3.0 / alpha) 53 | 54 | aggdratio = np.sqrt(gam1) / np.sqrt(gam3) 55 | bl = aggdratio * left_mean_sqrt 56 | br = aggdratio * right_mean_sqrt 57 | 58 | # mean parameter 59 | N = (br - bl) * (gam2 / gam1) # *aggdratio 60 | return (alpha, N, bl, br, left_mean_sqrt, right_mean_sqrt) 61 | 62 | 63 | def ggd_features(imdata): 64 | nr_gam = 1 / prec_gammas 65 | sigma_sq = np.var(imdata) 66 | E = np.mean(np.abs(imdata)) 67 | rho = sigma_sq / E ** 2 68 | pos = np.argmin(np.abs(nr_gam - rho)); 69 | return gamma_range[pos], sigma_sq 70 | 71 | 72 | def paired_product(new_im): 73 | shift1 = np.roll(new_im.copy(), 1, axis=1) 74 | shift2 = np.roll(new_im.copy(), 1, axis=0) 75 | shift3 = np.roll(np.roll(new_im.copy(), 1, axis=0), 1, axis=1) 76 | shift4 = np.roll(np.roll(new_im.copy(), 1, axis=0), -1, axis=1) 77 | 78 | H_img = shift1 * new_im 79 | V_img = shift2 * new_im 80 | D1_img = shift3 * new_im 81 | D2_img = shift4 * new_im 82 | 83 | return (H_img, V_img, D1_img, D2_img) 84 | 85 | 86 | def gen_gauss_window(lw, sigma): 87 | sd = np.float32(sigma) 88 | lw = int(lw) 89 | weights = [0.0] * (2 * lw + 1) 90 | weights[lw] = 1.0 91 | sum = 1.0 92 | sd *= sd 93 | for ii in range(1, lw + 1): 94 | tmp = np.exp(-0.5 * np.float32(ii * ii) / sd) 95 | weights[lw + ii] = tmp 96 | weights[lw - ii] = tmp 97 | sum += 2.0 * tmp 98 | for ii in range(2 * lw + 1): 99 | weights[ii] /= sum 100 | return weights 101 | 102 | 103 | def compute_image_mscn_transform(image, C=1, avg_window=None, extend_mode='constant'): 104 | if avg_window is None: 105 | avg_window = gen_gauss_window(3, 7.0 / 6.0) 106 | assert len(np.shape(image)) == 2 107 | h, w = np.shape(image) 108 | mu_image = np.zeros((h, w), dtype=np.float32) 109 | var_image = np.zeros((h, w), dtype=np.float32) 110 | image = np.array(image).astype('float32') 111 | scipy.ndimage.correlate1d(image, avg_window, 0, mu_image, mode=extend_mode) 112 | scipy.ndimage.correlate1d(mu_image, avg_window, 1, mu_image, mode=extend_mode) 113 | scipy.ndimage.correlate1d(image ** 2, avg_window, 0, var_image, mode=extend_mode) 114 | scipy.ndimage.correlate1d(var_image, avg_window, 1, var_image, mode=extend_mode) 115 | var_image = np.sqrt(np.abs(var_image - mu_image ** 2)) 116 | return (image - mu_image) / (var_image + C), var_image, mu_image 117 | 118 | 119 | def _niqe_extract_subband_feats(mscncoefs): 120 | # alpha_m, = extract_ggd_features(mscncoefs) 121 | alpha_m, N, bl, br, lsq, rsq = aggd_features(mscncoefs.copy()) 122 | pps1, pps2, pps3, pps4 = paired_product(mscncoefs) 123 | alpha1, N1, bl1, br1, lsq1, rsq1 = aggd_features(pps1) 124 | alpha2, N2, bl2, br2, lsq2, rsq2 = aggd_features(pps2) 125 | alpha3, N3, bl3, br3, lsq3, rsq3 = aggd_features(pps3) 126 | alpha4, N4, bl4, br4, lsq4, rsq4 = aggd_features(pps4) 127 | return np.array([alpha_m, (bl + br) / 2.0, 128 | alpha1, N1, bl1, br1, # (V) 129 | alpha2, N2, bl2, br2, # (H) 130 | alpha3, N3, bl3, bl3, # (D1) 131 | alpha4, N4, bl4, bl4, # (D2) 132 | ]) 133 | 134 | 135 | def get_patches_train_features(img, patch_size, stride=8): 136 | return _get_patches_generic(img, patch_size, 1, stride) 137 | 138 | 139 | def get_patches_test_features(img, patch_size, stride=8): 140 | return _get_patches_generic(img, patch_size, 0, stride) 141 | 142 | 143 | def extract_on_patches(img, patch_size): 144 | h, w = img.shape 145 | patch_size = np.int(patch_size) 146 | patches = [] 147 | for j in range(0, h - patch_size + 1, patch_size): 148 | for i in range(0, w - patch_size + 1, patch_size): 149 | patch = img[j:j + patch_size, i:i + patch_size] 150 | patches.append(patch) 151 | 152 | patches = np.array(patches) 153 | 154 | patch_features = [] 155 | for p in patches: 156 | patch_features.append(_niqe_extract_subband_feats(p)) 157 | patch_features = np.array(patch_features) 158 | 159 | return patch_features 160 | 161 | 162 | def _get_patches_generic(img, patch_size, is_train, stride): 163 | h, w = np.shape(img) 164 | if h < patch_size or w < patch_size: 165 | print("Input image is too small") 166 | exit(0) 167 | 168 | # ensure that the patch divides evenly into img 169 | hoffset = (h % patch_size) 170 | woffset = (w % patch_size) 171 | 172 | if hoffset > 0: 173 | img = img[:-hoffset, :] 174 | if woffset > 0: 175 | img = img[:, :-woffset] 176 | 177 | img = img.astype(np.float32) 178 | # img2 = scipy.misc.imresize(img, 0.5, interp='bicubic', mode='F') 179 | _h, _w = img.shape 180 | img2 = np.array(Image.fromarray(img).resize((int(_w * 0.5), int(_h * 0.5)), resample=Image.BICUBIC)) 181 | 182 | mscn1, var, mu = compute_image_mscn_transform(img) 183 | mscn1 = mscn1.astype(np.float32) 184 | 185 | mscn2, _, _ = compute_image_mscn_transform(img2) 186 | mscn2 = mscn2.astype(np.float32) 187 | 188 | feats_lvl1 = extract_on_patches(mscn1, patch_size) 189 | feats_lvl2 = extract_on_patches(mscn2, patch_size / 2) 190 | 191 | feats = np.hstack((feats_lvl1, feats_lvl2)) # feats_lvl3)) 192 | 193 | return feats 194 | 195 | 196 | def niqe(inputImgData): 197 | patch_size = 96 198 | module_path = dirname(__file__) 199 | 200 | # TODO: memoize 201 | params = scipy.io.loadmat(join(module_path, 'niqe_image_params.mat')) 202 | pop_mu = np.ravel(params["pop_mu"]) 203 | pop_cov = params["pop_cov"] 204 | 205 | inputImgData = inputImgData.astype('float32') 206 | M, N = inputImgData.shape 207 | 208 | # assert C == 1, "niqe called with videos containing %d channels. Please supply only the luminance channel" % (C,) 209 | assert M > ( 210 | patch_size * 2 + 1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters" 211 | assert N > ( 212 | patch_size * 2 + 1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters" 213 | 214 | feats = get_patches_test_features(inputImgData, patch_size) 215 | sample_mu = np.mean(feats, axis=0) 216 | sample_cov = np.cov(feats.T) 217 | 218 | X = sample_mu - pop_mu 219 | covmat = ((pop_cov + sample_cov) / 2.0) 220 | pinvmat = scipy.linalg.pinv(covmat) 221 | niqe_score = np.sqrt(np.dot(np.dot(X, pinvmat), X)) 222 | 223 | return niqe_score 224 | -------------------------------------------------------------------------------- /utils/NIQE/niqe_image_params.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csbhr/OpenUtility/c9cf713c99523c0a2e0be6c2afa988af751ad161/utils/NIQE/niqe_image_params.mat -------------------------------------------------------------------------------- /utils/ResizeRight/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Assaf Shocher 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/ResizeRight/README.md: -------------------------------------------------------------------------------- 1 | # ResizeRight 2 | This is a resizing packge for images or tensors, that supports both Numpy and PyTorch (**fully differentiable**) seamlessly. The main motivation for creating this is to address some **crucial incorrectness issues** (see item 3 in the list below) that exist in all other resizing packages I am aware of. As far as I know, it is the only one that performs correctly in all cases. ResizeRight is specially made for machine learning, image enhancement and restoration challenges. 3 | 4 | The code is inspired by MATLAB's imresize function, but with crucial differences. It is specifically useful due to the following reasons: 5 | 6 | 1. ResizeRight produces results **identical to MATLAB for the simple cases** (scale_factor * in_size is integer). None of the Python packages I am aware of, currently resize images with results similar to MATLAB's imresize (which is a common benchmark for image resotration tasks, especially super-resolution). 7 | 8 | 2. No other **differntiable** method I am aware of supports **AntiAliasing** as in MATLAB. Actually very few non-differentiable ones, including popular ones, do. This causes artifacts and inconsistency in downscaling. (see [this tweet](https://twitter.com/jaakkolehtinen/status/1258102168176951299) by [Jaakko Lehtinen](https://users.aalto.fi/~lehtinj7/) 9 | for example). 10 | 11 | 3. The most important part: In the general case where scale_factor * in_size is non-integer, **no existing resizing method I am aware of (including MATLAB) performs consistently.** ResizeRight is accurate and consistent due to its ability to process **both scale-factor and output-size** provided by the user. This is a super important feature for super-resolution and learning. One must acknowledge that the same output-size can be resulted with varying scale-factors as output-size is usually determined by *ceil(input_size * scale_factor)*. This situation creates dangerous lack of consistency. Best explained by example: say you have an image of size 9x9 and you resize by scale-factor of 0.5. Result size is 5x5. now you resize with scale-factor of 2. you get result sized 10x10. "no big deal", you must be thinking now, "I can resize it according to output-size 9x9", right? but then you will not get the correct scale-fcator which is calculated as output-size / input-size = 1.8. 12 | Due to a simple observation regarding the projection of the output grid to the input grid, ResizeRight is the only one that consistently maintains the image centered, as in optical zoom while complying with the exact scale-factor and output size the user requires. 13 | This is one of the main reasons for creating this repository. this downscale-upscale consistency is often crucial for learning based tasks (e.g. ["Zero-Shot Super-Resolution"](http://www.wisdom.weizmann.ac.il/~vision/zssr/)), and does not exist in other python packages nor in MATLAB. 14 | 15 | 4. Misalignment in resizing is a pandemic! Many existing packages actually return misaligned results. it is visually not apparent but can cause great damage to image enhancement tasks.(for example, see [how tensorflow's image resize stole 60 days of my life](https://hackernoon.com/how-tensorflows-tf-image-resize-stole-60-days-of-my-life-aba5eb093f35)). I personally also suffered from many misfortunate consequences of such missalignment before and throughout making this method. 16 | 17 | 5. Resizing supports **both Numpy and PyTorch** tensors seamlessly, just by the type of input tensor given. Results are checked to be identical in both modes, so you can safely apply to different tensor types and maintain consistency. No Numpy <-> Torch conversion takes part at any step. The process is done exclusively with one of the frameworks. No direct dependency is needed, so you can run it without having PyTorch installed at all, or without Numpy. You only need one of them. 18 | 19 | 6. In the case where scale_factor * in_size is a rational number with denominater not too big (this is a prameter), calculation is done efficiently based on convolutions (currently only PyTorch is supported). This is extremely more efficient for big tensors and suitable for working on large batches or high resolution. Note that this efficient calculation can be applied to certain dims that maintain the conditions while performing the regular calculation for the other dims. 20 | 21 | 7. Differently from some existing methods, including MATLAB, You can **resize N-D tensors in M-D dimensions.** for any M<=N. 22 | 23 | 8. You can specify a list of scale-factors to resize each dimension using a different scale-factor. 24 | 25 | 9. You can easily add and embed your own interpolation methods for the resizer to use (see interp_mehods.py) 26 | 27 | 10. All used framework padding methods are supported (depends on numpy/PyTorch mode) 28 | PyTorch: 'constant', 'reflect', 'replicate', 'circular'. 29 | Numpy: ‘constant’, ‘edge’, ‘linear_ramp’, ‘maximum’, ‘mean’, ‘median’, ‘minimum’, ‘reflect’, ‘symmetric’, ‘wrap’, ‘empty’ 30 | 31 | 11. Some general calculations are done more efficiently than the MATLAB version (one example is that MATLAB extends the kernel size by 2, and then searches for zero columns in the weights and cancels them. ResizeRight uses an observation that resizing is actually a continuous convolution and avoids having these redundancies ahead, see Shocher et al. ["From Discrete to Continuous Convolution Layers"](https://arxiv.org/abs/2006.11120)). 32 | -------- 33 | 34 | ### Usage: 35 | For dynamic resize using either Numpy or PyTorch: 36 | ``` 37 | resize_right.resize(input, scale_factors=None, out_shape=None, 38 | interp_method=interp_methods.cubic, support_sz=None, 39 | antialiasing=True, by_convs=False, scale_tolerance=None, 40 | max_denominator=10, pad_mode='constant'): 41 | ``` 42 | 43 | __input__ : 44 | the input image/tensor, a Numpy or Torch tensor. 45 | 46 | __scale_factors__: 47 | can be specified as- 48 | 1. one scalar scale - then it will be assumed that you want to resize first two dims with this scale for Numpy or last two dims for PyTorch. 49 | 2. a list or tupple of scales - one for each dimension you want to resize. note: if length of the list is L then first L dims will be rescaled for Numpy and last L for PyTorch. 50 | 3. not specified - then it will be calculated using output_size. this is not recomended (see advantage 3 in the list above). 51 | 52 | __out_shape__: 53 | A list or tupple. if shorter than input.shape then only the first/last (depending np/torch) dims are resized. if not specified, can be calcualated from scale_factor. 54 | 55 | __interp_method__: 56 | The type of interpolation used to calculate the weights. this is a scalar to scalar function that can be applied to tensors pointwise. The classical methods are implemented and can be found in interp_methods.py. (cubic, linear, laczos2, lanczos3, box). 57 | 58 | __support_sz__: 59 | This is the support of the interpolation function, i.e length of non-zero segment over its 1d input domain. this is a characteristic of the function. eg. for bicubic 4, linear 2, laczos2 4, lanczos3 6, box 1. 60 | 61 | __antialiasing__: 62 | This is an option similar to MATLAB's default. only relevant for downscaling. if true it basicly means that the kernel is stretched with 1/scale_factor to prevent aliasing (low-pass filtering) 63 | 64 | __by_convs__: 65 | This determines whether to allow efficient calculation using convolutions according to tolerance. This feature should be used when scale_factor * in_size is rational with a denominator low enough (or close enough to being an integer) and the tensors are big (batches or high-resolution). 66 | 67 | __scale_tolerance__: 68 | This is the allowed distance between the M/N closest frac to the float scale_factore provided. if the frac is closer than this distance, then it will be used and efficient convolution calculation will take place. 69 | 70 | __max_denominator__: 71 | When by_convs is on, the scale_factor is translated to a rational frac M/N. Where M is limited by this parameter. The goal is to make the calculation more efficient. The number of convolutions used is the size of the denominator. 72 | 73 | __pad_mode__: 74 | This can be used according to the padding methods of each framework. 75 | PyTorch: 'constant', 'reflect', 'replicate', 'circular'. 76 | Numpy: ‘constant’, ‘edge’, ‘linear_ramp’, ‘maximum’, ‘mean’, ‘median’, ‘minimum’, ‘reflect’, ‘symmetric’, ‘wrap’, ‘empty’ 77 | 78 | -------- 79 | 80 | ### Cite / credit 81 | If you find our work useful in your research or publication, please cite this work: 82 | ``` 83 | @misc{ResizeRight, 84 | author = {Shocher, Assaf}, 85 | title = {ResizeRight}, 86 | year = {2018}, 87 | publisher = {GitHub}, 88 | journal = {GitHub repository}, 89 | howpublished = {\url{https://github.com/assafshocher/ResizeRight}}, 90 | } 91 | ``` 92 | -------------------------------------------------------------------------------- /utils/ResizeRight/interp_methods.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | 3 | try: 4 | import torch 5 | except ImportError: 6 | torch = None 7 | 8 | try: 9 | import numpy 10 | except ImportError: 11 | numpy = None 12 | 13 | if numpy is None and torch is None: 14 | raise ImportError("Must have either Numpy or PyTorch but both not found") 15 | 16 | 17 | def set_framework_dependencies(x): 18 | if type(x) is numpy.ndarray: 19 | to_dtype = lambda a: a 20 | fw = numpy 21 | else: 22 | to_dtype = lambda a: a.to(x.dtype) 23 | fw = torch 24 | eps = fw.finfo(fw.float32).eps 25 | return fw, to_dtype, eps 26 | 27 | 28 | def support_sz(sz): 29 | def wrapper(f): 30 | f.support_sz = sz 31 | return f 32 | return wrapper 33 | 34 | 35 | @support_sz(4) 36 | def cubic(x): 37 | fw, to_dtype, eps = set_framework_dependencies(x) 38 | absx = fw.abs(x) 39 | absx2 = absx ** 2 40 | absx3 = absx ** 3 41 | return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) + 42 | (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) * 43 | to_dtype((1. < absx) & (absx <= 2.))) 44 | 45 | 46 | @support_sz(4) 47 | def lanczos2(x): 48 | fw, to_dtype, eps = set_framework_dependencies(x) 49 | return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) / 50 | ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2)) 51 | 52 | 53 | @support_sz(6) 54 | def lanczos3(x): 55 | fw, to_dtype, eps = set_framework_dependencies(x) 56 | return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) / 57 | ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3)) 58 | 59 | 60 | @support_sz(2) 61 | def linear(x): 62 | fw, to_dtype, eps = set_framework_dependencies(x) 63 | return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) * 64 | to_dtype((0 <= x) & (x <= 1))) 65 | 66 | 67 | @support_sz(1) 68 | def box(x): 69 | fw, to_dtype, eps = set_framework_dependencies(x) 70 | return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1)) 71 | -------------------------------------------------------------------------------- /utils/create_gif.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import imageio 4 | 5 | 6 | frame_root_1 = '/home/csbhr/Downloads/ours' 7 | frame_root_2 = '/home/csbhr/Downloads/input' 8 | gif_path = '/home/csbhr/Downloads/visual.gif' 9 | 10 | frames_fn_1 = sorted(os.listdir(frame_root_1)) 11 | frames_fn_2 = sorted(os.listdir(frame_root_2)) 12 | 13 | tmp_img = imageio.imread(os.path.join(frame_root_1, frames_fn_1[0])) 14 | H, W, C = tmp_img.shape 15 | gif = np.zeros(tmp_img.shape, dtype=tmp_img.dtype) 16 | line = np.zeros(tmp_img.shape, dtype=tmp_img.dtype) 17 | line[:, :, 0] = 0 18 | line[:, :, 1] = 255 19 | line[:, :, 2] = 0 20 | 21 | nl = 30 22 | wl = 3 23 | gif_images = [] 24 | for i, (fn1, fn2) in enumerate(zip(frames_fn_1, frames_fn_2)): 25 | frame_1 = imageio.imread(os.path.join(frame_root_1, fn1)) 26 | gif = frame_1 27 | if i < nl: 28 | frame_2 = imageio.imread(os.path.join(frame_root_2, fn2)) 29 | gif[:, :W-(W // nl * i), :] = frame_2[:, :W-(W // nl * i), :] 30 | gif[:, W-(W // nl * i + wl):W-(W // nl * i), :] = line[:, W-(W // nl * i + wl):W-(W // nl * i), :] 31 | gif_images.append(gif.copy()) 32 | imageio.mimsave(gif_path, gif_images, fps=4) 33 | 34 | 35 | # frames_root = '/home/csbhr/Downloads/(TPAMI) Self-Supervised Deep Blind Video Super-Resolution Supplemental Material/figures/videos' 36 | # gif_root = '/home/csbhr/Downloads/(TPAMI) Self-Supervised Deep Blind Video Super-Resolution Supplemental Material/figures' 37 | # 38 | # methods = sorted(os.listdir(frames_root)) 39 | # for me in methods: 40 | # frames_fn = sorted(os.listdir(os.path.join(frames_root, me))) 41 | # gif_images = [] 42 | # for fn in frames_fn: 43 | # gif_images.append(imageio.imread(os.path.join(frames_root, me, fn))) 44 | # imageio.mimsave(os.path.join(gif_root, "{}.gif".format(me)), gif_images, fps=2) 45 | -------------------------------------------------------------------------------- /utils/dnn_utils.py: -------------------------------------------------------------------------------- 1 | def cal_parmeters(model): 2 | params = list(model.parameters()) 3 | sum = 0 4 | for i in params: 5 | layer = 1 6 | for j in i.size(): 7 | layer *= j 8 | print("该层的结构:" + str(list(i.size())), "该层参数和:" + str(layer)) 9 | sum = sum + layer 10 | print("总参数数量和:%.2fM" % (sum / 1e6)) -------------------------------------------------------------------------------- /utils/file_regroup_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from base.os_base import handle_dir, get_fname_ext, copy_file, rename_file, glob_match 3 | 4 | 5 | def remove_files_prefix(root, prefix=''): 6 | ''' 7 | remove prefix from filename 8 | params: 9 | root: the dir of files that need to be processed 10 | prefix: the prefix to be removed 11 | ''' 12 | img_list = glob_match(os.path.join(root, "*")) 13 | for im in img_list: 14 | basename = os.path.basename(im) 15 | now_prefix = basename[:len(prefix)] 16 | if now_prefix == prefix: 17 | dest_basename = basename[len(prefix):] 18 | src = im 19 | dst = os.path.join(root, dest_basename) 20 | rename_file(src, dst) 21 | 22 | 23 | def remove_files_postfix(root, postfix=''): 24 | ''' 25 | remove postfix from filename 26 | params: 27 | root: the dir of files that need to be processed 28 | postfix: the postfix to be removed 29 | ''' 30 | img_list = glob_match(os.path.join(root, "*")) 31 | for im in img_list: 32 | fname, ext = get_fname_ext(im) 33 | now_postfix = fname[-len(postfix):] 34 | if now_postfix == postfix: 35 | dest_basename = "{}.{}".format(fname[:-len(postfix)], ext) 36 | src = im 37 | dst = os.path.join(root, dest_basename) 38 | rename_file(src, dst) 39 | 40 | 41 | def add_files_postfix(root, postfix=''): 42 | ''' 43 | add postfix to filename 44 | params: 45 | root: the dir of files that need to be processed 46 | postfix: the postfix to be added 47 | ''' 48 | img_list = glob_match(os.path.join(root, "*")) 49 | for im in img_list: 50 | fname, ext = get_fname_ext(im) 51 | dest_basename = "{}{}.{}".format(fname, postfix, ext) 52 | src = im 53 | dst = os.path.join(root, dest_basename) 54 | rename_file(src, dst) 55 | 56 | 57 | def extra_files_by_postfix(ori_root, dest_root, match_postfix='', new_postfix=None, match_ext='*'): 58 | ''' 59 | extra files from ori_root to dest_root by match_postfix and match_ext 60 | params: 61 | ori_root: the dir of files that need to be processed 62 | dest_root: the dir for saving matched files 63 | match_postfix: the postfix to be matched 64 | new_postfix: the postfix for matched files 65 | default: None, that is keeping the ori postfix 66 | match_ext: the ext to be matched 67 | ''' 68 | if new_postfix is None: 69 | new_postfix = match_postfix 70 | 71 | handle_dir(dest_root) 72 | flag_img_list = glob_match(os.path.join(ori_root, "*{}.{}".format(match_postfix, match_ext))) 73 | for im in flag_img_list: 74 | fname, ext = get_fname_ext(im) 75 | dest_basename = '{}{}.{}'.format(fname[-len(match_postfix):], new_postfix, ext) 76 | src = im 77 | dst = os.path.join(dest_root, dest_basename) 78 | copy_file(src, dst) 79 | 80 | 81 | def resort_files_index(root, template='{:0>4}', start_idx=0): 82 | ''' 83 | resort files' filename using template that index start from start_idx 84 | params: 85 | root: the dir of files that need to be processed 86 | template: the template for processed filename 87 | start_idx: the start index 88 | ''' 89 | template = template + '.{}' 90 | img_list = sorted(glob_match(os.path.join(root, "*"))) 91 | for i, im in enumerate(img_list): 92 | fname, ext = get_fname_ext(im) 93 | dest_basename = template.format(i + start_idx, ext) 94 | src = im 95 | dst = os.path.join(root, dest_basename) 96 | rename_file(src, dst) 97 | -------------------------------------------------------------------------------- /utils/image_crop_combine_utils.py: -------------------------------------------------------------------------------- 1 | ''' 2 | require: 3 | filenames: 4 | filenames should not contrain symbol "-" 5 | crop flags: 6 | the crop flag "x-x-x-x" is at the end of filename when cropping 7 | so, the crop flag should be at the end of filename when combining 8 | 9 | ''' 10 | 11 | import cv2 12 | import os 13 | import numpy as np 14 | from base.image_base import evaluate_smooth 15 | from base.os_base import handle_dir, listdir, glob_match 16 | 17 | 18 | def crop_img_with_padding(img, min_size=(100, 100), padding=10): 19 | h, w, c = img.shape 20 | n_x, n_y = h // min_size[0], w // min_size[1] 21 | 22 | croped_imgs = {} 23 | for i in range(n_x): 24 | for j in range(n_y): 25 | xl, xr = i * min_size[0] - padding, (i + 1) * min_size[0] + padding 26 | yl, yr = j * min_size[0] - padding, (j + 1) * min_size[0] + padding 27 | if i == 0: 28 | xl = 0 29 | if i == n_x - 1: 30 | xr = h 31 | if j == 0: 32 | yl = 0 33 | if j == n_y - 1: 34 | yr = w 35 | croped_imgs['{}-{}-{}-{}'.format(n_x, n_y, i, j)] = img[xl:xr, yl:yr, :] 36 | 37 | return croped_imgs 38 | 39 | 40 | def combine_img(croped_imgs, padding=10): 41 | keys = list(croped_imgs.keys()) 42 | n_x, n_y = int(keys[0].split('-')[0]), int(keys[0].split('-')[1]) 43 | img_blocks = [["" for _ in range(n_y)] for _ in range(n_x)] 44 | for k in keys: 45 | i, j = int(k.split('-')[2]), int(k.split('-')[3]) 46 | xl, xr, yl, yr = padding, -padding, padding, -padding 47 | if i == 0: 48 | xl = 0 49 | if i == n_x - 1: 50 | xr = 9999 51 | if j == 0: 52 | yl = 0 53 | if j == n_y - 1: 54 | yr = 9999 55 | img_blocks[i][j] = croped_imgs[k][xl:xr, yl:yr, :] 56 | 57 | line_imgs = [] 58 | for i in range(n_x): 59 | img_line = img_blocks[i][0] 60 | for j in range(n_y - 1): 61 | img_line = np.concatenate((img_line, img_blocks[i][j + 1]), axis=1) 62 | line_imgs.append(img_line) 63 | combined_img = line_imgs[0] 64 | for i in range(n_x - 1): 65 | combined_img = np.concatenate((combined_img, line_imgs[i + 1]), axis=0) 66 | return combined_img 67 | 68 | 69 | def traverse_crop_img(img, dsize=(100, 100), interval=10): 70 | h, w, c = img.shape 71 | 72 | croped_imgs = [] 73 | for i in range(9999): 74 | isbreak_x = False 75 | ix = i * interval 76 | xl, xr = ix, ix + dsize[0] 77 | if xr > h: 78 | xl, xr = h - dsize[0], h 79 | isbreak_x = True 80 | for j in range(9999): 81 | isbreak_y = False 82 | iy = j * interval 83 | yl, yr = iy, iy + dsize[1] 84 | if yr > w: 85 | yl, yr = w - dsize[1], w 86 | isbreak_y = True 87 | croped = img[xl:xr, yl:yr, :] 88 | croped_imgs.append(croped) 89 | if isbreak_y: 90 | break 91 | if isbreak_x: 92 | break 93 | 94 | return croped_imgs 95 | 96 | 97 | def batch_crop_img_with_padding(ori_root, dest_root, min_size=(100, 100), padding=10): 98 | ''' 99 | function: 100 | cropping image to many patches with padding 101 | it can be used for inferring large image 102 | params: 103 | ori_root: the dir of images that need to be processed 104 | dest_root: the dir to save processed images 105 | min_size: a tuple (h, w) the min size of crop, the border patch will be larger 106 | padding: the padding size of each patch 107 | notice: 108 | filenames should not contain the character "-" 109 | the crop flag "x-x-x-x" will be at the end of filename when cropping 110 | ''' 111 | handle_dir(dest_root) 112 | images_fname = sorted(listdir(ori_root)) 113 | for imf in images_fname: 114 | img = cv2.imread(os.path.join(ori_root, imf)) 115 | img_cropped = crop_img_with_padding(img, min_size=min_size, padding=padding) 116 | for k in img_cropped.keys(): 117 | cv2.imwrite(os.path.join(dest_root, "{}_{}.png".format(os.path.basename(imf).split('.')[0], k)), 118 | img_cropped[k]) 119 | print(imf, "crop done !") 120 | 121 | 122 | def batch_combine_img(ori_root, dest_root, padding=10): 123 | ''' 124 | function: 125 | combining many patches to image 126 | it can be used to combine patches to image, when you finish inferring large image with cropped patches 127 | params: 128 | ori_root: the dir of images that need to be processed 129 | dest_root: the dir to save processed images 130 | padding: the padding size of each patch 131 | notice: 132 | filenames should not contain the character "-" except for the crop flag 133 | the crop flag "x-x-x-x" should be at the end of filename when combining 134 | ''' 135 | handle_dir(dest_root) 136 | images_fname = [fn[:-(len(fn.split('_')[-1]) + 1)] for fn in listdir(ori_root)] 137 | images_fname = list(set(images_fname)) 138 | for imf in images_fname: 139 | croped_imgs_path = sorted(glob_match(os.path.join(ori_root, "{}*".format(imf)))) 140 | croped_imgs = {} 141 | for cip in croped_imgs_path: 142 | img = cv2.imread(cip) 143 | k = cip.split('.')[0].split('_')[-1] 144 | croped_imgs[k] = img 145 | img_combined = combine_img(croped_imgs, padding=padding) 146 | cv2.imwrite(os.path.join(dest_root, "{}.png".format(imf)), img_combined) 147 | print("{}.png".format(imf), "combine done !") 148 | 149 | 150 | def batch_traverse_crop_img(ori_root, dest_root, dsize=(100, 100), interval=10): 151 | ''' 152 | function: 153 | traversing crop image to many patches with same interval 154 | params: 155 | ori_root: the dir of images that need to be processed 156 | dest_root: the dir to save processed images 157 | dsize: a tuple (h, w) the size of crop, the border patch will be overlapped for satisfing the dsize 158 | interval: the interval when traversing 159 | ''' 160 | handle_dir(dest_root) 161 | images_fname = sorted(listdir(ori_root)) 162 | for imf in images_fname: 163 | img = cv2.imread(os.path.join(ori_root, imf)) 164 | img_cropped = traverse_crop_img(img, dsize=dsize, interval=interval) 165 | for i, cim in enumerate(img_cropped): 166 | cv2.imwrite(os.path.join(dest_root, "{}_{}.png".format(os.path.basename(imf).split('.')[0], i)), cim) 167 | print(imf, "crop done !") 168 | 169 | 170 | def batch_select_valid_patch(ori_root, dest_root, thre=7): 171 | ''' 172 | function: 173 | selecting valid patch that are not too smooth 174 | params: 175 | ori_root: the dir of patches that need to be selected 176 | dest_root: the dir to save selected patch 177 | thre: the threshold value of smooth 178 | ''' 179 | handle_dir(dest_root) 180 | images_fname = sorted(listdir(ori_root)) 181 | total_num = len(images_fname) 182 | valid_num = 0 183 | for imf in images_fname: 184 | img = cv2.imread(os.path.join(ori_root, imf)) 185 | smooth = evaluate_smooth(img) 186 | if smooth > thre: 187 | cv2.imwrite(os.path.join(dest_root, imf), img) 188 | valid_num += 1 189 | else: 190 | print(imf, "too smooth, smooth={}".format(smooth)) 191 | print("Total {} patches, valid {}, remove {}".format(total_num, valid_num, total_num - valid_num)) 192 | -------------------------------------------------------------------------------- /utils/image_edge_sharpen.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import cv2 6 | import os 7 | import math 8 | from utils.video_metric_utils import batch_calc_video_PSNR_SSIM 9 | 10 | 11 | def handle_dir(dir): 12 | if not os.path.exists(dir): 13 | os.mkdir(dir) 14 | print("mkdir:", dir) 15 | 16 | 17 | def matlab_style_gauss2D(shape=(5, 5), sigma=0.5): 18 | """ 19 | 2D gaussian mask - should give the same result as MATLAB's 20 | fspecial('gaussian',[shape],[sigma]) 21 | """ 22 | m, n = [(ss - 1.) / 2. for ss in shape] 23 | y, x = np.ogrid[-m:m + 1, -n:n + 1] 24 | h = np.exp(-(x * x + y * y) / (2. * sigma * sigma)) 25 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 26 | sumh = h.sum() 27 | if sumh != 0: 28 | h /= sumh 29 | return h 30 | 31 | 32 | def get_blur_kernel(): 33 | gaussian_sigma = 1.0 34 | gaussian_blur_kernel_size = int(math.ceil(gaussian_sigma * 3) * 2 + 1) 35 | kernel = matlab_style_gauss2D((gaussian_blur_kernel_size, gaussian_blur_kernel_size), gaussian_sigma) 36 | return kernel 37 | 38 | 39 | def get_blur(img, kernel): 40 | img = np.array(img).astype('float32') 41 | img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).float() 42 | 43 | kernel_size = kernel.shape[0] 44 | psize = kernel_size // 2 45 | img_tensor = F.pad(img_tensor, (psize, psize, psize, psize), mode='replicate') 46 | 47 | gaussian_blur = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=kernel_size, stride=1, 48 | padding=int((kernel_size - 1) // 2), bias=False) 49 | nn.init.constant_(gaussian_blur.weight.data, 0.0) 50 | gaussian_blur.weight.data[0, 0, :, :] = torch.FloatTensor(kernel) 51 | gaussian_blur.weight.data[1, 1, :, :] = torch.FloatTensor(kernel) 52 | gaussian_blur.weight.data[2, 2, :, :] = torch.FloatTensor(kernel) 53 | 54 | blur_tensor = gaussian_blur(img_tensor) 55 | blur_tensor = blur_tensor[:, :, psize:-psize, psize:-psize] 56 | 57 | blur_img = blur_tensor[0].detach().numpy().transpose(1, 2, 0).astype('float32') 58 | 59 | return blur_img 60 | 61 | 62 | def image_post_process_results(ori_root, save_root, alpha=3.): 63 | handle_dir(save_root) 64 | 65 | guassian_kernel = get_blur_kernel() 66 | 67 | frame_names = sorted(os.listdir(os.path.join(ori_root))) 68 | for fn in frame_names: 69 | ori_img = cv2.imread(os.path.join(ori_root, fn)).astype('float32') 70 | blur_img = get_blur(ori_img, guassian_kernel).astype('float32') 71 | 72 | res_img = ori_img - blur_img 73 | 74 | result = blur_img + alpha * res_img 75 | 76 | basename = fn.split(".")[0] 77 | cv2.imwrite(os.path.join(save_root, "{}_post.png".format(basename)), result) 78 | 79 | ## 80 | # res_img = np.clip(res_img, 0, np.max(res_img)) 81 | # res_img = (res_img / np.max(res_img)) * 255. 82 | # cv2.imwrite(os.path.join(save_root, "{}_res.png".format(basename)), res_img) 83 | ## 84 | 85 | print("{} done!".format(fn)) 86 | 87 | 88 | def video_post_process_results(ori_root, save_root, alpha=3.): 89 | handle_dir(save_root) 90 | 91 | guassian_kernel = get_blur_kernel() 92 | 93 | video_names = sorted(os.listdir(ori_root)) 94 | for vn in video_names: 95 | handle_dir(os.path.join(save_root, vn)) 96 | 97 | frame_names = sorted(os.listdir(os.path.join(ori_root, vn))) 98 | for fn in frame_names: 99 | ori_img = cv2.imread(os.path.join(ori_root, vn, fn)).astype('float32') 100 | blur_img = get_blur(ori_img, guassian_kernel).astype('float32') 101 | 102 | res_img = ori_img - blur_img 103 | 104 | result = blur_img + alpha * res_img 105 | 106 | basename = fn.split(".")[0] 107 | cv2.imwrite(os.path.join(save_root, vn, "{}_post.png".format(basename)), result) 108 | 109 | ## 110 | # res_img = np.clip(res_img, 0, np.max(res_img)) 111 | # res_img = (res_img / np.max(res_img)) * 255. 112 | # cv2.imwrite(os.path.join(save_root, vn, "{}_res.png".format(basename)), res_img) 113 | ## 114 | 115 | print("{}-{} done!".format(vn, fn)) 116 | 117 | 118 | if __name__ == '__main__': 119 | # image_post_process_results( 120 | # ori_root='/media/csbhr/Bear/Dataset/FaceSR/face/test/bic', 121 | # save_root='./temp/edge/residual', 122 | # alpha=2 123 | # ) 124 | 125 | root_list = [] 126 | for i in range(21): 127 | alpha = 1.0 + i * 0.1 128 | save_root = '/home/csbhr/Disk-2T/work/OpenUtility/temp/edge_sharpen/post_{}'.format(alpha) 129 | video_post_process_results( 130 | ori_root='/home/csbhr/Disk-2T/work/OpenUtility/temp/edge_sharpen/ori', 131 | save_root=save_root, 132 | alpha=alpha 133 | ) 134 | root_list.append({ 135 | 'output': save_root, 136 | 'gt': '/home/csbhr/Disk-2T/work/OpenUtility/temp/edge_sharpen/HR' 137 | }) 138 | batch_calc_video_PSNR_SSIM(root_list) 139 | -------------------------------------------------------------------------------- /utils/image_metric_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import numpy as np 5 | from base import image_base 6 | from base.os_base import listdir 7 | from base import matlab_imresize 8 | import utils.LPIPS.models as lpips_models 9 | import utils.NIQE.niqe as cal_niqe 10 | 11 | 12 | def calc_image_PSNR_SSIM(output_root, gt_root, crop_border=4, shift_window_size=0, test_ycbcr=False, crop_GT=False): 13 | ''' 14 | 计算图片的 PSNR、SSIM,使用 EDVR 的计算方式 15 | 要求 output_root, gt_root 中的文件按顺序一一对应 16 | ''' 17 | 18 | if test_ycbcr: 19 | print('Testing Y channel.') 20 | else: 21 | print('Testing RGB channels.') 22 | 23 | PSNR_list = [] 24 | SSIM_list = [] 25 | output_img_list = sorted(listdir(output_root)) 26 | gt_img_list = sorted(listdir(gt_root)) 27 | for o_im, g_im in zip(output_img_list, gt_img_list): 28 | o_im_path = os.path.join(output_root, o_im) 29 | g_im_path = os.path.join(gt_root, g_im) 30 | im_GT = cv2.imread(g_im_path) / 255. 31 | im_Gen = cv2.imread(o_im_path) / 255. 32 | 33 | if crop_GT: 34 | h, w, c = im_Gen.shape 35 | im_GT = im_GT[:h, :w, :] # crop GT to output size 36 | 37 | if test_ycbcr and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space 38 | im_GT = image_base.bgr2ycbcr(im_GT, range=1.) 39 | im_Gen = image_base.bgr2ycbcr(im_Gen, range=1.) 40 | 41 | # crop borders 42 | if crop_border != 0: 43 | if im_GT.ndim == 3: 44 | cropped_GT = im_GT[crop_border:-crop_border, crop_border:-crop_border, :] 45 | cropped_Gen = im_Gen[crop_border:-crop_border, crop_border:-crop_border, :] 46 | elif im_GT.ndim == 2: 47 | cropped_GT = im_GT[crop_border:-crop_border, crop_border:-crop_border] 48 | cropped_Gen = im_Gen[crop_border:-crop_border, crop_border:-crop_border] 49 | else: 50 | raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT.ndim)) 51 | else: 52 | cropped_GT = im_GT 53 | cropped_Gen = im_Gen 54 | 55 | if shift_window_size == 0: 56 | psnr = image_base.PSNR(cropped_GT * 255, cropped_Gen * 255) 57 | ssim = image_base.SSIM(cropped_GT * 255, cropped_Gen * 255) 58 | else: 59 | psnr, ssim = image_base.PSNR_SSIM_Shift_Best(cropped_GT * 255, cropped_Gen * 255, 60 | window_size=shift_window_size) 61 | PSNR_list.append(psnr) 62 | SSIM_list.append(ssim) 63 | 64 | print("{} PSNR={:.5}, SSIM={:.4}".format(o_im, psnr, ssim)) 65 | 66 | log = 'Average PSNR={:.5}, SSIM={:.4}'.format(sum(PSNR_list) / len(PSNR_list), sum(SSIM_list) / len(SSIM_list)) 67 | print(log) 68 | 69 | return PSNR_list, SSIM_list, log 70 | 71 | 72 | def batch_calc_image_PSNR_SSIM(root_list, crop_border=4, shift_window_size=0, test_ycbcr=False, crop_GT=False): 73 | ''' 74 | required params: 75 | root_list: a list, each item should be a dictionary that given two key-values: 76 | output: the dir of output images 77 | gt: the dir of gt images 78 | optional params: 79 | crop_border: defalut=4, crop pixels when calculating PSNR/SSIM 80 | shift_window_size: defalut=0, if >0, shifting image within a window for best metric 81 | test_ycbcr: default=False, if True, applying Ycbcr color space 82 | crop_GT: default=False, if True, cropping GT to output size 83 | return: 84 | log_list: a list, each item is a dictionary that given two key-values: 85 | data_path: the evaluated dir 86 | log: the log of this dir 87 | ''' 88 | log_list = [] 89 | for i, root in enumerate(root_list): 90 | ouput_root = root['output'] 91 | gt_root = root['gt'] 92 | print(">>>> Now Evaluation >>>>") 93 | print(">>>> OUTPUT: {}".format(ouput_root)) 94 | print(">>>> GT: {}".format(gt_root)) 95 | _, _, log = calc_image_PSNR_SSIM( 96 | ouput_root, gt_root, crop_border=crop_border, shift_window_size=shift_window_size, 97 | test_ycbcr=test_ycbcr, crop_GT=crop_GT 98 | ) 99 | log_list.append({ 100 | 'data_path': ouput_root, 101 | 'log': log 102 | }) 103 | 104 | print("--------------------------------------------------------------------------------------") 105 | for i, log in enumerate(log_list): 106 | print("## The {}-th:".format(i)) 107 | print(">> ", log['data_path']) 108 | print(">> ", log['log']) 109 | 110 | return log_list 111 | 112 | 113 | def calc_image_LPIPS(output_root, gt_root, model=None, use_gpu=False, spatial=True): 114 | ''' 115 | 计算图片的 LPIPS 116 | 要求 output_root, gt_root 中的文件按顺序一一对应 117 | ''' 118 | 119 | def _load_image(path, size=(512, 512)): 120 | img = cv2.imread(path) 121 | h, w, c = img.shape 122 | if h != size[0] or w != size[1]: 123 | img = cv2.resize(img, dsize=size, interpolation=cv2.INTER_CUBIC) 124 | return img[:, :, ::-1] 125 | 126 | def _im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.): 127 | return torch.Tensor((image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 128 | 129 | if model is None: 130 | model = lpips_models.PerceptualLoss(model='net-lin', net='alex', use_gpu=use_gpu, spatial=spatial) 131 | 132 | LPIPS_list = [] 133 | output_img_list = sorted(listdir(output_root)) 134 | gt_img_list = sorted(listdir(gt_root)) 135 | for o_im, g_im in zip(output_img_list, gt_img_list): 136 | o_im_path = os.path.join(output_root, o_im) 137 | g_im_path = os.path.join(gt_root, g_im) 138 | im_GT = _im2tensor(_load_image(g_im_path)) 139 | im_Gen = _im2tensor(_load_image(o_im_path)) 140 | 141 | if use_gpu: 142 | im_GT = im_GT.cuda() 143 | im_Gen = im_Gen.cuda() 144 | 145 | lpips = model.forward(im_GT, im_Gen).mean() 146 | LPIPS_list.append(lpips) 147 | 148 | print("{} LPIPS={:.4}".format(o_im, lpips)) 149 | 150 | log = 'Average LPIPS={:.4}'.format(sum(LPIPS_list) / len(LPIPS_list)) 151 | print(log) 152 | 153 | return LPIPS_list, log 154 | 155 | 156 | def batch_calc_image_LPIPS(root_list, use_gpu=False, spatial=True): 157 | ''' 158 | required params: 159 | root_list: a list, each item should be a dictionary that given two key-values: 160 | output: the dir of output images 161 | gt: the dir of gt images 162 | optional params: 163 | use_gpu: defalut=False, if True, using gpu 164 | spatial: default=True, if True, return spatial map 165 | return: 166 | log_list: a list, each item is a dictionary that given two key-values: 167 | data_path: the evaluated dir 168 | log: the log of this dir 169 | ''' 170 | model = lpips_models.PerceptualLoss(model='net-lin', net='alex', use_gpu=use_gpu, spatial=spatial) 171 | 172 | log_list = [] 173 | for i, root in enumerate(root_list): 174 | ouput_root = root['output'] 175 | gt_root = root['gt'] 176 | print(">>>> Now Evaluation >>>>") 177 | print(">>>> OUTPUT: {}".format(ouput_root)) 178 | print(">>>> GT: {}".format(gt_root)) 179 | _, log = calc_image_LPIPS(ouput_root, gt_root, model=model, use_gpu=use_gpu, spatial=spatial) 180 | log_list.append({ 181 | 'data_path': ouput_root, 182 | 'log': log 183 | }) 184 | 185 | print("--------------------------------------------------------------------------------------") 186 | for i, log in enumerate(log_list): 187 | print("## The {}-th:".format(i)) 188 | print(">> ", log['data_path']) 189 | print(">> ", log['log']) 190 | 191 | return log_list 192 | 193 | 194 | def calc_image_NIQE(output_root, crop_border=4): 195 | ''' 196 | 计算图片的 NIQE 197 | ''' 198 | 199 | NIQE_list = [] 200 | output_img_list = sorted(listdir(output_root)) 201 | for o_im in output_img_list: 202 | o_im_path = os.path.join(output_root, o_im) 203 | im_Gen = cv2.imread(o_im_path) 204 | im_Gen = cv2.cvtColor(im_Gen, cv2.COLOR_BGR2GRAY) 205 | 206 | # crop borders 207 | if crop_border != 0: 208 | if im_Gen.ndim == 3: 209 | cropped_Gen = im_Gen[crop_border:-crop_border, crop_border:-crop_border, :] 210 | elif im_Gen.ndim == 2: 211 | cropped_Gen = im_Gen[crop_border:-crop_border, crop_border:-crop_border] 212 | else: 213 | raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_Gen.ndim)) 214 | else: 215 | cropped_Gen = im_Gen 216 | 217 | h, w = cropped_Gen.shape 218 | if h < 193 or w < 193: 219 | cropped_Gen = matlab_imresize.imresize(cropped_Gen, output_shape=(512, 512), method='bicubic') 220 | # cropped_Gen = cv2.resize(cropped_Gen, dsize=(512, 512), interpolation=cv2.INTER_CUBIC) 221 | 222 | niqe = cal_niqe.niqe(cropped_Gen) 223 | NIQE_list.append(niqe) 224 | 225 | print("{} NIQE={:.5}".format(o_im, niqe)) 226 | 227 | log = 'Average NIQE={:.5}'.format(sum(NIQE_list) / len(NIQE_list)) 228 | print(log) 229 | 230 | return NIQE_list, log 231 | 232 | 233 | def batch_calc_image_NIQE(root_list, crop_border=4): 234 | ''' 235 | required params: 236 | root_list: a list, each item should be a dictionary that given key-values: 237 | output: the dir of output images 238 | optional params: 239 | crop_border: defalut=4, crop pixels when calculating NIQE 240 | return: 241 | log_list: a list, each item is a dictionary that given two key-values: 242 | data_path: the evaluated dir 243 | log: the log of this dir 244 | ''' 245 | log_list = [] 246 | for i, root in enumerate(root_list): 247 | ouput_root = root['output'] 248 | print(">>>> Now Evaluation >>>>") 249 | print(">>>> OUTPUT: {}".format(ouput_root)) 250 | _, log = calc_image_NIQE( 251 | ouput_root, crop_border=crop_border 252 | ) 253 | log_list.append({ 254 | 'data_path': ouput_root, 255 | 'log': log 256 | }) 257 | 258 | print("--------------------------------------------------------------------------------------") 259 | for i, log in enumerate(log_list): 260 | print("## The {}-th:".format(i)) 261 | print(">> ", log['data_path']) 262 | print(">> ", log['log']) 263 | 264 | return log_list 265 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from base.os_base import handle_dir, get_fname_ext, listdir 5 | from base.image_base import matlab_imresize, image_shift 6 | 7 | 8 | def matlab_resize_images(ori_root, dest_root, scale=1.0, method='bicubic', filename_template="{}.png"): 9 | ''' 10 | function: 11 | resizing images in batches, same as matlab2017 imresize 12 | params: 13 | ori_root: string, the dir of images that need to be processed 14 | dest_root: string, the dir to save processed images 15 | scale: float, the resize scale 16 | method: string, the interpolation method, 17 | optional: 'bilinear', 'bicubic' 18 | default: 'bicubic' 19 | filename_template: string, the filename template for saving images 20 | ''' 21 | if method != 'bilinear' and method != 'bicubic': 22 | raise Exception('Unknown method!') 23 | 24 | handle_dir(dest_root) 25 | scale = float(scale) 26 | images_fname = sorted(listdir(ori_root)) 27 | for imf in images_fname: 28 | img = cv2.imread(os.path.join(ori_root, imf)).astype('float32') 29 | img = matlab_imresize(img, scalar_scale=scale, method=method) 30 | cv2.imwrite(os.path.join(dest_root, filename_template.format(get_fname_ext(imf)[0])), img) 31 | print("Image", imf, "resize done !") 32 | 33 | 34 | def cv2_resize_images(ori_root, dest_root, scale=1.0, method='bicubic', filename_template="{}.png"): 35 | ''' 36 | function: 37 | resizing images in batches 38 | params: 39 | ori_root: string, the dir of images that need to be processed 40 | dest_root: string, the dir to save processed images 41 | scale: float, the resize scale 42 | method: string, the interpolation method, 43 | optional: 'nearest', 'bilinear', 'bicubic' 44 | default: 'bicubic' 45 | filename_template: string, the filename template for saving images 46 | ''' 47 | if method == 'nearest': 48 | interpolation = cv2.INTER_NEAREST 49 | elif method == 'bilinear': 50 | interpolation = cv2.INTER_LINEAR 51 | elif method == 'bicubic': 52 | interpolation = cv2.INTER_CUBIC 53 | else: 54 | raise Exception('Unknown method!') 55 | 56 | handle_dir(dest_root) 57 | scale = float(scale) 58 | images_fname = sorted(listdir(ori_root)) 59 | for imf in images_fname: 60 | img = cv2.imread(os.path.join(ori_root, imf)).astype('float32') 61 | img = cv2.resize(img, dsize=(0, 0), fx=scale, fy=scale, interpolation=interpolation) 62 | cv2.imwrite(os.path.join(dest_root, filename_template.format(get_fname_ext(imf)[0])), img) 63 | print("Image", imf, "resize done !") 64 | 65 | 66 | def shift_images(ori_root, dest_root, offset_x=0., offset_y=0., filename_template="{}.png"): # TODO 67 | ''' 68 | function: 69 | shifting images by (offset_x, offset_y) on (axis-x, axis-y) in batches 70 | params: 71 | ori_root: string, the dir of images that need to be processed 72 | dest_root: string, the dir to save processed images 73 | offset_x: float, offset pixels on axis-x 74 | positive=left; negative=right 75 | offset_y: float, offset pixels on axis-y 76 | positive=up; negative=down 77 | filename_template: string, the filename template for saving images 78 | ''' 79 | 80 | handle_dir(dest_root) 81 | offset_x, offset_y = float(offset_x), float(offset_y) 82 | images_fname = sorted(listdir(ori_root)) 83 | for imf in images_fname: 84 | img = cv2.imread(os.path.join(ori_root, imf)).astype('float32') 85 | img = image_shift(img, offset_x=offset_x, offset_y=offset_y) 86 | cv2.imwrite(os.path.join(dest_root, filename_template.format(get_fname_ext(imf)[0])), img) 87 | print("Image", imf, "shift done !") 88 | 89 | 90 | def margin_patch(patch, margin_width=5): 91 | '''margin patch by red cycle''' 92 | red = np.zeros_like(patch) 93 | red[:, :, 1] = 255 94 | mask = np.zeros_like(patch) 95 | mask[:margin_width, :, :] = mask[-margin_width:, :, :] = mask[:, :margin_width, :] = mask[:, -margin_width:, :] = 1 96 | res = red * mask + patch * (1 - mask) 97 | return res 98 | 99 | 100 | def circle_zoom_img(img, pos_1, pos_2, scale=2., hr_pos='right_down'): 101 | '''circle and zoom a patch in image''' 102 | patch_lr = img[pos_1[0]:pos_1[1], pos_2[0]:pos_2[1], :] 103 | patch_hr = cv2.resize(patch_lr, dsize=(0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC) 104 | patch_lr = margin_patch(patch_lr) 105 | patch_hr = margin_patch(patch_hr) 106 | img[pos_1[0]:pos_1[1], pos_2[0]:pos_2[1], :] = patch_lr 107 | if hr_pos == 'right_down': 108 | img[-patch_hr.shape[0]:, -patch_hr.shape[1]:, :] = patch_hr 109 | elif hr_pos == 'left_down': 110 | img[-patch_hr.shape[0]:, :patch_hr.shape[1], :] = patch_hr 111 | return img 112 | 113 | 114 | def circle_img(img, pos_1, pos_2, margin_width=2): 115 | '''circle a patch in image''' 116 | patch_lr = img[pos_1[0]:pos_1[1], pos_2[0]:pos_2[1], :] 117 | patch_lr = margin_patch(patch_lr, margin_width=margin_width) 118 | img[pos_1[0]:pos_1[1], pos_2[0]:pos_2[1], :] = patch_lr 119 | return img 120 | -------------------------------------------------------------------------------- /utils/kernel_metric_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from base import kernel_base 3 | from base.kernel_base import load_mat_kernel 4 | from base.os_base import listdir 5 | 6 | 7 | def calc_kernel_metric(output_root, gt_root): 8 | ''' 9 | 计算 kernel 的 metric: gradient similarity, rmse, psnr, ssim 10 | 要求 output_root, gt_root 中的文件按顺序一一对应 11 | ''' 12 | 13 | metric_dict = { 14 | 'gradient_similarity': [], 15 | 'rmse': [], 16 | 'psnr': [], 17 | 'ssim': [], 18 | } 19 | 20 | output_kernel_list = sorted(listdir(output_root)) 21 | gt_kernel_list = sorted(listdir(gt_root)) 22 | for o_k, g_k in zip(output_kernel_list, gt_kernel_list): 23 | o_k_path = os.path.join(output_root, o_k) 24 | g_k_path = os.path.join(gt_root, g_k) 25 | kernel_GT = load_mat_kernel(g_k_path) 26 | kernel_Gen = load_mat_kernel(o_k_path) 27 | 28 | gradient_similarity = kernel_base.Gradient_Similarity(kernel_Gen, kernel_GT) 29 | rmse, psnr, ssim = kernel_base.Kernel_RMSE_PSNR_SSIM(kernel_Gen, kernel_GT) 30 | metric_dict['gradient_similarity'].append(gradient_similarity) 31 | metric_dict['rmse'].append(rmse) 32 | metric_dict['psnr'].append(psnr) 33 | metric_dict['ssim'].append(ssim) 34 | 35 | print("{} Gradient-Similarity={:.4}, RMSE={:.4}, PSNR={:.4}, SSIM={:.4}".format( 36 | o_k, gradient_similarity, rmse, psnr, ssim 37 | )) 38 | 39 | log = 'Average Gradient-Similarity={:.4}, RMSE={:.4}, PSNR={:.4}, SSIM={:.4}'.format( 40 | sum(metric_dict['gradient_similarity']) / len(metric_dict['gradient_similarity']), 41 | sum(metric_dict['rmse']) / len(metric_dict['rmse']), 42 | sum(metric_dict['psnr']) / len(metric_dict['psnr']), 43 | sum(metric_dict['ssim']) / len(metric_dict['ssim']) 44 | ) 45 | print(log) 46 | 47 | return metric_dict, log 48 | 49 | 50 | def calc_kernel_metric_video(output_root, gt_root): 51 | ''' 52 | 计算视频的 kernel 的 metric: gradient similarity, rmse 53 | 要求 output_root, gt_root 中的文件按顺序一一对应 54 | ''' 55 | gradient_similarity_sum = 0. 56 | rmse_sum = 0. 57 | psnr_sum = 0. 58 | ssim_sum = 0. 59 | kernel_num = 0 60 | 61 | video_metric = [] 62 | 63 | video_list = sorted(listdir(output_root)) 64 | for v in video_list: 65 | v_metric_list, _ = calc_kernel_metric( 66 | output_root=os.path.join(output_root, v), 67 | gt_root=os.path.join(gt_root, v) 68 | ) 69 | gradient_similarity_sum += sum(v_metric_list['gradient_similarity']) 70 | rmse_sum += sum(v_metric_list['rmse']) 71 | psnr_sum += sum(v_metric_list['psnr']) 72 | ssim_sum += sum(v_metric_list['ssim']) 73 | kernel_num += len(v_metric_list['gradient_similarity']) 74 | 75 | video_metric.append({ 76 | 'video_name': v, 77 | 'gradient_similarity': v_metric_list['gradient_similarity'], 78 | 'rmse': v_metric_list['rmse'], 79 | 'psnr': v_metric_list['psnr'], 80 | 'ssim': v_metric_list['ssim'] 81 | }) 82 | 83 | logs = [] 84 | for v_m in video_metric: 85 | log = 'Video: {} Gradient-Similarity={:.4}, RMSE={:.4}, PSNR={:.4}, SSIM={:.4}'.format( 86 | v_m['video_name'], 87 | sum(v_m['gradient_similarity']) / len(v_m['gradient_similarity']), 88 | sum(v_m['rmse']) / len(v_m['rmse']), 89 | sum(v_m['psnr']) / len(v_m['psnr']), 90 | sum(v_m['ssim']) / len(v_m['ssim']) 91 | ) 92 | print(log) 93 | logs.append(log) 94 | log = 'Average Gradient-Similarity={:.4}, RMSE={:.4}, PSNR={:.4}, SSIM={:.4}'.format( 95 | gradient_similarity_sum / kernel_num, 96 | rmse_sum / kernel_num, 97 | psnr_sum / kernel_num, 98 | ssim_sum / kernel_num 99 | ) 100 | print(log) 101 | logs.append(log) 102 | 103 | return video_metric, logs 104 | 105 | 106 | def batch_calc_kernel_metric(root_list, video_type=False): 107 | ''' 108 | required params: 109 | root_list: a list, each item should be a dictionary that given two key-values: 110 | output: the dir of output images 111 | gt: the dir of gt images 112 | return: 113 | log_list: a list, each item is a dictionary that given two key-values: 114 | data_path: the evaluated dir 115 | log: the log of this dir 116 | ''' 117 | log_list = [] 118 | for i, root in enumerate(root_list): 119 | ouput_root = root['output'] 120 | gt_root = root['gt'] 121 | print(">>>> Now Evaluation >>>>") 122 | print(">>>> OUTPUT: {}".format(ouput_root)) 123 | print(">>>> GT: {}".format(gt_root)) 124 | if video_type: 125 | _, log = calc_kernel_metric_video(ouput_root, gt_root) 126 | else: 127 | _, log = calc_kernel_metric(ouput_root, gt_root) 128 | log_list.append({ 129 | 'data_path': ouput_root, 130 | 'log': log 131 | }) 132 | 133 | print("--------------------------------------------------------------------------------------") 134 | for i, log in enumerate(log_list): 135 | print("## The {}-th:".format(i)) 136 | print(">> ", log['data_path']) 137 | if video_type: 138 | for lo in log['log']: 139 | print(">> ", lo) 140 | else: 141 | print(">> ", log['log']) 142 | 143 | return log_list 144 | -------------------------------------------------------------------------------- /utils/kernel_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from base.os_base import handle_dir, get_fname_ext, listdir 4 | from base.kernel_base import kernel2png, load_mat_kernel 5 | 6 | 7 | def save_kernels_as_png(ori_root, dest_root, filename_template="{}.png"): 8 | ''' 9 | function: 10 | convert kernel for saving as png 11 | params: 12 | ori_root: string, the dir of kernels that need to be processed 13 | dest_root: string, the dir to save processed kernels 14 | filename_template: string, the filename template for saving kernels 15 | ''' 16 | 17 | handle_dir(dest_root) 18 | kernels_fname = sorted(listdir(ori_root)) 19 | for kerf in kernels_fname: 20 | ker = load_mat_kernel(os.path.join(ori_root, kerf)) 21 | ker_png = kernel2png(ker) 22 | cv2.imwrite(os.path.join(dest_root, filename_template.format(get_fname_ext(kerf)[0])), ker_png) 23 | print("Kernel", kerf, "save done !") 24 | -------------------------------------------------------------------------------- /utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def get_max(arr): 6 | new_arr = np.zeros(arr.shape) 7 | for i in range(arr.shape[0]): 8 | new_arr[i] = np.max(arr[:i + 1]) 9 | return new_arr 10 | 11 | 12 | def get_min(arr): 13 | new_arr = np.zeros(arr.shape) 14 | for i in range(arr.shape[0]): 15 | new_arr[i] = np.min(arr[:i + 1]) 16 | return new_arr 17 | 18 | 19 | def plot_curve(data, title, xlabel='Epochs', ylabel=None, save=None): 20 | ''' 21 | function: 22 | plot curve 23 | required params: 24 | data: the data of curve, numpy.array, ndim=1 25 | title: the title for each curve 26 | optional params: 27 | xlabel: the flag of x-axis 28 | ylabel: the flag of y-axis, if None: ylabel=title 29 | save: 30 | if None: just show figure 31 | else: the path to save figure, the figure will be saved 32 | ''' 33 | length = data.shape[0] 34 | axis = np.linspace(1, length, length) 35 | fig = plt.figure() 36 | plt.title('{} Graph'.format(title)) 37 | plt.plot(axis, data) 38 | plt.legend() 39 | plt.xlabel(xlabel) 40 | if not ylabel: 41 | ylabel = title 42 | plt.ylabel(ylabel) 43 | plt.grid(True) 44 | 45 | if save: 46 | plt.savefig(save) 47 | plt.close(fig) 48 | else: 49 | plt.show() 50 | 51 | 52 | def plot_multi_curve(array_list, label_list, title='compare', xlabel='Epochs', ylabel=None, save=None): 53 | ''' 54 | function: 55 | plot multi curve in one figure 56 | required params: 57 | array_list: the data of curves, a list 58 | each item should be numpy.array, ndim=1 59 | label_list: list, labels for each curve 60 | optional params: 61 | title: the title of figure 62 | xlabel: the flag of x-axis 63 | ylabel: the flag of y-axis, if None: ylabel=title 64 | save: 65 | if None: just show figure 66 | else: the path to save figure, the figure will be saved 67 | ''' 68 | assert len(array_list) == len(label_list), "length not equal" 69 | 70 | fig = plt.figure() 71 | plt.title(title) 72 | 73 | # plot curves 74 | for i in range(len(array_list)): 75 | length = array_list[i].shape[0] 76 | axis = np.linspace(1, length, length) 77 | plt.plot(axis, array_list[i], label=label_list[i]) 78 | 79 | plt.legend() 80 | plt.xlabel(xlabel) 81 | if not ylabel: 82 | ylabel = title 83 | plt.ylabel(ylabel) 84 | plt.grid(True) 85 | 86 | if save: 87 | plt.savefig(save) 88 | plt.close(fig) 89 | else: 90 | plt.show() 91 | 92 | 93 | def plot_multi_curve_given_axis(array_list, label_list, axis, title='compare', xlabel='Epochs', ylabel=None, save=None): 94 | ''' 95 | function: 96 | plot multi curve in one figure 97 | required params: 98 | array_list: the data of curves, a list 99 | each item should be numpy.array, ndim=1 100 | label_list: list, labels for each curve 101 | axis: list, x-axis flags 102 | optional params: 103 | title: the title of figure 104 | xlabel: the flag of x-axis 105 | ylabel: the flag of y-axis, if None: ylabel=title 106 | save: 107 | if None: just show figure 108 | else: the path to save figure, the figure will be saved 109 | ''' 110 | assert len(array_list) == len(label_list), "length not equal" 111 | 112 | fig = plt.figure() 113 | plt.title(title) 114 | 115 | # plot curves 116 | for i in range(len(array_list)): 117 | plt.plot(axis, array_list[i], label=label_list[i]) 118 | 119 | plt.legend() 120 | plt.xlabel(xlabel) 121 | if not ylabel: 122 | ylabel = title 123 | plt.ylabel(ylabel) 124 | plt.grid(True) 125 | 126 | if save: 127 | plt.savefig(save) 128 | plt.close(fig) 129 | else: 130 | plt.show() 131 | -------------------------------------------------------------------------------- /utils/string_utils.py: -------------------------------------------------------------------------------- 1 | def LCS(str1, str2): 2 | ''' 3 | Longest Common Subsequence 4 | ''' 5 | num_res = [[0 for _ in range(len(str2) + 1)] for _ in range(len(str1) + 1)] 6 | char_res = [['' for _ in range(len(str2) + 1)] for _ in range(len(str1) + 1)] 7 | for i in range(1, len(str1) + 1): 8 | for j in range(1, len(str2) + 1): 9 | if str1[i - 1] == str2[j - 1]: 10 | num_res[i][j] = num_res[i - 1][j - 1] + 1 11 | char_res[i][j] = char_res[i - 1][j - 1] + str1[i - 1] 12 | else: 13 | if num_res[i][j - 1] > num_res[i - 1][j]: 14 | num_res[i][j] = num_res[i][j - 1] 15 | char_res[i][j] = char_res[i][j - 1] 16 | else: 17 | num_res[i][j] = num_res[i - 1][j] 18 | char_res[i][j] = char_res[i - 1][j] 19 | return num_res[-1][-1], char_res[-1][-1] 20 | -------------------------------------------------------------------------------- /utils/torchstat/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | .pypirc 106 | -------------------------------------------------------------------------------- /utils/torchstat/.travis.yml: -------------------------------------------------------------------------------- 1 | os: linux 2 | sudo: false 3 | language: python 4 | python: 5 | - "3.6" 6 | 7 | install: 8 | - pip install pycodestyle 9 | 10 | script: 11 | - pycodestyle torchstat/ 12 | 13 | cache: 14 | - pip 15 | 16 | notifications: 17 | email: false 18 | -------------------------------------------------------------------------------- /utils/torchstat/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Swall0w - Alan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/torchstat/README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/Swall0w/torchstat.svg?branch=master)](https://travis-ci.org/Swall0w/torchstat) 2 | 3 | # torchstat 4 | This is a lightweight neural network analyzer based on PyTorch. 5 | It is designed to make building your networks quick and easy, with the ability to debug them. 6 | **Note**: This repository is currently under development. Therefore, some APIs might be changed. 7 | 8 | This tools can show 9 | 10 | * Total number of network parameters 11 | * Theoretical amount of floating point arithmetics (FLOPs) 12 | * Theoretical amount of multiply-adds (MAdd) 13 | * Memory usage 14 | 15 | ## Installing 16 | There're two ways to install torchstat into your environment. 17 | * Install it via pip. 18 | ```bash 19 | $ pip install torchstat 20 | ``` 21 | 22 | * Install and update using **setup.py** after cloning this repository. 23 | ```bash 24 | $ python3 setup.py install 25 | ``` 26 | 27 | ## A Simple Example 28 | If you want to run the torchstat asap, you can call it as a CLI tool if your network exists in a script. 29 | Otherwise you need to import torchstat as a module. 30 | 31 | ### CLI tool 32 | ```bash 33 | $ torchstat masato$ torchstat -f example.py -m Net 34 | [MAdd]: Dropout2d is not supported! 35 | [Flops]: Dropout2d is not supported! 36 | [Memory]: Dropout2d is not supported! 37 | module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B) 38 | 0 conv1 3 224 224 10 220 220 760.0 1.85 72,600,000.0 36,784,000.0 605152.0 1936000.0 57.49% 2541152.0 39 | 1 conv2 10 110 110 20 106 106 5020.0 0.86 112,360,000.0 56,404,720.0 504080.0 898880.0 26.62% 1402960.0 40 | 2 conv2_drop 20 106 106 20 106 106 0.0 0.86 0.0 0.0 0.0 0.0 4.09% 0.0 41 | 3 fc1 56180 50 2809050.0 0.00 5,617,950.0 2,809,000.0 11460920.0 200.0 11.58% 11461120.0 42 | 4 fc2 50 10 510.0 0.00 990.0 500.0 2240.0 40.0 0.22% 2280.0 43 | total 2815340.0 3.56 190,578,940.0 95,998,220.0 2240.0 40.0 100.00% 15407512.0 44 | =============================================================================================================================================== 45 | Total params: 2,815,340 46 | ----------------------------------------------------------------------------------------------------------------------------------------------- 47 | Total memory: 3.56MB 48 | Total MAdd: 190.58MMAdd 49 | Total Flops: 96.0MFlops 50 | Total MemR+W: 14.69MB 51 | ``` 52 | 53 | If you're not sure how to use a specific command, run the command with the -h or –help switches. 54 | You'll see usage information and a list of options you can use with the command. 55 | 56 | ### Module 57 | ```python 58 | from torchstat import stat 59 | import torchvision.models as models 60 | 61 | model = models.resnet18() 62 | stat(model, (3, 224, 224)) 63 | ``` 64 | 65 | ## Features & TODO 66 | **Note**: These features work only nn.Module. Modules in torch.nn.functional are not supported yet. 67 | - [x] FLOPs 68 | - [x] Number of Parameters 69 | - [x] Total memory 70 | - [x] Madd(FMA) 71 | - [x] MemRead 72 | - [x] MemWrite 73 | - [ ] Model summary(detail, layer-wise) 74 | - [ ] Export score table 75 | - [ ] Arbitrary input shape 76 | 77 | For the supported layers, check out [the details](./detail.md). 78 | 79 | 80 | ## Requirements 81 | * Python 3.6+ 82 | * Pytorch 0.4.0+ 83 | * Pandas 0.23.4+ 84 | * NumPy 1.14.3+ 85 | 86 | ## References 87 | Thanks to @sovrasov for the initial version of flops computation, @ceykmc for the backbone of scripts. 88 | * [flops-counter.pytorch](https://github.com/sovrasov/flops-counter.pytorch) 89 | * [pytorch_model_summary](https://github.com/ceykmc/pytorch_model_summary) 90 | * [chainer_computational_cost](https://github.com/belltailjp/chainer_computational_cost) 91 | * [convnet-burden](https://github.com/albanie/convnet-burden). 92 | -------------------------------------------------------------------------------- /utils/torchstat/detail.md: -------------------------------------------------------------------------------- 1 | # Supported Layers 2 | |Layer|Flops|Madd|MemRead|MemWrite| 3 | |---|---|---|---|---| 4 | |Conv2d|ok|ok|ok| ok| 5 | |ConvTranspose2d| |ok||| 6 | |BatchNorm2d|ok|ok|ok|ok| 7 | |Linear|ok|ok|ok|ok| 8 | |UpSample|ok| ||| 9 | |AvgPool2d|ok|ok|ok|ok| 10 | |MaxPool2d|ok|ok|ok|ok| 11 | |ReLU|ok|ok|ok|ok| 12 | |ReLU6|ok|ok|ok|ok| 13 | |LeaklyReLU|ok||ok|ok| 14 | |PReLU|ok|ok|ok|ok| 15 | |ELU|ok|ok|ok|ok| 16 | -------------------------------------------------------------------------------- /utils/torchstat/example.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchstat import stat 5 | 6 | 7 | class Net(nn.Module): 8 | def __init__(self): 9 | super(Net, self).__init__() 10 | self.conv1 = nn.Conv2d(3, 10, kernel_size=5) 11 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5) 12 | self.conv2_drop = nn.Dropout2d() 13 | self.fc1 = nn.Linear(56180, 50) 14 | self.fc2 = nn.Linear(50, 10) 15 | 16 | def forward(self, x): 17 | x = F.relu(F.max_pool2d(self.conv1(x), 2)) 18 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) 19 | x = x.view(-1, 56180) 20 | x = F.relu(self.fc1(x)) 21 | x = F.dropout(x, training=self.training) 22 | x = self.fc2(x) 23 | return F.log_softmax(x, dim=1) 24 | 25 | 26 | if __name__ == '__main__': 27 | model = Net() 28 | stat(model, (3, 224, 224)) 29 | -------------------------------------------------------------------------------- /utils/torchstat/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | numpy 3 | pandas 4 | -------------------------------------------------------------------------------- /utils/torchstat/setup.cfg: -------------------------------------------------------------------------------- 1 | [pycodestyle] 2 | count = False 3 | ignore = E226,E302,E41 4 | max-line-length = 160 5 | statistics = True 6 | -------------------------------------------------------------------------------- /utils/torchstat/setup.py: -------------------------------------------------------------------------------- 1 | #!usr/bin/env python3 2 | 3 | import re 4 | from os import path 5 | 6 | from setuptools import find_packages, setup 7 | 8 | package_name = "torchstat" 9 | root_dir = path.abspath(path.dirname(__file__)) 10 | 11 | 12 | def _requirements(): 13 | return [name.rstrip() for name in open(path.join(root_dir, 'requirements.txt'), encoding='utf-8').readlines()] 14 | 15 | 16 | def _test_requirements(): 17 | return [name.rstrip() for name in open(path.join(root_dir, 'test_requirements.txt'), encoding='utf-8').readlines()] 18 | 19 | 20 | with open(path.join(root_dir, package_name, '__init__.py'), encoding='utf-8') as f: 21 | init_text = f.read() 22 | version = re.search(r'__version__ = [\'\"](.+?)[\'\"]', init_text).group(1) 23 | author = re.search(r'__author__ =\s*[\'\"](.+?)[\'\"]', init_text).group(1) 24 | url = re.search(r'__url__ =\s*[\'\"](.+?)[\'\"]', init_text).group(1) 25 | 26 | assert version 27 | assert author 28 | assert url 29 | 30 | with open('README.md', encoding='utf-8') as f: 31 | long_description = f.read() 32 | 33 | setup( 34 | name=package_name, 35 | version=version, 36 | description="torchstat: The Pytorch Model Analyzer.", 37 | long_description=long_description, 38 | long_description_content_type='text/markdown', 39 | author=author, 40 | url=url, 41 | 42 | install_requires = _requirements(), 43 | tests_requires = _test_requirements(), 44 | include_package_data=True, 45 | 46 | license=license, 47 | packages=find_packages(exclude=('tests')), 48 | test_suite='tests', 49 | entry_points=""" 50 | [console_scripts] 51 | torchstat = torchstat.__main__:main 52 | """, 53 | ) 54 | -------------------------------------------------------------------------------- /utils/torchstat/test_requirements.txt: -------------------------------------------------------------------------------- 1 | pydocstyle 2 | -------------------------------------------------------------------------------- /utils/torchstat/torchstat/__init__.py: -------------------------------------------------------------------------------- 1 | __copyright__ = 'Copyright (C) 2018 Swall0w' 2 | __version__ = '0.0.7' 3 | __author__ = 'Swall0w' 4 | __url__ = 'https://github.com/Swall0w/torchstat' 5 | 6 | from torchstat.compute_memory import compute_memory 7 | from torchstat.compute_madd import compute_madd 8 | from torchstat.compute_flops import compute_flops 9 | from torchstat.stat_tree import StatTree, StatNode 10 | from torchstat.model_hook import ModelHook 11 | from torchstat.reporter import report_format 12 | from torchstat.statistics import stat, ModelStat 13 | 14 | __all__ = ['report_format', 'StatTree', 'StatNode', 'compute_madd', 15 | 'compute_flops', 'ModelHook', 'stat', 'ModelStat', '__main__', 16 | 'compute_memory'] 17 | -------------------------------------------------------------------------------- /utils/torchstat/torchstat/__main__.py: -------------------------------------------------------------------------------- 1 | from torchstat import stat 2 | import argparse 3 | import importlib.util 4 | import torch 5 | 6 | 7 | def arg(): 8 | parser = argparse.ArgumentParser(description='Torch model statistics') 9 | parser.add_argument('--file', '-f', type=str, 10 | help='Module file.') 11 | parser.add_argument('--model', '-m', type=str, 12 | help='Model name') 13 | parser.add_argument('--size', '-s', type=str, default='3x224x224', 14 | help='Input size. channels x height x width (default: 3x224x224)') 15 | return parser.parse_args() 16 | 17 | 18 | def main(): 19 | args = arg() 20 | try: 21 | spec = importlib.util.spec_from_file_location('models', args.file) 22 | module = importlib.util.module_from_spec(spec) 23 | spec.loader.exec_module(module) 24 | model = getattr(module, args.model)() 25 | except Exception: 26 | import traceback 27 | print(f'Tried to import {args.model} from {args.file}. but failed.') 28 | traceback.print_exc() 29 | 30 | import sys 31 | sys.exit() 32 | 33 | input_size = tuple(int(x) for x in args.size.split('x')) 34 | stat(model, input_size, query_granularity=1) 35 | -------------------------------------------------------------------------------- /utils/torchstat/torchstat/compute_flops.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def compute_flops(module, inp, out): 7 | if isinstance(module, nn.Conv2d): 8 | return compute_Conv2d_flops(module, inp, out) 9 | elif isinstance(module, nn.BatchNorm2d): 10 | return compute_BatchNorm2d_flops(module, inp, out) 11 | elif isinstance(module, (nn.AvgPool2d, nn.MaxPool2d)): 12 | return compute_Pool2d_flops(module, inp, out) 13 | elif isinstance(module, (nn.ReLU, nn.ReLU6, nn.PReLU, nn.ELU, nn.LeakyReLU)): 14 | return compute_ReLU_flops(module, inp, out) 15 | elif isinstance(module, nn.Upsample): 16 | return compute_Upsample_flops(module, inp, out) 17 | elif isinstance(module, nn.Linear): 18 | return compute_Linear_flops(module, inp, out) 19 | else: 20 | print(f"[Flops]: {type(module).__name__} is not supported!") 21 | return 0 22 | pass 23 | 24 | 25 | def compute_Conv2d_flops(module, inp, out): 26 | # Can have multiple inputs, getting the first one 27 | assert isinstance(module, nn.Conv2d) 28 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 29 | 30 | batch_size = inp.size()[0] 31 | in_c = inp.size()[1] 32 | k_h, k_w = module.kernel_size 33 | out_c, out_h, out_w = out.size()[1:] 34 | groups = module.groups 35 | 36 | filters_per_channel = out_c // groups 37 | conv_per_position_flops = k_h * k_w * in_c * filters_per_channel 38 | active_elements_count = batch_size * out_h * out_w 39 | 40 | total_conv_flops = conv_per_position_flops * active_elements_count 41 | 42 | bias_flops = 0 43 | if module.bias is not None: 44 | bias_flops = out_c * active_elements_count 45 | 46 | total_flops = total_conv_flops + bias_flops 47 | return total_flops 48 | 49 | 50 | def compute_BatchNorm2d_flops(module, inp, out): 51 | assert isinstance(module, nn.BatchNorm2d) 52 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 53 | in_c, in_h, in_w = inp.size()[1:] 54 | batch_flops = np.prod(inp.shape) 55 | if module.affine: 56 | batch_flops *= 2 57 | return batch_flops 58 | 59 | 60 | def compute_ReLU_flops(module, inp, out): 61 | assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.PReLU, nn.ELU, nn.LeakyReLU)) 62 | batch_size = inp.size()[0] 63 | active_elements_count = batch_size 64 | 65 | for s in inp.size()[1:]: 66 | active_elements_count *= s 67 | 68 | return active_elements_count 69 | 70 | 71 | def compute_Pool2d_flops(module, inp, out): 72 | assert isinstance(module, nn.MaxPool2d) or isinstance(module, nn.AvgPool2d) 73 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 74 | return np.prod(inp.shape) 75 | 76 | 77 | def compute_Linear_flops(module, inp, out): 78 | assert isinstance(module, nn.Linear) 79 | assert len(inp.size()) == 2 and len(out.size()) == 2 80 | batch_size = inp.size()[0] 81 | return batch_size * inp.size()[1] * out.size()[1] 82 | 83 | def compute_Upsample_flops(module, inp, out): 84 | assert isinstance(module, nn.Upsample) 85 | output_size = out[0] 86 | batch_size = inp.size()[0] 87 | output_elements_count = batch_size 88 | for s in output_size.shape[1:]: 89 | output_elements_count *= s 90 | 91 | return output_elements_count 92 | -------------------------------------------------------------------------------- /utils/torchstat/torchstat/compute_madd.py: -------------------------------------------------------------------------------- 1 | """ 2 | compute Multiply-Adds(MAdd) of each leaf module 3 | """ 4 | 5 | import torch.nn as nn 6 | 7 | 8 | def compute_Conv2d_madd(module, inp, out): 9 | assert isinstance(module, nn.Conv2d) 10 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 11 | 12 | in_c = inp.size()[1] 13 | k_h, k_w = module.kernel_size 14 | out_c, out_h, out_w = out.size()[1:] 15 | groups = module.groups 16 | 17 | # ops per output element 18 | kernel_mul = k_h * k_w * (in_c // groups) 19 | kernel_add = kernel_mul - 1 + (0 if module.bias is None else 1) 20 | 21 | kernel_mul_group = kernel_mul * out_h * out_w * (out_c // groups) 22 | kernel_add_group = kernel_add * out_h * out_w * (out_c // groups) 23 | 24 | total_mul = kernel_mul_group * groups 25 | total_add = kernel_add_group * groups 26 | 27 | return total_mul + total_add 28 | 29 | 30 | def compute_ConvTranspose2d_madd(module, inp, out): 31 | assert isinstance(module, nn.ConvTranspose2d) 32 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 33 | 34 | in_c, in_h, in_w = inp.size()[1:] 35 | k_h, k_w = module.kernel_size 36 | out_c, out_h, out_w = out.size()[1:] 37 | groups = module.groups 38 | 39 | kernel_mul = k_h * k_w * (in_c // groups) 40 | kernel_add = kernel_mul - 1 + (0 if module.bias is None else 1) 41 | 42 | kernel_mul_group = kernel_mul * in_h * in_w * (out_c // groups) 43 | kernel_add_group = kernel_add * in_h * in_w * (out_c // groups) 44 | 45 | total_mul = kernel_mul_group * groups 46 | total_add = kernel_add_group * groups 47 | 48 | return total_mul + total_add 49 | 50 | 51 | def compute_BatchNorm2d_madd(module, inp, out): 52 | assert isinstance(module, nn.BatchNorm2d) 53 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 54 | 55 | in_c, in_h, in_w = inp.size()[1:] 56 | 57 | # 1. sub mean 58 | # 2. div standard deviation 59 | # 3. mul alpha 60 | # 4. add beta 61 | return 4 * in_c * in_h * in_w 62 | 63 | 64 | def compute_MaxPool2d_madd(module, inp, out): 65 | assert isinstance(module, nn.MaxPool2d) 66 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 67 | 68 | if isinstance(module.kernel_size, (tuple, list)): 69 | k_h, k_w = module.kernel_size 70 | else: 71 | k_h, k_w = module.kernel_size, module.kernel_size 72 | out_c, out_h, out_w = out.size()[1:] 73 | 74 | return (k_h * k_w - 1) * out_h * out_w * out_c 75 | 76 | 77 | def compute_AvgPool2d_madd(module, inp, out): 78 | assert isinstance(module, nn.AvgPool2d) 79 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 80 | 81 | if isinstance(module.kernel_size, (tuple, list)): 82 | k_h, k_w = module.kernel_size 83 | else: 84 | k_h, k_w = module.kernel_size, module.kernel_size 85 | out_c, out_h, out_w = out.size()[1:] 86 | 87 | kernel_add = k_h * k_w - 1 88 | kernel_avg = 1 89 | 90 | return (kernel_add + kernel_avg) * (out_h * out_w) * out_c 91 | 92 | 93 | def compute_ReLU_madd(module, inp, out): 94 | assert isinstance(module, (nn.ReLU, nn.ReLU6)) 95 | 96 | count = 1 97 | for i in inp.size()[1:]: 98 | count *= i 99 | return count 100 | 101 | 102 | def compute_Softmax_madd(module, inp, out): 103 | assert isinstance(module, nn.Softmax) 104 | assert len(inp.size()) > 1 105 | 106 | count = 1 107 | for s in inp.size()[1:]: 108 | count *= s 109 | exp = count 110 | add = count - 1 111 | div = count 112 | return exp + add + div 113 | 114 | 115 | def compute_Linear_madd(module, inp, out): 116 | assert isinstance(module, nn.Linear) 117 | assert len(inp.size()) == 2 and len(out.size()) == 2 118 | 119 | num_in_features = inp.size()[1] 120 | num_out_features = out.size()[1] 121 | 122 | mul = num_in_features 123 | add = num_in_features - 1 124 | return num_out_features * (mul + add) 125 | 126 | 127 | def compute_Bilinear_madd(module, inp1, inp2, out): 128 | assert isinstance(module, nn.Bilinear) 129 | assert len(inp1.size()) == 2 and len(inp2.size()) == 2 and len(out.size()) == 2 130 | 131 | num_in_features_1 = inp1.size()[1] 132 | num_in_features_2 = inp2.size()[1] 133 | num_out_features = out.size()[1] 134 | 135 | mul = num_in_features_1 * num_in_features_2 + num_in_features_2 136 | add = num_in_features_1 * num_in_features_2 + num_in_features_2 - 1 137 | return num_out_features * (mul + add) 138 | 139 | 140 | def compute_madd(module, inp, out): 141 | if isinstance(module, nn.Conv2d): 142 | return compute_Conv2d_madd(module, inp, out) 143 | elif isinstance(module, nn.ConvTranspose2d): 144 | return compute_ConvTranspose2d_madd(module, inp, out) 145 | elif isinstance(module, nn.BatchNorm2d): 146 | return compute_BatchNorm2d_madd(module, inp, out) 147 | elif isinstance(module, nn.MaxPool2d): 148 | return compute_MaxPool2d_madd(module, inp, out) 149 | elif isinstance(module, nn.AvgPool2d): 150 | return compute_AvgPool2d_madd(module, inp, out) 151 | elif isinstance(module, (nn.ReLU, nn.ReLU6)): 152 | return compute_ReLU_madd(module, inp, out) 153 | elif isinstance(module, nn.Softmax): 154 | return compute_Softmax_madd(module, inp, out) 155 | elif isinstance(module, nn.Linear): 156 | return compute_Linear_madd(module, inp, out) 157 | elif isinstance(module, nn.Bilinear): 158 | return compute_Bilinear_madd(module, inp[0], inp[1], out) 159 | else: 160 | print(f"[MAdd]: {type(module).__name__} is not supported!") 161 | return 0 162 | -------------------------------------------------------------------------------- /utils/torchstat/torchstat/compute_memory.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | 5 | 6 | def compute_memory(module, inp, out): 7 | if isinstance(module, (nn.ReLU, nn.ReLU6, nn.ELU, nn.LeakyReLU)): 8 | return compute_ReLU_memory(module, inp, out) 9 | elif isinstance(module, nn.PReLU): 10 | return compute_PReLU_memory(module, inp, out) 11 | elif isinstance(module, nn.Conv2d): 12 | return compute_Conv2d_memory(module, inp, out) 13 | elif isinstance(module, nn.BatchNorm2d): 14 | return compute_BatchNorm2d_memory(module, inp, out) 15 | elif isinstance(module, nn.Linear): 16 | return compute_Linear_memory(module, inp, out) 17 | elif isinstance(module, (nn.AvgPool2d, nn.MaxPool2d)): 18 | return compute_Pool2d_memory(module, inp, out) 19 | else: 20 | print(f"[Memory]: {type(module).__name__} is not supported!") 21 | return (0, 0) 22 | pass 23 | 24 | 25 | def num_params(module): 26 | return sum(p.numel() for p in module.parameters() if p.requires_grad) 27 | 28 | 29 | def compute_ReLU_memory(module, inp, out): 30 | assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.ELU, nn.LeakyReLU)) 31 | batch_size = inp.size()[0] 32 | mread = batch_size * inp.size()[1:].numel() 33 | mwrite = batch_size * inp.size()[1:].numel() 34 | 35 | return (mread, mwrite) 36 | 37 | 38 | def compute_PReLU_memory(module, inp, out): 39 | assert isinstance(module, (nn.PReLU)) 40 | batch_size = inp.size()[0] 41 | mread = batch_size * (inp.size()[1:].numel() + num_params(module)) 42 | mwrite = batch_size * inp.size()[1:].numel() 43 | 44 | return (mread, mwrite) 45 | 46 | 47 | def compute_Conv2d_memory(module, inp, out): 48 | # Can have multiple inputs, getting the first one 49 | assert isinstance(module, nn.Conv2d) 50 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 51 | 52 | batch_size = inp.size()[0] 53 | in_c = inp.size()[1] 54 | out_c, out_h, out_w = out.size()[1:] 55 | 56 | # This includes weighs with bias if the module contains it. 57 | mread = batch_size * (inp.size()[1:].numel() + num_params(module)) 58 | mwrite = batch_size * out_c * out_h * out_w 59 | return (mread, mwrite) 60 | 61 | 62 | def compute_BatchNorm2d_memory(module, inp, out): 63 | assert isinstance(module, nn.BatchNorm2d) 64 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 65 | batch_size, in_c, in_h, in_w = inp.size() 66 | 67 | mread = batch_size * (inp.size()[1:].numel() + 2 * in_c) 68 | mwrite = inp.size().numel() 69 | return (mread, mwrite) 70 | 71 | 72 | def compute_Linear_memory(module, inp, out): 73 | assert isinstance(module, nn.Linear) 74 | assert len(inp.size()) == 2 and len(out.size()) == 2 75 | batch_size = inp.size()[0] 76 | mread = batch_size * (inp.size()[1:].numel() + num_params(module)) 77 | mwrite = out.size().numel() 78 | 79 | return (mread, mwrite) 80 | 81 | 82 | def compute_Pool2d_memory(module, inp, out): 83 | assert isinstance(module, (nn.MaxPool2d, nn.AvgPool2d)) 84 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) 85 | batch_size = inp.size()[0] 86 | mread = batch_size * inp.size()[1:].numel() 87 | mwrite = batch_size * out.size()[1:].numel() 88 | return (mread, mwrite) 89 | -------------------------------------------------------------------------------- /utils/torchstat/torchstat/model_hook.py: -------------------------------------------------------------------------------- 1 | import time 2 | from collections import OrderedDict 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | from torchstat import compute_madd 8 | from torchstat import compute_flops 9 | from torchstat import compute_memory 10 | 11 | 12 | class ModelHook(object): 13 | def __init__(self, model, input_size): 14 | assert isinstance(model, nn.Module) 15 | assert isinstance(input_size, (list, tuple)) 16 | 17 | self._model = model 18 | self._input_size = input_size 19 | self._origin_call = dict() # sub module call hook 20 | 21 | self._hook_model() 22 | x = torch.rand(1, *self._input_size).to('cuda') # add module duration time 23 | self._model.eval() 24 | self._model(x) 25 | 26 | @staticmethod 27 | def _register_buffer(module): 28 | assert isinstance(module, nn.Module) 29 | 30 | if len(list(module.children())) > 0: 31 | return 32 | 33 | module.register_buffer('input_shape', torch.zeros(3).int()) 34 | module.register_buffer('output_shape', torch.zeros(3).int()) 35 | module.register_buffer('parameter_quantity', torch.zeros(1).int()) 36 | module.register_buffer('inference_memory', torch.zeros(1).long()) 37 | module.register_buffer('MAdd', torch.zeros(1).long()) 38 | module.register_buffer('duration', torch.zeros(1).float()) 39 | module.register_buffer('Flops', torch.zeros(1).long()) 40 | module.register_buffer('Memory', torch.zeros(2).long()) 41 | 42 | def _sub_module_call_hook(self): 43 | def wrap_call(module, *input, **kwargs): 44 | assert module.__class__ in self._origin_call 45 | 46 | # Itemsize for memory 47 | itemsize = input[0].cpu().detach().numpy().itemsize 48 | 49 | start = time.time() 50 | output = self._origin_call[module.__class__](module, *input, **kwargs) 51 | end = time.time() 52 | module.duration = torch.from_numpy( 53 | np.array([end - start], dtype=np.float32)) 54 | 55 | module.input_shape = torch.from_numpy( 56 | np.array(input[0].size()[1:], dtype=np.int32)) 57 | module.output_shape = torch.from_numpy( 58 | np.array(output.size()[1:], dtype=np.int32)) 59 | 60 | parameter_quantity = 0 61 | # iterate through parameters and count num params 62 | for name, p in module._parameters.items(): 63 | parameter_quantity += (0 if p is None else torch.numel(p.data)) 64 | module.parameter_quantity = torch.from_numpy( 65 | np.array([parameter_quantity], dtype=np.long)) 66 | 67 | inference_memory = 1 68 | for s in output.size()[1:]: 69 | inference_memory *= s 70 | # memory += parameters_number # exclude parameter memory 71 | inference_memory = inference_memory * 4 / (1024 ** 2) # shown as MB unit 72 | module.inference_memory = torch.from_numpy( 73 | np.array([inference_memory], dtype=np.float32)) 74 | 75 | if len(input) == 1: 76 | madd = compute_madd(module, input[0], output) 77 | flops = compute_flops(module, input[0], output) 78 | Memory = compute_memory(module, input[0], output) 79 | elif len(input) > 1: 80 | madd = compute_madd(module, input, output) 81 | flops = compute_flops(module, input, output) 82 | Memory = compute_memory(module, input, output) 83 | else: # error 84 | madd = 0 85 | flops = 0 86 | Memory = (0, 0) 87 | module.MAdd = torch.from_numpy( 88 | np.array([madd], dtype=np.int64)) 89 | module.Flops = torch.from_numpy( 90 | np.array([flops], dtype=np.int64)) 91 | Memory = np.array(Memory, dtype=np.int32) * itemsize 92 | module.Memory = torch.from_numpy(Memory) 93 | 94 | return output 95 | 96 | for module in self._model.modules(): 97 | if len(list(module.children())) == 0 and module.__class__ not in self._origin_call: 98 | self._origin_call[module.__class__] = module.__class__.__call__ 99 | module.__class__.__call__ = wrap_call 100 | 101 | def _hook_model(self): 102 | self._model.apply(self._register_buffer) 103 | self._sub_module_call_hook() 104 | 105 | @staticmethod 106 | def _retrieve_leaf_modules(model): 107 | leaf_modules = [] 108 | for name, m in model.named_modules(): 109 | if len(list(m.children())) == 0: 110 | leaf_modules.append((name, m)) 111 | return leaf_modules 112 | 113 | def retrieve_leaf_modules(self): 114 | return OrderedDict(self._retrieve_leaf_modules(self._model)) 115 | -------------------------------------------------------------------------------- /utils/torchstat/torchstat/reporter.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | 3 | 4 | pd.set_option('display.width', 1000) 5 | pd.set_option('display.max_rows', 10000) 6 | pd.set_option('display.max_columns', 10000) 7 | 8 | 9 | def round_value(value, binary=False): 10 | divisor = 1024. if binary else 1000. 11 | 12 | if value // divisor**3 > 0: 13 | return str(round(value / divisor**3, 2)) + 'G' 14 | elif value // divisor**2 > 0: 15 | return str(round(value / divisor**2, 2)) + 'M' 16 | elif value // divisor > 0: 17 | return str(round(value / divisor, 2)) + 'K' 18 | return str(value) 19 | 20 | 21 | def report_format(collected_nodes): 22 | data = list() 23 | for node in collected_nodes: 24 | name = node.name 25 | input_shape = ' '.join(['{:>3d}'] * len(node.input_shape)).format( 26 | *[e for e in node.input_shape]) 27 | output_shape = ' '.join(['{:>3d}'] * len(node.output_shape)).format( 28 | *[e for e in node.output_shape]) 29 | parameter_quantity = node.parameter_quantity 30 | inference_memory = node.inference_memory 31 | MAdd = node.MAdd 32 | Flops = node.Flops 33 | mread, mwrite = [i for i in node.Memory] 34 | duration = node.duration 35 | data.append([name, input_shape, output_shape, parameter_quantity, 36 | inference_memory, MAdd, duration, Flops, mread, 37 | mwrite]) 38 | df = pd.DataFrame(data) 39 | df.columns = ['module name', 'input shape', 'output shape', 40 | 'params', 'memory(MB)', 41 | 'MAdd', 'duration', 'Flops', 'MemRead(B)', 'MemWrite(B)'] 42 | df['duration[%]'] = df['duration'] / (df['duration'].sum() + 1e-7) 43 | df['MemR+W(B)'] = df['MemRead(B)'] + df['MemWrite(B)'] 44 | total_parameters_quantity = df['params'].sum() 45 | total_memory = df['memory(MB)'].sum() 46 | total_operation_quantity = df['MAdd'].sum() 47 | total_flops = df['Flops'].sum() 48 | total_duration = df['duration[%]'].sum() 49 | total_mread = df['MemRead(B)'].sum() 50 | total_mwrite = df['MemWrite(B)'].sum() 51 | total_memrw = df['MemR+W(B)'].sum() 52 | del df['duration'] 53 | 54 | # Add Total row 55 | total_df = pd.Series([total_parameters_quantity, total_memory, 56 | total_operation_quantity, total_flops, 57 | total_duration, mread, mwrite, total_memrw], 58 | index=['params', 'memory(MB)', 'MAdd', 'Flops', 'duration[%]', 59 | 'MemRead(B)', 'MemWrite(B)', 'MemR+W(B)'], 60 | name='total') 61 | df = df.append(total_df) 62 | 63 | df = df.fillna(' ') 64 | df['memory(MB)'] = df['memory(MB)'].apply( 65 | lambda x: '{:.2f}'.format(x)) 66 | df['duration[%]'] = df['duration[%]'].apply(lambda x: '{:.2%}'.format(x)) 67 | df['MAdd'] = df['MAdd'].apply(lambda x: '{:,}'.format(x)) 68 | df['Flops'] = df['Flops'].apply(lambda x: '{:,}'.format(x)) 69 | 70 | summary = str(df) + '\n' 71 | summary += "=" * len(str(df).split('\n')[0]) 72 | summary += '\n' 73 | summary += "Total params: {:,}\n".format(total_parameters_quantity) 74 | 75 | summary += "-" * len(str(df).split('\n')[0]) 76 | summary += '\n' 77 | summary += "Total memory: {:.2f}MB\n".format(total_memory) 78 | summary += "Total MAdd: {}MAdd\n".format(round_value(total_operation_quantity)) 79 | summary += "Total Flops: {}Flops\n".format(round_value(total_flops)) 80 | summary += "Total MemR+W: {}B\n".format(round_value(total_memrw, True)) 81 | return summary 82 | -------------------------------------------------------------------------------- /utils/torchstat/torchstat/stat_tree.py: -------------------------------------------------------------------------------- 1 | import queue 2 | 3 | 4 | class StatTree(object): 5 | def __init__(self, root_node): 6 | assert isinstance(root_node, StatNode) 7 | 8 | self.root_node = root_node 9 | 10 | def get_same_level_max_node_depth(self, query_node): 11 | if query_node.name == self.root_node.name: 12 | return 0 13 | same_level_depth = max([child.depth for child in query_node.parent.children]) 14 | return same_level_depth 15 | 16 | def update_stat_nodes_granularity(self): 17 | q = queue.Queue() 18 | q.put(self.root_node) 19 | while not q.empty(): 20 | node = q.get() 21 | node.granularity = self.get_same_level_max_node_depth(node) 22 | for child in node.children: 23 | q.put(child) 24 | 25 | def get_collected_stat_nodes(self, query_granularity): 26 | self.update_stat_nodes_granularity() 27 | 28 | collected_nodes = [] 29 | stack = list() 30 | stack.append(self.root_node) 31 | while len(stack) > 0: 32 | node = stack.pop() 33 | for child in reversed(node.children): 34 | stack.append(child) 35 | if node.depth == query_granularity: 36 | collected_nodes.append(node) 37 | if node.depth < query_granularity <= node.granularity: 38 | collected_nodes.append(node) 39 | return collected_nodes 40 | 41 | 42 | class StatNode(object): 43 | def __init__(self, name=str(), parent=None): 44 | self._name = name 45 | self._input_shape = None 46 | self._output_shape = None 47 | self._parameter_quantity = 0 48 | self._inference_memory = 0 49 | self._MAdd = 0 50 | self._Memory = (0, 0) 51 | self._Flops = 0 52 | self._duration = 0 53 | self._duration_percent = 0 54 | 55 | self._granularity = 1 56 | self._depth = 1 57 | self.parent = parent 58 | self.children = list() 59 | 60 | @property 61 | def name(self): 62 | return self._name 63 | 64 | @name.setter 65 | def name(self, name): 66 | self._name = name 67 | 68 | @property 69 | def granularity(self): 70 | return self._granularity 71 | 72 | @granularity.setter 73 | def granularity(self, g): 74 | self._granularity = g 75 | 76 | @property 77 | def depth(self): 78 | d = self._depth 79 | if len(self.children) > 0: 80 | d += max([child.depth for child in self.children]) 81 | return d 82 | 83 | @property 84 | def input_shape(self): 85 | if len(self.children) == 0: # leaf 86 | return self._input_shape 87 | else: 88 | return self.children[0].input_shape 89 | 90 | @input_shape.setter 91 | def input_shape(self, input_shape): 92 | assert isinstance(input_shape, (list, tuple)) 93 | self._input_shape = input_shape 94 | 95 | @property 96 | def output_shape(self): 97 | if len(self.children) == 0: # leaf 98 | return self._output_shape 99 | else: 100 | return self.children[-1].output_shape 101 | 102 | @output_shape.setter 103 | def output_shape(self, output_shape): 104 | assert isinstance(output_shape, (list, tuple)) 105 | self._output_shape = output_shape 106 | 107 | @property 108 | def parameter_quantity(self): 109 | # return self.parameters_quantity 110 | total_parameter_quantity = self._parameter_quantity 111 | for child in self.children: 112 | total_parameter_quantity += child.parameter_quantity 113 | return total_parameter_quantity 114 | 115 | @parameter_quantity.setter 116 | def parameter_quantity(self, parameter_quantity): 117 | assert parameter_quantity >= 0 118 | self._parameter_quantity = parameter_quantity 119 | 120 | @property 121 | def inference_memory(self): 122 | total_inference_memory = self._inference_memory 123 | for child in self.children: 124 | total_inference_memory += child.inference_memory 125 | return total_inference_memory 126 | 127 | @inference_memory.setter 128 | def inference_memory(self, inference_memory): 129 | self._inference_memory = inference_memory 130 | 131 | @property 132 | def MAdd(self): 133 | total_MAdd = self._MAdd 134 | for child in self.children: 135 | total_MAdd += child.MAdd 136 | return total_MAdd 137 | 138 | @MAdd.setter 139 | def MAdd(self, MAdd): 140 | self._MAdd = MAdd 141 | 142 | @property 143 | def Flops(self): 144 | total_Flops = self._Flops 145 | for child in self.children: 146 | total_Flops += child.Flops 147 | return total_Flops 148 | 149 | @Flops.setter 150 | def Flops(self, Flops): 151 | self._Flops = Flops 152 | 153 | @property 154 | def Memory(self): 155 | total_Memory = self._Memory 156 | for child in self.children: 157 | total_Memory[0] += child.Memory[0] 158 | total_Memory[1] += child.Memory[1] 159 | print(total_Memory) 160 | return total_Memory 161 | 162 | @Memory.setter 163 | def Memory(self, Memory): 164 | assert isinstance(Memory, (list, tuple)) 165 | self._Memory = Memory 166 | 167 | @property 168 | def duration(self): 169 | total_duration = self._duration 170 | for child in self.children: 171 | total_duration += child.duration 172 | return total_duration 173 | 174 | @duration.setter 175 | def duration(self, duration): 176 | self._duration = duration 177 | 178 | def find_child_index(self, child_name): 179 | assert isinstance(child_name, str) 180 | 181 | index = -1 182 | for i in range(len(self.children)): 183 | if child_name == self.children[i].name: 184 | index = i 185 | return index 186 | 187 | def add_child(self, node): 188 | assert isinstance(node, StatNode) 189 | 190 | if self.find_child_index(node.name) == -1: # not exist 191 | self.children.append(node) 192 | -------------------------------------------------------------------------------- /utils/torchstat/torchstat/statistics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchstat import ModelHook 4 | from collections import OrderedDict 5 | from torchstat import StatTree, StatNode, report_format 6 | 7 | 8 | def get_parent_node(root_node, stat_node_name): 9 | assert isinstance(root_node, StatNode) 10 | 11 | node = root_node 12 | names = stat_node_name.split('.') 13 | for i in range(len(names) - 1): 14 | node_name = '.'.join(names[0:i+1]) 15 | child_index = node.find_child_index(node_name) 16 | assert child_index != -1 17 | node = node.children[child_index] 18 | return node 19 | 20 | 21 | def convert_leaf_modules_to_stat_tree(leaf_modules): 22 | assert isinstance(leaf_modules, OrderedDict) 23 | 24 | create_index = 1 25 | root_node = StatNode(name='root', parent=None) 26 | for leaf_module_name, leaf_module in leaf_modules.items(): 27 | names = leaf_module_name.split('.') 28 | for i in range(len(names)): 29 | create_index += 1 30 | stat_node_name = '.'.join(names[0:i+1]) 31 | parent_node = get_parent_node(root_node, stat_node_name) 32 | node = StatNode(name=stat_node_name, parent=parent_node) 33 | parent_node.add_child(node) 34 | if i == len(names) - 1: # leaf module itself 35 | input_shape = leaf_module.input_shape.numpy().tolist() 36 | output_shape = leaf_module.output_shape.numpy().tolist() 37 | node.input_shape = input_shape 38 | node.output_shape = output_shape 39 | node.parameter_quantity = leaf_module.parameter_quantity.numpy()[0] 40 | node.inference_memory = leaf_module.inference_memory.numpy()[0] 41 | node.MAdd = leaf_module.MAdd.numpy()[0] 42 | node.Flops = leaf_module.Flops.numpy()[0] 43 | node.duration = leaf_module.duration.numpy()[0] 44 | node.Memory = leaf_module.Memory.numpy().tolist() 45 | return StatTree(root_node) 46 | 47 | 48 | class ModelStat(object): 49 | def __init__(self, model, input_size, query_granularity=1): 50 | assert isinstance(model, nn.Module) 51 | assert isinstance(input_size, (tuple, list)) and len(input_size) == 3 52 | self._model = model 53 | self._input_size = input_size 54 | self._query_granularity = query_granularity 55 | 56 | def _analyze_model(self): 57 | model_hook = ModelHook(self._model, self._input_size) 58 | leaf_modules = model_hook.retrieve_leaf_modules() 59 | stat_tree = convert_leaf_modules_to_stat_tree(leaf_modules) 60 | collected_nodes = stat_tree.get_collected_stat_nodes(self._query_granularity) 61 | return collected_nodes 62 | 63 | def show_report(self): 64 | collected_nodes = self._analyze_model() 65 | report = report_format(collected_nodes) 66 | print(report) 67 | 68 | 69 | def stat(model, input_size, query_granularity=1): 70 | ms = ModelStat(model, input_size, query_granularity) 71 | ms.show_report() 72 | -------------------------------------------------------------------------------- /utils/video_metric_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from base.file_io_base import write_csv 4 | from base.os_base import listdir 5 | import utils.LPIPS.models as lpips_models 6 | from utils.image_metric_utils import calc_image_PSNR_SSIM, calc_image_LPIPS, calc_image_NIQE 7 | 8 | 9 | def calc_video_PSNR_SSIM(output_root, gt_root, crop_border=4, shift_window_size=0, test_ycbcr=False, crop_GT=False): 10 | ''' 11 | 计算视频的 PSNR、SSIM,使用 EDVR 的计算方式 12 | 要求 output_root, gt_root 中的文件按顺序一一对应 13 | ''' 14 | 15 | PSNR_sum = 0. 16 | SSIM_sum = 0. 17 | img_num = 0 18 | 19 | video_PSNR = [] 20 | video_SSIM = [] 21 | 22 | video_list = sorted(listdir(output_root)) 23 | for v in video_list: 24 | v_PSNR_list, v_SSIM_list, _ = calc_image_PSNR_SSIM( 25 | output_root=os.path.join(output_root, v), 26 | gt_root=os.path.join(gt_root, v), 27 | crop_border=crop_border, shift_window_size=shift_window_size, 28 | test_ycbcr=test_ycbcr, crop_GT=crop_GT 29 | ) 30 | PSNR_sum += sum(v_PSNR_list) 31 | SSIM_sum += sum(v_SSIM_list) 32 | img_num += len(v_PSNR_list) 33 | 34 | video_PSNR.append({ 35 | 'video_name': v, 36 | 'psnr': v_PSNR_list 37 | }) 38 | video_SSIM.append({ 39 | 'video_name': v, 40 | 'ssim': v_SSIM_list 41 | }) 42 | 43 | logs = [] 44 | PSNR_SSIM_csv_log = { 45 | 'col_names': [], 46 | 'row_names': [output_root], 47 | 'psnr_ssim': [[]] 48 | } 49 | for v_psnr, v_ssim in zip(video_PSNR, video_SSIM): 50 | PSNR_SSIM_csv_log['col_names'].append('#{}'.format(v_psnr['video_name'])) 51 | PSNR_SSIM_csv_log['psnr_ssim'][0].append('{:.5}/{:.4}'.format(sum(v_psnr['psnr']) / len(v_psnr['psnr']), 52 | sum(v_ssim['ssim']) / len(v_ssim['ssim']))) 53 | log = 'Video: {} PSNR={:.5}, SSIM={:.4}'.format(v_psnr['video_name'], 54 | sum(v_psnr['psnr']) / len(v_psnr['psnr']), 55 | sum(v_ssim['ssim']) / len(v_ssim['ssim'])) 56 | print(log) 57 | logs.append(log) 58 | PSNR_SSIM_csv_log['col_names'].append('AVG') 59 | PSNR_SSIM_csv_log['psnr_ssim'][0].append('{:.5}/{:.4}'.format(PSNR_sum / img_num, SSIM_sum / img_num)) 60 | log = 'Average PSNR={:.5}, SSIM={:.4}'.format(PSNR_sum / img_num, SSIM_sum / img_num) 61 | print(log) 62 | logs.append(log) 63 | 64 | return PSNR_SSIM_csv_log, logs 65 | 66 | 67 | def batch_calc_video_PSNR_SSIM(root_list, crop_border=4, shift_window_size=0, test_ycbcr=False, crop_GT=False, 68 | save_log=False, save_log_root=None, combine_save=False): 69 | ''' 70 | required params: 71 | root_list: a list, each item should be a dictionary that given two key-values: 72 | output: the dir of output videos 73 | gt: the dir of gt videos 74 | optional params: 75 | crop_border: defalut=4, crop pixels when calculating PSNR/SSIM 76 | shift_window_size: defalut=0, if >0, shifting image within a window for best metric 77 | test_ycbcr: default=False, if True, applying Ycbcr color space 78 | crop_GT: default=False, if True, cropping GT to output size 79 | save_log: default=False, if True, saving csv log 80 | save_log_root: thr dir of output log 81 | combine_save: default=False, if True, combining all output log to one csv file 82 | return: 83 | log_list: a list, each item is a dictionary that given two key-values: 84 | data_path: the evaluated dir 85 | log: the log of this dir 86 | ''' 87 | if save_log: 88 | assert save_log_root is not None, "Unknown save_log_root!" 89 | 90 | total_csv_log = [] 91 | log_list = [] 92 | for i, root in enumerate(root_list): 93 | ouput_root = root['output'] 94 | gt_root = root['gt'] 95 | print(">>>> Now Evaluation >>>>") 96 | print(">>>> OUTPUT: {}".format(ouput_root)) 97 | print(">>>> GT: {}".format(gt_root)) 98 | csv_log, logs = calc_video_PSNR_SSIM( 99 | ouput_root, gt_root, crop_border=crop_border, shift_window_size=shift_window_size, 100 | test_ycbcr=test_ycbcr, crop_GT=crop_GT 101 | ) 102 | log_list.append({ 103 | 'data_path': ouput_root, 104 | 'log': logs 105 | }) 106 | 107 | # output the PSNR/SSIM log of each evaluated dir to a single csv file 108 | if save_log: 109 | csv_log['row_names'] = [os.path.basename(p) for p in csv_log['row_names']] 110 | write_csv(file_path=os.path.join(save_log_root, "{}_{}.csv".format(i, csv_log['row_names'][0])), 111 | data=np.array(csv_log['psnr_ssim']), 112 | row_names=csv_log['row_names'], 113 | col_names=csv_log['col_names']) 114 | total_csv_log.append(csv_log) 115 | 116 | # output all PSNR/SSIM log to a csv file 117 | if save_log and combine_save and len(total_csv_log) > 0: 118 | com_csv_log = { 119 | 'col_names': total_csv_log[0]['col_names'], 120 | 'row_names': [], 121 | 'psnr_ssim': [] 122 | } 123 | for csv_log in total_csv_log: 124 | com_csv_log['row_names'].append(csv_log['row_names'][0]) 125 | com_csv_log['psnr_ssim'].append(csv_log['psnr_ssim'][0]) 126 | write_csv(file_path=os.path.join(save_log_root, "psnr_ssim.csv"), 127 | data=np.array(com_csv_log['psnr_ssim']), 128 | row_names=com_csv_log['row_names'], 129 | col_names=com_csv_log['col_names']) 130 | 131 | print("--------------------------------------------------------------------------------------") 132 | for i, logs in enumerate(log_list): 133 | print("## The {}-th:".format(i)) 134 | print(">> ", logs['data_path']) 135 | for log in logs['log']: 136 | print(">> ", log) 137 | 138 | return log_list 139 | 140 | 141 | def calc_video_LPIPS(output_root, gt_root, model=None, use_gpu=False, spatial=True): 142 | ''' 143 | 计算视频的 LPIPS 144 | 要求 output_root, gt_root 中的文件按顺序一一对应 145 | ''' 146 | 147 | if model is None: 148 | model = lpips_models.PerceptualLoss(model='net-lin', net='alex', use_gpu=use_gpu, spatial=spatial) 149 | 150 | LPIPS_sum = 0. 151 | img_num = 0 152 | 153 | video_LPIPS = [] 154 | 155 | video_list = sorted(listdir(output_root)) 156 | for v in video_list: 157 | v_LPIPS_list, _ = calc_image_LPIPS( 158 | output_root=os.path.join(output_root, v), 159 | gt_root=os.path.join(gt_root, v), 160 | model=model, use_gpu=use_gpu, spatial=spatial 161 | ) 162 | LPIPS_sum += sum(v_LPIPS_list) 163 | img_num += len(v_LPIPS_list) 164 | 165 | video_LPIPS.append({ 166 | 'video_name': v, 167 | 'lpips': v_LPIPS_list 168 | }) 169 | 170 | logs = [] 171 | LPIPS_csv_log = { 172 | 'col_names': [], 173 | 'row_names': [output_root], 174 | 'lpips': [[]] 175 | } 176 | for v_lpips in video_LPIPS: 177 | LPIPS_csv_log['col_names'].append('#{}'.format(v_lpips['video_name'])) 178 | LPIPS_csv_log['lpips'][0].append('{:.4}'.format(sum(v_lpips['lpips']) / len(v_lpips['lpips']))) 179 | log = 'Video: {} LPIPS={:.4}'.format(v_lpips['video_name'], sum(v_lpips['lpips']) / len(v_lpips['lpips'])) 180 | print(log) 181 | logs.append(log) 182 | LPIPS_csv_log['col_names'].append('AVG') 183 | LPIPS_csv_log['lpips'][0].append('{:.4}'.format(LPIPS_sum / img_num)) 184 | log = 'Average LPIPS={:.4}'.format(LPIPS_sum / img_num) 185 | print(log) 186 | logs.append(log) 187 | 188 | return LPIPS_csv_log, logs 189 | 190 | 191 | def batch_calc_video_LPIPS(root_list, use_gpu=False, spatial=True, 192 | save_log=False, save_log_root=None, combine_save=False): 193 | ''' 194 | required params: 195 | root_list: a list, each item should be a dictionary that given two key-values: 196 | output: the dir of output videos 197 | gt: the dir of gt videos 198 | optional params: 199 | use_gpu: defalut=False, if True, using gpu 200 | spatial: default=True, if True, return spatial map 201 | save_log: default=False, if True, saving csv log 202 | save_log_root: thr dir of output log 203 | combine_save: default=False, if True, combining all output log to one csv file 204 | return: 205 | log_list: a list, each item is a dictionary that given two key-values: 206 | data_path: the evaluated dir 207 | log: the log of this dir 208 | ''' 209 | if save_log: 210 | assert save_log_root is not None, "Unknown save_log_root!" 211 | 212 | model = lpips_models.PerceptualLoss(model='net-lin', net='alex', use_gpu=use_gpu, spatial=spatial) 213 | 214 | total_csv_log = [] 215 | log_list = [] 216 | for i, root in enumerate(root_list): 217 | ouput_root = root['output'] 218 | gt_root = root['gt'] 219 | print(">>>> Now Evaluation >>>>") 220 | print(">>>> OUTPUT: {}".format(ouput_root)) 221 | print(">>>> GT: {}".format(gt_root)) 222 | csv_log, logs = calc_video_LPIPS(ouput_root, gt_root, model=model, use_gpu=use_gpu, spatial=spatial) 223 | log_list.append({ 224 | 'data_path': ouput_root, 225 | 'log': logs 226 | }) 227 | 228 | # output the LPIPS log of each evaluated dir to a single csv file 229 | if save_log: 230 | csv_log['row_names'] = [os.path.basename(p) for p in csv_log['row_names']] 231 | write_csv(file_path=os.path.join(save_log_root, "{}_{}.csv".format(i, csv_log['row_names'][0])), 232 | data=np.array(csv_log['lpips']), 233 | row_names=csv_log['row_names'], 234 | col_names=csv_log['col_names']) 235 | total_csv_log.append(csv_log) 236 | 237 | # output all LPIPS log to a csv file 238 | if save_log and combine_save and len(total_csv_log) > 0: 239 | com_csv_log = { 240 | 'col_names': total_csv_log[0]['col_names'], 241 | 'row_names': [], 242 | 'lpips': [] 243 | } 244 | for csv_log in total_csv_log: 245 | com_csv_log['row_names'].append(csv_log['row_names'][0]) 246 | com_csv_log['lpips'].append(csv_log['lpips'][0]) 247 | write_csv(file_path=os.path.join(save_log_root, "lpips.csv"), 248 | data=np.array(com_csv_log['lpips']), 249 | row_names=com_csv_log['row_names'], 250 | col_names=com_csv_log['col_names']) 251 | 252 | print("--------------------------------------------------------------------------------------") 253 | for i, logs in enumerate(log_list): 254 | print("## The {}-th:".format(i)) 255 | print(">> ", logs['data_path']) 256 | for log in logs['log']: 257 | print(">> ", log) 258 | 259 | return log_list 260 | 261 | 262 | def calc_video_NIQE(output_root, crop_border=4): 263 | ''' 264 | 计算视频的 NIQE 265 | ''' 266 | 267 | NIQE_sum = 0. 268 | img_num = 0 269 | 270 | video_NIQE = [] 271 | 272 | video_list = sorted(listdir(output_root)) 273 | for v in video_list: 274 | v_NIQE_list, _ = calc_image_NIQE( 275 | output_root=os.path.join(output_root, v), 276 | crop_border=crop_border 277 | ) 278 | NIQE_sum += sum(v_NIQE_list) 279 | img_num += len(v_NIQE_list) 280 | 281 | video_NIQE.append({ 282 | 'video_name': v, 283 | 'niqe': v_NIQE_list 284 | }) 285 | 286 | logs = [] 287 | NIQE_csv_log = { 288 | 'col_names': [], 289 | 'row_names': [output_root], 290 | 'niqe': [[]] 291 | } 292 | for v_niqe in video_NIQE: 293 | NIQE_csv_log['col_names'].append('#{}'.format(v_niqe['video_name'])) 294 | NIQE_csv_log['niqe'][0].append('{:.5}'.format(sum(v_niqe['niqe']) / len(v_niqe['niqe']))) 295 | log = 'Video: {} NIQE={:.5}'.format(v_niqe['video_name'], sum(v_niqe['niqe']) / len(v_niqe['niqe'])) 296 | print(log) 297 | logs.append(log) 298 | NIQE_csv_log['col_names'].append('AVG') 299 | NIQE_csv_log['niqe'][0].append('{:.5}'.format(NIQE_sum / img_num)) 300 | log = 'Average NIQE={:.5}'.format(NIQE_sum / img_num) 301 | print(log) 302 | logs.append(log) 303 | 304 | return NIQE_csv_log, logs 305 | 306 | 307 | def batch_calc_video_NIQE(root_list, crop_border=4, save_log=False, save_log_root=None, combine_save=False): 308 | ''' 309 | required params: 310 | root_list: a list, each item should be a dictionary that given key-values: 311 | output: the dir of output videos 312 | optional params: 313 | crop_border: defalut=4, crop pixels when calculating NIQE 314 | save_log: default=False, if True, saving csv log 315 | save_log_root: thr dir of output log 316 | combine_save: default=False, if True, combining all output log to one csv file 317 | return: 318 | log_list: a list, each item is a dictionary that given two key-values: 319 | data_path: the evaluated dir 320 | log: the log of this dir 321 | ''' 322 | if save_log: 323 | assert save_log_root is not None, "Unknown save_log_root!" 324 | 325 | total_csv_log = [] 326 | log_list = [] 327 | for i, root in enumerate(root_list): 328 | ouput_root = root['output'] 329 | print(">>>> Now Evaluation >>>>") 330 | print(">>>> OUTPUT: {}".format(ouput_root)) 331 | csv_log, logs = calc_video_NIQE( 332 | ouput_root, crop_border=crop_border 333 | ) 334 | log_list.append({ 335 | 'data_path': ouput_root, 336 | 'log': logs 337 | }) 338 | 339 | # output the NIQE log of each evaluated dir to a single csv file 340 | if save_log: 341 | csv_log['row_names'] = [os.path.basename(p) for p in csv_log['row_names']] 342 | write_csv(file_path=os.path.join(save_log_root, "{}_{}.csv".format(i, csv_log['row_names'][0])), 343 | data=np.array(csv_log['niqe']), 344 | row_names=csv_log['row_names'], 345 | col_names=csv_log['col_names']) 346 | total_csv_log.append(csv_log) 347 | 348 | # output all NIQE log to a csv file 349 | if save_log and combine_save and len(total_csv_log) > 0: 350 | com_csv_log = { 351 | 'col_names': total_csv_log[0]['col_names'], 352 | 'row_names': [], 353 | 'niqe': [] 354 | } 355 | for csv_log in total_csv_log: 356 | com_csv_log['row_names'].append(csv_log['row_names'][0]) 357 | com_csv_log['niqe'].append(csv_log['niqe'][0]) 358 | write_csv(file_path=os.path.join(save_log_root, "niqe.csv"), 359 | data=np.array(com_csv_log['niqe']), 360 | row_names=com_csv_log['row_names'], 361 | col_names=com_csv_log['col_names']) 362 | 363 | print("--------------------------------------------------------------------------------------") 364 | for i, logs in enumerate(log_list): 365 | print("## The {}-th:".format(i)) 366 | print(">> ", logs['data_path']) 367 | for log in logs['log']: 368 | print(">> ", log) 369 | 370 | return log_list 371 | -------------------------------------------------------------------------------- /utils/video_regroup_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from base.os_base import handle_dir, copy_file, move_file, listdir, glob_match 3 | from utils import file_regroup_utils 4 | 5 | 6 | def VideoFlag2FlagVideo(ori_root, dest_root, ori_flag, dest_flag=None): 7 | ''' 8 | videos/type/frames --> type/videos/frames 9 | params: 10 | ori_root: the dir of files that need to be processed 11 | dest_root: the dir for saving matched files 12 | ori_flag: the ori video flag(e.g. blur) 13 | dest_flag: the flag(e.g. blur) for saving videos 14 | default: None, that is keeping the ori flag 15 | ''' 16 | if dest_flag is None: 17 | dest_flag = ori_flag 18 | handle_dir(dest_root) 19 | handle_dir(os.path.join(dest_root, dest_flag)) 20 | video_list = listdir(ori_root) 21 | for v in video_list: 22 | image_list = listdir(os.path.join(ori_root, v, ori_flag)) 23 | handle_dir(os.path.join(dest_root, dest_flag, v)) 24 | for im in image_list: 25 | src = os.path.join(ori_root, v, ori_flag, im) 26 | dst = os.path.join(dest_root, dest_flag, v, im) 27 | copy_file(src, dst) 28 | 29 | 30 | def FlagVideo2VideoFlag(ori_root, dest_root, ori_flag, dest_flag=None): 31 | ''' 32 | blur/videos/frames --> videos/blur/frames 33 | params: 34 | ori_root: the dir of files that need to be processed 35 | dest_root: the dir for saving matched files 36 | ori_flag: the ori video flag(e.g. blur) 37 | dest_flag: the flag(e.g. blur) for saving videos 38 | default: None, that is keeping the ori flag 39 | ''' 40 | if dest_flag is None: 41 | dest_flag = ori_flag 42 | handle_dir(dest_root) 43 | video_list = listdir(os.path.join(ori_root, ori_flag)) 44 | for v in video_list: 45 | image_list = listdir(os.path.join(ori_root, ori_flag, v)) 46 | handle_dir(os.path.join(dest_root, v)) 47 | handle_dir(os.path.join(dest_root, v, dest_flag)) 48 | for im in image_list: 49 | src = os.path.join(ori_root, ori_flag, v, im) 50 | dst = os.path.join(dest_root, v, dest_flag, im) 51 | copy_file(src, dst) 52 | 53 | 54 | def remove_frames_prefix(root, prefix=''): 55 | ''' 56 | remove prefix from frames 57 | params: 58 | root: the dir of videos that need to be processed 59 | prefix: the prefix to be removed 60 | ''' 61 | video_list = listdir(root) 62 | for v in video_list: 63 | file_regroup_utils.remove_files_prefix(os.path.join(root, v), prefix=prefix) 64 | 65 | 66 | def remove_frames_postfix(root, postfix=''): 67 | ''' 68 | remove postfix from frames 69 | params: 70 | root: the dir of videos that need to be processed 71 | postfix: the postfix to be removed 72 | ''' 73 | video_list = listdir(root) 74 | for v in video_list: 75 | file_regroup_utils.remove_files_postfix(os.path.join(root, v), postfix=postfix) 76 | 77 | 78 | def add_frames_postfix(root, postfix=''): 79 | ''' 80 | add postfix to frames 81 | params: 82 | root: the dir of videos that need to be processed 83 | postfix: the postfix to be added 84 | ''' 85 | video_list = listdir(root) 86 | for v in video_list: 87 | file_regroup_utils.add_files_postfix(os.path.join(root, v), postfix=postfix) 88 | 89 | 90 | def extra_frames_by_postfix(ori_root, dest_root, match_postfix='', new_postfix=None, match_ext='*'): 91 | ''' 92 | extra frames from ori_root to dest_root by match_postfix and match_ext 93 | params: 94 | ori_root: the dir of videos that need to be processed 95 | dest_root: the dir for saving matched files 96 | match_postfix: the postfix to be matched 97 | new_postfix: the postfix for matched files 98 | default: None, that is keeping the ori postfix 99 | match_ext: the ext to be matched 100 | ''' 101 | if new_postfix is None: 102 | new_postfix = match_postfix 103 | 104 | handle_dir(dest_root) 105 | video_list = listdir(ori_root) 106 | for v in video_list: 107 | file_regroup_utils.extra_files_by_postfix( 108 | ori_root=os.path.join(ori_root, v), 109 | dest_root=os.path.join(dest_root, v), 110 | match_postfix=match_postfix, 111 | new_postfix=new_postfix, 112 | match_ext=match_ext 113 | ) 114 | 115 | 116 | def resort_frames_index(root, template='{:0>4}', start_idx=0): 117 | ''' 118 | resort frames' filename using template that index start from start_idx 119 | params: 120 | root: the dir of files that need to be processed 121 | template: the template for processed filename 122 | start_idx: the start index 123 | ''' 124 | video_list = listdir(root) 125 | for v in video_list: 126 | file_regroup_utils.resort_files_index(os.path.join(root, v), template=template, start_idx=start_idx) 127 | 128 | 129 | def remove_head_tail_frames(root, recycle_bin=None, num=0): 130 | ''' 131 | remove num hean&tail frames from videos 132 | params: 133 | root: the dir of files that need to be processed 134 | recycle_bin: the removed frames will be put here 135 | defalut: None, that is putting the removed frames in root/_recycle_bin 136 | num: the number of frames to be removed 137 | ''' 138 | if recycle_bin is None: 139 | recycle_bin = os.path.join(root, '_recycle_bin') 140 | handle_dir(recycle_bin) 141 | 142 | video_list = listdir(root) 143 | for v in video_list: 144 | img_list = sorted(glob_match(os.path.join(root, v, "*"))) 145 | handle_dir(os.path.join(recycle_bin, v)) 146 | for i in range(num): 147 | src = img_list[i] 148 | dest = os.path.join(recycle_bin, v, os.path.basename(src)) 149 | move_file(src, dest) 150 | 151 | src = img_list[-(i + 1)] 152 | dest = os.path.join(recycle_bin, v, os.path.basename(src)) 153 | move_file(src, dest) 154 | -------------------------------------------------------------------------------- /utils/video_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from base.os_base import handle_dir, copy_file, listdir 3 | from utils.image_utils import cv2_resize_images, matlab_resize_images, shift_images 4 | 5 | 6 | def matlab_resize_videos(ori_root, dest_root, scale=1.0, method='bicubic', filename_template="{}.png"): 7 | ''' 8 | function: 9 | resizing videos in batches, same as matlab2017 imresize 10 | params: 11 | ori_root: string, the dir of videos that need to be processed 12 | dest_root: string, the dir to save processed videos 13 | scale: float, the resize scale 14 | method: string, the interpolation method, 15 | optional: 'bilinear', 'bicubic' 16 | default: 'bicubic' 17 | filename_template: string, the filename template for saving images 18 | ''' 19 | handle_dir(dest_root) 20 | videos = listdir(ori_root) 21 | for v in videos: 22 | matlab_resize_images( 23 | ori_root=os.path.join(ori_root, v), 24 | dest_root=os.path.join(dest_root, v), 25 | scale=scale, 26 | method=method, 27 | filename_template=filename_template 28 | ) 29 | print("Video", v, "resize done !") 30 | 31 | 32 | def cv2_resize_videos(ori_root, dest_root, scale=1.0, method='bicubic', filename_template="{}.png"): 33 | ''' 34 | function: 35 | resizing videos in batches 36 | params: 37 | ori_root: string, the dir of videos that need to be processed 38 | dest_root: string, the dir to save processed videos 39 | scale: float, the resize scale 40 | method: string, the interpolation method, 41 | optional: 'nearest', 'bilinear', 'bicubic' 42 | default: 'bicubic' 43 | filename_template: string, the filename template for saving images 44 | ''' 45 | handle_dir(dest_root) 46 | videos = listdir(ori_root) 47 | for v in videos: 48 | cv2_resize_images( 49 | ori_root=os.path.join(ori_root, v), 50 | dest_root=os.path.join(dest_root, v), 51 | scale=scale, 52 | method=method, 53 | filename_template=filename_template 54 | ) 55 | print("Video", v, "resize done !") 56 | 57 | 58 | def shift_videos(ori_root, dest_root, offset_x=0., offset_y=0., filename_template="{}.png"): 59 | ''' 60 | function: 61 | shifting videos by (offset_x, offset_y) on (axis-x, axis-y) in batches 62 | params: 63 | ori_root: string, the dir of videos that need to be processed 64 | dest_root: string, the dir to save processed videos 65 | offset_x: float, offset pixels on axis-x 66 | positive=left; negative=right 67 | offset_y: float, offset pixels on axis-y 68 | positive=up; negative=down 69 | filename_template: string, the filename template for saving images 70 | ''' 71 | handle_dir(dest_root) 72 | videos = listdir(ori_root) 73 | for v in videos: 74 | shift_images( 75 | ori_root=os.path.join(ori_root, v), 76 | dest_root=os.path.join(dest_root, v), 77 | offset_x=offset_x, 78 | offset_y=offset_y, 79 | filename_template=filename_template 80 | ) 81 | print("Video", v, "shift done !") 82 | 83 | 84 | def extra_frames_from_videos(ori_root, save_root, fname_template='%4d.png', ffmpeg_path='ffmpeg'): 85 | ''' 86 | function: 87 | ext frames from videos 88 | params: 89 | ori_root: string, the dir of videos that need to be processed 90 | save_root: string, the dir to save processed frames 91 | fname_template: the template for frames' filename 92 | ffmpeg_path: ffmpeg path 93 | ''' 94 | 95 | handle_dir(save_root) 96 | videos = sorted(listdir(ori_root)) 97 | 98 | for v in videos: 99 | vn = v[:-(len(v.split('.')[-1]) + 1)] 100 | video_path = os.path.join(ori_root, v) 101 | png_dir = os.path.join(save_root, vn) 102 | png_path = os.path.join(png_dir, fname_template) 103 | handle_dir(png_dir) 104 | command = '{} -i {} {}'.format(ffmpeg_path, video_path, png_path) 105 | os.system(command) 106 | print("Extra frames from {}".format(video_path)) 107 | 108 | 109 | def zip_frames_to_videos(ori_root, save_root, fname_template='%4d.png', video_ext='mkv', ffmpeg_path='ffmpeg'): 110 | ''' 111 | function: 112 | zip frames to videos 113 | params: 114 | ori_root: string, the dir of frames that need to be processed 115 | save_root: string, the dir to save processed videos 116 | fname_template: the template of frames' filename 117 | video_ext: the extension of videos 118 | ffmpeg_path: ffmpeg path 119 | ''' 120 | 121 | handle_dir(save_root) 122 | videos_name = sorted(listdir(ori_root)) 123 | 124 | for vn in videos_name: 125 | imgs_path = os.path.join(ori_root, vn, fname_template) 126 | video_path = os.path.join(save_root, '{}.{}'.format(vn, video_ext)) 127 | command = '{} -i {} -c:v libx265 -x265-params lossless=1 {}'.format( 128 | ffmpeg_path, imgs_path, video_path 129 | ) # NTIRE 2022 Super-Resolution and Quality Enhancement of Compressed Video 130 | # command = '{} -r 24000/1001 -i {} -vcodec libx265 -pix_fmt yuv422p -crf 10 {}'.format( 131 | # ffmpeg_path, imgs_path, video_path 132 | # ) # youku competition 133 | os.system(command) 134 | print("Zip frames to {}".format(video_path)) 135 | 136 | 137 | def copy_frames_for_fps(ori_root, save_root, mul=12, fname_template="{:0>4}", ext="png"): 138 | ''' 139 | function: 140 | copy frames for fps 141 | params: 142 | ori_root: string, the dir of videos that need to be processed 143 | dest_root: string, the dir to save processed videos 144 | mul: the multiple of copy 145 | fname_template: the template of frames' filename 146 | ext: the ext of frames' filename 147 | ''' 148 | fname_template = fname_template + '.{}' 149 | videos_name = sorted(listdir(ori_root)) 150 | handle_dir(save_root) 151 | for vn in videos_name: 152 | frmames = sorted(listdir(os.path.join(ori_root, vn))) 153 | handle_dir(os.path.join(save_root, vn)) 154 | for i, f in enumerate(frmames): 155 | for j in range(mul): 156 | now_idx = i * mul + j 157 | src = os.path.join(ori_root, vn, f) 158 | dest = os.path.join(save_root, vn, fname_template.format(now_idx, ext)) 159 | copy_file(src, dest) 160 | --------------------------------------------------------------------------------