├── .gitignore ├── README.md ├── config.py ├── dataset_miscs ├── build_tfrs.py └── build_tfrs_places.py ├── exp_configs ├── la_final.json ├── la_plc_trans_final.json └── la_trans_final.json ├── framework.py ├── model ├── __init__.py ├── alexnet_model.py ├── cluster_km.py ├── dataset_utils.py ├── instance_model.py ├── memory_bank.py ├── prep_utils.py ├── preprocessing.py ├── resnet_model.py └── vggnet_model.py ├── param_setter.py ├── train.py ├── train_tfutils.py ├── train_transfer.py ├── train_transfer_tfutils.py ├── trans_param_setter.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | 27 | # PyInstaller 28 | # Usually these files are written by a python script from a template 29 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 30 | *.manifest 31 | *.spec 32 | 33 | # Installer logs 34 | pip-log.txt 35 | pip-delete-this-directory.txt 36 | 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | .hypothesis/ 47 | .idea 48 | 49 | # Translations 50 | *.mo 51 | *.pot 52 | 53 | # Django stuff: 54 | *.log 55 | local_settings.py 56 | 57 | # Flask stuff: 58 | instance/ 59 | .webassets-cache 60 | 61 | # Scrapy stuff: 62 | .scrapy 63 | 64 | # Sphinx documentation 65 | docs/_build/ 66 | 67 | # PyBuilder 68 | target/ 69 | 70 | # IPython Notebook 71 | .ipynb_checkpoints 72 | 73 | # pyenv 74 | .python-version 75 | 76 | # celery beat schedule file 77 | celerybeat-schedule 78 | 79 | # dotenv 80 | .env 81 | 82 | # virtualenv 83 | venv/ 84 | ENV/ 85 | 86 | # Spyder project settings 87 | .spyderproject 88 | 89 | # Rope project settings 90 | .ropeproject 91 | 92 | data/coco 93 | data/pretrained_models 94 | tags 95 | output 96 | 97 | *.swp 98 | *.gz 99 | 100 | # emacs autosave files 101 | *~ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Local Aggregation for Unsupervised Learning of Visual Embeddings 2 | This repo implements the Local Aggregation (LA) algorithm on ImageNet and related transfer learning pipelines for both ImageNet and Places205. 3 | Pytorch implementation of this algorithm is at [LocalAggregation-Pytorch](https://github.com/neuroailab/LocalAggregation-Pytorch). 4 | This repo also includes a tensorflow implementation for the Instance Recognition (IR) task introduced in paper "Unsupervised Feature Learning via Non-Parametric Instance Discrimination". 5 | 6 | # Pretrained Model 7 | A Local-Aggregation pretrained ResNet-18 model can be found at [link](http://visualmaster-models.s3.amazonaws.com/la_orig/checkpoint-1901710.tar), though this model may not be as good as a fully trained model by this repo, as it's a slightly earlier checkpoint than the final one. 8 | 9 | # Instructions for training 10 | 11 | ## Prerequisites 12 | We have tested this repo under Ubuntu 16.04 with tensorflow version 1.9.0. 13 | Training LA model requires `faiss==1.6.1`. 14 | 15 | ## Data preparation 16 | Prepare the ImageNet data as the raw JPEG format used in pytorch ImageNet training (see [link](https://github.com/pytorch/vision/blob/master/torchvision/datasets/imagenet.py)). 17 | Then run the following command: 18 | ``` 19 | python dataset_miscs/build_tfrs.py --save_dir /path/to/imagenet/tfrs --img_folder /path/to/imagenet/raw/folder 20 | ``` 21 | 22 | ## Model training 23 | We provide implementations for LA trained AlexNet, VggNet, ResNet-18, and ResNet-50. 24 | We provide commands for ResNet-18 training, while commands for other networks can be acquired through slightly modifying these commands after inspecting for `exp_configs/la_final.json`. 25 | As LA algorithm requires training the model using IR algorithm for 10 epochs as a warm start, we first run the IR training using the following command: 26 | ``` 27 | python train.py --config exp_configs/la_final.json:res18_IR --image_dir /path/to/imagenet/tfrs --gpu [your gpu number] --cache_dir /path/to/model/save/folder 28 | ``` 29 | Then run the following command to do the LA training: 30 | ``` 31 | python train.py --config exp_configs/la_final.json:res18_LA --image_dir /path/to/imagenet/tfrs --gpu [your gpu number] --cache_dir /path/to/model/save/folder 32 | ``` 33 | 34 | ### Code reading 35 | 36 | For your convenience, the most important function you want to look at is function `build_targets` in script `model/instance_model.py`. 37 | 38 | ## Transfer learning to ImageNet 39 | After finishing the LA training, run the following command to do the transfer learning to ImageNet: 40 | ``` 41 | python train_transfer.py --config exp_configs/la_trans_final.json:trans_res18_LA --image_dir /path/to/imagenet/tfrs --gpu [your gpu number] --cache_dir /path/to/model/save/folder 42 | ``` 43 | 44 | ## Transfer learning to Places205 45 | Generate the tfrecords for Places205 using the following command: 46 | ``` 47 | python dataset_miscs/build_tfrs_places.py --out_dir /path/to/places205/tfrs --csv_folder /path/to/places205/csvs --base_dir /path/to/places205/raw/folder --run 48 | ``` 49 | `/path/to/places205/csvs` should include `train_places205.csv` and `val_places205.csv` for Places205. 50 | `/path/to/places205/raw/folder` should include the raw Places205 images such as `/path/to/places205/raw/folder/data/vision/torralba/deeplearning/images256/a/abbey/gsun_0003586c3eedd97457b2d729ebfe18b5.jpg` 51 | 52 | Then, run this command for transfer learning: 53 | ``` 54 | python train_transfer.py --config exp_configs/la_plc_trans_final.json:plc_trans_res18_LA --image_dir /path/to/imagenet/tfrs --gpu [your gpu number] --cache_dir /path/to/model/save/folder 55 | ``` 56 | 57 | ## Multi-GPU training 58 | Unfortunately, this implementation does not support an efficient multi-gpu training, which is non-trivial in tensorflow. 59 | Instead, we provide another implementation using [TFUtils](https://github.com/neuroailab/tfutils), which supports multi-gpu training but requires installing TFUtils. 60 | After installing TFUtils, run the same training commands using `train_tfutils.py` and `train_transfer_tfutils.py` with multi-gpu argument such as `--gpu a,b,c,d`, where `a,b,c,d` are the gpu numbers used. 61 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | import argparse 3 | import json 4 | 5 | 6 | def named_choices(choices): 7 | def convert(val): 8 | if val not in choices: 9 | raise Exception('%s is not a recognized ' 10 | 'choice (choices are %s)' 11 | % (val, ', '.join(choices.keys()))) 12 | return choices[val] 13 | return convert 14 | 15 | 16 | class Config(object): 17 | def __init__(self): 18 | self.parser = argparse.ArgumentParser( 19 | description="Train instance task using dataset interface") 20 | self.parser.add_argument('--config', default=None, 21 | type=str, 22 | help="Path to a JSON file containing configuration info. Any " \ 23 | "configurations loaded from this file are superseded by " \ 24 | "configurations passed from the command line.") 25 | self.fields = [] 26 | self.required_fields = [] 27 | 28 | self._reserved = ['config', 'description'] 29 | self._default_values = {} 30 | self._types = {} 31 | 32 | def add(self, field, type, help, 33 | default=None, required=False, 34 | action='store'): 35 | def _assert(cond, mesg): 36 | if not cond: 37 | raise Exception("Error in defining flag %s: %s" % (field, mesg)) 38 | _assert(field not in self._reserved, "flag name reserved!") 39 | _assert(field not in self.fields, "already defined!") 40 | 41 | if type is bool: 42 | if default is None: 43 | default = False 44 | self.parser.add_argument( 45 | '--' + field, default=None, 46 | help=help, action='store_true') 47 | else: 48 | self.parser.add_argument( 49 | '--' + field, default=None, type=type, 50 | help=help, action=action) 51 | 52 | self.fields.append(field) 53 | self._types[field] = type 54 | 55 | if default is not None: 56 | _assert(not required, "default doesn't make sense " \ 57 | "when flag is required!") 58 | self._default_values[field] = type(default) 59 | if required: 60 | self.required_fields.append(field) 61 | 62 | def parse_config_file(self, config_str): 63 | if config_str is None: 64 | return {} 65 | 66 | parts = config_str.split(':') 67 | assert len(parts) <= 2 68 | if len(parts) < 2: 69 | parts.append(None) 70 | path, config_name = parts 71 | 72 | def strip_comments(s): 73 | # Quick-and-dirty way to strip comments. Should work for our 74 | # purposes. 75 | lines = s.split('\n') 76 | lines = filter(lambda x: not x.strip().startswith('//'), lines) 77 | return '\n'.join(lines) 78 | 79 | f = open(path) 80 | json_str = strip_comments(f.read()) 81 | json_dict = json.loads(json_str) 82 | if config_name is not None: 83 | if config_name not in json_dict: 84 | raise Exception("Could not find configuration called '%s' " 85 | "in file '%s'" % (config_name, path)) 86 | json_dict = json_dict[config_name] 87 | return json_dict 88 | 89 | def parse_args(self): 90 | args = self.parser.parse_args() 91 | file_cfg = self.parse_config_file(args.config) 92 | 93 | # Configuration priority: 94 | # 1. Explicit command line values 95 | # 2. Config file values 96 | # 3. Default values 97 | for field in self.fields: 98 | cmd_val = getattr(args, field) 99 | if cmd_val is not None: 100 | continue 101 | 102 | if field in file_cfg: 103 | value = self._types[field](file_cfg[field]) 104 | setattr(args, field, value) 105 | elif field in self._default_values: 106 | setattr(args, field, self._default_values[field]) 107 | 108 | for field in self.required_fields: 109 | if getattr(args, field) is None: 110 | raise Exception("Missing required argument %s" % field) 111 | return args 112 | -------------------------------------------------------------------------------- /dataset_miscs/build_tfrs.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import numpy as np 3 | import argparse 4 | import tensorflow as tf 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | 9 | def _bytes_feature(value): 10 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 11 | 12 | 13 | def _int64_feature(value): 14 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 15 | 16 | 17 | def get_parser(): 18 | parser = argparse.ArgumentParser( 19 | description='Generate tfrecords from jpeg images, ' \ 20 | + 'default parameters are for node7') 21 | parser.add_argument( 22 | '--save_dir', 23 | default='/data5/chengxuz/Dataset/test_la_imagenet_tfr', type=str, 24 | action='store', help='Directory to save the tfrecords') 25 | parser.add_argument( 26 | '--img_folder', 27 | default='/data5/chengxuz/Dataset/imagenet_raw', type=str, 28 | action='store', help='Directory storing the original images') 29 | parser.add_argument( 30 | '--random_seed', 31 | default=0, type=int, 32 | action='store', help='Random seed for numpy') 33 | return parser 34 | 35 | 36 | def get_label_dict(folder): 37 | all_nouns = os.listdir(folder) 38 | all_nouns.sort() 39 | label_dict = {noun:idx for idx, noun in enumerate(all_nouns)} 40 | return label_dict, all_nouns 41 | 42 | 43 | def get_imgs_from_dir(synset_dir): 44 | curr_imgs = os.listdir(synset_dir) 45 | curr_imgs = [os.path.join(synset_dir, each_img) 46 | for each_img in curr_imgs] 47 | curr_imgs.sort() 48 | return curr_imgs 49 | 50 | 51 | def get_path_and_label(img_folder): 52 | all_path_labels = [] 53 | label_dict, all_nouns = get_label_dict(img_folder) 54 | print('Getting all image paths') 55 | for each_noun in tqdm(all_nouns): 56 | curr_paths = get_imgs_from_dir( 57 | os.path.join(img_folder, each_noun)) 58 | curr_path_labels = [(each_path, label_dict[each_noun]) \ 59 | for each_path in curr_paths] 60 | all_path_labels.extend(curr_path_labels) 61 | return all_path_labels 62 | 63 | 64 | class ImageCoder(object): 65 | """ 66 | Helper class that provides TensorFlow image coding utilities. 67 | from https://github.com/tensorflow/models/blob/a156e20367c2a8195ba11da2e1d8589e93afdf40/research/inception/inception/data/build_imagenet_data.py 68 | """ 69 | 70 | def __init__(self): 71 | # Create a single Session to run all image coding calls. 72 | os.environ['CUDA_VISIBLE_DEVICES'] = '' 73 | self._sess = tf.Session() 74 | 75 | # Initializes function that converts PNG to JPEG data. 76 | self._png_data = tf.placeholder(dtype=tf.string) 77 | image = tf.image.decode_png(self._png_data, channels=3) 78 | self._png_to_jpeg = tf.image.encode_jpeg(image, format='rgb', quality=100) 79 | 80 | # Initializes function that converts CMYK JPEG data to RGB JPEG data. 81 | self._cmyk_data = tf.placeholder(dtype=tf.string) 82 | image = tf.image.decode_jpeg(self._cmyk_data, channels=0) 83 | self._cmyk_to_rgb = tf.image.encode_jpeg(image, format='rgb', quality=100) 84 | 85 | def png_to_jpeg(self, image_data): 86 | return self._sess.run(self._png_to_jpeg, 87 | feed_dict={self._png_data: image_data}) 88 | 89 | def cmyk_to_rgb(self, image_data): 90 | return self._sess.run(self._cmyk_to_rgb, 91 | feed_dict={self._cmyk_data: image_data}) 92 | 93 | coder = ImageCoder() 94 | 95 | 96 | def _is_png(filename): 97 | """Determine if a file contains a PNG format image. 98 | Args: 99 | filename: string, path of the image file. 100 | Returns: 101 | boolean indicating if the image is a PNG. 102 | """ 103 | # File list from: 104 | # https://groups.google.com/forum/embed/?place=forum/torch7#!topic/torch7/fOSTXHIESSU 105 | return 'n02105855_2933.JPEG' in filename 106 | 107 | 108 | def _is_cmyk(filename): 109 | """Determine if file contains a CMYK JPEG format image. 110 | Args: 111 | filename: string, path of the image file. 112 | Returns: 113 | boolean indicating if the image is a JPEG encoded with CMYK color space. 114 | """ 115 | # File list from: 116 | # https://github.com/cytsai/ilsvrc-cmyk-image-list 117 | # add one validation file which is also CMYK 118 | blacklist = ['n01739381_1309.JPEG', 'n02077923_14822.JPEG', 119 | 'n02447366_23489.JPEG', 'n02492035_15739.JPEG', 120 | 'n02747177_10752.JPEG', 'n03018349_4028.JPEG', 121 | 'n03062245_4620.JPEG', 'n03347037_9675.JPEG', 122 | 'n03467068_12171.JPEG', 'n03529860_11437.JPEG', 123 | 'n03544143_17228.JPEG', 'n03633091_5218.JPEG', 124 | 'n03710637_5125.JPEG', 'n03961711_5286.JPEG', 125 | 'n04033995_2932.JPEG', 'n04258138_17003.JPEG', 126 | 'n04264628_27969.JPEG', 'n04336792_7448.JPEG', 127 | 'n04371774_5854.JPEG', 'n04596742_4225.JPEG', 128 | 'n07583066_647.JPEG', 'n13037406_4650.JPEG', 129 | 'ILSVRC2012_val_00019877.JPEG'] 130 | return filename.split('/')[-1] in blacklist 131 | 132 | 133 | def get_img_raw_str(jpg_path): 134 | with tf.gfile.FastGFile(jpg_path, 'rb') as f: 135 | img_raw_str = f.read() 136 | if _is_png(jpg_path): 137 | # 1 image is a PNG. 138 | print('Converting PNG to JPEG for %s' % jpg_path) 139 | img_raw_str = coder.png_to_jpeg(img_raw_str) 140 | elif _is_cmyk(jpg_path): 141 | # 23 JPEG images are in CMYK colorspace. 142 | print('Converting CMYK to RGB for %s' % jpg_path) 143 | img_raw_str = coder.cmyk_to_rgb(img_raw_str) 144 | return img_raw_str 145 | 146 | 147 | def write_to_tfrs(tfrs_path, curr_file_list): 148 | # Write each image and label 149 | writer = tf.python_io.TFRecordWriter(tfrs_path) 150 | for idx, jpg_path, lbl in curr_file_list: 151 | img_raw_str = get_img_raw_str(jpg_path) 152 | example = tf.train.Example(features=tf.train.Features(feature={ 153 | 'images': _bytes_feature(img_raw_str), 154 | 'labels': _int64_feature(lbl), 155 | 'index': _int64_feature(idx), 156 | })) 157 | writer.write(example.SerializeToString()) 158 | writer.close() 159 | 160 | 161 | def build_all_tfrs_from_folder( 162 | folder_path, num_tfrs, tfr_pat, 163 | random_seed=None): 164 | # get all path and labels, shuffle them if needed 165 | all_path_labels = get_path_and_label(folder_path) 166 | if random_seed is not None: 167 | np.random.seed(random_seed) 168 | all_path_labels = np.random.permutation(all_path_labels) 169 | overall_num_imgs = len(all_path_labels) 170 | all_path_lbl_idx = [(idx, path, int(lbl)) \ 171 | for idx, (path, lbl) in enumerate(all_path_labels)] 172 | print('%i images in total' % overall_num_imgs) 173 | 174 | # Cut into num_tfr tfrecords and write each of them 175 | num_img_per = int(np.ceil(overall_num_imgs*1.0/num_tfrs)) 176 | print('Writing into tfrecords') 177 | for curr_tfr in tqdm(range(num_tfrs)): 178 | tfrs_path = tfr_pat % (curr_tfr, num_tfrs) 179 | start_num = curr_tfr * num_img_per 180 | end_num = min((curr_tfr+1) * num_img_per, overall_num_imgs) 181 | write_to_tfrs(tfrs_path, all_path_lbl_idx[start_num:end_num]) 182 | 183 | 184 | def main(): 185 | parser = get_parser() 186 | args = parser.parse_args() 187 | os.system('mkdir -p %s' % args.save_dir) 188 | 189 | build_all_tfrs_from_folder( 190 | os.path.join(args.img_folder, 'train'), 191 | 1024, 192 | os.path.join(args.save_dir, 'train-%05i-of-%05i'), 193 | args.random_seed) 194 | build_all_tfrs_from_folder( 195 | os.path.join(args.img_folder, 'val'), 196 | 128, 197 | os.path.join(args.save_dir, 'validation-%05i-of-%05i')) 198 | 199 | 200 | if __name__=="__main__": 201 | main() 202 | -------------------------------------------------------------------------------- /dataset_miscs/build_tfrs_places.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import argparse 3 | import os 4 | import numpy as np 5 | import sys 6 | import pdb 7 | from tqdm import tqdm 8 | 9 | NUM_TRAIN_TFR = 1024 10 | NUM_VAL_TFR = 128 11 | SEED = 0 12 | TRAIN_PAT = 'train-%05i-%05i' 13 | VAL_PAT = 'validation-%05i-%05i' 14 | 15 | 16 | def _bytes_feature(value): 17 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 18 | 19 | 20 | def _int64_feature(value): 21 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 22 | 23 | 24 | def get_parser(): 25 | parser = argparse.ArgumentParser( 26 | description='The script to write to tfrecords for places') 27 | parser.add_argument( 28 | '--out_dir', 29 | default='/data5/chengxuz/Dataset/places/tfrs_205', 30 | type=str, action='store', 31 | help='Output directory') 32 | parser.add_argument( 33 | '--csv_folder', 34 | default='/data5/chengxuz/Dataset/places/split/trainvalsplit_places205', 35 | type=str, action='store', 36 | help='Csv folder') 37 | parser.add_argument( 38 | '--base_dir', 39 | default='/data5/chengxuz/Dataset/places/images', 40 | type=str, action='store', 41 | help='Image base folder') 42 | parser.add_argument( 43 | '--run', 44 | action='store_true', 45 | help='Whether actually run') 46 | return parser 47 | 48 | 49 | def get_jpg_list(csv_path): 50 | fin = open(csv_path, 'r') 51 | all_lines = fin.readlines() 52 | 53 | all_jpg_lbls = [] 54 | for each_line in all_lines: 55 | try: 56 | line_splits = each_line.split() 57 | jpg_path = line_splits[0] 58 | curr_label = int(line_splits[1]) 59 | 60 | all_jpg_lbls.append((jpg_path, curr_label)) 61 | except: 62 | print(each_line) 63 | return all_jpg_lbls 64 | 65 | 66 | def get_train_val_list(args): 67 | train_csv_path = os.path.join(args.csv_folder, 'train_places205.csv') 68 | val_csv_path = os.path.join(args.csv_folder, 'val_places205.csv') 69 | 70 | train_jpg_lbls = get_jpg_list(train_csv_path) 71 | val_jpg_lbls = get_jpg_list(val_csv_path) 72 | return train_jpg_lbls, val_jpg_lbls 73 | 74 | 75 | def write_one_rec(writer, idx, img_path, lbl, args): 76 | img_path = os.path.join(args.base_dir, img_path) 77 | img_jpg_str = open(img_path, 'rb').read() 78 | feature_dict = { 79 | 'images': _bytes_feature(img_jpg_str), 80 | 'labels': _int64_feature(lbl)} 81 | if idx is not None: 82 | feature_dict['index'] = _int64_feature(idx) 83 | example = tf.train.Example(features=tf.train.Features(feature=feature_dict)) 84 | writer.write(example.SerializeToString()) 85 | 86 | 87 | def write_tfrs(jpg_lbls, num_files, file_pat, args): 88 | no_imgs_each_tfr = int(len(jpg_lbls) / num_files) 89 | curr_tfr_idx = 0 90 | writer = None 91 | for idx, (img_path, lbl) in enumerate(tqdm(jpg_lbls)): 92 | if idx % no_imgs_each_tfr == 0: 93 | if writer is not None: 94 | writer.close() 95 | tfr_path = os.path.join( 96 | args.out_dir, file_pat % (curr_tfr_idx, num_files)) 97 | writer = tf.python_io.TFRecordWriter(tfr_path) 98 | curr_tfr_idx += 1 99 | write_one_rec(writer, idx, img_path, lbl, args) 100 | writer.close() 101 | 102 | 103 | def main(): 104 | parser = get_parser() 105 | args = parser.parse_args() 106 | args.base_dir = os.path.join( 107 | args.base_dir, 108 | 'data/vision/torralba/deeplearning/images256/') 109 | 110 | os.system('mkdir -p {path}'.format(path=args.out_dir)) 111 | train_jpg_lbls, val_jpg_lbls = get_train_val_list(args) 112 | 113 | np.random.seed(SEED) 114 | np.random.shuffle(train_jpg_lbls) 115 | if args.run: 116 | write_tfrs(train_jpg_lbls, NUM_TRAIN_TFR, TRAIN_PAT, args) 117 | write_tfrs(val_jpg_lbls, NUM_VAL_TFR, VAL_PAT, args) 118 | 119 | 120 | if __name__=="__main__": 121 | main() 122 | -------------------------------------------------------------------------------- /exp_configs/la_final.json: -------------------------------------------------------------------------------- 1 | { 2 | "alexnet_IR": { 3 | "exp_id": "alexnet_IR", 4 | "port": 27009, 5 | "task": "IR", 6 | "db_name": "la_pub", 7 | "col_name": "test", 8 | "model_type": "alexnet_bn_no_drop", 9 | "train_num_steps": 100090, 10 | "fre_filter": 100090, 11 | "fre_cache_filter": 10009 12 | }, 13 | "alexnet_LA": { 14 | "exp_id": "alexnet_LA", 15 | "port": 27009, 16 | "db_name": "la_pub", 17 | "col_name": "test", 18 | "model_type": "alexnet_bn_no_drop", 19 | "load_exp": "la_pub/test/alexnet_IR", 20 | "load_step": 100090, 21 | "lr_boundaries": "1190000,1423100", 22 | "train_num_steps": 1601440, 23 | "kmeans_k": "30000,30000,30000", 24 | "fre_filter": 100090, 25 | "fre_cache_filter": 10009 26 | }, 27 | "vggnet_IR": { 28 | "exp_id": "vggnet_IR", 29 | "port": 27009, 30 | "task": "IR", 31 | "db_name": "la_pub", 32 | "col_name": "test", 33 | "model_type": "vggnet_fx", 34 | "train_num_steps": 100090, 35 | "fre_filter": 100090, 36 | "fre_cache_filter": 10009 37 | }, 38 | "vggnet_LA": { 39 | "exp_id": "vggnet_LA", 40 | "port": 27009, 41 | "db_name": "la_pub", 42 | "col_name": "test", 43 | "model_type": "vggnet_fx", 44 | "load_exp": "la_pub/test/vggnet_IR", 45 | "load_step": 100090, 46 | "lr_boundaries": "680681,790791", 47 | "kmeans_k": "10000,10000,10000,10000,10000,10000", 48 | "train_num_steps": 890712, 49 | "batch_size": 256, 50 | "fre_valid": 5004, 51 | "fre_filter": 50040, 52 | "fre_cache_filter": 5004 53 | }, 54 | "res18_IR": { 55 | "exp_id": "res18_IR", 56 | "port": 27009, 57 | "task": "IR", 58 | "db_name": "la_pub", 59 | "col_name": "test", 60 | "train_num_steps": 100090, 61 | "fre_filter": 100090, 62 | "fre_cache_filter": 10009 63 | }, 64 | "res18_LA": { 65 | "exp_id": "res18_LA", 66 | "port": 27009, 67 | "task": "LA", 68 | "db_name": "la_pub", 69 | "col_name": "test", 70 | "load_exp": "la_pub/test/res18_IR", 71 | "load_step": 100090, 72 | "kmeans_k": "30000,30000,30000,30000,30000,30000,30000,30000,30000,30000", 73 | "lr_boundaries": "1600000,1873411", 74 | "train_num_steps": 2001800, 75 | "fre_filter": 100090, 76 | "fre_cache_filter": 10009 77 | }, 78 | "res50_IR": { 79 | "exp_id": "res50_IR", 80 | "port": 27009, 81 | "task": "IR", 82 | "db_name": "la_pub", 83 | "col_name": "test", 84 | "model_type": "resnet50", 85 | "train_num_steps": 100090, 86 | "fre_filter": 100090, 87 | "fre_cache_filter": 10009 88 | }, 89 | "res50_LA": { 90 | "exp_id": "res50_LA", 91 | "port": 27009, 92 | "db_name": "la_pub", 93 | "col_name": "test", 94 | "model_type": "resnet50", 95 | "load_exp": "la_pub/test/res50_IR", 96 | "load_step": 100090, 97 | "lr_boundaries": "1522000,1895857", 98 | "train_num_steps": 2001800, 99 | "kmeans_k": "30000,30000,30000,30000,30000,30000,30000,30000,30000,30000", 100 | "fre_filter": 100090, 101 | "fre_cache_filter": 10009 102 | } 103 | } 104 | -------------------------------------------------------------------------------- /exp_configs/la_plc_trans_final.json: -------------------------------------------------------------------------------- 1 | { 2 | "plc_trans_alexnet_LA": { 3 | "save_exp": "la_pub/plc_trans_test/alexnet_LA", 4 | "load_exp": "la_pub/test/alexnet_LA", 5 | "train_crop": "alexnet_crop_flip", 6 | "num_classes": 205, 7 | "port": 27009, 8 | "ten_crop": true, 9 | "model_type": "alexnet_bn_no_drop", 10 | "get_all_layers": "conv1,conv2,conv3,conv4,conv5", 11 | "weight_decay": 0, 12 | "load_step": 1601440, 13 | "lr_boundaries": "2051845,2151935", 14 | "train_num_steps": 2252025, 15 | "fre_filter": 50045, 16 | "fre_cache_filter": 10009 17 | }, 18 | "plc_trans_vggnet_LA": { 19 | "save_exp": "la_pub/plc_trans_test/vggnet_LA", 20 | "load_exp": "la_pub/test/vggnet_LA", 21 | "train_crop": "alexnet_crop_flip", 22 | "num_classes": 205, 23 | "port": 27009, 24 | "ten_crop": true, 25 | "model_type": "vggnet_fx", 26 | "get_all_layers": "0,1,2,3,4", 27 | "load_step": 890712, 28 | "lr_boundaries": "1000900,1301170", 29 | "train_num_steps": 1401260, 30 | "fre_filter": 50045, 31 | "fre_cache_filter": 10009 32 | }, 33 | "plc_trans_res18_LA": { 34 | "save_exp": "la_pub/plc_trans_test/res18_LA", 35 | "load_exp": "la_pub/test/res18_LA", 36 | "train_crop": "alexnet_crop_flip", 37 | "num_classes": 205, 38 | "port": 27009, 39 | "ten_crop": true, 40 | "get_all_layers": "1,3,5,7,9", 41 | "load_step": 2001800, 42 | "lr_boundaries": "2302070,2602340", 43 | "train_num_steps": 2702430, 44 | "fre_filter": 50045, 45 | "fre_cache_filter": 10009 46 | }, 47 | "plc_trans_res50_LA": { 48 | "save_exp": "la_pub/plc_trans_test/res50_LA", 49 | "load_exp": "la_pub/test/res50_LA", 50 | "train_crop": "alexnet_crop_flip", 51 | "num_classes": 205, 52 | "port": 27009, 53 | "ten_crop": true, 54 | "model_type": "resnet50", 55 | "get_all_layers": "1,4,8,14,17", 56 | "load_step": 2001800, 57 | "lr_boundaries": "2302070,2602340", 58 | "train_num_steps": 2702430, 59 | "fre_filter": 50045, 60 | "fre_cache_filter": 10009 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /exp_configs/la_trans_final.json: -------------------------------------------------------------------------------- 1 | { 2 | "trans_alexnet_LA": { 3 | "save_exp": "la_pub/trans_test/alexnet_LA", 4 | "load_exp": "la_pub/test/alexnet_LA", 5 | "port": 27009, 6 | "ten_crop": true, 7 | "model_type": "alexnet_bn_no_drop", 8 | "get_all_layers": "conv1,conv2,conv3,conv4,conv5", 9 | "train_crop": "alexnet_crop_flip", 10 | "weight_decay": 0, 11 | "load_step": 1601440, 12 | "lr_boundaries": "2051845,2151935", 13 | "train_num_steps": 2252025, 14 | "fre_filter": 50045, 15 | "fre_cache_filter": 10009 16 | }, 17 | "trans_vggnet_LA": { 18 | "save_exp": "la_pub/trans_test/vggnet_LA", 19 | "load_exp": "la_pub/test/vggnet_LA", 20 | "port": 27009, 21 | "ten_crop": true, 22 | "model_type": "vggnet_fx", 23 | "get_all_layers": "0,1,2,3,4", 24 | "load_step": 890712, 25 | "lr_boundaries": "1000900,1301170", 26 | "train_num_steps": 1401260, 27 | "fre_filter": 50045, 28 | "fre_cache_filter": 10009 29 | }, 30 | "trans_res18_LA": { 31 | "save_exp": "la_pub/trans_test/res18_LA", 32 | "load_exp": "la_pub/test/res18_LA", 33 | "port": 27009, 34 | "ten_crop": true, 35 | "get_all_layers": "1,3,5,7,9", 36 | "load_step": 2001800, 37 | "lr_boundaries": "2302070,2602340", 38 | "train_num_steps": 2702430, 39 | "fre_filter": 50045, 40 | "fre_cache_filter": 10009 41 | }, 42 | "trans_res50_LA": { 43 | "save_exp": "la_pub/trans_test/res50_LA", 44 | "load_exp": "la_pub/test/res50_LA", 45 | "port": 27009, 46 | "ten_crop": true, 47 | "model_type": "resnet50", 48 | "get_all_layers": "1,4,8,14,17", 49 | "load_step": 2001800, 50 | "lr_boundaries": "2302070,2602340", 51 | "train_num_steps": 2702430, 52 | "fre_filter": 50045, 53 | "fre_cache_filter": 10009 54 | } 55 | } 56 | -------------------------------------------------------------------------------- /framework.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | import time 4 | import tqdm 5 | import pdb 6 | import copy 7 | 8 | 9 | class TrainFramework(object): 10 | def __init__(self, params): 11 | self.params = params 12 | self.save_params = params['save_params'] 13 | self.load_params = params['load_params'] 14 | self.train_params = params['train_params'] 15 | self.model_params = params['model_params'] 16 | self.loss_params = params['loss_params'] 17 | self.validation_params = params['validation_params'] 18 | 19 | # Set cache directory 20 | self.cache_dir = self.save_params['cache_dir'] 21 | os.system('mkdir -p %s' % self.cache_dir) 22 | 23 | self.log_file_path = os.path.join(self.cache_dir, 'log.txt') 24 | self.val_log_file_path = os.path.join(self.cache_dir, 'val_log.txt') 25 | 26 | self.load_from_curr_exp = tf.train.latest_checkpoint(self.cache_dir) 27 | 28 | if not self.load_from_curr_exp: 29 | self.log_writer = open(self.log_file_path, 'w') 30 | self.val_log_writer = open(self.val_log_file_path, 'w') 31 | else: 32 | self.log_writer = open(self.log_file_path, 'a+') 33 | self.val_log_writer = open(self.val_log_file_path, 'a+') 34 | 35 | def build_inputs(self): 36 | data_params = self.train_params['data_params'] 37 | func = data_params.pop('func') 38 | self.inputs = func(**data_params) 39 | 40 | def build_network(self, inputs, train): 41 | model_params = self.model_params 42 | func = model_params.pop('func') 43 | outputs, _ = func( 44 | inputs=inputs, 45 | train=train, 46 | **model_params) 47 | model_params['func'] = func 48 | 49 | if 'trainable_scopes' in model_params: 50 | trainable_scopes = model_params['trainable_scopes'] 51 | all_train_ref = tf.get_collection_ref( 52 | tf.GraphKeys.TRAINABLE_VARIABLES) 53 | cp_all_train_ref = copy.copy(all_train_ref) 54 | for each_v in cp_all_train_ref: 55 | should_be_trainable = False 56 | for each_trainable_scope in trainable_scopes: 57 | if each_v.op.name.startswith(each_trainable_scope): 58 | should_be_trainable = True 59 | if not should_be_trainable: 60 | all_train_ref.remove(each_v) 61 | return outputs 62 | 63 | def build_train_op(self): 64 | loss_params = self.loss_params 65 | 66 | input_targets = [self.inputs[key] \ 67 | for key in loss_params['pred_targets']] 68 | func = loss_params['loss_func'] 69 | self.loss_retval = func( 70 | self.outputs, 71 | *input_targets, 72 | **loss_params.get('loss_func_kwargs', {})) 73 | self.loss_retval = loss_params['agg_func']( 74 | self.loss_retval, 75 | **loss_params.get('agg_func_kwargs', {})) 76 | 77 | self.global_step = tf.get_variable( 78 | 'global_step', [], 79 | dtype=tf.int64, trainable=False, 80 | initializer=tf.constant_initializer(0)) 81 | lr_rate_params = self.params['learning_rate_params'] 82 | func = lr_rate_params.pop('func') 83 | learning_rate = func(self.global_step, **lr_rate_params) 84 | self.learning_rate = learning_rate 85 | 86 | opt_params = self.params['optimizer_params'] 87 | func = opt_params.pop('optimizer') 88 | opt = func(learning_rate=learning_rate, **opt_params) 89 | 90 | with tf.control_dependencies( 91 | tf.get_collection(tf.GraphKeys.UPDATE_OPS)): 92 | self.train_op = opt.minimize( 93 | self.loss_retval, 94 | global_step=self.global_step) 95 | 96 | def build_train_targets(self): 97 | 98 | extra_targets_params = self.train_params['targets'] 99 | func = extra_targets_params.pop('func') 100 | train_targets = func(self.inputs, self.outputs, **extra_targets_params) 101 | 102 | train_targets['train_op'] = self.train_op 103 | train_targets['loss'] = self.loss_retval 104 | train_targets['learning_rate'] = self.learning_rate 105 | 106 | self.train_targets = train_targets 107 | 108 | def build_sess_and_saver(self): 109 | gpu_options = tf.GPUOptions(allow_growth=True) 110 | sess = tf.Session(config=tf.ConfigProto( 111 | allow_soft_placement=True, 112 | gpu_options=gpu_options, 113 | )) 114 | self.sess = sess 115 | self.saver = tf.train.Saver() 116 | 117 | def load_from_ckpt(self, ckpt_path): 118 | print('Restore from %s' % ckpt_path) 119 | self.saver.restore(self.sess, ckpt_path) 120 | 121 | def init_and_restore(self): 122 | init_op_global = tf.global_variables_initializer() 123 | self.sess.run(init_op_global) 124 | init_op_local = tf.local_variables_initializer() 125 | self.sess.run(init_op_local) 126 | 127 | if self.load_from_curr_exp: 128 | self.load_from_ckpt(self.load_from_curr_exp) 129 | else: 130 | split_cache_path = self.cache_dir.split('/') 131 | split_cache_path[-1] = self.load_params['exp_id'] 132 | split_cache_path[-2] = self.load_params['collname'] 133 | split_cache_path[-3] = self.load_params['dbname'] 134 | load_dir = '/'.join(split_cache_path) 135 | if self.load_params['query']: 136 | ckpt_path = os.path.join( 137 | load_dir, 138 | 'model.ckpt-%i' % self.load_params['query']['step']) 139 | else: 140 | ckpt_path = tf.train.latest_checkpoint(load_dir) 141 | if ckpt_path: 142 | print('Restore from %s' % ckpt_path) 143 | #self.load_from_ckpt(ckpt_path) 144 | reader = tf.train.NewCheckpointReader(ckpt_path) 145 | saved_var_shapes = reader.get_variable_to_shape_map() 146 | 147 | all_vars = tf.global_variables() 148 | all_var_list = {v.op.name: v for v in all_vars} 149 | filtered_var_list = {} 150 | for name, var in all_var_list.items(): 151 | if name in saved_var_shapes: 152 | curr_shape = var.get_shape().as_list() 153 | saved_shape = saved_var_shapes[name] 154 | if (curr_shape == saved_shape): 155 | filtered_var_list[name] = var 156 | else: 157 | print('Shape mismatch for %s: ' % name \ 158 | + str(curr_shape) \ 159 | + str(saved_shape)) 160 | _load_saver = tf.train.Saver(var_list=filtered_var_list) 161 | _load_saver.restore(self.sess, ckpt_path) 162 | 163 | def run_each_validation(self, val_key): 164 | agg_res = None 165 | num_steps = self.validation_params[val_key]['num_steps'] 166 | for _step in tqdm.trange(num_steps, desc=val_key): 167 | res = self.sess.run(self.all_val_targets[val_key]) 168 | online_func = self.validation_params[val_key]['online_agg_func'] 169 | agg_res = online_func(agg_res, res, _step) 170 | agg_func = self.validation_params[val_key]['agg_func'] 171 | val_result = agg_func(agg_res) 172 | return val_result 173 | 174 | def run_train_loop(self): 175 | start_step = self.sess.run(self.global_step) 176 | train_loop = self.train_params.get('train_loop', None) 177 | 178 | for curr_step in xrange(start_step, self.train_params['num_steps']+1): 179 | self.start_time = time.time() 180 | if train_loop is None: 181 | train_res = self.sess.run(self.train_targets) 182 | else: 183 | train_res = train_loop['func'](self.sess, self.train_targets) 184 | 185 | duration = time.time() - self.start_time 186 | 187 | message = 'Step {} ({:.0f} ms) -- '\ 188 | .format(curr_step, 1000 * duration) 189 | rep_msg = ['{}: {:.4f}'.format(k, v) \ 190 | for k, v in train_res.items() 191 | if k != 'train_op'] 192 | message += ', '.join(rep_msg) 193 | print(message) 194 | 195 | if curr_step % self.save_params['cache_filters_freq'] == 0 \ 196 | and curr_step > 0: 197 | print('Saving model...') 198 | self.saver.save( 199 | self.sess, 200 | os.path.join( 201 | self.cache_dir, 202 | 'model.ckpt'), 203 | global_step=curr_step) 204 | 205 | self.log_writer.write(message + '\n') 206 | if curr_step % self.save_params['save_metrics_freq'] == 0: 207 | self.log_writer.close() 208 | self.log_writer = open(self.log_file_path, 'a+') 209 | 210 | if curr_step % self.save_params['save_valid_freq'] == 0: 211 | for each_val_key in self.validation_params: 212 | val_result = self.run_each_validation(each_val_key) 213 | self.val_log_writer.write( 214 | '%s: %s\n' % (each_val_key, str(val_result))) 215 | print(val_result) 216 | self.val_log_writer.close() 217 | self.val_log_writer = open(self.val_log_file_path, 'a+') 218 | 219 | def build_train(self): 220 | self.build_inputs() 221 | self.outputs = self.build_network(self.inputs, True) 222 | self.build_train_op() 223 | self.build_train_targets() 224 | 225 | def build_val_inputs(self, val_key): 226 | data_params = self.validation_params[val_key]['data_params'] 227 | func = data_params.pop('func') 228 | val_inputs = func(**data_params) 229 | return val_inputs 230 | 231 | def build_val_network(self, val_key, val_inputs): 232 | with tf.name_scope('validation/' + val_key): 233 | val_outputs = self.build_network(val_inputs, False) 234 | return val_outputs 235 | 236 | def build_val_targets(self, val_key, val_inputs, val_outputs): 237 | target_params = self.validation_params[val_key]['targets'] 238 | func = target_params.pop('func') 239 | val_targets = func(val_inputs, val_outputs, **target_params) 240 | return val_targets 241 | 242 | def build_val(self): 243 | tf.get_variable_scope().reuse_variables() 244 | self.all_val_targets = {} 245 | for each_val_key in self.validation_params: 246 | val_inputs = self.build_val_inputs(each_val_key) 247 | val_outputs = self.build_val_network(each_val_key, val_inputs) 248 | val_targets = self.build_val_targets( 249 | each_val_key, val_inputs, val_outputs) 250 | self.all_val_targets[each_val_key] = val_targets 251 | 252 | def train(self): 253 | self.build_train() 254 | self.build_val() 255 | 256 | self.build_sess_and_saver() 257 | self.init_and_restore() 258 | 259 | self.run_train_loop() 260 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neuroailab/LocalAggregation/87ac59350eca9d6d79d4a29eec3aa8eb34c5ad5c/model/__init__.py -------------------------------------------------------------------------------- /model/alexnet_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a model definition for AlexNet. 16 | This work was first described in: 17 | ImageNet Classification with Deep Convolutional Neural Networks 18 | Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton 19 | and later refined in: 20 | One weird trick for parallelizing convolutional neural networks 21 | Alex Krizhevsky, 2014 22 | Here we provide the implementation proposed in "One weird trick" and not 23 | "ImageNet Classification", as per the paper, the LRN layers have been removed. 24 | Usage: 25 | with slim.arg_scope(alexnet.alexnet_v2_arg_scope()): 26 | outputs, end_points = alexnet.alexnet_v2(inputs) 27 | @@alexnet_v2 28 | """ 29 | 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | 34 | import tensorflow as tf 35 | from collections import OrderedDict 36 | 37 | slim = tf.contrib.slim 38 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 39 | 40 | 41 | def alexnet_v2_arg_scope(weight_decay=0.0005): 42 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 43 | activation_fn=tf.nn.relu, 44 | biases_initializer=tf.constant_initializer(0.1), 45 | weights_regularizer=slim.l2_regularizer(weight_decay)): 46 | with slim.arg_scope([slim.conv2d], padding='SAME'): 47 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 48 | return arg_sc 49 | 50 | 51 | def alexnet_v2(inputs, 52 | num_classes=1000, 53 | is_training=True, 54 | dropout_keep_prob=0.5, 55 | spatial_squeeze=True, 56 | scope='alexnet_v2', 57 | global_pool=False, 58 | with_bn=False): 59 | """AlexNet version 2. 60 | Described in: http://arxiv.org/pdf/1404.5997v2.pdf 61 | Parameters from: 62 | github.com/akrizhevsky/cuda-convnet2/blob/master/layers/ 63 | layers-imagenet-1gpu.cfg 64 | Note: All the fully_connected layers have been transformed to conv2d layers. 65 | To use in classification mode, resize input to 224x224 or set 66 | global_pool=True. To use in fully convolutional mode, set 67 | spatial_squeeze to false. 68 | The LRN layers have been removed and change the initializers from 69 | random_normal_initializer to xavier_initializer. 70 | Args: 71 | inputs: a tensor of size [batch_size, height, width, channels]. 72 | num_classes: the number of predicted classes. If 0 or None, the logits layer 73 | is omitted and the input features to the logits layer are returned instead. 74 | is_training: whether or not the model is being trained. 75 | dropout_keep_prob: the probability that activations are kept in the dropout 76 | layers during training. 77 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 78 | logits. Useful to remove unnecessary dimensions for classification. 79 | scope: Optional scope for the variables. 80 | global_pool: Optional boolean flag. If True, the input to the classification 81 | layer is avgpooled to size 1x1, for any input size. (This is not part 82 | of the original AlexNet.) 83 | Returns: 84 | net: the output of the logits layer (if num_classes is a non-zero integer), 85 | or the non-dropped-out input to the logits layer (if num_classes is 0 86 | or None). 87 | end_points: a dict of tensors with intermediate activations. 88 | """ 89 | end_points = OrderedDict() 90 | with tf.variable_scope(scope, 'alexnet_v2', [inputs]) as sc: 91 | end_points_collection = sc.original_name_scope + '_end_points' 92 | # Collect outputs for conv2d, fully_connected and max_pool2d. 93 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 94 | outputs_collections=[end_points_collection]): 95 | if not with_bn: 96 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 97 | scope='conv1') 98 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool1') 99 | net = slim.conv2d(net, 192, [5, 5], scope='conv2') 100 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool2') 101 | net = slim.conv2d(net, 384, [3, 3], scope='conv3') 102 | net = slim.conv2d(net, 384, [3, 3], scope='conv4') 103 | net = slim.conv2d(net, 256, [3, 3], scope='conv5') 104 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool5') 105 | else: 106 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 107 | scope='conv1', activation_fn=None) 108 | net = tf.layers.batch_normalization( 109 | inputs=net, 110 | momentum=0.997, epsilon=1e-5, 111 | training=is_training, fused=True, 112 | name='conv1') 113 | net = tf.nn.relu(net) 114 | end_points[sc.name + '/conv1'] = net 115 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool1') 116 | end_points[sc.name + '/pool1'] = net 117 | 118 | net = slim.conv2d(net, 192, [5, 5], scope='conv2', activation_fn=None) 119 | net = tf.layers.batch_normalization( 120 | inputs=net, 121 | momentum=0.997, epsilon=1e-5, 122 | training=is_training, fused=True, 123 | name='conv2') 124 | net = tf.nn.relu(net) 125 | end_points[sc.name + '/conv2'] = net 126 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool2') 127 | end_points[sc.name + '/pool2'] = net 128 | 129 | net = slim.conv2d(net, 384, [3, 3], scope='conv3', activation_fn=None) 130 | net = tf.layers.batch_normalization( 131 | inputs=net, 132 | momentum=0.997, epsilon=1e-5, 133 | training=is_training, fused=True, 134 | name='conv3') 135 | net = tf.nn.relu(net) 136 | end_points[sc.name + '/conv3'] = net 137 | 138 | net = slim.conv2d(net, 384, [3, 3], scope='conv4', activation_fn=None) 139 | net = tf.layers.batch_normalization( 140 | inputs=net, 141 | momentum=0.997, epsilon=1e-5, 142 | training=is_training, fused=True, 143 | name='conv4') 144 | net = tf.nn.relu(net) 145 | end_points[sc.name + '/conv4'] = net 146 | 147 | net = slim.conv2d(net, 256, [3, 3], scope='conv5', activation_fn=None) 148 | net = tf.layers.batch_normalization( 149 | inputs=net, 150 | momentum=0.997, epsilon=1e-5, 151 | training=is_training, fused=True, 152 | name='conv5') 153 | net = tf.nn.relu(net) 154 | end_points[sc.name + '/conv5'] = net 155 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool5') 156 | end_points[sc.name + '/pool5'] = net 157 | 158 | # Use conv2d instead of fully_connected layers. 159 | with slim.arg_scope([slim.conv2d], 160 | weights_initializer=trunc_normal(0.005), 161 | biases_initializer=tf.constant_initializer(0.1)): 162 | if not with_bn: 163 | net = slim.conv2d(net, 4096, [5, 5], padding='VALID', 164 | scope='fc6') 165 | if dropout_keep_prob > 0: 166 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 167 | scope='dropout6') 168 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 169 | else: 170 | net = slim.conv2d(net, 4096, [5, 5], padding='VALID', 171 | scope='fc6', activation_fn=None) 172 | net = tf.layers.batch_normalization( 173 | inputs=net, 174 | momentum=0.997, epsilon=1e-5, 175 | training=is_training, fused=True, 176 | name='fc6') 177 | net = tf.nn.relu(net) 178 | end_points[sc.name + '/fc6'] = net 179 | if dropout_keep_prob > 0: 180 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 181 | scope='dropout6') 182 | 183 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7', activation_fn=None) 184 | net = tf.layers.batch_normalization( 185 | inputs=net, 186 | momentum=0.997, epsilon=1e-5, 187 | training=is_training, fused=True, 188 | name='fc7') 189 | net = tf.nn.relu(net) 190 | end_points[sc.name + '/fc7'] = net 191 | 192 | # Convert end_points_collection into a end_point dict. 193 | #end_points = slim.utils.convert_collection_to_dict( 194 | # end_points_collection) 195 | if global_pool: 196 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 197 | end_points['global_pool'] = net 198 | if num_classes: 199 | if dropout_keep_prob > 0: 200 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 201 | scope='dropout7') 202 | net = slim.conv2d(net, num_classes, [1, 1], 203 | activation_fn=None, 204 | normalizer_fn=None, 205 | biases_initializer=tf.zeros_initializer(), 206 | scope='fc8') 207 | if spatial_squeeze: 208 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 209 | end_points[sc.name + '/fc8'] = net 210 | return net, end_points 211 | 212 | 213 | def alexnet_v2_with_bn_no_drop(*args, **kwargs): 214 | kwargs['with_bn'] = True 215 | kwargs['dropout_keep_prob'] = 0 216 | return alexnet_v2( 217 | *args, **kwargs) 218 | 219 | 220 | def alexnet_v2_with_bn(*args, **kwargs): 221 | kwargs['with_bn'] = True 222 | return alexnet_v2( 223 | *args, **kwargs) 224 | 225 | 226 | alexnet_v2.default_image_size = 224 227 | -------------------------------------------------------------------------------- /model/cluster_km.py: -------------------------------------------------------------------------------- 1 | import time 2 | import faiss 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | DEFAULT_SEED = 1234 7 | 8 | 9 | def run_kmeans(x, nmb_clusters, verbose=False, seed=DEFAULT_SEED): 10 | """Runs kmeans on 1 GPU. 11 | Args: 12 | x: data 13 | nmb_clusters (int): number of clusters 14 | Returns: 15 | list: ids of data in each cluster 16 | """ 17 | n_data, d = x.shape 18 | 19 | # faiss implementation of k-means 20 | clus = faiss.Clustering(d, nmb_clusters) 21 | clus.niter = 20 22 | clus.max_points_per_centroid = 10000000 23 | clus.seed = seed 24 | res = faiss.StandardGpuResources() 25 | flat_config = faiss.GpuIndexFlatConfig() 26 | flat_config.useFloat16 = False 27 | flat_config.device = 0 28 | index = faiss.GpuIndexFlatL2(res, d, flat_config) 29 | 30 | # perform the training 31 | clus.train(x, index) 32 | _, I = index.search(x, 1) 33 | losses = faiss.vector_to_array(clus.obj) 34 | if verbose: 35 | print('k-means loss evolution: {0}'.format(losses)) 36 | 37 | return [int(n[0]) for n in I], losses[-1] 38 | 39 | 40 | class Kmeans: 41 | def __init__(self, k, memory_bank, cluster_labels): 42 | self.k = k 43 | self.memory_bank = memory_bank 44 | self.cluster_labels = cluster_labels 45 | 46 | self.new_cluster_feed = tf.placeholder( 47 | tf.int64, shape=self.cluster_labels.get_shape().as_list()) 48 | self.update_clusters_op = tf.assign( 49 | self.cluster_labels, self.new_cluster_feed) 50 | 51 | def recompute_clusters(self, sess, verbose=True): 52 | """Performs k-means clustering. 53 | Args: 54 | x_data (np.array N * dim): data to cluster 55 | """ 56 | end = time.time() 57 | 58 | data = sess.run(self.memory_bank.as_tensor()) 59 | 60 | all_lables = [] 61 | for k_idx, each_k in enumerate(self.k): 62 | # cluster the data 63 | I, _ = run_kmeans(data, each_k, 64 | verbose, seed = k_idx + DEFAULT_SEED) 65 | new_clust_labels = np.asarray(I) 66 | all_lables.append(new_clust_labels) 67 | new_clust_labels = np.stack(all_lables, axis=0) 68 | 69 | if verbose: 70 | print('k-means time: {0:.0f} s'.format(time.time() - end)) 71 | return new_clust_labels 72 | 73 | def apply_clusters(self, sess, new_clust_labels): 74 | sess.run(self.update_clusters_op, feed_dict={ 75 | self.new_cluster_feed: new_clust_labels 76 | }) 77 | -------------------------------------------------------------------------------- /model/dataset_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | import os, sys 3 | import functools 4 | import numpy as np 5 | import tensorflow as tf 6 | 7 | 8 | def image_dir_to_tfrecords_dataset(image_dir, is_train): 9 | pattern = 'train-*' if is_train else 'validation-*' 10 | pattern = os.path.join(image_dir, pattern) 11 | datasource = tf.gfile.Glob(pattern) 12 | datasource.sort() 13 | tfr_list = np.asarray(datasource) 14 | dataset = tf.data.Dataset.list_files(tfr_list) 15 | 16 | if is_train: 17 | dataset = dataset.apply( 18 | tf.contrib.data.shuffle_and_repeat(len(tfr_list)) 19 | ) 20 | else: 21 | dataset = dataset.repeat() 22 | 23 | def fetch(filename): 24 | buffer_size = 32 * 1024 * 1024 # 32 MiB per file 25 | return tf.data.TFRecordDataset(filename, buffer_size=buffer_size) 26 | 27 | dataset = dataset.apply( 28 | tf.contrib.data.parallel_interleave( 29 | fetch, cycle_length=8, sloppy=True)) 30 | return dataset 31 | 32 | 33 | def data_parser(record_str_tensor, process_img_func, 34 | is_train=True, with_indx=True): 35 | ''' 36 | Takes a TFRecord string and outputs a dictionary ready to use 37 | as input to the model. 38 | ''' 39 | 40 | # Parse the TFRecord 41 | keys_to_features = { 42 | 'images': tf.FixedLenFeature((), tf.string, ''), 43 | 'labels': tf.FixedLenFeature([], tf.int64, -1)} 44 | if with_indx: 45 | keys_to_features['index'] = tf.FixedLenFeature([], tf.int64, -1) 46 | parsed = tf.parse_single_example(record_str_tensor, keys_to_features) 47 | image_string = parsed['images'] 48 | image_label = parsed['labels'] 49 | image_index = parsed.get('index', None) 50 | 51 | # Process the image 52 | image = process_img_func(image_string) 53 | ret_dict = {'image': image, 'label': image_label} 54 | if with_indx: 55 | ret_dict['index'] = image_index 56 | return ret_dict 57 | 58 | 59 | def dataset_func(image_dir, process_img_func, is_train, batch_size, q_cap): 60 | dataset = image_dir_to_tfrecords_dataset(image_dir, is_train=is_train) 61 | if is_train: 62 | dataset = dataset.shuffle(buffer_size=q_cap) 63 | dataset = dataset.prefetch(batch_size * 4) 64 | dataset = dataset.map(functools.partial( 65 | data_parser, process_img_func=process_img_func, 66 | is_train=is_train, with_indx=is_train 67 | ), num_parallel_calls=64) 68 | dataset = dataset.apply( 69 | tf.contrib.data.batch_and_drop_remainder(batch_size)) 70 | dataset = dataset.prefetch(4) 71 | next_element = dataset.make_one_shot_iterator().get_next() 72 | return next_element 73 | -------------------------------------------------------------------------------- /model/instance_model.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | import os, sys 3 | import json 4 | import numpy as np 5 | import tensorflow as tf 6 | import copy 7 | import pdb 8 | from collections import OrderedDict 9 | 10 | from . import resnet_model 11 | from . import alexnet_model 12 | from . import vggnet_model 13 | from .memory_bank import MemoryBank 14 | from .prep_utils import ColorNormalize 15 | 16 | DATA_LEN_IMAGENET_FULL = 1281167 17 | 18 | 19 | def get_global_step_var(): 20 | global_step_vars = [v for v in tf.global_variables() \ 21 | if 'global_step' in v.name] 22 | assert len(global_step_vars) == 1 23 | return global_step_vars[0] 24 | 25 | 26 | def assert_shape(t, shape): 27 | assert t.get_shape().as_list() == shape, \ 28 | "Got shape %r, expected %r" % (t.get_shape().as_list(), shape) 29 | 30 | 31 | def flatten(layer_out): 32 | curr_shape = layer_out.get_shape().as_list() 33 | if len(curr_shape) > 2: 34 | layer_out = tf.reshape(layer_out, [curr_shape[0], -1]) 35 | return layer_out 36 | 37 | 38 | def get_alexnet_all_layers(all_layers, get_all_layers): 39 | if get_all_layers == 'default' or get_all_layers is None: 40 | keys = ['pool1', 'pool2', 'conv3', 'conv4', 'pool5', 'fc6', 'fc7'] 41 | elif get_all_layers == 'conv_all': 42 | keys = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'fc6', 'fc7'] 43 | elif get_all_layers == 'conv5-avg': 44 | keys = ['conv5'] 45 | else: 46 | keys = get_all_layers.split(',') 47 | 48 | output_dict = OrderedDict() 49 | for each_key in keys: 50 | for layer_name, layer_out in all_layers.items(): 51 | if each_key in layer_name: 52 | if get_all_layers == 'conv5-avg': 53 | layer_out = tf.nn.avg_pool( 54 | layer_out, 55 | ksize=[1,2,2,1], 56 | strides=[1,2,2,1], 57 | padding='SAME') 58 | layer_out = flatten(layer_out) 59 | output_dict[each_key] = layer_out 60 | break 61 | return output_dict 62 | 63 | 64 | def get_resnet_all_layers(ending_points, get_all_layers): 65 | ending_dict = {} 66 | get_all_layers = get_all_layers.split(',') 67 | for idx, layer_out in enumerate(ending_points): 68 | if str(idx) in get_all_layers: 69 | feat_size = np.prod(layer_out.get_shape().as_list()[1:]) 70 | if feat_size > 200000: 71 | pool_size = 2 72 | if feat_size / 4 > 200000: 73 | pool_size = 4 74 | layer_out = tf.transpose(layer_out, [0, 2, 3, 1]) 75 | layer_out = tf.nn.avg_pool( 76 | layer_out, 77 | ksize=[1,pool_size,pool_size,1], 78 | strides=[1,pool_size,pool_size,1], 79 | padding='SAME') 80 | layer_out = flatten(layer_out) 81 | ending_dict[str(idx)] = layer_out 82 | return ending_dict 83 | 84 | 85 | def network_embedding( 86 | img_batch, 87 | model_type, 88 | dtype=tf.float32, 89 | data_format=None, train=False, 90 | resnet_version=resnet_model.DEFAULT_VERSION, 91 | get_all_layers=None, 92 | skip_final_dense=False): 93 | image = tf.cast(img_batch, tf.float32) 94 | image = tf.div(image, tf.constant(255, dtype=tf.float32)) 95 | image = tf.map_fn(ColorNormalize, image) 96 | 97 | if model_type.startswith('resnet'): 98 | resnet_size = int(model_type[6:]) 99 | model = resnet_model.ImagenetModel( 100 | resnet_size, data_format, 101 | resnet_version=resnet_version, 102 | dtype=dtype) 103 | 104 | if skip_final_dense and get_all_layers is None: 105 | return model(image, train, skip_final_dense=True) 106 | 107 | if get_all_layers: 108 | _, ending_points = model( 109 | image, train, get_all_layers=get_all_layers) 110 | all_layers = get_resnet_all_layers(ending_points, get_all_layers) 111 | return all_layers 112 | 113 | model_out = model(image, train, skip_final_dense=False) 114 | elif model_type == 'alexnet_bn_no_drop': 115 | model_out, all_layers = alexnet_model.alexnet_v2_with_bn_no_drop( 116 | image, is_training=train, 117 | num_classes=128) 118 | if get_all_layers or skip_final_dense: 119 | all_layers = get_alexnet_all_layers(all_layers, get_all_layers) 120 | if get_all_layers: 121 | return all_layers 122 | else: 123 | return all_layers['fc6'] 124 | elif model_type == 'vggnet_fx': 125 | model_out, ending_points = vggnet_model.vgg_16( 126 | image, is_training=train, 127 | num_classes=128, with_bn=True, 128 | fix_bug=True, 129 | dropout_keep_prob=0) 130 | if get_all_layers: 131 | all_layers = get_resnet_all_layers(ending_points, get_all_layers) 132 | return all_layers 133 | else: 134 | raise ValueError('Model type not supported!') 135 | 136 | return tf.nn.l2_normalize(model_out, axis=1) # [bs, out_dim] 137 | 138 | 139 | def repeat_1d_tensor(t, num_reps): 140 | ret = tf.tile(tf.expand_dims(t, axis=1), (1, num_reps)) 141 | return ret 142 | 143 | 144 | class LossBuilder(object): 145 | def __init__(self, 146 | inputs, output, 147 | memory_bank, 148 | instance_k=4096, 149 | instance_t=0.07, 150 | instance_m=0.5, 151 | **kwargs): 152 | self.inputs = inputs 153 | self.embed_output = output 154 | self.batch_size, self.out_dim = self.embed_output.get_shape().as_list() 155 | self.memory_bank = memory_bank 156 | 157 | self.data_len = memory_bank.size 158 | self.instance_k = instance_k 159 | self.instance_t = instance_t 160 | self.instance_m = instance_m 161 | 162 | def _softmax(self, dot_prods): 163 | instance_Z = tf.constant( 164 | 2876934.2 / 1281167 * self.data_len, 165 | dtype=tf.float32) 166 | return tf.exp(dot_prods / self.instance_t) / instance_Z 167 | 168 | def compute_data_prob(self): 169 | logits = self.memory_bank.get_dot_products( 170 | self.embed_output, 171 | self.inputs['index']) 172 | return self._softmax(logits) 173 | 174 | def compute_noise_prob(self): 175 | noise_indx = tf.random_uniform( 176 | shape=(self.batch_size, self.instance_k), 177 | minval=0, 178 | maxval=self.data_len, 179 | dtype=tf.int64) 180 | noise_probs = self._softmax( 181 | self.memory_bank.get_dot_products(self.embed_output, noise_indx)) 182 | return noise_probs 183 | 184 | def updated_new_data_memory(self): 185 | data_indx = self.inputs['index'] # [bs] 186 | data_memory = self.memory_bank.at_idxs(data_indx) 187 | new_data_memory = (data_memory * self.instance_m 188 | + (1 - self.instance_m) * self.embed_output) 189 | return tf.nn.l2_normalize(new_data_memory, axis=1) 190 | 191 | def __get_close_nei_in_back( 192 | self, each_k_idx, cluster_labels, 193 | back_nei_idxs, k): 194 | batch_labels = tf.gather( 195 | cluster_labels[each_k_idx], 196 | self.inputs['index']) 197 | 198 | top_cluster_labels = tf.gather( 199 | cluster_labels[each_k_idx], back_nei_idxs) 200 | batch_labels = repeat_1d_tensor(batch_labels, k) 201 | curr_close_nei = tf.equal(batch_labels, top_cluster_labels) 202 | return curr_close_nei 203 | 204 | def __get_relative_prob(self, all_close_nei, back_nei_probs): 205 | relative_probs = tf.reduce_sum( 206 | tf.where( 207 | all_close_nei, 208 | x=back_nei_probs, y=tf.zeros_like(back_nei_probs), 209 | ), axis=1) 210 | relative_probs /= tf.reduce_sum(back_nei_probs, axis=1) 211 | return relative_probs 212 | 213 | def get_LA_loss( 214 | self, cluster_labels, 215 | k=None): 216 | if not k: 217 | k = self.instance_k 218 | # use the top k nearest examples as background neighbors 219 | all_dps = self.memory_bank.get_all_dot_products(self.embed_output) 220 | back_nei_dps, back_nei_idxs = tf.nn.top_k(all_dps, k=k, sorted=False) 221 | back_nei_probs = self._softmax(back_nei_dps) 222 | 223 | no_kmeans = cluster_labels.get_shape().as_list()[0] 224 | all_close_nei = None 225 | for each_k_idx in range(no_kmeans): 226 | curr_close_nei = self.__get_close_nei_in_back( 227 | each_k_idx, cluster_labels, back_nei_idxs, k) 228 | 229 | if all_close_nei is None: 230 | all_close_nei = curr_close_nei 231 | else: 232 | all_close_nei = tf.logical_or(all_close_nei, curr_close_nei) 233 | relative_probs = self.__get_relative_prob( 234 | all_close_nei, back_nei_probs) 235 | 236 | assert_shape(relative_probs, [self.batch_size]) 237 | loss = -tf.reduce_mean(tf.log(relative_probs + 1e-7)) 238 | return loss 239 | 240 | def get_IR_losses(self): 241 | data_prob = self.compute_data_prob() 242 | noise_prob = self.compute_noise_prob() 243 | assert_shape(data_prob, [self.batch_size]) 244 | assert_shape(noise_prob, [self.batch_size, self.instance_k]) 245 | 246 | base_prob = 1.0 / self.data_len 247 | eps = 1e-7 248 | ## Pmt 249 | data_div = data_prob + (self.instance_k*base_prob + eps) 250 | ln_data = tf.log(data_prob / data_div) 251 | ## Pon 252 | noise_div = noise_prob + (self.instance_k*base_prob + eps) 253 | ln_noise = tf.log((self.instance_k*base_prob) / noise_div) 254 | 255 | curr_loss = -(tf.reduce_sum(ln_data) \ 256 | + tf.reduce_sum(ln_noise)) / self.batch_size 257 | return curr_loss, \ 258 | -tf.reduce_sum(ln_data)/self.batch_size, \ 259 | -tf.reduce_sum(ln_noise)/self.batch_size 260 | 261 | 262 | def build_targets( 263 | inputs, train, 264 | model_type, 265 | kmeans_k, 266 | task, 267 | **kwargs): 268 | # This will be stored in the db 269 | logged_cfg = {'kwargs': kwargs} 270 | 271 | data_len = kwargs.get('data_len', DATA_LEN_IMAGENET_FULL) 272 | with tf.variable_scope('instance', reuse=tf.AUTO_REUSE): 273 | all_labels = tf.get_variable( 274 | 'all_labels', 275 | initializer=tf.zeros_initializer, 276 | shape=(data_len,), 277 | trainable=False, 278 | dtype=tf.int64, 279 | ) 280 | # TODO: hard-coded output dimension 128 281 | memory_bank = MemoryBank(data_len, 128) 282 | 283 | lbl_init_values = tf.range(data_len, dtype=tf.int64) 284 | no_kmeans_k = len(kmeans_k) 285 | lbl_init_values = tf.tile( 286 | tf.expand_dims(lbl_init_values, axis=0), 287 | [no_kmeans_k, 1]) 288 | cluster_labels = tf.get_variable( 289 | 'cluster_labels', 290 | initializer=lbl_init_values, 291 | trainable=False, dtype=tf.int64, 292 | ) 293 | 294 | output = network_embedding( 295 | inputs['image'], train=train, 296 | model_type=model_type) 297 | 298 | if not train: 299 | all_dist = memory_bank.get_all_dot_products(output) 300 | return [all_dist, all_labels], logged_cfg 301 | 302 | loss_builder = LossBuilder( 303 | inputs=inputs, output=output, 304 | memory_bank=memory_bank, 305 | **kwargs) 306 | 307 | if task == 'IR': 308 | loss, loss_model, loss_noise = loss_builder.get_IR_losses() 309 | clustering = None 310 | ret_loss = [loss, loss_model, loss_noise] 311 | elif task == 'LA': 312 | from .cluster_km import Kmeans 313 | clustering = Kmeans(kmeans_k, memory_bank, cluster_labels) 314 | loss = loss_builder.get_LA_loss(cluster_labels) 315 | ret_loss = [loss] 316 | else: 317 | raise NotImplementedError('Task not supported!') 318 | 319 | new_data_memory = loss_builder.updated_new_data_memory() 320 | return { 321 | "losses": ret_loss, 322 | "data_indx": inputs['index'], 323 | "memory_bank": memory_bank.as_tensor(), 324 | "new_data_memory": new_data_memory, 325 | "all_labels": all_labels, 326 | }, logged_cfg, clustering 327 | 328 | 329 | def build_transfer_targets( 330 | inputs, train, 331 | model_type='resnet18', 332 | get_all_layers=None, 333 | num_classes=1000, 334 | **kwargs): 335 | # This will be stored in the db 336 | logged_cfg = {'kwargs': kwargs} 337 | input_image = inputs['image'] 338 | num_crop = None 339 | curr_image_shape = input_image.get_shape().as_list() 340 | batch_size = curr_image_shape[0] 341 | if len(curr_image_shape) > 4: 342 | num_crop = curr_image_shape[1] 343 | input_image = tf.reshape(input_image, [-1] + curr_image_shape[2:]) 344 | 345 | resnet_output = network_embedding( 346 | input_image, 347 | train=False, 348 | model_type=model_type, 349 | skip_final_dense=True, 350 | get_all_layers=get_all_layers) 351 | 352 | with tf.variable_scope('instance', reuse=tf.AUTO_REUSE): 353 | init_builder = tf.contrib.layers.variance_scaling_initializer() 354 | if not get_all_layers: 355 | class_output = tf.layers.dense( 356 | inputs=resnet_output, units=num_classes, 357 | kernel_initializer=init_builder, 358 | trainable=True, 359 | name='transfer_dense') 360 | else: 361 | class_output = OrderedDict() 362 | for key, curr_out in resnet_output.items(): 363 | class_output[key] = tf.layers.dense( 364 | inputs=curr_out, units=num_classes, 365 | kernel_initializer=init_builder, 366 | trainable=True, 367 | name='transfer_dense_{name}'.format(name=key)) 368 | 369 | def __get_loss_accuracy(curr_output): 370 | if num_crop: 371 | curr_output = tf.nn.softmax(curr_output) 372 | curr_output = tf.reshape(curr_output, [batch_size, num_crop, -1]) 373 | curr_output = tf.reduce_mean(curr_output, axis=1) 374 | _, pred = tf.nn.top_k(curr_output, k=1) 375 | pred = tf.cast(tf.squeeze(pred), tf.int64) 376 | accuracy = tf.reduce_mean( 377 | tf.cast(tf.equal(pred, inputs['label']), tf.float32) 378 | ) 379 | 380 | one_hot_labels = tf.one_hot(inputs['label'], num_classes) 381 | loss = tf.losses.softmax_cross_entropy(one_hot_labels, curr_output) 382 | return loss, accuracy 383 | if not get_all_layers: 384 | loss, accuracy = __get_loss_accuracy(class_output) 385 | else: 386 | loss = [] 387 | accuracy = OrderedDict() 388 | for key, curr_out in class_output.items(): 389 | curr_loss, curr_acc = __get_loss_accuracy(curr_out) 390 | loss.append(curr_loss) 391 | accuracy[key] = curr_acc 392 | loss = tf.reduce_sum(loss) 393 | 394 | if not train: 395 | return accuracy, logged_cfg 396 | return [loss, accuracy], logged_cfg 397 | -------------------------------------------------------------------------------- /model/memory_bank.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | import numpy as np 3 | import tensorflow as tf 4 | 5 | 6 | class MemoryBank(object): 7 | def __init__(self, size, dim, seed=None): 8 | self.size = size 9 | self.dim = dim 10 | self.seed = seed or 0 11 | self._bank = self._create() 12 | 13 | def _create(self): 14 | mb_init = tf.random_uniform( 15 | shape=(self.size, self.dim), 16 | seed=self.seed, 17 | ) 18 | std_dev = 1. / np.sqrt(self.dim/3) 19 | mb_init = mb_init * (2*std_dev) - std_dev 20 | return tf.get_variable( 21 | 'memory_bank', 22 | initializer=mb_init, 23 | dtype=tf.float32, 24 | trainable=False, 25 | ) 26 | 27 | def as_tensor(self): 28 | return self._bank 29 | 30 | def at_idxs(self, idxs): 31 | return tf.gather(self._bank, idxs, axis=0) 32 | 33 | def get_all_dot_products(self, vec): 34 | vec_shape = vec.get_shape().as_list() 35 | # [bs, dim] 36 | assert len(vec_shape) == 2 37 | return tf.matmul(vec, tf.transpose(self._bank, [1, 0])) 38 | 39 | def get_dot_products(self, vec, idxs): 40 | vec_shape = vec.get_shape().as_list() 41 | # [bs, dim] 42 | idxs_shape = idxs.get_shape().as_list() 43 | # [bs, ...] 44 | assert len(vec_shape) == 2 45 | assert vec_shape[0] == idxs_shape[0] 46 | 47 | memory_vecs = tf.gather(self._bank, idxs, axis=0) 48 | memory_vecs_shape = memory_vecs.get_shape().as_list() 49 | # [bs, ..., dim] 50 | assert memory_vecs_shape[:-1] == idxs_shape 51 | 52 | vec_shape[1:1] = [1] * (len(idxs_shape) - 1) 53 | vec = tf.reshape(vec, vec_shape) 54 | # [bs, 1,...,1, dim] 55 | 56 | prods = tf.multiply(memory_vecs, vec) 57 | assert prods.get_shape().as_list() == memory_vecs_shape 58 | return tf.reduce_sum(prods, axis=-1) 59 | -------------------------------------------------------------------------------- /model/prep_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import tensorflow as tf 6 | 7 | import os, sys 8 | import numpy as np 9 | import pdb 10 | 11 | EPS = 1e-6 12 | 13 | 14 | def _at_least_x_are_true(a, b, x): 15 | """At least `x` of `a` and `b` `Tensors` are true.""" 16 | match = tf.equal(a, b) 17 | match = tf.cast(match, tf.int32) 18 | return tf.greater_equal(tf.reduce_sum(match), x) 19 | 20 | 21 | def image_resize( 22 | crop_image, out_height, out_width, 23 | ): 24 | resize_func = tf.image.resize_area 25 | image = tf.cast( 26 | resize_func( 27 | [crop_image], 28 | [out_height, out_width])[0], 29 | dtype=tf.uint8) 30 | return image 31 | 32 | 33 | def RandomSizedCrop_from_jpeg( 34 | image_str, 35 | out_height, 36 | out_width, 37 | size_minval=0.08, 38 | ): 39 | shape = tf.image.extract_jpeg_shape(image_str) 40 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) 41 | crop_max_attempts = 100 42 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 43 | shape, 44 | bounding_boxes=bbox, 45 | min_object_covered=0.1, 46 | aspect_ratio_range=(3. / 4, 4. / 3.), 47 | area_range=(size_minval, 1.0), 48 | max_attempts=crop_max_attempts, 49 | use_image_if_no_bounding_boxes=True) 50 | bbox_begin, bbox_size, bbox = sample_distorted_bounding_box 51 | random_image = tf.image.decode_and_crop_jpeg( 52 | image_str, 53 | tf.stack([bbox_begin[0], bbox_begin[1], \ 54 | bbox_size[0], bbox_size[1]]), 55 | channels=3) 56 | bad = _at_least_x_are_true(shape, tf.shape(random_image), 3) 57 | # central crop if bad 58 | min_size = tf.minimum(shape[0], shape[1]) 59 | offset_height = tf.random_uniform( 60 | shape=[], 61 | minval=0, maxval=shape[0] - min_size + 1, 62 | dtype=tf.int32 63 | ) 64 | offset_width = tf.random_uniform( 65 | shape=[], 66 | minval=0, maxval=shape[1] - min_size + 1, 67 | dtype=tf.int32 68 | ) 69 | bad_image = tf.image.decode_and_crop_jpeg( 70 | image_str, 71 | tf.stack([offset_height, offset_width, \ 72 | min_size, min_size]), 73 | channels=3) 74 | image = tf.cond( 75 | bad, 76 | lambda: bad_image, 77 | lambda: random_image, 78 | ) 79 | # if use py_func, will do resize elsewhere 80 | image = image_resize( 81 | image, 82 | out_height, out_width, 83 | ) 84 | image.set_shape([out_height, out_width, 3]) 85 | return image 86 | 87 | 88 | def RandomBrightness(image, low, high): 89 | rnd_bright = tf.random_uniform( 90 | shape=[], 91 | minval=low, maxval=high, 92 | dtype=tf.float32) 93 | #rnd_bright = tf.Print(rnd_bright, [rnd_bright], message='Brigh') 94 | flt_image = tf.cast(image, tf.float32) 95 | blend_image = flt_image * rnd_bright 96 | blend_image = tf.maximum(blend_image, 0) 97 | blend_image = tf.minimum(blend_image, 255) 98 | image_after = tf.cast(blend_image + EPS, tf.uint8) 99 | return image_after 100 | 101 | 102 | def RGBtoGray(flt_image): 103 | flt_image = tf.cast(flt_image, tf.float32) 104 | gry_image = flt_image[:,:,0] * 0.299 \ 105 | + flt_image[:,:,1] * 0.587 \ 106 | + flt_image[:,:,2] * 0.114 107 | gry_image = tf.expand_dims(gry_image, axis=2) 108 | gry_image = tf.cast(gry_image + EPS, tf.uint8) 109 | gry_image = tf.cast(gry_image, tf.float32) 110 | return gry_image 111 | 112 | 113 | def RandomSaturation(image, low, high): 114 | rnd_saturt = tf.random_uniform( 115 | shape=[], 116 | minval=low, maxval=high, 117 | dtype=tf.float32) 118 | #rnd_saturt = tf.Print(rnd_saturt, [rnd_saturt], message='Satu') 119 | flt_image = tf.cast(image, tf.float32) 120 | gry_image = RGBtoGray(flt_image) 121 | blend_image = flt_image * rnd_saturt + gry_image * (1-rnd_saturt) 122 | blend_image = tf.maximum(blend_image, 0) 123 | blend_image = tf.minimum(blend_image, 255) 124 | image_after = tf.cast(blend_image + EPS, tf.uint8) 125 | return image_after 126 | 127 | 128 | def RandomContrast(image, low, high): 129 | rnd_contr = tf.random_uniform( 130 | shape=[], 131 | minval=low, maxval=high, 132 | dtype=tf.float32) 133 | #rnd_contr = tf.Print(rnd_contr, [rnd_contr], message='Contr') 134 | flt_image = tf.cast(image, tf.float32) 135 | mean_gray = tf.cast( 136 | tf.cast( 137 | tf.reduce_mean(RGBtoGray(flt_image)) + EPS, 138 | tf.uint8), 139 | tf.float32) 140 | blend_image = flt_image * rnd_contr + mean_gray * (1-rnd_contr) 141 | blend_image = tf.maximum(blend_image, 0) 142 | blend_image = tf.minimum(blend_image, 255) 143 | image_after = tf.cast(blend_image + EPS, tf.uint8) 144 | return image_after 145 | 146 | 147 | def ColorJitter(image, seed_random=0, 148 | as_batch=False, shape_undefined=1, 149 | ): 150 | order_temp = tf.constant([0,1,2,3], dtype=tf.int32) 151 | order_rand = tf.random_shuffle(order_temp) 152 | #order_rand = tf.Print(order_rand, [order_rand], message='Order') 153 | 154 | random_hue_func = tf.image.random_hue 155 | 156 | fn_pred_fn_pairs = lambda x, image: [ 157 | (tf.equal(x, order_temp[0]), \ 158 | lambda :RandomSaturation(image, 0.6, 1.4)), 159 | (tf.equal(x, order_temp[1]), \ 160 | lambda :RandomBrightness(image, 0.6, 1.4)), 161 | (tf.equal(x, order_temp[2]), \ 162 | lambda :random_hue_func(image, 0.4)), 163 | ] 164 | #default_fn = lambda image: tf.image.random_contrast(image, 0.6, 1.4) 165 | default_fn = lambda image: RandomContrast(image, 0.6, 1.4) 166 | 167 | def _color_jitter_one(_norm): 168 | orig_shape = tf.shape(_norm) 169 | for curr_idx in range(order_temp.get_shape().as_list()[0]): 170 | _norm = tf.case( 171 | fn_pred_fn_pairs(order_rand[curr_idx], _norm), 172 | default=lambda : default_fn(_norm)) 173 | if shape_undefined==0: 174 | _norm.set_shape(orig_shape) 175 | return _norm 176 | if as_batch: 177 | image = tf.map_fn(_color_jitter_one, image) 178 | else: 179 | image = _color_jitter_one(image) 180 | return image 181 | 182 | 183 | def ColorNormalize(image): 184 | transpose_flag = image.get_shape().as_list()[-1] != 3 185 | if transpose_flag: 186 | image = tf.transpose(image, [1,2,0]) 187 | imagenet_mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) 188 | imagenet_std = np.array([0.229, 0.224, 0.225], dtype=np.float32) 189 | image = (image - imagenet_mean) / imagenet_std 190 | if transpose_flag: 191 | image = tf.transpose(image, [2,0,1]) 192 | 193 | return image 194 | 195 | 196 | def ApplyGray(norm, prob_gray, as_batch=False): 197 | def _postprocess_gray(im): 198 | do_gray = tf.random_uniform( 199 | shape=[], 200 | minval=0, 201 | maxval=1, 202 | dtype=tf.float32) 203 | def __gray(im): 204 | gray_im = tf.cast(RGBtoGray(im), tf.uint8) 205 | gray_im = tf.tile(gray_im, [1,1,3]) 206 | return gray_im 207 | return tf.cond( 208 | tf.less(do_gray, prob_gray), 209 | lambda: __gray(im), 210 | lambda: im) 211 | if as_batch: 212 | norm = tf.map_fn(_postprocess_gray, norm, dtype=norm.dtype) 213 | else: 214 | norm = _postprocess_gray(norm) 215 | return norm 216 | 217 | 218 | def get_resize_scale(height, width, smallest_side): 219 | smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32) 220 | 221 | height = tf.to_float(height) 222 | width = tf.to_float(width) 223 | smallest_side = tf.to_float(smallest_side) 224 | 225 | scale = tf.cond( 226 | tf.greater(height, width), 227 | lambda: smallest_side / width, 228 | lambda: smallest_side / height) 229 | return scale 230 | 231 | 232 | def alexnet_crop_from_jpg(image_string): 233 | """ 234 | Resize the image to make its smallest side to be 256; 235 | then randomly get a 224 crop 236 | """ 237 | crop_size = 224 238 | shape = tf.image.extract_jpeg_shape(image_string) 239 | scale = get_resize_scale(shape[0], shape[1], 256) 240 | cp_height = tf.cast(crop_size / scale, tf.int32) 241 | cp_width = tf.cast(crop_size / scale, tf.int32) 242 | 243 | # Randomly sample begin x and y 244 | # Original AlexNet preprocessing uses center 256*256 to crop 245 | min_shape = tf.minimum(shape[0], shape[1]) 246 | x_range = [ 247 | tf.cast((shape[0] - min_shape) / 2, tf.int32), 248 | shape[0] - cp_height + 1 - \ 249 | tf.cast( 250 | (shape[0] - min_shape) / 2, 251 | tf.int32), 252 | ] 253 | y_range = [ 254 | tf.cast((shape[1] - min_shape) / 2, tf.int32), 255 | shape[1] - cp_width + 1 - \ 256 | tf.cast( 257 | (shape[1] - min_shape) / 2, 258 | tf.int32), 259 | ] 260 | 261 | cp_begin_x = tf.random_uniform( 262 | shape=[], 263 | minval=x_range[0], maxval=x_range[1], 264 | dtype=tf.int32 265 | ) 266 | cp_begin_y = tf.random_uniform( 267 | shape=[], 268 | minval=y_range[0], maxval=y_range[1], 269 | dtype=tf.int32 270 | ) 271 | 272 | bbox = tf.stack([ 273 | cp_begin_x, cp_begin_y, \ 274 | cp_height, cp_width]) 275 | crop_image = tf.image.decode_and_crop_jpeg( 276 | image_string, 277 | bbox, 278 | channels=3) 279 | image = image_resize(crop_image, crop_size, crop_size) 280 | return image 281 | 282 | 283 | def jpeg_crop_at_xy(image_string, bbox, out_height, out_width): 284 | crop_image = tf.image.decode_and_crop_jpeg( 285 | image_string, 286 | bbox, 287 | channels=3) 288 | image = image_resize( 289 | crop_image, 290 | out_height, out_width, 291 | ) 292 | image.set_shape([out_height, out_width, 3]) 293 | return image 294 | 295 | 296 | def prep_10crops_validate( 297 | image_string, 298 | out_height, 299 | out_width, 300 | ): 301 | shape = tf.image.extract_jpeg_shape(image_string) 302 | scale = get_resize_scale(shape[0], shape[1], 256) 303 | cp_height = tf.cast(out_height / scale, tf.int32) 304 | cp_width = tf.cast(out_width / scale, tf.int32) 305 | 306 | all_images = [] 307 | # center crop 308 | cp_begin_x = tf.cast((shape[0] - cp_height) / 2, tf.int32) 309 | cp_begin_y = tf.cast((shape[1] - cp_width) / 2, tf.int32) 310 | bbox = tf.stack([ 311 | cp_begin_x, cp_begin_y, \ 312 | cp_height, cp_width]) 313 | image = jpeg_crop_at_xy(image_string, bbox, out_height, out_width) 314 | all_images.append(image) 315 | 316 | # Conrners 317 | for x_pos in ['up', 'down']: 318 | for y_pos in ['left', 'right']: 319 | if x_pos == 'up': 320 | cp_begin_x = tf.cast(0, tf.int32) 321 | else: 322 | cp_begin_x = tf.cast(shape[0] - cp_height, tf.int32) 323 | 324 | if y_pos == 'left': 325 | cp_begin_y = tf.cast(0, tf.int32) 326 | else: 327 | cp_begin_y = tf.cast(shape[1] - cp_width, tf.int32) 328 | 329 | bbox = tf.stack([ 330 | cp_begin_x, cp_begin_y, \ 331 | cp_height, cp_width]) 332 | image = jpeg_crop_at_xy(image_string, bbox, out_height, out_width) 333 | all_images.append(image) 334 | 335 | flipped_images = [] 336 | for each_image in all_images: 337 | flipped_images.append(tf.image.flip_left_right(each_image)) 338 | all_images.extend(flipped_images) 339 | images = tf.stack(all_images, axis=0) 340 | return images 341 | 342 | 343 | def preprocessing_inst( 344 | image_string, 345 | out_height, 346 | out_width, 347 | is_train, 348 | size_minval=0.2, 349 | ): 350 | def _val_func(image_string): 351 | shape = tf.image.extract_jpeg_shape(image_string) 352 | scale = get_resize_scale(shape[0], shape[1], 256) 353 | cp_height = tf.cast(out_height / scale, tf.int32) 354 | cp_width = tf.cast(out_width / scale, tf.int32) 355 | cp_begin_x = tf.cast((shape[0] - cp_height) / 2, tf.int32) 356 | cp_begin_y = tf.cast((shape[1] - cp_width) / 2, tf.int32) 357 | bbox = tf.stack([ 358 | cp_begin_x, cp_begin_y, \ 359 | cp_height, cp_width]) 360 | image = jpeg_crop_at_xy(image_string, bbox, out_height, out_width) 361 | return image 362 | 363 | def _rand_crop(image_string): 364 | image = RandomSizedCrop_from_jpeg( 365 | image_string, 366 | out_height=out_height, 367 | out_width=out_width, 368 | size_minval=size_minval, 369 | ) 370 | return image 371 | 372 | if is_train: 373 | image = _rand_crop(image_string) 374 | image = ApplyGray(image, 0.2) 375 | image = ColorJitter(image) 376 | image = tf.image.random_flip_left_right(image) 377 | 378 | else: 379 | image = _val_func(image_string) 380 | 381 | return image 382 | -------------------------------------------------------------------------------- /model/preprocessing.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function 2 | import os, sys 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from .prep_utils import ( 7 | preprocessing_inst, RandomSizedCrop_from_jpeg, 8 | ApplyGray, ColorJitter, alexnet_crop_from_jpg, prep_10crops_validate 9 | ) 10 | 11 | # This file contains various preprocessing ops for images (typically 12 | # used for data augmentation). 13 | 14 | def resnet_train(img_str): 15 | return preprocessing_inst(img_str, 224, 224, is_train=True) 16 | 17 | 18 | def resnet_validate(img_str): 19 | return preprocessing_inst(img_str, 224, 224, is_train=False) 20 | 21 | 22 | def resnet_10crop_validate(img_str): 23 | return prep_10crops_validate(img_str, 224, 224) 24 | 25 | 26 | def resnet_crop_flip(img_str): 27 | img = RandomSizedCrop_from_jpeg( 28 | img_str, out_height=224, out_width=224, size_minval=0.2) 29 | img = tf.image.random_flip_left_right(img) 30 | return img 31 | 32 | 33 | def alexnet_crop_flip(img_str): 34 | img = alexnet_crop_from_jpg(img_str) 35 | img = tf.image.random_flip_left_right(img) 36 | return img 37 | -------------------------------------------------------------------------------- /model/resnet_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains definitions for Residual Networks. 16 | 17 | Residual networks ('v1' ResNets) were originally proposed in: 18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 20 | 21 | The full preactivation 'v2' ResNet variant was introduced by: 22 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 23 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 24 | 25 | The key difference of the full preactivation 'v2' variant compared to the 26 | 'v1' variant in [1] is the use of batch normalization before every weight layer 27 | rather than after. 28 | """ 29 | 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | 34 | import tensorflow as tf 35 | 36 | _BATCH_NORM_DECAY = 0.997 37 | _BATCH_NORM_EPSILON = 1e-5 38 | DEFAULT_VERSION = 2 39 | DEFAULT_DTYPE = tf.float32 40 | CASTABLE_TYPES = (tf.float16,) 41 | ALLOWED_TYPES = (DEFAULT_DTYPE,) + CASTABLE_TYPES 42 | 43 | ENDING_POINTS = [] 44 | 45 | 46 | ################################################################################ 47 | # Convenience functions for building the ResNet model. 48 | ################################################################################ 49 | def batch_norm(inputs, training, data_format): 50 | """Performs a batch normalization using a standard set of parameters.""" 51 | # We set fused=True for a significant performance boost. See 52 | # https://www.tensorflow.org/performance/performance_guide#common_fused_ops 53 | return tf.layers.batch_normalization( 54 | inputs=inputs, axis=1 if data_format == 'channels_first' else 3, 55 | momentum=_BATCH_NORM_DECAY, epsilon=_BATCH_NORM_EPSILON, center=True, 56 | scale=True, training=training, fused=True) 57 | 58 | 59 | def fixed_padding(inputs, kernel_size, data_format): 60 | """Pads the input along the spatial dimensions independently of input size. 61 | 62 | Args: 63 | inputs: A tensor of size [batch, channels, height_in, width_in] or 64 | [batch, height_in, width_in, channels] depending on data_format. 65 | kernel_size: The kernel to be used in the conv2d or max_pool2d operation. 66 | Should be a positive integer. 67 | data_format: The input format ('channels_last' or 'channels_first'). 68 | 69 | Returns: 70 | A tensor with the same format as the input with the data either intact 71 | (if kernel_size == 1) or padded (if kernel_size > 1). 72 | """ 73 | pad_total = kernel_size - 1 74 | pad_beg = pad_total // 2 75 | pad_end = pad_total - pad_beg 76 | 77 | if data_format == 'channels_first': 78 | padded_inputs = tf.pad(inputs, [[0, 0], [0, 0], 79 | [pad_beg, pad_end], [pad_beg, pad_end]]) 80 | else: 81 | padded_inputs = tf.pad(inputs, [[0, 0], [pad_beg, pad_end], 82 | [pad_beg, pad_end], [0, 0]]) 83 | return padded_inputs 84 | 85 | 86 | def conv2d_fixed_padding(inputs, filters, kernel_size, strides, data_format): 87 | """Strided 2-D convolution with explicit padding.""" 88 | # The padding is consistent and is based only on `kernel_size`, not on the 89 | # dimensions of `inputs` (as opposed to using `tf.layers.conv2d` alone). 90 | if strides > 1: 91 | inputs = fixed_padding(inputs, kernel_size, data_format) 92 | 93 | return tf.layers.conv2d( 94 | inputs=inputs, filters=filters, kernel_size=kernel_size, strides=strides, 95 | padding=('SAME' if strides == 1 else 'VALID'), use_bias=False, 96 | kernel_initializer=tf.variance_scaling_initializer(), 97 | data_format=data_format) 98 | 99 | 100 | ################################################################################ 101 | # ResNet block definitions. 102 | ################################################################################ 103 | def _building_block_v1(inputs, filters, training, projection_shortcut, strides, 104 | data_format): 105 | """A single block for ResNet v1, without a bottleneck. 106 | 107 | Convolution then batch normalization then ReLU as described by: 108 | Deep Residual Learning for Image Recognition 109 | https://arxiv.org/pdf/1512.03385.pdf 110 | by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015. 111 | 112 | Args: 113 | inputs: A tensor of size [batch, channels, height_in, width_in] or 114 | [batch, height_in, width_in, channels] depending on data_format. 115 | filters: The number of filters for the convolutions. 116 | training: A Boolean for whether the model is in training or inference 117 | mode. Needed for batch normalization. 118 | projection_shortcut: The function to use for projection shortcuts 119 | (typically a 1x1 convolution when downsampling the input). 120 | strides: The block's stride. If greater than 1, this block will ultimately 121 | downsample the input. 122 | data_format: The input format ('channels_last' or 'channels_first'). 123 | 124 | Returns: 125 | The output tensor of the block; shape should match inputs. 126 | """ 127 | shortcut = inputs 128 | 129 | if projection_shortcut is not None: 130 | shortcut = projection_shortcut(inputs) 131 | shortcut = batch_norm(inputs=shortcut, training=training, 132 | data_format=data_format) 133 | 134 | inputs = conv2d_fixed_padding( 135 | inputs=inputs, filters=filters, kernel_size=3, strides=strides, 136 | data_format=data_format) 137 | inputs = batch_norm(inputs, training, data_format) 138 | inputs = tf.nn.relu(inputs) 139 | 140 | inputs = conv2d_fixed_padding( 141 | inputs=inputs, filters=filters, kernel_size=3, strides=1, 142 | data_format=data_format) 143 | inputs = batch_norm(inputs, training, data_format) 144 | inputs += shortcut 145 | inputs = tf.nn.relu(inputs) 146 | 147 | return inputs 148 | 149 | 150 | def _building_block_v2(inputs, filters, training, projection_shortcut, strides, 151 | data_format): 152 | """A single block for ResNet v2, without a bottleneck. 153 | 154 | Batch normalization then ReLu then convolution as described by: 155 | Identity Mappings in Deep Residual Networks 156 | https://arxiv.org/pdf/1603.05027.pdf 157 | by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016. 158 | 159 | Args: 160 | inputs: A tensor of size [batch, channels, height_in, width_in] or 161 | [batch, height_in, width_in, channels] depending on data_format. 162 | filters: The number of filters for the convolutions. 163 | training: A Boolean for whether the model is in training or inference 164 | mode. Needed for batch normalization. 165 | projection_shortcut: The function to use for projection shortcuts 166 | (typically a 1x1 convolution when downsampling the input). 167 | strides: The block's stride. If greater than 1, this block will ultimately 168 | downsample the input. 169 | data_format: The input format ('channels_last' or 'channels_first'). 170 | 171 | Returns: 172 | The output tensor of the block; shape should match inputs. 173 | """ 174 | shortcut = inputs 175 | inputs = batch_norm(inputs, training, data_format) 176 | inputs = tf.nn.relu(inputs) 177 | ENDING_POINTS.append(inputs) 178 | 179 | # The projection shortcut should come after the first batch norm and ReLU 180 | # since it performs a 1x1 convolution. 181 | if projection_shortcut is not None: 182 | shortcut = projection_shortcut(inputs) 183 | 184 | inputs = conv2d_fixed_padding( 185 | inputs=inputs, filters=filters, kernel_size=3, strides=strides, 186 | data_format=data_format) 187 | 188 | inputs = batch_norm(inputs, training, data_format) 189 | inputs = tf.nn.relu(inputs) 190 | inputs = conv2d_fixed_padding( 191 | inputs=inputs, filters=filters, kernel_size=3, strides=1, 192 | data_format=data_format) 193 | 194 | return inputs + shortcut 195 | 196 | 197 | def _bottleneck_block_v1(inputs, filters, training, projection_shortcut, 198 | strides, data_format): 199 | """A single block for ResNet v1, with a bottleneck. 200 | 201 | Similar to _building_block_v1(), except using the "bottleneck" blocks 202 | described in: 203 | Convolution then batch normalization then ReLU as described by: 204 | Deep Residual Learning for Image Recognition 205 | https://arxiv.org/pdf/1512.03385.pdf 206 | by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015. 207 | 208 | Args: 209 | inputs: A tensor of size [batch, channels, height_in, width_in] or 210 | [batch, height_in, width_in, channels] depending on data_format. 211 | filters: The number of filters for the convolutions. 212 | training: A Boolean for whether the model is in training or inference 213 | mode. Needed for batch normalization. 214 | projection_shortcut: The function to use for projection shortcuts 215 | (typically a 1x1 convolution when downsampling the input). 216 | strides: The block's stride. If greater than 1, this block will ultimately 217 | downsample the input. 218 | data_format: The input format ('channels_last' or 'channels_first'). 219 | 220 | Returns: 221 | The output tensor of the block; shape should match inputs. 222 | """ 223 | shortcut = inputs 224 | 225 | if projection_shortcut is not None: 226 | shortcut = projection_shortcut(inputs) 227 | shortcut = batch_norm(inputs=shortcut, training=training, 228 | data_format=data_format) 229 | 230 | inputs = conv2d_fixed_padding( 231 | inputs=inputs, filters=filters, kernel_size=1, strides=1, 232 | data_format=data_format) 233 | inputs = batch_norm(inputs, training, data_format) 234 | inputs = tf.nn.relu(inputs) 235 | 236 | inputs = conv2d_fixed_padding( 237 | inputs=inputs, filters=filters, kernel_size=3, strides=strides, 238 | data_format=data_format) 239 | inputs = batch_norm(inputs, training, data_format) 240 | inputs = tf.nn.relu(inputs) 241 | 242 | inputs = conv2d_fixed_padding( 243 | inputs=inputs, filters=4 * filters, kernel_size=1, strides=1, 244 | data_format=data_format) 245 | inputs = batch_norm(inputs, training, data_format) 246 | inputs += shortcut 247 | inputs = tf.nn.relu(inputs) 248 | 249 | return inputs 250 | 251 | 252 | def _bottleneck_block_v2(inputs, filters, training, projection_shortcut, 253 | strides, data_format): 254 | """A single block for ResNet v2, with a bottleneck. 255 | 256 | Similar to _building_block_v2(), except using the "bottleneck" blocks 257 | described in: 258 | Convolution then batch normalization then ReLU as described by: 259 | Deep Residual Learning for Image Recognition 260 | https://arxiv.org/pdf/1512.03385.pdf 261 | by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Dec 2015. 262 | 263 | Adapted to the ordering conventions of: 264 | Batch normalization then ReLu then convolution as described by: 265 | Identity Mappings in Deep Residual Networks 266 | https://arxiv.org/pdf/1603.05027.pdf 267 | by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016. 268 | 269 | Args: 270 | inputs: A tensor of size [batch, channels, height_in, width_in] or 271 | [batch, height_in, width_in, channels] depending on data_format. 272 | filters: The number of filters for the convolutions. 273 | training: A Boolean for whether the model is in training or inference 274 | mode. Needed for batch normalization. 275 | projection_shortcut: The function to use for projection shortcuts 276 | (typically a 1x1 convolution when downsampling the input). 277 | strides: The block's stride. If greater than 1, this block will ultimately 278 | downsample the input. 279 | data_format: The input format ('channels_last' or 'channels_first'). 280 | 281 | Returns: 282 | The output tensor of the block; shape should match inputs. 283 | """ 284 | shortcut = inputs 285 | inputs = batch_norm(inputs, training, data_format) 286 | inputs = tf.nn.relu(inputs) 287 | ENDING_POINTS.append(inputs) 288 | 289 | # The projection shortcut should come after the first batch norm and ReLU 290 | # since it performs a 1x1 convolution. 291 | if projection_shortcut is not None: 292 | shortcut = projection_shortcut(inputs) 293 | 294 | inputs = conv2d_fixed_padding( 295 | inputs=inputs, filters=filters, kernel_size=1, strides=1, 296 | data_format=data_format) 297 | 298 | inputs = batch_norm(inputs, training, data_format) 299 | inputs = tf.nn.relu(inputs) 300 | inputs = conv2d_fixed_padding( 301 | inputs=inputs, filters=filters, kernel_size=3, strides=strides, 302 | data_format=data_format) 303 | 304 | inputs = batch_norm(inputs, training, data_format) 305 | inputs = tf.nn.relu(inputs) 306 | inputs = conv2d_fixed_padding( 307 | inputs=inputs, filters=4 * filters, kernel_size=1, strides=1, 308 | data_format=data_format) 309 | 310 | return inputs + shortcut 311 | 312 | 313 | def block_layer(inputs, filters, bottleneck, block_fn, blocks, strides, 314 | training, name, data_format): 315 | """Creates one layer of blocks for the ResNet model. 316 | 317 | Args: 318 | inputs: A tensor of size [batch, channels, height_in, width_in] or 319 | [batch, height_in, width_in, channels] depending on data_format. 320 | filters: The number of filters for the first convolution of the layer. 321 | bottleneck: Is the block created a bottleneck block. 322 | block_fn: The block to use within the model, either `building_block` or 323 | `bottleneck_block`. 324 | blocks: The number of blocks contained in the layer. 325 | strides: The stride to use for the first convolution of the layer. If 326 | greater than 1, this layer will ultimately downsample the input. 327 | training: Either True or False, whether we are currently training the 328 | model. Needed for batch norm. 329 | name: A string name for the tensor output of the block layer. 330 | data_format: The input format ('channels_last' or 'channels_first'). 331 | 332 | Returns: 333 | The output tensor of the block layer. 334 | """ 335 | 336 | # Bottleneck blocks end with 4x the number of filters as they start with 337 | filters_out = filters * 4 if bottleneck else filters 338 | 339 | def projection_shortcut(inputs): 340 | return conv2d_fixed_padding( 341 | inputs=inputs, filters=filters_out, kernel_size=1, strides=strides, 342 | data_format=data_format) 343 | 344 | # Only the first block per block_layer uses projection_shortcut and strides 345 | inputs = block_fn(inputs, filters, training, projection_shortcut, strides, 346 | data_format) 347 | 348 | for _ in range(1, blocks): 349 | inputs = block_fn(inputs, filters, training, None, 1, data_format) 350 | 351 | return tf.identity(inputs, name) 352 | 353 | 354 | class Model(object): 355 | """Base class for building the Resnet Model.""" 356 | 357 | def __init__(self, resnet_size, bottleneck, num_classes, num_filters, 358 | kernel_size, 359 | conv_stride, first_pool_size, first_pool_stride, 360 | block_sizes, block_strides, 361 | final_size, resnet_version=DEFAULT_VERSION, data_format=None, 362 | dtype=DEFAULT_DTYPE): 363 | """Creates a model for classifying an image. 364 | 365 | Args: 366 | resnet_size: A single integer for the size of the ResNet model. 367 | bottleneck: Use regular blocks or bottleneck blocks. 368 | num_classes: The number of classes used as labels. 369 | num_filters: The number of filters to use for the first block layer 370 | of the model. This number is then doubled for each subsequent block 371 | layer. 372 | kernel_size: The kernel size to use for convolution. 373 | conv_stride: stride size for the initial convolutional layer 374 | first_pool_size: Pool size to be used for the first pooling layer. 375 | If none, the first pooling layer is skipped. 376 | first_pool_stride: stride size for the first pooling layer. Not used 377 | if first_pool_size is None. 378 | block_sizes: A list containing n values, where n is the number of sets of 379 | block layers desired. Each value should be the number of blocks in the 380 | i-th set. 381 | block_strides: List of integers representing the desired stride size for 382 | each of the sets of block layers. Should be same length as block_sizes. 383 | final_size: The expected size of the model after the second pooling. 384 | resnet_version: Integer representing which version of the ResNet network 385 | to use. See README for details. Valid values: [1, 2] 386 | data_format: Input format ('channels_last', 'channels_first', or None). 387 | If set to None, the format is dependent on whether a GPU is available. 388 | dtype: The TensorFlow dtype to use for calculations. If not specified 389 | tf.float32 is used. 390 | 391 | Raises: 392 | ValueError: if invalid version is selected. 393 | """ 394 | self.resnet_size = resnet_size 395 | 396 | if not data_format: 397 | data_format = ( 398 | 'channels_first' if tf.test.is_built_with_cuda() else 'channels_last') 399 | 400 | self.resnet_version = resnet_version 401 | if resnet_version not in (1, 2): 402 | raise ValueError( 403 | 'Resnet version should be 1 or 2. See README for citations.') 404 | 405 | self.bottleneck = bottleneck 406 | if bottleneck: 407 | if resnet_version == 1: 408 | self.block_fn = _bottleneck_block_v1 409 | else: 410 | self.block_fn = _bottleneck_block_v2 411 | else: 412 | if resnet_version == 1: 413 | self.block_fn = _building_block_v1 414 | else: 415 | self.block_fn = _building_block_v2 416 | 417 | if dtype not in ALLOWED_TYPES: 418 | raise ValueError('dtype must be one of: {}'.format(ALLOWED_TYPES)) 419 | 420 | self.data_format = data_format 421 | self.num_classes = num_classes 422 | self.num_filters = num_filters 423 | self.kernel_size = kernel_size 424 | self.conv_stride = conv_stride 425 | self.first_pool_size = first_pool_size 426 | self.first_pool_stride = first_pool_stride 427 | self.block_sizes = block_sizes 428 | self.block_strides = block_strides 429 | self.final_size = final_size 430 | self.dtype = dtype 431 | self.pre_activation = resnet_version == 2 432 | 433 | def _custom_dtype_getter(self, getter, name, shape=None, dtype=DEFAULT_DTYPE, 434 | *args, **kwargs): 435 | """Creates variables in fp32, then casts to fp16 if necessary. 436 | 437 | This function is a custom getter. A custom getter is a function with the 438 | same signature as tf.get_variable, except it has an additional getter 439 | parameter. Custom getters can be passed as the `custom_getter` parameter of 440 | tf.variable_scope. Then, tf.get_variable will call the custom getter, 441 | instead of directly getting a variable itself. This can be used to change 442 | the types of variables that are retrieved with tf.get_variable. 443 | The `getter` parameter is the underlying variable getter, that would have 444 | been called if no custom getter was used. Custom getters typically get a 445 | variable with `getter`, then modify it in some way. 446 | 447 | This custom getter will create an fp32 variable. If a low precision 448 | (e.g. float16) variable was requested it will then cast the variable to the 449 | requested dtype. The reason we do not directly create variables in low 450 | precision dtypes is that applying small gradients to such variables may 451 | cause the variable not to change. 452 | 453 | Args: 454 | getter: The underlying variable getter, that has the same signature as 455 | tf.get_variable and returns a variable. 456 | name: The name of the variable to get. 457 | shape: The shape of the variable to get. 458 | dtype: The dtype of the variable to get. Note that if this is a low 459 | precision dtype, the variable will be created as a tf.float32 variable, 460 | then cast to the appropriate dtype 461 | *args: Additional arguments to pass unmodified to getter. 462 | **kwargs: Additional keyword arguments to pass unmodified to getter. 463 | 464 | Returns: 465 | A variable which is cast to fp16 if necessary. 466 | """ 467 | 468 | if dtype in CASTABLE_TYPES: 469 | var = getter(name, shape, tf.float32, *args, **kwargs) 470 | return tf.cast(var, dtype=dtype, name=name + '_cast') 471 | else: 472 | return getter(name, shape, dtype, *args, **kwargs) 473 | 474 | def _model_variable_scope(self): 475 | """Returns a variable scope that the model should be created under. 476 | 477 | If self.dtype is a castable type, model variable will be created in fp32 478 | then cast to self.dtype before being used. 479 | 480 | Returns: 481 | A variable scope for the model. 482 | """ 483 | 484 | return tf.variable_scope('resnet_model', 485 | custom_getter=self._custom_dtype_getter) 486 | 487 | def __call__( 488 | self, inputs, training, get_all_layers=None, skip_final_dense=False): 489 | """Add operations to classify a batch of input images. 490 | 491 | Args: 492 | inputs: A Tensor representing a batch of input images. 493 | training: A boolean. Set to True to add operations required only when 494 | training the classifier. 495 | 496 | Returns: 497 | A logits Tensor with shape [, self.num_classes]. 498 | """ 499 | global ENDING_POINTS 500 | ENDING_POINTS = [] 501 | 502 | with self._model_variable_scope(): 503 | if self.data_format == 'channels_first': 504 | # Convert the inputs from channels_last (NHWC) to channels_first (NCHW). 505 | # This provides a large performance boost on GPU. See 506 | # https://www.tensorflow.org/performance/performance_guide#data_formats 507 | inputs = tf.transpose(inputs, [0, 3, 1, 2]) 508 | 509 | inputs = conv2d_fixed_padding( 510 | inputs=inputs, filters=self.num_filters, kernel_size=self.kernel_size, 511 | strides=self.conv_stride, data_format=self.data_format) 512 | inputs = tf.identity(inputs, 'initial_conv') 513 | 514 | # We do not include batch normalization or activation functions in V2 515 | # for the initial conv1 because the first ResNet unit will perform these 516 | # for both the shortcut and non-shortcut paths as part of the first 517 | # block's projection. Cf. Appendix of [2]. 518 | if self.resnet_version == 1: 519 | inputs = batch_norm(inputs, training, self.data_format) 520 | inputs = tf.nn.relu(inputs) 521 | 522 | if self.first_pool_size: 523 | inputs = tf.layers.max_pooling2d( 524 | inputs=inputs, pool_size=self.first_pool_size, 525 | strides=self.first_pool_stride, padding='SAME', 526 | data_format=self.data_format) 527 | inputs = tf.identity(inputs, 'initial_max_pool') 528 | ENDING_POINTS.append(inputs) 529 | 530 | for i, num_blocks in enumerate(self.block_sizes): 531 | num_filters = self.num_filters * (2**i) 532 | inputs = block_layer( 533 | inputs=inputs, filters=num_filters, bottleneck=self.bottleneck, 534 | block_fn=self.block_fn, blocks=num_blocks, 535 | strides=self.block_strides[i], training=training, 536 | name='block_layer{}'.format(i + 1), data_format=self.data_format) 537 | 538 | # Only apply the BN and ReLU for model that does pre_activation in each 539 | # building/bottleneck block, eg resnet V2. 540 | if self.pre_activation: 541 | inputs = batch_norm(inputs, training, self.data_format) 542 | inputs = tf.nn.relu(inputs) 543 | ENDING_POINTS.append(inputs) 544 | 545 | if skip_final_dense: 546 | return tf.reshape(inputs, [-1, 7 * 7 * self.final_size]) 547 | 548 | # The current top layer has shape 549 | # `batch_size x pool_size x pool_size x final_size`. 550 | # ResNet does an Average Pooling layer over pool_size, 551 | # but that is the same as doing a reduce_mean. We do a reduce_mean 552 | # here because it performs better than AveragePooling2D. 553 | axes = [2, 3] if self.data_format == 'channels_first' else [1, 2] 554 | inputs = tf.reduce_mean(inputs, axes, keepdims=True) 555 | inputs = tf.identity(inputs, 'final_reduce_mean') 556 | 557 | inputs = tf.reshape(inputs, [-1, self.final_size]) 558 | inputs = tf.layers.dense(inputs=inputs, units=self.num_classes) 559 | inputs = tf.identity(inputs, 'final_dense') 560 | if not get_all_layers: 561 | return inputs 562 | else: 563 | return inputs, ENDING_POINTS 564 | 565 | 566 | ############################################################################### 567 | # Running the model 568 | ############################################################################### 569 | _NUM_CLASSES = 128 570 | 571 | 572 | def _get_block_sizes(resnet_size): 573 | """Retrieve the size of each block_layer in the ResNet model. 574 | 575 | The number of block layers used for the Resnet model varies according 576 | to the size of the model. This helper grabs the layer set we want, throwing 577 | an error if a non-standard size has been selected. 578 | 579 | Args: 580 | resnet_size: The number of convolutional layers needed in the model. 581 | 582 | Returns: 583 | A list of block sizes to use in building the model. 584 | 585 | Raises: 586 | KeyError: if invalid resnet_size is received. 587 | """ 588 | choices = { 589 | 18: [2, 2, 2, 2], 590 | 34: [3, 4, 6, 3], 591 | 50: [3, 4, 6, 3], 592 | 101: [3, 4, 23, 3], 593 | 152: [3, 8, 36, 3], 594 | 200: [3, 24, 36, 3] 595 | } 596 | 597 | try: 598 | return choices[resnet_size] 599 | except KeyError: 600 | err = ('Could not find layers for selected Resnet size.\n' 601 | 'Size received: {}; sizes allowed: {}.'.format( 602 | resnet_size, choices.keys())) 603 | raise ValueError(err) 604 | 605 | 606 | class ImagenetModel(Model): 607 | """Model class with appropriate defaults for Imagenet data.""" 608 | 609 | def __init__(self, resnet_size, data_format=None, num_classes=_NUM_CLASSES, 610 | resnet_version=DEFAULT_VERSION, 611 | dtype=DEFAULT_DTYPE): 612 | """These are the parameters that work for Imagenet data. 613 | 614 | Args: 615 | resnet_size: The number of convolutional layers needed in the model. 616 | data_format: Either 'channels_first' or 'channels_last', specifying which 617 | data format to use when setting up the model. 618 | num_classes: The number of output classes needed from the model. This 619 | enables users to extend the same model to their own datasets. 620 | resnet_version: Integer representing which version of the ResNet network 621 | to use. See README for details. Valid values: [1, 2] 622 | dtype: The TensorFlow dtype to use for calculations. 623 | """ 624 | 625 | # For bigger models, we want to use "bottleneck" layers 626 | if resnet_size < 50: 627 | bottleneck = False 628 | final_size = 512 629 | else: 630 | bottleneck = True 631 | final_size = 2048 632 | 633 | super(ImagenetModel, self).__init__( 634 | resnet_size=resnet_size, 635 | bottleneck=bottleneck, 636 | num_classes=num_classes, 637 | num_filters=64, 638 | kernel_size=7, 639 | conv_stride=2, 640 | first_pool_size=3, 641 | first_pool_stride=2, 642 | block_sizes=_get_block_sizes(resnet_size), 643 | block_strides=[1, 2, 2, 2], 644 | final_size=final_size, 645 | resnet_version=resnet_version, 646 | data_format=data_format, 647 | dtype=dtype 648 | ) 649 | -------------------------------------------------------------------------------- /model/vggnet_model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains model definitions for versions of the Oxford VGG network. 16 | These model definitions were introduced in the following technical report: 17 | Very Deep Convolutional Networks For Large-Scale Image Recognition 18 | Karen Simonyan and Andrew Zisserman 19 | arXiv technical report, 2015 20 | PDF: http://arxiv.org/pdf/1409.1556.pdf 21 | ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf 22 | CC-BY-4.0 23 | More information can be obtained from the VGG website: 24 | www.robots.ox.ac.uk/~vgg/research/very_deep/ 25 | Usage: 26 | with slim.arg_scope(vgg.vgg_arg_scope()): 27 | outputs, end_points = vgg.vgg_a(inputs) 28 | with slim.arg_scope(vgg.vgg_arg_scope()): 29 | outputs, end_points = vgg.vgg_16(inputs) 30 | @@vgg_a 31 | @@vgg_16 32 | @@vgg_19 33 | """ 34 | from __future__ import absolute_import 35 | from __future__ import division 36 | from __future__ import print_function 37 | 38 | import tensorflow as tf 39 | 40 | slim = tf.contrib.slim 41 | 42 | 43 | def vgg_arg_scope(weight_decay=0.0005): 44 | """Defines the VGG arg scope. 45 | Args: 46 | weight_decay: The l2 regularization coefficient. 47 | Returns: 48 | An arg_scope. 49 | """ 50 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 51 | activation_fn=tf.nn.relu, 52 | weights_regularizer=slim.l2_regularizer(weight_decay), 53 | biases_initializer=tf.zeros_initializer()): 54 | with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc: 55 | return arg_sc 56 | 57 | 58 | def vgg_a(inputs, 59 | num_classes=1000, 60 | is_training=True, 61 | dropout_keep_prob=0.5, 62 | spatial_squeeze=True, 63 | scope='vgg_a', 64 | fc_conv_padding='VALID', 65 | global_pool=False): 66 | """Oxford Net VGG 11-Layers version A Example. 67 | Note: All the fully_connected layers have been transformed to conv2d layers. 68 | To use in classification mode, resize input to 224x224. 69 | Args: 70 | inputs: a tensor of size [batch_size, height, width, channels]. 71 | num_classes: number of predicted classes. If 0 or None, the logits layer is 72 | omitted and the input features to the logits layer are returned instead. 73 | is_training: whether or not the model is being trained. 74 | dropout_keep_prob: the probability that activations are kept in the dropout 75 | layers during training. 76 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 77 | outputs. Useful to remove unnecessary dimensions for classification. 78 | scope: Optional scope for the variables. 79 | fc_conv_padding: the type of padding to use for the fully connected layer 80 | that is implemented as a convolutional layer. Use 'SAME' padding if you 81 | are applying the network in a fully convolutional manner and want to 82 | get a prediction map downsampled by a factor of 32 as an output. 83 | Otherwise, the output prediction map will be (input / 32) - 6 in case of 84 | 'VALID' padding. 85 | global_pool: Optional boolean flag. If True, the input to the classification 86 | layer is avgpooled to size 1x1, for any input size. (This is not part 87 | of the original VGG architecture.) 88 | Returns: 89 | net: the output of the logits layer (if num_classes is a non-zero integer), 90 | or the input to the logits layer (if num_classes is 0 or None). 91 | end_points: a dict of tensors with intermediate activations. 92 | """ 93 | with tf.variable_scope(scope, 'vgg_a', [inputs]) as sc: 94 | end_points_collection = sc.original_name_scope + '_end_points' 95 | # Collect outputs for conv2d, fully_connected and max_pool2d. 96 | with slim.arg_scope([slim.conv2d, slim.max_pool2d], 97 | outputs_collections=end_points_collection): 98 | net = slim.repeat(inputs, 1, slim.conv2d, 64, [3, 3], scope='conv1') 99 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 100 | net = slim.repeat(net, 1, slim.conv2d, 128, [3, 3], scope='conv2') 101 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 102 | net = slim.repeat(net, 2, slim.conv2d, 256, [3, 3], scope='conv3') 103 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 104 | net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv4') 105 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 106 | net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv5') 107 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 108 | 109 | # Use conv2d instead of fully_connected layers. 110 | net = slim.conv2d(net, 4096, [7, 7], padding=fc_conv_padding, scope='fc6') 111 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 112 | scope='dropout6') 113 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 114 | # Convert end_points_collection into a end_point dict. 115 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 116 | if global_pool: 117 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 118 | end_points['global_pool'] = net 119 | if num_classes: 120 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 121 | scope='dropout7') 122 | net = slim.conv2d(net, num_classes, [1, 1], 123 | activation_fn=None, 124 | normalizer_fn=None, 125 | scope='fc8') 126 | if spatial_squeeze: 127 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 128 | end_points[sc.name + '/fc8'] = net 129 | return net, end_points 130 | vgg_a.default_image_size = 224 131 | 132 | 133 | def vgg_16(inputs, 134 | num_classes=1000, 135 | is_training=True, 136 | dropout_keep_prob=0.5, 137 | spatial_squeeze=True, 138 | scope='vgg_16', 139 | fc_conv_padding='VALID', 140 | fix_bug=False, # For previous behaviours 141 | with_bn=False, 142 | global_pool=False): 143 | """Oxford Net VGG 16-Layers version D Example. 144 | Note: All the fully_connected layers have been transformed to conv2d layers. 145 | To use in classification mode, resize input to 224x224. 146 | Args: 147 | inputs: a tensor of size [batch_size, height, width, channels]. 148 | num_classes: number of predicted classes. If 0 or None, the logits layer is 149 | omitted and the input features to the logits layer are returned instead. 150 | is_training: whether or not the model is being trained. 151 | dropout_keep_prob: the probability that activations are kept in the dropout 152 | layers during training. 153 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 154 | outputs. Useful to remove unnecessary dimensions for classification. 155 | scope: Optional scope for the variables. 156 | fc_conv_padding: the type of padding to use for the fully connected layer 157 | that is implemented as a convolutional layer. Use 'SAME' padding if you 158 | are applying the network in a fully convolutional manner and want to 159 | get a prediction map downsampled by a factor of 32 as an output. 160 | Otherwise, the output prediction map will be (input / 32) - 6 in case of 161 | 'VALID' padding. 162 | global_pool: Optional boolean flag. If True, the input to the classification 163 | layer is avgpooled to size 1x1, for any input size. (This is not part 164 | of the original VGG architecture.) 165 | Returns: 166 | net: the output of the logits layer (if num_classes is a non-zero integer), 167 | or the input to the logits layer (if num_classes is 0 or None). 168 | end_points: a dict of tensors with intermediate activations. 169 | """ 170 | if with_bn: 171 | normalizer_fn = tf.layers.batch_normalization 172 | normalizer_params = { 173 | 'momentum': 0.997, 174 | 'epsilon': 1e-5, 175 | 'training': is_training, 176 | 'fused': True} 177 | kwargs = { 178 | 'normalizer_fn': normalizer_fn, 179 | 'normalizer_params': normalizer_params} 180 | else: 181 | kwargs = {} 182 | end_points = [] 183 | with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc: 184 | end_points_collection = sc.original_name_scope + '_end_points' 185 | # Collect outputs for conv2d, fully_connected and max_pool2d. 186 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 187 | outputs_collections=end_points_collection): 188 | if fix_bug: 189 | net = slim.repeat( 190 | inputs, 1, slim.conv2d, 64, [3, 3], scope='conv1') 191 | inputs = net 192 | net = slim.repeat( 193 | inputs, 1, slim.conv2d, 64, [3, 3], scope='conv1_final', **kwargs) 194 | end_points.append(net) 195 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 196 | net = slim.repeat( 197 | net, 1, slim.conv2d, 128, [3, 3], scope='conv2') 198 | net = slim.repeat( 199 | net, 1, slim.conv2d, 128, [3, 3], scope='conv2_final', **kwargs) 200 | end_points.append(net) 201 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 202 | net = slim.repeat( 203 | net, 2, slim.conv2d, 256, [3, 3], scope='conv3') 204 | net = slim.repeat( 205 | net, 1, slim.conv2d, 256, [3, 3], scope='conv3_final', **kwargs) 206 | end_points.append(net) 207 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 208 | net = slim.repeat( 209 | net, 2, slim.conv2d, 512, [3, 3], scope='conv4') 210 | net = slim.repeat( 211 | net, 1, slim.conv2d, 512, [3, 3], scope='conv4_final', **kwargs) 212 | end_points.append(net) 213 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 214 | net = slim.repeat( 215 | net, 2, slim.conv2d, 512, [3, 3], scope='conv5') 216 | net = slim.repeat( 217 | net, 1, slim.conv2d, 512, [3, 3], scope='conv5_final', **kwargs) 218 | end_points.append(net) 219 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 220 | 221 | # Use conv2d instead of fully_connected layers. 222 | net = slim.conv2d(net, 4096, [7, 7], padding=fc_conv_padding, scope='fc6', **kwargs) 223 | if dropout_keep_prob > 0: 224 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 225 | scope='dropout6') 226 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7', **kwargs) 227 | # Convert end_points_collection into a end_point dict. 228 | #end_points = slim.utils.convert_collection_to_dict(end_points_collection) 229 | if global_pool: 230 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 231 | #end_points['global_pool'] = net 232 | if num_classes: 233 | if dropout_keep_prob > 0: 234 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 235 | scope='dropout7') 236 | net = slim.conv2d(net, num_classes, [1, 1], 237 | activation_fn=None, 238 | normalizer_fn=None, 239 | scope='fc8') 240 | if spatial_squeeze: 241 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 242 | #end_points[sc.name + '/fc8'] = net 243 | return net, end_points 244 | vgg_16.default_image_size = 224 245 | 246 | 247 | def vgg_19(inputs, 248 | num_classes=1000, 249 | is_training=True, 250 | dropout_keep_prob=0.5, 251 | spatial_squeeze=True, 252 | scope='vgg_19', 253 | fc_conv_padding='VALID', 254 | global_pool=False): 255 | """Oxford Net VGG 19-Layers version E Example. 256 | Note: All the fully_connected layers have been transformed to conv2d layers. 257 | To use in classification mode, resize input to 224x224. 258 | Args: 259 | inputs: a tensor of size [batch_size, height, width, channels]. 260 | num_classes: number of predicted classes. If 0 or None, the logits layer is 261 | omitted and the input features to the logits layer are returned instead. 262 | is_training: whether or not the model is being trained. 263 | dropout_keep_prob: the probability that activations are kept in the dropout 264 | layers during training. 265 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 266 | outputs. Useful to remove unnecessary dimensions for classification. 267 | scope: Optional scope for the variables. 268 | fc_conv_padding: the type of padding to use for the fully connected layer 269 | that is implemented as a convolutional layer. Use 'SAME' padding if you 270 | are applying the network in a fully convolutional manner and want to 271 | get a prediction map downsampled by a factor of 32 as an output. 272 | Otherwise, the output prediction map will be (input / 32) - 6 in case of 273 | 'VALID' padding. 274 | global_pool: Optional boolean flag. If True, the input to the classification 275 | layer is avgpooled to size 1x1, for any input size. (This is not part 276 | of the original VGG architecture.) 277 | Returns: 278 | net: the output of the logits layer (if num_classes is a non-zero integer), 279 | or the non-dropped-out input to the logits layer (if num_classes is 0 or 280 | None). 281 | end_points: a dict of tensors with intermediate activations. 282 | """ 283 | with tf.variable_scope(scope, 'vgg_19', [inputs]) as sc: 284 | end_points_collection = sc.original_name_scope + '_end_points' 285 | # Collect outputs for conv2d, fully_connected and max_pool2d. 286 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 287 | outputs_collections=end_points_collection): 288 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') 289 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 290 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') 291 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 292 | net = slim.repeat(net, 4, slim.conv2d, 256, [3, 3], scope='conv3') 293 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 294 | net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv4') 295 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 296 | net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv5') 297 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 298 | 299 | # Use conv2d instead of fully_connected layers. 300 | net = slim.conv2d(net, 4096, [7, 7], padding=fc_conv_padding, scope='fc6') 301 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 302 | scope='dropout6') 303 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 304 | # Convert end_points_collection into a end_point dict. 305 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 306 | if global_pool: 307 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 308 | end_points['global_pool'] = net 309 | if num_classes: 310 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 311 | scope='dropout7') 312 | net = slim.conv2d(net, num_classes, [1, 1], 313 | activation_fn=None, 314 | normalizer_fn=None, 315 | scope='fc8') 316 | if spatial_squeeze: 317 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 318 | end_points[sc.name + '/fc8'] = net 319 | return net, end_points 320 | vgg_19.default_image_size = 224 321 | 322 | # Alias 323 | vgg_d = vgg_16 324 | vgg_e = vgg_19 325 | -------------------------------------------------------------------------------- /param_setter.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | import os, sys 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | import json 7 | import copy 8 | import argparse 9 | import time 10 | import functools 11 | import inspect 12 | 13 | from model import preprocessing as prep 14 | from model import instance_model 15 | from model.memory_bank import MemoryBank 16 | from model.dataset_utils import dataset_func 17 | from model.instance_model import get_global_step_var 18 | 19 | from utils import DATA_LEN_IMAGENET_FULL, tuple_get_one 20 | import config 21 | import pdb 22 | 23 | 24 | def get_config(): 25 | cfg = config.Config() 26 | cfg.add('exp_id', type=str, required=True, 27 | help='Name of experiment ID') 28 | cfg.add('batch_size', type=int, default=128, 29 | help='Training batch size') 30 | cfg.add('test_batch_size', type=int, default=64, 31 | help='Testing batch size') 32 | cfg.add('init_lr', type=float, default=0.03, 33 | help='Initial learning rate') 34 | cfg.add('gpu', type=str, required=True, 35 | help='Value for CUDA_VISIBLE_DEVICES') 36 | cfg.add('gpu_offset', type=int, default=0, 37 | help='GPU offset, useful for KMeans') 38 | cfg.add('image_dir', type=str, required=True, 39 | help='Directory containing dataset') 40 | cfg.add('q_cap', type=int, default=102400, 41 | help='Shuffle queue capacity of tfr data') 42 | cfg.add('data_len', type=int, default=DATA_LEN_IMAGENET_FULL, 43 | help='Total number of images in the input dataset') 44 | 45 | # Training parameters 46 | cfg.add('weight_decay', type=float, default=1e-4, 47 | help='Weight decay') 48 | cfg.add('instance_t', type=float, default=0.07, 49 | help='Temperature in softmax.') 50 | cfg.add('instance_k', type=int, default=4096, 51 | help='Closes neighbors to sample.') 52 | cfg.add('lr_boundaries', type=str, default=None, 53 | help='Learning rate boundaries for 10x drops') 54 | cfg.add('train_num_steps', type=int, default=None, 55 | help='Number of overall steps for training') 56 | 57 | cfg.add('kmeans_k', type=str, default='10000', 58 | help='K for Kmeans') 59 | cfg.add('model_type', type=str, default='resnet18', 60 | help='Model type, resnet or alexnet') 61 | cfg.add('task', type=str, default='LA', 62 | help='IR for instance recognition or LA for local aggregation') 63 | 64 | # Saving parameters 65 | cfg.add('port', type=int, required=True, 66 | help='Port number for mongodb') 67 | cfg.add('db_name', type=str, required=True, 68 | help='Name of database') 69 | cfg.add('col_name', type=str, required=True, 70 | help='Name of collection') 71 | cfg.add('cache_dir', type=str, required=True, 72 | help='Prefix of saving directory') 73 | cfg.add('fre_valid', type=int, default=10009, 74 | help='Frequency of validation') 75 | cfg.add('fre_filter', type=int, default=10009, 76 | help='Frequency of saving filters') 77 | cfg.add('fre_cache_filter', type=int, 78 | help='Frequency of caching filters') 79 | 80 | # Loading parameters 81 | cfg.add('load_exp', type=str, default=None, 82 | help='The experiment to load from, in the format ' 83 | '[dbname]/[collname]/[exp_id]') 84 | cfg.add('load_port', type=int, 85 | help='Port number of mongodb for loading (defaults to saving port') 86 | cfg.add('load_step', type=int, 87 | help='Step number for loading') 88 | 89 | return cfg 90 | 91 | 92 | def loss_func(output, *args, **kwargs): 93 | loss_pure = output['losses'][0] 94 | return loss_pure 95 | 96 | 97 | def reg_loss(loss, weight_decay): 98 | # Add weight decay to the loss. 99 | def exclude_batch_norm_and_other_device(name): 100 | return 'batch_normalization' not in name 101 | l2_loss = weight_decay * tf.add_n( 102 | [tf.nn.l2_loss(tf.cast(v, tf.float32)) 103 | for v in tf.trainable_variables() 104 | if exclude_batch_norm_and_other_device(v.name)]) 105 | loss_all = tf.add(loss, l2_loss) 106 | return loss_all 107 | 108 | 109 | def rep_loss_func( 110 | inputs, 111 | output, 112 | gpu_offset=0, 113 | **kwargs 114 | ): 115 | data_indx = output['data_indx'] 116 | new_data_memory = output['new_data_memory'] 117 | 118 | memory_bank_list = output['memory_bank'] 119 | all_labels_list = output['all_labels'] 120 | if isinstance(memory_bank_list, tf.Variable): 121 | memory_bank_list = [memory_bank_list] 122 | all_labels_list = [all_labels_list] 123 | 124 | devices = ['/gpu:%i' \ 125 | % (idx + gpu_offset) for idx in range(len(memory_bank_list))] 126 | update_ops = [] 127 | for device, memory_bank, all_labels \ 128 | in zip(devices, memory_bank_list, all_labels_list): 129 | with tf.device(device): 130 | mb_update_op = tf.scatter_update( 131 | memory_bank, data_indx, new_data_memory) 132 | update_ops.append(mb_update_op) 133 | lb_update_op = tf.scatter_update( 134 | all_labels, data_indx, 135 | inputs['label']) 136 | update_ops.append(lb_update_op) 137 | 138 | with tf.control_dependencies(update_ops): 139 | # Force the updates to happen before the next batch. 140 | if len(output['losses']) == 3: 141 | _, loss_model, loss_noise = output['losses'] 142 | loss_model = tf.identity(loss_model) 143 | loss_noise = tf.identity(loss_noise) 144 | ret_dict = { 145 | 'loss_model': loss_model, 146 | 'loss_noise': loss_noise} 147 | else: 148 | loss_pure = output['losses'][0] 149 | loss_pure = tf.identity(loss_pure) 150 | ret_dict = {'loss_pure': loss_pure} 151 | return ret_dict 152 | 153 | 154 | def online_agg(agg_res, res, step): 155 | if agg_res is None: 156 | agg_res = {k: [] for k in res} 157 | for k, v in res.items(): 158 | agg_res[k].append(np.mean(v)) 159 | return agg_res 160 | 161 | 162 | def valid_perf_func( 163 | inputs, 164 | output, 165 | ): 166 | curr_dist, all_labels = output 167 | all_labels = tuple_get_one(all_labels) 168 | _, top_indices = tf.nn.top_k(curr_dist, k=1) 169 | curr_pred = tf.gather(all_labels, tf.squeeze(top_indices, axis=1)) 170 | imagenet_top1 = tf.reduce_mean( 171 | tf.cast( 172 | tf.equal(curr_pred, inputs['label']), 173 | tf.float32)) 174 | return {'top1': imagenet_top1} 175 | 176 | 177 | def get_model_func_params(args): 178 | model_params = { 179 | "data_len": args.data_len, 180 | "instance_t": args.instance_t, 181 | "instance_k": args.instance_k, 182 | "kmeans_k": args.kmeans_k, 183 | "model_type": args.model_type, 184 | "task": args.task, 185 | } 186 | return model_params 187 | 188 | 189 | def get_lr_from_boundary( 190 | global_step, boundaries, 191 | init_lr, 192 | NUM_BATCHES_PER_EPOCH): 193 | if boundaries is not None: 194 | boundaries = boundaries.split(',') 195 | boundaries = [int(each_boundary) for each_boundary in boundaries] 196 | 197 | all_lrs = [ 198 | init_lr * (0.1 ** drop_level) \ 199 | for drop_level in range(len(boundaries) + 1)] 200 | curr_lr = tf.train.piecewise_constant( 201 | x=global_step, 202 | boundaries=boundaries, values=all_lrs) 203 | else: 204 | curr_lr = tf.constant(init_lr) 205 | return curr_lr 206 | 207 | 208 | def get_params_from_arg(args): 209 | ''' 210 | This function gets parameters needed for training 211 | ''' 212 | multi_gpu = len(args.gpu.split(',')) - args.gpu_offset 213 | data_len = args.data_len 214 | val_len = 50000 215 | NUM_BATCHES_PER_EPOCH = data_len // args.batch_size 216 | 217 | # save_params: defining where to save the models 218 | args.fre_cache_filter = args.fre_cache_filter or args.fre_filter 219 | cache_dir = os.path.join( 220 | args.cache_dir, 'models', 221 | args.db_name, args.col_name, args.exp_id) 222 | save_params = { 223 | 'host': 'localhost', # used for tfutils 224 | 'port': args.port, # used for tfutils 225 | 'dbname': args.db_name, 226 | 'collname': args.col_name, 227 | 'exp_id': args.exp_id, 228 | 'do_save': True, 229 | 'save_initial_filters': True, 230 | 'save_metrics_freq': 1000, 231 | 'save_valid_freq': args.fre_valid, 232 | 'save_filters_freq': args.fre_filter, 233 | 'cache_filters_freq': args.fre_cache_filter, 234 | 'cache_dir': cache_dir, 235 | } 236 | 237 | # load_params: defining where to load, if needed 238 | load_port = args.load_port or args.port 239 | load_dbname = args.db_name 240 | load_collname = args.col_name 241 | load_exp_id = args.exp_id 242 | load_query = None 243 | 244 | if args.load_exp is not None: 245 | load_dbname, load_collname, load_exp_id = args.load_exp.split('/') 246 | if args.load_step: 247 | load_query = {'exp_id': load_exp_id, 248 | 'saved_filters': True, 249 | 'step': args.load_step} 250 | print('Load query', load_query) 251 | 252 | load_params = { 253 | 'host': 'localhost', # used for tfutils 254 | 'port': load_port, # used for tfutils 255 | 'dbname': load_dbname, 256 | 'collname': load_collname, 257 | 'exp_id': load_exp_id, 258 | 'do_restore': True, 259 | 'query': load_query, 260 | } 261 | 262 | 263 | # XXX: hack to set up training loop properly 264 | if args.kmeans_k.isdigit(): 265 | args.kmeans_k = [int(args.kmeans_k)] 266 | else: 267 | args.kmeans_k = [int(each_k) for each_k in args.kmeans_k.split(',')] 268 | clusterings = [] 269 | first_step = [] 270 | # model_params: a function that will build the model 271 | model_func_params = get_model_func_params(args) 272 | def build_output(inputs, train, **kwargs): 273 | targets = instance_model.build_targets( 274 | inputs, train, 275 | **model_func_params) 276 | if not train: 277 | return targets 278 | outputs, logged_cfg, clustering = targets 279 | clusterings.append(clustering) 280 | return outputs, logged_cfg 281 | 282 | def train_loop(sess, train_targets, **params): 283 | assert len(clusterings) == multi_gpu 284 | 285 | global_step_var = get_global_step_var() 286 | global_step = sess.run(global_step_var) 287 | 288 | # TODO: consider making this reclustering frequency a flag 289 | first_flag = len(first_step) == 0 290 | update_fre = NUM_BATCHES_PER_EPOCH 291 | if (global_step % update_fre == 0 or first_flag) \ 292 | and clusterings[0] is not None: 293 | if first_flag: 294 | first_step.append(1) 295 | print("Recomputing clusters...") 296 | new_clust_labels = clusterings[0].recompute_clusters(sess) 297 | for clustering in clusterings: 298 | clustering.apply_clusters(sess, new_clust_labels) 299 | 300 | return sess.run(train_targets) 301 | 302 | model_params = {'func': build_output} 303 | if multi_gpu > 1: 304 | model_params['num_gpus'] = multi_gpu 305 | model_params['devices'] = ['/gpu:%i' \ 306 | % (idx + args.gpu_offset) \ 307 | for idx in range(multi_gpu)] 308 | 309 | # train_params: parameters about training data 310 | train_data_param = { 311 | 'func': dataset_func, 312 | 'image_dir': args.image_dir, 313 | 'process_img_func': prep.resnet_train, 314 | 'is_train': True, 315 | 'q_cap': args.q_cap, 316 | 'batch_size': args.batch_size} 317 | train_num_steps = args.train_num_steps or float('Inf') 318 | train_params = { 319 | 'validate_first': False, 320 | 'data_params': train_data_param, 321 | 'thres_loss': float('Inf'), 322 | 'num_steps': train_num_steps, 323 | 'train_loop': {'func': train_loop}, 324 | } 325 | 326 | ## Add other loss reports (loss_model, loss_noise) 327 | train_params['targets'] = { 328 | 'func': rep_loss_func, 329 | 'gpu_offset': args.gpu_offset, 330 | } 331 | 332 | # loss_params: parameters to build the loss 333 | loss_params = { 334 | 'pred_targets': [], 335 | 'agg_func': reg_loss, 336 | 'agg_func_kwargs': {'weight_decay': args.weight_decay}, 337 | 'loss_func': loss_func, 338 | } 339 | 340 | # learning_rate_params: build the learning rate 341 | # For now, just stay the same 342 | learning_rate_params = { 343 | 'func': get_lr_from_boundary, 344 | 'init_lr': args.init_lr, 345 | 'NUM_BATCHES_PER_EPOCH': NUM_BATCHES_PER_EPOCH, 346 | 'boundaries': args.lr_boundaries, 347 | } 348 | 349 | optimizer_params = { 350 | 'optimizer': tf.train.MomentumOptimizer, 351 | 'momentum': .9, 352 | } 353 | 354 | # validation_params: control the validation 355 | topn_val_data_param = { 356 | 'func': dataset_func, 357 | 'image_dir': args.image_dir, 358 | 'process_img_func': prep.resnet_validate, 359 | 'is_train': False, 360 | 'q_cap': args.test_batch_size, 361 | 'batch_size': args.test_batch_size} 362 | val_step_num = int(val_len/args.test_batch_size) 363 | val_targets = {'func': valid_perf_func} 364 | topn_val_param = { 365 | 'data_params': topn_val_data_param, 366 | 'targets': val_targets, 367 | 'num_steps': val_step_num, 368 | 'agg_func': lambda x: {k: np.mean(v) for k, v in x.items()}, 369 | 'online_agg_func': online_agg, 370 | } 371 | 372 | validation_params = { 373 | 'topn': topn_val_param, 374 | } 375 | 376 | # Put all parameters together 377 | params = { 378 | 'save_params': save_params, 379 | 'load_params': load_params, 380 | 'model_params': model_params, 381 | 'train_params': train_params, 382 | 'loss_params': loss_params, 383 | 'learning_rate_params': learning_rate_params, 384 | 'optimizer_params': optimizer_params, 385 | 'log_device_placement': False, 386 | 'validation_params': validation_params, 387 | 'skip_check': True, 388 | } 389 | return params 390 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | import os 3 | 4 | from param_setter import get_config, get_params_from_arg 5 | from framework import TrainFramework 6 | 7 | 8 | def main(): 9 | # Parse arguments 10 | cfg = get_config() 11 | args = cfg.parse_args() 12 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 13 | 14 | # Get params needed, start training 15 | params = get_params_from_arg(args) 16 | train_framework = TrainFramework(params) 17 | train_framework.train() 18 | 19 | 20 | if __name__ == "__main__": 21 | main() 22 | -------------------------------------------------------------------------------- /train_tfutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | import os 3 | from tfutils import base 4 | 5 | import tensorflow as tf 6 | from param_setter import get_config, get_params_from_arg 7 | 8 | 9 | def reg_loss_in_faster(loss, which_device, weight_decay): 10 | from tfutils.multi_gpu.easy_variable_mgr import COPY_NAME_SCOPE 11 | curr_scope_name = '%s%i' % (COPY_NAME_SCOPE, which_device) 12 | # Add weight decay to the loss. 13 | def exclude_batch_norm_and_other_device(name): 14 | return 'batch_normalization' not in name and curr_scope_name in name 15 | l2_loss = weight_decay * tf.add_n( 16 | [tf.nn.l2_loss(tf.cast(v, tf.float32)) 17 | for v in tf.trainable_variables() 18 | if exclude_batch_norm_and_other_device(v.name)]) 19 | loss_all = tf.add(loss, l2_loss) 20 | return loss_all 21 | 22 | 23 | def main(): 24 | # Parse arguments 25 | cfg = get_config() 26 | args = cfg.parse_args() 27 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 28 | 29 | # Get params needed, start training 30 | params = get_params_from_arg(args) 31 | params['loss_params']['agg_func'] = reg_loss_in_faster 32 | cache_dir = os.path.join( 33 | args.cache_dir, 'models_tfutils', 34 | args.db_name, args.col_name, args.exp_id) 35 | params['save_params']['cache_dir'] = cache_dir 36 | base.train_from_params(**params) 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /train_transfer.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | import os 3 | from trans_param_setter import get_config, get_params_from_args 4 | from framework import TrainFramework 5 | 6 | 7 | def main(): 8 | # Parse arguments 9 | cfg = get_config() 10 | args = cfg.parse_args() 11 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 12 | 13 | params = get_params_from_args(args) 14 | train_framework = TrainFramework(params) 15 | train_framework.train() 16 | 17 | 18 | if __name__ == "__main__": 19 | main() 20 | -------------------------------------------------------------------------------- /train_transfer_tfutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | import os 3 | from trans_param_setter import get_config, get_params_from_args 4 | from tfutils import base 5 | import tensorflow as tf 6 | 7 | 8 | def reg_loss_in_faster(loss, which_device, weight_decay): 9 | from tfutils.multi_gpu.easy_variable_mgr import COPY_NAME_SCOPE 10 | curr_scope_name = '%s%i' % (COPY_NAME_SCOPE, which_device) 11 | # Add weight decay to the loss. 12 | def exclude_batch_norm_and_other_device(name): 13 | return 'batch_normalization' not in name \ 14 | and curr_scope_name in name \ 15 | and 'instance' in name 16 | l2_loss = weight_decay * tf.add_n( 17 | [tf.nn.l2_loss(tf.cast(v, tf.float32)) 18 | for v in tf.trainable_variables() 19 | if exclude_batch_norm_and_other_device(v.name)]) 20 | loss_all = tf.add(loss, l2_loss) 21 | return loss_all 22 | 23 | 24 | def main(): 25 | # Parse arguments 26 | cfg = get_config() 27 | args = cfg.parse_args() 28 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 29 | 30 | params = get_params_from_args(args) 31 | params['loss_params']['agg_func'] = reg_loss_in_faster 32 | cache_dir = os.path.join( 33 | args.cache_dir, 'models_tfutils', args.save_exp) 34 | params['save_params']['cache_dir'] = cache_dir 35 | base.train_from_params(**params) 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /trans_param_setter.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | import os, sys 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | import json 7 | import copy 8 | import argparse 9 | import time 10 | import functools 11 | import inspect 12 | 13 | from model import preprocessing as prep 14 | from model import instance_model 15 | from model.dataset_utils import dataset_func 16 | from utils import DATA_LEN_IMAGENET_FULL, tuple_get_one 17 | from param_setter import get_lr_from_boundary 18 | import config 19 | 20 | 21 | def get_config(): 22 | cfg = config.Config() 23 | cfg.add('batch_size', type=int, default=128, 24 | help='Training batch size') 25 | cfg.add('test_batch_size', type=int, default=64, 26 | help='Testing batch size') 27 | cfg.add('init_lr', type=float, default=0.01, 28 | help='Initial learning rate') 29 | cfg.add('gpu', type=str, required=True, 30 | help='Value for CUDA_VISIBLE_DEVICES') 31 | cfg.add('weight_decay', type=float, default=1e-4, 32 | help='Weight decay') 33 | cfg.add('image_dir', type=str, required=True, 34 | help='Directory containing dataset') 35 | cfg.add('q_cap', type=int, default=102400, 36 | help='Shuffle queue capacity of tfr data') 37 | cfg.add('ten_crop', type=bool, 38 | help='Whether do ten crop validation') 39 | 40 | # Loading parameters 41 | cfg.add('load_exp', type=str, required=True, 42 | help='The experiment to load from, in the format ' 43 | '[dbname]/[collname]/[exp_id]') 44 | cfg.add('load_step', type=int, default=None, 45 | help='Step number for loading') 46 | cfg.add('load_port', type=int, 47 | help='Port number of mongodb for loading (defaults to saving port') 48 | cfg.add('resume', type=bool, 49 | help='Flag for loading from last step of this exp_id, will override' 50 | ' all other loading options.') 51 | 52 | # Saving parameters 53 | cfg.add('port', type=int, required=True, 54 | help='Port number for mongodb') 55 | cfg.add('host', type=str, default='localhost', 56 | help='Host for mongodb') 57 | cfg.add('save_exp', type=str, required=True, 58 | help='The [dbname]/[collname]/[exp_id] of this experiment.') 59 | cfg.add('cache_dir', type=str, required=True, 60 | help='Prefix of saving directory') 61 | cfg.add('fre_valid', type=int, default=10009, 62 | help='Frequency of validation') 63 | cfg.add('fre_metric', type=int, default=1000, 64 | help='Frequency of saving metrics') 65 | cfg.add('fre_filter', type=int, default=10009, 66 | help='Frequency of saving filters') 67 | cfg.add('fre_cache_filter', type=int, 68 | help='Frequency of caching filters') 69 | 70 | # Training parameters 71 | cfg.add('model_type', type=str, default='resnet18', 72 | help='Model type, resnet or alexnet') 73 | cfg.add('get_all_layers', type=str, default=None, 74 | help='Whether get outputs for all layers') 75 | cfg.add('lr_boundaries', type=str, default=None, 76 | help='Learning rate boundaries for 10x drops') 77 | cfg.add('train_crop', type=str, default='default', 78 | help='Train crop style') 79 | cfg.add('num_classes', type=int, default=1000, 80 | help='Number of classes') 81 | return cfg 82 | 83 | 84 | def reg_loss(loss, weight_decay): 85 | # Add weight decay to the loss. 86 | def exclude_batch_norm_and_other_device(name): 87 | return 'batch_normalization' not in name \ 88 | and 'instance' in name 89 | l2_loss = weight_decay * tf.add_n( 90 | [tf.nn.l2_loss(tf.cast(v, tf.float32)) 91 | for v in tf.trainable_variables() 92 | if exclude_batch_norm_and_other_device(v.name)]) 93 | loss_all = tf.add(loss, l2_loss) 94 | return loss_all 95 | 96 | 97 | def add_training_params(params, args): 98 | NUM_BATCHES_PER_EPOCH = DATA_LEN_IMAGENET_FULL / args.batch_size 99 | 100 | # model_params: a function that will build the model 101 | model_params = { 102 | 'func': instance_model.build_transfer_targets, 103 | 'trainable_scopes': ['instance'], 104 | 'get_all_layers': args.get_all_layers, 105 | "model_type": args.model_type, 106 | "num_classes": args.num_classes, 107 | } 108 | multi_gpu = len(args.gpu.split(',')) 109 | if multi_gpu > 1: 110 | model_params['num_gpus'] = multi_gpu 111 | model_params['devices'] = ['/gpu:%i' % idx for idx in range(multi_gpu)] 112 | params['model_params'] = model_params 113 | 114 | # train_params: parameters about training data 115 | process_img_func = prep.resnet_train 116 | if args.train_crop == 'resnet_crop_flip': 117 | process_img_func = prep.resnet_crop_flip 118 | elif args.train_crop == 'alexnet_crop_flip': 119 | process_img_func = prep.alexnet_crop_flip 120 | elif args.train_crop == 'validate_crop': 121 | process_img_func = prep.resnet_validate 122 | 123 | train_data_param = { 124 | 'func': dataset_func, 125 | 'image_dir': args.image_dir, 126 | 'process_img_func': process_img_func, 127 | 'is_train': True, 128 | 'q_cap': args.q_cap, 129 | 'batch_size': args.batch_size} 130 | 131 | def _train_target_func( 132 | inputs, 133 | output, 134 | get_all_layers=None, 135 | *args, 136 | **kwargs): 137 | if not get_all_layers: 138 | return {'accuracy': output[1]} 139 | else: 140 | return {'accuracy': tf.reduce_mean(output[1].values())} 141 | 142 | params['train_params'] = { 143 | 'validate_first': False, 144 | 'data_params': train_data_param, 145 | 'queue_params': None, 146 | 'thres_loss': float('Inf'), 147 | 'num_steps': int(2000 * NUM_BATCHES_PER_EPOCH), 148 | 'targets': { 149 | 'func': _train_target_func, 150 | 'get_all_layers': args.get_all_layers, 151 | }, 152 | } 153 | 154 | # loss_params: parameters to build the loss 155 | def loss_func(output, *args, **kwargs): 156 | #print('loss_output', output) 157 | return output[0] 158 | params['loss_params'] = { 159 | 'pred_targets': [], 160 | # we don't want GPUs to calculate l2 loss separately 161 | 'agg_func': reg_loss, 162 | 'agg_func_kwargs': {'weight_decay': args.weight_decay}, 163 | 'loss_func': loss_func, 164 | } 165 | 166 | 167 | def add_validation_params(params, args): 168 | # validation_params: control the validation 169 | val_len = 50000 170 | valid_prep_func = prep.resnet_validate 171 | if args.ten_crop: 172 | valid_prep_func = prep.resnet_10crop_validate 173 | 174 | topn_val_data_param = { 175 | 'func': dataset_func, 176 | 'image_dir': args.image_dir, 177 | 'process_img_func': valid_prep_func, 178 | 'is_train': False, 179 | 'q_cap': args.test_batch_size, 180 | 'batch_size': args.test_batch_size} 181 | 182 | def online_agg(agg_res, res, step): 183 | if agg_res is None: 184 | agg_res = {k: [] for k in res} 185 | for k, v in res.items(): 186 | agg_res[k].append(np.mean(v)) 187 | return agg_res 188 | def valid_perf_func(inputs, output): 189 | if not args.get_all_layers: 190 | return {'top1': output} 191 | else: 192 | ret_dict = {} 193 | for key, each_out in output.items(): 194 | ret_dict['top1_{name}'.format(name=key)] = each_out 195 | return ret_dict 196 | 197 | topn_val_param = { 198 | 'data_params': topn_val_data_param, 199 | 'queue_params': None, 200 | 'targets': {'func': valid_perf_func}, 201 | # TODO: slight rounding error? 202 | 'num_steps': int(val_len/args.test_batch_size), 203 | 'agg_func': lambda x: {k: np.mean(v) for k, v in x.items()}, 204 | 'online_agg_func': online_agg, 205 | } 206 | params['validation_params'] = { 207 | 'topn': topn_val_param, 208 | } 209 | 210 | 211 | def add_save_and_load_params(params, args): 212 | # save_params: defining where to save the models 213 | db_name, col_name, exp_id = args.save_exp.split('/') 214 | cache_dir = os.path.join( 215 | args.cache_dir, 'models', 216 | db_name, col_name, exp_id) 217 | params['save_params'] = { 218 | 'host': 'localhost', # used for tfutils 219 | 'port': args.port, # used for tfutils 220 | 'dbname': db_name, 221 | 'collname': col_name, 222 | 'exp_id': exp_id, 223 | 'do_save': True, 224 | 'save_initial_filters': True, 225 | 'save_metrics_freq': args.fre_metric, 226 | 'save_valid_freq': args.fre_valid, 227 | 'save_filters_freq': args.fre_filter, 228 | 'cache_filters_freq': args.fre_cache_filter or args.fre_filter, 229 | 'cache_dir': cache_dir, 230 | } 231 | 232 | # load_params: defining where to load, if needed 233 | if args.resume or args.load_exp is None: 234 | load_exp = args.save_exp 235 | else: 236 | load_exp = args.load_exp 237 | load_dbname, load_collname, load_exp_id = load_exp.split('/') 238 | if args.resume or args.load_step is None: 239 | load_query = None 240 | else: 241 | load_query = { 242 | 'exp_id': load_exp_id, 243 | 'saved_filters': True, 244 | 'step': args.load_step 245 | } 246 | params['load_params'] = { 247 | 'host': 'localhost', # used for tfutils 248 | 'port': args.load_port or args.port, # used for tfutils 249 | 'dbname': load_dbname, 250 | 'collname': load_collname, 251 | 'exp_id': load_exp_id, 252 | 'do_restore': True, 253 | 'query': load_query, 254 | } 255 | 256 | 257 | def add_optimization_params(params, args): 258 | # learning_rate_params: build the learning rate 259 | # For now, just stay the same 260 | NUM_BATCHES_PER_EPOCH = DATA_LEN_IMAGENET_FULL / args.batch_size 261 | params['learning_rate_params'] = { 262 | 'func': get_lr_from_boundary, 263 | 'init_lr': args.init_lr, 264 | 'NUM_BATCHES_PER_EPOCH': NUM_BATCHES_PER_EPOCH, 265 | 'boundaries': args.lr_boundaries, 266 | } 267 | 268 | # optimizer_params 269 | params['optimizer_params'] = { 270 | 'optimizer': tf.train.MomentumOptimizer, 271 | 'momentum': .9, 272 | } 273 | 274 | 275 | def get_params_from_args(args): 276 | params = { 277 | 'skip_check': True, 278 | 'log_device_placement': False 279 | } 280 | 281 | add_training_params(params, args) 282 | add_save_and_load_params(params, args) 283 | add_optimization_params(params, args) 284 | add_validation_params(params, args) 285 | return params 286 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import division, print_function, absolute_import 2 | import os, sys, datetime 3 | import numpy as np 4 | import tensorflow as tf 5 | import copy 6 | import pdb 7 | from model.instance_model import DATA_LEN_IMAGENET_FULL 8 | 9 | 10 | def tuple_get_one(x): 11 | if isinstance(x, tuple) or isinstance(x, list): 12 | return x[0] 13 | return x 14 | 15 | 16 | class Logger(object): 17 | def __init__(self, path_prefix, exp_id): 18 | self.path_prefix = path_prefix 19 | self.exp_id = exp_id 20 | self.exp_path = exp_path = os.path.join(path_prefix, exp_id) 21 | os.system('mkdir -p %s' % exp_path) 22 | 23 | # Record info about this experiment 24 | config_path = os.path.join(exp_path, 'config_%s.txt' % exp_id) 25 | print('Storing experiment configs at %s' % config_path) 26 | # TODO: prompt if config already exists (indicating experiment has been run previously) 27 | with open(config_path, 'w') as f: 28 | f.write('Started running at: %s\n' % str(datetime.datetime.now())) 29 | # Record command used to run the training 30 | f.write(" ".join(sys.argv) + '\n') 31 | 32 | self._log_path = os.path.join(exp_path, 'log_performance_%s.txt' % exp_id) 33 | print('Logging results to %s' % self._log_path) 34 | self._writer = open(self._log_path, 'a+') 35 | 36 | def reopen(self): 37 | self._writer.close() 38 | self._writer = open(self._log_path, 'a+') 39 | 40 | def log(self, s, also_print=True): 41 | if also_print: 42 | print(s) 43 | sys.stdout.flush() 44 | self._writer.write(s + '\n') 45 | 46 | def close(self): 47 | self._writer.close() 48 | --------------------------------------------------------------------------------