├── data ├── __init__.py ├── data_loader.py ├── base_data_loader.py ├── custom_dataset_data_loader.py ├── image_folder.py ├── base_dataset.py └── aligned_dataset.py ├── models ├── __init__.py ├── models.py ├── base_model.py ├── pix2pixHD_model.py ├── ui_model.py └── networks.py ├── util ├── __init__.py ├── image_pool.py ├── html.py ├── util.py └── visualizer.py ├── options ├── __init__.py ├── test_options.py ├── train_options.py └── base_options.py ├── _config.yml ├── .gitattributes ├── requirements.txt ├── .gitignore ├── personal_scripts ├── clear_data.py └── preprocess.py ├── LICENSE ├── precompute_feature_maps.py ├── encode_features.py ├── test.py ├── train.py ├── README.md └── run_engine.py /data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-minimal -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # To install these Python dependencies, please type: 2 | # pip install -r requirements.txt 3 | 4 | streamlit 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | **/checkpoints 2 | **/datasets 3 | **/apex 4 | **/results 5 | **/raw_data 6 | **/imgs 7 | *.pyc 8 | *.prototxt 9 | *.tex 10 | -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | def CreateDataLoader(opt): 3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 4 | data_loader = CustomDatasetDataLoader() 5 | print(data_loader.name()) 6 | data_loader.initialize(opt) 7 | return data_loader 8 | -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | 2 | class BaseDataLoader(): 3 | def __init__(self): 4 | pass 5 | 6 | def initialize(self, opt): 7 | self.opt = opt 8 | pass 9 | 10 | def load_data(): 11 | return None 12 | 13 | 14 | 15 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def create_model(opt): 4 | if opt.model == 'pix2pixHD': 5 | from .pix2pixHD_model import Pix2PixHDModel, InferenceModel 6 | if opt.isTrain: 7 | model = Pix2PixHDModel() 8 | else: 9 | model = InferenceModel() 10 | else: 11 | from .ui_model import UIModel 12 | model = UIModel() 13 | model.initialize(opt) 14 | if opt.verbose: 15 | print("model [%s] was created" % (model.name())) 16 | 17 | if opt.isTrain and len(opt.gpu_ids) and not opt.fp16: 18 | model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) 19 | 20 | return model 21 | -------------------------------------------------------------------------------- /personal_scripts/clear_data.py: -------------------------------------------------------------------------------- 1 | import os, argparse 2 | 3 | def main(): 4 | parser = argparse.ArgumentParser('Remove all files of according extension') 5 | parser.add_argument('--fold_A', dest='fold_A', help='input directory for image A', type=str, default='./A') 6 | parser.add_argument('--fold_B', dest='fold_B', help='input directory for image B', type=str, default='./B') 7 | parser.add_argument('--ext', dest='ext', help='file extension you want removed', type=str, default='.png') 8 | args = parser.parse_args() 9 | for root, dirs, files in os.walk(args.fold_A): 10 | for file in files: 11 | if file.endswith(args.ext): 12 | os.remove(os.path.join(root,file)) 13 | for root, dirs, files in os.walk(args.fold_B): 14 | for file in files: 15 | if file.endswith(args.ext): 16 | os.remove(os.path.join(root,file)) 17 | print(args.ext + "removed") 18 | 19 | if __name__ == "__main__": 20 | main() 21 | -------------------------------------------------------------------------------- /data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | from data.base_data_loader import BaseDataLoader 3 | 4 | 5 | def CreateDataset(opt): 6 | dataset = None 7 | from data.aligned_dataset import AlignedDataset 8 | dataset = AlignedDataset() 9 | 10 | print("dataset [%s] was created" % (dataset.name())) 11 | dataset.initialize(opt) 12 | return dataset 13 | 14 | class CustomDatasetDataLoader(BaseDataLoader): 15 | def name(self): 16 | return 'CustomDatasetDataLoader' 17 | 18 | def initialize(self, opt): 19 | BaseDataLoader.initialize(self, opt) 20 | self.dataset = CreateDataset(opt) 21 | self.dataloader = torch.utils.data.DataLoader( 22 | self.dataset, 23 | batch_size=opt.batchSize, 24 | shuffle=not opt.serial_batches, 25 | num_workers=int(opt.nThreads)) 26 | 27 | def load_data(self): 28 | return self.dataloader 29 | 30 | def __len__(self): 31 | return min(len(self.dataset), self.opt.max_dataset_size) 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Thomas Huang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch.autograd import Variable 4 | class ImagePool(): 5 | def __init__(self, pool_size): 6 | self.pool_size = pool_size 7 | if self.pool_size > 0: 8 | self.num_imgs = 0 9 | self.images = [] 10 | 11 | def query(self, images): 12 | if self.pool_size == 0: 13 | return images 14 | return_images = [] 15 | for image in images.data: 16 | image = torch.unsqueeze(image, 0) 17 | if self.num_imgs < self.pool_size: 18 | self.num_imgs = self.num_imgs + 1 19 | self.images.append(image) 20 | return_images.append(image) 21 | else: 22 | p = random.uniform(0, 1) 23 | if p > 0.5: 24 | random_id = random.randint(0, self.pool_size-1) 25 | tmp = self.images[random_id].clone() 26 | self.images[random_id] = image 27 | return_images.append(tmp) 28 | else: 29 | return_images.append(image) 30 | return_images = Variable(torch.cat(return_images, 0)) 31 | return return_images 32 | -------------------------------------------------------------------------------- /precompute_feature_maps.py: -------------------------------------------------------------------------------- 1 | from options.train_options import TrainOptions 2 | from data.data_loader import CreateDataLoader 3 | from models.models import create_model 4 | import os 5 | import util.util as util 6 | from torch.autograd import Variable 7 | import torch.nn as nn 8 | 9 | opt = TrainOptions().parse() 10 | opt.nThreads = 1 11 | opt.batchSize = 1 12 | opt.serial_batches = True 13 | opt.no_flip = True 14 | opt.instance_feat = True 15 | 16 | name = 'features' 17 | save_path = os.path.join(opt.checkpoints_dir, opt.name) 18 | 19 | ############ Initialize ######### 20 | data_loader = CreateDataLoader(opt) 21 | dataset = data_loader.load_data() 22 | dataset_size = len(data_loader) 23 | model = create_model(opt) 24 | util.mkdirs(os.path.join(opt.dataroot, opt.phase + '_feat')) 25 | 26 | ######## Save precomputed feature maps for 1024p training ####### 27 | for i, data in enumerate(dataset): 28 | print('%d / %d images' % (i+1, dataset_size)) 29 | feat_map = model.module.netE.forward(Variable(data['image'].cuda(), volatile=True), data['inst'].cuda()) 30 | feat_map = nn.Upsample(scale_factor=2, mode='nearest')(feat_map) 31 | image_numpy = util.tensor2im(feat_map.data[0]) 32 | save_path = data['path'][0].replace('/train_label/', '/train_feat/') 33 | util.save_image(image_numpy, save_path) 34 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | class TestOptions(BaseOptions): 4 | def initialize(self): 5 | BaseOptions.initialize(self) 6 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.') 7 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 8 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 9 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 10 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 11 | self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run') 12 | self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features') 13 | self.parser.add_argument('--use_encoded_image', action='store_true', help='if specified, encode the real image to get the feature map') 14 | self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file") 15 | self.parser.add_argument("--engine", type=str, help="run serialized TRT engine") 16 | self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT") 17 | self.isTrain = False 18 | -------------------------------------------------------------------------------- /encode_features.py: -------------------------------------------------------------------------------- 1 | from options.train_options import TrainOptions 2 | from data.data_loader import CreateDataLoader 3 | from models.models import create_model 4 | import numpy as np 5 | import os 6 | 7 | opt = TrainOptions().parse() 8 | opt.nThreads = 1 9 | opt.batchSize = 1 10 | opt.serial_batches = True 11 | opt.no_flip = True 12 | opt.instance_feat = True 13 | opt.continue_train = True 14 | 15 | name = 'features' 16 | save_path = os.path.join(opt.checkpoints_dir, opt.name) 17 | 18 | ############ Initialize ######### 19 | data_loader = CreateDataLoader(opt) 20 | dataset = data_loader.load_data() 21 | dataset_size = len(data_loader) 22 | model = create_model(opt) 23 | 24 | ########### Encode features ########### 25 | reencode = True 26 | if reencode: 27 | features = {} 28 | for label in range(opt.label_nc): 29 | features[label] = np.zeros((0, opt.feat_num+1)) 30 | for i, data in enumerate(dataset): 31 | feat = model.module.encode_features(data['image'], data['inst']) 32 | for label in range(opt.label_nc): 33 | features[label] = np.append(features[label], feat[label], axis=0) 34 | 35 | print('%d / %d images' % (i+1, dataset_size)) 36 | save_name = os.path.join(save_path, name + '.npy') 37 | np.save(save_name, features) 38 | 39 | ############## Clustering ########### 40 | n_clusters = opt.n_clusters 41 | load_name = os.path.join(save_path, name + '.npy') 42 | features = np.load(load_name).item() 43 | from sklearn.cluster import KMeans 44 | centers = {} 45 | for label in range(opt.label_nc): 46 | feat = features[label] 47 | feat = feat[feat[:,-1] > 0.5, :-1] 48 | if feat.shape[0]: 49 | n_clusters = min(feat.shape[0], opt.n_clusters) 50 | kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(feat) 51 | centers[label] = kmeans.cluster_centers_ 52 | save_name = os.path.join(save_path, name + '_clustered_%03d.npy' % opt.n_clusters) 53 | np.save(save_name, centers) 54 | print('saving to %s' % save_name) 55 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import * 3 | import os 4 | 5 | 6 | class HTML: 7 | def __init__(self, web_dir, title, refresh=0): 8 | self.title = title 9 | self.web_dir = web_dir 10 | self.img_dir = os.path.join(self.web_dir, 'images') 11 | if not os.path.exists(self.web_dir): 12 | os.makedirs(self.web_dir) 13 | if not os.path.exists(self.img_dir): 14 | os.makedirs(self.img_dir) 15 | 16 | self.doc = dominate.document(title=title) 17 | if refresh > 0: 18 | with self.doc.head: 19 | meta(http_equiv="refresh", content=str(refresh)) 20 | 21 | def get_image_dir(self): 22 | return self.img_dir 23 | 24 | def add_header(self, str): 25 | with self.doc: 26 | h3(str) 27 | 28 | def add_table(self, border=1): 29 | self.t = table(border=border, style="table-layout: fixed;") 30 | self.doc.add(self.t) 31 | 32 | def add_images(self, ims, txts, links, width=512): 33 | self.add_table() 34 | with self.t: 35 | with tr(): 36 | for im, txt, link in zip(ims, txts, links): 37 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 38 | with p(): 39 | with a(href=os.path.join('images', link)): 40 | img(style="width:%dpx" % (width), src=os.path.join('images', im)) 41 | br() 42 | p(txt) 43 | 44 | def save(self): 45 | html_file = '%s/index.html' % self.web_dir 46 | f = open(html_file, 'wt') 47 | f.write(self.doc.render()) 48 | f.close() 49 | 50 | 51 | if __name__ == '__main__': 52 | html = HTML('web/', 'test_html') 53 | html.add_header('hello world') 54 | 55 | ims = [] 56 | txts = [] 57 | links = [] 58 | for n in range(4): 59 | ims.append('image_%d.jpg' % n) 60 | txts.append('text_%d' % n) 61 | links.append('image_%d.jpg' % n) 62 | html.add_images(ims, txts, links) 63 | html.save() 64 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import os 10 | 11 | IMG_EXTENSIONS = [ 12 | '.jpg', '.JPG', '.jpeg', '.JPEG', 13 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff' 14 | ] 15 | 16 | 17 | def is_image_file(filename): 18 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 19 | 20 | 21 | def make_dataset(dir): 22 | images = [] 23 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 24 | 25 | for root, _, fnames in sorted(os.walk(dir)): 26 | for fname in fnames: 27 | if is_image_file(fname): 28 | path = os.path.join(root, fname) 29 | images.append(path) 30 | 31 | return images 32 | 33 | 34 | def default_loader(path): 35 | return Image.open(path).convert('RGB') 36 | 37 | 38 | class ImageFolder(data.Dataset): 39 | 40 | def __init__(self, root, transform=None, return_paths=False, 41 | loader=default_loader): 42 | imgs = make_dataset(root) 43 | if len(imgs) == 0: 44 | raise(RuntimeError("Found 0 images in: " + root + "\n" 45 | "Supported image extensions are: " + 46 | ",".join(IMG_EXTENSIONS))) 47 | 48 | self.root = root 49 | self.imgs = imgs 50 | self.transform = transform 51 | self.return_paths = return_paths 52 | self.loader = loader 53 | 54 | def __getitem__(self, index): 55 | path = self.imgs[index] 56 | img = self.loader(path) 57 | if self.transform is not None: 58 | img = self.transform(img) 59 | if self.return_paths: 60 | return img, path 61 | else: 62 | return img 63 | 64 | def __len__(self): 65 | return len(self.imgs) 66 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | from torch.autograd import Variable 4 | from options.test_options import TestOptions 5 | from data.data_loader import CreateDataLoader 6 | from models.models import create_model 7 | import util.util as util 8 | from util.visualizer import Visualizer 9 | from util import html 10 | import torch 11 | 12 | opt = TestOptions().parse(save=False) 13 | opt.nThreads = 1 # test code only supports nThreads = 1 14 | opt.batchSize = 1 # test code only supports batchSize = 1 15 | opt.serial_batches = True # no shuffle 16 | opt.no_flip = True # no flip 17 | 18 | data_loader = CreateDataLoader(opt) 19 | dataset = data_loader.load_data() 20 | visualizer = Visualizer(opt) 21 | # create website 22 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) 23 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch)) 24 | 25 | # test 26 | if not opt.engine and not opt.onnx: 27 | model = create_model(opt) 28 | if opt.data_type == 16: 29 | model.half() 30 | elif opt.data_type == 8: 31 | model.type(torch.uint8) 32 | 33 | if opt.verbose: 34 | print(model) 35 | else: 36 | from run_engine import run_trt_engine, run_onnx 37 | 38 | for i, data in enumerate(dataset): 39 | if i >= opt.how_many: 40 | break 41 | if opt.data_type == 16: 42 | data['label'] = data['label'].half() 43 | data['inst'] = data['inst'].half() 44 | elif opt.data_type == 8: 45 | data['label'] = data['label'].uint8() 46 | data['inst'] = data['inst'].uint8() 47 | if opt.export_onnx: 48 | print ("Exporting to ONNX: ", opt.export_onnx) 49 | assert opt.export_onnx.endswith("onnx"), "Export model file should end with .onnx" 50 | torch.onnx.export(model, [data['label'], data['inst']], 51 | opt.export_onnx, verbose=True) 52 | exit(0) 53 | minibatch = 1 54 | if opt.engine: 55 | generated = run_trt_engine(opt.engine, minibatch, [data['label'], data['inst']]) 56 | elif opt.onnx: 57 | generated = run_onnx(opt.onnx, opt.data_type, minibatch, [data['label'], data['inst']]) 58 | else: 59 | generated = model.inference(data['label'], data['inst'], data['image']) 60 | 61 | visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), 62 | ('synthesized_image', util.tensor2im(generated.data[0]))]) 63 | img_path = data['path'] 64 | print('process image... %s' % img_path) 65 | visualizer.save_images(webpage, visuals, img_path) 66 | 67 | webpage.save() 68 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | class TrainOptions(BaseOptions): 4 | def initialize(self): 5 | BaseOptions.initialize(self) 6 | # for displays 7 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 8 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 9 | self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results') 10 | self.parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs') 11 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 12 | self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') 13 | 14 | # for training 15 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 16 | self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location') 17 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 18 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 19 | self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate') 20 | self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero') 21 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 22 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 23 | 24 | # for discriminators 25 | self.parser.add_argument('--num_D', type=int, default=2, help='number of discriminators to use') 26 | self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers') 27 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 28 | self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss') 29 | self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss') 30 | self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss') 31 | self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN') 32 | self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images') 33 | 34 | self.isTrain = True 35 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | from PIL import Image 3 | import torchvision.transforms as transforms 4 | import numpy as np 5 | import random 6 | 7 | class BaseDataset(data.Dataset): 8 | def __init__(self): 9 | super(BaseDataset, self).__init__() 10 | 11 | def name(self): 12 | return 'BaseDataset' 13 | 14 | def initialize(self, opt): 15 | pass 16 | 17 | def get_params(opt, size): 18 | w, h = size 19 | new_h = h 20 | new_w = w 21 | if opt.resize_or_crop == 'resize_and_crop': 22 | new_h = new_w = opt.loadSize 23 | elif opt.resize_or_crop == 'scale_width_and_crop': 24 | new_w = opt.loadSize 25 | new_h = opt.loadSize * h // w 26 | 27 | x = random.randint(0, np.maximum(0, new_w - opt.fineSize)) 28 | y = random.randint(0, np.maximum(0, new_h - opt.fineSize)) 29 | 30 | flip = random.random() > 0.5 31 | return {'crop_pos': (x, y), 'flip': flip} 32 | 33 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True): 34 | transform_list = [] 35 | if 'resize' in opt.resize_or_crop: 36 | osize = [opt.loadSize, opt.loadSize] 37 | transform_list.append(transforms.Scale(osize, method)) 38 | elif 'scale_width' in opt.resize_or_crop: 39 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method))) 40 | 41 | if 'crop' in opt.resize_or_crop: 42 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize))) 43 | 44 | if opt.resize_or_crop == 'none': 45 | base = float(2 ** opt.n_downsample_global) 46 | if opt.netG == 'local': 47 | base *= (2 ** opt.n_local_enhancers) 48 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) 49 | 50 | if opt.isTrain and not opt.no_flip: 51 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 52 | 53 | transform_list += [transforms.ToTensor()] 54 | 55 | if normalize: 56 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 57 | (0.5, 0.5, 0.5))] 58 | return transforms.Compose(transform_list) 59 | 60 | def normalize(): 61 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) 62 | 63 | def __make_power_2(img, base, method=Image.BICUBIC): 64 | ow, oh = img.size 65 | h = int(round(oh / base) * base) 66 | w = int(round(ow / base) * base) 67 | if (h == oh) and (w == ow): 68 | return img 69 | return img.resize((w, h), method) 70 | 71 | def __scale_width(img, target_width, method=Image.BICUBIC): 72 | ow, oh = img.size 73 | if (ow == target_width): 74 | return img 75 | w = target_width 76 | h = int(target_width * oh / ow) 77 | return img.resize((w, h), method) 78 | 79 | def __crop(img, pos, size): 80 | ow, oh = img.size 81 | x1, y1 = pos 82 | tw = th = size 83 | if (ow > tw or oh > th): 84 | return img.crop((x1, y1, x1 + tw, y1 + th)) 85 | return img 86 | 87 | def __flip(img, flip): 88 | if flip: 89 | return img.transpose(Image.FLIP_LEFT_RIGHT) 90 | return img 91 | -------------------------------------------------------------------------------- /data/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_params, get_transform, normalize 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | 6 | class AlignedDataset(BaseDataset): 7 | def initialize(self, opt): 8 | self.opt = opt 9 | self.root = opt.dataroot 10 | 11 | ### input A (label maps) 12 | dir_A = '_A' if self.opt.label_nc == 0 else '_label' 13 | self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A) 14 | self.A_paths = sorted(make_dataset(self.dir_A)) 15 | 16 | ### input B (real images) 17 | if opt.isTrain or opt.use_encoded_image: 18 | dir_B = '_B' if self.opt.label_nc == 0 else '_img' 19 | self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B) 20 | self.B_paths = sorted(make_dataset(self.dir_B)) 21 | 22 | ### instance maps 23 | if not opt.no_instance: 24 | self.dir_inst = os.path.join(opt.dataroot, opt.phase + '_inst') 25 | self.inst_paths = sorted(make_dataset(self.dir_inst)) 26 | 27 | ### load precomputed instance-wise encoded features 28 | if opt.load_features: 29 | self.dir_feat = os.path.join(opt.dataroot, opt.phase + '_feat') 30 | print('----------- loading features from %s ----------' % self.dir_feat) 31 | self.feat_paths = sorted(make_dataset(self.dir_feat)) 32 | 33 | self.dataset_size = len(self.A_paths) 34 | 35 | def __getitem__(self, index): 36 | ### input A (label maps) 37 | A_path = self.A_paths[index] 38 | A = Image.open(A_path) 39 | params = get_params(self.opt, A.size) 40 | if self.opt.label_nc == 0: 41 | transform_A = get_transform(self.opt, params) 42 | A_tensor = transform_A(A.convert('RGB')) 43 | else: 44 | transform_A = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) 45 | A_tensor = transform_A(A) * 255.0 46 | 47 | B_tensor = inst_tensor = feat_tensor = 0 48 | ### input B (real images) 49 | if self.opt.isTrain or self.opt.use_encoded_image: 50 | B_path = self.B_paths[index] 51 | B = Image.open(B_path).convert('RGB') 52 | transform_B = get_transform(self.opt, params) 53 | B_tensor = transform_B(B) 54 | 55 | ### if using instance maps 56 | if not self.opt.no_instance: 57 | inst_path = self.inst_paths[index] 58 | inst = Image.open(inst_path) 59 | inst_tensor = transform_A(inst) 60 | 61 | if self.opt.load_features: 62 | feat_path = self.feat_paths[index] 63 | feat = Image.open(feat_path).convert('RGB') 64 | norm = normalize() 65 | feat_tensor = norm(transform_A(feat)) 66 | 67 | input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor, 68 | 'feat': feat_tensor, 'path': A_path} 69 | 70 | return input_dict 71 | 72 | def __len__(self): 73 | return len(self.A_paths) // self.opt.batchSize * self.opt.batchSize 74 | 75 | def name(self): 76 | return 'AlignedDataset' -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import sys 4 | 5 | class BaseModel(torch.nn.Module): 6 | def name(self): 7 | return 'BaseModel' 8 | 9 | def initialize(self, opt): 10 | self.opt = opt 11 | self.gpu_ids = opt.gpu_ids 12 | self.isTrain = opt.isTrain 13 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor 14 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 15 | 16 | def set_input(self, input): 17 | self.input = input 18 | 19 | def forward(self): 20 | pass 21 | 22 | # used in test time, no backprop 23 | def test(self): 24 | pass 25 | 26 | def get_image_paths(self): 27 | pass 28 | 29 | def optimize_parameters(self): 30 | pass 31 | 32 | def get_current_visuals(self): 33 | return self.input 34 | 35 | def get_current_errors(self): 36 | return {} 37 | 38 | def save(self, label): 39 | pass 40 | 41 | # helper saving function that can be used by subclasses 42 | def save_network(self, network, network_label, epoch_label, gpu_ids): 43 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 44 | save_path = os.path.join(self.save_dir, save_filename) 45 | torch.save(network.cpu().state_dict(), save_path) 46 | if len(gpu_ids) and torch.cuda.is_available(): 47 | network.cuda() 48 | 49 | # helper loading function that can be used by subclasses 50 | def load_network(self, network, network_label, epoch_label, save_dir=''): 51 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 52 | if not save_dir: 53 | save_dir = self.save_dir 54 | save_path = os.path.join(save_dir, save_filename) 55 | if not os.path.isfile(save_path): 56 | print('%s not exists yet!' % save_path) 57 | if network_label == 'G': 58 | raise('Generator must exist!') 59 | else: 60 | #network.load_state_dict(torch.load(save_path)) 61 | try: 62 | network.load_state_dict(torch.load(save_path)) 63 | except: 64 | pretrained_dict = torch.load(save_path) 65 | model_dict = network.state_dict() 66 | try: 67 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 68 | network.load_state_dict(pretrained_dict) 69 | if self.opt.verbose: 70 | print('Pretrained network %s has excessive layers; Only loading layers that are used' % network_label) 71 | except: 72 | print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label) 73 | for k, v in pretrained_dict.items(): 74 | if v.size() == model_dict[k].size(): 75 | model_dict[k] = v 76 | 77 | if sys.version_info >= (3,0): 78 | not_initialized = set() 79 | else: 80 | from sets import Set 81 | not_initialized = Set() 82 | 83 | for k, v in model_dict.items(): 84 | if k not in pretrained_dict or v.size() != pretrained_dict[k].size(): 85 | not_initialized.add(k.split('.')[0]) 86 | 87 | print(sorted(not_initialized)) 88 | network.load_state_dict(model_dict) 89 | 90 | def update_learning_rate(self): 91 | pass 92 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | import numpy as np 6 | import os 7 | 8 | # Converts a Tensor into a Numpy array 9 | # |imtype|: the desired type of the converted numpy array 10 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True): 11 | if isinstance(image_tensor, list): 12 | image_numpy = [] 13 | for i in range(len(image_tensor)): 14 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize)) 15 | return image_numpy 16 | image_numpy = image_tensor.cpu().float().numpy() 17 | if normalize: 18 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 19 | else: 20 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0 21 | image_numpy = np.clip(image_numpy, 0, 255) 22 | if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3: 23 | image_numpy = image_numpy[:,:,0] 24 | return image_numpy.astype(imtype) 25 | 26 | # Converts a one-hot tensor into a colorful label map 27 | def tensor2label(label_tensor, n_label, imtype=np.uint8): 28 | if n_label == 0: 29 | return tensor2im(label_tensor, imtype) 30 | label_tensor = label_tensor.cpu().float() 31 | if label_tensor.size()[0] > 1: 32 | label_tensor = label_tensor.max(0, keepdim=True)[1] 33 | label_tensor = Colorize(n_label)(label_tensor) 34 | label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0)) 35 | return label_numpy.astype(imtype) 36 | 37 | def save_image(image_numpy, image_path): 38 | image_pil = Image.fromarray(image_numpy) 39 | image_pil.save(image_path) 40 | 41 | def mkdirs(paths): 42 | if isinstance(paths, list) and not isinstance(paths, str): 43 | for path in paths: 44 | mkdir(path) 45 | else: 46 | mkdir(paths) 47 | 48 | def mkdir(path): 49 | if not os.path.exists(path): 50 | os.makedirs(path) 51 | 52 | ############################################################################### 53 | # Code from 54 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py 55 | # Modified so it complies with the Citscape label map colors 56 | ############################################################################### 57 | def uint82bin(n, count=8): 58 | """returns the binary of integer n, count refers to amount of bits""" 59 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)]) 60 | 61 | def labelcolormap(N): 62 | if N == 35: # cityscape 63 | cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81), 64 | (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153), 65 | (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0), 66 | (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70), 67 | ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)], 68 | dtype=np.uint8) 69 | else: 70 | cmap = np.zeros((N, 3), dtype=np.uint8) 71 | for i in range(N): 72 | r, g, b = 0, 0, 0 73 | id = i 74 | for j in range(7): 75 | str_id = uint82bin(id) 76 | r = r ^ (np.uint8(str_id[-1]) << (7-j)) 77 | g = g ^ (np.uint8(str_id[-2]) << (7-j)) 78 | b = b ^ (np.uint8(str_id[-3]) << (7-j)) 79 | id = id >> 3 80 | cmap[i, 0] = r 81 | cmap[i, 1] = g 82 | cmap[i, 2] = b 83 | return cmap 84 | 85 | class Colorize(object): 86 | def __init__(self, n=35): 87 | self.cmap = labelcolormap(n) 88 | self.cmap = torch.from_numpy(self.cmap[:n]) 89 | 90 | def __call__(self, gray_image): 91 | size = gray_image.size() 92 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0) 93 | 94 | for label in range(0, len(self.cmap)): 95 | mask = (label == gray_image[0]).cpu() 96 | color_image[0][mask] = self.cmap[label][0] 97 | color_image[1][mask] = self.cmap[label][1] 98 | color_image[2][mask] = self.cmap[label][2] 99 | 100 | return color_image 101 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import numpy as np 4 | import torch 5 | from torch.autograd import Variable 6 | from collections import OrderedDict 7 | from subprocess import call 8 | import fractions 9 | def lcm(a,b): return abs(a * b)/fractions.gcd(a,b) if a and b else 0 10 | 11 | from options.train_options import TrainOptions 12 | from data.data_loader import CreateDataLoader 13 | from models.models import create_model 14 | import util.util as util 15 | from util.visualizer import Visualizer 16 | 17 | opt = TrainOptions().parse() 18 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') 19 | if opt.continue_train: 20 | try: 21 | start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) 22 | except: 23 | start_epoch, epoch_iter = 1, 0 24 | print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) 25 | else: 26 | start_epoch, epoch_iter = 1, 0 27 | 28 | opt.print_freq = lcm(opt.print_freq, opt.batchSize) 29 | if opt.debug: 30 | opt.display_freq = 1 31 | opt.print_freq = 1 32 | opt.niter = 1 33 | opt.niter_decay = 0 34 | opt.max_dataset_size = 10 35 | 36 | data_loader = CreateDataLoader(opt) 37 | dataset = data_loader.load_data() 38 | dataset_size = len(data_loader) 39 | print('#training images = %d' % dataset_size) 40 | 41 | model = create_model(opt) 42 | visualizer = Visualizer(opt) 43 | if opt.fp16: 44 | from apex import amp 45 | model, [optimizer_G, optimizer_D] = amp.initialize(model, [model.optimizer_G, model.optimizer_D], opt_level='O1') 46 | model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) 47 | else: 48 | optimizer_G, optimizer_D = model.module.optimizer_G, model.module.optimizer_D 49 | 50 | total_steps = (start_epoch-1) * dataset_size + epoch_iter 51 | 52 | display_delta = total_steps % opt.display_freq 53 | print_delta = total_steps % opt.print_freq 54 | save_delta = total_steps % opt.save_latest_freq 55 | 56 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1): 57 | epoch_start_time = time.time() 58 | if epoch != start_epoch: 59 | epoch_iter = epoch_iter % dataset_size 60 | for i, data in enumerate(dataset, start=epoch_iter): 61 | if total_steps % opt.print_freq == print_delta: 62 | iter_start_time = time.time() 63 | total_steps += opt.batchSize 64 | epoch_iter += opt.batchSize 65 | 66 | # whether to collect output images 67 | save_fake = total_steps % opt.display_freq == display_delta 68 | 69 | ############## Forward Pass ###################### 70 | losses, generated = model(Variable(data['label']), Variable(data['inst']), 71 | Variable(data['image']), Variable(data['feat']), infer=save_fake) 72 | 73 | # sum per device losses 74 | losses = [ torch.mean(x) if not isinstance(x, int) else x for x in losses ] 75 | loss_dict = dict(zip(model.module.loss_names, losses)) 76 | 77 | # calculate final loss scalar 78 | loss_D = (loss_dict['D_fake'] + loss_dict['D_real']) * 0.5 79 | loss_G = loss_dict['G_GAN'] + loss_dict.get('G_GAN_Feat',0) + loss_dict.get('G_VGG',0) 80 | 81 | ############### Backward Pass #################### 82 | # update generator weights 83 | optimizer_G.zero_grad() 84 | if opt.fp16: 85 | with amp.scale_loss(loss_G, optimizer_G) as scaled_loss: scaled_loss.backward() 86 | else: 87 | loss_G.backward() 88 | optimizer_G.step() 89 | 90 | # update discriminator weights 91 | optimizer_D.zero_grad() 92 | if opt.fp16: 93 | with amp.scale_loss(loss_D, optimizer_D) as scaled_loss: scaled_loss.backward() 94 | else: 95 | loss_D.backward() 96 | optimizer_D.step() 97 | 98 | ############## Display results and errors ########## 99 | ### print out errors 100 | if total_steps % opt.print_freq == print_delta: 101 | errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dict.items()} 102 | t = (time.time() - iter_start_time) / opt.print_freq 103 | visualizer.print_current_errors(epoch, epoch_iter, errors, t) 104 | visualizer.plot_current_errors(errors, total_steps) 105 | 106 | ### display output images 107 | if save_fake: 108 | visuals = OrderedDict([('input_label', util.tensor2label(data['label'][0], opt.label_nc)), 109 | ('synthesized_image', util.tensor2im(generated.data[0])), 110 | ('real_image', util.tensor2im(data['image'][0]))]) 111 | visualizer.display_current_results(visuals, epoch, total_steps) 112 | 113 | ### save latest model 114 | if total_steps % opt.save_latest_freq == save_delta: 115 | print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) 116 | model.module.save('latest') 117 | np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') 118 | 119 | if epoch_iter >= dataset_size: 120 | break 121 | 122 | # end of epoch 123 | iter_end_time = time.time() 124 | print('End of epoch %d / %d \t Time Taken: %d sec' % 125 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time)) 126 | 127 | ### save model for this epoch 128 | if epoch % opt.save_epoch_freq == 0: 129 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) 130 | model.module.save('latest') 131 | model.module.save(epoch) 132 | np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') 133 | 134 | ### instead of only training the local enhancer, train the entire network after certain iterations 135 | if (opt.niter_fix_global != 0) and (epoch == opt.niter_fix_global): 136 | model.module.update_fixed_params() 137 | 138 | ### linearly decay learning rate after certain iterations 139 | if epoch > opt.niter: 140 | model.module.update_learning_rate() 141 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import ntpath 4 | import time 5 | from . import util 6 | from . import html 7 | from PIL import Image 8 | try: 9 | from StringIO import StringIO # Python 2.7 10 | except ImportError: 11 | from io import BytesIO # Python 3.x 12 | 13 | class Visualizer(): 14 | def __init__(self, opt): 15 | # self.opt = opt 16 | self.tf_log = opt.tf_log 17 | self.use_html = opt.isTrain and not opt.no_html 18 | self.win_size = opt.display_winsize 19 | self.name = opt.name 20 | if self.tf_log: 21 | import tensorflow as tf 22 | self.tf = tf 23 | self.log_dir = os.path.join(opt.checkpoints_dir, opt.name, 'logs') 24 | self.writer = tf.summary.FileWriter(self.log_dir) 25 | 26 | if self.use_html: 27 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 28 | self.img_dir = os.path.join(self.web_dir, 'images') 29 | print('create web directory %s...' % self.web_dir) 30 | util.mkdirs([self.web_dir, self.img_dir]) 31 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 32 | with open(self.log_name, "a") as log_file: 33 | now = time.strftime("%c") 34 | log_file.write('================ Training Loss (%s) ================\n' % now) 35 | 36 | # |visuals|: dictionary of images to display or save 37 | def display_current_results(self, visuals, epoch, step): 38 | if self.tf_log: # show images in tensorboard output 39 | img_summaries = [] 40 | for label, image_numpy in visuals.items(): 41 | # Write the image to a string 42 | try: 43 | s = StringIO() 44 | except: 45 | s = BytesIO() 46 | Image.fromarray(image_numpy).save(s, format="jpeg") 47 | # Create an Image object 48 | img_sum = self.tf.Summary.Image(encoded_image_string=s.getvalue(), height=image_numpy.shape[0], width=image_numpy.shape[1]) 49 | # Create a Summary value 50 | img_summaries.append(self.tf.Summary.Value(tag=label, image=img_sum)) 51 | 52 | # Create and write Summary 53 | summary = self.tf.Summary(value=img_summaries) 54 | self.writer.add_summary(summary, step) 55 | 56 | if self.use_html: # save images to a html file 57 | for label, image_numpy in visuals.items(): 58 | if isinstance(image_numpy, list): 59 | for i in range(len(image_numpy)): 60 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s_%d.jpg' % (epoch, label, i)) 61 | util.save_image(image_numpy[i], img_path) 62 | else: 63 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.jpg' % (epoch, label)) 64 | util.save_image(image_numpy, img_path) 65 | 66 | # update website 67 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=30) 68 | for n in range(epoch, 0, -1): 69 | webpage.add_header('epoch [%d]' % n) 70 | ims = [] 71 | txts = [] 72 | links = [] 73 | 74 | for label, image_numpy in visuals.items(): 75 | if isinstance(image_numpy, list): 76 | for i in range(len(image_numpy)): 77 | img_path = 'epoch%.3d_%s_%d.jpg' % (n, label, i) 78 | ims.append(img_path) 79 | txts.append(label+str(i)) 80 | links.append(img_path) 81 | else: 82 | img_path = 'epoch%.3d_%s.jpg' % (n, label) 83 | ims.append(img_path) 84 | txts.append(label) 85 | links.append(img_path) 86 | if len(ims) < 10: 87 | webpage.add_images(ims, txts, links, width=self.win_size) 88 | else: 89 | num = int(round(len(ims)/2.0)) 90 | webpage.add_images(ims[:num], txts[:num], links[:num], width=self.win_size) 91 | webpage.add_images(ims[num:], txts[num:], links[num:], width=self.win_size) 92 | webpage.save() 93 | 94 | # errors: dictionary of error labels and values 95 | def plot_current_errors(self, errors, step): 96 | if self.tf_log: 97 | for tag, value in errors.items(): 98 | summary = self.tf.Summary(value=[self.tf.Summary.Value(tag=tag, simple_value=value)]) 99 | self.writer.add_summary(summary, step) 100 | 101 | # errors: same format as |errors| of plotCurrentErrors 102 | def print_current_errors(self, epoch, i, errors, t): 103 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t) 104 | for k, v in errors.items(): 105 | if v != 0: 106 | message += '%s: %.3f ' % (k, v) 107 | 108 | print(message) 109 | with open(self.log_name, "a") as log_file: 110 | log_file.write('%s\n' % message) 111 | 112 | # save image to the disk 113 | def save_images(self, webpage, visuals, image_path): 114 | image_dir = webpage.get_image_dir() 115 | short_path = ntpath.basename(image_path[0]) 116 | name = os.path.splitext(short_path)[0] 117 | 118 | webpage.add_header(name) 119 | ims = [] 120 | txts = [] 121 | links = [] 122 | 123 | for label, image_numpy in visuals.items(): 124 | image_name = '%s_%s.jpg' % (name, label) 125 | save_path = os.path.join(image_dir, image_name) 126 | util.save_image(image_numpy, save_path) 127 | 128 | ims.append(image_name) 129 | txts.append(label) 130 | links.append(image_name) 131 | webpage.add_images(ims, txts, links, width=self.win_size) 132 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Docuwarp 2 | [![Codacy Badge](https://app.codacy.com/project/badge/Grade/e8bf67a83de04872aecd3c09f11b6389)](https://www.codacy.com/gh/thomasjhuang/deep-learning-for-document-dewarping/dashboard?utm_source=github.com&utm_medium=referral&utm_content=thomasjhuang/deep-learning-for-document-dewarping&utm_campaign=Badge_Grade) 3 | ![Python version](https://img.shields.io/pypi/pyversions/dominate.svg?style=flat) 4 | 5 | This project is focused on dewarping document images through the usage of pix2pixHD, a GAN that is useful for general image to image translation. The objective is to take images of documents that are warped, folded, crumpled, etc. and convert the image to a "dewarped" state by using [pix2pixHD](https://github.com/NVIDIA/pix2pixHD) to train and perform inference. All of the model code is borrowed directly from the pix2pixHD official repository. 6 | 7 | Some of the intuition behind doing this is inspired by these two papers: 8 | 1. [DocUNet: Document Image Unwarping via A Stacked U-Net (Ma et.al)](https://www.juew.org/publication/DocUNet.pdf) 9 | 2. [Document Image Dewarping using Deep Learning (Ramanna et.al)](www.insticc.org/Primoris/Resources/PaperPdf.ashx?idPaper=73684) 10 | 11 | ## Prerequisites 12 | 13 | This project requires **Python** and the following Python libraries installed: 14 | 15 | - Linux or OSX 16 | - [scikit-learn](http://scikit-learn.org/stable/) 17 | - NVIDIA GPU (11G memory or larger) + CUDA cuDNN 18 | - [Pytorch](https://pytorch.org/get-started/locally/) 19 | - [Pillow](https://pillow.readthedocs.io/en/stable/installation.html) 20 | - [OpenCV](https://opencv-python-tutroals.readthedocs.io/en/latest/py_tutorials/py_setup/py_table_of_contents_setup/py_table_of_contents_setup.html) 21 | 22 | ## Getting Started 23 | ### Installation 24 | - Install PyTorch and dependencies from 25 | - Install python libraries [dominate](https://github.com/Knio/dominate). 26 | ```bash 27 | pip install dominate 28 | ``` 29 | - Clone this repo: 30 | ```bash 31 | git clone https://github.com/thomasjhuang/deep-learning-for-document-dewarping 32 | cd deep-learning-for-document-dewarping 33 | ``` 34 | 35 | ### Training 36 | - Train the kaggle model with 256x256 crops: 37 | ```bash 38 | python train.py --name kaggle --label_nc 0 --no_instance --no_flip --netG local --ngf 32 --fineSize 256 39 | ``` 40 | - To view training results, please checkout intermediate results in `./checkpoints/kaggle/web/index.html`. 41 | If you have tensorflow installed, you can see tensorboard logs in `./checkpoints/kaggle/logs` by adding `--tf_log` to the training scripts. 42 | 43 | ### Training with your own dataset 44 | - If you want to train with your own dataset, please generate label maps which are one-channel whose pixel values correspond to the object labels (i.e. 0,1,...,N-1, where N is the number of labels). This is because we need to generate one-hot vectors from the label maps. Please also specity `--label_nc N` during both training and testing. 45 | - If your input is not a label map, please just specify `--label_nc 0` which will directly use the RGB colors as input. The folders should then be named `train_A`, `train_B` instead of `train_label`, `train_img`, where the goal is to translate images from A to B. 46 | - If you don't have instance maps or don't want to use them, please specify `--no_instance`. 47 | - The default setting for preprocessing is `scale_width`, which will scale the width of all training images to `opt.loadSize` (1024) while keeping the aspect ratio. If you want a different setting, please change it by using the `--resize_or_crop` option. For example, `scale_width_and_crop` first resizes the image to have width `opt.loadSize` and then does random cropping of size `(opt.fineSize, opt.fineSize)`. `crop` skips the resizing step and only performs random cropping. If you don't want any preprocessing, please specify `none`, which will do nothing other than making sure the image is divisible by 32. 48 | 49 | ### Testing 50 | - Test the model: 51 | ```bash 52 | python test.py --name kaggle --label_nc 0 --netG local --ngf 32 --resize_or_crop crop --no_instance --no_flip --fineSize 256 53 | ``` 54 | The test results will be saved to a directory here: `./results/kaggle/test_latest/`. 55 | 56 | 57 | ### Dataset 58 | - I use the kaggle denoising dirty documents dataset. To train a model on the full dataset, please download it from the [official website](https://www.kaggle.com/c/denoising-dirty-documents/data). 59 | After downloading, please put it under the `datasets` folder with warped images under the directory name `train_A` and unwarped images under the directory `train_B`. Your test images are warped images, and should be under the name `test_A`. Below is an example dataset directory structure. 60 | 61 | . 62 | ├── ... 63 | ├── datasets 64 | │ ├── train_A # warped images 65 | │ ├── train_B # unwarped, "ground truth" images 66 | │ └── test_A # warped images used for testing 67 | └── ... 68 | 69 | ### Multi-GPU training 70 | - Train a model using multiple GPUs (`bash ./scripts/train_kaggle_256_multigpu.sh`): 71 | ```bash 72 | #!./scripts/train_kaggle_256_multigpu.sh 73 | python train.py --name kaggle_256_multigpu --label_nc 0 --netG local --ngf 32 --resize_or_crop crop --no_instance --no_flip --fineSize 256 --batchSize 32 --gpu_ids 0,1,2,3,4,5,6,7 74 | ``` 75 | 76 | ### Training with Automatic Mixed Precision (AMP) for faster speed 77 | - To train with mixed precision support, please first install apex from: 78 | - You can then train the model by adding `--fp16`. For example, 79 | ```bash 80 | #!./scripts/train_512p_fp16.sh 81 | python -m torch.distributed.launch train.py --name label2city_512p --fp16 82 | ``` 83 | In my test case, it trains about 80% faster with AMP on a Volta machine. 84 | 85 | ## More Training/Test Details 86 | - Flags: see `options/train_options.py` and `options/base_options.py` for all the training flags; see `options/test_options.py` and `options/base_options.py` for all the test flags. 87 | - Instance map: we take in both label maps and instance maps as input. If you don't want to use instance maps, please specify the flag `--no_instance`. 88 | -------------------------------------------------------------------------------- /run_engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from random import randint 4 | import numpy as np 5 | import tensorrt 6 | 7 | try: 8 | from PIL import Image 9 | import pycuda.driver as cuda 10 | import pycuda.gpuarray as gpuarray 11 | import pycuda.autoinit 12 | import argparse 13 | except ImportError as err: 14 | sys.stderr.write("""ERROR: failed to import module ({}) 15 | Please make sure you have pycuda and the example dependencies installed. 16 | https://wiki.tiker.net/PyCuda/Installation/Linux 17 | pip(3) install tensorrt[examples] 18 | """.format(err)) 19 | exit(1) 20 | 21 | try: 22 | import tensorrt as trt 23 | from tensorrt.parsers import caffeparser 24 | from tensorrt.parsers import onnxparser 25 | except ImportError as err: 26 | sys.stderr.write("""ERROR: failed to import module ({}) 27 | Please make sure you have the TensorRT Library installed 28 | and accessible in your LD_LIBRARY_PATH 29 | """.format(err)) 30 | exit(1) 31 | 32 | 33 | G_LOGGER = trt.infer.ConsoleLogger(trt.infer.LogSeverity.INFO) 34 | 35 | class Profiler(trt.infer.Profiler): 36 | """ 37 | Example Implimentation of a Profiler 38 | Is identical to the Profiler class in trt.infer so it is possible 39 | to just use that instead of implementing this if further 40 | functionality is not needed 41 | """ 42 | def __init__(self, timing_iter): 43 | trt.infer.Profiler.__init__(self) 44 | self.timing_iterations = timing_iter 45 | self.profile = [] 46 | 47 | def report_layer_time(self, layerName, ms): 48 | record = next((r for r in self.profile if r[0] == layerName), (None, None)) 49 | if record == (None, None): 50 | self.profile.append((layerName, ms)) 51 | else: 52 | self.profile[self.profile.index(record)] = (record[0], record[1] + ms) 53 | 54 | def print_layer_times(self): 55 | totalTime = 0 56 | for i in range(len(self.profile)): 57 | print("{:40.40} {:4.3f}ms".format(self.profile[i][0], self.profile[i][1] / self.timing_iterations)) 58 | totalTime += self.profile[i][1] 59 | print("Time over all layers: {:4.2f} ms per iteration".format(totalTime / self.timing_iterations)) 60 | 61 | 62 | def get_input_output_names(trt_engine): 63 | nbindings = trt_engine.get_nb_bindings(); 64 | maps = [] 65 | 66 | for b in range(0, nbindings): 67 | dims = trt_engine.get_binding_dimensions(b).to_DimsCHW() 68 | name = trt_engine.get_binding_name(b) 69 | type = trt_engine.get_binding_data_type(b) 70 | 71 | if (trt_engine.binding_is_input(b)): 72 | maps.append(name) 73 | print("Found input: ", name) 74 | else: 75 | maps.append(name) 76 | print("Found output: ", name) 77 | 78 | print("shape=" + str(dims.C()) + " , " + str(dims.H()) + " , " + str(dims.W())) 79 | print("dtype=" + str(type)) 80 | return maps 81 | 82 | def create_memory(engine, name, buf, mem, batchsize, inp, inp_idx): 83 | binding_idx = engine.get_binding_index(name) 84 | if binding_idx == -1: 85 | raise AttributeError("Not a valid binding") 86 | print("Binding: name={}, bindingIndex={}".format(name, str(binding_idx))) 87 | dims = engine.get_binding_dimensions(binding_idx).to_DimsCHW() 88 | eltCount = dims.C() * dims.H() * dims.W() * batchsize 89 | 90 | if engine.binding_is_input(binding_idx): 91 | h_mem = inp[inp_idx] 92 | inp_idx = inp_idx + 1 93 | else: 94 | h_mem = np.random.uniform(0.0, 255.0, eltCount).astype(np.dtype('f4')) 95 | 96 | d_mem = cuda.mem_alloc(eltCount * 4) 97 | cuda.memcpy_htod(d_mem, h_mem) 98 | buf.insert(binding_idx, int(d_mem)) 99 | mem.append(d_mem) 100 | return inp_idx 101 | 102 | 103 | #Run inference on device 104 | def time_inference(engine, batch_size, inp): 105 | bindings = [] 106 | mem = [] 107 | inp_idx = 0 108 | for io in get_input_output_names(engine): 109 | inp_idx = create_memory(engine, io, bindings, mem, 110 | batch_size, inp, inp_idx) 111 | 112 | context = engine.create_execution_context() 113 | g_prof = Profiler(500) 114 | context.set_profiler(g_prof) 115 | for _ in range(iter): 116 | context.execute(batch_size, bindings) 117 | g_prof.print_layer_times() 118 | 119 | context.destroy() 120 | return 121 | 122 | 123 | def convert_to_datatype(v): 124 | if v==8: 125 | return trt.infer.DataType.INT8 126 | elif v==16: 127 | return trt.infer.DataType.HALF 128 | elif v==32: 129 | return trt.infer.DataType.FLOAT 130 | else: 131 | print("ERROR: Invalid model data type bit depth: " + str(v)) 132 | return trt.infer.DataType.INT8 133 | 134 | def run_trt_engine(engine_file, bs, it): 135 | engine = trt.utils.load_engine(G_LOGGER, engine_file) 136 | time_inference(engine, bs, it) 137 | 138 | def run_onnx(onnx_file, data_type, bs, inp): 139 | # Create onnx_config 140 | apex = onnxparser.create_onnxconfig() 141 | apex.set_model_file_name(onnx_file) 142 | apex.set_model_dtype(convert_to_datatype(data_type)) 143 | 144 | # create parser 145 | trt_parser = onnxparser.create_onnxparser(apex) 146 | assert(trt_parser) 147 | data_type = apex.get_model_dtype() 148 | onnx_filename = apex.get_model_file_name() 149 | trt_parser.parse(onnx_filename, data_type) 150 | trt_parser.report_parsing_info() 151 | trt_parser.convert_to_trtnetwork() 152 | trt_network = trt_parser.get_trtnetwork() 153 | assert(trt_network) 154 | 155 | # create infer builder 156 | trt_builder = trt.infer.create_infer_builder(G_LOGGER) 157 | trt_builder.set_max_batch_size(max_batch_size) 158 | trt_builder.set_max_workspace_size(max_workspace_size) 159 | 160 | if (apex.get_model_dtype() == trt.infer.DataType_kHALF): 161 | print("------------------- Running FP16 -----------------------------") 162 | trt_builder.set_half2_mode(True) 163 | elif (apex.get_model_dtype() == trt.infer.DataType_kINT8): 164 | print("------------------- Running INT8 -----------------------------") 165 | trt_builder.set_int8_mode(True) 166 | else: 167 | print("------------------- Running FP32 -----------------------------") 168 | 169 | print("----- Builder is Done -----") 170 | print("----- Creating Engine -----") 171 | trt_engine = trt_builder.build_cuda_engine(trt_network) 172 | print("----- Engine is built -----") 173 | time_inference(engine, bs, inp) 174 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | 6 | class BaseOptions(): 7 | def __init__(self): 8 | self.parser = argparse.ArgumentParser() 9 | self.initialized = False 10 | 11 | def initialize(self): 12 | # experiment specifics 13 | self.parser.add_argument('--name', type=str, default='label2city', help='name of the experiment. It decides where to store samples and models') 14 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 15 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 16 | self.parser.add_argument('--model', type=str, default='pix2pixHD', help='which model to use') 17 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization') 18 | self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator') 19 | self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit") 20 | self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose') 21 | self.parser.add_argument('--fp16', action='store_true', default=False, help='train with AMP') 22 | self.parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training') 23 | 24 | # input/output sizes 25 | self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size') 26 | self.parser.add_argument('--loadSize', type=int, default=1024, help='scale images to this size') 27 | self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size') 28 | self.parser.add_argument('--label_nc', type=int, default=35, help='# of input label channels') 29 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 30 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 31 | 32 | # for setting inputs 33 | self.parser.add_argument('--dataroot', type=str, default='./datasets/cityscapes/') 34 | self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]') 35 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 36 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') 37 | self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data') 38 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 39 | 40 | # for displays 41 | self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size') 42 | self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed') 43 | 44 | # for generator 45 | self.parser.add_argument('--netG', type=str, default='global', help='selects model to use for netG') 46 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 47 | self.parser.add_argument('--n_downsample_global', type=int, default=4, help='number of downsampling layers in netG') 48 | self.parser.add_argument('--n_blocks_global', type=int, default=9, help='number of residual blocks in the global generator network') 49 | self.parser.add_argument('--n_blocks_local', type=int, default=3, help='number of residual blocks in the local enhancer network') 50 | self.parser.add_argument('--n_local_enhancers', type=int, default=1, help='number of local enhancers to use') 51 | self.parser.add_argument('--niter_fix_global', type=int, default=0, help='number of epochs that we only train the outmost local enhancer') 52 | 53 | # for instance-wise features 54 | self.parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input') 55 | self.parser.add_argument('--instance_feat', action='store_true', help='if specified, add encoded instance features as input') 56 | self.parser.add_argument('--label_feat', action='store_true', help='if specified, add encoded label features as input') 57 | self.parser.add_argument('--feat_num', type=int, default=3, help='vector length for encoded features') 58 | self.parser.add_argument('--load_features', action='store_true', help='if specified, load precomputed feature maps') 59 | self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder') 60 | self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer') 61 | self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features') 62 | 63 | self.initialized = True 64 | 65 | def parse(self, save=True): 66 | if not self.initialized: 67 | self.initialize() 68 | self.opt = self.parser.parse_args() 69 | self.opt.isTrain = self.isTrain # train or test 70 | 71 | str_ids = self.opt.gpu_ids.split(',') 72 | self.opt.gpu_ids = [] 73 | for str_id in str_ids: 74 | id = int(str_id) 75 | if id >= 0: 76 | self.opt.gpu_ids.append(id) 77 | 78 | # set gpu ids 79 | if len(self.opt.gpu_ids) > 0: 80 | torch.cuda.set_device(self.opt.gpu_ids[0]) 81 | 82 | args = vars(self.opt) 83 | 84 | print('------------ Options -------------') 85 | for k, v in sorted(args.items()): 86 | print('%s: %s' % (str(k), str(v))) 87 | print('-------------- End ----------------') 88 | 89 | # save to the disk 90 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name) 91 | util.mkdirs(expr_dir) 92 | if save and not self.opt.continue_train: 93 | file_name = os.path.join(expr_dir, 'opt.txt') 94 | with open(file_name, 'wt') as opt_file: 95 | opt_file.write('------------ Options -------------\n') 96 | for k, v in sorted(args.items()): 97 | opt_file.write('%s: %s\n' % (str(k), str(v))) 98 | opt_file.write('-------------- End ----------------\n') 99 | return self.opt 100 | -------------------------------------------------------------------------------- /models/pix2pixHD_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from torch.autograd import Variable 5 | from util.image_pool import ImagePool 6 | from .base_model import BaseModel 7 | from . import networks 8 | 9 | class Pix2PixHDModel(BaseModel): 10 | def name(self): 11 | return 'Pix2PixHDModel' 12 | 13 | def init_loss_filter(self, use_gan_feat_loss, use_vgg_loss): 14 | flags = (True, use_gan_feat_loss, use_vgg_loss, True, True) 15 | def loss_filter(g_gan, g_gan_feat, g_vgg, d_real, d_fake): 16 | return [l for (l,f) in zip((g_gan,g_gan_feat,g_vgg,d_real,d_fake),flags) if f] 17 | return loss_filter 18 | 19 | def initialize(self, opt): 20 | BaseModel.initialize(self, opt) 21 | if opt.resize_or_crop != 'none' or not opt.isTrain: # when training at full res this causes OOM 22 | torch.backends.cudnn.benchmark = True 23 | self.isTrain = opt.isTrain 24 | self.use_features = opt.instance_feat or opt.label_feat 25 | self.gen_features = self.use_features and not self.opt.load_features 26 | input_nc = opt.label_nc if opt.label_nc != 0 else opt.input_nc 27 | 28 | ##### define networks 29 | # Generator network 30 | netG_input_nc = input_nc 31 | if not opt.no_instance: 32 | netG_input_nc += 1 33 | if self.use_features: 34 | netG_input_nc += opt.feat_num 35 | self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, 36 | opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, 37 | opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids) 38 | 39 | # Discriminator network 40 | if self.isTrain: 41 | use_sigmoid = opt.no_lsgan 42 | netD_input_nc = input_nc + opt.output_nc 43 | if not opt.no_instance: 44 | netD_input_nc += 1 45 | self.netD = networks.define_D(netD_input_nc, opt.ndf, opt.n_layers_D, opt.norm, use_sigmoid, 46 | opt.num_D, not opt.no_ganFeat_loss, gpu_ids=self.gpu_ids) 47 | 48 | ### Encoder network 49 | if self.gen_features: 50 | self.netE = networks.define_G(opt.output_nc, opt.feat_num, opt.nef, 'encoder', 51 | opt.n_downsample_E, norm=opt.norm, gpu_ids=self.gpu_ids) 52 | if self.opt.verbose: 53 | print('---------- Networks initialized -------------') 54 | 55 | # load networks 56 | if not self.isTrain or opt.continue_train or opt.load_pretrain: 57 | pretrained_path = '' if not self.isTrain else opt.load_pretrain 58 | self.load_network(self.netG, 'G', opt.which_epoch, pretrained_path) 59 | if self.isTrain: 60 | self.load_network(self.netD, 'D', opt.which_epoch, pretrained_path) 61 | if self.gen_features: 62 | self.load_network(self.netE, 'E', opt.which_epoch, pretrained_path) 63 | 64 | # set loss functions and optimizers 65 | if self.isTrain: 66 | if opt.pool_size > 0 and (len(self.gpu_ids)) > 1: 67 | raise NotImplementedError("Fake Pool Not Implemented for MultiGPU") 68 | self.fake_pool = ImagePool(opt.pool_size) 69 | self.old_lr = opt.lr 70 | 71 | # define loss functions 72 | self.loss_filter = self.init_loss_filter(not opt.no_ganFeat_loss, not opt.no_vgg_loss) 73 | 74 | self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor) 75 | self.criterionFeat = torch.nn.L1Loss() 76 | if not opt.no_vgg_loss: 77 | self.criterionVGG = networks.VGGLoss(self.gpu_ids) 78 | 79 | 80 | # Names so we can breakout loss 81 | self.loss_names = self.loss_filter('G_GAN','G_GAN_Feat','G_VGG','D_real', 'D_fake') 82 | 83 | # initialize optimizers 84 | # optimizer G 85 | if opt.niter_fix_global > 0: 86 | import sys 87 | if sys.version_info >= (3,0): 88 | finetune_list = set() 89 | else: 90 | from sets import Set 91 | finetune_list = Set() 92 | 93 | params_dict = dict(self.netG.named_parameters()) 94 | params = [] 95 | for key, value in params_dict.items(): 96 | if key.startswith('model' + str(opt.n_local_enhancers)): 97 | params += [value] 98 | finetune_list.add(key.split('.')[0]) 99 | print('------------- Only training the local enhancer network (for %d epochs) ------------' % opt.niter_fix_global) 100 | print('The layers that are finetuned are ', sorted(finetune_list)) 101 | else: 102 | params = list(self.netG.parameters()) 103 | if self.gen_features: 104 | params += list(self.netE.parameters()) 105 | self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) 106 | 107 | # optimizer D 108 | params = list(self.netD.parameters()) 109 | self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999)) 110 | 111 | def encode_input(self, label_map, inst_map=None, real_image=None, feat_map=None, infer=False): 112 | if self.opt.label_nc == 0: 113 | input_label = label_map.data.cuda() 114 | else: 115 | # create one-hot vector for label map 116 | size = label_map.size() 117 | oneHot_size = (size[0], self.opt.label_nc, size[2], size[3]) 118 | input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() 119 | input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0) 120 | if self.opt.data_type == 16: 121 | input_label = input_label.half() 122 | 123 | # get edges from instance map 124 | if not self.opt.no_instance: 125 | inst_map = inst_map.data.cuda() 126 | edge_map = self.get_edges(inst_map) 127 | input_label = torch.cat((input_label, edge_map), dim=1) 128 | input_label = Variable(input_label, volatile=infer) 129 | 130 | # real images for training 131 | if real_image is not None: 132 | real_image = Variable(real_image.data.cuda()) 133 | 134 | # instance map for feature encoding 135 | if self.use_features: 136 | # get precomputed feature maps 137 | if self.opt.load_features: 138 | feat_map = Variable(feat_map.data.cuda()) 139 | if self.opt.label_feat: 140 | inst_map = label_map.cuda() 141 | 142 | return input_label, inst_map, real_image, feat_map 143 | 144 | def discriminate(self, input_label, test_image, use_pool=False): 145 | input_concat = torch.cat((input_label, test_image.detach()), dim=1) 146 | if use_pool: 147 | fake_query = self.fake_pool.query(input_concat) 148 | return self.netD.forward(fake_query) 149 | else: 150 | return self.netD.forward(input_concat) 151 | 152 | def forward(self, label, inst, image, feat, infer=False): 153 | # Encode Inputs 154 | input_label, inst_map, real_image, feat_map = self.encode_input(label, inst, image, feat) 155 | 156 | # Fake Generation 157 | if self.use_features: 158 | if not self.opt.load_features: 159 | feat_map = self.netE.forward(real_image, inst_map) 160 | input_concat = torch.cat((input_label, feat_map), dim=1) 161 | else: 162 | input_concat = input_label 163 | fake_image = self.netG.forward(input_concat) 164 | 165 | # Fake Detection and Loss 166 | pred_fake_pool = self.discriminate(input_label, fake_image, use_pool=True) 167 | loss_D_fake = self.criterionGAN(pred_fake_pool, False) 168 | 169 | # Real Detection and Loss 170 | pred_real = self.discriminate(input_label, real_image) 171 | loss_D_real = self.criterionGAN(pred_real, True) 172 | 173 | # GAN loss (Fake Passability Loss) 174 | pred_fake = self.netD.forward(torch.cat((input_label, fake_image), dim=1)) 175 | loss_G_GAN = self.criterionGAN(pred_fake, True) 176 | 177 | # GAN feature matching loss 178 | loss_G_GAN_Feat = 0 179 | if not self.opt.no_ganFeat_loss: 180 | feat_weights = 4.0 / (self.opt.n_layers_D + 1) 181 | D_weights = 1.0 / self.opt.num_D 182 | for i in range(self.opt.num_D): 183 | for j in range(len(pred_fake[i])-1): 184 | loss_G_GAN_Feat += D_weights * feat_weights * \ 185 | self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) * self.opt.lambda_feat 186 | 187 | # VGG feature matching loss 188 | loss_G_VGG = 0 189 | if not self.opt.no_vgg_loss: 190 | loss_G_VGG = self.criterionVGG(fake_image, real_image) * self.opt.lambda_feat 191 | 192 | # Only return the fake_B image if necessary to save BW 193 | return [ self.loss_filter( loss_G_GAN, loss_G_GAN_Feat, loss_G_VGG, loss_D_real, loss_D_fake ), None if not infer else fake_image ] 194 | 195 | def inference(self, label, inst, image=None): 196 | # Encode Inputs 197 | image = Variable(image) if image is not None else None 198 | input_label, inst_map, real_image, _ = self.encode_input(Variable(label), Variable(inst), image, infer=True) 199 | 200 | # Fake Generation 201 | if self.use_features: 202 | if self.opt.use_encoded_image: 203 | # encode the real image to get feature map 204 | feat_map = self.netE.forward(real_image, inst_map) 205 | else: 206 | # sample clusters from precomputed features 207 | feat_map = self.sample_features(inst_map) 208 | input_concat = torch.cat((input_label, feat_map), dim=1) 209 | else: 210 | input_concat = input_label 211 | 212 | if torch.__version__.startswith('0.4'): 213 | with torch.no_grad(): 214 | fake_image = self.netG.forward(input_concat) 215 | else: 216 | fake_image = self.netG.forward(input_concat) 217 | return fake_image 218 | 219 | def sample_features(self, inst): 220 | # read precomputed feature clusters 221 | cluster_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, self.opt.cluster_path) 222 | features_clustered = np.load(cluster_path, encoding='latin1').item() 223 | 224 | # randomly sample from the feature clusters 225 | inst_np = inst.cpu().numpy().astype(int) 226 | feat_map = self.Tensor(inst.size()[0], self.opt.feat_num, inst.size()[2], inst.size()[3]) 227 | for i in np.unique(inst_np): 228 | label = i if i < 1000 else i//1000 229 | if label in features_clustered: 230 | feat = features_clustered[label] 231 | cluster_idx = np.random.randint(0, feat.shape[0]) 232 | 233 | idx = (inst == int(i)).nonzero() 234 | for k in range(self.opt.feat_num): 235 | feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] 236 | if self.opt.data_type==16: 237 | feat_map = feat_map.half() 238 | return feat_map 239 | 240 | def encode_features(self, image, inst): 241 | image = Variable(image.cuda(), volatile=True) 242 | feat_num = self.opt.feat_num 243 | h, w = inst.size()[2], inst.size()[3] 244 | block_num = 32 245 | feat_map = self.netE.forward(image, inst.cuda()) 246 | inst_np = inst.cpu().numpy().astype(int) 247 | feature = {} 248 | for i in range(self.opt.label_nc): 249 | feature[i] = np.zeros((0, feat_num+1)) 250 | for i in np.unique(inst_np): 251 | label = i if i < 1000 else i//1000 252 | idx = (inst == int(i)).nonzero() 253 | num = idx.size()[0] 254 | idx = idx[num//2,:] 255 | val = np.zeros((1, feat_num+1)) 256 | for k in range(feat_num): 257 | val[0, k] = feat_map[idx[0], idx[1] + k, idx[2], idx[3]].data[0] 258 | val[0, feat_num] = float(num) / (h * w // block_num) 259 | feature[label] = np.append(feature[label], val, axis=0) 260 | return feature 261 | 262 | def get_edges(self, t): 263 | edge = torch.cuda.ByteTensor(t.size()).zero_() 264 | edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1]) 265 | edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1]) 266 | edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) 267 | edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) 268 | if self.opt.data_type==16: 269 | return edge.half() 270 | else: 271 | return edge.float() 272 | 273 | def save(self, which_epoch): 274 | self.save_network(self.netG, 'G', which_epoch, self.gpu_ids) 275 | self.save_network(self.netD, 'D', which_epoch, self.gpu_ids) 276 | if self.gen_features: 277 | self.save_network(self.netE, 'E', which_epoch, self.gpu_ids) 278 | 279 | def update_fixed_params(self): 280 | # after fixing the global generator for a number of iterations, also start finetuning it 281 | params = list(self.netG.parameters()) 282 | if self.gen_features: 283 | params += list(self.netE.parameters()) 284 | self.optimizer_G = torch.optim.Adam(params, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) 285 | if self.opt.verbose: 286 | print('------------ Now also finetuning global generator -----------') 287 | 288 | def update_learning_rate(self): 289 | lrd = self.opt.lr / self.opt.niter_decay 290 | lr = self.old_lr - lrd 291 | for param_group in self.optimizer_D.param_groups: 292 | param_group['lr'] = lr 293 | for param_group in self.optimizer_G.param_groups: 294 | param_group['lr'] = lr 295 | if self.opt.verbose: 296 | print('update learning rate: %f -> %f' % (self.old_lr, lr)) 297 | self.old_lr = lr 298 | 299 | class InferenceModel(Pix2PixHDModel): 300 | def forward(self, inp): 301 | label, inst = inp 302 | return self.inference(label, inst) 303 | -------------------------------------------------------------------------------- /personal_scripts/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os, sys, cv2, shutil, argparse, time, concurrent.futures, imageio 3 | import warnings 4 | import itertools as it 5 | from PIL import Image 6 | from random import shuffle 7 | from math import floor 8 | from skimage import data 9 | import imgaug as ia 10 | import imgaug.augmenters as iaa 11 | from imgaug.augmentables.batches import Batch 12 | 13 | def quad_as_rect(quad): 14 | if quad[0] != quad[2]: return False 15 | if quad[0] != quad[2]: return False 16 | if quad[1] != quad[7]: return False 17 | if quad[4] != quad[6]: return False 18 | if quad[3] != quad[5]: return False 19 | return True 20 | 21 | def quad_to_rect(quad): 22 | assert(len(quad) == 8) 23 | assert(quad_as_rect(quad)) 24 | return (quad[0], quad[1], quad[4], quad[3]) 25 | 26 | def rect_to_quad(rect): 27 | assert(len(rect) == 4) 28 | return (rect[0], rect[1], rect[0], rect[3], rect[2], rect[3], rect[2], rect[1]) 29 | 30 | def shape_to_rect(shape): 31 | assert(len(shape) == 2) 32 | return (0, 0, shape[0], shape[1]) 33 | 34 | def griddify(rect, w_div, h_div): 35 | w = rect[2] - rect[0] 36 | h = rect[3] - rect[1] 37 | x_step = w / float(w_div) 38 | y_step = h / float(h_div) 39 | y = rect[1] 40 | grid_vertex_matrix = [] 41 | for _ in range(h_div + 1): 42 | grid_vertex_matrix.append([]) 43 | x = rect[0] 44 | for _ in range(w_div + 1): 45 | grid_vertex_matrix[-1].append([int(x), int(y)]) 46 | x += x_step 47 | y += y_step 48 | grid = np.array(grid_vertex_matrix) 49 | return grid 50 | 51 | def distort_grid(org_grid, max_shift): 52 | new_grid = np.copy(org_grid) 53 | x_min = np.min(new_grid[:, :, 0]) 54 | y_min = np.min(new_grid[:, :, 1]) 55 | x_max = np.max(new_grid[:, :, 0]) 56 | y_max = np.max(new_grid[:, :, 1]) 57 | new_grid += np.random.randint(- max_shift, max_shift + 1, new_grid.shape) 58 | new_grid[:, :, 0] = np.maximum(x_min, new_grid[:, :, 0]) 59 | new_grid[:, :, 1] = np.maximum(y_min, new_grid[:, :, 1]) 60 | new_grid[:, :, 0] = np.minimum(x_max, new_grid[:, :, 0]) 61 | new_grid[:, :, 1] = np.minimum(y_max, new_grid[:, :, 1]) 62 | return new_grid 63 | 64 | def grid_to_mesh(src_grid, dst_grid): 65 | assert(src_grid.shape == dst_grid.shape) 66 | mesh = [] 67 | for i in range(src_grid.shape[0] - 1): 68 | for j in range(src_grid.shape[1] - 1): 69 | src_quad = [src_grid[i , j , 0], src_grid[i , j , 1], 70 | src_grid[i + 1, j , 0], src_grid[i + 1, j , 1], 71 | src_grid[i + 1, j + 1, 0], src_grid[i + 1, j + 1, 1], 72 | src_grid[i , j + 1, 0], src_grid[i , j + 1, 1]] 73 | dst_quad = [dst_grid[i , j , 0], dst_grid[i , j , 1], 74 | dst_grid[i + 1, j , 0], dst_grid[i + 1, j , 1], 75 | dst_grid[i + 1, j + 1, 0], dst_grid[i + 1, j + 1, 1], 76 | dst_grid[i , j + 1, 0], dst_grid[i , j + 1, 1]] 77 | dst_rect = quad_to_rect(dst_quad) 78 | mesh.append([dst_rect, src_quad]) 79 | return mesh 80 | 81 | def resize(args): 82 | img_size,filepath = args 83 | sq_img = cv2.imread(filepath) # square image 84 | scaled_sq_img = resizeAndPad(sq_img, (img_size,img_size), 127) 85 | cv2.imwrite(filepath, scaled_sq_img) 86 | 87 | def resizeAndPad(img, size, padColor=0): 88 | 89 | h, w = img.shape[:2] 90 | sh, sw = size 91 | 92 | # interpolation method 93 | if h > sh or w > sw: # shrinking image 94 | interp = cv2.INTER_AREA 95 | else: # stretching image 96 | interp = cv2.INTER_CUBIC 97 | # aspect ratio of image 98 | aspect = w/h # if on Python 2, you might need to cast as a float: float(w)/h 99 | 100 | # compute scaling and pad sizing 101 | if aspect > 1: # horizontal image 102 | new_w = sw 103 | new_h = np.round(new_w/aspect).astype(int) 104 | pad_vert = (sh-new_h)/2 105 | pad_top, pad_bot = np.floor(pad_vert).astype(int), np.ceil(pad_vert).astype(int) 106 | pad_left, pad_right = 0, 0 107 | elif aspect < 1: # vertical image 108 | new_h = sh 109 | new_w = np.round(new_h*aspect).astype(int) 110 | pad_horz = (sw-new_w)/2 111 | pad_left, pad_right = np.floor(pad_horz).astype(int), np.ceil(pad_horz).astype(int) 112 | pad_top, pad_bot = 0, 0 113 | else: # square image 114 | new_h, new_w = sh, sw 115 | pad_left, pad_right, pad_top, pad_bot = 0, 0, 0, 0 116 | 117 | # set pad color 118 | if len(img.shape) is 3 and not isinstance(padColor, (list, tuple, np.ndarray)): # color image but only one color provided 119 | padColor = [padColor]*3 120 | 121 | # scale and pad 122 | scaled_img = cv2.resize(img, (new_w, new_h), interpolation=interp) 123 | scaled_img = cv2.copyMakeBorder(scaled_img, pad_top, pad_bot, pad_left, pad_right, borderType=cv2.BORDER_CONSTANT, value=padColor) 124 | 125 | return scaled_img 126 | 127 | def get_file_list_from_dir(datadir): 128 | all_files = os.listdir(os.path.relpath(datadir)) 129 | data_files = list(filter(lambda file: file.endswith('.png'), all_files)) 130 | return data_files 131 | 132 | def randomize_files(file_list): 133 | shuffle(file_list) 134 | return file_list 135 | 136 | def get_training_and_testing_sets(train, file_list): 137 | #Initial split for training and test data 138 | split = train 139 | split_index = floor(len(file_list) * split) 140 | training = file_list[:split_index] 141 | final_testing = file_list[split_index:] 142 | #Secondary split for validation 143 | split_index = floor(len(training) * split) 144 | final_training = training[:split_index] 145 | return final_training, final_testing 146 | 147 | def warp(args): 148 | filename, root, fold_A = args 149 | im = Image.open(os.path.join(root,filename)) 150 | dst_grid = griddify(shape_to_rect(im.size), 4, 4) 151 | src_grid = distort_grid(dst_grid, 50) 152 | mesh = grid_to_mesh(src_grid, dst_grid) 153 | im = im.transform(im.size, Image.MESH, mesh) 154 | im.save(os.path.join(fold_A,root.rsplit('/', 1)[-1],filename)) 155 | 156 | def png2jpg(args): 157 | filepath = args 158 | im = Image.open(filepath) 159 | im = im.convert('RGB') 160 | im.save(os.path.splitext(filepath)[0] + '.jpg', quality=100) 161 | 162 | 163 | def rotate_image(mat, angle): 164 | 165 | height, width = mat.shape[:2] # image shape has 3 dimensions 166 | image_center = (width/2, height/2) # getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape 167 | 168 | rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.) 169 | 170 | abs_cos = abs(rotation_mat[0,0]) 171 | abs_sin = abs(rotation_mat[0,1]) 172 | 173 | bound_w = int(height * abs_sin + width * abs_cos) 174 | bound_h = int(height * abs_cos + width * abs_sin) 175 | 176 | rotation_mat[0, 2] += bound_w/2 - image_center[0] 177 | rotation_mat[1, 2] += bound_h/2 - image_center[1] 178 | 179 | rotated_mat = cv2.warpAffine(mat, rotation_mat, (bound_w, bound_h)) 180 | return rotated_mat 181 | 182 | 183 | def imgaug(args): 184 | # Number of batches and batch size for this example 185 | filename, root, fold_A = args 186 | img = cv2.imread(os.path.join(root,filename)) 187 | print('image opened ' + os.path.join(root,filename)) 188 | batch_size = 4 189 | for i in range(0,batch_size): 190 | imageio.imwrite(os.path.join(root, os.path.splitext(filename)[0] + '_' + str(i) + '.jpg'), img) #convert the current image in B into a jpg from png 191 | nb_batches = 1 192 | 193 | # Example augmentation sequence to run in the background 194 | sometimes = lambda aug: iaa.Sometimes(0.4, aug) 195 | augseq = iaa.Sequential( 196 | [ 197 | iaa.PiecewiseAffine(scale=(0.01, 0.01005)) 198 | ] 199 | ) 200 | 201 | # Make batches out of the example image (here: 10 batches, each 32 times 202 | # the example image) 203 | batches = [] 204 | for _ in range(nb_batches): 205 | batches.append(Batch(images=[img] * batch_size)) 206 | 207 | #Save images 208 | for batch in augseq.augment_batches(batches, background=False): 209 | count = 0 210 | for img in batch.images_aug: 211 | path = os.path.join(fold_A,root.rsplit('/', 1)[-1], os.path.splitext(filename)[0] + '_' + str(count) + '.jpg') 212 | cv2.imwrite(path, img) 213 | print('image saved as: ' + path) 214 | count +=1 215 | 216 | def resize_and_rotate(filename): 217 | img = cv2.imread(filename)[:, :, ::-1] 218 | print('image opened: ' + filename) 219 | img = rotate_image(img,90) 220 | img = ia.imresize_single_image(img, (1024, 2048)) 221 | 222 | cv2.imwrite(filename, img) 223 | print('image rotated and resized saved as: ' + filename) 224 | 225 | 226 | 227 | def upsample(args): 228 | filename, root = args 229 | img = imageio.imread(os.path.join(root, filename)) 230 | img = rotate_image(img,-90) 231 | img = ia.imresize_single_image(img, (4400, 3400)) 232 | cv2.imwrite(os.path.join(root, filename)) 233 | 234 | 235 | def main(): 236 | #Note: Atm, it is necessary to have a 'train', 'val', and 'test' folder under folder A and B already before running the script 237 | #arg parsing 238 | parser = argparse.ArgumentParser('Completely preprocess all data for pix2pix') 239 | parser.add_argument('--raw_data', dest='raw_data', help='input directory for all initial flatbed images', type=str, default='./raw_data') 240 | parser.add_argument('--dest_dir', dest='dest_dir', help='destination directory for processed images', type=str, default='./datasets') 241 | parser.add_argument('--train', dest='train', help='% of data that are training (this will also determine validation split)', type=float, default=0.7) 242 | parser.add_argument('--imgsize', dest='imgsize', help='# of pixels (in an n pixels x n pixels square) that you want images to be resized to', type=int, default=512) 243 | parser.add_argument('--split', dest='split', help='determine if you want to split or not', type=bool, default=False) 244 | parser.add_argument('--resize', dest='resize', help='to resize a set of images', type=bool, default=False) 245 | parser.add_argument('--resize_fold', dest='resize_fold', help='the folder where you want jpgs to be resized', type=str, default='./') 246 | parser.add_argument('--preprocess', dest='preprocess', help='complete preprocessing', type=bool, default=False) 247 | parser.add_argument('--removepng', dest='removepng', help='removes pngs', type=bool, default=False) 248 | parser.add_argument('--png2jpg', dest='png2jpg', help='convert folder of choice to jpg', type=str, default="") 249 | 250 | 251 | args = parser.parse_args() 252 | warnings.simplefilter(action='ignore', category=FutureWarning) 253 | #print args 254 | for arg in vars(args): 255 | print('[%s] = ' % arg, getattr(args, arg)) 256 | 257 | start = time.clock() 258 | total = time.clock() 259 | 260 | #1. Begin splitting data 261 | if(args.split == True): 262 | print('Splitting data...') 263 | datadir = args.raw_data 264 | data_files = get_file_list_from_dir(datadir) 265 | randomized = randomize_files(data_files) 266 | training, testing = get_training_and_testing_sets(args.train, randomized) 267 | 268 | train_A_dir = os.path.join(args.dest_dir,'train_A') 269 | train_B_dir = os.path.join(args.dest_dir,'train_B')) 270 | test_A_dir = os.path.join(args.dest_dir,'test_A')) 271 | test_B_dir = os.path.join(args.dest_dir,'test_B')) 272 | 273 | os.mkdir(train_A_dir) 274 | os.mkdir(train_B_dir) 275 | os.mkdir(test_A_dir) 276 | os.mkdir(test_B_dir) 277 | 278 | count = 0 279 | for f in training: 280 | shutil.copy(args.raw_data + f, os.path.join( train_B_dir, "IMG_" + str(count) + ".png") 281 | count += 1 282 | for f in testing: 283 | shutil.copy(args.raw_data + f, os.path.join( test_B_dir, "IMG_" + str(count) + ".png") 284 | count +=1 285 | print('Splitting took {:.3f} seconds'.format((time.clock() - start)*1000.0)) 286 | 287 | if(args.preprocess == True): 288 | #2. Perform image warping and save those image into folder A 289 | print('Resizing, Warping, and Converting images...') 290 | total2 = time.clock() 291 | 292 | #imgaug 293 | start = time.clock() 294 | with concurrent.futures.ProcessPoolExecutor() as executor: 295 | for root, dirs, files in os.walk(): 296 | conv_list = ((file, root, args.fold_A) for file in files) 297 | executor.map(imgaug, conv_list) 298 | executor.shutdown(wait=True) 299 | print('Augmentation took {:.3f} seconds'.format((time.clock() - start)*1000.0)) 300 | 301 | if(args.png2jpg != ""): 302 | print("Converting all pngs to jpgs") 303 | with concurrent.futures.ProcessPoolExecutor() as executor: 304 | for root, dirs, files in os.walk(args.png2jpg): 305 | conv_list = (os.path.join(root,file) for file in files) 306 | executor.map(png2jpg, conv_list) 307 | executor.shutdown(wait=True) 308 | print("Conversion to jpgs done") 309 | 310 | if(args.resize == True): 311 | print('Resizing data') 312 | start = time.clock() 313 | with concurrent.futures.ProcessPoolExecutor() as executor: 314 | process_list = [] 315 | for filename in os.listdir(args.resize_fold): 316 | if filename.endswith(".jpg"): 317 | process_list.append(args.resize_fold + filename) 318 | executor.map(resize_and_rotate, process_list) 319 | executor.shutdown(wait=True) 320 | print('Rotations took {:.3f} seconds'.format((time.clock() - start)*1000.0)) 321 | 322 | if(args.removepng == True): 323 | #3. Removal of pngs 324 | print("Removing png's") 325 | for root, dirs, files in os.walk(args.fold_A): 326 | for file in files: 327 | if file.endswith('.png'): 328 | os.remove(os.path.join(root,file)) 329 | for root, dirs, files in os.walk(args.fold_B): 330 | for file in files: 331 | if file.endswith('.png'): 332 | os.remove(os.path.join(root,file)) 333 | print("Png's removed") 334 | 335 | if __name__ == "__main__": 336 | main() 337 | -------------------------------------------------------------------------------- /models/ui_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | from collections import OrderedDict 4 | import numpy as np 5 | import os 6 | from PIL import Image 7 | import util.util as util 8 | from .base_model import BaseModel 9 | from . import networks 10 | 11 | class UIModel(BaseModel): 12 | def name(self): 13 | return 'UIModel' 14 | 15 | def initialize(self, opt): 16 | assert(not opt.isTrain) 17 | BaseModel.initialize(self, opt) 18 | self.use_features = opt.instance_feat or opt.label_feat 19 | 20 | netG_input_nc = opt.label_nc 21 | if not opt.no_instance: 22 | netG_input_nc += 1 23 | if self.use_features: 24 | netG_input_nc += opt.feat_num 25 | 26 | self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG, 27 | opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers, 28 | opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids) 29 | self.load_network(self.netG, 'G', opt.which_epoch) 30 | 31 | print('---------- Networks initialized -------------') 32 | 33 | def toTensor(self, img, normalize=False): 34 | tensor = torch.from_numpy(np.array(img, np.int32, copy=False)) 35 | tensor = tensor.view(1, img.size[1], img.size[0], len(img.mode)) 36 | tensor = tensor.transpose(1, 2).transpose(1, 3).contiguous() 37 | if normalize: 38 | return (tensor.float()/255.0 - 0.5) / 0.5 39 | return tensor.float() 40 | 41 | def load_image(self, label_path, inst_path, feat_path): 42 | opt = self.opt 43 | # read label map 44 | label_img = Image.open(label_path) 45 | if label_path.find('face') != -1: 46 | label_img = label_img.convert('L') 47 | ow, oh = label_img.size 48 | w = opt.loadSize 49 | h = int(w * oh / ow) 50 | label_img = label_img.resize((w, h), Image.NEAREST) 51 | label_map = self.toTensor(label_img) 52 | 53 | # onehot vector input for label map 54 | self.label_map = label_map.cuda() 55 | oneHot_size = (1, opt.label_nc, h, w) 56 | input_label = self.Tensor(torch.Size(oneHot_size)).zero_() 57 | self.input_label = input_label.scatter_(1, label_map.long().cuda(), 1.0) 58 | 59 | # read instance map 60 | if not opt.no_instance: 61 | inst_img = Image.open(inst_path) 62 | inst_img = inst_img.resize((w, h), Image.NEAREST) 63 | self.inst_map = self.toTensor(inst_img).cuda() 64 | self.edge_map = self.get_edges(self.inst_map) 65 | self.net_input = Variable(torch.cat((self.input_label, self.edge_map), dim=1), volatile=True) 66 | else: 67 | self.net_input = Variable(self.input_label, volatile=True) 68 | 69 | self.features_clustered = np.load(feat_path).item() 70 | self.object_map = self.inst_map if opt.instance_feat else self.label_map 71 | 72 | object_np = self.object_map.cpu().numpy().astype(int) 73 | self.feat_map = self.Tensor(1, opt.feat_num, h, w).zero_() 74 | self.cluster_indices = np.zeros(self.opt.label_nc, np.uint8) 75 | for i in np.unique(object_np): 76 | label = i if i < 1000 else i//1000 77 | if label in self.features_clustered: 78 | feat = self.features_clustered[label] 79 | np.random.seed(i+1) 80 | cluster_idx = np.random.randint(0, feat.shape[0]) 81 | self.cluster_indices[label] = cluster_idx 82 | idx = (self.object_map == i).nonzero() 83 | self.set_features(idx, feat, cluster_idx) 84 | 85 | self.net_input_original = self.net_input.clone() 86 | self.label_map_original = self.label_map.clone() 87 | self.feat_map_original = self.feat_map.clone() 88 | if not opt.no_instance: 89 | self.inst_map_original = self.inst_map.clone() 90 | 91 | def reset(self): 92 | self.net_input = self.net_input_prev = self.net_input_original.clone() 93 | self.label_map = self.label_map_prev = self.label_map_original.clone() 94 | self.feat_map = self.feat_map_prev = self.feat_map_original.clone() 95 | if not self.opt.no_instance: 96 | self.inst_map = self.inst_map_prev = self.inst_map_original.clone() 97 | self.object_map = self.inst_map if self.opt.instance_feat else self.label_map 98 | 99 | def undo(self): 100 | self.net_input = self.net_input_prev 101 | self.label_map = self.label_map_prev 102 | self.feat_map = self.feat_map_prev 103 | if not self.opt.no_instance: 104 | self.inst_map = self.inst_map_prev 105 | self.object_map = self.inst_map if self.opt.instance_feat else self.label_map 106 | 107 | # get boundary map from instance map 108 | def get_edges(self, t): 109 | edge = torch.cuda.ByteTensor(t.size()).zero_() 110 | edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1]) 111 | edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1]) 112 | edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) 113 | edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:]) 114 | return edge.float() 115 | 116 | # change the label at the source position to the label at the target position 117 | def change_labels(self, click_src, click_tgt): 118 | y_src, x_src = click_src[0], click_src[1] 119 | y_tgt, x_tgt = click_tgt[0], click_tgt[1] 120 | label_src = int(self.label_map[0, 0, y_src, x_src]) 121 | inst_src = self.inst_map[0, 0, y_src, x_src] 122 | label_tgt = int(self.label_map[0, 0, y_tgt, x_tgt]) 123 | inst_tgt = self.inst_map[0, 0, y_tgt, x_tgt] 124 | 125 | idx_src = (self.inst_map == inst_src).nonzero() 126 | # need to change 3 things: label map, instance map, and feature map 127 | if idx_src.shape: 128 | # backup current maps 129 | self.backup_current_state() 130 | 131 | # change both the label map and the network input 132 | self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt 133 | self.net_input[idx_src[:,0], idx_src[:,1] + label_src, idx_src[:,2], idx_src[:,3]] = 0 134 | self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1 135 | 136 | # update the instance map (and the network input) 137 | if inst_tgt > 1000: 138 | # if different instances have different ids, give the new object a new id 139 | tgt_indices = (self.inst_map > label_tgt * 1000) & (self.inst_map < (label_tgt+1) * 1000) 140 | inst_tgt = self.inst_map[tgt_indices].max() + 1 141 | self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = inst_tgt 142 | self.net_input[:,-1,:,:] = self.get_edges(self.inst_map) 143 | 144 | # also copy the source features to the target position 145 | idx_tgt = (self.inst_map == inst_tgt).nonzero() 146 | if idx_tgt.shape: 147 | self.copy_features(idx_src, idx_tgt[0,:]) 148 | 149 | self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map)) 150 | 151 | # add strokes of target label in the image 152 | def add_strokes(self, click_src, label_tgt, bw, save): 153 | # get the region of the new strokes (bw is the brush width) 154 | size = self.net_input.size() 155 | h, w = size[2], size[3] 156 | idx_src = torch.LongTensor(bw**2, 4).fill_(0) 157 | for i in range(bw): 158 | idx_src[i*bw:(i+1)*bw, 2] = min(h-1, max(0, click_src[0]-bw//2 + i)) 159 | for j in range(bw): 160 | idx_src[i*bw+j, 3] = min(w-1, max(0, click_src[1]-bw//2 + j)) 161 | idx_src = idx_src.cuda() 162 | 163 | # again, need to update 3 things 164 | if idx_src.shape: 165 | # backup current maps 166 | if save: 167 | self.backup_current_state() 168 | 169 | # update the label map (and the network input) in the stroke region 170 | self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt 171 | for k in range(self.opt.label_nc): 172 | self.net_input[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = 0 173 | self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1 174 | 175 | # update the instance map (and the network input) 176 | self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt 177 | self.net_input[:,-1,:,:] = self.get_edges(self.inst_map) 178 | 179 | # also update the features if available 180 | if self.opt.instance_feat: 181 | feat = self.features_clustered[label_tgt] 182 | #np.random.seed(label_tgt+1) 183 | #cluster_idx = np.random.randint(0, feat.shape[0]) 184 | cluster_idx = self.cluster_indices[label_tgt] 185 | self.set_features(idx_src, feat, cluster_idx) 186 | 187 | self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map)) 188 | 189 | # add an object to the clicked position with selected style 190 | def add_objects(self, click_src, label_tgt, mask, style_id=0): 191 | y, x = click_src[0], click_src[1] 192 | mask = np.transpose(mask, (2, 0, 1))[np.newaxis,...] 193 | idx_src = torch.from_numpy(mask).cuda().nonzero() 194 | idx_src[:,2] += y 195 | idx_src[:,3] += x 196 | 197 | # backup current maps 198 | self.backup_current_state() 199 | 200 | # update label map 201 | self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt 202 | for k in range(self.opt.label_nc): 203 | self.net_input[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = 0 204 | self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1 205 | 206 | # update instance map 207 | self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt 208 | self.net_input[:,-1,:,:] = self.get_edges(self.inst_map) 209 | 210 | # update feature map 211 | self.set_features(idx_src, self.feat, style_id) 212 | 213 | self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map)) 214 | 215 | def single_forward(self, net_input, feat_map): 216 | net_input = torch.cat((net_input, feat_map), dim=1) 217 | fake_image = self.netG.forward(net_input) 218 | 219 | if fake_image.size()[0] == 1: 220 | return fake_image.data[0] 221 | return fake_image.data 222 | 223 | 224 | # generate all outputs for different styles 225 | def style_forward(self, click_pt, style_id=-1): 226 | if click_pt is None: 227 | self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map)) 228 | self.crop = None 229 | self.mask = None 230 | else: 231 | instToChange = int(self.object_map[0, 0, click_pt[0], click_pt[1]]) 232 | self.instToChange = instToChange 233 | label = instToChange if instToChange < 1000 else instToChange//1000 234 | self.feat = self.features_clustered[label] 235 | self.fake_image = [] 236 | self.mask = self.object_map == instToChange 237 | idx = self.mask.nonzero() 238 | self.get_crop_region(idx) 239 | if idx.size(): 240 | if style_id == -1: 241 | (min_y, min_x, max_y, max_x) = self.crop 242 | ### original 243 | for cluster_idx in range(self.opt.multiple_output): 244 | self.set_features(idx, self.feat, cluster_idx) 245 | fake_image = self.single_forward(self.net_input, self.feat_map) 246 | fake_image = util.tensor2im(fake_image[:,min_y:max_y,min_x:max_x]) 247 | self.fake_image.append(fake_image) 248 | """### To speed up previewing different style results, either crop or downsample the label maps 249 | if instToChange > 1000: 250 | (min_y, min_x, max_y, max_x) = self.crop 251 | ### crop 252 | _, _, h, w = self.net_input.size() 253 | offset = 512 254 | y_start, x_start = max(0, min_y-offset), max(0, min_x-offset) 255 | y_end, x_end = min(h, (max_y + offset)), min(w, (max_x + offset)) 256 | y_region = slice(y_start, y_start+(y_end-y_start)//16*16) 257 | x_region = slice(x_start, x_start+(x_end-x_start)//16*16) 258 | net_input = self.net_input[:,:,y_region,x_region] 259 | for cluster_idx in range(self.opt.multiple_output): 260 | self.set_features(idx, self.feat, cluster_idx) 261 | fake_image = self.single_forward(net_input, self.feat_map[:,:,y_region,x_region]) 262 | fake_image = util.tensor2im(fake_image[:,min_y-y_start:max_y-y_start,min_x-x_start:max_x-x_start]) 263 | self.fake_image.append(fake_image) 264 | else: 265 | ### downsample 266 | (min_y, min_x, max_y, max_x) = [crop//2 for crop in self.crop] 267 | net_input = self.net_input[:,:,::2,::2] 268 | size = net_input.size() 269 | net_input_batch = net_input.expand(self.opt.multiple_output, size[1], size[2], size[3]) 270 | for cluster_idx in range(self.opt.multiple_output): 271 | self.set_features(idx, self.feat, cluster_idx) 272 | feat_map = self.feat_map[:,:,::2,::2] 273 | if cluster_idx == 0: 274 | feat_map_batch = feat_map 275 | else: 276 | feat_map_batch = torch.cat((feat_map_batch, feat_map), dim=0) 277 | fake_image_batch = self.single_forward(net_input_batch, feat_map_batch) 278 | for i in range(self.opt.multiple_output): 279 | self.fake_image.append(util.tensor2im(fake_image_batch[i,:,min_y:max_y,min_x:max_x]))""" 280 | 281 | else: 282 | self.set_features(idx, self.feat, style_id) 283 | self.cluster_indices[label] = style_id 284 | self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map)) 285 | 286 | def backup_current_state(self): 287 | self.net_input_prev = self.net_input.clone() 288 | self.label_map_prev = self.label_map.clone() 289 | self.inst_map_prev = self.inst_map.clone() 290 | self.feat_map_prev = self.feat_map.clone() 291 | 292 | # crop the ROI and get the mask of the object 293 | def get_crop_region(self, idx): 294 | size = self.net_input.size() 295 | h, w = size[2], size[3] 296 | min_y, min_x = idx[:,2].min(), idx[:,3].min() 297 | max_y, max_x = idx[:,2].max(), idx[:,3].max() 298 | crop_min = 128 299 | if max_y - min_y < crop_min: 300 | min_y = max(0, (max_y + min_y) // 2 - crop_min // 2) 301 | max_y = min(h-1, min_y + crop_min) 302 | if max_x - min_x < crop_min: 303 | min_x = max(0, (max_x + min_x) // 2 - crop_min // 2) 304 | max_x = min(w-1, min_x + crop_min) 305 | self.crop = (min_y, min_x, max_y, max_x) 306 | self.mask = self.mask[:,:, min_y:max_y, min_x:max_x] 307 | 308 | # update the feature map once a new object is added or the label is changed 309 | def update_features(self, cluster_idx, mask=None, click_pt=None): 310 | self.feat_map_prev = self.feat_map.clone() 311 | # adding a new object 312 | if mask is not None: 313 | y, x = click_pt[0], click_pt[1] 314 | mask = np.transpose(mask, (2,0,1))[np.newaxis,...] 315 | idx = torch.from_numpy(mask).cuda().nonzero() 316 | idx[:,2] += y 317 | idx[:,3] += x 318 | # changing the label of an existing object 319 | else: 320 | idx = (self.object_map == self.instToChange).nonzero() 321 | 322 | # update feature map 323 | self.set_features(idx, self.feat, cluster_idx) 324 | 325 | # set the class features to the target feature 326 | def set_features(self, idx, feat, cluster_idx): 327 | for k in range(self.opt.feat_num): 328 | self.feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k] 329 | 330 | # copy the features at the target position to the source position 331 | def copy_features(self, idx_src, idx_tgt): 332 | for k in range(self.opt.feat_num): 333 | val = self.feat_map[idx_tgt[0], idx_tgt[1] + k, idx_tgt[2], idx_tgt[3]] 334 | self.feat_map[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = val 335 | 336 | def get_current_visuals(self, getLabel=False): 337 | mask = self.mask 338 | if self.mask is not None: 339 | mask = np.transpose(self.mask[0].cpu().float().numpy(), (1,2,0)).astype(np.uint8) 340 | 341 | dict_list = [('fake_image', self.fake_image), ('mask', mask)] 342 | 343 | if getLabel: # only output label map if needed to save bandwidth 344 | label = util.tensor2label(self.net_input.data[0], self.opt.label_nc) 345 | dict_list += [('label', label)] 346 | 347 | return OrderedDict(dict_list) 348 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import functools 4 | from torch.autograd import Variable 5 | import numpy as np 6 | 7 | ############################################################################### 8 | # Functions 9 | ############################################################################### 10 | def weights_init(m): 11 | classname = m.__class__.__name__ 12 | if classname.find('Conv') != -1: 13 | m.weight.data.normal_(0.0, 0.02) 14 | elif classname.find('BatchNorm2d') != -1: 15 | m.weight.data.normal_(1.0, 0.02) 16 | m.bias.data.fill_(0) 17 | 18 | def get_norm_layer(norm_type='instance'): 19 | if norm_type == 'batch': 20 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True) 21 | elif norm_type == 'instance': 22 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False) 23 | else: 24 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 25 | return norm_layer 26 | 27 | def define_G(input_nc, output_nc, ngf, netG, n_downsample_global=3, n_blocks_global=9, n_local_enhancers=1, 28 | n_blocks_local=3, norm='instance', gpu_ids=[]): 29 | norm_layer = get_norm_layer(norm_type=norm) 30 | if netG == 'global': 31 | netG = GlobalGenerator(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, norm_layer) 32 | elif netG == 'local': 33 | netG = LocalEnhancer(input_nc, output_nc, ngf, n_downsample_global, n_blocks_global, 34 | n_local_enhancers, n_blocks_local, norm_layer) 35 | elif netG == 'encoder': 36 | netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, norm_layer) 37 | else: 38 | raise('generator not implemented!') 39 | print(netG) 40 | if len(gpu_ids) > 0: 41 | assert(torch.cuda.is_available()) 42 | netG.cuda(gpu_ids[0]) 43 | netG.apply(weights_init) 44 | return netG 45 | 46 | def define_D(input_nc, ndf, n_layers_D, norm='instance', use_sigmoid=False, num_D=1, getIntermFeat=False, gpu_ids=[]): 47 | norm_layer = get_norm_layer(norm_type=norm) 48 | netD = MultiscaleDiscriminator(input_nc, ndf, n_layers_D, norm_layer, use_sigmoid, num_D, getIntermFeat) 49 | print(netD) 50 | if len(gpu_ids) > 0: 51 | assert(torch.cuda.is_available()) 52 | netD.cuda(gpu_ids[0]) 53 | netD.apply(weights_init) 54 | return netD 55 | 56 | def print_network(net): 57 | if isinstance(net, list): 58 | net = net[0] 59 | num_params = 0 60 | for param in net.parameters(): 61 | num_params += param.numel() 62 | print(net) 63 | print('Total number of parameters: %d' % num_params) 64 | 65 | ############################################################################## 66 | # Losses 67 | ############################################################################## 68 | class GANLoss(nn.Module): 69 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, 70 | tensor=torch.FloatTensor): 71 | super(GANLoss, self).__init__() 72 | self.real_label = target_real_label 73 | self.fake_label = target_fake_label 74 | self.real_label_var = None 75 | self.fake_label_var = None 76 | self.Tensor = tensor 77 | if use_lsgan: 78 | self.loss = nn.MSELoss() 79 | else: 80 | self.loss = nn.BCELoss() 81 | 82 | def get_target_tensor(self, _input, target_is_real): 83 | target_tensor = None 84 | if target_is_real: 85 | create_label = ((self.real_label_var is None) or 86 | (self.real_label_var.numel() != _input.numel())) 87 | if create_label: 88 | real_tensor = self.Tensor(_input.size()).fill_(self.real_label) 89 | self.real_label_var = Variable(real_tensor, requires_grad=False) 90 | target_tensor = self.real_label_var 91 | else: 92 | create_label = ((self.fake_label_var is None) or 93 | (self.fake_label_var.numel() != _input.numel())) 94 | if create_label: 95 | fake_tensor = self.Tensor(_input.size()).fill_(self.fake_label) 96 | self.fake_label_var = Variable(fake_tensor, requires_grad=False) 97 | target_tensor = self.fake_label_var 98 | return target_tensor 99 | 100 | def __call__(self, _input, target_is_real): 101 | if isinstance(_input[0], list): 102 | loss = 0 103 | for input_i in _input: 104 | pred = input_i[-1] 105 | target_tensor = self.get_target_tensor(pred, target_is_real) 106 | loss += self.loss(pred, target_tensor) 107 | return loss 108 | else: 109 | target_tensor = self.get_target_tensor(_input[-1], target_is_real) 110 | return self.loss(_input[-1], target_tensor) 111 | 112 | class VGGLoss(nn.Module): 113 | def __init__(self, gpu_ids): 114 | super(VGGLoss, self).__init__() 115 | self.vgg = Vgg19().cuda() 116 | self.criterion = nn.L1Loss() 117 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] 118 | 119 | def forward(self, x, y): 120 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 121 | loss = 0 122 | for i in range(len(x_vgg)): 123 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 124 | return loss 125 | 126 | ############################################################################## 127 | # Generator 128 | ############################################################################## 129 | class LocalEnhancer(nn.Module): 130 | def __init__(self, input_nc, output_nc, ngf=32, n_downsample_global=3, n_blocks_global=9, 131 | n_local_enhancers=1, n_blocks_local=3, norm_layer=nn.BatchNorm2d, padding_type='reflect'): 132 | super(LocalEnhancer, self).__init__() 133 | self.n_local_enhancers = n_local_enhancers 134 | 135 | ###### global generator model ##### 136 | ngf_global = ngf * (2**n_local_enhancers) 137 | model_global = GlobalGenerator(input_nc, output_nc, ngf_global, n_downsample_global, n_blocks_global, norm_layer).model 138 | model_global = [model_global[i] for i in range(len(model_global)-3)] # get rid of final convolution layers 139 | self.model = nn.Sequential(*model_global) 140 | 141 | ###### local enhancer layers ##### 142 | for n in range(1, n_local_enhancers+1): 143 | ### downsample 144 | ngf_global = ngf * (2**(n_local_enhancers-n)) 145 | model_downsample = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0), 146 | norm_layer(ngf_global), nn.ReLU(True), 147 | nn.Conv2d(ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1), 148 | norm_layer(ngf_global * 2), nn.ReLU(True)] 149 | ### residual blocks 150 | model_upsample = [] 151 | for i in range(n_blocks_local): 152 | model_upsample += [ResnetBlock(ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer)] 153 | 154 | ### upsample 155 | model_upsample += [nn.ConvTranspose2d(ngf_global * 2, ngf_global, kernel_size=3, stride=2, padding=1, output_padding=1), 156 | norm_layer(ngf_global), nn.ReLU(True)] 157 | 158 | ### final convolution 159 | if n == n_local_enhancers: 160 | model_upsample += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] 161 | 162 | setattr(self, 'model'+str(n)+'_1', nn.Sequential(*model_downsample)) 163 | setattr(self, 'model'+str(n)+'_2', nn.Sequential(*model_upsample)) 164 | 165 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) 166 | 167 | def forward(self, _input): 168 | ### create input pyramid 169 | input_downsampled = [_input] 170 | for i in range(self.n_local_enhancers): 171 | input_downsampled.append(self.downsample(input_downsampled[-1])) 172 | 173 | ### output at coarest level 174 | output_prev = self.model(input_downsampled[-1]) 175 | ### build up one layer at a time 176 | for n_local_enhancers in range(1, self.n_local_enhancers+1): 177 | model_downsample = getattr(self, 'model'+str(n_local_enhancers)+'_1') 178 | model_upsample = getattr(self, 'model'+str(n_local_enhancers)+'_2') 179 | input_i = input_downsampled[self.n_local_enhancers-n_local_enhancers] 180 | output_prev = model_upsample(model_downsample(input_i) + output_prev) 181 | return output_prev 182 | 183 | class GlobalGenerator(nn.Module): 184 | def __init__(self, input_nc, output_nc, ngf=64, n_downsampling=3, n_blocks=9, norm_layer=nn.BatchNorm2d, 185 | padding_type='reflect'): 186 | assert(n_blocks >= 0) 187 | super(GlobalGenerator, self).__init__() 188 | activation = nn.ReLU(True) 189 | 190 | model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), norm_layer(ngf), activation] 191 | ### downsample 192 | for i in range(n_downsampling): 193 | mult = 2**i 194 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), 195 | norm_layer(ngf * mult * 2), activation] 196 | 197 | ### resnet blocks 198 | mult = 2**n_downsampling 199 | for i in range(n_blocks): 200 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=activation, norm_layer=norm_layer)] 201 | 202 | ### upsample 203 | for i in range(n_downsampling): 204 | mult = 2**(n_downsampling - i) 205 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), 206 | norm_layer(int(ngf * mult / 2)), activation] 207 | model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] 208 | self.model = nn.Sequential(*model) 209 | 210 | def forward(self, _input): 211 | return self.model(_input) 212 | 213 | # Define a resnet block 214 | class ResnetBlock(nn.Module): 215 | def __init__(self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False): 216 | super(ResnetBlock, self).__init__() 217 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout) 218 | 219 | def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout): 220 | conv_block = [] 221 | p = 0 222 | if padding_type == 'reflect': 223 | conv_block += [nn.ReflectionPad2d(1)] 224 | elif padding_type == 'replicate': 225 | conv_block += [nn.ReplicationPad2d(1)] 226 | elif padding_type == 'zero': 227 | p = 1 228 | else: 229 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 230 | 231 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), 232 | norm_layer(dim), 233 | activation] 234 | if use_dropout: 235 | conv_block += [nn.Dropout(0.5)] 236 | 237 | p = 0 238 | if padding_type == 'reflect': 239 | conv_block += [nn.ReflectionPad2d(1)] 240 | elif padding_type == 'replicate': 241 | conv_block += [nn.ReplicationPad2d(1)] 242 | elif padding_type == 'zero': 243 | p = 1 244 | else: 245 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 246 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p), 247 | norm_layer(dim)] 248 | 249 | return nn.Sequential(*conv_block) 250 | 251 | def forward(self, x): 252 | out = x + self.conv_block(x) 253 | return out 254 | 255 | class Encoder(nn.Module): 256 | def __init__(self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d): 257 | super(Encoder, self).__init__() 258 | self.output_nc = output_nc 259 | 260 | model = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0), 261 | norm_layer(ngf), nn.ReLU(True)] 262 | ### downsample 263 | for i in range(n_downsampling): 264 | mult = 2**i 265 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1), 266 | norm_layer(ngf * mult * 2), nn.ReLU(True)] 267 | 268 | ### upsample 269 | for i in range(n_downsampling): 270 | mult = 2**(n_downsampling - i) 271 | model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1), 272 | norm_layer(int(ngf * mult / 2)), nn.ReLU(True)] 273 | 274 | model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()] 275 | self.model = nn.Sequential(*model) 276 | 277 | def forward(self, _input, inst): 278 | outputs = self.model(_input) 279 | 280 | # instance-wise average pooling 281 | outputs_mean = outputs.clone() 282 | inst_list = np.unique(inst.cpu().numpy().astype(int)) 283 | for i in inst_list: 284 | for b in range(input.size()[0]): 285 | indices = (inst[b:b+1] == int(i)).nonzero() # n x 4 286 | for j in range(self.output_nc): 287 | output_ins = outputs[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]] 288 | mean_feat = torch.mean(output_ins).expand_as(output_ins) 289 | outputs_mean[indices[:,0] + b, indices[:,1] + j, indices[:,2], indices[:,3]] = mean_feat 290 | return outputs_mean 291 | 292 | class MultiscaleDiscriminator(nn.Module): 293 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, 294 | use_sigmoid=False, num_D=3, getIntermFeat=False): 295 | super(MultiscaleDiscriminator, self).__init__() 296 | self.num_D = num_D 297 | self.n_layers = n_layers 298 | self.getIntermFeat = getIntermFeat 299 | 300 | for i in range(num_D): 301 | netD = NLayerDiscriminator(input_nc, ndf, n_layers, norm_layer, use_sigmoid, getIntermFeat) 302 | if getIntermFeat: 303 | for j in range(n_layers+2): 304 | setattr(self, 'scale'+str(i)+'_layer'+str(j), getattr(netD, 'model'+str(j))) 305 | else: 306 | setattr(self, 'layer'+str(i), netD.model) 307 | 308 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) 309 | 310 | def singleD_forward(self, model, _input): 311 | if self.getIntermFeat: 312 | result = [_input] 313 | for i in range(len(model)): 314 | result.append(model[i](result[-1])) 315 | return result[1:] 316 | else: 317 | return [model(_input)] 318 | 319 | def forward(self, _input): 320 | num_D = self.num_D 321 | result = [] 322 | input_downsampled = _input 323 | for i in range(num_D): 324 | if self.getIntermFeat: 325 | model = [getattr(self, 'scale'+str(num_D-1-i)+'_layer'+str(j)) for j in range(self.n_layers+2)] 326 | else: 327 | model = getattr(self, 'layer'+str(num_D-1-i)) 328 | result.append(self.singleD_forward(model, input_downsampled)) 329 | if i != (num_D-1): 330 | input_downsampled = self.downsample(input_downsampled) 331 | return result 332 | 333 | # Defines the PatchGAN discriminator with the specified arguments. 334 | class NLayerDiscriminator(nn.Module): 335 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=False): 336 | super(NLayerDiscriminator, self).__init__() 337 | self.getIntermFeat = getIntermFeat 338 | self.n_layers = n_layers 339 | 340 | kw = 4 341 | padw = int(np.ceil((kw-1.0)/2)) 342 | sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] 343 | 344 | nf = ndf 345 | for n in range(1, n_layers): 346 | nf_prev = nf 347 | nf = min(nf * 2, 512) 348 | sequence += [[ 349 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), 350 | norm_layer(nf), nn.LeakyReLU(0.2, True) 351 | ]] 352 | 353 | nf_prev = nf 354 | nf = min(nf * 2, 512) 355 | sequence += [[ 356 | nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), 357 | norm_layer(nf), 358 | nn.LeakyReLU(0.2, True) 359 | ]] 360 | 361 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 362 | 363 | if use_sigmoid: 364 | sequence += [[nn.Sigmoid()]] 365 | 366 | if getIntermFeat: 367 | for n in range(len(sequence)): 368 | setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) 369 | else: 370 | sequence_stream = [] 371 | for n in range(len(sequence)): 372 | sequence_stream += sequence[n] 373 | self.model = nn.Sequential(*sequence_stream) 374 | 375 | def forward(self, _input): 376 | if self.getIntermFeat: 377 | res = [_input] 378 | for n in range(self.n_layers+2): 379 | model = getattr(self, 'model'+str(n)) 380 | res.append(model(res[-1])) 381 | return res[1:] 382 | else: 383 | return self.model(_input) 384 | 385 | from torchvision import models 386 | class Vgg19(torch.nn.Module): 387 | def __init__(self, requires_grad=False): 388 | super(Vgg19, self).__init__() 389 | vgg_pretrained_features = models.vgg19(pretrained=True).features 390 | self.slice1 = torch.nn.Sequential() 391 | self.slice2 = torch.nn.Sequential() 392 | self.slice3 = torch.nn.Sequential() 393 | self.slice4 = torch.nn.Sequential() 394 | self.slice5 = torch.nn.Sequential() 395 | for x in range(2): 396 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 397 | for x in range(2, 7): 398 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 399 | for x in range(7, 12): 400 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 401 | for x in range(12, 21): 402 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 403 | for x in range(21, 30): 404 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 405 | if not requires_grad: 406 | for param in self.parameters(): 407 | param.requires_grad = False 408 | 409 | def forward(self, X): 410 | h_relu1 = self.slice1(X) 411 | h_relu2 = self.slice2(h_relu1) 412 | h_relu3 = self.slice3(h_relu2) 413 | h_relu4 = self.slice4(h_relu3) 414 | h_relu5 = self.slice5(h_relu4) 415 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 416 | return out 417 | --------------------------------------------------------------------------------