├── .gitattributes
├── .gitignore
├── README.md
├── base_layers.py
├── base_parser.py
├── base_trainer.py
├── checkpoints
├── decom_net.pth
├── illum_net.pth
└── restore_net.pth
├── config.yaml
├── dataloader.py
├── decom_trainer.py
├── evaluate.py
├── figures
├── 778_epoch_-1.png
├── failure_case_decom.png
└── official_decom_train.png
├── illum_trainer.py
├── illum_trainer_custom.py
├── losses.py
├── models.py
├── pytorch_ssim
└── __init__.py
├── restore_MSIA_trainer.py
├── restore_trainer.py
├── test_your_pictures.py
├── utils.py
└── utils
└── img_generator.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # custom ignore
2 | *.h5
3 | *.jpg
4 | *.zip
5 | weights/
6 | images/
7 |
8 | # Byte-compiled / optimized / DLL files
9 | __pycache__/
10 | *.py[cod]
11 | *$py.class
12 |
13 | # C extensions
14 | *.so
15 |
16 | # Distribution / packaging
17 | .Python
18 | build/
19 | develop-eggs/
20 | dist/
21 | downloads/
22 | eggs/
23 | .eggs/
24 | lib/
25 | lib64/
26 | parts/
27 | sdist/
28 | var/
29 | wheels/
30 | pip-wheel-metadata/
31 | share/python-wheels/
32 | *.egg-info/
33 | .installed.cfg
34 | *.egg
35 | MANIFEST
36 |
37 | # PyInstaller
38 | # Usually these files are written by a python script from a template
39 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
40 | *.manifest
41 | *.spec
42 |
43 | # Installer logs
44 | pip-log.txt
45 | pip-delete-this-directory.txt
46 |
47 | # Unit test / coverage reports
48 | htmlcov/
49 | .tox/
50 | .nox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | *.py,cover
58 | .hypothesis/
59 | .pytest_cache/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 | db.sqlite3
69 | db.sqlite3-journal
70 |
71 | # Flask stuff:
72 | instance/
73 | .webassets-cache
74 |
75 | # Scrapy stuff:
76 | .scrapy
77 |
78 | # Sphinx documentation
79 | docs/_build/
80 |
81 | # PyBuilder
82 | target/
83 |
84 | # Jupyter Notebook
85 | .ipynb_checkpoints
86 |
87 | # IPython
88 | profile_default/
89 | ipython_config.py
90 |
91 | # pyenv
92 | .python-version
93 |
94 | # pipenv
95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
98 | # install all needed dependencies.
99 | #Pipfile.lock
100 |
101 | # celery beat schedule file
102 | celerybeat-schedule
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 | *.json
134 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # KinD-pytorch
2 | This is a PyTorch implementation of KinD.
3 |
4 | The official KinD project(TensorFlow) is [KinD](https://github.com/zhangyhuaee/KinD).
5 |
6 | The KinD net was proposed in the following [Paper](http://doi.acm.org/10.1145/3343031.3350926).
7 |
8 | Kindling the Darkness: a Practical Low-light Image Enhancer. In ACM MM 2019
9 | Yonghua Zhang, Jiawan Zhang, Xiaojie Guo
10 | ****
11 |
12 | ## Enviroment ##
13 | 1. Python = 3.6
14 | 2. PyTorch = 1.2.0
15 | 3. Other common packages
16 |
17 | ## Test ##
18 | Please put test images into './images/inputs' folder and download the pre-trained checkpoints from [BaiduNetDisk](https://pan.baidu.com/s/1e_P6_qxQqAwDG7q6NN_2ng), 提取码:fxkl, then just run
19 | ```shell
20 | python test_your_picture.py
21 | # -i: change input path, -o: change output path, -p: Plot more information
22 | # -c: change checkpoints path, -b: change default target brightness
23 | ```
24 |
25 | ## Train ##
26 | The original LOLdataset can be downloaded from [here](https://daooshee.github.io/BMVC2018website/).
27 | For training, **please change the dataset path in the code**, then run
28 | ```shell
29 | python decom_trainer.py
30 | python illum_trainer.py
31 | python restore_trainer.py
32 | ```
33 | You can also evaluate on the LOLdataset, **please change the dataset path in the code**, then run
34 | ```shell
35 | python evaluate_LOLdataset.py
36 | ```
37 |
38 | ## Problems ##
39 | I meet some serious problems when I try to train the decomposition net, which makes results look unpleasant.
40 |
41 | ### My PyTorch implementation's evaluation on LOLDataset: ###
42 |
43 | The problem that confuses me the most is the illumination_smoothness_loss. As long as I add this loss, my illuminance map output will tend to be completely black (low light map) and close to gray (high light map).
44 |
45 | ### My PyTorch implementation's failure case on LOLDataset: ###
46 |
47 | I have run the official TensorFlow code to train and test the decompostion net. The result is pretty strange. If I load official checkpoints to test it, it will perform well. However, if I use LOLDataset to train it, it will be worse and worse. I am really puzzled about this issue. If you have any idea about it, please tell me.
48 |
49 | I show the example below.
50 | ### Official implementation's strange case on LOLDataset: ###
51 |
52 | The left column shows the decomposition results of the high light map, and the right column shows the decomposition results of the low light map. The left side of each image is a reflection map, and the right side is a illumination map. The first line is the effect of the official weight, the second line is the effect when the official code is retrained by 100 epoch, and the third line is the result of the official code training of 1600 epoch. The training code cannot achieve the effect of official weights.
53 |
54 | Other test results on LOLDataset(eval15) can be found at the samples-KinD in the [BaiduNetDisk](https://pan.baidu.com/s/1e_P6_qxQqAwDG7q6NN_2ng), 提取码:fxkl.
55 |
56 | ## References ##
57 | [1] Y. Zhang, J. Zhang, and X. Guo, “Kindling the darkness: A practical low-light image enhancer,” in ACM MM, 2019, pp. 1632–1640.
58 |
--------------------------------------------------------------------------------
/base_layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | class MSIA(nn.Module):
6 | def __init__(self, filters, activation='lrelu'):
7 | super().__init__()
8 | # Down 1
9 | self.conv_bn_relu_1 = Conv_BN_Relu(filters, activation)
10 | # Down 2
11 | self.down_2 = MaxPooling2D(2, 2)
12 | self.conv_bn_relu_2 = Conv_BN_Relu(filters, activation)
13 | self.deconv_2 = ConvTranspose2D(filters, filters)
14 | # Down 4
15 | self.down_4 = MaxPooling2D(2, 2)
16 | self.conv_bn_relu_4 = Conv_BN_Relu(filters, activation, kernel=1)
17 | self.deconv_4_1 = ConvTranspose2D(filters, filters)
18 | self.deconv_4_2 = ConvTranspose2D(filters, filters)
19 | # output
20 | self.out = Conv2D(filters*4, filters)
21 |
22 | def forward(self, R, I_att):
23 | R_att = R * I_att
24 | # Down 1
25 | msia_1 = self.conv_bn_relu_1(R_att)
26 | # Down 2
27 | down_2 = self.down_2(R_att)
28 | conv_bn_relu_2 = self.conv_bn_relu_2(down_2)
29 | msia_2 = self.deconv_2(conv_bn_relu_2)
30 | # Down 4
31 | down_4 = self.down_4(down_2)
32 | conv_bn_relu_4 = self.conv_bn_relu_4(down_4)
33 | deconv_4 = self.deconv_4_1(conv_bn_relu_4)
34 | msia_4 = self.deconv_4_2(deconv_4)
35 | # concat
36 | concat = torch.cat([R, msia_1, msia_2, msia_4], dim=1)
37 | out = self.out(concat)
38 | return out
39 |
40 |
41 | class Conv_BN_Relu(nn.Module):
42 | def __init__(self, channels, activation='lrelu', kernel=3):
43 | super().__init__()
44 | self.ActivationLayer = nn.LeakyReLU(inplace=True)
45 | if activation == 'relu':
46 | self.ActivationLayer = nn.ReLU(inplace=True)
47 | self.conv_bn_relu = nn.Sequential(
48 | nn.Conv2d(channels, channels, kernel_size=kernel, padding=kernel//2),
49 | nn.BatchNorm2d(channels, momentum=0.99), # 原论文用的tf.layer的默认参数
50 | self.ActivationLayer,
51 | )
52 |
53 | def forward(self, x):
54 | return self.conv_bn_relu(x)
55 |
56 |
57 | class DoubleConv(nn.Module):
58 | def __init__(self, in_channels, out_channels, activation='lrelu'):
59 | super().__init__()
60 | self.doubleconv = nn.Sequential(
61 | Conv2D(in_channels, out_channels, activation),
62 | Conv2D(out_channels,out_channels, activation)
63 | )
64 |
65 | def forward(self, x):
66 | return self.doubleconv(x)
67 |
68 | class ResConv(nn.Module):
69 | def __init__(self, in_channels, out_channels, activation='lrelu'):
70 | super().__init__()
71 | self.relu = nn.LeakyReLU(0.2, inplace=True)
72 | if activation == 'relu':
73 | self.relu = nn.ReLU(inplace=True)
74 |
75 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
76 | self.bn1 = nn.BatchNorm2d(out_channels, momentum=0.8)
77 | self.cbam = CBAM(out_channels)
78 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
79 | self.bn2 = nn.BatchNorm2d(out_channels, momentum=0.8)
80 |
81 | def forward(self, x):
82 | conv1 = self.conv1(x)
83 | bn1 = self.bn1(conv1)
84 | x1 = self.relu(bn1)
85 | cbam = self.cbam(x1)
86 | conv2 = self.conv2(cbam)
87 | bn2 = self.bn1(conv2)
88 | out = bn2 + x
89 | return out
90 |
91 | class Conv2D(nn.Module):
92 | def __init__(self, in_channels, out_channels, activation='lrelu', stride=1):
93 | super().__init__()
94 | self.ActivationLayer = nn.LeakyReLU(inplace=True)
95 | if activation == 'relu':
96 | self.ActivationLayer = nn.ReLU(inplace=True)
97 | self.conv_relu = nn.Sequential(
98 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
99 | self.ActivationLayer,
100 | )
101 |
102 | def forward(self, x):
103 | return self.conv_relu(x)
104 |
105 |
106 | class ConvTranspose2D(nn.Module):
107 | def __init__(self, in_channels, out_channels, activation='lrelu'):
108 | super().__init__()
109 | self.ActivationLayer = nn.LeakyReLU(inplace=True)
110 | if activation == 'relu':
111 | self.ActivationLayer = nn.ReLU(inplace=True)
112 | self.deconv_relu = nn.Sequential(
113 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0),
114 | self.ActivationLayer,
115 | )
116 |
117 | def forward(self, x):
118 | return self.deconv_relu(x)
119 |
120 |
121 | class MaxPooling2D(nn.Module):
122 | def __init__(self, kernel_size=2, stride=2):
123 | super().__init__()
124 | self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride)
125 |
126 | def forward(self, x):
127 | return self.maxpool(x)
128 |
129 |
130 | class AvgPooling2D(nn.Module):
131 | def __init__(self, kernel_size=2, stride=2):
132 | super().__init__()
133 | self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2)
134 |
135 | def forward(self, x):
136 | return self.avgpool(x)
137 |
138 |
139 | class ChannelAttention(nn.Module):
140 | def __init__(self, in_planes, ratio=16):
141 | super().__init__()
142 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
143 | self.max_pool = nn.AdaptiveMaxPool2d(1)
144 |
145 | self.sharedMLP = nn.Sequential(
146 | nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(),
147 | nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False))
148 | self.sigmoid = nn.Sigmoid()
149 |
150 | def forward(self, x):
151 | avgout = self.sharedMLP(self.avg_pool(x))
152 | maxout = self.sharedMLP(self.max_pool(x))
153 | return self.sigmoid(avgout + maxout)
154 |
155 |
156 | class SpatialAttention(nn.Module):
157 | def __init__(self, kernel_size=3):
158 | super().__init__()
159 | self.conv = nn.Conv2d(2,1,kernel_size, padding=1, bias=False)
160 | self.sigmoid = nn.Sigmoid()
161 |
162 | def forward(self, x):
163 | avgout = torch.mean(x, dim=1, keepdim=True)
164 | maxout, _ = torch.max(x, dim=1, keepdim=True)
165 | x = torch.cat([avgout, maxout], dim=1)
166 | x = self.conv(x)
167 | return self.sigmoid(x)
168 |
169 |
170 | class CBAM(nn.Module):
171 | def __init__(self, planes):
172 | super().__init__()
173 | self.ca = ChannelAttention(planes)
174 | self.sa = SpatialAttention()
175 | def forward(self, x):
176 | x = self.ca(x) * x
177 | out = self.sa(x) * x
178 | return x
179 |
180 |
181 | class Concat(nn.Module):
182 | def forward(self, x, y):
183 | _, _, xh, xw = x.size()
184 | _, _, yh, yw = y.size()
185 | diffY = xh - yh
186 | diffX = xw - yw
187 | y = F.pad(y, (diffX // 2, diffX - diffX//2,
188 | diffY // 2, diffY - diffY//2))
189 | return torch.cat((x, y), dim=1)
--------------------------------------------------------------------------------
/base_parser.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | class BaseParser():
4 | def __init__(self):
5 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
6 |
7 | def parse(self):
8 | self.parser.add_argument("--mode", default="train", choices=["train", "test"])
9 | self.parser.add_argument("--config", default="./config.yaml", help="path to config")
10 | self.parser.add_argument("--checkpoint", default=True,help="path to checkpoint to restore")
11 | return self.parser.parse_args()
12 |
--------------------------------------------------------------------------------
/base_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import optim
3 | import os
4 | import sys
5 | import time
6 | import numpy as np
7 | import matplotlib.pyplot as plt
8 | import cv2
9 | from torchsummary import summary
10 |
11 | class BaseTrainer:
12 | def __init__(self, config, dataloader, criterion, model,
13 | dataloader_test=None, extra_model=None):
14 | self.initialize(config)
15 | self.dataloader = dataloader
16 | self.dataloader_test = dataloader_test
17 | self.loss_fn = criterion
18 | self.model = model
19 | self.extra_model = extra_model
20 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21 | self.model.to(device=self.device)
22 | # faster convolutions, but more memory
23 | if self.device == 'cuda':
24 | torch.backends.cudnn.benchmark=True
25 |
26 | def initialize(self, config):
27 | self.batch_size = config['batch_size']
28 | self.length = config['length']
29 | self.epochs = config['epochs']
30 | self.steps_per_epoch = config['steps_per_epoch']
31 | self.print_frequency = config['print_frequency']
32 | self.save_frequency = config['save_frequency']
33 | self.weights_dir = config['weights_dir']
34 | self.samples_dir = config['samples_dir'] # './logs/samples'
35 | self.learning_rate = config['learning_rate']
36 | self.noDecom = config['noDecom']
37 |
38 | def train(self):
39 | print(f'Using device {self.device}')
40 | summary(self.model, input_size=(3, 48, 48))
41 |
42 | self.model.to(device=self.device)
43 | # faster convolutions, but more memory
44 | cudnn.benchmark = True
45 |
46 | optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
47 | try:
48 | for iter in range(self.epochs):
49 | epoch_loss = 0
50 | steps = 0
51 | iter_start_time = time.time()
52 | for idx, data in enumerate(self.dataloader):
53 | input_ = data['input']
54 | input_ = input_.to(self.device)
55 | target = data['target']
56 | target = target.to(self.device)
57 | y_pred = self.model(input_)
58 | loss = self.loss_fn(y_pred, target)
59 | print("iter: ", idx, "average_loss: ", loss.item())
60 | optimizer.zero_grad()
61 | loss.backward()
62 | optimizer.step()
63 | steps += 1
64 | if idx > 0 and idx % self.save_frequency == 0:
65 | # torch.save(self.model.state_dict(), './checkpoints/g_net_{}.pth'.format(str(idx % 3)))
66 | print('Saved model.')
67 | self.test(iter, idx, plotImage=True, saveImage=True)
68 | iter_end_time = time.time()
69 | print("End of epochs {}, Time taken: {},average loss: {}".format(iter, iter_end_time - iter_start_time, epoch_loss / steps))
70 | iter_end_time = time.time()
71 | print("End of epochs {}, Time taken: {.3f}, average loss: {.5f}".format(iter, iter_end_time - iter_start_time, epoch_loss / steps))
72 | except KeyboardInterrupt:
73 | torch.save(self.model.state_dict(), 'INTERRUPTED.pth')
74 | print('Saved interrupt')
75 | try:
76 | sys.exit(0)
77 | except SystemExit:
78 | os._exit(0)
79 |
80 | def test(self, epoch=-1, plot_dir='./images/samples-illum'):
81 | self.model.eval()
82 | for R_low_tensor, I_low_tensor, R_high_tensor, I_high_tensor, name in self.dataloader_test:
83 | I_low = I_low_tensor.to(self.device)
84 | I_high = I_high_tensor.to(self.device)
85 | with torch.no_grad():
86 | ratio_high2low = torch.mean(torch.div((I_low + 0.0001), (I_high + 0.0001)))
87 | ratio_low2high = torch.mean(torch.div((I_high + 0.0001), (I_low + 0.0001)))
88 | ratio_high2low_map = torch.ones_like(I_low) * ratio_high2low
89 | ratio_low2high_map = torch.ones_like(I_low) * ratio_low2high
90 |
91 | I_low2high_map = self.model(I_low, ratio_low2high_map)
92 | I_high2low_map = self.model(I_high, ratio_high2low_map)
93 |
94 | I_low2high_np = I_low2high_map.detach().cpu().numpy()[0]
95 | I_high2low_np = I_high2low_map.detach().cpu().numpy()[0]
96 | I_low_np = I_low_tensor.numpy()[0]
97 | I_high_np = I_high_tensor.numpy()[0]
98 | sample_imgs = np.concatenate( (I_low_np, I_high_np, I_high2low_np, I_low2high_np), axis=0 )
99 | filepath = os.path.join(plot_dir, f'{name}_epoch_{epoch}.png')
100 | split_point = [0, 1, 2, 3, 4]
101 | sample(sample_imgs, split=split_point, figure_size=(2, 2),
102 | img_dim=self.length, path=filepath, num=epoch)
--------------------------------------------------------------------------------
/checkpoints/decom_net.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenghansen/KinD-pytorch/815d26f538221ba5b59799f5439d491579f6bd2d/checkpoints/decom_net.pth
--------------------------------------------------------------------------------
/checkpoints/illum_net.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenghansen/KinD-pytorch/815d26f538221ba5b59799f5439d491579f6bd2d/checkpoints/illum_net.pth
--------------------------------------------------------------------------------
/checkpoints/restore_net.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenghansen/KinD-pytorch/815d26f538221ba5b59799f5439d491579f6bd2d/checkpoints/restore_net.pth
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | batch_size: 32
2 | length: 256
3 | epochs: 401
4 | steps_per_epoch: 128
5 |
6 | print_frequency: 10
7 | save_frequency: 10
8 |
9 | samples_dir: './samples'
10 | weights_dir: './checkpoints'
11 |
12 | learning_rate: 0.0004
13 | checkpoints: True
14 |
15 | noDecom: False
--------------------------------------------------------------------------------
/dataloader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import random
4 | import matplotlib.pyplot as plt
5 | import collections
6 | import torch
7 | import torchvision
8 | import cv2
9 | import shutil
10 | import time
11 | from PIL import Image
12 | import torchvision.transforms as transforms
13 | from torch.utils.data import Dataset, DataLoader
14 | from utils import *
15 |
16 |
17 | class CustomDataset(Dataset):
18 | def __init__(self, datapath):
19 | super().__init__()
20 | self.datapath = datapath
21 | self.img_path = [os.path.join(datapath, f) for f in os.listdir(datapath) if
22 | any(filetype in f.lower() for filetype in ['jpeg', 'png', 'jpg', 'bmp'])]
23 | self.name = [f.split(".")[0] for f in os.listdir(datapath) if any(filetype in
24 | f.lower() for filetype in ['jpeg', 'png', 'jpg', 'bmp'])]
25 |
26 | def __len__(self):
27 | return len(self.img_path)
28 |
29 | def __getitem__(self, idx):
30 | datafiles = self.img_path[idx]
31 | img = Image.open(datafiles).convert('RGB')
32 | img = np.asarray(img, np.float32).transpose((2,0,1)) / 255.
33 | return img, self.name[idx]
34 |
35 |
36 | class LOLDataset(Dataset):
37 | def __init__(self, root, list_path, crop_size=256, to_RAM=False, training=True):
38 | super(LOLDataset,self).__init__()
39 | self.training = training
40 | self.to_RAM = to_RAM
41 | self.root = root
42 | self.list_path = list_path
43 | self.crop_size = crop_size
44 | with open(list_path) as f:
45 | self.pairs = f.readlines()
46 | self.files = []
47 | for pair in self.pairs:
48 | lr_path, hr_path = pair.split(",")
49 | hr_path = hr_path[:-1]
50 | name = lr_path.split("\\")[-1][:-4]
51 | lr_file = os.path.join(self.root, lr_path)
52 | hr_file = os.path.join(self.root, hr_path)
53 | self.files.append({
54 | "lr": lr_file,
55 | "hr": hr_file,
56 | "name": name
57 | })
58 | self.data = []
59 | if self.to_RAM:
60 | for i, fileinfo in enumerate(self.files):
61 | name = fileinfo["name"]
62 | lr_img = Image.open(fileinfo["lr"])
63 | hr_img = Image.open(fileinfo["hr"])
64 | self.data.append({
65 | "lr": lr_img,
66 | "hr": hr_img,
67 | "name": name
68 | })
69 | log("Finish loading all images to RAM...")
70 |
71 | def __len__(self):
72 | return len(self.files)
73 |
74 | def __getitem__(self, idx):
75 | datafiles = self.files[idx]
76 |
77 | '''load the datas'''
78 | if not self.to_RAM:
79 | name = datafiles["name"]
80 | lr_img = Image.open(datafiles["lr"])
81 | hr_img = Image.open(datafiles["hr"])
82 | else:
83 | name = self.data[idx]["name"]
84 | lr_img = self.data[idx]["lr"]
85 | hr_img = self.data[idx]["hr"]
86 |
87 | '''random crop the inputs'''
88 | if self.crop_size > 0:
89 |
90 | #select a random start-point for croping operation
91 | h_offset = random.randint(0, lr_img.size[1] - self.crop_size)
92 | w_offset = random.randint(0, lr_img.size[0] - self.crop_size)
93 | #crop the image and the label
94 | crop_box = (w_offset, h_offset, w_offset+self.crop_size, h_offset+self.crop_size)
95 | lr_crop = lr_img
96 | hr_crop = hr_img
97 | if self.training is True:
98 | lr_crop = lr_img.crop(crop_box)
99 | hr_crop = hr_img.crop(crop_box)
100 | rand_mode = np.random.randint(0, 7)
101 | lr_crop = data_augmentation(lr_crop, rand_mode)
102 | hr_crop = data_augmentation(hr_crop, rand_mode)
103 |
104 |
105 | '''convert PIL Image to numpy array'''
106 | lr_crop = np.asarray(lr_crop, np.float32).transpose((2,0,1)) / 255.
107 | hr_crop = np.asarray(hr_crop, np.float32).transpose((2,0,1)) / 255.
108 | return lr_crop, hr_crop, name
109 |
110 |
111 | class LOLDataset_Decom(Dataset):
112 | def __init__(self, root, list_path,
113 | crop_size=256, to_RAM=False, training=True):
114 | super().__init__()
115 | self.training = training
116 | self.to_RAM = to_RAM
117 | self.root = root
118 | self.list_path = list_path
119 | self.crop_size = crop_size
120 | with open(list_path) as f:
121 | self.pairs = f.readlines()
122 | self.files = []
123 | for pair in self.pairs:
124 | lr_path_R, lr_path_I, hr_path_R, hr_path_I = pair.split(",")
125 | hr_path_I = hr_path_I[:-1]
126 | name = lr_path_R.split("\\")[-1][:-4]
127 | lr_file_R = os.path.join(self.root, lr_path_R)
128 | lr_file_I = os.path.join(self.root, lr_path_I)
129 | hr_file_R = os.path.join(self.root, hr_path_R)
130 | hr_file_I = os.path.join(self.root, hr_path_I)
131 | self.files.append({
132 | "lr_R": lr_file_R,
133 | "lr_I": lr_file_I,
134 | "hr_R": hr_file_R,
135 | "hr_I": hr_file_I,
136 | "name": name
137 | })
138 | self.data = []
139 | if self.to_RAM:
140 | for i, fileinfo in enumerate(self.files):
141 | name = fileinfo["name"]
142 | lr_img_R = Image.open(fileinfo["lr_R"])
143 | hr_img_R = Image.open(fileinfo["hr_R"])
144 | lr_img_I = Image.open(fileinfo["lr_I"]).convert('L')
145 | hr_img_I = Image.open(fileinfo["hr_I"]).convert('L')
146 | self.data.append({
147 | "lr_R": lr_img_R,
148 | "lr_I": lr_img_I,
149 | "hr_R": hr_img_R,
150 | "hr_I": hr_img_I,
151 | "name": name
152 | })
153 | log("Finish loading all images to RAM...")
154 |
155 | def __len__(self):
156 | return len(self.files)
157 |
158 | def __getitem__(self, idx):
159 | datafiles = self.files[idx]
160 |
161 | '''load the datas'''
162 | if not self.to_RAM:
163 | name = datafiles["name"]
164 | lr_img_R = Image.open(datafiles["lr_R"])
165 | hr_img_R = Image.open(datafiles["hr_R"])
166 | lr_img_I = Image.open(datafiles["lr_I"]).convert('L')
167 | hr_img_I = Image.open(datafiles["hr_I"]).convert('L')
168 | else:
169 | name = self.data[idx]["name"]
170 | lr_img_R = self.data[idx]["lr_R"]
171 | lr_img_I = self.data[idx]["lr_I"]
172 | hr_img_R = self.data[idx]["hr_R"]
173 | hr_img_I = self.data[idx]["hr_I"]
174 |
175 |
176 | '''random crop the inputs'''
177 | if self.crop_size > 0:
178 |
179 | #select a random start-point for croping operation
180 | h_offset = random.randint(0, lr_img_R.size[1] - self.crop_size)
181 | w_offset = random.randint(0, lr_img_R.size[0] - self.crop_size)
182 | #crop the image and the label
183 | crop_box = (w_offset, h_offset, w_offset+self.crop_size, h_offset+self.crop_size)
184 | lr_crop_R = lr_img_R
185 | lr_crop_I = lr_img_I
186 | hr_crop_R = hr_img_R
187 | hr_crop_I = hr_img_I
188 | if self.training is True:
189 | lr_crop_R = lr_crop_R.crop(crop_box)
190 | lr_crop_I = lr_crop_I.crop(crop_box)
191 | hr_crop_R = hr_crop_R.crop(crop_box)
192 | hr_crop_I = hr_crop_I.crop(crop_box)
193 | rand_mode = np.random.randint(0, 7)
194 | lr_crop_R = data_augmentation(lr_crop_R, rand_mode)
195 | lr_crop_I = data_augmentation(lr_crop_I, rand_mode)
196 | hr_crop_R = data_augmentation(hr_crop_R, rand_mode)
197 | hr_crop_I = data_augmentation(hr_crop_I, rand_mode)
198 |
199 |
200 | '''convert PIL Image to numpy array'''
201 | lr_crop_R = np.asarray(lr_crop_R, np.float32).transpose((2,0,1)) / 255.
202 | lr_crop_I = np.expand_dims(np.asarray(lr_crop_I, np.float32) , axis=0) / 255.
203 | hr_crop_R = np.asarray(hr_crop_R, np.float32).transpose((2,0,1)) / 255.
204 | hr_crop_I = np.expand_dims(np.asarray(hr_crop_I, np.float32) , axis=0) / 255.
205 | return lr_crop_R, lr_crop_I, hr_crop_R, hr_crop_I, name
206 |
207 |
208 | def build_LOLDataset_list_txt(dst_dir):
209 | log(f"Buliding LOLDataset list text at {dst_dir}")
210 | lr_dir = os.path.join(dst_dir, 'low')
211 | hr_dir = os.path.join(dst_dir, 'high')
212 | img_lr_path = [os.path.join('low', name) for name in os.listdir(lr_dir)]
213 | img_hr_path = [os.path.join('high', name) for name in os.listdir(hr_dir)]
214 | list_path = os.path.join(dst_dir, 'pair_list.csv')
215 | with open(list_path, 'w') as f:
216 | for lr_path, hr_path in zip(img_lr_path, img_hr_path):
217 | f.write(f"{lr_path},{hr_path}\n")
218 | log(f"Finish... There are {len(img_lr_path)} pairs...")
219 | return list_path
220 |
221 |
222 | def build_LOLDataset_Decom_list_txt(dst_dir):
223 | log(f"Buliding LOLDataset Decom list text at {dst_dir}")
224 | dir_lists = []
225 | tail = ['low\\R', 'low\\I', 'high\\R', 'high\\I']
226 | for t in tail:
227 | dir_lists.append(os.path.join(dst_dir, t))
228 | imgs_path = [[],[],[],[]]
229 | for i, direction in enumerate(dir_lists):
230 | for name in os.listdir(direction):
231 | path = os.path.join(tail[i], name)
232 | imgs_path[i].append(path)
233 | list_path = os.path.join(dst_dir, 'pair_list.csv')
234 | with open(list_path, 'w') as f:
235 | for lr_R, lr_I, hr_R, hr_I in zip(*imgs_path):
236 | f.write(f"{lr_R},{lr_I},{hr_R},{hr_I}\n")
237 | log(f"Finish... There are {len(imgs_path[0])} pairs...")
238 | return list_path
239 |
240 |
241 | def divide_dataset(dst_dir):
242 | lr_dir_R = os.path.join(dst_dir, 'low/R')
243 | lr_dir_I = os.path.join(dst_dir, 'low/I')
244 | hr_dir_R = os.path.join(dst_dir, 'high/R')
245 | hr_dir_I = os.path.join(dst_dir, 'high/I')
246 | for name in os.listdir(dst_dir):
247 | path = os.path.join(dst_dir, name)
248 | name = name[:-4]
249 | item = name.split("_")
250 | if item[0] == 'high' and item[-1] == 'R':
251 | shutil.move(path, os.path.join(hr_dir_R, item[1]+".png"))
252 | if item[0] == 'high' and item[-1] == 'I':
253 | shutil.move(path, os.path.join(hr_dir_I, item[1]+".png"))
254 | if item[0] == 'low' and item[-1] == 'R':
255 | shutil.move(path, os.path.join(lr_dir_R, item[1]+".png"))
256 | if item[0] == 'low' and item[-1] == 'I':
257 | shutil.move(path, os.path.join(lr_dir_I, item[1]+".png"))
258 | log(f"Finish...")
259 |
260 |
261 | def change_name(dst_dir):
262 | dir_lists = []
263 | dir_lists.append(os.path.join(dst_dir, 'low\\R'))
264 | dir_lists.append(os.path.join(dst_dir, 'low\\I'))
265 | dir_lists.append(os.path.join(dst_dir, 'high\\R'))
266 | dir_lists.append(os.path.join(dst_dir, 'high\\I'))
267 | for direction in dir_lists:
268 | for name in os.listdir(direction):
269 | path = os.path.join(direction, name)
270 | name = name[:-4]
271 | item = name.split("_")
272 | os.rename(path, os.path.join(direction, item[1]+".png"))
273 | log(f"Finish...")
274 |
275 |
276 | if __name__ == '__main__':
277 | # noDecom Dataloader Test
278 | # root_path_train = r'H:\datasets\Low-Light Dataset\LOLdataset_decom\our485'
279 | # root_path_test = r'H:\datasets\Low-Light Dataset\LOLdataset_decom\eval15'
280 | # list_path_train = build_LOLDataset_Decom_list_txt(root_path_train)
281 | # list_path_test = build_LOLDataset_Decom_list_txt(root_path_test)
282 | # Batch_size = 2
283 | # log("Buliding LOL Dataset...")
284 | # dst_train = LOLDataset_Decom(root_path_train, list_path_train, crop_size=128, to_RAM=True)
285 | # dst_test = LOLDataset_Decom(root_path_test, list_path_test, crop_size=128, to_RAM=False)
286 | # # But when we are training a model, the mean should have another value
287 | # trainloader = DataLoader(dst_train, batch_size = Batch_size)
288 | # testloader = DataLoader(dst_test, batch_size=1)
289 | # plt.ion()
290 | # for i, data in enumerate(trainloader):
291 | # _, _, _, imgs, name = data
292 | # log(name)
293 | # img = imgs[0].numpy()
294 | # sample(imgs[0], figure_size=(1, 1), img_dim=128)
295 |
296 | # Dataloader Test
297 | root_path_train = r'H:\datasets\Low-Light Dataset\KinD++\LOLdataset\our485'
298 | root_path_test = r'H:\datasets\Low-Light Dataset\KinD++\LOLdataset\eval15'
299 | # root_path_train = r'H:\datasets\Low-Light Dataset\KinD++\LOLdataset\our485'
300 | # root_path_test = r'H:\datasets\Low-Light Dataset\KinD++\LOLdataset\eval15'
301 | list_path_train = build_LOLDataset_list_txt(root_path_train)
302 | list_path_test = build_LOLDataset_list_txt(root_path_test)
303 | Batch_size = 2
304 | log("Buliding LOL Dataset...")
305 | dst_train = LOLDataset(root_path_train, list_path_train, crop_size=128, to_RAM=False)
306 | dst_test = LOLDataset(root_path_test, list_path_test, crop_size=128, to_RAM=False)
307 | # But when we are training a model, the mean should have another value
308 | trainloader = DataLoader(dst_train, batch_size = Batch_size)
309 | testloader = DataLoader(dst_test, batch_size=1)
310 | plt.ion()
311 | for i, data in enumerate(trainloader):
312 | _, imgs, name = data
313 | img = imgs[0].numpy()
314 | sample(imgs[0], figure_size=(1, 1), img_dim=128)
315 |
--------------------------------------------------------------------------------
/decom_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.optim import lr_scheduler
5 | import numpy as np
6 | import time
7 | import yaml
8 | import sys
9 | from tqdm import tqdm
10 | from torchvision.utils import make_grid
11 | from torchvision import transforms
12 | from torchsummary import summary
13 | from base_trainer import BaseTrainer
14 | from losses import Decom_Loss
15 | from models import DecomNet
16 | from base_parser import BaseParser
17 | from dataloader import *
18 |
19 | class Decom_Trainer(BaseTrainer):
20 | def train(self):
21 | print(f'Using device {self.device}')
22 | self.model.to(device=self.device)
23 | summary(self.model, input_size=(3, 48, 48))
24 | # faster convolutions, but more memory
25 | # cudnn.benchmark = True
26 |
27 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
28 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.997)
29 | try:
30 | for iter in range(self.epochs):
31 | epoch_loss = 0
32 | idx = 0
33 | hook_number = -1
34 | iter_start_time = time.time()
35 | # with tqdm(total=self.steps_per_epoch) as pbar:
36 | for L_low_tensor, L_high_tensor, name in self.dataloader:
37 | L_low = L_low_tensor.to(self.device)
38 | L_high = L_high_tensor.to(self.device)
39 | R_low, I_low = self.model(L_low)
40 | R_high, I_high = self.model(L_high)
41 | if idx % self.print_frequency == 0:
42 | hook_number = -1
43 | loss = self.loss_fn(R_low, R_high, I_low, I_high, L_low, L_high, hook=hook_number)
44 | hook_number = -1
45 | if idx % 8 == 0:
46 | print(f"iter: {iter}_{idx}\taverage_loss: {loss.item():.6f}")
47 | optimizer.zero_grad()
48 | loss.backward()
49 | optimizer.step()
50 | idx += 1
51 | # pbar.update(1)
52 | # pbar.set_postfix({'loss':loss.item()})
53 |
54 | if iter % self.print_frequency == 0:
55 | self.test(iter, plot_dir='./images/samples-decom')
56 |
57 | if iter % self.save_frequency == 0:
58 | torch.save(self.model.state_dict(), './weights/decom_net_test3.pth')
59 | log("Weight Has saved as 'decom_net.pth'")
60 |
61 | scheduler.step()
62 | iter_end_time = time.time()
63 | log(f"Time taken: {iter_end_time - iter_start_time:.3f} seconds\t lr={scheduler.get_lr()[0]:.6f}")
64 |
65 | except KeyboardInterrupt:
66 | torch.save(self.model.state_dict(), 'INTERRUPTED_decom.pth')
67 | log('Saved interrupt_decom')
68 | try:
69 | sys.exit(0)
70 | except SystemExit:
71 | os._exit(0)
72 |
73 | @no_grad
74 | def test(self, epoch=-1, plot_dir='./images/samples-decom'):
75 | self.model.eval()
76 | hook = 0
77 | for L_low_tensor, L_high_tensor, name in self.dataloader_test:
78 | L_low = L_low_tensor.to(self.device)
79 | L_high = L_high_tensor.to(self.device)
80 | R_low, I_low = self.model(L_low)
81 | R_high, I_high = self.model(L_high)
82 |
83 | if epoch % (self.print_frequency*10) == 0:
84 | loss = self.loss_fn(R_low, R_high, I_low, I_high, L_low, L_high, hook=hook)
85 | hook += 1
86 | loss = 0
87 |
88 | R_low_np = R_low.detach().cpu().numpy()[0]
89 | R_high_np = R_high.detach().cpu().numpy()[0]
90 | I_low_np = I_low.detach().cpu().numpy()[0]
91 | I_high_np = I_high.detach().cpu().numpy()[0]
92 | L_low_np = L_low_tensor.numpy()[0]
93 | L_high_np = L_high_tensor.numpy()[0]
94 | sample_imgs = np.concatenate( (R_low_np, I_low_np, L_low_np,
95 | R_high_np, I_high_np, L_high_np), axis=0 )
96 | filepath = os.path.join(plot_dir, f'{name[0]}_epoch_{epoch}.png')
97 | split_point = [0, 3, 4, 7, 10, 11, 14]
98 | img_dim = I_low_np.shape[1:]
99 | sample(sample_imgs, split=split_point, figure_size=(2, 3),
100 | img_dim=img_dim, path=filepath, num=epoch)
101 |
102 |
103 | if __name__ == "__main__":
104 | criterion = Decom_Loss()
105 | model = DecomNet()
106 |
107 | parser = BaseParser()
108 | args = parser.parse()
109 | args.checkpoint = True
110 | if args.checkpoint is not None:
111 | pretrain = torch.load('./weights/decom_net.pth')
112 | model.load_state_dict(pretrain)
113 | print('Model loaded from decom_net.pth')
114 |
115 | with open(args.config) as f:
116 | config = yaml.load(f)
117 |
118 | root_path_train = r'H:\datasets\Low-Light Dataset\KinD++\LOLdataset\our485'
119 | root_path_test = r'C:\DeepLearning\KinD_plus-master\LOLdataset\eval15'
120 | list_path_train = build_LOLDataset_list_txt(root_path_train)
121 | list_path_test = build_LOLDataset_list_txt(root_path_test)
122 | # list_path_train = os.path.join(root_path_train, 'pair_list.csv')
123 | # list_path_test = os.path.join(root_path_test, 'pair_list.csv')
124 |
125 | log("Buliding LOL Dataset...")
126 | # transform = transforms.Compose([transforms.ToTensor(),])
127 | dst_train = LOLDataset(root_path_train, list_path_train,
128 | crop_size=config['length'], to_RAM=True)
129 | dst_test = LOLDataset(root_path_test, list_path_test,
130 | crop_size=config['length'], to_RAM=True, training=False)
131 |
132 | train_loader = DataLoader(dst_train, batch_size = config['batch_size'], shuffle=True)
133 | test_loader = DataLoader(dst_test, batch_size=1)
134 |
135 | # if args.noDecom is True:
136 | # root_path_valid = r'H:\datasets\Low-Light Dataset\LOLdataset_decom\our485'
137 | # list_path_valid = os.path.join(root_path_test, 'pair_list.csv')
138 |
139 | # log("Buliding LOL Dataset (noDecom)...")
140 | # # transform = transforms.Compose([transforms.ToTensor()])
141 | # dst_valid = LOLDataset_Decom(root_path_test, list_path_test,
142 | # crop_size=config['length'], to_RAM=True, training=False)
143 | # valid_loader = DataLoader(dst_train, batch_size = config['batch_size'], shuffle=True)
144 |
145 | trainer = Decom_Trainer(config, train_loader, criterion, model, dataloader_test=test_loader)
146 | # --config ./config/config.yaml
147 | if args.mode == 'train':
148 | trainer.train()
149 | else:
150 | trainer.test()
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import time
6 | import yaml
7 | import sys
8 | from tqdm import tqdm
9 | from torch.optim import lr_scheduler
10 | from torchvision.utils import make_grid
11 | from torchvision import transforms
12 | from torchsummary import summary
13 | from base_trainer import BaseTrainer
14 | from losses import *
15 | from models import *
16 | from base_parser import BaseParser
17 | from dataloader import *
18 |
19 | class KinD_noDecom_Trainer(BaseTrainer):
20 | @no_grad
21 | def test(self, epoch=-1, plot_dir='./images/samples-KinD'):
22 | self.model.eval()
23 | self.model.to(device=self.device)
24 | if 'decom_net' in model._modules:
25 | for L_low_tensor, L_high_tensor, name in self.dataloader_test:
26 | L_low = L_low_tensor.to(self.device)
27 | L_high = L_high_tensor.to(self.device)
28 |
29 | R_low, I_low = self.model.decom_net(L_low)
30 | R_high, I_high = self.model.decom_net(L_high)
31 | I_low_3 = torch.cat([I_low, I_low, I_low], dim=1)
32 | I_high_3 = torch.cat([I_high, I_high, I_high], dim=1)
33 |
34 | output_low = I_low_3 * R_low
35 | output_high = I_high_3 * R_high
36 |
37 | b = 0.7; w=0.5
38 | bright_low = torch.mean(I_low)
39 | # bright_high = torch.mean(I_high)
40 | bright_high = torch.ones_like(bright_low) * b + bright_low * w
41 | ratio = torch.div(bright_high, bright_low)
42 | log(f"Brightness: {bright_high}\tIllumation Magnification: {ratio.item()}")
43 | # ratio_map = torch.ones_like(I_low) * ratio
44 |
45 | R_final, I_final, output_final = self.model(L_low, ratio)
46 |
47 | R_final_np = R_final.detach().cpu().numpy()[0]
48 | I_final_np = I_final.detach().cpu().numpy()[0]
49 | R_low_np = R_low.detach().cpu().numpy()[0]
50 | I_low_np = I_low.detach().cpu().numpy()[0]
51 | R_high_np = R_high.detach().cpu().numpy()[0]
52 | I_high_np = I_high.detach().cpu().numpy()[0]
53 | output_final_np = output_final.detach().cpu().numpy()[0]
54 | output_low_np = output_low.detach().cpu().numpy()[0]
55 | output_high_np = output_high.detach().cpu().numpy()[0]
56 | # ratio_map_np = ratio_map.detach().cpu().numpy()[0]
57 | L_low_np = L_low_tensor.numpy()[0]
58 | L_high_np = L_high_tensor.numpy()[0]
59 |
60 | sample_imgs = np.concatenate( (R_low_np, I_low_np, output_low_np, L_low_np,
61 | R_high_np, I_high_np, output_high_np, L_high_np,
62 | R_final_np, I_final_np, output_final_np, L_high_np), axis=0 )
63 |
64 | filepath = os.path.join(plot_dir, f'{name[0]}_epoch_{epoch}.png')
65 | split_point = [0, 3, 4, 7, 10, 13, 14, 17, 20, 23, 24, 27, 30]
66 | img_dim = I_high_np.shape[1:]
67 | sample(sample_imgs, split=split_point, figure_size=(3, 4),
68 | img_dim=img_dim, path=filepath, num=epoch)
69 | else:
70 | for R_low_tensor, I_low_tensor, R_high_tensor, I_high_tensor, name in self.dataloader_test:
71 | R_low = R_low_tensor.to(self.device)
72 | R_high = R_high_tensor.to(self.device)
73 | I_low = I_low_tensor.to(self.device)
74 | I_high = I_high_tensor.to(self.device)
75 | I_high_3 = torch.cat([I_high, I_high, I_high], dim=1)
76 | output_high = I_high_3 * R_high
77 |
78 | # while True:
79 | # b = float(input('请输入增强水平:'))
80 | # if b <= 0: break
81 | b = 0.6; w = 0.5
82 | bright_low = torch.mean(I_low)
83 | bright_high = torch.ones_like(bright_low) * b + bright_low * w
84 | ratio = torch.div(bright_high, bright_low)
85 | print(bright_high, ratio)
86 | # ratio_map = torch.ones_like(I_low) * ratio
87 |
88 | R_final, I_final, output_final = self.model(R_low, I_low, ratio)
89 |
90 | R_final_np = R_final.detach().cpu().numpy()[0]
91 | I_final_np = I_final.detach().cpu().numpy()[0]
92 | output_final_np = output_final.detach().cpu().numpy()[0]
93 | output_high_np = output_high.detach().cpu().numpy()[0]
94 | # ratio_map_np = ratio_map.detach().cpu().numpy()[0]
95 | I_high_np = I_high_tensor.numpy()[0]
96 | R_high_np = R_high_tensor.numpy()[0]
97 |
98 | sample_imgs = np.concatenate( (R_high_np, I_high_np, output_high_np,
99 | R_final_np, I_final_np, output_final_np), axis=0 )
100 |
101 | filepath = os.path.join(plot_dir, f'{name[0]}_epoch_{epoch}.png')
102 | split_point = [0, 3, 4, 7, 10, 11, 14]
103 | img_dim = I_high_np.shape[1:]
104 | sample(sample_imgs, split=split_point, figure_size=(2, 3),
105 | img_dim=img_dim, path=filepath, num=epoch)
106 |
107 |
108 | if __name__ == "__main__":
109 | criterion = None
110 | parser = BaseParser()
111 | args = parser.parse()
112 | # args.noDecom = True
113 | with open(args.config) as f:
114 | config = yaml.load(f)
115 | if config['noDecom'] is True:
116 | model = KinD_noDecom()
117 | else:
118 | model = KinD()
119 |
120 | if args.checkpoint is not None:
121 | if config['noDecom'] is False:
122 | pretrain_decom = torch.load('./weights/decom_net.pth')
123 | model.decom_net.load_state_dict(pretrain_decom)
124 | log('Model loaded from decom_net.pth')
125 | pretrain_resotre = torch.load('./weights/restore_net.pth')
126 | model.restore_net.load_state_dict(pretrain_resotre)
127 | log('Model loaded from restore_net.pth')
128 | pretrain_illum = torch.load('./weights/illum_net.pth')
129 | model.illum_net.load_state_dict(pretrain_illum)
130 | log('Model loaded from illum_net.pth')
131 |
132 | if config['noDecom'] is True:
133 | root_path_test = r'H:\datasets\Low-Light Dataset\LOLdataset_decom\eval15'
134 | list_path_test = os.path.join(root_path_test, 'pair_list.csv')
135 |
136 | log("Buliding LOL Dataset (noDecom)...")
137 | # transform = transforms.Compose([transforms.ToTensor()])
138 | dst_test = LOLDataset_Decom(root_path_test, list_path_test,
139 | crop_size=config['length'], to_RAM=True, training=False)
140 | else:
141 | root_path_test = r'C:\DeepLearning\KinD_plus-master\LOLdataset\eval15'
142 | list_path_test = os.path.join(root_path_test, 'pair_list.csv')
143 |
144 | log("Buliding LOL Dataset...")
145 | # transform = transforms.Compose([transforms.ToTensor()])
146 | dst_test = LOLDataset(root_path_test, list_path_test,
147 | crop_size=config['length'], to_RAM=True, training=False)
148 |
149 | test_loader = DataLoader(dst_test, batch_size=1)
150 |
151 | KinD = KinD_noDecom_Trainer(config, None, criterion, model, dataloader_test=test_loader)
152 |
153 | # Please change your output direction here
154 | output_dir = './images/samples-KinD'
155 | KinD.test(plot_dir=output_dir)
--------------------------------------------------------------------------------
/figures/778_epoch_-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenghansen/KinD-pytorch/815d26f538221ba5b59799f5439d491579f6bd2d/figures/778_epoch_-1.png
--------------------------------------------------------------------------------
/figures/failure_case_decom.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenghansen/KinD-pytorch/815d26f538221ba5b59799f5439d491579f6bd2d/figures/failure_case_decom.png
--------------------------------------------------------------------------------
/figures/official_decom_train.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fenghansen/KinD-pytorch/815d26f538221ba5b59799f5439d491579f6bd2d/figures/official_decom_train.png
--------------------------------------------------------------------------------
/illum_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.optim import lr_scheduler
5 | import numpy as np
6 | import time
7 | import yaml
8 | import sys
9 | from tqdm import tqdm
10 | from torchvision.utils import make_grid
11 | from torchvision import transforms
12 | from torchsummary import summary
13 | from base_trainer import BaseTrainer
14 | from losses import *
15 | from models import *
16 | from base_parser import BaseParser
17 | from dataloader import *
18 |
19 | class Illum_Trainer(BaseTrainer):
20 | def __init__(self, config, dataloader, criterion, model,
21 | dataloader_test=None, decom_net=None):
22 | super().__init__(config, dataloader, criterion, model, dataloader_test)
23 | log(f'Using device {self.device}')
24 | self.decom_net = decom_net
25 | self.decom_net.to(device=self.device)
26 |
27 | def train(self):
28 | self.model.train()
29 | log(f'Using device {self.device}')
30 | self.model.to(device=self.device)
31 | print(self.model)
32 | # summary(self.model, input_size=[(1, 384, 384), (1,)], batch_size=4)
33 |
34 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
35 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99426)
36 | try:
37 | for iter in range(self.epochs):
38 | epoch_loss = 0
39 | idx = 0
40 | hook_number = -1
41 | iter_start_time = time.time()
42 | if self.noDecom is True:
43 | # with tqdm(total=self.steps_per_epoch) as pbar:
44 | for R_low_tensor, I_low_tensor, R_high_tensor, I_high_tensor, name in self.dataloader:
45 | optimizer.zero_grad()
46 | I_low = I_low_tensor.to(self.device)
47 | I_high = I_high_tensor.to(self.device)
48 | with torch.no_grad():
49 | ratio_high2low = torch.mean(torch.div((I_low + 0.0001), (I_high + 0.0001)))
50 | ratio_low2high = torch.mean(torch.div((I_high + 0.0001), (I_low + 0.0001)))
51 |
52 | I_low2high_map = self.model(I_low, ratio_low2high)
53 | I_high2low_map = self.model(I_high, ratio_high2low)
54 |
55 | if idx % self.print_frequency == 0:
56 | hook_number = iter
57 | loss = self.loss_fn(I_low2high_map, I_high, hook=hook_number) + self.loss_fn(I_high2low_map, I_low, hook=hook_number)
58 | hook_number = -1
59 | if idx % 30 == 0:
60 | log(f"iter: {iter}_{idx}\taverage_loss: {loss.item():.6f}")
61 | print(ratio_high2low, ratio_low2high)
62 | loss.backward()
63 | optimizer.step()
64 | idx += 1
65 | # pbar.update(1)
66 | # pbar.set_postfix({'loss':loss.item()})
67 | else:
68 | # with tqdm(total=self.steps_per_epoch) as pbar:
69 | for L_low_tensor, L_high_tensor, name in self.dataloader:
70 | optimizer.zero_grad()
71 | L_low = L_low_tensor.to(self.device)
72 | L_high = L_high_tensor.to(self.device)
73 |
74 | with torch.no_grad():
75 | R_low, I_low = self.decom_net(L_low)
76 | R_high, I_high = self.decom_net(L_high)
77 | # ratio_high2low = torch.mean(torch.div((I_low + 0.0001), (I_high + 0.0001)))
78 | # ratio_low2high = torch.mean(torch.div((I_high + 0.0001), (I_low + 0.0001)))
79 | bright_low = torch.mean(I_low)
80 | bright_high = torch.mean(I_high)
81 | ratio_high2low = torch.div(bright_low, bright_high)
82 | ratio_low2high = torch.div(bright_high, bright_low)
83 |
84 | I_low2high_map = self.model(I_low, ratio_low2high)
85 | I_high2low_map = self.model(I_high, ratio_high2low)
86 |
87 | loss = self.loss_fn(I_low2high_map, I_high, hook=hook_number) + \
88 | self.loss_fn(I_high2low_map, I_low, hook=hook_number)
89 |
90 | if idx % 30 == 0:
91 | log(f"iter: {iter}_{idx}\taverage_loss: {loss.item():.6f}")
92 | print(ratio_high2low, ratio_low2high)
93 | loss.backward()
94 | optimizer.step()
95 | idx += 1
96 | # pbar.update(1)
97 | # pbar.set_postfix({'loss':loss.item()})
98 |
99 | if iter % self.print_frequency == 0:
100 | self.test(iter, plot_dir='./images/samples-illum')
101 |
102 | if iter % self.save_frequency == 0:
103 | torch.save(self.model.state_dict(), './weights/illum_net.pth')
104 | log("Weight Has saved as 'illum_net.pth'")
105 |
106 | scheduler.step()
107 | iter_end_time = time.time()
108 | log(f"Time taken: {iter_end_time - iter_start_time:.3f} seconds\t lr={scheduler.get_lr()[0]:.6f}")
109 |
110 | except KeyboardInterrupt:
111 | torch.save(self.model.state_dict(), './weights/INTERRUPTED_illum.pth')
112 | log('Saved interrupt')
113 | try:
114 | sys.exit(0)
115 | except SystemExit:
116 | os._exit(0)
117 |
118 | @no_grad
119 | def test(self, epoch=-1, plot_dir='./images/samples-illum'):
120 | self.model.eval()
121 | if self.noDecom:
122 | for R_low_tensor, I_low_tensor, R_high_tensor, I_high_tensor, name in self.dataloader_test:
123 | I_low = I_low_tensor.to(self.device)
124 | I_high = I_high_tensor.to(self.device)
125 |
126 | ratio_high2low = torch.mean(torch.div((I_low + 0.0001), (I_high + 0.0001)))
127 | ratio_low2high = torch.mean(torch.div((I_high + 0.0001), (I_low + 0.0001)))
128 | print(ratio_low2high)
129 | # 采用粗略的亮度水平估计
130 | bright_low = torch.mean(I_low)
131 | bright_high = torch.ones_like(bright_low) * 0.3 + bright_low * 0.55
132 | ratio_high2low = torch.div(bright_low, bright_high)
133 | ratio_low2high = torch.div(bright_high, bright_low)
134 | print(ratio_low2high)
135 |
136 | I_low2high_map = self.model(I_low, ratio_low2high)
137 | I_high2low_map = self.model(I_high, ratio_high2low)
138 |
139 | I_low2high_np = I_low2high_map.detach().cpu().numpy()[0]
140 | I_high2low_np = I_high2low_map.detach().cpu().numpy()[0]
141 | I_low_np = I_low_tensor.numpy()[0]
142 | I_high_np = I_high_tensor.numpy()[0]
143 | sample_imgs = np.concatenate( (I_low_np, I_high_np, I_high2low_np, I_low2high_np), axis=0 )
144 |
145 | filepath = os.path.join(plot_dir, f'{name[0]}_epoch_{epoch}.png')
146 | split_point = [0, 1, 2, 3, 4]
147 | img_dim = I_low_np.shape[1:]
148 | sample(sample_imgs, split=split_point, figure_size=(2, 2),
149 | img_dim=img_dim, path=filepath, num=epoch)
150 | else:
151 | for L_low_tensor, L_high_tensor, name in self.dataloader_test:
152 | L_low = L_low_tensor.to(self.device)
153 | L_high = L_high_tensor.to(self.device)
154 |
155 | R_low, I_low = self.decom_net(L_low)
156 | R_high, I_high = self.decom_net(L_high)
157 | bright_low = torch.mean(I_low)
158 | bright_high = torch.mean(I_high)
159 | ratio_high2low = torch.div(bright_low, bright_high)
160 | ratio_low2high = torch.div(bright_high, bright_low)
161 | print(ratio_low2high)
162 |
163 | I_low2high_map = self.model(I_low, ratio_low2high)
164 | I_high2low_map = self.model(I_high, ratio_high2low)
165 |
166 | I_low2high_np = I_low2high_map.detach().cpu().numpy()[0]
167 | I_high2low_np = I_high2low_map.detach().cpu().numpy()[0]
168 | I_low_np = I_low.detach().cpu().numpy()[0]
169 | I_high_np = I_high.detach().cpu().numpy()[0]
170 | sample_imgs = np.concatenate( (I_low_np, I_high_np, I_high2low_np, I_low2high_np), axis=0 )
171 |
172 | filepath = os.path.join(plot_dir, f'{name[0]}_epoch_{epoch}.png')
173 | split_point = [0, 1, 2, 3, 4]
174 | img_dim = I_low_np.shape[1:]
175 | sample(sample_imgs, split=split_point, figure_size=(2, 2),
176 | img_dim=img_dim, path=filepath, num=epoch)
177 |
178 | if __name__ == "__main__":
179 | criterion = Illum_Loss()
180 | decom_net = DecomNet()
181 | model = IllumNet()
182 |
183 | parser = BaseParser()
184 | args = parser.parse()
185 |
186 | with open(args.config) as f:
187 | config = yaml.load(f)
188 |
189 | args.checkpoint = True
190 | if args.checkpoint is not None:
191 | if config['noDecom'] is False:
192 | decom_net = load_weights(decom_net, path='./weights/decom_net.pth')
193 | log('DecomNet loaded from decom_net.pth')
194 | model = load_weights(model, path='./weights/illum_net.pth')
195 | log('Model loaded from illum_net.pth')
196 |
197 | if config['noDecom'] is True:
198 | root_path_train = r'H:\datasets\Low-Light Dataset\LOLdataset_decom\our485'
199 | root_path_test = r'H:\datasets\Low-Light Dataset\LOLdataset_decom\eval15'
200 | list_path_train = build_LOLDataset_Decom_list_txt(root_path_train)
201 | list_path_test = build_LOLDataset_Decom_list_txt(root_path_test)
202 | # list_path_train = os.path.join(root_path_train, 'pair_list.csv')
203 | # list_path_test = os.path.join(root_path_test, 'pair_list.csv')
204 |
205 | log("Buliding LOL Dataset...")
206 | # transform = transforms.Compose([transforms.ToTensor()])
207 | dst_train = LOLDataset_Decom(root_path_train, list_path_train,
208 | crop_size=config['length'], to_RAM=True)
209 | dst_test = LOLDataset_Decom(root_path_test, list_path_test,
210 | crop_size=config['length'], to_RAM=True, training=False)
211 |
212 | train_loader = DataLoader(dst_train, batch_size = config['batch_size'], shuffle=True)
213 | test_loader = DataLoader(dst_test, batch_size=1)
214 |
215 | else:
216 | root_path_train = r'C:\DeepLearning\KinD_plus-master\LOLdataset\our485'
217 | root_path_test = r'C:\DeepLearning\KinD_plus-master\LOLdataset\eval15'
218 | list_path_train = build_LOLDataset_list_txt(root_path_train)
219 | list_path_test = build_LOLDataset_list_txt(root_path_test)
220 | # list_path_train = os.path.join(root_path_train, 'pair_list.csv')
221 | # list_path_test = os.path.join(root_path_test, 'pair_list.csv')
222 |
223 | log("Buliding LOL Dataset...")
224 | # transform = transforms.Compose([transforms.ToTensor()])
225 | dst_train = LOLDataset(root_path_train, list_path_train,
226 | crop_size=config['length'], to_RAM=True)
227 | dst_test = LOLDataset(root_path_test, list_path_test,
228 | crop_size=config['length'], to_RAM=True, training=False)
229 |
230 | train_loader = DataLoader(dst_train, batch_size = config['batch_size'], shuffle=True)
231 | test_loader = DataLoader(dst_test, batch_size=1)
232 |
233 | trainer = Illum_Trainer(config, train_loader, criterion, model,
234 | dataloader_test=test_loader, decom_net=decom_net)
235 |
236 | if args.mode == 'train':
237 | trainer.train()
238 | else:
239 | trainer.test()
--------------------------------------------------------------------------------
/illum_trainer_custom.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torch.optim import lr_scheduler
5 | import numpy as np
6 | import time
7 | import yaml
8 | import sys
9 | from tqdm import tqdm
10 | from torchvision.utils import make_grid
11 | from torchvision import transforms
12 | from torchsummary import summary
13 | from base_trainer import BaseTrainer
14 | from losses import *
15 | from models import *
16 | from base_parser import BaseParser
17 | from dataloader import *
18 |
19 | class Illum_Trainer(BaseTrainer):
20 | def __init__(self, config, dataloader, criterion, model,
21 | dataloader_test=None, decom_net=None):
22 | super().__init__(config, dataloader, criterion, model, dataloader_test)
23 | log(f'Using device {self.device}')
24 | self.decom_net = decom_net
25 | self.decom_net.to(device=self.device)
26 | torch.backends.cudnn.benchmark = True
27 |
28 | def train(self):
29 | self.model.train()
30 | log(f'Using device {self.device}')
31 | self.model.to(device=self.device)
32 | print(self.model)
33 | # summary(self.model, input_size=[(1, 384, 384), (1,)], batch_size=4)
34 |
35 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
36 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99426)
37 | try:
38 | for iter in range(self.epochs):
39 | idx = 0
40 | iter_start_time = time.time()
41 | for L_low_tensor, L_high_tensor, name in self.dataloader:
42 | optimizer.zero_grad()
43 | L_low = L_low_tensor.to(self.device)
44 | L_high = L_high_tensor.to(self.device)
45 |
46 | with torch.no_grad():
47 | _, I_low = self.decom_net(L_low)
48 | _, I_high = self.decom_net(L_high)
49 |
50 | I_out, I_standard = self.model(I_low, 1)
51 | loss = self.loss_fn(I_out, I_high, I_standard)
52 |
53 | if idx % 6 == 0:
54 | log(f"iter: {iter}_{idx}\taverage_loss: {loss.item():.6f}")
55 | loss.backward()
56 | optimizer.step()
57 | idx += 1
58 |
59 | if iter % self.print_frequency == 0:
60 | self.test(iter, plot_dir='./images/samples-illum-custom')
61 |
62 | if iter % self.save_frequency == 0:
63 | torch.save(self.model.state_dict(), f'./weights/illum_net_custom_{iter//100}.pth')
64 | log("Weight Has saved as 'illum_net.pth'")
65 |
66 | scheduler.step()
67 | iter_end_time = time.time()
68 | w, sigma = self.model.get_parameter()
69 | log(f"w:{float(w):.4f}\t sigma:{float(sigma):.2f}")
70 | log(f"Time taken: {iter_end_time - iter_start_time:.3f} seconds\t lr={scheduler.get_lr()[0]:.6f}")
71 |
72 | except KeyboardInterrupt:
73 | torch.save(self.model.state_dict(), './weights/INTERRUPTED_illum_custom.pth')
74 | log('Saved interrupt')
75 | try:
76 | sys.exit(0)
77 | except SystemExit:
78 | os._exit(0)
79 |
80 | @no_grad
81 | def test(self, epoch=-1, plot_dir='./images/samples-illum'):
82 | self.model.eval()
83 | for L_low_tensor, L_high_tensor, name in self.dataloader_test:
84 | L_low = L_low_tensor.to(self.device)
85 | L_high = L_high_tensor.to(self.device)
86 |
87 | with torch.no_grad():
88 | _, I_low = self.decom_net(L_low)
89 | _, I_high = self.decom_net(L_high)
90 | I_out, I_standard = self.model(I_low, 1)
91 | # I_low_standard = standard_illum(I_low, w=0.72, gamma=0.53, blur=True)
92 | # I_high_standard = standard_illum(I_high, w=0.08, gamma=1.34)
93 |
94 | I_standard_np = I_standard.detach().cpu().numpy()[0]
95 | I_out_np = I_out.detach().cpu().numpy()[0]
96 | I_low_np = I_low.detach().cpu().numpy()[0]
97 | I_high_np = I_high.detach().cpu().numpy()[0]
98 | # I_low_standard = standard_illum(I_low_np, dynamic=3)
99 | # I_high_standard = standard_illum(I_high_np)
100 |
101 | sample_imgs = np.concatenate( (I_low_np, I_high_np, I_standard_np, I_out_np), axis=0 )
102 |
103 | filepath = os.path.join(plot_dir, f'{name[0]}_epoch_{epoch//100}.png')
104 | split_point = [0, 1, 2, 3, 4]
105 | img_dim = I_low_np.shape[1:]
106 | sample(sample_imgs, split=split_point, figure_size=(2, 2),
107 | img_dim=img_dim, path=filepath, num=epoch)
108 |
109 |
110 | if __name__ == "__main__":
111 | criterion = Illum_Custom_Loss()
112 | decom_net = DecomNet()
113 | model = IllumNet_Custom()
114 |
115 | parser = BaseParser()
116 | args = parser.parse()
117 |
118 | with open(args.config) as f:
119 | config = yaml.load(f)
120 |
121 | args.checkpoint = True
122 | if args.checkpoint is not None:
123 | if config['noDecom'] is False:
124 | decom_net = load_weights(decom_net, path='./weights/decom_net.pth')
125 | log('DecomNet loaded from decom_net.pth')
126 | model = load_weights(model, path='./weights/illum_net_custom_0.pth')
127 | log('Model loaded from illum_net.pth')
128 |
129 | root_path_train = r'C:\DeepLearning\KinD_plus-master\LOLdataset\our485'
130 | root_path_test = r'C:\DeepLearning\KinD_plus-master\LOLdataset\eval15'
131 | list_path_train = build_LOLDataset_list_txt(root_path_train)
132 | list_path_test = build_LOLDataset_list_txt(root_path_test)
133 | # list_path_train = os.path.join(root_path_train, 'pair_list.csv')
134 | # list_path_test = os.path.join(root_path_test, 'pair_list.csv')
135 |
136 | log("Buliding LOL Dataset...")
137 | # transform = transforms.Compose([transforms.ToTensor()])
138 | dst_train = LOLDataset(root_path_train, list_path_train,
139 | crop_size=config['length'], to_RAM=True)
140 | dst_test = LOLDataset(root_path_test, list_path_test,
141 | crop_size=config['length'], to_RAM=True, training=False)
142 |
143 | train_loader = DataLoader(dst_train, batch_size = config['batch_size'], shuffle=True)
144 | test_loader = DataLoader(dst_test, batch_size=1)
145 |
146 | trainer = Illum_Trainer(config, train_loader, criterion, model,
147 | dataloader_test=test_loader, decom_net=decom_net)
148 |
149 | if args.mode == 'train':
150 | trainer.train()
151 | else:
152 | trainer.test()
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import pytorch_ssim
6 | from dataloader import *
7 |
8 | Sobel = np.array([[-1,-2,-1],
9 | [ 0, 0, 0],
10 | [ 1, 2, 1]])
11 | Robert = np.array([[0, 0],
12 | [-1, 1]])
13 | Sobel = torch.Tensor(Sobel)
14 | Robert = torch.Tensor(Robert)
15 |
16 | def feature_map_hook(*args, path=None):
17 | feature_maps = []
18 | for feature in args:
19 | feature_maps.append(feature)
20 | feature_all = torch.cat(feature_maps, dim=1)
21 | fmap = feature_all.detach().cpu().numpy()[0]
22 | fmap = np.array(fmap)
23 | fshape = fmap.shape
24 | num = fshape[0]
25 | shape = fshape[1:]
26 | sample(fmap, figure_size=(2, num//2), img_dim=shape, path=path)
27 | return fmap
28 |
29 | # 已测试本模块没有问题,作用为提取一阶导数算子滤波图(边缘图)
30 | def gradient(maps, direction, device='cuda', kernel='sobel'):
31 | channels = maps.size()[1]
32 | if kernel == 'robert':
33 | smooth_kernel_x = Robert.expand(channels, channels, 2, 2)
34 | maps = F.pad(maps, (0, 0, 1, 1))
35 | elif kernel == 'sobel':
36 | smooth_kernel_x = Sobel.expand(channels, channels, 3, 3)
37 | maps = F.pad(maps, (1, 1, 1, 1))
38 | smooth_kernel_y = smooth_kernel_x.permute(0, 1, 3, 2)
39 | if direction == "x":
40 | kernel = smooth_kernel_x
41 | elif direction == "y":
42 | kernel = smooth_kernel_y
43 | kernel = kernel.to(device=device)
44 | # kernel size is (2, 2) so need pad bottom and right side
45 | gradient_orig = torch.abs(F.conv2d(maps, weight=kernel, padding=0))
46 | grad_min = torch.min(gradient_orig)
47 | grad_max = torch.max(gradient_orig)
48 | grad_norm = torch.div((gradient_orig - grad_min), (grad_max - grad_min + 0.0001))
49 | return grad_norm
50 |
51 |
52 | def gradient_no_abs(maps, direction, device='cuda', kernel='sobel'):
53 | channels = maps.size()[1]
54 | if kernel == 'robert':
55 | smooth_kernel_x = Robert.expand(channels, channels, 2, 2)
56 | maps = F.pad(maps, (0, 0, 1, 1))
57 | elif kernel == 'sobel':
58 | smooth_kernel_x = Sobel.expand(channels, channels, 3, 3)
59 | maps = F.pad(maps, (1, 1, 1, 1))
60 | smooth_kernel_y = smooth_kernel_x.permute(0, 1, 3, 2)
61 | if direction == "x":
62 | kernel = smooth_kernel_x
63 | elif direction == "y":
64 | kernel = smooth_kernel_y
65 | kernel = kernel.to(device=device)
66 | # kernel size is (2, 2) so need pad bottom and right side
67 | gradient_orig = torch.abs(F.conv2d(maps, weight=kernel, padding=0))
68 | grad_min = torch.min(gradient_orig)
69 | grad_max = torch.max(gradient_orig)
70 | grad_norm = torch.div((gradient_orig - grad_min), (grad_max - grad_min + 0.0001))
71 | return grad_norm
72 |
73 |
74 | class Decom_Loss(nn.Module):
75 | def __init__(self):
76 | super().__init__()
77 |
78 | def reflectance_similarity(self, R_low, R_high):
79 | return torch.mean(torch.abs(R_low - R_high))
80 |
81 | def illumination_smoothness(self, I, L, name='low', hook=-1):
82 | # L_transpose = L.permute(0, 2, 3, 1)
83 | # L_gray_transpose = 0.299*L[:,:,:,0] + 0.587*L[:,:,:,1] + 0.114*L[:,:,:,2]
84 | # L_gray = L.permute(0, 3, 1, 2)
85 | L_gray = 0.299*L[:,0,:,:] + 0.587*L[:,1,:,:] + 0.114*L[:,2,:,:]
86 | L_gray = L_gray.unsqueeze(dim=1)
87 | I_gradient_x = gradient(I, "x")
88 | L_gradient_x = gradient(L_gray, "x")
89 | epsilon = 0.01*torch.ones_like(L_gradient_x)
90 | Denominator_x = torch.max(L_gradient_x, epsilon)
91 | x_loss = torch.abs(torch.div(I_gradient_x, Denominator_x))
92 | I_gradient_y = gradient(I, "y")
93 | L_gradient_y = gradient(L_gray, "y")
94 | Denominator_y = torch.max(L_gradient_y, epsilon)
95 | y_loss = torch.abs(torch.div(I_gradient_y, Denominator_y))
96 | mut_loss = torch.mean(x_loss + y_loss)
97 | if hook > -1:
98 | feature_map_hook(I, L_gray, epsilon, I_gradient_x+I_gradient_y, Denominator_x+Denominator_y,
99 | x_loss+y_loss, path=f'./images/samples-features/ilux_smooth_{name}_epoch{hook}.png')
100 | return mut_loss
101 |
102 | def mutual_consistency(self, I_low, I_high, hook=-1):
103 | low_gradient_x = gradient(I_low, "x")
104 | high_gradient_x = gradient(I_high, "x")
105 | M_gradient_x = low_gradient_x + high_gradient_x
106 | x_loss = M_gradient_x * torch.exp(-10 * M_gradient_x)
107 | low_gradient_y = gradient(I_low, "y")
108 | high_gradient_y = gradient(I_high, "y")
109 | M_gradient_y = low_gradient_y + high_gradient_y
110 | y_loss = M_gradient_y * torch.exp(-10 * M_gradient_y)
111 | mutual_loss = torch.mean(x_loss + y_loss)
112 | if hook > -1:
113 | feature_map_hook(I_low, I_high, low_gradient_x+low_gradient_y, high_gradient_x+high_gradient_y,
114 | M_gradient_x + M_gradient_y, x_loss+ y_loss, path=f'./images/samples-features/mutual_consist_epoch{hook}.png')
115 | return mutual_loss
116 |
117 | def reconstruction_error(self, R_low, R_high, I_low_3, I_high_3, L_low, L_high):
118 | recon_loss_low = torch.mean(torch.abs(R_low * I_low_3 - L_low))
119 | recon_loss_high = torch.mean(torch.abs(R_high * I_high_3 - L_high))
120 | # recon_loss_l2h = torch.mean(torch.abs(R_high * I_low_3 - L_low))
121 | # recon_loss_h2l = torch.mean(torch.abs(R_low * I_high_3 - L_high))
122 | return recon_loss_high + recon_loss_low # + recon_loss_l2h + recon_loss_h2l
123 |
124 | def forward(self, R_low, R_high, I_low, I_high, L_low, L_high, hook=-1):
125 | I_low_3 = torch.cat([I_low, I_low, I_low], dim=1)
126 | I_high_3 = torch.cat([I_high, I_high, I_high], dim=1)
127 | #network output
128 | recon_loss = self.reconstruction_error(R_low, R_high, I_low_3, I_high_3, L_low, L_high)
129 | equal_R_loss = self.reflectance_similarity(R_low, R_high)
130 | i_mutual_loss = self.mutual_consistency(I_low, I_high, hook=hook)
131 | ilux_smooth_loss = self.illumination_smoothness(I_low, L_low, hook=hook) + \
132 | self.illumination_smoothness(I_high, L_high, name='high', hook=hook)
133 |
134 | decom_loss = recon_loss + 0.009 * equal_R_loss + 0.2 * i_mutual_loss + 0.15 * ilux_smooth_loss
135 |
136 | return decom_loss
137 |
138 |
139 | class Illum_Loss(nn.Module):
140 | def __init__(self):
141 | super().__init__()
142 |
143 | def grad_loss(self, low, high, hook=-1):
144 | x_loss = F.l1_loss(gradient_no_abs(low, 'x'), gradient_no_abs(high, 'x'))
145 | y_loss = F.l1_loss(gradient_no_abs(low, 'y'), gradient_no_abs(high, 'y'))
146 | grad_loss_all = x_loss + y_loss
147 | return grad_loss_all
148 |
149 | def forward(self, I_low, I_high, hook=-1):
150 | loss_grad = self.grad_loss(I_low, I_high, hook=hook)
151 | loss_recon = F.l1_loss(I_low, I_high)
152 | loss_adjust = loss_recon + loss_grad
153 | return loss_adjust
154 |
155 | class Illum_Custom_Loss(nn.Module):
156 | def __init__(self):
157 | super().__init__()
158 |
159 | def grad_loss(self, low, high):
160 | x_loss = F.l1_loss(gradient_no_abs(low, 'x'), gradient_no_abs(high, 'x'))
161 | y_loss = F.l1_loss(gradient_no_abs(low, 'y'), gradient_no_abs(high, 'y'))
162 | grad_loss_all = x_loss + y_loss
163 | return grad_loss_all
164 |
165 | def gamma_loss(self, I_standard, I_high):
166 | loss = F.l1_loss(I_high, I_standard)
167 | return loss
168 |
169 | def forward(self, I_low, I_high, I_standard):
170 | loss_gamma = self.gamma_loss(I_standard, I_high)
171 | loss_grad = self.grad_loss(I_low, I_high)
172 | loss_recon = F.l1_loss(I_low, I_high)
173 | loss_adjust = loss_gamma + loss_recon + loss_grad
174 | return loss_adjust
175 |
176 |
177 | class Restore_Loss(nn.Module):
178 | def __init__(self):
179 | super().__init__()
180 | self.ssim_loss = pytorch_ssim.SSIM()
181 |
182 | def grad_loss(self, low, high, hook=-1):
183 | x_loss = F.mse_loss(gradient_no_abs(low, 'x'), gradient_no_abs(high, 'x'))
184 | y_loss = F.mse_loss(gradient_no_abs(low, 'y'), gradient_no_abs(high, 'y'))
185 | grad_loss_all = x_loss + y_loss
186 | return grad_loss_all
187 |
188 | def forward(self, R_low, R_high, hook=-1):
189 | # loss_grad = self.grad_loss(R_low, R_high, hook=hook)
190 | loss_recon = F.l1_loss(R_low, R_high)
191 | loss_ssim = 1-self.ssim_loss(R_low, R_high)
192 | loss_restore = loss_recon + loss_ssim #+ loss_grad
193 | return loss_restore
194 |
195 |
196 | if __name__ == "__main__":
197 | from dataloader import *
198 | from torch.utils.data import DataLoader
199 | from torchvision.utils import make_grid
200 | from matplotlib import pyplot as plt
201 | root_path_train = r'H:\datasets\Low-Light Dataset\KinD++\LOLdataset\our485'
202 | list_path_train = build_LOLDataset_list_txt(root_path_train)
203 | Batch_size = 1
204 | log("Buliding LOL Dataset...")
205 | dst_test = LOLDataset(root_path_train, list_path_train, to_RAM=True, training=False)
206 | # But when we are training a model, the mean should have another value
207 | testloader = DataLoader(dst_test, batch_size = Batch_size)
208 | for i, data in enumerate(testloader):
209 | L_low, L_high, name = data
210 | L_gradient_x = gradient_no_abs(L_high, "x", device='cpu', kernel='sobel')
211 | epsilon = 0.01*torch.ones_like(L_gradient_x)
212 | Denominator_x = torch.max(L_gradient_x, epsilon)
213 | imgs = Denominator_x
214 | img = imgs[1].numpy()
215 | sample(img, figure_size=(1,1), img_dim=400)
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from base_layers import *
6 |
7 | class DecomNet(nn.Module):
8 | def __init__(self, filters=32, activation='lrelu'):
9 | super().__init__()
10 | self.conv_input = Conv2D(3, filters)
11 | # top path build Reflectance map
12 | self.maxpool_r1 = MaxPooling2D()
13 | self.conv_r1 = Conv2D(filters, filters*2)
14 | self.maxpool_r2 = MaxPooling2D()
15 | self.conv_r2 = Conv2D(filters*2, filters*4)
16 | self.deconv_r1 = ConvTranspose2D(filters*4, filters*2)
17 | self.concat_r1 = Concat()
18 | self.conv_r3 = Conv2D(filters*4, filters*2)
19 | self.deconv_r2 = ConvTranspose2D(filters*2, filters)
20 | self.concat_r2 = Concat()
21 | self.conv_r4 = Conv2D(filters*2, filters)
22 | self.conv_r5 = nn.Conv2d(filters, 3, kernel_size=3, padding=1)
23 | self.R_out = nn.Sigmoid()
24 | # bottom path build Illumination map
25 | self.conv_i1 = Conv2D(filters, filters)
26 | self.concat_i1 = Concat()
27 | self.conv_i2 = nn.Conv2d(filters*2, 1, kernel_size=3, padding=1)
28 | self.I_out = nn.Sigmoid()
29 |
30 | def forward(self, x):
31 | conv_input = self.conv_input(x)
32 | # build Reflectance map
33 | maxpool_r1 = self.maxpool_r1(conv_input)
34 | conv_r1 = self.conv_r1(maxpool_r1)
35 | maxpool_r2 = self.maxpool_r2(conv_r1)
36 | conv_r2 = self.conv_r2(maxpool_r2)
37 | deconv_r1 = self.deconv_r1(conv_r2)
38 | concat_r1 = self.concat_r1(conv_r1, deconv_r1)
39 | conv_r3 = self.conv_r3(concat_r1)
40 | deconv_r2 = self.deconv_r2(conv_r3)
41 | concat_r2 = self.concat_r2(conv_input, deconv_r2)
42 | conv_r4 = self.conv_r4(concat_r2)
43 | conv_r5 = self.conv_r5(conv_r4)
44 | R_out = self.R_out(conv_r5)
45 |
46 | # build Illumination map
47 | conv_i1 = self.conv_i1(conv_input)
48 | concat_i1 = self.concat_i1(conv_r4, conv_i1)
49 | conv_i2 = self.conv_i2(concat_i1)
50 | I_out = self.I_out(conv_i2)
51 |
52 | return R_out, I_out
53 |
54 |
55 | class IllumNet(nn.Module):
56 | def __init__(self, filters=32, activation='lrelu'):
57 | super().__init__()
58 | self.concat_input = Concat()
59 | # bottom path build Illumination map
60 | self.conv_i1 = Conv2D(2, filters)
61 | self.conv_i2 = Conv2D(filters, filters)
62 | self.conv_i3 = Conv2D(filters, filters)
63 | self.conv_i4 = nn.Conv2d(filters, 1, kernel_size=3, padding=1)
64 | self.I_out = nn.Sigmoid()
65 |
66 | def forward(self, I, ratio):
67 | with torch.no_grad():
68 | ratio_map = torch.ones_like(I) * ratio
69 | concat_input = self.concat_input(I, ratio_map)
70 | # build Illumination map
71 | conv_i1 = self.conv_i1(concat_input)
72 | conv_i2 = self.conv_i2(conv_i1)
73 | conv_i3 = self.conv_i3(conv_i2)
74 | conv_i4 = self.conv_i4(conv_i3)
75 | I_out = self.I_out(conv_i4)
76 |
77 | return I_out
78 |
79 |
80 | class IllumNet_Custom(nn.Module):
81 | def __init__(self, filters=16, activation='lrelu', device='cuda'):
82 | super().__init__()
83 | self.concat_input = Concat()
84 | # Parameter
85 | self.Gauss = torch.as_tensor(
86 | np.array([[0.0947416, 0.118318, 0.0947416],
87 | [ 0.118318, 0.147761, 0.118318],
88 | [0.0947416, 0.118318, 0.0947416]]).astype(np.float32)
89 | )
90 | self.Gauss_kernel = self.Gauss.expand(1, 1, 3, 3).to(device)
91 | self.w = nn.Parameter(torch.FloatTensor(1), requires_grad=True).to(device).data.fill_(0.72)
92 | self.sigma = nn.Parameter(torch.FloatTensor(1), requires_grad=True).to(device).data.fill_(2.0)
93 |
94 |
95 | # bottom path build Illumination map
96 | self.conv_input = Conv2D(2, filters)
97 | self.res_block = nn.Sequential(
98 | ResConv(filters, filters),
99 | ResConv(filters, filters),
100 | ResConv(filters, filters)
101 | )
102 | # self.down1 = MaxPooling2D()
103 | # self.conv_2 = Conv2D(filters, filters*2)
104 | # self.down2 = MaxPooling2D()
105 | # self.conv_3 = Conv2D(filters*2, filters*4)
106 | # self.down3 = MaxPooling2D()
107 | # self.conv_4 = Conv2D(filters*4, filters*8)
108 |
109 | # self.d = nn.Dropout2d(0.5)
110 |
111 | # self.deconv_3 = ConvTranspose2D(filters*8, filters*4)
112 | # self.concat3 = Concat()
113 | # self.cbam3 = CBAM(filters*8)
114 | # self.deconv_2 = ConvTranspose2D(filters*8, filters*2)
115 | # self.concat2 = Concat()
116 | # self.cbam2 = CBAM(filters*4)
117 | # self.deconv_1 = ConvTranspose2D(filters*4, filters*1)
118 | # self.concat1 = Concat()
119 | # self.cbam1 = CBAM(filters*2)
120 | self.conv_out = nn.Conv2d(filters, 1, kernel_size=3, padding=1)
121 |
122 | self.I_out = nn.Sigmoid()
123 |
124 | def standard_illum_map(self, I, ratio=1, blur=False):
125 | self.w.clamp_(0.01, 0.99)
126 | self.sigma.clamp_(0.1, 10)
127 | # if blur: # low light image have much noisy
128 | # I = torch.nn.functional.conv2d(I, weight=self.Gauss_kernel, padding=1)
129 | I = torch.log(I + 1.)
130 | I_mean = torch.mean(I, dim=[2, 3], keepdim=True)
131 | I_std = torch.std(I, dim=[2, 3], keepdim=True)
132 | I_min = I_mean - self.sigma * I_std
133 | I_max = I_mean + self.sigma * I_std
134 | I_range = I_max - I_min
135 | I_out = torch.clamp((I - I_min) / I_range, min=0.0, max=1.0)
136 | # Transfer to gamma correction, center intensity is w
137 | I_out = I_out ** (-1.442695 * torch.log(self.w))
138 | return I_out
139 |
140 | def set_parameter(self, w=None):
141 | if w is None:
142 | self.w.requires_grad = True
143 | else:
144 | self.w.data.fill_(w)
145 | self.w.requires_grad = False
146 |
147 | def get_parameter(self):
148 | if self.w.device.type == 'cuda':
149 | w = self.w.detach().cpu().numpy()
150 | sigma = self.sigma.detach().cpu().numpy()
151 | else:
152 | w = self.w.numpy()
153 | sigma = self.sigma.numpy()
154 | return w, sigma
155 |
156 | def forward(self, I, ratio):
157 | I_standard = self.standard_illum_map(I, ratio)
158 | concat_input = torch.cat([I, I_standard], dim=1)
159 | # build Illumination map
160 | conv_input = self.conv_input(concat_input)
161 | res_block = self.res_block(conv_input)
162 | # down1 = self.down1(conv_1)
163 | # conv_2 = self.conv_2(down1)
164 | # down2 = self.down2(conv_2)
165 | # conv_3 = self.conv_3(down2)
166 | # down3 = self.down3(conv_3)
167 | # conv_4 = self.conv_4(down3)
168 | # d = self.d(conv_4)
169 | # deconv_3 = self.deconv_3(d)
170 |
171 | # concat3 = self.concat3(conv_3, deconv_3)
172 | # cbam3 = self.cbam3(concat3)
173 | # deconv_2 = self.deconv_2(cbam3)
174 |
175 | # concat2 = self.concat2(conv_2, deconv_2)
176 | # cbam2 = self.cbam2(concat2)
177 | # deconv_1 = self.deconv_1(cbam2)
178 |
179 | # concat1 = self.concat1(conv_1, deconv_1)
180 | # cbam1 = self.cbam1(concat1)
181 | res_out = res_block + conv_input
182 | conv_out = self.conv_out(res_out)
183 | I_out = self.I_out(conv_out)
184 |
185 | return I_out, I_standard
186 |
187 |
188 | class RestoreNet_MSIA(nn.Module):
189 | def __init__(self, filters=16, activation='relu'):
190 | super().__init__()
191 | # Illumination Attention
192 | self.i_input = nn.Conv2d(1,1,kernel_size=3,padding=1)
193 | self.i_att = nn.Sigmoid()
194 |
195 | # Network
196 | self.conv1_1 = Conv2D(3, filters, activation)
197 | self.conv1_2 = Conv2D(filters, filters*2, activation)
198 | self.msia1 = MSIA(filters*2, activation)
199 |
200 | self.conv2_1 = Conv2D(filters*2, filters*4, activation)
201 | self.conv2_2 = Conv2D(filters*4, filters*4, activation)
202 | self.msia2 = MSIA(filters*4, activation)
203 |
204 | self.conv3_1 = Conv2D(filters*4, filters*8, activation)
205 | self.dropout = nn.Dropout2d(0.5)
206 | self.conv3_2 = Conv2D(filters*8, filters*4, activation)
207 | self.msia3 = MSIA(filters*4, activation)
208 |
209 | self.conv4_1 = Conv2D(filters*4, filters*2, activation)
210 | self.conv4_2 = Conv2D(filters*2, filters*2, activation)
211 | self.msia4 = MSIA(filters*2, activation)
212 |
213 | self.conv5_1 = Conv2D(filters*2, filters*1, activation)
214 | self.conv5_2 = nn.Conv2d(filters, 3, kernel_size=1, padding=0)
215 | self.out = nn.Sigmoid()
216 |
217 | def forward(self, R, I):
218 | # Illumination Attention
219 | i_input = self.i_input(I)
220 | i_att = self.i_att(i_input)
221 |
222 | # Network
223 | conv1 = self.conv1_1(R)
224 | conv1 = self.conv1_2(conv1)
225 | msia1 = self.msia1(conv1, i_att)
226 |
227 | conv2 = self.conv2_1(msia1)
228 | conv2 = self.conv2_2(conv2)
229 | msia2 = self.msia2(conv2, i_att)
230 |
231 | conv3 = self.conv3_1(msia2)
232 | conv3 = self.conv3_2(conv3)
233 | msia3 = self.msia3(conv3, i_att)
234 |
235 | conv4 = self.conv4_1(msia3)
236 | conv4 = self.conv4_2(conv4)
237 | msia4 = self.msia4(conv4, i_att)
238 |
239 | conv5 = self.conv5_1(msia4)
240 | conv5 = self.conv5_2(conv5)
241 |
242 | # out = self.out(conv5)
243 | out = conv5.clamp(min=0.0, max=1.0)
244 | return out
245 |
246 |
247 | class RestoreNet_Unet(nn.Module):
248 | def __init__(self, filters=32, activation='lrelu'):
249 | super().__init__()
250 | self.conv1_1 = Conv2D(4, filters)
251 | self.conv1_2 = Conv2D(filters, filters)
252 | self.pool1 = MaxPooling2D()
253 |
254 | self.conv2_1 = Conv2D(filters, filters*2)
255 | self.conv2_2 = Conv2D(filters*2, filters*2)
256 | self.pool2 = MaxPooling2D()
257 |
258 | self.conv3_1 = Conv2D(filters*2, filters*4)
259 | self.conv3_2 = Conv2D(filters*4, filters*4)
260 | self.pool3 = MaxPooling2D()
261 |
262 | self.conv4_1 = Conv2D(filters*4, filters*8)
263 | self.conv4_2 = Conv2D(filters*8, filters*8)
264 | self.pool4 = MaxPooling2D()
265 |
266 | self.conv5_1 = Conv2D(filters*8, filters*16)
267 | self.conv5_2 = Conv2D(filters*16, filters*16)
268 | self.dropout = nn.Dropout2d(0.5)
269 |
270 | self.upv6 = ConvTranspose2D(filters*16, filters*8)
271 | self.concat6 = Concat()
272 | self.conv6_1 = Conv2D(filters*16, filters*8)
273 | self.conv6_2 = Conv2D(filters*8, filters*8)
274 |
275 | self.upv7 = ConvTranspose2D(filters*8, filters*4)
276 | self.concat7 = Concat()
277 | self.conv7_1 = Conv2D(filters*8, filters*4)
278 | self.conv7_2 = Conv2D(filters*4, filters*4)
279 |
280 | self.upv8 = ConvTranspose2D(filters*4, filters*2)
281 | self.concat8 = Concat()
282 | self.conv8_1 = Conv2D(filters*4, filters*2)
283 | self.conv8_2 = Conv2D(filters*2, filters*2)
284 |
285 | self.upv9 = ConvTranspose2D(filters*2, filters)
286 | self.concat9 = Concat()
287 | self.conv9_1 = Conv2D(filters*2, filters)
288 | self.conv9_2 = Conv2D(filters, filters)
289 |
290 | self.conv10_1 = nn.Conv2d(filters, 3, kernel_size=1, stride=1)
291 | self.out = nn.Sigmoid()
292 |
293 | def forward(self, R, I):
294 | x = torch.cat([R, I], dim=1)
295 | conv1 = self.conv1_1(x)
296 | conv1 = self.conv1_2(conv1)
297 | pool1 = self.pool1(conv1)
298 |
299 | conv2 = self.conv2_1(pool1)
300 | conv2 = self.conv2_2(conv2)
301 | pool2 = self.pool1(conv2)
302 |
303 | conv3 = self.conv3_1(pool2)
304 | conv3 = self.conv3_2(conv3)
305 | pool3 = self.pool1(conv3)
306 |
307 | conv4 = self.conv4_1(pool3)
308 | conv4 = self.conv4_2(conv4)
309 | pool4 = self.pool1(conv4)
310 |
311 | conv5 = self.conv5_1(pool4)
312 | conv5 = self.conv5_2(conv5)
313 |
314 | # d = self.dropout(conv5)
315 | up6 = self.upv6(conv5)
316 | up6 = self.concat6(conv4, up6)
317 | conv6 = self.conv6_1(up6)
318 | conv6 = self.conv6_2(conv6)
319 |
320 | up7 = self.upv7(conv6)
321 | up7 = self.concat7(conv3, up7)
322 | conv7 = self.conv7_1(up7)
323 | conv7 = self.conv7_2(conv7)
324 |
325 | up8 = self.upv8(conv7)
326 | up8 = self.concat8(conv2, up8)
327 | conv8 = self.conv8_1(up8)
328 | conv8 = self.conv8_2(conv8)
329 |
330 | up9 = self.upv9(conv8)
331 | up9 = self.concat9(conv1, up9)
332 | conv9 = self.conv9_1(up9)
333 | conv9 = self.conv9_2(conv9)
334 |
335 | conv10 = self.conv10_1(conv9)
336 | out = self.out(conv10)
337 | return out
338 |
339 | class KinD_noDecom(nn.Module):
340 | def __init__(self, filters=32, activation='lrelu'):
341 | super().__init__()
342 | # self.decom_net = DecomNet()
343 | self.restore_net = RestoreNet_Unet()
344 | self.illum_net = IllumNet()
345 |
346 | def forward(self, R, I, ratio):
347 | I_final = self.illum_net(I, ratio)
348 | R_final = self.restore_net(R, I)
349 | I_final_3 = torch.cat([I_final, I_final, I_final], dim=1)
350 | output = I_final_3 * R_final
351 | return R_final, I_final, output
352 |
353 |
354 | class KinD(nn.Module):
355 | def __init__(self, filters=32, activation='lrelu'):
356 | super().__init__()
357 | self.decom_net = DecomNet()
358 | self.restore_net = RestoreNet_Unet()
359 | self.illum_net = IllumNet()
360 | self.KinD_noDecom = KinD_noDecom()
361 | self.KinD_noDecom.restore_net = self.restore_net
362 | self.KinD_noDecom.illum_net = self.illum_net
363 |
364 | def forward(self, L, ratio):
365 | R, I = self.decom_net(L)
366 | R_final, I_final, output = self.KinD_noDecom(R, I, ratio)
367 | # I_final = self.illum_net(I, ratio)
368 | # R_final = self.restore_net(R, I)
369 | # I_final_3 = torch.cat([I_final, I_final, I_final], dim=1)
370 | # output = I_final_3 * R_final
371 | return R_final, I_final, output
372 |
373 | class KinD_plus(nn.Module):
374 | def __init__(self, filters=32, activation='lrelu'):
375 | super().__init__()
376 | self.decom_net = DecomNet()
377 | self.restore_net = RestoreNet_MSIA()
378 | self.illum_net = IllumNet_Custom()
379 |
380 | def forward(self, L, ratio):
381 | R, I = self.decom_net(L)
382 | # R_final, I_final, output = self.KinD_noDecom(R, I, ratio)
383 | I_final, I_standard = self.illum_net(I, ratio)
384 | R_final = self.restore_net(R, I)
385 | I_final_3 = torch.cat([I_final, I_final, I_final], dim=1)
386 | output = I_final_3 * R_final
387 | return R_final, I_final, output
--------------------------------------------------------------------------------
/pytorch_ssim/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 | def gaussian(window_size, sigma):
8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
9 | return gauss/gauss.sum()
10 |
11 | def create_window(window_size, channel):
12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
15 | return window
16 |
17 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
20 |
21 | mu1_sq = mu1.pow(2)
22 | mu2_sq = mu2.pow(2)
23 | mu1_mu2 = mu1*mu2
24 |
25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
28 |
29 | C1 = 0.01**2
30 | C2 = 0.03**2
31 |
32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
33 |
34 | if size_average:
35 | return ssim_map.mean()
36 | else:
37 | return ssim_map.mean(1).mean(1).mean(1)
38 |
39 | class SSIM(torch.nn.Module):
40 | def __init__(self, window_size = 11, size_average = True):
41 | super(SSIM, self).__init__()
42 | self.window_size = window_size
43 | self.size_average = size_average
44 | self.channel = 1
45 | self.window = create_window(window_size, self.channel)
46 |
47 | def forward(self, img1, img2):
48 | (_, channel, _, _) = img1.size()
49 |
50 | if channel == self.channel and self.window.data.type() == img1.data.type():
51 | window = self.window
52 | else:
53 | window = create_window(self.window_size, channel)
54 |
55 | if img1.is_cuda:
56 | window = window.cuda(img1.get_device())
57 | window = window.type_as(img1)
58 |
59 | self.window = window
60 | self.channel = channel
61 |
62 |
63 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
64 |
65 | def ssim(img1, img2, window_size = 11, size_average = True):
66 | (_, channel, _, _) = img1.size()
67 | window = create_window(window_size, channel)
68 |
69 | if img1.is_cuda:
70 | window = window.cuda(img1.get_device())
71 | window = window.type_as(img1)
72 |
73 | return _ssim(img1, img2, window, window_size, channel, size_average)
--------------------------------------------------------------------------------
/restore_MSIA_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import time
6 | import yaml
7 | import sys
8 | from tqdm import tqdm
9 | from torch.optim import lr_scheduler
10 | from torchvision.utils import make_grid
11 | from torchvision import transforms
12 | from torchsummary import summary
13 | from base_trainer import BaseTrainer
14 | from losses import *
15 | from models import *
16 | from base_parser import BaseParser
17 | from dataloader import *
18 |
19 | class Restore_Trainer(BaseTrainer):
20 | def __init__(self, config, dataloader, criterion, model,
21 | dataloader_test=None, decom_net=None):
22 | super().__init__(config, dataloader, criterion, model, dataloader_test)
23 | log(f'Using device {self.device}')
24 | self.decom_net = decom_net
25 | self.decom_net.to(device=self.device)
26 |
27 | def train(self):
28 | # print(self.model)
29 | summary(self.model, input_size=[(3, 256, 256), (1,256,256)])
30 |
31 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
32 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99426) #0.977237, 0.986233
33 | try:
34 | for iter in range(self.epochs):
35 | epoch_loss = 0
36 | idx = 0
37 | hook_number = -1
38 | iter_start_time = time.time()
39 | # with tqdm(total=self.steps_per_epoch) as pbar:
40 | for L_low_tensor, L_high_tensor, name in self.dataloader:
41 | optimizer.zero_grad()
42 | L_low = L_low_tensor.to(self.device)
43 | L_high = L_high_tensor.to(self.device)
44 |
45 | with torch.no_grad():
46 | R_low, I_low = self.decom_net(L_low)
47 | R_high, I_high = self.decom_net(L_high)
48 |
49 | R_restore = self.model(R_low, I_low)
50 |
51 | if idx % self.print_frequency == 0:
52 | hook_number = iter
53 | loss = self.loss_fn(R_restore, R_high, hook=hook_number)
54 | hook_number = -1
55 | if idx % 8 == 0:
56 | log(f"iter: {iter}_{idx}\taverage_loss: {loss.item():.6f}")
57 | loss.backward()
58 | optimizer.step()
59 | idx += 1
60 | # pbar.update(1)
61 | # pbar.set_postfix({'loss':loss.item()})
62 |
63 | if iter % self.print_frequency == 0:
64 | self.test(iter, plot_dir='./images/samples-restore-MSIA')
65 |
66 | if iter % self.save_frequency == 0:
67 | torch.save(self.model.state_dict(), f'./weights/restore_net_MSIA_{iter//100}.pth')
68 | log("Weight Has saved as 'restore_net.pth'")
69 |
70 | scheduler.step()
71 | iter_end_time = time.time()
72 | log(f"Time taken: {iter_end_time - iter_start_time:.3f} seconds\t lr={scheduler.get_lr()[0]:.6f}")
73 | # print("End of epochs {.0f}, Time taken: {.3f}, average loss: {.5f}".format(
74 | # idx, iter_end_time - iter_start_time, epoch_loss / idx))
75 |
76 | except KeyboardInterrupt:
77 | torch.save(self.model.state_dict(), './weights/INTERRUPTED_restore.pth')
78 | print('Saved interrupt')
79 | try:
80 | sys.exit(0)
81 | except SystemExit:
82 | os._exit(0)
83 |
84 | @no_grad
85 | def test(self, epoch=-1, plot_dir='./images/samples-restore'):
86 | self.model.eval()
87 | for L_low_tensor, L_high_tensor, name in self.dataloader_test:
88 | L_low = L_low_tensor.to(self.device)
89 | L_high = L_high_tensor.to(self.device)
90 |
91 | R_low, I_low = self.decom_net(L_low)
92 | R_high, I_high = self.decom_net(L_high)
93 |
94 | R_restore = self.model(R_low, I_low)
95 |
96 | R_restore_np = R_restore.detach().cpu().numpy()[0]
97 | I_low_np = I_low.detach().cpu().numpy()[0]
98 | R_low_np = R_low.detach().cpu().numpy()[0]
99 | R_high_np = R_high.detach().cpu().numpy()[0]
100 | sample_imgs = np.concatenate( (I_low_np, R_low_np, R_restore_np, R_high_np), axis=0 )
101 |
102 | filepath = os.path.join(plot_dir, f'{name[0]}_epoch_{epoch//100}.png')
103 | split_point = [0, 1, 4, 7, 10]
104 | img_dim = I_low_np.shape[1:]
105 | sample(sample_imgs, split=split_point, figure_size=(2, 2),
106 | img_dim=img_dim, path=filepath, num=epoch)
107 |
108 | if __name__ == "__main__":
109 | criterion = Restore_Loss()
110 | model = RestoreNet_MSIA()
111 | decom_net = DecomNet()
112 |
113 | parser = BaseParser()
114 | args = parser.parse()
115 |
116 | with open(args.config) as f:
117 | config = yaml.load(f)
118 | args.checkpoint = True
119 |
120 | if args.checkpoint is not None:
121 | if config['noDecom'] is False:
122 | decom_net = load_weights(decom_net, path='./weights/decom_net.pth')
123 | log('DecomNet loaded from decom_net.pth')
124 | model = load_weights(model, path='./weights/restore_net_MSIA_1.pth')
125 | log('Model loaded from restore_net_MSIA.pth')
126 |
127 | root_path_train = r'C:\DeepLearning\KinD_plus-master\LOLdataset\our485'
128 | root_path_test = r'C:\DeepLearning\KinD_plus-master\LOLdataset\eval15'
129 | list_path_train = build_LOLDataset_list_txt(root_path_train)
130 | list_path_test = build_LOLDataset_list_txt(root_path_test)
131 | # list_path_train = os.path.join(root_path_train, 'pair_list.csv')
132 | # list_path_test = os.path.join(root_path_test, 'pair_list.csv')
133 |
134 | log("Buliding LOL Dataset...")
135 | # transform = transforms.Compose([transforms.ToTensor()])
136 | dst_train = LOLDataset(root_path_train, list_path_train,
137 | crop_size=config['length'], to_RAM=True)
138 | dst_test = LOLDataset(root_path_test, list_path_test,
139 | crop_size=config['length'], to_RAM=True, training=False)
140 |
141 | train_loader = DataLoader(dst_train, batch_size = config['batch_size'], shuffle=True)
142 | # train_loader = data_prefetcher(train_loader)
143 | test_loader = DataLoader(dst_test, batch_size=1)
144 |
145 | trainer = Restore_Trainer(config, train_loader, criterion, model,
146 | dataloader_test=test_loader, decom_net=decom_net)
147 |
148 | if args.mode == 'train':
149 | trainer.train()
150 | else:
151 | trainer.test()
--------------------------------------------------------------------------------
/restore_trainer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import time
6 | import yaml
7 | import sys
8 | from tqdm import tqdm
9 | from torch.optim import lr_scheduler
10 | from torchvision.utils import make_grid
11 | from torchvision import transforms
12 | from torchsummary import summary
13 | from base_trainer import BaseTrainer
14 | from losses import *
15 | from models import *
16 | from base_parser import BaseParser
17 | from dataloader import *
18 |
19 | class Restore_Trainer(BaseTrainer):
20 | def __init__(self, config, dataloader, criterion, model,
21 | dataloader_test=None, decom_net=None):
22 | super().__init__(config, dataloader, criterion, model, dataloader_test)
23 | log(f'Using device {self.device}')
24 | self.decom_net = decom_net
25 | self.decom_net.to(device=self.device)
26 |
27 | def train(self):
28 | # print(self.model)
29 | summary(self.model, input_size=[(3, 384, 384), (1,384,384)])
30 |
31 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate)
32 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.986233) #0.977237, 0.986233
33 | try:
34 | for iter in range(self.epochs):
35 | epoch_loss = 0
36 | idx = 0
37 | hook_number = -1
38 | iter_start_time = time.time()
39 | # with tqdm(total=self.steps_per_epoch) as pbar:
40 | if self.noDecom is True:
41 | for R_low_tensor, I_low_tensor, R_high_tensor, I_high_tensor, name in self.dataloader:
42 | optimizer.zero_grad()
43 | I_low = I_low_tensor.to(self.device)
44 | R_low = R_low_tensor.to(self.device)
45 | R_high = R_high_tensor.to(self.device)
46 | R_restore = self.model(R_low, I_low)
47 |
48 | if idx % self.print_frequency == 0:
49 | hook_number = iter
50 | loss = self.loss_fn(R_restore, R_high, hook=hook_number)
51 | hook_number = -1
52 | if idx % 30 == 0:
53 | log(f"iter: {iter}_{idx}\taverage_loss: {loss.item():.6f}")
54 | loss.backward()
55 | optimizer.step()
56 | idx += 1
57 | # pbar.update(1)
58 | # pbar.set_postfix({'loss':loss.item()})
59 |
60 | else:
61 | for L_low_tensor, L_high_tensor, name in self.dataloader:
62 | optimizer.zero_grad()
63 | L_low = L_low_tensor.to(self.device)
64 | L_high = L_high_tensor.to(self.device)
65 |
66 | with torch.no_grad():
67 | R_low, I_low = self.decom_net(L_low)
68 | R_high, I_high = self.decom_net(L_high)
69 |
70 | R_restore = self.model(R_low, I_low)
71 |
72 | if idx % self.print_frequency == 0:
73 | hook_number = iter
74 | loss = self.loss_fn(R_restore, R_high, hook=hook_number)
75 | hook_number = -1
76 | if idx % 30 == 0:
77 | log(f"iter: {iter}_{idx}\taverage_loss: {loss.item():.6f}")
78 | loss.backward()
79 | optimizer.step()
80 | idx += 1
81 | # pbar.update(1)
82 | # pbar.set_postfix({'loss':loss.item()})
83 |
84 | if iter % self.print_frequency == 0:
85 | self.test(iter, plot_dir='./images/samples-restore')
86 |
87 | if iter % self.save_frequency == 0:
88 | torch.save(self.model.state_dict(), './weights/restore_net.pth')
89 | log("Weight Has saved as 'restore_net.pth'")
90 |
91 | scheduler.step()
92 | iter_end_time = time.time()
93 | log(f"Time taken: {iter_end_time - iter_start_time:.3f} seconds\t lr={scheduler.get_lr()[0]:.6f}")
94 | # print("End of epochs {.0f}, Time taken: {.3f}, average loss: {.5f}".format(
95 | # idx, iter_end_time - iter_start_time, epoch_loss / idx))
96 |
97 | except KeyboardInterrupt:
98 | torch.save(self.model.state_dict(), './weights/INTERRUPTED_restore.pth')
99 | print('Saved interrupt')
100 | try:
101 | sys.exit(0)
102 | except SystemExit:
103 | os._exit(0)
104 |
105 | @no_grad
106 | def test(self, epoch=-1, plot_dir='./images/samples-restore'):
107 | self.model.eval()
108 | if self.noDecom:
109 | for R_low_tensor, I_low_tensor, R_high_tensor, I_high_tensor, name in self.dataloader_test:
110 | I_low = I_low_tensor.to(self.device)
111 | R_low = R_low_tensor.to(self.device)
112 | R_restore = self.model(R_low, I_low)
113 |
114 | R_restore_np = R_restore.detach().cpu().numpy()[0]
115 | I_low_np = I_low_tensor.numpy()[0]
116 | R_low_np = R_low_tensor.numpy()[0]
117 | R_high_np = R_high_tensor.numpy()[0]
118 | sample_imgs = np.concatenate( (I_low_np, R_low_np, R_restore_np, R_high_np), axis=0 )
119 |
120 | filepath = os.path.join(plot_dir, f'{name[0]}_epoch_{epoch}.png')
121 | split_point = [0, 1, 4, 7, 10]
122 | img_dim = I_low_np.shape[1:]
123 | sample(sample_imgs, split=split_point, figure_size=(2, 2),
124 | img_dim=img_dim, path=filepath, num=epoch)
125 | else:
126 | for L_low_tensor, L_high_tensor, name in self.dataloader_test:
127 | L_low = L_low_tensor.to(self.device)
128 | L_high = L_high_tensor.to(self.device)
129 |
130 | R_low, I_low = self.decom_net(L_low)
131 | R_high, I_high = self.decom_net(L_high)
132 |
133 | R_restore = self.model(R_low, I_low)
134 |
135 | R_restore_np = R_restore.detach().cpu().numpy()[0]
136 | I_low_np = I_low.detach().cpu().numpy()[0]
137 | R_low_np = R_low.detach().cpu().numpy()[0]
138 | R_high_np = R_high.detach().cpu().numpy()[0]
139 | sample_imgs = np.concatenate( (I_low_np, R_low_np, R_restore_np, R_high_np), axis=0 )
140 |
141 | filepath = os.path.join(plot_dir, f'{name[0]}_epoch_{epoch}.png')
142 | split_point = [0, 1, 4, 7, 10]
143 | img_dim = I_low_np.shape[1:]
144 | sample(sample_imgs, split=split_point, figure_size=(2, 2),
145 | img_dim=img_dim, path=filepath, num=epoch)
146 |
147 | if __name__ == "__main__":
148 | criterion = Restore_Loss()
149 | model = RestoreNet_Unet()
150 | decom_net = DecomNet()
151 |
152 | parser = BaseParser()
153 | args = parser.parse()
154 |
155 | with open(args.config) as f:
156 | config = yaml.load(f)
157 | args.checkpoint = True
158 |
159 | if args.checkpoint is not None:
160 | if config['noDecom'] is False:
161 | pretrain_decom = torch.load('./weights/decom_net_test3.pth')
162 | decom_net.load_state_dict(pretrain_decom)
163 | log('DecomNet loaded from decom_net.pth')
164 | pretrain = torch.load('./weights/restore_net.pth')
165 | model.load_state_dict(pretrain)
166 | print('Model loaded from restore_net.pth')
167 |
168 | if config['noDecom'] is True:
169 | root_path_train = r'H:\datasets\Low-Light Dataset\LOLdataset_decom\our485'
170 | root_path_test = r'H:\datasets\Low-Light Dataset\LOLdataset_decom\eval15'
171 | list_path_train = build_LOLDataset_Decom_list_txt(root_path_train)
172 | list_path_test = build_LOLDataset_Decom_list_txt(root_path_test)
173 | # list_path_train = os.path.join(root_path_train, 'pair_list.csv')
174 | # list_path_test = os.path.join(root_path_test, 'pair_list.csv')
175 |
176 | log("Buliding LOL Dataset...")
177 | # transform = transforms.Compose([transforms.ToTensor()])
178 | dst_train = LOLDataset_Decom(root_path_train, list_path_train,
179 | crop_size=config['length'], to_RAM=True)
180 | dst_test = LOLDataset_Decom(root_path_test, list_path_test,
181 | crop_size=config['length'], to_RAM=True, training=False)
182 |
183 | train_loader = DataLoader(dst_train, batch_size = config['batch_size'], shuffle=True)
184 | test_loader = DataLoader(dst_test, batch_size=1)
185 |
186 | else:
187 | root_path_train = r'C:\DeepLearning\KinD_plus-master\LOLdataset\our485'
188 | root_path_test = r'C:\DeepLearning\KinD_plus-master\LOLdataset\eval15'
189 | list_path_train = build_LOLDataset_list_txt(root_path_train)
190 | list_path_test = build_LOLDataset_list_txt(root_path_test)
191 | # list_path_train = os.path.join(root_path_train, 'pair_list.csv')
192 | # list_path_test = os.path.join(root_path_test, 'pair_list.csv')
193 |
194 | log("Buliding LOL Dataset...")
195 | # transform = transforms.Compose([transforms.ToTensor()])
196 | dst_train = LOLDataset(root_path_train, list_path_train,
197 | crop_size=config['length'], to_RAM=True)
198 | dst_test = LOLDataset(root_path_test, list_path_test,
199 | crop_size=config['length'], to_RAM=True, training=False)
200 |
201 | train_loader = DataLoader(dst_train, batch_size = config['batch_size'], shuffle=True)
202 | test_loader = DataLoader(dst_test, batch_size=1)
203 |
204 | trainer = Restore_Trainer(config, train_loader, criterion, model,
205 | dataloader_test=test_loader, decom_net=decom_net)
206 |
207 | if args.mode == 'train':
208 | trainer.train()
209 | else:
210 | trainer.test()
--------------------------------------------------------------------------------
/test_your_pictures.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import time
6 | import yaml
7 | import sys
8 | from torch.optim import lr_scheduler
9 | from torchvision.utils import make_grid
10 | from torchvision import transforms
11 | from torchsummary import summary
12 | from base_trainer import BaseTrainer
13 | from losses import *
14 | from models import *
15 | from base_parser import BaseParser
16 | from dataloader import *
17 |
18 | class KinD_Player(BaseTrainer):
19 | def __init__(self, model, dataloader_test, plot_more=False):
20 | self.dataloader_test = dataloader_test
21 | self.model = model
22 | self.plot_more = plot_more
23 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
24 | self.model.to(device=self.device)
25 |
26 | @no_grad
27 | def test(self, target_b=0.70, plot_dir='./images/samples-KinD'):
28 | self.model.eval()
29 | self.model.to(device=self.device)
30 | for L_low_tensor, name in self.dataloader_test:
31 | L_low = L_low_tensor.to(self.device)
32 |
33 | if self.plot_more:
34 | # Use DecomNet to decomposite Reflectance Map and Illumation Map
35 | R_low, I_low = self.model.decom_net(L_low)
36 | # Compute brightness ratio
37 | bright_low = torch.mean(I_low)
38 | else:
39 | bright_low = torch.mean(L_low)
40 |
41 | bright_high = torch.ones_like(bright_low) * target_b + 0.5 * bright_low
42 | ratio = torch.div(bright_high, bright_low)
43 | log(f"Brightness: {bright_high:.4f}\tIllumation Magnification: {ratio.item():.3f}")
44 |
45 | R_final, I_final, output_final = self.model(L_low, ratio)
46 |
47 | output_final_np = output_final.detach().cpu().numpy()[0]
48 | L_low_np = L_low_tensor.numpy()[0]
49 | # Only plot result
50 | filepath = os.path.join(plot_dir, f'{name[0]}.png')
51 | split_point = [0, 3]
52 | img_dim = L_low_np.shape[1:]
53 | sample(output_final_np, split=split_point, figure_size=(1, 1),
54 | img_dim=img_dim, path=filepath)
55 |
56 | if self.plot_more:
57 | R_final_np = R_final.detach().cpu().numpy()[0]
58 | I_final_np = I_final.detach().cpu().numpy()[0]
59 | R_low_np = R_low.detach().cpu().numpy()[0]
60 | I_low_np = I_low.detach().cpu().numpy()[0]
61 |
62 | sample_imgs = np.concatenate( (R_low_np, I_low_np, L_low_np,
63 | R_final_np, I_final_np, output_final_np), axis=0 )
64 | filepath = os.path.join(plot_dir, f'{name[0]}_extra.png')
65 | split_point = [0, 3, 4, 7, 10, 11, 14]
66 | img_dim = L_low_np.shape[1:]
67 | sample(sample_imgs, split=split_point, figure_size=(2, 3),
68 | img_dim=img_dim, path=filepath)
69 |
70 |
71 | class TestParser(BaseParser):
72 | def parse(self):
73 | self.parser.add_argument("-p", "--plot_more", default=True,
74 | help="Plot intermediate variables. such as R_images and I_images")
75 | self.parser.add_argument("-c", "--checkpoint", default="./weights/",
76 | help="Path of checkpoints")
77 | self.parser.add_argument("-i", "--input_dir", default="./images/inputs/",
78 | help="Path of input pictures")
79 | self.parser.add_argument("-o", "--output_dir", default="./images/outputs/",
80 | help="Path of output pictures")
81 | self.parser.add_argument("-b", "--b_target", default=0.75, help="Target brightness")
82 | # self.parser.add_argument("-u", "--use_gpu", default=True,
83 | # help="If you want to use GPU to accelerate")
84 | return self.parser.parse_args()
85 |
86 |
87 | if __name__ == "__main__":
88 | model = KinD()
89 | parser = TestParser()
90 | args = parser.parse()
91 |
92 | input_dir = args.input_dir
93 | output_dir = args.output_dir
94 | plot_more = args.plot_more
95 | checkpoint = args.checkpoint
96 | decom_net_dir = os.path.join(checkpoint, "decom_net.pth")
97 | restore_net_dir = os.path.join(checkpoint, "restore_net.pth")
98 | illum_net_dir = os.path.join(checkpoint, "illum_net.pth")
99 |
100 | pretrain_decom = torch.load(decom_net_dir)
101 | model.decom_net.load_state_dict(pretrain_decom)
102 | log('Model loaded from decom_net.pth')
103 | pretrain_resotre = torch.load(restore_net_dir)
104 | model.restore_net.load_state_dict(pretrain_resotre)
105 | log('Model loaded from restore_net.pth')
106 | pretrain_illum = torch.load(illum_net_dir)
107 | model.illum_net.load_state_dict(pretrain_illum)
108 | log('Model loaded from illum_net.pth')
109 |
110 | log("Buliding Dataset...")
111 | dst = CustomDataset(input_dir)
112 | log(f"There are {len(dst)} images in the input direction...")
113 | dataloader = DataLoader(dst, batch_size=1)
114 |
115 | KinD = KinD_Player(model, dataloader, plot_more=plot_more)
116 |
117 | KinD.test(plot_dir=output_dir, target_b=args.b_target)
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import random
4 | import matplotlib.pyplot as plt
5 | from PIL import Image
6 | import cv2
7 | import collections
8 | import torch
9 | import torchvision
10 | import shutil
11 | import time
12 |
13 |
14 | def log(string):
15 | print(time.strftime('%H:%M:%S'), ">> ", string)
16 |
17 | def data_augmentation(image, mode):
18 | if mode == 0:
19 | # original
20 | return image
21 | elif mode == 1:
22 | # flip up and down
23 | return np.flipud(image)
24 | elif mode == 2:
25 | # rotate counterwise 90 degree
26 | return np.rot90(image)
27 | elif mode == 3:
28 | # rotate 90 degree and flip up and down
29 | image = np.rot90(image)
30 | return np.flipud(image)
31 | elif mode == 4:
32 | # rotate 180 degree
33 | return np.rot90(image, k=2)
34 | elif mode == 5:
35 | # rotate 180 degree and flip
36 | image = np.rot90(image, k=2)
37 | return np.flipud(image)
38 | elif mode == 6:
39 | # rotate 270 degree
40 | return np.rot90(image, k=3)
41 | elif mode == 7:
42 | # rotate 270 degree and flip
43 | image = np.rot90(image, k=3)
44 | return np.flipud(image)
45 |
46 | # 作为装饰器函数
47 | def no_grad(fn):
48 | with torch.no_grad():
49 | def transfer(*args,**kwargs):
50 | fn(*args,**kwargs)
51 | return fn
52 |
53 |
54 | def load_weights(model, path):
55 | pretrained_dict=torch.load(path)
56 | model_dict=model.state_dict()
57 | # 1. filter out unnecessary keys
58 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
59 | # 2. overwrite entries in the existing state dict
60 | model_dict.update(pretrained_dict)
61 | model.load_state_dict(model_dict)
62 | return model
63 |
64 |
65 | class data_prefetcher():
66 | def __init__(self, loader):
67 | self.loader = iter(loader)
68 | self.stream = torch.cuda.Stream()
69 | self.preload()
70 |
71 | def preload(self):
72 | try:
73 | self.next_low, self.next_high, self.next_name = next(self.loader)
74 | except StopIteration:
75 | self.next_low = None
76 | self.next_high = None
77 | self.next_name = None
78 | return
79 | with torch.cuda.stream(self.stream):
80 | self.next_low = self.next_low.cuda(non_blocking=True)
81 | self.next_high = self.next_high.cuda(non_blocking=True)
82 |
83 | def __iter__(self):
84 | return self
85 |
86 | def __next__(self):
87 | torch.cuda.current_stream().wait_stream(self.stream)
88 | low = self.next_low
89 | high = self.next_high
90 | name = self.next_name
91 | self.preload()
92 | return low, high, name
93 |
94 | # def rgb2hsv(img):
95 | # if torch.is_tensor:
96 | # log(f'Image tensor size is {img.size()}')
97 | # else:
98 | # log("This Function can only deal PyTorch Tensor!")
99 | # return img
100 | # r, g, b = img.split(1, 0)
101 | # tensor_max = torch.max(torch.max(r, g), b)
102 | # tensor_min = torch.min(torch.min(r, g), b)
103 | # m = tensor_max-tensor_min
104 | # if tensor_max == tensor_min:
105 | # h = 0
106 | # elif tensor_max == r:
107 | # if g >= b:
108 | # h = ((g-b)/m)*60
109 | # else:
110 | # h = ((g-b)/m)*60 + 360
111 | # elif tensor_max == g:
112 | # h = ((b-r)/m)*60 + 120
113 | # elif tensor_max == b:
114 | # h = ((r-g)/m)*60 + 240
115 | # if tensor_max == 0:
116 | # s = 0
117 | # else:
118 | # s = m/tensor_max
119 | # v = tensor_max
120 | # return h, s, v
121 |
122 | def standard_illum(I, dynamic=2, w=0.5, gamma=None, blur=False):
123 | sigma = dynamic
124 | if torch.is_tensor(I):
125 | # I = torch.log(I + 1.)
126 | if blur:
127 | Gauss = torch.as_tensor(
128 | np.array([[0.0947416, 0.118318, 0.0947416],
129 | [ 0.118318, 0.147761, 0.118318],
130 | [0.0947416, 0.118318, 0.0947416]]).astype(np.float32)
131 | ).to(I.device)
132 | channels = I.size()[1]
133 | Gauss_kernel = Gauss.expand(channels, channels, 3, 3)
134 | I = torch.nn.functional.conv2d(I, weight=Gauss_kernel, padding=1)
135 | I_mean = torch.mean(I, dim=[2, 3], keepdim=True)
136 | I_std = torch.std(I, dim=[2, 3], keepdim=True)
137 | # I_max = torch.nn.AdaptiveMaxPool2d((1, 1))(I)
138 | # I_min = 1 - torch.nn.AdaptiveMaxPool2d((1, 1))(1-I)
139 | I_min = I_mean - sigma * I_std
140 | I_max = I_mean + sigma * I_std
141 | I_range = I_max - I_min
142 | I_out = torch.clamp((I - I_min) / I_range, min=0.0, max=1.0)
143 | # if gamma is not None:
144 | # return I**gamma
145 | w = torch.as_tensor(np.array(w).astype(np.float32)).to(I.device)
146 | I_out = I_out.pow(-1.442695 * torch.log(w))
147 | print((-1.442695 * torch.log(w)))
148 |
149 | else:
150 | I = np.log(I + 1.)
151 | I_mean = np.mean(I)
152 | I_std = np.std(I)
153 | I_min = I_mean - sigma * I_std
154 | I_max = I_mean + sigma * I_std
155 | I_range = I_max - I_min
156 | I_out = np.clip((I - I_min) / I_range, 0.0, 1.0)
157 |
158 | return I_out
159 |
160 |
161 | def sample(imgs, split=None ,figure_size=(2, 3), img_dim=(400, 600), path=None, num=0):
162 | if type(img_dim) is int:
163 | img_dim = (img_dim, img_dim)
164 | img_dim = tuple(img_dim)
165 | if len(img_dim) == 1:
166 | h_dim = img_dim
167 | w_dim = img_dim
168 | elif len(img_dim) == 2:
169 | h_dim, w_dim = img_dim
170 | h, w = figure_size
171 | if split is None:
172 | num_of_imgs = figure_size[0] * figure_size[1]
173 | gap = len(imgs) // num_of_imgs
174 | split = list(range(0, len(imgs)+1, gap))
175 | figure = np.zeros((h_dim*h, w_dim*w, 3))
176 | for i in range(h):
177 | for j in range(w):
178 | idx = i*w+j
179 | if idx >= len(split)-1: break
180 | digit = imgs[ split[idx] : split[idx+1] ]
181 | if len(digit) == 1:
182 | for k in range(3):
183 | figure[i*h_dim: (i+1)*h_dim,
184 | j*w_dim: (j+1)*w_dim, k] = digit
185 | elif len(digit) == 3:
186 | for k in range(3):
187 | figure[i*h_dim: (i+1)*h_dim,
188 | j*w_dim: (j+1)*w_dim, k] = digit[2-k]
189 | if path is None:
190 | cv2.imshow('Figure%d'%num, figure)
191 | cv2.waitKey()
192 | else:
193 | figure *= 255
194 | filename1 = path.split('\\')[-1]
195 | filename2 = path.split('/')[-1]
196 | if len(filename1) < len(filename2):
197 | filename = filename1
198 | else:
199 | filename = filename2
200 | root_path = path[:-len(filename)]
201 | if not os.path.exists(root_path):
202 | os.makedirs(root_path)
203 | log("Saving Image at {}".format(path))
204 | cv2.imwrite(path, figure)
--------------------------------------------------------------------------------
/utils/img_generator.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | def sample(imgs, split=None ,figure_size=(2, 3), img_dim=96, path=None, num=0):
5 | h, w = figure_size
6 | if split is None:
7 | split = range(len(imgs)+1)
8 | figure = np.zeros((img_dim*h, img_dim*w, 3))
9 | for i in range(h):
10 | for j in range(w):
11 | idx = i*w+j
12 | if idx >= len(split)-1: break
13 | digit = imgs[ split[idx] : split[idx+1] ]
14 | if len(digit) == 1:
15 | for k in range(3):
16 | figure[i*img_dim: (i+1)*img_dim,
17 | j*img_dim: (j+1)*img_dim, k] = digit
18 | elif len(digit) == 3:
19 | for k in range(3):
20 | figure[i*img_dim: (i+1)*img_dim,
21 | j*img_dim: (j+1)*img_dim, k] = digit[2-k]
22 | if path is None:
23 | cv2.imshow('Figure%d'%num, figure)
24 | cv2.waitKey()
25 | else:
26 | figure *= 255
27 | print(">> Saving Image at {}".format(path))
28 | cv2.imwrite(path, figure)
--------------------------------------------------------------------------------