├── image.jpg
├── preview.jpg
├── angulargan.jpg
├── angulargan
├── scripts
│ ├── install_deps.sh
│ └── conda_deps.sh
├── requirements.txt
├── models
│ ├── angular_loss
│ │ ├── __pycache__
│ │ │ └── __init__.cpython-36.pyc
│ │ └── __init__.py
│ ├── __init__.py
│ ├── test_model.py
│ ├── angular_gan_model.py
│ ├── angular_gan_v2_model.py
│ ├── base_model.py
│ └── networks.py
├── runtest.sh
├── data
│ ├── base_data_loader.py
│ ├── single_dataset.py
│ ├── image_folder.py
│ ├── unaligned_dataset.py
│ ├── __init__.py
│ ├── aligned_dataset.py
│ └── base_dataset.py
├── run.sh
├── options
│ ├── test_options.py
│ ├── train_options.py
│ └── base_options.py
├── util
│ ├── image_pool.py
│ ├── util.py
│ ├── html.py
│ ├── get_data.py
│ └── visualizer.py
├── test.py
├── docs
│ ├── qa.md
│ ├── tips.md
│ └── datasets.md
├── datasets
│ ├── combine_A_and_B.py
│ └── make_dataset_aligned.py
└── train.py
├── real_illum_11346_Normalized.mat
├── LICENSE
├── generate_tinted_images.m
└── README.md
/image.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acecreamu/angularGAN/HEAD/image.jpg
--------------------------------------------------------------------------------
/preview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acecreamu/angularGAN/HEAD/preview.jpg
--------------------------------------------------------------------------------
/angulargan.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acecreamu/angularGAN/HEAD/angulargan.jpg
--------------------------------------------------------------------------------
/angulargan/scripts/install_deps.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | pip install visdom
3 | pip install dominate
4 |
--------------------------------------------------------------------------------
/angulargan/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=0.4.0
2 | torchvision>=0.2.1
3 | dominate>=2.3.1
4 | visdom>=0.1.8.3
5 |
--------------------------------------------------------------------------------
/real_illum_11346_Normalized.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acecreamu/angularGAN/HEAD/real_illum_11346_Normalized.mat
--------------------------------------------------------------------------------
/angulargan/models/angular_loss/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/acecreamu/angularGAN/HEAD/angulargan/models/angular_loss/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/angulargan/runtest.sh:
--------------------------------------------------------------------------------
1 | a=$(ls datasets/facades/test | wc -l)
2 | echo $a
3 | python test.py --dataroot ./datasets/facades --name angular_gan --model angular_gan --which_direction BtoA --display_id -1 --how_many $a
4 |
--------------------------------------------------------------------------------
/angulargan/data/base_data_loader.py:
--------------------------------------------------------------------------------
1 | class BaseDataLoader():
2 | def __init__(self):
3 | pass
4 |
5 | def initialize(self, opt):
6 | self.opt = opt
7 | pass
8 |
9 | def load_data():
10 | return None
11 |
--------------------------------------------------------------------------------
/angulargan/scripts/conda_deps.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | conda install numpy pyyaml mkl mkl-include setuptools cmake cffi typing
3 | conda install pytorch torchvision -c pytorch # add cuda90 if CUDA 9
4 | conda install visdom dominate -c conda-forge # install visdom and dominate
5 |
--------------------------------------------------------------------------------
/angulargan/run.sh:
--------------------------------------------------------------------------------
1 | python train.py --checkpoints_dir ./checkpoints --dataroot ./datasets/facades --name angular_gan --model angular_gan --lambda_Angular 1 --lambda_L1 1 --which_model_netG unet_256 --which_direction BtoA --dataset_mode aligned --no_lsgan --norm batch --pool_size 0 --save_epoch_freq 10
2 |
--------------------------------------------------------------------------------
/angulargan/models/angular_loss/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from math import pi
3 |
4 | class angular_loss(torch.nn.Module):
5 |
6 | def __init__(self):
7 | super(angular_loss,self).__init__()
8 |
9 | def forward(self, illum_gt, illum_pred):
10 | # img_gt = img_input / illum_gt
11 | # illum_gt = img_input / img_gt
12 | # illum_pred = img_input / img_output
13 |
14 | # ACOS
15 | cos_between = torch.nn.CosineSimilarity(dim=1)
16 | cos = cos_between(illum_gt, illum_pred)
17 | cos = torch.clamp(cos,-0.99999, 0.99999)
18 | loss = torch.mean(torch.acos(cos)) * 180 / pi
19 |
20 | # MSE
21 | # loss = torch.mean((illum_gt - illum_pred)**2)
22 |
23 | # 1 - COS
24 | # loss = 1 - torch.mean(cos)
25 |
26 | # 1 - COS^2
27 | # loss = 1 - torch.mean(cos**2)
28 | return loss
29 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/angulargan/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | def initialize(self, parser):
6 | parser = BaseOptions.initialize(self, parser)
7 | parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
8 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
9 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
10 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
11 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
12 | parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
13 |
14 | parser.set_defaults(model='test')
15 | # To avoid cropping, the loadSize should be the same as fineSize
16 | parser.set_defaults(loadSize=parser.get_default('fineSize'))
17 |
18 | self.isTrain = False
19 | return parser
20 |
--------------------------------------------------------------------------------
/angulargan/util/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 |
4 |
5 | class ImagePool():
6 | def __init__(self, pool_size):
7 | self.pool_size = pool_size
8 | if self.pool_size > 0:
9 | self.num_imgs = 0
10 | self.images = []
11 |
12 | def query(self, images):
13 | if self.pool_size == 0:
14 | return images
15 | return_images = []
16 | for image in images:
17 | image = torch.unsqueeze(image.data, 0)
18 | if self.num_imgs < self.pool_size:
19 | self.num_imgs = self.num_imgs + 1
20 | self.images.append(image)
21 | return_images.append(image)
22 | else:
23 | p = random.uniform(0, 1)
24 | if p > 0.5:
25 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive
26 | tmp = self.images[random_id].clone()
27 | self.images[random_id] = image
28 | return_images.append(tmp)
29 | else:
30 | return_images.append(image)
31 | return_images = torch.cat(return_images, 0)
32 | return return_images
33 |
--------------------------------------------------------------------------------
/angulargan/data/single_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from data.base_dataset import BaseDataset, get_transform
3 | from data.image_folder import make_dataset
4 | from PIL import Image
5 |
6 |
7 | class SingleDataset(BaseDataset):
8 | @staticmethod
9 | def modify_commandline_options(parser, is_train):
10 | return parser
11 |
12 | def initialize(self, opt):
13 | self.opt = opt
14 | self.root = opt.dataroot
15 | self.dir_A = os.path.join(opt.dataroot)
16 |
17 | self.A_paths = make_dataset(self.dir_A)
18 |
19 | self.A_paths = sorted(self.A_paths)
20 |
21 | self.transform = get_transform(opt)
22 |
23 | def __getitem__(self, index):
24 | A_path = self.A_paths[index]
25 | A_img = Image.open(A_path).convert('RGB')
26 | A = self.transform(A_img)
27 | if self.opt.which_direction == 'BtoA':
28 | input_nc = self.opt.output_nc
29 | else:
30 | input_nc = self.opt.input_nc
31 |
32 | if input_nc == 1: # RGB to gray
33 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
34 | A = tmp.unsqueeze(0)
35 |
36 | return {'A': A, 'A_paths': A_path}
37 |
38 | def __len__(self):
39 | return len(self.A_paths)
40 |
41 | def name(self):
42 | return 'SingleImageDataset'
43 |
--------------------------------------------------------------------------------
/angulargan/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from models.base_model import BaseModel
3 |
4 |
5 | def find_model_using_name(model_name):
6 | # Given the option --model [modelname],
7 | # the file "models/modelname_model.py"
8 | # will be imported.
9 | model_filename = "models." + model_name + "_model"
10 | modellib = importlib.import_module(model_filename)
11 |
12 | # In the file, the class called ModelNameModel() will
13 | # be instantiated. It has to be a subclass of BaseModel,
14 | # and it is case-insensitive.
15 | model = None
16 | target_model_name = model_name.replace('_', '') + 'model'
17 | for name, cls in modellib.__dict__.items():
18 | if name.lower() == target_model_name.lower() \
19 | and issubclass(cls, BaseModel):
20 | model = cls
21 |
22 | if model is None:
23 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
24 | exit(0)
25 |
26 | return model
27 |
28 |
29 | def get_option_setter(model_name):
30 | model_class = find_model_using_name(model_name)
31 | return model_class.modify_commandline_options
32 |
33 |
34 | def create_model(opt):
35 | model = find_model_using_name(opt.model)
36 | instance = model()
37 | instance.initialize(opt)
38 | print("model [%s] was created" % (instance.name()))
39 | return instance
40 |
--------------------------------------------------------------------------------
/angulargan/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from options.test_options import TestOptions
3 | from data import CreateDataLoader
4 | from models import create_model
5 | from util.visualizer import save_images
6 | from util import html
7 |
8 |
9 | if __name__ == '__main__':
10 | opt = TestOptions().parse()
11 | opt.nThreads = 1 # test code only supports nThreads = 1
12 | opt.batchSize = 1 # test code only supports batchSize = 1
13 | opt.serial_batches = True # no shuffle
14 | opt.no_flip = True # no flip
15 | opt.display_id = -1 # no visdom display
16 | data_loader = CreateDataLoader(opt)
17 | dataset = data_loader.load_data()
18 | model = create_model(opt)
19 | model.setup(opt)
20 | # create website
21 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
22 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
23 | # test
24 | for i, data in enumerate(dataset):
25 | if i >= opt.how_many:
26 | break
27 | model.set_input(data)
28 | model.test()
29 | visuals = model.get_current_visuals()
30 | img_path = model.get_image_paths()
31 | if i % 5 == 0:
32 | print('processing (%04d)-th image... %s' % (i, img_path))
33 | save_images(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize)
34 |
35 | webpage.save()
36 |
--------------------------------------------------------------------------------
/angulargan/docs/qa.md:
--------------------------------------------------------------------------------
1 | ## Frequently Asked Questions
2 | - The color gets inverted from the beginning of training: the authors also observed that the generator unnecessarily inverts the color of the input image early in the trianing, and then never learns to undo the inversion. In this case, two things can be tried. First, try using identity loss `--identity 1.0` or `--identity 0.1`. We observed that the identity loss makes the generator to be more conservative and make less of unnecessary changes. However, because of this, the change may not be as dramatic. Second, try smaller variance when initializing weights by changing `--init_gain`. We observed that smaller variance in weight initialization results in less color inversion.
3 | - Out of memory error: CycleGAN is quite memory intensive because it needs two generator networks and two discriminator networks. If you would like to generate high resolution images, you can do the following. First, train CycleGAN on cropped images of the training set. Please be careful not to change the aspect ratio or the scale of the original image, as this can lead to the training/test gap. You can usually do this by using `--resize_or_crop crop` option, or `--resize_or_crop scale_width_and_crop`. Then at test time, load only one generator to generate the results in one direction only. This greatly saves memory because you are not loading the discriminators nad the other generator in the opposite direction. You can probably input the whole image (we have done image generation of 1024x512 resolution). You can do this using `--model test --dataroot [path to the directory containing the actual images (ex. ./datasets/horse2zebra/trainA)] --model_suffix _A`. For more explanation, please see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/test_model.py#L16.
4 |
--------------------------------------------------------------------------------
/generate_tinted_images.m:
--------------------------------------------------------------------------------
1 | %% Load data
2 | fID = fopen('Source_Image/file.lst');
3 | list = textscan(fID,'%s','delimiter','\n');
4 | list = [list{1,1}];
5 | fclose(fID);
6 |
7 | response = load('Source_Image/real_illum_11346_Normalized.mat');
8 | response = response.real_rgb;
9 |
10 | %% Train-Test partition
11 | rng('default')
12 | CVparts = cvpartition(11346,'KFold',15);
13 | idx = test(CVparts,1);
14 |
15 | %% Main loop
16 | for i = 1:size(list,1)
17 | img = imread(list{i,1});
18 | img = img(1:224,1:224,:);
19 |
20 | % color correction acording to gt illumination
21 | e = reshape(response(i,:),1,1,3);
22 | img = uint8(double(img)./e);
23 |
24 | img = imresize(img,[256,256]);
25 | [img_, map] = tint_image(img);
26 | img = cat(2,img_,img);
27 | filename = strcat('train\',sprintf('%05d.jpg', i));
28 | if idx(i)
29 | filename = strcat('test\',sprintf('%05d.jpg', i));
30 | end
31 | imwrite(img,filename);
32 | filename = strcat('tint_maps\',sprintf('%05d.jpg', i));
33 | imwrite(map,filename);
34 | end
35 |
36 | %% Tint function
37 | function [img_, map] = tint_image(Im)
38 | map = zeros(256,256,3);
39 | [X,Y] = meshgrid(1:256,1:256);
40 | colors = [0, 0, 255; 0, 255, 0; 255, 0, 0];
41 | for i = 1:3
42 | mu = rand(1,2)*256;
43 | sigma = [rand(1)*25600 0; 0 rand(1)*25600];
44 | tint = mvnpdf([X(:) Y(:)], mu, sigma);
45 | tint = reshape(tint, 256, 256);
46 | tint = tint / max(tint(:));
47 | tint = repmat(tint,[1 1 3]);
48 | tint = tint .* reshape(colors(i,:)/norm(colors(i,:)), 1,1,3);
49 | tint = 1 - tint;
50 | map = map + tint ./ 3;
51 | end
52 | map = imresize(map, [size(Im,1) size(Im,2)]);
53 |
54 | img_ = uint8(map .* double(Im));
55 | map = uint8(res * 255);
56 | end
57 |
--------------------------------------------------------------------------------
/angulargan/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | import os
6 |
7 |
8 | # Converts a Tensor into an image array (numpy)
9 | # |imtype|: the desired type of the converted numpy array
10 | def tensor2im(input_image, imtype=np.uint8):
11 | if isinstance(input_image, torch.Tensor):
12 | image_tensor = input_image.data
13 | else:
14 | return input_image
15 | image_numpy = image_tensor[0].cpu().float().numpy()
16 | if image_numpy.shape[0] == 1:
17 | image_numpy = np.tile(image_numpy, (3, 1, 1))
18 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
19 | return image_numpy.astype(imtype)
20 |
21 |
22 | def diagnose_network(net, name='network'):
23 | mean = 0.0
24 | count = 0
25 | for param in net.parameters():
26 | if param.grad is not None:
27 | mean += torch.mean(torch.abs(param.grad.data))
28 | count += 1
29 | if count > 0:
30 | mean = mean / count
31 | print(name)
32 | print(mean)
33 |
34 |
35 | def save_image(image_numpy, image_path):
36 | image_pil = Image.fromarray(image_numpy)
37 | image_pil.save(image_path)
38 |
39 |
40 | def print_numpy(x, val=True, shp=False):
41 | x = x.astype(np.float64)
42 | if shp:
43 | print('shape,', x.shape)
44 | if val:
45 | x = x.flatten()
46 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
47 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
48 |
49 |
50 | def mkdirs(paths):
51 | if isinstance(paths, list) and not isinstance(paths, str):
52 | for path in paths:
53 | mkdir(path)
54 | else:
55 | mkdir(paths)
56 |
57 |
58 | def mkdir(path):
59 | if not os.path.exists(path):
60 | os.makedirs(path)
61 |
--------------------------------------------------------------------------------
/angulargan/models/test_model.py:
--------------------------------------------------------------------------------
1 | from .base_model import BaseModel
2 | from . import networks
3 | from .cycle_gan_model import CycleGANModel
4 |
5 |
6 | class TestModel(BaseModel):
7 | def name(self):
8 | return 'TestModel'
9 |
10 | @staticmethod
11 | def modify_commandline_options(parser, is_train=True):
12 | assert not is_train, 'TestModel cannot be used in train mode'
13 | parser = CycleGANModel.modify_commandline_options(parser, is_train=False)
14 | parser.set_defaults(dataset_mode='single')
15 |
16 | parser.add_argument('--model_suffix', type=str, default='',
17 | help='In checkpoints_dir, [which_epoch]_net_G[model_suffix].pth will'
18 | ' be loaded as the generator of TestModel')
19 |
20 | return parser
21 |
22 | def initialize(self, opt):
23 | assert(not opt.isTrain)
24 | BaseModel.initialize(self, opt)
25 |
26 | # specify the training losses you want to print out. The program will call base_model.get_current_losses
27 | self.loss_names = []
28 | # specify the images you want to save/display. The program will call base_model.get_current_visuals
29 | self.visual_names = ['real_A', 'fake_B']
30 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
31 | self.model_names = ['G' + opt.model_suffix]
32 |
33 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.which_model_netG,
34 | opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
35 |
36 | # assigns the model to self.netG_[suffix] so that it can be loaded
37 | # please see BaseModel.load_networks
38 | setattr(self, 'netG' + opt.model_suffix, self.netG)
39 |
40 | def set_input(self, input):
41 | # we need to use single_dataset mode
42 | self.real_A = input['A'].to(self.device)
43 | self.image_paths = input['A_paths']
44 |
45 | def forward(self):
46 | self.fake_B = self.netG(self.real_A)
47 |
--------------------------------------------------------------------------------
/angulargan/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import *
3 | import os
4 |
5 |
6 | class HTML:
7 | def __init__(self, web_dir, title, reflesh=0):
8 | self.title = title
9 | self.web_dir = web_dir
10 | self.img_dir = os.path.join(self.web_dir, 'images')
11 | if not os.path.exists(self.web_dir):
12 | os.makedirs(self.web_dir)
13 | if not os.path.exists(self.img_dir):
14 | os.makedirs(self.img_dir)
15 | # print(self.img_dir)
16 |
17 | self.doc = dominate.document(title=title)
18 | if reflesh > 0:
19 | with self.doc.head:
20 | meta(http_equiv="reflesh", content=str(reflesh))
21 |
22 | def get_image_dir(self):
23 | return self.img_dir
24 |
25 | def add_header(self, str):
26 | with self.doc:
27 | h3(str)
28 |
29 | def add_table(self, border=1):
30 | self.t = table(border=border, style="table-layout: fixed;")
31 | self.doc.add(self.t)
32 |
33 | def add_images(self, ims, txts, links, width=400):
34 | self.add_table()
35 | with self.t:
36 | with tr():
37 | for im, txt, link in zip(ims, txts, links):
38 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
39 | with p():
40 | with a(href=os.path.join('images', link)):
41 | img(style="width:%dpx" % width, src=os.path.join('images', im))
42 | br()
43 | p(txt)
44 |
45 | def save(self):
46 | html_file = '%s/index.html' % self.web_dir
47 | f = open(html_file, 'wt')
48 | f.write(self.doc.render())
49 | f.close()
50 |
51 |
52 | if __name__ == '__main__':
53 | html = HTML('web/', 'test_html')
54 | html.add_header('hello world')
55 |
56 | ims = []
57 | txts = []
58 | links = []
59 | for n in range(4):
60 | ims.append('image_%d.png' % n)
61 | txts.append('text_%d' % n)
62 | links.append('image_%d.png' % n)
63 | html.add_images(ims, txts, links)
64 | html.save()
65 |
--------------------------------------------------------------------------------
/angulargan/data/image_folder.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4 | # Modified the original code so that it also loads images from the current
5 | # directory as well as the subdirectories
6 | ###############################################################################
7 |
8 | import torch.utils.data as data
9 |
10 | from PIL import Image
11 | import os
12 | import os.path
13 |
14 | IMG_EXTENSIONS = [
15 | '.jpg', '.JPG', '.jpeg', '.JPEG',
16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
17 | ]
18 |
19 |
20 | def is_image_file(filename):
21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22 |
23 |
24 | def make_dataset(dir):
25 | images = []
26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
27 |
28 | for root, _, fnames in sorted(os.walk(dir)):
29 | for fname in fnames:
30 | if is_image_file(fname):
31 | path = os.path.join(root, fname)
32 | images.append(path)
33 |
34 | return images
35 |
36 |
37 | def default_loader(path):
38 | return Image.open(path).convert('RGB')
39 |
40 |
41 | class ImageFolder(data.Dataset):
42 |
43 | def __init__(self, root, transform=None, return_paths=False,
44 | loader=default_loader):
45 | imgs = make_dataset(root)
46 | if len(imgs) == 0:
47 | raise(RuntimeError("Found 0 images in: " + root + "\n"
48 | "Supported image extensions are: " +
49 | ",".join(IMG_EXTENSIONS)))
50 |
51 | self.root = root
52 | self.imgs = imgs
53 | self.transform = transform
54 | self.return_paths = return_paths
55 | self.loader = loader
56 |
57 | def __getitem__(self, index):
58 | path = self.imgs[index]
59 | img = self.loader(path)
60 | if self.transform is not None:
61 | img = self.transform(img)
62 | if self.return_paths:
63 | return img, path
64 | else:
65 | return img
66 |
67 | def __len__(self):
68 | return len(self.imgs)
69 |
--------------------------------------------------------------------------------
/angulargan/datasets/combine_A_and_B.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import cv2
4 | import argparse
5 |
6 | parser = argparse.ArgumentParser('create image pairs')
7 | parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='../dataset/50kshoes_edges')
8 | parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='../dataset/50kshoes_jpg')
9 | parser.add_argument('--fold_AB', dest='fold_AB', help='output directory', type=str, default='../dataset/test_AB')
10 | parser.add_argument('--num_imgs', dest='num_imgs', help='number of images',type=int, default=1000000)
11 | parser.add_argument('--use_AB', dest='use_AB', help='if true: (0001_A, 0001_B) to (0001_AB)',action='store_true')
12 | args = parser.parse_args()
13 |
14 | for arg in vars(args):
15 | print('[%s] = ' % arg, getattr(args, arg))
16 |
17 | splits = os.listdir(args.fold_A)
18 |
19 | for sp in splits:
20 | img_fold_A = os.path.join(args.fold_A, sp)
21 | img_fold_B = os.path.join(args.fold_B, sp)
22 | img_list = os.listdir(img_fold_A)
23 | if args.use_AB:
24 | img_list = [img_path for img_path in img_list if '_A.' in img_path]
25 |
26 | num_imgs = min(args.num_imgs, len(img_list))
27 | print('split = %s, use %d/%d images' % (sp, num_imgs, len(img_list)))
28 | img_fold_AB = os.path.join(args.fold_AB, sp)
29 | if not os.path.isdir(img_fold_AB):
30 | os.makedirs(img_fold_AB)
31 | print('split = %s, number of images = %d' % (sp, num_imgs))
32 | for n in range(num_imgs):
33 | name_A = img_list[n]
34 | path_A = os.path.join(img_fold_A, name_A)
35 | if args.use_AB:
36 | name_B = name_A.replace('_A.', '_B.')
37 | else:
38 | name_B = name_A
39 | path_B = os.path.join(img_fold_B, name_B)
40 | if os.path.isfile(path_A) and os.path.isfile(path_B):
41 | name_AB = name_A
42 | if args.use_AB:
43 | name_AB = name_AB.replace('_A.', '.') # remove _A
44 | path_AB = os.path.join(img_fold_AB, name_AB)
45 | im_A = cv2.imread(path_A, cv2.CV_LOAD_IMAGE_COLOR)
46 | im_B = cv2.imread(path_B, cv2.CV_LOAD_IMAGE_COLOR)
47 | im_AB = np.concatenate([im_A, im_B], 1)
48 | cv2.imwrite(path_AB, im_AB)
49 |
--------------------------------------------------------------------------------
/angulargan/data/unaligned_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from data.base_dataset import BaseDataset, get_transform
3 | from data.image_folder import make_dataset
4 | from PIL import Image
5 | import random
6 |
7 |
8 | class UnalignedDataset(BaseDataset):
9 | @staticmethod
10 | def modify_commandline_options(parser, is_train):
11 | return parser
12 |
13 | def initialize(self, opt):
14 | self.opt = opt
15 | self.root = opt.dataroot
16 | self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')
17 | self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')
18 |
19 | self.A_paths = make_dataset(self.dir_A)
20 | self.B_paths = make_dataset(self.dir_B)
21 |
22 | self.A_paths = sorted(self.A_paths)
23 | self.B_paths = sorted(self.B_paths)
24 | self.A_size = len(self.A_paths)
25 | self.B_size = len(self.B_paths)
26 | self.transform = get_transform(opt)
27 |
28 | def __getitem__(self, index):
29 | A_path = self.A_paths[index % self.A_size]
30 | if self.opt.serial_batches:
31 | index_B = index % self.B_size
32 | else:
33 | index_B = random.randint(0, self.B_size - 1)
34 | B_path = self.B_paths[index_B]
35 | # print('(A, B) = (%d, %d)' % (index_A, index_B))
36 | A_img = Image.open(A_path).convert('RGB')
37 | B_img = Image.open(B_path).convert('RGB')
38 |
39 | A = self.transform(A_img)
40 | B = self.transform(B_img)
41 | if self.opt.which_direction == 'BtoA':
42 | input_nc = self.opt.output_nc
43 | output_nc = self.opt.input_nc
44 | else:
45 | input_nc = self.opt.input_nc
46 | output_nc = self.opt.output_nc
47 |
48 | if input_nc == 1: # RGB to gray
49 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
50 | A = tmp.unsqueeze(0)
51 |
52 | if output_nc == 1: # RGB to gray
53 | tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
54 | B = tmp.unsqueeze(0)
55 | return {'A': A, 'B': B,
56 | 'A_paths': A_path, 'B_paths': B_path}
57 |
58 | def __len__(self):
59 | return max(self.A_size, self.B_size)
60 |
61 | def name(self):
62 | return 'UnalignedDataset'
63 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Conditional GANs for Multi-Illuminant Color Constancy: Revolution or Yet Another Approach?
2 | Supporting code to the paper
3 | [O Sidorov. Conditional GANs for Multi-Illuminant Color Constancy: Revolution or Yet Another Approach?](https://arxiv.org/abs/1811.06604)
4 |
5 |
6 | 
7 |
8 | # AngularGAN
9 | The work presents an extension of the supervised image-to-image translation algorithm ["pix2pix" by Isola *et al.*](https://arxiv.org/abs/1611.07004) orriented specifically to the color constancy task.
10 | AngularGAN inherits from [this](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) implementation of pix2pix in PyTorch. Therefore, you may follow original instruction for installation and dependincies. The new modules are implemented in Torch and do not require additional packages.
11 | #### Datasets are below!
12 | ### Getting started
13 | - Put your data in datasets/facades in the format
14 | ```
15 | -facedes/
16 | -test/
17 | -xxx.jpg
18 | -yyy.jpg
19 | -...
20 | -train/
21 | -zzz.jpg
22 | -...
23 | ```
24 | where each image consist of couple of images A and B (input and output) concatenated along horizontal axis.
25 | - Run `visdom` to open training visualization (optional)
26 | - Run training (change parameter `--model angular_gan_v2` to use v2)
27 | ```
28 | chmod a+x run.sh
29 | ./run.sh
30 | ```
31 | - Replace `runtest.sh` for testing (change parameter `--model angular_gan_v2` to use v2)
32 | *We thank autors of pix2pix for their excellent work!*
33 |
34 | 
35 |
36 | # Datasets
37 | ### Tinted Multi-illuminant dataset
38 |
39 | The MATLAB code `generate_tinted_images.m` allows to apply multi-illuimnant color cast to the input images. The tint maps are randomized and are not coherent between frames.
40 | You can use the provided file `real_illum_11346_Normalized.mat` or create your own by simple normalization of the original illumination vectors as `e_norm = e./norm(e)`.
41 |
42 | ### GTAV Shadow Removal Dataset
43 | The GTAV Shadow Removal Dataset of 5,723 image pairs with and without shadows may be acessed by the [link](https://drive.google.com/open?id=1ktOXJmMQL_6U2J03mks3yWh6EMWKjUmu).
44 |
45 | #### Preview
46 |
47 | 
48 |
--------------------------------------------------------------------------------
/angulargan/datasets/make_dataset_aligned.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from PIL import Image
4 |
5 |
6 | def get_file_paths(folder):
7 | image_file_paths = []
8 | for root, dirs, filenames in os.walk(folder):
9 | filenames = sorted(filenames)
10 | for filename in filenames:
11 | input_path = os.path.abspath(root)
12 | file_path = os.path.join(input_path, filename)
13 | if filename.endswith('.png') or filename.endswith('.jpg'):
14 | image_file_paths.append(file_path)
15 |
16 | break # prevent descending into subfolders
17 | return image_file_paths
18 |
19 |
20 | def align_images(a_file_paths, b_file_paths, target_path):
21 | if not os.path.exists(target_path):
22 | os.makedirs(target_path)
23 |
24 | for i in range(len(a_file_paths)):
25 | img_a = Image.open(a_file_paths[i])
26 | img_b = Image.open(b_file_paths[i])
27 | assert(img_a.size == img_b.size)
28 |
29 | aligned_image = Image.new("RGB", (img_a.size[0] * 2, img_a.size[1]))
30 | aligned_image.paste(img_a, (0, 0))
31 | aligned_image.paste(img_b, (img_a.size[0], 0))
32 | aligned_image.save(os.path.join(target_path, '{:04d}.jpg'.format(i)))
33 |
34 |
35 | if __name__ == '__main__':
36 | import argparse
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument(
39 | '--dataset-path',
40 | dest='dataset_path',
41 | help='Which folder to process (it should have subfolders testA, testB, trainA and trainB'
42 | )
43 | args = parser.parse_args()
44 |
45 | dataset_folder = args.dataset_path
46 | print(dataset_folder)
47 |
48 | test_a_path = os.path.join(dataset_folder, 'testA')
49 | test_b_path = os.path.join(dataset_folder, 'testB')
50 | test_a_file_paths = get_file_paths(test_a_path)
51 | test_b_file_paths = get_file_paths(test_b_path)
52 | assert(len(test_a_file_paths) == len(test_b_file_paths))
53 | test_path = os.path.join(dataset_folder, 'test')
54 |
55 | train_a_path = os.path.join(dataset_folder, 'trainA')
56 | train_b_path = os.path.join(dataset_folder, 'trainB')
57 | train_a_file_paths = get_file_paths(train_a_path)
58 | train_b_file_paths = get_file_paths(train_b_path)
59 | assert(len(train_a_file_paths) == len(train_b_file_paths))
60 | train_path = os.path.join(dataset_folder, 'train')
61 |
62 | align_images(test_a_file_paths, test_b_file_paths, test_path)
63 | align_images(train_a_file_paths, train_b_file_paths, train_path)
64 |
--------------------------------------------------------------------------------
/angulargan/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | from options.train_options import TrainOptions
3 | from data import CreateDataLoader
4 | from models import create_model
5 | from util.visualizer import Visualizer
6 |
7 | if __name__ == '__main__':
8 | opt = TrainOptions().parse()
9 | data_loader = CreateDataLoader(opt)
10 | dataset = data_loader.load_data()
11 | dataset_size = len(data_loader)
12 | print('#training images = %d' % dataset_size)
13 |
14 | model = create_model(opt)
15 | model.setup(opt)
16 | visualizer = Visualizer(opt)
17 | total_steps = 0
18 |
19 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
20 | epoch_start_time = time.time()
21 | iter_data_time = time.time()
22 | epoch_iter = 0
23 |
24 | for i, data in enumerate(dataset):
25 | iter_start_time = time.time()
26 | if total_steps % opt.print_freq == 0:
27 | t_data = iter_start_time - iter_data_time
28 | visualizer.reset()
29 | total_steps += opt.batchSize
30 | epoch_iter += opt.batchSize
31 | model.set_input(data)
32 | model.optimize_parameters()
33 |
34 | if total_steps % opt.display_freq == 0:
35 | save_result = total_steps % opt.update_html_freq == 0
36 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
37 |
38 | if total_steps % opt.print_freq == 0:
39 | losses = model.get_current_losses()
40 | t = (time.time() - iter_start_time) / opt.batchSize
41 | visualizer.print_current_losses(epoch, epoch_iter, losses, t, t_data)
42 | if opt.display_id > 0:
43 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, opt, losses)
44 |
45 | if total_steps % opt.save_latest_freq == 0:
46 | print('saving the latest model (epoch %d, total_steps %d)' %
47 | (epoch, total_steps))
48 | model.save_networks('latest')
49 |
50 | iter_data_time = time.time()
51 | if epoch % opt.save_epoch_freq == 0:
52 | print('saving the model at the end of epoch %d, iters %d' %
53 | (epoch, total_steps))
54 | model.save_networks('latest')
55 | model.save_networks(epoch)
56 |
57 | print('End of epoch %d / %d \t Time Taken: %d sec' %
58 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
59 | model.update_learning_rate()
60 |
--------------------------------------------------------------------------------
/angulargan/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TrainOptions(BaseOptions):
5 | def initialize(self, parser):
6 | parser = BaseOptions.initialize(self, parser)
7 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen')
8 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
9 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
10 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
11 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
12 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
13 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
14 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
15 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
16 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
17 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
18 | parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
19 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
20 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
21 | parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
22 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images')
23 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
24 | parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau')
25 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
26 |
27 | self.isTrain = True
28 | return parser
29 |
--------------------------------------------------------------------------------
/angulargan/docs/tips.md:
--------------------------------------------------------------------------------
1 | ## Training/test Tips
2 | - Flags: see `options/train_options.py` and `options/base_options.py` for the training flags; see `options/test_options.py` and `options/base_options.py` for the test flags. There are some model-specific flags as well, which are added in the model files, such as `--lambda_A` option in `model/cycle_gan_model.py`. The default values of these options are also adjusted in the model files.
3 | - CPU/GPU (default `--gpu_ids 0`): set`--gpu_ids -1` to use CPU mode; set `--gpu_ids 0,1,2` for multi-GPU mode. You need a large batch size (e.g. `--batchSize 32`) to benefit from multiple GPUs.
4 | - Visualization: during training, the current results can be viewed using two methods. First, if you set `--display_id` > 0, the results and loss plot will appear on a local graphics web server launched by [visdom](https://github.com/facebookresearch/visdom). To do this, you should have `visdom` installed and a server running by the command `python -m visdom.server`. The default server URL is `http://localhost:8097`. `display_id` corresponds to the window ID that is displayed on the `visdom` server. The `visdom` display functionality is turned on by default. To avoid the extra overhead of communicating with `visdom` set `--display_id -1`. Second, the intermediate results are saved to `[opt.checkpoints_dir]/[opt.name]/web/` as an HTML file. To avoid this, set `--no_html`.
5 | - Preprocessing: images can be resized and cropped in different ways using `--resize_or_crop` option. The default option `'resize_and_crop'` resizes the image to be of size `(opt.loadSize, opt.loadSize)` and does a random crop of size `(opt.fineSize, opt.fineSize)`. `'crop'` skips the resizing step and only performs random cropping. `'scale_width'` resizes the image to have width `opt.fineSize` while keeping the aspect ratio. `'scale_width_and_crop'` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`. `'none'` tries to skip all these preprocessing steps. However, if the image size is not a multiple of some number depending on the number of downsamplings of the generator, you will get an error because the size of the output image may be different from the size of the input image. Therefore, `'none'` option still tries to adjust the image size to be a multiple of 4. You might need a bigger adjustment if you change the generator architecture. Please see `data/base_datset.py` do see how all these were implemented.
6 | - Fine-tuning/Resume training: to fine-tune a pre-trained model, or resume the previous training, use the `--continue_train` flag. The program will then load the model based on `which_epoch`. By default, the program will initialize the epoch count as 1. Set `--epoch_count ` to specify a different starting epoch count.
7 |
--------------------------------------------------------------------------------
/angulargan/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import torch.utils.data
3 | from data.base_data_loader import BaseDataLoader
4 | from data.base_dataset import BaseDataset
5 |
6 | def find_dataset_using_name(dataset_name):
7 | # Given the option --dataset_mode [datasetname],
8 | # the file "data/datasetname_dataset.py"
9 | # will be imported.
10 | dataset_filename = "data." + dataset_name + "_dataset"
11 | datasetlib = importlib.import_module(dataset_filename)
12 |
13 | # In the file, the class called DatasetNameDataset() will
14 | # be instantiated. It has to be a subclass of BaseDataset,
15 | # and it is case-insensitive.
16 | dataset = None
17 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
18 | for name, cls in datasetlib.__dict__.items():
19 | if name.lower() == target_dataset_name.lower() \
20 | and issubclass(cls, BaseDataset):
21 | dataset = cls
22 |
23 | if dataset is None:
24 | print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
25 | exit(0)
26 |
27 | return dataset
28 |
29 |
30 | def get_option_setter(dataset_name):
31 | dataset_class = find_dataset_using_name(dataset_name)
32 | return dataset_class.modify_commandline_options
33 |
34 |
35 | def create_dataset(opt):
36 | dataset = find_dataset_using_name(opt.dataset_mode)
37 | instance = dataset()
38 | instance.initialize(opt)
39 | print("dataset [%s] was created" % (instance.name()))
40 | return instance
41 |
42 |
43 | def CreateDataLoader(opt):
44 | data_loader = CustomDatasetDataLoader()
45 | data_loader.initialize(opt)
46 | return data_loader
47 |
48 |
49 | ## Wrapper class of Dataset class that performs
50 | ## multi-threaded data loading
51 | class CustomDatasetDataLoader(BaseDataLoader):
52 | def name(self):
53 | return 'CustomDatasetDataLoader'
54 |
55 | def initialize(self, opt):
56 | BaseDataLoader.initialize(self, opt)
57 | self.dataset = create_dataset(opt)
58 | self.dataloader = torch.utils.data.DataLoader(
59 | self.dataset,
60 | batch_size=opt.batchSize,
61 | shuffle=not opt.serial_batches,
62 | num_workers=int(opt.nThreads))
63 |
64 | def load_data(self):
65 | return self
66 |
67 | def __len__(self):
68 | return min(len(self.dataset), self.opt.max_dataset_size)
69 |
70 | def __iter__(self):
71 | for i, data in enumerate(self.dataloader):
72 | if i * self.opt.batchSize >= self.opt.max_dataset_size:
73 | break
74 | yield data
75 |
--------------------------------------------------------------------------------
/angulargan/data/aligned_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import random
3 | import torchvision.transforms as transforms
4 | import torch
5 | from data.base_dataset import BaseDataset
6 | from data.image_folder import make_dataset
7 | from PIL import Image
8 |
9 |
10 | class AlignedDataset(BaseDataset):
11 | @staticmethod
12 | def modify_commandline_options(parser, is_train):
13 | return parser
14 |
15 | def initialize(self, opt):
16 | self.opt = opt
17 | self.root = opt.dataroot
18 | self.dir_AB = os.path.join(opt.dataroot, opt.phase)
19 | self.AB_paths = sorted(make_dataset(self.dir_AB))
20 | assert(opt.resize_or_crop == 'resize_and_crop')
21 |
22 | def __getitem__(self, index):
23 | AB_path = self.AB_paths[index]
24 | AB = Image.open(AB_path).convert('RGB')
25 | w, h = AB.size
26 | w2 = int(w / 2)
27 | A = AB.crop((0, 0, w2, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
28 | B = AB.crop((w2, 0, w, h)).resize((self.opt.loadSize, self.opt.loadSize), Image.BICUBIC)
29 | A = transforms.ToTensor()(A)
30 | B = transforms.ToTensor()(B)
31 | w_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
32 | h_offset = random.randint(0, max(0, self.opt.loadSize - self.opt.fineSize - 1))
33 |
34 | A = A[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
35 | B = B[:, h_offset:h_offset + self.opt.fineSize, w_offset:w_offset + self.opt.fineSize]
36 |
37 | A = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(A)
38 | B = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(B)
39 |
40 | if self.opt.which_direction == 'BtoA':
41 | input_nc = self.opt.output_nc
42 | output_nc = self.opt.input_nc
43 | else:
44 | input_nc = self.opt.input_nc
45 | output_nc = self.opt.output_nc
46 |
47 | if (not self.opt.no_flip) and random.random() < 0.5:
48 | idx = [i for i in range(A.size(2) - 1, -1, -1)]
49 | idx = torch.LongTensor(idx)
50 | A = A.index_select(2, idx)
51 | B = B.index_select(2, idx)
52 |
53 | if input_nc == 1: # RGB to gray
54 | tmp = A[0, ...] * 0.299 + A[1, ...] * 0.587 + A[2, ...] * 0.114
55 | A = tmp.unsqueeze(0)
56 |
57 | if output_nc == 1: # RGB to gray
58 | tmp = B[0, ...] * 0.299 + B[1, ...] * 0.587 + B[2, ...] * 0.114
59 | B = tmp.unsqueeze(0)
60 |
61 | return {'A': A, 'B': B,
62 | 'A_paths': AB_path, 'B_paths': AB_path}
63 |
64 | def __len__(self):
65 | return len(self.AB_paths)
66 |
67 | def name(self):
68 | return 'AlignedDataset'
69 |
--------------------------------------------------------------------------------
/angulargan/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from PIL import Image
3 | import torchvision.transforms as transforms
4 |
5 |
6 | class BaseDataset(data.Dataset):
7 | def __init__(self):
8 | super(BaseDataset, self).__init__()
9 |
10 | def name(self):
11 | return 'BaseDataset'
12 |
13 | @staticmethod
14 | def modify_commandline_options(parser, is_train):
15 | return parser
16 |
17 | def initialize(self, opt):
18 | pass
19 |
20 | def __len__(self):
21 | return 0
22 |
23 |
24 | def get_transform(opt):
25 | transform_list = []
26 | if opt.resize_or_crop == 'resize_and_crop':
27 | osize = [opt.loadSize, opt.loadSize]
28 | transform_list.append(transforms.Resize(osize, Image.BICUBIC))
29 | transform_list.append(transforms.RandomCrop(opt.fineSize))
30 | elif opt.resize_or_crop == 'crop':
31 | transform_list.append(transforms.RandomCrop(opt.fineSize))
32 | elif opt.resize_or_crop == 'scale_width':
33 | transform_list.append(transforms.Lambda(
34 | lambda img: __scale_width(img, opt.fineSize)))
35 | elif opt.resize_or_crop == 'scale_width_and_crop':
36 | transform_list.append(transforms.Lambda(
37 | lambda img: __scale_width(img, opt.loadSize)))
38 | transform_list.append(transforms.RandomCrop(opt.fineSize))
39 | elif opt.resize_or_crop == 'none':
40 | transform_list.append(transforms.Lambda(
41 | lambda img: __adjust(img)))
42 | else:
43 | raise ValueError('--resize_or_crop %s is not a valid option.' % opt.resize_or_crop)
44 |
45 | if opt.isTrain and not opt.no_flip:
46 | transform_list.append(transforms.RandomHorizontalFlip())
47 |
48 | transform_list += [transforms.ToTensor(),
49 | transforms.Normalize((0.5, 0.5, 0.5),
50 | (0.5, 0.5, 0.5))]
51 | return transforms.Compose(transform_list)
52 |
53 | # just modify the width and height to be multiple of 4
54 | def __adjust(img):
55 | ow, oh = img.size
56 |
57 | # the size needs to be a multiple of this number,
58 | # because going through generator network may change img size
59 | # and eventually cause size mismatch error
60 | mult = 4
61 | if ow % mult == 0 and oh % mult == 0:
62 | return img
63 | w = (ow - 1) // mult
64 | w = (w + 1) * mult
65 | h = (oh - 1) // mult
66 | h = (h + 1) * mult
67 |
68 | if ow != w or oh != h:
69 | __print_size_warning(ow, oh, w, h)
70 |
71 | return img.resize((w, h), Image.BICUBIC)
72 |
73 |
74 | def __scale_width(img, target_width):
75 | ow, oh = img.size
76 |
77 | # the size needs to be a multiple of this number,
78 | # because going through generator network may change img size
79 | # and eventually cause size mismatch error
80 | mult = 4
81 | assert target_width % mult == 0, "the target width needs to be multiple of %d." % mult
82 | if (ow == target_width and oh % mult == 0):
83 | return img
84 | w = target_width
85 | target_height = int(target_width * oh / ow)
86 | m = (target_height - 1) // mult
87 | h = (m + 1) * mult
88 |
89 | if target_height != h:
90 | __print_size_warning(target_width, target_height, w, h)
91 |
92 | return img.resize((w, h), Image.BICUBIC)
93 |
94 |
95 | def __print_size_warning(ow, oh, w, h):
96 | if not hasattr(__print_size_warning, 'has_printed'):
97 | print("The image size needs to be a multiple of 4. "
98 | "The loaded image size was (%d, %d), so it was adjusted to "
99 | "(%d, %d). This adjustment will be done to all images "
100 | "whose sizes are not multiples of 4" % (ow, oh, w, h))
101 | __print_size_warning.has_printed = True
102 |
103 |
104 |
--------------------------------------------------------------------------------
/angulargan/util/get_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import tarfile
4 | import requests
5 | from warnings import warn
6 | from zipfile import ZipFile
7 | from bs4 import BeautifulSoup
8 | from os.path import abspath, isdir, join, basename
9 |
10 |
11 | class GetData(object):
12 | """
13 |
14 | Download CycleGAN or Pix2Pix Data.
15 |
16 | Args:
17 | technique : str
18 | One of: 'cyclegan' or 'pix2pix'.
19 | verbose : bool
20 | If True, print additional information.
21 |
22 | Examples:
23 | >>> from util.get_data import GetData
24 | >>> gd = GetData(technique='cyclegan')
25 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
26 |
27 | """
28 |
29 | def __init__(self, technique='cyclegan', verbose=True):
30 | url_dict = {
31 | 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets',
32 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
33 | }
34 | self.url = url_dict.get(technique.lower())
35 | self._verbose = verbose
36 |
37 | def _print(self, text):
38 | if self._verbose:
39 | print(text)
40 |
41 | @staticmethod
42 | def _get_options(r):
43 | soup = BeautifulSoup(r.text, 'lxml')
44 | options = [h.text for h in soup.find_all('a', href=True)
45 | if h.text.endswith(('.zip', 'tar.gz'))]
46 | return options
47 |
48 | def _present_options(self):
49 | r = requests.get(self.url)
50 | options = self._get_options(r)
51 | print('Options:\n')
52 | for i, o in enumerate(options):
53 | print("{0}: {1}".format(i, o))
54 | choice = input("\nPlease enter the number of the "
55 | "dataset above you wish to download:")
56 | return options[int(choice)]
57 |
58 | def _download_data(self, dataset_url, save_path):
59 | if not isdir(save_path):
60 | os.makedirs(save_path)
61 |
62 | base = basename(dataset_url)
63 | temp_save_path = join(save_path, base)
64 |
65 | with open(temp_save_path, "wb") as f:
66 | r = requests.get(dataset_url)
67 | f.write(r.content)
68 |
69 | if base.endswith('.tar.gz'):
70 | obj = tarfile.open(temp_save_path)
71 | elif base.endswith('.zip'):
72 | obj = ZipFile(temp_save_path, 'r')
73 | else:
74 | raise ValueError("Unknown File Type: {0}.".format(base))
75 |
76 | self._print("Unpacking Data...")
77 | obj.extractall(save_path)
78 | obj.close()
79 | os.remove(temp_save_path)
80 |
81 | def get(self, save_path, dataset=None):
82 | """
83 |
84 | Download a dataset.
85 |
86 | Args:
87 | save_path : str
88 | A directory to save the data to.
89 | dataset : str, optional
90 | A specific dataset to download.
91 | Note: this must include the file extension.
92 | If None, options will be presented for you
93 | to choose from.
94 |
95 | Returns:
96 | save_path_full : str
97 | The absolute path to the downloaded data.
98 |
99 | """
100 | if dataset is None:
101 | selected_dataset = self._present_options()
102 | else:
103 | selected_dataset = dataset
104 |
105 | save_path_full = join(save_path, selected_dataset.split('.')[0])
106 |
107 | if isdir(save_path_full):
108 | warn("\n'{0}' already exists. Voiding Download.".format(
109 | save_path_full))
110 | else:
111 | self._print('Downloading Data...')
112 | url = "{0}/{1}".format(self.url, selected_dataset)
113 | self._download_data(url, save_path=save_path)
114 |
115 | return abspath(save_path_full)
116 |
--------------------------------------------------------------------------------
/angulargan/docs/datasets.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ### CycleGAN Datasets
4 | Download the CycleGAN datasets using the following script. Some of the datasets are collected by other researchers. Please cite their papers if you use the data.
5 | ```bash
6 | bash ./datasets/download_cyclegan_dataset.sh dataset_name
7 | ```
8 | - `facades`: 400 images from the [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade). [[Citation](datasets/bibtex/facades.tex)]
9 | - `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com). [[Citation](datasets/bibtex/cityscapes.tex)]
10 | - `maps`: 1096 training images scraped from Google Maps.
11 | - `horse2zebra`: 939 horse images and 1177 zebra images downloaded from [ImageNet](http://www.image-net.org) using keywords `wild horse` and `zebra`
12 | - `apple2orange`: 996 apple images and 1020 orange images downloaded from [ImageNet](http://www.image-net.org) using keywords `apple` and `navel orange`.
13 | - `summer2winter_yosemite`: 1273 summer Yosemite images and 854 winter Yosemite images were downloaded using Flickr API. See more details in our paper.
14 | - `monet2photo`, `vangogh2photo`, `ukiyoe2photo`, `cezanne2photo`: The art images were downloaded from [Wikiart](https://www.wikiart.org/). The real photos are downloaded from Flickr using the combination of the tags *landscape* and *landscapephotography*. The training set size of each class is Monet:1074, Cezanne:584, Van Gogh:401, Ukiyo-e:1433, Photographs:6853.
15 | - `iphone2dslr_flower`: both classes of images were downlaoded from Flickr. The training set size of each class is iPhone:1813, DSLR:3316. See more details in our paper.
16 |
17 | To train a model on your own datasets, you need to create a data folder with two subdirectories `trainA` and `trainB` that contain images from domain A and B. You can test your model on your training set by setting `--phase train` in `test.py`. You can also create subdirectories `testA` and `testB` if you have test data.
18 |
19 | You should **not** expect our method to work on just any random combination of input and output datasets (e.g. `cats<->keyboards`). From our experiments, we find it works better if two datasets share similar visual content. For example, `landscape painting<->landscape photographs` works much better than `portrait painting <-> landscape photographs`. `zebras<->horses` achieves compelling results while `cats<->dogs` completely fails.
20 |
21 | ### pix2pix datasets
22 | Download the pix2pix datasets using the following script. Some of the datasets are collected by other researchers. Please cite their papers if you use the data.
23 | ```bash
24 | bash ./datasets/download_pix2pix_dataset.sh dataset_name
25 | ```
26 | - `facades`: 400 images from [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade). [[Citation](datasets/bibtex/facades.tex)]
27 | - `cityscapes`: 2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com). [[Citation](datasets/bibtex/cityscapes.tex)]
28 | - `maps`: 1096 training images scraped from Google Maps
29 | - `edges2shoes`: 50k training images from [UT Zappos50K dataset](http://vision.cs.utexas.edu/projects/finegrained/utzap50k). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. [[Citation](datasets/bibtex/shoes.tex)]
30 | - `edges2handbags`: 137K Amazon Handbag images from [iGAN project](https://github.com/junyanz/iGAN). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. [[Citation](datasets/bibtex/handbags.tex)]
31 |
32 | We provide a python script to generate pix2pix training data in the form of pairs of images {A,B}, where A and B are two different depictions of the same underlying scene. For example, these might be pairs {label map, photo} or {bw image, color image}. Then we can learn to translate A to B or B to A:
33 |
34 | Create folder `/path/to/data` with subfolders `A` and `B`. `A` and `B` should each have their own subfolders `train`, `val`, `test`, etc. In `/path/to/data/A/train`, put training images in style A. In `/path/to/data/B/train`, put the corresponding images in style B. Repeat same for other data splits (`val`, `test`, etc).
35 |
36 | Corresponding images in a pair {A,B} must be the same size and have the same filename, e.g., `/path/to/data/A/train/1.jpg` is considered to correspond to `/path/to/data/B/train/1.jpg`.
37 |
38 | Once the data is formatted this way, call:
39 | ```bash
40 | python datasets/combine_A_and_B.py --fold_A /path/to/data/A --fold_B /path/to/data/B --fold_AB /path/to/data
41 | ```
42 |
43 | This will combine each pair of images (A,B) into a single image file, ready for training.
44 |
--------------------------------------------------------------------------------
/angulargan/models/angular_gan_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from util.image_pool import ImagePool
3 | from .base_model import BaseModel
4 | from . import networks
5 | from . import angular_loss
6 |
7 |
8 | class AngularGANModel(BaseModel):
9 | def name(self):
10 | return 'AngularGANModel'
11 |
12 | @staticmethod
13 | def modify_commandline_options(parser, is_train=True):
14 |
15 |
16 | parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch')
17 | parser.set_defaults(dataset_mode='aligned')
18 | parser.set_defaults(which_model_netG='unet_256')
19 | if is_train:
20 | parser.add_argument('--lambda_L1', type=float, default=1.0, help='weight for L1 loss')
21 | parser.add_argument('--lambda_Angular', type=float, default=1.0, help='influence of angular loss')
22 |
23 | return parser
24 |
25 | def initialize(self, opt):
26 | BaseModel.initialize(self, opt)
27 | self.isTrain = opt.isTrain
28 | # specify the training losses you want to print out. The program will call base_model.get_current_losses
29 | self.loss_names = ['G_GAN', 'G_L1', 'G_Ang', 'D_real', 'D_fake']
30 | # specify the images you want to save/display. The program will call base_model.get_current_visuals
31 | self.visual_names = ['real_A', 'fake_B', 'real_B']
32 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
33 | if self.isTrain:
34 | self.model_names = ['G', 'D']
35 | else: # during test time, only load Gs
36 | self.model_names = ['G']
37 | # load/define networks
38 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
39 | opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
40 |
41 | if self.isTrain:
42 | use_sigmoid = opt.no_lsgan
43 | self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
44 | opt.which_model_netD,
45 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
46 |
47 | if self.isTrain:
48 | self.fake_AB_pool = ImagePool(opt.pool_size)
49 | # define loss functions
50 | self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
51 | self.criterionL1 = torch.nn.L1Loss()
52 | self.criterionAngular = angular_loss.angular_loss()
53 |
54 | # initialize optimizers
55 | self.optimizers = []
56 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
57 | lr=opt.lr, betas=(opt.beta1, 0.999))
58 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
59 | lr=opt.lr, betas=(opt.beta1, 0.999))
60 | self.optimizers.append(self.optimizer_G)
61 | self.optimizers.append(self.optimizer_D)
62 |
63 | def set_input(self, input):
64 | AtoB = self.opt.which_direction == 'AtoB'
65 | self.real_A = input['A' if AtoB else 'B'].to(self.device)
66 | self.real_B = input['B' if AtoB else 'A'].to(self.device)
67 | self.image_paths = input['A_paths' if AtoB else 'B_paths']
68 |
69 | def forward(self):
70 | self.fake_B = self.netG(self.real_A)
71 |
72 | def backward_D(self):
73 | # Fake
74 | # stop backprop to the generator by detaching fake_B
75 | fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
76 | pred_fake = self.netD(fake_AB.detach())
77 | self.loss_D_fake = self.criterionGAN(pred_fake, False)
78 |
79 | # Real
80 | real_AB = torch.cat((self.real_A, self.real_B), 1)
81 | pred_real = self.netD(real_AB)
82 | self.loss_D_real = self.criterionGAN(pred_real, True)
83 |
84 | # Combined loss
85 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
86 |
87 | self.loss_D.backward()
88 |
89 | def backward_G(self):
90 | # First, G(A) should fake the discriminator
91 | fake_AB = torch.cat((self.real_A, self.fake_B), 1)
92 | pred_fake = self.netD(fake_AB)
93 | self.loss_G_GAN = self.criterionGAN(pred_fake, True)
94 |
95 | # Second, G(A) = B
96 | self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 * 100
97 |
98 | self.eps = torch.tensor(1e-04).to(self.device)
99 | self.illum_gt = torch.div(self.real_A, torch.max(self.real_B, self.eps))
100 | self.illum_pred = torch.div(self.real_A, torch.max(self.fake_B, self.eps))
101 | self.loss_G_Ang = self.criterionAngular(self.illum_gt, self.illum_pred) * self.opt.lambda_Angular
102 |
103 | self.loss_G = self.loss_G_GAN + self.loss_G_Ang + self.loss_G_L1
104 |
105 | self.loss_G.backward()
106 |
107 | def optimize_parameters(self):
108 | self.forward()
109 | # update D
110 | self.set_requires_grad(self.netD, True)
111 | self.optimizer_D.zero_grad()
112 | self.backward_D()
113 | self.optimizer_D.step()
114 |
115 | # update G
116 | self.set_requires_grad(self.netD, False)
117 | self.optimizer_G.zero_grad()
118 | self.backward_G()
119 | self.optimizer_G.step()
120 |
--------------------------------------------------------------------------------
/angulargan/models/angular_gan_v2_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from util.image_pool import ImagePool
3 | from .base_model import BaseModel
4 | from . import networks
5 | from . import angular_loss
6 |
7 |
8 | class AngularGANv2Model(BaseModel):
9 | def name(self):
10 | return 'AngularGANv2Model'
11 |
12 | @staticmethod
13 | def modify_commandline_options(parser, is_train=True):
14 |
15 | parser.set_defaults(pool_size=0, no_lsgan=True, norm='batch')
16 | parser.set_defaults(dataset_mode='aligned')
17 | parser.set_defaults(which_model_netG='unet_256')
18 | if is_train:
19 | parser.add_argument('--lambda_L1', type=float, default=1.0, help='weight for L1 loss')
20 | parser.add_argument('--lambda_Angular', type=float, default=1.0, help='influence of angular loss')
21 |
22 | return parser
23 |
24 | def initialize(self, opt):
25 | BaseModel.initialize(self, opt)
26 | self.isTrain = opt.isTrain
27 | # specify the training losses you want to print out. The program will call base_model.get_current_losses
28 | self.loss_names = ['G_GAN', 'G_L1', 'G_Ang', 'D_real', 'D_fake']
29 | # specify the images you want to save/display. The program will call base_model.get_current_visuals
30 | self.visual_names = ['real_A', 'fake_B', 'real_B']
31 | # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks
32 | if self.isTrain:
33 | self.model_names = ['G', 'D']
34 | else: # during test time, only load Gs
35 | self.model_names = ['G']
36 | # load/define networks
37 | self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
38 | opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
39 |
40 | if self.isTrain:
41 | use_sigmoid = opt.no_lsgan
42 | self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf,
43 | opt.which_model_netD,
44 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, opt.init_gain, self.gpu_ids)
45 |
46 | if self.isTrain:
47 | self.fake_AB_pool = ImagePool(opt.pool_size)
48 | # define loss functions
49 | self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan).to(self.device)
50 | self.criterionL1 = torch.nn.L1Loss()
51 | self.criterionAngular = angular_loss.angular_loss()
52 |
53 | # initialize optimizers
54 | self.optimizers = []
55 | self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
56 | lr=opt.lr, betas=(opt.beta1, 0.999))
57 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(),
58 | lr=opt.lr, betas=(opt.beta1, 0.999))
59 | self.optimizers.append(self.optimizer_G)
60 | self.optimizers.append(self.optimizer_D)
61 |
62 | def set_input(self, input):
63 | AtoB = self.opt.which_direction == 'AtoB'
64 | self.real_A = input['A' if AtoB else 'B'].to(self.device)
65 | self.real_B = input['B' if AtoB else 'A'].to(self.device)
66 | self.image_paths = input['A_paths' if AtoB else 'B_paths']
67 |
68 | def forward(self):
69 | self.fake_B = self.netG(self.real_A)
70 |
71 | def backward_D(self):
72 | # Fake
73 | # stop backprop to the generator by detaching fake_B
74 | fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
75 | pred_fake = self.netD(fake_AB.detach())
76 | self.loss_D_fake = self.criterionGAN(pred_fake, False)
77 |
78 | # Real
79 | self.eps = torch.tensor(1e-04).to(self.device)
80 | real_AB = torch.cat((self.real_A, torch.div(self.real_A, torch.max(self.real_B, self.eps))), 1)
81 | pred_real = self.netD(real_AB)
82 | self.loss_D_real = self.criterionGAN(pred_real, True)
83 |
84 | # Combined loss
85 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
86 |
87 | self.loss_D.backward()
88 |
89 | def backward_G(self):
90 | # First, G(A) should fake the discriminator
91 | fake_AB = torch.cat((self.real_A, self.fake_B), 1)
92 | pred_fake = self.netD(fake_AB)
93 | self.loss_G_GAN = self.criterionGAN(pred_fake, True)
94 |
95 | # Second, G(A) = B
96 | self.eps = torch.tensor(1e-04).to(self.device)
97 | self.loss_G_L1 = self.criterionL1(self.fake_B, torch.div(self.real_A, torch.max(self.real_B, self.eps))) * self.opt.lambda_L1 * 100
98 |
99 | self.illum_gt = self.real_B
100 | self.illum_pred = torch.div(self.real_A, torch.max(self.fake_B, self.eps))
101 | self.loss_G_Ang = self.criterionAngular(self.illum_gt, self.illum_pred) * self.opt.lambda_Angular
102 |
103 | self.loss_G = self.loss_G_GAN + self.loss_G_Ang + self.loss_G_L1
104 |
105 | self.loss_G.backward()
106 |
107 | def optimize_parameters(self):
108 | self.forward()
109 | # update D
110 | self.set_requires_grad(self.netD, True)
111 | self.optimizer_D.zero_grad()
112 | self.backward_D()
113 | self.optimizer_D.step()
114 |
115 | # update G
116 | self.set_requires_grad(self.netD, False)
117 | self.optimizer_G.zero_grad()
118 | self.backward_G()
119 | self.optimizer_G.step()
120 |
--------------------------------------------------------------------------------
/angulargan/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from . import networks
5 |
6 |
7 | class BaseModel():
8 |
9 | # modify parser to add command line options,
10 | # and also change the default values if needed
11 | @staticmethod
12 | def modify_commandline_options(parser, is_train):
13 | return parser
14 |
15 | def name(self):
16 | return 'BaseModel'
17 |
18 | def initialize(self, opt):
19 | self.opt = opt
20 | self.gpu_ids = opt.gpu_ids
21 | self.isTrain = opt.isTrain
22 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu')
23 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
24 | if opt.resize_or_crop != 'scale_width':
25 | torch.backends.cudnn.benchmark = True
26 | self.loss_names = []
27 | self.model_names = []
28 | self.visual_names = []
29 | self.image_paths = []
30 |
31 | def set_input(self, input):
32 | self.input = input
33 |
34 | def forward(self):
35 | pass
36 |
37 | # load and print networks; create schedulers
38 | def setup(self, opt, parser=None):
39 | if self.isTrain:
40 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
41 |
42 | if not self.isTrain or opt.continue_train:
43 | self.load_networks(opt.which_epoch)
44 | self.print_networks(opt.verbose)
45 |
46 | # make models eval mode during test time
47 | def eval(self):
48 | for name in self.model_names:
49 | if isinstance(name, str):
50 | net = getattr(self, 'net' + name)
51 | net.eval()
52 |
53 | # used in test time, wrapping `forward` in no_grad() so we don't save
54 | # intermediate steps for backprop
55 | def test(self):
56 | with torch.no_grad():
57 | self.forward()
58 |
59 | # get image paths
60 | def get_image_paths(self):
61 | return self.image_paths
62 |
63 | def optimize_parameters(self):
64 | pass
65 |
66 | # update learning rate (called once every epoch)
67 | def update_learning_rate(self):
68 | for scheduler in self.schedulers:
69 | scheduler.step()
70 | lr = self.optimizers[0].param_groups[0]['lr']
71 | print('learning rate = %.7f' % lr)
72 |
73 | # return visualization images. train.py will display these images, and save the images to a html
74 | def get_current_visuals(self):
75 | visual_ret = OrderedDict()
76 | for name in self.visual_names:
77 | if isinstance(name, str):
78 | visual_ret[name] = getattr(self, name)
79 | return visual_ret
80 |
81 | # return traning losses/errors. train.py will print out these errors as debugging information
82 | def get_current_losses(self):
83 | errors_ret = OrderedDict()
84 | for name in self.loss_names:
85 | if isinstance(name, str):
86 | # float(...) works for both scalar tensor and float number
87 | errors_ret[name] = float(getattr(self, 'loss_' + name))
88 | return errors_ret
89 |
90 | # save models to the disk
91 | def save_networks(self, which_epoch):
92 | for name in self.model_names:
93 | if isinstance(name, str):
94 | save_filename = '%s_net_%s.pth' % (which_epoch, name)
95 | save_path = os.path.join(self.save_dir, save_filename)
96 | net = getattr(self, 'net' + name)
97 |
98 | if len(self.gpu_ids) > 0 and torch.cuda.is_available():
99 | torch.save(net.module.cpu().state_dict(), save_path)
100 | net.cuda(self.gpu_ids[0])
101 | else:
102 | torch.save(net.cpu().state_dict(), save_path)
103 |
104 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
105 | key = keys[i]
106 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
107 | if module.__class__.__name__.startswith('InstanceNorm') and \
108 | (key == 'running_mean' or key == 'running_var'):
109 | if getattr(module, key) is None:
110 | state_dict.pop('.'.join(keys))
111 | if module.__class__.__name__.startswith('InstanceNorm') and \
112 | (key == 'num_batches_tracked'):
113 | state_dict.pop('.'.join(keys))
114 | else:
115 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
116 |
117 | # load models from the disk
118 | def load_networks(self, which_epoch):
119 | for name in self.model_names:
120 | if isinstance(name, str):
121 | load_filename = '%s_net_%s.pth' % (which_epoch, name)
122 | load_path = os.path.join(self.save_dir, load_filename)
123 | net = getattr(self, 'net' + name)
124 | if isinstance(net, torch.nn.DataParallel):
125 | net = net.module
126 | print('loading the model from %s' % load_path)
127 | # if you are using PyTorch newer than 0.4 (e.g., built from
128 | # GitHub source), you can remove str() on self.device
129 | state_dict = torch.load(load_path, map_location=str(self.device))
130 | if hasattr(state_dict, '_metadata'):
131 | del state_dict._metadata
132 |
133 | # patch InstanceNorm checkpoints prior to 0.4
134 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
135 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
136 | net.load_state_dict(state_dict)
137 |
138 | # print network information
139 | def print_networks(self, verbose):
140 | print('---------- Networks initialized -------------')
141 | for name in self.model_names:
142 | if isinstance(name, str):
143 | net = getattr(self, 'net' + name)
144 | num_params = 0
145 | for param in net.parameters():
146 | num_params += param.numel()
147 | if verbose:
148 | print(net)
149 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
150 | print('-----------------------------------------------')
151 |
152 | # set requies_grad=Fasle to avoid computation
153 | def set_requires_grad(self, nets, requires_grad=False):
154 | if not isinstance(nets, list):
155 | nets = [nets]
156 | for net in nets:
157 | if net is not None:
158 | for param in net.parameters():
159 | param.requires_grad = requires_grad
160 |
--------------------------------------------------------------------------------
/angulargan/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from util import util
4 | import torch
5 | import models
6 | import data
7 |
8 |
9 | class BaseOptions():
10 | def __init__(self):
11 | self.initialized = False
12 |
13 | def initialize(self, parser):
14 | parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
15 | parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
16 | parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size')
17 | parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size')
18 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
19 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
20 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
21 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
22 | parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD')
23 | parser.add_argument('--which_model_netG', type=str, default='resnet_9blocks', help='selects model to use for netG')
24 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
25 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
26 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models')
27 | parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single]')
28 | parser.add_argument('--model', type=str, default='cycle_gan',
29 | help='chooses which model to use. cycle_gan, pix2pix, test')
30 | parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
31 | parser.add_argument('--nThreads', default=4, type=int, help='# threads for loading data')
32 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
33 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
34 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
35 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
36 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
37 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display')
38 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")')
39 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
40 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
41 | parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
42 | parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
43 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation')
44 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
45 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
46 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
47 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{which_model_netG}_size{loadSize}')
48 | self.initialized = True
49 | return parser
50 |
51 | def gather_options(self):
52 | # initialize parser with basic options
53 | if not self.initialized:
54 | parser = argparse.ArgumentParser(
55 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
56 | parser = self.initialize(parser)
57 |
58 | # get the basic options
59 | opt, _ = parser.parse_known_args()
60 |
61 | # modify model-related parser options
62 | model_name = opt.model
63 | model_option_setter = models.get_option_setter(model_name)
64 | parser = model_option_setter(parser, self.isTrain)
65 | opt, _ = parser.parse_known_args() # parse again with the new defaults
66 |
67 | # modify dataset-related parser options
68 | dataset_name = opt.dataset_mode
69 | dataset_option_setter = data.get_option_setter(dataset_name)
70 | parser = dataset_option_setter(parser, self.isTrain)
71 |
72 | self.parser = parser
73 |
74 | return parser.parse_args()
75 |
76 | def print_options(self, opt):
77 | message = ''
78 | message += '----------------- Options ---------------\n'
79 | for k, v in sorted(vars(opt).items()):
80 | comment = ''
81 | default = self.parser.get_default(k)
82 | if v != default:
83 | comment = '\t[default: %s]' % str(default)
84 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
85 | message += '----------------- End -------------------'
86 | print(message)
87 |
88 | # save to the disk
89 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
90 | util.mkdirs(expr_dir)
91 | file_name = os.path.join(expr_dir, 'opt.txt')
92 | with open(file_name, 'wt') as opt_file:
93 | opt_file.write(message)
94 | opt_file.write('\n')
95 |
96 | def parse(self):
97 |
98 | opt = self.gather_options()
99 | opt.isTrain = self.isTrain # train or test
100 |
101 | # process opt.suffix
102 | if opt.suffix:
103 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
104 | opt.name = opt.name + suffix
105 |
106 | self.print_options(opt)
107 |
108 | # set gpu ids
109 | str_ids = opt.gpu_ids.split(',')
110 | opt.gpu_ids = []
111 | for str_id in str_ids:
112 | id = int(str_id)
113 | if id >= 0:
114 | opt.gpu_ids.append(id)
115 | if len(opt.gpu_ids) > 0:
116 | torch.cuda.set_device(opt.gpu_ids[0])
117 |
118 | self.opt = opt
119 | return self.opt
120 |
--------------------------------------------------------------------------------
/angulargan/util/visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import ntpath
4 | import time
5 | from . import util
6 | from . import html
7 | from scipy.misc import imresize
8 |
9 |
10 | # save image to the disk
11 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
12 | image_dir = webpage.get_image_dir()
13 | short_path = ntpath.basename(image_path[0])
14 | name = os.path.splitext(short_path)[0]
15 |
16 | webpage.add_header(name)
17 | ims, txts, links = [], [], []
18 |
19 | for label, im_data in visuals.items():
20 | im = util.tensor2im(im_data)
21 | image_name = '%s_%s.png' % (name, label)
22 | save_path = os.path.join(image_dir, image_name)
23 | h, w, _ = im.shape
24 | if aspect_ratio > 1.0:
25 | im = imresize(im, (h, int(w * aspect_ratio)), interp='bicubic')
26 | if aspect_ratio < 1.0:
27 | im = imresize(im, (int(h / aspect_ratio), w), interp='bicubic')
28 | util.save_image(im, save_path)
29 |
30 | ims.append(image_name)
31 | txts.append(label)
32 | links.append(image_name)
33 | webpage.add_images(ims, txts, links, width=width)
34 |
35 |
36 | class Visualizer():
37 | def __init__(self, opt):
38 | self.display_id = opt.display_id
39 | self.use_html = opt.isTrain and not opt.no_html
40 | self.win_size = opt.display_winsize
41 | self.name = opt.name
42 | self.opt = opt
43 | self.saved = False
44 | if self.display_id > 0:
45 | import visdom
46 | self.ncols = opt.display_ncols
47 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env, raise_exceptions=True)
48 |
49 | if self.use_html:
50 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
51 | self.img_dir = os.path.join(self.web_dir, 'images')
52 | print('create web directory %s...' % self.web_dir)
53 | util.mkdirs([self.web_dir, self.img_dir])
54 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
55 | with open(self.log_name, "a") as log_file:
56 | now = time.strftime("%c")
57 | log_file.write('================ Training Loss (%s) ================\n' % now)
58 |
59 | def reset(self):
60 | self.saved = False
61 |
62 | def throw_visdom_connection_error(self):
63 | print('\n\nCould not connect to Visdom server (https://github.com/facebookresearch/visdom) for displaying training progress.\nYou can suppress connection to Visdom using the option --display_id -1. To install visdom, run \n$ pip install visdom\n, and start the server by \n$ python -m visdom.server.\n\n')
64 | exit(1)
65 |
66 | # |visuals|: dictionary of images to display or save
67 | def display_current_results(self, visuals, epoch, save_result):
68 | if self.display_id > 0: # show images in the browser
69 | ncols = self.ncols
70 | if ncols > 0:
71 | ncols = min(ncols, len(visuals))
72 | h, w = next(iter(visuals.values())).shape[:2]
73 | table_css = """""" % (w, h)
77 | title = self.name
78 | label_html = ''
79 | label_html_row = ''
80 | images = []
81 | idx = 0
82 | for label, image in visuals.items():
83 | image_numpy = util.tensor2im(image)
84 | label_html_row += '%s | ' % label
85 | images.append(image_numpy.transpose([2, 0, 1]))
86 | idx += 1
87 | if idx % ncols == 0:
88 | label_html += '%s
' % label_html_row
89 | label_html_row = ''
90 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
91 | while idx % ncols != 0:
92 | images.append(white_image)
93 | label_html_row += ' | '
94 | idx += 1
95 | if label_html_row != '':
96 | label_html += '%s
' % label_html_row
97 | # pane col = image row
98 | try:
99 | self.vis.images(images, nrow=ncols, win=self.display_id + 1,
100 | padding=2, opts=dict(title=title + ' images'))
101 | label_html = '' % label_html
102 | self.vis.text(table_css + label_html, win=self.display_id + 2,
103 | opts=dict(title=title + ' labels'))
104 | except ConnectionError:
105 | self.throw_visdom_connection_error()
106 |
107 | else:
108 | idx = 1
109 | for label, image in visuals.items():
110 | image_numpy = util.tensor2im(image)
111 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
112 | win=self.display_id + idx)
113 | idx += 1
114 |
115 | if self.use_html and (save_result or not self.saved): # save images to a html file
116 | self.saved = True
117 | for label, image in visuals.items():
118 | image_numpy = util.tensor2im(image)
119 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
120 | util.save_image(image_numpy, img_path)
121 | # update website
122 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
123 | for n in range(epoch, 0, -1):
124 | webpage.add_header('epoch [%d]' % n)
125 | ims, txts, links = [], [], []
126 |
127 | for label, image_numpy in visuals.items():
128 | image_numpy = util.tensor2im(image)
129 | img_path = 'epoch%.3d_%s.png' % (n, label)
130 | ims.append(img_path)
131 | txts.append(label)
132 | links.append(img_path)
133 | webpage.add_images(ims, txts, links, width=self.win_size)
134 | webpage.save()
135 |
136 | # losses: dictionary of error labels and values
137 | def plot_current_losses(self, epoch, counter_ratio, opt, losses):
138 | if not hasattr(self, 'plot_data'):
139 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())}
140 | self.plot_data['X'].append(epoch + counter_ratio)
141 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
142 | try:
143 | self.vis.line(
144 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
145 | Y=np.array(self.plot_data['Y']),
146 | opts={
147 | 'title': self.name + ' loss over time',
148 | 'legend': self.plot_data['legend'],
149 | 'xlabel': 'epoch',
150 | 'ylabel': 'loss'},
151 | win=self.display_id)
152 | except ConnectionError:
153 | self.throw_visdom_connection_error()
154 |
155 | # losses: same format as |losses| of plot_current_losses
156 | def print_current_losses(self, epoch, i, losses, t, t_data):
157 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, i, t, t_data)
158 | for k, v in losses.items():
159 | message += '%s: %.3f ' % (k, v)
160 |
161 | print(message)
162 | with open(self.log_name, "a") as log_file:
163 | log_file.write('%s\n' % message)
164 |
--------------------------------------------------------------------------------
/angulargan/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools
5 | from torch.optim import lr_scheduler
6 |
7 | ###############################################################################
8 | # Helper Functions
9 | ###############################################################################
10 |
11 |
12 | def get_norm_layer(norm_type='instance'):
13 | if norm_type == 'batch':
14 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
15 | elif norm_type == 'instance':
16 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
17 | elif norm_type == 'none':
18 | norm_layer = None
19 | else:
20 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
21 | return norm_layer
22 |
23 |
24 | def get_scheduler(optimizer, opt):
25 | if opt.lr_policy == 'lambda':
26 | def lambda_rule(epoch):
27 | lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
28 | return lr_l
29 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
30 | elif opt.lr_policy == 'step':
31 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
32 | elif opt.lr_policy == 'plateau':
33 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
34 | else:
35 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
36 | return scheduler
37 |
38 |
39 | def init_weights(net, init_type='normal', gain=0.02):
40 | def init_func(m):
41 | classname = m.__class__.__name__
42 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
43 | if init_type == 'normal':
44 | init.normal_(m.weight.data, 0.0, gain)
45 | elif init_type == 'xavier':
46 | init.xavier_normal_(m.weight.data, gain=gain)
47 | elif init_type == 'kaiming':
48 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
49 | elif init_type == 'orthogonal':
50 | init.orthogonal_(m.weight.data, gain=gain)
51 | else:
52 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
53 | if hasattr(m, 'bias') and m.bias is not None:
54 | init.constant_(m.bias.data, 0.0)
55 | elif classname.find('BatchNorm2d') != -1:
56 | init.normal_(m.weight.data, 1.0, gain)
57 | init.constant_(m.bias.data, 0.0)
58 |
59 | print('initialize network with %s' % init_type)
60 | net.apply(init_func)
61 |
62 |
63 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
64 | if len(gpu_ids) > 0:
65 | assert(torch.cuda.is_available())
66 | net.to(gpu_ids[0])
67 | net = torch.nn.DataParallel(net, gpu_ids)
68 | init_weights(net, init_type, gain=init_gain)
69 | return net
70 |
71 |
72 | def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
73 | netG = None
74 | norm_layer = get_norm_layer(norm_type=norm)
75 |
76 | if which_model_netG == 'resnet_9blocks':
77 | netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9)
78 | elif which_model_netG == 'resnet_6blocks':
79 | netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6)
80 | elif which_model_netG == 'unet_128':
81 | netG = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
82 | elif which_model_netG == 'unet_256':
83 | netG = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
84 | else:
85 | raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
86 | return init_net(netG, init_type, init_gain, gpu_ids)
87 |
88 |
89 | def define_D(input_nc, ndf, which_model_netD,
90 | n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', init_gain=0.02, gpu_ids=[]):
91 | netD = None
92 | norm_layer = get_norm_layer(norm_type=norm)
93 |
94 | if which_model_netD == 'basic':
95 | netD = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
96 | elif which_model_netD == 'n_layers':
97 | netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
98 | elif which_model_netD == 'pixel':
99 | netD = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_sigmoid=use_sigmoid)
100 | else:
101 | raise NotImplementedError('Discriminator model name [%s] is not recognized' %
102 | which_model_netD)
103 | return init_net(netD, init_type, init_gain, gpu_ids)
104 |
105 |
106 | ##############################################################################
107 | # Classes
108 | ##############################################################################
109 |
110 |
111 | # Defines the GAN loss which uses either LSGAN or the regular GAN.
112 | # When LSGAN is used, it is basically same as MSELoss,
113 | # but it abstracts away the need to create the target label tensor
114 | # that has the same size as the input
115 | class GANLoss(nn.Module):
116 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
117 | super(GANLoss, self).__init__()
118 | self.register_buffer('real_label', torch.tensor(target_real_label))
119 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
120 | if use_lsgan:
121 | self.loss = nn.MSELoss()
122 | else:
123 | self.loss = nn.BCELoss()
124 |
125 | def get_target_tensor(self, input, target_is_real):
126 | if target_is_real:
127 | target_tensor = self.real_label
128 | else:
129 | target_tensor = self.fake_label
130 | return target_tensor.expand_as(input)
131 |
132 | def __call__(self, input, target_is_real):
133 | target_tensor = self.get_target_tensor(input, target_is_real)
134 | return self.loss(input, target_tensor)
135 |
136 |
137 | # Defines the generator that consists of Resnet blocks between a few
138 | # downsampling/upsampling operations.
139 | # Code and idea originally from Justin Johnson's architecture.
140 | # https://github.com/jcjohnson/fast-neural-style/
141 | class ResnetGenerator(nn.Module):
142 | def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
143 | assert(n_blocks >= 0)
144 | super(ResnetGenerator, self).__init__()
145 | self.input_nc = input_nc
146 | self.output_nc = output_nc
147 | self.ngf = ngf
148 | if type(norm_layer) == functools.partial:
149 | use_bias = norm_layer.func == nn.InstanceNorm2d
150 | else:
151 | use_bias = norm_layer == nn.InstanceNorm2d
152 |
153 | model = [nn.ReflectionPad2d(3),
154 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
155 | bias=use_bias),
156 | norm_layer(ngf),
157 | nn.ReLU(True)]
158 |
159 | n_downsampling = 2
160 | for i in range(n_downsampling):
161 | mult = 2**i
162 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
163 | stride=2, padding=1, bias=use_bias),
164 | norm_layer(ngf * mult * 2),
165 | nn.ReLU(True)]
166 |
167 | mult = 2**n_downsampling
168 | for i in range(n_blocks):
169 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
170 |
171 | for i in range(n_downsampling):
172 | mult = 2**(n_downsampling - i)
173 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
174 | kernel_size=3, stride=2,
175 | padding=1, output_padding=1,
176 | bias=use_bias),
177 | norm_layer(int(ngf * mult / 2)),
178 | nn.ReLU(True)]
179 | model += [nn.ReflectionPad2d(3)]
180 | model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
181 | model += [nn.Tanh()]
182 |
183 | self.model = nn.Sequential(*model)
184 |
185 | def forward(self, input):
186 | return self.model(input)
187 |
188 |
189 | # Define a resnet block
190 | class ResnetBlock(nn.Module):
191 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
192 | super(ResnetBlock, self).__init__()
193 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
194 |
195 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
196 | conv_block = []
197 | p = 0
198 | if padding_type == 'reflect':
199 | conv_block += [nn.ReflectionPad2d(1)]
200 | elif padding_type == 'replicate':
201 | conv_block += [nn.ReplicationPad2d(1)]
202 | elif padding_type == 'zero':
203 | p = 1
204 | else:
205 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
206 |
207 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
208 | norm_layer(dim),
209 | nn.ReLU(True)]
210 | if use_dropout:
211 | conv_block += [nn.Dropout(0.5)]
212 |
213 | p = 0
214 | if padding_type == 'reflect':
215 | conv_block += [nn.ReflectionPad2d(1)]
216 | elif padding_type == 'replicate':
217 | conv_block += [nn.ReplicationPad2d(1)]
218 | elif padding_type == 'zero':
219 | p = 1
220 | else:
221 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
222 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
223 | norm_layer(dim)]
224 |
225 | return nn.Sequential(*conv_block)
226 |
227 | def forward(self, x):
228 | out = x + self.conv_block(x)
229 | return out
230 |
231 |
232 | # Defines the Unet generator.
233 | # |num_downs|: number of downsamplings in UNet. For example,
234 | # if |num_downs| == 7, image of size 128x128 will become of size 1x1
235 | # at the bottleneck
236 | class UnetGenerator(nn.Module):
237 | def __init__(self, input_nc, output_nc, num_downs, ngf=64,
238 | norm_layer=nn.BatchNorm2d, use_dropout=False):
239 | super(UnetGenerator, self).__init__()
240 |
241 | # construct unet structure
242 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
243 | for i in range(num_downs - 5):
244 | unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
245 | unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
246 | unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
247 | unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
248 | unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
249 |
250 | self.model = unet_block
251 |
252 | def forward(self, input):
253 | return self.model(input)
254 |
255 |
256 | # Defines the submodule with skip connection.
257 | # X -------------------identity---------------------- X
258 | # |-- downsampling -- |submodule| -- upsampling --|
259 | class UnetSkipConnectionBlock(nn.Module):
260 | def __init__(self, outer_nc, inner_nc, input_nc=None,
261 | submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
262 | super(UnetSkipConnectionBlock, self).__init__()
263 | self.outermost = outermost
264 | if type(norm_layer) == functools.partial:
265 | use_bias = norm_layer.func == nn.InstanceNorm2d
266 | else:
267 | use_bias = norm_layer == nn.InstanceNorm2d
268 | if input_nc is None:
269 | input_nc = outer_nc
270 | downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
271 | stride=2, padding=1, bias=use_bias)
272 | downrelu = nn.LeakyReLU(0.2, True)
273 | downnorm = norm_layer(inner_nc)
274 | uprelu = nn.ReLU(True)
275 | upnorm = norm_layer(outer_nc)
276 |
277 | if outermost:
278 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
279 | kernel_size=4, stride=2,
280 | padding=1)
281 | down = [downconv]
282 | up = [uprelu, upconv, nn.Tanh()]
283 | model = down + [submodule] + up
284 | elif innermost:
285 | upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
286 | kernel_size=4, stride=2,
287 | padding=1, bias=use_bias)
288 | down = [downrelu, downconv]
289 | up = [uprelu, upconv, upnorm]
290 | model = down + up
291 | else:
292 | upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
293 | kernel_size=4, stride=2,
294 | padding=1, bias=use_bias)
295 | down = [downrelu, downconv, downnorm]
296 | up = [uprelu, upconv, upnorm]
297 |
298 | if use_dropout:
299 | model = down + [submodule] + up + [nn.Dropout(0.5)]
300 | else:
301 | model = down + [submodule] + up
302 |
303 | self.model = nn.Sequential(*model)
304 |
305 | def forward(self, x):
306 | if self.outermost:
307 | return self.model(x)
308 | else:
309 | return torch.cat([x, self.model(x)], 1)
310 |
311 |
312 | # Defines the PatchGAN discriminator with the specified arguments.
313 | class NLayerDiscriminator(nn.Module):
314 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
315 | super(NLayerDiscriminator, self).__init__()
316 | if type(norm_layer) == functools.partial:
317 | use_bias = norm_layer.func == nn.InstanceNorm2d
318 | else:
319 | use_bias = norm_layer == nn.InstanceNorm2d
320 |
321 | kw = 4
322 | padw = 1
323 | sequence = [
324 | nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
325 | nn.LeakyReLU(0.2, True)
326 | ]
327 |
328 | nf_mult = 1
329 | nf_mult_prev = 1
330 | for n in range(1, n_layers):
331 | nf_mult_prev = nf_mult
332 | nf_mult = min(2**n, 8)
333 | sequence += [
334 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
335 | kernel_size=kw, stride=2, padding=padw, bias=use_bias),
336 | norm_layer(ndf * nf_mult),
337 | nn.LeakyReLU(0.2, True)
338 | ]
339 |
340 | nf_mult_prev = nf_mult
341 | nf_mult = min(2**n_layers, 8)
342 | sequence += [
343 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
344 | kernel_size=kw, stride=1, padding=padw, bias=use_bias),
345 | norm_layer(ndf * nf_mult),
346 | nn.LeakyReLU(0.2, True)
347 | ]
348 |
349 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
350 |
351 | if use_sigmoid:
352 | sequence += [nn.Sigmoid()]
353 |
354 | self.model = nn.Sequential(*sequence)
355 |
356 | def forward(self, input):
357 | return self.model(input)
358 |
359 |
360 | class PixelDiscriminator(nn.Module):
361 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, use_sigmoid=False):
362 | super(PixelDiscriminator, self).__init__()
363 | if type(norm_layer) == functools.partial:
364 | use_bias = norm_layer.func == nn.InstanceNorm2d
365 | else:
366 | use_bias = norm_layer == nn.InstanceNorm2d
367 |
368 | self.net = [
369 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0),
370 | nn.LeakyReLU(0.2, True),
371 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias),
372 | norm_layer(ndf * 2),
373 | nn.LeakyReLU(0.2, True),
374 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)]
375 |
376 | if use_sigmoid:
377 | self.net.append(nn.Sigmoid())
378 |
379 | self.net = nn.Sequential(*self.net)
380 |
381 | def forward(self, input):
382 | return self.net(input)
383 |
--------------------------------------------------------------------------------