├── LICENSE ├── Readme.md ├── data_processing ├── classes.csv ├── coco_data_provider.py ├── coco_share_classes.txt ├── edge_detection │ ├── PostprocessHED.m │ └── batch_hed.py ├── extract_images.py ├── flickr_crawler.py ├── flickr_filter.py ├── flickr_to_tfrecord.py ├── imagenet_classes.txt ├── imagenet_lsvrc_2015_synsets.txt ├── imagenet_metadata.txt ├── imagenet_share_classes.txt ├── pycocotools │ ├── __init__.py │ ├── _mask.pyx │ ├── coco.py │ ├── cocoeval.py │ └── mask.py ├── sketchy_to_tfrecord.py └── tfrecord.py ├── inception_v4_model └── put_inception_v4.ckpt_here ├── main_single.py └── src_single ├── config.py ├── graph_single.py ├── inception_score.py ├── inception_utils.py ├── inception_v4.py ├── input_pipeline.py ├── models_mru.py ├── mru.py ├── sn.py ├── train_single.py └── vgg.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Wengling Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | SketchyGAN: Towards Diverse and Realistic Sketch to Image Synthesis 2 | ===================================== 3 | 4 | Code for ["SketchyGAN: Towards Diverse and Realistic Sketch to Image Synthesis"](https://arxiv.org/abs/1801.02753). 5 | 6 | 7 | ## Prerequisites 8 | 9 | - Python 3, NumPy, SciPy, OpenCV 3 10 | - Tensorflow(>=1.7.0). Tensorflow 2.0 is not supported. 11 | - A recent NVIDIA GPU 12 | 13 | 14 | ## Preparations 15 | 16 | - The path to data files needs to be specified in `input_pipeline.py`. See below for detailed information on data files. 17 | - You need to download ["Inception-V4 model"](http://download.tensorflow.org/models/inception_v4_2016_09_09.tar.gz), unzip it and put the checkpoint under `inception_v4_model`. 18 | 19 | 20 | ## Dataset 21 | ~~Pre-built tfrecord files are available for out of the box training.~~ 22 | - ~~Files for the Sketchy Database can be found [here](https://gtvault-my.sharepoint.com/:f:/g/personal/wchen342_gatech_edu/EtKmg1alDNdIl09WcvtJp_cBFs_7td3wKnb5FUcWZswEmw?e=eBGO6G).~~ 23 | - ~~Files for Augmented Sketchy(i.e. flickr images+edge maps), resized to 256x256 regardless of original aspect ratios, can be found [here](https://gtvault-my.sharepoint.com/:f:/g/personal/wchen342_gatech_edu/EmF7KlhqZ8ZPnpzbTIMDKBoBcjMrezh3X2eS1P_KtWiGCQ?e=BJhFPF).~~ 24 | 25 | **Note**: The webite hosting the dataset is no longer available. Please use the script under `data_processing` folder to crawl your own images. 26 | 27 | If you want to build tfrecord files from images, run `flickr_to_tfrecord.py` or `sketchy_to_tfrecord.py` for the respective dataset. 28 | 29 | If you wish to get the image files: 30 | - The Sketchy Database can be found [here](http://sketchy.eye.gatech.edu/). 31 | - Use `extract_images.py` under `data_processing` to extract images from tfrecord files. You need to specify input and output paths. The extracted images will be sorted by class names. 32 | - The dataset I used is no longer availabe due to its large size. You can crawl your own images and run through `edge_detection/batch_hed.py` -> `edge_detection/PostprocessHED.m` -> ``flickr_to_tfrecord.py` to create your own dataset. 33 | - ~~Please contact me if you need the original (not resized) Flickr images, since they are too large to upload to any online space.~~ 34 | 35 | 36 | ## Configurations 37 | 38 | The model can be trained out of the box, by running `main_single.py`. But there are several places you can change configurations: 39 | 40 | - Commandline options in `main_single.py` 41 | - Some global options in `config.py` 42 | - Activation/Normalization functions in `models_mru.py` 43 | 44 | 45 | ## Model 46 | 47 | - The model will be saved periodically. If you wish to resume, just use commandline switch `resume_from`. 48 | - If you wish to test the model, change `mode` from `train` to `test` and fill in `resume_from`. 49 | 50 | 51 | ## Citation 52 | 53 | If you use our work for your research, please cite our paper 54 | ``` 55 | @InProceedings{Chen_2018_CVPR, 56 | author = {Chen, Wengling and Hays, James}, 57 | title = {SketchyGAN: Towards Diverse and Realistic Sketch to Image Synthesis}, 58 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 59 | month = {June}, 60 | year = {2018} 61 | } 62 | ``` 63 | 64 | 65 | ## Credits 66 | - Inception-V4 and VGG16 code by Tensorflow Authors. 67 | - Tensorflow implementation of Spectral Normalization by [minhnhat93](https://github.com/minhnhat93/tf-SNDCGAN) 68 | - [Improved WGAN](https://github.com/igul222/improved_wgan_training) 69 | -------------------------------------------------------------------------------- /data_processing/classes.csv: -------------------------------------------------------------------------------- 1 | Name,ID,Image_num,Used 2 | airplane,n02691156,1434,1 3 | alarm_clock,n02694662,1442,1 4 | ant,n02219486,1656,1 5 | ape,n02470325,1332,1 6 | apple,n07739125,1319,1 7 | armor,n02895154,942,0 8 | axe,n02764044,1254,1 9 | banana,n07753592,1409,1 10 | bat,n02139199,1304,1 11 | bear,n02131653,1688,1 12 | bee,n02206856,1672,1 13 | beetle,n02167151,1077,1 14 | bell,n02824448,1227,1 15 | bench,n02828884,1355,1 16 | bicycle,n02834778,1344,1 17 | blimp,n02850950,1110,1 18 | bread,n07679356,1249,1 19 | butterfly,n02274259,2115,1 20 | cabin,n02932400,1250,1 21 | camel,n02437136,1428,1 22 | candle,n02948072,1789,1 23 | cannon,n02950826,1220,1 24 | car_(sedan),n02958343,1307,1 25 | castle,n02980441,1379,1 26 | cat,n02121620,1485,1 27 | chair,n02738535,1592,1 28 | chicken,n01791625,1210,1 29 | church,n03028079,1329,1 30 | couch,n04256520,1686,1 31 | cow,n01887787,1588,1 32 | crab,n01978287,1214,1 33 | crocodilian,n01697457,1188,1 34 | cup,n03063073,1260,1 35 | deer,n02430045,1680,1 36 | dog,n02103406,1340,1 37 | dolphin,n02068974,1460,1 38 | door,n03222176,1384,1 39 | duck,n01846331,1642,1 40 | elephant,n02503517,1387,1 41 | eyeglasses,n04272054,1280,0 42 | fan,n03271574,1223,1 43 | fish,n01440764,1133,1 44 | flower,n11669921,1924,1 45 | frog,n01639765,1222,1 46 | geyser,n09288635,1585,1 47 | giraffe,n02439033,1256,1 48 | guitar,n02676566,2017,1 49 | hamburger,n07697100,1373,1 50 | hammer,n03481172,1390,1 51 | harp,n03495258,1822,1 52 | hat,n02859184,1319,0 53 | hedgehog,n02346627,1147,1 54 | helicopter,n03512147,1247,1 55 | hermit_crab,n01986214,1430,1 56 | horse,n02374451,1402,1 57 | hot-air_balloon,n02782093,1240,1 58 | hotdog,n07697537,1257,1 59 | hourglass,n03544143,1176,1 60 | jack-o-lantern,n03590841,1712,1 61 | jellyfish,n01910747,1635,1 62 | kangaroo,n01877134,1556,1 63 | knife,n02973904,999,0 64 | lion,n02129165,1795,1 65 | lizard,n01674464,1202,1 66 | lobster,n01983481,1123,1 67 | motorcycle,n03790512,1380,1 68 | mouse,n02330245,1252,1 69 | mushroom,n12997919,1406,1 70 | owl,n01621127,1296,1 71 | parrot,n01816887,1149,1 72 | pear,n07767847,1279,1 73 | penguin,n02055803,1281,1 74 | piano,n03452741,1318,1 75 | pickup_truck,n03930630,1443,1 76 | pig,n02395406,1463,1 77 | pineapple,n07753275,1209,1 78 | pistol,n03948459,1338,1 79 | pizza,n07873807,1289,1 80 | pretzel,n07695742,1263,1 81 | rabbit,n02325366,1280,1 82 | raccoon,n02508021,1722,1 83 | racket,n02772700,1402,0 84 | ray,n01496331,967,1 85 | rhinoceros,n02391994,1496,1 86 | rifle,n02749479,1172,0 87 | rocket,n03773504,1220,1 88 | sailboat,n04128499,1191,1 89 | saw,n02770585,371,0 90 | saxophone,n04141076,1673,0 91 | scissors,n03044934,562,0 92 | scorpion,n01770393,1197,1 93 | sea_turtle,n01663401,1485,1 94 | seagull,n02041246,1413,1 95 | seal,n02076196,1832,1 96 | shark,n01482330,1219,1 97 | sheep,n02411705,1273,1 98 | shoe,n02882894,1381,0 99 | skyscraper,n04233124,1546,1 100 | snail,n01944390,1837,1 101 | snake,n01726692,1289,1 102 | songbird,n01527347,1536,1 103 | spider,n01772222,1249,1 104 | spoon,n03633091,1817,1 105 | squirrel,n02355227,1192,1 106 | starfish,n02317335,1396,1 107 | strawberry,n07745940,1478,1 108 | swan,n01858441,1314,1 109 | sword,n03039493,533,0 110 | table,n03201208,2217,1 111 | tank,n04389033,1488,1 112 | teapot,n04398044,1590,1 113 | teddy_bear,n04399382,0,0 114 | tiger,n02129604,2086,1 115 | tree,n11608250,1148,1 116 | trumpet,n03110669,2092,1 117 | turtle,n01669191,1328,1 118 | umbrella,n04507155,1341,1 119 | violin,n04536866,1330,0 120 | volcano,n09472597,1404,1 121 | wading_bird,n02000954,1178,1 122 | wheelchair,n04576002,1245,0 123 | windmill,n04587404,1899,1 124 | window,n03227184,493,0 125 | wine_bottle,n04591713,1258,1 126 | zebra,n02391049,1474,1 127 | -------------------------------------------------------------------------------- /data_processing/coco_data_provider.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append('.') 5 | sys.path.append('./pycocotools') 6 | import pycocotools.coco as coco 7 | 8 | pascal_classes = ['person', 'bird', 'cat', 'cow', 'dog', 'horse', 'sheep', 'aeroplane', 9 | 'bicycle', 'boat', 'bus', 'car', 'motorbike', 'train', 'bottle', 'chair', 10 | 'dining table', 'potted plant', 'sofa', 'tv/monitor'] 11 | pascal_classes_mapped = ['person', 'bird', 'cat', 'cow', 'dog', 'horse', 'sheep', 'airplane', 12 | 'bicycle', 'boat', 'bus', 'car (sedan)', 'motorcycle', 'train', 'bottle', 'chair', 13 | 'dining table', 'potted plant', 'sofa', 'tv'] 14 | 15 | images_dir = '../Datasets/COCO/coco-master/images' 16 | anno_dir = '../Datasets/COCO/coco-master/annotations' 17 | itrain_2014 = 'instances_train2014.json' 18 | ival_2014 = 'instances_val2014.json' 19 | 20 | 21 | def img_info_list_to_dict(input_list): 22 | dic = {} 23 | for i in input_list: 24 | image_id = i['id'] 25 | assert image_id not in dic.keys() 26 | dic[image_id] = i 27 | return dic 28 | 29 | 30 | def get_all_images_data_categories(split, catIds=[]): 31 | if split == 'train': 32 | COCO = coco.COCO(annotation_file=os.path.join(anno_dir, itrain_2014)) 33 | elif split == 'test': 34 | COCO = coco.COCO(annotation_file=os.path.join(anno_dir, ival_2014)) 35 | if len(catIds) == 0: 36 | return COCO.loadImgs(COCO.getImgIds()), COCO.loadAnns(COCO.getAnnIds()), COCO.loadCats(COCO.getCatIds()) 37 | else: 38 | return COCO.loadImgs(ids=COCO.getImgIds(catIds=catIds)), COCO.loadAnns( 39 | ids=COCO.getAnnIds(catIds=catIds)), COCO.loadCats(COCO.getCatIds(catIds=catIds)) 40 | 41 | 42 | def expand_bbox(bbox, max_height, max_width, frac): 43 | assert len(bbox) == 4 44 | half_width = round(bbox[2] / 2) 45 | half_height = round(bbox[3] / 2) 46 | mid_x = bbox[0] + half_width 47 | mid_y = bbox[1] + half_height 48 | 49 | x_min = max(0, mid_x - half_width * frac) 50 | y_min = max(0, mid_y - half_height * frac) 51 | x_max = min(max_width, mid_x + half_width * frac) 52 | y_max = min(max_height, mid_y + half_height * frac) 53 | return [round(x_min), round(y_min), round(x_max), round(y_max)] 54 | 55 | 56 | def get_shared_classes(input=None, print_out=False, output_file=True): 57 | if input is None: 58 | ret = get_all_images_data_categories('train')[2] 59 | else: 60 | ret = input 61 | coco_classes = [cls['name'] for cls in ret] 62 | class_dict = {item['name']: item for item in ret} 63 | # Convert 'car' to 'car (sedan)' for comparison with Sketchy 64 | coco_classes[coco_classes.index('car')] = 'car (sedan)' 65 | with open('../../shared_classes', 'r') as f: 66 | shared_classes = [cls[:-1].replace('_', ' ') for cls in f.readlines()] 67 | 68 | if print_out: 69 | print(len(shared_classes)) 70 | print(shared_classes) 71 | if output_file: 72 | shared_classes2 = [(cls + '\n') for cls in shared_classes if cls in coco_classes] 73 | with open('../../shared_classes2', 'w') as f: 74 | f.writelines(shared_classes2) 75 | 76 | shared_classes = [cls for cls in shared_classes if cls in coco_classes] 77 | print([cls for cls in shared_classes if cls in coco_classes and cls in pascal_classes_mapped]) 78 | shared_classes[shared_classes.index('car (sedan)')] = 'car' # Convert 'car' back 79 | output_dict = {class_dict[cls]['id']: class_dict[cls] for cls in shared_classes} 80 | return shared_classes, output_dict 81 | 82 | 83 | def get_bbox(): 84 | img_info, seg_info, cat_info = get_all_images_data_categories('train') 85 | img_info = img_info_list_to_dict(img_info) 86 | shared_classes, cls_info = get_shared_classes(input=cat_info, print_out=False, output_file=False) 87 | bbox_list = {cls: [] for cls in shared_classes} 88 | for object in seg_info: 89 | category_id = object['category_id'] 90 | if category_id not in cls_info: 91 | continue 92 | 93 | category_name = cls_info[category_id]['name'] 94 | bbox = object['bbox'] 95 | image_id = object['image_id'] 96 | iscrowd = object['iscrowd'] 97 | 98 | this_img_info = img_info[image_id] 99 | file_name = this_img_info['file_name'] 100 | height = this_img_info['height'] 101 | width = this_img_info['width'] 102 | bbox = expand_bbox(bbox, height, width, 1.5) 103 | 104 | bbox_list[category_name].append({ 105 | 'image_id': image_id, 106 | 'category_name': category_name, 107 | 'category_id': category_id, 108 | 'iscrowd': iscrowd, 109 | 'file_name': file_name, 110 | 'height': height, 111 | 'width': width, 112 | 'bbox': bbox, 113 | }) 114 | 115 | return bbox_list 116 | 117 | 118 | if __name__ == '__main__': 119 | get_bbox() 120 | -------------------------------------------------------------------------------- /data_processing/coco_share_classes.txt: -------------------------------------------------------------------------------- 1 | airplane 2 | apple 3 | bear 4 | bench 5 | bicycle 6 | cat 7 | car 8 | chair 9 | couch 10 | cow 11 | dog 12 | elephant 13 | giraffe 14 | horse 15 | knife 16 | motorcycle 17 | scissors 18 | sheep 19 | spoon 20 | -------------------------------------------------------------------------------- /data_processing/edge_detection/PostprocessHED.m: -------------------------------------------------------------------------------- 1 | % Modified from pix2pix scripts: https://github.com/phillipi/pix2pix/blob/master/scripts/edges 2 | 3 | %%% Prerequisites 4 | % You need to get the cpp file edgesNmsMex.cpp from https://raw.githubusercontent.com/pdollar/edges/master/private/edgesNmsMex.cpp 5 | % and compile it in Matlab: mex edgesNmsMex.cpp 6 | % You also need to download and install Piotr's Computer Vision Matlab Toolbox: https://pdollar.github.io/toolbox/ 7 | 8 | %%% parameters 9 | % hed_mat_dir: the hed mat file directory (the output of 'batch_hed.py') 10 | % edge_dir: the output HED edges directory 11 | % image_width: resize the edge map to [image_width, image_width] 12 | % threshold: threshold for image binarization (default 25.0/255.0) 13 | % small_edge: remove small edges (default 5) 14 | 15 | function [] = PostprocessHED(hed_mat_dir, edge_dir, image_width, threshold, small_edge) 16 | 17 | if ~exist(edge_dir, 'dir') 18 | mkdir(edge_dir); 19 | end 20 | fileList = dir(fullfile(hed_mat_dir, '*.mat')); 21 | nFiles = numel(fileList); 22 | fprintf('find %d mat files\n', nFiles); 23 | 24 | parfor n = 1 : nFiles 25 | if mod(n, 1000) == 0 26 | fprintf('process %d/%d images\n', n, nFiles); 27 | end 28 | fileName = fileList(n).name; 29 | filePath = fullfile(hed_mat_dir, fileName); 30 | jpgName = strrep(fileName, '.mat', '.png'); 31 | edge_path = fullfile(edge_dir, jpgName); 32 | 33 | if ~exist(edge_path, 'file') 34 | E = GetEdge(filePath); 35 | E = imresize(E,[image_width,image_width]); 36 | E_simple = SimpleEdge(E, threshold, small_edge); 37 | E_simple = uint8(E_simple*255); 38 | imwrite(E_simple, edge_path); 39 | end 40 | end 41 | end 42 | 43 | 44 | 45 | 46 | function [E] = GetEdge(filePath) 47 | matdata = load(filePath); 48 | E = 1-matdata.predict; 49 | end 50 | 51 | function [E4] = SimpleEdge(E, threshold, small_edge) 52 | if nargin <= 1 53 | threshold = 25.0/255.0; 54 | end 55 | 56 | if nargin <= 2 57 | small_edge = 5; 58 | end 59 | 60 | if ndims(E) == 3 61 | E = E(:,:,1); 62 | end 63 | 64 | E1 = 1 - E; 65 | E2 = EdgeNMS(E1); 66 | E3 = double(E2>=max(eps,threshold)); 67 | E3 = bwmorph(bwmorph(E3,'thin',inf), 'clean'); 68 | E4 = bwareaopen(E3, small_edge); 69 | E4 = bwmorph(bwmorph(E4,'shrink'), 'spur'); 70 | E4=1-E4; 71 | end 72 | 73 | function [E_nms] = EdgeNMS( E ) 74 | E=single(E); 75 | [Ox,Oy] = gradient2(convTri(E,4)); 76 | [Oxx,~] = gradient2(Ox); 77 | [Oxy,Oyy] = gradient2(Oy); 78 | O = mod(atan(Oyy.*sign(-Oxy)./(Oxx+1e-5)),pi); 79 | E_nms = edgesNmsMex(E,O,1,5,1.01,1); 80 | end 81 | -------------------------------------------------------------------------------- /data_processing/edge_detection/batch_hed.py: -------------------------------------------------------------------------------- 1 | # HED batch processing script 2 | # Modified from pix2pix scripts: https://github.com/phillipi/pix2pix/tree/master/scripts/edges 3 | # Notice that this script processes images in batches instead of one-by-one, which is slightly faster. 4 | # However, it will discard the last n images where 0 < n < batch_size. 5 | # Originally modified from https://github.com/s9xie/hed/blob/master/examples/hed/HED-tutorial.ipynb 6 | # Step 1: download the hed repo: https://github.com/s9xie/hed 7 | # Step 2: download the models and protoxt, and put them under {caffe_root}/examples/hed/ 8 | # Step 3: put this script under {caffe_root}/examples/hed/ 9 | # Step 4: run the following script: 10 | # python batch_hed.py --images_dir=/data/to/path/photos/ --hed_mat_dir=/data/to/path/hed_mat_files/ 11 | # Step 5: run the MATLAB post-processing script "PostprocessHED.m" 12 | # The code sometimes crashes after computation is done. Error looks like "Check failed: ... driver shutting down". You can just kill the job. 13 | # For large images, it will produce gpu memory issue. Therefore, you better resize the images before running this script. 14 | 15 | import numpy as np 16 | import scipy.misc 17 | from PIL import Image 18 | import matplotlib.pyplot as plt 19 | import matplotlib.pylab as pylab 20 | import matplotlib.cm as cm 21 | import scipy.io 22 | import os 23 | import cv2 24 | import argparse 25 | from time import time 26 | 27 | 28 | def parse_args(): 29 | parser = argparse.ArgumentParser(description='batch proccesing: photos->edges') 30 | parser.add_argument('--caffe_root', dest='caffe_root', help='caffe root', 31 | default='../hed-master', type=str) 32 | parser.add_argument('--caffemodel', dest='caffemodel', help='caffemodel', 33 | default='../hed-master/examples/hed/hed_pretrained_bsds.caffemodel', 34 | type=str) 35 | parser.add_argument('--prototxt', dest='prototxt', help='caffe prototxt file', 36 | default='../hed-master/examples/hed/deploy.prototxt', type=str) 37 | parser.add_argument('--images_dir', dest='images_dir', 38 | default='../flickr_coco/airplane', 39 | help='directory to store input photos', type=str) 40 | parser.add_argument('--hed_mat_dir', dest='hed_mat_dir', default='../flickr_hed_coco/mat/airplane', 41 | help='directory to store output hed edges in mat file', 42 | type=str) 43 | parser.add_argument('--border', dest='border', help='padding border', type=int, default=128) 44 | parser.add_argument('--gpu_id', dest='gpu_id', help='gpu id', type=int, default=0) 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | args = parse_args() 50 | for arg in vars(args): 51 | print('[%s] =' % arg, getattr(args, arg)) 52 | # Make sure that caffe is on the python path: 53 | caffe_root = args.caffe_root # this file is expected to be in {caffe_root}/examples/hed/ 54 | import sys 55 | sys.path.insert(0, os.path.join(caffe_root, 'python')) 56 | 57 | # Suppress output 58 | os.environ['GLOG_minloglevel'] = '2' 59 | 60 | import caffe 61 | import scipy.io as sio 62 | 63 | if not os.path.exists(args.hed_mat_dir): 64 | print('create output directory %s' % args.hed_mat_dir) 65 | os.makedirs(args.hed_mat_dir) 66 | 67 | imgList = os.listdir(args.images_dir) 68 | nImgs = len(imgList) 69 | print('#images = %d' % nImgs) 70 | 71 | caffe.set_mode_gpu() 72 | caffe.set_device(args.gpu_id) 73 | # load net 74 | net = caffe.Net(args.prototxt, args.caffemodel, caffe.TEST) 75 | # pad border 76 | border = args.border 77 | 78 | # Time counter 79 | prev_time = float("-inf") 80 | curr_time = float("-inf") 81 | image_list = [] 82 | batch_size = 8 83 | 84 | for i in range(nImgs): 85 | if i % 500 == 0: 86 | print('processing image %d/%d' % (i, nImgs)) 87 | curr_time = time() 88 | elapsed = curr_time - prev_time 89 | print( 90 | "Now at iteration %d. Elapsed time: %.5fs. Average time: %.5fs/iter" % (i, elapsed, elapsed / 500.)) 91 | prev_time = curr_time 92 | 93 | if i % batch_size == 0 and i > 0: 94 | in_ = np.concatenate(image_list, axis=0) 95 | 96 | # shape for input (data blob is N x C x H x W), set data 97 | net.blobs['data'].reshape(*in_.shape) 98 | net.blobs['data'].data[...] = in_ 99 | # run net and take argmax for prediction 100 | net.forward() 101 | fuse = np.squeeze(net.blobs['sigmoid-fuse'].data) 102 | # get rid of the border 103 | fuse = fuse[:, border:-border, border:-border] 104 | 105 | for j in range(batch_size): 106 | # save hed file to the disk 107 | name, ext = os.path.splitext(imgList[i - batch_size + j]) 108 | sio.savemat(os.path.join(args.hed_mat_dir, name + '.mat'), {'predict': fuse[j]}) 109 | 110 | image_list = [] 111 | 112 | im = cv2.imread(os.path.join(args.images_dir, imgList[i]), cv2.IMREAD_COLOR) 113 | im = cv2.resize(im, (256, 256), interpolation=cv2.INTER_AREA) 114 | 115 | in_ = im.astype(np.float32) 116 | in_ = np.pad(in_, ((border, border), (border, border), (0, 0)), 'reflect') 117 | 118 | in_ = in_[:, :, ::-1] 119 | in_ -= np.array((104.00698793, 116.66876762, 122.67891434)) 120 | in_ = np.expand_dims(in_.transpose((2, 0, 1)), axis=0) 121 | 122 | image_list.append(in_) 123 | -------------------------------------------------------------------------------- /data_processing/extract_images.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | datafile_path = "../flickr_output" 7 | image_output_path = "../extract_output/images" 8 | edgemap_output_path = "../extract_output/edges" 9 | 10 | 11 | def get_paired_input(filenames): 12 | filename_queue = tf.train.string_input_producer(filenames, capacity=512, shuffle=False, num_epochs=1) 13 | reader = tf.TFRecordReader() 14 | 15 | _, serialized_example = reader.read(filename_queue) 16 | 17 | features = tf.parse_single_example( 18 | serialized_example, 19 | features={ 20 | 'ImageNetID': tf.FixedLenFeature([], tf.string), 21 | 'SketchID': tf.FixedLenFeature([], tf.int64), 22 | 'Category': tf.FixedLenFeature([], tf.string), 23 | 'CategoryID': tf.FixedLenFeature([], tf.int64), 24 | 'Difficulty': tf.FixedLenFeature([], tf.int64), 25 | 'Stroke_Count': tf.FixedLenFeature([], tf.int64), 26 | 'WrongPose': tf.FixedLenFeature([], tf.int64), 27 | 'Context': tf.FixedLenFeature([], tf.int64), 28 | 'Ambiguous': tf.FixedLenFeature([], tf.int64), 29 | 'Error': tf.FixedLenFeature([], tf.int64), 30 | 'class_id': tf.FixedLenFeature([], tf.int64), 31 | 'is_test': tf.FixedLenFeature([], tf.int64), 32 | 'image_jpeg': tf.FixedLenFeature([], tf.string), 33 | 'sketch_png': tf.FixedLenFeature([], tf.string), 34 | } 35 | ) 36 | 37 | image = features['image_jpeg'] 38 | sketch = features['sketch_png'] 39 | 40 | # Attributes 41 | category = features['Category'] 42 | # Not used 43 | # imagenet_id = features['ImageNetID'] 44 | # sketch_id = features['SketchID'] 45 | # class_id = features['class_id'] 46 | # is_test = features['is_test'] 47 | # Stroke_Count = features['Stroke_Count'] 48 | # Difficulty = features['Difficulty'] 49 | # CategoryID = features['CategoryID'] 50 | # WrongPose = features['WrongPose'] 51 | # Context = features['Context'] 52 | # Ambiguous = features['Ambiguous'] 53 | # Error = features['Error'] 54 | 55 | return image, sketch, category 56 | 57 | 58 | def build_queue(filenames, batch_size, capacity=1024): 59 | image, sketch, category = get_paired_input(filenames) 60 | 61 | images, sketchs, categories = tf.train.batch( 62 | [image, sketch, category], 63 | batch_size=1, capacity=capacity, num_threads=2, allow_smaller_final_batch=True) 64 | 65 | return images, sketchs, categories 66 | 67 | 68 | def extract_images(class_name): 69 | filenames = sorted([os.path.join(datafile_path, f) for f in os.listdir(datafile_path) 70 | if os.path.isfile(os.path.join(datafile_path, f)) and f.startswith(class_name)]) 71 | 72 | # Make dirs 73 | this_image_path = os.path.join(image_output_path, class_name) 74 | this_edgemap_path = os.path.join(edgemap_output_path, class_name) 75 | if not os.path.isdir(image_output_path) and not os.path.exists(image_output_path): 76 | os.makedirs(image_output_path) 77 | if not os.path.isdir(this_image_path) and not os.path.exists(this_image_path): 78 | os.makedirs(this_image_path) 79 | if not os.path.isdir(edgemap_output_path) and not os.path.exists(edgemap_output_path): 80 | os.makedirs(edgemap_output_path) 81 | if not os.path.isdir(this_edgemap_path) and not os.path.exists(this_edgemap_path): 82 | os.makedirs(this_edgemap_path) 83 | 84 | # Read tfrecords 85 | images, sketchs, categories = build_queue(filenames, 64) 86 | 87 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 88 | sess.run(tf.global_variables_initializer()) 89 | sess.run(tf.local_variables_initializer()) 90 | 91 | counter = 0 92 | 93 | coord = tf.train.Coordinator() 94 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 95 | 96 | while True: 97 | try: 98 | raw_jpeg_data, raw_png_data, category_names = sess.run( 99 | [images, sketchs, categories]) 100 | filename_appendix = "_%08d" % counter 101 | with open(os.path.join(this_image_path, class_name + filename_appendix + '.jpg'), 'wb') as f: 102 | f.write(raw_jpeg_data[0]) 103 | with open(os.path.join(this_edgemap_path, class_name + filename_appendix + '.png'), 'wb') as f: 104 | f.write(raw_png_data[0]) 105 | 106 | counter += 1 107 | except Exception as e: 108 | print(e.args) 109 | break 110 | 111 | if counter % 100 == 0: 112 | print("Now at iteration %d." % counter) 113 | 114 | coord.request_stop() 115 | coord.join(threads) 116 | print() 117 | 118 | 119 | if __name__ == "__main__": 120 | # class_name = "airplane" 121 | filenames = sorted([f for f in os.listdir(datafile_path) if os.path.isfile(os.path.join(datafile_path, f))]) 122 | class_names = sorted(list({f.replace('_', '.').split('.', 1)[0] for f in filenames})) 123 | print('Num of classes found: %d' % len(class_names)) 124 | 125 | for cls in class_names: 126 | extract_images(cls) 127 | -------------------------------------------------------------------------------- /data_processing/flickr_crawler.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import json 3 | import os 4 | from time import time 5 | from six.moves.urllib.parse import urlparse 6 | from datetime import date, timedelta 7 | 8 | from icrawler.downloader import ImageDownloader 9 | from icrawler.builtin import FlickrImageCrawler 10 | from icrawler.builtin.flickr import FlickrParser 11 | 12 | 13 | # You need a Flickr API key to make this script work. Also note that Flickr has a rate restriction 14 | # so you cannot crawl too fast. 15 | 16 | 17 | # Override default icrawler classes 18 | class MyFlickrParser(FlickrParser): 19 | 20 | def parse(self, response, apikey, size_preference=None): 21 | content = json.loads(response.content.decode()) 22 | if content['stat'] != 'ok': 23 | raise ValueError("Status: %s" % content['stat']) 24 | photos = content['photos']['photo'] 25 | print('Num photos: %d' % len(photos)) 26 | for photo in photos: 27 | photo_id = photo['id'] 28 | 29 | if 'url_z' in photo.keys(): 30 | url = photo['url_z'] 31 | elif 'url_n' in photo.keys(): 32 | url = photo['url_n'] 33 | else: 34 | # print('Empty URL!') 35 | continue 36 | 37 | yield dict(file_url=url, meta=photo) 38 | 39 | 40 | class MyImageDownloader(ImageDownloader): 41 | def get_filename(self, task, default_ext): 42 | """Set the path where the image will be saved. 43 | 44 | The default strategy is to use an increasing 6-digit number as 45 | the filename. You can override this method if you want to set custom 46 | naming rules. The file extension is kept if it can be obtained from 47 | the url, otherwise ``default_ext`` is used as extension. 48 | 49 | Args:(i + 1) 50 | task (dict): The task dict got from ``task_queue``. 51 | 52 | Output: 53 | Filename with extension. 54 | """ 55 | url_path = urlparse(task['file_url'])[2] 56 | extension = url_path.split('.')[-1] if '.' in url_path else default_ext 57 | filename = url_path.split('.')[0].split('/')[-1].split('_')[0] 58 | # file_idx = self.fetched_num + self.file_idx_offset 59 | return '{}.{}'.format(filename, extension) 60 | 61 | 62 | TODAY = date(2018, 4, 21) 63 | delta = timedelta(days=5 * 365/12) # go back 5 years 64 | output_path = '../flickr_output' 65 | 66 | 67 | # Main method 68 | def crawl(crawl_list, work_list): 69 | for i in range(len(crawl_list)): 70 | class_name = crawl_list[i] 71 | print("Now fetching class: %s" % class_name) 72 | output_dir = os.path.join(output_path, class_name) 73 | if not os.path.exists(output_dir): 74 | os.mkdir(output_dir) 75 | flickr_crawler = FlickrImageCrawler('', # put your Flickr API key here 76 | feeder_threads=2, parser_threads=10, downloader_threads=5, 77 | parser_cls=MyFlickrParser, 78 | downloader_cls=MyImageDownloader, 79 | storage={'root_dir': output_dir}, 80 | log_level=logging.ERROR) 81 | 82 | # Time counter 83 | prev_time = float("-inf") 84 | curr_time = float("-inf") 85 | for i in range(28): 86 | curr_time = time() 87 | elapsed = curr_time - prev_time 88 | print( 89 | "Now at iteration %d. Elapsed time: %.5fs." % (i, elapsed)) 90 | prev_time = curr_time 91 | flickr_crawler.crawl(max_num=4000, text=class_name, sort='relevance', per_page=500, 92 | min_upload_date=TODAY - (i+1) * delta, max_upload_date=TODAY - i * delta, 93 | extras='url_n,url_z,original_format,path_alias') 94 | 95 | work_list.append(class_name) 96 | if i >= len(crawl_list) - 1: 97 | work_list.append('end') 98 | 99 | 100 | if __name__ == '__main__': 101 | 102 | print(TODAY - 28 * delta) 103 | 104 | crawl_list = [] 105 | for class_name in ['car', 'clock']: 106 | crawl_list.append(class_name) 107 | 108 | crawl(crawl_list, []) 109 | -------------------------------------------------------------------------------- /data_processing/flickr_filter.py: -------------------------------------------------------------------------------- 1 | import imghdr 2 | import itertools 3 | import os 4 | import sys 5 | from time import time 6 | import csv 7 | import PIL.Image as im 8 | import numpy as np 9 | import scipy.io 10 | import scipy.misc as spm 11 | 12 | sys.path.append('..') 13 | sys.path.append('../slim') 14 | sys.path.append('../object_detection') 15 | # Notice: you need to clone TF-slim and Tensorflow Object Detection API 16 | # into data_processing: 17 | # https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim 18 | # https://github.com/tensorflow/models/tree/master/research/object_detection 19 | 20 | import cv2 21 | import coco_data_provider as coco 22 | import tensorflow as tf 23 | from slim.nets import nets_factory 24 | from slim.preprocessing import preprocessing_factory 25 | 26 | from object_detection.utils import label_map_util 27 | 28 | 29 | inception_ckpt_path = '../../inception_resnet_v2/inception_resnet_v2_2016_08_30.ckpt' 30 | 31 | slim = tf.contrib.slim 32 | tf.logging.set_verbosity(tf.logging.INFO) 33 | 34 | 35 | def load_image_into_numpy_array(image): 36 | (im_width, im_height) = image.size 37 | return np.array(image.getdata()).reshape( 38 | (im_height, im_width, 3)).astype(np.uint8), im_width, im_height 39 | 40 | 41 | def get_imagenet_class_labels(): 42 | synset_list = [s.strip() for s in open('./imagenet_lsvrc_2015_synsets.txt', 'r').readlines()] 43 | num_synsets_in_ilsvrc = len(synset_list) 44 | assert num_synsets_in_ilsvrc == 1000 45 | 46 | synset_to_human_list = open('./imagenet_metadata.txt', 'r').readlines() 47 | num_synsets_in_all_imagenet = len(synset_to_human_list) 48 | assert num_synsets_in_all_imagenet == 21842 49 | 50 | synset_to_human = {} 51 | for s in synset_to_human_list: 52 | parts = s.strip().split('\t') 53 | assert len(parts) == 2 54 | synset = parts[0] 55 | human = parts[1] 56 | synset_to_human[synset] = human 57 | 58 | label_index = 1 59 | labels_to_names = {0: 'background'} 60 | for synset in synset_list: 61 | name = synset_to_human[synset] 62 | labels_to_names[label_index] = name 63 | label_index += 1 64 | 65 | return labels_to_names 66 | 67 | 68 | def check_jpg_vadility_single(path): 69 | if imghdr.what(path) == 'jpg': 70 | return True 71 | return False 72 | 73 | 74 | def check_jpg_vadility(path): 75 | file_list = [f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))] 76 | invalid_file_list = [] 77 | # Time counter 78 | prev_time = float("-inf") 79 | curr_time = float("-inf") 80 | for i in range(len(file_list)): 81 | if i % 5000 == 0: 82 | curr_time = time() 83 | elapsed = curr_time - prev_time 84 | print( 85 | "Now at iteration %d. Elapsed time: %.5fs." % (i, elapsed)) 86 | prev_time = curr_time 87 | print(len(invalid_file_list)) 88 | try: 89 | img = im.open(os.path.join(path, file_list[i])) 90 | format = img.format.lower() 91 | if format != 'jpg' and format != 'jpeg': 92 | raise ValueError 93 | # img.load() 94 | except: 95 | invalid_file_list.append(file_list[i]) 96 | return invalid_file_list 97 | 98 | 99 | def build_imagenet_graph(path): 100 | tf.reset_default_graph() 101 | print(path) 102 | 103 | filename_queue = tf.train.string_input_producer(tf.train.match_filenames_once(path + "/*.jpg"), 104 | num_epochs=1, shuffle=False, capacity=100) 105 | image_reader = tf.WholeFileReader() 106 | image_file_name, image_file = image_reader.read(filename_queue) 107 | 108 | image = tf.image.decode_jpeg(image_file, channels=3, fancy_upscaling=True) 109 | 110 | model_name = 'inception_resnet_v2' 111 | network_fn = nets_factory.get_network_fn(model_name, is_training=False, num_classes=1001) 112 | 113 | preprocessing_name = model_name 114 | image_preprocessing_fn = preprocessing_factory.get_preprocessing(preprocessing_name, is_training=False) 115 | 116 | eval_image_size = network_fn.default_image_size 117 | 118 | image = image_preprocessing_fn(image, eval_image_size, eval_image_size) 119 | 120 | filenames, images = tf.train.batch([image_file_name, image], batch_size=100, num_threads=2, capacity=500) 121 | logits, _ = network_fn(images) 122 | 123 | variables_to_restore = slim.get_variables_to_restore() 124 | predictions = tf.argmax(logits, 1) 125 | 126 | return filenames, logits, predictions, variables_to_restore 127 | 128 | 129 | def filter_by_imagenet(path, cls_name): 130 | labels_dict = get_imagenet_class_labels() 131 | filenames, logits, predictions, variables_to_restore = build_imagenet_graph(path) 132 | saver = tf.train.Saver(variables_to_restore) 133 | output_filename_list = [] 134 | counter = 0 135 | 136 | with tf.Session(config=config) as sess: 137 | sess.run(tf.global_variables_initializer()) 138 | sess.run(tf.local_variables_initializer()) 139 | 140 | saver.restore(sess, inception_ckpt_path) 141 | 142 | coord = tf.train.Coordinator() 143 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 144 | 145 | while True: 146 | try: 147 | filename_list, logit_array, prediction_list = sess.run([filenames, logits, predictions]) 148 | except Exception as e: 149 | break 150 | 151 | if counter % 5000 == 0: 152 | print("Evaluated %d files" % counter) 153 | print(len(output_filename_list)) 154 | 155 | prediction_dict = {os.path.split(filename)[1]: labels_dict[prediction] for filename, prediction in 156 | zip(filename_list, prediction_list)} 157 | for i, j in prediction_dict.items(): 158 | j = [p.strip() for p in j.lower().split(',')] 159 | if cls_name.lower() in j and len(j) == 1: 160 | output_filename_list.append(i.decode('ascii')) 161 | 162 | counter += 100 163 | 164 | coord.request_stop() 165 | coord.join(threads) 166 | 167 | return output_filename_list 168 | 169 | 170 | # SSD filter for COCO classes in Tensorflow instead of Caffe. 171 | # Not fully functional yet. It will not output filtered filenames. 172 | def filter_by_coco(path, cls_name): 173 | TEST_IMAGE_PATHS = [os.path.join(path, f) for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))] 174 | 175 | counter = 0 176 | output_filename_list = [] 177 | 178 | tf.reset_default_graph() 179 | print(path) 180 | 181 | PATH_TO_CKPT = '../ssd_inception_v2/frozen_inference_graph.pb' 182 | PATH_TO_LABELS = os.path.join('../../object_detection/data', 'mscoco_label_map.pbtxt') 183 | NUM_CLASSES = 90 184 | 185 | # Label map 186 | label_map = label_map_util.load_labelmap(PATH_TO_LABELS) 187 | categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, 188 | use_display_name=True) 189 | category_index = label_map_util.create_category_index(categories) 190 | for i in range(80): 191 | if categories[i]['name'] == cls_name: 192 | cls_index = categories[i]['id'] 193 | 194 | # Load graph 195 | detection_graph = tf.Graph() 196 | with detection_graph.as_default(): 197 | # Input queue 198 | filename_queue = tf.train.string_input_producer(tf.train.match_filenames_once(path + "/*.jpg"), 199 | num_epochs=1, shuffle=False, capacity=100) 200 | image_reader = tf.WholeFileReader() 201 | image_file_name, image_file = image_reader.read(filename_queue) 202 | 203 | image = tf.image.decode_jpeg(image_file, channels=3, fancy_upscaling=True) 204 | image0 = tf.image.resize_image_with_crop_or_pad(image, 500, 500) 205 | image = tf.image.resize_images(image0, [250, 250], method=tf.image.ResizeMethod.BILINEAR) 206 | image = tf.cast(image, tf.uint8) 207 | 208 | filenames, images = tf.train.batch([image_file_name, image], batch_size=20, num_threads=2, capacity=500) 209 | 210 | # Graph Def 211 | od_graph_def = tf.GraphDef() 212 | with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid: 213 | serialized_graph = fid.read() 214 | od_graph_def.ParseFromString(serialized_graph) 215 | tf.import_graph_def(od_graph_def, name='', input_map={'image_tensor:0': images}) 216 | 217 | # Time counter 218 | prev_time = float("-inf") 219 | curr_time = float("-inf") 220 | 221 | with detection_graph.as_default(): 222 | with tf.Session(graph=detection_graph, config=config) as sess: 223 | sess.run(tf.global_variables_initializer()) 224 | sess.run(tf.local_variables_initializer()) 225 | 226 | coord = tf.train.Coordinator() 227 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 228 | 229 | # Definite input and output Tensors for detection_graph 230 | image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') 231 | # Each box represents a part of the image where a particular object was detected. 232 | detection_boxes = detection_graph.get_tensor_by_name('detection_boxes:0') 233 | # Each score represent how level of confidence for each of the objects. 234 | # Score is shown on the result image, together with the class label. 235 | detection_scores = detection_graph.get_tensor_by_name('detection_scores:0') 236 | detection_classes = detection_graph.get_tensor_by_name('detection_classes:0') 237 | num_detections = detection_graph.get_tensor_by_name('num_detections:0') 238 | for image_path in TEST_IMAGE_PATHS: 239 | if counter % 5 == 0: 240 | curr_time = time() 241 | elapsed = curr_time - prev_time 242 | print( 243 | "Now at iteration %d. Elapsed time: %.5fs." % (counter, elapsed)) 244 | prev_time = curr_time 245 | 246 | image = im.open(image_path) 247 | # the array based representation of the image will be used later in order to prepare the 248 | # result image with boxes and labels on it. 249 | image_np, im_width, im_height = load_image_into_numpy_array(image) 250 | image_np = scipy.misc.imresize(image_np, 0.5, 'bilinear') 251 | # Expand dimensions since the model expects images to have shape: [1, None, None, 3] 252 | image_np_expanded = np.expand_dims(image_np, axis=0) 253 | # Actual detection. 254 | (boxes, scores, classes, num) = sess.run( 255 | [detection_boxes, detection_scores, detection_classes, num_detections], 256 | feed_dict={image_tensor: image_np_expanded}) 257 | 258 | # Filter results 259 | boxes = np.squeeze(boxes) 260 | classes = np.squeeze(classes).astype(np.int32), 261 | scores = np.squeeze(scores) 262 | idx = np.logical_and(scores > 0.9, classes == cls_index) 263 | portion = np.prod(boxes[idx], axis=1) / (im_width * im_height) 264 | 265 | if portion.size > 0: 266 | print() 267 | 268 | counter += 1 269 | 270 | 271 | def filter_images(flickr_dir, cls_name): 272 | imagenet_classes = [i[:-1] for i in open('./imagenet_share_classes.txt').readlines()] 273 | coco_classes = [i[:-1] for i in open('./coco_share_classes.txt').readlines()] 274 | 275 | this_dir = os.path.join(flickr_dir, cls_name) 276 | 277 | invalid_file_list = [] 278 | print("Invalid file number: %d" % len(invalid_file_list)) 279 | for file_name in invalid_file_list: 280 | os.remove(os.path.join(this_dir, file_name)) 281 | 282 | if cls_name in imagenet_classes: 283 | output_filename_list = filter_by_imagenet(this_dir, cls_name) 284 | elif cls_name in coco_classes: 285 | output_filename_list = filter_by_coco(this_dir, cls_name) 286 | else: 287 | raise NotImplementedError 288 | 289 | file_list = [f for f in os.listdir(this_dir) if os.path.isfile(os.path.join(this_dir, f))] 290 | print(len(file_list) - len(output_filename_list)) 291 | for file_name in file_list: 292 | if file_name not in output_filename_list: 293 | os.remove(os.path.join(this_dir, file_name)) 294 | 295 | 296 | config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, 297 | intra_op_parallelism_threads=4) 298 | config.gpu_options.allow_growth = True 299 | config.gpu_options.per_process_gpu_memory_fraction = 0.9 300 | 301 | with open('./imagenet_share_classes.txt', 'r') as f: 302 | classes_list = [i[:-1] for i in f.readlines()] 303 | 304 | if __name__ == '__main__': 305 | # # Imagenet 306 | # labels_dict = get_imagenet_class_labels() 307 | # labels_list = [label + '\n' for i, label in labels_dict.items()] 308 | # with open('./imagenet_classes.txt', 'w') as f: 309 | # f.writelines(labels_list) 310 | 311 | # COCO 312 | labels_list = [cls['name'] for cls in coco.get_all_images_data_categories(split='train')[2]] 313 | 314 | # with open('./all_classes', 'r') as f: 315 | # sketchy_class_list = [i[:-1] for i in f.readlines()] 316 | # 317 | # filtered_classes = [] 318 | # for cls in sketchy_class_list: 319 | # for large_cls in labels_list: 320 | # large_cls_names = [i.strip() for i in large_cls.split(',')] 321 | # for name in large_cls_names: 322 | # if cls.lower() == name.lower() and cls.lower() not in filtered_classes: 323 | # filtered_classes.append(cls + '\n') 324 | # 325 | # with open('./coco_share_classes.txt', 'w') as f: 326 | # f.writelines(filtered_classes) 327 | 328 | # Inference 329 | filter_range = (8, 12) 330 | 331 | class_list = ['airplane'] 332 | for class_name in class_list: 333 | filter_images('../flickr_output', class_name) 334 | -------------------------------------------------------------------------------- /data_processing/flickr_to_tfrecord.py: -------------------------------------------------------------------------------- 1 | import multiprocessing as mp 2 | import os 3 | import sys 4 | import csv 5 | import numpy as np 6 | # import scipy.io 7 | # import scipy.misc as spm 8 | 9 | import cv2 10 | from scipy import ndimage 11 | import tensorflow as tf 12 | # from tensorflow.python.framework import ops 13 | 14 | 15 | np.seterr(all='raise') 16 | 17 | 18 | def showImg(img): 19 | cv2.imshow("test", img) 20 | cv2.waitKey(-1) 21 | 22 | 23 | def dense_to_one_hot(labels_dense, num_classes): 24 | """Convert class labels from scalars to one-hot vectors.""" 25 | num_labels = labels_dense.shape[0] 26 | index_offset = np.arange(num_labels) * num_classes 27 | labels_one_hot = np.zeros((num_labels, num_classes), dtype=np.int32) 28 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 29 | return labels_one_hot 30 | 31 | 32 | def _bytes_feature(value): 33 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 34 | 35 | 36 | def _int64_feature(value): 37 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 38 | 39 | 40 | valid_class_names = ['car_(sedan)'] # Class to convert 41 | 42 | classes_info = '../data_processing/classes.csv' 43 | photo_folder = '../flickr_coco' 44 | sketch_folder = '../flickr_hed/jpg' 45 | data_dir = '../flickr_output' 46 | 47 | config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, 48 | intra_op_parallelism_threads=8) 49 | 50 | 51 | def check_repeat(seq): 52 | seen = set() 53 | seen_add = seen.add 54 | seen_twice = set(x for x in seq if x in seen or seen_add(x)) 55 | return list(seen_twice) 56 | 57 | 58 | def read_csv(filename): 59 | with open(filename) as csvfile: 60 | reader = csv.DictReader(csvfile) 61 | l = list(reader) 62 | 63 | return l 64 | 65 | 66 | def read_txt(filename): 67 | with open(filename) as txtfile: 68 | lines = txtfile.readlines() 69 | return [l[:-1] for l in lines] 70 | 71 | 72 | def build_graph(): 73 | photo_filename = tf.placeholder(dtype=tf.string, shape=()) 74 | label_filename = tf.placeholder(dtype=tf.string, shape=()) 75 | photo = tf.read_file(photo_filename) 76 | label = tf.read_file(label_filename) 77 | photo_decoded = tf.image.decode_jpeg(photo, channels=3, fancy_upscaling=True) 78 | label_decoded = tf.image.decode_png(label) 79 | 80 | # Encode 64x64 81 | photo_input = tf.placeholder(dtype=tf.uint8, shape=(256, 256, 3)) 82 | photo_small_input = tf.placeholder(dtype=tf.uint8, shape=(64, 64, 3)) 83 | label_input = tf.placeholder(dtype=tf.uint8, shape=(256, 256, 1)) 84 | label_small_input = tf.placeholder(dtype=tf.uint8, shape=(64, 64, 1)) 85 | 86 | photo_stream = tf.image.encode_jpeg(photo_input, quality=95, progressive=False, 87 | optimize_size=False, chroma_downsampling=False) 88 | photo_small_stream = tf.image.encode_jpeg(photo_small_input, quality=95, progressive=False, 89 | optimize_size=False, chroma_downsampling=False) 90 | label_stream = tf.image.encode_png(label_input, compression=7) 91 | label_small_stream = tf.image.encode_png(label_small_input, compression=7) 92 | 93 | return photo_filename, label_filename, photo, label, photo_decoded, label_decoded,\ 94 | photo_input, photo_small_input, label_input, label_small_input, photo_stream, photo_small_stream,\ 95 | label_stream, label_small_stream 96 | 97 | 98 | def split_csvlist(stat_info): 99 | cat = list(set([item['Category'] for item in stat_info])) 100 | l = [] 101 | for c in cat: 102 | li = [item for item in stat_info if item['Category'] == c] 103 | l.append(li) 104 | 105 | return cat, l 106 | 107 | 108 | def binarize(sketch, threshold=245): 109 | sketch[sketch < threshold] = 0 110 | sketch[sketch >= threshold] = 255 111 | return sketch 112 | 113 | 114 | def write_image_data(): 115 | dir_list = [d for d in os.listdir(sketch_folder) if os.path.isdir(os.path.join(sketch_folder, d))] 116 | classes = read_csv(classes_info) 117 | classes_ids = [item['Name'] for item in classes] 118 | work_list = [] 119 | 120 | for dir in dir_list: 121 | if dir not in valid_class_names: 122 | continue 123 | this_sketch_folder = os.path.join(sketch_folder, dir) 124 | this_photo_folder = os.path.join(photo_folder, dir) 125 | sketch_files_list = [f for f in os.listdir(this_sketch_folder) 126 | if os.path.isfile(os.path.join(this_sketch_folder, f))] 127 | photo_files_list = [f for f in os.listdir(this_photo_folder) 128 | if os.path.isfile(os.path.join(this_photo_folder, f)) and 129 | os.path.isfile(os.path.join(this_sketch_folder, os.path.splitext(f)[0] + '.png'))] 130 | assert len(photo_files_list) == len(sketch_files_list) 131 | 132 | class_id = classes_ids.index(dir) 133 | work_list.append((class_id, photo_files_list, sketch_files_list)) 134 | 135 | num_processes = 8 136 | 137 | # launch processes 138 | pool = mp.Pool(processes=num_processes) 139 | results = [] 140 | for i in range(len(work_list)): 141 | result = pool.apply_async(write_dir_photo, args=(work_list[i], classes_ids, i)) 142 | results.append(result) 143 | for i in range(len(results)): 144 | results[i].get() 145 | 146 | pool.close() 147 | pool.join() 148 | 149 | 150 | def write_dir_photo(object_item, classes_ids, process_id): 151 | 152 | class_id_num = object_item[0] 153 | class_id = str(class_id_num) 154 | Category = classes_ids[class_id_num] 155 | if Category not in valid_class_names: 156 | return 157 | 158 | photo_files = object_item[1] 159 | sketch_files = object_item[2] 160 | 161 | processed_num = -1 162 | writer = None 163 | file_contain_photo_num = 2048 164 | 165 | path_image = os.path.join(photo_folder, Category) 166 | path_label = os.path.join(sketch_folder, Category) 167 | 168 | with tf.device('/cpu:0'): 169 | photo_filename, label_filename, photo, label, photo_decoded, label_decoded, \ 170 | photo_input, photo_small_input, label_input, label_small_input, photo_stream, photo_small_stream, \ 171 | label_stream, label_small_stream = build_graph() 172 | 173 | with tf.Session(config=config) as sess: 174 | sess.run(tf.global_variables_initializer()) 175 | print('ID %s with num photos: %d' % (class_id, len(photo_files))) 176 | 177 | for i in range(len(photo_files)): 178 | 179 | processed_num += 1 180 | 181 | if processed_num % file_contain_photo_num == 0: 182 | if writer is not None: 183 | writer.close() 184 | print('ID %s current at: %d' % (class_id, i)) 185 | else: 186 | print('Init first writer') 187 | writer = tf.python_io.TFRecordWriter( 188 | os.path.join(data_dir, Category + '_coco_seg_%d.tfrecord' % (processed_num // file_contain_photo_num))) 189 | 190 | cur_photo_path = os.path.join(path_image, photo_files[i]) 191 | label_name = os.path.splitext(photo_files[i])[0] + '.png' 192 | if label_name not in sketch_files: 193 | print('Wrong filename: %s' % photo_files[i]) 194 | cur_label_path = os.path.join(path_label, label_name) 195 | 196 | try: 197 | out_image, out_image_decoded = sess.run([photo, photo_decoded], feed_dict={ 198 | photo_filename: os.path.join(cur_photo_path)}) 199 | out_label, out_label_decoded = sess.run([label, label_decoded], feed_dict={ 200 | label_filename: os.path.join(cur_label_path)}) 201 | except: 202 | print('Invalid file') 203 | continue 204 | 205 | # Resize 206 | channel_num = 3. if len(out_label_decoded.shape) == 3 and out_label_decoded.shape[2] == 3 else 1. 207 | out_image_decoded = cv2.resize(out_image_decoded, (256, 256), interpolation=cv2.INTER_AREA) 208 | out_image_decoded_small = cv2.resize(out_image_decoded, (64, 64), interpolation=cv2.INTER_AREA) 209 | out_label_decoded = (np.sum(out_label_decoded.astype(np.float64), axis=2)/channel_num).astype(np.uint8) 210 | out_label_decoded_small = cv2.resize(out_label_decoded, (64, 64), interpolation=cv2.INTER_AREA) 211 | if (out_label_decoded_small == 0).all() or (out_label_decoded_small == 255).all(): 212 | print('Warning: blank sketch from resize') 213 | continue 214 | 215 | # Distance map 216 | out_dist_map = ndimage.distance_transform_edt(binarize(out_label_decoded)) 217 | out_dist_map = (out_dist_map / out_dist_map.max() * 255.).astype(np.uint8) 218 | 219 | out_dist_map_small = ndimage.distance_transform_edt(binarize(out_label_decoded_small)) 220 | out_dist_map_small = (out_dist_map_small / out_dist_map_small.max() * 255.).astype(np.uint8) 221 | 222 | # Stream 223 | image_string, label_string = sess.run([photo_stream, label_stream], feed_dict={ 224 | photo_input: out_image_decoded, label_input: out_label_decoded.reshape((256, 256, 1)) 225 | }) 226 | image_string_small, label_string_small = sess.run([photo_small_stream, label_small_stream], feed_dict={ 227 | photo_small_input: out_image_decoded_small, label_small_input: out_label_decoded_small.reshape((64, 64, 1)) 228 | }) 229 | dist_map_string = sess.run(label_stream, feed_dict={label_input: out_dist_map.reshape((256, 256, 1))}) 230 | dist_map_string_small = sess.run(label_small_stream, feed_dict={ 231 | label_small_input: out_dist_map_small.reshape((64, 64, 1))}) 232 | 233 | example = tf.train.Example(features=tf.train.Features(feature={ 234 | 'ImageNetID': _bytes_feature(''.encode('utf-8')), 235 | 'SketchID': _int64_feature(0), 236 | 'Category': _bytes_feature(Category.encode('utf-8')), 237 | 'CategoryID': _int64_feature(class_id_num), 238 | 'Difficulty': _int64_feature(0), 239 | 'Stroke_Count': _int64_feature(0), 240 | 'WrongPose': _int64_feature(0), 241 | 'Context': _int64_feature(0), 242 | 'Ambiguous': _int64_feature(0), 243 | 'Error': _int64_feature(0), 244 | 'is_test': _int64_feature(0), 245 | 'class_id': _int64_feature(class_id_num), 246 | 'image_jpeg': _bytes_feature(image_string), 247 | 'image_small_jpeg': _bytes_feature(image_string_small), 248 | 'sketch_png': _bytes_feature(label_string), 249 | 'sketch_small_png': _bytes_feature(label_string_small), 250 | 'dist_map_png': _bytes_feature(dist_map_string), 251 | 'dist_map_small_png': _bytes_feature(dist_map_string_small), 252 | })) 253 | writer.write(example.SerializeToString()) 254 | 255 | writer.close() 256 | 257 | 258 | write_image_data() 259 | -------------------------------------------------------------------------------- /data_processing/imagenet_lsvrc_2015_synsets.txt: -------------------------------------------------------------------------------- 1 | n01440764 2 | n01443537 3 | n01484850 4 | n01491361 5 | n01494475 6 | n01496331 7 | n01498041 8 | n01514668 9 | n01514859 10 | n01518878 11 | n01530575 12 | n01531178 13 | n01532829 14 | n01534433 15 | n01537544 16 | n01558993 17 | n01560419 18 | n01580077 19 | n01582220 20 | n01592084 21 | n01601694 22 | n01608432 23 | n01614925 24 | n01616318 25 | n01622779 26 | n01629819 27 | n01630670 28 | n01631663 29 | n01632458 30 | n01632777 31 | n01641577 32 | n01644373 33 | n01644900 34 | n01664065 35 | n01665541 36 | n01667114 37 | n01667778 38 | n01669191 39 | n01675722 40 | n01677366 41 | n01682714 42 | n01685808 43 | n01687978 44 | n01688243 45 | n01689811 46 | n01692333 47 | n01693334 48 | n01694178 49 | n01695060 50 | n01697457 51 | n01698640 52 | n01704323 53 | n01728572 54 | n01728920 55 | n01729322 56 | n01729977 57 | n01734418 58 | n01735189 59 | n01737021 60 | n01739381 61 | n01740131 62 | n01742172 63 | n01744401 64 | n01748264 65 | n01749939 66 | n01751748 67 | n01753488 68 | n01755581 69 | n01756291 70 | n01768244 71 | n01770081 72 | n01770393 73 | n01773157 74 | n01773549 75 | n01773797 76 | n01774384 77 | n01774750 78 | n01775062 79 | n01776313 80 | n01784675 81 | n01795545 82 | n01796340 83 | n01797886 84 | n01798484 85 | n01806143 86 | n01806567 87 | n01807496 88 | n01817953 89 | n01818515 90 | n01819313 91 | n01820546 92 | n01824575 93 | n01828970 94 | n01829413 95 | n01833805 96 | n01843065 97 | n01843383 98 | n01847000 99 | n01855032 100 | n01855672 101 | n01860187 102 | n01871265 103 | n01872401 104 | n01873310 105 | n01877812 106 | n01882714 107 | n01883070 108 | n01910747 109 | n01914609 110 | n01917289 111 | n01924916 112 | n01930112 113 | n01943899 114 | n01944390 115 | n01945685 116 | n01950731 117 | n01955084 118 | n01968897 119 | n01978287 120 | n01978455 121 | n01980166 122 | n01981276 123 | n01983481 124 | n01984695 125 | n01985128 126 | n01986214 127 | n01990800 128 | n02002556 129 | n02002724 130 | n02006656 131 | n02007558 132 | n02009229 133 | n02009912 134 | n02011460 135 | n02012849 136 | n02013706 137 | n02017213 138 | n02018207 139 | n02018795 140 | n02025239 141 | n02027492 142 | n02028035 143 | n02033041 144 | n02037110 145 | n02051845 146 | n02056570 147 | n02058221 148 | n02066245 149 | n02071294 150 | n02074367 151 | n02077923 152 | n02085620 153 | n02085782 154 | n02085936 155 | n02086079 156 | n02086240 157 | n02086646 158 | n02086910 159 | n02087046 160 | n02087394 161 | n02088094 162 | n02088238 163 | n02088364 164 | n02088466 165 | n02088632 166 | n02089078 167 | n02089867 168 | n02089973 169 | n02090379 170 | n02090622 171 | n02090721 172 | n02091032 173 | n02091134 174 | n02091244 175 | n02091467 176 | n02091635 177 | n02091831 178 | n02092002 179 | n02092339 180 | n02093256 181 | n02093428 182 | n02093647 183 | n02093754 184 | n02093859 185 | n02093991 186 | n02094114 187 | n02094258 188 | n02094433 189 | n02095314 190 | n02095570 191 | n02095889 192 | n02096051 193 | n02096177 194 | n02096294 195 | n02096437 196 | n02096585 197 | n02097047 198 | n02097130 199 | n02097209 200 | n02097298 201 | n02097474 202 | n02097658 203 | n02098105 204 | n02098286 205 | n02098413 206 | n02099267 207 | n02099429 208 | n02099601 209 | n02099712 210 | n02099849 211 | n02100236 212 | n02100583 213 | n02100735 214 | n02100877 215 | n02101006 216 | n02101388 217 | n02101556 218 | n02102040 219 | n02102177 220 | n02102318 221 | n02102480 222 | n02102973 223 | n02104029 224 | n02104365 225 | n02105056 226 | n02105162 227 | n02105251 228 | n02105412 229 | n02105505 230 | n02105641 231 | n02105855 232 | n02106030 233 | n02106166 234 | n02106382 235 | n02106550 236 | n02106662 237 | n02107142 238 | n02107312 239 | n02107574 240 | n02107683 241 | n02107908 242 | n02108000 243 | n02108089 244 | n02108422 245 | n02108551 246 | n02108915 247 | n02109047 248 | n02109525 249 | n02109961 250 | n02110063 251 | n02110185 252 | n02110341 253 | n02110627 254 | n02110806 255 | n02110958 256 | n02111129 257 | n02111277 258 | n02111500 259 | n02111889 260 | n02112018 261 | n02112137 262 | n02112350 263 | n02112706 264 | n02113023 265 | n02113186 266 | n02113624 267 | n02113712 268 | n02113799 269 | n02113978 270 | n02114367 271 | n02114548 272 | n02114712 273 | n02114855 274 | n02115641 275 | n02115913 276 | n02116738 277 | n02117135 278 | n02119022 279 | n02119789 280 | n02120079 281 | n02120505 282 | n02123045 283 | n02123159 284 | n02123394 285 | n02123597 286 | n02124075 287 | n02125311 288 | n02127052 289 | n02128385 290 | n02128757 291 | n02128925 292 | n02129165 293 | n02129604 294 | n02130308 295 | n02132136 296 | n02133161 297 | n02134084 298 | n02134418 299 | n02137549 300 | n02138441 301 | n02165105 302 | n02165456 303 | n02167151 304 | n02168699 305 | n02169497 306 | n02172182 307 | n02174001 308 | n02177972 309 | n02190166 310 | n02206856 311 | n02219486 312 | n02226429 313 | n02229544 314 | n02231487 315 | n02233338 316 | n02236044 317 | n02256656 318 | n02259212 319 | n02264363 320 | n02268443 321 | n02268853 322 | n02276258 323 | n02277742 324 | n02279972 325 | n02280649 326 | n02281406 327 | n02281787 328 | n02317335 329 | n02319095 330 | n02321529 331 | n02325366 332 | n02326432 333 | n02328150 334 | n02342885 335 | n02346627 336 | n02356798 337 | n02361337 338 | n02363005 339 | n02364673 340 | n02389026 341 | n02391049 342 | n02395406 343 | n02396427 344 | n02397096 345 | n02398521 346 | n02403003 347 | n02408429 348 | n02410509 349 | n02412080 350 | n02415577 351 | n02417914 352 | n02422106 353 | n02422699 354 | n02423022 355 | n02437312 356 | n02437616 357 | n02441942 358 | n02442845 359 | n02443114 360 | n02443484 361 | n02444819 362 | n02445715 363 | n02447366 364 | n02454379 365 | n02457408 366 | n02480495 367 | n02480855 368 | n02481823 369 | n02483362 370 | n02483708 371 | n02484975 372 | n02486261 373 | n02486410 374 | n02487347 375 | n02488291 376 | n02488702 377 | n02489166 378 | n02490219 379 | n02492035 380 | n02492660 381 | n02493509 382 | n02493793 383 | n02494079 384 | n02497673 385 | n02500267 386 | n02504013 387 | n02504458 388 | n02509815 389 | n02510455 390 | n02514041 391 | n02526121 392 | n02536864 393 | n02606052 394 | n02607072 395 | n02640242 396 | n02641379 397 | n02643566 398 | n02655020 399 | n02666196 400 | n02667093 401 | n02669723 402 | n02672831 403 | n02676566 404 | n02687172 405 | n02690373 406 | n02692877 407 | n02699494 408 | n02701002 409 | n02704792 410 | n02708093 411 | n02727426 412 | n02730930 413 | n02747177 414 | n02749479 415 | n02769748 416 | n02776631 417 | n02777292 418 | n02782093 419 | n02783161 420 | n02786058 421 | n02787622 422 | n02788148 423 | n02790996 424 | n02791124 425 | n02791270 426 | n02793495 427 | n02794156 428 | n02795169 429 | n02797295 430 | n02799071 431 | n02802426 432 | n02804414 433 | n02804610 434 | n02807133 435 | n02808304 436 | n02808440 437 | n02814533 438 | n02814860 439 | n02815834 440 | n02817516 441 | n02823428 442 | n02823750 443 | n02825657 444 | n02834397 445 | n02835271 446 | n02837789 447 | n02840245 448 | n02841315 449 | n02843684 450 | n02859443 451 | n02860847 452 | n02865351 453 | n02869837 454 | n02870880 455 | n02871525 456 | n02877765 457 | n02879718 458 | n02883205 459 | n02892201 460 | n02892767 461 | n02894605 462 | n02895154 463 | n02906734 464 | n02909870 465 | n02910353 466 | n02916936 467 | n02917067 468 | n02927161 469 | n02930766 470 | n02939185 471 | n02948072 472 | n02950826 473 | n02951358 474 | n02951585 475 | n02963159 476 | n02965783 477 | n02966193 478 | n02966687 479 | n02971356 480 | n02974003 481 | n02977058 482 | n02978881 483 | n02979186 484 | n02980441 485 | n02981792 486 | n02988304 487 | n02992211 488 | n02992529 489 | n02999410 490 | n03000134 491 | n03000247 492 | n03000684 493 | n03014705 494 | n03016953 495 | n03017168 496 | n03018349 497 | n03026506 498 | n03028079 499 | n03032252 500 | n03041632 501 | n03042490 502 | n03045698 503 | n03047690 504 | n03062245 505 | n03063599 506 | n03063689 507 | n03065424 508 | n03075370 509 | n03085013 510 | n03089624 511 | n03095699 512 | n03100240 513 | n03109150 514 | n03110669 515 | n03124043 516 | n03124170 517 | n03125729 518 | n03126707 519 | n03127747 520 | n03127925 521 | n03131574 522 | n03133878 523 | n03134739 524 | n03141823 525 | n03146219 526 | n03160309 527 | n03179701 528 | n03180011 529 | n03187595 530 | n03188531 531 | n03196217 532 | n03197337 533 | n03201208 534 | n03207743 535 | n03207941 536 | n03208938 537 | n03216828 538 | n03218198 539 | n03220513 540 | n03223299 541 | n03240683 542 | n03249569 543 | n03250847 544 | n03255030 545 | n03259280 546 | n03271574 547 | n03272010 548 | n03272562 549 | n03290653 550 | n03291819 551 | n03297495 552 | n03314780 553 | n03325584 554 | n03337140 555 | n03344393 556 | n03345487 557 | n03347037 558 | n03355925 559 | n03372029 560 | n03376595 561 | n03379051 562 | n03384352 563 | n03388043 564 | n03388183 565 | n03388549 566 | n03393912 567 | n03394916 568 | n03400231 569 | n03404251 570 | n03417042 571 | n03424325 572 | n03425413 573 | n03443371 574 | n03444034 575 | n03445777 576 | n03445924 577 | n03447447 578 | n03447721 579 | n03450230 580 | n03452741 581 | n03457902 582 | n03459775 583 | n03461385 584 | n03467068 585 | n03476684 586 | n03476991 587 | n03478589 588 | n03481172 589 | n03482405 590 | n03483316 591 | n03485407 592 | n03485794 593 | n03492542 594 | n03494278 595 | n03495258 596 | n03496892 597 | n03498962 598 | n03527444 599 | n03529860 600 | n03530642 601 | n03532672 602 | n03534580 603 | n03535780 604 | n03538406 605 | n03544143 606 | n03584254 607 | n03584829 608 | n03590841 609 | n03594734 610 | n03594945 611 | n03595614 612 | n03598930 613 | n03599486 614 | n03602883 615 | n03617480 616 | n03623198 617 | n03627232 618 | n03630383 619 | n03633091 620 | n03637318 621 | n03642806 622 | n03649909 623 | n03657121 624 | n03658185 625 | n03661043 626 | n03662601 627 | n03666591 628 | n03670208 629 | n03673027 630 | n03676483 631 | n03680355 632 | n03690938 633 | n03691459 634 | n03692522 635 | n03697007 636 | n03706229 637 | n03709823 638 | n03710193 639 | n03710637 640 | n03710721 641 | n03717622 642 | n03720891 643 | n03721384 644 | n03724870 645 | n03729826 646 | n03733131 647 | n03733281 648 | n03733805 649 | n03742115 650 | n03743016 651 | n03759954 652 | n03761084 653 | n03763968 654 | n03764736 655 | n03769881 656 | n03770439 657 | n03770679 658 | n03773504 659 | n03775071 660 | n03775546 661 | n03776460 662 | n03777568 663 | n03777754 664 | n03781244 665 | n03782006 666 | n03785016 667 | n03786901 668 | n03787032 669 | n03788195 670 | n03788365 671 | n03791053 672 | n03792782 673 | n03792972 674 | n03793489 675 | n03794056 676 | n03796401 677 | n03803284 678 | n03804744 679 | n03814639 680 | n03814906 681 | n03825788 682 | n03832673 683 | n03837869 684 | n03838899 685 | n03840681 686 | n03841143 687 | n03843555 688 | n03854065 689 | n03857828 690 | n03866082 691 | n03868242 692 | n03868863 693 | n03871628 694 | n03873416 695 | n03874293 696 | n03874599 697 | n03876231 698 | n03877472 699 | n03877845 700 | n03884397 701 | n03887697 702 | n03888257 703 | n03888605 704 | n03891251 705 | n03891332 706 | n03895866 707 | n03899768 708 | n03902125 709 | n03903868 710 | n03908618 711 | n03908714 712 | n03916031 713 | n03920288 714 | n03924679 715 | n03929660 716 | n03929855 717 | n03930313 718 | n03930630 719 | n03933933 720 | n03935335 721 | n03937543 722 | n03938244 723 | n03942813 724 | n03944341 725 | n03947888 726 | n03950228 727 | n03954731 728 | n03956157 729 | n03958227 730 | n03961711 731 | n03967562 732 | n03970156 733 | n03976467 734 | n03976657 735 | n03977966 736 | n03980874 737 | n03982430 738 | n03983396 739 | n03991062 740 | n03992509 741 | n03995372 742 | n03998194 743 | n04004767 744 | n04005630 745 | n04008634 746 | n04009552 747 | n04019541 748 | n04023962 749 | n04026417 750 | n04033901 751 | n04033995 752 | n04037443 753 | n04039381 754 | n04040759 755 | n04041544 756 | n04044716 757 | n04049303 758 | n04065272 759 | n04067472 760 | n04069434 761 | n04070727 762 | n04074963 763 | n04081281 764 | n04086273 765 | n04090263 766 | n04099969 767 | n04111531 768 | n04116512 769 | n04118538 770 | n04118776 771 | n04120489 772 | n04125021 773 | n04127249 774 | n04131690 775 | n04133789 776 | n04136333 777 | n04141076 778 | n04141327 779 | n04141975 780 | n04146614 781 | n04147183 782 | n04149813 783 | n04152593 784 | n04153751 785 | n04154565 786 | n04162706 787 | n04179913 788 | n04192698 789 | n04200800 790 | n04201297 791 | n04204238 792 | n04204347 793 | n04208210 794 | n04209133 795 | n04209239 796 | n04228054 797 | n04229816 798 | n04235860 799 | n04238763 800 | n04239074 801 | n04243546 802 | n04251144 803 | n04252077 804 | n04252225 805 | n04254120 806 | n04254680 807 | n04254777 808 | n04258138 809 | n04259630 810 | n04263257 811 | n04264628 812 | n04265275 813 | n04266014 814 | n04270147 815 | n04273569 816 | n04275548 817 | n04277352 818 | n04285008 819 | n04286575 820 | n04296562 821 | n04310018 822 | n04311004 823 | n04311174 824 | n04317175 825 | n04325704 826 | n04326547 827 | n04328186 828 | n04330267 829 | n04332243 830 | n04335435 831 | n04336792 832 | n04344873 833 | n04346328 834 | n04347754 835 | n04350905 836 | n04355338 837 | n04355933 838 | n04356056 839 | n04357314 840 | n04366367 841 | n04367480 842 | n04370456 843 | n04371430 844 | n04371774 845 | n04372370 846 | n04376876 847 | n04380533 848 | n04389033 849 | n04392985 850 | n04398044 851 | n04399382 852 | n04404412 853 | n04409515 854 | n04417672 855 | n04418357 856 | n04423845 857 | n04428191 858 | n04429376 859 | n04435653 860 | n04442312 861 | n04443257 862 | n04447861 863 | n04456115 864 | n04458633 865 | n04461696 866 | n04462240 867 | n04465501 868 | n04467665 869 | n04476259 870 | n04479046 871 | n04482393 872 | n04483307 873 | n04485082 874 | n04486054 875 | n04487081 876 | n04487394 877 | n04493381 878 | n04501370 879 | n04505470 880 | n04507155 881 | n04509417 882 | n04515003 883 | n04517823 884 | n04522168 885 | n04523525 886 | n04525038 887 | n04525305 888 | n04532106 889 | n04532670 890 | n04536866 891 | n04540053 892 | n04542943 893 | n04548280 894 | n04548362 895 | n04550184 896 | n04552348 897 | n04553703 898 | n04554684 899 | n04557648 900 | n04560804 901 | n04562935 902 | n04579145 903 | n04579432 904 | n04584207 905 | n04589890 906 | n04590129 907 | n04591157 908 | n04591713 909 | n04592741 910 | n04596742 911 | n04597913 912 | n04599235 913 | n04604644 914 | n04606251 915 | n04612504 916 | n04613696 917 | n06359193 918 | n06596364 919 | n06785654 920 | n06794110 921 | n06874185 922 | n07248320 923 | n07565083 924 | n07579787 925 | n07583066 926 | n07584110 927 | n07590611 928 | n07613480 929 | n07614500 930 | n07615774 931 | n07684084 932 | n07693725 933 | n07695742 934 | n07697313 935 | n07697537 936 | n07711569 937 | n07714571 938 | n07714990 939 | n07715103 940 | n07716358 941 | n07716906 942 | n07717410 943 | n07717556 944 | n07718472 945 | n07718747 946 | n07720875 947 | n07730033 948 | n07734744 949 | n07742313 950 | n07745940 951 | n07747607 952 | n07749582 953 | n07753113 954 | n07753275 955 | n07753592 956 | n07754684 957 | n07760859 958 | n07768694 959 | n07802026 960 | n07831146 961 | n07836838 962 | n07860988 963 | n07871810 964 | n07873807 965 | n07875152 966 | n07880968 967 | n07892512 968 | n07920052 969 | n07930864 970 | n07932039 971 | n09193705 972 | n09229709 973 | n09246464 974 | n09256479 975 | n09288635 976 | n09332890 977 | n09399592 978 | n09421951 979 | n09428293 980 | n09468604 981 | n09472597 982 | n09835506 983 | n10148035 984 | n10565667 985 | n11879895 986 | n11939491 987 | n12057211 988 | n12144580 989 | n12267677 990 | n12620546 991 | n12768682 992 | n12985857 993 | n12998815 994 | n13037406 995 | n13040303 996 | n13044778 997 | n13052670 998 | n13054560 999 | n13133613 1000 | n15075141 1001 | -------------------------------------------------------------------------------- /data_processing/imagenet_share_classes.txt: -------------------------------------------------------------------------------- 1 | ant 2 | banana 3 | bee 4 | bell 5 | candle 6 | cannon 7 | castle 8 | church 9 | cup 10 | geyser 11 | hammer 12 | hedgehog 13 | hotdog 14 | hourglass 15 | jellyfish 16 | lion 17 | mushroom 18 | pig 19 | pineapple 20 | pizza 21 | pretzel 22 | rifle 23 | scorpion 24 | snail 25 | starfish 26 | strawberry 27 | tank 28 | teapot 29 | tiger 30 | volcano 31 | zebra 32 | -------------------------------------------------------------------------------- /data_processing/pycocotools/__init__.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | -------------------------------------------------------------------------------- /data_processing/pycocotools/_mask.pyx: -------------------------------------------------------------------------------- 1 | # distutils: language = c 2 | # distutils: sources = ../common/maskApi.c 3 | 4 | #************************************************************************** 5 | # Microsoft COCO Toolbox. version 2.0 6 | # Data, paper, and tutorials available at: http://mscoco.org/ 7 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 8 | # Licensed under the Simplified BSD License [see coco/license.txt] 9 | #************************************************************************** 10 | 11 | __author__ = 'tsungyi' 12 | 13 | # import both Python-level and C-level symbols of Numpy 14 | # the API uses Numpy to interface C and Python 15 | import numpy as np 16 | cimport numpy as np 17 | from libc.stdlib cimport malloc, free 18 | 19 | # intialized Numpy. must do. 20 | np.import_array() 21 | 22 | # import numpy C function 23 | # we use PyArray_ENABLEFLAGS to make Numpy ndarray responsible to memoery management 24 | cdef extern from "numpy/arrayobject.h": 25 | void PyArray_ENABLEFLAGS(np.ndarray arr, int flags) 26 | 27 | # Declare the prototype of the C functions in MaskApi.h 28 | cdef extern from "maskApi.h": 29 | ctypedef unsigned int uint 30 | ctypedef unsigned long siz 31 | ctypedef unsigned char byte 32 | ctypedef double* BB 33 | ctypedef struct RLE: 34 | siz h, 35 | siz w, 36 | siz m, 37 | uint* cnts, 38 | void rlesInit( RLE **R, siz n ) 39 | void rleEncode( RLE *R, const byte *M, siz h, siz w, siz n ) 40 | void rleDecode( const RLE *R, byte *mask, siz n ) 41 | void rleMerge( const RLE *R, RLE *M, siz n, bint intersect ) 42 | void rleArea( const RLE *R, siz n, uint *a ) 43 | void rleIou( RLE *dt, RLE *gt, siz m, siz n, byte *iscrowd, double *o ) 44 | void bbIou( BB dt, BB gt, siz m, siz n, byte *iscrowd, double *o ) 45 | void rleToBbox( const RLE *R, BB bb, siz n ) 46 | void rleFrBbox( RLE *R, const BB bb, siz h, siz w, siz n ) 47 | void rleFrPoly( RLE *R, const double *xy, siz k, siz h, siz w ) 48 | char* rleToString( const RLE *R ) 49 | void rleFrString( RLE *R, char *s, siz h, siz w ) 50 | 51 | # python class to wrap RLE array in C 52 | # the class handles the memory allocation and deallocation 53 | cdef class RLEs: 54 | cdef RLE *_R 55 | cdef siz _n 56 | 57 | def __cinit__(self, siz n =0): 58 | rlesInit(&self._R, n) 59 | self._n = n 60 | 61 | # free the RLE array here 62 | def __dealloc__(self): 63 | if self._R is not NULL: 64 | for i in range(self._n): 65 | free(self._R[i].cnts) 66 | free(self._R) 67 | def __getattr__(self, key): 68 | if key == 'n': 69 | return self._n 70 | raise AttributeError(key) 71 | 72 | # python class to wrap Mask array in C 73 | # the class handles the memory allocation and deallocation 74 | cdef class Masks: 75 | cdef byte *_mask 76 | cdef siz _h 77 | cdef siz _w 78 | cdef siz _n 79 | 80 | def __cinit__(self, h, w, n): 81 | self._mask = malloc(h*w*n* sizeof(byte)) 82 | self._h = h 83 | self._w = w 84 | self._n = n 85 | # def __dealloc__(self): 86 | # the memory management of _mask has been passed to np.ndarray 87 | # it doesn't need to be freed here 88 | 89 | # called when passing into np.array() and return an np.ndarray in column-major order 90 | def __array__(self): 91 | cdef np.npy_intp shape[1] 92 | shape[0] = self._h*self._w*self._n 93 | # Create a 1D array, and reshape it to fortran/Matlab column-major array 94 | ndarray = np.PyArray_SimpleNewFromData(1, shape, np.NPY_UINT8, self._mask).reshape((self._h, self._w, self._n), order='F') 95 | # The _mask allocated by Masks is now handled by ndarray 96 | PyArray_ENABLEFLAGS(ndarray, np.NPY_OWNDATA) 97 | return ndarray 98 | 99 | # internal conversion from Python RLEs object to compressed RLE format 100 | def _toString(RLEs Rs): 101 | cdef siz n = Rs.n 102 | cdef bytes py_string 103 | cdef char* c_string 104 | objs = [] 105 | for i in range(n): 106 | c_string = rleToString( &Rs._R[i] ) 107 | py_string = c_string 108 | objs.append({ 109 | 'size': [Rs._R[i].h, Rs._R[i].w], 110 | 'counts': py_string 111 | }) 112 | free(c_string) 113 | return objs 114 | 115 | # internal conversion from compressed RLE format to Python RLEs object 116 | def _frString(rleObjs): 117 | cdef siz n = len(rleObjs) 118 | Rs = RLEs(n) 119 | cdef bytes py_string 120 | cdef char* c_string 121 | for i, obj in enumerate(rleObjs): 122 | py_string = str(obj['counts']).encode('utf8') 123 | c_string = py_string 124 | rleFrString( &Rs._R[i], c_string, obj['size'][0], obj['size'][1] ) 125 | return Rs 126 | 127 | # encode mask to RLEs objects 128 | # list of RLE string can be generated by RLEs member function 129 | def encode(np.ndarray[np.uint8_t, ndim=3, mode='fortran'] mask): 130 | h, w, n = mask.shape[0], mask.shape[1], mask.shape[2] 131 | cdef RLEs Rs = RLEs(n) 132 | rleEncode(Rs._R,mask.data,h,w,n) 133 | objs = _toString(Rs) 134 | return objs 135 | 136 | # decode mask from compressed list of RLE string or RLEs object 137 | def decode(rleObjs): 138 | cdef RLEs Rs = _frString(rleObjs) 139 | h, w, n = Rs._R[0].h, Rs._R[0].w, Rs._n 140 | masks = Masks(h, w, n) 141 | rleDecode( Rs._R, masks._mask, n ); 142 | return np.array(masks) 143 | 144 | def merge(rleObjs, bint intersect=0): 145 | cdef RLEs Rs = _frString(rleObjs) 146 | cdef RLEs R = RLEs(1) 147 | rleMerge(Rs._R, R._R, Rs._n, intersect) 148 | obj = _toString(R)[0] 149 | return obj 150 | 151 | def area(rleObjs): 152 | cdef RLEs Rs = _frString(rleObjs) 153 | cdef uint* _a = malloc(Rs._n* sizeof(uint)) 154 | rleArea(Rs._R, Rs._n, _a) 155 | cdef np.npy_intp shape[1] 156 | shape[0] = Rs._n 157 | a = np.array((Rs._n, ), dtype=np.uint8) 158 | a = np.PyArray_SimpleNewFromData(1, shape, np.NPY_UINT32, _a) 159 | PyArray_ENABLEFLAGS(a, np.NPY_OWNDATA) 160 | return a 161 | 162 | # iou computation. support function overload (RLEs-RLEs and bbox-bbox). 163 | def iou( dt, gt, pyiscrowd ): 164 | def _preproc(objs): 165 | if len(objs) == 0: 166 | return objs 167 | if type(objs) == np.ndarray: 168 | if len(objs.shape) == 1: 169 | objs = objs.reshape((objs[0], 1)) 170 | # check if it's Nx4 bbox 171 | if not len(objs.shape) == 2 or not objs.shape[1] == 4: 172 | raise Exception('numpy ndarray input is only for *bounding boxes* and should have Nx4 dimension') 173 | objs = objs.astype(np.double) 174 | elif type(objs) == list: 175 | # check if list is in box format and convert it to np.ndarray 176 | isbox = np.all(np.array([(len(obj)==4) and ((type(obj)==list) or (type(obj)==np.ndarray)) for obj in objs])) 177 | isrle = np.all(np.array([type(obj) == dict for obj in objs])) 178 | if isbox: 179 | objs = np.array(objs, dtype=np.double) 180 | if len(objs.shape) == 1: 181 | objs = objs.reshape((1,objs.shape[0])) 182 | elif isrle: 183 | objs = _frString(objs) 184 | else: 185 | raise Exception('list input can be bounding box (Nx4) or RLEs ([RLE])') 186 | else: 187 | raise Exception('unrecognized type. The following type: RLEs (rle), np.ndarray (box), and list (box) are supported.') 188 | return objs 189 | def _rleIou(RLEs dt, RLEs gt, np.ndarray[np.uint8_t, ndim=1] iscrowd, siz m, siz n, np.ndarray[np.double_t, ndim=1] _iou): 190 | rleIou( dt._R, gt._R, m, n, iscrowd.data, _iou.data ) 191 | def _bbIou(np.ndarray[np.double_t, ndim=2] dt, np.ndarray[np.double_t, ndim=2] gt, np.ndarray[np.uint8_t, ndim=1] iscrowd, siz m, siz n, np.ndarray[np.double_t, ndim=1] _iou): 192 | bbIou( dt.data, gt.data, m, n, iscrowd.data, _iou.data ) 193 | def _len(obj): 194 | cdef siz N = 0 195 | if type(obj) == RLEs: 196 | N = obj.n 197 | elif len(obj)==0: 198 | pass 199 | elif type(obj) == np.ndarray: 200 | N = obj.shape[0] 201 | return N 202 | # convert iscrowd to numpy array 203 | cdef np.ndarray[np.uint8_t, ndim=1] iscrowd = np.array(pyiscrowd, dtype=np.uint8) 204 | # simple type checking 205 | cdef siz m, n 206 | dt = _preproc(dt) 207 | gt = _preproc(gt) 208 | m = _len(dt) 209 | n = _len(gt) 210 | if m == 0 or n == 0: 211 | return [] 212 | if not type(dt) == type(gt): 213 | raise Exception('The dt and gt should have the same data type, either RLEs, list or np.ndarray') 214 | 215 | # define local variables 216 | cdef double* _iou = 0 217 | cdef np.npy_intp shape[1] 218 | # check type and assign iou function 219 | if type(dt) == RLEs: 220 | _iouFun = _rleIou 221 | elif type(dt) == np.ndarray: 222 | _iouFun = _bbIou 223 | else: 224 | raise Exception('input data type not allowed.') 225 | _iou = malloc(m*n* sizeof(double)) 226 | iou = np.zeros((m*n, ), dtype=np.double) 227 | shape[0] = m*n 228 | iou = np.PyArray_SimpleNewFromData(1, shape, np.NPY_DOUBLE, _iou) 229 | PyArray_ENABLEFLAGS(iou, np.NPY_OWNDATA) 230 | _iouFun(dt, gt, iscrowd, m, n, iou) 231 | return iou.reshape((m,n), order='F') 232 | 233 | def toBbox( rleObjs ): 234 | cdef RLEs Rs = _frString(rleObjs) 235 | cdef siz n = Rs.n 236 | cdef BB _bb = malloc(4*n* sizeof(double)) 237 | rleToBbox( Rs._R, _bb, n ) 238 | cdef np.npy_intp shape[1] 239 | shape[0] = 4*n 240 | bb = np.array((1,4*n), dtype=np.double) 241 | bb = np.PyArray_SimpleNewFromData(1, shape, np.NPY_DOUBLE, _bb).reshape((n, 4)) 242 | PyArray_ENABLEFLAGS(bb, np.NPY_OWNDATA) 243 | return bb 244 | 245 | def frBbox(np.ndarray[np.double_t, ndim=2] bb, siz h, siz w ): 246 | cdef siz n = bb.shape[0] 247 | Rs = RLEs(n) 248 | rleFrBbox( Rs._R, bb.data, h, w, n ) 249 | objs = _toString(Rs) 250 | return objs 251 | 252 | def frPoly( poly, siz h, siz w ): 253 | cdef np.ndarray[np.double_t, ndim=1] np_poly 254 | n = len(poly) 255 | Rs = RLEs(n) 256 | for i, p in enumerate(poly): 257 | np_poly = np.array(p, dtype=np.double, order='F') 258 | rleFrPoly( &Rs._R[i], np_poly.data, int(len(p)/2), h, w ) 259 | objs = _toString(Rs) 260 | return objs 261 | 262 | def frUncompressedRLE(ucRles, siz h, siz w): 263 | cdef np.ndarray[np.uint32_t, ndim=1] cnts 264 | cdef RLE R 265 | cdef uint *data 266 | n = len(ucRles) 267 | objs = [] 268 | for i in range(n): 269 | Rs = RLEs(1) 270 | cnts = np.array(ucRles[i]['counts'], dtype=np.uint32) 271 | # time for malloc can be saved here but it's fine 272 | data = malloc(len(cnts)* sizeof(uint)) 273 | for j in range(len(cnts)): 274 | data[j] = cnts[j] 275 | R = RLE(ucRles[i]['size'][0], ucRles[i]['size'][1], len(cnts), data) 276 | Rs._R[0] = R 277 | objs.append(_toString(Rs)[0]) 278 | return objs 279 | 280 | def frPyObjects(pyobj, siz h, w): 281 | if type(pyobj) == np.ndarray: 282 | objs = frBbox(pyobj, h, w ) 283 | elif type(pyobj) == list and len(pyobj[0]) == 4: 284 | objs = frBbox(pyobj, h, w ) 285 | elif type(pyobj) == list and len(pyobj[0]) > 4: 286 | objs = frPoly(pyobj, h, w ) 287 | elif type(pyobj) == list and type(pyobj[0]) == dict: 288 | objs = frUncompressedRLE(pyobj, h, w) 289 | else: 290 | raise Exception('input type is not supported.') 291 | return objs 292 | -------------------------------------------------------------------------------- /data_processing/pycocotools/coco.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tylin' 2 | __version__ = '2.0' 3 | # Interface for accessing the Microsoft COCO dataset. 4 | 5 | # Microsoft COCO is a large image dataset designed for object detection, 6 | # segmentation, and caption generation. pycocotools is a Python API that 7 | # assists in loading, parsing and visualizing the annotations in COCO. 8 | # Please visit http://mscoco.org/ for more information on COCO, including 9 | # for the data, paper, and tutorials. The exact format of the annotations 10 | # is also described on the COCO website. For example usage of the pycocotools 11 | # please see pycocotools_demo.ipynb. In addition to this API, please download both 12 | # the COCO images and annotations in order to run the demo. 13 | 14 | # An alternative to using the API is to load the annotations directly 15 | # into Python dictionary 16 | # Using the API provides additional utility functions. Note that this API 17 | # supports both *instance* and *caption* annotations. In the case of 18 | # captions not all functions are defined (e.g. categories are undefined). 19 | 20 | # The following API functions are defined: 21 | # COCO - COCO api class that loads COCO annotation file and prepare data structures. 22 | # decodeMask - Decode binary mask M encoded via run-length encoding. 23 | # encodeMask - Encode binary mask M using run-length encoding. 24 | # getAnnIds - Get ann ids that satisfy given filter conditions. 25 | # getCatIds - Get cat ids that satisfy given filter conditions. 26 | # getImgIds - Get img ids that satisfy given filter conditions. 27 | # loadAnns - Load anns with the specified ids. 28 | # loadCats - Load cats with the specified ids. 29 | # loadImgs - Load imgs with the specified ids. 30 | # segToMask - Convert polygon segmentation to binary mask. 31 | # showAnns - Display the specified annotations. 32 | # loadRes - Load algorithm results and create API for accessing them. 33 | # download - Download COCO images from mscoco.org server. 34 | # Throughout the API "ann"=annotation, "cat"=category, and "img"=image. 35 | # Help on each functions can be accessed by: "help COCO>function". 36 | 37 | # See also COCO>decodeMask, 38 | # COCO>encodeMask, COCO>getAnnIds, COCO>getCatIds, 39 | # COCO>getImgIds, COCO>loadAnns, COCO>loadCats, 40 | # COCO>loadImgs, COCO>segToMask, COCO>showAnns 41 | 42 | # Microsoft COCO Toolbox. version 2.0 43 | # Data, paper, and tutorials available at: http://mscoco.org/ 44 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2014. 45 | # Licensed under the Simplified BSD License [see bsd.txt] 46 | 47 | import json 48 | import time 49 | import matplotlib.pyplot as plt 50 | from matplotlib.collections import PatchCollection 51 | from matplotlib.patches import Polygon 52 | import numpy as np 53 | import urllib 54 | import copy 55 | import itertools 56 | import mask 57 | import os 58 | from collections import defaultdict 59 | 60 | class COCO: 61 | def __init__(self, annotation_file=None): 62 | """ 63 | Constructor of Microsoft COCO helper class for reading and visualizing annotations. 64 | :param annotation_file (str): location of annotation file 65 | :param image_folder (str): location to the folder that hosts images. 66 | :return: 67 | """ 68 | # load dataset 69 | self.dataset,self.anns,self.cats,self.imgs = dict(),dict(),dict(),dict() 70 | self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list) 71 | if not annotation_file == None: 72 | print('loading annotations into memory...') 73 | tic = time.time() 74 | dataset = json.load(open(annotation_file, 'r')) 75 | assert type(dataset)==dict, "annotation file format %s not supported"%(type(dataset)) 76 | print('Done (t=%0.2fs)'%(time.time()- tic)) 77 | self.dataset = dataset 78 | self.createIndex() 79 | 80 | def createIndex(self): 81 | # create index 82 | print('creating index...') 83 | anns,cats,imgs = dict(),dict(),dict() 84 | imgToAnns,catToImgs = defaultdict(list),defaultdict(list) 85 | if 'annotations' in self.dataset: 86 | for ann in self.dataset['annotations']: 87 | imgToAnns[ann['image_id']].append(ann) 88 | anns[ann['id']] = ann 89 | 90 | if 'images' in self.dataset: 91 | for img in self.dataset['images']: 92 | imgs[img['id']] = img 93 | 94 | if 'categories' in self.dataset: 95 | for cat in self.dataset['categories']: 96 | cats[cat['id']] = cat 97 | for ann in self.dataset['annotations']: 98 | catToImgs[ann['category_id']].append(ann['image_id']) 99 | 100 | print('index created!') 101 | 102 | # create class members 103 | self.anns = anns 104 | self.imgToAnns = imgToAnns 105 | self.catToImgs = catToImgs 106 | self.imgs = imgs 107 | self.cats = cats 108 | 109 | def info(self): 110 | """ 111 | Print information about the annotation file. 112 | :return: 113 | """ 114 | for key, value in self.dataset['info'].items(): 115 | print('%s: %s'%(key, value)) 116 | 117 | def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None): 118 | """ 119 | Get ann ids that satisfy given filter conditions. default skips that filter 120 | :param imgIds (int array) : get anns for given imgs 121 | catIds (int array) : get anns for given cats 122 | areaRng (float array) : get anns for given area range (e.g. [0 inf]) 123 | iscrowd (boolean) : get anns for given crowd label (False or True) 124 | :return: ids (int array) : integer array of ann ids 125 | """ 126 | imgIds = imgIds if type(imgIds) == list else [imgIds] 127 | catIds = catIds if type(catIds) == list else [catIds] 128 | 129 | if len(imgIds) == len(catIds) == len(areaRng) == 0: 130 | anns = self.dataset['annotations'] 131 | else: 132 | if not len(imgIds) == 0: 133 | lists = [self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns] 134 | anns = list(itertools.chain.from_iterable(lists)) 135 | else: 136 | anns = self.dataset['annotations'] 137 | anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds] 138 | anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]] 139 | if not iscrowd == None: 140 | ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd] 141 | else: 142 | ids = [ann['id'] for ann in anns] 143 | return ids 144 | 145 | def getCatIds(self, catNms=[], supNms=[], catIds=[]): 146 | """ 147 | filtering parameters. default skips that filter. 148 | :param catNms (str array) : get cats for given cat names 149 | :param supNms (str array) : get cats for given supercategory names 150 | :param catIds (int array) : get cats for given cat ids 151 | :return: ids (int array) : integer array of cat ids 152 | """ 153 | catNms = catNms if type(catNms) == list else [catNms] 154 | supNms = supNms if type(supNms) == list else [supNms] 155 | catIds = catIds if type(catIds) == list else [catIds] 156 | 157 | if len(catNms) == len(supNms) == len(catIds) == 0: 158 | cats = self.dataset['categories'] 159 | else: 160 | cats = self.dataset['categories'] 161 | cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms] 162 | cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms] 163 | cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds] 164 | ids = [cat['id'] for cat in cats] 165 | return ids 166 | 167 | def getImgIds(self, imgIds=[], catIds=[]): 168 | ''' 169 | Get img ids that satisfy given filter conditions. 170 | :param imgIds (int array) : get imgs for given ids 171 | :param catIds (int array) : get imgs with all given cats 172 | :return: ids (int array) : integer array of img ids 173 | ''' 174 | imgIds = imgIds if type(imgIds) == list else [imgIds] 175 | catIds = catIds if type(catIds) == list else [catIds] 176 | 177 | if len(imgIds) == len(catIds) == 0: 178 | ids = self.imgs.keys() 179 | else: 180 | ids = set(imgIds) 181 | for i, catId in enumerate(catIds): 182 | if i == 0 and len(ids) == 0: 183 | ids = set(self.catToImgs[catId]) 184 | else: 185 | ids &= set(self.catToImgs[catId]) 186 | return list(ids) 187 | 188 | def loadAnns(self, ids=[]): 189 | """ 190 | Load anns with the specified ids. 191 | :param ids (int array) : integer ids specifying anns 192 | :return: anns (object array) : loaded ann objects 193 | """ 194 | if type(ids) == list: 195 | return [self.anns[id] for id in ids] 196 | elif type(ids) == int: 197 | return [self.anns[ids]] 198 | 199 | def loadCats(self, ids=[]): 200 | """ 201 | Load cats with the specified ids. 202 | :param ids (int array) : integer ids specifying cats 203 | :return: cats (object array) : loaded cat objects 204 | """ 205 | if type(ids) == list: 206 | return [self.cats[id] for id in ids] 207 | elif type(ids) == int: 208 | return [self.cats[ids]] 209 | 210 | def loadImgs(self, ids=[]): 211 | """ 212 | Load anns with the specified ids. 213 | :param ids (int array) : integer ids specifying img 214 | :return: imgs (object array) : loaded img objects 215 | """ 216 | if type(ids) == list: 217 | return [self.imgs[id] for id in ids] 218 | elif type(ids) == int: 219 | return [self.imgs[ids]] 220 | 221 | def showAnns(self, anns): 222 | """ 223 | Display the specified annotations. 224 | :param anns (array of object): annotations to display 225 | :return: None 226 | """ 227 | if len(anns) == 0: 228 | return 0 229 | if 'segmentation' in anns[0] or 'keypoints' in anns[0]: 230 | datasetType = 'instances' 231 | elif 'caption' in anns[0]: 232 | datasetType = 'captions' 233 | else: 234 | raise Exception("datasetType not supported") 235 | if datasetType == 'instances': 236 | ax = plt.gca() 237 | ax.set_autoscale_on(False) 238 | polygons = [] 239 | color = [] 240 | for ann in anns: 241 | c = (np.random.random((1, 3))*0.6+0.4).tolist()[0] 242 | if 'segmentation' in ann: 243 | if type(ann['segmentation']) == list: 244 | # polygon 245 | for seg in ann['segmentation']: 246 | poly = np.array(seg).reshape((len(seg)/2, 2)) 247 | polygons.append(Polygon(poly)) 248 | color.append(c) 249 | else: 250 | # mask 251 | t = self.imgs[ann['image_id']] 252 | if type(ann['segmentation']['counts']) == list: 253 | rle = mask.frPyObjects([ann['segmentation']], t['height'], t['width']) 254 | else: 255 | rle = [ann['segmentation']] 256 | m = mask.decode(rle) 257 | img = np.ones( (m.shape[0], m.shape[1], 3) ) 258 | if ann['iscrowd'] == 1: 259 | color_mask = np.array([2.0,166.0,101.0])/255 260 | if ann['iscrowd'] == 0: 261 | color_mask = np.random.random((1, 3)).tolist()[0] 262 | for i in range(3): 263 | img[:,:,i] = color_mask[i] 264 | ax.imshow(np.dstack( (img, m*0.5) )) 265 | if 'keypoints' in ann and type(ann['keypoints']) == list: 266 | # turn skeleton into zero-based index 267 | sks = np.array(self.loadCats(ann['category_id'])[0]['skeleton'])-1 268 | kp = np.array(ann['keypoints']) 269 | x = kp[0::3] 270 | y = kp[1::3] 271 | v = kp[2::3] 272 | for sk in sks: 273 | if np.all(v[sk]>0): 274 | plt.plot(x[sk],y[sk], linewidth=3, color=c) 275 | plt.plot(x[v>0], y[v>0],'o',markersize=8, markerfacecolor=c, markeredgecolor='k',markeredgewidth=2) 276 | plt.plot(x[v>1], y[v>1],'o',markersize=8, markerfacecolor=c, markeredgecolor=c, markeredgewidth=2) 277 | p = PatchCollection(polygons, facecolor=color, linewidths=0, alpha=0.4) 278 | ax.add_collection(p) 279 | p = PatchCollection(polygons, facecolor="none", edgecolors=color, linewidths=2) 280 | ax.add_collection(p) 281 | elif datasetType == 'captions': 282 | for ann in anns: 283 | print(ann['caption']) 284 | 285 | def loadRes(self, resFile): 286 | """ 287 | Load result file and return a result api object. 288 | :param resFile (str) : file name of result file 289 | :return: res (obj) : result api object 290 | """ 291 | res = COCO() 292 | res.dataset['images'] = [img for img in self.dataset['images']] 293 | 294 | print('Loading and preparing results... ') 295 | tic = time.time() 296 | if type(resFile) == str or type(resFile) == unicode: 297 | anns = json.load(open(resFile)) 298 | elif type(resFile) == np.ndarray: 299 | anns = self.loadNumpyAnnotations(resFile) 300 | else: 301 | anns = resFile 302 | assert type(anns) == list, 'results in not an array of objects' 303 | annsImgIds = [ann['image_id'] for ann in anns] 304 | assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \ 305 | 'Results do not correspond to current coco set' 306 | if 'caption' in anns[0]: 307 | imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns]) 308 | res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds] 309 | for id, ann in enumerate(anns): 310 | ann['id'] = id+1 311 | elif 'bbox' in anns[0] and not anns[0]['bbox'] == []: 312 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 313 | for id, ann in enumerate(anns): 314 | bb = ann['bbox'] 315 | x1, x2, y1, y2 = [bb[0], bb[0]+bb[2], bb[1], bb[1]+bb[3]] 316 | if not 'segmentation' in ann: 317 | ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]] 318 | ann['area'] = bb[2]*bb[3] 319 | ann['id'] = id+1 320 | ann['iscrowd'] = 0 321 | elif 'segmentation' in anns[0]: 322 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 323 | for id, ann in enumerate(anns): 324 | # now only support compressed RLE format as segmentation results 325 | ann['area'] = mask.area([ann['segmentation']])[0] 326 | if not 'bbox' in ann: 327 | ann['bbox'] = mask.toBbox([ann['segmentation']])[0] 328 | ann['id'] = id+1 329 | ann['iscrowd'] = 0 330 | elif 'keypoints' in anns[0]: 331 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) 332 | for id, ann in enumerate(anns): 333 | s = ann['keypoints'] 334 | x = s[0::3] 335 | y = s[1::3] 336 | x0,x1,y0,y1 = np.min(x), np.max(x), np.min(y), np.max(y) 337 | ann['area'] = (x1-x0)*(y1-y0) 338 | ann['id'] = id + 1 339 | ann['bbox'] = [x0,y0,x1-x0,y1-y0] 340 | print('DONE (t=%0.2fs)'%(time.time()- tic)) 341 | 342 | res.dataset['annotations'] = anns 343 | res.createIndex() 344 | return res 345 | 346 | def download( self, tarDir = None, imgIds = [] ): 347 | ''' 348 | Download COCO images from mscoco.org server. 349 | :param tarDir (str): COCO results directory name 350 | imgIds (list): images to be downloaded 351 | :return: 352 | ''' 353 | if tarDir is None: 354 | print('Please specify target directory') 355 | return -1 356 | if len(imgIds) == 0: 357 | imgs = self.imgs.values() 358 | else: 359 | imgs = self.loadImgs(imgIds) 360 | N = len(imgs) 361 | if not os.path.exists(tarDir): 362 | os.makedirs(tarDir) 363 | for i, img in enumerate(imgs): 364 | tic = time.time() 365 | fname = os.path.join(tarDir, img['file_name']) 366 | if not os.path.exists(fname): 367 | urllib.urlretrieve(img['coco_url'], fname) 368 | print('downloaded %d/%d images (t=%.1fs)'%(i, N, time.time()- tic)) 369 | 370 | def loadNumpyAnnotations(self, data): 371 | """ 372 | Convert result data from a numpy array [Nx7] where each row contains {imageID,x1,y1,w,h,score,class} 373 | :param data (numpy.ndarray) 374 | :return: annotations (python nested list) 375 | """ 376 | print("Converting ndarray to lists...") 377 | assert(type(data) == np.ndarray) 378 | print(data.shape) 379 | assert(data.shape[1] == 7) 380 | N = data.shape[0] 381 | ann = [] 382 | for i in range(N): 383 | if i % 1000000 == 0: 384 | print("%d/%d" % (i,N)) 385 | ann += [{ 386 | 'image_id' : int(data[i, 0]), 387 | 'bbox' : [ data[i, 1], data[i, 2], data[i, 3], data[i, 4] ], 388 | 'score' : data[i, 5], 389 | 'category_id': int(data[i, 6]), 390 | }] 391 | return ann 392 | -------------------------------------------------------------------------------- /data_processing/pycocotools/mask.py: -------------------------------------------------------------------------------- 1 | __author__ = 'tsungyi' 2 | 3 | import pycocotools._mask as _mask 4 | 5 | # Interface for manipulating masks stored in RLE format. 6 | # 7 | # RLE is a simple yet efficient format for storing binary masks. RLE 8 | # first divides a vector (or vectorized image) into a series of piecewise 9 | # constant regions and then for each piece simply stores the length of 10 | # that piece. For example, given M=[0 0 1 1 1 0 1] the RLE counts would 11 | # be [2 3 1 1], or for M=[1 1 1 1 1 1 0] the counts would be [0 6 1] 12 | # (note that the odd counts are always the numbers of zeros). Instead of 13 | # storing the counts directly, additional compression is achieved with a 14 | # variable bitrate representation based on a common scheme called LEB128. 15 | # 16 | # Compression is greatest given large piecewise constant regions. 17 | # Specifically, the size of the RLE is proportional to the number of 18 | # *boundaries* in M (or for an image the number of boundaries in the y 19 | # direction). Assuming fairly simple shapes, the RLE representation is 20 | # O(sqrt(n)) where n is number of pixels in the object. Hence space usage 21 | # is substantially lower, especially for large simple objects (large n). 22 | # 23 | # Many common operations on masks can be computed directly using the RLE 24 | # (without need for decoding). This includes computations such as area, 25 | # union, intersection, etc. All of these operations are linear in the 26 | # size of the RLE, in other words they are O(sqrt(n)) where n is the area 27 | # of the object. Computing these operations on the original mask is O(n). 28 | # Thus, using the RLE can result in substantial computational savings. 29 | # 30 | # The following API functions are defined: 31 | # encode - Encode binary masks using RLE. 32 | # decode - Decode binary masks encoded via RLE. 33 | # merge - Compute union or intersection of encoded masks. 34 | # iou - Compute intersection over union between masks. 35 | # area - Compute area of encoded masks. 36 | # toBbox - Get bounding boxes surrounding encoded masks. 37 | # frPyObjects - Convert polygon, bbox, and uncompressed RLE to encoded RLE mask. 38 | # 39 | # Usage: 40 | # Rs = encode( masks ) 41 | # masks = decode( Rs ) 42 | # R = merge( Rs, intersect=false ) 43 | # o = iou( dt, gt, iscrowd ) 44 | # a = area( Rs ) 45 | # bbs = toBbox( Rs ) 46 | # Rs = frPyObjects( [pyObjects], h, w ) 47 | # 48 | # In the API the following formats are used: 49 | # Rs - [dict] Run-length encoding of binary masks 50 | # R - dict Run-length encoding of binary mask 51 | # masks - [hxwxn] Binary mask(s) (must have type np.ndarray(dtype=uint8) in column-major order) 52 | # iscrowd - [nx1] list of np.ndarray. 1 indicates corresponding gt image has crowd region to ignore 53 | # bbs - [nx4] Bounding box(es) stored as [x y w h] 54 | # poly - Polygon stored as [[x1 y1 x2 y2...],[x1 y1 ...],...] (2D list) 55 | # dt,gt - May be either bounding boxes or encoded masks 56 | # Both poly and bbs are 0-indexed (bbox=[0 0 1 1] encloses first pixel). 57 | # 58 | # Finally, a note about the intersection over union (iou) computation. 59 | # The standard iou of a ground truth (gt) and detected (dt) object is 60 | # iou(gt,dt) = area(intersect(gt,dt)) / area(union(gt,dt)) 61 | # For "crowd" regions, we use a modified criteria. If a gt object is 62 | # marked as "iscrowd", we allow a dt to match any subregion of the gt. 63 | # Choosing gt' in the crowd gt that best matches the dt can be done using 64 | # gt'=intersect(dt,gt). Since by definition union(gt',dt)=dt, computing 65 | # iou(gt,dt,iscrowd) = iou(gt',dt) = area(intersect(gt,dt)) / area(dt) 66 | # For crowd gt regions we use this modified criteria above for the iou. 67 | # 68 | # To compile run "python setup.py build_ext --inplace" 69 | # Please do not contact us for help with compiling. 70 | # 71 | # Microsoft COCO Toolbox. version 2.0 72 | # Data, paper, and tutorials available at: http://mscoco.org/ 73 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2015. 74 | # Licensed under the Simplified BSD License [see coco/license.txt] 75 | 76 | encode = _mask.encode 77 | decode = _mask.decode 78 | iou = _mask.iou 79 | merge = _mask.merge 80 | area = _mask.area 81 | toBbox = _mask.toBbox 82 | frPyObjects = _mask.frPyObjects -------------------------------------------------------------------------------- /data_processing/sketchy_to_tfrecord.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import csv 4 | import numpy as np 5 | import scipy.io 6 | import scipy.misc as spm 7 | 8 | import cv2 9 | from scipy import ndimage 10 | import tensorflow as tf 11 | from tensorflow.python.framework import ops 12 | 13 | 14 | def showImg(img): 15 | cv2.imshow("test", img) 16 | cv2.waitKey(-1) 17 | 18 | 19 | def dense_to_one_hot(labels_dense, num_classes): 20 | """Convert class labels from scalars to one-hot vectors.""" 21 | num_labels = labels_dense.shape[0] 22 | index_offset = np.arange(num_labels) * num_classes 23 | labels_one_hot = np.zeros((num_labels, num_classes), dtype=np.int32) 24 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 25 | return labels_one_hot 26 | 27 | 28 | def _bytes_feature(value): 29 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 30 | 31 | 32 | def _int64_feature(value): 33 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 34 | 35 | 36 | classes_info = '../data_processing/classes.csv' 37 | photo_folder = '../Datasets/Sketchy/rendered_256x256/256x256/photo/tx_000000000000' 38 | sketch_folder = '../Datasets/Sketchy/rendered_256x256/256x256/sketch/tx_000000000000' 39 | info_dir = '../Datasets/Sketchy/info' 40 | data_dir = '../tfrecords/sketchy' 41 | 42 | config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, 43 | intra_op_parallelism_threads=4) 44 | 45 | 46 | def check_repeat(seq): 47 | seen = set() 48 | seen_add = seen.add 49 | seen_twice = set(x for x in seq if x in seen or seen_add(x)) 50 | return list(seen_twice) 51 | 52 | 53 | def build_graph(): 54 | photo_filename = tf.placeholder(dtype=tf.string, shape=()) 55 | label_filename = tf.placeholder(dtype=tf.string, shape=()) 56 | photo = tf.read_file(photo_filename) 57 | label = tf.read_file(label_filename) 58 | photo_decoded = tf.image.decode_jpeg(photo, fancy_upscaling=True) 59 | label_decoded = tf.image.decode_png(label) 60 | 61 | # Encode 64x64 62 | photo_input = tf.placeholder(dtype=tf.uint8, shape=(64, 64, 3)) 63 | label_input = tf.placeholder(dtype=tf.uint8, shape=(256, 256, 1)) 64 | label_small_input = tf.placeholder(dtype=tf.uint8, shape=(64, 64, 1)) 65 | 66 | photo_stream = tf.image.encode_jpeg(photo_input, quality=95, progressive=False, 67 | optimize_size=False, chroma_downsampling=False) 68 | label_stream = tf.image.encode_png(label_input, compression=7) 69 | label_small_stream = tf.image.encode_png(label_small_input, compression=7) 70 | 71 | return photo_filename, label_filename, photo, label, photo_decoded, label_decoded, photo_input, label_input,\ 72 | label_small_input, photo_stream, label_stream, label_small_stream 73 | 74 | 75 | def read_csv(filename): 76 | with open(filename) as csvfile: 77 | reader = csv.DictReader(csvfile) 78 | l = list(reader) 79 | 80 | return l 81 | 82 | 83 | def read_txt(filename): 84 | with open(filename) as txtfile: 85 | lines = txtfile.readlines() 86 | return [l[:-1] for l in lines] 87 | 88 | 89 | def split_csvlist(stat_info): 90 | cat = list(set([item['Category'] for item in stat_info])) 91 | l = [] 92 | for c in cat: 93 | li = [item for item in stat_info if item['Category'] == c] 94 | l.append(li) 95 | 96 | return cat, l 97 | 98 | 99 | def binarize(sketch, threshold=245): 100 | sketch[sketch < threshold] = 0 101 | sketch[sketch >= threshold] = 255 102 | return sketch 103 | 104 | 105 | def write_image_data(): 106 | 107 | csv_file = os.path.join(info_dir, 'stats.csv') 108 | stat_info = read_csv(csv_file) 109 | classes = read_csv(classes_info) 110 | classes_ids = [item['Name'] for item in classes] 111 | 112 | test_list = read_txt(os.path.join(info_dir, 'testset.txt')) 113 | 114 | invalid_notations = ['invalid-ambiguous.txt', 'invalid-context.txt', 'invalid-error.txt', 'invalid-pose.txt'] 115 | invalid_files = [] 116 | for txtfile in invalid_notations: 117 | cur_path = os.path.join(info_dir, txtfile) 118 | files = read_txt(cur_path) 119 | files = [f[:-1] for f in files] 120 | invalid_files.extend(files) 121 | 122 | path_image = photo_folder 123 | path_label = sketch_folder 124 | 125 | dirs, stats = split_csvlist(stat_info) 126 | photo_filename, label_filename, photo, label, photo_decoded, label_decoded, photo_input, label_input, \ 127 | label_small_input, photo_stream, label_stream, label_small_stream = build_graph() 128 | assert len(dirs) == len(stats) 129 | 130 | with tf.Session(config=config) as sess: 131 | sess.run(tf.global_variables_initializer()) 132 | # coord = tf.train.Coordinator() 133 | # threads = tf.train.start_queue_runners(sess=sess, coord=coord) 134 | 135 | for i in range(len(dirs)): 136 | dir = dirs[i].replace(' ', '_') 137 | print(dir) 138 | class_id = classes_ids.index(dir) 139 | stat = stats[i] 140 | writer = tf.python_io.TFRecordWriter(os.path.join(data_dir, dir + '.tfrecord')) 141 | 142 | cur_photo_path = os.path.join(path_image, dir) 143 | cur_label_path = os.path.join(path_label, dir) 144 | num_label = len(stat) 145 | # photo_files = [f for f in os.listdir(cur_photo_path) if os.path.isfile(os.path.join(cur_photo_path, f))] 146 | # label_files = [f for f in os.listdir(cur_label_path) if os.path.isfile(os.path.join(cur_label_path, f))] 147 | 148 | for j in range(num_label): 149 | if j % 500 == 499: 150 | print(j) 151 | item = stat[j] 152 | 153 | ImageNetID = item['ImageNetID'] 154 | SketchID = int(item['SketchID']) 155 | Category = item['Category'] 156 | CategoryID = int(item['CategoryID']) 157 | Difficulty = int(item['Difficulty']) 158 | Stroke_Count = int(item['Stroke_Count']) 159 | 160 | WrongPose = int(item['WrongPose?']) 161 | Context = int(item['Context?']) 162 | Ambiguous = int(item['Ambiguous?']) 163 | Error = int(item['Error?']) 164 | 165 | if os.path.join(dir, ImageNetID + '.jpg') in test_list: 166 | IsTest = 1 167 | else: 168 | IsTest = 0 169 | 170 | # print(os.path.join(cur_photo_path, ImageNetID + '.jpg')) 171 | # print(os.path.join(cur_label_path, ImageNetID + '-' + str(SketchID) + '.png')) 172 | out_image, out_image_decoded = sess.run([photo, photo_decoded], feed_dict={ 173 | photo_filename: os.path.join(cur_photo_path, ImageNetID + '.jpg')}) 174 | out_label, out_label_decoded = sess.run([label, label_decoded], feed_dict={ 175 | label_filename: os.path.join(cur_label_path, ImageNetID + '-' + str(SketchID) + '.png')}) 176 | 177 | # Resize 178 | out_image_decoded_small = cv2.resize(out_image_decoded, (64, 64), interpolation=cv2.INTER_AREA) 179 | out_label_decoded = (np.sum(out_label_decoded.astype(np.float64), axis=2)/3).astype(np.uint8) 180 | out_label_decoded_small = cv2.resize(out_label_decoded, (64, 64), interpolation=cv2.INTER_AREA) 181 | 182 | # Distance map 183 | out_dist_map = ndimage.distance_transform_edt(binarize(out_label_decoded)) 184 | out_dist_map = (out_dist_map / out_dist_map.max() * 255.).astype(np.uint8) 185 | 186 | out_dist_map_small = ndimage.distance_transform_edt(binarize(out_label_decoded_small)) 187 | out_dist_map_small = (out_dist_map_small / out_dist_map_small.max() * 255.).astype(np.uint8) 188 | 189 | # Stream 190 | image_string_small, label_string_small = sess.run([photo_stream, label_small_stream], feed_dict={ 191 | photo_input: out_image_decoded_small, label_small_input: out_label_decoded_small.reshape((64, 64, 1)) 192 | }) 193 | dist_map_string = sess.run(label_stream, feed_dict={label_input: out_dist_map.reshape((256, 256, 1))}) 194 | dist_map_string_small = sess.run(label_small_stream, feed_dict={ 195 | label_small_input: out_dist_map_small.reshape((64, 64, 1))}) 196 | 197 | example = tf.train.Example(features=tf.train.Features(feature={ 198 | 'ImageNetID': _bytes_feature(ImageNetID.encode('utf-8')), 199 | 'SketchID': _int64_feature(SketchID), 200 | 'Category': _bytes_feature(Category.encode('utf-8')), 201 | 'CategoryID': _int64_feature(CategoryID), 202 | 'Difficulty': _int64_feature(Difficulty), 203 | 'Stroke_Count': _int64_feature(Stroke_Count), 204 | 'WrongPose': _int64_feature(WrongPose), 205 | 'Context': _int64_feature(Context), 206 | 'Ambiguous': _int64_feature(Ambiguous), 207 | 'Error': _int64_feature(Error), 208 | 'is_test': _int64_feature(IsTest), 209 | 'class_id': _int64_feature(class_id), 210 | 'image_jpeg': _bytes_feature(out_image), 211 | 'image_small_jpeg': _bytes_feature(image_string_small), 212 | 'sketch_png': _bytes_feature(out_label), 213 | 'sketch_small_png': _bytes_feature(label_string_small), 214 | 'dist_map_png': _bytes_feature(dist_map_string), 215 | 'dist_map_small_png': _bytes_feature(dist_map_string_small), 216 | })) 217 | writer.write(example.SerializeToString()) 218 | 219 | # coord.request_stop() 220 | # coord.join(threads) 221 | 222 | writer.close() 223 | 224 | 225 | write_image_data() 226 | -------------------------------------------------------------------------------- /data_processing/tfrecord.py: -------------------------------------------------------------------------------- 1 | import csv 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | import cv2 6 | 7 | 8 | def check_repeat(seq): 9 | seen = set() 10 | seen_add = seen.add 11 | seen_twice = set(x for x in seq if x in seen or seen_add(x)) 12 | return list(seen_twice) 13 | 14 | 15 | def binarize(sketch, threshold=245): 16 | sketch[sketch < threshold] = 0 17 | sketch[sketch >= threshold] = 255 18 | return sketch 19 | 20 | 21 | def showImg(img): 22 | cv2.imshow("test", img) 23 | cv2.waitKey(-1) 24 | 25 | 26 | def dense_to_one_hot(labels_dense, num_classes): 27 | """Convert class labels from scalars to one-hot vectors.""" 28 | num_labels = labels_dense.shape[0] 29 | index_offset = np.arange(num_labels) * num_classes 30 | labels_one_hot = np.zeros((num_labels, num_classes), dtype=np.int32) 31 | labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 32 | return labels_one_hot 33 | 34 | 35 | def bytes_feature(value): 36 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 37 | 38 | 39 | def int64_feature(value): 40 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 41 | 42 | 43 | def read_csv(filename): 44 | with open(filename) as csvfile: 45 | reader = csv.DictReader(csvfile) 46 | l = list(reader) 47 | 48 | return l 49 | 50 | 51 | def read_txt(filename): 52 | with open(filename) as txtfile: 53 | lines = txtfile.readlines() 54 | return [l[:-1] for l in lines] 55 | 56 | 57 | def split_csvlist(stat_info): 58 | cat = list(set([item['Category'] for item in stat_info])) 59 | l = [] 60 | for c in cat: 61 | li = [item for item in stat_info if item['Category'] == c] 62 | l.append(li) 63 | 64 | return cat, l 65 | -------------------------------------------------------------------------------- /inception_v4_model/put_inception_v4.ckpt_here: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wchen342/SketchyGAN/1828860ec77d017bdd4217b8a90ea15baea06578/inception_v4_model/put_inception_v4.ckpt_here -------------------------------------------------------------------------------- /main_single.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import importlib 3 | import os 4 | import sys 5 | import shutil 6 | import json 7 | import tensorflow as tf 8 | from time import gmtime, strftime 9 | 10 | src_dir = './src_single' 11 | 12 | 13 | def launch_training(**kwargs): 14 | 15 | # Deal with file and paths 16 | appendix = kwargs["resume_from"] 17 | if appendix is None or appendix == '': 18 | cur_time = strftime("%Y-%m-%d-%H-%M-%S", gmtime()) 19 | log_dir = './log_skgan_' + cur_time 20 | ckpt_dir = './ckpt_skgan_' + cur_time 21 | if not os.path.isdir(log_dir) and not os.path.exists(log_dir): 22 | os.makedirs(log_dir) 23 | if not os.path.isdir(ckpt_dir) and not os.path.exists(ckpt_dir): 24 | os.makedirs(ckpt_dir) 25 | 26 | # copy current script in src folder to log dir for record 27 | if not os.path.exists(src_dir) or not os.path.isdir(src_dir): 28 | print("src folder does not exist.") 29 | return 30 | else: 31 | for file in os.listdir(src_dir): 32 | if file.endswith(".py"): 33 | shutil.copy(os.path.join(src_dir, file), log_dir) 34 | 35 | kwargs['log_dir'] = log_dir 36 | kwargs['ckpt_dir'] = ckpt_dir 37 | appendix = cur_time 38 | kwargs["resume_from"] = appendix 39 | kwargs["iter_from"] = 0 40 | 41 | # Save parameters 42 | with open(os.path.join(log_dir, 'param_%d.json' % 0), 'w') as fp: 43 | json.dump(kwargs, fp, indent=4) 44 | 45 | sys.path.append(src_dir) 46 | entry_point_module = kwargs['entry_point'] 47 | from config import Config 48 | Config.set_from_dict(kwargs) 49 | 50 | print("Launching new train: %s" % cur_time) 51 | else: 52 | if len(appendix.split('-')) != 6: 53 | print("Invalid resume folder") 54 | return 55 | 56 | log_dir = './log_skgan_' + appendix 57 | ckpt_dir = './ckpt_skgan_' + appendix 58 | 59 | # Get last parameters (recover entry point module name) 60 | json_files = [f for f in os.listdir(log_dir) if 61 | os.path.isfile(os.path.join(log_dir, f)) and os.path.splitext(f)[1] == '.json'] 62 | iter_starts = max([int(os.path.splitext(filename)[0].split('_')[1]) for filename in json_files]) 63 | with open(os.path.join(log_dir, 'param_%d.json' % iter_starts), 'r') as fp: 64 | params = json.load(fp) 65 | entry_point_module = params['entry_point'] 66 | 67 | # Recover parameters 68 | _ignored = ['num_gpu', 'iter_from'] 69 | for k, v in params.items(): 70 | if k not in _ignored: 71 | kwargs[k] = v 72 | 73 | sys.path.append(log_dir) 74 | 75 | # Get latest checkpoint filename 76 | # if stage == 1: 77 | # ckpt_file = tf.train.latest_checkpoint(stage_1_ckpt_dir) 78 | # elif stage == 2: 79 | ckpt_file = tf.train.latest_checkpoint(ckpt_dir) 80 | if ckpt_file is None: 81 | raise RuntimeError 82 | else: 83 | iter_from = int(os.path.split(ckpt_file)[1].split('-')[1]) + 1 84 | kwargs['log_dir'] = log_dir 85 | kwargs['ckpt_dir'] = ckpt_dir 86 | kwargs['iter_from'] = iter_from 87 | 88 | # Save new set of parameters 89 | with open(os.path.join(log_dir, 'param_%d.json' % iter_from), 'w') as fp: 90 | kwargs['entry_point'] = entry_point_module 91 | json.dump(kwargs, fp, indent=4) 92 | 93 | from config import Config 94 | Config.set_from_dict(kwargs) 95 | print("Launching train from checkpoint: %s" % appendix) 96 | 97 | # Launch train 98 | train_module = importlib.import_module(entry_point_module) 99 | # from train_paired_aug_multi_gpu import train 100 | status = train_module.train(**kwargs) 101 | 102 | return status, appendix 103 | 104 | 105 | def launch_test(**kwargs): 106 | # Deal with file and paths 107 | appendix = kwargs["resume_from"] 108 | if appendix is None or appendix == '' or len(appendix.split('-')) != 6: 109 | print("Invalid resume folder") 110 | return 111 | 112 | log_dir = './log_skgan_' + appendix 113 | ckpt_dir = './ckpt_skgan_' + appendix 114 | 115 | sys.path.append(log_dir) 116 | 117 | # Get latest checkpoint filename 118 | kwargs['log_dir'] = log_dir 119 | kwargs['ckpt_dir'] = ckpt_dir 120 | 121 | # Get last parameters (recover entry point module name) 122 | # Assuming last json file 123 | json_files = [f for f in os.listdir(log_dir) if 124 | os.path.isfile(os.path.join(log_dir, f)) and os.path.splitext(f)[1] == '.json'] 125 | iter_starts = max([int(os.path.splitext(filename)[0].split('_')[1]) for filename in json_files]) 126 | with open(os.path.join(log_dir, 'param_%d.json' % iter_starts), 'r') as fp: 127 | params = json.load(fp) 128 | entry_point_module = params['entry_point'] 129 | 130 | # Recover parameters 131 | _ignored = ["num_gpu", 'iter_from'] 132 | for k, v in params.items(): 133 | if k not in _ignored: 134 | kwargs[k] = v 135 | 136 | from config import Config 137 | Config.set_from_dict(kwargs) 138 | print("Launching test from checkpoint: %s" % appendix) 139 | 140 | # Launch test 141 | train_module = importlib.import_module(entry_point_module) 142 | train_module.test(**kwargs) 143 | 144 | 145 | if __name__ == "__main__": 146 | 147 | parser = argparse.ArgumentParser(description='Train or Test model') 148 | parser.add_argument('--mode', type=str, default="train", help="train or test") 149 | parser.add_argument('--resume_from', type=str, default='', help="Whether resume last checkpoint from a past run. Notice: you only need to fill in the string after skgan_, i.e. the part with yyyy-mm-dd-hr-min-sec") 150 | parser.add_argument('--entry_point', type=str, default='train_single', help="name of the training .py file") 151 | parser.add_argument('--batch_size', default=16, type=int, help='Batch size per gpu') 152 | parser.add_argument('--max_iter_step', default=300000, type=int, help="Max number of iterations") 153 | parser.add_argument('--disc_iterations', default=1, type=int, help="Number of discriminator iterations") 154 | parser.add_argument('--ld', default=10, type=float, help="Gradient penalty lambda hyperparameter") 155 | parser.add_argument('--optimizer', type=str, default="Adam", help="Optimizer for the graph") 156 | parser.add_argument('--lr_G', type=float, default=2e-4, help="learning rate for the generator") 157 | parser.add_argument('--lr_D', type=float, default=4e-4, help="learning rate for the discriminator") 158 | parser.add_argument('--num_gpu', default=2, type=int, help="Number of GPUs to use") 159 | parser.add_argument('--distance_map', default=1, type=int, help="Whether using distance maps for sketches") 160 | parser.add_argument('--small_img', default=1, type=int, help="Whether using 64x64 instead of 256x256") 161 | parser.add_argument('--extra_info', default="", type=str, help="Extra information saved for record") 162 | 163 | args = parser.parse_args() 164 | 165 | assert args.optimizer in ["RMSprop", "Adam", "AdaDelta", "AdaGrad"], "Unsupported optimizer" 166 | 167 | # Set default params 168 | d_params = {"resume_from": args.resume_from, 169 | "entry_point": args.entry_point, 170 | "batch_size": args.batch_size, 171 | "max_iter_step": args.max_iter_step, 172 | "disc_iterations": args.disc_iterations, 173 | "ld": args.ld, 174 | "optimizer": args.optimizer, 175 | "lr_G": args.lr_G, 176 | "lr_D": args.lr_D, 177 | "num_gpu": args.num_gpu, 178 | "distance_map": args.distance_map, 179 | "small_img": args.small_img, 180 | "extra_info": args.extra_info, 181 | } 182 | 183 | if args.mode == 'train': 184 | # Launch training 185 | status, appendix = launch_training(**d_params) 186 | while status == -1: # NaN during training 187 | print("Training ended with status -1. Restarting..") 188 | d_params["resume_from"] = appendix 189 | status = launch_training(**d_params) 190 | elif args.mode == 'test': 191 | launch_test(**d_params) 192 | -------------------------------------------------------------------------------- /src_single/config.py: -------------------------------------------------------------------------------- 1 | # Global config 2 | 3 | 4 | class Config(object): 5 | # global options 6 | data_format = 'NCHW' # DO NOT CHANGE THIS 7 | SPECTRAL_NORM_UPDATE_OPS = "spectral_norm_update_ops" 8 | sn = True # Whether uses Spectral Normalization(https://arxiv.org/abs/1802.05957) 9 | proj_d = False # Whether uses projection discriminator(https://arxiv.org/abs/1802.05637) 10 | wgan = False # WGAN or DRAGAN(only effective if sn is False) 11 | pre_calculated_dist_map = True # Whether calculate distance maps on the fly 12 | 13 | @staticmethod 14 | def set_from_dict(d): 15 | assert type(d) is dict 16 | for k, v in d.items(): 17 | setattr(Config, k, v) 18 | -------------------------------------------------------------------------------- /src_single/inception_score.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/openai/improved-gan/blob/master/inception_score/model.py 2 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py 3 | import os.path 4 | import sys 5 | import tarfile 6 | 7 | import numpy as np 8 | from six.moves import urllib 9 | import tensorflow as tf 10 | import glob 11 | import scipy.misc 12 | import math 13 | import sys 14 | 15 | MODEL_DIR = './inception_model/imagenet' 16 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 17 | softmax = None 18 | prefix = 'Inception' 19 | config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=True) 20 | config.gpu_options.allow_growth = True 21 | config.gpu_options.per_process_gpu_memory_fraction = 0.9 22 | 23 | 24 | # Call this function with list of images. Each of elements should be a 25 | # numpy array with values ranging from 0 to 255. 26 | def get_inception_score(images, sess, splits=10): 27 | assert (type(images) == list) 28 | assert (type(images[0]) == np.ndarray) 29 | assert (len(images[0].shape) == 3) 30 | # assert (np.max(images[0]) > 10) 31 | # assert (np.min(images[0]) >= 0.0) 32 | inps = [] 33 | for img in images: 34 | img = img.astype(np.float32) 35 | inps.append(np.expand_dims(img, 0)) 36 | bs = 100 37 | 38 | preds = [] 39 | n_batches = int(math.ceil(float(len(inps)) / float(bs))) 40 | for i in range(n_batches): 41 | # sys.stdout.write(".") 42 | # sys.stdout.flush() 43 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))] 44 | inp = np.concatenate(inp, 0) 45 | pred = sess.run(softmax, {prefix + 'ExpandDims:0': inp}) 46 | preds.append(pred) 47 | preds = np.concatenate(preds, 0) 48 | scores = [] 49 | for i in range(splits): 50 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 51 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 52 | kl = np.mean(np.sum(kl, 1)) 53 | scores.append(np.exp(kl)) 54 | return np.mean(scores), np.std(scores) 55 | 56 | 57 | # This function is called automatically. 58 | def _init_inception(): 59 | global softmax, prefix 60 | if not os.path.exists(MODEL_DIR): 61 | os.makedirs(MODEL_DIR) 62 | filename = DATA_URL.split('/')[-1] 63 | filepath = os.path.join(MODEL_DIR, filename) 64 | if not os.path.exists(filepath): 65 | def _progress(count, block_size, total_size): 66 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 67 | filename, float(count * block_size) / float(total_size) * 100.0)) 68 | sys.stdout.flush() 69 | 70 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 71 | print() 72 | statinfo = os.stat(filepath) 73 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 74 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) 75 | with tf.gfile.FastGFile(os.path.join( 76 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: 77 | graph_def = tf.GraphDef() 78 | graph_def.ParseFromString(f.read()) 79 | for node in graph_def.node: 80 | node.device = "/gpu:2" 81 | _ = tf.import_graph_def(graph_def, name=prefix) 82 | if prefix[-1] != '/': 83 | prefix += '/' 84 | # Works with an arbitrary minibatch size. 85 | with tf.Session(config=config) as sess: 86 | pool3 = sess.graph.get_tensor_by_name(prefix + 'pool_3:0') 87 | ops = pool3.graph.get_operations() 88 | for op_idx, op in enumerate(ops): 89 | for o in op.outputs: 90 | shape = o.get_shape() 91 | shape = [s.value for s in shape] 92 | # new_shape = [] 93 | for j, s in enumerate(shape): 94 | if s == 1 and j == 0: 95 | o.shape.dims[0] = tf.Dimension(None) 96 | # new_shape.append(None) 97 | # else: 98 | # new_shape.append(s) 99 | # o.set_shape(tf.TensorShape(new_shape)) 100 | w = sess.graph.get_operation_by_name(prefix + "softmax/logits/MatMul").inputs[1] 101 | logits = tf.matmul(tf.squeeze(pool3, (1, 2)), w) 102 | softmax = tf.nn.softmax(logits) 103 | 104 | 105 | if softmax is None: 106 | _init_inception() 107 | -------------------------------------------------------------------------------- /src_single/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | 17 | Usage of arg scope: 18 | with slim.arg_scope(inception_arg_scope()): 19 | logits, end_points = inception.inception_v3(images, num_classes, 20 | is_training=is_training) 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def inception_arg_scope(weight_decay=0.00004, 33 | use_batch_norm=True, 34 | batch_norm_decay=0.9997, 35 | batch_norm_epsilon=0.001): 36 | """Defines the default arg scope for inception models. 37 | 38 | Args: 39 | weight_decay: The weight decay to use for regularizing the model. 40 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 41 | batch_norm_decay: Decay for batch norm moving average. 42 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 43 | in batch norm. 44 | 45 | Returns: 46 | An `arg_scope` to use for the inception models. 47 | """ 48 | batch_norm_params = { 49 | # Decay for the moving averages. 50 | 'decay': batch_norm_decay, 51 | # epsilon to prevent 0s in variance. 52 | 'epsilon': batch_norm_epsilon, 53 | # collection containing update_ops. 54 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 55 | } 56 | if use_batch_norm: 57 | normalizer_fn = slim.batch_norm 58 | normalizer_params = batch_norm_params 59 | else: 60 | normalizer_fn = None 61 | normalizer_params = {} 62 | # Set weight_decay for weights in Conv and FC layers. 63 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 64 | weights_regularizer=slim.l2_regularizer(weight_decay)): 65 | with slim.arg_scope( 66 | [slim.conv2d], 67 | weights_initializer=slim.variance_scaling_initializer(), 68 | activation_fn=tf.nn.relu, 69 | normalizer_fn=normalizer_fn, 70 | normalizer_params=normalizer_params) as sc: 71 | return sc 72 | -------------------------------------------------------------------------------- /src_single/inception_v4.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the definition of the Inception V4 architecture. 16 | 17 | As described in http://arxiv.org/abs/1602.07261. 18 | 19 | Inception-v4, Inception-ResNet and the Impact of Residual Connections 20 | on Learning 21 | Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | import inception_utils 30 | 31 | slim = tf.contrib.slim 32 | 33 | 34 | def block_inception_a(inputs, scope=None, reuse=None): 35 | """Builds Inception-A block for Inception v4 network.""" 36 | # By default use stride=1 and SAME padding 37 | with slim.arg_scope([slim.conv2d, slim.avg_pool2d, slim.max_pool2d], 38 | stride=1, padding='SAME'): 39 | with tf.variable_scope(scope, 'BlockInceptionA', [inputs], reuse=reuse): 40 | with tf.variable_scope('Branch_0'): 41 | branch_0 = slim.conv2d(inputs, 96, [1, 1], scope='Conv2d_0a_1x1') 42 | with tf.variable_scope('Branch_1'): 43 | branch_1 = slim.conv2d(inputs, 64, [1, 1], scope='Conv2d_0a_1x1') 44 | branch_1 = slim.conv2d(branch_1, 96, [3, 3], scope='Conv2d_0b_3x3') 45 | with tf.variable_scope('Branch_2'): 46 | branch_2 = slim.conv2d(inputs, 64, [1, 1], scope='Conv2d_0a_1x1') 47 | branch_2 = slim.conv2d(branch_2, 96, [3, 3], scope='Conv2d_0b_3x3') 48 | branch_2 = slim.conv2d(branch_2, 96, [3, 3], scope='Conv2d_0c_3x3') 49 | with tf.variable_scope('Branch_3'): 50 | branch_3 = slim.avg_pool2d(inputs, [3, 3], scope='AvgPool_0a_3x3') 51 | branch_3 = slim.conv2d(branch_3, 96, [1, 1], scope='Conv2d_0b_1x1') 52 | return tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 53 | 54 | 55 | def block_reduction_a(inputs, scope=None, reuse=None): 56 | """Builds Reduction-A block for Inception v4 network.""" 57 | # By default use stride=1 and SAME padding 58 | with slim.arg_scope([slim.conv2d, slim.avg_pool2d, slim.max_pool2d], 59 | stride=1, padding='SAME'): 60 | with tf.variable_scope(scope, 'BlockReductionA', [inputs], reuse=reuse): 61 | with tf.variable_scope('Branch_0'): 62 | branch_0 = slim.conv2d(inputs, 384, [3, 3], stride=2, padding='VALID', 63 | scope='Conv2d_1a_3x3') 64 | with tf.variable_scope('Branch_1'): 65 | branch_1 = slim.conv2d(inputs, 192, [1, 1], scope='Conv2d_0a_1x1') 66 | branch_1 = slim.conv2d(branch_1, 224, [3, 3], scope='Conv2d_0b_3x3') 67 | branch_1 = slim.conv2d(branch_1, 256, [3, 3], stride=2, 68 | padding='VALID', scope='Conv2d_1a_3x3') 69 | with tf.variable_scope('Branch_2'): 70 | branch_2 = slim.max_pool2d(inputs, [3, 3], stride=2, padding='VALID', 71 | scope='MaxPool_1a_3x3') 72 | return tf.concat(axis=3, values=[branch_0, branch_1, branch_2]) 73 | 74 | 75 | def block_inception_b(inputs, scope=None, reuse=None): 76 | """Builds Inception-B block for Inception v4 network.""" 77 | # By default use stride=1 and SAME padding 78 | with slim.arg_scope([slim.conv2d, slim.avg_pool2d, slim.max_pool2d], 79 | stride=1, padding='SAME'): 80 | with tf.variable_scope(scope, 'BlockInceptionB', [inputs], reuse=reuse): 81 | with tf.variable_scope('Branch_0'): 82 | branch_0 = slim.conv2d(inputs, 384, [1, 1], scope='Conv2d_0a_1x1') 83 | with tf.variable_scope('Branch_1'): 84 | branch_1 = slim.conv2d(inputs, 192, [1, 1], scope='Conv2d_0a_1x1') 85 | branch_1 = slim.conv2d(branch_1, 224, [1, 7], scope='Conv2d_0b_1x7') 86 | branch_1 = slim.conv2d(branch_1, 256, [7, 1], scope='Conv2d_0c_7x1') 87 | with tf.variable_scope('Branch_2'): 88 | branch_2 = slim.conv2d(inputs, 192, [1, 1], scope='Conv2d_0a_1x1') 89 | branch_2 = slim.conv2d(branch_2, 192, [7, 1], scope='Conv2d_0b_7x1') 90 | branch_2 = slim.conv2d(branch_2, 224, [1, 7], scope='Conv2d_0c_1x7') 91 | branch_2 = slim.conv2d(branch_2, 224, [7, 1], scope='Conv2d_0d_7x1') 92 | branch_2 = slim.conv2d(branch_2, 256, [1, 7], scope='Conv2d_0e_1x7') 93 | with tf.variable_scope('Branch_3'): 94 | branch_3 = slim.avg_pool2d(inputs, [3, 3], scope='AvgPool_0a_3x3') 95 | branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1') 96 | return tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 97 | 98 | 99 | def block_reduction_b(inputs, scope=None, reuse=None): 100 | """Builds Reduction-B block for Inception v4 network.""" 101 | # By default use stride=1 and SAME padding 102 | with slim.arg_scope([slim.conv2d, slim.avg_pool2d, slim.max_pool2d], 103 | stride=1, padding='SAME'): 104 | with tf.variable_scope(scope, 'BlockReductionB', [inputs], reuse=reuse): 105 | with tf.variable_scope('Branch_0'): 106 | branch_0 = slim.conv2d(inputs, 192, [1, 1], scope='Conv2d_0a_1x1') 107 | branch_0 = slim.conv2d(branch_0, 192, [3, 3], stride=2, 108 | padding='VALID', scope='Conv2d_1a_3x3') 109 | with tf.variable_scope('Branch_1'): 110 | branch_1 = slim.conv2d(inputs, 256, [1, 1], scope='Conv2d_0a_1x1') 111 | branch_1 = slim.conv2d(branch_1, 256, [1, 7], scope='Conv2d_0b_1x7') 112 | branch_1 = slim.conv2d(branch_1, 320, [7, 1], scope='Conv2d_0c_7x1') 113 | branch_1 = slim.conv2d(branch_1, 320, [3, 3], stride=2, 114 | padding='VALID', scope='Conv2d_1a_3x3') 115 | with tf.variable_scope('Branch_2'): 116 | branch_2 = slim.max_pool2d(inputs, [3, 3], stride=2, padding='VALID', 117 | scope='MaxPool_1a_3x3') 118 | return tf.concat(axis=3, values=[branch_0, branch_1, branch_2]) 119 | 120 | 121 | def block_inception_c(inputs, scope=None, reuse=None): 122 | """Builds Inception-C block for Inception v4 network.""" 123 | # By default use stride=1 and SAME padding 124 | with slim.arg_scope([slim.conv2d, slim.avg_pool2d, slim.max_pool2d], 125 | stride=1, padding='SAME'): 126 | with tf.variable_scope(scope, 'BlockInceptionC', [inputs], reuse=reuse): 127 | with tf.variable_scope('Branch_0'): 128 | branch_0 = slim.conv2d(inputs, 256, [1, 1], scope='Conv2d_0a_1x1') 129 | with tf.variable_scope('Branch_1'): 130 | branch_1 = slim.conv2d(inputs, 384, [1, 1], scope='Conv2d_0a_1x1') 131 | branch_1 = tf.concat(axis=3, values=[ 132 | slim.conv2d(branch_1, 256, [1, 3], scope='Conv2d_0b_1x3'), 133 | slim.conv2d(branch_1, 256, [3, 1], scope='Conv2d_0c_3x1')]) 134 | with tf.variable_scope('Branch_2'): 135 | branch_2 = slim.conv2d(inputs, 384, [1, 1], scope='Conv2d_0a_1x1') 136 | branch_2 = slim.conv2d(branch_2, 448, [3, 1], scope='Conv2d_0b_3x1') 137 | branch_2 = slim.conv2d(branch_2, 512, [1, 3], scope='Conv2d_0c_1x3') 138 | branch_2 = tf.concat(axis=3, values=[ 139 | slim.conv2d(branch_2, 256, [1, 3], scope='Conv2d_0d_1x3'), 140 | slim.conv2d(branch_2, 256, [3, 1], scope='Conv2d_0e_3x1')]) 141 | with tf.variable_scope('Branch_3'): 142 | branch_3 = slim.avg_pool2d(inputs, [3, 3], scope='AvgPool_0a_3x3') 143 | branch_3 = slim.conv2d(branch_3, 256, [1, 1], scope='Conv2d_0b_1x1') 144 | return tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3]) 145 | 146 | 147 | def inception_v4_base(inputs, final_endpoint='Mixed_7d', scope=None): 148 | """Creates the Inception V4 network up to the given final endpoint. 149 | 150 | Args: 151 | inputs: a 4-D tensor of size [batch_size, height, width, 3]. 152 | final_endpoint: specifies the endpoint to construct the network up to. 153 | It can be one of [ 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 154 | 'Mixed_3a', 'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 155 | 'Mixed_5e', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e', 156 | 'Mixed_6f', 'Mixed_6g', 'Mixed_6h', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c', 157 | 'Mixed_7d'] 158 | scope: Optional variable_scope. 159 | 160 | Returns: 161 | logits: the logits outputs of the model. 162 | end_points: the set of end_points from the inception model. 163 | 164 | Raises: 165 | ValueError: if final_endpoint is not set to one of the predefined values, 166 | """ 167 | end_points = {} 168 | 169 | def add_and_check_final(name, net): 170 | end_points[name] = net 171 | return name == final_endpoint 172 | 173 | with tf.variable_scope(scope, 'InceptionV4', [inputs]): 174 | with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d], 175 | stride=1, padding='SAME'): 176 | # 299 x 299 x 3 177 | net = slim.conv2d(inputs, 32, [3, 3], stride=2, 178 | padding='VALID', scope='Conv2d_1a_3x3') 179 | if add_and_check_final('Conv2d_1a_3x3', net): return net, end_points 180 | # 149 x 149 x 32 181 | net = slim.conv2d(net, 32, [3, 3], padding='VALID', 182 | scope='Conv2d_2a_3x3') 183 | if add_and_check_final('Conv2d_2a_3x3', net): return net, end_points 184 | # 147 x 147 x 32 185 | net = slim.conv2d(net, 64, [3, 3], scope='Conv2d_2b_3x3') 186 | if add_and_check_final('Conv2d_2b_3x3', net): return net, end_points 187 | # 147 x 147 x 64 188 | with tf.variable_scope('Mixed_3a'): 189 | with tf.variable_scope('Branch_0'): 190 | branch_0 = slim.max_pool2d(net, [3, 3], stride=2, padding='VALID', 191 | scope='MaxPool_0a_3x3') 192 | with tf.variable_scope('Branch_1'): 193 | branch_1 = slim.conv2d(net, 96, [3, 3], stride=2, padding='VALID', 194 | scope='Conv2d_0a_3x3') 195 | net = tf.concat(axis=3, values=[branch_0, branch_1]) 196 | if add_and_check_final('Mixed_3a', net): return net, end_points 197 | 198 | # 73 x 73 x 160 199 | with tf.variable_scope('Mixed_4a'): 200 | with tf.variable_scope('Branch_0'): 201 | branch_0 = slim.conv2d(net, 64, [1, 1], scope='Conv2d_0a_1x1') 202 | branch_0 = slim.conv2d(branch_0, 96, [3, 3], padding='VALID', 203 | scope='Conv2d_1a_3x3') 204 | with tf.variable_scope('Branch_1'): 205 | branch_1 = slim.conv2d(net, 64, [1, 1], scope='Conv2d_0a_1x1') 206 | branch_1 = slim.conv2d(branch_1, 64, [1, 7], scope='Conv2d_0b_1x7') 207 | branch_1 = slim.conv2d(branch_1, 64, [7, 1], scope='Conv2d_0c_7x1') 208 | branch_1 = slim.conv2d(branch_1, 96, [3, 3], padding='VALID', 209 | scope='Conv2d_1a_3x3') 210 | net = tf.concat(axis=3, values=[branch_0, branch_1]) 211 | if add_and_check_final('Mixed_4a', net): return net, end_points 212 | 213 | # 71 x 71 x 192 214 | with tf.variable_scope('Mixed_5a'): 215 | with tf.variable_scope('Branch_0'): 216 | branch_0 = slim.conv2d(net, 192, [3, 3], stride=2, padding='VALID', 217 | scope='Conv2d_1a_3x3') 218 | with tf.variable_scope('Branch_1'): 219 | branch_1 = slim.max_pool2d(net, [3, 3], stride=2, padding='VALID', 220 | scope='MaxPool_1a_3x3') 221 | net = tf.concat(axis=3, values=[branch_0, branch_1]) 222 | if add_and_check_final('Mixed_5a', net): return net, end_points 223 | 224 | # 35 x 35 x 384 225 | # 4 x Inception-A blocks 226 | for idx in range(4): 227 | block_scope = 'Mixed_5' + chr(ord('b') + idx) 228 | net = block_inception_a(net, block_scope) 229 | if add_and_check_final(block_scope, net): return net, end_points 230 | 231 | # 35 x 35 x 384 232 | # Reduction-A block 233 | net = block_reduction_a(net, 'Mixed_6a') 234 | if add_and_check_final('Mixed_6a', net): return net, end_points 235 | 236 | # 17 x 17 x 1024 237 | # 7 x Inception-B blocks 238 | for idx in range(7): 239 | block_scope = 'Mixed_6' + chr(ord('b') + idx) 240 | net = block_inception_b(net, block_scope) 241 | if add_and_check_final(block_scope, net): return net, end_points 242 | 243 | # 17 x 17 x 1024 244 | # Reduction-B block 245 | net = block_reduction_b(net, 'Mixed_7a') 246 | if add_and_check_final('Mixed_7a', net): return net, end_points 247 | 248 | # 8 x 8 x 1536 249 | # 3 x Inception-C blocks 250 | for idx in range(3): 251 | block_scope = 'Mixed_7' + chr(ord('b') + idx) 252 | net = block_inception_c(net, block_scope) 253 | if add_and_check_final(block_scope, net): return net, end_points 254 | raise ValueError('Unknown final endpoint %s' % final_endpoint) 255 | 256 | 257 | def inception_v4(inputs, num_classes=1001, is_training=True, 258 | dropout_keep_prob=0.8, 259 | reuse=None, 260 | scope='InceptionV4', 261 | create_aux_logits=True): 262 | """Creates the Inception V4 model. 263 | 264 | Args: 265 | inputs: a 4-D tensor of size [batch_size, height, width, 3]. 266 | num_classes: number of predicted classes. 267 | is_training: whether is training or not. 268 | dropout_keep_prob: float, the fraction to keep before final layer. 269 | reuse: whether or not the network and its variables should be reused. To be 270 | able to reuse 'scope' must be given. 271 | scope: Optional variable_scope. 272 | create_aux_logits: Whether to include the auxiliary logits. 273 | 274 | Returns: 275 | logits: the logits outputs of the model. 276 | end_points: the set of end_points from the inception model. 277 | """ 278 | end_points = {} 279 | with tf.variable_scope(scope, 'InceptionV4', [inputs], reuse=reuse) as scope: 280 | with slim.arg_scope([slim.batch_norm, slim.dropout], 281 | is_training=is_training): 282 | net, end_points = inception_v4_base(inputs, scope=scope) 283 | 284 | with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d], 285 | stride=1, padding='SAME'): 286 | # Auxiliary Head logits 287 | if create_aux_logits: 288 | with tf.variable_scope('AuxLogits'): 289 | # 17 x 17 x 1024 290 | aux_logits = end_points['Mixed_6h'] 291 | aux_logits = slim.avg_pool2d(aux_logits, [5, 5], stride=3, 292 | padding='VALID', 293 | scope='AvgPool_1a_5x5') 294 | aux_logits = slim.conv2d(aux_logits, 128, [1, 1], 295 | scope='Conv2d_1b_1x1') 296 | aux_logits = slim.conv2d(aux_logits, 768, 297 | aux_logits.get_shape()[1:3], 298 | padding='VALID', scope='Conv2d_2a') 299 | aux_logits = slim.flatten(aux_logits) 300 | aux_logits = slim.fully_connected(aux_logits, num_classes, 301 | activation_fn=None, 302 | scope='Aux_logits') 303 | end_points['AuxLogits'] = aux_logits 304 | 305 | # Final pooling and prediction 306 | with tf.variable_scope('Logits'): 307 | # 8 x 8 x 1536 308 | net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID', 309 | scope='AvgPool_1a') 310 | # 1 x 1 x 1536 311 | net = slim.dropout(net, dropout_keep_prob, scope='Dropout_1b') 312 | net = slim.flatten(net, scope='PreLogitsFlatten') 313 | end_points['PreLogitsFlatten'] = net 314 | # 1536 315 | logits = slim.fully_connected(net, num_classes, activation_fn=None, 316 | scope='Logits') 317 | end_points['Logits'] = logits 318 | end_points['Predictions'] = tf.nn.softmax(logits, name='Predictions') 319 | return logits, end_points 320 | 321 | 322 | inception_v4.default_image_size = 299 323 | 324 | inception_v4_arg_scope = inception_utils.inception_arg_scope 325 | -------------------------------------------------------------------------------- /src_single/input_pipeline.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import tensorflow as tf 5 | from data_processing.tfrecord import * 6 | 7 | from scipy import ndimage 8 | from config import Config 9 | 10 | 11 | # TODO Change to Dataset API 12 | sketchy_dir = '../training_data/sketchy' 13 | flickr_dir = '../training_data/flickr_output' 14 | 15 | 16 | paired_filenames_1 = [os.path.join(sketchy_dir, f) for f in os.listdir(sketchy_dir) 17 | if os.path.isfile(os.path.join(sketchy_dir, f))] 18 | paired_filenames_2 = [os.path.join(flickr_dir, f) for f in os.listdir(flickr_dir) 19 | if os.path.isfile(os.path.join(flickr_dir, f))] 20 | 21 | print("paired file sketchy num: %d" % len(paired_filenames_1)) 22 | print("paired file flickr num: %d" % len(paired_filenames_2)) 23 | 24 | # build class map 25 | class_mapping = [] 26 | classes_info = './data_processing/classes.csv' 27 | classes = read_csv(classes_info) 28 | classes_id = [item['Name'] for item in classes] 29 | for name in paired_filenames_1: 30 | name = os.path.splitext(os.path.split(name)[1])[0].split('_coco_')[0] 31 | class_id = classes_id.index(name) 32 | if class_id not in class_mapping: 33 | class_mapping.append(class_id) 34 | class_mapping = sorted(class_mapping) 35 | for name in paired_filenames_2: 36 | name = os.path.splitext(os.path.split(name)[1])[0].split('_coco_')[0] 37 | class_id = classes_id.index(name) 38 | if class_id not in class_mapping: 39 | print(name) 40 | raise RuntimeError 41 | num_classes = len(class_mapping) 42 | 43 | 44 | def get_num_classes(): 45 | return num_classes 46 | 47 | 48 | def one_hot_to_dense(labels): 49 | # Assume on value is 1 50 | batch_size = int(labels.get_shape()[0]) 51 | return tf.reshape(tf.where(tf.equal(labels, 1))[:, 1], (batch_size,)) 52 | 53 | 54 | def map_class_id_to_labels(batch_class_id, class_mapping=class_mapping): 55 | batch_class_id_backup = tf.identity(batch_class_id) 56 | 57 | for i in range(num_classes): 58 | comparison = tf.equal(batch_class_id_backup, tf.constant(class_mapping[i], dtype=tf.int64)) 59 | batch_class_id = tf.where(comparison, tf.ones_like(batch_class_id) * i, batch_class_id) 60 | ret_tensor = tf.squeeze(tf.one_hot(tf.cast(batch_class_id, dtype=tf.int32), num_classes, 61 | on_value=1, off_value=0, axis=1)) 62 | return ret_tensor 63 | 64 | 65 | def binarize(sketch, threshold=250): 66 | return tf.where(sketch < threshold, x=tf.zeros_like(sketch), y=tf.ones_like(sketch) * 255.) 67 | 68 | 69 | # SKETCH_CHANNEL = 3 70 | SIZE = {True: (64, 64), 71 | False: (256, 256)} 72 | 73 | 74 | # Distance map first, then resize 75 | def get_paired_input(paired_filenames, test_mode, distance_map=True, img_dim=(256, 256), 76 | fancy_upscaling=False, data_format='NCHW'): 77 | if test_mode: 78 | num_epochs = 1 79 | shuffle = False 80 | else: 81 | num_epochs = None 82 | shuffle = True 83 | filename_queue = tf.train.string_input_producer( 84 | paired_filenames, capacity=512, shuffle=shuffle, num_epochs=num_epochs) 85 | reader = tf.TFRecordReader() 86 | 87 | _, serialized_example = reader.read(filename_queue) 88 | 89 | features = tf.parse_single_example( 90 | serialized_example, 91 | features={ 92 | 'ImageNetID': tf.FixedLenFeature([], tf.string), 93 | 'SketchID': tf.FixedLenFeature([], tf.int64), 94 | 'Category': tf.FixedLenFeature([], tf.string), 95 | 'CategoryID': tf.FixedLenFeature([], tf.int64), 96 | 'Difficulty': tf.FixedLenFeature([], tf.int64), 97 | 'Stroke_Count': tf.FixedLenFeature([], tf.int64), 98 | 'WrongPose': tf.FixedLenFeature([], tf.int64), 99 | 'Context': tf.FixedLenFeature([], tf.int64), 100 | 'Ambiguous': tf.FixedLenFeature([], tf.int64), 101 | 'Error': tf.FixedLenFeature([], tf.int64), 102 | 'class_id': tf.FixedLenFeature([], tf.int64), 103 | 'is_test': tf.FixedLenFeature([], tf.int64), 104 | 'image_jpeg': tf.FixedLenFeature([], tf.string), 105 | 'image_small_jpeg': tf.FixedLenFeature([], tf.string), 106 | 'sketch_png': tf.FixedLenFeature([], tf.string), 107 | 'sketch_small_png': tf.FixedLenFeature([], tf.string), 108 | 'dist_map_png': tf.FixedLenFeature([], tf.string), 109 | 'dist_map_small_png': tf.FixedLenFeature([], tf.string), 110 | } 111 | ) 112 | 113 | if img_dim[0] < 64: 114 | image = tf.image.decode_jpeg(features['image_small_jpeg'], fancy_upscaling=fancy_upscaling) 115 | image = tf.cast(image, tf.float32) 116 | image = tf.reshape(image, (64, 64, 3)) 117 | else: 118 | image = tf.image.decode_jpeg(features['image_jpeg'], fancy_upscaling=fancy_upscaling) 119 | image = tf.cast(image, tf.float32) 120 | image = tf.reshape(image, (256, 256, 3)) 121 | 122 | if img_dim[0] < 64: 123 | if Config.pre_calculated_dist_map: 124 | sketch = tf.image.decode_png(features['dist_map_small_png'], channels=3) if distance_map \ 125 | else tf.image.decode_png(features['sketch_small_png'], channels=3) 126 | else: 127 | sketch = tf.image.decode_png(features['sketch_small_png'], channels=3) 128 | sketch = tf.cast(sketch, tf.float32) 129 | sketch = tf.reshape(sketch, (64, 64, 3)) 130 | else: 131 | if Config.pre_calculated_dist_map: 132 | sketch = tf.image.decode_png(features['dist_map_png'], channels=3) if distance_map \ 133 | else tf.image.decode_png(features['sketch_png'], channels=3) 134 | else: 135 | sketch = tf.image.decode_png(features['sketch_png'], channels=3) 136 | sketch = tf.cast(sketch, tf.float32) 137 | sketch = tf.reshape(sketch, (256, 256, 3)) 138 | 139 | # Distance map 140 | if not Config.pre_calculated_dist_map and distance_map: 141 | # Binarize 142 | sketch = binarize(sketch) 143 | sketch_shape = sketch.shape 144 | 145 | sketch = tf.py_func(lambda x: ndimage.distance_transform_edt(x).astype(np.float32), 146 | [sketch], tf.float32, stateful=False) 147 | sketch = tf.reshape(sketch, sketch_shape) 148 | # Normalize 149 | sketch = sketch / tf.reduce_max(sketch) * 255. 150 | 151 | # Resize 152 | if img_dim[0] != 256: 153 | image = tf.image.resize_images(image, img_dim, method=tf.image.ResizeMethod.BILINEAR) 154 | sketch = tf.image.resize_images(sketch, img_dim, method=tf.image.ResizeMethod.BILINEAR) 155 | # if img_dim[0] > 256: 156 | # image = tf.image.resize_images(image, img_dim, method=tf.image.ResizeMethod.BILINEAR) 157 | # sketch = tf.image.resize_images(sketch, img_dim, method=tf.image.ResizeMethod.BILINEAR) 158 | # elif img_dim[0] < 256: 159 | # image = tf.image.resize_images(image, img_dim, method=tf.image.ResizeMethod.AREA) 160 | # sketch = tf.image.resize_images(sketch, img_dim, method=tf.image.ResizeMethod.AREA) 161 | 162 | # Augmentation 163 | # Image 164 | image = tf.image.random_brightness(image, max_delta=0.3) 165 | image = tf.image.random_contrast(image, lower=0.8, upper=1.2) 166 | # image_large = tf.image.random_hue(image_large, max_delta=0.05) 167 | 168 | # Normalization 169 | image = (image - tf.reduce_min(image)) / (tf.reduce_max(image) - tf.reduce_min(image) + 1) 170 | image += tf.random_uniform(shape=image.shape, minval=0., maxval=1. / 256) # dequantize 171 | sketch = sketch / 255. 172 | 173 | image = image * 2. - 1 174 | sketch = sketch * 2. - 1 175 | 176 | # Transpose for data format 177 | if data_format == 'NCHW': 178 | image = tf.transpose(image, [2, 0, 1]) 179 | sketch = tf.transpose(sketch, [2, 0, 1]) 180 | 181 | # Attributes 182 | category = features['Category'] 183 | imagenet_id = features['ImageNetID'] 184 | sketch_id = features['SketchID'] 185 | class_id = features['class_id'] 186 | is_test = features['is_test'] 187 | WrongPose = features['WrongPose'] 188 | Context = features['Context'] 189 | Ambiguous = features['Ambiguous'] 190 | Error = features['Error'] 191 | 192 | if not test_mode: 193 | is_valid = WrongPose + Context + Ambiguous + Error + is_test 194 | else: 195 | is_valid = 1 - is_test 196 | 197 | is_valid = tf.equal(is_valid, 0) 198 | 199 | return image, sketch, class_id, is_valid, category, imagenet_id, sketch_id 200 | 201 | 202 | def build_input_queue_paired_sketchy(batch_size, data_format='NCHW', distance_map=True, small=True, one_hot=False, 203 | capacity=8192): 204 | image, sketch, class_id, is_valid, _, _, _ = get_paired_input( 205 | paired_filenames_1, test_mode=False, distance_map=distance_map, img_dim=SIZE[small], data_format=data_format) 206 | 207 | images, sketches, class_ids = tf.train.maybe_shuffle_batch( 208 | [image, sketch, class_id], 209 | batch_size=batch_size, capacity=capacity, 210 | keep_input=is_valid, min_after_dequeue=32, 211 | num_threads=4) 212 | 213 | if one_hot: 214 | labels = map_class_id_to_labels(class_ids) 215 | else: 216 | labels = one_hot_to_dense(map_class_id_to_labels(class_ids)) 217 | return images, sketches, labels 218 | 219 | 220 | def build_input_queue_paired_sketchy_test(batch_size, data_format='NCHW', distance_map=True, small=True, one_hot=False, 221 | capacity=8192): 222 | image, sketch, class_id, is_valid, category, imagenet_id, sketch_id = get_paired_input( 223 | paired_filenames_1, test_mode=True, distance_map=distance_map, img_dim=SIZE[small], data_format=data_format) 224 | 225 | images, sketches, class_ids, categories, imagenet_ids, sketch_ids = tf.train.maybe_batch( 226 | [image, sketch, class_id, category, imagenet_id, sketch_id], 227 | batch_size=batch_size, capacity=capacity, 228 | keep_input=is_valid, num_threads=2) 229 | 230 | if one_hot: 231 | labels = map_class_id_to_labels(class_ids) 232 | else: 233 | labels = one_hot_to_dense(map_class_id_to_labels(class_ids)) 234 | 235 | return images, sketches, labels, categories, imagenet_ids, sketch_ids 236 | 237 | 238 | def build_input_queue_paired_flickr(batch_size, data_format='NCHW', distance_map=True, small=True, one_hot=False, 239 | capacity=int(1.5 * 2 ** 15)): 240 | image, sketch, class_id, is_valid, _, _, _ = get_paired_input( 241 | paired_filenames_2, test_mode=False, distance_map=distance_map, img_dim=SIZE[small], data_format=data_format) 242 | 243 | images, sketches, class_ids = tf.train.maybe_shuffle_batch( 244 | [image, sketch, class_id], 245 | batch_size=batch_size, capacity=capacity, 246 | keep_input=is_valid, min_after_dequeue=512, 247 | num_threads=4) 248 | 249 | if one_hot: 250 | labels = map_class_id_to_labels(class_ids) 251 | else: 252 | labels = one_hot_to_dense(map_class_id_to_labels(class_ids)) 253 | 254 | return images, sketches, labels 255 | 256 | 257 | def build_input_queue_paired_mixed(batch_size, proportion=None, data_format='NCHW', distance_map=True, small=True, 258 | one_hot=False, capacity=int(1.5 * 2 ** 15)): 259 | def _sk_list(): 260 | image_sk, sketch_sk, class_id_sk, is_valid_sk, _, _, _ = get_paired_input( 261 | paired_filenames_1, test_mode=False, distance_map=distance_map, img_dim=SIZE[small], data_format=data_format) 262 | return image_sk, sketch_sk, class_id_sk, is_valid_sk 263 | 264 | def _f_list(): 265 | image_f, sketch_f, class_id_f, is_valid_f, _, _, _ = get_paired_input( 266 | paired_filenames_2, test_mode=False, distance_map=distance_map, img_dim=SIZE[small], data_format=data_format) 267 | return image_f, sketch_f, class_id_f, is_valid_f 268 | 269 | idx = tf.floor(tf.random_uniform(shape=(), minval=0., maxval=1., dtype=tf.float32) + proportion) 270 | sk_list = _sk_list() 271 | f_list = _f_list() 272 | image, sketch, class_id, is_valid = [ 273 | tf.cast(a, tf.float32) * idx + tf.cast(b, tf.float32) * (1 - idx) for a, b in zip(sk_list, f_list) 274 | ] 275 | class_id = tf.cast(class_id, tf.int64) 276 | is_valid = tf.cast(is_valid, tf.bool) 277 | # is_valid = tf.Print(is_valid, [idx, sk_list[4], f_list[4], class_id, sk_list[5], f_list[5], is_valid]) 278 | 279 | images, sketches, class_ids = tf.train.maybe_shuffle_batch( 280 | [image, sketch, class_id], 281 | batch_size=batch_size, capacity=capacity, 282 | keep_input=is_valid, min_after_dequeue=512, 283 | num_threads=4) 284 | 285 | if one_hot: 286 | labels = map_class_id_to_labels(class_ids) 287 | else: 288 | labels = one_hot_to_dense(map_class_id_to_labels(class_ids)) 289 | 290 | return images, sketches, labels 291 | 292 | 293 | def split_inputs(input_data, batch_size, batch_portion, num_gpu): 294 | input_data_list = [] 295 | dim = len(input_data.get_shape()) 296 | start = 0 297 | for i in range(num_gpu): 298 | idx = [start] 299 | size = [batch_size * batch_portion[i]] 300 | idx.extend([0] * (dim - 1)) 301 | size.extend([-1] * (dim - 1)) 302 | input_data_list.append(tf.slice(input_data, idx, size)) 303 | 304 | start += batch_size * batch_portion[i] 305 | return input_data_list 306 | -------------------------------------------------------------------------------- /src_single/models_mru.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import tensorflow.contrib.layers as ly 4 | 5 | from mru import embed_labels, fully_connected, conv2d, mean_pool, upsample, mru_conv, mru_deconv 6 | from config import Config 7 | 8 | SIZE = 64 9 | NUM_BLOCKS = 1 10 | 11 | 12 | def image_resize(inputs, size, method, data_format): 13 | if data_format == 'NCHW': 14 | inputs = tf.transpose(inputs, [0, 2, 3, 1]) 15 | out = tf.image.resize_images(inputs, size, method) 16 | if data_format == 'NCHW': 17 | out = tf.transpose(out, [0, 3, 1, 2]) 18 | return out 19 | 20 | 21 | def batchnorm(inputs, data_format=None, activation_fn=None, labels=None, n_labels=None): 22 | """conditional batchnorm (dumoulin et al 2016) for BCHW conv filtermaps""" 23 | if data_format != 'NCHW': 24 | raise Exception('unsupported') 25 | mean, var = tf.nn.moments(inputs, (0, 2, 3) if len(inputs.shape) == 4 else (0,), keep_dims=True) 26 | shape = mean.get_shape().as_list() # shape is [1,n,1,1] 27 | offset_m = tf.get_variable('offset', initializer=np.zeros([n_labels, shape[1]], dtype='float32')) 28 | scale_m = tf.get_variable('scale', initializer=np.ones([n_labels, shape[1]], dtype='float32')) 29 | offset = tf.nn.embedding_lookup(offset_m, labels) 30 | scale = tf.nn.embedding_lookup(scale_m, labels) 31 | result = tf.nn.batch_normalization(inputs, mean, var, 32 | offset[:, :, None, None] if len(inputs.shape) == 4 else offset[:, :], 33 | scale[:, :, None, None] if len(inputs.shape) == 4 else scale[:, :], 34 | 1e-5) 35 | return result 36 | 37 | 38 | def conditional(inputs, data_format=None, activation_fn=None, labels=None, n_labels=10): 39 | if data_format != 'NCHW': 40 | raise Exception('unsupported') 41 | with tf.variable_scope(None, 'conditional_shift'): 42 | depth = inputs.get_shape().as_list()[1] 43 | offset_m = tf.get_variable('offset', initializer=np.zeros([n_labels, depth], dtype='float32')) 44 | scale_m = tf.get_variable('scale', initializer=np.ones([n_labels, depth], dtype='float32')) 45 | offset = tf.nn.embedding_lookup(offset_m, labels) 46 | scale = tf.nn.embedding_lookup(scale_m, labels) 47 | result = inputs * scale[:, :, None, None] + offset[:, :, None, None] 48 | return result 49 | 50 | 51 | def lrelu(x, leak=0.3, name="lrelu"): 52 | with tf.variable_scope(name): 53 | return tf.maximum(leak * x, x) 54 | 55 | 56 | def prelu(x, name="prelu"): 57 | with tf.variable_scope(name): 58 | leak = tf.get_variable("param", shape=None, initializer=0.2, regularizer=None, 59 | trainable=True, caching_device=None) 60 | return tf.maximum(leak * x, x) 61 | 62 | 63 | def miu_relu(x, miu=0.7, name="miu_relu"): 64 | with tf.variable_scope(name): 65 | return (x + tf.sqrt((1 - miu) ** 2 + x ** 2)) / 2. 66 | 67 | 68 | def image_encoder_mru(x, num_classes, reuse=False, data_format='NCHW', labels=None, scope_name=None): 69 | assert data_format == 'NCHW' 70 | size = SIZE 71 | num_blocks = NUM_BLOCKS 72 | resize_func = tf.image.resize_bilinear 73 | sn = False 74 | 75 | if normalizer_params_e is not None and normalizer_fn_e != ly.batch_norm and normalizer_fn_e != ly.layer_norm: 76 | normalizer_params_e['labels'] = labels 77 | normalizer_params_e['n_labels'] = num_classes 78 | 79 | if data_format == 'NCHW': 80 | x_list = [] 81 | resized_ = x 82 | x_list.append(resized_) 83 | 84 | for i in range(4): 85 | resized_ = mean_pool(resized_, data_format=data_format) 86 | x_list.append(resized_) 87 | x_list = x_list[::-1] 88 | else: 89 | raise NotImplementedError 90 | 91 | output_list = [] 92 | 93 | h0 = conv2d(x_list[-1], 8, kernel_size=7, sn=sn, stride=2, data_format=data_format, 94 | activation_fn=None, 95 | normalizer_fn=None, 96 | normalizer_params=None, 97 | weights_initializer=weight_initializer) 98 | 99 | output_list.append(h0) 100 | 101 | # Initial memory state 102 | hidden_state_shape = h0.get_shape().as_list() 103 | hidden_state_shape[0] = 1 104 | hts_0 = [h0] 105 | 106 | hts_1 = mru_conv(x_list[-2], hts_0, 107 | size * 1, sn=sn, stride=2, dilate_rate=1, 108 | data_format=data_format, num_blocks=num_blocks, 109 | last_unit=False, 110 | activation_fn=activation_fn_e, 111 | normalizer_fn=normalizer_fn_e, 112 | normalizer_params=normalizer_params_e, 113 | weights_initializer=weight_initializer, 114 | unit_num=1) 115 | output_list.append(hts_1[-1]) 116 | hts_2 = mru_conv(x_list[-3], hts_1, 117 | size * 2, sn=sn, stride=2, dilate_rate=1, 118 | data_format=data_format, num_blocks=num_blocks, 119 | last_unit=False, 120 | activation_fn=activation_fn_e, 121 | normalizer_fn=normalizer_fn_e, 122 | normalizer_params=normalizer_params_e, 123 | weights_initializer=weight_initializer, 124 | unit_num=2) 125 | output_list.append(hts_2[-1]) 126 | hts_3 = mru_conv(x_list[-4], hts_2, 127 | size * 4, sn=sn, stride=2, dilate_rate=1, 128 | data_format=data_format, num_blocks=num_blocks, 129 | last_unit=False, 130 | activation_fn=activation_fn_e, 131 | normalizer_fn=normalizer_fn_e, 132 | normalizer_params=normalizer_params_e, 133 | weights_initializer=weight_initializer, 134 | unit_num=3) 135 | output_list.append(hts_3[-1]) 136 | hts_4 = mru_conv(x_list[-5], hts_3, 137 | size * 8, sn=sn, stride=2, dilate_rate=1, 138 | data_format=data_format, num_blocks=num_blocks, 139 | last_unit=True, 140 | activation_fn=activation_fn_e, 141 | normalizer_fn=normalizer_fn_e, 142 | normalizer_params=normalizer_params_e, 143 | weights_initializer=weight_initializer, 144 | unit_num=4) 145 | output_list.append(hts_4[-1]) 146 | 147 | return output_list 148 | 149 | 150 | def generator_skip(z, output_channel, num_classes, reuse=False, data_format='NCHW', 151 | labels=None, scope_name=None): 152 | print("G") 153 | size = SIZE 154 | num_blocks = NUM_BLOCKS 155 | sn = False 156 | 157 | input_dims = z.get_shape().as_list() 158 | resize_method = tf.image.ResizeMethod.AREA 159 | 160 | if data_format == 'NCHW': 161 | height = input_dims[2] 162 | width = input_dims[3] 163 | else: 164 | height = input_dims[1] 165 | width = input_dims[2] 166 | resized_z = [tf.identity(z)] 167 | for i in range(5): 168 | resized_z.append(image_resize(z, [int(height / 2 ** (i + 1)), int(width / 2 ** (i + 1))], 169 | resize_method, data_format)) 170 | resized_z = resized_z[::-1] 171 | 172 | if data_format == 'NCHW': 173 | concat_axis = 1 174 | else: 175 | concat_axis = 3 176 | 177 | if normalizer_params_g is not None and normalizer_fn_g != ly.batch_norm and normalizer_fn_g != ly.layer_norm: 178 | normalizer_params_g['labels'] = labels 179 | normalizer_params_g['n_labels'] = num_classes 180 | 181 | with tf.variable_scope(scope_name) as scope: 182 | if reuse: 183 | scope.reuse_variables() 184 | 185 | z_encoded = image_encoder_mru(z, num_classes=num_classes, reuse=reuse, data_format=data_format, 186 | labels=labels, scope_name=scope_name) 187 | 188 | input_e_dims = z_encoded[-1].get_shape().as_list() 189 | batch_size = input_e_dims[0] 190 | channel_depth = int(input_e_dims[concat_axis] / 8.) 191 | if data_format == 'NCHW': 192 | noise_dims = [batch_size, channel_depth, int(input_e_dims[2] * 2), int(input_e_dims[3] * 2)] 193 | else: 194 | noise_dims = [batch_size, int(input_e_dims[1] * 2), int(input_e_dims[2] * 2), channel_depth] 195 | 196 | noise_vec = tf.random_normal(shape=(batch_size, 256), dtype=tf.float32) 197 | noise = fully_connected(noise_vec, int(np.prod(noise_dims[1:])), sn=sn, 198 | activation_fn=activation_fn_g, 199 | normalizer_fn=normalizer_fn_g, 200 | normalizer_params=normalizer_params_g) 201 | noise = tf.reshape(noise, shape=noise_dims) 202 | 203 | # Initial memory state 204 | hidden_state_shape = z_encoded[-1].get_shape().as_list() 205 | hidden_state_shape[0] = 1 206 | hts_0 = [z_encoded[-1]] 207 | 208 | input_0 = tf.concat([resized_z[1], noise], axis=concat_axis) 209 | hts_1 = mru_deconv(input_0, hts_0, 210 | size * 6, sn=sn, stride=2, data_format=data_format, 211 | num_blocks=num_blocks, 212 | last_unit=False, 213 | activation_fn=activation_fn_g, 214 | normalizer_fn=normalizer_fn_g, 215 | normalizer_params=normalizer_params_g, 216 | weights_initializer=weight_initializer, 217 | unit_num=0) 218 | input_1 = tf.concat([resized_z[2], z_encoded[-3]], axis=concat_axis) 219 | hts_2 = mru_deconv(input_1, hts_1, 220 | size * 4, sn=sn, stride=2, data_format=data_format, 221 | num_blocks=num_blocks, 222 | last_unit=False, 223 | activation_fn=activation_fn_g, 224 | normalizer_fn=normalizer_fn_g, 225 | normalizer_params=normalizer_params_g, 226 | weights_initializer=weight_initializer, 227 | unit_num=2) 228 | input_2 = tf.concat([resized_z[3], z_encoded[-4]], axis=concat_axis) 229 | hts_3 = mru_deconv(input_2, hts_2, 230 | size * 2, sn=sn, stride=2, data_format=data_format, 231 | num_blocks=num_blocks, 232 | last_unit=False, 233 | activation_fn=activation_fn_g, 234 | normalizer_fn=normalizer_fn_g, 235 | normalizer_params=normalizer_params_g, 236 | weights_initializer=weight_initializer, 237 | unit_num=4) 238 | input_3 = tf.concat([resized_z[4], z_encoded[-5]], axis=concat_axis) 239 | hts_4 = mru_deconv(input_3, hts_3, 240 | size * 2, sn=sn, stride=2, data_format=data_format, 241 | num_blocks=num_blocks, 242 | last_unit=False, 243 | activation_fn=activation_fn_g, 244 | normalizer_fn=normalizer_fn_g, 245 | normalizer_params=normalizer_params_g, 246 | weights_initializer=weight_initializer, 247 | unit_num=6) 248 | hts_5 = mru_deconv(resized_z[5], hts_4, 249 | size * 1, sn=sn, stride=2, data_format=data_format, 250 | num_blocks=num_blocks, 251 | last_unit=True, 252 | activation_fn=activation_fn_g, 253 | normalizer_fn=normalizer_fn_g, 254 | normalizer_params=normalizer_params_g, 255 | weights_initializer=weight_initializer, 256 | unit_num=8) 257 | out = conv2d(hts_5[-1], 3, 7, sn=sn, stride=1, data_format=data_format, 258 | normalizer_fn=None, activation_fn=tf.nn.tanh, 259 | weights_initializer=weight_initializer) 260 | assert out.get_shape().as_list()[2] == 64 261 | return out, noise_vec 262 | 263 | 264 | # MRU 265 | def critic_multiple_proj(x, num_classes, labels=None, reuse=False, data_format='NCHW', scope_name=None): 266 | print("D") 267 | assert data_format == 'NCHW' 268 | size = SIZE 269 | num_blocks = NUM_BLOCKS 270 | resize_func = tf.image.resize_bilinear 271 | sn = Config.sn 272 | 273 | if data_format == 'NCHW': 274 | channel_axis = 1 275 | else: 276 | channel_axis = 3 277 | if type(x) is list: 278 | x = x[-1] 279 | 280 | if data_format == 'NCHW': 281 | x_list = [] 282 | resized_ = x 283 | x_list.append(resized_) 284 | 285 | for i in range(5): 286 | resized_ = mean_pool(resized_, data_format=data_format) 287 | x_list.append(resized_) 288 | x_list = x_list[::-1] 289 | else: 290 | raise NotImplementedError 291 | 292 | output_dim = 1 293 | 294 | with tf.variable_scope(scope_name) as scope: 295 | if reuse: 296 | scope.reuse_variables() 297 | 298 | h0 = conv2d(x_list[-1], 8, kernel_size=7, sn=sn, stride=1, data_format=data_format, 299 | activation_fn=activation_fn_d, 300 | normalizer_fn=normalizer_fn_d, 301 | normalizer_params=normalizer_params_d, 302 | weights_initializer=weight_initializer) 303 | 304 | # Initial memory state 305 | hidden_state_shape = h0.get_shape().as_list() 306 | batch_size = hidden_state_shape[0] 307 | hidden_state_shape[0] = 1 308 | hts_0 = [h0] 309 | for i in range(1, num_blocks): 310 | h0 = tf.tile(tf.get_variable("initial_hidden_state_%d" % i, shape=hidden_state_shape, dtype=tf.float32, 311 | initializer=tf.zeros_initializer()), [batch_size, 1, 1, 1]) 312 | hts_0.append(h0) 313 | 314 | hts_1 = mru_conv(x_list[-1], hts_0, 315 | size * 2, sn=sn, stride=2, dilate_rate=1, 316 | data_format=data_format, num_blocks=num_blocks, 317 | last_unit=False, 318 | activation_fn=activation_fn_d, 319 | normalizer_fn=normalizer_fn_d, 320 | normalizer_params=normalizer_params_d, 321 | weights_initializer=weight_initializer, 322 | unit_num=1) 323 | hts_2 = mru_conv(x_list[-2], hts_1, 324 | size * 4, sn=sn, stride=2, dilate_rate=1, 325 | data_format=data_format, num_blocks=num_blocks, 326 | last_unit=False, 327 | activation_fn=activation_fn_d, 328 | normalizer_fn=normalizer_fn_d, 329 | normalizer_params=normalizer_params_d, 330 | weights_initializer=weight_initializer, 331 | unit_num=2) 332 | hts_3 = mru_conv(x_list[-3], hts_2, 333 | size * 8, sn=sn, stride=2, dilate_rate=1, 334 | data_format=data_format, num_blocks=num_blocks, 335 | last_unit=False, 336 | activation_fn=activation_fn_d, 337 | normalizer_fn=normalizer_fn_d, 338 | normalizer_params=normalizer_params_d, 339 | weights_initializer=weight_initializer, 340 | unit_num=3) 341 | hts_4 = mru_conv(x_list[-4], hts_3, 342 | size * 12, sn=sn, stride=2, dilate_rate=1, 343 | data_format=data_format, num_blocks=num_blocks, 344 | last_unit=True, 345 | activation_fn=activation_fn_d, 346 | normalizer_fn=normalizer_fn_d, 347 | normalizer_params=normalizer_params_d, 348 | weights_initializer=weight_initializer, 349 | unit_num=4) 350 | 351 | img = hts_4[-1] 352 | img_shape = img.get_shape().as_list() 353 | 354 | # discriminator end 355 | disc = conv2d(img, output_dim, kernel_size=1, sn=sn, stride=1, data_format=data_format, 356 | activation_fn=None, normalizer_fn=None, 357 | weights_initializer=weight_initializer) 358 | 359 | if Config.proj_d: 360 | # Projection discriminator 361 | assert labels is not None and (len(labels.get_shape()) == 1 or labels.get_shape().as_list()[-1] == 1) 362 | 363 | class_embeddings = embed_labels(labels, num_classes, img_shape[channel_axis], sn=sn) 364 | class_embeddings = tf.reshape(class_embeddings, (img_shape[0], img_shape[channel_axis], 1, 1)) # NCHW 365 | 366 | disc += tf.reduce_sum(img * class_embeddings, axis=1, keep_dims=True) 367 | 368 | logits = None 369 | else: 370 | # classification end 371 | img = tf.reduce_mean(img, axis=(2, 3) if data_format == 'NCHW' else (1, 2)) 372 | logits = fully_connected(img, num_classes, sn=sn, activation_fn=None, normalizer_fn=None) 373 | 374 | return disc, logits 375 | 376 | 377 | weight_initializer = tf.random_normal_initializer(0, 0.02) 378 | # weight_initializer = ly.xavier_initializer_conv2d() 379 | 380 | 381 | def set_param(data_format='NCHW'): 382 | global model_data_format, normalizer_fn_e, normalizer_fn_g, normalizer_fn_d, \ 383 | normalizer_params_e, normalizer_params_g, normalizer_params_d 384 | model_data_format = data_format 385 | normalizer_fn_e = batchnorm 386 | normalizer_params_e = {'data_format': model_data_format} 387 | normalizer_fn_g = batchnorm 388 | normalizer_params_g = {'data_format': model_data_format} 389 | normalizer_fn_d = None 390 | normalizer_params_d = None 391 | 392 | 393 | model_data_format = None 394 | 395 | normalizer_fn_e = None 396 | normalizer_params_e = None 397 | normalizer_fn_g = None 398 | normalizer_params_g = None 399 | normalizer_fn_d = None 400 | normalizer_params_d = None 401 | 402 | activation_fn_e = miu_relu 403 | activation_fn_g = miu_relu 404 | activation_fn_d = prelu 405 | 406 | generator = generator_skip 407 | critic = critic_multiple_proj 408 | -------------------------------------------------------------------------------- /src_single/sn.py: -------------------------------------------------------------------------------- 1 | # Code by minhnhat93 2 | import tensorflow as tf 3 | import warnings 4 | 5 | NO_OPS = 'NO_OPS' 6 | 7 | 8 | def _l2normalize(v, eps=1e-12): 9 | return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps) 10 | 11 | 12 | def spectral_normed_weight(W, u=None, num_iters=1, update_collection=None, with_sigma=False): 13 | # Usually num_iters = 1 will be enough 14 | W_shape = W.shape.as_list() 15 | W_reshaped = tf.reshape(W, [-1, W_shape[-1]]) 16 | if u is None: 17 | with tf.variable_scope(W.name.rsplit('/', 1)[0]) as sc: 18 | u = tf.get_variable("u", [1, W_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False) 19 | 20 | def power_iteration(i, u_i, v_i): 21 | v_ip1 = _l2normalize(tf.matmul(u_i, tf.transpose(W_reshaped))) 22 | u_ip1 = _l2normalize(tf.matmul(v_ip1, W_reshaped)) 23 | return i + 1, u_ip1, v_ip1 24 | 25 | _, u_final, v_final = tf.while_loop( 26 | cond=lambda i, _1, _2: i < num_iters, 27 | body=power_iteration, 28 | loop_vars=(tf.constant(0, dtype=tf.int32), 29 | u, tf.zeros(dtype=tf.float32, shape=[1, W_reshaped.shape.as_list()[0]])) 30 | ) 31 | if update_collection is None: 32 | warnings.warn( 33 | 'Setting update_collection to None will make u being updated every W execution. This maybe undesirable' 34 | '. Please consider using a update collection instead.') 35 | sigma = tf.matmul(tf.matmul(v_final, W_reshaped), tf.transpose(u_final))[0, 0] 36 | # sigma = tf.reduce_sum(tf.matmul(u_final, tf.transpose(W_reshaped)) * v_final) 37 | W_bar = W_reshaped / sigma 38 | with tf.control_dependencies([u.assign(u_final)]): 39 | W_bar = tf.reshape(W_bar, W_shape) 40 | else: 41 | sigma = tf.matmul(tf.matmul(v_final, W_reshaped), tf.transpose(u_final))[0, 0] 42 | # sigma = tf.reduce_sum(tf.matmul(u_final, tf.transpose(W_reshaped)) * v_final) 43 | W_bar = W_reshaped / sigma 44 | W_bar = tf.reshape(W_bar, W_shape) 45 | # Put NO_OPS to not update any collection. This is useful for the second call of discriminator if the update_op 46 | # has already been collected on the first call. 47 | if update_collection != NO_OPS: 48 | tf.add_to_collection(update_collection, u.assign(u_final)) 49 | if with_sigma: 50 | return W_bar, sigma 51 | else: 52 | return W_bar 53 | -------------------------------------------------------------------------------- /src_single/train_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | from time import time 3 | import pickle 4 | 5 | import cv2 6 | import numpy as np 7 | import tensorflow as tf 8 | from tensorflow.python.client import timeline 9 | 10 | from graph_single import build_multi_tower_graph, build_single_graph 11 | from input_pipeline import build_input_queue_paired_sketchy, build_input_queue_paired_sketchy_test, build_input_queue_paired_flickr, build_input_queue_paired_mixed 12 | import inception_score 13 | from config import Config 14 | 15 | tf.logging.set_verbosity(tf.logging.INFO) 16 | inception_v4_ckpt_path = './inception_v4_model/inception_v4.ckpt' 17 | vgg_16_ckpt_path = './vgg_16_model/vgg_16.ckpt' 18 | 19 | 20 | def print_parameter_count(verbose=False): 21 | total_parameters = 0 22 | for variable in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator'): 23 | # shape is an array of tf.Dimension 24 | shape = variable.get_shape() 25 | # print(len(shape)) 26 | variable_parametes = 1 27 | for dim in shape: 28 | # print(dim) 29 | variable_parametes *= dim.value 30 | if verbose and len(shape) > 1: 31 | print(shape) 32 | print(variable_parametes) 33 | total_parameters += variable_parametes 34 | print('generator') 35 | print(total_parameters) 36 | 37 | total_parameters = 0 38 | for variable in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='discriminator'): 39 | # shape is an array of tf.Dimension 40 | shape = variable.get_shape() 41 | # print(len(shape)) 42 | variable_parametes = 1 43 | for dim in shape: 44 | # print(dim) 45 | variable_parametes *= dim.value 46 | if verbose and len(shape) > 1: 47 | print(shape) 48 | print(variable_parametes) 49 | total_parameters += variable_parametes 50 | print('critic') 51 | print(total_parameters) 52 | 53 | 54 | def train(**kwargs): 55 | 56 | def get_inception_score_origin(generator_out, data_format, session, n): 57 | all_samples = [] 58 | img_dim = 64 59 | for i in range(n // 100): 60 | all_samples.append(session.run(generator_out)) 61 | all_samples = np.concatenate(all_samples, axis=0) 62 | all_samples = ((all_samples + 1.) * (255. / 2)).astype('int32') 63 | all_samples = all_samples.reshape((-1, 3, img_dim, img_dim)) 64 | if data_format == 'NCHW': 65 | all_samples = all_samples.transpose(0, 2, 3, 1) 66 | return inception_score.get_inception_score(list(all_samples), session) 67 | 68 | status = 0 69 | 70 | # Roll out the parameters 71 | appendix = Config.resume_from 72 | batch_size = Config.batch_size 73 | max_iter_step = Config.max_iter_step 74 | Diters = Config.disc_iterations 75 | ld = Config.ld 76 | optimizer = Config.optimizer 77 | lr_G = Config.lr_G 78 | lr_D = Config.lr_D 79 | num_gpu = Config.num_gpu 80 | log_dir = Config.log_dir 81 | ckpt_dir = Config.ckpt_dir 82 | data_format = Config.data_format 83 | distance_map = Config.distance_map 84 | small_img = Config.small_img 85 | 86 | distance_map = distance_map != 0 87 | small = small_img != 0 88 | batch_portion = np.array([1, 1, 1, 1], dtype=np.int32) 89 | 90 | iter_from = kwargs['iter_from'] 91 | 92 | # Time counter 93 | prev_time = float("-inf") 94 | curr_time = float("-inf") 95 | 96 | tf.reset_default_graph() 97 | print('Iteration starts from: %d' % iter_from) 98 | 99 | assert inception_score.softmax.graph != tf.get_default_graph() 100 | inception_score._init_inception() 101 | 102 | counter = tf.Variable(initial_value=iter_from, dtype=tf.int32, trainable=False) 103 | counter_addition_op = tf.assign_add(counter, 1, use_locking=True) 104 | 105 | # proportion = tf.round(tf.cast(counter, tf.float32) / max_iter_step) 106 | proportion = 0.2 + tf.minimum(0.6, tf.cast(counter, tf.float32) / max_iter_step * 0.6) 107 | 108 | # Construct data queue 109 | with tf.device('/cpu:0'): 110 | images, sketches, image_paired_class_ids = build_input_queue_paired_mixed( 111 | batch_size=batch_size * num_gpu, 112 | proportion=proportion, 113 | data_format=data_format, 114 | distance_map=distance_map, 115 | small=small, capacity=2 ** 11) 116 | with tf.device('/cpu:0'): 117 | images_d, _, image_paired_class_ids_d = build_input_queue_paired_mixed( 118 | batch_size=batch_size * num_gpu, 119 | proportion=tf.constant(0.1, dtype=tf.float32), 120 | data_format=data_format, 121 | distance_map=distance_map, 122 | small=small, capacity=2 ** 11) 123 | with tf.device('/cpu:0'): 124 | _, sketches_100, image_paired_class_ids_100 = build_input_queue_paired_sketchy( 125 | batch_size=100, 126 | data_format=data_format, 127 | distance_map=distance_map, 128 | small=small, capacity=1024) 129 | 130 | opt_g, opt_d, loss_g, loss_d, merged_all, gen_out = build_multi_tower_graph( 131 | images, sketches, images_d, 132 | sketches_100, 133 | image_paired_class_ids, image_paired_class_ids_d, image_paired_class_ids_100, 134 | batch_size=batch_size, num_gpu=num_gpu, batch_portion=batch_portion, training=True, 135 | learning_rates={ 136 | "generator": lr_G, 137 | "discriminator": lr_D, 138 | }, 139 | counter=counter, proportion=proportion, max_iter_step=max_iter_step, 140 | ld=ld, data_format=data_format, 141 | distance_map=distance_map, 142 | optimizer=optimizer) 143 | 144 | inception_score_mean = tf.placeholder(dtype=tf.float32, shape=()) 145 | inception_score_std = tf.placeholder(dtype=tf.float32, shape=()) 146 | inception_score_mean_summary = tf.summary.scalar("inception_score/mean", inception_score_mean) 147 | inception_score_std_summary = tf.summary.scalar("inception_score/std", inception_score_std) 148 | inception_score_summary = tf.summary.merge((inception_score_mean_summary, inception_score_std_summary)) 149 | 150 | saver = tf.train.Saver() 151 | try: 152 | saver2 = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='InceptionV4')) 153 | perceptual_model_path = inception_v4_ckpt_path 154 | except: 155 | try: 156 | saver2 = tf.train.Saver(tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='vgg_16')) 157 | perceptual_model_path = vgg_16_ckpt_path 158 | except: 159 | saver2 = None 160 | 161 | config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False, 162 | intra_op_parallelism_threads=4, inter_op_parallelism_threads=4, 163 | # device_count={"CPU": 8}, 164 | ) 165 | # config.graph_options.optimizer_options.global_jit_level = tf.OptimizerOptions.ON_1 # JIT XLA 166 | config.gpu_options.allow_growth = True 167 | # config.gpu_options.per_process_gpu_memory_fraction = 0.9 168 | 169 | with tf.Session(config=config) as sess: 170 | sess.run(tf.global_variables_initializer()) 171 | sess.run(tf.local_variables_initializer()) 172 | 173 | if saver2 is not None: 174 | saver2.restore(sess, perceptual_model_path) 175 | 176 | summary_writer = tf.summary.FileWriter(log_dir, sess.graph) 177 | if iter_from > 0: 178 | saver.restore(sess, tf.train.latest_checkpoint(ckpt_dir)) 179 | summary_writer.reopen() 180 | 181 | run_options = tf.RunOptions(trace_level=tf.RunOptions.NO_TRACE) 182 | run_metadata = tf.RunMetadata() 183 | 184 | print_parameter_count(verbose=False) 185 | 186 | coord = tf.train.Coordinator() 187 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 188 | 189 | sess.run([counter.assign(iter_from)]) 190 | 191 | for i in range(iter_from, max_iter_step): 192 | if status == -1: 193 | break 194 | 195 | if i % 100 == 0: 196 | curr_time = time() 197 | elapsed = curr_time - prev_time 198 | print( 199 | "Now at iteration %d. Elapsed time: %.5fs. Average time: %.5fs/iter" % (i, elapsed, elapsed / 100.)) 200 | prev_time = curr_time 201 | 202 | diters = Diters 203 | 204 | # Train Discriminator 205 | for j in range(diters): 206 | # print(j) 207 | if i % 100 == 0 and j == 0: 208 | _, merged, loss_d_out = sess.run([opt_d, merged_all, loss_d], 209 | options=run_options, 210 | run_metadata=run_metadata) 211 | summary_writer.add_summary(merged, i) 212 | summary_writer.add_run_metadata( 213 | run_metadata, 'discriminator_metadata {}'.format(i), i) 214 | else: 215 | _, loss_d_out = sess.run([opt_d, loss_d]) 216 | if np.isnan(np.sum(loss_d_out)): 217 | status = -1 218 | print("NaN occurred during training D") 219 | return status 220 | 221 | # Train Generator 222 | if i % 100 == 0: 223 | _, merged, loss_g_out, counter_out, _ = sess.run( 224 | [opt_g, merged_all, loss_g, counter, counter_addition_op], 225 | options=run_options, 226 | run_metadata=run_metadata) 227 | summary_writer.add_summary(merged, i) 228 | summary_writer.add_run_metadata( 229 | run_metadata, 'generator_metadata {}'.format(i), i) 230 | else: 231 | _, loss_g_out, counter_out, _ = sess.run([opt_g, loss_g, counter, counter_addition_op]) 232 | if np.isnan(np.sum(loss_g_out)): 233 | status = -1 234 | print("NaN occurred during training G") 235 | return status 236 | 237 | if i % 5000 == 4999: 238 | saver.save(sess, os.path.join( 239 | ckpt_dir, "model.ckpt"), global_step=i) 240 | 241 | if i % 1000 == 999: 242 | this_score = get_inception_score_origin(gen_out, data_format=data_format, 243 | session=sess, n=10000) 244 | merged_sum = sess.run(inception_score_summary, feed_dict={ 245 | inception_score_mean: this_score[0], 246 | inception_score_std: this_score[1], 247 | }) 248 | summary_writer.add_summary(merged_sum, i) 249 | 250 | coord.request_stop() 251 | coord.join(threads) 252 | 253 | return status 254 | 255 | 256 | def test(**kwargs): 257 | 258 | def binarize(sketch, threshold=245): 259 | sketch[sketch < threshold] = 0 260 | sketch[sketch >= threshold] = 255 261 | return sketch 262 | 263 | # Roll out the parameters 264 | appendix = Config.resume_from 265 | batch_size = Config.batch_size 266 | log_dir = Config.log_dir 267 | ckpt_dir = Config.ckpt_dir 268 | data_format = Config.data_format 269 | distance_map = Config.distance_map 270 | small_img = Config.small_img 271 | 272 | build_func = build_single_graph 273 | channel = 3 274 | distance_map = distance_map != 0 275 | small = small_img != 0 276 | if small: 277 | img_dim = 64 278 | else: 279 | img_dim = 256 280 | 281 | output_folder = os.path.join(log_dir, 'out') 282 | print(output_folder) 283 | if not os.path.exists(output_folder): 284 | os.mkdir(output_folder) 285 | 286 | # Time counter 287 | prev_time = float("-inf") 288 | curr_time = float("-inf") 289 | # Construct data queue 290 | with tf.device('/cpu:0'): 291 | images, sketches, class_ids, categories, imagenet_ids, sketch_ids = build_input_queue_paired_sketchy_test( 292 | batch_size=batch_size, data_format=data_format, 293 | distance_map=distance_map, small=small, capacity=512) 294 | 295 | with tf.device('/gpu:0'): 296 | ret_list = build_func(images, sketches, None, None, 297 | class_ids, None, None, 298 | batch_size=batch_size, training=False, 299 | data_format=data_format, 300 | distance_map=distance_map) 301 | 302 | saver = tf.train.Saver() 303 | 304 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 305 | sess.run(tf.global_variables_initializer()) 306 | sess.run(tf.local_variables_initializer()) 307 | 308 | saver.restore(sess, tf.train.latest_checkpoint(ckpt_dir)) 309 | counter = 0 310 | 311 | coord = tf.train.Coordinator() 312 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 313 | 314 | while True: 315 | try: 316 | generated_img, gt_image, input_sketch, category, imagenet_id, sketch_id = sess.run( 317 | [ret_list[0], ret_list[1], ret_list[2], categories, imagenet_ids, sketch_ids]) 318 | except Exception as e: 319 | print(e.args) 320 | break 321 | 322 | if counter % 100 == 0: 323 | curr_time = time() 324 | elapsed = curr_time - prev_time 325 | print( 326 | "Now at iteration %d. Elapsed time: %.5fs." % (counter, elapsed)) 327 | prev_time = curr_time 328 | 329 | if data_format == 'NCHW': 330 | generated_img = np.transpose(generated_img, (0, 2, 3, 1)) 331 | gt_image = np.transpose(gt_image, (0, 2, 3, 1)) 332 | input_sketch = np.transpose(input_sketch, (0, 2, 3, 1)) 333 | generated_img = ((generated_img + 1) / 2.) * 255 334 | gt_image = ((gt_image + 1) / 2.) * 255 335 | input_sketch = ((input_sketch + 1) / 2.) * 255 336 | generated_img = generated_img[:, :, :, ::-1].astype(np.uint8) 337 | gt_image = gt_image[:, :, :, ::-1].astype(np.uint8) 338 | input_sketch = input_sketch.astype(np.uint8) 339 | 340 | for i in range(batch_size): 341 | this_prefix = '%s_%d_%d' % (category[i].decode('ascii'), 342 | int(imagenet_id[i].decode('ascii').split('_')[1]), 343 | sketch_id[i]) 344 | img_out_filename = this_prefix + '_fake_B.png' 345 | img_gt_filename = this_prefix + '_real_B.png' 346 | sketch_in_filename = this_prefix + '_real_A.png' 347 | 348 | # Save file 349 | # file_path = os.path.join(output_folder, 'output_%d.jpg' % int(counter / batch_size)) 350 | cv2.imwrite(os.path.join(output_folder, img_out_filename), generated_img[i]) 351 | cv2.imwrite(os.path.join(output_folder, img_gt_filename), gt_image[i]) 352 | cv2.imwrite(os.path.join(output_folder, sketch_in_filename), input_sketch[i]) 353 | # output_img = np.zeros((img_dim * 2, img_dim * batch_size, channel)) 354 | 355 | print('Saved file %s' % this_prefix) 356 | 357 | counter += 1 358 | 359 | coord.request_stop() 360 | coord.join(threads) -------------------------------------------------------------------------------- /src_single/vgg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains model definitions for versions of the Oxford VGG network. 16 | 17 | These model definitions were introduced in the following technical report: 18 | 19 | Very Deep Convolutional Networks For Large-Scale Image Recognition 20 | Karen Simonyan and Andrew Zisserman 21 | arXiv technical report, 2015 22 | PDF: http://arxiv.org/pdf/1409.1556.pdf 23 | ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf 24 | CC-BY-4.0 25 | 26 | More information can be obtained from the VGG website: 27 | www.robots.ox.ac.uk/~vgg/research/very_deep/ 28 | 29 | Usage: 30 | with slim.arg_scope(vgg.vgg_arg_scope()): 31 | outputs, end_points = vgg.vgg_a(inputs) 32 | 33 | with slim.arg_scope(vgg.vgg_arg_scope()): 34 | outputs, end_points = vgg.vgg_16(inputs) 35 | 36 | @@vgg_a 37 | @@vgg_16 38 | @@vgg_19 39 | """ 40 | from __future__ import absolute_import 41 | from __future__ import division 42 | from __future__ import print_function 43 | 44 | import tensorflow as tf 45 | 46 | slim = tf.contrib.slim 47 | 48 | 49 | def vgg_arg_scope(weight_decay=0.0005): 50 | """Defines the VGG arg scope. 51 | 52 | Args: 53 | weight_decay: The l2 regularization coefficient. 54 | 55 | Returns: 56 | An arg_scope. 57 | """ 58 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 59 | activation_fn=tf.nn.relu, 60 | weights_regularizer=slim.l2_regularizer(weight_decay), 61 | biases_initializer=tf.zeros_initializer()): 62 | with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc: 63 | return arg_sc 64 | 65 | 66 | def vgg_a(inputs, 67 | num_classes=1000, 68 | is_training=True, 69 | dropout_keep_prob=0.5, 70 | spatial_squeeze=True, 71 | scope='vgg_a', 72 | fc_conv_padding='VALID'): 73 | """Oxford Net VGG 11-Layers version A Example. 74 | 75 | Note: All the fully_connected layers have been transformed to conv2d layers. 76 | To use in classification mode, resize input to 224x224. 77 | 78 | Args: 79 | inputs: a tensor of size [batch_size, height, width, channels]. 80 | num_classes: number of predicted classes. 81 | is_training: whether or not the model is being trained. 82 | dropout_keep_prob: the probability that activations are kept in the dropout 83 | layers during training. 84 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 85 | outputs. Useful to remove unnecessary dimensions for classification. 86 | scope: Optional scope for the variables. 87 | fc_conv_padding: the type of padding to use for the fully connected layer 88 | that is implemented as a convolutional layer. Use 'SAME' padding if you 89 | are applying the network in a fully convolutional manner and want to 90 | get a prediction map downsampled by a factor of 32 as an output. 91 | Otherwise, the output prediction map will be (input / 32) - 6 in case of 92 | 'VALID' padding. 93 | 94 | Returns: 95 | the last op containing the log predictions and end_points dict. 96 | """ 97 | with tf.variable_scope(scope, 'vgg_a', [inputs]) as sc: 98 | end_points_collection = sc.name + '_end_points' 99 | # Collect outputs for conv2d, fully_connected and max_pool2d. 100 | with slim.arg_scope([slim.conv2d, slim.max_pool2d], 101 | outputs_collections=end_points_collection): 102 | net = slim.repeat(inputs, 1, slim.conv2d, 64, [3, 3], scope='conv1') 103 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 104 | net = slim.repeat(net, 1, slim.conv2d, 128, [3, 3], scope='conv2') 105 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 106 | net = slim.repeat(net, 2, slim.conv2d, 256, [3, 3], scope='conv3') 107 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 108 | net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv4') 109 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 110 | net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv5') 111 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 112 | # Use conv2d instead of fully_connected layers. 113 | net = slim.conv2d(net, 4096, [7, 7], padding=fc_conv_padding, scope='fc6') 114 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 115 | scope='dropout6') 116 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 117 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 118 | scope='dropout7') 119 | net = slim.conv2d(net, num_classes, [1, 1], 120 | activation_fn=None, 121 | normalizer_fn=None, 122 | scope='fc8') 123 | # Convert end_points_collection into a end_point dict. 124 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 125 | if spatial_squeeze: 126 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 127 | end_points[sc.name + '/fc8'] = net 128 | return net, end_points 129 | vgg_a.default_image_size = 224 130 | 131 | 132 | def vgg_16(inputs, 133 | num_classes=1000, 134 | is_training=True, 135 | dropout_keep_prob=0.5, 136 | spatial_squeeze=True, 137 | reuse=False, 138 | num=0, 139 | scope='vgg_16', 140 | fc_conv_padding='VALID'): 141 | """Oxford Net VGG 16-Layers version D Example. 142 | 143 | Note: All the fully_connected layers have been transformed to conv2d layers. 144 | To use in classification mode, resize input to 224x224. 145 | 146 | Args: 147 | inputs: a tensor of size [batch_size, height, width, channels]. 148 | num_classes: number of predicted classes. 149 | is_training: whether or not the model is being trained. 150 | dropout_keep_prob: the probability that activations are kept in the dropout 151 | layers during training. 152 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 153 | outputs. Useful to remove unnecessary dimensions for classification. 154 | scope: Optional scope for the variables. 155 | fc_conv_padding: the type of padding to use for the fully connected layer 156 | that is implemented as a convolutional layer. Use 'SAME' padding if you 157 | are applying the network in a fully convolutional manner and want to 158 | get a prediction map downsampled by a factor of 32 as an output. 159 | Otherwise, the output prediction map will be (input / 32) - 6 in case of 160 | 'VALID' padding. 161 | 162 | Returns: 163 | the last op containing the log predictions and end_points dict. 164 | """ 165 | my_end_points = {} 166 | with tf.variable_scope(scope, 'vgg_16', [inputs], reuse=reuse) as sc: 167 | end_points_collection = sc.name + ('_end_points_%d' % num) 168 | # Collect outputs for conv2d, fully_connected and max_pool2d. 169 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 170 | outputs_collections=end_points_collection): 171 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') 172 | my_end_points[net.aliases[0]] = net 173 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 174 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') 175 | my_end_points[net.aliases[0]] = net 176 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 177 | net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3') 178 | my_end_points[net.aliases[0]] = net 179 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 180 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4') 181 | my_end_points[net.aliases[0]] = net 182 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 183 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5') 184 | my_end_points[net.aliases[0]] = net 185 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 186 | # Use conv2d instead of fully_connected layers. 187 | net = slim.conv2d(net, 4096, [7, 7], padding=fc_conv_padding, scope='fc6') 188 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 189 | scope='dropout6') 190 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 191 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 192 | scope='dropout7') 193 | net = slim.conv2d(net, num_classes, [1, 1], 194 | activation_fn=None, 195 | normalizer_fn=None, 196 | scope='fc8') 197 | # Convert end_points_collection into a end_point dict. 198 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 199 | if spatial_squeeze: 200 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 201 | end_points[sc.name + '/fc8'] = net 202 | return net, end_points, my_end_points 203 | vgg_16.default_image_size = 224 204 | 205 | 206 | def vgg_19(inputs, 207 | num_classes=1000, 208 | is_training=True, 209 | dropout_keep_prob=0.5, 210 | spatial_squeeze=True, 211 | reuse=False, 212 | num=0, 213 | scope='vgg_19', 214 | fc_conv_padding='VALID'): 215 | """Oxford Net VGG 19-Layers version E Example. 216 | 217 | Note: All the fully_connected layers have been transformed to conv2d layers. 218 | To use in classification mode, resize input to 224x224. 219 | 220 | Args: 221 | inputs: a tensor of size [batch_size, height, width, channels]. 222 | num_classes: number of predicted classes. 223 | is_training: whether or not the model is being trained. 224 | dropout_keep_prob: the probability that activations are kept in the dropout 225 | layers during training. 226 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 227 | outputs. Useful to remove unnecessary dimensions for classification. 228 | scope: Optional scope for the variables. 229 | fc_conv_padding: the type of padding to use for the fully connected layer 230 | that is implemented as a convolutional layer. Use 'SAME' padding if you 231 | are applying the network in a fully convolutional manner and want to 232 | get a prediction map downsampled by a factor of 32 as an output. 233 | Otherwise, the output prediction map will be (input / 32) - 6 in case of 234 | 'VALID' padding. 235 | 236 | 237 | Returns: 238 | the last op containing the log predictions and end_points dict. 239 | """ 240 | with tf.variable_scope(scope, 'vgg_19', [inputs], reuse=reuse) as sc: 241 | end_points_collection = sc.name + ('_end_points_%d' % num) 242 | # Collect outputs for conv2d, fully_connected and max_pool2d. 243 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 244 | outputs_collections=end_points_collection): 245 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') 246 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 247 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') 248 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 249 | net = slim.repeat(net, 4, slim.conv2d, 256, [3, 3], scope='conv3') 250 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 251 | net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv4') 252 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 253 | net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv5') 254 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 255 | # Use conv2d instead of fully_connected layers. 256 | net = slim.conv2d(net, 4096, [7, 7], padding=fc_conv_padding, scope='fc6') 257 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 258 | scope='dropout6') 259 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 260 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 261 | scope='dropout7') 262 | net = slim.conv2d(net, num_classes, [1, 1], 263 | activation_fn=None, 264 | normalizer_fn=None, 265 | scope='fc8') 266 | # Convert end_points_collection into a end_point dict. 267 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 268 | if spatial_squeeze: 269 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 270 | end_points[sc.name + '/fc8'] = net 271 | return net, end_points 272 | vgg_19.default_image_size = 224 273 | 274 | # Alias 275 | vgg_d = vgg_16 276 | vgg_e = vgg_19 277 | --------------------------------------------------------------------------------