├── scripts
├── train_fashion.sh
├── eval_fashion.sh
├── eval_market.sh
├── test_market.sh
├── test_fashion.sh
└── train_market.sh
├── requirements.txt
├── metrics
├── README.md
├── inception.py
└── metrics.py
├── data
├── data_loader.py
├── base_data_loader.py
├── custom_dataset_data_loader.py
├── __init__.py
├── generate_fashion_datasets.py
├── image_folder.py
├── base_dataset.py
├── market_dataset.py
└── fashion_dataset.py
├── Poster.md
├── models
├── models.py
├── __init__.py
├── PTM.py
├── base_model.py
├── DPTN_model.py
├── networks.py
├── external_function.py
├── base_function.py
└── ui_model.py
├── test.py
├── util
├── image_pool.py
├── html.py
├── pose_utils.py
├── visualizer.py
└── util.py
├── options
├── test_options.py
├── train_options.py
└── base_options.py
├── train.py
├── README.md
└── LICENSE.md
/scripts/train_fashion.sh:
--------------------------------------------------------------------------------
1 | python train.py --name=DPTN_fashion --model=DPTN --dataset_mode=fashion --dataroot=./dataset/fashion --batchSize 32 --gpu_id=0
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.7.1
2 | torchvision==0.8.2
3 | imageio
4 | natsort
5 | scipy
6 | scikit-image
7 | pandas
8 | dominate
9 | opencv-python
10 | visdom
11 |
--------------------------------------------------------------------------------
/scripts/eval_fashion.sh:
--------------------------------------------------------------------------------
1 | python -m metrics.metrics --gt_path=./dataset/fashion/test --distorated_path=./results/DPTN_fashion --fid_real_path=./dataset/fashion/train --name=./fashion
--------------------------------------------------------------------------------
/scripts/eval_market.sh:
--------------------------------------------------------------------------------
1 | python -m metrics.metrics --gt_path=./dataset/market/test --distorated_path=./results/DPTN_market --fid_real_path=./dataset/market/train --name=./market --market
--------------------------------------------------------------------------------
/scripts/test_market.sh:
--------------------------------------------------------------------------------
1 | python test.py --name=DPTN_market --model=DPTN --dataset_mode=market --dataroot=./dataset/market --which_epoch latest --results_dir=./results/DPTN_market --batchSize 1 --gpu_id=0
--------------------------------------------------------------------------------
/scripts/test_fashion.sh:
--------------------------------------------------------------------------------
1 | python test.py --name=DPTN_fashion --model=DPTN --dataset_mode=fashion --dataroot=./dataset/fashion --which_epoch latest --results_dir ./results/DPTN_fashion --batchSize 1 --gpu_id=0
--------------------------------------------------------------------------------
/metrics/README.md:
--------------------------------------------------------------------------------
1 | Please clone the official repository **[PerceptualSimilarity](https://github.com/richzhang/PerceptualSimilarity/tree/future)** of the LPIPS score, and put the folder PerceptualSimilarity here.
2 |
--------------------------------------------------------------------------------
/scripts/train_market.sh:
--------------------------------------------------------------------------------
1 | python train.py --name=DPTN_market --model=DPTN --dataset_mode=market --dataroot=./dataset/market --dis_layer=3 --lambda_g=5 --lambda_rec 2 --t_s_ratio=0.8 --save_latest_freq=10400 --batchSize 32 --gpu_id=0
--------------------------------------------------------------------------------
/data/data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | def CreateDataLoader(opt):
3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader
4 | data_loader = CustomDatasetDataLoader()
5 | print(data_loader.name())
6 | data_loader.initialize(opt)
7 | return data_loader
8 |
--------------------------------------------------------------------------------
/data/base_data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | class BaseDataLoader():
3 | def __init__(self):
4 | pass
5 |
6 | def initialize(self, opt):
7 | self.opt = opt
8 | pass
9 |
10 | def load_data():
11 | return None
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/Poster.md:
--------------------------------------------------------------------------------
1 |
2 | Our poster template can be download from [Google Drive](https://docs.google.com/presentation/d/1i02V0JZCw2mRZF99szitaOEkfVNeKR1q/edit?usp=sharing&ouid=111594135598063931892&rtpof=true&sd=true).
3 |
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import models
3 |
4 | def create_model(opt):
5 | '''
6 | if opt.model == 'pix2pixHD':
7 | from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
8 | if opt.isTrain:
9 | model = Pix2PixHDModel()
10 | else:
11 | model = InferenceModel()
12 | elif opt.model == 'basic':
13 | from .basic_model import BasicModel
14 | model = BasicModel(opt)
15 | else:
16 | from .ui_model import UIModel
17 | model = UIModel()
18 | '''
19 | model = models.find_model_using_name(opt.model)(opt)
20 | if opt.verbose:
21 | print("model [%s] was created" % (model.name()))
22 | return model
23 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | from data.data_loader import CreateDataLoader
2 | from options.test_options import TestOptions
3 | from models.models import create_model
4 | import numpy as np
5 | import torch
6 |
7 | if __name__=='__main__':
8 | # get testing options
9 | opt = TestOptions().parse()
10 | # creat a dataset
11 | data_loader = CreateDataLoader(opt)
12 | dataset = data_loader.load_data()
13 |
14 |
15 | print(len(dataset))
16 |
17 | dataset_size = len(dataset) * opt.batchSize
18 | print('testing images = %d' % dataset_size)
19 | # create a model
20 | model = create_model(opt)
21 |
22 | with torch.no_grad():
23 | for i, data in enumerate(dataset):
24 | model.set_input(data)
25 | model.test()
26 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains modules related to function, network architectures, and models"""
2 |
3 | import importlib
4 | from .base_model import BaseModel
5 |
6 |
7 | def find_model_using_name(model_name):
8 | """Import the module "model/[model_name]_model.py"."""
9 | model_file_name = "models." + model_name + "_model"
10 | modellib = importlib.import_module(model_file_name)
11 | model = None
12 | for name, cls in modellib.__dict__.items():
13 | if name.lower() == (model_name+'model').lower() and issubclass(cls, BaseModel):
14 | model = cls
15 |
16 | if model is None:
17 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_file_name, model_name))
18 | exit(0)
19 |
20 | return model
21 |
22 |
23 | def get_option_setter(model_name):
24 | """Return the static method of the model class."""
25 | model = find_model_using_name(model_name)
26 | return model.modify_options
27 |
--------------------------------------------------------------------------------
/util/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import torch
3 | from torch.autograd import Variable
4 | class ImagePool():
5 | def __init__(self, pool_size):
6 | self.pool_size = pool_size
7 | if self.pool_size > 0:
8 | self.num_imgs = 0
9 | self.images = []
10 |
11 | def query(self, images):
12 | if self.pool_size == 0:
13 | return images
14 | return_images = []
15 | for image in images.data:
16 | image = torch.unsqueeze(image, 0)
17 | if self.num_imgs < self.pool_size:
18 | self.num_imgs = self.num_imgs + 1
19 | self.images.append(image)
20 | return_images.append(image)
21 | else:
22 | p = random.uniform(0, 1)
23 | if p > 0.5:
24 | random_id = random.randint(0, self.pool_size-1)
25 | tmp = self.images[random_id].clone()
26 | self.images[random_id] = image
27 | return_images.append(tmp)
28 | else:
29 | return_images.append(image)
30 | return_images = Variable(torch.cat(return_images, 0))
31 | return return_images
32 |
--------------------------------------------------------------------------------
/data/custom_dataset_data_loader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from data.base_data_loader import BaseDataLoader
3 | import data
4 |
5 | def CreateDataset(opt):
6 | '''
7 | dataset = None
8 | if opt.dataset_mode == 'fashion':
9 | from data.fashion_dataset import FashionDataset
10 | dataset = FashionDataset()
11 | else:
12 | from data.aligned_dataset import AlignedDataset
13 | dataset = AlignedDataset()
14 | '''
15 | dataset = data.find_dataset_using_name(opt.dataset_mode)()
16 | print("dataset [%s] was created" % (dataset.name()))
17 | dataset.initialize(opt)
18 | return dataset
19 |
20 | class CustomDatasetDataLoader(BaseDataLoader):
21 | def name(self):
22 | return 'CustomDatasetDataLoader'
23 |
24 | def initialize(self, opt):
25 | BaseDataLoader.initialize(self, opt)
26 | self.dataset = CreateDataset(opt)
27 | self.dataloader = torch.utils.data.DataLoader(
28 | self.dataset,
29 | batch_size=opt.batchSize,
30 | shuffle=(not opt.serial_batches) and opt.isTrain,
31 | num_workers=int(opt.nThreads))
32 |
33 | def load_data(self):
34 | return self.dataloader
35 |
36 | def __len__(self):
37 | return min(len(self.dataset), self.opt.max_dataset_size)
38 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 | class TestOptions(BaseOptions):
4 | def initialize(self):
5 | BaseOptions.initialize(self)
6 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
7 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
8 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
9 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
10 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
11 | self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
12 | self.parser.add_argument('--cluster_path', type=str, default='features_clustered_010.npy', help='the path for clustered results of encoded features')
13 | self.parser.add_argument('--use_encoded_image', action='store_true', help='if specified, encode the real image to get the feature map')
14 | self.parser.add_argument("--export_onnx", type=str, help="export ONNX model to a given file")
15 | self.parser.add_argument("--engine", type=str, help="run serialized TRT engine")
16 | self.parser.add_argument("--onnx", type=str, help="run ONNX model via TRT")
17 | self.isTrain = False
18 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import torch.utils.data
3 | from data.base_dataset import BaseDataset
4 |
5 |
6 | def find_dataset_using_name(dataset_name):
7 | # Given the option --dataset [datasetname],
8 | # the file "datasets/datasetname_dataset.py"
9 | # will be imported.
10 | dataset_filename = "data." + dataset_name + "_dataset"
11 | datasetlib = importlib.import_module(dataset_filename)
12 |
13 | # In the file, the class called DatasetNameDataset() will
14 | # be instantiated. It has to be a subclass of BaseDataset,
15 | # and it is case-insensitive.
16 | dataset = None
17 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
18 | for name, cls in datasetlib.__dict__.items():
19 | if name.lower() == target_dataset_name.lower() \
20 | and issubclass(cls, BaseDataset):
21 | dataset = cls
22 |
23 | if dataset is None:
24 | raise ValueError("In %s.py, there should be a subclass of BaseDataset "
25 | "with class name that matches %s in lowercase." %
26 | (dataset_filename, target_dataset_name))
27 |
28 | return dataset
29 |
30 |
31 | def get_option_setter(dataset_name):
32 | dataset_class = find_dataset_using_name(dataset_name)
33 | return dataset_class.modify_commandline_options
34 |
35 | '''
36 | def create_dataloader(opt):
37 | dataset = find_dataset_using_name(opt.dataset_mode)
38 | instance = dataset()
39 | instance.initialize(opt)
40 | print("dataset [%s] of size %d was created" %
41 | (type(instance).__name__, len(instance)))
42 | dataloader = torch.utils.data.DataLoader(
43 | instance,
44 | batch_size=opt.batchSize,
45 | shuffle=not opt.serial_batches,
46 | num_workers=int(opt.nThreads),
47 | drop_last=opt.isTrain
48 | )
49 | return dataloader
50 | '''
--------------------------------------------------------------------------------
/data/generate_fashion_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from PIL import Image
4 |
5 | IMG_EXTENSIONS = [
6 | '.jpg', '.JPG', '.jpeg', '.JPEG',
7 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
8 | ]
9 |
10 | def is_image_file(filename):
11 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
12 |
13 | def make_dataset(dir):
14 | images = []
15 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
16 | new_root = './fashion'
17 | if not os.path.exists(new_root):
18 | os.mkdir(new_root)
19 |
20 | train_root = './fashion/train'
21 | if not os.path.exists(train_root):
22 | os.mkdir(train_root)
23 |
24 | test_root = './fashion/test'
25 | if not os.path.exists(test_root):
26 | os.mkdir(test_root)
27 |
28 | train_images = []
29 | train_f = open('./fashion/train.lst', 'r')
30 | for lines in train_f:
31 | lines = lines.strip()
32 | if lines.endswith('.jpg'):
33 | train_images.append(lines)
34 |
35 | test_images = []
36 | test_f = open('./fashion/test.lst', 'r')
37 | for lines in test_f:
38 | lines = lines.strip()
39 | if lines.endswith('.jpg'):
40 | test_images.append(lines)
41 |
42 | print(train_images, test_images)
43 |
44 |
45 | for root, _, fnames in sorted(os.walk(dir)):
46 | for fname in fnames:
47 | if is_image_file(fname):
48 | path = os.path.join(root, fname)
49 | path_names = path.split('/')
50 | # path_names[2] = path_names[2].replace('_', '')
51 | path_names[3] = path_names[3].replace('_', '')
52 | path_names[4] = path_names[4].split('_')[0] + "_" + "".join(path_names[4].split('_')[1:])
53 | path_names = "".join(path_names)
54 | # new_path = os.path.join(root, path_names)
55 | img = Image.open(path)
56 | imgcrop = img.crop((40, 0, 216, 256))
57 | if new_path in train_images:
58 | imgcrop.save(os.path.join(train_root, path_names))
59 | elif new_path in test_images:
60 | imgcrop.save(os.path.join(test_root, path_names))
61 |
62 | make_dataset('./fashion')
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import *
3 | import os
4 |
5 |
6 | class HTML:
7 | def __init__(self, web_dir, title, refresh=0):
8 | self.title = title
9 | self.web_dir = web_dir
10 | self.img_dir = os.path.join(self.web_dir, 'images')
11 | if not os.path.exists(self.web_dir):
12 | os.makedirs(self.web_dir)
13 | if not os.path.exists(self.img_dir):
14 | os.makedirs(self.img_dir)
15 |
16 | self.doc = dominate.document(title=title)
17 | if refresh > 0:
18 | with self.doc.head:
19 | meta(http_equiv="refresh", content=str(refresh))
20 |
21 | def get_image_dir(self):
22 | return self.img_dir
23 |
24 | def add_header(self, str):
25 | with self.doc:
26 | h3(str)
27 |
28 | def add_table(self, border=1):
29 | self.t = table(border=border, style="table-layout: fixed;")
30 | self.doc.add(self.t)
31 |
32 | def add_images(self, ims, txts, links, width=512):
33 | self.add_table()
34 | with self.t:
35 | with tr():
36 | for im, txt, link in zip(ims, txts, links):
37 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
38 | with p():
39 | with a(href=os.path.join('images', link)):
40 | img(style="width:%dpx" % (width), src=os.path.join('images', im))
41 | br()
42 | p(txt)
43 |
44 | def save(self):
45 | html_file = '%s/index.html' % self.web_dir
46 | f = open(html_file, 'wt')
47 | f.write(self.doc.render())
48 | f.close()
49 |
50 |
51 | if __name__ == '__main__':
52 | html = HTML('web/', 'test_html')
53 | html.add_header('hello world')
54 |
55 | ims = []
56 | txts = []
57 | links = []
58 | for n in range(4):
59 | ims.append('image_%d.jpg' % n)
60 | txts.append('text_%d' % n)
61 | links.append('image_%d.jpg' % n)
62 | html.add_images(ims, txts, links)
63 | html.save()
64 |
--------------------------------------------------------------------------------
/data/image_folder.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4 | # Modified the original code so that it also loads images from the current
5 | # directory as well as the subdirectories
6 | ###############################################################################
7 | import torch.utils.data as data
8 | from PIL import Image
9 | import os
10 |
11 | IMG_EXTENSIONS = [
12 | '.jpg', '.JPG', '.jpeg', '.JPEG',
13 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
14 | ]
15 |
16 |
17 | def is_image_file(filename):
18 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
19 |
20 |
21 | def make_dataset(dir):
22 | images = []
23 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
24 |
25 | for root, _, fnames in sorted(os.walk(dir)):
26 | for fname in fnames:
27 | if is_image_file(fname):
28 | path = os.path.join(root, fname)
29 | images.append(path)
30 |
31 | return images
32 |
33 |
34 | def default_loader(path):
35 | return Image.open(path).convert('RGB')
36 |
37 |
38 | class ImageFolder(data.Dataset):
39 |
40 | def __init__(self, root, transform=None, return_paths=False,
41 | loader=default_loader):
42 | imgs = make_dataset(root)
43 | if len(imgs) == 0:
44 | raise(RuntimeError("Found 0 images in: " + root + "\n"
45 | "Supported image extensions are: " +
46 | ",".join(IMG_EXTENSIONS)))
47 |
48 | self.root = root
49 | self.imgs = imgs
50 | self.transform = transform
51 | self.return_paths = return_paths
52 | self.loader = loader
53 |
54 | def __getitem__(self, index):
55 | path = self.imgs[index]
56 | img = self.loader(path)
57 | if self.transform is not None:
58 | img = self.transform(img)
59 | if self.return_paths:
60 | return img, path
61 | else:
62 | return img
63 |
64 | def __len__(self):
65 | return len(self.imgs)
66 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from PIL import Image
3 | import torchvision.transforms as transforms
4 | import numpy as np
5 | import random
6 |
7 |
8 | class BaseDataset(data.Dataset):
9 | def __init__(self):
10 | super(BaseDataset, self).__init__()
11 |
12 | def name(self):
13 | return 'BaseDataset'
14 |
15 | def initialize(self, opt):
16 | pass
17 |
18 | def get_params(opt, size):
19 | w, h = size
20 | new_h = h
21 | new_w = w
22 | if opt.resize_or_crop == 'resize_and_crop':
23 | new_h = new_w = opt.loadSize
24 | elif opt.resize_or_crop == 'scale_width_and_crop':
25 | new_w = opt.loadSize
26 | new_h = opt.loadSize * h // w
27 |
28 | x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
29 | y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
30 |
31 | flip = random.random() > 0.5
32 | return {'crop_pos': (x, y), 'flip': flip}
33 |
34 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
35 | transform_list = []
36 | if 'resize' in opt.resize_or_crop:
37 | osize = [opt.loadSize, opt.loadSize]
38 | transform_list.append(transforms.Scale(osize, method))
39 | elif 'scale_width' in opt.resize_or_crop:
40 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
41 |
42 | if 'crop' in opt.resize_or_crop:
43 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
44 |
45 | if opt.resize_or_crop == 'none':
46 | base = float(2 ** opt.n_downsample_global)
47 | if opt.netG == 'local':
48 | base *= (2 ** opt.n_local_enhancers)
49 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
50 |
51 | if opt.isTrain and not opt.no_flip:
52 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
53 |
54 | transform_list += [transforms.ToTensor()]
55 |
56 | if normalize:
57 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
58 | (0.5, 0.5, 0.5))]
59 | return transforms.Compose(transform_list)
60 |
61 |
62 | def normalize():
63 | return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
64 |
65 |
66 | def __make_power_2(img, base, method=Image.BICUBIC):
67 | ow, oh = img.size
68 | h = int(round(oh / base) * base)
69 | w = int(round(ow / base) * base)
70 | if (h == oh) and (w == ow):
71 | return img
72 | return img.resize((w, h), method)
73 |
74 |
75 | def __scale_width(img, target_width, method=Image.BICUBIC):
76 | ow, oh = img.size
77 | if (ow == target_width):
78 | return img
79 | w = target_width
80 | h = int(target_width * oh / ow)
81 | return img.resize((w, h), method)
82 |
83 |
84 | def __crop(img, pos, size):
85 | ow, oh = img.size
86 | x1, y1 = pos
87 | tw = th = size
88 | if (ow > tw or oh > th):
89 | return img.crop((x1, y1, x1 + tw, y1 + th))
90 | return img
91 |
92 |
93 | def __flip(img, flip):
94 | if flip:
95 | return img.transpose(Image.FLIP_LEFT_RIGHT)
96 | return img
97 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 | class TrainOptions(BaseOptions):
4 | def initialize(self):
5 | BaseOptions.initialize(self)
6 | # for displays
7 | self.parser.add_argument('--display_freq', type=int, default=200, help='frequency of showing training results on screen')
8 | self.parser.add_argument('--print_freq', type=int, default=200, help='frequency of showing training results on console')
9 | self.parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results')
10 | self.parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
11 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
12 | self.parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration')
13 |
14 | # for training
15 | self.parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
16 | self.parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location')
17 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
18 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
19 | self.parser.add_argument('--niter', type=int, default=100, help='# of iter at starting learning rate')
20 | self.parser.add_argument('--niter_decay', type=int, default=100, help='# of iter to linearly decay learning rate to zero')
21 | self.parser.add_argument('--iter_start', type=int, default=0, help='# of iter to linearly decay learning rate to zero')
22 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
23 | self.parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam')
24 | self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy[lambda|step|plateau]')
25 | self.parser.add_argument('--gan_mode', type=str, default='lsgan', choices=['wgan-gp', 'hinge', 'lsgan'])
26 | # for discriminators
27 | self.parser.add_argument('--num_D', type=int, default=1, help='number of discriminators to use')
28 | self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
29 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
30 | self.parser.add_argument('--lambda_feat', type=float, default=10.0, help='weight for feature matching loss')
31 | self.parser.add_argument('--no_ganFeat_loss', action='store_true', help='if specified, do *not* use discriminator feature matching loss')
32 | self.parser.add_argument('--no_vgg_loss', action='store_true', help='if specified, do *not* use VGG feature matching loss')
33 | self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images')
34 |
35 | self.isTrain = True
36 |
--------------------------------------------------------------------------------
/data/market_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from data.base_dataset import BaseDataset, get_params, get_transform, normalize
3 | from data.image_folder import make_dataset
4 | import torchvision.transforms.functional as F
5 | import torchvision.transforms as transforms
6 | from PIL import Image
7 | from util import pose_utils
8 | import pandas as pd
9 | import numpy as np
10 | import torch
11 |
12 | class MarketDataset(BaseDataset):
13 | @staticmethod
14 | def modify_commandline_options(parser, is_train):
15 | if is_train:
16 | parser.set_defaults(load_size=128)
17 | else:
18 | parser.set_defaults(load_size=128)
19 | parser.set_defaults(old_size=(128, 64))
20 | parser.set_defaults(structure_nc=18)
21 | parser.set_defaults(image_nc=3)
22 | return parser
23 |
24 | def initialize(self, opt):
25 | self.opt = opt
26 | self.root = opt.dataroot
27 | self.phase = opt.phase
28 |
29 | # prepare for image (image_dir), image_pair (name_pairs) and bone annotation (annotation_file)
30 | self.image_dir = os.path.join(self.root, self.phase)
31 | self.bone_file = os.path.join(self.root, 'market-annotation-%s.csv' % self.phase)
32 | pairLst = os.path.join(self.root, 'market-pairs-%s.csv' % self.phase)
33 | self.name_pairs = self.init_categories(pairLst)
34 | self.annotation_file = pd.read_csv(self.bone_file, sep=':')
35 | self.annotation_file = self.annotation_file.set_index('name')
36 |
37 | # load image size
38 | if isinstance(opt.loadSize, int):
39 | self.load_size = (128, 64)
40 | else:
41 | self.load_size = opt.loadSize
42 |
43 | # prepare for transformation
44 | transform_list=[]
45 | transform_list.append(transforms.ToTensor())
46 | transform_list.append(transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5)))
47 | self.trans = transforms.Compose(transform_list)
48 |
49 | def __getitem__(self, index):
50 | # prepare for source image Xs and target image Xt
51 | Xs_name, Xt_name = self.name_pairs[index]
52 | Xs_path = os.path.join(self.image_dir, Xs_name)
53 | Xt_path = os.path.join(self.image_dir, Xt_name)
54 |
55 | Xs = Image.open(Xs_path).convert('RGB')
56 | Xt = Image.open(Xt_path).convert('RGB')
57 |
58 | Xs = F.resize(Xs, self.load_size)
59 | Xt = F.resize(Xt, self.load_size)
60 |
61 | Ps = self.obtain_bone(Xs_name)
62 | Xs = self.trans(Xs)
63 | Pt = self.obtain_bone(Xt_name)
64 | Xt = self.trans(Xt)
65 |
66 | return {'Xs': Xs, 'Ps': Ps, 'Xt': Xt, 'Pt': Pt,
67 | 'Xs_path': Xs_name, 'Xt_path': Xt_name}
68 |
69 | def init_categories(self, pairLst):
70 | pairs_file_train = pd.read_csv(pairLst)
71 | size = len(pairs_file_train)
72 | pairs = []
73 | print('Loading data pairs ...')
74 | for i in range(size):
75 | pair = [pairs_file_train.iloc[i]['from'], pairs_file_train.iloc[i]['to']]
76 | pairs.append(pair)
77 |
78 | print('Loading data pairs finished ...')
79 | return pairs
80 |
81 | def getRandomAffineParam(self):
82 | if self.opt.angle is not False:
83 | angle = np.random.uniform(low=self.opt.angle[0], high=self.opt.angle[1])
84 | else:
85 | angle = 0
86 | if self.opt.scale is not False:
87 | scale = np.random.uniform(low=self.opt.scale[0], high=self.opt.scale[1])
88 | else:
89 | scale = 1
90 | if self.opt.shift is not False:
91 | shift_x = np.random.uniform(low=self.opt.shift[0], high=self.opt.shift[1])
92 | shift_y = np.random.uniform(low=self.opt.shift[0], high=self.opt.shift[1])
93 | else:
94 | shift_x = 0
95 | shift_y = 0
96 | return angle, (shift_x, shift_y), scale
97 |
98 | def obtain_bone(self, name):
99 | string = self.annotation_file.loc[name]
100 | array = pose_utils.load_pose_cords_from_strings(string['keypoints_y'], string['keypoints_x'])
101 | pose = pose_utils.cords_to_map(array, self.load_size, self.opt.old_size)
102 | pose = np.transpose(pose,(2, 0, 1))
103 | pose = torch.Tensor(pose)
104 | return pose
105 |
106 | def obtain_bone_affine(self, name, affine_matrix):
107 | string = self.annotation_file.loc[name]
108 | array = pose_utils.load_pose_cords_from_strings(string['keypoints_y'], string['keypoints_x'])
109 | pose = pose_utils.cords_to_map(array, self.load_size, self.opt.old_size, affine_matrix)
110 | pose = np.transpose(pose,(2, 0, 1))
111 | pose = torch.Tensor(pose)
112 | return pose
113 |
114 | def __len__(self):
115 | return len(self.name_pairs) // self.opt.batchSize * self.opt.batchSize
116 |
117 | def name(self):
118 | return 'MarketDataset'
--------------------------------------------------------------------------------
/data/fashion_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from data.base_dataset import BaseDataset, get_params, get_transform, normalize
3 | from data.image_folder import make_dataset
4 | import torchvision.transforms.functional as F
5 | import torchvision.transforms as transforms
6 | from PIL import Image
7 | from util import pose_utils
8 | import pandas as pd
9 | import numpy as np
10 | import torch
11 |
12 | class FashionDataset(BaseDataset):
13 | @staticmethod
14 | def modify_commandline_options(parser, is_train):
15 | if is_train:
16 | parser.set_defaults(load_size=256)
17 | else:
18 | parser.set_defaults(load_size=256)
19 | parser.set_defaults(old_size=(256, 176))
20 | parser.set_defaults(structure_nc=18)
21 | parser.set_defaults(image_nc=3)
22 | return parser
23 |
24 | def initialize(self, opt):
25 | self.opt = opt
26 | self.root = opt.dataroot
27 | self.phase = opt.phase
28 |
29 | # prepare for image (image_dir), image_pair (name_pairs) and bone annotation (annotation_file)
30 | self.image_dir = os.path.join(self.root, self.phase)
31 | self.bone_file = os.path.join(self.root, 'fasion-resize-annotation-%s.csv' % self.phase)
32 | pairLst = os.path.join(self.root, 'fasion-resize-pairs-%s.csv' % self.phase)
33 | self.name_pairs = self.init_categories(pairLst)
34 | self.annotation_file = pd.read_csv(self.bone_file, sep=':')
35 | self.annotation_file = self.annotation_file.set_index('name')
36 |
37 | # load image size
38 | if isinstance(opt.loadSize, int):
39 | self.load_size = (opt.loadSize, opt.loadSize)
40 | else:
41 | self.load_size = opt.loadSize
42 |
43 | # prepare for transformation
44 | transform_list=[]
45 | transform_list.append(transforms.ToTensor())
46 | transform_list.append(transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5)))
47 | self.trans = transforms.Compose(transform_list)
48 |
49 | def __getitem__(self, index):
50 | # prepare for source image Xs and target image Xt
51 | Xs_name, Xt_name = self.name_pairs[index]
52 | Xs_path = os.path.join(self.image_dir, Xs_name)
53 | Xt_path = os.path.join(self.image_dir, Xt_name)
54 |
55 | Xs = Image.open(Xs_path).convert('RGB')
56 | Xt = Image.open(Xt_path).convert('RGB')
57 |
58 | Xs = F.resize(Xs, self.load_size)
59 | Xt = F.resize(Xt, self.load_size)
60 |
61 | Ps = self.obtain_bone(Xs_name)
62 | Xs = self.trans(Xs)
63 | Pt = self.obtain_bone(Xt_name)
64 | Xt = self.trans(Xt)
65 |
66 | return {'Xs': Xs, 'Ps': Ps, 'Xt': Xt, 'Pt': Pt,
67 | 'Xs_path': Xs_name, 'Xt_path': Xt_name}
68 |
69 | def init_categories(self, pairLst):
70 | pairs_file_train = pd.read_csv(pairLst)
71 | size = len(pairs_file_train)
72 | pairs = []
73 | print('Loading data pairs ...')
74 | for i in range(size):
75 | pair = [pairs_file_train.iloc[i]['from'], pairs_file_train.iloc[i]['to']]
76 | pairs.append(pair)
77 |
78 | print('Loading data pairs finished ...')
79 | return pairs
80 |
81 | def getRandomAffineParam(self):
82 | if self.opt.angle is not False:
83 | angle = np.random.uniform(low=self.opt.angle[0], high=self.opt.angle[1])
84 | else:
85 | angle = 0
86 | if self.opt.scale is not False:
87 | scale = np.random.uniform(low=self.opt.scale[0], high=self.opt.scale[1])
88 | else:
89 | scale = 1
90 | if self.opt.shift is not False:
91 | shift_x = np.random.uniform(low=self.opt.shift[0], high=self.opt.shift[1])
92 | shift_y = np.random.uniform(low=self.opt.shift[0], high=self.opt.shift[1])
93 | else:
94 | shift_x = 0
95 | shift_y = 0
96 | return angle, (shift_x, shift_y), scale
97 |
98 | def obtain_bone(self, name):
99 | string = self.annotation_file.loc[name]
100 | array = pose_utils.load_pose_cords_from_strings(string['keypoints_y'], string['keypoints_x'])
101 | pose = pose_utils.cords_to_map(array, self.load_size, self.opt.old_size)
102 | pose = np.transpose(pose,(2, 0, 1))
103 | pose = torch.Tensor(pose)
104 | return pose
105 |
106 | def obtain_bone_affine(self, name, affine_matrix):
107 | string = self.annotation_file.loc[name]
108 | array = pose_utils.load_pose_cords_from_strings(string['keypoints_y'], string['keypoints_x'])
109 | pose = pose_utils.cords_to_map(array, self.load_size, self.opt.old_size, affine_matrix)
110 | pose = np.transpose(pose,(2, 0, 1))
111 | pose = torch.Tensor(pose)
112 | return pose
113 |
114 | def __len__(self):
115 | return len(self.name_pairs) // self.opt.batchSize * self.opt.batchSize
116 |
117 | def name(self):
118 | return 'FashionDataset'
--------------------------------------------------------------------------------
/metrics/inception.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 | from torchvision import models
4 |
5 |
6 | class InceptionV3(nn.Module):
7 | """Pretrained InceptionV3 network returning feature maps"""
8 |
9 | # Index of default block of inception to return,
10 | # corresponds to output of final average pooling
11 | DEFAULT_BLOCK_INDEX = 3
12 |
13 | # Maps feature dimensionality to their output blocks indices
14 | BLOCK_INDEX_BY_DIM = {
15 | 64: 0, # First max pooling features
16 | 192: 1, # Second max pooling featurs
17 | 768: 2, # Pre-aux classifier features
18 | 2048: 3 # Final average pooling features
19 | }
20 |
21 | def __init__(self,
22 | output_blocks=[DEFAULT_BLOCK_INDEX],
23 | resize_input=True,
24 | normalize_input=True,
25 | requires_grad=False):
26 | """Build pretrained InceptionV3
27 | Parameters
28 | ----------
29 | output_blocks : list of int
30 | Indices of blocks to return features of. Possible values are:
31 | - 0: corresponds to output of first max pooling
32 | - 1: corresponds to output of second max pooling
33 | - 2: corresponds to output which is fed to aux classifier
34 | - 3: corresponds to output of final average pooling
35 | resize_input : bool
36 | If true, bilinearly resizes input to width and height 299 before
37 | feeding input to model. As the network without fully connected
38 | layers is fully convolutional, it should be able to handle inputs
39 | of arbitrary size, so resizing might not be strictly needed
40 | normalize_input : bool
41 | If true, normalizes the input to the statistics the pretrained
42 | Inception network expects
43 | requires_grad : bool
44 | If true, parameters of the model require gradient. Possibly useful
45 | for finetuning the network
46 | """
47 | super(InceptionV3, self).__init__()
48 |
49 | self.resize_input = resize_input
50 | self.normalize_input = normalize_input
51 | self.output_blocks = sorted(output_blocks)
52 | self.last_needed_block = max(output_blocks)
53 |
54 | assert self.last_needed_block <= 3, \
55 | 'Last possible output block index is 3'
56 |
57 | self.blocks = nn.ModuleList()
58 |
59 | inception = models.inception_v3(pretrained=True)
60 |
61 | # Block 0: input to maxpool1
62 | block0 = [
63 | inception.Conv2d_1a_3x3,
64 | inception.Conv2d_2a_3x3,
65 | inception.Conv2d_2b_3x3,
66 | nn.MaxPool2d(kernel_size=3, stride=2)
67 | ]
68 | self.blocks.append(nn.Sequential(*block0))
69 |
70 | # Block 1: maxpool1 to maxpool2
71 | if self.last_needed_block >= 1:
72 | block1 = [
73 | inception.Conv2d_3b_1x1,
74 | inception.Conv2d_4a_3x3,
75 | nn.MaxPool2d(kernel_size=3, stride=2)
76 | ]
77 | self.blocks.append(nn.Sequential(*block1))
78 |
79 | # Block 2: maxpool2 to aux classifier
80 | if self.last_needed_block >= 2:
81 | block2 = [
82 | inception.Mixed_5b,
83 | inception.Mixed_5c,
84 | inception.Mixed_5d,
85 | inception.Mixed_6a,
86 | inception.Mixed_6b,
87 | inception.Mixed_6c,
88 | inception.Mixed_6d,
89 | inception.Mixed_6e,
90 | ]
91 | self.blocks.append(nn.Sequential(*block2))
92 |
93 | # Block 3: aux classifier to final avgpool
94 | if self.last_needed_block >= 3:
95 | block3 = [
96 | inception.Mixed_7a,
97 | inception.Mixed_7b,
98 | inception.Mixed_7c,
99 | nn.AdaptiveAvgPool2d(output_size=(1, 1))
100 | ]
101 | self.blocks.append(nn.Sequential(*block3))
102 |
103 | for param in self.parameters():
104 | param.requires_grad = requires_grad
105 |
106 | def forward(self, inp):
107 | """Get Inception feature maps
108 | Parameters
109 | ----------
110 | inp : torch.autograd.Variable
111 | Input tensor of shape Bx3xHxW. Values are expected to be in
112 | range (0, 1)
113 | Returns
114 | -------
115 | List of torch.autograd.Variable, corresponding to the selected output
116 | block, sorted ascending by index
117 | """
118 | outp = []
119 | x = inp
120 |
121 | if self.resize_input:
122 | x = F.upsample(x, size=(299, 299), mode='bilinear')
123 |
124 | if self.normalize_input:
125 | x = x.clone()
126 | x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
127 | x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
128 | x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
129 |
130 | for idx, block in enumerate(self.blocks):
131 | x = block(x)
132 | if idx in self.output_blocks:
133 | outp.append(x)
134 |
135 | if idx == self.last_needed_block:
136 | break
137 |
138 | return outp
139 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import numpy as np
4 | import torch
5 | from torch.autograd import Variable
6 | from collections import OrderedDict
7 | from subprocess import call
8 | import fractions
9 | def lcm(a,b): return abs(a * b)/fractions.gcd(a,b) if a and b else 0
10 |
11 | from options.train_options import TrainOptions
12 | from data.data_loader import CreateDataLoader
13 | from models.models import create_model
14 | import util.util as util
15 | from util.visualizer import Visualizer
16 | from torch.utils.tensorboard import SummaryWriter
17 |
18 | opt = TrainOptions().parse()
19 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
20 | if opt.continue_train:
21 | try:
22 | start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
23 | except:
24 | start_epoch, epoch_iter = 1, 0
25 | print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
26 | else:
27 | start_epoch, epoch_iter = 1, 0
28 |
29 | opt.iter_start = start_epoch
30 |
31 | opt.print_freq = lcm(opt.print_freq, opt.batchSize)
32 | if opt.debug:
33 | opt.display_freq = 1
34 | opt.print_freq = 1
35 | opt.niter = 1
36 | opt.niter_decay = 0
37 | opt.max_dataset_size = 10
38 |
39 | data_loader = CreateDataLoader(opt)
40 | dataset = data_loader.load_data()
41 | dataset_size = len(data_loader)
42 | print('#training images = %d' % dataset_size)
43 | writer = SummaryWriter(comment=opt.name)
44 |
45 | model = create_model(opt)
46 | visualizer = Visualizer(opt)
47 |
48 | total_steps = (start_epoch-1) * dataset_size + epoch_iter
49 |
50 | display_delta = total_steps % opt.display_freq
51 | print_delta = total_steps % opt.print_freq
52 | save_delta = total_steps % opt.save_latest_freq
53 |
54 | for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
55 | epoch_start_time = time.time()
56 | if epoch != start_epoch:
57 | epoch_iter = epoch_iter % dataset_size
58 | for i, data in enumerate(dataset, start=epoch_iter):
59 | print("epoch: ", epoch, "iter: ", epoch_iter, "total_iteration: ", total_steps, end=" ")
60 | if total_steps % opt.print_freq == print_delta:
61 | iter_start_time = time.time()
62 | total_steps += opt.batchSize
63 | epoch_iter += opt.batchSize
64 |
65 | save_fake = total_steps % opt.display_freq == display_delta
66 |
67 | model.set_input(data)
68 | model.optimize_parameters()
69 |
70 | losses = model.get_current_errors()
71 | for k, v in losses.items():
72 | print(k, ": ", '%.2f' % v, end=" ")
73 | lr_G, lr_D = model.get_current_learning_rate()
74 | print("learning rate G: %.7f" % lr_G, end=" ")
75 | print("learning rate D: %.7f" % lr_D, end=" ")
76 | print('\n')
77 |
78 |
79 | writer.add_scalar('Loss/app_gen_s', losses['app_gen_s'], total_steps)
80 | writer.add_scalar('Loss/content_gen_s', losses['content_gen_s'], total_steps)
81 | writer.add_scalar('Loss/style_gen_s', losses['style_gen_s'], total_steps)
82 | writer.add_scalar('Loss/app_gen_t', losses['app_gen_t'], total_steps)
83 | writer.add_scalar('Loss/ad_gen_t', losses['ad_gen_t'], total_steps)
84 | writer.add_scalar('Loss/dis_img_gen_t', losses['dis_img_gen_t'], total_steps)
85 | writer.add_scalar('Loss/content_gen_t', losses['content_gen_t'], total_steps)
86 | writer.add_scalar('Loss/style_gen_t', losses['style_gen_t'], total_steps)
87 | writer.add_scalar('LR/G', lr_G, total_steps)
88 | writer.add_scalar('LR/D', lr_D, total_steps)
89 |
90 |
91 | ############## Display results and errors ##########
92 | if total_steps % opt.print_freq == print_delta:
93 | losses = model.get_current_errors()
94 | t = (time.time() - iter_start_time) / opt.batchSize
95 | visualizer.print_current_errors(epoch, epoch_iter, total_steps, losses, lr_G, lr_D, t)
96 | if opt.display_id > 0:
97 | visualizer.plot_current_errors(total_steps, losses)
98 |
99 | if total_steps % opt.display_freq == display_delta:
100 | visualizer.display_current_results(model.get_current_visuals(), epoch)
101 | if hasattr(model, 'distribution'):
102 | visualizer.plot_current_distribution(model.get_current_dis())
103 |
104 | if total_steps % opt.save_latest_freq == save_delta:
105 | print('saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
106 | model.save_networks('latest')
107 | if opt.dataset_mode == 'market':
108 | model.save_networks(total_steps)
109 | np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
110 |
111 | if epoch_iter >= dataset_size:
112 | break
113 |
114 | # end of epoch
115 | iter_end_time = time.time()
116 | print('End of epoch %d / %d \t Time Taken: %d sec' %
117 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
118 |
119 | ### save model for this epoch
120 | if epoch % opt.save_epoch_freq == 0 or (epoch > opt.niter and epoch % (opt.save_epoch_freq//2) == 0):
121 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
122 | model.save_networks('latest')
123 | model.save_networks(epoch)
124 | np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')
125 |
126 | ### linearly decay learning rate after certain iterations
127 | model.update_learning_rate()
128 |
--------------------------------------------------------------------------------
/util/pose_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.ndimage.filters import gaussian_filter
3 | from skimage.draw import circle, line_aa, polygon
4 | import json
5 |
6 | import matplotlib
7 | matplotlib.use('Agg')
8 | import matplotlib.pyplot as plt
9 | import matplotlib.patches as mpatches
10 | from collections import defaultdict
11 | import skimage.measure, skimage.transform
12 | import sys
13 |
14 | LIMB_SEQ = [[1,2], [1,5], [2,3], [3,4], [5,6], [6,7], [1,8], [8,9],
15 | [9,10], [1,11], [11,12], [12,13], [1,0], [0,14], [14,16],
16 | [0,15], [15,17], [2,16], [5,17]]
17 |
18 | COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
19 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
20 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
21 |
22 |
23 | LABELS = ['nose', 'neck', 'Rsho', 'Relb', 'Rwri', 'Lsho', 'Lelb', 'Lwri',
24 | 'Rhip', 'Rkne', 'Rank', 'Lhip', 'Lkne', 'Lank', 'Leye', 'Reye', 'Lear', 'Rear']
25 |
26 | MISSING_VALUE = -1
27 |
28 |
29 | def map_to_cord(pose_map, threshold=0.1):
30 | all_peaks = [[] for i in range(18)]
31 | pose_map = pose_map[..., :18]
32 |
33 | y, x, z = np.where(np.logical_and(pose_map == pose_map.max(axis = (0, 1)),
34 | pose_map > threshold))
35 | for x_i, y_i, z_i in zip(x, y, z):
36 | all_peaks[z_i].append([x_i, y_i])
37 |
38 | x_values = []
39 | y_values = []
40 |
41 | for i in range(18):
42 | if len(all_peaks[i]) != 0:
43 | x_values.append(all_peaks[i][0][0])
44 | y_values.append(all_peaks[i][0][1])
45 | else:
46 | x_values.append(MISSING_VALUE)
47 | y_values.append(MISSING_VALUE)
48 |
49 | return np.concatenate([np.expand_dims(y_values, -1), np.expand_dims(x_values, -1)], axis=1)
50 |
51 |
52 | def cords_to_map(cords, img_size, old_size=None, affine_matrix=None, sigma=6):
53 | old_size = img_size if old_size is None else old_size
54 | cords = cords.astype(float)
55 | result = np.zeros(img_size + cords.shape[0:1], dtype='float32')
56 | for i, point in enumerate(cords):
57 | if point[0] == MISSING_VALUE or point[1] == MISSING_VALUE:
58 | continue
59 | point[0] = point[0]/old_size[0] * img_size[0]
60 | point[1] = point[1]/old_size[1] * img_size[1]
61 | if affine_matrix is not None:
62 | point_ =np.dot(affine_matrix, np.matrix([point[1], point[0], 1]).reshape(3,1))
63 | point_0 = int(point_[1])
64 | point_1 = int(point_[0])
65 | else:
66 | point_0 = int(point[0])
67 | point_1 = int(point[1])
68 | xx, yy = np.meshgrid(np.arange(img_size[1]), np.arange(img_size[0]))
69 | result[..., i] = np.exp(-((yy - point_0) ** 2 + (xx - point_1) ** 2) / (2 * sigma ** 2))
70 | return result
71 |
72 |
73 | def draw_pose_from_cords(pose_joints, img_size, radius=2, draw_joints=True):
74 | colors = np.zeros(shape=img_size + (3, ), dtype=np.uint8)
75 | mask = np.zeros(shape=img_size, dtype=bool)
76 |
77 | if draw_joints:
78 | for f, t in LIMB_SEQ:
79 | from_missing = pose_joints[f][0] == MISSING_VALUE or pose_joints[f][1] == MISSING_VALUE
80 | to_missing = pose_joints[t][0] == MISSING_VALUE or pose_joints[t][1] == MISSING_VALUE
81 | if from_missing or to_missing:
82 | continue
83 | yy, xx, val = line_aa(pose_joints[f][0], pose_joints[f][1], pose_joints[t][0], pose_joints[t][1])
84 | colors[yy, xx] = np.expand_dims(val, 1) * 255
85 | mask[yy, xx] = True
86 |
87 | for i, joint in enumerate(pose_joints):
88 | if pose_joints[i][0] == MISSING_VALUE or pose_joints[i][1] == MISSING_VALUE:
89 | continue
90 | yy, xx = circle(joint[0], joint[1], radius=radius, shape=img_size)
91 | colors[yy, xx] = COLORS[i]
92 | mask[yy, xx] = True
93 |
94 | return colors, mask
95 |
96 |
97 | def draw_pose_from_map(pose_map, threshold=0.1, **kwargs):
98 | cords = map_to_cord(pose_map, threshold=threshold)
99 | return draw_pose_from_cords(cords, pose_map.shape[:2], **kwargs)
100 |
101 |
102 | def load_pose_cords_from_strings(y_str, x_str):
103 | y_cords = json.loads(y_str)
104 | x_cords = json.loads(x_str)
105 | return np.concatenate([np.expand_dims(y_cords, -1), np.expand_dims(x_cords, -1)], axis=1)
106 |
107 | def mean_inputation(X):
108 | X = X.copy()
109 | for i in range(X.shape[1]):
110 | for j in range(X.shape[2]):
111 | val = np.mean(X[:, i, j][X[:, i, j] != -1])
112 | X[:, i, j][X[:, i, j] == -1] = val
113 | return X
114 |
115 | def draw_legend():
116 | handles = [mpatches.Patch(color=np.array(color) / 255.0, label=name) for color, name in zip(COLORS, LABELS)]
117 | plt.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
118 |
119 | def produce_ma_mask(kp_array, img_size, point_radius=4):
120 | from skimage.morphology import dilation, erosion, square
121 | mask = np.zeros(shape=img_size, dtype=bool)
122 | limbs = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10],
123 | [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17],
124 | [1,16], [16,18], [2,17], [2,18], [9,12], [12,6], [9,3], [17,18]]
125 | limbs = np.array(limbs) - 1
126 | for f, t in limbs:
127 | from_missing = kp_array[f][0] == MISSING_VALUE or kp_array[f][1] == MISSING_VALUE
128 | to_missing = kp_array[t][0] == MISSING_VALUE or kp_array[t][1] == MISSING_VALUE
129 | if from_missing or to_missing:
130 | continue
131 |
132 | norm_vec = kp_array[f] - kp_array[t]
133 | norm_vec = np.array([-norm_vec[1], norm_vec[0]])
134 | norm_vec = point_radius * norm_vec / np.linalg.norm(norm_vec)
135 |
136 |
137 | vetexes = np.array([
138 | kp_array[f] + norm_vec,
139 | kp_array[f] - norm_vec,
140 | kp_array[t] - norm_vec,
141 | kp_array[t] + norm_vec
142 | ])
143 | yy, xx = polygon(vetexes[:, 0], vetexes[:, 1], shape=img_size)
144 | mask[yy, xx] = True
145 |
146 | for i, joint in enumerate(kp_array):
147 | if kp_array[i][0] == MISSING_VALUE or kp_array[i][1] == MISSING_VALUE:
148 | continue
149 | yy, xx = circle(joint[0], joint[1], radius=point_radius, shape=img_size)
150 | mask[yy, xx] = True
151 |
152 | mask = dilation(mask, square(5))
153 | mask = erosion(mask, square(5))
154 | return mask
155 |
156 | if __name__ == "__main__":
157 | import pandas as pd
158 | from skimage.io import imread
159 | import pylab as plt
160 | import os
161 | i = 5
162 | df = pd.read_csv('data/market-annotation-train.csv', sep=':')
163 |
164 | for index, row in df.iterrows():
165 | pose_cords = load_pose_cords_from_strings(row['keypoints_y'], row['keypoints_x'])
166 |
167 | colors, mask = draw_pose_from_cords(pose_cords, (128, 64))
168 |
169 | mmm = produce_ma_mask(pose_cords, (128, 64)).astype(float)[..., np.newaxis].repeat(3, axis=-1)
170 | print(mmm.shape)
171 | img = imread('data/market-dataset/train/' + row['name'])
172 |
173 | mmm[mask] = colors[mask]
174 |
175 | print (mmm)
176 | plt.subplot(1, 1, 1)
177 | plt.imshow(mmm)
178 | plt.show()
179 |
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from util import util
4 | import torch
5 | import data
6 | import models
7 |
8 | class BaseOptions():
9 | def __init__(self):
10 | self.parser = argparse.ArgumentParser()
11 | self.initialized = False
12 |
13 | def initialize(self):
14 | # experiment specifics
15 | self.parser.add_argument('--name', type=str, default='label2city', help='name of the experiment. It decides where to store samples and models')
16 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
17 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
18 | self.parser.add_argument('--model', type=str, default='pix2pixHD', help='which model to use')
19 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
20 | self.parser.add_argument('--use_dropout', action='store_true', help='use dropout for the generator')
21 | self.parser.add_argument('--data_type', default=32, type=int, choices=[8, 16, 32], help="Supported data type i.e. 8, 16, 32 bit")
22 | self.parser.add_argument('--verbose', action='store_true', default=False, help='toggles verbose')
23 | self.parser.add_argument('--fp16', action='store_true', default=False, help='train with AMP')
24 | self.parser.add_argument('--local_rank', type=int, default=0, help='local rank for distributed training')
25 |
26 | # input/output sizes
27 | self.parser.add_argument('--image_nc', type=int, default=3)
28 | self.parser.add_argument('--pose_nc', type=int, default=18)
29 | self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
30 | self.parser.add_argument('--old_size', type=int, default=(256, 176), help='Scale images to this size. The final image will be cropped to --crop_size.')
31 | self.parser.add_argument('--loadSize', type=int, default=256, help='scale images to this size')
32 | self.parser.add_argument('--fineSize', type=int, default=512, help='then crop to this size')
33 | self.parser.add_argument('--label_nc', type=int, default=35, help='# of input label channels')
34 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
35 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
36 |
37 | # for setting inputs
38 | self.parser.add_argument('--dataset_mode', type=str, default='fashion')
39 | self.parser.add_argument('--dataroot', type=str, default='/media/data2/zhangpz/DataSet/Fashion')
40 | self.parser.add_argument('--resize_or_crop', type=str, default='scale_width', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
41 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
42 | self.parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data argumentation')
43 | self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
44 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
45 |
46 | # for displays
47 | self.parser.add_argument('--display_winsize', type=int, default=512, help='display window size')
48 | self.parser.add_argument('--tf_log', action='store_true', help='if specified, use tensorboard logging. Requires tensorflow installed')
49 | self.parser.add_argument('--display_id', type=int, default=0, help='display id of the web') # 1
50 | self.parser.add_argument('--display_port', type=int, default=8096, help='visidom port of the web display')
51 | self.parser.add_argument('--display_single_pane_ncols', type=int, default=0,
52 | help='if positive, display all images in a single visidom web panel')
53 | self.parser.add_argument('--display_env', type=str, default=self.parser.parse_known_args()[0].name.replace('_', ''),
54 | help='the environment of visidom display')
55 | # for instance-wise features
56 | self.parser.add_argument('--no_instance', action='store_true', help='if specified, do *not* add instance map as input')
57 | self.parser.add_argument('--instance_feat', action='store_true', help='if specified, add encoded instance features as input')
58 | self.parser.add_argument('--label_feat', action='store_true', help='if specified, add encoded label features as input')
59 | self.parser.add_argument('--feat_num', type=int, default=3, help='vector length for encoded features')
60 | self.parser.add_argument('--load_features', action='store_true', help='if specified, load precomputed feature maps')
61 | self.parser.add_argument('--n_downsample_E', type=int, default=4, help='# of downsampling layers in encoder')
62 | self.parser.add_argument('--nef', type=int, default=16, help='# of encoder filters in the first conv layer')
63 | self.parser.add_argument('--n_clusters', type=int, default=10, help='number of clusters for features')
64 |
65 | self.initialized = True
66 |
67 | def parse(self, save=True):
68 | if not self.initialized:
69 | self.initialize()
70 | opt, _ = self.parser.parse_known_args()
71 | # modify the options for different models
72 | model_option_set = models.get_option_setter(opt.model)
73 | self.parser = model_option_set(self.parser, self.isTrain)
74 |
75 | data_option_set = data.get_option_setter(opt.dataset_mode)
76 | self.parser = data_option_set(self.parser, self.isTrain)
77 |
78 | self.opt = self.parser.parse_args()
79 | self.opt.isTrain = self.isTrain # train or test
80 |
81 | if torch.cuda.is_available():
82 | self.opt.device = torch.device("cuda")
83 | torch.backends.cudnn.benchmark = True # cudnn auto-tuner
84 | else:
85 | self.opt.device = torch.device("cpu")
86 |
87 | str_ids = self.opt.gpu_ids.split(',')
88 | self.opt.gpu_ids = []
89 | for str_id in str_ids:
90 | id = int(str_id)
91 | if id >= 0:
92 | self.opt.gpu_ids.append(id)
93 |
94 | # set gpu ids
95 | if len(self.opt.gpu_ids) > 0:
96 | torch.cuda.set_device(self.opt.gpu_ids[0])
97 |
98 | args = vars(self.opt)
99 |
100 | print('------------ Options -------------')
101 | for k, v in sorted(args.items()):
102 | print('%s: %s' % (str(k), str(v)))
103 | print('-------------- End ----------------')
104 |
105 | # save to the disk
106 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
107 | util.mkdirs(expr_dir)
108 | if save and not (self.isTrain and self.opt.continue_train):
109 | name = 'train' if self.isTrain else 'test'
110 | file_name = os.path.join(expr_dir, name+'_opt.txt')
111 | with open(file_name, 'wt') as opt_file:
112 | opt_file.write('------------ Options -------------\n')
113 | for k, v in sorted(args.items()):
114 | opt_file.write('%s: %s\n' % (str(k), str(v)))
115 | opt_file.write('-------------- End ----------------\n')
116 | return self.opt
117 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Dual-task Pose Transformer Network
2 | The source code for our paper "[Exploring Dual-task Correlation for Pose Guided Person Image Generation](https://openaccess.thecvf.com/content/CVPR2022/papers/Zhang_Exploring_Dual-Task_Correlation_for_Pose_Guided_Person_Image_Generation_CVPR_2022_paper.pdf)“, Pengze Zhang, Lingxiao Yang, Jianhuang Lai, and Xiaohua Xie, CVPR 2022. Video: [[Chinese](https://www.koushare.com/video/videodetail/35887)] [[English](https://www.youtube.com/watch?v=p9o3lOlZBSE)]
3 |
4 |
5 | ## Abstract
6 |
7 | Pose Guided Person Image Generation (PGPIG) is the task of transforming a person image from the source pose to a given target pose. Most of the existing methods only focus on the ill-posed source-to-target task and fail to capture reasonable texture mapping. To address this problem, we propose a novel Dual-task Pose Transformer Network (DPTN), which introduces an auxiliary task (i.e., source-tosource task) and exploits the dual-task correlation to promote the performance of PGPIG. The DPTN is of a Siamese structure, containing a source-to-source self-reconstruction branch, and a transformation branch for source-to-target generation. By sharing partial weights between them, the knowledge learned by the source-to-source task can effectively assist the source-to-target learning. Furthermore, we bridge the two branches with a proposed Pose Transformer Module (PTM) to adaptively explore the correlation between features from dual tasks. Such correlation can establish the fine-grained mapping of all the pixels between the sources and the targets, and promote the source texture transmission to enhance the details of the generated target images. Extensive experiments show that our DPTN outperforms state-of-the-arts in terms of both PSNR and LPIPS. In addition, our DPTN only contains 9.79 million parameters, which is significantly smaller than other approaches.
8 |
9 |
10 | ## Get Start
11 |
12 | ### 1) Requirement
13 |
14 | * Python 3.7.9
15 | * Pytorch 1.7.1
16 | * torchvision 0.8.2
17 | * CUDA 11.1
18 | * NVIDIA A100 40GB PCIe
19 |
20 | ### 2) Data Preperation
21 |
22 | Following **[PATN](https://github.com/tengteng95/Pose-Transfer)**, the dataset split files and extracted keypoints files can be obtained as follows:
23 |
24 | **DeepFashion**
25 |
26 |
27 | * Download the DeepFashion dataset **[in-shop clothes retrival benchmark](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html)**, and put them under the `./dataset/fashion` directory.
28 |
29 | * Download train/test pairs and train/test keypoints annotations from **[Google Drive](https://drive.google.com/drive/folders/1qZDod3QDD7PaBxnNyHCuLBR7ftTSkSE1?usp=sharing)**, including **fasion-resize-pairs-train.csv, fasion-resize-pairs-test.csv, fasion-resize-annotation-train.csv, fasion-resize-annotation-train.csv, train.lst, test.lst**, and put them under the `./dataset/fashion` directory.
30 |
31 | * Split the raw image into the training set (`./dataset/fashion/train`) and test set (`./dataset/fashion/test`):
32 | ``` bash
33 | python data/generate_fashion_datasets.py
34 | ```
35 |
36 | **Market1501**
37 |
38 | * Download the Market1501 dataset from **[here](http://zheng-lab.cecs.anu.edu.au/Project/project_reid.html)**. Rename **bounding_box_train** and **bounding_box_test** as **train** and **test**, and put them under the `./dataset/market` directory.
39 |
40 | * Download train/test key points annotations from **[Google Drive](https://drive.google.com/drive/folders/1zzkimhX_D5gR1G8txTQkPXwdZPRcnrAx?usp=sharing)** including **market-pairs-train.csv, market-pairs-test.csv, market-annotation-train.csv, market-annotation-train.csv**. Put these files under the `./dataset/market` directory.
41 |
42 | ### 3) Train a model
43 |
44 | **DeepFashion**
45 | ``` bash
46 | python train.py --name=DPTN_fashion --model=DPTN --dataset_mode=fashion --dataroot=./dataset/fashion --batchSize 32 --gpu_id=0
47 | ```
48 | **Market1501**
49 |
50 | ``` bash
51 | python train.py --name=DPTN_market --model=DPTN --dataset_mode=market --dataroot=./dataset/market --dis_layers=3 --lambda_g=5 --lambda_rec 2 --t_s_ratio=0.8 --save_latest_freq=10400 --batchSize 32 --gpu_id=0
52 | ```
53 |
54 | ### 4) Test the model
55 |
56 | You can directly download our test results from Google Drive: **[Deepfashion](https://drive.google.com/drive/folders/1Y_Ar7w_CAYRgG2gzBg2vfxTCCen7q7k2?usp=sharing)**, **[Market1501](https://drive.google.com/drive/folders/15UBWEtGAqYaoEREIIeIuD-P4dRgsys19?usp=sharing)**.
57 |
58 | **DeepFashion**
59 | ``` bash
60 | python test.py --name=DPTN_fashion --model=DPTN --dataset_mode=fashion --dataroot=./dataset/fashion --which_epoch latest --results_dir ./results/DPTN_fashion --batchSize 1 --gpu_id=0
61 | ```
62 |
63 | **Market1501**
64 |
65 | ``` bash
66 | python test.py --name=DPTN_market --model=DPTN --dataset_mode=market --dataroot=./dataset/market --which_epoch latest --results_dir=./results/DPTN_market --batchSize 1 --gpu_id=0
67 | ```
68 |
69 | ### 5) Evaluation
70 |
71 | We adopt SSIM, PSNR, FID, LPIPS and person re-identification (re-id) system for the evaluation. Please clone the official repository **[PerceptualSimilarity](https://github.com/richzhang/PerceptualSimilarity/tree/future)** of the LPIPS score, and put the folder PerceptualSimilarity to the folder **[metrics](https://github.com/PangzeCheung/Dual-task-Pose-Transformer-Network/tree/main/metrics)**.
72 |
73 | * For SSIM, PSNR, FID and LPIPS:
74 |
75 | **DeepFashion**
76 | ``` bash
77 | python -m metrics.metrics --gt_path=./dataset/fashion/test --distorated_path=./results/DPTN_fashion --fid_real_path=./dataset/fashion/train --name=./fashion
78 | ```
79 |
80 | **Market1501**
81 |
82 | ``` bash
83 | python -m metrics.metrics --gt_path=./dataset/market/test --distorated_path=./results/DPTN_market --fid_real_path=./dataset/market/train --name=./market --market
84 | ```
85 |
86 | * For person re-id system:
87 |
88 | Clone the code of the **[fast-reid](https://github.com/JDAI-CV/fast-reid)** to this project (`./fast-reid-master`). Move the **[config](https://drive.google.com/file/d/1xWCnNpcNrgjEMDKuK29Gre3sYEE1yWTV/view?usp=sharing)** and **[loader](https://drive.google.com/file/d/1axMKB7QlYQgo7f1ZWigTh3uLIDvXRxro/view?usp=sharing)** of the DeepFashion dataset to (`./fast-reid-master/configs/Fashion/bagtricks_R50.yml`) and (`./fast-reid-master/fastreid/data/datasets/fashion.py`) respectively. Download the **[pre-trained network](https://drive.google.com/file/d/1Co6NVWN6OSqPVUd7ut8xCwsQQDIOcypV/view?usp=sharing)** and put it under the `./fast-reid-master/logs/Fashion/bagtricks_R50-ibn/` directory. And then launch:
89 |
90 | ``` bash
91 | python ./tools/train_net.py --config-file ./configs/Fashion/bagtricks_R50.yml --eval-only MODEL.WEIGHTS ./logs/Fashion/bagtricks_R50-ibn/model_final.pth MODEL.DEVICE "cuda:0"
92 | ```
93 |
94 | ### 6) Pre-trained Model
95 |
96 | Our pre-trained models and logs can be downloaded from Google Drive: **[Deepfashion](https://drive.google.com/drive/folders/12Ufr8jkOwAIGVEamDedJy_ZWPvJZn8WG?usp=sharing)**[**[log](https://drive.google.com/drive/folders/16ZYYl_jVdK8E9FtnQi6oi6JGfBuD2jCt?usp=sharing)**], **[Market1501](https://drive.google.com/drive/folders/1YY_U2pMzLrZMTKoK8oBkMylR6KXnZJKP?usp=sharing)**[**[log](https://drive.google.com/drive/folders/1ujlvhz7JILULRVRJsLruT9ZAz2JCT74G?usp=sharing)**].
97 |
98 | ## Citation
99 |
100 | ```tex
101 | @InProceedings{Zhang_2022_CVPR,
102 | author = {Zhang, Pengze and Yang, Lingxiao and Lai, Jian-Huang and Xie, Xiaohua},
103 | title = {Exploring Dual-Task Correlation for Pose Guided Person Image Generation},
104 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
105 | month = {June},
106 | year = {2022},
107 | pages = {7713-7722}
108 | }
109 | ```
110 | ## Acknowledgement
111 |
112 | We build our project based on **[pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix)**. Some dataset preprocessing methods are derived from **[PATN](https://github.com/tengteng95/Pose-Transfer)**.
113 |
114 |
--------------------------------------------------------------------------------
/models/PTM.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import torch
3 | from torch import nn
4 | from .base_function import *
5 |
6 | class PTM(nn.Module):
7 | """
8 | Pose Transformer Module (PTM)
9 | :param d_model: number of channels in input
10 | :param nhead: number of heads in attention module
11 | :param num_CABs: number of CABs
12 | :param num_TTBs: number of TTBs
13 | :param dim_feedforward: dimension in feedforward
14 | :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
15 | :param affine: affine in normalization
16 | :param norm: normalization function 'instance, batch'
17 | """
18 | def __init__(self, d_model=512, nhead=8, num_CABs=6,
19 | num_TTBs=6, dim_feedforward=2048,
20 | activation="LeakyReLU",
21 | affine=True, norm='instance'):
22 | super().__init__()
23 | encoder_layer = CAB(d_model, nhead, dim_feedforward,
24 | activation, affine, norm)
25 | if norm == 'batch':
26 | encoder_norm = None
27 | decoder_norm = nn.BatchNorm1d(d_model, affine=affine)
28 | elif norm == 'instance':
29 | encoder_norm = None
30 | decoder_norm = nn.InstanceNorm1d(d_model, affine=affine)
31 |
32 | self.encoder = CABs(encoder_layer, num_CABs, encoder_norm)
33 |
34 | decoder_layer = TTB(d_model, nhead, dim_feedforward,
35 | activation, affine, norm)
36 |
37 | self.decoder = TTBs(decoder_layer, num_TTBs, decoder_norm)
38 |
39 | self._reset_parameters()
40 |
41 | self.d_model = d_model
42 | self.nhead = nhead
43 |
44 | def _reset_parameters(self):
45 | for p in self.parameters():
46 | if p.dim() > 1:
47 | nn.init.xavier_uniform_(p)
48 |
49 | def forward(self, src, tgt, val, pos_embed=None):
50 | bs, c, h, w = src.shape
51 | src = src.flatten(2).permute(2, 0, 1)
52 | tgt = tgt.flatten(2).permute(2, 0, 1)
53 | val = val.flatten(2).permute(2, 0, 1)
54 | if pos_embed != None:
55 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
56 | memory = self.encoder(src, pos=pos_embed)
57 | hs = self.decoder(tgt, memory, val, pos=pos_embed)
58 | return hs.view(bs, c, h, w)
59 |
60 |
61 | class CABs(nn.Module):
62 | """
63 | Context Augment Blocks (CABs)
64 | :param encoder_layer: CAB
65 | :param num_CABS: number of CABs
66 | :param norm: normalization function 'instance, batch'
67 | """
68 | def __init__(self, encoder_layer, num_CABs, norm=None):
69 | super().__init__()
70 | self.layers = _get_clones(encoder_layer, num_CABs)
71 | self.norm = norm
72 |
73 | def forward(self, src, pos = None):
74 | output = src
75 |
76 | for layer in self.layers:
77 | output = layer(output, pos=pos)
78 |
79 | if self.norm is not None:
80 | output = self.norm(output.permute(1, 2, 0)).permute(2, 0, 1)
81 |
82 | return output
83 |
84 |
85 | class TTBs(nn.Module):
86 | """
87 | Texture Transfer Blocks (TTBs)
88 | :param decoder_layer: TTB
89 | :param num_layers: number of TTBs
90 | :param norm: normalization function 'instance, batch'
91 | """
92 | def __init__(self, decoder_layer, num_TTBs, norm=None):
93 | super().__init__()
94 | self.layers = _get_clones(decoder_layer, num_TTBs)
95 | self.norm = norm
96 |
97 | def forward(self, tgt, memory, val, pos = None):
98 | output = tgt
99 |
100 | for layer in self.layers:
101 | output = layer(output, memory, val, pos=pos)
102 |
103 | if self.norm is not None:
104 | output = self.norm(output.permute(1, 2, 0))
105 | return output
106 |
107 |
108 | class CAB(nn.Module):
109 | """
110 | Context Augment Block (CAB)
111 | :param d_model: number of channels in input
112 | :param nhead: number of heads in attention module
113 | :param dim_feedforward: dimension in feedforward
114 | :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
115 | :param affine: affine in normalization
116 | :param norm: normalization function 'instance, batch'
117 | """
118 | def __init__(self, d_model, nhead, dim_feedforward=2048,
119 | activation="LeakyReLU", affine=True, norm='instance'):
120 | super().__init__()
121 | self.self_attn = nn.MultiheadAttention(d_model, nhead)
122 | self.linear1 = nn.Linear(d_model, dim_feedforward)
123 | self.linear2 = nn.Linear(dim_feedforward, d_model)
124 |
125 | if norm == 'batch':
126 | self.norm1 = nn.BatchNorm1d(d_model, affine=affine)
127 | self.norm2 = nn.BatchNorm1d(d_model, affine=affine)
128 | else:
129 | self.norm1 = nn.InstanceNorm1d(d_model, affine=affine)
130 | self.norm2 = nn.InstanceNorm1d(d_model, affine=affine)
131 |
132 | self.activation = get_nonlinearity_layer(activation)
133 |
134 | def with_pos_embed(self, tensor, pos):
135 | return tensor if pos is None else tensor + pos
136 |
137 | def forward(self, src, pos = None):
138 | q = k = self.with_pos_embed(src, pos)
139 | src2 = self.self_attn(q, k, value=src)[0]
140 | src = src + src2
141 | src = self.norm1(src.permute(1, 2, 0)).permute(2, 0, 1)
142 | src2 = self.linear2(self.activation(self.linear1(src)))
143 | src = src + src2
144 | src = self.norm2(src.permute(1, 2, 0)).permute(2, 0, 1)
145 | return src
146 |
147 |
148 | class TTB(nn.Module):
149 | """
150 | Texture Transfer Block (TTB)
151 | :param d_model: number of channels in input
152 | :param nhead: number of heads in attention module
153 | :param dim_feedforward: dimension in feedforward
154 | :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
155 | :param affine: affine in normalization
156 | :param norm: normalization function 'instance, batch'
157 | """
158 | def __init__(self, d_model, nhead, dim_feedforward=2048,
159 | activation="LeakyReLU", affine=True, norm='instance'):
160 | super().__init__()
161 | self.self_attn = nn.MultiheadAttention(d_model, nhead)
162 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead)
163 | self.linear1 = nn.Linear(d_model, dim_feedforward)
164 | self.linear2 = nn.Linear(dim_feedforward, d_model)
165 |
166 | if norm == 'batch':
167 | self.norm1 = nn.BatchNorm1d(d_model, affine=affine)
168 | self.norm2 = nn.BatchNorm1d(d_model, affine=affine)
169 | self.norm3 = nn.BatchNorm1d(d_model, affine=affine)
170 | else:
171 | self.norm1 = nn.InstanceNorm1d(d_model, affine=affine)
172 | self.norm2 = nn.InstanceNorm1d(d_model, affine=affine)
173 | self.norm3 = nn.InstanceNorm1d(d_model, affine=affine)
174 |
175 | self.activation = get_nonlinearity_layer(activation)
176 |
177 | def with_pos_embed(self, tensor, pos):
178 | return tensor if pos is None else tensor + pos
179 |
180 | def forward(self, tgt, memory, val, pos = None):
181 | q = k = self.with_pos_embed(tgt, pos)
182 | tgt2 = self.self_attn(q, k, value=tgt)[0]
183 | tgt = tgt + tgt2
184 | tgt = self.norm1(tgt.permute(1, 2, 0)).permute(2, 0, 1)
185 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, pos),
186 | key=self.with_pos_embed(memory, pos),
187 | value=val)[0]
188 | tgt = tgt + tgt2
189 | tgt = self.norm2(tgt.permute(1, 2, 0)).permute(2, 0, 1)
190 | tgt2 = self.linear2(self.activation(self.linear1(tgt)))
191 | tgt = tgt + tgt2
192 | tgt = self.norm3(tgt.permute(1, 2, 0)).permute(2, 0, 1)
193 | return tgt
194 |
195 |
196 | def _get_clones(module, N):
197 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
198 |
199 |
200 |
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import ntpath
4 | import time
5 | from . import util
6 | from . import html
7 | import scipy.misc
8 | try:
9 | from StringIO import StringIO # Python 2.7
10 | except ImportError:
11 | from io import BytesIO # Python 3.x
12 |
13 | class Visualizer():
14 | def __init__(self, opt):
15 | # self.opt = opt
16 | self.display_id = opt.display_id
17 | self.use_html = opt.isTrain and not opt.no_html
18 | self.win_size = opt.display_winsize
19 | self.name = opt.name
20 | if self.display_id > 0:
21 | import visdom
22 | self.vis = visdom.Visdom(port=opt.display_port, env=opt.display_env)
23 | self.display_single_pane_ncols = opt.display_single_pane_ncols
24 | self.use_html = 1
25 | if self.use_html:
26 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
27 | self.img_dir = os.path.join(self.web_dir, 'images')
28 | print('create web directory %s...' % self.web_dir)
29 | util.mkdirs([self.web_dir, self.img_dir])
30 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
31 | self.eval_log_name = os.path.join(opt.checkpoints_dir, opt.name, 'eval_log.txt')
32 | with open(self.log_name, "a") as log_file:
33 | now = time.strftime("%c")
34 | log_file.write('================ Training Loss (%s) ================\n' % now)
35 |
36 | # |visuals|: dictionary of images to display or save
37 | def display_current_results(self, visuals, epoch):
38 | if self.display_id > 0: # show images in the browser
39 | if self.display_single_pane_ncols > 0:
40 | h, w = next(iter(visuals.values())).shape[:2]
41 | table_css = """""" % (w, h)
45 | ncols = self.display_single_pane_ncols
46 | title = self.name
47 | label_html = ''
48 | label_html_row = ''
49 | nrows = int(np.ceil(len(visuals.items()) / ncols))
50 | images = []
51 | idx = 0
52 | for label, image_numpy in visuals.items():
53 | label_html_row += '%s | ' % label
54 | images.append(image_numpy.transpose([2, 0, 1]))
55 | idx += 1
56 | if idx % ncols == 0:
57 | label_html += '%s
' % label_html_row
58 | label_html_row = ''
59 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
60 | while idx % ncols != 0:
61 | images.append(white_image)
62 | label_html_row += ' | '
63 | idx += 1
64 | if label_html_row != '':
65 | label_html += '%s
' % label_html_row
66 | # pane col = image row
67 | self.vis.images(images, nrow=ncols, win=self.display_id + 1,
68 | padding=2, opts=dict(title=title + ' images'))
69 | label_html = '' % label_html
70 | self.vis.text(table_css + label_html, win = self.display_id + 2,
71 | opts=dict(title=title + ' labels'))
72 | else:
73 | idx = 1
74 | for label, image_numpy in visuals.items():
75 | #image_numpy = np.flipud(image_numpy)
76 | self.vis.image(image_numpy.transpose([2,0,1]), opts=dict(title=label),
77 | win=self.display_id + idx)
78 | idx += 1
79 | if self.use_html: # save images to a html file
80 | for label, image_numpy in visuals.items():
81 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
82 | util.save_image(image_numpy, img_path)
83 | # update website
84 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1)
85 | for n in range(epoch, 0, -1):
86 | webpage.add_header('epoch [%d]' % n)
87 | ims = []
88 | txts = []
89 | links = []
90 |
91 | for label, image_numpy in visuals.items():
92 | img_path = 'epoch%.3d_%s.png' % (n, label)
93 | ims.append(img_path)
94 | txts.append(label)
95 | links.append(img_path)
96 | webpage.add_images(ims, txts, links, width=self.win_size)
97 | webpage.save()
98 |
99 | # errors: dictionary of error labels and values
100 | def plot_current_errors(self, iters, errors):
101 | if not hasattr(self, 'plot_data'):
102 | self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())}
103 | self.plot_data['X'].append(iters)
104 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
105 | '''
106 | self.vis.line(
107 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
108 | Y=np.array(self.plot_data['Y']),
109 | opts={'title': self.name + ' loss over time',
110 | 'legend': self.plot_data['legend'],
111 | 'xlabel': 'iterations',
112 | 'ylabel': 'loss'},
113 | win=self.display_id)
114 | '''
115 |
116 | def plot_current_score(self, iters, scores):
117 | if not hasattr(self, 'plot_score'):
118 | self.plot_score = {'X':[],'Y':[], 'legend':list(scores.keys())}
119 | self.plot_score['X'].append(iters)
120 | self.plot_score['Y'].append([scores[k] for k in self.plot_score['legend']])
121 | '''
122 | self.vis.line(
123 | X=np.stack([np.array(self.plot_score['X'])] * len(self.plot_score['legend']), 1),
124 | Y=np.array(self.plot_score['Y']),
125 | opts={
126 | 'title': self.name + ' Evaluation Score over time',
127 | 'legend': self.plot_score['legend'],
128 | 'xlabel': 'iters',
129 | 'ylabel': 'score'},
130 | win=self.display_id + 29
131 | )
132 | '''
133 |
134 | # statistics distribution: draw data histogram
135 | def plot_current_distribution(self, distribution):
136 | name = list(distribution.keys())
137 | value = np.array(list(distribution.values())).swapaxes(1, 0)
138 | self.vis.boxplot(
139 | X=value,
140 | opts=dict(legend=name),
141 | win=self.display_id+30
142 | )
143 |
144 | # errors: same format as |errors| of plotCurrentErrors
145 | def print_current_errors(self, epoch, iter, i, errors, lr_G, lr_D, t):
146 | message = '(epoch: %d, iters: %d, total iters: %d, time: %.3f) ' % (epoch, iter, i, t)
147 | for k, v in errors.items():
148 | message += '%s: %.3f ' % (k, v)
149 | message += 'learning_rate_g: %.10f' % lr_G
150 | message += ' learning_rate_d: %.10f' % lr_D
151 | print(message)
152 | with open(self.log_name, "a") as log_file:
153 | log_file.write('%s\n' % message)
154 |
155 | def print_current_eval(self, epoch, i, score):
156 | message = '(epoch: %d, iters: %d)' % (epoch, i)
157 | for k, v in score.items():
158 | message += '%s: %.3f ' % (k, v)
159 |
160 | print(message)
161 | with open(self.eval_log_name, "a") as log_file:
162 | log_file.write('%s\n' % message)
163 |
164 | # save image to the disk
165 | def save_images(self, webpage, visuals, image_path):
166 | image_dir = webpage.get_image_dir()
167 | short_path = ntpath.basename(image_path[0])
168 | name = os.path.splitext(short_path)[0]
169 |
170 | webpage.add_header(name)
171 | ims = []
172 | txts = []
173 | links = []
174 |
175 | for label, image_numpy in visuals.items():
176 | image_name = '%s_%s.png' % (name, label)
177 | save_path = os.path.join(image_dir, image_name)
178 | util.save_image(image_numpy, save_path)
179 |
180 | ims.append(image_name)
181 | txts.append(label)
182 | links.append(image_name)
183 | webpage.add_images(ims, txts, links, width=self.win_size)
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import sys
4 | from collections import OrderedDict
5 | from util import util
6 | from util import pose_utils
7 | import numpy as np
8 | import ntpath
9 | import cv2
10 |
11 | class BaseModel():
12 | def name(self):
13 | return 'BaseModel'
14 |
15 | def __init__(self, opt):
16 | self.opt = opt
17 | self.gpu_ids = opt.gpu_ids
18 | self.isTrain = opt.isTrain
19 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
20 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
21 |
22 | def set_input(self, input):
23 | self.input = input
24 |
25 | def forward(self):
26 | pass
27 |
28 | # used in test time, no backprop
29 | def test(self):
30 | pass
31 |
32 | def get_image_paths(self):
33 | return self.image_paths
34 |
35 | def optimize_parameters(self):
36 | pass
37 |
38 | def get_current_visuals(self):
39 | """Return visualization images"""
40 | visual_ret = OrderedDict()
41 | for name in self.visual_names:
42 | if isinstance(name, str):
43 | value = getattr(self, name)
44 | if isinstance(value, list):
45 | # visual multi-scale ouputs
46 | for i in range(len(value)):
47 | visual_ret[name + str(i)] = self.convert2im(value[i], name)
48 | else:
49 | visual_ret[name] =self.convert2im(value, name)
50 | return visual_ret
51 |
52 | def convert2im(self, value, name):
53 | if 'label' in name:
54 | convert = getattr(self, 'label2color')
55 | value = convert(value)
56 |
57 | if 'flow' in name: # flow_field
58 | convert = getattr(self, 'flow2color')
59 | value = convert(value)
60 |
61 | if value.size(1) == 18: # bone_map
62 | value = np.transpose(value[0].detach().cpu().numpy(),(1,2,0))
63 | value = pose_utils.draw_pose_from_map(value)[0]
64 | result = value
65 |
66 | elif value.size(1) == 21: # bone_map + color image
67 | value = np.transpose(value[0,-3:,...].detach().cpu().numpy(),(1,2,0))
68 | # value = pose_utils.draw_pose_from_map(value)[0]
69 | result = value.astype(np.uint8)
70 |
71 | else:
72 | result = util.tensor2im(value.data)
73 | return result
74 |
75 | def get_current_errors(self):
76 | """Return training loss"""
77 | errors_ret = OrderedDict()
78 | for name in self.loss_names:
79 | if isinstance(name, str):
80 | errors_ret[name] = getattr(self, 'loss_' + name).item()
81 | return errors_ret
82 |
83 | def save(self, label):
84 | pass
85 |
86 | # save model
87 | def save_networks(self, which_epoch):
88 | """Save all the networks to the disk"""
89 | for name in self.model_names:
90 | if isinstance(name, str):
91 | save_filename = '%s_net_%s.pth' % (which_epoch, name)
92 | save_path = os.path.join(self.save_dir, save_filename)
93 | net = getattr(self, 'net_' + name)
94 | torch.save(net.cpu().state_dict(), save_path)
95 | if len(self.gpu_ids) > 0 and torch.cuda.is_available():
96 | net.cuda()
97 |
98 | # load models
99 | def load_networks(self, which_epoch):
100 | """Load all the networks from the disk"""
101 | for name in self.model_names:
102 | if isinstance(name, str):
103 | filename = '%s_net_%s.pth' % (which_epoch, name)
104 | path = os.path.join(self.save_dir, filename)
105 | net = getattr(self, 'net_' + name)
106 | try:
107 | '''
108 | new_dict = {}
109 | pretrained_dict = torch.load(path)
110 | for k, v in pretrained_dict.items():
111 | if 'transformer' in k:
112 | new_dict[k.replace('transformer', 'PTM')] = v
113 | else:
114 | new_dict[k] = v
115 |
116 | net.load_state_dict(new_dict)
117 | '''
118 | net.load_state_dict(torch.load(path))
119 | print('load %s from %s' % (name, filename))
120 | except FileNotFoundError:
121 | print('do not find checkpoint for network %s'%name)
122 | continue
123 | except:
124 | pretrained_dict = torch.load(path)
125 | model_dict = net.state_dict()
126 | try:
127 | pretrained_dict_ = {k: v for k, v in pretrained_dict.items() if k in model_dict}
128 | if len(pretrained_dict_) == 0:
129 | pretrained_dict_ = {k.replace('module.', ''): v for k, v in pretrained_dict.items() if
130 | k.replace('module.', '') in model_dict}
131 | if len(pretrained_dict_) == 0:
132 | pretrained_dict_ = {('module.' + k): v for k, v in pretrained_dict.items() if
133 | 'module.' + k in model_dict}
134 |
135 | pretrained_dict = pretrained_dict_
136 | net.load_state_dict(pretrained_dict)
137 | print('Pretrained network %s has excessive layers; Only loading layers that are used' % name)
138 | except:
139 | print('Pretrained network %s has fewer layers; The following are not initialized:' % name)
140 | not_initialized = set()
141 | for k, v in pretrained_dict.items():
142 | if v.size() == model_dict[k].size():
143 | model_dict[k] = v
144 |
145 | for k, v in model_dict.items():
146 | if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
147 | # not_initialized.add(k)
148 | not_initialized.add(k.split('.')[0])
149 | print(sorted(not_initialized))
150 | net.load_state_dict(model_dict)
151 | if len(self.gpu_ids) > 0 and torch.cuda.is_available():
152 | net.cuda()
153 | if not self.isTrain:
154 | net.eval()
155 |
156 | def update_learning_rate(self, epoch=None):
157 | """Update learning rate"""
158 | for scheduler in self.schedulers:
159 | if epoch == None:
160 | scheduler.step()
161 | else:
162 | scheduler.step(epoch)
163 | lr = self.optimizers[0].param_groups[0]['lr']
164 | print('learning rate=%.7f' % lr)
165 |
166 | def get_current_learning_rate(self):
167 | lr_G = self.optimizers[0].param_groups[0]['lr']
168 | lr_D = self.optimizers[1].param_groups[0]['lr']
169 | return lr_G, lr_D
170 |
171 | def save_results(self, save_data, old_size, data_name='none', data_ext='jpg'):
172 | """Save the training or testing results to disk"""
173 | img_paths = self.get_image_paths()
174 |
175 | for i in range(save_data.size(0)):
176 | print('process image ...... %s' % img_paths[i])
177 | short_path = ntpath.basename(img_paths[i]) # get image path
178 | name = os.path.splitext(short_path)[0]
179 | img_name = '%s_%s.%s' % (name, data_name, data_ext)
180 |
181 | util.mkdir(self.opt.results_dir)
182 | img_path = os.path.join(self.opt.results_dir, img_name)
183 | img_numpy = util.tensor2im(save_data[i].data)
184 | img_numpy = cv2.resize(img_numpy, (old_size[1], old_size[0]))
185 | util.save_image(img_numpy, img_path)
186 |
187 | def save_chair_results(self, save_data, old_size, img_path, data_name='none', data_ext='jpg'):
188 | """Save the training or testing results to disk"""
189 | img_paths = self.get_image_paths()
190 | print(save_data.shape)
191 | for i in range(save_data.size(0)):
192 | print('process image ...... %s' % img_paths[i])
193 | short_path = ntpath.basename(img_paths[i]) # get image path
194 | name = os.path.splitext(short_path)[0]
195 | img_name = '%s_%s.%s' % (name, data_name, data_ext)
196 |
197 | util.mkdir(self.opt.results_dir)
198 | img_numpy = util.tensor2im(save_data[i].data)
199 | img_numpy = cv2.resize(img_numpy, (old_size[1], old_size[0]))
200 | util.save_image(img_numpy, img_path)
201 |
--------------------------------------------------------------------------------
/models/DPTN_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import os
4 | import itertools
5 | from torch.autograd import Variable
6 | from util.image_pool import ImagePool
7 | from .base_model import BaseModel
8 | from . import networks
9 | from . import external_function
10 | from . import base_function
11 |
12 |
13 | class DPTNModel(BaseModel):
14 | def name(self):
15 | return 'DPTNModel'
16 |
17 | @staticmethod
18 | def modify_options(parser, is_train=True):
19 | """Add new options and rewrite default values for existing options"""
20 | parser.add_argument('--init_type', type=str, default='orthogonal', help='initial type')
21 | parser.add_argument('--use_spect_g', action='store_false', help='use spectual normalization in generator')
22 | parser.add_argument('--use_spect_d', action='store_false', help='use spectual normalization in generator')
23 | parser.add_argument('--use_coord', action='store_true', help='use coordconv')
24 | parser.add_argument('--lambda_style', type=float, default=500, help='weight for the VGG19 style loss')
25 | parser.add_argument('--lambda_content', type=float, default=0.5, help='weight for the VGG19 content loss')
26 | parser.add_argument('--layers_g', type=int, default=3, help='number of layers in G')
27 | parser.add_argument('--save_input', action='store_true', help="whether save the input images when testing")
28 | parser.add_argument('--num_blocks', type=int, default=3, help="number of resblocks")
29 | parser.add_argument('--affine', action='store_true', default=True, help="affine in PTM")
30 | parser.add_argument('--nhead', type=int, default=2, help="number of heads in PTM")
31 | parser.add_argument('--num_CABs', type=int, default=2, help="number of CABs in PTM")
32 | parser.add_argument('--num_TTBs', type=int, default=2, help="number of CABs in PTM")
33 |
34 | # if is_train:
35 | parser.add_argument('--ratio_g2d', type=float, default=0.1, help='learning rate ratio G to D')
36 | parser.add_argument('--lambda_rec', type=float, default=5.0, help='weight for image reconstruction loss')
37 | parser.add_argument('--lambda_g', type=float, default=2.0, help='weight for generation loss')
38 | parser.add_argument('--t_s_ratio', type=float, default=0.5, help='loss ratio between dual tasks')
39 | parser.add_argument('--dis_layers', type=int, default=4, help='number of layers in D')
40 | parser.set_defaults(use_spect_g=False)
41 | parser.set_defaults(use_spect_d=True)
42 | return parser
43 |
44 | def __init__(self, opt):
45 | BaseModel.__init__(self, opt)
46 | self.old_size = opt.old_size
47 | self.t_s_ratio = opt.t_s_ratio
48 | self.loss_names = ['app_gen_s', 'content_gen_s', 'style_gen_s', 'app_gen_t', 'ad_gen_t', 'dis_img_gen_t', 'content_gen_t', 'style_gen_t']
49 | self.model_names = ['G']
50 | self.visual_names = ['source_image', 'source_pose', 'target_image', 'target_pose', 'fake_image_s', 'fake_image_t']
51 |
52 | self.net_G = networks.define_G(opt, image_nc=opt.image_nc, pose_nc=opt.structure_nc, ngf=64, img_f=512,
53 | encoder_layer=3, norm=opt.norm, activation='LeakyReLU',
54 | use_spect=opt.use_spect_g, use_coord=opt.use_coord, output_nc=3, num_blocks=3, affine=True, nhead=opt.nhead, num_CABs=opt.num_CABs, num_TTBs=opt.num_TTBs)
55 |
56 | # Discriminator network
57 | if self.isTrain:
58 | self.model_names = ['G', 'D']
59 | self.net_D = networks.define_D(opt, ndf=32, img_f=128, layers=opt.dis_layers, use_spect=opt.use_spect_d)
60 |
61 | if self.opt.verbose:
62 | print('---------- Networks initialized -------------')
63 | # set loss functions and optimizers
64 | if self.isTrain:
65 | if opt.pool_size > 0 and (len(self.gpu_ids)) > 1:
66 | raise NotImplementedError("Fake Pool Not Implemented for MultiGPU")
67 | #self.fake_pool = ImagePool(opt.pool_size)
68 | self.old_lr = opt.lr
69 |
70 | self.GANloss = external_function.GANLoss(opt.gan_mode).to(opt.device)
71 | self.L1loss = torch.nn.L1Loss()
72 | self.Vggloss = external_function.VGGLoss().to(opt.device)
73 |
74 | # define the optimizer
75 | self.optimizer_G = torch.optim.Adam(itertools.chain(
76 | filter(lambda p: p.requires_grad, self.net_G.parameters())),
77 | lr=opt.lr, betas=(opt.beta1, 0.999))
78 | self.optimizers = []
79 | self.optimizers.append(self.optimizer_G)
80 | self.optimizer_D = torch.optim.Adam(itertools.chain(
81 | filter(lambda p: p.requires_grad, self.net_D.parameters())),
82 | lr=opt.lr * opt.ratio_g2d, betas=(opt.beta1, 0.999))
83 | self.optimizers.append(self.optimizer_D)
84 |
85 | self.schedulers = [base_function.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
86 | else:
87 | self.net_G.eval()
88 |
89 | if not self.isTrain or opt.continue_train:
90 | print('model resumed from latest')
91 | self.load_networks(opt.which_epoch)
92 |
93 | def set_input(self, input):
94 | self.input = input
95 | source_image, source_pose = input['Xs'], input['Ps']
96 | target_image, target_pose = input['Xt'], input['Pt']
97 | if len(self.gpu_ids) > 0:
98 | self.source_image = source_image.cuda()
99 | self.source_pose = source_pose.cuda()
100 | self.target_image = target_image.cuda()
101 | self.target_pose = target_pose.cuda()
102 |
103 | self.image_paths = []
104 | for i in range(self.source_image.size(0)):
105 | self.image_paths.append(os.path.splitext(input['Xs_path'][i])[0] + '_2_' + input['Xt_path'][i])
106 |
107 | def forward(self):
108 | # Encode Inputs
109 | self.fake_image_t, self.fake_image_s = self.net_G(self.source_image, self.source_pose, self.target_pose)
110 |
111 | def test(self):
112 | """Forward function used in test time"""
113 | fake_image_t, fake_image_s = self.net_G(self.source_image, self.source_pose, self.target_pose, False)
114 | self.save_results(fake_image_t, self.old_size, data_name='vis')
115 |
116 | def backward_D_basic(self, netD, real, fake):
117 | # Real
118 | D_real = netD(real)
119 | D_real_loss = self.GANloss(D_real, True, True)
120 | # fake
121 | D_fake = netD(fake.detach())
122 | D_fake_loss = self.GANloss(D_fake, False, True)
123 | # loss for discriminator
124 | D_loss = (D_real_loss + D_fake_loss) * 0.5
125 | # gradient penalty for wgan-gp
126 | if self.opt.gan_mode == 'wgangp':
127 | gradient_penalty, gradients = external_function.cal_gradient_penalty(netD, real, fake.detach())
128 | D_loss += gradient_penalty
129 |
130 | return D_loss
131 |
132 | def backward_D(self):
133 | base_function._unfreeze(self.net_D)
134 | self.loss_dis_img_gen_t = self.backward_D_basic(self.net_D, self.target_image, self.fake_image_t)
135 | D_loss = self.loss_dis_img_gen_t
136 | D_loss.backward()
137 |
138 | def backward_G_basic(self, fake_image, target_image, use_d):
139 | # Calculate reconstruction loss
140 | loss_app_gen = self.L1loss(fake_image, target_image)
141 | loss_app_gen = loss_app_gen * self.opt.lambda_rec
142 |
143 | # Calculate GAN loss
144 | loss_ad_gen = None
145 | if use_d:
146 | base_function._freeze(self.net_D)
147 | D_fake = self.net_D(fake_image)
148 | loss_ad_gen = self.GANloss(D_fake, True, False) * self.opt.lambda_g
149 |
150 | # Calculate perceptual loss
151 | loss_content_gen, loss_style_gen = self.Vggloss(fake_image, target_image)
152 | loss_style_gen = loss_style_gen * self.opt.lambda_style
153 | loss_content_gen = loss_content_gen * self.opt.lambda_content
154 |
155 | return loss_app_gen, loss_ad_gen, loss_style_gen, loss_content_gen
156 |
157 | def backward_G(self):
158 | base_function._unfreeze(self.net_D)
159 |
160 | self.loss_app_gen_t, self.loss_ad_gen_t, self.loss_style_gen_t, self.loss_content_gen_t = self.backward_G_basic(self.fake_image_t, self.target_image, use_d = True)
161 |
162 | self.loss_app_gen_s, self.loss_ad_gen_s, self.loss_style_gen_s, self.loss_content_gen_s = self.backward_G_basic(self.fake_image_s, self.source_image, use_d = False)
163 | G_loss = self.t_s_ratio*(self.loss_app_gen_t+self.loss_style_gen_t+self.loss_content_gen_t) + (1-self.t_s_ratio)*(self.loss_app_gen_s+self.loss_style_gen_s+self.loss_content_gen_s)+self.loss_ad_gen_t
164 | G_loss.backward()
165 |
166 | def optimize_parameters(self):
167 | self.forward()
168 |
169 | self.optimizer_D.zero_grad()
170 | self.backward_D()
171 | self.optimizer_D.step()
172 |
173 | self.optimizer_G.zero_grad()
174 | self.backward_G()
175 | self.optimizer_G.step()
176 |
177 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | import numpy as np
6 | import os
7 | import imageio
8 |
9 | # Converts a Tensor into a Numpy array
10 | # |imtype|: the desired type of the converted numpy array
11 | def tensor2im(image_tensor, imtype=np.uint8, normalize=True):
12 | if isinstance(image_tensor, list):
13 | image_numpy = []
14 | for i in range(len(image_tensor)):
15 | image_numpy.append(tensor2im(image_tensor[i], imtype, normalize))
16 | return image_numpy
17 | if image_tensor.dim() == 3:
18 | image_numpy = image_tensor.cpu().float().numpy()
19 | else:
20 | image_numpy = image_tensor[0].cpu().float().numpy()
21 | if normalize:
22 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
23 | else:
24 | image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0
25 | image_numpy = np.clip(image_numpy, 0, 255)
26 | if image_numpy.shape[2] == 1 or image_numpy.shape[2] > 3:
27 | image_numpy = image_numpy[:,:,0]
28 | return image_numpy.astype(imtype)
29 |
30 | # Converts a one-hot tensor into a colorful label map
31 | def tensor2label(label_tensor, n_label, imtype=np.uint8):
32 | if n_label == 0:
33 | return tensor2im(label_tensor, imtype)
34 | label_tensor = label_tensor.cpu().float()
35 | if label_tensor.size()[0] > 1:
36 | label_tensor = label_tensor.max(0, keepdim=True)[1]
37 | label_tensor = Colorize(n_label)(label_tensor)
38 | label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
39 | return label_numpy.astype(imtype)
40 |
41 | def save_image(image_numpy, image_path):
42 | image_pil = Image.fromarray(image_numpy)
43 | image_pil.save(image_path)
44 |
45 | def mkdirs(paths):
46 | if isinstance(paths, list) and not isinstance(paths, str):
47 | for path in paths:
48 | mkdir(path)
49 | else:
50 | mkdir(paths)
51 |
52 | def mkdir(path):
53 | if not os.path.exists(path):
54 | os.makedirs(path)
55 |
56 | ###############################################################################
57 | # Code from
58 | # https://github.com/ycszen/pytorch-seg/blob/master/transform.py
59 | # Modified so it complies with the Citscape label map colors
60 | ###############################################################################
61 | def uint82bin(n, count=8):
62 | """returns the binary of integer n, count refers to amount of bits"""
63 | return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
64 |
65 | def labelcolormap(N):
66 | if N == 35: # cityscape
67 | cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81),
68 | (128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153),
69 | (180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0),
70 | (107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70),
71 | ( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)],
72 | dtype=np.uint8)
73 | else:
74 | cmap = np.zeros((N, 3), dtype=np.uint8)
75 | for i in range(N):
76 | r, g, b = 0, 0, 0
77 | id = i
78 | for j in range(7):
79 | str_id = uint82bin(id)
80 | r = r ^ (np.uint8(str_id[-1]) << (7-j))
81 | g = g ^ (np.uint8(str_id[-2]) << (7-j))
82 | b = b ^ (np.uint8(str_id[-3]) << (7-j))
83 | id = id >> 3
84 | cmap[i, 0] = r
85 | cmap[i, 1] = g
86 | cmap[i, 2] = b
87 | return cmap
88 |
89 | class Colorize(object):
90 | def __init__(self, n=35):
91 | self.cmap = labelcolormap(n)
92 | self.cmap = torch.from_numpy(self.cmap[:n])
93 |
94 | def __call__(self, gray_image):
95 | size = gray_image.size()
96 | color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
97 |
98 | for label in range(0, len(self.cmap)):
99 | mask = (label == gray_image[0]).cpu()
100 | color_image[0][mask] = self.cmap[label][0]
101 | color_image[1][mask] = self.cmap[label][1]
102 | color_image[2][mask] = self.cmap[label][2]
103 |
104 | return color_image
105 |
106 | def make_colorwheel():
107 | '''
108 | Generates a color wheel for optical flow visualization as presented in:
109 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
110 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
111 | According to the C++ source code of Daniel Scharstein
112 | According to the Matlab source code of Deqing Sun
113 | '''
114 | RY = 15
115 | YG = 6
116 | GC = 4
117 | CB = 11
118 | BM = 13
119 | MR = 6
120 |
121 | ncols = RY + YG + GC + CB + BM + MR
122 | colorwheel = np.zeros((ncols, 3))
123 | col = 0
124 |
125 | # RY
126 | colorwheel[0:RY, 0] = 255
127 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
128 | col = col + RY
129 | # YG
130 | colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
131 | colorwheel[col:col + YG, 1] = 255
132 | col = col + YG
133 | # GC
134 | colorwheel[col:col + GC, 1] = 255
135 | colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
136 | col = col + GC
137 | # CB
138 | colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
139 | colorwheel[col:col + CB, 2] = 255
140 | col = col + CB
141 | # BM
142 | colorwheel[col:col + BM, 2] = 255
143 | colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
144 | col = col + BM
145 | # MR
146 | colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
147 | colorwheel[col:col + MR, 0] = 255
148 | return colorwheel
149 |
150 |
151 | class flow2color():
152 | # code from: https://github.com/tomrunia/OpticalFlow_Visualization
153 | # MIT License
154 | #
155 | # Copyright (c) 2018 Tom Runia
156 | #
157 | # Permission is hereby granted, free of charge, to any person obtaining a copy
158 | # of this software and associated documentation files (the "Software"), to deal
159 | # in the Software without restriction, including without limitation the rights
160 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
161 | # copies of the Software, and to permit persons to whom the Software is
162 | # furnished to do so, subject to conditions.
163 | #
164 | # Author: Tom Runia
165 | # Date Created: 2018-08-03
166 | def __init__(self):
167 | self.colorwheel = make_colorwheel()
168 |
169 | def flow_compute_color(self, u, v, convert_to_bgr=False):
170 | '''
171 | Applies the flow color wheel to (possibly clipped) flow components u and v.
172 | According to the C++ source code of Daniel Scharstein
173 | According to the Matlab source code of Deqing Sun
174 | :param u: np.ndarray, input horizontal flow
175 | :param v: np.ndarray, input vertical flow
176 | :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB
177 | :return:
178 | '''
179 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
180 | ncols = self.colorwheel.shape[0]
181 |
182 | rad = np.sqrt(np.square(u) + np.square(v))
183 | a = np.arctan2(-v, -u) / np.pi
184 | fk = (a + 1) / 2 * (ncols - 1)
185 | k0 = np.floor(fk).astype(np.int32)
186 | k1 = k0 + 1
187 | k1[k1 == ncols] = 0
188 | f = fk - k0
189 |
190 | for i in range(self.colorwheel.shape[1]):
191 | tmp = self.colorwheel[:, i]
192 | col0 = tmp[k0] / 255.0
193 | col1 = tmp[k1] / 255.0
194 | col = (1 - f) * col0 + f * col1
195 |
196 | idx = (rad <= 1)
197 | col[idx] = 1 - rad[idx] * (1 - col[idx])
198 | col[~idx] = col[~idx] * 0.75 # out of range?
199 |
200 | # Note the 2-i => BGR instead of RGB
201 | ch_idx = 2 - i if convert_to_bgr else i
202 | flow_image[:, :, ch_idx] = np.floor(255 * col)
203 |
204 | return flow_image
205 |
206 | def __call__(self, flow_uv, clip_flow=None, convert_to_bgr=False):
207 | '''
208 | Expects a two dimensional flow image of shape [H,W,2]
209 | According to the C++ source code of Daniel Scharstein
210 | According to the Matlab source code of Deqing Sun
211 | :param flow_uv: np.ndarray of shape [H,W,2]
212 | :param clip_flow: float, maximum clipping value for flow
213 | :return:
214 | '''
215 | if len(flow_uv.size()) != 3:
216 | flow_uv = flow_uv[0]
217 | flow_uv = flow_uv.permute(1, 2, 0).cpu().detach().numpy()
218 |
219 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
220 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
221 |
222 | if clip_flow is not None:
223 | flow_uv = np.clip(flow_uv, 0, clip_flow)
224 |
225 | u = flow_uv[:, :, 1]
226 | v = flow_uv[:, :, 0]
227 |
228 | rad = np.sqrt(np.square(u) + np.square(v))
229 | rad_max = np.max(rad)
230 |
231 | epsilon = 1e-5
232 | u = u / (rad_max + epsilon)
233 | v = v / (rad_max + epsilon)
234 | image = self.flow_compute_color(u, v, convert_to_bgr)
235 | image = torch.tensor(image).float().permute(2, 0, 1) / 255.0 * 2 - 1
236 | return image
237 |
238 |
239 | def save_image(image_numpy, image_path):
240 | if image_numpy.shape[2] == 1:
241 | image_numpy = image_numpy.reshape(image_numpy.shape[0], image_numpy.shape[1])
242 | #image_numpy = cv2.resize(image_numpy, (176,256))
243 |
244 | imageio.imwrite(image_path, image_numpy)
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import functools
4 | from torch.autograd import Variable
5 | import numpy as np
6 | from .base_function import *
7 | from .PTM import PTM
8 |
9 |
10 | ###############################################################################
11 | # Functions
12 | ###############################################################################
13 | def define_G(opt, image_nc, pose_nc, ngf=64, img_f=1024, encoder_layer=3, norm='batch',
14 | activation='ReLU', use_spect=True, use_coord=False, output_nc=3, num_blocks=3, affine=True, nhead=2, num_CABs=2, num_TTBs=2):
15 | print(opt.model)
16 | if opt.model == 'DPTN':
17 | netG = DPTNGenerator(image_nc, pose_nc, ngf, img_f, encoder_layer, norm, activation, use_spect, use_coord, output_nc, num_blocks, affine, nhead, num_CABs, num_TTBs)
18 | else:
19 | raise('generator not implemented!')
20 | return init_net(netG, opt.init_type, opt.gpu_ids)
21 |
22 |
23 | def define_D(opt, input_nc=3, ndf=64, img_f=1024, layers=3, norm='none', activation='LeakyReLU', use_spect=True,):
24 | netD = ResDiscriminator(input_nc, ndf, img_f, layers, norm, activation, use_spect)
25 | return init_net(netD, opt.init_type, opt.gpu_ids)
26 |
27 |
28 | def print_network(net):
29 | if isinstance(net, list):
30 | net = net[0]
31 | num_params = 0
32 | for param in net.parameters():
33 | num_params += param.numel()
34 | print(net)
35 | print('Total number of parameters: %d' % num_params)
36 |
37 |
38 | ##############################################################################
39 | # Generator
40 | ##############################################################################
41 | class SourceEncoder(nn.Module):
42 | """
43 | Source Image Encoder (En_s)
44 | :param image_nc: number of channels in input image
45 | :param ngf: base filter channel
46 | :param img_f: the largest feature channels
47 | :param encoder_layer: encoder layers
48 | :param norm: normalization function 'instance, batch, group'
49 | :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
50 | :param use_spect: use spectual normalization
51 | :param use_coord: use coordConv operation
52 | """
53 | def __init__(self, image_nc, ngf=64, img_f=1024, encoder_layer=3, norm='batch',
54 | activation='ReLU', use_spect=True, use_coord=False):
55 | super(SourceEncoder, self).__init__()
56 |
57 | self.encoder_layer = encoder_layer
58 |
59 | norm_layer = get_norm_layer(norm_type=norm)
60 | nonlinearity = get_nonlinearity_layer(activation_type=activation)
61 | input_nc = image_nc
62 |
63 | self.block0 = EncoderBlockOptimized(input_nc, ngf, norm_layer,
64 | nonlinearity, use_spect, use_coord)
65 | mult = 1
66 | for i in range(encoder_layer - 1):
67 | mult_prev = mult
68 | mult = min(2 ** (i + 1), img_f // ngf)
69 | block = EncoderBlock(ngf * mult_prev, ngf * mult, norm_layer,
70 | nonlinearity, use_spect, use_coord)
71 | setattr(self, 'encoder' + str(i), block)
72 |
73 | def forward(self, source):
74 | inputs = source
75 | out = self.block0(inputs)
76 | for i in range(self.encoder_layer - 1):
77 | model = getattr(self, 'encoder' + str(i))
78 | out = model(out)
79 | return out
80 |
81 |
82 | class DPTNGenerator(nn.Module):
83 | """
84 | Dual-task Pose Transformer Network (DPTN)
85 | :param image_nc: number of channels in input image
86 | :param pose_nc: number of channels in input pose
87 | :param ngf: base filter channel
88 | :param img_f: the largest feature channels
89 | :param layers: down and up sample layers
90 | :param norm: normalization function 'instance, batch, group'
91 | :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
92 | :param use_spect: use spectual normalization
93 | :param use_coord: use coordConv operation
94 | :param output_nc: number of channels in output image
95 | :param num_blocks: number of ResBlocks
96 | :param affine: affine in Pose Transformer Module
97 | :param nhead: number of heads in attention module
98 | :param num_CABs: number of CABs
99 | :param num_TTBs: number of TTBs
100 | """
101 | def __init__(self, image_nc, pose_nc, ngf=64, img_f=256, layers=3, norm='batch',
102 | activation='ReLU', use_spect=True, use_coord=False, output_nc=3, num_blocks=3, affine=True, nhead=2, num_CABs=2, num_TTBs=2):
103 | super(DPTNGenerator, self).__init__()
104 |
105 | self.layers = layers
106 | norm_layer = get_norm_layer(norm_type=norm)
107 | nonlinearity = get_nonlinearity_layer(activation_type=activation)
108 | input_nc = 2 * pose_nc + image_nc
109 |
110 | # Encoder En_c
111 | self.block0 = EncoderBlockOptimized(input_nc, ngf, norm_layer,
112 | nonlinearity, use_spect, use_coord)
113 | mult = 1
114 | for i in range(self.layers - 1):
115 | mult_prev = mult
116 | mult = min(2 ** (i + 1), img_f // ngf)
117 | block = EncoderBlock(ngf * mult_prev, ngf * mult, norm_layer,
118 | nonlinearity, use_spect, use_coord)
119 | setattr(self, 'encoder' + str(i), block)
120 |
121 | # ResBlocks
122 | self.num_blocks = num_blocks
123 | for i in range(num_blocks):
124 | block = ResBlock(ngf * mult, ngf * mult, norm_layer=norm_layer,
125 | nonlinearity=nonlinearity, use_spect=use_spect, use_coord=use_coord)
126 | setattr(self, 'mblock' + str(i), block)
127 |
128 | # Pose Transformer Module (PTM)
129 | self.PTM = PTM(d_model=ngf * mult, nhead=nhead, num_CABs=num_CABs,
130 | num_TTBs=num_TTBs, dim_feedforward=ngf * mult,
131 | activation="LeakyReLU", affine=affine, norm=norm)
132 |
133 | # Encoder En_s
134 | self.source_encoder = SourceEncoder(image_nc, ngf, img_f, layers, norm, activation, use_spect, use_coord)
135 |
136 | # Decoder
137 | for i in range(self.layers):
138 | mult_prev = mult
139 | mult = min(2 ** (self.layers - i - 2), img_f // ngf) if i != self.layers - 1 else 1
140 | up = ResBlockDecoder(ngf * mult_prev, ngf * mult, ngf * mult, norm_layer,
141 | nonlinearity, use_spect, use_coord)
142 | setattr(self, 'decoder' + str(i), up)
143 | self.outconv = Output(ngf, output_nc, 3, None, nonlinearity, use_spect, use_coord)
144 |
145 | def forward(self, source, source_B, target_B, is_train=True):
146 | # Self-reconstruction Branch
147 | # Source-to-source Inputs
148 | input_s_s = torch.cat((source, source_B, source_B), 1)
149 | # Source-to-source Encoder
150 | F_s_s = self.block0(input_s_s)
151 | for i in range(self.layers - 1):
152 | model = getattr(self, 'encoder' + str(i))
153 | F_s_s = model(F_s_s)
154 | # Source-to-source Resblocks
155 | for i in range(self.num_blocks):
156 | model = getattr(self, 'mblock' + str(i))
157 | F_s_s = model(F_s_s)
158 |
159 | # Transformation Branch
160 | # Source-to-target Inputs
161 | input_s_t = torch.cat((source, source_B, target_B), 1)
162 | # Source-to-target Encoder
163 | F_s_t = self.block0(input_s_t)
164 | for i in range(self.layers - 1):
165 | model = getattr(self, 'encoder' + str(i))
166 | F_s_t = model(F_s_t)
167 | # Source-to-target Resblocks
168 | for i in range(self.num_blocks):
169 | model = getattr(self, 'mblock' + str(i))
170 | F_s_t = model(F_s_t)
171 |
172 | # Source Image Encoding
173 | F_s = self.source_encoder(source)
174 |
175 | # Pose Transformer Module for Dual-task Correlation
176 | F_s_t = self.PTM(F_s_s, F_s_t, F_s)
177 |
178 | # Source-to-source Decoder (only for training)
179 | out_image_s = None
180 | if is_train:
181 | for i in range(self.layers):
182 | model = getattr(self, 'decoder' + str(i))
183 | F_s_s = model(F_s_s)
184 | out_image_s = self.outconv(F_s_s)
185 |
186 | # Source-to-target Decoder
187 | for i in range(self.layers):
188 | model = getattr(self, 'decoder' + str(i))
189 | F_s_t = model(F_s_t)
190 | out_image_t = self.outconv(F_s_t)
191 |
192 | return out_image_t, out_image_s
193 |
194 |
195 | ##############################################################################
196 | # Discriminator
197 | ##############################################################################
198 | class ResDiscriminator(nn.Module):
199 | """
200 | ResNet Discriminator Network
201 | :param input_nc: number of channels in input
202 | :param ndf: base filter channel
203 | :param layers: down and up sample layers
204 | :param img_f: the largest feature channels
205 | :param norm: normalization function 'instance, batch, group'
206 | :param activation: activation function 'ReLU, SELU, LeakyReLU, PReLU'
207 | :param use_spect: use spectual normalization
208 | :param use_coord: use coordConv operation
209 | """
210 | def __init__(self, input_nc=3, ndf=64, img_f=1024, layers=3, norm='none', activation='LeakyReLU', use_spect=True,
211 | use_coord=False):
212 | super(ResDiscriminator, self).__init__()
213 |
214 | self.layers = layers
215 |
216 | norm_layer = get_norm_layer(norm_type=norm)
217 | nonlinearity = get_nonlinearity_layer(activation_type=activation)
218 | self.nonlinearity = nonlinearity
219 |
220 | # encoder part
221 | self.block0 = ResBlockEncoderOptimized(input_nc, ndf, ndf, norm_layer, nonlinearity, use_spect, use_coord)
222 |
223 | mult = 1
224 | for i in range(layers - 1):
225 | mult_prev = mult
226 | mult = min(2 ** (i + 1), img_f//ndf)
227 | block = ResBlockEncoder(ndf*mult_prev, ndf*mult, ndf*mult_prev, norm_layer, nonlinearity, use_spect, use_coord)
228 | setattr(self, 'encoder' + str(i), block)
229 | self.conv = SpectralNorm(nn.Conv2d(ndf*mult, 1, 1))
230 |
231 | def forward(self, x):
232 | out = self.block0(x)
233 | for i in range(self.layers - 1):
234 | model = getattr(self, 'encoder' + str(i))
235 | out = model(out)
236 | out = self.conv(self.nonlinearity(out))
237 | return out
--------------------------------------------------------------------------------
/models/external_function.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torchvision.models as models
4 | from torch.nn import Parameter
5 | import torch.nn.functional as F
6 | import copy
7 |
8 |
9 | ####################################################################################################
10 | # adversarial loss for different gan mode
11 | ####################################################################################################
12 |
13 |
14 | class GANLoss(nn.Module):
15 | """Define different GAN objectives.
16 | The GANLoss class abstracts away the need to create the target label tensor
17 | that has the same size as the input.
18 | """
19 |
20 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
21 | """ Initialize the GANLoss class.
22 | Parameters:
23 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp.
24 | target_real_label (bool) - - label for a real image
25 | target_fake_label (bool) - - label of a fake image
26 | Note: Do not use sigmoid as the last layer of Discriminator.
27 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss.
28 | """
29 | super(GANLoss, self).__init__()
30 | self.register_buffer('real_label', torch.tensor(target_real_label))
31 | self.register_buffer('fake_label', torch.tensor(target_fake_label))
32 | self.gan_mode = gan_mode
33 | if gan_mode == 'lsgan':
34 | self.loss = nn.MSELoss()
35 | elif gan_mode == 'vanilla':
36 | self.loss = nn.BCEWithLogitsLoss()
37 | elif gan_mode == 'hinge':
38 | self.loss = nn.ReLU()
39 | elif gan_mode == 'wgangp':
40 | self.loss = None
41 | else:
42 | raise NotImplementedError('gan mode %s not implemented' % gan_mode)
43 |
44 | def __call__(self, prediction, target_is_real, is_disc=False):
45 | """Calculate loss given Discriminator's output and grount truth labels.
46 | Parameters:
47 | prediction (tensor) - - tpyically the prediction output from a discriminator
48 | target_is_real (bool) - - if the ground truth label is for real images or fake images
49 | Returns:
50 | the calculated loss.
51 | """
52 | if self.gan_mode in ['lsgan', 'vanilla']:
53 | labels = (self.real_label if target_is_real else self.fake_label).expand_as(prediction).type_as(prediction)
54 | loss = self.loss(prediction, labels)
55 | elif self.gan_mode in ['hinge', 'wgangp']:
56 | if is_disc:
57 | if target_is_real:
58 | prediction = -prediction
59 | if self.gan_mode == 'hinge':
60 | loss = self.loss(1 + prediction).mean()
61 | elif self.gan_mode == 'wgangp':
62 | loss = prediction.mean()
63 | else:
64 | loss = -prediction.mean()
65 | return loss
66 |
67 |
68 | def cal_gradient_penalty(netD, real_data, fake_data, type='mixed', constant=1.0, lambda_gp=10.0):
69 | """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028
70 | Arguments:
71 | netD (network) -- discriminator network
72 | real_data (tensor array) -- real images
73 | fake_data (tensor array) -- generated images from the generator
74 | type (str) -- if we mix real and fake data or not [real | fake | mixed].
75 | constant (float) -- the constant used in formula ( | |gradient||_2 - constant)^2
76 | lambda_gp (float) -- weight for this loss
77 | Returns the gradient penalty loss
78 | """
79 | if lambda_gp > 0.0:
80 | if type == 'real': # either use real images, fake images, or a linear interpolation of two.
81 | interpolatesv = real_data
82 | elif type == 'fake':
83 | interpolatesv = fake_data
84 | elif type == 'mixed':
85 | alpha = torch.rand(real_data.shape[0], 1)
86 | alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape)
87 | alpha = alpha.type_as(real_data)
88 | interpolatesv = alpha * real_data + ((1 - alpha) * fake_data)
89 | else:
90 | raise NotImplementedError('{} not implemented'.format(type))
91 | interpolatesv.requires_grad_(True)
92 | disc_interpolates = netD(interpolatesv)
93 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv,
94 | grad_outputs=torch.ones(disc_interpolates.size()).type_as(real_data),
95 | create_graph=True, retain_graph=True, only_inputs=True)
96 | gradients = gradients[0].view(real_data.size(0), -1) # flat the data
97 | gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps
98 | return gradient_penalty, gradients
99 | else:
100 | return 0.0, None
101 |
102 |
103 | class VGGLoss(nn.Module):
104 | r"""
105 | Perceptual loss, VGG-based
106 | https://arxiv.org/abs/1603.08155
107 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
108 | """
109 |
110 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
111 | super(VGGLoss, self).__init__()
112 | self.add_module('vgg', VGG19())
113 | self.criterion = torch.nn.L1Loss()
114 | self.weights = weights
115 |
116 | def compute_gram(self, x):
117 | b, ch, h, w = x.size()
118 | f = x.view(b, ch, w * h)
119 | f_T = f.transpose(1, 2)
120 | G = f.bmm(f_T) / (h * w * ch)
121 | return G
122 |
123 | def __call__(self, x, y):
124 | # Compute features
125 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
126 |
127 | content_loss = 0.0
128 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
129 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
130 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
131 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
132 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])
133 |
134 | # Compute loss
135 | style_loss = 0.0
136 | style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2']))
137 | style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4']))
138 | style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4']))
139 | style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2']))
140 |
141 | return content_loss, style_loss
142 |
143 |
144 | def reduce_sum(x, axis=None, keepdim=False):
145 | if not axis:
146 | axis = range(len(x.shape))
147 | for i in sorted(axis, reverse=True):
148 | x = torch.sum(x, dim=i, keepdim=keepdim)
149 | return x
150 |
151 |
152 | ####################################################################################################
153 | # neural style transform loss from neural_style_tutorial of pytorch
154 | ####################################################################################################
155 |
156 |
157 | class StyleLoss(nn.Module):
158 | r"""
159 | Perceptual loss, VGG-based
160 | https://arxiv.org/abs/1603.08155
161 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
162 | """
163 |
164 | def __init__(self):
165 | super(StyleLoss, self).__init__()
166 | self.add_module('vgg', VGG19())
167 | self.criterion = torch.nn.L1Loss()
168 |
169 | def compute_gram(self, x):
170 | b, ch, h, w = x.size()
171 | f = x.view(b, ch, w * h)
172 | f_T = f.transpose(1, 2)
173 | G = f.bmm(f_T) / (h * w * ch)
174 |
175 | return G
176 |
177 | def __call__(self, x, y):
178 | # Compute features
179 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
180 |
181 | # Compute loss
182 | style_loss = 0.0
183 | style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2']))
184 | style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4']))
185 | style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4']))
186 | style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2']))
187 |
188 | return style_loss
189 |
190 |
191 |
192 | class PerceptualLoss(nn.Module):
193 | r"""
194 | Perceptual loss, VGG-based
195 | https://arxiv.org/abs/1603.08155
196 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
197 | """
198 |
199 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
200 | super(PerceptualLoss, self).__init__()
201 | self.add_module('vgg', VGG19())
202 | self.criterion = torch.nn.L1Loss()
203 | self.weights = weights
204 |
205 | def __call__(self, x, y):
206 | # Compute features
207 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
208 |
209 | content_loss = 0.0
210 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
211 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
212 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
213 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
214 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])
215 |
216 |
217 | return content_loss
218 |
219 |
220 |
221 | class VGG19(torch.nn.Module):
222 | def __init__(self):
223 | super(VGG19, self).__init__()
224 | features = models.vgg19(pretrained=True).features
225 | self.relu1_1 = torch.nn.Sequential()
226 | self.relu1_2 = torch.nn.Sequential()
227 |
228 | self.relu2_1 = torch.nn.Sequential()
229 | self.relu2_2 = torch.nn.Sequential()
230 |
231 | self.relu3_1 = torch.nn.Sequential()
232 | self.relu3_2 = torch.nn.Sequential()
233 | self.relu3_3 = torch.nn.Sequential()
234 | self.relu3_4 = torch.nn.Sequential()
235 |
236 | self.relu4_1 = torch.nn.Sequential()
237 | self.relu4_2 = torch.nn.Sequential()
238 | self.relu4_3 = torch.nn.Sequential()
239 | self.relu4_4 = torch.nn.Sequential()
240 |
241 | self.relu5_1 = torch.nn.Sequential()
242 | self.relu5_2 = torch.nn.Sequential()
243 | self.relu5_3 = torch.nn.Sequential()
244 | self.relu5_4 = torch.nn.Sequential()
245 |
246 | for x in range(2):
247 | self.relu1_1.add_module(str(x), features[x])
248 |
249 | for x in range(2, 4):
250 | self.relu1_2.add_module(str(x), features[x])
251 |
252 | for x in range(4, 7):
253 | self.relu2_1.add_module(str(x), features[x])
254 |
255 | for x in range(7, 9):
256 | self.relu2_2.add_module(str(x), features[x])
257 |
258 | for x in range(9, 12):
259 | self.relu3_1.add_module(str(x), features[x])
260 |
261 | for x in range(12, 14):
262 | self.relu3_2.add_module(str(x), features[x])
263 |
264 | for x in range(14, 16):
265 | self.relu3_3.add_module(str(x), features[x])
266 |
267 | for x in range(16, 18):
268 | self.relu3_4.add_module(str(x), features[x])
269 |
270 | for x in range(18, 21):
271 | self.relu4_1.add_module(str(x), features[x])
272 |
273 | for x in range(21, 23):
274 | self.relu4_2.add_module(str(x), features[x])
275 |
276 | for x in range(23, 25):
277 | self.relu4_3.add_module(str(x), features[x])
278 |
279 | for x in range(25, 27):
280 | self.relu4_4.add_module(str(x), features[x])
281 |
282 | for x in range(27, 30):
283 | self.relu5_1.add_module(str(x), features[x])
284 |
285 | for x in range(30, 32):
286 | self.relu5_2.add_module(str(x), features[x])
287 |
288 | for x in range(32, 34):
289 | self.relu5_3.add_module(str(x), features[x])
290 |
291 | for x in range(34, 36):
292 | self.relu5_4.add_module(str(x), features[x])
293 |
294 | # don't need the gradients, just want the features
295 | for param in self.parameters():
296 | param.requires_grad = False
297 |
298 | def forward(self, x):
299 | relu1_1 = self.relu1_1(x)
300 | relu1_2 = self.relu1_2(relu1_1)
301 |
302 | relu2_1 = self.relu2_1(relu1_2)
303 | relu2_2 = self.relu2_2(relu2_1)
304 |
305 | relu3_1 = self.relu3_1(relu2_2)
306 | relu3_2 = self.relu3_2(relu3_1)
307 | relu3_3 = self.relu3_3(relu3_2)
308 | relu3_4 = self.relu3_4(relu3_3)
309 |
310 | relu4_1 = self.relu4_1(relu3_4)
311 | relu4_2 = self.relu4_2(relu4_1)
312 | relu4_3 = self.relu4_3(relu4_2)
313 | relu4_4 = self.relu4_4(relu4_3)
314 |
315 | relu5_1 = self.relu5_1(relu4_4)
316 | relu5_2 = self.relu5_2(relu5_1)
317 | relu5_3 = self.relu5_3(relu5_2)
318 | relu5_4 = self.relu5_4(relu5_3)
319 |
320 | out = {
321 | 'relu1_1': relu1_1,
322 | 'relu1_2': relu1_2,
323 |
324 | 'relu2_1': relu2_1,
325 | 'relu2_2': relu2_2,
326 |
327 | 'relu3_1': relu3_1,
328 | 'relu3_2': relu3_2,
329 | 'relu3_3': relu3_3,
330 | 'relu3_4': relu3_4,
331 |
332 | 'relu4_1': relu4_1,
333 | 'relu4_2': relu4_2,
334 | 'relu4_3': relu4_3,
335 | 'relu4_4': relu4_4,
336 |
337 | 'relu5_1': relu5_1,
338 | 'relu5_2': relu5_2,
339 | 'relu5_3': relu5_3,
340 | 'relu5_4': relu5_4,
341 | }
342 | return out
343 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | ## creative commons
2 |
3 | # Attribution-NonCommercial 4.0 International
4 |
5 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
6 |
7 | ### Using Creative Commons Public Licenses
8 |
9 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
10 |
11 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
12 |
13 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
14 |
15 | ## Creative Commons Attribution-NonCommercial 4.0 International Public License
16 |
17 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
18 |
19 | ### Section 1 – Definitions.
20 |
21 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
22 |
23 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
24 |
25 | c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
26 |
27 | d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
28 |
29 | e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
30 |
31 | f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
32 |
33 | g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
34 |
35 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
36 |
37 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
38 |
39 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
40 |
41 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
42 |
43 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
44 |
45 | ### Section 2 – Scope.
46 |
47 | a. ___License grant.___
48 |
49 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
50 |
51 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
52 |
53 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
54 |
55 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
56 |
57 | 3. __Term.__ The term of this Public License is specified in Section 6(a).
58 |
59 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
60 |
61 | 5. __Downstream recipients.__
62 |
63 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
64 |
65 | B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
66 |
67 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
68 |
69 | b. ___Other rights.___
70 |
71 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
72 |
73 | 2. Patent and trademark rights are not licensed under this Public License.
74 |
75 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
76 |
77 | ### Section 3 – License Conditions.
78 |
79 | Your exercise of the Licensed Rights is expressly made subject to the following conditions.
80 |
81 | a. ___Attribution.___
82 |
83 | 1. If You Share the Licensed Material (including in modified form), You must:
84 |
85 | A. retain the following if it is supplied by the Licensor with the Licensed Material:
86 |
87 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
88 |
89 | ii. a copyright notice;
90 |
91 | iii. a notice that refers to this Public License;
92 |
93 | iv. a notice that refers to the disclaimer of warranties;
94 |
95 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
96 |
97 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
98 |
99 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
100 |
101 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
102 |
103 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
104 |
105 | 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License.
106 |
107 | ### Section 4 – Sui Generis Database Rights.
108 |
109 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
110 |
111 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
112 |
113 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and
114 |
115 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
116 |
117 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
118 |
119 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability.
120 |
121 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
122 |
123 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
124 |
125 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
126 |
127 | ### Section 6 – Term and Termination.
128 |
129 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
130 |
131 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
132 |
133 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
134 |
135 | 2. upon express reinstatement by the Licensor.
136 |
137 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
138 |
139 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
140 |
141 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
142 |
143 | ### Section 7 – Other Terms and Conditions.
144 |
145 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
146 |
147 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
148 |
149 | ### Section 8 – Interpretation.
150 |
151 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
152 |
153 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
154 |
155 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
156 |
157 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
158 |
159 | > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
160 | >
161 | > Creative Commons may be contacted at creativecommons.org
162 |
--------------------------------------------------------------------------------
/models/base_function.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools
5 | from torch.optim import lr_scheduler
6 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
7 |
8 | ######################################################################################
9 | # base function for network structure
10 | ######################################################################################
11 |
12 |
13 | def init_weights(net, init_type='normal', gain=0.02):
14 | """Get different initial method for the network weights"""
15 | def init_func(m):
16 | classname = m.__class__.__name__
17 | if hasattr(m, 'weight') and (classname.find('Conv')!=-1 or classname.find('Linear')!=-1):
18 | if init_type == 'normal':
19 | init.normal_(m.weight.data, 0.0, gain)
20 | elif init_type == 'xavier':
21 | init.xavier_normal_(m.weight.data, gain=gain)
22 | elif init_type == 'kaiming':
23 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
24 | elif init_type == 'orthogonal':
25 | init.orthogonal_(m.weight.data, gain=gain)
26 | else:
27 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
28 | if hasattr(m, 'bias') and m.bias is not None:
29 | init.constant_(m.bias.data, 0.0)
30 | elif classname.find('BatchNorm2d') != -1:
31 | init.normal_(m.weight.data, 1.0, 0.02)
32 | init.constant_(m.bias.data, 0.0)
33 |
34 | print('initialize network with %s' % init_type)
35 | net.apply(init_func)
36 |
37 |
38 | def get_norm_layer(norm_type='batch'):
39 | """Get the normalization layer for the networks"""
40 | if norm_type == 'batch':
41 | norm_layer = functools.partial(nn.BatchNorm2d, momentum=0.1, affine=True)
42 | elif norm_type == 'instance':
43 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=True)
44 | elif norm_type == 'none':
45 | norm_layer = None
46 | else:
47 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
48 | return norm_layer
49 |
50 |
51 | def get_nonlinearity_layer(activation_type='PReLU'):
52 | """Get the activation layer for the networks"""
53 | if activation_type == 'ReLU':
54 | nonlinearity_layer = nn.ReLU()
55 | elif activation_type == 'SELU':
56 | nonlinearity_layer = nn.SELU()
57 | elif activation_type == 'LeakyReLU':
58 | nonlinearity_layer = nn.LeakyReLU(0.1)
59 | elif activation_type == 'PReLU':
60 | nonlinearity_layer = nn.PReLU()
61 | else:
62 | raise NotImplementedError('activation layer [%s] is not found' % activation_type)
63 | return nonlinearity_layer
64 |
65 |
66 | def get_scheduler(optimizer, opt):
67 | """Get the training learning rate for different epoch"""
68 | if opt.lr_policy == 'lambda':
69 | def lambda_rule(epoch):
70 | lr_l = 1.0 - max(0, epoch+opt.iter_start-opt.niter) / float(opt.niter_decay+1)
71 | return lr_l
72 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
73 | elif opt.lr_policy == 'step':
74 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
75 | elif opt.lr_policy == 'exponent':
76 | scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
77 | elif opt.lr_policy == 'cosine':
78 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=32, eta_min=0)
79 | else:
80 | raise NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
81 | return scheduler
82 |
83 |
84 | def print_network(net):
85 | """print the network"""
86 | num_params = 0
87 | for param in net.parameters():
88 | num_params += param.numel()
89 | print(net)
90 | print('total number of parameters: %.3f M' % (num_params/1e6))
91 |
92 |
93 | def init_net(net, init_type='normal', gpu_ids=[]):
94 | """print the network structure and initial the network"""
95 | print_network(net)
96 |
97 | if len(gpu_ids) > 0:
98 | assert(torch.cuda.is_available())
99 | net.cuda()
100 | net = torch.nn.DataParallel(net, gpu_ids)
101 | init_weights(net, init_type)
102 | return net
103 |
104 |
105 | def _freeze(*args):
106 | """freeze the network for forward process"""
107 | for module in args:
108 | if module:
109 | for p in module.parameters():
110 | p.requires_grad = False
111 |
112 |
113 | def _unfreeze(*args):
114 | """ unfreeze the network for parameter update"""
115 | for module in args:
116 | if module:
117 | for p in module.parameters():
118 | p.requires_grad = True
119 |
120 |
121 | def spectral_norm(module, use_spect=True):
122 | """use spectral normal layer to stable the training process"""
123 | if use_spect:
124 | return SpectralNorm(module)
125 | else:
126 | return module
127 |
128 |
129 | def coord_conv(input_nc, output_nc, use_spect=False, use_coord=False, with_r=False, **kwargs):
130 | """use coord convolution layer to add position information"""
131 | if use_coord:
132 | print("ERROR! #### ERROR! #### ERROR! #### ERROR! #### ERROR! #### ERROR! #### ERROR! #### ")
133 | return CoordConv(input_nc, output_nc, with_r, use_spect, **kwargs)
134 | else:
135 | return spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
136 |
137 |
138 | ######################################################################################
139 | # Network basic function
140 | ######################################################################################
141 | class AddCoords(nn.Module):
142 | """
143 | Add Coords to a tensor
144 | """
145 | def __init__(self, with_r=False):
146 | super(AddCoords, self).__init__()
147 | self.with_r = with_r
148 |
149 | def forward(self, x):
150 | """
151 | :param x: shape (batch, channel, x_dim, y_dim)
152 | :return: shape (batch, channel+2, x_dim, y_dim)
153 | """
154 | B, _, x_dim, y_dim = x.size()
155 |
156 | # coord calculate
157 | xx_channel = torch.arange(x_dim).repeat(B, 1, y_dim, 1).type_as(x)
158 | yy_cahnnel = torch.arange(y_dim).repeat(B, 1, x_dim, 1).permute(0, 1, 3, 2).type_as(x)
159 | # normalization
160 | xx_channel = xx_channel.float() / (x_dim-1)
161 | yy_cahnnel = yy_cahnnel.float() / (y_dim-1)
162 | xx_channel = xx_channel * 2 - 1
163 | yy_cahnnel = yy_cahnnel * 2 - 1
164 |
165 | ret = torch.cat([x, xx_channel, yy_cahnnel], dim=1)
166 |
167 | if self.with_r:
168 | rr = torch.sqrt(xx_channel ** 2 + yy_cahnnel ** 2)
169 | ret = torch.cat([ret, rr], dim=1)
170 |
171 | return ret
172 |
173 |
174 | class CoordConv(nn.Module):
175 | """
176 | CoordConv operation
177 | """
178 | def __init__(self, input_nc, output_nc, with_r=False, use_spect=False, **kwargs):
179 | super(CoordConv, self).__init__()
180 | self.addcoords = AddCoords(with_r=with_r)
181 | input_nc = input_nc + 2
182 | if with_r:
183 | input_nc = input_nc + 1
184 | self.conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
185 |
186 | def forward(self, x):
187 | ret = self.addcoords(x)
188 | ret = self.conv(ret)
189 |
190 | return ret
191 |
192 |
193 | class ResBlock(nn.Module):
194 | """
195 | Define an Residual block for different types
196 | """
197 | def __init__(self, input_nc, output_nc, hidden_nc=None, norm_layer=nn.BatchNorm2d, nonlinearity= nn.LeakyReLU(),
198 | sample_type='none', use_spect=False, use_coord=False):
199 | super(ResBlock, self).__init__()
200 |
201 | hidden_nc = output_nc if hidden_nc is None else hidden_nc
202 | self.sample = True
203 | if sample_type == 'none':
204 | self.sample = False
205 | elif sample_type == 'up':
206 | output_nc = output_nc * 4
207 | self.pool = nn.PixelShuffle(upscale_factor=2)
208 | elif sample_type == 'down':
209 | self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
210 | else:
211 | raise NotImplementedError('sample type [%s] is not found' % sample_type)
212 |
213 | kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1}
214 | kwargs_short = {'kernel_size': 1, 'stride': 1, 'padding': 0}
215 |
216 | self.conv1 = coord_conv(input_nc, hidden_nc, use_spect, use_coord, **kwargs)
217 | self.conv2 = coord_conv(hidden_nc, output_nc, use_spect, use_coord, **kwargs)
218 | self.bypass = coord_conv(input_nc, output_nc, use_spect, use_coord, **kwargs_short)
219 |
220 | if type(norm_layer) == type(None):
221 | self.model = nn.Sequential(nonlinearity, self.conv1, nonlinearity, self.conv2,)
222 | else:
223 | self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, self.conv1, norm_layer(hidden_nc), nonlinearity, self.conv2,)
224 |
225 | self.shortcut = nn.Sequential(self.bypass,)
226 |
227 | def forward(self, x):
228 | if self.sample:
229 | out = self.pool(self.model(x)) + self.pool(self.shortcut(x))
230 | else:
231 | out = self.model(x) + self.shortcut(x)
232 |
233 | return out
234 |
235 |
236 | class EncoderBlockOptimized(nn.Module):
237 | """
238 | Define an Encoder block for the first layer of the generator
239 | """
240 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(),
241 | use_spect=False, use_coord=False):
242 | super(EncoderBlockOptimized, self).__init__()
243 |
244 | kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1}
245 | kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1}
246 |
247 | conv1 = coord_conv(input_nc, output_nc, use_spect, use_coord, **kwargs_down)
248 | conv2 = coord_conv(output_nc, output_nc, use_spect, use_coord, **kwargs_fine)
249 |
250 | if type(norm_layer) == type(None):
251 | self.model = nn.Sequential(conv1, nonlinearity, conv2)
252 | else:
253 | self.model = nn.Sequential(conv1, norm_layer(output_nc), nonlinearity, conv2)
254 |
255 | def forward(self, x):
256 | out = self.model(x)
257 | return out
258 |
259 |
260 | class EncoderBlock(nn.Module):
261 | """
262 | Define an Encoder block for the medium layer of the generator
263 | """
264 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(),
265 | use_spect=False, use_coord=False):
266 | super(EncoderBlock, self).__init__()
267 |
268 |
269 | kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1}
270 | kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1}
271 |
272 | conv1 = coord_conv(input_nc, output_nc, use_spect, use_coord, **kwargs_down)
273 | conv2 = coord_conv(output_nc, output_nc, use_spect, use_coord, **kwargs_fine)
274 |
275 | if type(norm_layer) == type(None):
276 | self.model = nn.Sequential(conv1, nonlinearity, conv2, nonlinearity)
277 | else:
278 | self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, conv1,
279 | norm_layer(output_nc), nonlinearity, conv2)
280 |
281 | def forward(self, x):
282 | out = self.model(x)
283 | return out
284 |
285 |
286 | class ResBlockDecoder(nn.Module):
287 | """
288 | Define a decoder block
289 | """
290 | def __init__(self, input_nc, output_nc, hidden_nc=None, norm_layer=nn.BatchNorm2d, nonlinearity= nn.LeakyReLU(),
291 | use_spect=False, use_coord=False):
292 | super(ResBlockDecoder, self).__init__()
293 |
294 | hidden_nc = output_nc if hidden_nc is None else hidden_nc
295 |
296 | conv1 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, kernel_size=3, stride=1, padding=1), use_spect)
297 | conv2 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, kernel_size=3, stride=2, padding=1, output_padding=1), use_spect)
298 | bypass = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, kernel_size=3, stride=2, padding=1, output_padding=1), use_spect)
299 |
300 | if type(norm_layer) == type(None):
301 | self.model = nn.Sequential(nonlinearity, conv1, nonlinearity, conv2,)
302 | else:
303 | self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, conv1, norm_layer(hidden_nc), nonlinearity, conv2,)
304 |
305 | self.shortcut = nn.Sequential(bypass)
306 |
307 | def forward(self, x):
308 | out = self.model(x) + self.shortcut(x)
309 |
310 | return out
311 |
312 |
313 | class ResBlockEncoderOptimized(nn.Module):
314 | """
315 | Define an Encoder block for the first layer of the discriminator
316 | """
317 | def __init__(self, input_nc, output_nc, hidden_nc=None, norm_layer=nn.BatchNorm2d, nonlinearity= nn.LeakyReLU(),
318 | use_spect=False, use_coord=False):
319 | super(ResBlockEncoderOptimized, self).__init__()
320 |
321 | hidden_nc = input_nc if hidden_nc is None else hidden_nc
322 |
323 | conv1 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, kernel_size=3, stride=1, padding=1), use_spect)
324 | conv2 = spectral_norm(nn.Conv2d(hidden_nc, output_nc, kernel_size=4, stride=2, padding=1), use_spect)
325 | bypass = spectral_norm(nn.Conv2d(input_nc, output_nc, kernel_size=1, stride=1, padding=0), use_spect)
326 |
327 | if type(norm_layer) == type(None):
328 | self.model = nn.Sequential(conv1, nonlinearity, conv2,)
329 | else:
330 | self.model = nn.Sequential(conv1, norm_layer(hidden_nc), nonlinearity, conv2,)
331 | self.shortcut = nn.Sequential(nn.AvgPool2d(kernel_size=2, stride=2), bypass)
332 |
333 | def forward(self, x):
334 | out = self.model(x) + self.shortcut(x)
335 | return out
336 |
337 |
338 | class ResBlockEncoder(nn.Module):
339 | """
340 | Define an Encoder block for the medium layer of the discriminator
341 | """
342 | def __init__(self, input_nc, output_nc, hidden_nc=None, norm_layer=nn.BatchNorm2d, nonlinearity= nn.LeakyReLU(),
343 | use_spect=False, use_coord=False):
344 | super(ResBlockEncoder, self).__init__()
345 |
346 | hidden_nc = input_nc if hidden_nc is None else hidden_nc
347 |
348 | conv1 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, kernel_size=3, stride=1, padding=1), use_spect)
349 | conv2 = spectral_norm(nn.Conv2d(hidden_nc, output_nc, kernel_size=4, stride=2, padding=1), use_spect)
350 | bypass = spectral_norm(nn.Conv2d(input_nc, output_nc, kernel_size=1, stride=1, padding=0), use_spect)
351 |
352 | if type(norm_layer) == type(None):
353 | self.model = nn.Sequential(nonlinearity, conv1, nonlinearity, conv2,)
354 | else:
355 | self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, conv1,
356 | norm_layer(hidden_nc), nonlinearity, conv2,)
357 | self.shortcut = nn.Sequential(nn.AvgPool2d(kernel_size=2, stride=2), bypass)
358 |
359 | def forward(self, x):
360 | out = self.model(x) + self.shortcut(x)
361 | return out
362 |
363 |
364 | class Output(nn.Module):
365 | """
366 | Define the output layer
367 | """
368 | def __init__(self, input_nc, output_nc, kernel_size = 3, norm_layer=nn.BatchNorm2d, nonlinearity= nn.LeakyReLU(),
369 | use_spect=False, use_coord=False):
370 | super(Output, self).__init__()
371 |
372 | kwargs = {'kernel_size': kernel_size, 'padding':0, 'bias': True}
373 |
374 | self.conv1 = coord_conv(input_nc, output_nc, use_spect, use_coord, **kwargs)
375 |
376 | if type(norm_layer) == type(None):
377 | self.model = nn.Sequential(nonlinearity, nn.ReflectionPad2d(int(kernel_size/2)), self.conv1, nn.Tanh())
378 | else:
379 | self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, nn.ReflectionPad2d(int(kernel_size / 2)), self.conv1, nn.Tanh())
380 |
381 | def forward(self, x):
382 | out = self.model(x)
383 |
384 | return out
385 |
386 |
387 | class Auto_Attn(nn.Module):
388 | """ Short+Long attention Layer"""
389 |
390 | def __init__(self, input_nc, norm_layer=nn.BatchNorm2d):
391 | super(Auto_Attn, self).__init__()
392 | self.input_nc = input_nc
393 |
394 | self.query_conv = nn.Conv2d(input_nc, input_nc // 4, kernel_size=1)
395 | self.gamma = nn.Parameter(torch.zeros(1))
396 | self.alpha = nn.Parameter(torch.zeros(1))
397 |
398 | self.softmax = nn.Softmax(dim=-1)
399 |
400 | self.model = ResBlock(int(input_nc*2), input_nc, input_nc, norm_layer=norm_layer, use_spect=True)
401 |
402 | def forward(self, x, pre=None, mask=None):
403 | """
404 | inputs :
405 | x : input feature maps( B X C X W X H)
406 | returns :
407 | out : self attention value + input feature
408 | attention: B X N X N (N is Width*Height)
409 | """
410 | B, C, W, H = x.size()
411 | proj_query = self.query_conv(x).view(B, -1, W * H) # B X (N)X C
412 | proj_key = proj_query # B X C x (N)
413 |
414 | energy = torch.bmm(proj_query.permute(0, 2, 1), proj_key) # transpose check
415 | attention = self.softmax(energy) # BX (N) X (N)
416 | proj_value = x.view(B, -1, W * H) # B X C X N
417 |
418 | out = torch.bmm(proj_value, attention.permute(0, 2, 1))
419 | out = out.view(B, C, W, H)
420 |
421 | out = self.gamma * out + x
422 |
423 | if type(pre) != type(None):
424 | # using long distance attention layer to copy information from valid regions
425 | context_flow = torch.bmm(pre.view(B, -1, W*H), attention.permute(0, 2, 1)).view(B, -1, W, H)
426 | context_flow = self.alpha * (1-mask) * context_flow + (mask) * pre
427 | out = self.model(torch.cat([out, context_flow], dim=1))
428 |
429 | return out, attention
430 |
--------------------------------------------------------------------------------
/models/ui_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | from collections import OrderedDict
4 | import numpy as np
5 | import os
6 | from PIL import Image
7 | import util.util as util
8 | from .base_model import BaseModel
9 | from . import networks
10 |
11 | class UIModel(BaseModel):
12 | def name(self):
13 | return 'UIModel'
14 |
15 | def initialize(self, opt):
16 | assert(not opt.isTrain)
17 | BaseModel.initialize(self, opt)
18 | self.use_features = opt.instance_feat or opt.label_feat
19 |
20 | netG_input_nc = opt.label_nc
21 | if not opt.no_instance:
22 | netG_input_nc += 1
23 | if self.use_features:
24 | netG_input_nc += opt.feat_num
25 |
26 | self.netG = networks.define_G(netG_input_nc, opt.output_nc, opt.ngf, opt.netG,
27 | opt.n_downsample_global, opt.n_blocks_global, opt.n_local_enhancers,
28 | opt.n_blocks_local, opt.norm, gpu_ids=self.gpu_ids)
29 | self.load_network(self.netG, 'G', opt.which_epoch)
30 |
31 | print('---------- Networks initialized -------------')
32 |
33 | def toTensor(self, img, normalize=False):
34 | tensor = torch.from_numpy(np.array(img, np.int32, copy=False))
35 | tensor = tensor.view(1, img.size[1], img.size[0], len(img.mode))
36 | tensor = tensor.transpose(1, 2).transpose(1, 3).contiguous()
37 | if normalize:
38 | return (tensor.float()/255.0 - 0.5) / 0.5
39 | return tensor.float()
40 |
41 | def load_image(self, label_path, inst_path, feat_path):
42 | opt = self.opt
43 | # read label map
44 | label_img = Image.open(label_path)
45 | if label_path.find('face') != -1:
46 | label_img = label_img.convert('L')
47 | ow, oh = label_img.size
48 | w = opt.loadSize
49 | h = int(w * oh / ow)
50 | label_img = label_img.resize((w, h), Image.NEAREST)
51 | label_map = self.toTensor(label_img)
52 |
53 | # onehot vector input for label map
54 | self.label_map = label_map.cuda()
55 | oneHot_size = (1, opt.label_nc, h, w)
56 | input_label = self.Tensor(torch.Size(oneHot_size)).zero_()
57 | self.input_label = input_label.scatter_(1, label_map.long().cuda(), 1.0)
58 |
59 | # read instance map
60 | if not opt.no_instance:
61 | inst_img = Image.open(inst_path)
62 | inst_img = inst_img.resize((w, h), Image.NEAREST)
63 | self.inst_map = self.toTensor(inst_img).cuda()
64 | self.edge_map = self.get_edges(self.inst_map)
65 | self.net_input = Variable(torch.cat((self.input_label, self.edge_map), dim=1), volatile=True)
66 | else:
67 | self.net_input = Variable(self.input_label, volatile=True)
68 |
69 | self.features_clustered = np.load(feat_path).item()
70 | self.object_map = self.inst_map if opt.instance_feat else self.label_map
71 |
72 | object_np = self.object_map.cpu().numpy().astype(int)
73 | self.feat_map = self.Tensor(1, opt.feat_num, h, w).zero_()
74 | self.cluster_indices = np.zeros(self.opt.label_nc, np.uint8)
75 | for i in np.unique(object_np):
76 | label = i if i < 1000 else i//1000
77 | if label in self.features_clustered:
78 | feat = self.features_clustered[label]
79 | np.random.seed(i+1)
80 | cluster_idx = np.random.randint(0, feat.shape[0])
81 | self.cluster_indices[label] = cluster_idx
82 | idx = (self.object_map == i).nonzero()
83 | self.set_features(idx, feat, cluster_idx)
84 |
85 | self.net_input_original = self.net_input.clone()
86 | self.label_map_original = self.label_map.clone()
87 | self.feat_map_original = self.feat_map.clone()
88 | if not opt.no_instance:
89 | self.inst_map_original = self.inst_map.clone()
90 |
91 | def reset(self):
92 | self.net_input = self.net_input_prev = self.net_input_original.clone()
93 | self.label_map = self.label_map_prev = self.label_map_original.clone()
94 | self.feat_map = self.feat_map_prev = self.feat_map_original.clone()
95 | if not self.opt.no_instance:
96 | self.inst_map = self.inst_map_prev = self.inst_map_original.clone()
97 | self.object_map = self.inst_map if self.opt.instance_feat else self.label_map
98 |
99 | def undo(self):
100 | self.net_input = self.net_input_prev
101 | self.label_map = self.label_map_prev
102 | self.feat_map = self.feat_map_prev
103 | if not self.opt.no_instance:
104 | self.inst_map = self.inst_map_prev
105 | self.object_map = self.inst_map if self.opt.instance_feat else self.label_map
106 |
107 | # get boundary map from instance map
108 | def get_edges(self, t):
109 | edge = torch.cuda.ByteTensor(t.size()).zero_()
110 | edge[:,:,:,1:] = edge[:,:,:,1:] | (t[:,:,:,1:] != t[:,:,:,:-1])
111 | edge[:,:,:,:-1] = edge[:,:,:,:-1] | (t[:,:,:,1:] != t[:,:,:,:-1])
112 | edge[:,:,1:,:] = edge[:,:,1:,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
113 | edge[:,:,:-1,:] = edge[:,:,:-1,:] | (t[:,:,1:,:] != t[:,:,:-1,:])
114 | return edge.float()
115 |
116 | # change the label at the source position to the label at the target position
117 | def change_labels(self, click_src, click_tgt):
118 | y_src, x_src = click_src[0], click_src[1]
119 | y_tgt, x_tgt = click_tgt[0], click_tgt[1]
120 | label_src = int(self.label_map[0, 0, y_src, x_src])
121 | inst_src = self.inst_map[0, 0, y_src, x_src]
122 | label_tgt = int(self.label_map[0, 0, y_tgt, x_tgt])
123 | inst_tgt = self.inst_map[0, 0, y_tgt, x_tgt]
124 |
125 | idx_src = (self.inst_map == inst_src).nonzero()
126 | # need to change 3 things: label map, instance map, and feature map
127 | if idx_src.shape:
128 | # backup current maps
129 | self.backup_current_state()
130 |
131 | # change both the label map and the network input
132 | self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
133 | self.net_input[idx_src[:,0], idx_src[:,1] + label_src, idx_src[:,2], idx_src[:,3]] = 0
134 | self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1
135 |
136 | # update the instance map (and the network input)
137 | if inst_tgt > 1000:
138 | # if different instances have different ids, give the new object a new id
139 | tgt_indices = (self.inst_map > label_tgt * 1000) & (self.inst_map < (label_tgt+1) * 1000)
140 | inst_tgt = self.inst_map[tgt_indices].max() + 1
141 | self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = inst_tgt
142 | self.net_input[:,-1,:,:] = self.get_edges(self.inst_map)
143 |
144 | # also copy the source features to the target position
145 | idx_tgt = (self.inst_map == inst_tgt).nonzero()
146 | if idx_tgt.shape:
147 | self.copy_features(idx_src, idx_tgt[0,:])
148 |
149 | self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
150 |
151 | # add strokes of target label in the image
152 | def add_strokes(self, click_src, label_tgt, bw, save):
153 | # get the region of the new strokes (bw is the brush width)
154 | size = self.net_input.size()
155 | h, w = size[2], size[3]
156 | idx_src = torch.LongTensor(bw**2, 4).fill_(0)
157 | for i in range(bw):
158 | idx_src[i*bw:(i+1)*bw, 2] = min(h-1, max(0, click_src[0]-bw//2 + i))
159 | for j in range(bw):
160 | idx_src[i*bw+j, 3] = min(w-1, max(0, click_src[1]-bw//2 + j))
161 | idx_src = idx_src.cuda()
162 |
163 | # again, need to update 3 things
164 | if idx_src.shape:
165 | # backup current maps
166 | if save:
167 | self.backup_current_state()
168 |
169 | # update the label map (and the network input) in the stroke region
170 | self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
171 | for k in range(self.opt.label_nc):
172 | self.net_input[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = 0
173 | self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1
174 |
175 | # update the instance map (and the network input)
176 | self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
177 | self.net_input[:,-1,:,:] = self.get_edges(self.inst_map)
178 |
179 | # also update the features if available
180 | if self.opt.instance_feat:
181 | feat = self.features_clustered[label_tgt]
182 | #np.random.seed(label_tgt+1)
183 | #cluster_idx = np.random.randint(0, feat.shape[0])
184 | cluster_idx = self.cluster_indices[label_tgt]
185 | self.set_features(idx_src, feat, cluster_idx)
186 |
187 | self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
188 |
189 | # add an object to the clicked position with selected style
190 | def add_objects(self, click_src, label_tgt, mask, style_id=0):
191 | y, x = click_src[0], click_src[1]
192 | mask = np.transpose(mask, (2, 0, 1))[np.newaxis,...]
193 | idx_src = torch.from_numpy(mask).cuda().nonzero()
194 | idx_src[:,2] += y
195 | idx_src[:,3] += x
196 |
197 | # backup current maps
198 | self.backup_current_state()
199 |
200 | # update label map
201 | self.label_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
202 | for k in range(self.opt.label_nc):
203 | self.net_input[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = 0
204 | self.net_input[idx_src[:,0], idx_src[:,1] + label_tgt, idx_src[:,2], idx_src[:,3]] = 1
205 |
206 | # update instance map
207 | self.inst_map[idx_src[:,0], idx_src[:,1], idx_src[:,2], idx_src[:,3]] = label_tgt
208 | self.net_input[:,-1,:,:] = self.get_edges(self.inst_map)
209 |
210 | # update feature map
211 | self.set_features(idx_src, self.feat, style_id)
212 |
213 | self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
214 |
215 | def single_forward(self, net_input, feat_map):
216 | net_input = torch.cat((net_input, feat_map), dim=1)
217 | fake_image = self.netG.forward(net_input)
218 |
219 | if fake_image.size()[0] == 1:
220 | return fake_image.data[0]
221 | return fake_image.data
222 |
223 |
224 | # generate all outputs for different styles
225 | def style_forward(self, click_pt, style_id=-1):
226 | if click_pt is None:
227 | self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
228 | self.crop = None
229 | self.mask = None
230 | else:
231 | instToChange = int(self.object_map[0, 0, click_pt[0], click_pt[1]])
232 | self.instToChange = instToChange
233 | label = instToChange if instToChange < 1000 else instToChange//1000
234 | self.feat = self.features_clustered[label]
235 | self.fake_image = []
236 | self.mask = self.object_map == instToChange
237 | idx = self.mask.nonzero()
238 | self.get_crop_region(idx)
239 | if idx.size():
240 | if style_id == -1:
241 | (min_y, min_x, max_y, max_x) = self.crop
242 | ### original
243 | for cluster_idx in range(self.opt.multiple_output):
244 | self.set_features(idx, self.feat, cluster_idx)
245 | fake_image = self.single_forward(self.net_input, self.feat_map)
246 | fake_image = util.tensor2im(fake_image[:,min_y:max_y,min_x:max_x])
247 | self.fake_image.append(fake_image)
248 | """### To speed up previewing different style results, either crop or downsample the label maps
249 | if instToChange > 1000:
250 | (min_y, min_x, max_y, max_x) = self.crop
251 | ### crop
252 | _, _, h, w = self.net_input.size()
253 | offset = 512
254 | y_start, x_start = max(0, min_y-offset), max(0, min_x-offset)
255 | y_end, x_end = min(h, (max_y + offset)), min(w, (max_x + offset))
256 | y_region = slice(y_start, y_start+(y_end-y_start)//16*16)
257 | x_region = slice(x_start, x_start+(x_end-x_start)//16*16)
258 | net_input = self.net_input[:,:,y_region,x_region]
259 | for cluster_idx in range(self.opt.multiple_output):
260 | self.set_features(idx, self.feat, cluster_idx)
261 | fake_image = self.single_forward(net_input, self.feat_map[:,:,y_region,x_region])
262 | fake_image = util.tensor2im(fake_image[:,min_y-y_start:max_y-y_start,min_x-x_start:max_x-x_start])
263 | self.fake_image.append(fake_image)
264 | else:
265 | ### downsample
266 | (min_y, min_x, max_y, max_x) = [crop//2 for crop in self.crop]
267 | net_input = self.net_input[:,:,::2,::2]
268 | size = net_input.size()
269 | net_input_batch = net_input.expand(self.opt.multiple_output, size[1], size[2], size[3])
270 | for cluster_idx in range(self.opt.multiple_output):
271 | self.set_features(idx, self.feat, cluster_idx)
272 | feat_map = self.feat_map[:,:,::2,::2]
273 | if cluster_idx == 0:
274 | feat_map_batch = feat_map
275 | else:
276 | feat_map_batch = torch.cat((feat_map_batch, feat_map), dim=0)
277 | fake_image_batch = self.single_forward(net_input_batch, feat_map_batch)
278 | for i in range(self.opt.multiple_output):
279 | self.fake_image.append(util.tensor2im(fake_image_batch[i,:,min_y:max_y,min_x:max_x]))"""
280 |
281 | else:
282 | self.set_features(idx, self.feat, style_id)
283 | self.cluster_indices[label] = style_id
284 | self.fake_image = util.tensor2im(self.single_forward(self.net_input, self.feat_map))
285 |
286 | def backup_current_state(self):
287 | self.net_input_prev = self.net_input.clone()
288 | self.label_map_prev = self.label_map.clone()
289 | self.inst_map_prev = self.inst_map.clone()
290 | self.feat_map_prev = self.feat_map.clone()
291 |
292 | # crop the ROI and get the mask of the object
293 | def get_crop_region(self, idx):
294 | size = self.net_input.size()
295 | h, w = size[2], size[3]
296 | min_y, min_x = idx[:,2].min(), idx[:,3].min()
297 | max_y, max_x = idx[:,2].max(), idx[:,3].max()
298 | crop_min = 128
299 | if max_y - min_y < crop_min:
300 | min_y = max(0, (max_y + min_y) // 2 - crop_min // 2)
301 | max_y = min(h-1, min_y + crop_min)
302 | if max_x - min_x < crop_min:
303 | min_x = max(0, (max_x + min_x) // 2 - crop_min // 2)
304 | max_x = min(w-1, min_x + crop_min)
305 | self.crop = (min_y, min_x, max_y, max_x)
306 | self.mask = self.mask[:,:, min_y:max_y, min_x:max_x]
307 |
308 | # update the feature map once a new object is added or the label is changed
309 | def update_features(self, cluster_idx, mask=None, click_pt=None):
310 | self.feat_map_prev = self.feat_map.clone()
311 | # adding a new object
312 | if mask is not None:
313 | y, x = click_pt[0], click_pt[1]
314 | mask = np.transpose(mask, (2,0,1))[np.newaxis,...]
315 | idx = torch.from_numpy(mask).cuda().nonzero()
316 | idx[:,2] += y
317 | idx[:,3] += x
318 | # changing the label of an existing object
319 | else:
320 | idx = (self.object_map == self.instToChange).nonzero()
321 |
322 | # update feature map
323 | self.set_features(idx, self.feat, cluster_idx)
324 |
325 | # set the class features to the target feature
326 | def set_features(self, idx, feat, cluster_idx):
327 | for k in range(self.opt.feat_num):
328 | self.feat_map[idx[:,0], idx[:,1] + k, idx[:,2], idx[:,3]] = feat[cluster_idx, k]
329 |
330 | # copy the features at the target position to the source position
331 | def copy_features(self, idx_src, idx_tgt):
332 | for k in range(self.opt.feat_num):
333 | val = self.feat_map[idx_tgt[0], idx_tgt[1] + k, idx_tgt[2], idx_tgt[3]]
334 | self.feat_map[idx_src[:,0], idx_src[:,1] + k, idx_src[:,2], idx_src[:,3]] = val
335 |
336 | def get_current_visuals(self, getLabel=False):
337 | mask = self.mask
338 | if self.mask is not None:
339 | mask = np.transpose(self.mask[0].cpu().float().numpy(), (1,2,0)).astype(np.uint8)
340 |
341 | dict_list = [('fake_image', self.fake_image), ('mask', mask)]
342 |
343 | if getLabel: # only output label map if needed to save bandwidth
344 | label = util.tensor2label(self.net_input.data[0], self.opt.label_nc)
345 | dict_list += [('label', label)]
346 |
347 | return OrderedDict(dict_list)
--------------------------------------------------------------------------------
/metrics/metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | import torch
4 | import numpy as np
5 | from imageio import imread
6 | from scipy import linalg
7 | from torch.nn.functional import adaptive_avg_pool2d
8 | from skimage.measure import compare_ssim
9 | from skimage.measure import compare_psnr
10 | import glob
11 | import argparse
12 | import matplotlib.pyplot as plt
13 | from metrics.inception import InceptionV3
14 | from metrics.PerceptualSimilarity.models import dist_model as dm
15 | import pandas as pd
16 | import json
17 | import imageio
18 | from skimage.draw import circle, line_aa, polygon
19 |
20 |
21 | def pad_256(img):
22 | result = np.ones((256, 256, 3), dtype=float) * 255
23 | result[:,40:216,:] = img
24 | return result
25 |
26 |
27 | class FID():
28 | """docstring for FID
29 | Calculates the Frechet Inception Distance (FID) to evalulate GANs
30 | The FID metric calculates the distance between two distributions of images.
31 | Typically, we have summary statistics (mean & covariance matrix) of one
32 | of these distributions, while the 2nd distribution is given by a GAN.
33 | When run as a stand-alone program, it compares the distribution of
34 | images that are stored as PNG/JPEG at a specified location with a
35 | distribution given by summary statistics (in pickle format).
36 | The FID is calculated by assuming that X_1 and X_2 are the activations of
37 | the pool_3 layer of the inception net for generated samples and real world
38 | samples respectivly.
39 | See --help to see further details.
40 | Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
41 | of Tensorflow
42 | Copyright 2018 Institute of Bioinformatics, JKU Linz
43 | Licensed under the Apache License, Version 2.0 (the "License");
44 | you may not use this file except in compliance with the License.
45 | You may obtain a copy of the License at
46 | http://www.apache.org/licenses/LICENSE-2.0
47 | Unless required by applicable law or agreed to in writing, software
48 | distributed under the License is distributed on an "AS IS" BASIS,
49 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
50 | See the License for the specific language governing permissions and
51 | limitations under the License.
52 | """
53 | def __init__(self):
54 | self.dims = 2048
55 | self.batch_size = 64
56 | self.cuda = True
57 | self.verbose=False
58 |
59 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims]
60 | self.model = InceptionV3([block_idx])
61 | if self.cuda:
62 | # TODO: put model into specific GPU
63 | self.model.cuda()
64 |
65 | def __call__(self, images, gt_path):
66 | """ images: list of the generated image. The values must lie between 0 and 1.
67 | gt_path: the path of the ground truth images. The values must lie between 0 and 1.
68 | """
69 | if not os.path.exists(gt_path):
70 | raise RuntimeError('Invalid path: %s' % gt_path)
71 |
72 | print('calculate gt_path statistics...')
73 | m1, s1 = self.compute_statistics_of_path(gt_path, self.verbose)
74 | print('calculate generated_images statistics...')
75 | m2, s2 = self.calculate_activation_statistics(images, self.verbose)
76 | fid_value = self.calculate_frechet_distance(m1, s1, m2, s2)
77 | return fid_value
78 |
79 | def calculate_from_disk(self, generated_path, gt_path):
80 | """
81 | """
82 | if not os.path.exists(gt_path):
83 | raise RuntimeError('Invalid path: %s' % gt_path)
84 | if not os.path.exists(generated_path):
85 | raise RuntimeError('Invalid path: %s' % generated_path)
86 |
87 | print('calculate gt_path statistics...')
88 | m1, s1 = self.compute_statistics_of_path(gt_path, self.verbose)
89 | print('calculate generated_path statistics...')
90 | m2, s2 = self.compute_statistics_of_path(generated_path, self.verbose)
91 | print('calculate frechet distance...')
92 | fid_value = self.calculate_frechet_distance(m1, s1, m2, s2)
93 | print('fid_distance %f' % (fid_value))
94 | return fid_value
95 |
96 | def compute_statistics_of_path(self, path, verbose):
97 | npz_file = os.path.join(path, 'statistics.npz')
98 | if os.path.exists(npz_file):
99 | f = np.load(npz_file)
100 | m, s = f['mu'][:], f['sigma'][:]
101 | f.close()
102 | else:
103 | path = pathlib.Path(path)
104 | files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
105 |
106 | imgs = np.array([imread(str(fn)).astype(np.float32) for fn in files])
107 |
108 | # Bring images to shape (B, 3, H, W)
109 | imgs = imgs.transpose((0, 3, 1, 2))
110 |
111 | # Rescale images to be between 0 and 1
112 | imgs /= 255
113 |
114 | m, s = self.calculate_activation_statistics(imgs, verbose)
115 | np.savez(npz_file, mu=m, sigma=s)
116 |
117 | return m, s
118 |
119 | def calculate_activation_statistics(self, images, verbose):
120 | """Calculation of the statistics used by the FID.
121 | Params:
122 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values
123 | must lie between 0 and 1.
124 | -- model : Instance of inception model
125 | -- batch_size : The images numpy array is split into batches with
126 | batch size batch_size. A reasonable batch size
127 | depends on the hardware.
128 | -- dims : Dimensionality of features returned by Inception
129 | -- cuda : If set to True, use GPU
130 | -- verbose : If set to True and parameter out_step is given, the
131 | number of calculated batches is reported.
132 | Returns:
133 | -- mu : The mean over samples of the activations of the pool_3 layer of
134 | the inception model.
135 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
136 | the inception model.
137 | """
138 | act = self.get_activations(images, verbose)
139 | mu = np.mean(act, axis=0)
140 | sigma = np.cov(act, rowvar=False)
141 | return mu, sigma
142 |
143 | def get_activations(self, images, verbose=False):
144 | """Calculates the activations of the pool_3 layer for all images.
145 | Params:
146 | -- images : Numpy array of dimension (n_images, 3, hi, wi). The values
147 | must lie between 0 and 1.
148 | -- model : Instance of inception model
149 | -- batch_size : the images numpy array is split into batches with
150 | batch size batch_size. A reasonable batch size depends
151 | on the hardware.
152 | -- dims : Dimensionality of features returned by Inception
153 | -- cuda : If set to True, use GPU
154 | -- verbose : If set to True and parameter out_step is given, the number
155 | of calculated batches is reported.
156 | Returns:
157 | -- A numpy array of dimension (num images, dims) that contains the
158 | activations of the given tensor when feeding inception with the
159 | query tensor.
160 | """
161 | self.model.eval()
162 |
163 | d0 = images.shape[0]
164 | if self.batch_size > d0:
165 | print(('Warning: batch size is bigger than the data size. '
166 | 'Setting batch size to data size'))
167 | self.batch_size = d0
168 |
169 | n_batches = d0 // self.batch_size
170 | n_used_imgs = n_batches * self.batch_size
171 |
172 | pred_arr = np.empty((n_used_imgs, self.dims))
173 | for i in range(n_batches):
174 | if verbose:
175 | print('\rPropagating batch %d/%d' % (i + 1, n_batches))
176 | start = i * self.batch_size
177 | end = start + self.batch_size
178 |
179 | batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor)
180 | if self.cuda:
181 | batch = batch.cuda()
182 |
183 | pred = self.model(batch)[0]
184 |
185 | # If model output is not scalar, apply global spatial average pooling.
186 | # This happens if you choose a dimensionality not equal 2048.
187 | if pred.shape[2] != 1 or pred.shape[3] != 1:
188 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
189 |
190 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(self.batch_size, -1)
191 |
192 | if verbose:
193 | print(' done')
194 |
195 | return pred_arr
196 |
197 | def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
198 | """Numpy implementation of the Frechet Distance.
199 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
200 | and X_2 ~ N(mu_2, C_2) is
201 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
202 | Stable version by Dougal J. Sutherland.
203 | Params:
204 | -- mu1 : Numpy array containing the activations of a layer of the
205 | inception net (like returned by the function 'get_predictions')
206 | for generated samples.
207 | -- mu2 : The sample mean over activations, precalculated on an
208 | representive data set.
209 | -- sigma1: The covariance matrix over activations for generated samples.
210 | -- sigma2: The covariance matrix over activations, precalculated on an
211 | representive data set.
212 | Returns:
213 | -- : The Frechet Distance.
214 | """
215 |
216 | mu1 = np.atleast_1d(mu1)
217 | mu2 = np.atleast_1d(mu2)
218 |
219 | sigma1 = np.atleast_2d(sigma1)
220 | sigma2 = np.atleast_2d(sigma2)
221 |
222 | assert mu1.shape == mu2.shape, \
223 | 'Training and test mean vectors have different lengths'
224 | assert sigma1.shape == sigma2.shape, \
225 | 'Training and test covariances have different dimensions'
226 |
227 | diff = mu1 - mu2
228 |
229 | # Product might be almost singular
230 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
231 | if not np.isfinite(covmean).all():
232 | msg = ('fid calculation produces singular product; '
233 | 'adding %s to diagonal of cov estimates') % eps
234 | print(msg)
235 | offset = np.eye(sigma1.shape[0]) * eps
236 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
237 |
238 | # Numerical error might give slight imaginary component
239 | if np.iscomplexobj(covmean):
240 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
241 | m = np.max(np.abs(covmean.imag))
242 | raise ValueError('Imaginary component {}'.format(m))
243 | covmean = covmean.real
244 |
245 | tr_covmean = np.trace(covmean)
246 |
247 | return (diff.dot(diff) + np.trace(sigma1) +
248 | np.trace(sigma2) - 2 * tr_covmean)
249 |
250 |
251 | class Reconstruction_Metrics():
252 | def __init__(self, metric_list=['ssim', 'psnr', 'l1', 'mae'], data_range=1, win_size=51, multichannel=True):
253 | self.data_range = data_range
254 | self.win_size = win_size
255 | self.multichannel = multichannel
256 | for metric in metric_list:
257 | if metric in ['ssim', 'psnr', 'l1', 'mae']:
258 | setattr(self, metric, True)
259 | else:
260 | print('unsupport reconstruction metric: %s'%metric)
261 |
262 | def __call__(self, inputs, gts):
263 | """
264 | inputs: the generated image, size (b,c,w,h), data range(0, data_range)
265 | gts: the ground-truth image, size (b,c,w,h), data range(0, data_range)
266 | """
267 | result = dict()
268 | [b,n,w,h] = inputs.size()
269 | inputs = inputs.view(b*n, w, h).detach().cpu().numpy().astype(np.float32).transpose(1,2,0)
270 | gts = gts.view(b*n, w, h).detach().cpu().numpy().astype(np.float32).transpose(1,2,0)
271 |
272 | if hasattr(self, 'ssim'):
273 | ssim_value = compare_ssim(inputs, gts, data_range=self.data_range,
274 | win_size=self.win_size, multichannel=self.multichannel)
275 | result['ssim'] = ssim_value
276 |
277 |
278 | if hasattr(self, 'psnr'):
279 | psnr_value = compare_psnr(inputs, gts, self.data_range)
280 | result['psnr'] = psnr_value
281 |
282 | if hasattr(self, 'l1'):
283 | l1_value = compare_l1(inputs, gts)
284 | result['l1'] = l1_value
285 |
286 | if hasattr(self, 'mae'):
287 | mae_value = compare_mae(inputs, gts)
288 | result['mae'] = mae_value
289 | return result
290 |
291 | def calculate_from_disk(self, inputs, gts, save_path=None, sort=True, debug=0):
292 | """
293 | inputs: .txt files, floders, image files (string), image files (list)
294 | gts: .txt files, floders, image files (string), image files (list)
295 | """
296 | if sort:
297 | input_image_list = sorted(get_image_list(inputs))
298 | gt_image_list = sorted(get_image_list(gts))
299 | else:
300 | input_image_list = get_image_list(inputs)
301 | gt_image_list = get_image_list(gts)
302 | npz_file = os.path.join(save_path, 'metrics.npz')
303 | if os.path.exists(npz_file):
304 | f = np.load(npz_file)
305 | psnr,ssim,ssim_256,mae,l1=f['psnr'],f['ssim'],f['ssim_256'],f['mae'],f['l1']
306 | else:
307 | psnr = []
308 | ssim = []
309 | ssim_256 = []
310 | mae = []
311 | l1 = []
312 | names = []
313 |
314 | for index in range(len(input_image_list)):
315 | name = os.path.basename(input_image_list[index])
316 | names.append(name)
317 |
318 | img_gt = pad_256(imread(str(gt_image_list[index]))).astype(np.float32) / 255.0
319 | img_pred = pad_256(imread(str(input_image_list[index]))).astype(np.float32) / 255.0
320 |
321 |
322 | if debug != 0:
323 | plt.subplot('121')
324 | plt.imshow(img_gt)
325 | plt.title('Groud truth')
326 | plt.subplot('122')
327 | plt.imshow(img_pred)
328 | plt.title('Output')
329 | plt.show()
330 |
331 | psnr.append(compare_psnr(img_gt, img_pred, data_range=self.data_range))
332 | ssim.append(compare_ssim(img_gt, img_pred, data_range=self.data_range,
333 | win_size=self.win_size,multichannel=self.multichannel))
334 | mae.append(compare_mae(img_gt, img_pred))
335 | l1.append(compare_l1(img_gt, img_pred))
336 |
337 | img_gt_256 = img_gt*255.0
338 | img_pred_256 = img_pred*255.0
339 | ssim_256.append(compare_ssim(img_gt_256, img_pred_256, gaussian_weights=True, sigma=1.5,
340 | use_sample_covariance=False, multichannel=True,
341 | data_range=img_pred_256.max() - img_pred_256.min()))
342 | if np.mod(index, 200) == 0:
343 | print(
344 | str(index) + ' images processed',
345 | "PSNR: %.4f" % round(np.mean(psnr), 4),
346 | "SSIM: %.4f" % round(np.mean(ssim), 4),
347 | "SSIM_256: %.4f" % round(np.mean(ssim_256), 4),
348 | "MAE: %.4f" % round(np.mean(mae), 4),
349 | "l1: %.4f" % round(np.mean(l1), 4),
350 | )
351 |
352 | if save_path:
353 | np.savez(save_path + '/metrics.npz', psnr=psnr, ssim=ssim, ssim_256=ssim_256, mae=mae, l1=l1, names=names)
354 |
355 | print(
356 | "PSNR: %.4f" % round(np.mean(psnr), 4),
357 | "PSNR Variance: %.4f" % round(np.var(psnr), 4),
358 | "SSIM: %.4f" % round(np.mean(ssim), 4),
359 | "SSIM Variance: %.4f" % round(np.var(ssim), 4),
360 | "SSIM_256: %.4f" % round(np.mean(ssim_256), 4),
361 | "SSIM_256 Variance: %.4f" % round(np.var(ssim_256), 4),
362 | "MAE: %.4f" % round(np.mean(mae), 4),
363 | "MAE Variance: %.4f" % round(np.var(mae), 4),
364 | "l1: %.4f" % round(np.mean(l1), 4),
365 | "l1 Variance: %.4f" % round(np.var(l1), 4)
366 | )
367 |
368 | dic = {"psnr":[round(np.mean(psnr), 6)],
369 | "psnr_variance": [round(np.var(psnr), 6)],
370 | "ssim": [round(np.mean(ssim), 6)],
371 | "ssim_variance": [round(np.var(ssim), 6)],
372 | "ssim_256": [round(np.mean(ssim_256), 6)],
373 | "ssim_256_variance": [round(np.var(ssim_256), 6)],
374 | "mae": [round(np.mean(mae), 6)],
375 | "mae_variance": [round(np.var(mae), 6)],
376 | "l1": [round(np.mean(l1), 6)],
377 | "l1_variance": [round(np.var(l1), 6)] }
378 |
379 | return dic
380 |
381 |
382 | class Reconstruction_Market_Metrics():
383 | def __init__(self, metric_list=['ssim', 'psnr', 'l1', 'mae'], data_range=1, win_size=51, multichannel=True):
384 | self.data_range = data_range
385 | self.win_size = win_size
386 | self.multichannel = multichannel
387 | for metric in metric_list:
388 | if metric in ['ssim', 'psnr', 'l1', 'mae']:
389 | setattr(self, metric, True)
390 | else:
391 | print('unsupport reconstruction metric: %s' % metric)
392 |
393 | def __call__(self, inputs, gts):
394 | """
395 | inputs: the generated image, size (b,c,w,h), data range(0, data_range)
396 | gts: the ground-truth image, size (b,c,w,h), data range(0, data_range)
397 | """
398 | result = dict()
399 | [b, n, w, h] = inputs.size()
400 | inputs = inputs.view(b * n, w, h).detach().cpu().numpy().astype(np.float32).transpose(1, 2, 0)
401 | gts = gts.view(b * n, w, h).detach().cpu().numpy().astype(np.float32).transpose(1, 2, 0)
402 |
403 | if hasattr(self, 'ssim'):
404 | ssim_value = compare_ssim(inputs, gts, data_range=self.data_range,
405 | win_size=self.win_size, multichannel=self.multichannel)
406 | result['ssim'] = ssim_value
407 |
408 | if hasattr(self, 'psnr'):
409 | psnr_value = compare_psnr(inputs, gts, self.data_range)
410 | result['psnr'] = psnr_value
411 |
412 | if hasattr(self, 'l1'):
413 | l1_value = compare_l1(inputs, gts)
414 | result['l1'] = l1_value
415 |
416 | if hasattr(self, 'mae'):
417 | mae_value = compare_mae(inputs, gts)
418 | result['mae'] = mae_value
419 | return result
420 |
421 | def calculate_from_disk(self, inputs, gts, save_path=None, sort=True, debug=0):
422 | """
423 | inputs: .txt files, floders, image files (string), image files (list)
424 | gts: .txt files, floders, image files (string), image files (list)
425 | """
426 | if sort:
427 | input_image_list = sorted(get_image_list(inputs))
428 | gt_image_list = sorted(get_image_list(gts))
429 | else:
430 | input_image_list = get_image_list(inputs)
431 | gt_image_list = get_image_list(gts)
432 | npz_file = os.path.join(save_path, 'metrics.npz')
433 | if os.path.exists(npz_file):
434 | f = np.load(npz_file)
435 | psnr, ssim, ssim_256, mae, l1 = f['psnr'], f['ssim'], f['ssim_256'], f['mae'], f['l1']
436 | else:
437 | psnr = []
438 | ssim = []
439 | ssim_256 = []
440 | mae = []
441 | l1 = []
442 | names = []
443 |
444 | for index in range(len(input_image_list)):
445 | name = os.path.basename(input_image_list[index])
446 | names.append(name)
447 |
448 | img_gt = imread(str(gt_image_list[index])).astype(np.float32) / 255.0
449 | img_pred = imread(str(input_image_list[index])).astype(np.float32) / 255.0
450 |
451 | if debug != 0:
452 | plt.subplot('121')
453 | plt.imshow(img_gt)
454 | plt.title('Groud truth')
455 | plt.subplot('122')
456 | plt.imshow(img_pred)
457 | plt.title('Output')
458 | plt.show()
459 |
460 | psnr.append(compare_psnr(img_gt, img_pred, data_range=self.data_range))
461 | ssim.append(compare_ssim(img_gt, img_pred, data_range=self.data_range,
462 | win_size=self.win_size, multichannel=self.multichannel))
463 | mae.append(compare_mae(img_gt, img_pred))
464 | l1.append(compare_l1(img_gt, img_pred))
465 |
466 | img_gt_256 = img_gt * 255.0
467 | img_pred_256 = img_pred * 255.0
468 | ssim_256.append(compare_ssim(img_gt_256, img_pred_256, gaussian_weights=True, sigma=1.5,
469 | use_sample_covariance=False, multichannel=True,
470 | data_range=img_pred_256.max() - img_pred_256.min()))
471 | if np.mod(index, 200) == 0:
472 | print(
473 | str(index) + ' images processed',
474 | "PSNR: %.4f" % round(np.mean(psnr), 4),
475 | "SSIM: %.4f" % round(np.mean(ssim), 4),
476 | "SSIM_256: %.4f" % round(np.mean(ssim_256), 4),
477 | "MAE: %.4f" % round(np.mean(mae), 4),
478 | "l1: %.4f" % round(np.mean(l1), 4),
479 | )
480 |
481 | if save_path:
482 | np.savez(save_path + '/metrics.npz', psnr=psnr, ssim=ssim, ssim_256=ssim_256, mae=mae, l1=l1,
483 | names=names)
484 |
485 | print(
486 | "PSNR: %.4f" % round(np.mean(psnr), 4),
487 | "PSNR Variance: %.4f" % round(np.var(psnr), 4),
488 | "SSIM: %.4f" % round(np.mean(ssim), 4),
489 | "SSIM Variance: %.4f" % round(np.var(ssim), 4),
490 | "SSIM_256: %.4f" % round(np.mean(ssim_256), 4),
491 | "SSIM_256 Variance: %.4f" % round(np.var(ssim_256), 4),
492 | "MAE: %.4f" % round(np.mean(mae), 4),
493 | "MAE Variance: %.4f" % round(np.var(mae), 4),
494 | "l1: %.4f" % round(np.mean(l1), 4),
495 | "l1 Variance: %.4f" % round(np.var(l1), 4)
496 | )
497 |
498 | dic = {"psnr": [round(np.mean(psnr), 6)],
499 | "psnr_variance": [round(np.var(psnr), 6)],
500 | "ssim": [round(np.mean(ssim), 6)],
501 | "ssim_variance": [round(np.var(ssim), 6)],
502 | "ssim_256": [round(np.mean(ssim_256), 6)],
503 | "ssim_256_variance": [round(np.var(ssim_256), 6)],
504 | "mae": [round(np.mean(mae), 6)],
505 | "mae_variance": [round(np.var(mae), 6)],
506 | "l1": [round(np.mean(l1), 6)],
507 | "l1_variance": [round(np.var(l1), 6)]}
508 |
509 | return dic
510 |
511 |
512 | def get_image_list(flist):
513 | if isinstance(flist, list):
514 | return flist
515 |
516 | # flist: image file path, image directory path, text file flist path
517 | if isinstance(flist, str):
518 | if os.path.isdir(flist):
519 | flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png'))
520 | flist.sort()
521 | return flist
522 |
523 | if os.path.isfile(flist):
524 | try:
525 | return np.genfromtxt(flist, dtype=np.str)
526 | except:
527 | return [flist]
528 | print('can not read files from %s return empty list'%flist)
529 | return []
530 |
531 |
532 | def compare_l1(img_true, img_test):
533 | img_true = img_true.astype(np.float32)
534 | img_test = img_test.astype(np.float32)
535 | return np.mean(np.abs(img_true - img_test))
536 |
537 |
538 | def compare_mae(img_true, img_test):
539 | img_true = img_true.astype(np.float32)
540 | img_test = img_test.astype(np.float32)
541 | return np.sum(np.abs(img_true - img_test)) / np.sum(img_true + img_test)
542 |
543 |
544 | def preprocess_path_for_deform_task(gt_path, distorted_path):
545 | distorted_image_list = sorted(get_image_list(distorted_path))
546 | gt_list=[]
547 | distorated_list=[]
548 |
549 | for distorted_image in distorted_image_list:
550 | image = os.path.basename(distorted_image)
551 | image = image.split('_2_')[-1]
552 | image = image.split('_vis')[0] +'.jpg'
553 | gt_image = os.path.join(gt_path, image)
554 | if not os.path.isfile(gt_image):
555 | print("hhhhhhhhh")
556 | print(gt_image)
557 | continue
558 | gt_list.append(gt_image)
559 | distorated_list.append(distorted_image)
560 |
561 | return gt_list, distorated_list
562 |
563 |
564 |
565 | class LPIPS():
566 | def __init__(self, use_gpu=True):
567 | self.model = dm.DistModel()
568 | self.model.initialize(model='net-lin', net='alex',use_gpu=use_gpu)
569 | self.use_gpu=use_gpu
570 |
571 | def __call__(self, image_1, image_2):
572 | """
573 | image_1: images with size (n, 3, w, h) with value [-1, 1]
574 | image_2: images with size (n, 3, w, h) with value [-1, 1]
575 | """
576 | result = self.model.forward(image_1, image_2)
577 | return result
578 |
579 | def calculate_from_disk(self, path_1, path_2, batch_size=1, verbose=False, sort=True):
580 | if sort:
581 | files_1 = sorted(get_image_list(path_1))
582 | files_2 = sorted(get_image_list(path_2))
583 | else:
584 | files_1 = get_image_list(path_1)
585 | files_2 = get_image_list(path_2)
586 |
587 |
588 | imgs_1 = np.array([imread(str(fn)).astype(np.float32)/127.5-1 for fn in files_1])
589 | imgs_2 = np.array([imread(str(fn)).astype(np.float32)/127.5-1 for fn in files_2])
590 |
591 | # Bring images to shape (B, 3, H, W)
592 | imgs_1 = imgs_1.transpose((0, 3, 1, 2))
593 | imgs_2 = imgs_2.transpose((0, 3, 1, 2))
594 |
595 | result=[]
596 |
597 |
598 | d0 = imgs_1.shape[0]
599 | if batch_size > d0:
600 | print(('Warning: batch size is bigger than the data size. '
601 | 'Setting batch size to data size'))
602 | batch_size = d0
603 |
604 | n_batches = d0 // batch_size
605 |
606 | for i in range(n_batches):
607 | if verbose:
608 | print('\rPropagating batch %d/%d' % (i + 1, n_batches))
609 | start = i * batch_size
610 | end = start + batch_size
611 |
612 | img_1_batch = torch.from_numpy(imgs_1[start:end]).type(torch.FloatTensor)
613 | img_2_batch = torch.from_numpy(imgs_2[start:end]).type(torch.FloatTensor)
614 |
615 | if self.use_gpu:
616 | img_1_batch = img_1_batch.cuda()
617 | img_2_batch = img_2_batch.cuda()
618 |
619 | a = self.model.forward(img_1_batch, img_2_batch).item()
620 | result.append(a)
621 |
622 |
623 | distance = np.average(result)
624 | print('lpips: ', distance)
625 | return distance
626 |
627 | def calculate_mask_lpips(self, path_1, path_2, batch_size=64, verbose=False, sort=True):
628 | if sort:
629 | files_1 = sorted(get_image_list(path_1))
630 | files_2 = sorted(get_image_list(path_2))
631 | else:
632 | files_1 = get_image_list(path_1)
633 | files_2 = get_image_list(path_2)
634 |
635 | imgs_1=[]
636 | imgs_2=[]
637 | bonesLst = '/media/data1/zhangpz/DataSet/Market/market-annotation-test.csv'
638 | annotation_file = pd.read_csv(bonesLst, sep=':')
639 | annotation_file = annotation_file.set_index('name')
640 |
641 | for i in range(len(files_1)):
642 | string = annotation_file.loc[os.path.basename(files_2[i])]
643 | mask = np.tile(np.expand_dims(create_masked_image(string).astype(np.float32), -1), (1,1,3))#.repeat(1,1,3)
644 | imgs_1.append((imread(str(files_1[i])).astype(np.float32)/127.5-1)*mask)
645 | imgs_2.append((imread(str(files_2[i])).astype(np.float32)/127.5-1)*mask)
646 |
647 | # Bring images to shape (B, 3, H, W)
648 | imgs_1 = np.array(imgs_1)
649 | imgs_2 = np.array(imgs_2)
650 | imgs_1 = imgs_1.transpose((0, 3, 1, 2))
651 | imgs_2 = imgs_2.transpose((0, 3, 1, 2))
652 |
653 | result=[]
654 |
655 |
656 | d0 = imgs_1.shape[0]
657 | if batch_size > d0:
658 | print(('Warning: batch size is bigger than the data size. '
659 | 'Setting batch size to data size'))
660 | batch_size = d0
661 |
662 | n_batches = d0 // batch_size
663 |
664 | for i in range(n_batches):
665 | if verbose:
666 | print('\rPropagating batch %d/%d' % (i + 1, n_batches))
667 | start = i * batch_size
668 | end = start + batch_size
669 |
670 | img_1_batch = torch.from_numpy(imgs_1[start:end]).type(torch.FloatTensor)
671 | img_2_batch = torch.from_numpy(imgs_2[start:end]).type(torch.FloatTensor)
672 |
673 | if self.use_gpu:
674 | img_1_batch = img_1_batch.cuda()
675 | img_2_batch = img_2_batch.cuda()
676 |
677 |
678 | result.append(self.model.forward(img_1_batch, img_2_batch))
679 |
680 |
681 | distance = torch.mean(torch.stack(result))
682 | print('lpips_mask: ', distance)
683 | return distance
684 |
685 |
686 | def produce_ma_mask(kp_array, img_size=(128, 64), point_radius=4):
687 | MISSING_VALUE = -1
688 | from skimage.morphology import dilation, erosion, square
689 | mask = np.zeros(shape=img_size, dtype=bool)
690 | limbs = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10],
691 | [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17],
692 | [1,16], [16,18], [2,17], [2,18], [9,12], [12,6], [9,3], [17,18]]
693 | limbs = np.array(limbs) - 1
694 | for f, t in limbs:
695 | from_missing = kp_array[f][0] == MISSING_VALUE or kp_array[f][1] == MISSING_VALUE
696 | to_missing = kp_array[t][0] == MISSING_VALUE or kp_array[t][1] == MISSING_VALUE
697 | if from_missing or to_missing:
698 | continue
699 |
700 | norm_vec = kp_array[f] - kp_array[t]
701 | norm_vec = np.array([-norm_vec[1], norm_vec[0]])
702 | norm_vec = point_radius * norm_vec / np.linalg.norm(norm_vec)
703 |
704 |
705 | vetexes = np.array([
706 | kp_array[f] + norm_vec,
707 | kp_array[f] - norm_vec,
708 | kp_array[t] - norm_vec,
709 | kp_array[t] + norm_vec
710 | ])
711 | yy, xx = polygon(vetexes[:, 0], vetexes[:, 1], shape=img_size)
712 | mask[yy, xx] = True
713 |
714 | for i, joint in enumerate(kp_array):
715 | if kp_array[i][0] == MISSING_VALUE or kp_array[i][1] == MISSING_VALUE:
716 | continue
717 | yy, xx = circle(joint[0], joint[1], radius=point_radius, shape=img_size)
718 | mask[yy, xx] = True
719 |
720 | mask = dilation(mask, square(5))
721 | mask = erosion(mask, square(5))
722 | return mask
723 |
724 |
725 | def load_pose_cords_from_strings(y_str, x_str):
726 | y_cords = json.loads(y_str)
727 | x_cords = json.loads(x_str)
728 | return np.concatenate([np.expand_dims(y_cords, -1), np.expand_dims(x_cords, -1)], axis=1)
729 |
730 |
731 | def create_masked_image(ano_to):
732 | kp_to = load_pose_cords_from_strings(ano_to['keypoints_y'], ano_to['keypoints_x'])
733 | mask = produce_ma_mask(kp_to)
734 | return mask
735 |
736 |
737 | if __name__ == "__main__":
738 | parser = argparse.ArgumentParser(description='script to compute all statistics')
739 | parser.add_argument('--gt_path', help='Path to ground truth data', type=str)
740 | parser.add_argument('--distorated_path', help='Path to output data', type=str)
741 | parser.add_argument('--fid_real_path', help='Path to real images when calculate FID', type=str)
742 | parser.add_argument('--name', help='name of the experiment', type=str)
743 | parser.add_argument('--calculate_mask', action='store_true')
744 | parser.add_argument('--market', action='store_true')
745 | args = parser.parse_args()
746 |
747 | print('load start')
748 |
749 | fid = FID()
750 | print('load FID')
751 |
752 | if args.market:
753 | rec = Reconstruction_Market_Metrics()
754 | print('load market rec')
755 | else:
756 | rec = Reconstruction_Metrics()
757 | print('load rec')
758 |
759 | lpips = LPIPS()
760 | print('load LPIPS')
761 |
762 | for arg in vars(args):
763 | print('[%s] =' % arg, getattr(args, arg))
764 |
765 | print('calculate LPIPS...')
766 | gt_list, distorated_list = preprocess_path_for_deform_task(args.gt_path, args.distorated_path)
767 | lpips_score = lpips.calculate_from_disk(distorated_list, gt_list, sort=False)
768 |
769 | print('calculate fid metric...')
770 | fid_score = fid.calculate_from_disk(args.distorated_path, args.fid_real_path)
771 |
772 | print('calculate reconstruction metric...')
773 | rec_dic = rec.calculate_from_disk(distorated_list, gt_list, save_path=args.distorated_path, sort=False, debug=False)
774 |
775 | if args.calculate_mask:
776 | mask_lpips_score = lpips.calculate_mask_lpips(distorated_list, gt_list, sort=False)
777 |
778 | dic = {}
779 | dic['name'] = [args.name]
780 | for key in rec_dic:
781 | dic[key] = rec_dic[key]
782 | dic['fid'] = [fid_score]
783 |
784 | print('fid', fid_score)
785 |
786 | dic['lpips']=[lpips_score]
787 | print('lpips_score', lpips_score)
788 |
789 | if args.calculate_mask:
790 | dic['mask_lpips']=[mask_lpips_score]
791 |
792 |
793 |
794 |
795 |
796 |
797 |
798 |
799 |
--------------------------------------------------------------------------------