├── .gitignore ├── alignment.py ├── dataset.py ├── images ├── Meg Ryan │ ├── P00015.jpg │ └── P00015_gen.jpg └── examples.png ├── main_cal_warp_degree.py ├── main_generate.py ├── main_generate_single_image.py ├── networks ├── __init__.py ├── loss.py ├── modules.py ├── styler.py └── warper.py ├── readme.md ├── test_styler.py ├── test_warper.py ├── train_styler.py ├── train_warper.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # celery beat schedule file 96 | celerybeat-schedule 97 | 98 | # SageMath parsed files 99 | *.sage.py 100 | 101 | # Environments 102 | .env 103 | .venv 104 | env/ 105 | venv/ 106 | ENV/ 107 | env.bak/ 108 | venv.bak/ 109 | 110 | # Spyder project settings 111 | .spyderproject 112 | .spyproject 113 | 114 | # Rope project settings 115 | .ropeproject 116 | 117 | # mkdocs documentation 118 | /site 119 | 120 | # mypy 121 | .mypy_cache/ 122 | .dmypy.json 123 | dmypy.json 124 | 125 | # Pyre type checker 126 | .pyre/ 127 | 128 | ### JetBrains template 129 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 130 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 131 | 132 | # User-specific stuff 133 | .idea/**/workspace.xml 134 | .idea/**/tasks.xml 135 | .idea/**/usage.statistics.xml 136 | .idea/**/dictionaries 137 | .idea/**/shelf 138 | 139 | # Generated files 140 | .idea/**/contentModel.xml 141 | 142 | # Sensitive or high-churn files 143 | .idea/**/dataSources/ 144 | .idea/**/dataSources.ids 145 | .idea/**/dataSources.local.xml 146 | .idea/**/sqlDataSources.xml 147 | .idea/**/dynamic.xml 148 | .idea/**/uiDesigner.xml 149 | .idea/**/dbnavigator.xml 150 | 151 | # Gradle 152 | .idea/**/gradle.xml 153 | .idea/**/libraries 154 | 155 | # Gradle and Maven with auto-import 156 | # When using Gradle or Maven with auto-import, you should exclude module files, 157 | # since they will be recreated, and may cause churn. Uncomment if using 158 | # auto-import. 159 | # .idea/modules.xml 160 | # .idea/*.iml 161 | # .idea/modules 162 | # *.iml 163 | # *.ipr 164 | 165 | # CMake 166 | cmake-build-*/ 167 | 168 | # Mongo Explorer plugin 169 | .idea/**/mongoSettings.xml 170 | 171 | # File-based project format 172 | *.iws 173 | 174 | # IntelliJ 175 | out/ 176 | 177 | # mpeltonen/sbt-idea plugin 178 | .idea_modules/ 179 | 180 | # JIRA plugin 181 | atlassian-ide-plugin.xml 182 | 183 | # Cursive Clojure plugin 184 | .idea/replstate.xml 185 | 186 | # Crashlytics plugin (for Android Studio and IntelliJ) 187 | com_crashlytics_export_strings.xml 188 | crashlytics.properties 189 | crashlytics-build.properties 190 | fabric.properties 191 | 192 | # Editor-based Rest Client 193 | .idea/httpRequests 194 | 195 | # Android studio 3.1+ serialized cache file 196 | .idea/caches/build_file_checksums.ser 197 | 198 | .gitignore 199 | data/ 200 | pretrained/ 201 | -------------------------------------------------------------------------------- /alignment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import linecache 4 | from PIL import Image 5 | import matplotlib.pyplot as plt 6 | import copy 7 | import threading 8 | 9 | num_of_points = 17 10 | time_enlarge = 1.3 11 | resize_w = 256 12 | resize_h = 256 13 | path_data = 'data/WebCaricature' 14 | path_output = 'data/WebCaricature_aligned' 15 | 16 | path_images = os.path.join(path_data, 'OriginalImages') 17 | path_points = os.path.join(path_data, 'FacialPoints') 18 | path_output_images = os.path.join(path_output, 'image') 19 | path_output_points = os.path.join(path_output, 'landmark') 20 | 21 | 22 | def main(): 23 | root = path_points 24 | image_list = list_all_images(root) 25 | start_threads(image_list) 26 | 27 | 28 | def get_point_from_line(path, x): 29 | line = linecache.getline(path, x) 30 | x, y = line.strip().split(' ') 31 | return float(x), float(y) 32 | 33 | 34 | def get_points_from_txt(path): 35 | points = [[0] * 2 for i in range(num_of_points)] 36 | for i in range(num_of_points): 37 | x, y = get_point_from_line(path, i+1) 38 | points[i][0] = x 39 | points[i][1] = y 40 | return points 41 | 42 | 43 | def load_landmark(path): 44 | result = [[0] * 2 for i in range(num_of_points)] 45 | with open(path, 'r', encoding='utf-8') as f: 46 | lines = f.readlines() 47 | f.close() 48 | for i in range(17): 49 | result[i][0] = int(lines[i].split('\t')[0]) 50 | result[i][1] = int(lines[i].split('\t')[1]) 51 | return result 52 | 53 | 54 | def get_rotate_angle(points): 55 | xl = (points[8][0] + points[9][0]) / 2 56 | yl = (points[8][1] + points[9][1]) / 2 57 | xr = (points[10][0] + points[11][0]) / 2 58 | yr = (points[10][1] + points[11][1]) / 2 59 | if xl == xr: 60 | if yr > yl: 61 | return 90 62 | elif yr < yl: 63 | return -90 64 | else: 65 | print(points) 66 | raise RuntimeError('x=x,y=y') 67 | tan_x = (yr - yl) / (xr - xl) 68 | x_rad = math.atan(tan_x) 69 | x_angle = (180 * x_rad) / math.pi 70 | return x_angle 71 | 72 | 73 | def rotate_image(name, filename, angle): 74 | path_image = os.path.join(path_images, name, filename) 75 | image = Image.open(path_image) 76 | image_rotated = image.rotate(angle, Image.BILINEAR) 77 | w, h = image.size 78 | return image_rotated, w, h 79 | 80 | 81 | def calculate_new_point(x0, y0, angle, w, h): 82 | if angle == 0: 83 | return x0, y0 84 | angle *= -(math.pi / 180) 85 | x1 = x0 - w / 2 86 | y1 = y0 - h / 2 87 | r_square = x1 * x1 + y1 * y1 88 | if x1 == 0: 89 | tanx = -(1 / math.tan(angle)) 90 | else: 91 | if y1 * math.tan(angle) == x1: 92 | x2 = 0 93 | y2 = math.sqrt(r_square) 94 | return x2, y2 95 | elif 1 - math.tan(angle) * (y1 / x1) == 0: 96 | x2 = 0 97 | y2 = math.sqrt(r_square) 98 | return x2, y2 99 | else: 100 | tanx = (y1 / x1 + math.tan(angle)) / (1 - math.tan(angle) * (y1 / x1)) 101 | x2_square = r_square / (1 + tanx * tanx) 102 | x2_1 = math.sqrt(x2_square) 103 | y2_1 = x2_1 * tanx 104 | x2_2 = -x2_1 105 | y2_2 = -y2_1 106 | d1_square = (x1 - x2_1) * (x1 - x2_1) + (y1 - y2_1) * (y1 - y2_1) 107 | d2_square = (x1 - x2_2) * (x1 - x2_2) + (y1 - y2_2) * (y1 - y2_2) 108 | if d1_square < d2_square: 109 | x2 = x2_1 110 | y2 = y2_1 111 | else: 112 | x2 = x2_2 113 | y2 = y2_2 114 | x2 += w / 2 115 | y2 += h / 2 116 | return x2, y2 117 | 118 | 119 | def calculate_new_points(points, angle, w, h): 120 | if angle == 0: 121 | return points 122 | else: 123 | result = [[0] * 2 for i in range(num_of_points)] 124 | for i in range(num_of_points): 125 | x0 = points[i][0] 126 | y0 = points[i][1] 127 | x2, y2 = calculate_new_point(x0, y0, angle, w, h) 128 | result[i][0] = x2 129 | result[i][1] = y2 130 | return result 131 | 132 | 133 | def calculate_boundingbox(points): 134 | max_list = [] 135 | min_list = [] 136 | for j in range(len(points[0])): 137 | list = [] 138 | for i in range(len(points)): 139 | list.append(points[i][j]) 140 | max_list.append(max(list)) 141 | min_list.append(min(list)) 142 | x_max = max_list[0] 143 | y_max = max_list[1] 144 | x_min = min_list[0] 145 | y_min = min_list[1] 146 | 147 | delta_x = x_max - x_min 148 | delta_y = y_max - y_min 149 | length = abs(delta_x - delta_y) / 2 150 | if delta_x > delta_y: 151 | y_min -= length 152 | y_max += length 153 | else: 154 | x_min -= length 155 | x_max += length 156 | 157 | return x_max, x_min, y_max, y_min 158 | 159 | 160 | def enlarge(x_max, x_min, y_max, y_min, time_enlarge, w, h): 161 | nx_max = x_max + (time_enlarge - 1) * (x_max - x_min) / 2 162 | ny_max = y_max + (time_enlarge - 1) * (y_max - y_min) / 2 163 | nx_min = x_min - (time_enlarge - 1) * (x_max - x_min) / 2 164 | ny_min = y_min - (time_enlarge - 1) * (y_max - y_min) / 2 165 | return nx_max, nx_min, ny_max, ny_min 166 | 167 | 168 | def look(img, points, savepath): 169 | plt.clf() 170 | xs = [points[i][0] for i in range(17)] 171 | ys = [points[i][1] for i in range(17)] 172 | plt.imshow(img) 173 | plt.scatter(xs, ys, s=16) 174 | plt.savefig(savepath) 175 | plt.close() 176 | 177 | 178 | def update_landmark_cropped(landmark, x_min, y_min): 179 | result = copy.deepcopy(landmark) 180 | for i in range(17): 181 | result[i][0] = landmark[i][0] - x_min 182 | result[i][1] = landmark[i][1] - y_min 183 | return result 184 | 185 | 186 | def update_landmark_enlarged(landmark, w, h, resize_w, resize_h): 187 | time_w = resize_w / w 188 | time_h = resize_h / h 189 | for i in range(17): 190 | landmark[i][0] = landmark[i][0] * time_w 191 | landmark[i][1] = landmark[i][1] * time_h 192 | return landmark 193 | 194 | 195 | def save_landmark(landmark, dir, filename): 196 | if not os.path.exists(dir): 197 | os.makedirs(dir) 198 | path = os.path.join(dir, filename) 199 | with open(path, 'w', encoding='utf-8') as f: 200 | for i in range(17): 201 | f.write(str(int(landmark[i][0])) + '\t' + str(int(landmark[i][1])) + '\n') 202 | f.close() 203 | 204 | 205 | def list_all_images(root): 206 | result = [] 207 | for name in os.listdir(root): 208 | for file in os.listdir(os.path.join(root, name)): 209 | result.append(os.path.join(root, name, file)) 210 | return result 211 | 212 | 213 | def start_threads(image_list, n_threads=16): 214 | if n_threads > len(image_list): 215 | n_threads = len(image_list) 216 | n = int(math.ceil(len(image_list) / float(n_threads))) 217 | print('the thread num is {}'.format(n_threads)) 218 | print('each thread images num is {}'.format(n)) 219 | image_lists = [image_list[index:index + n] for index in range(0, len(image_list), n)] 220 | thread_list = {} 221 | for thread_id in range(n_threads): 222 | thread_list[thread_id] = MyThread(image_lists[thread_id], thread_id) 223 | thread_list[thread_id].start() 224 | 225 | for thread_id in range(n_threads): 226 | thread_list[thread_id].join() 227 | 228 | 229 | class MyThread(threading.Thread): 230 | def __init__(self, image_list, thread_id): 231 | threading.Thread.__init__(self) 232 | self.image_list = image_list 233 | self.thread_id = thread_id 234 | 235 | def run(self): 236 | print('thread {} begin'.format(self.thread_id)) 237 | 238 | image_len = len(self.image_list) 239 | print_interval = image_len // 100 240 | print_interval = print_interval if print_interval > 0 else 1 241 | 242 | for index, image_path in enumerate(self.image_list): 243 | name = image_path.split('/')[-2] 244 | filename = image_path.split('/')[-1][:-4] 245 | try: 246 | path_point_txt = image_path 247 | points = get_points_from_txt(path_point_txt) 248 | angle = get_rotate_angle(points) 249 | image_rotated, w, h = rotate_image(name, filename + '.jpg', angle) 250 | points_rotated = calculate_new_points(points, angle, w, h) 251 | x_max, x_min, y_max, y_min = calculate_boundingbox(points_rotated) 252 | x_max, x_min, y_max, y_min = enlarge(x_max, x_min, y_max, y_min, time_enlarge, w, h) 253 | image_cropped = image_rotated.crop((x_min, y_min, x_max, y_max)) 254 | points_cropped = update_landmark_cropped(points_rotated, x_min, y_min) 255 | image_result = image_cropped.resize((resize_w, resize_h), Image.BILINEAR) 256 | w_cropped, h_cropped = image_cropped.size 257 | points_result = update_landmark_enlarged(points_cropped, w_cropped, h_cropped, resize_w, resize_h) 258 | dir = os.path.join(path_output_points, name) 259 | save_landmark(points_result, dir, filename + '.txt') 260 | path_output_image = os.path.join(path_output_images, name) 261 | if not os.path.exists(path_output_image): 262 | os.makedirs(path_output_image) 263 | image_result.save(os.path.join(path_output_image, filename + '.jpg'), 'jpeg') 264 | except ZeroDivisionError: 265 | print(name) 266 | print(filename) 267 | if index % print_interval == 0 and index > 0: 268 | print('{}/{} in thread {} has been sloven' 269 | .format(index, image_len, self.thread_id)) 270 | 271 | print('thread {} end.'.format(self.thread_id)) 272 | 273 | 274 | if __name__ == '__main__': 275 | main() 276 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import numpy as np 4 | 5 | import torch 6 | import torch.utils.data as data 7 | import torchvision.transforms.functional as transF 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | 11 | from utils import make_init_field, load_filenames, warp_position_map 12 | 13 | transform = transforms.Compose([ 14 | transforms.ToTensor(), 15 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 16 | ] 17 | ) 18 | 19 | flip_index = [0, 3, 2, 1, 7, 6, 5, 4, 11, 10, 9, 8, 12, 13, 16, 15, 14] 20 | 21 | 22 | def make_dataset(args): 23 | list_total = os.listdir(os.path.join(args.data_root, 'image')) 24 | train_list = random.sample(list_total, len(list_total) // 2) 25 | if args.mode == 'train': 26 | return WCDataSet(args, train_list, 'train') 27 | else: 28 | test_list = list(set(list_total).difference(set(train_list))) 29 | return WCDataSet(args, test_list, 'test') 30 | 31 | 32 | def load_img(path): 33 | return Image.open(path).convert('RGB') 34 | 35 | 36 | def load_landmark(path): 37 | return torch.from_numpy(np.loadtxt(path, delimiter='\t')).float() 38 | 39 | 40 | def cal_field(src, dst): 41 | field_y, field_x = warp_position_map(src, dst) 42 | field = np.concatenate((field_y, field_x), axis=2) 43 | field = torch.from_numpy(field).float() 44 | return field 45 | 46 | 47 | def random_horizonal_flip(img, landmark): 48 | if random.random() < 0.5: 49 | img_hflip = transF.hflip(img) 50 | landmark[:, 0] = 256 - landmark[:, 0] 51 | landmark_hflip = torch.zeros_like(landmark) 52 | for i in range(landmark.shape[0]): 53 | landmark_hflip[i] = landmark[flip_index[i]] 54 | return img_hflip, landmark_hflip 55 | else: 56 | return img, landmark 57 | 58 | 59 | def random_enlarge_landmark(landmark_mean, landmark, p=0.5): 60 | if random.random() < p: 61 | rate = random.random() / 5 + 1 62 | else: 63 | rate = 1 64 | landmark = (landmark - landmark_mean) * (rate - 1) + landmark 65 | return landmark, rate 66 | 67 | 68 | def random_resize_crop(img, landmark, resize=288): 69 | if random.random() < 0.5: 70 | w, h = img.size 71 | time = resize / w 72 | img = img.resize((resize, resize), Image.BILINEAR) 73 | x = random.random() * (time - 1) * w 74 | y = random.random() * (time - 1) * h 75 | img_crop = img.crop((x, y, x + w, y + h)) 76 | landmark_crop = landmark * time 77 | landmark_crop[:, 0] = landmark_crop[:, 0] - x 78 | landmark_crop[:, 1] = landmark_crop[:, 1] - y 79 | return img_crop, landmark_crop 80 | else: 81 | return img, landmark 82 | 83 | 84 | class WCDataSet(data.Dataset): 85 | def __init__(self, args, name_list, mode='train', transform=transform): 86 | super(WCDataSet, self).__init__() 87 | self.data_root = args.data_root 88 | self.name_list = name_list 89 | self.mode = mode 90 | self.enlarge = args.enlarge 91 | self.hflip = args.hflip 92 | self.resize_crop = args.resize_crop 93 | self.same_id = args.same_id 94 | self.transform = transform 95 | self.const_map = make_init_field().squeeze(0) 96 | 97 | self.p_dir = {name: load_filenames(self.data_root, name, 'image', 'P') for name in name_list} 98 | self.c_dir = {name: load_filenames(self.data_root, name, 'image', 'C') for name in name_list} 99 | 100 | self.num_p = sum([len(files) for files in self.p_dir.values()]) 101 | self.num_c = sum([len(files) for files in self.c_dir.values()]) 102 | 103 | print('load dataset over') 104 | print('{} image: {}, {} caricature: {}'.format(self.mode, self.num_p, self.mode, self.num_c)) 105 | 106 | self.img_list = [] 107 | for name in name_list: 108 | self.img_list += self.p_dir[name] 109 | assert len(self.img_list) == self.num_p 110 | 111 | if self.mode == 'train': 112 | self.size = min(args.max_dataset_size, 100000) 113 | self.landmark_mean = torch.zeros(17, 2) 114 | for img_path in self.img_list: 115 | landmark_path = img_path.replace('image', 'landmark').replace('.jpg', '.txt') 116 | landmark = load_landmark(landmark_path) 117 | self.landmark_mean += landmark 118 | self.landmark_mean /= self.num_p 119 | else: 120 | self.size = self.num_p 121 | 122 | def sample_pair(self, same_id=True): 123 | name1 = random.choice(self.name_list) 124 | img_p_path = random.choice(self.p_dir[name1]) 125 | landmark_p_path = img_p_path.replace('image', 'landmark').replace('.jpg', '.txt') 126 | 127 | if same_id: 128 | img_c_path = random.choice(self.c_dir[name1]) 129 | else: 130 | name2 = random.choice(self.name_list) 131 | img_c_path = random.choice(self.c_dir[name2]) 132 | landmark_c_path = img_c_path.replace('image', 'landmark').replace('.jpg', '.txt') 133 | return img_p_path, img_c_path, landmark_p_path, landmark_c_path 134 | 135 | def __getitem__(self, index): 136 | if self.mode == 'train': 137 | img_p_path, img_c_path, landmark_p_path, landmark_c_path = self.sample_pair(same_id=self.same_id) 138 | 139 | img_p = load_img(img_p_path) 140 | img_c = load_img(img_c_path) 141 | landmark_p = load_landmark(landmark_p_path) 142 | landmark_c = load_landmark(landmark_c_path) 143 | 144 | if self.hflip: 145 | img_p, landmark_p = random_horizonal_flip(img_p, landmark_p) 146 | img_c, landmark_c = random_horizonal_flip(img_c, landmark_c) 147 | 148 | if self.resize_crop: 149 | img_p, landmark_p = random_resize_crop(img_p, landmark_p) 150 | img_c, landmark_c = random_resize_crop(img_c, landmark_c) 151 | 152 | if self.enlarge: 153 | landmark_c, _ = random_enlarge_landmark(self.landmark_mean, landmark_c) 154 | 155 | img_p = self.transform(img_p) 156 | img_c = self.transform(img_c) 157 | 158 | field_m2c = cal_field(self.landmark_mean, landmark_c) 159 | field_m2p = cal_field(self.landmark_mean, landmark_p) 160 | field_p2c = cal_field(landmark_p, landmark_c) 161 | 162 | item = { 163 | 'name': os.path.basename(os.path.dirname(img_p_path)), 164 | 'filename': os.path.basename(img_p_path)[:-4], 165 | 'img_p': img_p, 166 | 'img_c': img_c, 167 | 'field_p2c': field_p2c, 168 | 'field_m2c': field_m2c, 169 | 'field_m2p': field_m2p, 170 | } 171 | else: 172 | img_p_path = self.img_list[index] 173 | img_p = load_img(img_p_path) 174 | name = os.path.basename(os.path.dirname(img_p_path)) 175 | 176 | img_p = self.transform(img_p) 177 | 178 | item = { 179 | 'img_p': img_p, 180 | 'name': name, 181 | 'filename': os.path.basename(img_p_path)[:-4] 182 | } 183 | 184 | return item 185 | 186 | def __len__(self): 187 | return self.size 188 | 189 | -------------------------------------------------------------------------------- /images/Meg Ryan/P00015.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edward3862/CariMe-pytorch/b20ddd6d772c3880bac40109790f86c2a1d7212d/images/Meg Ryan/P00015.jpg -------------------------------------------------------------------------------- /images/Meg Ryan/P00015_gen.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edward3862/CariMe-pytorch/b20ddd6d772c3880bac40109790f86c2a1d7212d/images/Meg Ryan/P00015_gen.jpg -------------------------------------------------------------------------------- /images/examples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/edward3862/CariMe-pytorch/b20ddd6d772c3880bac40109790f86c2a1d7212d/images/examples.png -------------------------------------------------------------------------------- /main_cal_warp_degree.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import argparse 4 | import os 5 | 6 | import torch 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | from networks import Warper 11 | from utils import str2bool 12 | from dataset import make_dataset 13 | from torch.utils.data import DataLoader 14 | 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--data_root', type=str, default='data/WebCaricature_align_1.3_256') 19 | parser.add_argument('--name', type=str, default='results/warper') 20 | parser.add_argument('--model', type=str, default='warper_00020000.pt') 21 | 22 | parser.add_argument('--resize_crop', type=str2bool, default=False) 23 | parser.add_argument('--enlarge', type=str2bool, default=False) 24 | parser.add_argument('--same_id', type=str2bool, default=False) 25 | parser.add_argument('--hflip', type=str2bool, default=False) 26 | 27 | parser.add_argument('--mode', type=str, default='test') 28 | parser.add_argument('--batch_size', type=int, default=4) 29 | parser.add_argument('--num_workers', type=int, default=8) 30 | 31 | parser.add_argument('--img_size', type=int, default=256) 32 | parser.add_argument('--field_size', type=int, default=128) 33 | parser.add_argument('--embedding_dim', type=int, default=32) 34 | parser.add_argument('--warp_dim', type=int, default=64) 35 | parser.add_argument('--scale', type=float, default=1.0) 36 | 37 | args = parser.parse_args() 38 | 39 | 40 | def make_field(length): 41 | temp_height = np.linspace(-1.0, 1.0, num=length).reshape(length, 1, 1) 42 | temp_width = np.linspace(-1.0, 1.0, num=length).reshape(1, length, 1) 43 | 44 | pos_x = np.repeat(temp_height, length, axis=1) 45 | pos_y = np.repeat(temp_width, length, axis=0) 46 | 47 | return np.concatenate((pos_y, pos_x), axis=2) 48 | 49 | 50 | def cal_delta(map1, map2): 51 | y1, x1 = map1[:, :, 0], map1[:, :, 1] 52 | y2, x2 = map2[:, :, 0], map2[:, :, 1] 53 | return np.mean(np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)) 54 | 55 | 56 | if __name__ == '__main__': 57 | SEED = 0 58 | random.seed(SEED) 59 | np.random.seed(SEED) 60 | torch.manual_seed(SEED) 61 | torch.cuda.manual_seed(SEED) 62 | 63 | print(args.name) 64 | print(args.scale) 65 | 66 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 67 | model_path = os.path.join(args.name, 'checkpoints', args.model) 68 | print('load model: ', model_path) 69 | 70 | dataset = make_dataset(args) 71 | dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, 72 | num_workers=args.num_workers) 73 | 74 | warper = Warper(args) 75 | state_dict = torch.load(model_path) 76 | warper.load_state_dict(state_dict) 77 | warper.to(device) 78 | warper.eval() 79 | 80 | deltas = [] 81 | const = make_field(256) 82 | 83 | for batch, item in tqdm(enumerate(dataloader)): 84 | img_p = item['img_p'].to(device) 85 | names = item['name'] 86 | filenames = item['filename'] 87 | 88 | z = torch.randn(img_p.size()[0], args.warp_dim, 1, 1).cuda() 89 | _, fields, _ = warper(img_p, z, scale=args.scale) 90 | 91 | for i in range(img_p.size()[0]): 92 | field = fields[i].detach().cpu().numpy() 93 | deltas.append(cal_delta(const, field) * 256) 94 | print(np.mean(deltas)) 95 | print(np.std(deltas)) 96 | 97 | 98 | -------------------------------------------------------------------------------- /main_generate.py: -------------------------------------------------------------------------------- 1 | import random 2 | import shutil 3 | 4 | import argparse 5 | import os 6 | import torch 7 | import numpy as np 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | 13 | from dataset import make_dataset 14 | from networks import Warper, Styler 15 | from utils import unload_img, str2bool 16 | 17 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 18 | 19 | parser = argparse.ArgumentParser() 20 | 21 | parser.add_argument('--data_root', type=str, default='data/WebCaricature_align_1.3_256') 22 | parser.add_argument('--model_path_warper', type=str, default='results/warper/checkpoints/warper_00020000.pt') 23 | parser.add_argument('--model_path_styler', type=str, default='results/styler/checkpoints/gen_00200000.pt') 24 | parser.add_argument('--output_path', type=str, default='results/generated') 25 | 26 | parser.add_argument('--mode', type=str, default='test') 27 | parser.add_argument('--hflip', type=str2bool, default=False) 28 | parser.add_argument('--enlarge', type=str2bool, default=False) 29 | parser.add_argument('--resize_crop', type=str2bool, default=False) 30 | parser.add_argument('--same_id', type=str2bool, default=True) 31 | parser.add_argument('--batch_size', type=int, default=4) 32 | parser.add_argument('--num_workers', type=int, default=8) 33 | 34 | parser.add_argument('--img_size', type=int, default=256) 35 | parser.add_argument('--field_size', type=int, default=128) 36 | parser.add_argument('--embedding_dim', type=int, default=32) 37 | parser.add_argument('--warp_dim', type=int, default=64) 38 | parser.add_argument('--style_dim', type=int, default=8) 39 | parser.add_argument('--scale', type=float, default=1) 40 | parser.add_argument('--generate_num', type=int, default=3) 41 | 42 | 43 | args = parser.parse_args() 44 | 45 | transform = transforms.Compose([ 46 | transforms.ToTensor(), 47 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 48 | ] 49 | ) 50 | 51 | 52 | def load_img(path): 53 | img = Image.open(path) 54 | img = transform(img) 55 | img = img.unsqueeze(0) 56 | return img 57 | 58 | 59 | if __name__ == '__main__': 60 | 61 | SEED = 0 62 | random.seed(SEED) 63 | np.random.seed(SEED) 64 | torch.manual_seed(SEED) 65 | torch.cuda.manual_seed(SEED) 66 | 67 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 68 | print('load warper: ', args.model_path_warper) 69 | print('load styler: ', args.model_path_styler) 70 | output_path = args.output_path 71 | print('output path: ', output_path) 72 | if os.path.exists(output_path): 73 | shutil.rmtree(output_path) 74 | os.makedirs(output_path, exist_ok=True) 75 | 76 | dataset = make_dataset(args) 77 | dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, 78 | num_workers=args.num_workers) 79 | 80 | warper = Warper(args) 81 | warper.load(args.model_path_warper) 82 | warper.to(device) 83 | warper.eval() 84 | 85 | styler = Styler(args) 86 | styler.load(args.model_path_styler) 87 | styler.to(device) 88 | styler.eval() 89 | 90 | num = args.generate_num 91 | 92 | for batch, item in tqdm(enumerate(dataloader)): 93 | img_p = item['img_p'].to(device) 94 | names = item['name'] 95 | filenames = item['filename'] 96 | 97 | results = [] 98 | 99 | for i in range(num): 100 | z = torch.randn(img_p.size()[0], args.warp_dim, 1, 1).cuda() 101 | s = torch.randn(img_p.size()[0], args.style_dim, 1, 1).cuda() 102 | 103 | img_warp, psmap, _ = warper(img_p, z, scale=args.scale) 104 | img_style = styler(img_p, s) 105 | img_warp_style = styler(img_warp, s) 106 | 107 | results.append(img_warp_style.unsqueeze(0)) 108 | results = torch.cat(results, dim=0).detach().cpu() 109 | 110 | for i in range(img_p.size()[0]): 111 | input = img_p[i].detach().cpu() 112 | name = names[i] 113 | filename = filenames[i] 114 | result = results[:, i, :, :, :] 115 | 116 | result = result.permute(1, 2, 0, 3) 117 | result = result.reshape(3, 256, 256 * num) 118 | 119 | output = torch.cat((input, result), dim=2) 120 | unload_img(output).save(os.path.join(output_path, '{}_{}.jpg'.format(name, filename)), 'jpeg') 121 | 122 | -------------------------------------------------------------------------------- /main_generate_single_image.py: -------------------------------------------------------------------------------- 1 | import random 2 | import shutil 3 | 4 | import argparse 5 | import os 6 | import torch 7 | import numpy as np 8 | from networks import Warper, Styler 9 | from utils import load_img, unload_img 10 | 11 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 12 | 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument('--input_path', type=str, default='images/Meg Ryan/P00015.jpg') 16 | parser.add_argument('--model_path_warper', type=str, default='results/warper/checkpoints/warper_00020000.pt') 17 | parser.add_argument('--model_path_styler', type=str, default='results/styler/checkpoints/gen_00200000.pt') 18 | 19 | parser.add_argument('--img_size', type=int, default=256) 20 | parser.add_argument('--field_size', type=int, default=128) 21 | parser.add_argument('--embedding_dim', type=int, default=32) 22 | parser.add_argument('--warp_dim', type=int, default=64) 23 | parser.add_argument('--style_dim', type=int, default=8) 24 | parser.add_argument('--scale', type=float, default=1) 25 | parser.add_argument('--generate_num', type=int, default=5) 26 | 27 | args = parser.parse_args() 28 | 29 | 30 | if __name__ == '__main__': 31 | output_path = os.path.join(args.input_path[:-4] + '_gen.jpg') 32 | 33 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 34 | print('load warper: ', args.model_path_warper) 35 | print('load styler: ', args.model_path_styler) 36 | 37 | warper = Warper(args) 38 | warper.load(args.model_path_warper) 39 | warper.to(device) 40 | warper.eval() 41 | 42 | styler = Styler(args) 43 | styler.load(args.model_path_styler) 44 | styler.to(device) 45 | styler.eval() 46 | 47 | num = args.generate_num 48 | img_p = load_img(args.input_path).to(device) 49 | results = [] 50 | for i in range(num): 51 | z = torch.randn(img_p.size()[0], args.warp_dim, 1, 1).cuda() 52 | img_warp, psmap, _ = warper(img_p, z, scale=args.scale) 53 | 54 | s = torch.randn(img_p.size()[0], args.style_dim, 1, 1).cuda() 55 | img_style = styler(img_p, s) 56 | img_warp_style = styler(img_warp, s) 57 | 58 | results.append(img_warp_style) 59 | 60 | results = torch.cat(results, dim=3) 61 | output = torch.cat([img_p, results], dim=3).squeeze().detach().cpu() 62 | unload_img(output).save(output_path, 'jpeg') 63 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .loss import * 2 | from .modules import * 3 | from .styler import * 4 | from .warper import * -------------------------------------------------------------------------------- /networks/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def l1_loss(input, target): 5 | return torch.mean(torch.abs(input - target)) 6 | 7 | 8 | def mse_loss(input, target): 9 | return torch.mean((input-target)**2) 10 | 11 | 12 | def tv_loss(img): 13 | w_variance = torch.sum(torch.pow(img[:, :, :, :-1] - img[:, :, :, 1:], 2)) 14 | h_variance = torch.sum(torch.pow(img[:, :, :-1, :] - img[:, :, 1:, :], 2)) 15 | loss = h_variance + w_variance 16 | return loss 17 | -------------------------------------------------------------------------------- /networks/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.nn.parameter import Parameter 5 | 6 | 7 | class ResBlocks(nn.Module): 8 | def __init__(self, num_blocks, dim, norm, activation, pad_type): 9 | super(ResBlocks, self).__init__() 10 | self.model = [] 11 | for i in range(num_blocks): 12 | self.model += [ResBlock(dim, 13 | norm=norm, 14 | activation=activation, 15 | pad_type=pad_type)] 16 | self.model = nn.Sequential(*self.model) 17 | 18 | def forward(self, x): 19 | return self.model(x) 20 | 21 | 22 | class AdaResBlocks(nn.Module): 23 | def __init__(self, num_blocks, dim, restype='adain', dropout=0): 24 | super(AdaResBlocks, self).__init__() 25 | self.num_blocks = num_blocks 26 | for i in range(num_blocks): 27 | self.__setattr__('res' + str(i), AdaResBlock(dim, restype=restype, dropout=dropout)) 28 | 29 | def forward(self, x, gamma, beta): 30 | for i in range(self.num_blocks): 31 | res_block = self.__getattr__('res' + str(i)) 32 | x = res_block(x, gamma, beta) 33 | return x 34 | 35 | 36 | class MLP(nn.Module): 37 | def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu'): 38 | 39 | super(MLP, self).__init__() 40 | self.model = [] 41 | self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)] 42 | for i in range(n_blk - 2): 43 | self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)] 44 | self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] 45 | self.model = nn.Sequential(*self.model) 46 | 47 | def forward(self, x): 48 | return self.model(x.view(x.size(0), -1)) 49 | 50 | 51 | class ResBlock(nn.Module): 52 | def __init__(self, dim, norm='in', activation='relu', pad_type='zero'): 53 | super(ResBlock, self).__init__() 54 | model = [] 55 | model += [Conv2dBlock(dim, dim, 3, 1, 1, 56 | norm=norm, 57 | activation=activation, 58 | pad_type=pad_type)] 59 | model += [Conv2dBlock(dim, dim, 3, 1, 1, 60 | norm=norm, 61 | activation='none', 62 | pad_type=pad_type)] 63 | self.model = nn.Sequential(*model) 64 | 65 | def forward(self, x): 66 | residual = x 67 | out = self.model(x) 68 | out += residual 69 | return out 70 | 71 | 72 | class AdaResBlock(nn.Module): 73 | def __init__(self, dim, restype='adain', use_bias=False, dropout=0): 74 | super(AdaResBlock, self).__init__() 75 | self.model = [] 76 | self.dropout = dropout 77 | 78 | self.pad1 = nn.ReflectionPad2d(1) 79 | self.conv1 = nn.Conv2d(dim, dim, 3, 1, 0, bias=use_bias) 80 | self.norm1 = AdaIN() if restype == 'adain' else AdaLIN(dim) 81 | self.activ = nn.ReLU(inplace=True) 82 | 83 | self.pad2 = nn.ReflectionPad2d(1) 84 | self.conv2 = nn.Conv2d(dim, dim, 3, 1, 0, bias=use_bias) 85 | self.norm2 = AdaIN() if restype == 'adain' else AdaLIN(dim) 86 | 87 | def forward(self, x, gamma, beta): 88 | residual = x 89 | x = self.pad1(x) 90 | x = self.conv1(x) 91 | x = self.norm1(x, gamma, beta) 92 | x = self.activ(x) 93 | 94 | if self.dropout: 95 | x = F.dropout(x, p=self.dropout) 96 | 97 | x = self.pad2(x) 98 | x = self.conv2(x) 99 | x = self.norm2(x, gamma, beta) 100 | out = x + residual 101 | return out 102 | 103 | 104 | class LinearBlock(nn.Module): 105 | def __init__(self, input_dim, output_dim, norm='none', activation='relu', use_bias=False, dropout=0): 106 | super(LinearBlock, self).__init__() 107 | self.dropout = dropout 108 | # initialize fully connected layer 109 | if norm == 'sn': 110 | self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias)) 111 | else: 112 | self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) 113 | 114 | # initialize normalization 115 | norm_dim = output_dim 116 | if norm == 'bn': 117 | self.norm = nn.BatchNorm1d(norm_dim) 118 | elif norm == 'in': 119 | self.norm = nn.InstanceNorm1d(norm_dim) 120 | elif norm == 'ln': 121 | self.norm = nn.LayerNorm(norm_dim) 122 | elif norm == 'none' or norm == 'sn': 123 | self.norm = None 124 | else: 125 | assert 0, "Unsupported normalization: {}".format(norm) 126 | 127 | # initialize activation 128 | if activation == 'relu': 129 | self.activation = nn.ReLU(inplace=True) 130 | elif activation == 'lrelu': 131 | self.activation = nn.LeakyReLU(0.2, inplace=True) 132 | elif activation == 'prelu': 133 | self.activation = nn.PReLU() 134 | elif activation == 'selu': 135 | self.activation = nn.SELU(inplace=True) 136 | elif activation == 'sigmoid': 137 | self.activation = nn.Sigmoid() 138 | elif activation == 'tanh': 139 | self.activation = nn.Tanh() 140 | elif activation == 'none': 141 | self.activation = None 142 | else: 143 | assert 0, "Unsupported activation: {}".format(activation) 144 | 145 | def forward(self, x): 146 | out = self.fc(x) 147 | if self.norm: 148 | out = self.norm(out) 149 | if self.activation: 150 | out = self.activation(out) 151 | return out 152 | 153 | 154 | class Conv2dBlock(nn.Module): 155 | def __init__(self, input_dim, output_dim, kernel_size, stride, padding=0, norm='bn', activation='relu', 156 | pad_type='reflect', use_bias=False, dropout=0): 157 | super(Conv2dBlock, self).__init__() 158 | self.use_bias = use_bias 159 | self.dropout = dropout 160 | # initialize padding 161 | if pad_type == 'reflect': 162 | self.pad = nn.ReflectionPad2d(padding) 163 | elif pad_type == 'replicate': 164 | self.pad = nn.ReplicationPad2d(padding) 165 | elif pad_type == 'zero': 166 | self.pad = nn.ZeroPad2d(padding) 167 | else: 168 | assert 0, "Unsupported padding type: {}".format(pad_type) 169 | 170 | # initialize normalization 171 | norm_dim = output_dim 172 | if norm == 'bn': 173 | self.norm = nn.BatchNorm2d(norm_dim) 174 | elif norm == 'in': 175 | self.norm = nn.InstanceNorm2d(norm_dim) 176 | elif norm == 'ln': 177 | self.norm = nn.LayerNorm(norm_dim) 178 | elif norm == 'adain': 179 | self.norm = AdaptiveInstanceNorm2d(norm_dim) 180 | elif norm == 'lin': 181 | self.norm = LIN(norm_dim) 182 | elif norm == 'none' or norm == 'sn': 183 | self.norm = None 184 | else: 185 | assert 0, "Unsupported normalization: {}".format(norm) 186 | 187 | # initialize activation 188 | if activation == 'relu': 189 | self.activation = nn.ReLU(inplace=True) 190 | elif activation == 'sigmoid': 191 | self.activation = nn.Sigmoid() 192 | elif activation == 'lrelu': 193 | self.activation = nn.LeakyReLU(0.2, inplace=True) 194 | elif activation == 'prelu': 195 | self.activation = nn.PReLU() 196 | elif activation == 'selu': 197 | self.activation = nn.SELU(inplace=True) 198 | elif activation == 'tanh': 199 | self.activation = nn.Tanh() 200 | elif activation == 'none': 201 | self.activation = None 202 | else: 203 | assert 0, "Unsupported activation: {}".format(activation) 204 | 205 | # initialize convolution 206 | if norm == 'sn': 207 | self.conv = nn.utils.spectral_norm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)) 208 | else: 209 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) 210 | 211 | def forward(self, x): 212 | x = self.conv(self.pad(x)) 213 | if self.norm: 214 | x = self.norm(x) 215 | if self.activation: 216 | x = self.activation(x) 217 | if self.dropout: 218 | x = F.dropout(x, p=self.dropout) 219 | return x 220 | 221 | 222 | class AdaptiveInstanceNorm2d(nn.Module): 223 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 224 | super(AdaptiveInstanceNorm2d, self).__init__() 225 | self.num_features = num_features 226 | self.eps = eps 227 | self.momentum = momentum 228 | self.weight = None 229 | self.bias = None 230 | self.register_buffer('running_mean', torch.zeros(num_features)) 231 | self.register_buffer('running_var', torch.ones(num_features)) 232 | 233 | def forward(self, x): 234 | assert self.weight is not None and \ 235 | self.bias is not None, "Please assign AdaIN weight first" 236 | b, c = x.size(0), x.size(1) 237 | running_mean = self.running_mean.repeat(b) 238 | running_var = self.running_var.repeat(b) 239 | x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) 240 | out = F.batch_norm( 241 | x_reshaped, running_mean, running_var, self.weight, self.bias, 242 | True, self.momentum, self.eps) 243 | return out.view(b, c, *x.size()[2:]) 244 | 245 | def __repr__(self): 246 | return self.__class__.__name__ + '(' + str(self.num_features) + ')' 247 | 248 | 249 | def l2normalize(v, eps=1e-12): 250 | return v / (v.norm() + eps) 251 | 252 | 253 | class SpectralNorm(nn.Module): 254 | """ 255 | Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida 256 | and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan 257 | """ 258 | def __init__(self, module, name='weight', power_iterations=1): 259 | super(SpectralNorm, self).__init__() 260 | self.module = module 261 | self.name = name 262 | self.power_iterations = power_iterations 263 | if not self._made_params(): 264 | self._make_params() 265 | 266 | def _update_u_v(self): 267 | u = getattr(self.module, self.name + "_u") 268 | v = getattr(self.module, self.name + "_v") 269 | w = getattr(self.module, self.name + "_bar") 270 | 271 | height = w.data.shape[0] 272 | for _ in range(self.power_iterations): 273 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) 274 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) 275 | 276 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 277 | sigma = u.dot(w.view(height, -1).mv(v)) 278 | setattr(self.module, self.name, w / sigma.expand_as(w)) 279 | 280 | def _made_params(self): 281 | try: 282 | u = getattr(self.module, self.name + "_u") 283 | v = getattr(self.module, self.name + "_v") 284 | w = getattr(self.module, self.name + "_bar") 285 | return True 286 | except AttributeError: 287 | return False 288 | 289 | def _make_params(self): 290 | w = getattr(self.module, self.name) 291 | 292 | height = w.data.shape[0] 293 | width = w.view(height, -1).data.shape[1] 294 | 295 | u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 296 | v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 297 | u.data = l2normalize(u.data) 298 | v.data = l2normalize(v.data) 299 | w_bar = nn.Parameter(w.data) 300 | 301 | del self.module._parameters[self.name] 302 | 303 | self.module.register_parameter(self.name + "_u", u) 304 | self.module.register_parameter(self.name + "_v", v) 305 | self.module.register_parameter(self.name + "_bar", w_bar) 306 | 307 | def forward(self, *args): 308 | self._update_u_v() 309 | return self.module.forward(*args) 310 | 311 | 312 | class LIN(nn.Module): 313 | def __init__(self, num_features, eps=1e-5): 314 | super(LIN, self).__init__() 315 | self.eps = eps 316 | self.rho = Parameter(torch.Tensor(1, num_features, 1, 1)) 317 | self.gamma = Parameter(torch.Tensor(1, num_features, 1, 1)) 318 | self.beta = Parameter(torch.Tensor(1, num_features, 1, 1)) 319 | self.rho.data.fill_(0.0) 320 | self.gamma.data.fill_(1.0) 321 | self.beta.data.fill_(0.0) 322 | 323 | def forward(self, input): 324 | in_mean, in_var = torch.mean(input, dim=[2, 3], keepdim=True), torch.var(input, dim=[2, 3], keepdim=True) 325 | out_in = (input - in_mean) / torch.sqrt(in_var + self.eps) 326 | ln_mean, ln_var = torch.mean(input, dim=[1, 2, 3], keepdim=True), torch.var(input, dim=[1, 2, 3], keepdim=True) 327 | out_ln = (input - ln_mean) / torch.sqrt(ln_var + self.eps) 328 | out = self.rho.expand(input.shape[0], -1, -1, -1) * out_in + (1-self.rho.expand(input.shape[0], -1, -1, -1)) * out_ln 329 | out = out * self.gamma.expand(input.shape[0], -1, -1, -1) + self.beta.expand(input.shape[0], -1, -1, -1) 330 | 331 | return out 332 | 333 | 334 | class AdaIN(nn.Module): 335 | def __init__(self, eps=1e-5): 336 | super(AdaIN, self).__init__() 337 | self.eps = eps 338 | 339 | def forward(self, x, gamma, beta): 340 | in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True) 341 | out_in = (x - in_mean) / torch.sqrt(in_var + self.eps) 342 | out = out_in * gamma + beta 343 | return out 344 | 345 | 346 | class AdaLIN(nn.Module): 347 | def __init__(self, dim, eps=1e-5): 348 | super(AdaLIN, self).__init__() 349 | 350 | self.eps = eps 351 | self.rho = Parameter(torch.Tensor(1, dim, 1, 1)) 352 | self.rho.data.fill_(0.9) 353 | 354 | def forward(self, x, gamma, beta): 355 | in_mean, in_var = torch.mean(x, dim=[2, 3], keepdim=True), torch.var(x, dim=[2, 3], keepdim=True) 356 | out_in = (x - in_mean) / torch.sqrt(in_var + self.eps) 357 | ln_mean, ln_var = torch.mean(x, dim=[1, 2, 3], keepdim=True), torch.var(x, dim=[1, 2, 3], keepdim=True) 358 | out_ln = (x - ln_mean) / torch.sqrt(ln_var + self.eps) 359 | out = self.rho.expand(x.shape[0], -1, -1, -1) * out_in + (1 - self.rho.expand(x.shape[0], -1, -1, -1)) * out_ln 360 | out = out * gamma + beta 361 | return out 362 | 363 | 364 | class RhoClipper(object): 365 | 366 | def __init__(self, min, max): 367 | self.clip_min = min 368 | self.clip_max = max 369 | assert min < max 370 | 371 | def __call__(self, module): 372 | 373 | if hasattr(module, 'rho'): 374 | w = module.rho.data 375 | w = w.clamp(self.clip_min, self.clip_max) 376 | module.rho.data = w -------------------------------------------------------------------------------- /networks/styler.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from networks.modules import Conv2dBlock, ResBlocks, LinearBlock, AdaResBlocks 8 | 9 | 10 | class ContentEncoder(nn.Module): 11 | def __init__(self, input_dim=3, dim=64, num_res=3, downs=2, norm='in', activation='relu', pad_type='reflect'): 12 | super(ContentEncoder, self).__init__() 13 | self.model = [] 14 | self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activation, pad_type=pad_type)] 15 | for i in range(downs): 16 | self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activation, pad_type=pad_type)] 17 | dim *= 2 18 | self.model += [ResBlocks(num_res, dim, norm=norm, activation='relu', pad_type=pad_type)] 19 | self.model = nn.Sequential(*self.model) 20 | self.output_dim = dim 21 | 22 | def forward(self, x): 23 | return self.model(x) 24 | 25 | 26 | class StyleEncoder(nn.Module): 27 | def __init__(self, input_dim=3, dim=64, downs=2, style_dim=8, norm='none', activation='relu', pad_type='reflect'): 28 | super(StyleEncoder, self).__init__() 29 | self.model = [] 30 | self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activation, pad_type=pad_type)] 31 | for i in range(2): 32 | self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activation, pad_type=pad_type)] 33 | dim *= 2 34 | for i in range(downs - 2): 35 | self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activation, pad_type=pad_type)] 36 | self.model += [nn.AdaptiveAvgPool2d(1)] 37 | self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)] 38 | self.model = nn.Sequential(*self.model) 39 | self.norm = nn.BatchNorm2d(style_dim) 40 | 41 | def forward(self, x): 42 | x = self.model(x) 43 | # if norm: 44 | # x = self.norm(x) 45 | return x 46 | 47 | 48 | class StyleController(nn.Module): 49 | def __init__(self, style_dim=8, dim=256, norm='ln', activation='relu'): 50 | super(StyleController, self).__init__() 51 | self.model = [] 52 | self.model += [LinearBlock(style_dim, dim, norm=norm, activation=activation)] 53 | self.model += [LinearBlock(dim, dim, norm=norm, activation=activation)] 54 | self.model = nn.Sequential(*self.model) 55 | self.fc_gamma = LinearBlock(256, 256, norm='none', activation='none') 56 | self.fc_beta = LinearBlock(256, 256, norm='none', activation='none') 57 | 58 | def forward(self, x): 59 | x = x.view(x.size(0), -1) 60 | x = self.model(x) 61 | gamma = self.fc_gamma(x).unsqueeze(2).unsqueeze(3) 62 | beta = self.fc_beta(x).unsqueeze(2).unsqueeze(3) 63 | return gamma, beta 64 | 65 | 66 | class Decoder(nn.Module): 67 | def __init__(self, output_dim=3, dim=256, num_res=3, ups=2, restype='adalin', norm='lin', activation='relu', 68 | pad_type='reflect'): 69 | super(Decoder, self).__init__() 70 | self.res = AdaResBlocks(num_res, dim, restype=restype) 71 | self.model = [] 72 | for i in range(ups): 73 | self.model += [nn.Upsample(scale_factor=2), 74 | Conv2dBlock(dim, dim // 2, 5, 1, 2, norm=norm, activation=activation, pad_type=pad_type)] 75 | dim //= 2 76 | self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)] 77 | self.model = nn.Sequential(*self.model) 78 | 79 | def forward(self, x, gamma, beta): 80 | x = self.res(x, gamma, beta) 81 | return self.model(x) 82 | 83 | 84 | class Gen_Style(nn.Module): 85 | def __init__(self, args): 86 | super(Gen_Style, self).__init__() 87 | style_dim = args.style_dim 88 | self.encoder_c = ContentEncoder(3, 64, 4, 2, norm='in', activation='relu', pad_type='reflect') 89 | self.encoder_s = StyleEncoder(3, 64, 2, style_dim, norm='none', activation='relu', pad_type='reflect') 90 | latent_dim = self.encoder_c.output_dim 91 | self.decoder = Decoder(3, latent_dim, 4, 2, restype='adalin', norm='lin', activation='relu', pad_type='reflect') 92 | self.style_controller = StyleController(style_dim, latent_dim, norm='ln', activation='relu') 93 | 94 | def encode(self, x): 95 | content = self.encoder_c(x) 96 | style = self.encoder_s(x) 97 | return content, style 98 | 99 | def decode(self, content, style): 100 | gamma, beta = self.style_controller(style) 101 | output = self.decoder(content, gamma, beta) 102 | return output 103 | 104 | def forward(self, img_p, s): 105 | content, _ = self.encode(img_p) 106 | output = self.decode(content, s) 107 | return output 108 | 109 | 110 | class Dis(nn.Module): 111 | def __init__(self): 112 | super(Dis, self).__init__() 113 | dim = 64 114 | n_layers = 5 115 | model = [] 116 | model += [Conv2dBlock(3, 64, 4, 2, 1, norm='sn', activation='lrelu', use_bias=True, pad_type='reflect')] 117 | 118 | for i in range(1, n_layers - 2): 119 | model += [ 120 | Conv2dBlock(dim, dim * 2, 4, 2, 1, norm='sn', activation='lrelu', use_bias=True, pad_type='reflect')] 121 | dim = dim * 2 122 | 123 | model += [Conv2dBlock(dim, dim * 2, 4, 1, 1, norm='sn', activation='lrelu', use_bias=True, pad_type='reflect')] 124 | 125 | self.pad = nn.ReflectionPad2d(1) 126 | self.conv = nn.utils.spectral_norm( 127 | nn.Conv2d(dim * 2, 1, kernel_size=4, stride=1, padding=1, bias=False)) 128 | 129 | self.model = nn.Sequential(*model) 130 | 131 | def forward(self, x): 132 | x = self.model(x) 133 | x = self.pad(x) 134 | x = self.conv(x) 135 | return x 136 | 137 | def calc_dis_loss(self, img_real, img_fake): 138 | # calculate the loss to train D 139 | logit_real = self.forward(img_real) 140 | logit_fake = self.forward(img_fake) 141 | loss = F.mse_loss(logit_real, torch.ones_like(logit_real).cuda()) \ 142 | + F.mse_loss(logit_fake, torch.zeros_like(logit_fake).cuda()) 143 | return loss 144 | 145 | def calc_gen_loss(self, img_fake): 146 | # calculate the loss to train G 147 | logit_fake = self.forward(img_fake) 148 | loss = F.mse_loss(logit_fake, torch.ones_like(logit_fake).cuda()) 149 | return loss 150 | 151 | 152 | class Styler(nn.Module): 153 | def __init__(self, args): 154 | super(Styler, self).__init__() 155 | self.gen = Gen_Style(args) 156 | self.dis = Dis() 157 | 158 | def encode(self, x): 159 | return self.gen.encode(x) 160 | 161 | def decode(self, content, style): 162 | return self.gen.decode(content, style) 163 | 164 | def forward(self, img_p, s): 165 | output = self.gen(img_p, s) 166 | return output 167 | 168 | def save(self, dir, step): 169 | gen_name = os.path.join(dir, 'gen_%08d.pt' % (step + 1)) 170 | torch.save(self.gen.state_dict(), gen_name) 171 | dis_name = os.path.join(dir, 'dis_%08d.pt' % (step + 1)) 172 | torch.save(self.dis.state_dict(), dis_name) 173 | 174 | def load(self, path): 175 | state_dict = torch.load(path) 176 | self.gen.load_state_dict(state_dict) 177 | -------------------------------------------------------------------------------- /networks/warper.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from utils import make_init_field 9 | from .modules import Conv2dBlock, LinearBlock 10 | 11 | 12 | class Encoder(nn.Module): 13 | def __init__(self, input_dim=3, dim=64, latent_dim=32, norm='bn', activation='relu', pad_type='reflect'): 14 | super(Encoder, self).__init__() 15 | self.model = [] 16 | self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activation, pad_type=pad_type)] 17 | for i in range(2): 18 | self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activation, pad_type=pad_type)] 19 | dim *= 2 20 | count = int(math.log2(dim / latent_dim)) 21 | for i in range(count): 22 | self.model += [Conv2dBlock(dim, dim // 2, 1, 1, 0, norm=norm, activation=activation, pad_type=pad_type)] 23 | dim //= 2 24 | self.model = nn.Sequential(*self.model) 25 | 26 | self.pool = [] 27 | self.pool += [ 28 | nn.AdaptiveAvgPool2d(1), 29 | nn.BatchNorm2d(latent_dim) 30 | ] 31 | self.pool = nn.Sequential(*self.pool) 32 | 33 | def forward(self, x, norm=False): 34 | x = self.model(x) 35 | embedding = self.pool(x) 36 | return x, embedding 37 | 38 | 39 | class Decoder(nn.Module): 40 | def __init__(self, dim=32, output_dim=3, norm='bn', activation='relu', pad_type='reflect'): 41 | super(Decoder, self).__init__() 42 | self.model = [] 43 | count = 8 - int(math.log2(dim)) 44 | for i in range(count): 45 | self.model += [Conv2dBlock(dim, dim * 2, 1, 1, 0, norm=norm, activation=activation, pad_type=pad_type)] 46 | dim *= 2 47 | for i in range(2): 48 | self.model += [nn.Upsample(scale_factor=2), 49 | Conv2dBlock(dim, dim // 2, 5, 1, 2, norm=norm, activation=activation, pad_type=pad_type)] 50 | dim //= 2 51 | self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)] 52 | self.model = nn.Sequential(*self.model) 53 | 54 | def forward(self, x): 55 | x = self.model(x) 56 | return x 57 | 58 | 59 | class WarpEncoder(nn.Module): 60 | def __init__(self, input_dim=2, dim=64, downs=4, code_dim=8, norm='bn', activation='relu', 61 | pad_type='reflect'): 62 | super(WarpEncoder, self).__init__() 63 | 64 | self.model = [] 65 | self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activation, pad_type=pad_type)] 66 | for i in range(2): 67 | self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activation, pad_type=pad_type)] 68 | dim *= 2 69 | for i in range(downs - 2): 70 | self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activation, pad_type=pad_type)] 71 | self.model += [nn.AdaptiveAvgPool2d(1), 72 | nn.Conv2d(dim, code_dim, 1, 1, 0)] 73 | self.model += [nn.BatchNorm2d(code_dim)] 74 | self.model = nn.Sequential(*self.model) 75 | 76 | def forward(self, x): 77 | x = self.model(x) 78 | return x 79 | 80 | 81 | class WarpDecoder(nn.Module): 82 | def __init__(self, latent_dim=96, output_dim=2, output_size=256, dim=256, ups=4, norm='bn', activation='relu', 83 | pad_type='reflect'): 84 | super(WarpDecoder, self).__init__() 85 | self.init_size = output_size // (2 ** ups) 86 | self.init_dim = dim 87 | self.linear = [] 88 | self.linear += [LinearBlock(latent_dim, dim * self.init_size ** 2, norm=norm, activation=activation)] 89 | self.linear = nn.Sequential(*self.linear) 90 | 91 | self.conv = [] 92 | 93 | for i in range(ups): 94 | self.conv += [nn.Upsample(scale_factor=2), 95 | Conv2dBlock(dim, dim // 2, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type), 96 | Conv2dBlock(dim // 2, dim // 2, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type) 97 | ] 98 | dim //= 2 99 | 100 | self.conv += [Conv2dBlock(dim, output_dim, 5, 1, 2, norm='none', activation='none', pad_type=pad_type)] 101 | self.conv = nn.Sequential(*self.conv) 102 | 103 | def forward(self, z): 104 | z = z.view(z.size(0), -1) 105 | out = self.linear(z) 106 | out = out.view(out.shape[0], self.init_dim, self.init_size, self.init_size) 107 | out = self.conv(out) 108 | return out 109 | 110 | 111 | class Warper(nn.Module): 112 | def __init__(self, args): 113 | super(Warper, self).__init__() 114 | self.encoder_p = Encoder(latent_dim=args.embedding_dim) 115 | self.decoder_p = Decoder(dim=args.embedding_dim) 116 | self.encoder_w = WarpEncoder(input_dim=2, dim=64, downs=4, code_dim=args.warp_dim) 117 | self.decoder_w = WarpDecoder(latent_dim=(args.embedding_dim + args.warp_dim), ups=4, output_size=args.field_size) 118 | self.const_field = make_init_field(args.img_size).cuda() 119 | self.factor = args.img_size // args.field_size 120 | 121 | def encode_p(self, img_p): 122 | return self.encoder_p(img_p) 123 | 124 | def decode_p(self, feat): 125 | return self.decoder_p(feat) 126 | 127 | def encode_f(self, field): 128 | field = F.interpolate(field.permute(0, 3, 1, 2), scale_factor=1 / self.factor, mode='bilinear', 129 | align_corners=True) 130 | z = self.encoder_w(field) 131 | return z 132 | 133 | def decode_f(self, embedding, z, scale=1.0): 134 | flow = torch.cat((embedding, z), dim=1) 135 | flow = self.decoder_w(flow) 136 | flow = F.interpolate(flow, scale_factor=self.factor, mode='bilinear', align_corners=True).permute(0, 2, 3, 1) 137 | field = self.const_field + scale * flow 138 | return flow, field 139 | 140 | def forward(self, img_p, z, scale=1.0): 141 | _, embedding = self.encoder_p(img_p) 142 | flow, field_pred = self.decode_f(embedding, z, scale) 143 | output = F.grid_sample(img_p, field_pred, align_corners=True) 144 | return output, field_pred, flow 145 | 146 | def save(self, dir, step): 147 | warper_name = os.path.join(dir, 'warper_%08d.pt' % (step + 1)) 148 | torch.save(self.state_dict(), warper_name) 149 | 150 | def load(self, path): 151 | state_dict = torch.load(path) 152 | self.load_state_dict(state_dict) 153 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # CariMe-pytorch 2 | 3 | 4 | The official pytorch implementation of our TMM paper CariMe: Unpaired Caricature Generation with Multiple Exaggerations. 5 | 6 | ![examples](images/examples.png) 7 | 8 | >CariMe: Unpaired Caricature Generation with Multiple Exaggerations 9 | > 10 | >Zheng Gu, Chuanqi Dong, Jing Huo, Wenbin Li, and Yang Gao 11 | > 12 | >Paper: https://ieeexplore.ieee.org/abstract/document/9454341/ 13 | 14 | 15 | 16 | ## Prerequisites 17 | - Python 3.6 18 | - Pytorch 1.5.1 19 | - scikit-image 0.17.2 20 | 21 | ## Preparing Dataset 22 | - Get the [Webcaricature](https://cs.nju.edu.cn/rl/WebCaricature.htm) dataset, unzip the dataset to the `data` folder and align the dataset by running the following script: 23 | ```shell script 24 | python alignment.py 25 | ``` 26 | 27 | ## Training 28 | Train the Warper: 29 | ```shell script 30 | python train_warper.py 31 | ``` 32 | Train the Styler: 33 | ```shell script 34 | python train_styler.py 35 | ``` 36 | 37 | ## Testing 38 | - Test the Warper only: 39 | ```shell script 40 | python test_warper.py --scale 1.0 41 | ``` 42 | 43 | - Test the Styler only: 44 | ```shell script 45 | python test_styler.py 46 | ``` 47 | 48 | - Generate caricatures with both exaggeration and style transfer: 49 | ```shell script 50 | python main_generate.py --model_path_warper pretrained/warper.pt --model_path_styler pretrained/styler.pt 51 | ``` 52 | 53 | 54 | - Generate caricatures with both exaggeration and style transfer for a single image: 55 | ```shell script 56 | python main_generate_single_image.py --model_path_warper pretrained/warper.pt --model_path_styler pretrained/styler.pt --input_path images/Meg Ryan/P00015.jpg --generate_num 5 --scale 1.0 57 | ``` 58 | 59 | The above command will translate the input photo into 5 caricatures with different exaggerations and styles: 60 | 61 | ![examples](images/Meg%20Ryan/P00015_gen.jpg) 62 | 63 | 64 | ## Pretrained Models 65 | The pre-trained models are shared [here](https://drive.google.com/drive/folders/1hBdCqWZ-kqvVLOCz-j9faLNkIbifBr3t?usp=sharing). 66 | 67 | ## Citation 68 | If you use this code for your research, please cite our paper. 69 | 70 | @article{gu2021carime, 71 | title={CariMe: Unpaired Caricature Generation with Multiple Exaggerations}, 72 | author={Gu, Zheng and Dong, Chuanqi and Huo, Jing and Li, Wenbin and Gao, Yang}, 73 | journal={IEEE Transactions on Multimedia}, 74 | year={2021}, 75 | publisher={IEEE} 76 | } 77 | 78 | 79 | ## Reference 80 | Some of our code is based on [FUNIT](https://github.com/NVlabs/FUNIT) and [UGATIT](https://github.com/znxlwm/UGATIT-pytorch). 81 | -------------------------------------------------------------------------------- /test_styler.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import argparse 3 | import os 4 | import shutil 5 | import torch 6 | import numpy as np 7 | import random 8 | 9 | from tqdm import tqdm 10 | from networks.styler import Styler 11 | from utils import unload_img, str2bool 12 | from dataset import make_dataset 13 | from torch.utils.data import DataLoader 14 | 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--data_root', type=str, default='data/WebCaricature_align_1.3_256') 19 | parser.add_argument('--name', type=str, default='results/styler') 20 | parser.add_argument('--model', type=str, default='gen_00200000.pt') 21 | parser.add_argument('--output_dir', type=str, default='test') 22 | 23 | parser.add_argument('--resize_crop', type=str2bool, default=False) 24 | parser.add_argument('--hflip', type=str2bool, default=False) 25 | parser.add_argument('--enlarge', type=str2bool, default=False) 26 | parser.add_argument('--mode', type=str, default='test') 27 | parser.add_argument('--batch_size', type=int, default=4) 28 | parser.add_argument('--num_workers', type=int, default=8) 29 | parser.add_argument('--style_dim', type=int, default=8) 30 | parser.add_argument('--down_es', type=int, default=2) 31 | parser.add_argument('--restype', type=str, default='adalin') 32 | 33 | args = parser.parse_args() 34 | 35 | if __name__ == '__main__': 36 | SEED = 0 37 | random.seed(SEED) 38 | np.random.seed(SEED) 39 | torch.manual_seed(SEED) 40 | torch.cuda.manual_seed(SEED) 41 | 42 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 43 | model_path = os.path.join(args.name, 'checkpoints', args.model) 44 | print('load model: ', model_path) 45 | output_path = os.path.join(args.name, args.output_dir) 46 | print('output path: ', output_path) 47 | if os.path.exists(output_path): 48 | shutil.rmtree(output_path) 49 | if not os.path.exists(output_path): 50 | os.makedirs(output_path) 51 | 52 | dataset = make_dataset(args) 53 | print(len(dataset)) 54 | dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, 55 | num_workers=args.num_workers) 56 | 57 | model = Styler(args) 58 | model.load(model_path) 59 | model.to(device) 60 | model.eval() 61 | 62 | for batch, item in tqdm(enumerate(dataloader)): 63 | 64 | img_ps = item['img_p'].to(device) 65 | names = item['name'] 66 | filenames = item['filename'] 67 | 68 | s = torch.randn(img_ps.size(0), 8, 1, 1).cuda() 69 | outputs = model(img_ps, s) 70 | 71 | for i in range(img_ps.size()[0]): 72 | input = img_ps[i].detach().cpu() 73 | output = outputs[i].detach().cpu() 74 | name = names[i] 75 | filename = filenames[i] 76 | 77 | figure = torch.cat((input, output), dim=2) 78 | unload_img(figure).save(os.path.join(output_path, '{}_{}.jpg'.format(name, filename)), 'jpeg') -------------------------------------------------------------------------------- /test_warper.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import argparse 4 | import os 5 | 6 | import torch 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | from networks import Warper 11 | from utils import unload_img, str2bool, shutil 12 | from dataset import make_dataset 13 | from torch.utils.data import DataLoader 14 | 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--data_root', type=str, default='data/WebCaricature_align_1.3_256') 19 | parser.add_argument('--name', type=str, default='results/warper') 20 | parser.add_argument('--model', type=str, default='warper_00020000.pt') 21 | parser.add_argument('--output_dir', type=str, default='test') 22 | 23 | parser.add_argument('--resize_crop', type=str2bool, default=True) 24 | parser.add_argument('--enlarge', type=str2bool, default=False) 25 | parser.add_argument('--same_id', type=str2bool, default=True) 26 | parser.add_argument('--hflip', type=str2bool, default=False) 27 | 28 | parser.add_argument('--mode', type=str, default='test') 29 | parser.add_argument('--batch_size', type=int, default=4) 30 | parser.add_argument('--num_workers', type=int, default=8) 31 | 32 | parser.add_argument('--img_size', type=int, default=256) 33 | parser.add_argument('--field_size', type=int, default=128) 34 | parser.add_argument('--embedding_dim', type=int, default=32) 35 | parser.add_argument('--warp_dim', type=int, default=64) 36 | parser.add_argument('--scale', type=float, default=1.0) 37 | 38 | args = parser.parse_args() 39 | 40 | if __name__ == '__main__': 41 | SEED = 0 42 | random.seed(SEED) 43 | np.random.seed(SEED) 44 | torch.manual_seed(SEED) 45 | torch.cuda.manual_seed(SEED) 46 | 47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 48 | model_path = os.path.join(args.name, 'checkpoints', args.model) 49 | print('load model: ', model_path) 50 | output_path = os.path.join(args.name, args.output_dir) 51 | print('output path: ', output_path) 52 | if os.path.exists(output_path): 53 | shutil.rmtree(output_path) 54 | os.makedirs(output_path, exist_ok=True) 55 | 56 | dataset = make_dataset(args) 57 | dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=False, drop_last=False, 58 | num_workers=args.num_workers) 59 | 60 | warper = Warper(args) 61 | state_dict = torch.load(model_path) 62 | warper.load_state_dict(state_dict) 63 | warper.to(device) 64 | warper.eval() 65 | 66 | for batch, item in tqdm(enumerate(dataloader)): 67 | img_p = item['img_p'].to(device) 68 | names = item['name'] 69 | filenames = item['filename'] 70 | 71 | z = torch.randn(img_p.size()[0], args.warp_dim, 1, 1).cuda() 72 | img_warp, psmap, flows = warper(img_p, z, scale=args.scale) 73 | 74 | for i in range(img_p.size()[0]): 75 | input = img_p[i] 76 | result = img_warp[i] 77 | flow = flows[i] 78 | name = names[i] 79 | filename = filenames[i] 80 | 81 | output = torch.cat((input, result), dim=2) 82 | unload_img(output.detach().cpu()).save(os.path.join(output_path, '{}_{}.jpg'.format(name, filename)), 'jpeg') 83 | -------------------------------------------------------------------------------- /train_styler.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | 4 | import argparse 5 | import os 6 | import torch 7 | import torch.optim as optim 8 | import numpy as np 9 | 10 | from networks import Styler, l1_loss, RhoClipper 11 | from utils import prepare_sub_folder, weights_init, str2bool, write_image 12 | from dataset import make_dataset 13 | from torch.utils.data import DataLoader 14 | 15 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--data_root', type=str, default='data/WebCaricature_align_1.3_256') 19 | parser.add_argument('--output_path', type=str, default='results/styler/') 20 | parser.add_argument('--max_dataset_size', type=int, default=100000) 21 | 22 | parser.add_argument('--resize_crop', type=str2bool, default=True) 23 | parser.add_argument('--enlarge', type=str2bool, default=True) 24 | parser.add_argument('--same_id', type=str2bool, default=False) 25 | parser.add_argument('--hflip', type=str2bool, default=True) 26 | 27 | parser.add_argument('--mode', type=str, default='train') 28 | parser.add_argument('--iteration', type=int, default=500000) 29 | parser.add_argument('--snapshot_log', type=int, default=100) 30 | parser.add_argument('--snapshot_vis', type=int, default=1000) 31 | parser.add_argument('--snapshot_save', type=int, default=100000) 32 | 33 | parser.add_argument('--batch_size', type=int, default=4) 34 | parser.add_argument('--num_workers', type=int, default=8) 35 | parser.add_argument('--lr', type=float, default=0.0001) 36 | parser.add_argument('--style_dim', type=int, default=8) 37 | parser.add_argument('--w_recon_img', type=float, default=10) 38 | parser.add_argument('--w_cyc_s', type=float, default=1) 39 | parser.add_argument('--w_cyc_c', type=float, default=1) 40 | args = parser.parse_args() 41 | 42 | if __name__ == '__main__': 43 | SEED = 0 44 | random.seed(SEED) 45 | np.random.seed(SEED) 46 | torch.manual_seed(SEED) 47 | torch.cuda.manual_seed(SEED) 48 | 49 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 50 | checkpoint_dir, image_dir = prepare_sub_folder(args.output_path) 51 | 52 | dataset = make_dataset(args) 53 | dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True, drop_last=False, 54 | num_workers=args.num_workers) 55 | 56 | model = Styler(args) 57 | model.to(device) 58 | model.train() 59 | 60 | gen_para = list(model.gen.parameters()) 61 | dis_para = list(model.dis.parameters()) 62 | gen_opt = optim.Adam([p for p in gen_para if p.requires_grad], lr=args.lr, betas=(0.5, 0.999), weight_decay=1e-5) 63 | dis_opt = optim.Adam([p for p in dis_para if p.requires_grad], lr=args.lr, betas=(0.5, 0.999), weight_decay=1e-5) 64 | model.apply(weights_init('kaiming')) 65 | rho_clipper = RhoClipper(0, 1) 66 | 67 | train_iter = iter(dataloader) 68 | start = time.time() 69 | for step in range(0, args.iteration + 1): 70 | try: 71 | item = train_iter.next() 72 | except: 73 | train_iter = iter(dataloader) 74 | item = train_iter.next() 75 | 76 | if step > (args.iteration // 2): 77 | gen_opt.param_groups[0]['lr'] -= ((args.lr - 0) / (args.iteration // 2)) 78 | dis_opt.param_groups[0]['lr'] -= ((args.lr - 0) / (args.iteration // 2)) 79 | 80 | img_p = item['img_p'].to(device) 81 | img_c = item['img_c'].to(device) 82 | 83 | # update discriminator 84 | dis_opt.zero_grad() 85 | random_s = torch.randn(img_p.size(0), args.style_dim, 1, 1).cuda() 86 | 87 | content_p, _ = model.encode(img_p) 88 | fake_c = model.decode(content_p, random_s) 89 | 90 | loss_adv_dis = model.dis.calc_dis_loss(img_c, fake_c) 91 | loss_adv_dis.backward() 92 | dis_opt.step() 93 | 94 | # update generator 95 | gen_opt.zero_grad() 96 | random_s = torch.randn(img_p.size(0), 8, 1, 1).cuda() 97 | 98 | content_p, style_p = model.encode(img_p) 99 | content_c, style_c = model.encode(img_c) 100 | 101 | recon_p = model.decode(content_p, style_p) 102 | recon_c = model.decode(content_c, style_c) 103 | 104 | img_p2c = model.decode(content_p, random_s) 105 | recon_content_p, recon_s = model.encode(img_p2c) 106 | 107 | loss_recon_img = (l1_loss(recon_c, img_c) + l1_loss(recon_p, img_p)) * args.w_recon_img 108 | loss_cyc_style = l1_loss(recon_s, random_s) * args.w_cyc_s 109 | loss_cyc_content = l1_loss(recon_content_p, content_p) * args.w_cyc_c 110 | loss_adv_gen = model.dis.calc_gen_loss(img_p2c) 111 | 112 | loss_gen = loss_recon_img + loss_adv_gen + loss_cyc_content + loss_cyc_style 113 | loss_gen.backward() 114 | gen_opt.step() 115 | 116 | model.gen.apply(rho_clipper) 117 | 118 | # output log 119 | if (step + 1) % args.snapshot_log == 0: 120 | end = time.time() 121 | print( 122 | 'Step: {} ({:.0f}%) time: {} loss_adv_g:{:.4f} loss_adv_d:{:.4f} loss_recon_img:{:.4f} loss_cyc_c:{:.4f} loss_cyc_s:{:.4f} '.format( 123 | step + 1, 124 | 100.0 * step / args.iteration, 125 | int(end - start), 126 | loss_adv_gen, 127 | loss_adv_dis, 128 | loss_recon_img, 129 | loss_cyc_content, 130 | loss_cyc_style 131 | )) 132 | 133 | if (step + 1) % args.snapshot_vis == 0: 134 | # input photo, input caricature, recon photo, recon caricature, image translated 135 | vis = torch.stack((img_p, img_c, recon_p, recon_c, img_p2c), dim=1) 136 | write_image(step, image_dir, vis) 137 | 138 | # save checkpoint 139 | if (step + 1) % args.snapshot_save == 0: 140 | model.save(checkpoint_dir, step) 141 | -------------------------------------------------------------------------------- /train_warper.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import argparse 4 | import os 5 | import time 6 | 7 | import torch 8 | import torch.optim as optim 9 | import torch.nn.functional as F 10 | import numpy as np 11 | from torch.utils.data import DataLoader 12 | 13 | from networks import Warper, l1_loss, tv_loss 14 | from dataset import make_dataset 15 | from utils import prepare_sub_folder, weights_init, str2bool, write_image 16 | 17 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--data_root', type=str, default='data/WebCaricature_align_1.3_256') 21 | parser.add_argument('--output_path', type=str, default='results/warper/') 22 | parser.add_argument('--max_dataset_size', type=int, default=10000) 23 | 24 | parser.add_argument('--resize_crop', type=str2bool, default=True) 25 | parser.add_argument('--enlarge', type=str2bool, default=False) 26 | parser.add_argument('--same_id', type=str2bool, default=True) 27 | parser.add_argument('--hflip', type=str2bool, default=True) 28 | 29 | parser.add_argument('--mode', type=str, default='train') 30 | parser.add_argument('--iteration', type=int, default=20000) 31 | parser.add_argument('--snapshot_log', type=int, default=100) 32 | parser.add_argument('--snapshot_save', type=int, default=10000) 33 | 34 | parser.add_argument('--batch_size', type=int, default=4) 35 | parser.add_argument('--num_workers', type=int, default=8) 36 | parser.add_argument('--lr', type=float, default=0.0001) 37 | 38 | parser.add_argument('--img_size', type=int, default=256) 39 | parser.add_argument('--field_size', type=int, default=128) 40 | parser.add_argument('--embedding_dim', type=int, default=32) 41 | parser.add_argument('--warp_dim', type=int, default=64) 42 | parser.add_argument('--scale', type=float, default=1.0) 43 | 44 | parser.add_argument('--w_recon_img', type=float, default=10) 45 | parser.add_argument('--w_recon_field', type=float, default=10) 46 | parser.add_argument('--w_tv', type=float, default=0.000005) 47 | 48 | args = parser.parse_args() 49 | 50 | if __name__ == '__main__': 51 | SEED = 0 52 | random.seed(SEED) 53 | np.random.seed(SEED) 54 | torch.manual_seed(SEED) 55 | torch.cuda.manual_seed(SEED) 56 | 57 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 58 | checkpoint_dir, image_dir = prepare_sub_folder(args.output_path, delete_first=True) 59 | 60 | dataset = make_dataset(args) 61 | dataloader = DataLoader(dataset=dataset, batch_size=args.batch_size, shuffle=True, drop_last=False, 62 | num_workers=args.num_workers) 63 | 64 | warper = Warper(args) 65 | warper.to(device) 66 | warper.train() 67 | 68 | paras = list(warper.parameters()) 69 | opt = optim.Adam([p for p in paras if p.requires_grad], lr=args.lr, betas=(0.5, 0.999), weight_decay=1e-5) 70 | warper.apply(weights_init('kaiming')) 71 | 72 | train_iter = iter(dataloader) 73 | start = time.time() 74 | for step in range(0, args.iteration + 1): 75 | try: 76 | item = train_iter.next() 77 | except: 78 | train_iter = iter(dataloader) 79 | item = train_iter.next() 80 | 81 | if step > args.iteration // 2: 82 | opt.param_groups[0]['lr'] -= ((args.lr - 0.) / (args.iteration // 2)) 83 | 84 | img_p = item['img_p'].to(device) 85 | img_c = item['img_c'].to(device) 86 | field_p2c = item['field_p2c'].to(device) 87 | field_m2c = item['field_m2c'].to(device) 88 | field_m2p = item['field_m2p'].to(device) 89 | 90 | opt.zero_grad() 91 | feat, embedding = warper.encode_p(img_p) 92 | img_recon = warper.decode_p(feat) 93 | loss_recon_p = l1_loss(img_p, img_recon) * args.w_recon_img 94 | 95 | z = warper.encode_f(field_m2c) 96 | _, field_recon = warper.decode_f(embedding, z, scale=args.scale) 97 | loss_recon_warp = l1_loss(field_recon, field_p2c) * args.w_recon_field 98 | 99 | random_z = torch.randn(img_p.size(0), args.warp_dim, 1, 1).cuda() 100 | _, field_gen = warper.decode_f(embedding, random_z, scale=args.scale) 101 | img_warp_gen = F.grid_sample(img_p, field_gen, align_corners=True) 102 | loss_tv = tv_loss(img_warp_gen) * args.w_tv 103 | 104 | loss_total = loss_recon_p + loss_recon_warp + loss_tv 105 | loss_total.backward() 106 | opt.step() 107 | 108 | # output log 109 | if (step + 1) % args.snapshot_log == 0: 110 | end = time.time() 111 | print('Step: {} ({:.0f}%) time:{} loss_rec_p:{:.4f} loss_rec_warp:{:.4f} loss_tv:{:.4f}'.format( 112 | step + 1, 113 | 100.0 * step / args.iteration, 114 | int(end - start), 115 | loss_recon_p, 116 | loss_recon_warp, 117 | loss_tv)) 118 | # input photo, input caricature, image_warp_p2c, image_warp_generated 119 | vis = torch.stack((img_p, img_c, F.grid_sample(img_p, field_p2c, align_corners=True), img_warp_gen), dim=1) 120 | write_image(step, image_dir, vis) 121 | 122 | # save checkpoint 123 | if (step + 1) % args.snapshot_save == 0: 124 | warper.save(checkpoint_dir, step) 125 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import shutil 5 | import numpy as np 6 | 7 | import torch 8 | import torch.nn.init as init 9 | import torchvision.utils as vutils 10 | import torchvision.transforms as transforms 11 | import skimage.transform as sktf 12 | from PIL import Image 13 | 14 | 15 | def make_field(length): 16 | temp_height = torch.linspace(start=-1.0, end=1.0, steps=length, requires_grad=False).view(length, 1) 17 | temp_width = torch.linspace(start=-1, end=1, steps=length, requires_grad=False).view(1, length) 18 | pos_x = temp_height.repeat(1, length).view(length, length, 1) 19 | pos_y = temp_width.repeat(length, 1).view(length, length, 1) 20 | return pos_y.numpy(), pos_x.numpy() 21 | 22 | 23 | def make_init_field(length=256): 24 | y, x = make_field(length) 25 | map = np.concatenate((y, x), axis=2) 26 | map = torch.from_numpy(map).unsqueeze(0) 27 | return map 28 | 29 | 30 | def warp_image(image, src_points, dst_points): 31 | src_points = np.array( 32 | [ 33 | [0, 0], [0, image.shape[0]], 34 | [image.shape[0], 0], list(image.shape[:2]) 35 | ] + src_points.tolist() 36 | ) 37 | dst_points = np.array( 38 | [ 39 | [0, 0], [0, image.shape[0]], 40 | [image.shape[0], 0], list(image.shape[:2]) 41 | ] + dst_points.tolist() 42 | ) 43 | 44 | tform3 = sktf.PiecewiseAffineTransform() 45 | tform3.estimate(dst_points, src_points) 46 | 47 | warped = sktf.warp(image, tform3, output_shape=image.shape) 48 | return warped 49 | 50 | 51 | def load_filenames(data_root, name, token, file_type): 52 | files = os.listdir(os.path.join(data_root, token, name)) 53 | return [os.path.join(data_root, token, name, file) for file in files if file.startswith(file_type)] 54 | 55 | 56 | def warp_position_map(p, c, length=256): 57 | pos_ys, pos_xs = make_field(length) 58 | pos_ys_warped = warp_image(pos_ys, p, c) 59 | pos_xs_warped = warp_image(pos_xs, p, c) 60 | return pos_ys_warped, pos_xs_warped 61 | 62 | 63 | def load_img(path): 64 | transform = transforms.Compose([ 65 | transforms.ToTensor(), 66 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 67 | ] 68 | ) 69 | img = Image.open(path).convert('RGB') 70 | img = transform(img) 71 | img = img.unsqueeze(0) 72 | return img 73 | 74 | 75 | def unload_img(img): 76 | img = (img + 1) / 2 77 | tf = transforms.Compose([ 78 | transforms.ToPILImage() 79 | ]) 80 | return tf(img) 81 | 82 | 83 | def prepare_sub_folder(output_path, delete_first=True): 84 | print('preparing sub folder for {}'.format(output_path)) 85 | if delete_first and os.path.exists(output_path): 86 | shutil.rmtree(output_path) 87 | os.makedirs(output_path, exist_ok=True) 88 | 89 | checkpoint_path = os.path.join(output_path, 'checkpoints') 90 | os.makedirs(checkpoint_path, exist_ok=True) 91 | 92 | images_path = os.path.join(output_path, 'images') 93 | os.makedirs(images_path, exist_ok=True) 94 | 95 | return checkpoint_path, images_path 96 | 97 | 98 | def weights_init(init_type='gaussian', mean=0.0, std=0.02): 99 | def init_fun(m): 100 | classname = m.__class__.__name__ 101 | if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'): 102 | if init_type == 'gaussian': 103 | init.normal_(m.weight.data, mean, std) 104 | elif init_type == 'xavier': 105 | init.xavier_normal_(m.weight.data, gain=math.sqrt(2)) 106 | elif init_type == 'kaiming': 107 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 108 | elif init_type == 'orthogonal': 109 | init.orthogonal_(m.weight.data, gain=math.sqrt(2)) 110 | elif init_type == 'default': 111 | pass 112 | else: 113 | assert 0, "Unsupported initialization: {}".format(init_type) 114 | if hasattr(m, 'bias') and m.bias is not None: 115 | init.constant_(m.bias.data, 0.0) 116 | 117 | return init_fun 118 | 119 | 120 | def str2bool(v): 121 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 122 | return True 123 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 124 | return False 125 | else: 126 | raise argparse.ArgumentTypeError('Unsupported value encountered.') 127 | 128 | 129 | def write_image(iterations, dir, im_ins): 130 | B, K, C, H, W = im_ins.size() 131 | file_name = os.path.join(dir, '%08d' % (iterations + 1) + '.jpg') 132 | image_tensor = im_ins.view(B*K, C, H, W) 133 | image_grid = vutils.make_grid(image_tensor.data, nrow=K, padding=0, normalize=True) 134 | vutils.save_image(image_grid, file_name, nrow=1) 135 | 136 | --------------------------------------------------------------------------------