├── hang.py ├── images ├── llama.jpeg ├── zebra.jpeg └── sealion.jpeg ├── resources └── optimization-objective-breakdown.png ├── metrics.py ├── LICENSE ├── checkpoint.py ├── README.md ├── datagenerator.py ├── utils.py ├── finetune.py ├── alexnet.py └── caffe_classes.py /hang.py: -------------------------------------------------------------------------------- 1 | print('This is a hanging script, ^C to quit') 2 | while True: pass 3 | -------------------------------------------------------------------------------- /images/llama.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuheng-liu/siamese-optimization-for-neural-nets/HEAD/images/llama.jpeg -------------------------------------------------------------------------------- /images/zebra.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuheng-liu/siamese-optimization-for-neural-nets/HEAD/images/zebra.jpeg -------------------------------------------------------------------------------- /images/sealion.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuheng-liu/siamese-optimization-for-neural-nets/HEAD/images/sealion.jpeg -------------------------------------------------------------------------------- /resources/optimization-objective-breakdown.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/shuheng-liu/siamese-optimization-for-neural-nets/HEAD/resources/optimization-objective-breakdown.png -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | class Metrics: 5 | def __init__(self, beta=0.9): 6 | self.metrics_dict = {} 7 | self.beta = beta 8 | 9 | def update_metric(self, metric, value): 10 | if value is None: 11 | value = 0. 12 | if metric in self.metrics_dict: 13 | self.metrics_dict[metric] = self.beta * self.metrics_dict[metric] + (1. - self.beta) * value 14 | else: 15 | self.metrics_dict[metric] = float(value) 16 | 17 | def update_metrics(self, metrics, values): 18 | if values is None: 19 | values = [0.] * len(metrics) 20 | for metric, value in zip(metrics, values): 21 | self.update_metric(metric, value) 22 | 23 | def write_metrics(self, metrics=None): 24 | """This function writes out metrics_dict in certain formats for FloydHub Parser to Parse 25 | and generates figures, See https://docs.floydhub.com/guides/jobs/metrics_dict/ for more 26 | information""" 27 | if metrics is None: 28 | for metric in self.metrics_dict: 29 | sys.stdout.write('{"metric": "%s", "value": %f}\n' % (metric, self.metrics_dict[metric])) 30 | else: 31 | for metric in self.metrics_dict: 32 | if (metric in metrics) or (metric == metrics): 33 | sys.stdout.write('{"metric": "%s", "value": %f}\n' % (metric, self.metrics_dict[metric])) 34 | 35 | def get_metrics_dict(self): 36 | return self.metrics_dict 37 | 38 | def get_metrics_names(self): 39 | return list(self.metrics_dict.keys) 40 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2017, Frederik Kratzert 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /checkpoint.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import heapq 4 | from alexnet import AlexNet, SiameseAlexNet, Model 5 | from tensorflow.python.client.session import BaseSession 6 | 7 | 8 | class MemCache: 9 | def __init__(self, parameters, index, metric): 10 | self._parameters = parameters 11 | self._index = index 12 | self._metric = metric 13 | 14 | def get_parameters(self): 15 | return self._parameters 16 | 17 | def get_index(self): 18 | return self._index 19 | 20 | def get_metric(self): 21 | return self._metric 22 | 23 | def set_parameters(self, parameters): 24 | self._parameters = parameters 25 | 26 | def set_index(self, index): 27 | self._index = index 28 | 29 | def sef_metric(self, metric): 30 | self._metric = metric 31 | 32 | 33 | class Checkpointer: 34 | def __init__(self, name: str, model: Model, save_path, higher_is_better=True, sess=None, mem_size=5): 35 | self.best = -1e10 if higher_is_better else 1e10 36 | self.best_type = "Highest" if higher_is_better else "Lowest" 37 | self.name = name 38 | self.model = model 39 | self.save_path = save_path 40 | self.higher_is_better = higher_is_better 41 | self.session = None 42 | # initiate self._mem_caches with the default set of values 43 | self._mem_caches = [MemCache(model.get_model_vars(sess), -1, -1e10 if higher_is_better else 1e10)] 44 | self._mem_size = mem_size 45 | self.heaper_func = heapq.nlargest if higher_is_better else heapq.nsmallest 46 | if sess is not None: 47 | self.update_session(sess) 48 | 49 | def update_session(self, sess: BaseSession): 50 | assert isinstance(sess, BaseSession), "sess is not a TensorFlow Session" 51 | self.session = sess 52 | 53 | def add_memory_cache(self, mem_cache): 54 | if not isinstance(mem_cache, MemCache): 55 | print("Ignoring non-MemCache instance", mem_cache) 56 | return 57 | self._mem_caches.append(mem_cache) 58 | self._mem_caches = self.heaper_func(self._mem_size, self._mem_caches, key=lambda cache: cache.get_metric()) 59 | 60 | def list_memory_caches(self): 61 | return self._mem_caches 62 | 63 | def update_best(self, value, epoch=0, checkpoint=True, mem_cache=False): 64 | if self._better_than_current_best(value): 65 | self._update_new_best(value, checkpoint=checkpoint) 66 | else: 67 | self._retain_current_best() 68 | 69 | if mem_cache: 70 | try: 71 | parameters = self.model.get_model_vars(self.session) 72 | self.add_memory_cache(MemCache(parameters, epoch, value)) 73 | except AttributeError as e: 74 | print(e) 75 | print("Default: not updating memory cache") 76 | 77 | def _better_than_current_best(self, value) -> bool: 78 | # implement a reader-friendly xor function 79 | if self.higher_is_better: 80 | return value > self.best 81 | else: 82 | return value < self.best 83 | 84 | def _update_new_best(self, new_best, checkpoint=True): 85 | assert self.session is not None 86 | print(self.best_type, self.name, "updated {} ---> {}".format(self.best, new_best)) 87 | self.best = new_best 88 | 89 | if checkpoint: 90 | print("Saving checkpoint at", self.save_path) 91 | self.model.save_model_vars(self.save_path, self.session) 92 | print("Checkpoint Saved") 93 | else: 94 | print("Not Saving checkpoint due to configuration") 95 | 96 | def _retain_current_best(self): 97 | print(self.best_type, self.name, "remained {}".format(self.best)) 98 | 99 | 100 | if __name__ == "__main__": 101 | x = tf.placeholder(tf.float32, [None, 227, 227, 3], name="x") 102 | keep_prob = tf.placeholder(tf.float32, [], name="keep_prob") 103 | save_path = "/Users/liushuheng/Desktop/vars" 104 | name = "xent" 105 | net = AlexNet(x, keep_prob, 3, ['fc8']) 106 | checkpointer = Checkpointer(name, net, save_path, higher_is_better=False) 107 | with tf.Session() as sess: 108 | sess.run(tf.global_variables_initializer()) 109 | for epoch in range(20): 110 | checkpointer.update_session(sess) 111 | new_value = np.random.rand(1) 112 | print("\nnew value = {}".format(new_value)) 113 | checkpointer.update_best(new_value, epoch=epoch, checkpoint=False, mem_cache=True) 114 | 115 | print(checkpointer.best) 116 | mem_caches = checkpointer.list_memory_caches() 117 | # print(checkpointer.list_memory_caches()) 118 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Partial Siamese Training - An Example With AlexNet 4 | 5 | This project aims to partially employ another optimization technique rather than pure gradient based method. 6 | The techinque is referred to as *Optimization Objective Breakdown*, as it breaks down optimization into two (or more) parts. 7 | 8 | ## How to Improve upon Gradient-Based Methods 9 | ### Illustration of the Proposed Method 10 | ![a](resources/optimization-objective-breakdown.png) 11 | ### Gradient-Based Methods 12 | Traditional methods of tuning a neural network relies heavily (and often solely) on gradients w.r.t. loss function $\nabla_{\theta}l(y, \hat{y})$, where the loss function $l(\cdot, \cdot)$ is typically a cross-entropy. 13 | When the nerual net is too deep, overfitting and gradient vanishment becomes a major problem, especially if input samples are rare or different labels share a similar distribution. 14 | 15 | ### Optmization Object Breakdown 16 | With *Optimization Objective Breakdown*, we improve upon the gradient model by splitting the network into two parts, separated by a *latent layer*. 17 | We define a pair-wise loss function on representations in the *latent layer*--our choice is the Siamese Loss. 18 | - All layers prior to the *latent layer* is trained with the objective of minimizing Siamese Loss on the *latent layer*. 19 | * In this phase, training model does not require prediction information. The latter part of the model is not used. 20 | - All layers pursuant to the *latent layer* is trained with the objective of minimizing cross-entropy on the *prediction layer*. 21 | * In this phase, only the second part of the model is trained. (The first part of the model is fixed). 22 | 23 | ## Mathematical Intuition 24 | In our paper, *Classification of Citrus Canker on Small Datasets*, we mathematically proved that this techinique can result in **linear separability** of representations on the *latent layer*, if the model is sufficiently trained. 25 | This intuition entails that the second part of the model only needs to draw a hyperplane for classification, a trivial task for any neural networks. 26 | 27 | 28 | 29 | # Version Info 30 | The code is refactored for compatibility with latest version of Tensorflow. 31 | You can find an explanation of the new input pipeline in a new [blog post](https://kratzert.github.io/2017/06/15/example-of-tensorflows-new-input-pipeline.html) 32 | You can use this code as before for finetuning AlexNet on your own dataset without OpenCV. 33 | 34 | This repository contains all the code needed to finetune [AlexNet](http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf) on any arbitrary dataset. Beside the comments in the code itself, I also wrote an article which you can fine [here](https://kratzert.github.io/2017/02/24/finetuning-alexnet-with-tensorflow.html) with further explanation. 35 | 36 | All you need are the pretrained weights, which you can find [here](http://www.cs.toronto.edu/~guerzhoy/tf_alexnet/) or convert yourself from the caffe library using [caffe-to-tensorflow](https://github.com/ethereon/caffe-tensorflow). 37 | If you convert them on your own, take a look on the structure of the `.npy` weights file (dict of dicts or dict of lists). 38 | 39 | **Note**: I won't write to much of an explanation here, as I already wrote a long article about the entire code on my blog. 40 | 41 | ## Requirements 42 | 43 | - Python 3 44 | - TensorFlow >= 1.12rc0 45 | - Numpy 46 | 47 | 48 | ## TensorBoard support 49 | 50 | The code has TensorFlows summaries implemented so that you can follow the training progress in TensorBoard. (--logdir in the config section of `finetune.py`) 51 | 52 | ## Content 53 | 54 | - `alexnet.py`: Class with the graph definition of the AlexNet. 55 | - `finetune.py`: Script to run the finetuning process. 56 | - `datagenerator.py`: Contains a wrapper class for the new input pipeline. 57 | - `caffe_classes.py`: List of the 1000 class names of ImageNet (copied from [here](http://www.cs.toronto.edu/~guerzhoy/tf_alexnet/)). 58 | - `validate_alexnet_on_imagenet.ipynb`: Notebook to test the correct implementation of AlexNet and the pretrained weights on some images from the ImageNet database. 59 | - `images/*`: contains three example images, needed for the notebook. 60 | 61 | ## Usage 62 | 63 | All you need to touch is the `finetune.py`, although I strongly recommend to take a look at the entire code of this repository. In the `finetune.py` script you will find a section of configuration settings you have to adapt on your problem. 64 | If you do not want to touch the code any further than necessary you have to provide two `.txt` files to the script (`train.txt` and `val.txt`). Each of them list the complete path to your train/val images together with the class number in the following structure. 65 | 66 | ``` 67 | Example train.txt: 68 | /path/to/train/image1.png 0 69 | /path/to/train/image2.png 1 70 | /path/to/train/image3.png 2 71 | /path/to/train/image4.png 0 72 | . 73 | . 74 | ``` 75 | were the first column is the path and the second the class label. 76 | 77 | The other option is that you bring your own method of loading images and providing batches of images and labels, but then you have to adapt the code on a few lines. 78 | -------------------------------------------------------------------------------- /datagenerator.py: -------------------------------------------------------------------------------- 1 | # Created on Wed May 31 14:48:46 2017 2 | # 3 | # @author: Frederik Kratzert 4 | 5 | """Containes a helper class for image input pipelines in tensorflow.""" 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | try: 10 | from tensorflow.contrib.data import Dataset 11 | except ImportError as e: 12 | print(e) 13 | print("importing Dataset from tf.data") 14 | Dataset = tf.data.Dataset 15 | from tensorflow.python.framework import dtypes 16 | from tensorflow.python.framework.ops import convert_to_tensor 17 | 18 | IMAGENET_MEAN = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32) 19 | 20 | 21 | class ImageDataGenerator(object): 22 | """Wrapper class around the new Tensorflows dataset pipeline. 23 | 24 | Requires Tensorflow >= version 1.12rc0 25 | """ 26 | 27 | def __init__(self, txt_file, mode, batch_size, num_classes, shuffle=True, 28 | buffer_size=1000): 29 | """Create a new ImageDataGenerator. 30 | 31 | Receives a path string to a text file, which consists of many lines, 32 | where each line has first a path string to an image and separated by 33 | a space an integer, referring to the class number. Using this data, 34 | this class will create TensorFlow datasets, that can be used to train 35 | e.g. a convolutional neural network. 36 | 37 | Args: 38 | txt_file: Path to the text file. 39 | mode: Either 'training' or 'validation'. Depending on this value, 40 | different parsing functions will be used. 41 | batch_size: Number of images per batch. 42 | num_classes: Number of classes in the dataset. 43 | shuffle: Whether or not to shuffle the data in the dataset and the 44 | initial file list. 45 | buffer_size: Number of images used as buffer for TensorFlows 46 | shuffling of the dataset. 47 | 48 | Raises: 49 | ValueError: If an invalid mode is passed. 50 | 51 | """ 52 | self.txt_file = txt_file 53 | self.num_classes = num_classes 54 | self.BUFFER_SIZE = buffer_size 55 | self.BATCH_SIZE = batch_size 56 | 57 | # retrieve the data from the text file 58 | self._read_txt_file() 59 | # number of samples in the dataset 60 | self.data_size = len(self.labels) 61 | # initial shuffling of the file and label lists (together!) 62 | if shuffle: 63 | self._shuffle_lists() 64 | 65 | # convert lists to TF tensor 66 | self.img_paths = convert_to_tensor(self.img_paths, dtype=dtypes.string) 67 | self.labels = convert_to_tensor(self.labels, dtype=dtypes.int32) 68 | 69 | # create dataset 70 | data = Dataset.from_tensor_slices((self.img_paths, self.labels)) # type: Dataset 71 | 72 | # distinguish between train/infer. when calling the parsing functions 73 | if mode == 'training': 74 | data = data.map(self._parse_function_train, num_threads=8, 75 | output_buffer_size=100 * batch_size) 76 | 77 | elif mode == 'inference': 78 | data = data.map(self._parse_function_inference, num_threads=8, 79 | output_buffer_size=100 * batch_size) 80 | 81 | else: 82 | raise ValueError("Invalid mode '%s'." % (mode)) 83 | 84 | # shuffle the first `buffer_size` elements of the dataset 85 | if shuffle: 86 | data = data.shuffle(buffer_size=buffer_size) 87 | 88 | # TODO consider whether the following .repeat() method is necessary 89 | data = data.repeat() 90 | 91 | # create a new dataset with batches of images 92 | data = data.batch(batch_size) 93 | 94 | self.data = data 95 | 96 | def _read_txt_file(self): 97 | """Read the content of the text file and store it into lists.""" 98 | self.img_paths = [] 99 | self.labels = [] 100 | with open(self.txt_file, 'r') as f: 101 | lines = f.readlines() 102 | for line in lines: 103 | items = line.split(' ') 104 | self.img_paths.append(' '.join(items[:-1])) # in case the filename contains spaces 105 | self.labels.append(int(items[-1])) 106 | 107 | def _shuffle_lists(self): 108 | """Conjoined shuffling of the list of paths and labels.""" 109 | path = self.img_paths 110 | labels = self.labels 111 | permutation = np.random.permutation(self.data_size) 112 | self.img_paths = [] 113 | self.labels = [] 114 | for i in permutation: 115 | self.img_paths.append(path[i]) 116 | self.labels.append(labels[i]) 117 | 118 | def _parse_function_train(self, filename, label): 119 | """Input parser for samples of the training set.""" 120 | # convert label number into one-hot-encoding 121 | one_hot = tf.one_hot(label, self.num_classes) 122 | 123 | # load and preprocess the image 124 | img_string = tf.read_file(filename) 125 | img_decoded = tf.image.decode_jpeg(img_string, channels=3) 126 | img_resized = tf.image.resize_images(img_decoded, [227, 227]) 127 | """ 128 | Dataaugmentation comes here. 129 | """ 130 | img_centered = tf.subtract(img_resized, IMAGENET_MEAN) 131 | 132 | # RGB -> BGR 133 | img_bgr = img_centered[:, :, ::-1] 134 | 135 | return img_bgr, one_hot 136 | 137 | def _parse_function_inference(self, filename, label): 138 | """Input parser for samples of the validation/test set.""" 139 | # convert label number into one-hot-encoding 140 | one_hot = tf.one_hot(label, self.num_classes) 141 | 142 | # load and preprocess the image 143 | img_string = tf.read_file(filename) 144 | img_decoded = tf.image.decode_jpeg(img_string, channels=3) 145 | img_resized = tf.image.resize_images(img_decoded, [227, 227]) 146 | img_centered = tf.subtract(img_resized, IMAGENET_MEAN) 147 | 148 | # RGB -> BGR 149 | img_bgr = img_centered[:, :, ::-1] 150 | 151 | return img_bgr, one_hot 152 | 153 | def reshuffle_data(self): 154 | self.data = self.data.shuffle(buffer_size=self.BUFFER_SIZE) # type: Dataset 155 | self.data = self.data.batch(self.BATCH_SIZE) # type: Dataset 156 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from datagenerator import ImageDataGenerator 4 | 5 | 6 | # generate a txt file containing image paths and labels 7 | def make_list(folders, flags=None, ceils=None, mode='train', store_path='/output'): 8 | suffices = ('jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG') 9 | if ceils is None: ceils = [-1] * len(folders) # ceil constraint not imposed 10 | if flags is None: flags = list(range(len(folders))) # flags = [0, 1, ..., n-1] 11 | assert len(folders) == len(flags) == len(ceils), (len(folders), len(flags), len(ceils)) 12 | assert mode in ['train', 'val', 'test'] 13 | folders_flags_ceils = [tup for tup in zip(folders, flags, ceils) 14 | if isinstance(tup[0], str) and os.path.isdir(tup[0])] 15 | assert folders_flags_ceils 16 | 17 | print('Making %s list' % mode) 18 | for tup in folders_flags_ceils: 19 | print('Folder {}: flag = {}, ceil = {}'.format(*tup)) 20 | if not os.path.isdir(store_path): os.mkdir(store_path) 21 | out_list = os.path.join(store_path, mode + '.txt') 22 | list_length = 0 23 | with open(out_list, 'w') as fo: 24 | for (folder, flag, ceil) in folders_flags_ceils: 25 | count = 0 26 | for pic_name in os.listdir(folder): 27 | if pic_name.split('.')[-1] not in suffices: 28 | print('Ignoring non-image file {} in folder {}.'.format(pic_name, folder), 29 | 'Legal suffices are', suffices) 30 | continue 31 | count += 1 32 | list_length += 1 33 | fo.write("{} {}\n".format(os.path.join(folder, pic_name), flag)) 34 | # if ceil is imposed (ceil > 0) and count exceeds ceil, break and write next flag 35 | if 0 < ceil <= count: break 36 | print('%s list made\n' % mode) 37 | return out_list, list_length 38 | 39 | 40 | # find a suitable batchSize 41 | # TODO redefine auto_adatpt_batch() for suitable behaviour of Siamese Training 42 | def auto_adapt_batch(train_size, val_size, batch_count_multiple=1, max_size=256): 43 | """ 44 | returns a suitable batch size according to train and val dataset size, 45 | say max_size = 128, and val_size is smaller than train_size, 46 | if val_size < 128, the batch_size1 to be returned is val_size 47 | if 128 < val_size <= 256, the batch size is 1/2 of val_size, at most 1 validation sample cannot be used 48 | if 256 < val_size <= 384, the batch size is 1/3 of val_size, at most 2 validation samples cannot be used 49 | ... 50 | :param train_size: the number of training samples in the training set 51 | :param val_size: the number of validation samples in the validation set 52 | :param max_size: the maximum batch_size1 that is allowed to be returned 53 | :param batch_count_multiple: force the batch count to be a multiple of this number, default = 1 54 | :return: a suitable batch_size1 for the input 55 | """ 56 | print('Auto adapting batch size...') 57 | numerator = min(train_size, val_size) 58 | denominator = 0 59 | while True: 60 | denominator += batch_count_multiple 61 | batch_size = numerator // denominator 62 | if batch_size <= max_size: return batch_size 63 | 64 | 65 | # make train & val & test list, and do stats 66 | def determine_list(opt, sample_path): 67 | # paths 68 | train0, train1, train2 = opt.train0.split(), opt.train1.split(), opt.train2.split() 69 | val0, val1, val2 = opt.val0.split(), opt.val1.split(), opt.val2.split() 70 | train, val = train0 + train1 + train2, val0 + val1 + val2 71 | test = train + val 72 | # flags 73 | train_flags = [0] * len(train0) + [1] * len(train1) + [2] * len(train2) 74 | val_flags = [0] * len(val0) + [1] * len(val1) + [2] * len(val2) 75 | test_flags = train_flags + val_flags 76 | # ceils 77 | train_ceils = opt.trainCeils.split() if opt.trainCeils else [-1] * len(train_flags) 78 | train_ceils = [int(c) for c in train_ceils] 79 | val_ceils = opt.valCeils.split() if opt.valCeils else [-1] * len(train_flags) 80 | val_ceils = [int(c) for c in val_ceils] 81 | test_ceils = train_ceils + val_ceils 82 | # do list generating 83 | train_file, train_length = make_list(train, flags=train_flags, ceils=train_ceils, mode='train', 84 | store_path=sample_path) 85 | val_file, val_length = make_list(val, flags=val_flags, mode='val', ceils=val_ceils, store_path=sample_path) 86 | test_file, test_length = make_list(test, flags=test_flags, mode='test', ceils=test_ceils, store_path=sample_path) 87 | return train_file, train_length, val_file, val_length, test_file, test_length 88 | 89 | 90 | # print the info about training, validation and testing sets 91 | def print_info(batch_size, train_batches, val_batches, test_batches, train_length, val_length, 92 | test_length, model='AlexNet'): 93 | assert model in ['SiameseAlexNet', 'AlexNet'] 94 | 95 | print("********** In model %s **********" % model) 96 | print('%d samples in training set' % train_length) 97 | print('%d samples in validation set' % val_length) 98 | print('Train-Val ratio == %.1f%s : %.1f%s' % (100 * train_length / (train_length + val_length), '%', 99 | 100 * val_length / (train_length + val_length), '%')) 100 | print('Batch Size =', batch_size) 101 | print('Of all %d val samples, %d is utilized, percentage = %.1f%s' % (val_length, val_batches * batch_size, 102 | val_batches * batch_size / val_length 103 | * 100, '%')) 104 | print('Of all %d train samples, %d is utilized, percentage = %.1f%s' % (train_length, train_batches * batch_size, 105 | train_batches * batch_size / train_length 106 | * 100, '%')) 107 | print('Of all %d test samples, %d is utilized, percentage = %.1f%s' % (test_length, test_batches * batch_size, 108 | test_batches * batch_size / test_length 109 | * 100, '%')) 110 | 111 | 112 | def get_environment_parameters(): 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument('--train0', required=True, help='paths to negative training dataset, separated by space') 115 | parser.add_argument('--train1', required=True, help='paths to positive training dataset, separated by space') 116 | parser.add_argument('--train2', default='', help='paths to other disease training dataset, separated by space') 117 | parser.add_argument('--trainCeils', default=None, help='Ceils of Training') 118 | parser.add_argument('--val0', required=True, help='paths to negative validation dataset, separated by space') 119 | parser.add_argument('--val1', required=True, help='paths to positive validation dataset, separated by space') 120 | parser.add_argument('--val2', default='', help='paths to other disease validation dataset, separated by space') 121 | parser.add_argument('--valCeils', default=None, help='Ceils of validation') 122 | parser.add_argument('--lr1', type=float, default=3e-4, help='learning rate for supervised learning, default=1e-3') 123 | parser.add_argument('--lr2', type=float, default=5e-7, help='learning rate for siamese learning, default=5e-7') 124 | parser.add_argument('--nepochs1', type=int, default=100, help='number of supervised epochs, default = 100') 125 | parser.add_argument('--nepochs2', type=int, default=100, help='number of siamese epochs, default = 100') 126 | parser.add_argument('--batchSize1', type=int, default=0, help='default = automatic-adapting') 127 | parser.add_argument('--batchSize2', type=int, default=0, help='default = automatic-adapting') 128 | parser.add_argument('--dropout', type=int, default=0.5, help='dropout rate for alexnet, default = 0.5') 129 | parser.add_argument('--nclasses', type=int, default=2, help='number of classes, default = 2') 130 | parser.add_argument('--trainLayers1', type=str, default='fc7 fc8', help='default = fc7 fc8') 131 | parser.add_argument('--trainLayers2', type=str, default='fc6', help='default = fc6') 132 | parser.add_argument('--displayStep', type=int, default=200, help='How often to write tf.summary') 133 | parser.add_argument('--outf', type=str, default='/output', help='path for checkpoints & tf.summary & samplelist') 134 | parser.add_argument('--pretrained', type=str, default='/', help='path for pre-trained weights *.npy') 135 | parser.add_argument('--noCheck', action='store_true', help='don\'t save model checkpoints') 136 | parser.add_argument('--siamese', type=str, default='dropout6', help='siamese projection layers, default=dropout6') 137 | parser.add_argument('--checkStd', type=str, default='xent', help='Standard for checkpointing, acc or xent') 138 | parser.add_argument('--margin00', type=float, default=8.0, help='distance margin for neg-neg pair, default=10.0') 139 | parser.add_argument('--margin11', type=float, default=6.0, help='distance margin for pos-pos pair, default=8.0') 140 | parser.add_argument('--margin01', type=float, default=7.07, help='distance margin for neg-pos pair, default=7.0') 141 | parser.add_argument('--punish00', type=float, default=1.0, help='punishment for neg-neg pair, default=1.0') 142 | parser.add_argument('--punish11', type=float, default=1.0, help='punishment for pos-pos pair, default=1.0') 143 | parser.add_argument('--punish01', type=float, default=5.0, help='punishment for neg-pos pair, default=5.0') 144 | return parser.parse_args() 145 | 146 | 147 | def get_init_op(iterator, some_data: ImageDataGenerator): 148 | return iterator.make_initializer(some_data.data) 149 | 150 | 151 | def get_precision_recall_fscore(TP=0, TN=0, FP=0, FN=0, alpha=1.0): 152 | # convert to int instead of np.int32 153 | TP, TN, FP, FN = int(TP), int(TN), int(FP), int(FN) 154 | 155 | # get precision 156 | try: 157 | precision = TP / (TP + FP) 158 | except ZeroDivisionError: 159 | print('ZeroDivisionError in calculating precision') 160 | precision = 0 161 | 162 | # get recall 163 | try: 164 | recall = TP / (TP + FN) 165 | except ZeroDivisionError: 166 | print('ZeroDivisionError in calculating recall') 167 | recall = 0 168 | 169 | # get F score 170 | alpha_squared = alpha ** 2 # calculate alpha^2 in advance 171 | try: 172 | fscore = (1 + alpha_squared) * TP / ((1 + alpha_squared) * TP + alpha_squared * FN + FP) 173 | except ZeroDivisionError: 174 | print('ZeroDivisionError in calculating F score') 175 | fscore = 0 176 | 177 | return precision, recall, fscore 178 | -------------------------------------------------------------------------------- /finetune.py: -------------------------------------------------------------------------------- 1 | """Script to finetune AlexNet using Tensorflow. 2 | 3 | With this script you can finetune AlexNet as provided in the alexnet.py 4 | class on any given dataset. Specify the configuration settings at the 5 | beginning according to your problem. 6 | This script was written for TensorFlow >= version 1.2rc0 and comes with a blog 7 | post, which you can find here: 8 | 9 | https://kratzert.github.io/2017/02/24/finetuning-alexnet-with-tensorflow.html 10 | 11 | Author: Frederik Kratzert 12 | contact: f.kratzert(at)gmail.com 13 | """ 14 | 15 | import math 16 | import os 17 | from datetime import datetime 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | try: 22 | from tensorflow.contrib.data import Iterator 23 | except ImportError as e: 24 | print(e) 25 | print("importing Iterator from tf.data instead") 26 | Iterator = tf.data.Iterator 27 | from tensorflow.python.framework.errors_impl import OutOfRangeError 28 | 29 | from alexnet import AlexNet, SiameseAlexNet 30 | from datagenerator import ImageDataGenerator 31 | from checkpoint import Checkpointer, MemCache 32 | from metrics import Metrics 33 | from utils import \ 34 | auto_adapt_batch, \ 35 | determine_list, \ 36 | print_info, \ 37 | get_environment_parameters, \ 38 | get_init_op, \ 39 | get_precision_recall_fscore 40 | 41 | """ 42 | Configuration Part. 43 | """ 44 | 45 | np.set_printoptions(threshold=np.inf) 46 | 47 | # define a function to write metrics_dict for floydhub to parse 48 | sMetrics = Metrics(beta=0.) 49 | aMetrics = Metrics(beta=0.) 50 | 51 | opt = get_environment_parameters() 52 | print(opt) 53 | 54 | # Learning params 55 | aLR, sLR = opt.lr1, opt.lr2 56 | aEpochs, sEpochs = opt.nepochs1, opt.nepochs2 57 | 58 | # Network params 59 | dropout_rate = opt.dropout 60 | if opt.nclasses == 0: 61 | if opt.val2 and opt.train2: 62 | num_classes = 3 63 | else: 64 | num_classes = 2 65 | else: 66 | num_classes = opt.nclasses 67 | print('There are %d labels for classification' % num_classes) 68 | # train_layers = opt.trainLayers.split() 69 | aTrainLayers, sTrainLayers = set(opt.trainLayers1.split()), set(opt.trainLayers2.split()) 70 | # train_layers = aTrainLayers | sTrainLayers 71 | # How often we want to write the tf.summary data to disk 72 | display_step = opt.displayStep 73 | assert opt.checkStd in ['acc', 'xent'], 'Illegal check standard, %s' % opt.checkStd 74 | 75 | # Path for tf.summary.FileWriter and to store model checkpoints 76 | filewriter_path = opt.outf 77 | checkpoint_path = os.path.join(opt.outf, 'checkpoints') 78 | sample_path = os.path.join(opt.outf, 'samplelist') 79 | 80 | train_file, train_length, val_file, val_length, test_file, test_length = determine_list(opt, sample_path) 81 | 82 | batch_size1 = opt.batchSize1 if opt.batchSize1 else auto_adapt_batch(train_length, val_length) 83 | batch_size2 = opt.batchSize2 if opt.batchSize2 else auto_adapt_batch(train_length, val_length) // 2 84 | 85 | # TODO overall debugging, paying special attention to the instantiation of tf.Session 86 | """ 87 | Main Part of the finetuning Script. 88 | """ 89 | 90 | # Create parent path if it doesn't exist 91 | if not os.path.isdir(checkpoint_path): 92 | os.mkdir(checkpoint_path) 93 | 94 | # Place data loading and preprocessing on the cpu 95 | with tf.device('/cpu:0'): 96 | train_data = ImageDataGenerator(train_file, 97 | mode='training', 98 | batch_size=batch_size1, 99 | num_classes=num_classes, 100 | shuffle=True) 101 | val_data = ImageDataGenerator(val_file, 102 | mode='inference', 103 | batch_size=batch_size1, 104 | num_classes=num_classes, 105 | shuffle=True) 106 | test_data = ImageDataGenerator(test_file, 107 | mode='inference', 108 | batch_size=batch_size1, 109 | num_classes=num_classes, 110 | shuffle=True) 111 | 112 | # create an reinitializable iterator given the dataset structure 113 | iterator = Iterator.from_structure(train_data.data.output_types, 114 | train_data.data.output_shapes) 115 | next_batch = iterator.get_next() 116 | print('data loaded and preprocessed on the cpu') 117 | 118 | # Ops for initializing the two different iterators 119 | training_init_op = get_init_op(iterator, train_data) 120 | validation_init_op = get_init_op(iterator, val_data) 121 | testing_init_op = get_init_op(iterator, test_data) 122 | 123 | # TF placeholder for graph input and output 124 | # y = tf.placeholder(tf.float32, [None, num_classes]) 125 | keep_prob = tf.placeholder(tf.float32) 126 | 127 | # Initialize the FileWriter 128 | train_writer = tf.summary.FileWriter(os.path.join(filewriter_path, 'train')) 129 | val_writer = tf.summary.FileWriter(os.path.join(filewriter_path, 'val')) 130 | 131 | # Initialize an saver for store model checkpoints 132 | # saver = tf.train.Saver() 133 | 134 | # Get the number of training/validation steps per epoch 135 | aTrainBatches = int(train_data.data_size // batch_size1) 136 | aValBatches = int(val_data.data_size // batch_size1) 137 | aTestBatches = int(test_data.data_size // batch_size1) 138 | 139 | sTrainBatches = int(train_data.data_size // batch_size2) // 2 140 | sValBatches = int(val_data.data_size // batch_size2) // 2 141 | sTestBatches = int(test_data.data_size // batch_size2) // 2 142 | 143 | print_info(batch_size1, aTrainBatches, aValBatches, aTestBatches, train_length, val_length, test_length, 144 | model='AlexNet') 145 | print_info(batch_size2, sTrainBatches * 2, sValBatches * 2, sTestBatches * 2, train_length, val_length, test_length, 146 | model='SiameseAlexNet') 147 | 148 | # instantiate an InteractiveSession() instance, which is the only Session in the script 149 | sess = tf.InteractiveSession() 150 | 151 | # create and prepare a Siamese AlexNet 152 | # create 2 placeholders for the AlexNets nested in Siamese Net 153 | x1 = tf.placeholder(tf.float32, [None, 227, 227, 3], name='x1') 154 | x2 = tf.placeholder(tf.float32, [None, 227, 227, 3], name='x2') 155 | # Construct the Siamese Model 156 | sNet = SiameseAlexNet(x1, x2, keep_prob, num_classes, sTrainLayers, weights_path=opt.pretrained, 157 | margin00=opt.margin00, margin01=opt.margin01, margin11=opt.margin11, 158 | punish00=opt.punish00, punish01=opt.punish01, punish11=opt.punish11, 159 | proj=opt.siamese) 160 | y1, y2 = sNet.net1.y, sNet.net2.y # get the label placeholders of the Siamese Net 161 | sVars = [v for v in tf.trainable_variables() if v.name.split('/')[-2] in sTrainLayers] 162 | sLoss = sNet.loss # get loss tensor of the Siamese Net 163 | with tf.name_scope("siamese-train"): # define the train_op of Siamese Training 164 | sGradients = tf.gradients(sLoss, sVars) 165 | sGradients = list(zip(sGradients, sVars)) 166 | 167 | # get rid of the null elements 168 | sGradients = [g_and_v for g_and_v in sGradients if g_and_v[0] is not None] 169 | sVars = [g_and_v[1] for g_and_v in sGradients] 170 | print('Siamese Variables are:', sVars) 171 | 172 | sOptimizer = tf.train.GradientDescentOptimizer(sLR) 173 | sTrainOp = sOptimizer.apply_gradients(grads_and_vars=sGradients) 174 | for gradient, var in sGradients: # write summary of gradients 175 | tf.summary.histogram(var.name + "/gradient-Siamese", gradient) 176 | for var in sVars: # write summary of variables 177 | tf.summary.histogram(var.name, var) 178 | sLoss_summ = tf.summary.scalar('siamese-loss', sLoss) # write summary of siamese loss 179 | 180 | # Siamese Training for n1 epochs 181 | sess.run(tf.global_variables_initializer()) 182 | train_writer.add_graph(sess.graph) 183 | sNet.load_model_pretrained(sess) 184 | # initialize a best-metrics update mechanism, i.e. checkpointer 185 | ckpt_path = os.path.join(checkpoint_path, "siam.npy") 186 | checkpointer = Checkpointer("Siamese Loss(Val)", sNet, ckpt_path, higher_is_better=False, sess=sess) 187 | 188 | # bach sizes are divided by half in siamese training 189 | # dealt with NaN in siamese training 190 | # REVIEW Val-Loss is way larger than Train-Loss, possibly due to different dropout rate in train and val 191 | for epoch in range(sEpochs): 192 | print("------- Siamese Epoch number: {} ------- ".format(epoch + 1)) 193 | 194 | print("start training, %d batches in total" % sTrainBatches) 195 | sess.run(training_init_op) 196 | for step in range(sTrainBatches): 197 | try: 198 | img_batch1, label_batch1 = sess.run(next_batch) 199 | img_batch2, label_batch2 = sess.run(next_batch) 200 | sess.run([sTrainOp], feed_dict={x1: img_batch1, 201 | x2: img_batch2, 202 | y1: label_batch1, 203 | y2: label_batch2, 204 | keep_prob: 1.}) 205 | s, _loss = sess.run([sLoss_summ, sLoss], feed_dict={x1: img_batch1, 206 | x2: img_batch2, 207 | y1: label_batch1, 208 | y2: label_batch2, 209 | keep_prob: 1.}) 210 | print("step: %d, loss = %f" % (step, _loss)) 211 | count00, count11, count01, loss00, loss11, loss01 = \ 212 | sess.run([sNet.count00, sNet.count11, sNet.count01, sNet.loss00, sNet.loss11, sNet.loss01], 213 | feed_dict={x1: img_batch1, 214 | x2: img_batch2, 215 | y1: label_batch1, 216 | y2: label_batch2, 217 | keep_prob: 1.}) 218 | print('neg-neg-count =', count00, 'mean neg-neg loss =', loss00) 219 | print('pos-pos-count =', count11, 'mean pos-pos loss =', loss11) 220 | print('neg-pos-count =', count01, 'mean neg-pos loss =', loss01) 221 | if (epoch * sTrainBatches + step) % display_step == 0: 222 | train_writer.add_summary(s) 223 | except OutOfRangeError as e: 224 | print(e) 225 | print('ignoring residue batches in step %d' % step) 226 | except Exception as e: 227 | print(e) 228 | print('some other exception occurred in step %d' % step) 229 | print("start validation, %d batches in total" % sValBatches) 230 | sess.run(validation_init_op) 231 | val_loss = 0 232 | for step in range(sValBatches): 233 | try: 234 | img_batch1, label_batch1 = sess.run(next_batch) 235 | img_batch2, label_batch2 = sess.run(next_batch) 236 | count00, count11, count01, loss00, loss11, loss01 = \ 237 | sess.run([sNet.count00, sNet.count11, sNet.count01, sNet.loss00, sNet.loss11, sNet.loss01], 238 | feed_dict={x1: img_batch1, 239 | x2: img_batch2, 240 | y1: label_batch1, 241 | y2: label_batch2, 242 | keep_prob: 1.}) 243 | print('neg-neg-count =', count00, 'mean neg-neg loss =', loss00) 244 | print('pos-pos-count =', count11, 'mean pos-pos loss =', loss11) 245 | print('neg-pos-count =', count01, 'mean neg-pos loss =', loss01) 246 | s, step_loss = sess.run([sLoss_summ, sLoss], feed_dict={x1: img_batch1, 247 | x2: img_batch2, 248 | y1: label_batch1, 249 | y2: label_batch2, 250 | keep_prob: 1.}) 251 | print("step: %d, loss = %f" % (step, step_loss)) 252 | val_loss += step_loss 253 | if (epoch * sValBatches + step) % display_step == 0: 254 | val_writer.add_summary(s) 255 | except OutOfRangeError as e: 256 | print(e) 257 | print('ignoring residue batches in step %d' % step) 258 | except Exception as e: 259 | print(e) 260 | print('some other exception occurred in step %d' % step) 261 | val_loss /= sValBatches 262 | sMetrics.update_metric('siam-loss-val', val_loss) 263 | sMetrics.write_metrics() 264 | 265 | # do checkpointing 266 | checkpointer.update_best(val_loss, checkpoint=(not opt.noCheck), mem_cache=True, epoch=epoch) 267 | 268 | # re-shuffle the training set to generate new pairs for siamese training 269 | train_data.reshuffle_data() 270 | 271 | # after training, save the parameters corresponding to the lowest losses 272 | mem_caches = checkpointer.list_memory_caches() 273 | 274 | # get and prepare the AlexNet of that Siamese Net 275 | # set the parameters to `mem_size` sets of parameters with lowest losses, currently using only the first one 276 | # with tf.Session() as sess: 277 | # sNet.set_model_vars(mem_caches[0].get_parameters(), sess) 278 | 279 | x = tf.placeholder(tf.float32, [None, 227, 227, 3], name="x") # get the input placeholder 280 | # sess.run(tf.global_variables_initializer()) 281 | aNet = sNet.get_net_copy(sess, x=x, train_layers=aTrainLayers) # get an AlexNet with trained Variables 282 | # aNet = AlexNet(x, keep_prob, num_classes, aTrainLayers, weights_path=opt.pretrained) 283 | y = aNet.y # grab the label placeholder of the alexNet 284 | aVars = [v for v in tf.trainable_variables() if v.name.split('/')[-2] in aTrainLayers] 285 | # get the metrics of the AlexNet 286 | aLoss, accuracy, precision, recall, F_alpha = aNet.loss, aNet.accuracy, aNet.precision, aNet.recall, aNet.F_alpha 287 | with tf.name_scope("classification-train"): 288 | aGradients = tf.gradients(aLoss, aVars) 289 | aGradients = list(zip(aGradients, aVars)) 290 | 291 | # get rid of the null elements 292 | aGradients = [g_and_v for g_and_v in aGradients if g_and_v[0] is not None] 293 | aVars = [g_and_v[1] for g_and_v in aGradients] 294 | print('AlexNet Variables are:', aVars) 295 | 296 | aOptimizer = tf.train.GradientDescentOptimizer(aLR) 297 | aTrainOp = aOptimizer.apply_gradients(grads_and_vars=aGradients) 298 | for gradient, var in aGradients: 299 | tf.summary.histogram(var.name + "/gradients-AlexNet", gradient) 300 | for var in aVars: 301 | tf.summary.histogram(var.name, var) 302 | # get the metric summaries of the AlexNet 303 | aLoss_summ = tf.summary.scalar('xent-loss', aLoss) 304 | accuracy_summ = tf.summary.scalar('accuracy', accuracy) 305 | precision_summ = tf.summary.scalar('precision', precision) 306 | recall_summ = tf.summary.scalar('recall', recall) 307 | F_alpha_summ = tf.summary.scalar('F_alpha', F_alpha) 308 | alexnet_summ = tf.summary.merge([aLoss_summ, accuracy_summ, precision_summ, recall_summ, F_alpha_summ]) 309 | 310 | 311 | # Classification Training for n2 epochs 312 | def alexnet_training(aNet, param_set_id): 313 | global sess, ckpt_path, checkpointer, epoch, step, s, val_loss, step_loss 314 | # REVIEW maybe some variables should be initialized in this function 315 | # sess.run(tf.global_variables_initializer()) 316 | train_writer.add_graph(sess.graph) 317 | # since the aNet is obtained from sNet, pre-trained weights does not need to be loaded 318 | 319 | ckpt_path = os.path.join(checkpoint_path, "alex.npy") 320 | checkpointer = Checkpointer(opt.checkStd, aNet, ckpt_path, higher_is_better=(opt.checkStd == "acc"), sess=sess) 321 | 322 | for epoch in range(aEpochs): 323 | print("------- Parameter Set: {}, AlexNet Epoch number: {} ------- ".format(param_set_id, epoch + 1)) 324 | 325 | print("start training, %d batches in total" % aTrainBatches) 326 | sess.run(training_init_op) 327 | for step in range(aTrainBatches): 328 | img_batch, label_batch = sess.run(next_batch) 329 | sess.run(aTrainOp, feed_dict={x: img_batch, y: label_batch, keep_prob: 1. - dropout_rate}) 330 | # TP, TN, FP, FN = sess.run([aNet.TP, aNet.TN, aNet.FP, aNet.FN], 331 | # feed_dict={x: img_batch, y: label_batch, keep_prob: 1. - dropout_rate}) 332 | # print("TP = %d, TN = %d, FP = %d, FN = %d" % (TP, TN, FP, FN)) 333 | if (epoch * aTrainBatches + step) % display_step == 0: 334 | # print("TP = %d, TN = %d, FP = %d, FN = %d" % (TP, TN, FP, FN)) 335 | s = sess.run(alexnet_summ, feed_dict={x: img_batch, y: label_batch, keep_prob: 1 - dropout_rate}) 336 | train_writer.add_summary(s) 337 | 338 | print("start validation, %d batches in total" % aValBatches) 339 | sess.run(validation_init_op) 340 | # calculate precision, recall and F_score only once in each epoch 341 | val_loss, val_acc, val_precision, val_recall, val_F_alpha = 0., 0., 0., 0., 0. 342 | val_TP, val_TN, val_FP, val_FN = 0, 0, 0, 0 343 | for step in range(aValBatches): 344 | img_batch, label_batch = sess.run(next_batch) 345 | step_TP, step_TN, step_FP, step_FN = sess.run([aNet.TP, aNet.TN, aNet.FP, aNet.FN], 346 | feed_dict={x: img_batch, y: label_batch, keep_prob: 1.}) 347 | 348 | s, step_loss, step_acc = sess.run([alexnet_summ, aLoss, accuracy], 349 | feed_dict={x: img_batch, y: label_batch, keep_prob: 1.}) 350 | val_loss += step_loss 351 | val_acc += step_acc 352 | val_TP += step_TP 353 | val_TN += step_TN 354 | val_FP += step_FP 355 | val_FN += step_FN 356 | if (epoch * aValBatches + step) % display_step == 0: 357 | val_writer.add_summary(s) 358 | print("TP = %d, TN = %d, FP = %d, FN = %d" % (val_TP, val_TN, val_FP, val_FN)) 359 | val_loss /= aValBatches 360 | val_acc /= aValBatches 361 | val_precision, val_recall, val_F_alpha = get_precision_recall_fscore(val_TP, val_TN, val_FP, val_FN, 1.0) 362 | metric_names = ['val_loss', 'val_acc', 'val_precision', 'val_recall', 'val_F_alpha'] 363 | aMetrics.update_metrics(metric_names, [val_loss, val_acc, val_precision, val_recall, val_F_alpha]) 364 | aMetrics.write_metrics() 365 | 366 | # do checkpointing 367 | checkpointer.update_best(val_loss, checkpoint=(not opt.noCheck)) 368 | 369 | # re-shuffle data to observe model behaviour 370 | train_data.reshuffle_data() 371 | 372 | 373 | for index, mem_cache in enumerate(mem_caches): # type: int, MemCache 374 | print("+" * 15, "proceeding with parameters set No. %d" % (index + 1), "+" * 15) 375 | aNet.set_model_vars(mem_cache.get_parameters(), sess) 376 | 377 | alexnet_training(aNet, index) 378 | -------------------------------------------------------------------------------- /alexnet.py: -------------------------------------------------------------------------------- 1 | """This is an TensorFLow implementation of AlexNet by Alex Krizhevsky at all. 2 | 3 | Paper: 4 | (http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf) 5 | 6 | Explanation can be found in my blog post: 7 | https://kratzert.github.io/2017/02/24/finetuning-alexnet-with-tensorflow.html 8 | 9 | This script enables finetuning AlexNet on any given Dataset with any number of 10 | classes. The structure of this script is strongly inspired by the fast.ai 11 | Deep Learning class by Jeremy Howard and Rachel Thomas, especially their vgg16 12 | finetuning script: 13 | Link: 14 | - https://github.com/fastai/courses/blob/master/deeplearning1/nbs/vgg16.py 15 | 16 | 17 | The pretrained weights can be downloaded here and should be placed in the same 18 | folder as this file: 19 | - http://www.cs.toronto.edu/~guerzhoy/tf_alexnet/ 20 | 21 | @author: Frederik Kratzert (contact: f.kratzert(at)gmail.com) 22 | """ 23 | 24 | import numpy as np 25 | import tensorflow as tf 26 | from abc import abstractmethod 27 | 28 | 29 | class Model(object): 30 | @abstractmethod 31 | def __init__(self): 32 | pass 33 | 34 | @abstractmethod 35 | def set_model_vars(self, variable_dict, session): 36 | pass 37 | 38 | @abstractmethod 39 | def get_model_vars(self, session, init=False): 40 | return {} 41 | 42 | @abstractmethod 43 | def load_model_vars(self, path: str, session): 44 | pass 45 | 46 | @abstractmethod 47 | def save_model_vars(self, path: str, session, init=False): 48 | pass 49 | 50 | @abstractmethod 51 | def load_model_pretrained(self, session): 52 | pass 53 | 54 | @abstractmethod 55 | def _create_loss(self, *args): 56 | pass 57 | 58 | 59 | # noinspection PyCompatibility 60 | class AlexNet(Model): 61 | """Implementation of the AlexNet.""" 62 | TRAIN_LAYERS = ... # type: set 63 | y = ... # type: tf.placeholder 64 | 65 | # ATTENTION: loading pretrained weights is called outside the constructor 66 | def __init__(self, x, keep_prob, num_classes, train_layers, falpha=2.0, 67 | weights_path='/pretrained/bvlc_alexnet.npy'): 68 | """Create the graph of the AlexNet model. 69 | 70 | Args: 71 | x: Placeholder for the input tensor. 72 | keep_prob: Dropout probability. 73 | num_classes: Number of classes in the dataset. 74 | train_layers: List of names of the layer, that get trained from 75 | scratch 76 | weights_path: Complete path to the pretrained weight file, if it 77 | isn't in the same folder as this code 78 | """ 79 | # Parse input arguments into class variables 80 | super(AlexNet, self).__init__() 81 | self.X = x 82 | # self.X = tf.placeholder(tf.float32, [None, 227, 227, 3]) 83 | self.NUM_CLASSES = num_classes 84 | self.KEEP_PROB = keep_prob 85 | self.TRAIN_LAYERS = train_layers 86 | self.WEIGHTS_PATH = weights_path 87 | self.ALPHA = falpha 88 | 89 | # Call the create function to build the computational graph of AlexNet 90 | # with tf.variable_scope('') as scope: 91 | 92 | self._create_discriminator() 93 | 94 | # define metrics 95 | # TODO consider switching the second dimension to self.NUM_CLASSES 96 | self.y = tf.placeholder(tf.float32, [None, self.NUM_CLASSES], name='y') 97 | self.correct_pred = tf.equal(tf.argmax(self.fc8, 1), tf.argmax(self.y, 1)) 98 | self.accuracy = tf.reduce_mean(tf.cast(self.correct_pred, tf.float32), name='accuracy') 99 | 100 | self._create_loss() 101 | self._create_stats(falpha) 102 | 103 | def _create_discriminator(self): 104 | """Create the network graph. returns tensors of fc7 and fc8""" 105 | # 1st Layer: Conv (w ReLu) -> Lrn -> Pool 106 | conv1 = conv(self.X, 11, 11, 96, 4, 4, padding='VALID', name='conv1') 107 | norm1 = lrn(conv1, 2, 1e-05, 0.75, name='norm1') 108 | pool1 = max_pool(norm1, 3, 3, 2, 2, padding='VALID', name='pool1') 109 | self.conv1, self.norm1, self.pool1 = conv1, norm1, pool1 110 | 111 | # 2nd Layer: Conv (w ReLu) -> Lrn -> Pool with 2 groups 112 | conv2 = conv(pool1, 5, 5, 256, 1, 1, groups=2, name='conv2') 113 | norm2 = lrn(conv2, 2, 1e-05, 0.75, name='norm2') 114 | pool2 = max_pool(norm2, 3, 3, 2, 2, padding='VALID', name='pool2') 115 | self.conv2, self.norm2, self.pool1 = conv2, norm2, pool2 116 | 117 | # 3rd Layer: Conv (w ReLu) 118 | conv3 = conv(pool2, 3, 3, 384, 1, 1, name='conv3') 119 | self.conv3 = conv3 120 | 121 | # 4th Layer: Conv (w ReLu) splitted into two groups 122 | conv4 = conv(conv3, 3, 3, 384, 1, 1, groups=2, name='conv4') 123 | self.conv4 = conv4 124 | 125 | # 5th Layer: Conv (w ReLu) -> Pool splitted into two groups 126 | conv5 = conv(conv4, 3, 3, 256, 1, 1, groups=2, name='conv5') 127 | pool5 = max_pool(conv5, 3, 3, 2, 2, padding='VALID', name='pool5') 128 | self.conv5, self.pool5 = conv5, pool5 129 | 130 | # 6th Layer: Flatten -> FC (w ReLu) -> Dropout 131 | flattened = tf.reshape(pool5, [-1, 6 * 6 * 256]) 132 | fc6 = fc(flattened, 6 * 6 * 256, 4096, name='fc6') 133 | dropout6 = dropout(fc6, self.KEEP_PROB, name='dropout6') 134 | self.flattened, self.fc6, self.dropout6 = flattened, fc6, dropout6 135 | 136 | # 7th Layer: FC (w ReLu) -> Dropout 137 | fc7 = fc(dropout6, 4096, 4096, name='fc7') 138 | dropout7 = dropout(fc7, self.KEEP_PROB, name='dropout7') 139 | self.fc7, self.dropout7 = fc7, dropout7 140 | 141 | # 8th Layer: FC and return unscaled activations 142 | fc8 = fc(dropout7, 4096, self.NUM_CLASSES, relu=False, name='fc8') 143 | self.fc8 = fc8 144 | 145 | def _create_loss(self): 146 | with tf.name_scope("cross_ent"): 147 | self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.fc8, labels=self.y), 148 | name="loss") 149 | 150 | def load_model_pretrained(self, session): 151 | """Load weights from file into network. 152 | 153 | As the weights from http://www.cs.toronto.edu/~guerzhoy/tf_alexnet/ 154 | come as a dict of lists (e.g. weights['conv1'] is a list) and not as 155 | dict of dicts (e.g. weights['conv1'] is a dict with keys 'weights' & 156 | 'biases') we need a special load function 157 | """ 158 | # Load the weights into memory 159 | variable_dict = np.load(self.WEIGHTS_PATH, encoding='bytes').item() # type: dict 160 | # Loop over all layer names stored in the weights dict 161 | for op_name in variable_dict: # type: str 162 | # Check if layer should be trained from scratch 163 | if op_name not in self.TRAIN_LAYERS: 164 | with tf.variable_scope(op_name, reuse=True): 165 | # Assign weights/biases to their corresponding tf variable 166 | for data in variable_dict[op_name]: 167 | var_name = "biases" if len(data.shape) == 1 else "weights" 168 | var = tf.get_variable(var_name, trainable=False) 169 | try: 170 | session.run(var.assign(data)) 171 | except: 172 | print("Failed to assign value to", var.name) 173 | 174 | def _create_stats(self, alpha): 175 | """only works for binary classification""" 176 | prediction = tf.argmax(self.fc8, axis=1, name='alexnet-prediction') 177 | ground_truth = tf.argmax(self.y, axis=1, name='alexnet-ground-truth') 178 | self.prediction, self.ground_truth = prediction, ground_truth 179 | self.TP = tf.reduce_sum(prediction * ground_truth) # True Positive 180 | self.TN = tf.reduce_sum((1 - prediction) * (1 - ground_truth)) # True Negative 181 | self.FP = tf.reduce_sum(prediction * (1 - ground_truth)) # False Positive 182 | self.FN = tf.reduce_sum((1 - prediction) * ground_truth) # False Negative 183 | self.precision = self.TP / (self.TP + self.FP) 184 | self.recall = self.TP / (self.TP + self.FN) 185 | self.F_alpha = (1 + alpha) / (1 / self.precision + alpha / self.recall) 186 | if self.NUM_CLASSES != 2: 187 | print("Warning: precision, recall and F_alpha score does not apply to Multi-Label Classification") 188 | 189 | def get_model_vars(self, session, init=False): 190 | """returns a dict of variables in the model, with keys being layer names and values being list of np.arrays""" 191 | if init: 192 | session.run(tf.global_variables_initializer()) 193 | layers = ['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'fc6', 'fc7', 'fc8'] 194 | variable_dict = {layer: [] for layer in layers} 195 | for layer in variable_dict: 196 | with tf.variable_scope(layer, reuse=True): 197 | for var_name in ["weights", "biases"]: 198 | var = tf.get_variable(var_name) 199 | variable_dict[layer].append(session.run(var)) 200 | return variable_dict 201 | 202 | def set_model_vars(self, variable_dict, session): 203 | """assign model variables with values from a dict passed""" 204 | for op_name in variable_dict: 205 | with tf.variable_scope(op_name, reuse=True): 206 | for data in variable_dict[op_name]: 207 | var_name = 'biases' if len(data.shape) == 1 else "weights" 208 | # in case set_model_vars() is called before load_model_pretrained(), set trainable 209 | var = tf.get_variable(var_name, trainable=op_name in self.TRAIN_LAYERS) 210 | session.run(var.assign(data)) 211 | 212 | def save_model_vars(self, path: str, session, init=False): 213 | """save model var-value dict under passed path""" 214 | np.save(path, self.get_model_vars(session, init=init)) 215 | 216 | def load_model_vars(self, path: str, session): 217 | """load model var-value from passed path""" 218 | variable_dict = np.load(path, encoding="bytes").item() # type: dict 219 | self.set_model_vars(variable_dict, session) 220 | 221 | 222 | class SiameseAlexNet(Model): 223 | def __init__(self, x1, x2, keep_prob, num_classes, train_layers, name_scope="Siamese", proj="flattened", 224 | falpha=2.0, margin00=3.5, margin01=7.0, margin11=8.0, weights_path='/pretrained/bvlc_alexnet.npy', 225 | punish00=1.0, punish11=1.0, punish01=5.0): 226 | super(SiameseAlexNet, self).__init__() 227 | self.name_scope = name_scope 228 | self.margin00 = margin00 229 | self.margin01 = margin01 230 | self.margin11 = margin11 231 | self.punish00 = punish00 232 | self.punish11 = punish11 233 | self.punish01 = punish01 234 | self.proj = proj 235 | with tf.variable_scope(self.name_scope) as scope: 236 | self.net1 = AlexNet(x1, keep_prob, num_classes, train_layers, falpha=falpha, weights_path=weights_path) 237 | scope.reuse_variables() 238 | self.net2 = AlexNet(x2, keep_prob, num_classes, train_layers, falpha=falpha, weights_path=weights_path) 239 | # define a loss for Siamese Network 240 | self._create_loss(proj) 241 | 242 | def _create_loss(self, proj): 243 | # XXX punishing the Pos-Neg loss harder than Pos-Pos loss and Neg-Neg loss to avoid underfitting 244 | proj1, proj2 = self._get_projections(proj) 245 | eucd2 = tf.reduce_mean((proj1 - proj2) ** 2, axis=1, name="euclidean_dist_squared") 246 | eucd = tf.sqrt(eucd2, name="euclidean_dist") 247 | print('euclidean distances tensor', eucd) 248 | # y1, y2 and y_cmp should be a class member 249 | y1 = tf.cast(tf.argmax(self.net1.y, axis=1), tf.float32, name='siam-y1') 250 | y2 = tf.cast(tf.argmax(self.net2.y, axis=1), tf.float32, name='siam-y2') 251 | self.y1_label, self.y2_label = y1, y2 252 | y_diff = tf.cast(y1 - y2, tf.bool, name="comparison_label_in_tf.bool") 253 | y_diff = tf.cast(y_diff, tf.float32, name="comparison_label_in_tf.float32") 254 | self.count01 = tf.reduce_sum(y_diff, name='count01') 255 | self.count00 = tf.reduce_sum((1 - y1) * (1 - y2), name='count00') 256 | self.count11 = tf.reduce_sum(y1 * y2, name='count11') 257 | 258 | # if label1 and label2 are the same, y_diff = 0, punish the part where eucd exceeds margin 259 | loss00 = tf.reduce_mean(((1 - y1) * (1 - y2) * tf.nn.relu(eucd - self.margin00)) ** 2, axis=0, name='loss00') 260 | loss11 = tf.reduce_mean((y1 * y2 * tf.nn.relu(eucd - self.margin11)) ** 2, axis=0, name='loss11') 261 | self.mean_dist00 = tf.reduce_sum((1 - y1) * (1 - y2) * eucd) / self.count00 262 | self.mean_dist11 = tf.reduce_sum(y1 * y2 * eucd) / self.count11 263 | 264 | # if label1 and label2 are different, y_diff = 1, punish the part where eucd falls short of margin 265 | loss01 = tf.reduce_mean((y_diff * tf.nn.relu(self.margin01 - eucd)) ** 2, axis=0, name='loss01') 266 | self.mean_dist01 = tf.reduce_sum(y_diff * eucd) / self.count01 267 | 268 | self.loss00 = loss00 * self.punish00 269 | self.loss01 = loss01 * self.punish01 270 | self.loss11 = loss11 * self.punish11 271 | self.loss = tf.add(self.loss00 + self.loss11, self.loss01, name="siamese-loss") 272 | print(self.loss) 273 | 274 | def _get_projections(self, proj): 275 | print('projection =', proj, "type=", type(proj)) 276 | projections = (self.net1.dropout6, self.net2.dropout6) 277 | try: 278 | if proj == "fc6": 279 | projections = (self.net1.fc6, self.net2.fc6) 280 | elif proj == "fc7": 281 | projections = (self.net1.fc7, self.net2.fc7) 282 | elif proj == "fc8": 283 | projections = (self.net1.fc8, self.net2.fc8) 284 | elif proj == "dropout6": 285 | projections = (self.net1.dropout6, self.net2.dropout6) 286 | elif proj == "dropout7": 287 | projections = (self.net1.dropout7, self.net2.dropout7) 288 | elif proj == "flattened": 289 | projections = (self.net1.flattened, self.net2.flattened) 290 | else: 291 | raise ValueError("Illegal Projection: " + proj) 292 | except ValueError as e: 293 | print("ValueError: encountered in _get_predictions") 294 | print(e) 295 | finally: 296 | print("projections of %s are " % self.name_scope, projections[0].name, projections[1].name) 297 | print("dimensions of projection is", projections[0].shape, projections[1].shape) 298 | return projections 299 | 300 | def load_model_pretrained(self, session): 301 | with tf.variable_scope(self.name_scope, reuse=True): 302 | self.net1.load_model_pretrained(session) 303 | 304 | def load_model_vars(self, path: str, session): 305 | """load model var-value from passed path""" 306 | with tf.variable_scope(self.name_scope, reuse=True): 307 | self.net1.load_model_vars(path, session) 308 | 309 | def save_model_vars(self, path: str, session, init=False): 310 | """save model var-value dict under passed path""" 311 | with tf.variable_scope(self.name_scope): 312 | self.net1.save_model_vars(path, session, init=init) 313 | 314 | def get_model_vars(self, session, init=False): 315 | """returns a dict of variables in the model, with keys being layer names and values being list of np.arrays""" 316 | with tf.variable_scope(self.name_scope): 317 | return self.net1.get_model_vars(session, init=init) 318 | 319 | def set_model_vars(self, variable_dict, session): 320 | """assign model variables with values from a dict passed""" 321 | with tf.variable_scope(self.name_scope): 322 | self.net1.set_model_vars(variable_dict, session) 323 | 324 | # return a new instance of AlexNet with trainable variables 325 | def get_net_copy(self, session, x=None, keep_prob=None, num_classes=None, train_layers=None, falpha=None, 326 | weights_path=None) -> AlexNet: 327 | if x is None: 328 | x = self.net1.X 329 | print("Warning: x should be specified as a new placeholder") 330 | if keep_prob is None: 331 | keep_prob = self.net1.KEEP_PROB 332 | if num_classes is None: 333 | num_classes = self.net1.NUM_CLASSES 334 | if train_layers is None: 335 | train_layers = self.net1.TRAIN_LAYERS 336 | print("Warning: train_layers should be specified as a new list of layer names") 337 | if falpha is None: 338 | falpha = self.net1.ALPHA 339 | if weights_path is None: 340 | weights_path = self.net1.WEIGHTS_PATH 341 | new_net = AlexNet(x, keep_prob, num_classes, train_layers, falpha=falpha, weights_path=weights_path) 342 | new_net.set_model_vars(self.get_model_vars(session), session) 343 | return new_net 344 | 345 | 346 | def conv(x, filter_height, filter_width, num_filters, stride_y, stride_x, name, padding='SAME', groups=1): 347 | """Create a convolution layer. 348 | 349 | Adapted from: https://github.com/ethereon/caffe-tensorflow 350 | """ 351 | # Get number of input channels 352 | input_channels = int(x.get_shape()[-1]) 353 | 354 | # Create lambda function for the convolution 355 | convolve = lambda i, k: tf.nn.conv2d(i, k, 356 | strides=[1, stride_y, stride_x, 1], 357 | padding=padding) 358 | 359 | with tf.variable_scope(name) as scope: 360 | # Create tf variables for the weights and biases of the conv layer 361 | weights = tf.get_variable('weights', shape=[filter_height, 362 | filter_width, 363 | input_channels / groups, 364 | num_filters]) 365 | biases = tf.get_variable('biases', shape=[num_filters]) 366 | 367 | if groups == 1: 368 | conv = convolve(x, weights) 369 | 370 | # In the cases of multiple groups, split inputs & weights and 371 | else: 372 | # Split input and weights and convolve them separately 373 | input_groups = tf.split(axis=3, num_or_size_splits=groups, value=x) 374 | weight_groups = tf.split(axis=3, num_or_size_splits=groups, 375 | value=weights) 376 | output_groups = [convolve(i, k) for i, k in zip(input_groups, weight_groups)] 377 | 378 | # Concat the convolved output together again 379 | conv = tf.concat(axis=3, values=output_groups) 380 | 381 | # Add biases 382 | bias = tf.reshape(tf.nn.bias_add(conv, biases), tf.shape(conv)) 383 | 384 | # Apply relu function 385 | relu = tf.nn.relu(bias, name=scope.name) 386 | 387 | return relu 388 | 389 | 390 | def fc(x, num_in, num_out, name, relu=True): 391 | """Create a fully connected layer.""" 392 | with tf.variable_scope(name) as scope: 393 | 394 | # Create tf variables for the weights and biases 395 | weights = tf.get_variable('weights', shape=[num_in, num_out], 396 | trainable=True) 397 | biases = tf.get_variable('biases', [num_out], trainable=True) 398 | 399 | # Matrix multiply weights and inputs and add bias 400 | act = tf.nn.xw_plus_b(x, weights, biases, name=scope.name) 401 | 402 | if relu: 403 | # Apply ReLu non linearity 404 | relu = tf.nn.relu(act) 405 | return relu 406 | else: 407 | return act 408 | 409 | 410 | def max_pool(x, filter_height, filter_width, stride_y, stride_x, name, padding='SAME'): 411 | """Create a max pooling layer.""" 412 | return tf.nn.max_pool(x, ksize=[1, filter_height, filter_width, 1], 413 | strides=[1, stride_y, stride_x, 1], 414 | padding=padding, name=name) 415 | 416 | 417 | def lrn(x, radius, alpha, beta, name, bias=1.0): 418 | """Create a local response normalization layer.""" 419 | return tf.nn.local_response_normalization(x, depth_radius=radius, 420 | alpha=alpha, beta=beta, 421 | bias=bias, name=name) 422 | 423 | 424 | def dropout(x, keep_prob, name='dropout'): 425 | """Create a dropout layer.""" 426 | return tf.nn.dropout(x, keep_prob, name=name) 427 | 428 | 429 | def test1(): 430 | keep_prob = tf.placeholder(tf.float32, [], name='keep_prob') 431 | x = tf.placeholder(tf.float32, [None, 227, 227, 3], name='x') 432 | x1 = tf.placeholder(tf.float32, [None, 227, 227, 3], name='x1') 433 | x2 = tf.placeholder(tf.float32, [None, 227, 227, 3], name='x2') 434 | num_classes = 2 435 | sTrainLayers = ['fc6'] 436 | aTrainLayers = ['fc7', 'fc8'] 437 | sNet = SiameseAlexNet(x1, x2, keep_prob, num_classes, sTrainLayers) 438 | sess = tf.InteractiveSession() 439 | # sNet.load_model_pretrained(session=sess) 440 | print(sNet.net1.conv1) 441 | 442 | aNet = sNet.get_net_copy() 443 | 444 | 445 | if __name__ == "__main__": 446 | # how the two nets in Siamese Net share the keep_prob placeholder? 447 | # the keep_prob argument passed to constructor is an integer, instead of a placeholder 448 | keep_prob = tf.placeholder(tf.float32, [], name='keep_prob') 449 | x = tf.placeholder(tf.float32, [None, 227, 227, 3], name='x') 450 | x1 = tf.placeholder(tf.float32, [None, 227, 227, 3], name='x1') 451 | x2 = tf.placeholder(tf.float32, [None, 227, 227, 3], name='x2') 452 | # image_batch = np.random.rand(5, 227, 227, 3) 453 | # label_batch = np.random.rand(5, 1000) 454 | net = AlexNet(x, keep_prob, 2, ['fc6', 'fc7']) 455 | # net = SiameseAlexNet(x1, x2, 0.5, 3, ['fc6', 'fc7', 'fc8'], name_scope="SiameseA", proj="flattened") 456 | # netB = SiameseAlexNet(x1, x2, 0.5, 3, ['fc6', 'fc7', 'fc8'], name_scope="SiameseB") 457 | # check_path = "/Users/liushuheng/Desktop/vars.npy" 458 | # with tf.Session() as sess: 459 | # sess.run(tf.global_variables_initializer()) 460 | # net.load_model_pretrained(sess) 461 | # y1 = sess.run(netA.net1.y, feed_dict={netA.net1.X: image_batch, netA.net1.y: label_batch}) 462 | # y2 = sess.run(netB.net1.y, feed_dict={netB.net1.X: image_batch, netB.net1.y: label_batch}) 463 | # netA.save_model_vars(check_path, sess) 464 | # netB.load_model_vars(check_path, sess) 465 | # y3 = sess.run(netB.net1.y, feed_dict={netB.net1.X: image_batch, netB.net1.y: label_batch}) 466 | # assert (y1 == y2).all(), "assertion1 failed" 467 | # print("assertion1 passed") 468 | # assert (y1 == y3).all(), "assertion2 failed" 469 | # print("assertion2 passed") 470 | # d = net.get_model_vars(sess) 471 | # init_weights = np.load("/pretrained/bvlc_alexnet.npy", encoding="bytes").item() 472 | 473 | # for var in tf.global_variables(): 474 | # # for var in tf.get_default_graph().get_operations(): 475 | # print(var.name, end=" ") 476 | -------------------------------------------------------------------------------- /caffe_classes.py: -------------------------------------------------------------------------------- 1 | class_names = '''tench, Tinca tinca 2 | goldfish, Carassius auratus 3 | great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias 4 | tiger shark, Galeocerdo cuvieri 5 | hammerhead, hammerhead shark 6 | electric ray, crampfish, numbfish, torpedo 7 | stingray 8 | cock 9 | hen 10 | ostrich, Struthio camelus 11 | brambling, Fringilla montifringilla 12 | goldfinch, Carduelis carduelis 13 | house finch, linnet, Carpodacus mexicanus 14 | junco, snowbird 15 | indigo bunting, indigo finch, indigo bird, Passerina cyanea 16 | robin, American robin, Turdus migratorius 17 | bulbul 18 | jay 19 | magpie 20 | chickadee 21 | water ouzel, dipper 22 | kite 23 | bald eagle, American eagle, Haliaeetus leucocephalus 24 | vulture 25 | great grey owl, great gray owl, Strix nebulosa 26 | European fire salamander, Salamandra salamandra 27 | common newt, Triturus vulgaris 28 | eft 29 | spotted salamander, Ambystoma maculatum 30 | axolotl, mud puppy, Ambystoma mexicanum 31 | bullfrog, Rana catesbeiana 32 | tree frog, tree-frog 33 | tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui 34 | loggerhead, loggerhead turtle, Caretta caretta 35 | leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea 36 | mud turtle 37 | terrapin 38 | box turtle, box tortoise 39 | banded gecko 40 | common iguana, iguana, Iguana iguana 41 | American chameleon, anole, Anolis carolinensis 42 | whiptail, whiptail lizard 43 | agama 44 | frilled lizard, Chlamydosaurus kingi 45 | alligator lizard 46 | Gila monster, Heloderma suspectum 47 | green lizard, Lacerta viridis 48 | African chameleon, Chamaeleo chamaeleon 49 | Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis 50 | African crocodile, Nile crocodile, Crocodylus niloticus 51 | American alligator, Alligator mississipiensis 52 | triceratops 53 | thunder snake, worm snake, Carphophis amoenus 54 | ringneck snake, ring-necked snake, ring snake 55 | hognose snake, puff adder, sand viper 56 | green snake, grass snake 57 | king snake, kingsnake 58 | garter snake, grass snake 59 | water snake 60 | vine snake 61 | night snake, Hypsiglena torquata 62 | boa constrictor, Constrictor constrictor 63 | rock python, rock snake, Python sebae 64 | Indian cobra, Naja naja 65 | green mamba 66 | sea snake 67 | horned viper, cerastes, sand viper, horned asp, Cerastes cornutus 68 | diamondback, diamondback rattlesnake, Crotalus adamanteus 69 | sidewinder, horned rattlesnake, Crotalus cerastes 70 | trilobite 71 | harvestman, daddy longlegs, Phalangium opilio 72 | scorpion 73 | black and gold garden spider, Argiope aurantia 74 | barn spider, Araneus cavaticus 75 | garden spider, Aranea diademata 76 | black widow, Latrodectus mactans 77 | tarantula 78 | wolf spider, hunting spider 79 | tick 80 | centipede 81 | black grouse 82 | ptarmigan 83 | ruffed grouse, partridge, Bonasa umbellus 84 | prairie chicken, prairie grouse, prairie fowl 85 | peacock 86 | quail 87 | partridge 88 | African grey, African gray, Psittacus erithacus 89 | macaw 90 | sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita 91 | lorikeet 92 | coucal 93 | bee eater 94 | hornbill 95 | hummingbird 96 | jacamar 97 | toucan 98 | drake 99 | red-breasted merganser, Mergus serrator 100 | goose 101 | black swan, Cygnus atratus 102 | tusker 103 | echidna, spiny anteater, anteater 104 | platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus 105 | wallaby, brush kangaroo 106 | koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus 107 | wombat 108 | jellyfish 109 | sea anemone, anemone 110 | brain coral 111 | flatworm, platyhelminth 112 | nematode, nematode worm, roundworm 113 | conch 114 | snail 115 | slug 116 | sea slug, nudibranch 117 | chiton, coat-of-mail shell, sea cradle, polyplacophore 118 | chambered nautilus, pearly nautilus, nautilus 119 | Dungeness crab, Cancer magister 120 | rock crab, Cancer irroratus 121 | fiddler crab 122 | king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica 123 | American lobster, Northern lobster, Maine lobster, Homarus americanus 124 | spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish 125 | crayfish, crawfish, crawdad, crawdaddy 126 | hermit crab 127 | isopod 128 | white stork, Ciconia ciconia 129 | black stork, Ciconia nigra 130 | spoonbill 131 | flamingo 132 | little blue heron, Egretta caerulea 133 | American egret, great white heron, Egretta albus 134 | bittern 135 | crane 136 | limpkin, Aramus pictus 137 | European gallinule, Porphyrio porphyrio 138 | American coot, marsh hen, mud hen, water hen, Fulica americana 139 | bustard 140 | ruddy turnstone, Arenaria interpres 141 | red-backed sandpiper, dunlin, Erolia alpina 142 | redshank, Tringa totanus 143 | dowitcher 144 | oystercatcher, oyster catcher 145 | pelican 146 | king penguin, Aptenodytes patagonica 147 | albatross, mollymawk 148 | grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus 149 | killer whale, killer, orca, grampus, sea wolf, Orcinus orca 150 | dugong, Dugong dugon 151 | sea lion 152 | Chihuahua 153 | Japanese spaniel 154 | Maltese dog, Maltese terrier, Maltese 155 | Pekinese, Pekingese, Peke 156 | Shih-Tzu 157 | Blenheim spaniel 158 | papillon 159 | toy terrier 160 | Rhodesian ridgeback 161 | Afghan hound, Afghan 162 | basset, basset hound 163 | beagle 164 | bloodhound, sleuthhound 165 | bluetick 166 | black-and-tan coonhound 167 | Walker hound, Walker foxhound 168 | English foxhound 169 | redbone 170 | borzoi, Russian wolfhound 171 | Irish wolfhound 172 | Italian greyhound 173 | whippet 174 | Ibizan hound, Ibizan Podenco 175 | Norwegian elkhound, elkhound 176 | otterhound, otter hound 177 | Saluki, gazelle hound 178 | Scottish deerhound, deerhound 179 | Weimaraner 180 | Staffordshire bullterrier, Staffordshire bull terrier 181 | American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier 182 | Bedlington terrier 183 | Border terrier 184 | Kerry blue terrier 185 | Irish terrier 186 | Norfolk terrier 187 | Norwich terrier 188 | Yorkshire terrier 189 | wire-haired fox terrier 190 | Lakeland terrier 191 | Sealyham terrier, Sealyham 192 | Airedale, Airedale terrier 193 | cairn, cairn terrier 194 | Australian terrier 195 | Dandie Dinmont, Dandie Dinmont terrier 196 | Boston bull, Boston terrier 197 | miniature schnauzer 198 | giant schnauzer 199 | standard schnauzer 200 | Scotch terrier, Scottish terrier, Scottie 201 | Tibetan terrier, chrysanthemum dog 202 | silky terrier, Sydney silky 203 | soft-coated wheaten terrier 204 | West Highland white terrier 205 | Lhasa, Lhasa apso 206 | flat-coated retriever 207 | curly-coated retriever 208 | golden retriever 209 | Labrador retriever 210 | Chesapeake Bay retriever 211 | German short-haired pointer 212 | vizsla, Hungarian pointer 213 | English setter 214 | Irish setter, red setter 215 | Gordon setter 216 | Brittany spaniel 217 | clumber, clumber spaniel 218 | English springer, English springer spaniel 219 | Welsh springer spaniel 220 | cocker spaniel, English cocker spaniel, cocker 221 | Sussex spaniel 222 | Irish water spaniel 223 | kuvasz 224 | schipperke 225 | groenendael 226 | malinois 227 | briard 228 | kelpie 229 | komondor 230 | Old English sheepdog, bobtail 231 | Shetland sheepdog, Shetland sheep dog, Shetland 232 | collie 233 | Border collie 234 | Bouvier des Flandres, Bouviers des Flandres 235 | Rottweiler 236 | German shepherd, German shepherd dog, German police dog, alsatian 237 | Doberman, Doberman pinscher 238 | miniature pinscher 239 | Greater Swiss Mountain dog 240 | Bernese mountain dog 241 | Appenzeller 242 | EntleBucher 243 | boxer 244 | bull mastiff 245 | Tibetan mastiff 246 | French bulldog 247 | Great Dane 248 | Saint Bernard, St Bernard 249 | Eskimo dog, husky 250 | malamute, malemute, Alaskan malamute 251 | Siberian husky 252 | dalmatian, coach dog, carriage dog 253 | affenpinscher, monkey pinscher, monkey dog 254 | basenji 255 | pug, pug-dog 256 | Leonberg 257 | Newfoundland, Newfoundland dog 258 | Great Pyrenees 259 | Samoyed, Samoyede 260 | Pomeranian 261 | chow, chow chow 262 | keeshond 263 | Brabancon griffon 264 | Pembroke, Pembroke Welsh corgi 265 | Cardigan, Cardigan Welsh corgi 266 | toy poodle 267 | miniature poodle 268 | standard poodle 269 | Mexican hairless 270 | timber wolf, grey wolf, gray wolf, Canis lupus 271 | white wolf, Arctic wolf, Canis lupus tundrarum 272 | red wolf, maned wolf, Canis rufus, Canis niger 273 | coyote, prairie wolf, brush wolf, Canis latrans 274 | dingo, warrigal, warragal, Canis dingo 275 | dhole, Cuon alpinus 276 | African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus 277 | hyena, hyaena 278 | red fox, Vulpes vulpes 279 | kit fox, Vulpes macrotis 280 | Arctic fox, white fox, Alopex lagopus 281 | grey fox, gray fox, Urocyon cinereoargenteus 282 | tabby, tabby cat 283 | tiger cat 284 | Persian cat 285 | Siamese cat, Siamese 286 | Egyptian cat 287 | cougar, puma, catamount, mountain lion, painter, panther, Felis concolor 288 | lynx, catamount 289 | leopard, Panthera pardus 290 | snow leopard, ounce, Panthera uncia 291 | jaguar, panther, Panthera onca, Felis onca 292 | lion, king of beasts, Panthera leo 293 | tiger, Panthera tigris 294 | cheetah, chetah, Acinonyx jubatus 295 | brown bear, bruin, Ursus arctos 296 | American black bear, black bear, Ursus americanus, Euarctos americanus 297 | ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus 298 | sloth bear, Melursus ursinus, Ursus ursinus 299 | mongoose 300 | meerkat, mierkat 301 | tiger beetle 302 | ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle 303 | ground beetle, carabid beetle 304 | long-horned beetle, longicorn, longicorn beetle 305 | leaf beetle, chrysomelid 306 | dung beetle 307 | rhinoceros beetle 308 | weevil 309 | fly 310 | bee 311 | ant, emmet, pismire 312 | grasshopper, hopper 313 | cricket 314 | walking stick, walkingstick, stick insect 315 | cockroach, roach 316 | mantis, mantid 317 | cicada, cicala 318 | leafhopper 319 | lacewing, lacewing fly 320 | dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk 321 | damselfly 322 | admiral 323 | ringlet, ringlet butterfly 324 | monarch, monarch butterfly, milkweed butterfly, Danaus plexippus 325 | cabbage butterfly 326 | sulphur butterfly, sulfur butterfly 327 | lycaenid, lycaenid butterfly 328 | starfish, sea star 329 | sea urchin 330 | sea cucumber, holothurian 331 | wood rabbit, cottontail, cottontail rabbit 332 | hare 333 | Angora, Angora rabbit 334 | hamster 335 | porcupine, hedgehog 336 | fox squirrel, eastern fox squirrel, Sciurus niger 337 | marmot 338 | beaver 339 | guinea pig, Cavia cobaya 340 | sorrel 341 | zebra 342 | hog, pig, grunter, squealer, Sus scrofa 343 | wild boar, boar, Sus scrofa 344 | warthog 345 | hippopotamus, hippo, river horse, Hippopotamus amphibius 346 | ox 347 | water buffalo, water ox, Asiatic buffalo, Bubalus bubalis 348 | bison 349 | ram, tup 350 | bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis 351 | ibex, Capra ibex 352 | hartebeest 353 | impala, Aepyceros melampus 354 | gazelle 355 | Arabian camel, dromedary, Camelus dromedarius 356 | llama 357 | weasel 358 | mink 359 | polecat, fitch, foulmart, foumart, Mustela putorius 360 | black-footed ferret, ferret, Mustela nigripes 361 | otter 362 | skunk, polecat, wood pussy 363 | badger 364 | armadillo 365 | three-toed sloth, ai, Bradypus tridactylus 366 | orangutan, orang, orangutang, Pongo pygmaeus 367 | gorilla, Gorilla gorilla 368 | chimpanzee, chimp, Pan troglodytes 369 | gibbon, Hylobates lar 370 | siamang, Hylobates syndactylus, Symphalangus syndactylus 371 | guenon, guenon monkey 372 | patas, hussar monkey, Erythrocebus patas 373 | baboon 374 | macaque 375 | langur 376 | colobus, colobus monkey 377 | proboscis monkey, Nasalis larvatus 378 | marmoset 379 | capuchin, ringtail, Cebus capucinus 380 | howler monkey, howler 381 | titi, titi monkey 382 | spider monkey, Ateles geoffroyi 383 | squirrel monkey, Saimiri sciureus 384 | Madagascar cat, ring-tailed lemur, Lemur catta 385 | indri, indris, Indri indri, Indri brevicaudatus 386 | Indian elephant, Elephas maximus 387 | African elephant, Loxodonta africana 388 | lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens 389 | giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca 390 | barracouta, snoek 391 | eel 392 | coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch 393 | rock beauty, Holocanthus tricolor 394 | anemone fish 395 | sturgeon 396 | gar, garfish, garpike, billfish, Lepisosteus osseus 397 | lionfish 398 | puffer, pufferfish, blowfish, globefish 399 | abacus 400 | abaya 401 | academic gown, academic robe, judge's robe 402 | accordion, piano accordion, squeeze box 403 | acoustic guitar 404 | aircraft carrier, carrier, flattop, attack aircraft carrier 405 | airliner 406 | airship, dirigible 407 | altar 408 | ambulance 409 | amphibian, amphibious vehicle 410 | analog clock 411 | apiary, bee house 412 | apron 413 | ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin 414 | assault rifle, assault gun 415 | backpack, back pack, knapsack, packsack, rucksack, haversack 416 | bakery, bakeshop, bakehouse 417 | balance beam, beam 418 | balloon 419 | ballpoint, ballpoint pen, ballpen, Biro 420 | Band Aid 421 | banjo 422 | bannister, banister, balustrade, balusters, handrail 423 | barbell 424 | barber chair 425 | barbershop 426 | barn 427 | barometer 428 | barrel, cask 429 | barrow, garden cart, lawn cart, wheelbarrow 430 | baseball 431 | basketball 432 | bassinet 433 | bassoon 434 | bathing cap, swimming cap 435 | bath towel 436 | bathtub, bathing tub, bath, tub 437 | beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon 438 | beacon, lighthouse, beacon light, pharos 439 | beaker 440 | bearskin, busby, shako 441 | beer bottle 442 | beer glass 443 | bell cote, bell cot 444 | bib 445 | bicycle-built-for-two, tandem bicycle, tandem 446 | bikini, two-piece 447 | binder, ring-binder 448 | binoculars, field glasses, opera glasses 449 | birdhouse 450 | boathouse 451 | bobsled, bobsleigh, bob 452 | bolo tie, bolo, bola tie, bola 453 | bonnet, poke bonnet 454 | bookcase 455 | bookshop, bookstore, bookstall 456 | bottlecap 457 | bow 458 | bow tie, bow-tie, bowtie 459 | brass, memorial tablet, plaque 460 | brassiere, bra, bandeau 461 | breakwater, groin, groyne, mole, bulwark, seawall, jetty 462 | breastplate, aegis, egis 463 | broom 464 | bucket, pail 465 | buckle 466 | bulletproof vest 467 | bullet train, bullet 468 | butcher shop, meat market 469 | cab, hack, taxi, taxicab 470 | caldron, cauldron 471 | candle, taper, wax light 472 | cannon 473 | canoe 474 | can opener, tin opener 475 | cardigan 476 | car mirror 477 | carousel, carrousel, merry-go-round, roundabout, whirligig 478 | carpenter's kit, tool kit 479 | carton 480 | car wheel 481 | cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM 482 | cassette 483 | cassette player 484 | castle 485 | catamaran 486 | CD player 487 | cello, violoncello 488 | cellular telephone, cellular phone, cellphone, cell, mobile phone 489 | chain 490 | chainlink fence 491 | chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour 492 | chain saw, chainsaw 493 | chest 494 | chiffonier, commode 495 | chime, bell, gong 496 | china cabinet, china closet 497 | Christmas stocking 498 | church, church building 499 | cinema, movie theater, movie theatre, movie house, picture palace 500 | cleaver, meat cleaver, chopper 501 | cliff dwelling 502 | cloak 503 | clog, geta, patten, sabot 504 | cocktail shaker 505 | coffee mug 506 | coffeepot 507 | coil, spiral, volute, whorl, helix 508 | combination lock 509 | computer keyboard, keypad 510 | confectionery, confectionary, candy store 511 | container ship, containership, container vessel 512 | convertible 513 | corkscrew, bottle screw 514 | cornet, horn, trumpet, trump 515 | cowboy boot 516 | cowboy hat, ten-gallon hat 517 | cradle 518 | crane 519 | crash helmet 520 | crate 521 | crib, cot 522 | Crock Pot 523 | croquet ball 524 | crutch 525 | cuirass 526 | dam, dike, dyke 527 | desk 528 | desktop computer 529 | dial telephone, dial phone 530 | diaper, nappy, napkin 531 | digital clock 532 | digital watch 533 | dining table, board 534 | dishrag, dishcloth 535 | dishwasher, dish washer, dishwashing machine 536 | disk brake, disc brake 537 | dock, dockage, docking facility 538 | dogsled, dog sled, dog sleigh 539 | dome 540 | doormat, welcome mat 541 | drilling platform, offshore rig 542 | drum, membranophone, tympan 543 | drumstick 544 | dumbbell 545 | Dutch oven 546 | electric fan, blower 547 | electric guitar 548 | electric locomotive 549 | entertainment center 550 | envelope 551 | espresso maker 552 | face powder 553 | feather boa, boa 554 | file, file cabinet, filing cabinet 555 | fireboat 556 | fire engine, fire truck 557 | fire screen, fireguard 558 | flagpole, flagstaff 559 | flute, transverse flute 560 | folding chair 561 | football helmet 562 | forklift 563 | fountain 564 | fountain pen 565 | four-poster 566 | freight car 567 | French horn, horn 568 | frying pan, frypan, skillet 569 | fur coat 570 | garbage truck, dustcart 571 | gasmask, respirator, gas helmet 572 | gas pump, gasoline pump, petrol pump, island dispenser 573 | goblet 574 | go-kart 575 | golf ball 576 | golfcart, golf cart 577 | gondola 578 | gong, tam-tam 579 | gown 580 | grand piano, grand 581 | greenhouse, nursery, glasshouse 582 | grille, radiator grille 583 | grocery store, grocery, food market, market 584 | guillotine 585 | hair slide 586 | hair spray 587 | half track 588 | hammer 589 | hamper 590 | hand blower, blow dryer, blow drier, hair dryer, hair drier 591 | hand-held computer, hand-held microcomputer 592 | handkerchief, hankie, hanky, hankey 593 | hard disc, hard disk, fixed disk 594 | harmonica, mouth organ, harp, mouth harp 595 | harp 596 | harvester, reaper 597 | hatchet 598 | holster 599 | home theater, home theatre 600 | honeycomb 601 | hook, claw 602 | hoopskirt, crinoline 603 | horizontal bar, high bar 604 | horse cart, horse-cart 605 | hourglass 606 | iPod 607 | iron, smoothing iron 608 | jack-o'-lantern 609 | jean, blue jean, denim 610 | jeep, landrover 611 | jersey, T-shirt, tee shirt 612 | jigsaw puzzle 613 | jinrikisha, ricksha, rickshaw 614 | joystick 615 | kimono 616 | knee pad 617 | knot 618 | lab coat, laboratory coat 619 | ladle 620 | lampshade, lamp shade 621 | laptop, laptop computer 622 | lawn mower, mower 623 | lens cap, lens cover 624 | letter opener, paper knife, paperknife 625 | library 626 | lifeboat 627 | lighter, light, igniter, ignitor 628 | limousine, limo 629 | liner, ocean liner 630 | lipstick, lip rouge 631 | Loafer 632 | lotion 633 | loudspeaker, speaker, speaker unit, loudspeaker system, speaker system 634 | loupe, jeweler's loupe 635 | lumbermill, sawmill 636 | magnetic compass 637 | mailbag, postbag 638 | mailbox, letter box 639 | maillot 640 | maillot, tank suit 641 | manhole cover 642 | maraca 643 | marimba, xylophone 644 | mask 645 | matchstick 646 | maypole 647 | maze, labyrinth 648 | measuring cup 649 | medicine chest, medicine cabinet 650 | megalith, megalithic structure 651 | microphone, mike 652 | microwave, microwave oven 653 | military uniform 654 | milk can 655 | minibus 656 | miniskirt, mini 657 | minivan 658 | missile 659 | mitten 660 | mixing bowl 661 | mobile home, manufactured home 662 | Model T 663 | modem 664 | monastery 665 | monitor 666 | moped 667 | mortar 668 | mortarboard 669 | mosque 670 | mosquito net 671 | motor scooter, scooter 672 | mountain bike, all-terrain bike, off-roader 673 | mountain tent 674 | mouse, computer mouse 675 | mousetrap 676 | moving van 677 | muzzle 678 | nail 679 | neck brace 680 | necklace 681 | nipple 682 | notebook, notebook computer 683 | obelisk 684 | oboe, hautboy, hautbois 685 | ocarina, sweet potato 686 | odometer, hodometer, mileometer, milometer 687 | oil filter 688 | organ, pipe organ 689 | oscilloscope, scope, cathode-ray oscilloscope, CRO 690 | overskirt 691 | oxcart 692 | oxygen mask 693 | packet 694 | paddle, boat paddle 695 | paddlewheel, paddle wheel 696 | padlock 697 | paintbrush 698 | pajama, pyjama, pj's, jammies 699 | palace 700 | panpipe, pandean pipe, syrinx 701 | paper towel 702 | parachute, chute 703 | parallel bars, bars 704 | park bench 705 | parking meter 706 | passenger car, coach, carriage 707 | patio, terrace 708 | pay-phone, pay-station 709 | pedestal, plinth, footstall 710 | pencil box, pencil case 711 | pencil sharpener 712 | perfume, essence 713 | Petri dish 714 | photocopier 715 | pick, plectrum, plectron 716 | pickelhaube 717 | picket fence, paling 718 | pickup, pickup truck 719 | pier 720 | piggy bank, penny bank 721 | pill bottle 722 | pillow 723 | ping-pong ball 724 | pinwheel 725 | pirate, pirate ship 726 | pitcher, ewer 727 | plane, carpenter's plane, woodworking plane 728 | planetarium 729 | plastic bag 730 | plate rack 731 | plow, plough 732 | plunger, plumber's helper 733 | Polaroid camera, Polaroid Land camera 734 | pole 735 | police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria 736 | poncho 737 | pool table, billiard table, snooker table 738 | pop bottle, soda bottle 739 | pot, flowerpot 740 | potter's wheel 741 | power drill 742 | prayer rug, prayer mat 743 | printer 744 | prison, prison house 745 | projectile, missile 746 | projector 747 | puck, hockey puck 748 | punching bag, punch bag, punching ball, punchball 749 | purse 750 | quill, quill pen 751 | quilt, comforter, comfort, puff 752 | racer, race car, racing car 753 | racket, racquet 754 | radiator 755 | radio, wireless 756 | radio telescope, radio reflector 757 | rain barrel 758 | recreational vehicle, RV, R.V. 759 | reel 760 | reflex camera 761 | refrigerator, icebox 762 | remote control, remote 763 | restaurant, eating house, eating place, eatery 764 | revolver, six-gun, six-shooter 765 | rifle 766 | rocking chair, rocker 767 | rotisserie 768 | rubber eraser, rubber, pencil eraser 769 | rugby ball 770 | rule, ruler 771 | running shoe 772 | safe 773 | safety pin 774 | saltshaker, salt shaker 775 | sandal 776 | sarong 777 | sax, saxophone 778 | scabbard 779 | scale, weighing machine 780 | school bus 781 | schooner 782 | scoreboard 783 | screen, CRT screen 784 | screw 785 | screwdriver 786 | seat belt, seatbelt 787 | sewing machine 788 | shield, buckler 789 | shoe shop, shoe-shop, shoe store 790 | shoji 791 | shopping basket 792 | shopping cart 793 | shovel 794 | shower cap 795 | shower curtain 796 | ski 797 | ski mask 798 | sleeping bag 799 | slide rule, slipstick 800 | sliding door 801 | slot, one-armed bandit 802 | snorkel 803 | snowmobile 804 | snowplow, snowplough 805 | soap dispenser 806 | soccer ball 807 | sock 808 | solar dish, solar collector, solar furnace 809 | sombrero 810 | soup bowl 811 | space bar 812 | space heater 813 | space shuttle 814 | spatula 815 | speedboat 816 | spider web, spider's web 817 | spindle 818 | sports car, sport car 819 | spotlight, spot 820 | stage 821 | steam locomotive 822 | steel arch bridge 823 | steel drum 824 | stethoscope 825 | stole 826 | stone wall 827 | stopwatch, stop watch 828 | stove 829 | strainer 830 | streetcar, tram, tramcar, trolley, trolley car 831 | stretcher 832 | studio couch, day bed 833 | stupa, tope 834 | submarine, pigboat, sub, U-boat 835 | suit, suit of clothes 836 | sundial 837 | sunglass 838 | sunglasses, dark glasses, shades 839 | sunscreen, sunblock, sun blocker 840 | suspension bridge 841 | swab, swob, mop 842 | sweatshirt 843 | swimming trunks, bathing trunks 844 | swing 845 | switch, electric switch, electrical switch 846 | syringe 847 | table lamp 848 | tank, army tank, armored combat vehicle, armoured combat vehicle 849 | tape player 850 | teapot 851 | teddy, teddy bear 852 | television, television system 853 | tennis ball 854 | thatch, thatched roof 855 | theater curtain, theatre curtain 856 | thimble 857 | thresher, thrasher, threshing machine 858 | throne 859 | tile roof 860 | toaster 861 | tobacco shop, tobacconist shop, tobacconist 862 | toilet seat 863 | torch 864 | totem pole 865 | tow truck, tow car, wrecker 866 | toyshop 867 | tractor 868 | trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi 869 | tray 870 | trench coat 871 | tricycle, trike, velocipede 872 | trimaran 873 | tripod 874 | triumphal arch 875 | trolleybus, trolley coach, trackless trolley 876 | trombone 877 | tub, vat 878 | turnstile 879 | typewriter keyboard 880 | umbrella 881 | unicycle, monocycle 882 | upright, upright piano 883 | vacuum, vacuum cleaner 884 | vase 885 | vault 886 | velvet 887 | vending machine 888 | vestment 889 | viaduct 890 | violin, fiddle 891 | volleyball 892 | waffle iron 893 | wall clock 894 | wallet, billfold, notecase, pocketbook 895 | wardrobe, closet, press 896 | warplane, military plane 897 | washbasin, handbasin, washbowl, lavabo, wash-hand basin 898 | washer, automatic washer, washing machine 899 | water bottle 900 | water jug 901 | water tower 902 | whiskey jug 903 | whistle 904 | wig 905 | window screen 906 | window shade 907 | Windsor tie 908 | wine bottle 909 | wing 910 | wok 911 | wooden spoon 912 | wool, woolen, woollen 913 | worm fence, snake fence, snake-rail fence, Virginia fence 914 | wreck 915 | yawl 916 | yurt 917 | web site, website, internet site, site 918 | comic book 919 | crossword puzzle, crossword 920 | street sign 921 | traffic light, traffic signal, stoplight 922 | book jacket, dust cover, dust jacket, dust wrapper 923 | menu 924 | plate 925 | guacamole 926 | consomme 927 | hot pot, hotpot 928 | trifle 929 | ice cream, icecream 930 | ice lolly, lolly, lollipop, popsicle 931 | French loaf 932 | bagel, beigel 933 | pretzel 934 | cheeseburger 935 | hotdog, hot dog, red hot 936 | mashed potato 937 | head cabbage 938 | broccoli 939 | cauliflower 940 | zucchini, courgette 941 | spaghetti squash 942 | acorn squash 943 | butternut squash 944 | cucumber, cuke 945 | artichoke, globe artichoke 946 | bell pepper 947 | cardoon 948 | mushroom 949 | Granny Smith 950 | strawberry 951 | orange 952 | lemon 953 | fig 954 | pineapple, ananas 955 | banana 956 | jackfruit, jak, jack 957 | custard apple 958 | pomegranate 959 | hay 960 | carbonara 961 | chocolate sauce, chocolate syrup 962 | dough 963 | meat loaf, meatloaf 964 | pizza, pizza pie 965 | potpie 966 | burrito 967 | red wine 968 | espresso 969 | cup 970 | eggnog 971 | alp 972 | bubble 973 | cliff, drop, drop-off 974 | coral reef 975 | geyser 976 | lakeside, lakeshore 977 | promontory, headland, head, foreland 978 | sandbar, sand bar 979 | seashore, coast, seacoast, sea-coast 980 | valley, vale 981 | volcano 982 | ballplayer, baseball player 983 | groom, bridegroom 984 | scuba diver 985 | rapeseed 986 | daisy 987 | yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum 988 | corn 989 | acorn 990 | hip, rose hip, rosehip 991 | buckeye, horse chestnut, conker 992 | coral fungus 993 | agaric 994 | gyromitra 995 | stinkhorn, carrion fungus 996 | earthstar 997 | hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa 998 | bolete 999 | ear, spike, capitulum 1000 | toilet tissue, toilet paper, bathroom tissue'''.split("\n") 1001 | --------------------------------------------------------------------------------