├── .gitattributes ├── imgs ├── teaser.png ├── teaser_1.jpg └── teaser_2.jpg ├── util ├── __init__.py ├── iter_counter.py └── util.py ├── options ├── __init__.py ├── test_options.py ├── train_options.py └── base_options.py ├── trainers ├── __init__.py └── pix2pix_trainer.py ├── requirements.txt ├── CODE_OF_CONDUCT.md ├── data ├── preprocess.py ├── __init__.py ├── base_dataset.py ├── pix2pix_dataset.py └── deepfashionHD_dataset.py ├── LICENSE ├── SUPPORT.md ├── models ├── __init__.py ├── networks │ ├── ops.py │ ├── __init__.py │ ├── generator.py │ ├── base_network.py │ ├── ContextualLoss.py │ ├── convgru.py │ ├── loss.py │ ├── normalization.py │ ├── discriminator.py │ ├── architecture.py │ ├── patch_match.py │ └── correspondence.py └── pix2pix_model.py ├── test.py ├── .gitignore ├── SECURITY.md ├── train.py └── README.md /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | -------------------------------------------------------------------------------- /imgs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/CoCosNet-v2/HEAD/imgs/teaser.png -------------------------------------------------------------------------------- /imgs/teaser_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/CoCosNet-v2/HEAD/imgs/teaser_1.jpg -------------------------------------------------------------------------------- /imgs/teaser_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/CoCosNet-v2/HEAD/imgs/teaser_2.jpg -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.0 2 | torchvision 3 | matplotlib 4 | pillow 5 | imageio 6 | numpy 7 | pandas 8 | scipy 9 | scikit-image 10 | opencv-python -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | from .base_options import BaseOptions 5 | 6 | 7 | class TestOptions(BaseOptions): 8 | def initialize(self, parser): 9 | BaseOptions.initialize(self, parser) 10 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 11 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 12 | parser.add_argument('--how_many', type=int, default=float("inf"), help='how many test images to run') 13 | parser.add_argument('--save_per_img', action='store_true', help='if specified, save per image') 14 | parser.add_argument('--show_corr', action='store_true', help='if specified, save bilinear upsample correspondence') 15 | parser.set_defaults(preprocess_mode='scale_width_and_crop', crop_size=256, load_size=256, display_winsize=256) 16 | parser.set_defaults(serial_batches=True) 17 | parser.set_defaults(no_flip=True) 18 | parser.set_defaults(phase='test') 19 | self.isTrain = False 20 | return parser 21 | -------------------------------------------------------------------------------- /data/preprocess.py: -------------------------------------------------------------------------------- 1 | import os 2 | import skimage.util as util 3 | from skimage import io 4 | from skimage.transform import resize 5 | 6 | 7 | with open('train.txt', 'r') as fd: 8 | image_files = fd.readlines() 9 | 10 | total = len(image_files) 11 | cnt = 0 12 | 13 | # path/to/deepfashion directory 14 | root = '/path/to/deepfashion' 15 | # path/to/save directory 16 | save_root = 'path/to/save' 17 | 18 | for image_file in image_files: 19 | image_file = os.path.join(root, image_file).strip() 20 | image = io.imread(image_file) 21 | pad_width_1 = (1101-750) // 2 22 | pad_width_2 = (1101-750) // 2 + 1 23 | image_pad = util.pad(image, ((0,0),(pad_width_1, pad_width_2),(0,0)), constant_values=232) 24 | image_resize = resize(image_pad, (1024, 1024)) 25 | image_resize = (image_resize * 255).astype('uint8') 26 | dst_file = os.path.dirname(image_file).replace(root, save_root) 27 | os.makedirs(dst_file, exist_ok=True) 28 | dst_file = os.path.join(dst_file, os.path.basename(image_file)) 29 | # dst_file = dst_file.replace('.jpg', '.png') 30 | io.imsave(dst_file, image_resize) 31 | cnt += 1 32 | if cnt % 20 == 0: 33 | print('Processing: %d / %d' % (cnt, total)) 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). 7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | import importlib 6 | 7 | 8 | def find_model_using_name(model_name): 9 | # Given the option --model [modelname], 10 | # the file "models/modelname_model.py" 11 | # will be imported. 12 | model_filename = "models." + model_name + "_model" 13 | modellib = importlib.import_module(model_filename) 14 | # In the file, the class called ModelNameModel() will 15 | # be instantiated. It has to be a subclass of torch.nn.Module, 16 | # and it is case-insensitive. 17 | model = None 18 | target_model_name = model_name.replace('_', '') + 'model' 19 | for name, cls in modellib.__dict__.items(): 20 | if name.lower() == target_model_name.lower() \ 21 | and issubclass(cls, torch.nn.Module): 22 | model = cls 23 | if model is None: 24 | print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name)) 25 | exit(0) 26 | return model 27 | 28 | 29 | def get_option_setter(model_name): 30 | model_class = find_model_using_name(model_name) 31 | return model_class.modify_commandline_options 32 | 33 | 34 | def create_model(opt): 35 | model = find_model_using_name(opt.model) 36 | instance = model(opt) 37 | print("model [%s] was created" % (type(instance).__name__)) 38 | return instance 39 | -------------------------------------------------------------------------------- /models/networks/ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | 10 | def convert_1d_to_2d(index, base=64): 11 | x = index // base 12 | y = index % base 13 | return x,y 14 | 15 | 16 | def convert_2d_to_1d(x, y, base=64): 17 | return x*base+y 18 | 19 | 20 | def batch_meshgrid(shape, device): 21 | batch_size, _, height, width = shape 22 | x_range = torch.arange(0.0, width, device=device) 23 | y_range = torch.arange(0.0, height, device=device) 24 | x_coordinate, y_coordinate = torch.meshgrid(x_range, y_range) 25 | x_coordinate = x_coordinate.expand(batch_size, -1, -1).unsqueeze(1) 26 | y_coordinate = y_coordinate.expand(batch_size, -1, -1).unsqueeze(1) 27 | return x_coordinate, y_coordinate 28 | 29 | 30 | def inds_to_offset(inds): 31 | """ 32 | inds: b x number x h x w 33 | """ 34 | shape = inds.size() 35 | device = inds.device 36 | x_coordinate, y_coordinate = batch_meshgrid(shape, device) 37 | batch_size, _, height, width = shape 38 | x = inds // width 39 | y = inds % width 40 | return x - x_coordinate, y - y_coordinate 41 | 42 | 43 | def offset_to_inds(offset_x, offset_y): 44 | shape = offset_x.size() 45 | device = offset_x.device 46 | x_coordinate, y_coordinate = batch_meshgrid(shape, device) 47 | h, w = offset_x.size()[2:] 48 | x = torch.clamp(x_coordinate + offset_x, 0, h-1) 49 | y = torch.clamp(y_coordinate + offset_y, 0, w-1) 50 | return x * offset_x.size()[3] + y 51 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import importlib 5 | import torch.utils.data 6 | from data.base_dataset import BaseDataset 7 | 8 | 9 | def find_dataset_using_name(dataset_name): 10 | dataset_filename = "data." + dataset_name + "_dataset" 11 | datasetlib = importlib.import_module(dataset_filename) 12 | dataset = None 13 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 14 | for name, cls in datasetlib.__dict__.items(): 15 | if name.lower() == target_dataset_name.lower() \ 16 | and issubclass(cls, BaseDataset): 17 | dataset = cls 18 | if dataset is None: 19 | raise ValueError("In %s.py, there should be a subclass of BaseDataset " 20 | "with class name that matches %s in lowercase." % 21 | (dataset_filename, target_dataset_name)) 22 | return dataset 23 | 24 | 25 | def get_option_setter(dataset_name): 26 | dataset_class = find_dataset_using_name(dataset_name) 27 | return dataset_class.modify_commandline_options 28 | 29 | 30 | def create_dataloader(opt): 31 | dataset = find_dataset_using_name(opt.dataset_mode) 32 | instance = dataset() 33 | instance.initialize(opt) 34 | print("Dataset [%s] of size %d was created" % (type(instance).__name__, len(instance))) 35 | dataloader = torch.utils.data.DataLoader( 36 | instance, 37 | batch_size=opt.batchSize, 38 | shuffle=(opt.phase=='train'), 39 | num_workers=int(opt.nThreads), 40 | drop_last=(opt.phase=='train') 41 | ) 42 | return dataloader 43 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | from torchvision.utils import save_image 6 | import os 7 | import imageio 8 | import numpy as np 9 | import data 10 | from util.util import mkdir 11 | from options.test_options import TestOptions 12 | from models.pix2pix_model import Pix2PixModel 13 | 14 | 15 | if __name__ == '__main__': 16 | opt = TestOptions().parse() 17 | dataloader = data.create_dataloader(opt) 18 | model = Pix2PixModel(opt) 19 | if len(opt.gpu_ids) > 1: 20 | model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids) 21 | else: 22 | model.to(opt.gpu_ids[0]) 23 | model.eval() 24 | save_root = os.path.join(opt.checkpoints_dir, opt.name, 'test') 25 | mkdir(save_root) 26 | for i, data_i in enumerate(dataloader): 27 | print('{} / {}'.format(i, len(dataloader))) 28 | if i * opt.batchSize >= opt.how_many: 29 | break 30 | imgs_num = data_i['label'].shape[0] 31 | out = model(data_i, mode='inference') 32 | if opt.save_per_img: 33 | try: 34 | for it in range(imgs_num): 35 | save_name = os.path.join(save_root, '%08d_%04d.png' % (i, it)) 36 | save_image(out['fake_image'][it:it+1], save_name, padding=0, normalize=True) 37 | except OSError as err: 38 | print(err) 39 | else: 40 | label = data_i['label'][:,:3,:,:] 41 | imgs = torch.cat((label.cpu(), data_i['ref'].cpu(), out['fake_image'].data.cpu()), 0) 42 | try: 43 | save_name = os.path.join(save_root, '%08d.png' % i) 44 | save_image(imgs, save_name, nrow=imgs_num, padding=0, normalize=True) 45 | except OSError as err: 46 | print(err) 47 | -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | from models.networks.base_network import BaseNetwork 6 | from models.networks.loss import * 7 | from models.networks.discriminator import * 8 | from models.networks.generator import * 9 | from models.networks.ContextualLoss import * 10 | from models.networks.correspondence import * 11 | from models.networks.ops import * 12 | import util.util as util 13 | 14 | 15 | def find_network_using_name(target_network_name, filename, add=True): 16 | target_class_name = target_network_name + filename if add else target_network_name 17 | module_name = 'models.networks.' + filename 18 | network = util.find_class_in_module(target_class_name, module_name) 19 | assert issubclass(network, BaseNetwork), \ 20 | "Class %s should be a subclass of BaseNetwork" % network 21 | return network 22 | 23 | 24 | def modify_commandline_options(parser, is_train): 25 | opt, _ = parser.parse_known_args() 26 | netG_cls = find_network_using_name(opt.netG, 'generator') 27 | parser = netG_cls.modify_commandline_options(parser, is_train) 28 | if is_train: 29 | netD_cls = find_network_using_name(opt.netD, 'discriminator') 30 | parser = netD_cls.modify_commandline_options(parser, is_train) 31 | return parser 32 | 33 | 34 | def create_network(cls, opt): 35 | net = cls(opt) 36 | net.print_network() 37 | if len(opt.gpu_ids) > 0: 38 | assert(torch.cuda.is_available()) 39 | net.cuda() 40 | net.init_weights(opt.init_type, opt.init_variance) 41 | return net 42 | 43 | 44 | def define_G(opt): 45 | netG_cls = find_network_using_name(opt.netG, 'generator') 46 | return create_network(netG_cls, opt) 47 | 48 | 49 | def define_D(opt): 50 | netD_cls = find_network_using_name(opt.netD, 'discriminator') 51 | return create_network(netD_cls, opt) 52 | 53 | def define_Corr(opt): 54 | netCoor_cls = find_network_using_name(opt.netCorr, 'correspondence') 55 | return create_network(netCoor_cls, opt) 56 | -------------------------------------------------------------------------------- /models/networks/generator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | from torch.autograd import Function 8 | 9 | from models.networks.base_network import BaseNetwork 10 | from models.networks.architecture import SPADEResnetBlock 11 | 12 | 13 | class SPADEGenerator(BaseNetwork): 14 | @staticmethod 15 | def modify_commandline_options(parser, is_train): 16 | parser.set_defaults(norm_G='spectralspadesyncbatch3x3') 17 | return parser 18 | 19 | def __init__(self, opt): 20 | super().__init__() 21 | self.opt = opt 22 | nf = opt.ngf 23 | self.sw, self.sh = self.compute_latent_vector_size(opt) 24 | ic = 4*3+opt.label_nc 25 | self.fc = nn.Conv2d(ic, 8 * nf, 3, padding=1) 26 | self.head_0 = SPADEResnetBlock(8 * nf, 8 * nf, opt) 27 | self.G_middle_0 = SPADEResnetBlock(8 * nf, 8 * nf, opt) 28 | self.G_middle_1 = SPADEResnetBlock(8 * nf, 8 * nf, opt) 29 | self.up_0 = SPADEResnetBlock(8 * nf, 8 * nf, opt) 30 | self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt) 31 | self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt) 32 | self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt) 33 | final_nc = nf 34 | self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1) 35 | self.up = nn.Upsample(scale_factor=2) 36 | 37 | def compute_latent_vector_size(self, opt): 38 | num_up_layers = 5 39 | sw = opt.crop_size // (2**num_up_layers) 40 | sh = round(sw / opt.aspect_ratio) 41 | return sw, sh 42 | 43 | def forward(self, input, warp_out=None): 44 | seg = torch.cat((F.interpolate(warp_out[0], size=(512, 512)), F.interpolate(warp_out[1], size=(512, 512)), F.interpolate(warp_out[2], size=(512, 512)), warp_out[3], input), dim=1) 45 | x = F.interpolate(seg, size=(self.sh, self.sw)) 46 | x = self.fc(x) 47 | x = self.head_0(x, seg) 48 | x = self.up(x) 49 | x = self.G_middle_0(x, seg) 50 | x = self.G_middle_1(x, seg) 51 | x = self.up(x) 52 | x = self.up_0(x, seg) 53 | x = self.up(x) 54 | x = self.up_1(x, seg) 55 | x = self.up(x) 56 | x = self.up_2(x, seg) 57 | x = self.up(x) 58 | x = self.up_3(x, seg) 59 | x = self.conv_img(F.leaky_relu(x, 2e-1)) 60 | x = torch.tanh(x) 61 | return x 62 | -------------------------------------------------------------------------------- /models/networks/base_network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | 5 | import torch.nn as nn 6 | from torch.nn import init 7 | 8 | 9 | class BaseNetwork(nn.Module): 10 | def __init__(self): 11 | super(BaseNetwork, self).__init__() 12 | 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train): 15 | return parser 16 | 17 | def print_network(self): 18 | if isinstance(self, list): 19 | self = self[0] 20 | num_params = 0 21 | for param in self.parameters(): 22 | num_params += param.numel() 23 | print('Network [%s] was created. Total number of parameters: %.1f million. ' 24 | 'To see the architecture, do print(network).' 25 | % (type(self).__name__, num_params / 1000000)) 26 | 27 | def init_weights(self, init_type='normal', gain=0.02): 28 | def init_func(m): 29 | classname = m.__class__.__name__ 30 | if classname.find('BatchNorm2d') != -1: 31 | if hasattr(m, 'weight') and m.weight is not None: 32 | init.normal_(m.weight.data, 1.0, gain) 33 | if hasattr(m, 'bias') and m.bias is not None: 34 | init.constant_(m.bias.data, 0.0) 35 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 36 | if init_type == 'normal': 37 | init.normal_(m.weight.data, 0.0, gain) 38 | elif init_type == 'xavier': 39 | init.xavier_normal_(m.weight.data, gain=gain) 40 | elif init_type == 'xavier_uniform': 41 | init.xavier_uniform_(m.weight.data, gain=1.0) 42 | elif init_type == 'kaiming': 43 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 44 | elif init_type == 'orthogonal': 45 | init.orthogonal_(m.weight.data, gain=gain) 46 | elif init_type == 'none': # uses pytorch's default init method 47 | m.reset_parameters() 48 | else: 49 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 50 | if hasattr(m, 'bias') and m.bias is not None: 51 | init.constant_(m.bias.data, 0.0) 52 | self.apply(init_func) 53 | # propagate to children 54 | for m in self.children(): 55 | if hasattr(m, 'init_weights'): 56 | m.init_weights(init_type, gain) 57 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 98 | __pypackages__/ 99 | 100 | # Celery stuff 101 | celerybeat-schedule 102 | celerybeat.pid 103 | 104 | # SageMath parsed files 105 | *.sage.py 106 | 107 | # Environments 108 | .env 109 | .venv 110 | env/ 111 | venv/ 112 | ENV/ 113 | env.bak/ 114 | venv.bak/ 115 | 116 | # Spyder project settings 117 | .spyderproject 118 | .spyproject 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # pytype static type analyzer 135 | .pytype/ 136 | 137 | # Cython debug symbols 138 | cython_debug/ 139 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd). 40 | 41 | -------------------------------------------------------------------------------- /models/networks/ContextualLoss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from util.util import feature_normalize, mse_loss 10 | 11 | 12 | 13 | class ContextualLoss_forward(nn.Module): 14 | ''' 15 | input is Al, Bl, channel = 1, range ~ [0, 255] 16 | ''' 17 | 18 | def __init__(self, opt): 19 | super(ContextualLoss_forward, self).__init__() 20 | self.opt = opt 21 | return None 22 | 23 | def forward(self, X_features, Y_features, h=0.1, feature_centering=True): 24 | ''' 25 | X_features&Y_features are are feature vectors or feature 2d array 26 | h: bandwidth 27 | return the per-sample loss 28 | ''' 29 | batch_size = X_features.shape[0] 30 | feature_depth = X_features.shape[1] 31 | feature_size = X_features.shape[2] 32 | 33 | # to normalized feature vectors 34 | if feature_centering: 35 | if self.opt.PONO: 36 | X_features = X_features - Y_features.mean(dim=1).unsqueeze(dim=1) 37 | Y_features = Y_features - Y_features.mean(dim=1).unsqueeze(dim=1) 38 | else: 39 | X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(dim=-1) 40 | Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(dim=-1) 41 | 42 | # X_features = X_features - Y_features.mean(dim=1).unsqueeze(dim=1) 43 | # Y_features = Y_features - Y_features.mean(dim=1).unsqueeze(dim=1) 44 | 45 | X_features = feature_normalize(X_features).view(batch_size, feature_depth, -1) # batch_size * feature_depth * feature_size * feature_size 46 | Y_features = feature_normalize(Y_features).view(batch_size, feature_depth, -1) # batch_size * feature_depth * feature_size * feature_size 47 | 48 | # X_features = F.unfold( 49 | # X_features, kernel_size=self.opt.match_kernel, stride=1, padding=int(self.opt.match_kernel // 2)) # batch_size * feature_depth_new * feature_size^2 50 | # Y_features = F.unfold( 51 | # Y_features, kernel_size=self.opt.match_kernel, stride=1, padding=int(self.opt.match_kernel // 2)) # batch_size * feature_depth_new * feature_size^2 52 | 53 | # conine distance = 1 - similarity 54 | X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth 55 | d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2 56 | 57 | # normalized distance: dij_bar 58 | # d_norm = d 59 | d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-3) # batch_size * feature_size^2 * feature_size^2 60 | 61 | # pairwise affinity 62 | w = torch.exp((1 - d_norm) / h) 63 | A_ij = w / torch.sum(w, dim=-1, keepdim=True) 64 | 65 | # contextual loss per sample 66 | CX = torch.mean(torch.max(A_ij, dim=-1)[0], dim=1) 67 | loss = -torch.log(CX) 68 | 69 | # contextual loss per batch 70 | # loss = torch.mean(loss) 71 | return loss 72 | -------------------------------------------------------------------------------- /util/iter_counter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import time 6 | import numpy as np 7 | 8 | # Helper class that keeps track of training iterations 9 | class IterationCounter(): 10 | def __init__(self, opt, dataset_size): 11 | self.opt = opt 12 | self.dataset_size = dataset_size 13 | self.batch_size = opt.batchSize 14 | self.first_epoch = 1 15 | self.total_epochs = opt.niter + opt.niter_decay 16 | # iter number within each epoch 17 | self.epoch_iter = 0 18 | self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, 'iter.txt') 19 | if opt.isTrain and opt.continue_train: 20 | try: 21 | self.first_epoch, self.epoch_iter = np.loadtxt(self.iter_record_path, delimiter=',', dtype=int) 22 | print('Resuming from epoch %d at iteration %d' % (self.first_epoch, self.epoch_iter)) 23 | except: 24 | print('Could not load iteration record at %s. Starting from beginning.' % self.iter_record_path) 25 | self.epoch_iter_num = dataset_size * self.batch_size 26 | self.total_steps_so_far = (self.first_epoch - 1) * self.epoch_iter_num + self.epoch_iter 27 | self.continue_train_flag = opt.continue_train 28 | 29 | # return the iterator of epochs for the training 30 | def training_epochs(self): 31 | return range(self.first_epoch, self.total_epochs + 1) 32 | 33 | def record_epoch_start(self, epoch): 34 | self.epoch_start_time = time.time() 35 | if not self.continue_train_flag: 36 | self.epoch_iter = 0 37 | else: 38 | self.continue_train_flag = False 39 | self.last_iter_time = time.time() 40 | self.current_epoch = epoch 41 | 42 | def record_one_iteration(self): 43 | current_time = time.time() 44 | # the last remaining batch is dropped (see data/__init__.py), 45 | # so we can assume batch size is always opt.batchSize 46 | self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize 47 | self.last_iter_time = current_time 48 | self.total_steps_so_far += self.opt.batchSize 49 | self.epoch_iter += self.opt.batchSize 50 | 51 | def record_epoch_end(self): 52 | current_time = time.time() 53 | self.time_per_epoch = current_time - self.epoch_start_time 54 | print('End of epoch %d / %d \t Time Taken: %d sec' % 55 | (self.current_epoch, self.total_epochs, self.time_per_epoch)) 56 | if self.current_epoch % self.opt.save_epoch_freq == 0: 57 | np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0), 58 | delimiter=',', fmt='%d') 59 | print('Saved current iteration count at %s.' % self.iter_record_path) 60 | 61 | def record_current_iter(self): 62 | np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter), 63 | delimiter=',', fmt='%d') 64 | print('Saved current iteration count at %s.' % self.iter_record_path) 65 | 66 | def needs_saving(self): 67 | return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize 68 | 69 | def needs_printing(self): 70 | return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize 71 | 72 | def needs_displaying(self): 73 | return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize 74 | -------------------------------------------------------------------------------- /models/networks/convgru.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class FlowHead(nn.Module): 10 | def __init__(self, input_dim=32, hidden_dim=64): 11 | super(FlowHead, self).__init__() 12 | candidate_num = 16 13 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 14 | self.conv2 = nn.Conv2d(hidden_dim, 2*candidate_num, 3, padding=1) 15 | self.relu = nn.ReLU(inplace=True) 16 | def forward(self, x): 17 | x = self.conv1(x) 18 | x = self.relu(x) 19 | x = self.conv2(x) 20 | num = x.size()[1] 21 | delta_offset_x, delta_offset_y = torch.split(x, [num//2, num//2], dim=1) 22 | return delta_offset_x, delta_offset_y 23 | 24 | 25 | class SepConvGRU(nn.Module): 26 | def __init__(self, hidden_dim=32, input_dim=64): 27 | super(SepConvGRU, self).__init__() 28 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 29 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 30 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 31 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 32 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 33 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 34 | 35 | def forward(self, h, x): 36 | # horizontal 37 | hx = torch.cat([h, x], dim=1) 38 | z = torch.sigmoid(self.convz1(hx)) 39 | r = torch.sigmoid(self.convr1(hx)) 40 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 41 | h = (1-z) * h + z * q 42 | # vertical 43 | hx = torch.cat([h, x], dim=1) 44 | z = torch.sigmoid(self.convz2(hx)) 45 | r = torch.sigmoid(self.convr2(hx)) 46 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 47 | h = (1-z) * h + z * q 48 | return h 49 | 50 | 51 | class BasicMotionEncoder(nn.Module): 52 | def __init__(self): 53 | super(BasicMotionEncoder, self).__init__() 54 | candidate_num = 16 55 | self.convc1 = nn.Conv2d(candidate_num, 64, 1, padding=0) 56 | self.convc2 = nn.Conv2d(64, 64, 3, padding=1) 57 | self.convf1 = nn.Conv2d(2*candidate_num, 64, 7, padding=3) 58 | self.convf2 = nn.Conv2d(64, 64, 3, padding=1) 59 | self.conv = nn.Conv2d(64+64, 64-2*candidate_num, 3, padding=1) 60 | 61 | def forward(self, flow, corr): 62 | cor = F.relu(self.convc1(corr)) 63 | cor = F.relu(self.convc2(cor)) 64 | flo = F.relu(self.convf1(flow)) 65 | flo = F.relu(self.convf2(flo)) 66 | cor_flo = torch.cat([cor, flo], dim=1) 67 | out = F.relu(self.conv(cor_flo)) 68 | return torch.cat([out, flow], dim=1) 69 | 70 | 71 | class BasicUpdateBlock(nn.Module): 72 | def __init__(self): 73 | super(BasicUpdateBlock, self).__init__() 74 | self.encoder = BasicMotionEncoder() 75 | self.gru = SepConvGRU() 76 | self.flow_head = FlowHead() 77 | 78 | def forward(self, net, corr, flow): 79 | motion_features = self.encoder(flow, corr) 80 | inp = motion_features 81 | net = self.gru(net, inp) 82 | delta_offset_x, delta_offset_y = self.flow_head(net) 83 | return net, delta_offset_x, delta_offset_y 84 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import sys 6 | import torch 7 | from torchvision.utils import save_image 8 | from options.train_options import TrainOptions 9 | import data 10 | from util.iter_counter import IterationCounter 11 | from util.util import print_current_errors 12 | from util.util import mkdir 13 | from trainers.pix2pix_trainer import Pix2PixTrainer 14 | 15 | 16 | if __name__ == '__main__': 17 | # parse options 18 | opt = TrainOptions().parse() 19 | # print options to help debugging 20 | print(' '.join(sys.argv)) 21 | dataloader = data.create_dataloader(opt) 22 | len_dataloader = len(dataloader) 23 | # create tool for counting iterations 24 | iter_counter = IterationCounter(opt, len(dataloader)) 25 | # create trainer for our model 26 | trainer = Pix2PixTrainer(opt, resume_epoch=iter_counter.first_epoch) 27 | save_root = os.path.join('checkpoints', opt.name, 'train') 28 | mkdir(save_root) 29 | 30 | for epoch in iter_counter.training_epochs(): 31 | opt.epoch = epoch 32 | iter_counter.record_epoch_start(epoch) 33 | for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter): 34 | iter_counter.record_one_iteration() 35 | # Training 36 | # train generator 37 | if i % opt.D_steps_per_G == 0: 38 | trainer.run_generator_one_step(data_i) 39 | # train discriminator 40 | trainer.run_discriminator_one_step(data_i) 41 | if iter_counter.needs_printing(): 42 | losses = trainer.get_latest_losses() 43 | try: 44 | print_current_errors(opt, epoch, iter_counter.epoch_iter, 45 | iter_counter.epoch_iter_num, losses, iter_counter.time_per_iter) 46 | except OSError as err: 47 | print(err) 48 | 49 | if iter_counter.needs_displaying(): 50 | imgs_num = data_i['label'].shape[0] 51 | 52 | if opt.dataset_mode == 'deepfashionHD': 53 | label = data_i['label'][:,:3,:,:] 54 | 55 | show_size = opt.display_winsize 56 | 57 | imgs = torch.cat((label.cpu(), data_i['ref'].cpu(), \ 58 | trainer.get_latest_generated().data.cpu(), \ 59 | data_i['image'].cpu()), 0) 60 | 61 | try: 62 | save_name = '%08d_%08d.png' % (epoch, iter_counter.total_steps_so_far) 63 | save_name = os.path.join(save_root, save_name) 64 | save_image(imgs, save_name, nrow=imgs_num, padding=0, normalize=True) 65 | except OSError as err: 66 | print(err) 67 | 68 | if iter_counter.needs_saving(): 69 | print('saving the latest model (epoch %d, total_steps %d)' % 70 | (epoch, iter_counter.total_steps_so_far)) 71 | try: 72 | trainer.save('latest') 73 | iter_counter.record_current_iter() 74 | except OSError as err: 75 | import pdb; pdb.set_trace() 76 | print(err) 77 | 78 | trainer.update_learning_rate(epoch) 79 | iter_counter.record_epoch_end() 80 | 81 | if epoch % opt.save_epoch_freq == 0 or epoch == iter_counter.total_epochs: 82 | print('saving the model at the end of epoch %d, iters %d' % 83 | (epoch, iter_counter.total_steps_so_far)) 84 | try: 85 | trainer.save('latest') 86 | trainer.save(epoch) 87 | except OSError as err: 88 | print(err) 89 | 90 | print('Training was successfully finished.') 91 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import re 6 | import argparse 7 | from argparse import Namespace 8 | import torch 9 | import numpy as np 10 | import importlib 11 | from PIL import Image 12 | 13 | 14 | 15 | def feature_normalize(feature_in, eps=1e-10): 16 | feature_in_norm = torch.norm(feature_in, 2, 1, keepdim=True) + eps 17 | feature_in_norm = torch.div(feature_in, feature_in_norm) 18 | return feature_in_norm 19 | 20 | 21 | def weighted_l1_loss(input, target, weights): 22 | out = torch.abs(input - target) 23 | out = out * weights.expand_as(out) 24 | loss = out.mean() 25 | return loss 26 | 27 | 28 | def mse_loss(input, target=0): 29 | return torch.mean((input - target)**2) 30 | 31 | 32 | def vgg_preprocess(tensor, vgg_normal_correct=False): 33 | if vgg_normal_correct: 34 | tensor = (tensor + 1) / 2 35 | # input is RGB tensor which ranges in [0,1] 36 | # output is BGR tensor which ranges in [0,255] 37 | tensor_bgr = torch.cat((tensor[:, 2:3, :, :], tensor[:, 1:2, :, :], tensor[:, 0:1, :, :]), dim=1) 38 | # tensor_bgr = tensor[:, [2, 1, 0], ...] 39 | tensor_bgr_ml = tensor_bgr - torch.Tensor([0.40760392, 0.45795686, 0.48501961]).type_as(tensor_bgr).view(1, 3, 1, 1) 40 | tensor_rst = tensor_bgr_ml * 255 41 | return tensor_rst 42 | 43 | 44 | def mkdirs(paths): 45 | if isinstance(paths, list) and not isinstance(paths, str): 46 | for path in paths: 47 | mkdir(path) 48 | else: 49 | mkdir(paths) 50 | 51 | 52 | def mkdir(path): 53 | if not os.path.exists(path): 54 | os.makedirs(path) 55 | 56 | 57 | def find_class_in_module(target_cls_name, module): 58 | target_cls_name = target_cls_name.replace('_', '').lower() 59 | clslib = importlib.import_module(module) 60 | cls = None 61 | for name, clsobj in clslib.__dict__.items(): 62 | if name.lower() == target_cls_name: 63 | cls = clsobj 64 | if cls is None: 65 | print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)) 66 | exit(0) 67 | return cls 68 | 69 | 70 | def save_network(net, label, epoch, opt): 71 | save_filename = '%s_net_%s.pth' % (epoch, label) 72 | save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename) 73 | torch.save(net.cpu().state_dict(), save_path) 74 | if len(opt.gpu_ids) and torch.cuda.is_available(): 75 | net.cuda() 76 | 77 | 78 | def load_network(net, label, epoch, opt): 79 | save_filename = '%s_net_%s.pth' % (epoch, label) 80 | save_dir = os.path.join(opt.checkpoints_dir, opt.name) 81 | save_path = os.path.join(save_dir, save_filename) 82 | if not os.path.exists(save_path): 83 | print('not find model :' + save_path + ', do not load model!') 84 | return net 85 | weights = torch.load(save_path) 86 | try: 87 | net.load_state_dict(weights) 88 | except KeyError: 89 | print('key error, not load!') 90 | except RuntimeError as err: 91 | print(err) 92 | net.load_state_dict(weights, strict=False) 93 | print('loaded with strict = False') 94 | print('Load from ' + save_path) 95 | return net 96 | 97 | 98 | def print_current_errors(opt, epoch, i, num, errors, t): 99 | message = '(epoch: %d, iters: %d, finish: %.2f%%, time: %.3f) ' % (epoch, i, (i/num)*100.0, t) 100 | for k, v in errors.items(): 101 | v = v.mean().float() 102 | message += '%s: %.3f ' % (k, v) 103 | print(message) 104 | log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 105 | with open(log_name, "a") as log_file: 106 | log_file.write('%s\n' % message) 107 | -------------------------------------------------------------------------------- /models/networks/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class GANLoss(nn.Module): 10 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, 11 | tensor=torch.FloatTensor, opt=None): 12 | super(GANLoss, self).__init__() 13 | self.real_label = target_real_label 14 | self.fake_label = target_fake_label 15 | self.real_label_tensor = None 16 | self.fake_label_tensor = None 17 | self.zero_tensor = None 18 | self.Tensor = tensor 19 | self.gan_mode = gan_mode 20 | self.opt = opt 21 | if gan_mode == 'ls': 22 | pass 23 | elif gan_mode == 'original': 24 | pass 25 | elif gan_mode == 'w': 26 | pass 27 | elif gan_mode == 'hinge': 28 | pass 29 | else: 30 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) 31 | 32 | def get_target_tensor(self, input, target_is_real): 33 | if target_is_real: 34 | if self.real_label_tensor is None: 35 | self.real_label_tensor = self.Tensor(1).fill_(self.real_label) 36 | self.real_label_tensor.requires_grad_(False) 37 | return self.real_label_tensor.expand_as(input) 38 | else: 39 | if self.fake_label_tensor is None: 40 | self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label) 41 | self.fake_label_tensor.requires_grad_(False) 42 | return self.fake_label_tensor.expand_as(input) 43 | 44 | def get_zero_tensor(self, input): 45 | if self.zero_tensor is None: 46 | self.zero_tensor = self.Tensor(1).fill_(0) 47 | self.zero_tensor.requires_grad_(False) 48 | return self.zero_tensor.expand_as(input).type_as(input) 49 | 50 | def loss(self, input, target_is_real, for_discriminator=True): 51 | if self.gan_mode == 'original': # cross entropy loss 52 | target_tensor = self.get_target_tensor(input, target_is_real) 53 | loss = F.binary_cross_entropy_with_logits(input, target_tensor) 54 | return loss 55 | elif self.gan_mode == 'ls': 56 | target_tensor = self.get_target_tensor(input, target_is_real) 57 | return F.mse_loss(input, target_tensor) 58 | elif self.gan_mode == 'hinge': 59 | if for_discriminator: 60 | if target_is_real: 61 | minval = torch.min(input - 1, self.get_zero_tensor(input)) 62 | loss = -torch.mean(minval) 63 | else: 64 | minval = torch.min(-input - 1, self.get_zero_tensor(input)) 65 | loss = -torch.mean(minval) 66 | else: 67 | assert target_is_real, "The generator's hinge loss must be aiming for real" 68 | loss = -torch.mean(input) 69 | return loss 70 | else: 71 | # wgan 72 | if target_is_real: 73 | return -input.mean() 74 | else: 75 | return input.mean() 76 | 77 | def __call__(self, input, target_is_real, for_discriminator=True): 78 | if isinstance(input, list): 79 | loss = 0 80 | for pred_i in input: 81 | if isinstance(pred_i, list): 82 | pred_i = pred_i[-1] 83 | loss_tensor = self.loss(pred_i, target_is_real, for_discriminator) 84 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) 85 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) 86 | loss += new_loss 87 | return loss / len(input) 88 | else: 89 | return self.loss(input, target_is_real, for_discriminator) 90 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | 5 | from .base_options import BaseOptions 6 | 7 | 8 | class TrainOptions(BaseOptions): 9 | def initialize(self, parser): 10 | BaseOptions.initialize(self, parser) 11 | # for displays 12 | parser.add_argument('--display_freq', type=int, default=2000, help='frequency of showing training results on screen') 13 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 14 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 15 | parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs') 16 | # for training 17 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 18 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 19 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate. This is NOT the total #epochs. Totla #epochs is niter + niter_decay') 20 | parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero') 21 | parser.add_argument('--optimizer', type=str, default='adam') 22 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 23 | parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam') 24 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 25 | parser.add_argument('--D_steps_per_G', type=int, default=1, help='number of discriminator iterations per generator iterations.') 26 | # for discriminators 27 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer') 28 | parser.add_argument('--netD', type=str, default='multiscale', help='(n_layers|multiscale|image)') 29 | parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme') 30 | parser.add_argument('--real_reference_probability', type=float, default=0.0, help='self-supervised training probability') 31 | parser.add_argument('--hard_reference_probability', type=float, default=0.0, help='hard reference training probability') 32 | # training loss weights 33 | parser.add_argument('--weight_warp_self', type=float, default=0.0, help='push warp self to ref') 34 | parser.add_argument('--weight_warp_cycle', type=float, default=0.0, help='push warp cycle to ref') 35 | parser.add_argument('--weight_novgg_featpair', type=float, default=10.0, help='in no vgg setting, use pair feat loss in domain adaptation') 36 | parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)') 37 | parser.add_argument('--weight_gan', type=float, default=10.0, help='weight of all loss in stage1') 38 | parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss') 39 | parser.add_argument('--weight_ganFeat', type=float, default=10.0, help='weight for feature matching loss') 40 | parser.add_argument('--which_perceptual', type=str, default='4_2', help='relu5_2 or relu4_2') 41 | parser.add_argument('--weight_perceptual', type=float, default=0.001) 42 | parser.add_argument('--weight_vgg', type=float, default=10.0, help='weight for vgg loss') 43 | parser.add_argument('--weight_contextual', type=float, default=1.0, help='ctx loss weight') 44 | parser.add_argument('--weight_fm_ratio', type=float, default=1.0, help='vgg fm loss weight comp with ctx loss') 45 | self.isTrain = True 46 | return parser 47 | -------------------------------------------------------------------------------- /models/networks/normalization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import re 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.nn.utils.spectral_norm as spectral_norm 9 | 10 | 11 | def get_nonspade_norm_layer(opt, norm_type='instance'): 12 | def get_out_channel(layer): 13 | if hasattr(layer, 'out_channels'): 14 | return getattr(layer, 'out_channels') 15 | return layer.weight.size(0) 16 | def add_norm_layer(layer): 17 | nonlocal norm_type 18 | if norm_type.startswith('spectral'): 19 | layer = spectral_norm(layer) 20 | subnorm_type = norm_type[len('spectral'):] 21 | else: 22 | subnorm_type =norm_type 23 | if subnorm_type == 'none' or len(subnorm_type) == 0: 24 | return layer 25 | if getattr(layer, 'bias', None) is not None: 26 | delattr(layer, 'bias') 27 | layer.register_parameter('bias', None) 28 | if subnorm_type == 'batch': 29 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) 30 | elif subnorm_type == 'sync_batch': 31 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) 32 | elif subnorm_type == 'instance': 33 | norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) 34 | else: 35 | raise ValueError('normalization layer %s is not recognized' % subnorm_type) 36 | return nn.Sequential(layer, norm_layer) 37 | return add_norm_layer 38 | 39 | 40 | def PositionalNorm2d(x, epsilon=1e-8): 41 | # x: B*C*W*H normalize in C dim 42 | mean = x.mean(dim=1, keepdim=True) 43 | std = x.var(dim=1, keepdim=True).add(epsilon).sqrt() 44 | output = (x - mean) / std 45 | return output 46 | 47 | 48 | class SPADE(nn.Module): 49 | def __init__(self, config_text, norm_nc, label_nc, PONO=False): 50 | super().__init__() 51 | assert config_text.startswith('spade') 52 | parsed = re.search('spade(\D+)(\d)x\d', config_text) 53 | param_free_norm_type = str(parsed.group(1)) 54 | ks = int(parsed.group(2)) 55 | self.pad_type = 'nozero' 56 | if PONO: 57 | self.param_free_norm = PositionalNorm2d 58 | elif param_free_norm_type == 'instance': 59 | self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) 60 | elif param_free_norm_type == 'syncbatch': 61 | self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=True) 62 | elif param_free_norm_type == 'batch': 63 | self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=True) 64 | else: 65 | raise ValueError('%s is not a recognized param-free norm type in SPADE' % param_free_norm_type) 66 | nhidden = 128 67 | pw = ks // 2 68 | if self.pad_type != 'zero': 69 | self.mlp_shared = nn.Sequential( 70 | nn.ReflectionPad2d(pw), 71 | nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=0), 72 | nn.ReLU() 73 | ) 74 | self.pad = nn.ReflectionPad2d(pw) 75 | self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=0) 76 | self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=0) 77 | 78 | else: 79 | self.mlp_shared = nn.Sequential( 80 | nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), 81 | nn.ReLU() 82 | ) 83 | self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 84 | self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 85 | 86 | def forward(self, x, segmap): 87 | normalized = self.param_free_norm(x) 88 | segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') 89 | actv = self.mlp_shared(segmap) 90 | if self.pad_type != 'zero': 91 | gamma = self.mlp_gamma(self.pad(actv)) 92 | beta = self.mlp_beta(self.pad(actv)) 93 | else: 94 | gamma = self.mlp_gamma(actv) 95 | beta = self.mlp_beta(actv) 96 | out = normalized * (1 + gamma) + beta 97 | return out 98 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch.utils.data as data 5 | from PIL import Image 6 | import torchvision.transforms as transforms 7 | import numpy as np 8 | import random 9 | 10 | 11 | class BaseDataset(data.Dataset): 12 | def __init__(self): 13 | super(BaseDataset, self).__init__() 14 | 15 | @staticmethod 16 | def modify_commandline_options(parser, is_train): 17 | return parser 18 | 19 | def initialize(self, opt): 20 | pass 21 | 22 | 23 | def get_params(opt, size): 24 | w, h = size 25 | new_h = h 26 | new_w = w 27 | if opt.preprocess_mode == 'resize_and_crop': 28 | new_h = new_w = opt.load_size 29 | elif opt.preprocess_mode == 'scale_width_and_crop': 30 | new_w = opt.load_size 31 | new_h = opt.load_size * h // w 32 | elif opt.preprocess_mode == 'scale_shortside_and_crop': 33 | ss, ls = min(w, h), max(w, h) # shortside and longside 34 | width_is_shorter = w == ss 35 | ls = int(opt.load_size * ls / ss) 36 | new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss) 37 | 38 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 39 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 40 | 41 | flip = random.random() > 0.5 42 | return {'crop_pos': (x, y), 'flip': flip} 43 | 44 | 45 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True): 46 | transform_list = [] 47 | if opt.dataset_mode == 'flickr' and method == Image.NEAREST: 48 | transform_list.append(transforms.Lambda(lambda img: __add1(img))) 49 | if 'resize' in opt.preprocess_mode: 50 | osize = [opt.load_size, opt.load_size] 51 | transform_list.append(transforms.Resize(osize, interpolation=method)) 52 | elif 'scale_width' in opt.preprocess_mode: 53 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method))) 54 | elif 'scale_shortside' in opt.preprocess_mode: 55 | transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method))) 56 | 57 | if 'crop' in opt.preprocess_mode: 58 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 59 | 60 | if opt.preprocess_mode == 'none': 61 | base = 32 62 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method))) 63 | 64 | if opt.preprocess_mode == 'fixed': 65 | w = opt.crop_size 66 | h = round(opt.crop_size / opt.aspect_ratio) 67 | transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method))) 68 | 69 | if opt.isTrain and not opt.no_flip: 70 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 71 | 72 | if opt.isTrain and 'rotate' in params.keys(): 73 | transform_list.append(transforms.Lambda(lambda img: __rotate(img, params['rotate'], method))) 74 | 75 | if toTensor: 76 | transform_list += [transforms.ToTensor()] 77 | 78 | if normalize: 79 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 80 | (0.5, 0.5, 0.5))] 81 | return transforms.Compose(transform_list) 82 | 83 | 84 | def __resize(img, w, h, method=Image.BICUBIC): 85 | return img.resize((w, h), method) 86 | 87 | 88 | def __make_power_2(img, base, method=Image.BICUBIC): 89 | ow, oh = img.size 90 | h = int(round(oh / base) * base) 91 | w = int(round(ow / base) * base) 92 | if (h == oh) and (w == ow): 93 | return img 94 | return img.resize((w, h), method) 95 | 96 | 97 | def __scale_width(img, target_width, method=Image.BICUBIC): 98 | ow, oh = img.size 99 | if (ow == target_width): 100 | return img 101 | w = target_width 102 | h = int(target_width * oh / ow) 103 | return img.resize((w, h), method) 104 | 105 | 106 | def __scale_shortside(img, target_width, method=Image.BICUBIC): 107 | ow, oh = img.size 108 | ss, ls = min(ow, oh), max(ow, oh) # shortside and longside 109 | width_is_shorter = ow == ss 110 | if (ss == target_width): 111 | return img 112 | ls = int(target_width * ls / ss) 113 | nw, nh = (ss, ls) if width_is_shorter else (ls, ss) 114 | return img.resize((nw, nh), method) 115 | 116 | 117 | def __crop(img, pos, size): 118 | ow, oh = img.size 119 | x1, y1 = pos 120 | tw = th = size 121 | return img.crop((x1, y1, x1 + tw, y1 + th)) 122 | 123 | 124 | def __flip(img, flip): 125 | if flip: 126 | return img.transpose(Image.FLIP_LEFT_RIGHT) 127 | return img 128 | 129 | 130 | def __rotate(img, deg, method=Image.BICUBIC): 131 | return img.rotate(deg, resample=method) 132 | 133 | 134 | def __add1(img): 135 | return Image.fromarray(np.array(img) + 1) -------------------------------------------------------------------------------- /models/networks/discriminator.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from models.networks.base_network import BaseNetwork 9 | from models.networks.normalization import get_nonspade_norm_layer 10 | import util.util as util 11 | 12 | 13 | class MultiscaleDiscriminator(BaseNetwork): 14 | @staticmethod 15 | def modify_commandline_options(parser, is_train): 16 | parser.add_argument('--netD_subarch', type=str, default='n_layer', 17 | help='architecture of each discriminator') 18 | parser.add_argument('--num_D', type=int, default=2, 19 | help='number of discriminators to be used in multiscale') 20 | opt, _ = parser.parse_known_args() 21 | # define properties of each discriminator of the multiscale discriminator 22 | subnetD = util.find_class_in_module(opt.netD_subarch + 'discriminator', \ 23 | 'models.networks.discriminator') 24 | subnetD.modify_commandline_options(parser, is_train) 25 | return parser 26 | 27 | def __init__(self, opt): 28 | super().__init__() 29 | self.opt = opt 30 | for i in range(opt.num_D): 31 | subnetD = self.create_single_discriminator(opt) 32 | self.add_module('discriminator_%d' % i, subnetD) 33 | 34 | def create_single_discriminator(self, opt): 35 | subarch = opt.netD_subarch 36 | if subarch == 'n_layer': 37 | netD = NLayerDiscriminator(opt) 38 | else: 39 | raise ValueError('unrecognized discriminator subarchitecture %s' % subarch) 40 | return netD 41 | 42 | def downsample(self, input): 43 | return F.avg_pool2d(input, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False) 44 | 45 | def forward(self, input): 46 | result = [] 47 | get_intermediate_features = not self.opt.no_ganFeat_loss 48 | for name, D in self.named_children(): 49 | out = D(input) 50 | if not get_intermediate_features: 51 | out = [out] 52 | result.append(out) 53 | input = self.downsample(input) 54 | return result 55 | 56 | 57 | class NLayerDiscriminator(BaseNetwork): 58 | @staticmethod 59 | def modify_commandline_options(parser, is_train): 60 | parser.add_argument('--n_layers_D', type=int, default=4, help='# layers in each discriminator') 61 | return parser 62 | 63 | def __init__(self, opt): 64 | super().__init__() 65 | self.opt = opt 66 | kw = 4 67 | padw = int((kw - 1.0) / 2) 68 | nf = opt.ndf 69 | input_nc = self.compute_D_input_nc(opt) 70 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_D) 71 | sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw), 72 | nn.LeakyReLU(0.2, False)]] 73 | for n in range(1, opt.n_layers_D): 74 | nf_prev = nf 75 | nf = min(nf * 2, 512) 76 | stride = 1 if n == opt.n_layers_D - 1 else 2 77 | if n == opt.n_layers_D - 1: 78 | dec = [] 79 | nc_dec = nf_prev 80 | for _ in range(opt.n_layers_D - 1): 81 | dec += [nn.Upsample(scale_factor=2), 82 | norm_layer(nn.Conv2d(nc_dec, int(nc_dec//2), kernel_size=3, stride=1, padding=1)), 83 | nn.LeakyReLU(0.2, False)] 84 | nc_dec = int(nc_dec // 2) 85 | dec += [nn.Conv2d(nc_dec, opt.semantic_nc, kernel_size=3, stride=1, padding=1)] 86 | self.dec = nn.Sequential(*dec) 87 | sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)), 88 | nn.LeakyReLU(0.2, False)]] 89 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 90 | for n in range(len(sequence)): 91 | self.add_module('model' + str(n), nn.Sequential(*sequence[n])) 92 | 93 | def compute_D_input_nc(self, opt): 94 | input_nc = opt.label_nc + opt.output_nc 95 | if opt.contain_dontcare_label: 96 | input_nc += 1 97 | return input_nc 98 | 99 | def forward(self, input): 100 | results = [input] 101 | seg = None 102 | cam_logit = None 103 | for name, submodel in self.named_children(): 104 | if 'model' not in name: 105 | continue 106 | x = results[-1] 107 | intermediate_output = submodel(x) 108 | results.append(intermediate_output) 109 | get_intermediate_features = not self.opt.no_ganFeat_loss 110 | if get_intermediate_features: 111 | retu = results[1:] 112 | else: 113 | retu = results[-1] 114 | return retu 115 | -------------------------------------------------------------------------------- /trainers/pix2pix_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import copy 6 | import sys 7 | import torch 8 | from models.pix2pix_model import Pix2PixModel 9 | try: 10 | from torch.cuda.amp import GradScaler 11 | except: 12 | # dummy GradScaler for PyTorch < 1.6 13 | class GradScaler: 14 | def __init__(self, enabled): 15 | pass 16 | def scale(self, loss): 17 | return loss 18 | def unscale_(self, optimizer): 19 | pass 20 | def step(self, optimizer): 21 | optimizer.step() 22 | def update(self): 23 | pass 24 | 25 | 26 | class Pix2PixTrainer(): 27 | """ 28 | Trainer creates the model and optimizers, and uses them to 29 | updates the weights of the network while reporting losses 30 | and the latest visuals to visualize the progress in training. 31 | """ 32 | 33 | def __init__(self, opt, resume_epoch=0): 34 | self.opt = opt 35 | self.pix2pix_model = Pix2PixModel(opt) 36 | if len(opt.gpu_ids) > 1: 37 | self.pix2pix_model = torch.nn.DataParallel(self.pix2pix_model, device_ids=opt.gpu_ids) 38 | self.pix2pix_model_on_one_gpu = self.pix2pix_model.module 39 | else: 40 | self.pix2pix_model.to(opt.gpu_ids[0]) 41 | self.pix2pix_model_on_one_gpu = self.pix2pix_model 42 | self.generated = None 43 | if opt.isTrain: 44 | self.optimizer_G, self.optimizer_D = self.pix2pix_model_on_one_gpu.create_optimizers(opt) 45 | self.old_lr = opt.lr 46 | if opt.continue_train and opt.which_epoch == 'latest': 47 | try: 48 | load_path = os.path.join(opt.checkpoints_dir, opt.name, 'optimizer.pth') 49 | checkpoint = torch.load(load_path) 50 | self.optimizer_G.load_state_dict(checkpoint['G']) 51 | self.optimizer_D.load_state_dict(checkpoint['D']) 52 | except FileNotFoundError as err: 53 | print(err) 54 | print('Not find optimizer state dict: ' + load_path + '. Do not load optimizer!') 55 | 56 | self.last_data, self.last_netCorr, self.last_netG, self.last_optimizer_G = None, None, None, None 57 | self.g_losses = {} 58 | self.d_losses = {} 59 | self.scaler = GradScaler(enabled=self.opt.amp) 60 | 61 | def run_generator_one_step(self, data): 62 | self.optimizer_G.zero_grad() 63 | g_losses, out = self.pix2pix_model(data, mode='generator') 64 | g_loss = sum(g_losses.values()).mean() 65 | # g_loss.backward() 66 | self.scaler.scale(g_loss).backward() 67 | self.scaler.unscale_(self.optimizer_G) 68 | # self.optimizer_G.step() 69 | self.scaler.step(self.optimizer_G) 70 | self.scaler.update() 71 | self.g_losses = g_losses 72 | self.out = out 73 | 74 | def run_discriminator_one_step(self, data): 75 | self.optimizer_D.zero_grad() 76 | GforD = {} 77 | GforD['fake_image'] = self.out['fake_image'] 78 | GforD['adaptive_feature_seg'] = self.out['adaptive_feature_seg'] 79 | GforD['adaptive_feature_img'] = self.out['adaptive_feature_img'] 80 | d_losses = self.pix2pix_model(data, mode='discriminator', GforD=GforD) 81 | d_loss = sum(d_losses.values()).mean() 82 | # d_loss.backward() 83 | self.scaler.scale(d_loss).backward() 84 | self.scaler.unscale_(self.optimizer_D) 85 | # self.optimizer_D.step() 86 | self.scaler.step(self.optimizer_D) 87 | self.scaler.update() 88 | self.d_losses = d_losses 89 | 90 | def get_latest_losses(self): 91 | return {**self.g_losses, **self.d_losses} 92 | 93 | def get_latest_generated(self): 94 | return self.out['fake_image'] 95 | 96 | def update_learning_rate(self, epoch): 97 | self.update_learning_rate(epoch) 98 | 99 | def save(self, epoch): 100 | self.pix2pix_model_on_one_gpu.save(epoch) 101 | if epoch == 'latest': 102 | torch.save({'G': self.optimizer_G.state_dict(), \ 103 | 'D': self.optimizer_D.state_dict(), \ 104 | 'lr': self.old_lr,}, \ 105 | os.path.join(self.opt.checkpoints_dir, self.opt.name, 'optimizer.pth')) 106 | 107 | def update_learning_rate(self, epoch): 108 | if epoch > self.opt.niter: 109 | lrd = self.opt.lr / self.opt.niter_decay 110 | new_lr = self.old_lr - lrd 111 | else: 112 | new_lr = self.old_lr 113 | if new_lr != self.old_lr: 114 | new_lr_G = new_lr 115 | new_lr_D = new_lr 116 | else: 117 | new_lr_G = self.old_lr 118 | new_lr_D = self.old_lr 119 | for param_group in self.optimizer_D.param_groups: 120 | param_group['lr'] = new_lr_D 121 | for param_group in self.optimizer_G.param_groups: 122 | param_group['lr'] = new_lr_G 123 | print('update learning rate: %f -> %f' % (self.old_lr, new_lr)) 124 | self.old_lr = new_lr 125 | -------------------------------------------------------------------------------- /data/pix2pix_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | 5 | import os 6 | import random 7 | from PIL import Image 8 | 9 | from data.base_dataset import BaseDataset, get_params, get_transform 10 | 11 | 12 | class Pix2pixDataset(BaseDataset): 13 | @staticmethod 14 | def modify_commandline_options(parser, is_train): 15 | parser.add_argument('--no_pairing_check', action='store_true', help='If specified, skip sanity check of correct label-image file pairing') 16 | return parser 17 | 18 | def initialize(self, opt): 19 | self.opt = opt 20 | label_paths, image_paths = self.get_paths(opt) 21 | label_paths = label_paths[:opt.max_dataset_size] 22 | image_paths = image_paths[:opt.max_dataset_size] 23 | if not opt.no_pairing_check: 24 | for path1, path2 in zip(label_paths, image_paths): 25 | assert self.paths_match(path1, path2), \ 26 | "The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this." % (path1, path2) 27 | self.label_paths = label_paths 28 | self.image_paths = image_paths 29 | size = len(self.label_paths) 30 | self.dataset_size = size 31 | self.real_reference_probability = 1 if opt.phase == 'test' else opt.real_reference_probability 32 | self.hard_reference_probability = 0 if opt.phase == 'test' else opt.hard_reference_probability 33 | self.ref_dict, self.train_test_folder = self.get_ref(opt) 34 | 35 | def get_paths(self, opt): 36 | label_paths = [] 37 | image_paths = [] 38 | assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)" 39 | return label_paths, image_paths 40 | 41 | def paths_match(self, path1, path2): 42 | filename1_without_ext = os.path.splitext(os.path.basename(path1))[0] 43 | filename2_without_ext = os.path.splitext(os.path.basename(path2))[0] 44 | return filename1_without_ext == filename2_without_ext 45 | 46 | def get_label_tensor(self, path): 47 | label = Image.open(path) 48 | params1 = get_params(self.opt, label.size) 49 | transform_label = get_transform(self.opt, params1, method=Image.NEAREST, normalize=False) 50 | label_tensor = transform_label(label) * 255.0 51 | label_tensor[label_tensor == 255] = self.opt.label_nc 52 | # 'unknown' is opt.label_nc 53 | return label_tensor, params1 54 | 55 | def __getitem__(self, index): 56 | # label Image 57 | label_path = self.label_paths[index] 58 | label_path = os.path.join(self.opt.dataroot, label_path) 59 | label_tensor, params1 = self.get_label_tensor(label_path) 60 | # input image (real images) 61 | image_path = self.image_paths[index] 62 | image_path = os.path.join(self.opt.dataroot, image_path) 63 | image = Image.open(image_path).convert('RGB') 64 | transform_image = get_transform(self.opt, params1) 65 | image_tensor = transform_image(image) 66 | ref_tensor = 0 67 | label_ref_tensor = 0 68 | random_p = random.random() 69 | if random_p < self.real_reference_probability or self.opt.phase == 'test': 70 | key = image_path.split('deepfashionHD/')[-1] 71 | val = self.ref_dict[key] 72 | if random_p < self.hard_reference_probability: 73 | #hard reference 74 | path_ref = val[1] 75 | else: 76 | #easy reference 77 | path_ref = val[0] 78 | if self.opt.dataset_mode == 'deepfashionHD': 79 | path_ref = os.path.join(self.opt.dataroot, path_ref) 80 | else: 81 | path_ref = os.path.dirname(image_path).replace(self.train_test_folder[1], self.train_test_folder[0]) + '/' + path_ref 82 | image_ref = Image.open(path_ref).convert('RGB') 83 | if self.opt.dataset_mode != 'deepfashionHD': 84 | path_ref_label = path_ref.replace('.jpg', '.png') 85 | path_ref_label = self.imgpath_to_labelpath(path_ref_label) 86 | else: 87 | path_ref_label = self.imgpath_to_labelpath(path_ref) 88 | label_ref_tensor, params = self.get_label_tensor(path_ref_label) 89 | transform_image = get_transform(self.opt, params) 90 | ref_tensor = transform_image(image_ref) 91 | self_ref_flag = 0.0 92 | else: 93 | pair = False 94 | if self.opt.dataset_mode == 'deepfashionHD' and self.opt.video_like: 95 | key = image_path.replace('\\', '/').split('deepfashionHD/')[-1] 96 | val = self.ref_dict[key] 97 | ref_name = val[0] 98 | key_name = key 99 | path_ref = os.path.join(self.opt.dataroot, ref_name) 100 | image_ref = Image.open(path_ref).convert('RGB') 101 | label_ref_path = self.imgpath_to_labelpath(path_ref) 102 | label_ref_tensor, params = self.get_label_tensor(label_ref_path) 103 | transform_image = get_transform(self.opt, params) 104 | ref_tensor = transform_image(image_ref) 105 | pair = True 106 | if not pair: 107 | label_ref_tensor, params = self.get_label_tensor(label_path) 108 | transform_image = get_transform(self.opt, params) 109 | ref_tensor = transform_image(image) 110 | self_ref_flag = 1.0 111 | input_dict = {'label': label_tensor, 112 | 'image': image_tensor, 113 | 'path': image_path, 114 | 'self_ref': self_ref_flag, 115 | 'ref': ref_tensor, 116 | 'label_ref': label_ref_tensor 117 | } 118 | return input_dict 119 | 120 | def __len__(self): 121 | return self.dataset_size 122 | 123 | def get_ref(self, opt): 124 | pass 125 | 126 | def imgpath_to_labelpath(self, path): 127 | return path 128 | -------------------------------------------------------------------------------- /data/deepfashionHD_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import cv2 6 | import torch 7 | import numpy as np 8 | import math 9 | import random 10 | from PIL import Image 11 | 12 | from data.pix2pix_dataset import Pix2pixDataset 13 | from data.base_dataset import get_params, get_transform 14 | 15 | 16 | class DeepFashionHDDataset(Pix2pixDataset): 17 | @staticmethod 18 | def modify_commandline_options(parser, is_train): 19 | parser = Pix2pixDataset.modify_commandline_options(parser, is_train) 20 | parser.set_defaults(preprocess_mode='resize_and_crop') 21 | parser.set_defaults(no_pairing_check=True) 22 | parser.set_defaults(load_size=550) 23 | parser.set_defaults(crop_size=512) 24 | parser.set_defaults(label_nc=20) 25 | parser.set_defaults(contain_dontcare_label=False) 26 | parser.set_defaults(cache_filelist_read=False) 27 | parser.set_defaults(cache_filelist_write=False) 28 | return parser 29 | 30 | def get_paths(self, opt): 31 | root = opt.dataroot 32 | if opt.phase == 'train': 33 | fd = open(os.path.join('./data/train.txt')) 34 | lines = fd.readlines() 35 | fd.close() 36 | elif opt.phase == 'test': 37 | fd = open(os.path.join('./data/val.txt')) 38 | lines = fd.readlines() 39 | fd.close() 40 | image_paths = [] 41 | label_paths = [] 42 | for i in range(len(lines)): 43 | name = lines[i].strip() 44 | image_paths.append(name) 45 | label_path = name.replace('img', 'pose').replace('.jpg', '_{}.txt') 46 | label_paths.append(os.path.join(label_path)) 47 | return label_paths, image_paths 48 | 49 | def get_ref_video_like(self, opt): 50 | pair_path = './data/deepfashion_self_pair.txt' 51 | with open(pair_path) as fd: 52 | self_pair = fd.readlines() 53 | self_pair = [it.strip() for it in self_pair] 54 | self_pair_dict = {} 55 | for it in self_pair: 56 | items = it.split(',') 57 | self_pair_dict[items[0]] = items[1:] 58 | ref_path = './data/deepfashion_ref_test.txt' if opt.phase == 'test' else './data/deepfashion_ref.txt' 59 | with open(ref_path) as fd: 60 | ref = fd.readlines() 61 | ref = [it.strip() for it in ref] 62 | ref_dict = {} 63 | for i in range(len(ref)): 64 | items = ref[i].strip().split(',') 65 | key = items[0] 66 | if key in self_pair_dict.keys(): 67 | val = [it for it in self_pair_dict[items[0]]] 68 | else: 69 | val = [items[-1]] 70 | ref_dict[key.replace('\\',"/")] = [v.replace('\\',"/") for v in val] 71 | train_test_folder = ('', '') 72 | return ref_dict, train_test_folder 73 | 74 | def get_ref_vgg(self, opt): 75 | extra = '' 76 | if opt.phase == 'test': 77 | extra = '_test' 78 | with open('./data/deepfashion_ref{}.txt'.format(extra)) as fd: 79 | lines = fd.readlines() 80 | ref_dict = {} 81 | for i in range(len(lines)): 82 | items = lines[i].strip().split(',') 83 | key = items[0] 84 | if opt.phase == 'test': 85 | val = [it for it in items[1:]] 86 | else: 87 | val = [items[-1]] 88 | ref_dict[key.replace('\\',"/")] = [v.replace('\\',"/") for v in val] 89 | train_test_folder = ('', '') 90 | return ref_dict, train_test_folder 91 | 92 | def get_ref(self, opt): 93 | if opt.video_like: 94 | return self.get_ref_video_like(opt) 95 | else: 96 | return self.get_ref_vgg(opt) 97 | 98 | def get_label_tensor(self, path): 99 | candidate = np.loadtxt(path.format('candidate')) 100 | subset = np.loadtxt(path.format('subset')) 101 | stickwidth = 20 102 | limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ 103 | [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ 104 | [1, 16], [16, 18], [3, 17], [6, 18]] 105 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ 106 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ 107 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 108 | canvas = np.zeros((1024, 1024, 3), dtype=np.uint8) 109 | cycle_radius = 20 110 | for i in range(18): 111 | index = int(subset[i]) 112 | if index == -1: 113 | continue 114 | x, y = candidate[index][0:2] 115 | cv2.circle(canvas, (int(x), int(y)), cycle_radius, colors[i], thickness=-1) 116 | joints = [] 117 | for i in range(17): 118 | index = subset[np.array(limbSeq[i]) - 1] 119 | cur_canvas = canvas.copy() 120 | if -1 in index: 121 | joints.append(np.zeros_like(cur_canvas[:, :, 0])) 122 | continue 123 | Y = candidate[index.astype(int), 0] 124 | X = candidate[index.astype(int), 1] 125 | mX = np.mean(X) 126 | mY = np.mean(Y) 127 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 128 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 129 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) 130 | cv2.fillConvexPoly(cur_canvas, polygon, colors[i]) 131 | canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) 132 | joint = np.zeros_like(cur_canvas[:, :, 0]) 133 | cv2.fillConvexPoly(joint, polygon, 255) 134 | joint = cv2.addWeighted(joint, 0.4, joint, 0.6, 0) 135 | joints.append(joint) 136 | pose = Image.fromarray(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)).resize((self.opt.load_size, self.opt.load_size), resample=Image.NEAREST) 137 | params = get_params(self.opt, pose.size) 138 | transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False) 139 | transform_img = get_transform(self.opt, params, method=Image.BILINEAR, normalize=False) 140 | tensors_dist = 0 141 | e = 1 142 | for i in range(len(joints)): 143 | im_dist = cv2.distanceTransform(255-joints[i], cv2.DIST_L1, 3) 144 | im_dist = np.clip((im_dist/3), 0, 255).astype(np.uint8) 145 | tensor_dist = transform_img(Image.fromarray(im_dist)) 146 | tensors_dist = tensor_dist if e == 1 else torch.cat([tensors_dist, tensor_dist]) 147 | e += 1 148 | tensor_pose = transform_label(pose) 149 | label_tensor = torch.cat((tensor_pose, tensors_dist), dim=0) 150 | return label_tensor, params 151 | 152 | def imgpath_to_labelpath(self, path): 153 | label_path = path.replace('/img/', '/pose/').replace('.jpg', '_{}.txt') 154 | return label_path 155 | 156 | def labelpath_to_imgpath(self, path): 157 | img_path = path.replace('/pose/', '/img/').replace('_{}.txt', '.jpg') 158 | return img_path 159 | -------------------------------------------------------------------------------- /models/networks/architecture.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.utils.spectral_norm as spectral_norm 7 | from models.networks.normalization import SPADE 8 | from util.util import vgg_preprocess 9 | 10 | 11 | class ResidualBlock(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1): 13 | super(ResidualBlock, self).__init__() 14 | self.relu = nn.PReLU() 15 | self.model = nn.Sequential( 16 | nn.ReflectionPad2d(padding), 17 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride), 18 | nn.InstanceNorm2d(out_channels), 19 | self.relu, 20 | nn.ReflectionPad2d(padding), 21 | nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride), 22 | nn.InstanceNorm2d(out_channels), 23 | ) 24 | 25 | def forward(self, x): 26 | out = self.relu(x + self.model(x)) 27 | return out 28 | 29 | 30 | class SPADEResnetBlock(nn.Module): 31 | def __init__(self, fin, fout, opt, use_se=False, dilation=1): 32 | super().__init__() 33 | # Attributes 34 | self.learned_shortcut = (fin != fout) 35 | fmiddle = min(fin, fout) 36 | self.opt = opt 37 | self.pad_type = 'nozero' 38 | self.use_se = use_se 39 | # create conv layers 40 | if self.pad_type != 'zero': 41 | self.pad = nn.ReflectionPad2d(dilation) 42 | self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=0, dilation=dilation) 43 | self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=0, dilation=dilation) 44 | else: 45 | self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation) 46 | self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation) 47 | if self.learned_shortcut: 48 | self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) 49 | # apply spectral norm if specified 50 | if 'spectral' in opt.norm_G: 51 | self.conv_0 = spectral_norm(self.conv_0) 52 | self.conv_1 = spectral_norm(self.conv_1) 53 | if self.learned_shortcut: 54 | self.conv_s = spectral_norm(self.conv_s) 55 | # define normalization layers 56 | spade_config_str = opt.norm_G.replace('spectral', '') 57 | if 'spade_ic' in opt: 58 | ic = opt.spade_ic 59 | else: 60 | ic = 4*3+opt.label_nc 61 | self.norm_0 = SPADE(spade_config_str, fin, ic, PONO=opt.PONO) 62 | self.norm_1 = SPADE(spade_config_str, fmiddle, ic, PONO=opt.PONO) 63 | if self.learned_shortcut: 64 | self.norm_s = SPADE(spade_config_str, fin, ic, PONO=opt.PONO) 65 | 66 | def forward(self, x, seg1): 67 | x_s = self.shortcut(x, seg1) 68 | if self.pad_type != 'zero': 69 | dx = self.conv_0(self.pad(self.actvn(self.norm_0(x, seg1)))) 70 | dx = self.conv_1(self.pad(self.actvn(self.norm_1(dx, seg1)))) 71 | else: 72 | dx = self.conv_0(self.actvn(self.norm_0(x, seg1))) 73 | dx = self.conv_1(self.actvn(self.norm_1(dx, seg1))) 74 | out = x_s + dx 75 | return out 76 | 77 | def shortcut(self, x, seg1): 78 | if self.learned_shortcut: 79 | x_s = self.conv_s(self.norm_s(x, seg1)) 80 | else: 81 | x_s = x 82 | return x_s 83 | 84 | def actvn(self, x): 85 | return F.leaky_relu(x, 2e-1) 86 | 87 | 88 | class VGG19_feature_color_torchversion(nn.Module): 89 | """ 90 | NOTE: there is no need to pre-process the input 91 | input tensor should range in [0,1] 92 | """ 93 | def __init__(self, pool='max', vgg_normal_correct=False, ic=3): 94 | super(VGG19_feature_color_torchversion, self).__init__() 95 | self.vgg_normal_correct = vgg_normal_correct 96 | 97 | self.conv1_1 = nn.Conv2d(ic, 64, kernel_size=3, padding=1) 98 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1) 99 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1) 100 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) 101 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1) 102 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 103 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 104 | self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1) 105 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1) 106 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 107 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 108 | self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 109 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 110 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 111 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 112 | self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1) 113 | if pool == 'max': 114 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) 115 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) 116 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) 117 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) 118 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2) 119 | elif pool == 'avg': 120 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2) 121 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2) 122 | self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2) 123 | self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2) 124 | self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2) 125 | 126 | def forward(self, x, out_keys, preprocess=True): 127 | ''' 128 | NOTE: input tensor should range in [0,1] 129 | ''' 130 | out = {} 131 | if preprocess: 132 | x = vgg_preprocess(x, vgg_normal_correct=self.vgg_normal_correct) 133 | out['r11'] = F.relu(self.conv1_1(x)) 134 | out['r12'] = F.relu(self.conv1_2(out['r11'])) 135 | out['p1'] = self.pool1(out['r12']) 136 | out['r21'] = F.relu(self.conv2_1(out['p1'])) 137 | out['r22'] = F.relu(self.conv2_2(out['r21'])) 138 | out['p2'] = self.pool2(out['r22']) 139 | out['r31'] = F.relu(self.conv3_1(out['p2'])) 140 | out['r32'] = F.relu(self.conv3_2(out['r31'])) 141 | out['r33'] = F.relu(self.conv3_3(out['r32'])) 142 | out['r34'] = F.relu(self.conv3_4(out['r33'])) 143 | out['p3'] = self.pool3(out['r34']) 144 | out['r41'] = F.relu(self.conv4_1(out['p3'])) 145 | out['r42'] = F.relu(self.conv4_2(out['r41'])) 146 | out['r43'] = F.relu(self.conv4_3(out['r42'])) 147 | out['r44'] = F.relu(self.conv4_4(out['r43'])) 148 | out['p4'] = self.pool4(out['r44']) 149 | out['r51'] = F.relu(self.conv5_1(out['p4'])) 150 | out['r52'] = F.relu(self.conv5_2(out['r51'])) 151 | out['r53'] = F.relu(self.conv5_3(out['r52'])) 152 | out['r54'] = F.relu(self.conv5_4(out['r53'])) 153 | out['p5'] = self.pool5(out['r54']) 154 | return [out[key] for key in out_keys] 155 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CoCosNet v2: Full-Resolution Correspondence Learning for Image Translation (CVPR 2021, oral presentation)
2 | ![teaser](imgs/teaser.png) 3 | 4 | **CoCosNet v2: Full-Resolution Correspondence Learning for Image Translation**
5 | **CVPR 2021, oral presentation**
6 | [Xingran Zhou](http://xingranzh.github.io/), [Bo Zhang](https://bo-zhang.me/), [Ting Zhang](https://www.microsoft.com/en-us/research/people/tinzhan/), [Pan Zhang](https://panzhang0212.github.io/), [Jianmin Bao](https://jianminbao.github.io/), [Dong Chen](https://www.microsoft.com/en-us/research/people/doch/), [Zhongfei Zhang](https://www.cs.binghamton.edu/~zhongfei/), [Fang Wen](https://www.microsoft.com/en-us/research/people/fangwen/)
7 | ### [Paper](https://arxiv.org/pdf/2012.02047.pdf) | [Slides](https://github.com/xingranzh/xingranzh.github.io/blob/master/slides/cocosnet_v2_slides.pdf)
8 | ## Abstract 9 | > We present the full-resolution correspondence learning for cross-domain images, which aids image translation. We adopt a hierarchical strategy that uses the correspondence from coarse level to guide the fine levels. At each hierarchy, the correspondence can be efficiently computed via PatchMatch that iteratively leverages the matchings from the neighborhood. Within each PatchMatch iteration, the ConvGRU module is employed to refine the current correspondence considering not only the matchings of larger context but also the historic estimates. The proposed CoCosNet v2, a GRU-assisted PatchMatch approach, is fully differentiable and highly efficient. When jointly trained with image translation, full-resolution semantic correspondence can be established in an unsupervised manner, which in turn facilitates the exemplar-based image translation. Experiments on diverse translation tasks show that CoCosNet v2 performs considerably better than state-of-the-art literature on producing high-resolution images. 10 | 11 | ## :sparkles: News 12 | 2022.12 We propose [Paint by Example](https://github.com/Fantasy-Studio/Paint-by-Example) which allows in the wild image editing according to an examplar based on **stable diffusion**. One can have a try for our [online demo](https://huggingface.co/spaces/Fantasy-Studio/Paint-by-Example). 13 | 14 | 2022.8 We recently propose [PITI](https://github.com/PITI-Synthesis/PITI) which is a SOTA image-to-image translation method based on *prtrained diffusion model*. 15 | 16 | ## Installation 17 | First please install dependencies for the experiment: 18 | ```bash 19 | pip install -r requirements.txt 20 | ```` 21 | We recommend to install Pytorch version after `Pytorch 1.6.0` since we made use of [automatic mixed precision](https://pytorch.org/docs/stable/amp.html) for accelerating. (we used `Pytorch 1.7.0` in our experiments)
22 | ## Prepare the dataset 23 | First download the Deepfashion dataset (high resolution version) from [this link](https://drive.google.com/file/d/1bByKH1ciLXY70Bp8le_AVnjk-Hd4pe_i/view?usp=sharing). Note the file name is `img_highres.zip`. Unzip the file and rename it as `img`.
24 | If the password is necessary, please contact [this link](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html) to access the dataset.
25 | We use [OpenPose](https://github.com/Hzzone/pytorch-openpose) to estimate pose of DeepFashion(HD). We offer the keypoints detection results used in our experiment in [this link](https://drive.google.com/file/d/1wxrqyb67Xu_IPyZzftLgBPHDTKGQP7Pk/view?usp=sharing). Download and unzip the results file.
26 | Since the original resolution of DeepfashionHD is 750x1101, we use a Python script to process the images to the resolution 512x512. You can find the script in [`data/preprocess.py`](https://github.com/microsoft/CoCosNet-v2/blob/main/data/preprocess.py). Note you need to download our train-val split lists `train.txt` and `val.txt` from [this link](https://drive.google.com/drive/folders/15NBujOTLnO_cRoAufWPqtOWKIinCKi0z?usp=sharing) in this step.
27 | Download the train-val lists from [this link](https://drive.google.com/drive/folders/15NBujOTLnO_cRoAufWPqtOWKIinCKi0z?usp=sharing), and the retrival pair lists from [this link](https://drive.google.com/drive/folders/1dJU8iq8kFiwq33nWtvj5Ql5rUh9fiXUi?usp=sharing). Note `train.txt` and `val.txt` are our train-val lists. `deepfashion_ref.txt`, `deepfashion_ref_test.txt` and `deepfashion_self_pair.txt` are the paring lists used in our experiment. Download them all and move below the folder `data/`.
28 | Finally create the root folder `deepfashionHD`, and move the folders `img` and `pose` below it. Now the the directory structure is like:
29 | ``` 30 | deepfashionHD 31 | │ 32 | └─── img 33 | │ │ 34 | │ └─── MEN 35 | │ │ │ ... 36 | │ │ 37 | │ └─── WOMEN 38 | │ │ ... 39 | │ 40 | └─── pose 41 | │ │ 42 | │ └─── MEN 43 | │ │ │ ... 44 | │ │ 45 | │ └─── WOMEN 46 | │ │ ... 47 | 48 | ``` 49 | ## Inference Using Pretrained Model 50 | The inference results are saved in the folder `checkpoints/deepfashionHD/test`. Download the pretrained model from [this link](https://drive.google.com/file/d/1ehkrKlf5s1gfpDNXO6AC9SIZMtqs5L3N/view?usp=sharing).
51 | Move the models below the folder `checkpoints/deepfashionHD`. Then run the following command. 52 | ````bash 53 | python test.py --name deepfashionHD --dataset_mode deepfashionHD --dataroot dataset/deepfashionHD --PONO --PONO_C --no_flip --batchSize 8 --gpu_ids 0 --netCorr NoVGGHPM --nThreads 16 --nef 32 --amp --display_winsize 512 --iteration_count 5 --load_size 512 --crop_size 512 54 | ```` 55 | The inference results are saved in the folder `checkpoints/deepfashionHD/test`.
56 | ## Training from scratch 57 | Make sure you have prepared the DeepfashionHD dataset as the instruction.
58 | Download the **pretrained VGG model** from [this link](https://drive.google.com/file/d/1D-z73DOt63BrPTgIxffN6Q4_L9qma9y8/view?usp=sharing), move it to `vgg/` folder. We use this model to calculate training loss.
59 | 60 | Run the following command for training from scratch. 61 | ````bash 62 | python train.py --name deepfashionHD --dataset_mode deepfashionHD --dataroot dataset/deepfashionHD --niter 100 --niter_decay 0 --real_reference_probability 0.0 --hard_reference_probability 0.0 --which_perceptual 4_2 --weight_perceptual 0.001 --PONO --PONO_C --vgg_normal_correct --weight_fm_ratio 1.0 --no_flip --video_like --batchSize 16 --gpu_ids 0,1,2,3,4,5,6,7 --netCorr NoVGGHPM --match_kernel 1 --featEnc_kernel 3 --display_freq 500 --print_freq 50 --save_latest_freq 2500 --save_epoch_freq 5 --nThreads 16 --weight_warp_self 500.0 --lr 0.0001 --nef 32 --amp --weight_warp_cycle 1.0 --display_winsize 512 --iteration_count 5 --temperature 0.01 --continue_train --load_size 550 --crop_size 512 --which_epoch 15 63 | ```` 64 | Note that `--dataroot` parameter is your DeepFashionHD dataset root, e.g. `dataset/DeepFashionHD`.
65 | We use 8 32GB Tesla V100 GPUs to train the network. You can set `batchSize` to 16, 8 or 4 with fewer GPUs and change `gpu_ids`. 66 | ## Citation 67 | If you use this code for your research, please cite our papers. 68 | ``` 69 | @InProceedings{Zhou_2021_CVPR, 70 | author={Zhou, Xingran and Zhang, Bo and Zhang, Ting and Zhang, Pan and Bao, Jianmin and Chen, Dong and Zhang, Zhongfei and Wen, Fang}, 71 | title={CoCosNet v2: Full-Resolution Correspondence Learning for Image Translation}, 72 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 73 | year={2021}, 74 | pages={11465-11475} 75 | } 76 | ``` 77 | 78 | Also, welcome to refer to our [CoCosNet v1](https://github.com/microsoft/CoCosNet): 79 | ``` 80 | @InProceedings{Zhang_2020_CVPR, 81 | author={Zhang, Pan and Zhang, Bo and Chen, Dong and Yuan, Lu and Wen, Fang}, 82 | title={Cross-Domain Correspondence Learning for Exemplar-Based Image Translation}, 83 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 84 | year={2020}, 85 | pages={5143-5153} 86 | } 87 | ``` 88 | 89 | ## Acknowledgments 90 | *This code borrows heavily from [CocosNet](https://github.com/microsoft/CoCosNet) and [DeepPruner](https://github.com/uber-research/DeepPruner). 91 | We also thank [SPADE](https://github.com/NVlabs/SPADE) and [RAFT](https://github.com/princeton-vl/RAFT).* 92 | ## License 93 | The codes and the pretrained model in this repository are under the MIT license as specified by the LICENSE file.
94 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 95 | -------------------------------------------------------------------------------- /models/networks/patch_match.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import random 10 | 11 | from models.networks.convgru import BasicUpdateBlock 12 | from models.networks.ops import * 13 | 14 | 15 | """patch match""" 16 | class Evaluate(nn.Module): 17 | def __init__(self, temperature): 18 | super().__init__() 19 | self.filter_size = 3 20 | self.temperature = temperature 21 | 22 | def forward(self, left_features, right_features, offset_x, offset_y): 23 | device = left_features.get_device() 24 | batch_size, num, height, width = offset_x.size() 25 | channel = left_features.size()[1] 26 | matching_inds = offset_to_inds(offset_x, offset_y) 27 | matching_inds = matching_inds.view(batch_size, num, height * width).permute(0, 2, 1).long() 28 | base_batch = torch.arange(batch_size).to(device).long() * (height * width) 29 | base_batch = base_batch.view(-1, 1, 1) 30 | matching_inds_add_base = matching_inds + base_batch 31 | right_features_view = right_features 32 | match_cost = [] 33 | # using A[:, idx] 34 | for i in range(matching_inds_add_base.size()[-1]): 35 | idx = matching_inds_add_base[:, :, i] 36 | idx = idx.contiguous().view(-1) 37 | right_features_select = right_features_view[:, idx] 38 | right_features_select = right_features_select.view(channel, batch_size, -1).transpose(0, 1) 39 | match_cost_i = torch.sum(left_features * right_features_select, dim=1, keepdim=True) / self.temperature 40 | match_cost.append(match_cost_i) 41 | match_cost = torch.cat(match_cost, dim=1).transpose(1, 2) 42 | match_cost = F.softmax(match_cost, dim=-1) 43 | match_cost_topk, match_cost_topk_indices = torch.topk(match_cost, num//self.filter_size, dim=-1) 44 | matching_inds = torch.gather(matching_inds, -1, match_cost_topk_indices) 45 | matching_inds = matching_inds.permute(0, 2, 1).view(batch_size, -1, height, width).float() 46 | offset_x, offset_y = inds_to_offset(matching_inds) 47 | corr = match_cost_topk.permute(0, 2, 1) 48 | return offset_x, offset_y, corr 49 | 50 | 51 | class PropagationFaster(nn.Module): 52 | def __init__(self): 53 | super().__init__() 54 | 55 | def forward(self, offset_x, offset_y, propagation_type="horizontal"): 56 | device = offset_x.get_device() 57 | self.horizontal_zeros = torch.zeros((offset_x.size()[0], offset_x.size()[1], offset_x.size()[2], 1)).to(device) 58 | self.vertical_zeros = torch.zeros((offset_x.size()[0], offset_x.size()[1], 1, offset_x.size()[3])).to(device) 59 | if propagation_type is "horizontal": 60 | offset_x = torch.cat((torch.cat((self.horizontal_zeros, offset_x[:, :, :, :-1]), dim=3), 61 | offset_x, 62 | torch.cat((offset_x[:, :, :, 1:], self.horizontal_zeros), dim=3)), dim=1) 63 | 64 | offset_y = torch.cat((torch.cat((self.horizontal_zeros, offset_y[:, :, :, :-1]), dim=3), 65 | offset_y, 66 | torch.cat((offset_y[:, :, :, 1:], self.horizontal_zeros), dim=3)), dim=1) 67 | 68 | else: 69 | offset_x = torch.cat((torch.cat((self.vertical_zeros, offset_x[:, :, :-1, :]), dim=2), 70 | offset_x, 71 | torch.cat((offset_x[:, :, 1:, :], self.vertical_zeros), dim=2)), dim=1) 72 | 73 | offset_y = torch.cat((torch.cat((self.vertical_zeros, offset_y[:, :, :-1, :]), dim=2), 74 | offset_y, 75 | torch.cat((offset_y[:, :, 1:, :], self.vertical_zeros), dim=2)), dim=1) 76 | return offset_x, offset_y 77 | 78 | 79 | class PatchMatchOnce(nn.Module): 80 | def __init__(self, opt): 81 | super().__init__() 82 | self.propagation = PropagationFaster() 83 | self.evaluate = Evaluate(opt.temperature) 84 | 85 | def forward(self, left_features, right_features, offset_x, offset_y): 86 | prob = random.random() 87 | if prob < 0.5: 88 | offset_x, offset_y = self.propagation(offset_x, offset_y, "horizontal") 89 | offset_x, offset_y, _ = self.evaluate(left_features, right_features, offset_x, offset_y) 90 | offset_x, offset_y = self.propagation(offset_x, offset_y, "vertical") 91 | offset_x, offset_y, corr = self.evaluate(left_features, right_features, offset_x, offset_y) 92 | else: 93 | offset_x, offset_y = self.propagation(offset_x, offset_y, "vertical") 94 | offset_x, offset_y, _ = self.evaluate(left_features, right_features, offset_x, offset_y) 95 | offset_x, offset_y = self.propagation(offset_x, offset_y, "horizontal") 96 | offset_x, offset_y, corr = self.evaluate(left_features, right_features, offset_x, offset_y) 97 | return offset_x, offset_y, corr 98 | 99 | 100 | class PatchMatchGRU(nn.Module): 101 | def __init__(self, opt): 102 | super().__init__() 103 | self.patch_match_one_step = PatchMatchOnce(opt) 104 | self.temperature = opt.temperature 105 | self.iters = opt.iteration_count 106 | input_dim = opt.nef 107 | hidden_dim = 32 108 | norm = nn.InstanceNorm2d(hidden_dim, affine=False) 109 | relu = nn.ReLU(inplace=True) 110 | """ 111 | concat left and right features 112 | """ 113 | self.initial_layer = nn.Sequential( 114 | nn.Conv2d(input_dim*2, hidden_dim, kernel_size=3, padding=1, stride=1), 115 | norm, 116 | relu, 117 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), 118 | norm, 119 | relu, 120 | ) 121 | self.refine_net = BasicUpdateBlock() 122 | 123 | def forward(self, left_features, right_features, right_input, initial_offset_x, initial_offset_y): 124 | device = left_features.get_device() 125 | batch_size, channel, height, width = left_features.size() 126 | num = initial_offset_x.size()[1] 127 | initial_input = torch.cat((left_features, right_features), dim=1) 128 | hidden = self.initial_layer(initial_input) 129 | left_features = left_features.view(batch_size, -1, height * width) 130 | right_features = right_features.view(batch_size, -1, height * width) 131 | right_features_view = right_features.transpose(0, 1).contiguous().view(channel, -1) 132 | with torch.no_grad(): 133 | offset_x, offset_y = initial_offset_x, initial_offset_y 134 | for it in range(self.iters): 135 | with torch.no_grad(): 136 | offset_x, offset_y, corr = self.patch_match_one_step(left_features, right_features_view, offset_x, offset_y) 137 | """GRU refinement""" 138 | flow = torch.cat((offset_x, offset_y), dim=1) 139 | corr = corr.view(batch_size, -1, height, width) 140 | hidden, delta_offset_x, delta_offset_y = self.refine_net(hidden, corr, flow) 141 | offset_x = offset_x + delta_offset_x 142 | offset_y = offset_y + delta_offset_y 143 | with torch.no_grad(): 144 | matching_inds = offset_to_inds(offset_x, offset_y) 145 | matching_inds = matching_inds.view(batch_size, num, height * width).permute(0, 2, 1).long() 146 | base_batch = torch.arange(batch_size).to(device).long() * (height * width) 147 | base_batch = base_batch.view(-1, 1, 1) 148 | matching_inds_plus_base = matching_inds + base_batch 149 | match_cost = [] 150 | # using A[:, idx] 151 | for i in range(matching_inds_plus_base.size()[-1]): 152 | idx = matching_inds_plus_base[:, :, i] 153 | idx = idx.contiguous().view(-1) 154 | right_features_select = right_features_view[:, idx] 155 | right_features_select = right_features_select.view(channel, batch_size, -1).transpose(0, 1) 156 | match_cost_i = torch.sum(left_features * right_features_select, dim=1, keepdim=True) / self.temperature 157 | match_cost.append(match_cost_i) 158 | match_cost = torch.cat(match_cost, dim=1).transpose(1, 2) 159 | match_cost = F.softmax(match_cost, dim=-1) 160 | right_input_view = right_input.transpose(0, 1).contiguous().view(right_input.size()[1], -1) 161 | warp = torch.zeros_like(right_input) 162 | # using A[:, idx] 163 | for i in range(match_cost.size()[-1]): 164 | idx = matching_inds_plus_base[:, :, i] 165 | idx = idx.contiguous().view(-1) 166 | right_input_select = right_input_view[:, idx] 167 | right_input_select = right_input_select.view(right_input.size()[1], batch_size, -1).transpose(0, 1) 168 | warp = warp + right_input_select * match_cost[:, :, i].unsqueeze(dim=1) 169 | return matching_inds, warp 170 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import os 5 | import sys 6 | import random 7 | import argparse 8 | import pickle 9 | import numpy as np 10 | import torch 11 | import models 12 | import data 13 | from util import util 14 | 15 | 16 | class BaseOptions(): 17 | def __init__(self): 18 | self.initialized = False 19 | 20 | def initialize(self, parser): 21 | # experiment specifics 22 | parser.add_argument('--name', type=str, default='deepfashionHD', help='name of the experiment. It decides where to store samples and models') 23 | parser.add_argument('--gpu_ids', type=str, default='0,1,2,3', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 24 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 25 | parser.add_argument('--model', type=str, default='pix2pix', help='which model to use') 26 | parser.add_argument('--norm_G', type=str, default='spectralinstance', help='instance normalization or batch normalization') 27 | parser.add_argument('--norm_D', type=str, default='spectralinstance', help='instance normalization or batch normalization') 28 | parser.add_argument('--norm_E', type=str, default='spectralinstance', help='instance normalization or batch normalization') 29 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 30 | # input/output sizes 31 | parser.add_argument('--batchSize', type=int, default=4, help='input batch size') 32 | parser.add_argument('--preprocess_mode', type=str, default='scale_width_and_crop', help='scaling and cropping of images at load time.', choices=("resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "scale_shortside", "scale_shortside_and_crop", "fixed", "none")) 33 | parser.add_argument('--load_size', type=int, default=256, help='Scale images to this size. The final image will be cropped to --crop_size.') 34 | parser.add_argument('--crop_size', type=int, default=256, help='Crop to the width of crop_size (after initially scaling the images to load_size.)') 35 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='The ratio width/height. The final height of the load image will be crop_size/aspect_ratio') 36 | parser.add_argument('--label_nc', type=int, default=182, help='# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.') 37 | parser.add_argument('--contain_dontcare_label', action='store_true', help='if the label map contains dontcare label (dontcare=255)') 38 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels') 39 | # for setting inputs 40 | parser.add_argument('--dataroot', type=str, default='dataset/deepfashionHD') 41 | parser.add_argument('--dataset_mode', type=str, default='deepfashionHD') 42 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 43 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation') 44 | parser.add_argument('--nThreads', default=16, type=int, help='# threads for loading data') 45 | parser.add_argument('--max_dataset_size', type=int, default=sys.maxsize, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 46 | parser.add_argument('--load_from_opt_file', action='store_true', help='load the options from checkpoints and use that as default') 47 | parser.add_argument('--cache_filelist_write', action='store_true', help='saves the current filelist into a text file, so that it loads faster') 48 | parser.add_argument('--cache_filelist_read', action='store_true', help='reads from the file list cache') 49 | # for displays 50 | parser.add_argument('--display_winsize', type=int, default=512, help='display window size') 51 | # for generator 52 | parser.add_argument('--netG', type=str, default='spade', help='selects model to use for netG (pix2pixhd | spade)') 53 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer') 54 | parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]') 55 | parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution') 56 | # for feature encoder 57 | parser.add_argument('--netCorr', type=str, default='NoVGGHPM') 58 | parser.add_argument('--nef', type=int, default=32, help='# of gen filters in first conv layer') 59 | # for instance-wise features 60 | parser.add_argument('--CBN_intype', type=str, default='warp_mask', help='type of CBN input for framework, warp/mask/warp_mask') 61 | parser.add_argument('--match_kernel', type=int, default=1, help='correspondence matrix match kernel size') 62 | parser.add_argument('--featEnc_kernel', type=int, default=3, help='kernel size in domain adaptor') 63 | parser.add_argument('--PONO', action='store_true', help='use positional normalization ') 64 | parser.add_argument('--PONO_C', action='store_true', help='use C normalization in corr module') 65 | parser.add_argument('--vgg_normal_correct', action='store_true', help='if true, correct vgg normalization and replace vgg FM model with ctx model') 66 | parser.add_argument('--use_coordconv', action='store_true', help='if true, use coordconv in CorrNet') 67 | parser.add_argument('--video_like', action='store_true', help='useful in deepfashion') 68 | parser.add_argument('--amp', action='store_true', help='use torch.cuda.amp') 69 | parser.add_argument('--temperature', type=float, default=0.01) 70 | parser.add_argument('--iteration_count', type=int, default=5) 71 | self.initialized = True 72 | return parser 73 | 74 | def gather_options(self): 75 | # initialize parser with basic options 76 | if not self.initialized: 77 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 78 | parser = self.initialize(parser) 79 | # get the basic options 80 | opt, unknown = parser.parse_known_args() 81 | # modify model-related parser options 82 | model_name = opt.model 83 | model_option_setter = models.get_option_setter(model_name) 84 | parser = model_option_setter(parser, self.isTrain) 85 | # modify dataset-related parser options 86 | dataset_mode = opt.dataset_mode 87 | dataset_option_setter = data.get_option_setter(dataset_mode) 88 | parser = dataset_option_setter(parser, self.isTrain) 89 | opt, unknown = parser.parse_known_args() 90 | # if there is opt_file, load it. 91 | # The previous default options will be overwritten 92 | if opt.load_from_opt_file: 93 | parser = self.update_options_from_file(parser, opt) 94 | opt = parser.parse_args() 95 | self.parser = parser 96 | return opt 97 | 98 | def print_options(self, opt): 99 | message = '' 100 | message += '----------------- Options ---------------\n' 101 | for k, v in sorted(vars(opt).items()): 102 | comment = '' 103 | default = self.parser.get_default(k) 104 | if v != default: 105 | comment = '\t[default: %s]' % str(default) 106 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 107 | message += '----------------- End -------------------' 108 | print(message) 109 | 110 | def option_file_path(self, opt, makedir=False): 111 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 112 | if makedir: 113 | util.mkdirs(expr_dir) 114 | file_name = os.path.join(expr_dir, 'opt') 115 | return file_name 116 | 117 | def save_options(self, opt): 118 | file_name = self.option_file_path(opt, makedir=True) 119 | with open(file_name + '.txt', 'wt') as opt_file: 120 | for k, v in sorted(vars(opt).items()): 121 | comment = '' 122 | default = self.parser.get_default(k) 123 | if v != default: 124 | comment = '\t[default: %s]' % str(default) 125 | opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)) 126 | with open(file_name + '.pkl', 'wb') as opt_file: 127 | pickle.dump(opt, opt_file) 128 | 129 | def update_options_from_file(self, parser, opt): 130 | new_opt = self.load_options(opt) 131 | for k, v in sorted(vars(opt).items()): 132 | if hasattr(new_opt, k) and v != getattr(new_opt, k): 133 | new_val = getattr(new_opt, k) 134 | parser.set_defaults(**{k: new_val}) 135 | return parser 136 | 137 | def load_options(self, opt): 138 | file_name = self.option_file_path(opt, makedir=False) 139 | new_opt = pickle.load(open(file_name + '.pkl', 'rb')) 140 | return new_opt 141 | 142 | def parse(self, save=False): 143 | # gather options from base, train, dataset, model 144 | opt = self.gather_options() 145 | # train or test 146 | opt.isTrain = self.isTrain 147 | self.print_options(opt) 148 | if opt.isTrain: 149 | self.save_options(opt) 150 | # Set semantic_nc based on the option. 151 | # This will be convenient in many places 152 | opt.semantic_nc = opt.label_nc + (1 if opt.contain_dontcare_label else 0) 153 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids 154 | str_ids = opt.gpu_ids.split(',') 155 | opt.gpu_ids = list(range(len(str_ids))) 156 | seed = 1234 157 | random.seed(seed) 158 | np.random.seed(seed) 159 | torch.manual_seed(seed) 160 | torch.random.manual_seed(seed) 161 | torch.cuda.manual_seed_all(seed) 162 | torch.backends.cudnn.benchmark = True 163 | if len(opt.gpu_ids) > 0: 164 | torch.cuda.set_device(opt.gpu_ids[0]) 165 | self.opt = opt 166 | return self.opt 167 | -------------------------------------------------------------------------------- /models/networks/correspondence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import util.util as util 9 | from models.networks.base_network import BaseNetwork 10 | from models.networks.architecture import ResidualBlock 11 | from models.networks.normalization import get_nonspade_norm_layer 12 | from models.networks.architecture import SPADEResnetBlock 13 | """patch match""" 14 | from models.networks.patch_match import PatchMatchGRU 15 | from models.networks.ops import * 16 | 17 | 18 | def match_kernel_and_pono_c(feature, match_kernel, PONO_C, eps=1e-10): 19 | b, c, h, w = feature.size() 20 | if match_kernel == 1: 21 | feature = feature.view(b, c, -1) 22 | else: 23 | feature = F.unfold(feature, kernel_size=match_kernel, padding=int(match_kernel//2)) 24 | dim_mean = 1 if PONO_C else -1 25 | feature = feature - feature.mean(dim=dim_mean, keepdim=True) 26 | feature_norm = torch.norm(feature, 2, 1, keepdim=True) + eps 27 | feature = torch.div(feature, feature_norm) 28 | return feature.view(b, -1, h, w) 29 | 30 | 31 | """512x512""" 32 | class AdaptiveFeatureGenerator(BaseNetwork): 33 | @staticmethod 34 | def modify_commandline_options(parser, is_train): 35 | return parser 36 | 37 | def __init__(self, opt): 38 | super().__init__() 39 | self.opt = opt 40 | kw = opt.featEnc_kernel 41 | pw = int((kw-1)//2) 42 | nf = opt.nef 43 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_E) 44 | self.layer1 = norm_layer(nn.Conv2d(opt.spade_ic, nf, 3, stride=1, padding=pw)) 45 | self.layer2 = nn.Sequential( 46 | norm_layer(nn.Conv2d(nf * 1, nf * 2, 3, 1, 1)), 47 | ResidualBlock(nf * 2, nf * 2), 48 | ) 49 | self.layer3 = nn.Sequential( 50 | norm_layer(nn.Conv2d(nf * 2, nf * 4, kw, stride=2, padding=pw)), 51 | ResidualBlock(nf * 4, nf * 4), 52 | ) 53 | self.layer4 = nn.Sequential( 54 | norm_layer(nn.Conv2d(nf * 4, nf * 4, kw, stride=2, padding=pw)), 55 | ResidualBlock(nf * 4, nf * 4), 56 | ) 57 | self.layer5 = nn.Sequential( 58 | norm_layer(nn.Conv2d(nf * 4, nf * 4, kw, stride=2, padding=pw)), 59 | ResidualBlock(nf * 4, nf * 4), 60 | ) 61 | self.head_0 = SPADEResnetBlock(nf * 4, nf * 4, opt) 62 | self.G_middle_0 = SPADEResnetBlock(nf * 4, nf * 4, opt) 63 | self.G_middle_1 = SPADEResnetBlock(nf * 4, nf * 2, opt) 64 | self.G_middle_2 = SPADEResnetBlock(nf * 2, nf * 1, opt) 65 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 66 | 67 | def forward(self, input, seg): 68 | # 512 69 | x1 = self.layer1(input) 70 | # 512 71 | x2 = self.layer2(self.actvn(x1)) 72 | # 256 73 | x3 = self.layer3(self.actvn(x2)) 74 | # 128 75 | x4 = self.layer4(self.actvn(x3)) 76 | # 64 77 | x5 = self.layer5(self.actvn(x4)) 78 | # bottleneck 79 | x6 = self.head_0(x5, seg) 80 | # 128 81 | x7 = self.G_middle_0(self.up(x6) + x4, seg) 82 | # 256 83 | x8 = self.G_middle_1(self.up(x7) + x3, seg) 84 | # 512 85 | x9 = self.G_middle_2(self.up(x8) + x2, seg) 86 | return [x6, x7, x8, x9] 87 | 88 | def actvn(self, x): 89 | return F.leaky_relu(x, 2e-1) 90 | 91 | 92 | class NoVGGHPMCorrespondence(BaseNetwork): 93 | def __init__(self, opt): 94 | self.opt = opt 95 | super().__init__() 96 | opt.spade_ic = opt.semantic_nc 97 | self.adaptive_model_seg = AdaptiveFeatureGenerator(opt) 98 | opt.spade_ic = 3 + opt.semantic_nc 99 | self.adaptive_model_img = AdaptiveFeatureGenerator(opt) 100 | del opt.spade_ic 101 | self.batch_size = opt.batchSize 102 | """512x512""" 103 | feature_channel = opt.nef 104 | self.phi_0 = nn.Conv2d(in_channels=feature_channel*4, out_channels=feature_channel, kernel_size=1, stride=1, padding=0) 105 | self.phi_1 = nn.Conv2d(in_channels=feature_channel*4, out_channels=feature_channel, kernel_size=1, stride=1, padding=0) 106 | self.phi_2 = nn.Conv2d(in_channels=feature_channel*2, out_channels=feature_channel, kernel_size=1, stride=1, padding=0) 107 | self.phi_3 = nn.Conv2d(in_channels=feature_channel, out_channels=feature_channel, kernel_size=1, stride=1, padding=0) 108 | self.theta_0 = nn.Conv2d(in_channels=feature_channel*4, out_channels=feature_channel, kernel_size=1, stride=1, padding=0) 109 | self.theta_1 = nn.Conv2d(in_channels=feature_channel*4, out_channels=feature_channel, kernel_size=1, stride=1, padding=0) 110 | self.theta_2 = nn.Conv2d(in_channels=feature_channel*2, out_channels=feature_channel, kernel_size=1, stride=1, padding=0) 111 | self.theta_3 = nn.Conv2d(in_channels=feature_channel, out_channels=feature_channel, kernel_size=1, stride=1, padding=0) 112 | self.patch_match = PatchMatchGRU(opt) 113 | 114 | """512x512""" 115 | def multi_scale_patch_match(self, f1, f2, ref, hierarchical_scale, pre=None, real_img=None): 116 | if hierarchical_scale == 0: 117 | y_cycle = None 118 | scale = 64 119 | batch_size, channel, feature_height, feature_width = f1.size() 120 | ref = F.avg_pool2d(ref, 8, stride=8) 121 | ref = ref.view(batch_size, 3, scale * scale) 122 | f1 = f1.view(batch_size, channel, scale * scale) 123 | f2 = f2.view(batch_size, channel, scale * scale) 124 | matmul_result = torch.matmul(f1.permute(0, 2, 1), f2)/self.opt.temperature 125 | mat = F.softmax(matmul_result, dim=-1) 126 | y = torch.matmul(mat, ref.permute(0, 2, 1)) 127 | if self.opt.phase is 'train' and self.opt.weight_warp_cycle > 0: 128 | mat_cycle = F.softmax(matmul_result.transpose(1, 2), dim=-1) 129 | y_cycle = torch.matmul(mat_cycle, y) 130 | y_cycle = y_cycle.permute(0, 2, 1).view(batch_size, 3, scale, scale) 131 | y = y.permute(0, 2, 1).view(batch_size, 3, scale, scale) 132 | return mat, y, y_cycle 133 | if hierarchical_scale == 1: 134 | scale = 128 135 | with torch.no_grad(): 136 | batch_size, channel, feature_height, feature_width = f1.size() 137 | topk_num = 1 138 | search_window = 4 139 | centering = 1 140 | dilation = 2 141 | total_candidate_num = topk_num * (search_window ** 2) 142 | topk_inds = torch.topk(pre, topk_num, dim=-1)[-1] 143 | inds = topk_inds.permute(0, 2, 1).view(batch_size, topk_num, (scale//2), (scale//2)).float() 144 | offset_x, offset_y = inds_to_offset(inds) 145 | dx = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=1).expand(-1, search_window).contiguous().view(-1) - centering 146 | dy = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=0).expand(search_window, -1).contiguous().view(-1) - centering 147 | dx = dx.view(1, search_window ** 2, 1, 1) * dilation 148 | dy = dy.view(1, search_window ** 2, 1, 1) * dilation 149 | offset_x_up = F.interpolate((2 * offset_x + dx), scale_factor=2) 150 | offset_y_up = F.interpolate((2 * offset_y + dy), scale_factor=2) 151 | ref = F.avg_pool2d(ref, 4, stride=4) 152 | ref = ref.view(batch_size, 3, scale * scale) 153 | mat, y = self.patch_match(f1, f2, ref, offset_x_up, offset_y_up) 154 | y = y.view(batch_size, 3, scale, scale) 155 | return mat, y 156 | if hierarchical_scale == 2: 157 | scale = 256 158 | with torch.no_grad(): 159 | batch_size, channel, feature_height, feature_width = f1.size() 160 | topk_num = 1 161 | search_window = 4 162 | centering = 1 163 | dilation = 2 164 | total_candidate_num = topk_num * (search_window ** 2) 165 | topk_inds = pre[:, :, :topk_num] 166 | inds = topk_inds.permute(0, 2, 1).view(batch_size, topk_num, (scale//2), (scale//2)).float() 167 | offset_x, offset_y = inds_to_offset(inds) 168 | dx = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=1).expand(-1, search_window).contiguous().view(-1) - centering 169 | dy = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=0).expand(search_window, -1).contiguous().view(-1) - centering 170 | dx = dx.view(1, search_window ** 2, 1, 1) * dilation 171 | dy = dy.view(1, search_window ** 2, 1, 1) * dilation 172 | offset_x_up = F.interpolate((2 * offset_x + dx), scale_factor=2) 173 | offset_y_up = F.interpolate((2 * offset_y + dy), scale_factor=2) 174 | ref = F.avg_pool2d(ref, 2, stride=2) 175 | ref = ref.view(batch_size, 3, scale * scale) 176 | mat, y = self.patch_match(f1, f2, ref, offset_x_up, offset_y_up) 177 | y = y.view(batch_size, 3, scale, scale) 178 | return mat, y 179 | if hierarchical_scale == 3: 180 | scale = 512 181 | with torch.no_grad(): 182 | batch_size, channel, feature_height, feature_width = f1.size() 183 | topk_num = 1 184 | search_window = 4 185 | centering = 1 186 | dilation = 2 187 | total_candidate_num = topk_num * (search_window ** 2) 188 | topk_inds = pre[:, :, :topk_num] 189 | inds = topk_inds.permute(0, 2, 1).view(batch_size, topk_num, (scale//2), (scale//2)).float() 190 | offset_x, offset_y = inds_to_offset(inds) 191 | dx = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=1).expand(-1, search_window).contiguous().view(-1) - centering 192 | dy = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=0).expand(search_window, -1).contiguous().view(-1) - centering 193 | dx = dx.view(1, search_window ** 2, 1, 1) * dilation 194 | dy = dy.view(1, search_window ** 2, 1, 1) * dilation 195 | offset_x_up = F.interpolate((2 * offset_x + dx), scale_factor=2) 196 | offset_y_up = F.interpolate((2 * offset_y + dy), scale_factor=2) 197 | ref = ref.view(batch_size, 3, scale * scale) 198 | mat, y = self.patch_match(f1, f2, ref, offset_x_up, offset_y_up) 199 | y = y.view(batch_size, 3, scale, scale) 200 | return mat, y 201 | 202 | def forward(self, ref_img, real_img, seg_map, ref_seg_map): 203 | corr_out = {} 204 | seg_input = seg_map 205 | adaptive_feature_seg = self.adaptive_model_seg(seg_input, seg_input) 206 | ref_input = torch.cat((ref_img, ref_seg_map), dim=1) 207 | adaptive_feature_img = self.adaptive_model_img(ref_input, ref_input) 208 | for i in range(len(adaptive_feature_seg)): 209 | adaptive_feature_seg[i] = util.feature_normalize(adaptive_feature_seg[i]) 210 | adaptive_feature_img[i] = util.feature_normalize(adaptive_feature_img[i]) 211 | if self.opt.isTrain and self.opt.weight_novgg_featpair > 0: 212 | real_input = torch.cat((real_img, seg_map), dim=1) 213 | adaptive_feature_img_pair = self.adaptive_model_img(real_input, real_input) 214 | loss_novgg_featpair = 0 215 | weights = [1.0, 1.0, 1.0, 1.0] 216 | for i in range(len(adaptive_feature_img_pair)): 217 | adaptive_feature_img_pair[i] = util.feature_normalize(adaptive_feature_img_pair[i]) 218 | loss_novgg_featpair += F.l1_loss(adaptive_feature_seg[i], adaptive_feature_img_pair[i]) * weights[i] 219 | corr_out['loss_novgg_featpair'] = loss_novgg_featpair * self.opt.weight_novgg_featpair 220 | cont_features = adaptive_feature_seg 221 | ref_features = adaptive_feature_img 222 | theta = [] 223 | phi = [] 224 | """512x512""" 225 | theta.append(match_kernel_and_pono_c(self.theta_0(cont_features[0]), self.opt.match_kernel, self.opt.PONO_C)) 226 | theta.append(match_kernel_and_pono_c(self.theta_1(cont_features[1]), self.opt.match_kernel, self.opt.PONO_C)) 227 | theta.append(match_kernel_and_pono_c(self.theta_2(cont_features[2]), self.opt.match_kernel, self.opt.PONO_C)) 228 | theta.append(match_kernel_and_pono_c(self.theta_3(cont_features[3]), self.opt.match_kernel, self.opt.PONO_C)) 229 | phi.append(match_kernel_and_pono_c(self.phi_0(ref_features[0]), self.opt.match_kernel, self.opt.PONO_C)) 230 | phi.append(match_kernel_and_pono_c(self.phi_1(ref_features[1]), self.opt.match_kernel, self.opt.PONO_C)) 231 | phi.append(match_kernel_and_pono_c(self.phi_2(ref_features[2]), self.opt.match_kernel, self.opt.PONO_C)) 232 | phi.append(match_kernel_and_pono_c(self.phi_3(ref_features[3]), self.opt.match_kernel, self.opt.PONO_C)) 233 | ref = ref_img 234 | ys = [] 235 | m = None 236 | for i in range(len(theta)): 237 | if i == 0: 238 | m, y, y_cycle = self.multi_scale_patch_match(theta[i], phi[i], ref, i, pre=m) 239 | if y_cycle is not None: 240 | corr_out['warp_cycle'] = y_cycle 241 | else: 242 | m, y = self.multi_scale_patch_match(theta[i], phi[i], ref, i, pre=m) 243 | ys.append(y) 244 | corr_out['warp_out'] = ys 245 | return corr_out 246 | -------------------------------------------------------------------------------- /models/pix2pix_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT License. 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import models.networks as networks 7 | import util.util as util 8 | import itertools 9 | try: 10 | from torch.cuda.amp import autocast 11 | except: 12 | # dummy autocast for PyTorch < 1.6 13 | class autocast: 14 | def __init__(self, enabled): 15 | pass 16 | def __enter__(self): 17 | pass 18 | def __exit__(self, *args): 19 | pass 20 | 21 | 22 | class Pix2PixModel(torch.nn.Module): 23 | @staticmethod 24 | def modify_commandline_options(parser, is_train): 25 | networks.modify_commandline_options(parser, is_train) 26 | return parser 27 | 28 | def __init__(self, opt): 29 | super().__init__() 30 | self.opt = opt 31 | self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \ 32 | else torch.FloatTensor 33 | self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \ 34 | else torch.ByteTensor 35 | self.net = torch.nn.ModuleDict(self.initialize_networks(opt)) 36 | # set loss functions 37 | if opt.isTrain: 38 | # vgg network 39 | self.vggnet_fix = networks.architecture.VGG19_feature_color_torchversion(vgg_normal_correct=opt.vgg_normal_correct) 40 | self.vggnet_fix.load_state_dict(torch.load('vgg/vgg19_conv.pth')) 41 | self.vggnet_fix.eval() 42 | for param in self.vggnet_fix.parameters(): 43 | param.requires_grad = False 44 | self.vggnet_fix.to(self.opt.gpu_ids[0]) 45 | # contextual loss 46 | self.contextual_forward_loss = networks.ContextualLoss_forward(opt) 47 | # GAN loss 48 | self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt) 49 | # L1 loss 50 | self.criterionFeat = torch.nn.L1Loss() 51 | # L2 loss 52 | self.MSE_loss = torch.nn.MSELoss() 53 | # setting which layer is used in the perceptual loss 54 | if opt.which_perceptual == '5_2': 55 | self.perceptual_layer = -1 56 | elif opt.which_perceptual == '4_2': 57 | self.perceptual_layer = -2 58 | 59 | def forward(self, data, mode, GforD=None): 60 | input_label, input_semantics, real_image, self_ref, ref_image, ref_label, ref_semantics = self.preprocess_input(data, ) 61 | generated_out = {} 62 | 63 | if mode == 'generator': 64 | g_loss, generated_out = self.compute_generator_loss(input_label, \ 65 | input_semantics, real_image, ref_label, \ 66 | ref_semantics, ref_image, self_ref) 67 | out = {} 68 | out['fake_image'] = generated_out['fake_image'] 69 | out['input_semantics'] = input_semantics 70 | out['ref_semantics'] = ref_semantics 71 | out['warp_out'] = None if 'warp_out' not in generated_out else generated_out['warp_out'] 72 | out['adaptive_feature_seg'] = None if 'adaptive_feature_seg' not in generated_out else generated_out['adaptive_feature_seg'] 73 | out['adaptive_feature_img'] = None if 'adaptive_feature_img' not in generated_out else generated_out['adaptive_feature_img'] 74 | out['warp_cycle'] = None if 'warp_cycle' not in generated_out else generated_out['warp_cycle'] 75 | return g_loss, out 76 | 77 | elif mode == 'discriminator': 78 | d_loss = self.compute_discriminator_loss(input_semantics, \ 79 | real_image, GforD, label=input_label) 80 | return d_loss 81 | 82 | elif mode == 'inference': 83 | out = {} 84 | with torch.no_grad(): 85 | out = self.inference(input_semantics, ref_semantics=ref_semantics, \ 86 | ref_image=ref_image, self_ref=self_ref, \ 87 | real_image=real_image) 88 | out['input_semantics'] = input_semantics 89 | out['ref_semantics'] = ref_semantics 90 | return out 91 | 92 | else: 93 | raise ValueError("|mode| is invalid") 94 | 95 | def create_optimizers(self, opt): 96 | if opt.no_TTUR: 97 | beta1, beta2 = opt.beta1, opt.beta2 98 | G_lr, D_lr = opt.lr, opt.lr 99 | else: 100 | beta1, beta2 = 0, 0.9 101 | G_lr, D_lr = opt.lr / 2, opt.lr * 2 102 | optimizer_G = torch.optim.Adam(itertools.chain(self.net['netG'].parameters(), \ 103 | self.net['netCorr'].parameters()), lr=G_lr, betas=(beta1, beta2), eps=1e-3) 104 | optimizer_D = torch.optim.Adam(itertools.chain(self.net['netD'].parameters()), \ 105 | lr=D_lr, betas=(beta1, beta2)) 106 | return optimizer_G, optimizer_D 107 | 108 | def save(self, epoch): 109 | util.save_network(self.net['netG'], 'G', epoch, self.opt) 110 | util.save_network(self.net['netD'], 'D', epoch, self.opt) 111 | util.save_network(self.net['netCorr'], 'Corr', epoch, self.opt) 112 | 113 | def initialize_networks(self, opt): 114 | net = {} 115 | net['netG'] = networks.define_G(opt) 116 | net['netD'] = networks.define_D(opt) if opt.isTrain else None 117 | net['netCorr'] = networks.define_Corr(opt) 118 | if not opt.isTrain or opt.continue_train: 119 | net['netCorr'] = util.load_network(net['netCorr'], 'Corr', opt.which_epoch, opt) 120 | net['netG'] = util.load_network(net['netG'], 'G', opt.which_epoch, opt) 121 | if opt.isTrain: 122 | net['netD'] = util.load_network(net['netD'], 'D', opt.which_epoch, opt) 123 | return net 124 | 125 | def preprocess_input(self, data): 126 | if self.use_gpu(): 127 | for k in data.keys(): 128 | try: 129 | data[k] = data[k].cuda() 130 | except: 131 | continue 132 | label = data['label'][:,:3,:,:].float() 133 | label_ref = data['label_ref'][:,:3,:,:].float() 134 | input_semantics = data['label'].float() 135 | ref_semantics = data['label_ref'].float() 136 | image = data['image'] 137 | ref = data['ref'] 138 | self_ref = data['self_ref'] 139 | return label, input_semantics, image, self_ref, ref, label_ref, ref_semantics 140 | 141 | def get_ctx_loss(self, source, target): 142 | contextual_style5_1 = torch.mean(self.contextual_forward_loss(source[-1], target[-1].detach())) * 8 143 | contextual_style4_1 = torch.mean(self.contextual_forward_loss(source[-2], target[-2].detach())) * 4 144 | contextual_style3_1 = torch.mean(self.contextual_forward_loss(F.avg_pool2d(source[-3], 2), F.avg_pool2d(target[-3].detach(), 2))) * 2 145 | return contextual_style5_1 + contextual_style4_1 + contextual_style3_1 146 | 147 | def compute_generator_loss(self, input_label, input_semantics, real_image, ref_label=None, ref_semantics=None, ref_image=None, self_ref=None): 148 | G_losses = {} 149 | generate_out = self.generate_fake(input_semantics, real_image, ref_semantics=ref_semantics, ref_image=ref_image, self_ref=self_ref) 150 | generate_out['fake_image'] = generate_out['fake_image'].float() 151 | weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 152 | sample_weights = self_ref/(sum(self_ref)+1e-5) 153 | sample_weights = sample_weights.view(-1, 1, 1, 1) 154 | """domain align""" 155 | if 'loss_novgg_featpair' in generate_out and generate_out['loss_novgg_featpair'] is not None: 156 | G_losses['no_vgg_feat'] = generate_out['loss_novgg_featpair'] 157 | """warping cycle""" 158 | if self.opt.weight_warp_cycle > 0: 159 | warp_cycle = generate_out['warp_cycle'] 160 | scale_factor = ref_image.size()[-1] // warp_cycle.size()[-1] 161 | ref = F.avg_pool2d(ref_image, scale_factor, stride=scale_factor) 162 | G_losses['G_warp_cycle'] = F.l1_loss(warp_cycle, ref) * self.opt.weight_warp_cycle 163 | """warping loss""" 164 | if self.opt.weight_warp_self > 0: 165 | """512x512""" 166 | warp1, warp2, warp3, warp4 = generate_out['warp_out'] 167 | G_losses['G_warp_self'] = \ 168 | torch.mean(F.l1_loss(warp4, real_image, reduction='none') * sample_weights) * self.opt.weight_warp_self * 1.0 + \ 169 | torch.mean(F.l1_loss(warp3, F.avg_pool2d(real_image, 2, stride=2), reduction='none') * sample_weights) * self.opt.weight_warp_self * 1.0 + \ 170 | torch.mean(F.l1_loss(warp2, F.avg_pool2d(real_image, 4, stride=4), reduction='none') * sample_weights) * self.opt.weight_warp_self * 1.0 + \ 171 | torch.mean(F.l1_loss(warp1, F.avg_pool2d(real_image, 8, stride=8), reduction='none') * sample_weights) * self.opt.weight_warp_self * 1.0 172 | """gan loss""" 173 | pred_fake, pred_real = self.discriminate(input_semantics, generate_out['fake_image'], real_image) 174 | G_losses['GAN'] = self.criterionGAN(pred_fake, True, for_discriminator=False) * self.opt.weight_gan 175 | if not self.opt.no_ganFeat_loss: 176 | num_D = len(pred_fake) 177 | GAN_Feat_loss = 0.0 178 | for i in range(num_D): 179 | # for each discriminator 180 | # last output is the final prediction, so we exclude it 181 | num_intermediate_outputs = len(pred_fake[i]) - 1 182 | for j in range(num_intermediate_outputs): 183 | # for each layer output 184 | unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach()) 185 | GAN_Feat_loss += unweighted_loss * self.opt.weight_ganFeat / num_D 186 | G_losses['GAN_Feat'] = GAN_Feat_loss 187 | """feature matching loss""" 188 | fake_features = self.vggnet_fix(generate_out['fake_image'], ['r12', 'r22', 'r32', 'r42', 'r52'], preprocess=True) 189 | loss = 0 190 | for i in range(len(generate_out['real_features'])): 191 | loss += weights[i] * util.weighted_l1_loss(fake_features[i], generate_out['real_features'][i].detach(), sample_weights) 192 | G_losses['fm'] = loss * self.opt.weight_vgg * self.opt.weight_fm_ratio 193 | """perceptual loss""" 194 | feat_loss = util.mse_loss(fake_features[self.perceptual_layer], generate_out['real_features'][self.perceptual_layer].detach()) 195 | G_losses['perc'] = feat_loss * self.opt.weight_perceptual 196 | """contextual loss""" 197 | G_losses['contextual'] = self.get_ctx_loss(fake_features, generate_out['ref_features']) * self.opt.weight_vgg * self.opt.weight_contextual 198 | return G_losses, generate_out 199 | 200 | def compute_discriminator_loss(self, input_semantics, real_image, GforD, label=None): 201 | D_losses = {} 202 | with torch.no_grad(): 203 | fake_image = GforD['fake_image'].detach() 204 | fake_image.requires_grad_() 205 | pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image) 206 | D_losses['D_Fake'] = self.criterionGAN(pred_fake, False, for_discriminator=True) * self.opt.weight_gan 207 | D_losses['D_real'] = self.criterionGAN(pred_real, True, for_discriminator=True) * self.opt.weight_gan 208 | return D_losses 209 | 210 | def encode_z(self, real_image): 211 | mu, logvar = self.net['netE'](real_image) 212 | z = self.reparameterize(mu, logvar) 213 | return z, mu, logvar 214 | 215 | def generate_fake(self, input_semantics, real_image, ref_semantics=None, ref_image=None, self_ref=None): 216 | generate_out = {} 217 | generate_out['ref_features'] = self.vggnet_fix(ref_image, ['r12', 'r22', 'r32', 'r42', 'r52'], preprocess=True) 218 | generate_out['real_features'] = self.vggnet_fix(real_image, ['r12', 'r22', 'r32', 'r42', 'r52'], preprocess=True) 219 | with autocast(enabled=self.opt.amp): 220 | corr_out = self.net['netCorr'](ref_image, real_image, input_semantics, ref_semantics) 221 | generate_out['fake_image'] = self.net['netG'](input_semantics, warp_out=corr_out['warp_out']) 222 | generate_out = {**generate_out, **corr_out} 223 | return generate_out 224 | 225 | def inference(self, input_semantics, ref_semantics=None, ref_image=None, self_ref=None, real_image=None): 226 | generate_out = {} 227 | with autocast(enabled=self.opt.amp): 228 | corr_out = self.net['netCorr'](ref_image, real_image, input_semantics, ref_semantics) 229 | generate_out['fake_image'] = self.net['netG'](input_semantics, warp_out=corr_out['warp_out']) 230 | generate_out = {**generate_out, **corr_out} 231 | return generate_out 232 | 233 | def discriminate(self, input_semantics, fake_image, real_image): 234 | fake_concat = torch.cat([input_semantics, fake_image], dim=1) 235 | real_concat = torch.cat([input_semantics, real_image], dim=1) 236 | fake_and_real = torch.cat([fake_concat, real_concat], dim=0) 237 | with autocast(enabled=self.opt.amp): 238 | discriminator_out = self.net['netD'](fake_and_real) 239 | pred_fake, pred_real = self.divide_pred(discriminator_out) 240 | return pred_fake, pred_real 241 | 242 | def divide_pred(self, pred): 243 | if type(pred) == list: 244 | fake = [] 245 | real = [] 246 | for p in pred: 247 | fake.append([tensor[:tensor.size(0) // 2] for tensor in p]) 248 | real.append([tensor[tensor.size(0) // 2:] for tensor in p]) 249 | else: 250 | fake = pred[:pred.size(0) // 2] 251 | real = pred[pred.size(0) // 2:] 252 | return fake, real 253 | 254 | def use_gpu(self): 255 | return len(self.opt.gpu_ids) > 0 256 | --------------------------------------------------------------------------------