├── .gitignore ├── LICENSE ├── README.md ├── correlation └── correlation.py ├── flolpips.py ├── pretrained_networks.py ├── pwcnet.py ├── sample_script.py ├── utils.py └── weights └── v0.1 └── alex.pth /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | *.ipynb 3 | *delete* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 danielism97 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FloLPIPS: A bespoke video quality metric for frame interpoation 2 | 3 | ### Duolikun Danier, Fan Zhang, David Bull 4 | 5 | 6 | [Project](https://danielism97.github.io/FloLPIPS) | [arXiv](https://arxiv.org/abs/2207.08119) 7 | 8 | 9 | ## Dependencies 10 | The following packages were used to evaluate the model. 11 | 12 | - python==3.8.8 13 | - pytorch==1.7.1 14 | - torchvision==0.8.2 15 | - cudatoolkit==10.1.243 16 | - opencv-python==4.5.1.48 17 | - numpy==1.19.2 18 | - pillow==8.1.2 19 | - cupy==9.0.0 20 | 21 | 22 | ## Usage 23 | ### Video-based Evaluation 24 | ```python 25 | from flolpips import calc_flolpips 26 | ref_video = '.mp4' 27 | dis_video = '.mp4' 28 | res = calc_flolpips(dis_video, ref_video) 29 | ``` 30 | 31 | ### Triplet Frame-based Evalation 32 | ```python 33 | from flolpips import Flolpips 34 | import torch 35 | 36 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 37 | eval_metric = Flolpips().to(device) 38 | 39 | batch = 8 40 | I0 = torch.rand(8, 3, 256, 448).to(device) # first frame of the triplet 41 | I1 = torch.rand(8, 3, 256, 448).to(device) # third frame of the triplet 42 | frame_dis = torch.rand(8, 3, 256, 448).to(device) # prediction of the intermediate frame 43 | frame_ref = torch.rand(8, 3, 256, 448).to(device) # ground-truth of the intermediate frame 44 | 45 | flolpips = eval_metric.forward(I0, I1, frame_dis, frame_ref) 46 | ``` 47 | 48 | 49 | ## Citation 50 | ``` 51 | @article{danier2022flolpips, 52 | title={FloLPIPS: A Bespoke Video Quality Metric for Frame Interpoation}, 53 | author={Danier, Duolikun and Zhang, Fan and Bull, David}, 54 | journal={arXiv preprint arXiv:2207.08119}, 55 | year={2022} 56 | } 57 | ``` 58 | 59 | ## Acknowledgement 60 | Lots of code in this repository are adapted/taken from the following repositories: 61 | 62 | - [LPIPS](https://github.com/richzhang/PerceptualSimilarity) 63 | - [pytorch-pwc](https://github.com/sniklaus/pytorch-pwc) 64 | 65 | We would like to thank the authors for sharing their code. -------------------------------------------------------------------------------- /correlation/correlation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | 5 | import cupy 6 | import re 7 | 8 | kernel_Correlation_rearrange = ''' 9 | extern "C" __global__ void kernel_Correlation_rearrange( 10 | const int n, 11 | const float* input, 12 | float* output 13 | ) { 14 | int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; 15 | 16 | if (intIndex >= n) { 17 | return; 18 | } 19 | 20 | int intSample = blockIdx.z; 21 | int intChannel = blockIdx.y; 22 | 23 | float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; 24 | 25 | __syncthreads(); 26 | 27 | int intPaddedY = (intIndex / SIZE_3(input)) + 4; 28 | int intPaddedX = (intIndex % SIZE_3(input)) + 4; 29 | int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; 30 | 31 | output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; 32 | } 33 | ''' 34 | 35 | kernel_Correlation_updateOutput = ''' 36 | extern "C" __global__ void kernel_Correlation_updateOutput( 37 | const int n, 38 | const float* rbot0, 39 | const float* rbot1, 40 | float* top 41 | ) { 42 | extern __shared__ char patch_data_char[]; 43 | 44 | float *patch_data = (float *)patch_data_char; 45 | 46 | // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 47 | int x1 = blockIdx.x + 4; 48 | int y1 = blockIdx.y + 4; 49 | int item = blockIdx.z; 50 | int ch_off = threadIdx.x; 51 | 52 | // Load 3D patch into shared shared memory 53 | for (int j = 0; j < 1; j++) { // HEIGHT 54 | for (int i = 0; i < 1; i++) { // WIDTH 55 | int ji_off = (j + i) * SIZE_3(rbot0); 56 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS 57 | int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; 58 | int idxPatchData = ji_off + ch; 59 | patch_data[idxPatchData] = rbot0[idx1]; 60 | } 61 | } 62 | } 63 | 64 | __syncthreads(); 65 | 66 | __shared__ float sum[32]; 67 | 68 | // Compute correlation 69 | for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { 70 | sum[ch_off] = 0; 71 | 72 | int s2o = top_channel % 9 - 4; 73 | int s2p = top_channel / 9 - 4; 74 | 75 | for (int j = 0; j < 1; j++) { // HEIGHT 76 | for (int i = 0; i < 1; i++) { // WIDTH 77 | int ji_off = (j + i) * SIZE_3(rbot0); 78 | for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS 79 | int x2 = x1 + s2o; 80 | int y2 = y1 + s2p; 81 | 82 | int idxPatchData = ji_off + ch; 83 | int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; 84 | 85 | sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; 86 | } 87 | } 88 | } 89 | 90 | __syncthreads(); 91 | 92 | if (ch_off == 0) { 93 | float total_sum = 0; 94 | for (int idx = 0; idx < 32; idx++) { 95 | total_sum += sum[idx]; 96 | } 97 | const int sumelems = SIZE_3(rbot0); 98 | const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; 99 | top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; 100 | } 101 | } 102 | } 103 | ''' 104 | 105 | kernel_Correlation_updateGradFirst = ''' 106 | #define ROUND_OFF 50000 107 | 108 | extern "C" __global__ void kernel_Correlation_updateGradFirst( 109 | const int n, 110 | const int intSample, 111 | const float* rbot0, 112 | const float* rbot1, 113 | const float* gradOutput, 114 | float* gradFirst, 115 | float* gradSecond 116 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 117 | int n = intIndex % SIZE_1(gradFirst); // channels 118 | int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos 119 | int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos 120 | 121 | // round_off is a trick to enable integer division with ceil, even for negative numbers 122 | // We use a large offset, for the inner part not to become negative. 123 | const int round_off = ROUND_OFF; 124 | const int round_off_s1 = round_off; 125 | 126 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: 127 | int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) 128 | int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) 129 | 130 | // Same here: 131 | int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4) 132 | int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4) 133 | 134 | float sum = 0; 135 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { 136 | xmin = max(0,xmin); 137 | xmax = min(SIZE_3(gradOutput)-1,xmax); 138 | 139 | ymin = max(0,ymin); 140 | ymax = min(SIZE_2(gradOutput)-1,ymax); 141 | 142 | for (int p = -4; p <= 4; p++) { 143 | for (int o = -4; o <= 4; o++) { 144 | // Get rbot1 data: 145 | int s2o = o; 146 | int s2p = p; 147 | int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; 148 | float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] 149 | 150 | // Index offset for gradOutput in following loops: 151 | int op = (p+4) * 9 + (o+4); // index[o,p] 152 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op); 153 | 154 | for (int y = ymin; y <= ymax; y++) { 155 | for (int x = xmin; x <= xmax; x++) { 156 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] 157 | sum += gradOutput[idxgradOutput] * bot1tmp; 158 | } 159 | } 160 | } 161 | } 162 | } 163 | const int sumelems = SIZE_1(gradFirst); 164 | const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4); 165 | gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; 166 | } } 167 | ''' 168 | 169 | kernel_Correlation_updateGradSecond = ''' 170 | #define ROUND_OFF 50000 171 | 172 | extern "C" __global__ void kernel_Correlation_updateGradSecond( 173 | const int n, 174 | const int intSample, 175 | const float* rbot0, 176 | const float* rbot1, 177 | const float* gradOutput, 178 | float* gradFirst, 179 | float* gradSecond 180 | ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { 181 | int n = intIndex % SIZE_1(gradSecond); // channels 182 | int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos 183 | int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos 184 | 185 | // round_off is a trick to enable integer division with ceil, even for negative numbers 186 | // We use a large offset, for the inner part not to become negative. 187 | const int round_off = ROUND_OFF; 188 | const int round_off_s1 = round_off; 189 | 190 | float sum = 0; 191 | for (int p = -4; p <= 4; p++) { 192 | for (int o = -4; o <= 4; o++) { 193 | int s2o = o; 194 | int s2p = p; 195 | 196 | //Get X,Y ranges and clamp 197 | // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: 198 | int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) 199 | int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) 200 | 201 | // Same here: 202 | int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o) 203 | int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p) 204 | 205 | if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { 206 | xmin = max(0,xmin); 207 | xmax = min(SIZE_3(gradOutput)-1,xmax); 208 | 209 | ymin = max(0,ymin); 210 | ymax = min(SIZE_2(gradOutput)-1,ymax); 211 | 212 | // Get rbot0 data: 213 | int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; 214 | float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] 215 | 216 | // Index offset for gradOutput in following loops: 217 | int op = (p+4) * 9 + (o+4); // index[o,p] 218 | int idxopoffset = (intSample * SIZE_1(gradOutput) + op); 219 | 220 | for (int y = ymin; y <= ymax; y++) { 221 | for (int x = xmin; x <= xmax; x++) { 222 | int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] 223 | sum += gradOutput[idxgradOutput] * bot0tmp; 224 | } 225 | } 226 | } 227 | } 228 | } 229 | const int sumelems = SIZE_1(gradSecond); 230 | const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4); 231 | gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; 232 | } } 233 | ''' 234 | 235 | def cupy_kernel(strFunction, objVariables): 236 | strKernel = globals()[strFunction] 237 | 238 | while True: 239 | objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) 240 | 241 | if objMatch is None: 242 | break 243 | # end 244 | 245 | intArg = int(objMatch.group(2)) 246 | 247 | strTensor = objMatch.group(4) 248 | intSizes = objVariables[strTensor].size() 249 | 250 | strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) 251 | # end 252 | 253 | while True: 254 | objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) 255 | 256 | if objMatch is None: 257 | break 258 | # end 259 | 260 | intArgs = int(objMatch.group(2)) 261 | strArgs = objMatch.group(4).split(',') 262 | 263 | strTensor = strArgs[0] 264 | intStrides = objVariables[strTensor].stride() 265 | strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] 266 | 267 | strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') 268 | # end 269 | 270 | return strKernel 271 | # end 272 | 273 | @cupy.memoize(for_each_device=True) 274 | def cupy_launch(strFunction, strKernel): 275 | return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) 276 | # end 277 | 278 | class _FunctionCorrelation(torch.autograd.Function): 279 | @staticmethod 280 | def forward(self, first, second): 281 | rbot0 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ]) 282 | rbot1 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ]) 283 | 284 | self.save_for_backward(first, second, rbot0, rbot1) 285 | 286 | first = first.contiguous(); assert(first.is_cuda == True) 287 | second = second.contiguous(); assert(second.is_cuda == True) 288 | 289 | output = first.new_zeros([ first.shape[0], 81, first.shape[2], first.shape[3] ]) 290 | 291 | if first.is_cuda == True: 292 | n = first.shape[2] * first.shape[3] 293 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { 294 | 'input': first, 295 | 'output': rbot0 296 | }))( 297 | grid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]), 298 | block=tuple([ 16, 1, 1 ]), 299 | args=[ n, first.data_ptr(), rbot0.data_ptr() ] 300 | ) 301 | 302 | n = second.shape[2] * second.shape[3] 303 | cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { 304 | 'input': second, 305 | 'output': rbot1 306 | }))( 307 | grid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]), 308 | block=tuple([ 16, 1, 1 ]), 309 | args=[ n, second.data_ptr(), rbot1.data_ptr() ] 310 | ) 311 | 312 | n = output.shape[1] * output.shape[2] * output.shape[3] 313 | cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { 314 | 'rbot0': rbot0, 315 | 'rbot1': rbot1, 316 | 'top': output 317 | }))( 318 | grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]), 319 | block=tuple([ 32, 1, 1 ]), 320 | shared_mem=first.shape[1] * 4, 321 | args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ] 322 | ) 323 | 324 | elif first.is_cuda == False: 325 | raise NotImplementedError() 326 | 327 | # end 328 | 329 | return output 330 | # end 331 | 332 | @staticmethod 333 | def backward(self, gradOutput): 334 | first, second, rbot0, rbot1 = self.saved_tensors 335 | 336 | gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True) 337 | 338 | gradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None 339 | gradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None 340 | 341 | if first.is_cuda == True: 342 | if gradFirst is not None: 343 | for intSample in range(first.shape[0]): 344 | n = first.shape[1] * first.shape[2] * first.shape[3] 345 | cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', { 346 | 'rbot0': rbot0, 347 | 'rbot1': rbot1, 348 | 'gradOutput': gradOutput, 349 | 'gradFirst': gradFirst, 350 | 'gradSecond': None 351 | }))( 352 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 353 | block=tuple([ 512, 1, 1 ]), 354 | args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ] 355 | ) 356 | # end 357 | # end 358 | 359 | if gradSecond is not None: 360 | for intSample in range(first.shape[0]): 361 | n = first.shape[1] * first.shape[2] * first.shape[3] 362 | cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', { 363 | 'rbot0': rbot0, 364 | 'rbot1': rbot1, 365 | 'gradOutput': gradOutput, 366 | 'gradFirst': None, 367 | 'gradSecond': gradSecond 368 | }))( 369 | grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), 370 | block=tuple([ 512, 1, 1 ]), 371 | args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ] 372 | ) 373 | # end 374 | # end 375 | 376 | elif first.is_cuda == False: 377 | raise NotImplementedError() 378 | 379 | # end 380 | 381 | return gradFirst, gradSecond 382 | # end 383 | # end 384 | 385 | def FunctionCorrelation(tenFirst, tenSecond): 386 | return _FunctionCorrelation.apply(tenFirst, tenSecond) 387 | # end 388 | 389 | class ModuleCorrelation(torch.nn.Module): 390 | def __init__(self): 391 | super(ModuleCorrelation, self).__init__() 392 | # end 393 | 394 | def forward(self, tenFirst, tenSecond): 395 | return _FunctionCorrelation.apply(tenFirst, tenSecond) 396 | # end 397 | # end -------------------------------------------------------------------------------- /flolpips.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | import os 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import pretrained_networks as pn 9 | import torch.nn 10 | import torch.nn.functional as F 11 | import torchvision.transforms.functional as TF 12 | import cv2 13 | 14 | from pwcnet import Network as PWCNet 15 | import utils 16 | 17 | def spatial_average(in_tens, keepdim=True): 18 | return in_tens.mean([2,3],keepdim=keepdim) 19 | 20 | def mw_spatial_average(in_tens, flow, keepdim=True): 21 | _,_,h,w = in_tens.shape 22 | flow = F.interpolate(flow, (h,w), align_corners=False, mode='bilinear') 23 | flow_mag = torch.sqrt(flow[:,0:1]**2 + flow[:,1:2]**2) 24 | flow_mag = flow_mag / torch.sum(flow_mag, dim=[1,2,3], keepdim=True) 25 | return torch.sum(in_tens*flow_mag, dim=[2,3],keepdim=keepdim) 26 | 27 | 28 | def mtw_spatial_average(in_tens, flow, texture, keepdim=True): 29 | _,_,h,w = in_tens.shape 30 | flow = F.interpolate(flow, (h,w), align_corners=False, mode='bilinear') 31 | texture = F.interpolate(texture, (h,w), align_corners=False, mode='bilinear') 32 | flow_mag = torch.sqrt(flow[:,0:1]**2 + flow[:,1:2]**2) 33 | flow_mag = (flow_mag - flow_mag.min()) / (flow_mag.max() - flow_mag.min()) + 1e-6 34 | texture = (texture - texture.min()) / (texture.max() - texture.min()) + 1e-6 35 | weight = flow_mag / texture 36 | weight /= torch.sum(weight) 37 | return torch.sum(in_tens*weight, dim=[2,3],keepdim=keepdim) 38 | 39 | 40 | 41 | def m2w_spatial_average(in_tens, flow, keepdim=True): 42 | _,_,h,w = in_tens.shape 43 | flow = F.interpolate(flow, (h,w), align_corners=False, mode='bilinear') 44 | flow_mag = flow[:,0:1]**2 + flow[:,1:2]**2 # B,1,H,W 45 | flow_mag = flow_mag / torch.sum(flow_mag) 46 | return torch.sum(in_tens*flow_mag, dim=[2,3],keepdim=keepdim) 47 | 48 | def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W 49 | in_H, in_W = in_tens.shape[2], in_tens.shape[3] 50 | return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens) 51 | 52 | # Learned perceptual metric 53 | class LPIPS(nn.Module): 54 | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, 55 | pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=False): 56 | # lpips - [True] means with linear calibration on top of base network 57 | # pretrained - [True] means load linear weights 58 | 59 | super(LPIPS, self).__init__() 60 | if(verbose): 61 | print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'% 62 | ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')) 63 | 64 | self.pnet_type = net 65 | self.pnet_tune = pnet_tune 66 | self.pnet_rand = pnet_rand 67 | self.spatial = spatial 68 | self.lpips = lpips # false means baseline of just averaging all layers 69 | self.version = version 70 | self.scaling_layer = ScalingLayer() 71 | 72 | if(self.pnet_type in ['vgg','vgg16']): 73 | net_type = pn.vgg16 74 | self.chns = [64,128,256,512,512] 75 | elif(self.pnet_type=='alex'): 76 | net_type = pn.alexnet 77 | self.chns = [64,192,384,256,256] 78 | elif(self.pnet_type=='squeeze'): 79 | net_type = pn.squeezenet 80 | self.chns = [64,128,256,384,384,512,512] 81 | self.L = len(self.chns) 82 | 83 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 84 | 85 | if(lpips): 86 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 87 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 88 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 89 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 90 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 91 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] 92 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet 93 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 94 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 95 | self.lins+=[self.lin5,self.lin6] 96 | self.lins = nn.ModuleList(self.lins) 97 | 98 | if(pretrained): 99 | if(model_path is None): 100 | import inspect 101 | import os 102 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net))) 103 | 104 | if(verbose): 105 | print('Loading model from: %s'%model_path) 106 | self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) 107 | 108 | if(eval_mode): 109 | self.eval() 110 | 111 | def forward(self, in0, in1, retPerLayer=False, normalize=False): 112 | if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] 113 | in0 = 2 * in0 - 1 114 | in1 = 2 * in1 - 1 115 | 116 | # v0.0 - original release had a bug, where input was not scaled 117 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) 118 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 119 | feats0, feats1, diffs = {}, {}, {} 120 | 121 | for kk in range(self.L): 122 | feats0[kk], feats1[kk] = utils.normalize_tensor(outs0[kk]), utils.normalize_tensor(outs1[kk]) 123 | diffs[kk] = (feats0[kk]-feats1[kk])**2 124 | 125 | if(self.lpips): 126 | if(self.spatial): 127 | res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] 128 | else: 129 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] 130 | else: 131 | if(self.spatial): 132 | res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] 133 | else: 134 | res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] 135 | 136 | # val = res[0] 137 | # for l in range(1,self.L): 138 | # val += res[l] 139 | # print(val) 140 | 141 | # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True) 142 | # b = torch.max(self.lins[kk](feats0[kk]**2)) 143 | # for kk in range(self.L): 144 | # a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True) 145 | # b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2))) 146 | # a = a/self.L 147 | # from IPython import embed 148 | # embed() 149 | # return 10*torch.log10(b/a) 150 | 151 | # if(retPerLayer): 152 | # return (val, res) 153 | # else: 154 | return torch.sum(torch.cat(res, 1), dim=(1,2,3), keepdims=False) 155 | 156 | 157 | class ScalingLayer(nn.Module): 158 | def __init__(self): 159 | super(ScalingLayer, self).__init__() 160 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) 161 | self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) 162 | 163 | def forward(self, inp): 164 | return (inp - self.shift) / self.scale 165 | 166 | 167 | class NetLinLayer(nn.Module): 168 | ''' A single linear layer which does a 1x1 conv ''' 169 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 170 | super(NetLinLayer, self).__init__() 171 | 172 | layers = [nn.Dropout(),] if(use_dropout) else [] 173 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] 174 | self.model = nn.Sequential(*layers) 175 | 176 | def forward(self, x): 177 | return self.model(x) 178 | 179 | class Dist2LogitLayer(nn.Module): 180 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' 181 | def __init__(self, chn_mid=32, use_sigmoid=True): 182 | super(Dist2LogitLayer, self).__init__() 183 | 184 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] 185 | layers += [nn.LeakyReLU(0.2,True),] 186 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] 187 | layers += [nn.LeakyReLU(0.2,True),] 188 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] 189 | if(use_sigmoid): 190 | layers += [nn.Sigmoid(),] 191 | self.model = nn.Sequential(*layers) 192 | 193 | def forward(self,d0,d1,eps=0.1): 194 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) 195 | 196 | class BCERankingLoss(nn.Module): 197 | def __init__(self, chn_mid=32): 198 | super(BCERankingLoss, self).__init__() 199 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 200 | # self.parameters = list(self.net.parameters()) 201 | self.loss = torch.nn.BCELoss() 202 | 203 | def forward(self, d0, d1, judge): 204 | per = (judge+1.)/2. 205 | self.logit = self.net.forward(d0,d1) 206 | return self.loss(self.logit, per) 207 | 208 | # L2, DSSIM metrics 209 | class FakeNet(nn.Module): 210 | def __init__(self, use_gpu=True, colorspace='Lab'): 211 | super(FakeNet, self).__init__() 212 | self.use_gpu = use_gpu 213 | self.colorspace = colorspace 214 | 215 | class L2(FakeNet): 216 | def forward(self, in0, in1, retPerLayer=None): 217 | assert(in0.size()[0]==1) # currently only supports batchSize 1 218 | 219 | if(self.colorspace=='RGB'): 220 | (N,C,X,Y) = in0.size() 221 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) 222 | return value 223 | elif(self.colorspace=='Lab'): 224 | value = utils.l2(utils.tensor2np(utils.tensor2tensorlab(in0.data,to_norm=False)), 225 | utils.tensor2np(utils.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 226 | ret_var = Variable( torch.Tensor((value,) ) ) 227 | if(self.use_gpu): 228 | ret_var = ret_var.cuda() 229 | return ret_var 230 | 231 | class DSSIM(FakeNet): 232 | 233 | def forward(self, in0, in1, retPerLayer=None): 234 | assert(in0.size()[0]==1) # currently only supports batchSize 1 235 | 236 | if(self.colorspace=='RGB'): 237 | value = utils.dssim(1.*utils.tensor2im(in0.data), 1.*utils.tensor2im(in1.data), range=255.).astype('float') 238 | elif(self.colorspace=='Lab'): 239 | value = utils.dssim(utils.tensor2np(utils.tensor2tensorlab(in0.data,to_norm=False)), 240 | utils.tensor2np(utils.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 241 | ret_var = Variable( torch.Tensor((value,) ) ) 242 | if(self.use_gpu): 243 | ret_var = ret_var.cuda() 244 | return ret_var 245 | 246 | def print_network(net): 247 | num_params = 0 248 | for param in net.parameters(): 249 | num_params += param.numel() 250 | print('Network',net) 251 | print('Total number of parameters: %d' % num_params) 252 | 253 | 254 | class FloLPIPS(LPIPS): 255 | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=False): 256 | super(FloLPIPS, self).__init__(pretrained, net, version, lpips, spatial, pnet_rand, pnet_tune, use_dropout, model_path, eval_mode, verbose) 257 | 258 | def forward(self, in0, in1, flow, retPerLayer=False, normalize=False): 259 | if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] 260 | in0 = 2 * in0 - 1 261 | in1 = 2 * in1 - 1 262 | 263 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) 264 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 265 | feats0, feats1, diffs = {}, {}, {} 266 | 267 | for kk in range(self.L): 268 | feats0[kk], feats1[kk] = utils.normalize_tensor(outs0[kk]), utils.normalize_tensor(outs1[kk]) 269 | diffs[kk] = (feats0[kk]-feats1[kk])**2 270 | 271 | res = [mw_spatial_average(self.lins[kk](diffs[kk]), flow, keepdim=True) for kk in range(self.L)] 272 | 273 | 274 | return torch.sum(torch.cat(res, 1), dim=(1,2,3), keepdims=False) 275 | 276 | 277 | 278 | def calc_flolpips(dis_path, ref_path): 279 | 280 | batch_size = 8 281 | 282 | # convert to yuv first 283 | os.system('ffmpeg -hide_banner -loglevel error -i {} flolpips_ref.yuv'.format(ref_path)) 284 | os.system('ffmpeg -hide_banner -loglevel error -i {} flolpips_dis.yuv'.format(dis_path)) 285 | 286 | loss_fn = FloLPIPS(net='alex',version='0.1').cuda() 287 | flownet = PWCNet().cuda() 288 | # batch_size = 128 289 | 290 | cap_dis = cv2.VideoCapture(dis_path) 291 | cap_ref = cv2.VideoCapture(ref_path) 292 | assert int(cap_dis.get(cv2.CAP_PROP_FRAME_COUNT)) == int(cap_ref.get(cv2.CAP_PROP_FRAME_COUNT)) 293 | num_frames = int(cap_ref.get(cv2.CAP_PROP_FRAME_COUNT)) 294 | width = int(cap_ref.get(3)) 295 | height = int(cap_ref.get(4)) 296 | cap_dis.release() 297 | cap_ref.release() 298 | stream_dis = open('flolpips_dis.yuv', 'r') 299 | stream_ref = open('flolpips_ref.yuv', 'r') 300 | 301 | flolpips_list = [] 302 | batch_ref_list, batch_dis_list = [], [] 303 | batch_ref_next_list, batch_dis_next_list = [], [] 304 | for iFrame in range(num_frames-1): 305 | frame_dis = TF.to_tensor(utils.read_frame_yuv2rgb(stream_dis, width, height, iFrame, 8, '420')) 306 | frame_dis_next = TF.to_tensor(utils.read_frame_yuv2rgb(stream_dis, width, height, iFrame+1, 8, '420')) 307 | frame_ref = TF.to_tensor(utils.read_frame_yuv2rgb(stream_ref, width, height, iFrame, 8, '420')) 308 | frame_ref_next = TF.to_tensor(utils.read_frame_yuv2rgb(stream_ref, width, height, iFrame+1, 8, '420')) 309 | batch_dis_list.append(frame_dis) 310 | batch_dis_next_list.append(frame_dis_next) 311 | batch_ref_list.append(frame_ref) 312 | batch_ref_next_list.append(frame_ref_next) 313 | if len(batch_ref_list) % batch_size == 0: 314 | with torch.no_grad(): 315 | frames_ref = torch.stack(batch_ref_list, dim=0).cuda() 316 | frames_dis = torch.stack(batch_dis_list, dim=0).cuda() 317 | frames_ref_next = torch.stack(batch_ref_next_list, dim=0).cuda() 318 | frames_dis_next = torch.stack(batch_dis_next_list, dim=0).cuda() 319 | flow_ref = flownet(frames_ref, frames_ref_next) 320 | flow_dis = flownet(frames_dis, frames_dis_next) 321 | flow_diff = flow_ref - flow_dis 322 | flolpips = loss_fn.forward(frames_ref, frames_dis, flow_diff, normalize=True) 323 | batch_ref_list, batch_dis_list, batch_ref_next_list, batch_dis_next_list = [], [], [], [] 324 | flolpips_list = flolpips_list + flolpips.cpu().tolist() 325 | if len(batch_ref_list) > 0: 326 | with torch.no_grad(): 327 | frames_ref = torch.stack(batch_ref_list, dim=0).cuda() 328 | frames_dis = torch.stack(batch_dis_list, dim=0).cuda() 329 | frames_ref_next = torch.stack(batch_ref_next_list, dim=0).cuda() 330 | frames_dis_next = torch.stack(batch_dis_next_list, dim=0).cuda() 331 | flow_ref = flownet(frames_ref, frames_ref_next) 332 | flow_dis = flownet(frames_dis, frames_dis_next) 333 | flow_diff = flow_ref - flow_dis 334 | flolpips = loss_fn.forward(frames_ref, frames_dis, flow_diff, normalize=True) 335 | flolpips_list = flolpips_list + flolpips.cpu().tolist() 336 | 337 | stream_dis.close() 338 | stream_ref.close() 339 | 340 | # delete files, modify command accordingly 341 | os.remove('flolpips_dis.yuv') 342 | os.remove('flolpips_ref.yuv') 343 | 344 | return np.mean(flolpips_list) 345 | 346 | 347 | class Flolpips(nn.Module): 348 | def __init__(self): 349 | super(Flolpips, self).__init__() 350 | self.loss_fn = FloLPIPS(net='alex',version='0.1') 351 | self.flownet = PWCNet() 352 | 353 | @torch.no_grad() 354 | def forward(self, I0, I1, frame_dis, frame_ref): 355 | """ 356 | args: 357 | I0: first frame of the triplet, shape: [B, C, H, W] 358 | I1: third frame of the triplet, shape: [B, C, H, W] 359 | frame_dis: prediction of the intermediate frame, shape: [B, C, H, W] 360 | frame_ref: ground-truth of the intermediate frame, shape: [B, C, H, W] 361 | """ 362 | assert I0.size() == I1.size() == frame_dis.size() == frame_ref.size(), \ 363 | "the 4 input tensors should have same size" 364 | 365 | flow_ref = self.flownet(frame_ref, I0) 366 | flow_dis = self.flownet(frame_dis, I0) 367 | flow_diff = flow_ref - flow_dis 368 | flolpips_wrt_I0 = self.loss_fn.forward(frame_ref, frame_dis, flow_diff, normalize=True) 369 | 370 | flow_ref = self.flownet(frame_ref, I1) 371 | flow_dis = self.flownet(frame_dis, I1) 372 | flow_diff = flow_ref - flow_dis 373 | flolpips_wrt_I1 = self.loss_fn.forward(frame_ref, frame_dis, flow_diff, normalize=True) 374 | 375 | flolpips = (flolpips_wrt_I0 + flolpips_wrt_I1) / 2 376 | return flolpips -------------------------------------------------------------------------------- /pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | 5 | class squeezenet(torch.nn.Module): 6 | def __init__(self, requires_grad=False, pretrained=True): 7 | super(squeezenet, self).__init__() 8 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 9 | self.slice1 = torch.nn.Sequential() 10 | self.slice2 = torch.nn.Sequential() 11 | self.slice3 = torch.nn.Sequential() 12 | self.slice4 = torch.nn.Sequential() 13 | self.slice5 = torch.nn.Sequential() 14 | self.slice6 = torch.nn.Sequential() 15 | self.slice7 = torch.nn.Sequential() 16 | self.N_slices = 7 17 | for x in range(2): 18 | self.slice1.add_module(str(x), pretrained_features[x]) 19 | for x in range(2,5): 20 | self.slice2.add_module(str(x), pretrained_features[x]) 21 | for x in range(5, 8): 22 | self.slice3.add_module(str(x), pretrained_features[x]) 23 | for x in range(8, 10): 24 | self.slice4.add_module(str(x), pretrained_features[x]) 25 | for x in range(10, 11): 26 | self.slice5.add_module(str(x), pretrained_features[x]) 27 | for x in range(11, 12): 28 | self.slice6.add_module(str(x), pretrained_features[x]) 29 | for x in range(12, 13): 30 | self.slice7.add_module(str(x), pretrained_features[x]) 31 | if not requires_grad: 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | def forward(self, X): 36 | h = self.slice1(X) 37 | h_relu1 = h 38 | h = self.slice2(h) 39 | h_relu2 = h 40 | h = self.slice3(h) 41 | h_relu3 = h 42 | h = self.slice4(h) 43 | h_relu4 = h 44 | h = self.slice5(h) 45 | h_relu5 = h 46 | h = self.slice6(h) 47 | h_relu6 = h 48 | h = self.slice7(h) 49 | h_relu7 = h 50 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 51 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 52 | 53 | return out 54 | 55 | 56 | class alexnet(torch.nn.Module): 57 | def __init__(self, requires_grad=False, pretrained=True): 58 | super(alexnet, self).__init__() 59 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 60 | self.slice1 = torch.nn.Sequential() 61 | self.slice2 = torch.nn.Sequential() 62 | self.slice3 = torch.nn.Sequential() 63 | self.slice4 = torch.nn.Sequential() 64 | self.slice5 = torch.nn.Sequential() 65 | self.N_slices = 5 66 | for x in range(2): 67 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 68 | for x in range(2, 5): 69 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 70 | for x in range(5, 8): 71 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 72 | for x in range(8, 10): 73 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 74 | for x in range(10, 12): 75 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 76 | if not requires_grad: 77 | for param in self.parameters(): 78 | param.requires_grad = False 79 | 80 | def forward(self, X): 81 | h = self.slice1(X) 82 | h_relu1 = h 83 | h = self.slice2(h) 84 | h_relu2 = h 85 | h = self.slice3(h) 86 | h_relu3 = h 87 | h = self.slice4(h) 88 | h_relu4 = h 89 | h = self.slice5(h) 90 | h_relu5 = h 91 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 92 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 93 | 94 | return out 95 | 96 | class vgg16(torch.nn.Module): 97 | def __init__(self, requires_grad=False, pretrained=True): 98 | super(vgg16, self).__init__() 99 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 100 | self.slice1 = torch.nn.Sequential() 101 | self.slice2 = torch.nn.Sequential() 102 | self.slice3 = torch.nn.Sequential() 103 | self.slice4 = torch.nn.Sequential() 104 | self.slice5 = torch.nn.Sequential() 105 | self.N_slices = 5 106 | for x in range(4): 107 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 108 | for x in range(4, 9): 109 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 110 | for x in range(9, 16): 111 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 112 | for x in range(16, 23): 113 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 114 | for x in range(23, 30): 115 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 116 | if not requires_grad: 117 | for param in self.parameters(): 118 | param.requires_grad = False 119 | 120 | def forward(self, X): 121 | h = self.slice1(X) 122 | h_relu1_2 = h 123 | h = self.slice2(h) 124 | h_relu2_2 = h 125 | h = self.slice3(h) 126 | h_relu3_3 = h 127 | h = self.slice4(h) 128 | h_relu4_3 = h 129 | h = self.slice5(h) 130 | h_relu5_3 = h 131 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 132 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 133 | 134 | return out 135 | 136 | 137 | 138 | class resnet(torch.nn.Module): 139 | def __init__(self, requires_grad=False, pretrained=True, num=18): 140 | super(resnet, self).__init__() 141 | if(num==18): 142 | self.net = tv.resnet18(pretrained=pretrained) 143 | elif(num==34): 144 | self.net = tv.resnet34(pretrained=pretrained) 145 | elif(num==50): 146 | self.net = tv.resnet50(pretrained=pretrained) 147 | elif(num==101): 148 | self.net = tv.resnet101(pretrained=pretrained) 149 | elif(num==152): 150 | self.net = tv.resnet152(pretrained=pretrained) 151 | self.N_slices = 5 152 | 153 | self.conv1 = self.net.conv1 154 | self.bn1 = self.net.bn1 155 | self.relu = self.net.relu 156 | self.maxpool = self.net.maxpool 157 | self.layer1 = self.net.layer1 158 | self.layer2 = self.net.layer2 159 | self.layer3 = self.net.layer3 160 | self.layer4 = self.net.layer4 161 | 162 | def forward(self, X): 163 | h = self.conv1(X) 164 | h = self.bn1(h) 165 | h = self.relu(h) 166 | h_relu1 = h 167 | h = self.maxpool(h) 168 | h = self.layer1(h) 169 | h_conv2 = h 170 | h = self.layer2(h) 171 | h_conv3 = h 172 | h = self.layer3(h) 173 | h_conv4 = h 174 | h = self.layer4(h) 175 | h_conv5 = h 176 | 177 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 178 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 179 | 180 | return out 181 | -------------------------------------------------------------------------------- /pwcnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | 5 | import getopt 6 | import math 7 | import numpy 8 | import os 9 | import PIL 10 | import PIL.Image 11 | import sys 12 | 13 | # try: 14 | from correlation import correlation # the custom cost volume layer 15 | # except: 16 | # sys.path.insert(0, './correlation'); import correlation # you should consider upgrading python 17 | # end 18 | 19 | ########################################################## 20 | 21 | # assert(int(str('').join(torch.__version__.split('.')[0:2])) >= 13) # requires at least pytorch version 1.3.0 22 | 23 | # torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance 24 | 25 | # torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance 26 | 27 | # ########################################################## 28 | 29 | # arguments_strModel = 'default' # 'default', or 'chairs-things' 30 | # arguments_strFirst = './images/first.png' 31 | # arguments_strSecond = './images/second.png' 32 | # arguments_strOut = './out.flo' 33 | 34 | # for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]: 35 | # if strOption == '--model' and strArgument != '': arguments_strModel = strArgument # which model to use 36 | # if strOption == '--first' and strArgument != '': arguments_strFirst = strArgument # path to the first frame 37 | # if strOption == '--second' and strArgument != '': arguments_strSecond = strArgument # path to the second frame 38 | # if strOption == '--out' and strArgument != '': arguments_strOut = strArgument # path to where the output should be stored 39 | # end 40 | 41 | ########################################################## 42 | 43 | 44 | 45 | def backwarp(tenInput, tenFlow): 46 | backwarp_tenGrid = {} 47 | backwarp_tenPartial = {} 48 | if str(tenFlow.shape) not in backwarp_tenGrid: 49 | tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1) 50 | tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3]) 51 | 52 | backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([ tenHor, tenVer ], 1).cuda() 53 | # end 54 | 55 | if str(tenFlow.shape) not in backwarp_tenPartial: 56 | backwarp_tenPartial[str(tenFlow.shape)] = tenFlow.new_ones([ tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3] ]) 57 | # end 58 | 59 | tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) 60 | tenInput = torch.cat([ tenInput, backwarp_tenPartial[str(tenFlow.shape)] ], 1) 61 | 62 | tenOutput = torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=False) 63 | 64 | tenMask = tenOutput[:, -1:, :, :]; tenMask[tenMask > 0.999] = 1.0; tenMask[tenMask < 1.0] = 0.0 65 | 66 | return tenOutput[:, :-1, :, :] * tenMask 67 | # end 68 | 69 | ########################################################## 70 | 71 | class Network(torch.nn.Module): 72 | def __init__(self): 73 | super(Network, self).__init__() 74 | 75 | class Extractor(torch.nn.Module): 76 | def __init__(self): 77 | super(Extractor, self).__init__() 78 | 79 | self.netOne = torch.nn.Sequential( 80 | torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1), 81 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 82 | torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), 83 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 84 | torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1), 85 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 86 | ) 87 | 88 | self.netTwo = torch.nn.Sequential( 89 | torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1), 90 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 91 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), 92 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 93 | torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), 94 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 95 | ) 96 | 97 | self.netThr = torch.nn.Sequential( 98 | torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), 99 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 100 | torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 101 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 102 | torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), 103 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 104 | ) 105 | 106 | self.netFou = torch.nn.Sequential( 107 | torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1), 108 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 109 | torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), 110 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 111 | torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), 112 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 113 | ) 114 | 115 | self.netFiv = torch.nn.Sequential( 116 | torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1), 117 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 118 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 119 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 120 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), 121 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 122 | ) 123 | 124 | self.netSix = torch.nn.Sequential( 125 | torch.nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1), 126 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 127 | torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), 128 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 129 | torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1), 130 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 131 | ) 132 | # end 133 | 134 | def forward(self, tenInput): 135 | tenOne = self.netOne(tenInput) 136 | tenTwo = self.netTwo(tenOne) 137 | tenThr = self.netThr(tenTwo) 138 | tenFou = self.netFou(tenThr) 139 | tenFiv = self.netFiv(tenFou) 140 | tenSix = self.netSix(tenFiv) 141 | 142 | return [ tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix ] 143 | # end 144 | # end 145 | 146 | class Decoder(torch.nn.Module): 147 | def __init__(self, intLevel): 148 | super(Decoder, self).__init__() 149 | 150 | intPrevious = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 1] 151 | intCurrent = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 0] 152 | 153 | if intLevel < 6: self.netUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1) 154 | if intLevel < 6: self.netUpfeat = torch.nn.ConvTranspose2d(in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=4, stride=2, padding=1) 155 | if intLevel < 6: self.fltBackwarp = [ None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel + 1] 156 | 157 | self.netOne = torch.nn.Sequential( 158 | torch.nn.Conv2d(in_channels=intCurrent, out_channels=128, kernel_size=3, stride=1, padding=1), 159 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 160 | ) 161 | 162 | self.netTwo = torch.nn.Sequential( 163 | torch.nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, kernel_size=3, stride=1, padding=1), 164 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 165 | ) 166 | 167 | self.netThr = torch.nn.Sequential( 168 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, kernel_size=3, stride=1, padding=1), 169 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 170 | ) 171 | 172 | self.netFou = torch.nn.Sequential( 173 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, kernel_size=3, stride=1, padding=1), 174 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 175 | ) 176 | 177 | self.netFiv = torch.nn.Sequential( 178 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, kernel_size=3, stride=1, padding=1), 179 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) 180 | ) 181 | 182 | self.netSix = torch.nn.Sequential( 183 | torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=3, stride=1, padding=1) 184 | ) 185 | # end 186 | 187 | def forward(self, tenFirst, tenSecond, objPrevious): 188 | tenFlow = None 189 | tenFeat = None 190 | 191 | if objPrevious is None: 192 | tenFlow = None 193 | tenFeat = None 194 | 195 | tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=tenSecond), negative_slope=0.1, inplace=False) 196 | 197 | tenFeat = torch.cat([ tenVolume ], 1) 198 | 199 | elif objPrevious is not None: 200 | tenFlow = self.netUpflow(objPrevious['tenFlow']) 201 | tenFeat = self.netUpfeat(objPrevious['tenFeat']) 202 | 203 | tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=backwarp(tenInput=tenSecond, tenFlow=tenFlow * self.fltBackwarp)), negative_slope=0.1, inplace=False) 204 | 205 | tenFeat = torch.cat([ tenVolume, tenFirst, tenFlow, tenFeat ], 1) 206 | 207 | # end 208 | 209 | tenFeat = torch.cat([ self.netOne(tenFeat), tenFeat ], 1) 210 | tenFeat = torch.cat([ self.netTwo(tenFeat), tenFeat ], 1) 211 | tenFeat = torch.cat([ self.netThr(tenFeat), tenFeat ], 1) 212 | tenFeat = torch.cat([ self.netFou(tenFeat), tenFeat ], 1) 213 | tenFeat = torch.cat([ self.netFiv(tenFeat), tenFeat ], 1) 214 | 215 | tenFlow = self.netSix(tenFeat) 216 | 217 | return { 218 | 'tenFlow': tenFlow, 219 | 'tenFeat': tenFeat 220 | } 221 | # end 222 | # end 223 | 224 | class Refiner(torch.nn.Module): 225 | def __init__(self): 226 | super(Refiner, self).__init__() 227 | 228 | self.netMain = torch.nn.Sequential( 229 | torch.nn.Conv2d(in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1), 230 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 231 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2), 232 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 233 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4), 234 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 235 | torch.nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8), 236 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 237 | torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16), 238 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 239 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1), 240 | torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), 241 | torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1) 242 | ) 243 | # end 244 | 245 | def forward(self, tenInput): 246 | return self.netMain(tenInput) 247 | # end 248 | # end 249 | 250 | self.netExtractor = Extractor() 251 | 252 | self.netTwo = Decoder(2) 253 | self.netThr = Decoder(3) 254 | self.netFou = Decoder(4) 255 | self.netFiv = Decoder(5) 256 | self.netSix = Decoder(6) 257 | 258 | self.netRefiner = Refiner() 259 | 260 | self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.hub.load_state_dict_from_url(url='http://content.sniklaus.com/github/pytorch-pwc/network-' + 'default' + '.pytorch').items() }) 261 | # end 262 | 263 | def forward(self, tenFirst, tenSecond): 264 | intWidth = tenFirst.shape[3] 265 | intHeight = tenFirst.shape[2] 266 | 267 | intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0)) 268 | intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0)) 269 | 270 | tenPreprocessedFirst = torch.nn.functional.interpolate(input=tenFirst, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) 271 | tenPreprocessedSecond = torch.nn.functional.interpolate(input=tenSecond, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) 272 | 273 | tenFirst = self.netExtractor(tenPreprocessedFirst) 274 | tenSecond = self.netExtractor(tenPreprocessedSecond) 275 | 276 | 277 | objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None) 278 | objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate) 279 | objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate) 280 | objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate) 281 | objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate) 282 | 283 | tenFlow = objEstimate['tenFlow'] + self.netRefiner(objEstimate['tenFeat']) 284 | tenFlow = 20.0 * torch.nn.functional.interpolate(input=tenFlow, size=(intHeight, intWidth), mode='bilinear', align_corners=False) 285 | tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) 286 | tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) 287 | 288 | return tenFlow 289 | # end 290 | # end 291 | 292 | netNetwork = None 293 | 294 | ########################################################## 295 | 296 | def estimate(tenFirst, tenSecond): 297 | global netNetwork 298 | 299 | if netNetwork is None: 300 | netNetwork = Network().cuda().eval() 301 | # end 302 | 303 | assert(tenFirst.shape[1] == tenSecond.shape[1]) 304 | assert(tenFirst.shape[2] == tenSecond.shape[2]) 305 | 306 | intWidth = tenFirst.shape[2] 307 | intHeight = tenFirst.shape[1] 308 | 309 | assert(intWidth == 1024) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 310 | assert(intHeight == 436) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue 311 | 312 | tenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth) 313 | tenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth) 314 | 315 | intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0)) 316 | intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0)) 317 | 318 | tenPreprocessedFirst = torch.nn.functional.interpolate(input=tenPreprocessedFirst, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) 319 | tenPreprocessedSecond = torch.nn.functional.interpolate(input=tenPreprocessedSecond, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) 320 | 321 | tenFlow = 20.0 * torch.nn.functional.interpolate(input=netNetwork(tenPreprocessedFirst, tenPreprocessedSecond), size=(intHeight, intWidth), mode='bilinear', align_corners=False) 322 | 323 | tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) 324 | tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) 325 | 326 | return tenFlow[0, :, :, :].cpu() 327 | # end 328 | 329 | ########################################################## 330 | 331 | # if __name__ == '__main__': 332 | # tenFirst = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strFirst))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0))) 333 | # tenSecond = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strSecond))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0))) 334 | 335 | # tenOutput = estimate(tenFirst, tenSecond) 336 | 337 | # objOutput = open(arguments_strOut, 'wb') 338 | 339 | # numpy.array([ 80, 73, 69, 72 ], numpy.uint8).tofile(objOutput) 340 | # numpy.array([ tenOutput.shape[2], tenOutput.shape[1] ], numpy.int32).tofile(objOutput) 341 | # numpy.array(tenOutput.numpy().transpose(1, 2, 0), numpy.float32).tofile(objOutput) 342 | 343 | # objOutput.close() 344 | # end -------------------------------------------------------------------------------- /sample_script.py: -------------------------------------------------------------------------------- 1 | from flolpips import Flolpips 2 | import torch 3 | 4 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 5 | eval_metric = Flolpips().to(device) 6 | 7 | batch = 8 8 | I0 = torch.rand(8, 3, 256, 448).to(device) 9 | I1 = torch.rand(8, 3, 256, 448).to(device) 10 | frame_dis = torch.rand(8, 3, 256, 448).to(device) 11 | frame_ref = torch.rand(8, 3, 256, 448).to(device) 12 | 13 | flolpips = eval_metric.forward(I0, I1, frame_dis, frame_ref) 14 | print(flolpips) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | 5 | 6 | def normalize_tensor(in_feat,eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 8 | return in_feat/(norm_factor+eps) 9 | 10 | def l2(p0, p1, range=255.): 11 | return .5*np.mean((p0 / range - p1 / range)**2) 12 | 13 | def dssim(p0, p1, range=255.): 14 | from skimage.measure import compare_ssim 15 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 16 | 17 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 18 | image_numpy = image_tensor[0].cpu().float().numpy() 19 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 20 | return image_numpy.astype(imtype) 21 | 22 | def tensor2np(tensor_obj): 23 | # change dimension of a tensor object into a numpy array 24 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 25 | 26 | def np2tensor(np_obj): 27 | # change dimenion of np array into tensor array 28 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 29 | 30 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 31 | # image tensor to lab tensor 32 | from skimage import color 33 | 34 | img = tensor2im(image_tensor) 35 | img_lab = color.rgb2lab(img) 36 | if(mc_only): 37 | img_lab[:,:,0] = img_lab[:,:,0]-50 38 | if(to_norm and not mc_only): 39 | img_lab[:,:,0] = img_lab[:,:,0]-50 40 | img_lab = img_lab/100. 41 | 42 | return np2tensor(img_lab) 43 | 44 | def read_frame_yuv2rgb(stream, width, height, iFrame, bit_depth, pix_fmt='420'): 45 | if pix_fmt == '420': 46 | multiplier = 1 47 | uv_factor = 2 48 | elif pix_fmt == '444': 49 | multiplier = 2 50 | uv_factor = 1 51 | else: 52 | print('Pixel format {} is not supported'.format(pix_fmt)) 53 | return 54 | 55 | if bit_depth == 8: 56 | datatype = np.uint8 57 | stream.seek(iFrame*1.5*width*height*multiplier) 58 | Y = np.fromfile(stream, dtype=datatype, count=width*height).reshape((height, width)) 59 | 60 | # read chroma samples and upsample since original is 4:2:0 sampling 61 | U = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ 62 | reshape((height//uv_factor, width//uv_factor)) 63 | V = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ 64 | reshape((height//uv_factor, width//uv_factor)) 65 | 66 | else: 67 | datatype = np.uint16 68 | stream.seek(iFrame*3*width*height*multiplier) 69 | Y = np.fromfile(stream, dtype=datatype, count=width*height).reshape((height, width)) 70 | 71 | U = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ 72 | reshape((height//uv_factor, width//uv_factor)) 73 | V = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\ 74 | reshape((height//uv_factor, width//uv_factor)) 75 | 76 | if pix_fmt == '420': 77 | yuv = np.empty((height*3//2, width), dtype=datatype) 78 | yuv[0:height,:] = Y 79 | 80 | yuv[height:height+height//4,:] = U.reshape(-1, width) 81 | yuv[height+height//4:,:] = V.reshape(-1, width) 82 | 83 | if bit_depth != 8: 84 | yuv = (yuv/(2**bit_depth-1)*255).astype(np.uint8) 85 | 86 | #convert to rgb 87 | rgb = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB_I420) 88 | 89 | else: 90 | yvu = np.stack([Y,V,U],axis=2) 91 | if bit_depth != 8: 92 | yvu = (yvu/(2**bit_depth-1)*255).astype(np.uint8) 93 | rgb = cv2.cvtColor(yvu, cv2.COLOR_YCrCb2RGB) 94 | 95 | return rgb 96 | -------------------------------------------------------------------------------- /weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danier97/flolpips/66dd39937e961c56cf162651074eacbfb5f9aab1/weights/v0.1/alex.pth --------------------------------------------------------------------------------