├── README.md
├── Videos
├── README.md
├── c3d_model.py
├── models
│ ├── c3d_model.png
│ └── train01_16_128_171_mean.npy
├── sports1m
│ └── get_labels.sh
└── test_model.py
├── classify.py
├── images
├── architecture.png
├── collies.JPG
├── grad-cam-faults.png
├── grad-campp.png
├── multiple_dogs.jpg
├── snake.JPEG
└── water-bird.JPEG
├── misc
├── GuideReLU.py
├── GuideReLU.pyc
├── __init__.py
├── __init__.pyc
├── utils.py
├── utils.pyc
└── utils.py~
├── models
├── __init__.py
├── __init__.pyc
├── vgg16.py
├── vgg16.pyc
├── vgg19.py
├── vgg_utils.py
└── vgg_utils.pyc
├── output
├── multiple_dogs.jpg
├── output.jpeg
└── water-bird.JPEG
└── synset.txt
/README.md:
--------------------------------------------------------------------------------
1 | # Grad-CAM++
2 |
3 | A generalized gradient-based CNN visualization technique
4 | code for the paper:
5 | ### Grad-CAM++: Generalized Gradient-based Visual Explanations for Deep Convolutional Networks
6 |
7 | To be presented at **WACV 2018**,
8 |
9 | Authors:
10 |
11 | Aditya Chattopadhyay\*,
12 | Anirban Sarkar\*,
13 | Prantik Howlader\* and
14 | Vineeth N Balasubramanian,
15 |
16 | (\* equal contribution)
17 |
18 | ##### Gradcam++ Architecture
19 | 
20 | ##### Performance of grad-cam++ with respect to grad-cam
21 | 
22 | Above we do a comparision of the performance of gradcam++ with respect to grad-cam. Gradcam++ does a better image localization that Gradcam, not only in scenarios where there is more than one object of same class in an image, but also in cases where there is a single object in an image.
23 |
24 | Implementation in python using tensorflow 1.3.
25 | Kindly download the pretrained weights of the vgg16 network (vgg16.npy) from the following link, and copy the file to the `models/` subdirectory.
26 | https://drive.google.com/drive/folders/0BzS5KZjihEdyUjBHcGFNRnk4bFU?usp=sharing
27 |
28 | ### USAGE:
29 | ```python classify.py -f images/water-bird.JPEG -gpu 3 -o output.jpeg ```
30 |
31 |
32 | #### Arguments:
33 | - `f`: path to input image
34 | - `gpu`: the gpu id to use, 0-indexed
35 | - `l`: class label, default is -1 (chooses the class predicted by the model)
36 | - `o`: Specify output file name for Grad-CAM++ visualization, default is `output.jpeg`. All results would be saved in the `output/` subdirectory.
37 |
38 |
39 |
40 | #### For Help:
41 | ```python classify.py -h ```
42 |
43 |
44 | The above code is for the vgg16 network, pre-trained on imagenet.
45 | We tested our code on tensorflow 1.3, compatibility with other versions is not guaranteed.
46 |
47 |
48 |
49 | #### Acknowledgements
50 | Parts of the code have been borrowed and modified from:
51 | ##### For the Grad-CAM tensorflow implementation
52 | https://github.com/Ankush96/grad-cam.tensorflow
53 | https://github.com/insikk/Grad-CAM-tensorflow
54 | ##### For porting pre-trained vgg16-model from caffe model zoo to tensorflow
55 | https://github.com/ry/tensorflow-vgg16
56 |
57 | #### If using this code, please cite our work:
58 | ```
59 | @article{chattopadhyay2017grad,
60 | title={Grad-CAM++: Generalized Gradient-based Visual Explanations for Deep Convolutional Networks},
61 | author={Chattopadhyay, Aditya and Sarkar, Anirban and Howlader, Prantik and Balasubramanian, Vineeth N},
62 | journal={arXiv preprint arXiv:1710.11063},
63 | year={2017}
64 | }
65 | ```
66 |
--------------------------------------------------------------------------------
/Videos/README.md:
--------------------------------------------------------------------------------
1 | C3D Model for Grad-CAM and Grad-CAM++ with Keras + TensorFlow
2 | =============================================================
3 |
4 | The scripts here are inspired by [`C3D Model for Keras`](https://gist.github.com/albertomontesg/d8b21a179c1e6cca0480ebdf292c34d2) gist, but specifically for Keras + TensorFlow (not Theano-backend).
5 |
6 | To reproduce results:
7 |
8 | - Run each of these steps:
9 |
10 | 1. Download the pretrained weights of the C3D model (sports1M_weights_tf.h5 and sports1M_weights_tf.json) from the following link, and copy the file to the models/ subdirectory.: `https://drive.google.com/drive/folders/11fi8rtYnPmiyMngUVeOF4vBzUjnISYiH?usp=sharing`
11 | 2. Download sport1mil labels: `bash sports1m/get_labels.sh`
12 | 3. Make sure the default keras config (in `~/.keras/keras.json`) has: `tf` image_dim_ordering, and `tensorflow` backend.
13 | 4. Download some test video and copy to the videos/ subdirectory.
14 | 5. Run test: `python test_model.py`
15 |
16 | Prerequisites
17 | =============
18 | Known to work with the following python packages:
19 | - Keras==2.0.6
20 | - tensorflow==1.4.0
21 | - h5py==2.8.0
22 | - numpy==1.14.2
23 |
24 | Results
25 | =======
26 | For every video in the videos/ subdirectory, this code will generate saliency maps for Grad-CAM and Grad-CAM++ with respect to the convolution layer selected in the code.
27 |
28 | The top 5 labels will also be reported, and should look something like:
29 |
30 | ```
31 | Top 5 probabilities and labels:
32 | basketball: 0.71422
33 | streetball: 0.10293
34 | volleyball: 0.04900
35 | greco-roman wrestling: 0.02638
36 | freestyle wrestling: 0.02408
37 | ```
38 |
39 | References
40 | ==========
41 |
42 | 1. [C3D Model for Keras](https://gist.github.com/albertomontesg/d8b21a179c1e6cca0480ebdf292c34d2)
43 | 2. [Original C3D implementation in Caffe](https://github.com/facebook/C3D)
44 | 3. [C3D paper](https://arxiv.org/abs/1412.0767)
45 |
--------------------------------------------------------------------------------
/Videos/c3d_model.py:
--------------------------------------------------------------------------------
1 | from keras.models import Sequential
2 | from keras.layers.core import Dense, Dropout, Flatten
3 | from keras.layers.convolutional import Convolution3D, MaxPooling3D, ZeroPadding3D
4 | from keras.optimizers import SGD
5 |
6 | '''
7 | dim_ordering issue:
8 | - 'th'-style dim_ordering: [batch, channels, depth, height, width]
9 | - 'tf'-style dim_ordering: [batch, depth, height, width, channels]
10 | '''
11 |
12 | def get_model(summary=False, backend='tf'):
13 | """ Return the Keras model of the network
14 | """
15 | model = Sequential()
16 | if backend == 'tf':
17 | input_shape=(16, 112, 112, 3) # l, h, w, c
18 | else:
19 | input_shape=(3, 16, 112, 112) # c, l, h, w
20 | model.add(Convolution3D(64, 3, 3, 3, activation='relu',
21 | border_mode='same', name='conv1',
22 | input_shape=input_shape))
23 | model.add(MaxPooling3D(pool_size=(1, 2, 2), strides=(1, 2, 2),
24 | border_mode='valid', name='pool1'))
25 | # 2nd layer group
26 | model.add(Convolution3D(128, 3, 3, 3, activation='relu',
27 | border_mode='same', name='conv2'))
28 | model.add(MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2),
29 | border_mode='valid', name='pool2'))
30 | # 3rd layer group
31 | model.add(Convolution3D(256, 3, 3, 3, activation='relu',
32 | border_mode='same', name='conv3a'))
33 | model.add(Convolution3D(256, 3, 3, 3, activation='relu',
34 | border_mode='same', name='conv3b'))
35 | model.add(MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2),
36 | border_mode='valid', name='pool3'))
37 | # 4th layer group
38 | model.add(Convolution3D(512, 3, 3, 3, activation='relu',
39 | border_mode='same', name='conv4a'))
40 | model.add(Convolution3D(512, 3, 3, 3, activation='relu',
41 | border_mode='same', name='conv4b'))
42 | model.add(MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2),
43 | border_mode='valid', name='pool4'))
44 | # 5th layer group
45 | model.add(Convolution3D(512, 3, 3, 3, activation='relu',
46 | border_mode='same', name='conv5a'))
47 | model.add(Convolution3D(512, 3, 3, 3, activation='relu',
48 | border_mode='same', name='conv5b'))
49 | model.add(ZeroPadding3D(padding=(0, 1, 1), name='zeropad5'))
50 | model.add(MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2),
51 | border_mode='valid', name='pool5'))
52 | model.add(Flatten())
53 | # FC layers group
54 | model.add(Dense(4096, activation='relu', name='fc6'))
55 | model.add(Dropout(.5))
56 | model.add(Dense(4096, activation='relu', name='fc7'))
57 | model.add(Dropout(.5))
58 | model.add(Dense(487, activation='softmax', name='fc8'))
59 |
60 | if summary:
61 | print(model.summary())
62 |
63 | return model
64 |
65 | def get_int_model(model, layer, backend='tf'):
66 |
67 | if backend == 'tf':
68 | input_shape=(16, 112, 112, 3) # l, h, w, c
69 | else:
70 | input_shape=(3, 16, 112, 112) # c, l, h, w
71 |
72 | int_model = Sequential()
73 |
74 | int_model.add(Convolution3D(64, 3, 3, 3, activation='relu',
75 | border_mode='same', name='conv1',
76 | input_shape=input_shape,
77 | weights=model.layers[0].get_weights()))
78 | if layer == 'conv1':
79 | return int_model
80 | int_model.add(MaxPooling3D(pool_size=(1, 2, 2), strides=(1, 2, 2),
81 | border_mode='valid', name='pool1'))
82 | if layer == 'pool1':
83 | return int_model
84 |
85 | # 2nd layer group
86 | int_model.add(Convolution3D(128, 3, 3, 3, activation='relu',
87 | border_mode='same', name='conv2',
88 | weights=model.layers[2].get_weights()))
89 | if layer == 'conv2':
90 | return int_model
91 | int_model.add(MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2),
92 | border_mode='valid', name='pool2'))
93 | if layer == 'pool2':
94 | return int_model
95 |
96 | # 3rd layer group
97 | int_model.add(Convolution3D(256, 3, 3, 3, activation='relu',
98 | border_mode='same', name='conv3a',
99 | weights=model.layers[4].get_weights()))
100 | if layer == 'conv3a':
101 | return int_model
102 | int_model.add(Convolution3D(256, 3, 3, 3, activation='relu',
103 | border_mode='same', name='conv3b',
104 | weights=model.layers[5].get_weights()))
105 | if layer == 'conv3b':
106 | return int_model
107 | int_model.add(MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2),
108 | border_mode='valid', name='pool3'))
109 | if layer == 'pool3':
110 | return int_model
111 |
112 | # 4th layer group
113 | int_model.add(Convolution3D(512, 3, 3, 3, activation='relu',
114 | border_mode='same', name='conv4a',
115 | weights=model.layers[7].get_weights()))
116 | if layer == 'conv4a':
117 | return int_model
118 | int_model.add(Convolution3D(512, 3, 3, 3, activation='relu',
119 | border_mode='same', name='conv4b',
120 | weights=model.layers[8].get_weights()))
121 | if layer == 'conv4b':
122 | return int_model
123 | int_model.add(MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2),
124 | border_mode='valid', name='pool4'))
125 | if layer == 'pool4':
126 | return int_model
127 |
128 | # 5th layer group
129 | int_model.add(Convolution3D(512, 3, 3, 3, activation='relu',
130 | border_mode='same', name='conv5a',
131 | weights=model.layers[10].get_weights()))
132 | if layer == 'conv5a':
133 | return int_model
134 | int_model.add(Convolution3D(512, 3, 3, 3, activation='relu',
135 | border_mode='same', name='conv5b',
136 | weights=model.layers[11].get_weights()))
137 | if layer == 'conv5b':
138 | return int_model
139 | int_model.add(ZeroPadding3D(padding=(0, 1, 1), name='zeropad'))
140 | int_model.add(MaxPooling3D(pool_size=(2, 2, 2), strides=(2, 2, 2),
141 | border_mode='valid', name='pool5'))
142 | if layer == 'pool5':
143 | return int_model
144 |
145 | int_model.add(Flatten())
146 | # FC layers group
147 | int_model.add(Dense(4096, activation='relu', name='fc6',
148 | weights=model.layers[15].get_weights()))
149 | if layer == 'fc6':
150 | return int_model
151 | int_model.add(Dropout(.5))
152 | int_model.add(Dense(4096, activation='relu', name='fc7',
153 | weights=model.layers[17].get_weights()))
154 | if layer == 'fc7':
155 | return int_model
156 | int_model.add(Dropout(.5))
157 | int_model.add(Dense(487, activation='softmax', name='fc8',
158 | weights=model.layers[19].get_weights()))
159 | if layer == 'fc8':
160 | return int_model
161 |
162 | return None
163 |
164 | if __name__ == '__main__':
165 | model = get_model(summary=True)
166 |
--------------------------------------------------------------------------------
/Videos/models/c3d_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/Videos/models/c3d_model.png
--------------------------------------------------------------------------------
/Videos/models/train01_16_128_171_mean.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/Videos/models/train01_16_128_171_mean.npy
--------------------------------------------------------------------------------
/Videos/sports1m/get_labels.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
4 |
5 | echo ---------------------------------------------
6 | echo Downloading sports-1m-dataset labels...
7 | wget \
8 | -N \
9 | https://raw.githubusercontent.com/gtoderici/sports-1m-dataset/master/labels.txt \
10 | --directory-prefix=${DIR}
11 |
12 | echo ---------------------------------------------
13 | echo Done!
14 |
--------------------------------------------------------------------------------
/Videos/test_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import matplotlib
4 | matplotlib.use('Agg')
5 | #import setGPU
6 | from keras.models import model_from_json
7 | from keras.layers.core import Lambda
8 | import tensorflow as tf
9 | import os
10 | import cv2
11 | import numpy as np
12 | from skimage.transform import resize
13 | import scipy.ndimage
14 | import matplotlib.pyplot as plt
15 | import c3d_model
16 | import sys
17 | import keras.backend as K
18 | # K.set_image_dim_ordering('th')
19 | os.environ["CUDA_VISIBLE_DEVICES"]="2,3"
20 | dim_ordering = K.image_dim_ordering
21 | print "[Info] image_dim_order (from default ~/.keras/keras.json)={}".format(
22 | dim_ordering)
23 | backend = dim_ordering
24 |
25 | def target_category_loss(x, category_index, nb_classes):
26 | return tf.multiply(x, K.one_hot([category_index], nb_classes))
27 |
28 | def target_category_loss_output_shape(input_shape):
29 | return input_shape
30 |
31 | def normalize(x):
32 | # utility function to normalize a tensor by its L2 norm
33 | return x / (K.sqrt(K.mean(K.square(x))) + 1e-5)
34 |
35 | def diagnose(data, verbose=True, label='input', plots=False, backend='tf'):
36 | # Convolution3D?
37 | if data.ndim > 2:
38 | if backend == 'th':
39 | data = np.transpose(data, (1, 2, 3, 0))
40 | #else:
41 | # data = np.transpose(data, (0, 2, 1, 3))
42 | min_num_spatial_axes = 10
43 | max_outputs_to_show = 3
44 | ndim = data.ndim
45 | print "[Info] {}.ndim={}".format(label, ndim)
46 | print "[Info] {}.shape={}".format(label, data.shape)
47 | for d in range(ndim):
48 | num_this_dim = data.shape[d]
49 | if num_this_dim >= min_num_spatial_axes: # check for spatial axes
50 | # just first, center, last indices
51 | range_this_dim = [0, num_this_dim/2, num_this_dim - 1]
52 | else:
53 | # sweep all indices for non-spatial axes
54 | range_this_dim = range(num_this_dim)
55 | for i in range_this_dim:
56 | new_dim = tuple([d] + range(d) + range(d + 1, ndim))
57 | sliced = np.transpose(data, new_dim)[i, ...]
58 | print("[Info] {}, dim:{} {}-th slice: "
59 | "(min, max, mean, std)=({}, {}, {}, {})".format(
60 | label,
61 | d, i,
62 | np.min(sliced),
63 | np.max(sliced),
64 | np.mean(sliced),
65 | np.std(sliced)))
66 | if plots:
67 | # assume (l, h, w, c)-shaped input
68 | if data.ndim != 4:
69 | print("[Error] data (shape={}) is not 4-dim. Check data".format(
70 | data.shape))
71 | return
72 | l, h, w, c = data.shape
73 | if l >= min_num_spatial_axes or \
74 | h < min_num_spatial_axes or \
75 | w < min_num_spatial_axes:
76 | print("[Error] data (shape={}) does not look like in (l,h,w,c) "
77 | "format. Do reshape/transpose.".format(data.shape))
78 | return
79 | nrows = int(np.ceil(np.sqrt(data.shape[0])))
80 | # BGR
81 | if c == 3:
82 | for i in range(l):
83 | mng = plt.get_current_fig_manager()
84 | mng.resize(*mng.window.maxsize())
85 | plt.subplot(nrows, nrows, i + 1) # doh, one-based!
86 | im = np.squeeze(data[i, ...]).astype(np.float32)
87 | im = im[:, :, ::-1] # BGR to RGB
88 | # force it to range [0,1]
89 | im_min, im_max = im.min(), im.max()
90 | if im_max > im_min:
91 | im_std = (im - im_min) / (im_max - im_min)
92 | else:
93 | print "[Warning] image is constant!"
94 | im_std = np.zeros_like(im)
95 | plt.imshow(im_std)
96 | plt.axis('off')
97 | plt.title("{}: t={}".format(label, i))
98 | plt.show()
99 | #plt.waitforbuttonpress()
100 | else:
101 | for j in range(min(c, max_outputs_to_show)):
102 | for i in range(l):
103 | mng = plt.get_current_fig_manager()
104 | mng.resize(*mng.window.maxsize())
105 | plt.subplot(nrows, nrows, i + 1) # doh, one-based!
106 | im = np.squeeze(data[i, ...]).astype(np.float32)
107 | im = im[:, :, j]
108 | # force it to range [0,1]
109 | im_min, im_max = im.min(), im.max()
110 | if im_max > im_min:
111 | im_std = (im - im_min) / (im_max - im_min)
112 | else:
113 | print "[Warning] image is constant!"
114 | im_std = np.zeros_like(im)
115 | plt.imshow(im_std)
116 | plt.axis('off')
117 | plt.title("{}: o={}, t={}".format(label, j, i))
118 | plt.show()
119 | #plt.waitforbuttonpress()
120 | elif data.ndim == 1:
121 | print("[Info] {} (min, max, mean, std)=({}, {}, {}, {})".format(
122 | label,
123 | np.min(data),
124 | np.max(data),
125 | np.mean(data),
126 | np.std(data)))
127 | print("[Info] data[:10]={}".format(data[:10]))
128 |
129 | return
130 |
131 | def main():
132 | show_images = False
133 | diagnose_plots = False
134 | model_dir = './models'
135 | global backend
136 |
137 | # override backend if provided as an input arg
138 | if len(sys.argv) > 1:
139 | if 'tf' in sys.argv[1].lower():
140 | backend = 'tf'
141 | else:
142 | backend = 'th'
143 | print "[Info] Using backend={}".format(backend)
144 |
145 | if backend == 'th':
146 | print "hi"
147 | model_weight_filename = os.path.join(model_dir, 'sports1M_weights_th.h5')
148 | model_json_filename = os.path.join(model_dir, 'sports1M_weights_th.json')
149 | else:
150 | print "hello"
151 | model_weight_filename = os.path.join(model_dir, 'sports1M_weights_tf.h5')
152 | model_json_filename = os.path.join(model_dir, 'sports1M_weights_tf.json')
153 |
154 | print("[Info] Reading model architecture...")
155 | model = model_from_json(open(model_json_filename, 'r').read())
156 | # model = c3d_model.get_model(backend=backend)
157 |
158 | # visualize model
159 | model_img_filename = os.path.join(model_dir, 'c3d_model.png')
160 | if not os.path.exists(model_img_filename):
161 | from keras.utils.visualize_util import plot
162 | plot(model, to_file=model_img_filename)
163 |
164 | print("[Info] Loading model weights...")
165 | model.load_weights(model_weight_filename)
166 | print("[Info] Loading model weights -- DONE!")
167 | model.compile(loss='mean_squared_error', optimizer='sgd')
168 |
169 | print("[Info] Loading labels...")
170 | with open('sports1m/labels.txt', 'r') as f:
171 | labels = [line.strip() for line in f.readlines()]
172 | print('Total labels: {}'.format(len(labels)))
173 |
174 | print("[Info] Loading a sample video...")
175 |
176 | f = open("scores.txt","w")
177 |
178 | for filename in sorted(os.listdir("videos/")):
179 | try:
180 | cap = cv2.VideoCapture("videos/" + filename)
181 | print filename
182 | vid = []
183 | while True:
184 | ret, img = cap.read()
185 | if not ret:
186 | break
187 | vid.append(cv2.resize(img, (171, 128)))
188 | vid = np.array(vid, dtype=np.float32)
189 |
190 | start_frame = 1000
191 |
192 | X = vid[start_frame:(start_frame + 16), :, :, :]
193 |
194 | # subtract mean
195 | mean_cube = np.load('models/train01_16_128_171_mean.npy')
196 | mean_cube = np.transpose(mean_cube, (1, 2, 3, 0))
197 |
198 | # center crop
199 | X = X[:, 8:120, 30:142, :] # (l, h, w, c)
200 |
201 | if backend == 'th':
202 | X = np.transpose(X, (3, 0, 1, 2)) # input_shape = (3,16,112,112)
203 | else:
204 | pass # input_shape = (16,112,112,3)
205 |
206 | if 'lambda' in model.layers[-1].name:
207 | model.layers.pop()
208 | model.outputs = [model.layers[-1].output]
209 | model.output_layers = [model.layers[-1]]
210 | model.layers[-1].outbound_nodes = []
211 |
212 |
213 | # inference
214 | output = model.predict_on_batch(np.array([X]))
215 |
216 | #################################################
217 | print X.shape
218 | predicted_class = np.argmax(output)
219 | print predicted_class
220 | print output[0][predicted_class], labels[predicted_class]
221 |
222 |
223 | nb_classes = len(labels)#487
224 | target_layer = lambda x: target_category_loss(x, predicted_class, nb_classes)
225 | model.add(Lambda(target_layer, output_shape = target_category_loss_output_shape))
226 | temp_label = np.zeros(output.shape)
227 | temp_label[0][int(np.argmax(output))] = 1.0
228 | loss = K.sum(model.layers[-1].output*(temp_label))
229 |
230 | for i in range(14):
231 | ###########Choose a conv layer to generate saliency maps##########
232 | if model.layers[i].name == "conv3a":
233 | conv_output = model.layers[i].output
234 |
235 |
236 | grads = normalize(K.gradients(loss, conv_output)[0])
237 |
238 | first_derivative = tf.exp(loss)*grads
239 | print first_derivative[0]
240 | print tf.exp(loss)
241 |
242 | #second_derivative
243 | second_derivative = tf.exp(loss)*grads*grads
244 | print second_derivative[0]
245 |
246 | #triple_derivative
247 | triple_derivative = tf.exp(loss)*grads*grads*grads
248 | print triple_derivative[0]
249 |
250 |
251 | gradient_function = K.function([model.layers[0].input, K.learning_phase()], [conv_output, grads, first_derivative, second_derivative, triple_derivative])
252 | grads_output, grads_val, conv_first_grad, conv_second_grad, conv_third_grad = gradient_function([np.array([X]), 0])
253 | grads_output, grads_val, conv_first_grad, conv_second_grad, conv_third_grad = grads_output[0, :, :], grads_val[0, :, :, :], conv_first_grad[0, :, :, :], conv_second_grad[0, :, :, :], conv_third_grad[0, :, :, :]
254 | print grads_output.shape, np.max(grads_output), np.min(grads_output)
255 | print grads_val.shape, np.max(grads_val), np.min(grads_val)
256 | print conv_first_grad.shape,np.max(conv_first_grad), np.min(conv_first_grad)
257 | print conv_second_grad.shape,np.max(conv_second_grad), np.min(conv_second_grad)
258 | print conv_third_grad.shape,np.max(conv_third_grad), np.min(conv_third_grad)
259 |
260 |
261 | ############## FOR GRAD-CAM #########################################
262 |
263 | weights = np.mean(grads_val, axis = (0, 1, 2))
264 | print weights.shape
265 | cam = np.zeros(grads_output.shape[0 : 3], dtype = np.float32)
266 | print cam.shape
267 |
268 | cam = np.sum(weights*grads_output, axis=3)
269 | print np.max(cam),np.min(cam)
270 |
271 | cam = np.maximum(cam, 0)
272 | cam = scipy.ndimage.zoom(cam, (2, 4, 4))
273 | heatmap = cam / np.max(cam)
274 | print np.max(heatmap),np.min(heatmap)
275 | print heatmap.shape
276 |
277 | vid_mod = X*heatmap.reshape((16,112,112,1))
278 | print vid_mod.shape
279 | output_mod = model.predict_on_batch(np.array([vid_mod]))
280 |
281 | predicted_class_mod = output_mod[0].argsort()[::-1][0]
282 | print output_mod[0][predicted_class_mod], labels[predicted_class_mod]
283 | print output_mod[0][predicted_class], labels[predicted_class]
284 |
285 | ################SAVE THE VIDEO AS FRAMES###############
286 |
287 | for i in range(heatmap.shape[0]):
288 | cam_mod = heatmap[i].reshape((112,112,1))
289 |
290 | gd_img_mod = X[i]*cam_mod
291 |
292 | gd_img_mod = cv2.resize(gd_img_mod, (640,480))
293 | cv2.imwrite("image-%05d.jpg" %i, gd_img_mod)
294 |
295 | ####################### GRAD-CAM UPTO THIS ############################
296 |
297 | ############## FOR GradCAM++ #################################
298 | global_sum = np.sum(conv_third_grad.reshape((-1,256)), axis=0)
299 | #print global_sum
300 |
301 | alpha_num = conv_second_grad
302 | alpha_denom = conv_second_grad*2.0 + conv_third_grad*global_sum.reshape((-1,))
303 | alpha_denom = np.where(alpha_denom != 0.0, alpha_denom, np.ones(alpha_denom.shape))
304 | alphas = alpha_num/alpha_denom
305 |
306 | weights = np.maximum(conv_first_grad, 0.0)
307 | #normalizing the alphas
308 | alphas_thresholding = np.where(weights, alphas, 0.0)
309 |
310 | alpha_normalization_constant = np.sum(np.sum(np.sum(alphas_thresholding, axis=0),axis=0),axis=0)
311 | alpha_normalization_constant_processed = np.where(alpha_normalization_constant != 0.0, alpha_normalization_constant, np.ones(alpha_normalization_constant.shape))
312 |
313 |
314 | alphas /= alpha_normalization_constant_processed.reshape((1,1,1,256))
315 | #print alphas
316 |
317 |
318 |
319 | deep_linearization_weights = np.sum((weights*alphas).reshape((-1,256)),axis=0)
320 | #print deep_linearization_weights
321 | grad_CAM_map = np.sum(deep_linearization_weights*grads_output, axis=3)
322 | print np.max(grad_CAM_map),np.min(grad_CAM_map)
323 |
324 | grad_CAM_map = scipy.ndimage.zoom(grad_CAM_map, (2, 4, 4))
325 | print np.max(grad_CAM_map),np.min(grad_CAM_map)
326 | # Passing through ReLU
327 | vid_cam = np.maximum(grad_CAM_map, 0)
328 | vid_heatmap = vid_cam / np.max(vid_cam) # scale 0 to 1.0
329 | print vid_heatmap.shape
330 |
331 |
332 | vid_mod_plus = X*vid_heatmap.reshape((16,112,112,1))
333 | print vid_mod_plus.shape
334 | output_mod_plus = model.predict_on_batch(np.array([vid_mod_plus]))
335 | predicted_class_mod_plus = output_mod_plus[0].argsort()[::-1][0]
336 | print output_mod_plus[0][predicted_class_mod_plus], labels[predicted_class_mod_plus]
337 | print output_mod_plus[0][predicted_class], labels[predicted_class]
338 |
339 |
340 | ################SAVE THE VIDEO AS FRAMES###############
341 |
342 |
343 | for i in range(vid_heatmap.shape[0]):
344 | vid_cam_mod = vid_heatmap[i].reshape((112,112,1))
345 |
346 |
347 | vid_gd_img_mod = X[i]*vid_cam_mod
348 | vid_gd_img_mod = cv2.resize(vid_gd_img_mod, (640,480))
349 | cv2.imwrite(os.path.join("./output", "image-%05d.jpg" %i), vid_gd_img_mod)
350 |
351 |
352 | X_mod = cv2.resize(X[i], (640,480))
353 | cv2.imwrite("original-image-%05d.jpg" %i, X_mod)
354 |
355 | #############GRAD-CAM++ UPTO THIS ####################################
356 |
357 | #############Write the scores into a file#############################
358 | f.write(str(output[0][predicted_class]) + " " + str(output_mod[0][predicted_class]) + " " + str(output_mod_plus[0][predicted_class]) + "\n")
359 |
360 | except:
361 | print filename
362 | continue
363 |
364 |
365 |
366 | # sort top five predictions from softmax output
367 | top_inds = output[0].argsort()[::-1][:5] # reverse sort and take five largest items
368 | print('\nTop 5 probabilities and labels:')
369 | for i in top_inds:
370 | print('{1}: {0:.5f}'.format(output[0][i], labels[i]))
371 |
372 | if __name__ == '__main__':
373 | main()
374 |
--------------------------------------------------------------------------------
/classify.py:
--------------------------------------------------------------------------------
1 | import models.vgg16 as vgg16
2 | import numpy as np
3 | import tensorflow as tf
4 | import argparse
5 | import misc.utils as utils
6 | import models.vgg_utils as vgg_utils
7 | import os
8 | def run(image_filename, class_label, output_filename):
9 | grad_CAM_map= utils.grad_CAM_plus(image_filename, class_label, output_filename)
10 |
11 | def main():
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument('-l', '--class_label', default=-1, type=int, help='if -1 (default) choose predicted class, else user supplied int')
14 | parser.add_argument('-gpu', '--gpu_device', default="0", type=str, help='if 0 (default) choose gpu 0, else user supplied int')
15 | parser.add_argument('-f', '--file_name', type=str, help="Specify image path for Grad-CAM++ visualization")
16 | parser.add_argument('-o', '--output_filename', default="output.jpeg", type=str, help="Specify output file name for Grad-CAM++ visualization, default name 'output.jpeg'")
17 | args = parser.parse_args()
18 | gpu_id = args.gpu_device
19 | print gpu_id
20 | os.environ["CUDA_VISIBLE_DEVICES"]=gpu_id
21 | run(args.file_name, args.class_label, args.output_filename)
22 |
23 | if __name__ == '__main__':
24 | main()
25 |
--------------------------------------------------------------------------------
/images/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/images/architecture.png
--------------------------------------------------------------------------------
/images/collies.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/images/collies.JPG
--------------------------------------------------------------------------------
/images/grad-cam-faults.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/images/grad-cam-faults.png
--------------------------------------------------------------------------------
/images/grad-campp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/images/grad-campp.png
--------------------------------------------------------------------------------
/images/multiple_dogs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/images/multiple_dogs.jpg
--------------------------------------------------------------------------------
/images/snake.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/images/snake.JPEG
--------------------------------------------------------------------------------
/images/water-bird.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/images/water-bird.JPEG
--------------------------------------------------------------------------------
/misc/GuideReLU.py:
--------------------------------------------------------------------------------
1 | # Replace vanila relu to guided relu to get guided backpropagation.
2 | import tensorflow as tf
3 |
4 | from tensorflow.python.framework import ops
5 | from tensorflow.python.ops import gen_nn_ops
6 |
7 | @ops.RegisterGradient("GuidedRelu")
8 | def _GuidedReluGrad(op, grad):
9 | return tf.where(0. < grad, gen_nn_ops._relu_grad(grad, op.outputs[0]), tf.zeros_like(grad))
10 |
11 |
12 |
--------------------------------------------------------------------------------
/misc/GuideReLU.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/misc/GuideReLU.pyc
--------------------------------------------------------------------------------
/misc/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/misc/__init__.py
--------------------------------------------------------------------------------
/misc/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/misc/__init__.pyc
--------------------------------------------------------------------------------
/misc/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import GuideReLU as GReLU
4 | import models.vgg16 as vgg16
5 | import models.vgg_utils as vgg_utils
6 | from skimage.transform import resize
7 | import matplotlib.pyplot as plt
8 | import os,cv2
9 | from scipy.misc import imread, imresize
10 | import tensorflow as tf
11 | from tensorflow.python.framework import graph_util
12 |
13 | def guided_BP(image, label_id = -1):
14 | g = tf.get_default_graph()
15 | with g.gradient_override_map({'Relu': 'GuidedRelu'}):
16 | label_vector = tf.placeholder("float", [None, 1000])
17 | input_image = tf.placeholder("float", [None, 224, 224, 3])
18 |
19 | vgg = vgg16.Vgg16()
20 | with tf.name_scope("content_vgg"):
21 | vgg.build(input_image)
22 |
23 | cost = vgg.fc8*label_vector
24 |
25 | # Guided backpropagtion back to input layer
26 | gb_grad = tf.gradients(cost, input_image)[0]
27 |
28 | init = tf.global_variables_initializer()
29 |
30 | # Run tensorflow
31 | with tf.Session(graph=g) as sess:
32 | sess.run(init)
33 | output = [0.0]*vgg.prob.get_shape().as_list()[1] #one-hot embedding for desired class activations
34 | if label_id == -1:
35 | prob = sess.run(vgg.prob, feed_dict={input_image:image})
36 |
37 | vgg_utils.print_prob(prob[0], './synset.txt')
38 |
39 | #creating the output vector for the respective class
40 | index = np.argmax(prob)
41 | print "Predicted_class: ", index
42 | output[index] = 1.0
43 |
44 | else:
45 | output[label_id] = 1.0
46 | output = np.array(output)
47 | gb_grad_value = sess.run(gb_grad, feed_dict={input_image:image, label_vector: output.reshape((1,-1))})
48 |
49 | return gb_grad_value[0]
50 |
51 | def grad_CAM_plus(filename, label_id, output_filename):
52 | g = tf.get_default_graph()
53 | init = tf.global_variables_initializer()
54 |
55 | # Run tensorflow
56 | sess = tf.Session()
57 |
58 | #define your tensor placeholders for, labels and images
59 | label_vector = tf.placeholder("float", [None, 1000])
60 | input_image = tf.placeholder("float", [1, 224, 224, 3])
61 | label_index = tf.placeholder("int64", ())
62 |
63 | #load vgg16 model
64 | vgg = vgg16.Vgg16()
65 | with tf.name_scope("content_vgg"):
66 | vgg.build(input_image)
67 | #prob = tf.placeholder("float", [None, 1000])
68 |
69 | #get the output neuron corresponding to the class of interest (label_id)
70 | cost = vgg.fc8*label_vector
71 |
72 | # Get last convolutional layer gradients for generating gradCAM++ visualization
73 | target_conv_layer = vgg.conv5_3
74 | target_conv_layer_grad = tf.gradients(cost, target_conv_layer)[0]
75 |
76 | #first_derivative
77 | first_derivative = tf.exp(cost)[0][label_index]*target_conv_layer_grad
78 |
79 | #second_derivative
80 | second_derivative = tf.exp(cost)[0][label_index]*target_conv_layer_grad*target_conv_layer_grad
81 |
82 | #triple_derivative
83 | triple_derivative = tf.exp(cost)[0][label_index]*target_conv_layer_grad*target_conv_layer_grad*target_conv_layer_grad
84 |
85 | sess.run(init)
86 |
87 | img1 = vgg_utils.load_image(filename)
88 |
89 | output = [0.0]*vgg.prob.get_shape().as_list()[1] #one-hot embedding for desired class activations
90 | #creating the output vector for the respective class
91 |
92 | if label_id == -1:
93 | prob_val = sess.run(vgg.prob, feed_dict={input_image:[img1]})
94 | vgg_utils.print_prob(prob_val[0], './synset.txt')
95 | #creating the output vector for the respective class
96 | index = np.argmax(prob_val)
97 | orig_score = prob_val[0][index]
98 | print "Predicted_class: ", index
99 | output[index] = 1.0
100 | label_id = index
101 | else:
102 | output[label_id] = 1.0
103 | output = np.array(output)
104 | print label_id
105 | conv_output, conv_first_grad, conv_second_grad, conv_third_grad = sess.run([target_conv_layer, first_derivative, second_derivative, triple_derivative], feed_dict={input_image:[img1], label_index:label_id, label_vector: output.reshape((1,-1))})
106 |
107 | global_sum = np.sum(conv_output[0].reshape((-1,conv_first_grad[0].shape[2])), axis=0)
108 |
109 | alpha_num = conv_second_grad[0]
110 | alpha_denom = conv_second_grad[0]*2.0 + conv_third_grad[0]*global_sum.reshape((1,1,conv_first_grad[0].shape[2]))
111 | alpha_denom = np.where(alpha_denom != 0.0, alpha_denom, np.ones(alpha_denom.shape))
112 | alphas = alpha_num/alpha_denom
113 |
114 | weights = np.maximum(conv_first_grad[0], 0.0)
115 | #normalizing the alphas
116 | """
117 | alpha_normalization_constant = np.sum(np.sum(alphas, axis=0),axis=0)
118 |
119 | alphas /= alpha_normalization_constant.reshape((1,1,conv_first_grad[0].shape[2]))
120 | """
121 |
122 | alphas_thresholding = np.where(weights, alphas, 0.0)
123 |
124 | alpha_normalization_constant = np.sum(np.sum(alphas_thresholding, axis=0),axis=0)
125 | alpha_normalization_constant_processed = np.where(alpha_normalization_constant != 0.0, alpha_normalization_constant, np.ones(alpha_normalization_constant.shape))
126 |
127 |
128 | alphas /= alpha_normalization_constant_processed.reshape((1,1,conv_first_grad[0].shape[2]))
129 |
130 |
131 |
132 | deep_linearization_weights = np.sum((weights*alphas).reshape((-1,conv_first_grad[0].shape[2])),axis=0)
133 | #print deep_linearization_weights
134 | grad_CAM_map = np.sum(deep_linearization_weights*conv_output[0], axis=2)
135 |
136 | # Passing through ReLU
137 | cam = np.maximum(grad_CAM_map, 0)
138 | cam = cam / np.max(cam) # scale 0 to 1.0
139 |
140 | cam = resize(cam, (224,224))
141 | # Passing through ReLU
142 | cam = np.maximum(grad_CAM_map, 0)
143 | cam = cam / np.max(cam) # scale 0 to 1.0
144 | cam = resize(cam, (224,224))
145 |
146 |
147 | gb = guided_BP([img1], label_id)
148 | visualize(img1, cam, output_filename, gb)
149 | return cam
150 |
151 | def visualize(img, cam, filename,gb_viz):
152 | gb_viz = np.dstack((
153 | gb_viz[:, :, 2],
154 | gb_viz[:, :, 1],
155 | gb_viz[:, :, 0],
156 | ))
157 |
158 | gb_viz -= np.min(gb_viz)
159 | gb_viz /= gb_viz.max()
160 |
161 | fig, ax = plt.subplots(nrows=1,ncols=3)
162 |
163 | plt.subplot(141)
164 | plt.axis("off")
165 | imgplot = plt.imshow(img)
166 |
167 | plt.subplot(142)
168 | gd_img = gb_viz*np.minimum(0.25,cam).reshape(224,224,1)
169 | x = gd_img
170 | x = np.squeeze(x)
171 |
172 | #normalize tensor
173 | x -= x.mean()
174 | x /= (x.std() + 1e-5)
175 | x *= 0.1
176 |
177 | # clip to [0, 1]
178 | x += 0.5
179 | x = np.clip(x, 0, 1)
180 |
181 | # convert to RGB array
182 | x *= 255
183 |
184 | x = np.clip(x, 0, 255).astype('uint8')
185 |
186 | plt.axis("off")
187 | imgplot = plt.imshow(x, vmin = 0, vmax = 20)
188 |
189 | cam = (cam*-1.0) + 1.0
190 | cam_heatmap = np.array(cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET))
191 | plt.subplot(143)
192 | plt.axis("off")
193 |
194 | imgplot = plt.imshow(cam_heatmap)
195 |
196 | plt.subplot(144)
197 | plt.axis("off")
198 |
199 | cam_heatmap = cam_heatmap/255.0
200 |
201 | fin = (img*0.7) + (cam_heatmap*0.3)
202 | imgplot = plt.imshow(fin)
203 |
204 | plt.savefig("output/" + filename, dpi=600)
205 | plt.close(fig)
206 |
207 |
--------------------------------------------------------------------------------
/misc/utils.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/misc/utils.pyc
--------------------------------------------------------------------------------
/misc/utils.py~:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import tensorflow as tf
3 | import GuideReLU as GReLU
4 | import models.vgg16 as vgg16
5 | import models.vgg_utils as vgg_utils
6 | from skimage.transform import resize
7 | import matplotlib.pyplot as plt
8 | import os
9 | from scipy.misc import imread, imresize
10 | import tensorflow as tf
11 | from tensorflow.python.framework import graph_util
12 |
13 | def guided_BP(image, label_id = -1):
14 | g = tf.get_default_graph()
15 | with g.gradient_override_map({'Relu': 'GuidedRelu'}):
16 | label_vector = tf.placeholder("float", [None, 1000])
17 | input_image = tf.placeholder("float", [None, 224, 224, 3])
18 |
19 | vgg = vgg16.Vgg16()
20 | with tf.name_scope("content_vgg"):
21 | vgg.build(input_image)
22 |
23 | cost = vgg.fc8*label_vector
24 |
25 | # Guided backpropagtion back to input layer
26 | gb_grad = tf.gradients(cost, input_image)[0]
27 |
28 | init = tf.global_variables_initializer()
29 |
30 | # Run tensorflow
31 | with tf.Session(graph=g) as sess:
32 | sess.run(init)
33 | output = [0.0]*vgg.prob.get_shape().as_list()[1] #one-hot embedding for desired class activations
34 | if label_id == -1:
35 | prob = sess.run(vgg.prob, feed_dict={input_image:image})
36 |
37 | vgg_utils.print_prob(prob[0], './synset.txt')
38 |
39 | #creating the output vector for the respective class
40 | index = np.argmax(prob)
41 | print "Predicted_class: ", index
42 | output[index] = 1.0
43 |
44 | else:
45 | output[label_id] = 1.0
46 | output = np.array(output)
47 | gb_grad_value = sess.run(gb_grad, feed_dict={input_image:image, label_vector: output.reshape((1,-1))})
48 |
49 | return gb_grad_value[0]
50 |
51 | def grad_CAM(image, label_id = -1):
52 | g = tf.get_default_graph()
53 | init = tf.global_variables_initializer()
54 | # Run tensorflow
55 | sess = tf.Session()
56 |
57 | label_vector = tf.placeholder("float", [None, 1000])
58 | input_image = tf.placeholder("float", [1, 224, 224, 3])
59 | label_index = tf.placeholder("int64", ())
60 |
61 | vgg = vgg16.Vgg16()
62 | with tf.name_scope("content_vgg"):
63 | vgg.build(input_image)
64 | #prob = tf.placeholder("float", [None, 1000])
65 |
66 | cost = vgg.prob*label_vector
67 |
68 | # Get last convolutional layer gradients for generating gradCAM visualization
69 | target_conv_layer = vgg.conv5_3
70 | target_conv_layer_grad = tf.gradients(cost, target_conv_layer)[0]
71 |
72 | #second_derivative
73 | temp_1 = target_conv_layer_grad*(target_conv_layer_grad/vgg.prob[0][label_index])
74 |
75 | #Calculating the gradient of every node in the softmax layer with the feature map
76 | total = tf.zeros_like(target_conv_layer)
77 | S_double_derivatives_j = []
78 | y_gradients_list = []
79 | for j in range(1000):
80 | j_output = np.zeros([1000,])
81 | j_output[j] = 1.0
82 | output_tensor_j = tf.constant(j_output, dtype=tf.float32)
83 |
84 | y_gradients_j = tf.gradients(vgg.fc8*output_tensor_j, target_conv_layer)[0]
85 | S_gradients_j = tf.gradients(vgg.prob*output_tensor_j, target_conv_layer)[0]
86 |
87 | y_gradients_list.append(y_gradients_j)
88 |
89 | total = total + S_gradients_j*y_gradients_j
90 |
91 | prob_output_j = vgg.prob[0][j]
92 | temp_1_mini = S_gradients_j*(S_gradients_j/prob_output_j)
93 |
94 | S_double_derivatives_j.append(temp_1_mini)
95 |
96 | S_double_derivatives_j_part_one = tf.stack(S_double_derivatives_j)
97 | temp_2 = total*vgg.prob[0][label_index]
98 | second_derivative = temp_1 - temp_2
99 |
100 | #third_derivative
101 | double_temp_1 = second_derivative*(target_conv_layer_grad/vgg.prob[0][label_index])
102 | double_temp_2 = total*target_conv_layer_grad
103 |
104 | temp_array = []
105 | for j in range(1000):
106 | prob_output_j = vgg.prob[0][j]
107 | temp_2_mini = total*prob_output_j
108 | temp_array.append(temp_2_mini)
109 |
110 | S_double_derivatives_j_part_two = tf.stack(temp_array)
111 | S_double_derivatives = S_double_derivatives_j_part_one - S_double_derivatives_j_part_two
112 |
113 | total_2 = tf.reduce_sum(S_double_derivatives*tf.stack(y_gradients_list),axis=0)
114 |
115 | double_temp_3 = total_2*vgg.prob[0][label_index]
116 |
117 | triple_derivative = double_temp_1 - double_temp_2*2.0 - double_temp_3
118 | sess.run(init)
119 | f = open("scores.txt","w")
120 | for filename in sorted(os.listdir("../../imagenet_val/")):
121 | print filename
122 |
123 | img1 = vgg_utils.load_image("../../imagenet_val/" + filename)
124 |
125 | output = [0.0]*vgg.prob.get_shape().as_list()[1] #one-hot embedding for desired class activations
126 | #creating the output vector for the respective class
127 |
128 | try:
129 | prob_val = sess.run(vgg.prob, feed_dict={input_image:[img1]})
130 | except:
131 | print "oops"
132 | continue
133 |
134 | vgg_utils.print_prob(prob_val[0], './synset.txt')
135 |
136 | #creating the output vector for the respective class
137 | index = np.argmax(prob_val)
138 | orig_score = prob_val[0][index]
139 | print "Predicted_class: ", index
140 | output[index] = 1.0
141 | label_id = index
142 | output = np.array(output)
143 |
144 |
145 | conv_output, conv_first_grad, conv_second_grad, conv_third_grad = sess.run([target_conv_layer, target_conv_layer_grad, second_derivative, triple_derivative], feed_dict={input_image:[img1], label_index:label_id, label_vector: output.reshape((1,-1))})
146 |
147 | global_sum = np.sum(conv_output[0].reshape((-1,512)), axis=0)
148 |
149 | alpha_num = conv_second_grad[0]
150 | alpha_denom = conv_second_grad[0]*2.0 + conv_third_grad[0]*global_sum.reshape((1,1,512))
151 | alpha_denom = np.where(alpha_denom != 0.0, alpha_denom, np.ones(alpha_denom.shape))
152 | alphas = alpha_num/alpha_denom
153 |
154 | weights = np.maximum(conv_first_grad[0], 0.0)
155 | #normalizing the alphas
156 |
157 | alpha_normalization_constant = np.sum(np.sum(alphas, axis=0),axis=0)
158 |
159 | alphas /= alpha_normalization_constant.reshape((1,1,512))
160 |
161 |
162 | deep_linearization_weights = np.sum((weights*alphas).reshape((-1,512)),axis=0)
163 | #print deep_linearization_weights
164 | grad_CAM_map = np.sum(deep_linearization_weights*conv_output[0], axis=2)
165 |
166 | # Passing through ReLU
167 | cam = np.maximum(grad_CAM_map, 0)
168 | cam = cam / np.max(cam) # scale 0 to 1.0
169 |
170 | cam = resize(cam, (224,224))
171 | # Passing through ReLU
172 | cam = np.maximum(grad_CAM_map, 0)
173 | cam = cam / np.max(cam) # scale 0 to 1.0
174 | cam = resize(cam, (224,224))
175 | visualize(img1, cam, filename)
176 |
177 |
178 | #keeping the important saliency parts in original image
179 | gd_img = img1*cam.reshape((224,224,1))
180 | gd_img -= np.min(gd_img)
181 | gd_img /= gd_img.max()
182 |
183 | new_prob = sess.run(vgg.prob, feed_dict={input_image:[gd_img]})
184 | f.write(str(orig_score) + " " + str(new_prob[0][index]) + "\n")
185 |
186 | # print(cam)
187 | return cam
188 |
189 | def visualize(img, cam, filename):
190 | """gb_viz = np.dstack((
191 | gb_viz[:, :, 2],
192 | gb_viz[:, :, 1],
193 | gb_viz[:, :, 0],
194 | ))
195 |
196 | gb_viz -= np.min(gb_viz)
197 | gb_viz /= gb_viz.max()"""
198 |
199 | fig, ax = plt.subplots(nrows=1,ncols=2)
200 |
201 | plt.subplot(121)
202 | imgplot = plt.imshow(img)
203 |
204 |
205 | plt.subplot(122)
206 | gd_img = img*cam.reshape((224,224,1))
207 | gd_img -= np.min(gd_img)
208 | gd_img /= gd_img.max()
209 | imgplot = plt.imshow(gd_img)
210 |
211 |
212 | """gd_gb = gb_viz*cam.reshape((224,224,1))
213 | gd_gb -= np.min(gd_gb)
214 | gd_gb /= gd_gb.max()
215 |
216 | plt.subplot(224)
217 | imgplot = plt.imshow(gd_gb)"""
218 |
219 | plt.savefig("output/" + filename + ".png")
220 | plt.close(fig)
221 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/models/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/models/__init__.pyc
--------------------------------------------------------------------------------
/models/vgg16.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | import os
3 |
4 | import numpy as np
5 | import tensorflow as tf
6 | import time
7 |
8 | VGG_MEAN = [103.939, 116.779, 123.68]
9 |
10 |
11 | class Vgg16:
12 | def __init__(self, vgg16_npy_path=None):
13 | if vgg16_npy_path is None:
14 | path = inspect.getfile(Vgg16)
15 | path = os.path.abspath(os.path.join(path, os.pardir))
16 | path = os.path.join(path, "vgg16.npy")
17 | vgg16_npy_path = path
18 | print(path)
19 |
20 | self.data_dict = np.load(vgg16_npy_path, encoding='latin1').item()
21 | print("npy file loaded")
22 |
23 | def build(self, rgb):
24 | """
25 | load variable from npy to build the VGG
26 |
27 | :param rgb: rgb image [batch, height, width, 3] values scaled [0, 1]
28 | """
29 |
30 | start_time = time.time()
31 | print("build model started")
32 | rgb_scaled = rgb * 255.0
33 |
34 | # Convert RGB to BGR
35 | red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled)
36 | assert red.get_shape().as_list()[1:] == [224, 224, 1]
37 | assert green.get_shape().as_list()[1:] == [224, 224, 1]
38 | assert blue.get_shape().as_list()[1:] == [224, 224, 1]
39 | bgr = tf.concat(axis=3, values=[
40 | blue - VGG_MEAN[0],
41 | green - VGG_MEAN[1],
42 | red - VGG_MEAN[2],
43 | ])
44 | assert bgr.get_shape().as_list()[1:] == [224, 224, 3]
45 |
46 | self.conv1_1 = self.conv_layer(bgr, "conv1_1")
47 | self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2")
48 | self.pool1 = self.max_pool(self.conv1_2, 'pool1')
49 |
50 | self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")
51 | self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2")
52 | self.pool2 = self.max_pool(self.conv2_2, 'pool2')
53 |
54 | self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")
55 | self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2")
56 | self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3")
57 | self.pool3 = self.max_pool(self.conv3_3, 'pool3')
58 |
59 | self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")
60 | self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2")
61 | self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3")
62 | self.pool4 = self.max_pool(self.conv4_3, 'pool4')
63 |
64 | self.conv5_1 = self.conv_layer(self.pool4, "conv5_1")
65 | self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2")
66 | self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3")
67 | self.pool5 = self.max_pool(self.conv5_3, 'pool5')
68 |
69 | self.fc6 = self.fc_layer(self.pool5, "fc6")
70 | assert self.fc6.get_shape().as_list()[1:] == [4096]
71 | self.relu6 = tf.nn.relu(self.fc6)
72 |
73 | self.fc7 = self.fc_layer(self.relu6, "fc7")
74 | self.relu7 = tf.nn.relu(self.fc7)
75 |
76 | self.fc8 = self.fc_layer(self.relu7, "fc8")
77 |
78 | self.prob = tf.nn.softmax(self.fc8, name="prob")
79 |
80 | self.data_dict = None
81 | print(("build model finished: %ds" % (time.time() - start_time)))
82 |
83 | def avg_pool(self, bottom, name):
84 | return tf.nn.avg_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)
85 |
86 | def max_pool(self, bottom, name):
87 | return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)
88 |
89 | def conv_layer(self, bottom, name):
90 | with tf.variable_scope(name):
91 | filt = self.get_conv_filter(name)
92 |
93 | conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')
94 |
95 | conv_biases = self.get_bias(name)
96 | bias = tf.nn.bias_add(conv, conv_biases)
97 |
98 | relu = tf.nn.relu(bias)
99 | return relu
100 |
101 | def fc_layer(self, bottom, name):
102 | with tf.variable_scope(name):
103 | shape = bottom.get_shape().as_list()
104 | dim = 1
105 | for d in shape[1:]:
106 | dim *= d
107 | x = tf.reshape(bottom, [-1, dim])
108 |
109 | weights = self.get_fc_weight(name)
110 | biases = self.get_bias(name)
111 |
112 | # Fully connected layer. Note that the '+' operation automatically
113 | # broadcasts the biases.
114 | fc = tf.nn.bias_add(tf.matmul(x, weights), biases)
115 |
116 | return fc
117 |
118 | def get_conv_filter(self, name):
119 | return tf.constant(self.data_dict[name][0], name="filter")
120 |
121 | def get_bias(self, name):
122 | return tf.constant(self.data_dict[name][1], name="biases")
123 |
124 | def get_fc_weight(self, name):
125 | return tf.constant(self.data_dict[name][0], name="weights")
126 |
--------------------------------------------------------------------------------
/models/vgg16.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/models/vgg16.pyc
--------------------------------------------------------------------------------
/models/vgg19.py:
--------------------------------------------------------------------------------
1 | import os
2 | import tensorflow as tf
3 |
4 | import numpy as np
5 | import time
6 | import inspect
7 |
8 | VGG_MEAN = [103.939, 116.779, 123.68]
9 |
10 |
11 | class Vgg19:
12 | def __init__(self, vgg19_npy_path=None):
13 | if vgg19_npy_path is None:
14 | path = inspect.getfile(Vgg19)
15 | path = os.path.abspath(os.path.join(path, os.pardir))
16 | path = os.path.join(path, "vgg19.npy")
17 | vgg19_npy_path = path
18 | print(vgg19_npy_path)
19 |
20 | self.data_dict = np.load(vgg19_npy_path, encoding='latin1').item()
21 | print("npy file loaded")
22 |
23 | def build(self, rgb):
24 | """
25 | load variable from npy to build the VGG
26 |
27 | :param rgb: rgb image [batch, height, width, 3] values scaled [0, 1]
28 | """
29 |
30 | start_time = time.time()
31 | print("build model started")
32 | rgb_scaled = rgb * 255.0
33 |
34 | # Convert RGB to BGR
35 | red, green, blue = tf.split(axis=3, num_or_size_splits=3, value=rgb_scaled)
36 | assert red.get_shape().as_list()[1:] == [224, 224, 1]
37 | assert green.get_shape().as_list()[1:] == [224, 224, 1]
38 | assert blue.get_shape().as_list()[1:] == [224, 224, 1]
39 | bgr = tf.concat(axis=3, values=[
40 | blue - VGG_MEAN[0],
41 | green - VGG_MEAN[1],
42 | red - VGG_MEAN[2],
43 | ])
44 | assert bgr.get_shape().as_list()[1:] == [224, 224, 3]
45 |
46 | self.conv1_1 = self.conv_layer(bgr, "conv1_1")
47 | self.conv1_2 = self.conv_layer(self.conv1_1, "conv1_2")
48 | self.pool1 = self.max_pool(self.conv1_2, 'pool1')
49 |
50 | self.conv2_1 = self.conv_layer(self.pool1, "conv2_1")
51 | self.conv2_2 = self.conv_layer(self.conv2_1, "conv2_2")
52 | self.pool2 = self.max_pool(self.conv2_2, 'pool2')
53 |
54 | self.conv3_1 = self.conv_layer(self.pool2, "conv3_1")
55 | self.conv3_2 = self.conv_layer(self.conv3_1, "conv3_2")
56 | self.conv3_3 = self.conv_layer(self.conv3_2, "conv3_3")
57 | self.conv3_4 = self.conv_layer(self.conv3_3, "conv3_4")
58 | self.pool3 = self.max_pool(self.conv3_4, 'pool3')
59 |
60 | self.conv4_1 = self.conv_layer(self.pool3, "conv4_1")
61 | self.conv4_2 = self.conv_layer(self.conv4_1, "conv4_2")
62 | self.conv4_3 = self.conv_layer(self.conv4_2, "conv4_3")
63 | self.conv4_4 = self.conv_layer(self.conv4_3, "conv4_4")
64 | self.pool4 = self.max_pool(self.conv4_4, 'pool4')
65 |
66 | self.conv5_1 = self.conv_layer(self.pool4, "conv5_1")
67 | self.conv5_2 = self.conv_layer(self.conv5_1, "conv5_2")
68 | self.conv5_3 = self.conv_layer(self.conv5_2, "conv5_3")
69 | self.conv5_4 = self.conv_layer(self.conv5_3, "conv5_4")
70 | self.pool5 = self.max_pool(self.conv5_4, 'pool5')
71 |
72 | self.fc6 = self.fc_layer(self.pool5, "fc6")
73 | assert self.fc6.get_shape().as_list()[1:] == [4096]
74 | self.relu6 = tf.nn.relu(self.fc6)
75 |
76 | self.fc7 = self.fc_layer(self.relu6, "fc7")
77 | self.relu7 = tf.nn.relu(self.fc7)
78 |
79 | self.fc8 = self.fc_layer(self.relu7, "fc8")
80 |
81 | self.prob = tf.nn.softmax(self.fc8, name="prob")
82 |
83 | self.data_dict = None
84 | print(("build model finished: %ds" % (time.time() - start_time)))
85 |
86 | def avg_pool(self, bottom, name):
87 | return tf.nn.avg_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)
88 |
89 | def max_pool(self, bottom, name):
90 | return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', name=name)
91 |
92 | def conv_layer(self, bottom, name):
93 | with tf.variable_scope(name):
94 | filt = self.get_conv_filter(name)
95 |
96 | conv = tf.nn.conv2d(bottom, filt, [1, 1, 1, 1], padding='SAME')
97 |
98 | conv_biases = self.get_bias(name)
99 | bias = tf.nn.bias_add(conv, conv_biases)
100 |
101 | relu = tf.nn.relu(bias)
102 | return relu
103 |
104 | def fc_layer(self, bottom, name):
105 | with tf.variable_scope(name):
106 | shape = bottom.get_shape().as_list()
107 | dim = 1
108 | for d in shape[1:]:
109 | dim *= d
110 | x = tf.reshape(bottom, [-1, dim])
111 |
112 | weights = self.get_fc_weight(name)
113 | biases = self.get_bias(name)
114 |
115 | # Fully connected layer. Note that the '+' operation automatically
116 | # broadcasts the biases.
117 | fc = tf.nn.bias_add(tf.matmul(x, weights), biases)
118 |
119 | return fc
120 |
121 | def get_conv_filter(self, name):
122 | return tf.constant(self.data_dict[name][0], name="filter")
123 |
124 | def get_bias(self, name):
125 | return tf.constant(self.data_dict[name][1], name="biases")
126 |
127 | def get_fc_weight(self, name):
128 | return tf.constant(self.data_dict[name][0], name="weights")
129 |
--------------------------------------------------------------------------------
/models/vgg_utils.py:
--------------------------------------------------------------------------------
1 | import skimage
2 | import skimage.io
3 | import skimage.transform
4 | import numpy as np
5 |
6 |
7 | # synset = [l.strip() for l in open('synset.txt').readlines()]
8 |
9 |
10 | # returns image of shape [224, 224, 3]
11 | # [height, width, depth]
12 | def load_image(path):
13 | # load image
14 | img = skimage.io.imread(path)
15 | img = img / 255.0
16 | assert (0 <= img).all() and (img <= 1.0).all()
17 | # print "Original Image Shape: ", img.shape
18 | # we crop image from center
19 | short_edge = min(img.shape[:2])
20 | yy = int((img.shape[0] - short_edge) / 2)
21 | xx = int((img.shape[1] - short_edge) / 2)
22 | crop_img = img[yy: yy + short_edge, xx: xx + short_edge]
23 | # resize to 224, 224
24 | resized_img = skimage.transform.resize(crop_img, (224, 224))
25 | return resized_img
26 |
27 |
28 | # returns the top1 string
29 | def print_prob(prob, file_path):
30 | synset = [l.strip() for l in open(file_path).readlines()]
31 |
32 | # print prob
33 | pred = np.argsort(prob)[::-1]
34 |
35 | # Get top1 label
36 | top1 = synset[pred[0]]
37 | print(("Top1: ", top1, prob[pred[0]]))
38 | # Get top5 label
39 | top5 = [(synset[pred[i]], prob[pred[i]]) for i in range(5)]
40 | print(("Top5: ", top5))
41 | return top1
42 |
43 |
44 | def load_image2(path, height=None, width=None):
45 | # load image
46 | img = skimage.io.imread(path)
47 | img = img / 255.0
48 | if height is not None and width is not None:
49 | ny = height
50 | nx = width
51 | elif height is not None:
52 | ny = height
53 | nx = img.shape[1] * ny / img.shape[0]
54 | elif width is not None:
55 | nx = width
56 | ny = img.shape[0] * nx / img.shape[1]
57 | else:
58 | ny = img.shape[0]
59 | nx = img.shape[1]
60 | return skimage.transform.resize(img, (ny, nx))
61 |
62 |
63 | def test():
64 | img = skimage.io.imread("./test_data/starry_night.jpg")
65 | ny = 300
66 | nx = img.shape[1] * ny / img.shape[0]
67 | img = skimage.transform.resize(img, (ny, nx))
68 | skimage.io.imsave("./test_data/test/output.jpg", img)
69 |
70 |
71 | if __name__ == "__main__":
72 | test()
73 |
--------------------------------------------------------------------------------
/models/vgg_utils.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/models/vgg_utils.pyc
--------------------------------------------------------------------------------
/output/multiple_dogs.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/output/multiple_dogs.jpg
--------------------------------------------------------------------------------
/output/output.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/output/output.jpeg
--------------------------------------------------------------------------------
/output/water-bird.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/adityac94/Grad_CAM_plus_plus/4a9faf6ac61ef0c56e19b88d8560b81cd62c5017/output/water-bird.JPEG
--------------------------------------------------------------------------------
/synset.txt:
--------------------------------------------------------------------------------
1 | n01440764 tench, Tinca tinca
2 | n01443537 goldfish, Carassius auratus
3 | n01484850 great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
4 | n01491361 tiger shark, Galeocerdo cuvieri
5 | n01494475 hammerhead, hammerhead shark
6 | n01496331 electric ray, crampfish, numbfish, torpedo
7 | n01498041 stingray
8 | n01514668 cock
9 | n01514859 hen
10 | n01518878 ostrich, Struthio camelus
11 | n01530575 brambling, Fringilla montifringilla
12 | n01531178 goldfinch, Carduelis carduelis
13 | n01532829 house finch, linnet, Carpodacus mexicanus
14 | n01534433 junco, snowbird
15 | n01537544 indigo bunting, indigo finch, indigo bird, Passerina cyanea
16 | n01558993 robin, American robin, Turdus migratorius
17 | n01560419 bulbul
18 | n01580077 jay
19 | n01582220 magpie
20 | n01592084 chickadee
21 | n01601694 water ouzel, dipper
22 | n01608432 kite
23 | n01614925 bald eagle, American eagle, Haliaeetus leucocephalus
24 | n01616318 vulture
25 | n01622779 great grey owl, great gray owl, Strix nebulosa
26 | n01629819 European fire salamander, Salamandra salamandra
27 | n01630670 common newt, Triturus vulgaris
28 | n01631663 eft
29 | n01632458 spotted salamander, Ambystoma maculatum
30 | n01632777 axolotl, mud puppy, Ambystoma mexicanum
31 | n01641577 bullfrog, Rana catesbeiana
32 | n01644373 tree frog, tree-frog
33 | n01644900 tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
34 | n01664065 loggerhead, loggerhead turtle, Caretta caretta
35 | n01665541 leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea
36 | n01667114 mud turtle
37 | n01667778 terrapin
38 | n01669191 box turtle, box tortoise
39 | n01675722 banded gecko
40 | n01677366 common iguana, iguana, Iguana iguana
41 | n01682714 American chameleon, anole, Anolis carolinensis
42 | n01685808 whiptail, whiptail lizard
43 | n01687978 agama
44 | n01688243 frilled lizard, Chlamydosaurus kingi
45 | n01689811 alligator lizard
46 | n01692333 Gila monster, Heloderma suspectum
47 | n01693334 green lizard, Lacerta viridis
48 | n01694178 African chameleon, Chamaeleo chamaeleon
49 | n01695060 Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis
50 | n01697457 African crocodile, Nile crocodile, Crocodylus niloticus
51 | n01698640 American alligator, Alligator mississipiensis
52 | n01704323 triceratops
53 | n01728572 thunder snake, worm snake, Carphophis amoenus
54 | n01728920 ringneck snake, ring-necked snake, ring snake
55 | n01729322 hognose snake, puff adder, sand viper
56 | n01729977 green snake, grass snake
57 | n01734418 king snake, kingsnake
58 | n01735189 garter snake, grass snake
59 | n01737021 water snake
60 | n01739381 vine snake
61 | n01740131 night snake, Hypsiglena torquata
62 | n01742172 boa constrictor, Constrictor constrictor
63 | n01744401 rock python, rock snake, Python sebae
64 | n01748264 Indian cobra, Naja naja
65 | n01749939 green mamba
66 | n01751748 sea snake
67 | n01753488 horned viper, cerastes, sand viper, horned asp, Cerastes cornutus
68 | n01755581 diamondback, diamondback rattlesnake, Crotalus adamanteus
69 | n01756291 sidewinder, horned rattlesnake, Crotalus cerastes
70 | n01768244 trilobite
71 | n01770081 harvestman, daddy longlegs, Phalangium opilio
72 | n01770393 scorpion
73 | n01773157 black and gold garden spider, Argiope aurantia
74 | n01773549 barn spider, Araneus cavaticus
75 | n01773797 garden spider, Aranea diademata
76 | n01774384 black widow, Latrodectus mactans
77 | n01774750 tarantula
78 | n01775062 wolf spider, hunting spider
79 | n01776313 tick
80 | n01784675 centipede
81 | n01795545 black grouse
82 | n01796340 ptarmigan
83 | n01797886 ruffed grouse, partridge, Bonasa umbellus
84 | n01798484 prairie chicken, prairie grouse, prairie fowl
85 | n01806143 peacock
86 | n01806567 quail
87 | n01807496 partridge
88 | n01817953 African grey, African gray, Psittacus erithacus
89 | n01818515 macaw
90 | n01819313 sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita
91 | n01820546 lorikeet
92 | n01824575 coucal
93 | n01828970 bee eater
94 | n01829413 hornbill
95 | n01833805 hummingbird
96 | n01843065 jacamar
97 | n01843383 toucan
98 | n01847000 drake
99 | n01855032 red-breasted merganser, Mergus serrator
100 | n01855672 goose
101 | n01860187 black swan, Cygnus atratus
102 | n01871265 tusker
103 | n01872401 echidna, spiny anteater, anteater
104 | n01873310 platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus
105 | n01877812 wallaby, brush kangaroo
106 | n01882714 koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
107 | n01883070 wombat
108 | n01910747 jellyfish
109 | n01914609 sea anemone, anemone
110 | n01917289 brain coral
111 | n01924916 flatworm, platyhelminth
112 | n01930112 nematode, nematode worm, roundworm
113 | n01943899 conch
114 | n01944390 snail
115 | n01945685 slug
116 | n01950731 sea slug, nudibranch
117 | n01955084 chiton, coat-of-mail shell, sea cradle, polyplacophore
118 | n01968897 chambered nautilus, pearly nautilus, nautilus
119 | n01978287 Dungeness crab, Cancer magister
120 | n01978455 rock crab, Cancer irroratus
121 | n01980166 fiddler crab
122 | n01981276 king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica
123 | n01983481 American lobster, Northern lobster, Maine lobster, Homarus americanus
124 | n01984695 spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
125 | n01985128 crayfish, crawfish, crawdad, crawdaddy
126 | n01986214 hermit crab
127 | n01990800 isopod
128 | n02002556 white stork, Ciconia ciconia
129 | n02002724 black stork, Ciconia nigra
130 | n02006656 spoonbill
131 | n02007558 flamingo
132 | n02009229 little blue heron, Egretta caerulea
133 | n02009912 American egret, great white heron, Egretta albus
134 | n02011460 bittern
135 | n02012849 crane
136 | n02013706 limpkin, Aramus pictus
137 | n02017213 European gallinule, Porphyrio porphyrio
138 | n02018207 American coot, marsh hen, mud hen, water hen, Fulica americana
139 | n02018795 bustard
140 | n02025239 ruddy turnstone, Arenaria interpres
141 | n02027492 red-backed sandpiper, dunlin, Erolia alpina
142 | n02028035 redshank, Tringa totanus
143 | n02033041 dowitcher
144 | n02037110 oystercatcher, oyster catcher
145 | n02051845 pelican
146 | n02056570 king penguin, Aptenodytes patagonica
147 | n02058221 albatross, mollymawk
148 | n02066245 grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus
149 | n02071294 killer whale, killer, orca, grampus, sea wolf, Orcinus orca
150 | n02074367 dugong, Dugong dugon
151 | n02077923 sea lion
152 | n02085620 Chihuahua
153 | n02085782 Japanese spaniel
154 | n02085936 Maltese dog, Maltese terrier, Maltese
155 | n02086079 Pekinese, Pekingese, Peke
156 | n02086240 Shih-Tzu
157 | n02086646 Blenheim spaniel
158 | n02086910 papillon
159 | n02087046 toy terrier
160 | n02087394 Rhodesian ridgeback
161 | n02088094 Afghan hound, Afghan
162 | n02088238 basset, basset hound
163 | n02088364 beagle
164 | n02088466 bloodhound, sleuthhound
165 | n02088632 bluetick
166 | n02089078 black-and-tan coonhound
167 | n02089867 Walker hound, Walker foxhound
168 | n02089973 English foxhound
169 | n02090379 redbone
170 | n02090622 borzoi, Russian wolfhound
171 | n02090721 Irish wolfhound
172 | n02091032 Italian greyhound
173 | n02091134 whippet
174 | n02091244 Ibizan hound, Ibizan Podenco
175 | n02091467 Norwegian elkhound, elkhound
176 | n02091635 otterhound, otter hound
177 | n02091831 Saluki, gazelle hound
178 | n02092002 Scottish deerhound, deerhound
179 | n02092339 Weimaraner
180 | n02093256 Staffordshire bullterrier, Staffordshire bull terrier
181 | n02093428 American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier
182 | n02093647 Bedlington terrier
183 | n02093754 Border terrier
184 | n02093859 Kerry blue terrier
185 | n02093991 Irish terrier
186 | n02094114 Norfolk terrier
187 | n02094258 Norwich terrier
188 | n02094433 Yorkshire terrier
189 | n02095314 wire-haired fox terrier
190 | n02095570 Lakeland terrier
191 | n02095889 Sealyham terrier, Sealyham
192 | n02096051 Airedale, Airedale terrier
193 | n02096177 cairn, cairn terrier
194 | n02096294 Australian terrier
195 | n02096437 Dandie Dinmont, Dandie Dinmont terrier
196 | n02096585 Boston bull, Boston terrier
197 | n02097047 miniature schnauzer
198 | n02097130 giant schnauzer
199 | n02097209 standard schnauzer
200 | n02097298 Scotch terrier, Scottish terrier, Scottie
201 | n02097474 Tibetan terrier, chrysanthemum dog
202 | n02097658 silky terrier, Sydney silky
203 | n02098105 soft-coated wheaten terrier
204 | n02098286 West Highland white terrier
205 | n02098413 Lhasa, Lhasa apso
206 | n02099267 flat-coated retriever
207 | n02099429 curly-coated retriever
208 | n02099601 golden retriever
209 | n02099712 Labrador retriever
210 | n02099849 Chesapeake Bay retriever
211 | n02100236 German short-haired pointer
212 | n02100583 vizsla, Hungarian pointer
213 | n02100735 English setter
214 | n02100877 Irish setter, red setter
215 | n02101006 Gordon setter
216 | n02101388 Brittany spaniel
217 | n02101556 clumber, clumber spaniel
218 | n02102040 English springer, English springer spaniel
219 | n02102177 Welsh springer spaniel
220 | n02102318 cocker spaniel, English cocker spaniel, cocker
221 | n02102480 Sussex spaniel
222 | n02102973 Irish water spaniel
223 | n02104029 kuvasz
224 | n02104365 schipperke
225 | n02105056 groenendael
226 | n02105162 malinois
227 | n02105251 briard
228 | n02105412 kelpie
229 | n02105505 komondor
230 | n02105641 Old English sheepdog, bobtail
231 | n02105855 Shetland sheepdog, Shetland sheep dog, Shetland
232 | n02106030 collie
233 | n02106166 Border collie
234 | n02106382 Bouvier des Flandres, Bouviers des Flandres
235 | n02106550 Rottweiler
236 | n02106662 German shepherd, German shepherd dog, German police dog, alsatian
237 | n02107142 Doberman, Doberman pinscher
238 | n02107312 miniature pinscher
239 | n02107574 Greater Swiss Mountain dog
240 | n02107683 Bernese mountain dog
241 | n02107908 Appenzeller
242 | n02108000 EntleBucher
243 | n02108089 boxer
244 | n02108422 bull mastiff
245 | n02108551 Tibetan mastiff
246 | n02108915 French bulldog
247 | n02109047 Great Dane
248 | n02109525 Saint Bernard, St Bernard
249 | n02109961 Eskimo dog, husky
250 | n02110063 malamute, malemute, Alaskan malamute
251 | n02110185 Siberian husky
252 | n02110341 dalmatian, coach dog, carriage dog
253 | n02110627 affenpinscher, monkey pinscher, monkey dog
254 | n02110806 basenji
255 | n02110958 pug, pug-dog
256 | n02111129 Leonberg
257 | n02111277 Newfoundland, Newfoundland dog
258 | n02111500 Great Pyrenees
259 | n02111889 Samoyed, Samoyede
260 | n02112018 Pomeranian
261 | n02112137 chow, chow chow
262 | n02112350 keeshond
263 | n02112706 Brabancon griffon
264 | n02113023 Pembroke, Pembroke Welsh corgi
265 | n02113186 Cardigan, Cardigan Welsh corgi
266 | n02113624 toy poodle
267 | n02113712 miniature poodle
268 | n02113799 standard poodle
269 | n02113978 Mexican hairless
270 | n02114367 timber wolf, grey wolf, gray wolf, Canis lupus
271 | n02114548 white wolf, Arctic wolf, Canis lupus tundrarum
272 | n02114712 red wolf, maned wolf, Canis rufus, Canis niger
273 | n02114855 coyote, prairie wolf, brush wolf, Canis latrans
274 | n02115641 dingo, warrigal, warragal, Canis dingo
275 | n02115913 dhole, Cuon alpinus
276 | n02116738 African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus
277 | n02117135 hyena, hyaena
278 | n02119022 red fox, Vulpes vulpes
279 | n02119789 kit fox, Vulpes macrotis
280 | n02120079 Arctic fox, white fox, Alopex lagopus
281 | n02120505 grey fox, gray fox, Urocyon cinereoargenteus
282 | n02123045 tabby, tabby cat
283 | n02123159 tiger cat
284 | n02123394 Persian cat
285 | n02123597 Siamese cat, Siamese
286 | n02124075 Egyptian cat
287 | n02125311 cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
288 | n02127052 lynx, catamount
289 | n02128385 leopard, Panthera pardus
290 | n02128757 snow leopard, ounce, Panthera uncia
291 | n02128925 jaguar, panther, Panthera onca, Felis onca
292 | n02129165 lion, king of beasts, Panthera leo
293 | n02129604 tiger, Panthera tigris
294 | n02130308 cheetah, chetah, Acinonyx jubatus
295 | n02132136 brown bear, bruin, Ursus arctos
296 | n02133161 American black bear, black bear, Ursus americanus, Euarctos americanus
297 | n02134084 ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus
298 | n02134418 sloth bear, Melursus ursinus, Ursus ursinus
299 | n02137549 mongoose
300 | n02138441 meerkat, mierkat
301 | n02165105 tiger beetle
302 | n02165456 ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
303 | n02167151 ground beetle, carabid beetle
304 | n02168699 long-horned beetle, longicorn, longicorn beetle
305 | n02169497 leaf beetle, chrysomelid
306 | n02172182 dung beetle
307 | n02174001 rhinoceros beetle
308 | n02177972 weevil
309 | n02190166 fly
310 | n02206856 bee
311 | n02219486 ant, emmet, pismire
312 | n02226429 grasshopper, hopper
313 | n02229544 cricket
314 | n02231487 walking stick, walkingstick, stick insect
315 | n02233338 cockroach, roach
316 | n02236044 mantis, mantid
317 | n02256656 cicada, cicala
318 | n02259212 leafhopper
319 | n02264363 lacewing, lacewing fly
320 | n02268443 dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
321 | n02268853 damselfly
322 | n02276258 admiral
323 | n02277742 ringlet, ringlet butterfly
324 | n02279972 monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
325 | n02280649 cabbage butterfly
326 | n02281406 sulphur butterfly, sulfur butterfly
327 | n02281787 lycaenid, lycaenid butterfly
328 | n02317335 starfish, sea star
329 | n02319095 sea urchin
330 | n02321529 sea cucumber, holothurian
331 | n02325366 wood rabbit, cottontail, cottontail rabbit
332 | n02326432 hare
333 | n02328150 Angora, Angora rabbit
334 | n02342885 hamster
335 | n02346627 porcupine, hedgehog
336 | n02356798 fox squirrel, eastern fox squirrel, Sciurus niger
337 | n02361337 marmot
338 | n02363005 beaver
339 | n02364673 guinea pig, Cavia cobaya
340 | n02389026 sorrel
341 | n02391049 zebra
342 | n02395406 hog, pig, grunter, squealer, Sus scrofa
343 | n02396427 wild boar, boar, Sus scrofa
344 | n02397096 warthog
345 | n02398521 hippopotamus, hippo, river horse, Hippopotamus amphibius
346 | n02403003 ox
347 | n02408429 water buffalo, water ox, Asiatic buffalo, Bubalus bubalis
348 | n02410509 bison
349 | n02412080 ram, tup
350 | n02415577 bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
351 | n02417914 ibex, Capra ibex
352 | n02422106 hartebeest
353 | n02422699 impala, Aepyceros melampus
354 | n02423022 gazelle
355 | n02437312 Arabian camel, dromedary, Camelus dromedarius
356 | n02437616 llama
357 | n02441942 weasel
358 | n02442845 mink
359 | n02443114 polecat, fitch, foulmart, foumart, Mustela putorius
360 | n02443484 black-footed ferret, ferret, Mustela nigripes
361 | n02444819 otter
362 | n02445715 skunk, polecat, wood pussy
363 | n02447366 badger
364 | n02454379 armadillo
365 | n02457408 three-toed sloth, ai, Bradypus tridactylus
366 | n02480495 orangutan, orang, orangutang, Pongo pygmaeus
367 | n02480855 gorilla, Gorilla gorilla
368 | n02481823 chimpanzee, chimp, Pan troglodytes
369 | n02483362 gibbon, Hylobates lar
370 | n02483708 siamang, Hylobates syndactylus, Symphalangus syndactylus
371 | n02484975 guenon, guenon monkey
372 | n02486261 patas, hussar monkey, Erythrocebus patas
373 | n02486410 baboon
374 | n02487347 macaque
375 | n02488291 langur
376 | n02488702 colobus, colobus monkey
377 | n02489166 proboscis monkey, Nasalis larvatus
378 | n02490219 marmoset
379 | n02492035 capuchin, ringtail, Cebus capucinus
380 | n02492660 howler monkey, howler
381 | n02493509 titi, titi monkey
382 | n02493793 spider monkey, Ateles geoffroyi
383 | n02494079 squirrel monkey, Saimiri sciureus
384 | n02497673 Madagascar cat, ring-tailed lemur, Lemur catta
385 | n02500267 indri, indris, Indri indri, Indri brevicaudatus
386 | n02504013 Indian elephant, Elephas maximus
387 | n02504458 African elephant, Loxodonta africana
388 | n02509815 lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
389 | n02510455 giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca
390 | n02514041 barracouta, snoek
391 | n02526121 eel
392 | n02536864 coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch
393 | n02606052 rock beauty, Holocanthus tricolor
394 | n02607072 anemone fish
395 | n02640242 sturgeon
396 | n02641379 gar, garfish, garpike, billfish, Lepisosteus osseus
397 | n02643566 lionfish
398 | n02655020 puffer, pufferfish, blowfish, globefish
399 | n02666196 abacus
400 | n02667093 abaya
401 | n02669723 academic gown, academic robe, judge's robe
402 | n02672831 accordion, piano accordion, squeeze box
403 | n02676566 acoustic guitar
404 | n02687172 aircraft carrier, carrier, flattop, attack aircraft carrier
405 | n02690373 airliner
406 | n02692877 airship, dirigible
407 | n02699494 altar
408 | n02701002 ambulance
409 | n02704792 amphibian, amphibious vehicle
410 | n02708093 analog clock
411 | n02727426 apiary, bee house
412 | n02730930 apron
413 | n02747177 ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin
414 | n02749479 assault rifle, assault gun
415 | n02769748 backpack, back pack, knapsack, packsack, rucksack, haversack
416 | n02776631 bakery, bakeshop, bakehouse
417 | n02777292 balance beam, beam
418 | n02782093 balloon
419 | n02783161 ballpoint, ballpoint pen, ballpen, Biro
420 | n02786058 Band Aid
421 | n02787622 banjo
422 | n02788148 bannister, banister, balustrade, balusters, handrail
423 | n02790996 barbell
424 | n02791124 barber chair
425 | n02791270 barbershop
426 | n02793495 barn
427 | n02794156 barometer
428 | n02795169 barrel, cask
429 | n02797295 barrow, garden cart, lawn cart, wheelbarrow
430 | n02799071 baseball
431 | n02802426 basketball
432 | n02804414 bassinet
433 | n02804610 bassoon
434 | n02807133 bathing cap, swimming cap
435 | n02808304 bath towel
436 | n02808440 bathtub, bathing tub, bath, tub
437 | n02814533 beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
438 | n02814860 beacon, lighthouse, beacon light, pharos
439 | n02815834 beaker
440 | n02817516 bearskin, busby, shako
441 | n02823428 beer bottle
442 | n02823750 beer glass
443 | n02825657 bell cote, bell cot
444 | n02834397 bib
445 | n02835271 bicycle-built-for-two, tandem bicycle, tandem
446 | n02837789 bikini, two-piece
447 | n02840245 binder, ring-binder
448 | n02841315 binoculars, field glasses, opera glasses
449 | n02843684 birdhouse
450 | n02859443 boathouse
451 | n02860847 bobsled, bobsleigh, bob
452 | n02865351 bolo tie, bolo, bola tie, bola
453 | n02869837 bonnet, poke bonnet
454 | n02870880 bookcase
455 | n02871525 bookshop, bookstore, bookstall
456 | n02877765 bottlecap
457 | n02879718 bow
458 | n02883205 bow tie, bow-tie, bowtie
459 | n02892201 brass, memorial tablet, plaque
460 | n02892767 brassiere, bra, bandeau
461 | n02894605 breakwater, groin, groyne, mole, bulwark, seawall, jetty
462 | n02895154 breastplate, aegis, egis
463 | n02906734 broom
464 | n02909870 bucket, pail
465 | n02910353 buckle
466 | n02916936 bulletproof vest
467 | n02917067 bullet train, bullet
468 | n02927161 butcher shop, meat market
469 | n02930766 cab, hack, taxi, taxicab
470 | n02939185 caldron, cauldron
471 | n02948072 candle, taper, wax light
472 | n02950826 cannon
473 | n02951358 canoe
474 | n02951585 can opener, tin opener
475 | n02963159 cardigan
476 | n02965783 car mirror
477 | n02966193 carousel, carrousel, merry-go-round, roundabout, whirligig
478 | n02966687 carpenter's kit, tool kit
479 | n02971356 carton
480 | n02974003 car wheel
481 | n02977058 cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
482 | n02978881 cassette
483 | n02979186 cassette player
484 | n02980441 castle
485 | n02981792 catamaran
486 | n02988304 CD player
487 | n02992211 cello, violoncello
488 | n02992529 cellular telephone, cellular phone, cellphone, cell, mobile phone
489 | n02999410 chain
490 | n03000134 chainlink fence
491 | n03000247 chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour
492 | n03000684 chain saw, chainsaw
493 | n03014705 chest
494 | n03016953 chiffonier, commode
495 | n03017168 chime, bell, gong
496 | n03018349 china cabinet, china closet
497 | n03026506 Christmas stocking
498 | n03028079 church, church building
499 | n03032252 cinema, movie theater, movie theatre, movie house, picture palace
500 | n03041632 cleaver, meat cleaver, chopper
501 | n03042490 cliff dwelling
502 | n03045698 cloak
503 | n03047690 clog, geta, patten, sabot
504 | n03062245 cocktail shaker
505 | n03063599 coffee mug
506 | n03063689 coffeepot
507 | n03065424 coil, spiral, volute, whorl, helix
508 | n03075370 combination lock
509 | n03085013 computer keyboard, keypad
510 | n03089624 confectionery, confectionary, candy store
511 | n03095699 container ship, containership, container vessel
512 | n03100240 convertible
513 | n03109150 corkscrew, bottle screw
514 | n03110669 cornet, horn, trumpet, trump
515 | n03124043 cowboy boot
516 | n03124170 cowboy hat, ten-gallon hat
517 | n03125729 cradle
518 | n03126707 crane
519 | n03127747 crash helmet
520 | n03127925 crate
521 | n03131574 crib, cot
522 | n03133878 Crock Pot
523 | n03134739 croquet ball
524 | n03141823 crutch
525 | n03146219 cuirass
526 | n03160309 dam, dike, dyke
527 | n03179701 desk
528 | n03180011 desktop computer
529 | n03187595 dial telephone, dial phone
530 | n03188531 diaper, nappy, napkin
531 | n03196217 digital clock
532 | n03197337 digital watch
533 | n03201208 dining table, board
534 | n03207743 dishrag, dishcloth
535 | n03207941 dishwasher, dish washer, dishwashing machine
536 | n03208938 disk brake, disc brake
537 | n03216828 dock, dockage, docking facility
538 | n03218198 dogsled, dog sled, dog sleigh
539 | n03220513 dome
540 | n03223299 doormat, welcome mat
541 | n03240683 drilling platform, offshore rig
542 | n03249569 drum, membranophone, tympan
543 | n03250847 drumstick
544 | n03255030 dumbbell
545 | n03259280 Dutch oven
546 | n03271574 electric fan, blower
547 | n03272010 electric guitar
548 | n03272562 electric locomotive
549 | n03290653 entertainment center
550 | n03291819 envelope
551 | n03297495 espresso maker
552 | n03314780 face powder
553 | n03325584 feather boa, boa
554 | n03337140 file, file cabinet, filing cabinet
555 | n03344393 fireboat
556 | n03345487 fire engine, fire truck
557 | n03347037 fire screen, fireguard
558 | n03355925 flagpole, flagstaff
559 | n03372029 flute, transverse flute
560 | n03376595 folding chair
561 | n03379051 football helmet
562 | n03384352 forklift
563 | n03388043 fountain
564 | n03388183 fountain pen
565 | n03388549 four-poster
566 | n03393912 freight car
567 | n03394916 French horn, horn
568 | n03400231 frying pan, frypan, skillet
569 | n03404251 fur coat
570 | n03417042 garbage truck, dustcart
571 | n03424325 gasmask, respirator, gas helmet
572 | n03425413 gas pump, gasoline pump, petrol pump, island dispenser
573 | n03443371 goblet
574 | n03444034 go-kart
575 | n03445777 golf ball
576 | n03445924 golfcart, golf cart
577 | n03447447 gondola
578 | n03447721 gong, tam-tam
579 | n03450230 gown
580 | n03452741 grand piano, grand
581 | n03457902 greenhouse, nursery, glasshouse
582 | n03459775 grille, radiator grille
583 | n03461385 grocery store, grocery, food market, market
584 | n03467068 guillotine
585 | n03476684 hair slide
586 | n03476991 hair spray
587 | n03478589 half track
588 | n03481172 hammer
589 | n03482405 hamper
590 | n03483316 hand blower, blow dryer, blow drier, hair dryer, hair drier
591 | n03485407 hand-held computer, hand-held microcomputer
592 | n03485794 handkerchief, hankie, hanky, hankey
593 | n03492542 hard disc, hard disk, fixed disk
594 | n03494278 harmonica, mouth organ, harp, mouth harp
595 | n03495258 harp
596 | n03496892 harvester, reaper
597 | n03498962 hatchet
598 | n03527444 holster
599 | n03529860 home theater, home theatre
600 | n03530642 honeycomb
601 | n03532672 hook, claw
602 | n03534580 hoopskirt, crinoline
603 | n03535780 horizontal bar, high bar
604 | n03538406 horse cart, horse-cart
605 | n03544143 hourglass
606 | n03584254 iPod
607 | n03584829 iron, smoothing iron
608 | n03590841 jack-o'-lantern
609 | n03594734 jean, blue jean, denim
610 | n03594945 jeep, landrover
611 | n03595614 jersey, T-shirt, tee shirt
612 | n03598930 jigsaw puzzle
613 | n03599486 jinrikisha, ricksha, rickshaw
614 | n03602883 joystick
615 | n03617480 kimono
616 | n03623198 knee pad
617 | n03627232 knot
618 | n03630383 lab coat, laboratory coat
619 | n03633091 ladle
620 | n03637318 lampshade, lamp shade
621 | n03642806 laptop, laptop computer
622 | n03649909 lawn mower, mower
623 | n03657121 lens cap, lens cover
624 | n03658185 letter opener, paper knife, paperknife
625 | n03661043 library
626 | n03662601 lifeboat
627 | n03666591 lighter, light, igniter, ignitor
628 | n03670208 limousine, limo
629 | n03673027 liner, ocean liner
630 | n03676483 lipstick, lip rouge
631 | n03680355 Loafer
632 | n03690938 lotion
633 | n03691459 loudspeaker, speaker, speaker unit, loudspeaker system, speaker system
634 | n03692522 loupe, jeweler's loupe
635 | n03697007 lumbermill, sawmill
636 | n03706229 magnetic compass
637 | n03709823 mailbag, postbag
638 | n03710193 mailbox, letter box
639 | n03710637 maillot
640 | n03710721 maillot, tank suit
641 | n03717622 manhole cover
642 | n03720891 maraca
643 | n03721384 marimba, xylophone
644 | n03724870 mask
645 | n03729826 matchstick
646 | n03733131 maypole
647 | n03733281 maze, labyrinth
648 | n03733805 measuring cup
649 | n03742115 medicine chest, medicine cabinet
650 | n03743016 megalith, megalithic structure
651 | n03759954 microphone, mike
652 | n03761084 microwave, microwave oven
653 | n03763968 military uniform
654 | n03764736 milk can
655 | n03769881 minibus
656 | n03770439 miniskirt, mini
657 | n03770679 minivan
658 | n03773504 missile
659 | n03775071 mitten
660 | n03775546 mixing bowl
661 | n03776460 mobile home, manufactured home
662 | n03777568 Model T
663 | n03777754 modem
664 | n03781244 monastery
665 | n03782006 monitor
666 | n03785016 moped
667 | n03786901 mortar
668 | n03787032 mortarboard
669 | n03788195 mosque
670 | n03788365 mosquito net
671 | n03791053 motor scooter, scooter
672 | n03792782 mountain bike, all-terrain bike, off-roader
673 | n03792972 mountain tent
674 | n03793489 mouse, computer mouse
675 | n03794056 mousetrap
676 | n03796401 moving van
677 | n03803284 muzzle
678 | n03804744 nail
679 | n03814639 neck brace
680 | n03814906 necklace
681 | n03825788 nipple
682 | n03832673 notebook, notebook computer
683 | n03837869 obelisk
684 | n03838899 oboe, hautboy, hautbois
685 | n03840681 ocarina, sweet potato
686 | n03841143 odometer, hodometer, mileometer, milometer
687 | n03843555 oil filter
688 | n03854065 organ, pipe organ
689 | n03857828 oscilloscope, scope, cathode-ray oscilloscope, CRO
690 | n03866082 overskirt
691 | n03868242 oxcart
692 | n03868863 oxygen mask
693 | n03871628 packet
694 | n03873416 paddle, boat paddle
695 | n03874293 paddlewheel, paddle wheel
696 | n03874599 padlock
697 | n03876231 paintbrush
698 | n03877472 pajama, pyjama, pj's, jammies
699 | n03877845 palace
700 | n03884397 panpipe, pandean pipe, syrinx
701 | n03887697 paper towel
702 | n03888257 parachute, chute
703 | n03888605 parallel bars, bars
704 | n03891251 park bench
705 | n03891332 parking meter
706 | n03895866 passenger car, coach, carriage
707 | n03899768 patio, terrace
708 | n03902125 pay-phone, pay-station
709 | n03903868 pedestal, plinth, footstall
710 | n03908618 pencil box, pencil case
711 | n03908714 pencil sharpener
712 | n03916031 perfume, essence
713 | n03920288 Petri dish
714 | n03924679 photocopier
715 | n03929660 pick, plectrum, plectron
716 | n03929855 pickelhaube
717 | n03930313 picket fence, paling
718 | n03930630 pickup, pickup truck
719 | n03933933 pier
720 | n03935335 piggy bank, penny bank
721 | n03937543 pill bottle
722 | n03938244 pillow
723 | n03942813 ping-pong ball
724 | n03944341 pinwheel
725 | n03947888 pirate, pirate ship
726 | n03950228 pitcher, ewer
727 | n03954731 plane, carpenter's plane, woodworking plane
728 | n03956157 planetarium
729 | n03958227 plastic bag
730 | n03961711 plate rack
731 | n03967562 plow, plough
732 | n03970156 plunger, plumber's helper
733 | n03976467 Polaroid camera, Polaroid Land camera
734 | n03976657 pole
735 | n03977966 police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
736 | n03980874 poncho
737 | n03982430 pool table, billiard table, snooker table
738 | n03983396 pop bottle, soda bottle
739 | n03991062 pot, flowerpot
740 | n03992509 potter's wheel
741 | n03995372 power drill
742 | n03998194 prayer rug, prayer mat
743 | n04004767 printer
744 | n04005630 prison, prison house
745 | n04008634 projectile, missile
746 | n04009552 projector
747 | n04019541 puck, hockey puck
748 | n04023962 punching bag, punch bag, punching ball, punchball
749 | n04026417 purse
750 | n04033901 quill, quill pen
751 | n04033995 quilt, comforter, comfort, puff
752 | n04037443 racer, race car, racing car
753 | n04039381 racket, racquet
754 | n04040759 radiator
755 | n04041544 radio, wireless
756 | n04044716 radio telescope, radio reflector
757 | n04049303 rain barrel
758 | n04065272 recreational vehicle, RV, R.V.
759 | n04067472 reel
760 | n04069434 reflex camera
761 | n04070727 refrigerator, icebox
762 | n04074963 remote control, remote
763 | n04081281 restaurant, eating house, eating place, eatery
764 | n04086273 revolver, six-gun, six-shooter
765 | n04090263 rifle
766 | n04099969 rocking chair, rocker
767 | n04111531 rotisserie
768 | n04116512 rubber eraser, rubber, pencil eraser
769 | n04118538 rugby ball
770 | n04118776 rule, ruler
771 | n04120489 running shoe
772 | n04125021 safe
773 | n04127249 safety pin
774 | n04131690 saltshaker, salt shaker
775 | n04133789 sandal
776 | n04136333 sarong
777 | n04141076 sax, saxophone
778 | n04141327 scabbard
779 | n04141975 scale, weighing machine
780 | n04146614 school bus
781 | n04147183 schooner
782 | n04149813 scoreboard
783 | n04152593 screen, CRT screen
784 | n04153751 screw
785 | n04154565 screwdriver
786 | n04162706 seat belt, seatbelt
787 | n04179913 sewing machine
788 | n04192698 shield, buckler
789 | n04200800 shoe shop, shoe-shop, shoe store
790 | n04201297 shoji
791 | n04204238 shopping basket
792 | n04204347 shopping cart
793 | n04208210 shovel
794 | n04209133 shower cap
795 | n04209239 shower curtain
796 | n04228054 ski
797 | n04229816 ski mask
798 | n04235860 sleeping bag
799 | n04238763 slide rule, slipstick
800 | n04239074 sliding door
801 | n04243546 slot, one-armed bandit
802 | n04251144 snorkel
803 | n04252077 snowmobile
804 | n04252225 snowplow, snowplough
805 | n04254120 soap dispenser
806 | n04254680 soccer ball
807 | n04254777 sock
808 | n04258138 solar dish, solar collector, solar furnace
809 | n04259630 sombrero
810 | n04263257 soup bowl
811 | n04264628 space bar
812 | n04265275 space heater
813 | n04266014 space shuttle
814 | n04270147 spatula
815 | n04273569 speedboat
816 | n04275548 spider web, spider's web
817 | n04277352 spindle
818 | n04285008 sports car, sport car
819 | n04286575 spotlight, spot
820 | n04296562 stage
821 | n04310018 steam locomotive
822 | n04311004 steel arch bridge
823 | n04311174 steel drum
824 | n04317175 stethoscope
825 | n04325704 stole
826 | n04326547 stone wall
827 | n04328186 stopwatch, stop watch
828 | n04330267 stove
829 | n04332243 strainer
830 | n04335435 streetcar, tram, tramcar, trolley, trolley car
831 | n04336792 stretcher
832 | n04344873 studio couch, day bed
833 | n04346328 stupa, tope
834 | n04347754 submarine, pigboat, sub, U-boat
835 | n04350905 suit, suit of clothes
836 | n04355338 sundial
837 | n04355933 sunglass
838 | n04356056 sunglasses, dark glasses, shades
839 | n04357314 sunscreen, sunblock, sun blocker
840 | n04366367 suspension bridge
841 | n04367480 swab, swob, mop
842 | n04370456 sweatshirt
843 | n04371430 swimming trunks, bathing trunks
844 | n04371774 swing
845 | n04372370 switch, electric switch, electrical switch
846 | n04376876 syringe
847 | n04380533 table lamp
848 | n04389033 tank, army tank, armored combat vehicle, armoured combat vehicle
849 | n04392985 tape player
850 | n04398044 teapot
851 | n04399382 teddy, teddy bear
852 | n04404412 television, television system
853 | n04409515 tennis ball
854 | n04417672 thatch, thatched roof
855 | n04418357 theater curtain, theatre curtain
856 | n04423845 thimble
857 | n04428191 thresher, thrasher, threshing machine
858 | n04429376 throne
859 | n04435653 tile roof
860 | n04442312 toaster
861 | n04443257 tobacco shop, tobacconist shop, tobacconist
862 | n04447861 toilet seat
863 | n04456115 torch
864 | n04458633 totem pole
865 | n04461696 tow truck, tow car, wrecker
866 | n04462240 toyshop
867 | n04465501 tractor
868 | n04467665 trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi
869 | n04476259 tray
870 | n04479046 trench coat
871 | n04482393 tricycle, trike, velocipede
872 | n04483307 trimaran
873 | n04485082 tripod
874 | n04486054 triumphal arch
875 | n04487081 trolleybus, trolley coach, trackless trolley
876 | n04487394 trombone
877 | n04493381 tub, vat
878 | n04501370 turnstile
879 | n04505470 typewriter keyboard
880 | n04507155 umbrella
881 | n04509417 unicycle, monocycle
882 | n04515003 upright, upright piano
883 | n04517823 vacuum, vacuum cleaner
884 | n04522168 vase
885 | n04523525 vault
886 | n04525038 velvet
887 | n04525305 vending machine
888 | n04532106 vestment
889 | n04532670 viaduct
890 | n04536866 violin, fiddle
891 | n04540053 volleyball
892 | n04542943 waffle iron
893 | n04548280 wall clock
894 | n04548362 wallet, billfold, notecase, pocketbook
895 | n04550184 wardrobe, closet, press
896 | n04552348 warplane, military plane
897 | n04553703 washbasin, handbasin, washbowl, lavabo, wash-hand basin
898 | n04554684 washer, automatic washer, washing machine
899 | n04557648 water bottle
900 | n04560804 water jug
901 | n04562935 water tower
902 | n04579145 whiskey jug
903 | n04579432 whistle
904 | n04584207 wig
905 | n04589890 window screen
906 | n04590129 window shade
907 | n04591157 Windsor tie
908 | n04591713 wine bottle
909 | n04592741 wing
910 | n04596742 wok
911 | n04597913 wooden spoon
912 | n04599235 wool, woolen, woollen
913 | n04604644 worm fence, snake fence, snake-rail fence, Virginia fence
914 | n04606251 wreck
915 | n04612504 yawl
916 | n04613696 yurt
917 | n06359193 web site, website, internet site, site
918 | n06596364 comic book
919 | n06785654 crossword puzzle, crossword
920 | n06794110 street sign
921 | n06874185 traffic light, traffic signal, stoplight
922 | n07248320 book jacket, dust cover, dust jacket, dust wrapper
923 | n07565083 menu
924 | n07579787 plate
925 | n07583066 guacamole
926 | n07584110 consomme
927 | n07590611 hot pot, hotpot
928 | n07613480 trifle
929 | n07614500 ice cream, icecream
930 | n07615774 ice lolly, lolly, lollipop, popsicle
931 | n07684084 French loaf
932 | n07693725 bagel, beigel
933 | n07695742 pretzel
934 | n07697313 cheeseburger
935 | n07697537 hotdog, hot dog, red hot
936 | n07711569 mashed potato
937 | n07714571 head cabbage
938 | n07714990 broccoli
939 | n07715103 cauliflower
940 | n07716358 zucchini, courgette
941 | n07716906 spaghetti squash
942 | n07717410 acorn squash
943 | n07717556 butternut squash
944 | n07718472 cucumber, cuke
945 | n07718747 artichoke, globe artichoke
946 | n07720875 bell pepper
947 | n07730033 cardoon
948 | n07734744 mushroom
949 | n07742313 Granny Smith
950 | n07745940 strawberry
951 | n07747607 orange
952 | n07749582 lemon
953 | n07753113 fig
954 | n07753275 pineapple, ananas
955 | n07753592 banana
956 | n07754684 jackfruit, jak, jack
957 | n07760859 custard apple
958 | n07768694 pomegranate
959 | n07802026 hay
960 | n07831146 carbonara
961 | n07836838 chocolate sauce, chocolate syrup
962 | n07860988 dough
963 | n07871810 meat loaf, meatloaf
964 | n07873807 pizza, pizza pie
965 | n07875152 potpie
966 | n07880968 burrito
967 | n07892512 red wine
968 | n07920052 espresso
969 | n07930864 cup
970 | n07932039 eggnog
971 | n09193705 alp
972 | n09229709 bubble
973 | n09246464 cliff, drop, drop-off
974 | n09256479 coral reef
975 | n09288635 geyser
976 | n09332890 lakeside, lakeshore
977 | n09399592 promontory, headland, head, foreland
978 | n09421951 sandbar, sand bar
979 | n09428293 seashore, coast, seacoast, sea-coast
980 | n09468604 valley, vale
981 | n09472597 volcano
982 | n09835506 ballplayer, baseball player
983 | n10148035 groom, bridegroom
984 | n10565667 scuba diver
985 | n11879895 rapeseed
986 | n11939491 daisy
987 | n12057211 yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum
988 | n12144580 corn
989 | n12267677 acorn
990 | n12620546 hip, rose hip, rosehip
991 | n12768682 buckeye, horse chestnut, conker
992 | n12985857 coral fungus
993 | n12998815 agaric
994 | n13037406 gyromitra
995 | n13040303 stinkhorn, carrion fungus
996 | n13044778 earthstar
997 | n13052670 hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa
998 | n13054560 bolete
999 | n13133613 ear, spike, capitulum
1000 | n15075141 toilet tissue, toilet paper, bathroom tissue
1001 |
--------------------------------------------------------------------------------