├── .gitignore
├── LICENSE
├── README.md
├── base
├── file_io_base.py
├── flow_base.py
├── image_base.py
├── imresize_right.py
├── kernel_base.py
├── matlab_imresize.py
└── os_base.py
└── utils
├── LPIPS
└── models
│ ├── __init__.py
│ ├── base_model.py
│ ├── dist_model.py
│ ├── networks_basic.py
│ ├── pretrained_networks.py
│ └── weights
│ ├── v0.0
│ ├── alex.pth
│ ├── squeeze.pth
│ └── vgg.pth
│ └── v0.1
│ ├── alex.pth
│ ├── squeeze.pth
│ └── vgg.pth
├── NIQE
├── niqe.py
└── niqe_image_params.mat
├── ResizeRight
├── LICENSE
├── README.md
├── interp_methods.py
└── resize_right.py
├── create_gif.py
├── dnn_utils.py
├── file_regroup_utils.py
├── image_crop_combine_utils.py
├── image_edge_sharpen.py
├── image_metric_utils.py
├── image_utils.py
├── kernel_metric_utils.py
├── kernel_utils.py
├── plot_utils.py
├── string_utils.py
├── torchstat
├── .gitignore
├── .travis.yml
├── LICENSE
├── README.md
├── detail.md
├── example.py
├── requirements.txt
├── setup.cfg
├── setup.py
├── test_requirements.txt
└── torchstat
│ ├── __init__.py
│ ├── __main__.py
│ ├── compute_flops.py
│ ├── compute_madd.py
│ ├── compute_memory.py
│ ├── model_hook.py
│ ├── reporter.py
│ ├── stat_tree.py
│ └── statistics.py
├── video_metric_utils.py
├── video_regroup_utils.py
└── video_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/
2 | */__pycache__
3 | **/__pycache__
4 | ss*
5 | draft*
6 | temp/
7 | .DS_Store*
8 | **/.DS_Store*
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Haoran Bai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OpenUtility
2 | [](https://github.com/csbhr/Python_Tools/blob/master/LICENSE)
3 | [](https://www.python.org/)
4 |
5 | There are some useful tools for low-level vision tasks.
6 |
7 | - [Calculating Metrics ( PSNR, SSIM, etc.)](#chapter-calculating-metrics)
8 | - [Image / Video Processing ( resize, crop, shift, etc.)](#chapter-image-video-processing)
9 | - [Deep Model Properties ( Params, Flops, etc. )](#chapter-model-properties)
10 | - [File Processing ( csv, etc. )](#chapter-file-processing)
11 | - [Visualize Tools ( plot, optical-flow, etc. )](#chapter-visualize-tools)
12 |
13 |
14 | ## Dependencies
15 |
16 | - Python 3 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux))
17 | - [PyTorch](https://pytorch.org/)
18 | - numpy: `conda install numpy`
19 | - matplotlib: `conda install matplotlib`
20 | - opencv: `conda install opencv`
21 | - pandas: `conda install pandas`
22 |
23 | ## Easy to use
24 | Here are some simple demos. If you want to learn more about the usage of these tools, you can refer to the source code for more optional parameters.
25 |
26 |
27 |
28 | ### 1. Calculating Metrics ( PSNR, SSIM, etc.)
29 |
30 | ##### 1.1 PSNR/SSIM
31 | - Following the demo for batch operation:
32 | ```python
33 | # Images
34 | from utils.image_metric_utils import batch_calc_image_PSNR_SSIM
35 | root_list = [
36 | {
37 | 'output': '/path/to/output images 1',
38 | 'gt': '/path/to/gt images 1'
39 | },
40 | {
41 | 'output': '/path/to/output images 2',
42 | 'gt': '/path/to/gt images 2'
43 | },
44 | ]
45 | batch_calc_image_PSNR_SSIM(root_list)
46 |
47 | # Videos
48 | from utils.video_metric_utils import batch_calc_video_PSNR_SSIM
49 | root_list = [
50 | {
51 | 'output': '/path/to/output videos 1',
52 | 'gt': '/path/to/gt videos 1'
53 | },
54 | {
55 | 'output': '/path/to/output videos 2',
56 | 'gt': '/path/to/gt videos 2'
57 | },
58 | ]
59 | batch_calc_video_PSNR_SSIM(root_list)
60 | ```
61 |
62 | ##### 1.2 LPIPS
63 | - Following the demo for batch operation:
64 | ```python
65 | # Images
66 | from utils.image_metric_utils import batch_calc_image_LPIPS
67 | root_list = [
68 | {
69 | 'output': '/path/to/output images 1',
70 | 'gt': '/path/to/gt images 1'
71 | },
72 | {
73 | 'output': '/path/to/output images 2',
74 | 'gt': '/path/to/gt images 2'
75 | },
76 | ]
77 | batch_calc_image_LPIPS(root_list)
78 |
79 | # Videos
80 | from utils.video_metric_utils import batch_calc_video_LPIPS
81 | root_list = [
82 | {
83 | 'output': '/path/to/output videos 1',
84 | 'gt': '/path/to/gt videos 1'
85 | },
86 | {
87 | 'output': '/path/to/output videos 2',
88 | 'gt': '/path/to/gt videos 2'
89 | },
90 | ]
91 | batch_calc_video_LPIPS(root_list)
92 | ```
93 |
94 | ##### 1.3 NIQE
95 | - Following the demo for batch operation:
96 | ```python
97 | # Images
98 | from utils.image_metric_utils import batch_calc_image_NIQE
99 | root_list = [
100 | {
101 | 'output': '/path/to/output images 1',
102 | 'gt': '/path/to/gt images 1'
103 | },
104 | {
105 | 'output': '/path/to/output images 2',
106 | 'gt': '/path/to/gt images 2'
107 | },
108 | ]
109 | batch_calc_image_NIQE(root_list)
110 |
111 | # Videos
112 | from utils.video_metric_utils import batch_calc_video_NIQE
113 | root_list = [
114 | {
115 | 'output': '/path/to/output videos 1',
116 | 'gt': '/path/to/gt videos 1'
117 | },
118 | {
119 | 'output': '/path/to/output videos 2',
120 | 'gt': '/path/to/gt videos 2'
121 | },
122 | ]
123 | batch_calc_video_NIQE(root_list)
124 | ```
125 |
126 | ##### 1.4 Kernel Gradient Similarity
127 | - Following the demo for batch operation:
128 | ```python
129 | # Images: video_type=False
130 | # Videos: video_type=True
131 | from utils.kernel_metric_utils import batch_calc_kernel_metric
132 | root_list = [
133 | {
134 | 'output': '/path/to/output kernel 1',
135 | 'gt': '/path/to/gt kernel 1'
136 | },
137 | {
138 | 'output': '/path/to/output kernel 2',
139 | 'gt': '/path/to/gt kernel 2'
140 | },
141 | ]
142 | batch_calc_kernel_metric(root_list, video_type=False)
143 | ```
144 |
145 |
146 |
147 | ### 2. Image / Video Processing ( resize, crop, shift, etc.)
148 |
149 | ##### 2.1 Image/Video Resize
150 | - We use fatheral's python implementation of matLab imresize() function [fatheral/matlab_imresize](https://github.com/fatheral/matlab_imresize).
151 | - Following the demo for batch operation:
152 |
153 | ```python
154 | # Images
155 | from utils.image_utils import matlab_resize_images
156 |
157 | ori_root = '/path/to/ori images'
158 | dest_root = '/path/to/dest images'
159 | matlab_resize_images(ori_root, dest_root, scale=2.0)
160 |
161 | # Videos
162 | from utils.video_utils import matlab_resize_videos
163 |
164 | ori_root = '/path/to/ori videos'
165 | dest_root = '/path/to/dest videos'
166 | matlab_resize_videos(ori_root, dest_root, scale=2.0)
167 | ```
168 | - We also apply opencv for resizing.
169 | - Following the demo for batch operation:
170 |
171 | ```python
172 | from utils.image_utils import cv2_resize_images
173 |
174 | ori_root = '/path/to/ori images'
175 | dest_root = '/path/to/dest images'
176 | cv2_resize_images(ori_root, dest_root, scale=2.0)
177 |
178 | # Videos
179 | from utils.video_utils import cv2_resize_videos
180 |
181 | ori_root = '/path/to/ori videos'
182 | dest_root = '/path/to/dest videos'
183 | cv2_resize_videos(ori_root, dest_root, scale=2.0)
184 | ```
185 |
186 | ##### 2.2 Crop and combine images
187 | - When you need to infer large image, you can crop image to many patches with padding by following the demo:
188 | ```python
189 | # Notice:
190 | # filenames should not contain the character "-"
191 | # the crop flag "x-x-x-x" will be at the end of filename when cropping
192 | # the combine operation will use the crop flag "x-x-x-x"
193 | from utils.image_crop_combine_utils import *
194 | ori_root = '/path/to/ori images'
195 | dest_root = '/path/to/dest images'
196 | batch_crop_img_with_padding(ori_root, dest_root, min_size=(800, 800), padding=100)
197 | ```
198 | - When you finish inferring large image with cropped patches, you can combine patches to image by following the demo:
199 | ```python
200 | # Notice:
201 | # filenames should not contain the character "-" except for the crop flag
202 | # the crop flag "x-x-x-x" should be at the end of filename when combining
203 | from utils.image_crop_combine_utils import *
204 | ori_root = '/path/to/ori images'
205 | dest_root = '/path/to/dest images'
206 | batch_combine_img(ori_root, dest_root, padding=100)
207 | ```
208 | - You can traversal crop image to many patches with same interval by following the demo:
209 | ```python
210 | from utils.image_crop_combine_utils import *
211 | ori_root = '/path/to/ori images'
212 | dest_root = '/path/to/dest images'
213 | batch_traverse_crop_img(ori_root, dest_root, dsize=(800, 800), interval=400)
214 | ```
215 | - You can select valid patch that are not too smooth by following the demo:
216 | ```python
217 | from utils.image_crop_combine_utils import *
218 | ori_root = '/path/to/ori images'
219 | dest_root = '/path/to/dest images'
220 | batch_select_valid_patch(ori_root, dest_root)
221 | ```
222 |
223 | ##### 2.3 Image/Video Shift
224 | - We use "Bilinear" interpolation method to shift images/videos for sub-pixels.
225 | - Following the demo for batch operation:
226 |
227 | ```python
228 | # Images
229 | from utils.image_utils import shift_images
230 |
231 | ori_root = '/path/to/ori images'
232 | dest_root = '/path/to/dest images'
233 | shift_images(ori_root, dest_root, offset_x=0.5, offset_y=0.5)
234 |
235 | # Videos
236 | from utils.video_utils import shift_videos
237 |
238 | ori_root = '/path/to/ori videos'
239 | dest_root = '/path/to/dest videos'
240 | shift_videos(ori_root, dest_root, offset_x=0.5, offset_y=0.5)
241 | ```
242 |
243 |
244 |
245 | ### 3. Deep Model Properties ( Params, Flops, etc. )
246 | - You can only calculate model Params by following the demo:
247 | ```python
248 | from utils.dnn_utils import cal_parmeters
249 | network = None # Please define the model
250 | cal_parmeters(network)
251 | ```
252 | - You can also calculate more properties of model: Params, Memory, MAdd, Flops, etc.
253 | - We use Swall0w's tools [Swall0w/torchstat](https://github.com/Swall0w/torchstat).
254 | - [Swall0w/torchstat] can not use cuda, so we modified it for using cuda. If you want to calculate Flops on cuda, please using the following command to install torchstat.
255 | ```shell script
256 | cd ./utils/torchstat
257 | python3 setup.py install
258 | ```
259 | - If you do not want to using cuda, please using the following command to install torchstat.
260 | ```shell script
261 | pip install torchstat # pytorch >= 1.0.0
262 | pip install torchstat==0.0.6 # pytorch = 0.4.1
263 | ```
264 | - And then you can calculate properties by following the demo:
265 | ```python
266 | from torchstat import stat
267 | network = None # Please define the model
268 | input_size = (3, 80, 80) # the size of input (channel, height, width)
269 | stat(network, input_size)
270 | ```
271 |
272 |
273 |
274 | ### 4. File Processing ( csv, etc. )
275 |
276 | ##### 4.1 csv file
277 | - You can read a csv file by following the demo:
278 | ```python
279 | from base import file_io_base
280 | data, col_names, row_names = file_io_base.read_csv('filename.csv', col_name_ind=0, row_name_ind=0)
281 | ```
282 | - You can write a numpy.array into a csv file by following the demo:
283 | ```python
284 | from base import file_io_base
285 | import numpy as np
286 | row_names = ['r1', 'r2', 'r3']
287 | col_names = ['c1', 'c2', 'c3', 'c4']
288 | data_array = np.array([[1, 2, 3, 4],
289 | [2, 3, 4, 5],
290 | [3, 4, 5, 6]])
291 | file_io_base.write_csv('filename.csv', data_array, col_names, row_names)
292 | ```
293 |
294 |
295 |
296 | ### 5. Visualize Tools ( plot, optical-flow, etc. )
297 |
298 | ##### 5.1 Plot multiple curves in one figure
299 | - You plot multiple curves in one figure by following the demo:
300 | ```python
301 | import numpy as np
302 | from utils import plot_utils
303 | array_list = [
304 | np.array([1, 4, 5, 3, 6]),
305 | np.array([2, 3, 7, 4, 5]),
306 | ]
307 | label_list = ['curve-1', 'curve-2']
308 | plot_utils.plot_multi_curve(array_list, label_list)
309 | ```
310 |
311 | ##### 5.2 Visualize optical flow
312 | - You can visualize optical flow by following the demo:
313 | ```python
314 | from utils import flow_utils
315 | flow = None # this is flow, shape=(h, w, 2)
316 | rgb_image = flow_utils.visual_flow(flow)
317 | ```
318 |
--------------------------------------------------------------------------------
/base/file_io_base.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 |
4 |
5 | #################################################################################
6 | #### csv file I/O ####
7 | #################################################################################
8 | def read_csv(file_path):
9 | '''
10 | function:
11 | read csv file into list[list[]]
12 | required params:
13 | file_path: the csv file's path
14 | return:
15 | data_list: the csv file's content
16 | '''
17 | csv_data = pd.read_csv(file_path)
18 | data_list = csv_data.values.tolist()
19 | return data_list
20 |
21 |
22 | def write_csv(file_path, data, col_names=None, row_names=None):
23 | '''
24 | function:
25 | write np.array into csv file
26 | required params:
27 | file_path: the csv file's path
28 | data: np.array, the ndim must be 2, the content of csv file
29 | optional params:
30 | col_names: a list, default None, if not None, the len must equal to data.shape[1]
31 | row_names: a list, default None, if not None, the len must equal to data.shape[0]
32 | '''
33 | index = False if row_names is None else True
34 | header = False if col_names is None else True
35 | assert data.ndim == 2, 'the ndim of data must be 2!'
36 | if index:
37 | assert data.shape[0] == len(row_names), 'the number of data rows must equal to len(row_names)'
38 | if header:
39 | assert data.shape[1] == len(col_names), 'the number of data cols must equal to len(col_names)'
40 | data = pd.DataFrame(data, index=row_names, columns=col_names)
41 | data.to_csv(file_path, index=index, header=header)
42 |
43 |
44 | #################################################################################
45 | #### excel file I/O ####
46 | #################################################################################
47 | def read_excel(file_path, sheet_name=0):
48 | '''
49 | function:
50 | read excel file into list[list[]]
51 | required params:
52 | file_path: the excel file's path
53 | optional params:
54 | sheet_name: int or str, the index/name of read sheet
55 | return:
56 | data_list: the excel file's content
57 | '''
58 | excel_data = pd.read_excel(file_path, sheet_name=sheet_name)
59 | data_list = excel_data.values.tolist()
60 | return data_list
61 |
--------------------------------------------------------------------------------
/base/flow_base.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 |
5 | def visual_flow(flow):
6 | # 色调H:用角度度量,取值范围为0°~360°,从红色开始按逆时针方向计算,红色为0°,绿色为120°,蓝色为240°
7 | # 饱和度S:取值范围为0.0~1.0
8 | # 亮度V:取值范围为0.0(黑色)~1.0(白色)
9 | # flow shape: [h, w, 2]
10 | h, w = flow.shape[:2]
11 | hsv = np.zeros((h, w, 3), np.uint8)
12 | mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
13 | hsv[..., 0] = ang * 180 / np.pi / 2
14 | hsv[..., 1] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
15 | # flownet是将V赋值为255, 此函数遵循flownet,饱和度S代表像素位移的大小,亮度都为最大,便于观看
16 | # 也有的光流可视化讲s赋值为255,亮度代表像素位移的大小,整个图片会很暗,很少这样用
17 | hsv[..., 2] = 255
18 | bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
19 | return bgr
20 |
--------------------------------------------------------------------------------
/base/image_base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | import math
5 | from torch.autograd import Variable
6 | import cv2
7 | from base.matlab_imresize import imresize
8 |
9 |
10 | #################################################################################
11 | #### Image Operation ####
12 | #################################################################################
13 | def matlab_imresize(img, scalar_scale=None, output_shape=None, method='bicubic'):
14 | '''same as matlab2017 imresize
15 | img: shape=[h, w, c]
16 | scalar_scale: the resize scale
17 | if None, using output_shape
18 | output_shape: the resize shape, (h, w)
19 | if scalar_scale=None, using this param
20 | method: the interpolation method
21 | optional: 'bicubic', 'bilinear'
22 | default: 'bicubic'
23 | '''
24 | return imresize(
25 | I=img,
26 | scalar_scale=scalar_scale,
27 | output_shape=output_shape,
28 | method=method
29 | )
30 |
31 |
32 | def image_shift(img, offset_x=0., offset_y=0.):
33 | '''shift the img by (offset_x, offset_y) on (axis-x, axis-y)
34 | img: numpy.array, shape=[h, w, c]
35 | offset_x: offset pixels on axis-x
36 | positive=left; negative=right
37 | offset_y: offset pixels on axis-y
38 | positive=up; negative=down
39 | '''
40 | img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).float()
41 | B, C, H, W = img_tensor.size()
42 |
43 | # init flow
44 | flo = torch.ones(B, 2, H, W).type_as(img_tensor)
45 | flo[:, 0, :, :] *= offset_x
46 | flo[:, 1, :, :] *= offset_y
47 |
48 | # mesh grid
49 | xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
50 | yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
51 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
52 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
53 | grid = torch.cat((xx, yy), 1).float()
54 | vgrid = Variable(grid) + flo
55 |
56 | # scale grid to [-1,1]
57 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
58 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0
59 |
60 | # Interpolation
61 | vgrid = vgrid.permute(0, 2, 3, 1)
62 | output_tensor = F.grid_sample(img_tensor, vgrid, padding_mode='border')
63 |
64 | output = output_tensor.round()[0].detach().numpy().transpose(1, 2, 0).astype(img.dtype)
65 |
66 | return output
67 |
68 |
69 | def rgb2ycbcr(img, range=255., only_y=True):
70 | """same as matlab rgb2ycbcr, please use bgr2ycbcr when using cv2.imread
71 | img: shape=[h, w, 3]
72 | range: the data range
73 | only_y: only return Y channel
74 | """
75 | in_img_type = img.dtype
76 | img.astype(np.float32)
77 | range_scale = 255. / range
78 | img *= range_scale
79 |
80 | # convert
81 | if only_y:
82 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
83 | else:
84 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
85 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
86 |
87 | rlt /= range_scale
88 | if in_img_type == np.uint8:
89 | rlt = rlt.round()
90 | return rlt.astype(in_img_type)
91 |
92 |
93 | def bgr2ycbcr(img, range=255., only_y=True):
94 | """bgr version of rgb2ycbcr, for cv2.imread
95 | img: shape=[h, w, 3]
96 | range: the data range
97 | only_y: only return Y channel
98 | """
99 | in_img_type = img.dtype
100 | img.astype(np.float32)
101 | range_scale = 255. / range
102 | img *= range_scale
103 |
104 | # convert
105 | if only_y:
106 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
107 | else:
108 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
109 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
110 |
111 | rlt /= range_scale
112 | if in_img_type == np.uint8:
113 | rlt = rlt.round()
114 | return rlt.astype(in_img_type)
115 |
116 |
117 | def ycbcr2rgb(img, range=255.):
118 | """same as matlab ycbcr2rgb
119 | img: shape=[h, w, 3]
120 | range: the data range
121 | """
122 | in_img_type = img.dtype
123 | img.astype(np.float32)
124 | range_scale = 255. / range
125 | img *= range_scale
126 |
127 | # convert
128 | rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
129 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
130 |
131 | rlt /= range_scale
132 | if in_img_type == np.uint8:
133 | rlt = rlt.round()
134 | return rlt.astype(in_img_type)
135 |
136 |
137 | #################################################################################
138 | #### Image PSNR SSIM ####
139 | #################################################################################
140 |
141 | def RMSE(img1, img2):
142 | img1 = img1.astype(np.float64)
143 | img2 = img2.astype(np.float64)
144 | mse = np.mean((img1 - img2) ** 2)
145 | if mse == 0:
146 | return float('inf')
147 | return math.sqrt(mse)
148 |
149 |
150 | def PSNR(img1, img2):
151 | '''
152 | img1, img2: [0, 255]
153 | '''
154 | img1 = img1.astype(np.float64)
155 | img2 = img2.astype(np.float64)
156 | mse = np.mean((img1 - img2) ** 2)
157 | if mse == 0:
158 | return float('inf')
159 | return 20 * math.log10(255.0 / math.sqrt(mse))
160 |
161 |
162 | def SSIM(img1, img2):
163 | '''calculate SSIM
164 | the same outputs as MATLAB's
165 | img1, img2: [0, 255]
166 | '''
167 |
168 | def ssim(img1, img2):
169 | C1 = (0.01 * 255) ** 2
170 | C2 = (0.03 * 255) ** 2
171 |
172 | img1 = img1.astype(np.float64)
173 | img2 = img2.astype(np.float64)
174 | kernel = cv2.getGaussianKernel(11, 1.5)
175 | window = np.outer(kernel, kernel.transpose())
176 |
177 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
178 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
179 | mu1_sq = mu1 ** 2
180 | mu2_sq = mu2 ** 2
181 | mu1_mu2 = mu1 * mu2
182 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
183 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
184 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
185 |
186 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
187 | (sigma1_sq + sigma2_sq + C2))
188 | return ssim_map.mean()
189 |
190 | if not img1.shape == img2.shape:
191 | raise ValueError('Input images must have the same dimensions.')
192 | if img1.ndim == 2:
193 | return ssim(img1, img2)
194 | elif img1.ndim == 3:
195 | if img1.shape[2] == 3:
196 | ssims = []
197 | for i in range(3):
198 | ssims.append(ssim(img1, img2))
199 | return np.array(ssims).mean()
200 | elif img1.shape[2] == 1:
201 | return ssim(np.squeeze(img1), np.squeeze(img2))
202 | else:
203 | raise ValueError('Wrong input image dimensions.')
204 |
205 |
206 | def PSNR_SSIM_Shift_Best(img1, img2, window_size=5):
207 | '''
208 | img1, img2: [0, 255]
209 | '''
210 | img1 = img1.astype(np.float64)
211 | img2 = img2.astype(np.float64)
212 | img1 = img1[window_size:-window_size, window_size:-window_size, :]
213 | best_psnr, best_ssim = 0., 0.
214 | for i in range(-window_size, window_size + 1):
215 | for j in range(-window_size, window_size + 1):
216 | shifted_img2 = image_shift(img2, offset_x=i, offset_y=j)
217 | shifted_img2 = shifted_img2[window_size:-window_size, window_size:-window_size, :]
218 | psnr = PSNR(img1, shifted_img2)
219 | ssim = SSIM(img1, shifted_img2)
220 | best_psnr = max(best_psnr, psnr)
221 | best_ssim = max(best_ssim, ssim)
222 | return best_psnr, best_ssim
223 |
224 |
225 | #################################################################################
226 | #### Others ####
227 | #################################################################################
228 |
229 |
230 | def evaluate_smooth(img):
231 | x = cv2.Sobel(img, cv2.CV_16S, 1, 0)
232 | y = cv2.Sobel(img, cv2.CV_16S, 0, 1)
233 | absX = cv2.convertScaleAbs(x)
234 | absY = cv2.convertScaleAbs(y)
235 | dst = cv2.addWeighted(absX, 0.5, absY, 0.5, 0)
236 | smooth = np.mean(dst)
237 | return smooth
238 |
239 |
240 | def calc_grad_sobel(img, device='cuda'):
241 | if not isinstance(img, torch.Tensor):
242 | raise Exception("Now just support torch.Tensor. See the Type(img)={}".format(type(img)))
243 | if not img.ndimension() == 4:
244 | raise Exception("Tensor ndimension must equal to 4. See the img.ndimension={}".format(img.ndimension()))
245 |
246 | img = torch.mean(img, dim=1, keepdim=True)
247 |
248 | # img = calc_meanFilter(img, device=device) # meanFilter
249 |
250 | sobel_filter_X = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).reshape((1, 1, 3, 3))
251 | sobel_filter_Y = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).reshape((1, 1, 3, 3))
252 | sobel_filter_X = torch.from_numpy(sobel_filter_X).float().to(device)
253 | sobel_filter_Y = torch.from_numpy(sobel_filter_Y).float().to(device)
254 | grad_X = F.conv2d(img, sobel_filter_X, bias=None, stride=1, padding=1)
255 | grad_Y = F.conv2d(img, sobel_filter_Y, bias=None, stride=1, padding=1)
256 | grad = torch.sqrt(grad_X.pow(2) + grad_Y.pow(2))
257 |
258 | return grad_X, grad_Y, grad
259 |
260 |
261 | def calc_meanFilter(img, kernel_size=11, n_channel=1, device='cuda'):
262 | mean_filter_X = np.ones(shape=(1, 1, kernel_size, kernel_size), dtype=np.float32) / (kernel_size * kernel_size)
263 | mean_filter_X = torch.from_numpy(mean_filter_X).float().to(device)
264 | new_img = torch.zeros_like(img)
265 | for i in range(n_channel):
266 | new_img[:, i:i + 1, :, :] = F.conv2d(img[:, i:i + 1, :, :], mean_filter_X, bias=None,
267 | stride=1, padding=kernel_size // 2)
268 | return new_img
269 |
270 |
271 | def warp_by_flow(x, flo, device='cuda'):
272 | B, C, H, W = flo.size()
273 |
274 | # mesh grid
275 | xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
276 | yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
277 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
278 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
279 | grid = torch.cat((xx, yy), 1).float()
280 | grid = grid.to(device)
281 | vgrid = Variable(grid) + flo
282 |
283 | # scale grid to [-1,1]
284 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
285 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0
286 |
287 | vgrid = vgrid.permute(0, 2, 3, 1)
288 | output = F.grid_sample(x, vgrid, padding_mode='border')
289 |
290 | return output
291 |
--------------------------------------------------------------------------------
/base/kernel_base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import math
3 | import cv2
4 | from scipy import io as scio
5 | from base.image_base import RMSE, PSNR, SSIM
6 |
7 |
8 | def load_mat_kernel(mat_path):
9 | data_dict = scio.loadmat(mat_path)
10 | key = [v for v in data_dict.keys() if v not in ['__header__', '__version__', '__globals__']][0]
11 | return data_dict[key]
12 |
13 |
14 | def kernel2png(kernel):
15 | # kernel: [ks, ks], float32/float64
16 | # return kernel_png: [ks, ks, 3], uint8
17 | kernel = cv2.resize(kernel, dsize=(0, 0), fx=8, fy=8, interpolation=cv2.INTER_CUBIC)
18 | kernel = np.clip(kernel, 0, np.max(kernel))
19 | kernel = kernel / np.sum(kernel)
20 | mi = np.min(kernel)
21 | ma = np.max(kernel)
22 | kernel = (kernel - mi) / (ma - mi)
23 | kernel = np.round(np.clip(kernel * 255., 0, 255))
24 | kernel_png = np.stack([kernel, kernel, kernel], axis=2).astype('uint8')
25 | return kernel_png
26 |
27 |
28 | def Gradient_Similarity(gradB, gradT):
29 | '''
30 | calculate the gradient similarity using the response of convolution devided by the norm
31 | Reference: Z. Hu and M.-H. Yang. Good regions to deblur. In ECCV 2012
32 | '''
33 | gradB = gradB.astype(np.float64)
34 | gradT = gradT.astype(np.float64)
35 |
36 | gradB_sum = math.sqrt(np.sum(np.power(gradB, 2)))
37 | gradT_sum = math.sqrt(np.sum(np.power(gradT, 2)))
38 |
39 | kb, kt = gradB.shape[0], gradT.shape[0]
40 | psize = kt // 2
41 | gradB_norm = gradB / gradB_sum
42 | gradB_pad = np.zeros(shape=[kb + 2 * psize, kb + 2 * psize], dtype=np.float64)
43 | gradB_pad[psize:-psize, psize:-psize] = gradB_norm
44 | corr = cv2.filter2D(src=gradB_pad, ddepth=-1, kernel=gradT, borderType=cv2.BORDER_CONSTANT)
45 |
46 | similarity = np.max(corr) / gradT_sum
47 |
48 | return similarity
49 |
50 |
51 | def Kernel_RMSE_PSNR_SSIM(kernel1, kernel2):
52 | h1, w1 = kernel1.shape
53 | h2, w2 = kernel2.shape
54 |
55 | mh, mw = max(h1, h2), max(w1, w2)
56 | kernel1_pad = np.zeros(shape=(mh, mw)).astype(kernel1.dtype)
57 | kernel2_pad = np.zeros(shape=(mh, mw)).astype(kernel2.dtype)
58 | kernel1_pad[(mh - h1) // 2:(mh - h1) // 2 + h1, (mw - w1) // 2:(mw - w1) // 2 + w1] = kernel1
59 | kernel2_pad[(mh - h2) // 2:(mh - h2) // 2 + h2, (mw - w2) // 2:(mw - w2) // 2 + w2] = kernel2
60 |
61 | # kernel1_pad = cv2.resize(kernel1, dsize=(w2, w2), interpolation=cv2.INTER_CUBIC)
62 | # kernel2_pad = kernel2
63 |
64 | rmse = RMSE(kernel1_pad, kernel2_pad)
65 | psnr = PSNR(kernel1_pad * 255, kernel2_pad * 255)
66 | ssim = SSIM(kernel1_pad * 255, kernel2_pad * 255)
67 | return rmse, psnr, ssim
68 |
--------------------------------------------------------------------------------
/base/matlab_imresize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from math import ceil
3 |
4 |
5 | def deriveSizeFromScale(img_shape, scale):
6 | output_shape = []
7 | for k in range(2):
8 | output_shape.append(int(ceil(scale[k] * img_shape[k])))
9 | return output_shape
10 |
11 |
12 | def deriveScaleFromSize(img_shape_in, img_shape_out):
13 | scale = []
14 | for k in range(2):
15 | scale.append(1.0 * img_shape_out[k] / img_shape_in[k])
16 | return scale
17 |
18 |
19 | def triangle(x):
20 | x = np.array(x).astype(np.float64)
21 | lessthanzero = np.logical_and((x >= -1), x < 0)
22 | greaterthanzero = np.logical_and((x <= 1), x >= 0)
23 | f = np.multiply((x + 1), lessthanzero) + np.multiply((1 - x), greaterthanzero)
24 | return f
25 |
26 |
27 | def cubic(x):
28 | x = np.array(x).astype(np.float64)
29 | absx = np.absolute(x)
30 | absx2 = np.multiply(absx, absx)
31 | absx3 = np.multiply(absx2, absx)
32 | f = np.multiply(1.5 * absx3 - 2.5 * absx2 + 1, absx <= 1) + np.multiply(-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2,
33 | (1 < absx) & (absx <= 2))
34 | return f
35 |
36 |
37 | def contributions(in_length, out_length, scale, kernel, k_width):
38 | if scale < 1:
39 | h = lambda x: scale * kernel(scale * x)
40 | kernel_width = 1.0 * k_width / scale
41 | else:
42 | h = kernel
43 | kernel_width = k_width
44 | x = np.arange(1, out_length + 1).astype(np.float64)
45 | u = x / scale + 0.5 * (1 - 1 / scale)
46 | left = np.floor(u - kernel_width / 2)
47 | P = int(ceil(kernel_width)) + 2
48 | ind = np.expand_dims(left, axis=1) + np.arange(P) - 1 # -1 because indexing from 0
49 | indices = ind.astype(np.int32)
50 | weights = h(np.expand_dims(u, axis=1) - indices - 1) # -1 because indexing from 0
51 | weights = np.divide(weights, np.expand_dims(np.sum(weights, axis=1), axis=1))
52 | aux = np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))).astype(np.int32)
53 | indices = aux[np.mod(indices, aux.size)]
54 | ind2store = np.nonzero(np.any(weights, axis=0))
55 | weights = weights[:, ind2store]
56 | indices = indices[:, ind2store]
57 | return weights, indices
58 |
59 |
60 | def imresizemex(inimg, weights, indices, dim):
61 | in_shape = inimg.shape
62 | w_shape = weights.shape
63 | out_shape = list(in_shape)
64 | out_shape[dim] = w_shape[0]
65 | outimg = np.zeros(out_shape)
66 | if dim == 0:
67 | for i_img in range(in_shape[1]):
68 | for i_w in range(w_shape[0]):
69 | w = weights[i_w, :]
70 | ind = indices[i_w, :]
71 | im_slice = inimg[ind, i_img].astype(np.float64)
72 | outimg[i_w, i_img] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
73 | elif dim == 1:
74 | for i_img in range(in_shape[0]):
75 | for i_w in range(w_shape[0]):
76 | w = weights[i_w, :]
77 | ind = indices[i_w, :]
78 | im_slice = inimg[i_img, ind].astype(np.float64)
79 | outimg[i_img, i_w] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
80 | if inimg.dtype == np.uint8:
81 | outimg = np.clip(outimg, 0, 255)
82 | return np.around(outimg).astype(np.uint8)
83 | else:
84 | return outimg
85 |
86 |
87 | def imresizevec(inimg, weights, indices, dim):
88 | wshape = weights.shape
89 | if dim == 0:
90 | weights = weights.reshape((wshape[0], wshape[2], 1, 1))
91 | outimg = np.sum(weights * ((inimg[indices].squeeze(axis=1)).astype(np.float64)), axis=1)
92 | elif dim == 1:
93 | weights = weights.reshape((1, wshape[0], wshape[2], 1))
94 | outimg = np.sum(weights * ((inimg[:, indices].squeeze(axis=2)).astype(np.float64)), axis=2)
95 | if inimg.dtype == np.uint8:
96 | outimg = np.clip(outimg, 0, 255)
97 | return np.around(outimg).astype(np.uint8)
98 | else:
99 | return outimg
100 |
101 |
102 | def resizeAlongDim(A, dim, weights, indices, mode="vec"):
103 | if mode == "org":
104 | out = imresizemex(A, weights, indices, dim)
105 | else:
106 | out = imresizevec(A, weights, indices, dim)
107 | return out
108 |
109 |
110 | def imresize(I, scalar_scale=None, method='bicubic', output_shape=None, mode="vec"):
111 | if method is 'bicubic':
112 | kernel = cubic
113 | elif method is 'bilinear':
114 | kernel = triangle
115 | else:
116 | print('Error: Unidentified method supplied')
117 |
118 | kernel_width = 4.0
119 | # Fill scale and output_size
120 | if scalar_scale is not None:
121 | scalar_scale = float(scalar_scale)
122 | scale = [scalar_scale, scalar_scale]
123 | output_size = deriveSizeFromScale(I.shape, scale)
124 | elif output_shape is not None:
125 | scale = deriveScaleFromSize(I.shape, output_shape)
126 | output_size = list(output_shape)
127 | else:
128 | print('Error: scalar_scale OR output_shape should be defined!')
129 | return
130 | scale_np = np.array(scale)
131 | order = np.argsort(scale_np)
132 | weights = []
133 | indices = []
134 | for k in range(2):
135 | w, ind = contributions(I.shape[k], output_size[k], scale[k], kernel, kernel_width)
136 | weights.append(w)
137 | indices.append(ind)
138 | B = np.copy(I)
139 | flag2D = False
140 | if B.ndim == 2:
141 | B = np.expand_dims(B, axis=2)
142 | flag2D = True
143 | for k in range(2):
144 | dim = order[k]
145 | B = resizeAlongDim(B, dim, weights[dim], indices[dim], mode)
146 | if flag2D:
147 | B = np.squeeze(B, axis=2)
148 | return B
149 |
150 |
151 | def convertDouble2Byte(I):
152 | B = np.clip(I, 0.0, 1.0)
153 | B = 255 * B
154 | return np.around(B).astype(np.uint8)
155 |
--------------------------------------------------------------------------------
/base/os_base.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import os
3 | import glob
4 |
5 |
6 | def handle_dir(dir):
7 | if not os.path.exists(dir):
8 | os.mkdir(dir)
9 | print('mkdir:', dir)
10 |
11 |
12 | def listdir(path):
13 | sys_files = ['.DS_Store']
14 | files = os.listdir(path)
15 | for sf in sys_files:
16 | if sf in files:
17 | files.remove(sf)
18 | return files
19 |
20 |
21 | def glob_match(template):
22 | fpathes = glob.glob(template)
23 | dest_fpathes = []
24 | for fp in fpathes:
25 | if '.DS_Store' not in fp:
26 | dest_fpathes.append(fp)
27 | return dest_fpathes
28 |
29 |
30 | def get_fname_ext(filepath):
31 | filename = os.path.basename(filepath)
32 | ext = filename.split(".")[-1]
33 | fname = filename[:-(len(ext) + 1)]
34 | return fname, ext
35 |
36 |
37 | def copy_file(src, dst):
38 | shutil.copy(src, dst)
39 | print('copy file from {} to {}'.format(os.path.basename(src), os.path.basename(dst)))
40 |
41 |
42 | def move_file(src, dst):
43 | shutil.move(src, dst)
44 | print('move file from {} to {}'.format(os.path.basename(src), os.path.basename(dst)))
45 |
46 |
47 | def rename_file(src, dst):
48 | os.rename(src, dst)
49 | print('rename file from {} to {}'.format(os.path.basename(src), os.path.basename(dst)))
50 |
--------------------------------------------------------------------------------
/utils/LPIPS/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import numpy as np
7 | from skimage.measure import compare_ssim
8 | import torch
9 |
10 | from utils.LPIPS.models import dist_model
11 |
12 | class PerceptualLoss(torch.nn.Module):
13 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0], version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric)
14 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss
15 | super(PerceptualLoss, self).__init__()
16 | print('Setting up Perceptual loss...')
17 | self.use_gpu = use_gpu
18 | self.spatial = spatial
19 | self.gpu_ids = gpu_ids
20 | self.model = dist_model.DistModel()
21 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids, version=version)
22 | print('...[%s] initialized'%self.model.name())
23 | print('...Done')
24 |
25 | def forward(self, pred, target, normalize=False):
26 | """
27 | Pred and target are Variables.
28 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1]
29 | If normalize is False, assumes the images are already between [-1,+1]
30 |
31 | Inputs pred and target are Nx3xHxW
32 | Output pytorch Variable N long
33 | """
34 |
35 | if normalize:
36 | target = 2 * target - 1
37 | pred = 2 * pred - 1
38 |
39 | return self.model.forward(target, pred)
40 |
41 | def normalize_tensor(in_feat,eps=1e-10):
42 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
43 | return in_feat/(norm_factor+eps)
44 |
45 | def l2(p0, p1, range=255.):
46 | return .5*np.mean((p0 / range - p1 / range)**2)
47 |
48 | def psnr(p0, p1, peak=255.):
49 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2))
50 |
51 | def dssim(p0, p1, range=255.):
52 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
53 |
54 | def rgb2lab(in_img,mean_cent=False):
55 | from skimage import color
56 | img_lab = color.rgb2lab(in_img)
57 | if(mean_cent):
58 | img_lab[:,:,0] = img_lab[:,:,0]-50
59 | return img_lab
60 |
61 | def tensor2np(tensor_obj):
62 | # change dimension of a tensor object into a numpy array
63 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
64 |
65 | def np2tensor(np_obj):
66 | # change dimenion of np array into tensor array
67 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
68 |
69 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
70 | # image tensor to lab tensor
71 | from skimage import color
72 |
73 | img = tensor2im(image_tensor)
74 | img_lab = color.rgb2lab(img)
75 | if(mc_only):
76 | img_lab[:,:,0] = img_lab[:,:,0]-50
77 | if(to_norm and not mc_only):
78 | img_lab[:,:,0] = img_lab[:,:,0]-50
79 | img_lab = img_lab/100.
80 |
81 | return np2tensor(img_lab)
82 |
83 | def tensorlab2tensor(lab_tensor,return_inbnd=False):
84 | from skimage import color
85 | import warnings
86 | warnings.filterwarnings("ignore")
87 |
88 | lab = tensor2np(lab_tensor)*100.
89 | lab[:,:,0] = lab[:,:,0]+50
90 |
91 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1)
92 | if(return_inbnd):
93 | # convert back to lab, see if we match
94 | lab_back = color.rgb2lab(rgb_back.astype('uint8'))
95 | mask = 1.*np.isclose(lab_back,lab,atol=2.)
96 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis])
97 | return (im2tensor(rgb_back),mask)
98 | else:
99 | return im2tensor(rgb_back)
100 |
101 | def rgb2lab(input):
102 | from skimage import color
103 | return color.rgb2lab(input / 255.)
104 |
105 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
106 | image_numpy = image_tensor[0].cpu().float().numpy()
107 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
108 | return image_numpy.astype(imtype)
109 |
110 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
111 | return torch.Tensor((image / factor - cent)
112 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
113 |
114 | def tensor2vec(vector_tensor):
115 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0]
116 |
117 | def voc_ap(rec, prec, use_07_metric=False):
118 | """ ap = voc_ap(rec, prec, [use_07_metric])
119 | Compute VOC AP given precision and recall.
120 | If use_07_metric is true, uses the
121 | VOC 07 11 point method (default:False).
122 | """
123 | if use_07_metric:
124 | # 11 point metric
125 | ap = 0.
126 | for t in np.arange(0., 1.1, 0.1):
127 | if np.sum(rec >= t) == 0:
128 | p = 0
129 | else:
130 | p = np.max(prec[rec >= t])
131 | ap = ap + p / 11.
132 | else:
133 | # correct AP calculation
134 | # first append sentinel values at the end
135 | mrec = np.concatenate(([0.], rec, [1.]))
136 | mpre = np.concatenate(([0.], prec, [0.]))
137 |
138 | # compute the precision envelope
139 | for i in range(mpre.size - 1, 0, -1):
140 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
141 |
142 | # to calculate area under PR curve, look for points
143 | # where X axis (recall) changes value
144 | i = np.where(mrec[1:] != mrec[:-1])[0]
145 |
146 | # and sum (\Delta recall) * prec
147 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
148 | return ap
149 |
150 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
151 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.):
152 | image_numpy = image_tensor[0].cpu().float().numpy()
153 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
154 | return image_numpy.astype(imtype)
155 |
156 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.):
157 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.):
158 | return torch.Tensor((image / factor - cent)
159 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
160 |
--------------------------------------------------------------------------------
/utils/LPIPS/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.autograd import Variable
4 | from pdb import set_trace as st
5 | # from IPython import embed
6 |
7 | class BaseModel():
8 | def __init__(self):
9 | pass;
10 |
11 | def name(self):
12 | return 'BaseModel'
13 |
14 | def initialize(self, use_gpu=True, gpu_ids=[0]):
15 | self.use_gpu = use_gpu
16 | self.gpu_ids = gpu_ids
17 |
18 | def forward(self):
19 | pass
20 |
21 | def get_image_paths(self):
22 | pass
23 |
24 | def optimize_parameters(self):
25 | pass
26 |
27 | def get_current_visuals(self):
28 | return self.input
29 |
30 | def get_current_errors(self):
31 | return {}
32 |
33 | def save(self, label):
34 | pass
35 |
36 | # helper saving function that can be used by subclasses
37 | def save_network(self, network, path, network_label, epoch_label):
38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
39 | save_path = os.path.join(path, save_filename)
40 | torch.save(network.state_dict(), save_path)
41 |
42 | # helper loading function that can be used by subclasses
43 | def load_network(self, network, network_label, epoch_label):
44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
45 | save_path = os.path.join(self.save_dir, save_filename)
46 | print('Loading network from %s'%save_path)
47 | network.load_state_dict(torch.load(save_path))
48 |
49 | def update_learning_rate():
50 | pass
51 |
52 | def get_image_paths(self):
53 | return self.image_paths
54 |
55 | def save_done(self, flag=False):
56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag)
57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i')
58 |
59 |
--------------------------------------------------------------------------------
/utils/LPIPS/models/dist_model.py:
--------------------------------------------------------------------------------
1 |
2 | from __future__ import absolute_import
3 |
4 | import numpy as np
5 | import torch
6 | import os
7 | from collections import OrderedDict
8 | from torch.autograd import Variable
9 | from .base_model import BaseModel
10 | from scipy.ndimage import zoom
11 | from tqdm import tqdm
12 |
13 | # from IPython import embed
14 |
15 | from . import networks_basic as networks
16 | from .. import models as util
17 |
18 |
19 | class DistModel(BaseModel):
20 | def name(self):
21 | return self.model_name
22 |
23 | def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, model_path=None,
24 | use_gpu=True, printNet=False, spatial=False,
25 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]):
26 | '''
27 | INPUTS
28 | model - ['net-lin'] for linearly calibrated network
29 | ['net'] for off-the-shelf network
30 | ['L2'] for L2 distance in Lab colorspace
31 | ['SSIM'] for ssim in RGB colorspace
32 | net - ['squeeze','alex','vgg']
33 | model_path - if None, will look in weights/[NET_NAME].pth
34 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM
35 | use_gpu - bool - whether or not to use a GPU
36 | printNet - bool - whether or not to print network architecture out
37 | spatial - bool - whether to output an array containing varying distances across spatial dimensions
38 | is_train - bool - [True] for training mode
39 | lr - float - initial learning rate
40 | beta1 - float - initial momentum term for adam
41 | version - 0.1 for latest, 0.0 was original (with a bug)
42 | gpu_ids - int array - [0] by default, gpus to use
43 | '''
44 | BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids)
45 |
46 | self.model = model
47 | self.net = net
48 | self.is_train = is_train
49 | self.spatial = spatial
50 | self.gpu_ids = gpu_ids
51 | self.model_name = '%s [%s]'%(model,net)
52 |
53 | if(self.model == 'net-lin'): # pretrained net + linear layer
54 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net,
55 | use_dropout=True, spatial=spatial, version=version, lpips=True)
56 | kw = {}
57 | if not use_gpu:
58 | kw['map_location'] = 'cpu'
59 | if(model_path is None):
60 | import inspect
61 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth'%(version,net)))
62 |
63 | if(not is_train):
64 | print('Loading model from: %s'%model_path)
65 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False)
66 |
67 | elif(self.model=='net'): # pretrained network
68 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False)
69 | elif(self.model in ['L2','l2']):
70 | self.net = networks.L2(use_gpu=use_gpu,colorspace=colorspace) # not really a network, only for testing
71 | self.model_name = 'L2'
72 | elif(self.model in ['DSSIM','dssim','SSIM','ssim']):
73 | self.net = networks.DSSIM(use_gpu=use_gpu,colorspace=colorspace)
74 | self.model_name = 'SSIM'
75 | else:
76 | raise ValueError("Model [%s] not recognized." % self.model)
77 |
78 | self.parameters = list(self.net.parameters())
79 |
80 | if self.is_train: # training mode
81 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*)
82 | self.rankLoss = networks.BCERankingLoss()
83 | self.parameters += list(self.rankLoss.net.parameters())
84 | self.lr = lr
85 | self.old_lr = lr
86 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999))
87 | else: # test mode
88 | self.net.eval()
89 |
90 | if(use_gpu):
91 | self.net.to(gpu_ids[0])
92 | self.net = torch.nn.DataParallel(self.net, device_ids=gpu_ids)
93 | if(self.is_train):
94 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0
95 |
96 | if(printNet):
97 | print('---------- Networks initialized -------------')
98 | networks.print_network(self.net)
99 | print('-----------------------------------------------')
100 |
101 | def forward(self, in0, in1, retPerLayer=False):
102 | ''' Function computes the distance between image patches in0 and in1
103 | INPUTS
104 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1]
105 | OUTPUT
106 | computed distances between in0 and in1
107 | '''
108 |
109 | return self.net.forward(in0, in1, retPerLayer=retPerLayer)
110 |
111 | # ***** TRAINING FUNCTIONS *****
112 | def optimize_parameters(self):
113 | self.forward_train()
114 | self.optimizer_net.zero_grad()
115 | self.backward_train()
116 | self.optimizer_net.step()
117 | self.clamp_weights()
118 |
119 | def clamp_weights(self):
120 | for module in self.net.modules():
121 | if(hasattr(module, 'weight') and module.kernel_size==(1,1)):
122 | module.weight.data = torch.clamp(module.weight.data,min=0)
123 |
124 | def set_input(self, data):
125 | self.input_ref = data['ref']
126 | self.input_p0 = data['p0']
127 | self.input_p1 = data['p1']
128 | self.input_judge = data['judge']
129 |
130 | if(self.use_gpu):
131 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0])
132 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0])
133 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0])
134 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0])
135 |
136 | self.var_ref = Variable(self.input_ref,requires_grad=True)
137 | self.var_p0 = Variable(self.input_p0,requires_grad=True)
138 | self.var_p1 = Variable(self.input_p1,requires_grad=True)
139 |
140 | def forward_train(self): # run forward pass
141 | # print(self.net.module.scaling_layer.shift)
142 | # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item())
143 |
144 | self.d0 = self.forward(self.var_ref, self.var_p0)
145 | self.d1 = self.forward(self.var_ref, self.var_p1)
146 | self.acc_r = self.compute_accuracy(self.d0,self.d1,self.input_judge)
147 |
148 | self.var_judge = Variable(1.*self.input_judge).view(self.d0.size())
149 |
150 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge*2.-1.)
151 |
152 | return self.loss_total
153 |
154 | def backward_train(self):
155 | torch.mean(self.loss_total).backward()
156 |
157 | def compute_accuracy(self,d0,d1,judge):
158 | ''' d0, d1 are Variables, judge is a Tensor '''
159 | d1_lt_d0 = (d1 %f' % (type,self.old_lr, lr))
202 | self.old_lr = lr
203 |
204 | def score_2afc_dataset(data_loader, func, name=''):
205 | ''' Function computes Two Alternative Forced Choice (2AFC) score using
206 | distance function 'func' in dataset 'data_loader'
207 | INPUTS
208 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside
209 | func - callable distance function - calling d=func(in0,in1) should take 2
210 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N
211 | OUTPUTS
212 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators
213 | [1] - dictionary with following elements
214 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches
215 | gts - N array in [0,1], preferred patch selected by human evaluators
216 | (closer to "0" for left patch p0, "1" for right patch p1,
217 | "0.6" means 60pct people preferred right patch, 40pct preferred left)
218 | scores - N array in [0,1], corresponding to what percentage function agreed with humans
219 | CONSTS
220 | N - number of test triplets in data_loader
221 | '''
222 |
223 | d0s = []
224 | d1s = []
225 | gts = []
226 |
227 | for data in tqdm(data_loader.load_data(), desc=name):
228 | d0s+=func(data['ref'],data['p0']).data.cpu().numpy().flatten().tolist()
229 | d1s+=func(data['ref'],data['p1']).data.cpu().numpy().flatten().tolist()
230 | gts+=data['judge'].cpu().numpy().flatten().tolist()
231 |
232 | d0s = np.array(d0s)
233 | d1s = np.array(d1s)
234 | gts = np.array(gts)
235 | scores = (d0s= 0]
26 | left_mean_sqrt = 0
27 | right_mean_sqrt = 0
28 | if len(left_data) > 0:
29 | left_mean_sqrt = np.sqrt(np.average(left_data))
30 | if len(right_data) > 0:
31 | right_mean_sqrt = np.sqrt(np.average(right_data))
32 |
33 | if right_mean_sqrt != 0:
34 | gamma_hat = left_mean_sqrt / right_mean_sqrt
35 | else:
36 | gamma_hat = np.inf
37 | # solve r-hat norm
38 |
39 | imdata2_mean = np.mean(imdata2)
40 | if imdata2_mean != 0:
41 | r_hat = (np.average(np.abs(imdata)) ** 2) / (np.average(imdata2))
42 | else:
43 | r_hat = np.inf
44 | rhat_norm = r_hat * (((math.pow(gamma_hat, 3) + 1) * (gamma_hat + 1)) / math.pow(math.pow(gamma_hat, 2) + 1, 2))
45 |
46 | # solve alpha by guessing values that minimize ro
47 | pos = np.argmin((prec_gammas - rhat_norm) ** 2);
48 | alpha = gamma_range[pos]
49 |
50 | gam1 = scipy.special.gamma(1.0 / alpha)
51 | gam2 = scipy.special.gamma(2.0 / alpha)
52 | gam3 = scipy.special.gamma(3.0 / alpha)
53 |
54 | aggdratio = np.sqrt(gam1) / np.sqrt(gam3)
55 | bl = aggdratio * left_mean_sqrt
56 | br = aggdratio * right_mean_sqrt
57 |
58 | # mean parameter
59 | N = (br - bl) * (gam2 / gam1) # *aggdratio
60 | return (alpha, N, bl, br, left_mean_sqrt, right_mean_sqrt)
61 |
62 |
63 | def ggd_features(imdata):
64 | nr_gam = 1 / prec_gammas
65 | sigma_sq = np.var(imdata)
66 | E = np.mean(np.abs(imdata))
67 | rho = sigma_sq / E ** 2
68 | pos = np.argmin(np.abs(nr_gam - rho));
69 | return gamma_range[pos], sigma_sq
70 |
71 |
72 | def paired_product(new_im):
73 | shift1 = np.roll(new_im.copy(), 1, axis=1)
74 | shift2 = np.roll(new_im.copy(), 1, axis=0)
75 | shift3 = np.roll(np.roll(new_im.copy(), 1, axis=0), 1, axis=1)
76 | shift4 = np.roll(np.roll(new_im.copy(), 1, axis=0), -1, axis=1)
77 |
78 | H_img = shift1 * new_im
79 | V_img = shift2 * new_im
80 | D1_img = shift3 * new_im
81 | D2_img = shift4 * new_im
82 |
83 | return (H_img, V_img, D1_img, D2_img)
84 |
85 |
86 | def gen_gauss_window(lw, sigma):
87 | sd = np.float32(sigma)
88 | lw = int(lw)
89 | weights = [0.0] * (2 * lw + 1)
90 | weights[lw] = 1.0
91 | sum = 1.0
92 | sd *= sd
93 | for ii in range(1, lw + 1):
94 | tmp = np.exp(-0.5 * np.float32(ii * ii) / sd)
95 | weights[lw + ii] = tmp
96 | weights[lw - ii] = tmp
97 | sum += 2.0 * tmp
98 | for ii in range(2 * lw + 1):
99 | weights[ii] /= sum
100 | return weights
101 |
102 |
103 | def compute_image_mscn_transform(image, C=1, avg_window=None, extend_mode='constant'):
104 | if avg_window is None:
105 | avg_window = gen_gauss_window(3, 7.0 / 6.0)
106 | assert len(np.shape(image)) == 2
107 | h, w = np.shape(image)
108 | mu_image = np.zeros((h, w), dtype=np.float32)
109 | var_image = np.zeros((h, w), dtype=np.float32)
110 | image = np.array(image).astype('float32')
111 | scipy.ndimage.correlate1d(image, avg_window, 0, mu_image, mode=extend_mode)
112 | scipy.ndimage.correlate1d(mu_image, avg_window, 1, mu_image, mode=extend_mode)
113 | scipy.ndimage.correlate1d(image ** 2, avg_window, 0, var_image, mode=extend_mode)
114 | scipy.ndimage.correlate1d(var_image, avg_window, 1, var_image, mode=extend_mode)
115 | var_image = np.sqrt(np.abs(var_image - mu_image ** 2))
116 | return (image - mu_image) / (var_image + C), var_image, mu_image
117 |
118 |
119 | def _niqe_extract_subband_feats(mscncoefs):
120 | # alpha_m, = extract_ggd_features(mscncoefs)
121 | alpha_m, N, bl, br, lsq, rsq = aggd_features(mscncoefs.copy())
122 | pps1, pps2, pps3, pps4 = paired_product(mscncoefs)
123 | alpha1, N1, bl1, br1, lsq1, rsq1 = aggd_features(pps1)
124 | alpha2, N2, bl2, br2, lsq2, rsq2 = aggd_features(pps2)
125 | alpha3, N3, bl3, br3, lsq3, rsq3 = aggd_features(pps3)
126 | alpha4, N4, bl4, br4, lsq4, rsq4 = aggd_features(pps4)
127 | return np.array([alpha_m, (bl + br) / 2.0,
128 | alpha1, N1, bl1, br1, # (V)
129 | alpha2, N2, bl2, br2, # (H)
130 | alpha3, N3, bl3, bl3, # (D1)
131 | alpha4, N4, bl4, bl4, # (D2)
132 | ])
133 |
134 |
135 | def get_patches_train_features(img, patch_size, stride=8):
136 | return _get_patches_generic(img, patch_size, 1, stride)
137 |
138 |
139 | def get_patches_test_features(img, patch_size, stride=8):
140 | return _get_patches_generic(img, patch_size, 0, stride)
141 |
142 |
143 | def extract_on_patches(img, patch_size):
144 | h, w = img.shape
145 | patch_size = np.int(patch_size)
146 | patches = []
147 | for j in range(0, h - patch_size + 1, patch_size):
148 | for i in range(0, w - patch_size + 1, patch_size):
149 | patch = img[j:j + patch_size, i:i + patch_size]
150 | patches.append(patch)
151 |
152 | patches = np.array(patches)
153 |
154 | patch_features = []
155 | for p in patches:
156 | patch_features.append(_niqe_extract_subband_feats(p))
157 | patch_features = np.array(patch_features)
158 |
159 | return patch_features
160 |
161 |
162 | def _get_patches_generic(img, patch_size, is_train, stride):
163 | h, w = np.shape(img)
164 | if h < patch_size or w < patch_size:
165 | print("Input image is too small")
166 | exit(0)
167 |
168 | # ensure that the patch divides evenly into img
169 | hoffset = (h % patch_size)
170 | woffset = (w % patch_size)
171 |
172 | if hoffset > 0:
173 | img = img[:-hoffset, :]
174 | if woffset > 0:
175 | img = img[:, :-woffset]
176 |
177 | img = img.astype(np.float32)
178 | # img2 = scipy.misc.imresize(img, 0.5, interp='bicubic', mode='F')
179 | _h, _w = img.shape
180 | img2 = np.array(Image.fromarray(img).resize((int(_w * 0.5), int(_h * 0.5)), resample=Image.BICUBIC))
181 |
182 | mscn1, var, mu = compute_image_mscn_transform(img)
183 | mscn1 = mscn1.astype(np.float32)
184 |
185 | mscn2, _, _ = compute_image_mscn_transform(img2)
186 | mscn2 = mscn2.astype(np.float32)
187 |
188 | feats_lvl1 = extract_on_patches(mscn1, patch_size)
189 | feats_lvl2 = extract_on_patches(mscn2, patch_size / 2)
190 |
191 | feats = np.hstack((feats_lvl1, feats_lvl2)) # feats_lvl3))
192 |
193 | return feats
194 |
195 |
196 | def niqe(inputImgData):
197 | patch_size = 96
198 | module_path = dirname(__file__)
199 |
200 | # TODO: memoize
201 | params = scipy.io.loadmat(join(module_path, 'niqe_image_params.mat'))
202 | pop_mu = np.ravel(params["pop_mu"])
203 | pop_cov = params["pop_cov"]
204 |
205 | inputImgData = inputImgData.astype('float32')
206 | M, N = inputImgData.shape
207 |
208 | # assert C == 1, "niqe called with videos containing %d channels. Please supply only the luminance channel" % (C,)
209 | assert M > (
210 | patch_size * 2 + 1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters"
211 | assert N > (
212 | patch_size * 2 + 1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters"
213 |
214 | feats = get_patches_test_features(inputImgData, patch_size)
215 | sample_mu = np.mean(feats, axis=0)
216 | sample_cov = np.cov(feats.T)
217 |
218 | X = sample_mu - pop_mu
219 | covmat = ((pop_cov + sample_cov) / 2.0)
220 | pinvmat = scipy.linalg.pinv(covmat)
221 | niqe_score = np.sqrt(np.dot(np.dot(X, pinvmat), X))
222 |
223 | return niqe_score
224 |
--------------------------------------------------------------------------------
/utils/NIQE/niqe_image_params.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csbhr/OpenUtility/c9cf713c99523c0a2e0be6c2afa988af751ad161/utils/NIQE/niqe_image_params.mat
--------------------------------------------------------------------------------
/utils/ResizeRight/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Assaf Shocher
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/utils/ResizeRight/README.md:
--------------------------------------------------------------------------------
1 | # ResizeRight
2 | This is a resizing packge for images or tensors, that supports both Numpy and PyTorch (**fully differentiable**) seamlessly. The main motivation for creating this is to address some **crucial incorrectness issues** (see item 3 in the list below) that exist in all other resizing packages I am aware of. As far as I know, it is the only one that performs correctly in all cases. ResizeRight is specially made for machine learning, image enhancement and restoration challenges.
3 |
4 | The code is inspired by MATLAB's imresize function, but with crucial differences. It is specifically useful due to the following reasons:
5 |
6 | 1. ResizeRight produces results **identical to MATLAB for the simple cases** (scale_factor * in_size is integer). None of the Python packages I am aware of, currently resize images with results similar to MATLAB's imresize (which is a common benchmark for image resotration tasks, especially super-resolution).
7 |
8 | 2. No other **differntiable** method I am aware of supports **AntiAliasing** as in MATLAB. Actually very few non-differentiable ones, including popular ones, do. This causes artifacts and inconsistency in downscaling. (see [this tweet](https://twitter.com/jaakkolehtinen/status/1258102168176951299) by [Jaakko Lehtinen](https://users.aalto.fi/~lehtinj7/)
9 | for example).
10 |
11 | 3. The most important part: In the general case where scale_factor * in_size is non-integer, **no existing resizing method I am aware of (including MATLAB) performs consistently.** ResizeRight is accurate and consistent due to its ability to process **both scale-factor and output-size** provided by the user. This is a super important feature for super-resolution and learning. One must acknowledge that the same output-size can be resulted with varying scale-factors as output-size is usually determined by *ceil(input_size * scale_factor)*. This situation creates dangerous lack of consistency. Best explained by example: say you have an image of size 9x9 and you resize by scale-factor of 0.5. Result size is 5x5. now you resize with scale-factor of 2. you get result sized 10x10. "no big deal", you must be thinking now, "I can resize it according to output-size 9x9", right? but then you will not get the correct scale-fcator which is calculated as output-size / input-size = 1.8.
12 | Due to a simple observation regarding the projection of the output grid to the input grid, ResizeRight is the only one that consistently maintains the image centered, as in optical zoom while complying with the exact scale-factor and output size the user requires.
13 | This is one of the main reasons for creating this repository. this downscale-upscale consistency is often crucial for learning based tasks (e.g. ["Zero-Shot Super-Resolution"](http://www.wisdom.weizmann.ac.il/~vision/zssr/)), and does not exist in other python packages nor in MATLAB.
14 |
15 | 4. Misalignment in resizing is a pandemic! Many existing packages actually return misaligned results. it is visually not apparent but can cause great damage to image enhancement tasks.(for example, see [how tensorflow's image resize stole 60 days of my life](https://hackernoon.com/how-tensorflows-tf-image-resize-stole-60-days-of-my-life-aba5eb093f35)). I personally also suffered from many misfortunate consequences of such missalignment before and throughout making this method.
16 |
17 | 5. Resizing supports **both Numpy and PyTorch** tensors seamlessly, just by the type of input tensor given. Results are checked to be identical in both modes, so you can safely apply to different tensor types and maintain consistency. No Numpy <-> Torch conversion takes part at any step. The process is done exclusively with one of the frameworks. No direct dependency is needed, so you can run it without having PyTorch installed at all, or without Numpy. You only need one of them.
18 |
19 | 6. In the case where scale_factor * in_size is a rational number with denominater not too big (this is a prameter), calculation is done efficiently based on convolutions (currently only PyTorch is supported). This is extremely more efficient for big tensors and suitable for working on large batches or high resolution. Note that this efficient calculation can be applied to certain dims that maintain the conditions while performing the regular calculation for the other dims.
20 |
21 | 7. Differently from some existing methods, including MATLAB, You can **resize N-D tensors in M-D dimensions.** for any M<=N.
22 |
23 | 8. You can specify a list of scale-factors to resize each dimension using a different scale-factor.
24 |
25 | 9. You can easily add and embed your own interpolation methods for the resizer to use (see interp_mehods.py)
26 |
27 | 10. All used framework padding methods are supported (depends on numpy/PyTorch mode)
28 | PyTorch: 'constant', 'reflect', 'replicate', 'circular'.
29 | Numpy: ‘constant’, ‘edge’, ‘linear_ramp’, ‘maximum’, ‘mean’, ‘median’, ‘minimum’, ‘reflect’, ‘symmetric’, ‘wrap’, ‘empty’
30 |
31 | 11. Some general calculations are done more efficiently than the MATLAB version (one example is that MATLAB extends the kernel size by 2, and then searches for zero columns in the weights and cancels them. ResizeRight uses an observation that resizing is actually a continuous convolution and avoids having these redundancies ahead, see Shocher et al. ["From Discrete to Continuous Convolution Layers"](https://arxiv.org/abs/2006.11120)).
32 | --------
33 |
34 | ### Usage:
35 | For dynamic resize using either Numpy or PyTorch:
36 | ```
37 | resize_right.resize(input, scale_factors=None, out_shape=None,
38 | interp_method=interp_methods.cubic, support_sz=None,
39 | antialiasing=True, by_convs=False, scale_tolerance=None,
40 | max_denominator=10, pad_mode='constant'):
41 | ```
42 |
43 | __input__ :
44 | the input image/tensor, a Numpy or Torch tensor.
45 |
46 | __scale_factors__:
47 | can be specified as-
48 | 1. one scalar scale - then it will be assumed that you want to resize first two dims with this scale for Numpy or last two dims for PyTorch.
49 | 2. a list or tupple of scales - one for each dimension you want to resize. note: if length of the list is L then first L dims will be rescaled for Numpy and last L for PyTorch.
50 | 3. not specified - then it will be calculated using output_size. this is not recomended (see advantage 3 in the list above).
51 |
52 | __out_shape__:
53 | A list or tupple. if shorter than input.shape then only the first/last (depending np/torch) dims are resized. if not specified, can be calcualated from scale_factor.
54 |
55 | __interp_method__:
56 | The type of interpolation used to calculate the weights. this is a scalar to scalar function that can be applied to tensors pointwise. The classical methods are implemented and can be found in interp_methods.py. (cubic, linear, laczos2, lanczos3, box).
57 |
58 | __support_sz__:
59 | This is the support of the interpolation function, i.e length of non-zero segment over its 1d input domain. this is a characteristic of the function. eg. for bicubic 4, linear 2, laczos2 4, lanczos3 6, box 1.
60 |
61 | __antialiasing__:
62 | This is an option similar to MATLAB's default. only relevant for downscaling. if true it basicly means that the kernel is stretched with 1/scale_factor to prevent aliasing (low-pass filtering)
63 |
64 | __by_convs__:
65 | This determines whether to allow efficient calculation using convolutions according to tolerance. This feature should be used when scale_factor * in_size is rational with a denominator low enough (or close enough to being an integer) and the tensors are big (batches or high-resolution).
66 |
67 | __scale_tolerance__:
68 | This is the allowed distance between the M/N closest frac to the float scale_factore provided. if the frac is closer than this distance, then it will be used and efficient convolution calculation will take place.
69 |
70 | __max_denominator__:
71 | When by_convs is on, the scale_factor is translated to a rational frac M/N. Where M is limited by this parameter. The goal is to make the calculation more efficient. The number of convolutions used is the size of the denominator.
72 |
73 | __pad_mode__:
74 | This can be used according to the padding methods of each framework.
75 | PyTorch: 'constant', 'reflect', 'replicate', 'circular'.
76 | Numpy: ‘constant’, ‘edge’, ‘linear_ramp’, ‘maximum’, ‘mean’, ‘median’, ‘minimum’, ‘reflect’, ‘symmetric’, ‘wrap’, ‘empty’
77 |
78 | --------
79 |
80 | ### Cite / credit
81 | If you find our work useful in your research or publication, please cite this work:
82 | ```
83 | @misc{ResizeRight,
84 | author = {Shocher, Assaf},
85 | title = {ResizeRight},
86 | year = {2018},
87 | publisher = {GitHub},
88 | journal = {GitHub repository},
89 | howpublished = {\url{https://github.com/assafshocher/ResizeRight}},
90 | }
91 | ```
92 |
--------------------------------------------------------------------------------
/utils/ResizeRight/interp_methods.py:
--------------------------------------------------------------------------------
1 | from math import pi
2 |
3 | try:
4 | import torch
5 | except ImportError:
6 | torch = None
7 |
8 | try:
9 | import numpy
10 | except ImportError:
11 | numpy = None
12 |
13 | if numpy is None and torch is None:
14 | raise ImportError("Must have either Numpy or PyTorch but both not found")
15 |
16 |
17 | def set_framework_dependencies(x):
18 | if type(x) is numpy.ndarray:
19 | to_dtype = lambda a: a
20 | fw = numpy
21 | else:
22 | to_dtype = lambda a: a.to(x.dtype)
23 | fw = torch
24 | eps = fw.finfo(fw.float32).eps
25 | return fw, to_dtype, eps
26 |
27 |
28 | def support_sz(sz):
29 | def wrapper(f):
30 | f.support_sz = sz
31 | return f
32 | return wrapper
33 |
34 |
35 | @support_sz(4)
36 | def cubic(x):
37 | fw, to_dtype, eps = set_framework_dependencies(x)
38 | absx = fw.abs(x)
39 | absx2 = absx ** 2
40 | absx3 = absx ** 3
41 | return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) +
42 | (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) *
43 | to_dtype((1. < absx) & (absx <= 2.)))
44 |
45 |
46 | @support_sz(4)
47 | def lanczos2(x):
48 | fw, to_dtype, eps = set_framework_dependencies(x)
49 | return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) /
50 | ((pi**2 * x**2 / 2) + eps)) * to_dtype(abs(x) < 2))
51 |
52 |
53 | @support_sz(6)
54 | def lanczos3(x):
55 | fw, to_dtype, eps = set_framework_dependencies(x)
56 | return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) /
57 | ((pi**2 * x**2 / 3) + eps)) * to_dtype(abs(x) < 3))
58 |
59 |
60 | @support_sz(2)
61 | def linear(x):
62 | fw, to_dtype, eps = set_framework_dependencies(x)
63 | return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) *
64 | to_dtype((0 <= x) & (x <= 1)))
65 |
66 |
67 | @support_sz(1)
68 | def box(x):
69 | fw, to_dtype, eps = set_framework_dependencies(x)
70 | return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1))
71 |
--------------------------------------------------------------------------------
/utils/create_gif.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import imageio
4 |
5 |
6 | frame_root_1 = '/home/csbhr/Downloads/ours'
7 | frame_root_2 = '/home/csbhr/Downloads/input'
8 | gif_path = '/home/csbhr/Downloads/visual.gif'
9 |
10 | frames_fn_1 = sorted(os.listdir(frame_root_1))
11 | frames_fn_2 = sorted(os.listdir(frame_root_2))
12 |
13 | tmp_img = imageio.imread(os.path.join(frame_root_1, frames_fn_1[0]))
14 | H, W, C = tmp_img.shape
15 | gif = np.zeros(tmp_img.shape, dtype=tmp_img.dtype)
16 | line = np.zeros(tmp_img.shape, dtype=tmp_img.dtype)
17 | line[:, :, 0] = 0
18 | line[:, :, 1] = 255
19 | line[:, :, 2] = 0
20 |
21 | nl = 30
22 | wl = 3
23 | gif_images = []
24 | for i, (fn1, fn2) in enumerate(zip(frames_fn_1, frames_fn_2)):
25 | frame_1 = imageio.imread(os.path.join(frame_root_1, fn1))
26 | gif = frame_1
27 | if i < nl:
28 | frame_2 = imageio.imread(os.path.join(frame_root_2, fn2))
29 | gif[:, :W-(W // nl * i), :] = frame_2[:, :W-(W // nl * i), :]
30 | gif[:, W-(W // nl * i + wl):W-(W // nl * i), :] = line[:, W-(W // nl * i + wl):W-(W // nl * i), :]
31 | gif_images.append(gif.copy())
32 | imageio.mimsave(gif_path, gif_images, fps=4)
33 |
34 |
35 | # frames_root = '/home/csbhr/Downloads/(TPAMI) Self-Supervised Deep Blind Video Super-Resolution Supplemental Material/figures/videos'
36 | # gif_root = '/home/csbhr/Downloads/(TPAMI) Self-Supervised Deep Blind Video Super-Resolution Supplemental Material/figures'
37 | #
38 | # methods = sorted(os.listdir(frames_root))
39 | # for me in methods:
40 | # frames_fn = sorted(os.listdir(os.path.join(frames_root, me)))
41 | # gif_images = []
42 | # for fn in frames_fn:
43 | # gif_images.append(imageio.imread(os.path.join(frames_root, me, fn)))
44 | # imageio.mimsave(os.path.join(gif_root, "{}.gif".format(me)), gif_images, fps=2)
45 |
--------------------------------------------------------------------------------
/utils/dnn_utils.py:
--------------------------------------------------------------------------------
1 | def cal_parmeters(model):
2 | params = list(model.parameters())
3 | sum = 0
4 | for i in params:
5 | layer = 1
6 | for j in i.size():
7 | layer *= j
8 | print("该层的结构:" + str(list(i.size())), "该层参数和:" + str(layer))
9 | sum = sum + layer
10 | print("总参数数量和:%.2fM" % (sum / 1e6))
--------------------------------------------------------------------------------
/utils/file_regroup_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from base.os_base import handle_dir, get_fname_ext, copy_file, rename_file, glob_match
3 |
4 |
5 | def remove_files_prefix(root, prefix=''):
6 | '''
7 | remove prefix from filename
8 | params:
9 | root: the dir of files that need to be processed
10 | prefix: the prefix to be removed
11 | '''
12 | img_list = glob_match(os.path.join(root, "*"))
13 | for im in img_list:
14 | basename = os.path.basename(im)
15 | now_prefix = basename[:len(prefix)]
16 | if now_prefix == prefix:
17 | dest_basename = basename[len(prefix):]
18 | src = im
19 | dst = os.path.join(root, dest_basename)
20 | rename_file(src, dst)
21 |
22 |
23 | def remove_files_postfix(root, postfix=''):
24 | '''
25 | remove postfix from filename
26 | params:
27 | root: the dir of files that need to be processed
28 | postfix: the postfix to be removed
29 | '''
30 | img_list = glob_match(os.path.join(root, "*"))
31 | for im in img_list:
32 | fname, ext = get_fname_ext(im)
33 | now_postfix = fname[-len(postfix):]
34 | if now_postfix == postfix:
35 | dest_basename = "{}.{}".format(fname[:-len(postfix)], ext)
36 | src = im
37 | dst = os.path.join(root, dest_basename)
38 | rename_file(src, dst)
39 |
40 |
41 | def add_files_postfix(root, postfix=''):
42 | '''
43 | add postfix to filename
44 | params:
45 | root: the dir of files that need to be processed
46 | postfix: the postfix to be added
47 | '''
48 | img_list = glob_match(os.path.join(root, "*"))
49 | for im in img_list:
50 | fname, ext = get_fname_ext(im)
51 | dest_basename = "{}{}.{}".format(fname, postfix, ext)
52 | src = im
53 | dst = os.path.join(root, dest_basename)
54 | rename_file(src, dst)
55 |
56 |
57 | def extra_files_by_postfix(ori_root, dest_root, match_postfix='', new_postfix=None, match_ext='*'):
58 | '''
59 | extra files from ori_root to dest_root by match_postfix and match_ext
60 | params:
61 | ori_root: the dir of files that need to be processed
62 | dest_root: the dir for saving matched files
63 | match_postfix: the postfix to be matched
64 | new_postfix: the postfix for matched files
65 | default: None, that is keeping the ori postfix
66 | match_ext: the ext to be matched
67 | '''
68 | if new_postfix is None:
69 | new_postfix = match_postfix
70 |
71 | handle_dir(dest_root)
72 | flag_img_list = glob_match(os.path.join(ori_root, "*{}.{}".format(match_postfix, match_ext)))
73 | for im in flag_img_list:
74 | fname, ext = get_fname_ext(im)
75 | dest_basename = '{}{}.{}'.format(fname[-len(match_postfix):], new_postfix, ext)
76 | src = im
77 | dst = os.path.join(dest_root, dest_basename)
78 | copy_file(src, dst)
79 |
80 |
81 | def resort_files_index(root, template='{:0>4}', start_idx=0):
82 | '''
83 | resort files' filename using template that index start from start_idx
84 | params:
85 | root: the dir of files that need to be processed
86 | template: the template for processed filename
87 | start_idx: the start index
88 | '''
89 | template = template + '.{}'
90 | img_list = sorted(glob_match(os.path.join(root, "*")))
91 | for i, im in enumerate(img_list):
92 | fname, ext = get_fname_ext(im)
93 | dest_basename = template.format(i + start_idx, ext)
94 | src = im
95 | dst = os.path.join(root, dest_basename)
96 | rename_file(src, dst)
97 |
--------------------------------------------------------------------------------
/utils/image_crop_combine_utils.py:
--------------------------------------------------------------------------------
1 | '''
2 | require:
3 | filenames:
4 | filenames should not contrain symbol "-"
5 | crop flags:
6 | the crop flag "x-x-x-x" is at the end of filename when cropping
7 | so, the crop flag should be at the end of filename when combining
8 |
9 | '''
10 |
11 | import cv2
12 | import os
13 | import numpy as np
14 | from base.image_base import evaluate_smooth
15 | from base.os_base import handle_dir, listdir, glob_match
16 |
17 |
18 | def crop_img_with_padding(img, min_size=(100, 100), padding=10):
19 | h, w, c = img.shape
20 | n_x, n_y = h // min_size[0], w // min_size[1]
21 |
22 | croped_imgs = {}
23 | for i in range(n_x):
24 | for j in range(n_y):
25 | xl, xr = i * min_size[0] - padding, (i + 1) * min_size[0] + padding
26 | yl, yr = j * min_size[0] - padding, (j + 1) * min_size[0] + padding
27 | if i == 0:
28 | xl = 0
29 | if i == n_x - 1:
30 | xr = h
31 | if j == 0:
32 | yl = 0
33 | if j == n_y - 1:
34 | yr = w
35 | croped_imgs['{}-{}-{}-{}'.format(n_x, n_y, i, j)] = img[xl:xr, yl:yr, :]
36 |
37 | return croped_imgs
38 |
39 |
40 | def combine_img(croped_imgs, padding=10):
41 | keys = list(croped_imgs.keys())
42 | n_x, n_y = int(keys[0].split('-')[0]), int(keys[0].split('-')[1])
43 | img_blocks = [["" for _ in range(n_y)] for _ in range(n_x)]
44 | for k in keys:
45 | i, j = int(k.split('-')[2]), int(k.split('-')[3])
46 | xl, xr, yl, yr = padding, -padding, padding, -padding
47 | if i == 0:
48 | xl = 0
49 | if i == n_x - 1:
50 | xr = 9999
51 | if j == 0:
52 | yl = 0
53 | if j == n_y - 1:
54 | yr = 9999
55 | img_blocks[i][j] = croped_imgs[k][xl:xr, yl:yr, :]
56 |
57 | line_imgs = []
58 | for i in range(n_x):
59 | img_line = img_blocks[i][0]
60 | for j in range(n_y - 1):
61 | img_line = np.concatenate((img_line, img_blocks[i][j + 1]), axis=1)
62 | line_imgs.append(img_line)
63 | combined_img = line_imgs[0]
64 | for i in range(n_x - 1):
65 | combined_img = np.concatenate((combined_img, line_imgs[i + 1]), axis=0)
66 | return combined_img
67 |
68 |
69 | def traverse_crop_img(img, dsize=(100, 100), interval=10):
70 | h, w, c = img.shape
71 |
72 | croped_imgs = []
73 | for i in range(9999):
74 | isbreak_x = False
75 | ix = i * interval
76 | xl, xr = ix, ix + dsize[0]
77 | if xr > h:
78 | xl, xr = h - dsize[0], h
79 | isbreak_x = True
80 | for j in range(9999):
81 | isbreak_y = False
82 | iy = j * interval
83 | yl, yr = iy, iy + dsize[1]
84 | if yr > w:
85 | yl, yr = w - dsize[1], w
86 | isbreak_y = True
87 | croped = img[xl:xr, yl:yr, :]
88 | croped_imgs.append(croped)
89 | if isbreak_y:
90 | break
91 | if isbreak_x:
92 | break
93 |
94 | return croped_imgs
95 |
96 |
97 | def batch_crop_img_with_padding(ori_root, dest_root, min_size=(100, 100), padding=10):
98 | '''
99 | function:
100 | cropping image to many patches with padding
101 | it can be used for inferring large image
102 | params:
103 | ori_root: the dir of images that need to be processed
104 | dest_root: the dir to save processed images
105 | min_size: a tuple (h, w) the min size of crop, the border patch will be larger
106 | padding: the padding size of each patch
107 | notice:
108 | filenames should not contain the character "-"
109 | the crop flag "x-x-x-x" will be at the end of filename when cropping
110 | '''
111 | handle_dir(dest_root)
112 | images_fname = sorted(listdir(ori_root))
113 | for imf in images_fname:
114 | img = cv2.imread(os.path.join(ori_root, imf))
115 | img_cropped = crop_img_with_padding(img, min_size=min_size, padding=padding)
116 | for k in img_cropped.keys():
117 | cv2.imwrite(os.path.join(dest_root, "{}_{}.png".format(os.path.basename(imf).split('.')[0], k)),
118 | img_cropped[k])
119 | print(imf, "crop done !")
120 |
121 |
122 | def batch_combine_img(ori_root, dest_root, padding=10):
123 | '''
124 | function:
125 | combining many patches to image
126 | it can be used to combine patches to image, when you finish inferring large image with cropped patches
127 | params:
128 | ori_root: the dir of images that need to be processed
129 | dest_root: the dir to save processed images
130 | padding: the padding size of each patch
131 | notice:
132 | filenames should not contain the character "-" except for the crop flag
133 | the crop flag "x-x-x-x" should be at the end of filename when combining
134 | '''
135 | handle_dir(dest_root)
136 | images_fname = [fn[:-(len(fn.split('_')[-1]) + 1)] for fn in listdir(ori_root)]
137 | images_fname = list(set(images_fname))
138 | for imf in images_fname:
139 | croped_imgs_path = sorted(glob_match(os.path.join(ori_root, "{}*".format(imf))))
140 | croped_imgs = {}
141 | for cip in croped_imgs_path:
142 | img = cv2.imread(cip)
143 | k = cip.split('.')[0].split('_')[-1]
144 | croped_imgs[k] = img
145 | img_combined = combine_img(croped_imgs, padding=padding)
146 | cv2.imwrite(os.path.join(dest_root, "{}.png".format(imf)), img_combined)
147 | print("{}.png".format(imf), "combine done !")
148 |
149 |
150 | def batch_traverse_crop_img(ori_root, dest_root, dsize=(100, 100), interval=10):
151 | '''
152 | function:
153 | traversing crop image to many patches with same interval
154 | params:
155 | ori_root: the dir of images that need to be processed
156 | dest_root: the dir to save processed images
157 | dsize: a tuple (h, w) the size of crop, the border patch will be overlapped for satisfing the dsize
158 | interval: the interval when traversing
159 | '''
160 | handle_dir(dest_root)
161 | images_fname = sorted(listdir(ori_root))
162 | for imf in images_fname:
163 | img = cv2.imread(os.path.join(ori_root, imf))
164 | img_cropped = traverse_crop_img(img, dsize=dsize, interval=interval)
165 | for i, cim in enumerate(img_cropped):
166 | cv2.imwrite(os.path.join(dest_root, "{}_{}.png".format(os.path.basename(imf).split('.')[0], i)), cim)
167 | print(imf, "crop done !")
168 |
169 |
170 | def batch_select_valid_patch(ori_root, dest_root, thre=7):
171 | '''
172 | function:
173 | selecting valid patch that are not too smooth
174 | params:
175 | ori_root: the dir of patches that need to be selected
176 | dest_root: the dir to save selected patch
177 | thre: the threshold value of smooth
178 | '''
179 | handle_dir(dest_root)
180 | images_fname = sorted(listdir(ori_root))
181 | total_num = len(images_fname)
182 | valid_num = 0
183 | for imf in images_fname:
184 | img = cv2.imread(os.path.join(ori_root, imf))
185 | smooth = evaluate_smooth(img)
186 | if smooth > thre:
187 | cv2.imwrite(os.path.join(dest_root, imf), img)
188 | valid_num += 1
189 | else:
190 | print(imf, "too smooth, smooth={}".format(smooth))
191 | print("Total {} patches, valid {}, remove {}".format(total_num, valid_num, total_num - valid_num))
192 |
--------------------------------------------------------------------------------
/utils/image_edge_sharpen.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import cv2
6 | import os
7 | import math
8 | from utils.video_metric_utils import batch_calc_video_PSNR_SSIM
9 |
10 |
11 | def handle_dir(dir):
12 | if not os.path.exists(dir):
13 | os.mkdir(dir)
14 | print("mkdir:", dir)
15 |
16 |
17 | def matlab_style_gauss2D(shape=(5, 5), sigma=0.5):
18 | """
19 | 2D gaussian mask - should give the same result as MATLAB's
20 | fspecial('gaussian',[shape],[sigma])
21 | """
22 | m, n = [(ss - 1.) / 2. for ss in shape]
23 | y, x = np.ogrid[-m:m + 1, -n:n + 1]
24 | h = np.exp(-(x * x + y * y) / (2. * sigma * sigma))
25 | h[h < np.finfo(h.dtype).eps * h.max()] = 0
26 | sumh = h.sum()
27 | if sumh != 0:
28 | h /= sumh
29 | return h
30 |
31 |
32 | def get_blur_kernel():
33 | gaussian_sigma = 1.0
34 | gaussian_blur_kernel_size = int(math.ceil(gaussian_sigma * 3) * 2 + 1)
35 | kernel = matlab_style_gauss2D((gaussian_blur_kernel_size, gaussian_blur_kernel_size), gaussian_sigma)
36 | return kernel
37 |
38 |
39 | def get_blur(img, kernel):
40 | img = np.array(img).astype('float32')
41 | img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).float()
42 |
43 | kernel_size = kernel.shape[0]
44 | psize = kernel_size // 2
45 | img_tensor = F.pad(img_tensor, (psize, psize, psize, psize), mode='replicate')
46 |
47 | gaussian_blur = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=kernel_size, stride=1,
48 | padding=int((kernel_size - 1) // 2), bias=False)
49 | nn.init.constant_(gaussian_blur.weight.data, 0.0)
50 | gaussian_blur.weight.data[0, 0, :, :] = torch.FloatTensor(kernel)
51 | gaussian_blur.weight.data[1, 1, :, :] = torch.FloatTensor(kernel)
52 | gaussian_blur.weight.data[2, 2, :, :] = torch.FloatTensor(kernel)
53 |
54 | blur_tensor = gaussian_blur(img_tensor)
55 | blur_tensor = blur_tensor[:, :, psize:-psize, psize:-psize]
56 |
57 | blur_img = blur_tensor[0].detach().numpy().transpose(1, 2, 0).astype('float32')
58 |
59 | return blur_img
60 |
61 |
62 | def image_post_process_results(ori_root, save_root, alpha=3.):
63 | handle_dir(save_root)
64 |
65 | guassian_kernel = get_blur_kernel()
66 |
67 | frame_names = sorted(os.listdir(os.path.join(ori_root)))
68 | for fn in frame_names:
69 | ori_img = cv2.imread(os.path.join(ori_root, fn)).astype('float32')
70 | blur_img = get_blur(ori_img, guassian_kernel).astype('float32')
71 |
72 | res_img = ori_img - blur_img
73 |
74 | result = blur_img + alpha * res_img
75 |
76 | basename = fn.split(".")[0]
77 | cv2.imwrite(os.path.join(save_root, "{}_post.png".format(basename)), result)
78 |
79 | ##
80 | # res_img = np.clip(res_img, 0, np.max(res_img))
81 | # res_img = (res_img / np.max(res_img)) * 255.
82 | # cv2.imwrite(os.path.join(save_root, "{}_res.png".format(basename)), res_img)
83 | ##
84 |
85 | print("{} done!".format(fn))
86 |
87 |
88 | def video_post_process_results(ori_root, save_root, alpha=3.):
89 | handle_dir(save_root)
90 |
91 | guassian_kernel = get_blur_kernel()
92 |
93 | video_names = sorted(os.listdir(ori_root))
94 | for vn in video_names:
95 | handle_dir(os.path.join(save_root, vn))
96 |
97 | frame_names = sorted(os.listdir(os.path.join(ori_root, vn)))
98 | for fn in frame_names:
99 | ori_img = cv2.imread(os.path.join(ori_root, vn, fn)).astype('float32')
100 | blur_img = get_blur(ori_img, guassian_kernel).astype('float32')
101 |
102 | res_img = ori_img - blur_img
103 |
104 | result = blur_img + alpha * res_img
105 |
106 | basename = fn.split(".")[0]
107 | cv2.imwrite(os.path.join(save_root, vn, "{}_post.png".format(basename)), result)
108 |
109 | ##
110 | # res_img = np.clip(res_img, 0, np.max(res_img))
111 | # res_img = (res_img / np.max(res_img)) * 255.
112 | # cv2.imwrite(os.path.join(save_root, vn, "{}_res.png".format(basename)), res_img)
113 | ##
114 |
115 | print("{}-{} done!".format(vn, fn))
116 |
117 |
118 | if __name__ == '__main__':
119 | # image_post_process_results(
120 | # ori_root='/media/csbhr/Bear/Dataset/FaceSR/face/test/bic',
121 | # save_root='./temp/edge/residual',
122 | # alpha=2
123 | # )
124 |
125 | root_list = []
126 | for i in range(21):
127 | alpha = 1.0 + i * 0.1
128 | save_root = '/home/csbhr/Disk-2T/work/OpenUtility/temp/edge_sharpen/post_{}'.format(alpha)
129 | video_post_process_results(
130 | ori_root='/home/csbhr/Disk-2T/work/OpenUtility/temp/edge_sharpen/ori',
131 | save_root=save_root,
132 | alpha=alpha
133 | )
134 | root_list.append({
135 | 'output': save_root,
136 | 'gt': '/home/csbhr/Disk-2T/work/OpenUtility/temp/edge_sharpen/HR'
137 | })
138 | batch_calc_video_PSNR_SSIM(root_list)
139 |
--------------------------------------------------------------------------------
/utils/image_metric_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import torch
4 | import numpy as np
5 | from base import image_base
6 | from base.os_base import listdir
7 | from base import matlab_imresize
8 | import utils.LPIPS.models as lpips_models
9 | import utils.NIQE.niqe as cal_niqe
10 |
11 |
12 | def calc_image_PSNR_SSIM(output_root, gt_root, crop_border=4, shift_window_size=0, test_ycbcr=False, crop_GT=False):
13 | '''
14 | 计算图片的 PSNR、SSIM,使用 EDVR 的计算方式
15 | 要求 output_root, gt_root 中的文件按顺序一一对应
16 | '''
17 |
18 | if test_ycbcr:
19 | print('Testing Y channel.')
20 | else:
21 | print('Testing RGB channels.')
22 |
23 | PSNR_list = []
24 | SSIM_list = []
25 | output_img_list = sorted(listdir(output_root))
26 | gt_img_list = sorted(listdir(gt_root))
27 | for o_im, g_im in zip(output_img_list, gt_img_list):
28 | o_im_path = os.path.join(output_root, o_im)
29 | g_im_path = os.path.join(gt_root, g_im)
30 | im_GT = cv2.imread(g_im_path) / 255.
31 | im_Gen = cv2.imread(o_im_path) / 255.
32 |
33 | if crop_GT:
34 | h, w, c = im_Gen.shape
35 | im_GT = im_GT[:h, :w, :] # crop GT to output size
36 |
37 | if test_ycbcr and im_GT.shape[2] == 3: # evaluate on Y channel in YCbCr color space
38 | im_GT = image_base.bgr2ycbcr(im_GT, range=1.)
39 | im_Gen = image_base.bgr2ycbcr(im_Gen, range=1.)
40 |
41 | # crop borders
42 | if crop_border != 0:
43 | if im_GT.ndim == 3:
44 | cropped_GT = im_GT[crop_border:-crop_border, crop_border:-crop_border, :]
45 | cropped_Gen = im_Gen[crop_border:-crop_border, crop_border:-crop_border, :]
46 | elif im_GT.ndim == 2:
47 | cropped_GT = im_GT[crop_border:-crop_border, crop_border:-crop_border]
48 | cropped_Gen = im_Gen[crop_border:-crop_border, crop_border:-crop_border]
49 | else:
50 | raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_GT.ndim))
51 | else:
52 | cropped_GT = im_GT
53 | cropped_Gen = im_Gen
54 |
55 | if shift_window_size == 0:
56 | psnr = image_base.PSNR(cropped_GT * 255, cropped_Gen * 255)
57 | ssim = image_base.SSIM(cropped_GT * 255, cropped_Gen * 255)
58 | else:
59 | psnr, ssim = image_base.PSNR_SSIM_Shift_Best(cropped_GT * 255, cropped_Gen * 255,
60 | window_size=shift_window_size)
61 | PSNR_list.append(psnr)
62 | SSIM_list.append(ssim)
63 |
64 | print("{} PSNR={:.5}, SSIM={:.4}".format(o_im, psnr, ssim))
65 |
66 | log = 'Average PSNR={:.5}, SSIM={:.4}'.format(sum(PSNR_list) / len(PSNR_list), sum(SSIM_list) / len(SSIM_list))
67 | print(log)
68 |
69 | return PSNR_list, SSIM_list, log
70 |
71 |
72 | def batch_calc_image_PSNR_SSIM(root_list, crop_border=4, shift_window_size=0, test_ycbcr=False, crop_GT=False):
73 | '''
74 | required params:
75 | root_list: a list, each item should be a dictionary that given two key-values:
76 | output: the dir of output images
77 | gt: the dir of gt images
78 | optional params:
79 | crop_border: defalut=4, crop pixels when calculating PSNR/SSIM
80 | shift_window_size: defalut=0, if >0, shifting image within a window for best metric
81 | test_ycbcr: default=False, if True, applying Ycbcr color space
82 | crop_GT: default=False, if True, cropping GT to output size
83 | return:
84 | log_list: a list, each item is a dictionary that given two key-values:
85 | data_path: the evaluated dir
86 | log: the log of this dir
87 | '''
88 | log_list = []
89 | for i, root in enumerate(root_list):
90 | ouput_root = root['output']
91 | gt_root = root['gt']
92 | print(">>>> Now Evaluation >>>>")
93 | print(">>>> OUTPUT: {}".format(ouput_root))
94 | print(">>>> GT: {}".format(gt_root))
95 | _, _, log = calc_image_PSNR_SSIM(
96 | ouput_root, gt_root, crop_border=crop_border, shift_window_size=shift_window_size,
97 | test_ycbcr=test_ycbcr, crop_GT=crop_GT
98 | )
99 | log_list.append({
100 | 'data_path': ouput_root,
101 | 'log': log
102 | })
103 |
104 | print("--------------------------------------------------------------------------------------")
105 | for i, log in enumerate(log_list):
106 | print("## The {}-th:".format(i))
107 | print(">> ", log['data_path'])
108 | print(">> ", log['log'])
109 |
110 | return log_list
111 |
112 |
113 | def calc_image_LPIPS(output_root, gt_root, model=None, use_gpu=False, spatial=True):
114 | '''
115 | 计算图片的 LPIPS
116 | 要求 output_root, gt_root 中的文件按顺序一一对应
117 | '''
118 |
119 | def _load_image(path, size=(512, 512)):
120 | img = cv2.imread(path)
121 | h, w, c = img.shape
122 | if h != size[0] or w != size[1]:
123 | img = cv2.resize(img, dsize=size, interpolation=cv2.INTER_CUBIC)
124 | return img[:, :, ::-1]
125 |
126 | def _im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.):
127 | return torch.Tensor((image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
128 |
129 | if model is None:
130 | model = lpips_models.PerceptualLoss(model='net-lin', net='alex', use_gpu=use_gpu, spatial=spatial)
131 |
132 | LPIPS_list = []
133 | output_img_list = sorted(listdir(output_root))
134 | gt_img_list = sorted(listdir(gt_root))
135 | for o_im, g_im in zip(output_img_list, gt_img_list):
136 | o_im_path = os.path.join(output_root, o_im)
137 | g_im_path = os.path.join(gt_root, g_im)
138 | im_GT = _im2tensor(_load_image(g_im_path))
139 | im_Gen = _im2tensor(_load_image(o_im_path))
140 |
141 | if use_gpu:
142 | im_GT = im_GT.cuda()
143 | im_Gen = im_Gen.cuda()
144 |
145 | lpips = model.forward(im_GT, im_Gen).mean()
146 | LPIPS_list.append(lpips)
147 |
148 | print("{} LPIPS={:.4}".format(o_im, lpips))
149 |
150 | log = 'Average LPIPS={:.4}'.format(sum(LPIPS_list) / len(LPIPS_list))
151 | print(log)
152 |
153 | return LPIPS_list, log
154 |
155 |
156 | def batch_calc_image_LPIPS(root_list, use_gpu=False, spatial=True):
157 | '''
158 | required params:
159 | root_list: a list, each item should be a dictionary that given two key-values:
160 | output: the dir of output images
161 | gt: the dir of gt images
162 | optional params:
163 | use_gpu: defalut=False, if True, using gpu
164 | spatial: default=True, if True, return spatial map
165 | return:
166 | log_list: a list, each item is a dictionary that given two key-values:
167 | data_path: the evaluated dir
168 | log: the log of this dir
169 | '''
170 | model = lpips_models.PerceptualLoss(model='net-lin', net='alex', use_gpu=use_gpu, spatial=spatial)
171 |
172 | log_list = []
173 | for i, root in enumerate(root_list):
174 | ouput_root = root['output']
175 | gt_root = root['gt']
176 | print(">>>> Now Evaluation >>>>")
177 | print(">>>> OUTPUT: {}".format(ouput_root))
178 | print(">>>> GT: {}".format(gt_root))
179 | _, log = calc_image_LPIPS(ouput_root, gt_root, model=model, use_gpu=use_gpu, spatial=spatial)
180 | log_list.append({
181 | 'data_path': ouput_root,
182 | 'log': log
183 | })
184 |
185 | print("--------------------------------------------------------------------------------------")
186 | for i, log in enumerate(log_list):
187 | print("## The {}-th:".format(i))
188 | print(">> ", log['data_path'])
189 | print(">> ", log['log'])
190 |
191 | return log_list
192 |
193 |
194 | def calc_image_NIQE(output_root, crop_border=4):
195 | '''
196 | 计算图片的 NIQE
197 | '''
198 |
199 | NIQE_list = []
200 | output_img_list = sorted(listdir(output_root))
201 | for o_im in output_img_list:
202 | o_im_path = os.path.join(output_root, o_im)
203 | im_Gen = cv2.imread(o_im_path)
204 | im_Gen = cv2.cvtColor(im_Gen, cv2.COLOR_BGR2GRAY)
205 |
206 | # crop borders
207 | if crop_border != 0:
208 | if im_Gen.ndim == 3:
209 | cropped_Gen = im_Gen[crop_border:-crop_border, crop_border:-crop_border, :]
210 | elif im_Gen.ndim == 2:
211 | cropped_Gen = im_Gen[crop_border:-crop_border, crop_border:-crop_border]
212 | else:
213 | raise ValueError('Wrong image dimension: {}. Should be 2 or 3.'.format(im_Gen.ndim))
214 | else:
215 | cropped_Gen = im_Gen
216 |
217 | h, w = cropped_Gen.shape
218 | if h < 193 or w < 193:
219 | cropped_Gen = matlab_imresize.imresize(cropped_Gen, output_shape=(512, 512), method='bicubic')
220 | # cropped_Gen = cv2.resize(cropped_Gen, dsize=(512, 512), interpolation=cv2.INTER_CUBIC)
221 |
222 | niqe = cal_niqe.niqe(cropped_Gen)
223 | NIQE_list.append(niqe)
224 |
225 | print("{} NIQE={:.5}".format(o_im, niqe))
226 |
227 | log = 'Average NIQE={:.5}'.format(sum(NIQE_list) / len(NIQE_list))
228 | print(log)
229 |
230 | return NIQE_list, log
231 |
232 |
233 | def batch_calc_image_NIQE(root_list, crop_border=4):
234 | '''
235 | required params:
236 | root_list: a list, each item should be a dictionary that given key-values:
237 | output: the dir of output images
238 | optional params:
239 | crop_border: defalut=4, crop pixels when calculating NIQE
240 | return:
241 | log_list: a list, each item is a dictionary that given two key-values:
242 | data_path: the evaluated dir
243 | log: the log of this dir
244 | '''
245 | log_list = []
246 | for i, root in enumerate(root_list):
247 | ouput_root = root['output']
248 | print(">>>> Now Evaluation >>>>")
249 | print(">>>> OUTPUT: {}".format(ouput_root))
250 | _, log = calc_image_NIQE(
251 | ouput_root, crop_border=crop_border
252 | )
253 | log_list.append({
254 | 'data_path': ouput_root,
255 | 'log': log
256 | })
257 |
258 | print("--------------------------------------------------------------------------------------")
259 | for i, log in enumerate(log_list):
260 | print("## The {}-th:".format(i))
261 | print(">> ", log['data_path'])
262 | print(">> ", log['log'])
263 |
264 | return log_list
265 |
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from base.os_base import handle_dir, get_fname_ext, listdir
5 | from base.image_base import matlab_imresize, image_shift
6 |
7 |
8 | def matlab_resize_images(ori_root, dest_root, scale=1.0, method='bicubic', filename_template="{}.png"):
9 | '''
10 | function:
11 | resizing images in batches, same as matlab2017 imresize
12 | params:
13 | ori_root: string, the dir of images that need to be processed
14 | dest_root: string, the dir to save processed images
15 | scale: float, the resize scale
16 | method: string, the interpolation method,
17 | optional: 'bilinear', 'bicubic'
18 | default: 'bicubic'
19 | filename_template: string, the filename template for saving images
20 | '''
21 | if method != 'bilinear' and method != 'bicubic':
22 | raise Exception('Unknown method!')
23 |
24 | handle_dir(dest_root)
25 | scale = float(scale)
26 | images_fname = sorted(listdir(ori_root))
27 | for imf in images_fname:
28 | img = cv2.imread(os.path.join(ori_root, imf)).astype('float32')
29 | img = matlab_imresize(img, scalar_scale=scale, method=method)
30 | cv2.imwrite(os.path.join(dest_root, filename_template.format(get_fname_ext(imf)[0])), img)
31 | print("Image", imf, "resize done !")
32 |
33 |
34 | def cv2_resize_images(ori_root, dest_root, scale=1.0, method='bicubic', filename_template="{}.png"):
35 | '''
36 | function:
37 | resizing images in batches
38 | params:
39 | ori_root: string, the dir of images that need to be processed
40 | dest_root: string, the dir to save processed images
41 | scale: float, the resize scale
42 | method: string, the interpolation method,
43 | optional: 'nearest', 'bilinear', 'bicubic'
44 | default: 'bicubic'
45 | filename_template: string, the filename template for saving images
46 | '''
47 | if method == 'nearest':
48 | interpolation = cv2.INTER_NEAREST
49 | elif method == 'bilinear':
50 | interpolation = cv2.INTER_LINEAR
51 | elif method == 'bicubic':
52 | interpolation = cv2.INTER_CUBIC
53 | else:
54 | raise Exception('Unknown method!')
55 |
56 | handle_dir(dest_root)
57 | scale = float(scale)
58 | images_fname = sorted(listdir(ori_root))
59 | for imf in images_fname:
60 | img = cv2.imread(os.path.join(ori_root, imf)).astype('float32')
61 | img = cv2.resize(img, dsize=(0, 0), fx=scale, fy=scale, interpolation=interpolation)
62 | cv2.imwrite(os.path.join(dest_root, filename_template.format(get_fname_ext(imf)[0])), img)
63 | print("Image", imf, "resize done !")
64 |
65 |
66 | def shift_images(ori_root, dest_root, offset_x=0., offset_y=0., filename_template="{}.png"): # TODO
67 | '''
68 | function:
69 | shifting images by (offset_x, offset_y) on (axis-x, axis-y) in batches
70 | params:
71 | ori_root: string, the dir of images that need to be processed
72 | dest_root: string, the dir to save processed images
73 | offset_x: float, offset pixels on axis-x
74 | positive=left; negative=right
75 | offset_y: float, offset pixels on axis-y
76 | positive=up; negative=down
77 | filename_template: string, the filename template for saving images
78 | '''
79 |
80 | handle_dir(dest_root)
81 | offset_x, offset_y = float(offset_x), float(offset_y)
82 | images_fname = sorted(listdir(ori_root))
83 | for imf in images_fname:
84 | img = cv2.imread(os.path.join(ori_root, imf)).astype('float32')
85 | img = image_shift(img, offset_x=offset_x, offset_y=offset_y)
86 | cv2.imwrite(os.path.join(dest_root, filename_template.format(get_fname_ext(imf)[0])), img)
87 | print("Image", imf, "shift done !")
88 |
89 |
90 | def margin_patch(patch, margin_width=5):
91 | '''margin patch by red cycle'''
92 | red = np.zeros_like(patch)
93 | red[:, :, 1] = 255
94 | mask = np.zeros_like(patch)
95 | mask[:margin_width, :, :] = mask[-margin_width:, :, :] = mask[:, :margin_width, :] = mask[:, -margin_width:, :] = 1
96 | res = red * mask + patch * (1 - mask)
97 | return res
98 |
99 |
100 | def circle_zoom_img(img, pos_1, pos_2, scale=2., hr_pos='right_down'):
101 | '''circle and zoom a patch in image'''
102 | patch_lr = img[pos_1[0]:pos_1[1], pos_2[0]:pos_2[1], :]
103 | patch_hr = cv2.resize(patch_lr, dsize=(0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
104 | patch_lr = margin_patch(patch_lr)
105 | patch_hr = margin_patch(patch_hr)
106 | img[pos_1[0]:pos_1[1], pos_2[0]:pos_2[1], :] = patch_lr
107 | if hr_pos == 'right_down':
108 | img[-patch_hr.shape[0]:, -patch_hr.shape[1]:, :] = patch_hr
109 | elif hr_pos == 'left_down':
110 | img[-patch_hr.shape[0]:, :patch_hr.shape[1], :] = patch_hr
111 | return img
112 |
113 |
114 | def circle_img(img, pos_1, pos_2, margin_width=2):
115 | '''circle a patch in image'''
116 | patch_lr = img[pos_1[0]:pos_1[1], pos_2[0]:pos_2[1], :]
117 | patch_lr = margin_patch(patch_lr, margin_width=margin_width)
118 | img[pos_1[0]:pos_1[1], pos_2[0]:pos_2[1], :] = patch_lr
119 | return img
120 |
--------------------------------------------------------------------------------
/utils/kernel_metric_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from base import kernel_base
3 | from base.kernel_base import load_mat_kernel
4 | from base.os_base import listdir
5 |
6 |
7 | def calc_kernel_metric(output_root, gt_root):
8 | '''
9 | 计算 kernel 的 metric: gradient similarity, rmse, psnr, ssim
10 | 要求 output_root, gt_root 中的文件按顺序一一对应
11 | '''
12 |
13 | metric_dict = {
14 | 'gradient_similarity': [],
15 | 'rmse': [],
16 | 'psnr': [],
17 | 'ssim': [],
18 | }
19 |
20 | output_kernel_list = sorted(listdir(output_root))
21 | gt_kernel_list = sorted(listdir(gt_root))
22 | for o_k, g_k in zip(output_kernel_list, gt_kernel_list):
23 | o_k_path = os.path.join(output_root, o_k)
24 | g_k_path = os.path.join(gt_root, g_k)
25 | kernel_GT = load_mat_kernel(g_k_path)
26 | kernel_Gen = load_mat_kernel(o_k_path)
27 |
28 | gradient_similarity = kernel_base.Gradient_Similarity(kernel_Gen, kernel_GT)
29 | rmse, psnr, ssim = kernel_base.Kernel_RMSE_PSNR_SSIM(kernel_Gen, kernel_GT)
30 | metric_dict['gradient_similarity'].append(gradient_similarity)
31 | metric_dict['rmse'].append(rmse)
32 | metric_dict['psnr'].append(psnr)
33 | metric_dict['ssim'].append(ssim)
34 |
35 | print("{} Gradient-Similarity={:.4}, RMSE={:.4}, PSNR={:.4}, SSIM={:.4}".format(
36 | o_k, gradient_similarity, rmse, psnr, ssim
37 | ))
38 |
39 | log = 'Average Gradient-Similarity={:.4}, RMSE={:.4}, PSNR={:.4}, SSIM={:.4}'.format(
40 | sum(metric_dict['gradient_similarity']) / len(metric_dict['gradient_similarity']),
41 | sum(metric_dict['rmse']) / len(metric_dict['rmse']),
42 | sum(metric_dict['psnr']) / len(metric_dict['psnr']),
43 | sum(metric_dict['ssim']) / len(metric_dict['ssim'])
44 | )
45 | print(log)
46 |
47 | return metric_dict, log
48 |
49 |
50 | def calc_kernel_metric_video(output_root, gt_root):
51 | '''
52 | 计算视频的 kernel 的 metric: gradient similarity, rmse
53 | 要求 output_root, gt_root 中的文件按顺序一一对应
54 | '''
55 | gradient_similarity_sum = 0.
56 | rmse_sum = 0.
57 | psnr_sum = 0.
58 | ssim_sum = 0.
59 | kernel_num = 0
60 |
61 | video_metric = []
62 |
63 | video_list = sorted(listdir(output_root))
64 | for v in video_list:
65 | v_metric_list, _ = calc_kernel_metric(
66 | output_root=os.path.join(output_root, v),
67 | gt_root=os.path.join(gt_root, v)
68 | )
69 | gradient_similarity_sum += sum(v_metric_list['gradient_similarity'])
70 | rmse_sum += sum(v_metric_list['rmse'])
71 | psnr_sum += sum(v_metric_list['psnr'])
72 | ssim_sum += sum(v_metric_list['ssim'])
73 | kernel_num += len(v_metric_list['gradient_similarity'])
74 |
75 | video_metric.append({
76 | 'video_name': v,
77 | 'gradient_similarity': v_metric_list['gradient_similarity'],
78 | 'rmse': v_metric_list['rmse'],
79 | 'psnr': v_metric_list['psnr'],
80 | 'ssim': v_metric_list['ssim']
81 | })
82 |
83 | logs = []
84 | for v_m in video_metric:
85 | log = 'Video: {} Gradient-Similarity={:.4}, RMSE={:.4}, PSNR={:.4}, SSIM={:.4}'.format(
86 | v_m['video_name'],
87 | sum(v_m['gradient_similarity']) / len(v_m['gradient_similarity']),
88 | sum(v_m['rmse']) / len(v_m['rmse']),
89 | sum(v_m['psnr']) / len(v_m['psnr']),
90 | sum(v_m['ssim']) / len(v_m['ssim'])
91 | )
92 | print(log)
93 | logs.append(log)
94 | log = 'Average Gradient-Similarity={:.4}, RMSE={:.4}, PSNR={:.4}, SSIM={:.4}'.format(
95 | gradient_similarity_sum / kernel_num,
96 | rmse_sum / kernel_num,
97 | psnr_sum / kernel_num,
98 | ssim_sum / kernel_num
99 | )
100 | print(log)
101 | logs.append(log)
102 |
103 | return video_metric, logs
104 |
105 |
106 | def batch_calc_kernel_metric(root_list, video_type=False):
107 | '''
108 | required params:
109 | root_list: a list, each item should be a dictionary that given two key-values:
110 | output: the dir of output images
111 | gt: the dir of gt images
112 | return:
113 | log_list: a list, each item is a dictionary that given two key-values:
114 | data_path: the evaluated dir
115 | log: the log of this dir
116 | '''
117 | log_list = []
118 | for i, root in enumerate(root_list):
119 | ouput_root = root['output']
120 | gt_root = root['gt']
121 | print(">>>> Now Evaluation >>>>")
122 | print(">>>> OUTPUT: {}".format(ouput_root))
123 | print(">>>> GT: {}".format(gt_root))
124 | if video_type:
125 | _, log = calc_kernel_metric_video(ouput_root, gt_root)
126 | else:
127 | _, log = calc_kernel_metric(ouput_root, gt_root)
128 | log_list.append({
129 | 'data_path': ouput_root,
130 | 'log': log
131 | })
132 |
133 | print("--------------------------------------------------------------------------------------")
134 | for i, log in enumerate(log_list):
135 | print("## The {}-th:".format(i))
136 | print(">> ", log['data_path'])
137 | if video_type:
138 | for lo in log['log']:
139 | print(">> ", lo)
140 | else:
141 | print(">> ", log['log'])
142 |
143 | return log_list
144 |
--------------------------------------------------------------------------------
/utils/kernel_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | from base.os_base import handle_dir, get_fname_ext, listdir
4 | from base.kernel_base import kernel2png, load_mat_kernel
5 |
6 |
7 | def save_kernels_as_png(ori_root, dest_root, filename_template="{}.png"):
8 | '''
9 | function:
10 | convert kernel for saving as png
11 | params:
12 | ori_root: string, the dir of kernels that need to be processed
13 | dest_root: string, the dir to save processed kernels
14 | filename_template: string, the filename template for saving kernels
15 | '''
16 |
17 | handle_dir(dest_root)
18 | kernels_fname = sorted(listdir(ori_root))
19 | for kerf in kernels_fname:
20 | ker = load_mat_kernel(os.path.join(ori_root, kerf))
21 | ker_png = kernel2png(ker)
22 | cv2.imwrite(os.path.join(dest_root, filename_template.format(get_fname_ext(kerf)[0])), ker_png)
23 | print("Kernel", kerf, "save done !")
24 |
--------------------------------------------------------------------------------
/utils/plot_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 |
4 |
5 | def get_max(arr):
6 | new_arr = np.zeros(arr.shape)
7 | for i in range(arr.shape[0]):
8 | new_arr[i] = np.max(arr[:i + 1])
9 | return new_arr
10 |
11 |
12 | def get_min(arr):
13 | new_arr = np.zeros(arr.shape)
14 | for i in range(arr.shape[0]):
15 | new_arr[i] = np.min(arr[:i + 1])
16 | return new_arr
17 |
18 |
19 | def plot_curve(data, title, xlabel='Epochs', ylabel=None, save=None):
20 | '''
21 | function:
22 | plot curve
23 | required params:
24 | data: the data of curve, numpy.array, ndim=1
25 | title: the title for each curve
26 | optional params:
27 | xlabel: the flag of x-axis
28 | ylabel: the flag of y-axis, if None: ylabel=title
29 | save:
30 | if None: just show figure
31 | else: the path to save figure, the figure will be saved
32 | '''
33 | length = data.shape[0]
34 | axis = np.linspace(1, length, length)
35 | fig = plt.figure()
36 | plt.title('{} Graph'.format(title))
37 | plt.plot(axis, data)
38 | plt.legend()
39 | plt.xlabel(xlabel)
40 | if not ylabel:
41 | ylabel = title
42 | plt.ylabel(ylabel)
43 | plt.grid(True)
44 |
45 | if save:
46 | plt.savefig(save)
47 | plt.close(fig)
48 | else:
49 | plt.show()
50 |
51 |
52 | def plot_multi_curve(array_list, label_list, title='compare', xlabel='Epochs', ylabel=None, save=None):
53 | '''
54 | function:
55 | plot multi curve in one figure
56 | required params:
57 | array_list: the data of curves, a list
58 | each item should be numpy.array, ndim=1
59 | label_list: list, labels for each curve
60 | optional params:
61 | title: the title of figure
62 | xlabel: the flag of x-axis
63 | ylabel: the flag of y-axis, if None: ylabel=title
64 | save:
65 | if None: just show figure
66 | else: the path to save figure, the figure will be saved
67 | '''
68 | assert len(array_list) == len(label_list), "length not equal"
69 |
70 | fig = plt.figure()
71 | plt.title(title)
72 |
73 | # plot curves
74 | for i in range(len(array_list)):
75 | length = array_list[i].shape[0]
76 | axis = np.linspace(1, length, length)
77 | plt.plot(axis, array_list[i], label=label_list[i])
78 |
79 | plt.legend()
80 | plt.xlabel(xlabel)
81 | if not ylabel:
82 | ylabel = title
83 | plt.ylabel(ylabel)
84 | plt.grid(True)
85 |
86 | if save:
87 | plt.savefig(save)
88 | plt.close(fig)
89 | else:
90 | plt.show()
91 |
92 |
93 | def plot_multi_curve_given_axis(array_list, label_list, axis, title='compare', xlabel='Epochs', ylabel=None, save=None):
94 | '''
95 | function:
96 | plot multi curve in one figure
97 | required params:
98 | array_list: the data of curves, a list
99 | each item should be numpy.array, ndim=1
100 | label_list: list, labels for each curve
101 | axis: list, x-axis flags
102 | optional params:
103 | title: the title of figure
104 | xlabel: the flag of x-axis
105 | ylabel: the flag of y-axis, if None: ylabel=title
106 | save:
107 | if None: just show figure
108 | else: the path to save figure, the figure will be saved
109 | '''
110 | assert len(array_list) == len(label_list), "length not equal"
111 |
112 | fig = plt.figure()
113 | plt.title(title)
114 |
115 | # plot curves
116 | for i in range(len(array_list)):
117 | plt.plot(axis, array_list[i], label=label_list[i])
118 |
119 | plt.legend()
120 | plt.xlabel(xlabel)
121 | if not ylabel:
122 | ylabel = title
123 | plt.ylabel(ylabel)
124 | plt.grid(True)
125 |
126 | if save:
127 | plt.savefig(save)
128 | plt.close(fig)
129 | else:
130 | plt.show()
131 |
--------------------------------------------------------------------------------
/utils/string_utils.py:
--------------------------------------------------------------------------------
1 | def LCS(str1, str2):
2 | '''
3 | Longest Common Subsequence
4 | '''
5 | num_res = [[0 for _ in range(len(str2) + 1)] for _ in range(len(str1) + 1)]
6 | char_res = [['' for _ in range(len(str2) + 1)] for _ in range(len(str1) + 1)]
7 | for i in range(1, len(str1) + 1):
8 | for j in range(1, len(str2) + 1):
9 | if str1[i - 1] == str2[j - 1]:
10 | num_res[i][j] = num_res[i - 1][j - 1] + 1
11 | char_res[i][j] = char_res[i - 1][j - 1] + str1[i - 1]
12 | else:
13 | if num_res[i][j - 1] > num_res[i - 1][j]:
14 | num_res[i][j] = num_res[i][j - 1]
15 | char_res[i][j] = char_res[i][j - 1]
16 | else:
17 | num_res[i][j] = num_res[i - 1][j]
18 | char_res[i][j] = char_res[i - 1][j]
19 | return num_res[-1][-1], char_res[-1][-1]
20 |
--------------------------------------------------------------------------------
/utils/torchstat/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 | .pypirc
106 |
--------------------------------------------------------------------------------
/utils/torchstat/.travis.yml:
--------------------------------------------------------------------------------
1 | os: linux
2 | sudo: false
3 | language: python
4 | python:
5 | - "3.6"
6 |
7 | install:
8 | - pip install pycodestyle
9 |
10 | script:
11 | - pycodestyle torchstat/
12 |
13 | cache:
14 | - pip
15 |
16 | notifications:
17 | email: false
18 |
--------------------------------------------------------------------------------
/utils/torchstat/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Swall0w - Alan
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/utils/torchstat/README.md:
--------------------------------------------------------------------------------
1 | [](https://travis-ci.org/Swall0w/torchstat)
2 |
3 | # torchstat
4 | This is a lightweight neural network analyzer based on PyTorch.
5 | It is designed to make building your networks quick and easy, with the ability to debug them.
6 | **Note**: This repository is currently under development. Therefore, some APIs might be changed.
7 |
8 | This tools can show
9 |
10 | * Total number of network parameters
11 | * Theoretical amount of floating point arithmetics (FLOPs)
12 | * Theoretical amount of multiply-adds (MAdd)
13 | * Memory usage
14 |
15 | ## Installing
16 | There're two ways to install torchstat into your environment.
17 | * Install it via pip.
18 | ```bash
19 | $ pip install torchstat
20 | ```
21 |
22 | * Install and update using **setup.py** after cloning this repository.
23 | ```bash
24 | $ python3 setup.py install
25 | ```
26 |
27 | ## A Simple Example
28 | If you want to run the torchstat asap, you can call it as a CLI tool if your network exists in a script.
29 | Otherwise you need to import torchstat as a module.
30 |
31 | ### CLI tool
32 | ```bash
33 | $ torchstat masato$ torchstat -f example.py -m Net
34 | [MAdd]: Dropout2d is not supported!
35 | [Flops]: Dropout2d is not supported!
36 | [Memory]: Dropout2d is not supported!
37 | module name input shape output shape params memory(MB) MAdd Flops MemRead(B) MemWrite(B) duration[%] MemR+W(B)
38 | 0 conv1 3 224 224 10 220 220 760.0 1.85 72,600,000.0 36,784,000.0 605152.0 1936000.0 57.49% 2541152.0
39 | 1 conv2 10 110 110 20 106 106 5020.0 0.86 112,360,000.0 56,404,720.0 504080.0 898880.0 26.62% 1402960.0
40 | 2 conv2_drop 20 106 106 20 106 106 0.0 0.86 0.0 0.0 0.0 0.0 4.09% 0.0
41 | 3 fc1 56180 50 2809050.0 0.00 5,617,950.0 2,809,000.0 11460920.0 200.0 11.58% 11461120.0
42 | 4 fc2 50 10 510.0 0.00 990.0 500.0 2240.0 40.0 0.22% 2280.0
43 | total 2815340.0 3.56 190,578,940.0 95,998,220.0 2240.0 40.0 100.00% 15407512.0
44 | ===============================================================================================================================================
45 | Total params: 2,815,340
46 | -----------------------------------------------------------------------------------------------------------------------------------------------
47 | Total memory: 3.56MB
48 | Total MAdd: 190.58MMAdd
49 | Total Flops: 96.0MFlops
50 | Total MemR+W: 14.69MB
51 | ```
52 |
53 | If you're not sure how to use a specific command, run the command with the -h or –help switches.
54 | You'll see usage information and a list of options you can use with the command.
55 |
56 | ### Module
57 | ```python
58 | from torchstat import stat
59 | import torchvision.models as models
60 |
61 | model = models.resnet18()
62 | stat(model, (3, 224, 224))
63 | ```
64 |
65 | ## Features & TODO
66 | **Note**: These features work only nn.Module. Modules in torch.nn.functional are not supported yet.
67 | - [x] FLOPs
68 | - [x] Number of Parameters
69 | - [x] Total memory
70 | - [x] Madd(FMA)
71 | - [x] MemRead
72 | - [x] MemWrite
73 | - [ ] Model summary(detail, layer-wise)
74 | - [ ] Export score table
75 | - [ ] Arbitrary input shape
76 |
77 | For the supported layers, check out [the details](./detail.md).
78 |
79 |
80 | ## Requirements
81 | * Python 3.6+
82 | * Pytorch 0.4.0+
83 | * Pandas 0.23.4+
84 | * NumPy 1.14.3+
85 |
86 | ## References
87 | Thanks to @sovrasov for the initial version of flops computation, @ceykmc for the backbone of scripts.
88 | * [flops-counter.pytorch](https://github.com/sovrasov/flops-counter.pytorch)
89 | * [pytorch_model_summary](https://github.com/ceykmc/pytorch_model_summary)
90 | * [chainer_computational_cost](https://github.com/belltailjp/chainer_computational_cost)
91 | * [convnet-burden](https://github.com/albanie/convnet-burden).
92 |
--------------------------------------------------------------------------------
/utils/torchstat/detail.md:
--------------------------------------------------------------------------------
1 | # Supported Layers
2 | |Layer|Flops|Madd|MemRead|MemWrite|
3 | |---|---|---|---|---|
4 | |Conv2d|ok|ok|ok| ok|
5 | |ConvTranspose2d| |ok|||
6 | |BatchNorm2d|ok|ok|ok|ok|
7 | |Linear|ok|ok|ok|ok|
8 | |UpSample|ok| |||
9 | |AvgPool2d|ok|ok|ok|ok|
10 | |MaxPool2d|ok|ok|ok|ok|
11 | |ReLU|ok|ok|ok|ok|
12 | |ReLU6|ok|ok|ok|ok|
13 | |LeaklyReLU|ok||ok|ok|
14 | |PReLU|ok|ok|ok|ok|
15 | |ELU|ok|ok|ok|ok|
16 |
--------------------------------------------------------------------------------
/utils/torchstat/example.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchstat import stat
5 |
6 |
7 | class Net(nn.Module):
8 | def __init__(self):
9 | super(Net, self).__init__()
10 | self.conv1 = nn.Conv2d(3, 10, kernel_size=5)
11 | self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
12 | self.conv2_drop = nn.Dropout2d()
13 | self.fc1 = nn.Linear(56180, 50)
14 | self.fc2 = nn.Linear(50, 10)
15 |
16 | def forward(self, x):
17 | x = F.relu(F.max_pool2d(self.conv1(x), 2))
18 | x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
19 | x = x.view(-1, 56180)
20 | x = F.relu(self.fc1(x))
21 | x = F.dropout(x, training=self.training)
22 | x = self.fc2(x)
23 | return F.log_softmax(x, dim=1)
24 |
25 |
26 | if __name__ == '__main__':
27 | model = Net()
28 | stat(model, (3, 224, 224))
29 |
--------------------------------------------------------------------------------
/utils/torchstat/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | numpy
3 | pandas
4 |
--------------------------------------------------------------------------------
/utils/torchstat/setup.cfg:
--------------------------------------------------------------------------------
1 | [pycodestyle]
2 | count = False
3 | ignore = E226,E302,E41
4 | max-line-length = 160
5 | statistics = True
6 |
--------------------------------------------------------------------------------
/utils/torchstat/setup.py:
--------------------------------------------------------------------------------
1 | #!usr/bin/env python3
2 |
3 | import re
4 | from os import path
5 |
6 | from setuptools import find_packages, setup
7 |
8 | package_name = "torchstat"
9 | root_dir = path.abspath(path.dirname(__file__))
10 |
11 |
12 | def _requirements():
13 | return [name.rstrip() for name in open(path.join(root_dir, 'requirements.txt'), encoding='utf-8').readlines()]
14 |
15 |
16 | def _test_requirements():
17 | return [name.rstrip() for name in open(path.join(root_dir, 'test_requirements.txt'), encoding='utf-8').readlines()]
18 |
19 |
20 | with open(path.join(root_dir, package_name, '__init__.py'), encoding='utf-8') as f:
21 | init_text = f.read()
22 | version = re.search(r'__version__ = [\'\"](.+?)[\'\"]', init_text).group(1)
23 | author = re.search(r'__author__ =\s*[\'\"](.+?)[\'\"]', init_text).group(1)
24 | url = re.search(r'__url__ =\s*[\'\"](.+?)[\'\"]', init_text).group(1)
25 |
26 | assert version
27 | assert author
28 | assert url
29 |
30 | with open('README.md', encoding='utf-8') as f:
31 | long_description = f.read()
32 |
33 | setup(
34 | name=package_name,
35 | version=version,
36 | description="torchstat: The Pytorch Model Analyzer.",
37 | long_description=long_description,
38 | long_description_content_type='text/markdown',
39 | author=author,
40 | url=url,
41 |
42 | install_requires = _requirements(),
43 | tests_requires = _test_requirements(),
44 | include_package_data=True,
45 |
46 | license=license,
47 | packages=find_packages(exclude=('tests')),
48 | test_suite='tests',
49 | entry_points="""
50 | [console_scripts]
51 | torchstat = torchstat.__main__:main
52 | """,
53 | )
54 |
--------------------------------------------------------------------------------
/utils/torchstat/test_requirements.txt:
--------------------------------------------------------------------------------
1 | pydocstyle
2 |
--------------------------------------------------------------------------------
/utils/torchstat/torchstat/__init__.py:
--------------------------------------------------------------------------------
1 | __copyright__ = 'Copyright (C) 2018 Swall0w'
2 | __version__ = '0.0.7'
3 | __author__ = 'Swall0w'
4 | __url__ = 'https://github.com/Swall0w/torchstat'
5 |
6 | from torchstat.compute_memory import compute_memory
7 | from torchstat.compute_madd import compute_madd
8 | from torchstat.compute_flops import compute_flops
9 | from torchstat.stat_tree import StatTree, StatNode
10 | from torchstat.model_hook import ModelHook
11 | from torchstat.reporter import report_format
12 | from torchstat.statistics import stat, ModelStat
13 |
14 | __all__ = ['report_format', 'StatTree', 'StatNode', 'compute_madd',
15 | 'compute_flops', 'ModelHook', 'stat', 'ModelStat', '__main__',
16 | 'compute_memory']
17 |
--------------------------------------------------------------------------------
/utils/torchstat/torchstat/__main__.py:
--------------------------------------------------------------------------------
1 | from torchstat import stat
2 | import argparse
3 | import importlib.util
4 | import torch
5 |
6 |
7 | def arg():
8 | parser = argparse.ArgumentParser(description='Torch model statistics')
9 | parser.add_argument('--file', '-f', type=str,
10 | help='Module file.')
11 | parser.add_argument('--model', '-m', type=str,
12 | help='Model name')
13 | parser.add_argument('--size', '-s', type=str, default='3x224x224',
14 | help='Input size. channels x height x width (default: 3x224x224)')
15 | return parser.parse_args()
16 |
17 |
18 | def main():
19 | args = arg()
20 | try:
21 | spec = importlib.util.spec_from_file_location('models', args.file)
22 | module = importlib.util.module_from_spec(spec)
23 | spec.loader.exec_module(module)
24 | model = getattr(module, args.model)()
25 | except Exception:
26 | import traceback
27 | print(f'Tried to import {args.model} from {args.file}. but failed.')
28 | traceback.print_exc()
29 |
30 | import sys
31 | sys.exit()
32 |
33 | input_size = tuple(int(x) for x in args.size.split('x'))
34 | stat(model, input_size, query_granularity=1)
35 |
--------------------------------------------------------------------------------
/utils/torchstat/torchstat/compute_flops.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import numpy as np
4 |
5 |
6 | def compute_flops(module, inp, out):
7 | if isinstance(module, nn.Conv2d):
8 | return compute_Conv2d_flops(module, inp, out)
9 | elif isinstance(module, nn.BatchNorm2d):
10 | return compute_BatchNorm2d_flops(module, inp, out)
11 | elif isinstance(module, (nn.AvgPool2d, nn.MaxPool2d)):
12 | return compute_Pool2d_flops(module, inp, out)
13 | elif isinstance(module, (nn.ReLU, nn.ReLU6, nn.PReLU, nn.ELU, nn.LeakyReLU)):
14 | return compute_ReLU_flops(module, inp, out)
15 | elif isinstance(module, nn.Upsample):
16 | return compute_Upsample_flops(module, inp, out)
17 | elif isinstance(module, nn.Linear):
18 | return compute_Linear_flops(module, inp, out)
19 | else:
20 | print(f"[Flops]: {type(module).__name__} is not supported!")
21 | return 0
22 | pass
23 |
24 |
25 | def compute_Conv2d_flops(module, inp, out):
26 | # Can have multiple inputs, getting the first one
27 | assert isinstance(module, nn.Conv2d)
28 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
29 |
30 | batch_size = inp.size()[0]
31 | in_c = inp.size()[1]
32 | k_h, k_w = module.kernel_size
33 | out_c, out_h, out_w = out.size()[1:]
34 | groups = module.groups
35 |
36 | filters_per_channel = out_c // groups
37 | conv_per_position_flops = k_h * k_w * in_c * filters_per_channel
38 | active_elements_count = batch_size * out_h * out_w
39 |
40 | total_conv_flops = conv_per_position_flops * active_elements_count
41 |
42 | bias_flops = 0
43 | if module.bias is not None:
44 | bias_flops = out_c * active_elements_count
45 |
46 | total_flops = total_conv_flops + bias_flops
47 | return total_flops
48 |
49 |
50 | def compute_BatchNorm2d_flops(module, inp, out):
51 | assert isinstance(module, nn.BatchNorm2d)
52 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
53 | in_c, in_h, in_w = inp.size()[1:]
54 | batch_flops = np.prod(inp.shape)
55 | if module.affine:
56 | batch_flops *= 2
57 | return batch_flops
58 |
59 |
60 | def compute_ReLU_flops(module, inp, out):
61 | assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.PReLU, nn.ELU, nn.LeakyReLU))
62 | batch_size = inp.size()[0]
63 | active_elements_count = batch_size
64 |
65 | for s in inp.size()[1:]:
66 | active_elements_count *= s
67 |
68 | return active_elements_count
69 |
70 |
71 | def compute_Pool2d_flops(module, inp, out):
72 | assert isinstance(module, nn.MaxPool2d) or isinstance(module, nn.AvgPool2d)
73 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
74 | return np.prod(inp.shape)
75 |
76 |
77 | def compute_Linear_flops(module, inp, out):
78 | assert isinstance(module, nn.Linear)
79 | assert len(inp.size()) == 2 and len(out.size()) == 2
80 | batch_size = inp.size()[0]
81 | return batch_size * inp.size()[1] * out.size()[1]
82 |
83 | def compute_Upsample_flops(module, inp, out):
84 | assert isinstance(module, nn.Upsample)
85 | output_size = out[0]
86 | batch_size = inp.size()[0]
87 | output_elements_count = batch_size
88 | for s in output_size.shape[1:]:
89 | output_elements_count *= s
90 |
91 | return output_elements_count
92 |
--------------------------------------------------------------------------------
/utils/torchstat/torchstat/compute_madd.py:
--------------------------------------------------------------------------------
1 | """
2 | compute Multiply-Adds(MAdd) of each leaf module
3 | """
4 |
5 | import torch.nn as nn
6 |
7 |
8 | def compute_Conv2d_madd(module, inp, out):
9 | assert isinstance(module, nn.Conv2d)
10 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
11 |
12 | in_c = inp.size()[1]
13 | k_h, k_w = module.kernel_size
14 | out_c, out_h, out_w = out.size()[1:]
15 | groups = module.groups
16 |
17 | # ops per output element
18 | kernel_mul = k_h * k_w * (in_c // groups)
19 | kernel_add = kernel_mul - 1 + (0 if module.bias is None else 1)
20 |
21 | kernel_mul_group = kernel_mul * out_h * out_w * (out_c // groups)
22 | kernel_add_group = kernel_add * out_h * out_w * (out_c // groups)
23 |
24 | total_mul = kernel_mul_group * groups
25 | total_add = kernel_add_group * groups
26 |
27 | return total_mul + total_add
28 |
29 |
30 | def compute_ConvTranspose2d_madd(module, inp, out):
31 | assert isinstance(module, nn.ConvTranspose2d)
32 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
33 |
34 | in_c, in_h, in_w = inp.size()[1:]
35 | k_h, k_w = module.kernel_size
36 | out_c, out_h, out_w = out.size()[1:]
37 | groups = module.groups
38 |
39 | kernel_mul = k_h * k_w * (in_c // groups)
40 | kernel_add = kernel_mul - 1 + (0 if module.bias is None else 1)
41 |
42 | kernel_mul_group = kernel_mul * in_h * in_w * (out_c // groups)
43 | kernel_add_group = kernel_add * in_h * in_w * (out_c // groups)
44 |
45 | total_mul = kernel_mul_group * groups
46 | total_add = kernel_add_group * groups
47 |
48 | return total_mul + total_add
49 |
50 |
51 | def compute_BatchNorm2d_madd(module, inp, out):
52 | assert isinstance(module, nn.BatchNorm2d)
53 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
54 |
55 | in_c, in_h, in_w = inp.size()[1:]
56 |
57 | # 1. sub mean
58 | # 2. div standard deviation
59 | # 3. mul alpha
60 | # 4. add beta
61 | return 4 * in_c * in_h * in_w
62 |
63 |
64 | def compute_MaxPool2d_madd(module, inp, out):
65 | assert isinstance(module, nn.MaxPool2d)
66 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
67 |
68 | if isinstance(module.kernel_size, (tuple, list)):
69 | k_h, k_w = module.kernel_size
70 | else:
71 | k_h, k_w = module.kernel_size, module.kernel_size
72 | out_c, out_h, out_w = out.size()[1:]
73 |
74 | return (k_h * k_w - 1) * out_h * out_w * out_c
75 |
76 |
77 | def compute_AvgPool2d_madd(module, inp, out):
78 | assert isinstance(module, nn.AvgPool2d)
79 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
80 |
81 | if isinstance(module.kernel_size, (tuple, list)):
82 | k_h, k_w = module.kernel_size
83 | else:
84 | k_h, k_w = module.kernel_size, module.kernel_size
85 | out_c, out_h, out_w = out.size()[1:]
86 |
87 | kernel_add = k_h * k_w - 1
88 | kernel_avg = 1
89 |
90 | return (kernel_add + kernel_avg) * (out_h * out_w) * out_c
91 |
92 |
93 | def compute_ReLU_madd(module, inp, out):
94 | assert isinstance(module, (nn.ReLU, nn.ReLU6))
95 |
96 | count = 1
97 | for i in inp.size()[1:]:
98 | count *= i
99 | return count
100 |
101 |
102 | def compute_Softmax_madd(module, inp, out):
103 | assert isinstance(module, nn.Softmax)
104 | assert len(inp.size()) > 1
105 |
106 | count = 1
107 | for s in inp.size()[1:]:
108 | count *= s
109 | exp = count
110 | add = count - 1
111 | div = count
112 | return exp + add + div
113 |
114 |
115 | def compute_Linear_madd(module, inp, out):
116 | assert isinstance(module, nn.Linear)
117 | assert len(inp.size()) == 2 and len(out.size()) == 2
118 |
119 | num_in_features = inp.size()[1]
120 | num_out_features = out.size()[1]
121 |
122 | mul = num_in_features
123 | add = num_in_features - 1
124 | return num_out_features * (mul + add)
125 |
126 |
127 | def compute_Bilinear_madd(module, inp1, inp2, out):
128 | assert isinstance(module, nn.Bilinear)
129 | assert len(inp1.size()) == 2 and len(inp2.size()) == 2 and len(out.size()) == 2
130 |
131 | num_in_features_1 = inp1.size()[1]
132 | num_in_features_2 = inp2.size()[1]
133 | num_out_features = out.size()[1]
134 |
135 | mul = num_in_features_1 * num_in_features_2 + num_in_features_2
136 | add = num_in_features_1 * num_in_features_2 + num_in_features_2 - 1
137 | return num_out_features * (mul + add)
138 |
139 |
140 | def compute_madd(module, inp, out):
141 | if isinstance(module, nn.Conv2d):
142 | return compute_Conv2d_madd(module, inp, out)
143 | elif isinstance(module, nn.ConvTranspose2d):
144 | return compute_ConvTranspose2d_madd(module, inp, out)
145 | elif isinstance(module, nn.BatchNorm2d):
146 | return compute_BatchNorm2d_madd(module, inp, out)
147 | elif isinstance(module, nn.MaxPool2d):
148 | return compute_MaxPool2d_madd(module, inp, out)
149 | elif isinstance(module, nn.AvgPool2d):
150 | return compute_AvgPool2d_madd(module, inp, out)
151 | elif isinstance(module, (nn.ReLU, nn.ReLU6)):
152 | return compute_ReLU_madd(module, inp, out)
153 | elif isinstance(module, nn.Softmax):
154 | return compute_Softmax_madd(module, inp, out)
155 | elif isinstance(module, nn.Linear):
156 | return compute_Linear_madd(module, inp, out)
157 | elif isinstance(module, nn.Bilinear):
158 | return compute_Bilinear_madd(module, inp[0], inp[1], out)
159 | else:
160 | print(f"[MAdd]: {type(module).__name__} is not supported!")
161 | return 0
162 |
--------------------------------------------------------------------------------
/utils/torchstat/torchstat/compute_memory.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import numpy as np
4 |
5 |
6 | def compute_memory(module, inp, out):
7 | if isinstance(module, (nn.ReLU, nn.ReLU6, nn.ELU, nn.LeakyReLU)):
8 | return compute_ReLU_memory(module, inp, out)
9 | elif isinstance(module, nn.PReLU):
10 | return compute_PReLU_memory(module, inp, out)
11 | elif isinstance(module, nn.Conv2d):
12 | return compute_Conv2d_memory(module, inp, out)
13 | elif isinstance(module, nn.BatchNorm2d):
14 | return compute_BatchNorm2d_memory(module, inp, out)
15 | elif isinstance(module, nn.Linear):
16 | return compute_Linear_memory(module, inp, out)
17 | elif isinstance(module, (nn.AvgPool2d, nn.MaxPool2d)):
18 | return compute_Pool2d_memory(module, inp, out)
19 | else:
20 | print(f"[Memory]: {type(module).__name__} is not supported!")
21 | return (0, 0)
22 | pass
23 |
24 |
25 | def num_params(module):
26 | return sum(p.numel() for p in module.parameters() if p.requires_grad)
27 |
28 |
29 | def compute_ReLU_memory(module, inp, out):
30 | assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.ELU, nn.LeakyReLU))
31 | batch_size = inp.size()[0]
32 | mread = batch_size * inp.size()[1:].numel()
33 | mwrite = batch_size * inp.size()[1:].numel()
34 |
35 | return (mread, mwrite)
36 |
37 |
38 | def compute_PReLU_memory(module, inp, out):
39 | assert isinstance(module, (nn.PReLU))
40 | batch_size = inp.size()[0]
41 | mread = batch_size * (inp.size()[1:].numel() + num_params(module))
42 | mwrite = batch_size * inp.size()[1:].numel()
43 |
44 | return (mread, mwrite)
45 |
46 |
47 | def compute_Conv2d_memory(module, inp, out):
48 | # Can have multiple inputs, getting the first one
49 | assert isinstance(module, nn.Conv2d)
50 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
51 |
52 | batch_size = inp.size()[0]
53 | in_c = inp.size()[1]
54 | out_c, out_h, out_w = out.size()[1:]
55 |
56 | # This includes weighs with bias if the module contains it.
57 | mread = batch_size * (inp.size()[1:].numel() + num_params(module))
58 | mwrite = batch_size * out_c * out_h * out_w
59 | return (mread, mwrite)
60 |
61 |
62 | def compute_BatchNorm2d_memory(module, inp, out):
63 | assert isinstance(module, nn.BatchNorm2d)
64 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
65 | batch_size, in_c, in_h, in_w = inp.size()
66 |
67 | mread = batch_size * (inp.size()[1:].numel() + 2 * in_c)
68 | mwrite = inp.size().numel()
69 | return (mread, mwrite)
70 |
71 |
72 | def compute_Linear_memory(module, inp, out):
73 | assert isinstance(module, nn.Linear)
74 | assert len(inp.size()) == 2 and len(out.size()) == 2
75 | batch_size = inp.size()[0]
76 | mread = batch_size * (inp.size()[1:].numel() + num_params(module))
77 | mwrite = out.size().numel()
78 |
79 | return (mread, mwrite)
80 |
81 |
82 | def compute_Pool2d_memory(module, inp, out):
83 | assert isinstance(module, (nn.MaxPool2d, nn.AvgPool2d))
84 | assert len(inp.size()) == 4 and len(inp.size()) == len(out.size())
85 | batch_size = inp.size()[0]
86 | mread = batch_size * inp.size()[1:].numel()
87 | mwrite = batch_size * out.size()[1:].numel()
88 | return (mread, mwrite)
89 |
--------------------------------------------------------------------------------
/utils/torchstat/torchstat/model_hook.py:
--------------------------------------------------------------------------------
1 | import time
2 | from collections import OrderedDict
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 |
7 | from torchstat import compute_madd
8 | from torchstat import compute_flops
9 | from torchstat import compute_memory
10 |
11 |
12 | class ModelHook(object):
13 | def __init__(self, model, input_size):
14 | assert isinstance(model, nn.Module)
15 | assert isinstance(input_size, (list, tuple))
16 |
17 | self._model = model
18 | self._input_size = input_size
19 | self._origin_call = dict() # sub module call hook
20 |
21 | self._hook_model()
22 | x = torch.rand(1, *self._input_size).to('cuda') # add module duration time
23 | self._model.eval()
24 | self._model(x)
25 |
26 | @staticmethod
27 | def _register_buffer(module):
28 | assert isinstance(module, nn.Module)
29 |
30 | if len(list(module.children())) > 0:
31 | return
32 |
33 | module.register_buffer('input_shape', torch.zeros(3).int())
34 | module.register_buffer('output_shape', torch.zeros(3).int())
35 | module.register_buffer('parameter_quantity', torch.zeros(1).int())
36 | module.register_buffer('inference_memory', torch.zeros(1).long())
37 | module.register_buffer('MAdd', torch.zeros(1).long())
38 | module.register_buffer('duration', torch.zeros(1).float())
39 | module.register_buffer('Flops', torch.zeros(1).long())
40 | module.register_buffer('Memory', torch.zeros(2).long())
41 |
42 | def _sub_module_call_hook(self):
43 | def wrap_call(module, *input, **kwargs):
44 | assert module.__class__ in self._origin_call
45 |
46 | # Itemsize for memory
47 | itemsize = input[0].cpu().detach().numpy().itemsize
48 |
49 | start = time.time()
50 | output = self._origin_call[module.__class__](module, *input, **kwargs)
51 | end = time.time()
52 | module.duration = torch.from_numpy(
53 | np.array([end - start], dtype=np.float32))
54 |
55 | module.input_shape = torch.from_numpy(
56 | np.array(input[0].size()[1:], dtype=np.int32))
57 | module.output_shape = torch.from_numpy(
58 | np.array(output.size()[1:], dtype=np.int32))
59 |
60 | parameter_quantity = 0
61 | # iterate through parameters and count num params
62 | for name, p in module._parameters.items():
63 | parameter_quantity += (0 if p is None else torch.numel(p.data))
64 | module.parameter_quantity = torch.from_numpy(
65 | np.array([parameter_quantity], dtype=np.long))
66 |
67 | inference_memory = 1
68 | for s in output.size()[1:]:
69 | inference_memory *= s
70 | # memory += parameters_number # exclude parameter memory
71 | inference_memory = inference_memory * 4 / (1024 ** 2) # shown as MB unit
72 | module.inference_memory = torch.from_numpy(
73 | np.array([inference_memory], dtype=np.float32))
74 |
75 | if len(input) == 1:
76 | madd = compute_madd(module, input[0], output)
77 | flops = compute_flops(module, input[0], output)
78 | Memory = compute_memory(module, input[0], output)
79 | elif len(input) > 1:
80 | madd = compute_madd(module, input, output)
81 | flops = compute_flops(module, input, output)
82 | Memory = compute_memory(module, input, output)
83 | else: # error
84 | madd = 0
85 | flops = 0
86 | Memory = (0, 0)
87 | module.MAdd = torch.from_numpy(
88 | np.array([madd], dtype=np.int64))
89 | module.Flops = torch.from_numpy(
90 | np.array([flops], dtype=np.int64))
91 | Memory = np.array(Memory, dtype=np.int32) * itemsize
92 | module.Memory = torch.from_numpy(Memory)
93 |
94 | return output
95 |
96 | for module in self._model.modules():
97 | if len(list(module.children())) == 0 and module.__class__ not in self._origin_call:
98 | self._origin_call[module.__class__] = module.__class__.__call__
99 | module.__class__.__call__ = wrap_call
100 |
101 | def _hook_model(self):
102 | self._model.apply(self._register_buffer)
103 | self._sub_module_call_hook()
104 |
105 | @staticmethod
106 | def _retrieve_leaf_modules(model):
107 | leaf_modules = []
108 | for name, m in model.named_modules():
109 | if len(list(m.children())) == 0:
110 | leaf_modules.append((name, m))
111 | return leaf_modules
112 |
113 | def retrieve_leaf_modules(self):
114 | return OrderedDict(self._retrieve_leaf_modules(self._model))
115 |
--------------------------------------------------------------------------------
/utils/torchstat/torchstat/reporter.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 |
3 |
4 | pd.set_option('display.width', 1000)
5 | pd.set_option('display.max_rows', 10000)
6 | pd.set_option('display.max_columns', 10000)
7 |
8 |
9 | def round_value(value, binary=False):
10 | divisor = 1024. if binary else 1000.
11 |
12 | if value // divisor**3 > 0:
13 | return str(round(value / divisor**3, 2)) + 'G'
14 | elif value // divisor**2 > 0:
15 | return str(round(value / divisor**2, 2)) + 'M'
16 | elif value // divisor > 0:
17 | return str(round(value / divisor, 2)) + 'K'
18 | return str(value)
19 |
20 |
21 | def report_format(collected_nodes):
22 | data = list()
23 | for node in collected_nodes:
24 | name = node.name
25 | input_shape = ' '.join(['{:>3d}'] * len(node.input_shape)).format(
26 | *[e for e in node.input_shape])
27 | output_shape = ' '.join(['{:>3d}'] * len(node.output_shape)).format(
28 | *[e for e in node.output_shape])
29 | parameter_quantity = node.parameter_quantity
30 | inference_memory = node.inference_memory
31 | MAdd = node.MAdd
32 | Flops = node.Flops
33 | mread, mwrite = [i for i in node.Memory]
34 | duration = node.duration
35 | data.append([name, input_shape, output_shape, parameter_quantity,
36 | inference_memory, MAdd, duration, Flops, mread,
37 | mwrite])
38 | df = pd.DataFrame(data)
39 | df.columns = ['module name', 'input shape', 'output shape',
40 | 'params', 'memory(MB)',
41 | 'MAdd', 'duration', 'Flops', 'MemRead(B)', 'MemWrite(B)']
42 | df['duration[%]'] = df['duration'] / (df['duration'].sum() + 1e-7)
43 | df['MemR+W(B)'] = df['MemRead(B)'] + df['MemWrite(B)']
44 | total_parameters_quantity = df['params'].sum()
45 | total_memory = df['memory(MB)'].sum()
46 | total_operation_quantity = df['MAdd'].sum()
47 | total_flops = df['Flops'].sum()
48 | total_duration = df['duration[%]'].sum()
49 | total_mread = df['MemRead(B)'].sum()
50 | total_mwrite = df['MemWrite(B)'].sum()
51 | total_memrw = df['MemR+W(B)'].sum()
52 | del df['duration']
53 |
54 | # Add Total row
55 | total_df = pd.Series([total_parameters_quantity, total_memory,
56 | total_operation_quantity, total_flops,
57 | total_duration, mread, mwrite, total_memrw],
58 | index=['params', 'memory(MB)', 'MAdd', 'Flops', 'duration[%]',
59 | 'MemRead(B)', 'MemWrite(B)', 'MemR+W(B)'],
60 | name='total')
61 | df = df.append(total_df)
62 |
63 | df = df.fillna(' ')
64 | df['memory(MB)'] = df['memory(MB)'].apply(
65 | lambda x: '{:.2f}'.format(x))
66 | df['duration[%]'] = df['duration[%]'].apply(lambda x: '{:.2%}'.format(x))
67 | df['MAdd'] = df['MAdd'].apply(lambda x: '{:,}'.format(x))
68 | df['Flops'] = df['Flops'].apply(lambda x: '{:,}'.format(x))
69 |
70 | summary = str(df) + '\n'
71 | summary += "=" * len(str(df).split('\n')[0])
72 | summary += '\n'
73 | summary += "Total params: {:,}\n".format(total_parameters_quantity)
74 |
75 | summary += "-" * len(str(df).split('\n')[0])
76 | summary += '\n'
77 | summary += "Total memory: {:.2f}MB\n".format(total_memory)
78 | summary += "Total MAdd: {}MAdd\n".format(round_value(total_operation_quantity))
79 | summary += "Total Flops: {}Flops\n".format(round_value(total_flops))
80 | summary += "Total MemR+W: {}B\n".format(round_value(total_memrw, True))
81 | return summary
82 |
--------------------------------------------------------------------------------
/utils/torchstat/torchstat/stat_tree.py:
--------------------------------------------------------------------------------
1 | import queue
2 |
3 |
4 | class StatTree(object):
5 | def __init__(self, root_node):
6 | assert isinstance(root_node, StatNode)
7 |
8 | self.root_node = root_node
9 |
10 | def get_same_level_max_node_depth(self, query_node):
11 | if query_node.name == self.root_node.name:
12 | return 0
13 | same_level_depth = max([child.depth for child in query_node.parent.children])
14 | return same_level_depth
15 |
16 | def update_stat_nodes_granularity(self):
17 | q = queue.Queue()
18 | q.put(self.root_node)
19 | while not q.empty():
20 | node = q.get()
21 | node.granularity = self.get_same_level_max_node_depth(node)
22 | for child in node.children:
23 | q.put(child)
24 |
25 | def get_collected_stat_nodes(self, query_granularity):
26 | self.update_stat_nodes_granularity()
27 |
28 | collected_nodes = []
29 | stack = list()
30 | stack.append(self.root_node)
31 | while len(stack) > 0:
32 | node = stack.pop()
33 | for child in reversed(node.children):
34 | stack.append(child)
35 | if node.depth == query_granularity:
36 | collected_nodes.append(node)
37 | if node.depth < query_granularity <= node.granularity:
38 | collected_nodes.append(node)
39 | return collected_nodes
40 |
41 |
42 | class StatNode(object):
43 | def __init__(self, name=str(), parent=None):
44 | self._name = name
45 | self._input_shape = None
46 | self._output_shape = None
47 | self._parameter_quantity = 0
48 | self._inference_memory = 0
49 | self._MAdd = 0
50 | self._Memory = (0, 0)
51 | self._Flops = 0
52 | self._duration = 0
53 | self._duration_percent = 0
54 |
55 | self._granularity = 1
56 | self._depth = 1
57 | self.parent = parent
58 | self.children = list()
59 |
60 | @property
61 | def name(self):
62 | return self._name
63 |
64 | @name.setter
65 | def name(self, name):
66 | self._name = name
67 |
68 | @property
69 | def granularity(self):
70 | return self._granularity
71 |
72 | @granularity.setter
73 | def granularity(self, g):
74 | self._granularity = g
75 |
76 | @property
77 | def depth(self):
78 | d = self._depth
79 | if len(self.children) > 0:
80 | d += max([child.depth for child in self.children])
81 | return d
82 |
83 | @property
84 | def input_shape(self):
85 | if len(self.children) == 0: # leaf
86 | return self._input_shape
87 | else:
88 | return self.children[0].input_shape
89 |
90 | @input_shape.setter
91 | def input_shape(self, input_shape):
92 | assert isinstance(input_shape, (list, tuple))
93 | self._input_shape = input_shape
94 |
95 | @property
96 | def output_shape(self):
97 | if len(self.children) == 0: # leaf
98 | return self._output_shape
99 | else:
100 | return self.children[-1].output_shape
101 |
102 | @output_shape.setter
103 | def output_shape(self, output_shape):
104 | assert isinstance(output_shape, (list, tuple))
105 | self._output_shape = output_shape
106 |
107 | @property
108 | def parameter_quantity(self):
109 | # return self.parameters_quantity
110 | total_parameter_quantity = self._parameter_quantity
111 | for child in self.children:
112 | total_parameter_quantity += child.parameter_quantity
113 | return total_parameter_quantity
114 |
115 | @parameter_quantity.setter
116 | def parameter_quantity(self, parameter_quantity):
117 | assert parameter_quantity >= 0
118 | self._parameter_quantity = parameter_quantity
119 |
120 | @property
121 | def inference_memory(self):
122 | total_inference_memory = self._inference_memory
123 | for child in self.children:
124 | total_inference_memory += child.inference_memory
125 | return total_inference_memory
126 |
127 | @inference_memory.setter
128 | def inference_memory(self, inference_memory):
129 | self._inference_memory = inference_memory
130 |
131 | @property
132 | def MAdd(self):
133 | total_MAdd = self._MAdd
134 | for child in self.children:
135 | total_MAdd += child.MAdd
136 | return total_MAdd
137 |
138 | @MAdd.setter
139 | def MAdd(self, MAdd):
140 | self._MAdd = MAdd
141 |
142 | @property
143 | def Flops(self):
144 | total_Flops = self._Flops
145 | for child in self.children:
146 | total_Flops += child.Flops
147 | return total_Flops
148 |
149 | @Flops.setter
150 | def Flops(self, Flops):
151 | self._Flops = Flops
152 |
153 | @property
154 | def Memory(self):
155 | total_Memory = self._Memory
156 | for child in self.children:
157 | total_Memory[0] += child.Memory[0]
158 | total_Memory[1] += child.Memory[1]
159 | print(total_Memory)
160 | return total_Memory
161 |
162 | @Memory.setter
163 | def Memory(self, Memory):
164 | assert isinstance(Memory, (list, tuple))
165 | self._Memory = Memory
166 |
167 | @property
168 | def duration(self):
169 | total_duration = self._duration
170 | for child in self.children:
171 | total_duration += child.duration
172 | return total_duration
173 |
174 | @duration.setter
175 | def duration(self, duration):
176 | self._duration = duration
177 |
178 | def find_child_index(self, child_name):
179 | assert isinstance(child_name, str)
180 |
181 | index = -1
182 | for i in range(len(self.children)):
183 | if child_name == self.children[i].name:
184 | index = i
185 | return index
186 |
187 | def add_child(self, node):
188 | assert isinstance(node, StatNode)
189 |
190 | if self.find_child_index(node.name) == -1: # not exist
191 | self.children.append(node)
192 |
--------------------------------------------------------------------------------
/utils/torchstat/torchstat/statistics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torchstat import ModelHook
4 | from collections import OrderedDict
5 | from torchstat import StatTree, StatNode, report_format
6 |
7 |
8 | def get_parent_node(root_node, stat_node_name):
9 | assert isinstance(root_node, StatNode)
10 |
11 | node = root_node
12 | names = stat_node_name.split('.')
13 | for i in range(len(names) - 1):
14 | node_name = '.'.join(names[0:i+1])
15 | child_index = node.find_child_index(node_name)
16 | assert child_index != -1
17 | node = node.children[child_index]
18 | return node
19 |
20 |
21 | def convert_leaf_modules_to_stat_tree(leaf_modules):
22 | assert isinstance(leaf_modules, OrderedDict)
23 |
24 | create_index = 1
25 | root_node = StatNode(name='root', parent=None)
26 | for leaf_module_name, leaf_module in leaf_modules.items():
27 | names = leaf_module_name.split('.')
28 | for i in range(len(names)):
29 | create_index += 1
30 | stat_node_name = '.'.join(names[0:i+1])
31 | parent_node = get_parent_node(root_node, stat_node_name)
32 | node = StatNode(name=stat_node_name, parent=parent_node)
33 | parent_node.add_child(node)
34 | if i == len(names) - 1: # leaf module itself
35 | input_shape = leaf_module.input_shape.numpy().tolist()
36 | output_shape = leaf_module.output_shape.numpy().tolist()
37 | node.input_shape = input_shape
38 | node.output_shape = output_shape
39 | node.parameter_quantity = leaf_module.parameter_quantity.numpy()[0]
40 | node.inference_memory = leaf_module.inference_memory.numpy()[0]
41 | node.MAdd = leaf_module.MAdd.numpy()[0]
42 | node.Flops = leaf_module.Flops.numpy()[0]
43 | node.duration = leaf_module.duration.numpy()[0]
44 | node.Memory = leaf_module.Memory.numpy().tolist()
45 | return StatTree(root_node)
46 |
47 |
48 | class ModelStat(object):
49 | def __init__(self, model, input_size, query_granularity=1):
50 | assert isinstance(model, nn.Module)
51 | assert isinstance(input_size, (tuple, list)) and len(input_size) == 3
52 | self._model = model
53 | self._input_size = input_size
54 | self._query_granularity = query_granularity
55 |
56 | def _analyze_model(self):
57 | model_hook = ModelHook(self._model, self._input_size)
58 | leaf_modules = model_hook.retrieve_leaf_modules()
59 | stat_tree = convert_leaf_modules_to_stat_tree(leaf_modules)
60 | collected_nodes = stat_tree.get_collected_stat_nodes(self._query_granularity)
61 | return collected_nodes
62 |
63 | def show_report(self):
64 | collected_nodes = self._analyze_model()
65 | report = report_format(collected_nodes)
66 | print(report)
67 |
68 |
69 | def stat(model, input_size, query_granularity=1):
70 | ms = ModelStat(model, input_size, query_granularity)
71 | ms.show_report()
72 |
--------------------------------------------------------------------------------
/utils/video_metric_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | from base.file_io_base import write_csv
4 | from base.os_base import listdir
5 | import utils.LPIPS.models as lpips_models
6 | from utils.image_metric_utils import calc_image_PSNR_SSIM, calc_image_LPIPS, calc_image_NIQE
7 |
8 |
9 | def calc_video_PSNR_SSIM(output_root, gt_root, crop_border=4, shift_window_size=0, test_ycbcr=False, crop_GT=False):
10 | '''
11 | 计算视频的 PSNR、SSIM,使用 EDVR 的计算方式
12 | 要求 output_root, gt_root 中的文件按顺序一一对应
13 | '''
14 |
15 | PSNR_sum = 0.
16 | SSIM_sum = 0.
17 | img_num = 0
18 |
19 | video_PSNR = []
20 | video_SSIM = []
21 |
22 | video_list = sorted(listdir(output_root))
23 | for v in video_list:
24 | v_PSNR_list, v_SSIM_list, _ = calc_image_PSNR_SSIM(
25 | output_root=os.path.join(output_root, v),
26 | gt_root=os.path.join(gt_root, v),
27 | crop_border=crop_border, shift_window_size=shift_window_size,
28 | test_ycbcr=test_ycbcr, crop_GT=crop_GT
29 | )
30 | PSNR_sum += sum(v_PSNR_list)
31 | SSIM_sum += sum(v_SSIM_list)
32 | img_num += len(v_PSNR_list)
33 |
34 | video_PSNR.append({
35 | 'video_name': v,
36 | 'psnr': v_PSNR_list
37 | })
38 | video_SSIM.append({
39 | 'video_name': v,
40 | 'ssim': v_SSIM_list
41 | })
42 |
43 | logs = []
44 | PSNR_SSIM_csv_log = {
45 | 'col_names': [],
46 | 'row_names': [output_root],
47 | 'psnr_ssim': [[]]
48 | }
49 | for v_psnr, v_ssim in zip(video_PSNR, video_SSIM):
50 | PSNR_SSIM_csv_log['col_names'].append('#{}'.format(v_psnr['video_name']))
51 | PSNR_SSIM_csv_log['psnr_ssim'][0].append('{:.5}/{:.4}'.format(sum(v_psnr['psnr']) / len(v_psnr['psnr']),
52 | sum(v_ssim['ssim']) / len(v_ssim['ssim'])))
53 | log = 'Video: {} PSNR={:.5}, SSIM={:.4}'.format(v_psnr['video_name'],
54 | sum(v_psnr['psnr']) / len(v_psnr['psnr']),
55 | sum(v_ssim['ssim']) / len(v_ssim['ssim']))
56 | print(log)
57 | logs.append(log)
58 | PSNR_SSIM_csv_log['col_names'].append('AVG')
59 | PSNR_SSIM_csv_log['psnr_ssim'][0].append('{:.5}/{:.4}'.format(PSNR_sum / img_num, SSIM_sum / img_num))
60 | log = 'Average PSNR={:.5}, SSIM={:.4}'.format(PSNR_sum / img_num, SSIM_sum / img_num)
61 | print(log)
62 | logs.append(log)
63 |
64 | return PSNR_SSIM_csv_log, logs
65 |
66 |
67 | def batch_calc_video_PSNR_SSIM(root_list, crop_border=4, shift_window_size=0, test_ycbcr=False, crop_GT=False,
68 | save_log=False, save_log_root=None, combine_save=False):
69 | '''
70 | required params:
71 | root_list: a list, each item should be a dictionary that given two key-values:
72 | output: the dir of output videos
73 | gt: the dir of gt videos
74 | optional params:
75 | crop_border: defalut=4, crop pixels when calculating PSNR/SSIM
76 | shift_window_size: defalut=0, if >0, shifting image within a window for best metric
77 | test_ycbcr: default=False, if True, applying Ycbcr color space
78 | crop_GT: default=False, if True, cropping GT to output size
79 | save_log: default=False, if True, saving csv log
80 | save_log_root: thr dir of output log
81 | combine_save: default=False, if True, combining all output log to one csv file
82 | return:
83 | log_list: a list, each item is a dictionary that given two key-values:
84 | data_path: the evaluated dir
85 | log: the log of this dir
86 | '''
87 | if save_log:
88 | assert save_log_root is not None, "Unknown save_log_root!"
89 |
90 | total_csv_log = []
91 | log_list = []
92 | for i, root in enumerate(root_list):
93 | ouput_root = root['output']
94 | gt_root = root['gt']
95 | print(">>>> Now Evaluation >>>>")
96 | print(">>>> OUTPUT: {}".format(ouput_root))
97 | print(">>>> GT: {}".format(gt_root))
98 | csv_log, logs = calc_video_PSNR_SSIM(
99 | ouput_root, gt_root, crop_border=crop_border, shift_window_size=shift_window_size,
100 | test_ycbcr=test_ycbcr, crop_GT=crop_GT
101 | )
102 | log_list.append({
103 | 'data_path': ouput_root,
104 | 'log': logs
105 | })
106 |
107 | # output the PSNR/SSIM log of each evaluated dir to a single csv file
108 | if save_log:
109 | csv_log['row_names'] = [os.path.basename(p) for p in csv_log['row_names']]
110 | write_csv(file_path=os.path.join(save_log_root, "{}_{}.csv".format(i, csv_log['row_names'][0])),
111 | data=np.array(csv_log['psnr_ssim']),
112 | row_names=csv_log['row_names'],
113 | col_names=csv_log['col_names'])
114 | total_csv_log.append(csv_log)
115 |
116 | # output all PSNR/SSIM log to a csv file
117 | if save_log and combine_save and len(total_csv_log) > 0:
118 | com_csv_log = {
119 | 'col_names': total_csv_log[0]['col_names'],
120 | 'row_names': [],
121 | 'psnr_ssim': []
122 | }
123 | for csv_log in total_csv_log:
124 | com_csv_log['row_names'].append(csv_log['row_names'][0])
125 | com_csv_log['psnr_ssim'].append(csv_log['psnr_ssim'][0])
126 | write_csv(file_path=os.path.join(save_log_root, "psnr_ssim.csv"),
127 | data=np.array(com_csv_log['psnr_ssim']),
128 | row_names=com_csv_log['row_names'],
129 | col_names=com_csv_log['col_names'])
130 |
131 | print("--------------------------------------------------------------------------------------")
132 | for i, logs in enumerate(log_list):
133 | print("## The {}-th:".format(i))
134 | print(">> ", logs['data_path'])
135 | for log in logs['log']:
136 | print(">> ", log)
137 |
138 | return log_list
139 |
140 |
141 | def calc_video_LPIPS(output_root, gt_root, model=None, use_gpu=False, spatial=True):
142 | '''
143 | 计算视频的 LPIPS
144 | 要求 output_root, gt_root 中的文件按顺序一一对应
145 | '''
146 |
147 | if model is None:
148 | model = lpips_models.PerceptualLoss(model='net-lin', net='alex', use_gpu=use_gpu, spatial=spatial)
149 |
150 | LPIPS_sum = 0.
151 | img_num = 0
152 |
153 | video_LPIPS = []
154 |
155 | video_list = sorted(listdir(output_root))
156 | for v in video_list:
157 | v_LPIPS_list, _ = calc_image_LPIPS(
158 | output_root=os.path.join(output_root, v),
159 | gt_root=os.path.join(gt_root, v),
160 | model=model, use_gpu=use_gpu, spatial=spatial
161 | )
162 | LPIPS_sum += sum(v_LPIPS_list)
163 | img_num += len(v_LPIPS_list)
164 |
165 | video_LPIPS.append({
166 | 'video_name': v,
167 | 'lpips': v_LPIPS_list
168 | })
169 |
170 | logs = []
171 | LPIPS_csv_log = {
172 | 'col_names': [],
173 | 'row_names': [output_root],
174 | 'lpips': [[]]
175 | }
176 | for v_lpips in video_LPIPS:
177 | LPIPS_csv_log['col_names'].append('#{}'.format(v_lpips['video_name']))
178 | LPIPS_csv_log['lpips'][0].append('{:.4}'.format(sum(v_lpips['lpips']) / len(v_lpips['lpips'])))
179 | log = 'Video: {} LPIPS={:.4}'.format(v_lpips['video_name'], sum(v_lpips['lpips']) / len(v_lpips['lpips']))
180 | print(log)
181 | logs.append(log)
182 | LPIPS_csv_log['col_names'].append('AVG')
183 | LPIPS_csv_log['lpips'][0].append('{:.4}'.format(LPIPS_sum / img_num))
184 | log = 'Average LPIPS={:.4}'.format(LPIPS_sum / img_num)
185 | print(log)
186 | logs.append(log)
187 |
188 | return LPIPS_csv_log, logs
189 |
190 |
191 | def batch_calc_video_LPIPS(root_list, use_gpu=False, spatial=True,
192 | save_log=False, save_log_root=None, combine_save=False):
193 | '''
194 | required params:
195 | root_list: a list, each item should be a dictionary that given two key-values:
196 | output: the dir of output videos
197 | gt: the dir of gt videos
198 | optional params:
199 | use_gpu: defalut=False, if True, using gpu
200 | spatial: default=True, if True, return spatial map
201 | save_log: default=False, if True, saving csv log
202 | save_log_root: thr dir of output log
203 | combine_save: default=False, if True, combining all output log to one csv file
204 | return:
205 | log_list: a list, each item is a dictionary that given two key-values:
206 | data_path: the evaluated dir
207 | log: the log of this dir
208 | '''
209 | if save_log:
210 | assert save_log_root is not None, "Unknown save_log_root!"
211 |
212 | model = lpips_models.PerceptualLoss(model='net-lin', net='alex', use_gpu=use_gpu, spatial=spatial)
213 |
214 | total_csv_log = []
215 | log_list = []
216 | for i, root in enumerate(root_list):
217 | ouput_root = root['output']
218 | gt_root = root['gt']
219 | print(">>>> Now Evaluation >>>>")
220 | print(">>>> OUTPUT: {}".format(ouput_root))
221 | print(">>>> GT: {}".format(gt_root))
222 | csv_log, logs = calc_video_LPIPS(ouput_root, gt_root, model=model, use_gpu=use_gpu, spatial=spatial)
223 | log_list.append({
224 | 'data_path': ouput_root,
225 | 'log': logs
226 | })
227 |
228 | # output the LPIPS log of each evaluated dir to a single csv file
229 | if save_log:
230 | csv_log['row_names'] = [os.path.basename(p) for p in csv_log['row_names']]
231 | write_csv(file_path=os.path.join(save_log_root, "{}_{}.csv".format(i, csv_log['row_names'][0])),
232 | data=np.array(csv_log['lpips']),
233 | row_names=csv_log['row_names'],
234 | col_names=csv_log['col_names'])
235 | total_csv_log.append(csv_log)
236 |
237 | # output all LPIPS log to a csv file
238 | if save_log and combine_save and len(total_csv_log) > 0:
239 | com_csv_log = {
240 | 'col_names': total_csv_log[0]['col_names'],
241 | 'row_names': [],
242 | 'lpips': []
243 | }
244 | for csv_log in total_csv_log:
245 | com_csv_log['row_names'].append(csv_log['row_names'][0])
246 | com_csv_log['lpips'].append(csv_log['lpips'][0])
247 | write_csv(file_path=os.path.join(save_log_root, "lpips.csv"),
248 | data=np.array(com_csv_log['lpips']),
249 | row_names=com_csv_log['row_names'],
250 | col_names=com_csv_log['col_names'])
251 |
252 | print("--------------------------------------------------------------------------------------")
253 | for i, logs in enumerate(log_list):
254 | print("## The {}-th:".format(i))
255 | print(">> ", logs['data_path'])
256 | for log in logs['log']:
257 | print(">> ", log)
258 |
259 | return log_list
260 |
261 |
262 | def calc_video_NIQE(output_root, crop_border=4):
263 | '''
264 | 计算视频的 NIQE
265 | '''
266 |
267 | NIQE_sum = 0.
268 | img_num = 0
269 |
270 | video_NIQE = []
271 |
272 | video_list = sorted(listdir(output_root))
273 | for v in video_list:
274 | v_NIQE_list, _ = calc_image_NIQE(
275 | output_root=os.path.join(output_root, v),
276 | crop_border=crop_border
277 | )
278 | NIQE_sum += sum(v_NIQE_list)
279 | img_num += len(v_NIQE_list)
280 |
281 | video_NIQE.append({
282 | 'video_name': v,
283 | 'niqe': v_NIQE_list
284 | })
285 |
286 | logs = []
287 | NIQE_csv_log = {
288 | 'col_names': [],
289 | 'row_names': [output_root],
290 | 'niqe': [[]]
291 | }
292 | for v_niqe in video_NIQE:
293 | NIQE_csv_log['col_names'].append('#{}'.format(v_niqe['video_name']))
294 | NIQE_csv_log['niqe'][0].append('{:.5}'.format(sum(v_niqe['niqe']) / len(v_niqe['niqe'])))
295 | log = 'Video: {} NIQE={:.5}'.format(v_niqe['video_name'], sum(v_niqe['niqe']) / len(v_niqe['niqe']))
296 | print(log)
297 | logs.append(log)
298 | NIQE_csv_log['col_names'].append('AVG')
299 | NIQE_csv_log['niqe'][0].append('{:.5}'.format(NIQE_sum / img_num))
300 | log = 'Average NIQE={:.5}'.format(NIQE_sum / img_num)
301 | print(log)
302 | logs.append(log)
303 |
304 | return NIQE_csv_log, logs
305 |
306 |
307 | def batch_calc_video_NIQE(root_list, crop_border=4, save_log=False, save_log_root=None, combine_save=False):
308 | '''
309 | required params:
310 | root_list: a list, each item should be a dictionary that given key-values:
311 | output: the dir of output videos
312 | optional params:
313 | crop_border: defalut=4, crop pixels when calculating NIQE
314 | save_log: default=False, if True, saving csv log
315 | save_log_root: thr dir of output log
316 | combine_save: default=False, if True, combining all output log to one csv file
317 | return:
318 | log_list: a list, each item is a dictionary that given two key-values:
319 | data_path: the evaluated dir
320 | log: the log of this dir
321 | '''
322 | if save_log:
323 | assert save_log_root is not None, "Unknown save_log_root!"
324 |
325 | total_csv_log = []
326 | log_list = []
327 | for i, root in enumerate(root_list):
328 | ouput_root = root['output']
329 | print(">>>> Now Evaluation >>>>")
330 | print(">>>> OUTPUT: {}".format(ouput_root))
331 | csv_log, logs = calc_video_NIQE(
332 | ouput_root, crop_border=crop_border
333 | )
334 | log_list.append({
335 | 'data_path': ouput_root,
336 | 'log': logs
337 | })
338 |
339 | # output the NIQE log of each evaluated dir to a single csv file
340 | if save_log:
341 | csv_log['row_names'] = [os.path.basename(p) for p in csv_log['row_names']]
342 | write_csv(file_path=os.path.join(save_log_root, "{}_{}.csv".format(i, csv_log['row_names'][0])),
343 | data=np.array(csv_log['niqe']),
344 | row_names=csv_log['row_names'],
345 | col_names=csv_log['col_names'])
346 | total_csv_log.append(csv_log)
347 |
348 | # output all NIQE log to a csv file
349 | if save_log and combine_save and len(total_csv_log) > 0:
350 | com_csv_log = {
351 | 'col_names': total_csv_log[0]['col_names'],
352 | 'row_names': [],
353 | 'niqe': []
354 | }
355 | for csv_log in total_csv_log:
356 | com_csv_log['row_names'].append(csv_log['row_names'][0])
357 | com_csv_log['niqe'].append(csv_log['niqe'][0])
358 | write_csv(file_path=os.path.join(save_log_root, "niqe.csv"),
359 | data=np.array(com_csv_log['niqe']),
360 | row_names=com_csv_log['row_names'],
361 | col_names=com_csv_log['col_names'])
362 |
363 | print("--------------------------------------------------------------------------------------")
364 | for i, logs in enumerate(log_list):
365 | print("## The {}-th:".format(i))
366 | print(">> ", logs['data_path'])
367 | for log in logs['log']:
368 | print(">> ", log)
369 |
370 | return log_list
371 |
--------------------------------------------------------------------------------
/utils/video_regroup_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from base.os_base import handle_dir, copy_file, move_file, listdir, glob_match
3 | from utils import file_regroup_utils
4 |
5 |
6 | def VideoFlag2FlagVideo(ori_root, dest_root, ori_flag, dest_flag=None):
7 | '''
8 | videos/type/frames --> type/videos/frames
9 | params:
10 | ori_root: the dir of files that need to be processed
11 | dest_root: the dir for saving matched files
12 | ori_flag: the ori video flag(e.g. blur)
13 | dest_flag: the flag(e.g. blur) for saving videos
14 | default: None, that is keeping the ori flag
15 | '''
16 | if dest_flag is None:
17 | dest_flag = ori_flag
18 | handle_dir(dest_root)
19 | handle_dir(os.path.join(dest_root, dest_flag))
20 | video_list = listdir(ori_root)
21 | for v in video_list:
22 | image_list = listdir(os.path.join(ori_root, v, ori_flag))
23 | handle_dir(os.path.join(dest_root, dest_flag, v))
24 | for im in image_list:
25 | src = os.path.join(ori_root, v, ori_flag, im)
26 | dst = os.path.join(dest_root, dest_flag, v, im)
27 | copy_file(src, dst)
28 |
29 |
30 | def FlagVideo2VideoFlag(ori_root, dest_root, ori_flag, dest_flag=None):
31 | '''
32 | blur/videos/frames --> videos/blur/frames
33 | params:
34 | ori_root: the dir of files that need to be processed
35 | dest_root: the dir for saving matched files
36 | ori_flag: the ori video flag(e.g. blur)
37 | dest_flag: the flag(e.g. blur) for saving videos
38 | default: None, that is keeping the ori flag
39 | '''
40 | if dest_flag is None:
41 | dest_flag = ori_flag
42 | handle_dir(dest_root)
43 | video_list = listdir(os.path.join(ori_root, ori_flag))
44 | for v in video_list:
45 | image_list = listdir(os.path.join(ori_root, ori_flag, v))
46 | handle_dir(os.path.join(dest_root, v))
47 | handle_dir(os.path.join(dest_root, v, dest_flag))
48 | for im in image_list:
49 | src = os.path.join(ori_root, ori_flag, v, im)
50 | dst = os.path.join(dest_root, v, dest_flag, im)
51 | copy_file(src, dst)
52 |
53 |
54 | def remove_frames_prefix(root, prefix=''):
55 | '''
56 | remove prefix from frames
57 | params:
58 | root: the dir of videos that need to be processed
59 | prefix: the prefix to be removed
60 | '''
61 | video_list = listdir(root)
62 | for v in video_list:
63 | file_regroup_utils.remove_files_prefix(os.path.join(root, v), prefix=prefix)
64 |
65 |
66 | def remove_frames_postfix(root, postfix=''):
67 | '''
68 | remove postfix from frames
69 | params:
70 | root: the dir of videos that need to be processed
71 | postfix: the postfix to be removed
72 | '''
73 | video_list = listdir(root)
74 | for v in video_list:
75 | file_regroup_utils.remove_files_postfix(os.path.join(root, v), postfix=postfix)
76 |
77 |
78 | def add_frames_postfix(root, postfix=''):
79 | '''
80 | add postfix to frames
81 | params:
82 | root: the dir of videos that need to be processed
83 | postfix: the postfix to be added
84 | '''
85 | video_list = listdir(root)
86 | for v in video_list:
87 | file_regroup_utils.add_files_postfix(os.path.join(root, v), postfix=postfix)
88 |
89 |
90 | def extra_frames_by_postfix(ori_root, dest_root, match_postfix='', new_postfix=None, match_ext='*'):
91 | '''
92 | extra frames from ori_root to dest_root by match_postfix and match_ext
93 | params:
94 | ori_root: the dir of videos that need to be processed
95 | dest_root: the dir for saving matched files
96 | match_postfix: the postfix to be matched
97 | new_postfix: the postfix for matched files
98 | default: None, that is keeping the ori postfix
99 | match_ext: the ext to be matched
100 | '''
101 | if new_postfix is None:
102 | new_postfix = match_postfix
103 |
104 | handle_dir(dest_root)
105 | video_list = listdir(ori_root)
106 | for v in video_list:
107 | file_regroup_utils.extra_files_by_postfix(
108 | ori_root=os.path.join(ori_root, v),
109 | dest_root=os.path.join(dest_root, v),
110 | match_postfix=match_postfix,
111 | new_postfix=new_postfix,
112 | match_ext=match_ext
113 | )
114 |
115 |
116 | def resort_frames_index(root, template='{:0>4}', start_idx=0):
117 | '''
118 | resort frames' filename using template that index start from start_idx
119 | params:
120 | root: the dir of files that need to be processed
121 | template: the template for processed filename
122 | start_idx: the start index
123 | '''
124 | video_list = listdir(root)
125 | for v in video_list:
126 | file_regroup_utils.resort_files_index(os.path.join(root, v), template=template, start_idx=start_idx)
127 |
128 |
129 | def remove_head_tail_frames(root, recycle_bin=None, num=0):
130 | '''
131 | remove num hean&tail frames from videos
132 | params:
133 | root: the dir of files that need to be processed
134 | recycle_bin: the removed frames will be put here
135 | defalut: None, that is putting the removed frames in root/_recycle_bin
136 | num: the number of frames to be removed
137 | '''
138 | if recycle_bin is None:
139 | recycle_bin = os.path.join(root, '_recycle_bin')
140 | handle_dir(recycle_bin)
141 |
142 | video_list = listdir(root)
143 | for v in video_list:
144 | img_list = sorted(glob_match(os.path.join(root, v, "*")))
145 | handle_dir(os.path.join(recycle_bin, v))
146 | for i in range(num):
147 | src = img_list[i]
148 | dest = os.path.join(recycle_bin, v, os.path.basename(src))
149 | move_file(src, dest)
150 |
151 | src = img_list[-(i + 1)]
152 | dest = os.path.join(recycle_bin, v, os.path.basename(src))
153 | move_file(src, dest)
154 |
--------------------------------------------------------------------------------
/utils/video_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from base.os_base import handle_dir, copy_file, listdir
3 | from utils.image_utils import cv2_resize_images, matlab_resize_images, shift_images
4 |
5 |
6 | def matlab_resize_videos(ori_root, dest_root, scale=1.0, method='bicubic', filename_template="{}.png"):
7 | '''
8 | function:
9 | resizing videos in batches, same as matlab2017 imresize
10 | params:
11 | ori_root: string, the dir of videos that need to be processed
12 | dest_root: string, the dir to save processed videos
13 | scale: float, the resize scale
14 | method: string, the interpolation method,
15 | optional: 'bilinear', 'bicubic'
16 | default: 'bicubic'
17 | filename_template: string, the filename template for saving images
18 | '''
19 | handle_dir(dest_root)
20 | videos = listdir(ori_root)
21 | for v in videos:
22 | matlab_resize_images(
23 | ori_root=os.path.join(ori_root, v),
24 | dest_root=os.path.join(dest_root, v),
25 | scale=scale,
26 | method=method,
27 | filename_template=filename_template
28 | )
29 | print("Video", v, "resize done !")
30 |
31 |
32 | def cv2_resize_videos(ori_root, dest_root, scale=1.0, method='bicubic', filename_template="{}.png"):
33 | '''
34 | function:
35 | resizing videos in batches
36 | params:
37 | ori_root: string, the dir of videos that need to be processed
38 | dest_root: string, the dir to save processed videos
39 | scale: float, the resize scale
40 | method: string, the interpolation method,
41 | optional: 'nearest', 'bilinear', 'bicubic'
42 | default: 'bicubic'
43 | filename_template: string, the filename template for saving images
44 | '''
45 | handle_dir(dest_root)
46 | videos = listdir(ori_root)
47 | for v in videos:
48 | cv2_resize_images(
49 | ori_root=os.path.join(ori_root, v),
50 | dest_root=os.path.join(dest_root, v),
51 | scale=scale,
52 | method=method,
53 | filename_template=filename_template
54 | )
55 | print("Video", v, "resize done !")
56 |
57 |
58 | def shift_videos(ori_root, dest_root, offset_x=0., offset_y=0., filename_template="{}.png"):
59 | '''
60 | function:
61 | shifting videos by (offset_x, offset_y) on (axis-x, axis-y) in batches
62 | params:
63 | ori_root: string, the dir of videos that need to be processed
64 | dest_root: string, the dir to save processed videos
65 | offset_x: float, offset pixels on axis-x
66 | positive=left; negative=right
67 | offset_y: float, offset pixels on axis-y
68 | positive=up; negative=down
69 | filename_template: string, the filename template for saving images
70 | '''
71 | handle_dir(dest_root)
72 | videos = listdir(ori_root)
73 | for v in videos:
74 | shift_images(
75 | ori_root=os.path.join(ori_root, v),
76 | dest_root=os.path.join(dest_root, v),
77 | offset_x=offset_x,
78 | offset_y=offset_y,
79 | filename_template=filename_template
80 | )
81 | print("Video", v, "shift done !")
82 |
83 |
84 | def extra_frames_from_videos(ori_root, save_root, fname_template='%4d.png', ffmpeg_path='ffmpeg'):
85 | '''
86 | function:
87 | ext frames from videos
88 | params:
89 | ori_root: string, the dir of videos that need to be processed
90 | save_root: string, the dir to save processed frames
91 | fname_template: the template for frames' filename
92 | ffmpeg_path: ffmpeg path
93 | '''
94 |
95 | handle_dir(save_root)
96 | videos = sorted(listdir(ori_root))
97 |
98 | for v in videos:
99 | vn = v[:-(len(v.split('.')[-1]) + 1)]
100 | video_path = os.path.join(ori_root, v)
101 | png_dir = os.path.join(save_root, vn)
102 | png_path = os.path.join(png_dir, fname_template)
103 | handle_dir(png_dir)
104 | command = '{} -i {} {}'.format(ffmpeg_path, video_path, png_path)
105 | os.system(command)
106 | print("Extra frames from {}".format(video_path))
107 |
108 |
109 | def zip_frames_to_videos(ori_root, save_root, fname_template='%4d.png', video_ext='mkv', ffmpeg_path='ffmpeg'):
110 | '''
111 | function:
112 | zip frames to videos
113 | params:
114 | ori_root: string, the dir of frames that need to be processed
115 | save_root: string, the dir to save processed videos
116 | fname_template: the template of frames' filename
117 | video_ext: the extension of videos
118 | ffmpeg_path: ffmpeg path
119 | '''
120 |
121 | handle_dir(save_root)
122 | videos_name = sorted(listdir(ori_root))
123 |
124 | for vn in videos_name:
125 | imgs_path = os.path.join(ori_root, vn, fname_template)
126 | video_path = os.path.join(save_root, '{}.{}'.format(vn, video_ext))
127 | command = '{} -i {} -c:v libx265 -x265-params lossless=1 {}'.format(
128 | ffmpeg_path, imgs_path, video_path
129 | ) # NTIRE 2022 Super-Resolution and Quality Enhancement of Compressed Video
130 | # command = '{} -r 24000/1001 -i {} -vcodec libx265 -pix_fmt yuv422p -crf 10 {}'.format(
131 | # ffmpeg_path, imgs_path, video_path
132 | # ) # youku competition
133 | os.system(command)
134 | print("Zip frames to {}".format(video_path))
135 |
136 |
137 | def copy_frames_for_fps(ori_root, save_root, mul=12, fname_template="{:0>4}", ext="png"):
138 | '''
139 | function:
140 | copy frames for fps
141 | params:
142 | ori_root: string, the dir of videos that need to be processed
143 | dest_root: string, the dir to save processed videos
144 | mul: the multiple of copy
145 | fname_template: the template of frames' filename
146 | ext: the ext of frames' filename
147 | '''
148 | fname_template = fname_template + '.{}'
149 | videos_name = sorted(listdir(ori_root))
150 | handle_dir(save_root)
151 | for vn in videos_name:
152 | frmames = sorted(listdir(os.path.join(ori_root, vn)))
153 | handle_dir(os.path.join(save_root, vn))
154 | for i, f in enumerate(frmames):
155 | for j in range(mul):
156 | now_idx = i * mul + j
157 | src = os.path.join(ori_root, vn, f)
158 | dest = os.path.join(save_root, vn, fname_template.format(now_idx, ext))
159 | copy_file(src, dest)
160 |
--------------------------------------------------------------------------------