├── image.jpg ├── preview.jpg ├── angulargan.jpg ├── angulargan ├── scripts │ ├── install_deps.sh │ └── conda_deps.sh ├── requirements.txt ├── models │ ├── angular_loss │ │ ├── __pycache__ │ │ │ └── __init__.cpython-36.pyc │ │ └── __init__.py │ ├── __init__.py │ ├── test_model.py │ ├── angular_gan_model.py │ ├── angular_gan_v2_model.py │ ├── base_model.py │ └── networks.py ├── runtest.sh ├── data │ ├── base_data_loader.py │ ├── single_dataset.py │ ├── image_folder.py │ ├── unaligned_dataset.py │ ├── __init__.py │ ├── aligned_dataset.py │ └── base_dataset.py ├── run.sh ├── options │ ├── test_options.py │ ├── train_options.py │ └── base_options.py ├── util │ ├── image_pool.py │ ├── util.py │ ├── html.py │ ├── get_data.py │ └── visualizer.py ├── test.py ├── docs │ ├── qa.md │ ├── tips.md │ └── datasets.md ├── datasets │ ├── combine_A_and_B.py │ └── make_dataset_aligned.py └── train.py ├── real_illum_11346_Normalized.mat ├── LICENSE ├── generate_tinted_images.m └── README.md /image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acecreamu/angularGAN/HEAD/image.jpg -------------------------------------------------------------------------------- /preview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acecreamu/angularGAN/HEAD/preview.jpg -------------------------------------------------------------------------------- /angulargan.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acecreamu/angularGAN/HEAD/angulargan.jpg -------------------------------------------------------------------------------- /angulargan/scripts/install_deps.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | pip install visdom 3 | pip install dominate 4 | -------------------------------------------------------------------------------- /angulargan/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=0.4.0 2 | torchvision>=0.2.1 3 | dominate>=2.3.1 4 | visdom>=0.1.8.3 5 | -------------------------------------------------------------------------------- /real_illum_11346_Normalized.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acecreamu/angularGAN/HEAD/real_illum_11346_Normalized.mat -------------------------------------------------------------------------------- /angulargan/models/angular_loss/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/acecreamu/angularGAN/HEAD/angulargan/models/angular_loss/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /angulargan/runtest.sh: -------------------------------------------------------------------------------- 1 | a=$(ls datasets/facades/test | wc -l) 2 | echo $a 3 | python test.py --dataroot ./datasets/facades --name angular_gan --model angular_gan --which_direction BtoA --display_id -1 --how_many $a 4 | -------------------------------------------------------------------------------- /angulargan/data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | class BaseDataLoader(): 2 | def __init__(self): 3 | pass 4 | 5 | def initialize(self, opt): 6 | self.opt = opt 7 | pass 8 | 9 | def load_data(): 10 | return None 11 | -------------------------------------------------------------------------------- /angulargan/scripts/conda_deps.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | conda install numpy pyyaml mkl mkl-include setuptools cmake cffi typing 3 | conda install pytorch torchvision -c pytorch # add cuda90 if CUDA 9 4 | conda install visdom dominate -c conda-forge # install visdom and dominate 5 | -------------------------------------------------------------------------------- /angulargan/run.sh: -------------------------------------------------------------------------------- 1 | python train.py --checkpoints_dir ./checkpoints --dataroot ./datasets/facades --name angular_gan --model angular_gan --lambda_Angular 1 --lambda_L1 1 --which_model_netG unet_256 --which_direction BtoA --dataset_mode aligned --no_lsgan --norm batch --pool_size 0 --save_epoch_freq 10 2 | -------------------------------------------------------------------------------- /angulargan/models/angular_loss/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from math import pi 3 | 4 | class angular_loss(torch.nn.Module): 5 | 6 | def __init__(self): 7 | super(angular_loss,self).__init__() 8 | 9 | def forward(self, illum_gt, illum_pred): 10 | # img_gt = img_input / illum_gt 11 | # illum_gt = img_input / img_gt 12 | # illum_pred = img_input / img_output 13 | 14 | # ACOS 15 | cos_between = torch.nn.CosineSimilarity(dim=1) 16 | cos = cos_between(illum_gt, illum_pred) 17 | cos = torch.clamp(cos,-0.99999, 0.99999) 18 | loss = torch.mean(torch.acos(cos)) * 180 / pi 19 | 20 | # MSE 21 | # loss = torch.mean((illum_gt - illum_pred)**2) 22 | 23 | # 1 - COS 24 | # loss = 1 - torch.mean(cos) 25 | 26 | # 1 - COS^2 27 | # loss = 1 - torch.mean(cos**2) 28 | return loss 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /angulargan/options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 8 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 9 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 10 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 11 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 12 | parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') 13 | 14 | parser.set_defaults(model='test') 15 | # To avoid cropping, the loadSize should be the same as fineSize 16 | parser.set_defaults(loadSize=parser.get_default('fineSize')) 17 | 18 | self.isTrain = False 19 | return parser 20 | -------------------------------------------------------------------------------- /angulargan/util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | def __init__(self, pool_size): 7 | self.pool_size = pool_size 8 | if self.pool_size > 0: 9 | self.num_imgs = 0 10 | self.images = [] 11 | 12 | def query(self, images): 13 | if self.pool_size == 0: 14 | return images 15 | return_images = [] 16 | for image in images: 17 | image = torch.unsqueeze(image.data, 0) 18 | if self.num_imgs < self.pool_size: 19 | self.num_imgs = self.num_imgs + 1 20 | self.images.append(image) 21 | return_images.append(image) 22 | else: 23 | p = random.uniform(0, 1) 24 | if p > 0.5: 25 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 26 | tmp = self.images[random_id].clone() 27 | self.images[random_id] = image 28 | return_images.append(tmp) 29 | else: 30 | return_images.append(image) 31 | return_images = torch.cat(return_images, 0) 32 | return return_images 33 | -------------------------------------------------------------------------------- /angulargan/data/single_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | 6 | 7 | class SingleDataset(BaseDataset): 8 | @staticmethod 9 | def modify_commandline_options(parser, is_train): 10 | return parser 11 | 12 | def initialize(self, opt): 13 | self.opt = opt 14 | self.root = opt.dataroot 15 | self.dir_A = os.path.join(opt.dataroot) 16 | 17 | self.A_paths = make_dataset(self.dir_A) 18 | 19 | self.A_paths = sorted(self.A_paths) 20 | 21 | self.transform = get_transform(opt) 22 | 23 | def __getitem__(self, index): 24 | A_path = self.A_paths[index] 25 | A_img = Image.open(A_path).convert('RGB') 26 | A = self.transform(A_img) 27 | if self.opt.which_direction == 'BtoA': 28 | input_nc = self.opt.output_nc 29 | else: 30 | input_nc = self.opt.input_nc 31 | 32 | if input_nc == 1: # RGB to gray 33 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 34 | A = tmp.unsqueeze(0) 35 | 36 | return {'A': A, 'A_paths': A_path} 37 | 38 | def __len__(self): 39 | return len(self.A_paths) 40 | 41 | def name(self): 42 | return 'SingleImageDataset' 43 | -------------------------------------------------------------------------------- /angulargan/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from models.base_model import BaseModel 3 | 4 | 5 | def find_model_using_name(model_name): 6 | # Given the option --model [modelname], 7 | # the file "models/modelname_model.py" 8 | # will be imported. 9 | model_filename = "models." + model_name + "_model" 10 | modellib = importlib.import_module(model_filename) 11 | 12 | # In the file, the class called ModelNameModel() will 13 | # be instantiated. It has to be a subclass of BaseModel, 14 | # and it is case-insensitive. 15 | model = None 16 | target_model_name = model_name.replace('_', '') + 'model' 17 | for name, cls in modellib.__dict__.items(): 18 | if name.lower() == target_model_name.lower() \ 19 | and issubclass(cls, BaseModel): 20 | model = cls 21 | 22 | if model is None: 23 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 24 | exit(0) 25 | 26 | return model 27 | 28 | 29 | def get_option_setter(model_name): 30 | model_class = find_model_using_name(model_name) 31 | return model_class.modify_commandline_options 32 | 33 | 34 | def create_model(opt): 35 | model = find_model_using_name(opt.model) 36 | instance = model() 37 | instance.initialize(opt) 38 | print("model [%s] was created" % (instance.name())) 39 | return instance 40 | -------------------------------------------------------------------------------- /angulargan/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from options.test_options import TestOptions 3 | from data import CreateDataLoader 4 | from models import create_model 5 | from util.visualizer import save_images 6 | from util import html 7 | 8 | 9 | if __name__ == '__main__': 10 | opt = TestOptions().parse() 11 | opt.nThreads = 1 # test code only supports nThreads = 1 12 | opt.batchSize = 1 # test code only supports batchSize = 1 13 | opt.serial_batches = True # no shuffle 14 | opt.no_flip = True # no flip 15 | opt.display_id = -1 # no visdom display 16 | data_loader = CreateDataLoader(opt) 17 | dataset = data_loader.load_data() 18 | model = create_model(opt) 19 | model.setup(opt) 20 | # create website 21 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) 22 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 23 | # test 24 | for i, data in enumerate(dataset): 25 | if i >= opt.how_many: 26 | break 27 | model.set_input(data) 28 | model.test() 29 | visuals = model.get_current_visuals() 30 | img_path = model.get_image_paths() 31 | if i % 5 == 0: 32 | print('processing (%04d)-th image... %s' % (i, img_path)) 33 | save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) 34 | 35 | webpage.save() 36 | -------------------------------------------------------------------------------- /angulargan/docs/qa.md: -------------------------------------------------------------------------------- 1 | ## Frequently Asked Questions 2 | - The color gets inverted from the beginning of training: the authors also observed that the generator unnecessarily inverts the color of the input image early in the trianing, and then never learns to undo the inversion. In this case, two things can be tried. First, try using identity loss `--identity 1.0` or `--identity 0.1`. We observed that the identity loss makes the generator to be more conservative and make less of unnecessary changes. However, because of this, the change may not be as dramatic. Second, try smaller variance when initializing weights by changing `--init_gain`. We observed that smaller variance in weight initialization results in less color inversion. 3 | - Out of memory error: CycleGAN is quite memory intensive because it needs two generator networks and two discriminator networks. If you would like to generate high resolution images, you can do the following. First, train CycleGAN on cropped images of the training set. Please be careful not to change the aspect ratio or the scale of the original image, as this can lead to the training/test gap. You can usually do this by using `--resize_or_crop crop` option, or `--resize_or_crop scale_width_and_crop`. Then at test time, load only one generator to generate the results in one direction only. This greatly saves memory because you are not loading the discriminators nad the other generator in the opposite direction. You can probably input the whole image (we have done image generation of 1024x512 resolution). You can do this using `--model test --dataroot [path to the directory containing the actual images (ex. ./datasets/horse2zebra/trainA)] --model_suffix _A`. For more explanation, please see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/test_model.py#L16. 4 | -------------------------------------------------------------------------------- /generate_tinted_images.m: -------------------------------------------------------------------------------- 1 | %% Load data 2 | fID = fopen('Source_Image/file.lst'); 3 | list = textscan(fID,'%s','delimiter','\n'); 4 | list = [list{1,1}]; 5 | fclose(fID); 6 | 7 | response = load('Source_Image/real_illum_11346_Normalized.mat'); 8 | response = response.real_rgb; 9 | 10 | %% Train-Test partition 11 | rng('default') 12 | CVparts = cvpartition(11346,'KFold',15); 13 | idx = test(CVparts,1); 14 | 15 | %% Main loop 16 | for i = 1:size(list,1) 17 | img = imread(list{i,1}); 18 | img = img(1:224,1:224,:); 19 | 20 | % color correction acording to gt illumination 21 | e = reshape(response(i,:),1,1,3); 22 | img = uint8(double(img)./e); 23 | 24 | img = imresize(img,[256,256]); 25 | [img_, map] = tint_image(img); 26 | img = cat(2,img_,img); 27 | filename = strcat('train\',sprintf('%05d.jpg', i)); 28 | if idx(i) 29 | filename = strcat('test\',sprintf('%05d.jpg', i)); 30 | end 31 | imwrite(img,filename); 32 | filename = strcat('tint_maps\',sprintf('%05d.jpg', i)); 33 | imwrite(map,filename); 34 | end 35 | 36 | %% Tint function 37 | function [img_, map] = tint_image(Im) 38 | map = zeros(256,256,3); 39 | [X,Y] = meshgrid(1:256,1:256); 40 | colors = [0, 0, 255; 0, 255, 0; 255, 0, 0]; 41 | for i = 1:3 42 | mu = rand(1,2)*256; 43 | sigma = [rand(1)*25600 0; 0 rand(1)*25600]; 44 | tint = mvnpdf([X(:) Y(:)], mu, sigma); 45 | tint = reshape(tint, 256, 256); 46 | tint = tint / max(tint(:)); 47 | tint = repmat(tint,[1 1 3]); 48 | tint = tint .* reshape(colors(i,:)/norm(colors(i,:)), 1,1,3); 49 | tint = 1 - tint; 50 | map = map + tint ./ 3; 51 | end 52 | map = imresize(map, [size(Im,1) size(Im,2)]); 53 | 54 | img_ = uint8(map .* double(Im)); 55 | map = uint8(res * 255); 56 | end 57 | -------------------------------------------------------------------------------- /angulargan/util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import os 6 | 7 | 8 | # Converts a Tensor into an image array (numpy) 9 | # |imtype|: the desired type of the converted numpy array 10 | def tensor2im(input_image, imtype=np.uint8): 11 | if isinstance(input_image, torch.Tensor): 12 | image_tensor = input_image.data 13 | else: 14 | return input_image 15 | image_numpy = image_tensor[0].cpu().float().numpy() 16 | if image_numpy.shape[0] == 1: 17 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 18 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 19 | return image_numpy.astype(imtype) 20 | 21 | 22 | def diagnose_network(net, name='network'): 23 | mean = 0.0 24 | count = 0 25 | for param in net.parameters(): 26 | if param.grad is not None: 27 | mean += torch.mean(torch.abs(param.grad.data)) 28 | count += 1 29 | if count > 0: 30 | mean = mean / count 31 | print(name) 32 | print(mean) 33 | 34 | 35 | def save_image(image_numpy, image_path): 36 | image_pil = Image.fromarray(image_numpy) 37 | image_pil.save(image_path) 38 | 39 | 40 | def print_numpy(x, val=True, shp=False): 41 | x = x.astype(np.float64) 42 | if shp: 43 | print('shape,', x.shape) 44 | if val: 45 | x = x.flatten() 46 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 47 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 48 | 49 | 50 | def mkdirs(paths): 51 | if isinstance(paths, list) and not isinstance(paths, str): 52 | for path in paths: 53 | mkdir(path) 54 | else: 55 | mkdir(paths) 56 | 57 | 58 | def mkdir(path): 59 | if not os.path.exists(path): 60 | os.makedirs(path) 61 | -------------------------------------------------------------------------------- /angulargan/models/test_model.py: -------------------------------------------------------------------------------- 1 | from .base_model import BaseModel 2 | from . import networks 3 | from .cycle_gan_model import CycleGANModel 4 | 5 | 6 | class TestModel(BaseModel): 7 | def name(self): 8 | return 'TestModel' 9 | 10 | @staticmethod 11 | def modify_commandline_options(parser, is_train=True): 12 | assert not is_train, 'TestModel cannot be used in train mode' 13 | parser = CycleGANModel.modify_commandline_options(parser, is_train=False) 14 | parser.set_defaults(dataset_mode='single') 15 | 16 | parser.add_argument('--model_suffix', type=str, default='', 17 | help='In checkpoints_dir, [which_epoch]_net_G[model_suffix].pth will' 18 | ' be loaded as the generator of TestModel') 19 | 20 | return parser 21 | 22 | def initialize(self, opt): 23 | assert(not opt.isTrain) 24 | BaseModel.initialize(self, opt) 25 | 26 | # specify the training losses you want to print out. The program will call base_model.get_current_losses 27 | self.loss_names = [] 28 | # specify the images you want to save/display. The program will call base_model.get_current_visuals 29 | self.visual_names = ['real_A', 'fake_B'] 30 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks 31 | self.model_names = ['G' + opt.model_suffix] 32 | 33 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG, 34 | opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) 35 | 36 | # assigns the model to self.netG_[suffix] so that it can be loaded 37 | # please see BaseModel.load_networks 38 | setattr(self, 'netG' + opt.model_suffix, self.netG) 39 | 40 | def set_input(self, input): 41 | # we need to use single_dataset mode 42 | self.real_A = input['A'].to(self.device) 43 | self.image_paths = input['A_paths'] 44 | 45 | def forward(self): 46 | self.fake_B = self.netG(self.real_A) 47 | -------------------------------------------------------------------------------- /angulargan/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, reflesh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | # print(self.img_dir) 16 | 17 | self.doc = dominate.document(title=title) 18 | if reflesh > 0: 19 | with self.doc.head: 20 | meta(http_equiv="reflesh", content=str(reflesh)) 21 | 22 | def get_image_dir(self): 23 | return self.img_dir 24 | 25 | def add_header(self, str): 26 | with self.doc: 27 | h3(str) 28 | 29 | def add_table(self, border=1): 30 | self.t = table(border=border, style="table-layout: fixed;") 31 | self.doc.add(self.t) 32 | 33 | def add_images(self, ims, txts, links, width=400): 34 | self.add_table() 35 | with self.t: 36 | with tr(): 37 | for im, txt, link in zip(ims, txts, links): 38 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 39 | with p(): 40 | with a(href=os.path.join('images', link)): 41 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 42 | br() 43 | p(txt) 44 | 45 | def save(self): 46 | html_file = '%s/index.html' % self.web_dir 47 | f = open(html_file, 'wt') 48 | f.write(self.doc.render()) 49 | f.close() 50 | 51 | 52 | if __name__ == '__main__': 53 | html = HTML('web/', 'test_html') 54 | html.add_header('hello world') 55 | 56 | ims = [] 57 | txts = [] 58 | links = [] 59 | for n in range(4): 60 | ims.append('image_%d.png' % n) 61 | txts.append('text_%d' % n) 62 | links.append('image_%d.png' % n) 63 | html.add_images(ims, txts, links) 64 | html.save() 65 | -------------------------------------------------------------------------------- /angulargan/data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | 8 | import torch.utils.data as data 9 | 10 | from PIL import Image 11 | import os 12 | import os.path 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir): 25 | images = [] 26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 27 | 28 | for root, _, fnames in sorted(os.walk(dir)): 29 | for fname in fnames: 30 | if is_image_file(fname): 31 | path = os.path.join(root, fname) 32 | images.append(path) 33 | 34 | return images 35 | 36 | 37 | def default_loader(path): 38 | return Image.open(path).convert('RGB') 39 | 40 | 41 | class ImageFolder(data.Dataset): 42 | 43 | def __init__(self, root, transform=None, return_paths=False, 44 | loader=default_loader): 45 | imgs = make_dataset(root) 46 | if len(imgs) == 0: 47 | raise(RuntimeError("Found 0 images in: " + root + "\n" 48 | "Supported image extensions are: " + 49 | ",".join(IMG_EXTENSIONS))) 50 | 51 | self.root = root 52 | self.imgs = imgs 53 | self.transform = transform 54 | self.return_paths = return_paths 55 | self.loader = loader 56 | 57 | def __getitem__(self, index): 58 | path = self.imgs[index] 59 | img = self.loader(path) 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | if self.return_paths: 63 | return img, path 64 | else: 65 | return img 66 | 67 | def __len__(self): 68 | return len(self.imgs) 69 | -------------------------------------------------------------------------------- /angulargan/datasets/combine_A_and_B.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import cv2 4 | import argparse 5 | 6 | parser = argparse.ArgumentParser('create image pairs') 7 | parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges') 8 | parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg') 9 | parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB') 10 | parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000) 11 | parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true') 12 | args = parser.parse_args() 13 | 14 | for arg in vars(args): 15 | print('[%s] = ' % arg, getattr(args, arg)) 16 | 17 | splits = os.listdir(args.fold_A) 18 | 19 | for sp in splits: 20 | img_fold_A = os.path.join(args.fold_A, sp) 21 | img_fold_B = os.path.join(args.fold_B, sp) 22 | img_list = os.listdir(img_fold_A) 23 | if args.use_AB: 24 | img_list = [img_path for img_path in img_list if '_A.' in img_path] 25 | 26 | num_imgs = min(args.num_imgs, len(img_list)) 27 | print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list))) 28 | img_fold_AB = os.path.join(args.fold_AB, sp) 29 | if not os.path.isdir(img_fold_AB): 30 | os.makedirs(img_fold_AB) 31 | print('split = %s, number of images = %d' % (sp, num_imgs)) 32 | for n in range(num_imgs): 33 | name_A = img_list[n] 34 | path_A = os.path.join(img_fold_A, name_A) 35 | if args.use_AB: 36 | name_B = name_A.replace('_A.', '_B.') 37 | else: 38 | name_B = name_A 39 | path_B = os.path.join(img_fold_B, name_B) 40 | if os.path.isfile(path_A) and os.path.isfile(path_B): 41 | name_AB = name_A 42 | if args.use_AB: 43 | name_AB = name_AB.replace('_A.', '.') # remove _A 44 | path_AB = os.path.join(img_fold_AB, name_AB) 45 | im_A = cv2.imread(path_A, cv2.CV_LOAD_IMAGE_COLOR) 46 | im_B = cv2.imread(path_B, cv2.CV_LOAD_IMAGE_COLOR) 47 | im_AB = np.concatenate([im_A, im_B], 1) 48 | cv2.imwrite(path_AB, im_AB) 49 | -------------------------------------------------------------------------------- /angulargan/data/unaligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | import random 6 | 7 | 8 | class UnalignedDataset(BaseDataset): 9 | @staticmethod 10 | def modify_commandline_options(parser, is_train): 11 | return parser 12 | 13 | def initialize(self, opt): 14 | self.opt = opt 15 | self.root = opt.dataroot 16 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') 17 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') 18 | 19 | self.A_paths = make_dataset(self.dir_A) 20 | self.B_paths = make_dataset(self.dir_B) 21 | 22 | self.A_paths = sorted(self.A_paths) 23 | self.B_paths = sorted(self.B_paths) 24 | self.A_size = len(self.A_paths) 25 | self.B_size = len(self.B_paths) 26 | self.transform = get_transform(opt) 27 | 28 | def __getitem__(self, index): 29 | A_path = self.A_paths[index % self.A_size] 30 | if self.opt.serial_batches: 31 | index_B = index % self.B_size 32 | else: 33 | index_B = random.randint(0, self.B_size - 1) 34 | B_path = self.B_paths[index_B] 35 | # print('(A, B) = (%d, %d)' % (index_A, index_B)) 36 | A_img = Image.open(A_path).convert('RGB') 37 | B_img = Image.open(B_path).convert('RGB') 38 | 39 | A = self.transform(A_img) 40 | B = self.transform(B_img) 41 | if self.opt.which_direction == 'BtoA': 42 | input_nc = self.opt.output_nc 43 | output_nc = self.opt.input_nc 44 | else: 45 | input_nc = self.opt.input_nc 46 | output_nc = self.opt.output_nc 47 | 48 | if input_nc == 1: # RGB to gray 49 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 50 | A = tmp.unsqueeze(0) 51 | 52 | if output_nc == 1: # RGB to gray 53 | tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 54 | B = tmp.unsqueeze(0) 55 | return {'A': A, 'B': B, 56 | 'A_paths': A_path, 'B_paths': B_path} 57 | 58 | def __len__(self): 59 | return max(self.A_size, self.B_size) 60 | 61 | def name(self): 62 | return 'UnalignedDataset' 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conditional GANs for Multi-Illuminant Color Constancy: Revolution or Yet Another Approach? 2 | Supporting code to the paper
3 | [O Sidorov. Conditional GANs for Multi-Illuminant Color Constancy: Revolution or Yet Another Approach?](https://arxiv.org/abs/1811.06604) 4 |
5 | 6 | ![image preview](https://github.com/acecreamu/angularGAN/blob/master/image.jpg) 7 | 8 | # AngularGAN 9 | The work presents an extension of the supervised image-to-image translation algorithm ["pix2pix" by Isola *et al.*](https://arxiv.org/abs/1611.07004) orriented specifically to the color constancy task.

10 | AngularGAN inherits from [this](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) implementation of pix2pix in PyTorch. Therefore, you may follow original instruction for installation and dependincies. The new modules are implemented in Torch and do not require additional packages. 11 | #### Datasets are below!
12 | ### Getting started 13 | - Put your data in datasets/facades in the format 14 | ``` 15 | -facedes/ 16 | -test/ 17 | -xxx.jpg 18 | -yyy.jpg 19 | -... 20 | -train/ 21 | -zzz.jpg 22 | -... 23 | ``` 24 | where each image consist of couple of images A and B (input and output) concatenated along horizontal axis.
25 | - Run `visdom` to open training visualization (optional) 26 | - Run training (change parameter `--model angular_gan_v2` to use v2) 27 | ``` 28 | chmod a+x run.sh 29 | ./run.sh 30 | ``` 31 | - Replace `runtest.sh` for testing (change parameter `--model angular_gan_v2` to use v2)

32 | *We thank autors of pix2pix for their excellent work!* 33 | 34 | ![angulargan_framework](https://github.com/acecreamu/angularGAN/blob/master/angulargan.jpg) 35 | 36 | # Datasets 37 | ### Tinted Multi-illuminant dataset 38 | 39 | The MATLAB code `generate_tinted_images.m` allows to apply multi-illuimnant color cast to the input images. The tint maps are randomized and are not coherent between frames.
40 | You can use the provided file `real_illum_11346_Normalized.mat` or create your own by simple normalization of the original illumination vectors as `e_norm = e./norm(e)`. 41 | 42 | ### GTAV Shadow Removal Dataset 43 | The GTAV Shadow Removal Dataset of 5,723 image pairs with and without shadows may be acessed by the [link](https://drive.google.com/open?id=1ktOXJmMQL_6U2J03mks3yWh6EMWKjUmu).
44 | 45 | #### Preview 46 | 47 | ![dataset preview](https://github.com/acecreamu/angularGAN/blob/master/preview.jpg) 48 | -------------------------------------------------------------------------------- /angulargan/datasets/make_dataset_aligned.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | 5 | 6 | def get_file_paths(folder): 7 | image_file_paths = [] 8 | for root, dirs, filenames in os.walk(folder): 9 | filenames = sorted(filenames) 10 | for filename in filenames: 11 | input_path = os.path.abspath(root) 12 | file_path = os.path.join(input_path, filename) 13 | if filename.endswith('.png') or filename.endswith('.jpg'): 14 | image_file_paths.append(file_path) 15 | 16 | break # prevent descending into subfolders 17 | return image_file_paths 18 | 19 | 20 | def align_images(a_file_paths, b_file_paths, target_path): 21 | if not os.path.exists(target_path): 22 | os.makedirs(target_path) 23 | 24 | for i in range(len(a_file_paths)): 25 | img_a = Image.open(a_file_paths[i]) 26 | img_b = Image.open(b_file_paths[i]) 27 | assert(img_a.size == img_b.size) 28 | 29 | aligned_image = Image.new("RGB", (img_a.size[0] * 2, img_a.size[1])) 30 | aligned_image.paste(img_a, (0, 0)) 31 | aligned_image.paste(img_b, (img_a.size[0], 0)) 32 | aligned_image.save(os.path.join(target_path, '{:04d}.jpg'.format(i))) 33 | 34 | 35 | if __name__ == '__main__': 36 | import argparse 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument( 39 | '--dataset-path', 40 | dest='dataset_path', 41 | help='Which folder to process (it should have subfolders testA, testB, trainA and trainB' 42 | ) 43 | args = parser.parse_args() 44 | 45 | dataset_folder = args.dataset_path 46 | print(dataset_folder) 47 | 48 | test_a_path = os.path.join(dataset_folder, 'testA') 49 | test_b_path = os.path.join(dataset_folder, 'testB') 50 | test_a_file_paths = get_file_paths(test_a_path) 51 | test_b_file_paths = get_file_paths(test_b_path) 52 | assert(len(test_a_file_paths) == len(test_b_file_paths)) 53 | test_path = os.path.join(dataset_folder, 'test') 54 | 55 | train_a_path = os.path.join(dataset_folder, 'trainA') 56 | train_b_path = os.path.join(dataset_folder, 'trainB') 57 | train_a_file_paths = get_file_paths(train_a_path) 58 | train_b_file_paths = get_file_paths(train_b_path) 59 | assert(len(train_a_file_paths) == len(train_b_file_paths)) 60 | train_path = os.path.join(dataset_folder, 'train') 61 | 62 | align_images(test_a_file_paths, test_b_file_paths, test_path) 63 | align_images(train_a_file_paths, train_b_file_paths, train_path) 64 | -------------------------------------------------------------------------------- /angulargan/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | from options.train_options import TrainOptions 3 | from data import CreateDataLoader 4 | from models import create_model 5 | from util.visualizer import Visualizer 6 | 7 | if __name__ == '__main__': 8 | opt = TrainOptions().parse() 9 | data_loader = CreateDataLoader(opt) 10 | dataset = data_loader.load_data() 11 | dataset_size = len(data_loader) 12 | print('#training images = %d' % dataset_size) 13 | 14 | model = create_model(opt) 15 | model.setup(opt) 16 | visualizer = Visualizer(opt) 17 | total_steps = 0 18 | 19 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1): 20 | epoch_start_time = time.time() 21 | iter_data_time = time.time() 22 | epoch_iter = 0 23 | 24 | for i, data in enumerate(dataset): 25 | iter_start_time = time.time() 26 | if total_steps % opt.print_freq == 0: 27 | t_data = iter_start_time - iter_data_time 28 | visualizer.reset() 29 | total_steps += opt.batchSize 30 | epoch_iter += opt.batchSize 31 | model.set_input(data) 32 | model.optimize_parameters() 33 | 34 | if total_steps % opt.display_freq == 0: 35 | save_result = total_steps % opt.update_html_freq == 0 36 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 37 | 38 | if total_steps % opt.print_freq == 0: 39 | losses = model.get_current_losses() 40 | t = (time.time() - iter_start_time) / opt.batchSize 41 | visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data) 42 | if opt.display_id > 0: 43 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses) 44 | 45 | if total_steps % opt.save_latest_freq == 0: 46 | print('saving the latest model (epoch %d, total_steps %d)' % 47 | (epoch, total_steps)) 48 | model.save_networks('latest') 49 | 50 | iter_data_time = time.time() 51 | if epoch % opt.save_epoch_freq == 0: 52 | print('saving the model at the end of epoch %d, iters %d' % 53 | (epoch, total_steps)) 54 | model.save_networks('latest') 55 | model.save_networks(epoch) 56 | 57 | print('End of epoch %d / %d \t Time Taken: %d sec' % 58 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 59 | model.update_learning_rate() 60 | -------------------------------------------------------------------------------- /angulargan/options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def initialize(self, parser): 6 | parser = BaseOptions.initialize(self, parser) 7 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') 8 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 9 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 10 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 11 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 12 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') 13 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 14 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 15 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 16 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 17 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 18 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 19 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 20 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 21 | parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') 22 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 23 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 24 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau') 25 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 26 | 27 | self.isTrain = True 28 | return parser 29 | -------------------------------------------------------------------------------- /angulargan/docs/tips.md: -------------------------------------------------------------------------------- 1 | ## Training/test Tips 2 | - Flags: see `options/train_options.py` and `options/base_options.py` for the training flags; see `options/test_options.py` and `options/base_options.py` for the test flags. There are some model-specific flags as well, which are added in the model files, such as `--lambda_A` option in `model/cycle_gan_model.py`. The default values of these options are also adjusted in the model files. 3 | - CPU/GPU (default `--gpu_ids 0`): set`--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. You need a large batch size (e.g. `--batchSize 32`) to benefit from multiple GPUs. 4 | - Visualization: during training, the current results can be viewed using two methods. First, if you set `--display_id` > 0, the results and loss plot will appear on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should have `visdom` installed and a server running by the command `python -m visdom.server`. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. The `visdom` display functionality is turned on by default. To avoid the extra overhead of communicating with `visdom` set `--display_id -1`. Second, the intermediate results are saved to `[opt.checkpoints_dir]/[opt.name]/web/` as an HTML file. To avoid this, set `--no_html`. 5 | - Preprocessing: images can be resized and cropped in different ways using `--resize_or_crop` option. The default option `'resize_and_crop'` resizes the image to be of size `(opt.loadSize, opt.loadSize)` and does a random crop of size `(opt.fineSize, opt.fineSize)`. `'crop'` skips the resizing step and only performs random cropping. `'scale_width'` resizes the image to have width `opt.fineSize` while keeping the aspect ratio. `'scale_width_and_crop'` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`. `'none'` tries to skip all these preprocessing steps. However, if the image size is not a multiple of some number depending on the number of downsamplings of the generator, you will get an error because the size of the output image may be different from the size of the input image. Therefore, `'none'` option still tries to adjust the image size to be a multiple of 4. You might need a bigger adjustment if you change the generator architecture. Please see `data/base_datset.py` do see how all these were implemented. 6 | - Fine-tuning/Resume training: to fine-tune a pre-trained model, or resume the previous training, use the `--continue_train` flag. The program will then load the model based on `which_epoch`. By default, the program will initialize the epoch count as 1. Set `--epoch_count ` to specify a different starting epoch count. 7 | -------------------------------------------------------------------------------- /angulargan/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch.utils.data 3 | from data.base_data_loader import BaseDataLoader 4 | from data.base_dataset import BaseDataset 5 | 6 | def find_dataset_using_name(dataset_name): 7 | # Given the option --dataset_mode [datasetname], 8 | # the file "data/datasetname_dataset.py" 9 | # will be imported. 10 | dataset_filename = "data." + dataset_name + "_dataset" 11 | datasetlib = importlib.import_module(dataset_filename) 12 | 13 | # In the file, the class called DatasetNameDataset() will 14 | # be instantiated. It has to be a subclass of BaseDataset, 15 | # and it is case-insensitive. 16 | dataset = None 17 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 18 | for name, cls in datasetlib.__dict__.items(): 19 | if name.lower() == target_dataset_name.lower() \ 20 | and issubclass(cls, BaseDataset): 21 | dataset = cls 22 | 23 | if dataset is None: 24 | print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 25 | exit(0) 26 | 27 | return dataset 28 | 29 | 30 | def get_option_setter(dataset_name): 31 | dataset_class = find_dataset_using_name(dataset_name) 32 | return dataset_class.modify_commandline_options 33 | 34 | 35 | def create_dataset(opt): 36 | dataset = find_dataset_using_name(opt.dataset_mode) 37 | instance = dataset() 38 | instance.initialize(opt) 39 | print("dataset [%s] was created" % (instance.name())) 40 | return instance 41 | 42 | 43 | def CreateDataLoader(opt): 44 | data_loader = CustomDatasetDataLoader() 45 | data_loader.initialize(opt) 46 | return data_loader 47 | 48 | 49 | ## Wrapper class of Dataset class that performs 50 | ## multi-threaded data loading 51 | class CustomDatasetDataLoader(BaseDataLoader): 52 | def name(self): 53 | return 'CustomDatasetDataLoader' 54 | 55 | def initialize(self, opt): 56 | BaseDataLoader.initialize(self, opt) 57 | self.dataset = create_dataset(opt) 58 | self.dataloader = torch.utils.data.DataLoader( 59 | self.dataset, 60 | batch_size=opt.batchSize, 61 | shuffle=not opt.serial_batches, 62 | num_workers=int(opt.nThreads)) 63 | 64 | def load_data(self): 65 | return self 66 | 67 | def __len__(self): 68 | return min(len(self.dataset), self.opt.max_dataset_size) 69 | 70 | def __iter__(self): 71 | for i, data in enumerate(self.dataloader): 72 | if i * self.opt.batchSize >= self.opt.max_dataset_size: 73 | break 74 | yield data 75 | -------------------------------------------------------------------------------- /angulargan/data/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | import torchvision.transforms as transforms 4 | import torch 5 | from data.base_dataset import BaseDataset 6 | from data.image_folder import make_dataset 7 | from PIL import Image 8 | 9 | 10 | class AlignedDataset(BaseDataset): 11 | @staticmethod 12 | def modify_commandline_options(parser, is_train): 13 | return parser 14 | 15 | def initialize(self, opt): 16 | self.opt = opt 17 | self.root = opt.dataroot 18 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) 19 | self.AB_paths = sorted(make_dataset(self.dir_AB)) 20 | assert(opt.resize_or_crop == 'resize_and_crop') 21 | 22 | def __getitem__(self, index): 23 | AB_path = self.AB_paths[index] 24 | AB = Image.open(AB_path).convert('RGB') 25 | w, h = AB.size 26 | w2 = int(w / 2) 27 | A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) 28 | B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC) 29 | A = transforms.ToTensor()(A) 30 | B = transforms.ToTensor()(B) 31 | w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1)) 32 | h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1)) 33 | 34 | A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] 35 | B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize] 36 | 37 | A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A) 38 | B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B) 39 | 40 | if self.opt.which_direction == 'BtoA': 41 | input_nc = self.opt.output_nc 42 | output_nc = self.opt.input_nc 43 | else: 44 | input_nc = self.opt.input_nc 45 | output_nc = self.opt.output_nc 46 | 47 | if (not self.opt.no_flip) and random.random() < 0.5: 48 | idx = [i for i in range(A.size(2) - 1, -1, -1)] 49 | idx = torch.LongTensor(idx) 50 | A = A.index_select(2, idx) 51 | B = B.index_select(2, idx) 52 | 53 | if input_nc == 1: # RGB to gray 54 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114 55 | A = tmp.unsqueeze(0) 56 | 57 | if output_nc == 1: # RGB to gray 58 | tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114 59 | B = tmp.unsqueeze(0) 60 | 61 | return {'A': A, 'B': B, 62 | 'A_paths': AB_path, 'B_paths': AB_path} 63 | 64 | def __len__(self): 65 | return len(self.AB_paths) 66 | 67 | def name(self): 68 | return 'AlignedDataset' 69 | -------------------------------------------------------------------------------- /angulargan/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | 5 | 6 | class BaseDataset(data.Dataset): 7 | def __init__(self): 8 | super(BaseDataset, self).__init__() 9 | 10 | def name(self): 11 | return 'BaseDataset' 12 | 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train): 15 | return parser 16 | 17 | def initialize(self, opt): 18 | pass 19 | 20 | def __len__(self): 21 | return 0 22 | 23 | 24 | def get_transform(opt): 25 | transform_list = [] 26 | if opt.resize_or_crop == 'resize_and_crop': 27 | osize = [opt.loadSize, opt.loadSize] 28 | transform_list.append(transforms.Resize(osize, Image.BICUBIC)) 29 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 30 | elif opt.resize_or_crop == 'crop': 31 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 32 | elif opt.resize_or_crop == 'scale_width': 33 | transform_list.append(transforms.Lambda( 34 | lambda img: __scale_width(img, opt.fineSize))) 35 | elif opt.resize_or_crop == 'scale_width_and_crop': 36 | transform_list.append(transforms.Lambda( 37 | lambda img: __scale_width(img, opt.loadSize))) 38 | transform_list.append(transforms.RandomCrop(opt.fineSize)) 39 | elif opt.resize_or_crop == 'none': 40 | transform_list.append(transforms.Lambda( 41 | lambda img: __adjust(img))) 42 | else: 43 | raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop) 44 | 45 | if opt.isTrain and not opt.no_flip: 46 | transform_list.append(transforms.RandomHorizontalFlip()) 47 | 48 | transform_list += [transforms.ToTensor(), 49 | transforms.Normalize((0.5, 0.5, 0.5), 50 | (0.5, 0.5, 0.5))] 51 | return transforms.Compose(transform_list) 52 | 53 | # just modify the width and height to be multiple of 4 54 | def __adjust(img): 55 | ow, oh = img.size 56 | 57 | # the size needs to be a multiple of this number, 58 | # because going through generator network may change img size 59 | # and eventually cause size mismatch error 60 | mult = 4 61 | if ow % mult == 0 and oh % mult == 0: 62 | return img 63 | w = (ow - 1) // mult 64 | w = (w + 1) * mult 65 | h = (oh - 1) // mult 66 | h = (h + 1) * mult 67 | 68 | if ow != w or oh != h: 69 | __print_size_warning(ow, oh, w, h) 70 | 71 | return img.resize((w, h), Image.BICUBIC) 72 | 73 | 74 | def __scale_width(img, target_width): 75 | ow, oh = img.size 76 | 77 | # the size needs to be a multiple of this number, 78 | # because going through generator network may change img size 79 | # and eventually cause size mismatch error 80 | mult = 4 81 | assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult 82 | if (ow == target_width and oh % mult == 0): 83 | return img 84 | w = target_width 85 | target_height = int(target_width * oh / ow) 86 | m = (target_height - 1) // mult 87 | h = (m + 1) * mult 88 | 89 | if target_height != h: 90 | __print_size_warning(target_width, target_height, w, h) 91 | 92 | return img.resize((w, h), Image.BICUBIC) 93 | 94 | 95 | def __print_size_warning(ow, oh, w, h): 96 | if not hasattr(__print_size_warning, 'has_printed'): 97 | print("The image size needs to be a multiple of 4. " 98 | "The loaded image size was (%d, %d), so it was adjusted to " 99 | "(%d, %d). This adjustment will be done to all images " 100 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 101 | __print_size_warning.has_printed = True 102 | 103 | 104 | -------------------------------------------------------------------------------- /angulargan/util/get_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import tarfile 4 | import requests 5 | from warnings import warn 6 | from zipfile import ZipFile 7 | from bs4 import BeautifulSoup 8 | from os.path import abspath, isdir, join, basename 9 | 10 | 11 | class GetData(object): 12 | """ 13 | 14 | Download CycleGAN or Pix2Pix Data. 15 | 16 | Args: 17 | technique : str 18 | One of: 'cyclegan' or 'pix2pix'. 19 | verbose : bool 20 | If True, print additional information. 21 | 22 | Examples: 23 | >>> from util.get_data import GetData 24 | >>> gd = GetData(technique='cyclegan') 25 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. 26 | 27 | """ 28 | 29 | def __init__(self, technique='cyclegan', verbose=True): 30 | url_dict = { 31 | 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets', 32 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' 33 | } 34 | self.url = url_dict.get(technique.lower()) 35 | self._verbose = verbose 36 | 37 | def _print(self, text): 38 | if self._verbose: 39 | print(text) 40 | 41 | @staticmethod 42 | def _get_options(r): 43 | soup = BeautifulSoup(r.text, 'lxml') 44 | options = [h.text for h in soup.find_all('a', href=True) 45 | if h.text.endswith(('.zip', 'tar.gz'))] 46 | return options 47 | 48 | def _present_options(self): 49 | r = requests.get(self.url) 50 | options = self._get_options(r) 51 | print('Options:\n') 52 | for i, o in enumerate(options): 53 | print("{0}: {1}".format(i, o)) 54 | choice = input("\nPlease enter the number of the " 55 | "dataset above you wish to download:") 56 | return options[int(choice)] 57 | 58 | def _download_data(self, dataset_url, save_path): 59 | if not isdir(save_path): 60 | os.makedirs(save_path) 61 | 62 | base = basename(dataset_url) 63 | temp_save_path = join(save_path, base) 64 | 65 | with open(temp_save_path, "wb") as f: 66 | r = requests.get(dataset_url) 67 | f.write(r.content) 68 | 69 | if base.endswith('.tar.gz'): 70 | obj = tarfile.open(temp_save_path) 71 | elif base.endswith('.zip'): 72 | obj = ZipFile(temp_save_path, 'r') 73 | else: 74 | raise ValueError("Unknown File Type: {0}.".format(base)) 75 | 76 | self._print("Unpacking Data...") 77 | obj.extractall(save_path) 78 | obj.close() 79 | os.remove(temp_save_path) 80 | 81 | def get(self, save_path, dataset=None): 82 | """ 83 | 84 | Download a dataset. 85 | 86 | Args: 87 | save_path : str 88 | A directory to save the data to. 89 | dataset : str, optional 90 | A specific dataset to download. 91 | Note: this must include the file extension. 92 | If None, options will be presented for you 93 | to choose from. 94 | 95 | Returns: 96 | save_path_full : str 97 | The absolute path to the downloaded data. 98 | 99 | """ 100 | if dataset is None: 101 | selected_dataset = self._present_options() 102 | else: 103 | selected_dataset = dataset 104 | 105 | save_path_full = join(save_path, selected_dataset.split('.')[0]) 106 | 107 | if isdir(save_path_full): 108 | warn("\n'{0}' already exists. Voiding Download.".format( 109 | save_path_full)) 110 | else: 111 | self._print('Downloading Data...') 112 | url = "{0}/{1}".format(self.url, selected_dataset) 113 | self._download_data(url, save_path=save_path) 114 | 115 | return abspath(save_path_full) 116 | -------------------------------------------------------------------------------- /angulargan/docs/datasets.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ### CycleGAN Datasets 4 | Download the CycleGAN datasets using the following script. Some of the datasets are collected by other researchers. Please cite their papers if you use the data. 5 | ```bash 6 | bash ./datasets/download_cyclegan_dataset.sh dataset_name 7 | ``` 8 | - `facades`: 400 images from the [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade). [[Citation](datasets/bibtex/facades.tex)] 9 | - `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com). [[Citation](datasets/bibtex/cityscapes.tex)] 10 | - `maps`: 1096 training images scraped from Google Maps. 11 | - `horse2zebra`: 939 horse images and 1177 zebra images downloaded from [ImageNet](http://www.image-net.org) using keywords `wild horse` and `zebra` 12 | - `apple2orange`: 996 apple images and 1020 orange images downloaded from [ImageNet](http://www.image-net.org) using keywords `apple` and `navel orange`. 13 | - `summer2winter_yosemite`: 1273 summer Yosemite images and 854 winter Yosemite images were downloaded using Flickr API. See more details in our paper. 14 | - `monet2photo`, `vangogh2photo`, `ukiyoe2photo`, `cezanne2photo`: The art images were downloaded from [Wikiart](https://www.wikiart.org/). The real photos are downloaded from Flickr using the combination of the tags *landscape* and *landscapephotography*. The training set size of each class is Monet:1074, Cezanne:584, Van Gogh:401, Ukiyo-e:1433, Photographs:6853. 15 | - `iphone2dslr_flower`: both classes of images were downlaoded from Flickr. The training set size of each class is iPhone:1813, DSLR:3316. See more details in our paper. 16 | 17 | To train a model on your own datasets, you need to create a data folder with two subdirectories `trainA` and `trainB` that contain images from domain A and B. You can test your model on your training set by setting `--phase train` in `test.py`. You can also create subdirectories `testA` and `testB` if you have test data. 18 | 19 | You should **not** expect our method to work on just any random combination of input and output datasets (e.g. `cats<->keyboards`). From our experiments, we find it works better if two datasets share similar visual content. For example, `landscape painting<->landscape photographs` works much better than `portrait painting <-> landscape photographs`. `zebras<->horses` achieves compelling results while `cats<->dogs` completely fails. 20 | 21 | ### pix2pix datasets 22 | Download the pix2pix datasets using the following script. Some of the datasets are collected by other researchers. Please cite their papers if you use the data. 23 | ```bash 24 | bash ./datasets/download_pix2pix_dataset.sh dataset_name 25 | ``` 26 | - `facades`: 400 images from [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade). [[Citation](datasets/bibtex/facades.tex)] 27 | - `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com). [[Citation](datasets/bibtex/cityscapes.tex)] 28 | - `maps`: 1096 training images scraped from Google Maps 29 | - `edges2shoes`: 50k training images from [UT Zappos50K dataset](http://vision.cs.utexas.edu/projects/finegrained/utzap50k). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. [[Citation](datasets/bibtex/shoes.tex)] 30 | - `edges2handbags`: 137K Amazon Handbag images from [iGAN project](https://github.com/junyanz/iGAN). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. [[Citation](datasets/bibtex/handbags.tex)] 31 | 32 | We provide a python script to generate pix2pix training data in the form of pairs of images {A,B}, where A and B are two different depictions of the same underlying scene. For example, these might be pairs {label map, photo} or {bw image, color image}. Then we can learn to translate A to B or B to A: 33 | 34 | Create folder `/path/to/data` with subfolders `A` and `B`. `A` and `B` should each have their own subfolders `train`, `val`, `test`, etc. In `/path/to/data/A/train`, put training images in style A. In `/path/to/data/B/train`, put the corresponding images in style B. Repeat same for other data splits (`val`, `test`, etc). 35 | 36 | Corresponding images in a pair {A,B} must be the same size and have the same filename, e.g., `/path/to/data/A/train/1.jpg` is considered to correspond to `/path/to/data/B/train/1.jpg`. 37 | 38 | Once the data is formatted this way, call: 39 | ```bash 40 | python datasets/combine_A_and_B.py --fold_A /path/to/data/A --fold_B /path/to/data/B --fold_AB /path/to/data 41 | ``` 42 | 43 | This will combine each pair of images (A,B) into a single image file, ready for training. 44 | -------------------------------------------------------------------------------- /angulargan/models/angular_gan_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from util.image_pool import ImagePool 3 | from .base_model import BaseModel 4 | from . import networks 5 | from . import angular_loss 6 | 7 | 8 | class AngularGANModel(BaseModel): 9 | def name(self): 10 | return 'AngularGANModel' 11 | 12 | @staticmethod 13 | def modify_commandline_options(parser, is_train=True): 14 | 15 | 16 | parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch') 17 | parser.set_defaults(dataset_mode='aligned') 18 | parser.set_defaults(which_model_netG='unet_256') 19 | if is_train: 20 | parser.add_argument('--lambda_L1', type=float, default=1.0, help='weight for L1 loss') 21 | parser.add_argument('--lambda_Angular', type=float, default=1.0, help='influence of angular loss') 22 | 23 | return parser 24 | 25 | def initialize(self, opt): 26 | BaseModel.initialize(self, opt) 27 | self.isTrain = opt.isTrain 28 | # specify the training losses you want to print out. The program will call base_model.get_current_losses 29 | self.loss_names = ['G_GAN', 'G_L1', 'G_Ang', 'D_real', 'D_fake'] 30 | # specify the images you want to save/display. The program will call base_model.get_current_visuals 31 | self.visual_names = ['real_A', 'fake_B', 'real_B'] 32 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks 33 | if self.isTrain: 34 | self.model_names = ['G', 'D'] 35 | else: # during test time, only load Gs 36 | self.model_names = ['G'] 37 | # load/define networks 38 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 39 | opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) 40 | 41 | if self.isTrain: 42 | use_sigmoid = opt.no_lsgan 43 | self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, 44 | opt.which_model_netD, 45 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) 46 | 47 | if self.isTrain: 48 | self.fake_AB_pool = ImagePool(opt.pool_size) 49 | # define loss functions 50 | self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) 51 | self.criterionL1 = torch.nn.L1Loss() 52 | self.criterionAngular = angular_loss.angular_loss() 53 | 54 | # initialize optimizers 55 | self.optimizers = [] 56 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), 57 | lr=opt.lr, betas=(opt.beta1, 0.999)) 58 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), 59 | lr=opt.lr, betas=(opt.beta1, 0.999)) 60 | self.optimizers.append(self.optimizer_G) 61 | self.optimizers.append(self.optimizer_D) 62 | 63 | def set_input(self, input): 64 | AtoB = self.opt.which_direction == 'AtoB' 65 | self.real_A = input['A' if AtoB else 'B'].to(self.device) 66 | self.real_B = input['B' if AtoB else 'A'].to(self.device) 67 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] 68 | 69 | def forward(self): 70 | self.fake_B = self.netG(self.real_A) 71 | 72 | def backward_D(self): 73 | # Fake 74 | # stop backprop to the generator by detaching fake_B 75 | fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) 76 | pred_fake = self.netD(fake_AB.detach()) 77 | self.loss_D_fake = self.criterionGAN(pred_fake, False) 78 | 79 | # Real 80 | real_AB = torch.cat((self.real_A, self.real_B), 1) 81 | pred_real = self.netD(real_AB) 82 | self.loss_D_real = self.criterionGAN(pred_real, True) 83 | 84 | # Combined loss 85 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 86 | 87 | self.loss_D.backward() 88 | 89 | def backward_G(self): 90 | # First, G(A) should fake the discriminator 91 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) 92 | pred_fake = self.netD(fake_AB) 93 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) 94 | 95 | # Second, G(A) = B 96 | self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 * 100 97 | 98 | self.eps = torch.tensor(1e-04).to(self.device) 99 | self.illum_gt = torch.div(self.real_A, torch.max(self.real_B, self.eps)) 100 | self.illum_pred = torch.div(self.real_A, torch.max(self.fake_B, self.eps)) 101 | self.loss_G_Ang = self.criterionAngular(self.illum_gt, self.illum_pred) * self.opt.lambda_Angular 102 | 103 | self.loss_G = self.loss_G_GAN + self.loss_G_Ang + self.loss_G_L1 104 | 105 | self.loss_G.backward() 106 | 107 | def optimize_parameters(self): 108 | self.forward() 109 | # update D 110 | self.set_requires_grad(self.netD, True) 111 | self.optimizer_D.zero_grad() 112 | self.backward_D() 113 | self.optimizer_D.step() 114 | 115 | # update G 116 | self.set_requires_grad(self.netD, False) 117 | self.optimizer_G.zero_grad() 118 | self.backward_G() 119 | self.optimizer_G.step() 120 | -------------------------------------------------------------------------------- /angulargan/models/angular_gan_v2_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from util.image_pool import ImagePool 3 | from .base_model import BaseModel 4 | from . import networks 5 | from . import angular_loss 6 | 7 | 8 | class AngularGANv2Model(BaseModel): 9 | def name(self): 10 | return 'AngularGANv2Model' 11 | 12 | @staticmethod 13 | def modify_commandline_options(parser, is_train=True): 14 | 15 | parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch') 16 | parser.set_defaults(dataset_mode='aligned') 17 | parser.set_defaults(which_model_netG='unet_256') 18 | if is_train: 19 | parser.add_argument('--lambda_L1', type=float, default=1.0, help='weight for L1 loss') 20 | parser.add_argument('--lambda_Angular', type=float, default=1.0, help='influence of angular loss') 21 | 22 | return parser 23 | 24 | def initialize(self, opt): 25 | BaseModel.initialize(self, opt) 26 | self.isTrain = opt.isTrain 27 | # specify the training losses you want to print out. The program will call base_model.get_current_losses 28 | self.loss_names = ['G_GAN', 'G_L1', 'G_Ang', 'D_real', 'D_fake'] 29 | # specify the images you want to save/display. The program will call base_model.get_current_visuals 30 | self.visual_names = ['real_A', 'fake_B', 'real_B'] 31 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks 32 | if self.isTrain: 33 | self.model_names = ['G', 'D'] 34 | else: # during test time, only load Gs 35 | self.model_names = ['G'] 36 | # load/define networks 37 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, 38 | opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids) 39 | 40 | if self.isTrain: 41 | use_sigmoid = opt.no_lsgan 42 | self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, 43 | opt.which_model_netD, 44 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids) 45 | 46 | if self.isTrain: 47 | self.fake_AB_pool = ImagePool(opt.pool_size) 48 | # define loss functions 49 | self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device) 50 | self.criterionL1 = torch.nn.L1Loss() 51 | self.criterionAngular = angular_loss.angular_loss() 52 | 53 | # initialize optimizers 54 | self.optimizers = [] 55 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(), 56 | lr=opt.lr, betas=(opt.beta1, 0.999)) 57 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), 58 | lr=opt.lr, betas=(opt.beta1, 0.999)) 59 | self.optimizers.append(self.optimizer_G) 60 | self.optimizers.append(self.optimizer_D) 61 | 62 | def set_input(self, input): 63 | AtoB = self.opt.which_direction == 'AtoB' 64 | self.real_A = input['A' if AtoB else 'B'].to(self.device) 65 | self.real_B = input['B' if AtoB else 'A'].to(self.device) 66 | self.image_paths = input['A_paths' if AtoB else 'B_paths'] 67 | 68 | def forward(self): 69 | self.fake_B = self.netG(self.real_A) 70 | 71 | def backward_D(self): 72 | # Fake 73 | # stop backprop to the generator by detaching fake_B 74 | fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1)) 75 | pred_fake = self.netD(fake_AB.detach()) 76 | self.loss_D_fake = self.criterionGAN(pred_fake, False) 77 | 78 | # Real 79 | self.eps = torch.tensor(1e-04).to(self.device) 80 | real_AB = torch.cat((self.real_A, torch.div(self.real_A, torch.max(self.real_B, self.eps))), 1) 81 | pred_real = self.netD(real_AB) 82 | self.loss_D_real = self.criterionGAN(pred_real, True) 83 | 84 | # Combined loss 85 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 86 | 87 | self.loss_D.backward() 88 | 89 | def backward_G(self): 90 | # First, G(A) should fake the discriminator 91 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) 92 | pred_fake = self.netD(fake_AB) 93 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) 94 | 95 | # Second, G(A) = B 96 | self.eps = torch.tensor(1e-04).to(self.device) 97 | self.loss_G_L1 = self.criterionL1(self.fake_B, torch.div(self.real_A, torch.max(self.real_B, self.eps))) * self.opt.lambda_L1 * 100 98 | 99 | self.illum_gt = self.real_B 100 | self.illum_pred = torch.div(self.real_A, torch.max(self.fake_B, self.eps)) 101 | self.loss_G_Ang = self.criterionAngular(self.illum_gt, self.illum_pred) * self.opt.lambda_Angular 102 | 103 | self.loss_G = self.loss_G_GAN + self.loss_G_Ang + self.loss_G_L1 104 | 105 | self.loss_G.backward() 106 | 107 | def optimize_parameters(self): 108 | self.forward() 109 | # update D 110 | self.set_requires_grad(self.netD, True) 111 | self.optimizer_D.zero_grad() 112 | self.backward_D() 113 | self.optimizer_D.step() 114 | 115 | # update G 116 | self.set_requires_grad(self.netD, False) 117 | self.optimizer_G.zero_grad() 118 | self.backward_G() 119 | self.optimizer_G.step() 120 | -------------------------------------------------------------------------------- /angulargan/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from . import networks 5 | 6 | 7 | class BaseModel(): 8 | 9 | # modify parser to add command line options, 10 | # and also change the default values if needed 11 | @staticmethod 12 | def modify_commandline_options(parser, is_train): 13 | return parser 14 | 15 | def name(self): 16 | return 'BaseModel' 17 | 18 | def initialize(self, opt): 19 | self.opt = opt 20 | self.gpu_ids = opt.gpu_ids 21 | self.isTrain = opt.isTrain 22 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') 23 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 24 | if opt.resize_or_crop != 'scale_width': 25 | torch.backends.cudnn.benchmark = True 26 | self.loss_names = [] 27 | self.model_names = [] 28 | self.visual_names = [] 29 | self.image_paths = [] 30 | 31 | def set_input(self, input): 32 | self.input = input 33 | 34 | def forward(self): 35 | pass 36 | 37 | # load and print networks; create schedulers 38 | def setup(self, opt, parser=None): 39 | if self.isTrain: 40 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 41 | 42 | if not self.isTrain or opt.continue_train: 43 | self.load_networks(opt.which_epoch) 44 | self.print_networks(opt.verbose) 45 | 46 | # make models eval mode during test time 47 | def eval(self): 48 | for name in self.model_names: 49 | if isinstance(name, str): 50 | net = getattr(self, 'net' + name) 51 | net.eval() 52 | 53 | # used in test time, wrapping `forward` in no_grad() so we don't save 54 | # intermediate steps for backprop 55 | def test(self): 56 | with torch.no_grad(): 57 | self.forward() 58 | 59 | # get image paths 60 | def get_image_paths(self): 61 | return self.image_paths 62 | 63 | def optimize_parameters(self): 64 | pass 65 | 66 | # update learning rate (called once every epoch) 67 | def update_learning_rate(self): 68 | for scheduler in self.schedulers: 69 | scheduler.step() 70 | lr = self.optimizers[0].param_groups[0]['lr'] 71 | print('learning rate = %.7f' % lr) 72 | 73 | # return visualization images. train.py will display these images, and save the images to a html 74 | def get_current_visuals(self): 75 | visual_ret = OrderedDict() 76 | for name in self.visual_names: 77 | if isinstance(name, str): 78 | visual_ret[name] = getattr(self, name) 79 | return visual_ret 80 | 81 | # return traning losses/errors. train.py will print out these errors as debugging information 82 | def get_current_losses(self): 83 | errors_ret = OrderedDict() 84 | for name in self.loss_names: 85 | if isinstance(name, str): 86 | # float(...) works for both scalar tensor and float number 87 | errors_ret[name] = float(getattr(self, 'loss_' + name)) 88 | return errors_ret 89 | 90 | # save models to the disk 91 | def save_networks(self, which_epoch): 92 | for name in self.model_names: 93 | if isinstance(name, str): 94 | save_filename = '%s_net_%s.pth' % (which_epoch, name) 95 | save_path = os.path.join(self.save_dir, save_filename) 96 | net = getattr(self, 'net' + name) 97 | 98 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 99 | torch.save(net.module.cpu().state_dict(), save_path) 100 | net.cuda(self.gpu_ids[0]) 101 | else: 102 | torch.save(net.cpu().state_dict(), save_path) 103 | 104 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 105 | key = keys[i] 106 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 107 | if module.__class__.__name__.startswith('InstanceNorm') and \ 108 | (key == 'running_mean' or key == 'running_var'): 109 | if getattr(module, key) is None: 110 | state_dict.pop('.'.join(keys)) 111 | if module.__class__.__name__.startswith('InstanceNorm') and \ 112 | (key == 'num_batches_tracked'): 113 | state_dict.pop('.'.join(keys)) 114 | else: 115 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 116 | 117 | # load models from the disk 118 | def load_networks(self, which_epoch): 119 | for name in self.model_names: 120 | if isinstance(name, str): 121 | load_filename = '%s_net_%s.pth' % (which_epoch, name) 122 | load_path = os.path.join(self.save_dir, load_filename) 123 | net = getattr(self, 'net' + name) 124 | if isinstance(net, torch.nn.DataParallel): 125 | net = net.module 126 | print('loading the model from %s' % load_path) 127 | # if you are using PyTorch newer than 0.4 (e.g., built from 128 | # GitHub source), you can remove str() on self.device 129 | state_dict = torch.load(load_path, map_location=str(self.device)) 130 | if hasattr(state_dict, '_metadata'): 131 | del state_dict._metadata 132 | 133 | # patch InstanceNorm checkpoints prior to 0.4 134 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 135 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 136 | net.load_state_dict(state_dict) 137 | 138 | # print network information 139 | def print_networks(self, verbose): 140 | print('---------- Networks initialized -------------') 141 | for name in self.model_names: 142 | if isinstance(name, str): 143 | net = getattr(self, 'net' + name) 144 | num_params = 0 145 | for param in net.parameters(): 146 | num_params += param.numel() 147 | if verbose: 148 | print(net) 149 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 150 | print('-----------------------------------------------') 151 | 152 | # set requies_grad=Fasle to avoid computation 153 | def set_requires_grad(self, nets, requires_grad=False): 154 | if not isinstance(nets, list): 155 | nets = [nets] 156 | for net in nets: 157 | if net is not None: 158 | for param in net.parameters(): 159 | param.requires_grad = requires_grad 160 | -------------------------------------------------------------------------------- /angulargan/options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | def __init__(self): 11 | self.initialized = False 12 | 13 | def initialize(self, parser): 14 | parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 15 | parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 16 | parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size') 17 | parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size') 18 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 19 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 20 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 21 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 22 | parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD') 23 | parser.add_argument('--which_model_netG', type=str, default='resnet_9blocks', help='selects model to use for netG') 24 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') 25 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 26 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 27 | parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single]') 28 | parser.add_argument('--model', type=str, default='cycle_gan', 29 | help='chooses which model to use. cycle_gan, pix2pix, test') 30 | parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA') 31 | parser.add_argument('--nThreads', default=4, type=int, help='# threads for loading data') 32 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 33 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 34 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 35 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size') 36 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 37 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 38 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 39 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 40 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 41 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 42 | parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 43 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 44 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') 45 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 46 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 47 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{which_model_netG}_size{loadSize}') 48 | self.initialized = True 49 | return parser 50 | 51 | def gather_options(self): 52 | # initialize parser with basic options 53 | if not self.initialized: 54 | parser = argparse.ArgumentParser( 55 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 56 | parser = self.initialize(parser) 57 | 58 | # get the basic options 59 | opt, _ = parser.parse_known_args() 60 | 61 | # modify model-related parser options 62 | model_name = opt.model 63 | model_option_setter = models.get_option_setter(model_name) 64 | parser = model_option_setter(parser, self.isTrain) 65 | opt, _ = parser.parse_known_args() # parse again with the new defaults 66 | 67 | # modify dataset-related parser options 68 | dataset_name = opt.dataset_mode 69 | dataset_option_setter = data.get_option_setter(dataset_name) 70 | parser = dataset_option_setter(parser, self.isTrain) 71 | 72 | self.parser = parser 73 | 74 | return parser.parse_args() 75 | 76 | def print_options(self, opt): 77 | message = '' 78 | message += '----------------- Options ---------------\n' 79 | for k, v in sorted(vars(opt).items()): 80 | comment = '' 81 | default = self.parser.get_default(k) 82 | if v != default: 83 | comment = '\t[default: %s]' % str(default) 84 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 85 | message += '----------------- End -------------------' 86 | print(message) 87 | 88 | # save to the disk 89 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 90 | util.mkdirs(expr_dir) 91 | file_name = os.path.join(expr_dir, 'opt.txt') 92 | with open(file_name, 'wt') as opt_file: 93 | opt_file.write(message) 94 | opt_file.write('\n') 95 | 96 | def parse(self): 97 | 98 | opt = self.gather_options() 99 | opt.isTrain = self.isTrain # train or test 100 | 101 | # process opt.suffix 102 | if opt.suffix: 103 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 104 | opt.name = opt.name + suffix 105 | 106 | self.print_options(opt) 107 | 108 | # set gpu ids 109 | str_ids = opt.gpu_ids.split(',') 110 | opt.gpu_ids = [] 111 | for str_id in str_ids: 112 | id = int(str_id) 113 | if id >= 0: 114 | opt.gpu_ids.append(id) 115 | if len(opt.gpu_ids) > 0: 116 | torch.cuda.set_device(opt.gpu_ids[0]) 117 | 118 | self.opt = opt 119 | return self.opt 120 | -------------------------------------------------------------------------------- /angulargan/util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ntpath 4 | import time 5 | from . import util 6 | from . import html 7 | from scipy.misc import imresize 8 | 9 | 10 | # save image to the disk 11 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): 12 | image_dir = webpage.get_image_dir() 13 | short_path = ntpath.basename(image_path[0]) 14 | name = os.path.splitext(short_path)[0] 15 | 16 | webpage.add_header(name) 17 | ims, txts, links = [], [], [] 18 | 19 | for label, im_data in visuals.items(): 20 | im = util.tensor2im(im_data) 21 | image_name = '%s_%s.png' % (name, label) 22 | save_path = os.path.join(image_dir, image_name) 23 | h, w, _ = im.shape 24 | if aspect_ratio > 1.0: 25 | im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic') 26 | if aspect_ratio < 1.0: 27 | im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic') 28 | util.save_image(im, save_path) 29 | 30 | ims.append(image_name) 31 | txts.append(label) 32 | links.append(image_name) 33 | webpage.add_images(ims, txts, links, width=width) 34 | 35 | 36 | class Visualizer(): 37 | def __init__(self, opt): 38 | self.display_id = opt.display_id 39 | self.use_html = opt.isTrain and not opt.no_html 40 | self.win_size = opt.display_winsize 41 | self.name = opt.name 42 | self.opt = opt 43 | self.saved = False 44 | if self.display_id > 0: 45 | import visdom 46 | self.ncols = opt.display_ncols 47 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env, raise_exceptions=True) 48 | 49 | if self.use_html: 50 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 51 | self.img_dir = os.path.join(self.web_dir, 'images') 52 | print('create web directory %s...' % self.web_dir) 53 | util.mkdirs([self.web_dir, self.img_dir]) 54 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 55 | with open(self.log_name, "a") as log_file: 56 | now = time.strftime("%c") 57 | log_file.write('================ Training Loss (%s) ================\n' % now) 58 | 59 | def reset(self): 60 | self.saved = False 61 | 62 | def throw_visdom_connection_error(self): 63 | print('\n\nCould not connect to Visdom server (https://github.com/facebookresearch/visdom) for displaying training progress.\nYou can suppress connection to Visdom using the option --display_id -1. To install visdom, run \n$ pip install visdom\n, and start the server by \n$ python -m visdom.server.\n\n') 64 | exit(1) 65 | 66 | # |visuals|: dictionary of images to display or save 67 | def display_current_results(self, visuals, epoch, save_result): 68 | if self.display_id > 0: # show images in the browser 69 | ncols = self.ncols 70 | if ncols > 0: 71 | ncols = min(ncols, len(visuals)) 72 | h, w = next(iter(visuals.values())).shape[:2] 73 | table_css = """""" % (w, h) 77 | title = self.name 78 | label_html = '' 79 | label_html_row = '' 80 | images = [] 81 | idx = 0 82 | for label, image in visuals.items(): 83 | image_numpy = util.tensor2im(image) 84 | label_html_row += '%s' % label 85 | images.append(image_numpy.transpose([2, 0, 1])) 86 | idx += 1 87 | if idx % ncols == 0: 88 | label_html += '%s' % label_html_row 89 | label_html_row = '' 90 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 91 | while idx % ncols != 0: 92 | images.append(white_image) 93 | label_html_row += '' 94 | idx += 1 95 | if label_html_row != '': 96 | label_html += '%s' % label_html_row 97 | # pane col = image row 98 | try: 99 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 100 | padding=2, opts=dict(title=title + ' images')) 101 | label_html = '%s
' % label_html 102 | self.vis.text(table_css + label_html, win=self.display_id + 2, 103 | opts=dict(title=title + ' labels')) 104 | except ConnectionError: 105 | self.throw_visdom_connection_error() 106 | 107 | else: 108 | idx = 1 109 | for label, image in visuals.items(): 110 | image_numpy = util.tensor2im(image) 111 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 112 | win=self.display_id + idx) 113 | idx += 1 114 | 115 | if self.use_html and (save_result or not self.saved): # save images to a html file 116 | self.saved = True 117 | for label, image in visuals.items(): 118 | image_numpy = util.tensor2im(image) 119 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 120 | util.save_image(image_numpy, img_path) 121 | # update website 122 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1) 123 | for n in range(epoch, 0, -1): 124 | webpage.add_header('epoch [%d]' % n) 125 | ims, txts, links = [], [], [] 126 | 127 | for label, image_numpy in visuals.items(): 128 | image_numpy = util.tensor2im(image) 129 | img_path = 'epoch%.3d_%s.png' % (n, label) 130 | ims.append(img_path) 131 | txts.append(label) 132 | links.append(img_path) 133 | webpage.add_images(ims, txts, links, width=self.win_size) 134 | webpage.save() 135 | 136 | # losses: dictionary of error labels and values 137 | def plot_current_losses(self, epoch, counter_ratio, opt, losses): 138 | if not hasattr(self, 'plot_data'): 139 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 140 | self.plot_data['X'].append(epoch + counter_ratio) 141 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 142 | try: 143 | self.vis.line( 144 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 145 | Y=np.array(self.plot_data['Y']), 146 | opts={ 147 | 'title': self.name + ' loss over time', 148 | 'legend': self.plot_data['legend'], 149 | 'xlabel': 'epoch', 150 | 'ylabel': 'loss'}, 151 | win=self.display_id) 152 | except ConnectionError: 153 | self.throw_visdom_connection_error() 154 | 155 | # losses: same format as |losses| of plot_current_losses 156 | def print_current_losses(self, epoch, i, losses, t, t_data): 157 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data) 158 | for k, v in losses.items(): 159 | message += '%s: %.3f ' % (k, v) 160 | 161 | print(message) 162 | with open(self.log_name, "a") as log_file: 163 | log_file.write('%s\n' % message) 164 | -------------------------------------------------------------------------------- /angulargan/models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | 7 | ############################################################################### 8 | # Helper Functions 9 | ############################################################################### 10 | 11 | 12 | def get_norm_layer(norm_type='instance'): 13 | if norm_type == 'batch': 14 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 15 | elif norm_type == 'instance': 16 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True) 17 | elif norm_type == 'none': 18 | norm_layer = None 19 | else: 20 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 21 | return norm_layer 22 | 23 | 24 | def get_scheduler(optimizer, opt): 25 | if opt.lr_policy == 'lambda': 26 | def lambda_rule(epoch): 27 | lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1) 28 | return lr_l 29 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 30 | elif opt.lr_policy == 'step': 31 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 32 | elif opt.lr_policy == 'plateau': 33 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 34 | else: 35 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 36 | return scheduler 37 | 38 | 39 | def init_weights(net, init_type='normal', gain=0.02): 40 | def init_func(m): 41 | classname = m.__class__.__name__ 42 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 43 | if init_type == 'normal': 44 | init.normal_(m.weight.data, 0.0, gain) 45 | elif init_type == 'xavier': 46 | init.xavier_normal_(m.weight.data, gain=gain) 47 | elif init_type == 'kaiming': 48 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 49 | elif init_type == 'orthogonal': 50 | init.orthogonal_(m.weight.data, gain=gain) 51 | else: 52 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 53 | if hasattr(m, 'bias') and m.bias is not None: 54 | init.constant_(m.bias.data, 0.0) 55 | elif classname.find('BatchNorm2d') != -1: 56 | init.normal_(m.weight.data, 1.0, gain) 57 | init.constant_(m.bias.data, 0.0) 58 | 59 | print('initialize network with %s' % init_type) 60 | net.apply(init_func) 61 | 62 | 63 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 64 | if len(gpu_ids) > 0: 65 | assert(torch.cuda.is_available()) 66 | net.to(gpu_ids[0]) 67 | net = torch.nn.DataParallel(net, gpu_ids) 68 | init_weights(net, init_type, gain=init_gain) 69 | return net 70 | 71 | 72 | def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): 73 | netG = None 74 | norm_layer = get_norm_layer(norm_type=norm) 75 | 76 | if which_model_netG == 'resnet_9blocks': 77 | netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) 78 | elif which_model_netG == 'resnet_6blocks': 79 | netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) 80 | elif which_model_netG == 'unet_128': 81 | netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 82 | elif which_model_netG == 'unet_256': 83 | netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) 84 | else: 85 | raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG) 86 | return init_net(netG, init_type, init_gain, gpu_ids) 87 | 88 | 89 | def define_D(input_nc, ndf, which_model_netD, 90 | n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_ids=[]): 91 | netD = None 92 | norm_layer = get_norm_layer(norm_type=norm) 93 | 94 | if which_model_netD == 'basic': 95 | netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid) 96 | elif which_model_netD == 'n_layers': 97 | netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid) 98 | elif which_model_netD == 'pixel': 99 | netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid) 100 | else: 101 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % 102 | which_model_netD) 103 | return init_net(netD, init_type, init_gain, gpu_ids) 104 | 105 | 106 | ############################################################################## 107 | # Classes 108 | ############################################################################## 109 | 110 | 111 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 112 | # When LSGAN is used, it is basically same as MSELoss, 113 | # but it abstracts away the need to create the target label tensor 114 | # that has the same size as the input 115 | class GANLoss(nn.Module): 116 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0): 117 | super(GANLoss, self).__init__() 118 | self.register_buffer('real_label', torch.tensor(target_real_label)) 119 | self.register_buffer('fake_label', torch.tensor(target_fake_label)) 120 | if use_lsgan: 121 | self.loss = nn.MSELoss() 122 | else: 123 | self.loss = nn.BCELoss() 124 | 125 | def get_target_tensor(self, input, target_is_real): 126 | if target_is_real: 127 | target_tensor = self.real_label 128 | else: 129 | target_tensor = self.fake_label 130 | return target_tensor.expand_as(input) 131 | 132 | def __call__(self, input, target_is_real): 133 | target_tensor = self.get_target_tensor(input, target_is_real) 134 | return self.loss(input, target_tensor) 135 | 136 | 137 | # Defines the generator that consists of Resnet blocks between a few 138 | # downsampling/upsampling operations. 139 | # Code and idea originally from Justin Johnson's architecture. 140 | # https://github.com/jcjohnson/fast-neural-style/ 141 | class ResnetGenerator(nn.Module): 142 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): 143 | assert(n_blocks >= 0) 144 | super(ResnetGenerator, self).__init__() 145 | self.input_nc = input_nc 146 | self.output_nc = output_nc 147 | self.ngf = ngf 148 | if type(norm_layer) == functools.partial: 149 | use_bias = norm_layer.func == nn.InstanceNorm2d 150 | else: 151 | use_bias = norm_layer == nn.InstanceNorm2d 152 | 153 | model = [nn.ReflectionPad2d(3), 154 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, 155 | bias=use_bias), 156 | norm_layer(ngf), 157 | nn.ReLU(True)] 158 | 159 | n_downsampling = 2 160 | for i in range(n_downsampling): 161 | mult = 2**i 162 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, 163 | stride=2, padding=1, bias=use_bias), 164 | norm_layer(ngf * mult * 2), 165 | nn.ReLU(True)] 166 | 167 | mult = 2**n_downsampling 168 | for i in range(n_blocks): 169 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] 170 | 171 | for i in range(n_downsampling): 172 | mult = 2**(n_downsampling - i) 173 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), 174 | kernel_size=3, stride=2, 175 | padding=1, output_padding=1, 176 | bias=use_bias), 177 | norm_layer(int(ngf * mult / 2)), 178 | nn.ReLU(True)] 179 | model += [nn.ReflectionPad2d(3)] 180 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] 181 | model += [nn.Tanh()] 182 | 183 | self.model = nn.Sequential(*model) 184 | 185 | def forward(self, input): 186 | return self.model(input) 187 | 188 | 189 | # Define a resnet block 190 | class ResnetBlock(nn.Module): 191 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): 192 | super(ResnetBlock, self).__init__() 193 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) 194 | 195 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): 196 | conv_block = [] 197 | p = 0 198 | if padding_type == 'reflect': 199 | conv_block += [nn.ReflectionPad2d(1)] 200 | elif padding_type == 'replicate': 201 | conv_block += [nn.ReplicationPad2d(1)] 202 | elif padding_type == 'zero': 203 | p = 1 204 | else: 205 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 206 | 207 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 208 | norm_layer(dim), 209 | nn.ReLU(True)] 210 | if use_dropout: 211 | conv_block += [nn.Dropout(0.5)] 212 | 213 | p = 0 214 | if padding_type == 'reflect': 215 | conv_block += [nn.ReflectionPad2d(1)] 216 | elif padding_type == 'replicate': 217 | conv_block += [nn.ReplicationPad2d(1)] 218 | elif padding_type == 'zero': 219 | p = 1 220 | else: 221 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 222 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 223 | norm_layer(dim)] 224 | 225 | return nn.Sequential(*conv_block) 226 | 227 | def forward(self, x): 228 | out = x + self.conv_block(x) 229 | return out 230 | 231 | 232 | # Defines the Unet generator. 233 | # |num_downs|: number of downsamplings in UNet. For example, 234 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1 235 | # at the bottleneck 236 | class UnetGenerator(nn.Module): 237 | def __init__(self, input_nc, output_nc, num_downs, ngf=64, 238 | norm_layer=nn.BatchNorm2d, use_dropout=False): 239 | super(UnetGenerator, self).__init__() 240 | 241 | # construct unet structure 242 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) 243 | for i in range(num_downs - 5): 244 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) 245 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 246 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 247 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) 248 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) 249 | 250 | self.model = unet_block 251 | 252 | def forward(self, input): 253 | return self.model(input) 254 | 255 | 256 | # Defines the submodule with skip connection. 257 | # X -------------------identity---------------------- X 258 | # |-- downsampling -- |submodule| -- upsampling --| 259 | class UnetSkipConnectionBlock(nn.Module): 260 | def __init__(self, outer_nc, inner_nc, input_nc=None, 261 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): 262 | super(UnetSkipConnectionBlock, self).__init__() 263 | self.outermost = outermost 264 | if type(norm_layer) == functools.partial: 265 | use_bias = norm_layer.func == nn.InstanceNorm2d 266 | else: 267 | use_bias = norm_layer == nn.InstanceNorm2d 268 | if input_nc is None: 269 | input_nc = outer_nc 270 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, 271 | stride=2, padding=1, bias=use_bias) 272 | downrelu = nn.LeakyReLU(0.2, True) 273 | downnorm = norm_layer(inner_nc) 274 | uprelu = nn.ReLU(True) 275 | upnorm = norm_layer(outer_nc) 276 | 277 | if outermost: 278 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 279 | kernel_size=4, stride=2, 280 | padding=1) 281 | down = [downconv] 282 | up = [uprelu, upconv, nn.Tanh()] 283 | model = down + [submodule] + up 284 | elif innermost: 285 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc, 286 | kernel_size=4, stride=2, 287 | padding=1, bias=use_bias) 288 | down = [downrelu, downconv] 289 | up = [uprelu, upconv, upnorm] 290 | model = down + up 291 | else: 292 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, 293 | kernel_size=4, stride=2, 294 | padding=1, bias=use_bias) 295 | down = [downrelu, downconv, downnorm] 296 | up = [uprelu, upconv, upnorm] 297 | 298 | if use_dropout: 299 | model = down + [submodule] + up + [nn.Dropout(0.5)] 300 | else: 301 | model = down + [submodule] + up 302 | 303 | self.model = nn.Sequential(*model) 304 | 305 | def forward(self, x): 306 | if self.outermost: 307 | return self.model(x) 308 | else: 309 | return torch.cat([x, self.model(x)], 1) 310 | 311 | 312 | # Defines the PatchGAN discriminator with the specified arguments. 313 | class NLayerDiscriminator(nn.Module): 314 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False): 315 | super(NLayerDiscriminator, self).__init__() 316 | if type(norm_layer) == functools.partial: 317 | use_bias = norm_layer.func == nn.InstanceNorm2d 318 | else: 319 | use_bias = norm_layer == nn.InstanceNorm2d 320 | 321 | kw = 4 322 | padw = 1 323 | sequence = [ 324 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), 325 | nn.LeakyReLU(0.2, True) 326 | ] 327 | 328 | nf_mult = 1 329 | nf_mult_prev = 1 330 | for n in range(1, n_layers): 331 | nf_mult_prev = nf_mult 332 | nf_mult = min(2**n, 8) 333 | sequence += [ 334 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 335 | kernel_size=kw, stride=2, padding=padw, bias=use_bias), 336 | norm_layer(ndf * nf_mult), 337 | nn.LeakyReLU(0.2, True) 338 | ] 339 | 340 | nf_mult_prev = nf_mult 341 | nf_mult = min(2**n_layers, 8) 342 | sequence += [ 343 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, 344 | kernel_size=kw, stride=1, padding=padw, bias=use_bias), 345 | norm_layer(ndf * nf_mult), 346 | nn.LeakyReLU(0.2, True) 347 | ] 348 | 349 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] 350 | 351 | if use_sigmoid: 352 | sequence += [nn.Sigmoid()] 353 | 354 | self.model = nn.Sequential(*sequence) 355 | 356 | def forward(self, input): 357 | return self.model(input) 358 | 359 | 360 | class PixelDiscriminator(nn.Module): 361 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False): 362 | super(PixelDiscriminator, self).__init__() 363 | if type(norm_layer) == functools.partial: 364 | use_bias = norm_layer.func == nn.InstanceNorm2d 365 | else: 366 | use_bias = norm_layer == nn.InstanceNorm2d 367 | 368 | self.net = [ 369 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), 370 | nn.LeakyReLU(0.2, True), 371 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), 372 | norm_layer(ndf * 2), 373 | nn.LeakyReLU(0.2, True), 374 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] 375 | 376 | if use_sigmoid: 377 | self.net.append(nn.Sigmoid()) 378 | 379 | self.net = nn.Sequential(*self.net) 380 | 381 | def forward(self, input): 382 | return self.net(input) 383 | --------------------------------------------------------------------------------