├── IdentityLUT33.txt ├── IdentityLUT64.txt ├── LICENSE ├── README.md ├── average_psnr_ssim.m ├── datasets.py ├── demo_eval.py ├── demo_images ├── XYZ │ └── a1629.png └── sRGB │ └── a1629.jpg ├── figures └── framework2.png ├── image_adaptive_lut_evaluation.py ├── image_adaptive_lut_train_paired.py ├── image_adaptive_lut_train_unpaired.py ├── local_tone_mapping ├── a1509.jpg ├── wlsFilter.m └── wlsTonemap.m ├── models.py ├── models_x.py ├── pretrained_models ├── XYZ │ ├── LUTs.pth │ ├── LUTs_unpaired.pth │ ├── classifier.pth │ └── classifier_unpaired.pth └── sRGB │ ├── LUTs.pth │ ├── LUTs_unpaired.pth │ ├── classifier.pth │ └── classifier_unpaired.pth ├── requirements ├── ssim.m ├── torchvision_x_functional.py ├── trilinear_c ├── build.py ├── make.sh └── src │ ├── trilinear.c │ ├── trilinear.h │ ├── trilinear_cuda.c │ ├── trilinear_cuda.h │ ├── trilinear_kernel.cu │ ├── trilinear_kernel.cu.o │ └── trilinear_kernel.h ├── trilinear_cpp ├── setup.py ├── setup.sh └── src │ ├── trilinear.cpp │ ├── trilinear.h │ ├── trilinear_cuda.cpp │ ├── trilinear_cuda.h │ ├── trilinear_kernel.cu │ └── trilinear_kernel.h ├── utils ├── generate_identity_3DLUT.py └── visualize_lut.py └── visualization_lut ├── learned_LUT_234_1.txt ├── learned_LUT_234_2.txt ├── learned_LUT_234_3.txt ├── save_trained_luts.py └── visualize_lut.m /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Image-Adaptive-3DLUT 2 | Learning Image-adaptive 3D Lookup Tables for High Performance Photo Enhancement in Real-time 3 | 4 | ## Downloads 5 | ### [Paper](https://www4.comp.polyu.edu.hk/~cslzhang/paper/PAMI_LUT.pdf), [Supplementary](https://www4.comp.polyu.edu.hk/~cslzhang/paper/Supplement_LUT.pdf), Datasets([[GoogleDrive](https://drive.google.com/drive/folders/1Y1Rv3uGiJkP6CIrNTSKxPn1p-WFAc48a?usp=sharing)],[[onedrive](https://connectpolyu-my.sharepoint.com/:f:/g/personal/16901447r_connect_polyu_hk/EqNGuQUKZe9Cv3fPG08OmGEBbHMUXey2aU03E21dFZwJyg?e=QNCMMZ)],[[baiduyun](https://pan.baidu.com/s/1CsQRFsEPZCSjkT3Z1X_B1w):5fyk]) 6 | The whole datasets used in the paper are over 300G. Here I only provided the FiveK dataset resized into 480p resolution (including 8-bit sRGB, 16-bit XYZ inputs and 8-bit sRGB targets). I also provided 10 full-resolution images for testing speed. To obtain the entire full-resolution images, it is recommended to convert from the original [FiveK](https://data.csail.mit.edu/graphics/fivek/) dataset. 7 | 8 | A model trained on the 480p resolution can be directly applied to images of 4K (or higher) resolution without performance drop. This can significantly speedup the training stage without loading the very heavy high-resolution images. 9 | 10 | ## Abstract 11 | Recent years have witnessed the increasing popularity of learning based methods to enhance the color and tone of photos. However, many existing photo enhancement methods either deliver unsatisfactory results or consume too much computational and memory resources, hindering their application to high-resolution images (usually with more than 12 megapixels) in practice. In this paper, we learn image-adaptive 3-dimensional lookup tables (3D LUTs) to achieve fast and robust photo enhancement. 3D LUTs are widely used for manipulating color and tone of photos, but they are usually manually tuned and fixed in camera imaging pipeline or photo editing tools. We, for the first time to our best knowledge, propose to learn 3D LUTs from annotated data using pairwise or unpaired learning. More importantly, our learned 3D LUT is image-adaptive for flexible photo enhancement. We learn multiple basis 3D LUTs and a small convolutional neural network (CNN) simultaneously in an end-to-end manner. The small CNN works on the down-sampled version of the input image to predict content-dependent weights to fuse the multiple basis 3D LUTs into an image-adaptive one, which is employed to transform the color and tone of source images efficiently. Our model contains less than **600K** parameters and takes **less than 2 ms** to process an image of 4K resolution using one Titan RTX GPU. While being highly efficient, our model also outperforms the state-of-the-art photo enhancement methods by a large margin in terms of PSNR, SSIM and a color difference metric on two publically available benchmark datasets. 12 | 13 | ## Framework 14 | 15 | 16 | ## Usage 17 | 18 | ### Useful issues 19 | Replace the trilinear interpolation with torch.nn.functional.grid_sample [https://github.com/HuiZeng/Image-Adaptive-3DLUT/issues/14]. 20 | 21 | ### Requirements 22 | Python3, requirements.txt 23 | 24 | ### Build 25 | By default, we use pytorch 0.4.1: 26 | 27 | cd trilinear_c 28 | sh make.sh 29 | 30 | For pytorch 1.x: 31 | 32 | cd trilinear_cpp 33 | sh setup.sh 34 | 35 | Please also replace the following lines: 36 | ``` 37 | # in image_adaptive_lut_train_paired.py, image_adaptive_lut_evaluation.py, demo_eval.py, and image_adaptive_lut_train_unpaired.py 38 | from models import * --> from models_x import * 39 | # in demo_eval.py 40 | result = trilinear_(LUT, img) --> _, result = trilinear_(LUT, img) 41 | # in image_adaptive_lut_train_paired.py and image_adaptive_lut_evaluation.py 42 | combine_A = trilinear_(LUT,img) --> _, combine_A = trilinear_(LUT,img) 43 | ``` 44 | 45 | ### Training 46 | #### paired training 47 | python3 image_adaptive_lut_train_paired.py 48 | #### unpaired training 49 | python3 image_adaptive_lut_train_unpaired.py 50 | 51 | ### Evaluation 52 | 1. use python to generate and save the test images: 53 | 54 | python3 image_adaptive_lut_evaluation.py 55 | 56 | speed can also be tested in above code. 57 | 58 | 2. use matlab to calculate the indexes used in our paper: 59 | 60 | average_psnr_ssim.m 61 | 62 | ### Demo 63 | 64 | python3 demo_eval.py 65 | 66 | ### Tools 67 | 1. You can generate identity 3DLUT with arbitrary dimension by using `utils/generate_identity_3DLUT.py` as follows: 68 | 69 | ``` 70 | # you can replace 33 with any number you want 71 | python3 utils/generate_identity_3DLUT.py -d 33 72 | ``` 73 | 74 | 2. You can visualize the learned 3D LUT either by using the matlab code in `visualization_lut` or using the python code `utils/visualize_lut.py` as follows: 75 | 76 | ``` 77 | python3 utils/visualize_lut.py path/to/your/lut 78 | # you can also modify the dimension of the lut as follows 79 | python3 utils/visualize_lut.py path/to/your/lut --lut_dim 64 80 | ``` 81 | 82 | ## Citation 83 | ``` 84 | @article{zeng2020lut, 85 | title={Learning Image-adaptive 3D Lookup Tables for High Performance Photo Enhancement in Real-time}, 86 | author={Zeng, Hui and Cai, Jianrui and Li, Lida and Cao, Zisheng and Zhang, Lei}, 87 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 88 | volume={44}, 89 | number={04}, 90 | pages={2058--2073}, 91 | year={2022}, 92 | publisher={IEEE Computer Society} 93 | } 94 | 95 | @inproceedings{zhang2022clut, 96 | title={CLUT-Net: Learning Adaptively Compressed Representations of 3DLUTs for Lightweight Image Enhancement}, 97 | author={Zhang, Fengyi and Zeng, Hui and Zhang, Tianjun and Zhang, Lin}, 98 | booktitle={Proceedings of the 30th ACM International Conference on Multimedia}, 99 | pages={6493--6501}, 100 | year={2022} 101 | } 102 | ``` 103 | -------------------------------------------------------------------------------- /average_psnr_ssim.m: -------------------------------------------------------------------------------- 1 | test_path = 'images/LUTs/paired/fiveK_480p_3LUT_sm_1e-4_mn_10_sRGB_145'; 2 | gt_path = '../data/fiveK/expertC/JPG/480p'; 3 | path_list = dir(fullfile(test_path,'*.png')); 4 | img_num = length(path_list); 5 | %calculate psnr 6 | total_psnr = 0; 7 | total_ssim = 0; 8 | total_color = 0; 9 | if img_num > 0 10 | for j = 1:img_num 11 | image_name = path_list(j).name; 12 | input = imread(fullfile(test_path,image_name)); 13 | gt = imread(fullfile(gt_path,[image_name(1:end-3), 'jpg'])); 14 | 15 | psnr_val = psnr(im2double(input), im2double(gt)); 16 | total_psnr = total_psnr + psnr_val; 17 | 18 | ssim_val = ssim(input, gt); 19 | total_ssim = total_ssim + ssim_val; 20 | 21 | color = sqrt(sum((rgb2lab(gt) - rgb2lab(input)).^2,3)); 22 | color = mean(color(:)); 23 | total_color = total_color + color; 24 | fprintf('%d %f %f %f %s\n',j,psnr_val,ssim_val,color,fullfile(test_path,image_name)); 25 | end 26 | end 27 | qm_psnr = total_psnr / img_num; 28 | avg_ssim = total_ssim / img_num; 29 | avg_color = total_color / img_num; 30 | fprintf('The avgerage psnr is: %f', qm_psnr); 31 | fprintf('The avgerage ssim is: %f', avg_ssim); 32 | fprintf('The avgerage lab is: %f', avg_color); 33 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | import os 4 | import numpy as np 5 | import torch 6 | import cv2 7 | 8 | from torch.utils.data import Dataset 9 | from PIL import Image 10 | import torchvision.transforms as transforms 11 | import torchvision.transforms.functional as TF 12 | import torchvision_x_functional as TF_x 13 | 14 | 15 | class ImageDataset_sRGB(Dataset): 16 | def __init__(self, root, mode="train", unpaird_data="fiveK", combined=True): 17 | self.mode = mode 18 | self.unpaird_data = unpaird_data 19 | 20 | file = open(os.path.join(root,'train_input.txt'),'r') 21 | set1_input_files = sorted(file.readlines()) 22 | self.set1_input_files = list() 23 | self.set1_expert_files = list() 24 | for i in range(len(set1_input_files)): 25 | self.set1_input_files.append(os.path.join(root,"input","JPG/480p",set1_input_files[i][:-1] + ".jpg")) 26 | self.set1_expert_files.append(os.path.join(root,"expertC","JPG/480p",set1_input_files[i][:-1] + ".jpg")) 27 | 28 | file = open(os.path.join(root,'train_label.txt'),'r') 29 | set2_input_files = sorted(file.readlines()) 30 | self.set2_input_files = list() 31 | self.set2_expert_files = list() 32 | for i in range(len(set2_input_files)): 33 | self.set2_input_files.append(os.path.join(root,"input","JPG/480p",set2_input_files[i][:-1] + ".jpg")) 34 | self.set2_expert_files.append(os.path.join(root,"expertC","JPG/480p",set2_input_files[i][:-1] + ".jpg")) 35 | 36 | file = open(os.path.join(root,'test.txt'),'r') 37 | test_input_files = sorted(file.readlines()) 38 | self.test_input_files = list() 39 | self.test_expert_files = list() 40 | for i in range(len(test_input_files)): 41 | self.test_input_files.append(os.path.join(root,"input","JPG/480p",test_input_files[i][:-1] + ".jpg")) 42 | self.test_expert_files.append(os.path.join(root,"expertC","JPG/480p",test_input_files[i][:-1] + ".jpg")) 43 | 44 | if combined: 45 | self.set1_input_files = self.set1_input_files + self.set2_input_files 46 | self.set1_expert_files = self.set1_expert_files + self.set2_expert_files 47 | 48 | 49 | def __getitem__(self, index): 50 | 51 | if self.mode == "train": 52 | img_name = os.path.split(self.set1_input_files[index % len(self.set1_input_files)])[-1] 53 | img_input = Image.open(self.set1_input_files[index % len(self.set1_input_files)]) 54 | img_exptC = Image.open(self.set1_expert_files[index % len(self.set1_expert_files)]) 55 | 56 | elif self.mode == "test": 57 | img_name = os.path.split(self.test_input_files[index % len(self.test_input_files)])[-1] 58 | img_input = Image.open(self.test_input_files[index % len(self.test_input_files)]) 59 | img_exptC = Image.open(self.test_expert_files[index % len(self.test_expert_files)]) 60 | 61 | if self.mode == "train": 62 | 63 | ratio_H = np.random.uniform(0.6,1.0) 64 | ratio_W = np.random.uniform(0.6,1.0) 65 | W,H = img_input._size 66 | crop_h = round(H*ratio_H) 67 | crop_w = round(W*ratio_W) 68 | i, j, h, w = transforms.RandomCrop.get_params(img_input, output_size=(crop_h, crop_w)) 69 | img_input = TF.crop(img_input, i, j, h, w) 70 | img_exptC = TF.crop(img_exptC, i, j, h, w) 71 | #img_input = TF.resized_crop(img_input, i, j, h, w, (320,320)) 72 | #img_exptC = TF.resized_crop(img_exptC, i, j, h, w, (320,320)) 73 | 74 | if np.random.random() > 0.5: 75 | img_input = TF.hflip(img_input) 76 | img_exptC = TF.hflip(img_exptC) 77 | 78 | a = np.random.uniform(0.8,1.2) 79 | img_input = TF.adjust_brightness(img_input,a) 80 | 81 | a = np.random.uniform(0.8,1.2) 82 | img_input = TF.adjust_saturation(img_input,a) 83 | 84 | img_input = TF.to_tensor(img_input) 85 | img_exptC = TF.to_tensor(img_exptC) 86 | 87 | return {"A_input": img_input, "A_exptC": img_exptC, "input_name": img_name} 88 | 89 | def __len__(self): 90 | if self.mode == "train": 91 | return len(self.set1_input_files) 92 | elif self.mode == "test": 93 | return len(self.test_input_files) 94 | 95 | 96 | class ImageDataset_XYZ(Dataset): 97 | def __init__(self, root, mode="train", unpaird_data="fiveK", combined=True): 98 | self.mode = mode 99 | 100 | file = open(os.path.join(root,'train_input.txt'),'r') 101 | set1_input_files = sorted(file.readlines()) 102 | self.set1_input_files = list() 103 | self.set1_expert_files = list() 104 | for i in range(len(set1_input_files)): 105 | self.set1_input_files.append(os.path.join(root,"input","PNG/480p_16bits_XYZ_WB",set1_input_files[i][:-1] + ".png")) 106 | self.set1_expert_files.append(os.path.join(root,"expertC","JPG/480p",set1_input_files[i][:-1] + ".jpg")) 107 | 108 | file = open(os.path.join(root,'train_label.txt'),'r') 109 | set2_input_files = sorted(file.readlines()) 110 | self.set2_input_files = list() 111 | self.set2_expert_files = list() 112 | for i in range(len(set2_input_files)): 113 | self.set2_input_files.append(os.path.join(root,"input","PNG/480p_16bits_XYZ_WB",set2_input_files[i][:-1] + ".png")) 114 | self.set2_expert_files.append(os.path.join(root,"expertC","JPG/480p",set2_input_files[i][:-1] + ".jpg")) 115 | 116 | file = open(os.path.join(root,'test.txt'),'r') 117 | test_input_files = sorted(file.readlines()) 118 | self.test_input_files = list() 119 | self.test_expert_files = list() 120 | for i in range(len(test_input_files)): 121 | self.test_input_files.append(os.path.join(root,"input","PNG/480p_16bits_XYZ_WB",test_input_files[i][:-1] + ".png")) 122 | self.test_expert_files.append(os.path.join(root,"expertC","JPG/480p",test_input_files[i][:-1] + ".jpg")) 123 | 124 | if combined: 125 | self.set1_input_files = self.set1_input_files + self.set2_input_files 126 | self.set1_expert_files = self.set1_expert_files + self.set2_expert_files 127 | 128 | 129 | def __getitem__(self, index): 130 | 131 | if self.mode == "train": 132 | img_name = os.path.split(self.set1_input_files[index % len(self.set1_input_files)])[-1] 133 | img_input = cv2.imread(self.set1_input_files[index % len(self.set1_input_files)],-1) 134 | img_exptC = Image.open(self.set1_expert_files[index % len(self.set1_expert_files)]) 135 | 136 | elif self.mode == "test": 137 | img_name = os.path.split(self.test_input_files[index % len(self.test_input_files)])[-1] 138 | img_input = cv2.imread(self.test_input_files[index % len(self.test_input_files)],-1) 139 | img_exptC = Image.open(self.test_expert_files[index % len(self.test_expert_files)]) 140 | 141 | img_input = np.array(img_input) 142 | #img_input = np.array(cv2.cvtColor(img_input,cv2.COLOR_BGR2RGB)) 143 | 144 | if self.mode == "train": 145 | 146 | ratio_H = np.random.uniform(0.6,1.0) 147 | ratio_W = np.random.uniform(0.6,1.0) 148 | W,H = img_exptC._size 149 | crop_h = round(H*ratio_H) 150 | crop_w = round(W*ratio_W) 151 | i, j, h, w = transforms.RandomCrop.get_params(img_exptC, output_size=(crop_h, crop_w)) 152 | img_input = TF_x.crop(img_input, i, j, h, w) 153 | img_exptC = TF.crop(img_exptC, i, j, h, w) 154 | 155 | if np.random.random() > 0.5: 156 | img_input = TF_x.hflip(img_input) 157 | img_exptC = TF.hflip(img_exptC) 158 | 159 | a = np.random.uniform(0.6,1.4) 160 | img_input = TF_x.adjust_brightness(img_input,a) 161 | 162 | img_input = TF_x.to_tensor(img_input) 163 | img_exptC = TF.to_tensor(img_exptC) 164 | 165 | return {"A_input": img_input, "A_exptC": img_exptC, "input_name": img_name} 166 | 167 | def __len__(self): 168 | if self.mode == "train": 169 | return len(self.set1_input_files) 170 | elif self.mode == "test": 171 | return len(self.test_input_files) 172 | 173 | class ImageDataset_sRGB_unpaired(Dataset): 174 | def __init__(self, root, mode="train", unpaird_data="fiveK"): 175 | self.mode = mode 176 | self.unpaird_data = unpaird_data 177 | 178 | file = open(os.path.join(root,'train_input.txt'),'r') 179 | set1_input_files = sorted(file.readlines()) 180 | self.set1_input_files = list() 181 | self.set1_expert_files = list() 182 | for i in range(len(set1_input_files)): 183 | self.set1_input_files.append(os.path.join(root,"input","JPG/480p",set1_input_files[i][:-1] + ".jpg")) 184 | self.set1_expert_files.append(os.path.join(root,"expertC","JPG/480p",set1_input_files[i][:-1] + ".jpg")) 185 | 186 | file = open(os.path.join(root,'train_label.txt'),'r') 187 | set2_input_files = sorted(file.readlines()) 188 | self.set2_input_files = list() 189 | self.set2_expert_files = list() 190 | for i in range(len(set2_input_files)): 191 | self.set2_input_files.append(os.path.join(root,"input","JPG/480p",set2_input_files[i][:-1] + ".jpg")) 192 | self.set2_expert_files.append(os.path.join(root,"expertC","JPG/480p",set2_input_files[i][:-1] + ".jpg")) 193 | 194 | file = open(os.path.join(root,'test.txt'),'r') 195 | test_input_files = sorted(file.readlines()) 196 | self.test_input_files = list() 197 | self.test_expert_files = list() 198 | for i in range(len(test_input_files)): 199 | self.test_input_files.append(os.path.join(root,"input","JPG/480p",test_input_files[i][:-1] + ".jpg")) 200 | self.test_expert_files.append(os.path.join(root,"expertC","JPG/480p",test_input_files[i][:-1] + ".jpg")) 201 | 202 | 203 | def __getitem__(self, index): 204 | 205 | if self.mode == "train": 206 | img_name = os.path.split(self.set1_input_files[index % len(self.set1_input_files)])[-1] 207 | img_input = Image.open(self.set1_input_files[index % len(self.set1_input_files)]) 208 | img_exptC = Image.open(self.set1_expert_files[index % len(self.set1_expert_files)]) 209 | seed = random.randint(1,len(self.set2_expert_files)) 210 | img2 = Image.open(self.set2_expert_files[(index + seed) % len(self.set2_expert_files)]) 211 | 212 | elif self.mode == "test": 213 | img_name = os.path.split(self.test_input_files[index % len(self.test_input_files)])[-1] 214 | img_input = Image.open(self.test_input_files[index % len(self.test_input_files)]) 215 | img_exptC = Image.open(self.test_expert_files[index % len(self.test_expert_files)]) 216 | img2 = img_exptC 217 | 218 | if self.mode == "train": 219 | ratio_H = np.random.uniform(0.6,1.0) 220 | ratio_W = np.random.uniform(0.6,1.0) 221 | W,H = img_input._size 222 | crop_h = round(H*ratio_H) 223 | crop_w = round(W*ratio_W) 224 | W2,H2 = img2._size 225 | crop_h = min(crop_h,H2) 226 | crop_w = min(crop_w,W2) 227 | i, j, h, w = transforms.RandomCrop.get_params(img_input, output_size=(crop_h, crop_w)) 228 | img_input = TF.crop(img_input, i, j, h, w) 229 | img_exptC = TF.crop(img_exptC, i, j, h, w) 230 | i, j, h, w = transforms.RandomCrop.get_params(img2, output_size=(crop_h, crop_w)) 231 | img2 = TF.crop(img2, i, j, h, w) 232 | 233 | if np.random.random() > 0.5: 234 | img_input = TF.hflip(img_input) 235 | img_exptC = TF.hflip(img_exptC) 236 | 237 | if np.random.random() > 0.5: 238 | img2 = TF.hflip(img2) 239 | 240 | #if np.random.random() > 0.5: 241 | # img_input = TF.vflip(img_input) 242 | # img_exptC = TF.vflip(img_exptC) 243 | # img2 = TF.vflip(img2) 244 | 245 | a = np.random.uniform(0.6,1.4) 246 | img_input = TF.adjust_brightness(img_input,a) 247 | 248 | a = np.random.uniform(0.8,1.2) 249 | img_input = TF.adjust_saturation(img_input,a) 250 | 251 | 252 | img_input = TF.to_tensor(img_input) 253 | img_exptC = TF.to_tensor(img_exptC) 254 | img2 = TF.to_tensor(img2) 255 | 256 | return {"A_input": img_input, "A_exptC": img_exptC, "B_exptC": img2, "input_name": img_name} 257 | 258 | def __len__(self): 259 | if self.mode == "train": 260 | return len(self.set1_input_files) 261 | elif self.mode == "test": 262 | return len(self.test_input_files) 263 | 264 | 265 | class ImageDataset_XYZ_unpaired(Dataset): 266 | def __init__(self, root, mode="train", unpaird_data="fiveK"): 267 | self.mode = mode 268 | self.unpaird_data = unpaird_data 269 | 270 | file = open(os.path.join(root,'train_input.txt'),'r') 271 | set1_input_files = sorted(file.readlines()) 272 | self.set1_input_files = list() 273 | self.set1_expert_files = list() 274 | for i in range(len(set1_input_files)): 275 | self.set1_input_files.append(os.path.join(root,"input","PNG/480p_16bits_XYZ_WB",set1_input_files[i][:-1] + ".png")) 276 | self.set1_expert_files.append(os.path.join(root,"expertC","JPG/480p",set1_input_files[i][:-1] + ".jpg")) 277 | 278 | file = open(os.path.join(root,'train_label.txt'),'r') 279 | set2_input_files = sorted(file.readlines()) 280 | self.set2_input_files = list() 281 | self.set2_expert_files = list() 282 | for i in range(len(set2_input_files)): 283 | self.set2_input_files.append(os.path.join(root,"input","PNG/480p_16bits_XYZ_WB",set2_input_files[i][:-1] + ".png")) 284 | self.set2_expert_files.append(os.path.join(root,"expertC","JPG/480p",set2_input_files[i][:-1] + ".jpg")) 285 | 286 | file = open(os.path.join(root,'test.txt'),'r') 287 | test_input_files = sorted(file.readlines()) 288 | self.test_input_files = list() 289 | self.test_expert_files = list() 290 | for i in range(len(test_input_files)): 291 | self.test_input_files.append(os.path.join(root,"input","PNG/480p_16bits_XYZ_WB",test_input_files[i][:-1] + ".png")) 292 | self.test_expert_files.append(os.path.join(root,"expertC","JPG/480p",test_input_files[i][:-1] + ".jpg")) 293 | 294 | 295 | def __getitem__(self, index): 296 | 297 | if self.mode == "train": 298 | img_name = os.path.split(self.set1_input_files[index % len(self.set1_input_files)])[-1] 299 | img_input = cv2.imread(self.set1_input_files[index % len(self.set1_input_files)],-1) 300 | img_exptC = Image.open(self.set1_expert_files[index % len(self.set1_expert_files)]) 301 | seed = random.randint(1,len(self.set2_expert_files)) 302 | img2 = Image.open(self.set2_expert_files[(index + seed) % len(self.set2_expert_files)]) 303 | 304 | elif self.mode == "test": 305 | img_name = os.path.split(self.test_input_files[index % len(self.test_input_files)])[-1] 306 | img_input = cv2.imread(self.test_input_files[index % len(self.test_input_files)],-1) 307 | img_exptC = Image.open(self.test_expert_files[index % len(self.test_expert_files)]) 308 | img2 = img_exptC 309 | 310 | img_input = np.array(img_input) 311 | #img_input = np.array(cv2.cvtColor(img_input,cv2.COLOR_BGR2RGB)) 312 | 313 | if self.mode == "train": 314 | ratio_H = np.random.uniform(0.6,1.0) 315 | ratio_W = np.random.uniform(0.6,1.0) 316 | W,H = img_exptC._size 317 | crop_h = round(H*ratio_H) 318 | crop_w = round(W*ratio_W) 319 | W2,H2 = img2._size 320 | crop_h = min(crop_h,H2) 321 | crop_w = min(crop_w,W2) 322 | i, j, h, w = transforms.RandomCrop.get_params(img_exptC, output_size=(crop_h, crop_w)) 323 | img_input = TF_x.crop(img_input, i, j, h, w) 324 | img_exptC = TF.crop(img_exptC, i, j, h, w) 325 | i, j, h, w = transforms.RandomCrop.get_params(img2, output_size=(crop_h, crop_w)) 326 | img2 = TF.crop(img2, i, j, h, w) 327 | 328 | if np.random.random() > 0.5: 329 | img_input = TF_x.hflip(img_input) 330 | img_exptC = TF.hflip(img_exptC) 331 | 332 | if np.random.random() > 0.5: 333 | img2 = TF.hflip(img2) 334 | 335 | a = np.random.uniform(0.6,1.4) 336 | img_input = TF_x.adjust_brightness(img_input,a) 337 | 338 | img_input = TF_x.to_tensor(img_input) 339 | img_exptC = TF.to_tensor(img_exptC) 340 | img2 = TF.to_tensor(img2) 341 | 342 | return {"A_input": img_input, "A_exptC": img_exptC, "B_exptC": img2, "input_name": img_name} 343 | 344 | def __len__(self): 345 | if self.mode == "train": 346 | return len(self.set1_input_files) 347 | elif self.mode == "test": 348 | return len(self.test_input_files) 349 | 350 | 351 | class ImageDataset_HDRplus(Dataset): 352 | def __init__(self, root, mode="train", combined=True): 353 | self.mode = mode 354 | 355 | file = open(os.path.join(root,'train.txt'),'r') 356 | set1_input_files = sorted(file.readlines()) 357 | self.set1_input_files = list() 358 | self.set1_expert_files = list() 359 | for i in range(len(set1_input_files)): 360 | self.set1_input_files.append(os.path.join(root,"middle_480p",set1_input_files[i][:-1] + ".png")) 361 | self.set1_expert_files.append(os.path.join(root,"output_480p",set1_input_files[i][:-1] + ".jpg")) 362 | 363 | file = open(os.path.join(root,'test.txt'),'r') 364 | test_input_files = sorted(file.readlines()) 365 | self.test_input_files = list() 366 | self.test_expert_files = list() 367 | for i in range(len(test_input_files)): 368 | self.test_input_files.append(os.path.join(root,"middle_480p",test_input_files[i][:-1] + ".png")) 369 | self.test_expert_files.append(os.path.join(root,"output_480p",test_input_files[i][:-1] + ".jpg")) 370 | 371 | 372 | def __getitem__(self, index): 373 | 374 | if self.mode == "train": 375 | img_name = os.path.split(self.set1_input_files[index % len(self.set1_input_files)])[-1] 376 | img_input = cv2.imread(self.set1_input_files[index % len(self.set1_input_files)],-1) 377 | img_exptC = Image.open(self.set1_expert_files[index % len(self.set1_expert_files)]) 378 | 379 | elif self.mode == "test": 380 | img_name = os.path.split(self.test_input_files[index % len(self.test_input_files)])[-1] 381 | img_input = cv2.imread(self.test_input_files[index % len(self.test_input_files)],-1) 382 | img_exptC = Image.open(self.test_expert_files[index % len(self.test_expert_files)]) 383 | 384 | img_input = np.array(img_input) 385 | #img_input = np.array(cv2.cvtColor(img_input,cv2.COLOR_BGR2RGB)) 386 | 387 | if self.mode == "train": 388 | 389 | ratio = np.random.uniform(0.6,1.0) 390 | W,H = img_exptC._size 391 | crop_h = round(H*ratio) 392 | crop_w = round(W*ratio) 393 | i, j, h, w = transforms.RandomCrop.get_params(img_exptC, output_size=(crop_h, crop_w)) 394 | try: 395 | img_input = TF_x.crop(img_input, i, j, h, w) 396 | except: 397 | print(crop_h,crop_w,img_input.shape()) 398 | img_exptC = TF.crop(img_exptC, i, j, h, w) 399 | 400 | if np.random.random() > 0.5: 401 | img_input = TF_x.hflip(img_input) 402 | img_exptC = TF.hflip(img_exptC) 403 | 404 | a = np.random.uniform(0.6,1.4) 405 | img_input = TF_x.adjust_brightness(img_input,a) 406 | 407 | #a = np.random.uniform(0.8,1.2) 408 | #img_input = TF_x.adjust_saturation(img_input,a) 409 | 410 | img_input = TF_x.to_tensor(img_input) 411 | img_exptC = TF.to_tensor(img_exptC) 412 | 413 | return {"A_input": img_input, "A_exptC": img_exptC, "input_name": img_name} 414 | 415 | def __len__(self): 416 | if self.mode == "train": 417 | return len(self.set1_input_files) 418 | elif self.mode == "test": 419 | return len(self.test_input_files) 420 | 421 | class ImageDataset_HDRplus_unpaired(Dataset): 422 | def __init__(self, root, mode="train"): 423 | self.mode = mode 424 | 425 | file = open(os.path.join(root,'train.txt'),'r') 426 | set1_input_files = sorted(file.readlines()) 427 | self.set1_input_files = list() 428 | self.set1_expert_files = list() 429 | for i in range(len(set1_input_files)): 430 | self.set1_input_files.append(os.path.join(root,"middle_480p",set1_input_files[i][:-1] + ".png")) 431 | self.set1_expert_files.append(os.path.join(root,"output_480p",set1_input_files[i][:-1] + ".jpg")) 432 | 433 | file = open(os.path.join(root,'train.txt'),'r') 434 | set2_input_files = sorted(file.readlines()) 435 | self.set2_input_files = list() 436 | self.set2_expert_files = list() 437 | for i in range(len(set2_input_files)): 438 | self.set2_input_files.append(os.path.join(root,"middle_480p",set2_input_files[i][:-1] + ".png")) 439 | self.set2_expert_files.append(os.path.join(root,"output_480p",set2_input_files[i][:-1] + ".jpg")) 440 | 441 | file = open(os.path.join(root,'test.txt'),'r') 442 | test_input_files = sorted(file.readlines()) 443 | self.test_input_files = list() 444 | self.test_expert_files = list() 445 | for i in range(len(test_input_files)): 446 | self.test_input_files.append(os.path.join(root,"middle_480p",test_input_files[i][:-1] + ".png")) 447 | self.test_expert_files.append(os.path.join(root,"output_480p",test_input_files[i][:-1] + ".jpg")) 448 | 449 | 450 | def __getitem__(self, index): 451 | 452 | if self.mode == "train": 453 | img_name = os.path.split(self.set1_input_files[index % len(self.set1_input_files)])[-1] 454 | img_input = cv2.imread(self.set1_input_files[index % len(self.set1_input_files)],-1) 455 | img_exptC = Image.open(self.set1_expert_files[index % len(self.set1_expert_files)]) 456 | seed = random.randint(1,len(self.set2_expert_files)) 457 | img2 = Image.open(self.set2_expert_files[(index + seed) % len(self.set2_expert_files)]) 458 | 459 | elif self.mode == "test": 460 | img_name = os.path.split(self.test_input_files[index % len(self.test_input_files)])[-1] 461 | img_input = cv2.imread(self.test_input_files[index % len(self.test_input_files)],-1) 462 | img_exptC = Image.open(self.test_expert_files[index % len(self.test_expert_files)]) 463 | img2 = img_exptC 464 | 465 | img_input = np.array(img_input) 466 | #img_input = np.array(cv2.cvtColor(img_input,cv2.COLOR_BGR2RGB)) 467 | 468 | if self.mode == "train": 469 | ratio = np.random.uniform(0.6,1.0) 470 | W,H = img_exptC._size 471 | crop_h = round(H*ratio) 472 | crop_w = round(W*ratio) 473 | W2,H2 = img2._size 474 | crop_h = min(crop_h,H2) 475 | crop_w = min(crop_w,W2) 476 | i, j, h, w = transforms.RandomCrop.get_params(img_exptC, output_size=(crop_h, crop_w)) 477 | img_input = TF_x.crop(img_input, i, j, h, w) 478 | img_exptC = TF.crop(img_exptC, i, j, h, w) 479 | i, j, h, w = transforms.RandomCrop.get_params(img2, output_size=(crop_h, crop_w)) 480 | img2 = TF.crop(img2, i, j, h, w) 481 | 482 | if np.random.random() > 0.5: 483 | img_input = TF_x.hflip(img_input) 484 | img_exptC = TF.hflip(img_exptC) 485 | 486 | if np.random.random() > 0.5: 487 | img2 = TF.hflip(img2) 488 | 489 | a = np.random.uniform(0.8,1.2) 490 | img_input = TF_x.adjust_brightness(img_input,a) 491 | 492 | img_input = TF_x.to_tensor(img_input) 493 | img_exptC = TF.to_tensor(img_exptC) 494 | img2 = TF.to_tensor(img2) 495 | 496 | return {"A_input": img_input, "A_exptC": img_exptC, "B_exptC": img2, "input_name": img_name} 497 | 498 | def __len__(self): 499 | if self.mode == "train": 500 | return len(self.set1_input_files) 501 | elif self.mode == "test": 502 | return len(self.test_input_files) 503 | -------------------------------------------------------------------------------- /demo_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import os 4 | import numpy as np 5 | import cv2 6 | from PIL import Image 7 | 8 | from models import * 9 | import torchvision_x_functional as TF_x 10 | import torchvision.transforms.functional as TF 11 | 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument("--image_dir", type=str, default="demo_images", help="directory of image") 16 | parser.add_argument("--image_name", type=str, default="a1629.jpg", help="name of image") 17 | parser.add_argument("--input_color_space", type=str, default="sRGB", help="input color space: sRGB or XYZ") 18 | parser.add_argument("--model_dir", type=str, default="pretrained_models", help="directory of pretrained models") 19 | parser.add_argument("--output_dir", type=str, default="demo_results", help="directory to save results") 20 | opt = parser.parse_args() 21 | opt.model_dir = opt.model_dir + '/' + opt.input_color_space 22 | opt.image_path = opt.image_dir + '/' + opt.input_color_space + '/' + opt.image_name 23 | os.makedirs(opt.output_dir, exist_ok=True) 24 | 25 | # use gpu when detect cuda 26 | cuda = True if torch.cuda.is_available() else False 27 | # Tensor type 28 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 29 | 30 | criterion_pixelwise = torch.nn.MSELoss() 31 | LUT0 = Generator3DLUT_identity() 32 | LUT1 = Generator3DLUT_zero() 33 | LUT2 = Generator3DLUT_zero() 34 | #LUT3 = Generator3DLUT_zero() 35 | #LUT4 = Generator3DLUT_zero() 36 | classifier = Classifier() 37 | trilinear_ = TrilinearInterpolation() 38 | 39 | if cuda: 40 | LUT0 = LUT0.cuda() 41 | LUT1 = LUT1.cuda() 42 | LUT2 = LUT2.cuda() 43 | #LUT3 = LUT3.cuda() 44 | #LUT4 = LUT4.cuda() 45 | classifier = classifier.cuda() 46 | criterion_pixelwise.cuda() 47 | 48 | # Load pretrained models 49 | LUTs = torch.load("%s/LUTs.pth" % opt.model_dir) 50 | LUT0.load_state_dict(LUTs["0"]) 51 | LUT1.load_state_dict(LUTs["1"]) 52 | LUT2.load_state_dict(LUTs["2"]) 53 | #LUT3.load_state_dict(LUTs["3"]) 54 | #LUT4.load_state_dict(LUTs["4"]) 55 | LUT0.eval() 56 | LUT1.eval() 57 | LUT2.eval() 58 | #LUT3.eval() 59 | #LUT4.eval() 60 | classifier.load_state_dict(torch.load("%s/classifier.pth" % opt.model_dir)) 61 | classifier.eval() 62 | 63 | 64 | def generate_LUT(img): 65 | 66 | pred = classifier(img).squeeze() 67 | 68 | LUT = pred[0] * LUT0.LUT + pred[1] * LUT1.LUT + pred[2] * LUT2.LUT #+ pred[3] * LUT3.LUT + pred[4] * LUT4.LUT 69 | 70 | return LUT 71 | 72 | 73 | # ---------- 74 | # test 75 | # ---------- 76 | # read image and transform to tensor 77 | if opt.input_color_space == 'sRGB': 78 | img = Image.open(opt.image_path) 79 | img = TF.to_tensor(img).type(Tensor) 80 | elif opt.input_color_space == 'XYZ': 81 | img = cv2.imread(opt.image_path, -1) 82 | img = np.array(img) 83 | img = TF_x.to_tensor(img).type(Tensor) 84 | img = img.unsqueeze(0) 85 | 86 | LUT = generate_LUT(img) 87 | 88 | # generate image 89 | result = trilinear_(LUT, img) 90 | 91 | # save image 92 | ndarr = result.squeeze().mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy() 93 | im = Image.fromarray(ndarr) 94 | im.save('%s/result.jpg' % opt.output_dir, quality=95) 95 | 96 | 97 | -------------------------------------------------------------------------------- /demo_images/XYZ/a1629.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/demo_images/XYZ/a1629.png -------------------------------------------------------------------------------- /demo_images/sRGB/a1629.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/demo_images/sRGB/a1629.jpg -------------------------------------------------------------------------------- /figures/framework2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/figures/framework2.png -------------------------------------------------------------------------------- /image_adaptive_lut_evaluation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import torch 4 | from torchvision.utils import save_image 5 | from torch.utils.data import DataLoader 6 | from torch.autograd import Variable 7 | 8 | from models import * 9 | from datasets import * 10 | 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument("--epoch", type=int, default=145, help="epoch to load the saved checkpoint") 14 | parser.add_argument("--dataset_name", type=str, default="fiveK", help="name of the dataset") 15 | parser.add_argument("--input_color_space", type=str, default="sRGB", help="input color space: sRGB or XYZ") 16 | parser.add_argument("--model_dir", type=str, default="LUTs/paired/fiveK_480p_3LUT_sm_1e-4_mn_10", help="directory of saved models") 17 | opt = parser.parse_args() 18 | opt.model_dir = opt.model_dir + '_' + opt.input_color_space 19 | 20 | # use gpu when detect cuda 21 | cuda = True if torch.cuda.is_available() else False 22 | # Tensor type 23 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 24 | 25 | criterion_pixelwise = torch.nn.MSELoss() 26 | LUT0 = Generator3DLUT_identity() 27 | LUT1 = Generator3DLUT_zero() 28 | LUT2 = Generator3DLUT_zero() 29 | #LUT3 = Generator3DLUT_zero() 30 | #LUT4 = Generator3DLUT_zero() 31 | classifier = Classifier() 32 | trilinear_ = TrilinearInterpolation() 33 | 34 | if cuda: 35 | LUT0 = LUT0.cuda() 36 | LUT1 = LUT1.cuda() 37 | LUT2 = LUT2.cuda() 38 | #LUT3 = LUT3.cuda() 39 | #LUT4 = LUT4.cuda() 40 | classifier = classifier.cuda() 41 | criterion_pixelwise.cuda() 42 | 43 | # Load pretrained models 44 | LUTs = torch.load("saved_models/%s/LUTs_%d.pth" % (opt.model_dir, opt.epoch)) 45 | LUT0.load_state_dict(LUTs["0"]) 46 | LUT1.load_state_dict(LUTs["1"]) 47 | LUT2.load_state_dict(LUTs["2"]) 48 | #LUT3.load_state_dict(LUTs["3"]) 49 | #LUT4.load_state_dict(LUTs["4"]) 50 | LUT0.eval() 51 | LUT1.eval() 52 | LUT2.eval() 53 | #LUT3.eval() 54 | #LUT4.eval() 55 | classifier.load_state_dict(torch.load("saved_models/%s/classifier_%d.pth" % (opt.model_dir, opt.epoch))) 56 | classifier.eval() 57 | 58 | if opt.input_color_space == 'sRGB': 59 | dataloader = DataLoader( 60 | ImageDataset_sRGB("../data/%s" % opt.dataset_name, mode="test"), 61 | batch_size=1, 62 | shuffle=False, 63 | num_workers=1, 64 | ) 65 | elif opt.input_color_space == 'XYZ': 66 | dataloader = DataLoader( 67 | ImageDataset_XYZ("../data/%s" % opt.dataset_name, mode="test"), 68 | batch_size=1, 69 | shuffle=False, 70 | num_workers=1, 71 | ) 72 | 73 | 74 | def generator(img): 75 | 76 | pred = classifier(img).squeeze() 77 | 78 | LUT = pred[0] * LUT0.LUT + pred[1] * LUT1.LUT + pred[2] * LUT2.LUT #+ pred[3] * LUT3.LUT + pred[4] * LUT4.LUT 79 | 80 | combine_A = img.new(img.size()) 81 | combine_A = trilinear_(LUT,img) 82 | 83 | return combine_A 84 | 85 | 86 | def visualize_result(): 87 | """Saves a generated sample from the validation set""" 88 | out_dir = "images/%s_%d" % (opt.model_dir, opt.epoch) 89 | os.makedirs(out_dir, exist_ok=True) 90 | for i, batch in enumerate(dataloader): 91 | real_A = Variable(batch["A_input"].type(Tensor)) 92 | img_name = batch["input_name"] 93 | fake_B = generator(real_A) 94 | 95 | #real_B = Variable(batch["A_exptC"].type(Tensor)) 96 | #img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -1) 97 | #save_image(img_sample, "images/LUTs/paired/JPGsRGB8_to_JPGsRGB8_WB_original_5LUT/%s.png" % (img_name[0][:-4]), nrow=3, normalize=False) 98 | save_image(fake_B, os.path.join(out_dir,"%s.png" % (img_name[0][:-4])), nrow=1, normalize=False) 99 | 100 | def test_speed(): 101 | t_list = [] 102 | for i in range(1,10): 103 | img_input = Image.open(os.path.join("../data/fiveK/input/JPG","original","a000%d.jpg"%i)) 104 | img_input = torch.unsqueeze(TF.to_tensor(TF.resize(img_input,(4000,6000))),0) 105 | real_A = Variable(img_input.type(Tensor)) 106 | torch.cuda.synchronize() 107 | t0 = time.time() 108 | for j in range(0,100): 109 | fake_B = generator(real_A) 110 | 111 | torch.cuda.synchronize() 112 | t1 = time.time() 113 | t_list.append(t1 - t0) 114 | print((t1 - t0)) 115 | print(t_list) 116 | 117 | # ---------- 118 | # evaluation 119 | # ---------- 120 | visualize_result() 121 | 122 | #test_speed() 123 | -------------------------------------------------------------------------------- /image_adaptive_lut_train_paired.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import math 5 | import itertools 6 | import time 7 | import datetime 8 | import sys 9 | 10 | import torchvision.transforms as transforms 11 | from torchvision.utils import save_image 12 | 13 | from torch.utils.data import DataLoader 14 | from torchvision import datasets 15 | from torch.autograd import Variable 16 | 17 | from models import * 18 | from datasets import * 19 | 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import torch 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from, 0 starts from scratch, >0 starts from saved checkpoints") 26 | parser.add_argument("--n_epochs", type=int, default=400, help="total number of epochs of training") 27 | parser.add_argument("--dataset_name", type=str, default="fiveK", help="name of the dataset") 28 | parser.add_argument("--input_color_space", type=str, default="sRGB", help="input color space: sRGB or XYZ") 29 | parser.add_argument("--batch_size", type=int, default=1, help="size of the batches") 30 | parser.add_argument("--lr", type=float, default=0.0001, help="adam: learning rate") 31 | parser.add_argument("--b1", type=float, default=0.9, help="adam: decay of first order momentum of gradient") 32 | parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") 33 | parser.add_argument("--lambda_smooth", type=float, default=0.0001, help="smooth regularization") 34 | parser.add_argument("--lambda_monotonicity", type=float, default=10.0, help="monotonicity regularization") 35 | parser.add_argument("--n_cpu", type=int, default=1, help="number of cpu threads to use during batch generation") 36 | parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between model checkpoints") 37 | parser.add_argument("--output_dir", type=str, default="LUTs/paired/fiveK_480p_3LUT_sm_1e-4_mn_10", help="path to save model") 38 | opt = parser.parse_args() 39 | 40 | opt.output_dir = opt.output_dir + '_' + opt.input_color_space 41 | print(opt) 42 | 43 | os.makedirs("saved_models/%s" % opt.output_dir, exist_ok=True) 44 | 45 | cuda = True if torch.cuda.is_available() else False 46 | # Tensor type 47 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 48 | 49 | # Loss functions 50 | criterion_pixelwise = torch.nn.MSELoss() 51 | 52 | # Initialize generator and discriminator 53 | LUT0 = Generator3DLUT_identity() 54 | LUT1 = Generator3DLUT_zero() 55 | LUT2 = Generator3DLUT_zero() 56 | #LUT3 = Generator3DLUT_zero() 57 | #LUT4 = Generator3DLUT_zero() 58 | classifier = Classifier() 59 | TV3 = TV_3D() 60 | trilinear_ = TrilinearInterpolation() 61 | 62 | if cuda: 63 | LUT0 = LUT0.cuda() 64 | LUT1 = LUT1.cuda() 65 | LUT2 = LUT2.cuda() 66 | #LUT3 = LUT3.cuda() 67 | #LUT4 = LUT4.cuda() 68 | classifier = classifier.cuda() 69 | criterion_pixelwise.cuda() 70 | TV3.cuda() 71 | TV3.weight_r = TV3.weight_r.type(Tensor) 72 | TV3.weight_g = TV3.weight_g.type(Tensor) 73 | TV3.weight_b = TV3.weight_b.type(Tensor) 74 | 75 | if opt.epoch != 0: 76 | # Load pretrained models 77 | LUTs = torch.load("saved_models/%s/LUTs_%d.pth" % (opt.output_dir, opt.epoch)) 78 | LUT0.load_state_dict(LUTs["0"]) 79 | LUT1.load_state_dict(LUTs["1"]) 80 | LUT2.load_state_dict(LUTs["2"]) 81 | #LUT3.load_state_dict(LUTs["3"]) 82 | #LUT4.load_state_dict(LUTs["4"]) 83 | classifier.load_state_dict(torch.load("saved_models/%s/classifier_%d.pth" % (opt.output_dir, opt.epoch))) 84 | else: 85 | # Initialize weights 86 | classifier.apply(weights_init_normal_classifier) 87 | torch.nn.init.constant_(classifier.model[16].bias.data, 1.0) 88 | 89 | # Optimizers 90 | 91 | optimizer_G = torch.optim.Adam(itertools.chain(classifier.parameters(), LUT0.parameters(), LUT1.parameters(), LUT2.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)) #, LUT3.parameters(), LUT4.parameters() 92 | 93 | if opt.input_color_space == 'sRGB': 94 | dataloader = DataLoader( 95 | ImageDataset_sRGB("../data/%s" % opt.dataset_name, mode = "train"), 96 | batch_size=opt.batch_size, 97 | shuffle=True, 98 | num_workers=opt.n_cpu, 99 | ) 100 | 101 | psnr_dataloader = DataLoader( 102 | ImageDataset_sRGB("../data/%s" % opt.dataset_name, mode="test"), 103 | batch_size=1, 104 | shuffle=False, 105 | num_workers=1, 106 | ) 107 | elif opt.input_color_space == 'XYZ': 108 | dataloader = DataLoader( 109 | ImageDataset_XYZ("../data/%s" % opt.dataset_name, mode = "train"), 110 | batch_size=opt.batch_size, 111 | shuffle=True, 112 | num_workers=opt.n_cpu, 113 | ) 114 | 115 | psnr_dataloader = DataLoader( 116 | ImageDataset_XYZ("../data/%s" % opt.dataset_name, mode="test"), 117 | batch_size=1, 118 | shuffle=False, 119 | num_workers=1, 120 | ) 121 | 122 | def generator_train(img): 123 | 124 | pred = classifier(img).squeeze() 125 | if len(pred.shape) == 1: 126 | pred = pred.unsqueeze(0) 127 | gen_A0 = LUT0(img) 128 | gen_A1 = LUT1(img) 129 | gen_A2 = LUT2(img) 130 | #gen_A3 = LUT3(img) 131 | #gen_A4 = LUT4(img) 132 | 133 | weights_norm = torch.mean(pred ** 2) 134 | 135 | combine_A = img.new(img.size()) 136 | for b in range(img.size(0)): 137 | combine_A[b,:,:,:] = pred[b,0] * gen_A0[b,:,:,:] + pred[b,1] * gen_A1[b,:,:,:] + pred[b,2] * gen_A2[b,:,:,:] #+ pred[b,3] * gen_A3[b,:,:,:] + pred[b,4] * gen_A4[b,:,:,:] 138 | 139 | return combine_A, weights_norm 140 | 141 | def generator_eval(img): 142 | 143 | pred = classifier(img).squeeze() 144 | 145 | LUT = pred[0] * LUT0.LUT + pred[1] * LUT1.LUT + pred[2] * LUT2.LUT #+ pred[3] * LUT3.LUT + pred[4] * LUT4.LUT 146 | 147 | weights_norm = torch.mean(pred ** 2) 148 | 149 | combine_A = img.new(img.size()) 150 | combine_A = trilinear_(LUT,img) 151 | 152 | return combine_A, weights_norm 153 | 154 | def calculate_psnr(): 155 | classifier.eval() 156 | avg_psnr = 0 157 | for i, batch in enumerate(psnr_dataloader): 158 | real_A = Variable(batch["A_input"].type(Tensor)) 159 | real_B = Variable(batch["A_exptC"].type(Tensor)) 160 | fake_B, weights_norm = generator_eval(real_A) 161 | fake_B = torch.round(fake_B*255) 162 | real_B = torch.round(real_B*255) 163 | mse = criterion_pixelwise(fake_B, real_B) 164 | psnr = 10 * math.log10(255.0 * 255.0 / mse.item()) 165 | avg_psnr += psnr 166 | 167 | return avg_psnr/ len(psnr_dataloader) 168 | 169 | 170 | def visualize_result(epoch): 171 | """Saves a generated sample from the validation set""" 172 | classifier.eval() 173 | os.makedirs("images/%s/" % opt.output_dir +str(epoch), exist_ok=True) 174 | for i, batch in enumerate(psnr_dataloader): 175 | real_A = Variable(batch["A_input"].type(Tensor)) 176 | real_B = Variable(batch["A_exptC"].type(Tensor)) 177 | img_name = batch["input_name"] 178 | fake_B, weights_norm = generator_eval(real_A) 179 | img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -1) 180 | fake_B = torch.round(fake_B*255) 181 | real_B = torch.round(real_B*255) 182 | mse = criterion_pixelwise(fake_B, real_B) 183 | psnr = 10 * math.log10(255.0 * 255.0 / mse.item()) 184 | save_image(img_sample, "images/%s/%s/%s.jpg" % (opt.output_dir,epoch, img_name[0]+'_'+str(psnr)[:5]), nrow=3, normalize=False) 185 | 186 | # ---------- 187 | # Training 188 | # ---------- 189 | 190 | prev_time = time.time() 191 | max_psnr = 0 192 | max_epoch = 0 193 | for epoch in range(opt.epoch, opt.n_epochs): 194 | mse_avg = 0 195 | psnr_avg = 0 196 | classifier.train() 197 | for i, batch in enumerate(dataloader): 198 | 199 | # Model inputs 200 | real_A = Variable(batch["A_input"].type(Tensor)) 201 | real_B = Variable(batch["A_exptC"].type(Tensor)) 202 | 203 | # ------------------ 204 | # Train Generators 205 | # ------------------ 206 | 207 | optimizer_G.zero_grad() 208 | 209 | fake_B, weights_norm = generator_train(real_A) 210 | 211 | # Pixel-wise loss 212 | mse = criterion_pixelwise(fake_B, real_B) 213 | 214 | tv0, mn0 = TV3(LUT0) 215 | tv1, mn1 = TV3(LUT1) 216 | tv2, mn2 = TV3(LUT2) 217 | #tv3, mn3 = TV3(LUT3) 218 | #tv4, mn4 = TV3(LUT4) 219 | tv_cons = tv0 + tv1 + tv2 #+ tv3 + tv4 220 | mn_cons = mn0 + mn1 + mn2 #+ mn3 + mn4 221 | 222 | loss = mse + opt.lambda_smooth * (weights_norm + tv_cons) + opt.lambda_monotonicity * mn_cons 223 | 224 | psnr_avg += 10 * math.log10(1 / mse.item()) 225 | 226 | mse_avg += mse.item() 227 | 228 | loss.backward() 229 | 230 | optimizer_G.step() 231 | 232 | 233 | # -------------- 234 | # Log Progress 235 | # -------------- 236 | 237 | # Determine approximate time left 238 | batches_done = epoch * len(dataloader) + i 239 | batches_left = opt.n_epochs * len(dataloader) - batches_done 240 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) 241 | prev_time = time.time() 242 | 243 | # Print log 244 | sys.stdout.write( 245 | "\r[Epoch %d/%d] [Batch %d/%d] [psnr: %f, tv: %f, wnorm: %f, mn: %f] ETA: %s" 246 | % (epoch,opt.n_epochs,i,len(dataloader),psnr_avg / (i+1),tv_cons, weights_norm, mn_cons, time_left, 247 | ) 248 | ) 249 | 250 | avg_psnr = calculate_psnr() 251 | if avg_psnr > max_psnr: 252 | max_psnr = avg_psnr 253 | max_epoch = epoch 254 | sys.stdout.write(" [PSNR: %f] [max PSNR: %f, epoch: %d]\n"% (avg_psnr, max_psnr, max_epoch)) 255 | 256 | #if (epoch+1) % 10 == 0: 257 | # visualize_result(epoch+1) 258 | 259 | if epoch % opt.checkpoint_interval == 0: 260 | # Save model checkpoints 261 | LUTs = {"0": LUT0.state_dict(),"1": LUT1.state_dict(),"2": LUT2.state_dict()} #,"3": LUT3.state_dict(),"4": LUT4.state_dict() 262 | torch.save(LUTs, "saved_models/%s/LUTs_%d.pth" % (opt.output_dir, epoch)) 263 | torch.save(classifier.state_dict(), "saved_models/%s/classifier_%d.pth" % (opt.output_dir, epoch)) 264 | file = open('saved_models/%s/result.txt' % opt.output_dir,'a') 265 | file.write(" [PSNR: %f] [max PSNR: %f, epoch: %d]\n"% (avg_psnr, max_psnr, max_epoch)) 266 | file.close() 267 | 268 | 269 | -------------------------------------------------------------------------------- /image_adaptive_lut_train_unpaired.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import numpy as np 4 | import math 5 | import itertools 6 | import time 7 | import datetime 8 | import sys 9 | 10 | import torchvision.transforms as transforms 11 | from torchvision.utils import save_image 12 | 13 | from torch.utils.data import DataLoader 14 | from torchvision import datasets 15 | from torch.autograd import Variable 16 | import torch.autograd as autograd 17 | 18 | from models_x import * 19 | from datasets import * 20 | 21 | import torch.nn as nn 22 | import torch.nn.functional as F 23 | import torch 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from, 0 starts from scratch, >0 starts from saved checkpoints") 27 | parser.add_argument("--n_epochs", type=int, default=800, help="total number of epochs of training") 28 | parser.add_argument("--dataset_name", type=str, default="fiveK", help="name of the dataset") 29 | parser.add_argument("--input_color_space", type=str, default="sRGB", help="input color space: sRGB or XYZ") 30 | parser.add_argument("--batch_size", type=int, default=1, help="size of the batches") 31 | parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") 32 | parser.add_argument("--b1", type=float, default=0.9, help="adam: decay of first order momentum of gradient") 33 | parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") 34 | parser.add_argument("--lambda_pixel", type=float, default=1000, help="content preservation weight: 1000 for sRGB input, 10 for XYZ input") 35 | parser.add_argument("--lambda_gp", type=float, default=10, help="gradient penalty weight in wgan-gp") 36 | parser.add_argument("--lambda_smooth", type=float, default=1e-4, help="smooth regularization") 37 | parser.add_argument("--lambda_monotonicity", type=float, default=10.0, help="monotonicity regularization: 10 for sRGB input, 100 for XYZ input (slightly better)") 38 | parser.add_argument("--n_cpu", type=int, default=1, help="number of cpu threads to use during batch generation") 39 | parser.add_argument("--n_critic", type=int, default=1, help="number of training steps for discriminator per iter") 40 | parser.add_argument("--output_dir", type=str, default="LUTs/unpaired/fiveK_480p_sm_1e-4_mn_10_pixel_1000", help="path to save model") 41 | parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between model checkpoints") 42 | opt = parser.parse_args() 43 | opt.output_dir = opt.output_dir + '_' + opt.input_color_space 44 | print(opt) 45 | 46 | os.makedirs("saved_models/%s" % opt.output_dir, exist_ok=True) 47 | 48 | cuda = True if torch.cuda.is_available() else False 49 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor 50 | 51 | # Loss functions 52 | criterion_GAN = torch.nn.MSELoss() 53 | criterion_pixelwise = torch.nn.MSELoss() 54 | 55 | # Initialize generator and discriminator 56 | LUT0 = Generator3DLUT_identity() 57 | LUT1 = Generator3DLUT_zero() 58 | LUT2 = Generator3DLUT_zero() 59 | #LUT3 = Generator3DLUT_zero() 60 | #LUT4 = Generator3DLUT_zero() 61 | classifier = Classifier_unpaired() 62 | discriminator = Discriminator() 63 | TV3 = TV_3D() 64 | 65 | if cuda: 66 | LUT0 = LUT0.cuda() 67 | LUT1 = LUT1.cuda() 68 | LUT2 = LUT2.cuda() 69 | #LUT3 = LUT3.cuda() 70 | #LUT4 = LUT4.cuda() 71 | classifier = classifier.cuda() 72 | criterion_GAN.cuda() 73 | criterion_pixelwise.cuda() 74 | discriminator = discriminator.cuda() 75 | TV3.cuda() 76 | TV3.weight_r = TV3.weight_r.type(Tensor) 77 | TV3.weight_g = TV3.weight_g.type(Tensor) 78 | TV3.weight_b = TV3.weight_b.type(Tensor) 79 | 80 | if opt.epoch != 0: 81 | # Load pretrained models 82 | LUTs = torch.load("saved_models/%s/LUTs_%d.pth" % (opt.output_dir, opt.epoch)) 83 | LUT0.load_state_dict(LUTs["0"]) 84 | LUT1.load_state_dict(LUTs["1"]) 85 | LUT2.load_state_dict(LUTs["2"]) 86 | #LUT3.load_state_dict(LUTs["3"]) 87 | #LUT4.load_state_dict(LUTs["4"]) 88 | classifier.load_state_dict(torch.load("saved_models/%s/classifier_%d.pth" % (opt.output_dir, opt.epoch))) 89 | else: 90 | # Initialize weights 91 | classifier.apply(weights_init_normal_classifier) 92 | torch.nn.init.constant_(classifier.model[12].bias.data, 1.0) 93 | discriminator.apply(weights_init_normal_classifier) 94 | 95 | # Optimizers 96 | optimizer_G = torch.optim.Adam(itertools.chain(classifier.parameters(), LUT0.parameters(),LUT1.parameters(),LUT2.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)) #,LUT3.parameters(),LUT4.parameters() 97 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)) 98 | 99 | if opt.input_color_space == 'sRGB': 100 | dataloader = DataLoader( 101 | ImageDataset_sRGB_unpaired("../data/%s" % opt.dataset_name, mode="train"), 102 | batch_size=opt.batch_size, 103 | shuffle=True, 104 | num_workers=opt.n_cpu, 105 | ) 106 | 107 | psnr_dataloader = DataLoader( 108 | ImageDataset_sRGB_unpaired("../data/%s" % opt.dataset_name, mode="test"), 109 | batch_size=1, 110 | shuffle=False, 111 | num_workers=1, 112 | ) 113 | elif opt.input_color_space == 'XYZ': 114 | dataloader = DataLoader( 115 | ImageDataset_XYZ_unpaired("../data/%s" % opt.dataset_name, mode="train"), 116 | batch_size=opt.batch_size, 117 | shuffle=True, 118 | num_workers=opt.n_cpu, 119 | ) 120 | 121 | psnr_dataloader = DataLoader( 122 | ImageDataset_XYZ_unpaired("../data/%s" % opt.dataset_name, mode="test"), 123 | batch_size=1, 124 | shuffle=False, 125 | num_workers=1, 126 | ) 127 | 128 | def calculate_psnr(): 129 | classifier.eval() 130 | avg_psnr = 0 131 | for i, batch in enumerate(psnr_dataloader): 132 | real_A = Variable(batch["A_input"].type(Tensor)) 133 | real_B = Variable(batch["A_exptC"].type(Tensor)) 134 | fake_B, weights_norm = generator(real_A) 135 | fake_B = torch.round(fake_B*255) 136 | real_B = torch.round(real_B*255) 137 | mse = criterion_pixelwise(fake_B, real_B) 138 | psnr = 10 * math.log10(255.0 * 255.0 / mse.item()) 139 | avg_psnr += psnr 140 | 141 | return avg_psnr/ len(psnr_dataloader) 142 | 143 | 144 | def visualize_result(epoch): 145 | """Saves a generated sample from the validation set""" 146 | os.makedirs("images/LUTs/" +str(epoch), exist_ok=True) 147 | for i, batch in enumerate(psnr_dataloader): 148 | real_A = Variable(batch["A_input"].type(Tensor)) 149 | real_B = Variable(batch["A_exptC"].type(Tensor)) 150 | img_name = batch["input_name"] 151 | fake_B, weights_norm = generator(real_A) 152 | img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -1) 153 | fake_B = torch.round(fake_B*255) 154 | real_B = torch.round(real_B*255) 155 | mse = criterion_pixelwise(fake_B, real_B) 156 | psnr = 10 * math.log10(255.0 * 255.0 / mse.item()) 157 | save_image(img_sample, "images/LUTs/%s/%s.jpg" % (epoch, img_name[0]+'_'+str(psnr)[:5]), nrow=3, normalize=False) 158 | 159 | 160 | def compute_gradient_penalty(D, real_samples, fake_samples): 161 | """Calculates the gradient penalty loss for WGAN GP""" 162 | # Random weight term for interpolation between real and fake samples 163 | alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1))) 164 | # Get random interpolation between real and fake samples 165 | interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True) 166 | d_interpolates = D(interpolates) 167 | fake = Variable(Tensor(real_samples.shape[0], 1, 1, 1).fill_(1.0), requires_grad=False) 168 | # Get gradient w.r.t. interpolates 169 | gradients = autograd.grad( 170 | outputs=d_interpolates, 171 | inputs=interpolates, 172 | grad_outputs=fake, 173 | create_graph=True, 174 | retain_graph=True, 175 | only_inputs=True, 176 | )[0] 177 | gradients = gradients.view(gradients.size(0), -1) 178 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() 179 | return gradient_penalty 180 | 181 | 182 | def generator(img): 183 | 184 | pred = classifier(img).squeeze() 185 | weights_norm = torch.mean(pred ** 2) 186 | combine_A = pred[0] * LUT0(img) + pred[1] * LUT1(img) + pred[2] * LUT2(img) #+ pred[3] * LUT3(img) + pred[4] * LUT4(img) 187 | 188 | return combine_A, weights_norm 189 | 190 | # ---------- 191 | # Training 192 | # ---------- 193 | 194 | avg_psnr = calculate_psnr() 195 | print(avg_psnr) 196 | prev_time = time.time() 197 | max_psnr = 0 198 | max_epoch = 0 199 | for epoch in range(opt.epoch, opt.n_epochs): 200 | loss_D_avg = 0 201 | loss_G_avg = 0 202 | loss_pixel_avg = 0 203 | cnt = 0 204 | psnr_avg = 0 205 | classifier.train() 206 | for i, batch in enumerate(dataloader): 207 | 208 | # Model inputs 209 | real_A = Variable(batch["A_input"].type(Tensor)) 210 | real_B = Variable(batch["B_exptC"].type(Tensor)) 211 | 212 | 213 | # --------------------- 214 | # Train Discriminator 215 | # --------------------- 216 | 217 | optimizer_D.zero_grad() 218 | 219 | fake_B, weights_norm = generator(real_A) 220 | pred_real = discriminator(real_B) 221 | pred_fake = discriminator(fake_B) 222 | 223 | # Gradient penalty 224 | gradient_penalty = compute_gradient_penalty(discriminator, real_B, fake_B) 225 | 226 | # Total loss 227 | loss_D = -torch.mean(pred_real) + torch.mean(pred_fake) + opt.lambda_gp * gradient_penalty 228 | 229 | loss_D.backward() 230 | optimizer_D.step() 231 | 232 | loss_D_avg += (-torch.mean(pred_real) + torch.mean(pred_fake)) / 2 233 | 234 | # ------------------ 235 | # Train Generators 236 | # ------------------ 237 | if i % opt.n_critic == 0: 238 | 239 | optimizer_G.zero_grad() 240 | 241 | fake_B, weights_norm = generator(real_A) 242 | pred_fake = discriminator(fake_B) 243 | # Pixel-wise loss 244 | loss_pixel = criterion_pixelwise(fake_B, real_A) 245 | 246 | tv0, mn0 = TV3(LUT0) 247 | tv1, mn1 = TV3(LUT1) 248 | tv2, mn2 = TV3(LUT2) 249 | #tv3, mn3 = TV3(LUT3) 250 | #tv4, mn4 = TV3(LUT4) 251 | 252 | tv_cons = tv0 + tv1 + tv2 #+ tv3 + tv4 253 | mn_cons = mn0 + mn1 + mn2 #+ mn3 + mn4 254 | 255 | loss_G = -torch.mean(pred_fake) + opt.lambda_pixel * loss_pixel + opt.lambda_smooth * (weights_norm + tv_cons) + opt.lambda_monotonicity * mn_cons 256 | 257 | loss_G.backward() 258 | 259 | optimizer_G.step() 260 | 261 | cnt += 1 262 | loss_G_avg += -torch.mean(pred_fake) 263 | 264 | loss_pixel_avg += loss_pixel 265 | psnr_avg += 10 * math.log10(1 / loss_pixel.item()) 266 | 267 | 268 | # -------------- 269 | # Log Progress 270 | # -------------- 271 | 272 | # Determine approximate time left 273 | batches_done = epoch * len(dataloader) + i 274 | batches_left = opt.n_epochs * len(dataloader) - batches_done 275 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time)) 276 | prev_time = time.time() 277 | 278 | # Print log 279 | sys.stdout.write( 280 | "\r[Epoch %d/%d] [Batch %d/%d] [D: %f, G: %f] [pixel: %f] [tv: %f, wnorm: %f, mn: %f] ETA: %s" 281 | % ( 282 | epoch, 283 | opt.n_epochs, 284 | i, 285 | len(dataloader), 286 | loss_D_avg.item() / cnt, 287 | loss_G_avg.item() / cnt, 288 | loss_pixel_avg.item() / cnt, 289 | tv_cons, weights_norm, mn_cons, 290 | time_left, 291 | ) 292 | ) 293 | 294 | # If at sample interval save image 295 | avg_psnr = calculate_psnr() 296 | if avg_psnr > max_psnr: 297 | max_psnr = avg_psnr 298 | max_epoch = epoch 299 | sys.stdout.write(" [PSNR: %f] [max PSNR: %f, epoch: %d]\n"% (avg_psnr, max_psnr, max_epoch)) 300 | 301 | #if (epoch+1) % 10 == 0: 302 | # visualize_result(epoch+1) 303 | 304 | if epoch % opt.checkpoint_interval == 0: 305 | # Save model checkpoints 306 | LUTs = {"0": LUT0.state_dict(), "1": LUT1.state_dict(), "2": LUT2.state_dict()} #, "3": LUT3.state_dict(), "4": LUT4.state_dict() 307 | torch.save(LUTs, "saved_models/%s/LUTs_%d.pth" % (opt.output_dir, epoch)) 308 | torch.save(classifier.state_dict(), "saved_models/%s/classifier_%d.pth" % (opt.output_dir, epoch)) 309 | file = open('saved_models/%s/result.txt' % opt.output_dir,'a') 310 | file.write(" [PSNR: %f] [max PSNR: %f, epoch: %d]\n"% (avg_psnr, max_psnr, max_epoch)) 311 | file.close() 312 | -------------------------------------------------------------------------------- /local_tone_mapping/a1509.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/local_tone_mapping/a1509.jpg -------------------------------------------------------------------------------- /local_tone_mapping/wlsFilter.m: -------------------------------------------------------------------------------- 1 | function OUT = wlsFilter(IN, lambda, alpha, L) 2 | %WLSFILTER Edge-preserving smoothing based on the weighted least squares(WLS) 3 | % optimization framework, as described in Farbman, Fattal, Lischinski, and 4 | % Szeliski, "Edge-Preserving Decompositions for Multi-Scale Tone and Detail 5 | % Manipulation", ACM Transactions on Graphics, 27(3), August 2008. 6 | % 7 | % Given an input image IN, we seek a new image OUT, which, on the one hand, 8 | % is as close as possible to IN, and, at the same time, is as smooth as 9 | % possible everywhere, except across significant gradients in L. 10 | % 11 | % 12 | % Input arguments: 13 | % ---------------- 14 | % IN Input image (2-D, double, N-by-M matrix). 15 | % 16 | % lambda Balances between the data term and the smoothness 17 | % term. Increasing lambda will produce smoother images. 18 | % Default value is 1.0 19 | % 20 | % alpha Gives a degree of control over the affinities by non- 21 | % lineary scaling the gradients. Increasing alpha will 22 | % result in sharper preserved edges. Default value: 1.2 23 | % 24 | % L Source image for the affinity matrix. Same dimensions 25 | % as the input image IN. Default: log(IN) 26 | % 27 | % 28 | % Example 29 | % ------- 30 | % RGB = imread('peppers.png'); 31 | % I = double(rgb2gray(RGB)); 32 | % I = I./max(I(:)); 33 | % res = wlsFilter(I, 0.5); 34 | % figure, imshow(I), figure, imshow(res) 35 | % res = wlsFilter(I, 2, 2); 36 | % figure, imshow(res) 37 | 38 | if(~exist('L', 'var')), 39 | L = log(IN+eps); 40 | end 41 | 42 | if(~exist('alpha', 'var')), 43 | alpha = 1.2; 44 | end 45 | 46 | if(~exist('lambda', 'var')), 47 | lambda = 1; 48 | end 49 | 50 | smallNum = 0.0001; 51 | 52 | [r,c] = size(IN); 53 | k = r*c; 54 | 55 | % Compute affinities between adjacent pixels based on gradients of L 56 | dy = diff(L, 1, 1); 57 | dy = -lambda./(abs(dy).^alpha + smallNum); 58 | dy = padarray(dy, [1 0], 'post'); 59 | dy = dy(:); 60 | 61 | dx = diff(L, 1, 2); 62 | dx = -lambda./(abs(dx).^alpha + smallNum); 63 | dx = padarray(dx, [0 1], 'post'); 64 | dx = dx(:); 65 | 66 | 67 | % Construct a five-point spatially inhomogeneous Laplacian matrix 68 | B(:,1) = dx; 69 | B(:,2) = dy; 70 | d = [-r,-1]; 71 | A = spdiags(B,d,k,k); 72 | 73 | e = dx; 74 | w = padarray(dx, r, 'pre'); w = w(1:end-r); 75 | s = dy; 76 | n = padarray(dy, 1, 'pre'); n = n(1:end-1); 77 | 78 | D = 1-(e+w+s+n); 79 | A = A + A' + spdiags(D, 0, k, k); 80 | 81 | % Solve 82 | OUT = A\IN(:); 83 | OUT = reshape(OUT, r, c); 84 | -------------------------------------------------------------------------------- /local_tone_mapping/wlsTonemap.m: -------------------------------------------------------------------------------- 1 | 2 | %WLSTONEMAP High Dynamic Range tonemapping using WLS 3 | % 4 | % The script reduces the dynamic range of an HDR image using the method 5 | % originally proposed by Durand and Dorsey, "Fast Bilateral Filtering 6 | % for the Display of High-Dynamic-Range Images", 7 | % ACM Transactions on Graphics, 2002. 8 | % 9 | % Instead of the bilateral filter, the edge-preserving smoothing here 10 | % is based on the weighted least squares(WLS) optimization framework, 11 | % as described in Farbman, Fattal, Lischinski, and Szeliski, 12 | % "Edge-Preserving Decompositions for Multi-Scale Tone and Detail 13 | % Manipulation", ACM Transactions on Graphics, 27(3), August 2008. 14 | 15 | 16 | %% Load HDR image from file and convert to greyscale 17 | % hdr = double(hdrread('smallOffice.hdr')); 18 | hdr = double(imread('a1509.jpg'))/255.0; 19 | % hdr = imresize(hdr,0.2); 20 | % hsv = rgb2hsv(hdr); 21 | % I = hsv(:,:,3); 22 | I = 0.2989*hdr(:,:,1) + 0.587*hdr(:,:,2) + 0.114*hdr(:,:,3); 23 | logI = log(I+eps); 24 | 25 | %% Perform edge-preserving smoothing using WLS 26 | lambda = 20.0; 27 | alpha = 1.2; 28 | % base = log(wlsFilter(I, lambda, alpha)); 29 | base = log(imguidedfilter(I)); 30 | 31 | %% Compress the base layer and restore detail 32 | compression = 0.6; 33 | detail = logI - base; 34 | OUT = base*compression + detail; 35 | OUT = exp(OUT); 36 | 37 | %% Restore color 38 | OUT = OUT./I; 39 | OUT = hdr .* padarray(OUT, [0 0 2], 'circular' , 'post'); 40 | % hsv(:,:,3) = I; 41 | % OUT = hsv2rgb(hsv); 42 | 43 | %% Finally, shift, scale, and gamma correct the result 44 | gamma = 1.0/1.0; 45 | bias = -min(OUT(:)); 46 | gain = 0.8; 47 | OUT = (gain*(OUT + bias)).^gamma; 48 | % figure 49 | imshow(OUT); 50 | imwrite(OUT,'a1509_T.jpg'); 51 | % imshowpair(hdr,OUT, 'montage') 52 | % imshow(hdr,OUT); 53 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torchvision.models as models 4 | import torchvision.transforms as transforms 5 | from torch.autograd import Variable 6 | import torch 7 | import numpy as np 8 | import math 9 | from trilinear_c._ext import trilinear 10 | 11 | def weights_init_normal_classifier(m): 12 | classname = m.__class__.__name__ 13 | if classname.find("Conv") != -1: 14 | torch.nn.init.xavier_normal_(m.weight.data) 15 | 16 | elif classname.find("BatchNorm2d") != -1 or classname.find("InstanceNorm2d") != -1: 17 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 18 | torch.nn.init.constant_(m.bias.data, 0.0) 19 | 20 | class resnet18_224(nn.Module): 21 | 22 | def __init__(self, out_dim=5, aug_test=False): 23 | super(resnet18_224, self).__init__() 24 | 25 | self.aug_test = aug_test 26 | net = models.resnet18(pretrained=True) 27 | # self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda() 28 | # self.std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda() 29 | 30 | self.upsample = nn.Upsample(size=(224,224),mode='bilinear') 31 | net.fc = nn.Linear(512, out_dim) 32 | self.model = net 33 | 34 | 35 | def forward(self, x): 36 | 37 | x = self.upsample(x) 38 | if self.aug_test: 39 | # x = torch.cat((x, torch.rot90(x, 1, [2, 3]), torch.rot90(x, 3, [2, 3])), 0) 40 | x = torch.cat((x, torch.flip(x, [3])), 0) 41 | f = self.model(x) 42 | 43 | return f 44 | 45 | ############################## 46 | # DPE 47 | ############################## 48 | 49 | 50 | class UNetDown(nn.Module): 51 | def __init__(self, in_size, out_size, normalize=True, dropout=0.0): 52 | super(UNetDown, self).__init__() 53 | layers = [nn.Conv2d(in_size, out_size, 5, 2, 2)] 54 | layers.append(nn.SELU(inplace=True)) 55 | if normalize: 56 | #layers.append(nn.BatchNorm2d(out_size)) 57 | nn.InstanceNorm2d(out_size, affine = True) 58 | if dropout: 59 | layers.append(nn.Dropout(dropout)) 60 | self.model = nn.Sequential(*layers) 61 | 62 | def forward(self, x): 63 | return self.model(x) 64 | 65 | 66 | class UNetUp(nn.Module): 67 | def __init__(self, in_size, out_size, normalize=True, dropout=0.0): 68 | super(UNetUp, self).__init__() 69 | layers = [ 70 | nn.Upsample(scale_factor=2, mode = 'bilinear', align_corners=True), 71 | nn.Conv2d(in_size, out_size, 3, padding=1), 72 | nn.SELU(inplace=True), 73 | ] 74 | 75 | if normalize: 76 | #layers.append(nn.BatchNorm2d(out_size)) 77 | nn.InstanceNorm2d(out_size, affine = True) 78 | 79 | if dropout: 80 | layers.append(nn.Dropout(dropout)) 81 | 82 | self.model = nn.Sequential(*layers) 83 | 84 | def forward(self, x, skip_input): 85 | x = self.model(x) 86 | x = torch.cat((x, skip_input), 1) 87 | 88 | return x 89 | 90 | 91 | class GeneratorUNet(nn.Module): 92 | def __init__(self, in_channels=3, out_channels=3): 93 | super(GeneratorUNet, self).__init__() 94 | 95 | self.conv1 = nn.Sequential( 96 | nn.Conv2d(3, 16, 3, padding=1), 97 | nn.SELU(inplace=True), 98 | #nn.BatchNorm2d(16), 99 | nn.InstanceNorm2d(16, affine = True), 100 | ) 101 | self.down1 = UNetDown(16, 32) 102 | self.down2 = UNetDown(32, 64) 103 | self.down3 = UNetDown(64, 128) 104 | self.down4 = UNetDown(128, 128) 105 | self.down5 = UNetDown(128, 128) 106 | self.down6 = UNetDown(128, 128) 107 | self.down7 = nn.Sequential( 108 | nn.Conv2d(128, 128, 3, padding=1), 109 | nn.SELU(inplace=True), 110 | nn.Conv2d(128, 128, 1, padding=0), 111 | ) 112 | 113 | self.upsample = nn.Upsample(scale_factor=4, mode = 'bilinear') 114 | self.conv1x1 = nn.Conv2d(256, 128, 1, padding=0) 115 | 116 | self.up1 = UNetUp(128, 128) 117 | self.up2 = UNetUp(256, 128) 118 | self.up3 = UNetUp(192, 64) 119 | self.up4 = UNetUp(96, 32) 120 | 121 | self.final = nn.Sequential( 122 | nn.Conv2d(48, 16, 3, padding=1), 123 | nn.SELU(inplace=True), 124 | #nn.BatchNorm2d(16), 125 | #nn.InstanceNorm2d(16, affine = True), 126 | nn.Conv2d(16, out_channels, 3, padding=1), 127 | #nn.Tanh(), 128 | ) 129 | 130 | def forward(self, x): 131 | # U-Net generator with skip connections from encoder to decoder 132 | 133 | x1 = self.conv1(x) 134 | d1 = self.down1(x1) 135 | d2 = self.down2(d1) 136 | d3 = self.down3(d2) 137 | d4 = self.down4(d3) 138 | 139 | d5 = self.down5(d4) 140 | d6 = self.down6(d5) 141 | d7 = self.down7(d6) 142 | 143 | d8 = self.upsample(d7) 144 | d9 = torch.cat((d4, d8), 1) 145 | d9 = self.conv1x1(d9) 146 | 147 | u1 = self.up1(d9, d3) 148 | u2 = self.up2(u1, d2) 149 | u3 = self.up3(u2, d1) 150 | u4 = self.up4(u3, x1) 151 | 152 | return torch.add(self.final(u4), x) 153 | 154 | 155 | class Discriminator_UNet(nn.Module): 156 | def __init__(self, in_channels=3): 157 | super(Discriminator_UNet, self).__init__() 158 | 159 | self.model = nn.Sequential( 160 | nn.Conv2d(3, 16, 3, stride=2, padding=1), 161 | nn.LeakyReLU(0.2), 162 | nn.InstanceNorm2d(16, affine=True), 163 | *discriminator_block(16, 32), 164 | *discriminator_block(32, 64), 165 | *discriminator_block(64, 128), 166 | *discriminator_block(128, 128), 167 | *discriminator_block(128, 128), 168 | nn.Conv2d(128, 1, 4, padding=0) 169 | ) 170 | 171 | def forward(self, img_input): 172 | return self.model(img_input) 173 | 174 | ############################## 175 | # Discriminator 176 | ############################## 177 | 178 | 179 | def discriminator_block(in_filters, out_filters, normalization=False): 180 | """Returns downsampling layers of each discriminator block""" 181 | layers = [nn.Conv2d(in_filters, out_filters, 3, stride=2, padding=1)] 182 | layers.append(nn.LeakyReLU(0.2)) 183 | if normalization: 184 | layers.append(nn.InstanceNorm2d(out_filters, affine=True)) 185 | #layers.append(nn.BatchNorm2d(out_filters)) 186 | 187 | return layers 188 | 189 | class Discriminator(nn.Module): 190 | def __init__(self, in_channels=3): 191 | super(Discriminator, self).__init__() 192 | 193 | self.model = nn.Sequential( 194 | nn.Upsample(size=(256,256),mode='bilinear'), 195 | nn.Conv2d(3, 16, 3, stride=2, padding=1), 196 | nn.LeakyReLU(0.2), 197 | nn.InstanceNorm2d(16, affine=True), 198 | *discriminator_block(16, 32), 199 | *discriminator_block(32, 64), 200 | *discriminator_block(64, 128), 201 | *discriminator_block(128, 128), 202 | #*discriminator_block(128, 128), 203 | nn.Conv2d(128, 1, 8, padding=0) 204 | ) 205 | 206 | def forward(self, img_input): 207 | return self.model(img_input) 208 | 209 | class Classifier(nn.Module): 210 | def __init__(self, in_channels=3): 211 | super(Classifier, self).__init__() 212 | 213 | self.model = nn.Sequential( 214 | nn.Upsample(size=(256,256),mode='bilinear'), 215 | nn.Conv2d(3, 16, 3, stride=2, padding=1), 216 | nn.LeakyReLU(0.2), 217 | nn.InstanceNorm2d(16, affine=True), 218 | *discriminator_block(16, 32, normalization=True), 219 | *discriminator_block(32, 64, normalization=True), 220 | *discriminator_block(64, 128, normalization=True), 221 | *discriminator_block(128, 128), 222 | #*discriminator_block(128, 128, normalization=True), 223 | nn.Dropout(p=0.5), 224 | nn.Conv2d(128, 3, 8, padding=0), 225 | ) 226 | 227 | def forward(self, img_input): 228 | return self.model(img_input) 229 | 230 | class Classifier_unpaired(nn.Module): 231 | def __init__(self, in_channels=3): 232 | super(Classifier_unpaired, self).__init__() 233 | 234 | self.model = nn.Sequential( 235 | nn.Upsample(size=(256,256),mode='bilinear'), 236 | nn.Conv2d(3, 16, 3, stride=2, padding=1), 237 | nn.LeakyReLU(0.2), 238 | nn.InstanceNorm2d(16, affine=True), 239 | *discriminator_block(16, 32), 240 | *discriminator_block(32, 64), 241 | *discriminator_block(64, 128), 242 | *discriminator_block(128, 128), 243 | #*discriminator_block(128, 128), 244 | nn.Conv2d(128, 3, 8, padding=0), 245 | ) 246 | 247 | def forward(self, img_input): 248 | return self.model(img_input) 249 | 250 | 251 | class Generator3DLUT_identity(nn.Module): 252 | def __init__(self, dim=33): 253 | super(Generator3DLUT_identity, self).__init__() 254 | if dim == 33: 255 | file = open("IdentityLUT33.txt",'r') 256 | elif dim == 64: 257 | file = open("IdentityLUT64.txt",'r') 258 | LUT = file.readlines() 259 | self.LUT = torch.zeros(3,dim,dim,dim, dtype=torch.float) 260 | 261 | for i in range(0,dim): 262 | for j in range(0,dim): 263 | for k in range(0,dim): 264 | n = i * dim*dim + j * dim + k 265 | x = LUT[n].split() 266 | self.LUT[0,i,j,k] = float(x[0]) 267 | self.LUT[1,i,j,k] = float(x[1]) 268 | self.LUT[2,i,j,k] = float(x[2]) 269 | self.LUT = nn.Parameter(torch.tensor(self.LUT)) 270 | self.TrilinearInterpolation = TrilinearInterpolation() 271 | 272 | def forward(self, x): 273 | 274 | return self.TrilinearInterpolation(self.LUT, x) 275 | 276 | 277 | class Generator3DLUT_zero(nn.Module): 278 | def __init__(self, dim=33): 279 | super(Generator3DLUT_zero, self).__init__() 280 | 281 | self.LUT = torch.zeros(3,dim,dim,dim, dtype=torch.float) 282 | self.LUT = nn.Parameter(torch.tensor(self.LUT)) 283 | self.TrilinearInterpolation = TrilinearInterpolation() 284 | 285 | def forward(self, x): 286 | 287 | return self.TrilinearInterpolation(self.LUT, x) 288 | 289 | class TrilinearInterpolation(torch.autograd.Function): 290 | 291 | def forward(self, LUT, x): 292 | 293 | x = x.contiguous() 294 | output = x.new(x.size()) 295 | dim = LUT.size()[-1] 296 | shift = dim ** 3 297 | binsize = 1.0001 / (dim-1) 298 | W = x.size(2) 299 | H = x.size(3) 300 | batch = x.size(0) 301 | 302 | self.x = x 303 | self.LUT = LUT 304 | self.dim = dim 305 | self.shift = shift 306 | self.binsize = binsize 307 | self.W = W 308 | self.H = H 309 | self.batch = batch 310 | 311 | if x.is_cuda: 312 | if batch == 1: 313 | trilinear.trilinear_forward_cuda(LUT,x,output,dim,shift,binsize,W,H,batch) 314 | elif batch > 1: 315 | output = output.permute(1,0,2,3).contiguous() 316 | trilinear.trilinear_forward_cuda(LUT,x.permute(1,0,2,3).contiguous(),output,dim,shift,binsize,W,H,batch) 317 | output = output.permute(1,0,2,3).contiguous() 318 | 319 | else: 320 | trilinear.trilinear_forward(LUT,x,output,dim,shift,binsize,W,H,batch) 321 | 322 | return output 323 | 324 | def backward(self, grad_x): 325 | 326 | grad_LUT = torch.zeros(3,self.dim,self.dim,self.dim,dtype=torch.float) 327 | 328 | if grad_x.is_cuda: 329 | grad_LUT = grad_LUT.cuda() 330 | if self.batch == 1: 331 | trilinear.trilinear_backward_cuda(self.x,grad_x,grad_LUT,self.dim,self.shift,self.binsize,self.W,self.H,self.batch) 332 | elif self.batch > 1: 333 | trilinear.trilinear_backward_cuda(self.x.permute(1,0,2,3).contiguous(),grad_x.permute(1,0,2,3).contiguous(),grad_LUT,self.dim,self.shift,self.binsize,self.W,self.H,self.batch) 334 | else: 335 | trilinear.trilinear_backward(self.x,grad_x,grad_LUT,self.dim,self.shift,self.binsize,self.W,self.H,self.batch) 336 | 337 | return grad_LUT, None 338 | 339 | 340 | class TV_3D(nn.Module): 341 | def __init__(self, dim=33): 342 | super(TV_3D,self).__init__() 343 | 344 | self.weight_r = torch.ones(3,dim,dim,dim-1, dtype=torch.float) 345 | self.weight_r[:,:,:,(0,dim-2)] *= 2.0 346 | self.weight_g = torch.ones(3,dim,dim-1,dim, dtype=torch.float) 347 | self.weight_g[:,:,(0,dim-2),:] *= 2.0 348 | self.weight_b = torch.ones(3,dim-1,dim,dim, dtype=torch.float) 349 | self.weight_b[:,(0,dim-2),:,:] *= 2.0 350 | self.relu = torch.nn.ReLU() 351 | 352 | def forward(self, LUT): 353 | 354 | dif_r = LUT.LUT[:,:,:,:-1] - LUT.LUT[:,:,:,1:] 355 | dif_g = LUT.LUT[:,:,:-1,:] - LUT.LUT[:,:,1:,:] 356 | dif_b = LUT.LUT[:,:-1,:,:] - LUT.LUT[:,1:,:,:] 357 | tv = torch.mean(torch.mul((dif_r ** 2),self.weight_r)) + torch.mean(torch.mul((dif_g ** 2),self.weight_g)) + torch.mean(torch.mul((dif_b ** 2),self.weight_b)) 358 | 359 | mn = torch.mean(self.relu(dif_r)) + torch.mean(self.relu(dif_g)) + torch.mean(self.relu(dif_b)) 360 | 361 | return tv, mn 362 | 363 | 364 | -------------------------------------------------------------------------------- /models_x.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torchvision.models as models 4 | import torchvision.transforms as transforms 5 | from torch.autograd import Variable 6 | import torch 7 | import numpy as np 8 | import math 9 | import trilinear 10 | 11 | def weights_init_normal_classifier(m): 12 | classname = m.__class__.__name__ 13 | if classname.find("Conv") != -1: 14 | torch.nn.init.xavier_normal_(m.weight.data) 15 | 16 | elif classname.find("BatchNorm2d") != -1 or classname.find("InstanceNorm2d") != -1: 17 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02) 18 | torch.nn.init.constant_(m.bias.data, 0.0) 19 | 20 | class resnet18_224(nn.Module): 21 | 22 | def __init__(self, out_dim=5, aug_test=False): 23 | super(resnet18_224, self).__init__() 24 | 25 | self.aug_test = aug_test 26 | net = models.resnet18(pretrained=True) 27 | # self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda() 28 | # self.std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda() 29 | 30 | self.upsample = nn.Upsample(size=(224,224),mode='bilinear') 31 | net.fc = nn.Linear(512, out_dim) 32 | self.model = net 33 | 34 | 35 | def forward(self, x): 36 | 37 | x = self.upsample(x) 38 | if self.aug_test: 39 | # x = torch.cat((x, torch.rot90(x, 1, [2, 3]), torch.rot90(x, 3, [2, 3])), 0) 40 | x = torch.cat((x, torch.flip(x, [3])), 0) 41 | f = self.model(x) 42 | 43 | return f 44 | 45 | ############################## 46 | # Discriminator 47 | ############################## 48 | 49 | 50 | def discriminator_block(in_filters, out_filters, normalization=False): 51 | """Returns downsampling layers of each discriminator block""" 52 | layers = [nn.Conv2d(in_filters, out_filters, 3, stride=2, padding=1)] 53 | layers.append(nn.LeakyReLU(0.2)) 54 | if normalization: 55 | layers.append(nn.InstanceNorm2d(out_filters, affine=True)) 56 | #layers.append(nn.BatchNorm2d(out_filters)) 57 | 58 | return layers 59 | 60 | class Discriminator(nn.Module): 61 | def __init__(self, in_channels=3): 62 | super(Discriminator, self).__init__() 63 | 64 | self.model = nn.Sequential( 65 | nn.Upsample(size=(256,256),mode='bilinear'), 66 | nn.Conv2d(3, 16, 3, stride=2, padding=1), 67 | nn.LeakyReLU(0.2), 68 | nn.InstanceNorm2d(16, affine=True), 69 | *discriminator_block(16, 32), 70 | *discriminator_block(32, 64), 71 | *discriminator_block(64, 128), 72 | *discriminator_block(128, 128), 73 | #*discriminator_block(128, 128), 74 | nn.Conv2d(128, 1, 8, padding=0) 75 | ) 76 | 77 | def forward(self, img_input): 78 | return self.model(img_input) 79 | 80 | class Classifier(nn.Module): 81 | def __init__(self, in_channels=3): 82 | super(Classifier, self).__init__() 83 | 84 | self.model = nn.Sequential( 85 | nn.Upsample(size=(256,256),mode='bilinear'), 86 | nn.Conv2d(3, 16, 3, stride=2, padding=1), 87 | nn.LeakyReLU(0.2), 88 | nn.InstanceNorm2d(16, affine=True), 89 | *discriminator_block(16, 32, normalization=True), 90 | *discriminator_block(32, 64, normalization=True), 91 | *discriminator_block(64, 128, normalization=True), 92 | *discriminator_block(128, 128), 93 | #*discriminator_block(128, 128, normalization=True), 94 | nn.Dropout(p=0.5), 95 | nn.Conv2d(128, 3, 8, padding=0), 96 | ) 97 | 98 | def forward(self, img_input): 99 | return self.model(img_input) 100 | 101 | class Classifier_unpaired(nn.Module): 102 | def __init__(self, in_channels=3): 103 | super(Classifier_unpaired, self).__init__() 104 | 105 | self.model = nn.Sequential( 106 | nn.Upsample(size=(256,256),mode='bilinear'), 107 | nn.Conv2d(3, 16, 3, stride=2, padding=1), 108 | nn.LeakyReLU(0.2), 109 | nn.InstanceNorm2d(16, affine=True), 110 | *discriminator_block(16, 32), 111 | *discriminator_block(32, 64), 112 | *discriminator_block(64, 128), 113 | *discriminator_block(128, 128), 114 | #*discriminator_block(128, 128), 115 | nn.Conv2d(128, 3, 8, padding=0), 116 | ) 117 | 118 | def forward(self, img_input): 119 | return self.model(img_input) 120 | 121 | 122 | class Generator3DLUT_identity(nn.Module): 123 | def __init__(self, dim=33): 124 | super(Generator3DLUT_identity, self).__init__() 125 | if dim == 33: 126 | file = open("IdentityLUT33.txt", 'r') 127 | elif dim == 64: 128 | file = open("IdentityLUT64.txt", 'r') 129 | lines = file.readlines() 130 | buffer = np.zeros((3,dim,dim,dim), dtype=np.float32) 131 | 132 | for i in range(0,dim): 133 | for j in range(0,dim): 134 | for k in range(0,dim): 135 | n = i * dim*dim + j * dim + k 136 | x = lines[n].split() 137 | buffer[0,i,j,k] = float(x[0]) 138 | buffer[1,i,j,k] = float(x[1]) 139 | buffer[2,i,j,k] = float(x[2]) 140 | self.LUT = nn.Parameter(torch.from_numpy(buffer).requires_grad_(True)) 141 | self.TrilinearInterpolation = TrilinearInterpolation() 142 | 143 | def forward(self, x): 144 | _, output = self.TrilinearInterpolation(self.LUT, x) 145 | #self.LUT, output = self.TrilinearInterpolation(self.LUT, x) 146 | return output 147 | 148 | class Generator3DLUT_zero(nn.Module): 149 | def __init__(self, dim=33): 150 | super(Generator3DLUT_zero, self).__init__() 151 | 152 | self.LUT = torch.zeros(3,dim,dim,dim, dtype=torch.float) 153 | self.LUT = nn.Parameter(torch.tensor(self.LUT)) 154 | self.TrilinearInterpolation = TrilinearInterpolation() 155 | 156 | def forward(self, x): 157 | _, output = self.TrilinearInterpolation(self.LUT, x) 158 | 159 | return output 160 | 161 | class TrilinearInterpolationFunction(torch.autograd.Function): 162 | @staticmethod 163 | def forward(ctx, lut, x): 164 | x = x.contiguous() 165 | 166 | output = x.new(x.size()) 167 | dim = lut.size()[-1] 168 | shift = dim ** 3 169 | binsize = 1.000001 / (dim-1) 170 | W = x.size(2) 171 | H = x.size(3) 172 | batch = x.size(0) 173 | 174 | assert 1 == trilinear.forward(lut, 175 | x, 176 | output, 177 | dim, 178 | shift, 179 | binsize, 180 | W, 181 | H, 182 | batch) 183 | 184 | int_package = torch.IntTensor([dim, shift, W, H, batch]) 185 | float_package = torch.FloatTensor([binsize]) 186 | variables = [lut, x, int_package, float_package] 187 | 188 | ctx.save_for_backward(*variables) 189 | 190 | return lut, output 191 | 192 | @staticmethod 193 | def backward(ctx, lut_grad, x_grad): 194 | 195 | lut, x, int_package, float_package = ctx.saved_variables 196 | dim, shift, W, H, batch = int_package 197 | dim, shift, W, H, batch = int(dim), int(shift), int(W), int(H), int(batch) 198 | binsize = float(float_package[0]) 199 | 200 | assert 1 == trilinear.backward(x, 201 | x_grad, 202 | lut_grad, 203 | dim, 204 | shift, 205 | binsize, 206 | W, 207 | H, 208 | batch) 209 | return lut_grad, x_grad 210 | 211 | 212 | class TrilinearInterpolation(torch.nn.Module): 213 | def __init__(self): 214 | super(TrilinearInterpolation, self).__init__() 215 | 216 | def forward(self, lut, x): 217 | return TrilinearInterpolationFunction.apply(lut, x) 218 | 219 | 220 | class TV_3D(nn.Module): 221 | def __init__(self, dim=33): 222 | super(TV_3D,self).__init__() 223 | 224 | self.weight_r = torch.ones(3,dim,dim,dim-1, dtype=torch.float) 225 | self.weight_r[:,:,:,(0,dim-2)] *= 2.0 226 | self.weight_g = torch.ones(3,dim,dim-1,dim, dtype=torch.float) 227 | self.weight_g[:,:,(0,dim-2),:] *= 2.0 228 | self.weight_b = torch.ones(3,dim-1,dim,dim, dtype=torch.float) 229 | self.weight_b[:,(0,dim-2),:,:] *= 2.0 230 | self.relu = torch.nn.ReLU() 231 | 232 | def forward(self, LUT): 233 | 234 | dif_r = LUT.LUT[:,:,:,:-1] - LUT.LUT[:,:,:,1:] 235 | dif_g = LUT.LUT[:,:,:-1,:] - LUT.LUT[:,:,1:,:] 236 | dif_b = LUT.LUT[:,:-1,:,:] - LUT.LUT[:,1:,:,:] 237 | tv = torch.mean(torch.mul((dif_r ** 2),self.weight_r)) + torch.mean(torch.mul((dif_g ** 2),self.weight_g)) + torch.mean(torch.mul((dif_b ** 2),self.weight_b)) 238 | 239 | mn = torch.mean(self.relu(dif_r)) + torch.mean(self.relu(dif_g)) + torch.mean(self.relu(dif_b)) 240 | 241 | return tv, mn 242 | 243 | 244 | -------------------------------------------------------------------------------- /pretrained_models/XYZ/LUTs.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/pretrained_models/XYZ/LUTs.pth -------------------------------------------------------------------------------- /pretrained_models/XYZ/LUTs_unpaired.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/pretrained_models/XYZ/LUTs_unpaired.pth -------------------------------------------------------------------------------- /pretrained_models/XYZ/classifier.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/pretrained_models/XYZ/classifier.pth -------------------------------------------------------------------------------- /pretrained_models/XYZ/classifier_unpaired.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/pretrained_models/XYZ/classifier_unpaired.pth -------------------------------------------------------------------------------- /pretrained_models/sRGB/LUTs.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/pretrained_models/sRGB/LUTs.pth -------------------------------------------------------------------------------- /pretrained_models/sRGB/LUTs_unpaired.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/pretrained_models/sRGB/LUTs_unpaired.pth -------------------------------------------------------------------------------- /pretrained_models/sRGB/classifier.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/pretrained_models/sRGB/classifier.pth -------------------------------------------------------------------------------- /pretrained_models/sRGB/classifier_unpaired.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/pretrained_models/sRGB/classifier_unpaired.pth -------------------------------------------------------------------------------- /requirements: -------------------------------------------------------------------------------- 1 | numpy==1.19.2 2 | PIL==6.1.0 3 | torch==0.4.1 4 | torchvision==0.2.2 5 | opencv-python==3.4.3 6 | -------------------------------------------------------------------------------- /ssim.m: -------------------------------------------------------------------------------- 1 | function [mssim, ssim_map] = ssim(img1, img2, K, window, L) 2 | 3 | % ======================================================================== 4 | % SSIM Index with automatic downsampling, Version 1.0 5 | % Copyright(c) 2009 Zhou Wang 6 | % All Rights Reserved. 7 | % 8 | % ---------------------------------------------------------------------- 9 | % Permission to use, copy, or modify this software and its documentation 10 | % for educational and research purposes only and without fee is hereby 11 | % granted, provided that this copyright notice and the original authors' 12 | % names appear on all copies and supporting documentation. This program 13 | % shall not be used, rewritten, or adapted as the basis of a commercial 14 | % software or hardware product without first obtaining permission of the 15 | % authors. The authors make no representations about the suitability of 16 | % this software for any purpose. It is provided "as is" without express 17 | % or implied warranty. 18 | %---------------------------------------------------------------------- 19 | % 20 | % This is an implementation of the algorithm for calculating the 21 | % Structural SIMilarity (SSIM) index between two images 22 | % 23 | % Please refer to the following paper and the website with suggested usage 24 | % 25 | % Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image 26 | % quality assessment: From error visibility to structural similarity," 27 | % IEEE Transactios on Image Processing, vol. 13, no. 4, pp. 600-612, 28 | % Apr. 2004. 29 | % 30 | % http://www.ece.uwaterloo.ca/~z70wang/research/ssim/ 31 | % 32 | % Note: This program is different from ssim_index.m, where no automatic 33 | % downsampling is performed. (downsampling was done in the above paper 34 | % and was described as suggested usage in the above website.) 35 | % 36 | % Kindly report any suggestions or corrections to zhouwang@ieee.org 37 | % 38 | %---------------------------------------------------------------------- 39 | % 40 | %Input : (1) img1: the first image being compared 41 | % (2) img2: the second image being compared 42 | % (3) K: constants in the SSIM index formula (see the above 43 | % reference). defualt value: K = [0.01 0.03] 44 | % (4) window: local window for statistics (see the above 45 | % reference). default widnow is Gaussian given by 46 | % window = fspecial('gaussian', 11, 1.5); 47 | % (5) L: dynamic range of the images. default: L = 255 48 | % 49 | %Output: (1) mssim: the mean SSIM index value between 2 images. 50 | % If one of the images being compared is regarded as 51 | % perfect quality, then mssim can be considered as the 52 | % quality measure of the other image. 53 | % If img1 = img2, then mssim = 1. 54 | % (2) ssim_map: the SSIM index map of the test image. The map 55 | % has a smaller size than the input images. The actual size 56 | % depends on the window size and the downsampling factor. 57 | % 58 | %Basic Usage: 59 | % Given 2 test images img1 and img2, whose dynamic range is 0-255 60 | % 61 | % [mssim, ssim_map] = ssim(img1, img2); 62 | % 63 | %Advanced Usage: 64 | % User defined parameters. For example 65 | % 66 | % K = [0.05 0.05]; 67 | % window = ones(8); 68 | % L = 100; 69 | % [mssim, ssim_map] = ssim(img1, img2, K, window, L); 70 | % 71 | %Visualize the results: 72 | % 73 | % mssim %Gives the mssim value 74 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map 75 | %======================================================================== 76 | 77 | 78 | if (nargin < 2 || nargin > 5) 79 | mssim = -Inf; 80 | ssim_map = -Inf; 81 | return; 82 | end 83 | 84 | if (size(img1) ~= size(img2)) 85 | mssim = -Inf; 86 | ssim_map = -Inf; 87 | return; 88 | end 89 | 90 | [M N] = size(img1); 91 | 92 | if (nargin == 2) 93 | if ((M < 11) || (N < 11)) 94 | mssim = -Inf; 95 | ssim_map = -Inf; 96 | return 97 | end 98 | window = fspecial('gaussian', 11, 1.5); % 99 | K(1) = 0.01; % default settings 100 | K(2) = 0.03; % 101 | L = 255; % 102 | end 103 | 104 | if (nargin == 3) 105 | if ((M < 11) || (N < 11)) 106 | mssim = -Inf; 107 | ssim_map = -Inf; 108 | return 109 | end 110 | window = fspecial('gaussian', 11, 1.5); 111 | L = 255; 112 | if (length(K) == 2) 113 | if (K(1) < 0 || K(2) < 0) 114 | mssim = -Inf; 115 | ssim_map = -Inf; 116 | return; 117 | end 118 | else 119 | mssim = -Inf; 120 | ssim_map = -Inf; 121 | return; 122 | end 123 | end 124 | 125 | if (nargin == 4) 126 | [H W] = size(window); 127 | if ((H*W) < 4 || (H > M) || (W > N)) 128 | mssim = -Inf; 129 | ssim_map = -Inf; 130 | return 131 | end 132 | L = 255; 133 | if (length(K) == 2) 134 | if (K(1) < 0 || K(2) < 0) 135 | mssim = -Inf; 136 | ssim_map = -Inf; 137 | return; 138 | end 139 | else 140 | mssim = -Inf; 141 | ssim_map = -Inf; 142 | return; 143 | end 144 | end 145 | 146 | if (nargin == 5) 147 | [H W] = size(window); 148 | if ((H*W) < 4 || (H > M) || (W > N)) 149 | mssim = -Inf; 150 | ssim_map = -Inf; 151 | return 152 | end 153 | if (length(K) == 2) 154 | if (K(1) < 0 || K(2) < 0) 155 | mssim = -Inf; 156 | ssim_map = -Inf; 157 | return; 158 | end 159 | else 160 | mssim = -Inf; 161 | ssim_map = -Inf; 162 | return; 163 | end 164 | end 165 | 166 | 167 | img1 = double(img1); 168 | img2 = double(img2); 169 | 170 | % automatic downsampling 171 | f = max(1,round(min(M,N)/256)); 172 | %downsampling by f 173 | %use a simple low-pass filter 174 | if(f>1) 175 | lpf = ones(f,f); 176 | lpf = lpf/sum(lpf(:)); 177 | img1 = imfilter(img1,lpf,'symmetric','same'); 178 | img2 = imfilter(img2,lpf,'symmetric','same'); 179 | 180 | img1 = img1(1:f:end,1:f:end); 181 | img2 = img2(1:f:end,1:f:end); 182 | end 183 | 184 | C1 = (K(1)*L)^2; 185 | C2 = (K(2)*L)^2; 186 | window = window/sum(sum(window)); 187 | 188 | mu1 = filter2(window, img1, 'valid'); 189 | mu2 = filter2(window, img2, 'valid'); 190 | mu1_sq = mu1.*mu1; 191 | mu2_sq = mu2.*mu2; 192 | mu1_mu2 = mu1.*mu2; 193 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq; 194 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq; 195 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2; 196 | 197 | if (C1 > 0 && C2 > 0) 198 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2)); 199 | else 200 | numerator1 = 2*mu1_mu2 + C1; 201 | numerator2 = 2*sigma12 + C2; 202 | denominator1 = mu1_sq + mu2_sq + C1; 203 | denominator2 = sigma1_sq + sigma2_sq + C2; 204 | ssim_map = ones(size(mu1)); 205 | index = (denominator1.*denominator2 > 0); 206 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index)); 207 | index = (denominator1 ~= 0) & (denominator2 == 0); 208 | ssim_map(index) = numerator1(index)./denominator1(index); 209 | end 210 | 211 | mssim = mean2(ssim_map); 212 | 213 | return -------------------------------------------------------------------------------- /torchvision_x_functional.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numbers 3 | from functools import wraps 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from scipy.ndimage.filters import gaussian_filter 10 | 11 | __numpy_type_map = { 12 | 'float64': torch.DoubleTensor, 13 | 'float32': torch.FloatTensor, 14 | 'float16': torch.HalfTensor, 15 | 'int64': torch.LongTensor, 16 | 'int32': torch.IntTensor, 17 | 'int16': torch.ShortTensor, 18 | 'uint16': torch.ShortTensor, 19 | 'int8': torch.CharTensor, 20 | 'uint8': torch.ByteTensor, 21 | } 22 | 23 | '''image functional utils 24 | 25 | ''' 26 | 27 | # NOTE: all the function should recive the ndarray like image, should be W x H x C or W x H 28 | 29 | # 如果将所有输出的维度够搞成height,width,channel 那么可以不用to_tensor??, 不行 30 | def preserve_channel_dim(func): 31 | """Preserve dummy channel dim.""" 32 | @wraps(func) 33 | def wrapped_function(img, *args, **kwargs): 34 | shape = img.shape 35 | result = func(img, *args, **kwargs) 36 | if len(shape) == 3 and shape[-1] == 1 and len(result.shape) == 2: 37 | result = np.expand_dims(result, axis=-1) 38 | return result 39 | 40 | return wrapped_function 41 | 42 | 43 | def _is_tensor_image(img): 44 | return torch.is_tensor(img) and img.ndimension() == 3 45 | 46 | 47 | def _is_numpy_image(img): 48 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) 49 | 50 | 51 | def to_tensor(img): 52 | '''convert numpy.ndarray to torch tensor. \n 53 | if the image is uint8 , it will be divided by 255;\n 54 | if the image is uint16 , it will be divided by 65535;\n 55 | if the image is float , it will not be divided, we suppose your image range should between [0~1] ;\n 56 | 57 | Arguments: 58 | img {numpy.ndarray} -- image to be converted to tensor. 59 | ''' 60 | if not _is_numpy_image(img): 61 | raise TypeError('data should be numpy ndarray. but got {}'.format(type(img))) 62 | 63 | if img.ndim == 2: 64 | img = img[:, :, None] 65 | 66 | if img.dtype == np.uint8: 67 | img = img.astype(np.float32)/255 68 | elif img.dtype == np.uint16: 69 | img = img.astype(np.float32)/65535 70 | elif img.dtype in [np.float32, np.float64]: 71 | img = img.astype(np.float32)/1 72 | else: 73 | raise TypeError('{} is not support'.format(img.dtype)) 74 | 75 | img = torch.from_numpy(img.transpose((2, 0, 1))) 76 | 77 | return img 78 | 79 | 80 | def to_pil_image(tensor): 81 | # TODO 82 | pass 83 | 84 | 85 | def to_tiff_image(tensor): 86 | # TODO 87 | pass 88 | 89 | 90 | def normalize(tensor, mean, std, inplace=False): 91 | """Normalize a tensor image with mean and standard deviation. 92 | 93 | .. note:: 94 | This transform acts out of place by default, i.e., it does not mutates the input tensor. 95 | 96 | See :class:`~torchsat.transforms.Normalize` for more details. 97 | 98 | Args: 99 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 100 | mean (sequence): Sequence of means for each channel. 101 | std (sequence): Sequence of standard deviations for each channel. 102 | 103 | Returns: 104 | Tensor: Normalized Tensor image. 105 | """ 106 | if not _is_tensor_image(tensor): 107 | raise TypeError('tensor is not a torch image.') 108 | 109 | if not inplace: 110 | tensor = tensor.clone() 111 | 112 | mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device) 113 | std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device) 114 | tensor.sub_(mean[:, None, None]).div_(std[:, None, None]) 115 | return tensor 116 | 117 | def noise(img, mode='gaussain', percent=0.02): 118 | """ 119 | TODO: Not good for uint16 data 120 | """ 121 | original_dtype = img.dtype 122 | if mode == 'gaussian': 123 | mean = 0 124 | var = 0.1 125 | sigma = var*0.5 126 | 127 | if img.ndim == 2: 128 | h, w = img.shape 129 | gauss = np.random.normal(mean, sigma, (h, w)) 130 | else: 131 | h, w, c = img.shape 132 | gauss = np.random.normal(mean, sigma, (h, w, c)) 133 | 134 | if img.dtype not in [np.float32, np.float64]: 135 | gauss = gauss * np.iinfo(img.dtype).max 136 | img = np.clip(img.astype(np.float) + gauss, 0, np.iinfo(img.dtype).max) 137 | else: 138 | img = np.clip(img.astype(np.float) + gauss, 0, 1) 139 | 140 | elif mode == 'salt': 141 | print(img.dtype) 142 | s_vs_p = 1 143 | num_salt = np.ceil(percent * img.size * s_vs_p) 144 | coords = tuple([np.random.randint(0, i - 1, int(num_salt)) for i in img.shape]) 145 | 146 | if img.dtype in [np.float32, np.float64]: 147 | img[coords] = 1 148 | else: 149 | img[coords] = np.iinfo(img.dtype).max 150 | print(img.dtype) 151 | elif mode == 'pepper': 152 | s_vs_p = 0 153 | num_pepper = np.ceil(percent * img.size * (1. - s_vs_p)) 154 | coords = tuple([np.random.randint(0, i - 1, int(num_pepper)) for i in img.shape]) 155 | img[coords] = 0 156 | 157 | elif mode == 's&p': 158 | s_vs_p = 0.5 159 | 160 | # Salt mode 161 | num_salt = np.ceil(percent * img.size * s_vs_p) 162 | coords = tuple([np.random.randint(0, i - 1, int(num_salt)) for i in img.shape]) 163 | if img.dtype in [np.float32, np.float64]: 164 | img[coords] = 1 165 | else: 166 | img[coords] = np.iinfo(img.dtype).max 167 | 168 | # Pepper mode 169 | num_pepper = np.ceil(percent* img.size * (1. - s_vs_p)) 170 | coords = tuple([np.random.randint(0, i - 1, int(num_pepper)) for i in img.shape]) 171 | img[coords] = 0 172 | else: 173 | raise ValueError('not support mode for {}'.format(mode)) 174 | 175 | noisy = img.astype(original_dtype) 176 | 177 | return noisy 178 | 179 | 180 | def gaussian_blur(img, kernel_size): 181 | # When sigma=0, it is computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8` 182 | return cv2.GaussianBlur(img, (kernel_size, kernel_size), sigmaX=0) 183 | 184 | 185 | def adjust_brightness(img, value=0): 186 | if img.dtype in [np.float, np.float32, np.float64, np.float128]: 187 | dtype_min, dtype_max = 0, 1 188 | dtype = np.float32 189 | else: 190 | dtype_min = np.iinfo(img.dtype).min 191 | dtype_max = np.iinfo(img.dtype).max 192 | dtype = np.iinfo(img.dtype) 193 | 194 | result = np.clip(img.astype(np.float)+value, dtype_min, dtype_max).astype(dtype) 195 | 196 | return result 197 | 198 | 199 | def adjust_contrast(img, factor): 200 | if img.dtype in [np.float, np.float32, np.float64, np.float128]: 201 | dtype_min, dtype_max = 0, 1 202 | dtype = np.float32 203 | else: 204 | dtype_min = np.iinfo(img.dtype).min 205 | dtype_max = np.iinfo(img.dtype).max 206 | dtype = np.iinfo(img.dtype) 207 | 208 | result = np.clip(img.astype(np.float)*factor, dtype_min, dtype_max).astype(dtype) 209 | 210 | return result 211 | 212 | def adjust_saturation(): 213 | # TODO 214 | pass 215 | 216 | def adjust_hue(): 217 | # TODO 218 | pass 219 | 220 | 221 | 222 | def to_grayscale(img, output_channels=1): 223 | """convert input ndarray image to gray sacle image. 224 | 225 | Arguments: 226 | img {ndarray} -- the input ndarray image 227 | 228 | Keyword Arguments: 229 | output_channels {int} -- output gray image channel (default: {1}) 230 | 231 | Returns: 232 | ndarray -- gray scale ndarray image 233 | """ 234 | if img.ndim == 2: 235 | gray_img = img 236 | elif img.shape[2] == 3: 237 | gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) 238 | else: 239 | gray_img = np.mean(img, axis=2) 240 | gray_img = gray_img.astype(img.dtype) 241 | 242 | if output_channels != 1: 243 | gray_img = np.tile(gray_img, (output_channels, 1, 1)) 244 | gray_img = np.transpose(gray_img, [1,2,0]) 245 | 246 | return gray_img 247 | 248 | 249 | def shift(img, top, left): 250 | (h, w) = img.shape[0:2] 251 | matrix = np.float32([[1, 0, left], [0, 1, top]]) 252 | dst = cv2.warpAffine(img, matrix, (w, h)) 253 | 254 | return dst 255 | 256 | 257 | def rotate(img, angle, center=None, scale=1.0): 258 | (h, w) = img.shape[:2] 259 | 260 | if center is None: 261 | center = (w / 2, h / 2) 262 | 263 | M = cv2.getRotationMatrix2D(center, angle, scale) 264 | rotated = cv2.warpAffine(img, M, (w, h)) 265 | 266 | return rotated 267 | 268 | 269 | def resize(img, size, interpolation=Image.BILINEAR): 270 | '''resize the image 271 | TODO: opencv resize 之后图像就成了0~1了 272 | Arguments: 273 | img {ndarray} -- the input ndarray image 274 | size {int, iterable} -- the target size, if size is intger, width and height will be resized to same \ 275 | otherwise, the size should be tuple (height, width) or list [height, width] 276 | 277 | 278 | Keyword Arguments: 279 | interpolation {Image} -- the interpolation method (default: {Image.BILINEAR}) 280 | 281 | Raises: 282 | TypeError -- img should be ndarray 283 | ValueError -- size should be intger or iterable vaiable and length should be 2. 284 | 285 | Returns: 286 | img -- resize ndarray image 287 | ''' 288 | 289 | if not _is_numpy_image(img): 290 | raise TypeError('img shoud be ndarray image [w, h, c] or [w, h], but got {}'.format(type(img))) 291 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size)==2)): 292 | raise ValueError('size should be intger or iterable vaiable(length is 2), but got {}'.format(type(size))) 293 | 294 | if isinstance(size, int): 295 | height, width = (size, size) 296 | else: 297 | height, width = (size[0], size[1]) 298 | 299 | return cv2.resize(img, (width, height), interpolation=interpolation) 300 | 301 | 302 | def pad(img, padding, fill=0, padding_mode='constant'): 303 | if isinstance(padding, int): 304 | pad_left = pad_right = pad_top = pad_bottom = padding 305 | if isinstance(padding, collections.Iterable) and len(padding) == 2: 306 | pad_left = pad_right = padding[0] 307 | pad_bottom = pad_top = padding[1] 308 | if isinstance(padding, collections.Iterable) and len(padding) == 4: 309 | pad_left = padding[0] 310 | pad_top = padding[1] 311 | pad_right = padding[2] 312 | pad_bottom = padding[3] 313 | 314 | if img.ndim == 2: 315 | if padding_mode == 'constant': 316 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode, constant_values=fill) 317 | else: 318 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode) 319 | if img.ndim == 3: 320 | if padding_mode == 'constant': 321 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode=padding_mode, constant_values=fill) 322 | else: 323 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode=padding_mode) 324 | return img 325 | 326 | 327 | def crop(img, top, left, height, width): 328 | '''crop image 329 | 330 | Arguments: 331 | img {ndarray} -- image to be croped 332 | top {int} -- top size 333 | left {int} -- left size 334 | height {int} -- croped height 335 | width {int} -- croped width 336 | ''' 337 | if not _is_numpy_image(img): 338 | raise TypeError('the input image should be numpy ndarray with dimension 2 or 3.' 339 | 'but got {}'.format(type(img)) 340 | ) 341 | 342 | if width<0 or height<0 or left <0 or height<0: 343 | raise ValueError('the input left, top, width, height should be greater than 0' 344 | 'but got left={}, top={} width={} height={}'.format(left, top, width, height) 345 | ) 346 | if img.ndim == 2: 347 | img_height, img_width = img.shape 348 | else: 349 | img_height, img_width, _ = img.shape 350 | if (left+width) > img_width or (top+height) > img_height: 351 | raise ValueError('the input crop width and height should be small or \ 352 | equal to image width and height. ') 353 | 354 | if img.ndim == 2: 355 | return img[top:(top+height), left:(left+width)] 356 | elif img.ndim == 3: 357 | return img[top:(top+height), left:(left+width), :] 358 | 359 | 360 | def center_crop(img, output_size): 361 | '''crop image 362 | 363 | Arguments: 364 | img {ndarray} -- input image 365 | output_size {number or sequence} -- the output image size. if sequence, should be [h, w] 366 | 367 | Raises: 368 | ValueError -- the input image is large than original image. 369 | 370 | Returns: 371 | ndarray image -- return croped ndarray image. 372 | ''' 373 | if img.ndim == 2: 374 | img_height, img_width = img.shape 375 | else: 376 | img_height, img_width, _ = img.shape 377 | 378 | if isinstance(output_size, numbers.Number): 379 | output_size = (int(output_size), int(output_size)) 380 | if output_size[0] > img_height or output_size[1] > img_width: 381 | raise ValueError('the output_size should not greater than image size, but got {}'.format(output_size)) 382 | 383 | target_height, target_width = output_size 384 | 385 | top = int(round((img_height - target_height)/2)) 386 | left = int(round((img_width - target_width)/2)) 387 | 388 | return crop(img, top, left, target_height, target_width) 389 | 390 | 391 | def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINEAR): 392 | 393 | img = crop(img, top, left, height, width) 394 | img = resize(img, size, interpolation) 395 | return img 396 | 397 | def vflip(img): 398 | return cv2.flip(img, 0) 399 | 400 | def hflip(img): 401 | return cv2.flip(img, 1) 402 | 403 | def flip(img, flip_code): 404 | return cv2.flip(img, flip_code) 405 | 406 | 407 | def elastic_transform(image, alpha, sigma, alpha_affine, interpolation=cv2.INTER_LINEAR, 408 | border_mode=cv2.BORDER_REFLECT_101, random_state=None, approximate=False): 409 | """Elastic deformation of images as described in [Simard2003]_ (with modifications). 410 | Based on https://gist.github.com/erniejunior/601cdf56d2b424757de5 411 | .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for 412 | Convolutional Neural Networks applied to Visual Document Analysis", in 413 | Proc. of the International Conference on Document Analysis and 414 | Recognition, 2003. 415 | """ 416 | if random_state is None: 417 | random_state = np.random.RandomState(1234) 418 | 419 | height, width = image.shape[:2] 420 | 421 | # Random affine 422 | center_square = np.float32((height, width)) // 2 423 | square_size = min((height, width)) // 3 424 | alpha = float(alpha) 425 | sigma = float(sigma) 426 | alpha_affine = float(alpha_affine) 427 | 428 | pts1 = np.float32([center_square + square_size, [center_square[0] + square_size, center_square[1] - square_size], 429 | center_square - square_size]) 430 | pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, size=pts1.shape).astype(np.float32) 431 | matrix = cv2.getAffineTransform(pts1, pts2) 432 | 433 | image = cv2.warpAffine(image, matrix, (width, height), flags=interpolation, borderMode=border_mode) 434 | 435 | if approximate: 436 | # Approximate computation smooth displacement map with a large enough kernel. 437 | # On large images (512+) this is approximately 2X times faster 438 | dx = (random_state.rand(height, width).astype(np.float32) * 2 - 1) 439 | cv2.GaussianBlur(dx, (17, 17), sigma, dst=dx) 440 | dx *= alpha 441 | 442 | dy = (random_state.rand(height, width).astype(np.float32) * 2 - 1) 443 | cv2.GaussianBlur(dy, (17, 17), sigma, dst=dy) 444 | dy *= alpha 445 | else: 446 | dx = np.float32(gaussian_filter((random_state.rand(height, width) * 2 - 1), sigma) * alpha) 447 | dy = np.float32(gaussian_filter((random_state.rand(height, width) * 2 - 1), sigma) * alpha) 448 | 449 | x, y = np.meshgrid(np.arange(width), np.arange(height)) 450 | 451 | mapx = np.float32(x + dx) 452 | mapy = np.float32(y + dy) 453 | 454 | return cv2.remap(image, mapx, mapy, interpolation, borderMode=border_mode) 455 | 456 | 457 | def bbox_shift(bboxes, top, left): 458 | pass 459 | 460 | 461 | def bbox_vflip(bboxes, img_height): 462 | """vertical flip the bboxes 463 | ........... 464 | . . 465 | . . 466 | >...........< 467 | . . 468 | . . 469 | ........... 470 | Args: 471 | bbox (ndarray): bbox ndarray [box_nums, 4] 472 | flip_code (int, optional): [description]. Defaults to 0. 473 | """ 474 | flipped = bboxes.copy() 475 | flipped[...,1::2] = img_height - bboxes[...,1::2] 476 | flipped = flipped[..., [0, 3, 2, 1]] 477 | return flipped 478 | 479 | 480 | def bbox_hflip(bboxes, img_width): 481 | """horizontal flip the bboxes 482 | ^ 483 | ............. 484 | . . . 485 | . . . 486 | . . . 487 | . . . 488 | ............. 489 | ^ 490 | Args: 491 | bbox (ndarray): bbox ndarray [box_nums, 4] 492 | flip_code (int, optional): [description]. Defaults to 0. 493 | """ 494 | flipped = bboxes.copy() 495 | flipped[..., 0::2] = img_width - bboxes[...,0::2] 496 | flipped = flipped[..., [2, 1, 0, 3]] 497 | return flipped 498 | 499 | 500 | def bbox_resize(bboxes, img_size, target_size): 501 | """resize the bbox 502 | 503 | Args: 504 | bboxes (ndarray): bbox ndarray [box_nums, 4] 505 | img_size (tuple): the image height and width 506 | target_size (int, or tuple): the target bbox size. 507 | Int or Tuple, if tuple the shape should be (height, width) 508 | """ 509 | if isinstance(target_size, numbers.Number): 510 | target_size = (target_size, target_size) 511 | 512 | ratio_height = target_size[0]/img_size[0] 513 | ratio_width = target_size[1]/img_size[1] 514 | 515 | return bboxes[...,]*[ratio_width,ratio_height,ratio_width,ratio_height] 516 | 517 | 518 | def bbox_crop(bboxes, top, left, height, width): 519 | '''crop bbox 520 | 521 | Arguments: 522 | img {ndarray} -- image to be croped 523 | top {int} -- top size 524 | left {int} -- left size 525 | height {int} -- croped height 526 | width {int} -- croped width 527 | ''' 528 | croped_bboxes = bboxes.copy() 529 | 530 | right = width + left 531 | bottom = height + top 532 | 533 | croped_bboxes[..., 0::2] = bboxes[..., 0::2].clip(left, right) - left 534 | croped_bboxes[..., 1::2] = bboxes[..., 1::2].clip(top, bottom) - top 535 | 536 | return croped_bboxes 537 | 538 | def bbox_pad(bboxes, padding): 539 | if isinstance(padding, int): 540 | pad_left = pad_right = pad_top = pad_bottom = padding 541 | if isinstance(padding, collections.Iterable) and len(padding) == 2: 542 | pad_left = pad_right = padding[0] 543 | pad_bottom = pad_top = padding[1] 544 | if isinstance(padding, collections.Iterable) and len(padding) == 4: 545 | pad_left = padding[0] 546 | pad_top = padding[1] 547 | pad_right = padding[2] 548 | pad_bottom = padding[3] 549 | 550 | pad_bboxes = bboxes.copy() 551 | pad_bboxes[..., 0::2] = bboxes[..., 0::2] + pad_left 552 | pad_bboxes[..., 1::2] = bboxes[..., 1::2] + pad_top 553 | 554 | return pad_bboxes 555 | -------------------------------------------------------------------------------- /trilinear_c/build.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import torch 4 | from torch.utils.ffi import create_extension 5 | 6 | sources = ['src/trilinear.c'] 7 | headers = ['src/trilinear.h'] 8 | extra_objects = [] 9 | #sources = [] 10 | #headers = [] 11 | defines = [] 12 | with_cuda = False 13 | 14 | this_file = os.path.dirname(os.path.realpath(__file__)) 15 | print(this_file) 16 | 17 | if torch.cuda.is_available(): 18 | print('Including CUDA code.') 19 | sources += ['src/trilinear_cuda.c'] 20 | headers += ['src/trilinear_cuda.h'] 21 | defines += [('WITH_CUDA', None)] 22 | with_cuda = True 23 | 24 | extra_objects = ['src/trilinear_kernel.cu.o'] 25 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 26 | 27 | ffi = create_extension( 28 | '_ext.trilinear', 29 | headers=headers, 30 | sources=sources, 31 | define_macros=defines, 32 | relative_to=__file__, 33 | with_cuda=with_cuda, 34 | extra_objects=extra_objects 35 | ) 36 | 37 | if __name__ == '__main__': 38 | ffi.build() 39 | -------------------------------------------------------------------------------- /trilinear_c/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CUDA_PATH=/usr/local/cuda/ 4 | 5 | cd src 6 | echo "Compiling my_lib kernels by nvcc..." 7 | nvcc -c -o trilinear_kernel.cu.o trilinear_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52 8 | 9 | cd ../ 10 | python3 build.py 11 | -------------------------------------------------------------------------------- /trilinear_c/src/trilinear.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | 6 | void TriLinearForwardCpu(const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int channels); 7 | 8 | void TriLinearBackwardCpu(const float* image, const float* image_grad, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int channels); 9 | 10 | int trilinear_forward(THFloatTensor * lut, THFloatTensor * image, THFloatTensor * output, 11 | int lut_dim, int shift, float binsize, int width, int height, int batch) 12 | { 13 | // Grab the input tensor 14 | float * lut_flat = THFloatTensor_data(lut); 15 | float * image_flat = THFloatTensor_data(image); 16 | float * output_flat = THFloatTensor_data(output); 17 | 18 | // whether color image 19 | int channels = THFloatTensor_size(image, 1); 20 | if (channels != 3) 21 | { 22 | return 0; 23 | } 24 | 25 | TriLinearForwardCpu(lut_flat, image_flat, output_flat, lut_dim, shift, binsize, width, height, channels); 26 | 27 | return 1; 28 | } 29 | 30 | int trilinear_backward(THFloatTensor * image, THFloatTensor * image_grad, THFloatTensor * lut_grad, 31 | int lut_dim, int shift, float binsize, int width, int height, int batch) 32 | { 33 | // Grab the input tensor 34 | float * image_grad_flat = THFloatTensor_data(image_grad); 35 | float * image_flat = THFloatTensor_data(image); 36 | float * lut_grad_flat = THFloatTensor_data(lut_grad); 37 | 38 | // whether color image 39 | int channels = THFloatTensor_size(image, 1); 40 | if (channels != 3) 41 | { 42 | return 0; 43 | } 44 | 45 | TriLinearBackwardCpu(image_flat, image_grad_flat, lut_grad_flat, lut_dim, shift, binsize, width, height, channels); 46 | 47 | return 1; 48 | } 49 | 50 | void TriLinearForwardCpu(const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int channels) 51 | { 52 | const int output_size = height * width;; 53 | 54 | int index = 0; 55 | for (index = 0; index < output_size; ++index) 56 | { 57 | float r = image[index]; 58 | float g = image[index + width * height]; 59 | float b = image[index + width * height * 2]; 60 | 61 | int r_id = floor(r / binsize); 62 | int g_id = floor(g / binsize); 63 | int b_id = floor(b / binsize); 64 | 65 | float r_d = fmod(r,binsize) / binsize; 66 | float g_d = fmod(g,binsize) / binsize; 67 | float b_d = fmod(b,binsize) / binsize; 68 | 69 | int id000 = r_id + g_id * dim + b_id * dim * dim; 70 | int id100 = r_id + 1 + g_id * dim + b_id * dim * dim; 71 | int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim; 72 | int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim; 73 | int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim; 74 | int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim; 75 | int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim; 76 | int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim; 77 | 78 | float w000 = (1-r_d)*(1-g_d)*(1-b_d); 79 | float w100 = r_d*(1-g_d)*(1-b_d); 80 | float w010 = (1-r_d)*g_d*(1-b_d); 81 | float w110 = r_d*g_d*(1-b_d); 82 | float w001 = (1-r_d)*(1-g_d)*b_d; 83 | float w101 = r_d*(1-g_d)*b_d; 84 | float w011 = (1-r_d)*g_d*b_d; 85 | float w111 = r_d*g_d*b_d; 86 | 87 | output[index] = w000 * lut[id000] + w100 * lut[id100] + 88 | w010 * lut[id010] + w110 * lut[id110] + 89 | w001 * lut[id001] + w101 * lut[id101] + 90 | w011 * lut[id011] + w111 * lut[id111]; 91 | 92 | output[index + width * height] = w000 * lut[id000 + shift] + w100 * lut[id100 + shift] + 93 | w010 * lut[id010 + shift] + w110 * lut[id110 + shift] + 94 | w001 * lut[id001 + shift] + w101 * lut[id101 + shift] + 95 | w011 * lut[id011 + shift] + w111 * lut[id111 + shift]; 96 | 97 | output[index + width * height * 2] = w000 * lut[id000 + shift * 2] + w100 * lut[id100 + shift * 2] + 98 | w010 * lut[id010 + shift * 2] + w110 * lut[id110 + shift * 2] + 99 | w001 * lut[id001 + shift * 2] + w101 * lut[id101 + shift * 2] + 100 | w011 * lut[id011 + shift * 2] + w111 * lut[id111 + shift * 2]; 101 | } 102 | } 103 | 104 | void TriLinearBackwardCpu(const float* image, const float* image_grad, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int channels) 105 | { 106 | const int output_size = height * width; 107 | 108 | int index = 0; 109 | for (index = 0; index < output_size; ++index) 110 | { 111 | float r = image[index]; 112 | float g = image[index + width * height]; 113 | float b = image[index + width * height * 2]; 114 | 115 | int r_id = floor(r / binsize); 116 | int g_id = floor(g / binsize); 117 | int b_id = floor(b / binsize); 118 | 119 | float r_d = fmod(r,binsize) / binsize; 120 | float g_d = fmod(g,binsize) / binsize; 121 | float b_d = fmod(b,binsize) / binsize; 122 | 123 | int id000 = r_id + g_id * dim + b_id * dim * dim; 124 | int id100 = r_id + 1 + g_id * dim + b_id * dim * dim; 125 | int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim; 126 | int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim; 127 | int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim; 128 | int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim; 129 | int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim; 130 | int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim; 131 | 132 | float w000 = (1-r_d)*(1-g_d)*(1-b_d); 133 | float w100 = r_d*(1-g_d)*(1-b_d); 134 | float w010 = (1-r_d)*g_d*(1-b_d); 135 | float w110 = r_d*g_d*(1-b_d); 136 | float w001 = (1-r_d)*(1-g_d)*b_d; 137 | float w101 = r_d*(1-g_d)*b_d; 138 | float w011 = (1-r_d)*g_d*b_d; 139 | float w111 = r_d*g_d*b_d; 140 | 141 | lut_grad[id000] += w000 * image_grad[index]; 142 | lut_grad[id100] += w100 * image_grad[index]; 143 | lut_grad[id010] += w010 * image_grad[index]; 144 | lut_grad[id110] += w110 * image_grad[index]; 145 | lut_grad[id001] += w001 * image_grad[index]; 146 | lut_grad[id101] += w101 * image_grad[index]; 147 | lut_grad[id011] += w011 * image_grad[index]; 148 | lut_grad[id111] += w111 * image_grad[index]; 149 | 150 | lut_grad[id000 + shift] += w000 * image_grad[index + width * height]; 151 | lut_grad[id100 + shift] += w100 * image_grad[index + width * height]; 152 | lut_grad[id010 + shift] += w010 * image_grad[index + width * height]; 153 | lut_grad[id110 + shift] += w110 * image_grad[index + width * height]; 154 | lut_grad[id001 + shift] += w001 * image_grad[index + width * height]; 155 | lut_grad[id101 + shift] += w101 * image_grad[index + width * height]; 156 | lut_grad[id011 + shift] += w011 * image_grad[index + width * height]; 157 | lut_grad[id111 + shift] += w111 * image_grad[index + width * height]; 158 | 159 | lut_grad[id000 + shift* 2] += w000 * image_grad[index + width * height * 2]; 160 | lut_grad[id100 + shift* 2] += w100 * image_grad[index + width * height * 2]; 161 | lut_grad[id010 + shift* 2] += w010 * image_grad[index + width * height * 2]; 162 | lut_grad[id110 + shift* 2] += w110 * image_grad[index + width * height * 2]; 163 | lut_grad[id001 + shift* 2] += w001 * image_grad[index + width * height * 2]; 164 | lut_grad[id101 + shift* 2] += w101 * image_grad[index + width * height * 2]; 165 | lut_grad[id011 + shift* 2] += w011 * image_grad[index + width * height * 2]; 166 | lut_grad[id111 + shift* 2] += w111 * image_grad[index + width * height * 2]; 167 | } 168 | } 169 | -------------------------------------------------------------------------------- /trilinear_c/src/trilinear.h: -------------------------------------------------------------------------------- 1 | int trilinear_forward(THFloatTensor * lut, THFloatTensor * image, THFloatTensor * output, 2 | int lut_dim, int shift, float binsize, int width, int height, int batch); 3 | 4 | int trilinear_backward(THFloatTensor * image, THFloatTensor * image_grad, THFloatTensor * lut_grad, 5 | int lut_dim, int shift, float binsize, int width, int height, int batch); 6 | 7 | -------------------------------------------------------------------------------- /trilinear_c/src/trilinear_cuda.c: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "trilinear_kernel.h" 4 | 5 | extern THCState *state; 6 | 7 | int trilinear_forward_cuda(THCudaTensor * lut, THCudaTensor * image, THCudaTensor * output, 8 | int lut_dim, int shift, float binsize, int width, int height, int batch) 9 | { 10 | // Grab the input tensor 11 | float * lut_flat = THCudaTensor_data(state, lut); 12 | float * image_flat = THCudaTensor_data(state, image); 13 | float * output_flat = THCudaTensor_data(state, output); 14 | 15 | // whether color image 16 | //int channels = THCudaTensor_size(state,image, 1); 17 | //if (channels != 3) 18 | //{ 19 | // return 0; 20 | //} 21 | 22 | cudaStream_t stream = THCState_getCurrentStream(state); 23 | 24 | TriLinearForwardLaucher(lut_flat, image_flat, output_flat, lut_dim, shift, binsize, width, height, batch, stream); 25 | 26 | return 1; 27 | } 28 | 29 | int trilinear_backward_cuda(THCudaTensor * image, THCudaTensor * image_grad, THCudaTensor * lut_grad, 30 | int lut_dim, int shift, float binsize, int width, int height, int batch) 31 | { 32 | // Grab the input tensor 33 | float * image_grad_flat = THCudaTensor_data(state, image_grad); 34 | float * image_flat = THCudaTensor_data(state, image); 35 | float * lut_grad_flat = THCudaTensor_data(state, lut_grad); 36 | 37 | // whether color image 38 | //int channels = THCudaTensor_size(state,image, 1); 39 | //if (channels != 3) 40 | //{ 41 | // return 0; 42 | //} 43 | 44 | cudaStream_t stream = THCState_getCurrentStream(state); 45 | TriLinearBackwardLaucher(image_flat, image_grad_flat, lut_grad_flat, lut_dim, shift, binsize, width, height, batch, stream); 46 | 47 | return 1; 48 | } 49 | -------------------------------------------------------------------------------- /trilinear_c/src/trilinear_cuda.h: -------------------------------------------------------------------------------- 1 | int trilinear_forward_cuda(THCudaTensor * lut, THCudaTensor * image, THCudaTensor * output, 2 | int lut_dim, int shift, float binsize, int width, int height, int batch); 3 | 4 | int trilinear_backward_cuda(THCudaTensor * image, THCudaTensor * image_grad, THCudaTensor * lut_grad, 5 | int lut_dim, int shift, float binsize, int width, int height, int batch); 6 | -------------------------------------------------------------------------------- /trilinear_c/src/trilinear_kernel.cu: -------------------------------------------------------------------------------- 1 | #ifdef __cplusplus 2 | extern "C" { 3 | #endif 4 | 5 | #include 6 | #include 7 | #include 8 | #include "trilinear_kernel.h" 9 | 10 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 11 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 12 | i += blockDim.x * gridDim.x) 13 | 14 | 15 | __global__ void TriLinearForward(const int nthreads, const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int batch) { 16 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 17 | 18 | float r = image[index]; 19 | float g = image[index + width * height * batch]; 20 | float b = image[index + width * height * batch * 2]; 21 | 22 | int r_id = floor(r / binsize); 23 | int g_id = floor(g / binsize); 24 | int b_id = floor(b / binsize); 25 | 26 | float r_d = fmod(r,binsize) / binsize; 27 | float g_d = fmod(g,binsize) / binsize; 28 | float b_d = fmod(b,binsize) / binsize; 29 | 30 | int id000 = r_id + g_id * dim + b_id * dim * dim; 31 | int id100 = r_id + 1 + g_id * dim + b_id * dim * dim; 32 | int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim; 33 | int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim; 34 | int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim; 35 | int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim; 36 | int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim; 37 | int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim; 38 | 39 | float w000 = (1-r_d)*(1-g_d)*(1-b_d); 40 | float w100 = r_d*(1-g_d)*(1-b_d); 41 | float w010 = (1-r_d)*g_d*(1-b_d); 42 | float w110 = r_d*g_d*(1-b_d); 43 | float w001 = (1-r_d)*(1-g_d)*b_d; 44 | float w101 = r_d*(1-g_d)*b_d; 45 | float w011 = (1-r_d)*g_d*b_d; 46 | float w111 = r_d*g_d*b_d; 47 | 48 | output[index] = w000 * lut[id000] + w100 * lut[id100] + 49 | w010 * lut[id010] + w110 * lut[id110] + 50 | w001 * lut[id001] + w101 * lut[id101] + 51 | w011 * lut[id011] + w111 * lut[id111]; 52 | 53 | output[index + width * height * batch] = w000 * lut[id000 + shift] + w100 * lut[id100 + shift] + 54 | w010 * lut[id010 + shift] + w110 * lut[id110 + shift] + 55 | w001 * lut[id001 + shift] + w101 * lut[id101 + shift] + 56 | w011 * lut[id011 + shift] + w111 * lut[id111 + shift]; 57 | 58 | output[index + width * height * batch * 2] = w000 * lut[id000 + shift * 2] + w100 * lut[id100 + shift * 2] + 59 | w010 * lut[id010 + shift * 2] + w110 * lut[id110 + shift * 2] + 60 | w001 * lut[id001 + shift * 2] + w101 * lut[id101 + shift * 2] + 61 | w011 * lut[id011 + shift * 2] + w111 * lut[id111 + shift * 2]; 62 | 63 | } 64 | } 65 | 66 | 67 | int TriLinearForwardLaucher(const float* lut, const float* image, float* output, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream) { 68 | const int kThreadsPerBlock = 1024; 69 | const int output_size = height * width * batch; 70 | cudaError_t err; 71 | 72 | 73 | TriLinearForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( 74 | output_size, lut, image, output, lut_dim, shift, binsize, width, height, batch); 75 | 76 | err = cudaGetLastError(); 77 | if(cudaSuccess != err) { 78 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 79 | exit( -1 ); 80 | } 81 | 82 | return 1; 83 | } 84 | 85 | 86 | __global__ void TriLinearBackward(const int nthreads, const float* image, const float* image_grad, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int batch) { 87 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 88 | 89 | float r = image[index]; 90 | float g = image[index + width * height * batch]; 91 | float b = image[index + width * height * batch * 2]; 92 | 93 | int r_id = floor(r / binsize); 94 | int g_id = floor(g / binsize); 95 | int b_id = floor(b / binsize); 96 | 97 | float r_d = fmod(r,binsize) / binsize; 98 | float g_d = fmod(g,binsize) / binsize; 99 | float b_d = fmod(b,binsize) / binsize; 100 | 101 | int id000 = r_id + g_id * dim + b_id * dim * dim; 102 | int id100 = r_id + 1 + g_id * dim + b_id * dim * dim; 103 | int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim; 104 | int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim; 105 | int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim; 106 | int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim; 107 | int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim; 108 | int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim; 109 | 110 | float w000 = (1-r_d)*(1-g_d)*(1-b_d); 111 | float w100 = r_d*(1-g_d)*(1-b_d); 112 | float w010 = (1-r_d)*g_d*(1-b_d); 113 | float w110 = r_d*g_d*(1-b_d); 114 | float w001 = (1-r_d)*(1-g_d)*b_d; 115 | float w101 = r_d*(1-g_d)*b_d; 116 | float w011 = (1-r_d)*g_d*b_d; 117 | float w111 = r_d*g_d*b_d; 118 | 119 | atomicAdd(lut_grad + id000, image_grad[index] * w000); 120 | atomicAdd(lut_grad + id100, image_grad[index] * w100); 121 | atomicAdd(lut_grad + id010, image_grad[index] * w010); 122 | atomicAdd(lut_grad + id110, image_grad[index] * w110); 123 | atomicAdd(lut_grad + id001, image_grad[index] * w001); 124 | atomicAdd(lut_grad + id101, image_grad[index] * w101); 125 | atomicAdd(lut_grad + id011, image_grad[index] * w011); 126 | atomicAdd(lut_grad + id111, image_grad[index] * w111); 127 | 128 | atomicAdd(lut_grad + id000 + shift, image_grad[index + width * height * batch] * w000); 129 | atomicAdd(lut_grad + id100 + shift, image_grad[index + width * height * batch] * w100); 130 | atomicAdd(lut_grad + id010 + shift, image_grad[index + width * height * batch] * w010); 131 | atomicAdd(lut_grad + id110 + shift, image_grad[index + width * height * batch] * w110); 132 | atomicAdd(lut_grad + id001 + shift, image_grad[index + width * height * batch] * w001); 133 | atomicAdd(lut_grad + id101 + shift, image_grad[index + width * height * batch] * w101); 134 | atomicAdd(lut_grad + id011 + shift, image_grad[index + width * height * batch] * w011); 135 | atomicAdd(lut_grad + id111 + shift, image_grad[index + width * height * batch] * w111); 136 | 137 | atomicAdd(lut_grad + id000 + shift * 2, image_grad[index + width * height * batch * 2] * w000); 138 | atomicAdd(lut_grad + id100 + shift * 2, image_grad[index + width * height * batch * 2] * w100); 139 | atomicAdd(lut_grad + id010 + shift * 2, image_grad[index + width * height * batch * 2] * w010); 140 | atomicAdd(lut_grad + id110 + shift * 2, image_grad[index + width * height * batch * 2] * w110); 141 | atomicAdd(lut_grad + id001 + shift * 2, image_grad[index + width * height * batch * 2] * w001); 142 | atomicAdd(lut_grad + id101 + shift * 2, image_grad[index + width * height * batch * 2] * w101); 143 | atomicAdd(lut_grad + id011 + shift * 2, image_grad[index + width * height * batch * 2] * w011); 144 | atomicAdd(lut_grad + id111 + shift * 2, image_grad[index + width * height * batch * 2] * w111); 145 | 146 | } 147 | } 148 | 149 | int TriLinearBackwardLaucher(const float* image, const float* image_grad, float* lut_grad, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream) { 150 | const int kThreadsPerBlock = 1024; 151 | const int output_size = height * width * batch; 152 | cudaError_t err; 153 | 154 | TriLinearBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>( 155 | output_size, image, image_grad, lut_grad, lut_dim, shift, binsize, width, height, batch); 156 | 157 | err = cudaGetLastError(); 158 | if(cudaSuccess != err) { 159 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 160 | exit( -1 ); 161 | } 162 | 163 | return 1; 164 | } 165 | 166 | 167 | #ifdef __cplusplus 168 | } 169 | #endif 170 | -------------------------------------------------------------------------------- /trilinear_c/src/trilinear_kernel.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/trilinear_c/src/trilinear_kernel.cu.o -------------------------------------------------------------------------------- /trilinear_c/src/trilinear_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _TRILINEAR_KERNEL 2 | #define _TRILINEAR_KERNEL 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | __global__ void TriLinearForward(const int nthreads, const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int batch); 9 | 10 | int TriLinearForwardLaucher(const float* lut, const float* image, float* output, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream); 11 | 12 | __global__ void TriLinearBackward(const int nthreads, const float* image, const float* image_grad, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int batch); 13 | 14 | int TriLinearBackwardLaucher(const float* image, const float* image_grad, float* lut_grad, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream); 15 | 16 | #ifdef __cplusplus 17 | } 18 | #endif 19 | 20 | #endif 21 | 22 | -------------------------------------------------------------------------------- /trilinear_cpp/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | import torch 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension 4 | 5 | if torch.cuda.is_available(): 6 | print('Including CUDA code.') 7 | setup( 8 | name='trilinear', 9 | ext_modules=[ 10 | CUDAExtension('trilinear', [ 11 | 'src/trilinear_cuda.cpp', 12 | 'src/trilinear_kernel.cu', 13 | ]) 14 | ], 15 | cmdclass={ 16 | 'build_ext': BuildExtension 17 | }) 18 | else: 19 | print('NO CUDA is found. Fall back to CPU.') 20 | setup(name='trilinear', 21 | ext_modules=[CppExtension('trilinear', ['src/trilinear.cpp'])], 22 | cmdclass={'build_ext': BuildExtension}) 23 | -------------------------------------------------------------------------------- /trilinear_cpp/setup.sh: -------------------------------------------------------------------------------- 1 | export CUDA_HOME=/usr/local/cuda-10.2 && python3 setup.py install 2 | -------------------------------------------------------------------------------- /trilinear_cpp/src/trilinear.cpp: -------------------------------------------------------------------------------- 1 | #include "trilinear.h" 2 | 3 | 4 | void TriLinearForwardCpu(const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int channels); 5 | 6 | void TriLinearBackwardCpu(const float* image, const float* image_grad, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int channels); 7 | 8 | int trilinear_forward(torch::Tensor lut, torch::Tensor image, torch::Tensor output, 9 | int lut_dim, int shift, float binsize, int width, int height, int batch) 10 | { 11 | // Grab the input tensor 12 | float * lut_flat = lut.data(); 13 | float * image_flat = image.data(); 14 | float * output_flat = output.data(); 15 | 16 | // whether color image 17 | auto image_size = image.sizes(); 18 | int channels = image_size[1]; 19 | if (channels != 3) 20 | { 21 | return 0; 22 | } 23 | 24 | TriLinearForwardCpu(lut_flat, image_flat, output_flat, lut_dim, shift, binsize, width, height, channels); 25 | 26 | return 1; 27 | } 28 | 29 | int trilinear_backward(torch::Tensor image, torch::Tensor image_grad, torch::Tensor lut_grad, 30 | int lut_dim, int shift, float binsize, int width, int height, int batch) 31 | { 32 | // Grab the input tensor 33 | float * image_grad_flat = image_grad.data(); 34 | float * image_flat = image.data(); 35 | float * lut_grad_flat = lut_grad.data(); 36 | 37 | // whether color image 38 | auto image_size = image.sizes(); 39 | int channels = image_size[1]; 40 | if (channels != 3) 41 | { 42 | return 0; 43 | } 44 | 45 | TriLinearBackwardCpu(image_flat, image_grad_flat, lut_grad_flat, lut_dim, shift, binsize, width, height, channels); 46 | 47 | return 1; 48 | } 49 | 50 | void TriLinearForwardCpu(const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int channels) 51 | { 52 | const int output_size = height * width;; 53 | 54 | int index = 0; 55 | for (index = 0; index < output_size; ++index) 56 | { 57 | float r = image[index]; 58 | float g = image[index + width * height]; 59 | float b = image[index + width * height * 2]; 60 | 61 | int r_id = floor(r / binsize); 62 | int g_id = floor(g / binsize); 63 | int b_id = floor(b / binsize); 64 | 65 | float r_d = fmod(r,binsize) / binsize; 66 | float g_d = fmod(g,binsize) / binsize; 67 | float b_d = fmod(b,binsize) / binsize; 68 | 69 | int id000 = r_id + g_id * dim + b_id * dim * dim; 70 | int id100 = r_id + 1 + g_id * dim + b_id * dim * dim; 71 | int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim; 72 | int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim; 73 | int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim; 74 | int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim; 75 | int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim; 76 | int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim; 77 | 78 | float w000 = (1-r_d)*(1-g_d)*(1-b_d); 79 | float w100 = r_d*(1-g_d)*(1-b_d); 80 | float w010 = (1-r_d)*g_d*(1-b_d); 81 | float w110 = r_d*g_d*(1-b_d); 82 | float w001 = (1-r_d)*(1-g_d)*b_d; 83 | float w101 = r_d*(1-g_d)*b_d; 84 | float w011 = (1-r_d)*g_d*b_d; 85 | float w111 = r_d*g_d*b_d; 86 | 87 | output[index] = w000 * lut[id000] + w100 * lut[id100] + 88 | w010 * lut[id010] + w110 * lut[id110] + 89 | w001 * lut[id001] + w101 * lut[id101] + 90 | w011 * lut[id011] + w111 * lut[id111]; 91 | 92 | output[index + width * height] = w000 * lut[id000 + shift] + w100 * lut[id100 + shift] + 93 | w010 * lut[id010 + shift] + w110 * lut[id110 + shift] + 94 | w001 * lut[id001 + shift] + w101 * lut[id101 + shift] + 95 | w011 * lut[id011 + shift] + w111 * lut[id111 + shift]; 96 | 97 | output[index + width * height * 2] = w000 * lut[id000 + shift * 2] + w100 * lut[id100 + shift * 2] + 98 | w010 * lut[id010 + shift * 2] + w110 * lut[id110 + shift * 2] + 99 | w001 * lut[id001 + shift * 2] + w101 * lut[id101 + shift * 2] + 100 | w011 * lut[id011 + shift * 2] + w111 * lut[id111 + shift * 2]; 101 | } 102 | } 103 | 104 | void TriLinearBackwardCpu(const float* image, const float* image_grad, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int channels) 105 | { 106 | const int output_size = height * width; 107 | 108 | int index = 0; 109 | for (index = 0; index < output_size; ++index) 110 | { 111 | float r = image[index]; 112 | float g = image[index + width * height]; 113 | float b = image[index + width * height * 2]; 114 | 115 | int r_id = floor(r / binsize); 116 | int g_id = floor(g / binsize); 117 | int b_id = floor(b / binsize); 118 | 119 | float r_d = fmod(r,binsize) / binsize; 120 | float g_d = fmod(g,binsize) / binsize; 121 | float b_d = fmod(b,binsize) / binsize; 122 | 123 | int id000 = r_id + g_id * dim + b_id * dim * dim; 124 | int id100 = r_id + 1 + g_id * dim + b_id * dim * dim; 125 | int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim; 126 | int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim; 127 | int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim; 128 | int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim; 129 | int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim; 130 | int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim; 131 | 132 | float w000 = (1-r_d)*(1-g_d)*(1-b_d); 133 | float w100 = r_d*(1-g_d)*(1-b_d); 134 | float w010 = (1-r_d)*g_d*(1-b_d); 135 | float w110 = r_d*g_d*(1-b_d); 136 | float w001 = (1-r_d)*(1-g_d)*b_d; 137 | float w101 = r_d*(1-g_d)*b_d; 138 | float w011 = (1-r_d)*g_d*b_d; 139 | float w111 = r_d*g_d*b_d; 140 | 141 | lut_grad[id000] += w000 * image_grad[index]; 142 | lut_grad[id100] += w100 * image_grad[index]; 143 | lut_grad[id010] += w010 * image_grad[index]; 144 | lut_grad[id110] += w110 * image_grad[index]; 145 | lut_grad[id001] += w001 * image_grad[index]; 146 | lut_grad[id101] += w101 * image_grad[index]; 147 | lut_grad[id011] += w011 * image_grad[index]; 148 | lut_grad[id111] += w111 * image_grad[index]; 149 | 150 | lut_grad[id000 + shift] += w000 * image_grad[index + width * height]; 151 | lut_grad[id100 + shift] += w100 * image_grad[index + width * height]; 152 | lut_grad[id010 + shift] += w010 * image_grad[index + width * height]; 153 | lut_grad[id110 + shift] += w110 * image_grad[index + width * height]; 154 | lut_grad[id001 + shift] += w001 * image_grad[index + width * height]; 155 | lut_grad[id101 + shift] += w101 * image_grad[index + width * height]; 156 | lut_grad[id011 + shift] += w011 * image_grad[index + width * height]; 157 | lut_grad[id111 + shift] += w111 * image_grad[index + width * height]; 158 | 159 | lut_grad[id000 + shift* 2] += w000 * image_grad[index + width * height * 2]; 160 | lut_grad[id100 + shift* 2] += w100 * image_grad[index + width * height * 2]; 161 | lut_grad[id010 + shift* 2] += w010 * image_grad[index + width * height * 2]; 162 | lut_grad[id110 + shift* 2] += w110 * image_grad[index + width * height * 2]; 163 | lut_grad[id001 + shift* 2] += w001 * image_grad[index + width * height * 2]; 164 | lut_grad[id101 + shift* 2] += w101 * image_grad[index + width * height * 2]; 165 | lut_grad[id011 + shift* 2] += w011 * image_grad[index + width * height * 2]; 166 | lut_grad[id111 + shift* 2] += w111 * image_grad[index + width * height * 2]; 167 | } 168 | } 169 | 170 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 171 | m.def("forward", &trilinear_forward, "Trilinear forward"); 172 | m.def("backward", &trilinear_backward, "Trilinear backward"); 173 | } 174 | -------------------------------------------------------------------------------- /trilinear_cpp/src/trilinear.h: -------------------------------------------------------------------------------- 1 | #ifndef TRILINEAR_H 2 | #define TRILINEAR_H 3 | 4 | #include 5 | 6 | int trilinear_forward(torch::Tensor lut, torch::Tensor image, torch::Tensor output, 7 | int lut_dim, int shift, float binsize, int width, int height, int batch); 8 | 9 | int trilinear_backward(torch::Tensor image, torch::Tensor image_grad, torch::Tensor lut_grad, 10 | int lut_dim, int shift, float binsize, int width, int height, int batch); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /trilinear_cpp/src/trilinear_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include "trilinear_kernel.h" 2 | #include 3 | #include 4 | 5 | int trilinear_forward_cuda(torch::Tensor lut, torch::Tensor image, torch::Tensor output, 6 | int lut_dim, int shift, float binsize, int width, int height, int batch) 7 | { 8 | // Grab the input tensor 9 | float * lut_flat = lut.data(); 10 | float * image_flat = image.data(); 11 | float * output_flat = output.data(); 12 | 13 | TriLinearForwardLaucher(lut_flat, image_flat, output_flat, lut_dim, shift, binsize, width, height, batch, at::cuda::getCurrentCUDAStream()); 14 | 15 | return 1; 16 | } 17 | 18 | int trilinear_backward_cuda(torch::Tensor image, torch::Tensor image_grad, torch::Tensor lut_grad, 19 | int lut_dim, int shift, float binsize, int width, int height, int batch) 20 | { 21 | // Grab the input tensor 22 | float * image_grad_flat = image_grad.data(); 23 | float * image_flat = image.data(); 24 | float * lut_grad_flat = lut_grad.data(); 25 | 26 | TriLinearBackwardLaucher(image_flat, image_grad_flat, lut_grad_flat, lut_dim, shift, binsize, width, height, batch, at::cuda::getCurrentCUDAStream()); 27 | 28 | return 1; 29 | } 30 | 31 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 32 | m.def("forward", &trilinear_forward_cuda, "Trilinear forward"); 33 | m.def("backward", &trilinear_backward_cuda, "Trilinear backward"); 34 | } 35 | 36 | -------------------------------------------------------------------------------- /trilinear_cpp/src/trilinear_cuda.h: -------------------------------------------------------------------------------- 1 | #ifndef TRILINEAR_CUDA_H 2 | #define TRILINEAR_CUDA_H 3 | 4 | #import 5 | 6 | int trilinear_forward_cuda(torch::Tensor lut, torch::Tensor image, torch::Tensor output, 7 | int lut_dim, int shift, float binsize, int width, int height, int batch); 8 | 9 | int trilinear_backward_cuda(torch::Tensor image, torch::Tensor image_grad, torch::Tensor lut_grad, 10 | int lut_dim, int shift, float binsize, int width, int height, int batch); 11 | 12 | #endif 13 | -------------------------------------------------------------------------------- /trilinear_cpp/src/trilinear_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include "trilinear_kernel.h" 4 | 5 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 6 | for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ 7 | i += blockDim.x * gridDim.x) 8 | 9 | 10 | __global__ void TriLinearForward(const int nthreads, const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int batch) { 11 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 12 | 13 | float r = image[index]; 14 | float g = image[index + width * height * batch]; 15 | float b = image[index + width * height * batch * 2]; 16 | 17 | int r_id = floor(r / binsize); 18 | int g_id = floor(g / binsize); 19 | int b_id = floor(b / binsize); 20 | 21 | float r_d = fmod(r,binsize) / binsize; 22 | float g_d = fmod(g,binsize) / binsize; 23 | float b_d = fmod(b,binsize) / binsize; 24 | 25 | int id000 = r_id + g_id * dim + b_id * dim * dim; 26 | int id100 = r_id + 1 + g_id * dim + b_id * dim * dim; 27 | int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim; 28 | int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim; 29 | int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim; 30 | int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim; 31 | int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim; 32 | int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim; 33 | 34 | float w000 = (1-r_d)*(1-g_d)*(1-b_d); 35 | float w100 = r_d*(1-g_d)*(1-b_d); 36 | float w010 = (1-r_d)*g_d*(1-b_d); 37 | float w110 = r_d*g_d*(1-b_d); 38 | float w001 = (1-r_d)*(1-g_d)*b_d; 39 | float w101 = r_d*(1-g_d)*b_d; 40 | float w011 = (1-r_d)*g_d*b_d; 41 | float w111 = r_d*g_d*b_d; 42 | 43 | output[index] = w000 * lut[id000] + w100 * lut[id100] + 44 | w010 * lut[id010] + w110 * lut[id110] + 45 | w001 * lut[id001] + w101 * lut[id101] + 46 | w011 * lut[id011] + w111 * lut[id111]; 47 | 48 | output[index + width * height * batch] = w000 * lut[id000 + shift] + w100 * lut[id100 + shift] + 49 | w010 * lut[id010 + shift] + w110 * lut[id110 + shift] + 50 | w001 * lut[id001 + shift] + w101 * lut[id101 + shift] + 51 | w011 * lut[id011 + shift] + w111 * lut[id111 + shift]; 52 | 53 | output[index + width * height * batch * 2] = w000 * lut[id000 + shift * 2] + w100 * lut[id100 + shift * 2] + 54 | w010 * lut[id010 + shift * 2] + w110 * lut[id110 + shift * 2] + 55 | w001 * lut[id001 + shift * 2] + w101 * lut[id101 + shift * 2] + 56 | w011 * lut[id011 + shift * 2] + w111 * lut[id111 + shift * 2]; 57 | 58 | } 59 | } 60 | 61 | 62 | int TriLinearForwardLaucher(const float* lut, const float* image, float* output, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream) { 63 | const int kThreadsPerBlock = 1024; 64 | const int output_size = height * width * batch; 65 | cudaError_t err; 66 | 67 | 68 | TriLinearForward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(output_size, lut, image, output, lut_dim, shift, binsize, width, height, batch); 69 | 70 | err = cudaGetLastError(); 71 | if(cudaSuccess != err) { 72 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 73 | exit( -1 ); 74 | } 75 | 76 | return 1; 77 | } 78 | 79 | 80 | __global__ void TriLinearBackward(const int nthreads, const float* image, const float* image_grad, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int batch) { 81 | CUDA_1D_KERNEL_LOOP(index, nthreads) { 82 | 83 | float r = image[index]; 84 | float g = image[index + width * height * batch]; 85 | float b = image[index + width * height * batch * 2]; 86 | 87 | int r_id = floor(r / binsize); 88 | int g_id = floor(g / binsize); 89 | int b_id = floor(b / binsize); 90 | 91 | float r_d = fmod(r,binsize) / binsize; 92 | float g_d = fmod(g,binsize) / binsize; 93 | float b_d = fmod(b,binsize) / binsize; 94 | 95 | int id000 = r_id + g_id * dim + b_id * dim * dim; 96 | int id100 = r_id + 1 + g_id * dim + b_id * dim * dim; 97 | int id010 = r_id + (g_id + 1) * dim + b_id * dim * dim; 98 | int id110 = r_id + 1 + (g_id + 1) * dim + b_id * dim * dim; 99 | int id001 = r_id + g_id * dim + (b_id + 1) * dim * dim; 100 | int id101 = r_id + 1 + g_id * dim + (b_id + 1) * dim * dim; 101 | int id011 = r_id + (g_id + 1) * dim + (b_id + 1) * dim * dim; 102 | int id111 = r_id + 1 + (g_id + 1) * dim + (b_id + 1) * dim * dim; 103 | 104 | float w000 = (1-r_d)*(1-g_d)*(1-b_d); 105 | float w100 = r_d*(1-g_d)*(1-b_d); 106 | float w010 = (1-r_d)*g_d*(1-b_d); 107 | float w110 = r_d*g_d*(1-b_d); 108 | float w001 = (1-r_d)*(1-g_d)*b_d; 109 | float w101 = r_d*(1-g_d)*b_d; 110 | float w011 = (1-r_d)*g_d*b_d; 111 | float w111 = r_d*g_d*b_d; 112 | 113 | atomicAdd(lut_grad + id000, image_grad[index] * w000); 114 | atomicAdd(lut_grad + id100, image_grad[index] * w100); 115 | atomicAdd(lut_grad + id010, image_grad[index] * w010); 116 | atomicAdd(lut_grad + id110, image_grad[index] * w110); 117 | atomicAdd(lut_grad + id001, image_grad[index] * w001); 118 | atomicAdd(lut_grad + id101, image_grad[index] * w101); 119 | atomicAdd(lut_grad + id011, image_grad[index] * w011); 120 | atomicAdd(lut_grad + id111, image_grad[index] * w111); 121 | 122 | atomicAdd(lut_grad + id000 + shift, image_grad[index + width * height * batch] * w000); 123 | atomicAdd(lut_grad + id100 + shift, image_grad[index + width * height * batch] * w100); 124 | atomicAdd(lut_grad + id010 + shift, image_grad[index + width * height * batch] * w010); 125 | atomicAdd(lut_grad + id110 + shift, image_grad[index + width * height * batch] * w110); 126 | atomicAdd(lut_grad + id001 + shift, image_grad[index + width * height * batch] * w001); 127 | atomicAdd(lut_grad + id101 + shift, image_grad[index + width * height * batch] * w101); 128 | atomicAdd(lut_grad + id011 + shift, image_grad[index + width * height * batch] * w011); 129 | atomicAdd(lut_grad + id111 + shift, image_grad[index + width * height * batch] * w111); 130 | 131 | atomicAdd(lut_grad + id000 + shift * 2, image_grad[index + width * height * batch * 2] * w000); 132 | atomicAdd(lut_grad + id100 + shift * 2, image_grad[index + width * height * batch * 2] * w100); 133 | atomicAdd(lut_grad + id010 + shift * 2, image_grad[index + width * height * batch * 2] * w010); 134 | atomicAdd(lut_grad + id110 + shift * 2, image_grad[index + width * height * batch * 2] * w110); 135 | atomicAdd(lut_grad + id001 + shift * 2, image_grad[index + width * height * batch * 2] * w001); 136 | atomicAdd(lut_grad + id101 + shift * 2, image_grad[index + width * height * batch * 2] * w101); 137 | atomicAdd(lut_grad + id011 + shift * 2, image_grad[index + width * height * batch * 2] * w011); 138 | atomicAdd(lut_grad + id111 + shift * 2, image_grad[index + width * height * batch * 2] * w111); 139 | } 140 | } 141 | 142 | int TriLinearBackwardLaucher(const float* image, const float* image_grad, float* lut_grad, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream) { 143 | const int kThreadsPerBlock = 1024; 144 | const int output_size = height * width * batch; 145 | cudaError_t err; 146 | 147 | TriLinearBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, kThreadsPerBlock, 0, stream>>>(output_size, image, image_grad, lut_grad, lut_dim, shift, binsize, width, height, batch); 148 | 149 | err = cudaGetLastError(); 150 | if(cudaSuccess != err) { 151 | fprintf( stderr, "cudaCheckError() failed : %s\n", cudaGetErrorString( err ) ); 152 | exit( -1 ); 153 | } 154 | 155 | return 1; 156 | } 157 | -------------------------------------------------------------------------------- /trilinear_cpp/src/trilinear_kernel.h: -------------------------------------------------------------------------------- 1 | #ifndef _TRILINEAR_KERNEL 2 | #define _TRILINEAR_KERNEL 3 | 4 | #include 5 | 6 | __global__ void TriLinearForward(const int nthreads, const float* lut, const float* image, float* output, const int dim, const int shift, const float binsize, const int width, const int height, const int batch); 7 | 8 | int TriLinearForwardLaucher(const float* lut, const float* image, float* output, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream); 9 | 10 | __global__ void TriLinearBackward(const int nthreads, const float* image, const float* image_grad, float* lut_grad, const int dim, const int shift, const float binsize, const int width, const int height, const int batch); 11 | 12 | int TriLinearBackwardLaucher(const float* image, const float* image_grad, float* lut_grad, const int lut_dim, const int shift, const float binsize, const int width, const int height, const int batch, cudaStream_t stream); 13 | 14 | 15 | #endif 16 | 17 | -------------------------------------------------------------------------------- /utils/generate_identity_3DLUT.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def generate_identity_3DLUT(dim, output_file): 5 | step = 1.0 / (dim - 1) 6 | with open(output_file, 'w') as f: 7 | for k in range(dim): 8 | for j in range(dim): 9 | for i in range(dim): 10 | f.write('{:.6f} {:.6f} {:.6f}\n'.format( 11 | step * i, step * j, step * k)) 12 | 13 | 14 | if __name__ == '__main__': 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('-d', '--dim', type=int, required=True) 17 | 18 | args = parser.parse_args() 19 | 20 | output_file = 'IdentityLUT{}.txt'.format(args.dim) 21 | generate_identity_3DLUT(args.dim, output_file) 22 | -------------------------------------------------------------------------------- /utils/visualize_lut.py: -------------------------------------------------------------------------------- 1 | """ 2 | @author: Hongkai Zhang 3 | @contact: kevin.hkzhang@gmail.com 4 | """ 5 | 6 | import argparse 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import torch 11 | from mpl_toolkits.mplot3d import Axes3D 12 | 13 | 14 | def vis_lut(lut, lut_dim): 15 | step = 1.0 / (lut_dim - 1) 16 | fig = plt.figure() 17 | ax = fig.add_subplot(111, projection='3d') 18 | for b in range(lut_dim): 19 | for g in range(lut_dim): 20 | # vectorization for efficiency 21 | r = np.arange(lut_dim) 22 | ax.scatter(b * step * np.ones(lut_dim), 23 | g * step * np.ones(lut_dim), 24 | r * step, 25 | c=lut[b, g, r].numpy(), 26 | marker='o', 27 | alpha=1.0) 28 | ax.set_xlabel('B') 29 | ax.set_ylabel('G') 30 | ax.set_zlabel('R') 31 | plt.show() 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser() 36 | parser.add_argument('lut_path', type=str, help='path to the LUT') 37 | parser.add_argument('--lut_dim', 38 | type=int, 39 | default=33, 40 | help='dimension of the LUT') 41 | args = parser.parse_args() 42 | 43 | lut = torch.load(args.lut_path, map_location='cpu') 44 | lut_dim = args.lut_dim 45 | lut0, lut1, lut2 = [lut[str(i)]['LUT'] for i in range(3)] 46 | 47 | # convert [3, 17, 17, 17] to [17, 17, 17, 3] 48 | lut0 = lut0.permute(1, 2, 3, 0) 49 | lut1 = lut1.permute(1, 2, 3, 0) 50 | lut2 = lut2.permute(1, 2, 3, 0) 51 | 52 | # TODO: better ways for this process 53 | # normalization 54 | lut0 = (lut0 - lut0.min()) / (lut0.max() - lut0.min()) 55 | lut1 = (lut1 - lut1.min()) / (lut1.max() - lut1.min()) 56 | lut2 = (lut2 - lut2.min()) / (lut2.max() - lut2.min()) 57 | 58 | # visualize the LUT, take lut0 as an example 59 | vis_lut(lut0, lut_dim) 60 | -------------------------------------------------------------------------------- /visualization_lut/save_trained_luts.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from models import * 5 | from datasets import * 6 | 7 | 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--epoch", type=int, default=234, help="epoch to start training from") 10 | parser.add_argument("--model_dir", type=str, default="LUTs/paired/fiveK_480p_3LUT_sm_1e-4_mn_10", help="path to save model") 11 | opt = parser.parse_args() 12 | 13 | cuda = True if torch.cuda.is_available() else False 14 | 15 | criterion_pixelwise = torch.nn.MSELoss() 16 | # Initialize generator and discriminator 17 | LUT1 = Generator3DLUT_identity() 18 | LUT2 = Generator3DLUT_zero() 19 | LUT3 = Generator3DLUT_zero() 20 | #LUT4 = Generator3DLUT_2() 21 | #LUT5 = Generator3DLUT_2() 22 | 23 | 24 | # Load pretrained models 25 | LUTs = torch.load("saved_models/%s/LUTs_%d.pth" % (opt.model_dir, opt.epoch)) 26 | LUT1.load_state_dict(LUTs["1"]) 27 | LUT2.load_state_dict(LUTs["2"]) 28 | LUT3.load_state_dict(LUTs["3"]) 29 | #LUT4.load_state_dict(LUTs["4"]) 30 | #LUT5.load_state_dict(LUTs["5"]) 31 | LUT1.eval() 32 | LUT2.eval() 33 | LUT3.eval() 34 | #LUT4.eval() 35 | #LUT5.eval() 36 | 37 | 38 | f = open('visualization/learned_LUT_%d_1.txt'%opt.epoch,'a') 39 | for p in range(0,LUT1.LUT.shape[0]): 40 | for i in range(0,LUT1.LUT.shape[1]): 41 | for j in range(0,LUT1.LUT.shape[2]): 42 | for k in range(0,LUT1.LUT.shape[3]): 43 | f.write("%f\n"%LUT1.LUT[p,i,j,k].detach().numpy()) 44 | f.close() 45 | f = open('visualization/learned_LUT_%d_2.txt'%opt.epoch,'a') 46 | for p in range(0,LUT2.LUT.shape[0]): 47 | for i in range(0,LUT2.LUT.shape[1]): 48 | for j in range(0,LUT2.LUT.shape[2]): 49 | for k in range(0,LUT2.LUT.shape[3]): 50 | f.write("%f\n"%LUT2.LUT[p,i,j,k].detach().numpy()) 51 | f.close() 52 | f = open('visualization/learned_LUT_%d_3.txt'%opt.epoch,'a') 53 | for p in range(0,LUT3.LUT.shape[0]): 54 | for i in range(0,LUT3.LUT.shape[1]): 55 | for j in range(0,LUT3.LUT.shape[2]): 56 | for k in range(0,LUT3.LUT.shape[3]): 57 | f.write("%f\n"%LUT3.LUT[p,i,j,k].detach().numpy()) 58 | f.close() 59 | 60 | -------------------------------------------------------------------------------- /visualization_lut/visualize_lut.m: -------------------------------------------------------------------------------- 1 | 2 | LUT1 = []; 3 | LUT_name = ['visualization/learned_LUT_234_1.txt']; 4 | f = fopen(LUT_name,'r'); 5 | for i = 1:n^3*3 6 | LUT1(i) = str2double(fgetl(f)); 7 | end 8 | fclose(f); 9 | 10 | LUT2 = []; 11 | LUT_name = ['visualization/learned_LUT_234_2.txt']; 12 | f = fopen(LUT_name,'r'); 13 | for i = 1:n^3*3 14 | LUT2(i) = str2double(fgetl(f)); 15 | end 16 | fclose(f); 17 | 18 | LUT3 = []; 19 | LUT_name = ['visualization/learned_LUT_234_3.txt']; 20 | f = fopen(LUT_name,'r'); 21 | for i = 1:n^3*3 22 | LUT3(i) = str2double(fgetl(f)); 23 | end 24 | fclose(f); 25 | 26 | r0 = repmat(linspace(0,1,n)',1,n,n); 27 | g0 = repmat(linspace(0,1,n),n,1,n); 28 | b0 = repmat(reshape(linspace(0,1,n),[1,1,n]),n,n); 29 | 30 | %adaptive weight 1 31 | % a1 = 2.49; 32 | % a2 = -1.92; 33 | % a3 = -0.33; 34 | a1 = 1.85; 35 | a2 = -0.09; 36 | a3 = -0.91; 37 | LUT = LUT1 * a1 + LUT2 * a2 + LUT3 * a3; 38 | r = LUT(1:n^3);r = reshape(r,[n,n,n]); 39 | g = LUT(n^3+1:n^3*2);g = reshape(g,[n,n,n]); 40 | b = LUT(n^3*2+1:n^3*3);b = reshape(b,[n,n,n]); 41 | C = [r(:),g(:),b(:)]; 42 | figure(1); 43 | scatter3(r0(:),g0(:),b0(:),20,C,'filled'); 44 | 45 | % adaptive weight 2 46 | a1 = 1.59; 47 | a2 = 0.99; 48 | a3 = -1.18; 49 | LUT = LUT1 * a1 + LUT2 * a2 + LUT3 * a3; 50 | r = LUT(1:n^3);r = reshape(r,[n,n,n]); 51 | g = LUT(n^3+1:n^3*2);g = reshape(g,[n,n,n]); 52 | b = LUT(n^3*2+1:n^3*3);b = reshape(b,[n,n,n]); 53 | C = [r(:),g(:),b(:)]; 54 | figure(2); 55 | scatter3(r0(:),g0(:),b0(:),20,C,'filled'); 56 | 57 | % plot used in the paper 58 | n = 33; 59 | fontsize = 12; 60 | set(gca,'FontSize',fontsize) 61 | figure(3); 62 | for i = 1:8:n%[1,17,33] 63 | r_slice = squeeze(r(i,:,:)); 64 | g_slice = squeeze(g(:,i,:)); 65 | b_slice = squeeze(b(:,:,i)); 66 | r2 = repmat(linspace(0,1,n)',1,n); 67 | g2 = repmat(linspace(0,1,n),n,1); 68 | 69 | subplot(1,3,1) 70 | surface(g2,r2,r_slice);view(3) 71 | set(gca,'FontSize',fontsize) 72 | % view(90,0) 73 | % zlim([-0.3,0.9]); 74 | 75 | subplot(1,3,2) 76 | surface(g2,r2,g_slice);view(3) 77 | set(gca,'FontSize',fontsize) 78 | % view(90,0) 79 | % zlim([-0.3,0.9]); 80 | 81 | subplot(1,3,3) 82 | surface(g2,r2,b_slice);view(3) 83 | set(gca,'FontSize',fontsize) 84 | % view(90,0) 85 | % zlim([-0.3,0.9]); 86 | end 87 | 88 | --------------------------------------------------------------------------------