├── .gitattributes ├── .gitignore ├── README.md ├── base_layers.py ├── base_parser.py ├── base_trainer.py ├── checkpoints ├── decom_net.pth ├── illum_net.pth └── restore_net.pth ├── config.yaml ├── dataloader.py ├── decom_trainer.py ├── evaluate.py ├── figures ├── 778_epoch_-1.png ├── failure_case_decom.png └── official_decom_train.png ├── illum_trainer.py ├── illum_trainer_custom.py ├── losses.py ├── models.py ├── pytorch_ssim └── __init__.py ├── restore_MSIA_trainer.py ├── restore_trainer.py ├── test_your_pictures.py ├── utils.py └── utils └── img_generator.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # custom ignore 2 | *.h5 3 | *.jpg 4 | *.zip 5 | weights/ 6 | images/ 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # celery beat schedule file 102 | celerybeat-schedule 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | *.json 134 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # KinD-pytorch 2 | This is a PyTorch implementation of KinD. 3 | 4 | The official KinD project(TensorFlow) is [KinD](https://github.com/zhangyhuaee/KinD). 5 | 6 | The KinD net was proposed in the following [Paper](http://doi.acm.org/10.1145/3343031.3350926). 7 | 8 | Kindling the Darkness: a Practical Low-light Image Enhancer. In ACM MM 2019
9 | Yonghua Zhang, Jiawan Zhang, Xiaojie Guo 10 | **** 11 | 12 | ## Enviroment ## 13 | 1. Python = 3.6 14 | 2. PyTorch = 1.2.0 15 | 3. Other common packages 16 | 17 | ## Test ## 18 | Please put test images into './images/inputs' folder and download the pre-trained checkpoints from [BaiduNetDisk](https://pan.baidu.com/s/1e_P6_qxQqAwDG7q6NN_2ng), 提取码:fxkl, then just run 19 | ```shell 20 | python test_your_picture.py 21 | # -i: change input path, -o: change output path, -p: Plot more information 22 | # -c: change checkpoints path, -b: change default target brightness 23 | ``` 24 | 25 | ## Train ## 26 | The original LOLdataset can be downloaded from [here](https://daooshee.github.io/BMVC2018website/). 27 | For training, **please change the dataset path in the code**, then run 28 | ```shell 29 | python decom_trainer.py 30 | python illum_trainer.py 31 | python restore_trainer.py 32 | ``` 33 | You can also evaluate on the LOLdataset, **please change the dataset path in the code**, then run 34 | ```shell 35 | python evaluate_LOLdataset.py 36 | ``` 37 | 38 | ## Problems ## 39 | I meet some serious problems when I try to train the decomposition net, which makes results look unpleasant. 40 | 41 | ### My PyTorch implementation's evaluation on LOLDataset: ### 42 | 43 | The problem that confuses me the most is the illumination_smoothness_loss. As long as I add this loss, my illuminance map output will tend to be completely black (low light map) and close to gray (high light map). 44 | 45 | ### My PyTorch implementation's failure case on LOLDataset: ### 46 | 47 | I have run the official TensorFlow code to train and test the decompostion net. The result is pretty strange. If I load official checkpoints to test it, it will perform well. However, if I use LOLDataset to train it, it will be worse and worse. I am really puzzled about this issue. If you have any idea about it, please tell me. 48 | 49 | I show the example below. 50 | ### Official implementation's strange case on LOLDataset: ### 51 | 52 | The left column shows the decomposition results of the high light map, and the right column shows the decomposition results of the low light map. The left side of each image is a reflection map, and the right side is a illumination map. The first line is the effect of the official weight, the second line is the effect when the official code is retrained by 100 epoch, and the third line is the result of the official code training of 1600 epoch. The training code cannot achieve the effect of official weights. 53 | 54 | Other test results on LOLDataset(eval15) can be found at the samples-KinD in the [BaiduNetDisk](https://pan.baidu.com/s/1e_P6_qxQqAwDG7q6NN_2ng), 提取码:fxkl. 55 | 56 | ## References ## 57 | [1] Y. Zhang, J. Zhang, and X. Guo, “Kindling the darkness: A practical low-light image enhancer,” in ACM MM, 2019, pp. 1632–1640. 58 | -------------------------------------------------------------------------------- /base_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class MSIA(nn.Module): 6 | def __init__(self, filters, activation='lrelu'): 7 | super().__init__() 8 | # Down 1 9 | self.conv_bn_relu_1 = Conv_BN_Relu(filters, activation) 10 | # Down 2 11 | self.down_2 = MaxPooling2D(2, 2) 12 | self.conv_bn_relu_2 = Conv_BN_Relu(filters, activation) 13 | self.deconv_2 = ConvTranspose2D(filters, filters) 14 | # Down 4 15 | self.down_4 = MaxPooling2D(2, 2) 16 | self.conv_bn_relu_4 = Conv_BN_Relu(filters, activation, kernel=1) 17 | self.deconv_4_1 = ConvTranspose2D(filters, filters) 18 | self.deconv_4_2 = ConvTranspose2D(filters, filters) 19 | # output 20 | self.out = Conv2D(filters*4, filters) 21 | 22 | def forward(self, R, I_att): 23 | R_att = R * I_att 24 | # Down 1 25 | msia_1 = self.conv_bn_relu_1(R_att) 26 | # Down 2 27 | down_2 = self.down_2(R_att) 28 | conv_bn_relu_2 = self.conv_bn_relu_2(down_2) 29 | msia_2 = self.deconv_2(conv_bn_relu_2) 30 | # Down 4 31 | down_4 = self.down_4(down_2) 32 | conv_bn_relu_4 = self.conv_bn_relu_4(down_4) 33 | deconv_4 = self.deconv_4_1(conv_bn_relu_4) 34 | msia_4 = self.deconv_4_2(deconv_4) 35 | # concat 36 | concat = torch.cat([R, msia_1, msia_2, msia_4], dim=1) 37 | out = self.out(concat) 38 | return out 39 | 40 | 41 | class Conv_BN_Relu(nn.Module): 42 | def __init__(self, channels, activation='lrelu', kernel=3): 43 | super().__init__() 44 | self.ActivationLayer = nn.LeakyReLU(inplace=True) 45 | if activation == 'relu': 46 | self.ActivationLayer = nn.ReLU(inplace=True) 47 | self.conv_bn_relu = nn.Sequential( 48 | nn.Conv2d(channels, channels, kernel_size=kernel, padding=kernel//2), 49 | nn.BatchNorm2d(channels, momentum=0.99), # 原论文用的tf.layer的默认参数 50 | self.ActivationLayer, 51 | ) 52 | 53 | def forward(self, x): 54 | return self.conv_bn_relu(x) 55 | 56 | 57 | class DoubleConv(nn.Module): 58 | def __init__(self, in_channels, out_channels, activation='lrelu'): 59 | super().__init__() 60 | self.doubleconv = nn.Sequential( 61 | Conv2D(in_channels, out_channels, activation), 62 | Conv2D(out_channels,out_channels, activation) 63 | ) 64 | 65 | def forward(self, x): 66 | return self.doubleconv(x) 67 | 68 | class ResConv(nn.Module): 69 | def __init__(self, in_channels, out_channels, activation='lrelu'): 70 | super().__init__() 71 | self.relu = nn.LeakyReLU(0.2, inplace=True) 72 | if activation == 'relu': 73 | self.relu = nn.ReLU(inplace=True) 74 | 75 | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) 76 | self.bn1 = nn.BatchNorm2d(out_channels, momentum=0.8) 77 | self.cbam = CBAM(out_channels) 78 | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) 79 | self.bn2 = nn.BatchNorm2d(out_channels, momentum=0.8) 80 | 81 | def forward(self, x): 82 | conv1 = self.conv1(x) 83 | bn1 = self.bn1(conv1) 84 | x1 = self.relu(bn1) 85 | cbam = self.cbam(x1) 86 | conv2 = self.conv2(cbam) 87 | bn2 = self.bn1(conv2) 88 | out = bn2 + x 89 | return out 90 | 91 | class Conv2D(nn.Module): 92 | def __init__(self, in_channels, out_channels, activation='lrelu', stride=1): 93 | super().__init__() 94 | self.ActivationLayer = nn.LeakyReLU(inplace=True) 95 | if activation == 'relu': 96 | self.ActivationLayer = nn.ReLU(inplace=True) 97 | self.conv_relu = nn.Sequential( 98 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1), 99 | self.ActivationLayer, 100 | ) 101 | 102 | def forward(self, x): 103 | return self.conv_relu(x) 104 | 105 | 106 | class ConvTranspose2D(nn.Module): 107 | def __init__(self, in_channels, out_channels, activation='lrelu'): 108 | super().__init__() 109 | self.ActivationLayer = nn.LeakyReLU(inplace=True) 110 | if activation == 'relu': 111 | self.ActivationLayer = nn.ReLU(inplace=True) 112 | self.deconv_relu = nn.Sequential( 113 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0), 114 | self.ActivationLayer, 115 | ) 116 | 117 | def forward(self, x): 118 | return self.deconv_relu(x) 119 | 120 | 121 | class MaxPooling2D(nn.Module): 122 | def __init__(self, kernel_size=2, stride=2): 123 | super().__init__() 124 | self.maxpool = nn.MaxPool2d(kernel_size=kernel_size, stride=stride) 125 | 126 | def forward(self, x): 127 | return self.maxpool(x) 128 | 129 | 130 | class AvgPooling2D(nn.Module): 131 | def __init__(self, kernel_size=2, stride=2): 132 | super().__init__() 133 | self.avgpool = nn.AvgPool2d(kernel_size=2, stride=2) 134 | 135 | def forward(self, x): 136 | return self.avgpool(x) 137 | 138 | 139 | class ChannelAttention(nn.Module): 140 | def __init__(self, in_planes, ratio=16): 141 | super().__init__() 142 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 143 | self.max_pool = nn.AdaptiveMaxPool2d(1) 144 | 145 | self.sharedMLP = nn.Sequential( 146 | nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), nn.ReLU(), 147 | nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)) 148 | self.sigmoid = nn.Sigmoid() 149 | 150 | def forward(self, x): 151 | avgout = self.sharedMLP(self.avg_pool(x)) 152 | maxout = self.sharedMLP(self.max_pool(x)) 153 | return self.sigmoid(avgout + maxout) 154 | 155 | 156 | class SpatialAttention(nn.Module): 157 | def __init__(self, kernel_size=3): 158 | super().__init__() 159 | self.conv = nn.Conv2d(2,1,kernel_size, padding=1, bias=False) 160 | self.sigmoid = nn.Sigmoid() 161 | 162 | def forward(self, x): 163 | avgout = torch.mean(x, dim=1, keepdim=True) 164 | maxout, _ = torch.max(x, dim=1, keepdim=True) 165 | x = torch.cat([avgout, maxout], dim=1) 166 | x = self.conv(x) 167 | return self.sigmoid(x) 168 | 169 | 170 | class CBAM(nn.Module): 171 | def __init__(self, planes): 172 | super().__init__() 173 | self.ca = ChannelAttention(planes) 174 | self.sa = SpatialAttention() 175 | def forward(self, x): 176 | x = self.ca(x) * x 177 | out = self.sa(x) * x 178 | return x 179 | 180 | 181 | class Concat(nn.Module): 182 | def forward(self, x, y): 183 | _, _, xh, xw = x.size() 184 | _, _, yh, yw = y.size() 185 | diffY = xh - yh 186 | diffX = xw - yw 187 | y = F.pad(y, (diffX // 2, diffX - diffX//2, 188 | diffY // 2, diffY - diffY//2)) 189 | return torch.cat((x, y), dim=1) -------------------------------------------------------------------------------- /base_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | class BaseParser(): 4 | def __init__(self): 5 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 6 | 7 | def parse(self): 8 | self.parser.add_argument("--mode", default="train", choices=["train", "test"]) 9 | self.parser.add_argument("--config", default="./config.yaml", help="path to config") 10 | self.parser.add_argument("--checkpoint", default=True,help="path to checkpoint to restore") 11 | return self.parser.parse_args() 12 | -------------------------------------------------------------------------------- /base_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim 3 | import os 4 | import sys 5 | import time 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import cv2 9 | from torchsummary import summary 10 | 11 | class BaseTrainer: 12 | def __init__(self, config, dataloader, criterion, model, 13 | dataloader_test=None, extra_model=None): 14 | self.initialize(config) 15 | self.dataloader = dataloader 16 | self.dataloader_test = dataloader_test 17 | self.loss_fn = criterion 18 | self.model = model 19 | self.extra_model = extra_model 20 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 21 | self.model.to(device=self.device) 22 | # faster convolutions, but more memory 23 | if self.device == 'cuda': 24 | torch.backends.cudnn.benchmark=True 25 | 26 | def initialize(self, config): 27 | self.batch_size = config['batch_size'] 28 | self.length = config['length'] 29 | self.epochs = config['epochs'] 30 | self.steps_per_epoch = config['steps_per_epoch'] 31 | self.print_frequency = config['print_frequency'] 32 | self.save_frequency = config['save_frequency'] 33 | self.weights_dir = config['weights_dir'] 34 | self.samples_dir = config['samples_dir'] # './logs/samples' 35 | self.learning_rate = config['learning_rate'] 36 | self.noDecom = config['noDecom'] 37 | 38 | def train(self): 39 | print(f'Using device {self.device}') 40 | summary(self.model, input_size=(3, 48, 48)) 41 | 42 | self.model.to(device=self.device) 43 | # faster convolutions, but more memory 44 | cudnn.benchmark = True 45 | 46 | optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate) 47 | try: 48 | for iter in range(self.epochs): 49 | epoch_loss = 0 50 | steps = 0 51 | iter_start_time = time.time() 52 | for idx, data in enumerate(self.dataloader): 53 | input_ = data['input'] 54 | input_ = input_.to(self.device) 55 | target = data['target'] 56 | target = target.to(self.device) 57 | y_pred = self.model(input_) 58 | loss = self.loss_fn(y_pred, target) 59 | print("iter: ", idx, "average_loss: ", loss.item()) 60 | optimizer.zero_grad() 61 | loss.backward() 62 | optimizer.step() 63 | steps += 1 64 | if idx > 0 and idx % self.save_frequency == 0: 65 | # torch.save(self.model.state_dict(), './checkpoints/g_net_{}.pth'.format(str(idx % 3))) 66 | print('Saved model.') 67 | self.test(iter, idx, plotImage=True, saveImage=True) 68 | iter_end_time = time.time() 69 | print("End of epochs {}, Time taken: {},average loss: {}".format(iter, iter_end_time - iter_start_time, epoch_loss / steps)) 70 | iter_end_time = time.time() 71 | print("End of epochs {}, Time taken: {.3f}, average loss: {.5f}".format(iter, iter_end_time - iter_start_time, epoch_loss / steps)) 72 | except KeyboardInterrupt: 73 | torch.save(self.model.state_dict(), 'INTERRUPTED.pth') 74 | print('Saved interrupt') 75 | try: 76 | sys.exit(0) 77 | except SystemExit: 78 | os._exit(0) 79 | 80 | def test(self, epoch=-1, plot_dir='./images/samples-illum'): 81 | self.model.eval() 82 | for R_low_tensor, I_low_tensor, R_high_tensor, I_high_tensor, name in self.dataloader_test: 83 | I_low = I_low_tensor.to(self.device) 84 | I_high = I_high_tensor.to(self.device) 85 | with torch.no_grad(): 86 | ratio_high2low = torch.mean(torch.div((I_low + 0.0001), (I_high + 0.0001))) 87 | ratio_low2high = torch.mean(torch.div((I_high + 0.0001), (I_low + 0.0001))) 88 | ratio_high2low_map = torch.ones_like(I_low) * ratio_high2low 89 | ratio_low2high_map = torch.ones_like(I_low) * ratio_low2high 90 | 91 | I_low2high_map = self.model(I_low, ratio_low2high_map) 92 | I_high2low_map = self.model(I_high, ratio_high2low_map) 93 | 94 | I_low2high_np = I_low2high_map.detach().cpu().numpy()[0] 95 | I_high2low_np = I_high2low_map.detach().cpu().numpy()[0] 96 | I_low_np = I_low_tensor.numpy()[0] 97 | I_high_np = I_high_tensor.numpy()[0] 98 | sample_imgs = np.concatenate( (I_low_np, I_high_np, I_high2low_np, I_low2high_np), axis=0 ) 99 | filepath = os.path.join(plot_dir, f'{name}_epoch_{epoch}.png') 100 | split_point = [0, 1, 2, 3, 4] 101 | sample(sample_imgs, split=split_point, figure_size=(2, 2), 102 | img_dim=self.length, path=filepath, num=epoch) -------------------------------------------------------------------------------- /checkpoints/decom_net.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenghansen/KinD-pytorch/815d26f538221ba5b59799f5439d491579f6bd2d/checkpoints/decom_net.pth -------------------------------------------------------------------------------- /checkpoints/illum_net.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenghansen/KinD-pytorch/815d26f538221ba5b59799f5439d491579f6bd2d/checkpoints/illum_net.pth -------------------------------------------------------------------------------- /checkpoints/restore_net.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenghansen/KinD-pytorch/815d26f538221ba5b59799f5439d491579f6bd2d/checkpoints/restore_net.pth -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 32 2 | length: 256 3 | epochs: 401 4 | steps_per_epoch: 128 5 | 6 | print_frequency: 10 7 | save_frequency: 10 8 | 9 | samples_dir: './samples' 10 | weights_dir: './checkpoints' 11 | 12 | learning_rate: 0.0004 13 | checkpoints: True 14 | 15 | noDecom: False -------------------------------------------------------------------------------- /dataloader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | import matplotlib.pyplot as plt 5 | import collections 6 | import torch 7 | import torchvision 8 | import cv2 9 | import shutil 10 | import time 11 | from PIL import Image 12 | import torchvision.transforms as transforms 13 | from torch.utils.data import Dataset, DataLoader 14 | from utils import * 15 | 16 | 17 | class CustomDataset(Dataset): 18 | def __init__(self, datapath): 19 | super().__init__() 20 | self.datapath = datapath 21 | self.img_path = [os.path.join(datapath, f) for f in os.listdir(datapath) if 22 | any(filetype in f.lower() for filetype in ['jpeg', 'png', 'jpg', 'bmp'])] 23 | self.name = [f.split(".")[0] for f in os.listdir(datapath) if any(filetype in 24 | f.lower() for filetype in ['jpeg', 'png', 'jpg', 'bmp'])] 25 | 26 | def __len__(self): 27 | return len(self.img_path) 28 | 29 | def __getitem__(self, idx): 30 | datafiles = self.img_path[idx] 31 | img = Image.open(datafiles).convert('RGB') 32 | img = np.asarray(img, np.float32).transpose((2,0,1)) / 255. 33 | return img, self.name[idx] 34 | 35 | 36 | class LOLDataset(Dataset): 37 | def __init__(self, root, list_path, crop_size=256, to_RAM=False, training=True): 38 | super(LOLDataset,self).__init__() 39 | self.training = training 40 | self.to_RAM = to_RAM 41 | self.root = root 42 | self.list_path = list_path 43 | self.crop_size = crop_size 44 | with open(list_path) as f: 45 | self.pairs = f.readlines() 46 | self.files = [] 47 | for pair in self.pairs: 48 | lr_path, hr_path = pair.split(",") 49 | hr_path = hr_path[:-1] 50 | name = lr_path.split("\\")[-1][:-4] 51 | lr_file = os.path.join(self.root, lr_path) 52 | hr_file = os.path.join(self.root, hr_path) 53 | self.files.append({ 54 | "lr": lr_file, 55 | "hr": hr_file, 56 | "name": name 57 | }) 58 | self.data = [] 59 | if self.to_RAM: 60 | for i, fileinfo in enumerate(self.files): 61 | name = fileinfo["name"] 62 | lr_img = Image.open(fileinfo["lr"]) 63 | hr_img = Image.open(fileinfo["hr"]) 64 | self.data.append({ 65 | "lr": lr_img, 66 | "hr": hr_img, 67 | "name": name 68 | }) 69 | log("Finish loading all images to RAM...") 70 | 71 | def __len__(self): 72 | return len(self.files) 73 | 74 | def __getitem__(self, idx): 75 | datafiles = self.files[idx] 76 | 77 | '''load the datas''' 78 | if not self.to_RAM: 79 | name = datafiles["name"] 80 | lr_img = Image.open(datafiles["lr"]) 81 | hr_img = Image.open(datafiles["hr"]) 82 | else: 83 | name = self.data[idx]["name"] 84 | lr_img = self.data[idx]["lr"] 85 | hr_img = self.data[idx]["hr"] 86 | 87 | '''random crop the inputs''' 88 | if self.crop_size > 0: 89 | 90 | #select a random start-point for croping operation 91 | h_offset = random.randint(0, lr_img.size[1] - self.crop_size) 92 | w_offset = random.randint(0, lr_img.size[0] - self.crop_size) 93 | #crop the image and the label 94 | crop_box = (w_offset, h_offset, w_offset+self.crop_size, h_offset+self.crop_size) 95 | lr_crop = lr_img 96 | hr_crop = hr_img 97 | if self.training is True: 98 | lr_crop = lr_img.crop(crop_box) 99 | hr_crop = hr_img.crop(crop_box) 100 | rand_mode = np.random.randint(0, 7) 101 | lr_crop = data_augmentation(lr_crop, rand_mode) 102 | hr_crop = data_augmentation(hr_crop, rand_mode) 103 | 104 | 105 | '''convert PIL Image to numpy array''' 106 | lr_crop = np.asarray(lr_crop, np.float32).transpose((2,0,1)) / 255. 107 | hr_crop = np.asarray(hr_crop, np.float32).transpose((2,0,1)) / 255. 108 | return lr_crop, hr_crop, name 109 | 110 | 111 | class LOLDataset_Decom(Dataset): 112 | def __init__(self, root, list_path, 113 | crop_size=256, to_RAM=False, training=True): 114 | super().__init__() 115 | self.training = training 116 | self.to_RAM = to_RAM 117 | self.root = root 118 | self.list_path = list_path 119 | self.crop_size = crop_size 120 | with open(list_path) as f: 121 | self.pairs = f.readlines() 122 | self.files = [] 123 | for pair in self.pairs: 124 | lr_path_R, lr_path_I, hr_path_R, hr_path_I = pair.split(",") 125 | hr_path_I = hr_path_I[:-1] 126 | name = lr_path_R.split("\\")[-1][:-4] 127 | lr_file_R = os.path.join(self.root, lr_path_R) 128 | lr_file_I = os.path.join(self.root, lr_path_I) 129 | hr_file_R = os.path.join(self.root, hr_path_R) 130 | hr_file_I = os.path.join(self.root, hr_path_I) 131 | self.files.append({ 132 | "lr_R": lr_file_R, 133 | "lr_I": lr_file_I, 134 | "hr_R": hr_file_R, 135 | "hr_I": hr_file_I, 136 | "name": name 137 | }) 138 | self.data = [] 139 | if self.to_RAM: 140 | for i, fileinfo in enumerate(self.files): 141 | name = fileinfo["name"] 142 | lr_img_R = Image.open(fileinfo["lr_R"]) 143 | hr_img_R = Image.open(fileinfo["hr_R"]) 144 | lr_img_I = Image.open(fileinfo["lr_I"]).convert('L') 145 | hr_img_I = Image.open(fileinfo["hr_I"]).convert('L') 146 | self.data.append({ 147 | "lr_R": lr_img_R, 148 | "lr_I": lr_img_I, 149 | "hr_R": hr_img_R, 150 | "hr_I": hr_img_I, 151 | "name": name 152 | }) 153 | log("Finish loading all images to RAM...") 154 | 155 | def __len__(self): 156 | return len(self.files) 157 | 158 | def __getitem__(self, idx): 159 | datafiles = self.files[idx] 160 | 161 | '''load the datas''' 162 | if not self.to_RAM: 163 | name = datafiles["name"] 164 | lr_img_R = Image.open(datafiles["lr_R"]) 165 | hr_img_R = Image.open(datafiles["hr_R"]) 166 | lr_img_I = Image.open(datafiles["lr_I"]).convert('L') 167 | hr_img_I = Image.open(datafiles["hr_I"]).convert('L') 168 | else: 169 | name = self.data[idx]["name"] 170 | lr_img_R = self.data[idx]["lr_R"] 171 | lr_img_I = self.data[idx]["lr_I"] 172 | hr_img_R = self.data[idx]["hr_R"] 173 | hr_img_I = self.data[idx]["hr_I"] 174 | 175 | 176 | '''random crop the inputs''' 177 | if self.crop_size > 0: 178 | 179 | #select a random start-point for croping operation 180 | h_offset = random.randint(0, lr_img_R.size[1] - self.crop_size) 181 | w_offset = random.randint(0, lr_img_R.size[0] - self.crop_size) 182 | #crop the image and the label 183 | crop_box = (w_offset, h_offset, w_offset+self.crop_size, h_offset+self.crop_size) 184 | lr_crop_R = lr_img_R 185 | lr_crop_I = lr_img_I 186 | hr_crop_R = hr_img_R 187 | hr_crop_I = hr_img_I 188 | if self.training is True: 189 | lr_crop_R = lr_crop_R.crop(crop_box) 190 | lr_crop_I = lr_crop_I.crop(crop_box) 191 | hr_crop_R = hr_crop_R.crop(crop_box) 192 | hr_crop_I = hr_crop_I.crop(crop_box) 193 | rand_mode = np.random.randint(0, 7) 194 | lr_crop_R = data_augmentation(lr_crop_R, rand_mode) 195 | lr_crop_I = data_augmentation(lr_crop_I, rand_mode) 196 | hr_crop_R = data_augmentation(hr_crop_R, rand_mode) 197 | hr_crop_I = data_augmentation(hr_crop_I, rand_mode) 198 | 199 | 200 | '''convert PIL Image to numpy array''' 201 | lr_crop_R = np.asarray(lr_crop_R, np.float32).transpose((2,0,1)) / 255. 202 | lr_crop_I = np.expand_dims(np.asarray(lr_crop_I, np.float32) , axis=0) / 255. 203 | hr_crop_R = np.asarray(hr_crop_R, np.float32).transpose((2,0,1)) / 255. 204 | hr_crop_I = np.expand_dims(np.asarray(hr_crop_I, np.float32) , axis=0) / 255. 205 | return lr_crop_R, lr_crop_I, hr_crop_R, hr_crop_I, name 206 | 207 | 208 | def build_LOLDataset_list_txt(dst_dir): 209 | log(f"Buliding LOLDataset list text at {dst_dir}") 210 | lr_dir = os.path.join(dst_dir, 'low') 211 | hr_dir = os.path.join(dst_dir, 'high') 212 | img_lr_path = [os.path.join('low', name) for name in os.listdir(lr_dir)] 213 | img_hr_path = [os.path.join('high', name) for name in os.listdir(hr_dir)] 214 | list_path = os.path.join(dst_dir, 'pair_list.csv') 215 | with open(list_path, 'w') as f: 216 | for lr_path, hr_path in zip(img_lr_path, img_hr_path): 217 | f.write(f"{lr_path},{hr_path}\n") 218 | log(f"Finish... There are {len(img_lr_path)} pairs...") 219 | return list_path 220 | 221 | 222 | def build_LOLDataset_Decom_list_txt(dst_dir): 223 | log(f"Buliding LOLDataset Decom list text at {dst_dir}") 224 | dir_lists = [] 225 | tail = ['low\\R', 'low\\I', 'high\\R', 'high\\I'] 226 | for t in tail: 227 | dir_lists.append(os.path.join(dst_dir, t)) 228 | imgs_path = [[],[],[],[]] 229 | for i, direction in enumerate(dir_lists): 230 | for name in os.listdir(direction): 231 | path = os.path.join(tail[i], name) 232 | imgs_path[i].append(path) 233 | list_path = os.path.join(dst_dir, 'pair_list.csv') 234 | with open(list_path, 'w') as f: 235 | for lr_R, lr_I, hr_R, hr_I in zip(*imgs_path): 236 | f.write(f"{lr_R},{lr_I},{hr_R},{hr_I}\n") 237 | log(f"Finish... There are {len(imgs_path[0])} pairs...") 238 | return list_path 239 | 240 | 241 | def divide_dataset(dst_dir): 242 | lr_dir_R = os.path.join(dst_dir, 'low/R') 243 | lr_dir_I = os.path.join(dst_dir, 'low/I') 244 | hr_dir_R = os.path.join(dst_dir, 'high/R') 245 | hr_dir_I = os.path.join(dst_dir, 'high/I') 246 | for name in os.listdir(dst_dir): 247 | path = os.path.join(dst_dir, name) 248 | name = name[:-4] 249 | item = name.split("_") 250 | if item[0] == 'high' and item[-1] == 'R': 251 | shutil.move(path, os.path.join(hr_dir_R, item[1]+".png")) 252 | if item[0] == 'high' and item[-1] == 'I': 253 | shutil.move(path, os.path.join(hr_dir_I, item[1]+".png")) 254 | if item[0] == 'low' and item[-1] == 'R': 255 | shutil.move(path, os.path.join(lr_dir_R, item[1]+".png")) 256 | if item[0] == 'low' and item[-1] == 'I': 257 | shutil.move(path, os.path.join(lr_dir_I, item[1]+".png")) 258 | log(f"Finish...") 259 | 260 | 261 | def change_name(dst_dir): 262 | dir_lists = [] 263 | dir_lists.append(os.path.join(dst_dir, 'low\\R')) 264 | dir_lists.append(os.path.join(dst_dir, 'low\\I')) 265 | dir_lists.append(os.path.join(dst_dir, 'high\\R')) 266 | dir_lists.append(os.path.join(dst_dir, 'high\\I')) 267 | for direction in dir_lists: 268 | for name in os.listdir(direction): 269 | path = os.path.join(direction, name) 270 | name = name[:-4] 271 | item = name.split("_") 272 | os.rename(path, os.path.join(direction, item[1]+".png")) 273 | log(f"Finish...") 274 | 275 | 276 | if __name__ == '__main__': 277 | # noDecom Dataloader Test 278 | # root_path_train = r'H:\datasets\Low-Light Dataset\LOLdataset_decom\our485' 279 | # root_path_test = r'H:\datasets\Low-Light Dataset\LOLdataset_decom\eval15' 280 | # list_path_train = build_LOLDataset_Decom_list_txt(root_path_train) 281 | # list_path_test = build_LOLDataset_Decom_list_txt(root_path_test) 282 | # Batch_size = 2 283 | # log("Buliding LOL Dataset...") 284 | # dst_train = LOLDataset_Decom(root_path_train, list_path_train, crop_size=128, to_RAM=True) 285 | # dst_test = LOLDataset_Decom(root_path_test, list_path_test, crop_size=128, to_RAM=False) 286 | # # But when we are training a model, the mean should have another value 287 | # trainloader = DataLoader(dst_train, batch_size = Batch_size) 288 | # testloader = DataLoader(dst_test, batch_size=1) 289 | # plt.ion() 290 | # for i, data in enumerate(trainloader): 291 | # _, _, _, imgs, name = data 292 | # log(name) 293 | # img = imgs[0].numpy() 294 | # sample(imgs[0], figure_size=(1, 1), img_dim=128) 295 | 296 | # Dataloader Test 297 | root_path_train = r'H:\datasets\Low-Light Dataset\KinD++\LOLdataset\our485' 298 | root_path_test = r'H:\datasets\Low-Light Dataset\KinD++\LOLdataset\eval15' 299 | # root_path_train = r'H:\datasets\Low-Light Dataset\KinD++\LOLdataset\our485' 300 | # root_path_test = r'H:\datasets\Low-Light Dataset\KinD++\LOLdataset\eval15' 301 | list_path_train = build_LOLDataset_list_txt(root_path_train) 302 | list_path_test = build_LOLDataset_list_txt(root_path_test) 303 | Batch_size = 2 304 | log("Buliding LOL Dataset...") 305 | dst_train = LOLDataset(root_path_train, list_path_train, crop_size=128, to_RAM=False) 306 | dst_test = LOLDataset(root_path_test, list_path_test, crop_size=128, to_RAM=False) 307 | # But when we are training a model, the mean should have another value 308 | trainloader = DataLoader(dst_train, batch_size = Batch_size) 309 | testloader = DataLoader(dst_test, batch_size=1) 310 | plt.ion() 311 | for i, data in enumerate(trainloader): 312 | _, imgs, name = data 313 | img = imgs[0].numpy() 314 | sample(imgs[0], figure_size=(1, 1), img_dim=128) 315 | -------------------------------------------------------------------------------- /decom_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.optim import lr_scheduler 5 | import numpy as np 6 | import time 7 | import yaml 8 | import sys 9 | from tqdm import tqdm 10 | from torchvision.utils import make_grid 11 | from torchvision import transforms 12 | from torchsummary import summary 13 | from base_trainer import BaseTrainer 14 | from losses import Decom_Loss 15 | from models import DecomNet 16 | from base_parser import BaseParser 17 | from dataloader import * 18 | 19 | class Decom_Trainer(BaseTrainer): 20 | def train(self): 21 | print(f'Using device {self.device}') 22 | self.model.to(device=self.device) 23 | summary(self.model, input_size=(3, 48, 48)) 24 | # faster convolutions, but more memory 25 | # cudnn.benchmark = True 26 | 27 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) 28 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.997) 29 | try: 30 | for iter in range(self.epochs): 31 | epoch_loss = 0 32 | idx = 0 33 | hook_number = -1 34 | iter_start_time = time.time() 35 | # with tqdm(total=self.steps_per_epoch) as pbar: 36 | for L_low_tensor, L_high_tensor, name in self.dataloader: 37 | L_low = L_low_tensor.to(self.device) 38 | L_high = L_high_tensor.to(self.device) 39 | R_low, I_low = self.model(L_low) 40 | R_high, I_high = self.model(L_high) 41 | if idx % self.print_frequency == 0: 42 | hook_number = -1 43 | loss = self.loss_fn(R_low, R_high, I_low, I_high, L_low, L_high, hook=hook_number) 44 | hook_number = -1 45 | if idx % 8 == 0: 46 | print(f"iter: {iter}_{idx}\taverage_loss: {loss.item():.6f}") 47 | optimizer.zero_grad() 48 | loss.backward() 49 | optimizer.step() 50 | idx += 1 51 | # pbar.update(1) 52 | # pbar.set_postfix({'loss':loss.item()}) 53 | 54 | if iter % self.print_frequency == 0: 55 | self.test(iter, plot_dir='./images/samples-decom') 56 | 57 | if iter % self.save_frequency == 0: 58 | torch.save(self.model.state_dict(), './weights/decom_net_test3.pth') 59 | log("Weight Has saved as 'decom_net.pth'") 60 | 61 | scheduler.step() 62 | iter_end_time = time.time() 63 | log(f"Time taken: {iter_end_time - iter_start_time:.3f} seconds\t lr={scheduler.get_lr()[0]:.6f}") 64 | 65 | except KeyboardInterrupt: 66 | torch.save(self.model.state_dict(), 'INTERRUPTED_decom.pth') 67 | log('Saved interrupt_decom') 68 | try: 69 | sys.exit(0) 70 | except SystemExit: 71 | os._exit(0) 72 | 73 | @no_grad 74 | def test(self, epoch=-1, plot_dir='./images/samples-decom'): 75 | self.model.eval() 76 | hook = 0 77 | for L_low_tensor, L_high_tensor, name in self.dataloader_test: 78 | L_low = L_low_tensor.to(self.device) 79 | L_high = L_high_tensor.to(self.device) 80 | R_low, I_low = self.model(L_low) 81 | R_high, I_high = self.model(L_high) 82 | 83 | if epoch % (self.print_frequency*10) == 0: 84 | loss = self.loss_fn(R_low, R_high, I_low, I_high, L_low, L_high, hook=hook) 85 | hook += 1 86 | loss = 0 87 | 88 | R_low_np = R_low.detach().cpu().numpy()[0] 89 | R_high_np = R_high.detach().cpu().numpy()[0] 90 | I_low_np = I_low.detach().cpu().numpy()[0] 91 | I_high_np = I_high.detach().cpu().numpy()[0] 92 | L_low_np = L_low_tensor.numpy()[0] 93 | L_high_np = L_high_tensor.numpy()[0] 94 | sample_imgs = np.concatenate( (R_low_np, I_low_np, L_low_np, 95 | R_high_np, I_high_np, L_high_np), axis=0 ) 96 | filepath = os.path.join(plot_dir, f'{name[0]}_epoch_{epoch}.png') 97 | split_point = [0, 3, 4, 7, 10, 11, 14] 98 | img_dim = I_low_np.shape[1:] 99 | sample(sample_imgs, split=split_point, figure_size=(2, 3), 100 | img_dim=img_dim, path=filepath, num=epoch) 101 | 102 | 103 | if __name__ == "__main__": 104 | criterion = Decom_Loss() 105 | model = DecomNet() 106 | 107 | parser = BaseParser() 108 | args = parser.parse() 109 | args.checkpoint = True 110 | if args.checkpoint is not None: 111 | pretrain = torch.load('./weights/decom_net.pth') 112 | model.load_state_dict(pretrain) 113 | print('Model loaded from decom_net.pth') 114 | 115 | with open(args.config) as f: 116 | config = yaml.load(f) 117 | 118 | root_path_train = r'H:\datasets\Low-Light Dataset\KinD++\LOLdataset\our485' 119 | root_path_test = r'C:\DeepLearning\KinD_plus-master\LOLdataset\eval15' 120 | list_path_train = build_LOLDataset_list_txt(root_path_train) 121 | list_path_test = build_LOLDataset_list_txt(root_path_test) 122 | # list_path_train = os.path.join(root_path_train, 'pair_list.csv') 123 | # list_path_test = os.path.join(root_path_test, 'pair_list.csv') 124 | 125 | log("Buliding LOL Dataset...") 126 | # transform = transforms.Compose([transforms.ToTensor(),]) 127 | dst_train = LOLDataset(root_path_train, list_path_train, 128 | crop_size=config['length'], to_RAM=True) 129 | dst_test = LOLDataset(root_path_test, list_path_test, 130 | crop_size=config['length'], to_RAM=True, training=False) 131 | 132 | train_loader = DataLoader(dst_train, batch_size = config['batch_size'], shuffle=True) 133 | test_loader = DataLoader(dst_test, batch_size=1) 134 | 135 | # if args.noDecom is True: 136 | # root_path_valid = r'H:\datasets\Low-Light Dataset\LOLdataset_decom\our485' 137 | # list_path_valid = os.path.join(root_path_test, 'pair_list.csv') 138 | 139 | # log("Buliding LOL Dataset (noDecom)...") 140 | # # transform = transforms.Compose([transforms.ToTensor()]) 141 | # dst_valid = LOLDataset_Decom(root_path_test, list_path_test, 142 | # crop_size=config['length'], to_RAM=True, training=False) 143 | # valid_loader = DataLoader(dst_train, batch_size = config['batch_size'], shuffle=True) 144 | 145 | trainer = Decom_Trainer(config, train_loader, criterion, model, dataloader_test=test_loader) 146 | # --config ./config/config.yaml 147 | if args.mode == 'train': 148 | trainer.train() 149 | else: 150 | trainer.test() -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import time 6 | import yaml 7 | import sys 8 | from tqdm import tqdm 9 | from torch.optim import lr_scheduler 10 | from torchvision.utils import make_grid 11 | from torchvision import transforms 12 | from torchsummary import summary 13 | from base_trainer import BaseTrainer 14 | from losses import * 15 | from models import * 16 | from base_parser import BaseParser 17 | from dataloader import * 18 | 19 | class KinD_noDecom_Trainer(BaseTrainer): 20 | @no_grad 21 | def test(self, epoch=-1, plot_dir='./images/samples-KinD'): 22 | self.model.eval() 23 | self.model.to(device=self.device) 24 | if 'decom_net' in model._modules: 25 | for L_low_tensor, L_high_tensor, name in self.dataloader_test: 26 | L_low = L_low_tensor.to(self.device) 27 | L_high = L_high_tensor.to(self.device) 28 | 29 | R_low, I_low = self.model.decom_net(L_low) 30 | R_high, I_high = self.model.decom_net(L_high) 31 | I_low_3 = torch.cat([I_low, I_low, I_low], dim=1) 32 | I_high_3 = torch.cat([I_high, I_high, I_high], dim=1) 33 | 34 | output_low = I_low_3 * R_low 35 | output_high = I_high_3 * R_high 36 | 37 | b = 0.7; w=0.5 38 | bright_low = torch.mean(I_low) 39 | # bright_high = torch.mean(I_high) 40 | bright_high = torch.ones_like(bright_low) * b + bright_low * w 41 | ratio = torch.div(bright_high, bright_low) 42 | log(f"Brightness: {bright_high}\tIllumation Magnification: {ratio.item()}") 43 | # ratio_map = torch.ones_like(I_low) * ratio 44 | 45 | R_final, I_final, output_final = self.model(L_low, ratio) 46 | 47 | R_final_np = R_final.detach().cpu().numpy()[0] 48 | I_final_np = I_final.detach().cpu().numpy()[0] 49 | R_low_np = R_low.detach().cpu().numpy()[0] 50 | I_low_np = I_low.detach().cpu().numpy()[0] 51 | R_high_np = R_high.detach().cpu().numpy()[0] 52 | I_high_np = I_high.detach().cpu().numpy()[0] 53 | output_final_np = output_final.detach().cpu().numpy()[0] 54 | output_low_np = output_low.detach().cpu().numpy()[0] 55 | output_high_np = output_high.detach().cpu().numpy()[0] 56 | # ratio_map_np = ratio_map.detach().cpu().numpy()[0] 57 | L_low_np = L_low_tensor.numpy()[0] 58 | L_high_np = L_high_tensor.numpy()[0] 59 | 60 | sample_imgs = np.concatenate( (R_low_np, I_low_np, output_low_np, L_low_np, 61 | R_high_np, I_high_np, output_high_np, L_high_np, 62 | R_final_np, I_final_np, output_final_np, L_high_np), axis=0 ) 63 | 64 | filepath = os.path.join(plot_dir, f'{name[0]}_epoch_{epoch}.png') 65 | split_point = [0, 3, 4, 7, 10, 13, 14, 17, 20, 23, 24, 27, 30] 66 | img_dim = I_high_np.shape[1:] 67 | sample(sample_imgs, split=split_point, figure_size=(3, 4), 68 | img_dim=img_dim, path=filepath, num=epoch) 69 | else: 70 | for R_low_tensor, I_low_tensor, R_high_tensor, I_high_tensor, name in self.dataloader_test: 71 | R_low = R_low_tensor.to(self.device) 72 | R_high = R_high_tensor.to(self.device) 73 | I_low = I_low_tensor.to(self.device) 74 | I_high = I_high_tensor.to(self.device) 75 | I_high_3 = torch.cat([I_high, I_high, I_high], dim=1) 76 | output_high = I_high_3 * R_high 77 | 78 | # while True: 79 | # b = float(input('请输入增强水平:')) 80 | # if b <= 0: break 81 | b = 0.6; w = 0.5 82 | bright_low = torch.mean(I_low) 83 | bright_high = torch.ones_like(bright_low) * b + bright_low * w 84 | ratio = torch.div(bright_high, bright_low) 85 | print(bright_high, ratio) 86 | # ratio_map = torch.ones_like(I_low) * ratio 87 | 88 | R_final, I_final, output_final = self.model(R_low, I_low, ratio) 89 | 90 | R_final_np = R_final.detach().cpu().numpy()[0] 91 | I_final_np = I_final.detach().cpu().numpy()[0] 92 | output_final_np = output_final.detach().cpu().numpy()[0] 93 | output_high_np = output_high.detach().cpu().numpy()[0] 94 | # ratio_map_np = ratio_map.detach().cpu().numpy()[0] 95 | I_high_np = I_high_tensor.numpy()[0] 96 | R_high_np = R_high_tensor.numpy()[0] 97 | 98 | sample_imgs = np.concatenate( (R_high_np, I_high_np, output_high_np, 99 | R_final_np, I_final_np, output_final_np), axis=0 ) 100 | 101 | filepath = os.path.join(plot_dir, f'{name[0]}_epoch_{epoch}.png') 102 | split_point = [0, 3, 4, 7, 10, 11, 14] 103 | img_dim = I_high_np.shape[1:] 104 | sample(sample_imgs, split=split_point, figure_size=(2, 3), 105 | img_dim=img_dim, path=filepath, num=epoch) 106 | 107 | 108 | if __name__ == "__main__": 109 | criterion = None 110 | parser = BaseParser() 111 | args = parser.parse() 112 | # args.noDecom = True 113 | with open(args.config) as f: 114 | config = yaml.load(f) 115 | if config['noDecom'] is True: 116 | model = KinD_noDecom() 117 | else: 118 | model = KinD() 119 | 120 | if args.checkpoint is not None: 121 | if config['noDecom'] is False: 122 | pretrain_decom = torch.load('./weights/decom_net.pth') 123 | model.decom_net.load_state_dict(pretrain_decom) 124 | log('Model loaded from decom_net.pth') 125 | pretrain_resotre = torch.load('./weights/restore_net.pth') 126 | model.restore_net.load_state_dict(pretrain_resotre) 127 | log('Model loaded from restore_net.pth') 128 | pretrain_illum = torch.load('./weights/illum_net.pth') 129 | model.illum_net.load_state_dict(pretrain_illum) 130 | log('Model loaded from illum_net.pth') 131 | 132 | if config['noDecom'] is True: 133 | root_path_test = r'H:\datasets\Low-Light Dataset\LOLdataset_decom\eval15' 134 | list_path_test = os.path.join(root_path_test, 'pair_list.csv') 135 | 136 | log("Buliding LOL Dataset (noDecom)...") 137 | # transform = transforms.Compose([transforms.ToTensor()]) 138 | dst_test = LOLDataset_Decom(root_path_test, list_path_test, 139 | crop_size=config['length'], to_RAM=True, training=False) 140 | else: 141 | root_path_test = r'C:\DeepLearning\KinD_plus-master\LOLdataset\eval15' 142 | list_path_test = os.path.join(root_path_test, 'pair_list.csv') 143 | 144 | log("Buliding LOL Dataset...") 145 | # transform = transforms.Compose([transforms.ToTensor()]) 146 | dst_test = LOLDataset(root_path_test, list_path_test, 147 | crop_size=config['length'], to_RAM=True, training=False) 148 | 149 | test_loader = DataLoader(dst_test, batch_size=1) 150 | 151 | KinD = KinD_noDecom_Trainer(config, None, criterion, model, dataloader_test=test_loader) 152 | 153 | # Please change your output direction here 154 | output_dir = './images/samples-KinD' 155 | KinD.test(plot_dir=output_dir) -------------------------------------------------------------------------------- /figures/778_epoch_-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenghansen/KinD-pytorch/815d26f538221ba5b59799f5439d491579f6bd2d/figures/778_epoch_-1.png -------------------------------------------------------------------------------- /figures/failure_case_decom.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenghansen/KinD-pytorch/815d26f538221ba5b59799f5439d491579f6bd2d/figures/failure_case_decom.png -------------------------------------------------------------------------------- /figures/official_decom_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fenghansen/KinD-pytorch/815d26f538221ba5b59799f5439d491579f6bd2d/figures/official_decom_train.png -------------------------------------------------------------------------------- /illum_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.optim import lr_scheduler 5 | import numpy as np 6 | import time 7 | import yaml 8 | import sys 9 | from tqdm import tqdm 10 | from torchvision.utils import make_grid 11 | from torchvision import transforms 12 | from torchsummary import summary 13 | from base_trainer import BaseTrainer 14 | from losses import * 15 | from models import * 16 | from base_parser import BaseParser 17 | from dataloader import * 18 | 19 | class Illum_Trainer(BaseTrainer): 20 | def __init__(self, config, dataloader, criterion, model, 21 | dataloader_test=None, decom_net=None): 22 | super().__init__(config, dataloader, criterion, model, dataloader_test) 23 | log(f'Using device {self.device}') 24 | self.decom_net = decom_net 25 | self.decom_net.to(device=self.device) 26 | 27 | def train(self): 28 | self.model.train() 29 | log(f'Using device {self.device}') 30 | self.model.to(device=self.device) 31 | print(self.model) 32 | # summary(self.model, input_size=[(1, 384, 384), (1,)], batch_size=4) 33 | 34 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) 35 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99426) 36 | try: 37 | for iter in range(self.epochs): 38 | epoch_loss = 0 39 | idx = 0 40 | hook_number = -1 41 | iter_start_time = time.time() 42 | if self.noDecom is True: 43 | # with tqdm(total=self.steps_per_epoch) as pbar: 44 | for R_low_tensor, I_low_tensor, R_high_tensor, I_high_tensor, name in self.dataloader: 45 | optimizer.zero_grad() 46 | I_low = I_low_tensor.to(self.device) 47 | I_high = I_high_tensor.to(self.device) 48 | with torch.no_grad(): 49 | ratio_high2low = torch.mean(torch.div((I_low + 0.0001), (I_high + 0.0001))) 50 | ratio_low2high = torch.mean(torch.div((I_high + 0.0001), (I_low + 0.0001))) 51 | 52 | I_low2high_map = self.model(I_low, ratio_low2high) 53 | I_high2low_map = self.model(I_high, ratio_high2low) 54 | 55 | if idx % self.print_frequency == 0: 56 | hook_number = iter 57 | loss = self.loss_fn(I_low2high_map, I_high, hook=hook_number) + self.loss_fn(I_high2low_map, I_low, hook=hook_number) 58 | hook_number = -1 59 | if idx % 30 == 0: 60 | log(f"iter: {iter}_{idx}\taverage_loss: {loss.item():.6f}") 61 | print(ratio_high2low, ratio_low2high) 62 | loss.backward() 63 | optimizer.step() 64 | idx += 1 65 | # pbar.update(1) 66 | # pbar.set_postfix({'loss':loss.item()}) 67 | else: 68 | # with tqdm(total=self.steps_per_epoch) as pbar: 69 | for L_low_tensor, L_high_tensor, name in self.dataloader: 70 | optimizer.zero_grad() 71 | L_low = L_low_tensor.to(self.device) 72 | L_high = L_high_tensor.to(self.device) 73 | 74 | with torch.no_grad(): 75 | R_low, I_low = self.decom_net(L_low) 76 | R_high, I_high = self.decom_net(L_high) 77 | # ratio_high2low = torch.mean(torch.div((I_low + 0.0001), (I_high + 0.0001))) 78 | # ratio_low2high = torch.mean(torch.div((I_high + 0.0001), (I_low + 0.0001))) 79 | bright_low = torch.mean(I_low) 80 | bright_high = torch.mean(I_high) 81 | ratio_high2low = torch.div(bright_low, bright_high) 82 | ratio_low2high = torch.div(bright_high, bright_low) 83 | 84 | I_low2high_map = self.model(I_low, ratio_low2high) 85 | I_high2low_map = self.model(I_high, ratio_high2low) 86 | 87 | loss = self.loss_fn(I_low2high_map, I_high, hook=hook_number) + \ 88 | self.loss_fn(I_high2low_map, I_low, hook=hook_number) 89 | 90 | if idx % 30 == 0: 91 | log(f"iter: {iter}_{idx}\taverage_loss: {loss.item():.6f}") 92 | print(ratio_high2low, ratio_low2high) 93 | loss.backward() 94 | optimizer.step() 95 | idx += 1 96 | # pbar.update(1) 97 | # pbar.set_postfix({'loss':loss.item()}) 98 | 99 | if iter % self.print_frequency == 0: 100 | self.test(iter, plot_dir='./images/samples-illum') 101 | 102 | if iter % self.save_frequency == 0: 103 | torch.save(self.model.state_dict(), './weights/illum_net.pth') 104 | log("Weight Has saved as 'illum_net.pth'") 105 | 106 | scheduler.step() 107 | iter_end_time = time.time() 108 | log(f"Time taken: {iter_end_time - iter_start_time:.3f} seconds\t lr={scheduler.get_lr()[0]:.6f}") 109 | 110 | except KeyboardInterrupt: 111 | torch.save(self.model.state_dict(), './weights/INTERRUPTED_illum.pth') 112 | log('Saved interrupt') 113 | try: 114 | sys.exit(0) 115 | except SystemExit: 116 | os._exit(0) 117 | 118 | @no_grad 119 | def test(self, epoch=-1, plot_dir='./images/samples-illum'): 120 | self.model.eval() 121 | if self.noDecom: 122 | for R_low_tensor, I_low_tensor, R_high_tensor, I_high_tensor, name in self.dataloader_test: 123 | I_low = I_low_tensor.to(self.device) 124 | I_high = I_high_tensor.to(self.device) 125 | 126 | ratio_high2low = torch.mean(torch.div((I_low + 0.0001), (I_high + 0.0001))) 127 | ratio_low2high = torch.mean(torch.div((I_high + 0.0001), (I_low + 0.0001))) 128 | print(ratio_low2high) 129 | # 采用粗略的亮度水平估计 130 | bright_low = torch.mean(I_low) 131 | bright_high = torch.ones_like(bright_low) * 0.3 + bright_low * 0.55 132 | ratio_high2low = torch.div(bright_low, bright_high) 133 | ratio_low2high = torch.div(bright_high, bright_low) 134 | print(ratio_low2high) 135 | 136 | I_low2high_map = self.model(I_low, ratio_low2high) 137 | I_high2low_map = self.model(I_high, ratio_high2low) 138 | 139 | I_low2high_np = I_low2high_map.detach().cpu().numpy()[0] 140 | I_high2low_np = I_high2low_map.detach().cpu().numpy()[0] 141 | I_low_np = I_low_tensor.numpy()[0] 142 | I_high_np = I_high_tensor.numpy()[0] 143 | sample_imgs = np.concatenate( (I_low_np, I_high_np, I_high2low_np, I_low2high_np), axis=0 ) 144 | 145 | filepath = os.path.join(plot_dir, f'{name[0]}_epoch_{epoch}.png') 146 | split_point = [0, 1, 2, 3, 4] 147 | img_dim = I_low_np.shape[1:] 148 | sample(sample_imgs, split=split_point, figure_size=(2, 2), 149 | img_dim=img_dim, path=filepath, num=epoch) 150 | else: 151 | for L_low_tensor, L_high_tensor, name in self.dataloader_test: 152 | L_low = L_low_tensor.to(self.device) 153 | L_high = L_high_tensor.to(self.device) 154 | 155 | R_low, I_low = self.decom_net(L_low) 156 | R_high, I_high = self.decom_net(L_high) 157 | bright_low = torch.mean(I_low) 158 | bright_high = torch.mean(I_high) 159 | ratio_high2low = torch.div(bright_low, bright_high) 160 | ratio_low2high = torch.div(bright_high, bright_low) 161 | print(ratio_low2high) 162 | 163 | I_low2high_map = self.model(I_low, ratio_low2high) 164 | I_high2low_map = self.model(I_high, ratio_high2low) 165 | 166 | I_low2high_np = I_low2high_map.detach().cpu().numpy()[0] 167 | I_high2low_np = I_high2low_map.detach().cpu().numpy()[0] 168 | I_low_np = I_low.detach().cpu().numpy()[0] 169 | I_high_np = I_high.detach().cpu().numpy()[0] 170 | sample_imgs = np.concatenate( (I_low_np, I_high_np, I_high2low_np, I_low2high_np), axis=0 ) 171 | 172 | filepath = os.path.join(plot_dir, f'{name[0]}_epoch_{epoch}.png') 173 | split_point = [0, 1, 2, 3, 4] 174 | img_dim = I_low_np.shape[1:] 175 | sample(sample_imgs, split=split_point, figure_size=(2, 2), 176 | img_dim=img_dim, path=filepath, num=epoch) 177 | 178 | if __name__ == "__main__": 179 | criterion = Illum_Loss() 180 | decom_net = DecomNet() 181 | model = IllumNet() 182 | 183 | parser = BaseParser() 184 | args = parser.parse() 185 | 186 | with open(args.config) as f: 187 | config = yaml.load(f) 188 | 189 | args.checkpoint = True 190 | if args.checkpoint is not None: 191 | if config['noDecom'] is False: 192 | decom_net = load_weights(decom_net, path='./weights/decom_net.pth') 193 | log('DecomNet loaded from decom_net.pth') 194 | model = load_weights(model, path='./weights/illum_net.pth') 195 | log('Model loaded from illum_net.pth') 196 | 197 | if config['noDecom'] is True: 198 | root_path_train = r'H:\datasets\Low-Light Dataset\LOLdataset_decom\our485' 199 | root_path_test = r'H:\datasets\Low-Light Dataset\LOLdataset_decom\eval15' 200 | list_path_train = build_LOLDataset_Decom_list_txt(root_path_train) 201 | list_path_test = build_LOLDataset_Decom_list_txt(root_path_test) 202 | # list_path_train = os.path.join(root_path_train, 'pair_list.csv') 203 | # list_path_test = os.path.join(root_path_test, 'pair_list.csv') 204 | 205 | log("Buliding LOL Dataset...") 206 | # transform = transforms.Compose([transforms.ToTensor()]) 207 | dst_train = LOLDataset_Decom(root_path_train, list_path_train, 208 | crop_size=config['length'], to_RAM=True) 209 | dst_test = LOLDataset_Decom(root_path_test, list_path_test, 210 | crop_size=config['length'], to_RAM=True, training=False) 211 | 212 | train_loader = DataLoader(dst_train, batch_size = config['batch_size'], shuffle=True) 213 | test_loader = DataLoader(dst_test, batch_size=1) 214 | 215 | else: 216 | root_path_train = r'C:\DeepLearning\KinD_plus-master\LOLdataset\our485' 217 | root_path_test = r'C:\DeepLearning\KinD_plus-master\LOLdataset\eval15' 218 | list_path_train = build_LOLDataset_list_txt(root_path_train) 219 | list_path_test = build_LOLDataset_list_txt(root_path_test) 220 | # list_path_train = os.path.join(root_path_train, 'pair_list.csv') 221 | # list_path_test = os.path.join(root_path_test, 'pair_list.csv') 222 | 223 | log("Buliding LOL Dataset...") 224 | # transform = transforms.Compose([transforms.ToTensor()]) 225 | dst_train = LOLDataset(root_path_train, list_path_train, 226 | crop_size=config['length'], to_RAM=True) 227 | dst_test = LOLDataset(root_path_test, list_path_test, 228 | crop_size=config['length'], to_RAM=True, training=False) 229 | 230 | train_loader = DataLoader(dst_train, batch_size = config['batch_size'], shuffle=True) 231 | test_loader = DataLoader(dst_test, batch_size=1) 232 | 233 | trainer = Illum_Trainer(config, train_loader, criterion, model, 234 | dataloader_test=test_loader, decom_net=decom_net) 235 | 236 | if args.mode == 'train': 237 | trainer.train() 238 | else: 239 | trainer.test() -------------------------------------------------------------------------------- /illum_trainer_custom.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.optim import lr_scheduler 5 | import numpy as np 6 | import time 7 | import yaml 8 | import sys 9 | from tqdm import tqdm 10 | from torchvision.utils import make_grid 11 | from torchvision import transforms 12 | from torchsummary import summary 13 | from base_trainer import BaseTrainer 14 | from losses import * 15 | from models import * 16 | from base_parser import BaseParser 17 | from dataloader import * 18 | 19 | class Illum_Trainer(BaseTrainer): 20 | def __init__(self, config, dataloader, criterion, model, 21 | dataloader_test=None, decom_net=None): 22 | super().__init__(config, dataloader, criterion, model, dataloader_test) 23 | log(f'Using device {self.device}') 24 | self.decom_net = decom_net 25 | self.decom_net.to(device=self.device) 26 | torch.backends.cudnn.benchmark = True 27 | 28 | def train(self): 29 | self.model.train() 30 | log(f'Using device {self.device}') 31 | self.model.to(device=self.device) 32 | print(self.model) 33 | # summary(self.model, input_size=[(1, 384, 384), (1,)], batch_size=4) 34 | 35 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) 36 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99426) 37 | try: 38 | for iter in range(self.epochs): 39 | idx = 0 40 | iter_start_time = time.time() 41 | for L_low_tensor, L_high_tensor, name in self.dataloader: 42 | optimizer.zero_grad() 43 | L_low = L_low_tensor.to(self.device) 44 | L_high = L_high_tensor.to(self.device) 45 | 46 | with torch.no_grad(): 47 | _, I_low = self.decom_net(L_low) 48 | _, I_high = self.decom_net(L_high) 49 | 50 | I_out, I_standard = self.model(I_low, 1) 51 | loss = self.loss_fn(I_out, I_high, I_standard) 52 | 53 | if idx % 6 == 0: 54 | log(f"iter: {iter}_{idx}\taverage_loss: {loss.item():.6f}") 55 | loss.backward() 56 | optimizer.step() 57 | idx += 1 58 | 59 | if iter % self.print_frequency == 0: 60 | self.test(iter, plot_dir='./images/samples-illum-custom') 61 | 62 | if iter % self.save_frequency == 0: 63 | torch.save(self.model.state_dict(), f'./weights/illum_net_custom_{iter//100}.pth') 64 | log("Weight Has saved as 'illum_net.pth'") 65 | 66 | scheduler.step() 67 | iter_end_time = time.time() 68 | w, sigma = self.model.get_parameter() 69 | log(f"w:{float(w):.4f}\t sigma:{float(sigma):.2f}") 70 | log(f"Time taken: {iter_end_time - iter_start_time:.3f} seconds\t lr={scheduler.get_lr()[0]:.6f}") 71 | 72 | except KeyboardInterrupt: 73 | torch.save(self.model.state_dict(), './weights/INTERRUPTED_illum_custom.pth') 74 | log('Saved interrupt') 75 | try: 76 | sys.exit(0) 77 | except SystemExit: 78 | os._exit(0) 79 | 80 | @no_grad 81 | def test(self, epoch=-1, plot_dir='./images/samples-illum'): 82 | self.model.eval() 83 | for L_low_tensor, L_high_tensor, name in self.dataloader_test: 84 | L_low = L_low_tensor.to(self.device) 85 | L_high = L_high_tensor.to(self.device) 86 | 87 | with torch.no_grad(): 88 | _, I_low = self.decom_net(L_low) 89 | _, I_high = self.decom_net(L_high) 90 | I_out, I_standard = self.model(I_low, 1) 91 | # I_low_standard = standard_illum(I_low, w=0.72, gamma=0.53, blur=True) 92 | # I_high_standard = standard_illum(I_high, w=0.08, gamma=1.34) 93 | 94 | I_standard_np = I_standard.detach().cpu().numpy()[0] 95 | I_out_np = I_out.detach().cpu().numpy()[0] 96 | I_low_np = I_low.detach().cpu().numpy()[0] 97 | I_high_np = I_high.detach().cpu().numpy()[0] 98 | # I_low_standard = standard_illum(I_low_np, dynamic=3) 99 | # I_high_standard = standard_illum(I_high_np) 100 | 101 | sample_imgs = np.concatenate( (I_low_np, I_high_np, I_standard_np, I_out_np), axis=0 ) 102 | 103 | filepath = os.path.join(plot_dir, f'{name[0]}_epoch_{epoch//100}.png') 104 | split_point = [0, 1, 2, 3, 4] 105 | img_dim = I_low_np.shape[1:] 106 | sample(sample_imgs, split=split_point, figure_size=(2, 2), 107 | img_dim=img_dim, path=filepath, num=epoch) 108 | 109 | 110 | if __name__ == "__main__": 111 | criterion = Illum_Custom_Loss() 112 | decom_net = DecomNet() 113 | model = IllumNet_Custom() 114 | 115 | parser = BaseParser() 116 | args = parser.parse() 117 | 118 | with open(args.config) as f: 119 | config = yaml.load(f) 120 | 121 | args.checkpoint = True 122 | if args.checkpoint is not None: 123 | if config['noDecom'] is False: 124 | decom_net = load_weights(decom_net, path='./weights/decom_net.pth') 125 | log('DecomNet loaded from decom_net.pth') 126 | model = load_weights(model, path='./weights/illum_net_custom_0.pth') 127 | log('Model loaded from illum_net.pth') 128 | 129 | root_path_train = r'C:\DeepLearning\KinD_plus-master\LOLdataset\our485' 130 | root_path_test = r'C:\DeepLearning\KinD_plus-master\LOLdataset\eval15' 131 | list_path_train = build_LOLDataset_list_txt(root_path_train) 132 | list_path_test = build_LOLDataset_list_txt(root_path_test) 133 | # list_path_train = os.path.join(root_path_train, 'pair_list.csv') 134 | # list_path_test = os.path.join(root_path_test, 'pair_list.csv') 135 | 136 | log("Buliding LOL Dataset...") 137 | # transform = transforms.Compose([transforms.ToTensor()]) 138 | dst_train = LOLDataset(root_path_train, list_path_train, 139 | crop_size=config['length'], to_RAM=True) 140 | dst_test = LOLDataset(root_path_test, list_path_test, 141 | crop_size=config['length'], to_RAM=True, training=False) 142 | 143 | train_loader = DataLoader(dst_train, batch_size = config['batch_size'], shuffle=True) 144 | test_loader = DataLoader(dst_test, batch_size=1) 145 | 146 | trainer = Illum_Trainer(config, train_loader, criterion, model, 147 | dataloader_test=test_loader, decom_net=decom_net) 148 | 149 | if args.mode == 'train': 150 | trainer.train() 151 | else: 152 | trainer.test() -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import pytorch_ssim 6 | from dataloader import * 7 | 8 | Sobel = np.array([[-1,-2,-1], 9 | [ 0, 0, 0], 10 | [ 1, 2, 1]]) 11 | Robert = np.array([[0, 0], 12 | [-1, 1]]) 13 | Sobel = torch.Tensor(Sobel) 14 | Robert = torch.Tensor(Robert) 15 | 16 | def feature_map_hook(*args, path=None): 17 | feature_maps = [] 18 | for feature in args: 19 | feature_maps.append(feature) 20 | feature_all = torch.cat(feature_maps, dim=1) 21 | fmap = feature_all.detach().cpu().numpy()[0] 22 | fmap = np.array(fmap) 23 | fshape = fmap.shape 24 | num = fshape[0] 25 | shape = fshape[1:] 26 | sample(fmap, figure_size=(2, num//2), img_dim=shape, path=path) 27 | return fmap 28 | 29 | # 已测试本模块没有问题,作用为提取一阶导数算子滤波图(边缘图) 30 | def gradient(maps, direction, device='cuda', kernel='sobel'): 31 | channels = maps.size()[1] 32 | if kernel == 'robert': 33 | smooth_kernel_x = Robert.expand(channels, channels, 2, 2) 34 | maps = F.pad(maps, (0, 0, 1, 1)) 35 | elif kernel == 'sobel': 36 | smooth_kernel_x = Sobel.expand(channels, channels, 3, 3) 37 | maps = F.pad(maps, (1, 1, 1, 1)) 38 | smooth_kernel_y = smooth_kernel_x.permute(0, 1, 3, 2) 39 | if direction == "x": 40 | kernel = smooth_kernel_x 41 | elif direction == "y": 42 | kernel = smooth_kernel_y 43 | kernel = kernel.to(device=device) 44 | # kernel size is (2, 2) so need pad bottom and right side 45 | gradient_orig = torch.abs(F.conv2d(maps, weight=kernel, padding=0)) 46 | grad_min = torch.min(gradient_orig) 47 | grad_max = torch.max(gradient_orig) 48 | grad_norm = torch.div((gradient_orig - grad_min), (grad_max - grad_min + 0.0001)) 49 | return grad_norm 50 | 51 | 52 | def gradient_no_abs(maps, direction, device='cuda', kernel='sobel'): 53 | channels = maps.size()[1] 54 | if kernel == 'robert': 55 | smooth_kernel_x = Robert.expand(channels, channels, 2, 2) 56 | maps = F.pad(maps, (0, 0, 1, 1)) 57 | elif kernel == 'sobel': 58 | smooth_kernel_x = Sobel.expand(channels, channels, 3, 3) 59 | maps = F.pad(maps, (1, 1, 1, 1)) 60 | smooth_kernel_y = smooth_kernel_x.permute(0, 1, 3, 2) 61 | if direction == "x": 62 | kernel = smooth_kernel_x 63 | elif direction == "y": 64 | kernel = smooth_kernel_y 65 | kernel = kernel.to(device=device) 66 | # kernel size is (2, 2) so need pad bottom and right side 67 | gradient_orig = torch.abs(F.conv2d(maps, weight=kernel, padding=0)) 68 | grad_min = torch.min(gradient_orig) 69 | grad_max = torch.max(gradient_orig) 70 | grad_norm = torch.div((gradient_orig - grad_min), (grad_max - grad_min + 0.0001)) 71 | return grad_norm 72 | 73 | 74 | class Decom_Loss(nn.Module): 75 | def __init__(self): 76 | super().__init__() 77 | 78 | def reflectance_similarity(self, R_low, R_high): 79 | return torch.mean(torch.abs(R_low - R_high)) 80 | 81 | def illumination_smoothness(self, I, L, name='low', hook=-1): 82 | # L_transpose = L.permute(0, 2, 3, 1) 83 | # L_gray_transpose = 0.299*L[:,:,:,0] + 0.587*L[:,:,:,1] + 0.114*L[:,:,:,2] 84 | # L_gray = L.permute(0, 3, 1, 2) 85 | L_gray = 0.299*L[:,0,:,:] + 0.587*L[:,1,:,:] + 0.114*L[:,2,:,:] 86 | L_gray = L_gray.unsqueeze(dim=1) 87 | I_gradient_x = gradient(I, "x") 88 | L_gradient_x = gradient(L_gray, "x") 89 | epsilon = 0.01*torch.ones_like(L_gradient_x) 90 | Denominator_x = torch.max(L_gradient_x, epsilon) 91 | x_loss = torch.abs(torch.div(I_gradient_x, Denominator_x)) 92 | I_gradient_y = gradient(I, "y") 93 | L_gradient_y = gradient(L_gray, "y") 94 | Denominator_y = torch.max(L_gradient_y, epsilon) 95 | y_loss = torch.abs(torch.div(I_gradient_y, Denominator_y)) 96 | mut_loss = torch.mean(x_loss + y_loss) 97 | if hook > -1: 98 | feature_map_hook(I, L_gray, epsilon, I_gradient_x+I_gradient_y, Denominator_x+Denominator_y, 99 | x_loss+y_loss, path=f'./images/samples-features/ilux_smooth_{name}_epoch{hook}.png') 100 | return mut_loss 101 | 102 | def mutual_consistency(self, I_low, I_high, hook=-1): 103 | low_gradient_x = gradient(I_low, "x") 104 | high_gradient_x = gradient(I_high, "x") 105 | M_gradient_x = low_gradient_x + high_gradient_x 106 | x_loss = M_gradient_x * torch.exp(-10 * M_gradient_x) 107 | low_gradient_y = gradient(I_low, "y") 108 | high_gradient_y = gradient(I_high, "y") 109 | M_gradient_y = low_gradient_y + high_gradient_y 110 | y_loss = M_gradient_y * torch.exp(-10 * M_gradient_y) 111 | mutual_loss = torch.mean(x_loss + y_loss) 112 | if hook > -1: 113 | feature_map_hook(I_low, I_high, low_gradient_x+low_gradient_y, high_gradient_x+high_gradient_y, 114 | M_gradient_x + M_gradient_y, x_loss+ y_loss, path=f'./images/samples-features/mutual_consist_epoch{hook}.png') 115 | return mutual_loss 116 | 117 | def reconstruction_error(self, R_low, R_high, I_low_3, I_high_3, L_low, L_high): 118 | recon_loss_low = torch.mean(torch.abs(R_low * I_low_3 - L_low)) 119 | recon_loss_high = torch.mean(torch.abs(R_high * I_high_3 - L_high)) 120 | # recon_loss_l2h = torch.mean(torch.abs(R_high * I_low_3 - L_low)) 121 | # recon_loss_h2l = torch.mean(torch.abs(R_low * I_high_3 - L_high)) 122 | return recon_loss_high + recon_loss_low # + recon_loss_l2h + recon_loss_h2l 123 | 124 | def forward(self, R_low, R_high, I_low, I_high, L_low, L_high, hook=-1): 125 | I_low_3 = torch.cat([I_low, I_low, I_low], dim=1) 126 | I_high_3 = torch.cat([I_high, I_high, I_high], dim=1) 127 | #network output 128 | recon_loss = self.reconstruction_error(R_low, R_high, I_low_3, I_high_3, L_low, L_high) 129 | equal_R_loss = self.reflectance_similarity(R_low, R_high) 130 | i_mutual_loss = self.mutual_consistency(I_low, I_high, hook=hook) 131 | ilux_smooth_loss = self.illumination_smoothness(I_low, L_low, hook=hook) + \ 132 | self.illumination_smoothness(I_high, L_high, name='high', hook=hook) 133 | 134 | decom_loss = recon_loss + 0.009 * equal_R_loss + 0.2 * i_mutual_loss + 0.15 * ilux_smooth_loss 135 | 136 | return decom_loss 137 | 138 | 139 | class Illum_Loss(nn.Module): 140 | def __init__(self): 141 | super().__init__() 142 | 143 | def grad_loss(self, low, high, hook=-1): 144 | x_loss = F.l1_loss(gradient_no_abs(low, 'x'), gradient_no_abs(high, 'x')) 145 | y_loss = F.l1_loss(gradient_no_abs(low, 'y'), gradient_no_abs(high, 'y')) 146 | grad_loss_all = x_loss + y_loss 147 | return grad_loss_all 148 | 149 | def forward(self, I_low, I_high, hook=-1): 150 | loss_grad = self.grad_loss(I_low, I_high, hook=hook) 151 | loss_recon = F.l1_loss(I_low, I_high) 152 | loss_adjust = loss_recon + loss_grad 153 | return loss_adjust 154 | 155 | class Illum_Custom_Loss(nn.Module): 156 | def __init__(self): 157 | super().__init__() 158 | 159 | def grad_loss(self, low, high): 160 | x_loss = F.l1_loss(gradient_no_abs(low, 'x'), gradient_no_abs(high, 'x')) 161 | y_loss = F.l1_loss(gradient_no_abs(low, 'y'), gradient_no_abs(high, 'y')) 162 | grad_loss_all = x_loss + y_loss 163 | return grad_loss_all 164 | 165 | def gamma_loss(self, I_standard, I_high): 166 | loss = F.l1_loss(I_high, I_standard) 167 | return loss 168 | 169 | def forward(self, I_low, I_high, I_standard): 170 | loss_gamma = self.gamma_loss(I_standard, I_high) 171 | loss_grad = self.grad_loss(I_low, I_high) 172 | loss_recon = F.l1_loss(I_low, I_high) 173 | loss_adjust = loss_gamma + loss_recon + loss_grad 174 | return loss_adjust 175 | 176 | 177 | class Restore_Loss(nn.Module): 178 | def __init__(self): 179 | super().__init__() 180 | self.ssim_loss = pytorch_ssim.SSIM() 181 | 182 | def grad_loss(self, low, high, hook=-1): 183 | x_loss = F.mse_loss(gradient_no_abs(low, 'x'), gradient_no_abs(high, 'x')) 184 | y_loss = F.mse_loss(gradient_no_abs(low, 'y'), gradient_no_abs(high, 'y')) 185 | grad_loss_all = x_loss + y_loss 186 | return grad_loss_all 187 | 188 | def forward(self, R_low, R_high, hook=-1): 189 | # loss_grad = self.grad_loss(R_low, R_high, hook=hook) 190 | loss_recon = F.l1_loss(R_low, R_high) 191 | loss_ssim = 1-self.ssim_loss(R_low, R_high) 192 | loss_restore = loss_recon + loss_ssim #+ loss_grad 193 | return loss_restore 194 | 195 | 196 | if __name__ == "__main__": 197 | from dataloader import * 198 | from torch.utils.data import DataLoader 199 | from torchvision.utils import make_grid 200 | from matplotlib import pyplot as plt 201 | root_path_train = r'H:\datasets\Low-Light Dataset\KinD++\LOLdataset\our485' 202 | list_path_train = build_LOLDataset_list_txt(root_path_train) 203 | Batch_size = 1 204 | log("Buliding LOL Dataset...") 205 | dst_test = LOLDataset(root_path_train, list_path_train, to_RAM=True, training=False) 206 | # But when we are training a model, the mean should have another value 207 | testloader = DataLoader(dst_test, batch_size = Batch_size) 208 | for i, data in enumerate(testloader): 209 | L_low, L_high, name = data 210 | L_gradient_x = gradient_no_abs(L_high, "x", device='cpu', kernel='sobel') 211 | epsilon = 0.01*torch.ones_like(L_gradient_x) 212 | Denominator_x = torch.max(L_gradient_x, epsilon) 213 | imgs = Denominator_x 214 | img = imgs[1].numpy() 215 | sample(img, figure_size=(1,1), img_dim=400) -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from base_layers import * 6 | 7 | class DecomNet(nn.Module): 8 | def __init__(self, filters=32, activation='lrelu'): 9 | super().__init__() 10 | self.conv_input = Conv2D(3, filters) 11 | # top path build Reflectance map 12 | self.maxpool_r1 = MaxPooling2D() 13 | self.conv_r1 = Conv2D(filters, filters*2) 14 | self.maxpool_r2 = MaxPooling2D() 15 | self.conv_r2 = Conv2D(filters*2, filters*4) 16 | self.deconv_r1 = ConvTranspose2D(filters*4, filters*2) 17 | self.concat_r1 = Concat() 18 | self.conv_r3 = Conv2D(filters*4, filters*2) 19 | self.deconv_r2 = ConvTranspose2D(filters*2, filters) 20 | self.concat_r2 = Concat() 21 | self.conv_r4 = Conv2D(filters*2, filters) 22 | self.conv_r5 = nn.Conv2d(filters, 3, kernel_size=3, padding=1) 23 | self.R_out = nn.Sigmoid() 24 | # bottom path build Illumination map 25 | self.conv_i1 = Conv2D(filters, filters) 26 | self.concat_i1 = Concat() 27 | self.conv_i2 = nn.Conv2d(filters*2, 1, kernel_size=3, padding=1) 28 | self.I_out = nn.Sigmoid() 29 | 30 | def forward(self, x): 31 | conv_input = self.conv_input(x) 32 | # build Reflectance map 33 | maxpool_r1 = self.maxpool_r1(conv_input) 34 | conv_r1 = self.conv_r1(maxpool_r1) 35 | maxpool_r2 = self.maxpool_r2(conv_r1) 36 | conv_r2 = self.conv_r2(maxpool_r2) 37 | deconv_r1 = self.deconv_r1(conv_r2) 38 | concat_r1 = self.concat_r1(conv_r1, deconv_r1) 39 | conv_r3 = self.conv_r3(concat_r1) 40 | deconv_r2 = self.deconv_r2(conv_r3) 41 | concat_r2 = self.concat_r2(conv_input, deconv_r2) 42 | conv_r4 = self.conv_r4(concat_r2) 43 | conv_r5 = self.conv_r5(conv_r4) 44 | R_out = self.R_out(conv_r5) 45 | 46 | # build Illumination map 47 | conv_i1 = self.conv_i1(conv_input) 48 | concat_i1 = self.concat_i1(conv_r4, conv_i1) 49 | conv_i2 = self.conv_i2(concat_i1) 50 | I_out = self.I_out(conv_i2) 51 | 52 | return R_out, I_out 53 | 54 | 55 | class IllumNet(nn.Module): 56 | def __init__(self, filters=32, activation='lrelu'): 57 | super().__init__() 58 | self.concat_input = Concat() 59 | # bottom path build Illumination map 60 | self.conv_i1 = Conv2D(2, filters) 61 | self.conv_i2 = Conv2D(filters, filters) 62 | self.conv_i3 = Conv2D(filters, filters) 63 | self.conv_i4 = nn.Conv2d(filters, 1, kernel_size=3, padding=1) 64 | self.I_out = nn.Sigmoid() 65 | 66 | def forward(self, I, ratio): 67 | with torch.no_grad(): 68 | ratio_map = torch.ones_like(I) * ratio 69 | concat_input = self.concat_input(I, ratio_map) 70 | # build Illumination map 71 | conv_i1 = self.conv_i1(concat_input) 72 | conv_i2 = self.conv_i2(conv_i1) 73 | conv_i3 = self.conv_i3(conv_i2) 74 | conv_i4 = self.conv_i4(conv_i3) 75 | I_out = self.I_out(conv_i4) 76 | 77 | return I_out 78 | 79 | 80 | class IllumNet_Custom(nn.Module): 81 | def __init__(self, filters=16, activation='lrelu', device='cuda'): 82 | super().__init__() 83 | self.concat_input = Concat() 84 | # Parameter 85 | self.Gauss = torch.as_tensor( 86 | np.array([[0.0947416, 0.118318, 0.0947416], 87 | [ 0.118318, 0.147761, 0.118318], 88 | [0.0947416, 0.118318, 0.0947416]]).astype(np.float32) 89 | ) 90 | self.Gauss_kernel = self.Gauss.expand(1, 1, 3, 3).to(device) 91 | self.w = nn.Parameter(torch.FloatTensor(1), requires_grad=True).to(device).data.fill_(0.72) 92 | self.sigma = nn.Parameter(torch.FloatTensor(1), requires_grad=True).to(device).data.fill_(2.0) 93 | 94 | 95 | # bottom path build Illumination map 96 | self.conv_input = Conv2D(2, filters) 97 | self.res_block = nn.Sequential( 98 | ResConv(filters, filters), 99 | ResConv(filters, filters), 100 | ResConv(filters, filters) 101 | ) 102 | # self.down1 = MaxPooling2D() 103 | # self.conv_2 = Conv2D(filters, filters*2) 104 | # self.down2 = MaxPooling2D() 105 | # self.conv_3 = Conv2D(filters*2, filters*4) 106 | # self.down3 = MaxPooling2D() 107 | # self.conv_4 = Conv2D(filters*4, filters*8) 108 | 109 | # self.d = nn.Dropout2d(0.5) 110 | 111 | # self.deconv_3 = ConvTranspose2D(filters*8, filters*4) 112 | # self.concat3 = Concat() 113 | # self.cbam3 = CBAM(filters*8) 114 | # self.deconv_2 = ConvTranspose2D(filters*8, filters*2) 115 | # self.concat2 = Concat() 116 | # self.cbam2 = CBAM(filters*4) 117 | # self.deconv_1 = ConvTranspose2D(filters*4, filters*1) 118 | # self.concat1 = Concat() 119 | # self.cbam1 = CBAM(filters*2) 120 | self.conv_out = nn.Conv2d(filters, 1, kernel_size=3, padding=1) 121 | 122 | self.I_out = nn.Sigmoid() 123 | 124 | def standard_illum_map(self, I, ratio=1, blur=False): 125 | self.w.clamp_(0.01, 0.99) 126 | self.sigma.clamp_(0.1, 10) 127 | # if blur: # low light image have much noisy 128 | # I = torch.nn.functional.conv2d(I, weight=self.Gauss_kernel, padding=1) 129 | I = torch.log(I + 1.) 130 | I_mean = torch.mean(I, dim=[2, 3], keepdim=True) 131 | I_std = torch.std(I, dim=[2, 3], keepdim=True) 132 | I_min = I_mean - self.sigma * I_std 133 | I_max = I_mean + self.sigma * I_std 134 | I_range = I_max - I_min 135 | I_out = torch.clamp((I - I_min) / I_range, min=0.0, max=1.0) 136 | # Transfer to gamma correction, center intensity is w 137 | I_out = I_out ** (-1.442695 * torch.log(self.w)) 138 | return I_out 139 | 140 | def set_parameter(self, w=None): 141 | if w is None: 142 | self.w.requires_grad = True 143 | else: 144 | self.w.data.fill_(w) 145 | self.w.requires_grad = False 146 | 147 | def get_parameter(self): 148 | if self.w.device.type == 'cuda': 149 | w = self.w.detach().cpu().numpy() 150 | sigma = self.sigma.detach().cpu().numpy() 151 | else: 152 | w = self.w.numpy() 153 | sigma = self.sigma.numpy() 154 | return w, sigma 155 | 156 | def forward(self, I, ratio): 157 | I_standard = self.standard_illum_map(I, ratio) 158 | concat_input = torch.cat([I, I_standard], dim=1) 159 | # build Illumination map 160 | conv_input = self.conv_input(concat_input) 161 | res_block = self.res_block(conv_input) 162 | # down1 = self.down1(conv_1) 163 | # conv_2 = self.conv_2(down1) 164 | # down2 = self.down2(conv_2) 165 | # conv_3 = self.conv_3(down2) 166 | # down3 = self.down3(conv_3) 167 | # conv_4 = self.conv_4(down3) 168 | # d = self.d(conv_4) 169 | # deconv_3 = self.deconv_3(d) 170 | 171 | # concat3 = self.concat3(conv_3, deconv_3) 172 | # cbam3 = self.cbam3(concat3) 173 | # deconv_2 = self.deconv_2(cbam3) 174 | 175 | # concat2 = self.concat2(conv_2, deconv_2) 176 | # cbam2 = self.cbam2(concat2) 177 | # deconv_1 = self.deconv_1(cbam2) 178 | 179 | # concat1 = self.concat1(conv_1, deconv_1) 180 | # cbam1 = self.cbam1(concat1) 181 | res_out = res_block + conv_input 182 | conv_out = self.conv_out(res_out) 183 | I_out = self.I_out(conv_out) 184 | 185 | return I_out, I_standard 186 | 187 | 188 | class RestoreNet_MSIA(nn.Module): 189 | def __init__(self, filters=16, activation='relu'): 190 | super().__init__() 191 | # Illumination Attention 192 | self.i_input = nn.Conv2d(1,1,kernel_size=3,padding=1) 193 | self.i_att = nn.Sigmoid() 194 | 195 | # Network 196 | self.conv1_1 = Conv2D(3, filters, activation) 197 | self.conv1_2 = Conv2D(filters, filters*2, activation) 198 | self.msia1 = MSIA(filters*2, activation) 199 | 200 | self.conv2_1 = Conv2D(filters*2, filters*4, activation) 201 | self.conv2_2 = Conv2D(filters*4, filters*4, activation) 202 | self.msia2 = MSIA(filters*4, activation) 203 | 204 | self.conv3_1 = Conv2D(filters*4, filters*8, activation) 205 | self.dropout = nn.Dropout2d(0.5) 206 | self.conv3_2 = Conv2D(filters*8, filters*4, activation) 207 | self.msia3 = MSIA(filters*4, activation) 208 | 209 | self.conv4_1 = Conv2D(filters*4, filters*2, activation) 210 | self.conv4_2 = Conv2D(filters*2, filters*2, activation) 211 | self.msia4 = MSIA(filters*2, activation) 212 | 213 | self.conv5_1 = Conv2D(filters*2, filters*1, activation) 214 | self.conv5_2 = nn.Conv2d(filters, 3, kernel_size=1, padding=0) 215 | self.out = nn.Sigmoid() 216 | 217 | def forward(self, R, I): 218 | # Illumination Attention 219 | i_input = self.i_input(I) 220 | i_att = self.i_att(i_input) 221 | 222 | # Network 223 | conv1 = self.conv1_1(R) 224 | conv1 = self.conv1_2(conv1) 225 | msia1 = self.msia1(conv1, i_att) 226 | 227 | conv2 = self.conv2_1(msia1) 228 | conv2 = self.conv2_2(conv2) 229 | msia2 = self.msia2(conv2, i_att) 230 | 231 | conv3 = self.conv3_1(msia2) 232 | conv3 = self.conv3_2(conv3) 233 | msia3 = self.msia3(conv3, i_att) 234 | 235 | conv4 = self.conv4_1(msia3) 236 | conv4 = self.conv4_2(conv4) 237 | msia4 = self.msia4(conv4, i_att) 238 | 239 | conv5 = self.conv5_1(msia4) 240 | conv5 = self.conv5_2(conv5) 241 | 242 | # out = self.out(conv5) 243 | out = conv5.clamp(min=0.0, max=1.0) 244 | return out 245 | 246 | 247 | class RestoreNet_Unet(nn.Module): 248 | def __init__(self, filters=32, activation='lrelu'): 249 | super().__init__() 250 | self.conv1_1 = Conv2D(4, filters) 251 | self.conv1_2 = Conv2D(filters, filters) 252 | self.pool1 = MaxPooling2D() 253 | 254 | self.conv2_1 = Conv2D(filters, filters*2) 255 | self.conv2_2 = Conv2D(filters*2, filters*2) 256 | self.pool2 = MaxPooling2D() 257 | 258 | self.conv3_1 = Conv2D(filters*2, filters*4) 259 | self.conv3_2 = Conv2D(filters*4, filters*4) 260 | self.pool3 = MaxPooling2D() 261 | 262 | self.conv4_1 = Conv2D(filters*4, filters*8) 263 | self.conv4_2 = Conv2D(filters*8, filters*8) 264 | self.pool4 = MaxPooling2D() 265 | 266 | self.conv5_1 = Conv2D(filters*8, filters*16) 267 | self.conv5_2 = Conv2D(filters*16, filters*16) 268 | self.dropout = nn.Dropout2d(0.5) 269 | 270 | self.upv6 = ConvTranspose2D(filters*16, filters*8) 271 | self.concat6 = Concat() 272 | self.conv6_1 = Conv2D(filters*16, filters*8) 273 | self.conv6_2 = Conv2D(filters*8, filters*8) 274 | 275 | self.upv7 = ConvTranspose2D(filters*8, filters*4) 276 | self.concat7 = Concat() 277 | self.conv7_1 = Conv2D(filters*8, filters*4) 278 | self.conv7_2 = Conv2D(filters*4, filters*4) 279 | 280 | self.upv8 = ConvTranspose2D(filters*4, filters*2) 281 | self.concat8 = Concat() 282 | self.conv8_1 = Conv2D(filters*4, filters*2) 283 | self.conv8_2 = Conv2D(filters*2, filters*2) 284 | 285 | self.upv9 = ConvTranspose2D(filters*2, filters) 286 | self.concat9 = Concat() 287 | self.conv9_1 = Conv2D(filters*2, filters) 288 | self.conv9_2 = Conv2D(filters, filters) 289 | 290 | self.conv10_1 = nn.Conv2d(filters, 3, kernel_size=1, stride=1) 291 | self.out = nn.Sigmoid() 292 | 293 | def forward(self, R, I): 294 | x = torch.cat([R, I], dim=1) 295 | conv1 = self.conv1_1(x) 296 | conv1 = self.conv1_2(conv1) 297 | pool1 = self.pool1(conv1) 298 | 299 | conv2 = self.conv2_1(pool1) 300 | conv2 = self.conv2_2(conv2) 301 | pool2 = self.pool1(conv2) 302 | 303 | conv3 = self.conv3_1(pool2) 304 | conv3 = self.conv3_2(conv3) 305 | pool3 = self.pool1(conv3) 306 | 307 | conv4 = self.conv4_1(pool3) 308 | conv4 = self.conv4_2(conv4) 309 | pool4 = self.pool1(conv4) 310 | 311 | conv5 = self.conv5_1(pool4) 312 | conv5 = self.conv5_2(conv5) 313 | 314 | # d = self.dropout(conv5) 315 | up6 = self.upv6(conv5) 316 | up6 = self.concat6(conv4, up6) 317 | conv6 = self.conv6_1(up6) 318 | conv6 = self.conv6_2(conv6) 319 | 320 | up7 = self.upv7(conv6) 321 | up7 = self.concat7(conv3, up7) 322 | conv7 = self.conv7_1(up7) 323 | conv7 = self.conv7_2(conv7) 324 | 325 | up8 = self.upv8(conv7) 326 | up8 = self.concat8(conv2, up8) 327 | conv8 = self.conv8_1(up8) 328 | conv8 = self.conv8_2(conv8) 329 | 330 | up9 = self.upv9(conv8) 331 | up9 = self.concat9(conv1, up9) 332 | conv9 = self.conv9_1(up9) 333 | conv9 = self.conv9_2(conv9) 334 | 335 | conv10 = self.conv10_1(conv9) 336 | out = self.out(conv10) 337 | return out 338 | 339 | class KinD_noDecom(nn.Module): 340 | def __init__(self, filters=32, activation='lrelu'): 341 | super().__init__() 342 | # self.decom_net = DecomNet() 343 | self.restore_net = RestoreNet_Unet() 344 | self.illum_net = IllumNet() 345 | 346 | def forward(self, R, I, ratio): 347 | I_final = self.illum_net(I, ratio) 348 | R_final = self.restore_net(R, I) 349 | I_final_3 = torch.cat([I_final, I_final, I_final], dim=1) 350 | output = I_final_3 * R_final 351 | return R_final, I_final, output 352 | 353 | 354 | class KinD(nn.Module): 355 | def __init__(self, filters=32, activation='lrelu'): 356 | super().__init__() 357 | self.decom_net = DecomNet() 358 | self.restore_net = RestoreNet_Unet() 359 | self.illum_net = IllumNet() 360 | self.KinD_noDecom = KinD_noDecom() 361 | self.KinD_noDecom.restore_net = self.restore_net 362 | self.KinD_noDecom.illum_net = self.illum_net 363 | 364 | def forward(self, L, ratio): 365 | R, I = self.decom_net(L) 366 | R_final, I_final, output = self.KinD_noDecom(R, I, ratio) 367 | # I_final = self.illum_net(I, ratio) 368 | # R_final = self.restore_net(R, I) 369 | # I_final_3 = torch.cat([I_final, I_final, I_final], dim=1) 370 | # output = I_final_3 * R_final 371 | return R_final, I_final, output 372 | 373 | class KinD_plus(nn.Module): 374 | def __init__(self, filters=32, activation='lrelu'): 375 | super().__init__() 376 | self.decom_net = DecomNet() 377 | self.restore_net = RestoreNet_MSIA() 378 | self.illum_net = IllumNet_Custom() 379 | 380 | def forward(self, L, ratio): 381 | R, I = self.decom_net(L) 382 | # R_final, I_final, output = self.KinD_noDecom(R, I, ratio) 383 | I_final, I_standard = self.illum_net(I, ratio) 384 | R_final = self.restore_net(R, I) 385 | I_final_3 = torch.cat([I_final, I_final, I_final], dim=1) 386 | output = I_final_3 * R_final 387 | return R_final, I_final, output -------------------------------------------------------------------------------- /pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 20 | 21 | mu1_sq = mu1.pow(2) 22 | mu2_sq = mu2.pow(2) 23 | mu1_mu2 = mu1*mu2 24 | 25 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 26 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 27 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 28 | 29 | C1 = 0.01**2 30 | C2 = 0.03**2 31 | 32 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 33 | 34 | if size_average: 35 | return ssim_map.mean() 36 | else: 37 | return ssim_map.mean(1).mean(1).mean(1) 38 | 39 | class SSIM(torch.nn.Module): 40 | def __init__(self, window_size = 11, size_average = True): 41 | super(SSIM, self).__init__() 42 | self.window_size = window_size 43 | self.size_average = size_average 44 | self.channel = 1 45 | self.window = create_window(window_size, self.channel) 46 | 47 | def forward(self, img1, img2): 48 | (_, channel, _, _) = img1.size() 49 | 50 | if channel == self.channel and self.window.data.type() == img1.data.type(): 51 | window = self.window 52 | else: 53 | window = create_window(self.window_size, channel) 54 | 55 | if img1.is_cuda: 56 | window = window.cuda(img1.get_device()) 57 | window = window.type_as(img1) 58 | 59 | self.window = window 60 | self.channel = channel 61 | 62 | 63 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 64 | 65 | def ssim(img1, img2, window_size = 11, size_average = True): 66 | (_, channel, _, _) = img1.size() 67 | window = create_window(window_size, channel) 68 | 69 | if img1.is_cuda: 70 | window = window.cuda(img1.get_device()) 71 | window = window.type_as(img1) 72 | 73 | return _ssim(img1, img2, window, window_size, channel, size_average) -------------------------------------------------------------------------------- /restore_MSIA_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import time 6 | import yaml 7 | import sys 8 | from tqdm import tqdm 9 | from torch.optim import lr_scheduler 10 | from torchvision.utils import make_grid 11 | from torchvision import transforms 12 | from torchsummary import summary 13 | from base_trainer import BaseTrainer 14 | from losses import * 15 | from models import * 16 | from base_parser import BaseParser 17 | from dataloader import * 18 | 19 | class Restore_Trainer(BaseTrainer): 20 | def __init__(self, config, dataloader, criterion, model, 21 | dataloader_test=None, decom_net=None): 22 | super().__init__(config, dataloader, criterion, model, dataloader_test) 23 | log(f'Using device {self.device}') 24 | self.decom_net = decom_net 25 | self.decom_net.to(device=self.device) 26 | 27 | def train(self): 28 | # print(self.model) 29 | summary(self.model, input_size=[(3, 256, 256), (1,256,256)]) 30 | 31 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) 32 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.99426) #0.977237, 0.986233 33 | try: 34 | for iter in range(self.epochs): 35 | epoch_loss = 0 36 | idx = 0 37 | hook_number = -1 38 | iter_start_time = time.time() 39 | # with tqdm(total=self.steps_per_epoch) as pbar: 40 | for L_low_tensor, L_high_tensor, name in self.dataloader: 41 | optimizer.zero_grad() 42 | L_low = L_low_tensor.to(self.device) 43 | L_high = L_high_tensor.to(self.device) 44 | 45 | with torch.no_grad(): 46 | R_low, I_low = self.decom_net(L_low) 47 | R_high, I_high = self.decom_net(L_high) 48 | 49 | R_restore = self.model(R_low, I_low) 50 | 51 | if idx % self.print_frequency == 0: 52 | hook_number = iter 53 | loss = self.loss_fn(R_restore, R_high, hook=hook_number) 54 | hook_number = -1 55 | if idx % 8 == 0: 56 | log(f"iter: {iter}_{idx}\taverage_loss: {loss.item():.6f}") 57 | loss.backward() 58 | optimizer.step() 59 | idx += 1 60 | # pbar.update(1) 61 | # pbar.set_postfix({'loss':loss.item()}) 62 | 63 | if iter % self.print_frequency == 0: 64 | self.test(iter, plot_dir='./images/samples-restore-MSIA') 65 | 66 | if iter % self.save_frequency == 0: 67 | torch.save(self.model.state_dict(), f'./weights/restore_net_MSIA_{iter//100}.pth') 68 | log("Weight Has saved as 'restore_net.pth'") 69 | 70 | scheduler.step() 71 | iter_end_time = time.time() 72 | log(f"Time taken: {iter_end_time - iter_start_time:.3f} seconds\t lr={scheduler.get_lr()[0]:.6f}") 73 | # print("End of epochs {.0f}, Time taken: {.3f}, average loss: {.5f}".format( 74 | # idx, iter_end_time - iter_start_time, epoch_loss / idx)) 75 | 76 | except KeyboardInterrupt: 77 | torch.save(self.model.state_dict(), './weights/INTERRUPTED_restore.pth') 78 | print('Saved interrupt') 79 | try: 80 | sys.exit(0) 81 | except SystemExit: 82 | os._exit(0) 83 | 84 | @no_grad 85 | def test(self, epoch=-1, plot_dir='./images/samples-restore'): 86 | self.model.eval() 87 | for L_low_tensor, L_high_tensor, name in self.dataloader_test: 88 | L_low = L_low_tensor.to(self.device) 89 | L_high = L_high_tensor.to(self.device) 90 | 91 | R_low, I_low = self.decom_net(L_low) 92 | R_high, I_high = self.decom_net(L_high) 93 | 94 | R_restore = self.model(R_low, I_low) 95 | 96 | R_restore_np = R_restore.detach().cpu().numpy()[0] 97 | I_low_np = I_low.detach().cpu().numpy()[0] 98 | R_low_np = R_low.detach().cpu().numpy()[0] 99 | R_high_np = R_high.detach().cpu().numpy()[0] 100 | sample_imgs = np.concatenate( (I_low_np, R_low_np, R_restore_np, R_high_np), axis=0 ) 101 | 102 | filepath = os.path.join(plot_dir, f'{name[0]}_epoch_{epoch//100}.png') 103 | split_point = [0, 1, 4, 7, 10] 104 | img_dim = I_low_np.shape[1:] 105 | sample(sample_imgs, split=split_point, figure_size=(2, 2), 106 | img_dim=img_dim, path=filepath, num=epoch) 107 | 108 | if __name__ == "__main__": 109 | criterion = Restore_Loss() 110 | model = RestoreNet_MSIA() 111 | decom_net = DecomNet() 112 | 113 | parser = BaseParser() 114 | args = parser.parse() 115 | 116 | with open(args.config) as f: 117 | config = yaml.load(f) 118 | args.checkpoint = True 119 | 120 | if args.checkpoint is not None: 121 | if config['noDecom'] is False: 122 | decom_net = load_weights(decom_net, path='./weights/decom_net.pth') 123 | log('DecomNet loaded from decom_net.pth') 124 | model = load_weights(model, path='./weights/restore_net_MSIA_1.pth') 125 | log('Model loaded from restore_net_MSIA.pth') 126 | 127 | root_path_train = r'C:\DeepLearning\KinD_plus-master\LOLdataset\our485' 128 | root_path_test = r'C:\DeepLearning\KinD_plus-master\LOLdataset\eval15' 129 | list_path_train = build_LOLDataset_list_txt(root_path_train) 130 | list_path_test = build_LOLDataset_list_txt(root_path_test) 131 | # list_path_train = os.path.join(root_path_train, 'pair_list.csv') 132 | # list_path_test = os.path.join(root_path_test, 'pair_list.csv') 133 | 134 | log("Buliding LOL Dataset...") 135 | # transform = transforms.Compose([transforms.ToTensor()]) 136 | dst_train = LOLDataset(root_path_train, list_path_train, 137 | crop_size=config['length'], to_RAM=True) 138 | dst_test = LOLDataset(root_path_test, list_path_test, 139 | crop_size=config['length'], to_RAM=True, training=False) 140 | 141 | train_loader = DataLoader(dst_train, batch_size = config['batch_size'], shuffle=True) 142 | # train_loader = data_prefetcher(train_loader) 143 | test_loader = DataLoader(dst_test, batch_size=1) 144 | 145 | trainer = Restore_Trainer(config, train_loader, criterion, model, 146 | dataloader_test=test_loader, decom_net=decom_net) 147 | 148 | if args.mode == 'train': 149 | trainer.train() 150 | else: 151 | trainer.test() -------------------------------------------------------------------------------- /restore_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import time 6 | import yaml 7 | import sys 8 | from tqdm import tqdm 9 | from torch.optim import lr_scheduler 10 | from torchvision.utils import make_grid 11 | from torchvision import transforms 12 | from torchsummary import summary 13 | from base_trainer import BaseTrainer 14 | from losses import * 15 | from models import * 16 | from base_parser import BaseParser 17 | from dataloader import * 18 | 19 | class Restore_Trainer(BaseTrainer): 20 | def __init__(self, config, dataloader, criterion, model, 21 | dataloader_test=None, decom_net=None): 22 | super().__init__(config, dataloader, criterion, model, dataloader_test) 23 | log(f'Using device {self.device}') 24 | self.decom_net = decom_net 25 | self.decom_net.to(device=self.device) 26 | 27 | def train(self): 28 | # print(self.model) 29 | summary(self.model, input_size=[(3, 384, 384), (1,384,384)]) 30 | 31 | optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) 32 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.986233) #0.977237, 0.986233 33 | try: 34 | for iter in range(self.epochs): 35 | epoch_loss = 0 36 | idx = 0 37 | hook_number = -1 38 | iter_start_time = time.time() 39 | # with tqdm(total=self.steps_per_epoch) as pbar: 40 | if self.noDecom is True: 41 | for R_low_tensor, I_low_tensor, R_high_tensor, I_high_tensor, name in self.dataloader: 42 | optimizer.zero_grad() 43 | I_low = I_low_tensor.to(self.device) 44 | R_low = R_low_tensor.to(self.device) 45 | R_high = R_high_tensor.to(self.device) 46 | R_restore = self.model(R_low, I_low) 47 | 48 | if idx % self.print_frequency == 0: 49 | hook_number = iter 50 | loss = self.loss_fn(R_restore, R_high, hook=hook_number) 51 | hook_number = -1 52 | if idx % 30 == 0: 53 | log(f"iter: {iter}_{idx}\taverage_loss: {loss.item():.6f}") 54 | loss.backward() 55 | optimizer.step() 56 | idx += 1 57 | # pbar.update(1) 58 | # pbar.set_postfix({'loss':loss.item()}) 59 | 60 | else: 61 | for L_low_tensor, L_high_tensor, name in self.dataloader: 62 | optimizer.zero_grad() 63 | L_low = L_low_tensor.to(self.device) 64 | L_high = L_high_tensor.to(self.device) 65 | 66 | with torch.no_grad(): 67 | R_low, I_low = self.decom_net(L_low) 68 | R_high, I_high = self.decom_net(L_high) 69 | 70 | R_restore = self.model(R_low, I_low) 71 | 72 | if idx % self.print_frequency == 0: 73 | hook_number = iter 74 | loss = self.loss_fn(R_restore, R_high, hook=hook_number) 75 | hook_number = -1 76 | if idx % 30 == 0: 77 | log(f"iter: {iter}_{idx}\taverage_loss: {loss.item():.6f}") 78 | loss.backward() 79 | optimizer.step() 80 | idx += 1 81 | # pbar.update(1) 82 | # pbar.set_postfix({'loss':loss.item()}) 83 | 84 | if iter % self.print_frequency == 0: 85 | self.test(iter, plot_dir='./images/samples-restore') 86 | 87 | if iter % self.save_frequency == 0: 88 | torch.save(self.model.state_dict(), './weights/restore_net.pth') 89 | log("Weight Has saved as 'restore_net.pth'") 90 | 91 | scheduler.step() 92 | iter_end_time = time.time() 93 | log(f"Time taken: {iter_end_time - iter_start_time:.3f} seconds\t lr={scheduler.get_lr()[0]:.6f}") 94 | # print("End of epochs {.0f}, Time taken: {.3f}, average loss: {.5f}".format( 95 | # idx, iter_end_time - iter_start_time, epoch_loss / idx)) 96 | 97 | except KeyboardInterrupt: 98 | torch.save(self.model.state_dict(), './weights/INTERRUPTED_restore.pth') 99 | print('Saved interrupt') 100 | try: 101 | sys.exit(0) 102 | except SystemExit: 103 | os._exit(0) 104 | 105 | @no_grad 106 | def test(self, epoch=-1, plot_dir='./images/samples-restore'): 107 | self.model.eval() 108 | if self.noDecom: 109 | for R_low_tensor, I_low_tensor, R_high_tensor, I_high_tensor, name in self.dataloader_test: 110 | I_low = I_low_tensor.to(self.device) 111 | R_low = R_low_tensor.to(self.device) 112 | R_restore = self.model(R_low, I_low) 113 | 114 | R_restore_np = R_restore.detach().cpu().numpy()[0] 115 | I_low_np = I_low_tensor.numpy()[0] 116 | R_low_np = R_low_tensor.numpy()[0] 117 | R_high_np = R_high_tensor.numpy()[0] 118 | sample_imgs = np.concatenate( (I_low_np, R_low_np, R_restore_np, R_high_np), axis=0 ) 119 | 120 | filepath = os.path.join(plot_dir, f'{name[0]}_epoch_{epoch}.png') 121 | split_point = [0, 1, 4, 7, 10] 122 | img_dim = I_low_np.shape[1:] 123 | sample(sample_imgs, split=split_point, figure_size=(2, 2), 124 | img_dim=img_dim, path=filepath, num=epoch) 125 | else: 126 | for L_low_tensor, L_high_tensor, name in self.dataloader_test: 127 | L_low = L_low_tensor.to(self.device) 128 | L_high = L_high_tensor.to(self.device) 129 | 130 | R_low, I_low = self.decom_net(L_low) 131 | R_high, I_high = self.decom_net(L_high) 132 | 133 | R_restore = self.model(R_low, I_low) 134 | 135 | R_restore_np = R_restore.detach().cpu().numpy()[0] 136 | I_low_np = I_low.detach().cpu().numpy()[0] 137 | R_low_np = R_low.detach().cpu().numpy()[0] 138 | R_high_np = R_high.detach().cpu().numpy()[0] 139 | sample_imgs = np.concatenate( (I_low_np, R_low_np, R_restore_np, R_high_np), axis=0 ) 140 | 141 | filepath = os.path.join(plot_dir, f'{name[0]}_epoch_{epoch}.png') 142 | split_point = [0, 1, 4, 7, 10] 143 | img_dim = I_low_np.shape[1:] 144 | sample(sample_imgs, split=split_point, figure_size=(2, 2), 145 | img_dim=img_dim, path=filepath, num=epoch) 146 | 147 | if __name__ == "__main__": 148 | criterion = Restore_Loss() 149 | model = RestoreNet_Unet() 150 | decom_net = DecomNet() 151 | 152 | parser = BaseParser() 153 | args = parser.parse() 154 | 155 | with open(args.config) as f: 156 | config = yaml.load(f) 157 | args.checkpoint = True 158 | 159 | if args.checkpoint is not None: 160 | if config['noDecom'] is False: 161 | pretrain_decom = torch.load('./weights/decom_net_test3.pth') 162 | decom_net.load_state_dict(pretrain_decom) 163 | log('DecomNet loaded from decom_net.pth') 164 | pretrain = torch.load('./weights/restore_net.pth') 165 | model.load_state_dict(pretrain) 166 | print('Model loaded from restore_net.pth') 167 | 168 | if config['noDecom'] is True: 169 | root_path_train = r'H:\datasets\Low-Light Dataset\LOLdataset_decom\our485' 170 | root_path_test = r'H:\datasets\Low-Light Dataset\LOLdataset_decom\eval15' 171 | list_path_train = build_LOLDataset_Decom_list_txt(root_path_train) 172 | list_path_test = build_LOLDataset_Decom_list_txt(root_path_test) 173 | # list_path_train = os.path.join(root_path_train, 'pair_list.csv') 174 | # list_path_test = os.path.join(root_path_test, 'pair_list.csv') 175 | 176 | log("Buliding LOL Dataset...") 177 | # transform = transforms.Compose([transforms.ToTensor()]) 178 | dst_train = LOLDataset_Decom(root_path_train, list_path_train, 179 | crop_size=config['length'], to_RAM=True) 180 | dst_test = LOLDataset_Decom(root_path_test, list_path_test, 181 | crop_size=config['length'], to_RAM=True, training=False) 182 | 183 | train_loader = DataLoader(dst_train, batch_size = config['batch_size'], shuffle=True) 184 | test_loader = DataLoader(dst_test, batch_size=1) 185 | 186 | else: 187 | root_path_train = r'C:\DeepLearning\KinD_plus-master\LOLdataset\our485' 188 | root_path_test = r'C:\DeepLearning\KinD_plus-master\LOLdataset\eval15' 189 | list_path_train = build_LOLDataset_list_txt(root_path_train) 190 | list_path_test = build_LOLDataset_list_txt(root_path_test) 191 | # list_path_train = os.path.join(root_path_train, 'pair_list.csv') 192 | # list_path_test = os.path.join(root_path_test, 'pair_list.csv') 193 | 194 | log("Buliding LOL Dataset...") 195 | # transform = transforms.Compose([transforms.ToTensor()]) 196 | dst_train = LOLDataset(root_path_train, list_path_train, 197 | crop_size=config['length'], to_RAM=True) 198 | dst_test = LOLDataset(root_path_test, list_path_test, 199 | crop_size=config['length'], to_RAM=True, training=False) 200 | 201 | train_loader = DataLoader(dst_train, batch_size = config['batch_size'], shuffle=True) 202 | test_loader = DataLoader(dst_test, batch_size=1) 203 | 204 | trainer = Restore_Trainer(config, train_loader, criterion, model, 205 | dataloader_test=test_loader, decom_net=decom_net) 206 | 207 | if args.mode == 'train': 208 | trainer.train() 209 | else: 210 | trainer.test() -------------------------------------------------------------------------------- /test_your_pictures.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import time 6 | import yaml 7 | import sys 8 | from torch.optim import lr_scheduler 9 | from torchvision.utils import make_grid 10 | from torchvision import transforms 11 | from torchsummary import summary 12 | from base_trainer import BaseTrainer 13 | from losses import * 14 | from models import * 15 | from base_parser import BaseParser 16 | from dataloader import * 17 | 18 | class KinD_Player(BaseTrainer): 19 | def __init__(self, model, dataloader_test, plot_more=False): 20 | self.dataloader_test = dataloader_test 21 | self.model = model 22 | self.plot_more = plot_more 23 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 24 | self.model.to(device=self.device) 25 | 26 | @no_grad 27 | def test(self, target_b=0.70, plot_dir='./images/samples-KinD'): 28 | self.model.eval() 29 | self.model.to(device=self.device) 30 | for L_low_tensor, name in self.dataloader_test: 31 | L_low = L_low_tensor.to(self.device) 32 | 33 | if self.plot_more: 34 | # Use DecomNet to decomposite Reflectance Map and Illumation Map 35 | R_low, I_low = self.model.decom_net(L_low) 36 | # Compute brightness ratio 37 | bright_low = torch.mean(I_low) 38 | else: 39 | bright_low = torch.mean(L_low) 40 | 41 | bright_high = torch.ones_like(bright_low) * target_b + 0.5 * bright_low 42 | ratio = torch.div(bright_high, bright_low) 43 | log(f"Brightness: {bright_high:.4f}\tIllumation Magnification: {ratio.item():.3f}") 44 | 45 | R_final, I_final, output_final = self.model(L_low, ratio) 46 | 47 | output_final_np = output_final.detach().cpu().numpy()[0] 48 | L_low_np = L_low_tensor.numpy()[0] 49 | # Only plot result 50 | filepath = os.path.join(plot_dir, f'{name[0]}.png') 51 | split_point = [0, 3] 52 | img_dim = L_low_np.shape[1:] 53 | sample(output_final_np, split=split_point, figure_size=(1, 1), 54 | img_dim=img_dim, path=filepath) 55 | 56 | if self.plot_more: 57 | R_final_np = R_final.detach().cpu().numpy()[0] 58 | I_final_np = I_final.detach().cpu().numpy()[0] 59 | R_low_np = R_low.detach().cpu().numpy()[0] 60 | I_low_np = I_low.detach().cpu().numpy()[0] 61 | 62 | sample_imgs = np.concatenate( (R_low_np, I_low_np, L_low_np, 63 | R_final_np, I_final_np, output_final_np), axis=0 ) 64 | filepath = os.path.join(plot_dir, f'{name[0]}_extra.png') 65 | split_point = [0, 3, 4, 7, 10, 11, 14] 66 | img_dim = L_low_np.shape[1:] 67 | sample(sample_imgs, split=split_point, figure_size=(2, 3), 68 | img_dim=img_dim, path=filepath) 69 | 70 | 71 | class TestParser(BaseParser): 72 | def parse(self): 73 | self.parser.add_argument("-p", "--plot_more", default=True, 74 | help="Plot intermediate variables. such as R_images and I_images") 75 | self.parser.add_argument("-c", "--checkpoint", default="./weights/", 76 | help="Path of checkpoints") 77 | self.parser.add_argument("-i", "--input_dir", default="./images/inputs/", 78 | help="Path of input pictures") 79 | self.parser.add_argument("-o", "--output_dir", default="./images/outputs/", 80 | help="Path of output pictures") 81 | self.parser.add_argument("-b", "--b_target", default=0.75, help="Target brightness") 82 | # self.parser.add_argument("-u", "--use_gpu", default=True, 83 | # help="If you want to use GPU to accelerate") 84 | return self.parser.parse_args() 85 | 86 | 87 | if __name__ == "__main__": 88 | model = KinD() 89 | parser = TestParser() 90 | args = parser.parse() 91 | 92 | input_dir = args.input_dir 93 | output_dir = args.output_dir 94 | plot_more = args.plot_more 95 | checkpoint = args.checkpoint 96 | decom_net_dir = os.path.join(checkpoint, "decom_net.pth") 97 | restore_net_dir = os.path.join(checkpoint, "restore_net.pth") 98 | illum_net_dir = os.path.join(checkpoint, "illum_net.pth") 99 | 100 | pretrain_decom = torch.load(decom_net_dir) 101 | model.decom_net.load_state_dict(pretrain_decom) 102 | log('Model loaded from decom_net.pth') 103 | pretrain_resotre = torch.load(restore_net_dir) 104 | model.restore_net.load_state_dict(pretrain_resotre) 105 | log('Model loaded from restore_net.pth') 106 | pretrain_illum = torch.load(illum_net_dir) 107 | model.illum_net.load_state_dict(pretrain_illum) 108 | log('Model loaded from illum_net.pth') 109 | 110 | log("Buliding Dataset...") 111 | dst = CustomDataset(input_dir) 112 | log(f"There are {len(dst)} images in the input direction...") 113 | dataloader = DataLoader(dst, batch_size=1) 114 | 115 | KinD = KinD_Player(model, dataloader, plot_more=plot_more) 116 | 117 | KinD.test(plot_dir=output_dir, target_b=args.b_target) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import random 4 | import matplotlib.pyplot as plt 5 | from PIL import Image 6 | import cv2 7 | import collections 8 | import torch 9 | import torchvision 10 | import shutil 11 | import time 12 | 13 | 14 | def log(string): 15 | print(time.strftime('%H:%M:%S'), ">> ", string) 16 | 17 | def data_augmentation(image, mode): 18 | if mode == 0: 19 | # original 20 | return image 21 | elif mode == 1: 22 | # flip up and down 23 | return np.flipud(image) 24 | elif mode == 2: 25 | # rotate counterwise 90 degree 26 | return np.rot90(image) 27 | elif mode == 3: 28 | # rotate 90 degree and flip up and down 29 | image = np.rot90(image) 30 | return np.flipud(image) 31 | elif mode == 4: 32 | # rotate 180 degree 33 | return np.rot90(image, k=2) 34 | elif mode == 5: 35 | # rotate 180 degree and flip 36 | image = np.rot90(image, k=2) 37 | return np.flipud(image) 38 | elif mode == 6: 39 | # rotate 270 degree 40 | return np.rot90(image, k=3) 41 | elif mode == 7: 42 | # rotate 270 degree and flip 43 | image = np.rot90(image, k=3) 44 | return np.flipud(image) 45 | 46 | # 作为装饰器函数 47 | def no_grad(fn): 48 | with torch.no_grad(): 49 | def transfer(*args,**kwargs): 50 | fn(*args,**kwargs) 51 | return fn 52 | 53 | 54 | def load_weights(model, path): 55 | pretrained_dict=torch.load(path) 56 | model_dict=model.state_dict() 57 | # 1. filter out unnecessary keys 58 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 59 | # 2. overwrite entries in the existing state dict 60 | model_dict.update(pretrained_dict) 61 | model.load_state_dict(model_dict) 62 | return model 63 | 64 | 65 | class data_prefetcher(): 66 | def __init__(self, loader): 67 | self.loader = iter(loader) 68 | self.stream = torch.cuda.Stream() 69 | self.preload() 70 | 71 | def preload(self): 72 | try: 73 | self.next_low, self.next_high, self.next_name = next(self.loader) 74 | except StopIteration: 75 | self.next_low = None 76 | self.next_high = None 77 | self.next_name = None 78 | return 79 | with torch.cuda.stream(self.stream): 80 | self.next_low = self.next_low.cuda(non_blocking=True) 81 | self.next_high = self.next_high.cuda(non_blocking=True) 82 | 83 | def __iter__(self): 84 | return self 85 | 86 | def __next__(self): 87 | torch.cuda.current_stream().wait_stream(self.stream) 88 | low = self.next_low 89 | high = self.next_high 90 | name = self.next_name 91 | self.preload() 92 | return low, high, name 93 | 94 | # def rgb2hsv(img): 95 | # if torch.is_tensor: 96 | # log(f'Image tensor size is {img.size()}') 97 | # else: 98 | # log("This Function can only deal PyTorch Tensor!") 99 | # return img 100 | # r, g, b = img.split(1, 0) 101 | # tensor_max = torch.max(torch.max(r, g), b) 102 | # tensor_min = torch.min(torch.min(r, g), b) 103 | # m = tensor_max-tensor_min 104 | # if tensor_max == tensor_min: 105 | # h = 0 106 | # elif tensor_max == r: 107 | # if g >= b: 108 | # h = ((g-b)/m)*60 109 | # else: 110 | # h = ((g-b)/m)*60 + 360 111 | # elif tensor_max == g: 112 | # h = ((b-r)/m)*60 + 120 113 | # elif tensor_max == b: 114 | # h = ((r-g)/m)*60 + 240 115 | # if tensor_max == 0: 116 | # s = 0 117 | # else: 118 | # s = m/tensor_max 119 | # v = tensor_max 120 | # return h, s, v 121 | 122 | def standard_illum(I, dynamic=2, w=0.5, gamma=None, blur=False): 123 | sigma = dynamic 124 | if torch.is_tensor(I): 125 | # I = torch.log(I + 1.) 126 | if blur: 127 | Gauss = torch.as_tensor( 128 | np.array([[0.0947416, 0.118318, 0.0947416], 129 | [ 0.118318, 0.147761, 0.118318], 130 | [0.0947416, 0.118318, 0.0947416]]).astype(np.float32) 131 | ).to(I.device) 132 | channels = I.size()[1] 133 | Gauss_kernel = Gauss.expand(channels, channels, 3, 3) 134 | I = torch.nn.functional.conv2d(I, weight=Gauss_kernel, padding=1) 135 | I_mean = torch.mean(I, dim=[2, 3], keepdim=True) 136 | I_std = torch.std(I, dim=[2, 3], keepdim=True) 137 | # I_max = torch.nn.AdaptiveMaxPool2d((1, 1))(I) 138 | # I_min = 1 - torch.nn.AdaptiveMaxPool2d((1, 1))(1-I) 139 | I_min = I_mean - sigma * I_std 140 | I_max = I_mean + sigma * I_std 141 | I_range = I_max - I_min 142 | I_out = torch.clamp((I - I_min) / I_range, min=0.0, max=1.0) 143 | # if gamma is not None: 144 | # return I**gamma 145 | w = torch.as_tensor(np.array(w).astype(np.float32)).to(I.device) 146 | I_out = I_out.pow(-1.442695 * torch.log(w)) 147 | print((-1.442695 * torch.log(w))) 148 | 149 | else: 150 | I = np.log(I + 1.) 151 | I_mean = np.mean(I) 152 | I_std = np.std(I) 153 | I_min = I_mean - sigma * I_std 154 | I_max = I_mean + sigma * I_std 155 | I_range = I_max - I_min 156 | I_out = np.clip((I - I_min) / I_range, 0.0, 1.0) 157 | 158 | return I_out 159 | 160 | 161 | def sample(imgs, split=None ,figure_size=(2, 3), img_dim=(400, 600), path=None, num=0): 162 | if type(img_dim) is int: 163 | img_dim = (img_dim, img_dim) 164 | img_dim = tuple(img_dim) 165 | if len(img_dim) == 1: 166 | h_dim = img_dim 167 | w_dim = img_dim 168 | elif len(img_dim) == 2: 169 | h_dim, w_dim = img_dim 170 | h, w = figure_size 171 | if split is None: 172 | num_of_imgs = figure_size[0] * figure_size[1] 173 | gap = len(imgs) // num_of_imgs 174 | split = list(range(0, len(imgs)+1, gap)) 175 | figure = np.zeros((h_dim*h, w_dim*w, 3)) 176 | for i in range(h): 177 | for j in range(w): 178 | idx = i*w+j 179 | if idx >= len(split)-1: break 180 | digit = imgs[ split[idx] : split[idx+1] ] 181 | if len(digit) == 1: 182 | for k in range(3): 183 | figure[i*h_dim: (i+1)*h_dim, 184 | j*w_dim: (j+1)*w_dim, k] = digit 185 | elif len(digit) == 3: 186 | for k in range(3): 187 | figure[i*h_dim: (i+1)*h_dim, 188 | j*w_dim: (j+1)*w_dim, k] = digit[2-k] 189 | if path is None: 190 | cv2.imshow('Figure%d'%num, figure) 191 | cv2.waitKey() 192 | else: 193 | figure *= 255 194 | filename1 = path.split('\\')[-1] 195 | filename2 = path.split('/')[-1] 196 | if len(filename1) < len(filename2): 197 | filename = filename1 198 | else: 199 | filename = filename2 200 | root_path = path[:-len(filename)] 201 | if not os.path.exists(root_path): 202 | os.makedirs(root_path) 203 | log("Saving Image at {}".format(path)) 204 | cv2.imwrite(path, figure) -------------------------------------------------------------------------------- /utils/img_generator.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | def sample(imgs, split=None ,figure_size=(2, 3), img_dim=96, path=None, num=0): 5 | h, w = figure_size 6 | if split is None: 7 | split = range(len(imgs)+1) 8 | figure = np.zeros((img_dim*h, img_dim*w, 3)) 9 | for i in range(h): 10 | for j in range(w): 11 | idx = i*w+j 12 | if idx >= len(split)-1: break 13 | digit = imgs[ split[idx] : split[idx+1] ] 14 | if len(digit) == 1: 15 | for k in range(3): 16 | figure[i*img_dim: (i+1)*img_dim, 17 | j*img_dim: (j+1)*img_dim, k] = digit 18 | elif len(digit) == 3: 19 | for k in range(3): 20 | figure[i*img_dim: (i+1)*img_dim, 21 | j*img_dim: (j+1)*img_dim, k] = digit[2-k] 22 | if path is None: 23 | cv2.imshow('Figure%d'%num, figure) 24 | cv2.waitKey() 25 | else: 26 | figure *= 255 27 | print(">> Saving Image at {}".format(path)) 28 | cv2.imwrite(path, figure) --------------------------------------------------------------------------------