├── .gitignore ├── README.md ├── dali_data.py ├── data.py ├── data_noise.py ├── models ├── __init__.py └── full_mix3_deep_encoder_decoder.py ├── psnr.py ├── readme.txt ├── result_ensemble.py ├── test-full.py ├── test-pad.py ├── test-val.py ├── test.py ├── train-noise.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RAW2RGBNet 2 | This is a PyTorch implement of RAW2RGBNet. Our Team: The First Team of Hogwarts School got 22.34dB in the validation set. 3 | For more details please refer to the official website of the challenge: https://competitions.codalab.org/competitions/20158#results 4 | 5 | ## Training 6 | ```bash 7 | python train.py --name full_mix3_deep_encoder_decoder --model full_mix3_deep_encoder_decoder --batchSize 16 --data_root /data1/kangfu/Datasets/RAW2RGB/ --checkpoint /data1/kangfu/Checkpoints/RAW2RGB/ --cuda --size 64 8 | ``` 9 | 10 | ## Validation 11 | ```bash 12 | python test-val.py --model mix3_deep_encoder_decoder --checkpoint /data1/kangfu/Checkpoints/RAW2RGB/mix3_deep_encoder_decoder_32_10_16_8_216_f_f_f/94.pth --output /data1/kangfu/Datasets/RAW2RGB/val_results --data /data1/kangfu/Datasets/RAW2RGB/RAW/ 13 | 14 | python psnr.py --data /data1/kangfu/Datasets/RAW2RGB/val_results --gt /data1/kangfu/Datasets/RAW2RGB/RGB/ 15 | ``` 16 | 17 | ## Testing 18 | ```bash 19 | CUDA_VISIBLE_DEVICES=0 python test.py --model mix3_deep_encoder_decoder --checkpoint ./80.pth --output /data1/kangfu/Datasets/RAW2RGB/val_results --data ~/ram_data/RAW2RGB/Validation 20 | ``` 21 | 22 | ## Testing Full-Resolution Images on Single Titan XP (12GB) 23 | ```bash 24 | CUDA_VISIBLE_DEVICES=0 python test-pad.py --model full_mix3_deep_encoder_decoder --checkpoint ./112.pth --output /data1/kangfu/Datasets/RAW2RGB/testing_full_results_full_mix3_bacth_224_ep_112 --data /data1/kangfu/Datasets/RAW2RGB/FullResTestingPhoneRaw 25 | ``` 26 | 27 | ## Testing Full-Resolution Images on Single Tesla M40 (24GB) 28 | ```bash 29 | python3 test-full.py --model full_mix3_deep_encoder_decoder --checkpoint ./114.pth --output ../testing_full_results_full_mix3_bacth_224_ep_114 --data ../FullResTestingPhoneRaw/ 30 | ``` 31 | 32 | ## Reproduce results in the challenge submission 33 | You can download the pre-trained model from here [114.pth](https://cuhko365-my.sharepoint.com/:u:/g/personal/219019003_link_cuhk_edu_cn/EZrS367uMMlPjVEQ41j1N30B-4d6fcfNESWNi0JPH2Pyfg?e=IlRYiU) [115.pth](https://cuhko365-my.sharepoint.com/:u:/g/personal/219019003_link_cuhk_edu_cn/Ea7hSVs-cXFHhKGxiTAt6BUBCh66brqiaeiqSNRfigoc2Q?e=wD8WwN) 34 | ```bash 35 | # For track 1 36 | # generate results using the 114.pth and the 115.pth respectively 37 | python test.py --model full_mix3_deep_encoder_decoder --checkpoint ./114.pth --output /data1/kangfu/Datasets/RAW2RGB/validation_results_full_mix3_bacth_224_ep_114 --data /data1/kangfu/Datasets/RAW2RGB/Validation 38 | 39 | python test.py --model full_mix3_deep_encoder_decoder --checkpoint ./115.pth --output /data1/kangfu/Datasets/RAW2RGB/validation_results_full_mix3_bacth_224_ep_115--data /data1/kangfu/Datasets/RAW2RGB/Validation 40 | 41 | # ensemble the results from 104.pth and 105.pth 42 | python result_ensemble.py --data /data1/kangfu/Datasets/RAW2RGB/testing_results_full_mix3_bacth_224_ep_114,/data1/kangfu/Datasets/RAW2RGB/testing_results_full_mix3_bacth_224_ep_115 --output /data1/kangfu/Datasets/RAW2RGB/testing_results_ensemble_114_115 43 | 44 | # For track 2 45 | # generate results using the 115.pth only 46 | python3 test-full.py --model full_mix3_deep_encoder_decoder --checkpoint ./115.pth --output ../testing_full_results_full_mix3_bacth_224_ep_114 --data ../FullResTestingPhoneRaw/ 47 | 48 | ``` 49 | 50 | 51 | ## Contact 52 | If you have any questions about the code, please contact kangfumei@link.cuhk.edu.cn 53 | -------------------------------------------------------------------------------- /dali_data.py: -------------------------------------------------------------------------------- 1 | import nvidia.dali.ops as ops 2 | import nvidia.dali.types as types 3 | from nvidia.dali.pipeline import Pipeline 4 | from random import shuffle 5 | from os import listdir 6 | from os.path import join 7 | import numpy as np 8 | 9 | 10 | def is_image_file(filename): 11 | filename_lower = filename.lower() 12 | return any(filename_lower.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.tif']) 13 | 14 | 15 | class RAW2RGBInputIterator(object): 16 | def __init__(self, dataset_dir, batch_size, div=88000, test=False): 17 | self.batch_size = batch_size 18 | data_dir = join(dataset_dir, "RAW") 19 | label_dir = join(dataset_dir, "RGB") 20 | 21 | data_filenames = [join(data_dir, x) for x in listdir(data_dir) if is_image_file(x)] 22 | label_filenames = [join(label_dir, x) for x in listdir(label_dir) if is_image_file(x)] 23 | 24 | data_filenames.sort() 25 | label_filenames.sort() 26 | 27 | data_filenames = data_filenames[div:] if test else data_filenames[:div] 28 | label_filenames = label_filenames[div:] if test else label_filenames[:div] 29 | 30 | data_label_filenames = list(zip(data_filenames, label_filenames)) 31 | shuffle(data_label_filenames) 32 | data_filenames, label_filenames = zip(*data_label_filenames) 33 | 34 | self.data_filenames = data_filenames 35 | self.label_filename = label_filenames 36 | 37 | def __iter__(self): 38 | self.i = 0 39 | self.n = len(self.data_filenames) 40 | return self 41 | 42 | def __next__(self): 43 | batch_data = [] 44 | batch_label = [] 45 | for _ in range(self.batch_size): 46 | data_path = self.data_filenames[self.i] 47 | label_path = self.label_filename[self.i] 48 | f_data = open(data_path, 'rb') 49 | f_label = open(label_path, 'rb') 50 | batch_data.append(np.frombuffer(f_data.read(), dtype=np.uint8)) 51 | batch_label.append(np.frombuffer(f_label.read(), dtype=np.uint8)) 52 | self.i = (self.i + 1) % self.n 53 | return batch_data, batch_label 54 | 55 | next = __next__ 56 | 57 | 58 | class HybridTrainPipe(Pipeline): 59 | def __init__(self, dataset_dir, batch_size, num_threads, device_id, crop, dali_cpu=False, local_rank=0, test=False): 60 | super(HybridTrainPipe, self).__init__(batch_size, num_threads, device_id, seed=666) 61 | self.raw2rgbit = iter(RAW2RGBInputIterator(dataset_dir, batch_size, test=test)) 62 | dali_device = "gpu" 63 | self.input_data = ops.ExternalSource() 64 | self.input_label = ops.ExternalSource() 65 | self.data_decode = ops.ImageDecoder(device="mixed") 66 | self.label_decode = ops.ImageDecoder(device="mixed", output_type=types.RGB) 67 | self.uniform = ops.Uniform(range=(0., 1.)) 68 | self.crop = ops.Crop(device=dali_device, crop_h=crop, crop_w=crop) 69 | self.cmnp = ops.CropMirrorNormalize(device="gpu", 70 | output_dtype=types.FLOAT, 71 | output_layout=types.NCHW) 72 | self.coin = ops.CoinFlip(probability=0.5) 73 | 74 | def iter_setup(self): 75 | data, label = self.raw2rgbit.next() 76 | self.feed_input(self.data, data) 77 | self.feed_input(self.label, label) 78 | 79 | def define_graph(self): 80 | rng = self.coin() 81 | self.data = self.input_data() 82 | self.label = self.input_label() 83 | data_im = self.data_decode(self.data) 84 | label_im = self.label_decode(self.label) 85 | return data_im, label_im 86 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch.utils.data as data 3 | from os import listdir 4 | from os.path import join 5 | import random 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def is_image_file(filename): 11 | filename_lower = filename.lower() 12 | return any(filename_lower.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.tif']) 13 | 14 | 15 | def get_patch(*args, patch_size): 16 | if patch_size == 0: 17 | return args 18 | ih, iw = args[0].shape[:2] 19 | ix = random.randrange(0, iw - patch_size + 1) 20 | iy = random.randrange(0, ih - patch_size + 1) 21 | 22 | ret = [*[a[iy:iy + patch_size, ix:ix + patch_size, :] for a in args]] 23 | 24 | return ret 25 | 26 | 27 | def augment(*args, hflip=True, rot=False): 28 | hflip = hflip and random.random() < 0.5 29 | vflip = rot and random.random() < 0.5 30 | rot90 = rot and random.random() < 0.5 31 | 32 | def _augment(img): 33 | if hflip: img = img[:, ::-1, :] 34 | if vflip: img = img[::-1, :, :] 35 | if rot90: img = img.transpose(1, 0, 2) 36 | 37 | return img 38 | 39 | return [_augment(a) for a in args] 40 | 41 | 42 | def np2Tensor(*args, rgb_range=1.): 43 | def _np2Tensor(img): 44 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 45 | tensor = torch.from_numpy(np_transpose).float() 46 | tensor.mul_(rgb_range / 255) 47 | 48 | return tensor 49 | 50 | return [_np2Tensor(a) for a in args] 51 | 52 | 53 | class RAW2RGBData(data.Dataset): 54 | def __init__(self, dataset_dir, patch_size=0, test=False): 55 | super(RAW2RGBData, self).__init__() 56 | self.patch_size = patch_size 57 | self.test = test 58 | data_dir = join(dataset_dir, "RAW") 59 | label_dir = join(dataset_dir, "RGB") 60 | 61 | data_filenames = [join(data_dir, x) for x in listdir(data_dir) if is_image_file(x)] 62 | label_filenames = [join(label_dir, x) for x in listdir(label_dir) if is_image_file(x)] 63 | 64 | label_filenames.sort() 65 | data_filenames.sort() 66 | 67 | # data_filenames = data_filenames[:1200] 68 | # label_filenames = label_filenames[:1200] 69 | 70 | data_filenames = data_filenames[::200] if test else list(set(data_filenames) - set(data_filenames[::200])) 71 | label_filenames = label_filenames[::200] if test else list(set(label_filenames) - set(label_filenames[::200])) 72 | label_filenames.sort() 73 | data_filenames.sort() 74 | 75 | self.data_filenames = data_filenames 76 | self.label_filenames = label_filenames 77 | 78 | def __getitem__(self, index): 79 | data = np.asarray(Image.open(self.data_filenames[index])) 80 | label = np.asarray(Image.open(self.label_filenames[index])) 81 | 82 | data, label = get_patch(data, label, patch_size=self.patch_size) 83 | if not self.test: 84 | data, label = augment(data, label) 85 | data, label = np2Tensor(data, label) 86 | 87 | return data, label 88 | 89 | def __len__(self): 90 | return len(self.data_filenames) 91 | -------------------------------------------------------------------------------- /data_noise.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch.utils.data as data 3 | from os import listdir 4 | from os.path import join 5 | import random 6 | import numpy as np 7 | import torch 8 | 9 | 10 | def add_noise(x, noise='.'): 11 | if noise is not '.': 12 | noise_type = noise[0] 13 | noise_value = int(noise[1:]) 14 | if noise_type == 'G': 15 | noises = np.random.normal(scale=noise_value, size=x.shape) 16 | noises = noises.round() 17 | elif noise_type == 'S': 18 | noises = np.random.poisson(x * noise_value) / noise_value 19 | noises = noises - noises.mean(axis=0).mean(axis=0) 20 | 21 | x_noise = x.astype(np.int16) + noises.astype(np.int16) 22 | x_noise = x_noise.clip(0, 255).astype(np.uint8) 23 | return x_noise 24 | else: 25 | return x 26 | 27 | 28 | def is_image_file(filename): 29 | filename_lower = filename.lower() 30 | return any(filename_lower.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.tif']) 31 | 32 | 33 | def get_patch(*args, patch_size): 34 | if patch_size == 0: 35 | return args 36 | ih, iw = args[0].shape[:2] 37 | ix = random.randrange(0, iw - patch_size + 1) 38 | iy = random.randrange(0, ih - patch_size + 1) 39 | 40 | ret = [*[a[iy:iy + patch_size, ix:ix + patch_size, :] for a in args]] 41 | 42 | return ret 43 | 44 | 45 | def augment(*args, hflip=True, rot=False): 46 | hflip = hflip and random.random() < 0.5 47 | vflip = rot and random.random() < 0.5 48 | rot90 = rot and random.random() < 0.5 49 | 50 | def _augment(img): 51 | if hflip: img = img[:, ::-1, :] 52 | if vflip: img = img[::-1, :, :] 53 | if rot90: img = img.transpose(1, 0, 2) 54 | 55 | return img 56 | 57 | return [_augment(a) for a in args] 58 | 59 | 60 | def np2Tensor(*args, rgb_range=1.): 61 | def _np2Tensor(img): 62 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 63 | tensor = torch.from_numpy(np_transpose).float() 64 | tensor.mul_(rgb_range / 255) 65 | 66 | return tensor 67 | 68 | return [_np2Tensor(a) for a in args] 69 | 70 | 71 | class RAW2RGBData(data.Dataset): 72 | def __init__(self, dataset_dir, patch_size=0, test=False): 73 | super(RAW2RGBData, self).__init__() 74 | self.patch_size = patch_size 75 | self.test = test 76 | data_dir = join(dataset_dir, "RAW") 77 | label_dir = join(dataset_dir, "RGB") 78 | 79 | data_filenames = [join(data_dir, x) for x in listdir(data_dir) if is_image_file(x)] 80 | label_filenames = [join(label_dir, x) for x in listdir(label_dir) if is_image_file(x)] 81 | 82 | label_filenames.sort() 83 | data_filenames.sort() 84 | 85 | # data_filenames = data_filenames[:1200] 86 | # label_filenames = label_filenames[:1200] 87 | 88 | data_filenames = data_filenames[::200] if test else list(set(data_filenames) - set(data_filenames[::200])) 89 | label_filenames = label_filenames[::200] if test else list(set(label_filenames) - set(label_filenames[::200])) 90 | label_filenames.sort() 91 | data_filenames.sort() 92 | 93 | self.data_filenames = data_filenames 94 | self.label_filenames = label_filenames 95 | 96 | def __getitem__(self, index): 97 | data = np.asarray(Image.open(self.data_filenames[index])) 98 | add_noise(data, 'G1') 99 | label = np.asarray(Image.open(self.label_filenames[index])) 100 | 101 | data, label = get_patch(data, label, patch_size=self.patch_size) 102 | if not self.test: 103 | data, label = augment(data, label) 104 | data, label = np2Tensor(data, label) 105 | 106 | return data, label 107 | 108 | def __len__(self): 109 | return len(self.data_filenames) 110 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MKFMIKU/RAW2RGBNet/0aff2a764dba656f45557f6cc404b79a88acec4c/models/__init__.py -------------------------------------------------------------------------------- /models/full_mix3_deep_encoder_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | def make_model(opts): 7 | return EncoderDecoderNet(n_feats=32, n_blocks=10, n_resgroups=16) 8 | 9 | 10 | class MSRB(nn.Module): 11 | def __init__(self, n_feats=64): 12 | super(MSRB, self).__init__() 13 | 14 | kernel_size_1 = 3 15 | kernel_size_2 = 5 16 | 17 | self.conv_3_1 = nn.Conv2d(n_feats, n_feats, kernel_size_1, stride=1, padding=kernel_size_1 // 2) 18 | self.conv_3_2 = nn.Conv2d(n_feats * 2, n_feats * 2, kernel_size_1, stride=1, padding=kernel_size_1 // 2) 19 | self.conv_5_1 = nn.Conv2d(n_feats, n_feats, kernel_size_2, stride=1, padding=kernel_size_2 // 2) 20 | self.conv_5_2 = nn.Conv2d(n_feats * 2, n_feats * 2, kernel_size_2, stride=1, padding=kernel_size_2 // 2) 21 | self.confusion = nn.Conv2d(n_feats * 4, n_feats, 1, padding=0, stride=1) 22 | self.relu = nn.PReLU() 23 | 24 | def forward(self, x): 25 | input_1 = x 26 | output_3_1 = self.relu(self.conv_3_1(input_1)) 27 | output_5_1 = self.relu(self.conv_5_1(input_1)) 28 | input_2 = torch.cat([output_3_1, output_5_1], 1) 29 | output_3_2 = self.relu(self.conv_3_2(input_2)) 30 | output_5_2 = self.relu(self.conv_5_2(input_2)) 31 | input_3 = torch.cat([output_3_2, output_5_2], 1) 32 | output = self.confusion(input_3) 33 | output += x 34 | return output 35 | 36 | 37 | class RB(nn.Module): 38 | def __init__(self, n_feats, nm='in'): 39 | super(RB, self).__init__() 40 | module_body = [] 41 | for i in range(2): 42 | module_body.append(nn.Conv2d(n_feats, n_feats, kernel_size=3, stride=1, padding=1, bias=True)) 43 | if nm == 'in': 44 | module_body.append(nn.InstanceNorm2d(n_feats, affine=True)) 45 | if nm == 'bn': 46 | module_body.append(nn.BatchNorm2d(n_feats)) 47 | if i == 0: 48 | module_body.append(nn.PReLU()) 49 | self.module_body = nn.Sequential(*module_body) 50 | 51 | def forward(self, x): 52 | res = self.module_body(x) 53 | res += x 54 | return res 55 | 56 | 57 | class RBGroup(nn.Module): 58 | def __init__(self, n_feats, n_blocks, nm='in'): 59 | super(RBGroup, self).__init__() 60 | module_body = [ 61 | RB(n_feats, nm) for _ in range(n_blocks) 62 | ] 63 | module_body.append(nn.Conv2d(n_feats, n_feats, kernel_size=3, stride=1, padding=1, bias=True)) 64 | self.module_body = nn.Sequential(*module_body) 65 | 66 | def forward(self, x): 67 | res = self.module_body(x) 68 | res += x 69 | return res 70 | 71 | 72 | class EncoderDecoderNet(nn.Module): 73 | def __init__(self, n_feats, n_blocks, n_resgroups, nm=None): 74 | super(EncoderDecoderNet, self).__init__() 75 | self.n_feats = n_feats 76 | self.n_blocks = n_blocks 77 | self.n_resgroups = n_resgroups 78 | self.nm = nm 79 | self.__build_model() 80 | 81 | def __build_model(self): 82 | 83 | self.fix_path = nn.Sequential( 84 | nn.Conv2d(4, self.n_feats, kernel_size=3, stride=1, padding=2, bias=True), 85 | nn.PReLU(), 86 | nn.Conv2d(self.n_feats, self.n_feats, kernel_size=3, stride=2, padding=1, bias=True), 87 | nn.PReLU(), 88 | nn.Conv2d(self.n_feats, self.n_feats, kernel_size=3, stride=2, padding=1, bias=True), 89 | nn.PReLU(), 90 | nn.Conv2d(self.n_feats, self.n_feats, kernel_size=3, stride=2, padding=1, bias=True), 91 | nn.PReLU(), 92 | nn.Conv2d(self.n_feats, self.n_feats, kernel_size=3, stride=2, padding=1, bias=True), 93 | nn.PReLU(), 94 | nn.Conv2d(self.n_feats, self.n_feats * 2, kernel_size=3, stride=2, padding=1, bias=False), 95 | nn.PReLU(), 96 | nn.AdaptiveAvgPool2d(1) 97 | ) 98 | 99 | self.head = nn.Sequential( 100 | nn.Conv2d(4, self.n_feats * 2, kernel_size=3, stride=1, padding=1, bias=True), 101 | nn.PReLU() 102 | ) 103 | 104 | self.downer = nn.Sequential( 105 | nn.Conv2d(self.n_feats * 2, self.n_feats * 2, kernel_size=3, stride=2, padding=1, bias=True), 106 | nn.PReLU(), 107 | nn.Conv2d(self.n_feats * 2, self.n_feats * 4, kernel_size=3, stride=2, padding=1, bias=True) 108 | ) 109 | local_path = [ 110 | RBGroup(n_feats=self.n_feats * 4, nm=self.nm, n_blocks=self.n_blocks) for _ in range(self.n_resgroups) 111 | ] 112 | local_path.append(nn.Conv2d(self.n_feats * 4, self.n_feats * 4, kernel_size=3, stride=1, padding=1, bias=True)) 113 | self.local_path = nn.Sequential(*local_path) 114 | self.uper = nn.Sequential( 115 | nn.ConvTranspose2d(self.n_feats * 4, self.n_feats * 2, kernel_size=4, stride=2, padding=1, bias=True), 116 | nn.PReLU(), 117 | nn.ConvTranspose2d(self.n_feats * 2, self.n_feats * 2, kernel_size=4, stride=2, padding=1, bias=True) 118 | ) 119 | 120 | self.global_path = nn.Sequential( 121 | MSRB(self.n_feats * 2), 122 | MSRB(self.n_feats * 2), 123 | MSRB(self.n_feats * 2), 124 | MSRB(self.n_feats * 2), 125 | MSRB(self.n_feats * 2), 126 | MSRB(self.n_feats * 2), 127 | MSRB(self.n_feats * 2), 128 | MSRB(self.n_feats * 2), 129 | ) 130 | self.global_down = nn.Conv2d(self.n_feats * 8 * 2, self.n_feats * 2, kernel_size=3, stride=1, padding=1, bias=True) 131 | 132 | self.linear = nn.Sequential( 133 | nn.Conv2d(self.n_feats * 4, self.n_feats * 2, kernel_size=3, stride=1, padding=1, bias=True), 134 | nn.PReLU() 135 | ) 136 | 137 | self.tail = nn.Conv2d(self.n_feats * 2, 3, kernel_size=3, stride=1, padding=1, bias=True) 138 | 139 | def forward(self, x): 140 | fix_s = F.interpolate(x, size=192, mode='bilinear') 141 | fix_s = self.fix_path(fix_s) 142 | 143 | x = self.head(x) 144 | 145 | x_down = self.downer(x) 146 | local_fea = self.local_path(x_down) 147 | local_fea += x_down 148 | local_fea = self.uper(local_fea) 149 | 150 | out = x 151 | msrb_out = [] 152 | for i in range(8): 153 | out = self.global_path[i](out) 154 | msrb_out.append(out) 155 | global_fea = torch.cat(msrb_out, 1) 156 | global_fea = self.global_down(global_fea) 157 | 158 | fused_fea = torch.cat([global_fea, local_fea], 1) 159 | 160 | fused_fea = self.linear(fused_fea) 161 | fused_fea += fix_s 162 | 163 | x = self.tail(fused_fea) 164 | return F.tanh(x) 165 | -------------------------------------------------------------------------------- /psnr.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import utils 4 | from PIL import Image 5 | import numpy as np 6 | from os.path import join 7 | from tqdm import tqdm 8 | import os 9 | 10 | parser = argparse.ArgumentParser(description="PyTorch DeepDehazing") 11 | parser.add_argument("--data", type=str, default="output", help="path to load data images") 12 | parser.add_argument("--gt", type=str, help="path to load gt images") 13 | parser.add_argument("--view", type=str, help="path to save data gt images") 14 | 15 | opt = parser.parse_args() 16 | print(opt) 17 | 18 | if not os.path.exists(opt.view): 19 | os.makedirs(opt.view) 20 | 21 | datas = utils.load_all_image(opt.data) 22 | datas.sort() 23 | 24 | 25 | def output_psnr_mse(img_orig, img_out): 26 | squared_error = np.square(img_orig - img_out) 27 | mse = np.mean(squared_error) 28 | psnr = 10 * np.log10(1.0 / mse) 29 | return psnr 30 | 31 | 32 | psnrs = [] 33 | for data_p in tqdm(datas): 34 | data = Image.open(data_p) 35 | gt = Image.open(join(opt.gt, data_p.split('/')[-1][:-3]+'jpg')) 36 | w, h = data.size 37 | new_im = Image.new('RGB', (w * 2, h)) 38 | new_im.paste(data, (0, 0)) 39 | new_im.paste(gt, (w, 0)) 40 | new_im.save(os.path.join(opt.view, data_p.split('/')[-1])) 41 | 42 | data = np.asarray(data).astype(float) / 255.0 43 | gt = np.asarray(gt).astype(float) / 255.0 44 | psnr = output_psnr_mse(data, gt) 45 | psnrs.append(psnr) 46 | print("mean PSNR:", np.mean(psnrs)) 47 | -------------------------------------------------------------------------------- /readme.txt: -------------------------------------------------------------------------------- 1 | runtime per image [s] : 0.7 2 | CPU[1] / GPU[0] : 0 3 | Extra Data [1] / No Extra Data [0] : 0 4 | Other description : Solution based on RCAN and HDRNet 5 | -------------------------------------------------------------------------------- /result_ensemble.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from PIL import Image 4 | import utils 5 | import numpy as np 6 | 7 | parser = argparse.ArgumentParser(description="Test Script") 8 | parser.add_argument("--output", type=str, required=True, help="path to save output images") 9 | parser.add_argument("--datas", type=str, required=True, help="path to load datas images") 10 | 11 | opt = parser.parse_args() 12 | print(opt) 13 | 14 | if not os.path.exists(opt.output): 15 | os.makedirs(opt.output) 16 | 17 | datas = opt.datas.split(',') 18 | data_images = [utils.load_all_image(data) for data in datas] 19 | [data_image.sort() for data_image in data_images] 20 | for i, p in enumerate(data_images[0]): 21 | filename = p.split('/')[-1] 22 | image_paths = [ps[i] for ps in data_images] 23 | images = [Image.open(ip) for ip in image_paths] 24 | images_np = [np.asarray(image) for image in images] 25 | output = np.mean(images_np, axis=0) 26 | output = output.round() 27 | output[output >= 255] = 255 28 | output[output <= 0] = 0 29 | output = Image.fromarray(output.astype(np.uint8), mode='RGB') 30 | output.save(os.path.join(opt.output, filename)) 31 | -------------------------------------------------------------------------------- /test-full.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from importlib import import_module 4 | from PIL import Image 5 | import torch 6 | from torchvision import transforms 7 | from tqdm import tqdm 8 | import utils 9 | import torchvision.transforms.functional as F 10 | import numpy as np 11 | 12 | parser = argparse.ArgumentParser(description="Test Script") 13 | parser.add_argument( 14 | "--model", 15 | required=True, 16 | type=str, 17 | help="name of model for this training" 18 | ) 19 | parser.add_argument("--checkpoint", type=str, required=True, help="path to load model checkpoint") 20 | parser.add_argument("--output", type=str, required=True, help="path to save output images") 21 | parser.add_argument("--data", type=str, required=True, help="path to load data images") 22 | 23 | opt = parser.parse_args() 24 | print(opt) 25 | 26 | if not os.path.exists(opt.output): 27 | os.makedirs(opt.output) 28 | 29 | model = import_module('models.' + opt.model.lower()).make_model(opt) 30 | model.load_state_dict(torch.load(opt.checkpoint)['state_dict_model']) 31 | model = model.cuda() 32 | model = model.eval() 33 | 34 | images = utils.load_all_image(opt.data) 35 | images.sort() 36 | 37 | 38 | def infer(im): 39 | w, h = im.size 40 | pad_w = 4 - w % 4 41 | pad_h = 4 - h % 4 42 | to_tensor = transforms.ToTensor() 43 | 44 | im_pad = transforms.Pad(padding=(pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2), padding_mode='reflect')(im) 45 | im_augs = [ 46 | to_tensor(im_pad), 47 | to_tensor(F.hflip(im_pad)), 48 | to_tensor(F.vflip(im_pad)), 49 | to_tensor(F.hflip(F.vflip(im_pad))) 50 | ] 51 | output_augs = [] 52 | for im_pad in im_augs: 53 | im_pad_th = im_pad.unsqueeze(0).cuda() 54 | with torch.no_grad(): 55 | torch.cuda.empty_cache() 56 | output = model(im_pad_th) 57 | output_augs.append(np.transpose(output.squeeze(0).cpu().numpy(), (1, 2, 0))) 58 | output_augs = [ 59 | output_augs[0], 60 | np.fliplr(output_augs[1]), 61 | np.flipud(output_augs[2]), 62 | np.fliplr(np.flipud(output_augs[3])) 63 | ] 64 | output = np.mean(output_augs, axis=0) * 255. 65 | output = output[pad_h // 2:-(pad_h - pad_h // 2), pad_w // 2:-(pad_w - pad_w // 2), :] 66 | output = output.round() 67 | output[output >= 255] = 255 68 | output[output <= 0] = 0 69 | output = Image.fromarray(output.astype(np.uint8), mode='RGB') 70 | return output 71 | 72 | 73 | for im_path in tqdm(images): 74 | filename = im_path.split('/')[-1] 75 | img = Image.open(im_path) 76 | output = infer(img) 77 | assert output.size == img.size 78 | output.save(os.path.join(opt.output, filename)) 79 | -------------------------------------------------------------------------------- /test-pad.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from importlib import import_module 4 | 5 | from PIL import Image 6 | import numpy as np 7 | import torch 8 | from torch import nn, optim 9 | from torchvision import transforms 10 | from tqdm import tqdm 11 | from skimage.io import imsave 12 | 13 | import utils 14 | import gc 15 | 16 | parser = argparse.ArgumentParser(description="Test Script") 17 | parser.add_argument( 18 | "--model", 19 | required=True, 20 | type=str, 21 | help="name of model for this training" 22 | ) 23 | parser.add_argument("--checkpoint", type=str, required=True, help="path to load model checkpoint") 24 | parser.add_argument("--output", type=str, required=True, help="path to save output images") 25 | parser.add_argument("--data", type=str, required=True, help="path to load data images") 26 | 27 | opt = parser.parse_args() 28 | print(opt) 29 | 30 | if not os.path.exists(opt.output): 31 | os.makedirs(opt.output) 32 | 33 | model = import_module('models.' + opt.model.lower()).make_model(opt) 34 | model.load_state_dict(torch.load(opt.checkpoint)['state_dict_model']) 35 | model = model.cuda() 36 | model = model.eval() 37 | 38 | images = utils.load_all_image(opt.data) 39 | images.sort() 40 | 41 | 42 | def add_noise(x, noise='.'): 43 | if noise is not '.': 44 | noise_type = noise[0] 45 | noise_value = int(noise[1:]) 46 | if noise_type == 'G': 47 | noises = np.random.normal(scale=noise_value, size=x.shape) 48 | noises = noises.round() 49 | elif noise_type == 'S': 50 | noises = np.random.poisson(x * noise_value) / noise_value 51 | noises = noises - noises.mean(axis=0).mean(axis=0) 52 | 53 | x_noise = x.astype(np.int16) + noises.astype(np.int16) 54 | x_noise = x_noise.clip(0, 255).astype(np.uint8) 55 | return x_noise 56 | else: 57 | return x 58 | 59 | 60 | def infer(im): 61 | w, h = im.size 62 | pad_w = 8 - w % 8 63 | pad_h = 8 - h % 8 64 | padding = 100 65 | 66 | im_pad = transforms.Pad(padding=(pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2), padding_mode='reflect')(im) 67 | im_pad = np.asarray(im_pad) 68 | im_pad = add_noise(im_pad, 'G1') 69 | im_pad_th = transforms.ToTensor()(im_pad) 70 | im_pad_th = im_pad_th.unsqueeze(0).cuda() 71 | _, _, _, ww = im_pad_th.shape 72 | im_pad_th_l, im_pad_th_r = im_pad_th[:, :, :, :ww//2 + padding], im_pad_th[:, :, :, ww//2-padding:] 73 | with torch.no_grad(): 74 | torch.cuda.empty_cache() 75 | im_pad_th_l = model(im_pad_th_l) 76 | torch.cuda.empty_cache() 77 | im_pad_th_r = model(im_pad_th_r) 78 | pad_th = (im_pad_th_l[:, :, :, -padding * 2:] + im_pad_th_r[:, :, :, :padding * 2]) / 2 79 | output = torch.cat((im_pad_th_l[:, :, :, :-padding*2], pad_th, im_pad_th_r[:, :, :, padding*2:]), dim=-1) 80 | output = output.squeeze(0).cpu() 81 | output = torch.clamp(output, 0., 1.) 82 | output = transforms.ToPILImage()(output) 83 | return output 84 | 85 | 86 | for im_path in tqdm(images): 87 | filename = im_path.split('/')[-1] 88 | img = Image.open(im_path) 89 | output = infer(img) 90 | output.save(os.path.join(opt.output, filename)) 91 | -------------------------------------------------------------------------------- /test-val.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from importlib import import_module 4 | from PIL import Image 5 | import torch 6 | from torchvision import transforms 7 | from tqdm import tqdm 8 | import utils 9 | import numpy as np 10 | import torchvision.transforms.functional as F 11 | 12 | parser = argparse.ArgumentParser(description="Test Script") 13 | parser.add_argument( 14 | "--model", 15 | required=True, 16 | type=str, 17 | help="name of model for this training" 18 | ) 19 | parser.add_argument("--checkpoint", type=str, required=True, help="path to load model checkpoint") 20 | parser.add_argument("--output", type=str, required=True, help="path to save output images") 21 | parser.add_argument("--data", type=str, required=True, help="path to load data images") 22 | 23 | opt = parser.parse_args() 24 | print(opt) 25 | 26 | if not os.path.exists(opt.output): 27 | os.makedirs(opt.output) 28 | 29 | model = import_module('models.' + opt.model.lower()).make_model(opt) 30 | model.load_state_dict(torch.load(opt.checkpoint)['state_dict_model']) 31 | model = model.cuda() 32 | model = model.eval() 33 | 34 | images = utils.load_all_image(opt.data) 35 | images.sort() 36 | images = images[::200] 37 | 38 | 39 | def infer(im): 40 | to_tensor = transforms.ToTensor() 41 | im_augs = [ 42 | to_tensor(im), 43 | to_tensor(F.hflip(im)), 44 | to_tensor(F.vflip(im)), 45 | to_tensor(F.hflip(F.vflip(im))), 46 | ] 47 | im_augs = torch.stack(im_augs) 48 | im_augs = im_augs.cuda() 49 | with torch.no_grad(): 50 | output_augs = model(im_augs) 51 | output_augs = np.transpose(output_augs.cpu().numpy(), (0, 2, 3, 1)) 52 | output_augs = [ 53 | output_augs[0], 54 | np.fliplr(output_augs[1]), 55 | np.flipud(output_augs[2]), 56 | np.fliplr(np.flipud(output_augs[3])), 57 | ] 58 | output = np.mean(output_augs, axis=0) * 255. 59 | output = output.round() 60 | output[output >= 255] = 255 61 | output[output <= 0] = 0 62 | output = Image.fromarray(output.astype(np.uint8), mode='RGB') 63 | return output 64 | 65 | 66 | for im_path in tqdm(images): 67 | filename = im_path.split('/')[-1] 68 | img = Image.open(im_path) 69 | img = infer(img) 70 | img.save(os.path.join(opt.output, filename)) 71 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from importlib import import_module 4 | from PIL import Image 5 | import torch 6 | from torchvision import transforms 7 | from tqdm import tqdm 8 | import utils 9 | import numpy as np 10 | import torchvision.transforms.functional as F 11 | 12 | parser = argparse.ArgumentParser(description="Test Script") 13 | parser.add_argument( 14 | "--model", 15 | required=True, 16 | type=str, 17 | help="name of model for this training" 18 | ) 19 | parser.add_argument("--checkpoint", type=str, required=True, help="path to load model checkpoint") 20 | parser.add_argument("--output", type=str, required=True, help="path to save output images") 21 | parser.add_argument("--data", type=str, required=True, help="path to load data images") 22 | 23 | opt = parser.parse_args() 24 | print(opt) 25 | 26 | if not os.path.exists(opt.output): 27 | os.makedirs(opt.output) 28 | 29 | model = import_module('models.' + opt.model.lower()).make_model(opt) 30 | model.load_state_dict(torch.load(opt.checkpoint)['state_dict_model']) 31 | model = model.cuda() 32 | model = model.eval() 33 | 34 | images = utils.load_all_image(opt.data) 35 | images.sort() 36 | 37 | 38 | def infer(im): 39 | to_tensor = transforms.ToTensor() 40 | im_augs = [ 41 | to_tensor(im), 42 | to_tensor(F.hflip(im)), 43 | to_tensor(F.vflip(im)), 44 | to_tensor(F.hflip(F.vflip(im))), 45 | ] 46 | im_augs = torch.stack(im_augs) 47 | im_augs = im_augs.cuda() 48 | with torch.no_grad(): 49 | output_augs = model(im_augs) 50 | output_augs = np.transpose(output_augs.cpu().numpy(), (0, 2, 3, 1)) 51 | output_augs = [ 52 | output_augs[0], 53 | np.fliplr(output_augs[1]), 54 | np.flipud(output_augs[2]), 55 | np.fliplr(np.flipud(output_augs[3])), 56 | ] 57 | output = np.mean(output_augs, axis=0) * 255. 58 | output = output.round() 59 | output[output >= 255] = 255 60 | output[output <= 0] = 0 61 | output = Image.fromarray(output.astype(np.uint8), mode='RGB') 62 | return output 63 | 64 | 65 | for im_path in tqdm(images): 66 | filename = im_path.split('/')[-1] 67 | img = Image.open(im_path) 68 | img = infer(img) 69 | img.save(os.path.join(opt.output, filename)) 70 | -------------------------------------------------------------------------------- /train-noise.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import argparse 3 | import sys 4 | import os 5 | from importlib import import_module 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn, optim 10 | import torch.nn.functional as F 11 | from torch.backends import cudnn 12 | from torch.utils.data import DataLoader 13 | from torch.utils.tensorboard import SummaryWriter 14 | from torchvision.utils import make_grid 15 | 16 | from data_noise import RAW2RGBData 17 | 18 | from tqdm import tqdm 19 | 20 | from utils import save_checkpoint, plot_grad_flow, init_weights 21 | 22 | parser = argparse.ArgumentParser(description="Training Script") 23 | parser.add_argument("--name", required=True, type=str, help="name for training version") 24 | parser.add_argument("--div", type=int, default=88800, help="division of train && test data. Default=88000") 25 | parser.add_argument("--batchSize", type=int, default=64, help="training batch size. Default=64") 26 | parser.add_argument("--threads", type=int, default=8, help="threads for data loader to use. Default=8") 27 | parser.add_argument("--decay_epoch", type=int, default=200, help="epoch from which to start lr decay. Default=1000") 28 | parser.add_argument("--resume", default="", type=str, help="path to checkpoint. Default: none") 29 | parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number. Default=1") 30 | parser.add_argument("--n-epoch", type=int, default=2000, help="number of epochs to train. Default=2000") 31 | parser.add_argument("--cuda", default=True, action="store_true", help="Use cuda?") 32 | parser.add_argument("--lr", type=float, default=0.0001, help="learning rate. Default=1e-4") 33 | parser.add_argument("--size", type=int, default=64, help="size that crop image into") 34 | parser.add_argument( 35 | "--model", 36 | required=True, 37 | type=str, 38 | help="name of model for this training" 39 | ) 40 | parser.add_argument( 41 | "--data_root", 42 | required=True, 43 | type=str, 44 | help="path to load train datasets" 45 | ) 46 | parser.add_argument( 47 | "--checkpoint", 48 | required=True, 49 | type=str, 50 | help="path to save checkpoints" 51 | ) 52 | 53 | opts = parser.parse_args() 54 | print(opts) 55 | 56 | writer = SummaryWriter(comment=opts.name) 57 | writer.add_text("command", " ".join(sys.argv)) 58 | KWAI_SEED = 666 59 | torch.manual_seed(KWAI_SEED) 60 | np.random.seed(KWAI_SEED) 61 | 62 | 63 | cuda = opts.cuda 64 | cudnn.benchmark = True 65 | 66 | train_dataset = RAW2RGBData(opts.data_root, patch_size=opts.size) 67 | test_datasets = RAW2RGBData(opts.data_root, test=True) 68 | 69 | training_data_loader = DataLoader( 70 | dataset=train_dataset, 71 | batch_size=opts.batchSize, 72 | pin_memory=True, 73 | shuffle=True, 74 | num_workers=opts.threads, 75 | ) 76 | 77 | testing_data_loader = DataLoader( 78 | dataset=test_datasets, 79 | batch_size=1, 80 | num_workers=1, 81 | ) 82 | 83 | model = import_module('models.' + opts.model.lower()).make_model(opts) 84 | model_define_r = open(os.path.join("models", opts.model.lower() + ".py"), 'r') 85 | model_define = model_define_r.read() 86 | writer.add_text("models", model_define) 87 | model_define_r.close() 88 | criterion = nn.L1Loss() 89 | 90 | # init_weights(model, 'orthogonal') 91 | 92 | if opts.resume: 93 | if os.path.isfile(opts.resume): 94 | print("======> loading checkpoint at '{}'".format(opts.resume)) 95 | checkpoint = torch.load(opts.resume) 96 | model.load_state_dict(checkpoint["state_dict_model"], strict=False) 97 | else: 98 | print("======> founding no checkpoint at '{}'".format(opts.resume)) 99 | 100 | if cuda: 101 | model = nn.DataParallel(model).cuda() 102 | 103 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opts.lr, betas=(0.9, 0.999)) 104 | # optimizer = optim.Adam(model.parameters(), lr=opts.lr, betas=(0.9, 0.999)) 105 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.decay_epoch, gamma=0.1) 106 | 107 | for epoch in range(opts.start_epoch, opts.n_epoch + 1): 108 | print("epoch =", epoch, "lr =", optimizer.param_groups[0]["lr"]) 109 | model.train() 110 | 111 | pbar = tqdm(training_data_loader) 112 | output = None 113 | for iteration, batch in enumerate(pbar): 114 | data, label = batch[0], batch[1] 115 | data = data.cuda() if opts.cuda else data.cpu() 116 | label = label.cuda() if opts.cuda else label.cpu() 117 | 118 | model.zero_grad() 119 | output = model(data) 120 | loss = criterion(output, label) 121 | 122 | loss.backward() 123 | optimizer.step() 124 | 125 | if iteration % 100 == 0: 126 | pbar.set_description("Epoch[{}]({}/{}): Loss: {:.4f}".format( 127 | epoch, iteration, len(training_data_loader), loss.item()) 128 | ) 129 | writer.add_scalar("l1loss", loss.item(), iteration+(epoch-1)*len(training_data_loader)) 130 | lr_scheduler.step(epoch=epoch) 131 | writer.add_image("output", make_grid(output, range=[0., 1.]), epoch) 132 | save_checkpoint(model, opts.name, None, epoch, opts.checkpoint) 133 | if epoch % 1 == 0: 134 | mean_psnr = 0 135 | model.eval() 136 | for iteration, batch in enumerate(testing_data_loader, 1): 137 | data, label = batch[0], batch[1] 138 | data = data.cuda() if opts.cuda else data.cpu() 139 | label = label.cuda() if opts.cuda else label.cpu() 140 | 141 | with torch.no_grad(): 142 | output = model(data) 143 | output = torch.clamp(output, 0.0, 1.0) 144 | mse = F.mse_loss(output, label) 145 | psnr = 10 * np.log10(1.0 / mse.item()) 146 | mean_psnr += psnr 147 | mean_psnr /= len(testing_data_loader) 148 | writer.add_scalar("mean_psnr", mean_psnr, epoch) 149 | print("Vaild epoch %d psnr: %f" % (epoch, mean_psnr)) 150 | writer.close() 151 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | import argparse 3 | import sys 4 | import os 5 | from importlib import import_module 6 | 7 | import numpy as np 8 | import torch 9 | from torch import nn, optim 10 | import torch.nn.functional as F 11 | from torch.backends import cudnn 12 | from torch.utils.data import DataLoader 13 | from torch.utils.tensorboard import SummaryWriter 14 | from torchvision.utils import make_grid 15 | 16 | from data import RAW2RGBData 17 | 18 | from tqdm import tqdm 19 | 20 | from utils import save_checkpoint, plot_grad_flow, init_weights 21 | 22 | parser = argparse.ArgumentParser(description="Training Script") 23 | parser.add_argument("--name", required=True, type=str, help="name for training version") 24 | parser.add_argument("--div", type=int, default=88800, help="division of train && test data. Default=88000") 25 | parser.add_argument("--batchSize", type=int, default=64, help="training batch size. Default=64") 26 | parser.add_argument("--threads", type=int, default=8, help="threads for data loader to use. Default=8") 27 | parser.add_argument("--decay_epoch", type=int, default=200, help="epoch from which to start lr decay. Default=1000") 28 | parser.add_argument("--resume", default="", type=str, help="path to checkpoint. Default: none") 29 | parser.add_argument("--start-epoch", default=1, type=int, help="Manual epoch number. Default=1") 30 | parser.add_argument("--n-epoch", type=int, default=2000, help="number of epochs to train. Default=2000") 31 | parser.add_argument("--cuda", default=True, action="store_true", help="Use cuda?") 32 | parser.add_argument("--lr", type=float, default=0.0001, help="learning rate. Default=1e-4") 33 | parser.add_argument("--size", type=int, default=64, help="size that crop image into") 34 | parser.add_argument( 35 | "--model", 36 | required=True, 37 | type=str, 38 | help="name of model for this training" 39 | ) 40 | parser.add_argument( 41 | "--data_root", 42 | required=True, 43 | type=str, 44 | help="path to load train datasets" 45 | ) 46 | parser.add_argument( 47 | "--checkpoint", 48 | required=True, 49 | type=str, 50 | help="path to save checkpoints" 51 | ) 52 | 53 | opts = parser.parse_args() 54 | print(opts) 55 | 56 | writer = SummaryWriter(comment=opts.name) 57 | writer.add_text("command", " ".join(sys.argv)) 58 | KWAI_SEED = 666 59 | torch.manual_seed(KWAI_SEED) 60 | np.random.seed(KWAI_SEED) 61 | 62 | 63 | cuda = opts.cuda 64 | cudnn.benchmark = True 65 | 66 | train_dataset = RAW2RGBData(opts.data_root, patch_size=opts.size) 67 | test_datasets = RAW2RGBData(opts.data_root, test=True) 68 | 69 | training_data_loader = DataLoader( 70 | dataset=train_dataset, 71 | batch_size=opts.batchSize, 72 | pin_memory=True, 73 | shuffle=True, 74 | num_workers=opts.threads, 75 | ) 76 | 77 | testing_data_loader = DataLoader( 78 | dataset=test_datasets, 79 | batch_size=1, 80 | num_workers=1, 81 | ) 82 | 83 | model = import_module('models.' + opts.model.lower()).make_model(opts) 84 | model_define_r = open(os.path.join("models", opts.model.lower() + ".py"), 'r') 85 | model_define = model_define_r.read() 86 | writer.add_text("models", model_define) 87 | model_define_r.close() 88 | criterion = nn.L1Loss() 89 | 90 | # init_weights(model, 'orthogonal') 91 | 92 | if opts.resume: 93 | if os.path.isfile(opts.resume): 94 | print("======> loading checkpoint at '{}'".format(opts.resume)) 95 | checkpoint = torch.load(opts.resume) 96 | model.load_state_dict(checkpoint["state_dict_model"], strict=False) 97 | else: 98 | print("======> founding no checkpoint at '{}'".format(opts.resume)) 99 | 100 | if cuda: 101 | model = nn.DataParallel(model).cuda() 102 | 103 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=opts.lr, betas=(0.9, 0.999)) 104 | # optimizer = optim.Adam(model.parameters(), lr=opts.lr, betas=(0.9, 0.999)) 105 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.decay_epoch, gamma=0.1) 106 | 107 | for epoch in range(opts.start_epoch, opts.n_epoch + 1): 108 | print("epoch =", epoch, "lr =", optimizer.param_groups[0]["lr"]) 109 | model.train() 110 | 111 | pbar = tqdm(training_data_loader) 112 | output = None 113 | for iteration, batch in enumerate(pbar): 114 | data, label = batch[0], batch[1] 115 | data = data.cuda() if opts.cuda else data.cpu() 116 | label = label.cuda() if opts.cuda else label.cpu() 117 | 118 | model.zero_grad() 119 | output = model(data) 120 | loss = criterion(output, label) 121 | 122 | loss.backward() 123 | optimizer.step() 124 | 125 | if iteration % 100 == 0: 126 | pbar.set_description("Epoch[{}]({}/{}): Loss: {:.4f}".format( 127 | epoch, iteration, len(training_data_loader), loss.item()) 128 | ) 129 | writer.add_scalar("l1loss", loss.item(), iteration+(epoch-1)*len(training_data_loader)) 130 | lr_scheduler.step(epoch=epoch) 131 | writer.add_image("output", make_grid(output, range=[0., 1.]), epoch) 132 | save_checkpoint(model, opts.name, None, epoch, opts.checkpoint) 133 | if epoch % 1 == 0: 134 | mean_psnr = 0 135 | model.eval() 136 | for iteration, batch in enumerate(testing_data_loader, 1): 137 | data, label = batch[0], batch[1] 138 | data = data.cuda() if opts.cuda else data.cpu() 139 | label = label.cuda() if opts.cuda else label.cpu() 140 | 141 | with torch.no_grad(): 142 | output = model(data) 143 | output = torch.clamp(output, 0.0, 1.0) 144 | mse = F.mse_loss(output, label) 145 | psnr = 10 * np.log10(1.0 / mse.item()) 146 | mean_psnr += psnr 147 | mean_psnr /= len(testing_data_loader) 148 | writer.add_scalar("mean_psnr", mean_psnr, epoch) 149 | print("Vaild epoch %d psnr: %f" % (epoch, mean_psnr)) 150 | writer.close() 151 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os import listdir 3 | from os.path import join 4 | from os.path import exists 5 | # import torch 6 | import random 7 | # from torch.autograd import Variable 8 | # from torch.nn import init 9 | 10 | from matplotlib.lines import Line2D 11 | import matplotlib.pyplot as plt 12 | import numpy as np 13 | # from torch import nn 14 | 15 | import math 16 | 17 | 18 | def plot_grad_flow(named_parameters): 19 | '''Plots the gradients flowing through different layers in the net during training. 20 | Can be used for checking for possible gradient vanishing / exploding problems. 21 | 22 | Usage: Plug this function in Trainer class after loss.backwards() as 23 | "plot_grad_flow(self.model.named_parameters())" to visualize the gradient flow''' 24 | plt.switch_backend('agg') 25 | figure = plt.figure(figsize=(16, 4)) 26 | ave_grads = [] 27 | max_grads = [] 28 | layers = [] 29 | for n, p in named_parameters: 30 | if (p.requires_grad) and ("bias" not in n): 31 | layers.append(n) 32 | ave_grads.append(p.grad.abs().mean()) 33 | max_grads.append(p.grad.abs().max()) 34 | plt.bar(np.arange(len(max_grads)), max_grads, alpha=0.1, lw=1, color="c") 35 | plt.bar(np.arange(len(max_grads)), ave_grads, alpha=0.1, lw=1, color="b") 36 | plt.hlines(0, 0, len(ave_grads) + 1, lw=2, color="k") 37 | plt.xticks(range(0, len(ave_grads), 1), layers, rotation="vertical") 38 | plt.xlim(left=0, right=len(ave_grads)) 39 | plt.ylim(bottom=-0.001, top=0.02) # zoom in on the lower gradient regions 40 | plt.xlabel("Layers") 41 | plt.ylabel("average gradient") 42 | plt.title("Gradient flow") 43 | plt.grid(True) 44 | plt.legend([Line2D([0], [0], color="c", lw=4), 45 | Line2D([0], [0], color="b", lw=4), 46 | Line2D([0], [0], color="k", lw=4)], ['max-gradient', 'mean-gradient', 'zero-gradient']) 47 | return figure 48 | 49 | 50 | def is_image_file(filename): 51 | filename_lower = filename.lower() 52 | return any(filename_lower.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.tif']) 53 | 54 | 55 | def load_all_image(path): 56 | return [join(path, x) for x in listdir(path) if is_image_file(x)] 57 | 58 | 59 | def save_checkpoint(model, name, discriminator, epoch, model_folder): 60 | if not exists(model_folder): 61 | os.makedirs(model_folder) 62 | if not exists(join(model_folder, name)): 63 | os.makedirs(join(model_folder, name)) 64 | model_out_path = "%s/%d.pth" % (join(model_folder, name), epoch) 65 | 66 | state_dict_model = model.module.state_dict() 67 | 68 | for key in state_dict_model.keys(): 69 | state_dict_model[key] = state_dict_model[key].cpu() 70 | 71 | if discriminator: 72 | state_dict_discriminator = discriminator.module.state_dict() 73 | for key in state_dict_discriminator.keys(): 74 | state_dict_discriminator[key] = state_dict_discriminator[key].cpu() 75 | 76 | torch.save({"epoch": epoch, 77 | "state_dict_model": state_dict_model, 78 | "state_dict_discriminator": state_dict_discriminator}, model_out_path) 79 | else: 80 | torch.save({"epoch": epoch, 81 | "state_dict_model": state_dict_model}, model_out_path) 82 | print("Checkpoint saved to {}".format(model_out_path)) 83 | 84 | 85 | def init_weights(net, init_type='normal', init_gain=0.02): 86 | """Initialize network weights. 87 | Parameters: 88 | net (network) -- network to be initialized 89 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 90 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 91 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 92 | work better for some applications. Feel free to try yourself. 93 | """ 94 | 95 | def init_func(m): # define the initialization function 96 | classname = m.__class__.__name__ 97 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 98 | if init_type == 'normal': 99 | init.normal_(m.weight.data, 0.0, init_gain) 100 | elif init_type == 'xavier': 101 | init.xavier_normal_(m.weight.data, gain=init_gain) 102 | elif init_type == 'kaiming': 103 | init.kaiming_uniform_(m.weight.data, a=math.sqrt(5)) 104 | elif init_type == 'orthogonal': 105 | init.orthogonal_(m.weight.data, gain=init_gain) 106 | else: 107 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 108 | if hasattr(m, 'bias') and m.bias is not None: 109 | fan_in, _ = init._calculate_fan_in_and_fan_out(m.weight.data) 110 | bound = 1 / math.sqrt(fan_in) 111 | init.uniform_(m.bias.data, -bound, bound) 112 | 113 | elif classname.find('BatchNorm2d') != -1: 114 | init.normal_(m.weight.data, 1.0, init_gain) 115 | init.constant_(m.bias.data, 0.0) 116 | 117 | print('initialize network with %s' % init_type) 118 | net.apply(init_func) # apply the initialization function 119 | 120 | 121 | def quantize(img, rgb_range): 122 | return img.mul(rgb_range).clamp(0, 255).round() 123 | --------------------------------------------------------------------------------