├── datasets ├── __init__.py └── pix2pix.py ├── make_dataset ├── create_train.py └── random_select.py ├── model.py ├── model_pretrained └── AOD_net_epoch_relu_10.pth ├── readme.md ├── result ├── canyon1_dehaze.jpg ├── canyon2_dehaze.jpg ├── forest1_dehaze.jpg ├── school1_dehaze.jpg ├── school2_dehaze.jpg ├── school3_dehaze.jpg ├── school4_dehaze.jpg ├── tiananmen1_dehaze.jpg ├── tiananmen2_dehaze.jpg ├── tiananmen3_dehaze.jpg ├── underwater1_dehaze.jpg ├── underwater2_dehaze.jpg ├── underwater3_dehaze.jpg └── underwater4_dehaze.jpg ├── test.py ├── test ├── canyon1.jpg ├── canyon2.jpg ├── forest1.jpg ├── school1.jpg ├── school2.jpg ├── school3.jpg ├── school4.jpg ├── tiananmen1.png ├── tiananmen2.png ├── tiananmen3.jpg ├── underwater1.JPG ├── underwater2.JPG ├── underwater3.JPG └── underwater4.jpg ├── train.py └── transforms ├── __init__.py ├── __init__.pyc ├── __pycache__ ├── __init__.cpython-36.pyc └── pix2pix.cpython-36.pyc └── pix2pix.py /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /datasets/pix2pix.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path 4 | 5 | import numpy as np 6 | import scipy.ndimage 7 | import torch.utils.data as data 8 | from PIL import Image 9 | 10 | import h5py 11 | 12 | IMG_EXTENSIONS = [ 13 | '.jpg', '.JPG', '.jpeg', '.JPEG', 14 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 15 | ] 16 | 17 | def is_image_file(filename): 18 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 19 | 20 | def make_dataset(dir): 21 | images = [] 22 | if not os.path.isdir(dir): 23 | raise Exception('Check dataroot') 24 | for root, _, fnames in sorted(os.walk(dir)): 25 | for fname in fnames: 26 | path = os.path.join(dir, fname) 27 | item = path 28 | images.append(item) 29 | return images 30 | 31 | def default_loader(path): 32 | return Image.open(path).convert('RGB') 33 | 34 | class pix2pix(data.Dataset): 35 | def __init__(self, root, transform=None, loader=default_loader, seed=None): 36 | imgs = make_dataset(root) 37 | self.root = root 38 | self.imgs = imgs 39 | self.transform = transform 40 | self.loader = loader 41 | 42 | if seed is not None: 43 | np.random.seed(seed) 44 | 45 | def __getitem__(self, index): 46 | path = self.imgs[index] 47 | f=h5py.File(path,'r') 48 | 49 | haze_image=f['haze'][:] 50 | GT=f['gt'][:] 51 | 52 | haze_image=np.swapaxes(haze_image,0,2) 53 | GT=np.swapaxes(GT,0,2) 54 | 55 | haze_image=np.swapaxes(haze_image,1,2) 56 | GT=np.swapaxes(GT,1,2) 57 | return haze_image, GT 58 | 59 | def __len__(self): 60 | train_list=glob.glob(self.root+'/*h5') 61 | return len(train_list) 62 | -------------------------------------------------------------------------------- /make_dataset/create_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | # torch condiguration 4 | import argparse 5 | import math 6 | import os 7 | import pdb 8 | import pickle 9 | import random 10 | import shutil 11 | import sys 12 | import time 13 | from math import log10 14 | from random import uniform 15 | 16 | import matplotlib.cm as cm 17 | import matplotlib.patches as patches 18 | import matplotlib.pyplot as plt 19 | # import scipy.io as sio 20 | import numpy as np 21 | import PIL 22 | import scipy 23 | import scipy.io as sio 24 | import scipy.ndimage.interpolation 25 | from PIL import Image 26 | 27 | import h5py 28 | 29 | sys.path.append("./mingqingscript") 30 | plt.ion() 31 | 32 | def array2PIL(arr, size): 33 | mode = 'RGBA' 34 | arr = arr.reshape(arr.shape[0] * arr.shape[1], arr.shape[2]) 35 | if len(arr[0]) == 3: 36 | arr = np.c_[arr, 255 * np.ones((len(arr), 1), np.uint8)] 37 | return Image.frombuffer(mode, size, arr.tostring(), 'raw', mode, 0, 1) 38 | 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument('--nyu', type=str, required=True, help='path to nyu_depth_v2_labeled.mat') 41 | parser.add_argument('--dataset', type=str, require=True, help='path to synthesized hazy images dataset store') 42 | args = parser.parse_args() 43 | print(args) 44 | 45 | index = 1 46 | nyu_depth = h5py.File(args.nyu + '/nyu_depth_v2_labeled.mat', 'r') 47 | 48 | directory = args.dataset + '/train' 49 | saveimgdir = args.dataset + '/demo' 50 | if not os.path.exists(directory): 51 | os.makedirs(directory) 52 | if not os.path.exists(saveimgdir): 53 | os.makedirs(saveimgdir) 54 | 55 | image = nyu_depth['images'] 56 | depth = nyu_depth['depths'] 57 | 58 | img_size = 224 59 | 60 | # per=np.random.permutation(1400) 61 | # np.save('rand_per.py',per) 62 | # pdb.set_trace() 63 | total_num = 0 64 | plt.ion() 65 | for index in range(1445): 66 | index = index 67 | gt_image = (image[index, :, :, :]).astype(float) 68 | gt_image = np.swapaxes(gt_image, 0, 2) 69 | gt_image = scipy.misc.imresize(gt_image, [480, 640]).astype(float) 70 | gt_image = gt_image / 255 71 | 72 | gt_depth = depth[index, :, :] 73 | maxhazy = gt_depth.max() 74 | minhazy = gt_depth.min() 75 | gt_depth = (gt_depth) / (maxhazy) 76 | 77 | gt_depth = np.swapaxes(gt_depth, 0, 1) 78 | 79 | for j in range(7): 80 | for k in range(3): 81 | #beta 82 | bias = 0.05 83 | temp_beta = 0.4 + 0.2*j 84 | beta = uniform(temp_beta-bias, temp_beta+bias) 85 | 86 | tx1 = np.exp(-beta * gt_depth) 87 | 88 | #A 89 | abias = 0.1 90 | temp_a = 0.5 + 0.2*k 91 | a = uniform(temp_a-abias, temp_a+abias) 92 | A = [a,a,a] 93 | 94 | m = gt_image.shape[0] 95 | n = gt_image.shape[1] 96 | 97 | rep_atmosphere = np.tile(np.reshape(A, [1, 1, 3]), [m, n, 1]) 98 | tx1 = np.reshape(tx1, [m, n, 1]) 99 | 100 | max_transmission = np.tile(tx1, [1, 1, 3]) 101 | 102 | haze_image = gt_image * max_transmission + rep_atmosphere * (1 - max_transmission) 103 | 104 | total_num = total_num + 1 105 | scipy.misc.imsave(saveimgdir+'/haze.jpg', haze_image) 106 | scipy.misc.imsave(saveimgdir+'/gt.jpg', gt_image) 107 | 108 | h5f=h5py.File(directory+'/'+str(total_num)+'.h5','w') 109 | h5f.create_dataset('haze',data=haze_image) 110 | h5f.create_dataset('gt',data=gt_image) 111 | -------------------------------------------------------------------------------- /make_dataset/random_select.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import shutil 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--trainroot', type=str, required=True, help='path to train directory') 8 | parser.add_argument('--valroot', type=str, required=True, help='path to validation directory') 9 | args = parser.parse_args() 10 | print(args) 11 | 12 | train = args.trainroot 13 | val = args.valroot 14 | 15 | if not os.path.exists(val): 16 | os.makedirs(val) 17 | for i in range(3169): 18 | a = random.choice(os.listdir(train)) 19 | shutil.move((train+'/'+a), (val+'/'+a)) 20 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class AODnet(nn.Module): 8 | def __init__(self): 9 | super(AODnet, self).__init__() 10 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=1) 11 | self.conv2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1) 12 | self.conv3 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=5, padding=2) 13 | self.conv4 = nn.Conv2d(in_channels=6, out_channels=3, kernel_size=7, padding=3) 14 | self.conv5 = nn.Conv2d(in_channels=12, out_channels=3, kernel_size=3, padding=1) 15 | self.b = 1 16 | 17 | def forward(self, x): 18 | x1 = F.relu(self.conv1(x)) 19 | x2 = F.relu(self.conv2(x1)) 20 | cat1 = torch.cat((x1, x2), 1) 21 | x3 = F.relu(self.conv3(cat1)) 22 | cat2 = torch.cat((x2, x3),1) 23 | x4 = F.relu(self.conv4(cat2)) 24 | cat3 = torch.cat((x1, x2, x3, x4),1) 25 | k = F.relu(self.conv5(cat3)) 26 | 27 | if k.size() != x.size(): 28 | raise Exception("k, haze image are different size!") 29 | 30 | output = k * x - k + self.b 31 | return F.relu(output) 32 | -------------------------------------------------------------------------------- /model_pretrained/AOD_net_epoch_relu_10.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/model_pretrained/AOD_net_epoch_relu_10.pth -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # AOD-Net by Pytorch 2 | This project implements the [AOD-Net : All-in-One Network for Dehazing](https://arxiv.org/abs/1707.06543) for image dehazing using Python and PyTorch. The model is capable of removing haze, smoke, and water impurities from images." 3 | 4 | The repository includes: 5 | * Source code of AOD-Net 6 | * Building code for synthesized hazy images based on [NYU Depth V2](https://cs.nyu.edu/~silberman/datasets/nyu_depth_v2.html) 7 | * Training code for our hazy dataset 8 | * Pre-trained model for AOD-Net 9 | 10 | # Requirements 11 | Python 3.6, Pytorch 0.4.0 and other common packages 12 | 13 | ## NYU Depth V2 14 | To build synthetic hazy dataset, you'll also need: 15 | * Download [NYU Depth V2 labeled dataset](http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/nyu_depth_v2_labeled.mat) 16 | 17 | # Training Part 18 | ## Dateset Setup 19 | 1. Clone this repository 20 | 2. Create dataset from the repository root directory 21 | ```bash 22 | $ cd make_dataset 23 | $ python create_train.py --nyu {Your NYU Depth V2 path} --dataset {Your trainset path} 24 | ``` 25 | 3. Random pick 3,169 pictures as validation set 26 | ```bash 27 | $ python random_select.py --traindir {Your trainset path} --valdir {Your valset path} 28 | ``` 29 | ## Start to training 30 | 4. training AOD-Net 31 | ```bash 32 | $ python train.py --dataroot {Your trainset path} --valDataroot {Your valset path} --cuda 33 | ``` 34 | # Testing Part 35 | 5. test hazy image on AOD-Net 36 | ```bash 37 | $ python test.py --input_image /test/canyon1.jpg --model /model_pretrained/AOD_net_epoch_relu_10.pth --output_filename /result/canyon1_dehaze.jpg --cuda 38 | ``` 39 | -------------------------------------------------------------------------------- /result/canyon1_dehaze.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/result/canyon1_dehaze.jpg -------------------------------------------------------------------------------- /result/canyon2_dehaze.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/result/canyon2_dehaze.jpg -------------------------------------------------------------------------------- /result/forest1_dehaze.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/result/forest1_dehaze.jpg -------------------------------------------------------------------------------- /result/school1_dehaze.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/result/school1_dehaze.jpg -------------------------------------------------------------------------------- /result/school2_dehaze.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/result/school2_dehaze.jpg -------------------------------------------------------------------------------- /result/school3_dehaze.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/result/school3_dehaze.jpg -------------------------------------------------------------------------------- /result/school4_dehaze.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/result/school4_dehaze.jpg -------------------------------------------------------------------------------- /result/tiananmen1_dehaze.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/result/tiananmen1_dehaze.jpg -------------------------------------------------------------------------------- /result/tiananmen2_dehaze.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/result/tiananmen2_dehaze.jpg -------------------------------------------------------------------------------- /result/tiananmen3_dehaze.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/result/tiananmen3_dehaze.jpg -------------------------------------------------------------------------------- /result/underwater1_dehaze.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/result/underwater1_dehaze.jpg -------------------------------------------------------------------------------- /result/underwater2_dehaze.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/result/underwater2_dehaze.jpg -------------------------------------------------------------------------------- /result/underwater3_dehaze.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/result/underwater3_dehaze.jpg -------------------------------------------------------------------------------- /result/underwater4_dehaze.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/result/underwater4_dehaze.jpg -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import PIL.Image as Image 4 | import scipy.misc 5 | import skimage.io as sio 6 | import torch 7 | import torch.nn.parallel 8 | import torchvision.transforms as transforms 9 | from torch.autograd import Variable 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--input_image', type=str, required=True, help='input image to use') 13 | parser.add_argument('--model', type=str, required=True, help='model file to use') 14 | parser.add_argument('--output_filename', type=str, required=True, help='where to save the output image') 15 | parser.add_argument('--cuda', action='store_true', help='use cuda?') 16 | 17 | args = parser.parse_args() 18 | print(args) 19 | 20 | #===== load model ===== 21 | print('===> Loading model') 22 | net = torch.load(args.model) 23 | if args.cuda: 24 | net = net.cuda() 25 | 26 | #===== Load input image ===== 27 | transform = transforms.Compose([ 28 | transforms.ToTensor(), 29 | transforms.Normalize(mean = (0.5, 0.5, 0.5), std = (0.5, 0.5, 0.5)) 30 | ] 31 | ) 32 | 33 | img = Image.open(args.input_image).convert('RGB') 34 | imgIn = transform(img).unsqueeze_(0) 35 | 36 | #===== Test procedures ===== 37 | varIn = Variable(imgIn) 38 | if args.cuda: 39 | varIn = varIn.cuda() 40 | 41 | prediction = net(varIn) 42 | prediction = prediction.data.cpu().numpy().squeeze().transpose((1,2,0)) 43 | scipy.misc.toimage(prediction).save(args.output_filename) 44 | -------------------------------------------------------------------------------- /test/canyon1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/test/canyon1.jpg -------------------------------------------------------------------------------- /test/canyon2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/test/canyon2.jpg -------------------------------------------------------------------------------- /test/forest1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/test/forest1.jpg -------------------------------------------------------------------------------- /test/school1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/test/school1.jpg -------------------------------------------------------------------------------- /test/school2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/test/school2.jpg -------------------------------------------------------------------------------- /test/school3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/test/school3.jpg -------------------------------------------------------------------------------- /test/school4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/test/school4.jpg -------------------------------------------------------------------------------- /test/tiananmen1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/test/tiananmen1.png -------------------------------------------------------------------------------- /test/tiananmen2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/test/tiananmen2.png -------------------------------------------------------------------------------- /test/tiananmen3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/test/tiananmen3.jpg -------------------------------------------------------------------------------- /test/underwater1.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/test/underwater1.JPG -------------------------------------------------------------------------------- /test/underwater2.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/test/underwater2.JPG -------------------------------------------------------------------------------- /test/underwater3.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/test/underwater3.JPG -------------------------------------------------------------------------------- /test/underwater4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/test/underwater4.jpg -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import torch.nn.parallel 10 | import torch.optim as optim 11 | import torchvision.utils as vutils 12 | from torch.autograd import Variable 13 | from torch.utils.data import DataLoader 14 | from torchvision import datasets, transforms 15 | 16 | from model import AODnet 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--dataset', required=False, default='pix2pix', help='') 20 | parser.add_argument('--dataroot', required=True, help='path to trn dataset') 21 | parser.add_argument('--valDataroot', required=True, help='path to val dataset') 22 | parser.add_argument('--valBatchSize', type=int, default=32, help='input batch size') 23 | parser.add_argument('--cuda', action='store_true', help='use cuda?') 24 | parser.add_argument('--lr', type=float, default=1e-4, help='Learning Rate. Default=1e-4') 25 | parser.add_argument('--threads', type=int, default=4, help='number of threads for data loader to use, if Your OS is window, please set to 0') 26 | parser.add_argument('--exp', default='pretrain', help='folder to model checkpoints') 27 | parser.add_argument('--printEvery', type=int, default=50, help='number of batches to print average loss ') 28 | parser.add_argument('--batchSize', type=int, default=32, help='training batch size') 29 | parser.add_argument('--epochSize', type=int, default=840, help='number of batches as one epoch (for validating once)') 30 | parser.add_argument('--nEpochs', type=int, default=10, help='number of epochs for training') 31 | 32 | args = parser.parse_args() 33 | print(args) 34 | 35 | args.manualSeed = random.randint(1, 10000) 36 | random.seed(args.manualSeed) 37 | torch.manual_seed(args.manualSeed) 38 | torch.cuda.manual_seed_all(args.manualSeed) 39 | print("Random Seed: ", args.manualSeed) 40 | 41 | #===== Dataset ===== 42 | def getLoader(datasetName, dataroot, batchSize, workers, 43 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), split='train', shuffle=True, seed=None): 44 | 45 | if datasetName == 'pix2pix': 46 | from datasets.pix2pix import pix2pix as commonDataset 47 | import transforms.pix2pix as transforms 48 | if split == 'train': 49 | dataset = commonDataset(root=dataroot, 50 | transform=transforms.Compose([ 51 | transforms.RandomHorizontalFlip(), 52 | transforms.ToTensor(), 53 | transforms.Normalize(mean, std), 54 | ]), 55 | seed=seed) 56 | else: 57 | dataset = commonDataset(root=dataroot, 58 | transform=transforms.Compose([ 59 | transforms.ToTensor(), 60 | transforms.Normalize(mean, std), 61 | ]), 62 | seed=seed) 63 | 64 | dataloader = torch.utils.data.DataLoader(dataset, 65 | batch_size=batchSize, 66 | shuffle=shuffle, 67 | num_workers=int(workers)) 68 | return dataloader 69 | 70 | trainDataloader = getLoader(args.dataset, 71 | args.dataroot, 72 | args.batchSize, 73 | args.threads, 74 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 75 | split='train', 76 | shuffle=True, 77 | seed=args.manualSeed) 78 | 79 | valDataloader = getLoader(args.dataset, 80 | args.valDataroot, 81 | args.valBatchSize, 82 | args.threads, 83 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), 84 | split='val', 85 | shuffle=False, 86 | seed=args.manualSeed) 87 | 88 | #===== DehazeNet ===== 89 | print('===> Building model') 90 | net = AODnet() 91 | if args.cuda: 92 | net = net.cuda() 93 | 94 | #===== Loss function & optimizer ===== 95 | criterion = torch.nn.MSELoss() 96 | 97 | if args.cuda: 98 | criterion = criterion.cuda() 99 | 100 | optimizer = torch.optim.Adam(net.parameters(), lr=args.lr, weight_decay=0.0001) 101 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=53760, gamma=0.5) 102 | 103 | #===== Training and validation procedures ===== 104 | 105 | def train(epoch): 106 | net.train() 107 | epoch_loss = 0 108 | for iteration, batch in enumerate(trainDataloader, 0): 109 | varIn, varTar = Variable(batch[0]), Variable(batch[1]) 110 | varIn, varTar = varIn.float(), varTar.float() 111 | 112 | if args.cuda: 113 | varIn = varIn.cuda() 114 | if args.cuda: 115 | varTar = varTar.cuda() 116 | 117 | # print(iteration) 118 | optimizer.zero_grad() 119 | 120 | loss = criterion(net(varIn), varTar) 121 | # print(loss) 122 | epoch_loss += loss.data[0] 123 | loss.backward() 124 | optimizer.step() 125 | if iteration%args.printEvery == 0: 126 | print("===> Epoch[{}]({}/{}): Avg. Loss: {:.4f}".format(epoch, iteration+1, len(trainDataloader), epoch_loss/args.printEvery)) 127 | epoch_loss = 0 128 | 129 | def validate(): 130 | net.eval() 131 | avg_mse = 0 132 | for _, batch in enumerate(valDataloader, 0): 133 | varIn, varTar = Variable(batch[0]), Variable(batch[1]) 134 | varIn, varTar = varTar.float(), varIn.float() 135 | 136 | if args.cuda: 137 | varIn = varIn.cuda() 138 | if args.cuda: 139 | varTar = varTar.cuda() 140 | 141 | prediction = net(varIn) 142 | mse = criterion(prediction, varTar) 143 | avg_mse += mse.data[0] 144 | print("===>Avg. Loss: {:.4f}".format(avg_mse/len(valDataloader))) 145 | 146 | 147 | def checkpoint(epoch): 148 | model_out_path = "./model_pretrained/AOD_net_epoch_relu_{}.pth".format(epoch) 149 | torch.save(net, model_out_path) 150 | print("Checkpoint saved to {}".format(model_out_path)) 151 | #===== Main procedure ===== 152 | for epoch in range(1, args.nEpochs + 1): 153 | train(epoch) 154 | validate() 155 | checkpoint(epoch) 156 | -------------------------------------------------------------------------------- /transforms/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /transforms/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/transforms/__init__.pyc -------------------------------------------------------------------------------- /transforms/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/transforms/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /transforms/__pycache__/pix2pix.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/weberwcwei/AODnet-by-pytorch/3330478481b04f7cfdf696a3e5b1bc7e12970cc8/transforms/__pycache__/pix2pix.cpython-36.pyc -------------------------------------------------------------------------------- /transforms/pix2pix.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | import math 4 | import random 5 | from PIL import Image, ImageOps 6 | import numpy as np 7 | import numbers 8 | import types 9 | 10 | class Compose(object): 11 | """Composes several transforms together. 12 | Args: 13 | transforms (List[Transform]): list of transforms to compose. 14 | Example: 15 | >>> transforms.Compose([ 16 | >>> transforms.CenterCrop(10), 17 | >>> transforms.ToTensor(), 18 | >>> ]) 19 | """ 20 | def __init__(self, transforms): 21 | self.transforms = transforms 22 | 23 | def __call__(self, imgA, imgB): 24 | for t in self.transforms: 25 | imgA, imgB = t(imgA, imgB) 26 | return imgA, imgB 27 | 28 | class ToTensor(object): 29 | """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 30 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 31 | """ 32 | def __call__(self, picA, picB): 33 | pics = [picA, picB] 34 | output = [] 35 | for pic in pics: 36 | if isinstance(pic, np.ndarray): 37 | # handle numpy array 38 | img = torch.from_numpy(pic.transpose((2, 0, 1))) 39 | else: 40 | # handle PIL Image 41 | img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 42 | # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK 43 | if pic.mode == 'YCbCr': 44 | nchannel = 3 45 | else: 46 | nchannel = len(pic.mode) 47 | img = img.view(pic.size[1], pic.size[0], nchannel) 48 | # put it from HWC to CHW format 49 | # yikes, this transpose takes 80% of the loading time/CPU 50 | img = img.transpose(0, 1).transpose(0, 2).contiguous() 51 | img = img.float().div(255.) 52 | output.append(img) 53 | return output[0], output[1] 54 | 55 | class ToPILImage(object): 56 | """Converts a torch.*Tensor of range [0, 1] and shape C x H x W 57 | or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C 58 | to a PIL.Image of range [0, 255] 59 | """ 60 | def __call__(self, picA, picB): 61 | pics = [picA, picB] 62 | output = [] 63 | for pic in pics: 64 | npimg = pic 65 | mode = None 66 | if not isinstance(npimg, np.ndarray): 67 | npimg = pic.mul(255).byte().numpy() 68 | npimg = np.transpose(npimg, (1, 2, 0)) 69 | 70 | if npimg.shape[2] == 1: 71 | npimg = npimg[:, :, 0] 72 | mode = "L" 73 | output.append(Image.fromarray(npimg, mode=mode)) 74 | 75 | return output[0], output[1] 76 | 77 | class Normalize(object): 78 | """Given mean: (R, G, B) and std: (R, G, B), 79 | will normalize each channel of the torch.*Tensor, i.e. 80 | channel = (channel - mean) / std 81 | """ 82 | def __init__(self, mean, std): 83 | self.mean = mean 84 | self.std = std 85 | 86 | def __call__(self, tensorA, tensorB): 87 | tensors = [tensorA, tensorB] 88 | output = [] 89 | for tensor in tensors: 90 | # TODO: make efficient 91 | for t, m, s in zip(tensor, self.mean, self.std): 92 | t.sub_(m).div_(s) 93 | output.append(tensor) 94 | return output[0], output[1] 95 | 96 | class Scale(object): 97 | """Rescales the input PIL.Image to the given 'size'. 98 | 'size' will be the size of the smaller edge. 99 | For example, if height > width, then image will be 100 | rescaled to (size * height / width, size) 101 | size: size of the smaller edge 102 | interpolation: Default: PIL.Image.BILINEAR 103 | """ 104 | def __init__(self, size, interpolation=Image.BILINEAR): 105 | self.size = size 106 | self.interpolation = interpolation 107 | 108 | def __call__(self, imgA, imgB): 109 | imgs = [imgA, imgB] 110 | output = [] 111 | for img in imgs: 112 | w, h = img.size 113 | if (w <= h and w == self.size) or (h <= w and h == self.size): 114 | output.append(img) 115 | continue 116 | if w < h: 117 | ow = self.size 118 | oh = int(self.size * h / w) 119 | output.append(img.resize((ow, oh), self.interpolation)) 120 | continue 121 | else: 122 | oh = self.size 123 | ow = int(self.size * w / h) 124 | output.append(img.resize((ow, oh), self.interpolation)) 125 | return output[0], output[1] 126 | 127 | class CenterCrop(object): 128 | """Crops the given PIL.Image at the center to have a region of 129 | the given size. size can be a tuple (target_height, target_width) 130 | or an integer, in which case the target will be of a square shape (size, size) 131 | """ 132 | def __init__(self, size): 133 | if isinstance(size, numbers.Number): 134 | self.size = (int(size), int(size)) 135 | else: 136 | self.size = size 137 | 138 | def __call__(self, imgA, imgB): 139 | imgs = [imgA, imgB] 140 | output = [] 141 | for img in imgs: 142 | w, h = img.size 143 | th, tw = self.size 144 | x1 = int(round((w - tw) / 2.)) 145 | y1 = int(round((h - th) / 2.)) 146 | output.append(img.crop((x1, y1, x1 + tw, y1 + th))) 147 | return output[0], output[1] 148 | 149 | class Pad(object): 150 | """Pads the given PIL.Image on all sides with the given "pad" value""" 151 | def __init__(self, padding, fill=0): 152 | assert isinstance(padding, numbers.Number) 153 | assert isinstance(fill, numbers.Number) or isinstance(fill, str) or isinstance(fill, tuple) 154 | self.padding = padding 155 | self.fill = fill 156 | 157 | def __call__(self, imgA, imgB): 158 | imgs = [imgA, imgB] 159 | output = [] 160 | for img in imgs: 161 | output.append(ImageOps.expand(img, border=self.padding, fill=self.fill)) 162 | return output[0], output[1] 163 | 164 | class Lambda(object): 165 | """Applies a lambda as a transform.""" 166 | def __init__(self, lambd): 167 | assert isinstance(lambd, types.LambdaType) 168 | self.lambd = lambd 169 | 170 | def __call__(self, imgA, imgB): 171 | imgs = [imgA, imgB] 172 | output = [] 173 | for img in imgs: 174 | output.append(self.lambd(img)) 175 | return output[0], output[1] 176 | 177 | class RandomCrop(object): 178 | """Crops the given PIL.Image at a random location to have a region of 179 | the given size. size can be a tuple (target_height, target_width) 180 | or an integer, in which case the target will be of a square shape (size, size) 181 | """ 182 | def __init__(self, size, padding=0): 183 | if isinstance(size, numbers.Number): 184 | self.size = (int(size), int(size)) 185 | else: 186 | self.size = size 187 | self.padding = padding 188 | 189 | def __call__(self, imgA, imgB): 190 | imgs = [imgA, imgB] 191 | output = [] 192 | x1 = -1 193 | y1 = -1 194 | for img in imgs: 195 | if self.padding > 0: 196 | img = ImageOps.expand(img, border=self.padding, fill=0) 197 | 198 | w, h = img.size 199 | th, tw = self.size 200 | if w == tw and h == th: 201 | output.append(img) 202 | continue 203 | 204 | if x1 == -1 and y1 == -1: 205 | x1 = random.randint(0, w - tw) 206 | y1 = random.randint(0, h - th) 207 | output.append(img.crop((x1, y1, x1 + tw, y1 + th))) 208 | return output[0], output[1] 209 | 210 | class RandomHorizontalFlip(object): 211 | """Randomly horizontally flips the given PIL.Image with a probability of 0.5 212 | """ 213 | def __call__(self, imgA, imgB): 214 | imgs = [imgA, imgB] 215 | output = [] 216 | flag = random.random() < 0.5 217 | for img in imgs: 218 | if flag: 219 | output.append(img.transpose(Image.FLIP_LEFT_RIGHT)) 220 | else: 221 | output.append(img) 222 | return output[0], output[1] 223 | --------------------------------------------------------------------------------