├── Figures └── M-LVC.PNG ├── RD_Results ├── ClassB_MSSSIM.png ├── ClassB_PSNR.png ├── ClassC_MSSSIM.png ├── ClassC_PSNR.png ├── ClassD_MSSSIM.png ├── ClassD_PSNR.png ├── ClassE_MSSSIM.png ├── ClassE_PSNR.png ├── UVG_MSSSIM.png ├── UVG_PSNR.png └── data.txt ├── README.md ├── SDL.dll ├── SDL_image.dll ├── bpgdec.exe ├── bpgenc.exe ├── exp_data_dir ├── I_frames_enc │ ├── RaceHorsesC_832x448.png │ ├── dec_RaceHorsesC_832x448_30_qp21.png │ └── enc_RaceHorsesC_832x448_30_qp21.bpg └── figure │ └── ClassC_RaceHorsesC_curidx1.png ├── flow_utils.py ├── flownet_models.py ├── helper.py ├── libgcc_s_seh-1.dll ├── libjpeg-62.dll ├── libpng16-16.dll ├── libstdc++-6.dll ├── libtiff-5.dll ├── libwinpthread-1.dll ├── model.py ├── modules.py ├── msssim.py ├── tensorflow_compression ├── __init__.py ├── __pycache__ │ └── __init__.cpython-37.pyc └── python │ ├── __init__.py │ ├── __pycache__ │ └── __init__.cpython-37.pyc │ ├── layers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── entropy_models.cpython-37.pyc │ │ ├── entropy_models_gauss.cpython-37.pyc │ │ ├── gdn.cpython-37.pyc │ │ ├── initializers.cpython-37.pyc │ │ ├── parameterizers.cpython-37.pyc │ │ └── signal_conv.cpython-37.pyc │ ├── entropy_models.py │ ├── entropy_models_gauss.py │ ├── entropy_models_test.py │ ├── gdn.py │ ├── gdn_test.py │ ├── initializers.py │ ├── parameterizers.py │ ├── parameterizers_test.py │ ├── signal_conv.py │ └── signal_conv_test.py │ └── ops │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── coder_ops.cpython-37.pyc │ ├── math_ops.cpython-37.pyc │ ├── padding_ops.cpython-37.pyc │ └── spectral_ops.cpython-37.pyc │ ├── coder_ops.py │ ├── coder_ops_test.py │ ├── math_ops.py │ ├── math_ops_test.py │ ├── padding_ops.py │ ├── padding_ops_test.py │ ├── spectral_ops.py │ └── spectral_ops_test.py ├── test.py ├── utils.py ├── yuv_import.py └── zlib1.dll /Figures/M-LVC.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/Figures/M-LVC.PNG -------------------------------------------------------------------------------- /RD_Results/ClassB_MSSSIM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/RD_Results/ClassB_MSSSIM.png -------------------------------------------------------------------------------- /RD_Results/ClassB_PSNR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/RD_Results/ClassB_PSNR.png -------------------------------------------------------------------------------- /RD_Results/ClassC_MSSSIM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/RD_Results/ClassC_MSSSIM.png -------------------------------------------------------------------------------- /RD_Results/ClassC_PSNR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/RD_Results/ClassC_PSNR.png -------------------------------------------------------------------------------- /RD_Results/ClassD_MSSSIM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/RD_Results/ClassD_MSSSIM.png -------------------------------------------------------------------------------- /RD_Results/ClassD_PSNR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/RD_Results/ClassD_PSNR.png -------------------------------------------------------------------------------- /RD_Results/ClassE_MSSSIM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/RD_Results/ClassE_MSSSIM.png -------------------------------------------------------------------------------- /RD_Results/ClassE_PSNR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/RD_Results/ClassE_PSNR.png -------------------------------------------------------------------------------- /RD_Results/UVG_MSSSIM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/RD_Results/UVG_MSSSIM.png -------------------------------------------------------------------------------- /RD_Results/UVG_PSNR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/RD_Results/UVG_PSNR.png -------------------------------------------------------------------------------- /RD_Results/data.txt: -------------------------------------------------------------------------------- 1 | Class UVG 2 | bpp [0.0598, 0.07934285714285713, 0.11425714285714283, 0.1600142857142857] 3 | psnr [35.25747142857143, 36.02594285714286, 37.205371428571425, 38.112585714285714] 4 | msssim [0.9581257142857142, 0.9618857142857143, 0.9703157142857144, 0.9766199999999998] 5 | 6 | Class B 7 | bpp [0.0728, 0.10398, 0.16044, 0.24101999999999996] 8 | psnr [33.07536, 34.00686, 35.10922000000001, 36.0504] 9 | msssim [0.9590679999999999, 0.9648019999999999, 0.971544, 0.9768839999999999] 10 | 11 | Class D 12 | bpp [0.12439999999999998, 0.1943, 0.28912499999999997, 0.41047500000000003] 13 | psnr [29.340875, 30.594125, 32.229025, 33.7589] 14 | msssim [0.967345, 0.975635, 0.9827699999999999, 0.9877174999999999] 15 | 16 | Class C 17 | bpp [0.13219999999999998, 0.198125, 0.29325, 0.41425] 18 | psnr [29.151575, 30.419525, 32.000975, 33.491125] 19 | msssim [0.9579425, 0.9684275, 0.9776150000000001, 0.98363] 20 | 21 | Class E 22 | bpp [0.026033333333333335, 0.033800000000000004, 0.0509, 0.0732] 23 | psnr [36.108133333333335, 37.1209, 38.249700000000004, 38.92263333333333] 24 | msssim [0.9798166666666667, 0.9823266666666667, 0.9848766666666667, 0.9864033333333334] -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # M-LVC: Multiple Frames Prediction for Learned Video Compression 2 | 3 | The project page for the paper: 4 | 5 | Jianping Lin, Dong Liu, Houqiang Li, Feng Wu, “M-LVC: Multiple Frames Prediction for Learned Video Compression”. in IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2020. [[OpenAccess](https://openaccess.thecvf.com/content_CVPR_2020/html/Lin_M-LVC_Multiple_Frames_Prediction_for_Learned_Video_Compression_CVPR_2020_paper.html)][[arXiv](https://arxiv.org/abs/2004.10290)] 6 | 7 | If our paper and codes are useful for your research, please cite: 8 | ``` 9 | @inproceedings{lin2020m, 10 | title={M-LVC: Multiple Frames Prediction for Learned Video Compression}, 11 | author={Lin, Jianping and Liu, Dong and Li, Houqiang and Wu, Feng}, 12 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 13 | pages={3546--3554}, 14 | year={2020} 15 | } 16 | ``` 17 | If you have any question or find any bug, please feel free to contact: 18 | 19 | Jianping Lin @ University of Science and Technology of China (USTC) 20 | 21 | Email: ljp105@mail.ustc.edu.cn 22 | 23 | ## Introduction 24 | 25 | ![ ](Figures/M-LVC.PNG) 26 | 27 | We propose an end-to-end learned video compression scheme for low-latency scenarios. Previous methods are limited in using the previous one frame as reference. Our method introduces the usage of the previous multiple frames as references. In our scheme, the motion vector (MV) field is calculated between the current frame and the previous one. With multiple reference frames and associated multiple MV fields, our designed network can generate more accurate prediction of the current frame, yielding less residual. Multiple reference frames also help generate MV prediction, which reduces the coding cost of MV field. We use two deep auto-encoders to compress the residual and the MV, respectively. To compensate for the compression error of the auto-encoders, we further design a MV refinement network and a residual refinement network, taking use of the multiple reference frames as well. All the modules in our scheme are jointly optimized through a single rate-distortion loss function. We use a step-by-step training strategy to optimize the entire scheme. Experimental results show that the proposed method outperforms the existing learned video compression methods for low-latency mode. Our method also performs better than H.265 in both PSNR and MS-SSIM. Our code and models are publicly available. 28 | 29 | ## Codes 30 | The currently available code is for evaluation, while it can also be modified for training as the implementation of the network is available. 31 | 32 | ### Dependency 33 | - tensorflow-gpu >=1.13.1 (the code only can be run in GPU mode) 34 | 35 | - opencv-python, matplotlib, scipy, pillow 36 | 37 | - Pre-trained models ([Download link](https://drive.google.com/file/d/1DaYh6_WTmrp0RoTfEPZSjZUujr4rGhqx/view?usp=sharing)) 38 | 39 | - BPG ([Download link](https://bellard.org/bpg/)) 40 | 41 | (*In our code, we use BPG to compress I-frames instead of training learned image compression models. Here, we upload the executable files of BPG for windows.*) 42 | 43 | ### Compressing video sequences 44 | 45 | Since our code currently only supports the sequences with the height and width as the multiples of 64, we first use ffmpeg to resize the original sequences to the multiples of 64, e.g., 46 | ``` 47 | ffmpeg -pix_fmt yuv420p -s 1920x1080 -i input_video.yuv -vf scale="1920:1024" output_video.yuv 48 | ``` 49 | Our resized sequences of JCT-VC Class C dataset can be downloaded from ([link](https://drive.google.com/file/d/1gFNscYeZ3C-ZZj1T9IsWOtXBT3qtue4-/view?usp=sharing)). 50 | 51 | You can use the following command to compress any class of the UVG and JCT-VC datasets: 52 | 53 | ``` 54 | python test.py --command compress --test_seq_dir directory_containing_testSequence --test_class ClassC --exp_data_dir ./exp_data_dir -r path_to_model/model_name.ckpt --lambda 16 55 | ``` 56 | ``` 57 | --test_class, the video class to be compressed (e.g., ClassB, ClassC, ClassD, ClassE, ClassUVG) 58 | --lambda, the lambda value of used trained model (i.e., 16, 24, 40, 64) 59 | ``` 60 | 61 | ### Entropy coding 62 | Currently, we do not provide the entropy coding module. We give the estimated Bpp for the quantized latent representations. It is straightforward to compress them by using traditional entropy coding tools, such as Range Coder. 63 | 64 | ### Experimental Results 65 | We test the proposed method on the JCT-VC (Classes B, C, D and E) and the [UVG](http://ultravideo.cs.tut.fi/#testsequences) datasets. Note that, the [UVG](http://ultravideo.cs.tut.fi/#testsequences) dataset has been enlarged recently. To compare with previous approaches, we only test on the original 7 videos in UVG, i.e., *Beauty*, *Bosphorus*, *HoneyBee*, *Jockey*, *ReadySetGo*, *ShakeNDry* and *YachtRide*. 66 | 67 | The detailed results (bpp, PSNR and MS-SSIM values) on each video dataset are shown in [data.txt](/RD_Results). The RD curves of our method compared with [Lu *et al.*, DVC](http://openaccess.thecvf.com/content_CVPR_2019/papers/Lu_DVC_An_End-To-End_Deep_Video_Compression_Framework_CVPR_2019_paper.pdf) and x264/x265 with *LDP very fast* mode are shown by the figures in /RD_Results folder. Same as DVC, for each video sequence, we got the average PSNR by averaging the PSNRs from all frames. For each dataset, like ClassB, we average the PSNR from different video sequences. Note that, the overall RD results here are slightly better than the results in our paper, as we set more appropriate quantization parameters of BPG to compress I-frames. 68 | -------------------------------------------------------------------------------- /SDL.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/SDL.dll -------------------------------------------------------------------------------- /SDL_image.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/SDL_image.dll -------------------------------------------------------------------------------- /bpgdec.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/bpgdec.exe -------------------------------------------------------------------------------- /bpgenc.exe: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/bpgenc.exe -------------------------------------------------------------------------------- /exp_data_dir/I_frames_enc/RaceHorsesC_832x448.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/exp_data_dir/I_frames_enc/RaceHorsesC_832x448.png -------------------------------------------------------------------------------- /exp_data_dir/I_frames_enc/dec_RaceHorsesC_832x448_30_qp21.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/exp_data_dir/I_frames_enc/dec_RaceHorsesC_832x448_30_qp21.png -------------------------------------------------------------------------------- /exp_data_dir/I_frames_enc/enc_RaceHorsesC_832x448_30_qp21.bpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/exp_data_dir/I_frames_enc/enc_RaceHorsesC_832x448_30_qp21.bpg -------------------------------------------------------------------------------- /exp_data_dir/figure/ClassC_RaceHorsesC_curidx1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/exp_data_dir/figure/ClassC_RaceHorsesC_curidx1.png -------------------------------------------------------------------------------- /flow_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib 3 | 4 | matplotlib.use('Agg') 5 | from pylab import box 6 | import matplotlib.pyplot as plt 7 | # import cv2 8 | import sys 9 | import argparse 10 | import math 11 | 12 | __all__ = ['load_flow', 'save_flow', 'vis_flow'] 13 | 14 | 15 | def load_flow(path): 16 | with open(path, 'rb') as f: 17 | magic = float(np.fromfile(f, np.float32, count=1)[0]) 18 | if magic == 202021.25: 19 | w, h = np.fromfile(f, np.int32, count=1)[0], np.fromfile(f, np.int32, count=1)[0] 20 | data = np.fromfile(f, np.float32, count=h * w * 2) 21 | data.resize((h, w, 2)) 22 | return data 23 | return None 24 | 25 | 26 | def save_flow(path, flow): 27 | magic = np.array([202021.25], np.float32) 28 | h, w = flow.shape[:2] 29 | h, w = np.array([h], np.int32), np.array([w], np.int32) 30 | 31 | with open(path, 'wb') as f: 32 | magic.tofile(f); 33 | w.tofile(f); 34 | h.tofile(f); 35 | flow.tofile(f) 36 | 37 | 38 | def makeColorwheel(): 39 | # color encoding scheme 40 | 41 | # adapted from the color circle idea described at 42 | # http://members.shaw.ca/quadibloc/other/colint.htm 43 | 44 | RY = 15 45 | YG = 6 46 | GC = 4 47 | CB = 11 48 | BM = 13 49 | MR = 6 50 | 51 | ncols = RY + YG + GC + CB + BM + MR 52 | 53 | colorwheel = np.zeros([ncols, 3]) # r g b 54 | 55 | col = 0 56 | # RY 57 | colorwheel[0:RY, 0] = 255 58 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY, 1) / RY) 59 | col += RY 60 | 61 | # YG 62 | colorwheel[col:YG + col, 0] = 255 - np.floor(255 * np.arange(0, YG, 1) / YG) 63 | colorwheel[col:YG + col, 1] = 255; 64 | col += YG; 65 | 66 | # GC 67 | colorwheel[col:GC + col, 1] = 255 68 | colorwheel[col:GC + col, 2] = np.floor(255 * np.arange(0, GC, 1) / GC) 69 | col += GC; 70 | 71 | # CB 72 | colorwheel[col:CB + col, 1] = 255 - np.floor(255 * np.arange(0, CB, 1) / CB) 73 | colorwheel[col:CB + col, 2] = 255 74 | col += CB; 75 | 76 | # BM 77 | colorwheel[col:BM + col, 2] = 255 78 | colorwheel[col:BM + col, 0] = np.floor(255 * np.arange(0, BM, 1) / BM) 79 | col += BM; 80 | 81 | # MR 82 | colorwheel[col:MR + col, 2] = 255 - np.floor(255 * np.arange(0, MR, 1) / MR) 83 | colorwheel[col:MR + col, 0] = 255 84 | return colorwheel 85 | 86 | 87 | def computeColor(u, v): 88 | colorwheel = makeColorwheel(); 89 | nan_u = np.isnan(u) 90 | nan_v = np.isnan(v) 91 | nan_u = np.where(nan_u) 92 | nan_v = np.where(nan_v) 93 | 94 | u[nan_u] = 0 95 | u[nan_v] = 0 96 | v[nan_u] = 0 97 | v[nan_v] = 0 98 | 99 | ncols = colorwheel.shape[0] 100 | radius = np.sqrt(u ** 2 + v ** 2) 101 | a = np.arctan2(-v, -u) / np.pi 102 | fk = (a + 1) / 2 * (ncols - 1) # -1~1 maped to 1~ncols 103 | k0 = fk.astype(np.uint8) # 1, 2, ..., ncols 104 | k1 = k0 + 1 105 | k1[k1 == ncols] = 0 106 | f = fk - k0 107 | 108 | img = np.empty([k1.shape[0], k1.shape[1], 3]) 109 | ncolors = colorwheel.shape[1] 110 | for i in range(ncolors): 111 | tmp = colorwheel[:, i] 112 | col0 = tmp[k0] / 255 113 | col1 = tmp[k1] / 255 114 | col = (1 - f) * col0 + f * col1 115 | idx = radius <= 1 116 | col[idx] = 1 - radius[idx] * (1 - col[idx]) # increase saturation with radius 117 | col[~idx] *= 0.75 # out of range 118 | img[:, :, 2 - i] = np.floor(255 * col).astype(np.uint8) 119 | 120 | return img.astype(np.uint8) 121 | 122 | 123 | def vis_flow(flow): 124 | eps = sys.float_info.epsilon 125 | UNKNOWN_FLOW_THRESH = 1e9 126 | UNKNOWN_FLOW = 1e10 127 | 128 | u = flow[:, :, 0] 129 | v = flow[:, :, 1] 130 | 131 | maxu = -999 132 | maxv = -999 133 | 134 | minu = 999 135 | minv = 999 136 | 137 | maxrad = -1 138 | # fix unknown flow 139 | greater_u = np.where(u > UNKNOWN_FLOW_THRESH) 140 | greater_v = np.where(v > UNKNOWN_FLOW_THRESH) 141 | u[greater_u] = 0 142 | u[greater_v] = 0 143 | v[greater_u] = 0 144 | v[greater_v] = 0 145 | 146 | maxu = max([maxu, np.amax(u)]) 147 | minu = min([minu, np.amin(u)]) 148 | 149 | maxv = max([maxv, np.amax(v)]) 150 | minv = min([minv, np.amin(v)]) 151 | rad = np.sqrt(np.multiply(u, u) + np.multiply(v, v)) 152 | maxrad = max([maxrad, np.amax(rad)]) 153 | # print('max flow: %.4f flow range: u = %.3f .. %.3f; v = %.3f .. %.3f\n' % (maxrad, minu, maxu, minv, maxv)) 154 | 155 | u = u / (maxrad + eps) 156 | v = v / (maxrad + eps) 157 | img = computeColor(u, v) 158 | return img[:, :, [2, 1, 0]] 159 | 160 | def vis_flow_image_final(flow_pyramid, flow_gt_pyramid, images_list,gray_images_list, filename='./flow.png'): 161 | num_contents = len(flow_pyramid) + len(flow_gt_pyramid) + len(images_list)+ len(gray_images_list) 162 | nums_list = [len(flow_pyramid), len(flow_gt_pyramid), len(images_list), len(gray_images_list)] 163 | nums_list.sort() 164 | cols = nums_list[-2] 165 | if cols <= 3: 166 | cols = nums_list[-1] 167 | cols=4 168 | rows = math.ceil(num_contents / cols) 169 | 170 | fig_dpi=200 171 | plt.rcParams['savefig.dpi'] = fig_dpi 172 | plt.rcParams['figure.dpi'] = fig_dpi 173 | 174 | fig = plt.figure() 175 | 176 | fig_id = 1 177 | 178 | for image in images_list: 179 | plt.subplot(rows, cols, fig_id) 180 | plt.imshow(image) 181 | plt.tick_params(labelbottom=False, bottom=False) 182 | plt.tick_params(labelleft=False, left=False) 183 | plt.xticks([]) 184 | box(False) 185 | fig_id += 1 186 | 187 | for image in gray_images_list: 188 | plt.subplot(rows, cols, fig_id) 189 | plt.imshow(image,cmap='gray') 190 | plt.tick_params(labelbottom=False, bottom=False) 191 | plt.tick_params(labelleft=False, left=False) 192 | plt.xticks([]) 193 | box(False) 194 | fig_id += 1 195 | 196 | for flow_gt in flow_gt_pyramid: 197 | plt.subplot(rows, cols, fig_id) 198 | plt.imshow(vis_flow(flow_gt)) 199 | plt.tick_params(labelbottom=False, bottom=False) 200 | plt.tick_params(labelleft=False, left=False) 201 | plt.xticks([]) 202 | box(False) 203 | 204 | fig_id += 1 205 | for flow in flow_pyramid: 206 | plt.subplot(rows, cols, fig_id) 207 | plt.imshow(vis_flow(flow)) 208 | plt.tick_params(labelbottom=False, bottom=False) 209 | plt.tick_params(labelleft=False, left=False) 210 | plt.xticks([]) 211 | box(False) 212 | 213 | fig_id += 1 214 | 215 | plt.tight_layout() 216 | plt.savefig(filename, bbox_inches='tight', pad_inches=0.1, dpi=fig_dpi) 217 | plt.close() 218 | 219 | if __name__ == '__main__': 220 | import matplotlib.pyplot as plt 221 | 222 | flow = load_flow('13382_flow.flo') 223 | flow = load_flow('datasets/Sintel/training/flow/alley_1/frame_0001.flo') 224 | img = vis_flow(flow) 225 | import imageio 226 | 227 | imageio.imsave('test.png', img) 228 | # import cv2 229 | # cv2.imshow('', img[:,:,:]) 230 | # cv2.waitKey() 231 | -------------------------------------------------------------------------------- /helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Flow(object): 5 | """ 6 | based on https://github.com/cgtuebingen/learning-blind-motion-deblurring/blob/master/synthblur/src/flow.cpp#L44 7 | """ 8 | def __init__(self): 9 | super(Flow, self).__init__() 10 | self.wheel = None 11 | self._construct_wheel() 12 | 13 | @staticmethod 14 | def read(file): 15 | # https://stackoverflow.com/a/44906777/7443104 16 | with open(file, 'rb') as f: 17 | magic = np.fromfile(f, np.float32, count=1) 18 | if 202021.25 != magic: 19 | raise Exception('Magic number incorrect. Invalid .flo file') 20 | else: 21 | w = np.fromfile(f, np.int32, count=1)[0] 22 | h = np.fromfile(f, np.int32, count=1)[0] 23 | data = np.fromfile(f, np.float32, count=2 * w * h) 24 | return np.resize(data, (h, w, 2)) 25 | 26 | def _construct_wheel(self): 27 | k = 0 28 | 29 | RY, YG, GC = 15, 6, 4 30 | YG, GC, CB = 6, 4, 11 31 | BM, MR = 13, 6 32 | 33 | self.wheel = np.zeros((55, 3), dtype=np.float32) 34 | 35 | for i in range(RY): 36 | self.wheel[k] = np.array([255., 255. * i / float(RY), 0]) 37 | k += 1 38 | 39 | for i in range(YG): 40 | self.wheel[k] = np.array([255. - 255. * i / float(YG), 255., 0]) 41 | k += 1 42 | 43 | for i in range(GC): 44 | self.wheel[k] = np.array([0, 255., 255. * i / float(GC)]) 45 | k += 1 46 | 47 | for i in range(CB): 48 | self.wheel[k] = np.array([0, 255. - 255. * i / float(CB), 255.]) 49 | k += 1 50 | 51 | for i in range(BM): 52 | self.wheel[k] = np.array([255. * i / float(BM), 0, 255.]) 53 | k += 1 54 | 55 | for i in range(MR): 56 | self.wheel[k] = np.array([255., 0, 255. - 255. * i / float(MR)]) 57 | k += 1 58 | 59 | self.wheel = self.wheel / 255. 60 | 61 | def visualize(self, nnf): 62 | assert len(nnf.shape) == 3 63 | assert nnf.shape[2] == 2 64 | 65 | RY, YG, GC = 15, 6, 4 66 | YG, GC, CB = 6, 4, 11 67 | BM, MR = 13, 6 68 | NCOLS = RY + YG + GC + CB + BM + MR 69 | 70 | fx = nnf[:, :, 0].astype(np.float32) 71 | fy = nnf[:, :, 1].astype(np.float32) 72 | 73 | h, w = fx.shape[:2] 74 | fx = fx.reshape([-1]) 75 | fy = fy.reshape([-1]) 76 | 77 | rad = np.sqrt(fx * fx + fy * fy) 78 | 79 | max_rad = rad.max() 80 | 81 | a = np.arctan2(-fy, -fx) / np.pi 82 | fk = (a + 1.0) / 2.0 * (NCOLS - 1) 83 | k0 = fk.astype(np.int32) 84 | k1 = (k0 + 1) % NCOLS 85 | f = (fk - k0).astype(np.float32) 86 | 87 | color0 = self.wheel[k0, :] 88 | color1 = self.wheel[k1, :] 89 | 90 | f = np.stack([f, f, f], axis=-1) 91 | color = (1 - f) * color0 + f * color1 92 | 93 | color = 1 - (np.expand_dims(rad, axis=-1) / max_rad) * (1 - color) 94 | 95 | return color.reshape(h, w, 3)[:, :, ::-1] 96 | 97 | 98 | if __name__ == '__main__': 99 | import cv2 100 | nnf = Flow.read('/tmp/data2/07446_flow.flo') 101 | v = Flow() 102 | rgb = v.visualize(nnf) 103 | cv2.imshow('rgb', rgb) 104 | cv2.waitKey(0) -------------------------------------------------------------------------------- /libgcc_s_seh-1.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/libgcc_s_seh-1.dll -------------------------------------------------------------------------------- /libjpeg-62.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/libjpeg-62.dll -------------------------------------------------------------------------------- /libpng16-16.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/libpng16-16.dll -------------------------------------------------------------------------------- /libstdc++-6.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/libstdc++-6.dll -------------------------------------------------------------------------------- /libtiff-5.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/libtiff-5.dll -------------------------------------------------------------------------------- /libwinpthread-1.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/libwinpthread-1.dll -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | from modules import * 8 | 9 | import tensorflow_compression as tfc 10 | import numpy as np 11 | 12 | class MAMVPNet(object): 13 | def __init__(self, num_levels=6, 14 | warp_type='bilinear', use_dc=False, 15 | output_level=4, name='pwcmenet'): 16 | self.num_levels = num_levels 17 | self.warp_type = warp_type 18 | self.use_dc = use_dc 19 | self.output_level = output_level 20 | self.name = name 21 | 22 | self.fp_extractor = FeaturePyramidExtractor_custom_low(self.num_levels) 23 | self.warp_layer = WarpingLayer(self.warp_type) 24 | self.of_estimators = [OpticalFlowEstimator_custom_ME(use_dc=self.use_dc, name=f'optflow_{l}') \ 25 | for l in range(self.num_levels + 1)] 26 | 27 | def __call__(self, flows_2_pyramid, flows_1_pyramid, flows_0_pyramid, reuse=False): 28 | with tf.variable_scope(self.name, reuse=reuse) as vs: 29 | pyramid_0 = self.fp_extractor(flows_0_pyramid[-1], reuse=reuse) 30 | pyramid_1 = self.fp_extractor(flows_1_pyramid[-1]) 31 | pyramid_2 = self.fp_extractor(flows_2_pyramid[-1]) 32 | 33 | pyramid_1_warped = [] 34 | pyramid_2_warped = [] 35 | for l, (flows_0, flows_1, features_1, features_2) in enumerate( 36 | zip(flows_0_pyramid, flows_1_pyramid, pyramid_1, pyramid_2)): 37 | print(f'Warp Optical Flow Level {l}') 38 | 39 | features_1_warped = self.warp_layer(features_1, flows_0) 40 | pyramid_1_warped.append(features_1_warped) 41 | features_2_warped = self.warp_layer(features_2, (flows_0 + self.warp_layer(flows_1, flows_0))) 42 | pyramid_2_warped.append(features_2_warped) 43 | 44 | flows_pyramid = [] 45 | flows_up, features_up = None, None 46 | for l, (features_0, features_1, features_2) in enumerate( 47 | zip(pyramid_0, pyramid_1_warped, pyramid_2_warped)): 48 | print(f'Level {l}') 49 | 50 | # Optical flow estimation 51 | features_total = tf.concat([features_0, features_1, features_2], axis=3) 52 | if l < self.output_level: 53 | flows, flows_up, features_up \ 54 | = self.of_estimators[l](features_total, flows_up, features_up) 55 | else: 56 | # At output level 57 | flows = self.of_estimators[l](features_total, flows_up, features_up, 58 | is_output=True) 59 | return flows, None, None 60 | 61 | flows_pyramid.append(flows) 62 | 63 | @property 64 | def vars(self): 65 | return [var for var in tf.global_variables() if self.name in var.name] 66 | 67 | class MCNet_Multiple(object): 68 | def __init__(self, name='mcnet'): 69 | self.name = name 70 | self.warp_layer = WarpingLayer('bilinear') 71 | self.f_extractor = FeatureExtractor_custom_RGB_new() 72 | self.context_RGB = ContextNetwork_RGB_ResNet(name='context_RGB') 73 | 74 | def __call__(self, images_pre_rec_4, images_pre_rec_3, images_pre_rec_2, images_pre_rec, flow3, flow2, flow1, flow, reuse=False): 75 | with tf.variable_scope(self.name, reuse=reuse) as vs: 76 | images_pre_rec_warped = self.warp_layer(images_pre_rec, flow*20.0) 77 | features = self.f_extractor(images_pre_rec) 78 | features_warped = self.warp_layer(features, flow*20.0) 79 | flow1_warped = self.warp_layer(flow1, flow * 20.0) 80 | flow2_warped = self.warp_layer(flow2, (flow1_warped + flow) * 20.0) 81 | flow3_warped = self.warp_layer(flow3, (flow2_warped + flow1_warped + flow) * 20.0) 82 | features_4 = self.f_extractor(images_pre_rec_4, reuse=True) 83 | features_4_warped = self.warp_layer(features_4, (flow3_warped + flow2_warped + flow1_warped + flow) * 20.0) 84 | features_3 = self.f_extractor(images_pre_rec_3, reuse=True) 85 | features_3_warped = self.warp_layer(features_3, (flow2_warped + flow1_warped + flow) * 20.0) 86 | features_2 = self.f_extractor(images_pre_rec_2, reuse=True) 87 | features_2_warped = self.warp_layer(features_2, (flow1_warped + flow) * 20.0) 88 | features = tf.concat([features_4_warped, features_3_warped, features_2_warped, features_warped], axis=3) 89 | output = self.context_RGB(images_pre_rec_warped, features) 90 | return output, features 91 | 92 | @property 93 | def vars(self): 94 | return [var for var in tf.global_variables() if self.name in var.name] 95 | 96 | @property 97 | def vars_restore(self): 98 | return [var for var in tf.global_variables() if ((self.name in var.name) and ('context_RGB' not in var.name))] 99 | 100 | class MCNet(object): 101 | def __init__(self, name='mcnet'): 102 | self.name = name 103 | self.warp_layer = WarpingLayer('bilinear') 104 | self.f_extractor = FeatureExtractor_custom_RGB() 105 | self.context_RGB = ContextNetwork_RGB_ResNet(name='context_RGB') 106 | # self.scales = [None, 0.625, 1.25, 2.5, 5.0, 10., 20.] 107 | 108 | def __call__(self, images_pre_rec, flow, reuse=False): 109 | """Y_frames (n, h, w, 1)""" 110 | with tf.variable_scope(self.name, reuse=reuse) as vs: 111 | images_pre_rec_warped = self.warp_layer(images_pre_rec, flow) 112 | features = self.f_extractor(images_pre_rec) 113 | features_warped = self.warp_layer(features, flow) 114 | output = self.context_RGB(images_pre_rec_warped, features_warped) 115 | return output 116 | 117 | @property 118 | def vars(self): 119 | return [var for var in tf.global_variables() if self.name in var.name] 120 | 121 | @property 122 | def vars_restore(self): 123 | return [var for var in tf.global_variables() if 124 | ((self.name in var.name) and ('context_RGB' not in var.name) and ('f_extractor' not in var.name))] 125 | 126 | class ResiDeBlurNet(object): 127 | def __init__(self, name='resideblurmodel'): 128 | self.name = name 129 | self.f_extractor = FeatureExtractor_custom_RGB_new() 130 | self.warp_layer = WarpingLayer('bilinear') 131 | self.resideblur_ResNet=Resideblur_ResNet_RGB('Resideblur_ResNet') 132 | # self.scales = [None, 0.625, 1.25, 2.5, 5.0, 10., 20.] 133 | 134 | def __call__(self, tensor, images_pred, features_warped, reuse=False): 135 | """Y_frames (n, h, w, 1)""" 136 | with tf.variable_scope(self.name, reuse=reuse) as vs: 137 | features = self.f_extractor(images_pred) 138 | features = tf.concat([features_warped, features], axis=3) 139 | output = self.resideblur_ResNet(tensor, features) 140 | return output 141 | 142 | @property 143 | def vars(self): 144 | return [var for var in tf.global_variables() if self.name in var.name] 145 | 146 | class bls2017ImgCompression_mvd_factor(object): 147 | def __init__(self, input_channel=2, num_filters=128, name='bls2017ImgCompression'): 148 | self.input_channel=input_channel 149 | self.num_filters=num_filters 150 | self.name = name 151 | 152 | def analysis_transform(self, tensor, num_filters): 153 | """Builds the analysis transform.""" 154 | 155 | with tf.variable_scope("analysis"): 156 | with tf.variable_scope("layer_0"): 157 | layer = tfc.SignalConv2D( 158 | num_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros", 159 | use_bias=True, activation=tfc.GDN()) 160 | tensor = layer(tensor) 161 | with tf.variable_scope("layer_1"): 162 | layer = tfc.SignalConv2D( 163 | num_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros", 164 | use_bias=True, activation=tfc.GDN()) 165 | tensor = layer(tensor) 166 | with tf.variable_scope("layer_2"): 167 | layer = tfc.SignalConv2D( 168 | num_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros", 169 | use_bias=True, activation=tfc.GDN()) 170 | tensor = layer(tensor) 171 | 172 | with tf.variable_scope("layer_3"): 173 | layer = tfc.SignalConv2D( 174 | num_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros", 175 | use_bias=False, activation=None) 176 | tensor = layer(tensor) 177 | 178 | return tensor 179 | 180 | def synthesis_transform(self, tensor, input_channel, num_filters): 181 | """Builds the synthesis transform.""" 182 | 183 | with tf.variable_scope("synthesis"): 184 | with tf.variable_scope("layer_0"): 185 | layer = tfc.SignalConv2D( 186 | num_filters, (5, 5), corr=False, strides_up=2, padding="same_zeros", 187 | use_bias=True, activation=tfc.GDN(inverse=True)) 188 | tensor = layer(tensor) 189 | with tf.variable_scope("layer_1"): 190 | layer = tfc.SignalConv2D( 191 | num_filters, (5, 5), corr=False, strides_up=2, padding="same_zeros", 192 | use_bias=True, activation=tfc.GDN(inverse=True)) 193 | tensor = layer(tensor) 194 | with tf.variable_scope("layer_2"): 195 | layer = tfc.SignalConv2D( 196 | num_filters, (5, 5), corr=False, strides_up=2, padding="same_zeros", 197 | use_bias=True, activation=tfc.GDN(inverse=True)) 198 | tensor = layer(tensor) 199 | with tf.variable_scope("layer_3"): 200 | layer = tfc.SignalConv2D( 201 | input_channel, (5, 5), corr=False, strides_up=2, padding="same_zeros", 202 | use_bias=True, activation=None) 203 | tensor = layer(tensor) 204 | 205 | return tensor 206 | def __call__(self, x, num_pixels, reuse=False, isTrain=True): 207 | with tf.variable_scope(self.name, reuse=reuse) as vs: 208 | y = self.analysis_transform(x, self.num_filters) 209 | entropy_bottleneck = tfc.EntropyBottleneck() 210 | bit_string = None 211 | if isTrain: 212 | y_tilde, likelihoods = entropy_bottleneck(y, training=True) 213 | else: 214 | string = entropy_bottleneck.compress(y) 215 | bit_string = tf.squeeze(string, axis=0) 216 | y_tilde, likelihoods = entropy_bottleneck(y, training=False) 217 | x_tilde = self.synthesis_transform(y_tilde, self.input_channel, self.num_filters) 218 | train_bpp = tf.reduce_sum(tf.log(likelihoods)) / (-np.log(2) * num_pixels) 219 | return bit_string, entropy_bottleneck, x_tilde, train_bpp 220 | 221 | @property 222 | def vars(self): 223 | return [var for var in tf.global_variables() if self.name in var.name] 224 | 225 | class bls2017ImgCompression_resi_RGB(object): 226 | def __init__(self, input_channel=2, N_filters=128,M_filters=128, name='bls2017ImgCompression'): 227 | self.input_channel = input_channel 228 | self.N_filters = N_filters 229 | self.M_filters = M_filters 230 | self.name = name 231 | self.hyperModel=HyperPrior_resi(M_filters,N_filters,'hyper_resi') 232 | 233 | def analysis_transform(self, tensor, N_filters, M_filters): 234 | """Builds the analysis transform.""" 235 | 236 | with tf.variable_scope("analysis"): 237 | with tf.variable_scope("layer_0"): 238 | layer = tfc.SignalConv2D( 239 | N_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros", 240 | use_bias=True, activation=tfc.GDN()) 241 | tensor = layer(tensor) 242 | with tf.variable_scope("layer_1"): 243 | layer = tfc.SignalConv2D( 244 | N_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros", 245 | use_bias=True, activation=tfc.GDN()) 246 | tensor = layer(tensor) 247 | with tf.variable_scope("layer_2"): 248 | layer = tfc.SignalConv2D( 249 | N_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros", 250 | use_bias=True, activation=tfc.GDN()) 251 | tensor = layer(tensor) 252 | 253 | with tf.variable_scope("layer_3"): 254 | layer = tfc.SignalConv2D( 255 | M_filters, (5, 5), corr=True, strides_down=2, padding="same_zeros", 256 | use_bias=False, activation=None) 257 | tensor = layer(tensor) 258 | 259 | return tensor 260 | 261 | def synthesis_transform(self, tensor, N_filters): 262 | """Builds the synthesis transform.""" 263 | 264 | with tf.variable_scope("synthesis"): 265 | with tf.variable_scope("layer_0"): 266 | layer = tfc.SignalConv2D( 267 | N_filters, (5, 5), corr=False, strides_up=2, padding="same_zeros", 268 | use_bias=True, activation=tfc.GDN(inverse=True)) 269 | tensor = layer(tensor) 270 | with tf.variable_scope("layer_1"): 271 | layer = tfc.SignalConv2D( 272 | N_filters, (5, 5), corr=False, strides_up=2, padding="same_zeros", 273 | use_bias=True, activation=tfc.GDN(inverse=True)) 274 | tensor = layer(tensor) 275 | with tf.variable_scope("layer_2"): 276 | layer = tfc.SignalConv2D( 277 | N_filters, (5, 5), corr=False, strides_up=2, padding="same_zeros", 278 | use_bias=True, activation=tfc.GDN(inverse=True)) 279 | tensor = layer(tensor) 280 | with tf.variable_scope("layer_out"): 281 | layer = tfc.SignalConv2D( 282 | 32, (5, 5), corr=False, strides_up=2, padding="same_zeros", 283 | use_bias=True, activation=None) 284 | tensor = layer(tensor) 285 | ''' 286 | with tf.variable_scope("layer_4"): 287 | layer = tfc.SignalConv2D( 288 | 3, (3, 3), corr=False, strides_up=1, padding="same_zeros", 289 | use_bias=True, activation=None) 290 | tensor = layer(tensor) 291 | ''' 292 | 293 | return tensor 294 | 295 | def __call__(self, resi_frames, num_pixels, reuse=False, isTrain=True): 296 | with tf.variable_scope(self.name, reuse=reuse) as vs: 297 | 298 | y = self.analysis_transform(resi_frames, self.N_filters, self.M_filters) 299 | 300 | entropy_bottleneck = tfc.EntropyBottleneck_gauss() 301 | bit_string = None 302 | bit_string_dev = None 303 | if isTrain: 304 | _, entropy_bottleneck_dev, dev_tilde, train_bpp_dev = self.hyperModel(y, num_pixels, reuse=False, isTrain=True) 305 | y_tilde, likelihoods = entropy_bottleneck(y, dev_tilde, training=True) 306 | else: 307 | bit_string_dev, entropy_bottleneck_dev, dev_tilde, train_bpp_dev = self.hyperModel(y, num_pixels, 308 | reuse=False, 309 | isTrain=False) 310 | with tf.device("/cpu:0"): 311 | string = entropy_bottleneck.compress(y, dev_tilde) 312 | bit_string = tf.squeeze(string, axis=0) 313 | y_tilde, likelihoods = entropy_bottleneck(y, dev_tilde, training=False) 314 | tensor_tilde = self.synthesis_transform(y_tilde, self.N_filters) 315 | train_bpp = tf.reduce_sum(tf.log(likelihoods)) / (-np.log(2) * num_pixels) 316 | return bit_string, entropy_bottleneck, tensor_tilde, train_bpp, bit_string_dev, entropy_bottleneck_dev, dev_tilde, train_bpp_dev, 317 | 318 | @property 319 | def vars(self): 320 | return [var for var in tf.global_variables() if self.name in var.name] 321 | 322 | 323 | class MVLoopFiltering(object): 324 | def __init__(self, name='mvlfmodel'): 325 | self.name = name 326 | self.f_extractor_0=FeatureExtractor_mvlf('FeatureExtractor_ilf_0') 327 | self.f_extractor_1 = FeatureExtractor_mvlf('FeatureExtractor_ilf_1') 328 | self.f_extractor_2 = FeatureExtractor_mvlf('FeatureExtractor_ilf_2') 329 | self.warp_layer = WarpingLayer('bilinear') 330 | self.context_mv_Unet = ContextNetwork_mv_Unet(name='context_mv_Unet') 331 | # self.scales = [None, 0.625, 1.25, 2.5, 5.0, 10., 20.] 332 | 333 | def __call__(self, flow3, flow2, flow1, flow, images_pre_rec, reuse=False): 334 | """Y_frames (n, h, w, 1)""" 335 | with tf.variable_scope(self.name, reuse=reuse) as vs: 336 | flow3_f = self.f_extractor_0(flow3) 337 | flow2_f = self.f_extractor_0(flow2, reuse=True) 338 | flow1_f = self.f_extractor_0(flow1, reuse=True) 339 | flow_f = self.f_extractor_1(flow) 340 | images_pre_rec_f = self.f_extractor_2(images_pre_rec) 341 | flow1_warped = self.warp_layer(flow1, flow * 20.0) 342 | flow2_warped = self.warp_layer(flow2, (flow1_warped + flow) * 20.0) 343 | 344 | flow3_f_warped = self.warp_layer(flow3_f, (flow2_warped + flow1_warped + flow) * 20.) 345 | flow2_f_warped = self.warp_layer(flow2_f, (flow1_warped + flow) * 20.) 346 | flow1_f_warped = self.warp_layer(flow1_f, flow * 20.) 347 | input = tf.concat([flow3_f_warped, flow2_f_warped, flow1_f_warped, flow_f, images_pre_rec_f], axis=3) 348 | output = self.context_mv_Unet(input, flow) 349 | return output 350 | 351 | @property 352 | def vars(self): 353 | return [var for var in tf.global_variables() if self.name in var.name] -------------------------------------------------------------------------------- /modules.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from functools import partial 4 | 5 | import tensorflow_compression as tfc 6 | 7 | # Feature pyramid extractor module simple/original ----------------------- 8 | 9 | class FeaturePyramidExtractor_custom_low(object): 10 | """ Feature pyramid extractor module""" 11 | 12 | def __init__(self, num_levels=6, name='fp_extractor'): 13 | self.num_levels = num_levels 14 | self.filters = [16, 24, 24, 24] 15 | self.name = name 16 | 17 | def __call__(self, images, reuse=True): 18 | with tf.variable_scope(self.name, reuse=reuse) as vs: 19 | features_pyramid = [] 20 | x = images 21 | for l in range(self.num_levels): 22 | if l==0: 23 | x = tf.layers.Conv2D(self.filters[l], (3, 3), (1, 1), 'same')(x) 24 | else: 25 | x = tf.layers.Conv2D(self.filters[l], (3, 3), (2, 2), 'same')(x) 26 | x = tf.nn.leaky_relu(x, 0.1) 27 | x = tf.layers.Conv2D(self.filters[l], (3, 3), (1, 1), 'same')(x) 28 | x = tf.nn.leaky_relu(x, 0.1) 29 | features_pyramid.append(x) 30 | 31 | # return feature pyramid by ascent order 32 | return features_pyramid[::-1] 33 | 34 | class OpticalFlowEstimator_custom_ME(object): 35 | """ Optical flow estimator module """ 36 | def __init__(self, use_dc = False, name = 'of_estimator'): 37 | self.filters = [32, 32, 32, 32, 32] 38 | self.use_dc = use_dc 39 | self.name = name 40 | 41 | def __call__(self, features, flows_up_prev = None, features_up_prev = None, 42 | is_output = False): 43 | with tf.variable_scope(self.name) as vs: 44 | features = features 45 | for f in [flows_up_prev, features_up_prev]: 46 | if f is not None: 47 | features = tf.concat([features, f], axis = 3) 48 | 49 | for f in self.filters: 50 | conv = tf.layers.Conv2D(f, (3, 3), (1, 1), 'same')(features) 51 | conv = tf.nn.leaky_relu(conv, 0.1) 52 | if self.use_dc: 53 | features = tf.concat([conv, features], axis = 3) 54 | else: 55 | features = conv 56 | 57 | flows = tf.layers.Conv2D(2, (3, 3), (1, 1), 'same')(features) 58 | if flows_up_prev is not None: 59 | # Residual connection 60 | flows += flows_up_prev 61 | 62 | if is_output: 63 | return flows 64 | else: 65 | _, h, w, _ = tf.unstack(tf.shape(flows)) 66 | flows_up = tf.image.resize_bilinear(flows, (2*h, 2*w))*2.0 67 | features_up = tf.image.resize_bilinear(features, (2*h, 2*w)) 68 | return flows, flows_up, features_up 69 | 70 | 71 | # Warping layer --------------------------------- 72 | def get_grid(x): 73 | batch_size, height, width, filters = tf.unstack(tf.shape(x)) 74 | Bg, Yg, Xg = tf.meshgrid(tf.range(batch_size), tf.range(height), tf.range(width), 75 | indexing = 'ij') 76 | # return indices volume indicate (batch, y, x) 77 | # return tf.stack([Bg, Yg, Xg], axis = 3) 78 | return Bg, Yg, Xg # return collectively for elementwise processing 79 | 80 | def nearest_warp(x, flow): 81 | grid_b, grid_y, grid_x = get_grid(x) 82 | flow = tf.cast(flow, tf.int32) 83 | 84 | warped_gy = tf.add(grid_y, flow[:,:,:,1]) # flow_y 85 | warped_gx = tf.add(grid_x, flow[:,:,:,0]) # flow_x 86 | # clip value by height/width limitation 87 | _, h, w, _ = tf.unstack(tf.shape(x)) 88 | warped_gy = tf.clip_by_value(warped_gy, 0, h-1) 89 | warped_gx = tf.clip_by_value(warped_gx, 0, w-1) 90 | 91 | warped_indices = tf.stack([grid_b, warped_gy, warped_gx], axis = 3) 92 | 93 | warped_x = tf.gather_nd(x, warped_indices) 94 | return warped_x 95 | 96 | def bilinear_warp(x, flow): 97 | _, h, w, _ = tf.unstack(tf.shape(x)) 98 | grid_b, grid_y, grid_x = get_grid(x) 99 | grid_b = tf.cast(grid_b, tf.float32) 100 | grid_y = tf.cast(grid_y, tf.float32) 101 | grid_x = tf.cast(grid_x, tf.float32) 102 | 103 | fx, fy = tf.unstack(flow, axis = -1) 104 | fx_0 = tf.floor(fx) 105 | fx_1 = fx_0+1 106 | fy_0 = tf.floor(fy) 107 | fy_1 = fy_0+1 108 | 109 | # warping indices 110 | h_lim = tf.cast(h-1, tf.float32) 111 | w_lim = tf.cast(w-1, tf.float32) 112 | gy_0 = tf.clip_by_value(grid_y + fy_0, 0., h_lim) 113 | gy_1 = tf.clip_by_value(grid_y + fy_1, 0., h_lim) 114 | gx_0 = tf.clip_by_value(grid_x + fx_0, 0., w_lim) 115 | gx_1 = tf.clip_by_value(grid_x + fx_1, 0., w_lim) 116 | 117 | g_00 = tf.cast(tf.stack([grid_b, gy_0, gx_0], axis = 3), tf.int32) 118 | g_01 = tf.cast(tf.stack([grid_b, gy_0, gx_1], axis = 3), tf.int32) 119 | g_10 = tf.cast(tf.stack([grid_b, gy_1, gx_0], axis = 3), tf.int32) 120 | g_11 = tf.cast(tf.stack([grid_b, gy_1, gx_1], axis = 3), tf.int32) 121 | 122 | # gather contents 123 | x_00 = tf.gather_nd(x, g_00) 124 | x_01 = tf.gather_nd(x, g_01) 125 | x_10 = tf.gather_nd(x, g_10) 126 | x_11 = tf.gather_nd(x, g_11) 127 | 128 | # coefficients 129 | c_00 = tf.expand_dims((fy_1 - fy)*(fx_1 - fx), axis = 3) 130 | c_01 = tf.expand_dims((fy_1 - fy)*(fx - fx_0), axis = 3) 131 | c_10 = tf.expand_dims((fy - fy_0)*(fx_1 - fx), axis = 3) 132 | c_11 = tf.expand_dims((fy - fy_0)*(fx - fx_0), axis = 3) 133 | 134 | return c_00*x_00 + c_01*x_01 + c_10*x_10 + c_11*x_11 135 | 136 | class WarpingLayer(object): 137 | def __init__(self, warp_type = 'nearest', name = 'warping'): 138 | self.warp = warp_type 139 | self.name = name 140 | 141 | def __call__(self, x, flow): 142 | with tf.name_scope(self.name) as ns: 143 | assert self.warp in ['nearest', 'bilinear'] 144 | if self.warp == 'nearest': 145 | x_warped = nearest_warp(x, flow) 146 | else: 147 | x_warped = bilinear_warp(x, flow) 148 | return x_warped 149 | # Context module ----------------------------------------------- 150 | class FeatureExtractor_custom_RGB(object): 151 | def __init__(self, name='f_extractor'): 152 | self.filters = [6] 153 | self.name = name 154 | 155 | def __call__(self, images, reuse=False): 156 | with tf.variable_scope(self.name, reuse=reuse) as vs: 157 | x=images 158 | x = tf.layers.Conv2D(32, (3, 3), (1, 1), 'same')(x) 159 | x = tf.nn.leaky_relu(x, 0.1) 160 | x = tf.layers.Conv2D(24, (3, 3), (1, 1), 'same')(x) 161 | x = tf.nn.leaky_relu(x, 0.1) 162 | x = tf.layers.Conv2D(12, (3, 3), (1, 1), 'same')(x) 163 | return x 164 | 165 | class FeatureExtractor_custom_RGB_new(object): 166 | """ Feature pyramid extractor module""" 167 | 168 | def __init__(self, name='f_extractor'): 169 | self.filters = [6] 170 | self.name = name 171 | 172 | def __call__(self, images, reuse=False): 173 | with tf.variable_scope(self.name, reuse=reuse) as vs: 174 | x=images 175 | x = tf.layers.Conv2D(32, (3, 3), (1, 1), 'same')(x) 176 | x = tf.nn.leaky_relu(x, 0.1) 177 | return x 178 | 179 | class FeatureExtractor_custom(object): 180 | """ Feature pyramid extractor module""" 181 | 182 | def __init__(self, name='f_extractor'): 183 | self.filters = [6] 184 | self.name = name 185 | 186 | def __call__(self, images, reuse=False): 187 | """ 188 | Args: 189 | - images [Y_images, U_images, V_images] 190 | - Y_images (batch, h, w, 1),U_images (batch, h, w, 1),V_images (batch, h, w, 1) 191 | 192 | Returns: 193 | - features_pyramid (batch, h_l, w_l, nch_l) for each scale levels: 194 | extracted feature pyramid (deep -> shallow order) 195 | """ 196 | with tf.variable_scope(self.name, reuse=reuse) as vs: 197 | l=0 198 | Y=images[0] 199 | _, h, w, _ = tf.unstack(tf.shape(Y)) 200 | U_up = tf.image.resize_bilinear(images[1], (h, w)) 201 | V_up = tf.image.resize_bilinear(images[2], (h, w)) 202 | YUV = tf.concat([Y,U_up,V_up],axis=3) 203 | x = tf.layers.Conv2D(32, (3, 3), (1, 1), 'same')(YUV) 204 | x = tf.nn.leaky_relu(x, 0.1) 205 | x = tf.layers.Conv2D(24, (3, 3), (1, 1), 'same')(x) 206 | x = tf.nn.leaky_relu(x, 0.1) 207 | x = tf.layers.Conv2D(12, (3, 3), (1, 1), 'same')(x) 208 | return x 209 | 210 | class ContextNetwork_RGB(object): 211 | """ Context module """ 212 | def __init__(self, name = 'context'): 213 | self.name = name 214 | 215 | def __call__(self, images, features): 216 | """ 217 | Args: 218 | - flows (batch, h, w, 2): optical flow 219 | - features (batch, h, w, 2): feature map passed from previous OF-estimator 220 | 221 | Returns: 222 | - flows (batch, h, w, 2): convolved optical flow 223 | """ 224 | with tf.variable_scope(self.name) as vs: 225 | x = tf.concat([images, features], axis = 3) 226 | x = tf.layers.Conv2D(64, (3, 3), (1, 1),'same', 227 | dilation_rate = (1, 1))(x) 228 | x = tf.nn.leaky_relu(x, 0.1) 229 | x = tf.layers.Conv2D(64, (3, 3), (1, 1),'same', 230 | dilation_rate = (2, 2))(x) 231 | x = tf.nn.leaky_relu(x, 0.1) 232 | x = tf.layers.Conv2D(32, (3, 3), (1, 1),'same', 233 | dilation_rate = (4, 4))(x) 234 | x = tf.nn.leaky_relu(x, 0.1) 235 | x = tf.layers.Conv2D(32, (3, 3), (1, 1),'same', 236 | dilation_rate = (8, 8))(x) 237 | x = tf.nn.leaky_relu(x, 0.1) 238 | x = tf.layers.Conv2D(32, (3, 3), (1, 1),'same', 239 | dilation_rate = (1, 1))(x) 240 | x = tf.nn.leaky_relu(x, 0.1) 241 | x = tf.layers.Conv2D(3, (3, 3), (1, 1),'same', 242 | dilation_rate = (1, 1))(x) 243 | output = images + x 244 | 245 | return output 246 | 247 | class ContextNetwork_RGB_ResNet(object): 248 | def __init__(self, name='context'): 249 | self.name = name 250 | 251 | def resblock(self, x): 252 | tmp = tf.nn.relu(tf.layers.Conv2D(64, (3, 3), (1, 1), 'same')(x)) 253 | tmp = tf.layers.Conv2D(64, (3, 3), (1, 1), 'same')(tmp) 254 | return x + tmp 255 | def __call__(self, images, features, reuse=False): 256 | with tf.variable_scope(self.name, reuse=reuse) as vs: 257 | x = tf.layers.Conv2D(64, (3, 3), (1, 1), 'same')(tf.concat([images,features],axis=3)) 258 | x = self.resblock(self.resblock(self.resblock(x))) 259 | x = tf.nn.leaky_relu(tf.layers.Conv2D(64, (3, 3), (1, 1), 'same')(x), 0.1) 260 | x1 = tf.nn.leaky_relu(tf.layers.Conv2D(64, (3, 3), (2, 2), 'same')(x), 0.1) 261 | x2 = tf.nn.leaky_relu(tf.layers.Conv2D(64, (3, 3), (2, 2), 'same')(x1), 0.1) 262 | _, h, w, _ = tf.unstack(tf.shape(x2)) 263 | x2_up = tf.image.resize_bilinear(self.resblock(x2), (2*h, 2*w)) 264 | x1 = self.resblock(x1)+x2_up 265 | _, h, w, _ = tf.unstack(tf.shape(x1)) 266 | x1_up = tf.image.resize_bilinear(self.resblock(x1), (2 * h, 2 * w)) 267 | x = self.resblock(self.resblock(x)) + x1_up 268 | x = self.resblock(self.resblock(self.resblock(x))) 269 | output = tf.layers.Conv2D(3, (3, 3), (1, 1), 'same')(x) + images 270 | return output 271 | 272 | class HyperPrior_resi(object): 273 | def __init__(self, input_channel=128, num_filters=128, name='bls2017ImgCompression'): 274 | self.input_channel = input_channel 275 | self.num_filters = num_filters 276 | self.name = name 277 | 278 | def analysis_transform(self, tensor, num_filters): 279 | """Builds the analysis transform.""" 280 | 281 | with tf.variable_scope("analysis"): 282 | tensor = tf.abs(tensor) 283 | with tf.variable_scope("layer_0"): 284 | tensor = tf.layers.Conv2D(num_filters, (3, 3), (1, 1), 'same')(tensor) 285 | tensor = tf.nn.relu(tensor) 286 | with tf.variable_scope("layer_1"): 287 | tensor = tf.layers.Conv2D(num_filters, (5, 5), (2, 2), 'same')(tensor) 288 | tensor = tf.nn.relu(tensor) 289 | with tf.variable_scope("layer_2"): 290 | tensor = tf.layers.Conv2D(num_filters, (5, 5), (2, 2), 'same')(tensor) 291 | return tensor 292 | 293 | def synthesis_transform(self, tensor, input_channel, num_filters): 294 | """Builds the synthesis transform.""" 295 | 296 | with tf.variable_scope("synthesis"): 297 | with tf.variable_scope("layer_0"): 298 | tensor = tf.layers.Conv2DTranspose(num_filters, (5, 5), (2, 2), 'same')(tensor) 299 | tensor = tf.nn.relu(tensor) 300 | with tf.variable_scope("layer_1"): 301 | tensor = tf.layers.Conv2DTranspose(num_filters, (5, 5), (2, 2), 'same')(tensor) 302 | tensor = tf.nn.relu(tensor) 303 | with tf.variable_scope("layer_2"): 304 | tensor = tf.layers.Conv2D(input_channel, (3, 3), (1, 1), 'same')(tensor) 305 | tensor = tf.nn.relu(tensor) 306 | return tensor 307 | 308 | def __call__(self, x, num_pixels, reuse=False, isTrain=True): 309 | with tf.variable_scope(self.name, reuse=reuse) as vs: 310 | y = self.analysis_transform(x, self.num_filters) 311 | entropy_bottleneck = tfc.EntropyBottleneck() 312 | bit_string = None 313 | if isTrain: 314 | y_tilde, likelihoods = entropy_bottleneck(y, training=True) 315 | else: 316 | string = entropy_bottleneck.compress(y) 317 | bit_string = tf.squeeze(string, axis=0) 318 | y_tilde, likelihoods = entropy_bottleneck(y, training=False) 319 | x_tilde = self.synthesis_transform(y_tilde, self.input_channel, self.num_filters) + 0.00000001 320 | train_bpp = tf.reduce_sum(tf.log(likelihoods)) / (-np.log(2) * num_pixels) 321 | return bit_string, entropy_bottleneck, x_tilde, train_bpp 322 | 323 | @property 324 | def vars(self): 325 | return [var for var in tf.global_variables() if self.name in var.name] 326 | 327 | class FeatureExtractor_mvlf(object): 328 | def __init__(self, name='mvlf_extractor'): 329 | self.name = name 330 | def __call__(self, x, reuse=False): 331 | with tf.variable_scope(self.name, reuse=reuse) as vs: 332 | x = tf.layers.Conv2D(32, (3, 3), (1, 1), 'same')(x) 333 | x = tf.nn.leaky_relu(x, 0.1) 334 | return x 335 | 336 | class ContextNetwork_mv_Unet(object): 337 | """ Context module """ 338 | def __init__(self, name = 'ContextNetwork_mv_Unet'): 339 | self.name = name 340 | 341 | def resblock(self, x): 342 | tmp = tf.nn.relu(tf.layers.Conv2D(64, (3, 3), (1, 1), 'same')(x)) 343 | tmp = tf.layers.Conv2D(64, (3, 3), (1, 1), 'same')(tmp) 344 | return x + tmp 345 | 346 | def __call__(self, x, flow): 347 | with tf.variable_scope(self.name) as vs: 348 | x = tf.layers.Conv2D(64, (3, 3), (1, 1), 'same', 349 | dilation_rate=(1, 1))(x) 350 | x = tf.nn.leaky_relu(x, 0.1) 351 | x = tf.layers.Conv2D(64, (3, 3), (1, 1), 'same', 352 | dilation_rate=(2, 2))(x) 353 | x = tf.nn.leaky_relu(x, 0.1) 354 | x = tf.layers.Conv2D(64, (3, 3), (1, 1), 'same', 355 | dilation_rate=(4, 4))(x) 356 | x = tf.nn.leaky_relu(x, 0.1) 357 | x = tf.layers.Conv2D(48, (3, 3), (1, 1), 'same', 358 | dilation_rate=(8, 8))(x) 359 | x = tf.nn.leaky_relu(x, 0.1) 360 | x = tf.layers.Conv2D(32, (3, 3), (1, 1), 'same', 361 | dilation_rate=(16, 16))(x) 362 | x = tf.nn.leaky_relu(x, 0.1) 363 | x = tf.layers.Conv2D(32, (3, 3), (1, 1), 'same', 364 | dilation_rate=(1, 1))(x) 365 | x = tf.nn.leaky_relu(x, 0.1) 366 | x = tf.layers.Conv2D(2, (3, 3), (1, 1), 'same', 367 | dilation_rate=(1, 1))(x) 368 | return x+flow 369 | 370 | class Resideblur_ResNet_RGB(object): 371 | def __init__(self, name='Resideblur_ResNet'): 372 | self.name = name 373 | 374 | def resblock(self, x): 375 | tmp = tf.nn.relu(tf.layers.Conv2D(48, (3, 3), (1, 1), 'same')(x)) 376 | tmp = tf.layers.Conv2D(48, (3, 3), (1, 1), 'same')(tmp) 377 | return x + tmp 378 | def __call__(self, x, features, reuse=False): 379 | with tf.variable_scope(self.name, reuse=reuse) as vs: 380 | x = tf.layers.Conv2D(48, (3, 3), (1, 1), 'same')(tf.concat([x,features],axis=3)) 381 | x = self.resblock(x) 382 | x = tf.nn.leaky_relu(tf.layers.Conv2D(48, (3, 3), (1, 1), 'same')(x), 0.1) 383 | x1 = tf.nn.leaky_relu(tf.layers.Conv2D(48, (3, 3), (2, 2), 'same')(x), 0.1) 384 | x2 = tf.nn.leaky_relu(tf.layers.Conv2D(48, (3, 3), (2, 2), 'same')(x1), 0.1) 385 | _, h, w, _ = tf.unstack(tf.shape(x2)) 386 | x2_up = tf.image.resize_bilinear(self.resblock(x2), (2*h, 2*w)) 387 | x1 = self.resblock(x1)+x2_up 388 | _, h, w, _ = tf.unstack(tf.shape(x1)) 389 | x1_up = tf.image.resize_bilinear(self.resblock(x1), (2 * h, 2 * w)) 390 | x = self.resblock(x) + x1_up 391 | x = self.resblock(self.resblock(x)) 392 | output_RGB = tf.layers.Conv2D(3, (3, 3), (1, 1), 'same')(x) 393 | return output_RGB -------------------------------------------------------------------------------- /msssim.py: -------------------------------------------------------------------------------- 1 | # This is adapted from 2 | # https://github.com/tensorflow/models/blob/master/research/compression/image_encoder/msssim.py 3 | # 4 | # ============================================================================== 5 | #!/usr/bin/python 6 | # 7 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | # ============================================================================== 21 | 22 | """Python implementation of MS-SSIM. 23 | 24 | Usage: 25 | 26 | python msssim.py --original_image=original.png --compared_image=distorted.png 27 | """ 28 | import numpy as np 29 | from scipy import signal 30 | from scipy.ndimage.filters import convolve 31 | 32 | 33 | def _FSpecialGauss(size, sigma): 34 | """Function to mimic the 'fspecial' gaussian MATLAB function.""" 35 | radius = size // 2 36 | offset = 0.0 37 | start, stop = -radius, radius + 1 38 | if size % 2 == 0: 39 | offset = 0.5 40 | stop -= 1 41 | x, y = np.mgrid[offset + start:stop, offset + start:stop] 42 | assert len(x) == size 43 | g = np.exp(-((x**2 + y**2)/(2.0 * sigma**2))) 44 | return g / g.sum() 45 | 46 | 47 | def _SSIMForMultiScale(img1, img2, max_val=255, filter_size=11, 48 | filter_sigma=1.5, k1=0.01, k2=0.03): 49 | """Return the Structural Similarity Map between `img1` and `img2`. 50 | 51 | This function attempts to match the functionality of ssim_index_new.m by 52 | Zhou Wang: http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 53 | 54 | Arguments: 55 | img1: Numpy array holding the first RGB image batch. 56 | img2: Numpy array holding the second RGB image batch. 57 | max_val: the dynamic range of the images (i.e., the difference between the 58 | maximum the and minimum allowed values). 59 | filter_size: Size of blur kernel to use (will be reduced for small images). 60 | filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced 61 | for small images). 62 | k1: Constant used to maintain stability in the SSIM calculation (0.01 in 63 | the original paper). 64 | k2: Constant used to maintain stability in the SSIM calculation (0.03 in 65 | the original paper). 66 | 67 | Returns: 68 | Pair containing the mean SSIM and contrast sensitivity between `img1` and 69 | `img2`. 70 | 71 | Raises: 72 | RuntimeError: If input images don't have the same shape or don't have four 73 | dimensions: [batch_size, height, width, depth]. 74 | """ 75 | if img1.shape != img2.shape: 76 | raise RuntimeError('Input images must have the same shape (%s vs. %s).', 77 | img1.shape, img2.shape) 78 | if img1.ndim != 4: 79 | raise RuntimeError('Input images must have four dimensions, not %d', 80 | img1.ndim) 81 | 82 | img1 = img1.astype(np.float64) 83 | img2 = img2.astype(np.float64) 84 | _, height, width, _ = img1.shape 85 | 86 | # Filter size can't be larger than height or width of images. 87 | size = min(filter_size, height, width) 88 | 89 | # Scale down sigma if a smaller filter size is used. 90 | sigma = size * filter_sigma / filter_size if filter_size else 0 91 | 92 | if filter_size: 93 | window = np.reshape(_FSpecialGauss(size, sigma), (1, size, size, 1)) 94 | mu1 = signal.fftconvolve(img1, window, mode='valid') 95 | mu2 = signal.fftconvolve(img2, window, mode='valid') 96 | sigma11 = signal.fftconvolve(img1 * img1, window, mode='valid') 97 | sigma22 = signal.fftconvolve(img2 * img2, window, mode='valid') 98 | sigma12 = signal.fftconvolve(img1 * img2, window, mode='valid') 99 | else: 100 | # Empty blur kernel so no need to convolve. 101 | mu1, mu2 = img1, img2 102 | sigma11 = img1 * img1 103 | sigma22 = img2 * img2 104 | sigma12 = img1 * img2 105 | 106 | mu11 = mu1 * mu1 107 | mu22 = mu2 * mu2 108 | mu12 = mu1 * mu2 109 | sigma11 -= mu11 110 | sigma22 -= mu22 111 | sigma12 -= mu12 112 | 113 | # Calculate intermediate values used by both ssim and cs_map. 114 | c1 = (k1 * max_val) ** 2 115 | c2 = (k2 * max_val) ** 2 116 | v1 = 2.0 * sigma12 + c2 117 | v2 = sigma11 + sigma22 + c2 118 | ssim = np.mean((((2.0 * mu12 + c1) * v1) / ((mu11 + mu22 + c1) * v2))) 119 | cs = np.mean(v1 / v2) 120 | return ssim, cs 121 | 122 | 123 | def MultiScaleSSIM(img1, img2, max_val=255, filter_size=11, filter_sigma=1.5, 124 | k1=0.01, k2=0.03, weights=None): 125 | """Return the MS-SSIM score between `img1` and `img2`. 126 | 127 | This function implements Multi-Scale Structural Similarity (MS-SSIM) Image 128 | Quality Assessment according to Zhou Wang's paper, "Multi-scale structural 129 | similarity for image quality assessment" (2003). 130 | Link: https://ece.uwaterloo.ca/~z70wang/publications/msssim.pdf 131 | 132 | Author's MATLAB implementation: 133 | http://www.cns.nyu.edu/~lcv/ssim/msssim.zip 134 | 135 | Arguments: 136 | img1: Numpy array holding the first RGB image batch. 137 | img2: Numpy array holding the second RGB image batch. 138 | max_val: the dynamic range of the images (i.e., the difference between the 139 | maximum the and minimum allowed values). 140 | filter_size: Size of blur kernel to use (will be reduced for small images). 141 | filter_sigma: Standard deviation for Gaussian blur kernel (will be reduced 142 | for small images). 143 | k1: Constant used to maintain stability in the SSIM calculation (0.01 in 144 | the original paper). 145 | k2: Constant used to maintain stability in the SSIM calculation (0.03 in 146 | the original paper). 147 | weights: List of weights for each level; if none, use five levels and the 148 | weights from the original paper. 149 | 150 | Returns: 151 | MS-SSIM score between `img1` and `img2`. 152 | 153 | Raises: 154 | RuntimeError: If input images don't have the same shape or don't have four 155 | dimensions: [batch_size, height, width, depth]. 156 | """ 157 | img1 = np.expand_dims(np.expand_dims(img1, axis=-1), axis=0) 158 | img2 = np.expand_dims(np.expand_dims(img2, axis=-1), axis=0) 159 | if img1.shape != img2.shape: 160 | raise RuntimeError('Input images must have the same shape (%s vs. %s).', 161 | img1.shape, img2.shape) 162 | if img1.ndim != 4: 163 | raise RuntimeError('Input images must have four dimensions, not %d', 164 | img1.ndim) 165 | 166 | # Note: default weights don't sum to 1.0 but do match the paper / matlab code. 167 | weights = np.array(weights if weights else 168 | [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) 169 | levels = weights.size 170 | downsample_filter = np.ones((1, 2, 2, 1)) / 4.0 171 | im1, im2 = [x.astype(np.float64) for x in [img1, img2]] 172 | mssim = np.array([]) 173 | mcs = np.array([]) 174 | for _ in range(levels): 175 | ssim, cs = _SSIMForMultiScale( 176 | im1, im2, max_val=max_val, filter_size=filter_size, 177 | filter_sigma=filter_sigma, k1=k1, k2=k2) 178 | mssim = np.append(mssim, ssim) 179 | mcs = np.append(mcs, cs) 180 | filtered = [convolve(im, downsample_filter, mode='reflect') 181 | for im in [im1, im2]] 182 | im1, im2 = [x[:, ::2, ::2, :] for x in filtered] 183 | return (np.prod(mcs[0:levels-1] ** weights[0:levels-1]) * 184 | (mssim[levels-1] ** weights[levels-1])) 185 | -------------------------------------------------------------------------------- /tensorflow_compression/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Data compression tools.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Dependency imports 23 | 24 | from tensorflow.python.util.all_util import remove_undocumented 25 | 26 | # pylint: disable=wildcard-import 27 | from tensorflow_compression.python.layers.entropy_models import * 28 | from tensorflow_compression.python.layers.entropy_models_gauss import * 29 | from tensorflow_compression.python.layers.gdn import * 30 | from tensorflow_compression.python.layers.initializers import * 31 | from tensorflow_compression.python.layers.parameterizers import * 32 | from tensorflow_compression.python.layers.signal_conv import * 33 | from tensorflow_compression.python.ops.coder_ops import * 34 | from tensorflow_compression.python.ops.math_ops import * 35 | from tensorflow_compression.python.ops.padding_ops import * 36 | from tensorflow_compression.python.ops.spectral_ops import * 37 | # pylint: enable=wildcard-import 38 | 39 | remove_undocumented(__name__, [ 40 | "EntropyBottleneck","EntropyBottleneck_gauss", "GDN", "IdentityInitializer", "Parameterizer", 41 | "StaticParameterizer", "RDFTParameterizer", "NonnegativeParameterizer", 42 | "SignalConv1D", "SignalConv2D", "SignalConv3D", 43 | "upper_bound", "lower_bound", "same_padding_for_kernel", "irdft_matrix", 44 | "pmf_to_quantized_cdf", "range_decode", "range_encode", 45 | ]) 46 | -------------------------------------------------------------------------------- /tensorflow_compression/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/tensorflow_compression/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_compression/python/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | -------------------------------------------------------------------------------- /tensorflow_compression/python/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/tensorflow_compression/python/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_compression/python/layers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | -------------------------------------------------------------------------------- /tensorflow_compression/python/layers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/tensorflow_compression/python/layers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_compression/python/layers/__pycache__/entropy_models.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/tensorflow_compression/python/layers/__pycache__/entropy_models.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_compression/python/layers/__pycache__/entropy_models_gauss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/tensorflow_compression/python/layers/__pycache__/entropy_models_gauss.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_compression/python/layers/__pycache__/gdn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/tensorflow_compression/python/layers/__pycache__/gdn.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_compression/python/layers/__pycache__/initializers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/tensorflow_compression/python/layers/__pycache__/initializers.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_compression/python/layers/__pycache__/parameterizers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/tensorflow_compression/python/layers/__pycache__/parameterizers.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_compression/python/layers/__pycache__/signal_conv.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/tensorflow_compression/python/layers/__pycache__/signal_conv.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_compression/python/layers/entropy_models_gauss.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Entropy bottleneck layer.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Dependency imports 23 | 24 | import tensorflow as tf 25 | 26 | from tensorflow.python.eager import context 27 | from tensorflow.python.framework import constant_op 28 | from tensorflow.python.framework import dtypes 29 | from tensorflow.python.framework import ops 30 | from tensorflow.python.framework import tensor_shape 31 | from tensorflow.python.keras.engine import base_layer 32 | from tensorflow.python.keras.engine import input_spec 33 | from tensorflow.python.ops import array_ops 34 | from tensorflow.python.ops import init_ops 35 | from tensorflow.python.ops import functional_ops 36 | from tensorflow.python.ops import math_ops 37 | from tensorflow.python.ops import nn 38 | from tensorflow.python.ops import random_ops 39 | from tensorflow.python.ops import state_ops 40 | from tensorflow.python.summary import summary 41 | 42 | from tensorflow_compression.python.ops import coder_ops 43 | from tensorflow_compression.python.ops import math_ops as tfc_math_ops 44 | import tensorflow.contrib.distributions as tfd 45 | 46 | class EntropyBottleneck_gauss(base_layer.Layer): 47 | """Entropy bottleneck layer. 48 | 49 | This layer models the entropy of the tensor passing through it. During 50 | training, this can be used to impose a (soft) entropy constraint on its 51 | activations, limiting the amount of information flowing through the layer. 52 | After training, the layer can be used to compress any input tensor to a 53 | string, which may be written to a file, and to decompress a file which it 54 | previously generated back to a reconstructed tensor. The entropies estimated 55 | during training or evaluation are approximately equal to the average length of 56 | the strings in bits. 57 | 58 | The layer implements a flexible probability density model to estimate entropy 59 | of its input tensor, which is described in the appendix of the paper (please 60 | cite the paper if you use this code for scientific work): 61 | 62 | > "Variational image compression with a scale hyperprior"
63 | > J. Ballé, D. Minnen, S. Singh, S. J. Hwang, N. Johnston
64 | > https://arxiv.org/abs/1802.01436 65 | 66 | The layer assumes that the input tensor is at least 2D, with a batch dimension 67 | at the beginning and a channel dimension as specified by `data_format`. The 68 | layer trains an independent probability density model for each channel, but 69 | assumes that across all other dimensions, the inputs are i.i.d. (independent 70 | and identically distributed). 71 | 72 | Because data compression always involves discretization, the outputs of the 73 | layer are generally only approximations of its inputs. During training, 74 | discretization is modeled using additive uniform noise to ensure 75 | differentiability. The entropies computed during training are differential 76 | entropies. During evaluation, the data is actually quantized, and the 77 | entropies are discrete (Shannon entropies). To make sure the approximated 78 | tensor values are good enough for practical purposes, the training phase must 79 | be used to balance the quality of the approximation with the entropy, by 80 | adding an entropy term to the training loss. See the example in the package 81 | documentation to get started. 82 | 83 | Note: the layer always produces exactly one auxiliary loss and one update op, 84 | which are only significant for compression and decompression. To use the 85 | compression feature, the auxiliary loss must be minimized during or after 86 | training. After that, the update op must be executed at least once. 87 | 88 | Arguments: 89 | init_scale: Float. A scaling factor determining the initial width of the 90 | probability densities. This should be chosen big enough so that the 91 | range of values of the layer inputs roughly falls within the interval 92 | [`-init_scale`, `init_scale`] at the beginning of training. 93 | filters: An iterable of ints, giving the number of filters at each layer of 94 | the density model. Generally, the more filters and layers, the more 95 | expressive is the density model in terms of modeling more complicated 96 | distributions of the layer inputs. For details, refer to the paper 97 | referenced above. The default is `[3, 3, 3]`, which should be sufficient 98 | for most practical purposes. 99 | tail_mass: Float, between 0 and 1. The bottleneck layer automatically 100 | determines the range of input values that should be represented based on 101 | their frequency of occurrence. Values occurring in the tails of the 102 | distributions will be clipped to that range during compression. 103 | `tail_mass` determines the amount of probability mass in the tails which 104 | is cut off in the worst case. For example, the default value of `1e-9` 105 | means that at most 1 in a billion input samples will be clipped to the 106 | range. 107 | optimize_integer_offset: Boolean. Typically, the input values of this layer 108 | are floats, which means that quantization during evaluation can be 109 | performed with an arbitrary offset. By default, the layer determines that 110 | offset automatically. In special situations, such as when it is known that 111 | the layer will receive only full integer values during evaluation, it can 112 | be desirable to set this argument to `False` instead, in order to always 113 | quantize to full integer values. 114 | likelihood_bound: Float. If positive, the returned likelihood values are 115 | ensured to be greater than or equal to this value. This prevents very 116 | large gradients with a typical entropy loss (defaults to 1e-9). 117 | range_coder_precision: Integer, between 1 and 16. The precision of the range 118 | coder used for compression and decompression. This trades off computation 119 | speed with compression efficiency, where 16 is the slowest but most 120 | efficient setting. Choosing lower values may increase the average 121 | codelength slightly compared to the estimated entropies. 122 | data_format: Either `'channels_first'` or `'channels_last'` (default). 123 | trainable: Boolean. Whether the layer should be trained. 124 | name: String. The name of the layer. 125 | dtype: Default dtype of the layer's parameters (default of `None` means use 126 | the type of the first input). 127 | 128 | Read-only properties: 129 | init_scale: See above. 130 | filters: See above. 131 | tail_mass: See above. 132 | optimize_integer_offset: See above. 133 | likelihood_bound: See above. 134 | range_coder_precision: See above. 135 | data_format: See above. 136 | name: String. See above. 137 | dtype: See above. 138 | trainable_variables: List of trainable variables. 139 | non_trainable_variables: List of non-trainable variables. 140 | variables: List of all variables of this layer, trainable and non-trainable. 141 | updates: List of update ops of this layer. Always contains exactly one 142 | update op, which must be run once after the last training step, before 143 | `compress` or `decompress` is used. 144 | losses: List of losses added by this layer. Always contains exactly one 145 | auxiliary loss, which must be added to the training loss. 146 | 147 | Mutable properties: 148 | trainable: Boolean. Whether the layer should be trained. 149 | input_spec: Optional `InputSpec` object specifying the constraints on inputs 150 | that can be accepted by the layer. 151 | """ 152 | 153 | def __init__(self, init_scale=10, filters=(3, 3, 3), tail_mass=1e-9, 154 | optimize_integer_offset=True, likelihood_bound=1e-9, 155 | range_coder_precision=16, data_format="channels_last", **kwargs): 156 | super(EntropyBottleneck_gauss, self).__init__(**kwargs) 157 | self._init_scale = float(init_scale) 158 | self._filters = tuple(int(f) for f in filters) 159 | self._tail_mass = float(tail_mass) 160 | if not 0 < self.tail_mass < 1: 161 | raise ValueError( 162 | "`tail_mass` must be between 0 and 1, got {}.".format(self.tail_mass)) 163 | self._optimize_integer_offset = bool(optimize_integer_offset) 164 | self._likelihood_bound = float(likelihood_bound) 165 | self._range_coder_precision = int(range_coder_precision) 166 | self._data_format = data_format 167 | self._channel_axis(2) # trigger ValueError early 168 | self.input_spec = base_layer.InputSpec(min_ndim=2) 169 | 170 | @property 171 | def init_scale(self): 172 | return self._init_scale 173 | 174 | @property 175 | def filters(self): 176 | return self._filters 177 | 178 | @property 179 | def tail_mass(self): 180 | return self._tail_mass 181 | 182 | @property 183 | def optimize_integer_offset(self): 184 | return self._optimize_integer_offset 185 | 186 | @property 187 | def likelihood_bound(self): 188 | return self._likelihood_bound 189 | 190 | @property 191 | def range_coder_precision(self): 192 | return self._range_coder_precision 193 | 194 | @property 195 | def data_format(self): 196 | return self._data_format 197 | 198 | def _channel_axis(self, ndim): 199 | try: 200 | return {"channels_first": 1, "channels_last": ndim - 1}[self.data_format] 201 | except KeyError: 202 | raise ValueError("Unsupported `data_format` for {} layer: {}.".format( 203 | self.__class__.__name__, self.data_format)) 204 | 205 | def _logits_cumulative(self, inputs, stop_gradient): 206 | """Evaluate logits of the cumulative densities. 207 | 208 | Args: 209 | inputs: The values at which to evaluate the cumulative densities, expected 210 | to be a `Tensor` of shape `(channels, 1, batch)`. 211 | stop_gradient: Boolean. Whether to add `array_ops.stop_gradient` calls so 212 | that the gradient of the output with respect to the density model 213 | parameters is disconnected (the gradient with respect to `inputs` is 214 | left untouched). 215 | 216 | Returns: 217 | A `Tensor` of the same shape as `inputs`, containing the logits of the 218 | cumulative densities evaluated at the given inputs. 219 | """ 220 | logits = inputs 221 | 222 | for i in range(len(self.filters) + 1): 223 | matrix = self._matrices[i] 224 | if stop_gradient: 225 | matrix = array_ops.stop_gradient(matrix) 226 | logits = math_ops.matmul(matrix, logits) 227 | 228 | bias = self._biases[i] 229 | if stop_gradient: 230 | bias = array_ops.stop_gradient(bias) 231 | logits += bias 232 | 233 | if i < len(self._factors): 234 | factor = self._factors[i] 235 | if stop_gradient: 236 | factor = array_ops.stop_gradient(factor) 237 | logits += factor * math_ops.tanh(logits) 238 | 239 | return logits 240 | 241 | def build(self, input_shape): 242 | """Builds the layer. 243 | 244 | Creates the variables for the network modeling the densities, creates the 245 | auxiliary loss estimating the median and tail quantiles of the densities, 246 | and then uses that to create the probability mass functions and the update 247 | op that produces the discrete cumulative density functions used by the range 248 | coder. 249 | 250 | Args: 251 | input_shape: Shape of the input tensor, used to get the number of 252 | channels. 253 | 254 | Raises: 255 | ValueError: if `input_shape` doesn't specify the length of the channel 256 | dimension. 257 | """ 258 | input_shape = tensor_shape.TensorShape(input_shape) 259 | channel_axis = self._channel_axis(input_shape.ndims) 260 | channels = input_shape[channel_axis].value 261 | self.n = input_shape[0].value 262 | self.h = input_shape[1].value 263 | self.w = input_shape[2].value 264 | self.c = input_shape[3].value 265 | if channels is None: 266 | raise ValueError("The channel dimension of the inputs must be defined.") 267 | self.input_spec = base_layer.InputSpec( 268 | ndim=input_shape.ndims, axes={channel_axis: channels}) 269 | super(EntropyBottleneck_gauss, self).build(input_shape) 270 | def build_gauss(self, stddev): 271 | """Builds the layer. 272 | 273 | Creates the variables for the network modeling the densities, creates the 274 | auxiliary loss estimating the median and tail quantiles of the densities, 275 | and then uses that to create the probability mass functions and the update 276 | op that produces the discrete cumulative density functions used by the range 277 | coder. 278 | 279 | Args: 280 | input_shape: Shape of the input tensor, used to get the number of 281 | channels. 282 | 283 | Raises: 284 | ValueError: if `input_shape` doesn't specify the length of the channel 285 | dimension. 286 | """ 287 | 288 | # To figure out what range of the densities to sample, we need to compute 289 | # the quantiles given by `tail_mass / 2` and `1 - tail_mass / 2`. Since we 290 | # can't take inverses of the cumulative directly, we make it an optimization 291 | # problem: 292 | # `quantiles = argmin(|logit(cumulative) - target|)` 293 | # where `target` is `logit(tail_mass / 2)` or `logit(1 - tail_mass / 2)`. 294 | # Taking the logit (inverse of sigmoid) of the cumulative makes the 295 | # representation of the right target more numerically stable. 296 | 297 | # Numerically stable way of computing logits of `tail_mass / 2` 298 | # and `1 - tail_mass / 2`. 299 | # Compute lower and upper tail quantile as well as median. 300 | 301 | tail_prob = constant_op.constant(1/(65536.0*100), dtype=self.dtype, shape=(self.n, self.h, self.w, self.c, 1)) 302 | median_prob = constant_op.constant(0.5, dtype=self.dtype, shape=(self.n, self.h, self.w, self.c, 1)) 303 | head_prob = constant_op.constant(1 - 1/(65536.0*100), dtype=self.dtype, shape=(self.n, self.h, self.w, self.c, 1)) 304 | mean = constant_op.constant(0., dtype=self.dtype, shape=(self.n, self.h, self.w, self.c, 1)) 305 | norm_dist = tfd.Normal(loc=mean, scale=stddev) 306 | tail_quantile = norm_dist.quantile(tail_prob) 307 | median_quantile = norm_dist.quantile(median_prob) 308 | head_quantile = norm_dist.quantile(head_prob) 309 | 310 | # Save medians for `call`, `compress`, and `decompress`. 311 | self._medians = median_quantile 312 | if not self.optimize_integer_offset: 313 | self._medians = math_ops.round(self._medians) 314 | 315 | # Largest distance observed between lower tail quantile and median, 316 | # or between median and upper tail quantile. 317 | minima = math_ops.reduce_max(self._medians - tail_quantile) 318 | maxima = math_ops.reduce_max(head_quantile - self._medians) 319 | minmax = math_ops.maximum(minima, maxima) 320 | minmax = math_ops.ceil(minmax) 321 | minmax = math_ops.maximum(minmax, 1) 322 | minmax = math_ops.minimum(minmax, 1e3) 323 | 324 | # Sample the density up to `minmax` around the median. 325 | samples = math_ops.range(-minmax, minmax + 1, dtype=self.dtype) 326 | samples += self._medians 327 | 328 | half = constant_op.constant(.5, dtype=self.dtype) 329 | # We strip the sigmoid from the end here, so we can use the special rule 330 | # below to only compute differences in the left tail of the sigmoid. 331 | # This increases numerical stability (see explanation in `call`). 332 | pmf = abs(norm_dist.cdf(samples + half) - norm_dist.cdf(samples - half)) 333 | # Flip signs if we can move more towards the left tail of the sigmoid. 334 | # Add tail masses to first and last bin of pmf, as we clip values for 335 | # compression, meaning that out-of-range values get mapped to these bins. 336 | pmf = array_ops.concat([ 337 | math_ops.add_n([pmf[:, :, :, :, :1], norm_dist.cdf(samples[:, :, :, :, :1] - half)]), 338 | pmf[:, :, :, :, 1:-1], 339 | math_ops.add_n([pmf[:, :, :, :, -1:], norm_dist.cdf(-(samples[:, :, :, :, -1:] + half))]), 340 | ], axis=-1) 341 | self._pmf = pmf 342 | 343 | self._quantized_cdf = coder_ops.pmf_to_quantized_cdf( 344 | pmf, precision=self.range_coder_precision) 345 | self._quantized_cdf = array_ops.squeeze(self._quantized_cdf, [0]) 346 | 347 | # We need to supply an initializer without fully defined static shape here, 348 | # or the variable will return the wrong dynamic shape later. A placeholder 349 | # with default gets the trick done. 350 | 351 | def call_gauss(self, inputs, input_stddev, training): 352 | """Pass a tensor through the bottleneck. 353 | 354 | Args: 355 | inputs: The tensor to be passed through the bottleneck. 356 | training: Boolean. If `True`, returns a differentiable approximation of 357 | the inputs, and their likelihoods under the modeled probability 358 | densities. If `False`, returns the quantized inputs and their 359 | likelihoods under the corresponding probability mass function. These 360 | quantities can't be used for training, as they are not differentiable, 361 | but represent actual compression more closely. 362 | 363 | Returns: 364 | values: `Tensor` with the same shape as `inputs` containing the perturbed 365 | or quantized input values. 366 | likelihood: `Tensor` with the same shape as `inputs` containing the 367 | likelihood of `values` under the modeled probability distributions. 368 | 369 | Raises: 370 | ValueError: if `inputs` has different `dtype` or number of channels than 371 | a previous set of inputs the model was invoked with earlier. 372 | """ 373 | inputs = ops.convert_to_tensor(inputs) 374 | input_stddev = ops.convert_to_tensor(input_stddev) 375 | inputs = array_ops.expand_dims(inputs, axis=4) 376 | input_stddev = array_ops.expand_dims(input_stddev, axis=4) 377 | #self.build_gauss(input_stddev) 378 | half = constant_op.constant(.5, dtype=self.dtype) 379 | 380 | # Convert to (channels, 1, batch) format by commuting channels to front 381 | # and then collapsing. 382 | values = inputs 383 | stddev = input_stddev 384 | 385 | # Add noise or quantize. 386 | if training: 387 | noise = random_ops.random_uniform(array_ops.shape(values), -half, half) 388 | values = math_ops.add_n([values, noise]) 389 | elif self.optimize_integer_offset: 390 | values = math_ops.round(values - self._medians) + self._medians 391 | else: 392 | values = math_ops.round(values) 393 | 394 | mean = constant_op.constant(0., dtype=self.dtype, shape=(self.n, self.h, self.w, self.c, 1)) 395 | norm_dist = tfd.Normal(loc=mean, scale=stddev) 396 | likelihood = abs(norm_dist.cdf(values + half) - norm_dist.cdf(values - half)) 397 | if self.likelihood_bound > 0: 398 | likelihood_bound = constant_op.constant( 399 | self.likelihood_bound, dtype=self.dtype) 400 | likelihood = tfc_math_ops.lower_bound(likelihood, likelihood_bound) 401 | 402 | if not context.executing_eagerly(): 403 | values_shape, likelihood_shape = self.compute_output_shape(inputs.shape) 404 | values.set_shape(values_shape) 405 | likelihood.set_shape(likelihood_shape) 406 | 407 | values = array_ops.squeeze(values, [-1]) 408 | likelihood = array_ops.squeeze(likelihood, [-1]) 409 | 410 | return values, likelihood 411 | 412 | def call(self, inputs, input_stddev, training): 413 | """Pass a tensor through the bottleneck. 414 | 415 | Args: 416 | inputs: The tensor to be passed through the bottleneck. 417 | training: Boolean. If `True`, returns a differentiable approximation of 418 | the inputs, and their likelihoods under the modeled probability 419 | densities. If `False`, returns the quantized inputs and their 420 | likelihoods under the corresponding probability mass function. These 421 | quantities can't be used for training, as they are not differentiable, 422 | but represent actual compression more closely. 423 | 424 | Returns: 425 | values: `Tensor` with the same shape as `inputs` containing the perturbed 426 | or quantized input values. 427 | likelihood: `Tensor` with the same shape as `inputs` containing the 428 | likelihood of `values` under the modeled probability distributions. 429 | 430 | Raises: 431 | ValueError: if `inputs` has different `dtype` or number of channels than 432 | a previous set of inputs the model was invoked with earlier. 433 | """ 434 | values, likelihood = self.call_gauss(inputs, input_stddev, training) 435 | 436 | return values, likelihood 437 | 438 | def compress(self, inputs, input_stddev): 439 | """Compress inputs and store their binary representations into strings. 440 | 441 | Args: 442 | inputs: `Tensor` with values to be compressed. 443 | 444 | Returns: 445 | String `Tensor` vector containing the compressed representation of each 446 | batch element of `inputs`. 447 | """ 448 | with ops.name_scope(self._name_scope()): 449 | inputs = ops.convert_to_tensor(inputs) 450 | if not self.built: 451 | # Check input assumptions set before layer building, e.g. input rank. 452 | input_spec.assert_input_compatibility(self.input_spec, inputs, 453 | self.name) 454 | if self.dtype is None: 455 | self._dtype = inputs.dtype.base_dtype.name 456 | self.build(inputs.shape) 457 | input_stddev = ops.convert_to_tensor(input_stddev) 458 | inputs = array_ops.expand_dims(inputs, axis=4) 459 | input_stddev = array_ops.expand_dims(input_stddev, axis=4) 460 | self.build_gauss(input_stddev) 461 | return tf.zeros(shape=inputs.shape[:1],dtype=tf.string) 462 | # Check input assumptions set after layer building, e.g. input shape. 463 | 464 | # Expand dimensions of CDF to input dimensions, keeping the channels along 465 | # the right dimension. 466 | cdf = self._quantized_cdf 467 | num_levels = array_ops.shape(cdf)[-1] - 1 468 | half_num_levels = math_ops.cast(num_levels // 2,self.dtype) 469 | 470 | # Bring inputs to the right range by centering the range on the medians. 471 | half = constant_op.constant(.5, dtype=self.dtype) 472 | offsets = - self._medians + ( half_num_levels + half) 473 | # Expand offsets to input dimensions and add to inputs. 474 | values = inputs + offsets 475 | 476 | # Clip to range and cast to integers. Because we have added .5 above, and 477 | # all values are positive, the cast effectively implements rounding. 478 | values = math_ops.maximum(values, half) 479 | values = math_ops.minimum( 480 | values, math_ops.cast(num_levels, self.dtype) - half) 481 | values = math_ops.cast(values, dtypes.int16) 482 | 483 | values = array_ops.squeeze(values, [-1]) 484 | 485 | def loop_body(tensor): 486 | return coder_ops.range_encode( 487 | tensor, cdf, precision=self.range_coder_precision) 488 | strings = functional_ops.map_fn( 489 | loop_body, values, dtype=dtypes.string, back_prop=False) 490 | 491 | if not context.executing_eagerly(): 492 | strings.set_shape(inputs.shape[:1]) 493 | 494 | return strings 495 | 496 | def decompress(self, strings, shape, channels=None): 497 | """Decompress values from their compressed string representations. 498 | 499 | Args: 500 | strings: A string `Tensor` vector containing the compressed data. 501 | shape: A `Tensor` vector of int32 type. Contains the shape of the tensor 502 | to be decompressed, excluding the batch dimension. 503 | channels: Integer. Specifies the number of channels statically. Needs only 504 | be set if the layer hasn't been built yet (i.e., this is the first input 505 | it receives). 506 | 507 | Returns: 508 | The decompressed `Tensor`. Its shape will be equal to `shape` prepended 509 | with the batch dimension from `strings`. 510 | 511 | Raises: 512 | ValueError: If the length of `shape` isn't available at graph construction 513 | time. 514 | """ 515 | with ops.name_scope(self._name_scope()): 516 | strings = ops.convert_to_tensor(strings) 517 | shape = ops.convert_to_tensor(shape) 518 | if self.built: 519 | ndim = self.input_spec.ndim 520 | channel_axis = self._channel_axis(ndim) 521 | if channels is None: 522 | channels = self.input_spec.axes[channel_axis] 523 | else: 524 | if not (shape.shape.is_fully_defined() and shape.shape.ndims == 1): 525 | raise ValueError("`shape` must be a vector with known length.") 526 | ndim = shape.shape[0].value + 1 527 | channel_axis = self._channel_axis(ndim) 528 | input_shape = ndim * [None] 529 | input_shape[channel_axis] = channels 530 | self.build(input_shape) 531 | 532 | # Tuple of slices for expanding dimensions of tensors below. 533 | slices = ndim * [None] + [slice(None)] 534 | slices[channel_axis] = slice(None) 535 | slices = tuple(slices) 536 | 537 | # Expand dimensions of CDF to input dimensions, keeping the channels along 538 | # the right dimension. 539 | cdf = self._quantized_cdf[slices[1:]] 540 | num_levels = array_ops.shape(cdf)[-1] - 1 541 | 542 | def loop_body(string): 543 | return coder_ops.range_decode( 544 | string, shape, cdf, precision=self.range_coder_precision) 545 | outputs = functional_ops.map_fn( 546 | loop_body, strings, dtype=dtypes.int16, back_prop=False) 547 | outputs = math_ops.cast(outputs, self.dtype) 548 | 549 | medians = array_ops.squeeze(self._medians, [1, 2]) 550 | offsets = math_ops.cast(num_levels // 2, self.dtype) - medians 551 | outputs -= offsets[slices[:-1]] 552 | 553 | if not context.executing_eagerly(): 554 | outputs_shape = ndim * [None] 555 | outputs_shape[0] = strings.shape[0] 556 | outputs_shape[channel_axis] = channels 557 | outputs.set_shape(outputs_shape) 558 | 559 | return outputs 560 | 561 | def visualize(self): 562 | """Multi-channel visualization of densities as images. 563 | 564 | Creates and returns an image summary visualizing the current probabilty 565 | density estimates. The image contains one row for each channel. Within each 566 | row, the pixel intensities are proportional to probability values, and each 567 | row is centered on the median of the corresponding distribution. 568 | 569 | Returns: 570 | The created image summary. 571 | """ 572 | with ops.name_scope(self._name_scope()): 573 | image = self._pmf 574 | image *= 255 / math_ops.reduce_max(image, axis=1, keepdims=True) 575 | image = math_ops.cast(image + .5, dtypes.uint8) 576 | image = image[None, :, :, None] 577 | return summary.image("pmf", image, max_outputs=1) 578 | 579 | def compute_output_shape(self, input_shape): 580 | input_shape = tensor_shape.TensorShape(input_shape) 581 | return input_shape, input_shape 582 | -------------------------------------------------------------------------------- /tensorflow_compression/python/layers/entropy_models_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests of EntropyBottleneck class.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Dependency imports 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | 27 | from tensorflow.python.platform import test 28 | 29 | import tensorflow_compression as tfc 30 | 31 | 32 | class EntropyBottleneckTest(test.TestCase): 33 | 34 | def test_noise(self): 35 | # Tests that the noise added is uniform noise between -0.5 and 0.5. 36 | inputs = tf.placeholder(tf.float32, (None, 1)) 37 | layer = tfc.EntropyBottleneck() 38 | noisy, _ = layer(inputs, training=True) 39 | with self.test_session() as sess: 40 | sess.run(tf.global_variables_initializer()) 41 | values = np.linspace(-50, 50, 100)[:, None] 42 | noisy, = sess.run([noisy], {inputs: values}) 43 | self.assertFalse(np.allclose(values, noisy, rtol=0, atol=.49)) 44 | self.assertAllClose(values, noisy, rtol=0, atol=.5) 45 | 46 | def test_quantization(self): 47 | # Tests that inputs are quantized to full integer values, even after 48 | # quantiles have been updated. 49 | inputs = tf.placeholder(tf.float32, (None, 1)) 50 | layer = tfc.EntropyBottleneck(optimize_integer_offset=False) 51 | quantized, _ = layer(inputs, training=False) 52 | opt = tf.train.GradientDescentOptimizer(learning_rate=1) 53 | self.assertTrue(len(layer.losses) == 1) 54 | step = opt.minimize(layer.losses[0]) 55 | with self.test_session() as sess: 56 | sess.run(tf.global_variables_initializer()) 57 | sess.run(step) 58 | values = np.linspace(-50, 50, 100)[:, None] 59 | quantized, = sess.run([quantized], {inputs: values}) 60 | self.assertAllClose(np.around(values), quantized, rtol=0, atol=1e-6) 61 | 62 | def test_quantization_optimized_offset(self): 63 | # Tests that inputs are not quantized to full integer values after quantiles 64 | # have been updated. However, the difference between input and output should 65 | # be between -0.5 and 0.5, and the offset must be consistent. 66 | inputs = tf.placeholder(tf.float32, (None, 1)) 67 | layer = tfc.EntropyBottleneck(optimize_integer_offset=True) 68 | quantized, _ = layer(inputs, training=False) 69 | opt = tf.train.GradientDescentOptimizer(learning_rate=1) 70 | self.assertTrue(len(layer.losses) == 1) 71 | step = opt.minimize(layer.losses[0]) 72 | with self.test_session() as sess: 73 | sess.run(tf.global_variables_initializer()) 74 | sess.run(step) 75 | values = np.linspace(-50, 50, 100)[:, None] 76 | quantized, = sess.run([quantized], {inputs: values}) 77 | self.assertAllClose(values, quantized, rtol=0, atol=.5) 78 | diff = np.ravel(np.around(values) - quantized) % 1 79 | self.assertAllClose(diff, np.full_like(diff, diff[0]), rtol=0, atol=5e-6) 80 | self.assertNotEqual(diff[0], 0) 81 | 82 | def test_codec(self): 83 | # Tests that inputs are compressed and decompressed correctly, and quantized 84 | # to full integer values, even after quantiles have been updated. 85 | inputs = tf.placeholder(tf.float32, (1, None, 1)) 86 | layer = tfc.EntropyBottleneck( 87 | data_format="channels_last", init_scale=60, 88 | optimize_integer_offset=False) 89 | bitstrings = layer.compress(inputs) 90 | decoded = layer.decompress(bitstrings, tf.shape(inputs)[1:]) 91 | opt = tf.train.GradientDescentOptimizer(learning_rate=1) 92 | self.assertTrue(len(layer.losses) == 1) 93 | step = opt.minimize(layer.losses[0]) 94 | with self.test_session() as sess: 95 | sess.run(tf.global_variables_initializer()) 96 | sess.run(step) 97 | self.assertTrue(len(layer.updates) == 1) 98 | sess.run(layer.updates[0]) 99 | values = np.linspace(-50, 50, 100)[None, :, None] 100 | decoded, = sess.run([decoded], {inputs: values}) 101 | self.assertAllClose(np.around(values), decoded, rtol=0, atol=1e-6) 102 | 103 | def test_codec_optimized_offset(self): 104 | # Tests that inputs are compressed and decompressed correctly, and not 105 | # quantized to full integer values after quantiles have been updated. 106 | # However, the difference between input and output should be between -0.5 107 | # and 0.5, and the offset must be consistent. 108 | inputs = tf.placeholder(tf.float32, (1, None, 1)) 109 | layer = tfc.EntropyBottleneck( 110 | data_format="channels_last", init_scale=60, 111 | optimize_integer_offset=True) 112 | bitstrings = layer.compress(inputs) 113 | decoded = layer.decompress(bitstrings, tf.shape(inputs)[1:]) 114 | opt = tf.train.GradientDescentOptimizer(learning_rate=1) 115 | self.assertTrue(len(layer.losses) == 1) 116 | step = opt.minimize(layer.losses[0]) 117 | with self.test_session() as sess: 118 | sess.run(tf.global_variables_initializer()) 119 | sess.run(step) 120 | self.assertTrue(len(layer.updates) == 1) 121 | sess.run(layer.updates[0]) 122 | values = np.linspace(-50, 50, 100)[None, :, None] 123 | decoded, = sess.run([decoded], {inputs: values}) 124 | self.assertAllClose(values, decoded, rtol=0, atol=.5) 125 | diff = np.ravel(np.around(values) - decoded) % 1 126 | self.assertAllClose(diff, np.full_like(diff, diff[0]), rtol=0, atol=5e-6) 127 | self.assertNotEqual(diff[0], 0) 128 | 129 | def test_codec_clipping(self): 130 | # Tests that inputs are compressed and decompressed correctly, and clipped 131 | # to the expected range. 132 | inputs = tf.placeholder(tf.float32, (1, None, 1)) 133 | layer = tfc.EntropyBottleneck( 134 | data_format="channels_last", init_scale=40) 135 | bitstrings = layer.compress(inputs) 136 | decoded = layer.decompress(bitstrings, tf.shape(inputs)[1:]) 137 | with self.test_session() as sess: 138 | sess.run(tf.global_variables_initializer()) 139 | self.assertTrue(len(layer.updates) == 1) 140 | sess.run(layer.updates[0]) 141 | values = np.linspace(-50, 50, 100)[None, :, None] 142 | decoded, = sess.run([decoded], {inputs: values}) 143 | expected = np.clip(np.around(values), -40, 40) 144 | self.assertAllClose(expected, decoded, rtol=0, atol=1e-6) 145 | 146 | def test_channels_last(self): 147 | # Test the layer with more than one channel and multiple input dimensions, 148 | # with the channels in the last dimension. 149 | inputs = tf.placeholder(tf.float32, (None, None, None, 2)) 150 | layer = tfc.EntropyBottleneck( 151 | data_format="channels_last", init_scale=50) 152 | noisy, _ = layer(inputs, training=True) 153 | quantized, _ = layer(inputs, training=False) 154 | bitstrings = layer.compress(inputs) 155 | decoded = layer.decompress(bitstrings, tf.shape(inputs)[1:]) 156 | with self.test_session() as sess: 157 | sess.run(tf.global_variables_initializer()) 158 | self.assertTrue(len(layer.updates) == 1) 159 | sess.run(layer.updates[0]) 160 | values = 5 * np.random.normal(size=(7, 5, 3, 2)) 161 | noisy, quantized, decoded = sess.run( 162 | [noisy, quantized, decoded], {inputs: values}) 163 | self.assertAllClose(values, noisy, rtol=0, atol=.5) 164 | self.assertAllClose(values, quantized, rtol=0, atol=.5) 165 | self.assertAllClose(values, decoded, rtol=0, atol=.5) 166 | 167 | def test_channels_first(self): 168 | # Test the layer with more than one channel and multiple input dimensions, 169 | # with the channel dimension right after the batch dimension. 170 | inputs = tf.placeholder(tf.float32, (None, 3, None, None)) 171 | layer = tfc.EntropyBottleneck( 172 | data_format="channels_first", init_scale=50) 173 | noisy, _ = layer(inputs, training=True) 174 | quantized, _ = layer(inputs, training=False) 175 | bitstrings = layer.compress(inputs) 176 | decoded = layer.decompress(bitstrings, tf.shape(inputs)[1:]) 177 | with self.test_session() as sess: 178 | sess.run(tf.global_variables_initializer()) 179 | self.assertTrue(len(layer.updates) == 1) 180 | sess.run(layer.updates[0]) 181 | values = 5 * np.random.normal(size=(2, 3, 5, 7)) 182 | noisy, quantized, decoded = sess.run( 183 | [noisy, quantized, decoded], {inputs: values}) 184 | self.assertAllClose(values, noisy, rtol=0, atol=.5) 185 | self.assertAllClose(values, quantized, rtol=0, atol=.5) 186 | self.assertAllClose(values, decoded, rtol=0, atol=.5) 187 | 188 | def test_compress(self): 189 | # Test compression and decompression, and produce test data for 190 | # `test_decompress`. If you set the constant at the end to `True`, this test 191 | # will fail and the log will contain the new test data. 192 | inputs = tf.placeholder(tf.float32, (2, 3, 10)) 193 | layer = tfc.EntropyBottleneck( 194 | data_format="channels_first", filters=(), init_scale=2) 195 | bitstrings = layer.compress(inputs) 196 | decoded = layer.decompress(bitstrings, tf.shape(inputs)[1:]) 197 | with self.test_session() as sess: 198 | sess.run(tf.global_variables_initializer()) 199 | self.assertTrue(len(layer.updates) == 1) 200 | sess.run(layer.updates[0]) 201 | values = 5 * np.random.uniform(size=(2, 3, 10)) - 2.5 202 | bitstrings, quantized_cdf, decoded = sess.run( 203 | [bitstrings, layer._quantized_cdf, decoded], {inputs: values}) 204 | self.assertAllClose(values, decoded, rtol=0, atol=.5) 205 | # Set this constant to `True` to log new test data for `test_decompress`. 206 | if False: # pylint:disable=using-constant-test 207 | assert False, (bitstrings, quantized_cdf, decoded) 208 | 209 | # Data generated by `test_compress`. 210 | # pylint:disable=g-inconsistent-quotes,bad-whitespace 211 | bitstrings = np.array([ 212 | b'\x1e\xbag}\xc2\xdaN\x8b\xbd.', 213 | b'\x8dF\xf0%\x1cv\xccllW' 214 | ], dtype=object) 215 | 216 | quantized_cdf = np.array([ 217 | [ 0, 15636, 22324, 30145, 38278, 65536], 218 | [ 0, 19482, 26927, 35052, 42904, 65535], 219 | [ 0, 21093, 28769, 36919, 44578, 65536] 220 | ], dtype=np.int32) 221 | 222 | expected = np.array([ 223 | [[-2., 1., 0., -2., -1., -2., -2., -2., 2., -1.], 224 | [ 1., 2., 1., 0., -2., -2., 1., 2., 0., 1.], 225 | [ 2., 0., -2., 2., 0., -1., -2., 0., 2., 0.]], 226 | [[ 1., 2., 0., -1., 1., 2., 1., 1., 2., -2.], 227 | [ 2., -1., -1., 0., -1., 2., 0., 2., -2., 2.], 228 | [ 2., -2., -2., -1., -2., 1., -2., 0., 0., 0.]] 229 | ], dtype=np.float32) 230 | # pylint:enable=g-inconsistent-quotes,bad-whitespace 231 | 232 | def test_decompress(self): 233 | # Test that decompression of values compressed with a previous version 234 | # works, i.e. that the file format doesn't change across revisions. 235 | bitstrings = tf.placeholder(tf.string) 236 | input_shape = tf.placeholder(tf.int32) 237 | quantized_cdf = tf.placeholder(tf.int32) 238 | layer = tfc.EntropyBottleneck( 239 | data_format="channels_first", filters=(), dtype=tf.float32) 240 | layer.build(self.expected.shape) 241 | layer._quantized_cdf = quantized_cdf 242 | decoded = layer.decompress(bitstrings, input_shape[1:]) 243 | with self.test_session() as sess: 244 | sess.run(tf.global_variables_initializer()) 245 | decoded, = sess.run([decoded], { 246 | bitstrings: self.bitstrings, input_shape: self.expected.shape, 247 | quantized_cdf: self.quantized_cdf}) 248 | self.assertAllClose(self.expected, decoded, rtol=0, atol=1e-6) 249 | 250 | def test_build_decompress(self): 251 | # Test that layer can be built when `decompress` is the first call to it. 252 | bitstrings = tf.placeholder(tf.string) 253 | input_shape = tf.placeholder(tf.int32, shape=[3]) 254 | layer = tfc.EntropyBottleneck(dtype=tf.float32) 255 | layer.decompress(bitstrings, input_shape[1:], channels=5) 256 | self.assertTrue(layer.built) 257 | 258 | def test_pmf_normalization(self): 259 | # Test that probability mass functions are normalized correctly. 260 | layer = tfc.EntropyBottleneck(dtype=tf.float32) 261 | layer.build((None, 10)) 262 | with self.test_session() as sess: 263 | sess.run(tf.global_variables_initializer()) 264 | pmf, = sess.run([layer._pmf]) 265 | self.assertAllClose(np.ones(10), np.sum(pmf, axis=-1), rtol=0, atol=1e-6) 266 | 267 | def test_visualize(self): 268 | # Test that summary op can be constructed. 269 | layer = tfc.EntropyBottleneck(dtype=tf.float32) 270 | layer.build((None, 10)) 271 | summary = layer.visualize() 272 | with self.test_session() as sess: 273 | sess.run(tf.global_variables_initializer()) 274 | sess.run([summary]) 275 | 276 | def test_normalization(self): 277 | # Test that densities are normalized correctly. 278 | inputs = tf.placeholder(tf.float32, (None, 1)) 279 | layer = tfc.EntropyBottleneck(filters=(2,)) 280 | _, likelihood = layer(inputs, training=True) 281 | with self.test_session() as sess: 282 | sess.run(tf.global_variables_initializer()) 283 | x = np.repeat(np.arange(-200, 201), 1000)[:, None] 284 | likelihood, = sess.run([likelihood], {inputs: x}) 285 | self.assertEqual(x.shape, likelihood.shape) 286 | integral = np.sum(likelihood) * .001 287 | self.assertAllClose(1, integral, rtol=0, atol=2e-4) 288 | 289 | def test_entropy_estimates(self): 290 | # Test that entropy estimates match actual range coding. 291 | inputs = tf.placeholder(tf.float32, (1, None, 1)) 292 | layer = tfc.EntropyBottleneck( 293 | filters=(2, 3), data_format="channels_last") 294 | _, likelihood = layer(inputs, training=True) 295 | diff_entropy = tf.reduce_sum(tf.log(likelihood)) / -np.log(2) 296 | _, likelihood = layer(inputs, training=False) 297 | disc_entropy = tf.reduce_sum(tf.log(likelihood)) / -np.log(2) 298 | bitstrings = layer.compress(inputs) 299 | with self.test_session() as sess: 300 | sess.run(tf.global_variables_initializer()) 301 | self.assertTrue(len(layer.updates) == 1) 302 | sess.run(layer.updates[0]) 303 | diff_entropy, disc_entropy, bitstrings = sess.run( 304 | [diff_entropy, disc_entropy, bitstrings], 305 | {inputs: np.random.normal(size=(1, 10000, 1))}) 306 | codelength = 8 * sum(len(bitstring) for bitstring in bitstrings) 307 | self.assertAllClose(diff_entropy, disc_entropy, rtol=5e-3, atol=0) 308 | self.assertAllClose(disc_entropy, codelength, rtol=5e-3, atol=0) 309 | self.assertGreater(codelength, disc_entropy) 310 | 311 | 312 | if __name__ == "__main__": 313 | test.main() 314 | -------------------------------------------------------------------------------- /tensorflow_compression/python/layers/gdn.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """GDN layer.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Dependency imports 23 | 24 | from tensorflow.python.eager import context 25 | from tensorflow.python.framework import ops 26 | from tensorflow.python.framework import tensor_shape 27 | from tensorflow.python.layers import base 28 | from tensorflow.python.ops import array_ops 29 | from tensorflow.python.ops import init_ops 30 | from tensorflow.python.ops import math_ops 31 | from tensorflow.python.ops import nn 32 | 33 | from tensorflow_compression.python.layers import parameterizers 34 | 35 | 36 | _default_beta_param = parameterizers.NonnegativeParameterizer( 37 | minimum=1e-6) 38 | _default_gamma_param = parameterizers.NonnegativeParameterizer() 39 | 40 | 41 | class GDN(base.Layer): 42 | """Generalized divisive normalization layer. 43 | 44 | Based on the papers: 45 | 46 | > "Density modeling of images using a generalized normalization 47 | > transformation"
48 | > J. Ballé, V. Laparra, E.P. Simoncelli
49 | > https://arxiv.org/abs/1511.06281 50 | 51 | > "End-to-end optimized image compression"
52 | > J. Ballé, V. Laparra, E.P. Simoncelli
53 | > https://arxiv.org/abs/1611.01704 54 | 55 | Implements an activation function that is essentially a multivariate 56 | generalization of a particular sigmoid-type function: 57 | 58 | ``` 59 | y[i] = x[i] / sqrt(beta[i] + sum_j(gamma[j, i] * x[j])) 60 | ``` 61 | 62 | where `i` and `j` run over channels. This implementation never sums across 63 | spatial dimensions. It is similar to local response normalization, but much 64 | more flexible, as `beta` and `gamma` are trainable parameters. 65 | 66 | Arguments: 67 | inverse: Boolean. If `False` (default), compute GDN response. If `True`, 68 | compute IGDN response (one step of fixed point iteration to invert GDN; 69 | the division is replaced by multiplication). 70 | rectify: Boolean. If `True`, apply a `relu` nonlinearity to the inputs 71 | before calculating GDN response. 72 | gamma_init: The gamma matrix will be initialized as the identity matrix 73 | multiplied with this value. If set to zero, the layer is effectively 74 | initialized to the identity operation, since beta is initialized as one. 75 | A good default setting is somewhere between 0 and 0.5. 76 | data_format: Format of input tensor. Currently supports `'channels_first'` 77 | and `'channels_last'`. 78 | beta_parameterizer: Reparameterization for beta parameter. Defaults to 79 | `NonnegativeParameterizer` with a minimum value of `1e-6`. 80 | gamma_parameterizer: Reparameterization for gamma parameter. Defaults to 81 | `NonnegativeParameterizer` with a minimum value of `0`. 82 | activity_regularizer: Regularizer function for the output. 83 | trainable: Boolean, if `True`, also add variables to the graph collection 84 | `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). 85 | name: String, the name of the layer. Layers with the same name will 86 | share weights, but to avoid mistakes we require `reuse=True` in such 87 | cases. 88 | 89 | Properties: 90 | inverse: Boolean, whether GDN is computed (`True`) or IGDN (`False`). 91 | rectify: Boolean, whether to apply `relu` before normalization or not. 92 | data_format: Format of input tensor. Currently supports `'channels_first'` 93 | and `'channels_last'`. 94 | beta: The beta parameter as defined above (1D `Tensor`). 95 | gamma: The gamma parameter as defined above (2D `Tensor`). 96 | """ 97 | 98 | def __init__(self, 99 | inverse=False, 100 | rectify=False, 101 | gamma_init=.1, 102 | data_format="channels_last", 103 | beta_parameterizer=_default_beta_param, 104 | gamma_parameterizer=_default_gamma_param, 105 | activity_regularizer=None, 106 | trainable=True, 107 | name=None, 108 | **kwargs): 109 | super(GDN, self).__init__(trainable=trainable, name=name, 110 | activity_regularizer=activity_regularizer, 111 | **kwargs) 112 | self.inverse = bool(inverse) 113 | self.rectify = bool(rectify) 114 | self._gamma_init = float(gamma_init) 115 | self.data_format = data_format 116 | self._beta_parameterizer = beta_parameterizer 117 | self._gamma_parameterizer = gamma_parameterizer 118 | self._channel_axis() # trigger ValueError early 119 | self.input_spec = base.InputSpec(min_ndim=2) 120 | 121 | def _channel_axis(self): 122 | try: 123 | return {"channels_first": 1, "channels_last": -1}[self.data_format] 124 | except KeyError: 125 | raise ValueError("Unsupported `data_format` for GDN layer: {}.".format( 126 | self.data_format)) 127 | 128 | def build(self, input_shape): 129 | channel_axis = self._channel_axis() 130 | input_shape = tensor_shape.TensorShape(input_shape) 131 | num_channels = input_shape[channel_axis].value 132 | if num_channels is None: 133 | raise ValueError("The channel dimension of the inputs to `GDN` " 134 | "must be defined.") 135 | self._input_rank = input_shape.ndims 136 | self.input_spec = base.InputSpec(ndim=input_shape.ndims, 137 | axes={channel_axis: num_channels}) 138 | 139 | self.beta = self._beta_parameterizer( 140 | name="beta", shape=[num_channels], dtype=self.dtype, 141 | getter=self.add_variable, initializer=init_ops.Ones()) 142 | 143 | self.gamma = self._gamma_parameterizer( 144 | name="gamma", shape=[num_channels, num_channels], dtype=self.dtype, 145 | getter=self.add_variable, 146 | initializer=init_ops.Identity(gain=self._gamma_init)) 147 | 148 | self.built = True 149 | 150 | def call(self, inputs): 151 | inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) 152 | ndim = self._input_rank 153 | 154 | if self.rectify: 155 | inputs = nn.relu(inputs) 156 | 157 | # Compute normalization pool. 158 | if ndim == 2: 159 | norm_pool = math_ops.matmul(math_ops.square(inputs), self.gamma) 160 | norm_pool = nn.bias_add(norm_pool, self.beta) 161 | elif self.data_format == "channels_last" and ndim <= 5: 162 | shape = self.gamma.shape.as_list() 163 | gamma = array_ops.reshape(self.gamma, (ndim - 2) * [1] + shape) 164 | norm_pool = nn.convolution(math_ops.square(inputs), gamma, "VALID") 165 | norm_pool = nn.bias_add(norm_pool, self.beta) 166 | else: # generic implementation 167 | # This puts channels in the last dimension regardless of input. 168 | norm_pool = math_ops.tensordot( 169 | math_ops.square(inputs), self.gamma, [[self._channel_axis()], [0]]) 170 | norm_pool += self.beta 171 | if self.data_format == "channels_first": 172 | # Return to channels_first format if necessary. 173 | axes = list(range(ndim - 1)) 174 | axes.insert(1, ndim - 1) 175 | norm_pool = array_ops.transpose(norm_pool, axes) 176 | 177 | if self.inverse: 178 | norm_pool = math_ops.sqrt(norm_pool) 179 | else: 180 | norm_pool = math_ops.rsqrt(norm_pool) 181 | outputs = inputs * norm_pool 182 | 183 | if not context.executing_eagerly(): 184 | outputs.set_shape(self.compute_output_shape(inputs.shape)) 185 | return outputs 186 | 187 | def compute_output_shape(self, input_shape): 188 | return tensor_shape.TensorShape(input_shape) 189 | -------------------------------------------------------------------------------- /tensorflow_compression/python/layers/gdn_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests of GDN layer.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Dependency imports 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | import tensorflow_compression as tfc 27 | 28 | 29 | class GDNTest(tf.test.TestCase): 30 | 31 | def _run_gdn(self, x, shape, inverse, rectify, data_format): 32 | inputs = tf.placeholder(tf.float32, shape) 33 | layer = tfc.GDN( 34 | inverse=inverse, rectify=rectify, data_format=data_format) 35 | outputs = layer(inputs) 36 | with self.test_session() as sess: 37 | tf.global_variables_initializer().run() 38 | y, = sess.run([outputs], {inputs: x}) 39 | return y 40 | 41 | def test_invalid_data_format(self): 42 | x = np.random.uniform(size=(1, 2, 3, 4)) 43 | with self.assertRaises(ValueError): 44 | self._run_gdn(x, x.shape, False, False, "NHWC") 45 | 46 | def test_unknown_dim(self): 47 | x = np.random.uniform(size=(1, 2, 3, 4)) 48 | with self.assertRaises(ValueError): 49 | self._run_gdn(x, 4 * [None], False, False, "channels_last") 50 | 51 | def test_channels_last(self): 52 | for ndim in [2, 3, 4, 5, 6]: 53 | x = np.random.uniform(size=(1, 2, 3, 4, 5, 6)[:ndim]) 54 | y = self._run_gdn(x, x.shape, False, False, "channels_last") 55 | self.assertEqual(x.shape, y.shape) 56 | self.assertAllClose(y, x / np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6) 57 | 58 | def test_channels_first(self): 59 | for ndim in [2, 3, 4, 5, 6]: 60 | x = np.random.uniform(size=(6, 5, 4, 3, 2, 1)[:ndim]) 61 | y = self._run_gdn(x, x.shape, False, False, "channels_first") 62 | self.assertEqual(x.shape, y.shape) 63 | self.assertAllClose( 64 | y, x / np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6) 65 | 66 | def test_wrong_dims(self): 67 | x = np.random.uniform(size=(3,)) 68 | with self.assertRaises(ValueError): 69 | self._run_gdn(x, x.shape, False, False, "channels_last") 70 | with self.assertRaises(ValueError): 71 | self._run_gdn(x, x.shape, True, True, "channels_first") 72 | 73 | def test_igdn(self): 74 | x = np.random.uniform(size=(1, 2, 3, 4)) 75 | y = self._run_gdn(x, x.shape, True, False, "channels_last") 76 | self.assertEqual(x.shape, y.shape) 77 | self.assertAllClose(y, x * np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6) 78 | 79 | def test_rgdn(self): 80 | x = np.random.uniform(-.5, .5, size=(1, 2, 3, 4)) 81 | y = self._run_gdn(x, x.shape, False, True, "channels_last") 82 | self.assertEqual(x.shape, y.shape) 83 | x = np.maximum(x, 0) 84 | self.assertAllClose(y, x / np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6) 85 | 86 | 87 | if __name__ == "__main__": 88 | tf.test.main() 89 | -------------------------------------------------------------------------------- /tensorflow_compression/python/layers/initializers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Initializers for layer classes.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Dependency imports 23 | 24 | from tensorflow.python.ops import array_ops 25 | from tensorflow.python.ops import linalg_ops 26 | 27 | 28 | class IdentityInitializer(object): 29 | """Initialize to the identity kernel with the given shape. 30 | 31 | This creates an n-D kernel suitable for `SignalConv*` with the requested 32 | support that produces an output identical to its input (except possibly at the 33 | signal boundaries). 34 | 35 | Note: The identity initializer in `tf.initializers` is only suitable for 36 | matrices, not for n-D convolution kernels (i.e., no spatial support). 37 | """ 38 | 39 | def __init__(self, gain=1): 40 | self.gain = float(gain) 41 | 42 | def __call__(self, shape, dtype=None, partition_info=None): 43 | del partition_info # unused 44 | assert len(shape) > 2, shape 45 | 46 | support = tuple(shape[:-2]) + (1, 1) 47 | indices = [[s // 2 for s in support]] 48 | updates = array_ops.constant([self.gain], dtype=dtype) 49 | kernel = array_ops.scatter_nd(indices, updates, support) 50 | 51 | assert shape[-2] == shape[-1], shape 52 | if shape[-1] != 1: 53 | kernel *= linalg_ops.eye(shape[-1], dtype=dtype) 54 | 55 | return kernel 56 | -------------------------------------------------------------------------------- /tensorflow_compression/python/layers/parameterizers.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Parameterizations for layer classes.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Dependency imports 23 | 24 | from tensorflow.python.ops import array_ops 25 | from tensorflow.python.ops import math_ops 26 | 27 | from tensorflow_compression.python.ops import math_ops as cmath_ops 28 | from tensorflow_compression.python.ops import spectral_ops as spectral_ops 29 | 30 | 31 | class Parameterizer(object): 32 | """Parameterizer object (abstract base class). 33 | 34 | Parameterizer objects are immutable objects designed to facilitate 35 | reparameterization of model parameters (tensor variables). They are called 36 | just like `tf.get_variable` with an additional argument `getter` specifying 37 | the actual function call to generate a variable (in many cases, `getter` would 38 | be `tf.get_variable`). 39 | 40 | To achieve reparameterization, a parameterizer object wraps the provided 41 | initializer, regularizer, and the returned variable in its own Tensorflow 42 | code. 43 | """ 44 | pass 45 | 46 | 47 | class StaticParameterizer(Parameterizer): 48 | """A parameterization object that always returns a constant tensor. 49 | 50 | No variables are created, hence the parameter never changes. 51 | 52 | Args: 53 | initializer: An initializer object which will be called to produce the 54 | static parameter. 55 | """ 56 | 57 | def __init__(self, initializer): 58 | self.initializer = initializer 59 | 60 | def __call__(self, getter, name, shape, dtype, initializer, regularizer=None): 61 | del getter, name, initializer, regularizer # unused 62 | return self.initializer(shape, dtype) 63 | 64 | 65 | class RDFTParameterizer(Parameterizer): 66 | """Object encapsulating RDFT reparameterization. 67 | 68 | This uses the real-input discrete Fourier transform (RDFT) of a kernel as 69 | its parameterization. The inverse RDFT is applied to the variable to produce 70 | the parameter. 71 | 72 | (see https://en.wikipedia.org/wiki/Discrete_Fourier_transform) 73 | 74 | Args: 75 | dc: Boolean. If `False`, the DC component of the kernel RDFTs is not 76 | represented, forcing the filters to be highpass. Defaults to `True`. 77 | """ 78 | 79 | def __init__(self, dc=True): 80 | self.dc = bool(dc) 81 | 82 | def __call__(self, getter, name, shape, dtype, initializer, regularizer=None): 83 | if all(s == 1 for s in shape[:-2]): 84 | return getter(name=name, shape=shape, dtype=dtype, 85 | initializer=initializer, regularizer=regularizer) 86 | var_shape = shape 87 | var_dtype = dtype 88 | size = var_shape[0] 89 | for s in var_shape[1:-2]: 90 | size *= s 91 | irdft_matrix = spectral_ops.irdft_matrix(var_shape[:-2], dtype=var_dtype) 92 | if self.dc: 93 | rdft_shape = (size, var_shape[-2] * var_shape[-1]) 94 | else: 95 | irdft_matrix = irdft_matrix[:, 1:] 96 | rdft_shape = (size - 1, var_shape[-2] * var_shape[-1]) 97 | rdft_dtype = var_dtype 98 | rdft_name = name + "_rdft" 99 | 100 | def rdft_initializer(shape, dtype=None, partition_info=None): 101 | assert tuple(shape) == rdft_shape, shape 102 | assert dtype == rdft_dtype, dtype 103 | init = initializer( 104 | var_shape, dtype=var_dtype, partition_info=partition_info) 105 | init = array_ops.reshape(init, (-1, rdft_shape[-1])) 106 | init = math_ops.matmul(irdft_matrix, init, transpose_a=True) 107 | return init 108 | 109 | def reparam(rdft): 110 | var = math_ops.matmul(irdft_matrix, rdft) 111 | var = array_ops.reshape(var, var_shape) 112 | return var 113 | 114 | if regularizer is not None: 115 | regularizer = lambda rdft: regularizer(reparam(rdft)) 116 | 117 | rdft = getter( 118 | name=rdft_name, shape=rdft_shape, dtype=rdft_dtype, 119 | initializer=rdft_initializer, regularizer=regularizer) 120 | return reparam(rdft) 121 | 122 | 123 | class NonnegativeParameterizer(Parameterizer): 124 | """Object encapsulating nonnegative parameterization as needed for GDN. 125 | 126 | The variable is subjected to an invertible transformation that slows down the 127 | learning rate for small values. 128 | 129 | Args: 130 | minimum: Float. Lower bound for parameters (defaults to zero). 131 | reparam_offset: Float. Offset added to the reparameterization of beta and 132 | gamma. The reparameterization of beta and gamma as their square roots lets 133 | the training slow down when their values are close to zero, which is 134 | desirable as small values in the denominator can lead to a situation where 135 | gradient noise on beta/gamma leads to extreme amounts of noise in the GDN 136 | activations. However, without the offset, we would get zero gradients if 137 | any elements of beta or gamma were exactly zero, and thus the training 138 | could get stuck. To prevent this, we add this small constant. The default 139 | value was empirically determined as a good starting point. Making it 140 | bigger potentially leads to more gradient noise on the activations, making 141 | it too small may lead to numerical precision issues. 142 | """ 143 | 144 | def __init__(self, minimum=0, reparam_offset=2 ** -18): 145 | self.minimum = float(minimum) 146 | self.reparam_offset = float(reparam_offset) 147 | 148 | def __call__(self, getter, name, shape, dtype, initializer, regularizer=None): 149 | pedestal = array_ops.constant(self.reparam_offset ** 2, dtype=dtype) 150 | bound = array_ops.constant( 151 | (self.minimum + self.reparam_offset ** 2) ** .5, dtype=dtype) 152 | reparam_name = "reparam_" + name 153 | 154 | def reparam_initializer(shape, dtype=None, partition_info=None): 155 | init = initializer(shape, dtype=dtype, partition_info=partition_info) 156 | init = math_ops.sqrt(init + pedestal) 157 | return init 158 | 159 | def reparam(var): 160 | var = cmath_ops.lower_bound(var, bound) 161 | var = math_ops.square(var) - pedestal 162 | return var 163 | 164 | if regularizer is not None: 165 | regularizer = lambda rdft: regularizer(reparam(rdft)) 166 | 167 | var = getter( 168 | name=reparam_name, shape=shape, dtype=dtype, 169 | initializer=reparam_initializer, regularizer=regularizer) 170 | return reparam(var) 171 | -------------------------------------------------------------------------------- /tensorflow_compression/python/layers/parameterizers_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests of parameterizers.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Dependency imports 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | import tensorflow_compression as tfc 27 | 28 | 29 | class ParameterizersTest(tf.test.TestCase): 30 | 31 | def _test_parameterizer(self, param, init, shape): 32 | var = param( 33 | getter=tf.get_variable, name="test", shape=shape, dtype=tf.float32, 34 | initializer=init, regularizer=None) 35 | with self.test_session() as sess: 36 | tf.global_variables_initializer().run() 37 | var, = sess.run([var]) 38 | return var 39 | 40 | def test_static_parameterizer(self): 41 | shape = (1, 2, 3, 4) 42 | var = self._test_parameterizer( 43 | tfc.StaticParameterizer(tf.initializers.zeros()), 44 | tf.initializers.random_uniform(), shape) 45 | self.assertEqual(var.shape, shape) 46 | self.assertAllClose(var, np.zeros(shape), rtol=0, atol=1e-7) 47 | 48 | def test_rdft_parameterizer(self): 49 | shape = (3, 4, 2, 1) 50 | var = self._test_parameterizer( 51 | tfc.RDFTParameterizer(), 52 | tf.initializers.ones(), shape) 53 | self.assertEqual(var.shape, shape) 54 | self.assertAllClose(var, np.ones(shape), rtol=0, atol=1e-6) 55 | 56 | def test_nonnegative_parameterizer(self): 57 | shape = (1, 2, 3, 4) 58 | var = self._test_parameterizer( 59 | tfc.NonnegativeParameterizer(), 60 | tf.initializers.random_uniform(), shape) 61 | self.assertEqual(var.shape, shape) 62 | self.assertTrue(np.all(var >= 0)) 63 | 64 | def test_positive_parameterizer(self): 65 | shape = (1, 2, 3, 4) 66 | var = self._test_parameterizer( 67 | tfc.NonnegativeParameterizer(minimum=.1), 68 | tf.initializers.random_uniform(), shape) 69 | self.assertEqual(var.shape, shape) 70 | self.assertTrue(np.all(var >= .1)) 71 | 72 | 73 | if __name__ == "__main__": 74 | tf.test.main() 75 | -------------------------------------------------------------------------------- /tensorflow_compression/python/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | -------------------------------------------------------------------------------- /tensorflow_compression/python/ops/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/tensorflow_compression/python/ops/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_compression/python/ops/__pycache__/coder_ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/tensorflow_compression/python/ops/__pycache__/coder_ops.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_compression/python/ops/__pycache__/math_ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/tensorflow_compression/python/ops/__pycache__/math_ops.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_compression/python/ops/__pycache__/padding_ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/tensorflow_compression/python/ops/__pycache__/padding_ops.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_compression/python/ops/__pycache__/spectral_ops.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/tensorflow_compression/python/ops/__pycache__/spectral_ops.cpython-37.pyc -------------------------------------------------------------------------------- /tensorflow_compression/python/ops/coder_ops.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Range coder operations.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Dependency imports 23 | 24 | from tensorflow.contrib.coder.python.ops import coder_ops 25 | 26 | pmf_to_quantized_cdf = coder_ops.pmf_to_quantized_cdf 27 | range_decode = coder_ops.range_decode 28 | range_encode = coder_ops.range_encode 29 | -------------------------------------------------------------------------------- /tensorflow_compression/python/ops/coder_ops_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Coder operations tests.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Dependency imports 23 | 24 | import tensorflow as tf 25 | 26 | from tensorflow.python.platform import test 27 | 28 | import tensorflow_compression as tfc 29 | 30 | 31 | class CoderOpsTest(test.TestCase): 32 | """Coder ops test. 33 | 34 | Coder ops have C++ tests. Python test just ensures that Python binding is not 35 | broken. 36 | """ 37 | 38 | def testReadmeExample(self): 39 | data = tf.random_uniform((128, 128), 0, 10, dtype=tf.int32) 40 | histogram = tf.bincount(data, minlength=10, maxlength=10) 41 | cdf = tf.cumsum(histogram, exclusive=False) 42 | cdf = tf.pad(cdf, [[1, 0]]) 43 | cdf = tf.reshape(cdf, [1, 1, -1]) 44 | 45 | data = tf.cast(data, tf.int16) 46 | encoded = tfc.range_encode(data, cdf, precision=14) 47 | decoded = tfc.range_decode(encoded, tf.shape(data), cdf, precision=14) 48 | 49 | with self.test_session() as sess: 50 | self.assertAllEqual(*sess.run((data, decoded))) 51 | 52 | 53 | if __name__ == "__main__": 54 | test.main() 55 | -------------------------------------------------------------------------------- /tensorflow_compression/python/ops/math_ops.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Math operations.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Dependency imports 23 | 24 | from tensorflow.python.framework import ops 25 | from tensorflow.python.ops import math_ops 26 | 27 | 28 | @ops.RegisterGradient("IdentityFirstOfTwoInputs") 29 | def _identity_first_of_two_inputs_grad(op, grad): 30 | """Gradient for `lower_bound` or `upper_bound` if `gradient == 'identity'`. 31 | 32 | Args: 33 | op: The op for which to calculate a gradient. 34 | grad: Gradient with respect to the output of the op. 35 | 36 | Returns: 37 | Gradient with respect to the inputs of the op. 38 | """ 39 | del op # unused 40 | return [grad, None] 41 | 42 | 43 | @ops.RegisterGradient("UpperBound") 44 | def _upper_bound_grad(op, grad): 45 | """Gradient for `upper_bound` if `gradient == 'identity_if_towards'`. 46 | 47 | Args: 48 | op: The op for which to calculate a gradient. 49 | grad: Gradient with respect to the output of the op. 50 | 51 | Returns: 52 | Gradient with respect to the inputs of the op. 53 | """ 54 | inputs, bound = op.inputs 55 | pass_through_if = math_ops.logical_or(inputs <= bound, grad > 0) 56 | return [math_ops.cast(pass_through_if, grad.dtype) * grad, None] 57 | 58 | 59 | @ops.RegisterGradient("LowerBound") 60 | def _lower_bound_grad(op, grad): 61 | """Gradient for `lower_bound` if `gradient == 'identity_if_towards'`. 62 | 63 | Args: 64 | op: The op for which to calculate a gradient. 65 | grad: Gradient with respect to the output of the op. 66 | 67 | Returns: 68 | Gradient with respect to the inputs of the op. 69 | """ 70 | inputs, bound = op.inputs 71 | pass_through_if = math_ops.logical_or(inputs >= bound, grad < 0) 72 | return [math_ops.cast(pass_through_if, grad.dtype) * grad, None] 73 | 74 | 75 | def upper_bound(inputs, bound, gradient="identity_if_towards", name=None): 76 | """Same as `tf.minimum`, but with helpful gradient for `inputs > bound`. 77 | 78 | This function behaves just like `tf.minimum`, but the behavior of the gradient 79 | with respect to `inputs` for input values that hit the bound depends on 80 | `gradient`: 81 | 82 | If set to `'disconnected'`, the returned gradient is zero for values that hit 83 | the bound. This is identical to the behavior of `tf.minimum`. 84 | 85 | If set to `'identity'`, the gradient is unconditionally replaced with the 86 | identity function (i.e., pretending this function does not exist). 87 | 88 | If set to `'identity_if_towards'`, the gradient is replaced with the identity 89 | function, but only if applying gradient descent would push the values of 90 | `inputs` towards the bound. For gradient values that push away from the bound, 91 | the returned gradient is still zero. 92 | 93 | Note: In the latter two cases, no gradient is returned for `bound`. 94 | Also, the implementation of `gradient == 'identity_if_towards'` currently 95 | assumes that the shape of `inputs` is the same as the shape of the output. It 96 | won't work reliably for all possible broadcasting scenarios. 97 | 98 | Args: 99 | inputs: Input tensor. 100 | bound: Upper bound for the input tensor. 101 | gradient: 'disconnected', 'identity', or 'identity_if_towards' (default). 102 | name: Name for this op. 103 | 104 | Returns: 105 | `tf.minimum(inputs, bound)` 106 | 107 | Raises: 108 | ValueError: for invalid value of `gradient`. 109 | """ 110 | try: 111 | gradient = { 112 | "identity_if_towards": "UpperBound", 113 | "identity": "IdentityFirstOfTwoInputs", 114 | "disconnected": None, 115 | }[gradient] 116 | except KeyError: 117 | raise ValueError("Invalid value for `gradient`: '{}'.".format(gradient)) 118 | 119 | with ops.name_scope(name, "UpperBound", [inputs, bound]) as scope: 120 | inputs = ops.convert_to_tensor(inputs, name="inputs") 121 | bound = ops.convert_to_tensor( 122 | bound, name="bound", dtype=inputs.dtype) 123 | if gradient: 124 | with ops.get_default_graph().gradient_override_map({"Minimum": gradient}): 125 | return math_ops.minimum(inputs, bound, name=scope) 126 | else: 127 | return math_ops.minimum(inputs, bound, name=scope) 128 | 129 | 130 | def lower_bound(inputs, bound, gradient="identity_if_towards", name=None): 131 | """Same as `tf.maximum`, but with helpful gradient for `inputs < bound`. 132 | 133 | This function behaves just like `tf.maximum`, but the behavior of the gradient 134 | with respect to `inputs` for input values that hit the bound depends on 135 | `gradient`: 136 | 137 | If set to `'disconnected'`, the returned gradient is zero for values that hit 138 | the bound. This is identical to the behavior of `tf.maximum`. 139 | 140 | If set to `'identity'`, the gradient is unconditionally replaced with the 141 | identity function (i.e., pretending this function does not exist). 142 | 143 | If set to `'identity_if_towards'`, the gradient is replaced with the identity 144 | function, but only if applying gradient descent would push the values of 145 | `inputs` towards the bound. For gradient values that push away from the bound, 146 | the returned gradient is still zero. 147 | 148 | Note: In the latter two cases, no gradient is returned for `bound`. 149 | Also, the implementation of `gradient == 'identity_if_towards'` currently 150 | assumes that the shape of `inputs` is the same as the shape of the output. It 151 | won't work reliably for all possible broadcasting scenarios. 152 | 153 | Args: 154 | inputs: Input tensor. 155 | bound: Lower bound for the input tensor. 156 | gradient: 'disconnected', 'identity', or 'identity_if_towards' (default). 157 | name: Name for this op. 158 | 159 | Returns: 160 | `tf.maximum(inputs, bound)` 161 | 162 | Raises: 163 | ValueError: for invalid value of `gradient`. 164 | """ 165 | try: 166 | gradient = { 167 | "identity_if_towards": "LowerBound", 168 | "identity": "IdentityFirstOfTwoInputs", 169 | "disconnected": None, 170 | }[gradient] 171 | except KeyError: 172 | raise ValueError("Invalid value for `gradient`: '{}'.".format(gradient)) 173 | 174 | with ops.name_scope(name, "LowerBound", [inputs, bound]) as scope: 175 | inputs = ops.convert_to_tensor(inputs, name="inputs") 176 | bound = ops.convert_to_tensor( 177 | bound, name="bound", dtype=inputs.dtype) 178 | if gradient: 179 | with ops.get_default_graph().gradient_override_map({"Maximum": gradient}): 180 | return math_ops.maximum(inputs, bound, name=scope) 181 | else: 182 | return math_ops.maximum(inputs, bound, name=scope) 183 | -------------------------------------------------------------------------------- /tensorflow_compression/python/ops/math_ops_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests for the math operations.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Dependency imports 23 | 24 | import tensorflow as tf 25 | import tensorflow_compression as tfc 26 | 27 | 28 | class MathTest(tf.test.TestCase): 29 | 30 | def _test_upper_bound(self, gradient): 31 | inputs = tf.placeholder(dtype=tf.float32) 32 | outputs = tfc.upper_bound(inputs, 0, gradient=gradient) 33 | pgrads, = tf.gradients([outputs], [inputs], [tf.ones_like(inputs)]) 34 | ngrads, = tf.gradients([outputs], [inputs], [-tf.ones_like(inputs)]) 35 | 36 | inputs_feed = [-1, 1] 37 | outputs_expected = [-1, 0] 38 | if gradient == "disconnected": 39 | pgrads_expected = [1, 0] 40 | ngrads_expected = [-1, 0] 41 | elif gradient == "identity": 42 | pgrads_expected = [1, 1] 43 | ngrads_expected = [-1, -1] 44 | else: 45 | pgrads_expected = [1, 1] 46 | ngrads_expected = [-1, 0] 47 | 48 | with self.test_session() as sess: 49 | outputs, pgrads, ngrads = sess.run( 50 | [outputs, pgrads, ngrads], {inputs: inputs_feed}) 51 | self.assertAllEqual(outputs, outputs_expected) 52 | self.assertAllEqual(pgrads, pgrads_expected) 53 | self.assertAllEqual(ngrads, ngrads_expected) 54 | 55 | def test_upper_bound_disconnected(self): 56 | self._test_upper_bound("disconnected") 57 | 58 | def test_upper_bound_identity(self): 59 | self._test_upper_bound("identity") 60 | 61 | def test_upper_bound_identity_if_towards(self): 62 | self._test_upper_bound("identity_if_towards") 63 | 64 | def test_upper_bound_invalid(self): 65 | with self.assertRaises(ValueError): 66 | self._test_upper_bound("invalid") 67 | 68 | def _test_lower_bound(self, gradient): 69 | inputs = tf.placeholder(dtype=tf.float32) 70 | outputs = tfc.lower_bound(inputs, 0, gradient=gradient) 71 | pgrads, = tf.gradients([outputs], [inputs], [tf.ones_like(inputs)]) 72 | ngrads, = tf.gradients([outputs], [inputs], [-tf.ones_like(inputs)]) 73 | 74 | inputs_feed = [-1, 1] 75 | outputs_expected = [0, 1] 76 | if gradient == "disconnected": 77 | pgrads_expected = [0, 1] 78 | ngrads_expected = [0, -1] 79 | elif gradient == "identity": 80 | pgrads_expected = [1, 1] 81 | ngrads_expected = [-1, -1] 82 | else: 83 | pgrads_expected = [0, 1] 84 | ngrads_expected = [-1, -1] 85 | 86 | with self.test_session() as sess: 87 | outputs, pgrads, ngrads = sess.run( 88 | [outputs, pgrads, ngrads], {inputs: inputs_feed}) 89 | self.assertAllEqual(outputs, outputs_expected) 90 | self.assertAllEqual(pgrads, pgrads_expected) 91 | self.assertAllEqual(ngrads, ngrads_expected) 92 | 93 | def test_lower_bound_disconnected(self): 94 | self._test_lower_bound("disconnected") 95 | 96 | def test_lower_bound_identity(self): 97 | self._test_lower_bound("identity") 98 | 99 | def test_lower_bound_identity_if_towards(self): 100 | self._test_lower_bound("identity_if_towards") 101 | 102 | def test_lower_bound_invalid(self): 103 | with self.assertRaises(ValueError): 104 | self._test_lower_bound("invalid") 105 | 106 | 107 | if __name__ == "__main__": 108 | tf.test.main() 109 | -------------------------------------------------------------------------------- /tensorflow_compression/python/ops/padding_ops.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Padding ops.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Dependency imports 23 | 24 | 25 | def same_padding_for_kernel(shape, corr, strides_up=None): 26 | """Determine correct amount of padding for `same` convolution. 27 | 28 | To implement `'same'` convolutions, we first pad the image, and then perform a 29 | `'valid'` convolution or correlation. Given the kernel shape, this function 30 | determines the correct amount of padding so that the output of the convolution 31 | or correlation is the same size as the pre-padded input. 32 | 33 | Args: 34 | shape: Shape of the convolution kernel (without the channel dimensions). 35 | corr: Boolean. If `True`, assume cross correlation, if `False`, convolution. 36 | strides_up: If this is used for an upsampled convolution, specify the 37 | strides here. (For downsampled convolutions, specify `(1, 1)`: in that 38 | case, the strides don't matter.) 39 | 40 | Returns: 41 | The amount of padding at the beginning and end for each dimension. 42 | """ 43 | rank = len(shape) 44 | if strides_up is None: 45 | strides_up = rank * (1,) 46 | 47 | if corr: 48 | padding = [(s // 2, (s - 1) // 2) for s in shape] 49 | else: 50 | padding = [((s - 1) // 2, s // 2) for s in shape] 51 | 52 | padding = [((padding[i][0] - 1) // strides_up[i] + 1, 53 | (padding[i][1] - 1) // strides_up[i] + 1) for i in range(rank)] 54 | return padding 55 | -------------------------------------------------------------------------------- /tensorflow_compression/python/ops/padding_ops_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests of padding ops.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Dependency imports 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | import tensorflow_compression as tfc 27 | 28 | 29 | class PaddingOpsTest(tf.test.TestCase): 30 | 31 | def test_same_padding_corr(self): 32 | for ishape in [[10], [11]]: 33 | inputs = np.zeros(ishape, dtype=np.float32) 34 | inputs[len(inputs) // 2] = 1 35 | for kshape in [[4], [5]]: 36 | kernel = np.zeros(kshape, dtype=np.float32) 37 | kernel[len(kernel) // 2] = 1 38 | outputs = tf.nn.convolution( 39 | tf.reshape(inputs, (1, 1, -1, 1)), 40 | tf.reshape(kernel, (1, -1, 1, 1)), 41 | padding="VALID", data_format="NHWC") 42 | with self.test_session() as sess: 43 | outputs = np.squeeze(sess.run(outputs)) 44 | pos_inp = np.squeeze(np.nonzero(inputs)) 45 | pos_out = np.squeeze(np.nonzero(outputs)) 46 | padding = tfc.same_padding_for_kernel(kshape, True) 47 | self.assertEqual(padding[0][0], pos_inp - pos_out) 48 | 49 | def test_same_padding_conv(self): 50 | for ishape in [[10], [11]]: 51 | inputs = np.zeros(ishape, dtype=np.float32) 52 | inputs[len(inputs) // 2] = 1 53 | for kshape in [[4], [5]]: 54 | kernel = np.zeros(kshape, dtype=np.float32) 55 | kernel[len(kernel) // 2] = 1 56 | outputs = tf.nn.conv2d_transpose( 57 | tf.reshape(inputs, (1, 1, -1, 1)), 58 | tf.reshape(kernel, (1, -1, 1, 1)), 59 | (1, 1, ishape[0] + kshape[0] - 1, 1), 60 | strides=(1, 1, 1, 1), padding="VALID", data_format="NHWC") 61 | outputs = outputs[:, :, (kshape[0] - 1):-(kshape[0] - 1), :] 62 | with self.test_session() as sess: 63 | outputs = np.squeeze(sess.run(outputs)) 64 | pos_inp = np.squeeze(np.nonzero(inputs)) 65 | pos_out = np.squeeze(np.nonzero(outputs)) 66 | padding = tfc.same_padding_for_kernel(kshape, False) 67 | self.assertEqual(padding[0][0], pos_inp - pos_out) 68 | 69 | 70 | if __name__ == "__main__": 71 | tf.test.main() 72 | -------------------------------------------------------------------------------- /tensorflow_compression/python/ops/spectral_ops.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Parameterizations for layer classes.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Dependency imports 23 | 24 | import numpy as np 25 | from scipy import fftpack 26 | 27 | from tensorflow.python.framework import dtypes 28 | from tensorflow.python.framework import ops 29 | from tensorflow.python.ops import array_ops 30 | 31 | 32 | _matrix_cache = {} 33 | 34 | 35 | def irdft_matrix(shape, dtype=dtypes.float32): 36 | """Matrix for implementing kernel reparameterization with `tf.matmul`. 37 | 38 | This can be used to represent a kernel with the provided shape in the RDFT 39 | domain. 40 | 41 | Example code for kernel creation, assuming 2D kernels: 42 | 43 | ``` 44 | def create_kernel(init): 45 | shape = init.shape.as_list() 46 | matrix = irdft_matrix(shape[:2]) 47 | init = tf.reshape(init, (shape[0] * shape[1], shape[2] * shape[3])) 48 | init = tf.matmul(tf.transpose(matrix), init) 49 | kernel = tf.Variable(init) 50 | kernel = tf.matmul(matrix, kernel) 51 | kernel = tf.reshape(kernel, shape) 52 | return kernel 53 | ``` 54 | 55 | Args: 56 | shape: Iterable of integers. Shape of kernel to apply this matrix to. 57 | dtype: `dtype` of returned matrix. 58 | 59 | Returns: 60 | `Tensor` of shape `(prod(shape), prod(shape))` and dtype `dtype`. 61 | """ 62 | shape = tuple(int(s) for s in shape) 63 | dtype = dtypes.as_dtype(dtype) 64 | key = (ops.get_default_graph(), "irdft", shape, dtype.as_datatype_enum) 65 | matrix = _matrix_cache.get(key) 66 | if matrix is None: 67 | size = np.prod(shape) 68 | rank = len(shape) 69 | matrix = np.identity(size, dtype=np.float64).reshape((size,) + shape) 70 | for axis in range(rank): 71 | matrix = fftpack.rfft(matrix, axis=axis + 1) 72 | slices = (rank + 1) * [slice(None)] 73 | if shape[axis] % 2 == 1: 74 | slices[axis + 1] = slice(1, None) 75 | else: 76 | slices[axis + 1] = slice(1, -1) 77 | matrix[tuple(slices)] *= np.sqrt(2) 78 | matrix /= np.sqrt(size) 79 | matrix = np.reshape(matrix, (size, size)) 80 | matrix = array_ops.constant( 81 | matrix, dtype=dtype, name="irdft_" + "x".join([str(s) for s in shape])) 82 | _matrix_cache[key] = matrix 83 | return matrix 84 | -------------------------------------------------------------------------------- /tensorflow_compression/python/ops/spectral_ops_test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright 2018 Google LLC. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Tests of spectral_ops.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | # Dependency imports 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | import tensorflow_compression as tfc 27 | 28 | 29 | class SpectralOpsTest(tf.test.TestCase): 30 | 31 | def test_irdft1_matrix(self): 32 | for shape in [(4,), (3,)]: 33 | size = shape[0] 34 | matrix = tfc.irdft_matrix(shape) 35 | # Test that the matrix is orthonormal. 36 | result = tf.matmul(matrix, tf.transpose(matrix)) 37 | with self.test_session() as sess: 38 | result, = sess.run([result]) 39 | self.assertAllClose(result, np.identity(size)) 40 | 41 | def test_irdft2_matrix(self): 42 | for shape in [(7, 4), (8, 9)]: 43 | size = shape[0] * shape[1] 44 | matrix = tfc.irdft_matrix(shape) 45 | # Test that the matrix is orthonormal. 46 | result = tf.matmul(matrix, tf.transpose(matrix)) 47 | with self.test_session() as sess: 48 | result, = sess.run([result]) 49 | self.assertAllClose(result, np.identity(size)) 50 | 51 | def test_irdft3_matrix(self): 52 | for shape in [(3, 4, 2), (6, 3, 1)]: 53 | size = shape[0] * shape[1] * shape[2] 54 | matrix = tfc.irdft_matrix(shape) 55 | # Test that the matrix is orthonormal. 56 | result = tf.matmul(matrix, tf.transpose(matrix)) 57 | with self.test_session() as sess: 58 | result, = sess.run([result]) 59 | self.assertAllClose(result, np.identity(size)) 60 | 61 | 62 | if __name__ == "__main__": 63 | tf.test.main() 64 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import time 4 | import tensorflow as tf 5 | import numpy as np 6 | import cv2 7 | from model import * 8 | from utils import * 9 | from flow_utils import vis_flow_image_final 10 | from yuv_import import * 11 | import flownet_models as models 12 | from PIL import Image 13 | 14 | 15 | class Tester(object): 16 | def __init__(self, args): 17 | self.args = args 18 | config = tf.ConfigProto() 19 | config.gpu_options.allow_growth = True 20 | self.sess = tf.Session(config=config) 21 | value_class = get_test_class_dic(self.args.test_class)[self.args.test_class] 22 | size_list = value_class['resolution'].split('x') 23 | width = int(size_list[0]) 24 | height = int(size_list[1]) 25 | self.image_size = [height, width] 26 | if not ((self.image_size[0] > 512) and (self.image_size[1] > 512)): 27 | self.patch_height, self.patch_width = self.image_size 28 | else: 29 | self.patches_h_list, self.patches_w_list, self.real_patches_h_list, self.real_patches_w_list, self.patch_height, self.patch_width, self.patch_num_heigt, self.patch_num_width = get_patches( 30 | self.image_size[0], self.image_size[1]) 31 | self.Chroma_size = [size // 2 for size in self.image_size] 32 | self._build_graph() 33 | 34 | def _build_graph(self): 35 | # Input images and ground truth optical flow definition 36 | with tf.name_scope('Data'): 37 | self.images_pre_rec_multiframes = tf.placeholder(tf.float32, 38 | shape=(self.args.batch_size, *self.image_size, 3, 4), 39 | name='images_pre_rec_multiframes') 40 | self.images_pre_rec_4, self.images_pre_rec_3, self.images_pre_rec_2, self.images_pre_rec = tf.unstack( 41 | self.images_pre_rec_multiframes, axis=4) 42 | self.images_cur_ori = tf.placeholder(tf.float32, shape=(self.args.batch_size, *self.image_size, 3), 43 | name='images_cur_ori') 44 | self.images_cur_ori_patch = tf.placeholder(tf.float32, shape=( 45 | self.args.batch_size, self.patch_height, self.patch_width, 3), 46 | name='images_cur_ori_patch') 47 | self.images_pre_rec_patch = tf.placeholder(tf.float32, shape=( 48 | self.args.batch_size, self.patch_height, self.patch_width, 3), 49 | name='.images_pre_rec_patch') 50 | self.flow_ori_input = tf.placeholder(tf.float32, shape=(self.args.batch_size, *self.image_size, 2), 51 | name='flow_ori_input') 52 | self.flow_rec_input = tf.placeholder(tf.float32, shape=(self.args.batch_size, *self.image_size, 2), 53 | name='flow_rec_input') 54 | self.images_cur_pred_input = tf.placeholder(tf.float32, shape=(self.args.batch_size, *self.image_size, 3), 55 | name='images_cur_pred_input') 56 | self.flows_pre_rec = tf.placeholder(tf.float32, shape=(self.args.batch_size, *self.image_size, 2, 3), 57 | name='flows_pre_rec') 58 | self.flow_3_pre_rec, self.flow_2_pre_rec, self.flow_1_pre_rec = tf.unstack(self.flows_pre_rec, axis=4) 59 | ########ME-Net######## 60 | memodel = models.FlowNet2(height=self.patch_height, width=self.patch_width, name='flownet2') 61 | self.flow_ori_patch = tf.transpose(memodel.build_graph(tf.transpose(self.images_cur_ori_patch, [0, 3, 1, 2]), 62 | tf.transpose(self.images_pre_rec_patch, [0, 3, 1, 2])),[0, 2, 3, 1]) / 20.0 63 | 64 | ########MAMVP-Net######## 65 | flows_1_pre_rec_pyramid = build_flows_pyramid(self.flow_1_pre_rec, self.args.num_levels) 66 | flows_2_pre_rec_pyramid = build_flows_pyramid(self.flow_2_pre_rec, self.args.num_levels) 67 | flows_3_pre_rec_pyramid = build_flows_pyramid(self.flow_3_pre_rec, self.args.num_levels) 68 | mamvpmodel = MAMVPNet(num_levels=self.args.num_levels, 69 | warp_type=self.args.warp_type, 70 | use_dc=self.args.use_dc, 71 | output_level=self.args.output_level, 72 | name='amvpnet') 73 | self.flow_pred, _, _ = mamvpmodel(flows_3_pre_rec_pyramid, flows_2_pre_rec_pyramid, 74 | flows_1_pre_rec_pyramid) 75 | 76 | self.flow_diff = self.flow_ori_input - self.flow_pred 77 | ########MVD Autoencoder######## 78 | num_pixels = self.args.batch_size * self.image_size[0] * self.image_size[1] 79 | mvdmodel = bls2017ImgCompression_mvd_factor(2, self.args.mvd_M_filters, name='mvdnet') 80 | self.bit_string_mvd, entropy_bottleneck_mvd, self.flow_diff_rec_0, mvd_train_bpp = mvdmodel( 81 | self.flow_diff, num_pixels, reuse=False, isTrain=False) 82 | 83 | flow_rec = self.flow_pred + self.flow_diff_rec_0 84 | self.flow_rec_0 = flow_rec 85 | 86 | ########MV Refine-Net######## 87 | mvlfmodel = MVLoopFiltering(name='mvlfmodel') 88 | self.flow_rec = mvlfmodel(self.flow_3_pre_rec, self.flow_2_pre_rec, self.flow_1_pre_rec, flow_rec, 89 | self.images_pre_rec) 90 | self.flow_diff_rec = self.flow_rec - self.flow_pred 91 | 92 | ########MMC-Net######## 93 | mcmodel = MCNet_Multiple(name='MCNet') 94 | self.images_cur_pred, features_warped = mcmodel(self.images_pre_rec_4, self.images_pre_rec_3, 95 | self.images_pre_rec_2, 96 | self.images_pre_rec, self.flow_3_pre_rec, self.flow_2_pre_rec, 97 | self.flow_1_pre_rec, self.flow_rec_input) 98 | 99 | self.images_cur_resi = self.images_cur_ori - self.images_cur_pred_input 100 | ########Residual Autoencoder######## 101 | resimodel = bls2017ImgCompression_resi_RGB(3, self.args.resi_N_filters, self.args.resi_M_filters, 102 | name='resinet') 103 | self.bit_string_resi, entropy_bottleneck_resi, tensor_tilde, images_cur_resi_train_bpp, self.bit_string_resi_dev, entropy_bottleneck_dev, _, resi_dev_train_bpp = resimodel( 104 | self.images_cur_resi, num_pixels, reuse=False, isTrain=False) 105 | ########Residual Refine-Net######## 106 | resideblurmodel = ResiDeBlurNet(name='resideblurmodel') 107 | self.images_cur_resi_rec = resideblurmodel(tensor_tilde, self.images_cur_pred, features_warped) 108 | self.images_cur_rec = self.images_cur_pred_input + self.images_cur_resi_rec 109 | 110 | model_vars_restore = memodel.vars + mamvpmodel.vars + mvdmodel.vars + mvlfmodel.vars + mcmodel.vars + resimodel.vars + resideblurmodel.vars 111 | 112 | with tf.name_scope('Loss'): 113 | self._losses_mvd = [] 114 | self._losses_resi = [] 115 | loss = tf.reduce_mean(tf.squared_difference(self.images_cur_ori * 255.0, self.images_cur_rec * 255.0)) 116 | self._losses_resi.append(loss) 117 | self._losses_mvd.append(mvd_train_bpp) 118 | self._losses_resi.append(images_cur_resi_train_bpp) 119 | self._losses_resi.append(resi_dev_train_bpp) 120 | 121 | # Initialization 122 | self.sess.run(tf.global_variables_initializer()) 123 | 124 | if self.args.resume is not None: 125 | saver_0 = tf.train.Saver(model_vars_restore) 126 | print(f'Loading learned model from checkpoint {self.args.resume}') 127 | saver_0.restore(self.sess, self.args.resume) 128 | 129 | componet = ['RGB'] 130 | PSNR_sum_list = [] 131 | self._PSNR_list = [] 132 | self._MSSSIM_list = [] 133 | for i in range(len(componet)): 134 | ori = self.images_cur_ori[i] * 255 135 | rec = tf.round(tf.clip_by_value(self.images_cur_rec[i], 0, 1) * 255) 136 | PSNR = tf.squeeze(tf.image.psnr(ori, rec, 255)) 137 | MSSSIM = tf.squeeze(tf.image.ssim_multiscale(ori, rec, 255)) 138 | self._PSNR_list.append(PSNR) 139 | self._MSSSIM_list.append(MSSSIM) 140 | PSNR_sum_list.append(tf.summary.scalar('PSNR/' + componet[i], PSNR)) 141 | 142 | def test(self): 143 | Orig_dir = self.args.test_seq_dir 144 | x265enc_dir = os.path.join(self.args.exp_data_dir, 'I_frames_enc') 145 | # crf_list=[15,19,23,27,31,35,39,43] 146 | qp_dic = {16: 21, 24: 23, 40: 25, 64: 27} 147 | qp_list = [qp_dic[self.args.lmbda]] 148 | frames_to_be_encoded = 100 149 | Org_frm_list = list(range(frames_to_be_encoded)) 150 | classes_dict = get_test_class_dic(self.args.test_class) 151 | for key_class, value_class in classes_dict.items(): 152 | size_list = value_class['resolution'].split('x') 153 | width = int(size_list[0]) 154 | height = int(size_list[1]) 155 | for seq_idx in range(len(value_class['sequence_name'])): 156 | for qp in qp_list: 157 | ori_filename = os.path.join(Orig_dir, value_class['ori_yuv'][seq_idx]) 158 | print(key_class, value_class['sequence_name'][seq_idx], 'qp' + str(qp)) 159 | bits_list = [] 160 | RGB_PSNR_list = [] 161 | RGB_MSSSIM_list = [] 162 | 163 | ori_all_Y_list, ori_all_U_list, ori_all_V_list = yuv420_import(ori_filename, height, width, 164 | Org_frm_list, len(Org_frm_list), 165 | False, False, False, 0, False) 166 | ori_Y = ori_all_Y_list[0][np.newaxis, :, :, np.newaxis] 167 | ori_U = ori_all_U_list[0][np.newaxis, :, :, np.newaxis] 168 | ori_V = ori_all_V_list[0][np.newaxis, :, :, np.newaxis] 169 | RGB_ori = np.squeeze(YUV2RGB420_custom(ori_Y, ori_U, ori_V)) 170 | 171 | ori_file = os.path.join(x265enc_dir, value_class['sequence_name'][seq_idx] + '_' + value_class[ 172 | 'resolution'] + '.png') 173 | img_ori_save = Image.fromarray(RGB_ori) 174 | img_ori_save.save(ori_file) 175 | bin_file = os.path.join(x265enc_dir, 'enc_' + value_class['sequence_name'][seq_idx] + '_' + value_class[ 176 | 'resolution'] + '_' + str(value_class['frameRate'][seq_idx]) + '_qp' + str(qp) + '.bpg') 177 | rec_file = os.path.join(x265enc_dir, 'dec_' + value_class['sequence_name'][seq_idx] + '_' + \ 178 | value_class['resolution'] + '_' + str(value_class['frameRate'][seq_idx]) + '_qp' + str(qp) + '.png') 179 | os.system('bpgenc -f 444 -b 8 -q ' + str(qp) + ' ' + ori_file + ' -o ' + bin_file) 180 | os.system('bpgdec -o ' + rec_file + ' ' + bin_file) 181 | img_dec = Image.open(rec_file) 182 | RGB_rec = np.array(img_dec) 183 | 184 | Bits = os.path.getsize(bin_file) * 8 185 | bits_list.append(Bits) 186 | 187 | rgb_psnr, rgb_msssim = evaluate(RGB_ori, RGB_rec) 188 | RGB_PSNR_list.append(rgb_psnr) 189 | RGB_MSSSIM_list.append(rgb_msssim) 190 | print('I frame, total_bits:[%d],PSNR_RGB:[%.4f],MSSSIM_RGB:[%.5f]' % (bits_list[0], RGB_PSNR_list[0], RGB_MSSSIM_list[0])) 191 | 192 | images_prev_rec_tmp = RGB_rec[np.newaxis, :, :, :] / 255.0 193 | images_prev_rec = np.zeros((1, self.image_size[0], self.image_size[1], 3, 4), np.float32) 194 | for fr in range(4): 195 | images_prev_rec[:, :, :, :, fr] = images_prev_rec_tmp[:, :, :, :] 196 | 197 | flows_pre_rec = np.zeros((1, self.image_size[0], self.image_size[1], 2, 3), np.float32) 198 | start_time = time.time() 199 | for cur_indx in range(1, frames_to_be_encoded): 200 | cur_ori_Y = ori_all_Y_list[cur_indx][np.newaxis, :, :, np.newaxis] 201 | cur_ori_U = ori_all_U_list[cur_indx][np.newaxis, :, :, np.newaxis] 202 | cur_ori_V = ori_all_V_list[cur_indx][np.newaxis, :, :, np.newaxis] 203 | images_cur_ori = YUV2RGB420_custom(cur_ori_Y, cur_ori_U,cur_ori_V) / 255.0 204 | images_pre_rec = images_prev_rec[:, :, :, :, 3] 205 | if not ((self.image_size[0] > self.patch_height) and (self.image_size[1] > self.patch_width)): 206 | flow_ori_test = self.sess.run(self.flow_ori_patch, 207 | feed_dict={self.images_cur_ori_patch: images_cur_ori, 208 | self.images_pre_rec_patch: images_pre_rec}) 209 | else: 210 | images_cur_ori_patches_list = reshape2patches_tesnsor(images_cur_ori, self.patches_h_list, 211 | self.patches_w_list) 212 | images_pre_rec_patches_list = reshape2patches_tesnsor(images_pre_rec, self.patches_h_list, 213 | self.patches_w_list) 214 | flow_ori_patches_list = [] 215 | for idx, (images_cur_ori_patch, images_pre_rec_patch) in enumerate( 216 | zip(images_cur_ori_patches_list, images_pre_rec_patches_list)): 217 | flow_ori_patch = self.sess.run(self.flow_ori_patch, 218 | feed_dict={ 219 | self.images_cur_ori_patch: images_cur_ori_patch, 220 | self.images_pre_rec_patch: images_pre_rec_patch}) 221 | flow_ori_patches_list.append(flow_ori_patch) 222 | 223 | flow_ori_test = reshape2image_tesnsor(flow_ori_patches_list, self.real_patches_h_list, 224 | self.real_patches_w_list, self.patch_num_heigt, 225 | self.patch_num_width) 226 | 227 | bit_string_mvd_test, flow_pred_test, flow_diff_test, flow_diff_rec_0_test, flow_diff_rec_test, flow_rec_0_test, flow_rec_test, losses_mvd_test = self.sess.run( 228 | [self.bit_string_mvd, 229 | self.flow_pred, 230 | self.flow_diff, self.flow_diff_rec_0, 231 | self.flow_diff_rec, self.flow_rec_0, self.flow_rec, self._losses_mvd], 232 | feed_dict={self.flow_ori_input: flow_ori_test, 233 | self.flows_pre_rec: flows_pre_rec, 234 | self.images_pre_rec_multiframes: images_prev_rec}) 235 | images_cur_pred_test = self.sess.run( 236 | self.images_cur_pred, 237 | feed_dict={self.images_pre_rec_multiframes: images_prev_rec, 238 | self.flow_rec_input: flow_rec_test, 239 | self.flows_pre_rec: flows_pre_rec}) 240 | bit_string_resi_test, bit_string_resi_dev_test, flow_3_pre_rec_test, flow_2_pre_rec_test, flow_1_pre_rec_test, images_cur_resi_test, images_cur_resi_rec_test, images_cur_rec_test, losses_resi_test, PSNR_list_test, MSSSIM_list_test = self.sess.run( 241 | [self.bit_string_resi, self.bit_string_resi_dev, self.flow_3_pre_rec, self.flow_2_pre_rec, 242 | self.flow_1_pre_rec, 243 | self.images_cur_resi, 244 | self.images_cur_resi_rec, self.images_cur_rec, self._losses_resi, 245 | self._PSNR_list, self._MSSSIM_list], 246 | feed_dict={self.images_pre_rec_multiframes: images_prev_rec, 247 | self.flow_rec_input: flow_rec_test, 248 | self.images_cur_ori: images_cur_ori, 249 | self.images_cur_pred_input: images_cur_pred_test, 250 | self.flows_pre_rec: flows_pre_rec}) 251 | mvd_bpp_test = losses_mvd_test[0] 252 | 253 | resi_bpp_test = losses_resi_test[1] 254 | resi_dev_bpp_test = losses_resi_test[2] 255 | mvd_bits_info = int(mvd_bpp_test * width * height + 0.5) 256 | resi_bits_info = int(resi_bpp_test * width * height + 0.5) 257 | resi_dev_bits_info = int(resi_dev_bpp_test * width * height + 0.5) 258 | cur_bits = mvd_bits_info + resi_bits_info + resi_dev_bits_info 259 | bits_list.append(cur_bits) 260 | RGB_PSNR_list.append(PSNR_list_test[0]) 261 | RGB_MSSSIM_list.append(MSSSIM_list_test[0]) 262 | if True: 263 | print( 264 | "cur_idx[%2d],time:[%4.4f],mvd_bits_info:[%d],resi_bits_info:[%d],resi_dev_bits_info:[%d],total_bits:[%d],PSNR_RGB:[%.4f],MSSSIM_RGB:[%.5f]" 265 | % (cur_indx, time.time() - start_time, 266 | mvd_bits_info, resi_bits_info, resi_dev_bits_info, cur_bits, 267 | PSNR_list_test[0], 268 | MSSSIM_list_test[0])) 269 | if self.args.visualize: 270 | All_flows_to_vis = [] 271 | All_flows_to_vis.append(np.squeeze(flow_3_pre_rec_test[0] * 20)) 272 | All_flows_to_vis.append(np.squeeze(flow_2_pre_rec_test[0] * 20)) 273 | All_flows_to_vis.append(np.squeeze(flow_1_pre_rec_test[0] * 20)) 274 | All_flows_to_vis.append(np.squeeze(flow_ori_test[0] * 20)) 275 | All_flows_to_vis.append(np.squeeze(flow_pred_test[0] * 20)) 276 | All_flows_to_vis.append(np.squeeze(flow_diff_test[0] * 20)) 277 | All_flows_to_vis.append(np.squeeze(flow_diff_rec_0_test[0] * 20)) 278 | All_flows_to_vis.append(np.squeeze(flow_diff_rec_test[0] * 20)) 279 | All_flows_to_vis.append(np.squeeze(flow_rec_0_test[0] * 20)) 280 | All_flows_to_vis.append(np.squeeze(flow_rec_test[0] * 20)) 281 | 282 | All_RGB_images_to_vis = [] 283 | All_Gray_images_to_vis = [] 284 | 285 | RGB_frames = np.clip(images_prev_rec[:, :, :, :, 0] * 255, 0, 255).astype(np.uint8) 286 | All_RGB_images_to_vis.append(np.squeeze(RGB_frames[0])) 287 | RGB_frames = np.clip(images_prev_rec[:, :, :, :, 1] * 255, 0, 255).astype(np.uint8) 288 | All_RGB_images_to_vis.append(np.squeeze(RGB_frames[0])) 289 | RGB_frames = np.clip(images_prev_rec[:, :, :, :, 2] * 255, 0, 255).astype(np.uint8) 290 | All_RGB_images_to_vis.append(np.squeeze(RGB_frames[0])) 291 | RGB_frames = np.clip(images_prev_rec[:, :, :, :, 3] * 255, 0, 255).astype(np.uint8) 292 | All_RGB_images_to_vis.append(np.squeeze(RGB_frames[0])) 293 | 294 | RGB_frames = np.clip(images_cur_ori * 255, 0, 255).astype(np.uint8) 295 | cur_ori_Gray_frames = cv2.cvtColor(np.squeeze(RGB_frames[0]), cv2.COLOR_RGB2GRAY) 296 | All_RGB_images_to_vis.append(np.squeeze(RGB_frames[0])) 297 | 298 | RGB_frames = np.clip(images_cur_pred_test * 255, 0, 255).astype(np.uint8) 299 | cur_pred_Gray_frames = cv2.cvtColor(np.squeeze(RGB_frames[0]), cv2.COLOR_RGB2GRAY) 300 | All_RGB_images_to_vis.append(np.squeeze(RGB_frames[0])) 301 | 302 | RGB_frames = np.clip(images_cur_rec_test * 255, 0, 255).astype(np.uint8) 303 | cur_rec_Gray_frames = cv2.cvtColor(np.squeeze(RGB_frames[0]), cv2.COLOR_RGB2GRAY) 304 | All_RGB_images_to_vis.append(np.squeeze(RGB_frames[0])) 305 | All_Gray_images_to_vis.append( 306 | np.squeeze(np.clip((cur_ori_Gray_frames.astype(np.int64) - cur_pred_Gray_frames.astype( 307 | np.int64)) * 5.0 + 128 + 0.5, 308 | 0, 255).astype(np.uint8))) 309 | All_Gray_images_to_vis.append( 310 | np.squeeze(np.clip((cur_rec_Gray_frames.astype(np.int64) - cur_pred_Gray_frames.astype( 311 | np.int64)) * 5.0 + 128 + 0.5, 312 | 0, 255).astype(np.uint8))) 313 | 314 | vis_flow_image_final([], All_flows_to_vis, All_RGB_images_to_vis, All_Gray_images_to_vis, 315 | filename=os.path.join(self.args.exp_data_dir, 'figure', 316 | key_class + '_' + 317 | value_class['sequence_name'][ 318 | seq_idx] + '_curidx' + str( 319 | cur_indx) + '.png')) 320 | 321 | images_prev_rec[:, :, :, :, 0:3] = images_prev_rec[:, :, :, :, 1:4] 322 | images_prev_rec[:, :, :, :, 3] = np.clip(images_cur_rec_test * 255 + 0.5, 0, 255).astype( 323 | np.uint8) / 255.0 324 | flows_pre_rec[:, :, :, :, 0:2] = flows_pre_rec[:, :, :, :, 1:3] 325 | flows_pre_rec[:, :, :, :, 2] = flow_rec_test 326 | 327 | Bpp_avg = np.mean(bits_list) / float(width * height) 328 | RGB_PSNR_mean = np.mean(RGB_PSNR_list) 329 | MSSSIM_avg = np.mean(RGB_MSSSIM_list) 330 | print("Summary: Bpp: [%.4f],PSNR_RGB: [%.4f],MSSSIM_RGB: [%.5f]" 331 | % (Bpp_avg, RGB_PSNR_mean, MSSSIM_avg)) 332 | 333 | 334 | if __name__ == '__main__': 335 | parser = argparse.ArgumentParser() 336 | parser.add_argument('-tsd', '--test_seq_dir', type=str, default='/testSequence', 337 | help='Directory containing test sequences') 338 | parser.add_argument('--test_class', type=str, default='ClassC', 339 | help='Directory containing test sequences') 340 | parser.add_argument('-edd', '--exp_data_dir', type=str, required=True, 341 | help='Directory containing experiment data') 342 | 343 | parser.add_argument("--batch_size", type=int, default=1, 344 | help="Number of filters per layer.") 345 | 346 | parser.add_argument('--num_levels', type=int, default=4, 347 | help='# of levels for feature extraction [6]') 348 | parser.add_argument('--warp_type', default='bilinear', choices=['bilinear', 'nearest'], 349 | help='Warping protocol, [bilinear] or nearest') 350 | parser.add_argument('--use-dc', dest='use_dc', action='store_true', 351 | help='Enable dense connection in optical flow estimator, [diabled] as default') 352 | parser.add_argument('--no-dc', dest='use_dc', action='store_false', 353 | help='Disable dense connection in optical flow estimator, [disabled] as default') 354 | parser.set_defaults(use_dc=False) 355 | parser.add_argument('--output_level', type=int, default=3, 356 | help='Final output level for estimated flow [4]') 357 | 358 | parser.add_argument('-v', '--visualize', dest='visualize', action='store_true', 359 | help='Enable estimated flow visualization, [enabled] as default') 360 | parser.add_argument('--no-visualize', dest='visualize', action='store_false', 361 | help='Disable estimated flow visualization, [enabled] as default') 362 | parser.set_defaults(visualize=True) 363 | parser.add_argument('-r', '--resume', type=str, default=None, 364 | help='Learned parameter checkpoint file [None]') 365 | parser.add_argument('-rME', '--resumeMEnet', type=str, default=None, 366 | help='Learned parameter checkpoint file [None]') 367 | parser.add_argument('-rMC', '--resumeMCnet', type=str, default=None, 368 | help='Learned parameter checkpoint file [None]') 369 | 370 | parser.add_argument( 371 | "--command", choices=["train", "compress", "decompress"], 372 | help="What to do: 'train' loads training data and trains (or continues " 373 | "to train) a new model. 'compress' reads an image file (lossless " 374 | "PNG format) and writes a compressed binary file. 'decompress' " 375 | "reads a binary file and reconstructs the image (in PNG format). " 376 | "input and output filenames need to be provided for the latter " 377 | "two options.") 378 | parser.add_argument( 379 | "--mvd_N_filters", type=int, default=128, 380 | help="Number of filters per layer.") 381 | parser.add_argument( 382 | "--mvd_M_filters", type=int, default=192, 383 | help="Number of filters per layer.") 384 | parser.add_argument( 385 | "--resi_N_filters", type=int, default=128, 386 | help="Number of filters per layer.") 387 | parser.add_argument( 388 | "--resi_M_filters", type=int, default=192, 389 | help="Number of filters per layer.") 390 | parser.add_argument( 391 | "--lambda", type=int, default=16, dest="lmbda", 392 | help="Lambda for rate-distortion tradeoff.") 393 | 394 | args = parser.parse_args() 395 | for key, item in vars(args).items(): 396 | print(f'{key} : {item}') 397 | 398 | os.environ['CUDA_VISIBLE_DEVICES'] = input('Input utilize gpu-id (-1:cpu) : ') 399 | 400 | tester = Tester(args) 401 | tester.test() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import shutil 4 | from collections import OrderedDict 5 | from datetime import datetime 6 | from pathlib import Path 7 | import numpy as np 8 | import cv2 9 | import random 10 | import os 11 | import tensorflow as tf 12 | from yuv_import import * 13 | from msssim import MultiScaleSSIM as msssim_ 14 | 15 | 16 | def build_flows_pyramid(flows, num_levels): 17 | flows_pyramid = [] 18 | flow_scales = [2.5, 5.0, 10., 20.] 19 | for l in range(num_levels): 20 | # Downsampling the scaled ground truth flow 21 | _, h, w, _ = tf.unstack(tf.shape(flows)) 22 | downscale = 2 ** (num_levels - 1 - l) 23 | if l == num_levels - 1: 24 | flow_down = flows * flow_scales[l] 25 | else: 26 | flow_down = tf.image.resize_nearest_neighbor(flows, 27 | (tf.floordiv(h, downscale), tf.floordiv(w, downscale))) * \ 28 | flow_scales[l] 29 | flows_pyramid.append(flow_down) 30 | 31 | return flows_pyramid 32 | 33 | def get_patches(image_height, image_width): 34 | real_patch_heigt = 512 35 | real_patch_width = 512 36 | overlap = 64 37 | patch_heigt = real_patch_heigt + overlap 38 | patch_width = real_patch_width + overlap 39 | patches_num_height = image_height // real_patch_heigt 40 | patches_num_width = image_width // real_patch_width 41 | if image_height % real_patch_heigt > 0: 42 | patches_num_height += 1 43 | if image_width % real_patch_width > 0: 44 | patches_num_width += 1 45 | patches_h_list = [] 46 | patches_w_list = [] 47 | real_patches_h_list = [] 48 | real_patches_w_list = [] 49 | for patch_idx_h in range(patches_num_height): 50 | for patch_idx_w in range(patches_num_width): 51 | if patch_idx_h == 0: 52 | patches_h_start = 0 53 | real_patches_h_start = 0 54 | real_patches_h_end = real_patches_h_start + real_patch_heigt 55 | elif patch_idx_h == patches_num_height - 1: 56 | patches_h_start = image_height - patch_heigt 57 | real_patches_h_start = patch_heigt-(image_height - (patches_num_height-1)*real_patch_heigt) 58 | real_patches_h_end = patch_heigt 59 | else: 60 | patches_h_start = patch_idx_h * real_patch_heigt - overlap // 2 61 | real_patches_h_start = overlap // 2 62 | real_patches_h_end = real_patches_h_start + real_patch_heigt 63 | patches_h_end = patches_h_start + patch_heigt 64 | 65 | if patch_idx_w == 0: 66 | patches_w_start = 0 67 | real_patches_w_start = 0 68 | real_patches_w_end = real_patches_w_start + real_patch_width 69 | elif patch_idx_w == patches_num_width - 1: 70 | patches_w_start = image_width - patch_width 71 | real_patches_w_start = patch_width-(image_width - (patches_num_width-1)*real_patch_width) 72 | real_patches_w_end = patch_width 73 | else: 74 | patches_w_start = patch_idx_w * real_patch_width - overlap // 2 75 | real_patches_w_start = overlap // 2 76 | real_patches_w_end = real_patches_w_start + real_patch_width 77 | patches_w_end = patches_w_start + patch_width 78 | 79 | patches_h_list.append([patches_h_start, patches_h_end]) 80 | patches_w_list.append([patches_w_start, patches_w_end]) 81 | real_patches_h_list.append([real_patches_h_start, real_patches_h_end]) 82 | real_patches_w_list.append([real_patches_w_start, real_patches_w_end]) 83 | return patches_h_list,patches_w_list,real_patches_h_list,real_patches_w_list, patch_heigt, patch_width, patches_num_height, patches_num_width 84 | 85 | def reshape2patches_tesnsor(tensor, patches_h_list, patches_w_list): 86 | patches_tensor_list = [] 87 | for (patches_h, patches_w) in zip(patches_h_list, patches_w_list): 88 | patches_h_start, patches_h_end = patches_h 89 | patches_w_start, patches_w_end = patches_w 90 | patches_tensor_list.append(tensor[:,patches_h_start:patches_h_end,patches_w_start:patches_w_end,:]) 91 | #patches_tensor = tf.stack(patches_tensor_list, axis=0) 92 | #patches_tensor = tf.squeeze(patches_tensor,axis=1) 93 | return patches_tensor_list 94 | 95 | def reshape2image_tesnsor(patches_tensor_list, real_patches_h_list, real_patches_w_list, patches_num_height, patches_num_width): 96 | idx = 0 97 | patch_h_list = [] 98 | for patch_idx_h in range(patches_num_height): 99 | patch_w_list = [] 100 | for patch_idx_w in range(patches_num_width): 101 | real_patches_h_start, real_patches_h_end = real_patches_h_list[idx] 102 | real_patches_w_start, real_patches_w_end = real_patches_w_list[idx] 103 | patch = patches_tensor_list[idx][:, real_patches_h_start:real_patches_h_end, real_patches_w_start:real_patches_w_end, :] 104 | patch_w_list.append(patch) 105 | idx += 1 106 | patch_h = np.concatenate(tuple(patch_w_list),axis=2) 107 | patch_h_list.append(patch_h) 108 | tensor = np.concatenate(tuple(patch_h_list), axis=1) 109 | return tensor 110 | 111 | def YUV2RGB420_custom(Y_frames,U_frames,V_frames): 112 | shape = np.shape(Y_frames) 113 | n = shape[0] 114 | h = shape[1] 115 | w = shape[2] 116 | RGB_frames = np.zeros([n, h, w, 3], np.uint8) 117 | yuv_frame=np.zeros([h*3//2,w],np.uint8) 118 | 119 | for n_i in range(n): 120 | for f_i in range(1): 121 | Y = Y_frames[n_i, :, :, 0] 122 | U = U_frames[n_i, :, :, 0] 123 | V = V_frames[n_i, :, :, 0] 124 | yuv_frame[:h,:]=Y 125 | yuv_frame[h:5 * h // 4, :] = np.reshape(U, [h // 4, w]) 126 | yuv_frame[5 * h // 4 :, :] = np.reshape(V, [h // 4, w]) 127 | rgb = cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2RGB_I420) 128 | RGB_frames[n_i,:,:,:]=rgb 129 | return RGB_frames 130 | 131 | def log10(x): 132 | numerator = tf.log(x) 133 | denominator = tf.log(tf.constant(10, dtype=numerator.dtype)) 134 | return numerator / denominator 135 | 136 | def get_test_class_dic(class_name): 137 | classes_dict = { 138 | 'ClassD': 139 | {'sequence_name': ['RaceHorses', 'BQSquare', 'BlowingBubbles', 'BasketballPass'], 140 | 'ori_yuv': ['RaceHorses_384x192_30.yuv', 'BQSquare_384x192_60.yuv', 141 | 'BlowingBubbles_384x192_50.yuv', 'BasketballPass_384x192_50.yuv'], 142 | 'resolution': '384x192', 143 | 'frameRate': [30, 60, 50, 50] 144 | }, 145 | 'ClassC': 146 | {'sequence_name': ['RaceHorsesC', 'BQMall', 'PartyScene', 'BasketballDrill'], 147 | 'ori_yuv': ['RaceHorses_832x448_30.yuv', 'BQMall_832x448_60.yuv', 'PartyScene_832x448_50.yuv', 148 | 'BasketballDrill_832x448_50.yuv'], 149 | 'resolution': '832x448', 150 | 'frameRate': [30, 60, 50, 50] 151 | }, 152 | 'ClassE': 153 | {'sequence_name': ['FourPeople', 'Johnny', 'KristenAndSara'], 154 | 'ori_yuv': ['FourPeople_1280x704_60.yuv', 'Johnny_1280x704_60.yuv', 155 | 'KristenAndSara_1280x704_60.yuv'], 156 | 'resolution': '1280x704', 157 | 'frameRate': [60, 60, 60] 158 | }, 159 | 'ClassB': 160 | {'sequence_name': ['Kimono', 'ParkScene', 'Cactus', 'BasketballDrive', 'BQTerrace'], 161 | 'ori_yuv': ['Kimono_1920x1024_24.yuv', 'ParkScene_1920x1024_24.yuv', 'Cactus_1920x1024_50.yuv', 162 | 'BasketballDrive_1920x1024_50.yuv', 'BQTerrace_1920x1024_60.yuv'], 163 | 'resolution': '1920x1024', 164 | 'frameRate': [24, 24, 50, 50, 60] 165 | }, 166 | 'ClassUVG': 167 | {'sequence_name': ['Beauty', 'Bosphorus', 'HoneyBee', 'Jockey', 168 | 'ReadySteadyGo', 'ShakeNDry', 'YachtRide'], 169 | 'ori_yuv': ['Beauty_1920x1024_120fps.yuv', 'Bosphorus_1920x1024_120fps.yuv', 170 | 'HoneyBee_1920x1024_120fps.yuv', 'Jockey_1920x1024_120fps.yuv', 171 | 'ReadySteadyGo_1920x1024_120fps.yuv', 'ShakeNDry_1920x1024_120fps.yuv', 172 | 'YachtRide_1920x1024_120fps.yuv'], 173 | 'resolution': '1920x1024', 174 | 'frameRate': [120, 120, 120, 120, 120, 120, 120] 175 | } 176 | } 177 | 178 | return {class_name: classes_dict[class_name]} 179 | 180 | def rgb_psnr_(img1, img2): 181 | mse = np.mean((img1.astype(np.float32) - img2.astype(np.float32)) ** 2) 182 | if mse == 0: 183 | return 100 184 | PIXEL_MAX = 255.0 185 | return 10 * math.log10(PIXEL_MAX**2 / mse) 186 | def evaluate(img0, img1): 187 | img0 = img0.astype('float32') 188 | img1 = img1.astype('float32') 189 | rgb_psnr = rgb_psnr_(img0, img1) 190 | r_msssim = msssim_(img0[:, :, 0], img1[:, :, 0]) 191 | g_msssim = msssim_(img0[:, :, 1], img1[:, :, 1]) 192 | b_msssim = msssim_(img0[:, :, 2], img1[:, :, 2]) 193 | rgb_msssim = (r_msssim + g_msssim + b_msssim)/3 194 | return rgb_psnr, rgb_msssim -------------------------------------------------------------------------------- /yuv_import.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import cv2 4 | import math 5 | # import random 6 | # import ipdb 7 | MSB2Num = [0,256,512,768] # for 10bit yuv, the Higher 2 bit could only be 00,01,10,11. The pixel_value = LSB_value + MSB2Num[MSB_value] 8 | ## randomly extract few frames from a sequence 9 | def yuv420_import(filename,height,width,Org_frm_list,extract_frm,flag,isRef,isList0,deltaPOC,isLR): 10 | # print filename, height, width 11 | if flag: 12 | frm_size = int(float(height*width*3)/float(2*2)) ## for 10bit yuv, each pixel occupy 2 byte 13 | else: 14 | frm_size = int(float(height*width*3)/2) 15 | if isLR: 16 | frm_size=frm_size*4 17 | if flag: 18 | row_size = width * 2 19 | else: 20 | row_size = width 21 | Luma = [] 22 | U=[] 23 | V=[] 24 | # Org_frm_list = range(1,numfrm) 25 | # random.shuffle(Org_frm_list) 26 | with open(filename,'rb') as fd: 27 | for extract_index in range(extract_frm): 28 | if isRef: 29 | if isList0: 30 | current_frm = Org_frm_list[extract_index]-deltaPOC 31 | else: 32 | current_frm = Org_frm_list[extract_index]+deltaPOC 33 | else: 34 | current_frm = Org_frm_list[extract_index] 35 | fd.seek(frm_size*current_frm,0) 36 | # ipdb.set_trace() 37 | if flag: 38 | Yt = np.zeros((height,width),np.uint16,'C') 39 | for m in range(height): 40 | for n in range(width): 41 | symbol = fd.read(2) 42 | LSB = ord(symbol[0]) 43 | MSB = ord(symbol[1]) 44 | Pixel_Value = LSB+MSB2Num[MSB] 45 | Yt[m,n]=Pixel_Value 46 | if isLR: 47 | fd.seek(row_size, 1) 48 | Luma.append(Yt) 49 | del Yt 50 | else: 51 | Yt = np.zeros((height,width),np.uint8,'C') 52 | for m in range(height): 53 | for n in range(width): 54 | symbol = fd.read(1) 55 | Pixel_Value = ord(symbol) 56 | Yt[m,n]=Pixel_Value 57 | if isLR: 58 | fd.seek(row_size, 1) 59 | Luma.append(Yt) 60 | del Yt 61 | Ut = np.zeros((height//2, width//2), np.uint8, 'C') 62 | for m in range(height//2): 63 | for n in range(width//2): 64 | symbol = fd.read(1) 65 | Pixel_Value = ord(symbol) 66 | Ut[m, n] = Pixel_Value 67 | if isLR: 68 | fd.seek(row_size, 1) 69 | U.append(Ut) 70 | del Ut 71 | Vt = np.zeros((height // 2, width // 2), np.uint8, 'C') 72 | for m in range(height // 2): 73 | for n in range(width // 2): 74 | symbol = fd.read(1) 75 | Pixel_Value = ord(symbol) 76 | Vt[m, n] = Pixel_Value 77 | if isLR: 78 | fd.seek(row_size, 1) 79 | V.append(Vt) 80 | del Vt 81 | return Luma,U,V 82 | 83 | def psnr(target, ref): 84 | target_data = np.asarray(target, 'f') 85 | ref_data = np.asarray(ref, 'f') 86 | diff = ref_data - target_data 87 | diff = diff.flatten('C') 88 | rmse = np.mean(diff ** 2.) 89 | return 10 * math.log10(255 ** 2. / rmse) 90 | 91 | def YUV2RGB420_custom(Y_frames,U_frames,V_frames): 92 | shape = np.shape(Y_frames) 93 | n = shape[0] 94 | h = shape[1] 95 | w = shape[2] 96 | RGB_frames = np.zeros([n, h, w, 3], np.uint8) 97 | yuv_frame=np.zeros([h*3//2,w],np.uint8) 98 | 99 | for n_i in range(n): 100 | for f_i in range(1): 101 | Y = Y_frames[n_i, :, :, 0] 102 | U = U_frames[n_i, :, :, 0] 103 | V = V_frames[n_i, :, :, 0] 104 | yuv_frame[:h,:]=Y 105 | yuv_frame[h:5 * h // 4, :] = np.reshape(U, [h // 4, w]) 106 | yuv_frame[5 * h // 4 :, :] = np.reshape(V, [h // 4, w]) 107 | rgb = cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2RGB_I420) 108 | RGB_frames[n_i,:,:,:]=rgb 109 | return RGB_frames -------------------------------------------------------------------------------- /zlib1.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JianpingLin/M-LVC_CVPR2020/d2c20635c61aa786115622694d03d3ba0e4f1d15/zlib1.dll --------------------------------------------------------------------------------