├── README.md
├── data
├── __init__.py
├── base_data_loader.py
├── base_dataset.py
├── custom_dataset_data_loader.py
├── data_loader.py
├── image_folder.py
└── keypoint.py
├── head_img3_00.png
├── losses
├── CX_style_loss.py
├── L1_plus_perceptualLoss.py
├── __init__.py
├── gan.py
└── lpips
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── lpips.cpython-36.pyc
│ ├── networks.cpython-36.pyc
│ └── utils.cpython-36.pyc
│ ├── lpips.py
│ ├── networks.py
│ └── utils.py
├── models
├── CASD.py
├── __init__.py
├── adgan.py
├── base_model.py
├── models.py
├── networks.py
├── test_model.py
├── vgg.py
└── vgg_SC.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── requirements.txt
├── test.py
├── tool
├── generate_fashion_datasets.py
├── generate_pose_map_fashion.py
└── resize_fashion.py
├── train.py
└── util
├── __init__.py
├── get_data.py
├── html.py
├── image_pool.py
├── png.py
├── pose_utils.py
├── util.py
└── visualizer.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Cross-Attention-Based-Style-Distribution
3 |
4 | The source code for our paper "[Cross Attention Based Style Distribution for Controllable Person Image Synthesis](https://arxiv.org/abs/2208.00712)" (**ECCV2022**).
5 |
6 |
7 |
8 |
9 |
10 |
11 | ## Installation
12 |
13 | #### Requirements
14 |
15 | - Python 3
16 | - PyTorch 1.7.0
17 | - CUDA 10.2
18 |
19 | #### Conda Installation
20 |
21 | ``` bash
22 | # 1. Create a conda virtual environment.
23 | conda create -n CASD python=3.6
24 | conda activate CASD
25 | conda install -c pytorch pytorch=1.7.0 torchvision=0.8.0 cudatoolkit=10.2
26 |
27 | # 2. Install other dependencies.
28 | pip install -r requirements.txt
29 | ```
30 |
31 |
32 | ### Data Preperation
33 |
34 | The dataset structure is recommended as:
35 | ```
36 | +—dataset
37 | | +—fashion
38 | | +--train (person images in 'train.lst')
39 | | +--test (person images in 'test.lst')
40 | | +--train_resize (resized person images in 'train.lst')
41 | | +--test_resize (resized person images in 'test.lst')
42 | | +--trainK(keypoints of person images)
43 | | +--testK(keypoints of person images)
44 | | +—semantic_merge3(semantic masks of person images)
45 | | +—fashion-resize-pairs-train.csv
46 | | +—fashion-resize-pairs-test.csv
47 | | +—fasion-resize-annotation-pairs-train.csv
48 | | +—fasion-resize-annotation-pairs-test.csv
49 | | +—train.lst
50 | | +—test.lst
51 | | +—vgg19-dcbb9e9d.pth
52 | | +—vgg_conv.pth
53 | | +—vgg.pth
54 | ...
55 | ```
56 |
57 |
58 | 1. Person images
59 |
60 | - Download `img_highres.zip` of the DeepFashion Dataset from [In-shop Clothes Retrieval Benchmark](https://drive.google.com/drive/folders/0B7EVK8r0v71pYkd5TzBiclMzR00).
61 |
62 | - Unzip `img_highres.zip`. You will need to ask for password from the [dataset maintainers](http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html). Then put the obtained folder **img_highres** under the `./dataset/fashion` directory.
63 |
64 | - Download train/test key points annotations and the train/test pairs from [Google Drive](https://drive.google.com/drive/folders/1qGRZUJY7QipLRDNQ0lhCubDPsJxmX2jK?usp=sharing) including **fashion-resize-pairs-train.csv**, **fashion-resize-pairs-test.csv**, **fashion-resize-annotation-train.csv**, **fashion-resize-annotation-test.csv,** **train.lst**, **test.lst**. Put these files under the `./dataset/fashion` directory.
65 |
66 | - Run the following code to split the train/test dataset.
67 |
68 | ```bash
69 | python tool/generate_fashion_datasets.py
70 | ```
71 |
72 | - Run the following code to resize the train/test dataset.
73 |
74 | ```bash
75 | python tool/resize_fashion.py
76 | ```
77 |
78 |
79 | 2. Keypoints files
80 |
81 | - Generate the pose heatmaps. Launch
82 | ```bash
83 | python tool/generate_pose_map_fashion.py
84 | ```
85 |
86 | 3. Segmentation files
87 | - Extract human segmentation results from existing human parser (e.g. LIP_JPPNet). Our segmentation results ‘semantic_merge3’ are provided in [Google Drive](https://drive.google.com/drive/folders/1qGRZUJY7QipLRDNQ0lhCubDPsJxmX2jK?usp=sharing). Put it under the ```./dataset/fashion``` directory.
88 |
89 |
90 | ### Training
91 |
92 | ```bash
93 | python train.py --dataroot ./dataset/fashion --dirSem ./dataset/fashion --pairLst ./dataset/fashion/fashion-resize-pairs-train.csv --name CASD_test --batchSize 16 --gpu_ids 0,1 --which_model_netG CASD --checkpoints_dir ./checkpoints
94 | ```
95 | The models are save in `./checkpoints`.
96 |
97 | ### Testing
98 | Download our pretrained model from [Google Drive](https://drive.google.com/drive/folders/1qGRZUJY7QipLRDNQ0lhCubDPsJxmX2jK?usp=sharing). Put the obtained checkpoints under `./checkpoints/CASD_test`. Modify your data path and launch
99 | ```bash
100 | python test.py --dataroot ./dataset/fashion --dirSem ./dataset/fashion --pairLst ./dataset/fashion/fashion-resize-pairs-test.csv --checkpoints_dir ./checkpoints --results_dir ./results --name CASD_test --phase test --batchSize 1 --gpu_ids 0,0 --which_model_netG CASD --which_epoch 1000
101 | ```
102 | The result images are save in `./results`.
103 |
104 | ## Citation
105 | If you use this code for your research, please cite
106 | ```
107 | @article{zhou2022casd,
108 | title={Cross Attention Based Style Distribution for Controllable Person Image Synthesis},
109 | author={Zhou, Xinyue and Yin, Mingyu and Chen, Xinyuan and Sun, Li and Gao, Changxin and Li, Qingli},
110 | journal={arXiv preprint arXiv:2208.00712},
111 | year={2022}
112 | }
113 | ```
114 |
115 |
116 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/data/__init__.py
--------------------------------------------------------------------------------
/data/base_data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | class BaseDataLoader():
3 | def __init__(self):
4 | pass
5 |
6 | def initialize(self, opt):
7 | self.opt = opt
8 | pass
9 |
10 | def load_data():
11 | return None
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data as data
2 | from PIL import Image
3 | import torchvision.transforms as transforms
4 |
5 | class BaseDataset(data.Dataset):
6 | def __init__(self):
7 | super(BaseDataset, self).__init__()
8 |
9 | def name(self):
10 | return 'BaseDataset'
11 |
12 | def initialize(self, opt):
13 | pass
14 |
15 |
16 | def get_transform(opt):
17 | transform_list = []
18 | if opt.resize_or_crop == 'resize_and_crop':
19 | osize = [opt.loadSize, opt.loadSize]
20 | transform_list.append(transforms.Scale(osize, Image.BICUBIC))
21 | transform_list.append(transforms.RandomCrop(opt.FfineSize))
22 | elif opt.resize_or_crop == 'crop':
23 | transform_list.append(transforms.RandomCrop(opt.fineSize))
24 | elif opt.resize_or_crop == 'scale_width':
25 | transform_list.append(transforms.Lambda(
26 | lambda img: __scale_width(img, opt.fineSize)))
27 | elif opt.resize_or_crop == 'scale_width_and_crop':
28 | transform_list.append(transforms.Lambda(
29 | lambda img: __scale_width(img, opt.loadSize)))
30 | transform_list.append(transforms.RandomCrop(opt.fineSize))
31 |
32 | transform_list += [transforms.ToTensor(),
33 | transforms.Normalize((0.5, 0.5, 0.5),
34 | (0.5, 0.5, 0.5))]
35 | return transforms.Compose(transform_list)
36 |
37 | def __scale_width(img, target_width):
38 | ow, oh = img.size
39 | if (ow == target_width):
40 | return img
41 | w = target_width
42 | h = int(target_width * oh / ow)
43 | return img.resize((w, h), Image.BICUBIC)
44 |
--------------------------------------------------------------------------------
/data/custom_dataset_data_loader.py:
--------------------------------------------------------------------------------
1 | import torch.utils.data
2 | from data.base_data_loader import BaseDataLoader
3 |
4 |
5 | def CreateDataset(opt):
6 | dataset = None
7 |
8 | if opt.dataset_mode == 'keypoint':
9 | from data.keypoint import KeyDataset
10 | dataset = KeyDataset()
11 | elif opt.dataset_mode == 'keypoint_mix':
12 | from data.keypoint_mix import KeyDataset
13 | dataset = KeyDataset()
14 | else:
15 | raise ValueError("Dataset [%s] not recognized." % opt.dataset_mode)
16 |
17 | print("dataset [%s] was created" % (dataset.name()))
18 | dataset.initialize(opt)
19 | return dataset
20 |
21 |
22 | class CustomDatasetDataLoader(BaseDataLoader):
23 | def name(self):
24 | return 'CustomDatasetDataLoader'
25 |
26 | def initialize(self, opt):
27 | BaseDataLoader.initialize(self, opt)
28 | self.dataset = CreateDataset(opt)
29 | self.dataloader = torch.utils.data.DataLoader(
30 | self.dataset,
31 | batch_size=opt.batchSize,
32 | shuffle=not opt.serial_batches,
33 | num_workers=int(opt.nThreads))
34 |
35 | def load_data(self):
36 | return self
37 |
38 | def __len__(self):
39 | return min(len(self.dataset), self.opt.max_dataset_size)
40 |
41 | def __iter__(self):
42 | for i, data in enumerate(self.dataloader):
43 | if i >= self.opt.max_dataset_size:
44 | break
45 | yield data
46 |
--------------------------------------------------------------------------------
/data/data_loader.py:
--------------------------------------------------------------------------------
1 |
2 | def CreateDataLoader(opt):
3 | from data.custom_dataset_data_loader import CustomDatasetDataLoader
4 | data_loader = CustomDatasetDataLoader()
5 | print(data_loader.name())
6 | data_loader.initialize(opt)
7 | return data_loader
8 |
--------------------------------------------------------------------------------
/data/image_folder.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4 | # Modified the original code so that it also loads images from the current
5 | # directory as well as the subdirectories
6 | ###############################################################################
7 |
8 | import torch.utils.data as data
9 |
10 | from PIL import Image
11 | import os
12 | import os.path
13 |
14 | IMG_EXTENSIONS = [
15 | '.jpg', '.JPG', '.jpeg', '.JPEG',
16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
17 | ]
18 |
19 |
20 | def is_image_file(filename):
21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
22 |
23 |
24 | def make_dataset(dir):
25 | images = []
26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
27 |
28 | for root, _, fnames in sorted(os.walk(dir)):
29 | for fname in fnames:
30 | if is_image_file(fname):
31 | path = os.path.join(root, fname)
32 | images.append(path)
33 |
34 | return images
35 |
36 |
37 | def default_loader(path):
38 | return Image.open(path).convert('RGB')
39 |
40 |
41 | class ImageFolder(data.Dataset):
42 |
43 | def __init__(self, root, transform=None, return_paths=False,
44 | loader=default_loader):
45 | imgs = make_dataset(root)
46 | if len(imgs) == 0:
47 | raise(RuntimeError("Found 0 images in: " + root + "\n"
48 | "Supported image extensions are: " +
49 | ",".join(IMG_EXTENSIONS)))
50 |
51 | self.root = root
52 | self.imgs = imgs
53 | self.transform = transform
54 | self.return_paths = return_paths
55 | self.loader = loader
56 |
57 | def __getitem__(self, index):
58 | path = self.imgs[index]
59 | img = self.loader(path)
60 | if self.transform is not None:
61 | img = self.transform(img)
62 | if self.return_paths:
63 | return img, path
64 | else:
65 | return img
66 |
67 | def __len__(self):
68 | return len(self.imgs)
69 |
--------------------------------------------------------------------------------
/data/keypoint.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | from data.base_dataset import BaseDataset, get_transform
3 | from PIL import Image
4 | import random
5 | import pandas as pd
6 | import torch
7 | import util.util as util
8 | import numpy as np
9 | import torchvision.transforms.functional as F
10 |
11 | class KeyDataset(BaseDataset):
12 | def initialize(self, opt):
13 | self.opt = opt
14 | self.root = opt.dataroot
15 | self.dir_P = os.path.join(opt.dataroot, opt.phase + '_resize')
16 | self.dir_K = os.path.join(opt.dataroot, opt.phase + 'K')
17 | self.dir_SP = opt.dirSem
18 | self.SP_input_nc = opt.SP_input_nc
19 |
20 | self.init_categories(opt.pairLst)
21 | self.transform = get_transform(opt)
22 | self.use_BPD = self.opt.use_BPD
23 |
24 | self.finesize = opt.fineSize
25 |
26 | def init_categories(self, pairLst):
27 | pairs_file_train = pd.read_csv(pairLst)
28 | self.size = len(pairs_file_train)
29 | self.pairs = []
30 | print('Loading data pairs ...')
31 | for i in range(self.size):
32 | pair = [pairs_file_train.iloc[i]['from'], pairs_file_train.iloc[i]['to']]
33 | self.pairs.append(pair)
34 |
35 | print('Loading data pairs finished ...')
36 |
37 | def __getitem__(self, index):
38 | if self.opt.phase == 'train':
39 | index = random.randint(0, self.size - 1)
40 |
41 | P1_name, P2_name = self.pairs[index]
42 | P1_path = os.path.join(self.dir_P, P1_name)
43 | BP1_path = os.path.join(self.dir_K, P1_name + '.npy')
44 |
45 | P2_path = os.path.join(self.dir_P, P2_name)
46 | BP2_path = os.path.join(self.dir_K, P2_name + '.npy')
47 |
48 | P1_img = Image.open(P1_path).convert('RGB')
49 | P2_img = Image.open(P2_path).convert('RGB')
50 |
51 | BP1_img = np.load(BP1_path)
52 | BP2_img = np.load(BP2_path)
53 |
54 | if self.use_BPD:
55 | BPD1_img = util.draw_dis_from_map(BP1_img)[0]
56 | BPD2_img = util.draw_dis_from_map(BP2_img)[0]
57 |
58 | # use flip
59 | if self.opt.phase == 'train' and self.opt.use_flip:
60 | # print ('use_flip ...')
61 | flip_random = random.uniform(0, 1)
62 |
63 | if flip_random > 0.5:
64 | # print('fliped ...')
65 | P1_img = P1_img.transpose(Image.FLIP_LEFT_RIGHT)
66 | P2_img = P2_img.transpose(Image.FLIP_LEFT_RIGHT)
67 |
68 | BP1_img = np.array(BP1_img[:, ::-1, :])
69 | BP2_img = np.array(BP2_img[:, ::-1, :])
70 |
71 | BP1 = torch.from_numpy(BP1_img).float()
72 | BP1 = BP1.transpose(2, 0)
73 | BP1 = BP1.transpose(2, 1)
74 |
75 | BP2 = torch.from_numpy(BP2_img).float()
76 | BP2 = BP2.transpose(2, 0)
77 | BP2 = BP2.transpose(2, 1)
78 |
79 | P1 = self.transform(P1_img)
80 | P2 = self.transform(P2_img)
81 | else:
82 | BP1 = torch.from_numpy(BP1_img).float()
83 | BP1 = BP1.transpose(2, 0)
84 | BP1 = BP1.transpose(2, 1)
85 |
86 | BP2 = torch.from_numpy(BP2_img).float()
87 | BP2 = BP2.transpose(2, 0)
88 | BP2 = BP2.transpose(2, 1)
89 |
90 | P1 = self.transform(P1_img)
91 | P2 = self.transform(P2_img)
92 | if self.use_BPD:
93 | BPD1 = torch.from_numpy(BPD1_img).float()
94 | BPD1 = BPD1.transpose(2, 0)
95 | BPD1 = BPD1.transpose(2, 1)
96 |
97 | BPD2 = torch.from_numpy(BPD2_img).float()
98 | BPD2 = BPD2.transpose(2, 0)
99 | BPD2 = BPD2.transpose(2, 1)
100 |
101 |
102 | SP1_name = self.split_name_sementic3(P1_name, 'semantic_merge3')
103 | SP2_name = self.split_name_sementic3(P2_name, 'semantic_merge3')
104 | SP1_path = os.path.join(self.dir_SP, SP1_name)
105 | SP1_path = SP1_path[:-4] + '.png'
106 | SP1_data = Image.open(SP1_path)
107 | SP1_data = np.array(SP1_data)
108 | SP2_path = os.path.join(self.dir_SP, SP2_name)
109 | SP2_path = SP2_path[:-4] + '.png'
110 | SP2_data = Image.open(SP2_path)
111 | SP2_data = np.array(SP2_data)
112 | SP1 = np.zeros((self.SP_input_nc, self.finesize[0], self.finesize[1]), dtype='float32')
113 | SP2 = np.zeros((self.SP_input_nc, self.finesize[0], self.finesize[1]), dtype='float32')
114 | SP1_20 = np.zeros((20, self.finesize[0], self.finesize[1]), dtype='float32')
115 | SP2_20 = np.zeros((20, self.finesize[0], self.finesize[1]), dtype='float32')
116 | nc = 20
117 | for id in range(nc):
118 | SP1_20[id] = (SP1_data == id).astype('float32')
119 | SP2_20[id] = (SP2_data == id).astype('float32')
120 | SP1[0] = SP1_20[0]
121 | SP1[1] = SP1_20[9] + SP1_20[12]
122 | SP1[2] = SP1_20[2] + SP1_20[1]
123 | SP1[3] = SP1_20[3]
124 | SP1[4] = SP1_20[13] + SP1_20[4]
125 | SP1[5] = SP1_20[5] + SP1_20[6] + SP1_20[7] + SP1_20[10] + SP1_20[11]
126 | SP1[6] = SP1_20[14] + SP1_20[15]
127 | SP1[7] = SP1_20[8] + SP1_20[16] + SP1_20[17] + SP1_20[18] + SP1_20[19]
128 |
129 | SP2[0] = SP2_20[0]
130 | SP2[1] = SP2_20[9] + SP2_20[12]
131 | SP2[2] = SP2_20[2] + SP2_20[1]
132 | SP2[3] = SP2_20[3]
133 | SP2[4] = SP2_20[13] + SP2_20[4]
134 | SP2[5] = SP2_20[5] + SP2_20[6] + SP2_20[7] + SP2_20[10] + SP2_20[11]
135 | SP2[6] = SP2_20[14] + SP2_20[15]
136 | SP2[7] = SP2_20[8] + SP2_20[16] + SP2_20[17] + SP2_20[18] + SP2_20[19]
137 |
138 |
139 | if self.use_BPD:
140 | return {'P1': P1, 'BP1': BP1, 'SP1': SP1, 'BPD1': BPD1,
141 | 'P2': P2, 'BP2': BP2, 'SP2': SP2, 'BPD2': BPD2,
142 | 'P1_path': P1_name, 'P2_path': P2_name}
143 | else:
144 | return {'P1': P1, 'BP1': BP1, 'SP1': SP1,
145 | 'P2': P2, 'BP2': BP2, 'SP2': SP2,
146 | 'P1_path': P1_name, 'P2_path': P2_name}
147 |
148 | def __len__(self):
149 | if self.opt.phase == 'train':
150 | return 4000
151 | elif self.opt.phase == 'test':
152 | return self.size
153 |
154 | def name(self):
155 | return 'KeyDataset'
156 |
157 |
158 | def split_name_sementic3(self, str, type):
159 | list = []
160 | list.append(type)
161 | list.append(str)
162 |
163 | head = ''
164 | for path in list:
165 | head = os.path.join(head, path)
166 | return head
167 |
168 |
--------------------------------------------------------------------------------
/head_img3_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/head_img3_00.png
--------------------------------------------------------------------------------
/losses/CX_style_loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from torch import nn
5 | import torch.nn.functional as F
6 |
7 | class CXLoss(nn.Module):
8 |
9 | def __init__(self, sigma=0.1, b=1.0, similarity="consine"):
10 | super(CXLoss, self).__init__()
11 | self.similarity = similarity
12 | self.sigma = sigma
13 | self.b = b
14 |
15 | def center_by_T(self, featureI, featureT):
16 | # Calculate mean channel vector for feature map.
17 | meanT = featureT.mean(0, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True)
18 | return featureI - meanT, featureT - meanT
19 |
20 | def l2_normalize_channelwise(self, features):
21 | # Normalize on channel dimension (axis=1)
22 | norms = features.norm(p=2, dim=1, keepdim=True)
23 | features = features.div(norms)
24 | return features
25 |
26 | def patch_decomposition(self, features):
27 | N, C, H, W = features.shape
28 | assert N == 1
29 | P = H * W
30 | # NCHW --> 1x1xCxHW --> HWxCx1x1
31 | patches = features.view(1, 1, C, P).permute((3, 2, 0, 1))
32 | return patches
33 |
34 | def calc_relative_distances(self, raw_dist, axis=1):
35 | epsilon = 1e-5
36 | div = torch.min(raw_dist, dim=axis, keepdim=True)[0]
37 | relative_dist = raw_dist / (div + epsilon)
38 | return relative_dist
39 |
40 | def calc_CX(self, dist, axis=1):
41 | W = torch.exp((self.b - dist) / self.sigma)
42 | W_sum = W.sum(dim=axis, keepdim=True)
43 | return W.div(W_sum)
44 |
45 | def forward(self, featureT, featureI):
46 | '''
47 | :param featureT: target
48 | :param featureI: inference
49 | :return:
50 | '''
51 | # NCHW
52 | # print(featureI.shape)
53 |
54 | featureI, featureT = self.center_by_T(featureI, featureT)
55 |
56 | featureI = self.l2_normalize_channelwise(featureI)
57 | featureT = self.l2_normalize_channelwise(featureT)
58 |
59 | dist = []
60 | N = featureT.size()[0]
61 | for i in range(N):
62 | # NCHW
63 | featureT_i = featureT[i, :, :, :].unsqueeze(0)
64 | # NCHW
65 | featureI_i = featureI[i, :, :, :].unsqueeze(0)
66 | featureT_patch = self.patch_decomposition(featureT_i)
67 | # Calculate cosine similarity
68 | # See the torch document for functional.conv2d
69 | dist_i = F.conv2d(featureI_i, featureT_patch)
70 | dist.append(dist_i)
71 |
72 | # NCHW
73 | dist = torch.cat(dist, dim=0)
74 |
75 | raw_dist = (1. - dist) / 2.
76 |
77 | relative_dist = self.calc_relative_distances(raw_dist)
78 |
79 | CX = self.calc_CX(relative_dist)
80 |
81 | CX = CX.max(dim=3)[0].max(dim=2)[0]
82 | CX = CX.mean(1)
83 | CX = -torch.log(CX)
84 | CX = torch.mean(CX)
85 | return CX
86 |
87 |
--------------------------------------------------------------------------------
/losses/L1_plus_perceptualLoss.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from torch import nn
5 | from torch.autograd import Variable
6 | import numpy as np
7 | import torch.nn.functional as F
8 | import torchvision.models as models
9 |
10 | class L1_plus_perceptualLoss(nn.Module):
11 | def __init__(self, lambda_L1, lambda_perceptual, perceptual_layers, gpu_ids, percep_is_l1):
12 | super(L1_plus_perceptualLoss, self).__init__()
13 |
14 | self.lambda_L1 = lambda_L1
15 | self.lambda_perceptual = lambda_perceptual
16 | self.gpu_ids = gpu_ids
17 |
18 | self.percep_is_l1 = percep_is_l1
19 |
20 | # vgg = models.vgg19(pretrained=True).features
21 | vgg19 = models.vgg19(pretrained=False)
22 | vgg19.load_state_dict(torch.load('/home/haihuam/CASD-main/dataset/fashion/vgg19-dcbb9e9d.pth'))
23 | vgg = vgg19.features
24 |
25 |
26 | self.vgg_submodel = nn.Sequential()
27 | for i,layer in enumerate(list(vgg)):
28 | self.vgg_submodel.add_module(str(i),layer)
29 | if i == perceptual_layers:
30 | break
31 | self.vgg_submodel = self.vgg_submodel.cuda()
32 | #self.vgg_submodel = torch.nn.DataParallel(self.vgg_submodel, device_ids=gpu_ids).cuda()
33 |
34 | print(self.vgg_submodel)
35 |
36 | def forward(self, inputs, targets):
37 | if self.lambda_L1 == 0 and self.lambda_perceptual == 0:
38 | return Variable(torch.zeros(1)).cuda(), Variable(torch.zeros(1)), Variable(torch.zeros(1))
39 | # normal L1
40 | loss_l1 = F.l1_loss(inputs, targets) * self.lambda_L1
41 |
42 | # perceptual L1
43 | mean = torch.FloatTensor(3)
44 | mean[0] = 0.485
45 | mean[1] = 0.456
46 | mean[2] = 0.406
47 | mean = Variable(mean)
48 | mean = mean.resize(1, 3, 1, 1).cuda()
49 |
50 | std = torch.FloatTensor(3)
51 | std[0] = 0.229
52 | std[1] = 0.224
53 | std[2] = 0.225
54 | std = Variable(std)
55 | std = std.resize(1, 3, 1, 1).cuda()
56 |
57 | fake_p2_norm = (inputs + 1)/2 # [-1, 1] => [0, 1]
58 | fake_p2_norm = (fake_p2_norm - mean)/std
59 |
60 | input_p2_norm = (targets + 1)/2 # [-1, 1] => [0, 1]
61 | input_p2_norm = (input_p2_norm - mean)/std
62 |
63 |
64 | fake_p2_norm = self.vgg_submodel(fake_p2_norm)
65 | input_p2_norm = self.vgg_submodel(input_p2_norm)
66 | input_p2_norm_no_grad = input_p2_norm.detach()
67 |
68 | if self.percep_is_l1 == 1:
69 | # use l1 for perceptual loss
70 | loss_perceptual = F.l1_loss(fake_p2_norm, input_p2_norm_no_grad) * self.lambda_perceptual
71 | else:
72 | # use l2 for perceptual loss
73 | loss_perceptual = F.mse_loss(fake_p2_norm, input_p2_norm_no_grad) * self.lambda_perceptual
74 |
75 | loss = loss_l1 + loss_perceptual
76 |
77 | return loss, loss_l1, loss_perceptual
78 |
79 |
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/losses/__init__.py
--------------------------------------------------------------------------------
/losses/gan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from util.distributed import master_only_print as print
6 |
7 |
8 | @torch.jit.script
9 | def fuse_math_min_mean_pos(x):
10 | r"""Fuse operation min mean for hinge loss computation of positive
11 | samples"""
12 | minval = torch.min(x - 1, x * 0)
13 | loss = -torch.mean(minval)
14 | return loss
15 |
16 |
17 | @torch.jit.script
18 | def fuse_math_min_mean_neg(x):
19 | r"""Fuse operation min mean for hinge loss computation of negative
20 | samples"""
21 | minval = torch.min(-x - 1, x * 0)
22 | loss = -torch.mean(minval)
23 | return loss
24 |
25 |
26 | class GANLoss(nn.Module):
27 | r"""GAN loss constructor.
28 |
29 | Args:
30 | gan_mode (str): Type of GAN loss. ``'hinge'``, ``'least_square'``,
31 | ``'non_saturated'``, ``'wasserstein'``.
32 | target_real_label (float): The desired output label for real images.
33 | target_fake_label (float): The desired output label for fake images.
34 | """
35 |
36 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
37 | super(GANLoss, self).__init__()
38 | self.real_label = target_real_label
39 | self.fake_label = target_fake_label
40 | self.real_label_tensor = None
41 | self.fake_label_tensor = None
42 | self.gan_mode = gan_mode
43 | print('GAN mode: %s' % gan_mode)
44 |
45 | def forward(self, dis_output, t_real, dis_update=True):
46 | r"""GAN loss computation.
47 |
48 | Args:
49 | dis_output (tensor or list of tensors): Discriminator outputs.
50 | t_real (bool): If ``True``, uses the real label as target, otherwise
51 | uses the fake label as target.
52 | dis_update (bool): If ``True``, the loss will be used to update the
53 | discriminator, otherwise the generator.
54 | Returns:
55 | loss (tensor): Loss value.
56 | """
57 | if isinstance(dis_output, list):
58 | # For multi-scale discriminators.
59 | # In this implementation, the loss is first averaged for each scale
60 | # (batch size and number of locations) then averaged across scales,
61 | # so that the gradient is not dominated by the discriminator that
62 | # has the most output values (highest resolution).
63 | loss = 0
64 | for dis_output_i in dis_output:
65 | assert isinstance(dis_output_i, torch.Tensor)
66 | loss += self.loss(dis_output_i, t_real, dis_update)
67 | return loss / len(dis_output)
68 | else:
69 | return self.loss(dis_output, t_real, dis_update)
70 |
71 | def loss(self, dis_output, t_real, dis_update=True):
72 | r"""GAN loss computation.
73 |
74 | Args:
75 | dis_output (tensor): Discriminator outputs.
76 | t_real (bool): If ``True``, uses the real label as target, otherwise
77 | uses the fake label as target.
78 | dis_update (bool): Updating the discriminator or the generator.
79 | Returns:
80 | loss (tensor): Loss value.
81 | """
82 | if not dis_update:
83 | assert t_real, \
84 | "The target should be real when updating the generator."
85 |
86 | if self.gan_mode == 'non_saturated':
87 | target_tensor = self.get_target_tensor(dis_output, t_real)
88 | loss = F.binary_cross_entropy_with_logits(dis_output,
89 | target_tensor)
90 | elif self.gan_mode == 'least_square':
91 | target_tensor = self.get_target_tensor(dis_output, t_real)
92 | loss = 0.5 * F.mse_loss(dis_output, target_tensor)
93 | elif self.gan_mode == 'hinge':
94 | if dis_update:
95 | if t_real:
96 | loss = fuse_math_min_mean_pos(dis_output)
97 | else:
98 | loss = fuse_math_min_mean_neg(dis_output)
99 | else:
100 | loss = -torch.mean(dis_output)
101 | elif self.gan_mode == 'wasserstein':
102 | if t_real:
103 | loss = -torch.mean(dis_output)
104 | else:
105 | loss = torch.mean(dis_output)
106 | elif self.gan_mode == 'style_gan2':
107 | if t_real:
108 | loss = F.softplus(-dis_output).mean()
109 | else:
110 | loss = F.softplus(dis_output).mean()
111 | else:
112 | raise ValueError('Unexpected gan_mode {}'.format(self.gan_mode))
113 | return loss
114 |
115 |
116 | def get_target_tensor(self, dis_output, t_real):
117 | r"""Return the target vector for the binary cross entropy loss
118 | computation.
119 |
120 | Args:
121 | dis_output (tensor): Discriminator outputs.
122 | t_real (bool): If ``True``, uses the real label as target, otherwise
123 | uses the fake label as target.
124 | Returns:
125 | target (tensor): Target tensor vector.
126 | """
127 | if t_real:
128 | if self.real_label_tensor is None:
129 | self.real_label_tensor = dis_output.new_tensor(self.real_label)
130 | return self.real_label_tensor.expand_as(dis_output)
131 | else:
132 | if self.fake_label_tensor is None:
133 | self.fake_label_tensor = dis_output.new_tensor(self.fake_label)
134 | return self.fake_label_tensor.expand_as(dis_output)
135 |
--------------------------------------------------------------------------------
/losses/lpips/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/losses/lpips/__init__.py
--------------------------------------------------------------------------------
/losses/lpips/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/losses/lpips/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/losses/lpips/__pycache__/lpips.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/losses/lpips/__pycache__/lpips.cpython-36.pyc
--------------------------------------------------------------------------------
/losses/lpips/__pycache__/networks.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/losses/lpips/__pycache__/networks.cpython-36.pyc
--------------------------------------------------------------------------------
/losses/lpips/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/losses/lpips/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/losses/lpips/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from losses.lpips.networks import get_network, LinLayers
5 | from losses.lpips.utils import get_state_dict
6 |
7 |
8 | class LPIPS(nn.Module):
9 | r"""Creates a criterion that measures
10 | Learned Perceptual Image Patch Similarity (LPIPS).
11 | Arguments:
12 | net_type (str): the network type to compare the features:
13 | 'alex' | 'squeeze' | 'vgg'. Default: 'alex'.
14 | version (str): the version of LPIPS. Default: 0.1.
15 | """
16 | def __init__(self, net_type: str = 'alex', version: str = '0.1'):
17 |
18 | assert version in ['0.1'], 'v0.1 is only supported now'
19 |
20 | super(LPIPS, self).__init__()
21 |
22 | # pretrained network
23 | self.net = get_network(net_type).to("cuda")
24 |
25 | # linear layers
26 | self.lin = LinLayers(self.net.n_channels_list).to("cuda")
27 | self.lin.load_state_dict(get_state_dict(net_type, version))
28 |
29 | def forward(self, x: torch.Tensor, y: torch.Tensor):
30 | feat_x, feat_y = self.net(x), self.net(y)
31 |
32 | diff = [(fx - fy) ** 2 for fx, fy in zip(feat_x, feat_y)]
33 | res = [l(d).mean((2, 3), True) for d, l in zip(diff, self.lin)]
34 |
35 | return torch.sum(torch.cat(res, 0)) / x.shape[0]
36 |
--------------------------------------------------------------------------------
/losses/lpips/networks.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | from itertools import chain
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torchvision import models
8 |
9 | from losses.lpips.utils import normalize_activation
10 |
11 |
12 | def get_network(net_type: str):
13 | if net_type == 'alex':
14 | return AlexNet()
15 | elif net_type == 'squeeze':
16 | return SqueezeNet()
17 | elif net_type == 'vgg':
18 | return VGG16()
19 | else:
20 | raise NotImplementedError('choose net_type from [alex, squeeze, vgg].')
21 |
22 |
23 | class LinLayers(nn.ModuleList):
24 | def __init__(self, n_channels_list: Sequence[int]):
25 | super(LinLayers, self).__init__([
26 | nn.Sequential(
27 | nn.Identity(),
28 | nn.Conv2d(nc, 1, 1, 1, 0, bias=False)
29 | ) for nc in n_channels_list
30 | ])
31 |
32 | for param in self.parameters():
33 | param.requires_grad = False
34 |
35 |
36 | class BaseNet(nn.Module):
37 | def __init__(self):
38 | super(BaseNet, self).__init__()
39 |
40 | # register buffer
41 | self.register_buffer(
42 | 'mean', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
43 | self.register_buffer(
44 | 'std', torch.Tensor([.458, .448, .450])[None, :, None, None])
45 |
46 | def set_requires_grad(self, state: bool):
47 | for param in chain(self.parameters(), self.buffers()):
48 | param.requires_grad = state
49 |
50 | def z_score(self, x: torch.Tensor):
51 | return (x - self.mean) / self.std
52 |
53 | def forward(self, x: torch.Tensor):
54 | x = self.z_score(x)
55 |
56 | output = []
57 | for i, (_, layer) in enumerate(self.layers._modules.items(), 1):
58 | x = layer(x)
59 | if i in self.target_layers:
60 | output.append(normalize_activation(x))
61 | if len(output) == len(self.target_layers):
62 | break
63 | return output
64 |
65 |
66 | class SqueezeNet(BaseNet):
67 | def __init__(self):
68 | super(SqueezeNet, self).__init__()
69 |
70 | self.layers = models.squeezenet1_1(True).features
71 | self.target_layers = [2, 5, 8, 10, 11, 12, 13]
72 | self.n_channels_list = [64, 128, 256, 384, 384, 512, 512]
73 |
74 | self.set_requires_grad(False)
75 |
76 |
77 | class AlexNet(BaseNet):
78 | def __init__(self):
79 | super(AlexNet, self).__init__()
80 |
81 | self.layers = models.alexnet(True).features
82 | self.target_layers = [2, 5, 8, 10, 12]
83 | self.n_channels_list = [64, 192, 384, 256, 256]
84 |
85 | self.set_requires_grad(False)
86 |
87 |
88 | class VGG16(BaseNet):
89 | def __init__(self):
90 | super(VGG16, self).__init__()
91 |
92 | self.layers = models.vgg16(True).features
93 | self.target_layers = [4, 9, 16, 23, 30]
94 | self.n_channels_list = [64, 128, 256, 512, 512]
95 |
96 | self.set_requires_grad(False)
--------------------------------------------------------------------------------
/losses/lpips/utils.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 |
5 |
6 | def normalize_activation(x, eps=1e-10):
7 | norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
8 | return x / (norm_factor + eps)
9 |
10 |
11 | def get_state_dict(net_type: str = 'alex', version: str = '0.1'):
12 | # build url
13 | url = 'https://raw.githubusercontent.com/richzhang/PerceptualSimilarity/' \
14 | + f'master/lpips/weights/v{version}/{net_type}.pth'
15 |
16 | # download
17 | old_state_dict = torch.hub.load_state_dict_from_url(
18 | url, progress=True,
19 | map_location=None if torch.cuda.is_available() else torch.device('cpu')
20 | )
21 |
22 | # rename keys
23 | new_state_dict = OrderedDict()
24 | for key, val in old_state_dict.items():
25 | new_key = key
26 | new_key = new_key.replace('lin', '')
27 | new_key = new_key.replace('model.', '')
28 | new_state_dict[new_key] = val
29 |
30 | return new_state_dict
31 |
--------------------------------------------------------------------------------
/models/CASD.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import functools
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | import os
7 | import torchvision.models.vgg as models
8 | from torch.nn.parameter import Parameter
9 |
10 | from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm
11 | import functools
12 |
13 |
14 | # Moddfied with AdINGen
15 | class ADGen(nn.Module):
16 | # AdaIN auto-encoder architecture
17 | def __init__(self, input_dim, dim, style_dim, n_downsample, n_res, mlp_dim, activ='relu', pad_type='reflect'):
18 | super(ADGen, self).__init__()
19 |
20 | # style encoder
21 | input_dim = 3
22 | self.SP_input_nc = 8
23 | self.enc_style = VggStyleEncoder(3, input_dim, dim, int(style_dim / self.SP_input_nc), norm='none', activ=activ,
24 | pad_type=pad_type)
25 |
26 | # content encoder
27 | self.enc_content = ContentEncoder(layers=2, ngf=64, img_f=512)
28 |
29 | input_dim = 3
30 | self.dec = Decoder(style_dim, mlp_dim, n_downsample, n_res, 256, input_dim,
31 | self.SP_input_nc, res_norm='adain', activ=activ, pad_type=pad_type)
32 |
33 | def forward(self, img_A, img_B, sem_B):
34 | content = self.enc_content(img_A)
35 | style = self.enc_style(img_B, sem_B)
36 | images_recon = self.dec(content, style)
37 | return images_recon
38 |
39 |
40 | def calc_mean_std(feat, eps=1e-5):
41 | # eps is a small value added to the variance to avoid divide-by-zero.
42 | size = feat.size()
43 | assert (len(size) == 4)
44 | N, C = size[:2]
45 | feat_var = feat.view(N, C, -1).var(dim=2) + eps
46 | feat_std = feat_var.sqrt().view(N, C, 1, 1)
47 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
48 | return feat_mean, feat_std
49 |
50 |
51 | class VggStyleEncoder(nn.Module):
52 | def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type):
53 | super(VggStyleEncoder, self).__init__()
54 | # self.vgg = models.vgg19(pretrained=True).features
55 | vgg19 = models.vgg19(pretrained=False)
56 | vgg19.load_state_dict(torch.load('/home/haihuam/CASD-main/dataset/fashion/vgg19-dcbb9e9d.pth'))
57 | self.vgg = vgg19.features
58 |
59 | for param in self.vgg.parameters():
60 | param.requires_grad_(False)
61 |
62 | self.conv1 = Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type) # 3->64
63 | dim = dim * 2
64 | self.conv2 = Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 128->128
65 | dim = dim * 2
66 | self.conv3 = Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 256->256
67 | dim = dim * 2
68 | self.conv4 = Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type) # 512->512
69 | dim = dim * 2
70 |
71 | self.model0 = []
72 | self.model0 += [nn.Conv2d(dim, style_dim, 1, 1, 0)]
73 | self.model0 = nn.Sequential(*self.model0)
74 |
75 | self.AP = []
76 | self.AP += [nn.AdaptiveAvgPool2d(1)]
77 | self.AP = nn.Sequential(*self.AP)
78 | self.output_dim = dim
79 |
80 | def get_features(self, image, model, layers=None):
81 | if layers is None:
82 | layers = {'0': 'conv1_1', '5': 'conv2_1', '10': 'conv3_1', '19': 'conv4_1'}
83 | features = {}
84 | x = image
85 | # model._modules is a dictionary holding each module in the model
86 | for name, layer in model._modules.items():
87 | x = layer(x)
88 | if name in layers:
89 | features[layers[name]] = x
90 | return features
91 |
92 | def texture_enc(self, x):
93 | sty_fea = self.get_features(x, self.vgg)
94 | x = self.conv1(x)
95 | x = torch.cat([x, sty_fea['conv1_1']], dim=1)
96 | x = self.conv2(x)
97 | x = torch.cat([x, sty_fea['conv2_1']], dim=1)
98 | x = self.conv3(x)
99 | x = torch.cat([x, sty_fea['conv3_1']], dim=1)
100 | x = self.conv4(x)
101 | x = torch.cat([x, sty_fea['conv4_1']], dim=1)
102 | x0 = self.model0(x)
103 | return x0
104 |
105 | def forward(self, x, sem):
106 |
107 | codes = self.texture_enc(x)
108 | segmap = F.interpolate(sem, size=codes.size()[2:], mode='nearest')
109 |
110 | bs = codes.shape[0]
111 | hs = codes.shape[2]
112 | ws = codes.shape[3]
113 | cs = codes.shape[1]
114 | f_size = cs
115 |
116 | s_size = segmap.shape[1]
117 | codes_vector = torch.zeros((bs, s_size, cs), dtype=codes.dtype, device=codes.device)
118 |
119 | for i in range(bs):
120 | for j in range(s_size):
121 | component_mask_area = torch.sum(segmap.bool()[i, j])
122 | if component_mask_area > 0:
123 | codes_component_feature = codes[i].masked_select(segmap.bool()[i, j]).reshape(f_size,
124 | component_mask_area).mean(1)
125 | codes_vector[i][j] = codes_component_feature
126 | else:
127 | tmpmean, tmpstd = calc_mean_std(
128 | codes[i].reshape(1, codes[i].shape[0], codes[i].shape[1], codes[i].shape[2]))
129 | codes_vector[i][j] = tmpmean.squeeze()
130 |
131 |
132 | return codes_vector.view(bs, -1).unsqueeze(2).unsqueeze(3)
133 |
134 |
135 | class ContentEncoder(nn.Module):
136 | def __init__(self, layers=2, ngf=64, img_f=512, use_spect = False, use_coord = False):
137 | super(ContentEncoder, self).__init__()
138 |
139 | self.layers = layers
140 | norm_layer = get_norm_layer(norm_type='instance')
141 | nonlinearity = get_nonlinearity_layer(activation_type='LeakyReLU')
142 | self.ngf = ngf
143 | self.img_f = img_f
144 | self.block0 = EncoderBlock(30, ngf, norm_layer,
145 | nonlinearity, use_spect, use_coord)
146 | mult = 1
147 | for i in range(self.layers-1):
148 | mult_prev = mult
149 | mult = min(2 ** (i + 1), self.img_f//self.ngf)
150 | block = EncoderBlock(self.ngf*mult_prev, self.ngf*mult, norm_layer,
151 | nonlinearity, use_spect, use_coord)
152 | setattr(self, 'encoder' + str(i), block)
153 |
154 | self.model0 = []
155 | self.model0 += [norm_layer(128)]
156 | self.model0 += [nonlinearity]
157 | self.model0 += [nn.Conv2d(128, 256, 1, 1, 0)]
158 | self.model0 = nn.Sequential(*self.model0)
159 |
160 | def forward(self, x):
161 | out = self.block0(x)
162 | for i in range(self.layers-1):
163 | model = getattr(self, 'encoder' + str(i))
164 | out = model(out)
165 | out = self.model0(out)
166 | return out
167 |
168 |
169 | class FFN(nn.Module):
170 | def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
171 | super().__init__()
172 | out_features = out_features or in_features
173 | hidden_features = hidden_features or in_features
174 | self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
175 | self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
176 | self.drop = nn.Dropout(drop)
177 |
178 | def forward(self, x):
179 | b, c, h, w = x.size()
180 | x = self.fc1(x)
181 | x = F.gelu(x)
182 | x = self.drop(x)
183 | x = self.fc2(x)
184 | x = self.drop(x)
185 | x = torch.reshape(x, (b, c, h, w))
186 | return x
187 |
188 |
189 |
190 | class Decoder(nn.Module):
191 | def __init__(self, style_dim, mlp_dim, n_upsample, n_res, dim, output_dim, SP_input_nc, res_norm='adain',
192 | activ='relu', pad_type='zero'):
193 | super(Decoder, self).__init__()
194 | self.softmax = nn.Softmax(dim=1)
195 | self.softmax_style = nn.Softmax(dim=2)
196 | self.SP_input_nc = SP_input_nc
197 | self.model0 = []
198 | self.model1 = []
199 | self.model2 = []
200 | self.n_res = n_res
201 |
202 | self.mlp = MLP(style_dim, n_res * dim * 4, mlp_dim, 3, norm='none', activ=activ)
203 | self.fc = LinearBlock(style_dim, style_dim, norm='none', activation=activ)
204 |
205 | # AdaIN residual blocks
206 | self.model0_0 = [ResBlock_my(dim, res_norm, activ, pad_type=pad_type)]
207 | self.model0_0 = nn.Sequential(*self.model0_0)
208 | self.model0_1 = [ResBlock_my(dim, res_norm, activ, pad_type=pad_type)]
209 | self.model0_1 = nn.Sequential(*self.model0_1)
210 | self.model0_2 = [ResBlock_my(dim, res_norm, activ, pad_type=pad_type)]
211 | self.model0_2 = nn.Sequential(*self.model0_2)
212 | self.model0_3 = [ResBlock_my(dim, res_norm, activ, pad_type=pad_type)]
213 | self.model0_3 = nn.Sequential(*self.model0_3)
214 | self.model0_4 = [ResBlock_myDFNM(dim, 'spade', activ, pad_type=pad_type)]
215 | self.model0_4 = nn.Sequential(*self.model0_4)
216 | self.model0_5 = [ResBlock_myDFNM(dim, 'spade', activ, pad_type=pad_type)]
217 | self.model0_5 = nn.Sequential(*self.model0_5)
218 | self.model0_6 = [ResBlock_myDFNM(dim, 'spade', activ, pad_type=pad_type)]
219 | self.model0_6 = nn.Sequential(*self.model0_6)
220 | self.model0_7 = [ResBlock_myDFNM(dim, 'spade', activ, pad_type=pad_type)]
221 | self.model0_7 = nn.Sequential(*self.model0_7)
222 | # upsampling blocks
223 | for i in range(n_upsample):
224 | self.model1 += [nn.Upsample(scale_factor=2),
225 | Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='ln', activation=activ, pad_type=pad_type)]
226 | dim //= 2
227 | self.model1 = nn.Sequential(*self.model1)
228 | # use reflection padding in the last conv layer
229 | self.model2 += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)]
230 | self.model2 = nn.Sequential(*self.model2)
231 | # attention parameter
232 |
233 | self.gamma3_1 = nn.Parameter(torch.zeros(1))
234 | self.gamma3_2 = nn.Parameter(torch.zeros(1))
235 | self.gamma3_3 = nn.Parameter(torch.zeros(1))
236 | self.gamma3_style_sa = nn.Parameter(torch.zeros(1))
237 | in_dim = int(style_dim / self.SP_input_nc)
238 | self.value3_conv_sa = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
239 | self.LN_3_style = ILNKVT(256)
240 | self.LN_3_pose = ILNQT(256)
241 | self.LN_3_pose_0 = ILNQT(256)
242 | self.query3_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
243 | self.key3_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
244 | self.value3_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
245 | self.query3_conv_0 = nn.Conv2d(in_channels=in_dim, out_channels=self.SP_input_nc, kernel_size=1)
246 |
247 | self.gamma4_1 = nn.Parameter(torch.zeros(1))
248 | self.gamma4_2 = nn.Parameter(torch.zeros(1))
249 | self.gamma4_3 = nn.Parameter(torch.zeros(1))
250 | self.gamma4_style_sa = nn.Parameter(torch.zeros(1))
251 | in_dim = int(style_dim / self.SP_input_nc)
252 | self.value4_conv_sa = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
253 | self.LN_4_style = ILNKVT(256)
254 | self.LN_4_pose = ILNQT(256)
255 | self.LN_4_pose_0 = ILNQT(256)
256 | self.query4_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
257 | self.key4_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
258 | self.value4_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
259 | self.query4_conv_0 = nn.Conv2d(in_channels=in_dim, out_channels=self.SP_input_nc, kernel_size=1)
260 |
261 | self.FFN3_1 = FFN(256)
262 | self.FFN4_1 = FFN(256)
263 | self.up = nn.Upsample(scale_factor=2)
264 |
265 | def forward(self, x, style):
266 | # fusion module
267 | style_fusion = self.fc(style.view(style.size(0), -1))
268 | adain_params = self.mlp(style_fusion)
269 | adain_params = torch.split(adain_params, int(adain_params.shape[1] / self.n_res), 1)
270 |
271 | x_0 = x
272 | x = self.model0_0([x, adain_params[0]])
273 | x = self.model0_1([x, adain_params[1]])
274 | x = self.model0_2([x, adain_params[2]])
275 | x = self.model0_3([x, adain_params[3]])
276 |
277 | x3, enerrgy_sum3 = self.styleatt(x, x_0, style, self.gamma3_1, self.gamma3_2, self.gamma3_3, \
278 | self.gamma3_style_sa, self.value3_conv_sa, \
279 | self.LN_3_style, self.LN_3_pose, self.LN_3_pose_0, \
280 | self.query3_conv, self.key3_conv, self.value3_conv, self.query3_conv_0, \
281 | self.FFN3_1)
282 |
283 | x_, enerrgy_sum4 = self.styleatt(x3, x_0, style, self.gamma4_1, self.gamma4_2, self.gamma4_3, \
284 | self.gamma4_style_sa, self.value4_conv_sa, \
285 | self.LN_4_style, self.LN_4_pose, self.LN_4_pose_0, \
286 | self.query4_conv, self.key4_conv, self.value4_conv, self.query4_conv_0, \
287 | self.FFN4_1)
288 |
289 | x = self.model0_4([x_0, x_])
290 | x = self.model0_5([x, x_])
291 | x = self.model0_6([x, x_])
292 | x = self.model0_7([x, x_])
293 | x = self.model1(x)
294 | return self.model2(x), [enerrgy_sum3, enerrgy_sum4]
295 |
296 | def styleatt(self, x, x_0, style, gamma1, gamma2, gamma3, gamma_style_sa, value_conv_sa, ln_style, ln_pose,
297 | ln_pose_0, query_conv, key_conv, value_conv, query_conv_0, ffn1):
298 | B, C, H, W = x.size()
299 | B, Cs, _, _ = style.size()
300 | K = self.SP_input_nc
301 | style = style.view((B, K, int(Cs / K))) # [B,K,C]
302 |
303 | x = ln_pose(x) # [B,C,H,W]
304 | style = ln_style(style.permute(0, 2, 1)) # [B,C,K]
305 | x_0 = ln_pose_0(x_0)
306 |
307 | style = style.permute(0, 2, 1) # [B,K,C]
308 | style_sa_value = torch.squeeze(value_conv_sa(torch.unsqueeze(style.permute(0, 2, 1), 3)), 3) # [B,C,K]
309 | self_att = self.softmax(torch.bmm(style, style.permute(0, 2, 1))) + 1e-8 # [B,K,K]
310 | self_att = self_att / torch.sum(self_att, dim=2, keepdim=True)
311 | style_ = torch.bmm(self_att, style_sa_value.permute(0, 2, 1))
312 | style = style + gamma_style_sa * style_ # [B,K,C]
313 |
314 | style = style.permute(0, 2, 1) #[B,C,K]
315 | x_query = query_conv(x)
316 | style_key = torch.squeeze(key_conv(torch.unsqueeze(style, 3)).permute(0, 2, 1, 3), 3)
317 | style_value = torch.squeeze(value_conv(torch.unsqueeze(style, 3)), 3)
318 |
319 | energy_0 = query_conv_0(x_0).view((B, K, H * W))
320 | energy = torch.bmm(style_key.detach(), x_query.view(B, C, -1))
321 | enerrgy_sum = energy_0 + energy
322 | attention = self.softmax_style(enerrgy_sum) + 1e-8
323 | attention = attention / torch.sum(attention, dim=1, keepdim=True)
324 |
325 | out = torch.bmm(style_value, attention)
326 | out = out.view(B, C, H, W)
327 | out = gamma1 * out + x
328 | out = out + gamma3 * ffn1(out)
329 |
330 | return out, torch.reshape(enerrgy_sum, (B, K, H, W))
331 |
332 |
333 | class ILNKVT(nn.Module):
334 | def __init__(self, num_features, eps=1e-5):
335 | super().__init__()
336 | self.eps = eps
337 | self.rho = Parameter(torch.Tensor(1, num_features, 1))
338 | self.gamma = Parameter(torch.Tensor(1, num_features, 1))
339 | self.beta = Parameter(torch.Tensor(1, num_features, 1))
340 | self.rho.data.fill_(0.0)
341 | self.gamma.data.fill_(1.0)
342 | self.beta.data.fill_(0.0)
343 |
344 | def forward(self, input):
345 | in_mean, in_var = torch.mean(input, dim=[2], keepdim=True), torch.var(input, dim=[2], keepdim=True)
346 | out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
347 | ln_mean, ln_var = torch.mean(input, dim=[1], keepdim=True), torch.var(input, dim=[1], keepdim=True)
348 | out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
349 | out = self.rho.expand(input.shape[0], -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1)) * out_ln
350 | out = out * self.gamma.expand(input.shape[0], -1, -1) + self.beta.expand(input.shape[0], -1, -1)
351 |
352 | return out
353 |
354 | class ILNQT(nn.Module):
355 | def __init__(self, num_features, eps=1e-5):
356 | super().__init__()
357 | self.eps = eps
358 | self.rho = Parameter(torch.Tensor(1, num_features, 1, 1))
359 | self.gamma = Parameter(torch.Tensor(1, num_features, 1, 1))
360 | self.beta = Parameter(torch.Tensor(1, num_features, 1, 1))
361 | self.rho.data.fill_(0.0)
362 | self.gamma.data.fill_(1.0)
363 | self.beta.data.fill_(0.0)
364 |
365 | def forward(self, input):
366 | in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True)
367 | out_in = (input - in_mean) / torch.sqrt(in_var + self.eps)
368 | ln_mean, ln_var = torch.mean(input, dim=[1], keepdim=True), torch.var(input, dim=[1], keepdim=True)
369 | out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps)
370 | out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln
371 | out = out * self.gamma.expand(input.shape[0], -1, -1, -1) + self.beta.expand(input.shape[0], -1, -1, -1)
372 |
373 | return out
374 |
375 |
376 | ##################################################################################
377 | # Sequential Models
378 | ##################################################################################
379 | class ResBlocks(nn.Module):
380 | def __init__(self, num_blocks, dim, norm='in', activation='relu', pad_type='zero'):
381 | super(ResBlocks, self).__init__()
382 | self.model = []
383 | for i in range(num_blocks):
384 | self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)]
385 | self.model = nn.Sequential(*self.model)
386 |
387 | def forward(self, x):
388 | return self.model(x)
389 |
390 |
391 | class ResBlock_myDFNM(nn.Module):
392 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
393 | super(ResBlock_myDFNM, self).__init__()
394 |
395 | model1 = []
396 | model2 = []
397 | model1 += [Conv2dBlock_my(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
398 | model2 += [Conv2dBlock_my(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
399 | models1 = []
400 | models1 += [Conv2dBlock(dim, dim, 3, 1, 1, norm='in', activation='relu', pad_type=pad_type)]
401 | models1 += [Conv2dBlock(dim, 2 * dim, 3, 1, 1, norm='none', activation='none', pad_type=pad_type)]
402 | models2 = []
403 | models2 += [Conv2dBlock(dim, dim, 3, 1, 1, norm='in', activation='relu', pad_type=pad_type)]
404 | models2 += [Conv2dBlock(dim, 2 * dim, 3, 1, 1, norm='none', activation='none', pad_type=pad_type)]
405 | self.model1 = nn.Sequential(*model1)
406 | self.model2 = nn.Sequential(*model2)
407 | self.models1 = nn.Sequential(*models1)
408 | self.models2 = nn.Sequential(*models2)
409 |
410 | def forward(self, x):
411 | style = x[1]
412 | style1 = self.models1(style)
413 | style2 = self.models2(style)
414 | residual = x[0]
415 | out = self.model1([x[0], style1])
416 | out = self.model2([out, style2])
417 | out += residual
418 |
419 | return out
420 |
421 |
422 | class ResBlock_my(nn.Module):
423 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
424 | super(ResBlock_my, self).__init__()
425 |
426 | model1 = []
427 | model2 = []
428 | model1 += [Conv2dBlock_my(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
429 | model2 += [Conv2dBlock_my(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
430 | self.model1 = nn.Sequential(*model1)
431 | self.model2 = nn.Sequential(*model2)
432 |
433 | def forward(self, x):
434 | style = x[1]
435 | style1, style2 = torch.split(style, int(style.shape[1] / 2), 1)
436 | residual = x[0]
437 | out = self.model1([x[0], style1])
438 | out = self.model2([out, style2])
439 | out += residual
440 | return out
441 |
442 |
443 | class MLP(nn.Module):
444 | def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'):
445 |
446 | super(MLP, self).__init__()
447 | self.model = []
448 | self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)]
449 | for i in range(n_blk - 2):
450 | self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)]
451 | self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations
452 | self.model = nn.Sequential(*self.model)
453 |
454 | def forward(self, x):
455 | return self.model(x)
456 |
457 |
458 | ##################################################################################
459 | # Basic Blocks
460 | ##################################################################################
461 | class ResBlock(nn.Module):
462 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'):
463 | super(ResBlock, self).__init__()
464 |
465 | model = []
466 | model += [Conv2dBlock(dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)]
467 | model += [Conv2dBlock(dim, dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)]
468 | self.model = nn.Sequential(*model)
469 |
470 | def forward(self, x):
471 | residual = x
472 | out = self.model(x)
473 | out += residual
474 | return out
475 |
476 |
477 | class Conv2dBlock_my(nn.Module):
478 | def __init__(self, input_dim, output_dim, kernel_size, stride,
479 | padding=0, norm='none', activation='relu', pad_type='zero'):
480 | super(Conv2dBlock_my, self).__init__()
481 | self.use_bias = True
482 | # initialize padding
483 | if pad_type == 'reflect':
484 | self.pad = nn.ReflectionPad2d(padding)
485 | elif pad_type == 'replicate':
486 | self.pad = nn.ReplicationPad2d(padding)
487 | elif pad_type == 'zero':
488 | self.pad = nn.ZeroPad2d(padding)
489 | else:
490 | assert 0, "Unsupported padding type: {}".format(pad_type)
491 |
492 | # initialize normalization
493 | norm_dim = output_dim
494 | if norm == 'bn':
495 | self.norm = nn.BatchNorm2d(norm_dim)
496 | elif norm == 'in':
497 | # self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
498 | self.norm = nn.InstanceNorm2d(norm_dim)
499 | elif norm == 'ln':
500 | self.norm = LayerNorm(norm_dim)
501 | elif norm == 'adain':
502 | self.norm = AdaptiveInstanceNorm2d(norm_dim)
503 | elif norm == 'spade':
504 | self.norm = SPADE()
505 | elif norm == 'none' or norm == 'sn':
506 | self.norm = None
507 | else:
508 | assert 0, "Unsupported normalization: {}".format(norm)
509 |
510 | # initialize activation
511 | if activation == 'relu':
512 | self.activation = nn.ReLU(inplace=True)
513 | elif activation == 'lrelu':
514 | self.activation = nn.LeakyReLU(0.2, inplace=True)
515 | elif activation == 'prelu':
516 | self.activation = nn.PReLU()
517 | elif activation == 'selu':
518 | self.activation = nn.SELU(inplace=True)
519 | elif activation == 'tanh':
520 | self.activation = nn.Tanh()
521 | elif activation == 'none':
522 | self.activation = None
523 | else:
524 | assert 0, "Unsupported activation: {}".format(activation)
525 |
526 | # initialize convolution
527 | if norm == 'sn':
528 | self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias))
529 | else:
530 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
531 |
532 | def forward(self, x):
533 | style = x[1]
534 | x = x[0]
535 | x = self.conv(self.pad(x))
536 | if self.norm:
537 | x = self.norm([x, style])
538 | if self.activation:
539 | x = self.activation(x)
540 | return x
541 |
542 |
543 | class Conv2dBlock(nn.Module):
544 | def __init__(self, input_dim, output_dim, kernel_size, stride,
545 | padding=0, norm='none', activation='relu', pad_type='zero'):
546 | super(Conv2dBlock, self).__init__()
547 | self.use_bias = True
548 | # initialize padding
549 | if pad_type == 'reflect':
550 | self.pad = nn.ReflectionPad2d(padding)
551 | elif pad_type == 'replicate':
552 | self.pad = nn.ReplicationPad2d(padding)
553 | elif pad_type == 'zero':
554 | self.pad = nn.ZeroPad2d(padding)
555 | else:
556 | assert 0, "Unsupported padding type: {}".format(pad_type)
557 |
558 | # initialize normalization
559 | norm_dim = output_dim
560 | if norm == 'bn':
561 | self.norm = nn.BatchNorm2d(norm_dim)
562 | elif norm == 'in':
563 | # self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
564 | self.norm = nn.InstanceNorm2d(norm_dim)
565 | elif norm == 'ln':
566 | self.norm = LayerNorm(norm_dim)
567 | elif norm == 'adain':
568 | self.norm = AdaptiveInstanceNorm2d(norm_dim)
569 | elif norm == 'none' or norm == 'sn':
570 | self.norm = None
571 | else:
572 | assert 0, "Unsupported normalization: {}".format(norm)
573 |
574 | # initialize activation
575 | if activation == 'relu':
576 | self.activation = nn.ReLU(inplace=True)
577 | elif activation == 'lrelu':
578 | self.activation = nn.LeakyReLU(0.2, inplace=True)
579 | elif activation == 'prelu':
580 | self.activation = nn.PReLU()
581 | elif activation == 'selu':
582 | self.activation = nn.SELU(inplace=True)
583 | elif activation == 'tanh':
584 | self.activation = nn.Tanh()
585 | elif activation == 'none':
586 | self.activation = None
587 | else:
588 | assert 0, "Unsupported activation: {}".format(activation)
589 |
590 | # initialize convolution
591 | if norm == 'sn':
592 | self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias))
593 | else:
594 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)
595 |
596 | def forward(self, x):
597 | x = self.conv(self.pad(x))
598 | if self.norm:
599 | x = self.norm(x)
600 | if self.activation:
601 | x = self.activation(x)
602 | return x
603 |
604 |
605 | class LinearBlock(nn.Module):
606 | def __init__(self, input_dim, output_dim, norm='none', activation='relu'):
607 | super(LinearBlock, self).__init__()
608 | use_bias = True
609 | # initialize fully connected layer
610 | if norm == 'sn':
611 | self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias))
612 | else:
613 | self.fc = nn.Linear(input_dim, output_dim, bias=use_bias)
614 |
615 | # initialize normalization
616 | norm_dim = output_dim
617 | if norm == 'bn':
618 | self.norm = nn.BatchNorm1d(norm_dim)
619 | elif norm == 'in':
620 | self.norm = nn.InstanceNorm1d(norm_dim)
621 | elif norm == 'ln':
622 | self.norm = LayerNorm(norm_dim)
623 | elif norm == 'none' or norm == 'sn':
624 | self.norm = None
625 | else:
626 | assert 0, "Unsupported normalization: {}".format(norm)
627 |
628 | # initialize activation
629 | if activation == 'relu':
630 | self.activation = nn.ReLU(inplace=True)
631 | elif activation == 'lrelu':
632 | self.activation = nn.LeakyReLU(0.2, inplace=True)
633 | elif activation == 'prelu':
634 | self.activation = nn.PReLU()
635 | elif activation == 'selu':
636 | self.activation = nn.SELU(inplace=True)
637 | elif activation == 'tanh':
638 | self.activation = nn.Tanh()
639 | elif activation == 'none':
640 | self.activation = None
641 | else:
642 | assert 0, "Unsupported activation: {}".format(activation)
643 |
644 | def forward(self, x):
645 | out = self.fc(x)
646 | if self.norm:
647 | out = self.norm(out)
648 | if self.activation:
649 | out = self.activation(out)
650 | return out
651 |
652 |
653 | ##################################################################################
654 | # Normalization layers
655 | ##################################################################################
656 | class SPADE(nn.Module):
657 | def __init__(self):
658 | super().__init__()
659 |
660 | def forward(self, x):
661 | style = x[1]
662 | x = x[0]
663 | # Part 1. generate parameter-free normalized activations
664 | x_mean = torch.mean(x, (0, 2, 3), keepdim=True)
665 | x_var = torch.var(x, (0, 2, 3), keepdim=True)
666 | normalized = (x - x_mean) / (x_var + 1e-6)
667 |
668 | # Part 2. produce scaling and bias conditioned on semantic map
669 | gamma, beta = torch.split(style, int(style.size(1) / 2), 1)
670 | # apply scale and bias
671 | out = normalized * (1 + gamma) + beta
672 |
673 | return out
674 |
675 |
676 | class AdaptiveInstanceNorm2d(nn.Module):
677 | def __init__(self, num_features, eps=1e-5, momentum=0.1):
678 | super(AdaptiveInstanceNorm2d, self).__init__()
679 | self.num_features = num_features
680 | self.eps = eps
681 | self.momentum = momentum
682 | # weight and bias are dynamically assigned
683 | self.weight = None
684 | self.bias = None
685 | # just dummy buffers, not used
686 | self.register_buffer('running_mean', torch.zeros(num_features))
687 | self.register_buffer('running_var', torch.ones(num_features))
688 |
689 | def forward(self, x):
690 | style = x[1]
691 | self.weight, self.bias = torch.split(style, int(style.shape[1] / 2), 1)
692 | x = x[0]
693 | b, c = x.size(0), x.size(1)
694 | running_mean = self.running_mean.repeat(b)
695 | running_var = self.running_var.repeat(b)
696 |
697 | # Apply instance norm
698 | x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
699 |
700 | out = F.batch_norm(
701 | x_reshaped, running_mean, running_var, self.weight, self.bias,
702 | True, self.momentum, self.eps)
703 |
704 | return out.view(b, c, *x.size()[2:])
705 |
706 | def __repr__(self):
707 | return self.__class__.__name__ + '(' + str(self.num_features) + ')'
708 |
709 |
710 | class LayerNorm(nn.Module):
711 | def __init__(self, num_features, eps=1e-5, affine=True):
712 | super(LayerNorm, self).__init__()
713 | self.num_features = num_features
714 | self.affine = affine
715 | self.eps = eps
716 |
717 | if self.affine:
718 | self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
719 | self.beta = nn.Parameter(torch.zeros(num_features))
720 |
721 | def forward(self, x):
722 | shape = [-1] + [1] * (x.dim() - 1)
723 | # print(x.size())
724 | if x.size(0) == 1:
725 | # These two lines run much faster in pytorch 0.4 than the two lines listed below.
726 | mean = x.view(-1).mean().view(*shape)
727 | std = x.view(-1).std().view(*shape)
728 | else:
729 | mean = x.view(x.size(0), -1).mean(1).view(*shape)
730 | std = x.view(x.size(0), -1).std(1).view(*shape)
731 |
732 | x = (x - mean) / (std + self.eps)
733 |
734 | if self.affine:
735 | shape = [1, -1] + [1] * (x.dim() - 2)
736 | x = x * self.gamma.view(*shape) + self.beta.view(*shape)
737 | return x
738 |
739 |
740 | def l2normalize(v, eps=1e-12):
741 | return v / (v.norm() + eps)
742 |
743 |
744 | class SpectralNorm(nn.Module):
745 | """
746 | Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida
747 | and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
748 | """
749 |
750 | def __init__(self, module, name='weight', power_iterations=1):
751 | super(SpectralNorm, self).__init__()
752 | self.module = module
753 | self.name = name
754 | self.power_iterations = power_iterations
755 | if not self._made_params():
756 | self._make_params()
757 |
758 | def _update_u_v(self):
759 | u = getattr(self.module, self.name + "_u")
760 | v = getattr(self.module, self.name + "_v")
761 | w = getattr(self.module, self.name + "_bar")
762 |
763 | height = w.data.shape[0]
764 | for _ in range(self.power_iterations):
765 | v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
766 | u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
767 |
768 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
769 | sigma = u.dot(w.view(height, -1).mv(v))
770 | setattr(self.module, self.name, w / sigma.expand_as(w))
771 |
772 | def _made_params(self):
773 | try:
774 | u = getattr(self.module, self.name + "_u")
775 | v = getattr(self.module, self.name + "_v")
776 | w = getattr(self.module, self.name + "_bar")
777 | return True
778 | except AttributeError:
779 | return False
780 |
781 | def _make_params(self):
782 | w = getattr(self.module, self.name)
783 |
784 | height = w.data.shape[0]
785 | width = w.view(height, -1).data.shape[1]
786 |
787 | u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
788 | v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
789 | u.data = l2normalize(u.data)
790 | v.data = l2normalize(v.data)
791 | w_bar = nn.Parameter(w.data)
792 |
793 | del self.module._parameters[self.name]
794 |
795 | self.module.register_parameter(self.name + "_u", u)
796 | self.module.register_parameter(self.name + "_v", v)
797 | self.module.register_parameter(self.name + "_bar", w_bar)
798 |
799 | def forward(self, *args):
800 | self._update_u_v()
801 | return self.module.forward(*args)
802 |
803 |
804 | def get_norm_layer(norm_type='batch'):
805 | """Get the normalization layer for the networks"""
806 | if norm_type == 'batch':
807 | norm_layer = functools.partial(nn.BatchNorm2d, momentum=0.1, affine=True)
808 | elif norm_type == 'instance':
809 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=True)
810 | elif norm_type == 'adain':
811 | norm_layer = functools.partial(ADAIN)
812 | elif norm_type == 'spade':
813 | norm_layer = functools.partial(SPADE, config_text='spadeinstance3x3')
814 | elif norm_type == 'none':
815 | norm_layer = None
816 | else:
817 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
818 |
819 | if norm_type != 'none':
820 | norm_layer.__name__ = norm_type
821 |
822 | return norm_layer
823 |
824 | def get_nonlinearity_layer(activation_type='PReLU'):
825 | """Get the activation layer for the networks"""
826 | if activation_type == 'ReLU':
827 | nonlinearity_layer = nn.ReLU()
828 | elif activation_type == 'SELU':
829 | nonlinearity_layer = nn.SELU()
830 | elif activation_type == 'LeakyReLU':
831 | nonlinearity_layer = nn.LeakyReLU(0.1)
832 | elif activation_type == 'PReLU':
833 | nonlinearity_layer = nn.PReLU()
834 | else:
835 | raise NotImplementedError('activation layer [%s] is not found' % activation_type)
836 | return nonlinearity_layer
837 |
838 |
839 | class AddCoords(nn.Module):
840 | """
841 | Add Coords to a tensor
842 | """
843 | def __init__(self, with_r=False):
844 | super(AddCoords, self).__init__()
845 | self.with_r = with_r
846 |
847 | def forward(self, x):
848 | """
849 | :param x: shape (batch, channel, x_dim, y_dim)
850 | :return: shape (batch, channel+2, x_dim, y_dim)
851 | """
852 | B, _, x_dim, y_dim = x.size()
853 |
854 | # coord calculate
855 | xx_channel = torch.arange(x_dim).repeat(B, 1, y_dim, 1).type_as(x)
856 | yy_cahnnel = torch.arange(y_dim).repeat(B, 1, x_dim, 1).permute(0, 1, 3, 2).type_as(x)
857 | # normalization
858 | xx_channel = xx_channel.float() / (x_dim-1)
859 | yy_cahnnel = yy_cahnnel.float() / (y_dim-1)
860 | xx_channel = xx_channel * 2 - 1
861 | yy_cahnnel = yy_cahnnel * 2 - 1
862 |
863 | ret = torch.cat([x, xx_channel, yy_cahnnel], dim=1)
864 |
865 | if self.with_r:
866 | rr = torch.sqrt(xx_channel ** 2 + yy_cahnnel ** 2)
867 | ret = torch.cat([ret, rr], dim=1)
868 |
869 | return ret
870 |
871 |
872 | def spectral_norm(module, use_spect=True):
873 | """use spectral normal layer to stable the training process"""
874 | if use_spect:
875 | return SpectralNorm(module)
876 | else:
877 | return module
878 |
879 |
880 |
881 | class CoordConv(nn.Module):
882 | """
883 | CoordConv operation
884 | """
885 | def __init__(self, input_nc, output_nc, with_r=False, use_spect=False, **kwargs):
886 | super(CoordConv, self).__init__()
887 | self.addcoords = AddCoords(with_r=with_r)
888 | input_nc = input_nc + 2
889 | if with_r:
890 | input_nc = input_nc + 1
891 | self.conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
892 |
893 | def forward(self, x):
894 | ret = self.addcoords(x)
895 | ret = self.conv(ret)
896 |
897 | return ret
898 |
899 |
900 | def coord_conv(input_nc, output_nc, use_spect=False, use_coord=False, with_r=False, **kwargs):
901 | """use coord convolution layer to add position information"""
902 | if use_coord:
903 | return CoordConv(input_nc, output_nc, with_r, use_spect, **kwargs)
904 | else:
905 | return spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect)
906 |
907 |
908 | class EncoderBlock(nn.Module):
909 | def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(),
910 | use_spect=False, use_coord=False):
911 | super(EncoderBlock, self).__init__()
912 |
913 |
914 | kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1}
915 | kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1}
916 |
917 | conv1 = coord_conv(input_nc, output_nc, use_spect, use_coord, **kwargs_down)
918 | conv2 = coord_conv(output_nc, output_nc, use_spect, use_coord, **kwargs_fine)
919 |
920 | if type(norm_layer) == type(None):
921 | self.model = nn.Sequential(nonlinearity, conv1, nonlinearity, conv2,)
922 | else:
923 | self.model = nn.Sequential(norm_layer(input_nc), nonlinearity, conv1,
924 | norm_layer(output_nc), nonlinearity, conv2,)
925 |
926 | def forward(self, x):
927 | out = self.model(x)
928 | return out
929 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/models/__init__.py
--------------------------------------------------------------------------------
/models/adgan.py:
--------------------------------------------------------------------------------
1 | from torch.autograd import Variable
2 | import numpy as np
3 | import torch
4 | import os
5 | from collections import OrderedDict
6 | import util.util as util
7 | from util.image_pool import ImagePool
8 | from .base_model import BaseModel
9 | from . import networks
10 | # losses
11 | from losses.L1_plus_perceptualLoss import L1_plus_perceptualLoss
12 | from losses.CX_style_loss import CXLoss
13 | from .vgg_SC import VGG, VGGLoss
14 | from losses.lpips.lpips import LPIPS
15 |
16 |
17 |
18 | class TransferModel(BaseModel):
19 | def name(self):
20 | return 'TransferModel'
21 |
22 | def initialize(self, opt):
23 | BaseModel.initialize(self, opt)
24 | nb = opt.batchSize
25 | size = opt.fineSize
26 | self.use_AMCE = opt.use_AMCE
27 | self.use_BPD = opt.use_BPD
28 | self.SP_input_nc = opt.SP_input_nc
29 | self.input_P1_set = self.Tensor(nb, opt.P_input_nc, size[0], size[1])
30 | self.input_BP1_set = self.Tensor(nb, opt.BP_input_nc, size[0], size[1])
31 | self.input_P2_set = self.Tensor(nb, opt.P_input_nc, size[0], size[1])
32 | self.input_BP2_set = self.Tensor(nb, opt.BP_input_nc, size[0], size[1])
33 | self.input_SP1_set = self.Tensor(nb, opt.SP_input_nc, size[0], size[1])
34 | self.input_SP2_set = self.Tensor(nb, opt.SP_input_nc, size[0], size[1])
35 | if self.use_BPD:
36 | self.input_BPD1_set = self.Tensor(nb, opt.BPD_input_nc, size[0], size[1])
37 | self.input_BPD2_set = self.Tensor(nb, opt.BPD_input_nc, size[0], size[1])
38 |
39 |
40 | input_nc = [opt.P_input_nc, opt.BP_input_nc+opt.BP_input_nc + (opt.BPD_input_nc+opt.BPD_input_nc if self.use_BPD else 0)]
41 | self.netG = networks.define_G(input_nc, opt.P_input_nc,
42 | opt.ngf, opt.which_model_netG, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids,
43 | n_downsampling=opt.G_n_downsampling)
44 |
45 | if self.isTrain:
46 | use_sigmoid = opt.no_lsgan
47 | if opt.with_D_PB:
48 | self.netD_PB = networks.define_D(opt.P_input_nc+opt.BP_input_nc + (opt.BPD_input_nc if self.use_BPD else 0), opt.ndf,
49 | opt.which_model_netD,
50 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids,
51 | not opt.no_dropout_D,
52 | n_downsampling = opt.D_n_downsampling)
53 |
54 | if opt.with_D_PP:
55 | self.netD_PP = networks.define_D(opt.P_input_nc+opt.P_input_nc, opt.ndf,
56 | opt.which_model_netD,
57 | opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids,
58 | not opt.no_dropout_D,
59 | n_downsampling = opt.D_n_downsampling)
60 |
61 | if len(opt.gpu_ids) > 1:
62 | self.load_VGG(self.netG.module.enc_style.vgg)
63 | else:
64 | self.load_VGG(self.netG.enc_style.vgg)
65 |
66 | if not self.isTrain or opt.continue_train:
67 | which_epoch = opt.which_epoch
68 | self.load_network(self.netG, 'netG', which_epoch)
69 | if self.isTrain:
70 | if opt.with_D_PB:
71 | self.load_network(self.netD_PB, 'netD_PB', which_epoch)
72 | if opt.with_D_PP:
73 | self.load_network(self.netD_PP, 'netD_PP', which_epoch)
74 |
75 | if self.isTrain:
76 | self.old_lr = opt.lr
77 | self.fake_PP_pool = ImagePool(opt.pool_size)
78 | self.fake_PB_pool = ImagePool(opt.pool_size)
79 | # define loss functions
80 | self.criterionGAN = networks.GANLoss(use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
81 |
82 |
83 | if opt.L1_type == 'origin':
84 | self.criterionL1 = torch.nn.L1Loss()
85 | elif opt.L1_type == 'l1_plus_perL1':
86 | self.criterionL1 = L1_plus_perceptualLoss(opt.lambda_A, opt.lambda_B, opt.perceptual_layers, self.gpu_ids, opt.percep_is_l1)
87 | else:
88 | raise Excption('Unsurportted type of L1!')
89 |
90 | if opt.use_cxloss:
91 | self.CX_loss = CXLoss(sigma=0.5)
92 | if torch.cuda.is_available():
93 | self.CX_loss.cuda()
94 | self.vgg = VGG()
95 | self.vgg.load_state_dict(torch.load(os.path.abspath(opt.dataroot) + '/vgg_conv.pth'))
96 | for param in self.vgg.parameters():
97 | param.requires_grad = False
98 | if torch.cuda.is_available():
99 | self.vgg.cuda()
100 |
101 | if opt.use_lpips:
102 | self.lpips_loss = LPIPS(net_type='vgg').cuda().eval()
103 |
104 | if self.use_AMCE:
105 | self.AM_CE_loss = torch.nn.CrossEntropyLoss()
106 | if torch.cuda.is_available():
107 | self.AM_CE_loss.cuda()
108 |
109 |
110 | self.Vggloss = VGGLoss().cuda().eval()
111 |
112 |
113 | # initialize optimizers
114 | self.optimizer_G = torch.optim.Adam(filter(lambda p: p.requires_grad, self.netG.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
115 |
116 | if opt.with_D_PB:
117 | self.optimizer_D_PB = torch.optim.Adam(self.netD_PB.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
118 | if opt.with_D_PP:
119 | self.optimizer_D_PP = torch.optim.Adam(self.netD_PP.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
120 |
121 | self.optimizers = []
122 | self.schedulers = []
123 | self.optimizers.append(self.optimizer_G)
124 | if opt.with_D_PB:
125 | self.optimizers.append(self.optimizer_D_PB)
126 | if opt.with_D_PP:
127 | self.optimizers.append(self.optimizer_D_PP)
128 | for optimizer in self.optimizers:
129 | self.schedulers.append(networks.get_scheduler(optimizer, opt))
130 |
131 | print('---------- Networks initialized -------------')
132 | networks.print_network(self.netG)
133 | if self.isTrain:
134 | if opt.with_D_PB:
135 | networks.print_network(self.netD_PB)
136 | if opt.with_D_PP:
137 | networks.print_network(self.netD_PP)
138 | print('-----------------------------------------------')
139 |
140 |
141 | def set_input(self, input):
142 | input_P1, input_BP1 = input['P1'], input['BP1']
143 | input_P2, input_BP2 = input['P2'], input['BP2']
144 |
145 | self.input_P1_set.resize_(input_P1.size()).copy_(input_P1)
146 | self.input_BP1_set.resize_(input_BP1.size()).copy_(input_BP1)
147 | self.input_P2_set.resize_(input_P2.size()).copy_(input_P2)
148 | self.input_BP2_set.resize_(input_BP2.size()).copy_(input_BP2)
149 |
150 | if self.use_BPD:
151 | input_BPD1, input_BPD2 = input['BPD1'], input['BPD2']
152 | self.input_BPD1_set.resize_(input_BPD1.size()).copy_(input_BPD1)
153 | self.input_BPD2_set.resize_(input_BPD2.size()).copy_(input_BPD2)
154 |
155 | input_SP1 = input['SP1']
156 | self.input_SP1_set.resize_(input_SP1.size()).copy_(input_SP1)
157 | if self.use_AMCE:
158 | input_SP2 = input['SP2']
159 | self.input_SP2_set.resize_(input_SP2.size()).copy_(input_SP2)
160 |
161 | self.image_paths = input['P1_path'][0] + '___' + input['P2_path'][0]
162 | self.person_paths = input['P1_path'][0]
163 |
164 |
165 | def forward(self):
166 |
167 | self.input_P1 = Variable(self.input_P1_set)
168 | self.input_BP1 = Variable(self.input_BP1_set)
169 |
170 | self.input_P2 = Variable(self.input_P2_set)
171 | self.input_BP2 = Variable(self.input_BP2_set)
172 |
173 | if self.use_BPD:
174 | self.input_BPD1 = Variable(self.input_BPD1_set)
175 | self.input_BPD2 = Variable(self.input_BPD2_set)
176 |
177 | self.input_SP1 = Variable(self.input_SP1_set)
178 | self.input_SP2 = Variable(self.input_SP2_set)
179 |
180 | if self.use_BPD:
181 | self.fake_p2, self.fake_sp2 = self.netG(torch.cat([self.input_BP2, self.input_BPD2], 1), self.input_P1, self.input_SP1)
182 | else:
183 | self.fake_p2, self.fake_sp2 = self.netG(self.input_BP2, self.input_P1, self.input_SP1)
184 |
185 |
186 | def test(self):
187 | self.input_P1 = Variable(self.input_P1_set)
188 | self.input_BP1 = Variable(self.input_BP1_set)
189 |
190 | self.input_P2 = Variable(self.input_P2_set)
191 | self.input_BP2 = Variable(self.input_BP2_set)
192 |
193 | if self.use_BPD:
194 | self.input_BPD1 = Variable(self.input_BPD1_set)
195 | self.input_BPD2 = Variable(self.input_BPD2_set)
196 |
197 | self.input_SP1 = Variable(self.input_SP1_set)
198 | self.input_SP2 = Variable(self.input_SP2_set)
199 |
200 |
201 | if self.use_BPD:
202 | self.fake_p2, self.fake_sp2 = self.netG(torch.cat([self.input_BP2, self.input_BPD2], 1), self.input_P1, self.input_SP1)
203 | else:
204 | self.fake_p2, self.fake_sp2 = self.netG(self.input_BP2, self.input_P1, self.input_SP1)
205 |
206 |
207 | # get image paths
208 | def get_image_paths(self):
209 | return self.image_paths
210 |
211 | def get_person_paths(self):
212 | return self.person_paths
213 |
214 |
215 | def backward_G(self):
216 | if self.opt.with_D_PB:
217 | if self.use_BPD:
218 | pred_fake_PB = self.netD_PB(torch.cat((self.fake_p2, self.input_BP2, self.input_BPD2), 1))
219 | else:
220 | pred_fake_PB = self.netD_PB(torch.cat((self.fake_p2, self.input_BP2), 1))
221 | self.loss_G_GAN_PB = self.criterionGAN(pred_fake_PB, True)
222 |
223 | if self.opt.with_D_PP:
224 | pred_fake_PP = self.netD_PP(torch.cat((self.fake_p2, self.input_P1), 1))
225 | self.loss_G_GAN_PP = self.criterionGAN(pred_fake_PP, True)
226 |
227 | # CX loss
228 | if self.opt.use_cxloss:
229 | style_layer = ['r32', 'r42']
230 | vgg_style = self.vgg(self.input_P2, style_layer)
231 | vgg_fake = self.vgg(self.fake_p2, style_layer)
232 | cx_style_loss = 0
233 |
234 | for i, val in enumerate(vgg_fake):
235 | cx_style_loss += self.CX_loss(vgg_style[i], vgg_fake[i])
236 | cx_style_loss *= self.opt.lambda_cx
237 |
238 | pair_cxloss = cx_style_loss
239 |
240 | if self.opt.use_lpips:
241 | lpips_loss = self.lpips_loss(self.fake_p2, self.input_P2)
242 | lpips_loss *= self.opt.lambda_lpips
243 | pair_lpips_loss = lpips_loss
244 |
245 | # Attention Map Cross Entropy loss
246 | if self.use_AMCE:
247 | up_ = torch.nn.Upsample(scale_factor=4, mode='bilinear')
248 | if isinstance(self.fake_sp2,list):
249 | AMCE_loss = 0
250 | B, C, H, W = self.input_SP2.shape
251 | for i in range(len(self.fake_sp2)):
252 | logits = up_(self.fake_sp2[i])
253 | logits = torch.reshape(logits.permute(0,2,3,1), (B*H*W, C))
254 | labels = torch.argmax(torch.reshape(self.input_SP2.permute(0,2,3,1), (B*H*W, C)), 1)
255 | AMCE_loss += self.AM_CE_loss(logits, labels)
256 |
257 | AMCE_loss *= self.opt.lambda_AMCE
258 | pair_AMCE_loss = AMCE_loss
259 | else:
260 | logits = up_(self.fake_sp2)
261 | B, C, H, W = self.input_SP2.shape
262 | logits = torch.reshape(logits.permute(0,2,3,1), (B*H*W, C))
263 | labels = torch.argmax(torch.reshape(self.input_SP2.permute(0,2,3,1), (B*H*W, C)), 1)
264 | AMCE_loss = self.AM_CE_loss(logits, labels)
265 | AMCE_loss *= self.opt.lambda_AMCE
266 | pair_AMCE_loss = AMCE_loss
267 |
268 | self.opt.lambda_style = 200
269 | self.opt.lambda_content = 0.5
270 | loss_content_gen, loss_style_gen = self.Vggloss(self.fake_p2, self.input_P2)
271 | pair_style_loss = loss_style_gen*self.opt.lambda_style
272 | pair_content_loss = loss_content_gen*self.opt.lambda_content
273 |
274 |
275 |
276 | # L1 loss
277 | if self.opt.L1_type == 'l1_plus_perL1' :
278 | losses = self.criterionL1(self.fake_p2, self.input_P2)
279 | self.loss_G_L1 = losses[0]
280 | self.loss_originL1 = losses[1].data
281 | self.loss_perceptual = losses[2].data
282 |
283 | else:
284 | self.loss_G_L1 = self.criterionL1(self.fake_p2, self.input_P2) * self.opt.lambda_A
285 |
286 | pair_L1loss = self.loss_G_L1
287 |
288 | if self.opt.with_D_PB:
289 | pair_GANloss = self.loss_G_GAN_PB * self.opt.lambda_GAN
290 | if self.opt.with_D_PP:
291 | pair_GANloss += self.loss_G_GAN_PP * self.opt.lambda_GAN
292 | pair_GANloss = pair_GANloss / 2
293 | else:
294 | if self.opt.with_D_PP:
295 | pair_GANloss = self.loss_G_GAN_PP * self.opt.lambda_GAN
296 |
297 |
298 | if self.opt.with_D_PB or self.opt.with_D_PP:
299 | pair_loss = pair_L1loss + pair_GANloss
300 | else:
301 | pair_loss = pair_L1loss
302 |
303 | if self.opt.use_cxloss:
304 | pair_loss = pair_loss + pair_cxloss
305 | if self.opt.use_AMCE:
306 | pair_loss = pair_loss + pair_AMCE_loss
307 | if self.opt.use_lpips:
308 | pair_loss = pair_loss + pair_lpips_loss
309 |
310 | pair_loss = pair_loss + pair_content_loss
311 | pair_loss = pair_loss + pair_style_loss
312 |
313 | pair_loss.backward()
314 |
315 | self.pair_L1loss = pair_L1loss.data
316 | if self.opt.with_D_PB or self.opt.with_D_PP:
317 | self.pair_GANloss = pair_GANloss.data
318 |
319 | if self.opt.use_cxloss:
320 | self.pair_cxloss = pair_cxloss.data
321 |
322 | if self.opt.use_lpips:
323 | self.pair_lpips_loss = pair_lpips_loss.data
324 | if self.opt.use_AMCE:
325 | self.pair_AMCE_loss = pair_AMCE_loss.data
326 |
327 | self.pair_content_loss = pair_content_loss.data
328 | self.pair_style_loss = pair_style_loss.data
329 |
330 |
331 | def backward_D_basic(self, netD, real, fake):
332 | # Real
333 | pred_real = netD(real)
334 | loss_D_real = self.criterionGAN(pred_real, True) * self.opt.lambda_GAN
335 | # Fake
336 | pred_fake = netD(fake.detach())
337 | loss_D_fake = self.criterionGAN(pred_fake, False) * self.opt.lambda_GAN
338 | # Combined loss
339 | loss_D = (loss_D_real + loss_D_fake) * 0.5
340 | # backward
341 | loss_D.backward()
342 | return loss_D
343 |
344 | # D: take(P, B) as input
345 | def backward_D_PB(self):
346 | if self.use_BPD:
347 | real_PB = torch.cat((self.input_P2, self.input_BP2, self.input_BPD2), 1)
348 | fake_PB = self.fake_PB_pool.query( torch.cat((self.fake_p2, self.input_BP2, self.input_BPD2), 1).data )
349 | else:
350 | real_PB = torch.cat((self.input_P2, self.input_BP2), 1)
351 | fake_PB = self.fake_PB_pool.query( torch.cat((self.fake_p2, self.input_BP2), 1).data )
352 | loss_D_PB = self.backward_D_basic(self.netD_PB, real_PB, fake_PB)
353 |
354 | self.loss_D_PB = loss_D_PB.data
355 |
356 | # D: take(P, P') as input
357 | def backward_D_PP(self):
358 | real_PP = torch.cat((self.input_P2, self.input_P1), 1)
359 | fake_PP = self.fake_PP_pool.query( torch.cat((self.fake_p2, self.input_P1), 1).data )
360 | loss_D_PP = self.backward_D_basic(self.netD_PP, real_PP, fake_PP)
361 |
362 | self.loss_D_PP = loss_D_PP.data
363 |
364 |
365 | def optimize_parameters(self):
366 | # forward
367 | self.forward()
368 |
369 | self.optimizer_G.zero_grad()
370 | self.backward_G()
371 | self.optimizer_G.step()
372 |
373 | # D_P
374 | if self.opt.with_D_PP:
375 | for i in range(self.opt.DG_ratio):
376 | self.optimizer_D_PP.zero_grad()
377 | self.backward_D_PP()
378 | self.optimizer_D_PP.step()
379 |
380 | # D_BP
381 | if self.opt.with_D_PB:
382 | for i in range(self.opt.DG_ratio):
383 | self.optimizer_D_PB.zero_grad()
384 | self.backward_D_PB()
385 | self.optimizer_D_PB.step()
386 |
387 | def get_current_errors(self):
388 | ret_errors = OrderedDict([ ('pair_L1loss', self.pair_L1loss)])
389 | if self.opt.with_D_PP:
390 | ret_errors['D_PP'] = self.loss_D_PP
391 | if self.opt.with_D_PB:
392 | ret_errors['D_PB'] = self.loss_D_PB
393 | if self.opt.with_D_PB or self.opt.with_D_PP or self.opt.with_D_PS:
394 | ret_errors['pair_GANloss'] = self.pair_GANloss
395 |
396 | if self.opt.L1_type == 'l1_plus_perL1':
397 | ret_errors['origin_L1'] = self.loss_originL1
398 | ret_errors['perceptual'] = self.loss_perceptual
399 |
400 | if self.opt.use_cxloss:
401 | ret_errors['CXLoss'] = self.pair_cxloss
402 | if self.opt.use_lpips:
403 | ret_errors['lpips'] = self.pair_lpips_loss
404 | if self.opt.use_AMCE:
405 | ret_errors['AMCE'] = self.pair_AMCE_loss
406 |
407 | ret_errors['content'] = self.pair_content_loss
408 | ret_errors['style'] = self.pair_style_loss
409 |
410 | return ret_errors
411 |
412 | def get_current_visuals(self):
413 | height, width = self.input_P1.size(2), self.input_P1.size(3)
414 | input_P1 = util.tensor2im(self.input_P1.data)
415 | input_P2 = util.tensor2im(self.input_P2.data)
416 |
417 | input_BP1 = util.draw_pose_from_map(self.input_BP1.data)[0]
418 | input_BP2 = util.draw_pose_from_map(self.input_BP2.data)[0]
419 |
420 |
421 | if self.use_BPD:
422 | input_BPD1 = util.draw_dis_from_map(self.input_BP1.data)[1]
423 | input_BPD1 = (np.repeat(np.expand_dims(input_BPD1, -1), 3, -1)*255).astype('uint8')
424 | input_BPD2 = util.draw_dis_from_map(self.input_BP2.data)[1]
425 | input_BPD2 = (np.repeat(np.expand_dims(input_BPD2, -1), 3, -1)*255).astype('uint8')
426 |
427 |
428 | fake_p2 = util.tensor2im(self.fake_p2.data)
429 |
430 | if self.use_BPD:
431 | vis = np.zeros((height, width*7, 3)).astype(np.uint8) #h, w, c
432 | vis[:, :width, :] = input_P1
433 | vis[:, width:width*2, :] = input_BP1
434 | vis[:, width*2:width*3, :] = input_BPD1
435 | vis[:, width*3:width*4, :] = input_P2
436 | vis[:, width*4:width*5, :] = input_BP2
437 | vis[:, width*5:width*6, :] = input_BPD2
438 | vis[:, width*6:width*7, :] = fake_p2
439 | else:
440 | vis = np.zeros((height, width*5, 3)).astype(np.uint8) #h, w, c
441 | vis[:, :width, :] = input_P1
442 | vis[:, width:width*2, :] = input_BP1
443 | vis[:, width*2:width*3, :] = input_P2
444 | vis[:, width*3:width*4, :] = input_BP2
445 | vis[:, width*4:, :] = fake_p2
446 |
447 | ret_visuals = OrderedDict([('vis', vis)])
448 |
449 | return ret_visuals
450 |
451 |
452 | def save(self, label):
453 | self.save_network(self.netG, 'netG', label, self.gpu_ids)
454 | if self.opt.with_D_PB:
455 | self.save_network(self.netD_PB, 'netD_PB', label, self.gpu_ids)
456 | if self.opt.with_D_PP:
457 | self.save_network(self.netD_PP, 'netD_PP', label, self.gpu_ids)
458 |
459 |
460 |
461 |
462 |
--------------------------------------------------------------------------------
/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torchvision.models.vgg as models
5 |
6 | class BaseModel(nn.Module):
7 |
8 | def __init__(self):
9 | super(BaseModel, self).__init__()
10 |
11 | def name(self):
12 | return 'BaseModel'
13 |
14 | def initialize(self, opt):
15 | self.opt = opt
16 | self.gpu_ids = opt.gpu_ids
17 | self.isTrain = opt.isTrain
18 | self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor
19 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
20 | self.vgg_path = os.path.join(os.path.abspath(opt.dataroot), 'vgg19-dcbb9e9d.pth')
21 |
22 | def set_input(self, input):
23 | self.input = input
24 |
25 | def forward(self):
26 | pass
27 |
28 | # used in test time, no backprop
29 | def test(self):
30 | pass
31 |
32 | def get_image_paths(self):
33 | pass
34 |
35 | def optimize_parameters(self):
36 | pass
37 |
38 | def get_current_visuals(self):
39 | return self.input
40 |
41 | def get_current_errors(self):
42 | return {}
43 |
44 | def save(self, label):
45 | pass
46 |
47 | # helper saving function that can be used by subclasses
48 | def save_network(self, network, network_label, epoch_label, gpu_ids):
49 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
50 | save_path = os.path.join(self.save_dir, save_filename)
51 | torch.save(network.cpu().state_dict(), save_path)
52 | if len(gpu_ids) and torch.cuda.is_available():
53 | network.cuda(gpu_ids[0])
54 |
55 | # helper loading function that can be used by subclasses
56 | def load_network(self, network, network_label, epoch_label):
57 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
58 | save_path = os.path.join(self.save_dir, save_filename)
59 | # network.load_state_dict(torch.load(save_path))
60 |
61 | model_dict = torch.load(save_path)
62 | model_dict_clone = model_dict.copy() # We can't mutate while iterating
63 | for key, value in model_dict_clone.items():
64 | if key.endswith(('running_mean', 'running_var')):
65 | del model_dict[key]
66 | ### Next cell
67 | network.load_state_dict(model_dict, False)
68 |
69 | def load_VGG(self, network):
70 | # pretrained_dict = torch.load(self.vgg_path)
71 |
72 | # pretrained_model = models.vgg19(pretrained=True).features
73 | vgg19 = models.vgg19(pretrained=False)
74 | vgg19.load_state_dict(torch.load(self.vgg_path))
75 | pretrained_model = vgg19.features
76 |
77 | pretrained_dict = pretrained_model.state_dict()
78 |
79 | model_dict = network.state_dict()
80 |
81 | # filter out unnecessary keys
82 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
83 | # overwrite entries in the existing state dict
84 | model_dict.update(pretrained_dict)
85 | # load the new state dict
86 | network.load_state_dict(model_dict)
87 |
88 | # update learning rate (called once every epoch)
89 | def update_learning_rate(self):
90 | for scheduler in self.schedulers:
91 | scheduler.step()
92 | lr = self.optimizers[0].param_groups[0]['lr']
93 | print('learning rate = %.7f' % lr)
94 |
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 |
2 | def create_model(opt):
3 | model = None
4 | print(opt.model)
5 | if opt.model == 'adgan':
6 | assert opt.dataset_mode == 'keypoint'
7 | from .adgan import TransferModel
8 | model = TransferModel()
9 | elif opt.model == 'adgan_mix':
10 | assert opt.dataset_mode == 'keypoint_mix'
11 | from .adgan_mix import TransferModel
12 | model = TransferModel()
13 | else:
14 | raise ValueError("Model [%s] not recognized." % opt.model)
15 | model.initialize(opt)
16 | print("model [%s] was created" % (model.name()))
17 | return model
18 |
--------------------------------------------------------------------------------
/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 | import functools
5 | from torch.autograd import Variable
6 | from torch.optim import lr_scheduler
7 |
8 | import math
9 |
10 | # added
11 | def weights_init_ada(init_type='gaussian'):
12 | def init_fun(m):
13 | classname = m.__class__.__name__
14 | if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
15 | # print m.__class__.__name__
16 | if init_type == 'gaussian':
17 | init.normal_(m.weight.data, 0.0, 0.02)
18 | elif init_type == 'xavier':
19 | init.xavier_normal_(m.weight.data, gain=math.sqrt(2))
20 | elif init_type == 'kaiming':
21 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
22 | elif init_type == 'orthogonal':
23 | init.orthogonal_(m.weight.data, gain=math.sqrt(2))
24 | elif init_type == 'default':
25 | pass
26 | else:
27 | assert 0, "Unsupported initialization: {}".format(init_type)
28 | if hasattr(m, 'bias') and m.bias is not None:
29 | init.constant_(m.bias.data, 0.0)
30 | return init_fun
31 |
32 |
33 | def weights_init_normal(m):
34 | classname = m.__class__.__name__
35 | if classname.find('Conv') != -1 and hasattr(m, 'weight'):
36 | init.normal(m.weight.data, 0.0, 0.02)
37 | elif classname.find('Linear') != -1 and hasattr(m, 'weight'):
38 | init.normal(m.weight.data, 0.0, 0.02)
39 | elif classname.find('BatchNorm2d') != -1:
40 | init.normal(m.weight.data, 1.0, 0.02)
41 | init.constant(m.bias.data, 0.0)
42 |
43 |
44 | def weights_init_xavier(m):
45 | classname = m.__class__.__name__
46 | # print(classname)
47 | if classname.find('Conv') != -1:
48 | init.xavier_normal(m.weight.data, gain=0.02)
49 | elif classname.find('Linear') != -1:
50 | init.xavier_normal(m.weight.data, gain=0.02)
51 | elif classname.find('BatchNorm2d') != -1:
52 | init.normal(m.weight.data, 1.0, 0.02)
53 | init.constant(m.bias.data, 0.0)
54 |
55 |
56 | def weights_init_kaiming(m):
57 | classname = m.__class__.__name__
58 | # print(classname)
59 | if classname.find('Conv') != -1:
60 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
61 | elif classname.find('Linear') != -1:
62 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
63 | elif classname.find('BatchNorm2d') != -1:
64 | init.normal(m.weight.data, 1.0, 0.02)
65 | init.constant(m.bias.data, 0.0)
66 |
67 |
68 | def weights_init_orthogonal(m):
69 | classname = m.__class__.__name__
70 | print(classname)
71 | if classname.find('Conv') != -1:
72 | init.orthogonal(m.weight.data, gain=1)
73 | elif classname.find('Linear') != -1:
74 | init.orthogonal(m.weight.data, gain=1)
75 | elif classname.find('BatchNorm2d') != -1:
76 | init.normal(m.weight.data, 1.0, 0.02)
77 | init.constant(m.bias.data, 0.0)
78 |
79 |
80 | def init_weights(net, init_type='normal'):
81 | print('initialization method [%s]' % init_type)
82 | if init_type == 'normal':
83 | net.apply(weights_init_normal)
84 | elif init_type == 'xavier':
85 | net.apply(weights_init_xavier)
86 | elif init_type == 'kaiming':
87 | net.apply(weights_init_kaiming)
88 | elif init_type == 'orthogonal':
89 | net.apply(weights_init_orthogonal)
90 | else:
91 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
92 |
93 |
94 | def get_norm_layer(norm_type='instance'):
95 | if norm_type == 'batch':
96 | norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
97 | elif norm_type == 'batch_sync':
98 | norm_layer = BatchNorm2d
99 | elif norm_type == 'instance':
100 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
101 | elif norm_type == 'none':
102 | norm_layer = None
103 | else:
104 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
105 | return norm_layer
106 |
107 |
108 | def get_scheduler(optimizer, opt):
109 | if opt.lr_policy == 'lambda':
110 | def lambda_rule(epoch):
111 | lr_l = 1.0 - max(0, epoch + 1 + opt.epoch_count - opt.niter) / float(opt.niter_decay + 1)
112 | return lr_l
113 |
114 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
115 | elif opt.lr_policy == 'step':
116 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
117 | elif opt.lr_policy == 'plateau':
118 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
119 | else:
120 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
121 | return scheduler
122 |
123 |
124 | def define_G(input_nc, output_nc, ngf, which_model_netG, norm='batch', use_dropout=False, init_type='normal',
125 | gpu_ids=[], n_downsampling=2):
126 | netG = None
127 | use_gpu = len(gpu_ids) > 0
128 | norm_layer = get_norm_layer(norm_type=norm)
129 |
130 | if use_gpu:
131 | assert (torch.cuda.is_available())
132 |
133 | if which_model_netG == 'CASD':
134 | style_dim = 2048
135 | n_res = 8
136 | mlp_dim = 256
137 | from models.CASD import ADGen
138 | netG = ADGen(input_nc, ngf, style_dim, n_downsampling, n_res, mlp_dim)
139 | else:
140 | raise NotImplementedError('Generator model name [%s] is not recognized' % which_model_netG)
141 | if len(gpu_ids) > 1:
142 | netG = torch.nn.DataParallel(netG, device_ids=gpu_ids)
143 | netG.cuda()
144 | init_weights(netG, init_type=init_type)
145 | return netG
146 |
147 |
148 | class AttrDict(dict):
149 |
150 | def __init__(self,*args,**kwargs):
151 | super().__init__(*args,**kwargs)
152 |
153 | def operation_list(self,value):
154 | new_value = []
155 | for v in value:
156 | if isinstance(v, dict):
157 | new_value.append(AttrDict(v))
158 | elif isinstance(v,list):
159 | new_value.append(self.operation_list(v))
160 | else:
161 | new_value.append(v)
162 | return new_value
163 |
164 | def __getattr__(self, item):
165 | value=self[item]
166 | if isinstance(value,dict):
167 | value=AttrDict(value)
168 | elif isinstance(value,list):
169 | value=self.operation_list(value)
170 | return value
171 |
172 |
173 | def define_D(input_nc, ndf, which_model_netD,
174 | n_layers_D=3, norm='batch', use_sigmoid=False, init_type='normal', gpu_ids=[], use_dropout=False,
175 | n_downsampling=2):
176 | netD = None
177 | use_gpu = len(gpu_ids) > 0
178 | norm_layer = get_norm_layer(norm_type=norm)
179 |
180 | if use_gpu:
181 | assert (torch.cuda.is_available())
182 |
183 | if which_model_netD == 'resnet':
184 | netD = ResnetDiscriminator(input_nc, ndf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=n_layers_D,
185 | gpu_ids=[], padding_type='reflect', use_sigmoid=use_sigmoid,
186 | n_downsampling=n_downsampling)
187 | else:
188 | raise NotImplementedError('Discriminator model name [%s] is not recognized' %
189 | which_model_netD)
190 | if len(gpu_ids) > 1:
191 | netD = torch.nn.DataParallel(netD, device_ids=gpu_ids)
192 | netD.cuda()
193 | return netD
194 |
195 |
196 | def print_network(net):
197 | num_params = 0
198 | for param in net.parameters():
199 | num_params += param.numel()
200 | print(net)
201 | print('Total number of parameters: %d' % num_params)
202 |
203 |
204 | ##############################################################################
205 | # Classes
206 | ##############################################################################
207 |
208 | # Defines the GAN loss which uses either LSGAN or the regular GAN.
209 | # When LSGAN is used, it is basically same as MSELoss,
210 | # but it abstracts away the need to create the target label tensor
211 | # that has the same size as the input
212 | class GANLoss(nn.Module):
213 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
214 | tensor=torch.FloatTensor):
215 | super(GANLoss, self).__init__()
216 | self.real_label = target_real_label
217 | self.fake_label = target_fake_label
218 | self.real_label_var = None
219 | self.fake_label_var = None
220 | self.Tensor = tensor
221 | if use_lsgan:
222 | self.loss = nn.MSELoss()
223 | else:
224 | self.loss = nn.BCELoss()
225 |
226 | def get_target_tensor(self, input, target_is_real):
227 | target_tensor = None
228 | if target_is_real:
229 | create_label = ((self.real_label_var is None) or
230 | (self.real_label_var.numel() != input.numel()))
231 | if create_label:
232 | real_tensor = self.Tensor(input.size()).fill_(self.real_label)
233 | self.real_label_var = Variable(real_tensor, requires_grad=False)
234 | target_tensor = self.real_label_var
235 | else:
236 | create_label = ((self.fake_label_var is None) or
237 | (self.fake_label_var.numel() != input.numel()))
238 | if create_label:
239 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
240 | self.fake_label_var = Variable(fake_tensor, requires_grad=False)
241 | target_tensor = self.fake_label_var
242 | return target_tensor
243 |
244 | def __call__(self, input, target_is_real):
245 | target_tensor = self.get_target_tensor(input, target_is_real)
246 | return self.loss(input, target_tensor)
247 |
248 |
249 |
250 | # Define a resnet block
251 | class ResnetBlock(nn.Module):
252 | def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
253 | super(ResnetBlock, self).__init__()
254 | self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
255 |
256 | def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
257 | conv_block = []
258 | p = 0
259 | if padding_type == 'reflect':
260 | conv_block += [nn.ReflectionPad2d(1)]
261 | elif padding_type == 'replicate':
262 | conv_block += [nn.ReplicationPad2d(1)]
263 | elif padding_type == 'zero':
264 | p = 1
265 | else:
266 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
267 |
268 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
269 | norm_layer(dim),
270 | nn.ReLU(True)]
271 | if use_dropout:
272 | conv_block += [nn.Dropout(0.5)]
273 |
274 | p = 0
275 | if padding_type == 'reflect':
276 | conv_block += [nn.ReflectionPad2d(1)]
277 | elif padding_type == 'replicate':
278 | conv_block += [nn.ReplicationPad2d(1)]
279 | elif padding_type == 'zero':
280 | p = 1
281 | else:
282 | raise NotImplementedError('padding [%s] is not implemented' % padding_type)
283 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
284 | norm_layer(dim)]
285 |
286 | return nn.Sequential(*conv_block)
287 |
288 | def forward(self, x):
289 | out = x + self.conv_block(x)
290 | return out
291 |
292 | class ResnetDiscriminator(nn.Module):
293 | def __init__(self, input_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, gpu_ids=[],
294 | padding_type='reflect', use_sigmoid=False, n_downsampling=2):
295 | assert (n_blocks >= 0)
296 | super(ResnetDiscriminator, self).__init__()
297 | self.input_nc = input_nc
298 | self.ngf = ngf
299 | self.gpu_ids = gpu_ids
300 | if type(norm_layer) == functools.partial:
301 | use_bias = norm_layer.func == nn.InstanceNorm2d
302 | else:
303 | use_bias = norm_layer == nn.InstanceNorm2d
304 |
305 | model = [nn.ReflectionPad2d(3),
306 | nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0,
307 | bias=use_bias),
308 | norm_layer(ngf),
309 | nn.ReLU(True)]
310 |
311 | # n_downsampling = 2
312 | if n_downsampling <= 2:
313 | for i in range(n_downsampling):
314 | mult = 2 ** i
315 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
316 | stride=2, padding=1, bias=use_bias),
317 | norm_layer(ngf * mult * 2),
318 | nn.ReLU(True)]
319 | elif n_downsampling == 3:
320 | mult = 2 ** 0
321 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
322 | stride=2, padding=1, bias=use_bias),
323 | norm_layer(ngf * mult * 2),
324 | nn.ReLU(True)]
325 | mult = 2 ** 1
326 | model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
327 | stride=2, padding=1, bias=use_bias),
328 | norm_layer(ngf * mult * 2),
329 | nn.ReLU(True)]
330 | mult = 2 ** 2
331 | model += [nn.Conv2d(ngf * mult, ngf * mult, kernel_size=3,
332 | stride=2, padding=1, bias=use_bias),
333 | norm_layer(ngf * mult),
334 | nn.ReLU(True)]
335 |
336 | if n_downsampling <= 2:
337 | mult = 2 ** n_downsampling
338 | else:
339 | mult = 4
340 | for i in range(n_blocks):
341 | model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout,
342 | use_bias=use_bias)]
343 |
344 | if use_sigmoid:
345 | model += [nn.Sigmoid()]
346 |
347 | self.model = nn.Sequential(*model)
348 |
349 | def forward(self, input):
350 | if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor):
351 | return nn.parallel.data_parallel(self.model, input, self.gpu_ids)
352 | else:
353 | return self.model(input)
354 |
355 |
356 |
--------------------------------------------------------------------------------
/models/test_model.py:
--------------------------------------------------------------------------------
1 | from torch.autograd import Variable
2 | from collections import OrderedDict
3 | import util.util as util
4 | from .base_model import BaseModel
5 | from . import networks
6 |
7 |
8 | class TestModel(BaseModel):
9 | def name(self):
10 | return 'TestModel'
11 |
12 | def initialize(self, opt):
13 | assert(not opt.isTrain)
14 | BaseModel.initialize(self, opt)
15 | self.input_A = self.Tensor(opt.batchSize, opt.input_nc, opt.fineSize, opt.fineSize)
16 |
17 | self.netG = networks.define_G(opt.input_nc, opt.output_nc,
18 | opt.ngf, opt.which_model_netG,
19 | opt.norm, not opt.no_dropout,
20 | opt.init_type,
21 | self.gpu_ids)
22 | which_epoch = opt.which_epoch
23 | self.load_network(self.netG, 'G', which_epoch)
24 |
25 | print('---------- Networks initialized -------------')
26 | networks.print_network(self.netG)
27 | print('-----------------------------------------------')
28 |
29 | def set_input(self, input):
30 | # we need to use single_dataset mode
31 | input_A = input['A']
32 | self.input_A.resize_(input_A.size()).copy_(input_A)
33 | self.image_paths = input['A_paths']
34 |
35 | def test(self):
36 | self.real_A = Variable(self.input_A)
37 | self.fake_B = self.netG(self.real_A)
38 |
39 | # get image paths
40 | def get_image_paths(self):
41 | return self.image_paths
42 |
43 | def get_current_visuals(self):
44 | real_A = util.tensor2im(self.real_A.data)
45 | fake_B = util.tensor2im(self.fake_B.data)
46 | return OrderedDict([('real_A', real_A), ('fake_B', fake_B)])
47 |
--------------------------------------------------------------------------------
/models/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | # gram matrix and loss
7 | class GramMatrix(nn.Module):
8 | def forward(self, input):
9 | b, c, h, w = input.size()
10 | F = input.view(b, c, h * w)
11 | G = torch.bmm(F, F.transpose(1, 2))
12 | G.div_(h * w)
13 | return G
14 |
15 |
16 | class GramMSELoss(nn.Module):
17 | def forward(self, input, target):
18 | out = nn.MSELoss()(GramMatrix()(input), target)
19 | return (out)
20 |
21 |
22 | # vgg definition that conveniently let's you grab the outputs from any layer
23 | class VGG(nn.Module):
24 | def __init__(self, pool='max'):
25 | super(VGG, self).__init__()
26 | # vgg modules
27 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
28 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
29 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
30 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
31 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
32 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
33 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
34 | self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
35 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
36 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
37 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
38 | self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
39 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
40 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
41 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
42 | self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
43 | if pool == 'max':
44 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
45 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
46 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
47 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
48 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
49 | elif pool == 'avg':
50 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
51 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
52 | self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
53 | self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
54 | self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)
55 |
56 | def forward(self, x, out_keys):
57 | out = {}
58 | out['r11'] = F.relu(self.conv1_1(x))
59 | out['r12'] = F.relu(self.conv1_2(out['r11']))
60 | out['p1'] = self.pool1(out['r12'])
61 | out['r21'] = F.relu(self.conv2_1(out['p1']))
62 | out['r22'] = F.relu(self.conv2_2(out['r21']))
63 | out['p2'] = self.pool2(out['r22'])
64 | out['r31'] = F.relu(self.conv3_1(out['p2']))
65 | out['r32'] = F.relu(self.conv3_2(out['r31']))
66 | out['r33'] = F.relu(self.conv3_3(out['r32']))
67 | out['r34'] = F.relu(self.conv3_4(out['r33']))
68 | out['p3'] = self.pool3(out['r34'])
69 | out['r41'] = F.relu(self.conv4_1(out['p3']))
70 | out['r42'] = F.relu(self.conv4_2(out['r41']))
71 | out['r43'] = F.relu(self.conv4_3(out['r42']))
72 | out['r44'] = F.relu(self.conv4_4(out['r43']))
73 | out['p4'] = self.pool4(out['r44'])
74 | out['r51'] = F.relu(self.conv5_1(out['p4']))
75 | out['r52'] = F.relu(self.conv5_2(out['r51']))
76 | out['r53'] = F.relu(self.conv5_3(out['r52']))
77 | out['r54'] = F.relu(self.conv5_4(out['r53']))
78 | out['p5'] = self.pool5(out['r54'])
79 | return [out[key] for key in out_keys]
80 |
81 |
--------------------------------------------------------------------------------
/models/vgg_SC.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchvision.models.vgg as models
6 |
7 |
8 |
9 | # gram matrix and loss
10 | class GramMatrix(nn.Module):
11 | def forward(self, input):
12 | b, c, h, w = input.size()
13 | F = input.view(b, c, h * w)
14 | G = torch.bmm(F, F.transpose(1, 2))
15 | G.div_(h * w)
16 | return G
17 |
18 |
19 | class GramMSELoss(nn.Module):
20 | def forward(self, input, target):
21 | out = nn.MSELoss()(GramMatrix()(input), target)
22 | return (out)
23 |
24 |
25 | # vgg definition that conveniently let's you grab the outputs from any layer
26 | class VGG(nn.Module):
27 | def __init__(self, pool='max'):
28 | super(VGG, self).__init__()
29 | # vgg modules
30 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
31 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
32 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
33 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
34 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
35 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
36 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
37 | self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
38 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
39 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
40 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
41 | self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
42 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
43 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
44 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
45 | self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
46 | if pool == 'max':
47 | self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
48 | self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
49 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
50 | self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
51 | self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
52 | elif pool == 'avg':
53 | self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
54 | self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
55 | self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
56 | self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
57 | self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)
58 |
59 | def forward(self, x, out_keys):
60 | out = {}
61 | out['r11'] = F.relu(self.conv1_1(x))
62 | out['r12'] = F.relu(self.conv1_2(out['r11']))
63 | out['p1'] = self.pool1(out['r12'])
64 | out['r21'] = F.relu(self.conv2_1(out['p1']))
65 | out['r22'] = F.relu(self.conv2_2(out['r21']))
66 | out['p2'] = self.pool2(out['r22'])
67 | out['r31'] = F.relu(self.conv3_1(out['p2']))
68 | out['r32'] = F.relu(self.conv3_2(out['r31']))
69 | out['r33'] = F.relu(self.conv3_3(out['r32']))
70 | out['r34'] = F.relu(self.conv3_4(out['r33']))
71 | out['p3'] = self.pool3(out['r34'])
72 | out['r41'] = F.relu(self.conv4_1(out['p3']))
73 | out['r42'] = F.relu(self.conv4_2(out['r41']))
74 | out['r43'] = F.relu(self.conv4_3(out['r42']))
75 | out['r44'] = F.relu(self.conv4_4(out['r43']))
76 | out['p4'] = self.pool4(out['r44'])
77 | out['r51'] = F.relu(self.conv5_1(out['p4']))
78 | out['r52'] = F.relu(self.conv5_2(out['r51']))
79 | out['r53'] = F.relu(self.conv5_3(out['r52']))
80 | out['r54'] = F.relu(self.conv5_4(out['r53']))
81 | out['p5'] = self.pool5(out['r54'])
82 | return [out[key] for key in out_keys]
83 |
84 |
85 | class VGGLoss(nn.Module):
86 | r"""
87 | Perceptual loss, VGG-based
88 | https://arxiv.org/abs/1603.08155
89 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
90 | """
91 |
92 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
93 | super(VGGLoss, self).__init__()
94 | self.add_module('vgg', VGG19())
95 | self.criterion = torch.nn.L1Loss()
96 | self.weights = weights
97 |
98 | def compute_gram(self, x):
99 | b, ch, h, w = x.size()
100 | f = x.view(b, ch, w * h)
101 | f_T = f.transpose(1, 2)
102 | G = f.bmm(f_T) / (h * w * ch)
103 | return G
104 |
105 | def __call__(self, x, y):
106 | # Compute features
107 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
108 |
109 | content_loss = 0.0
110 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
111 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
112 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
113 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
114 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])
115 |
116 | # Compute loss
117 | style_loss = 0.0
118 | style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2']))
119 | style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4']))
120 | style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4']))
121 | style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2']))
122 |
123 | return content_loss, style_loss
124 |
125 |
126 | class StyleLoss(nn.Module):
127 | r"""
128 | Perceptual loss, VGG-based
129 | https://arxiv.org/abs/1603.08155
130 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
131 | """
132 |
133 | def __init__(self):
134 | super(StyleLoss, self).__init__()
135 | self.add_module('vgg', VGG19())
136 | self.criterion = torch.nn.L1Loss()
137 |
138 | def compute_gram(self, x):
139 | b, ch, h, w = x.size()
140 | f = x.view(b, ch, w * h)
141 | f_T = f.transpose(1, 2)
142 | G = f.bmm(f_T) / (h * w * ch)
143 |
144 | return G
145 |
146 | def __call__(self, x, y):
147 | # Compute features
148 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
149 |
150 | # Compute loss
151 | style_loss = 0.0
152 | style_loss += self.criterion(self.compute_gram(x_vgg['relu2_2']), self.compute_gram(y_vgg['relu2_2']))
153 | style_loss += self.criterion(self.compute_gram(x_vgg['relu3_4']), self.compute_gram(y_vgg['relu3_4']))
154 | style_loss += self.criterion(self.compute_gram(x_vgg['relu4_4']), self.compute_gram(y_vgg['relu4_4']))
155 | style_loss += self.criterion(self.compute_gram(x_vgg['relu5_2']), self.compute_gram(y_vgg['relu5_2']))
156 |
157 | return style_loss
158 |
159 |
160 | class PerceptualLoss(nn.Module):
161 | r"""
162 | Perceptual loss, VGG-based
163 | https://arxiv.org/abs/1603.08155
164 | https://github.com/dxyang/StyleTransfer/blob/master/utils.py
165 | """
166 |
167 | def __init__(self, weights=[1.0, 1.0, 1.0, 1.0, 1.0]):
168 | super(PerceptualLoss, self).__init__()
169 | self.add_module('vgg', VGG19())
170 | self.criterion = torch.nn.L1Loss()
171 | self.weights = weights
172 |
173 | def __call__(self, x, y):
174 | # Compute features
175 | x_vgg, y_vgg = self.vgg(x), self.vgg(y)
176 | content_loss = 0.0
177 | content_loss += self.weights[0] * self.criterion(x_vgg['relu1_1'], y_vgg['relu1_1'])
178 | content_loss += self.weights[1] * self.criterion(x_vgg['relu2_1'], y_vgg['relu2_1'])
179 | content_loss += self.weights[2] * self.criterion(x_vgg['relu3_1'], y_vgg['relu3_1'])
180 | content_loss += self.weights[3] * self.criterion(x_vgg['relu4_1'], y_vgg['relu4_1'])
181 | content_loss += self.weights[4] * self.criterion(x_vgg['relu5_1'], y_vgg['relu5_1'])
182 |
183 | return content_loss
184 |
185 | class VGG19(torch.nn.Module):
186 | def __init__(self):
187 | super(VGG19, self).__init__()
188 | # features = models.vgg19(pretrained=True).features
189 |
190 | vgg19 = models.vgg19(pretrained=False)
191 | vgg19.load_state_dict(torch.load('/home/haihuam/CASD-main/dataset/fashion/vgg19-dcbb9e9d.pth'))
192 | self.vgg = vgg19.features
193 | features = vgg19.features
194 |
195 | for param in self.vgg.parameters():
196 | param.requires_grad_(False)
197 |
198 |
199 | self.relu1_1 = torch.nn.Sequential()
200 | self.relu1_2 = torch.nn.Sequential()
201 |
202 | self.relu2_1 = torch.nn.Sequential()
203 | self.relu2_2 = torch.nn.Sequential()
204 |
205 | self.relu3_1 = torch.nn.Sequential()
206 | self.relu3_2 = torch.nn.Sequential()
207 | self.relu3_3 = torch.nn.Sequential()
208 | self.relu3_4 = torch.nn.Sequential()
209 |
210 | self.relu4_1 = torch.nn.Sequential()
211 | self.relu4_2 = torch.nn.Sequential()
212 | self.relu4_3 = torch.nn.Sequential()
213 | self.relu4_4 = torch.nn.Sequential()
214 |
215 | self.relu5_1 = torch.nn.Sequential()
216 | self.relu5_2 = torch.nn.Sequential()
217 | self.relu5_3 = torch.nn.Sequential()
218 | self.relu5_4 = torch.nn.Sequential()
219 |
220 | for x in range(2):
221 | self.relu1_1.add_module(str(x), features[x])
222 |
223 | for x in range(2, 4):
224 | self.relu1_2.add_module(str(x), features[x])
225 |
226 | for x in range(4, 7):
227 | self.relu2_1.add_module(str(x), features[x])
228 |
229 | for x in range(7, 9):
230 | self.relu2_2.add_module(str(x), features[x])
231 |
232 | for x in range(9, 12):
233 | self.relu3_1.add_module(str(x), features[x])
234 |
235 | for x in range(12, 14):
236 | self.relu3_2.add_module(str(x), features[x])
237 |
238 | for x in range(14, 16):
239 | self.relu3_2.add_module(str(x), features[x])
240 |
241 | for x in range(16, 18):
242 | self.relu3_4.add_module(str(x), features[x])
243 |
244 | for x in range(18, 21):
245 | self.relu4_1.add_module(str(x), features[x])
246 |
247 | for x in range(21, 23):
248 | self.relu4_2.add_module(str(x), features[x])
249 |
250 | for x in range(23, 25):
251 | self.relu4_3.add_module(str(x), features[x])
252 |
253 | for x in range(25, 27):
254 | self.relu4_4.add_module(str(x), features[x])
255 |
256 | for x in range(27, 30):
257 | self.relu5_1.add_module(str(x), features[x])
258 |
259 | for x in range(30, 32):
260 | self.relu5_2.add_module(str(x), features[x])
261 |
262 | for x in range(32, 34):
263 | self.relu5_3.add_module(str(x), features[x])
264 |
265 | for x in range(34, 36):
266 | self.relu5_4.add_module(str(x), features[x])
267 |
268 | # don't need the gradients, just want the features
269 | for param in self.parameters():
270 | param.requires_grad = False
271 |
272 | def forward(self, x):
273 | relu1_1 = self.relu1_1(x)
274 | relu1_2 = self.relu1_2(relu1_1)
275 |
276 | relu2_1 = self.relu2_1(relu1_2)
277 | relu2_2 = self.relu2_2(relu2_1)
278 |
279 | relu3_1 = self.relu3_1(relu2_2)
280 | relu3_2 = self.relu3_2(relu3_1)
281 | relu3_3 = self.relu3_3(relu3_2)
282 | relu3_4 = self.relu3_4(relu3_3)
283 |
284 | relu4_1 = self.relu4_1(relu3_4)
285 | relu4_2 = self.relu4_2(relu4_1)
286 | relu4_3 = self.relu4_3(relu4_2)
287 | relu4_4 = self.relu4_4(relu4_3)
288 |
289 | relu5_1 = self.relu5_1(relu4_4)
290 | relu5_2 = self.relu5_2(relu5_1)
291 | relu5_3 = self.relu5_3(relu5_2)
292 | relu5_4 = self.relu5_4(relu5_3)
293 |
294 | out = {
295 | 'relu1_1': relu1_1,
296 | 'relu1_2': relu1_2,
297 |
298 | 'relu2_1': relu2_1,
299 | 'relu2_2': relu2_2,
300 |
301 | 'relu3_1': relu3_1,
302 | 'relu3_2': relu3_2,
303 | 'relu3_3': relu3_3,
304 | 'relu3_4': relu3_4,
305 |
306 | 'relu4_1': relu4_1,
307 | 'relu4_2': relu4_2,
308 | 'relu4_3': relu4_3,
309 | 'relu4_4': relu4_4,
310 |
311 | 'relu5_1': relu5_1,
312 | 'relu5_2': relu5_2,
313 | 'relu5_3': relu5_3,
314 | 'relu5_4': relu5_4,
315 | }
316 | return out
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/options/__init__.py
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from util import util
4 | import torch
5 |
6 |
7 | class BaseOptions():
8 | def __init__(self):
9 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
10 | self.initialized = False
11 |
12 | def initialize(self):
13 | self.parser.add_argument('--dataroot', default='./dataset/fashion',\
14 | help='path to images ')
15 | self.parser.add_argument('--dirSem', default='./dataset/fashion',\
16 | help='path to semantic images')
17 |
18 | self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
19 | self.parser.add_argument('--which_model_netG', type=str, default='CASD', help='selects model to use for netG')
20 | self.parser.add_argument('--name', type=str,
21 | default='CASD_test',
22 | help='name of the experiment. It decides where to store samples and models')
23 | self.parser.add_argument('--fineSize', type=int, default=[256,256], help='input image size')
24 | self.parser.add_argument('--pairLst', type=str, default='./dataset/fashion/fashion-resize-pairs-train.csv', help='fashion pairs')
25 | # self.parser.add_argument('--pairLst', type=str, default='./dataset/fashion/fashion-resize-pairs-test.csv', help='fashion pairs')
26 |
27 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
28 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
29 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
30 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
31 | self.parser.add_argument('--which_model_netD', type=str, default='resnet', help='selects model to use for netD')
32 | self.parser.add_argument('--n_layers_D', type=int, default=0, help='blocks used in D')
33 | self.parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU')
34 | self.parser.add_argument('--dataset_mode', type=str, default='keypoint', help='chooses how datasets are loaded. [unaligned | aligned | single]')
35 | self.parser.add_argument('--model', type=str, default='adgan',
36 | help='chooses which model to use. cycle_gan, pix2pix, test')
37 | self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
38 | self.parser.add_argument('--nThreads', default=2, type=int, help='# threads for loading data')
39 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
40 | self.parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization')
41 | self.parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
42 | self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
43 | self.parser.add_argument('--display_id', type=int, default=0, help='window id of the web display')
44 | self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
45 | self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
46 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
47 | self.parser.add_argument('--resize_or_crop', type=str, default='no', help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
48 | self.parser.add_argument('--no_flip', default=False, help='if specified, flip the images for data augmentation')
49 | self.parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
50 |
51 | self.parser.add_argument('--P_input_nc', type=int, default=3, help='# of input image channels')
52 | self.parser.add_argument('--BP_input_nc', type=int, default=18, help='# of input image channels')
53 | self.parser.add_argument('--BPD_input_nc', type=int, default=12, help='# of input image channels')
54 | self.parser.add_argument('--SP_input_nc', type=int, default=8, help='# of input image channels')
55 | self.parser.add_argument('--with_D_PP', type=int, default=1, help='use D to judge P and P is pair or not')
56 | self.parser.add_argument('--with_D_PB', type=int, default=1, help='use D to judge P and B is pair or not')
57 | self.parser.add_argument('--without_concat_SBP', type=int, default=0, help='do not concat source BP')
58 | self.parser.add_argument('--use_flip', type=int, default=0, help='flip or not')
59 |
60 | # down-sampling times
61 | self.parser.add_argument('--G_n_downsampling', type=int, default=2, help='down-sampling blocks for generator')
62 | self.parser.add_argument('--D_n_downsampling', type=int, default=2, help='down-sampling blocks for discriminator')
63 | self.parser.add_argument('--use_AMCE', type=int, default=1, help='flip or not')
64 | self.parser.add_argument('--use_BPD', type=int, default=1, help='flip or not')
65 | self.parser.add_argument('--use_lpips', type=int, default=1, help='flip or not')
66 |
67 | self.initialized = True
68 |
69 | def parse(self):
70 | if not self.initialized:
71 | self.initialize()
72 | self.opt = self.parser.parse_args()
73 | self.opt.isTrain = self.isTrain # train or test
74 |
75 | str_ids = self.opt.gpu_ids.split(',')
76 | self.opt.gpu_ids = []
77 | for str_id in str_ids:
78 | id = int(str_id)
79 | if id >= 0:
80 | self.opt.gpu_ids.append(id)
81 |
82 | # set gpu ids
83 | # num_gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',').__len__()
84 | # os.environ['CUDA_VISIBLE_DEVICES'] = '3'
85 | if len(self.opt.gpu_ids) > 0:
86 | torch.cuda.set_device(self.opt.gpu_ids[0])
87 |
88 | args = vars(self.opt)
89 |
90 | print('------------ Options -------------')
91 | for k, v in sorted(args.items()):
92 | print('%s: %s' % (str(k), str(v)))
93 | print('-------------- End ----------------')
94 |
95 | # save to the disk
96 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
97 | util.mkdirs(expr_dir)
98 | file_name = os.path.join(expr_dir, 'opt.txt')
99 | with open(file_name, 'wt') as opt_file:
100 | opt_file.write('------------ Options -------------\n')
101 | for k, v in sorted(args.items()):
102 | opt_file.write('%s: %s\n' % (str(k), str(v)))
103 | opt_file.write('-------------- End ----------------\n')
104 | return self.opt
105 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | def initialize(self):
6 | BaseOptions.initialize(self)
7 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
8 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
9 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
10 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
11 | self.parser.add_argument('--which_epoch', type=str, default='1000', help='which epoch to load? set to latest to use latest cached model')
12 | self.parser.add_argument('--how_many', type=int, default=200, help='how many test images to run')
13 |
14 | self.isTrain = False
15 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TrainOptions(BaseOptions):
5 | def initialize(self):
6 | BaseOptions.initialize(self)
7 | self.parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
8 | self.parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.')
9 | self.parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html')
10 | self.parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
11 | self.parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
12 | self.parser.add_argument('--save_epoch_freq', type=int, default=20, help='frequency of saving checkpoints at the end of epochs')
13 | self.parser.add_argument('--continue_train', default=False, help='continue training: load the latest model')
14 | self.parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...')
15 | self.parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
16 | self.parser.add_argument('--which_epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
17 | self.parser.add_argument('--niter', type=int, default=500, help='# of iter at starting learning rate')
18 | self.parser.add_argument('--niter_decay', type=int, default=500, help='# of iter to linearly decay learning rate to zero')
19 | self.parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
20 | self.parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate for adam')
21 | self.parser.add_argument('--no_lsgan', action='store_true', help='do *not* use least square GAN, if false, use vanilla GAN')
22 |
23 | self.parser.add_argument('--lambda_A', type=float, default=1.0, help='weight for L1 loss')
24 | self.parser.add_argument('--lambda_B', type=float, default=1.0, help='weight for perceptual L1 loss')
25 | self.parser.add_argument('--lambda_GAN', type=float, default=5.0, help='weight of GAN loss')
26 | self.parser.add_argument('--lambda_cx', type=float, default=0.1, help='weight of CX loss')
27 | self.parser.add_argument('--lambda_AMCE', type=float, default=0.1, help='weight of CX loss')
28 | self.parser.add_argument('--lambda_lpips', type=float, default=1.0, help='weight of CX loss')
29 |
30 | self.parser.add_argument('--pool_size', type=int, default=0, help='the size of image buffer that stores previously generated images')
31 | self.parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
32 | self.parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau')
33 | self.parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')
34 | self.parser.add_argument('--L1_type', type=str, default='l1_plus_perL1', help='use which kind of L1 loss. (origin|l1_plus_perL1)')
35 | self.parser.add_argument('--perceptual_layers', type=int, default=3, help='index of vgg layer for extracting perceptual features.')
36 | self.parser.add_argument('--percep_is_l1', type=int, default=1, help='type of perceptual loss: l1 or l2')
37 | self.parser.add_argument('--no_dropout_D', action='store_true', help='no dropout for the discriminator')
38 | self.parser.add_argument('--DG_ratio', type=int, default=1, help='how many times for D training after training G once')
39 | self.parser.add_argument('--use_cxloss', type=int, default=1, help='use cxloss or not')
40 |
41 |
42 | self.isTrain = True
43 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | dominate
2 | scikit-image
3 | pandas
4 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from options.test_options import TestOptions
3 | from data.data_loader import CreateDataLoader
4 | from models.models import create_model
5 | from util.visualizer import Visualizer
6 | from util import html
7 | import time
8 |
9 | opt = TestOptions().parse()
10 | opt.nThreads = 1 # test code only supports nThreads = 1
11 | opt.batchSize = 1 # test code only supports batchSize = 1
12 | opt.serial_batches = True # no shuffle
13 | opt.no_flip = True # no flip
14 |
15 | data_loader = CreateDataLoader(opt)
16 | dataset = data_loader.load_data()
17 | model = create_model(opt)
18 | visualizer = Visualizer(opt)
19 | # create website
20 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
21 |
22 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
23 |
24 | print(opt.how_many)
25 | print(len(dataset))
26 |
27 | model = model.eval()
28 | print(model.training)
29 |
30 | opt.how_many = 999999
31 | # test
32 | for i, data in enumerate(dataset):
33 | print(' process %d/%d img ..'%(i,opt.how_many))
34 | if i >= opt.how_many:
35 | break
36 | model.set_input(data)
37 | startTime = time.time()
38 | model.test()
39 | endTime = time.time()
40 | print(endTime-startTime)
41 | visuals = model.get_current_visuals()
42 | img_path = model.get_image_paths()
43 | img_path = [img_path]
44 | print(img_path)
45 | visualizer.save_images(webpage, visuals, img_path)
46 |
47 | webpage.save()
48 |
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/tool/generate_fashion_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 | from PIL import Image
4 |
5 | IMG_EXTENSIONS = [
6 | '.jpg', '.JPG', '.jpeg', '.JPEG',
7 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
8 | ]
9 |
10 | def is_image_file(filename):
11 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
12 |
13 | def make_dataset(dir):
14 | images = []
15 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
16 | # new_root = './fashion'
17 | # if not os.path.exists(new_root):
18 | # os.mkdir(new_root)
19 |
20 | train_root = os.path.join(dir, 'train')
21 | if not os.path.exists(train_root):
22 | os.mkdir(train_root)
23 |
24 | test_root = os.path.join(dir, 'test')
25 | if not os.path.exists(test_root):
26 | os.mkdir(test_root)
27 |
28 | train_images = []
29 | train_f = open(os.path.join(dir, 'train.lst'), 'r')
30 | for lines in train_f:
31 | lines = lines.strip()
32 | if lines.endswith('.jpg'):
33 | train_images.append(lines)
34 |
35 | test_images = []
36 | test_f = open(os.path.join(dir, 'test.lst'), 'r')
37 | for lines in test_f:
38 | lines = lines.strip()
39 | if lines.endswith('.jpg'):
40 | test_images.append(lines)
41 |
42 | # print(train_images, test_images)
43 |
44 | for root, _, fnames in sorted(os.walk(os.path.join(dir, 'img_highres'))):
45 | for fname in fnames:
46 | if is_image_file(fname):
47 | path = os.path.join(root, fname)
48 | path_names = path.split('/')
49 | print(path_names)
50 |
51 | path_names = path_names[2:]
52 | del path_names[1]
53 | path_names[3] = path_names[3].replace('_', '')
54 | path_names[4] = path_names[4].split('_')[0] + "_" + "".join(path_names[4].split('_')[1:])
55 | path_names = "".join(path_names)
56 | # img = Image.open(path)
57 | if path_names in train_images:
58 | shutil.copy(path, os.path.join(train_root, path_names))
59 | print(os.path.join(train_root, path_names))
60 | # pass
61 | elif path_names in test_images:
62 | shutil.copy(path, os.path.join(test_root, path_names))
63 | print(os.path.join(train_root, path_names))
64 | # pass
65 |
66 | make_dataset('../dataset/fashion/')
--------------------------------------------------------------------------------
/tool/generate_pose_map_fashion.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import json
4 | import os
5 |
6 | MISSING_VALUE = -1
7 | # fix PATH
8 | img_dir = '../dataset/fashion/'
9 | annotations_file = os.path.join(img_dir, 'fashion-resize-annotation-test.csv') #pose annotation path
10 | save_path = os.path.join(img_dir, 'testK')
11 | if not os.path.exists(save_path):
12 | os.makedirs(save_path)
13 |
14 | def load_pose_cords_from_strings(y_str, x_str):
15 | y_cords = json.loads(y_str)
16 | x_cords = json.loads(x_str)
17 | return np.concatenate([np.expand_dims(y_cords, -1), np.expand_dims(x_cords, -1)], axis=1)
18 |
19 | def cords_to_map(cords, img_size, sigma=6):
20 | result = np.zeros(img_size + cords.shape[0:1], dtype='uint8')
21 | for i, point in enumerate(cords):
22 | if point[0] == MISSING_VALUE or point[1] == MISSING_VALUE:
23 | continue
24 | xx, yy = np.meshgrid(np.arange(img_size[1]), np.arange(img_size[0]))
25 | result[..., i] = np.exp(-((yy - point[0]) ** 2 + (xx - point[1]) ** 2) / (2 * sigma ** 2))
26 | # result[..., i] = np.where(((yy - point[0]) ** 2 + (xx - point[1]) ** 2) < (sigma ** 2), 1, 0)
27 | return result
28 |
29 | def compute_pose(image_dir, annotations_file, savePath, sigma):
30 | annotations_file = pd.read_csv(annotations_file, sep=':')
31 | annotations_file = annotations_file.set_index('name')
32 | image_size = (256, 256)
33 | cnt = len(annotations_file)
34 | for i in range(cnt):
35 | print('processing %d / %d ...' %(i, cnt))
36 | row = annotations_file.iloc[i]
37 | name = row.name
38 | print(savePath, name)
39 | file_name = os.path.join(savePath, name + '.npy')
40 | kp_array = load_pose_cords_from_strings(row.keypoints_y, row.keypoints_x)
41 | pose = cords_to_map(kp_array, image_size, sigma)
42 | np.save(file_name, pose)
43 | # input()
44 |
45 | compute_pose(img_dir, annotations_file, save_path,6)
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/tool/resize_fashion.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from PIL import ImageFile
4 |
5 | ImageFile.LOAD_TRUNCATED_IMAGES=True
6 | def resize_dataset(folder, new_folder, new_size = (256, 256), crop_bord=0):
7 | if not os.path.exists(new_folder):
8 | os.makedirs(new_folder)
9 | for name in os.listdir(folder):
10 | old_name = os.path.join(folder, name)
11 | new_name = os.path.join(new_folder, name)
12 |
13 |
14 | img = Image.open(old_name)
15 | w, h =img.size
16 | if crop_bord == 0:
17 | pass
18 | else:
19 | img = img.crop((crop_bord, 0, w-crop_bord, h))
20 | img = img.resize([new_size[1],new_size[0]])
21 | img.save(new_name)
22 | print('resize %s succefully' % old_name)
23 |
24 |
25 | old_dir = '../dataset/fashion/train'
26 | root_dir = '../dataset/fashion/train_resize'
27 | resize_dataset(old_dir, root_dir)
28 |
29 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import time
2 | from options.train_options import TrainOptions
3 | from data.data_loader import CreateDataLoader
4 | from models.models import create_model
5 | from util.visualizer import Visualizer
6 |
7 |
8 | opt = TrainOptions().parse()
9 | data_loader = CreateDataLoader(opt)
10 | dataset = data_loader.load_data()
11 | dataset_size = len(data_loader)
12 | print('#training images = %d' % dataset_size)
13 |
14 | model = create_model(opt)
15 |
16 |
17 | visualizer = Visualizer(opt)
18 | total_steps = 0
19 |
20 | for epoch in range(opt.epoch_count, opt.niter + opt.niter_decay + 1):
21 | epoch_start_time = time.time()
22 | epoch_iter = 0
23 |
24 |
25 | for i, data in enumerate(dataset):
26 | iter_start_time = time.time()
27 | visualizer.reset()
28 | total_steps += opt.batchSize
29 | epoch_iter += opt.batchSize
30 | model.set_input(data)
31 |
32 | # model.optimize_parameters()
33 | model.optimize_parameters()
34 |
35 | if total_steps % opt.display_freq == 0:
36 | save_result = total_steps % opt.update_html_freq == 0
37 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
38 |
39 | if total_steps % opt.print_freq == 0:
40 | errors = model.get_current_errors()
41 | t = (time.time() - iter_start_time) / opt.batchSize
42 | visualizer.print_current_errors(epoch, epoch_iter, errors, t)
43 | if opt.display_id > 0:
44 | visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
45 |
46 | if total_steps % opt.save_latest_freq == 0:
47 | print('saving the latest model (epoch %d, total_steps %d)' %
48 | (epoch, total_steps))
49 | model.save('latest')
50 |
51 | if epoch % opt.save_epoch_freq == 0:
52 | print('saving the model at the end of epoch %d, iters %d' %
53 | (epoch, total_steps))
54 | model.save('latest')
55 | model.save(epoch)
56 |
57 | print('End of epoch %d / %d \t Time Taken: %d sec' %
58 | (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
59 | model.update_learning_rate()
60 |
61 |
62 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xyzhouo/CASD/cb1aabb64b5d8ae712b626a1eec045d08f90933e/util/__init__.py
--------------------------------------------------------------------------------
/util/get_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os
3 | import tarfile
4 | import requests
5 | from warnings import warn
6 | from zipfile import ZipFile
7 | from bs4 import BeautifulSoup
8 | from os.path import abspath, isdir, join, basename
9 |
10 |
11 | class GetData(object):
12 | """
13 |
14 | Download CycleGAN or Pix2Pix Data.
15 |
16 | Args:
17 | technique : str
18 | One of: 'cyclegan' or 'pix2pix'.
19 | verbose : bool
20 | If True, print additional information.
21 |
22 | Examples:
23 | >>> from util.get_data import GetData
24 | >>> gd = GetData(technique='cyclegan')
25 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
26 |
27 | """
28 |
29 | def __init__(self, technique='cyclegan', verbose=True):
30 | url_dict = {
31 | 'pix2pix': 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets',
32 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets'
33 | }
34 | self.url = url_dict.get(technique.lower())
35 | self._verbose = verbose
36 |
37 | def _print(self, text):
38 | if self._verbose:
39 | print(text)
40 |
41 | @staticmethod
42 | def _get_options(r):
43 | soup = BeautifulSoup(r.text, 'lxml')
44 | options = [h.text for h in soup.find_all('a', href=True)
45 | if h.text.endswith(('.zip', 'tar.gz'))]
46 | return options
47 |
48 | def _present_options(self):
49 | r = requests.get(self.url)
50 | options = self._get_options(r)
51 | print('Options:\n')
52 | for i, o in enumerate(options):
53 | print("{0}: {1}".format(i, o))
54 | choice = input("\nPlease enter the number of the "
55 | "dataset above you wish to download:")
56 | return options[int(choice)]
57 |
58 | def _download_data(self, dataset_url, save_path):
59 | if not isdir(save_path):
60 | os.makedirs(save_path)
61 |
62 | base = basename(dataset_url)
63 | temp_save_path = join(save_path, base)
64 |
65 | with open(temp_save_path, "wb") as f:
66 | r = requests.get(dataset_url)
67 | f.write(r.content)
68 |
69 | if base.endswith('.tar.gz'):
70 | obj = tarfile.open(temp_save_path)
71 | elif base.endswith('.zip'):
72 | obj = ZipFile(temp_save_path, 'r')
73 | else:
74 | raise ValueError("Unknown File Type: {0}.".format(base))
75 |
76 | self._print("Unpacking Data...")
77 | obj.extractall(save_path)
78 | obj.close()
79 | os.remove(temp_save_path)
80 |
81 | def get(self, save_path, dataset=None):
82 | """
83 |
84 | Download a dataset.
85 |
86 | Args:
87 | save_path : str
88 | A directory to save the data to.
89 | dataset : str, optional
90 | A specific dataset to download.
91 | Note: this must include the file extension.
92 | If None, options will be presented for you
93 | to choose from.
94 |
95 | Returns:
96 | save_path_full : str
97 | The absolute path to the downloaded data.
98 |
99 | """
100 | if dataset is None:
101 | selected_dataset = self._present_options()
102 | else:
103 | selected_dataset = dataset
104 |
105 | save_path_full = join(save_path, selected_dataset.split('.')[0])
106 |
107 | if isdir(save_path_full):
108 | warn("\n'{0}' already exists. Voiding Download.".format(
109 | save_path_full))
110 | else:
111 | self._print('Downloading Data...')
112 | url = "{0}/{1}".format(self.url, selected_dataset)
113 | self._download_data(url, save_path=save_path)
114 |
115 | return abspath(save_path_full)
116 |
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import *
3 | import os
4 |
5 |
6 | class HTML:
7 | def __init__(self, web_dir, title, reflesh=0):
8 | self.title = title
9 | self.web_dir = web_dir
10 | self.img_dir = os.path.join(self.web_dir, 'images')
11 | if not os.path.exists(self.web_dir):
12 | os.makedirs(self.web_dir)
13 | if not os.path.exists(self.img_dir):
14 | os.makedirs(self.img_dir)
15 | # print(self.img_dir)
16 |
17 | self.doc = dominate.document(title=title)
18 | if reflesh > 0:
19 | with self.doc.head:
20 | meta(http_equiv="reflesh", content=str(reflesh))
21 |
22 | def get_image_dir(self):
23 | return self.img_dir
24 |
25 | def add_header(self, str):
26 | with self.doc:
27 | h3(str)
28 |
29 | def add_table(self, border=1):
30 | self.t = table(border=border, style="table-layout: fixed;")
31 | self.doc.add(self.t)
32 |
33 | def add_images(self, ims, txts, links, width=400):
34 | self.add_table()
35 | with self.t:
36 | with tr():
37 | for im, txt, link in zip(ims, txts, links):
38 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
39 | with p():
40 | with a(href=os.path.join('images', link)):
41 | img(style="width:%dpx" % width, src=os.path.join('images', im))
42 | br()
43 | p(txt)
44 |
45 | def save(self):
46 | html_file = '%s/index.html' % self.web_dir
47 | f = open(html_file, 'wt')
48 | f.write(self.doc.render())
49 | f.close()
50 |
51 |
52 | if __name__ == '__main__':
53 | html = HTML('web/', 'test_html')
54 | html.add_header('hello world')
55 |
56 | ims = []
57 | txts = []
58 | links = []
59 | for n in range(4):
60 | ims.append('image_%d.png' % n)
61 | txts.append('text_%d' % n)
62 | links.append('image_%d.png' % n)
63 | html.add_images(ims, txts, links)
64 | html.save()
65 |
--------------------------------------------------------------------------------
/util/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 | from torch.autograd import Variable
5 |
6 |
7 | class ImagePool():
8 | def __init__(self, pool_size):
9 | self.pool_size = pool_size
10 | if self.pool_size > 0:
11 | self.num_imgs = 0
12 | self.images = []
13 |
14 | def query(self, images):
15 | if self.pool_size == 0:
16 | return Variable(images)
17 | return_images = []
18 | for image in images:
19 | image = torch.unsqueeze(image, 0)
20 | if self.num_imgs < self.pool_size:
21 | self.num_imgs = self.num_imgs + 1
22 | self.images.append(image)
23 | return_images.append(image)
24 | else:
25 | p = random.uniform(0, 1)
26 | if p > 0.5:
27 | random_id = random.randint(0, self.pool_size-1)
28 | tmp = self.images[random_id].clone()
29 | self.images[random_id] = image
30 | return_images.append(tmp)
31 | else:
32 | return_images.append(image)
33 | return_images = Variable(torch.cat(return_images, 0))
34 | return return_images
35 |
--------------------------------------------------------------------------------
/util/png.py:
--------------------------------------------------------------------------------
1 | import struct
2 | import zlib
3 |
4 | def encode(buf, width, height):
5 | """ buf: must be bytes or a bytearray in py3, a regular string in py2. formatted RGBRGB... """
6 | assert (width * height * 3 == len(buf))
7 | bpp = 3
8 |
9 | def raw_data():
10 | # reverse the vertical line order and add null bytes at the start
11 | row_bytes = width * bpp
12 | for row_start in range((height - 1) * width * bpp, -1, -row_bytes):
13 | yield b'\x00'
14 | yield buf[row_start:row_start + row_bytes]
15 |
16 | def chunk(tag, data):
17 | return [
18 | struct.pack("!I", len(data)),
19 | tag,
20 | data,
21 | struct.pack("!I", 0xFFFFFFFF & zlib.crc32(data, zlib.crc32(tag)))
22 | ]
23 |
24 | SIGNATURE = b'\x89PNG\r\n\x1a\n'
25 | COLOR_TYPE_RGB = 2
26 | COLOR_TYPE_RGBA = 6
27 | bit_depth = 8
28 | return b''.join(
29 | [ SIGNATURE ] +
30 | chunk(b'IHDR', struct.pack("!2I5B", width, height, bit_depth, COLOR_TYPE_RGB, 0, 0, 0)) +
31 | chunk(b'IDAT', zlib.compress(b''.join(raw_data()), 9)) +
32 | chunk(b'IEND', b'')
33 | )
34 |
--------------------------------------------------------------------------------
/util/pose_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from skimage.draw import circle, line_aa, polygon
3 | import json
4 |
5 | import matplotlib
6 | matplotlib.use('Agg')
7 | import matplotlib.pyplot as plt
8 | import matplotlib.patches as mpatches
9 |
10 | LIMB_SEQ = [[1,2], [1,5], [2,3], [3,4], [5,6], [6,7], [1,8], [8,9],
11 | [9,10], [1,11], [11,12], [12,13], [1,0], [0,14], [14,16],
12 | [0,15], [15,17], [2,16], [5,17]]
13 |
14 | COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
15 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
16 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
17 |
18 |
19 | LABELS = ['nose', 'neck', 'Rsho', 'Relb', 'Rwri', 'Lsho', 'Lelb', 'Lwri',
20 | 'Rhip', 'Rkne', 'Rank', 'Lhip', 'Lkne', 'Lank', 'Leye', 'Reye', 'Lear', 'Rear']
21 |
22 | MISSING_VALUE = -1
23 |
24 |
25 | def map_to_cord(pose_map, threshold=0.1):
26 | all_peaks = [[] for i in range(18)]
27 | pose_map = pose_map[..., :18]
28 |
29 | y, x, z = np.where(np.logical_and(pose_map == pose_map.max(axis = (0, 1)),
30 | pose_map > threshold))
31 | for x_i, y_i, z_i in zip(x, y, z):
32 | all_peaks[z_i].append([x_i, y_i])
33 |
34 | x_values = []
35 | y_values = []
36 |
37 | for i in range(18):
38 | if len(all_peaks[i]) != 0:
39 | x_values.append(all_peaks[i][0][0])
40 | y_values.append(all_peaks[i][0][1])
41 | else:
42 | x_values.append(MISSING_VALUE)
43 | y_values.append(MISSING_VALUE)
44 |
45 | return np.concatenate([np.expand_dims(y_values, -1), np.expand_dims(x_values, -1)], axis=1)
46 |
47 |
48 | def cords_to_map(cords, img_size, old_size=None, affine_matrix=None, sigma=6):
49 | old_size = img_size if old_size is None else old_size
50 | cords = cords.astype(float)
51 | result = np.zeros(img_size + cords.shape[0:1], dtype='float32')
52 | for i, point in enumerate(cords):
53 | if point[0] == MISSING_VALUE or point[1] == MISSING_VALUE:
54 | continue
55 | point[0] = point[0]/old_size[0] * img_size[0]
56 | point[1] = point[1]/old_size[1] * img_size[1]
57 | if affine_matrix is not None:
58 | point_ =np.dot(affine_matrix, np.matrix([point[1], point[0], 1]).reshape(3,1))
59 | point_0 = int(point_[1])
60 | point_1 = int(point_[0])
61 | else:
62 | point_0 = int(point[0])
63 | point_1 = int(point[1])
64 | xx, yy = np.meshgrid(np.arange(img_size[1]), np.arange(img_size[0]))
65 | result[..., i] = np.exp(-((yy - point_0) ** 2 + (xx - point_1) ** 2) / (2 * sigma ** 2))
66 | return result
67 |
68 |
69 | def draw_pose_from_cords(pose_joints, img_size, radius=2, draw_joints=True):
70 | colors = np.zeros(shape=img_size + (3, ), dtype=np.uint8)
71 | mask = np.zeros(shape=img_size, dtype=bool)
72 |
73 | if draw_joints:
74 | for f, t in LIMB_SEQ:
75 | from_missing = pose_joints[f][0] == MISSING_VALUE or pose_joints[f][1] == MISSING_VALUE
76 | to_missing = pose_joints[t][0] == MISSING_VALUE or pose_joints[t][1] == MISSING_VALUE
77 | if from_missing or to_missing:
78 | continue
79 | yy, xx, val = line_aa(pose_joints[f][0], pose_joints[f][1], pose_joints[t][0], pose_joints[t][1])
80 | colors[yy, xx] = np.expand_dims(val, 1) * 255
81 | mask[yy, xx] = True
82 |
83 | for i, joint in enumerate(pose_joints):
84 | if pose_joints[i][0] == MISSING_VALUE or pose_joints[i][1] == MISSING_VALUE:
85 | continue
86 | yy, xx = circle(joint[0], joint[1], radius=radius, shape=img_size)
87 | colors[yy, xx] = COLORS[i]
88 | mask[yy, xx] = True
89 |
90 | return colors, mask
91 |
92 |
93 | def draw_pose_from_map(pose_map, threshold=0.1, **kwargs):
94 | cords = map_to_cord(pose_map, threshold=threshold)
95 | return draw_pose_from_cords(cords, pose_map.shape[:2], **kwargs)
96 |
97 |
98 | def load_pose_cords_from_strings(y_str, x_str):
99 | y_cords = json.loads(y_str)
100 | x_cords = json.loads(x_str)
101 | return np.concatenate([np.expand_dims(y_cords, -1), np.expand_dims(x_cords, -1)], axis=1)
102 |
103 | def mean_inputation(X):
104 | X = X.copy()
105 | for i in range(X.shape[1]):
106 | for j in range(X.shape[2]):
107 | val = np.mean(X[:, i, j][X[:, i, j] != -1])
108 | X[:, i, j][X[:, i, j] == -1] = val
109 | return X
110 |
111 | def draw_legend():
112 | handles = [mpatches.Patch(color=np.array(color) / 255.0, label=name) for color, name in zip(COLORS, LABELS)]
113 | plt.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
114 |
115 | def produce_ma_mask(kp_array, img_size, point_radius=4):
116 | from skimage.morphology import dilation, erosion, square
117 | mask = np.zeros(shape=img_size, dtype=bool)
118 | limbs = [[2,3], [2,6], [3,4], [4,5], [6,7], [7,8], [2,9], [9,10],
119 | [10,11], [2,12], [12,13], [13,14], [2,1], [1,15], [15,17],
120 | [1,16], [16,18], [2,17], [2,18], [9,12], [12,6], [9,3], [17,18]]
121 | limbs = np.array(limbs) - 1
122 | for f, t in limbs:
123 | from_missing = kp_array[f][0] == MISSING_VALUE or kp_array[f][1] == MISSING_VALUE
124 | to_missing = kp_array[t][0] == MISSING_VALUE or kp_array[t][1] == MISSING_VALUE
125 | if from_missing or to_missing:
126 | continue
127 |
128 | norm_vec = kp_array[f] - kp_array[t]
129 | norm_vec = np.array([-norm_vec[1], norm_vec[0]])
130 | norm_vec = point_radius * norm_vec / np.linalg.norm(norm_vec)
131 |
132 |
133 | vetexes = np.array([
134 | kp_array[f] + norm_vec,
135 | kp_array[f] - norm_vec,
136 | kp_array[t] - norm_vec,
137 | kp_array[t] + norm_vec
138 | ])
139 | yy, xx = polygon(vetexes[:, 0], vetexes[:, 1], shape=img_size)
140 | mask[yy, xx] = True
141 |
142 | for i, joint in enumerate(kp_array):
143 | if kp_array[i][0] == MISSING_VALUE or kp_array[i][1] == MISSING_VALUE:
144 | continue
145 | yy, xx = circle(joint[0], joint[1], radius=point_radius, shape=img_size)
146 | mask[yy, xx] = True
147 |
148 | mask = dilation(mask, square(5))
149 | mask = erosion(mask, square(5))
150 | return mask
151 |
152 | if __name__ == "__main__":
153 | import pandas as pd
154 | from skimage.io import imread
155 | import pylab as plt
156 | import os
157 | i = 5
158 | df = pd.read_csv('data/market-annotation-train.csv', sep=':')
159 |
160 | for index, row in df.iterrows():
161 | pose_cords = load_pose_cords_from_strings(row['keypoints_y'], row['keypoints_x'])
162 |
163 | colors, mask = draw_pose_from_cords(pose_cords, (128, 64))
164 |
165 | mmm = produce_ma_mask(pose_cords, (128, 64)).astype(float)[..., np.newaxis].repeat(3, axis=-1)
166 | print(mmm.shape)
167 | img = imread('data/market-dataset/train/' + row['name'])
168 |
169 | mmm[mask] = colors[mask]
170 |
171 | print (mmm)
172 | plt.subplot(1, 1, 1)
173 | plt.imshow(mmm)
174 | plt.show()
175 |
--------------------------------------------------------------------------------
/util/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | import inspect, re
6 | import os
7 | import collections
8 | from skimage.draw import circle, line_aa
9 |
10 |
11 |
12 | # Converts a Tensor into a Numpy array
13 | # |imtype|: the desired type of the converted numpy array
14 | def tensor2im(image_tensor, imtype=np.uint8):
15 | image_numpy = image_tensor[0].cpu().float().numpy()
16 | if image_numpy.shape[0] == 1:
17 | image_numpy = np.tile(image_numpy, (3, 1, 1))
18 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
19 | return image_numpy.astype(imtype)
20 |
21 |
22 | LIMB_SEQ = [[1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7], [1, 8], [8, 9],
23 | [9, 10], [1, 11], [11, 12], [12, 13], [1, 0], [0, 14], [14, 16],
24 | [0, 15], [15, 17]]
25 |
26 | # draw dis img
27 | LIMB_SEQ_DIS = [[1, 2], [1, 5], [2, 3], [3, 4], [5, 6], [6, 7], [1, 8], [8, 9],
28 | [9, 10], [1, 11], [11, 12], [12, 13]]
29 |
30 | COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
31 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
32 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
33 |
34 | LABELS = ['nose', 'neck', 'Rsho', 'Relb', 'Rwri', 'Lsho', 'Lelb', 'Lwri',
35 | 'Rhip', 'Rkne', 'Rank', 'Lhip', 'Lkne', 'Lank', 'Leye', 'Reye', 'Lear', 'Rear']
36 |
37 | MISSING_VALUE = -1
38 |
39 |
40 | def map_to_cord(pose_map, threshold=0.1):
41 | all_peaks = [[] for i in range(18)]
42 | pose_map = pose_map[..., :18]
43 |
44 | if torch.is_tensor(pose_map):
45 | pose_map = pose_map.cpu()
46 | try:
47 | y, x, z = np.where(np.logical_and(pose_map == 1.0, pose_map > threshold))
48 | except:
49 | print(np.where(np.logical_and(pose_map == 1.0, pose_map > threshold)))
50 | print(pose_map.shape)
51 | for x_i, y_i, z_i in zip(x, y, z):
52 | all_peaks[z_i].append([x_i, y_i])
53 |
54 | x_values = []
55 | y_values = []
56 |
57 | for i in range(18):
58 | if len(all_peaks[i]) != 0:
59 | x_values.append(all_peaks[i][0][0])
60 | y_values.append(all_peaks[i][0][1])
61 | else:
62 | x_values.append(MISSING_VALUE)
63 | y_values.append(MISSING_VALUE)
64 |
65 | return np.concatenate([np.expand_dims(y_values, -1), np.expand_dims(x_values, -1)], axis=1)
66 |
67 |
68 | def draw_pose_from_map(pose_map, threshold=0.1, **kwargs):
69 | # CHW -> HCW -> HWC
70 | pose_map = pose_map[0].cpu().transpose(1, 0).transpose(2, 1).numpy()
71 |
72 | cords = map_to_cord(pose_map, threshold=threshold)
73 | return draw_pose_from_cords(cords, pose_map.shape[:2], **kwargs)
74 |
75 |
76 | def draw_dis_from_map(pose_map, threshold=0.1, **kwargs):
77 | # CHW -> HCW -> HWC
78 | # print(pose_map.shape)
79 | if torch.is_tensor(pose_map):
80 | pose_map = pose_map[0].cpu().transpose(1, 0).transpose(2, 1).numpy()
81 | print(pose_map.shape)
82 | cords = map_to_cord(pose_map, threshold=threshold)
83 | return draw_dis_from_cords(cords, pose_map.shape[:2], **kwargs)
84 |
85 |
86 |
87 | # draw pose from map
88 | def draw_pose_from_cords(pose_joints, img_size, radius=2, draw_joints=True):
89 | colors = np.zeros(shape=img_size + (3,), dtype=np.uint8)
90 | mask = np.zeros(shape=img_size, dtype=bool)
91 |
92 | if draw_joints:
93 | for f, t in LIMB_SEQ:
94 | from_missing = pose_joints[f][0] == MISSING_VALUE or pose_joints[f][1] == MISSING_VALUE
95 | to_missing = pose_joints[t][0] == MISSING_VALUE or pose_joints[t][1] == MISSING_VALUE
96 | if from_missing or to_missing:
97 | continue
98 | yy, xx, val = line_aa(pose_joints[f][0], pose_joints[f][1], pose_joints[t][0], pose_joints[t][1])
99 | colors[yy, xx] = np.expand_dims(val, 1) * 255
100 | mask[yy, xx] = True
101 |
102 | for i, joint in enumerate(pose_joints):
103 | if pose_joints[i][0] == MISSING_VALUE or pose_joints[i][1] == MISSING_VALUE:
104 | continue
105 | yy, xx = circle(joint[0], joint[1], radius=radius, shape=img_size)
106 | colors[yy, xx] = COLORS[i]
107 | mask[yy, xx] = True
108 |
109 | return colors, mask
110 |
111 |
112 | # point to line distance
113 | def get_distance_from_point_to_line(point, line_point1, line_point2):
114 |
115 | if line_point1 == line_point2:
116 | point_array = np.array(point)
117 | point1_array = np.array(line_point1)
118 | aa = np.expand_dims(np.expand_dims(point1_array, -1), -1)
119 | aa = np.repeat(aa, point.shape[1], 1)
120 | aa = np.repeat(aa, point.shape[2], 2)
121 | return np.linalg.norm(point_array - aa)
122 | A = line_point2[0] - line_point1[0]
123 | B = line_point1[1] - line_point2[1]
124 | C = (line_point1[0] - line_point2[0]) * line_point1[1] + \
125 | (line_point2[1] - line_point1[1]) * line_point1[0]
126 | distance = np.abs(A * point[1] + B * point[0] + C) / (np.sqrt(A ** 2 + B ** 2))
127 | distance = np.exp(-0.1 * distance)
128 | return distance
129 |
130 |
131 |
132 |
133 | # draw dis from map
134 | def draw_dis_from_cords(pose_joints, img_size, radius=2, draw_joints=True):
135 | dis = np.zeros(shape=img_size + (12,), dtype=np.float64)
136 | y = np.linspace(0, img_size[0] - 1, img_size[0])
137 | x = np.linspace(0, img_size[1] - 1, img_size[1])
138 | xv, yv = np.meshgrid(x, y)
139 | point = np.concatenate([np.expand_dims(yv, 0), np.expand_dims(xv, 0)], 0)
140 |
141 | for i, (f, t) in enumerate(LIMB_SEQ_DIS):
142 | from_missing = pose_joints[f][0] == MISSING_VALUE or pose_joints[f][1] == MISSING_VALUE
143 | to_missing = pose_joints[t][0] == MISSING_VALUE or pose_joints[t][1] == MISSING_VALUE
144 | if from_missing or to_missing:
145 | continue
146 | dis[:, :, i] = get_distance_from_point_to_line(point, [pose_joints[f][0], pose_joints[f][1]],
147 | [pose_joints[t][0], pose_joints[t][1]])
148 | return dis, np.mean(dis, -1)
149 |
150 |
151 |
152 | def diagnose_network(net, name='network'):
153 | mean = 0.0
154 | count = 0
155 | for param in net.parameters():
156 | if param.grad is not None:
157 | mean += torch.mean(torch.abs(param.grad.data))
158 | count += 1
159 | if count > 0:
160 | mean = mean / count
161 | print(name)
162 | print(mean)
163 |
164 |
165 | def save_image(image_numpy, image_path):
166 | image_pil = Image.fromarray(image_numpy)
167 | image_pil.save(image_path)
168 |
169 |
170 | def info(object, spacing=10, collapse=1):
171 | """Print methods and doc strings.
172 | Takes module, class, list, dictionary, or string."""
173 | methodList = [e for e in dir(object) if isinstance(getattr(object, e), collections.Callable)]
174 | processFunc = collapse and (lambda s: " ".join(s.split())) or (lambda s: s)
175 | print("\n".join(["%s %s" %
176 | (method.ljust(spacing),
177 | processFunc(str(getattr(object, method).__doc__)))
178 | for method in methodList]))
179 |
180 |
181 | def varname(p):
182 | for line in inspect.getframeinfo(inspect.currentframe().f_back)[3]:
183 | m = re.search(r'\bvarname\s*\(\s*([A-Za-z_][A-Za-z0-9_]*)\s*\)', line)
184 | if m:
185 | return m.group(1)
186 |
187 |
188 | def print_numpy(x, val=True, shp=False):
189 | x = x.astype(np.float64)
190 | if shp:
191 | print('shape,', x.shape)
192 | if val:
193 | x = x.flatten()
194 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
195 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
196 |
197 |
198 | def mkdirs(paths):
199 | if isinstance(paths, list) and not isinstance(paths, str):
200 | for path in paths:
201 | mkdir(path)
202 | else:
203 | mkdir(paths)
204 |
205 |
206 | def mkdir(path):
207 | if not os.path.exists(path):
208 | os.makedirs(path)
209 |
--------------------------------------------------------------------------------
/util/visualizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import ntpath
4 | import time
5 | from . import util
6 | #from . import html
7 |
8 |
9 | class Visualizer():
10 | def __init__(self, opt):
11 | # self.opt = opt
12 | self.display_id = opt.display_id
13 | self.use_html = opt.isTrain and not opt.no_html
14 | self.win_size = opt.display_winsize
15 | self.name = opt.name
16 | self.opt = opt
17 | self.saved = False
18 | if self.display_id > 0:
19 | import visdom
20 | self.vis = visdom.Visdom(port=opt.display_port)
21 |
22 | if self.use_html:
23 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
24 | self.img_dir = os.path.join(self.web_dir, 'images')
25 | print('create web directory %s...' % self.web_dir)
26 | util.mkdirs([self.web_dir, self.img_dir])
27 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
28 | with open(self.log_name, "a") as log_file:
29 | now = time.strftime("%c")
30 | log_file.write('================ Training Loss (%s) ================\n' % now)
31 |
32 | def reset(self):
33 | self.saved = False
34 |
35 | # |visuals|: dictionary of images to display or save
36 | def display_current_results(self, visuals, epoch, save_result):
37 | if self.display_id > 0: # show images in the browser
38 | ncols = self.opt.display_single_pane_ncols
39 | if ncols > 0:
40 | h, w = next(iter(visuals.values())).shape[:2]
41 | table_css = """""" % (w, h)
45 | title = self.name
46 | label_html = ''
47 | label_html_row = ''
48 | nrows = int(np.ceil(len(visuals.items()) / ncols))
49 | images = []
50 | idx = 0
51 | for label, image_numpy in visuals.items():
52 | label_html_row += '%s | ' % label
53 | images.append(image_numpy.transpose([2, 0, 1]))
54 | idx += 1
55 | if idx % ncols == 0:
56 | label_html += '%s
' % label_html_row
57 | label_html_row = ''
58 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1]))*255
59 | while idx % ncols != 0:
60 | images.append(white_image)
61 | label_html_row += ' | '
62 | idx += 1
63 | if label_html_row != '':
64 | label_html += '%s
' % label_html_row
65 | # pane col = image row
66 | self.vis.images(images, nrow=ncols, win=self.display_id + 1,
67 | padding=2, opts=dict(title=title + ' images'))
68 | label_html = '' % label_html
69 | self.vis.text(table_css + label_html, win=self.display_id + 2,
70 | opts=dict(title=title + ' labels'))
71 | else:
72 | idx = 1
73 | for label, image_numpy in visuals.items():
74 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label),
75 | win=self.display_id + idx)
76 | idx += 1
77 |
78 | if self.use_html and (save_result or not self.saved): # save images to a html file
79 | self.saved = True
80 | for label, image_numpy in visuals.items():
81 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
82 | util.save_image(image_numpy, img_path)
83 | # update website
84 | # webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, reflesh=1)
85 | # for n in range(epoch, 0, -1):
86 | # webpage.add_header('epoch [%d]' % n)
87 | # ims = []
88 | # txts = []
89 | # links = []
90 |
91 | # for label, image_numpy in visuals.items():
92 | # img_path = 'epoch%.3d_%s.png' % (n, label)
93 | # ims.append(img_path)
94 | # txts.append(label)
95 | # links.append(img_path)
96 | # webpage.add_images(ims, txts, links, width=self.win_size)
97 | # webpage.save()
98 |
99 | # errors: dictionary of error labels and values
100 | def plot_current_errors(self, epoch, counter_ratio, opt, errors):
101 | if not hasattr(self, 'plot_data'):
102 | self.plot_data = {'X': [], 'Y': [], 'legend': list(errors.keys())}
103 | self.plot_data['X'].append(epoch + counter_ratio)
104 | self.plot_data['Y'].append([errors[k] for k in self.plot_data['legend']])
105 | self.vis.line(
106 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
107 | Y=np.array(self.plot_data['Y']),
108 | opts={
109 | 'title': self.name + ' loss over time',
110 | 'legend': self.plot_data['legend'],
111 | 'xlabel': 'epoch',
112 | 'ylabel': 'loss'},
113 | win=self.display_id)
114 |
115 | # errors: same format as |errors| of plotCurrentErrors
116 | def print_current_errors(self, epoch, i, errors, t):
117 | message = '(epoch: %d, iters: %d, time: %.3f) ' % (epoch, i, t)
118 | for k, v in errors.items():
119 | message += '%s: %.3f ' % (k, v)
120 |
121 | print(message)
122 | with open(self.log_name, "a") as log_file:
123 | log_file.write('%s\n' % message)
124 |
125 | # save image to the disk
126 | def save_images(self, webpage, visuals, image_path):
127 | image_dir = webpage.get_image_dir()
128 | short_path = ntpath.basename(image_path[0])
129 | name = os.path.splitext(short_path)[0]
130 |
131 | webpage.add_header(name)
132 | ims = []
133 | txts = []
134 | links = []
135 |
136 | for label, image_numpy in visuals.items():
137 | image_name = '%s_%s.jpg' % (image_path[0], label)
138 | save_path = os.path.join(image_dir, image_name)
139 | print(save_path)
140 | util.save_image(image_numpy, save_path)
141 |
142 | ims.append(image_name)
143 | txts.append(label)
144 | links.append(image_name)
145 | webpage.add_images(ims, txts, links, width=self.win_size)
146 |
--------------------------------------------------------------------------------