├── results.jpg ├── run_exps.sh ├── concat_results.py ├── README.md ├── Interp.py ├── synthdata.py ├── pairedtransforms.py ├── datasets.py ├── main.py └── networks.py /results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zmurez/TurbulentWater/HEAD/results.jpg -------------------------------------------------------------------------------- /run_exps.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | echo "Hello World!" 4 | CUDA_VISIBLE_DEVICES=1 python main.py --exp-name both_L1VGGAdv 5 | CUDA_VISIBLE_DEVICES=1 python main.py --exp-name color_L1VGGAdv --no-warp-net 6 | CUDA_VISIBLE_DEVICES=1 python main.py --exp-name warp_L1VGG --no-color-net 7 | CUDA_VISIBLE_DEVICES=1 python main.py --exp-name warp_L1 --no-color-net --weight-Y-VGG 0 8 | CUDA_VISIBLE_DEVICES=1 python main.py --exp-name color_L1VGG --no-warp-net --weight-Z-Adv 0 9 | CUDA_VISIBLE_DEVICES=1 python main.py --exp-name warp_L1VGG_synth --no-color-net --synth-data 10 | python concat_results.py 11 | -------------------------------------------------------------------------------- /concat_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from datasets import ImageFolder 4 | import torchvision.transforms as transforms 5 | from PIL import Image 6 | from itertools import izip 7 | 8 | rawroot='/mnt/Data1/Water_Real' 9 | outroot='./results' 10 | outname='concat' 11 | 12 | datasets = [] 13 | 14 | # input images 15 | datasets.append( ImageFolder(rawroot, transform=transforms.Compose([transforms.Resize(256), transforms.CenterCrop(256)]), return_path=True) ) 16 | 17 | # results images 18 | for exp_name in ['warp_L1', 'warp_L1VGG', 'color_L1VGG', 'color_L1VGGAdv', 'both_L1VGGAdv']: 19 | datasets.append( ImageFolder(os.path.join(outroot,'%s_test'%exp_name), return_path=True) ) 20 | 21 | # concat and save each image 22 | for i, imgs in enumerate(izip(*datasets)): 23 | name = imgs[0][-1] 24 | print '%d/%d %s'%(i, len(datasets[0]), name) 25 | 26 | if not os.path.exists(os.path.join(outroot, outname, os.path.dirname(name))): 27 | os.makedirs(os.path.join(outroot, outname, os.path.dirname(name))) 28 | 29 | im = Image.fromarray( np.hstack((np.asarray(img[0]) for img in imgs)) ) 30 | im.save(os.path.join(outroot, outname, name)) 31 | 32 | # concat best examples into figure 33 | imgs=[] 34 | for name in ['Tank/262A4109.JPG','Wild/262A4895.JPG','Wild/262A4984.JPG']: 35 | imgs.append( Image.open(os.path.join(outroot, outname, name)) ) 36 | im = Image.fromarray( np.vstack((np.asarray(img) for img in imgs)) ) 37 | im.save(os.path.join(outroot, outname+'.jpg')) 38 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TurbulentWater 2 | Code for "Learning to See through Turbulent Water" WACV 2018 3 | 4 | Data and pretrained models are available at http://cseweb.ucsd.edu/~viscomp/projects/WACV18Water/ 5 | 6 | ## Instructions 7 | - download train.zip, val.zip and test.zip from http://cseweb.ucsd.edu/~viscomp/projects/WACV18Water/ 8 | - unzip train.zip into DATAROOT/Water 9 | - unzip val.zip into DATAROOT/Water (note these images correspond to the ImageNet test set) 10 | - unzip test.zip into DATAROOT/Water_Real 11 | - dowanload the origional ImageNet training and test sets to DATAROOT/ImageNet/ 12 | 13 | ### Training 14 | ``` 15 | python main.py --dataroot DATAROOT 16 | ``` 17 | ### Testing 18 | ``` 19 | python main.py --test --dataroot DATAROOT --exp-name warp_L1 --no-color-net --weight-Y-VGG 0 20 | python main.py --test --dataroot DATAROOT --exp-name warp_L1VGG --no-color-net 21 | python main.py --test --dataroot DATAROOT --exp-name color_L1VGG --no-warp-net --weight-Z-Adv 0 22 | python main.py --test --dataroot DATAROOT --exp-name color_L1VGGAdv --no-warp-net 23 | python main.py --test --dataroot DATAROOT --exp-name both_L1VGGAdv 24 | 25 | python main.py --test --dataroot DATAROOT --exp-name warp_L1VGG_synth --no-color-net 26 | ``` 27 | 28 | ## Minor modifications from the paper 29 | - added a reconstruction (L1) and perceptual loss (VGG) to the output of the WarpNet 30 | - all networks are trained for 3 epochs with all the losses from the start, with a constant learning rate of .0002 31 | - all hyper-parameter weights are set to 1.0 except the perceptual and adversarial losses of the final output which are set to 0.5 and 0.2 respectively 32 | - replaced transposed convolutions with nearest neighbor upsampling 33 | - replaced instance normalization (and final layer denormalization) with group normalization 34 | - added a 7x7 conv layer (to project from RGB to features and vice versa) to the begining and end of the networks 35 | 36 | ## Results 37 | 38 | -------------------------------------------------------------------------------- /Interp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def interp1(f,i): 6 | # f is the signal to be interpolated with size [num_batch, channels, length] 7 | # i are the indicies into f with size [num_batch, new_height, new_width] 8 | # returns a new signal of size [num_batch, channels, new_height, new_width] 9 | f = f.transpose(1,2).contiguous() 10 | num_batch, length, channels = f.size() 11 | new_size = (i.size()+(channels,)) 12 | f_flat = f.view(-1,channels) 13 | i = i.clamp(0,length-1) 14 | i0 = i.floor() 15 | i0_ = i0.long() 16 | i1_ = (i0_+1).clamp(0,length-1) 17 | batch_ind = torch.arange(0,num_batch).long().view(-1,1,1).expand_as(i0) 18 | if f.is_cuda: 19 | batch_ind = batch_ind.cuda() 20 | idx0 = batch_ind*length + i0_ 21 | idx1 = batch_ind*length + i1_ 22 | f0 = f_flat.index_select(0,idx0.view(-1)).view(*new_size) 23 | f1 = f_flat.index_select(0,idx1.view(-1)).view(*new_size) 24 | di = (i-i0).unsqueeze(3).expand_as(f0) 25 | f = f0*(1-di) + f1*di 26 | return f.transpose(2,3).transpose(1,2).contiguous() 27 | 28 | 29 | def interp2(f,i,j): 30 | # f is the image to be interpolated with size [num_batch, channels, height, width] 31 | # i,j are grids of indicies into f with size [num_batch, new_height, new_width] 32 | # returns a new image of size [num_batch, channels, new_height, new_width] 33 | f = f.transpose(1,2).transpose(2,3).contiguous() 34 | num_batch, height, width, channels = f.size() 35 | new_size = (i.size()+(channels,)) 36 | f_flat = f.view(-1,channels) 37 | i = i.clamp(0,height-1) 38 | j = j.clamp(0,width-1) 39 | i0 = i.floor() 40 | j0 = j.floor() 41 | i0_ = i0.long() 42 | j0_ = j0.long() 43 | i1_ = (i0_+1).clamp(0,height-1) 44 | j1_ = (j0_+1).clamp(0,width-1) 45 | batch_ind = torch.arange(0,num_batch).long().view(-1,1,1).expand_as(i0) 46 | if f.is_cuda: 47 | batch_ind = batch_ind.cuda() 48 | idx00 = batch_ind*width*height + i0_*width + j0_ 49 | idx01 = batch_ind*width*height + i0_*width + j1_ 50 | idx10 = batch_ind*width*height + i1_*width + j0_ 51 | idx11 = batch_ind*width*height + i1_*width + j1_ 52 | f00 = f_flat.index_select(0,idx00.view(-1)).view(*new_size) 53 | f01 = f_flat.index_select(0,idx01.view(-1)).view(*new_size) 54 | f10 = f_flat.index_select(0,idx10.view(-1)).view(*new_size) 55 | f11 = f_flat.index_select(0,idx11.view(-1)).view(*new_size) 56 | di = (i-i0).unsqueeze(3).expand_as(f00) 57 | dj = (j-j0).unsqueeze(3).expand_as(f00) 58 | f0 = f00*(1-dj) + f01*dj 59 | f1 = f10*(1-dj) + f11*dj 60 | f = f0*(1-di) + f1*di 61 | return f.transpose(2,3).transpose(1,2).contiguous() 62 | 63 | def warp(im,di,dj): 64 | # f is the image to be interpolated with size [num_batch, channels, height, width] 65 | # di,dj are grids of index offsets into f with size [num_batch, height, width] 66 | i,j = np.meshgrid(np.arange(di.size()[1], dtype='float32'), np.arange(di.size()[2], dtype='float32'), indexing='ij') 67 | i = torch.from_numpy(i).unsqueeze(0).expand_as(im[:,0,:,:]).float() 68 | j = torch.from_numpy(j).unsqueeze(0).expand_as(im[:,0,:,:]).float() 69 | if im.is_cuda: 70 | i,j=i.cuda(), j.cuda() 71 | return interp2(im, i+di, j+dj) 72 | 73 | -------------------------------------------------------------------------------- /synthdata.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import Interp 4 | import numbers 5 | 6 | class SynthData(object): 7 | def __init__(self, size, n=1, reset_after=250): 8 | if isinstance(size, numbers.Number): 9 | size = (int(size), int(size)) 10 | 11 | # coordinate grid 12 | x,y=np.meshgrid(np.linspace(-1,1,size[1]+2),np.linspace(1,-1,size[0]+2)) 13 | self.x=torch.from_numpy(x).float() 14 | self.y=torch.from_numpy(y).float() 15 | # height field 16 | self.u=[torch.zeros(size[0]+2,size[1]+2) for _ in range(n)] 17 | # velocity field 18 | self.v=[torch.zeros(size[0]+2,size[1]+2) for _ in range(n)] 19 | 20 | self.n = n # number of simulations (for independent batches) 21 | self.reset_after=reset_after 22 | self.reset() 23 | 24 | def reset(self): 25 | for u in self.u+self.v: 26 | u.fill_(0) 27 | self.ct=0 28 | self.random_config() 29 | #for i in range(self.n): 30 | # self.step(i, 50) 31 | 32 | def random_config(self): 33 | # initialize random distribution for perturb to sample from 34 | self.damp = np.random.rand(self.n)*.005+.994 35 | self.motion_blur = np.random.randint(7,15, size=self.n) 36 | self.window = np.hstack(( -np.ones((self.n,1)), np.ones((self.n,1)), -np.ones((self.n,1)), np.ones((self.n,1)) )) 37 | self.size = np.hstack(( .4-np.random.rand(self.n,1)*.25, .4+np.random.rand(self.n,1)*.25 )) 38 | self.ecin = np.hstack(( .7-np.random.rand(self.n,1)*.25, .7+np.random.rand(self.n,1)*.25 )) 39 | self.strength = np.hstack(( 10-np.random.rand(self.n,1)*4, 10+np.random.rand(self.n,1)*4 )) 40 | theta0 = np.random.rand(self.n,1)*180*np.pi/180 41 | thetad = np.random.rand(self.n,1)*30*np.pi/180 42 | self.theta = np.hstack(( theta0-thetad , theta0+thetad )) 43 | self.prob = np.random.rand(self.n)*.1+.1 44 | 45 | 46 | def perturb(self, i=0): 47 | # perturb the surface 48 | if np.random.rand()>self.prob[i]: 49 | return 50 | 51 | size = np.random.rand()**5*(self.size[i,1]-self.size[i,0])+self.size[i,0] 52 | strength = np.random.rand()*(self.strength[i,1]-self.strength[i,0])+self.strength[i,0] 53 | xc=np.random.rand()*(self.window[i,1]-self.window[i,0])+self.window[i,0] 54 | yc=np.random.rand()*(self.window[i,3]-self.window[i,2])+self.window[i,2] 55 | ecin=np.random.rand()*(self.ecin[i,1]-self.ecin[i,0])+self.ecin[i,0] 56 | theta=np.random.rand()*(self.theta[i,1]-self.theta[i,0])+self.theta[i,0] 57 | 58 | x = (self.x-xc)*np.cos(theta) + (self.y-yc)*np.sin(theta) 59 | y = -(self.x-xc)*np.sin(theta) + (self.y-yc)*np.cos(theta) 60 | r2=(x-xc)**2/size**2+(y-yc)**2/(1-ecin)/size**2 61 | self.u[i] += np.random.choice([-1,1])*torch.exp(-r2)*strength*25 62 | 63 | def step(self, i=0, steps=1, lock=False): 64 | # step the simulation 65 | u=self.u[i] 66 | v=self.v[i] 67 | for _ in range(steps): 68 | if not lock: 69 | self.perturb(i) 70 | f1 = (u[:-2,1:-1] + u[2:,1:-1] + u[1:-1,:-2] + u[1:-1,2:] - 4*u[1:-1,1:-1])/4 71 | f2 = (u[:-2,:-2] + u[2:,:-2] + u[:-2,2:] + u[2:,2:] - 4*u[1:-1,1:-1])/2/4 72 | v[1:-1,1:-1] = v[1:-1,1:-1] + (f1+f2) 73 | v*=self.damp[i] 74 | u+=v 75 | u[0,:]=u[1,:]*1 76 | u[:,0]=u[:,1]*1 77 | u[-1,:]=u[-2,:]*1 78 | u[:,-1]=u[:,-2]*1 79 | 80 | 81 | 82 | def __call__(self, img): 83 | # warp an image 84 | if self.ct >= self.reset_after*self.n: 85 | self.reset() 86 | i = self.ct % self.n 87 | img1 = torch.zeros_like(img) 88 | self.step(i, 10) 89 | for j in range(self.motion_blur[i]): 90 | if j>0: 91 | self.step(i, 2, lock=True) 92 | ux=(self.u[i][1:-1,2:]-self.u[i][1:-1,:-2])/2 93 | uy=(self.u[i][2:,1:-1]-self.u[i][:-2,1:-1])/2 94 | img1+=Interp.warp(img.unsqueeze(0),ux.unsqueeze(0),uy.unsqueeze(0)).squeeze(0) 95 | 96 | img1 = img1/self.motion_blur[i] 97 | self.ct+=1 98 | return [img1, img] 99 | 100 | 101 | if __name__ == '__main__': 102 | # test visualizations 103 | 104 | import matplotlib.pyplot as plt 105 | plt.ion() 106 | plt.show() 107 | 108 | x = SynthData(256, n=1) 109 | """ 110 | for i in range(100): 111 | x.step(steps=3) 112 | plt.imshow(x.u[0].numpy()) 113 | plt.title('%d %f %f %f'%(i,x.u[0].min(), x.u[0].mean(), x.u[0].max())) 114 | plt.draw() 115 | plt.pause(.1) 116 | 117 | 118 | """ 119 | from datasets import ImageFolder 120 | import torchvision.transforms as transforms 121 | 122 | img_dir='/mnt/Data1/ImageNet/val' 123 | data = ImageFolder(img_dir, transform= 124 | transforms.Compose([ 125 | transforms.Resize(256), 126 | transforms.CenterCrop(256), 127 | transforms.ToTensor(), 128 | ])) 129 | img = data[10][0] 130 | 131 | for i in range(100): 132 | img1 = x(img)[0] 133 | plt.imshow(img1.permute(1, 2, 0)) 134 | plt.draw() 135 | plt.pause(.1) 136 | 137 | 138 | -------------------------------------------------------------------------------- /pairedtransforms.py: -------------------------------------------------------------------------------- 1 | # minor modifications from main pytorch 2 | 3 | from torchvision.transforms import functional as F 4 | from PIL import Image 5 | import math 6 | import numbers 7 | import random 8 | 9 | _pil_interpolation_to_str = { 10 | Image.NEAREST: 'PIL.Image.NEAREST', 11 | Image.BILINEAR: 'PIL.Image.BILINEAR', 12 | Image.BICUBIC: 'PIL.Image.BICUBIC', 13 | Image.LANCZOS: 'PIL.Image.LANCZOS', 14 | } 15 | 16 | 17 | class ToTensor(object): 18 | """Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor. 19 | 20 | Converts a pair of PIL Images or numpy.ndarray (H x W x C) in the range 21 | [0, 255] to torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 22 | """ 23 | 24 | def __call__(self, pics): 25 | """ 26 | Args: 27 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 28 | 29 | Returns: 30 | Tensor: Converted image. 31 | """ 32 | return [F.to_tensor(pic) for pic in pics] 33 | 34 | 35 | def __repr__(self): 36 | return self.__class__.__name__ + '()' 37 | 38 | 39 | 40 | class Normalize(object): 41 | """Normalize a tensor image with mean and standard deviation. 42 | Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform 43 | will normalize each channel of the input ``torch.*Tensor`` i.e. 44 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 45 | Args: 46 | mean (sequence): Sequence of means for each channel. 47 | std (sequence): Sequence of standard deviations for each channel. 48 | """ 49 | 50 | def __init__(self, mean, std): 51 | self.mean = mean 52 | self.std = std 53 | 54 | def __call__(self, tensors): 55 | """ 56 | Args: 57 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized. 58 | Returns: 59 | Tensor: Normalized Tensor image. 60 | """ 61 | return [F.normalize(tensor, self.mean, self.std) for tensor in tensors] 62 | 63 | def __repr__(self): 64 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 65 | 66 | 67 | 68 | class Resize(object): 69 | """Resize the input PIL Image to the given size. 70 | Args: 71 | size (sequence or int): Desired output size. If size is a sequence like 72 | (h, w), output size will be matched to this. If size is an int, 73 | smaller edge of the image will be matched to this number. 74 | i.e, if height > width, then image will be rescaled to 75 | (size * height / width, size) 76 | interpolation (int, optional): Desired interpolation. Default is 77 | ``PIL.Image.BILINEAR`` 78 | """ 79 | 80 | def __init__(self, size, interpolation=Image.BILINEAR): 81 | assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) 82 | self.size = size 83 | self.interpolation = interpolation 84 | 85 | def __call__(self, imgs): 86 | """ 87 | Args: 88 | img (PIL Image): Image to be scaled. 89 | Returns: 90 | PIL Image: Rescaled image. 91 | """ 92 | return [F.resize(img, self.size, self.interpolation) for img in imgs] 93 | 94 | def __repr__(self): 95 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 96 | return self.__class__.__name__ + '(size={0}, interpolation={1})'.format(self.size, interpolate_str) 97 | 98 | 99 | class CenterCrop(object): 100 | """Crops the given PIL Image at the center. 101 | Args: 102 | size (sequence or int): Desired output size of the crop. If size is an 103 | int instead of sequence like (h, w), a square crop (size, size) is 104 | made. 105 | """ 106 | 107 | def __init__(self, size): 108 | if isinstance(size, numbers.Number): 109 | self.size = (int(size), int(size)) 110 | else: 111 | self.size = size 112 | 113 | def __call__(self, imgs): 114 | """ 115 | Args: 116 | img (PIL Image): Image to be cropped. 117 | Returns: 118 | PIL Image: Cropped image. 119 | """ 120 | return [F.center_crop(img, self.size) for img in imgs] 121 | 122 | def __repr__(self): 123 | return self.__class__.__name__ + '(size={0})'.format(self.size) 124 | 125 | 126 | 127 | class RandomResizedCrop(object): 128 | """Crop the given PIL Image to random size and aspect ratio. 129 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 130 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 131 | is finally resized to given size. 132 | This is popularly used to train the Inception networks. 133 | Args: 134 | size: expected output size of each edge 135 | scale: range of size of the origin size cropped 136 | ratio: range of aspect ratio of the origin aspect ratio cropped 137 | interpolation: Default: PIL.Image.BILINEAR 138 | """ 139 | 140 | def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=Image.BILINEAR): 141 | self.size = (size, size) 142 | self.interpolation = interpolation 143 | self.scale = scale 144 | self.ratio = ratio 145 | 146 | @staticmethod 147 | def get_params(img, scale, ratio): 148 | """Get parameters for ``crop`` for a random sized crop. 149 | Args: 150 | img (PIL Image): Image to be cropped. 151 | scale (tuple): range of size of the origin size cropped 152 | ratio (tuple): range of aspect ratio of the origin aspect ratio cropped 153 | Returns: 154 | tuple: params (i, j, h, w) to be passed to ``crop`` for a random 155 | sized crop. 156 | """ 157 | for attempt in range(10): 158 | area = img.size[0] * img.size[1] 159 | target_area = random.uniform(*scale) * area 160 | aspect_ratio = random.uniform(*ratio) 161 | 162 | w = int(round(math.sqrt(target_area * aspect_ratio))) 163 | h = int(round(math.sqrt(target_area / aspect_ratio))) 164 | 165 | if random.random() < 0.5: 166 | w, h = h, w 167 | 168 | if w <= img.size[0] and h <= img.size[1]: 169 | i = random.randint(0, img.size[1] - h) 170 | j = random.randint(0, img.size[0] - w) 171 | return i, j, h, w 172 | 173 | # Fallback 174 | w = min(img.size[0], img.size[1]) 175 | i = (img.size[1] - w) // 2 176 | j = (img.size[0] - w) // 2 177 | return i, j, w, w 178 | 179 | def __call__(self, imgs): 180 | """ 181 | Args: 182 | img (PIL Image): Image to be cropped and resized. 183 | Returns: 184 | PIL Image: Randomly cropped and resized image. 185 | """ 186 | i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio) 187 | return [F.resized_crop(img, i, j, h, w, self.size, self.interpolation) for img in imgs] 188 | 189 | def __repr__(self): 190 | interpolate_str = _pil_interpolation_to_str[self.interpolation] 191 | format_string = self.__class__.__name__ + '(size={0}'.format(self.size) 192 | format_string += ', scale={0}'.format(tuple(round(s, 4) for s in self.scale)) 193 | format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio)) 194 | format_string += ', interpolation={0})'.format(interpolate_str) 195 | return format_string 196 | 197 | 198 | class RandomHorizontalFlip(object): 199 | """Horizontally flip the given PIL Image randomly with a given probability. 200 | Args: 201 | p (float): probability of the image being flipped. Default value is 0.5 202 | """ 203 | 204 | def __init__(self, p=0.5): 205 | self.p = p 206 | 207 | def __call__(self, imgs): 208 | """ 209 | Args: 210 | img (PIL Image): Image to be flipped. 211 | Returns: 212 | PIL Image: Randomly flipped image. 213 | """ 214 | if random.random() < self.p: 215 | return [F.hflip(img) for img in imgs] 216 | return imgs 217 | 218 | def __repr__(self): 219 | return self.__class__.__name__ + '(p={})'.format(self.p) 220 | -------------------------------------------------------------------------------- /datasets.py: -------------------------------------------------------------------------------- 1 | # minor modifications from main pytorch 2 | 3 | import torch.utils.data as data 4 | from PIL import Image 5 | import os 6 | import os.path 7 | 8 | IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] 9 | 10 | def has_file_allowed_extension(filename, extensions): 11 | """Checks if a file is an allowed extension. 12 | 13 | Args: 14 | filename (string): path to a file 15 | 16 | Returns: 17 | bool: True if the filename ends with a known image extension 18 | """ 19 | filename_lower = filename.lower() 20 | return any(filename_lower.endswith(ext) for ext in extensions) 21 | 22 | def pil_loader(path): 23 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 24 | with open(path, 'rb') as f: 25 | img = Image.open(f) 26 | return img.convert('RGB') 27 | 28 | 29 | def accimage_loader(path): 30 | import accimage 31 | try: 32 | return accimage.Image(path) 33 | except IOError: 34 | # Potentially a decoding problem, fall back to PIL.Image 35 | return pil_loader(path) 36 | 37 | 38 | def default_loader(path): 39 | from torchvision import get_image_backend 40 | if get_image_backend() == 'accimage': 41 | return accimage_loader(path) 42 | else: 43 | return pil_loader(path) 44 | 45 | 46 | def find_classes(dir): 47 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 48 | classes.sort() 49 | class_to_idx = {classes[i]: i for i in range(len(classes))} 50 | return classes, class_to_idx 51 | 52 | 53 | def make_dataset(dir, class_to_idx, extensions): 54 | images = [] 55 | dir = os.path.expanduser(dir) 56 | dir_len = len(dir)+1 57 | for target in sorted(os.listdir(dir)): 58 | d = os.path.join(dir, target) 59 | if not os.path.isdir(d): 60 | continue 61 | 62 | for root, _, fnames in sorted(os.walk(d)): 63 | for fname in sorted(fnames): 64 | if has_file_allowed_extension(fname, extensions): 65 | path = os.path.join(root[dir_len:], fname) 66 | item = (path, class_to_idx[target]) 67 | images.append(item) 68 | 69 | return images 70 | 71 | 72 | class ImageFolder(data.Dataset): 73 | """A generic data loader where the samples are arranged in this way: :: 74 | 75 | root1/class_x/xxx.ext 76 | root1/class_x/xxy.ext 77 | root1/class_x/xxz.ext 78 | 79 | root1/class_y/123.ext 80 | root1/class_y/nsdf3.ext 81 | root1/class_y/asd932_.ext 82 | 83 | 84 | Args: 85 | root1 (string): Root 1 directory path. 86 | loader (callable): A function to load a sample given its path. 87 | extensions (list[string]): A list of allowed extensions. 88 | transform (callable, optional): A function/transform that takes in 89 | a sample and returns a transformed version. 90 | E.g, ``transforms.RandomCrop`` for images. 91 | target_transform (callable, optional): A function/transform that takes 92 | in the target and transforms it. 93 | return_path: also returns the name of the image 94 | 95 | Attributes: 96 | classes (list): List of the class names. 97 | class_to_idx (dict): Dict with items (class_name, class_index). 98 | samples (list): List of (sample path, class_index) tuples 99 | """ 100 | 101 | def __init__(self, root, loader=default_loader, extensions=IMG_EXTENSIONS, transform=None, target_transform=None, return_path=False): 102 | classes, class_to_idx = find_classes(root) 103 | samples = make_dataset(root, class_to_idx, extensions) 104 | if len(samples) == 0: 105 | raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" 106 | "Supported extensions are: " + ",".join(extensions))) 107 | 108 | self.root = root 109 | self.loader = loader 110 | self.extensions = extensions 111 | 112 | self.classes = classes 113 | self.class_to_idx = class_to_idx 114 | self.samples = samples 115 | 116 | self.transform = transform 117 | self.target_transform = target_transform 118 | 119 | self.return_path = return_path 120 | 121 | def __getitem__(self, index): 122 | """ 123 | Args: 124 | index (int): Index 125 | 126 | Returns: 127 | tuple: (sample, target) where target is class_index of the target class. 128 | """ 129 | path, target = self.samples[index] 130 | sample = self.loader(os.path.join(self.root,path)) 131 | if self.transform is not None: 132 | sample = self.transform(sample) 133 | if self.target_transform is not None: 134 | target = self.target_transform(target) 135 | 136 | if self.return_path: 137 | return sample, target, path 138 | else: 139 | return sample, target 140 | 141 | 142 | def __len__(self): 143 | return len(self.samples) 144 | 145 | def __repr__(self): 146 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 147 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 148 | fmt_str += ' Root Location: {}\n'.format(self.root) 149 | tmp = ' Transforms (if any): ' 150 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 151 | tmp = ' Target Transforms (if any): ' 152 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 153 | return fmt_str 154 | 155 | 156 | 157 | class PairedImageFolder(data.Dataset): 158 | """A generic data loader where the samples are arranged in this way: :: 159 | 160 | root1/class_x/xxx.ext root2/class_x/xxx.ext 161 | root1/class_x/xxy.ext root2/class_x/xxy.ext 162 | root1/class_x/xxz.ext root2/class_x/xxz.ext 163 | 164 | root1/class_y/123.ext root2/class_y/123.ext 165 | root1/class_y/nsdf3.ext root2/class_y/nsdf3.ext 166 | root1/class_y/asd932_.ext root2/class_y/asd932_.ext 167 | 168 | 169 | Args: 170 | root1 (string): Root 1 directory path. 171 | root2 (string): Root 2 directory path. 172 | loader (callable): A function to load a sample given its path. 173 | extensions (list[string]): A list of allowed extensions. 174 | transform (callable, optional): A function/transform that takes in 175 | a sample and returns a transformed version. 176 | E.g, ``transforms.RandomCrop`` for images. 177 | target_transform (callable, optional): A function/transform that takes 178 | in the target and transforms it. 179 | 180 | Attributes: 181 | classes (list): List of the class names. 182 | class_to_idx (dict): Dict with items (class_name, class_index). 183 | samples (list): List of (sample path, class_index) tuples 184 | """ 185 | 186 | def __init__(self, root1, root2, loader=default_loader, extensions=IMG_EXTENSIONS, transform=None, target_transform=None): 187 | classes, class_to_idx = find_classes(root1) 188 | samples = make_dataset(root1, class_to_idx, extensions) 189 | if len(samples) == 0: 190 | raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" 191 | "Supported extensions are: " + ",".join(extensions))) 192 | 193 | self.root1 = root1 194 | self.root2 = root2 195 | self.loader = loader 196 | self.extensions = extensions 197 | 198 | self.classes = classes 199 | self.class_to_idx = class_to_idx 200 | self.samples = samples 201 | 202 | self.transform = transform 203 | self.target_transform = target_transform 204 | 205 | 206 | def __getitem__(self, index): 207 | """ 208 | Args: 209 | index (int): Index 210 | 211 | Returns: 212 | tuple: (sample, target) where target is class_index of the target class. 213 | """ 214 | path, target = self.samples[index] 215 | sample1 = self.loader(os.path.join(self.root1,path)) 216 | sample2 = self.loader(os.path.join(self.root2,path)) 217 | if self.transform is not None: 218 | sample1, sample2 = self.transform([sample1, sample2]) 219 | if self.target_transform is not None: 220 | target = self.target_transform(target) 221 | 222 | return [sample1, sample2], target 223 | 224 | 225 | def __len__(self): 226 | return len(self.samples) 227 | 228 | def __repr__(self): 229 | fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' 230 | fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) 231 | fmt_str += ' Root Location: {} {}\n'.format(self.root1, self.root2) 232 | tmp = ' Transforms (if any): ' 233 | fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 234 | tmp = ' Target Transforms (if any): ' 235 | fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) 236 | return fmt_str 237 | 238 | 239 | 240 | 241 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import torch 5 | import torchvision 6 | from datasets import ImageFolder, PairedImageFolder 7 | import torchvision.transforms as transforms 8 | import pairedtransforms 9 | import networks 10 | from PIL import Image 11 | import synthdata 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--dataroot', default='/mnt/Data1', help='path to images') 16 | parser.add_argument('--workers', default=4, type=int, help='number of data loading workers (default: 4)') 17 | parser.add_argument('--batch-size', default=16, type=int, help='mini-batch size') 18 | parser.add_argument('--outroot', default='./results', help='path to save the results') 19 | parser.add_argument('--exp-name', default='test', help='name of expirement') 20 | parser.add_argument('--load', default='', help='name of pth to load weights from') 21 | parser.add_argument('--freeze-cc-net', dest='freeze_cc_net', action='store_true', help='dont train the color corrector net') 22 | parser.add_argument('--freeze-warp-net', dest='freeze_warp_net', action='store_true', help='dont train the warp net') 23 | parser.add_argument('--test', dest='test', action='store_true', help='only test the network') 24 | parser.add_argument('--synth-data', dest='synth_data', action='store_true', help='use synthetic data instead of tank data for training') 25 | parser.add_argument('--epochs', default=3, type=int, help='number of epochs to train for') 26 | parser.add_argument('--no-warp-net', dest='warp_net', action='store_false', help='do not include warp net in the model') 27 | parser.add_argument('--warp-net-downsample', default=3, type=int, help='number of downsampling layers in warp net') 28 | parser.add_argument('--no-color-net', dest='color_net', action='store_false', help='do not include color net in the model') 29 | parser.add_argument('--color-net-downsample', default=3, type=int, help='number of downsampling layers in color net') 30 | parser.add_argument('--no-color-net-skip', dest='color_net_skip', action='store_false', help='dont use u-net skip connections in the color net') 31 | parser.add_argument('--dim', default=32, type=int, help='initial feature dimension (doubled at each downsampling layer)') 32 | parser.add_argument('--n-res', default=8, type=int, help='number of residual blocks') 33 | parser.add_argument('--norm', default='gn', type=str, help='type of normalization layer') 34 | parser.add_argument('--denormalize', dest='denormalize', action='store_true', help='denormalize output image by input image mean/var') 35 | parser.add_argument('--weight-X-L1', default=1., type=float, help='weight of L1 reconstruction loss after color corrector net') 36 | parser.add_argument('--weight-Y-L1', default=1., type=float, help='weight of L1 reconstruction loss after warp net') 37 | parser.add_argument('--weight-Y-VGG', default=1., type=float, help='weight of perceptual loss after warp net') 38 | parser.add_argument('--weight-Z-L1', default=1., type=float, help='weight of L1 reconstruction loss after color net') 39 | parser.add_argument('--weight-Z-VGG', default=.5, type=float, help='weight of perceptual loss after color net') 40 | parser.add_argument('--weight-Z-Adv', default=0.2, type=float, help='weight of adversarial loss after color net') 41 | args = parser.parse_args() 42 | 43 | # set random seed for consistent fixed batch 44 | torch.manual_seed(8) 45 | 46 | # set weights of losses of intermediate outputs to zero if not necessary 47 | if not args.warp_net: 48 | args.weight_Y_L1=0 49 | args.weight_Y_VGG=0 50 | if not args.color_net: 51 | args.weight_Z_L1=0 52 | args.weight_Z_VGG=0 53 | args.weight_Z_Adv=0 54 | 55 | 56 | # datasets 57 | train_dir_1 = os.path.join(args.dataroot,'Water', 'train') 58 | train_dir_2 = os.path.join(args.dataroot,'ImageNet', 'train') 59 | val_dir_1 = os.path.join(args.dataroot,'Water', 'test') 60 | val_dir_2 = os.path.join(args.dataroot,'ImageNet', 'test') 61 | test_dir = os.path.join(args.dataroot,'Water_Real') 62 | 63 | if args.synth_data: 64 | train_data = ImageFolder(train_dir_2, transform= 65 | transforms.Compose([ 66 | transforms.RandomResizedCrop(224), 67 | transforms.RandomHorizontalFlip(), 68 | transforms.ToTensor(), 69 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 70 | synthdata.SynthData(224, n=args.batch_size), 71 | ])) 72 | else: 73 | train_data = PairedImageFolder(train_dir_1, train_dir_2, transform= 74 | transforms.Compose([ 75 | pairedtransforms.RandomResizedCrop(224), 76 | pairedtransforms.RandomHorizontalFlip(), 77 | pairedtransforms.ToTensor(), 78 | pairedtransforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 79 | ])) 80 | val_data = PairedImageFolder(val_dir_1, val_dir_2, transform= 81 | transforms.Compose([ 82 | pairedtransforms.Resize(256), 83 | pairedtransforms.CenterCrop(256), 84 | pairedtransforms.ToTensor(), 85 | pairedtransforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 86 | ])) 87 | test_data = ImageFolder(test_dir, transform= 88 | transforms.Compose([ 89 | transforms.Resize(256), 90 | transforms.CenterCrop(256), 91 | transforms.ToTensor(), 92 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), 93 | ]), return_path=True) 94 | 95 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, shuffle=True) 96 | val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, shuffle=True) 97 | test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, num_workers=args.workers, pin_memory=True, shuffle=False) 98 | 99 | # fixed test batch for visualization during training 100 | fixed_batch = iter(val_loader).next()[0] 101 | 102 | # model 103 | model=networks.Model(args) 104 | model.cuda() 105 | 106 | # load weights from checkpoint 107 | if args.test and not args.load: 108 | args.load = args.exp_name 109 | if args.load: 110 | model.load_state_dict(torch.load(os.path.join(args.outroot, '%s_net.pth'%args.load)), strict=args.test) 111 | 112 | # create outroot if necessary 113 | if not os.path.exists(args.outroot): 114 | os.makedirs(args.outroot) 115 | 116 | # if args.test only run test script 117 | if args.test: 118 | test(test_loader, model, args) 119 | return 120 | 121 | # main training loop 122 | for epoch in range(args.epochs): 123 | train(train_loader, model, fixed_batch, epoch, args) 124 | torch.save(model.state_dict(), os.path.join(args.outroot, '%s_net.pth'%args.exp_name)) 125 | test(test_loader, model, args) 126 | 127 | 128 | def train(loader, model, fixed_batch, epoch, args): 129 | model.train() 130 | 131 | end_time = time.time() 132 | for i, ((input, target), _) in enumerate(loader): 133 | input = input.cuda() 134 | target = target.cuda() 135 | data_time = time.time() - end_time 136 | 137 | # take an optimization step 138 | losses = model.optimize_parameters(input, target) 139 | batch_time = time.time() - end_time 140 | 141 | # display progress 142 | print '%s Epoch: %02d/%02d %04d/%04d time: %.3f %.3f '%(args.exp_name, epoch, args.epochs, 143 | i, len(loader), data_time, batch_time) + model.print_losses(losses) 144 | 145 | # visualize progress 146 | if i%100==0: 147 | visualize(input, target, model, os.path.join(args.outroot, '%s_train.jpg'%args.exp_name) ) 148 | visualize(fixed_batch[0].cuda(), fixed_batch[1], model, os.path.join(args.outroot, '%s_val.jpg'%args.exp_name) ) 149 | model.train() 150 | 151 | end_time = time.time() 152 | 153 | #if i==10: 154 | # break 155 | 156 | def visualize(input, target, model, name): 157 | model.eval() 158 | with torch.no_grad(): 159 | x, warp, y, z = model(input) 160 | warp = (warp+5)/10 161 | warp = torch.cat((warp, torch.ones_like(y[:,:1,:,:])), dim=1) 162 | visuals = torch.cat([input.cpu(), x.cpu(), warp.cpu(), y.cpu(), z.cpu(), target.cpu()], dim=2) 163 | torchvision.utils.save_image(visuals, name, nrow=16, normalize=True, range=(-1,1), pad_value=1) 164 | 165 | 166 | def test(loader, model, args): 167 | model.eval() 168 | with torch.no_grad(): 169 | end_time = time.time() 170 | for i, data in enumerate(loader): 171 | input = data[0].cuda() 172 | data_time = time.time() - end_time 173 | x, warp, y, z = model(input, cc=False) 174 | 175 | # save the output for each image by name 176 | for out, name in zip(z,data[-1]): 177 | if not os.path.exists(os.path.join(args.outroot, '%s_test'%args.exp_name, os.path.dirname(name))): 178 | os.makedirs(os.path.join(args.outroot, '%s_test'%args.exp_name, os.path.dirname(name))) 179 | 180 | im = Image.fromarray( (out*.5+.5).mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() ) 181 | im.save(os.path.join(args.outroot, '%s_test'%args.exp_name, name)) 182 | 183 | 184 | batch_time = time.time() - end_time 185 | print '%s Test: %04d/%04d time: %.3f %.3f '%(args.exp_name, i, len(loader), data_time, batch_time) 186 | end_time = time.time() 187 | 188 | #if i==10: 189 | # break 190 | 191 | 192 | 193 | if __name__ == '__main__': 194 | main() 195 | -------------------------------------------------------------------------------- /networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | import itertools 5 | from collections import OrderedDict 6 | import Interp 7 | from torchvision.models import vgg16 8 | 9 | 10 | class Model(nn.Module): 11 | def __init__(self, args): 12 | super(Model, self).__init__() 13 | 14 | # weights of losses 15 | self.weights = {'X_L1':args.weight_X_L1, 'Y_L1':args.weight_Y_L1, 'Z_L1':args.weight_Z_L1, 'Y_VGG':args.weight_Y_VGG, 'Z_VGG':args.weight_Z_VGG, 'Z_Adv':args.weight_Z_Adv} 16 | 17 | # networks 18 | params_G = [] 19 | 20 | # color corrector network (used to correct color shift between monitor images and real images in the training set) 21 | if args.weight_X_L1>0: 22 | self.cc_net = CCNet() 23 | if args.freeze_cc_net: 24 | for param in self.cc_net.parameters(): 25 | param.requires_grad=False 26 | else: 27 | params_G = itertools.chain(params_G, self.cc_net.parameters()) 28 | else: 29 | self.cc_net = None 30 | 31 | # warp net 32 | if args.warp_net: 33 | self.warp_net = I2INet(3, 2, args.warp_net_downsample, False, args.dim, args.n_res, args.norm, 'relu', 'reflect', final_activ='none') 34 | if args.freeze_warp_net: 35 | for param in self.warp_net.parameters(): 36 | param.requires_grad=False 37 | else: 38 | params_G = itertools.chain(params_G, self.warp_net.parameters()) 39 | else: 40 | self.warp_net = None 41 | 42 | # color net 43 | if args.color_net: 44 | self.color_net = I2INet(3, 3, args.color_net_downsample, args.color_net_skip, args.dim, args.n_res, args.norm, 'relu', 'reflect', args.denormalize) 45 | params_G = itertools.chain(params_G, self.color_net.parameters()) 46 | else: 47 | self.color_net = None 48 | 49 | # optimizer 50 | self.optimizer_G = torch.optim.Adam(params_G, lr=2e-4, betas=(.5, 0.999)) 51 | 52 | # for reconstruction loss 53 | self.recon_criterion = nn.L1Loss() 54 | 55 | # discriminator for adversarial loss 56 | if self.weights['Z_Adv']>0: 57 | self.discriminator = Discriminator(dim=args.dim) 58 | self.optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=2e-4, betas=(.5, 0.999)) 59 | 60 | # vgg for perceptual loss 61 | if self.weights['Y_VGG']>0 or self.weights['Z_VGG']>0: 62 | self.vgg = vgg_features() 63 | 64 | 65 | def forward(self, w, cc=True): 66 | # color corrector network 67 | if self.cc_net is not None and cc: 68 | x = self.cc_net(w) 69 | else: 70 | x = w 71 | 72 | # warp net 73 | if self.warp_net is not None: 74 | warp = self.warp_net(x)*10 75 | y = Interp.warp(x,warp[:,0,:,:],warp[:,1,:,:]) 76 | else: 77 | warp = torch.zeros_like(x[:,:2,:,:]) 78 | y = x 79 | 80 | # color net 81 | if self.color_net is not None: 82 | z = self.color_net(y) 83 | else: 84 | z=y 85 | 86 | return x, warp, y, z 87 | 88 | 89 | def compute_loss_G(self, x, y, z, target): 90 | losses = OrderedDict() 91 | 92 | # Reconstruction loss 93 | if self.weights['X_L1']>0: 94 | losses['X_L1'] = self.recon_criterion(x, target) 95 | if self.weights['Y_L1']>0: 96 | losses['Y_L1'] = self.recon_criterion(y, target) 97 | if self.weights['Z_L1']>0: 98 | losses['Z_L1'] = self.recon_criterion(z, target) 99 | 100 | # Perceptual loss 101 | if self.weights['Y_VGG']>0: 102 | losses['Y_VGG'] = self.recon_criterion(self.vgg(y), self.vgg(target)) 103 | if self.weights['Z_VGG']>0: 104 | losses['Z_VGG'] = self.recon_criterion(self.vgg(z), self.vgg(target)) 105 | 106 | # Adversarial loss 107 | if self.weights['Z_Adv']>0: 108 | losses['Z_Adv'] = self.discriminator.calc_gen_loss(z) 109 | 110 | return losses 111 | 112 | 113 | def optimize_parameters(self, input, target): 114 | x, warp, y, z = self.forward(input) 115 | # update discriminator 116 | if self.weights['Z_Adv']>0: 117 | self.optimizer_D.zero_grad() 118 | loss_d = self.discriminator.calc_dis_loss(z, target) 119 | loss_d.backward() 120 | self.optimizer_D.step() 121 | 122 | # update generators 123 | self.optimizer_G.zero_grad() 124 | losses = self.compute_loss_G(x, y, z, target) 125 | loss = sum([losses[key]*self.weights[key] for key in losses.keys()]) 126 | loss.backward() 127 | self.optimizer_G.step() 128 | 129 | if self.weights['Z_Adv']>0: 130 | losses['Dis'] = loss_d 131 | return losses 132 | 133 | 134 | def print_losses(self, losses): 135 | return ' '.join(['Loss %s: %.4f'%(key, val.item()) for key,val in losses.items()]) 136 | 137 | 138 | 139 | class I2INet(nn.Module): 140 | def __init__(self, input_dim=3, output_dim=3, n_downsample=3, skip=True, dim=32, n_res=8, norm='in', activ='relu', pad_type='reflect', denormalize=False, final_activ='tanh'): 141 | super(I2INet, self).__init__() 142 | 143 | self.skip=skip 144 | self.denormalize = denormalize 145 | 146 | # project to feature space 147 | self.conv_in = Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type) 148 | 149 | # downsampling blocks 150 | self.down_blocks = nn.ModuleList() 151 | for i in range(n_downsample): 152 | self.down_blocks.append( Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) ) 153 | dim *= 2 154 | 155 | # residual blocks 156 | self.res_blocks = nn.ModuleList() 157 | for i in range(n_res): 158 | self.res_blocks.append( ResBlock(dim, norm=norm, activation=activ, pad_type=pad_type) ) 159 | 160 | # upsampling blocks 161 | self.up_blocks = nn.ModuleList() 162 | for i in range(n_downsample): 163 | self.up_blocks.append( nn.Sequential(nn.Upsample(scale_factor=2), 164 | Conv2dBlock(dim, dim // 2, 5, 1, 2, norm=norm, activation=activ, pad_type=pad_type)) ) 165 | dim //= 2 166 | 167 | # project to image space 168 | self.conv_out = Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation=final_activ, pad_type=pad_type) 169 | 170 | #self.apply(weights_init('kaiming')) 171 | #self.apply(weights_init('gaussian')) 172 | 173 | 174 | def forward(self, x): 175 | # normalize image and save mean/var if using denormalization 176 | if self.denormalize: 177 | x_mean = x.view(x.size(0), x.size(1), -1).mean(2).view(x.size(0), x.size(1), 1, 1) 178 | x_var = x.view(x.size(0), x.size(1), -1).var(2).view(x.size(0), x.size(1), 1, 1) 179 | x = (x-x_mean)/x_var 180 | 181 | # project to feature space 182 | x = self.conv_in(x) 183 | 184 | # downsampling blocks 185 | xs = [] 186 | for block in self.down_blocks: 187 | xs += [x] 188 | x = block(x) 189 | 190 | # residual blocks 191 | for block in self.res_blocks: 192 | x = block(x) 193 | 194 | # upsampling blocks 195 | for block, skip in zip(self.up_blocks, reversed(xs)): 196 | x = block(x) 197 | if self.skip: 198 | x = x + skip 199 | 200 | # project to image space 201 | x = self.conv_out(x) 202 | 203 | # denormalize if necessary 204 | if self.denormalize: 205 | x = x*x_var+x_mean 206 | return x 207 | 208 | 209 | class CCNet(nn.Module): 210 | def __init__(self, input_dim=3, output_dim=3, layers=5, dim=32, norm='gn', activ='relu', pad_type='reflect', final_activ='tanh'): 211 | super(CCNet, self).__init__() 212 | self.model = [] 213 | #self.model += [Conv2dBlock(input_dim, dim, 3, 1, 1, norm=norm, activation=activ, pad_type=pad_type)] 214 | self.model += [Conv2dBlock(input_dim, dim, 1, 1, 0, norm=norm, activation=activ, pad_type=pad_type)] 215 | for i in range(layers-2): 216 | self.model += [Conv2dBlock(dim, dim, 1, 1, 0, norm=norm, activation=activ, pad_type=pad_type)] 217 | self.model += [Conv2dBlock(dim, output_dim, 1, 1, 0, norm='none', activation=final_activ, pad_type=pad_type)] 218 | self.model = nn.Sequential(*self.model) 219 | def forward(self, x): 220 | return self.model(x) 221 | 222 | 223 | class vgg_features(nn.Module): 224 | def __init__(self): 225 | super(vgg_features, self).__init__() 226 | # get vgg16 features up to conv 4_3 227 | self.model = nn.Sequential(*list(vgg16(pretrained=True).features)[:23]) 228 | # will not need to compute gradients 229 | for param in self.parameters(): 230 | param.requires_grad=False 231 | 232 | def forward(self, x, renormalize=True): 233 | # change normaliztion form [-1,1] to VGG normalization 234 | if renormalize: 235 | x = ((x*.5+.5)-torch.cuda.FloatTensor([0.485, 0.456, 0.406]).view(1,3,1,1))/torch.cuda.FloatTensor([0.229, 0.224, 0.225]).view(1,3,1,1) 236 | return self.model(x) 237 | 238 | 239 | class Discriminator(nn.Module): 240 | def __init__(self, gan_type='lsgan', input_dim=3, dim=64, n_layers=4, norm='bn', activ='lrelu', pad_type='reflect'): 241 | super(Discriminator, self).__init__() 242 | self.gan_type = gan_type 243 | self.model = [] 244 | self.model += [Conv2dBlock(input_dim, dim, 4, 2, 1, norm='none', activation=activ, pad_type=pad_type)] 245 | for i in range(n_layers - 1): 246 | self.model += [Conv2dBlock(dim, dim * 2, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] 247 | dim *= 2 248 | self.model += [nn.Conv2d(dim, 1, 1, 1, 0)] 249 | self.model = nn.Sequential(*self.model) 250 | #self.apply(weights_init('gaussian')) 251 | 252 | def forward(self, input): 253 | return self.model(input).mean(3).mean(2).squeeze() 254 | 255 | def calc_dis_loss(self, input_fake, input_real): 256 | input_fake = input_fake.detach() 257 | input_real = input_real.detach() 258 | out0 = self.forward(input_fake) 259 | out1 = self.forward(input_real) 260 | if self.gan_type == 'lsgan': 261 | loss = torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2) 262 | elif self.gan_type == 'nsgan': 263 | all0 = torch.zeros_like(out0, requires_grad=False).cuda() 264 | all1 = torch.ones_like(out1, requires_grad=False).cuda() 265 | loss = torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) + 266 | F.binary_cross_entropy(F.sigmoid(out1), all1)) 267 | elif self.gan_type == 'wgan': 268 | loss = out0.mean()-out1.mean() 269 | # grad penalty 270 | BatchSize = input_fake.size(0) 271 | alpha = torch.rand(BatchSize,1,1,1, requires_grad=False).cuda() 272 | interpolates = (alpha * input_real) + (( 1 - alpha ) * input_fake) 273 | interpolates.requires_grad=True 274 | outi = self.forward(interpolates) 275 | all1 = torch.ones_like(out1, requires_grad=False).cuda() 276 | gradients = torch.autograd.grad(outi, interpolates, grad_outputs=all1, create_graph=True)[0] 277 | #gradient_penalty = ((gradients.view(BatchSize,-1).norm(2, dim=1) - 1) ** 2).mean() 278 | gradient_penalty = ((gradients.view(BatchSize,-1).norm(1, dim=1) - 1).clamp(0) ** 2).mean() 279 | loss += 10*gradient_penalty 280 | else: 281 | assert 0, "Unsupported GAN type: {}".format(self.gan_type) 282 | return loss 283 | 284 | def calc_gen_loss(self, input_fake): 285 | out0 = self.forward(input_fake) 286 | if self.gan_type == 'lsgan': 287 | loss = torch.mean((out0 - 1)**2) 288 | elif self.gan_type == 'nsgan': 289 | all1 = torch.ones_like(out0.data, requires_grad=False).cuda() 290 | loss = torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all1)) 291 | elif self.gan_type == 'wgan': 292 | loss = -out0.mean() 293 | else: 294 | assert 0, "Unsupported GAN type: {}".format(self.gan_type) 295 | return loss 296 | 297 | 298 | 299 | ################################################################################## 300 | # Basic Blocks 301 | ################################################################################## 302 | class ResBlock(nn.Module): 303 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): 304 | super(ResBlock, self).__init__() 305 | 306 | model = [] 307 | model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] 308 | model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] 309 | self.model = nn.Sequential(*model) 310 | 311 | def forward(self, x): 312 | residual = x 313 | out = self.model(x) 314 | out += residual 315 | return out 316 | 317 | class Conv2dBlock(nn.Module): 318 | def __init__(self, input_dim ,output_dim, kernel_size, stride, 319 | padding=0, norm='none', activation='relu', pad_type='zero', transposed=False): 320 | super(Conv2dBlock, self).__init__() 321 | self.use_bias = True 322 | # initialize padding 323 | if pad_type == 'reflect': 324 | self.pad = nn.ReflectionPad2d(padding) 325 | elif pad_type == 'zero': 326 | self.pad = nn.ZeroPad2d(padding) 327 | else: 328 | assert 0, "Unsupported padding type: {}".format(pad_type) 329 | 330 | # initialize normalization 331 | norm_dim = output_dim 332 | if norm == 'bn': 333 | self.norm = nn.BatchNorm2d(norm_dim) 334 | elif norm == 'in': 335 | self.norm = nn.InstanceNorm2d(norm_dim) 336 | elif norm == 'gn': 337 | self.norm = nn.GroupNorm(norm_dim/8, norm_dim) 338 | elif norm == 'none': 339 | self.norm = None 340 | else: 341 | assert 0, "Unsupported normalization: {}".format(norm) 342 | 343 | # initialize activation 344 | if activation == 'relu': 345 | self.activation = nn.ReLU(inplace=True) 346 | elif activation == 'lrelu': 347 | self.activation = nn.LeakyReLU(0.2, inplace=True) 348 | elif activation == 'prelu': 349 | self.activation = nn.PReLU() 350 | elif activation == 'selu': 351 | self.activation = nn.SELU(inplace=True) 352 | elif activation == 'elu': 353 | self.activation = nn.ELU(inplace=True) 354 | elif activation == 'tanh': 355 | self.activation = nn.Tanh() 356 | elif activation == 'none': 357 | self.activation = None 358 | else: 359 | assert 0, "Unsupported activation: {}".format(activation) 360 | 361 | # initialize convolution 362 | if transposed: 363 | self.conv = nn.ConvTranspose2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) 364 | else: 365 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) 366 | 367 | def forward(self, x): 368 | x = self.conv(self.pad(x)) 369 | if self.norm: 370 | x = self.norm(x) 371 | if self.activation: 372 | x = self.activation(x) 373 | return x 374 | 375 | 376 | 377 | 378 | ################################################################################## 379 | # weight initialization 380 | ################################################################################## 381 | 382 | def weights_init(init_type='gaussian'): 383 | def init_fun(m): 384 | classname = m.__class__.__name__ 385 | if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'): 386 | # print m.__class__.__name__ 387 | if init_type == 'gaussian': 388 | nn.init.normal(m.weight.data, 0.0, 0.02) 389 | elif init_type == 'xavier': 390 | nn.init.xavier_normal(m.weight.data, gain=math.sqrt(2)) 391 | elif init_type == 'kaiming': 392 | nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 393 | elif init_type == 'orthogonal': 394 | nn.init.orthogonal(m.weight.data, gain=math.sqrt(2)) 395 | elif init_type == 'default': 396 | pass 397 | else: 398 | assert 0, "Unsupported initialization: {}".format(init_type) 399 | if hasattr(m, 'bias') and m.bias is not None: 400 | nn.init.constant(m.bias.data, 0.0) 401 | return init_fun 402 | 403 | 404 | 405 | --------------------------------------------------------------------------------