├── README.md ├── assets ├── cvpr2018.mp4 ├── cvpr2018_large.gif ├── cvpr2018_small.gif └── teaser.png ├── batches.py ├── coco.yaml ├── coco_retrain.yaml ├── deepfashion.yaml ├── deepfashion_retrain.yaml ├── deeploss.py ├── download_data.sh ├── main.py ├── market.yaml ├── market_retrain.yaml ├── models.py ├── nn.py ├── nn_compat.py └── requirements.txt /README.md: -------------------------------------------------------------------------------- 1 | # A Variational U-Net for Conditional Appearance and Shape Generation 2 | 3 | This repository contains training code for the CVPR 2018 spotlight 4 | 5 | [**A Variational U-Net for Conditional Appearance and Shape Generation**](https://compvis.github.io/vunet/images/vunet.pdf) 6 | 7 | The model learns to infer appearance from a single image and can synthesize 8 | images with that appearance in different poses. 9 | 10 | ![teaser](assets/cvpr2018_large.gif) 11 | 12 | [Project page with more results](https://compvis.github.io/vunet/) 13 | 14 | ## Notes 15 | 16 | This is a slightly modified version of the code that was used to produce the 17 | results in the paper. The original code was cleaned up, the data dependent 18 | weight initialization was made compatible with `tensorflow >= 1.3.0` and a 19 | unified model between the datasets is used. You can find [the original code and 20 | checkpoints online (`vunet/runs`)](https://heibox.uni-heidelberg.de/d/71842715a8/) but if you want to use 21 | them, please keep in mind that: 22 | 23 | - the original checkpoints are not compatible with the graphs defined in this 24 | repository. You must use the original code distributed with the checkpoints. 25 | - the original code uses a data dependent weight initialization scheme which 26 | does not work with `tensorflow >= 1.3.0`. You should use `tensorflow==1.2.1`. 27 | - the original code became a bit of a mess and we can no longer provide support for 28 | it. 29 | 30 | ## Requirements 31 | 32 | The code was developed with Python 3. Dependencies can be installed with 33 | 34 | pip install -r requirements.txt 35 | 36 | These requirements correspond to the dependency versions used to generate the 37 | pretrained models but other versions might work as well. 38 | 39 | ## Training 40 | 41 | [Download](https://heibox.uni-heidelberg.de/d/71842715a8/) and unpack the desired dataset. 42 | This results in a folder containing an `index.p` file. Either add a symbolic 43 | link named `data` pointing to the download directory or adjust the path to 44 | the `index.p` file in the `.yaml` config file. 45 | 46 | For convenience, you can also run 47 | 48 | ./download_data.sh 49 | 50 | which will perform the above steps automatically. `` can be one of 51 | `coco`, `deepfashion` or `market`. To train the model, run 52 | 53 | python main.py --config .yaml 54 | 55 | By default, images and checkpoints are saved to `log/`. To 56 | change the log directory and other options, see 57 | 58 | python main.py -h 59 | 60 | and the corresponding configuration file. To obtain images of optimal 61 | quality it is recommended to train for a second round with a loss based on 62 | Gram matrices. To do so run 63 | 64 | python main.py --config _retrain.yaml --retrain --checkpoint 65 | 66 | 67 | ## Pretrained models 68 | 69 | You can find [pretrained models 70 | online (`vunet/pretrained_checkpoints`)](https://heibox.uni-heidelberg.de/d/71842715a8/). 71 | 72 | 73 | ## Other Datasets 74 | 75 | To be able to train the model on your own dataset you must provide a pickled 76 | dictionary with the following keys: 77 | 78 | - `joint_order`: list indicating the order of joints. 79 | - `imgs`: list of paths to images (relative to pickle file). 80 | - `train`: list of booleans indicating if this image belongs to training split 81 | - `joints`: list of `[0,1]` normalized xy joint coordinates of shape `(len(joint_jorder), 2)`. Use negative values for occluded joints. 82 | 83 | `joint_order` should contain 84 | 85 | 'rankle', 'rknee', 'rhip', 'rshoulder', 'relbow', 'rwrist', 'reye', 'lankle', 'lknee', 'lhip', 'lshoulder', 'lelbow', 'lwrist', 'leye', 'cnose' 86 | 87 | and images without valid values for `rhip, rshoulder, lhip, lshoulder` are 88 | ignored. 89 | -------------------------------------------------------------------------------- /assets/cvpr2018.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/vunet/e66450ca4a8beb0598d9bd0754bb1d00d481c6b1/assets/cvpr2018.mp4 -------------------------------------------------------------------------------- /assets/cvpr2018_large.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/vunet/e66450ca4a8beb0598d9bd0754bb1d00d481c6b1/assets/cvpr2018_large.gif -------------------------------------------------------------------------------- /assets/cvpr2018_small.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/vunet/e66450ca4a8beb0598d9bd0754bb1d00d481c6b1/assets/cvpr2018_small.gif -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CompVis/vunet/e66450ca4a8beb0598d9bd0754bb1d00d481c6b1/assets/teaser.png -------------------------------------------------------------------------------- /batches.py: -------------------------------------------------------------------------------- 1 | import PIL.Image 2 | from multiprocessing.pool import ThreadPool 3 | import numpy as np 4 | import pickle 5 | import os 6 | import cv2 7 | import math 8 | 9 | 10 | n_boxes = 8 11 | 12 | 13 | class BufferedWrapper(object): 14 | """Fetch next batch asynchronuously to avoid bottleneck during GPU 15 | training.""" 16 | def __init__(self, gen): 17 | self.gen = gen 18 | self.n = gen.n 19 | self.pool = ThreadPool(1) 20 | self._async_next() 21 | 22 | 23 | def _async_next(self): 24 | self.buffer_ = self.pool.apply_async(next, (self.gen,)) 25 | 26 | 27 | def __next__(self): 28 | result = self.buffer_.get() 29 | self._async_next() 30 | return result 31 | 32 | 33 | def load_img(path, target_size): 34 | """Load image. target_size is specified as (height, width, channels) 35 | where channels == 1 means grayscale. uint8 image returned.""" 36 | img = PIL.Image.open(path) 37 | grayscale = target_size[2] == 1 38 | if grayscale: 39 | if img.mode != 'L': 40 | img = img.convert('L') 41 | else: 42 | if img.mode != 'RGB': 43 | img = img.convert('RGB') 44 | wh_tuple = (target_size[1], target_size[0]) 45 | if img.size != wh_tuple: 46 | img = img.resize(wh_tuple, resample = PIL.Image.BILINEAR) 47 | 48 | x = np.asarray(img, dtype = "uint8") 49 | if len(x.shape) == 2: 50 | x = np.expand_dims(x, -1) 51 | 52 | return x 53 | 54 | 55 | def preprocess(x): 56 | """From uint8 image to [-1,1].""" 57 | return np.cast[np.float32](x / 127.5 - 1.0) 58 | 59 | 60 | def postprocess(x): 61 | """[-1,1] to uint8.""" 62 | x = (x + 1.0) / 2.0 63 | x = np.clip(255 * x, 0, 255) 64 | x = np.cast[np.uint8](x) 65 | return x 66 | 67 | 68 | def tile(X, rows, cols): 69 | """Tile images for display.""" 70 | tiling = np.zeros((rows * X.shape[1], cols * X.shape[2], X.shape[3]), dtype = X.dtype) 71 | for i in range(rows): 72 | for j in range(cols): 73 | idx = i * cols + j 74 | if idx < X.shape[0]: 75 | img = X[idx,...] 76 | tiling[ 77 | i*X.shape[1]:(i+1)*X.shape[1], 78 | j*X.shape[2]:(j+1)*X.shape[2], 79 | :] = img 80 | return tiling 81 | 82 | 83 | def plot_batch(X, out_path): 84 | """Save batch of images tiled.""" 85 | n_channels = X.shape[3] 86 | if n_channels > 3: 87 | X = X[:,:,:,np.random.choice(n_channels, size = 3)] 88 | X = postprocess(X) 89 | rc = math.sqrt(X.shape[0]) 90 | rows = cols = math.ceil(rc) 91 | canvas = tile(X, rows, cols) 92 | canvas = np.squeeze(canvas) 93 | PIL.Image.fromarray(canvas).save(out_path) 94 | 95 | 96 | def make_joint_img(img_shape, jo, joints): 97 | # three channels: left, right, center 98 | scale_factor = img_shape[1] / 128 99 | thickness = int(3 * scale_factor) 100 | imgs = list() 101 | for i in range(3): 102 | imgs.append(np.zeros(img_shape[:2], dtype = "uint8")) 103 | 104 | body = ["lhip", "lshoulder", "rshoulder", "rhip"] 105 | body_pts = np.array([[joints[jo.index(part),:] for part in body]]) 106 | if np.min(body_pts) >= 0: 107 | body_pts = np.int_(body_pts) 108 | cv2.fillPoly(imgs[2], body_pts, 255) 109 | 110 | right_lines = [ 111 | ("rankle", "rknee"), 112 | ("rknee", "rhip"), 113 | ("rhip", "rshoulder"), 114 | ("rshoulder", "relbow"), 115 | ("relbow", "rwrist")] 116 | for line in right_lines: 117 | l = [jo.index(line[0]), jo.index(line[1])] 118 | if np.min(joints[l]) >= 0: 119 | a = tuple(np.int_(joints[l[0]])) 120 | b = tuple(np.int_(joints[l[1]])) 121 | cv2.line(imgs[0], a, b, color = 255, thickness = thickness) 122 | 123 | left_lines = [ 124 | ("lankle", "lknee"), 125 | ("lknee", "lhip"), 126 | ("lhip", "lshoulder"), 127 | ("lshoulder", "lelbow"), 128 | ("lelbow", "lwrist")] 129 | for line in left_lines: 130 | l = [jo.index(line[0]), jo.index(line[1])] 131 | if np.min(joints[l]) >= 0: 132 | a = tuple(np.int_(joints[l[0]])) 133 | b = tuple(np.int_(joints[l[1]])) 134 | cv2.line(imgs[1], a, b, color = 255, thickness = thickness) 135 | 136 | rs = joints[jo.index("rshoulder")] 137 | ls = joints[jo.index("lshoulder")] 138 | cn = joints[jo.index("cnose")] 139 | neck = 0.5*(rs+ls) 140 | a = tuple(np.int_(neck)) 141 | b = tuple(np.int_(cn)) 142 | if np.min(a) >= 0 and np.min(b) >= 0: 143 | cv2.line(imgs[0], a, b, color = 127, thickness = thickness) 144 | cv2.line(imgs[1], a, b, color = 127, thickness = thickness) 145 | 146 | cn = tuple(np.int_(cn)) 147 | leye = tuple(np.int_(joints[jo.index("leye")])) 148 | reye = tuple(np.int_(joints[jo.index("reye")])) 149 | if np.min(reye) >= 0 and np.min(leye) >= 0 and np.min(cn) >= 0: 150 | cv2.line(imgs[0], cn, reye, color = 255, thickness = thickness) 151 | cv2.line(imgs[1], cn, leye, color = 255, thickness = thickness) 152 | 153 | img = np.stack(imgs, axis = -1) 154 | if img_shape[-1] == 1: 155 | img = np.mean(img, axis = -1)[:,:,None] 156 | return img 157 | 158 | 159 | def valid_joints(*joints): 160 | j = np.stack(joints) 161 | return (j >= 0).all() 162 | 163 | 164 | def get_crop(bpart, joints, jo, wh, o_w, o_h, ar = 1.0): 165 | bpart_indices = [jo.index(b) for b in bpart] 166 | part_src = np.float32(joints[bpart_indices]) 167 | 168 | # fall backs 169 | if not valid_joints(part_src): 170 | if bpart[0] == "lhip" and bpart[1] == "lknee": 171 | bpart = ["lhip"] 172 | bpart_indices = [jo.index(b) for b in bpart] 173 | part_src = np.float32(joints[bpart_indices]) 174 | elif bpart[0] == "rhip" and bpart[1] == "rknee": 175 | bpart = ["rhip"] 176 | bpart_indices = [jo.index(b) for b in bpart] 177 | part_src = np.float32(joints[bpart_indices]) 178 | elif bpart[0] == "lshoulder" and bpart[1] == "rshoulder" and bpart[2] == "cnose": 179 | bpart = ["lshoulder", "rshoulder", "rshoulder"] 180 | bpart_indices = [jo.index(b) for b in bpart] 181 | part_src = np.float32(joints[bpart_indices]) 182 | 183 | 184 | if not valid_joints(part_src): 185 | return None 186 | 187 | if part_src.shape[0] == 1: 188 | # leg fallback 189 | a = part_src[0] 190 | b = np.float32([a[0],o_h - 1]) 191 | part_src = np.float32([a,b]) 192 | 193 | if part_src.shape[0] == 4: 194 | pass 195 | elif part_src.shape[0] == 3: 196 | # lshoulder, rshoulder, cnose 197 | if bpart == ["lshoulder", "rshoulder", "rshoulder"]: 198 | segment = part_src[1] - part_src[0] 199 | normal = np.array([-segment[1],segment[0]]) 200 | if normal[1] > 0.0: 201 | normal = -normal 202 | 203 | a = part_src[0] + normal 204 | b = part_src[0] 205 | c = part_src[1] 206 | d = part_src[1] + normal 207 | part_src = np.float32([a,b,c,d]) 208 | else: 209 | assert bpart == ["lshoulder", "rshoulder", "cnose"] 210 | neck = 0.5*(part_src[0] + part_src[1]) 211 | neck_to_nose = part_src[2] - neck 212 | part_src = np.float32([neck + 2*neck_to_nose, neck]) 213 | 214 | # segment box 215 | segment = part_src[1] - part_src[0] 216 | normal = np.array([-segment[1],segment[0]]) 217 | alpha = 1.0 / 2.0 218 | a = part_src[0] + alpha*normal 219 | b = part_src[0] - alpha*normal 220 | c = part_src[1] - alpha*normal 221 | d = part_src[1] + alpha*normal 222 | #part_src = np.float32([a,b,c,d]) 223 | part_src = np.float32([b,c,d,a]) 224 | else: 225 | assert part_src.shape[0] == 2 226 | 227 | segment = part_src[1] - part_src[0] 228 | normal = np.array([-segment[1],segment[0]]) 229 | alpha = ar / 2.0 230 | a = part_src[0] + alpha*normal 231 | b = part_src[0] - alpha*normal 232 | c = part_src[1] - alpha*normal 233 | d = part_src[1] + alpha*normal 234 | part_src = np.float32([a,b,c,d]) 235 | 236 | dst = np.float32([[0.0,0.0],[0.0,1.0],[1.0,1.0],[1.0,0.0]]) 237 | part_dst = np.float32(wh * dst) 238 | 239 | M = cv2.getPerspectiveTransform(part_src, part_dst) 240 | return M 241 | 242 | 243 | def normalize(imgs, coords, stickmen, jo, box_factor): 244 | out_imgs = list() 245 | out_stickmen = list() 246 | 247 | bs = len(imgs) 248 | for i in range(bs): 249 | img = imgs[i] 250 | joints = coords[i] 251 | stickman = stickmen[i] 252 | 253 | h,w = img.shape[:2] 254 | o_h = h 255 | o_w = w 256 | h = h // 2**box_factor 257 | w = w // 2**box_factor 258 | wh = np.array([w,h]) 259 | wh = np.expand_dims(wh, 0) 260 | 261 | bparts = [ 262 | ["lshoulder","lhip","rhip","rshoulder"], 263 | ["lshoulder", "rshoulder", "cnose"], 264 | ["lshoulder","lelbow"], 265 | ["lelbow", "lwrist"], 266 | ["rshoulder","relbow"], 267 | ["relbow", "rwrist"], 268 | ["lhip", "lknee"], 269 | ["rhip", "rknee"]] 270 | ar = 0.5 271 | 272 | part_imgs = list() 273 | part_stickmen = list() 274 | for bpart in bparts: 275 | part_img = np.zeros((h,w,3)) 276 | part_stickman = np.zeros((h,w,3)) 277 | M = get_crop(bpart, joints, jo, wh, o_w, o_h, ar) 278 | 279 | if M is not None: 280 | part_img = cv2.warpPerspective(img, M, (h,w), borderMode = cv2.BORDER_REPLICATE) 281 | part_stickman = cv2.warpPerspective(stickman, M, (h,w), borderMode = cv2.BORDER_REPLICATE) 282 | 283 | part_imgs.append(part_img) 284 | part_stickmen.append(part_stickman) 285 | img = np.concatenate(part_imgs, axis = 2) 286 | stickman = np.concatenate(part_stickmen, axis = 2) 287 | 288 | out_imgs.append(img) 289 | out_stickmen.append(stickman) 290 | out_imgs = np.stack(out_imgs) 291 | out_stickmen = np.stack(out_stickmen) 292 | return out_imgs, out_stickmen 293 | 294 | 295 | class IndexFlow(object): 296 | """Batches from index file.""" 297 | def __init__( 298 | self, 299 | shape, 300 | index_path, 301 | train, 302 | box_factor, 303 | fill_batches = True, 304 | shuffle = True, 305 | return_keys = ["imgs", "joints", "norm_imgs", "norm_joints"]): 306 | self.shape = shape 307 | self.batch_size = self.shape[0] 308 | self.img_shape = self.shape[1:] 309 | self.box_factor = box_factor 310 | with open(index_path, "rb") as f: 311 | self.index = pickle.load(f) 312 | self.basepath = os.path.dirname(index_path) 313 | self.train = train 314 | self.fill_batches = fill_batches 315 | self.shuffle_ = shuffle 316 | self.return_keys = return_keys 317 | 318 | self.jo = self.index["joint_order"] 319 | # rescale joint coordinates to image shape 320 | h,w = self.img_shape[:2] 321 | wh = np.array([[[w,h]]]) 322 | self.index["joints"] = self.index["joints"] * wh 323 | 324 | self.indices = np.array( 325 | [i for i in range(len(self.index["train"])) 326 | if self._filter(i)]) 327 | 328 | self.n = self.indices.shape[0] 329 | self.shuffle() 330 | 331 | 332 | def _filter(self, i): 333 | good = True 334 | good = good and (self.index["train"][i] == self.train) 335 | joints = self.index["joints"][i] 336 | required_joints = ["lshoulder","rshoulder","lhip","rhip"] 337 | joint_indices = [self.jo.index(b) for b in required_joints] 338 | joints = np.float32(joints[joint_indices]) 339 | good = good and valid_joints(joints) 340 | return good 341 | 342 | 343 | def __next__(self): 344 | batch = dict() 345 | 346 | # get indices for batch 347 | batch_start, batch_end = self.batch_start, self.batch_start + self.batch_size 348 | batch_indices = self.indices[batch_start:batch_end] 349 | if self.fill_batches and batch_indices.shape[0] != self.batch_size: 350 | n_missing = self.batch_size - batch_indices.shape[0] 351 | batch_indices = np.concatenate([batch_indices, self.indices[:n_missing]], axis = 0) 352 | assert batch_indices.shape[0] == self.batch_size 353 | batch_indices = np.array(batch_indices) 354 | batch["indices"] = batch_indices 355 | 356 | # prepare next batch 357 | if batch_end >= self.n: 358 | self.shuffle() 359 | else: 360 | self.batch_start = batch_end 361 | 362 | # prepare batch data 363 | # load images 364 | batch["imgs"] = list() 365 | for i in batch_indices: 366 | relpath = self.index["imgs"][i] 367 | path = os.path.join(self.basepath, relpath) 368 | batch["imgs"].append(load_img(path, target_size = self.img_shape)) 369 | batch["imgs"] = np.stack(batch["imgs"]) 370 | batch["imgs"] = preprocess(batch["imgs"]) 371 | 372 | # load joint coordinates 373 | batch["joints_coordinates"] = [self.index["joints"][i] for i in batch_indices] 374 | 375 | # generate stickmen images from coordinates 376 | batch["joints"] = list() 377 | for joints in batch["joints_coordinates"]: 378 | img = make_joint_img(self.img_shape, self.jo, joints) 379 | batch["joints"].append(img) 380 | batch["joints"] = np.stack(batch["joints"]) 381 | batch["joints"] = preprocess(batch["joints"]) 382 | 383 | imgs, joints = normalize(batch["imgs"], batch["joints_coordinates"], batch["joints"], self.jo, self.box_factor) 384 | batch["norm_imgs"] = imgs 385 | batch["norm_joints"] = joints 386 | 387 | batch_list = [batch[k] for k in self.return_keys] 388 | return batch_list 389 | 390 | 391 | def shuffle(self): 392 | self.batch_start = 0 393 | if self.shuffle_: 394 | np.random.shuffle(self.indices) 395 | 396 | 397 | def get_batches( 398 | shape, 399 | index_path, 400 | train, 401 | box_factor, 402 | fill_batches = True, 403 | shuffle = True, 404 | return_keys = ["imgs", "joints", "norm_imgs", "norm_joints"]): 405 | """Buffered IndexFlow.""" 406 | flow = IndexFlow(shape, index_path, train, box_factor, fill_batches, shuffle, return_keys) 407 | return BufferedWrapper(flow) 408 | 409 | 410 | if __name__ == "__main__": 411 | import sys 412 | if not len(sys.argv) == 2: 413 | print("Useage: {} ".format(sys.argv[0])) 414 | exit(1) 415 | 416 | batches = get_batches( 417 | shape = (16, 128, 128, 3), 418 | index_path = sys.argv[1], 419 | train = True, 420 | box_factor = 2, 421 | shuffle = True) 422 | X, C, XN, CN = next(batches) 423 | plot_batch(X, "images.png") 424 | plot_batch(C, "joints.png") 425 | -------------------------------------------------------------------------------- /coco.yaml: -------------------------------------------------------------------------------- 1 | data_index: data/coco/index.p 2 | batch_size: 8 3 | init_batches: 4 4 | spatial_size: 256 5 | box_factor: 2 6 | bottleneck_factor: 2 7 | 8 | lr: 1.0e-3 9 | lr_decay_begin: 1000 10 | lr_decay_end: 100000 11 | log_freq: 250 12 | ckpt_freq: 1000 13 | test_freq: 1000 14 | drop_prob: 0.0 15 | 16 | feature_layers: [ 17 | "input_1", 18 | "block1_conv2", 19 | "block2_conv2", 20 | "block3_conv2", 21 | "block4_conv2", 22 | "block5_conv2"] 23 | 24 | feature_weights: [ 25 | 1.0, 26 | 1.0, 27 | 1.0, 28 | 1.0, 29 | 1.0, 30 | 1.0] 31 | 32 | gram_weights: [ 33 | 0.0, 34 | 0.01, 35 | 0.01, 36 | 0.01, 37 | 0.01, 38 | 0.01] 39 | -------------------------------------------------------------------------------- /coco_retrain.yaml: -------------------------------------------------------------------------------- 1 | data_index: data/coco/index.p 2 | batch_size: 8 3 | init_batches: 4 4 | spatial_size: 256 5 | box_factor: 2 6 | bottleneck_factor: 2 7 | 8 | lr: 1.0e-3 9 | lr_decay_begin: 1000 10 | lr_decay_end: 100000 11 | log_freq: 250 12 | ckpt_freq: 1000 13 | test_freq: 1000 14 | drop_prob: 0.0 15 | 16 | feature_layers: [ 17 | "input_1", 18 | "block1_conv2", 19 | "block2_conv2", 20 | "block3_conv2", 21 | "block4_conv2", 22 | "block5_conv2"] 23 | 24 | feature_weights: [ 25 | 2.0, 26 | 2.0, 27 | 2.0, 28 | 2.0, 29 | 2.0, 30 | 2.0] 31 | 32 | gram_weights: [ 33 | 0.0, 34 | 0.02, 35 | 0.02, 36 | 0.02, 37 | 0.02, 38 | 0.02] 39 | -------------------------------------------------------------------------------- /deepfashion.yaml: -------------------------------------------------------------------------------- 1 | data_index: data/deepfashion/index.p 2 | batch_size: 8 3 | init_batches: 4 4 | spatial_size: 256 5 | box_factor: 2 6 | bottleneck_factor: 2 7 | 8 | lr: 1.0e-3 9 | lr_decay_begin: 1000 10 | lr_decay_end: 100000 11 | log_freq: 250 12 | ckpt_freq: 1000 13 | test_freq: 1000 14 | drop_prob: 0.1 15 | 16 | feature_layers: [ 17 | "input_1", 18 | "block1_conv2", 19 | "block2_conv2", 20 | "block3_conv2", 21 | "block4_conv2", 22 | "block5_conv2"] 23 | 24 | feature_weights: [ 25 | 1.0, 26 | 1.0, 27 | 1.0, 28 | 1.0, 29 | 1.0, 30 | 1.0] 31 | 32 | gram_weights: [ 33 | 0.0, 34 | 0.0, 35 | 0.0, 36 | 0.0, 37 | 0.0, 38 | 0.0] 39 | -------------------------------------------------------------------------------- /deepfashion_retrain.yaml: -------------------------------------------------------------------------------- 1 | data_index: data/deepfashion/index.p 2 | batch_size: 8 3 | init_batches: 4 4 | spatial_size: 256 5 | box_factor: 2 6 | bottleneck_factor: 2 7 | 8 | lr: 1.0e-3 9 | lr_decay_begin: 1000 10 | lr_decay_end: 100000 11 | log_freq: 250 12 | ckpt_freq: 1000 13 | test_freq: 1000 14 | drop_prob: 0.1 15 | 16 | feature_layers: [ 17 | "input_1", 18 | "block1_conv2", 19 | "block2_conv2", 20 | "block3_conv2", 21 | "block4_conv2", 22 | "block5_conv2"] 23 | 24 | feature_weights: [ 25 | 1.0, 26 | 1.0, 27 | 1.0, 28 | 1.0, 29 | 1.0, 30 | 1.0] 31 | 32 | gram_weights: [ 33 | 0.1, 34 | 0.1, 35 | 0.1, 36 | 0.1, 37 | 0.1, 38 | 0.1] 39 | -------------------------------------------------------------------------------- /deeploss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | # vgg19 from keras 4 | from tensorflow.contrib.keras.api.keras.models import Model 5 | from tensorflow.contrib.keras.api.keras.applications.vgg19 import VGG19 6 | from tensorflow.contrib.keras.api.keras import backend as K 7 | 8 | 9 | def preprocess_input(x): 10 | """Preprocesses a tensor encoding a batch of images. 11 | # Arguments 12 | x: input tensor, 4D in [-1,1] 13 | # Returns 14 | Preprocessed tensor. 15 | """ 16 | # from [-1, 1] to [0,255.0] 17 | x = (x + 1.0) / 2.0 * 255.0 18 | # 'RGB'->'BGR' 19 | x = x[:, :, :, ::-1] 20 | # Zero-center by mean pixel 21 | x = x - np.array([103.939, 116.779, 123.68]).reshape((1,1,1,3)) 22 | return x 23 | 24 | 25 | class VGG19Features(object): 26 | def __init__(self, session, 27 | feature_layers = None, feature_weights = None, gram_weights = None): 28 | K.set_session(session) 29 | self.base_model = VGG19( 30 | include_top = False, 31 | weights='imagenet') 32 | if feature_layers is None: 33 | feature_layers = [ 34 | "input_1", 35 | "block1_conv2", "block2_conv2", 36 | "block3_conv2", "block4_conv2", 37 | "block5_conv2"] 38 | self.layer_names = [l.name for l in self.base_model.layers] 39 | for k in feature_layers: 40 | if not k in self.layer_names: 41 | raise KeyError( 42 | "Invalid layer {}. Available layers: {}".format( 43 | k, self.layer_names)) 44 | features = [self.base_model.get_layer(k).output for k in feature_layers] 45 | self.model = Model( 46 | inputs = self.base_model.input, 47 | outputs = features) 48 | if feature_weights is None: 49 | feature_weights = len(feature_layers) * [1.0] 50 | if gram_weights is None: 51 | gram_weights = len(feature_layers) * [0.1] 52 | self.feature_weights = feature_weights 53 | self.gram_weights = gram_weights 54 | assert len(self.feature_weights) == len(features) 55 | self.use_gram = np.max(self.gram_weights) > 0.0 56 | 57 | self.variables = self.base_model.weights 58 | 59 | 60 | def extract_features(self, x): 61 | """x should be rgb in [-1,1].""" 62 | x = preprocess_input(x) 63 | features = self.model.predict(x) 64 | return features 65 | 66 | 67 | def make_feature_ops(self, x): 68 | """x should be rgb tensor in [-1,1].""" 69 | x = preprocess_input(x) 70 | features = self.model(x) 71 | return features 72 | 73 | 74 | def grams(self, fs): 75 | gs = list() 76 | for f in fs: 77 | bs, h, w, c = f.shape.as_list() 78 | f = tf.reshape(f, [bs, h*w, c]) 79 | ft = tf.transpose(f, [0,2,1]) 80 | g = tf.matmul(ft, f) 81 | g = g / (4.0*h*w) 82 | gs.append(g) 83 | return gs 84 | 85 | 86 | def make_loss_op(self, x, y): 87 | """x, y should be rgb tensors in [-1,1].""" 88 | x = preprocess_input(x) 89 | x_features = self.model(x) 90 | 91 | y = preprocess_input(y) 92 | y_features = self.model(y) 93 | 94 | x_grams = self.grams(x_features) 95 | y_grams = self.grams(y_features) 96 | 97 | losses = [ 98 | tf.reduce_mean(tf.abs(xf - yf)) for xf, yf in zip( 99 | x_features, y_features)] 100 | gram_losses = [ 101 | tf.reduce_mean(tf.abs(xg - yg)) for xg, yg in zip( 102 | x_grams, y_grams)] 103 | 104 | for i in range(len(losses)): 105 | losses[i] = self.feature_weights[i] * losses[i] 106 | gram_losses[i] = self.gram_weights[i] * gram_losses[i] 107 | loss = tf.add_n(losses) 108 | if self.use_gram: 109 | loss = loss + tf.add_n(gram_losses) 110 | 111 | self.losses = losses 112 | self.gram_losses = gram_losses 113 | 114 | return loss 115 | -------------------------------------------------------------------------------- /download_data.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | if [ $# -lt 2 ] 4 | then 5 | echo "Usage: $0 " 6 | echo "where is one of coco, deepfashion, market" 7 | echo "and is the destination folder" 8 | exit 1 9 | fi 10 | 11 | data=$1 12 | dest=$2 13 | mkdir -p "$dest" 14 | src="http://129.206.117.181:8080/${data}.tar.gz" 15 | wget -P "$dest" -c $src 16 | 17 | tar --skip-old-files -xzf "${dest}/${data}.tar.gz" -C "${dest}" 18 | 19 | srcdir=$(dirname $(realpath $0)) 20 | mkdir -p "${srcdir}/data" 21 | ln -s "${dest}/${data}" "${srcdir}/data/" 22 | 23 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | config = tf.ConfigProto() 3 | config.gpu_options.allow_growth = False 4 | session = tf.Session(config = config) 5 | 6 | import os, logging, shutil, datetime 7 | import glob 8 | import argparse 9 | import yaml 10 | import numpy as np 11 | from tqdm import tqdm, trange 12 | 13 | import nn 14 | import models 15 | from batches import get_batches, plot_batch, postprocess, n_boxes 16 | import deeploss 17 | 18 | 19 | def init_logging(out_base_dir): 20 | # get unique output directory based on current time 21 | os.makedirs(out_base_dir, exist_ok = True) 22 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 23 | out_dir = os.path.join(out_base_dir, now) 24 | os.makedirs(out_dir, exist_ok = False) 25 | # copy source code to logging dir to have an idea what the run was about 26 | this_file = os.path.realpath(__file__) 27 | assert(this_file.endswith(".py")) 28 | shutil.copy(this_file, out_dir) 29 | # copy all py files to logging dir 30 | src_dir = os.path.dirname(this_file) 31 | py_files = glob.glob(os.path.join(src_dir, "*.py")) 32 | for py_file in py_files: 33 | shutil.copy(py_file, out_dir) 34 | # init logging 35 | logging.basicConfig(filename = os.path.join(out_dir, 'log.txt')) 36 | logger = logging.getLogger(__name__) 37 | logger.setLevel(logging.DEBUG) 38 | return out_dir, logger 39 | 40 | 41 | class Model(object): 42 | def __init__(self, config, out_dir, logger): 43 | self.config = config 44 | self.batch_size = config["batch_size"] 45 | self.img_shape = 2*[config["spatial_size"]] + [3] 46 | self.bottleneck_factor = config["bottleneck_factor"] 47 | self.box_factor = config["box_factor"] 48 | self.imgn_shape = 2*[config["spatial_size"]//(2**self.box_factor)] + [n_boxes*3] 49 | self.init_batches = config["init_batches"] 50 | 51 | self.initial_lr = config["lr"] 52 | self.lr_decay_begin = config["lr_decay_begin"] 53 | self.lr_decay_end = config["lr_decay_end"] 54 | 55 | self.out_dir = out_dir 56 | self.logger = logger 57 | self.log_frequency = config["log_freq"] 58 | self.ckpt_frequency = config["ckpt_freq"] 59 | self.test_frequency = config["test_freq"] 60 | self.checkpoint_best = False 61 | 62 | self.dropout_p = config["drop_prob"] 63 | 64 | self.best_loss = float("inf") 65 | self.checkpoint_dir = os.path.join(self.out_dir, "checkpoints") 66 | os.makedirs(self.checkpoint_dir, exist_ok = True) 67 | 68 | self.define_models() 69 | self.define_graph() 70 | 71 | 72 | def define_models(self): 73 | n_latent_scales = 2 74 | n_scales = 1 + int(np.round(np.log2(self.img_shape[0]))) - self.bottleneck_factor 75 | n_filters = 32 76 | self.enc_up_pass = models.make_model( 77 | "enc_up", models.enc_up, 78 | n_scales = n_scales - self.box_factor, 79 | n_filters = n_filters*2**self.box_factor) 80 | self.enc_down_pass = models.make_model( 81 | "enc_down", models.enc_down, 82 | n_scales = n_scales - self.box_factor, 83 | n_latent_scales = n_latent_scales) 84 | self.dec_up_pass = models.make_model( 85 | "dec_up", models.dec_up, 86 | n_scales = n_scales, 87 | n_filters = n_filters) 88 | self.dec_down_pass = models.make_model( 89 | "dec_down", models.dec_down, 90 | n_scales = n_scales, 91 | n_latent_scales = n_latent_scales) 92 | self.dec_params = models.make_model( 93 | "dec_params", models.dec_parameters) 94 | 95 | 96 | def train_forward_pass(self, x, c, xn, cn, dropout_p, init = False): 97 | kwargs = {"init": init, "dropout_p": dropout_p} 98 | # encoder 99 | hs = self.enc_up_pass(xn, cn, **kwargs) 100 | es, qs, zs_posterior = self.enc_down_pass(hs, **kwargs) 101 | # decoder 102 | gs = self.dec_up_pass(c, **kwargs) 103 | ds, ps, zs_prior = self.dec_down_pass(gs, zs_posterior, training = True, **kwargs) 104 | params = self.dec_params(ds[-1], **kwargs) 105 | activations = hs + es + gs + ds 106 | return params, qs, ps, activations 107 | 108 | 109 | def test_forward_pass(self, c): 110 | kwargs = {"init": False, "dropout_p": 0.0} 111 | # decoder 112 | gs = self.dec_up_pass(c, **kwargs) 113 | ds, ps, zs_prior = self.dec_down_pass(gs, [], training = False, **kwargs) 114 | params = self.dec_params(ds[-1], **kwargs) 115 | return params 116 | 117 | 118 | def transfer_pass(self, infer_x, infer_c, generate_c): 119 | kwargs = {"init": False, "dropout_p": 0.0} 120 | # infer latent code 121 | hs = self.enc_up_pass(infer_x, infer_c, **kwargs) 122 | es, qs, zs_posterior = self.enc_down_pass(hs, **kwargs) 123 | zs_mean = list(qs) 124 | # generate from inferred latent code and conditioning 125 | gs = self.dec_up_pass(generate_c, **kwargs) 126 | use_mean = True 127 | if use_mean: 128 | ds, ps, zs_prior = self.dec_down_pass(gs, zs_mean, training = True, **kwargs) 129 | else: 130 | ds, ps, zs_prior = self.dec_down_pass(gs, zs_posterior, training = True, **kwargs) 131 | params = self.dec_params(ds[-1], **kwargs) 132 | return params 133 | 134 | 135 | def sample(self, params, **kwargs): 136 | return params 137 | 138 | 139 | def likelihood_loss(self, x, params): 140 | return 5.0*self.vgg19.make_loss_op(x, params) 141 | 142 | 143 | def define_graph(self): 144 | # pretrained net for perceptual loss 145 | self.vgg19 = deeploss.VGG19Features(session, 146 | feature_layers = self.config["feature_layers"], 147 | feature_weights = self.config["feature_weights"], 148 | gram_weights = self.config["gram_weights"]) 149 | 150 | global_step = tf.Variable(0, trainable = False, name = "global_step") 151 | lr = nn.make_linear_var( 152 | global_step, 153 | self.lr_decay_begin, self.lr_decay_end, 154 | self.initial_lr, 0.0, 155 | 0.0, self.initial_lr) 156 | kl_weight = nn.make_linear_var( 157 | global_step, 158 | self.lr_decay_end // 2, 3 * self.lr_decay_end // 4, 159 | 1e-6, 1.0, 160 | 1e-6, 1.0) 161 | 162 | # initialization 163 | self.x_init = tf.placeholder( 164 | tf.float32, 165 | shape = [self.init_batches * self.batch_size] + self.img_shape) 166 | self.c_init = tf.placeholder( 167 | tf.float32, 168 | shape = [self.init_batches * self.batch_size] + self.img_shape) 169 | self.xn_init = tf.placeholder( 170 | tf.float32, 171 | shape = [self.init_batches * self.batch_size] + self.imgn_shape) 172 | self.cn_init = tf.placeholder( 173 | tf.float32, 174 | shape = [self.init_batches * self.batch_size] + self.imgn_shape) 175 | self.dd_init_op = self.train_forward_pass( 176 | self.x_init, self.c_init, 177 | self.xn_init, self.cn_init, 178 | dropout_p = self.dropout_p, init = True) 179 | 180 | # training 181 | self.x = tf.placeholder( 182 | tf.float32, 183 | shape = [self.batch_size] + self.img_shape) 184 | self.c = tf.placeholder( 185 | tf.float32, 186 | shape = [self.batch_size] + self.img_shape) 187 | self.xn = tf.placeholder( 188 | tf.float32, 189 | shape = [self.batch_size] + self.imgn_shape) 190 | self.cn = tf.placeholder( 191 | tf.float32, 192 | shape = [self.batch_size] + self.imgn_shape) 193 | # compute parameters of model distribution 194 | params, qs, ps, activations = self.train_forward_pass( 195 | self.x, self.c, 196 | self.xn, self.cn, 197 | dropout_p = self.dropout_p) 198 | # sample from model distribution 199 | sample = self.sample(params) 200 | # maximize likelihood 201 | likelihood_loss = self.likelihood_loss(self.x, params) 202 | kl_loss = tf.to_float(0.0) 203 | for q, p in zip(qs, ps): 204 | self.logger.info("Latent shape: {}".format(q.shape.as_list())) 205 | kl_loss += models.latent_kl(q, p) 206 | loss = likelihood_loss + kl_weight * kl_loss 207 | 208 | # testing 209 | test_forward = self.test_forward_pass(self.c) 210 | test_sample = self.sample(test_forward) 211 | 212 | # reconstruction 213 | reconstruction_params, _, _, _ = self.train_forward_pass( 214 | self.x, self.c, 215 | self.xn, self.cn, 216 | dropout_p = 0.0) 217 | self.reconstruction = self.sample(reconstruction_params) 218 | 219 | # optimization 220 | self.trainable_variables = [v for v in tf.trainable_variables() 221 | if not v in self.vgg19.variables] 222 | optimizer = tf.train.AdamOptimizer(learning_rate = lr, beta1 = 0.5, beta2 = 0.9) 223 | opt_op = optimizer.minimize(loss, var_list = self.trainable_variables) 224 | with tf.control_dependencies([opt_op]): 225 | self.train_op = tf.assign(global_step, global_step + 1) 226 | 227 | 228 | # logging and visualization 229 | self.log_ops = dict() 230 | self.log_ops["global_step"] = global_step 231 | self.log_ops["likelihood_loss"] = likelihood_loss 232 | self.log_ops["kl_loss"] = kl_loss 233 | self.log_ops["kl_weight"] = kl_weight 234 | self.log_ops["loss"] = loss 235 | self.img_ops = dict() 236 | self.img_ops["sample"] = sample 237 | self.img_ops["test_sample"] = test_sample 238 | self.img_ops["x"] = self.x 239 | self.img_ops["c"] = self.c 240 | for i, l in enumerate(self.vgg19.losses): 241 | self.log_ops["vgg_loss_{}".format(i)] = l 242 | 243 | # keep seperate train and validation summaries 244 | # only training summary contains histograms 245 | train_summaries = list() 246 | for k, v in self.log_ops.items(): 247 | train_summaries.append(tf.summary.scalar(k, v)) 248 | self.train_summary_op = tf.summary.merge_all() 249 | 250 | valid_summaries = list() 251 | for k, v in self.log_ops.items(): 252 | valid_summaries.append(tf.summary.scalar(k+"_valid", v)) 253 | self.valid_summary_op = tf.summary.merge(valid_summaries) 254 | 255 | # all variables for initialization 256 | self.variables = [v for v in tf.global_variables() 257 | if not v in self.vgg19.variables] 258 | 259 | self.logger.info("Defined graph") 260 | 261 | 262 | def init_graph(self, init_batch): 263 | self.writer = tf.summary.FileWriter( 264 | self.out_dir, 265 | session.graph) 266 | self.saver = tf.train.Saver(self.variables) 267 | initializer_op = tf.variables_initializer(self.variables) 268 | feed = { 269 | self.xn_init: init_batch[2], 270 | self.cn_init: init_batch[3], 271 | self.x_init: init_batch[0], 272 | self.c_init: init_batch[1]} 273 | session.run(initializer_op, feed) 274 | session.run(self.dd_init_op, feed) 275 | self.logger.info("Initialized model from scratch") 276 | 277 | 278 | def restore_graph(self, restore_path): 279 | self.writer = tf.summary.FileWriter( 280 | self.out_dir, 281 | session.graph) 282 | self.saver = tf.train.Saver(self.variables) 283 | self.saver.restore(session, restore_path) 284 | self.logger.info("Restored model from {}".format(restore_path)) 285 | 286 | 287 | def reset_global_step(self): 288 | session.run(tf.assign(self.log_ops["global_step"], 0)) 289 | self.logger.info("Reset global_step") 290 | 291 | 292 | def fit(self, batches, valid_batches = None): 293 | start_step = self.log_ops["global_step"].eval(session) 294 | self.valid_batches = valid_batches 295 | for batch in trange(start_step, self.lr_decay_end): 296 | X_batch, C_batch, XN_batch, CN_batch = next(batches) 297 | feed_dict = { 298 | self.xn: XN_batch, 299 | self.cn: CN_batch, 300 | self.x: X_batch, 301 | self.c: C_batch} 302 | fetch_dict = {"train": self.train_op} 303 | if self.log_ops["global_step"].eval(session) % self.log_frequency == 0: 304 | fetch_dict["log"] = self.log_ops 305 | fetch_dict["img"] = self.img_ops 306 | fetch_dict["summary"] = self.train_summary_op 307 | result = session.run(fetch_dict, feed_dict) 308 | self.log_result(result) 309 | 310 | 311 | def log_result(self, result, **kwargs): 312 | global_step = self.log_ops["global_step"].eval(session) 313 | if "summary" in result: 314 | self.writer.add_summary(result["summary"], global_step) 315 | self.writer.flush() 316 | if "log" in result: 317 | for k in sorted(result["log"]): 318 | v = result["log"][k] 319 | self.logger.info("{}: {}".format(k, v)) 320 | if "img" in result: 321 | for k, v in result["img"].items(): 322 | plot_batch(v, os.path.join( 323 | self.out_dir, 324 | k + "_{:07}.png".format(global_step))) 325 | 326 | if self.valid_batches is not None: 327 | # validation run 328 | X_batch, C_batch, XN_batch, CN_batch = next(self.valid_batches) 329 | feed_dict = { 330 | self.xn: XN_batch, 331 | self.cn: CN_batch, 332 | self.x: X_batch, 333 | self.c: C_batch} 334 | fetch_dict = dict() 335 | fetch_dict["imgs"] = self.img_ops 336 | fetch_dict["summary"] = self.valid_summary_op 337 | fetch_dict["validation_loss"] = self.log_ops["loss"] 338 | result = session.run(fetch_dict, feed_dict) 339 | self.writer.add_summary(result["summary"], global_step) 340 | self.writer.flush() 341 | # display samples 342 | imgs = result["imgs"] 343 | for k, v in imgs.items(): 344 | plot_batch(v, os.path.join( 345 | self.out_dir, 346 | "valid_" + k + "_{:07}.png".format(global_step))) 347 | # log validation loss 348 | validation_loss = result["validation_loss"] 349 | self.logger.info("{}: {}".format("validation_loss", validation_loss)) 350 | if self.checkpoint_best and validation_loss < self.best_loss: 351 | # checkpoint if validation loss improved 352 | self.logger.info("step {}: Validation loss improved from {:.4e} to {:.4e}".format(global_step, self.best_loss, validation_loss)) 353 | self.best_loss = validation_loss 354 | self.make_checkpoint(global_step, prefix = "best_") 355 | if global_step % self.test_frequency == 0: 356 | if self.valid_batches is not None: 357 | # testing 358 | X_batch, C_batch, XN_batch, CN_batch = next(self.valid_batches) 359 | x_gen = self.test(C_batch) 360 | for k in x_gen: 361 | plot_batch(x_gen[k], os.path.join( 362 | self.out_dir, 363 | "testing_{}_{:07}.png".format(k, global_step))) 364 | # transfer 365 | bs = X_batch.shape[0] 366 | imgs = list() 367 | imgs.append(np.zeros_like(X_batch[0,...])) 368 | for r in range(bs): 369 | imgs.append(C_batch[r,...]) 370 | for i in range(bs): 371 | x_infer = XN_batch[i,...] 372 | c_infer = CN_batch[i,...] 373 | imgs.append(X_batch[i,...]) 374 | 375 | x_infer_batch = x_infer[None,...].repeat(bs, axis = 0) 376 | c_infer_batch = c_infer[None,...].repeat(bs, axis = 0) 377 | c_generate_batch = C_batch 378 | results = model.transfer(x_infer_batch, c_infer_batch, c_generate_batch) 379 | for j in range(bs): 380 | imgs.append(results[j,...]) 381 | imgs = np.stack(imgs, axis = 0) 382 | plot_batch(imgs, os.path.join( 383 | out_dir, 384 | "transfer_{:07}.png".format(global_step))) 385 | if global_step % self.ckpt_frequency == 0: 386 | self.make_checkpoint(global_step) 387 | 388 | 389 | def make_checkpoint(self, global_step, prefix = ""): 390 | fname = os.path.join(self.checkpoint_dir, prefix + "model.ckpt") 391 | self.saver.save( 392 | session, 393 | fname, 394 | global_step = global_step) 395 | self.logger.info("Saved model to {}".format(fname)) 396 | 397 | 398 | def test(self, c_batch): 399 | results = dict() 400 | results["cond"] = c_batch 401 | sample = session.run(self.img_ops["test_sample"], 402 | {self.c: c_batch}) 403 | results["test_sample"] = sample 404 | return results 405 | 406 | 407 | def reconstruct(self, x_batch, c_batch): 408 | return session.run( 409 | self.reconstruction, 410 | {self.x: x_batch, self.c: c_batch}) 411 | 412 | 413 | def transfer(self, x_encode, c_encode, c_decode): 414 | initialized = getattr(self, "_init_transfer", False) 415 | if not initialized: 416 | # transfer 417 | self.c_generator = tf.placeholder( 418 | tf.float32, 419 | shape = [self.batch_size] + self.img_shape) 420 | infer_x = self.xn 421 | infer_c = self.cn 422 | generate_c = self.c_generator 423 | transfer_params = self.transfer_pass(infer_x, infer_c, generate_c) 424 | self.transfer_mean_sample = self.sample(transfer_params) 425 | self._init_transfer = True 426 | 427 | return session.run( 428 | self.transfer_mean_sample, { 429 | self.xn: x_encode, 430 | self.cn: c_encode, 431 | self.c_generator: c_decode}) 432 | 433 | 434 | if __name__ == "__main__": 435 | default_log_dir = os.path.join(os.getcwd(), "log") 436 | 437 | parser = argparse.ArgumentParser() 438 | parser.add_argument("--config", required = True, help = "path to config") 439 | parser.add_argument("--mode", default = "train", 440 | choices=["train", "test", "add_reconstructions", "transfer"]) 441 | parser.add_argument("--log_dir", default = default_log_dir, help = "path to log into") 442 | parser.add_argument("--checkpoint", help = "path to checkpoint to restore") 443 | parser.add_argument("--retrain", dest = "retrain", action = "store_true", help = "reset global_step to zero") 444 | parser.set_defaults(retrain = False) 445 | 446 | opt = parser.parse_args() 447 | 448 | with open(opt.config) as f: 449 | config = yaml.load(f) 450 | 451 | out_dir, logger = init_logging(opt.log_dir) 452 | logger.info(opt) 453 | logger.info(yaml.dump(config)) 454 | 455 | if opt.mode == "train": 456 | batch_size = config["batch_size"] 457 | img_shape = 2*[config["spatial_size"]] + [3] 458 | data_shape = [batch_size] + img_shape 459 | init_shape = [config["init_batches"] * batch_size] + img_shape 460 | box_factor = config["box_factor"] 461 | 462 | data_index = config["data_index"] 463 | batches = get_batches(data_shape, data_index, train = True, box_factor = box_factor) 464 | init_batches = get_batches(init_shape, data_index, train = True, box_factor = box_factor) 465 | valid_batches = get_batches(data_shape, data_index, train = False, box_factor = box_factor) 466 | logger.info("Number of training samples: {}".format(batches.n)) 467 | logger.info("Number of validation samples: {}".format(valid_batches.n)) 468 | 469 | model = Model(config, out_dir, logger) 470 | if opt.checkpoint is not None: 471 | model.restore_graph(opt.checkpoint) 472 | else: 473 | model.init_graph(next(init_batches)) 474 | if opt.retrain: 475 | model.reset_global_step() 476 | model.fit(batches, valid_batches) 477 | elif opt.mode == "transfer": 478 | batch_size = config["batch_size"] 479 | img_shape = 2*[config["spatial_size"]] + [3] 480 | data_shape = [batch_size] + img_shape 481 | box_factor = config["box_factor"] 482 | data_index = config["data_index"] 483 | 484 | valid_batches = get_batches(data_shape, data_index, 485 | box_factor = box_factor, train = False) 486 | 487 | model = Model(config, out_dir, logger) 488 | assert opt.checkpoint is not None 489 | model.restore_graph(opt.checkpoint) 490 | 491 | for step in trange(10): 492 | X_batch, C_batch, XN_batch, CN_batch = next(valid_batches) 493 | bs = X_batch.shape[0] 494 | imgs = list() 495 | imgs.append(np.zeros_like(X_batch[0,...])) 496 | for r in range(bs): 497 | imgs.append(C_batch[r,...]) 498 | for i in range(bs): 499 | x_infer = XN_batch[i,...] 500 | c_infer = CN_batch[i,...] 501 | imgs.append(X_batch[i,...]) 502 | 503 | x_infer_batch = x_infer[None,...].repeat(bs, axis = 0) 504 | c_infer_batch = c_infer[None,...].repeat(bs, axis = 0) 505 | c_generate_batch = C_batch 506 | results = model.transfer(x_infer_batch, c_infer_batch, c_generate_batch) 507 | for j in range(bs): 508 | imgs.append(results[j,...]) 509 | imgs = np.stack(imgs, axis = 0) 510 | plot_batch(imgs, os.path.join( 511 | out_dir, 512 | "transfer_{}.png".format(step))) 513 | else: 514 | raise NotImplemented() 515 | -------------------------------------------------------------------------------- /market.yaml: -------------------------------------------------------------------------------- 1 | data_index: data/market/index.p 2 | batch_size: 16 3 | init_batches: 4 4 | spatial_size: 128 5 | box_factor: 1 6 | bottleneck_factor: 1 7 | 8 | lr: 1.0e-3 9 | lr_decay_begin: 1000 10 | lr_decay_end: 100000 11 | log_freq: 250 12 | ckpt_freq: 1000 13 | test_freq: 1000 14 | drop_prob: 0.1 15 | 16 | feature_layers: [ 17 | "input_1", 18 | "block1_conv2", 19 | "block2_conv2", 20 | "block3_conv2", 21 | "block4_conv2", 22 | "block5_conv2"] 23 | 24 | feature_weights: [ 25 | 1.0, 26 | 1.0, 27 | 1.0, 28 | 1.0, 29 | 1.0, 30 | 1.0] 31 | 32 | gram_weights: [ 33 | 0.0, 34 | 0.0, 35 | 0.0, 36 | 0.0, 37 | 0.0, 38 | 0.0] 39 | -------------------------------------------------------------------------------- /market_retrain.yaml: -------------------------------------------------------------------------------- 1 | data_index: data/market/index.p 2 | batch_size: 16 3 | init_batches: 4 4 | spatial_size: 128 5 | box_factor: 1 6 | bottleneck_factor: 1 7 | 8 | lr: 1.0e-3 9 | lr_decay_begin: 1000 10 | lr_decay_end: 100000 11 | log_freq: 250 12 | ckpt_freq: 1000 13 | test_freq: 1000 14 | drop_prob: 0.1 15 | 16 | feature_layers: [ 17 | "input_1", 18 | "block1_conv2", 19 | "block2_conv2", 20 | "block3_conv2", 21 | "block4_conv2", 22 | "block5_conv2"] 23 | 24 | feature_weights: [ 25 | 1.0, 26 | 1.0, 27 | 1.0, 28 | 1.0, 29 | 1.0, 30 | 1.0] 31 | 32 | gram_weights: [ 33 | 0.0, 34 | 0.05, 35 | 0.05, 36 | 0.0, 37 | 0.0, 38 | 0.0] 39 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.contrib.framework.python.ops import arg_scope 3 | import nn 4 | 5 | 6 | def model_arg_scope(**kwargs): 7 | """Create new counter and apply arg scope to all arg scoped nn 8 | operations.""" 9 | counters = {} 10 | return arg_scope( 11 | [nn.conv2d, nn.deconv2d, nn.residual_block, nn.dense, nn.activate], 12 | counters = counters, **kwargs) 13 | 14 | 15 | def make_model(name, template, **kwargs): 16 | """Create model with fixed kwargs.""" 17 | run = lambda *args, **kw: template(*args, **dict((k, v) for kws in (kw, kwargs) for k, v in kws.items())) 18 | return tf.make_template(name, run, unique_name_ = name) 19 | 20 | 21 | def dec_up( 22 | c, init = False, dropout_p = 0.5, 23 | n_scales = 1, n_residual_blocks = 2, activation = "elu", n_filters = 64, max_filters = 128): 24 | with model_arg_scope( 25 | init = init, dropout_p = dropout_p, activation = activation): 26 | # outputs 27 | hs = [] 28 | # prepare input 29 | h = nn.nin(c, n_filters) 30 | for l in range(n_scales): 31 | # level module 32 | for i in range(n_residual_blocks): 33 | h = nn.residual_block(h) 34 | hs.append(h) 35 | # prepare input to next level 36 | if l + 1 < n_scales: 37 | n_filters = min(2*n_filters, max_filters) 38 | h = nn.downsample(h, n_filters) 39 | return hs 40 | 41 | 42 | def dec_down( 43 | gs, zs_posterior, training, init = False, dropout_p = 0.5, 44 | n_scales = 1, n_residual_blocks = 2, activation = "elu", 45 | n_latent_scales = 2): 46 | assert n_residual_blocks % 2 == 0 47 | gs = list(gs) 48 | zs_posterior = list(zs_posterior) 49 | with model_arg_scope( 50 | init = init, dropout_p = dropout_p, activation = activation): 51 | # outputs 52 | hs = [] # hidden units 53 | ps = [] # priors 54 | zs = [] # prior samples 55 | # prepare input 56 | n_filters = gs[-1].shape.as_list()[-1] 57 | h = nn.nin(gs[-1], n_filters) 58 | for l in range(n_scales): 59 | # level module 60 | ## hidden units 61 | for i in range(n_residual_blocks // 2): 62 | h = nn.residual_block(h, gs.pop()) 63 | hs.append(h) 64 | if l < n_latent_scales: 65 | ## prior 66 | spatial_shape = h.shape.as_list()[1] 67 | n_h_channels = h.shape.as_list()[-1] 68 | if spatial_shape == 1: 69 | ### no spatial correlations 70 | p = latent_parameters(h) 71 | ps.append(p) 72 | z_prior = latent_sample(p) 73 | zs.append(z_prior) 74 | else: 75 | ### four autoregressively modeled groups 76 | if training: 77 | z_posterior_groups = nn.split_groups(zs_posterior[0]) 78 | p_groups = [] 79 | z_groups = [] 80 | p_features = tf.space_to_depth(nn.residual_block(h), 2) 81 | for i in range(4): 82 | p_group = latent_parameters(p_features, num_filters = n_h_channels) 83 | p_groups.append(p_group) 84 | z_group = latent_sample(p_group) 85 | z_groups.append(z_group) 86 | # ar feedback sampled from 87 | if training: 88 | feedback = z_posterior_groups.pop(0) 89 | else: 90 | feedback = z_group 91 | # prepare input for next group 92 | if i + 1 < 4: 93 | p_features = nn.residual_block(p_features, feedback) 94 | if training: 95 | assert not z_posterior_groups 96 | # complete prior parameters 97 | p = nn.merge_groups(p_groups) 98 | ps.append(p) 99 | # complete prior sample 100 | z_prior = nn.merge_groups(z_groups) 101 | zs.append(z_prior) 102 | ## vae feedback sampled from 103 | if training: 104 | ## posterior 105 | z = zs_posterior.pop(0) 106 | else: 107 | ## prior 108 | z = z_prior 109 | for i in range(n_residual_blocks // 2): 110 | n_h_channels = h.shape.as_list()[-1] 111 | h = tf.concat([h, z], axis = -1) 112 | h = nn.nin(h, n_h_channels) 113 | h = nn.residual_block(h, gs.pop()) 114 | hs.append(h) 115 | else: 116 | for i in range(n_residual_blocks // 2): 117 | h = nn.residual_block(h, gs.pop()) 118 | hs.append(h) 119 | # prepare input to next level 120 | if l + 1 < n_scales: 121 | n_filters = gs[-1].shape.as_list()[-1] 122 | h = nn.upsample(h, n_filters) 123 | 124 | assert not gs 125 | if training: 126 | assert not zs_posterior 127 | 128 | return hs, ps, zs 129 | 130 | 131 | def enc_up( 132 | x, c, init = False, dropout_p = 0.5, 133 | n_scales = 1, n_residual_blocks = 2, activation = "elu", n_filters = 64, max_filters = 128): 134 | with model_arg_scope( 135 | init = init, dropout_p = dropout_p, activation = activation): 136 | # outputs 137 | hs = [] 138 | # prepare input 139 | #xc = tf.concat([x,c], axis = -1) 140 | xc = x 141 | h = nn.nin(xc, n_filters) 142 | for l in range(n_scales): 143 | # level module 144 | for i in range(n_residual_blocks): 145 | h = nn.residual_block(h) 146 | hs.append(h) 147 | # prepare input to next level 148 | if l + 1 < n_scales: 149 | n_filters = min(2*n_filters, max_filters) 150 | h = nn.downsample(h, n_filters) 151 | return hs 152 | 153 | 154 | def enc_down( 155 | gs, init = False, dropout_p = 0.5, 156 | n_scales = 1, n_residual_blocks = 2, activation = "elu", 157 | n_latent_scales = 2): 158 | assert n_residual_blocks % 2 == 0 159 | gs = list(gs) 160 | with model_arg_scope( 161 | init = init, dropout_p = dropout_p, activation = activation): 162 | # outputs 163 | hs = [] # hidden units 164 | qs = [] # posteriors 165 | zs = [] # samples from posterior 166 | # prepare input 167 | n_filters = gs[-1].shape.as_list()[-1] 168 | h = nn.nin(gs[-1], n_filters) 169 | for l in range(n_scales): 170 | # level module 171 | ## hidden units 172 | for i in range(n_residual_blocks // 2): 173 | h = nn.residual_block(h, gs.pop()) 174 | hs.append(h) 175 | if l < n_latent_scales: 176 | ## posterior parameters 177 | q = latent_parameters(h) 178 | qs.append(q) 179 | ## posterior sample 180 | z = latent_sample(q) 181 | zs.append(z) 182 | ## sample feedback 183 | for i in range(n_residual_blocks // 2): 184 | gz = tf.concat([gs.pop(), z], axis = -1) 185 | h = nn.residual_block(h, gz) 186 | hs.append(h) 187 | else: 188 | break 189 | # prepare input to next level 190 | if l + 1 < n_scales: 191 | n_filters = gs[-1].shape.as_list()[-1] 192 | h = nn.upsample(h, n_filters) 193 | 194 | return hs, qs, zs 195 | 196 | 197 | def dec_parameters( 198 | h, init = False, **kwargs): 199 | with model_arg_scope(init = init): 200 | num_filters = 3 201 | return nn.conv2d(h, num_filters) 202 | 203 | 204 | def latent_parameters( 205 | h, init = False, **kwargs): 206 | num_filters = kwargs.get("num_filters", h.shape.as_list()[-1]) 207 | return nn.conv2d(h, num_filters) 208 | 209 | 210 | def latent_sample(p): 211 | mean = p 212 | stddev = 1.0 213 | eps = tf.random_normal(mean.shape, mean = 0.0, stddev = 1.0) 214 | return mean + stddev * eps 215 | 216 | 217 | def latent_kl(q, p): 218 | mean1 = q 219 | mean2 = p 220 | 221 | kl = 0.5 * tf.square(mean2 - mean1) 222 | kl = tf.reduce_sum(kl, axis = [1,2,3]) 223 | kl = tf.reduce_mean(kl) 224 | return kl 225 | -------------------------------------------------------------------------------- /nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | modified from pixelcnn++ 3 | Various tensorflow utilities 4 | """ 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | from tensorflow.contrib.framework.python.ops import add_arg_scope 9 | 10 | 11 | def int_shape(x): 12 | return x.shape.as_list() 13 | 14 | 15 | def get_name(layer_name, counters): 16 | ''' utlity for keeping track of layer names ''' 17 | if not layer_name in counters: 18 | counters[layer_name] = 0 19 | name = layer_name + '_' + str(counters[layer_name]) 20 | counters[layer_name] += 1 21 | return name 22 | 23 | 24 | @add_arg_scope 25 | def dense(x, num_units, init_scale=1., counters={}, init=False, **kwargs): 26 | ''' fully connected layer ''' 27 | name = get_name('dense', counters) 28 | with tf.variable_scope(name): 29 | xs = x.shape.as_list() 30 | V = tf.get_variable('V', [xs[1], num_units], tf.float32, tf.random_normal_initializer(0, 0.05)) 31 | g = tf.get_variable('g', [num_units], dtype=tf.float32, initializer=tf.constant_initializer(1.)) 32 | b = tf.get_variable('b', [num_units], dtype=tf.float32, initializer=tf.constant_initializer(0.)) 33 | 34 | V_norm = tf.nn.l2_normalize(V, [0]) 35 | x = tf.matmul(x, V_norm) 36 | if init: 37 | mean, var = tf.nn.moments(x, [0]) 38 | g = tf.assign(g, init_scale / tf.sqrt(var + 1e-10)) 39 | b = tf.assign(b, -mean * g) 40 | x = tf.reshape(g, [1, num_units])*x + tf.reshape(b, [1, num_units]) 41 | 42 | return x 43 | 44 | 45 | @add_arg_scope 46 | def conv2d(x, num_filters, filter_size=[3, 3], stride=[1, 1], pad='SAME', init_scale=1., counters={}, init=False, **kwargs): 47 | ''' convolutional layer ''' 48 | num_filters = int(num_filters) 49 | strides = [1] + stride + [1] 50 | name = get_name('conv2d', counters) 51 | with tf.variable_scope(name): 52 | xs = x.shape.as_list() 53 | V = tf.get_variable('V', filter_size + [xs[-1], num_filters], 54 | tf.float32, tf.random_normal_initializer(0, 0.05)) 55 | g = tf.get_variable('g', [num_filters], dtype=tf.float32, initializer=tf.constant_initializer(1.)) 56 | b = tf.get_variable('b', [num_filters], dtype=tf.float32, initializer=tf.constant_initializer(0.)) 57 | 58 | V_norm = tf.nn.l2_normalize(V, [0,1,2]) 59 | x = tf.nn.conv2d(x, V_norm, [1] + stride + [1], pad) 60 | if init: 61 | mean, var = tf.nn.moments(x, [0,1,2]) 62 | g = tf.assign(g, init_scale / tf.sqrt(var + 1e-10)) 63 | b = tf.assign(b, -mean * g) 64 | x = tf.reshape(g, [1, 1, 1, num_filters])*x + tf.reshape(b, [1, 1, 1, num_filters]) 65 | 66 | return x 67 | 68 | 69 | @add_arg_scope 70 | def deconv2d(x, num_filters, filter_size=[3, 3], stride=[1, 1], pad='SAME', init_scale=1., counters={}, init=False, **kwargs): 71 | ''' transposed convolutional layer ''' 72 | num_filters = int(num_filters) 73 | name = get_name('deconv2d', counters) 74 | xs = int_shape(x) 75 | strides = [1] + stride + [1] 76 | if pad == 'SAME': 77 | target_shape = [xs[0], xs[1] * stride[0], 78 | xs[2] * stride[1], num_filters] 79 | else: 80 | target_shape = [xs[0], xs[1] * stride[0] + filter_size[0] - 81 | 1, xs[2] * stride[1] + filter_size[1] - 1, num_filters] 82 | with tf.variable_scope(name): 83 | V = tf.get_variable('V', 84 | filter_size + [num_filters, xs[-1]], 85 | tf.float32, 86 | tf.random_normal_initializer(0, 0.05)) 87 | g = tf.get_variable('g', [num_filters], dtype=tf.float32, initializer=tf.constant_initializer(1.)) 88 | b = tf.get_variable('b', [num_filters], dtype=tf.float32, initializer=tf.constant_initializer(0.)) 89 | 90 | V_norm = tf.nn.l2_normalize(V, [0,1,3]) 91 | x = tf.nn.conv2d_transpose(x, V_norm, target_shape, [1] + stride + [1], pad) 92 | if init: 93 | mean, var = tf.nn.moments(x, [0,1,2]) 94 | g = tf.assign(g, init_scale / tf.sqrt(var + 1e-10)) 95 | b = tf.assign(b, -mean * g) 96 | x = tf.reshape(g, [1, 1, 1, num_filters])*x + tf.reshape(b, [1, 1, 1, num_filters]) 97 | 98 | return x 99 | 100 | 101 | @add_arg_scope 102 | def activate(x, activation, **kwargs): 103 | if activation == None: 104 | return x 105 | elif activation == "elu": 106 | return tf.nn.elu(x) 107 | else: 108 | raise NotImplemented(activation) 109 | 110 | 111 | def nin(x, num_units): 112 | """ a network in network layer (1x1 CONV) """ 113 | s = int_shape(x) 114 | x = tf.reshape(x, [np.prod(s[:-1]), s[-1]]) 115 | x = dense(x, num_units) 116 | return tf.reshape(x, s[:-1] + [num_units]) 117 | 118 | 119 | def downsample(x, num_units): 120 | return conv2d(x, num_units, stride = [2, 2]) 121 | 122 | 123 | def upsample(x, num_units, method = "subpixel"): 124 | if method == "conv_transposed": 125 | return deconv2d(x, num_units, stride = [2, 2]) 126 | elif method == "subpixel": 127 | x = conv2d(x, 4*num_units) 128 | x = tf.depth_to_space(x, 2) 129 | return x 130 | 131 | 132 | @add_arg_scope 133 | def residual_block(x, a = None, conv=conv2d, init=False, dropout_p=0.0, gated = False, **kwargs): 134 | """Slight variation of original.""" 135 | xs = int_shape(x) 136 | num_filters = xs[-1] 137 | 138 | residual = x 139 | if a is not None: 140 | a = nin(activate(a), num_filters) 141 | residual = tf.concat([residual, a], axis = -1) 142 | residual = activate(residual) 143 | residual = tf.nn.dropout(residual, keep_prob = 1.0 - dropout_p) 144 | residual = conv(residual, num_filters) 145 | if gated: 146 | residual = activate(residual) 147 | residual = tf.nn.dropout(residual, keep_prob = 1.0 - dropout_p) 148 | residual = conv(residual, 2*num_filters) 149 | a, b = tf.split(residual, 2, 3) 150 | residual = a * tf.nn.sigmoid(b) 151 | 152 | return x + residual 153 | 154 | 155 | def make_linear_var( 156 | step, 157 | start, end, 158 | start_value, end_value, 159 | clip_min = 0.0, clip_max = 1.0): 160 | """linear from (a, alpha) to (b, beta), i.e. 161 | (beta - alpha)/(b - a) * (x - a) + alpha""" 162 | linear = ( 163 | (end_value - start_value) / 164 | (end - start) * 165 | (tf.cast(step, tf.float32) - start) + start_value) 166 | return tf.clip_by_value(linear, clip_min, clip_max) 167 | 168 | 169 | def split_groups(x, bs = 2): 170 | return tf.split(tf.space_to_depth(x, bs), bs**2, axis = 3) 171 | 172 | 173 | def merge_groups(xs, bs = 2): 174 | return tf.depth_to_space(tf.concat(xs, axis = 3), bs) 175 | -------------------------------------------------------------------------------- /nn_compat.py: -------------------------------------------------------------------------------- 1 | """ 2 | modified from pixelcnn++ 3 | Various tensorflow utilities 4 | """ 5 | 6 | import numpy as np 7 | import tensorflow as tf 8 | from tensorflow.contrib.framework.python.ops import add_arg_scope 9 | 10 | 11 | def int_shape(x): 12 | return x.shape.as_list() 13 | 14 | 15 | def get_name(layer_name, counters): 16 | ''' utlity for keeping track of layer names ''' 17 | if not layer_name in counters: 18 | counters[layer_name] = 0 19 | name = layer_name + '_' + str(counters[layer_name]) 20 | counters[layer_name] += 1 21 | return name 22 | 23 | 24 | @add_arg_scope 25 | def dense(x, num_units, init_scale=1., counters={}, init=False, **kwargs): 26 | ''' fully connected layer ''' 27 | name = get_name('dense', counters) 28 | with tf.variable_scope(name): 29 | if init: 30 | xs = x.shape.as_list() 31 | # data based initialization of parameters 32 | V = tf.get_variable('V', [xs[1], num_units], tf.float32, tf.random_normal_initializer(0, 0.05)) 33 | V_norm = tf.nn.l2_normalize(V.initialized_value(), [0]) 34 | x_init = tf.matmul(x, V_norm) 35 | m_init, v_init = tf.nn.moments(x_init, [0]) 36 | scale_init = init_scale / tf.sqrt(v_init + 1e-10) 37 | g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init) 38 | b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init * scale_init) 39 | x_init = tf.reshape(scale_init, [1, num_units]) * (x_init - tf.reshape(m_init, [1, num_units])) 40 | 41 | return x_init 42 | else: 43 | V = tf.get_variable("V") 44 | g = tf.get_variable("g") 45 | b = tf.get_variable("b") 46 | with tf.control_dependencies([tf.assert_variables_initialized([V, g, b])]): 47 | # use weight normalization (Salimans & Kingma, 2016) 48 | x = tf.matmul(x, V) 49 | scaler = g / tf.sqrt(tf.reduce_sum(tf.square(V), [0])) 50 | x = tf.reshape(scaler, [1, num_units]) * x + tf.reshape(b, [1, num_units]) 51 | 52 | return x 53 | 54 | 55 | @add_arg_scope 56 | def conv2d(x, num_filters, filter_size=[3, 3], stride=[1, 1], pad='SAME', init_scale=1., counters={}, init=False, **kwargs): 57 | ''' convolutional layer ''' 58 | num_filters = int(num_filters) 59 | strides = [1] + stride + [1] 60 | name = get_name('conv2d', counters) 61 | with tf.variable_scope(name): 62 | if init: 63 | xs = x.shape.as_list() 64 | # data based initialization of parameters 65 | V = tf.get_variable('V', filter_size + [xs[-1], num_filters], 66 | tf.float32, tf.random_normal_initializer(0, 0.05)) 67 | V_norm = tf.nn.l2_normalize(V.initialized_value(), [0, 1, 2]) 68 | x_init = tf.nn.conv2d(x, V_norm, strides, pad) 69 | m_init, v_init = tf.nn.moments(x_init, [0, 1, 2]) 70 | scale_init = init_scale / tf.sqrt(v_init + 1e-8) 71 | g = tf.get_variable('g', dtype=tf.float32, initializer = scale_init) 72 | b = tf.get_variable('b', dtype=tf.float32, initializer = -m_init * scale_init) 73 | x_init = tf.reshape(scale_init, [1, 1, 1, num_filters]) * (x_init - tf.reshape(m_init, [1, 1, 1, num_filters])) 74 | 75 | return x_init 76 | else: 77 | V = tf.get_variable("V") 78 | g = tf.get_variable("g") 79 | b = tf.get_variable("b") 80 | with tf.control_dependencies([tf.assert_variables_initialized([V, g, b])]): 81 | # use weight normalization (Salimans & Kingma, 2016) 82 | W = tf.reshape(g, [1, 1, 1, num_filters]) * tf.nn.l2_normalize(V, [0, 1, 2]) 83 | 84 | # calculate convolutional layer output 85 | x = tf.nn.bias_add(tf.nn.conv2d(x, W, strides, pad), b) 86 | 87 | return x 88 | 89 | 90 | @add_arg_scope 91 | def deconv2d(x, num_filters, filter_size=[3, 3], stride=[1, 1], pad='SAME', init_scale=1., counters={}, init=False, **kwargs): 92 | ''' transposed convolutional layer ''' 93 | num_filters = int(num_filters) 94 | name = get_name('deconv2d', counters) 95 | xs = int_shape(x) 96 | strides = [1] + stride + [1] 97 | if pad == 'SAME': 98 | target_shape = [xs[0], xs[1] * stride[0], 99 | xs[2] * stride[1], num_filters] 100 | else: 101 | target_shape = [xs[0], xs[1] * stride[0] + filter_size[0] - 102 | 1, xs[2] * stride[1] + filter_size[1] - 1, num_filters] 103 | with tf.variable_scope(name): 104 | if init: 105 | # data based initialization of parameters 106 | V = tf.get_variable('V', filter_size + [num_filters, xs[-1]], tf.float32, tf.random_normal_initializer(0, 0.05)) 107 | V_norm = tf.nn.l2_normalize(V.initialized_value(), [0, 1, 3]) 108 | x_init = tf.nn.conv2d_transpose(x, V_norm, target_shape, strides, padding=pad) 109 | m_init, v_init = tf.nn.moments(x_init, [0, 1, 2]) 110 | scale_init = init_scale / tf.sqrt(v_init + 1e-8) 111 | g = tf.get_variable('g', dtype=tf.float32, initializer=scale_init) 112 | b = tf.get_variable('b', dtype=tf.float32, initializer=-m_init * scale_init) 113 | x_init = tf.reshape(scale_init, [1, 1, 1, num_filters]) * (x_init - tf.reshape(m_init, [1, 1, 1, num_filters])) 114 | 115 | return x_init 116 | else: 117 | V = tf.get_variable("V") 118 | g = tf.get_variable("g") 119 | b = tf.get_variable("b") 120 | with tf.control_dependencies([tf.assert_variables_initialized([V, g, b])]): 121 | # use weight normalization (Salimans & Kingma, 2016) 122 | W = tf.reshape(g, [1, 1, num_filters, 1]) * tf.nn.l2_normalize(V, [0, 1, 3]) 123 | 124 | # calculate convolutional layer output 125 | x = tf.nn.conv2d_transpose(x, W, target_shape, strides, padding=pad) 126 | x = tf.nn.bias_add(x, b) 127 | 128 | return x 129 | 130 | 131 | @add_arg_scope 132 | def activate(x, activation, **kwargs): 133 | if activation == None: 134 | return x 135 | elif activation == "elu": 136 | return tf.nn.elu(x) 137 | else: 138 | raise NotImplemented(activation) 139 | 140 | 141 | def nin(x, num_units): 142 | """ a network in network layer (1x1 CONV) """ 143 | s = int_shape(x) 144 | x = tf.reshape(x, [np.prod(s[:-1]), s[-1]]) 145 | x = dense(x, num_units) 146 | return tf.reshape(x, s[:-1] + [num_units]) 147 | 148 | 149 | def downsample(x, num_units): 150 | return conv2d(x, num_units, stride = [2, 2]) 151 | 152 | 153 | def upsample(x, num_units, method = "subpixel"): 154 | if method == "conv_transposed": 155 | return deconv2d(x, num_units, stride = [2, 2]) 156 | elif method == "subpixel": 157 | x = conv2d(x, 4*num_units) 158 | x = tf.depth_to_space(x, 2) 159 | return x 160 | 161 | 162 | @add_arg_scope 163 | def residual_block(x, a = None, conv=conv2d, init=False, dropout_p=0.0, gated = False, **kwargs): 164 | """Slight variation of original.""" 165 | xs = int_shape(x) 166 | num_filters = xs[-1] 167 | 168 | residual = x 169 | if a is not None: 170 | a = nin(activate(a), num_filters) 171 | residual = tf.concat([residual, a], axis = -1) 172 | residual = activate(residual) 173 | residual = tf.nn.dropout(residual, keep_prob = 1.0 - dropout_p) 174 | residual = conv(residual, num_filters) 175 | if gated: 176 | residual = activate(residual) 177 | residual = tf.nn.dropout(residual, keep_prob = 1.0 - dropout_p) 178 | residual = conv(residual, 2*num_filters) 179 | a, b = tf.split(residual, 2, 3) 180 | residual = a * tf.nn.sigmoid(b) 181 | 182 | return x + residual 183 | 184 | 185 | def make_linear_var( 186 | step, 187 | start, end, 188 | start_value, end_value, 189 | clip_min = 0.0, clip_max = 1.0): 190 | """linear from (a, alpha) to (b, beta), i.e. 191 | (beta - alpha)/(b - a) * (x - a) + alpha""" 192 | linear = ( 193 | (end_value - start_value) / 194 | (end - start) * 195 | (tf.cast(step, tf.float32) - start) + start_value) 196 | return tf.clip_by_value(linear, clip_min, clip_max) 197 | 198 | 199 | def split_groups(x, bs = 2): 200 | return tf.split(tf.space_to_depth(x, bs), bs**2, axis = 3) 201 | 202 | 203 | def merge_groups(xs, bs = 2): 204 | return tf.depth_to_space(tf.concat(xs, axis = 3), bs) 205 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu==1.10.1 2 | numpy==1.14.5 3 | opencv-python==3.4.3.18 4 | Pillow==5.2.0 5 | tqdm==4.26.0 6 | PyYAML==3.13 7 | h5py==2.8.0 8 | --------------------------------------------------------------------------------