├── .gitignore ├── .DS_Store ├── util ├── .DS_Store ├── test_ff_mask.py ├── generate_coco_flist.py ├── generate_voc_flists.py ├── test_voc_mask.py └── generate_filelist.py ├── config ├── .DS_Store ├── inpaint.yml ├── inpaint_voc_dog.yml ├── inpaint_voc_bird.yml ├── inpaint_voc.yml ├── inpaint_voc_sheep.yml ├── inapint_voc_person.yml └── inpaint_all_person.yml ├── examples ├── output.png ├── center_mask_256.png ├── places2 │ ├── canyon_mask.png │ ├── grass_input.png │ ├── grass_mask.png │ ├── sunset_mask.png │ ├── wooden_mask.png │ ├── building_mask.png │ ├── canyon_input.png │ ├── sunset_input.png │ ├── wooden_input.png │ └── building_input.png ├── style_transfer │ ├── bike.jpg │ ├── bike_style_out.png │ └── bnw_butterfly.png ├── celeba │ ├── celebahr_patches_163050_input.png │ ├── celebahr_patches_164036_input.png │ ├── celebahr_patches_164787_input.png │ ├── celebahr_patches_165118_input.png │ └── celebahr_patches_165230_input.png └── imagenet │ ├── imagenet_patches_ILSVRC2012_val_00000827_input.png │ ├── imagenet_patches_ILSVRC2012_val_00008210_input.png │ ├── imagenet_patches_ILSVRC2012_val_00022355_input.png │ ├── imagenet_patches_ILSVRC2012_val_00025892_input.png │ └── imagenet_patches_ILSVRC2012_val_00045643_input.png ├── .ftpconfig ├── README.md ├── sn.py ├── test.py ├── train.py ├── data_from_fnames.py ├── trainer.py ├── mask_from_fnames.py ├── inpaint_model.py ├── inpaint_model_gc.py ├── LICENSE └── inpaint_ops.py /.gitignore: -------------------------------------------------------------------------------- 1 | model_logs 2 | neuralgym_logs 3 | data 4 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/.DS_Store -------------------------------------------------------------------------------- /util/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/util/.DS_Store -------------------------------------------------------------------------------- /config/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/config/.DS_Store -------------------------------------------------------------------------------- /examples/output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/output.png -------------------------------------------------------------------------------- /examples/center_mask_256.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/center_mask_256.png -------------------------------------------------------------------------------- /examples/places2/canyon_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/places2/canyon_mask.png -------------------------------------------------------------------------------- /examples/places2/grass_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/places2/grass_input.png -------------------------------------------------------------------------------- /examples/places2/grass_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/places2/grass_mask.png -------------------------------------------------------------------------------- /examples/places2/sunset_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/places2/sunset_mask.png -------------------------------------------------------------------------------- /examples/places2/wooden_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/places2/wooden_mask.png -------------------------------------------------------------------------------- /examples/style_transfer/bike.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/style_transfer/bike.jpg -------------------------------------------------------------------------------- /examples/places2/building_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/places2/building_mask.png -------------------------------------------------------------------------------- /examples/places2/canyon_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/places2/canyon_input.png -------------------------------------------------------------------------------- /examples/places2/sunset_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/places2/sunset_input.png -------------------------------------------------------------------------------- /examples/places2/wooden_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/places2/wooden_input.png -------------------------------------------------------------------------------- /examples/places2/building_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/places2/building_input.png -------------------------------------------------------------------------------- /examples/style_transfer/bike_style_out.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/style_transfer/bike_style_out.png -------------------------------------------------------------------------------- /examples/style_transfer/bnw_butterfly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/style_transfer/bnw_butterfly.png -------------------------------------------------------------------------------- /examples/celeba/celebahr_patches_163050_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/celeba/celebahr_patches_163050_input.png -------------------------------------------------------------------------------- /examples/celeba/celebahr_patches_164036_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/celeba/celebahr_patches_164036_input.png -------------------------------------------------------------------------------- /examples/celeba/celebahr_patches_164787_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/celeba/celebahr_patches_164787_input.png -------------------------------------------------------------------------------- /examples/celeba/celebahr_patches_165118_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/celeba/celebahr_patches_165118_input.png -------------------------------------------------------------------------------- /examples/celeba/celebahr_patches_165230_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/celeba/celebahr_patches_165230_input.png -------------------------------------------------------------------------------- /examples/imagenet/imagenet_patches_ILSVRC2012_val_00000827_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/imagenet/imagenet_patches_ILSVRC2012_val_00000827_input.png -------------------------------------------------------------------------------- /examples/imagenet/imagenet_patches_ILSVRC2012_val_00008210_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/imagenet/imagenet_patches_ILSVRC2012_val_00008210_input.png -------------------------------------------------------------------------------- /examples/imagenet/imagenet_patches_ILSVRC2012_val_00022355_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/imagenet/imagenet_patches_ILSVRC2012_val_00022355_input.png -------------------------------------------------------------------------------- /examples/imagenet/imagenet_patches_ILSVRC2012_val_00025892_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/imagenet/imagenet_patches_ILSVRC2012_val_00025892_input.png -------------------------------------------------------------------------------- /examples/imagenet/imagenet_patches_ILSVRC2012_val_00045643_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/avalonstrel/GatedConvolution/HEAD/examples/imagenet/imagenet_patches_ILSVRC2012_val_00045643_input.png -------------------------------------------------------------------------------- /.ftpconfig: -------------------------------------------------------------------------------- 1 | { 2 | "protocol": "sftp", 3 | "host": "wh-a-internal.brainpp.cn", 4 | "port": 22, 5 | "user": "Xiyu.linhangyu.brw", 6 | "pass": "", 7 | "promptForPass": false, 8 | "remote": "/unsullied/sharefs/linhangyu/Inpainting/GatedConvolution", 9 | "local": "", 10 | "agent": "", 11 | "privatekey": "/Users/linhangyu/.ssh/id_rsa", 12 | "passphrase": "ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQCbfZGxkOx2Bm2SHbs25NlTv2DSg3EX6a/GBlsJQFCVH9BrvhqrIY7wDFI79u36b4wJG8lsTpcXGKGa2LV1h+Le1mRS7t1q93y+I9T0A4sT6NwNGlVztztAmF+RNQmgAlX/dqV0RNdO/GB5T0YLHhrK7DStwZo+fYzZ2T639xgA14Nu0qKHhgZfjcH/MUHzR0YNjzd6YJHs/m9J7PQnIN13S0ylOVsOz6St8ER5tVRKGnw77Zw9XOQdPNxBaHcWHeC7SVkB21OFPr6xRWhKgv1zEaIxB1K4Le513irbVon0UYCZQZ3mXwzdgL1A2+XnV8LY36s4NXlX4tohg2XdM+cj linhangyu@BJ102102-1.local", 13 | "hosthash": "", 14 | "ignorehost": false, 15 | "connTimeout": 10000, 16 | "keepalive": 10000, 17 | "keyboardInteractive": false, 18 | "keyboardInteractiveForPass": false, 19 | "remoteCommand": "", 20 | "remoteShell": "", 21 | "watch": [], 22 | "watchTimeout": 500 23 | } 24 | -------------------------------------------------------------------------------- /util/test_ff_mask.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import cv2 4 | from PIL import Image 5 | def random_ff_mask( name="ff_mask"): 6 | """Generate a random free form mask with configuration. 7 | 8 | Args: 9 | config: Config should have configuration including IMG_SHAPES, 10 | VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH. 11 | 12 | Returns: 13 | tuple: (top, left, height, width) 14 | 15 | """ 16 | img_shape = (256,256,3) 17 | h,w,c = img_shape 18 | def npmask(): 19 | 20 | mask = np.zeros((h,w)) 21 | num_v = 10 + np.random.randint(6)#tf.random_uniform([], minval=0, maxval=config.MAXVERTEX, dtype=tf.int32) 22 | 23 | for i in range(num_v): 24 | start_x = np.random.randint(w) 25 | start_y = np.random.randint(h) 26 | for j in range(4): 27 | 28 | angle = 0.01+np.random.randint(4.0) 29 | if i % 2 == 0: 30 | angle = 2 * 3.1415926 - angle 31 | length = 10+np.random.randint(40) 32 | brush_w = 10+np.random.randint(3) 33 | end_x = (start_x + length * np.sin(angle)).astype(np.int32) 34 | end_y = (start_y + length * np.cos(angle)).astype(np.int32) 35 | 36 | cv2.line(mask, (start_y, start_x), (end_y, end_x), 1, brush_w) 37 | start_x, start_y = end_x, end_y 38 | return mask*255 39 | 40 | mask_img = Image.fromarray(npmask().astype(np.uint8)) 41 | mask_img.save("test.png") 42 | 43 | random_ff_mask(name="ff_mask") 44 | -------------------------------------------------------------------------------- /util/generate_coco_flist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | #import pycocotools 4 | from pycocotools.coco import COCO 5 | import pickle as pkl 6 | # Load cls name 7 | cls_name = sys.argv[1] 8 | dataDir = '/unsullied/sharefs/linhangyu/Inpainting/Data/COCO/annotations_trainval2017' 9 | dataType = 'train2017' 10 | prefix = 'instances' 11 | annFile = '%s/annotations/%s_%s.json'%(dataDir,prefix,dataType) 12 | coco = COCO(annFile) 13 | 14 | 15 | catIds = coco.getCatIds(catNms=[cls_name]); 16 | #print(catIds, sep=' ', end='n', file=sys.stdout, flush=False) 17 | imgIds = coco.getImgIds(catIds=catIds[0] ); 18 | #imgs = coco.loadImgs(imgIds) 19 | annIds = coco.getAnnIds(imgIds=imgIds, catIds=catIds, iscrowd=None) 20 | anns = coco.loadAnns(annIds) 21 | 22 | print(len(anns)) 23 | imgId_dict = {} 24 | #imgIds = coco.getImgIds(imgIds = [324158]) 25 | with open("coco_{}_{}_flist.txt".format(cls_name, dataType), 'w') as f ,\ 26 | open("coco_{}_{}_bbox_flist.txt".format(cls_name, dataType), 'w') as fb: 27 | for ann in anns: 28 | imgId = ann['image_id'] 29 | if imgId in imgId_dict: 30 | continue 31 | else: 32 | imgId_dict[imgId] = 1 33 | img = coco.loadImgs(imgId)[0] 34 | print(img) 35 | file_path = '/unsullied/sharefs/linhangyu/Inpainting/Data/COCO/images/%s/%s/%s'%(dataType,dataType,img['file_name']) 36 | if not os.path.exists(file_path): 37 | print(file_path) 38 | f.write(file_path+'\n') 39 | bbox_path = '/unsullied/sharefs/linhangyu/Inpainting/Data/COCO/bboxes/%s/%s'%(dataType, img['file_name'].replace('jpg', 'pkl')) 40 | pkl.dump({"bbox":ann['bbox'], 'shape':(img['width'], img['height'])}, open(bbox_path, 'wb')) 41 | fb.write(bbox_path+'\n') 42 | -------------------------------------------------------------------------------- /util/generate_voc_flists.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | path = "/unsullied/sharefs/_research_detection/GeneralDetection/VOC/VOCdevkit/VOC2012" 5 | image_path = "/unsullied/sharefs/_research_detection/GeneralDetection/VOC/VOCdevkit/VOC2012/JPEGImages" 6 | annotation_path = "/unsullied/sharefs/_research_detection/GeneralDetection/VOC/VOCdevkit/VOC2012/Annotations" 7 | cls_name = sys.argv[1] 8 | train_file_name = os.path.join(path, "ImageSets/Main/{}_train.txt".format(cls_name)) 9 | val_file_name = os.path.join(path, "ImageSets/Main/{}_val.txt".format(cls_name)) 10 | #test_file_name = os.path.join(path, "ImageSets/Main/{}_val.txt".format(cls_name)) 11 | all_cls = ['bicycle', 'bird', 'boat', 'bottle', 'car', 'cat', 'chair', 'diningtable', 'dog', 'horse', 'person', 'pottedplant', 'motorbike', 'sheep', 'sofa', 'tvmonitor'] 12 | cls_dict = {term:i for i, term in enumerate(all_cls)} 13 | 14 | def generate_flist(file_name, cls_name, mode): 15 | with open(file_name, 'r') as f: 16 | fterms = f.read().splitlines() 17 | with open("/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_all_{}_flist.txt".format(mode), 'a') as f1, \ 18 | open("/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_all_{}_bbox_flist.txt".format(mode), "a") as f2: 19 | for term in fterms: 20 | terms = term.strip().split() 21 | file = terms[0] 22 | e = terms[1] 23 | print(terms, file, e) 24 | if e == '1': 25 | # f1.write(os.path.join(image_path, file+'.jpg') + '\t{}\n'.format(cls_dict[cls_name])) 26 | # f2.write(os.path.join(annotation_path, file+'.xml') + '\t{}\n'.format(cls_dict[cls_name])) 27 | f1.write(os.path.join(image_path, file+'.jpg') + '\n') 28 | f2.write(os.path.join(annotation_path, file+'.xml') + '\n') 29 | for cls_name in all_cls: 30 | train_file_name = os.path.join(path, "ImageSets/Main/{}_train.txt".format(cls_name)) 31 | val_file_name = os.path.join(path, "ImageSets/Main/{}_val.txt".format(cls_name)) 32 | generate_flist(train_file_name, cls_name, 'train') 33 | generate_flist(val_file_name, cls_name, 'val') 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # An Reimplemented version of DeepFillv2. 3 | I reimplement this model in pytorch(which is more familiar to me) in https://github.com/avalonstrel/GatedConvolution_pytorch. I provide a pre-trained model on Places2 and some results. 4 | This version will not be updated. Sorry. 5 | **Update (Aug, 2018)**: 6 | The main files I modify is inpaint_ops.py, inpaint_model_gc.py, and train.py. I add mask_from_fnames.py for add masks from Data(voc or coco). And I will refactor this project soon. ( To use this Project, you can refer the official version of DeepFillv1 since it is modified from the DeepFillv1) 7 | 8 | 9 | ## Run (From [DeepFillv1](https://github.com/JiahuiYu/generative_inpainting)) 10 | 0. Requirements: 11 | * Install python3. 12 | * Install [tensorflow](https://www.tensorflow.org/install/) (tested on Release 1.3.0, 1.4.0, 1.5.0, 1.6.0, 1.7.0). 13 | * Install tensorflow toolkit [neuralgym](https://github.com/JiahuiYu/neuralgym) (run `pip install git+https://github.com/JiahuiYu/neuralgym`). 14 | 1. Training: 15 | * Prepare training images filelist ([example](https://github.com/JiahuiYu/generative_inpainting/issues/15)). 16 | * Modify [inpaint.yml](/inpaint.yml) to set DATA_FLIST, LOG_DIR, IMG_SHAPES and other parameters. 17 | * Run `python train.py`. 18 | 2. Resume training: 19 | * Modify MODEL_RESTORE flag in [inpaint.yml](/inpaint.yml). E.g., MODEL_RESTORE: 20180115220926508503_places2_model. 20 | * Run `python train.py`. 21 | 3. Testing: 22 | * Run `python test.py --image examples/input.png --mask examples/mask.png --output examples/output.png --checkpoint model_logs/your_model_dir`.(I have not test) 23 | 24 | 4. Still have questions? 25 | * If you still have questions (e.g.: How filelist looks like? How to use multi-gpus? How to do batch testing?), please first search over closed issues. If the problem is not solved, please open a new issue.(Refer the [DeepFillv1](https://github.com/JiahuiYu/generative_inpainting)) 26 | 27 | ## Results(Still Testing) 28 | 29 | ## Pretrained Model(Still Testing) 30 | 31 | ## Acknowledgments 32 | My project acknowledge the official code DeepFillv1 and SNGAN. Especially, thanks for the authors of this amazing algorithm. 33 | -------------------------------------------------------------------------------- /sn.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import warnings 3 | 4 | 5 | NO_OPS = 'NO_OPS' 6 | 7 | 8 | def _l2normalize(v, eps=1e-12): 9 | return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps) 10 | 11 | 12 | def spectral_normed_weight(W, u=None, num_iters=1, update_collection=None, with_sigma=False, name="sn_w"): 13 | # Usually num_iters = 1 will be enough 14 | W_shape = W.shape.as_list() 15 | W_reshaped = tf.reshape(W, [-1, W_shape[-1]]) 16 | if u is None: 17 | u = tf.get_variable(name+"_u", [1, W_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False) 18 | def power_iteration(i, u_i, v_i): 19 | v_ip1 = _l2normalize(tf.matmul(u_i, tf.transpose(W_reshaped))) 20 | u_ip1 = _l2normalize(tf.matmul(v_ip1, W_reshaped)) 21 | return i + 1, u_ip1, v_ip1 22 | _, u_final, v_final = tf.while_loop( 23 | cond=lambda i, _1, _2: i < num_iters, 24 | body=power_iteration, 25 | loop_vars=(tf.constant(0, dtype=tf.int32), 26 | u, tf.zeros(dtype=tf.float32, shape=[1, W_reshaped.shape.as_list()[0]])) 27 | ) 28 | if update_collection is None: 29 | warnings.warn('Setting update_collection to None will make u being updated every W execution. This maybe undesirable' 30 | '. Please consider using a update collection instead.') 31 | sigma = tf.matmul(tf.matmul(v_final, W_reshaped), tf.transpose(u_final))[0, 0] 32 | # sigma = tf.reduce_sum(tf.matmul(u_final, tf.transpose(W_reshaped)) * v_final) 33 | W_bar = W_reshaped / sigma 34 | with tf.control_dependencies([u.assign(u_final)]): 35 | W_bar = tf.reshape(W_bar, W_shape) 36 | else: 37 | sigma = tf.matmul(tf.matmul(v_final, W_reshaped), tf.transpose(u_final))[0, 0] 38 | # sigma = tf.reduce_sum(tf.matmul(u_final, tf.transpose(W_reshaped)) * v_final) 39 | W_bar = W_reshaped / sigma 40 | W_bar = tf.reshape(W_bar, W_shape) 41 | # Put NO_OPS to not update any collection. This is useful for the second call of discriminator if the update_op 42 | # has already been collected on the first call. 43 | if update_collection != NO_OPS: 44 | tf.add_to_collection(update_collection, u.assign(u_final)) 45 | if with_sigma: 46 | return W_bar, sigma 47 | else: 48 | return W_bar 49 | -------------------------------------------------------------------------------- /util/test_voc_mask.py: -------------------------------------------------------------------------------- 1 | import random 2 | import time 3 | import numpy as np 4 | import cv2 5 | import tensorflow as tf 6 | from bs4 import BeautifulSoup 7 | from neuralgym.ops.image_ops import np_random_crop 8 | 9 | def read_bbox_shapes(filename): 10 | #file_path = os.path.join(self.path, "Annotations", filename) 11 | with open(filename, 'r') as reader: 12 | xml = reader.read() 13 | soup = BeautifulSoup(xml, 'xml') 14 | size = {} 15 | for tag in soup.size: 16 | if tag.string != "\n": 17 | size[tag.name] = int(tag.string) 18 | objects = soup.find_all('object') 19 | bndboxs = [] 20 | for obj in objects: 21 | bndbox = {} 22 | for tag in obj.bndbox: 23 | if tag.string != '\n': 24 | bndbox[tag.name] = int(tag.string) 25 | 26 | bbox = [bndbox['ymin'], bndbox['xmin'], bndbox['ymax']-bndbox['ymin'], bndbox['xmax']-bndbox['xmin']] 27 | bndboxs.append(bbox) 28 | #print(bndboxs, size) 29 | return bndboxs, (size['height'], size['width']) 30 | 31 | 32 | def bbox2mask( bbox, height, width, delta_h, delta_w, name='mask'): 33 | """Generate mask tensor from bbox. 34 | 35 | Args: 36 | bbox: configuration tuple, (top, left, height, width) 37 | config: Config should have configuration including IMG_SHAPES, 38 | MAX_DELTA_HEIGHT, MAX_DELTA_WIDTH. 39 | 40 | Returns: 41 | tf.Tensor: output with shape [1, H, W, 1] 42 | 43 | """ 44 | 45 | mask = np.zeros(( height, width, 1), np.float32) 46 | h = int(0.1*height)+np.random.randint(int(height*0.2+1)) 47 | w = int(0.1*width)+np.random.randint(int(width*0.2)+1) 48 | print(bbox[0]+h,bbox[0]+bbox[2]-h, 49 | bbox[1]+w,bbox[1]+bbox[3]-w) 50 | mask[bbox[0]+h:bbox[0]+bbox[2]-h, 51 | bbox[1]+w:bbox[1]+bbox[3]-w, :] = 1. 52 | return mask 53 | bbox, shape = read_bbox_shapes("/unsullied/sharefs/_research_detection/GeneralDetection/VOC/VOCdevkit/VOC2012/Annotations/2008_000008.xml") 54 | mask = bbox2mask(bbox[0], shape[0], shape[1], 0.3, 0.3, name='mask') 55 | mask = cv2.resize(mask, (256,256)) 56 | def show(mode, pil_numpy): 57 | print(mode, len(",".join([str(i) for i in pil_numpy.flatten() if i != 0]))) 58 | #print(show("mask", mask)) 59 | print(mask.shape) 60 | show("mask", ((mask>0).astype(np.int8)).reshape((256,256,1))) 61 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import cv2 4 | import numpy as np 5 | import tensorflow as tf 6 | import neuralgym as ng 7 | 8 | from inpaint_model import InpaintCAModel 9 | 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument('--image', default='', type=str, 13 | help='The filename of image to be completed.') 14 | parser.add_argument('--mask', default='', type=str, 15 | help='The filename of mask, value 255 indicates mask.') 16 | parser.add_argument('--output', default='output.png', type=str, 17 | help='Where to write output.') 18 | parser.add_argument('--checkpoint_dir', default='', type=str, 19 | help='The directory of tensorflow checkpoint.') 20 | 21 | 22 | if __name__ == "__main__": 23 | ng.get_gpus(1) 24 | args = parser.parse_args() 25 | 26 | model = InpaintCAModel() 27 | image = cv2.imread(args.image) 28 | mask = cv2.imread(args.mask) 29 | 30 | assert image.shape == mask.shape 31 | 32 | h, w, _ = image.shape 33 | grid = 8 34 | image = image[:h//grid*grid, :w//grid*grid, :] 35 | mask = mask[:h//grid*grid, :w//grid*grid, :] 36 | print('Shape of image: {}'.format(image.shape)) 37 | 38 | image = np.expand_dims(image, 0) 39 | mask = np.expand_dims(mask, 0) 40 | input_image = np.concatenate([image, mask], axis=2) 41 | print(input_image.shape) 42 | sess_config = tf.ConfigProto() 43 | sess_config.gpu_options.allow_growth = True 44 | with tf.Session(config=sess_config) as sess: 45 | input_image = tf.constant(input_image, dtype=tf.float32) 46 | output = model.build_server_graph(input_image) 47 | output = (output + 1.) * 127.5 48 | output = tf.reverse(output, [-1]) 49 | output = tf.saturate_cast(output, tf.uint8) 50 | # load pretrained model 51 | vars_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) 52 | assign_ops = [] 53 | for var in vars_list: 54 | vname = var.name 55 | from_name = vname 56 | var_value = tf.contrib.framework.load_variable(args.checkpoint_dir, 57 | from_name) 58 | assign_ops.append(tf.assign(var, var_value)) 59 | sess.run(assign_ops) 60 | print('Model loaded.') 61 | result = sess.run(output) 62 | cv2.imwrite(args.output, result[0][:, :, ::-1]) 63 | -------------------------------------------------------------------------------- /config/inpaint.yml: -------------------------------------------------------------------------------- 1 | # parameters 2 | DATASET: 'places2' # 'tmnist', 'dtd', 'places2', 'celeba', 'imagenet', 'cityscapes' 3 | MASKDATASET: 'irrmask' 4 | RANDOM_CROP: False 5 | MASKFROMFILE: False 6 | VAL: True 7 | LOG_DIR: full_model_places2_256 8 | MODEL_RESTORE: '' # '20180115220926508503_jyugpu0_places2_NORMAL_wgan_gp_full_model' 9 | 10 | GAN: 'sn_pgan' # 'dcgan', 'lsgan', 'wgan_gp', 'one_wgan_gp' 11 | PRETRAIN_COARSE_NETWORK: False 12 | GAN_LOSS_ALPHA: 0.001 # dcgan: 0.0008, wgan: 0.0005, onegan: 0.001 13 | WGAN_GP_LAMBDA: 10 14 | COARSE_L1_ALPHA: 1.2 15 | L1_LOSS_ALPHA: 1.2 16 | AE_LOSS_ALPHA: 1.2 17 | GAN_WITH_MASK: True 18 | GAN_WITH_GUIDE: False 19 | DISCOUNTED_MASK: True 20 | RANDOM_SEED: False 21 | PADDING: 'SAME' 22 | 23 | # training 24 | NUM_GPUS: 1 25 | GPU_ID: -1 # -1 indicate select any available one, otherwise select gpu ID, e.g. [0,1,3] 26 | TRAIN_SPE: 10000 27 | MAX_ITERS: 1000000 28 | VIZ_MAX_OUT: 10 29 | GRADS_SUMMARY: False 30 | GRADIENT_CLIP: False 31 | GRADIENT_CLIP_VALUE: 0.1 32 | VAL_PSTEPS: 1000 33 | 34 | # data 35 | DATA_FLIST: 36 | # https://github.com/JiahuiYu/progressive_growing_of_gans_tf 37 | celebahq: [ 38 | 'data/celeba_hq/train_shuffled.flist', 39 | 'data/celeba_hq/validation_static_view.flist' 40 | ] 41 | # http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html, please to use RANDOM_CROP: True 42 | celeba: [ 43 | 'data/celeba/train_shuffled.flist', 44 | 'data/celeba/validation_static_view.flist' 45 | ] 46 | # http://places2.csail.mit.edu/, please download the high-resolution dataset and use RANDOM_CROP: True 47 | places2: [ 48 | '../Data/PlacesData/train_list.txt', 49 | '../Data/PlacesData/val_list.txt' 50 | ] 51 | # http://www.image-net.org/, please use RANDOM_CROP: True 52 | imagenet: [ 53 | 'data/imagenet/train_shuffled.flist', 54 | 'data/imagenet/validation_static_view.flist', 55 | ] 56 | irrmask: [ 57 | '../Data/MaskData/irrmask_flist.txt', 58 | '../Data/MaskData/irrmask_flist.txt' 59 | ] 60 | horse: [ 61 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_horse_train_flist.txt', 62 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_horse_val_flist.txt' 63 | ] 64 | horse_mask: [ 65 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_horse_bbox_train_flist.txt', 66 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_horse_bbox_val_flist.txt' 67 | ] 68 | 69 | STATIC_VIEW_SIZE: 30 70 | IMG_SHAPES: [256, 256, 3] 71 | HEIGHT: 128 72 | WIDTH: 128 73 | MAX_DELTA_HEIGHT: 32 74 | MAX_DELTA_WIDTH: 32 75 | BATCH_SIZE: 16 76 | VERTICAL_MARGIN: 0 77 | HORIZONTAL_MARGIN: 0 78 | MAXVERTEX: 5 79 | MAXANGLE: 4.0 #pi 80 | MAXLENGTH: 40 81 | MAXBRUSHWIDTH: 10 82 | # loss 83 | AE_LOSS: True 84 | L1_LOSS: True 85 | GLOBAL_DCGAN_LOSS_ALPHA: 1. 86 | GLOBAL_WGAN_LOSS_ALPHA: 1. 87 | 88 | # loss legacy 89 | LOAD_VGG_MODEL: False 90 | VGG_MODEL_FILE: data/model_zoo/vgg16.npz 91 | FEATURE_LOSS: False 92 | GRAMS_LOSS: False 93 | TV_LOSS: False 94 | TV_LOSS_ALPHA: 0. 95 | FEATURE_LOSS_ALPHA: 0.01 96 | GRAMS_LOSS_ALPHA: 50 97 | SPATIAL_DISCOUNTING_GAMMA: 0.9 98 | -------------------------------------------------------------------------------- /config/inpaint_voc_dog.yml: -------------------------------------------------------------------------------- 1 | # parameters 2 | DATASET: 'dog' # 'tmnist', 'dtd', 'places2', 'celeba', 'imagenet', 'cityscapes' 3 | MASKDATASET: 'dog_mask' 4 | RANDOM_CROP: False 5 | MASKFROMFILE: True 6 | VAL: False 7 | LOG_DIR: full_model_voc_dog_256 8 | MODEL_RESTORE: '' # '20180115220926508503_jyugpu0_places2_NORMAL_wgan_gp_full_model' 9 | 10 | GAN: 'sn_pgan' # 'dcgan', 'lsgan', 'wgan_gp', 'one_wgan_gp' 11 | PRETRAIN_COARSE_NETWORK: False 12 | GAN_LOSS_ALPHA: 0.001 # dcgan: 0.0008, wgan: 0.0005, onegan: 0.001 13 | WGAN_GP_LAMBDA: 10 14 | COARSE_L1_ALPHA: 1.2 15 | L1_LOSS_ALPHA: 1.2 16 | AE_LOSS_ALPHA: 1.2 17 | GAN_WITH_MASK: True 18 | GAN_WITH_GUIDE: False 19 | DISCOUNTED_MASK: True 20 | RANDOM_SEED: False 21 | PADDING: 'SAME' 22 | 23 | # training 24 | NUM_GPUS: 1 25 | GPU_ID: -1 # -1 indicate select any available one, otherwise select gpu ID, e.g. [0,1,3] 26 | TRAIN_SPE: 10000 27 | MAX_ITERS: 1000000 28 | VIZ_MAX_OUT: 10 29 | GRADS_SUMMARY: False 30 | GRADIENT_CLIP: False 31 | GRADIENT_CLIP_VALUE: 0.1 32 | VAL_PSTEPS: 1000 33 | 34 | # data 35 | DATA_FLIST: 36 | # https://github.com/JiahuiYu/progressive_growing_of_gans_tf 37 | celebahq: [ 38 | 'data/celeba_hq/train_shuffled.flist', 39 | 'data/celeba_hq/validation_static_view.flist' 40 | ] 41 | # http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html, please to use RANDOM_CROP: True 42 | celeba: [ 43 | 'data/celeba/train_shuffled.flist', 44 | 'data/celeba/validation_static_view.flist' 45 | ] 46 | # http://places2.csail.mit.edu/, please download the high-resolution dataset and use RANDOM_CROP: True 47 | places2: [ 48 | '../Data/PlacesData/train_list.txt', 49 | '../Data/PlacesData/val_list.txt' 50 | ] 51 | # http://www.image-net.org/, please use RANDOM_CROP: True 52 | imagenet: [ 53 | 'data/imagenet/train_shuffled.flist', 54 | 'data/imagenet/validation_static_view.flist', 55 | ] 56 | irrmask: [ 57 | '../Data/MaskData/irrmask_flist.txt', 58 | '../Data/MaskData/irrmask_flist.txt' 59 | ] 60 | dog: [ 61 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_dog_train_flist.txt', 62 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_dog_val_flist.txt' 63 | ] 64 | dog_mask: [ 65 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_dog_train_bbox_flist.txt', 66 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_dog_val_bbox_flist.txt' 67 | ] 68 | 69 | STATIC_VIEW_SIZE: 30 70 | IMG_SHAPES: [256, 256, 3] 71 | MASK_SHAPES: [256, 256, 1] 72 | HEIGHT: 128 73 | WIDTH: 128 74 | MAX_DELTA_HEIGHT: 32 75 | MAX_DELTA_WIDTH: 32 76 | BATCH_SIZE: 16 77 | VERTICAL_MARGIN: 0 78 | HORIZONTAL_MARGIN: 0 79 | MAXVERTEX: 5 80 | MAXANGLE: 4.0 #pi 81 | MAXLENGTH: 40 82 | MAXBRUSHWIDTH: 10 83 | # loss 84 | AE_LOSS: True 85 | L1_LOSS: True 86 | GLOBAL_DCGAN_LOSS_ALPHA: 1. 87 | GLOBAL_WGAN_LOSS_ALPHA: 1. 88 | 89 | # loss legacy 90 | LOAD_VGG_MODEL: False 91 | VGG_MODEL_FILE: data/model_zoo/vgg16.npz 92 | FEATURE_LOSS: False 93 | GRAMS_LOSS: False 94 | TV_LOSS: False 95 | TV_LOSS_ALPHA: 0. 96 | FEATURE_LOSS_ALPHA: 0.01 97 | GRAMS_LOSS_ALPHA: 50 98 | SPATIAL_DISCOUNTING_GAMMA: 0.9 99 | -------------------------------------------------------------------------------- /config/inpaint_voc_bird.yml: -------------------------------------------------------------------------------- 1 | # parameters 2 | DATASET: 'bird' # 'tmnist', 'dtd', 'places2', 'celeba', 'imagenet', 'cityscapes' 3 | MASKDATASET: 'bird_mask' 4 | RANDOM_CROP: False 5 | MASKFROMFILE: True 6 | VAL: False 7 | LOG_DIR: full_model_voc_bird_256 8 | MODEL_RESTORE: '' # '20180115220926508503_jyugpu0_places2_NORMAL_wgan_gp_full_model' 9 | 10 | GAN: 'sn_pgan' # 'dcgan', 'lsgan', 'wgan_gp', 'one_wgan_gp' 11 | PRETRAIN_COARSE_NETWORK: False 12 | GAN_LOSS_ALPHA: 0.001 # dcgan: 0.0008, wgan: 0.0005, onegan: 0.001 13 | WGAN_GP_LAMBDA: 10 14 | COARSE_L1_ALPHA: 1.2 15 | L1_LOSS_ALPHA: 1.2 16 | AE_LOSS_ALPHA: 1.2 17 | GAN_WITH_MASK: True 18 | GAN_WITH_GUIDE: False 19 | DISCOUNTED_MASK: True 20 | RANDOM_SEED: False 21 | PADDING: 'SAME' 22 | 23 | # training 24 | NUM_GPUS: 1 25 | GPU_ID: -1 # -1 indicate select any available one, otherwise select gpu ID, e.g. [0,1,3] 26 | TRAIN_SPE: 10000 27 | MAX_ITERS: 1000000 28 | VIZ_MAX_OUT: 10 29 | GRADS_SUMMARY: False 30 | GRADIENT_CLIP: False 31 | GRADIENT_CLIP_VALUE: 0.1 32 | VAL_PSTEPS: 1000 33 | 34 | # data 35 | DATA_FLIST: 36 | # https://github.com/JiahuiYu/progressive_growing_of_gans_tf 37 | celebahq: [ 38 | 'data/celeba_hq/train_shuffled.flist', 39 | 'data/celeba_hq/validation_static_view.flist' 40 | ] 41 | # http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html, please to use RANDOM_CROP: True 42 | celeba: [ 43 | 'data/celeba/train_shuffled.flist', 44 | 'data/celeba/validation_static_view.flist' 45 | ] 46 | # http://places2.csail.mit.edu/, please download the high-resolution dataset and use RANDOM_CROP: True 47 | places2: [ 48 | '../Data/PlacesData/train_list.txt', 49 | '../Data/PlacesData/val_list.txt' 50 | ] 51 | # http://www.image-net.org/, please use RANDOM_CROP: True 52 | imagenet: [ 53 | 'data/imagenet/train_shuffled.flist', 54 | 'data/imagenet/validation_static_view.flist', 55 | ] 56 | irrmask: [ 57 | '../Data/MaskData/irrmask_flist.txt', 58 | '../Data/MaskData/irrmask_flist.txt' 59 | ] 60 | bird: [ 61 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_bird_train_flist.txt', 62 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_bird_val_flist.txt' 63 | ] 64 | bird_mask: [ 65 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_bird_train_bbox_flist.txt', 66 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_bird_val_bbox_flist.txt' 67 | ] 68 | 69 | STATIC_VIEW_SIZE: 30 70 | IMG_SHAPES: [256, 256, 3] 71 | MASK_SHAPES: [256, 256, 1] 72 | HEIGHT: 128 73 | WIDTH: 128 74 | MAX_DELTA_HEIGHT: 32 75 | MAX_DELTA_WIDTH: 32 76 | BATCH_SIZE: 16 77 | VERTICAL_MARGIN: 0 78 | HORIZONTAL_MARGIN: 0 79 | MAXVERTEX: 5 80 | MAXANGLE: 4.0 #pi 81 | MAXLENGTH: 40 82 | MAXBRUSHWIDTH: 10 83 | # loss 84 | AE_LOSS: True 85 | L1_LOSS: True 86 | GLOBAL_DCGAN_LOSS_ALPHA: 1. 87 | GLOBAL_WGAN_LOSS_ALPHA: 1. 88 | 89 | # loss legacy 90 | LOAD_VGG_MODEL: False 91 | VGG_MODEL_FILE: data/model_zoo/vgg16.npz 92 | FEATURE_LOSS: False 93 | GRAMS_LOSS: False 94 | TV_LOSS: False 95 | TV_LOSS_ALPHA: 0. 96 | FEATURE_LOSS_ALPHA: 0.01 97 | GRAMS_LOSS_ALPHA: 50 98 | SPATIAL_DISCOUNTING_GAMMA: 0.9 99 | -------------------------------------------------------------------------------- /config/inpaint_voc.yml: -------------------------------------------------------------------------------- 1 | # parameters 2 | DATASET: 'horse' # 'tmnist', 'dtd', 'places2', 'celeba', 'imagenet', 'cityscapes' 3 | MASKDATASET: 'horse_mask' 4 | RANDOM_CROP: False 5 | MASKFROMFILE: True 6 | VAL: True 7 | LOG_DIR: full_model_voc_horse_256 8 | MODEL_RESTORE: '' # '20180115220926508503_jyugpu0_places2_NORMAL_wgan_gp_full_model' 9 | 10 | GAN: 'sn_pgan' # 'dcgan', 'lsgan', 'wgan_gp', 'one_wgan_gp' 11 | PRETRAIN_COARSE_NETWORK: False 12 | GAN_LOSS_ALPHA: 0.001 # dcgan: 0.0008, wgan: 0.0005, onegan: 0.001 13 | WGAN_GP_LAMBDA: 10 14 | COARSE_L1_ALPHA: 1.2 15 | L1_LOSS_ALPHA: 1.2 16 | AE_LOSS_ALPHA: 1.2 17 | GAN_WITH_MASK: True 18 | GAN_WITH_GUIDE: False 19 | DISCOUNTED_MASK: True 20 | RANDOM_SEED: False 21 | PADDING: 'SAME' 22 | 23 | # training 24 | NUM_GPUS: 1 25 | GPU_ID: -1 # -1 indicate select any available one, otherwise select gpu ID, e.g. [0,1,3] 26 | TRAIN_SPE: 10000 27 | MAX_ITERS: 1000000 28 | VIZ_MAX_OUT: 10 29 | GRADS_SUMMARY: False 30 | GRADIENT_CLIP: False 31 | GRADIENT_CLIP_VALUE: 0.1 32 | VAL_PSTEPS: 5000 33 | 34 | # data 35 | DATA_FLIST: 36 | # https://github.com/JiahuiYu/progressive_growing_of_gans_tf 37 | celebahq: [ 38 | 'data/celeba_hq/train_shuffled.flist', 39 | 'data/celeba_hq/validation_static_view.flist' 40 | ] 41 | # http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html, please to use RANDOM_CROP: True 42 | celeba: [ 43 | 'data/celeba/train_shuffled.flist', 44 | 'data/celeba/validation_static_view.flist' 45 | ] 46 | # http://places2.csail.mit.edu/, please download the high-resolution dataset and use RANDOM_CROP: True 47 | places2: [ 48 | '../Data/PlacesData/train_list.txt', 49 | '../Data/PlacesData/val_list.txt' 50 | ] 51 | # http://www.image-net.org/, please use RANDOM_CROP: True 52 | imagenet: [ 53 | 'data/imagenet/train_shuffled.flist', 54 | 'data/imagenet/validation_static_view.flist', 55 | ] 56 | irrmask: [ 57 | '../Data/MaskData/irrmask_flist.txt', 58 | '../Data/MaskData/irrmask_flist.txt' 59 | ] 60 | horse: [ 61 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_horse_train_flist.txt', 62 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_horse_val_flist.txt' 63 | ] 64 | horse_mask: [ 65 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_horse_train_bbox_flist.txt', 66 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_horse_val_bbox_flist.txt' 67 | ] 68 | 69 | STATIC_VIEW_SIZE: 30 70 | IMG_SHAPES: [256, 256, 3] 71 | MASK_SHAPES: [256, 256, 1] 72 | HEIGHT: 128 73 | WIDTH: 128 74 | MAX_DELTA_HEIGHT: 32 75 | MAX_DELTA_WIDTH: 32 76 | BATCH_SIZE: 16 77 | VERTICAL_MARGIN: 0 78 | HORIZONTAL_MARGIN: 0 79 | MAXVERTEX: 5 80 | MAXANGLE: 4.0 #pi 81 | MAXLENGTH: 40 82 | MAXBRUSHWIDTH: 10 83 | # loss 84 | AE_LOSS: True 85 | L1_LOSS: True 86 | GLOBAL_DCGAN_LOSS_ALPHA: 1. 87 | GLOBAL_WGAN_LOSS_ALPHA: 1. 88 | 89 | # loss legacy 90 | LOAD_VGG_MODEL: False 91 | VGG_MODEL_FILE: data/model_zoo/vgg16.npz 92 | FEATURE_LOSS: False 93 | GRAMS_LOSS: False 94 | TV_LOSS: False 95 | TV_LOSS_ALPHA: 0. 96 | FEATURE_LOSS_ALPHA: 0.01 97 | GRAMS_LOSS_ALPHA: 50 98 | SPATIAL_DISCOUNTING_GAMMA: 0.9 99 | -------------------------------------------------------------------------------- /config/inpaint_voc_sheep.yml: -------------------------------------------------------------------------------- 1 | # parameters 2 | DATASET: 'sheep' # 'tmnist', 'dtd', 'places2', 'celeba', 'imagenet', 'cityscapes' 3 | MASKDATASET: 'sheep_mask' 4 | RANDOM_CROP: False 5 | MASKFROMFILE: True 6 | VAL: False 7 | LOG_DIR: full_model_voc_sheep_256 8 | MODEL_RESTORE: '' # '20180115220926508503_jyugpu0_places2_NORMAL_wgan_gp_full_model' 9 | 10 | GAN: 'sn_pgan' # 'dcgan', 'lsgan', 'wgan_gp', 'one_wgan_gp' 11 | PRETRAIN_COARSE_NETWORK: False 12 | GAN_LOSS_ALPHA: 0.001 # dcgan: 0.0008, wgan: 0.0005, onegan: 0.001 13 | WGAN_GP_LAMBDA: 10 14 | COARSE_L1_ALPHA: 1.2 15 | L1_LOSS_ALPHA: 1.2 16 | AE_LOSS_ALPHA: 1.2 17 | GAN_WITH_MASK: True 18 | GAN_WITH_GUIDE: False 19 | DISCOUNTED_MASK: True 20 | RANDOM_SEED: False 21 | PADDING: 'SAME' 22 | 23 | # training 24 | NUM_GPUS: 1 25 | GPU_ID: -1 # -1 indicate select any available one, otherwise select gpu ID, e.g. [0,1,3] 26 | TRAIN_SPE: 10000 27 | MAX_ITERS: 1000000 28 | VIZ_MAX_OUT: 10 29 | GRADS_SUMMARY: False 30 | GRADIENT_CLIP: False 31 | GRADIENT_CLIP_VALUE: 0.1 32 | VAL_PSTEPS: 1000 33 | 34 | # data 35 | DATA_FLIST: 36 | # https://github.com/JiahuiYu/progressive_growing_of_gans_tf 37 | celebahq: [ 38 | 'data/celeba_hq/train_shuffled.flist', 39 | 'data/celeba_hq/validation_static_view.flist' 40 | ] 41 | # http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html, please to use RANDOM_CROP: True 42 | celeba: [ 43 | 'data/celeba/train_shuffled.flist', 44 | 'data/celeba/validation_static_view.flist' 45 | ] 46 | # http://places2.csail.mit.edu/, please download the high-resolution dataset and use RANDOM_CROP: True 47 | places2: [ 48 | '../Data/PlacesData/train_list.txt', 49 | '../Data/PlacesData/val_list.txt' 50 | ] 51 | # http://www.image-net.org/, please use RANDOM_CROP: True 52 | imagenet: [ 53 | 'data/imagenet/train_shuffled.flist', 54 | 'data/imagenet/validation_static_view.flist', 55 | ] 56 | irrmask: [ 57 | '../Data/MaskData/irrmask_flist.txt', 58 | '../Data/MaskData/irrmask_flist.txt' 59 | ] 60 | sheep: [ 61 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_sheep_train_flist.txt', 62 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_sheep_val_flist.txt' 63 | ] 64 | sheep_mask: [ 65 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_sheep_train_bbox_flist.txt', 66 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_sheep_val_bbox_flist.txt' 67 | ] 68 | 69 | STATIC_VIEW_SIZE: 30 70 | IMG_SHAPES: [256, 256, 3] 71 | MASK_SHAPES: [256, 256, 1] 72 | HEIGHT: 128 73 | WIDTH: 128 74 | MAX_DELTA_HEIGHT: 32 75 | MAX_DELTA_WIDTH: 32 76 | BATCH_SIZE: 16 77 | VERTICAL_MARGIN: 0 78 | HORIZONTAL_MARGIN: 0 79 | MAXVERTEX: 5 80 | MAXANGLE: 4.0 #pi 81 | MAXLENGTH: 40 82 | MAXBRUSHWIDTH: 10 83 | # loss 84 | AE_LOSS: True 85 | L1_LOSS: True 86 | GLOBAL_DCGAN_LOSS_ALPHA: 1. 87 | GLOBAL_WGAN_LOSS_ALPHA: 1. 88 | 89 | # loss legacy 90 | LOAD_VGG_MODEL: False 91 | VGG_MODEL_FILE: data/model_zoo/vgg16.npz 92 | FEATURE_LOSS: False 93 | GRAMS_LOSS: False 94 | TV_LOSS: False 95 | TV_LOSS_ALPHA: 0. 96 | FEATURE_LOSS_ALPHA: 0.01 97 | GRAMS_LOSS_ALPHA: 50 98 | SPATIAL_DISCOUNTING_GAMMA: 0.9 99 | -------------------------------------------------------------------------------- /config/inapint_voc_person.yml: -------------------------------------------------------------------------------- 1 | # parameters 2 | DATASET: 'person' # 'tmnist', 'dtd', 'places2', 'celeba', 'imagenet', 'cityscapes' 3 | MASKDATASET: 'person_mask' 4 | RANDOM_CROP: False 5 | MASKFROMFILE: True 6 | VAL: False 7 | LOG_DIR: full_model_voc_person_256 8 | MODEL_RESTORE: '' # '20180115220926508503_jyugpu0_places2_NORMAL_wgan_gp_full_model' 9 | 10 | GAN: 'sn_pgan' # 'dcgan', 'lsgan', 'wgan_gp', 'one_wgan_gp' 11 | PRETRAIN_COARSE_NETWORK: False 12 | GAN_LOSS_ALPHA: 0.001 # dcgan: 0.0008, wgan: 0.0005, onegan: 0.001 13 | WGAN_GP_LAMBDA: 10 14 | COARSE_L1_ALPHA: 1.2 15 | L1_LOSS_ALPHA: 1.2 16 | AE_LOSS_ALPHA: 1.2 17 | GAN_WITH_MASK: True 18 | GAN_WITH_GUIDE: False 19 | DISCOUNTED_MASK: True 20 | RANDOM_SEED: False 21 | PADDING: 'SAME' 22 | 23 | # training 24 | NUM_GPUS: 1 25 | GPU_ID: -1 # -1 indicate select any available one, otherwise select gpu ID, e.g. [0,1,3] 26 | TRAIN_SPE: 10000 27 | MAX_ITERS: 1000000 28 | VIZ_MAX_OUT: 10 29 | GRADS_SUMMARY: False 30 | GRADIENT_CLIP: False 31 | GRADIENT_CLIP_VALUE: 0.1 32 | VAL_PSTEPS: 1000 33 | 34 | # data 35 | DATA_FLIST: 36 | # https://github.com/JiahuiYu/progressive_growing_of_gans_tf 37 | celebahq: [ 38 | 'data/celeba_hq/train_shuffled.flist', 39 | 'data/celeba_hq/validation_static_view.flist' 40 | ] 41 | # http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html, please to use RANDOM_CROP: True 42 | celeba: [ 43 | 'data/celeba/train_shuffled.flist', 44 | 'data/celeba/validation_static_view.flist' 45 | ] 46 | # http://places2.csail.mit.edu/, please download the high-resolution dataset and use RANDOM_CROP: True 47 | places2: [ 48 | '../Data/PlacesData/train_list.txt', 49 | '../Data/PlacesData/val_list.txt' 50 | ] 51 | # http://www.image-net.org/, please use RANDOM_CROP: True 52 | imagenet: [ 53 | 'data/imagenet/train_shuffled.flist', 54 | 'data/imagenet/validation_static_view.flist', 55 | ] 56 | irrmask: [ 57 | '../Data/MaskData/irrmask_flist.txt', 58 | '../Data/MaskData/irrmask_flist.txt' 59 | ] 60 | person: [ 61 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_person_train_flist.txt', 62 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_person_val_flist.txt' 63 | ] 64 | person_mask: [ 65 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_person_train_bbox_flist.txt', 66 | '/unsullied/sharefs/linhangyu/Inpainting/Data/VOCData/voc_person_val_bbox_flist.txt' 67 | ] 68 | 69 | STATIC_VIEW_SIZE: 30 70 | IMG_SHAPES: [256, 256, 3] 71 | MASK_SHAPES: [256, 256, 1] 72 | HEIGHT: 128 73 | WIDTH: 128 74 | MAX_DELTA_HEIGHT: 32 75 | MAX_DELTA_WIDTH: 32 76 | BATCH_SIZE: 16 77 | VERTICAL_MARGIN: 0 78 | HORIZONTAL_MARGIN: 0 79 | MAXVERTEX: 5 80 | MAXANGLE: 4.0 #pi 81 | MAXLENGTH: 40 82 | MAXBRUSHWIDTH: 10 83 | # loss 84 | AE_LOSS: True 85 | L1_LOSS: True 86 | GLOBAL_DCGAN_LOSS_ALPHA: 1. 87 | GLOBAL_WGAN_LOSS_ALPHA: 1. 88 | 89 | # loss legacy 90 | LOAD_VGG_MODEL: False 91 | VGG_MODEL_FILE: data/model_zoo/vgg16.npz 92 | FEATURE_LOSS: False 93 | GRAMS_LOSS: False 94 | TV_LOSS: False 95 | TV_LOSS_ALPHA: 0. 96 | FEATURE_LOSS_ALPHA: 0.01 97 | GRAMS_LOSS_ALPHA: 50 98 | SPATIAL_DISCOUNTING_GAMMA: 0.9 99 | -------------------------------------------------------------------------------- /config/inpaint_all_person.yml: -------------------------------------------------------------------------------- 1 | # parameters 2 | DATASET: 'all_person' # 'tmnist', 'dtd', 'places2', 'celeba', 'imagenet', 'cityscapes' 3 | MASKDATASET: 'all_person_mask' 4 | RANDOM_CROP: False 5 | MASKFROMFILE: True 6 | VAL: True 7 | LOG_DIR: full_model_voc_all_person_256 8 | MODEL_RESTORE: '' # '20180115220926508503_jyugpu0_places2_NORMAL_wgan_gp_full_model' 9 | 10 | GAN: 'sn_pgan' # 'dcgan', 'lsgan', 'wgan_gp', 'one_wgan_gp' 11 | PRETRAIN_COARSE_NETWORK: False 12 | GAN_LOSS_ALPHA: 0.001 # dcgan: 0.0008, wgan: 0.0005, onegan: 0.001 13 | WGAN_GP_LAMBDA: 10 14 | COARSE_L1_ALPHA: 1.2 15 | L1_LOSS_ALPHA: 1.2 16 | AE_LOSS_ALPHA: 1.2 17 | GAN_WITH_MASK: True 18 | GAN_WITH_GUIDE: False 19 | DISCOUNTED_MASK: True 20 | RANDOM_SEED: False 21 | PADDING: 'SAME' 22 | 23 | # training 24 | NUM_GPUS: 1 25 | GPU_ID: -1 # -1 indicate select any available one, otherwise select gpu ID, e.g. [0,1,3] 26 | TRAIN_SPE: 10000 27 | MAX_ITERS: 1000000 28 | VIZ_MAX_OUT: 10 29 | GRADS_SUMMARY: False 30 | GRADIENT_CLIP: False 31 | GRADIENT_CLIP_VALUE: 0.1 32 | VAL_PSTEPS: 50000 33 | 34 | # data 35 | DATA_FLIST: 36 | # https://github.com/JiahuiYu/progressive_growing_of_gans_tf 37 | celebahq: [ 38 | 'data/celeba_hq/train_shuffled.flist', 39 | 'data/celeba_hq/validation_static_view.flist' 40 | ] 41 | # http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html, please to use RANDOM_CROP: True 42 | celeba: [ 43 | 'data/celeba/train_shuffled.flist', 44 | 'data/celeba/validation_static_view.flist' 45 | ] 46 | # http://places2.csail.mit.edu/, please download the high-resolution dataset and use RANDOM_CROP: True 47 | places2: [ 48 | '../Data/PlacesData/train_list.txt', 49 | '../Data/PlacesData/val_list.txt' 50 | ] 51 | # http://www.image-net.org/, please use RANDOM_CROP: True 52 | imagenet: [ 53 | 'data/imagenet/train_shuffled.flist', 54 | 'data/imagenet/validation_static_view.flist', 55 | ] 56 | irrmask: [ 57 | '../Data/MaskData/irrmask_flist.txt', 58 | '../Data/MaskData/irrmask_flist.txt' 59 | ] 60 | all_person: [ 61 | '/unsullied/sharefs/linhangyu/Inpainting/Data/PersonData/all_person_train_flist.txt', 62 | '/unsullied/sharefs/linhangyu/Inpainting/Data/PersonData/all_person_val_flist.txt' 63 | # '/unsullied/sharefs/linhangyu/Inpainting/Data/Human/code/human_person_train_flist.txt', 64 | # #'/unsullied/sharefs/linhangyu/Inpainting/Data/COCO/flist/coco_person_train2017_flist.txt', 65 | # '/unsullied/sharefs/linhangyu/Inpainting/Data/PersonData/all_person_val_bbox_flist.txt' 66 | ] 67 | all_person_mask: [ 68 | #'/unsullied/sharefs/linhangyu/Inpainting/Data/COCO/flist/coco_person_train2017_bbox_flist.txt', 69 | '/unsullied/sharefs/linhangyu/Inpainting/Data/PersonData/all_person_train_bbox_flist.txt', 70 | '/unsullied/sharefs/linhangyu/Inpainting/Data/PersonData/all_person_val_bbox_flist.txt' 71 | ] 72 | 73 | STATIC_VIEW_SIZE: 30 74 | IMG_SHAPES: [256, 256, 3] 75 | MASK_SHAPES: [256, 256, 1] 76 | HEIGHT: 128 77 | WIDTH: 128 78 | MAX_DELTA_HEIGHT: 32 79 | MAX_DELTA_WIDTH: 32 80 | BATCH_SIZE: 16 81 | VERTICAL_MARGIN: 0 82 | HORIZONTAL_MARGIN: 0 83 | MAXVERTEX: 5 84 | MAXANGLE: 4.0 #pi 85 | MAXLENGTH: 40 86 | MAXBRUSHWIDTH: 10 87 | # loss 88 | AE_LOSS: True 89 | L1_LOSS: True 90 | GLOBAL_DCGAN_LOSS_ALPHA: 1. 91 | GLOBAL_WGAN_LOSS_ALPHA: 1. 92 | 93 | # loss legacy 94 | LOAD_VGG_MODEL: False 95 | VGG_MODEL_FILE: data/model_zoo/vgg16.npz 96 | FEATURE_LOSS: False 97 | GRAMS_LOSS: False 98 | TV_LOSS: False 99 | TV_LOSS_ALPHA: 0. 100 | FEATURE_LOSS_ALPHA: 0.01 101 | GRAMS_LOSS_ALPHA: 50 102 | SPATIAL_DISCOUNTING_GAMMA: 0.9 103 | -------------------------------------------------------------------------------- /util/generate_filelist.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tarfile 3 | 4 | def norilist2list(filename, noriname, prefix): 5 | with open(filename, 'w') as wf: 6 | with open(noriname, 'r') as rf: 7 | for line in rf: 8 | terms = line.split('\t') 9 | wf.write(prefix+terms[2]+'\n') 10 | 11 | 12 | def tar2list(filename, tarname, prefix): 13 | 14 | tar = tarfile.open(tarname) 15 | with open(filename, 'w') as f: 16 | for tarinfo in tar: 17 | if tarinfo.isfile(): #and os.path.exists(os.path.join(prefix, tarinfo.name)): 18 | f.write(os.path.join(prefix, tarinfo.name)+"\n") 19 | 20 | def places2list(filename, dataset_path='/unsullied/sharefs/linhangyu/Inpainting/Data/PlacesData/data_256/'): 21 | with open(filename, mode='w') as f: 22 | # a, b, c 23 | for d in os.listdir(dataset_path): 24 | #for d in dirs: 25 | if not os.path.isdir(os.path.join(dataset_path, d)): 26 | continue 27 | print(d) 28 | # airfield 29 | for dd in os.listdir(os.path.join(dataset_path, d)): 30 | if not os.path.isdir(os.path.join(dataset_path,d,dd)): 31 | continue 32 | print(dd) 33 | for file in os.listdir(os.path.join(dataset_path,d,dd)): 34 | print(file) 35 | if file[-3:] == 'jpg' and os.path.exists(os.path.join(dataset_path,d,dd,file)): 36 | #print(os.path.join(dataset_path,d,dd,file)) 37 | f.write(os.path.join(dataset_path,d,dd,file)+"\n") 38 | 39 | def placesval2list(filename, dataset_path='/unsullied/sharefs/linhangyu/Inpainting/Data/PlacesData/val_256/'): 40 | with open(filename, mode='w') as f: 41 | # a, b, c 42 | for root, dirs, files in os.walk(dataset_path): 43 | for d in dirs: 44 | # airfield 45 | for dr, dds, dfs in os.walk(os.path.join(dataset_path, d)): 46 | for dd in dds: 47 | for file in os.listdir(os.path.join(dataset_path,d,dd)): 48 | if file[-3:] == 'jpg': 49 | #print(os.path.join(dataset_path,d,dd,file)) 50 | f.write(os.path.join(dataset_path,d,dd,file)+"\n") 51 | horse_path = "/unsullied/sharefs/linhangyu/Inpainting/Data/ImagenetData/horse" 52 | #places_path = '/unsullied/sharefs/linhangyu/Inpainting/Data/PlacesData' 53 | #places2list("/unsullied/sharefs/linhangyu/Inpainting/Data/PlacesData/train_flist.txt", dataset_path='/unsullied/sharefs/linhangyu/Inpainting/Data/PlacesData/data_256/') 54 | tar2list("horse_imagenet_flist.txt","/Users/linhangyu/Downloads/n02389026/horse.tar", horse_path) 55 | #tar2list("/unsullied/sharefs/linhangyu/Inpainting/Data/PlacesData/train_large_list.txt","/unsullied/sharefs/linhangyu/Inpainting/Data/PlacesData/train_large_places365standard.tar") 56 | #tar2list("/unsullied/sharefs/linhangyu/Inpainting/Data/PlacesData/val_list.txt","/unsullied/sharefs/linhangyu/Inpainting/Data/PlacesData/val_256.tar") 57 | #tar2list("/unsullied/sharefs/linhangyu/Inpainting/Data/PlacesData/test_list.txt","/unsullied/sharefs/linhangyu/Inpainting/Data/PlacesData/test_256.tar") 58 | #tar2list("/unsullied/sharefs/linhangyu/Inpainting/Data/MaskData/irrmask_flist.txt", "/unsullied/sharefs/linhangyu/Inpainting/Data/BenchmarkData/ILSVRC2012/IRRMASK/irrmask.tar", prefix="/unsullied/sharefs/linhangyu/Inpainting/Data/BenchmarkData/ILSVRC2012/IRRMASK") 59 | #norilist2list("/unsullied/sharefs/linhangyu/Inpainting/Data/MaskData/mask_flist.txt", 60 | #"/unsullied/sharefs/linhangyu/Inpainting/Data/MaskData/mask.512.train.nori.list", prefix="/unsullied/sharefs/linhangyu/Inpainting/Data/MaskData") 61 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import socket 4 | import logging 5 | import sys 6 | 7 | import tensorflow as tf 8 | import neuralgym as ng 9 | from data_from_fnames import DataFromFNames 10 | from mask_from_fnames import DataMaskFromFNames 11 | from inpaint_model import InpaintCAModel 12 | from inpaint_model_gc import InpaintGCModel 13 | from trainer import Trainer 14 | 15 | logger = logging.getLogger() 16 | 17 | def multigpu_graph_def(model, data_mask_data, guides, config, gpu_id=0, loss_type='g'): 18 | 19 | files = None 20 | with tf.device('/cpu:0'): 21 | if config.MASKFROMFILE: 22 | images, masks = data_mask_data.data_pipeline(config.BATCH_SIZE) 23 | else: 24 | images = data_mask_data.data_pipeline(config.BATCH_SIZE) 25 | masks = None 26 | # if config.RETURN_FILE: 27 | # images, files = data.data_pipeline(config.BATCH_SIZE) 28 | # else: 29 | # images = data.data_pipeline(config.BATCH_SIZE) 30 | # if mask_data is not None: 31 | # masks = mask_data.data_pipeline(config.BATCH_SIZE) 32 | # else: 33 | # masks = None 34 | if loss_type == 'g': 35 | _, _, losses = model.build_graph_with_losses( 36 | images, masks, guides, config, summary=True, reuse=True) 37 | else: 38 | _, _, losses = model.build_graph_with_losses( 39 | images, masks, guides, config, reuse=True) 40 | if loss_type == 'g': 41 | return losses['g_loss'] 42 | elif loss_type == 'd': 43 | return losses['d_loss'] 44 | else: 45 | raise ValueError('loss type is not supported.') 46 | 47 | 48 | if __name__ == "__main__": 49 | config = ng.Config(sys.argv[1]) 50 | if config.GPU_ID != -1: 51 | ng.set_gpus(config.GPU_ID) 52 | else: 53 | ng.get_gpus(config.NUM_GPUS) 54 | # training data 55 | # Image Data 56 | with open(config.DATA_FLIST[config.DATASET][0]) as f: 57 | fnames = f.read().splitlines() 58 | # # Mask Data 59 | 60 | if config.MASKFROMFILE: 61 | with open(config.DATA_FLIST[config.MASKDATASET][0]) as f: 62 | mask_fnames = f.read().splitlines() 63 | data_mask_data = DataMaskFromFNames( 64 | list(zip(fnames, mask_fnames)), [config.IMG_SHAPES, config.MASK_SHAPES], random_crop=config.RANDOM_CROP) 65 | images, masks = data_mask_data.data_pipeline(config.BATCH_SIZE) 66 | else: 67 | data_mask_data = DataFromFNames( 68 | fnames, config.IMG_SHAPES, random_crop=config.RANDOM_CROP) 69 | images = data_mask_data.data_pipeline(config.BATCH_SIZE) 70 | masks = None 71 | 72 | guides = None 73 | # main model 74 | model = InpaintGCModel() 75 | g_vars, d_vars, losses = model.build_graph_with_losses( 76 | images, masks, guides, config=config) 77 | # validation images 78 | if config.VAL: 79 | with open(config.DATA_FLIST[config.DATASET][1]) as f: 80 | val_fnames = f.read().splitlines() 81 | with open(config.DATA_FLIST[config.MASKDATASET][1]) as f: 82 | val_mask_fnames = f.read().splitlines() 83 | # progress monitor by visualizing static images 84 | for i in range(config.STATIC_VIEW_SIZE): 85 | static_fnames = val_fnames[i:i+1] 86 | 87 | if config.MASKFROMFILE: 88 | static_mask_fnames = val_mask_fnames[i:i+1] 89 | static_images, static_masks = DataMaskFromFNames( 90 | list(zip(static_fnames,static_mask_fnames)), [config.IMG_SHAPES, config.MASK_SHAPES], 91 | nthreads=1, random_crop=config.RANDOM_CROP).data_pipeline(1) 92 | else: 93 | static_images = DataFromFNames( 94 | static_fnames, config.IMG_SHAPES, nthreads=1, 95 | random_crop=config.RANDOM_CROP).data_pipeline(1) 96 | static_masks = None 97 | 98 | static_inpainted_images = model.build_static_infer_graph( 99 | static_images, static_masks, static_masks, config, name='static_view/%d' % i) 100 | # training settings 101 | lr = tf.get_variable( 102 | 'lr', shape=[], trainable=False, 103 | initializer=tf.constant_initializer(1e-4)) 104 | d_optimizer = tf.train.AdamOptimizer(lr, beta1=0.5, beta2=0.9) 105 | g_optimizer = d_optimizer 106 | # gradient processor 107 | if config.GRADIENT_CLIP: 108 | gradient_processor = lambda grad_var: ( 109 | tf.clip_by_average_norm(grad_var[0], config.GRADIENT_CLIP_VALUE), 110 | grad_var[1]) 111 | else: 112 | gradient_processor = None 113 | # log dir 114 | log_prefix = 'model_logs/' + '_'.join([ 115 | ng.date_uid(), socket.gethostname(), config.DATASET, 116 | 'MASKED' if config.GAN_WITH_MASK else 'NORMAL', 117 | config.GAN,config.LOG_DIR]) 118 | # train discriminator with secondary trainer, should initialize before 119 | # primary trainer. 120 | discriminator_training_callback = ng.callbacks.SecondaryTrainer( 121 | pstep=1, 122 | optimizer=d_optimizer, 123 | var_list=d_vars, 124 | max_iters=5, 125 | graph_def=multigpu_graph_def, 126 | graph_def_kwargs={ 127 | 'model': model, 'data_mask_data': data_mask_data, "guides":None, 'config': config, 'loss_type': 'd'}, 128 | ) 129 | # train generator with primary trainer 130 | trainer = Trainer( 131 | optimizer=g_optimizer, 132 | var_list=g_vars, 133 | max_iters=config.MAX_ITERS, 134 | graph_def=multigpu_graph_def, 135 | grads_summary=config.GRADS_SUMMARY, 136 | gradient_processor=gradient_processor, 137 | graph_def_kwargs={ 138 | 'model': model, 'data_mask_data': data_mask_data, "guides":None, 'config': config, 'loss_type': 'g'}, 139 | spe=config.TRAIN_SPE, 140 | log_dir=log_prefix, 141 | ) 142 | # add all callbacks 143 | if not config.PRETRAIN_COARSE_NETWORK: 144 | trainer.add_callbacks(discriminator_training_callback) 145 | trainer.add_callbacks([ 146 | ng.callbacks.WeightsViewer(), 147 | ng.callbacks.ModelRestorer(trainer.context['saver'], dump_prefix='model_logs/'+config.MODEL_RESTORE+'/snap', optimistic=True), 148 | ng.callbacks.ModelSaver(config.TRAIN_SPE, trainer.context['saver'], log_prefix+'/snap'), 149 | ng.callbacks.SummaryWriter((config.VAL_PSTEPS//1), trainer.context['summary_writer'], tf.summary.merge_all()), 150 | ]) 151 | # launch training 152 | trainer.train() 153 | -------------------------------------------------------------------------------- /data_from_fnames.py: -------------------------------------------------------------------------------- 1 | import random 2 | import threading 3 | import logging 4 | import time 5 | import numpy as np 6 | import cv2 7 | import tensorflow as tf 8 | 9 | from neuralgym.data import feeding_queue_runner as queue_runner 10 | from neuralgym.data.dataset import Dataset 11 | from neuralgym.ops.image_ops import np_random_crop 12 | 13 | 14 | logger = logging.getLogger() 15 | READER_LOCK = threading.Lock() 16 | 17 | 18 | class DataFromFNames(Dataset): 19 | """Data pipeline from list of filenames. 20 | Args: 21 | fnamelists (list): A list of filenames or tuple of filenames, e.g. 22 | ['image_001.png', ...] or 23 | [('pair_image_001_0.png', 'pair_image_001_1.png'), ...]. 24 | shapes (tuple): Shapes of data, e.g. [256, 256, 3] or 25 | [[256, 256, 3], [1]]. 26 | random (bool): Read from `fnamelists` randomly (default to False). 27 | random_crop (bool): If random crop to the shape from raw image or 28 | directly resize raw images to the shape. 29 | dtypes (tf.Type): Data types, default to tf.float32. 30 | enqueue_size (int): Enqueue size for pipeline. 31 | enqueue_size (int): Enqueue size for pipeline. 32 | nthreads (int): Parallel threads for reading from data. 33 | return_fnames (bool): If True, data_pipeline will also return fnames 34 | (last tensor). 35 | filetype (str): Currently only support image. 36 | Examples: 37 | >>> fnames = ['img001.png', 'img002.png', ..., 'img999.png'] 38 | >>> data = ng.data.DataFromFNames(fnames, [256, 256, 3]) 39 | >>> images = data.data_pipeline(128) 40 | >>> sess = tf.Session(config=tf.ConfigProto()) 41 | >>> tf.train.start_queue_runners(sess) 42 | >>> for i in range(5): sess.run(images) 43 | To get file lists, you can either use file:: 44 | with open('data/images.flist') as f: 45 | fnames = f.read().splitlines() 46 | or glob:: 47 | import glob 48 | fnames = glob.glob('data/*.png') 49 | You can also create fnames tuple:: 50 | with open('images.flist') as f: 51 | image_fnames = f.read().splitlines() 52 | with open('segmentation_annotation.flist') as f: 53 | annotation_fnames = f.read().splitlines() 54 | fnames = list(zip(image_fnames, annatation_fnames)) 55 | """ 56 | 57 | def __init__(self, fnamelists, shapes, random=False, random_crop=False, 58 | fn_preprocess=None, dtypes=tf.float32, 59 | enqueue_size=32, queue_size=256, nthreads=8, 60 | return_fnames=False, filetype='image'): 61 | self.fnamelists_ = self.process_fnamelists(fnamelists) 62 | self.file_length = len(self.fnamelists_) 63 | self.random = random 64 | self.random_crop = random_crop 65 | self.filetype = filetype 66 | if isinstance(shapes[0], list): 67 | self.shapes = shapes 68 | else: 69 | self.shapes = [shapes] * len(self.fnamelists_[0]) 70 | if isinstance(dtypes, list): 71 | self.dtypes = dtypes 72 | else: 73 | self.dtypes = [dtypes] * len(self.fnamelists_[0]) 74 | self.return_fnames = return_fnames 75 | self.batch_phs = [ 76 | tf.placeholder(dtype, [None] + shape) 77 | for dtype, shape in zip(self.dtypes, self.shapes)] 78 | if self.return_fnames: 79 | self.shapes += [[]] 80 | self.dtypes += [tf.string] 81 | self.batch_phs.append(tf.placeholder(tf.string, [None])) 82 | self.enqueue_size = enqueue_size 83 | self.queue_size = queue_size 84 | self.nthreads = nthreads 85 | self.fn_preprocess = fn_preprocess 86 | if not random: 87 | self.index = 0 88 | super().__init__() 89 | self.create_queue() 90 | 91 | def process_fnamelists(self, fnamelist): 92 | if isinstance(fnamelist, list): 93 | if isinstance(fnamelist[0], str): 94 | return [(i,) for i in fnamelist] 95 | elif isinstance(fnamelist[0], tuple): 96 | return fnamelist 97 | else: 98 | raise ValueError('Type error for fnamelist.') 99 | else: 100 | raise ValueError('Type error for fnamelist.') 101 | 102 | def data_pipeline(self, batch_size): 103 | """Batch data pipeline. 104 | Args: 105 | batch_size (int): Batch size. 106 | Returns: 107 | A tensor with shape [batch_size] and self.shapes 108 | e.g. if self.shapes = ([256, 256, 3], [1]), then return 109 | [[batch_size, 256, 256, 3], [batch_size, 1]]. 110 | """ 111 | data = self._queue.dequeue_many(batch_size) 112 | return data 113 | 114 | def create_queue(self, shared_name=None, name=None): 115 | from tensorflow.python.ops import data_flow_ops, logging_ops, math_ops 116 | from tensorflow.python.framework import dtypes 117 | assert self.dtypes is not None and self.shapes is not None 118 | assert len(self.dtypes) == len(self.shapes) 119 | capacity = self.queue_size 120 | self._queue = data_flow_ops.FIFOQueue( 121 | capacity=capacity, 122 | dtypes=self.dtypes, 123 | shapes=self.shapes, 124 | shared_name=shared_name, 125 | name=name) 126 | 127 | enq = self._queue.enqueue_many(self.batch_phs) 128 | # create a queue runner 129 | queue_runner.add_queue_runner(queue_runner.QueueRunner( 130 | self._queue, [enq]*self.nthreads, 131 | feed_dict_op=[lambda: self.next_batch()], 132 | feed_dict_key=self.batch_phs)) 133 | summary_name = 'fraction_of_%d_full' % capacity 134 | logging_ops.scalar_summary("queue/%s/%s" % ( 135 | self._queue.name, summary_name), math_ops.cast( 136 | self._queue.size(), dtypes.float32) * (1. / capacity)) 137 | 138 | def read_img(self, filename): 139 | #print(filename) 140 | img = cv2.imread(filename) 141 | 142 | if img is None: 143 | #logger.info('image is None, sleep this thread for 0.1s.{}'.format(filename)) 144 | 145 | #time.sleep(0.1) 146 | return img, True 147 | if self.fn_preprocess: 148 | img = self.fn_preprocess(img) 149 | return img, False 150 | 151 | 152 | def next_batch(self): 153 | batch_data = [] 154 | for _ in range(self.enqueue_size): 155 | error = True 156 | while error: 157 | error = False 158 | if random: 159 | filenames = random.choice(self.fnamelists_) 160 | else: 161 | with READER_LOCK: 162 | filenames = self.fnamelists_[self.index] 163 | self.index = (self.index + 1) % self.file_length 164 | imgs = [] 165 | random_h = None 166 | random_w = None 167 | for i in range(len(filenames)): 168 | img, error = self.read_img(filenames[i]) 169 | 170 | if self.random_crop: 171 | img, random_h_, random_w_ = np_random_crop( 172 | img, tuple(self.shapes[i][:-1]), 173 | random_h, random_w, align=False) # use last rand 174 | else: 175 | if img is None: 176 | continue 177 | #print(np.array(img).shape,tuple(self.shapes[i][:-1])) 178 | img = cv2.resize(img, tuple(self.shapes[i][:-1])) 179 | imgs.append(img) 180 | 181 | if self.return_fnames: 182 | batch_data.append(imgs + list(filenames)) 183 | else: 184 | batch_data.append(imgs) 185 | return zip(*batch_data) 186 | 187 | def _maybe_download_and_extract(self): 188 | pass 189 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import time 2 | import logging 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | from neuralgym.utils.logger import ProgressBar 8 | from neuralgym.callbacks import CallbackLoc 9 | from neuralgym.callbacks import PeriodicCallback, OnceCallback, ScheduledCallback 10 | from neuralgym.ops.train_ops import process_gradients 11 | 12 | 13 | logger = logging.getLogger() 14 | 15 | 16 | class Trainer(object): 17 | """Trainer class for train iterative algorithm on single GPU. 18 | 19 | There are two types of trainer in neuralgym: primary trainer and 20 | secondary trainer. For primary trainer, tensorflow related instances 21 | and configurations will be initialized, e.g. init all variables, summary 22 | writer, session, start_queue_runner and others. For the secondary trainer 23 | only train_ops and losses are iteratively updated/ran. 24 | """ 25 | 26 | def __init__(self, primary=True, **context): 27 | self.context = context 28 | self.primary = primary 29 | self.callbacks = self.context.pop('callbacks', []) 30 | # contexts 31 | self.context['feed_dict'] = self.context.pop('feed_dict', {}) 32 | self.context['max_iters'] = int(self.context.pop('max_iters', 999999)) 33 | self.context['log_dir'] = self.context.pop('log_dir', '/tmp/neuralgym') 34 | self.context['spe'] = self.context.pop('spe', 1) 35 | # grads summary 36 | self.context['grads_summary'] = self.context.pop( 37 | 'grads_summary', True) 38 | # train ops and losses 39 | self._train_op = self.context.pop('train_op', None) 40 | if self._train_op is None: 41 | self._train_op, self._loss = self.train_ops_and_losses() 42 | else: 43 | self._loss = self.context.pop('loss', 0) 44 | # global step 45 | self.context['log_progress'] = self.context.pop('log_progress', True) 46 | if self.context['log_progress']: 47 | self._bar = ProgressBar() 48 | # total loss, beginning timepoint 49 | self._log_stats = [0, None] 50 | # callbacks types 51 | self._periodic_callbacks = None 52 | self._once_callbacks = None 53 | self._scheduled_callbacks = None 54 | # init primary trainer 55 | if self.primary: 56 | self.init_primary_trainer() 57 | # log context of trainer 58 | if self.primary: 59 | logger.info(' Context Of Primary Trainer '.center(80, '-')) 60 | else: 61 | logger.info(' Context Of Secondary Trainer '.center(80, '-')) 62 | for k in self.context: 63 | logger.info(k + ': ' + str(self.context[k])) 64 | logger.info(''.center(80, '-')) 65 | 66 | def init_primary_trainer(self): 67 | """Initialize primary trainer context including: 68 | 69 | * log_dir 70 | * global_step 71 | * sess_config 72 | * allow_growth 73 | * summary writer 74 | * saver 75 | * global_variables_initializer 76 | * start_queue_runners 77 | 78 | """ 79 | self.context['global_step'] = self.context.pop( 80 | 'global_step', tf.get_variable( 81 | 'global_step', [], dtype=tf.int32, 82 | initializer=tf.zeros_initializer(), trainable=False)) 83 | self.context['global_step_add_one'] = tf.assign_add( 84 | self.context['global_step'], 1, name='add_one_to_global_step') 85 | self.context['sess_config'] = self.context.pop( 86 | 'sess_config', tf.ConfigProto()) 87 | self.context['sess_config'].gpu_options.allow_growth = ( 88 | self.context.pop('allow_growth', True)) 89 | self.context['sess_config'].allow_soft_placement = self.context.pop( 90 | 'allow_soft_placement', True) 91 | self.context['sess'] = tf.Session(config=self.context['sess_config']) 92 | self.context['summary_writer'] = tf.summary.FileWriter( 93 | self.context['log_dir'], self.context['sess'].graph) 94 | self.context['saver'] = tf.train.Saver(tf.global_variables()) 95 | # queue runner 96 | self.context['start_queue_runners'] = self.context.pop( 97 | 'start_queue_runner', True) 98 | if self.context['start_queue_runners']: 99 | tf.train.start_queue_runners(sess=self.context['sess']) 100 | # initialization 101 | self.context['global_variables_initializer'] = self.context.pop( 102 | 'global_variables_initializer', True) 103 | if self.context['global_variables_initializer']: 104 | self.context['sess'].run(tf.global_variables_initializer()) 105 | 106 | def train(self): 107 | """Start training with callbacks. 108 | 109 | """ 110 | sess = self.context['sess'] 111 | max_iters = self.context['max_iters'] 112 | self.update_callbacks() 113 | if self.context.get('global_step') is None: 114 | step = 0 115 | global_step_add_one = None 116 | else: 117 | step = sess.run(self.context['global_step']) 118 | global_step_add_one = self.context['global_step_add_one'] 119 | # once_callbacks at train start 120 | for cb in self._once_callbacks: 121 | if cb.cb_loc == CallbackLoc.train_start: 122 | cb.run(sess) 123 | try: 124 | while step < max_iters: 125 | # update and get current step 126 | step += 1 127 | if global_step_add_one is not None: 128 | sess.run(global_step_add_one) 129 | # periodic callbacks at step start 130 | for cb in self._periodic_callbacks: 131 | if (cb.cb_loc == CallbackLoc.step_start and 132 | step % cb.pstep == 0): 133 | cb.run(sess, step) 134 | # scheduled callbacks at step start 135 | for cb in self._scheduled_callbacks: 136 | if (cb.cb_loc == CallbackLoc.step_start and 137 | step in cb.schedule): 138 | cb.run(sess, step) 139 | # run train op 140 | _, loss_value = sess.run([self._train_op, self._loss], 141 | feed_dict=self.context['feed_dict']) 142 | # if nan, exist 143 | assert not np.isnan(loss_value) 144 | # log one 145 | if self.context['log_progress']: 146 | self.progress_logger(step, loss_value) 147 | # periodic callbacks at step end 148 | for cb in self._periodic_callbacks: 149 | if (cb.cb_loc == CallbackLoc.step_end and 150 | step % cb.pstep == 0): 151 | cb.run(sess, step) 152 | # scheduled callbacks at step end 153 | for cb in self._scheduled_callbacks: 154 | if (cb.cb_loc == CallbackLoc.step_end and 155 | step in cb.schedule): 156 | cb.run(sess, step) 157 | except (KeyboardInterrupt, SystemExit): 158 | logger.info("Training is stoped.") 159 | except: 160 | raise 161 | finally: 162 | # once_callbacks at exception 163 | for cb in self._once_callbacks: 164 | if cb.cb_loc == CallbackLoc.exception: 165 | cb.run(sess) 166 | # once_callbacks at train end 167 | for cb in self._once_callbacks: 168 | if cb.cb_loc == CallbackLoc.train_end: 169 | cb.run(sess) 170 | 171 | def progress_logger(self, step, loss): 172 | """Progress bar for logging. 173 | 174 | **Note** all statistics are averaged over epoch. 175 | """ 176 | # init 177 | if self._log_stats[1] is None: 178 | self._log_stats[1] = time.time() 179 | self._log_stats[0] = loss 180 | return 181 | # update statistic 182 | self._log_stats[0] += loss 183 | # time 184 | t_start = self._log_stats[1] 185 | t_now = time.time() 186 | spe = self.context['spe'] 187 | # after running the session, the step is actually increased. 188 | step = step + 1 189 | epoch_end = (step % spe == 0) 190 | # set update step 0.1% 191 | log_per_iters = max(int(spe/1000), 10) 192 | # update progress bar per log_per_iters 193 | epoch_nums = (step - 1) // spe + 1 194 | epoch_iters = (step - 1) % spe + 1 195 | if epoch_iters % log_per_iters == 0 or epoch_end: 196 | batches_per_sec = epoch_iters / (t_now - t_start) 197 | texts = ''.join([ 198 | 'train epoch {},'.format(epoch_nums), 199 | ' iter {}/{},'.format(epoch_iters, spe), 200 | ' loss {:.6f}, {:.2f} batches/sec.'.format( 201 | self._log_stats[0]/epoch_iters, batches_per_sec), 202 | ]) 203 | # progress, if at the end of epoch, 100%; else current progress 204 | prog = 1 if epoch_end else (step / spe) % 1 205 | self._bar.progress(prog, texts) 206 | # reset 207 | if epoch_end: 208 | self._log_stats[1] = None 209 | self._log_stats[0] = 0 210 | return 211 | 212 | def add_callbacks(self, callbacks): 213 | """Add callbacks. 214 | 215 | Args: 216 | callbacks: list of callbacks 217 | 218 | """ 219 | if not isinstance(callbacks, list): 220 | callbacks = [callbacks] 221 | # keep order 222 | self.callbacks = self.callbacks + callbacks 223 | # after add callbacks, update callbacks list. 224 | self.update_callbacks() 225 | 226 | def update_callbacks(self): 227 | def _check_type(t, cb): 228 | return t == cb.__class__ or t in cb.__class__.__bases__ 229 | 230 | # clear 231 | self._periodic_callbacks = [] 232 | self._once_callbacks = [] 233 | self._scheduled_callbacks = [] 234 | # add 235 | for cb in self.callbacks: 236 | if _check_type(PeriodicCallback, cb): 237 | self._periodic_callbacks.append(cb) 238 | if _check_type(OnceCallback, cb): 239 | self._once_callbacks.append(cb) 240 | if _check_type(ScheduledCallback, cb): 241 | self._scheduled_callbacks.append(cb) 242 | 243 | def train_ops_and_losses(self): 244 | optimizer = self.context['optimizer'] 245 | loss = self.context.get('loss') 246 | var_list = self.context.get('var_list') 247 | graph_def_kwargs = self.context['graph_def_kwargs'] 248 | gradient_processor = self.context.get('gradient_processor') 249 | if loss is None: 250 | loss = self.context['graph_def'](**graph_def_kwargs) 251 | # get gradients 252 | grads = optimizer.compute_gradients(loss, var_list) 253 | if self.context['grads_summary']: 254 | for grad, var in grads: 255 | if grad is not None: 256 | tf.summary.histogram('gradients/' + var.name, grad) 257 | grads = process_gradients(grads, gradient_processor) 258 | # get operations 259 | apply_gradient_op = optimizer.apply_gradients(grads) 260 | return apply_gradient_op, loss 261 | -------------------------------------------------------------------------------- /mask_from_fnames.py: -------------------------------------------------------------------------------- 1 | import random 2 | import threading 3 | import logging 4 | import time 5 | import numpy as np 6 | import cv2 7 | import tensorflow as tf 8 | from bs4 import BeautifulSoup 9 | import pickle as pkl 10 | #from pycocotools.coco import COCO 11 | from neuralgym.data import feeding_queue_runner as queue_runner 12 | from neuralgym.data.dataset import Dataset 13 | from neuralgym.ops.image_ops import np_random_crop 14 | 15 | #from pycocotools.coco import COCO 16 | logger = logging.getLogger() 17 | READER_LOCK = threading.Lock() 18 | 19 | 20 | class DataMaskFromFNames(Dataset): 21 | """Data pipeline from list of filenames. Read th filenames and return masks from bbox or segmentation 22 | Args: 23 | fnamelists (list): A list of filenames or tuple of filenames, e.g. 24 | ['image_001.png', ...] or 25 | [('pair_image_001_0.png', 'pair_image_001_1.png'), ...]. 26 | shapes (tuple): Shapes of data, e.g. [256, 256, 3] or 27 | [[256, 256, 3], [1]]. 28 | random (bool): Read from `fnamelists` randomly (default to False). 29 | random_crop (bool): If random crop to the shape from raw image or 30 | directly resize raw images to the shape. 31 | dtypes (tf.Type): Data types, default to tf.float32. 32 | enqueue_size (int): Enqueue size for pipeline. 33 | enqueue_size (int): Enqueue size for pipeline. 34 | nthreads (int): Parallel threads for reading from data. 35 | return_fnames (bool): If True, data_pipeline will also return fnames 36 | (last tensor). 37 | filetype (str): Currently only support image. 38 | Examples: 39 | >>> fnames = ['img001.png', 'img002.png', ..., 'img999.png'] 40 | >>> data = ng.data.DataFromFNames(fnames, [256, 256, 3]) 41 | >>> images = data.data_pipeline(128) 42 | >>> sess = tf.Session(config=tf.ConfigProto()) 43 | >>> tf.train.start_queue_runners(sess) 44 | >>> for i in range(5): sess.run(images) 45 | To get file lists, you can either use file:: 46 | with open('data/images.flist') as f: 47 | fnames = f.read().splitlines() 48 | or glob:: 49 | import glob 50 | fnames = glob.glob('data/*.png') 51 | You can also create fnames tuple:: 52 | with open('images.flist') as f: 53 | image_fnames = f.read().splitlines() 54 | with open('segmentation_annotation.flist') as f: 55 | annotation_fnames = f.read().splitlines() 56 | fnames = list(zip(image_fnames, annatation_fnames)) 57 | """ 58 | # shape = [[256,256,3],[256,256,1]] 59 | def __init__(self, fnamelists, shapes, random=False, random_crop=False, 60 | fn_preprocess=None, dtypes=tf.float32, 61 | enqueue_size=32, queue_size=256, nthreads=8, 62 | return_fnames=False, from_bbox=True, filetype='image'): 63 | self.fnamelists_ = self.process_fnamelists(fnamelists) 64 | #print(self.fnamelists_) 65 | self.file_length = len(self.fnamelists_) 66 | self.random = random 67 | self.random_crop = random_crop 68 | self.filetype = filetype 69 | if isinstance(shapes[0], list): 70 | self.shapes = shapes 71 | else: 72 | self.shapes = [shapes] * len(self.fnamelists_[0]) 73 | if isinstance(dtypes, list): 74 | self.dtypes = dtypes 75 | else: 76 | self.dtypes = [dtypes] * len(self.fnamelists_[0]) 77 | 78 | self.return_fnames = return_fnames 79 | self.batch_phs = [ 80 | tf.placeholder(dtype, [None] + shape) 81 | for dtype, shape in zip(self.dtypes, self.shapes)] 82 | 83 | if self.return_fnames: 84 | self.shapes += [[]] 85 | self.dtypes += [tf.string] 86 | self.batch_phs.append(tf.placeholder(tf.string, [None])) 87 | self.enqueue_size = enqueue_size 88 | self.queue_size = queue_size 89 | self.nthreads = nthreads 90 | self.fn_preprocess = fn_preprocess 91 | if not random: 92 | self.index = 0 93 | super().__init__() 94 | self.create_queue() 95 | 96 | def process_fnamelists(self, fnamelist): 97 | if isinstance(fnamelist, list): 98 | if isinstance(fnamelist[0], str): 99 | return [(i,) for i in fnamelist] 100 | elif isinstance(fnamelist[0], tuple): 101 | return fnamelist 102 | else: 103 | raise ValueError('Type error for fnamelist.') 104 | else: 105 | raise ValueError('Type error for fnamelist.') 106 | 107 | def data_pipeline(self, batch_size): 108 | """Batch data pipeline. 109 | Args: 110 | batch_size (int): Batch size. 111 | Returns: 112 | A tensor with shape [batch_size] and self.shapes 113 | e.g. if self.shapes = ([256, 256, 3], [1]), then return 114 | [[batch_size, 256, 256, 3], [batch_size, 1]]. 115 | """ 116 | data = self._queue.dequeue_many(batch_size) 117 | return data 118 | 119 | def create_queue(self, shared_name=None, name=None): 120 | from tensorflow.python.ops import data_flow_ops, logging_ops, math_ops 121 | from tensorflow.python.framework import dtypes 122 | assert self.dtypes is not None and self.shapes is not None 123 | assert len(self.dtypes) == len(self.shapes) 124 | capacity = self.queue_size 125 | self._queue = data_flow_ops.FIFOQueue( 126 | capacity=capacity, 127 | dtypes=self.dtypes, 128 | shapes=self.shapes, 129 | shared_name=shared_name, 130 | name=name) 131 | 132 | enq = self._queue.enqueue_many(self.batch_phs) 133 | # create a queue runner 134 | queue_runner.add_queue_runner(queue_runner.QueueRunner( 135 | self._queue, [enq]*self.nthreads, 136 | feed_dict_op=[lambda: self.next_batch()], 137 | feed_dict_key=self.batch_phs)) 138 | summary_name = 'fraction_of_%d_full' % capacity 139 | logging_ops.scalar_summary("queue/%s/%s" % ( 140 | self._queue.name, summary_name), math_ops.cast( 141 | self._queue.size(), dtypes.float32) * (1. / capacity)) 142 | 143 | def read_img(self, filename): 144 | #print(filename) 145 | img = cv2.imread(filename) 146 | 147 | if img is None: 148 | #logger.info('image is None, sleep this thread for 0.1s.{}'.format(filename)) 149 | 150 | #time.sleep(0.1) 151 | return img, True 152 | if self.fn_preprocess: 153 | img = self.fn_preprocess(img) 154 | return img, False 155 | # Crowd Human 156 | def read_ch_bbox(self, path): 157 | aux_dict = pkl.load(open(path, 'rb')) 158 | bboxs = aux_dict["bbox"] 159 | bbox = random.choice(bboxs) 160 | extra = bbox['extra'] 161 | shape = aux_dict["shape"] 162 | while 'ignore' in extra and extra['ignore'] == 1 and bbox['fbox'][0] < 0 and bbox['fbox'][1] < 0: 163 | bbox = random.choice(bboxs) 164 | extra = bbox['extra'] 165 | fbox = bbox['fbox'] 166 | return [[fbox[1],fbox[0],fbox[3],fbox[2]]], (shape[1], shape[0]) 167 | 168 | def read_coco_bbox(self, path): 169 | aux_dict = pkl.load(open(path, 'rb')) 170 | bbox = aux_dict["bbox"] 171 | shape = aux_dict["shape"] 172 | #bbox = random.choice(bbox) 173 | #fbox = bbox['fbox'] 174 | return [[int(bbox[1]), int(bbox[0]), int(bbox[3]), int(bbox[2])]], (shape[1], shape[0]) 175 | 176 | def read_bbox_shapes(self, filename): 177 | if filename[-3:] == 'pkl' and 'Human' in filename: 178 | return self.read_ch_bbox(filename) 179 | elif filename[-3:] == 'pkl' and 'COCO' in filename: 180 | return self.read_coco_bbox(filename) 181 | 182 | #file_path = os.path.join(self.path, "Annotations", filename) 183 | with open(filename, 'r') as reader: 184 | xml = reader.read() 185 | soup = BeautifulSoup(xml, 'xml') 186 | size = {} 187 | for tag in soup.size: 188 | if tag.string != "\n": 189 | size[tag.name] = int(tag.string) 190 | objects = soup.find_all('object') 191 | bndboxs = [] 192 | for obj in objects: 193 | bndbox = {} 194 | for tag in obj.bndbox: 195 | if tag.string != '\n': 196 | bndbox[tag.name] = int(tag.string) 197 | 198 | bbox = [bndbox['ymin'], bndbox['xmin'], bndbox['ymax']-bndbox['ymin'], bndbox['xmax']-bndbox['xmin']] 199 | bndboxs.append(bbox) 200 | #print(bndboxs, size) 201 | return bndboxs, (size['height'], size['width']) 202 | 203 | 204 | def bbox2mask(self, bbox, height, width, delta_h, delta_w, name='mask'): 205 | """Generate mask tensor from bbox. 206 | 207 | Args: 208 | bbox: configuration tuple, (top, left, height, width) 209 | config: Config should have configuration including IMG_SHAPES, 210 | MAX_DELTA_HEIGHT, MAX_DELTA_WIDTH. 211 | 212 | Returns: 213 | tf.Tensor: output with shape [1, H, W, 1] 214 | 215 | """ 216 | 217 | mask = np.zeros(( height, width, 1), np.float32) 218 | h = int(0.1*bbox[2])+np.random.randint(int(bbox[2]*0.2+1)) 219 | w = int(0.1*bbox[3])+np.random.randint(int(bbox[3]*0.2)+1) 220 | mask[bbox[0]+h:bbox[0]+bbox[2]-h, 221 | bbox[1]+w:bbox[1]+bbox[3]-w, :] = 1. 222 | return mask 223 | 224 | def next_batch(self): 225 | batch_data = [] 226 | for _ in range(self.enqueue_size): 227 | error = True 228 | while error: 229 | error = False 230 | if random: 231 | filenames = random.choice(self.fnamelists_) 232 | else: 233 | with READER_LOCK: 234 | filenames = self.fnamelists_[self.index] 235 | self.index = (self.index + 1) % self.file_length 236 | imgs = [] 237 | masks = [] 238 | random_h = None 239 | random_w = None 240 | #print(list(filenames)) 241 | for i in range(1): 242 | #print(filenames[i]) 243 | img, error = self.read_img(filenames[0]) 244 | bboxs, shape = self.read_bbox_shapes(filenames[1]) 245 | mask = self.bbox2mask(bboxs[0], shape[0], shape[1], 32, 32 ) 246 | 247 | if self.random_crop: 248 | img, random_h_, random_w_ = np_random_crop( 249 | img, tuple(self.shapes[i][:-1]), 250 | random_h, random_w, align=False) # use last rand 251 | mask, random_h, random_w = np_random_crop( 252 | mask, tuple(self.shapes[i][:-1]), 253 | random_h, random_w, align=False) # use last rand 254 | else: 255 | if img is None or mask is None: 256 | continue 257 | img = cv2.resize(img, tuple(self.shapes[i][:-1])) 258 | mask = cv2.resize(mask, tuple(self.shapes[i][:-1])) 259 | #assert not np.isnan(((img>0).astype(np.int8)).reshape(self.shapes[i])) 260 | 261 | mask = ((mask>0).astype(np.int8)).reshape((self.shapes[i][:-1]+[1,])) 262 | if(np.max(mask) == 0): 263 | print(bboxs[0], shape) 264 | error = True 265 | else: 266 | imgs.append(img) 267 | masks.append(mask) 268 | 269 | if self.return_fnames: 270 | batch_data.append(imgs + masks + list(filenames)) 271 | else: 272 | batch_data.append(imgs + masks) 273 | return zip(*batch_data) 274 | 275 | def _maybe_download_and_extract(self): 276 | pass 277 | -------------------------------------------------------------------------------- /inpaint_model.py: -------------------------------------------------------------------------------- 1 | """ common model for DCGAN """ 2 | import logging 3 | 4 | import cv2 5 | import neuralgym as ng 6 | import tensorflow as tf 7 | from tensorflow.contrib.framework.python.ops import arg_scope 8 | 9 | from neuralgym.models import Model 10 | from neuralgym.ops.summary_ops import scalar_summary, images_summary 11 | from neuralgym.ops.summary_ops import gradients_summary 12 | from neuralgym.ops.layers import flatten, resize 13 | from neuralgym.ops.gan_ops import gan_wgan_loss, gradients_penalty 14 | from neuralgym.ops.gan_ops import random_interpolates 15 | 16 | from inpaint_ops import gen_conv, gen_deconv, dis_conv 17 | from inpaint_ops import random_bbox, bbox2mask, local_patch 18 | from inpaint_ops import spatial_discounting_mask 19 | from inpaint_ops import resize_mask_like, contextual_attention 20 | 21 | 22 | logger = logging.getLogger() 23 | 24 | 25 | class InpaintCAModel(Model): 26 | def __init__(self): 27 | super().__init__('InpaintCAModel') 28 | 29 | def build_inpaint_net(self, x, mask, config=None, reuse=False, 30 | training=True, padding='SAME', name='inpaint_net'): 31 | """Inpaint network. 32 | 33 | Args: 34 | x: incomplete image, [-1, 1] 35 | mask: mask region {0, 1} 36 | Returns: 37 | [-1, 1] as predicted image 38 | """ 39 | xin = x 40 | offset_flow = None 41 | ones_x = tf.ones_like(x)[:, :, :, 0:1] 42 | x = tf.concat([x, ones_x, ones_x*mask], axis=3) 43 | 44 | # two stage network 45 | cnum = 32 46 | with tf.variable_scope(name, reuse=reuse), \ 47 | arg_scope([gen_conv, gen_deconv], 48 | training=training, padding=padding): 49 | # stage1 50 | x = gen_conv(x, cnum, 5, 1, name='conv1') 51 | x = gen_conv(x, 2*cnum, 3, 2, name='conv2_downsample') 52 | x = gen_conv(x, 2*cnum, 3, 1, name='conv3') 53 | x = gen_conv(x, 4*cnum, 3, 2, name='conv4_downsample') 54 | x = gen_conv(x, 4*cnum, 3, 1, name='conv5') 55 | x = gen_conv(x, 4*cnum, 3, 1, name='conv6') 56 | mask_s = resize_mask_like(mask, x) 57 | x = gen_conv(x, 4*cnum, 3, rate=2, name='conv7_atrous') 58 | x = gen_conv(x, 4*cnum, 3, rate=4, name='conv8_atrous') 59 | x = gen_conv(x, 4*cnum, 3, rate=8, name='conv9_atrous') 60 | x = gen_conv(x, 4*cnum, 3, rate=16, name='conv10_atrous') 61 | x = gen_conv(x, 4*cnum, 3, 1, name='conv11') 62 | x = gen_conv(x, 4*cnum, 3, 1, name='conv12') 63 | x = gen_deconv(x, 2*cnum, name='conv13_upsample') 64 | x = gen_conv(x, 2*cnum, 3, 1, name='conv14') 65 | x = gen_deconv(x, cnum, name='conv15_upsample') 66 | x = gen_conv(x, cnum//2, 3, 1, name='conv16') 67 | x = gen_conv(x, 3, 3, 1, activation=None, name='conv17') 68 | x = tf.clip_by_value(x, -1., 1.) 69 | x_stage1 = x 70 | # return x_stage1, None, None 71 | 72 | # stage2, paste result as input 73 | # x = tf.stop_gradient(x) 74 | x = x*mask + xin*(1.-mask) 75 | x.set_shape(xin.get_shape().as_list()) 76 | # conv branch 77 | xnow = tf.concat([x, ones_x, ones_x*mask], axis=3) 78 | x = gen_conv(xnow, cnum, 5, 1, name='xconv1') 79 | x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample') 80 | x = gen_conv(x, 2*cnum, 3, 1, name='xconv3') 81 | x = gen_conv(x, 2*cnum, 3, 2, name='xconv4_downsample') 82 | x = gen_conv(x, 4*cnum, 3, 1, name='xconv5') 83 | x = gen_conv(x, 4*cnum, 3, 1, name='xconv6') 84 | x = gen_conv(x, 4*cnum, 3, rate=2, name='xconv7_atrous') 85 | x = gen_conv(x, 4*cnum, 3, rate=4, name='xconv8_atrous') 86 | x = gen_conv(x, 4*cnum, 3, rate=8, name='xconv9_atrous') 87 | x = gen_conv(x, 4*cnum, 3, rate=16, name='xconv10_atrous') 88 | x_hallu = x 89 | # attention branch 90 | x = gen_conv(xnow, cnum, 5, 1, name='pmconv1') 91 | x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample') 92 | x = gen_conv(x, 2*cnum, 3, 1, name='pmconv3') 93 | x = gen_conv(x, 4*cnum, 3, 2, name='pmconv4_downsample') 94 | x = gen_conv(x, 4*cnum, 3, 1, name='pmconv5') 95 | x = gen_conv(x, 4*cnum, 3, 1, name='pmconv6', 96 | activation=tf.nn.relu) 97 | x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2) 98 | x = gen_conv(x, 4*cnum, 3, 1, name='pmconv9') 99 | x = gen_conv(x, 4*cnum, 3, 1, name='pmconv10') 100 | pm = x 101 | x = tf.concat([x_hallu, pm], axis=3) 102 | 103 | x = gen_conv(x, 4*cnum, 3, 1, name='allconv11') 104 | x = gen_conv(x, 4*cnum, 3, 1, name='allconv12') 105 | x = gen_deconv(x, 2*cnum, name='allconv13_upsample') 106 | x = gen_conv(x, 2*cnum, 3, 1, name='allconv14') 107 | x = gen_deconv(x, cnum, name='allconv15_upsample') 108 | x = gen_conv(x, cnum//2, 3, 1, name='allconv16') 109 | x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17') 110 | x_stage2 = tf.clip_by_value(x, -1., 1.) 111 | return x_stage1, x_stage2, offset_flow 112 | 113 | def build_wgan_local_discriminator(self, x, reuse=False, training=True): 114 | with tf.variable_scope('discriminator_local', reuse=reuse): 115 | cnum = 64 116 | x = dis_conv(x, cnum, name='conv1', training=training) 117 | x = dis_conv(x, cnum*2, name='conv2', training=training) 118 | x = dis_conv(x, cnum*4, name='conv3', training=training) 119 | x = dis_conv(x, cnum*8, name='conv4', training=training) 120 | x = flatten(x, name='flatten') 121 | return x 122 | 123 | def build_wgan_global_discriminator(self, x, reuse=False, training=True): 124 | with tf.variable_scope('discriminator_global', reuse=reuse): 125 | cnum = 64 126 | x = dis_conv(x, cnum, name='conv1', training=training) 127 | x = dis_conv(x, cnum*2, name='conv2', training=training) 128 | x = dis_conv(x, cnum*4, name='conv3', training=training) 129 | x = dis_conv(x, cnum*4, name='conv4', training=training) 130 | x = flatten(x, name='flatten') 131 | return x 132 | 133 | def build_wgan_discriminator(self, batch_local, batch_global, 134 | reuse=False, training=True): 135 | with tf.variable_scope('discriminator', reuse=reuse): 136 | dlocal = self.build_wgan_local_discriminator( 137 | batch_local, reuse=reuse, training=training) 138 | dglobal = self.build_wgan_global_discriminator( 139 | batch_global, reuse=reuse, training=training) 140 | dout_local = tf.layers.dense(dlocal, 1, name='dout_local_fc') 141 | dout_global = tf.layers.dense(dglobal, 1, name='dout_global_fc') 142 | return dout_local, dout_global 143 | 144 | def build_graph_with_losses(self, batch_data, config, training=True, 145 | summary=False, reuse=False): 146 | batch_pos = batch_data / 127.5 - 1. 147 | # generate mask, 1 represents masked point 148 | bbox = random_bbox(config) 149 | mask = bbox2mask(bbox, config, name='mask_c') 150 | batch_incomplete = batch_pos*(1.-mask) 151 | x1, x2, offset_flow = self.build_inpaint_net( 152 | batch_incomplete, mask, config, reuse=reuse, training=training, 153 | padding=config.PADDING) 154 | if config.PRETRAIN_COARSE_NETWORK: 155 | batch_predicted = x1 156 | logger.info('Set batch_predicted to x1.') 157 | else: 158 | batch_predicted = x2 159 | logger.info('Set batch_predicted to x2.') 160 | losses = {} 161 | # apply mask and complete image 162 | batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask) 163 | # local patches 164 | local_patch_batch_pos = local_patch(batch_pos, bbox) 165 | local_patch_batch_predicted = local_patch(batch_predicted, bbox) 166 | local_patch_x1 = local_patch(x1, bbox) 167 | local_patch_x2 = local_patch(x2, bbox) 168 | local_patch_batch_complete = local_patch(batch_complete, bbox) 169 | local_patch_mask = local_patch(mask, bbox) 170 | 171 | l1_alpha = config.COARSE_L1_ALPHA 172 | losses['l1_loss'] = l1_alpha * tf.reduce_mean(tf.abs(local_patch_batch_pos - local_patch_x1)*spatial_discounting_mask(config)) 173 | if not config.PRETRAIN_COARSE_NETWORK: 174 | losses['l1_loss'] += tf.reduce_mean(tf.abs(local_patch_batch_pos - local_patch_x2)*spatial_discounting_mask(config)) 175 | losses['ae_loss'] = l1_alpha * tf.reduce_mean(tf.abs(batch_pos - x1) * (1.-mask)) 176 | if not config.PRETRAIN_COARSE_NETWORK: 177 | losses['ae_loss'] += tf.reduce_mean(tf.abs(batch_pos - x2) * (1.-mask)) 178 | losses['ae_loss'] /= tf.reduce_mean(1.-mask) 179 | if summary: 180 | scalar_summary('losses/l1_loss', losses['l1_loss']) 181 | scalar_summary('losses/ae_loss', losses['ae_loss']) 182 | viz_img = [batch_pos, batch_incomplete, batch_complete] 183 | if offset_flow is not None: 184 | viz_img.append( 185 | resize(offset_flow, scale=4, 186 | func=tf.image.resize_nearest_neighbor)) 187 | images_summary( 188 | tf.concat(viz_img, axis=2), 189 | 'raw_incomplete_predicted_complete', config.VIZ_MAX_OUT) 190 | 191 | # gan 192 | batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0) 193 | 194 | # local deterministic patch 195 | local_patch_batch_pos_neg = tf.concat([local_patch_batch_pos, local_patch_batch_complete], 0) 196 | if config.GAN_WITH_MASK: 197 | batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(mask, [config.BATCH_SIZE*2, 1, 1, 1])], axis=3) 198 | # wgan with gradient penalty 199 | if config.GAN == 'wgan_gp': 200 | # seperate gan 201 | pos_neg_local, pos_neg_global = self.build_wgan_discriminator(local_patch_batch_pos_neg, batch_pos_neg, training=training, reuse=reuse) 202 | pos_local, neg_local = tf.split(pos_neg_local, 2) 203 | pos_global, neg_global = tf.split(pos_neg_global, 2) 204 | # wgan loss 205 | g_loss_local, d_loss_local = gan_wgan_loss(pos_local, neg_local, name='gan/local_gan') 206 | g_loss_global, d_loss_global = gan_wgan_loss(pos_global, neg_global, name='gan/global_gan') 207 | losses['g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global + g_loss_local 208 | losses['d_loss'] = d_loss_global + d_loss_local 209 | # gp 210 | interpolates_local = random_interpolates(local_patch_batch_pos, local_patch_batch_complete) 211 | interpolates_global = random_interpolates(batch_pos, batch_complete) 212 | dout_local, dout_global = self.build_wgan_discriminator( 213 | interpolates_local, interpolates_global, reuse=True) 214 | # apply penalty 215 | penalty_local = gradients_penalty(interpolates_local, dout_local, mask=local_patch_mask) 216 | penalty_global = gradients_penalty(interpolates_global, dout_global, mask=mask) 217 | losses['gp_loss'] = config.WGAN_GP_LAMBDA * (penalty_local + penalty_global) 218 | losses['d_loss'] = losses['d_loss'] + losses['gp_loss'] 219 | if summary and not config.PRETRAIN_COARSE_NETWORK: 220 | gradients_summary(g_loss_local, batch_predicted, name='g_loss_local') 221 | gradients_summary(g_loss_global, batch_predicted, name='g_loss_global') 222 | scalar_summary('convergence/d_loss', losses['d_loss']) 223 | scalar_summary('convergence/local_d_loss', d_loss_local) 224 | scalar_summary('convergence/global_d_loss', d_loss_global) 225 | scalar_summary('gan_wgan_loss/gp_loss', losses['gp_loss']) 226 | scalar_summary('gan_wgan_loss/gp_penalty_local', penalty_local) 227 | scalar_summary('gan_wgan_loss/gp_penalty_global', penalty_global) 228 | 229 | if summary and not config.PRETRAIN_COARSE_NETWORK: 230 | # summary the magnitude of gradients from different losses w.r.t. predicted image 231 | gradients_summary(losses['g_loss'], batch_predicted, name='g_loss') 232 | gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1') 233 | gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2') 234 | gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1') 235 | gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2') 236 | gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1') 237 | gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2') 238 | if config.PRETRAIN_COARSE_NETWORK: 239 | losses['g_loss'] = 0 240 | else: 241 | losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss'] 242 | losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss'] 243 | logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA) 244 | logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA) 245 | if config.AE_LOSS: 246 | losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss'] 247 | logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA) 248 | g_vars = tf.get_collection( 249 | tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net') 250 | d_vars = tf.get_collection( 251 | tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator') 252 | return g_vars, d_vars, losses 253 | 254 | def build_infer_graph(self, batch_data, config, bbox=None, name='val'): 255 | """ 256 | """ 257 | config.MAX_DELTA_HEIGHT = 0 258 | config.MAX_DELTA_WIDTH = 0 259 | if bbox is None: 260 | bbox = random_bbox(config) 261 | mask = bbox2mask(bbox, config, name=name+'mask_c') 262 | batch_pos = batch_data / 127.5 - 1. 263 | edges = None 264 | batch_incomplete = batch_pos*(1.-mask) 265 | # inpaint 266 | x1, x2, offset_flow = self.build_inpaint_net( 267 | batch_incomplete, mask, config, reuse=True, 268 | training=False, padding=config.PADDING) 269 | if config.PRETRAIN_COARSE_NETWORK: 270 | batch_predicted = x1 271 | logger.info('Set batch_predicted to x1.') 272 | else: 273 | batch_predicted = x2 274 | logger.info('Set batch_predicted to x2.') 275 | # apply mask and reconstruct 276 | batch_complete = batch_predicted*mask + batch_incomplete*(1.-mask) 277 | # global image visualization 278 | viz_img = [batch_pos, batch_incomplete, batch_complete] 279 | if offset_flow is not None: 280 | viz_img.append( 281 | resize(offset_flow, scale=4, 282 | func=tf.image.resize_nearest_neighbor)) 283 | images_summary( 284 | tf.concat(viz_img, axis=2), 285 | name+'_raw_incomplete_complete', config.VIZ_MAX_OUT) 286 | return batch_complete 287 | 288 | def build_static_infer_graph(self, batch_data, config, name): 289 | """ 290 | """ 291 | # generate mask, 1 represents masked point 292 | bbox = (tf.constant(config.HEIGHT//2), tf.constant(config.WIDTH//2), 293 | tf.constant(config.HEIGHT), tf.constant(config.WIDTH)) 294 | return self.build_infer_graph(batch_data, config, bbox, name) 295 | 296 | 297 | def build_server_graph(self, batch_data, reuse=False, is_training=False): 298 | """ 299 | """ 300 | # generate mask, 1 represents masked point 301 | batch_raw, masks_raw = tf.split(batch_data, 2, axis=2) 302 | masks = tf.cast(masks_raw[0:1, :, :, 0:1] > 127.5, tf.float32) 303 | 304 | batch_pos = batch_raw / 127.5 - 1. 305 | batch_incomplete = batch_pos * (1. - masks) 306 | # inpaint 307 | x1, x2, flow = self.build_inpaint_net( 308 | batch_incomplete, masks, reuse=reuse, training=is_training, 309 | config=None) 310 | batch_predict = x2 311 | # apply mask and reconstruct 312 | batch_complete = batch_predict*masks + batch_incomplete*(1-masks) 313 | return batch_complete 314 | -------------------------------------------------------------------------------- /inpaint_model_gc.py: -------------------------------------------------------------------------------- 1 | """ Reimplemented Inapinting Gated Convolution Model """ 2 | import logging 3 | 4 | import cv2 5 | import neuralgym as ng 6 | import tensorflow as tf 7 | from tensorflow.contrib.framework.python.ops import arg_scope 8 | from neuralgym.models import Model 9 | from neuralgym.ops.summary_ops import scalar_summary, images_summary 10 | from neuralgym.ops.summary_ops import gradients_summary 11 | from neuralgym.ops.layers import flatten, resize 12 | from neuralgym.ops.gan_ops import gan_wgan_loss, gradients_penalty 13 | from neuralgym.ops.gan_ops import random_interpolates 14 | 15 | from inpaint_ops import gated_conv, gated_deconv, gen_conv, dis_conv, gen_snconv, gen_deconv 16 | from inpaint_ops import random_ff_mask, mask_patch, mask_from_bbox_voc, mask_from_seg_voc 17 | from inpaint_ops import resize_mask_like, contextual_attention 18 | 19 | 20 | logger = logging.getLogger() 21 | 22 | def gan_sn_pgan_loss(pos, neg, name='gan_sn_pgan_loss'): 23 | """ 24 | sn_pgan loss function for GANs. 25 | 26 | - Wasserstein GAN: https://arxiv.org/abs/1701.07875 27 | """ 28 | with tf.variable_scope(name): 29 | d_loss = tf.reduce_mean(tf.nn.relu(1-pos)) + tf.reduce_mean(tf.nn.relu(1+neg)) 30 | g_loss = -tf.reduce_mean(neg) 31 | scalar_summary('d_loss', d_loss) 32 | scalar_summary('g_loss', g_loss) 33 | scalar_summary('pos_value_avg', tf.reduce_mean(pos)) 34 | scalar_summary('neg_value_avg', tf.reduce_mean(neg)) 35 | return g_loss, d_loss 36 | 37 | class InpaintGCModel(Model): 38 | def __init__(self): 39 | super().__init__('InpaintGCModel') 40 | 41 | def build_inpaint_net(self, x, mask, guide, config=None, reuse=False, 42 | training=True, padding='SAME', name='inpaint_net'): 43 | """Inpaint network. 44 | 45 | Args: 46 | x: incomplete image, [-1, 1] 47 | mask: mask region {0, 1} 48 | Returns: 49 | [-1, 1] as predicted image 50 | """ 51 | xin = x 52 | offset_flow = None 53 | ones_x = tf.ones_like(x)[:, :, :, 0:1] 54 | # 55 | #mask = ones_x*mask 56 | x = tf.concat([x, ones_x*mask, ones_x], axis=3) 57 | #x = tf.concat([x, ones_x*guide], axis=3) 58 | # two stage network 59 | cnum = 32 60 | with tf.variable_scope(name, reuse=reuse), \ 61 | arg_scope([gated_conv, gated_deconv], 62 | training=training, padding=padding): 63 | # stage1 64 | x = gated_conv(x, cnum, 5, 1, name='conv1') 65 | x = gated_conv(x, 2*cnum, 3, 2, name='conv2_downsample') 66 | x = gated_conv(x, 2*cnum, 3, 1, name='conv3') 67 | x = gated_conv(x, 4*cnum, 3, 2, name='conv4_downsample') 68 | x = gated_conv(x, 4*cnum, 3, 1, name='conv5') 69 | x = gated_conv(x, 4*cnum, 3, 1, name='conv6') 70 | mask_s = resize_mask_like(mask, x) 71 | #print(mask_s.shape) 72 | x = gated_conv(x, 4*cnum, 3, rate=2, name='conv7_atrous') 73 | x = gated_conv(x, 4*cnum, 3, rate=4, name='conv8_atrous') 74 | x = gated_conv(x, 4*cnum, 3, rate=8, name='conv9_atrous') 75 | x = gated_conv(x, 4*cnum, 3, rate=16, name='conv10_atrous') 76 | x = gated_conv(x, 4*cnum, 3, 1, name='conv11') 77 | x = gated_conv(x, 4*cnum, 3, 1, name='conv12') 78 | x = gated_deconv(x, 2*cnum, name='conv13_upsample') 79 | x = gated_conv(x, 2*cnum, 3, 1, name='conv14') 80 | x = gated_deconv(x, cnum, name='conv15_upsample') 81 | x = gated_conv(x, cnum//2, 3, 1, name='conv16') 82 | x = gated_conv(x, 3, 3, 1, activation=None, name='conv17') 83 | x = tf.clip_by_value(x, -1., 1.) 84 | x_stage1 = x 85 | # return x_stage1, None, None 86 | 87 | # stage2, paste result as input 88 | # x = tf.stop_gradient(x) 89 | x = x*mask + xin*(1.-mask) 90 | x.set_shape(xin.get_shape().as_list()) 91 | # conv branch 92 | xnow = tf.concat([x, ones_x, ones_x*mask], axis=3) 93 | x = gated_conv(xnow, cnum, 5, 1, name='xconv1') 94 | x = gated_conv(x, cnum, 3, 2, name='xconv2_downsample') 95 | x = gated_conv(x, 2*cnum, 3, 1, name='xconv3') 96 | x = gated_conv(x, 2*cnum, 3, 2, name='xconv4_downsample') 97 | x = gated_conv(x, 4*cnum, 3, 1, name='xconv5') 98 | x = gated_conv(x, 4*cnum, 3, 1, name='xconv6') 99 | x = gated_conv(x, 4*cnum, 3, rate=2, name='xconv7_atrous') 100 | x = gated_conv(x, 4*cnum, 3, rate=4, name='xconv8_atrous') 101 | x = gated_conv(x, 4*cnum, 3, rate=8, name='xconv9_atrous') 102 | x = gated_conv(x, 4*cnum, 3, rate=16, name='xconv10_atrous') 103 | x_hallu = x 104 | # attention branch 105 | x = gated_conv(xnow, cnum, 5, 1, name='pmconv1') 106 | x = gated_conv(x, cnum, 3, 2, name='pmconv2_downsample') 107 | x = gated_conv(x, 2*cnum, 3, 1, name='pmconv3') 108 | x = gated_conv(x, 4*cnum, 3, 2, name='pmconv4_downsample') 109 | x = gated_conv(x, 4*cnum, 3, 1, name='pmconv5') 110 | x = gated_conv(x, 4*cnum, 3, 1, name='pmconv6', 111 | activation=tf.nn.relu) 112 | # self.test_m_shape = tf.shape(mask_s) 113 | # self.test_x_shape = tf.shape([x) 114 | x, offset_flow = contextual_attention(x, x, mask_s, 3, 1, rate=2) 115 | x = gated_conv(x, 4*cnum, 3, 1, name='pmconv9') 116 | x = gated_conv(x, 4*cnum, 3, 1, name='pmconv10') 117 | pm = x 118 | x = tf.concat([x_hallu, pm], axis=3) 119 | 120 | x = gated_conv(x, 4*cnum, 3, 1, name='allconv11') 121 | x = gated_conv(x, 4*cnum, 3, 1, name='allconv12') 122 | x = gated_deconv(x, 2*cnum, name='allconv13_upsample') 123 | x = gated_conv(x, 2*cnum, 3, 1, name='allconv14') 124 | x = gated_deconv(x, cnum, name='allconv15_upsample') 125 | x = gated_conv(x, cnum//2, 3, 1, name='allconv16') 126 | x = gated_conv(x, 3, 3, 1, activation=None, name='allconv17') 127 | x_stage2 = tf.clip_by_value(x, -1., 1.) 128 | return x_stage1, x_stage2, offset_flow 129 | 130 | 131 | def build_sn_pgan_discriminator(self, x, reuse=False, training=True): 132 | with tf.variable_scope('discriminator', reuse=reuse): 133 | cnum = 64 134 | x = gen_snconv(x, cnum, 5, 2, name='conv1', training=training) 135 | x = gen_snconv(x, cnum*2, 5, 2, name='conv2', training=training) 136 | x = gen_snconv(x, cnum*4, 5, 2, name='conv3', training=training) 137 | x = gen_snconv(x, cnum*4, 5, 2, name='conv4', training=training) 138 | x = gen_snconv(x, cnum*4, 5, 2, name='conv5', training=training) 139 | x = gen_snconv(x, cnum*4, 5, 2, name='conv6', training=training) 140 | return x 141 | 142 | 143 | 144 | def build_graph_with_losses(self, batch_data, batch_mask, batch_guide, config, training=True, 145 | summary=False, reuse=False): 146 | batch_pos = batch_data / 127.5 - 1. 147 | # generate mask, 1 represents masked point[] 148 | #print(batch_data, batch_mask) 149 | if batch_mask is None: 150 | batch_mask = random_ff_mask(config) 151 | else: 152 | pass 153 | #batch_mask = tf.reshape(batch_mask[0], [1, *batch_mask.get_shape().as_list()[1:]]) 154 | #print(batch_mask.shape) 155 | #rint() 156 | batch_incomplete = batch_pos*(1.-batch_mask) 157 | ones_x = tf.ones_like(batch_mask)[:, :, :, 0:1] 158 | batch_mask = ones_x*batch_mask 159 | batch_guide = ones_x 160 | x1, x2, offset_flow = self.build_inpaint_net( 161 | batch_incomplete, batch_mask, batch_mask, config, reuse=reuse, training=training, padding=config.PADDING) 162 | if config.PRETRAIN_COARSE_NETWORK: 163 | batch_predicted = x1 164 | logger.info('Set batch_predicted to x1.') 165 | else: 166 | batch_predicted = x2 167 | logger.info('Set batch_predicted to x2.') 168 | losses = {} 169 | # apply mask and complete image 170 | batch_complete = batch_predicted*batch_mask + batch_incomplete*(1.-batch_mask) 171 | 172 | # local patches 173 | local_patch_batch_pos = mask_patch(batch_pos, batch_mask) 174 | local_patch_batch_predicted = mask_patch(batch_predicted, batch_mask) 175 | local_patch_x1 = mask_patch(x1, batch_mask) 176 | local_patch_x2 = mask_patch(x2, batch_mask) 177 | local_patch_batch_complete = mask_patch(batch_complete, batch_mask) 178 | #local_patch_mask = mask_patch(mask, bbox) 179 | 180 | # local patch l1 loss hole+out same as partial convolution 181 | l1_alpha = config.COARSE_L1_ALPHA 182 | losses['l1_loss'] = l1_alpha * tf.reduce_mean(tf.abs(local_patch_batch_pos - local_patch_x1)) # *spatial_discounting_mask(config)) 183 | if not config.PRETRAIN_COARSE_NETWORK: 184 | losses['l1_loss'] += tf.reduce_mean(tf.abs(local_patch_batch_pos - local_patch_x2)) # *spatial_discounting_mask(config)) 185 | losses['ae_loss'] = l1_alpha * tf.reduce_mean(tf.abs(batch_pos - x1) * (1.-batch_mask)) 186 | if not config.PRETRAIN_COARSE_NETWORK: 187 | losses['ae_loss'] += tf.reduce_mean(tf.abs(batch_pos - x2) * (1.-batch_mask)) 188 | losses['ae_loss'] /= tf.reduce_mean(1.-batch_mask) 189 | 190 | if summary: 191 | scalar_summary('losses/l1_loss', losses['l1_loss']) 192 | scalar_summary('losses/ae_loss', losses['ae_loss']) 193 | viz_img = [batch_pos, batch_incomplete, batch_complete] 194 | if offset_flow is not None: 195 | viz_img.append( 196 | resize(offset_flow, scale=4, 197 | func=tf.image.resize_nearest_neighbor)) 198 | images_summary( 199 | tf.concat(viz_img, axis=2), 200 | 'raw_incomplete_predicted_complete', config.VIZ_MAX_OUT) 201 | 202 | # gan 203 | batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0) 204 | if config.MASKFROMFILE: 205 | batch_mask_all = tf.tile(batch_mask, [2, 1, 1, 1]) 206 | #batch_mask = tf.tile(batch_mask, [config.BATCH_SIZE, 1, 1, 1]) 207 | else: 208 | batch_mask_all = tf.tile(batch_mask, [config.BATCH_SIZE*2, 1, 1, 1]) 209 | batch_mask = tf.tile(batch_mask, [config.BATCH_SIZE, 1, 1, 1]) 210 | if config.GAN_WITH_MASK : 211 | batch_pos_neg = tf.concat([batch_pos_neg, batch_mask_all], axis=3) 212 | 213 | if config.GAN_WITH_GUIDE: 214 | batch_pos_neg = tf.concat([batch_pos_neg, tf.tile(batch_guide, [config.BATCH_SIZE*2, 1, 1, 1])], axis=3) 215 | #batch_pos_, batch_complete_ = tf.split(axis, value, num_split, name=None) 216 | # sn-pgan with gradient penalty 217 | if config.GAN == 'sn_pgan': 218 | # sn path gan 219 | pos_neg = self.build_sn_pgan_discriminator(batch_pos_neg, training=training, reuse=reuse) 220 | pos_global, neg_global = tf.split(pos_neg, 2) 221 | 222 | # wgan loss 223 | #g_loss_local, d_loss_local = gan_wgan_loss(pos_local, neg_local, name='gan/local_gan') 224 | g_loss_global, d_loss_global = gan_sn_pgan_loss(pos_global, neg_global, name='gan/global_gan') 225 | losses['g_loss'] = config.GLOBAL_WGAN_LOSS_ALPHA * g_loss_global 226 | losses['d_loss'] = d_loss_global 227 | # gp 228 | 229 | # Random Interpolate between true and false 230 | 231 | interpolates_global = random_interpolates( 232 | tf.concat([batch_pos, batch_mask], axis=3), 233 | tf.concat([batch_complete, batch_mask], axis=3)) 234 | dout_global = self.build_sn_pgan_discriminator(interpolates_global, reuse=True) 235 | 236 | # apply penalty 237 | penalty_global = gradients_penalty(interpolates_global, dout_global, mask=batch_mask) 238 | losses['gp_loss'] = config.WGAN_GP_LAMBDA * penalty_global 239 | #losses['d_loss'] = losses['d_loss'] + losses['gp_loss'] 240 | 241 | if summary and not config.PRETRAIN_COARSE_NETWORK: 242 | gradients_summary(g_loss_global, batch_predicted, name='g_loss_global') 243 | scalar_summary('convergence/d_loss', losses['d_loss']) 244 | scalar_summary('convergence/global_d_loss', d_loss_global) 245 | scalar_summary('gan_sn_pgan_loss/gp_loss', losses['gp_loss']) 246 | scalar_summary('gan_sn_pgan_loss/gp_penalty_global', penalty_global) 247 | 248 | if summary and not config.PRETRAIN_COARSE_NETWORK: 249 | # summary the magnitude of gradients from different losses w.r.t. predicted image 250 | gradients_summary(losses['g_loss'], batch_predicted, name='g_loss') 251 | gradients_summary(losses['g_loss'], x1, name='g_loss_to_x1') 252 | gradients_summary(losses['g_loss'], x2, name='g_loss_to_x2') 253 | gradients_summary(losses['l1_loss'], x1, name='l1_loss_to_x1') 254 | gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2') 255 | gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1') 256 | gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2') 257 | if config.PRETRAIN_COARSE_NETWORK: 258 | losses['g_loss'] = 0 259 | else: 260 | losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss'] 261 | losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss'] 262 | logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA) 263 | logger.info('Set GAN_LOSS_ALPHA to %f' % config.GAN_LOSS_ALPHA) 264 | if config.AE_LOSS: 265 | losses['g_loss'] += config.AE_LOSS_ALPHA * losses['ae_loss'] 266 | logger.info('Set AE_LOSS_ALPHA to %f' % config.AE_LOSS_ALPHA) 267 | g_vars = tf.get_collection( 268 | tf.GraphKeys.TRAINABLE_VARIABLES, 'inpaint_net') 269 | d_vars = tf.get_collection( 270 | tf.GraphKeys.TRAINABLE_VARIABLES, 'discriminator') 271 | return g_vars, d_vars, losses 272 | 273 | def build_infer_graph(self, batch_data, batch_mask, batch_guide, config, name='val'): 274 | """ 275 | validation 276 | """ 277 | config.MAX_DELTA_HEIGHT = 0 278 | config.MAX_DELTA_WIDTH = 0 279 | batch_pos = batch_data / 127.5 - 1. 280 | batch_incomplete = batch_pos*(1.-batch_mask) 281 | # inpaint 282 | x1, x2, offset_flow = self.build_inpaint_net( 283 | batch_incomplete, batch_mask, batch_guide, config, reuse=True, 284 | training=False, padding=config.PADDING) 285 | if config.PRETRAIN_COARSE_NETWORK: 286 | batch_predicted = x1 287 | logger.info('Set batch_predicted to x1.') 288 | else: 289 | batch_predicted = x2 290 | logger.info('Set batch_predicted to x2.') 291 | # apply mask and reconstruct 292 | batch_complete = batch_predicted*batch_mask + batch_incomplete*(1.-batch_mask) 293 | # global image visualization 294 | viz_img = [batch_pos, batch_incomplete, batch_complete] 295 | if offset_flow is not None: 296 | viz_img.append( 297 | resize(offset_flow, scale=4, 298 | func=tf.image.resize_nearest_neighbor)) 299 | images_summary( 300 | tf.concat(viz_img, axis=2), 301 | name+'_raw_incomplete_complete', config.VIZ_MAX_OUT) 302 | return batch_complete 303 | 304 | def build_static_infer_graph(self, batch_data, batch_mask, batch_guide, config, name): 305 | """ 306 | """ 307 | # generate mask, 1 represents masked point 308 | bbox = (tf.constant(config.HEIGHT//2), tf.constant(config.WIDTH//2), 309 | tf.constant(config.HEIGHT), tf.constant(config.WIDTH)) 310 | return self.build_infer_graph(batch_data, batch_mask, batch_guide, config, name) 311 | 312 | 313 | def build_server_graph(self, batch_data, batch_mask, batch_guide, reuse=False, is_training=False): 314 | """ 315 | """ 316 | # generate mask, 1 represents masked point 317 | # batch_raw, masks_raw = tf.split(batch_data, 2, axis=2) 318 | #masks = tf.cast(masks_raw[0:1, :, :, 0:1] > 127.5, tf.float32) 319 | #batch_mask = tf.cast(batch_mask[0:1, :, :, 0:1] > 127.5, tf.float32) 320 | batch_pos = batch_data / 127.5 - 1. 321 | batch_incomplete = batch_pos * (1. - batch_mask) 322 | # inpaint 323 | x1, x2, flow = self.build_inpaint_net( 324 | batch_incomplete, batch_mask, batch_guide, reuse=reuse, training=is_training, 325 | config=None) 326 | batch_predict = x2 327 | # apply mask and reconstruct 328 | batch_complete = batch_predict*batch_mask + batch_incomplete*(1-batch_mask) 329 | return batch_complete 330 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | ## creative commons 2 | 3 | # Attribution-NonCommercial 4.0 International 4 | 5 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. 6 | 7 | ### Using Creative Commons Public Licenses 8 | 9 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. 10 | 11 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). 12 | 13 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). 14 | 15 | ## Creative Commons Attribution-NonCommercial 4.0 International Public License 16 | 17 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 18 | 19 | ### Section 1 – Definitions. 20 | 21 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. 22 | 23 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. 24 | 25 | c. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. 26 | 27 | d. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. 28 | 29 | e. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. 30 | 31 | f. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. 32 | 33 | g. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. 34 | 35 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. 36 | 37 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. 38 | 39 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. 40 | 41 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. 42 | 43 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. 44 | 45 | ### Section 2 – Scope. 46 | 47 | a. ___License grant.___ 48 | 49 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: 50 | 51 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and 52 | 53 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. 54 | 55 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 56 | 57 | 3. __Term.__ The term of this Public License is specified in Section 6(a). 58 | 59 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. 60 | 61 | 5. __Downstream recipients.__ 62 | 63 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. 64 | 65 | B. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 66 | 67 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). 68 | 69 | b. ___Other rights.___ 70 | 71 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 72 | 73 | 2. Patent and trademark rights are not licensed under this Public License. 74 | 75 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. 76 | 77 | ### Section 3 – License Conditions. 78 | 79 | Your exercise of the Licensed Rights is expressly made subject to the following conditions. 80 | 81 | a. ___Attribution.___ 82 | 83 | 1. If You Share the Licensed Material (including in modified form), You must: 84 | 85 | A. retain the following if it is supplied by the Licensor with the Licensed Material: 86 | 87 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); 88 | 89 | ii. a copyright notice; 90 | 91 | iii. a notice that refers to this Public License; 92 | 93 | iv. a notice that refers to the disclaimer of warranties; 94 | 95 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; 96 | 97 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and 98 | 99 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 100 | 101 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 102 | 103 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 104 | 105 | 4. If You Share Adapted Material You produce, the Adapter's License You apply must not prevent recipients of the Adapted Material from complying with this Public License. 106 | 107 | ### Section 4 – Sui Generis Database Rights. 108 | 109 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: 110 | 111 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; 112 | 113 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material; and 114 | 115 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. 116 | 117 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. 118 | 119 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability. 120 | 121 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ 122 | 123 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ 124 | 125 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 126 | 127 | ### Section 6 – Term and Termination. 128 | 129 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. 130 | 131 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 132 | 133 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 134 | 135 | 2. upon express reinstatement by the Licensor. 136 | 137 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. 138 | 139 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. 140 | 141 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. 142 | 143 | ### Section 7 – Other Terms and Conditions. 144 | 145 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 146 | 147 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. 148 | 149 | ### Section 8 – Interpretation. 150 | 151 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. 152 | 153 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. 154 | 155 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 156 | 157 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. 158 | 159 | > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. 160 | > 161 | > Creative Commons may be contacted at creativecommons.org 162 | -------------------------------------------------------------------------------- /inpaint_ops.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import cv2 4 | import numpy as np 5 | import tensorflow as tf 6 | from tensorflow.contrib.framework.python.ops import add_arg_scope 7 | 8 | from neuralgym.ops.layers import resize 9 | from neuralgym.ops.layers import * 10 | from neuralgym.ops.loss_ops import * 11 | from neuralgym.ops.summary_ops import * 12 | from sn import spectral_normed_weight 13 | 14 | logger = logging.getLogger() 15 | np.random.seed(2018) 16 | 17 | 18 | @add_arg_scope 19 | def gen_conv(x, cnum, ksize, stride=1, rate=1, name='conv', 20 | padding='SAME', activation=tf.nn.relu, training=True): 21 | """Define conv for generator. 22 | 23 | Args: 24 | x: Input. 25 | cnum: Channel number. 26 | ksize: Kernel size. 27 | Stride: Convolution stride. 28 | Rate: Rate for or dilated conv. 29 | name: Name of layers. 30 | padding: Default to SYMMETRIC. 31 | activation: Activation function after convolution. 32 | training: If current graph is for training or inference, used for bn. 33 | 34 | Returns: 35 | tf.Tensor: output 36 | 37 | """ 38 | assert padding in ['SYMMETRIC', 'SAME', 'REFELECT'] 39 | if padding == 'SYMMETRIC' or padding == 'REFELECT': 40 | p = int(rate*(ksize-1)/2) 41 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding) 42 | padding = 'VALID' 43 | x = tf.layers.conv2d( 44 | x, cnum, ksize, stride, dilation_rate=rate, 45 | activation=activation, padding=padding, name=name) 46 | return x 47 | 48 | 49 | @add_arg_scope 50 | def gen_deconv(x, cnum, name='upsample', padding='SAME', training=True): 51 | """Define deconv for generator. 52 | The deconv is defined to be a x2 resize_nearest_neighbor operation with 53 | additional gen_conv operation. 54 | 55 | Args: 56 | x: Input. 57 | cnum: Channel number. 58 | name: Name of layers. 59 | training: If current graph is for training or inference, used for bn. 60 | 61 | Returns: 62 | tf.Tensor: output 63 | 64 | """ 65 | with tf.variable_scope(name): 66 | x = resize(x, func=tf.image.resize_nearest_neighbor) 67 | x = gen_conv( 68 | x, cnum, 3, 1, name=name+'_conv', padding=padding, 69 | training=training) 70 | return x 71 | 72 | @add_arg_scope 73 | def gen_snconv(x, cnum, ksize, stride=1, rate=1, name='conv', 74 | padding='SAME', activation=tf.nn.relu, training=True): 75 | """Define spectral normalization conv for discriminator. 76 | 77 | Args: 78 | x: Input. 79 | cnum: Channel number. 80 | ksize: Kernel size. 81 | Stride: Convolution stride. 82 | Rate: Rate for or dilated conv. 83 | name: Name of layers. 84 | padding: Default to SYMMETRIC. 85 | activation: Activation function after convolution. 86 | training: If current graph is for training or inference, used for bn. 87 | 88 | Returns: 89 | tf.Tensor: output 90 | 91 | """ 92 | assert padding in ['SYMMETRIC', 'SAME', 'REFELECT'] 93 | if padding == 'SYMMETRIC' or padding == 'REFELECT': 94 | p = int(rate*(ksize-1)/2) 95 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding) 96 | padding = 'VALID' 97 | 98 | fan_in = ksize * ksize * x.get_shape().as_list()[-1] 99 | fan_out = ksize * ksize * cnum 100 | stddev = np.sqrt(2. / (fan_in)) 101 | # initializer for w used for spectral normalization 102 | w = tf.get_variable(name+"_w", [ksize, ksize, x.get_shape()[-1], cnum], 103 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 104 | x = tf.nn.conv2d(x, spectral_normed_weight(w, update_collection=tf.GraphKeys.UPDATE_OPS, name=name+"_sn_w"), 105 | strides=[1, stride, stride, 1], dilations=[1, rate, rate, 1], padding=padding, name=name) 106 | return x 107 | 108 | 109 | @add_arg_scope 110 | def gen_deconv(x, cnum, name='upsample', padding='SAME', training=True): 111 | """Define deconv for generator. 112 | The deconv is defined to be a x2 resize_nearest_neighbor operation with 113 | additional gen_conv operation. 114 | 115 | Args: 116 | x: Input. 117 | cnum: Channel number. 118 | name: Name of layers. 119 | training: If current graph is for training or inference, used for bn. 120 | 121 | Returns: 122 | tf.Tensor: output 123 | 124 | """ 125 | with tf.variable_scope(name): 126 | x = resize(x, func=tf.image.resize_nearest_neighbor) 127 | x = gen_conv( 128 | x, cnum, 3, 1, name=name+'_conv', padding=padding, 129 | training=training) 130 | return x 131 | 132 | @add_arg_scope 133 | def dis_conv(x, cnum, ksize=5, stride=2, name='conv', training=True): 134 | """Define conv for discriminator. 135 | Activation is set to leaky_relu. 136 | 137 | Args: 138 | x: Input. 139 | cnum: Channel number. 140 | ksize: Kernel size. 141 | Stride: Convolution stride. 142 | name: Name of layers. 143 | training: If current graph is for training or inference, used for bn. 144 | 145 | Returns: 146 | tf.Tensor: output 147 | 148 | """ 149 | x = tf.layers.conv2d(x, cnum, ksize, stride, 'SAME', name=name) 150 | x = tf.nn.leaky_relu(x) 151 | return x 152 | 153 | @add_arg_scope 154 | def gated_conv(x, cnum, ksize, stride=1, rate=1, name='gated_conv', 155 | padding='SAME', activation=tf.nn.relu, training=True): 156 | """Define gated conv for generator. Add a gating filter 157 | 158 | Args: 159 | x: Input. 160 | cnum: Channel number. 161 | ksize: Kernel size. 162 | Stride: Convolution stride. 163 | Rate: Rate for or dilated conv. 164 | name: Name of layers. 165 | padding: Default to SYMMETRIC. 166 | activation: Activation function after convolution. 167 | training: If current graph is for training or inference, used for bn. 168 | 169 | Returns: 170 | tf.Tensor: output 171 | 172 | """ 173 | assert padding in ['SYMMETRIC', 'SAME', 'REFELECT'] 174 | if padding == 'SYMMETRIC' or padding == 'REFELECT': 175 | p = int(rate*(ksize-1)/2) 176 | x = tf.pad(x, [[0,0], [p, p], [p, p], [0,0]], mode=padding) 177 | padding = 'VALID' 178 | xin = x 179 | x = tf.layers.conv2d( 180 | xin, cnum, ksize, stride, dilation_rate=rate, 181 | activation=activation, padding=padding, name=name) 182 | 183 | gated_mask = tf.layers.conv2d( 184 | xin, cnum, ksize, stride, dilation_rate=rate, 185 | activation=tf.nn.sigmoid, padding=padding, name=name+"_mask") 186 | 187 | return x * gated_mask 188 | 189 | @add_arg_scope 190 | def gated_deconv(x, cnum, name='upsample', padding='SAME', training=True): 191 | """Define gated deconv for generator. 192 | The deconv is defined to be a x2 resize_nearest_neighbor operation with 193 | additional gen_conv operation. 194 | 195 | Args: 196 | x: Input. 197 | cnum: Channel number. 198 | name: Name of layers. 199 | training: If current graph is for training or inference, used for bn. 200 | 201 | Returns: 202 | tf.Tensor: output 203 | 204 | """ 205 | with tf.variable_scope(name): 206 | x = resize(x, func=tf.image.resize_nearest_neighbor) 207 | x = gated_conv( 208 | x, cnum, 3, 1, name=name+'_conv', padding=padding, 209 | training=training) 210 | return x 211 | 212 | 213 | 214 | 215 | def random_ff_mask(config, name="ff_mask"): 216 | """Generate a random free form mask with configuration. 217 | 218 | Args: 219 | config: Config should have configuration including IMG_SHAPES, 220 | VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH. 221 | 222 | Returns: 223 | tuple: (top, left, height, width) 224 | 225 | """ 226 | img_shape = config.IMG_SHAPES 227 | h,w,c = img_shape 228 | def npmask(): 229 | 230 | mask = np.zeros((h,w)) 231 | num_v = 8+np.random.randint(config.MAXVERTEX)#tf.random_uniform([], minval=0, maxval=config.MAXVERTEX, dtype=tf.int32) 232 | 233 | for i in range(num_v): 234 | start_x = np.random.randint(w) 235 | start_y = np.random.randint(h) 236 | for j in range(1+np.random.randint(5)): 237 | angle = 0.01+np.random.randint(config.MAXANGLE) 238 | if i % 2 == 0: 239 | angle = 2 * 3.1415926 - angle 240 | length = 10+np.random.randint(config.MAXLENGTH) 241 | brush_w = 10+np.random.randint(config.MAXBRUSHWIDTH) 242 | end_x = (start_x + length * np.sin(angle)).astype(np.int32) 243 | end_y = (start_y + length * np.cos(angle)).astype(np.int32) 244 | 245 | cv2.line(mask, (start_y, start_x), (end_y, end_x), 1.0, brush_w) 246 | start_x, start_y = end_x, end_y 247 | return mask.reshape(mask.shape+(1,)).astype(np.float32) 248 | with tf.variable_scope(name), tf.device('/cpu:0'): 249 | mask = tf.py_func( 250 | npmask, 251 | [], 252 | tf.float32, stateful=False) 253 | mask.set_shape([1,] + [h, w] + [1,]) 254 | return mask 255 | 256 | def mask_from_bbox_voc(config, files): 257 | """ 258 | Use the data from voc dataset. And generate mask from bounding segmentation data 259 | 260 | Args: 261 | config: Config should have configuration including IMG_SHAPES, 262 | VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH. 263 | 264 | files: Filename list used for generate bboxes 265 | Returns: 266 | a batch of masks 267 | """ 268 | img_shape = config.IMG_SHAPES 269 | img_height = img_shape[0] 270 | img_width = img_shape[1] 271 | fs = tf.shape(files) 272 | files_group = tf.split(files, fs[0], axis=0) 273 | with tf.variable_scope(name), tf.device('/cpu:0'): 274 | for i in range(fs[0]): 275 | file = files_group[i] 276 | 277 | 278 | def mask_from_seg_voc(config, files): 279 | """ 280 | Use the data from voc dataset. And generate mask from bounding box data 281 | 282 | Args: 283 | config: Config should have configuration including IMG_SHAPES, 284 | VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH. 285 | 286 | files: Filename list used for generate bboxes 287 | Returns: 288 | a batch of masks 289 | """ 290 | pass 291 | def random_bbox(config): 292 | """Generate a random tlhw with configuration. 293 | 294 | Args: 295 | config: Config should have configuration including IMG_SHAPES, 296 | VERTICAL_MARGIN, HEIGHT, HORIZONTAL_MARGIN, WIDTH. 297 | 298 | Returns: 299 | tuple: (top, left, height, width) 300 | 301 | """ 302 | img_shape = config.IMG_SHAPES 303 | img_height = img_shape[0] 304 | img_width = img_shape[1] 305 | maxt = img_height - config.VERTICAL_MARGIN - config.HEIGHT 306 | maxl = img_width - config.HORIZONTAL_MARGIN - config.WIDTH 307 | t = tf.random_uniform( 308 | [], minval=config.VERTICAL_MARGIN, maxval=maxt, dtype=tf.int32) 309 | l = tf.random_uniform( 310 | [], minval=config.HORIZONTAL_MARGIN, maxval=maxl, dtype=tf.int32) 311 | h = tf.constant(config.HEIGHT) 312 | w = tf.constant(config.WIDTH) 313 | return (t, l, h, w) 314 | 315 | 316 | def bbox2mask(bbox, config, name='mask'): 317 | """Generate mask tensor from bbox. 318 | 319 | Args: 320 | bbox: configuration tuple, (top, left, height, width) 321 | config: Config should have configuration including IMG_SHAPES, 322 | MAX_DELTA_HEIGHT, MAX_DELTA_WIDTH. 323 | 324 | Returns: 325 | tf.Tensor: output with shape [1, H, W, 1] 326 | 327 | """ 328 | def npmask(bbox, height, width, delta_h, delta_w): 329 | mask = np.zeros((1, height, width, 1), np.float32) 330 | h = np.random.randint(delta_h//2+1) 331 | w = np.random.randint(delta_w//2+1) 332 | mask[:, bbox[0]+h:bbox[0]+bbox[2]-h, 333 | bbox[1]+w:bbox[1]+bbox[3]-w, :] = 1. 334 | return mask 335 | with tf.variable_scope(name), tf.device('/cpu:0'): 336 | img_shape = config.IMG_SHAPES 337 | height = img_shape[0] 338 | width = img_shape[1] 339 | mask = tf.py_func( 340 | npmask, 341 | [bbox, height, width, 342 | config.MAX_DELTA_HEIGHT, config.MAX_DELTA_WIDTH], 343 | tf.float32, stateful=False) 344 | mask.set_shape([1] + [height, width] + [1]) 345 | return mask 346 | 347 | 348 | def mask_patch(x, mask): 349 | """Crop local patch according to mask. 350 | 351 | Args: 352 | x: input 353 | mask: 0,1 mask have the same size of x 354 | 355 | Returns: 356 | tf.Tensor: local patch 357 | 358 | """ 359 | return x*mask 360 | def local_patch(x, bbox): 361 | """Crop local patch according to bbox. 362 | 363 | Args: 364 | x: input 365 | bbox: (top, left, height, width) 366 | 367 | Returns: 368 | tf.Tensor: local patch 369 | 370 | """ 371 | x = tf.image.crop_to_bounding_box(x, bbox[0], bbox[1], bbox[2], bbox[3]) 372 | return x 373 | 374 | 375 | def resize_mask_like(mask, x): 376 | """Resize mask like shape of x. 377 | 378 | Args: 379 | mask: Original mask. 380 | x: To shape of x. 381 | 382 | Returns: 383 | tf.Tensor: resized mask 384 | 385 | """ 386 | mask_resize = resize( 387 | mask, to_shape=x.get_shape().as_list()[1:3], 388 | func=tf.image.resize_nearest_neighbor) 389 | return mask_resize 390 | 391 | 392 | def spatial_discounting_mask(config): 393 | """Generate spatial discounting mask constant. 394 | 395 | Spatial discounting mask is first introduced in publication: 396 | Generative Image Inpainting with Contextual Attention, Yu et al. 397 | 398 | Args: 399 | config: Config should have configuration including HEIGHT, WIDTH, 400 | DISCOUNTED_MASK. 401 | 402 | Returns: 403 | tf.Tensor: spatial discounting mask 404 | """ 405 | gamma = config.SPATIAL_DISCOUNTING_GAMMA 406 | shape = [1, config.HEIGHT, config.WIDTH, 1] 407 | if config.DISCOUNTED_MASK: 408 | logger.info('Use spatial discounting l1 loss.') 409 | mask_values = np.ones((config.HEIGHT, config.WIDTH)) 410 | for i in range(config.HEIGHT): 411 | for j in range(config.WIDTH): 412 | mask_values[i, j] = max( 413 | gamma**min(i, config.HEIGHT-i), 414 | gamma**min(j, config.WIDTH-j)) 415 | mask_values = np.expand_dims(mask_values, 0) 416 | mask_values = np.expand_dims(mask_values, 3) 417 | mask_values = mask_values 418 | else: 419 | mask_values = np.ones(shape) 420 | return tf.constant(mask_values, dtype=tf.float32, shape=shape) 421 | 422 | 423 | def contextual_attention(f, b, mask=None, ksize=3, stride=1, rate=1, 424 | fuse_k=3, softmax_scale=10., training=True, fuse=True): 425 | """ Contextual attention layer implementation. 426 | 427 | Contextual attention is first introduced in publication: 428 | Generative Image Inpainting with Contextual Attention, Yu et al. 429 | 430 | Args: 431 | x: Input feature to match (foreground). 432 | t: Input feature for match (background). 433 | mask: Input mask for t, indicating patches not available. 434 | ksize: Kernel size for contextual attention. 435 | stride: Stride for extracting patches from t. 436 | rate: Dilation for matching. 437 | softmax_scale: Scaled softmax for attention. 438 | training: Indicating if current graph is training or inference. 439 | 440 | Returns: 441 | tf.Tensor: output 442 | 443 | """ 444 | # get shapes 445 | raw_fs = tf.shape(f) 446 | #batch_size = raw_fs[0] 447 | raw_int_fs = f.get_shape().as_list() 448 | raw_int_bs = b.get_shape().as_list() 449 | # extract patches from background with stride and rate 450 | kernel = 2*rate 451 | raw_w = tf.extract_image_patches( 452 | b, [1,kernel,kernel,1], [1,rate*stride,rate*stride,1], [1,1,1,1], padding='SAME') 453 | raw_w = tf.reshape(raw_w, [raw_int_bs[0], -1, kernel, kernel, raw_int_bs[3]]) 454 | raw_w = tf.transpose(raw_w, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw 455 | # downscaling foreground option: downscaling both foreground and 456 | # background for matching and use original background for reconstruction. 457 | f = resize(f, scale=1./rate, func=tf.image.resize_nearest_neighbor) 458 | b = resize(b, to_shape=[int(raw_int_bs[1]/rate), int(raw_int_bs[2]/rate)], func=tf.image.resize_nearest_neighbor) # https://github.com/tensorflow/tensorflow/issues/11651 459 | if mask is not None: 460 | mask = resize(mask, scale=1./rate, func=tf.image.resize_nearest_neighbor) 461 | fs = tf.shape(f) 462 | int_fs = f.get_shape().as_list() 463 | f_groups = tf.split(f, int_fs[0], axis=0) 464 | # from t(H*W*C) to w(b*k*k*c*h*w) 465 | bs = tf.shape(b) 466 | int_bs = b.get_shape().as_list() 467 | w = tf.extract_image_patches( 468 | b, [1,ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME') 469 | w = tf.reshape(w, [int_fs[0], -1, ksize, ksize, int_fs[3]]) 470 | w = tf.transpose(w, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw 471 | # process mask 472 | if mask is None: 473 | mask = tf.zeros([1, bs[1], bs[2], 1]) 474 | ms = tf.shape(mask) 475 | batch_size = ms[0] 476 | m = tf.extract_image_patches( 477 | mask, [1, ksize,ksize,1], [1,stride,stride,1], [1,1,1,1], padding='SAME') 478 | m = tf.reshape(m, [batch_size, -1, ksize, ksize, 1]) 479 | m = tf.transpose(m, [0, 2, 3, 4, 1]) # transpose to b*k*k*c*hw 480 | m = m[0] 481 | mm = tf.cast(tf.equal(tf.reduce_mean(m, axis=[0,1,2], keep_dims=True), 0.), tf.float32) 482 | w_groups = tf.split(w, int_bs[0], axis=0) 483 | raw_w_groups = tf.split(raw_w, int_bs[0], axis=0) 484 | y = [] 485 | offsets = [] 486 | k = fuse_k 487 | scale = softmax_scale 488 | fuse_weight = tf.reshape(tf.eye(k), [k, k, 1, 1]) 489 | for xi, wi, raw_wi in zip(f_groups, w_groups, raw_w_groups): 490 | # conv for compare 491 | wi = wi[0] 492 | wi_normed = wi / tf.maximum(tf.sqrt(tf.reduce_sum(tf.square(wi), axis=[0,1,2])), 1e-4) 493 | yi = tf.nn.conv2d(xi, wi_normed, strides=[1,1,1,1], padding="SAME") 494 | 495 | # conv implementation for fuse scores to encourage large patches 496 | if fuse: 497 | yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1]) 498 | yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME') 499 | yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1], bs[2]]) 500 | yi = tf.transpose(yi, [0, 2, 1, 4, 3]) 501 | yi = tf.reshape(yi, [1, fs[1]*fs[2], bs[1]*bs[2], 1]) 502 | yi = tf.nn.conv2d(yi, fuse_weight, strides=[1,1,1,1], padding='SAME') 503 | yi = tf.reshape(yi, [1, fs[2], fs[1], bs[2], bs[1]]) 504 | yi = tf.transpose(yi, [0, 2, 1, 4, 3]) 505 | yi = tf.reshape(yi, [1, fs[1], fs[2], bs[1]*bs[2]]) 506 | 507 | # softmax to match 508 | yi *= mm # mask 509 | yi = tf.nn.softmax(yi*scale, 3) 510 | yi *= mm # mask 511 | 512 | offset = tf.argmax(yi, axis=3, output_type=tf.int32) 513 | offset = tf.stack([offset // fs[2], offset % fs[2]], axis=-1) 514 | # deconv for patch pasting 515 | # 3.1 paste center 516 | wi_center = raw_wi[0] 517 | yi = tf.nn.conv2d_transpose(yi, wi_center, tf.concat([[1], raw_fs[1:]], axis=0), strides=[1,rate,rate,1]) / 4. 518 | y.append(yi) 519 | offsets.append(offset) 520 | y = tf.concat(y, axis=0) 521 | y.set_shape(raw_int_fs) 522 | offsets = tf.concat(offsets, axis=0) 523 | offsets.set_shape(int_bs[:3] + [2]) 524 | # case1: visualize optical flow: minus current position 525 | h_add = tf.tile(tf.reshape(tf.range(bs[1]), [1, bs[1], 1, 1]), [bs[0], 1, bs[2], 1]) 526 | w_add = tf.tile(tf.reshape(tf.range(bs[2]), [1, 1, bs[2], 1]), [bs[0], bs[1], 1, 1]) 527 | offsets = offsets - tf.concat([h_add, w_add], axis=3) 528 | # to flow image 529 | flow = flow_to_image_tf(offsets) 530 | # # case2: visualize which pixels are attended 531 | # flow = highlight_flow_tf(offsets * tf.cast(mask, tf.int32)) 532 | if rate != 1: 533 | flow = resize(flow, scale=rate, func=tf.image.resize_nearest_neighbor) 534 | return y, flow 535 | 536 | 537 | def test_contextual_attention(args): 538 | """Test contextual attention layer with 3-channel image input 539 | (instead of n-channel feature). 540 | 541 | """ 542 | import cv2 543 | import os 544 | # run on cpu 545 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 546 | 547 | rate = 2 548 | stride = 1 549 | grid = rate*stride 550 | 551 | b = cv2.imread(args.imageA) 552 | b = cv2.resize(b, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_CUBIC) 553 | h, w, _ = b.shape 554 | b = b[:h//grid*grid, :w//grid*grid, :] 555 | b = np.expand_dims(b, 0) 556 | logger.info('Size of imageA: {}'.format(b.shape)) 557 | 558 | f = cv2.imread(args.imageB) 559 | h, w, _ = f.shape 560 | f = f[:h//grid*grid, :w//grid*grid, :] 561 | f = np.expand_dims(f, 0) 562 | logger.info('Size of imageB: {}'.format(f.shape)) 563 | 564 | with tf.Session() as sess: 565 | bt = tf.constant(b, dtype=tf.float32) 566 | ft = tf.constant(f, dtype=tf.float32) 567 | 568 | yt, flow = contextual_attention( 569 | ft, bt, stride=stride, rate=rate, 570 | training=False, fuse=False) 571 | y = sess.run(yt) 572 | cv2.imwrite(args.imageOut, y[0]) 573 | 574 | 575 | def make_color_wheel(): 576 | RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6) 577 | ncols = RY + YG + GC + CB + BM + MR 578 | colorwheel = np.zeros([ncols, 3]) 579 | col = 0 580 | # RY 581 | colorwheel[0:RY, 0] = 255 582 | colorwheel[0:RY, 1] = np.transpose(np.floor(255*np.arange(0, RY) / RY)) 583 | col += RY 584 | # YG 585 | colorwheel[col:col+YG, 0] = 255 - np.transpose(np.floor(255*np.arange(0, YG) / YG)) 586 | colorwheel[col:col+YG, 1] = 255 587 | col += YG 588 | # GC 589 | colorwheel[col:col+GC, 1] = 255 590 | colorwheel[col:col+GC, 2] = np.transpose(np.floor(255*np.arange(0, GC) / GC)) 591 | col += GC 592 | # CB 593 | colorwheel[col:col+CB, 1] = 255 - np.transpose(np.floor(255*np.arange(0, CB) / CB)) 594 | colorwheel[col:col+CB, 2] = 255 595 | col += CB 596 | # BM 597 | colorwheel[col:col+BM, 2] = 255 598 | colorwheel[col:col+BM, 0] = np.transpose(np.floor(255*np.arange(0, BM) / BM)) 599 | col += + BM 600 | # MR 601 | colorwheel[col:col+MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) 602 | colorwheel[col:col+MR, 0] = 255 603 | return colorwheel 604 | 605 | 606 | COLORWHEEL = make_color_wheel() 607 | 608 | 609 | def compute_color(u,v): 610 | h, w = u.shape 611 | img = np.zeros([h, w, 3]) 612 | nanIdx = np.isnan(u) | np.isnan(v) 613 | u[nanIdx] = 0 614 | v[nanIdx] = 0 615 | # colorwheel = COLORWHEEL 616 | colorwheel = make_color_wheel() 617 | ncols = np.size(colorwheel, 0) 618 | rad = np.sqrt(u**2+v**2) 619 | a = np.arctan2(-v, -u) / np.pi 620 | fk = (a+1) / 2 * (ncols - 1) + 1 621 | k0 = np.floor(fk).astype(int) 622 | k1 = k0 + 1 623 | k1[k1 == ncols+1] = 1 624 | f = fk - k0 625 | for i in range(np.size(colorwheel,1)): 626 | tmp = colorwheel[:, i] 627 | col0 = tmp[k0-1] / 255 628 | col1 = tmp[k1-1] / 255 629 | col = (1-f) * col0 + f * col1 630 | idx = rad <= 1 631 | col[idx] = 1-rad[idx]*(1-col[idx]) 632 | notidx = np.logical_not(idx) 633 | col[notidx] *= 0.75 634 | img[:, :, i] = np.uint8(np.floor(255 * col*(1-nanIdx))) 635 | return img 636 | 637 | 638 | 639 | def flow_to_image(flow): 640 | """Transfer flow map to image. 641 | Part of code forked from flownet. 642 | """ 643 | out = [] 644 | maxu = -999. 645 | maxv = -999. 646 | minu = 999. 647 | minv = 999. 648 | maxrad = -1 649 | for i in range(flow.shape[0]): 650 | u = flow[i, :, :, 0] 651 | v = flow[i, :, :, 1] 652 | idxunknow = (abs(u) > 1e7) | (abs(v) > 1e7) 653 | u[idxunknow] = 0 654 | v[idxunknow] = 0 655 | maxu = max(maxu, np.max(u)) 656 | minu = min(minu, np.min(u)) 657 | maxv = max(maxv, np.max(v)) 658 | minv = min(minv, np.min(v)) 659 | rad = np.sqrt(u ** 2 + v ** 2) 660 | maxrad = max(maxrad, np.max(rad)) 661 | u = u/(maxrad + np.finfo(float).eps) 662 | v = v/(maxrad + np.finfo(float).eps) 663 | img = compute_color(u, v) 664 | out.append(img) 665 | return np.float32(np.uint8(out)) 666 | 667 | 668 | def flow_to_image_tf(flow, name='flow_to_image'): 669 | """Tensorflow ops for computing flow to image. 670 | """ 671 | with tf.variable_scope(name), tf.device('/cpu:0'): 672 | img = tf.py_func(flow_to_image, [flow], tf.float32, stateful=False) 673 | img.set_shape(flow.get_shape().as_list()[0:-1]+[3]) 674 | img = img / 127.5 - 1. 675 | return img 676 | 677 | 678 | def highlight_flow(flow): 679 | """Convert flow into middlebury color code image. 680 | """ 681 | out = [] 682 | s = flow.shape 683 | for i in range(flow.shape[0]): 684 | img = np.ones((s[1], s[2], 3)) * 144. 685 | u = flow[i, :, :, 0] 686 | v = flow[i, :, :, 1] 687 | for h in range(s[1]): 688 | for w in range(s[1]): 689 | ui = u[h,w] 690 | vi = v[h,w] 691 | img[ui, vi, :] = 255. 692 | out.append(img) 693 | return np.float32(np.uint8(out)) 694 | 695 | 696 | def highlight_flow_tf(flow, name='flow_to_image'): 697 | """Tensorflow ops for highlight flow. 698 | """ 699 | with tf.variable_scope(name), tf.device('/cpu:0'): 700 | img = tf.py_func(highlight_flow, [flow], tf.float32, stateful=False) 701 | img.set_shape(flow.get_shape().as_list()[0:-1]+[3]) 702 | img = img / 127.5 - 1. 703 | return img 704 | 705 | 706 | def image2edge(image): 707 | """Convert image to edges. 708 | """ 709 | out = [] 710 | for i in range(image.shape[0]): 711 | img = cv2.Laplacian(image[i, :, :, :], cv2.CV_64F, ksize=3, scale=2) 712 | out.append(img) 713 | return np.float32(np.uint8(out)) 714 | 715 | class VOCReader(object): 716 | def __init__(self, path): 717 | self.path = path 718 | 719 | def load_seg(self, filename): 720 | file_path = os.path.join(self.path, "SegmentationClass", filename) 721 | seg = cv2.imread(file_path) 722 | return seg 723 | def seg2mask(self, seg): 724 | pass 725 | def load_bndbox(self, filename): 726 | file_path = os.path.join(self.path, "Annotations", filename) 727 | with open(filename, 'r') as reader: 728 | xml = reader.read() 729 | soup = BeautifulSoup(xml, 'xml') 730 | size = {} 731 | for tag in soup.size: 732 | if tag.string != "\n": 733 | size[tag.name] = int(tag.string) 734 | objects = soup.find_all('object') 735 | bndboxs = [] 736 | for obj in objects: 737 | bndbox = {} 738 | for tag in obj.bndbox: 739 | if tag.string != '\n': 740 | bndbox[tag.name] = int(tag.string) 741 | bndboxs.append(bndbox) 742 | return bndboxs 743 | 744 | 745 | 746 | if __name__ == "__main__": 747 | import argparse 748 | parser = argparse.ArgumentParser() 749 | parser.add_argument('--imageA', default='', type=str, help='Image A as background patches to reconstruct image B.') 750 | parser.add_argument('--imageB', default='', type=str, help='Image B is reconstructed with image A.') 751 | parser.add_argument('--imageOut', default='result.png', type=str, help='Image B is reconstructed with image A.') 752 | args = parser.parse_args() 753 | test_contextual_attention(args) 754 | --------------------------------------------------------------------------------