├── IdentityLUT33.txt
├── IdentityLUT64.txt
├── LICENSE
├── README.md
├── average_psnr_ssim.m
├── datasets.py
├── demo_eval.py
├── demo_images
├── XYZ
│ └── a1629.png
└── sRGB
│ └── a1629.jpg
├── figures
└── framework2.png
├── image_adaptive_lut_evaluation.py
├── image_adaptive_lut_train_paired.py
├── image_adaptive_lut_train_unpaired.py
├── local_tone_mapping
├── a1509.jpg
├── wlsFilter.m
└── wlsTonemap.m
├── models.py
├── models_x.py
├── pretrained_models
├── XYZ
│ ├── LUTs.pth
│ ├── LUTs_unpaired.pth
│ ├── classifier.pth
│ └── classifier_unpaired.pth
└── sRGB
│ ├── LUTs.pth
│ ├── LUTs_unpaired.pth
│ ├── classifier.pth
│ └── classifier_unpaired.pth
├── requirements
├── ssim.m
├── torchvision_x_functional.py
├── trilinear_c
├── build.py
├── make.sh
└── src
│ ├── trilinear.c
│ ├── trilinear.h
│ ├── trilinear_cuda.c
│ ├── trilinear_cuda.h
│ ├── trilinear_kernel.cu
│ ├── trilinear_kernel.cu.o
│ └── trilinear_kernel.h
├── trilinear_cpp
├── setup.py
├── setup.sh
└── src
│ ├── trilinear.cpp
│ ├── trilinear.h
│ ├── trilinear_cuda.cpp
│ ├── trilinear_cuda.h
│ ├── trilinear_kernel.cu
│ └── trilinear_kernel.h
├── utils
├── generate_identity_3DLUT.py
└── visualize_lut.py
└── visualization_lut
├── learned_LUT_234_1.txt
├── learned_LUT_234_2.txt
├── learned_LUT_234_3.txt
├── save_trained_luts.py
└── visualize_lut.m
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Image-Adaptive-3DLUT
2 | Learning Image-adaptive 3D Lookup Tables for High Performance Photo Enhancement in Real-time
3 |
4 | ## Downloads
5 | ### [Paper](https://www4.comp.polyu.edu.hk/~cslzhang/paper/PAMI_LUT.pdf), [Supplementary](https://www4.comp.polyu.edu.hk/~cslzhang/paper/Supplement_LUT.pdf), Datasets([[GoogleDrive](https://drive.google.com/drive/folders/1Y1Rv3uGiJkP6CIrNTSKxPn1p-WFAc48a?usp=sharing)],[[onedrive](https://connectpolyu-my.sharepoint.com/:f:/g/personal/16901447r_connect_polyu_hk/EqNGuQUKZe9Cv3fPG08OmGEBbHMUXey2aU03E21dFZwJyg?e=QNCMMZ)],[[baiduyun](https://pan.baidu.com/s/1CsQRFsEPZCSjkT3Z1X_B1w):5fyk])
6 | The whole datasets used in the paper are over 300G. Here I only provided the FiveK dataset resized into 480p resolution (including 8-bit sRGB, 16-bit XYZ inputs and 8-bit sRGB targets). I also provided 10 full-resolution images for testing speed. To obtain the entire full-resolution images, it is recommended to convert from the original [FiveK](https://data.csail.mit.edu/graphics/fivek/) dataset.
7 |
8 | A model trained on the 480p resolution can be directly applied to images of 4K (or higher) resolution without performance drop. This can significantly speedup the training stage without loading the very heavy high-resolution images.
9 |
10 | ## Abstract
11 | Recent years have witnessed the increasing popularity of learning based methods to enhance the color and tone of photos. However, many existing photo enhancement methods either deliver unsatisfactory results or consume too much computational and memory resources, hindering their application to high-resolution images (usually with more than 12 megapixels) in practice. In this paper, we learn image-adaptive 3-dimensional lookup tables (3D LUTs) to achieve fast and robust photo enhancement. 3D LUTs are widely used for manipulating color and tone of photos, but they are usually manually tuned and fixed in camera imaging pipeline or photo editing tools. We, for the first time to our best knowledge, propose to learn 3D LUTs from annotated data using pairwise or unpaired learning. More importantly, our learned 3D LUT is image-adaptive for flexible photo enhancement. We learn multiple basis 3D LUTs and a small convolutional neural network (CNN) simultaneously in an end-to-end manner. The small CNN works on the down-sampled version of the input image to predict content-dependent weights to fuse the multiple basis 3D LUTs into an image-adaptive one, which is employed to transform the color and tone of source images efficiently. Our model contains less than **600K** parameters and takes **less than 2 ms** to process an image of 4K resolution using one Titan RTX GPU. While being highly efficient, our model also outperforms the state-of-the-art photo enhancement methods by a large margin in terms of PSNR, SSIM and a color difference metric on two publically available benchmark datasets.
12 |
13 | ## Framework
14 |
15 |
16 | ## Usage
17 |
18 | ### Useful issues
19 | Replace the trilinear interpolation with torch.nn.functional.grid_sample [https://github.com/HuiZeng/Image-Adaptive-3DLUT/issues/14].
20 |
21 | ### Requirements
22 | Python3, requirements.txt
23 |
24 | ### Build
25 | By default, we use pytorch 0.4.1:
26 |
27 | cd trilinear_c
28 | sh make.sh
29 |
30 | For pytorch 1.x:
31 |
32 | cd trilinear_cpp
33 | sh setup.sh
34 |
35 | Please also replace the following lines:
36 | ```
37 | # in image_adaptive_lut_train_paired.py, image_adaptive_lut_evaluation.py, demo_eval.py, and image_adaptive_lut_train_unpaired.py
38 | from models import * --> from models_x import *
39 | # in demo_eval.py
40 | result = trilinear_(LUT, img) --> _, result = trilinear_(LUT, img)
41 | # in image_adaptive_lut_train_paired.py and image_adaptive_lut_evaluation.py
42 | combine_A = trilinear_(LUT,img) --> _, combine_A = trilinear_(LUT,img)
43 | ```
44 |
45 | ### Training
46 | #### paired training
47 | python3 image_adaptive_lut_train_paired.py
48 | #### unpaired training
49 | python3 image_adaptive_lut_train_unpaired.py
50 |
51 | ### Evaluation
52 | 1. use python to generate and save the test images:
53 |
54 | python3 image_adaptive_lut_evaluation.py
55 |
56 | speed can also be tested in above code.
57 |
58 | 2. use matlab to calculate the indexes used in our paper:
59 |
60 | average_psnr_ssim.m
61 |
62 | ### Demo
63 |
64 | python3 demo_eval.py
65 |
66 | ### Tools
67 | 1. You can generate identity 3DLUT with arbitrary dimension by using `utils/generate_identity_3DLUT.py` as follows:
68 |
69 | ```
70 | # you can replace 33 with any number you want
71 | python3 utils/generate_identity_3DLUT.py -d 33
72 | ```
73 |
74 | 2. You can visualize the learned 3D LUT either by using the matlab code in `visualization_lut` or using the python code `utils/visualize_lut.py` as follows:
75 |
76 | ```
77 | python3 utils/visualize_lut.py path/to/your/lut
78 | # you can also modify the dimension of the lut as follows
79 | python3 utils/visualize_lut.py path/to/your/lut --lut_dim 64
80 | ```
81 |
82 | ## Citation
83 | ```
84 | @article{zeng2020lut,
85 | title={Learning Image-adaptive 3D Lookup Tables for High Performance Photo Enhancement in Real-time},
86 | author={Zeng, Hui and Cai, Jianrui and Li, Lida and Cao, Zisheng and Zhang, Lei},
87 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
88 | volume={44},
89 | number={04},
90 | pages={2058--2073},
91 | year={2022},
92 | publisher={IEEE Computer Society}
93 | }
94 |
95 | @inproceedings{zhang2022clut,
96 | title={CLUT-Net: Learning Adaptively Compressed Representations of 3DLUTs for Lightweight Image Enhancement},
97 | author={Zhang, Fengyi and Zeng, Hui and Zhang, Tianjun and Zhang, Lin},
98 | booktitle={Proceedings of the 30th ACM International Conference on Multimedia},
99 | pages={6493--6501},
100 | year={2022}
101 | }
102 | ```
103 |
--------------------------------------------------------------------------------
/average_psnr_ssim.m:
--------------------------------------------------------------------------------
1 | test_path = 'images/LUTs/paired/fiveK_480p_3LUT_sm_1e-4_mn_10_sRGB_145';
2 | gt_path = '../data/fiveK/expertC/JPG/480p';
3 | path_list = dir(fullfile(test_path,'*.png'));
4 | img_num = length(path_list);
5 | %calculate psnr
6 | total_psnr = 0;
7 | total_ssim = 0;
8 | total_color = 0;
9 | if img_num > 0
10 | for j = 1:img_num
11 | image_name = path_list(j).name;
12 | input = imread(fullfile(test_path,image_name));
13 | gt = imread(fullfile(gt_path,[image_name(1:end-3), 'jpg']));
14 |
15 | psnr_val = psnr(im2double(input), im2double(gt));
16 | total_psnr = total_psnr + psnr_val;
17 |
18 | ssim_val = ssim(input, gt);
19 | total_ssim = total_ssim + ssim_val;
20 |
21 | color = sqrt(sum((rgb2lab(gt) - rgb2lab(input)).^2,3));
22 | color = mean(color(:));
23 | total_color = total_color + color;
24 | fprintf('%d %f %f %f %s\n',j,psnr_val,ssim_val,color,fullfile(test_path,image_name));
25 | end
26 | end
27 | qm_psnr = total_psnr / img_num;
28 | avg_ssim = total_ssim / img_num;
29 | avg_color = total_color / img_num;
30 | fprintf('The avgerage psnr is: %f', qm_psnr);
31 | fprintf('The avgerage ssim is: %f', avg_ssim);
32 | fprintf('The avgerage lab is: %f', avg_color);
33 |
--------------------------------------------------------------------------------
/datasets.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import random
3 | import os
4 | import numpy as np
5 | import torch
6 | import cv2
7 |
8 | from torch.utils.data import Dataset
9 | from PIL import Image
10 | import torchvision.transforms as transforms
11 | import torchvision.transforms.functional as TF
12 | import torchvision_x_functional as TF_x
13 |
14 |
15 | class ImageDataset_sRGB(Dataset):
16 | def __init__(self, root, mode="train", unpaird_data="fiveK", combined=True):
17 | self.mode = mode
18 | self.unpaird_data = unpaird_data
19 |
20 | file = open(os.path.join(root,'train_input.txt'),'r')
21 | set1_input_files = sorted(file.readlines())
22 | self.set1_input_files = list()
23 | self.set1_expert_files = list()
24 | for i in range(len(set1_input_files)):
25 | self.set1_input_files.append(os.path.join(root,"input","JPG/480p",set1_input_files[i][:-1] + ".jpg"))
26 | self.set1_expert_files.append(os.path.join(root,"expertC","JPG/480p",set1_input_files[i][:-1] + ".jpg"))
27 |
28 | file = open(os.path.join(root,'train_label.txt'),'r')
29 | set2_input_files = sorted(file.readlines())
30 | self.set2_input_files = list()
31 | self.set2_expert_files = list()
32 | for i in range(len(set2_input_files)):
33 | self.set2_input_files.append(os.path.join(root,"input","JPG/480p",set2_input_files[i][:-1] + ".jpg"))
34 | self.set2_expert_files.append(os.path.join(root,"expertC","JPG/480p",set2_input_files[i][:-1] + ".jpg"))
35 |
36 | file = open(os.path.join(root,'test.txt'),'r')
37 | test_input_files = sorted(file.readlines())
38 | self.test_input_files = list()
39 | self.test_expert_files = list()
40 | for i in range(len(test_input_files)):
41 | self.test_input_files.append(os.path.join(root,"input","JPG/480p",test_input_files[i][:-1] + ".jpg"))
42 | self.test_expert_files.append(os.path.join(root,"expertC","JPG/480p",test_input_files[i][:-1] + ".jpg"))
43 |
44 | if combined:
45 | self.set1_input_files = self.set1_input_files + self.set2_input_files
46 | self.set1_expert_files = self.set1_expert_files + self.set2_expert_files
47 |
48 |
49 | def __getitem__(self, index):
50 |
51 | if self.mode == "train":
52 | img_name = os.path.split(self.set1_input_files[index % len(self.set1_input_files)])[-1]
53 | img_input = Image.open(self.set1_input_files[index % len(self.set1_input_files)])
54 | img_exptC = Image.open(self.set1_expert_files[index % len(self.set1_expert_files)])
55 |
56 | elif self.mode == "test":
57 | img_name = os.path.split(self.test_input_files[index % len(self.test_input_files)])[-1]
58 | img_input = Image.open(self.test_input_files[index % len(self.test_input_files)])
59 | img_exptC = Image.open(self.test_expert_files[index % len(self.test_expert_files)])
60 |
61 | if self.mode == "train":
62 |
63 | ratio_H = np.random.uniform(0.6,1.0)
64 | ratio_W = np.random.uniform(0.6,1.0)
65 | W,H = img_input._size
66 | crop_h = round(H*ratio_H)
67 | crop_w = round(W*ratio_W)
68 | i, j, h, w = transforms.RandomCrop.get_params(img_input, output_size=(crop_h, crop_w))
69 | img_input = TF.crop(img_input, i, j, h, w)
70 | img_exptC = TF.crop(img_exptC, i, j, h, w)
71 | #img_input = TF.resized_crop(img_input, i, j, h, w, (320,320))
72 | #img_exptC = TF.resized_crop(img_exptC, i, j, h, w, (320,320))
73 |
74 | if np.random.random() > 0.5:
75 | img_input = TF.hflip(img_input)
76 | img_exptC = TF.hflip(img_exptC)
77 |
78 | a = np.random.uniform(0.8,1.2)
79 | img_input = TF.adjust_brightness(img_input,a)
80 |
81 | a = np.random.uniform(0.8,1.2)
82 | img_input = TF.adjust_saturation(img_input,a)
83 |
84 | img_input = TF.to_tensor(img_input)
85 | img_exptC = TF.to_tensor(img_exptC)
86 |
87 | return {"A_input": img_input, "A_exptC": img_exptC, "input_name": img_name}
88 |
89 | def __len__(self):
90 | if self.mode == "train":
91 | return len(self.set1_input_files)
92 | elif self.mode == "test":
93 | return len(self.test_input_files)
94 |
95 |
96 | class ImageDataset_XYZ(Dataset):
97 | def __init__(self, root, mode="train", unpaird_data="fiveK", combined=True):
98 | self.mode = mode
99 |
100 | file = open(os.path.join(root,'train_input.txt'),'r')
101 | set1_input_files = sorted(file.readlines())
102 | self.set1_input_files = list()
103 | self.set1_expert_files = list()
104 | for i in range(len(set1_input_files)):
105 | self.set1_input_files.append(os.path.join(root,"input","PNG/480p_16bits_XYZ_WB",set1_input_files[i][:-1] + ".png"))
106 | self.set1_expert_files.append(os.path.join(root,"expertC","JPG/480p",set1_input_files[i][:-1] + ".jpg"))
107 |
108 | file = open(os.path.join(root,'train_label.txt'),'r')
109 | set2_input_files = sorted(file.readlines())
110 | self.set2_input_files = list()
111 | self.set2_expert_files = list()
112 | for i in range(len(set2_input_files)):
113 | self.set2_input_files.append(os.path.join(root,"input","PNG/480p_16bits_XYZ_WB",set2_input_files[i][:-1] + ".png"))
114 | self.set2_expert_files.append(os.path.join(root,"expertC","JPG/480p",set2_input_files[i][:-1] + ".jpg"))
115 |
116 | file = open(os.path.join(root,'test.txt'),'r')
117 | test_input_files = sorted(file.readlines())
118 | self.test_input_files = list()
119 | self.test_expert_files = list()
120 | for i in range(len(test_input_files)):
121 | self.test_input_files.append(os.path.join(root,"input","PNG/480p_16bits_XYZ_WB",test_input_files[i][:-1] + ".png"))
122 | self.test_expert_files.append(os.path.join(root,"expertC","JPG/480p",test_input_files[i][:-1] + ".jpg"))
123 |
124 | if combined:
125 | self.set1_input_files = self.set1_input_files + self.set2_input_files
126 | self.set1_expert_files = self.set1_expert_files + self.set2_expert_files
127 |
128 |
129 | def __getitem__(self, index):
130 |
131 | if self.mode == "train":
132 | img_name = os.path.split(self.set1_input_files[index % len(self.set1_input_files)])[-1]
133 | img_input = cv2.imread(self.set1_input_files[index % len(self.set1_input_files)],-1)
134 | img_exptC = Image.open(self.set1_expert_files[index % len(self.set1_expert_files)])
135 |
136 | elif self.mode == "test":
137 | img_name = os.path.split(self.test_input_files[index % len(self.test_input_files)])[-1]
138 | img_input = cv2.imread(self.test_input_files[index % len(self.test_input_files)],-1)
139 | img_exptC = Image.open(self.test_expert_files[index % len(self.test_expert_files)])
140 |
141 | img_input = np.array(img_input)
142 | #img_input = np.array(cv2.cvtColor(img_input,cv2.COLOR_BGR2RGB))
143 |
144 | if self.mode == "train":
145 |
146 | ratio_H = np.random.uniform(0.6,1.0)
147 | ratio_W = np.random.uniform(0.6,1.0)
148 | W,H = img_exptC._size
149 | crop_h = round(H*ratio_H)
150 | crop_w = round(W*ratio_W)
151 | i, j, h, w = transforms.RandomCrop.get_params(img_exptC, output_size=(crop_h, crop_w))
152 | img_input = TF_x.crop(img_input, i, j, h, w)
153 | img_exptC = TF.crop(img_exptC, i, j, h, w)
154 |
155 | if np.random.random() > 0.5:
156 | img_input = TF_x.hflip(img_input)
157 | img_exptC = TF.hflip(img_exptC)
158 |
159 | a = np.random.uniform(0.6,1.4)
160 | img_input = TF_x.adjust_brightness(img_input,a)
161 |
162 | img_input = TF_x.to_tensor(img_input)
163 | img_exptC = TF.to_tensor(img_exptC)
164 |
165 | return {"A_input": img_input, "A_exptC": img_exptC, "input_name": img_name}
166 |
167 | def __len__(self):
168 | if self.mode == "train":
169 | return len(self.set1_input_files)
170 | elif self.mode == "test":
171 | return len(self.test_input_files)
172 |
173 | class ImageDataset_sRGB_unpaired(Dataset):
174 | def __init__(self, root, mode="train", unpaird_data="fiveK"):
175 | self.mode = mode
176 | self.unpaird_data = unpaird_data
177 |
178 | file = open(os.path.join(root,'train_input.txt'),'r')
179 | set1_input_files = sorted(file.readlines())
180 | self.set1_input_files = list()
181 | self.set1_expert_files = list()
182 | for i in range(len(set1_input_files)):
183 | self.set1_input_files.append(os.path.join(root,"input","JPG/480p",set1_input_files[i][:-1] + ".jpg"))
184 | self.set1_expert_files.append(os.path.join(root,"expertC","JPG/480p",set1_input_files[i][:-1] + ".jpg"))
185 |
186 | file = open(os.path.join(root,'train_label.txt'),'r')
187 | set2_input_files = sorted(file.readlines())
188 | self.set2_input_files = list()
189 | self.set2_expert_files = list()
190 | for i in range(len(set2_input_files)):
191 | self.set2_input_files.append(os.path.join(root,"input","JPG/480p",set2_input_files[i][:-1] + ".jpg"))
192 | self.set2_expert_files.append(os.path.join(root,"expertC","JPG/480p",set2_input_files[i][:-1] + ".jpg"))
193 |
194 | file = open(os.path.join(root,'test.txt'),'r')
195 | test_input_files = sorted(file.readlines())
196 | self.test_input_files = list()
197 | self.test_expert_files = list()
198 | for i in range(len(test_input_files)):
199 | self.test_input_files.append(os.path.join(root,"input","JPG/480p",test_input_files[i][:-1] + ".jpg"))
200 | self.test_expert_files.append(os.path.join(root,"expertC","JPG/480p",test_input_files[i][:-1] + ".jpg"))
201 |
202 |
203 | def __getitem__(self, index):
204 |
205 | if self.mode == "train":
206 | img_name = os.path.split(self.set1_input_files[index % len(self.set1_input_files)])[-1]
207 | img_input = Image.open(self.set1_input_files[index % len(self.set1_input_files)])
208 | img_exptC = Image.open(self.set1_expert_files[index % len(self.set1_expert_files)])
209 | seed = random.randint(1,len(self.set2_expert_files))
210 | img2 = Image.open(self.set2_expert_files[(index + seed) % len(self.set2_expert_files)])
211 |
212 | elif self.mode == "test":
213 | img_name = os.path.split(self.test_input_files[index % len(self.test_input_files)])[-1]
214 | img_input = Image.open(self.test_input_files[index % len(self.test_input_files)])
215 | img_exptC = Image.open(self.test_expert_files[index % len(self.test_expert_files)])
216 | img2 = img_exptC
217 |
218 | if self.mode == "train":
219 | ratio_H = np.random.uniform(0.6,1.0)
220 | ratio_W = np.random.uniform(0.6,1.0)
221 | W,H = img_input._size
222 | crop_h = round(H*ratio_H)
223 | crop_w = round(W*ratio_W)
224 | W2,H2 = img2._size
225 | crop_h = min(crop_h,H2)
226 | crop_w = min(crop_w,W2)
227 | i, j, h, w = transforms.RandomCrop.get_params(img_input, output_size=(crop_h, crop_w))
228 | img_input = TF.crop(img_input, i, j, h, w)
229 | img_exptC = TF.crop(img_exptC, i, j, h, w)
230 | i, j, h, w = transforms.RandomCrop.get_params(img2, output_size=(crop_h, crop_w))
231 | img2 = TF.crop(img2, i, j, h, w)
232 |
233 | if np.random.random() > 0.5:
234 | img_input = TF.hflip(img_input)
235 | img_exptC = TF.hflip(img_exptC)
236 |
237 | if np.random.random() > 0.5:
238 | img2 = TF.hflip(img2)
239 |
240 | #if np.random.random() > 0.5:
241 | # img_input = TF.vflip(img_input)
242 | # img_exptC = TF.vflip(img_exptC)
243 | # img2 = TF.vflip(img2)
244 |
245 | a = np.random.uniform(0.6,1.4)
246 | img_input = TF.adjust_brightness(img_input,a)
247 |
248 | a = np.random.uniform(0.8,1.2)
249 | img_input = TF.adjust_saturation(img_input,a)
250 |
251 |
252 | img_input = TF.to_tensor(img_input)
253 | img_exptC = TF.to_tensor(img_exptC)
254 | img2 = TF.to_tensor(img2)
255 |
256 | return {"A_input": img_input, "A_exptC": img_exptC, "B_exptC": img2, "input_name": img_name}
257 |
258 | def __len__(self):
259 | if self.mode == "train":
260 | return len(self.set1_input_files)
261 | elif self.mode == "test":
262 | return len(self.test_input_files)
263 |
264 |
265 | class ImageDataset_XYZ_unpaired(Dataset):
266 | def __init__(self, root, mode="train", unpaird_data="fiveK"):
267 | self.mode = mode
268 | self.unpaird_data = unpaird_data
269 |
270 | file = open(os.path.join(root,'train_input.txt'),'r')
271 | set1_input_files = sorted(file.readlines())
272 | self.set1_input_files = list()
273 | self.set1_expert_files = list()
274 | for i in range(len(set1_input_files)):
275 | self.set1_input_files.append(os.path.join(root,"input","PNG/480p_16bits_XYZ_WB",set1_input_files[i][:-1] + ".png"))
276 | self.set1_expert_files.append(os.path.join(root,"expertC","JPG/480p",set1_input_files[i][:-1] + ".jpg"))
277 |
278 | file = open(os.path.join(root,'train_label.txt'),'r')
279 | set2_input_files = sorted(file.readlines())
280 | self.set2_input_files = list()
281 | self.set2_expert_files = list()
282 | for i in range(len(set2_input_files)):
283 | self.set2_input_files.append(os.path.join(root,"input","PNG/480p_16bits_XYZ_WB",set2_input_files[i][:-1] + ".png"))
284 | self.set2_expert_files.append(os.path.join(root,"expertC","JPG/480p",set2_input_files[i][:-1] + ".jpg"))
285 |
286 | file = open(os.path.join(root,'test.txt'),'r')
287 | test_input_files = sorted(file.readlines())
288 | self.test_input_files = list()
289 | self.test_expert_files = list()
290 | for i in range(len(test_input_files)):
291 | self.test_input_files.append(os.path.join(root,"input","PNG/480p_16bits_XYZ_WB",test_input_files[i][:-1] + ".png"))
292 | self.test_expert_files.append(os.path.join(root,"expertC","JPG/480p",test_input_files[i][:-1] + ".jpg"))
293 |
294 |
295 | def __getitem__(self, index):
296 |
297 | if self.mode == "train":
298 | img_name = os.path.split(self.set1_input_files[index % len(self.set1_input_files)])[-1]
299 | img_input = cv2.imread(self.set1_input_files[index % len(self.set1_input_files)],-1)
300 | img_exptC = Image.open(self.set1_expert_files[index % len(self.set1_expert_files)])
301 | seed = random.randint(1,len(self.set2_expert_files))
302 | img2 = Image.open(self.set2_expert_files[(index + seed) % len(self.set2_expert_files)])
303 |
304 | elif self.mode == "test":
305 | img_name = os.path.split(self.test_input_files[index % len(self.test_input_files)])[-1]
306 | img_input = cv2.imread(self.test_input_files[index % len(self.test_input_files)],-1)
307 | img_exptC = Image.open(self.test_expert_files[index % len(self.test_expert_files)])
308 | img2 = img_exptC
309 |
310 | img_input = np.array(img_input)
311 | #img_input = np.array(cv2.cvtColor(img_input,cv2.COLOR_BGR2RGB))
312 |
313 | if self.mode == "train":
314 | ratio_H = np.random.uniform(0.6,1.0)
315 | ratio_W = np.random.uniform(0.6,1.0)
316 | W,H = img_exptC._size
317 | crop_h = round(H*ratio_H)
318 | crop_w = round(W*ratio_W)
319 | W2,H2 = img2._size
320 | crop_h = min(crop_h,H2)
321 | crop_w = min(crop_w,W2)
322 | i, j, h, w = transforms.RandomCrop.get_params(img_exptC, output_size=(crop_h, crop_w))
323 | img_input = TF_x.crop(img_input, i, j, h, w)
324 | img_exptC = TF.crop(img_exptC, i, j, h, w)
325 | i, j, h, w = transforms.RandomCrop.get_params(img2, output_size=(crop_h, crop_w))
326 | img2 = TF.crop(img2, i, j, h, w)
327 |
328 | if np.random.random() > 0.5:
329 | img_input = TF_x.hflip(img_input)
330 | img_exptC = TF.hflip(img_exptC)
331 |
332 | if np.random.random() > 0.5:
333 | img2 = TF.hflip(img2)
334 |
335 | a = np.random.uniform(0.6,1.4)
336 | img_input = TF_x.adjust_brightness(img_input,a)
337 |
338 | img_input = TF_x.to_tensor(img_input)
339 | img_exptC = TF.to_tensor(img_exptC)
340 | img2 = TF.to_tensor(img2)
341 |
342 | return {"A_input": img_input, "A_exptC": img_exptC, "B_exptC": img2, "input_name": img_name}
343 |
344 | def __len__(self):
345 | if self.mode == "train":
346 | return len(self.set1_input_files)
347 | elif self.mode == "test":
348 | return len(self.test_input_files)
349 |
350 |
351 | class ImageDataset_HDRplus(Dataset):
352 | def __init__(self, root, mode="train", combined=True):
353 | self.mode = mode
354 |
355 | file = open(os.path.join(root,'train.txt'),'r')
356 | set1_input_files = sorted(file.readlines())
357 | self.set1_input_files = list()
358 | self.set1_expert_files = list()
359 | for i in range(len(set1_input_files)):
360 | self.set1_input_files.append(os.path.join(root,"middle_480p",set1_input_files[i][:-1] + ".png"))
361 | self.set1_expert_files.append(os.path.join(root,"output_480p",set1_input_files[i][:-1] + ".jpg"))
362 |
363 | file = open(os.path.join(root,'test.txt'),'r')
364 | test_input_files = sorted(file.readlines())
365 | self.test_input_files = list()
366 | self.test_expert_files = list()
367 | for i in range(len(test_input_files)):
368 | self.test_input_files.append(os.path.join(root,"middle_480p",test_input_files[i][:-1] + ".png"))
369 | self.test_expert_files.append(os.path.join(root,"output_480p",test_input_files[i][:-1] + ".jpg"))
370 |
371 |
372 | def __getitem__(self, index):
373 |
374 | if self.mode == "train":
375 | img_name = os.path.split(self.set1_input_files[index % len(self.set1_input_files)])[-1]
376 | img_input = cv2.imread(self.set1_input_files[index % len(self.set1_input_files)],-1)
377 | img_exptC = Image.open(self.set1_expert_files[index % len(self.set1_expert_files)])
378 |
379 | elif self.mode == "test":
380 | img_name = os.path.split(self.test_input_files[index % len(self.test_input_files)])[-1]
381 | img_input = cv2.imread(self.test_input_files[index % len(self.test_input_files)],-1)
382 | img_exptC = Image.open(self.test_expert_files[index % len(self.test_expert_files)])
383 |
384 | img_input = np.array(img_input)
385 | #img_input = np.array(cv2.cvtColor(img_input,cv2.COLOR_BGR2RGB))
386 |
387 | if self.mode == "train":
388 |
389 | ratio = np.random.uniform(0.6,1.0)
390 | W,H = img_exptC._size
391 | crop_h = round(H*ratio)
392 | crop_w = round(W*ratio)
393 | i, j, h, w = transforms.RandomCrop.get_params(img_exptC, output_size=(crop_h, crop_w))
394 | try:
395 | img_input = TF_x.crop(img_input, i, j, h, w)
396 | except:
397 | print(crop_h,crop_w,img_input.shape())
398 | img_exptC = TF.crop(img_exptC, i, j, h, w)
399 |
400 | if np.random.random() > 0.5:
401 | img_input = TF_x.hflip(img_input)
402 | img_exptC = TF.hflip(img_exptC)
403 |
404 | a = np.random.uniform(0.6,1.4)
405 | img_input = TF_x.adjust_brightness(img_input,a)
406 |
407 | #a = np.random.uniform(0.8,1.2)
408 | #img_input = TF_x.adjust_saturation(img_input,a)
409 |
410 | img_input = TF_x.to_tensor(img_input)
411 | img_exptC = TF.to_tensor(img_exptC)
412 |
413 | return {"A_input": img_input, "A_exptC": img_exptC, "input_name": img_name}
414 |
415 | def __len__(self):
416 | if self.mode == "train":
417 | return len(self.set1_input_files)
418 | elif self.mode == "test":
419 | return len(self.test_input_files)
420 |
421 | class ImageDataset_HDRplus_unpaired(Dataset):
422 | def __init__(self, root, mode="train"):
423 | self.mode = mode
424 |
425 | file = open(os.path.join(root,'train.txt'),'r')
426 | set1_input_files = sorted(file.readlines())
427 | self.set1_input_files = list()
428 | self.set1_expert_files = list()
429 | for i in range(len(set1_input_files)):
430 | self.set1_input_files.append(os.path.join(root,"middle_480p",set1_input_files[i][:-1] + ".png"))
431 | self.set1_expert_files.append(os.path.join(root,"output_480p",set1_input_files[i][:-1] + ".jpg"))
432 |
433 | file = open(os.path.join(root,'train.txt'),'r')
434 | set2_input_files = sorted(file.readlines())
435 | self.set2_input_files = list()
436 | self.set2_expert_files = list()
437 | for i in range(len(set2_input_files)):
438 | self.set2_input_files.append(os.path.join(root,"middle_480p",set2_input_files[i][:-1] + ".png"))
439 | self.set2_expert_files.append(os.path.join(root,"output_480p",set2_input_files[i][:-1] + ".jpg"))
440 |
441 | file = open(os.path.join(root,'test.txt'),'r')
442 | test_input_files = sorted(file.readlines())
443 | self.test_input_files = list()
444 | self.test_expert_files = list()
445 | for i in range(len(test_input_files)):
446 | self.test_input_files.append(os.path.join(root,"middle_480p",test_input_files[i][:-1] + ".png"))
447 | self.test_expert_files.append(os.path.join(root,"output_480p",test_input_files[i][:-1] + ".jpg"))
448 |
449 |
450 | def __getitem__(self, index):
451 |
452 | if self.mode == "train":
453 | img_name = os.path.split(self.set1_input_files[index % len(self.set1_input_files)])[-1]
454 | img_input = cv2.imread(self.set1_input_files[index % len(self.set1_input_files)],-1)
455 | img_exptC = Image.open(self.set1_expert_files[index % len(self.set1_expert_files)])
456 | seed = random.randint(1,len(self.set2_expert_files))
457 | img2 = Image.open(self.set2_expert_files[(index + seed) % len(self.set2_expert_files)])
458 |
459 | elif self.mode == "test":
460 | img_name = os.path.split(self.test_input_files[index % len(self.test_input_files)])[-1]
461 | img_input = cv2.imread(self.test_input_files[index % len(self.test_input_files)],-1)
462 | img_exptC = Image.open(self.test_expert_files[index % len(self.test_expert_files)])
463 | img2 = img_exptC
464 |
465 | img_input = np.array(img_input)
466 | #img_input = np.array(cv2.cvtColor(img_input,cv2.COLOR_BGR2RGB))
467 |
468 | if self.mode == "train":
469 | ratio = np.random.uniform(0.6,1.0)
470 | W,H = img_exptC._size
471 | crop_h = round(H*ratio)
472 | crop_w = round(W*ratio)
473 | W2,H2 = img2._size
474 | crop_h = min(crop_h,H2)
475 | crop_w = min(crop_w,W2)
476 | i, j, h, w = transforms.RandomCrop.get_params(img_exptC, output_size=(crop_h, crop_w))
477 | img_input = TF_x.crop(img_input, i, j, h, w)
478 | img_exptC = TF.crop(img_exptC, i, j, h, w)
479 | i, j, h, w = transforms.RandomCrop.get_params(img2, output_size=(crop_h, crop_w))
480 | img2 = TF.crop(img2, i, j, h, w)
481 |
482 | if np.random.random() > 0.5:
483 | img_input = TF_x.hflip(img_input)
484 | img_exptC = TF.hflip(img_exptC)
485 |
486 | if np.random.random() > 0.5:
487 | img2 = TF.hflip(img2)
488 |
489 | a = np.random.uniform(0.8,1.2)
490 | img_input = TF_x.adjust_brightness(img_input,a)
491 |
492 | img_input = TF_x.to_tensor(img_input)
493 | img_exptC = TF.to_tensor(img_exptC)
494 | img2 = TF.to_tensor(img2)
495 |
496 | return {"A_input": img_input, "A_exptC": img_exptC, "B_exptC": img2, "input_name": img_name}
497 |
498 | def __len__(self):
499 | if self.mode == "train":
500 | return len(self.set1_input_files)
501 | elif self.mode == "test":
502 | return len(self.test_input_files)
503 |
--------------------------------------------------------------------------------
/demo_eval.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import os
4 | import numpy as np
5 | import cv2
6 | from PIL import Image
7 |
8 | from models import *
9 | import torchvision_x_functional as TF_x
10 | import torchvision.transforms.functional as TF
11 |
12 |
13 | parser = argparse.ArgumentParser()
14 |
15 | parser.add_argument("--image_dir", type=str, default="demo_images", help="directory of image")
16 | parser.add_argument("--image_name", type=str, default="a1629.jpg", help="name of image")
17 | parser.add_argument("--input_color_space", type=str, default="sRGB", help="input color space: sRGB or XYZ")
18 | parser.add_argument("--model_dir", type=str, default="pretrained_models", help="directory of pretrained models")
19 | parser.add_argument("--output_dir", type=str, default="demo_results", help="directory to save results")
20 | opt = parser.parse_args()
21 | opt.model_dir = opt.model_dir + '/' + opt.input_color_space
22 | opt.image_path = opt.image_dir + '/' + opt.input_color_space + '/' + opt.image_name
23 | os.makedirs(opt.output_dir, exist_ok=True)
24 |
25 | # use gpu when detect cuda
26 | cuda = True if torch.cuda.is_available() else False
27 | # Tensor type
28 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
29 |
30 | criterion_pixelwise = torch.nn.MSELoss()
31 | LUT0 = Generator3DLUT_identity()
32 | LUT1 = Generator3DLUT_zero()
33 | LUT2 = Generator3DLUT_zero()
34 | #LUT3 = Generator3DLUT_zero()
35 | #LUT4 = Generator3DLUT_zero()
36 | classifier = Classifier()
37 | trilinear_ = TrilinearInterpolation()
38 |
39 | if cuda:
40 | LUT0 = LUT0.cuda()
41 | LUT1 = LUT1.cuda()
42 | LUT2 = LUT2.cuda()
43 | #LUT3 = LUT3.cuda()
44 | #LUT4 = LUT4.cuda()
45 | classifier = classifier.cuda()
46 | criterion_pixelwise.cuda()
47 |
48 | # Load pretrained models
49 | LUTs = torch.load("%s/LUTs.pth" % opt.model_dir)
50 | LUT0.load_state_dict(LUTs["0"])
51 | LUT1.load_state_dict(LUTs["1"])
52 | LUT2.load_state_dict(LUTs["2"])
53 | #LUT3.load_state_dict(LUTs["3"])
54 | #LUT4.load_state_dict(LUTs["4"])
55 | LUT0.eval()
56 | LUT1.eval()
57 | LUT2.eval()
58 | #LUT3.eval()
59 | #LUT4.eval()
60 | classifier.load_state_dict(torch.load("%s/classifier.pth" % opt.model_dir))
61 | classifier.eval()
62 |
63 |
64 | def generate_LUT(img):
65 |
66 | pred = classifier(img).squeeze()
67 |
68 | LUT = pred[0] * LUT0.LUT + pred[1] * LUT1.LUT + pred[2] * LUT2.LUT #+ pred[3] * LUT3.LUT + pred[4] * LUT4.LUT
69 |
70 | return LUT
71 |
72 |
73 | # ----------
74 | # test
75 | # ----------
76 | # read image and transform to tensor
77 | if opt.input_color_space == 'sRGB':
78 | img = Image.open(opt.image_path)
79 | img = TF.to_tensor(img).type(Tensor)
80 | elif opt.input_color_space == 'XYZ':
81 | img = cv2.imread(opt.image_path, -1)
82 | img = np.array(img)
83 | img = TF_x.to_tensor(img).type(Tensor)
84 | img = img.unsqueeze(0)
85 |
86 | LUT = generate_LUT(img)
87 |
88 | # generate image
89 | result = trilinear_(LUT, img)
90 |
91 | # save image
92 | ndarr = result.squeeze().mul_(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
93 | im = Image.fromarray(ndarr)
94 | im.save('%s/result.jpg' % opt.output_dir, quality=95)
95 |
96 |
97 |
--------------------------------------------------------------------------------
/demo_images/XYZ/a1629.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/demo_images/XYZ/a1629.png
--------------------------------------------------------------------------------
/demo_images/sRGB/a1629.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/demo_images/sRGB/a1629.jpg
--------------------------------------------------------------------------------
/figures/framework2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/figures/framework2.png
--------------------------------------------------------------------------------
/image_adaptive_lut_evaluation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import time
3 | import torch
4 | from torchvision.utils import save_image
5 | from torch.utils.data import DataLoader
6 | from torch.autograd import Variable
7 |
8 | from models import *
9 | from datasets import *
10 |
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument("--epoch", type=int, default=145, help="epoch to load the saved checkpoint")
14 | parser.add_argument("--dataset_name", type=str, default="fiveK", help="name of the dataset")
15 | parser.add_argument("--input_color_space", type=str, default="sRGB", help="input color space: sRGB or XYZ")
16 | parser.add_argument("--model_dir", type=str, default="LUTs/paired/fiveK_480p_3LUT_sm_1e-4_mn_10", help="directory of saved models")
17 | opt = parser.parse_args()
18 | opt.model_dir = opt.model_dir + '_' + opt.input_color_space
19 |
20 | # use gpu when detect cuda
21 | cuda = True if torch.cuda.is_available() else False
22 | # Tensor type
23 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
24 |
25 | criterion_pixelwise = torch.nn.MSELoss()
26 | LUT0 = Generator3DLUT_identity()
27 | LUT1 = Generator3DLUT_zero()
28 | LUT2 = Generator3DLUT_zero()
29 | #LUT3 = Generator3DLUT_zero()
30 | #LUT4 = Generator3DLUT_zero()
31 | classifier = Classifier()
32 | trilinear_ = TrilinearInterpolation()
33 |
34 | if cuda:
35 | LUT0 = LUT0.cuda()
36 | LUT1 = LUT1.cuda()
37 | LUT2 = LUT2.cuda()
38 | #LUT3 = LUT3.cuda()
39 | #LUT4 = LUT4.cuda()
40 | classifier = classifier.cuda()
41 | criterion_pixelwise.cuda()
42 |
43 | # Load pretrained models
44 | LUTs = torch.load("saved_models/%s/LUTs_%d.pth" % (opt.model_dir, opt.epoch))
45 | LUT0.load_state_dict(LUTs["0"])
46 | LUT1.load_state_dict(LUTs["1"])
47 | LUT2.load_state_dict(LUTs["2"])
48 | #LUT3.load_state_dict(LUTs["3"])
49 | #LUT4.load_state_dict(LUTs["4"])
50 | LUT0.eval()
51 | LUT1.eval()
52 | LUT2.eval()
53 | #LUT3.eval()
54 | #LUT4.eval()
55 | classifier.load_state_dict(torch.load("saved_models/%s/classifier_%d.pth" % (opt.model_dir, opt.epoch)))
56 | classifier.eval()
57 |
58 | if opt.input_color_space == 'sRGB':
59 | dataloader = DataLoader(
60 | ImageDataset_sRGB("../data/%s" % opt.dataset_name, mode="test"),
61 | batch_size=1,
62 | shuffle=False,
63 | num_workers=1,
64 | )
65 | elif opt.input_color_space == 'XYZ':
66 | dataloader = DataLoader(
67 | ImageDataset_XYZ("../data/%s" % opt.dataset_name, mode="test"),
68 | batch_size=1,
69 | shuffle=False,
70 | num_workers=1,
71 | )
72 |
73 |
74 | def generator(img):
75 |
76 | pred = classifier(img).squeeze()
77 |
78 | LUT = pred[0] * LUT0.LUT + pred[1] * LUT1.LUT + pred[2] * LUT2.LUT #+ pred[3] * LUT3.LUT + pred[4] * LUT4.LUT
79 |
80 | combine_A = img.new(img.size())
81 | combine_A = trilinear_(LUT,img)
82 |
83 | return combine_A
84 |
85 |
86 | def visualize_result():
87 | """Saves a generated sample from the validation set"""
88 | out_dir = "images/%s_%d" % (opt.model_dir, opt.epoch)
89 | os.makedirs(out_dir, exist_ok=True)
90 | for i, batch in enumerate(dataloader):
91 | real_A = Variable(batch["A_input"].type(Tensor))
92 | img_name = batch["input_name"]
93 | fake_B = generator(real_A)
94 |
95 | #real_B = Variable(batch["A_exptC"].type(Tensor))
96 | #img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -1)
97 | #save_image(img_sample, "images/LUTs/paired/JPGsRGB8_to_JPGsRGB8_WB_original_5LUT/%s.png" % (img_name[0][:-4]), nrow=3, normalize=False)
98 | save_image(fake_B, os.path.join(out_dir,"%s.png" % (img_name[0][:-4])), nrow=1, normalize=False)
99 |
100 | def test_speed():
101 | t_list = []
102 | for i in range(1,10):
103 | img_input = Image.open(os.path.join("../data/fiveK/input/JPG","original","a000%d.jpg"%i))
104 | img_input = torch.unsqueeze(TF.to_tensor(TF.resize(img_input,(4000,6000))),0)
105 | real_A = Variable(img_input.type(Tensor))
106 | torch.cuda.synchronize()
107 | t0 = time.time()
108 | for j in range(0,100):
109 | fake_B = generator(real_A)
110 |
111 | torch.cuda.synchronize()
112 | t1 = time.time()
113 | t_list.append(t1 - t0)
114 | print((t1 - t0))
115 | print(t_list)
116 |
117 | # ----------
118 | # evaluation
119 | # ----------
120 | visualize_result()
121 |
122 | #test_speed()
123 |
--------------------------------------------------------------------------------
/image_adaptive_lut_train_paired.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import math
5 | import itertools
6 | import time
7 | import datetime
8 | import sys
9 |
10 | import torchvision.transforms as transforms
11 | from torchvision.utils import save_image
12 |
13 | from torch.utils.data import DataLoader
14 | from torchvision import datasets
15 | from torch.autograd import Variable
16 |
17 | from models import *
18 | from datasets import *
19 |
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | import torch
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from, 0 starts from scratch, >0 starts from saved checkpoints")
26 | parser.add_argument("--n_epochs", type=int, default=400, help="total number of epochs of training")
27 | parser.add_argument("--dataset_name", type=str, default="fiveK", help="name of the dataset")
28 | parser.add_argument("--input_color_space", type=str, default="sRGB", help="input color space: sRGB or XYZ")
29 | parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
30 | parser.add_argument("--lr", type=float, default=0.0001, help="adam: learning rate")
31 | parser.add_argument("--b1", type=float, default=0.9, help="adam: decay of first order momentum of gradient")
32 | parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
33 | parser.add_argument("--lambda_smooth", type=float, default=0.0001, help="smooth regularization")
34 | parser.add_argument("--lambda_monotonicity", type=float, default=10.0, help="monotonicity regularization")
35 | parser.add_argument("--n_cpu", type=int, default=1, help="number of cpu threads to use during batch generation")
36 | parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between model checkpoints")
37 | parser.add_argument("--output_dir", type=str, default="LUTs/paired/fiveK_480p_3LUT_sm_1e-4_mn_10", help="path to save model")
38 | opt = parser.parse_args()
39 |
40 | opt.output_dir = opt.output_dir + '_' + opt.input_color_space
41 | print(opt)
42 |
43 | os.makedirs("saved_models/%s" % opt.output_dir, exist_ok=True)
44 |
45 | cuda = True if torch.cuda.is_available() else False
46 | # Tensor type
47 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
48 |
49 | # Loss functions
50 | criterion_pixelwise = torch.nn.MSELoss()
51 |
52 | # Initialize generator and discriminator
53 | LUT0 = Generator3DLUT_identity()
54 | LUT1 = Generator3DLUT_zero()
55 | LUT2 = Generator3DLUT_zero()
56 | #LUT3 = Generator3DLUT_zero()
57 | #LUT4 = Generator3DLUT_zero()
58 | classifier = Classifier()
59 | TV3 = TV_3D()
60 | trilinear_ = TrilinearInterpolation()
61 |
62 | if cuda:
63 | LUT0 = LUT0.cuda()
64 | LUT1 = LUT1.cuda()
65 | LUT2 = LUT2.cuda()
66 | #LUT3 = LUT3.cuda()
67 | #LUT4 = LUT4.cuda()
68 | classifier = classifier.cuda()
69 | criterion_pixelwise.cuda()
70 | TV3.cuda()
71 | TV3.weight_r = TV3.weight_r.type(Tensor)
72 | TV3.weight_g = TV3.weight_g.type(Tensor)
73 | TV3.weight_b = TV3.weight_b.type(Tensor)
74 |
75 | if opt.epoch != 0:
76 | # Load pretrained models
77 | LUTs = torch.load("saved_models/%s/LUTs_%d.pth" % (opt.output_dir, opt.epoch))
78 | LUT0.load_state_dict(LUTs["0"])
79 | LUT1.load_state_dict(LUTs["1"])
80 | LUT2.load_state_dict(LUTs["2"])
81 | #LUT3.load_state_dict(LUTs["3"])
82 | #LUT4.load_state_dict(LUTs["4"])
83 | classifier.load_state_dict(torch.load("saved_models/%s/classifier_%d.pth" % (opt.output_dir, opt.epoch)))
84 | else:
85 | # Initialize weights
86 | classifier.apply(weights_init_normal_classifier)
87 | torch.nn.init.constant_(classifier.model[16].bias.data, 1.0)
88 |
89 | # Optimizers
90 |
91 | optimizer_G = torch.optim.Adam(itertools.chain(classifier.parameters(), LUT0.parameters(), LUT1.parameters(), LUT2.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)) #, LUT3.parameters(), LUT4.parameters()
92 |
93 | if opt.input_color_space == 'sRGB':
94 | dataloader = DataLoader(
95 | ImageDataset_sRGB("../data/%s" % opt.dataset_name, mode = "train"),
96 | batch_size=opt.batch_size,
97 | shuffle=True,
98 | num_workers=opt.n_cpu,
99 | )
100 |
101 | psnr_dataloader = DataLoader(
102 | ImageDataset_sRGB("../data/%s" % opt.dataset_name, mode="test"),
103 | batch_size=1,
104 | shuffle=False,
105 | num_workers=1,
106 | )
107 | elif opt.input_color_space == 'XYZ':
108 | dataloader = DataLoader(
109 | ImageDataset_XYZ("../data/%s" % opt.dataset_name, mode = "train"),
110 | batch_size=opt.batch_size,
111 | shuffle=True,
112 | num_workers=opt.n_cpu,
113 | )
114 |
115 | psnr_dataloader = DataLoader(
116 | ImageDataset_XYZ("../data/%s" % opt.dataset_name, mode="test"),
117 | batch_size=1,
118 | shuffle=False,
119 | num_workers=1,
120 | )
121 |
122 | def generator_train(img):
123 |
124 | pred = classifier(img).squeeze()
125 | if len(pred.shape) == 1:
126 | pred = pred.unsqueeze(0)
127 | gen_A0 = LUT0(img)
128 | gen_A1 = LUT1(img)
129 | gen_A2 = LUT2(img)
130 | #gen_A3 = LUT3(img)
131 | #gen_A4 = LUT4(img)
132 |
133 | weights_norm = torch.mean(pred ** 2)
134 |
135 | combine_A = img.new(img.size())
136 | for b in range(img.size(0)):
137 | combine_A[b,:,:,:] = pred[b,0] * gen_A0[b,:,:,:] + pred[b,1] * gen_A1[b,:,:,:] + pred[b,2] * gen_A2[b,:,:,:] #+ pred[b,3] * gen_A3[b,:,:,:] + pred[b,4] * gen_A4[b,:,:,:]
138 |
139 | return combine_A, weights_norm
140 |
141 | def generator_eval(img):
142 |
143 | pred = classifier(img).squeeze()
144 |
145 | LUT = pred[0] * LUT0.LUT + pred[1] * LUT1.LUT + pred[2] * LUT2.LUT #+ pred[3] * LUT3.LUT + pred[4] * LUT4.LUT
146 |
147 | weights_norm = torch.mean(pred ** 2)
148 |
149 | combine_A = img.new(img.size())
150 | combine_A = trilinear_(LUT,img)
151 |
152 | return combine_A, weights_norm
153 |
154 | def calculate_psnr():
155 | classifier.eval()
156 | avg_psnr = 0
157 | for i, batch in enumerate(psnr_dataloader):
158 | real_A = Variable(batch["A_input"].type(Tensor))
159 | real_B = Variable(batch["A_exptC"].type(Tensor))
160 | fake_B, weights_norm = generator_eval(real_A)
161 | fake_B = torch.round(fake_B*255)
162 | real_B = torch.round(real_B*255)
163 | mse = criterion_pixelwise(fake_B, real_B)
164 | psnr = 10 * math.log10(255.0 * 255.0 / mse.item())
165 | avg_psnr += psnr
166 |
167 | return avg_psnr/ len(psnr_dataloader)
168 |
169 |
170 | def visualize_result(epoch):
171 | """Saves a generated sample from the validation set"""
172 | classifier.eval()
173 | os.makedirs("images/%s/" % opt.output_dir +str(epoch), exist_ok=True)
174 | for i, batch in enumerate(psnr_dataloader):
175 | real_A = Variable(batch["A_input"].type(Tensor))
176 | real_B = Variable(batch["A_exptC"].type(Tensor))
177 | img_name = batch["input_name"]
178 | fake_B, weights_norm = generator_eval(real_A)
179 | img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -1)
180 | fake_B = torch.round(fake_B*255)
181 | real_B = torch.round(real_B*255)
182 | mse = criterion_pixelwise(fake_B, real_B)
183 | psnr = 10 * math.log10(255.0 * 255.0 / mse.item())
184 | save_image(img_sample, "images/%s/%s/%s.jpg" % (opt.output_dir,epoch, img_name[0]+'_'+str(psnr)[:5]), nrow=3, normalize=False)
185 |
186 | # ----------
187 | # Training
188 | # ----------
189 |
190 | prev_time = time.time()
191 | max_psnr = 0
192 | max_epoch = 0
193 | for epoch in range(opt.epoch, opt.n_epochs):
194 | mse_avg = 0
195 | psnr_avg = 0
196 | classifier.train()
197 | for i, batch in enumerate(dataloader):
198 |
199 | # Model inputs
200 | real_A = Variable(batch["A_input"].type(Tensor))
201 | real_B = Variable(batch["A_exptC"].type(Tensor))
202 |
203 | # ------------------
204 | # Train Generators
205 | # ------------------
206 |
207 | optimizer_G.zero_grad()
208 |
209 | fake_B, weights_norm = generator_train(real_A)
210 |
211 | # Pixel-wise loss
212 | mse = criterion_pixelwise(fake_B, real_B)
213 |
214 | tv0, mn0 = TV3(LUT0)
215 | tv1, mn1 = TV3(LUT1)
216 | tv2, mn2 = TV3(LUT2)
217 | #tv3, mn3 = TV3(LUT3)
218 | #tv4, mn4 = TV3(LUT4)
219 | tv_cons = tv0 + tv1 + tv2 #+ tv3 + tv4
220 | mn_cons = mn0 + mn1 + mn2 #+ mn3 + mn4
221 |
222 | loss = mse + opt.lambda_smooth * (weights_norm + tv_cons) + opt.lambda_monotonicity * mn_cons
223 |
224 | psnr_avg += 10 * math.log10(1 / mse.item())
225 |
226 | mse_avg += mse.item()
227 |
228 | loss.backward()
229 |
230 | optimizer_G.step()
231 |
232 |
233 | # --------------
234 | # Log Progress
235 | # --------------
236 |
237 | # Determine approximate time left
238 | batches_done = epoch * len(dataloader) + i
239 | batches_left = opt.n_epochs * len(dataloader) - batches_done
240 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
241 | prev_time = time.time()
242 |
243 | # Print log
244 | sys.stdout.write(
245 | "\r[Epoch %d/%d] [Batch %d/%d] [psnr: %f, tv: %f, wnorm: %f, mn: %f] ETA: %s"
246 | % (epoch,opt.n_epochs,i,len(dataloader),psnr_avg / (i+1),tv_cons, weights_norm, mn_cons, time_left,
247 | )
248 | )
249 |
250 | avg_psnr = calculate_psnr()
251 | if avg_psnr > max_psnr:
252 | max_psnr = avg_psnr
253 | max_epoch = epoch
254 | sys.stdout.write(" [PSNR: %f] [max PSNR: %f, epoch: %d]\n"% (avg_psnr, max_psnr, max_epoch))
255 |
256 | #if (epoch+1) % 10 == 0:
257 | # visualize_result(epoch+1)
258 |
259 | if epoch % opt.checkpoint_interval == 0:
260 | # Save model checkpoints
261 | LUTs = {"0": LUT0.state_dict(),"1": LUT1.state_dict(),"2": LUT2.state_dict()} #,"3": LUT3.state_dict(),"4": LUT4.state_dict()
262 | torch.save(LUTs, "saved_models/%s/LUTs_%d.pth" % (opt.output_dir, epoch))
263 | torch.save(classifier.state_dict(), "saved_models/%s/classifier_%d.pth" % (opt.output_dir, epoch))
264 | file = open('saved_models/%s/result.txt' % opt.output_dir,'a')
265 | file.write(" [PSNR: %f] [max PSNR: %f, epoch: %d]\n"% (avg_psnr, max_psnr, max_epoch))
266 | file.close()
267 |
268 |
269 |
--------------------------------------------------------------------------------
/image_adaptive_lut_train_unpaired.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import numpy as np
4 | import math
5 | import itertools
6 | import time
7 | import datetime
8 | import sys
9 |
10 | import torchvision.transforms as transforms
11 | from torchvision.utils import save_image
12 |
13 | from torch.utils.data import DataLoader
14 | from torchvision import datasets
15 | from torch.autograd import Variable
16 | import torch.autograd as autograd
17 |
18 | from models_x import *
19 | from datasets import *
20 |
21 | import torch.nn as nn
22 | import torch.nn.functional as F
23 | import torch
24 |
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from, 0 starts from scratch, >0 starts from saved checkpoints")
27 | parser.add_argument("--n_epochs", type=int, default=800, help="total number of epochs of training")
28 | parser.add_argument("--dataset_name", type=str, default="fiveK", help="name of the dataset")
29 | parser.add_argument("--input_color_space", type=str, default="sRGB", help="input color space: sRGB or XYZ")
30 | parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
31 | parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
32 | parser.add_argument("--b1", type=float, default=0.9, help="adam: decay of first order momentum of gradient")
33 | parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
34 | parser.add_argument("--lambda_pixel", type=float, default=1000, help="content preservation weight: 1000 for sRGB input, 10 for XYZ input")
35 | parser.add_argument("--lambda_gp", type=float, default=10, help="gradient penalty weight in wgan-gp")
36 | parser.add_argument("--lambda_smooth", type=float, default=1e-4, help="smooth regularization")
37 | parser.add_argument("--lambda_monotonicity", type=float, default=10.0, help="monotonicity regularization: 10 for sRGB input, 100 for XYZ input (slightly better)")
38 | parser.add_argument("--n_cpu", type=int, default=1, help="number of cpu threads to use during batch generation")
39 | parser.add_argument("--n_critic", type=int, default=1, help="number of training steps for discriminator per iter")
40 | parser.add_argument("--output_dir", type=str, default="LUTs/unpaired/fiveK_480p_sm_1e-4_mn_10_pixel_1000", help="path to save model")
41 | parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between model checkpoints")
42 | opt = parser.parse_args()
43 | opt.output_dir = opt.output_dir + '_' + opt.input_color_space
44 | print(opt)
45 |
46 | os.makedirs("saved_models/%s" % opt.output_dir, exist_ok=True)
47 |
48 | cuda = True if torch.cuda.is_available() else False
49 | Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
50 |
51 | # Loss functions
52 | criterion_GAN = torch.nn.MSELoss()
53 | criterion_pixelwise = torch.nn.MSELoss()
54 |
55 | # Initialize generator and discriminator
56 | LUT0 = Generator3DLUT_identity()
57 | LUT1 = Generator3DLUT_zero()
58 | LUT2 = Generator3DLUT_zero()
59 | #LUT3 = Generator3DLUT_zero()
60 | #LUT4 = Generator3DLUT_zero()
61 | classifier = Classifier_unpaired()
62 | discriminator = Discriminator()
63 | TV3 = TV_3D()
64 |
65 | if cuda:
66 | LUT0 = LUT0.cuda()
67 | LUT1 = LUT1.cuda()
68 | LUT2 = LUT2.cuda()
69 | #LUT3 = LUT3.cuda()
70 | #LUT4 = LUT4.cuda()
71 | classifier = classifier.cuda()
72 | criterion_GAN.cuda()
73 | criterion_pixelwise.cuda()
74 | discriminator = discriminator.cuda()
75 | TV3.cuda()
76 | TV3.weight_r = TV3.weight_r.type(Tensor)
77 | TV3.weight_g = TV3.weight_g.type(Tensor)
78 | TV3.weight_b = TV3.weight_b.type(Tensor)
79 |
80 | if opt.epoch != 0:
81 | # Load pretrained models
82 | LUTs = torch.load("saved_models/%s/LUTs_%d.pth" % (opt.output_dir, opt.epoch))
83 | LUT0.load_state_dict(LUTs["0"])
84 | LUT1.load_state_dict(LUTs["1"])
85 | LUT2.load_state_dict(LUTs["2"])
86 | #LUT3.load_state_dict(LUTs["3"])
87 | #LUT4.load_state_dict(LUTs["4"])
88 | classifier.load_state_dict(torch.load("saved_models/%s/classifier_%d.pth" % (opt.output_dir, opt.epoch)))
89 | else:
90 | # Initialize weights
91 | classifier.apply(weights_init_normal_classifier)
92 | torch.nn.init.constant_(classifier.model[12].bias.data, 1.0)
93 | discriminator.apply(weights_init_normal_classifier)
94 |
95 | # Optimizers
96 | optimizer_G = torch.optim.Adam(itertools.chain(classifier.parameters(), LUT0.parameters(),LUT1.parameters(),LUT2.parameters()), lr=opt.lr, betas=(opt.b1, opt.b2)) #,LUT3.parameters(),LUT4.parameters()
97 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
98 |
99 | if opt.input_color_space == 'sRGB':
100 | dataloader = DataLoader(
101 | ImageDataset_sRGB_unpaired("../data/%s" % opt.dataset_name, mode="train"),
102 | batch_size=opt.batch_size,
103 | shuffle=True,
104 | num_workers=opt.n_cpu,
105 | )
106 |
107 | psnr_dataloader = DataLoader(
108 | ImageDataset_sRGB_unpaired("../data/%s" % opt.dataset_name, mode="test"),
109 | batch_size=1,
110 | shuffle=False,
111 | num_workers=1,
112 | )
113 | elif opt.input_color_space == 'XYZ':
114 | dataloader = DataLoader(
115 | ImageDataset_XYZ_unpaired("../data/%s" % opt.dataset_name, mode="train"),
116 | batch_size=opt.batch_size,
117 | shuffle=True,
118 | num_workers=opt.n_cpu,
119 | )
120 |
121 | psnr_dataloader = DataLoader(
122 | ImageDataset_XYZ_unpaired("../data/%s" % opt.dataset_name, mode="test"),
123 | batch_size=1,
124 | shuffle=False,
125 | num_workers=1,
126 | )
127 |
128 | def calculate_psnr():
129 | classifier.eval()
130 | avg_psnr = 0
131 | for i, batch in enumerate(psnr_dataloader):
132 | real_A = Variable(batch["A_input"].type(Tensor))
133 | real_B = Variable(batch["A_exptC"].type(Tensor))
134 | fake_B, weights_norm = generator(real_A)
135 | fake_B = torch.round(fake_B*255)
136 | real_B = torch.round(real_B*255)
137 | mse = criterion_pixelwise(fake_B, real_B)
138 | psnr = 10 * math.log10(255.0 * 255.0 / mse.item())
139 | avg_psnr += psnr
140 |
141 | return avg_psnr/ len(psnr_dataloader)
142 |
143 |
144 | def visualize_result(epoch):
145 | """Saves a generated sample from the validation set"""
146 | os.makedirs("images/LUTs/" +str(epoch), exist_ok=True)
147 | for i, batch in enumerate(psnr_dataloader):
148 | real_A = Variable(batch["A_input"].type(Tensor))
149 | real_B = Variable(batch["A_exptC"].type(Tensor))
150 | img_name = batch["input_name"]
151 | fake_B, weights_norm = generator(real_A)
152 | img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -1)
153 | fake_B = torch.round(fake_B*255)
154 | real_B = torch.round(real_B*255)
155 | mse = criterion_pixelwise(fake_B, real_B)
156 | psnr = 10 * math.log10(255.0 * 255.0 / mse.item())
157 | save_image(img_sample, "images/LUTs/%s/%s.jpg" % (epoch, img_name[0]+'_'+str(psnr)[:5]), nrow=3, normalize=False)
158 |
159 |
160 | def compute_gradient_penalty(D, real_samples, fake_samples):
161 | """Calculates the gradient penalty loss for WGAN GP"""
162 | # Random weight term for interpolation between real and fake samples
163 | alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
164 | # Get random interpolation between real and fake samples
165 | interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
166 | d_interpolates = D(interpolates)
167 | fake = Variable(Tensor(real_samples.shape[0], 1, 1, 1).fill_(1.0), requires_grad=False)
168 | # Get gradient w.r.t. interpolates
169 | gradients = autograd.grad(
170 | outputs=d_interpolates,
171 | inputs=interpolates,
172 | grad_outputs=fake,
173 | create_graph=True,
174 | retain_graph=True,
175 | only_inputs=True,
176 | )[0]
177 | gradients = gradients.view(gradients.size(0), -1)
178 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
179 | return gradient_penalty
180 |
181 |
182 | def generator(img):
183 |
184 | pred = classifier(img).squeeze()
185 | weights_norm = torch.mean(pred ** 2)
186 | combine_A = pred[0] * LUT0(img) + pred[1] * LUT1(img) + pred[2] * LUT2(img) #+ pred[3] * LUT3(img) + pred[4] * LUT4(img)
187 |
188 | return combine_A, weights_norm
189 |
190 | # ----------
191 | # Training
192 | # ----------
193 |
194 | avg_psnr = calculate_psnr()
195 | print(avg_psnr)
196 | prev_time = time.time()
197 | max_psnr = 0
198 | max_epoch = 0
199 | for epoch in range(opt.epoch, opt.n_epochs):
200 | loss_D_avg = 0
201 | loss_G_avg = 0
202 | loss_pixel_avg = 0
203 | cnt = 0
204 | psnr_avg = 0
205 | classifier.train()
206 | for i, batch in enumerate(dataloader):
207 |
208 | # Model inputs
209 | real_A = Variable(batch["A_input"].type(Tensor))
210 | real_B = Variable(batch["B_exptC"].type(Tensor))
211 |
212 |
213 | # ---------------------
214 | # Train Discriminator
215 | # ---------------------
216 |
217 | optimizer_D.zero_grad()
218 |
219 | fake_B, weights_norm = generator(real_A)
220 | pred_real = discriminator(real_B)
221 | pred_fake = discriminator(fake_B)
222 |
223 | # Gradient penalty
224 | gradient_penalty = compute_gradient_penalty(discriminator, real_B, fake_B)
225 |
226 | # Total loss
227 | loss_D = -torch.mean(pred_real) + torch.mean(pred_fake) + opt.lambda_gp * gradient_penalty
228 |
229 | loss_D.backward()
230 | optimizer_D.step()
231 |
232 | loss_D_avg += (-torch.mean(pred_real) + torch.mean(pred_fake)) / 2
233 |
234 | # ------------------
235 | # Train Generators
236 | # ------------------
237 | if i % opt.n_critic == 0:
238 |
239 | optimizer_G.zero_grad()
240 |
241 | fake_B, weights_norm = generator(real_A)
242 | pred_fake = discriminator(fake_B)
243 | # Pixel-wise loss
244 | loss_pixel = criterion_pixelwise(fake_B, real_A)
245 |
246 | tv0, mn0 = TV3(LUT0)
247 | tv1, mn1 = TV3(LUT1)
248 | tv2, mn2 = TV3(LUT2)
249 | #tv3, mn3 = TV3(LUT3)
250 | #tv4, mn4 = TV3(LUT4)
251 |
252 | tv_cons = tv0 + tv1 + tv2 #+ tv3 + tv4
253 | mn_cons = mn0 + mn1 + mn2 #+ mn3 + mn4
254 |
255 | loss_G = -torch.mean(pred_fake) + opt.lambda_pixel * loss_pixel + opt.lambda_smooth * (weights_norm + tv_cons) + opt.lambda_monotonicity * mn_cons
256 |
257 | loss_G.backward()
258 |
259 | optimizer_G.step()
260 |
261 | cnt += 1
262 | loss_G_avg += -torch.mean(pred_fake)
263 |
264 | loss_pixel_avg += loss_pixel
265 | psnr_avg += 10 * math.log10(1 / loss_pixel.item())
266 |
267 |
268 | # --------------
269 | # Log Progress
270 | # --------------
271 |
272 | # Determine approximate time left
273 | batches_done = epoch * len(dataloader) + i
274 | batches_left = opt.n_epochs * len(dataloader) - batches_done
275 | time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
276 | prev_time = time.time()
277 |
278 | # Print log
279 | sys.stdout.write(
280 | "\r[Epoch %d/%d] [Batch %d/%d] [D: %f, G: %f] [pixel: %f] [tv: %f, wnorm: %f, mn: %f] ETA: %s"
281 | % (
282 | epoch,
283 | opt.n_epochs,
284 | i,
285 | len(dataloader),
286 | loss_D_avg.item() / cnt,
287 | loss_G_avg.item() / cnt,
288 | loss_pixel_avg.item() / cnt,
289 | tv_cons, weights_norm, mn_cons,
290 | time_left,
291 | )
292 | )
293 |
294 | # If at sample interval save image
295 | avg_psnr = calculate_psnr()
296 | if avg_psnr > max_psnr:
297 | max_psnr = avg_psnr
298 | max_epoch = epoch
299 | sys.stdout.write(" [PSNR: %f] [max PSNR: %f, epoch: %d]\n"% (avg_psnr, max_psnr, max_epoch))
300 |
301 | #if (epoch+1) % 10 == 0:
302 | # visualize_result(epoch+1)
303 |
304 | if epoch % opt.checkpoint_interval == 0:
305 | # Save model checkpoints
306 | LUTs = {"0": LUT0.state_dict(), "1": LUT1.state_dict(), "2": LUT2.state_dict()} #, "3": LUT3.state_dict(), "4": LUT4.state_dict()
307 | torch.save(LUTs, "saved_models/%s/LUTs_%d.pth" % (opt.output_dir, epoch))
308 | torch.save(classifier.state_dict(), "saved_models/%s/classifier_%d.pth" % (opt.output_dir, epoch))
309 | file = open('saved_models/%s/result.txt' % opt.output_dir,'a')
310 | file.write(" [PSNR: %f] [max PSNR: %f, epoch: %d]\n"% (avg_psnr, max_psnr, max_epoch))
311 | file.close()
312 |
--------------------------------------------------------------------------------
/local_tone_mapping/a1509.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/local_tone_mapping/a1509.jpg
--------------------------------------------------------------------------------
/local_tone_mapping/wlsFilter.m:
--------------------------------------------------------------------------------
1 | function OUT = wlsFilter(IN, lambda, alpha, L)
2 | %WLSFILTER Edge-preserving smoothing based on the weighted least squares(WLS)
3 | % optimization framework, as described in Farbman, Fattal, Lischinski, and
4 | % Szeliski, "Edge-Preserving Decompositions for Multi-Scale Tone and Detail
5 | % Manipulation", ACM Transactions on Graphics, 27(3), August 2008.
6 | %
7 | % Given an input image IN, we seek a new image OUT, which, on the one hand,
8 | % is as close as possible to IN, and, at the same time, is as smooth as
9 | % possible everywhere, except across significant gradients in L.
10 | %
11 | %
12 | % Input arguments:
13 | % ----------------
14 | % IN Input image (2-D, double, N-by-M matrix).
15 | %
16 | % lambda Balances between the data term and the smoothness
17 | % term. Increasing lambda will produce smoother images.
18 | % Default value is 1.0
19 | %
20 | % alpha Gives a degree of control over the affinities by non-
21 | % lineary scaling the gradients. Increasing alpha will
22 | % result in sharper preserved edges. Default value: 1.2
23 | %
24 | % L Source image for the affinity matrix. Same dimensions
25 | % as the input image IN. Default: log(IN)
26 | %
27 | %
28 | % Example
29 | % -------
30 | % RGB = imread('peppers.png');
31 | % I = double(rgb2gray(RGB));
32 | % I = I./max(I(:));
33 | % res = wlsFilter(I, 0.5);
34 | % figure, imshow(I), figure, imshow(res)
35 | % res = wlsFilter(I, 2, 2);
36 | % figure, imshow(res)
37 |
38 | if(~exist('L', 'var')),
39 | L = log(IN+eps);
40 | end
41 |
42 | if(~exist('alpha', 'var')),
43 | alpha = 1.2;
44 | end
45 |
46 | if(~exist('lambda', 'var')),
47 | lambda = 1;
48 | end
49 |
50 | smallNum = 0.0001;
51 |
52 | [r,c] = size(IN);
53 | k = r*c;
54 |
55 | % Compute affinities between adjacent pixels based on gradients of L
56 | dy = diff(L, 1, 1);
57 | dy = -lambda./(abs(dy).^alpha + smallNum);
58 | dy = padarray(dy, [1 0], 'post');
59 | dy = dy(:);
60 |
61 | dx = diff(L, 1, 2);
62 | dx = -lambda./(abs(dx).^alpha + smallNum);
63 | dx = padarray(dx, [0 1], 'post');
64 | dx = dx(:);
65 |
66 |
67 | % Construct a five-point spatially inhomogeneous Laplacian matrix
68 | B(:,1) = dx;
69 | B(:,2) = dy;
70 | d = [-r,-1];
71 | A = spdiags(B,d,k,k);
72 |
73 | e = dx;
74 | w = padarray(dx, r, 'pre'); w = w(1:end-r);
75 | s = dy;
76 | n = padarray(dy, 1, 'pre'); n = n(1:end-1);
77 |
78 | D = 1-(e+w+s+n);
79 | A = A + A' + spdiags(D, 0, k, k);
80 |
81 | % Solve
82 | OUT = A\IN(:);
83 | OUT = reshape(OUT, r, c);
84 |
--------------------------------------------------------------------------------
/local_tone_mapping/wlsTonemap.m:
--------------------------------------------------------------------------------
1 |
2 | %WLSTONEMAP High Dynamic Range tonemapping using WLS
3 | %
4 | % The script reduces the dynamic range of an HDR image using the method
5 | % originally proposed by Durand and Dorsey, "Fast Bilateral Filtering
6 | % for the Display of High-Dynamic-Range Images",
7 | % ACM Transactions on Graphics, 2002.
8 | %
9 | % Instead of the bilateral filter, the edge-preserving smoothing here
10 | % is based on the weighted least squares(WLS) optimization framework,
11 | % as described in Farbman, Fattal, Lischinski, and Szeliski,
12 | % "Edge-Preserving Decompositions for Multi-Scale Tone and Detail
13 | % Manipulation", ACM Transactions on Graphics, 27(3), August 2008.
14 |
15 |
16 | %% Load HDR image from file and convert to greyscale
17 | % hdr = double(hdrread('smallOffice.hdr'));
18 | hdr = double(imread('a1509.jpg'))/255.0;
19 | % hdr = imresize(hdr,0.2);
20 | % hsv = rgb2hsv(hdr);
21 | % I = hsv(:,:,3);
22 | I = 0.2989*hdr(:,:,1) + 0.587*hdr(:,:,2) + 0.114*hdr(:,:,3);
23 | logI = log(I+eps);
24 |
25 | %% Perform edge-preserving smoothing using WLS
26 | lambda = 20.0;
27 | alpha = 1.2;
28 | % base = log(wlsFilter(I, lambda, alpha));
29 | base = log(imguidedfilter(I));
30 |
31 | %% Compress the base layer and restore detail
32 | compression = 0.6;
33 | detail = logI - base;
34 | OUT = base*compression + detail;
35 | OUT = exp(OUT);
36 |
37 | %% Restore color
38 | OUT = OUT./I;
39 | OUT = hdr .* padarray(OUT, [0 0 2], 'circular' , 'post');
40 | % hsv(:,:,3) = I;
41 | % OUT = hsv2rgb(hsv);
42 |
43 | %% Finally, shift, scale, and gamma correct the result
44 | gamma = 1.0/1.0;
45 | bias = -min(OUT(:));
46 | gain = 0.8;
47 | OUT = (gain*(OUT + bias)).^gamma;
48 | % figure
49 | imshow(OUT);
50 | imwrite(OUT,'a1509_T.jpg');
51 | % imshowpair(hdr,OUT, 'montage')
52 | % imshow(hdr,OUT);
53 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torchvision.models as models
4 | import torchvision.transforms as transforms
5 | from torch.autograd import Variable
6 | import torch
7 | import numpy as np
8 | import math
9 | from trilinear_c._ext import trilinear
10 |
11 | def weights_init_normal_classifier(m):
12 | classname = m.__class__.__name__
13 | if classname.find("Conv") != -1:
14 | torch.nn.init.xavier_normal_(m.weight.data)
15 |
16 | elif classname.find("BatchNorm2d") != -1 or classname.find("InstanceNorm2d") != -1:
17 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
18 | torch.nn.init.constant_(m.bias.data, 0.0)
19 |
20 | class resnet18_224(nn.Module):
21 |
22 | def __init__(self, out_dim=5, aug_test=False):
23 | super(resnet18_224, self).__init__()
24 |
25 | self.aug_test = aug_test
26 | net = models.resnet18(pretrained=True)
27 | # self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
28 | # self.std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()
29 |
30 | self.upsample = nn.Upsample(size=(224,224),mode='bilinear')
31 | net.fc = nn.Linear(512, out_dim)
32 | self.model = net
33 |
34 |
35 | def forward(self, x):
36 |
37 | x = self.upsample(x)
38 | if self.aug_test:
39 | # x = torch.cat((x, torch.rot90(x, 1, [2, 3]), torch.rot90(x, 3, [2, 3])), 0)
40 | x = torch.cat((x, torch.flip(x, [3])), 0)
41 | f = self.model(x)
42 |
43 | return f
44 |
45 | ##############################
46 | # DPE
47 | ##############################
48 |
49 |
50 | class UNetDown(nn.Module):
51 | def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
52 | super(UNetDown, self).__init__()
53 | layers = [nn.Conv2d(in_size, out_size, 5, 2, 2)]
54 | layers.append(nn.SELU(inplace=True))
55 | if normalize:
56 | #layers.append(nn.BatchNorm2d(out_size))
57 | nn.InstanceNorm2d(out_size, affine = True)
58 | if dropout:
59 | layers.append(nn.Dropout(dropout))
60 | self.model = nn.Sequential(*layers)
61 |
62 | def forward(self, x):
63 | return self.model(x)
64 |
65 |
66 | class UNetUp(nn.Module):
67 | def __init__(self, in_size, out_size, normalize=True, dropout=0.0):
68 | super(UNetUp, self).__init__()
69 | layers = [
70 | nn.Upsample(scale_factor=2, mode = 'bilinear', align_corners=True),
71 | nn.Conv2d(in_size, out_size, 3, padding=1),
72 | nn.SELU(inplace=True),
73 | ]
74 |
75 | if normalize:
76 | #layers.append(nn.BatchNorm2d(out_size))
77 | nn.InstanceNorm2d(out_size, affine = True)
78 |
79 | if dropout:
80 | layers.append(nn.Dropout(dropout))
81 |
82 | self.model = nn.Sequential(*layers)
83 |
84 | def forward(self, x, skip_input):
85 | x = self.model(x)
86 | x = torch.cat((x, skip_input), 1)
87 |
88 | return x
89 |
90 |
91 | class GeneratorUNet(nn.Module):
92 | def __init__(self, in_channels=3, out_channels=3):
93 | super(GeneratorUNet, self).__init__()
94 |
95 | self.conv1 = nn.Sequential(
96 | nn.Conv2d(3, 16, 3, padding=1),
97 | nn.SELU(inplace=True),
98 | #nn.BatchNorm2d(16),
99 | nn.InstanceNorm2d(16, affine = True),
100 | )
101 | self.down1 = UNetDown(16, 32)
102 | self.down2 = UNetDown(32, 64)
103 | self.down3 = UNetDown(64, 128)
104 | self.down4 = UNetDown(128, 128)
105 | self.down5 = UNetDown(128, 128)
106 | self.down6 = UNetDown(128, 128)
107 | self.down7 = nn.Sequential(
108 | nn.Conv2d(128, 128, 3, padding=1),
109 | nn.SELU(inplace=True),
110 | nn.Conv2d(128, 128, 1, padding=0),
111 | )
112 |
113 | self.upsample = nn.Upsample(scale_factor=4, mode = 'bilinear')
114 | self.conv1x1 = nn.Conv2d(256, 128, 1, padding=0)
115 |
116 | self.up1 = UNetUp(128, 128)
117 | self.up2 = UNetUp(256, 128)
118 | self.up3 = UNetUp(192, 64)
119 | self.up4 = UNetUp(96, 32)
120 |
121 | self.final = nn.Sequential(
122 | nn.Conv2d(48, 16, 3, padding=1),
123 | nn.SELU(inplace=True),
124 | #nn.BatchNorm2d(16),
125 | #nn.InstanceNorm2d(16, affine = True),
126 | nn.Conv2d(16, out_channels, 3, padding=1),
127 | #nn.Tanh(),
128 | )
129 |
130 | def forward(self, x):
131 | # U-Net generator with skip connections from encoder to decoder
132 |
133 | x1 = self.conv1(x)
134 | d1 = self.down1(x1)
135 | d2 = self.down2(d1)
136 | d3 = self.down3(d2)
137 | d4 = self.down4(d3)
138 |
139 | d5 = self.down5(d4)
140 | d6 = self.down6(d5)
141 | d7 = self.down7(d6)
142 |
143 | d8 = self.upsample(d7)
144 | d9 = torch.cat((d4, d8), 1)
145 | d9 = self.conv1x1(d9)
146 |
147 | u1 = self.up1(d9, d3)
148 | u2 = self.up2(u1, d2)
149 | u3 = self.up3(u2, d1)
150 | u4 = self.up4(u3, x1)
151 |
152 | return torch.add(self.final(u4), x)
153 |
154 |
155 | class Discriminator_UNet(nn.Module):
156 | def __init__(self, in_channels=3):
157 | super(Discriminator_UNet, self).__init__()
158 |
159 | self.model = nn.Sequential(
160 | nn.Conv2d(3, 16, 3, stride=2, padding=1),
161 | nn.LeakyReLU(0.2),
162 | nn.InstanceNorm2d(16, affine=True),
163 | *discriminator_block(16, 32),
164 | *discriminator_block(32, 64),
165 | *discriminator_block(64, 128),
166 | *discriminator_block(128, 128),
167 | *discriminator_block(128, 128),
168 | nn.Conv2d(128, 1, 4, padding=0)
169 | )
170 |
171 | def forward(self, img_input):
172 | return self.model(img_input)
173 |
174 | ##############################
175 | # Discriminator
176 | ##############################
177 |
178 |
179 | def discriminator_block(in_filters, out_filters, normalization=False):
180 | """Returns downsampling layers of each discriminator block"""
181 | layers = [nn.Conv2d(in_filters, out_filters, 3, stride=2, padding=1)]
182 | layers.append(nn.LeakyReLU(0.2))
183 | if normalization:
184 | layers.append(nn.InstanceNorm2d(out_filters, affine=True))
185 | #layers.append(nn.BatchNorm2d(out_filters))
186 |
187 | return layers
188 |
189 | class Discriminator(nn.Module):
190 | def __init__(self, in_channels=3):
191 | super(Discriminator, self).__init__()
192 |
193 | self.model = nn.Sequential(
194 | nn.Upsample(size=(256,256),mode='bilinear'),
195 | nn.Conv2d(3, 16, 3, stride=2, padding=1),
196 | nn.LeakyReLU(0.2),
197 | nn.InstanceNorm2d(16, affine=True),
198 | *discriminator_block(16, 32),
199 | *discriminator_block(32, 64),
200 | *discriminator_block(64, 128),
201 | *discriminator_block(128, 128),
202 | #*discriminator_block(128, 128),
203 | nn.Conv2d(128, 1, 8, padding=0)
204 | )
205 |
206 | def forward(self, img_input):
207 | return self.model(img_input)
208 |
209 | class Classifier(nn.Module):
210 | def __init__(self, in_channels=3):
211 | super(Classifier, self).__init__()
212 |
213 | self.model = nn.Sequential(
214 | nn.Upsample(size=(256,256),mode='bilinear'),
215 | nn.Conv2d(3, 16, 3, stride=2, padding=1),
216 | nn.LeakyReLU(0.2),
217 | nn.InstanceNorm2d(16, affine=True),
218 | *discriminator_block(16, 32, normalization=True),
219 | *discriminator_block(32, 64, normalization=True),
220 | *discriminator_block(64, 128, normalization=True),
221 | *discriminator_block(128, 128),
222 | #*discriminator_block(128, 128, normalization=True),
223 | nn.Dropout(p=0.5),
224 | nn.Conv2d(128, 3, 8, padding=0),
225 | )
226 |
227 | def forward(self, img_input):
228 | return self.model(img_input)
229 |
230 | class Classifier_unpaired(nn.Module):
231 | def __init__(self, in_channels=3):
232 | super(Classifier_unpaired, self).__init__()
233 |
234 | self.model = nn.Sequential(
235 | nn.Upsample(size=(256,256),mode='bilinear'),
236 | nn.Conv2d(3, 16, 3, stride=2, padding=1),
237 | nn.LeakyReLU(0.2),
238 | nn.InstanceNorm2d(16, affine=True),
239 | *discriminator_block(16, 32),
240 | *discriminator_block(32, 64),
241 | *discriminator_block(64, 128),
242 | *discriminator_block(128, 128),
243 | #*discriminator_block(128, 128),
244 | nn.Conv2d(128, 3, 8, padding=0),
245 | )
246 |
247 | def forward(self, img_input):
248 | return self.model(img_input)
249 |
250 |
251 | class Generator3DLUT_identity(nn.Module):
252 | def __init__(self, dim=33):
253 | super(Generator3DLUT_identity, self).__init__()
254 | if dim == 33:
255 | file = open("IdentityLUT33.txt",'r')
256 | elif dim == 64:
257 | file = open("IdentityLUT64.txt",'r')
258 | LUT = file.readlines()
259 | self.LUT = torch.zeros(3,dim,dim,dim, dtype=torch.float)
260 |
261 | for i in range(0,dim):
262 | for j in range(0,dim):
263 | for k in range(0,dim):
264 | n = i * dim*dim + j * dim + k
265 | x = LUT[n].split()
266 | self.LUT[0,i,j,k] = float(x[0])
267 | self.LUT[1,i,j,k] = float(x[1])
268 | self.LUT[2,i,j,k] = float(x[2])
269 | self.LUT = nn.Parameter(torch.tensor(self.LUT))
270 | self.TrilinearInterpolation = TrilinearInterpolation()
271 |
272 | def forward(self, x):
273 |
274 | return self.TrilinearInterpolation(self.LUT, x)
275 |
276 |
277 | class Generator3DLUT_zero(nn.Module):
278 | def __init__(self, dim=33):
279 | super(Generator3DLUT_zero, self).__init__()
280 |
281 | self.LUT = torch.zeros(3,dim,dim,dim, dtype=torch.float)
282 | self.LUT = nn.Parameter(torch.tensor(self.LUT))
283 | self.TrilinearInterpolation = TrilinearInterpolation()
284 |
285 | def forward(self, x):
286 |
287 | return self.TrilinearInterpolation(self.LUT, x)
288 |
289 | class TrilinearInterpolation(torch.autograd.Function):
290 |
291 | def forward(self, LUT, x):
292 |
293 | x = x.contiguous()
294 | output = x.new(x.size())
295 | dim = LUT.size()[-1]
296 | shift = dim ** 3
297 | binsize = 1.0001 / (dim-1)
298 | W = x.size(2)
299 | H = x.size(3)
300 | batch = x.size(0)
301 |
302 | self.x = x
303 | self.LUT = LUT
304 | self.dim = dim
305 | self.shift = shift
306 | self.binsize = binsize
307 | self.W = W
308 | self.H = H
309 | self.batch = batch
310 |
311 | if x.is_cuda:
312 | if batch == 1:
313 | trilinear.trilinear_forward_cuda(LUT,x,output,dim,shift,binsize,W,H,batch)
314 | elif batch > 1:
315 | output = output.permute(1,0,2,3).contiguous()
316 | trilinear.trilinear_forward_cuda(LUT,x.permute(1,0,2,3).contiguous(),output,dim,shift,binsize,W,H,batch)
317 | output = output.permute(1,0,2,3).contiguous()
318 |
319 | else:
320 | trilinear.trilinear_forward(LUT,x,output,dim,shift,binsize,W,H,batch)
321 |
322 | return output
323 |
324 | def backward(self, grad_x):
325 |
326 | grad_LUT = torch.zeros(3,self.dim,self.dim,self.dim,dtype=torch.float)
327 |
328 | if grad_x.is_cuda:
329 | grad_LUT = grad_LUT.cuda()
330 | if self.batch == 1:
331 | trilinear.trilinear_backward_cuda(self.x,grad_x,grad_LUT,self.dim,self.shift,self.binsize,self.W,self.H,self.batch)
332 | elif self.batch > 1:
333 | trilinear.trilinear_backward_cuda(self.x.permute(1,0,2,3).contiguous(),grad_x.permute(1,0,2,3).contiguous(),grad_LUT,self.dim,self.shift,self.binsize,self.W,self.H,self.batch)
334 | else:
335 | trilinear.trilinear_backward(self.x,grad_x,grad_LUT,self.dim,self.shift,self.binsize,self.W,self.H,self.batch)
336 |
337 | return grad_LUT, None
338 |
339 |
340 | class TV_3D(nn.Module):
341 | def __init__(self, dim=33):
342 | super(TV_3D,self).__init__()
343 |
344 | self.weight_r = torch.ones(3,dim,dim,dim-1, dtype=torch.float)
345 | self.weight_r[:,:,:,(0,dim-2)] *= 2.0
346 | self.weight_g = torch.ones(3,dim,dim-1,dim, dtype=torch.float)
347 | self.weight_g[:,:,(0,dim-2),:] *= 2.0
348 | self.weight_b = torch.ones(3,dim-1,dim,dim, dtype=torch.float)
349 | self.weight_b[:,(0,dim-2),:,:] *= 2.0
350 | self.relu = torch.nn.ReLU()
351 |
352 | def forward(self, LUT):
353 |
354 | dif_r = LUT.LUT[:,:,:,:-1] - LUT.LUT[:,:,:,1:]
355 | dif_g = LUT.LUT[:,:,:-1,:] - LUT.LUT[:,:,1:,:]
356 | dif_b = LUT.LUT[:,:-1,:,:] - LUT.LUT[:,1:,:,:]
357 | tv = torch.mean(torch.mul((dif_r ** 2),self.weight_r)) + torch.mean(torch.mul((dif_g ** 2),self.weight_g)) + torch.mean(torch.mul((dif_b ** 2),self.weight_b))
358 |
359 | mn = torch.mean(self.relu(dif_r)) + torch.mean(self.relu(dif_g)) + torch.mean(self.relu(dif_b))
360 |
361 | return tv, mn
362 |
363 |
364 |
--------------------------------------------------------------------------------
/models_x.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | import torchvision.models as models
4 | import torchvision.transforms as transforms
5 | from torch.autograd import Variable
6 | import torch
7 | import numpy as np
8 | import math
9 | import trilinear
10 |
11 | def weights_init_normal_classifier(m):
12 | classname = m.__class__.__name__
13 | if classname.find("Conv") != -1:
14 | torch.nn.init.xavier_normal_(m.weight.data)
15 |
16 | elif classname.find("BatchNorm2d") != -1 or classname.find("InstanceNorm2d") != -1:
17 | torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
18 | torch.nn.init.constant_(m.bias.data, 0.0)
19 |
20 | class resnet18_224(nn.Module):
21 |
22 | def __init__(self, out_dim=5, aug_test=False):
23 | super(resnet18_224, self).__init__()
24 |
25 | self.aug_test = aug_test
26 | net = models.resnet18(pretrained=True)
27 | # self.mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).cuda()
28 | # self.std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).cuda()
29 |
30 | self.upsample = nn.Upsample(size=(224,224),mode='bilinear')
31 | net.fc = nn.Linear(512, out_dim)
32 | self.model = net
33 |
34 |
35 | def forward(self, x):
36 |
37 | x = self.upsample(x)
38 | if self.aug_test:
39 | # x = torch.cat((x, torch.rot90(x, 1, [2, 3]), torch.rot90(x, 3, [2, 3])), 0)
40 | x = torch.cat((x, torch.flip(x, [3])), 0)
41 | f = self.model(x)
42 |
43 | return f
44 |
45 | ##############################
46 | # Discriminator
47 | ##############################
48 |
49 |
50 | def discriminator_block(in_filters, out_filters, normalization=False):
51 | """Returns downsampling layers of each discriminator block"""
52 | layers = [nn.Conv2d(in_filters, out_filters, 3, stride=2, padding=1)]
53 | layers.append(nn.LeakyReLU(0.2))
54 | if normalization:
55 | layers.append(nn.InstanceNorm2d(out_filters, affine=True))
56 | #layers.append(nn.BatchNorm2d(out_filters))
57 |
58 | return layers
59 |
60 | class Discriminator(nn.Module):
61 | def __init__(self, in_channels=3):
62 | super(Discriminator, self).__init__()
63 |
64 | self.model = nn.Sequential(
65 | nn.Upsample(size=(256,256),mode='bilinear'),
66 | nn.Conv2d(3, 16, 3, stride=2, padding=1),
67 | nn.LeakyReLU(0.2),
68 | nn.InstanceNorm2d(16, affine=True),
69 | *discriminator_block(16, 32),
70 | *discriminator_block(32, 64),
71 | *discriminator_block(64, 128),
72 | *discriminator_block(128, 128),
73 | #*discriminator_block(128, 128),
74 | nn.Conv2d(128, 1, 8, padding=0)
75 | )
76 |
77 | def forward(self, img_input):
78 | return self.model(img_input)
79 |
80 | class Classifier(nn.Module):
81 | def __init__(self, in_channels=3):
82 | super(Classifier, self).__init__()
83 |
84 | self.model = nn.Sequential(
85 | nn.Upsample(size=(256,256),mode='bilinear'),
86 | nn.Conv2d(3, 16, 3, stride=2, padding=1),
87 | nn.LeakyReLU(0.2),
88 | nn.InstanceNorm2d(16, affine=True),
89 | *discriminator_block(16, 32, normalization=True),
90 | *discriminator_block(32, 64, normalization=True),
91 | *discriminator_block(64, 128, normalization=True),
92 | *discriminator_block(128, 128),
93 | #*discriminator_block(128, 128, normalization=True),
94 | nn.Dropout(p=0.5),
95 | nn.Conv2d(128, 3, 8, padding=0),
96 | )
97 |
98 | def forward(self, img_input):
99 | return self.model(img_input)
100 |
101 | class Classifier_unpaired(nn.Module):
102 | def __init__(self, in_channels=3):
103 | super(Classifier_unpaired, self).__init__()
104 |
105 | self.model = nn.Sequential(
106 | nn.Upsample(size=(256,256),mode='bilinear'),
107 | nn.Conv2d(3, 16, 3, stride=2, padding=1),
108 | nn.LeakyReLU(0.2),
109 | nn.InstanceNorm2d(16, affine=True),
110 | *discriminator_block(16, 32),
111 | *discriminator_block(32, 64),
112 | *discriminator_block(64, 128),
113 | *discriminator_block(128, 128),
114 | #*discriminator_block(128, 128),
115 | nn.Conv2d(128, 3, 8, padding=0),
116 | )
117 |
118 | def forward(self, img_input):
119 | return self.model(img_input)
120 |
121 |
122 | class Generator3DLUT_identity(nn.Module):
123 | def __init__(self, dim=33):
124 | super(Generator3DLUT_identity, self).__init__()
125 | if dim == 33:
126 | file = open("IdentityLUT33.txt", 'r')
127 | elif dim == 64:
128 | file = open("IdentityLUT64.txt", 'r')
129 | lines = file.readlines()
130 | buffer = np.zeros((3,dim,dim,dim), dtype=np.float32)
131 |
132 | for i in range(0,dim):
133 | for j in range(0,dim):
134 | for k in range(0,dim):
135 | n = i * dim*dim + j * dim + k
136 | x = lines[n].split()
137 | buffer[0,i,j,k] = float(x[0])
138 | buffer[1,i,j,k] = float(x[1])
139 | buffer[2,i,j,k] = float(x[2])
140 | self.LUT = nn.Parameter(torch.from_numpy(buffer).requires_grad_(True))
141 | self.TrilinearInterpolation = TrilinearInterpolation()
142 |
143 | def forward(self, x):
144 | _, output = self.TrilinearInterpolation(self.LUT, x)
145 | #self.LUT, output = self.TrilinearInterpolation(self.LUT, x)
146 | return output
147 |
148 | class Generator3DLUT_zero(nn.Module):
149 | def __init__(self, dim=33):
150 | super(Generator3DLUT_zero, self).__init__()
151 |
152 | self.LUT = torch.zeros(3,dim,dim,dim, dtype=torch.float)
153 | self.LUT = nn.Parameter(torch.tensor(self.LUT))
154 | self.TrilinearInterpolation = TrilinearInterpolation()
155 |
156 | def forward(self, x):
157 | _, output = self.TrilinearInterpolation(self.LUT, x)
158 |
159 | return output
160 |
161 | class TrilinearInterpolationFunction(torch.autograd.Function):
162 | @staticmethod
163 | def forward(ctx, lut, x):
164 | x = x.contiguous()
165 |
166 | output = x.new(x.size())
167 | dim = lut.size()[-1]
168 | shift = dim ** 3
169 | binsize = 1.000001 / (dim-1)
170 | W = x.size(2)
171 | H = x.size(3)
172 | batch = x.size(0)
173 |
174 | assert 1 == trilinear.forward(lut,
175 | x,
176 | output,
177 | dim,
178 | shift,
179 | binsize,
180 | W,
181 | H,
182 | batch)
183 |
184 | int_package = torch.IntTensor([dim, shift, W, H, batch])
185 | float_package = torch.FloatTensor([binsize])
186 | variables = [lut, x, int_package, float_package]
187 |
188 | ctx.save_for_backward(*variables)
189 |
190 | return lut, output
191 |
192 | @staticmethod
193 | def backward(ctx, lut_grad, x_grad):
194 |
195 | lut, x, int_package, float_package = ctx.saved_variables
196 | dim, shift, W, H, batch = int_package
197 | dim, shift, W, H, batch = int(dim), int(shift), int(W), int(H), int(batch)
198 | binsize = float(float_package[0])
199 |
200 | assert 1 == trilinear.backward(x,
201 | x_grad,
202 | lut_grad,
203 | dim,
204 | shift,
205 | binsize,
206 | W,
207 | H,
208 | batch)
209 | return lut_grad, x_grad
210 |
211 |
212 | class TrilinearInterpolation(torch.nn.Module):
213 | def __init__(self):
214 | super(TrilinearInterpolation, self).__init__()
215 |
216 | def forward(self, lut, x):
217 | return TrilinearInterpolationFunction.apply(lut, x)
218 |
219 |
220 | class TV_3D(nn.Module):
221 | def __init__(self, dim=33):
222 | super(TV_3D,self).__init__()
223 |
224 | self.weight_r = torch.ones(3,dim,dim,dim-1, dtype=torch.float)
225 | self.weight_r[:,:,:,(0,dim-2)] *= 2.0
226 | self.weight_g = torch.ones(3,dim,dim-1,dim, dtype=torch.float)
227 | self.weight_g[:,:,(0,dim-2),:] *= 2.0
228 | self.weight_b = torch.ones(3,dim-1,dim,dim, dtype=torch.float)
229 | self.weight_b[:,(0,dim-2),:,:] *= 2.0
230 | self.relu = torch.nn.ReLU()
231 |
232 | def forward(self, LUT):
233 |
234 | dif_r = LUT.LUT[:,:,:,:-1] - LUT.LUT[:,:,:,1:]
235 | dif_g = LUT.LUT[:,:,:-1,:] - LUT.LUT[:,:,1:,:]
236 | dif_b = LUT.LUT[:,:-1,:,:] - LUT.LUT[:,1:,:,:]
237 | tv = torch.mean(torch.mul((dif_r ** 2),self.weight_r)) + torch.mean(torch.mul((dif_g ** 2),self.weight_g)) + torch.mean(torch.mul((dif_b ** 2),self.weight_b))
238 |
239 | mn = torch.mean(self.relu(dif_r)) + torch.mean(self.relu(dif_g)) + torch.mean(self.relu(dif_b))
240 |
241 | return tv, mn
242 |
243 |
244 |
--------------------------------------------------------------------------------
/pretrained_models/XYZ/LUTs.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/pretrained_models/XYZ/LUTs.pth
--------------------------------------------------------------------------------
/pretrained_models/XYZ/LUTs_unpaired.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/pretrained_models/XYZ/LUTs_unpaired.pth
--------------------------------------------------------------------------------
/pretrained_models/XYZ/classifier.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/pretrained_models/XYZ/classifier.pth
--------------------------------------------------------------------------------
/pretrained_models/XYZ/classifier_unpaired.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/pretrained_models/XYZ/classifier_unpaired.pth
--------------------------------------------------------------------------------
/pretrained_models/sRGB/LUTs.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/pretrained_models/sRGB/LUTs.pth
--------------------------------------------------------------------------------
/pretrained_models/sRGB/LUTs_unpaired.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/pretrained_models/sRGB/LUTs_unpaired.pth
--------------------------------------------------------------------------------
/pretrained_models/sRGB/classifier.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/pretrained_models/sRGB/classifier.pth
--------------------------------------------------------------------------------
/pretrained_models/sRGB/classifier_unpaired.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/HuiZeng/Image-Adaptive-3DLUT/b491f6df64a588864739a157db271e5c848e1805/pretrained_models/sRGB/classifier_unpaired.pth
--------------------------------------------------------------------------------
/requirements:
--------------------------------------------------------------------------------
1 | numpy==1.19.2
2 | PIL==6.1.0
3 | torch==0.4.1
4 | torchvision==0.2.2
5 | opencv-python==3.4.3
6 |
--------------------------------------------------------------------------------
/ssim.m:
--------------------------------------------------------------------------------
1 | function [mssim, ssim_map] = ssim(img1, img2, K, window, L)
2 |
3 | % ========================================================================
4 | % SSIM Index with automatic downsampling, Version 1.0
5 | % Copyright(c) 2009 Zhou Wang
6 | % All Rights Reserved.
7 | %
8 | % ----------------------------------------------------------------------
9 | % Permission to use, copy, or modify this software and its documentation
10 | % for educational and research purposes only and without fee is hereby
11 | % granted, provided that this copyright notice and the original authors'
12 | % names appear on all copies and supporting documentation. This program
13 | % shall not be used, rewritten, or adapted as the basis of a commercial
14 | % software or hardware product without first obtaining permission of the
15 | % authors. The authors make no representations about the suitability of
16 | % this software for any purpose. It is provided "as is" without express
17 | % or implied warranty.
18 | %----------------------------------------------------------------------
19 | %
20 | % This is an implementation of the algorithm for calculating the
21 | % Structural SIMilarity (SSIM) index between two images
22 | %
23 | % Please refer to the following paper and the website with suggested usage
24 | %
25 | % Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
26 | % quality assessment: From error visibility to structural similarity,"
27 | % IEEE Transactios on Image Processing, vol. 13, no. 4, pp. 600-612,
28 | % Apr. 2004.
29 | %
30 | % http://www.ece.uwaterloo.ca/~z70wang/research/ssim/
31 | %
32 | % Note: This program is different from ssim_index.m, where no automatic
33 | % downsampling is performed. (downsampling was done in the above paper
34 | % and was described as suggested usage in the above website.)
35 | %
36 | % Kindly report any suggestions or corrections to zhouwang@ieee.org
37 | %
38 | %----------------------------------------------------------------------
39 | %
40 | %Input : (1) img1: the first image being compared
41 | % (2) img2: the second image being compared
42 | % (3) K: constants in the SSIM index formula (see the above
43 | % reference). defualt value: K = [0.01 0.03]
44 | % (4) window: local window for statistics (see the above
45 | % reference). default widnow is Gaussian given by
46 | % window = fspecial('gaussian', 11, 1.5);
47 | % (5) L: dynamic range of the images. default: L = 255
48 | %
49 | %Output: (1) mssim: the mean SSIM index value between 2 images.
50 | % If one of the images being compared is regarded as
51 | % perfect quality, then mssim can be considered as the
52 | % quality measure of the other image.
53 | % If img1 = img2, then mssim = 1.
54 | % (2) ssim_map: the SSIM index map of the test image. The map
55 | % has a smaller size than the input images. The actual size
56 | % depends on the window size and the downsampling factor.
57 | %
58 | %Basic Usage:
59 | % Given 2 test images img1 and img2, whose dynamic range is 0-255
60 | %
61 | % [mssim, ssim_map] = ssim(img1, img2);
62 | %
63 | %Advanced Usage:
64 | % User defined parameters. For example
65 | %
66 | % K = [0.05 0.05];
67 | % window = ones(8);
68 | % L = 100;
69 | % [mssim, ssim_map] = ssim(img1, img2, K, window, L);
70 | %
71 | %Visualize the results:
72 | %
73 | % mssim %Gives the mssim value
74 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map
75 | %========================================================================
76 |
77 |
78 | if (nargin < 2 || nargin > 5)
79 | mssim = -Inf;
80 | ssim_map = -Inf;
81 | return;
82 | end
83 |
84 | if (size(img1) ~= size(img2))
85 | mssim = -Inf;
86 | ssim_map = -Inf;
87 | return;
88 | end
89 |
90 | [M N] = size(img1);
91 |
92 | if (nargin == 2)
93 | if ((M < 11) || (N < 11))
94 | mssim = -Inf;
95 | ssim_map = -Inf;
96 | return
97 | end
98 | window = fspecial('gaussian', 11, 1.5); %
99 | K(1) = 0.01; % default settings
100 | K(2) = 0.03; %
101 | L = 255; %
102 | end
103 |
104 | if (nargin == 3)
105 | if ((M < 11) || (N < 11))
106 | mssim = -Inf;
107 | ssim_map = -Inf;
108 | return
109 | end
110 | window = fspecial('gaussian', 11, 1.5);
111 | L = 255;
112 | if (length(K) == 2)
113 | if (K(1) < 0 || K(2) < 0)
114 | mssim = -Inf;
115 | ssim_map = -Inf;
116 | return;
117 | end
118 | else
119 | mssim = -Inf;
120 | ssim_map = -Inf;
121 | return;
122 | end
123 | end
124 |
125 | if (nargin == 4)
126 | [H W] = size(window);
127 | if ((H*W) < 4 || (H > M) || (W > N))
128 | mssim = -Inf;
129 | ssim_map = -Inf;
130 | return
131 | end
132 | L = 255;
133 | if (length(K) == 2)
134 | if (K(1) < 0 || K(2) < 0)
135 | mssim = -Inf;
136 | ssim_map = -Inf;
137 | return;
138 | end
139 | else
140 | mssim = -Inf;
141 | ssim_map = -Inf;
142 | return;
143 | end
144 | end
145 |
146 | if (nargin == 5)
147 | [H W] = size(window);
148 | if ((H*W) < 4 || (H > M) || (W > N))
149 | mssim = -Inf;
150 | ssim_map = -Inf;
151 | return
152 | end
153 | if (length(K) == 2)
154 | if (K(1) < 0 || K(2) < 0)
155 | mssim = -Inf;
156 | ssim_map = -Inf;
157 | return;
158 | end
159 | else
160 | mssim = -Inf;
161 | ssim_map = -Inf;
162 | return;
163 | end
164 | end
165 |
166 |
167 | img1 = double(img1);
168 | img2 = double(img2);
169 |
170 | % automatic downsampling
171 | f = max(1,round(min(M,N)/256));
172 | %downsampling by f
173 | %use a simple low-pass filter
174 | if(f>1)
175 | lpf = ones(f,f);
176 | lpf = lpf/sum(lpf(:));
177 | img1 = imfilter(img1,lpf,'symmetric','same');
178 | img2 = imfilter(img2,lpf,'symmetric','same');
179 |
180 | img1 = img1(1:f:end,1:f:end);
181 | img2 = img2(1:f:end,1:f:end);
182 | end
183 |
184 | C1 = (K(1)*L)^2;
185 | C2 = (K(2)*L)^2;
186 | window = window/sum(sum(window));
187 |
188 | mu1 = filter2(window, img1, 'valid');
189 | mu2 = filter2(window, img2, 'valid');
190 | mu1_sq = mu1.*mu1;
191 | mu2_sq = mu2.*mu2;
192 | mu1_mu2 = mu1.*mu2;
193 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq;
194 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq;
195 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2;
196 |
197 | if (C1 > 0 && C2 > 0)
198 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2));
199 | else
200 | numerator1 = 2*mu1_mu2 + C1;
201 | numerator2 = 2*sigma12 + C2;
202 | denominator1 = mu1_sq + mu2_sq + C1;
203 | denominator2 = sigma1_sq + sigma2_sq + C2;
204 | ssim_map = ones(size(mu1));
205 | index = (denominator1.*denominator2 > 0);
206 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index));
207 | index = (denominator1 ~= 0) & (denominator2 == 0);
208 | ssim_map(index) = numerator1(index)./denominator1(index);
209 | end
210 |
211 | mssim = mean2(ssim_map);
212 |
213 | return
--------------------------------------------------------------------------------
/torchvision_x_functional.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import numbers
3 | from functools import wraps
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 | from PIL import Image
9 | from scipy.ndimage.filters import gaussian_filter
10 |
11 | __numpy_type_map = {
12 | 'float64': torch.DoubleTensor,
13 | 'float32': torch.FloatTensor,
14 | 'float16': torch.HalfTensor,
15 | 'int64': torch.LongTensor,
16 | 'int32': torch.IntTensor,
17 | 'int16': torch.ShortTensor,
18 | 'uint16': torch.ShortTensor,
19 | 'int8': torch.CharTensor,
20 | 'uint8': torch.ByteTensor,
21 | }
22 |
23 | '''image functional utils
24 |
25 | '''
26 |
27 | # NOTE: all the function should recive the ndarray like image, should be W x H x C or W x H
28 |
29 | # 如果将所有输出的维度够搞成height,width,channel 那么可以不用to_tensor??, 不行
30 | def preserve_channel_dim(func):
31 | """Preserve dummy channel dim."""
32 | @wraps(func)
33 | def wrapped_function(img, *args, **kwargs):
34 | shape = img.shape
35 | result = func(img, *args, **kwargs)
36 | if len(shape) == 3 and shape[-1] == 1 and len(result.shape) == 2:
37 | result = np.expand_dims(result, axis=-1)
38 | return result
39 |
40 | return wrapped_function
41 |
42 |
43 | def _is_tensor_image(img):
44 | return torch.is_tensor(img) and img.ndimension() == 3
45 |
46 |
47 | def _is_numpy_image(img):
48 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
49 |
50 |
51 | def to_tensor(img):
52 | '''convert numpy.ndarray to torch tensor. \n
53 | if the image is uint8 , it will be divided by 255;\n
54 | if the image is uint16 , it will be divided by 65535;\n
55 | if the image is float , it will not be divided, we suppose your image range should between [0~1] ;\n
56 |
57 | Arguments:
58 | img {numpy.ndarray} -- image to be converted to tensor.
59 | '''
60 | if not _is_numpy_image(img):
61 | raise TypeError('data should be numpy ndarray. but got {}'.format(type(img)))
62 |
63 | if img.ndim == 2:
64 | img = img[:, :, None]
65 |
66 | if img.dtype == np.uint8:
67 | img = img.astype(np.float32)/255
68 | elif img.dtype == np.uint16:
69 | img = img.astype(np.float32)/65535
70 | elif img.dtype in [np.float32, np.float64]:
71 | img = img.astype(np.float32)/1
72 | else:
73 | raise TypeError('{} is not support'.format(img.dtype))
74 |
75 | img = torch.from_numpy(img.transpose((2, 0, 1)))
76 |
77 | return img
78 |
79 |
80 | def to_pil_image(tensor):
81 | # TODO
82 | pass
83 |
84 |
85 | def to_tiff_image(tensor):
86 | # TODO
87 | pass
88 |
89 |
90 | def normalize(tensor, mean, std, inplace=False):
91 | """Normalize a tensor image with mean and standard deviation.
92 |
93 | .. note::
94 | This transform acts out of place by default, i.e., it does not mutates the input tensor.
95 |
96 | See :class:`~torchsat.transforms.Normalize` for more details.
97 |
98 | Args:
99 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
100 | mean (sequence): Sequence of means for each channel.
101 | std (sequence): Sequence of standard deviations for each channel.
102 |
103 | Returns:
104 | Tensor: Normalized Tensor image.
105 | """
106 | if not _is_tensor_image(tensor):
107 | raise TypeError('tensor is not a torch image.')
108 |
109 | if not inplace:
110 | tensor = tensor.clone()
111 |
112 | mean = torch.as_tensor(mean, dtype=tensor.dtype, device=tensor.device)
113 | std = torch.as_tensor(std, dtype=tensor.dtype, device=tensor.device)
114 | tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
115 | return tensor
116 |
117 | def noise(img, mode='gaussain', percent=0.02):
118 | """
119 | TODO: Not good for uint16 data
120 | """
121 | original_dtype = img.dtype
122 | if mode == 'gaussian':
123 | mean = 0
124 | var = 0.1
125 | sigma = var*0.5
126 |
127 | if img.ndim == 2:
128 | h, w = img.shape
129 | gauss = np.random.normal(mean, sigma, (h, w))
130 | else:
131 | h, w, c = img.shape
132 | gauss = np.random.normal(mean, sigma, (h, w, c))
133 |
134 | if img.dtype not in [np.float32, np.float64]:
135 | gauss = gauss * np.iinfo(img.dtype).max
136 | img = np.clip(img.astype(np.float) + gauss, 0, np.iinfo(img.dtype).max)
137 | else:
138 | img = np.clip(img.astype(np.float) + gauss, 0, 1)
139 |
140 | elif mode == 'salt':
141 | print(img.dtype)
142 | s_vs_p = 1
143 | num_salt = np.ceil(percent * img.size * s_vs_p)
144 | coords = tuple([np.random.randint(0, i - 1, int(num_salt)) for i in img.shape])
145 |
146 | if img.dtype in [np.float32, np.float64]:
147 | img[coords] = 1
148 | else:
149 | img[coords] = np.iinfo(img.dtype).max
150 | print(img.dtype)
151 | elif mode == 'pepper':
152 | s_vs_p = 0
153 | num_pepper = np.ceil(percent * img.size * (1. - s_vs_p))
154 | coords = tuple([np.random.randint(0, i - 1, int(num_pepper)) for i in img.shape])
155 | img[coords] = 0
156 |
157 | elif mode == 's&p':
158 | s_vs_p = 0.5
159 |
160 | # Salt mode
161 | num_salt = np.ceil(percent * img.size * s_vs_p)
162 | coords = tuple([np.random.randint(0, i - 1, int(num_salt)) for i in img.shape])
163 | if img.dtype in [np.float32, np.float64]:
164 | img[coords] = 1
165 | else:
166 | img[coords] = np.iinfo(img.dtype).max
167 |
168 | # Pepper mode
169 | num_pepper = np.ceil(percent* img.size * (1. - s_vs_p))
170 | coords = tuple([np.random.randint(0, i - 1, int(num_pepper)) for i in img.shape])
171 | img[coords] = 0
172 | else:
173 | raise ValueError('not support mode for {}'.format(mode))
174 |
175 | noisy = img.astype(original_dtype)
176 |
177 | return noisy
178 |
179 |
180 | def gaussian_blur(img, kernel_size):
181 | # When sigma=0, it is computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8`
182 | return cv2.GaussianBlur(img, (kernel_size, kernel_size), sigmaX=0)
183 |
184 |
185 | def adjust_brightness(img, value=0):
186 | if img.dtype in [np.float, np.float32, np.float64, np.float128]:
187 | dtype_min, dtype_max = 0, 1
188 | dtype = np.float32
189 | else:
190 | dtype_min = np.iinfo(img.dtype).min
191 | dtype_max = np.iinfo(img.dtype).max
192 | dtype = np.iinfo(img.dtype)
193 |
194 | result = np.clip(img.astype(np.float)+value, dtype_min, dtype_max).astype(dtype)
195 |
196 | return result
197 |
198 |
199 | def adjust_contrast(img, factor):
200 | if img.dtype in [np.float, np.float32, np.float64, np.float128]:
201 | dtype_min, dtype_max = 0, 1
202 | dtype = np.float32
203 | else:
204 | dtype_min = np.iinfo(img.dtype).min
205 | dtype_max = np.iinfo(img.dtype).max
206 | dtype = np.iinfo(img.dtype)
207 |
208 | result = np.clip(img.astype(np.float)*factor, dtype_min, dtype_max).astype(dtype)
209 |
210 | return result
211 |
212 | def adjust_saturation():
213 | # TODO
214 | pass
215 |
216 | def adjust_hue():
217 | # TODO
218 | pass
219 |
220 |
221 |
222 | def to_grayscale(img, output_channels=1):
223 | """convert input ndarray image to gray sacle image.
224 |
225 | Arguments:
226 | img {ndarray} -- the input ndarray image
227 |
228 | Keyword Arguments:
229 | output_channels {int} -- output gray image channel (default: {1})
230 |
231 | Returns:
232 | ndarray -- gray scale ndarray image
233 | """
234 | if img.ndim == 2:
235 | gray_img = img
236 | elif img.shape[2] == 3:
237 | gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
238 | else:
239 | gray_img = np.mean(img, axis=2)
240 | gray_img = gray_img.astype(img.dtype)
241 |
242 | if output_channels != 1:
243 | gray_img = np.tile(gray_img, (output_channels, 1, 1))
244 | gray_img = np.transpose(gray_img, [1,2,0])
245 |
246 | return gray_img
247 |
248 |
249 | def shift(img, top, left):
250 | (h, w) = img.shape[0:2]
251 | matrix = np.float32([[1, 0, left], [0, 1, top]])
252 | dst = cv2.warpAffine(img, matrix, (w, h))
253 |
254 | return dst
255 |
256 |
257 | def rotate(img, angle, center=None, scale=1.0):
258 | (h, w) = img.shape[:2]
259 |
260 | if center is None:
261 | center = (w / 2, h / 2)
262 |
263 | M = cv2.getRotationMatrix2D(center, angle, scale)
264 | rotated = cv2.warpAffine(img, M, (w, h))
265 |
266 | return rotated
267 |
268 |
269 | def resize(img, size, interpolation=Image.BILINEAR):
270 | '''resize the image
271 | TODO: opencv resize 之后图像就成了0~1了
272 | Arguments:
273 | img {ndarray} -- the input ndarray image
274 | size {int, iterable} -- the target size, if size is intger, width and height will be resized to same \
275 | otherwise, the size should be tuple (height, width) or list [height, width]
276 |
277 |
278 | Keyword Arguments:
279 | interpolation {Image} -- the interpolation method (default: {Image.BILINEAR})
280 |
281 | Raises:
282 | TypeError -- img should be ndarray
283 | ValueError -- size should be intger or iterable vaiable and length should be 2.
284 |
285 | Returns:
286 | img -- resize ndarray image
287 | '''
288 |
289 | if not _is_numpy_image(img):
290 | raise TypeError('img shoud be ndarray image [w, h, c] or [w, h], but got {}'.format(type(img)))
291 | if not (isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size)==2)):
292 | raise ValueError('size should be intger or iterable vaiable(length is 2), but got {}'.format(type(size)))
293 |
294 | if isinstance(size, int):
295 | height, width = (size, size)
296 | else:
297 | height, width = (size[0], size[1])
298 |
299 | return cv2.resize(img, (width, height), interpolation=interpolation)
300 |
301 |
302 | def pad(img, padding, fill=0, padding_mode='constant'):
303 | if isinstance(padding, int):
304 | pad_left = pad_right = pad_top = pad_bottom = padding
305 | if isinstance(padding, collections.Iterable) and len(padding) == 2:
306 | pad_left = pad_right = padding[0]
307 | pad_bottom = pad_top = padding[1]
308 | if isinstance(padding, collections.Iterable) and len(padding) == 4:
309 | pad_left = padding[0]
310 | pad_top = padding[1]
311 | pad_right = padding[2]
312 | pad_bottom = padding[3]
313 |
314 | if img.ndim == 2:
315 | if padding_mode == 'constant':
316 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode, constant_values=fill)
317 | else:
318 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), mode=padding_mode)
319 | if img.ndim == 3:
320 | if padding_mode == 'constant':
321 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode=padding_mode, constant_values=fill)
322 | else:
323 | img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode=padding_mode)
324 | return img
325 |
326 |
327 | def crop(img, top, left, height, width):
328 | '''crop image
329 |
330 | Arguments:
331 | img {ndarray} -- image to be croped
332 | top {int} -- top size
333 | left {int} -- left size
334 | height {int} -- croped height
335 | width {int} -- croped width
336 | '''
337 | if not _is_numpy_image(img):
338 | raise TypeError('the input image should be numpy ndarray with dimension 2 or 3.'
339 | 'but got {}'.format(type(img))
340 | )
341 |
342 | if width<0 or height<0 or left <0 or height<0:
343 | raise ValueError('the input left, top, width, height should be greater than 0'
344 | 'but got left={}, top={} width={} height={}'.format(left, top, width, height)
345 | )
346 | if img.ndim == 2:
347 | img_height, img_width = img.shape
348 | else:
349 | img_height, img_width, _ = img.shape
350 | if (left+width) > img_width or (top+height) > img_height:
351 | raise ValueError('the input crop width and height should be small or \
352 | equal to image width and height. ')
353 |
354 | if img.ndim == 2:
355 | return img[top:(top+height), left:(left+width)]
356 | elif img.ndim == 3:
357 | return img[top:(top+height), left:(left+width), :]
358 |
359 |
360 | def center_crop(img, output_size):
361 | '''crop image
362 |
363 | Arguments:
364 | img {ndarray} -- input image
365 | output_size {number or sequence} -- the output image size. if sequence, should be [h, w]
366 |
367 | Raises:
368 | ValueError -- the input image is large than original image.
369 |
370 | Returns:
371 | ndarray image -- return croped ndarray image.
372 | '''
373 | if img.ndim == 2:
374 | img_height, img_width = img.shape
375 | else:
376 | img_height, img_width, _ = img.shape
377 |
378 | if isinstance(output_size, numbers.Number):
379 | output_size = (int(output_size), int(output_size))
380 | if output_size[0] > img_height or output_size[1] > img_width:
381 | raise ValueError('the output_size should not greater than image size, but got {}'.format(output_size))
382 |
383 | target_height, target_width = output_size
384 |
385 | top = int(round((img_height - target_height)/2))
386 | left = int(round((img_width - target_width)/2))
387 |
388 | return crop(img, top, left, target_height, target_width)
389 |
390 |
391 | def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINEAR):
392 |
393 | img = crop(img, top, left, height, width)
394 | img = resize(img, size, interpolation)
395 | return img
396 |
397 | def vflip(img):
398 | return cv2.flip(img, 0)
399 |
400 | def hflip(img):
401 | return cv2.flip(img, 1)
402 |
403 | def flip(img, flip_code):
404 | return cv2.flip(img, flip_code)
405 |
406 |
407 | def elastic_transform(image, alpha, sigma, alpha_affine, interpolation=cv2.INTER_LINEAR,
408 | border_mode=cv2.BORDER_REFLECT_101, random_state=None, approximate=False):
409 | """Elastic deformation of images as described in [Simard2003]_ (with modifications).
410 | Based on https://gist.github.com/erniejunior/601cdf56d2b424757de5
411 | .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
412 | Convolutional Neural Networks applied to Visual Document Analysis", in
413 | Proc. of the International Conference on Document Analysis and
414 | Recognition, 2003.
415 | """
416 | if random_state is None:
417 | random_state = np.random.RandomState(1234)
418 |
419 | height, width = image.shape[:2]
420 |
421 | # Random affine
422 | center_square = np.float32((height, width)) // 2
423 | square_size = min((height, width)) // 3
424 | alpha = float(alpha)
425 | sigma = float(sigma)
426 | alpha_affine = float(alpha_affine)
427 |
428 | pts1 = np.float32([center_square + square_size, [center_square[0] + square_size, center_square[1] - square_size],
429 | center_square - square_size])
430 | pts2 = pts1 + random_state.uniform(-alpha_affine, alpha_affine, size=pts1.shape).astype(np.float32)
431 | matrix = cv2.getAffineTransform(pts1, pts2)
432 |
433 | image = cv2.warpAffine(image, matrix, (width, height), flags=interpolation, borderMode=border_mode)
434 |
435 | if approximate:
436 | # Approximate computation smooth displacement map with a large enough kernel.
437 | # On large images (512+) this is approximately 2X times faster
438 | dx = (random_state.rand(height, width).astype(np.float32) * 2 - 1)
439 | cv2.GaussianBlur(dx, (17, 17), sigma, dst=dx)
440 | dx *= alpha
441 |
442 | dy = (random_state.rand(height, width).astype(np.float32) * 2 - 1)
443 | cv2.GaussianBlur(dy, (17, 17), sigma, dst=dy)
444 | dy *= alpha
445 | else:
446 | dx = np.float32(gaussian_filter((random_state.rand(height, width) * 2 - 1), sigma) * alpha)
447 | dy = np.float32(gaussian_filter((random_state.rand(height, width) * 2 - 1), sigma) * alpha)
448 |
449 | x, y = np.meshgrid(np.arange(width), np.arange(height))
450 |
451 | mapx = np.float32(x + dx)
452 | mapy = np.float32(y + dy)
453 |
454 | return cv2.remap(image, mapx, mapy, interpolation, borderMode=border_mode)
455 |
456 |
457 | def bbox_shift(bboxes, top, left):
458 | pass
459 |
460 |
461 | def bbox_vflip(bboxes, img_height):
462 | """vertical flip the bboxes
463 | ...........
464 | . .
465 | . .
466 | >...........<
467 | . .
468 | . .
469 | ...........
470 | Args:
471 | bbox (ndarray): bbox ndarray [box_nums, 4]
472 | flip_code (int, optional): [description]. Defaults to 0.
473 | """
474 | flipped = bboxes.copy()
475 | flipped[...,1::2] = img_height - bboxes[...,1::2]
476 | flipped = flipped[..., [0, 3, 2, 1]]
477 | return flipped
478 |
479 |
480 | def bbox_hflip(bboxes, img_width):
481 | """horizontal flip the bboxes
482 | ^
483 | .............
484 | . . .
485 | . . .
486 | . . .
487 | . . .
488 | .............
489 | ^
490 | Args:
491 | bbox (ndarray): bbox ndarray [box_nums, 4]
492 | flip_code (int, optional): [description]. Defaults to 0.
493 | """
494 | flipped = bboxes.copy()
495 | flipped[..., 0::2] = img_width - bboxes[...,0::2]
496 | flipped = flipped[..., [2, 1, 0, 3]]
497 | return flipped
498 |
499 |
500 | def bbox_resize(bboxes, img_size, target_size):
501 | """resize the bbox
502 |
503 | Args:
504 | bboxes (ndarray): bbox ndarray [box_nums, 4]
505 | img_size (tuple): the image height and width
506 | target_size (int, or tuple): the target bbox size.
507 | Int or Tuple, if tuple the shape should be (height, width)
508 | """
509 | if isinstance(target_size, numbers.Number):
510 | target_size = (target_size, target_size)
511 |
512 | ratio_height = target_size[0]/img_size[0]
513 | ratio_width = target_size[1]/img_size[1]
514 |
515 | return bboxes[...,]*[ratio_width,ratio_height,ratio_width,ratio_height]
516 |
517 |
518 | def bbox_crop(bboxes, top, left, height, width):
519 | '''crop bbox
520 |
521 | Arguments:
522 | img {ndarray} -- image to be croped
523 | top {int} -- top size
524 | left {int} -- left size
525 | height {int} -- croped height
526 | width {int} -- croped width
527 | '''
528 | croped_bboxes = bboxes.copy()
529 |
530 | right = width + left
531 | bottom = height + top
532 |
533 | croped_bboxes[..., 0::2] = bboxes[..., 0::2].clip(left, right) - left
534 | croped_bboxes[..., 1::2] = bboxes[..., 1::2].clip(top, bottom) - top
535 |
536 | return croped_bboxes
537 |
538 | def bbox_pad(bboxes, padding):
539 | if isinstance(padding, int):
540 | pad_left = pad_right = pad_top = pad_bottom = padding
541 | if isinstance(padding, collections.Iterable) and len(padding) == 2:
542 | pad_left = pad_right = padding[0]
543 | pad_bottom = pad_top = padding[1]
544 | if isinstance(padding, collections.Iterable) and len(padding) == 4:
545 | pad_left = padding[0]
546 | pad_top = padding[1]
547 | pad_right = padding[2]
548 | pad_bottom = padding[3]
549 |
550 | pad_bboxes = bboxes.copy()
551 | pad_bboxes[..., 0::2] = bboxes[..., 0::2] + pad_left
552 | pad_bboxes[..., 1::2] = bboxes[..., 1::2] + pad_top
553 |
554 | return pad_bboxes
555 |
--------------------------------------------------------------------------------
/trilinear_c/build.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import torch
4 | from torch.utils.ffi import create_extension
5 |
6 | sources = ['src/trilinear.c']
7 | headers = ['src/trilinear.h']
8 | extra_objects = []
9 | #sources = []
10 | #headers = []
11 | defines = []
12 | with_cuda = False
13 |
14 | this_file = os.path.dirname(os.path.realpath(__file__))
15 | print(this_file)
16 |
17 | if torch.cuda.is_available():
18 | print('Including CUDA code.')
19 | sources += ['src/trilinear_cuda.c']
20 | headers += ['src/trilinear_cuda.h']
21 | defines += [('WITH_CUDA', None)]
22 | with_cuda = True
23 |
24 | extra_objects = ['src/trilinear_kernel.cu.o']
25 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]
26 |
27 | ffi = create_extension(
28 | '_ext.trilinear',
29 | headers=headers,
30 | sources=sources,
31 | define_macros=defines,
32 | relative_to=__file__,
33 | with_cuda=with_cuda,
34 | extra_objects=extra_objects
35 | )
36 |
37 | if __name__ == '__main__':
38 | ffi.build()
39 |
--------------------------------------------------------------------------------
/trilinear_c/make.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CUDA_PATH=/usr/local/cuda/
4 |
5 | cd src
6 | echo "Compiling my_lib kernels by nvcc..."
7 | nvcc -c -o trilinear_kernel.cu.o trilinear_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52
8 |
9 | cd ../
10 | python3 build.py
11 |
--------------------------------------------------------------------------------
/trilinear_c/src/trilinear.c:
--------------------------------------------------------------------------------
1 | #include