├── .gitattributes
├── imgs
├── teaser.png
├── teaser_1.jpg
└── teaser_2.jpg
├── util
├── __init__.py
├── iter_counter.py
└── util.py
├── options
├── __init__.py
├── test_options.py
├── train_options.py
└── base_options.py
├── trainers
├── __init__.py
└── pix2pix_trainer.py
├── requirements.txt
├── CODE_OF_CONDUCT.md
├── data
├── preprocess.py
├── __init__.py
├── base_dataset.py
├── pix2pix_dataset.py
└── deepfashionHD_dataset.py
├── LICENSE
├── SUPPORT.md
├── models
├── __init__.py
├── networks
│ ├── ops.py
│ ├── __init__.py
│ ├── generator.py
│ ├── base_network.py
│ ├── ContextualLoss.py
│ ├── convgru.py
│ ├── loss.py
│ ├── normalization.py
│ ├── discriminator.py
│ ├── architecture.py
│ ├── patch_match.py
│ └── correspondence.py
└── pix2pix_model.py
├── test.py
├── .gitignore
├── SECURITY.md
├── train.py
└── README.md
/.gitattributes:
--------------------------------------------------------------------------------
1 | # Auto detect text files and perform LF normalization
2 | * text=auto
3 |
--------------------------------------------------------------------------------
/imgs/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/CoCosNet-v2/HEAD/imgs/teaser.png
--------------------------------------------------------------------------------
/imgs/teaser_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/CoCosNet-v2/HEAD/imgs/teaser_1.jpg
--------------------------------------------------------------------------------
/imgs/teaser_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/CoCosNet-v2/HEAD/imgs/teaser_2.jpg
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
--------------------------------------------------------------------------------
/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.7.0
2 | torchvision
3 | matplotlib
4 | pillow
5 | imageio
6 | numpy
7 | pandas
8 | scipy
9 | scikit-image
10 | opencv-python
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Microsoft Open Source Code of Conduct
2 |
3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
4 |
5 | Resources:
6 |
7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/)
8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/)
9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns
10 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | from .base_options import BaseOptions
5 |
6 |
7 | class TestOptions(BaseOptions):
8 | def initialize(self, parser):
9 | BaseOptions.initialize(self, parser)
10 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
11 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
12 | parser.add_argument('--how_many', type=int, default=float("inf"), help='how many test images to run')
13 | parser.add_argument('--save_per_img', action='store_true', help='if specified, save per image')
14 | parser.add_argument('--show_corr', action='store_true', help='if specified, save bilinear upsample correspondence')
15 | parser.set_defaults(preprocess_mode='scale_width_and_crop', crop_size=256, load_size=256, display_winsize=256)
16 | parser.set_defaults(serial_batches=True)
17 | parser.set_defaults(no_flip=True)
18 | parser.set_defaults(phase='test')
19 | self.isTrain = False
20 | return parser
21 |
--------------------------------------------------------------------------------
/data/preprocess.py:
--------------------------------------------------------------------------------
1 | import os
2 | import skimage.util as util
3 | from skimage import io
4 | from skimage.transform import resize
5 |
6 |
7 | with open('train.txt', 'r') as fd:
8 | image_files = fd.readlines()
9 |
10 | total = len(image_files)
11 | cnt = 0
12 |
13 | # path/to/deepfashion directory
14 | root = '/path/to/deepfashion'
15 | # path/to/save directory
16 | save_root = 'path/to/save'
17 |
18 | for image_file in image_files:
19 | image_file = os.path.join(root, image_file).strip()
20 | image = io.imread(image_file)
21 | pad_width_1 = (1101-750) // 2
22 | pad_width_2 = (1101-750) // 2 + 1
23 | image_pad = util.pad(image, ((0,0),(pad_width_1, pad_width_2),(0,0)), constant_values=232)
24 | image_resize = resize(image_pad, (1024, 1024))
25 | image_resize = (image_resize * 255).astype('uint8')
26 | dst_file = os.path.dirname(image_file).replace(root, save_root)
27 | os.makedirs(dst_file, exist_ok=True)
28 | dst_file = os.path.join(dst_file, os.path.basename(image_file))
29 | # dst_file = dst_file.replace('.jpg', '.png')
30 | io.imsave(dst_file, image_resize)
31 | cnt += 1
32 | if cnt % 20 == 0:
33 | print('Processing: %d / %d' % (cnt, total))
34 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) Microsoft Corporation.
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE
22 |
--------------------------------------------------------------------------------
/SUPPORT.md:
--------------------------------------------------------------------------------
1 | # TODO: The maintainer of this repo has not yet edited this file
2 |
3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project?
4 |
5 | - **No CSS support:** Fill out this template with information about how to file issues and get help.
6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/spot](https://aka.ms/spot). CSS will work with/help you to determine next steps. More details also available at [aka.ms/onboardsupport](https://aka.ms/onboardsupport).
7 | - **Not sure?** Fill out a SPOT intake as though the answer were "Yes". CSS will help you decide.
8 |
9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.*
10 |
11 | # Support
12 |
13 | ## How to file issues and get help
14 |
15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing
16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or
17 | feature request as a new Issue.
18 |
19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE
20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER
21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**.
22 |
23 | ## Microsoft Support Policy
24 |
25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above.
26 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import torch
5 | import importlib
6 |
7 |
8 | def find_model_using_name(model_name):
9 | # Given the option --model [modelname],
10 | # the file "models/modelname_model.py"
11 | # will be imported.
12 | model_filename = "models." + model_name + "_model"
13 | modellib = importlib.import_module(model_filename)
14 | # In the file, the class called ModelNameModel() will
15 | # be instantiated. It has to be a subclass of torch.nn.Module,
16 | # and it is case-insensitive.
17 | model = None
18 | target_model_name = model_name.replace('_', '') + 'model'
19 | for name, cls in modellib.__dict__.items():
20 | if name.lower() == target_model_name.lower() \
21 | and issubclass(cls, torch.nn.Module):
22 | model = cls
23 | if model is None:
24 | print("In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s in lowercase." % (model_filename, target_model_name))
25 | exit(0)
26 | return model
27 |
28 |
29 | def get_option_setter(model_name):
30 | model_class = find_model_using_name(model_name)
31 | return model_class.modify_commandline_options
32 |
33 |
34 | def create_model(opt):
35 | model = find_model_using_name(opt.model)
36 | instance = model(opt)
37 | print("model [%s] was created" % (type(instance).__name__))
38 | return instance
39 |
--------------------------------------------------------------------------------
/models/networks/ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 |
10 | def convert_1d_to_2d(index, base=64):
11 | x = index // base
12 | y = index % base
13 | return x,y
14 |
15 |
16 | def convert_2d_to_1d(x, y, base=64):
17 | return x*base+y
18 |
19 |
20 | def batch_meshgrid(shape, device):
21 | batch_size, _, height, width = shape
22 | x_range = torch.arange(0.0, width, device=device)
23 | y_range = torch.arange(0.0, height, device=device)
24 | x_coordinate, y_coordinate = torch.meshgrid(x_range, y_range)
25 | x_coordinate = x_coordinate.expand(batch_size, -1, -1).unsqueeze(1)
26 | y_coordinate = y_coordinate.expand(batch_size, -1, -1).unsqueeze(1)
27 | return x_coordinate, y_coordinate
28 |
29 |
30 | def inds_to_offset(inds):
31 | """
32 | inds: b x number x h x w
33 | """
34 | shape = inds.size()
35 | device = inds.device
36 | x_coordinate, y_coordinate = batch_meshgrid(shape, device)
37 | batch_size, _, height, width = shape
38 | x = inds // width
39 | y = inds % width
40 | return x - x_coordinate, y - y_coordinate
41 |
42 |
43 | def offset_to_inds(offset_x, offset_y):
44 | shape = offset_x.size()
45 | device = offset_x.device
46 | x_coordinate, y_coordinate = batch_meshgrid(shape, device)
47 | h, w = offset_x.size()[2:]
48 | x = torch.clamp(x_coordinate + offset_x, 0, h-1)
49 | y = torch.clamp(y_coordinate + offset_y, 0, w-1)
50 | return x * offset_x.size()[3] + y
51 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import importlib
5 | import torch.utils.data
6 | from data.base_dataset import BaseDataset
7 |
8 |
9 | def find_dataset_using_name(dataset_name):
10 | dataset_filename = "data." + dataset_name + "_dataset"
11 | datasetlib = importlib.import_module(dataset_filename)
12 | dataset = None
13 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
14 | for name, cls in datasetlib.__dict__.items():
15 | if name.lower() == target_dataset_name.lower() \
16 | and issubclass(cls, BaseDataset):
17 | dataset = cls
18 | if dataset is None:
19 | raise ValueError("In %s.py, there should be a subclass of BaseDataset "
20 | "with class name that matches %s in lowercase." %
21 | (dataset_filename, target_dataset_name))
22 | return dataset
23 |
24 |
25 | def get_option_setter(dataset_name):
26 | dataset_class = find_dataset_using_name(dataset_name)
27 | return dataset_class.modify_commandline_options
28 |
29 |
30 | def create_dataloader(opt):
31 | dataset = find_dataset_using_name(opt.dataset_mode)
32 | instance = dataset()
33 | instance.initialize(opt)
34 | print("Dataset [%s] of size %d was created" % (type(instance).__name__, len(instance)))
35 | dataloader = torch.utils.data.DataLoader(
36 | instance,
37 | batch_size=opt.batchSize,
38 | shuffle=(opt.phase=='train'),
39 | num_workers=int(opt.nThreads),
40 | drop_last=(opt.phase=='train')
41 | )
42 | return dataloader
43 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import torch
5 | from torchvision.utils import save_image
6 | import os
7 | import imageio
8 | import numpy as np
9 | import data
10 | from util.util import mkdir
11 | from options.test_options import TestOptions
12 | from models.pix2pix_model import Pix2PixModel
13 |
14 |
15 | if __name__ == '__main__':
16 | opt = TestOptions().parse()
17 | dataloader = data.create_dataloader(opt)
18 | model = Pix2PixModel(opt)
19 | if len(opt.gpu_ids) > 1:
20 | model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
21 | else:
22 | model.to(opt.gpu_ids[0])
23 | model.eval()
24 | save_root = os.path.join(opt.checkpoints_dir, opt.name, 'test')
25 | mkdir(save_root)
26 | for i, data_i in enumerate(dataloader):
27 | print('{} / {}'.format(i, len(dataloader)))
28 | if i * opt.batchSize >= opt.how_many:
29 | break
30 | imgs_num = data_i['label'].shape[0]
31 | out = model(data_i, mode='inference')
32 | if opt.save_per_img:
33 | try:
34 | for it in range(imgs_num):
35 | save_name = os.path.join(save_root, '%08d_%04d.png' % (i, it))
36 | save_image(out['fake_image'][it:it+1], save_name, padding=0, normalize=True)
37 | except OSError as err:
38 | print(err)
39 | else:
40 | label = data_i['label'][:,:3,:,:]
41 | imgs = torch.cat((label.cpu(), data_i['ref'].cpu(), out['fake_image'].data.cpu()), 0)
42 | try:
43 | save_name = os.path.join(save_root, '%08d.png' % i)
44 | save_image(imgs, save_name, nrow=imgs_num, padding=0, normalize=True)
45 | except OSError as err:
46 | print(err)
47 |
--------------------------------------------------------------------------------
/models/networks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import torch
5 | from models.networks.base_network import BaseNetwork
6 | from models.networks.loss import *
7 | from models.networks.discriminator import *
8 | from models.networks.generator import *
9 | from models.networks.ContextualLoss import *
10 | from models.networks.correspondence import *
11 | from models.networks.ops import *
12 | import util.util as util
13 |
14 |
15 | def find_network_using_name(target_network_name, filename, add=True):
16 | target_class_name = target_network_name + filename if add else target_network_name
17 | module_name = 'models.networks.' + filename
18 | network = util.find_class_in_module(target_class_name, module_name)
19 | assert issubclass(network, BaseNetwork), \
20 | "Class %s should be a subclass of BaseNetwork" % network
21 | return network
22 |
23 |
24 | def modify_commandline_options(parser, is_train):
25 | opt, _ = parser.parse_known_args()
26 | netG_cls = find_network_using_name(opt.netG, 'generator')
27 | parser = netG_cls.modify_commandline_options(parser, is_train)
28 | if is_train:
29 | netD_cls = find_network_using_name(opt.netD, 'discriminator')
30 | parser = netD_cls.modify_commandline_options(parser, is_train)
31 | return parser
32 |
33 |
34 | def create_network(cls, opt):
35 | net = cls(opt)
36 | net.print_network()
37 | if len(opt.gpu_ids) > 0:
38 | assert(torch.cuda.is_available())
39 | net.cuda()
40 | net.init_weights(opt.init_type, opt.init_variance)
41 | return net
42 |
43 |
44 | def define_G(opt):
45 | netG_cls = find_network_using_name(opt.netG, 'generator')
46 | return create_network(netG_cls, opt)
47 |
48 |
49 | def define_D(opt):
50 | netD_cls = find_network_using_name(opt.netD, 'discriminator')
51 | return create_network(netD_cls, opt)
52 |
53 | def define_Corr(opt):
54 | netCoor_cls = find_network_using_name(opt.netCorr, 'correspondence')
55 | return create_network(netCoor_cls, opt)
56 |
--------------------------------------------------------------------------------
/models/networks/generator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | from torch.autograd import Function
8 |
9 | from models.networks.base_network import BaseNetwork
10 | from models.networks.architecture import SPADEResnetBlock
11 |
12 |
13 | class SPADEGenerator(BaseNetwork):
14 | @staticmethod
15 | def modify_commandline_options(parser, is_train):
16 | parser.set_defaults(norm_G='spectralspadesyncbatch3x3')
17 | return parser
18 |
19 | def __init__(self, opt):
20 | super().__init__()
21 | self.opt = opt
22 | nf = opt.ngf
23 | self.sw, self.sh = self.compute_latent_vector_size(opt)
24 | ic = 4*3+opt.label_nc
25 | self.fc = nn.Conv2d(ic, 8 * nf, 3, padding=1)
26 | self.head_0 = SPADEResnetBlock(8 * nf, 8 * nf, opt)
27 | self.G_middle_0 = SPADEResnetBlock(8 * nf, 8 * nf, opt)
28 | self.G_middle_1 = SPADEResnetBlock(8 * nf, 8 * nf, opt)
29 | self.up_0 = SPADEResnetBlock(8 * nf, 8 * nf, opt)
30 | self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf, opt)
31 | self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf, opt)
32 | self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf, opt)
33 | final_nc = nf
34 | self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
35 | self.up = nn.Upsample(scale_factor=2)
36 |
37 | def compute_latent_vector_size(self, opt):
38 | num_up_layers = 5
39 | sw = opt.crop_size // (2**num_up_layers)
40 | sh = round(sw / opt.aspect_ratio)
41 | return sw, sh
42 |
43 | def forward(self, input, warp_out=None):
44 | seg = torch.cat((F.interpolate(warp_out[0], size=(512, 512)), F.interpolate(warp_out[1], size=(512, 512)), F.interpolate(warp_out[2], size=(512, 512)), warp_out[3], input), dim=1)
45 | x = F.interpolate(seg, size=(self.sh, self.sw))
46 | x = self.fc(x)
47 | x = self.head_0(x, seg)
48 | x = self.up(x)
49 | x = self.G_middle_0(x, seg)
50 | x = self.G_middle_1(x, seg)
51 | x = self.up(x)
52 | x = self.up_0(x, seg)
53 | x = self.up(x)
54 | x = self.up_1(x, seg)
55 | x = self.up(x)
56 | x = self.up_2(x, seg)
57 | x = self.up(x)
58 | x = self.up_3(x, seg)
59 | x = self.conv_img(F.leaky_relu(x, 2e-1))
60 | x = torch.tanh(x)
61 | return x
62 |
--------------------------------------------------------------------------------
/models/networks/base_network.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
5 | import torch.nn as nn
6 | from torch.nn import init
7 |
8 |
9 | class BaseNetwork(nn.Module):
10 | def __init__(self):
11 | super(BaseNetwork, self).__init__()
12 |
13 | @staticmethod
14 | def modify_commandline_options(parser, is_train):
15 | return parser
16 |
17 | def print_network(self):
18 | if isinstance(self, list):
19 | self = self[0]
20 | num_params = 0
21 | for param in self.parameters():
22 | num_params += param.numel()
23 | print('Network [%s] was created. Total number of parameters: %.1f million. '
24 | 'To see the architecture, do print(network).'
25 | % (type(self).__name__, num_params / 1000000))
26 |
27 | def init_weights(self, init_type='normal', gain=0.02):
28 | def init_func(m):
29 | classname = m.__class__.__name__
30 | if classname.find('BatchNorm2d') != -1:
31 | if hasattr(m, 'weight') and m.weight is not None:
32 | init.normal_(m.weight.data, 1.0, gain)
33 | if hasattr(m, 'bias') and m.bias is not None:
34 | init.constant_(m.bias.data, 0.0)
35 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
36 | if init_type == 'normal':
37 | init.normal_(m.weight.data, 0.0, gain)
38 | elif init_type == 'xavier':
39 | init.xavier_normal_(m.weight.data, gain=gain)
40 | elif init_type == 'xavier_uniform':
41 | init.xavier_uniform_(m.weight.data, gain=1.0)
42 | elif init_type == 'kaiming':
43 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
44 | elif init_type == 'orthogonal':
45 | init.orthogonal_(m.weight.data, gain=gain)
46 | elif init_type == 'none': # uses pytorch's default init method
47 | m.reset_parameters()
48 | else:
49 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
50 | if hasattr(m, 'bias') and m.bias is not None:
51 | init.constant_(m.bias.data, 0.0)
52 | self.apply(init_func)
53 | # propagate to children
54 | for m in self.children():
55 | if hasattr(m, 'init_weights'):
56 | m.init_weights(init_type, gain)
57 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98 | __pypackages__/
99 |
100 | # Celery stuff
101 | celerybeat-schedule
102 | celerybeat.pid
103 |
104 | # SageMath parsed files
105 | *.sage.py
106 |
107 | # Environments
108 | .env
109 | .venv
110 | env/
111 | venv/
112 | ENV/
113 | env.bak/
114 | venv.bak/
115 |
116 | # Spyder project settings
117 | .spyderproject
118 | .spyproject
119 |
120 | # Rope project settings
121 | .ropeproject
122 |
123 | # mkdocs documentation
124 | /site
125 |
126 | # mypy
127 | .mypy_cache/
128 | .dmypy.json
129 | dmypy.json
130 |
131 | # Pyre type checker
132 | .pyre/
133 |
134 | # pytype static type analyzer
135 | .pytype/
136 |
137 | # Cython debug symbols
138 | cython_debug/
139 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://docs.microsoft.com/en-us/previous-versions/tn-archive/cc751383(v=technet.10)), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://msrc.microsoft.com/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://www.microsoft.com/en-us/msrc/pgp-key-msrc).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://microsoft.com/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://www.microsoft.com/en-us/msrc/cvd).
40 |
41 |
--------------------------------------------------------------------------------
/models/networks/ContextualLoss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | from util.util import feature_normalize, mse_loss
10 |
11 |
12 |
13 | class ContextualLoss_forward(nn.Module):
14 | '''
15 | input is Al, Bl, channel = 1, range ~ [0, 255]
16 | '''
17 |
18 | def __init__(self, opt):
19 | super(ContextualLoss_forward, self).__init__()
20 | self.opt = opt
21 | return None
22 |
23 | def forward(self, X_features, Y_features, h=0.1, feature_centering=True):
24 | '''
25 | X_features&Y_features are are feature vectors or feature 2d array
26 | h: bandwidth
27 | return the per-sample loss
28 | '''
29 | batch_size = X_features.shape[0]
30 | feature_depth = X_features.shape[1]
31 | feature_size = X_features.shape[2]
32 |
33 | # to normalized feature vectors
34 | if feature_centering:
35 | if self.opt.PONO:
36 | X_features = X_features - Y_features.mean(dim=1).unsqueeze(dim=1)
37 | Y_features = Y_features - Y_features.mean(dim=1).unsqueeze(dim=1)
38 | else:
39 | X_features = X_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(dim=-1)
40 | Y_features = Y_features - Y_features.view(batch_size, feature_depth, -1).mean(dim=-1).unsqueeze(dim=-1).unsqueeze(dim=-1)
41 |
42 | # X_features = X_features - Y_features.mean(dim=1).unsqueeze(dim=1)
43 | # Y_features = Y_features - Y_features.mean(dim=1).unsqueeze(dim=1)
44 |
45 | X_features = feature_normalize(X_features).view(batch_size, feature_depth, -1) # batch_size * feature_depth * feature_size * feature_size
46 | Y_features = feature_normalize(Y_features).view(batch_size, feature_depth, -1) # batch_size * feature_depth * feature_size * feature_size
47 |
48 | # X_features = F.unfold(
49 | # X_features, kernel_size=self.opt.match_kernel, stride=1, padding=int(self.opt.match_kernel // 2)) # batch_size * feature_depth_new * feature_size^2
50 | # Y_features = F.unfold(
51 | # Y_features, kernel_size=self.opt.match_kernel, stride=1, padding=int(self.opt.match_kernel // 2)) # batch_size * feature_depth_new * feature_size^2
52 |
53 | # conine distance = 1 - similarity
54 | X_features_permute = X_features.permute(0, 2, 1) # batch_size * feature_size^2 * feature_depth
55 | d = 1 - torch.matmul(X_features_permute, Y_features) # batch_size * feature_size^2 * feature_size^2
56 |
57 | # normalized distance: dij_bar
58 | # d_norm = d
59 | d_norm = d / (torch.min(d, dim=-1, keepdim=True)[0] + 1e-3) # batch_size * feature_size^2 * feature_size^2
60 |
61 | # pairwise affinity
62 | w = torch.exp((1 - d_norm) / h)
63 | A_ij = w / torch.sum(w, dim=-1, keepdim=True)
64 |
65 | # contextual loss per sample
66 | CX = torch.mean(torch.max(A_ij, dim=-1)[0], dim=1)
67 | loss = -torch.log(CX)
68 |
69 | # contextual loss per batch
70 | # loss = torch.mean(loss)
71 | return loss
72 |
--------------------------------------------------------------------------------
/util/iter_counter.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import os
5 | import time
6 | import numpy as np
7 |
8 | # Helper class that keeps track of training iterations
9 | class IterationCounter():
10 | def __init__(self, opt, dataset_size):
11 | self.opt = opt
12 | self.dataset_size = dataset_size
13 | self.batch_size = opt.batchSize
14 | self.first_epoch = 1
15 | self.total_epochs = opt.niter + opt.niter_decay
16 | # iter number within each epoch
17 | self.epoch_iter = 0
18 | self.iter_record_path = os.path.join(self.opt.checkpoints_dir, self.opt.name, 'iter.txt')
19 | if opt.isTrain and opt.continue_train:
20 | try:
21 | self.first_epoch, self.epoch_iter = np.loadtxt(self.iter_record_path, delimiter=',', dtype=int)
22 | print('Resuming from epoch %d at iteration %d' % (self.first_epoch, self.epoch_iter))
23 | except:
24 | print('Could not load iteration record at %s. Starting from beginning.' % self.iter_record_path)
25 | self.epoch_iter_num = dataset_size * self.batch_size
26 | self.total_steps_so_far = (self.first_epoch - 1) * self.epoch_iter_num + self.epoch_iter
27 | self.continue_train_flag = opt.continue_train
28 |
29 | # return the iterator of epochs for the training
30 | def training_epochs(self):
31 | return range(self.first_epoch, self.total_epochs + 1)
32 |
33 | def record_epoch_start(self, epoch):
34 | self.epoch_start_time = time.time()
35 | if not self.continue_train_flag:
36 | self.epoch_iter = 0
37 | else:
38 | self.continue_train_flag = False
39 | self.last_iter_time = time.time()
40 | self.current_epoch = epoch
41 |
42 | def record_one_iteration(self):
43 | current_time = time.time()
44 | # the last remaining batch is dropped (see data/__init__.py),
45 | # so we can assume batch size is always opt.batchSize
46 | self.time_per_iter = (current_time - self.last_iter_time) / self.opt.batchSize
47 | self.last_iter_time = current_time
48 | self.total_steps_so_far += self.opt.batchSize
49 | self.epoch_iter += self.opt.batchSize
50 |
51 | def record_epoch_end(self):
52 | current_time = time.time()
53 | self.time_per_epoch = current_time - self.epoch_start_time
54 | print('End of epoch %d / %d \t Time Taken: %d sec' %
55 | (self.current_epoch, self.total_epochs, self.time_per_epoch))
56 | if self.current_epoch % self.opt.save_epoch_freq == 0:
57 | np.savetxt(self.iter_record_path, (self.current_epoch + 1, 0),
58 | delimiter=',', fmt='%d')
59 | print('Saved current iteration count at %s.' % self.iter_record_path)
60 |
61 | def record_current_iter(self):
62 | np.savetxt(self.iter_record_path, (self.current_epoch, self.epoch_iter),
63 | delimiter=',', fmt='%d')
64 | print('Saved current iteration count at %s.' % self.iter_record_path)
65 |
66 | def needs_saving(self):
67 | return (self.total_steps_so_far % self.opt.save_latest_freq) < self.opt.batchSize
68 |
69 | def needs_printing(self):
70 | return (self.total_steps_so_far % self.opt.print_freq) < self.opt.batchSize
71 |
72 | def needs_displaying(self):
73 | return (self.total_steps_so_far % self.opt.display_freq) < self.opt.batchSize
74 |
--------------------------------------------------------------------------------
/models/networks/convgru.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class FlowHead(nn.Module):
10 | def __init__(self, input_dim=32, hidden_dim=64):
11 | super(FlowHead, self).__init__()
12 | candidate_num = 16
13 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
14 | self.conv2 = nn.Conv2d(hidden_dim, 2*candidate_num, 3, padding=1)
15 | self.relu = nn.ReLU(inplace=True)
16 | def forward(self, x):
17 | x = self.conv1(x)
18 | x = self.relu(x)
19 | x = self.conv2(x)
20 | num = x.size()[1]
21 | delta_offset_x, delta_offset_y = torch.split(x, [num//2, num//2], dim=1)
22 | return delta_offset_x, delta_offset_y
23 |
24 |
25 | class SepConvGRU(nn.Module):
26 | def __init__(self, hidden_dim=32, input_dim=64):
27 | super(SepConvGRU, self).__init__()
28 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
29 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
30 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
31 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
32 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
33 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
34 |
35 | def forward(self, h, x):
36 | # horizontal
37 | hx = torch.cat([h, x], dim=1)
38 | z = torch.sigmoid(self.convz1(hx))
39 | r = torch.sigmoid(self.convr1(hx))
40 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
41 | h = (1-z) * h + z * q
42 | # vertical
43 | hx = torch.cat([h, x], dim=1)
44 | z = torch.sigmoid(self.convz2(hx))
45 | r = torch.sigmoid(self.convr2(hx))
46 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
47 | h = (1-z) * h + z * q
48 | return h
49 |
50 |
51 | class BasicMotionEncoder(nn.Module):
52 | def __init__(self):
53 | super(BasicMotionEncoder, self).__init__()
54 | candidate_num = 16
55 | self.convc1 = nn.Conv2d(candidate_num, 64, 1, padding=0)
56 | self.convc2 = nn.Conv2d(64, 64, 3, padding=1)
57 | self.convf1 = nn.Conv2d(2*candidate_num, 64, 7, padding=3)
58 | self.convf2 = nn.Conv2d(64, 64, 3, padding=1)
59 | self.conv = nn.Conv2d(64+64, 64-2*candidate_num, 3, padding=1)
60 |
61 | def forward(self, flow, corr):
62 | cor = F.relu(self.convc1(corr))
63 | cor = F.relu(self.convc2(cor))
64 | flo = F.relu(self.convf1(flow))
65 | flo = F.relu(self.convf2(flo))
66 | cor_flo = torch.cat([cor, flo], dim=1)
67 | out = F.relu(self.conv(cor_flo))
68 | return torch.cat([out, flow], dim=1)
69 |
70 |
71 | class BasicUpdateBlock(nn.Module):
72 | def __init__(self):
73 | super(BasicUpdateBlock, self).__init__()
74 | self.encoder = BasicMotionEncoder()
75 | self.gru = SepConvGRU()
76 | self.flow_head = FlowHead()
77 |
78 | def forward(self, net, corr, flow):
79 | motion_features = self.encoder(flow, corr)
80 | inp = motion_features
81 | net = self.gru(net, inp)
82 | delta_offset_x, delta_offset_y = self.flow_head(net)
83 | return net, delta_offset_x, delta_offset_y
84 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import os
5 | import sys
6 | import torch
7 | from torchvision.utils import save_image
8 | from options.train_options import TrainOptions
9 | import data
10 | from util.iter_counter import IterationCounter
11 | from util.util import print_current_errors
12 | from util.util import mkdir
13 | from trainers.pix2pix_trainer import Pix2PixTrainer
14 |
15 |
16 | if __name__ == '__main__':
17 | # parse options
18 | opt = TrainOptions().parse()
19 | # print options to help debugging
20 | print(' '.join(sys.argv))
21 | dataloader = data.create_dataloader(opt)
22 | len_dataloader = len(dataloader)
23 | # create tool for counting iterations
24 | iter_counter = IterationCounter(opt, len(dataloader))
25 | # create trainer for our model
26 | trainer = Pix2PixTrainer(opt, resume_epoch=iter_counter.first_epoch)
27 | save_root = os.path.join('checkpoints', opt.name, 'train')
28 | mkdir(save_root)
29 |
30 | for epoch in iter_counter.training_epochs():
31 | opt.epoch = epoch
32 | iter_counter.record_epoch_start(epoch)
33 | for i, data_i in enumerate(dataloader, start=iter_counter.epoch_iter):
34 | iter_counter.record_one_iteration()
35 | # Training
36 | # train generator
37 | if i % opt.D_steps_per_G == 0:
38 | trainer.run_generator_one_step(data_i)
39 | # train discriminator
40 | trainer.run_discriminator_one_step(data_i)
41 | if iter_counter.needs_printing():
42 | losses = trainer.get_latest_losses()
43 | try:
44 | print_current_errors(opt, epoch, iter_counter.epoch_iter,
45 | iter_counter.epoch_iter_num, losses, iter_counter.time_per_iter)
46 | except OSError as err:
47 | print(err)
48 |
49 | if iter_counter.needs_displaying():
50 | imgs_num = data_i['label'].shape[0]
51 |
52 | if opt.dataset_mode == 'deepfashionHD':
53 | label = data_i['label'][:,:3,:,:]
54 |
55 | show_size = opt.display_winsize
56 |
57 | imgs = torch.cat((label.cpu(), data_i['ref'].cpu(), \
58 | trainer.get_latest_generated().data.cpu(), \
59 | data_i['image'].cpu()), 0)
60 |
61 | try:
62 | save_name = '%08d_%08d.png' % (epoch, iter_counter.total_steps_so_far)
63 | save_name = os.path.join(save_root, save_name)
64 | save_image(imgs, save_name, nrow=imgs_num, padding=0, normalize=True)
65 | except OSError as err:
66 | print(err)
67 |
68 | if iter_counter.needs_saving():
69 | print('saving the latest model (epoch %d, total_steps %d)' %
70 | (epoch, iter_counter.total_steps_so_far))
71 | try:
72 | trainer.save('latest')
73 | iter_counter.record_current_iter()
74 | except OSError as err:
75 | import pdb; pdb.set_trace()
76 | print(err)
77 |
78 | trainer.update_learning_rate(epoch)
79 | iter_counter.record_epoch_end()
80 |
81 | if epoch % opt.save_epoch_freq == 0 or epoch == iter_counter.total_epochs:
82 | print('saving the model at the end of epoch %d, iters %d' %
83 | (epoch, iter_counter.total_steps_so_far))
84 | try:
85 | trainer.save('latest')
86 | trainer.save(epoch)
87 | except OSError as err:
88 | print(err)
89 |
90 | print('Training was successfully finished.')
91 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import os
5 | import re
6 | import argparse
7 | from argparse import Namespace
8 | import torch
9 | import numpy as np
10 | import importlib
11 | from PIL import Image
12 |
13 |
14 |
15 | def feature_normalize(feature_in, eps=1e-10):
16 | feature_in_norm = torch.norm(feature_in, 2, 1, keepdim=True) + eps
17 | feature_in_norm = torch.div(feature_in, feature_in_norm)
18 | return feature_in_norm
19 |
20 |
21 | def weighted_l1_loss(input, target, weights):
22 | out = torch.abs(input - target)
23 | out = out * weights.expand_as(out)
24 | loss = out.mean()
25 | return loss
26 |
27 |
28 | def mse_loss(input, target=0):
29 | return torch.mean((input - target)**2)
30 |
31 |
32 | def vgg_preprocess(tensor, vgg_normal_correct=False):
33 | if vgg_normal_correct:
34 | tensor = (tensor + 1) / 2
35 | # input is RGB tensor which ranges in [0,1]
36 | # output is BGR tensor which ranges in [0,255]
37 | tensor_bgr = torch.cat((tensor[:, 2:3, :, :], tensor[:, 1:2, :, :], tensor[:, 0:1, :, :]), dim=1)
38 | # tensor_bgr = tensor[:, [2, 1, 0], ...]
39 | tensor_bgr_ml = tensor_bgr - torch.Tensor([0.40760392, 0.45795686, 0.48501961]).type_as(tensor_bgr).view(1, 3, 1, 1)
40 | tensor_rst = tensor_bgr_ml * 255
41 | return tensor_rst
42 |
43 |
44 | def mkdirs(paths):
45 | if isinstance(paths, list) and not isinstance(paths, str):
46 | for path in paths:
47 | mkdir(path)
48 | else:
49 | mkdir(paths)
50 |
51 |
52 | def mkdir(path):
53 | if not os.path.exists(path):
54 | os.makedirs(path)
55 |
56 |
57 | def find_class_in_module(target_cls_name, module):
58 | target_cls_name = target_cls_name.replace('_', '').lower()
59 | clslib = importlib.import_module(module)
60 | cls = None
61 | for name, clsobj in clslib.__dict__.items():
62 | if name.lower() == target_cls_name:
63 | cls = clsobj
64 | if cls is None:
65 | print("In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name))
66 | exit(0)
67 | return cls
68 |
69 |
70 | def save_network(net, label, epoch, opt):
71 | save_filename = '%s_net_%s.pth' % (epoch, label)
72 | save_path = os.path.join(opt.checkpoints_dir, opt.name, save_filename)
73 | torch.save(net.cpu().state_dict(), save_path)
74 | if len(opt.gpu_ids) and torch.cuda.is_available():
75 | net.cuda()
76 |
77 |
78 | def load_network(net, label, epoch, opt):
79 | save_filename = '%s_net_%s.pth' % (epoch, label)
80 | save_dir = os.path.join(opt.checkpoints_dir, opt.name)
81 | save_path = os.path.join(save_dir, save_filename)
82 | if not os.path.exists(save_path):
83 | print('not find model :' + save_path + ', do not load model!')
84 | return net
85 | weights = torch.load(save_path)
86 | try:
87 | net.load_state_dict(weights)
88 | except KeyError:
89 | print('key error, not load!')
90 | except RuntimeError as err:
91 | print(err)
92 | net.load_state_dict(weights, strict=False)
93 | print('loaded with strict = False')
94 | print('Load from ' + save_path)
95 | return net
96 |
97 |
98 | def print_current_errors(opt, epoch, i, num, errors, t):
99 | message = '(epoch: %d, iters: %d, finish: %.2f%%, time: %.3f) ' % (epoch, i, (i/num)*100.0, t)
100 | for k, v in errors.items():
101 | v = v.mean().float()
102 | message += '%s: %.3f ' % (k, v)
103 | print(message)
104 | log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
105 | with open(log_name, "a") as log_file:
106 | log_file.write('%s\n' % message)
107 |
--------------------------------------------------------------------------------
/models/networks/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 |
9 | class GANLoss(nn.Module):
10 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0,
11 | tensor=torch.FloatTensor, opt=None):
12 | super(GANLoss, self).__init__()
13 | self.real_label = target_real_label
14 | self.fake_label = target_fake_label
15 | self.real_label_tensor = None
16 | self.fake_label_tensor = None
17 | self.zero_tensor = None
18 | self.Tensor = tensor
19 | self.gan_mode = gan_mode
20 | self.opt = opt
21 | if gan_mode == 'ls':
22 | pass
23 | elif gan_mode == 'original':
24 | pass
25 | elif gan_mode == 'w':
26 | pass
27 | elif gan_mode == 'hinge':
28 | pass
29 | else:
30 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
31 |
32 | def get_target_tensor(self, input, target_is_real):
33 | if target_is_real:
34 | if self.real_label_tensor is None:
35 | self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
36 | self.real_label_tensor.requires_grad_(False)
37 | return self.real_label_tensor.expand_as(input)
38 | else:
39 | if self.fake_label_tensor is None:
40 | self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
41 | self.fake_label_tensor.requires_grad_(False)
42 | return self.fake_label_tensor.expand_as(input)
43 |
44 | def get_zero_tensor(self, input):
45 | if self.zero_tensor is None:
46 | self.zero_tensor = self.Tensor(1).fill_(0)
47 | self.zero_tensor.requires_grad_(False)
48 | return self.zero_tensor.expand_as(input).type_as(input)
49 |
50 | def loss(self, input, target_is_real, for_discriminator=True):
51 | if self.gan_mode == 'original': # cross entropy loss
52 | target_tensor = self.get_target_tensor(input, target_is_real)
53 | loss = F.binary_cross_entropy_with_logits(input, target_tensor)
54 | return loss
55 | elif self.gan_mode == 'ls':
56 | target_tensor = self.get_target_tensor(input, target_is_real)
57 | return F.mse_loss(input, target_tensor)
58 | elif self.gan_mode == 'hinge':
59 | if for_discriminator:
60 | if target_is_real:
61 | minval = torch.min(input - 1, self.get_zero_tensor(input))
62 | loss = -torch.mean(minval)
63 | else:
64 | minval = torch.min(-input - 1, self.get_zero_tensor(input))
65 | loss = -torch.mean(minval)
66 | else:
67 | assert target_is_real, "The generator's hinge loss must be aiming for real"
68 | loss = -torch.mean(input)
69 | return loss
70 | else:
71 | # wgan
72 | if target_is_real:
73 | return -input.mean()
74 | else:
75 | return input.mean()
76 |
77 | def __call__(self, input, target_is_real, for_discriminator=True):
78 | if isinstance(input, list):
79 | loss = 0
80 | for pred_i in input:
81 | if isinstance(pred_i, list):
82 | pred_i = pred_i[-1]
83 | loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
84 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
85 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
86 | loss += new_loss
87 | return loss / len(input)
88 | else:
89 | return self.loss(input, target_is_real, for_discriminator)
90 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
5 | from .base_options import BaseOptions
6 |
7 |
8 | class TrainOptions(BaseOptions):
9 | def initialize(self, parser):
10 | BaseOptions.initialize(self, parser)
11 | # for displays
12 | parser.add_argument('--display_freq', type=int, default=2000, help='frequency of showing training results on screen')
13 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
14 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
15 | parser.add_argument('--save_epoch_freq', type=int, default=10, help='frequency of saving checkpoints at the end of epochs')
16 | # for training
17 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
18 | parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
19 | parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate. This is NOT the total #epochs. Totla #epochs is niter + niter_decay')
20 | parser.add_argument('--niter_decay', type=int, default=0, help='# of iter to linearly decay learning rate to zero')
21 | parser.add_argument('--optimizer', type=str, default='adam')
22 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
23 | parser.add_argument('--beta2', type=float, default=0.999, help='momentum term of adam')
24 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
25 | parser.add_argument('--D_steps_per_G', type=int, default=1, help='number of discriminator iterations per generator iterations.')
26 | # for discriminators
27 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
28 | parser.add_argument('--netD', type=str, default='multiscale', help='(n_layers|multiscale|image)')
29 | parser.add_argument('--no_TTUR', action='store_true', help='Use TTUR training scheme')
30 | parser.add_argument('--real_reference_probability', type=float, default=0.0, help='self-supervised training probability')
31 | parser.add_argument('--hard_reference_probability', type=float, default=0.0, help='hard reference training probability')
32 | # training loss weights
33 | parser.add_argument('--weight_warp_self', type=float, default=0.0, help='push warp self to ref')
34 | parser.add_argument('--weight_warp_cycle', type=float, default=0.0, help='push warp cycle to ref')
35 | parser.add_argument('--weight_novgg_featpair', type=float, default=10.0, help='in no vgg setting, use pair feat loss in domain adaptation')
36 | parser.add_argument('--gan_mode', type=str, default='hinge', help='(ls|original|hinge)')
37 | parser.add_argument('--weight_gan', type=float, default=10.0, help='weight of all loss in stage1')
38 | parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
39 | parser.add_argument('--weight_ganFeat', type=float, default=10.0, help='weight for feature matching loss')
40 | parser.add_argument('--which_perceptual', type=str, default='4_2', help='relu5_2 or relu4_2')
41 | parser.add_argument('--weight_perceptual', type=float, default=0.001)
42 | parser.add_argument('--weight_vgg', type=float, default=10.0, help='weight for vgg loss')
43 | parser.add_argument('--weight_contextual', type=float, default=1.0, help='ctx loss weight')
44 | parser.add_argument('--weight_fm_ratio', type=float, default=1.0, help='vgg fm loss weight comp with ctx loss')
45 | self.isTrain = True
46 | return parser
47 |
--------------------------------------------------------------------------------
/models/networks/normalization.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import re
5 |
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch.nn.utils.spectral_norm as spectral_norm
9 |
10 |
11 | def get_nonspade_norm_layer(opt, norm_type='instance'):
12 | def get_out_channel(layer):
13 | if hasattr(layer, 'out_channels'):
14 | return getattr(layer, 'out_channels')
15 | return layer.weight.size(0)
16 | def add_norm_layer(layer):
17 | nonlocal norm_type
18 | if norm_type.startswith('spectral'):
19 | layer = spectral_norm(layer)
20 | subnorm_type = norm_type[len('spectral'):]
21 | else:
22 | subnorm_type =norm_type
23 | if subnorm_type == 'none' or len(subnorm_type) == 0:
24 | return layer
25 | if getattr(layer, 'bias', None) is not None:
26 | delattr(layer, 'bias')
27 | layer.register_parameter('bias', None)
28 | if subnorm_type == 'batch':
29 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
30 | elif subnorm_type == 'sync_batch':
31 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
32 | elif subnorm_type == 'instance':
33 | norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
34 | else:
35 | raise ValueError('normalization layer %s is not recognized' % subnorm_type)
36 | return nn.Sequential(layer, norm_layer)
37 | return add_norm_layer
38 |
39 |
40 | def PositionalNorm2d(x, epsilon=1e-8):
41 | # x: B*C*W*H normalize in C dim
42 | mean = x.mean(dim=1, keepdim=True)
43 | std = x.var(dim=1, keepdim=True).add(epsilon).sqrt()
44 | output = (x - mean) / std
45 | return output
46 |
47 |
48 | class SPADE(nn.Module):
49 | def __init__(self, config_text, norm_nc, label_nc, PONO=False):
50 | super().__init__()
51 | assert config_text.startswith('spade')
52 | parsed = re.search('spade(\D+)(\d)x\d', config_text)
53 | param_free_norm_type = str(parsed.group(1))
54 | ks = int(parsed.group(2))
55 | self.pad_type = 'nozero'
56 | if PONO:
57 | self.param_free_norm = PositionalNorm2d
58 | elif param_free_norm_type == 'instance':
59 | self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
60 | elif param_free_norm_type == 'syncbatch':
61 | self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=True)
62 | elif param_free_norm_type == 'batch':
63 | self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=True)
64 | else:
65 | raise ValueError('%s is not a recognized param-free norm type in SPADE' % param_free_norm_type)
66 | nhidden = 128
67 | pw = ks // 2
68 | if self.pad_type != 'zero':
69 | self.mlp_shared = nn.Sequential(
70 | nn.ReflectionPad2d(pw),
71 | nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=0),
72 | nn.ReLU()
73 | )
74 | self.pad = nn.ReflectionPad2d(pw)
75 | self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=0)
76 | self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=0)
77 |
78 | else:
79 | self.mlp_shared = nn.Sequential(
80 | nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
81 | nn.ReLU()
82 | )
83 | self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
84 | self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
85 |
86 | def forward(self, x, segmap):
87 | normalized = self.param_free_norm(x)
88 | segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
89 | actv = self.mlp_shared(segmap)
90 | if self.pad_type != 'zero':
91 | gamma = self.mlp_gamma(self.pad(actv))
92 | beta = self.mlp_beta(self.pad(actv))
93 | else:
94 | gamma = self.mlp_gamma(actv)
95 | beta = self.mlp_beta(actv)
96 | out = normalized * (1 + gamma) + beta
97 | return out
98 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import torch.utils.data as data
5 | from PIL import Image
6 | import torchvision.transforms as transforms
7 | import numpy as np
8 | import random
9 |
10 |
11 | class BaseDataset(data.Dataset):
12 | def __init__(self):
13 | super(BaseDataset, self).__init__()
14 |
15 | @staticmethod
16 | def modify_commandline_options(parser, is_train):
17 | return parser
18 |
19 | def initialize(self, opt):
20 | pass
21 |
22 |
23 | def get_params(opt, size):
24 | w, h = size
25 | new_h = h
26 | new_w = w
27 | if opt.preprocess_mode == 'resize_and_crop':
28 | new_h = new_w = opt.load_size
29 | elif opt.preprocess_mode == 'scale_width_and_crop':
30 | new_w = opt.load_size
31 | new_h = opt.load_size * h // w
32 | elif opt.preprocess_mode == 'scale_shortside_and_crop':
33 | ss, ls = min(w, h), max(w, h) # shortside and longside
34 | width_is_shorter = w == ss
35 | ls = int(opt.load_size * ls / ss)
36 | new_w, new_h = (ss, ls) if width_is_shorter else (ls, ss)
37 |
38 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
39 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
40 |
41 | flip = random.random() > 0.5
42 | return {'crop_pos': (x, y), 'flip': flip}
43 |
44 |
45 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True):
46 | transform_list = []
47 | if opt.dataset_mode == 'flickr' and method == Image.NEAREST:
48 | transform_list.append(transforms.Lambda(lambda img: __add1(img)))
49 | if 'resize' in opt.preprocess_mode:
50 | osize = [opt.load_size, opt.load_size]
51 | transform_list.append(transforms.Resize(osize, interpolation=method))
52 | elif 'scale_width' in opt.preprocess_mode:
53 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, method)))
54 | elif 'scale_shortside' in opt.preprocess_mode:
55 | transform_list.append(transforms.Lambda(lambda img: __scale_shortside(img, opt.load_size, method)))
56 |
57 | if 'crop' in opt.preprocess_mode:
58 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
59 |
60 | if opt.preprocess_mode == 'none':
61 | base = 32
62 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
63 |
64 | if opt.preprocess_mode == 'fixed':
65 | w = opt.crop_size
66 | h = round(opt.crop_size / opt.aspect_ratio)
67 | transform_list.append(transforms.Lambda(lambda img: __resize(img, w, h, method)))
68 |
69 | if opt.isTrain and not opt.no_flip:
70 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
71 |
72 | if opt.isTrain and 'rotate' in params.keys():
73 | transform_list.append(transforms.Lambda(lambda img: __rotate(img, params['rotate'], method)))
74 |
75 | if toTensor:
76 | transform_list += [transforms.ToTensor()]
77 |
78 | if normalize:
79 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
80 | (0.5, 0.5, 0.5))]
81 | return transforms.Compose(transform_list)
82 |
83 |
84 | def __resize(img, w, h, method=Image.BICUBIC):
85 | return img.resize((w, h), method)
86 |
87 |
88 | def __make_power_2(img, base, method=Image.BICUBIC):
89 | ow, oh = img.size
90 | h = int(round(oh / base) * base)
91 | w = int(round(ow / base) * base)
92 | if (h == oh) and (w == ow):
93 | return img
94 | return img.resize((w, h), method)
95 |
96 |
97 | def __scale_width(img, target_width, method=Image.BICUBIC):
98 | ow, oh = img.size
99 | if (ow == target_width):
100 | return img
101 | w = target_width
102 | h = int(target_width * oh / ow)
103 | return img.resize((w, h), method)
104 |
105 |
106 | def __scale_shortside(img, target_width, method=Image.BICUBIC):
107 | ow, oh = img.size
108 | ss, ls = min(ow, oh), max(ow, oh) # shortside and longside
109 | width_is_shorter = ow == ss
110 | if (ss == target_width):
111 | return img
112 | ls = int(target_width * ls / ss)
113 | nw, nh = (ss, ls) if width_is_shorter else (ls, ss)
114 | return img.resize((nw, nh), method)
115 |
116 |
117 | def __crop(img, pos, size):
118 | ow, oh = img.size
119 | x1, y1 = pos
120 | tw = th = size
121 | return img.crop((x1, y1, x1 + tw, y1 + th))
122 |
123 |
124 | def __flip(img, flip):
125 | if flip:
126 | return img.transpose(Image.FLIP_LEFT_RIGHT)
127 | return img
128 |
129 |
130 | def __rotate(img, deg, method=Image.BICUBIC):
131 | return img.rotate(deg, resample=method)
132 |
133 |
134 | def __add1(img):
135 | return Image.fromarray(np.array(img) + 1)
--------------------------------------------------------------------------------
/models/networks/discriminator.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from models.networks.base_network import BaseNetwork
9 | from models.networks.normalization import get_nonspade_norm_layer
10 | import util.util as util
11 |
12 |
13 | class MultiscaleDiscriminator(BaseNetwork):
14 | @staticmethod
15 | def modify_commandline_options(parser, is_train):
16 | parser.add_argument('--netD_subarch', type=str, default='n_layer',
17 | help='architecture of each discriminator')
18 | parser.add_argument('--num_D', type=int, default=2,
19 | help='number of discriminators to be used in multiscale')
20 | opt, _ = parser.parse_known_args()
21 | # define properties of each discriminator of the multiscale discriminator
22 | subnetD = util.find_class_in_module(opt.netD_subarch + 'discriminator', \
23 | 'models.networks.discriminator')
24 | subnetD.modify_commandline_options(parser, is_train)
25 | return parser
26 |
27 | def __init__(self, opt):
28 | super().__init__()
29 | self.opt = opt
30 | for i in range(opt.num_D):
31 | subnetD = self.create_single_discriminator(opt)
32 | self.add_module('discriminator_%d' % i, subnetD)
33 |
34 | def create_single_discriminator(self, opt):
35 | subarch = opt.netD_subarch
36 | if subarch == 'n_layer':
37 | netD = NLayerDiscriminator(opt)
38 | else:
39 | raise ValueError('unrecognized discriminator subarchitecture %s' % subarch)
40 | return netD
41 |
42 | def downsample(self, input):
43 | return F.avg_pool2d(input, kernel_size=3, stride=2, padding=[1, 1], count_include_pad=False)
44 |
45 | def forward(self, input):
46 | result = []
47 | get_intermediate_features = not self.opt.no_ganFeat_loss
48 | for name, D in self.named_children():
49 | out = D(input)
50 | if not get_intermediate_features:
51 | out = [out]
52 | result.append(out)
53 | input = self.downsample(input)
54 | return result
55 |
56 |
57 | class NLayerDiscriminator(BaseNetwork):
58 | @staticmethod
59 | def modify_commandline_options(parser, is_train):
60 | parser.add_argument('--n_layers_D', type=int, default=4, help='# layers in each discriminator')
61 | return parser
62 |
63 | def __init__(self, opt):
64 | super().__init__()
65 | self.opt = opt
66 | kw = 4
67 | padw = int((kw - 1.0) / 2)
68 | nf = opt.ndf
69 | input_nc = self.compute_D_input_nc(opt)
70 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_D)
71 | sequence = [[nn.Conv2d(input_nc, nf, kernel_size=kw, stride=2, padding=padw),
72 | nn.LeakyReLU(0.2, False)]]
73 | for n in range(1, opt.n_layers_D):
74 | nf_prev = nf
75 | nf = min(nf * 2, 512)
76 | stride = 1 if n == opt.n_layers_D - 1 else 2
77 | if n == opt.n_layers_D - 1:
78 | dec = []
79 | nc_dec = nf_prev
80 | for _ in range(opt.n_layers_D - 1):
81 | dec += [nn.Upsample(scale_factor=2),
82 | norm_layer(nn.Conv2d(nc_dec, int(nc_dec//2), kernel_size=3, stride=1, padding=1)),
83 | nn.LeakyReLU(0.2, False)]
84 | nc_dec = int(nc_dec // 2)
85 | dec += [nn.Conv2d(nc_dec, opt.semantic_nc, kernel_size=3, stride=1, padding=1)]
86 | self.dec = nn.Sequential(*dec)
87 | sequence += [[norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)),
88 | nn.LeakyReLU(0.2, False)]]
89 | sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
90 | for n in range(len(sequence)):
91 | self.add_module('model' + str(n), nn.Sequential(*sequence[n]))
92 |
93 | def compute_D_input_nc(self, opt):
94 | input_nc = opt.label_nc + opt.output_nc
95 | if opt.contain_dontcare_label:
96 | input_nc += 1
97 | return input_nc
98 |
99 | def forward(self, input):
100 | results = [input]
101 | seg = None
102 | cam_logit = None
103 | for name, submodel in self.named_children():
104 | if 'model' not in name:
105 | continue
106 | x = results[-1]
107 | intermediate_output = submodel(x)
108 | results.append(intermediate_output)
109 | get_intermediate_features = not self.opt.no_ganFeat_loss
110 | if get_intermediate_features:
111 | retu = results[1:]
112 | else:
113 | retu = results[-1]
114 | return retu
115 |
--------------------------------------------------------------------------------
/trainers/pix2pix_trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import os
5 | import copy
6 | import sys
7 | import torch
8 | from models.pix2pix_model import Pix2PixModel
9 | try:
10 | from torch.cuda.amp import GradScaler
11 | except:
12 | # dummy GradScaler for PyTorch < 1.6
13 | class GradScaler:
14 | def __init__(self, enabled):
15 | pass
16 | def scale(self, loss):
17 | return loss
18 | def unscale_(self, optimizer):
19 | pass
20 | def step(self, optimizer):
21 | optimizer.step()
22 | def update(self):
23 | pass
24 |
25 |
26 | class Pix2PixTrainer():
27 | """
28 | Trainer creates the model and optimizers, and uses them to
29 | updates the weights of the network while reporting losses
30 | and the latest visuals to visualize the progress in training.
31 | """
32 |
33 | def __init__(self, opt, resume_epoch=0):
34 | self.opt = opt
35 | self.pix2pix_model = Pix2PixModel(opt)
36 | if len(opt.gpu_ids) > 1:
37 | self.pix2pix_model = torch.nn.DataParallel(self.pix2pix_model, device_ids=opt.gpu_ids)
38 | self.pix2pix_model_on_one_gpu = self.pix2pix_model.module
39 | else:
40 | self.pix2pix_model.to(opt.gpu_ids[0])
41 | self.pix2pix_model_on_one_gpu = self.pix2pix_model
42 | self.generated = None
43 | if opt.isTrain:
44 | self.optimizer_G, self.optimizer_D = self.pix2pix_model_on_one_gpu.create_optimizers(opt)
45 | self.old_lr = opt.lr
46 | if opt.continue_train and opt.which_epoch == 'latest':
47 | try:
48 | load_path = os.path.join(opt.checkpoints_dir, opt.name, 'optimizer.pth')
49 | checkpoint = torch.load(load_path)
50 | self.optimizer_G.load_state_dict(checkpoint['G'])
51 | self.optimizer_D.load_state_dict(checkpoint['D'])
52 | except FileNotFoundError as err:
53 | print(err)
54 | print('Not find optimizer state dict: ' + load_path + '. Do not load optimizer!')
55 |
56 | self.last_data, self.last_netCorr, self.last_netG, self.last_optimizer_G = None, None, None, None
57 | self.g_losses = {}
58 | self.d_losses = {}
59 | self.scaler = GradScaler(enabled=self.opt.amp)
60 |
61 | def run_generator_one_step(self, data):
62 | self.optimizer_G.zero_grad()
63 | g_losses, out = self.pix2pix_model(data, mode='generator')
64 | g_loss = sum(g_losses.values()).mean()
65 | # g_loss.backward()
66 | self.scaler.scale(g_loss).backward()
67 | self.scaler.unscale_(self.optimizer_G)
68 | # self.optimizer_G.step()
69 | self.scaler.step(self.optimizer_G)
70 | self.scaler.update()
71 | self.g_losses = g_losses
72 | self.out = out
73 |
74 | def run_discriminator_one_step(self, data):
75 | self.optimizer_D.zero_grad()
76 | GforD = {}
77 | GforD['fake_image'] = self.out['fake_image']
78 | GforD['adaptive_feature_seg'] = self.out['adaptive_feature_seg']
79 | GforD['adaptive_feature_img'] = self.out['adaptive_feature_img']
80 | d_losses = self.pix2pix_model(data, mode='discriminator', GforD=GforD)
81 | d_loss = sum(d_losses.values()).mean()
82 | # d_loss.backward()
83 | self.scaler.scale(d_loss).backward()
84 | self.scaler.unscale_(self.optimizer_D)
85 | # self.optimizer_D.step()
86 | self.scaler.step(self.optimizer_D)
87 | self.scaler.update()
88 | self.d_losses = d_losses
89 |
90 | def get_latest_losses(self):
91 | return {**self.g_losses, **self.d_losses}
92 |
93 | def get_latest_generated(self):
94 | return self.out['fake_image']
95 |
96 | def update_learning_rate(self, epoch):
97 | self.update_learning_rate(epoch)
98 |
99 | def save(self, epoch):
100 | self.pix2pix_model_on_one_gpu.save(epoch)
101 | if epoch == 'latest':
102 | torch.save({'G': self.optimizer_G.state_dict(), \
103 | 'D': self.optimizer_D.state_dict(), \
104 | 'lr': self.old_lr,}, \
105 | os.path.join(self.opt.checkpoints_dir, self.opt.name, 'optimizer.pth'))
106 |
107 | def update_learning_rate(self, epoch):
108 | if epoch > self.opt.niter:
109 | lrd = self.opt.lr / self.opt.niter_decay
110 | new_lr = self.old_lr - lrd
111 | else:
112 | new_lr = self.old_lr
113 | if new_lr != self.old_lr:
114 | new_lr_G = new_lr
115 | new_lr_D = new_lr
116 | else:
117 | new_lr_G = self.old_lr
118 | new_lr_D = self.old_lr
119 | for param_group in self.optimizer_D.param_groups:
120 | param_group['lr'] = new_lr_D
121 | for param_group in self.optimizer_G.param_groups:
122 | param_group['lr'] = new_lr_G
123 | print('update learning rate: %f -> %f' % (self.old_lr, new_lr))
124 | self.old_lr = new_lr
125 |
--------------------------------------------------------------------------------
/data/pix2pix_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
5 | import os
6 | import random
7 | from PIL import Image
8 |
9 | from data.base_dataset import BaseDataset, get_params, get_transform
10 |
11 |
12 | class Pix2pixDataset(BaseDataset):
13 | @staticmethod
14 | def modify_commandline_options(parser, is_train):
15 | parser.add_argument('--no_pairing_check', action='store_true', help='If specified, skip sanity check of correct label-image file pairing')
16 | return parser
17 |
18 | def initialize(self, opt):
19 | self.opt = opt
20 | label_paths, image_paths = self.get_paths(opt)
21 | label_paths = label_paths[:opt.max_dataset_size]
22 | image_paths = image_paths[:opt.max_dataset_size]
23 | if not opt.no_pairing_check:
24 | for path1, path2 in zip(label_paths, image_paths):
25 | assert self.paths_match(path1, path2), \
26 | "The label-image pair (%s, %s) do not look like the right pair because the filenames are quite different. Are you sure about the pairing? Please see data/pix2pix_dataset.py to see what is going on, and use --no_pairing_check to bypass this." % (path1, path2)
27 | self.label_paths = label_paths
28 | self.image_paths = image_paths
29 | size = len(self.label_paths)
30 | self.dataset_size = size
31 | self.real_reference_probability = 1 if opt.phase == 'test' else opt.real_reference_probability
32 | self.hard_reference_probability = 0 if opt.phase == 'test' else opt.hard_reference_probability
33 | self.ref_dict, self.train_test_folder = self.get_ref(opt)
34 |
35 | def get_paths(self, opt):
36 | label_paths = []
37 | image_paths = []
38 | assert False, "A subclass of Pix2pixDataset must override self.get_paths(self, opt)"
39 | return label_paths, image_paths
40 |
41 | def paths_match(self, path1, path2):
42 | filename1_without_ext = os.path.splitext(os.path.basename(path1))[0]
43 | filename2_without_ext = os.path.splitext(os.path.basename(path2))[0]
44 | return filename1_without_ext == filename2_without_ext
45 |
46 | def get_label_tensor(self, path):
47 | label = Image.open(path)
48 | params1 = get_params(self.opt, label.size)
49 | transform_label = get_transform(self.opt, params1, method=Image.NEAREST, normalize=False)
50 | label_tensor = transform_label(label) * 255.0
51 | label_tensor[label_tensor == 255] = self.opt.label_nc
52 | # 'unknown' is opt.label_nc
53 | return label_tensor, params1
54 |
55 | def __getitem__(self, index):
56 | # label Image
57 | label_path = self.label_paths[index]
58 | label_path = os.path.join(self.opt.dataroot, label_path)
59 | label_tensor, params1 = self.get_label_tensor(label_path)
60 | # input image (real images)
61 | image_path = self.image_paths[index]
62 | image_path = os.path.join(self.opt.dataroot, image_path)
63 | image = Image.open(image_path).convert('RGB')
64 | transform_image = get_transform(self.opt, params1)
65 | image_tensor = transform_image(image)
66 | ref_tensor = 0
67 | label_ref_tensor = 0
68 | random_p = random.random()
69 | if random_p < self.real_reference_probability or self.opt.phase == 'test':
70 | key = image_path.split('deepfashionHD/')[-1]
71 | val = self.ref_dict[key]
72 | if random_p < self.hard_reference_probability:
73 | #hard reference
74 | path_ref = val[1]
75 | else:
76 | #easy reference
77 | path_ref = val[0]
78 | if self.opt.dataset_mode == 'deepfashionHD':
79 | path_ref = os.path.join(self.opt.dataroot, path_ref)
80 | else:
81 | path_ref = os.path.dirname(image_path).replace(self.train_test_folder[1], self.train_test_folder[0]) + '/' + path_ref
82 | image_ref = Image.open(path_ref).convert('RGB')
83 | if self.opt.dataset_mode != 'deepfashionHD':
84 | path_ref_label = path_ref.replace('.jpg', '.png')
85 | path_ref_label = self.imgpath_to_labelpath(path_ref_label)
86 | else:
87 | path_ref_label = self.imgpath_to_labelpath(path_ref)
88 | label_ref_tensor, params = self.get_label_tensor(path_ref_label)
89 | transform_image = get_transform(self.opt, params)
90 | ref_tensor = transform_image(image_ref)
91 | self_ref_flag = 0.0
92 | else:
93 | pair = False
94 | if self.opt.dataset_mode == 'deepfashionHD' and self.opt.video_like:
95 | key = image_path.replace('\\', '/').split('deepfashionHD/')[-1]
96 | val = self.ref_dict[key]
97 | ref_name = val[0]
98 | key_name = key
99 | path_ref = os.path.join(self.opt.dataroot, ref_name)
100 | image_ref = Image.open(path_ref).convert('RGB')
101 | label_ref_path = self.imgpath_to_labelpath(path_ref)
102 | label_ref_tensor, params = self.get_label_tensor(label_ref_path)
103 | transform_image = get_transform(self.opt, params)
104 | ref_tensor = transform_image(image_ref)
105 | pair = True
106 | if not pair:
107 | label_ref_tensor, params = self.get_label_tensor(label_path)
108 | transform_image = get_transform(self.opt, params)
109 | ref_tensor = transform_image(image)
110 | self_ref_flag = 1.0
111 | input_dict = {'label': label_tensor,
112 | 'image': image_tensor,
113 | 'path': image_path,
114 | 'self_ref': self_ref_flag,
115 | 'ref': ref_tensor,
116 | 'label_ref': label_ref_tensor
117 | }
118 | return input_dict
119 |
120 | def __len__(self):
121 | return self.dataset_size
122 |
123 | def get_ref(self, opt):
124 | pass
125 |
126 | def imgpath_to_labelpath(self, path):
127 | return path
128 |
--------------------------------------------------------------------------------
/data/deepfashionHD_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import os
5 | import cv2
6 | import torch
7 | import numpy as np
8 | import math
9 | import random
10 | from PIL import Image
11 |
12 | from data.pix2pix_dataset import Pix2pixDataset
13 | from data.base_dataset import get_params, get_transform
14 |
15 |
16 | class DeepFashionHDDataset(Pix2pixDataset):
17 | @staticmethod
18 | def modify_commandline_options(parser, is_train):
19 | parser = Pix2pixDataset.modify_commandline_options(parser, is_train)
20 | parser.set_defaults(preprocess_mode='resize_and_crop')
21 | parser.set_defaults(no_pairing_check=True)
22 | parser.set_defaults(load_size=550)
23 | parser.set_defaults(crop_size=512)
24 | parser.set_defaults(label_nc=20)
25 | parser.set_defaults(contain_dontcare_label=False)
26 | parser.set_defaults(cache_filelist_read=False)
27 | parser.set_defaults(cache_filelist_write=False)
28 | return parser
29 |
30 | def get_paths(self, opt):
31 | root = opt.dataroot
32 | if opt.phase == 'train':
33 | fd = open(os.path.join('./data/train.txt'))
34 | lines = fd.readlines()
35 | fd.close()
36 | elif opt.phase == 'test':
37 | fd = open(os.path.join('./data/val.txt'))
38 | lines = fd.readlines()
39 | fd.close()
40 | image_paths = []
41 | label_paths = []
42 | for i in range(len(lines)):
43 | name = lines[i].strip()
44 | image_paths.append(name)
45 | label_path = name.replace('img', 'pose').replace('.jpg', '_{}.txt')
46 | label_paths.append(os.path.join(label_path))
47 | return label_paths, image_paths
48 |
49 | def get_ref_video_like(self, opt):
50 | pair_path = './data/deepfashion_self_pair.txt'
51 | with open(pair_path) as fd:
52 | self_pair = fd.readlines()
53 | self_pair = [it.strip() for it in self_pair]
54 | self_pair_dict = {}
55 | for it in self_pair:
56 | items = it.split(',')
57 | self_pair_dict[items[0]] = items[1:]
58 | ref_path = './data/deepfashion_ref_test.txt' if opt.phase == 'test' else './data/deepfashion_ref.txt'
59 | with open(ref_path) as fd:
60 | ref = fd.readlines()
61 | ref = [it.strip() for it in ref]
62 | ref_dict = {}
63 | for i in range(len(ref)):
64 | items = ref[i].strip().split(',')
65 | key = items[0]
66 | if key in self_pair_dict.keys():
67 | val = [it for it in self_pair_dict[items[0]]]
68 | else:
69 | val = [items[-1]]
70 | ref_dict[key.replace('\\',"/")] = [v.replace('\\',"/") for v in val]
71 | train_test_folder = ('', '')
72 | return ref_dict, train_test_folder
73 |
74 | def get_ref_vgg(self, opt):
75 | extra = ''
76 | if opt.phase == 'test':
77 | extra = '_test'
78 | with open('./data/deepfashion_ref{}.txt'.format(extra)) as fd:
79 | lines = fd.readlines()
80 | ref_dict = {}
81 | for i in range(len(lines)):
82 | items = lines[i].strip().split(',')
83 | key = items[0]
84 | if opt.phase == 'test':
85 | val = [it for it in items[1:]]
86 | else:
87 | val = [items[-1]]
88 | ref_dict[key.replace('\\',"/")] = [v.replace('\\',"/") for v in val]
89 | train_test_folder = ('', '')
90 | return ref_dict, train_test_folder
91 |
92 | def get_ref(self, opt):
93 | if opt.video_like:
94 | return self.get_ref_video_like(opt)
95 | else:
96 | return self.get_ref_vgg(opt)
97 |
98 | def get_label_tensor(self, path):
99 | candidate = np.loadtxt(path.format('candidate'))
100 | subset = np.loadtxt(path.format('subset'))
101 | stickwidth = 20
102 | limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
103 | [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
104 | [1, 16], [16, 18], [3, 17], [6, 18]]
105 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
106 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
107 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
108 | canvas = np.zeros((1024, 1024, 3), dtype=np.uint8)
109 | cycle_radius = 20
110 | for i in range(18):
111 | index = int(subset[i])
112 | if index == -1:
113 | continue
114 | x, y = candidate[index][0:2]
115 | cv2.circle(canvas, (int(x), int(y)), cycle_radius, colors[i], thickness=-1)
116 | joints = []
117 | for i in range(17):
118 | index = subset[np.array(limbSeq[i]) - 1]
119 | cur_canvas = canvas.copy()
120 | if -1 in index:
121 | joints.append(np.zeros_like(cur_canvas[:, :, 0]))
122 | continue
123 | Y = candidate[index.astype(int), 0]
124 | X = candidate[index.astype(int), 1]
125 | mX = np.mean(X)
126 | mY = np.mean(Y)
127 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
128 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
129 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
130 | cv2.fillConvexPoly(cur_canvas, polygon, colors[i])
131 | canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
132 | joint = np.zeros_like(cur_canvas[:, :, 0])
133 | cv2.fillConvexPoly(joint, polygon, 255)
134 | joint = cv2.addWeighted(joint, 0.4, joint, 0.6, 0)
135 | joints.append(joint)
136 | pose = Image.fromarray(cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)).resize((self.opt.load_size, self.opt.load_size), resample=Image.NEAREST)
137 | params = get_params(self.opt, pose.size)
138 | transform_label = get_transform(self.opt, params, method=Image.NEAREST, normalize=False)
139 | transform_img = get_transform(self.opt, params, method=Image.BILINEAR, normalize=False)
140 | tensors_dist = 0
141 | e = 1
142 | for i in range(len(joints)):
143 | im_dist = cv2.distanceTransform(255-joints[i], cv2.DIST_L1, 3)
144 | im_dist = np.clip((im_dist/3), 0, 255).astype(np.uint8)
145 | tensor_dist = transform_img(Image.fromarray(im_dist))
146 | tensors_dist = tensor_dist if e == 1 else torch.cat([tensors_dist, tensor_dist])
147 | e += 1
148 | tensor_pose = transform_label(pose)
149 | label_tensor = torch.cat((tensor_pose, tensors_dist), dim=0)
150 | return label_tensor, params
151 |
152 | def imgpath_to_labelpath(self, path):
153 | label_path = path.replace('/img/', '/pose/').replace('.jpg', '_{}.txt')
154 | return label_path
155 |
156 | def labelpath_to_imgpath(self, path):
157 | img_path = path.replace('/pose/', '/img/').replace('_{}.txt', '.jpg')
158 | return img_path
159 |
--------------------------------------------------------------------------------
/models/networks/architecture.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.nn.utils.spectral_norm as spectral_norm
7 | from models.networks.normalization import SPADE
8 | from util.util import vgg_preprocess
9 |
10 |
11 | class ResidualBlock(nn.Module):
12 | def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, stride=1):
13 | super(ResidualBlock, self).__init__()
14 | self.relu = nn.PReLU()
15 | self.model = nn.Sequential(
16 | nn.ReflectionPad2d(padding),
17 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride),
18 | nn.InstanceNorm2d(out_channels),
19 | self.relu,
20 | nn.ReflectionPad2d(padding),
21 | nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=0, stride=stride),
22 | nn.InstanceNorm2d(out_channels),
23 | )
24 |
25 | def forward(self, x):
26 | out = self.relu(x + self.model(x))
27 | return out
28 |
29 |
30 | class SPADEResnetBlock(nn.Module):
31 | def __init__(self, fin, fout, opt, use_se=False, dilation=1):
32 | super().__init__()
33 | # Attributes
34 | self.learned_shortcut = (fin != fout)
35 | fmiddle = min(fin, fout)
36 | self.opt = opt
37 | self.pad_type = 'nozero'
38 | self.use_se = use_se
39 | # create conv layers
40 | if self.pad_type != 'zero':
41 | self.pad = nn.ReflectionPad2d(dilation)
42 | self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=0, dilation=dilation)
43 | self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=0, dilation=dilation)
44 | else:
45 | self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
46 | self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
47 | if self.learned_shortcut:
48 | self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
49 | # apply spectral norm if specified
50 | if 'spectral' in opt.norm_G:
51 | self.conv_0 = spectral_norm(self.conv_0)
52 | self.conv_1 = spectral_norm(self.conv_1)
53 | if self.learned_shortcut:
54 | self.conv_s = spectral_norm(self.conv_s)
55 | # define normalization layers
56 | spade_config_str = opt.norm_G.replace('spectral', '')
57 | if 'spade_ic' in opt:
58 | ic = opt.spade_ic
59 | else:
60 | ic = 4*3+opt.label_nc
61 | self.norm_0 = SPADE(spade_config_str, fin, ic, PONO=opt.PONO)
62 | self.norm_1 = SPADE(spade_config_str, fmiddle, ic, PONO=opt.PONO)
63 | if self.learned_shortcut:
64 | self.norm_s = SPADE(spade_config_str, fin, ic, PONO=opt.PONO)
65 |
66 | def forward(self, x, seg1):
67 | x_s = self.shortcut(x, seg1)
68 | if self.pad_type != 'zero':
69 | dx = self.conv_0(self.pad(self.actvn(self.norm_0(x, seg1))))
70 | dx = self.conv_1(self.pad(self.actvn(self.norm_1(dx, seg1))))
71 | else:
72 | dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
73 | dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
74 | out = x_s + dx
75 | return out
76 |
77 | def shortcut(self, x, seg1):
78 | if self.learned_shortcut:
79 | x_s = self.conv_s(self.norm_s(x, seg1))
80 | else:
81 | x_s = x
82 | return x_s
83 |
84 | def actvn(self, x):
85 | return F.leaky_relu(x, 2e-1)
86 |
87 |
88 | class VGG19_feature_color_torchversion(nn.Module):
89 | """
90 | NOTE: there is no need to pre-process the input
91 | input tensor should range in [0,1]
92 | """
93 | def __init__(self, pool='max', vgg_normal_correct=False, ic=3):
94 | super(VGG19_feature_color_torchversion, self).__init__()
95 | self.vgg_normal_correct = vgg_normal_correct
96 |
97 | self.conv1_1 = nn.Conv2d(ic, 64, kernel_size=3, padding=1)
98 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
99 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
100 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
101 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
102 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
103 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
104 | self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
105 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
106 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
107 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
108 | self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
109 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
110 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
111 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
112 | self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
113 | if pool == 'max':
114 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
115 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
116 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
117 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
118 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
119 | elif pool == 'avg':
120 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
121 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
122 | self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
123 | self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
124 | self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)
125 |
126 | def forward(self, x, out_keys, preprocess=True):
127 | '''
128 | NOTE: input tensor should range in [0,1]
129 | '''
130 | out = {}
131 | if preprocess:
132 | x = vgg_preprocess(x, vgg_normal_correct=self.vgg_normal_correct)
133 | out['r11'] = F.relu(self.conv1_1(x))
134 | out['r12'] = F.relu(self.conv1_2(out['r11']))
135 | out['p1'] = self.pool1(out['r12'])
136 | out['r21'] = F.relu(self.conv2_1(out['p1']))
137 | out['r22'] = F.relu(self.conv2_2(out['r21']))
138 | out['p2'] = self.pool2(out['r22'])
139 | out['r31'] = F.relu(self.conv3_1(out['p2']))
140 | out['r32'] = F.relu(self.conv3_2(out['r31']))
141 | out['r33'] = F.relu(self.conv3_3(out['r32']))
142 | out['r34'] = F.relu(self.conv3_4(out['r33']))
143 | out['p3'] = self.pool3(out['r34'])
144 | out['r41'] = F.relu(self.conv4_1(out['p3']))
145 | out['r42'] = F.relu(self.conv4_2(out['r41']))
146 | out['r43'] = F.relu(self.conv4_3(out['r42']))
147 | out['r44'] = F.relu(self.conv4_4(out['r43']))
148 | out['p4'] = self.pool4(out['r44'])
149 | out['r51'] = F.relu(self.conv5_1(out['p4']))
150 | out['r52'] = F.relu(self.conv5_2(out['r51']))
151 | out['r53'] = F.relu(self.conv5_3(out['r52']))
152 | out['r54'] = F.relu(self.conv5_4(out['r53']))
153 | out['p5'] = self.pool5(out['r54'])
154 | return [out[key] for key in out_keys]
155 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CoCosNet v2: Full-Resolution Correspondence Learning for Image Translation (CVPR 2021, oral presentation)
2 | 
3 |
4 | **CoCosNet v2: Full-Resolution Correspondence Learning for Image Translation**
5 | **CVPR 2021, oral presentation**
6 | [Xingran Zhou](http://xingranzh.github.io/), [Bo Zhang](https://bo-zhang.me/), [Ting Zhang](https://www.microsoft.com/en-us/research/people/tinzhan/), [Pan Zhang](https://panzhang0212.github.io/), [Jianmin Bao](https://jianminbao.github.io/), [Dong Chen](https://www.microsoft.com/en-us/research/people/doch/), [Zhongfei Zhang](https://www.cs.binghamton.edu/~zhongfei/), [Fang Wen](https://www.microsoft.com/en-us/research/people/fangwen/)
7 | ### [Paper](https://arxiv.org/pdf/2012.02047.pdf) | [Slides](https://github.com/xingranzh/xingranzh.github.io/blob/master/slides/cocosnet_v2_slides.pdf)
8 | ## Abstract
9 | > We present the full-resolution correspondence learning for cross-domain images, which aids image translation. We adopt a hierarchical strategy that uses the correspondence from coarse level to guide the fine levels. At each hierarchy, the correspondence can be efficiently computed via PatchMatch that iteratively leverages the matchings from the neighborhood. Within each PatchMatch iteration, the ConvGRU module is employed to refine the current correspondence considering not only the matchings of larger context but also the historic estimates. The proposed CoCosNet v2, a GRU-assisted PatchMatch approach, is fully differentiable and highly efficient. When jointly trained with image translation, full-resolution semantic correspondence can be established in an unsupervised manner, which in turn facilitates the exemplar-based image translation. Experiments on diverse translation tasks show that CoCosNet v2 performs considerably better than state-of-the-art literature on producing high-resolution images.
10 |
11 | ## :sparkles: News
12 | 2022.12 We propose [Paint by Example](https://github.com/Fantasy-Studio/Paint-by-Example) which allows in the wild image editing according to an examplar based on **stable diffusion**. One can have a try for our [online demo](https://huggingface.co/spaces/Fantasy-Studio/Paint-by-Example).
13 |
14 | 2022.8 We recently propose [PITI](https://github.com/PITI-Synthesis/PITI) which is a SOTA image-to-image translation method based on *prtrained diffusion model*.
15 |
16 | ## Installation
17 | First please install dependencies for the experiment:
18 | ```bash
19 | pip install -r requirements.txt
20 | ````
21 | We recommend to install Pytorch version after `Pytorch 1.6.0` since we made use of [automatic mixed precision](https://pytorch.org/docs/stable/amp.html) for accelerating. (we used `Pytorch 1.7.0` in our experiments)
22 | ## Prepare the dataset
23 | First download the Deepfashion dataset (high resolution version) from [this link](https://drive.google.com/file/d/1bByKH1ciLXY70Bp8le_AVnjk-Hd4pe_i/view?usp=sharing). Note the file name is `img_highres.zip`. Unzip the file and rename it as `img`.
24 | If the password is necessary, please contact [this link](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html) to access the dataset.
25 | We use [OpenPose](https://github.com/Hzzone/pytorch-openpose) to estimate pose of DeepFashion(HD). We offer the keypoints detection results used in our experiment in [this link](https://drive.google.com/file/d/1wxrqyb67Xu_IPyZzftLgBPHDTKGQP7Pk/view?usp=sharing). Download and unzip the results file.
26 | Since the original resolution of DeepfashionHD is 750x1101, we use a Python script to process the images to the resolution 512x512. You can find the script in [`data/preprocess.py`](https://github.com/microsoft/CoCosNet-v2/blob/main/data/preprocess.py). Note you need to download our train-val split lists `train.txt` and `val.txt` from [this link](https://drive.google.com/drive/folders/15NBujOTLnO_cRoAufWPqtOWKIinCKi0z?usp=sharing) in this step.
27 | Download the train-val lists from [this link](https://drive.google.com/drive/folders/15NBujOTLnO_cRoAufWPqtOWKIinCKi0z?usp=sharing), and the retrival pair lists from [this link](https://drive.google.com/drive/folders/1dJU8iq8kFiwq33nWtvj5Ql5rUh9fiXUi?usp=sharing). Note `train.txt` and `val.txt` are our train-val lists. `deepfashion_ref.txt`, `deepfashion_ref_test.txt` and `deepfashion_self_pair.txt` are the paring lists used in our experiment. Download them all and move below the folder `data/`.
28 | Finally create the root folder `deepfashionHD`, and move the folders `img` and `pose` below it. Now the the directory structure is like:
29 | ```
30 | deepfashionHD
31 | │
32 | └─── img
33 | │ │
34 | │ └─── MEN
35 | │ │ │ ...
36 | │ │
37 | │ └─── WOMEN
38 | │ │ ...
39 | │
40 | └─── pose
41 | │ │
42 | │ └─── MEN
43 | │ │ │ ...
44 | │ │
45 | │ └─── WOMEN
46 | │ │ ...
47 |
48 | ```
49 | ## Inference Using Pretrained Model
50 | The inference results are saved in the folder `checkpoints/deepfashionHD/test`. Download the pretrained model from [this link](https://drive.google.com/file/d/1ehkrKlf5s1gfpDNXO6AC9SIZMtqs5L3N/view?usp=sharing).
51 | Move the models below the folder `checkpoints/deepfashionHD`. Then run the following command.
52 | ````bash
53 | python test.py --name deepfashionHD --dataset_mode deepfashionHD --dataroot dataset/deepfashionHD --PONO --PONO_C --no_flip --batchSize 8 --gpu_ids 0 --netCorr NoVGGHPM --nThreads 16 --nef 32 --amp --display_winsize 512 --iteration_count 5 --load_size 512 --crop_size 512
54 | ````
55 | The inference results are saved in the folder `checkpoints/deepfashionHD/test`.
56 | ## Training from scratch
57 | Make sure you have prepared the DeepfashionHD dataset as the instruction.
58 | Download the **pretrained VGG model** from [this link](https://drive.google.com/file/d/1D-z73DOt63BrPTgIxffN6Q4_L9qma9y8/view?usp=sharing), move it to `vgg/` folder. We use this model to calculate training loss.
59 |
60 | Run the following command for training from scratch.
61 | ````bash
62 | python train.py --name deepfashionHD --dataset_mode deepfashionHD --dataroot dataset/deepfashionHD --niter 100 --niter_decay 0 --real_reference_probability 0.0 --hard_reference_probability 0.0 --which_perceptual 4_2 --weight_perceptual 0.001 --PONO --PONO_C --vgg_normal_correct --weight_fm_ratio 1.0 --no_flip --video_like --batchSize 16 --gpu_ids 0,1,2,3,4,5,6,7 --netCorr NoVGGHPM --match_kernel 1 --featEnc_kernel 3 --display_freq 500 --print_freq 50 --save_latest_freq 2500 --save_epoch_freq 5 --nThreads 16 --weight_warp_self 500.0 --lr 0.0001 --nef 32 --amp --weight_warp_cycle 1.0 --display_winsize 512 --iteration_count 5 --temperature 0.01 --continue_train --load_size 550 --crop_size 512 --which_epoch 15
63 | ````
64 | Note that `--dataroot` parameter is your DeepFashionHD dataset root, e.g. `dataset/DeepFashionHD`.
65 | We use 8 32GB Tesla V100 GPUs to train the network. You can set `batchSize` to 16, 8 or 4 with fewer GPUs and change `gpu_ids`.
66 | ## Citation
67 | If you use this code for your research, please cite our papers.
68 | ```
69 | @InProceedings{Zhou_2021_CVPR,
70 | author={Zhou, Xingran and Zhang, Bo and Zhang, Ting and Zhang, Pan and Bao, Jianmin and Chen, Dong and Zhang, Zhongfei and Wen, Fang},
71 | title={CoCosNet v2: Full-Resolution Correspondence Learning for Image Translation},
72 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
73 | year={2021},
74 | pages={11465-11475}
75 | }
76 | ```
77 |
78 | Also, welcome to refer to our [CoCosNet v1](https://github.com/microsoft/CoCosNet):
79 | ```
80 | @InProceedings{Zhang_2020_CVPR,
81 | author={Zhang, Pan and Zhang, Bo and Chen, Dong and Yuan, Lu and Wen, Fang},
82 | title={Cross-Domain Correspondence Learning for Exemplar-Based Image Translation},
83 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
84 | year={2020},
85 | pages={5143-5153}
86 | }
87 | ```
88 |
89 | ## Acknowledgments
90 | *This code borrows heavily from [CocosNet](https://github.com/microsoft/CoCosNet) and [DeepPruner](https://github.com/uber-research/DeepPruner).
91 | We also thank [SPADE](https://github.com/NVlabs/SPADE) and [RAFT](https://github.com/princeton-vl/RAFT).*
92 | ## License
93 | The codes and the pretrained model in this repository are under the MIT license as specified by the LICENSE file.
94 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
95 |
--------------------------------------------------------------------------------
/models/networks/patch_match.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 |
9 | import random
10 |
11 | from models.networks.convgru import BasicUpdateBlock
12 | from models.networks.ops import *
13 |
14 |
15 | """patch match"""
16 | class Evaluate(nn.Module):
17 | def __init__(self, temperature):
18 | super().__init__()
19 | self.filter_size = 3
20 | self.temperature = temperature
21 |
22 | def forward(self, left_features, right_features, offset_x, offset_y):
23 | device = left_features.get_device()
24 | batch_size, num, height, width = offset_x.size()
25 | channel = left_features.size()[1]
26 | matching_inds = offset_to_inds(offset_x, offset_y)
27 | matching_inds = matching_inds.view(batch_size, num, height * width).permute(0, 2, 1).long()
28 | base_batch = torch.arange(batch_size).to(device).long() * (height * width)
29 | base_batch = base_batch.view(-1, 1, 1)
30 | matching_inds_add_base = matching_inds + base_batch
31 | right_features_view = right_features
32 | match_cost = []
33 | # using A[:, idx]
34 | for i in range(matching_inds_add_base.size()[-1]):
35 | idx = matching_inds_add_base[:, :, i]
36 | idx = idx.contiguous().view(-1)
37 | right_features_select = right_features_view[:, idx]
38 | right_features_select = right_features_select.view(channel, batch_size, -1).transpose(0, 1)
39 | match_cost_i = torch.sum(left_features * right_features_select, dim=1, keepdim=True) / self.temperature
40 | match_cost.append(match_cost_i)
41 | match_cost = torch.cat(match_cost, dim=1).transpose(1, 2)
42 | match_cost = F.softmax(match_cost, dim=-1)
43 | match_cost_topk, match_cost_topk_indices = torch.topk(match_cost, num//self.filter_size, dim=-1)
44 | matching_inds = torch.gather(matching_inds, -1, match_cost_topk_indices)
45 | matching_inds = matching_inds.permute(0, 2, 1).view(batch_size, -1, height, width).float()
46 | offset_x, offset_y = inds_to_offset(matching_inds)
47 | corr = match_cost_topk.permute(0, 2, 1)
48 | return offset_x, offset_y, corr
49 |
50 |
51 | class PropagationFaster(nn.Module):
52 | def __init__(self):
53 | super().__init__()
54 |
55 | def forward(self, offset_x, offset_y, propagation_type="horizontal"):
56 | device = offset_x.get_device()
57 | self.horizontal_zeros = torch.zeros((offset_x.size()[0], offset_x.size()[1], offset_x.size()[2], 1)).to(device)
58 | self.vertical_zeros = torch.zeros((offset_x.size()[0], offset_x.size()[1], 1, offset_x.size()[3])).to(device)
59 | if propagation_type is "horizontal":
60 | offset_x = torch.cat((torch.cat((self.horizontal_zeros, offset_x[:, :, :, :-1]), dim=3),
61 | offset_x,
62 | torch.cat((offset_x[:, :, :, 1:], self.horizontal_zeros), dim=3)), dim=1)
63 |
64 | offset_y = torch.cat((torch.cat((self.horizontal_zeros, offset_y[:, :, :, :-1]), dim=3),
65 | offset_y,
66 | torch.cat((offset_y[:, :, :, 1:], self.horizontal_zeros), dim=3)), dim=1)
67 |
68 | else:
69 | offset_x = torch.cat((torch.cat((self.vertical_zeros, offset_x[:, :, :-1, :]), dim=2),
70 | offset_x,
71 | torch.cat((offset_x[:, :, 1:, :], self.vertical_zeros), dim=2)), dim=1)
72 |
73 | offset_y = torch.cat((torch.cat((self.vertical_zeros, offset_y[:, :, :-1, :]), dim=2),
74 | offset_y,
75 | torch.cat((offset_y[:, :, 1:, :], self.vertical_zeros), dim=2)), dim=1)
76 | return offset_x, offset_y
77 |
78 |
79 | class PatchMatchOnce(nn.Module):
80 | def __init__(self, opt):
81 | super().__init__()
82 | self.propagation = PropagationFaster()
83 | self.evaluate = Evaluate(opt.temperature)
84 |
85 | def forward(self, left_features, right_features, offset_x, offset_y):
86 | prob = random.random()
87 | if prob < 0.5:
88 | offset_x, offset_y = self.propagation(offset_x, offset_y, "horizontal")
89 | offset_x, offset_y, _ = self.evaluate(left_features, right_features, offset_x, offset_y)
90 | offset_x, offset_y = self.propagation(offset_x, offset_y, "vertical")
91 | offset_x, offset_y, corr = self.evaluate(left_features, right_features, offset_x, offset_y)
92 | else:
93 | offset_x, offset_y = self.propagation(offset_x, offset_y, "vertical")
94 | offset_x, offset_y, _ = self.evaluate(left_features, right_features, offset_x, offset_y)
95 | offset_x, offset_y = self.propagation(offset_x, offset_y, "horizontal")
96 | offset_x, offset_y, corr = self.evaluate(left_features, right_features, offset_x, offset_y)
97 | return offset_x, offset_y, corr
98 |
99 |
100 | class PatchMatchGRU(nn.Module):
101 | def __init__(self, opt):
102 | super().__init__()
103 | self.patch_match_one_step = PatchMatchOnce(opt)
104 | self.temperature = opt.temperature
105 | self.iters = opt.iteration_count
106 | input_dim = opt.nef
107 | hidden_dim = 32
108 | norm = nn.InstanceNorm2d(hidden_dim, affine=False)
109 | relu = nn.ReLU(inplace=True)
110 | """
111 | concat left and right features
112 | """
113 | self.initial_layer = nn.Sequential(
114 | nn.Conv2d(input_dim*2, hidden_dim, kernel_size=3, padding=1, stride=1),
115 | norm,
116 | relu,
117 | nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
118 | norm,
119 | relu,
120 | )
121 | self.refine_net = BasicUpdateBlock()
122 |
123 | def forward(self, left_features, right_features, right_input, initial_offset_x, initial_offset_y):
124 | device = left_features.get_device()
125 | batch_size, channel, height, width = left_features.size()
126 | num = initial_offset_x.size()[1]
127 | initial_input = torch.cat((left_features, right_features), dim=1)
128 | hidden = self.initial_layer(initial_input)
129 | left_features = left_features.view(batch_size, -1, height * width)
130 | right_features = right_features.view(batch_size, -1, height * width)
131 | right_features_view = right_features.transpose(0, 1).contiguous().view(channel, -1)
132 | with torch.no_grad():
133 | offset_x, offset_y = initial_offset_x, initial_offset_y
134 | for it in range(self.iters):
135 | with torch.no_grad():
136 | offset_x, offset_y, corr = self.patch_match_one_step(left_features, right_features_view, offset_x, offset_y)
137 | """GRU refinement"""
138 | flow = torch.cat((offset_x, offset_y), dim=1)
139 | corr = corr.view(batch_size, -1, height, width)
140 | hidden, delta_offset_x, delta_offset_y = self.refine_net(hidden, corr, flow)
141 | offset_x = offset_x + delta_offset_x
142 | offset_y = offset_y + delta_offset_y
143 | with torch.no_grad():
144 | matching_inds = offset_to_inds(offset_x, offset_y)
145 | matching_inds = matching_inds.view(batch_size, num, height * width).permute(0, 2, 1).long()
146 | base_batch = torch.arange(batch_size).to(device).long() * (height * width)
147 | base_batch = base_batch.view(-1, 1, 1)
148 | matching_inds_plus_base = matching_inds + base_batch
149 | match_cost = []
150 | # using A[:, idx]
151 | for i in range(matching_inds_plus_base.size()[-1]):
152 | idx = matching_inds_plus_base[:, :, i]
153 | idx = idx.contiguous().view(-1)
154 | right_features_select = right_features_view[:, idx]
155 | right_features_select = right_features_select.view(channel, batch_size, -1).transpose(0, 1)
156 | match_cost_i = torch.sum(left_features * right_features_select, dim=1, keepdim=True) / self.temperature
157 | match_cost.append(match_cost_i)
158 | match_cost = torch.cat(match_cost, dim=1).transpose(1, 2)
159 | match_cost = F.softmax(match_cost, dim=-1)
160 | right_input_view = right_input.transpose(0, 1).contiguous().view(right_input.size()[1], -1)
161 | warp = torch.zeros_like(right_input)
162 | # using A[:, idx]
163 | for i in range(match_cost.size()[-1]):
164 | idx = matching_inds_plus_base[:, :, i]
165 | idx = idx.contiguous().view(-1)
166 | right_input_select = right_input_view[:, idx]
167 | right_input_select = right_input_select.view(right_input.size()[1], batch_size, -1).transpose(0, 1)
168 | warp = warp + right_input_select * match_cost[:, :, i].unsqueeze(dim=1)
169 | return matching_inds, warp
170 |
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import os
5 | import sys
6 | import random
7 | import argparse
8 | import pickle
9 | import numpy as np
10 | import torch
11 | import models
12 | import data
13 | from util import util
14 |
15 |
16 | class BaseOptions():
17 | def __init__(self):
18 | self.initialized = False
19 |
20 | def initialize(self, parser):
21 | # experiment specifics
22 | parser.add_argument('--name', type=str, default='deepfashionHD', help='name of the experiment. It decides where to store samples and models')
23 | parser.add_argument('--gpu_ids', type=str, default='0,1,2,3', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
24 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
25 | parser.add_argument('--model', type=str, default='pix2pix', help='which model to use')
26 | parser.add_argument('--norm_G', type=str, default='spectralinstance', help='instance normalization or batch normalization')
27 | parser.add_argument('--norm_D', type=str, default='spectralinstance', help='instance normalization or batch normalization')
28 | parser.add_argument('--norm_E', type=str, default='spectralinstance', help='instance normalization or batch normalization')
29 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
30 | # input/output sizes
31 | parser.add_argument('--batchSize', type=int, default=4, help='input batch size')
32 | parser.add_argument('--preprocess_mode', type=str, default='scale_width_and_crop', help='scaling and cropping of images at load time.', choices=("resize_and_crop", "crop", "scale_width", "scale_width_and_crop", "scale_shortside", "scale_shortside_and_crop", "fixed", "none"))
33 | parser.add_argument('--load_size', type=int, default=256, help='Scale images to this size. The final image will be cropped to --crop_size.')
34 | parser.add_argument('--crop_size', type=int, default=256, help='Crop to the width of crop_size (after initially scaling the images to load_size.)')
35 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='The ratio width/height. The final height of the load image will be crop_size/aspect_ratio')
36 | parser.add_argument('--label_nc', type=int, default=182, help='# of input label classes without unknown class. If you have unknown class as class label, specify --contain_dopntcare_label.')
37 | parser.add_argument('--contain_dontcare_label', action='store_true', help='if the label map contains dontcare label (dontcare=255)')
38 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
39 | # for setting inputs
40 | parser.add_argument('--dataroot', type=str, default='dataset/deepfashionHD')
41 | parser.add_argument('--dataset_mode', type=str, default='deepfashionHD')
42 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
43 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
44 | parser.add_argument('--nThreads', default=16, type=int, help='# threads for loading data')
45 | parser.add_argument('--max_dataset_size', type=int, default=sys.maxsize, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
46 | parser.add_argument('--load_from_opt_file', action='store_true', help='load the options from checkpoints and use that as default')
47 | parser.add_argument('--cache_filelist_write', action='store_true', help='saves the current filelist into a text file, so that it loads faster')
48 | parser.add_argument('--cache_filelist_read', action='store_true', help='reads from the file list cache')
49 | # for displays
50 | parser.add_argument('--display_winsize', type=int, default=512, help='display window size')
51 | # for generator
52 | parser.add_argument('--netG', type=str, default='spade', help='selects model to use for netG (pix2pixhd | spade)')
53 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
54 | parser.add_argument('--init_type', type=str, default='xavier', help='network initialization [normal|xavier|kaiming|orthogonal]')
55 | parser.add_argument('--init_variance', type=float, default=0.02, help='variance of the initialization distribution')
56 | # for feature encoder
57 | parser.add_argument('--netCorr', type=str, default='NoVGGHPM')
58 | parser.add_argument('--nef', type=int, default=32, help='# of gen filters in first conv layer')
59 | # for instance-wise features
60 | parser.add_argument('--CBN_intype', type=str, default='warp_mask', help='type of CBN input for framework, warp/mask/warp_mask')
61 | parser.add_argument('--match_kernel', type=int, default=1, help='correspondence matrix match kernel size')
62 | parser.add_argument('--featEnc_kernel', type=int, default=3, help='kernel size in domain adaptor')
63 | parser.add_argument('--PONO', action='store_true', help='use positional normalization ')
64 | parser.add_argument('--PONO_C', action='store_true', help='use C normalization in corr module')
65 | parser.add_argument('--vgg_normal_correct', action='store_true', help='if true, correct vgg normalization and replace vgg FM model with ctx model')
66 | parser.add_argument('--use_coordconv', action='store_true', help='if true, use coordconv in CorrNet')
67 | parser.add_argument('--video_like', action='store_true', help='useful in deepfashion')
68 | parser.add_argument('--amp', action='store_true', help='use torch.cuda.amp')
69 | parser.add_argument('--temperature', type=float, default=0.01)
70 | parser.add_argument('--iteration_count', type=int, default=5)
71 | self.initialized = True
72 | return parser
73 |
74 | def gather_options(self):
75 | # initialize parser with basic options
76 | if not self.initialized:
77 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
78 | parser = self.initialize(parser)
79 | # get the basic options
80 | opt, unknown = parser.parse_known_args()
81 | # modify model-related parser options
82 | model_name = opt.model
83 | model_option_setter = models.get_option_setter(model_name)
84 | parser = model_option_setter(parser, self.isTrain)
85 | # modify dataset-related parser options
86 | dataset_mode = opt.dataset_mode
87 | dataset_option_setter = data.get_option_setter(dataset_mode)
88 | parser = dataset_option_setter(parser, self.isTrain)
89 | opt, unknown = parser.parse_known_args()
90 | # if there is opt_file, load it.
91 | # The previous default options will be overwritten
92 | if opt.load_from_opt_file:
93 | parser = self.update_options_from_file(parser, opt)
94 | opt = parser.parse_args()
95 | self.parser = parser
96 | return opt
97 |
98 | def print_options(self, opt):
99 | message = ''
100 | message += '----------------- Options ---------------\n'
101 | for k, v in sorted(vars(opt).items()):
102 | comment = ''
103 | default = self.parser.get_default(k)
104 | if v != default:
105 | comment = '\t[default: %s]' % str(default)
106 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
107 | message += '----------------- End -------------------'
108 | print(message)
109 |
110 | def option_file_path(self, opt, makedir=False):
111 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
112 | if makedir:
113 | util.mkdirs(expr_dir)
114 | file_name = os.path.join(expr_dir, 'opt')
115 | return file_name
116 |
117 | def save_options(self, opt):
118 | file_name = self.option_file_path(opt, makedir=True)
119 | with open(file_name + '.txt', 'wt') as opt_file:
120 | for k, v in sorted(vars(opt).items()):
121 | comment = ''
122 | default = self.parser.get_default(k)
123 | if v != default:
124 | comment = '\t[default: %s]' % str(default)
125 | opt_file.write('{:>25}: {:<30}{}\n'.format(str(k), str(v), comment))
126 | with open(file_name + '.pkl', 'wb') as opt_file:
127 | pickle.dump(opt, opt_file)
128 |
129 | def update_options_from_file(self, parser, opt):
130 | new_opt = self.load_options(opt)
131 | for k, v in sorted(vars(opt).items()):
132 | if hasattr(new_opt, k) and v != getattr(new_opt, k):
133 | new_val = getattr(new_opt, k)
134 | parser.set_defaults(**{k: new_val})
135 | return parser
136 |
137 | def load_options(self, opt):
138 | file_name = self.option_file_path(opt, makedir=False)
139 | new_opt = pickle.load(open(file_name + '.pkl', 'rb'))
140 | return new_opt
141 |
142 | def parse(self, save=False):
143 | # gather options from base, train, dataset, model
144 | opt = self.gather_options()
145 | # train or test
146 | opt.isTrain = self.isTrain
147 | self.print_options(opt)
148 | if opt.isTrain:
149 | self.save_options(opt)
150 | # Set semantic_nc based on the option.
151 | # This will be convenient in many places
152 | opt.semantic_nc = opt.label_nc + (1 if opt.contain_dontcare_label else 0)
153 | os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids
154 | str_ids = opt.gpu_ids.split(',')
155 | opt.gpu_ids = list(range(len(str_ids)))
156 | seed = 1234
157 | random.seed(seed)
158 | np.random.seed(seed)
159 | torch.manual_seed(seed)
160 | torch.random.manual_seed(seed)
161 | torch.cuda.manual_seed_all(seed)
162 | torch.backends.cudnn.benchmark = True
163 | if len(opt.gpu_ids) > 0:
164 | torch.cuda.set_device(opt.gpu_ids[0])
165 | self.opt = opt
166 | return self.opt
167 |
--------------------------------------------------------------------------------
/models/networks/correspondence.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | import util.util as util
9 | from models.networks.base_network import BaseNetwork
10 | from models.networks.architecture import ResidualBlock
11 | from models.networks.normalization import get_nonspade_norm_layer
12 | from models.networks.architecture import SPADEResnetBlock
13 | """patch match"""
14 | from models.networks.patch_match import PatchMatchGRU
15 | from models.networks.ops import *
16 |
17 |
18 | def match_kernel_and_pono_c(feature, match_kernel, PONO_C, eps=1e-10):
19 | b, c, h, w = feature.size()
20 | if match_kernel == 1:
21 | feature = feature.view(b, c, -1)
22 | else:
23 | feature = F.unfold(feature, kernel_size=match_kernel, padding=int(match_kernel//2))
24 | dim_mean = 1 if PONO_C else -1
25 | feature = feature - feature.mean(dim=dim_mean, keepdim=True)
26 | feature_norm = torch.norm(feature, 2, 1, keepdim=True) + eps
27 | feature = torch.div(feature, feature_norm)
28 | return feature.view(b, -1, h, w)
29 |
30 |
31 | """512x512"""
32 | class AdaptiveFeatureGenerator(BaseNetwork):
33 | @staticmethod
34 | def modify_commandline_options(parser, is_train):
35 | return parser
36 |
37 | def __init__(self, opt):
38 | super().__init__()
39 | self.opt = opt
40 | kw = opt.featEnc_kernel
41 | pw = int((kw-1)//2)
42 | nf = opt.nef
43 | norm_layer = get_nonspade_norm_layer(opt, opt.norm_E)
44 | self.layer1 = norm_layer(nn.Conv2d(opt.spade_ic, nf, 3, stride=1, padding=pw))
45 | self.layer2 = nn.Sequential(
46 | norm_layer(nn.Conv2d(nf * 1, nf * 2, 3, 1, 1)),
47 | ResidualBlock(nf * 2, nf * 2),
48 | )
49 | self.layer3 = nn.Sequential(
50 | norm_layer(nn.Conv2d(nf * 2, nf * 4, kw, stride=2, padding=pw)),
51 | ResidualBlock(nf * 4, nf * 4),
52 | )
53 | self.layer4 = nn.Sequential(
54 | norm_layer(nn.Conv2d(nf * 4, nf * 4, kw, stride=2, padding=pw)),
55 | ResidualBlock(nf * 4, nf * 4),
56 | )
57 | self.layer5 = nn.Sequential(
58 | norm_layer(nn.Conv2d(nf * 4, nf * 4, kw, stride=2, padding=pw)),
59 | ResidualBlock(nf * 4, nf * 4),
60 | )
61 | self.head_0 = SPADEResnetBlock(nf * 4, nf * 4, opt)
62 | self.G_middle_0 = SPADEResnetBlock(nf * 4, nf * 4, opt)
63 | self.G_middle_1 = SPADEResnetBlock(nf * 4, nf * 2, opt)
64 | self.G_middle_2 = SPADEResnetBlock(nf * 2, nf * 1, opt)
65 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
66 |
67 | def forward(self, input, seg):
68 | # 512
69 | x1 = self.layer1(input)
70 | # 512
71 | x2 = self.layer2(self.actvn(x1))
72 | # 256
73 | x3 = self.layer3(self.actvn(x2))
74 | # 128
75 | x4 = self.layer4(self.actvn(x3))
76 | # 64
77 | x5 = self.layer5(self.actvn(x4))
78 | # bottleneck
79 | x6 = self.head_0(x5, seg)
80 | # 128
81 | x7 = self.G_middle_0(self.up(x6) + x4, seg)
82 | # 256
83 | x8 = self.G_middle_1(self.up(x7) + x3, seg)
84 | # 512
85 | x9 = self.G_middle_2(self.up(x8) + x2, seg)
86 | return [x6, x7, x8, x9]
87 |
88 | def actvn(self, x):
89 | return F.leaky_relu(x, 2e-1)
90 |
91 |
92 | class NoVGGHPMCorrespondence(BaseNetwork):
93 | def __init__(self, opt):
94 | self.opt = opt
95 | super().__init__()
96 | opt.spade_ic = opt.semantic_nc
97 | self.adaptive_model_seg = AdaptiveFeatureGenerator(opt)
98 | opt.spade_ic = 3 + opt.semantic_nc
99 | self.adaptive_model_img = AdaptiveFeatureGenerator(opt)
100 | del opt.spade_ic
101 | self.batch_size = opt.batchSize
102 | """512x512"""
103 | feature_channel = opt.nef
104 | self.phi_0 = nn.Conv2d(in_channels=feature_channel*4, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
105 | self.phi_1 = nn.Conv2d(in_channels=feature_channel*4, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
106 | self.phi_2 = nn.Conv2d(in_channels=feature_channel*2, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
107 | self.phi_3 = nn.Conv2d(in_channels=feature_channel, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
108 | self.theta_0 = nn.Conv2d(in_channels=feature_channel*4, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
109 | self.theta_1 = nn.Conv2d(in_channels=feature_channel*4, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
110 | self.theta_2 = nn.Conv2d(in_channels=feature_channel*2, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
111 | self.theta_3 = nn.Conv2d(in_channels=feature_channel, out_channels=feature_channel, kernel_size=1, stride=1, padding=0)
112 | self.patch_match = PatchMatchGRU(opt)
113 |
114 | """512x512"""
115 | def multi_scale_patch_match(self, f1, f2, ref, hierarchical_scale, pre=None, real_img=None):
116 | if hierarchical_scale == 0:
117 | y_cycle = None
118 | scale = 64
119 | batch_size, channel, feature_height, feature_width = f1.size()
120 | ref = F.avg_pool2d(ref, 8, stride=8)
121 | ref = ref.view(batch_size, 3, scale * scale)
122 | f1 = f1.view(batch_size, channel, scale * scale)
123 | f2 = f2.view(batch_size, channel, scale * scale)
124 | matmul_result = torch.matmul(f1.permute(0, 2, 1), f2)/self.opt.temperature
125 | mat = F.softmax(matmul_result, dim=-1)
126 | y = torch.matmul(mat, ref.permute(0, 2, 1))
127 | if self.opt.phase is 'train' and self.opt.weight_warp_cycle > 0:
128 | mat_cycle = F.softmax(matmul_result.transpose(1, 2), dim=-1)
129 | y_cycle = torch.matmul(mat_cycle, y)
130 | y_cycle = y_cycle.permute(0, 2, 1).view(batch_size, 3, scale, scale)
131 | y = y.permute(0, 2, 1).view(batch_size, 3, scale, scale)
132 | return mat, y, y_cycle
133 | if hierarchical_scale == 1:
134 | scale = 128
135 | with torch.no_grad():
136 | batch_size, channel, feature_height, feature_width = f1.size()
137 | topk_num = 1
138 | search_window = 4
139 | centering = 1
140 | dilation = 2
141 | total_candidate_num = topk_num * (search_window ** 2)
142 | topk_inds = torch.topk(pre, topk_num, dim=-1)[-1]
143 | inds = topk_inds.permute(0, 2, 1).view(batch_size, topk_num, (scale//2), (scale//2)).float()
144 | offset_x, offset_y = inds_to_offset(inds)
145 | dx = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=1).expand(-1, search_window).contiguous().view(-1) - centering
146 | dy = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=0).expand(search_window, -1).contiguous().view(-1) - centering
147 | dx = dx.view(1, search_window ** 2, 1, 1) * dilation
148 | dy = dy.view(1, search_window ** 2, 1, 1) * dilation
149 | offset_x_up = F.interpolate((2 * offset_x + dx), scale_factor=2)
150 | offset_y_up = F.interpolate((2 * offset_y + dy), scale_factor=2)
151 | ref = F.avg_pool2d(ref, 4, stride=4)
152 | ref = ref.view(batch_size, 3, scale * scale)
153 | mat, y = self.patch_match(f1, f2, ref, offset_x_up, offset_y_up)
154 | y = y.view(batch_size, 3, scale, scale)
155 | return mat, y
156 | if hierarchical_scale == 2:
157 | scale = 256
158 | with torch.no_grad():
159 | batch_size, channel, feature_height, feature_width = f1.size()
160 | topk_num = 1
161 | search_window = 4
162 | centering = 1
163 | dilation = 2
164 | total_candidate_num = topk_num * (search_window ** 2)
165 | topk_inds = pre[:, :, :topk_num]
166 | inds = topk_inds.permute(0, 2, 1).view(batch_size, topk_num, (scale//2), (scale//2)).float()
167 | offset_x, offset_y = inds_to_offset(inds)
168 | dx = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=1).expand(-1, search_window).contiguous().view(-1) - centering
169 | dy = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=0).expand(search_window, -1).contiguous().view(-1) - centering
170 | dx = dx.view(1, search_window ** 2, 1, 1) * dilation
171 | dy = dy.view(1, search_window ** 2, 1, 1) * dilation
172 | offset_x_up = F.interpolate((2 * offset_x + dx), scale_factor=2)
173 | offset_y_up = F.interpolate((2 * offset_y + dy), scale_factor=2)
174 | ref = F.avg_pool2d(ref, 2, stride=2)
175 | ref = ref.view(batch_size, 3, scale * scale)
176 | mat, y = self.patch_match(f1, f2, ref, offset_x_up, offset_y_up)
177 | y = y.view(batch_size, 3, scale, scale)
178 | return mat, y
179 | if hierarchical_scale == 3:
180 | scale = 512
181 | with torch.no_grad():
182 | batch_size, channel, feature_height, feature_width = f1.size()
183 | topk_num = 1
184 | search_window = 4
185 | centering = 1
186 | dilation = 2
187 | total_candidate_num = topk_num * (search_window ** 2)
188 | topk_inds = pre[:, :, :topk_num]
189 | inds = topk_inds.permute(0, 2, 1).view(batch_size, topk_num, (scale//2), (scale//2)).float()
190 | offset_x, offset_y = inds_to_offset(inds)
191 | dx = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=1).expand(-1, search_window).contiguous().view(-1) - centering
192 | dy = torch.arange(search_window, dtype=topk_inds.dtype, device=topk_inds.device).unsqueeze_(dim=0).expand(search_window, -1).contiguous().view(-1) - centering
193 | dx = dx.view(1, search_window ** 2, 1, 1) * dilation
194 | dy = dy.view(1, search_window ** 2, 1, 1) * dilation
195 | offset_x_up = F.interpolate((2 * offset_x + dx), scale_factor=2)
196 | offset_y_up = F.interpolate((2 * offset_y + dy), scale_factor=2)
197 | ref = ref.view(batch_size, 3, scale * scale)
198 | mat, y = self.patch_match(f1, f2, ref, offset_x_up, offset_y_up)
199 | y = y.view(batch_size, 3, scale, scale)
200 | return mat, y
201 |
202 | def forward(self, ref_img, real_img, seg_map, ref_seg_map):
203 | corr_out = {}
204 | seg_input = seg_map
205 | adaptive_feature_seg = self.adaptive_model_seg(seg_input, seg_input)
206 | ref_input = torch.cat((ref_img, ref_seg_map), dim=1)
207 | adaptive_feature_img = self.adaptive_model_img(ref_input, ref_input)
208 | for i in range(len(adaptive_feature_seg)):
209 | adaptive_feature_seg[i] = util.feature_normalize(adaptive_feature_seg[i])
210 | adaptive_feature_img[i] = util.feature_normalize(adaptive_feature_img[i])
211 | if self.opt.isTrain and self.opt.weight_novgg_featpair > 0:
212 | real_input = torch.cat((real_img, seg_map), dim=1)
213 | adaptive_feature_img_pair = self.adaptive_model_img(real_input, real_input)
214 | loss_novgg_featpair = 0
215 | weights = [1.0, 1.0, 1.0, 1.0]
216 | for i in range(len(adaptive_feature_img_pair)):
217 | adaptive_feature_img_pair[i] = util.feature_normalize(adaptive_feature_img_pair[i])
218 | loss_novgg_featpair += F.l1_loss(adaptive_feature_seg[i], adaptive_feature_img_pair[i]) * weights[i]
219 | corr_out['loss_novgg_featpair'] = loss_novgg_featpair * self.opt.weight_novgg_featpair
220 | cont_features = adaptive_feature_seg
221 | ref_features = adaptive_feature_img
222 | theta = []
223 | phi = []
224 | """512x512"""
225 | theta.append(match_kernel_and_pono_c(self.theta_0(cont_features[0]), self.opt.match_kernel, self.opt.PONO_C))
226 | theta.append(match_kernel_and_pono_c(self.theta_1(cont_features[1]), self.opt.match_kernel, self.opt.PONO_C))
227 | theta.append(match_kernel_and_pono_c(self.theta_2(cont_features[2]), self.opt.match_kernel, self.opt.PONO_C))
228 | theta.append(match_kernel_and_pono_c(self.theta_3(cont_features[3]), self.opt.match_kernel, self.opt.PONO_C))
229 | phi.append(match_kernel_and_pono_c(self.phi_0(ref_features[0]), self.opt.match_kernel, self.opt.PONO_C))
230 | phi.append(match_kernel_and_pono_c(self.phi_1(ref_features[1]), self.opt.match_kernel, self.opt.PONO_C))
231 | phi.append(match_kernel_and_pono_c(self.phi_2(ref_features[2]), self.opt.match_kernel, self.opt.PONO_C))
232 | phi.append(match_kernel_and_pono_c(self.phi_3(ref_features[3]), self.opt.match_kernel, self.opt.PONO_C))
233 | ref = ref_img
234 | ys = []
235 | m = None
236 | for i in range(len(theta)):
237 | if i == 0:
238 | m, y, y_cycle = self.multi_scale_patch_match(theta[i], phi[i], ref, i, pre=m)
239 | if y_cycle is not None:
240 | corr_out['warp_cycle'] = y_cycle
241 | else:
242 | m, y = self.multi_scale_patch_match(theta[i], phi[i], ref, i, pre=m)
243 | ys.append(y)
244 | corr_out['warp_out'] = ys
245 | return corr_out
246 |
--------------------------------------------------------------------------------
/models/pix2pix_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # Licensed under the MIT License.
3 |
4 | import torch
5 | import torch.nn.functional as F
6 | import models.networks as networks
7 | import util.util as util
8 | import itertools
9 | try:
10 | from torch.cuda.amp import autocast
11 | except:
12 | # dummy autocast for PyTorch < 1.6
13 | class autocast:
14 | def __init__(self, enabled):
15 | pass
16 | def __enter__(self):
17 | pass
18 | def __exit__(self, *args):
19 | pass
20 |
21 |
22 | class Pix2PixModel(torch.nn.Module):
23 | @staticmethod
24 | def modify_commandline_options(parser, is_train):
25 | networks.modify_commandline_options(parser, is_train)
26 | return parser
27 |
28 | def __init__(self, opt):
29 | super().__init__()
30 | self.opt = opt
31 | self.FloatTensor = torch.cuda.FloatTensor if self.use_gpu() \
32 | else torch.FloatTensor
33 | self.ByteTensor = torch.cuda.ByteTensor if self.use_gpu() \
34 | else torch.ByteTensor
35 | self.net = torch.nn.ModuleDict(self.initialize_networks(opt))
36 | # set loss functions
37 | if opt.isTrain:
38 | # vgg network
39 | self.vggnet_fix = networks.architecture.VGG19_feature_color_torchversion(vgg_normal_correct=opt.vgg_normal_correct)
40 | self.vggnet_fix.load_state_dict(torch.load('vgg/vgg19_conv.pth'))
41 | self.vggnet_fix.eval()
42 | for param in self.vggnet_fix.parameters():
43 | param.requires_grad = False
44 | self.vggnet_fix.to(self.opt.gpu_ids[0])
45 | # contextual loss
46 | self.contextual_forward_loss = networks.ContextualLoss_forward(opt)
47 | # GAN loss
48 | self.criterionGAN = networks.GANLoss(opt.gan_mode, tensor=self.FloatTensor, opt=self.opt)
49 | # L1 loss
50 | self.criterionFeat = torch.nn.L1Loss()
51 | # L2 loss
52 | self.MSE_loss = torch.nn.MSELoss()
53 | # setting which layer is used in the perceptual loss
54 | if opt.which_perceptual == '5_2':
55 | self.perceptual_layer = -1
56 | elif opt.which_perceptual == '4_2':
57 | self.perceptual_layer = -2
58 |
59 | def forward(self, data, mode, GforD=None):
60 | input_label, input_semantics, real_image, self_ref, ref_image, ref_label, ref_semantics = self.preprocess_input(data, )
61 | generated_out = {}
62 |
63 | if mode == 'generator':
64 | g_loss, generated_out = self.compute_generator_loss(input_label, \
65 | input_semantics, real_image, ref_label, \
66 | ref_semantics, ref_image, self_ref)
67 | out = {}
68 | out['fake_image'] = generated_out['fake_image']
69 | out['input_semantics'] = input_semantics
70 | out['ref_semantics'] = ref_semantics
71 | out['warp_out'] = None if 'warp_out' not in generated_out else generated_out['warp_out']
72 | out['adaptive_feature_seg'] = None if 'adaptive_feature_seg' not in generated_out else generated_out['adaptive_feature_seg']
73 | out['adaptive_feature_img'] = None if 'adaptive_feature_img' not in generated_out else generated_out['adaptive_feature_img']
74 | out['warp_cycle'] = None if 'warp_cycle' not in generated_out else generated_out['warp_cycle']
75 | return g_loss, out
76 |
77 | elif mode == 'discriminator':
78 | d_loss = self.compute_discriminator_loss(input_semantics, \
79 | real_image, GforD, label=input_label)
80 | return d_loss
81 |
82 | elif mode == 'inference':
83 | out = {}
84 | with torch.no_grad():
85 | out = self.inference(input_semantics, ref_semantics=ref_semantics, \
86 | ref_image=ref_image, self_ref=self_ref, \
87 | real_image=real_image)
88 | out['input_semantics'] = input_semantics
89 | out['ref_semantics'] = ref_semantics
90 | return out
91 |
92 | else:
93 | raise ValueError("|mode| is invalid")
94 |
95 | def create_optimizers(self, opt):
96 | if opt.no_TTUR:
97 | beta1, beta2 = opt.beta1, opt.beta2
98 | G_lr, D_lr = opt.lr, opt.lr
99 | else:
100 | beta1, beta2 = 0, 0.9
101 | G_lr, D_lr = opt.lr / 2, opt.lr * 2
102 | optimizer_G = torch.optim.Adam(itertools.chain(self.net['netG'].parameters(), \
103 | self.net['netCorr'].parameters()), lr=G_lr, betas=(beta1, beta2), eps=1e-3)
104 | optimizer_D = torch.optim.Adam(itertools.chain(self.net['netD'].parameters()), \
105 | lr=D_lr, betas=(beta1, beta2))
106 | return optimizer_G, optimizer_D
107 |
108 | def save(self, epoch):
109 | util.save_network(self.net['netG'], 'G', epoch, self.opt)
110 | util.save_network(self.net['netD'], 'D', epoch, self.opt)
111 | util.save_network(self.net['netCorr'], 'Corr', epoch, self.opt)
112 |
113 | def initialize_networks(self, opt):
114 | net = {}
115 | net['netG'] = networks.define_G(opt)
116 | net['netD'] = networks.define_D(opt) if opt.isTrain else None
117 | net['netCorr'] = networks.define_Corr(opt)
118 | if not opt.isTrain or opt.continue_train:
119 | net['netCorr'] = util.load_network(net['netCorr'], 'Corr', opt.which_epoch, opt)
120 | net['netG'] = util.load_network(net['netG'], 'G', opt.which_epoch, opt)
121 | if opt.isTrain:
122 | net['netD'] = util.load_network(net['netD'], 'D', opt.which_epoch, opt)
123 | return net
124 |
125 | def preprocess_input(self, data):
126 | if self.use_gpu():
127 | for k in data.keys():
128 | try:
129 | data[k] = data[k].cuda()
130 | except:
131 | continue
132 | label = data['label'][:,:3,:,:].float()
133 | label_ref = data['label_ref'][:,:3,:,:].float()
134 | input_semantics = data['label'].float()
135 | ref_semantics = data['label_ref'].float()
136 | image = data['image']
137 | ref = data['ref']
138 | self_ref = data['self_ref']
139 | return label, input_semantics, image, self_ref, ref, label_ref, ref_semantics
140 |
141 | def get_ctx_loss(self, source, target):
142 | contextual_style5_1 = torch.mean(self.contextual_forward_loss(source[-1], target[-1].detach())) * 8
143 | contextual_style4_1 = torch.mean(self.contextual_forward_loss(source[-2], target[-2].detach())) * 4
144 | contextual_style3_1 = torch.mean(self.contextual_forward_loss(F.avg_pool2d(source[-3], 2), F.avg_pool2d(target[-3].detach(), 2))) * 2
145 | return contextual_style5_1 + contextual_style4_1 + contextual_style3_1
146 |
147 | def compute_generator_loss(self, input_label, input_semantics, real_image, ref_label=None, ref_semantics=None, ref_image=None, self_ref=None):
148 | G_losses = {}
149 | generate_out = self.generate_fake(input_semantics, real_image, ref_semantics=ref_semantics, ref_image=ref_image, self_ref=self_ref)
150 | generate_out['fake_image'] = generate_out['fake_image'].float()
151 | weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
152 | sample_weights = self_ref/(sum(self_ref)+1e-5)
153 | sample_weights = sample_weights.view(-1, 1, 1, 1)
154 | """domain align"""
155 | if 'loss_novgg_featpair' in generate_out and generate_out['loss_novgg_featpair'] is not None:
156 | G_losses['no_vgg_feat'] = generate_out['loss_novgg_featpair']
157 | """warping cycle"""
158 | if self.opt.weight_warp_cycle > 0:
159 | warp_cycle = generate_out['warp_cycle']
160 | scale_factor = ref_image.size()[-1] // warp_cycle.size()[-1]
161 | ref = F.avg_pool2d(ref_image, scale_factor, stride=scale_factor)
162 | G_losses['G_warp_cycle'] = F.l1_loss(warp_cycle, ref) * self.opt.weight_warp_cycle
163 | """warping loss"""
164 | if self.opt.weight_warp_self > 0:
165 | """512x512"""
166 | warp1, warp2, warp3, warp4 = generate_out['warp_out']
167 | G_losses['G_warp_self'] = \
168 | torch.mean(F.l1_loss(warp4, real_image, reduction='none') * sample_weights) * self.opt.weight_warp_self * 1.0 + \
169 | torch.mean(F.l1_loss(warp3, F.avg_pool2d(real_image, 2, stride=2), reduction='none') * sample_weights) * self.opt.weight_warp_self * 1.0 + \
170 | torch.mean(F.l1_loss(warp2, F.avg_pool2d(real_image, 4, stride=4), reduction='none') * sample_weights) * self.opt.weight_warp_self * 1.0 + \
171 | torch.mean(F.l1_loss(warp1, F.avg_pool2d(real_image, 8, stride=8), reduction='none') * sample_weights) * self.opt.weight_warp_self * 1.0
172 | """gan loss"""
173 | pred_fake, pred_real = self.discriminate(input_semantics, generate_out['fake_image'], real_image)
174 | G_losses['GAN'] = self.criterionGAN(pred_fake, True, for_discriminator=False) * self.opt.weight_gan
175 | if not self.opt.no_ganFeat_loss:
176 | num_D = len(pred_fake)
177 | GAN_Feat_loss = 0.0
178 | for i in range(num_D):
179 | # for each discriminator
180 | # last output is the final prediction, so we exclude it
181 | num_intermediate_outputs = len(pred_fake[i]) - 1
182 | for j in range(num_intermediate_outputs):
183 | # for each layer output
184 | unweighted_loss = self.criterionFeat(pred_fake[i][j], pred_real[i][j].detach())
185 | GAN_Feat_loss += unweighted_loss * self.opt.weight_ganFeat / num_D
186 | G_losses['GAN_Feat'] = GAN_Feat_loss
187 | """feature matching loss"""
188 | fake_features = self.vggnet_fix(generate_out['fake_image'], ['r12', 'r22', 'r32', 'r42', 'r52'], preprocess=True)
189 | loss = 0
190 | for i in range(len(generate_out['real_features'])):
191 | loss += weights[i] * util.weighted_l1_loss(fake_features[i], generate_out['real_features'][i].detach(), sample_weights)
192 | G_losses['fm'] = loss * self.opt.weight_vgg * self.opt.weight_fm_ratio
193 | """perceptual loss"""
194 | feat_loss = util.mse_loss(fake_features[self.perceptual_layer], generate_out['real_features'][self.perceptual_layer].detach())
195 | G_losses['perc'] = feat_loss * self.opt.weight_perceptual
196 | """contextual loss"""
197 | G_losses['contextual'] = self.get_ctx_loss(fake_features, generate_out['ref_features']) * self.opt.weight_vgg * self.opt.weight_contextual
198 | return G_losses, generate_out
199 |
200 | def compute_discriminator_loss(self, input_semantics, real_image, GforD, label=None):
201 | D_losses = {}
202 | with torch.no_grad():
203 | fake_image = GforD['fake_image'].detach()
204 | fake_image.requires_grad_()
205 | pred_fake, pred_real = self.discriminate(input_semantics, fake_image, real_image)
206 | D_losses['D_Fake'] = self.criterionGAN(pred_fake, False, for_discriminator=True) * self.opt.weight_gan
207 | D_losses['D_real'] = self.criterionGAN(pred_real, True, for_discriminator=True) * self.opt.weight_gan
208 | return D_losses
209 |
210 | def encode_z(self, real_image):
211 | mu, logvar = self.net['netE'](real_image)
212 | z = self.reparameterize(mu, logvar)
213 | return z, mu, logvar
214 |
215 | def generate_fake(self, input_semantics, real_image, ref_semantics=None, ref_image=None, self_ref=None):
216 | generate_out = {}
217 | generate_out['ref_features'] = self.vggnet_fix(ref_image, ['r12', 'r22', 'r32', 'r42', 'r52'], preprocess=True)
218 | generate_out['real_features'] = self.vggnet_fix(real_image, ['r12', 'r22', 'r32', 'r42', 'r52'], preprocess=True)
219 | with autocast(enabled=self.opt.amp):
220 | corr_out = self.net['netCorr'](ref_image, real_image, input_semantics, ref_semantics)
221 | generate_out['fake_image'] = self.net['netG'](input_semantics, warp_out=corr_out['warp_out'])
222 | generate_out = {**generate_out, **corr_out}
223 | return generate_out
224 |
225 | def inference(self, input_semantics, ref_semantics=None, ref_image=None, self_ref=None, real_image=None):
226 | generate_out = {}
227 | with autocast(enabled=self.opt.amp):
228 | corr_out = self.net['netCorr'](ref_image, real_image, input_semantics, ref_semantics)
229 | generate_out['fake_image'] = self.net['netG'](input_semantics, warp_out=corr_out['warp_out'])
230 | generate_out = {**generate_out, **corr_out}
231 | return generate_out
232 |
233 | def discriminate(self, input_semantics, fake_image, real_image):
234 | fake_concat = torch.cat([input_semantics, fake_image], dim=1)
235 | real_concat = torch.cat([input_semantics, real_image], dim=1)
236 | fake_and_real = torch.cat([fake_concat, real_concat], dim=0)
237 | with autocast(enabled=self.opt.amp):
238 | discriminator_out = self.net['netD'](fake_and_real)
239 | pred_fake, pred_real = self.divide_pred(discriminator_out)
240 | return pred_fake, pred_real
241 |
242 | def divide_pred(self, pred):
243 | if type(pred) == list:
244 | fake = []
245 | real = []
246 | for p in pred:
247 | fake.append([tensor[:tensor.size(0) // 2] for tensor in p])
248 | real.append([tensor[tensor.size(0) // 2:] for tensor in p])
249 | else:
250 | fake = pred[:pred.size(0) // 2]
251 | real = pred[pred.size(0) // 2:]
252 | return fake, real
253 |
254 | def use_gpu(self):
255 | return len(self.opt.gpu_ids) > 0
256 |
--------------------------------------------------------------------------------