├── README.md ├── Videos ├── README.md ├── c3d_model.py ├── models │ ├── c3d_model.png │ └── train01_16_128_171_mean.npy ├── sports1m │ └── get_labels.sh └── test_model.py ├── classify.py ├── images ├── architecture.png ├── collies.JPG ├── grad-cam-faults.png ├── grad-campp.png ├── multiple_dogs.jpg ├── snake.JPEG └── water-bird.JPEG ├── misc ├── GuideReLU.py ├── GuideReLU.pyc ├── __init__.py ├── __init__.pyc ├── utils.py ├── utils.pyc └── utils.py~ ├── models ├── __init__.py ├── __init__.pyc ├── vgg16.py ├── vgg16.pyc ├── vgg19.py ├── vgg_utils.py └── vgg_utils.pyc ├── output ├── multiple_dogs.jpg ├── output.jpeg └── water-bird.JPEG └── synset.txt /README.md: -------------------------------------------------------------------------------- 1 | # Grad-CAM++ 2 | 3 | A generalized gradient-based CNN visualization technique 4 | code for the paper: 5 | ### Grad-CAM++: Generalized Gradient-based Visual Explanations for Deep Convolutional Networks 6 | 7 | To be presented at **WACV 2018**, 8 |
9 | Authors: 10 |
11 | Aditya Chattopadhyay\*, 12 | Anirban Sarkar\*, 13 | Prantik Howlader\* and 14 | Vineeth N Balasubramanian, 15 |
16 | (\* equal contribution) 17 |
18 | ##### Gradcam++ Architecture 19 | ![alt text](images/architecture.png) 20 | ##### Performance of grad-cam++ with respect to grad-cam 21 | ![alt text](images/grad-campp.png) 22 | Above we do a comparision of the performance of gradcam++ with respect to grad-cam. Gradcam++ does a better image localization that Gradcam, not only in scenarios where there is more than one object of same class in an image, but also in cases where there is a single object in an image. 23 | 24 | Implementation in python using tensorflow 1.3. 25 | Kindly download the pretrained weights of the vgg16 network (vgg16.npy) from the following link, and copy the file to the `models/` subdirectory. 26 | https://drive.google.com/drive/folders/0BzS5KZjihEdyUjBHcGFNRnk4bFU?usp=sharing 27 | 28 | ### USAGE: 29 | ```python classify.py -f images/water-bird.JPEG -gpu 3 -o output.jpeg ``` 30 | 31 | 32 | #### Arguments: 33 | - `f`: path to input image 34 | - `gpu`: the gpu id to use, 0-indexed 35 | - `l`: class label, default is -1 (chooses the class predicted by the model) 36 | - `o`: Specify output file name for Grad-CAM++ visualization, default is `output.jpeg`. All results would be saved in the `output/` subdirectory. 37 | 38 | 39 | 40 | #### For Help: 41 | ```python classify.py -h ``` 42 | 43 | 44 | The above code is for the vgg16 network, pre-trained on imagenet. 45 | We tested our code on tensorflow 1.3, compatibility with other versions is not guaranteed. 46 | 47 | 48 | 49 | #### Acknowledgements 50 | Parts of the code have been borrowed and modified from: 51 | ##### For the Grad-CAM tensorflow implementation 52 | https://github.com/Ankush96/grad-cam.tensorflow 53 | https://github.com/insikk/Grad-CAM-tensorflow 54 | ##### For porting pre-trained vgg16-model from caffe model zoo to tensorflow 55 | https://github.com/ry/tensorflow-vgg16 56 | 57 | #### If using this code, please cite our work: 58 | ``` 59 | @article{chattopadhyay2017grad, 60 | title={Grad-CAM++: Generalized Gradient-based Visual Explanations for Deep Convolutional Networks}, 61 | author={Chattopadhyay, Aditya and Sarkar, Anirban and Howlader, Prantik and Balasubramanian, Vineeth N}, 62 | journal={arXiv preprint arXiv:1710.11063}, 63 | year={2017} 64 | } 65 | ``` 66 | -------------------------------------------------------------------------------- /Videos/README.md: -------------------------------------------------------------------------------- 1 | C3D Model for Grad-CAM and Grad-CAM++ with Keras + TensorFlow 2 | ============================================================= 3 | 4 | The scripts here are inspired by [`C3D Model for Keras`](https://gist.github.com/albertomontesg/d8b21a179c1e6cca0480ebdf292c34d2) gist, but specifically for Keras + TensorFlow (not Theano-backend). 5 | 6 | To reproduce results: 7 | 8 | - Run each of these steps: 9 | 10 | 1. Download the pretrained weights of the C3D model (sports1M_weights_tf.h5 and sports1M_weights_tf.json) from the following link, and copy the file to the models/ subdirectory.: `https://drive.google.com/drive/folders/11fi8rtYnPmiyMngUVeOF4vBzUjnISYiH?usp=sharing` 11 | 2. Download sport1mil labels: `bash sports1m/get_labels.sh` 12 | 3. Make sure the default keras config (in `~/.keras/keras.json`) has: `tf` image_dim_ordering, and `tensorflow` backend. 13 | 4. Download some test video and copy to the videos/ subdirectory. 14 | 5. Run test: `python test_model.py` 15 | 16 | Prerequisites 17 | ============= 18 | Known to work with the following python packages: 19 | - Keras==2.0.6 20 | - tensorflow==1.4.0 21 | - h5py==2.8.0 22 | - numpy==1.14.2 23 | 24 | Results 25 | ======= 26 | For every video in the videos/ subdirectory, this code will generate saliency maps for Grad-CAM and Grad-CAM++ with respect to the convolution layer selected in the code. 27 | 28 | The top 5 labels will also be reported, and should look something like: 29 | 30 | ``` 31 | Top 5 probabilities and labels: 32 | basketball: 0.71422 33 | streetball: 0.10293 34 | volleyball: 0.04900 35 | greco-roman wrestling: 0.02638 36 | freestyle wrestling: 0.02408 37 | ``` 38 | 39 | References 40 | ========== 41 | 42 | 1. [C3D Model for Keras](https://gist.github.com/albertomontesg/d8b21a179c1e6cca0480ebdf292c34d2) 43 | 2. [Original C3D implementation in Caffe](https://github.com/facebook/C3D) 44 | 3. [C3D paper](https://arxiv.org/abs/1412.0767) 45 | -------------------------------------------------------------------------------- /Videos/c3d_model.py: -------------------------------------------------------------------------------- 1 | from keras.models import Sequential 2 | from keras.layers.core import Dense, Dropout, Flatten 3 | from keras.layers.convolutional import Convolution3D, MaxPooling3D, ZeroPadding3D 4 | from keras.optimizers import SGD 5 | 6 | ''' 7 | dim_ordering issue: 8 | - 'th'-style dim_ordering: [batch, channels, depth, height, width] 9 | - 'tf'-style dim_ordering: [batch, depth, height, width, channels] 10 | ''' 11 | 12 | def get_model(summary=False, backend='tf'): 13 | """ Return the Keras model of the network 14 | """ 15 | model = Sequential() 16 | if backend == 'tf': 17 | input_shape=(16, 112, 112, 3) # l, h, w, c 18 | else: 19 | input_shape=(3, 16, 112, 112) # c, l, h, w 20 | model.add(Convolution3D(64, 3, 3, 3, activation='relu', 21 | border_mode='same', name='conv1', 22 | input_shape=input_shape)) 23 | model.add(MaxPooling3D(pool_size=(1, 2, 2), strides=(1, 2, 2), 24 | border_mode='valid', name='pool1')) 25 | # 2nd layer group 26 | model.add(Convolution3D(128, 3, 3, 3, activation='relu', 27 | border_mode='same', name='conv2')) 28 | model.add(MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), 29 | border_mode='valid', name='pool2')) 30 | # 3rd layer group 31 | model.add(Convolution3D(256, 3, 3, 3, activation='relu', 32 | border_mode='same', name='conv3a')) 33 | model.add(Convolution3D(256, 3, 3, 3, activation='relu', 34 | border_mode='same', name='conv3b')) 35 | model.add(MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), 36 | border_mode='valid', name='pool3')) 37 | # 4th layer group 38 | model.add(Convolution3D(512, 3, 3, 3, activation='relu', 39 | border_mode='same', name='conv4a')) 40 | model.add(Convolution3D(512, 3, 3, 3, activation='relu', 41 | border_mode='same', name='conv4b')) 42 | model.add(MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), 43 | border_mode='valid', name='pool4')) 44 | # 5th layer group 45 | model.add(Convolution3D(512, 3, 3, 3, activation='relu', 46 | border_mode='same', name='conv5a')) 47 | model.add(Convolution3D(512, 3, 3, 3, activation='relu', 48 | border_mode='same', name='conv5b')) 49 | model.add(ZeroPadding3D(padding=(0, 1, 1), name='zeropad5')) 50 | model.add(MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), 51 | border_mode='valid', name='pool5')) 52 | model.add(Flatten()) 53 | # FC layers group 54 | model.add(Dense(4096, activation='relu', name='fc6')) 55 | model.add(Dropout(.5)) 56 | model.add(Dense(4096, activation='relu', name='fc7')) 57 | model.add(Dropout(.5)) 58 | model.add(Dense(487, activation='softmax', name='fc8')) 59 | 60 | if summary: 61 | print(model.summary()) 62 | 63 | return model 64 | 65 | def get_int_model(model, layer, backend='tf'): 66 | 67 | if backend == 'tf': 68 | input_shape=(16, 112, 112, 3) # l, h, w, c 69 | else: 70 | input_shape=(3, 16, 112, 112) # c, l, h, w 71 | 72 | int_model = Sequential() 73 | 74 | int_model.add(Convolution3D(64, 3, 3, 3, activation='relu', 75 | border_mode='same', name='conv1', 76 | input_shape=input_shape, 77 | weights=model.layers[0].get_weights())) 78 | if layer == 'conv1': 79 | return int_model 80 | int_model.add(MaxPooling3D(pool_size=(1, 2, 2), strides=(1, 2, 2), 81 | border_mode='valid', name='pool1')) 82 | if layer == 'pool1': 83 | return int_model 84 | 85 | # 2nd layer group 86 | int_model.add(Convolution3D(128, 3, 3, 3, activation='relu', 87 | border_mode='same', name='conv2', 88 | weights=model.layers[2].get_weights())) 89 | if layer == 'conv2': 90 | return int_model 91 | int_model.add(MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), 92 | border_mode='valid', name='pool2')) 93 | if layer == 'pool2': 94 | return int_model 95 | 96 | # 3rd layer group 97 | int_model.add(Convolution3D(256, 3, 3, 3, activation='relu', 98 | border_mode='same', name='conv3a', 99 | weights=model.layers[4].get_weights())) 100 | if layer == 'conv3a': 101 | return int_model 102 | int_model.add(Convolution3D(256, 3, 3, 3, activation='relu', 103 | border_mode='same', name='conv3b', 104 | weights=model.layers[5].get_weights())) 105 | if layer == 'conv3b': 106 | return int_model 107 | int_model.add(MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), 108 | border_mode='valid', name='pool3')) 109 | if layer == 'pool3': 110 | return int_model 111 | 112 | # 4th layer group 113 | int_model.add(Convolution3D(512, 3, 3, 3, activation='relu', 114 | border_mode='same', name='conv4a', 115 | weights=model.layers[7].get_weights())) 116 | if layer == 'conv4a': 117 | return int_model 118 | int_model.add(Convolution3D(512, 3, 3, 3, activation='relu', 119 | border_mode='same', name='conv4b', 120 | weights=model.layers[8].get_weights())) 121 | if layer == 'conv4b': 122 | return int_model 123 | int_model.add(MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), 124 | border_mode='valid', name='pool4')) 125 | if layer == 'pool4': 126 | return int_model 127 | 128 | # 5th layer group 129 | int_model.add(Convolution3D(512, 3, 3, 3, activation='relu', 130 | border_mode='same', name='conv5a', 131 | weights=model.layers[10].get_weights())) 132 | if layer == 'conv5a': 133 | return int_model 134 | int_model.add(Convolution3D(512, 3, 3, 3, activation='relu', 135 | border_mode='same', name='conv5b', 136 | weights=model.layers[11].get_weights())) 137 | if layer == 'conv5b': 138 | return int_model 139 | int_model.add(ZeroPadding3D(padding=(0, 1, 1), name='zeropad')) 140 | int_model.add(MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2), 141 | border_mode='valid', name='pool5')) 142 | if layer == 'pool5': 143 | return int_model 144 | 145 | int_model.add(Flatten()) 146 | # FC layers group 147 | int_model.add(Dense(4096, activation='relu', name='fc6', 148 | weights=model.layers[15].get_weights())) 149 | if layer == 'fc6': 150 | return int_model 151 | int_model.add(Dropout(.5)) 152 | int_model.add(Dense(4096, activation='relu', name='fc7', 153 | weights=model.layers[17].get_weights())) 154 | if layer == 'fc7': 155 | return int_model 156 | int_model.add(Dropout(.5)) 157 | int_model.add(Dense(487, activation='softmax', name='fc8', 158 | weights=model.layers[19].get_weights())) 159 | if layer == 'fc8': 160 | return int_model 161 | 162 | return None 163 | 164 | if __name__ == '__main__': 165 | model = get_model(summary=True) 166 | -------------------------------------------------------------------------------- /Videos/models/c3d_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/Videos/models/c3d_model.png -------------------------------------------------------------------------------- /Videos/models/train01_16_128_171_mean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/Videos/models/train01_16_128_171_mean.npy -------------------------------------------------------------------------------- /Videos/sports1m/get_labels.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" 4 | 5 | echo --------------------------------------------- 6 | echo Downloading sports-1m-dataset labels... 7 | wget \ 8 | -N \ 9 | https://raw.githubusercontent.com/gtoderici/sports-1m-dataset/master/labels.txt \ 10 | --directory-prefix=${DIR} 11 | 12 | echo --------------------------------------------- 13 | echo Done! 14 | -------------------------------------------------------------------------------- /Videos/test_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import matplotlib 4 | matplotlib.use('Agg') 5 | #import setGPU 6 | from keras.models import model_from_json 7 | from keras.layers.core import Lambda 8 | import tensorflow as tf 9 | import os 10 | import cv2 11 | import numpy as np 12 | from skimage.transform import resize 13 | import scipy.ndimage 14 | import matplotlib.pyplot as plt 15 | import c3d_model 16 | import sys 17 | import keras.backend as K 18 | # K.set_image_dim_ordering('th') 19 | os.environ["CUDA_VISIBLE_DEVICES"]="2,3" 20 | dim_ordering = K.image_dim_ordering 21 | print "[Info] image_dim_order (from default ~/.keras/keras.json)={}".format( 22 | dim_ordering) 23 | backend = dim_ordering 24 | 25 | def target_category_loss(x, category_index, nb_classes): 26 | return tf.multiply(x, K.one_hot([category_index], nb_classes)) 27 | 28 | def target_category_loss_output_shape(input_shape): 29 | return input_shape 30 | 31 | def normalize(x): 32 | # utility function to normalize a tensor by its L2 norm 33 | return x / (K.sqrt(K.mean(K.square(x))) + 1e-5) 34 | 35 | def diagnose(data, verbose=True, label='input', plots=False, backend='tf'): 36 | # Convolution3D? 37 | if data.ndim > 2: 38 | if backend == 'th': 39 | data = np.transpose(data, (1, 2, 3, 0)) 40 | #else: 41 | # data = np.transpose(data, (0, 2, 1, 3)) 42 | min_num_spatial_axes = 10 43 | max_outputs_to_show = 3 44 | ndim = data.ndim 45 | print "[Info] {}.ndim={}".format(label, ndim) 46 | print "[Info] {}.shape={}".format(label, data.shape) 47 | for d in range(ndim): 48 | num_this_dim = data.shape[d] 49 | if num_this_dim >= min_num_spatial_axes: # check for spatial axes 50 | # just first, center, last indices 51 | range_this_dim = [0, num_this_dim/2, num_this_dim - 1] 52 | else: 53 | # sweep all indices for non-spatial axes 54 | range_this_dim = range(num_this_dim) 55 | for i in range_this_dim: 56 | new_dim = tuple([d] + range(d) + range(d + 1, ndim)) 57 | sliced = np.transpose(data, new_dim)[i, ...] 58 | print("[Info] {}, dim:{} {}-th slice: " 59 | "(min, max, mean, std)=({}, {}, {}, {})".format( 60 | label, 61 | d, i, 62 | np.min(sliced), 63 | np.max(sliced), 64 | np.mean(sliced), 65 | np.std(sliced))) 66 | if plots: 67 | # assume (l, h, w, c)-shaped input 68 | if data.ndim != 4: 69 | print("[Error] data (shape={}) is not 4-dim. Check data".format( 70 | data.shape)) 71 | return 72 | l, h, w, c = data.shape 73 | if l >= min_num_spatial_axes or \ 74 | h < min_num_spatial_axes or \ 75 | w < min_num_spatial_axes: 76 | print("[Error] data (shape={}) does not look like in (l,h,w,c) " 77 | "format. Do reshape/transpose.".format(data.shape)) 78 | return 79 | nrows = int(np.ceil(np.sqrt(data.shape[0]))) 80 | # BGR 81 | if c == 3: 82 | for i in range(l): 83 | mng = plt.get_current_fig_manager() 84 | mng.resize(*mng.window.maxsize()) 85 | plt.subplot(nrows, nrows, i + 1) # doh, one-based! 86 | im = np.squeeze(data[i, ...]).astype(np.float32) 87 | im = im[:, :, ::-1] # BGR to RGB 88 | # force it to range [0,1] 89 | im_min, im_max = im.min(), im.max() 90 | if im_max > im_min: 91 | im_std = (im - im_min) / (im_max - im_min) 92 | else: 93 | print "[Warning] image is constant!" 94 | im_std = np.zeros_like(im) 95 | plt.imshow(im_std) 96 | plt.axis('off') 97 | plt.title("{}: t={}".format(label, i)) 98 | plt.show() 99 | #plt.waitforbuttonpress() 100 | else: 101 | for j in range(min(c, max_outputs_to_show)): 102 | for i in range(l): 103 | mng = plt.get_current_fig_manager() 104 | mng.resize(*mng.window.maxsize()) 105 | plt.subplot(nrows, nrows, i + 1) # doh, one-based! 106 | im = np.squeeze(data[i, ...]).astype(np.float32) 107 | im = im[:, :, j] 108 | # force it to range [0,1] 109 | im_min, im_max = im.min(), im.max() 110 | if im_max > im_min: 111 | im_std = (im - im_min) / (im_max - im_min) 112 | else: 113 | print "[Warning] image is constant!" 114 | im_std = np.zeros_like(im) 115 | plt.imshow(im_std) 116 | plt.axis('off') 117 | plt.title("{}: o={}, t={}".format(label, j, i)) 118 | plt.show() 119 | #plt.waitforbuttonpress() 120 | elif data.ndim == 1: 121 | print("[Info] {} (min, max, mean, std)=({}, {}, {}, {})".format( 122 | label, 123 | np.min(data), 124 | np.max(data), 125 | np.mean(data), 126 | np.std(data))) 127 | print("[Info] data[:10]={}".format(data[:10])) 128 | 129 | return 130 | 131 | def main(): 132 | show_images = False 133 | diagnose_plots = False 134 | model_dir = './models' 135 | global backend 136 | 137 | # override backend if provided as an input arg 138 | if len(sys.argv) > 1: 139 | if 'tf' in sys.argv[1].lower(): 140 | backend = 'tf' 141 | else: 142 | backend = 'th' 143 | print "[Info] Using backend={}".format(backend) 144 | 145 | if backend == 'th': 146 | print "hi" 147 | model_weight_filename = os.path.join(model_dir, 'sports1M_weights_th.h5') 148 | model_json_filename = os.path.join(model_dir, 'sports1M_weights_th.json') 149 | else: 150 | print "hello" 151 | model_weight_filename = os.path.join(model_dir, 'sports1M_weights_tf.h5') 152 | model_json_filename = os.path.join(model_dir, 'sports1M_weights_tf.json') 153 | 154 | print("[Info] Reading model architecture...") 155 | model = model_from_json(open(model_json_filename, 'r').read()) 156 | # model = c3d_model.get_model(backend=backend) 157 | 158 | # visualize model 159 | model_img_filename = os.path.join(model_dir, 'c3d_model.png') 160 | if not os.path.exists(model_img_filename): 161 | from keras.utils.visualize_util import plot 162 | plot(model, to_file=model_img_filename) 163 | 164 | print("[Info] Loading model weights...") 165 | model.load_weights(model_weight_filename) 166 | print("[Info] Loading model weights -- DONE!") 167 | model.compile(loss='mean_squared_error', optimizer='sgd') 168 | 169 | print("[Info] Loading labels...") 170 | with open('sports1m/labels.txt', 'r') as f: 171 | labels = [line.strip() for line in f.readlines()] 172 | print('Total labels: {}'.format(len(labels))) 173 | 174 | print("[Info] Loading a sample video...") 175 | 176 | f = open("scores.txt","w") 177 | 178 | for filename in sorted(os.listdir("videos/")): 179 | try: 180 | cap = cv2.VideoCapture("videos/" + filename) 181 | print filename 182 | vid = [] 183 | while True: 184 | ret, img = cap.read() 185 | if not ret: 186 | break 187 | vid.append(cv2.resize(img, (171, 128))) 188 | vid = np.array(vid, dtype=np.float32) 189 | 190 | start_frame = 1000 191 | 192 | X = vid[start_frame:(start_frame + 16), :, :, :] 193 | 194 | # subtract mean 195 | mean_cube = np.load('models/train01_16_128_171_mean.npy') 196 | mean_cube = np.transpose(mean_cube, (1, 2, 3, 0)) 197 | 198 | # center crop 199 | X = X[:, 8:120, 30:142, :] # (l, h, w, c) 200 | 201 | if backend == 'th': 202 | X = np.transpose(X, (3, 0, 1, 2)) # input_shape = (3,16,112,112) 203 | else: 204 | pass # input_shape = (16,112,112,3) 205 | 206 | if 'lambda' in model.layers[-1].name: 207 | model.layers.pop() 208 | model.outputs = [model.layers[-1].output] 209 | model.output_layers = [model.layers[-1]] 210 | model.layers[-1].outbound_nodes = [] 211 | 212 | 213 | # inference 214 | output = model.predict_on_batch(np.array([X])) 215 | 216 | ################################################# 217 | print X.shape 218 | predicted_class = np.argmax(output) 219 | print predicted_class 220 | print output[0][predicted_class], labels[predicted_class] 221 | 222 | 223 | nb_classes = len(labels)#487 224 | target_layer = lambda x: target_category_loss(x, predicted_class, nb_classes) 225 | model.add(Lambda(target_layer, output_shape = target_category_loss_output_shape)) 226 | temp_label = np.zeros(output.shape) 227 | temp_label[0][int(np.argmax(output))] = 1.0 228 | loss = K.sum(model.layers[-1].output*(temp_label)) 229 | 230 | for i in range(14): 231 | ###########Choose a conv layer to generate saliency maps########## 232 | if model.layers[i].name == "conv3a": 233 | conv_output = model.layers[i].output 234 | 235 | 236 | grads = normalize(K.gradients(loss, conv_output)[0]) 237 | 238 | first_derivative = tf.exp(loss)*grads 239 | print first_derivative[0] 240 | print tf.exp(loss) 241 | 242 | #second_derivative 243 | second_derivative = tf.exp(loss)*grads*grads 244 | print second_derivative[0] 245 | 246 | #triple_derivative 247 | triple_derivative = tf.exp(loss)*grads*grads*grads 248 | print triple_derivative[0] 249 | 250 | 251 | gradient_function = K.function([model.layers[0].input, K.learning_phase()], [conv_output, grads, first_derivative, second_derivative, triple_derivative]) 252 | grads_output, grads_val, conv_first_grad, conv_second_grad, conv_third_grad = gradient_function([np.array([X]), 0]) 253 | grads_output, grads_val, conv_first_grad, conv_second_grad, conv_third_grad = grads_output[0, :, :], grads_val[0, :, :, :], conv_first_grad[0, :, :, :], conv_second_grad[0, :, :, :], conv_third_grad[0, :, :, :] 254 | print grads_output.shape, np.max(grads_output), np.min(grads_output) 255 | print grads_val.shape, np.max(grads_val), np.min(grads_val) 256 | print conv_first_grad.shape,np.max(conv_first_grad), np.min(conv_first_grad) 257 | print conv_second_grad.shape,np.max(conv_second_grad), np.min(conv_second_grad) 258 | print conv_third_grad.shape,np.max(conv_third_grad), np.min(conv_third_grad) 259 | 260 | 261 | ############## FOR GRAD-CAM ######################################### 262 | 263 | weights = np.mean(grads_val, axis = (0, 1, 2)) 264 | print weights.shape 265 | cam = np.zeros(grads_output.shape[0 : 3], dtype = np.float32) 266 | print cam.shape 267 | 268 | cam = np.sum(weights*grads_output, axis=3) 269 | print np.max(cam),np.min(cam) 270 | 271 | cam = np.maximum(cam, 0) 272 | cam = scipy.ndimage.zoom(cam, (2, 4, 4)) 273 | heatmap = cam / np.max(cam) 274 | print np.max(heatmap),np.min(heatmap) 275 | print heatmap.shape 276 | 277 | vid_mod = X*heatmap.reshape((16,112,112,1)) 278 | print vid_mod.shape 279 | output_mod = model.predict_on_batch(np.array([vid_mod])) 280 | 281 | predicted_class_mod = output_mod[0].argsort()[::-1][0] 282 | print output_mod[0][predicted_class_mod], labels[predicted_class_mod] 283 | print output_mod[0][predicted_class], labels[predicted_class] 284 | 285 | ################SAVE THE VIDEO AS FRAMES############### 286 | 287 | for i in range(heatmap.shape[0]): 288 | cam_mod = heatmap[i].reshape((112,112,1)) 289 | 290 | gd_img_mod = X[i]*cam_mod 291 | 292 | gd_img_mod = cv2.resize(gd_img_mod, (640,480)) 293 | cv2.imwrite("image-%05d.jpg" %i, gd_img_mod) 294 | 295 | ####################### GRAD-CAM UPTO THIS ############################ 296 | 297 | ############## FOR GradCAM++ ################################# 298 | global_sum = np.sum(conv_third_grad.reshape((-1,256)), axis=0) 299 | #print global_sum 300 | 301 | alpha_num = conv_second_grad 302 | alpha_denom = conv_second_grad*2.0 + conv_third_grad*global_sum.reshape((-1,)) 303 | alpha_denom = np.where(alpha_denom != 0.0, alpha_denom, np.ones(alpha_denom.shape)) 304 | alphas = alpha_num/alpha_denom 305 | 306 | weights = np.maximum(conv_first_grad, 0.0) 307 | #normalizing the alphas 308 | alphas_thresholding = np.where(weights, alphas, 0.0) 309 | 310 | alpha_normalization_constant = np.sum(np.sum(np.sum(alphas_thresholding, axis=0),axis=0),axis=0) 311 | alpha_normalization_constant_processed = np.where(alpha_normalization_constant != 0.0, alpha_normalization_constant, np.ones(alpha_normalization_constant.shape)) 312 | 313 | 314 | alphas /= alpha_normalization_constant_processed.reshape((1,1,1,256)) 315 | #print alphas 316 | 317 | 318 | 319 | deep_linearization_weights = np.sum((weights*alphas).reshape((-1,256)),axis=0) 320 | #print deep_linearization_weights 321 | grad_CAM_map = np.sum(deep_linearization_weights*grads_output, axis=3) 322 | print np.max(grad_CAM_map),np.min(grad_CAM_map) 323 | 324 | grad_CAM_map = scipy.ndimage.zoom(grad_CAM_map, (2, 4, 4)) 325 | print np.max(grad_CAM_map),np.min(grad_CAM_map) 326 | # Passing through ReLU 327 | vid_cam = np.maximum(grad_CAM_map, 0) 328 | vid_heatmap = vid_cam / np.max(vid_cam) # scale 0 to 1.0 329 | print vid_heatmap.shape 330 | 331 | 332 | vid_mod_plus = X*vid_heatmap.reshape((16,112,112,1)) 333 | print vid_mod_plus.shape 334 | output_mod_plus = model.predict_on_batch(np.array([vid_mod_plus])) 335 | predicted_class_mod_plus = output_mod_plus[0].argsort()[::-1][0] 336 | print output_mod_plus[0][predicted_class_mod_plus], labels[predicted_class_mod_plus] 337 | print output_mod_plus[0][predicted_class], labels[predicted_class] 338 | 339 | 340 | ################SAVE THE VIDEO AS FRAMES############### 341 | 342 | 343 | for i in range(vid_heatmap.shape[0]): 344 | vid_cam_mod = vid_heatmap[i].reshape((112,112,1)) 345 | 346 | 347 | vid_gd_img_mod = X[i]*vid_cam_mod 348 | vid_gd_img_mod = cv2.resize(vid_gd_img_mod, (640,480)) 349 | cv2.imwrite(os.path.join("./output", "image-%05d.jpg" %i), vid_gd_img_mod) 350 | 351 | 352 | X_mod = cv2.resize(X[i], (640,480)) 353 | cv2.imwrite("original-image-%05d.jpg" %i, X_mod) 354 | 355 | #############GRAD-CAM++ UPTO THIS #################################### 356 | 357 | #############Write the scores into a file############################# 358 | f.write(str(output[0][predicted_class]) + " " + str(output_mod[0][predicted_class]) + " " + str(output_mod_plus[0][predicted_class]) + "\n") 359 | 360 | except: 361 | print filename 362 | continue 363 | 364 | 365 | 366 | # sort top five predictions from softmax output 367 | top_inds = output[0].argsort()[::-1][:5] # reverse sort and take five largest items 368 | print('\nTop 5 probabilities and labels:') 369 | for i in top_inds: 370 | print('{1}: {0:.5f}'.format(output[0][i], labels[i])) 371 | 372 | if __name__ == '__main__': 373 | main() 374 | -------------------------------------------------------------------------------- /classify.py: -------------------------------------------------------------------------------- 1 | import models.vgg16 as vgg16 2 | import numpy as np 3 | import tensorflow as tf 4 | import argparse 5 | import misc.utils as utils 6 | import models.vgg_utils as vgg_utils 7 | import os 8 | def run(image_filename, class_label, output_filename): 9 | grad_CAM_map= utils.grad_CAM_plus(image_filename, class_label, output_filename) 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument('-l', '--class_label', default=-1, type=int, help='if -1 (default) choose predicted class, else user supplied int') 14 | parser.add_argument('-gpu', '--gpu_device', default="0", type=str, help='if 0 (default) choose gpu 0, else user supplied int') 15 | parser.add_argument('-f', '--file_name', type=str, help="Specify image path for Grad-CAM++ visualization") 16 | parser.add_argument('-o', '--output_filename', default="output.jpeg", type=str, help="Specify output file name for Grad-CAM++ visualization, default name 'output.jpeg'") 17 | args = parser.parse_args() 18 | gpu_id = args.gpu_device 19 | print gpu_id 20 | os.environ["CUDA_VISIBLE_DEVICES"]=gpu_id 21 | run(args.file_name, args.class_label, args.output_filename) 22 | 23 | if __name__ == '__main__': 24 | main() 25 | -------------------------------------------------------------------------------- /images/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/images/architecture.png -------------------------------------------------------------------------------- /images/collies.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/images/collies.JPG -------------------------------------------------------------------------------- /images/grad-cam-faults.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/images/grad-cam-faults.png -------------------------------------------------------------------------------- /images/grad-campp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/images/grad-campp.png -------------------------------------------------------------------------------- /images/multiple_dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/images/multiple_dogs.jpg -------------------------------------------------------------------------------- /images/snake.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/images/snake.JPEG -------------------------------------------------------------------------------- /images/water-bird.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/images/water-bird.JPEG -------------------------------------------------------------------------------- /misc/GuideReLU.py: -------------------------------------------------------------------------------- 1 | # Replace vanila relu to guided relu to get guided backpropagation. 2 | import tensorflow as tf 3 | 4 | from tensorflow.python.framework import ops 5 | from tensorflow.python.ops import gen_nn_ops 6 | 7 | @ops.RegisterGradient("GuidedRelu") 8 | def _GuidedReluGrad(op, grad): 9 | return tf.where(0. < grad, gen_nn_ops._relu_grad(grad, op.outputs[0]), tf.zeros_like(grad)) 10 | 11 | 12 | -------------------------------------------------------------------------------- /misc/GuideReLU.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/misc/GuideReLU.pyc -------------------------------------------------------------------------------- /misc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/misc/__init__.py -------------------------------------------------------------------------------- /misc/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/misc/__init__.pyc -------------------------------------------------------------------------------- /misc/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import GuideReLU as GReLU 4 | import models.vgg16 as vgg16 5 | import models.vgg_utils as vgg_utils 6 | from skimage.transform import resize 7 | import matplotlib.pyplot as plt 8 | import os,cv2 9 | from scipy.misc import imread, imresize 10 | import tensorflow as tf 11 | from tensorflow.python.framework import graph_util 12 | 13 | def guided_BP(image, label_id = -1): 14 | g = tf.get_default_graph() 15 | with g.gradient_override_map({'Relu': 'GuidedRelu'}): 16 | label_vector = tf.placeholder("float", [None, 1000]) 17 | input_image = tf.placeholder("float", [None, 224, 224, 3]) 18 | 19 | vgg = vgg16.Vgg16() 20 | with tf.name_scope("content_vgg"): 21 | vgg.build(input_image) 22 | 23 | cost = vgg.fc8*label_vector 24 | 25 | # Guided backpropagtion back to input layer 26 | gb_grad = tf.gradients(cost, input_image)[0] 27 | 28 | init = tf.global_variables_initializer() 29 | 30 | # Run tensorflow 31 | with tf.Session(graph=g) as sess: 32 | sess.run(init) 33 | output = [0.0]*vgg.prob.get_shape().as_list()[1] #one-hot embedding for desired class activations 34 | if label_id == -1: 35 | prob = sess.run(vgg.prob, feed_dict={input_image:image}) 36 | 37 | vgg_utils.print_prob(prob[0], './synset.txt') 38 | 39 | #creating the output vector for the respective class 40 | index = np.argmax(prob) 41 | print "Predicted_class: ", index 42 | output[index] = 1.0 43 | 44 | else: 45 | output[label_id] = 1.0 46 | output = np.array(output) 47 | gb_grad_value = sess.run(gb_grad, feed_dict={input_image:image, label_vector: output.reshape((1,-1))}) 48 | 49 | return gb_grad_value[0] 50 | 51 | def grad_CAM_plus(filename, label_id, output_filename): 52 | g = tf.get_default_graph() 53 | init = tf.global_variables_initializer() 54 | 55 | # Run tensorflow 56 | sess = tf.Session() 57 | 58 | #define your tensor placeholders for, labels and images 59 | label_vector = tf.placeholder("float", [None, 1000]) 60 | input_image = tf.placeholder("float", [1, 224, 224, 3]) 61 | label_index = tf.placeholder("int64", ()) 62 | 63 | #load vgg16 model 64 | vgg = vgg16.Vgg16() 65 | with tf.name_scope("content_vgg"): 66 | vgg.build(input_image) 67 | #prob = tf.placeholder("float", [None, 1000]) 68 | 69 | #get the output neuron corresponding to the class of interest (label_id) 70 | cost = vgg.fc8*label_vector 71 | 72 | # Get last convolutional layer gradients for generating gradCAM++ visualization 73 | target_conv_layer = vgg.conv5_3 74 | target_conv_layer_grad = tf.gradients(cost, target_conv_layer)[0] 75 | 76 | #first_derivative 77 | first_derivative = tf.exp(cost)[0][label_index]*target_conv_layer_grad 78 | 79 | #second_derivative 80 | second_derivative = tf.exp(cost)[0][label_index]*target_conv_layer_grad*target_conv_layer_grad 81 | 82 | #triple_derivative 83 | triple_derivative = tf.exp(cost)[0][label_index]*target_conv_layer_grad*target_conv_layer_grad*target_conv_layer_grad 84 | 85 | sess.run(init) 86 | 87 | img1 = vgg_utils.load_image(filename) 88 | 89 | output = [0.0]*vgg.prob.get_shape().as_list()[1] #one-hot embedding for desired class activations 90 | #creating the output vector for the respective class 91 | 92 | if label_id == -1: 93 | prob_val = sess.run(vgg.prob, feed_dict={input_image:[img1]}) 94 | vgg_utils.print_prob(prob_val[0], './synset.txt') 95 | #creating the output vector for the respective class 96 | index = np.argmax(prob_val) 97 | orig_score = prob_val[0][index] 98 | print "Predicted_class: ", index 99 | output[index] = 1.0 100 | label_id = index 101 | else: 102 | output[label_id] = 1.0 103 | output = np.array(output) 104 | print label_id 105 | conv_output, conv_first_grad, conv_second_grad, conv_third_grad = sess.run([target_conv_layer, first_derivative, second_derivative, triple_derivative], feed_dict={input_image:[img1], label_index:label_id, label_vector: output.reshape((1,-1))}) 106 | 107 | global_sum = np.sum(conv_output[0].reshape((-1,conv_first_grad[0].shape[2])), axis=0) 108 | 109 | alpha_num = conv_second_grad[0] 110 | alpha_denom = conv_second_grad[0]*2.0 + conv_third_grad[0]*global_sum.reshape((1,1,conv_first_grad[0].shape[2])) 111 | alpha_denom = np.where(alpha_denom != 0.0, alpha_denom, np.ones(alpha_denom.shape)) 112 | alphas = alpha_num/alpha_denom 113 | 114 | weights = np.maximum(conv_first_grad[0], 0.0) 115 | #normalizing the alphas 116 | """ 117 | alpha_normalization_constant = np.sum(np.sum(alphas, axis=0),axis=0) 118 | 119 | alphas /= alpha_normalization_constant.reshape((1,1,conv_first_grad[0].shape[2])) 120 | """ 121 | 122 | alphas_thresholding = np.where(weights, alphas, 0.0) 123 | 124 | alpha_normalization_constant = np.sum(np.sum(alphas_thresholding, axis=0),axis=0) 125 | alpha_normalization_constant_processed = np.where(alpha_normalization_constant != 0.0, alpha_normalization_constant, np.ones(alpha_normalization_constant.shape)) 126 | 127 | 128 | alphas /= alpha_normalization_constant_processed.reshape((1,1,conv_first_grad[0].shape[2])) 129 | 130 | 131 | 132 | deep_linearization_weights = np.sum((weights*alphas).reshape((-1,conv_first_grad[0].shape[2])),axis=0) 133 | #print deep_linearization_weights 134 | grad_CAM_map = np.sum(deep_linearization_weights*conv_output[0], axis=2) 135 | 136 | # Passing through ReLU 137 | cam = np.maximum(grad_CAM_map, 0) 138 | cam = cam / np.max(cam) # scale 0 to 1.0 139 | 140 | cam = resize(cam, (224,224)) 141 | # Passing through ReLU 142 | cam = np.maximum(grad_CAM_map, 0) 143 | cam = cam / np.max(cam) # scale 0 to 1.0 144 | cam = resize(cam, (224,224)) 145 | 146 | 147 | gb = guided_BP([img1], label_id) 148 | visualize(img1, cam, output_filename, gb) 149 | return cam 150 | 151 | def visualize(img, cam, filename,gb_viz): 152 | gb_viz = np.dstack(( 153 | gb_viz[:, :, 2], 154 | gb_viz[:, :, 1], 155 | gb_viz[:, :, 0], 156 | )) 157 | 158 | gb_viz -= np.min(gb_viz) 159 | gb_viz /= gb_viz.max() 160 | 161 | fig, ax = plt.subplots(nrows=1,ncols=3) 162 | 163 | plt.subplot(141) 164 | plt.axis("off") 165 | imgplot = plt.imshow(img) 166 | 167 | plt.subplot(142) 168 | gd_img = gb_viz*np.minimum(0.25,cam).reshape(224,224,1) 169 | x = gd_img 170 | x = np.squeeze(x) 171 | 172 | #normalize tensor 173 | x -= x.mean() 174 | x /= (x.std() + 1e-5) 175 | x *= 0.1 176 | 177 | # clip to [0, 1] 178 | x += 0.5 179 | x = np.clip(x, 0, 1) 180 | 181 | # convert to RGB array 182 | x *= 255 183 | 184 | x = np.clip(x, 0, 255).astype('uint8') 185 | 186 | plt.axis("off") 187 | imgplot = plt.imshow(x, vmin = 0, vmax = 20) 188 | 189 | cam = (cam*-1.0) + 1.0 190 | cam_heatmap = np.array(cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET)) 191 | plt.subplot(143) 192 | plt.axis("off") 193 | 194 | imgplot = plt.imshow(cam_heatmap) 195 | 196 | plt.subplot(144) 197 | plt.axis("off") 198 | 199 | cam_heatmap = cam_heatmap/255.0 200 | 201 | fin = (img*0.7) + (cam_heatmap*0.3) 202 | imgplot = plt.imshow(fin) 203 | 204 | plt.savefig("output/" + filename, dpi=600) 205 | plt.close(fig) 206 | 207 | -------------------------------------------------------------------------------- /misc/utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/misc/utils.pyc -------------------------------------------------------------------------------- /misc/utils.py~: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import GuideReLU as GReLU 4 | import models.vgg16 as vgg16 5 | import models.vgg_utils as vgg_utils 6 | from skimage.transform import resize 7 | import matplotlib.pyplot as plt 8 | import os 9 | from scipy.misc import imread, imresize 10 | import tensorflow as tf 11 | from tensorflow.python.framework import graph_util 12 | 13 | def guided_BP(image, label_id = -1): 14 | g = tf.get_default_graph() 15 | with g.gradient_override_map({'Relu': 'GuidedRelu'}): 16 | label_vector = tf.placeholder("float", [None, 1000]) 17 | input_image = tf.placeholder("float", [None, 224, 224, 3]) 18 | 19 | vgg = vgg16.Vgg16() 20 | with tf.name_scope("content_vgg"): 21 | vgg.build(input_image) 22 | 23 | cost = vgg.fc8*label_vector 24 | 25 | # Guided backpropagtion back to input layer 26 | gb_grad = tf.gradients(cost, input_image)[0] 27 | 28 | init = tf.global_variables_initializer() 29 | 30 | # Run tensorflow 31 | with tf.Session(graph=g) as sess: 32 | sess.run(init) 33 | output = [0.0]*vgg.prob.get_shape().as_list()[1] #one-hot embedding for desired class activations 34 | if label_id == -1: 35 | prob = sess.run(vgg.prob, feed_dict={input_image:image}) 36 | 37 | vgg_utils.print_prob(prob[0], './synset.txt') 38 | 39 | #creating the output vector for the respective class 40 | index = np.argmax(prob) 41 | print "Predicted_class: ", index 42 | output[index] = 1.0 43 | 44 | else: 45 | output[label_id] = 1.0 46 | output = np.array(output) 47 | gb_grad_value = sess.run(gb_grad, feed_dict={input_image:image, label_vector: output.reshape((1,-1))}) 48 | 49 | return gb_grad_value[0] 50 | 51 | def grad_CAM(image, label_id = -1): 52 | g = tf.get_default_graph() 53 | init = tf.global_variables_initializer() 54 | # Run tensorflow 55 | sess = tf.Session() 56 | 57 | label_vector = tf.placeholder("float", [None, 1000]) 58 | input_image = tf.placeholder("float", [1, 224, 224, 3]) 59 | label_index = tf.placeholder("int64", ()) 60 | 61 | vgg = vgg16.Vgg16() 62 | with tf.name_scope("content_vgg"): 63 | vgg.build(input_image) 64 | #prob = tf.placeholder("float", [None, 1000]) 65 | 66 | cost = vgg.prob*label_vector 67 | 68 | # Get last convolutional layer gradients for generating gradCAM visualization 69 | target_conv_layer = vgg.conv5_3 70 | target_conv_layer_grad = tf.gradients(cost, target_conv_layer)[0] 71 | 72 | #second_derivative 73 | temp_1 = target_conv_layer_grad*(target_conv_layer_grad/vgg.prob[0][label_index]) 74 | 75 | #Calculating the gradient of every node in the softmax layer with the feature map 76 | total = tf.zeros_like(target_conv_layer) 77 | S_double_derivatives_j = [] 78 | y_gradients_list = [] 79 | for j in range(1000): 80 | j_output = np.zeros([1000,]) 81 | j_output[j] = 1.0 82 | output_tensor_j = tf.constant(j_output, dtype=tf.float32) 83 | 84 | y_gradients_j = tf.gradients(vgg.fc8*output_tensor_j, target_conv_layer)[0] 85 | S_gradients_j = tf.gradients(vgg.prob*output_tensor_j, target_conv_layer)[0] 86 | 87 | y_gradients_list.append(y_gradients_j) 88 | 89 | total = total + S_gradients_j*y_gradients_j 90 | 91 | prob_output_j = vgg.prob[0][j] 92 | temp_1_mini = S_gradients_j*(S_gradients_j/prob_output_j) 93 | 94 | S_double_derivatives_j.append(temp_1_mini) 95 | 96 | S_double_derivatives_j_part_one = tf.stack(S_double_derivatives_j) 97 | temp_2 = total*vgg.prob[0][label_index] 98 | second_derivative = temp_1 - temp_2 99 | 100 | #third_derivative 101 | double_temp_1 = second_derivative*(target_conv_layer_grad/vgg.prob[0][label_index]) 102 | double_temp_2 = total*target_conv_layer_grad 103 | 104 | temp_array = [] 105 | for j in range(1000): 106 | prob_output_j = vgg.prob[0][j] 107 | temp_2_mini = total*prob_output_j 108 | temp_array.append(temp_2_mini) 109 | 110 | S_double_derivatives_j_part_two = tf.stack(temp_array) 111 | S_double_derivatives = S_double_derivatives_j_part_one - S_double_derivatives_j_part_two 112 | 113 | total_2 = tf.reduce_sum(S_double_derivatives*tf.stack(y_gradients_list),axis=0) 114 | 115 | double_temp_3 = total_2*vgg.prob[0][label_index] 116 | 117 | triple_derivative = double_temp_1 - double_temp_2*2.0 - double_temp_3 118 | sess.run(init) 119 | f = open("scores.txt","w") 120 | for filename in sorted(os.listdir("../../imagenet_val/")): 121 | print filename 122 | 123 | img1 = vgg_utils.load_image("../../imagenet_val/" + filename) 124 | 125 | output = [0.0]*vgg.prob.get_shape().as_list()[1] #one-hot embedding for desired class activations 126 | #creating the output vector for the respective class 127 | 128 | try: 129 | prob_val = sess.run(vgg.prob, feed_dict={input_image:[img1]}) 130 | except: 131 | print "oops" 132 | continue 133 | 134 | vgg_utils.print_prob(prob_val[0], './synset.txt') 135 | 136 | #creating the output vector for the respective class 137 | index = np.argmax(prob_val) 138 | orig_score = prob_val[0][index] 139 | print "Predicted_class: ", index 140 | output[index] = 1.0 141 | label_id = index 142 | output = np.array(output) 143 | 144 | 145 | conv_output, conv_first_grad, conv_second_grad, conv_third_grad = sess.run([target_conv_layer, target_conv_layer_grad, second_derivative, triple_derivative], feed_dict={input_image:[img1], label_index:label_id, label_vector: output.reshape((1,-1))}) 146 | 147 | global_sum = np.sum(conv_output[0].reshape((-1,512)), axis=0) 148 | 149 | alpha_num = conv_second_grad[0] 150 | alpha_denom = conv_second_grad[0]*2.0 + conv_third_grad[0]*global_sum.reshape((1,1,512)) 151 | alpha_denom = np.where(alpha_denom != 0.0, alpha_denom, np.ones(alpha_denom.shape)) 152 | alphas = alpha_num/alpha_denom 153 | 154 | weights = np.maximum(conv_first_grad[0], 0.0) 155 | #normalizing the alphas 156 | 157 | alpha_normalization_constant = np.sum(np.sum(alphas, axis=0),axis=0) 158 | 159 | alphas /= alpha_normalization_constant.reshape((1,1,512)) 160 | 161 | 162 | deep_linearization_weights = np.sum((weights*alphas).reshape((-1,512)),axis=0) 163 | #print deep_linearization_weights 164 | grad_CAM_map = np.sum(deep_linearization_weights*conv_output[0], axis=2) 165 | 166 | # Passing through ReLU 167 | cam = np.maximum(grad_CAM_map, 0) 168 | cam = cam / np.max(cam) # scale 0 to 1.0 169 | 170 | cam = resize(cam, (224,224)) 171 | # Passing through ReLU 172 | cam = np.maximum(grad_CAM_map, 0) 173 | cam = cam / np.max(cam) # scale 0 to 1.0 174 | cam = resize(cam, (224,224)) 175 | visualize(img1, cam, filename) 176 | 177 | 178 | #keeping the important saliency parts in original image 179 | gd_img = img1*cam.reshape((224,224,1)) 180 | gd_img -= np.min(gd_img) 181 | gd_img /= gd_img.max() 182 | 183 | new_prob = sess.run(vgg.prob, feed_dict={input_image:[gd_img]}) 184 | f.write(str(orig_score) + " " + str(new_prob[0][index]) + "\n") 185 | 186 | # print(cam) 187 | return cam 188 | 189 | def visualize(img, cam, filename): 190 | """gb_viz = np.dstack(( 191 | gb_viz[:, :, 2], 192 | gb_viz[:, :, 1], 193 | gb_viz[:, :, 0], 194 | )) 195 | 196 | gb_viz -= np.min(gb_viz) 197 | gb_viz /= gb_viz.max()""" 198 | 199 | fig, ax = plt.subplots(nrows=1,ncols=2) 200 | 201 | plt.subplot(121) 202 | imgplot = plt.imshow(img) 203 | 204 | 205 | plt.subplot(122) 206 | gd_img = img*cam.reshape((224,224,1)) 207 | gd_img -= np.min(gd_img) 208 | gd_img /= gd_img.max() 209 | imgplot = plt.imshow(gd_img) 210 | 211 | 212 | """gd_gb = gb_viz*cam.reshape((224,224,1)) 213 | gd_gb -= np.min(gd_gb) 214 | gd_gb /= gd_gb.max() 215 | 216 | plt.subplot(224) 217 | imgplot = plt.imshow(gd_gb)""" 218 | 219 | plt.savefig("output/" + filename + ".png") 220 | plt.close(fig) 221 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/models/__init__.pyc -------------------------------------------------------------------------------- /models/vgg16.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import os 3 | 4 | import numpy as np 5 | import tensorflow as tf 6 | import time 7 | 8 | VGG_MEAN = [103.939, 116.779, 123.68] 9 | 10 | 11 | class Vgg16: 12 | def __init__(self, vgg16_npy_path=None): 13 | if vgg16_npy_path is None: 14 | path = inspect.getfile(Vgg16) 15 | path = os.path.abspath(os.path.join(path, os.pardir)) 16 | path = os.path.join(path, "vgg16.npy") 17 | vgg16_npy_path = path 18 | print(path) 19 | 20 | self.data_dict = np.load(vgg16_npy_path, encoding='latin1').item() 21 | print("npy file loaded") 22 | 23 | def build(self, rgb): 24 | """ 25 | load variable from npy to build the VGG 26 | 27 | :param rgb: rgb image [batch, height, width, 3] values scaled [0, 1] 28 | """ 29 | 30 | start_time = time.time() 31 | print("build model started") 32 | rgb_scaled = rgb * 255.0 33 | 34 | # Convert RGB to BGR 35 | red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled) 36 | assert red.get_shape().as_list()[1:] == [224, 224, 1] 37 | assert green.get_shape().as_list()[1:] == [224, 224, 1] 38 | assert blue.get_shape().as_list()[1:] == [224, 224, 1] 39 | bgr = tf.concat(axis=3, values=[ 40 | blue - VGG_MEAN[0], 41 | green - VGG_MEAN[1], 42 | red - VGG_MEAN[2], 43 | ]) 44 | assert bgr.get_shape().as_list()[1:] == [224, 224, 3] 45 | 46 | self.conv1_1 = self.conv_layer(bgr, "conv1_1") 47 | self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2") 48 | self.pool1 = self.max_pool(self.conv1_2, 'pool1') 49 | 50 | self.conv2_1 = self.conv_layer(self.pool1, "conv2_1") 51 | self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2") 52 | self.pool2 = self.max_pool(self.conv2_2, 'pool2') 53 | 54 | self.conv3_1 = self.conv_layer(self.pool2, "conv3_1") 55 | self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2") 56 | self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3") 57 | self.pool3 = self.max_pool(self.conv3_3, 'pool3') 58 | 59 | self.conv4_1 = self.conv_layer(self.pool3, "conv4_1") 60 | self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2") 61 | self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3") 62 | self.pool4 = self.max_pool(self.conv4_3, 'pool4') 63 | 64 | self.conv5_1 = self.conv_layer(self.pool4, "conv5_1") 65 | self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2") 66 | self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3") 67 | self.pool5 = self.max_pool(self.conv5_3, 'pool5') 68 | 69 | self.fc6 = self.fc_layer(self.pool5, "fc6") 70 | assert self.fc6.get_shape().as_list()[1:] == [4096] 71 | self.relu6 = tf.nn.relu(self.fc6) 72 | 73 | self.fc7 = self.fc_layer(self.relu6, "fc7") 74 | self.relu7 = tf.nn.relu(self.fc7) 75 | 76 | self.fc8 = self.fc_layer(self.relu7, "fc8") 77 | 78 | self.prob = tf.nn.softmax(self.fc8, name="prob") 79 | 80 | self.data_dict = None 81 | print(("build model finished: %ds" % (time.time() - start_time))) 82 | 83 | def avg_pool(self, bottom, name): 84 | return tf.nn.avg_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) 85 | 86 | def max_pool(self, bottom, name): 87 | return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) 88 | 89 | def conv_layer(self, bottom, name): 90 | with tf.variable_scope(name): 91 | filt = self.get_conv_filter(name) 92 | 93 | conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME') 94 | 95 | conv_biases = self.get_bias(name) 96 | bias = tf.nn.bias_add(conv, conv_biases) 97 | 98 | relu = tf.nn.relu(bias) 99 | return relu 100 | 101 | def fc_layer(self, bottom, name): 102 | with tf.variable_scope(name): 103 | shape = bottom.get_shape().as_list() 104 | dim = 1 105 | for d in shape[1:]: 106 | dim *= d 107 | x = tf.reshape(bottom, [-1, dim]) 108 | 109 | weights = self.get_fc_weight(name) 110 | biases = self.get_bias(name) 111 | 112 | # Fully connected layer. Note that the '+' operation automatically 113 | # broadcasts the biases. 114 | fc = tf.nn.bias_add(tf.matmul(x, weights), biases) 115 | 116 | return fc 117 | 118 | def get_conv_filter(self, name): 119 | return tf.constant(self.data_dict[name][0], name="filter") 120 | 121 | def get_bias(self, name): 122 | return tf.constant(self.data_dict[name][1], name="biases") 123 | 124 | def get_fc_weight(self, name): 125 | return tf.constant(self.data_dict[name][0], name="weights") 126 | -------------------------------------------------------------------------------- /models/vgg16.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/models/vgg16.pyc -------------------------------------------------------------------------------- /models/vgg19.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tensorflow as tf 3 | 4 | import numpy as np 5 | import time 6 | import inspect 7 | 8 | VGG_MEAN = [103.939, 116.779, 123.68] 9 | 10 | 11 | class Vgg19: 12 | def __init__(self, vgg19_npy_path=None): 13 | if vgg19_npy_path is None: 14 | path = inspect.getfile(Vgg19) 15 | path = os.path.abspath(os.path.join(path, os.pardir)) 16 | path = os.path.join(path, "vgg19.npy") 17 | vgg19_npy_path = path 18 | print(vgg19_npy_path) 19 | 20 | self.data_dict = np.load(vgg19_npy_path, encoding='latin1').item() 21 | print("npy file loaded") 22 | 23 | def build(self, rgb): 24 | """ 25 | load variable from npy to build the VGG 26 | 27 | :param rgb: rgb image [batch, height, width, 3] values scaled [0, 1] 28 | """ 29 | 30 | start_time = time.time() 31 | print("build model started") 32 | rgb_scaled = rgb * 255.0 33 | 34 | # Convert RGB to BGR 35 | red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled) 36 | assert red.get_shape().as_list()[1:] == [224, 224, 1] 37 | assert green.get_shape().as_list()[1:] == [224, 224, 1] 38 | assert blue.get_shape().as_list()[1:] == [224, 224, 1] 39 | bgr = tf.concat(axis=3, values=[ 40 | blue - VGG_MEAN[0], 41 | green - VGG_MEAN[1], 42 | red - VGG_MEAN[2], 43 | ]) 44 | assert bgr.get_shape().as_list()[1:] == [224, 224, 3] 45 | 46 | self.conv1_1 = self.conv_layer(bgr, "conv1_1") 47 | self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2") 48 | self.pool1 = self.max_pool(self.conv1_2, 'pool1') 49 | 50 | self.conv2_1 = self.conv_layer(self.pool1, "conv2_1") 51 | self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2") 52 | self.pool2 = self.max_pool(self.conv2_2, 'pool2') 53 | 54 | self.conv3_1 = self.conv_layer(self.pool2, "conv3_1") 55 | self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2") 56 | self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3") 57 | self.conv3_4 = self.conv_layer(self.conv3_3, "conv3_4") 58 | self.pool3 = self.max_pool(self.conv3_4, 'pool3') 59 | 60 | self.conv4_1 = self.conv_layer(self.pool3, "conv4_1") 61 | self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2") 62 | self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3") 63 | self.conv4_4 = self.conv_layer(self.conv4_3, "conv4_4") 64 | self.pool4 = self.max_pool(self.conv4_4, 'pool4') 65 | 66 | self.conv5_1 = self.conv_layer(self.pool4, "conv5_1") 67 | self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2") 68 | self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3") 69 | self.conv5_4 = self.conv_layer(self.conv5_3, "conv5_4") 70 | self.pool5 = self.max_pool(self.conv5_4, 'pool5') 71 | 72 | self.fc6 = self.fc_layer(self.pool5, "fc6") 73 | assert self.fc6.get_shape().as_list()[1:] == [4096] 74 | self.relu6 = tf.nn.relu(self.fc6) 75 | 76 | self.fc7 = self.fc_layer(self.relu6, "fc7") 77 | self.relu7 = tf.nn.relu(self.fc7) 78 | 79 | self.fc8 = self.fc_layer(self.relu7, "fc8") 80 | 81 | self.prob = tf.nn.softmax(self.fc8, name="prob") 82 | 83 | self.data_dict = None 84 | print(("build model finished: %ds" % (time.time() - start_time))) 85 | 86 | def avg_pool(self, bottom, name): 87 | return tf.nn.avg_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) 88 | 89 | def max_pool(self, bottom, name): 90 | return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name) 91 | 92 | def conv_layer(self, bottom, name): 93 | with tf.variable_scope(name): 94 | filt = self.get_conv_filter(name) 95 | 96 | conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME') 97 | 98 | conv_biases = self.get_bias(name) 99 | bias = tf.nn.bias_add(conv, conv_biases) 100 | 101 | relu = tf.nn.relu(bias) 102 | return relu 103 | 104 | def fc_layer(self, bottom, name): 105 | with tf.variable_scope(name): 106 | shape = bottom.get_shape().as_list() 107 | dim = 1 108 | for d in shape[1:]: 109 | dim *= d 110 | x = tf.reshape(bottom, [-1, dim]) 111 | 112 | weights = self.get_fc_weight(name) 113 | biases = self.get_bias(name) 114 | 115 | # Fully connected layer. Note that the '+' operation automatically 116 | # broadcasts the biases. 117 | fc = tf.nn.bias_add(tf.matmul(x, weights), biases) 118 | 119 | return fc 120 | 121 | def get_conv_filter(self, name): 122 | return tf.constant(self.data_dict[name][0], name="filter") 123 | 124 | def get_bias(self, name): 125 | return tf.constant(self.data_dict[name][1], name="biases") 126 | 127 | def get_fc_weight(self, name): 128 | return tf.constant(self.data_dict[name][0], name="weights") 129 | -------------------------------------------------------------------------------- /models/vgg_utils.py: -------------------------------------------------------------------------------- 1 | import skimage 2 | import skimage.io 3 | import skimage.transform 4 | import numpy as np 5 | 6 | 7 | # synset = [l.strip() for l in open('synset.txt').readlines()] 8 | 9 | 10 | # returns image of shape [224, 224, 3] 11 | # [height, width, depth] 12 | def load_image(path): 13 | # load image 14 | img = skimage.io.imread(path) 15 | img = img / 255.0 16 | assert (0 <= img).all() and (img <= 1.0).all() 17 | # print "Original Image Shape: ", img.shape 18 | # we crop image from center 19 | short_edge = min(img.shape[:2]) 20 | yy = int((img.shape[0] - short_edge) / 2) 21 | xx = int((img.shape[1] - short_edge) / 2) 22 | crop_img = img[yy: yy + short_edge, xx: xx + short_edge] 23 | # resize to 224, 224 24 | resized_img = skimage.transform.resize(crop_img, (224, 224)) 25 | return resized_img 26 | 27 | 28 | # returns the top1 string 29 | def print_prob(prob, file_path): 30 | synset = [l.strip() for l in open(file_path).readlines()] 31 | 32 | # print prob 33 | pred = np.argsort(prob)[::-1] 34 | 35 | # Get top1 label 36 | top1 = synset[pred[0]] 37 | print(("Top1: ", top1, prob[pred[0]])) 38 | # Get top5 label 39 | top5 = [(synset[pred[i]], prob[pred[i]]) for i in range(5)] 40 | print(("Top5: ", top5)) 41 | return top1 42 | 43 | 44 | def load_image2(path, height=None, width=None): 45 | # load image 46 | img = skimage.io.imread(path) 47 | img = img / 255.0 48 | if height is not None and width is not None: 49 | ny = height 50 | nx = width 51 | elif height is not None: 52 | ny = height 53 | nx = img.shape[1] * ny / img.shape[0] 54 | elif width is not None: 55 | nx = width 56 | ny = img.shape[0] * nx / img.shape[1] 57 | else: 58 | ny = img.shape[0] 59 | nx = img.shape[1] 60 | return skimage.transform.resize(img, (ny, nx)) 61 | 62 | 63 | def test(): 64 | img = skimage.io.imread("./test_data/starry_night.jpg") 65 | ny = 300 66 | nx = img.shape[1] * ny / img.shape[0] 67 | img = skimage.transform.resize(img, (ny, nx)) 68 | skimage.io.imsave("./test_data/test/output.jpg", img) 69 | 70 | 71 | if __name__ == "__main__": 72 | test() 73 | -------------------------------------------------------------------------------- /models/vgg_utils.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/models/vgg_utils.pyc -------------------------------------------------------------------------------- /output/multiple_dogs.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/output/multiple_dogs.jpg -------------------------------------------------------------------------------- /output/output.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/output/output.jpeg -------------------------------------------------------------------------------- /output/water-bird.JPEG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/output/water-bird.JPEG -------------------------------------------------------------------------------- /synset.txt: -------------------------------------------------------------------------------- 1 | n01440764 tench, Tinca tinca 2 | n01443537 goldfish, Carassius auratus 3 | n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias 4 | n01491361 tiger shark, Galeocerdo cuvieri 5 | n01494475 hammerhead, hammerhead shark 6 | n01496331 electric ray, crampfish, numbfish, torpedo 7 | n01498041 stingray 8 | n01514668 cock 9 | n01514859 hen 10 | n01518878 ostrich, Struthio camelus 11 | n01530575 brambling, Fringilla montifringilla 12 | n01531178 goldfinch, Carduelis carduelis 13 | n01532829 house finch, linnet, Carpodacus mexicanus 14 | n01534433 junco, snowbird 15 | n01537544 indigo bunting, indigo finch, indigo bird, Passerina cyanea 16 | n01558993 robin, American robin, Turdus migratorius 17 | n01560419 bulbul 18 | n01580077 jay 19 | n01582220 magpie 20 | n01592084 chickadee 21 | n01601694 water ouzel, dipper 22 | n01608432 kite 23 | n01614925 bald eagle, American eagle, Haliaeetus leucocephalus 24 | n01616318 vulture 25 | n01622779 great grey owl, great gray owl, Strix nebulosa 26 | n01629819 European fire salamander, Salamandra salamandra 27 | n01630670 common newt, Triturus vulgaris 28 | n01631663 eft 29 | n01632458 spotted salamander, Ambystoma maculatum 30 | n01632777 axolotl, mud puppy, Ambystoma mexicanum 31 | n01641577 bullfrog, Rana catesbeiana 32 | n01644373 tree frog, tree-frog 33 | n01644900 tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui 34 | n01664065 loggerhead, loggerhead turtle, Caretta caretta 35 | n01665541 leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea 36 | n01667114 mud turtle 37 | n01667778 terrapin 38 | n01669191 box turtle, box tortoise 39 | n01675722 banded gecko 40 | n01677366 common iguana, iguana, Iguana iguana 41 | n01682714 American chameleon, anole, Anolis carolinensis 42 | n01685808 whiptail, whiptail lizard 43 | n01687978 agama 44 | n01688243 frilled lizard, Chlamydosaurus kingi 45 | n01689811 alligator lizard 46 | n01692333 Gila monster, Heloderma suspectum 47 | n01693334 green lizard, Lacerta viridis 48 | n01694178 African chameleon, Chamaeleo chamaeleon 49 | n01695060 Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis 50 | n01697457 African crocodile, Nile crocodile, Crocodylus niloticus 51 | n01698640 American alligator, Alligator mississipiensis 52 | n01704323 triceratops 53 | n01728572 thunder snake, worm snake, Carphophis amoenus 54 | n01728920 ringneck snake, ring-necked snake, ring snake 55 | n01729322 hognose snake, puff adder, sand viper 56 | n01729977 green snake, grass snake 57 | n01734418 king snake, kingsnake 58 | n01735189 garter snake, grass snake 59 | n01737021 water snake 60 | n01739381 vine snake 61 | n01740131 night snake, Hypsiglena torquata 62 | n01742172 boa constrictor, Constrictor constrictor 63 | n01744401 rock python, rock snake, Python sebae 64 | n01748264 Indian cobra, Naja naja 65 | n01749939 green mamba 66 | n01751748 sea snake 67 | n01753488 horned viper, cerastes, sand viper, horned asp, Cerastes cornutus 68 | n01755581 diamondback, diamondback rattlesnake, Crotalus adamanteus 69 | n01756291 sidewinder, horned rattlesnake, Crotalus cerastes 70 | n01768244 trilobite 71 | n01770081 harvestman, daddy longlegs, Phalangium opilio 72 | n01770393 scorpion 73 | n01773157 black and gold garden spider, Argiope aurantia 74 | n01773549 barn spider, Araneus cavaticus 75 | n01773797 garden spider, Aranea diademata 76 | n01774384 black widow, Latrodectus mactans 77 | n01774750 tarantula 78 | n01775062 wolf spider, hunting spider 79 | n01776313 tick 80 | n01784675 centipede 81 | n01795545 black grouse 82 | n01796340 ptarmigan 83 | n01797886 ruffed grouse, partridge, Bonasa umbellus 84 | n01798484 prairie chicken, prairie grouse, prairie fowl 85 | n01806143 peacock 86 | n01806567 quail 87 | n01807496 partridge 88 | n01817953 African grey, African gray, Psittacus erithacus 89 | n01818515 macaw 90 | n01819313 sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita 91 | n01820546 lorikeet 92 | n01824575 coucal 93 | n01828970 bee eater 94 | n01829413 hornbill 95 | n01833805 hummingbird 96 | n01843065 jacamar 97 | n01843383 toucan 98 | n01847000 drake 99 | n01855032 red-breasted merganser, Mergus serrator 100 | n01855672 goose 101 | n01860187 black swan, Cygnus atratus 102 | n01871265 tusker 103 | n01872401 echidna, spiny anteater, anteater 104 | n01873310 platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus 105 | n01877812 wallaby, brush kangaroo 106 | n01882714 koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus 107 | n01883070 wombat 108 | n01910747 jellyfish 109 | n01914609 sea anemone, anemone 110 | n01917289 brain coral 111 | n01924916 flatworm, platyhelminth 112 | n01930112 nematode, nematode worm, roundworm 113 | n01943899 conch 114 | n01944390 snail 115 | n01945685 slug 116 | n01950731 sea slug, nudibranch 117 | n01955084 chiton, coat-of-mail shell, sea cradle, polyplacophore 118 | n01968897 chambered nautilus, pearly nautilus, nautilus 119 | n01978287 Dungeness crab, Cancer magister 120 | n01978455 rock crab, Cancer irroratus 121 | n01980166 fiddler crab 122 | n01981276 king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica 123 | n01983481 American lobster, Northern lobster, Maine lobster, Homarus americanus 124 | n01984695 spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish 125 | n01985128 crayfish, crawfish, crawdad, crawdaddy 126 | n01986214 hermit crab 127 | n01990800 isopod 128 | n02002556 white stork, Ciconia ciconia 129 | n02002724 black stork, Ciconia nigra 130 | n02006656 spoonbill 131 | n02007558 flamingo 132 | n02009229 little blue heron, Egretta caerulea 133 | n02009912 American egret, great white heron, Egretta albus 134 | n02011460 bittern 135 | n02012849 crane 136 | n02013706 limpkin, Aramus pictus 137 | n02017213 European gallinule, Porphyrio porphyrio 138 | n02018207 American coot, marsh hen, mud hen, water hen, Fulica americana 139 | n02018795 bustard 140 | n02025239 ruddy turnstone, Arenaria interpres 141 | n02027492 red-backed sandpiper, dunlin, Erolia alpina 142 | n02028035 redshank, Tringa totanus 143 | n02033041 dowitcher 144 | n02037110 oystercatcher, oyster catcher 145 | n02051845 pelican 146 | n02056570 king penguin, Aptenodytes patagonica 147 | n02058221 albatross, mollymawk 148 | n02066245 grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus 149 | n02071294 killer whale, killer, orca, grampus, sea wolf, Orcinus orca 150 | n02074367 dugong, Dugong dugon 151 | n02077923 sea lion 152 | n02085620 Chihuahua 153 | n02085782 Japanese spaniel 154 | n02085936 Maltese dog, Maltese terrier, Maltese 155 | n02086079 Pekinese, Pekingese, Peke 156 | n02086240 Shih-Tzu 157 | n02086646 Blenheim spaniel 158 | n02086910 papillon 159 | n02087046 toy terrier 160 | n02087394 Rhodesian ridgeback 161 | n02088094 Afghan hound, Afghan 162 | n02088238 basset, basset hound 163 | n02088364 beagle 164 | n02088466 bloodhound, sleuthhound 165 | n02088632 bluetick 166 | n02089078 black-and-tan coonhound 167 | n02089867 Walker hound, Walker foxhound 168 | n02089973 English foxhound 169 | n02090379 redbone 170 | n02090622 borzoi, Russian wolfhound 171 | n02090721 Irish wolfhound 172 | n02091032 Italian greyhound 173 | n02091134 whippet 174 | n02091244 Ibizan hound, Ibizan Podenco 175 | n02091467 Norwegian elkhound, elkhound 176 | n02091635 otterhound, otter hound 177 | n02091831 Saluki, gazelle hound 178 | n02092002 Scottish deerhound, deerhound 179 | n02092339 Weimaraner 180 | n02093256 Staffordshire bullterrier, Staffordshire bull terrier 181 | n02093428 American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier 182 | n02093647 Bedlington terrier 183 | n02093754 Border terrier 184 | n02093859 Kerry blue terrier 185 | n02093991 Irish terrier 186 | n02094114 Norfolk terrier 187 | n02094258 Norwich terrier 188 | n02094433 Yorkshire terrier 189 | n02095314 wire-haired fox terrier 190 | n02095570 Lakeland terrier 191 | n02095889 Sealyham terrier, Sealyham 192 | n02096051 Airedale, Airedale terrier 193 | n02096177 cairn, cairn terrier 194 | n02096294 Australian terrier 195 | n02096437 Dandie Dinmont, Dandie Dinmont terrier 196 | n02096585 Boston bull, Boston terrier 197 | n02097047 miniature schnauzer 198 | n02097130 giant schnauzer 199 | n02097209 standard schnauzer 200 | n02097298 Scotch terrier, Scottish terrier, Scottie 201 | n02097474 Tibetan terrier, chrysanthemum dog 202 | n02097658 silky terrier, Sydney silky 203 | n02098105 soft-coated wheaten terrier 204 | n02098286 West Highland white terrier 205 | n02098413 Lhasa, Lhasa apso 206 | n02099267 flat-coated retriever 207 | n02099429 curly-coated retriever 208 | n02099601 golden retriever 209 | n02099712 Labrador retriever 210 | n02099849 Chesapeake Bay retriever 211 | n02100236 German short-haired pointer 212 | n02100583 vizsla, Hungarian pointer 213 | n02100735 English setter 214 | n02100877 Irish setter, red setter 215 | n02101006 Gordon setter 216 | n02101388 Brittany spaniel 217 | n02101556 clumber, clumber spaniel 218 | n02102040 English springer, English springer spaniel 219 | n02102177 Welsh springer spaniel 220 | n02102318 cocker spaniel, English cocker spaniel, cocker 221 | n02102480 Sussex spaniel 222 | n02102973 Irish water spaniel 223 | n02104029 kuvasz 224 | n02104365 schipperke 225 | n02105056 groenendael 226 | n02105162 malinois 227 | n02105251 briard 228 | n02105412 kelpie 229 | n02105505 komondor 230 | n02105641 Old English sheepdog, bobtail 231 | n02105855 Shetland sheepdog, Shetland sheep dog, Shetland 232 | n02106030 collie 233 | n02106166 Border collie 234 | n02106382 Bouvier des Flandres, Bouviers des Flandres 235 | n02106550 Rottweiler 236 | n02106662 German shepherd, German shepherd dog, German police dog, alsatian 237 | n02107142 Doberman, Doberman pinscher 238 | n02107312 miniature pinscher 239 | n02107574 Greater Swiss Mountain dog 240 | n02107683 Bernese mountain dog 241 | n02107908 Appenzeller 242 | n02108000 EntleBucher 243 | n02108089 boxer 244 | n02108422 bull mastiff 245 | n02108551 Tibetan mastiff 246 | n02108915 French bulldog 247 | n02109047 Great Dane 248 | n02109525 Saint Bernard, St Bernard 249 | n02109961 Eskimo dog, husky 250 | n02110063 malamute, malemute, Alaskan malamute 251 | n02110185 Siberian husky 252 | n02110341 dalmatian, coach dog, carriage dog 253 | n02110627 affenpinscher, monkey pinscher, monkey dog 254 | n02110806 basenji 255 | n02110958 pug, pug-dog 256 | n02111129 Leonberg 257 | n02111277 Newfoundland, Newfoundland dog 258 | n02111500 Great Pyrenees 259 | n02111889 Samoyed, Samoyede 260 | n02112018 Pomeranian 261 | n02112137 chow, chow chow 262 | n02112350 keeshond 263 | n02112706 Brabancon griffon 264 | n02113023 Pembroke, Pembroke Welsh corgi 265 | n02113186 Cardigan, Cardigan Welsh corgi 266 | n02113624 toy poodle 267 | n02113712 miniature poodle 268 | n02113799 standard poodle 269 | n02113978 Mexican hairless 270 | n02114367 timber wolf, grey wolf, gray wolf, Canis lupus 271 | n02114548 white wolf, Arctic wolf, Canis lupus tundrarum 272 | n02114712 red wolf, maned wolf, Canis rufus, Canis niger 273 | n02114855 coyote, prairie wolf, brush wolf, Canis latrans 274 | n02115641 dingo, warrigal, warragal, Canis dingo 275 | n02115913 dhole, Cuon alpinus 276 | n02116738 African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus 277 | n02117135 hyena, hyaena 278 | n02119022 red fox, Vulpes vulpes 279 | n02119789 kit fox, Vulpes macrotis 280 | n02120079 Arctic fox, white fox, Alopex lagopus 281 | n02120505 grey fox, gray fox, Urocyon cinereoargenteus 282 | n02123045 tabby, tabby cat 283 | n02123159 tiger cat 284 | n02123394 Persian cat 285 | n02123597 Siamese cat, Siamese 286 | n02124075 Egyptian cat 287 | n02125311 cougar, puma, catamount, mountain lion, painter, panther, Felis concolor 288 | n02127052 lynx, catamount 289 | n02128385 leopard, Panthera pardus 290 | n02128757 snow leopard, ounce, Panthera uncia 291 | n02128925 jaguar, panther, Panthera onca, Felis onca 292 | n02129165 lion, king of beasts, Panthera leo 293 | n02129604 tiger, Panthera tigris 294 | n02130308 cheetah, chetah, Acinonyx jubatus 295 | n02132136 brown bear, bruin, Ursus arctos 296 | n02133161 American black bear, black bear, Ursus americanus, Euarctos americanus 297 | n02134084 ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus 298 | n02134418 sloth bear, Melursus ursinus, Ursus ursinus 299 | n02137549 mongoose 300 | n02138441 meerkat, mierkat 301 | n02165105 tiger beetle 302 | n02165456 ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle 303 | n02167151 ground beetle, carabid beetle 304 | n02168699 long-horned beetle, longicorn, longicorn beetle 305 | n02169497 leaf beetle, chrysomelid 306 | n02172182 dung beetle 307 | n02174001 rhinoceros beetle 308 | n02177972 weevil 309 | n02190166 fly 310 | n02206856 bee 311 | n02219486 ant, emmet, pismire 312 | n02226429 grasshopper, hopper 313 | n02229544 cricket 314 | n02231487 walking stick, walkingstick, stick insect 315 | n02233338 cockroach, roach 316 | n02236044 mantis, mantid 317 | n02256656 cicada, cicala 318 | n02259212 leafhopper 319 | n02264363 lacewing, lacewing fly 320 | n02268443 dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk 321 | n02268853 damselfly 322 | n02276258 admiral 323 | n02277742 ringlet, ringlet butterfly 324 | n02279972 monarch, monarch butterfly, milkweed butterfly, Danaus plexippus 325 | n02280649 cabbage butterfly 326 | n02281406 sulphur butterfly, sulfur butterfly 327 | n02281787 lycaenid, lycaenid butterfly 328 | n02317335 starfish, sea star 329 | n02319095 sea urchin 330 | n02321529 sea cucumber, holothurian 331 | n02325366 wood rabbit, cottontail, cottontail rabbit 332 | n02326432 hare 333 | n02328150 Angora, Angora rabbit 334 | n02342885 hamster 335 | n02346627 porcupine, hedgehog 336 | n02356798 fox squirrel, eastern fox squirrel, Sciurus niger 337 | n02361337 marmot 338 | n02363005 beaver 339 | n02364673 guinea pig, Cavia cobaya 340 | n02389026 sorrel 341 | n02391049 zebra 342 | n02395406 hog, pig, grunter, squealer, Sus scrofa 343 | n02396427 wild boar, boar, Sus scrofa 344 | n02397096 warthog 345 | n02398521 hippopotamus, hippo, river horse, Hippopotamus amphibius 346 | n02403003 ox 347 | n02408429 water buffalo, water ox, Asiatic buffalo, Bubalus bubalis 348 | n02410509 bison 349 | n02412080 ram, tup 350 | n02415577 bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis 351 | n02417914 ibex, Capra ibex 352 | n02422106 hartebeest 353 | n02422699 impala, Aepyceros melampus 354 | n02423022 gazelle 355 | n02437312 Arabian camel, dromedary, Camelus dromedarius 356 | n02437616 llama 357 | n02441942 weasel 358 | n02442845 mink 359 | n02443114 polecat, fitch, foulmart, foumart, Mustela putorius 360 | n02443484 black-footed ferret, ferret, Mustela nigripes 361 | n02444819 otter 362 | n02445715 skunk, polecat, wood pussy 363 | n02447366 badger 364 | n02454379 armadillo 365 | n02457408 three-toed sloth, ai, Bradypus tridactylus 366 | n02480495 orangutan, orang, orangutang, Pongo pygmaeus 367 | n02480855 gorilla, Gorilla gorilla 368 | n02481823 chimpanzee, chimp, Pan troglodytes 369 | n02483362 gibbon, Hylobates lar 370 | n02483708 siamang, Hylobates syndactylus, Symphalangus syndactylus 371 | n02484975 guenon, guenon monkey 372 | n02486261 patas, hussar monkey, Erythrocebus patas 373 | n02486410 baboon 374 | n02487347 macaque 375 | n02488291 langur 376 | n02488702 colobus, colobus monkey 377 | n02489166 proboscis monkey, Nasalis larvatus 378 | n02490219 marmoset 379 | n02492035 capuchin, ringtail, Cebus capucinus 380 | n02492660 howler monkey, howler 381 | n02493509 titi, titi monkey 382 | n02493793 spider monkey, Ateles geoffroyi 383 | n02494079 squirrel monkey, Saimiri sciureus 384 | n02497673 Madagascar cat, ring-tailed lemur, Lemur catta 385 | n02500267 indri, indris, Indri indri, Indri brevicaudatus 386 | n02504013 Indian elephant, Elephas maximus 387 | n02504458 African elephant, Loxodonta africana 388 | n02509815 lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens 389 | n02510455 giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca 390 | n02514041 barracouta, snoek 391 | n02526121 eel 392 | n02536864 coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch 393 | n02606052 rock beauty, Holocanthus tricolor 394 | n02607072 anemone fish 395 | n02640242 sturgeon 396 | n02641379 gar, garfish, garpike, billfish, Lepisosteus osseus 397 | n02643566 lionfish 398 | n02655020 puffer, pufferfish, blowfish, globefish 399 | n02666196 abacus 400 | n02667093 abaya 401 | n02669723 academic gown, academic robe, judge's robe 402 | n02672831 accordion, piano accordion, squeeze box 403 | n02676566 acoustic guitar 404 | n02687172 aircraft carrier, carrier, flattop, attack aircraft carrier 405 | n02690373 airliner 406 | n02692877 airship, dirigible 407 | n02699494 altar 408 | n02701002 ambulance 409 | n02704792 amphibian, amphibious vehicle 410 | n02708093 analog clock 411 | n02727426 apiary, bee house 412 | n02730930 apron 413 | n02747177 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin 414 | n02749479 assault rifle, assault gun 415 | n02769748 backpack, back pack, knapsack, packsack, rucksack, haversack 416 | n02776631 bakery, bakeshop, bakehouse 417 | n02777292 balance beam, beam 418 | n02782093 balloon 419 | n02783161 ballpoint, ballpoint pen, ballpen, Biro 420 | n02786058 Band Aid 421 | n02787622 banjo 422 | n02788148 bannister, banister, balustrade, balusters, handrail 423 | n02790996 barbell 424 | n02791124 barber chair 425 | n02791270 barbershop 426 | n02793495 barn 427 | n02794156 barometer 428 | n02795169 barrel, cask 429 | n02797295 barrow, garden cart, lawn cart, wheelbarrow 430 | n02799071 baseball 431 | n02802426 basketball 432 | n02804414 bassinet 433 | n02804610 bassoon 434 | n02807133 bathing cap, swimming cap 435 | n02808304 bath towel 436 | n02808440 bathtub, bathing tub, bath, tub 437 | n02814533 beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon 438 | n02814860 beacon, lighthouse, beacon light, pharos 439 | n02815834 beaker 440 | n02817516 bearskin, busby, shako 441 | n02823428 beer bottle 442 | n02823750 beer glass 443 | n02825657 bell cote, bell cot 444 | n02834397 bib 445 | n02835271 bicycle-built-for-two, tandem bicycle, tandem 446 | n02837789 bikini, two-piece 447 | n02840245 binder, ring-binder 448 | n02841315 binoculars, field glasses, opera glasses 449 | n02843684 birdhouse 450 | n02859443 boathouse 451 | n02860847 bobsled, bobsleigh, bob 452 | n02865351 bolo tie, bolo, bola tie, bola 453 | n02869837 bonnet, poke bonnet 454 | n02870880 bookcase 455 | n02871525 bookshop, bookstore, bookstall 456 | n02877765 bottlecap 457 | n02879718 bow 458 | n02883205 bow tie, bow-tie, bowtie 459 | n02892201 brass, memorial tablet, plaque 460 | n02892767 brassiere, bra, bandeau 461 | n02894605 breakwater, groin, groyne, mole, bulwark, seawall, jetty 462 | n02895154 breastplate, aegis, egis 463 | n02906734 broom 464 | n02909870 bucket, pail 465 | n02910353 buckle 466 | n02916936 bulletproof vest 467 | n02917067 bullet train, bullet 468 | n02927161 butcher shop, meat market 469 | n02930766 cab, hack, taxi, taxicab 470 | n02939185 caldron, cauldron 471 | n02948072 candle, taper, wax light 472 | n02950826 cannon 473 | n02951358 canoe 474 | n02951585 can opener, tin opener 475 | n02963159 cardigan 476 | n02965783 car mirror 477 | n02966193 carousel, carrousel, merry-go-round, roundabout, whirligig 478 | n02966687 carpenter's kit, tool kit 479 | n02971356 carton 480 | n02974003 car wheel 481 | n02977058 cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM 482 | n02978881 cassette 483 | n02979186 cassette player 484 | n02980441 castle 485 | n02981792 catamaran 486 | n02988304 CD player 487 | n02992211 cello, violoncello 488 | n02992529 cellular telephone, cellular phone, cellphone, cell, mobile phone 489 | n02999410 chain 490 | n03000134 chainlink fence 491 | n03000247 chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour 492 | n03000684 chain saw, chainsaw 493 | n03014705 chest 494 | n03016953 chiffonier, commode 495 | n03017168 chime, bell, gong 496 | n03018349 china cabinet, china closet 497 | n03026506 Christmas stocking 498 | n03028079 church, church building 499 | n03032252 cinema, movie theater, movie theatre, movie house, picture palace 500 | n03041632 cleaver, meat cleaver, chopper 501 | n03042490 cliff dwelling 502 | n03045698 cloak 503 | n03047690 clog, geta, patten, sabot 504 | n03062245 cocktail shaker 505 | n03063599 coffee mug 506 | n03063689 coffeepot 507 | n03065424 coil, spiral, volute, whorl, helix 508 | n03075370 combination lock 509 | n03085013 computer keyboard, keypad 510 | n03089624 confectionery, confectionary, candy store 511 | n03095699 container ship, containership, container vessel 512 | n03100240 convertible 513 | n03109150 corkscrew, bottle screw 514 | n03110669 cornet, horn, trumpet, trump 515 | n03124043 cowboy boot 516 | n03124170 cowboy hat, ten-gallon hat 517 | n03125729 cradle 518 | n03126707 crane 519 | n03127747 crash helmet 520 | n03127925 crate 521 | n03131574 crib, cot 522 | n03133878 Crock Pot 523 | n03134739 croquet ball 524 | n03141823 crutch 525 | n03146219 cuirass 526 | n03160309 dam, dike, dyke 527 | n03179701 desk 528 | n03180011 desktop computer 529 | n03187595 dial telephone, dial phone 530 | n03188531 diaper, nappy, napkin 531 | n03196217 digital clock 532 | n03197337 digital watch 533 | n03201208 dining table, board 534 | n03207743 dishrag, dishcloth 535 | n03207941 dishwasher, dish washer, dishwashing machine 536 | n03208938 disk brake, disc brake 537 | n03216828 dock, dockage, docking facility 538 | n03218198 dogsled, dog sled, dog sleigh 539 | n03220513 dome 540 | n03223299 doormat, welcome mat 541 | n03240683 drilling platform, offshore rig 542 | n03249569 drum, membranophone, tympan 543 | n03250847 drumstick 544 | n03255030 dumbbell 545 | n03259280 Dutch oven 546 | n03271574 electric fan, blower 547 | n03272010 electric guitar 548 | n03272562 electric locomotive 549 | n03290653 entertainment center 550 | n03291819 envelope 551 | n03297495 espresso maker 552 | n03314780 face powder 553 | n03325584 feather boa, boa 554 | n03337140 file, file cabinet, filing cabinet 555 | n03344393 fireboat 556 | n03345487 fire engine, fire truck 557 | n03347037 fire screen, fireguard 558 | n03355925 flagpole, flagstaff 559 | n03372029 flute, transverse flute 560 | n03376595 folding chair 561 | n03379051 football helmet 562 | n03384352 forklift 563 | n03388043 fountain 564 | n03388183 fountain pen 565 | n03388549 four-poster 566 | n03393912 freight car 567 | n03394916 French horn, horn 568 | n03400231 frying pan, frypan, skillet 569 | n03404251 fur coat 570 | n03417042 garbage truck, dustcart 571 | n03424325 gasmask, respirator, gas helmet 572 | n03425413 gas pump, gasoline pump, petrol pump, island dispenser 573 | n03443371 goblet 574 | n03444034 go-kart 575 | n03445777 golf ball 576 | n03445924 golfcart, golf cart 577 | n03447447 gondola 578 | n03447721 gong, tam-tam 579 | n03450230 gown 580 | n03452741 grand piano, grand 581 | n03457902 greenhouse, nursery, glasshouse 582 | n03459775 grille, radiator grille 583 | n03461385 grocery store, grocery, food market, market 584 | n03467068 guillotine 585 | n03476684 hair slide 586 | n03476991 hair spray 587 | n03478589 half track 588 | n03481172 hammer 589 | n03482405 hamper 590 | n03483316 hand blower, blow dryer, blow drier, hair dryer, hair drier 591 | n03485407 hand-held computer, hand-held microcomputer 592 | n03485794 handkerchief, hankie, hanky, hankey 593 | n03492542 hard disc, hard disk, fixed disk 594 | n03494278 harmonica, mouth organ, harp, mouth harp 595 | n03495258 harp 596 | n03496892 harvester, reaper 597 | n03498962 hatchet 598 | n03527444 holster 599 | n03529860 home theater, home theatre 600 | n03530642 honeycomb 601 | n03532672 hook, claw 602 | n03534580 hoopskirt, crinoline 603 | n03535780 horizontal bar, high bar 604 | n03538406 horse cart, horse-cart 605 | n03544143 hourglass 606 | n03584254 iPod 607 | n03584829 iron, smoothing iron 608 | n03590841 jack-o'-lantern 609 | n03594734 jean, blue jean, denim 610 | n03594945 jeep, landrover 611 | n03595614 jersey, T-shirt, tee shirt 612 | n03598930 jigsaw puzzle 613 | n03599486 jinrikisha, ricksha, rickshaw 614 | n03602883 joystick 615 | n03617480 kimono 616 | n03623198 knee pad 617 | n03627232 knot 618 | n03630383 lab coat, laboratory coat 619 | n03633091 ladle 620 | n03637318 lampshade, lamp shade 621 | n03642806 laptop, laptop computer 622 | n03649909 lawn mower, mower 623 | n03657121 lens cap, lens cover 624 | n03658185 letter opener, paper knife, paperknife 625 | n03661043 library 626 | n03662601 lifeboat 627 | n03666591 lighter, light, igniter, ignitor 628 | n03670208 limousine, limo 629 | n03673027 liner, ocean liner 630 | n03676483 lipstick, lip rouge 631 | n03680355 Loafer 632 | n03690938 lotion 633 | n03691459 loudspeaker, speaker, speaker unit, loudspeaker system, speaker system 634 | n03692522 loupe, jeweler's loupe 635 | n03697007 lumbermill, sawmill 636 | n03706229 magnetic compass 637 | n03709823 mailbag, postbag 638 | n03710193 mailbox, letter box 639 | n03710637 maillot 640 | n03710721 maillot, tank suit 641 | n03717622 manhole cover 642 | n03720891 maraca 643 | n03721384 marimba, xylophone 644 | n03724870 mask 645 | n03729826 matchstick 646 | n03733131 maypole 647 | n03733281 maze, labyrinth 648 | n03733805 measuring cup 649 | n03742115 medicine chest, medicine cabinet 650 | n03743016 megalith, megalithic structure 651 | n03759954 microphone, mike 652 | n03761084 microwave, microwave oven 653 | n03763968 military uniform 654 | n03764736 milk can 655 | n03769881 minibus 656 | n03770439 miniskirt, mini 657 | n03770679 minivan 658 | n03773504 missile 659 | n03775071 mitten 660 | n03775546 mixing bowl 661 | n03776460 mobile home, manufactured home 662 | n03777568 Model T 663 | n03777754 modem 664 | n03781244 monastery 665 | n03782006 monitor 666 | n03785016 moped 667 | n03786901 mortar 668 | n03787032 mortarboard 669 | n03788195 mosque 670 | n03788365 mosquito net 671 | n03791053 motor scooter, scooter 672 | n03792782 mountain bike, all-terrain bike, off-roader 673 | n03792972 mountain tent 674 | n03793489 mouse, computer mouse 675 | n03794056 mousetrap 676 | n03796401 moving van 677 | n03803284 muzzle 678 | n03804744 nail 679 | n03814639 neck brace 680 | n03814906 necklace 681 | n03825788 nipple 682 | n03832673 notebook, notebook computer 683 | n03837869 obelisk 684 | n03838899 oboe, hautboy, hautbois 685 | n03840681 ocarina, sweet potato 686 | n03841143 odometer, hodometer, mileometer, milometer 687 | n03843555 oil filter 688 | n03854065 organ, pipe organ 689 | n03857828 oscilloscope, scope, cathode-ray oscilloscope, CRO 690 | n03866082 overskirt 691 | n03868242 oxcart 692 | n03868863 oxygen mask 693 | n03871628 packet 694 | n03873416 paddle, boat paddle 695 | n03874293 paddlewheel, paddle wheel 696 | n03874599 padlock 697 | n03876231 paintbrush 698 | n03877472 pajama, pyjama, pj's, jammies 699 | n03877845 palace 700 | n03884397 panpipe, pandean pipe, syrinx 701 | n03887697 paper towel 702 | n03888257 parachute, chute 703 | n03888605 parallel bars, bars 704 | n03891251 park bench 705 | n03891332 parking meter 706 | n03895866 passenger car, coach, carriage 707 | n03899768 patio, terrace 708 | n03902125 pay-phone, pay-station 709 | n03903868 pedestal, plinth, footstall 710 | n03908618 pencil box, pencil case 711 | n03908714 pencil sharpener 712 | n03916031 perfume, essence 713 | n03920288 Petri dish 714 | n03924679 photocopier 715 | n03929660 pick, plectrum, plectron 716 | n03929855 pickelhaube 717 | n03930313 picket fence, paling 718 | n03930630 pickup, pickup truck 719 | n03933933 pier 720 | n03935335 piggy bank, penny bank 721 | n03937543 pill bottle 722 | n03938244 pillow 723 | n03942813 ping-pong ball 724 | n03944341 pinwheel 725 | n03947888 pirate, pirate ship 726 | n03950228 pitcher, ewer 727 | n03954731 plane, carpenter's plane, woodworking plane 728 | n03956157 planetarium 729 | n03958227 plastic bag 730 | n03961711 plate rack 731 | n03967562 plow, plough 732 | n03970156 plunger, plumber's helper 733 | n03976467 Polaroid camera, Polaroid Land camera 734 | n03976657 pole 735 | n03977966 police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria 736 | n03980874 poncho 737 | n03982430 pool table, billiard table, snooker table 738 | n03983396 pop bottle, soda bottle 739 | n03991062 pot, flowerpot 740 | n03992509 potter's wheel 741 | n03995372 power drill 742 | n03998194 prayer rug, prayer mat 743 | n04004767 printer 744 | n04005630 prison, prison house 745 | n04008634 projectile, missile 746 | n04009552 projector 747 | n04019541 puck, hockey puck 748 | n04023962 punching bag, punch bag, punching ball, punchball 749 | n04026417 purse 750 | n04033901 quill, quill pen 751 | n04033995 quilt, comforter, comfort, puff 752 | n04037443 racer, race car, racing car 753 | n04039381 racket, racquet 754 | n04040759 radiator 755 | n04041544 radio, wireless 756 | n04044716 radio telescope, radio reflector 757 | n04049303 rain barrel 758 | n04065272 recreational vehicle, RV, R.V. 759 | n04067472 reel 760 | n04069434 reflex camera 761 | n04070727 refrigerator, icebox 762 | n04074963 remote control, remote 763 | n04081281 restaurant, eating house, eating place, eatery 764 | n04086273 revolver, six-gun, six-shooter 765 | n04090263 rifle 766 | n04099969 rocking chair, rocker 767 | n04111531 rotisserie 768 | n04116512 rubber eraser, rubber, pencil eraser 769 | n04118538 rugby ball 770 | n04118776 rule, ruler 771 | n04120489 running shoe 772 | n04125021 safe 773 | n04127249 safety pin 774 | n04131690 saltshaker, salt shaker 775 | n04133789 sandal 776 | n04136333 sarong 777 | n04141076 sax, saxophone 778 | n04141327 scabbard 779 | n04141975 scale, weighing machine 780 | n04146614 school bus 781 | n04147183 schooner 782 | n04149813 scoreboard 783 | n04152593 screen, CRT screen 784 | n04153751 screw 785 | n04154565 screwdriver 786 | n04162706 seat belt, seatbelt 787 | n04179913 sewing machine 788 | n04192698 shield, buckler 789 | n04200800 shoe shop, shoe-shop, shoe store 790 | n04201297 shoji 791 | n04204238 shopping basket 792 | n04204347 shopping cart 793 | n04208210 shovel 794 | n04209133 shower cap 795 | n04209239 shower curtain 796 | n04228054 ski 797 | n04229816 ski mask 798 | n04235860 sleeping bag 799 | n04238763 slide rule, slipstick 800 | n04239074 sliding door 801 | n04243546 slot, one-armed bandit 802 | n04251144 snorkel 803 | n04252077 snowmobile 804 | n04252225 snowplow, snowplough 805 | n04254120 soap dispenser 806 | n04254680 soccer ball 807 | n04254777 sock 808 | n04258138 solar dish, solar collector, solar furnace 809 | n04259630 sombrero 810 | n04263257 soup bowl 811 | n04264628 space bar 812 | n04265275 space heater 813 | n04266014 space shuttle 814 | n04270147 spatula 815 | n04273569 speedboat 816 | n04275548 spider web, spider's web 817 | n04277352 spindle 818 | n04285008 sports car, sport car 819 | n04286575 spotlight, spot 820 | n04296562 stage 821 | n04310018 steam locomotive 822 | n04311004 steel arch bridge 823 | n04311174 steel drum 824 | n04317175 stethoscope 825 | n04325704 stole 826 | n04326547 stone wall 827 | n04328186 stopwatch, stop watch 828 | n04330267 stove 829 | n04332243 strainer 830 | n04335435 streetcar, tram, tramcar, trolley, trolley car 831 | n04336792 stretcher 832 | n04344873 studio couch, day bed 833 | n04346328 stupa, tope 834 | n04347754 submarine, pigboat, sub, U-boat 835 | n04350905 suit, suit of clothes 836 | n04355338 sundial 837 | n04355933 sunglass 838 | n04356056 sunglasses, dark glasses, shades 839 | n04357314 sunscreen, sunblock, sun blocker 840 | n04366367 suspension bridge 841 | n04367480 swab, swob, mop 842 | n04370456 sweatshirt 843 | n04371430 swimming trunks, bathing trunks 844 | n04371774 swing 845 | n04372370 switch, electric switch, electrical switch 846 | n04376876 syringe 847 | n04380533 table lamp 848 | n04389033 tank, army tank, armored combat vehicle, armoured combat vehicle 849 | n04392985 tape player 850 | n04398044 teapot 851 | n04399382 teddy, teddy bear 852 | n04404412 television, television system 853 | n04409515 tennis ball 854 | n04417672 thatch, thatched roof 855 | n04418357 theater curtain, theatre curtain 856 | n04423845 thimble 857 | n04428191 thresher, thrasher, threshing machine 858 | n04429376 throne 859 | n04435653 tile roof 860 | n04442312 toaster 861 | n04443257 tobacco shop, tobacconist shop, tobacconist 862 | n04447861 toilet seat 863 | n04456115 torch 864 | n04458633 totem pole 865 | n04461696 tow truck, tow car, wrecker 866 | n04462240 toyshop 867 | n04465501 tractor 868 | n04467665 trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi 869 | n04476259 tray 870 | n04479046 trench coat 871 | n04482393 tricycle, trike, velocipede 872 | n04483307 trimaran 873 | n04485082 tripod 874 | n04486054 triumphal arch 875 | n04487081 trolleybus, trolley coach, trackless trolley 876 | n04487394 trombone 877 | n04493381 tub, vat 878 | n04501370 turnstile 879 | n04505470 typewriter keyboard 880 | n04507155 umbrella 881 | n04509417 unicycle, monocycle 882 | n04515003 upright, upright piano 883 | n04517823 vacuum, vacuum cleaner 884 | n04522168 vase 885 | n04523525 vault 886 | n04525038 velvet 887 | n04525305 vending machine 888 | n04532106 vestment 889 | n04532670 viaduct 890 | n04536866 violin, fiddle 891 | n04540053 volleyball 892 | n04542943 waffle iron 893 | n04548280 wall clock 894 | n04548362 wallet, billfold, notecase, pocketbook 895 | n04550184 wardrobe, closet, press 896 | n04552348 warplane, military plane 897 | n04553703 washbasin, handbasin, washbowl, lavabo, wash-hand basin 898 | n04554684 washer, automatic washer, washing machine 899 | n04557648 water bottle 900 | n04560804 water jug 901 | n04562935 water tower 902 | n04579145 whiskey jug 903 | n04579432 whistle 904 | n04584207 wig 905 | n04589890 window screen 906 | n04590129 window shade 907 | n04591157 Windsor tie 908 | n04591713 wine bottle 909 | n04592741 wing 910 | n04596742 wok 911 | n04597913 wooden spoon 912 | n04599235 wool, woolen, woollen 913 | n04604644 worm fence, snake fence, snake-rail fence, Virginia fence 914 | n04606251 wreck 915 | n04612504 yawl 916 | n04613696 yurt 917 | n06359193 web site, website, internet site, site 918 | n06596364 comic book 919 | n06785654 crossword puzzle, crossword 920 | n06794110 street sign 921 | n06874185 traffic light, traffic signal, stoplight 922 | n07248320 book jacket, dust cover, dust jacket, dust wrapper 923 | n07565083 menu 924 | n07579787 plate 925 | n07583066 guacamole 926 | n07584110 consomme 927 | n07590611 hot pot, hotpot 928 | n07613480 trifle 929 | n07614500 ice cream, icecream 930 | n07615774 ice lolly, lolly, lollipop, popsicle 931 | n07684084 French loaf 932 | n07693725 bagel, beigel 933 | n07695742 pretzel 934 | n07697313 cheeseburger 935 | n07697537 hotdog, hot dog, red hot 936 | n07711569 mashed potato 937 | n07714571 head cabbage 938 | n07714990 broccoli 939 | n07715103 cauliflower 940 | n07716358 zucchini, courgette 941 | n07716906 spaghetti squash 942 | n07717410 acorn squash 943 | n07717556 butternut squash 944 | n07718472 cucumber, cuke 945 | n07718747 artichoke, globe artichoke 946 | n07720875 bell pepper 947 | n07730033 cardoon 948 | n07734744 mushroom 949 | n07742313 Granny Smith 950 | n07745940 strawberry 951 | n07747607 orange 952 | n07749582 lemon 953 | n07753113 fig 954 | n07753275 pineapple, ananas 955 | n07753592 banana 956 | n07754684 jackfruit, jak, jack 957 | n07760859 custard apple 958 | n07768694 pomegranate 959 | n07802026 hay 960 | n07831146 carbonara 961 | n07836838 chocolate sauce, chocolate syrup 962 | n07860988 dough 963 | n07871810 meat loaf, meatloaf 964 | n07873807 pizza, pizza pie 965 | n07875152 potpie 966 | n07880968 burrito 967 | n07892512 red wine 968 | n07920052 espresso 969 | n07930864 cup 970 | n07932039 eggnog 971 | n09193705 alp 972 | n09229709 bubble 973 | n09246464 cliff, drop, drop-off 974 | n09256479 coral reef 975 | n09288635 geyser 976 | n09332890 lakeside, lakeshore 977 | n09399592 promontory, headland, head, foreland 978 | n09421951 sandbar, sand bar 979 | n09428293 seashore, coast, seacoast, sea-coast 980 | n09468604 valley, vale 981 | n09472597 volcano 982 | n09835506 ballplayer, baseball player 983 | n10148035 groom, bridegroom 984 | n10565667 scuba diver 985 | n11879895 rapeseed 986 | n11939491 daisy 987 | n12057211 yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum 988 | n12144580 corn 989 | n12267677 acorn 990 | n12620546 hip, rose hip, rosehip 991 | n12768682 buckeye, horse chestnut, conker 992 | n12985857 coral fungus 993 | n12998815 agaric 994 | n13037406 gyromitra 995 | n13040303 stinkhorn, carrion fungus 996 | n13044778 earthstar 997 | n13052670 hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa 998 | n13054560 bolete 999 | n13133613 ear, spike, capitulum 1000 | n15075141 toilet tissue, toilet paper, bathroom tissue 1001 | --------------------------------------------------------------------------------