├── .gitignore
├── Dockerfile
├── README.md
├── imagenet
├── get_checkpoint.sh
└── labels.txt
├── images
├── cat.jpg
├── cat_heatmap.png
└── scarjo.png
├── main.py
├── main.sh
└── model
├── __init__.py
├── nets
├── __init__.py
├── nets_factory.py
├── resnet_utils.py
└── resnet_v2.py
└── preprocessing
├── __init__.py
├── inception_preprocessing.py
└── preprocessing_factory.py
/.gitignore:
--------------------------------------------------------------------------------
1 | imagenet/*.ckpt
2 | imagenet/*.graph
3 | output.png
4 | *.pyc
5 | model/nets/*.pyc
6 | model/preprocessing/*.pyc
7 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM tensorflow/tensorflow:1.5.0-rc0-devel
2 |
3 | RUN pip install -U pip
4 |
5 | RUN apt-get update && \
6 | apt-get install -y \
7 | build-essential \
8 | cmake \
9 | git \
10 | libgtk2.0-dev \
11 | pkg-config \
12 | libavcodec-dev \
13 | libavformat-dev \
14 | libswscale-dev \
15 | python-dev \
16 | python-numpy \
17 | python-skimage \
18 | python-tk \
19 | libtbb2 \
20 | libtbb-dev \
21 | libjpeg-dev \
22 | libpng-dev \
23 | libtiff-dev \
24 | libjasper-dev \
25 | libdc1394-22-dev \
26 | qt5-default \
27 | wget \
28 | vim
29 |
30 | RUN git clone https://github.com/opencv/opencv.git /root/opencv && \
31 | cd /root/opencv && \
32 | git checkout 2.4 && \
33 | mkdir build && \
34 | cd build && \
35 | cmake -DWITH_QT=ON -DWITH_OPENGL=ON -DFORCE_VTK=ON -DWITH_TBB=ON -DWITH_GDAL=ON -DWITH_XINE=ON -DBUILD_EXAMPLES=ON .. && \
36 | make -j"$(nproc)" && \
37 | make install && \
38 | ldconfig && \
39 | echo 'ln /dev/null /dev/raw1394' >> ~/.bashrc
40 |
41 | RUN ln /dev/null /dev/raw1394
42 |
43 | RUN cd /root && git clone https://github.com/hiveml/tensorflow-grad-cam
44 |
45 | WORKDIR /root/tensorflow-grad-cam
46 |
47 | RUN cd imagenet && ./get_checkpoint.sh
48 |
49 | CMD /bin/bash
50 |
51 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Grad-Cam - Tensorflow Slim
4 |
5 |
6 |
7 |
8 | ### Features:
9 |
10 | Modular with Tensorflow slim. Easy to drop in other [Slim Models](https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models)
11 |
12 | Udated to work with Tensorflow 1.5
13 |
14 | Includes various output options: heatmap, shading, blur
15 |
16 | #### More examples and explanation [here](https://thehive.ai/blog/inside-a-neural-networks-mind)
17 |
18 |
19 | ### Dependencies
20 |
21 | * Python 2.7 and pip
22 | * Scikit image : `sudo apt-get install python-skimage`
23 | * Tkinter: `sudo apt-get install python-tk`
24 | * Tensorflow >= 1.5 : `pip install tensorflow==1.5.0rc0`
25 | * Opencv - see https://docs.opencv.org/2.4/doc/tutorials/introduction/linux_install/linux_install.html
26 |
27 |
28 | ### Installation
29 |
30 | Clone the repo:
31 | ```
32 | git clone https://github.com/hiveml/tensorflow-grad-cam.git
33 | cd tensorflow-grad-cam
34 | ```
35 | Download the ResNet-50 weights:
36 | ```
37 | ./imagenet/get_checkpoint.sh
38 | ```
39 | ### Usage
40 | ```
41 | ./main.sh
42 | ```
43 |
44 | ### Changing the Class
45 |
46 | By default this code shows the grad-cam results for the top class. You can
47 | change the `predicted_class` argument to function `grad_cam` to see where the network
48 | would look for other classes.
49 |
50 | ### How to load another resnet\_v2 model
51 |
52 | First download the new model from here: [Slim Models](https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models)
53 |
54 | Then modify the input arguments in main.sh:
55 | ```
56 | python main.py --model_name=resnet_v2_101 --dataset_dir=./imagenet/ --checkpoint_path=./imagenet/resnet_v2_101.ckpt --input=./images/cat.jpg --eval_image_size=299
57 |
58 | ```
59 |
60 |
61 | Repo is based off this [code](https://github.com/Ankush96/grad-cam.tensorflow).
62 |
63 |
--------------------------------------------------------------------------------
/imagenet/get_checkpoint.sh:
--------------------------------------------------------------------------------
1 | wget http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz
2 | tar xzvf resnet_v2_50_2017_04_14.tar.gz
3 | rm resnet_v2_50_2017_04_14.tar.gz
4 |
--------------------------------------------------------------------------------
/imagenet/labels.txt:
--------------------------------------------------------------------------------
1 | 0: tench, Tinca tinca
2 | 1: goldfish, Carassius auratus
3 | 2: great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias
4 | 3: tiger shark, Galeocerdo cuvieri
5 | 4: hammerhead, hammerhead shark
6 | 5: electric ray, crampfish, numbfish, torpedo
7 | 6: stingray
8 | 7: cock
9 | 8: hen
10 | 9: ostrich, Struthio camelus
11 | 10: brambling, Fringilla montifringilla
12 | 11: goldfinch, Carduelis carduelis
13 | 12: house finch, linnet, Carpodacus mexicanus
14 | 13: junco, snowbird
15 | 14: indigo bunting, indigo finch, indigo bird, Passerina cyanea
16 | 15: robin, American robin, Turdus migratorius
17 | 16: bulbul
18 | 17: jay
19 | 18: magpie
20 | 19: chickadee
21 | 20: water ouzel, dipper
22 | 21: kite
23 | 22: bald eagle, American eagle, Haliaeetus leucocephalus
24 | 23: vulture
25 | 24: great grey owl, great gray owl, Strix nebulosa
26 | 25: European fire salamander, Salamandra salamandra
27 | 26: common newt, Triturus vulgaris
28 | 27: eft
29 | 28: spotted salamander, Ambystoma maculatum
30 | 29: axolotl, mud puppy, Ambystoma mexicanum
31 | 30: bullfrog, Rana catesbeiana
32 | 31: tree frog, tree-frog
33 | 32: tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui
34 | 33: loggerhead, loggerhead turtle, Caretta caretta
35 | 34: leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea
36 | 35: mud turtle
37 | 36: terrapin
38 | 37: box turtle, box tortoise
39 | 38: banded gecko
40 | 39: common iguana, iguana, Iguana iguana
41 | 40: American chameleon, anole, Anolis carolinensis
42 | 41: whiptail, whiptail lizard
43 | 42: agama
44 | 43: frilled lizard, Chlamydosaurus kingi
45 | 44: alligator lizard
46 | 45: Gila monster, Heloderma suspectum
47 | 46: green lizard, Lacerta viridis
48 | 47: African chameleon, Chamaeleo chamaeleon
49 | 48: Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis
50 | 49: African crocodile, Nile crocodile, Crocodylus niloticus
51 | 50: American alligator, Alligator mississipiensis
52 | 51: triceratops
53 | 52: thunder snake, worm snake, Carphophis amoenus
54 | 53: ringneck snake, ring-necked snake, ring snake
55 | 54: hognose snake, puff adder, sand viper
56 | 55: green snake, grass snake
57 | 56: king snake, kingsnake
58 | 57: garter snake, grass snake
59 | 58: water snake
60 | 59: vine snake
61 | 60: night snake, Hypsiglena torquata
62 | 61: boa constrictor, Constrictor constrictor
63 | 62: rock python, rock snake, Python sebae
64 | 63: Indian cobra, Naja naja
65 | 64: green mamba
66 | 65: sea snake
67 | 66: horned viper, cerastes, sand viper, horned asp, Cerastes cornutus
68 | 67: diamondback, diamondback rattlesnake, Crotalus adamanteus
69 | 68: sidewinder, horned rattlesnake, Crotalus cerastes
70 | 69: trilobite
71 | 70: harvestman, daddy longlegs, Phalangium opilio
72 | 71: scorpion
73 | 72: black and gold garden spider, Argiope aurantia
74 | 73: barn spider, Araneus cavaticus
75 | 74: garden spider, Aranea diademata
76 | 75: black widow, Latrodectus mactans
77 | 76: tarantula
78 | 77: wolf spider, hunting spider
79 | 78: tick
80 | 79: centipede
81 | 80: black grouse
82 | 81: ptarmigan
83 | 82: ruffed grouse, partridge, Bonasa umbellus
84 | 83: prairie chicken, prairie grouse, prairie fowl
85 | 84: peacock
86 | 85: quail
87 | 86: partridge
88 | 87: African grey, African gray, Psittacus erithacus
89 | 88: macaw
90 | 89: sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita
91 | 90: lorikeet
92 | 91: coucal
93 | 92: bee eater
94 | 93: hornbill
95 | 94: hummingbird
96 | 95: jacamar
97 | 96: toucan
98 | 97: drake
99 | 98: red-breasted merganser, Mergus serrator
100 | 99: goose
101 | 100: black swan, Cygnus atratus
102 | 101: tusker
103 | 102: echidna, spiny anteater, anteater
104 | 103: platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus
105 | 104: wallaby, brush kangaroo
106 | 105: koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus
107 | 106: wombat
108 | 107: jellyfish
109 | 108: sea anemone, anemone
110 | 109: brain coral
111 | 110: flatworm, platyhelminth
112 | 111: nematode, nematode worm, roundworm
113 | 112: conch
114 | 113: snail
115 | 114: slug
116 | 115: sea slug, nudibranch
117 | 116: chiton, coat-of-mail shell, sea cradle, polyplacophore
118 | 117: chambered nautilus, pearly nautilus, nautilus
119 | 118: Dungeness crab, Cancer magister
120 | 119: rock crab, Cancer irroratus
121 | 120: fiddler crab
122 | 121: king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica
123 | 122: American lobster, Northern lobster, Maine lobster, Homarus americanus
124 | 123: spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish
125 | 124: crayfish, crawfish, crawdad, crawdaddy
126 | 125: hermit crab
127 | 126: isopod
128 | 127: white stork, Ciconia ciconia
129 | 128: black stork, Ciconia nigra
130 | 129: spoonbill
131 | 130: flamingo
132 | 131: little blue heron, Egretta caerulea
133 | 132: American egret, great white heron, Egretta albus
134 | 133: bittern
135 | 134: crane
136 | 135: limpkin, Aramus pictus
137 | 136: European gallinule, Porphyrio porphyrio
138 | 137: American coot, marsh hen, mud hen, water hen, Fulica americana
139 | 138: bustard
140 | 139: ruddy turnstone, Arenaria interpres
141 | 140: red-backed sandpiper, dunlin, Erolia alpina
142 | 141: redshank, Tringa totanus
143 | 142: dowitcher
144 | 143: oystercatcher, oyster catcher
145 | 144: pelican
146 | 145: king penguin, Aptenodytes patagonica
147 | 146: albatross, mollymawk
148 | 147: grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus
149 | 148: killer whale, killer, orca, grampus, sea wolf, Orcinus orca
150 | 149: dugong, Dugong dugon
151 | 150: sea lion
152 | 151: Chihuahua
153 | 152: Japanese spaniel
154 | 153: Maltese dog, Maltese terrier, Maltese
155 | 154: Pekinese, Pekingese, Peke
156 | 155: Shih-Tzu
157 | 156: Blenheim spaniel
158 | 157: papillon
159 | 158: toy terrier
160 | 159: Rhodesian ridgeback
161 | 160: Afghan hound, Afghan
162 | 161: basset, basset hound
163 | 162: beagle
164 | 163: bloodhound, sleuthhound
165 | 164: bluetick
166 | 165: black-and-tan coonhound
167 | 166: Walker hound, Walker foxhound
168 | 167: English foxhound
169 | 168: redbone
170 | 169: borzoi, Russian wolfhound
171 | 170: Irish wolfhound
172 | 171: Italian greyhound
173 | 172: whippet
174 | 173: Ibizan hound, Ibizan Podenco
175 | 174: Norwegian elkhound, elkhound
176 | 175: otterhound, otter hound
177 | 176: Saluki, gazelle hound
178 | 177: Scottish deerhound, deerhound
179 | 178: Weimaraner
180 | 179: Staffordshire bullterrier, Staffordshire bull terrier
181 | 180: American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier
182 | 181: Bedlington terrier
183 | 182: Border terrier
184 | 183: Kerry blue terrier
185 | 184: Irish terrier
186 | 185: Norfolk terrier
187 | 186: Norwich terrier
188 | 187: Yorkshire terrier
189 | 188: wire-haired fox terrier
190 | 189: Lakeland terrier
191 | 190: Sealyham terrier, Sealyham
192 | 191: Airedale, Airedale terrier
193 | 192: cairn, cairn terrier
194 | 193: Australian terrier
195 | 194: Dandie Dinmont, Dandie Dinmont terrier
196 | 195: Boston bull, Boston terrier
197 | 196: miniature schnauzer
198 | 197: giant schnauzer
199 | 198: standard schnauzer
200 | 199: Scotch terrier, Scottish terrier, Scottie
201 | 200: Tibetan terrier, chrysanthemum dog
202 | 201: silky terrier, Sydney silky
203 | 202: soft-coated wheaten terrier
204 | 203: West Highland white terrier
205 | 204: Lhasa, Lhasa apso
206 | 205: flat-coated retriever
207 | 206: curly-coated retriever
208 | 207: golden retriever
209 | 208: Labrador retriever
210 | 209: Chesapeake Bay retriever
211 | 210: German short-haired pointer
212 | 211: vizsla, Hungarian pointer
213 | 212: English setter
214 | 213: Irish setter, red setter
215 | 214: Gordon setter
216 | 215: Brittany spaniel
217 | 216: clumber, clumber spaniel
218 | 217: English springer, English springer spaniel
219 | 218: Welsh springer spaniel
220 | 219: cocker spaniel, English cocker spaniel, cocker
221 | 220: Sussex spaniel
222 | 221: Irish water spaniel
223 | 222: kuvasz
224 | 223: schipperke
225 | 224: groenendael
226 | 225: malinois
227 | 226: briard
228 | 227: kelpie
229 | 228: komondor
230 | 229: Old English sheepdog, bobtail
231 | 230: Shetland sheepdog, Shetland sheep dog, Shetland
232 | 231: collie
233 | 232: Border collie
234 | 233: Bouvier des Flandres, Bouviers des Flandres
235 | 234: Rottweiler
236 | 235: German shepherd, German shepherd dog, German police dog, alsatian
237 | 236: Doberman, Doberman pinscher
238 | 237: miniature pinscher
239 | 238: Greater Swiss Mountain dog
240 | 239: Bernese mountain dog
241 | 240: Appenzeller
242 | 241: EntleBucher
243 | 242: boxer
244 | 243: bull mastiff
245 | 244: Tibetan mastiff
246 | 245: French bulldog
247 | 246: Great Dane
248 | 247: Saint Bernard, St Bernard
249 | 248: Eskimo dog, husky
250 | 249: malamute, malemute, Alaskan malamute
251 | 250: Siberian husky
252 | 251: dalmatian, coach dog, carriage dog
253 | 252: affenpinscher, monkey pinscher, monkey dog
254 | 253: basenji
255 | 254: pug, pug-dog
256 | 255: Leonberg
257 | 256: Newfoundland, Newfoundland dog
258 | 257: Great Pyrenees
259 | 258: Samoyed, Samoyede
260 | 259: Pomeranian
261 | 260: chow, chow chow
262 | 261: keeshond
263 | 262: Brabancon griffon
264 | 263: Pembroke, Pembroke Welsh corgi
265 | 264: Cardigan, Cardigan Welsh corgi
266 | 265: toy poodle
267 | 266: miniature poodle
268 | 267: standard poodle
269 | 268: Mexican hairless
270 | 269: timber wolf, grey wolf, gray wolf, Canis lupus
271 | 270: white wolf, Arctic wolf, Canis lupus tundrarum
272 | 271: red wolf, maned wolf, Canis rufus, Canis niger
273 | 272: coyote, prairie wolf, brush wolf, Canis latrans
274 | 273: dingo, warrigal, warragal, Canis dingo
275 | 274: dhole, Cuon alpinus
276 | 275: African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus
277 | 276: hyena, hyaena
278 | 277: red fox, Vulpes vulpes
279 | 278: kit fox, Vulpes macrotis
280 | 279: Arctic fox, white fox, Alopex lagopus
281 | 280: grey fox, gray fox, Urocyon cinereoargenteus
282 | 281: tabby, tabby cat
283 | 282: tiger cat
284 | 283: Persian cat
285 | 284: Siamese cat, Siamese
286 | 285: Egyptian cat
287 | 286: cougar, puma, catamount, mountain lion, painter, panther, Felis concolor
288 | 287: lynx, catamount
289 | 288: leopard, Panthera pardus
290 | 289: snow leopard, ounce, Panthera uncia
291 | 290: jaguar, panther, Panthera onca, Felis onca
292 | 291: lion, king of beasts, Panthera leo
293 | 292: tiger, Panthera tigris
294 | 293: cheetah, chetah, Acinonyx jubatus
295 | 294: brown bear, bruin, Ursus arctos
296 | 295: American black bear, black bear, Ursus americanus, Euarctos americanus
297 | 296: ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus
298 | 297: sloth bear, Melursus ursinus, Ursus ursinus
299 | 298: mongoose
300 | 299: meerkat, mierkat
301 | 300: tiger beetle
302 | 301: ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle
303 | 302: ground beetle, carabid beetle
304 | 303: long-horned beetle, longicorn, longicorn beetle
305 | 304: leaf beetle, chrysomelid
306 | 305: dung beetle
307 | 306: rhinoceros beetle
308 | 307: weevil
309 | 308: fly
310 | 309: bee
311 | 310: ant, emmet, pismire
312 | 311: grasshopper, hopper
313 | 312: cricket
314 | 313: walking stick, walkingstick, stick insect
315 | 314: cockroach, roach
316 | 315: mantis, mantid
317 | 316: cicada, cicala
318 | 317: leafhopper
319 | 318: lacewing, lacewing fly
320 | 319: dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk
321 | 320: damselfly
322 | 321: admiral
323 | 322: ringlet, ringlet butterfly
324 | 323: monarch, monarch butterfly, milkweed butterfly, Danaus plexippus
325 | 324: cabbage butterfly
326 | 325: sulphur butterfly, sulfur butterfly
327 | 326: lycaenid, lycaenid butterfly
328 | 327: starfish, sea star
329 | 328: sea urchin
330 | 329: sea cucumber, holothurian
331 | 330: wood rabbit, cottontail, cottontail rabbit
332 | 331: hare
333 | 332: Angora, Angora rabbit
334 | 333: hamster
335 | 334: porcupine, hedgehog
336 | 335: fox squirrel, eastern fox squirrel, Sciurus niger
337 | 336: marmot
338 | 337: beaver
339 | 338: guinea pig, Cavia cobaya
340 | 339: sorrel
341 | 340: zebra
342 | 341: hog, pig, grunter, squealer, Sus scrofa
343 | 342: wild boar, boar, Sus scrofa
344 | 343: warthog
345 | 344: hippopotamus, hippo, river horse, Hippopotamus amphibius
346 | 345: ox
347 | 346: water buffalo, water ox, Asiatic buffalo, Bubalus bubalis
348 | 347: bison
349 | 348: ram, tup
350 | 349: bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis
351 | 350: ibex, Capra ibex
352 | 351: hartebeest
353 | 352: impala, Aepyceros melampus
354 | 353: gazelle
355 | 354: Arabian camel, dromedary, Camelus dromedarius
356 | 355: llama
357 | 356: weasel
358 | 357: mink
359 | 358: polecat, fitch, foulmart, foumart, Mustela putorius
360 | 359: black-footed ferret, ferret, Mustela nigripes
361 | 360: otter
362 | 361: skunk, polecat, wood pussy
363 | 362: badger
364 | 363: armadillo
365 | 364: three-toed sloth, ai, Bradypus tridactylus
366 | 365: orangutan, orang, orangutang, Pongo pygmaeus
367 | 366: gorilla, Gorilla gorilla
368 | 367: chimpanzee, chimp, Pan troglodytes
369 | 368: gibbon, Hylobates lar
370 | 369: siamang, Hylobates syndactylus, Symphalangus syndactylus
371 | 370: guenon, guenon monkey
372 | 371: patas, hussar monkey, Erythrocebus patas
373 | 372: baboon
374 | 373: macaque
375 | 374: langur
376 | 375: colobus, colobus monkey
377 | 376: proboscis monkey, Nasalis larvatus
378 | 377: marmoset
379 | 378: capuchin, ringtail, Cebus capucinus
380 | 379: howler monkey, howler
381 | 380: titi, titi monkey
382 | 381: spider monkey, Ateles geoffroyi
383 | 382: squirrel monkey, Saimiri sciureus
384 | 383: Madagascar cat, ring-tailed lemur, Lemur catta
385 | 384: indri, indris, Indri indri, Indri brevicaudatus
386 | 385: Indian elephant, Elephas maximus
387 | 386: African elephant, Loxodonta africana
388 | 387: lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens
389 | 388: giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca
390 | 389: barracouta, snoek
391 | 390: eel
392 | 391: coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch
393 | 392: rock beauty, Holocanthus tricolor
394 | 393: anemone fish
395 | 394: sturgeon
396 | 395: gar, garfish, garpike, billfish, Lepisosteus osseus
397 | 396: lionfish
398 | 397: puffer, pufferfish, blowfish, globefish
399 | 398: abacus
400 | 399: abaya
401 | 400: academic gown, academic robe, judge's robe
402 | 401: accordion, piano accordion, squeeze box
403 | 402: acoustic guitar
404 | 403: aircraft carrier, carrier, flattop, attack aircraft carrier
405 | 404: airliner
406 | 405: airship, dirigible
407 | 406: altar
408 | 407: ambulance
409 | 408: amphibian, amphibious vehicle
410 | 409: analog clock
411 | 410: apiary, bee house
412 | 411: apron
413 | 412: ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin
414 | 413: assault rifle, assault gun
415 | 414: backpack, back pack, knapsack, packsack, rucksack, haversack
416 | 415: bakery, bakeshop, bakehouse
417 | 416: balance beam, beam
418 | 417: balloon
419 | 418: ballpoint, ballpoint pen, ballpen, Biro
420 | 419: Band Aid
421 | 420: banjo
422 | 421: bannister, banister, balustrade, balusters, handrail
423 | 422: barbell
424 | 423: barber chair
425 | 424: barbershop
426 | 425: barn
427 | 426: barometer
428 | 427: barrel, cask
429 | 428: barrow, garden cart, lawn cart, wheelbarrow
430 | 429: baseball
431 | 430: basketball
432 | 431: bassinet
433 | 432: bassoon
434 | 433: bathing cap, swimming cap
435 | 434: bath towel
436 | 435: bathtub, bathing tub, bath, tub
437 | 436: beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon
438 | 437: beacon, lighthouse, beacon light, pharos
439 | 438: beaker
440 | 439: bearskin, busby, shako
441 | 440: beer bottle
442 | 441: beer glass
443 | 442: bell cote, bell cot
444 | 443: bib
445 | 444: bicycle-built-for-two, tandem bicycle, tandem
446 | 445: bikini, two-piece
447 | 446: binder, ring-binder
448 | 447: binoculars, field glasses, opera glasses
449 | 448: birdhouse
450 | 449: boathouse
451 | 450: bobsled, bobsleigh, bob
452 | 451: bolo tie, bolo, bola tie, bola
453 | 452: bonnet, poke bonnet
454 | 453: bookcase
455 | 454: bookshop, bookstore, bookstall
456 | 455: bottlecap
457 | 456: bow
458 | 457: bow tie, bow-tie, bowtie
459 | 458: brass, memorial tablet, plaque
460 | 459: brassiere, bra, bandeau
461 | 460: breakwater, groin, groyne, mole, bulwark, seawall, jetty
462 | 461: breastplate, aegis, egis
463 | 462: broom
464 | 463: bucket, pail
465 | 464: buckle
466 | 465: bulletproof vest
467 | 466: bullet train, bullet
468 | 467: butcher shop, meat market
469 | 468: cab, hack, taxi, taxicab
470 | 469: caldron, cauldron
471 | 470: candle, taper, wax light
472 | 471: cannon
473 | 472: canoe
474 | 473: can opener, tin opener
475 | 474: cardigan
476 | 475: car mirror
477 | 476: carousel, carrousel, merry-go-round, roundabout, whirligig
478 | 477: carpenter's kit, tool kit
479 | 478: carton
480 | 479: car wheel
481 | 480: cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM
482 | 481: cassette
483 | 482: cassette player
484 | 483: castle
485 | 484: catamaran
486 | 485: CD player
487 | 486: cello, violoncello
488 | 487: cellular telephone, cellular phone, cellphone, cell, mobile phone
489 | 488: chain
490 | 489: chainlink fence
491 | 490: chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour
492 | 491: chain saw, chainsaw
493 | 492: chest
494 | 493: chiffonier, commode
495 | 494: chime, bell, gong
496 | 495: china cabinet, china closet
497 | 496: Christmas stocking
498 | 497: church, church building
499 | 498: cinema, movie theater, movie theatre, movie house, picture palace
500 | 499: cleaver, meat cleaver, chopper
501 | 500: cliff dwelling
502 | 501: cloak
503 | 502: clog, geta, patten, sabot
504 | 503: cocktail shaker
505 | 504: coffee mug
506 | 505: coffeepot
507 | 506: coil, spiral, volute, whorl, helix
508 | 507: combination lock
509 | 508: computer keyboard, keypad
510 | 509: confectionery, confectionary, candy store
511 | 510: container ship, containership, container vessel
512 | 511: convertible
513 | 512: corkscrew, bottle screw
514 | 513: cornet, horn, trumpet, trump
515 | 514: cowboy boot
516 | 515: cowboy hat, ten-gallon hat
517 | 516: cradle
518 | 517: crane
519 | 518: crash helmet
520 | 519: crate
521 | 520: crib, cot
522 | 521: Crock Pot
523 | 522: croquet ball
524 | 523: crutch
525 | 524: cuirass
526 | 525: dam, dike, dyke
527 | 526: desk
528 | 527: desktop computer
529 | 528: dial telephone, dial phone
530 | 529: diaper, nappy, napkin
531 | 530: digital clock
532 | 531: digital watch
533 | 532: dining table, board
534 | 533: dishrag, dishcloth
535 | 534: dishwasher, dish washer, dishwashing machine
536 | 535: disk brake, disc brake
537 | 536: dock, dockage, docking facility
538 | 537: dogsled, dog sled, dog sleigh
539 | 538: dome
540 | 539: doormat, welcome mat
541 | 540: drilling platform, offshore rig
542 | 541: drum, membranophone, tympan
543 | 542: drumstick
544 | 543: dumbbell
545 | 544: Dutch oven
546 | 545: electric fan, blower
547 | 546: electric guitar
548 | 547: electric locomotive
549 | 548: entertainment center
550 | 549: envelope
551 | 550: espresso maker
552 | 551: face powder
553 | 552: feather boa, boa
554 | 553: file, file cabinet, filing cabinet
555 | 554: fireboat
556 | 555: fire engine, fire truck
557 | 556: fire screen, fireguard
558 | 557: flagpole, flagstaff
559 | 558: flute, transverse flute
560 | 559: folding chair
561 | 560: football helmet
562 | 561: forklift
563 | 562: fountain
564 | 563: fountain pen
565 | 564: four-poster
566 | 565: freight car
567 | 566: French horn, horn
568 | 567: frying pan, frypan, skillet
569 | 568: fur coat
570 | 569: garbage truck, dustcart
571 | 570: gasmask, respirator, gas helmet
572 | 571: gas pump, gasoline pump, petrol pump, island dispenser
573 | 572: goblet
574 | 573: go-kart
575 | 574: golf ball
576 | 575: golfcart, golf cart
577 | 576: gondola
578 | 577: gong, tam-tam
579 | 578: gown
580 | 579: grand piano, grand
581 | 580: greenhouse, nursery, glasshouse
582 | 581: grille, radiator grille
583 | 582: grocery store, grocery, food market, market
584 | 583: guillotine
585 | 584: hair slide
586 | 585: hair spray
587 | 586: half track
588 | 587: hammer
589 | 588: hamper
590 | 589: hand blower, blow dryer, blow drier, hair dryer, hair drier
591 | 590: hand-held computer, hand-held microcomputer
592 | 591: handkerchief, hankie, hanky, hankey
593 | 592: hard disc, hard disk, fixed disk
594 | 593: harmonica, mouth organ, harp, mouth harp
595 | 594: harp
596 | 595: harvester, reaper
597 | 596: hatchet
598 | 597: holster
599 | 598: home theater, home theatre
600 | 599: honeycomb
601 | 600: hook, claw
602 | 601: hoopskirt, crinoline
603 | 602: horizontal bar, high bar
604 | 603: horse cart, horse-cart
605 | 604: hourglass
606 | 605: iPod
607 | 606: iron, smoothing iron
608 | 607: jack-o'-lantern
609 | 608: jean, blue jean, denim
610 | 609: jeep, landrover
611 | 610: jersey, T-shirt, tee shirt
612 | 611: jigsaw puzzle
613 | 612: jinrikisha, ricksha, rickshaw
614 | 613: joystick
615 | 614: kimono
616 | 615: knee pad
617 | 616: knot
618 | 617: lab coat, laboratory coat
619 | 618: ladle
620 | 619: lampshade, lamp shade
621 | 620: laptop, laptop computer
622 | 621: lawn mower, mower
623 | 622: lens cap, lens cover
624 | 623: letter opener, paper knife, paperknife
625 | 624: library
626 | 625: lifeboat
627 | 626: lighter, light, igniter, ignitor
628 | 627: limousine, limo
629 | 628: liner, ocean liner
630 | 629: lipstick, lip rouge
631 | 630: Loafer
632 | 631: lotion
633 | 632: loudspeaker, speaker, speaker unit, loudspeaker system, speaker system
634 | 633: loupe, jeweler's loupe
635 | 634: lumbermill, sawmill
636 | 635: magnetic compass
637 | 636: mailbag, postbag
638 | 637: mailbox, letter box
639 | 638: maillot
640 | 639: maillot, tank suit
641 | 640: manhole cover
642 | 641: maraca
643 | 642: marimba, xylophone
644 | 643: mask
645 | 644: matchstick
646 | 645: maypole
647 | 646: maze, labyrinth
648 | 647: measuring cup
649 | 648: medicine chest, medicine cabinet
650 | 649: megalith, megalithic structure
651 | 650: microphone, mike
652 | 651: microwave, microwave oven
653 | 652: military uniform
654 | 653: milk can
655 | 654: minibus
656 | 655: miniskirt, mini
657 | 656: minivan
658 | 657: missile
659 | 658: mitten
660 | 659: mixing bowl
661 | 660: mobile home, manufactured home
662 | 661: Model T
663 | 662: modem
664 | 663: monastery
665 | 664: monitor
666 | 665: moped
667 | 666: mortar
668 | 667: mortarboard
669 | 668: mosque
670 | 669: mosquito net
671 | 670: motor scooter, scooter
672 | 671: mountain bike, all-terrain bike, off-roader
673 | 672: mountain tent
674 | 673: mouse, computer mouse
675 | 674: mousetrap
676 | 675: moving van
677 | 676: muzzle
678 | 677: nail
679 | 678: neck brace
680 | 679: necklace
681 | 680: nipple
682 | 681: notebook, notebook computer
683 | 682: obelisk
684 | 683: oboe, hautboy, hautbois
685 | 684: ocarina, sweet potato
686 | 685: odometer, hodometer, mileometer, milometer
687 | 686: oil filter
688 | 687: organ, pipe organ
689 | 688: oscilloscope, scope, cathode-ray oscilloscope, CRO
690 | 689: overskirt
691 | 690: oxcart
692 | 691: oxygen mask
693 | 692: packet
694 | 693: paddle, boat paddle
695 | 694: paddlewheel, paddle wheel
696 | 695: padlock
697 | 696: paintbrush
698 | 697: pajama, pyjama, pj's, jammies
699 | 698: palace
700 | 699: panpipe, pandean pipe, syrinx
701 | 700: paper towel
702 | 701: parachute, chute
703 | 702: parallel bars, bars
704 | 703: park bench
705 | 704: parking meter
706 | 705: passenger car, coach, carriage
707 | 706: patio, terrace
708 | 707: pay-phone, pay-station
709 | 708: pedestal, plinth, footstall
710 | 709: pencil box, pencil case
711 | 710: pencil sharpener
712 | 711: perfume, essence
713 | 712: Petri dish
714 | 713: photocopier
715 | 714: pick, plectrum, plectron
716 | 715: pickelhaube
717 | 716: picket fence, paling
718 | 717: pickup, pickup truck
719 | 718: pier
720 | 719: piggy bank, penny bank
721 | 720: pill bottle
722 | 721: pillow
723 | 722: ping-pong ball
724 | 723: pinwheel
725 | 724: pirate, pirate ship
726 | 725: pitcher, ewer
727 | 726: plane, carpenter's plane, woodworking plane
728 | 727: planetarium
729 | 728: plastic bag
730 | 729: plate rack
731 | 730: plow, plough
732 | 731: plunger, plumber's helper
733 | 732: Polaroid camera, Polaroid Land camera
734 | 733: pole
735 | 734: police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria
736 | 735: poncho
737 | 736: pool table, billiard table, snooker table
738 | 737: pop bottle, soda bottle
739 | 738: pot, flowerpot
740 | 739: potter's wheel
741 | 740: power drill
742 | 741: prayer rug, prayer mat
743 | 742: printer
744 | 743: prison, prison house
745 | 744: projectile, missile
746 | 745: projector
747 | 746: puck, hockey puck
748 | 747: punching bag, punch bag, punching ball, punchball
749 | 748: purse
750 | 749: quill, quill pen
751 | 750: quilt, comforter, comfort, puff
752 | 751: racer, race car, racing car
753 | 752: racket, racquet
754 | 753: radiator
755 | 754: radio, wireless
756 | 755: radio telescope, radio reflector
757 | 756: rain barrel
758 | 757: recreational vehicle, RV, R.V.
759 | 758: reel
760 | 759: reflex camera
761 | 760: refrigerator, icebox
762 | 761: remote control, remote
763 | 762: restaurant, eating house, eating place, eatery
764 | 763: revolver, six-gun, six-shooter
765 | 764: rifle
766 | 765: rocking chair, rocker
767 | 766: rotisserie
768 | 767: rubber eraser, rubber, pencil eraser
769 | 768: rugby ball
770 | 769: rule, ruler
771 | 770: running shoe
772 | 771: safe
773 | 772: safety pin
774 | 773: saltshaker, salt shaker
775 | 774: sandal
776 | 775: sarong
777 | 776: sax, saxophone
778 | 777: scabbard
779 | 778: scale, weighing machine
780 | 779: school bus
781 | 780: schooner
782 | 781: scoreboard
783 | 782: screen, CRT screen
784 | 783: screw
785 | 784: screwdriver
786 | 785: seat belt, seatbelt
787 | 786: sewing machine
788 | 787: shield, buckler
789 | 788: shoe shop, shoe-shop, shoe store
790 | 789: shoji
791 | 790: shopping basket
792 | 791: shopping cart
793 | 792: shovel
794 | 793: shower cap
795 | 794: shower curtain
796 | 795: ski
797 | 796: ski mask
798 | 797: sleeping bag
799 | 798: slide rule, slipstick
800 | 799: sliding door
801 | 800: slot, one-armed bandit
802 | 801: snorkel
803 | 802: snowmobile
804 | 803: snowplow, snowplough
805 | 804: soap dispenser
806 | 805: soccer ball
807 | 806: sock
808 | 807: solar dish, solar collector, solar furnace
809 | 808: sombrero
810 | 809: soup bowl
811 | 810: space bar
812 | 811: space heater
813 | 812: space shuttle
814 | 813: spatula
815 | 814: speedboat
816 | 815: spider web, spider's web
817 | 816: spindle
818 | 817: sports car, sport car
819 | 818: spotlight, spot
820 | 819: stage
821 | 820: steam locomotive
822 | 821: steel arch bridge
823 | 822: steel drum
824 | 823: stethoscope
825 | 824: stole
826 | 825: stone wall
827 | 826: stopwatch, stop watch
828 | 827: stove
829 | 828: strainer
830 | 829: streetcar, tram, tramcar, trolley, trolley car
831 | 830: stretcher
832 | 831: studio couch, day bed
833 | 832: stupa, tope
834 | 833: submarine, pigboat, sub, U-boat
835 | 834: suit, suit of clothes
836 | 835: sundial
837 | 836: sunglass
838 | 837: sunglasses, dark glasses, shades
839 | 838: sunscreen, sunblock, sun blocker
840 | 839: suspension bridge
841 | 840: swab, swob, mop
842 | 841: sweatshirt
843 | 842: swimming trunks, bathing trunks
844 | 843: swing
845 | 844: switch, electric switch, electrical switch
846 | 845: syringe
847 | 846: table lamp
848 | 847: tank, army tank, armored combat vehicle, armoured combat vehicle
849 | 848: tape player
850 | 849: teapot
851 | 850: teddy, teddy bear
852 | 851: television, television system
853 | 852: tennis ball
854 | 853: thatch, thatched roof
855 | 854: theater curtain, theatre curtain
856 | 855: thimble
857 | 856: thresher, thrasher, threshing machine
858 | 857: throne
859 | 858: tile roof
860 | 859: toaster
861 | 860: tobacco shop, tobacconist shop, tobacconist
862 | 861: toilet seat
863 | 862: torch
864 | 863: totem pole
865 | 864: tow truck, tow car, wrecker
866 | 865: toyshop
867 | 866: tractor
868 | 867: trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi
869 | 868: tray
870 | 869: trench coat
871 | 870: tricycle, trike, velocipede
872 | 871: trimaran
873 | 872: tripod
874 | 873: triumphal arch
875 | 874: trolleybus, trolley coach, trackless trolley
876 | 875: trombone
877 | 876: tub, vat
878 | 877: turnstile
879 | 878: typewriter keyboard
880 | 879: umbrella
881 | 880: unicycle, monocycle
882 | 881: upright, upright piano
883 | 882: vacuum, vacuum cleaner
884 | 883: vase
885 | 884: vault
886 | 885: velvet
887 | 886: vending machine
888 | 887: vestment
889 | 888: viaduct
890 | 889: violin, fiddle
891 | 890: volleyball
892 | 891: waffle iron
893 | 892: wall clock
894 | 893: wallet, billfold, notecase, pocketbook
895 | 894: wardrobe, closet, press
896 | 895: warplane, military plane
897 | 896: washbasin, handbasin, washbowl, lavabo, wash-hand basin
898 | 897: washer, automatic washer, washing machine
899 | 898: water bottle
900 | 899: water jug
901 | 900: water tower
902 | 901: whiskey jug
903 | 902: whistle
904 | 903: wig
905 | 904: window screen
906 | 905: window shade
907 | 906: Windsor tie
908 | 907: wine bottle
909 | 908: wing
910 | 909: wok
911 | 910: wooden spoon
912 | 911: wool, woolen, woollen
913 | 912: worm fence, snake fence, snake-rail fence, Virginia fence
914 | 913: wreck
915 | 914: yawl
916 | 915: yurt
917 | 916: web site, website, internet site, site
918 | 917: comic book
919 | 918: crossword puzzle, crossword
920 | 919: street sign
921 | 920: traffic light, traffic signal, stoplight
922 | 921: book jacket, dust cover, dust jacket, dust wrapper
923 | 922: menu
924 | 923: plate
925 | 924: guacamole
926 | 925: consomme
927 | 926: hot pot, hotpot
928 | 927: trifle
929 | 928: ice cream, icecream
930 | 929: ice lolly, lolly, lollipop, popsicle
931 | 930: French loaf
932 | 931: bagel, beigel
933 | 932: pretzel
934 | 933: cheeseburger
935 | 934: hotdog, hot dog, red hot
936 | 935: mashed potato
937 | 936: head cabbage
938 | 937: broccoli
939 | 938: cauliflower
940 | 939: zucchini, courgette
941 | 940: spaghetti squash
942 | 941: acorn squash
943 | 942: butternut squash
944 | 943: cucumber, cuke
945 | 944: artichoke, globe artichoke
946 | 945: bell pepper
947 | 946: cardoon
948 | 947: mushroom
949 | 948: Granny Smith
950 | 949: strawberry
951 | 950: orange
952 | 951: lemon
953 | 952: fig
954 | 953: pineapple, ananas
955 | 954: banana
956 | 955: jackfruit, jak, jack
957 | 956: custard apple
958 | 957: pomegranate
959 | 958: hay
960 | 959: carbonara
961 | 960: chocolate sauce, chocolate syrup
962 | 961: dough
963 | 962: meat loaf, meatloaf
964 | 963: pizza, pizza pie
965 | 964: potpie
966 | 965: burrito
967 | 966: red wine
968 | 967: espresso
969 | 968: cup
970 | 969: eggnog
971 | 970: alp
972 | 971: bubble
973 | 972: cliff, drop, drop-off
974 | 973: coral reef
975 | 974: geyser
976 | 975: lakeside, lakeshore
977 | 976: promontory, headland, head, foreland
978 | 977: sandbar, sand bar
979 | 978: seashore, coast, seacoast, sea-coast
980 | 979: valley, vale
981 | 980: volcano
982 | 981: ballplayer, baseball player
983 | 982: groom, bridegroom
984 | 983: scuba diver
985 | 984: rapeseed
986 | 985: daisy
987 | 986: yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum
988 | 987: corn
989 | 988: acorn
990 | 989: hip, rose hip, rosehip
991 | 990: buckeye, horse chestnut, conker
992 | 991: coral fungus
993 | 992: agaric
994 | 993: gyromitra
995 | 994: stinkhorn, carrion fungus
996 | 995: earthstar
997 | 996: hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa
998 | 997: bolete
999 | 998: ear, spike, capitulum
1000 | 999: toilet tissue, toilet paper, bathroom tissue
1001 |
--------------------------------------------------------------------------------
/images/cat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hiveml/tensorflow-grad-cam/9f4c9b7a5f9c94a0490b143282abc965ce1cfb0f/images/cat.jpg
--------------------------------------------------------------------------------
/images/cat_heatmap.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hiveml/tensorflow-grad-cam/9f4c9b7a5f9c94a0490b143282abc965ce1cfb0f/images/cat_heatmap.png
--------------------------------------------------------------------------------
/images/scarjo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hiveml/tensorflow-grad-cam/9f4c9b7a5f9c94a0490b143282abc965ce1cfb0f/images/scarjo.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os,sys
2 | import tensorflow as tf
3 | import numpy as np
4 | from skimage import io
5 | from matplotlib import pyplot as plt
6 | import cv2
7 |
8 | from model.nets import nets_factory
9 | from model.preprocessing import preprocessing_factory
10 |
11 | flags = tf.app.flags
12 | flags.DEFINE_string("input", "images/cat.jpg", "Path to input image ['images/cat.jpg']")
13 | flags.DEFINE_string("output", "output.png", "Path to output image ['output.png']")
14 | flags.DEFINE_string("layer_name", None, "Layer till which to backpropagate")
15 | flags.DEFINE_string("model_name", "resnet_v2_50", "Name of the model")
16 | flags.DEFINE_string("preprocessing_name", None, "Name of the image preprocessor")
17 | flags.DEFINE_integer("eval_image_size", None, "Resize images to this size before eval")
18 | flags.DEFINE_string("dataset_dir", "./imagenet", "Location of the labels.txt")
19 | flags.DEFINE_string("checkpoint_path", "./imagenet/resnet_v2_50.ckpt", "saved weights for model")
20 | flags.DEFINE_integer("label_offset", 1, "Used for imagenet with 1001 classes for background class")
21 |
22 | FLAGS = flags.FLAGS
23 |
24 | slim = tf.contrib.slim
25 |
26 | _layer_names = { "resnet_v2_50": ["PrePool","predictions"],
27 | "resnet_v2_101": ["PrePool","predictions"],
28 | "resnet_v2_152": ["PrePool","predictions"],
29 | }
30 |
31 | _logits_name = "Logits"
32 |
33 | def load_labels_from_file(dataset_dir):
34 | labels = {}
35 | labels_name = os.path.join(dataset_dir,'labels.txt')
36 | with open(labels_name) as label_file:
37 | for line in label_file:
38 | idx,label = line.rstrip('\n').split(':')
39 | labels[int(idx)] = label
40 | assert len(labels) > 1
41 | return labels
42 |
43 |
44 | def load_image(img_path):
45 | print("Loading image")
46 | img = cv2.imread(img_path)
47 | if img is None:
48 | sys.stderr.write('Unable to load img: %s\n' % img_path)
49 | sys.exit(1)
50 | img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
51 | return img
52 |
53 |
54 | def preprocess_image(image,eval_image_size):
55 | preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
56 | image_preprocessing_fn = preprocessing_factory.get_preprocessing(
57 | preprocessing_name, is_training=False)
58 | image = image_preprocessing_fn(image, eval_image_size, eval_image_size)
59 | return image
60 |
61 | def grad_cam(img, imgs0, end_points, sess, predicted_class, layer_name, nb_classes, eval_image_size):
62 | # Conv layer tensor [?,10,10,2048]
63 | conv_layer = end_points[layer_name]
64 | # [1000]-D tensor with target class index set to 1 and rest as 0
65 | one_hot = tf.sparse_to_dense(predicted_class, [nb_classes], 1.0)
66 | signal = tf.multiply(end_points[_logits_name], one_hot)
67 | loss = tf.reduce_mean(signal)
68 |
69 | grads = tf.gradients(loss, conv_layer)[0]
70 | # Normalizing the gradients
71 | norm_grads = tf.divide(grads, tf.sqrt(tf.reduce_mean(tf.square(grads))) + tf.constant(1e-5))
72 |
73 | output, grads_val = sess.run([conv_layer, norm_grads], feed_dict={imgs0: img})
74 | output = output[0] # [10,10,2048]
75 | grads_val = grads_val[0] # [10,10,2048]
76 |
77 | weights = np.mean(grads_val, axis = (0, 1)) # [2048]
78 | cam = np.ones(output.shape[0 : 2], dtype = np.float32) # [10,10]
79 |
80 | # Taking a weighted average
81 | for i, w in enumerate(weights):
82 | cam += w * output[:, :, i]
83 |
84 | # Passing through ReLU
85 | cam = np.maximum(cam, 0)
86 | cam = cam / np.max(cam)
87 | cam3 = cv2.resize(cam, (eval_image_size,eval_image_size))
88 |
89 | return cam3
90 |
91 |
92 | def main(_):
93 | checkpoint_path=FLAGS.checkpoint_path
94 | img = load_image(FLAGS.input)
95 |
96 | labels = load_labels_from_file(FLAGS.dataset_dir)
97 | num_classes = len(labels) + FLAGS.label_offset
98 |
99 | network_fn = nets_factory.get_network_fn(
100 | FLAGS.model_name,
101 | num_classes=num_classes,
102 | is_training=False)
103 |
104 | eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size
105 |
106 | print("\nLoading Model")
107 | imgs0 = tf.placeholder(tf.uint8, [None,None, 3])
108 | imgs = preprocess_image(imgs0,eval_image_size)
109 | imgs = tf.expand_dims(imgs,0)
110 |
111 | _,end_points = network_fn(imgs)
112 |
113 | init_fn = slim.assign_from_checkpoint_fn(checkpoint_path, slim.get_variables_to_restore())
114 |
115 | print("\nFeedforwarding")
116 |
117 | with tf.Session() as sess:
118 | init_fn(sess)
119 |
120 | ep = sess.run(end_points, feed_dict={imgs0: img})
121 | pred_layer_name = _layer_names[FLAGS.model_name][1]
122 | probs = ep[pred_layer_name][0]
123 |
124 | preds = (np.argsort(probs)[::-1])[0:5]
125 | print('\nTop 5 classes are')
126 | for p in preds:
127 | print(labels[p-FLAGS.label_offset], probs[p])
128 |
129 | # Target class
130 | predicted_class = preds[0]
131 | # Target layer for visualization
132 | layer_name = FLAGS.layer_name or _layer_names[FLAGS.model_name][0]
133 | # Number of output classes of model being used
134 | nb_classes = num_classes
135 |
136 | cam3 = grad_cam(img, imgs0, end_points, sess, predicted_class, layer_name, nb_classes, eval_image_size)
137 |
138 | img = cv2.resize(img,(eval_image_size,eval_image_size))
139 | img = img.astype(float)
140 | img /= img.max()
141 |
142 |
143 | cam3 = cv2.applyColorMap(np.uint8(255*cam3), cv2.COLORMAP_JET)
144 | cam3 = cv2.cvtColor(cam3, cv2.COLOR_BGR2RGB)
145 |
146 | # Superimposing the visualization with the image.
147 | alpha = 0.0025
148 | new_img = img+alpha*cam3
149 | new_img /= new_img.max()
150 |
151 | # Display and save
152 | io.imshow(new_img)
153 | plt.axis('off')
154 | plt.savefig(FLAGS.output,bbox_inches='tight')
155 | plt.show()
156 |
157 | if __name__ == '__main__':
158 | tf.app.run()
159 |
160 |
--------------------------------------------------------------------------------
/main.sh:
--------------------------------------------------------------------------------
1 | python main.py --model_name=resnet_v2_50 --dataset_dir=./imagenet/ --checkpoint_path=./imagenet/resnet_v2_50.ckpt --input=./images/cat.jpg --eval_image_size=299
2 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hiveml/tensorflow-grad-cam/9f4c9b7a5f9c94a0490b143282abc965ce1cfb0f/model/__init__.py
--------------------------------------------------------------------------------
/model/nets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hiveml/tensorflow-grad-cam/9f4c9b7a5f9c94a0490b143282abc965ce1cfb0f/model/nets/__init__.py
--------------------------------------------------------------------------------
/model/nets/nets_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a factory for building various models."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 | import functools
21 |
22 | import tensorflow as tf
23 |
24 | #from nets import alexnet
25 | #from nets import cifarnet
26 | #from nets import inception
27 | #from nets import lenet
28 | #from nets import mobilenet_v1
29 | #from nets import overfeat
30 | #from nets import resnet_v1
31 | from model.nets import resnet_v2
32 | #from nets import vgg
33 | #from nets.nasnet import nasnet
34 |
35 | slim = tf.contrib.slim
36 |
37 | networks_map = {
38 | 'resnet_v2_50': resnet_v2.resnet_v2_50,
39 | 'resnet_v2_101': resnet_v2.resnet_v2_101,
40 | 'resnet_v2_152': resnet_v2.resnet_v2_152,
41 | 'resnet_v2_200': resnet_v2.resnet_v2_200,
42 | }
43 |
44 | arg_scopes_map = {
45 | 'resnet_v2_50': resnet_v2.resnet_arg_scope,
46 | 'resnet_v2_101': resnet_v2.resnet_arg_scope,
47 | 'resnet_v2_152': resnet_v2.resnet_arg_scope,
48 | 'resnet_v2_200': resnet_v2.resnet_arg_scope,
49 | }
50 |
51 |
52 | def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
53 | """Returns a network_fn such as `logits, end_points = network_fn(images)`.
54 |
55 | Args:
56 | name: The name of the network.
57 | num_classes: The number of classes to use for classification. If 0 or None,
58 | the logits layer is omitted and its input features are returned instead.
59 | weight_decay: The l2 coefficient for the model weights.
60 | is_training: `True` if the model is being used for training and `False`
61 | otherwise.
62 |
63 | Returns:
64 | network_fn: A function that applies the model to a batch of images. It has
65 | the following signature:
66 | net, end_points = network_fn(images)
67 | The `images` input is a tensor of shape [batch_size, height, width, 3]
68 | with height = width = network_fn.default_image_size. (The permissibility
69 | and treatment of other sizes depends on the network_fn.)
70 | The returned `end_points` are a dictionary of intermediate activations.
71 | The returned `net` is the topmost layer, depending on `num_classes`:
72 | If `num_classes` was a non-zero integer, `net` is a logits tensor
73 | of shape [batch_size, num_classes].
74 | If `num_classes` was 0 or `None`, `net` is a tensor with the input
75 | to the logits layer of shape [batch_size, 1, 1, num_features] or
76 | [batch_size, num_features]. Dropout has not been applied to this
77 | (even if the network's original classification does); it remains for
78 | the caller to do this or not.
79 |
80 | Raises:
81 | ValueError: If network `name` is not recognized.
82 | """
83 | if name not in networks_map:
84 | raise ValueError('Name of network unknown %s' % name)
85 | func = networks_map[name]
86 | @functools.wraps(func)
87 | def network_fn(images, **kwargs):
88 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
89 | with slim.arg_scope(arg_scope):
90 | return func(images, num_classes, is_training=is_training, **kwargs)
91 | if hasattr(func, 'default_image_size'):
92 | network_fn.default_image_size = func.default_image_size
93 |
94 | return network_fn
95 |
96 |
--------------------------------------------------------------------------------
/model/nets/resnet_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains building blocks for various versions of Residual Networks.
16 |
17 | Residual networks (ResNets) were proposed in:
18 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015
20 |
21 | More variants were introduced in:
22 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
23 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016
24 |
25 | We can obtain different ResNet variants by changing the network depth, width,
26 | and form of residual unit. This module implements the infrastructure for
27 | building them. Concrete ResNet units and full ResNet networks are implemented in
28 | the accompanying resnet_v1.py and resnet_v2.py modules.
29 |
30 | Compared to https://github.com/KaimingHe/deep-residual-networks, in the current
31 | implementation we subsample the output activations in the last residual unit of
32 | each block, instead of subsampling the input activations in the first residual
33 | unit of each block. The two implementations give identical results but our
34 | implementation is more memory efficient.
35 | """
36 | from __future__ import absolute_import
37 | from __future__ import division
38 | from __future__ import print_function
39 |
40 | import collections
41 | import tensorflow as tf
42 |
43 | slim = tf.contrib.slim
44 |
45 |
46 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])):
47 | """A named tuple describing a ResNet block.
48 |
49 | Its parts are:
50 | scope: The scope of the `Block`.
51 | unit_fn: The ResNet unit function which takes as input a `Tensor` and
52 | returns another `Tensor` with the output of the ResNet unit.
53 | args: A list of length equal to the number of units in the `Block`. The list
54 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the
55 | block to serve as argument to unit_fn.
56 | """
57 |
58 |
59 | def subsample(inputs, factor, scope=None):
60 | """Subsamples the input along the spatial dimensions.
61 |
62 | Args:
63 | inputs: A `Tensor` of size [batch, height_in, width_in, channels].
64 | factor: The subsampling factor.
65 | scope: Optional variable_scope.
66 |
67 | Returns:
68 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the
69 | input, either intact (if factor == 1) or subsampled (if factor > 1).
70 | """
71 | if factor == 1:
72 | return inputs
73 | else:
74 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope)
75 |
76 |
77 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None):
78 | """Strided 2-D convolution with 'SAME' padding.
79 |
80 | When stride > 1, then we do explicit zero-padding, followed by conv2d with
81 | 'VALID' padding.
82 |
83 | Note that
84 |
85 | net = conv2d_same(inputs, num_outputs, 3, stride=stride)
86 |
87 | is equivalent to
88 |
89 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME')
90 | net = subsample(net, factor=stride)
91 |
92 | whereas
93 |
94 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME')
95 |
96 | is different when the input's height or width is even, which is why we add the
97 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven().
98 |
99 | Args:
100 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels].
101 | num_outputs: An integer, the number of output filters.
102 | kernel_size: An int with the kernel_size of the filters.
103 | stride: An integer, the output stride.
104 | rate: An integer, rate for atrous convolution.
105 | scope: Scope.
106 |
107 | Returns:
108 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with
109 | the convolution output.
110 | """
111 | if stride == 1:
112 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate,
113 | padding='SAME', scope=scope)
114 | else:
115 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
116 | pad_total = kernel_size_effective - 1
117 | pad_beg = pad_total // 2
118 | pad_end = pad_total - pad_beg
119 | inputs = tf.pad(inputs,
120 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
121 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride,
122 | rate=rate, padding='VALID', scope=scope)
123 |
124 |
125 | @slim.add_arg_scope
126 | def stack_blocks_dense(net, blocks, output_stride=None,
127 | outputs_collections=None):
128 | """Stacks ResNet `Blocks` and controls output feature density.
129 |
130 | First, this function creates scopes for the ResNet in the form of
131 | 'block_name/unit_1', 'block_name/unit_2', etc.
132 |
133 | Second, this function allows the user to explicitly control the ResNet
134 | output_stride, which is the ratio of the input to output spatial resolution.
135 | This is useful for dense prediction tasks such as semantic segmentation or
136 | object detection.
137 |
138 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a
139 | factor of 2 when transitioning between consecutive ResNet blocks. This results
140 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to
141 | half the nominal network stride (e.g., output_stride=4), then we compute
142 | responses twice.
143 |
144 | Control of the output feature density is implemented by atrous convolution.
145 |
146 | Args:
147 | net: A `Tensor` of size [batch, height, width, channels].
148 | blocks: A list of length equal to the number of ResNet `Blocks`. Each
149 | element is a ResNet `Block` object describing the units in the `Block`.
150 | output_stride: If `None`, then the output will be computed at the nominal
151 | network stride. If output_stride is not `None`, it specifies the requested
152 | ratio of input to output spatial resolution, which needs to be equal to
153 | the product of unit strides from the start up to some level of the ResNet.
154 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1,
155 | then valid values for the output_stride are 1, 2, 6, 24 or None (which
156 | is equivalent to output_stride=24).
157 | outputs_collections: Collection to add the ResNet block outputs.
158 |
159 | Returns:
160 | net: Output tensor with stride equal to the specified output_stride.
161 |
162 | Raises:
163 | ValueError: If the target output_stride is not valid.
164 | """
165 | # The current_stride variable keeps track of the effective stride of the
166 | # activations. This allows us to invoke atrous convolution whenever applying
167 | # the next residual unit would result in the activations having stride larger
168 | # than the target output_stride.
169 | current_stride = 1
170 |
171 | # The atrous convolution rate parameter.
172 | rate = 1
173 |
174 | for block in blocks:
175 | with tf.variable_scope(block.scope, 'block', [net]) as sc:
176 | for i, unit in enumerate(block.args):
177 | if output_stride is not None and current_stride > output_stride:
178 | raise ValueError('The target output_stride cannot be reached.')
179 |
180 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
181 | # If we have reached the target output_stride, then we need to employ
182 | # atrous convolution with stride=1 and multiply the atrous rate by the
183 | # current unit's stride for use in subsequent layers.
184 | if output_stride is not None and current_stride == output_stride:
185 | net = block.unit_fn(net, rate=rate, **dict(unit, stride=1))
186 | rate *= unit.get('stride', 1)
187 |
188 | else:
189 | net = block.unit_fn(net, rate=1, **unit)
190 | current_stride *= unit.get('stride', 1)
191 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)
192 |
193 | if output_stride is not None and current_stride != output_stride:
194 | raise ValueError('The target output_stride cannot be reached.')
195 |
196 | return net
197 |
198 |
199 | def resnet_arg_scope(weight_decay=0.0001,
200 | batch_norm_decay=0.997,
201 | batch_norm_epsilon=1e-5,
202 | batch_norm_scale=True,
203 | activation_fn=tf.nn.relu,
204 | use_batch_norm=True):
205 | """Defines the default ResNet arg scope.
206 |
207 | TODO(gpapan): The batch-normalization related default values above are
208 | appropriate for use in conjunction with the reference ResNet models
209 | released at https://github.com/KaimingHe/deep-residual-networks. When
210 | training ResNets from scratch, they might need to be tuned.
211 |
212 | Args:
213 | weight_decay: The weight decay to use for regularizing the model.
214 | batch_norm_decay: The moving average decay when estimating layer activation
215 | statistics in batch normalization.
216 | batch_norm_epsilon: Small constant to prevent division by zero when
217 | normalizing activations by their variance in batch normalization.
218 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
219 | activations in the batch normalization layer.
220 | activation_fn: The activation function which is used in ResNet.
221 | use_batch_norm: Whether or not to use batch normalization.
222 |
223 | Returns:
224 | An `arg_scope` to use for the resnet models.
225 | """
226 | batch_norm_params = {
227 | 'decay': batch_norm_decay,
228 | 'epsilon': batch_norm_epsilon,
229 | 'scale': batch_norm_scale,
230 | 'updates_collections': tf.GraphKeys.UPDATE_OPS,
231 | 'fused': None, # Use fused batch norm if possible.
232 | }
233 |
234 | with slim.arg_scope(
235 | [slim.conv2d],
236 | weights_regularizer=slim.l2_regularizer(weight_decay),
237 | weights_initializer=slim.variance_scaling_initializer(),
238 | activation_fn=activation_fn,
239 | normalizer_fn=slim.batch_norm if use_batch_norm else None,
240 | normalizer_params=batch_norm_params):
241 | with slim.arg_scope([slim.batch_norm], **batch_norm_params):
242 | # The following implies padding='SAME' for pool1, which makes feature
243 | # alignment easier for dense prediction tasks. This is also used in
244 | # https://github.com/facebook/fb.resnet.torch. However the accompanying
245 | # code of 'Deep Residual Learning for Image Recognition' uses
246 | # padding='VALID' for pool1. You can switch to that choice by setting
247 | # slim.arg_scope([slim.max_pool2d], padding='VALID').
248 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
249 | return arg_sc
250 |
251 |
--------------------------------------------------------------------------------
/model/nets/resnet_v2.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains definitions for the preactivation form of Residual Networks.
16 |
17 | Residual networks (ResNets) were originally proposed in:
18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385
20 |
21 | The full preactivation 'v2' ResNet variant implemented in this module was
22 | introduced by:
23 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
24 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027
25 |
26 | The key difference of the full preactivation 'v2' variant compared to the
27 | 'v1' variant in [1] is the use of batch normalization before every weight layer.
28 |
29 | Typical use:
30 |
31 | from tensorflow.contrib.slim.nets import resnet_v2
32 |
33 | ResNet-101 for image classification into 1000 classes:
34 |
35 | # inputs has shape [batch, 224, 224, 3]
36 | with slim.arg_scope(resnet_v2.resnet_arg_scope()):
37 | net, end_points = resnet_v2.resnet_v2_101(inputs, 1000, is_training=False)
38 |
39 | ResNet-101 for semantic segmentation into 21 classes:
40 |
41 | # inputs has shape [batch, 513, 513, 3]
42 | with slim.arg_scope(resnet_v2.resnet_arg_scope()):
43 | net, end_points = resnet_v2.resnet_v2_101(inputs,
44 | 21,
45 | is_training=False,
46 | global_pool=False,
47 | output_stride=16)
48 | """
49 | from __future__ import absolute_import
50 | from __future__ import division
51 | from __future__ import print_function
52 |
53 | import tensorflow as tf
54 |
55 | from model.nets import resnet_utils
56 |
57 | slim = tf.contrib.slim
58 | resnet_arg_scope = resnet_utils.resnet_arg_scope
59 |
60 |
61 | @slim.add_arg_scope
62 | def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1,
63 | outputs_collections=None, scope=None):
64 | """Bottleneck residual unit variant with BN before convolutions.
65 |
66 | This is the full preactivation residual unit variant proposed in [2]. See
67 | Fig. 1(b) of [2] for its definition. Note that we use here the bottleneck
68 | variant which has an extra bottleneck layer.
69 |
70 | When putting together two consecutive ResNet blocks that use this unit, one
71 | should use stride = 2 in the last unit of the first block.
72 |
73 | Args:
74 | inputs: A tensor of size [batch, height, width, channels].
75 | depth: The depth of the ResNet unit output.
76 | depth_bottleneck: The depth of the bottleneck layers.
77 | stride: The ResNet unit's stride. Determines the amount of downsampling of
78 | the units output compared to its input.
79 | rate: An integer, rate for atrous convolution.
80 | outputs_collections: Collection to add the ResNet unit output.
81 | scope: Optional variable_scope.
82 |
83 | Returns:
84 | The ResNet unit's output.
85 | """
86 | with tf.variable_scope(scope, 'bottleneck_v2', [inputs]) as sc:
87 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
88 | preact = slim.batch_norm(inputs, activation_fn=tf.nn.relu, scope='preact')
89 | if depth == depth_in:
90 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
91 | else:
92 | shortcut = slim.conv2d(preact, depth, [1, 1], stride=stride,
93 | normalizer_fn=None, activation_fn=None,
94 | scope='shortcut')
95 |
96 | residual = slim.conv2d(preact, depth_bottleneck, [1, 1], stride=1,
97 | scope='conv1')
98 | residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride,
99 | rate=rate, scope='conv2')
100 | residual = slim.conv2d(residual, depth, [1, 1], stride=1,
101 | normalizer_fn=None, activation_fn=None,
102 | scope='conv3')
103 |
104 | output = shortcut + residual
105 |
106 | return slim.utils.collect_named_outputs(outputs_collections,
107 | sc.name,
108 | output)
109 |
110 |
111 | def resnet_v2(inputs,
112 | blocks,
113 | num_classes=None,
114 | is_training=True,
115 | global_pool=True,
116 | output_stride=None,
117 | include_root_block=True,
118 | spatial_squeeze=True,
119 | reuse=None,
120 | scope=None):
121 | """Generator for v2 (preactivation) ResNet models.
122 |
123 | This function generates a family of ResNet v2 models. See the resnet_v2_*()
124 | methods for specific model instantiations, obtained by selecting different
125 | block instantiations that produce ResNets of various depths.
126 |
127 | Training for image classification on Imagenet is usually done with [224, 224]
128 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet
129 | block for the ResNets defined in [1] that have nominal stride equal to 32.
130 | However, for dense prediction tasks we advise that one uses inputs with
131 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In
132 | this case the feature maps at the ResNet output will have spatial shape
133 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1]
134 | and corners exactly aligned with the input image corners, which greatly
135 | facilitates alignment of the features to the image. Using as input [225, 225]
136 | images results in [8, 8] feature maps at the output of the last ResNet block.
137 |
138 | For dense prediction tasks, the ResNet needs to run in fully-convolutional
139 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all
140 | have nominal stride equal to 32 and a good choice in FCN mode is to use
141 | output_stride=16 in order to increase the density of the computed features at
142 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915.
143 |
144 | Args:
145 | inputs: A tensor of size [batch, height_in, width_in, channels].
146 | blocks: A list of length equal to the number of ResNet blocks. Each element
147 | is a resnet_utils.Block object describing the units in the block.
148 | num_classes: Number of predicted classes for classification tasks.
149 | If 0 or None, we return the features before the logit layer.
150 | is_training: whether batch_norm layers are in training mode.
151 | global_pool: If True, we perform global average pooling before computing the
152 | logits. Set to True for image classification, False for dense prediction.
153 | output_stride: If None, then the output will be computed at the nominal
154 | network stride. If output_stride is not None, it specifies the requested
155 | ratio of input to output spatial resolution.
156 | include_root_block: If True, include the initial convolution followed by
157 | max-pooling, if False excludes it. If excluded, `inputs` should be the
158 | results of an activation-less convolution.
159 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is
160 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
161 | To use this parameter, the input images must be smaller than 300x300
162 | pixels, in which case the output logit layer does not contain spatial
163 | information and can be removed.
164 | reuse: whether or not the network and its variables should be reused. To be
165 | able to reuse 'scope' must be given.
166 | scope: Optional variable_scope.
167 |
168 |
169 | Returns:
170 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
171 | If global_pool is False, then height_out and width_out are reduced by a
172 | factor of output_stride compared to the respective height_in and width_in,
173 | else both height_out and width_out equal one. If num_classes is 0 or None,
174 | then net is the output of the last ResNet block, potentially after global
175 | average pooling. If num_classes is a non-zero integer, net contains the
176 | pre-softmax activations.
177 | end_points: A dictionary from components of the network to the corresponding
178 | activation.
179 |
180 | Raises:
181 | ValueError: If the target output_stride is not valid.
182 | """
183 | with tf.variable_scope(scope, 'resnet_v2', [inputs], reuse=reuse) as sc:
184 | end_points_collection = sc.original_name_scope + '_end_points'
185 | with slim.arg_scope([slim.conv2d, bottleneck,
186 | resnet_utils.stack_blocks_dense],
187 | outputs_collections=end_points_collection):
188 | with slim.arg_scope([slim.batch_norm], is_training=is_training):
189 | net = inputs
190 | if include_root_block:
191 | if output_stride is not None:
192 | if output_stride % 4 != 0:
193 | raise ValueError('The output_stride needs to be a multiple of 4.')
194 | output_stride /= 4
195 | # We do not include batch normalization or activation functions in
196 | # conv1 because the first ResNet unit will perform these. Cf.
197 | # Appendix of [2].
198 | with slim.arg_scope([slim.conv2d],
199 | activation_fn=None, normalizer_fn=None):
200 | net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1')
201 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1')
202 | net = resnet_utils.stack_blocks_dense(net, blocks, output_stride)
203 | # This is needed because the pre-activation variant does not have batch
204 | # normalization or activation functions in the residual unit output. See
205 | # Appendix of [2].
206 | net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='postnorm')
207 | # Convert end_points_collection into a dictionary of end_points.
208 | end_points = slim.utils.convert_collection_to_dict(
209 | end_points_collection)
210 |
211 | end_points['PrePool'] = net
212 | if global_pool:
213 | # Global average pooling.
214 | net = tf.reduce_mean(net, [1, 2], name='pool5', keepdims=True)
215 | end_points['global_pool'] = net
216 | if num_classes is not None:
217 | net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
218 | normalizer_fn=None, scope='logits')
219 | end_points[sc.name + '/logits'] = net
220 | if spatial_squeeze:
221 | net = tf.squeeze(net, [1, 2], name='SpatialSqueeze')
222 | end_points[sc.name + '/spatial_squeeze'] = net
223 | end_points['Logits'] = net
224 | end_points['predictions'] = slim.softmax(net, scope='predictions')
225 | return net, end_points
226 | resnet_v2.default_image_size = 224
227 |
228 |
229 | def resnet_v2_block(scope, base_depth, num_units, stride):
230 | """Helper function for creating a resnet_v2 bottleneck block.
231 |
232 | Args:
233 | scope: The scope of the block.
234 | base_depth: The depth of the bottleneck layer for each unit.
235 | num_units: The number of units in the block.
236 | stride: The stride of the block, implemented as a stride in the last unit.
237 | All other units have stride=1.
238 |
239 | Returns:
240 | A resnet_v2 bottleneck block.
241 | """
242 | return resnet_utils.Block(scope, bottleneck, [{
243 | 'depth': base_depth * 4,
244 | 'depth_bottleneck': base_depth,
245 | 'stride': 1
246 | }] * (num_units - 1) + [{
247 | 'depth': base_depth * 4,
248 | 'depth_bottleneck': base_depth,
249 | 'stride': stride
250 | }])
251 | resnet_v2.default_image_size = 224
252 |
253 |
254 | def resnet_v2_50(inputs,
255 | num_classes=None,
256 | is_training=True,
257 | global_pool=True,
258 | output_stride=None,
259 | spatial_squeeze=True,
260 | reuse=None,
261 | scope='resnet_v2_50'):
262 | """ResNet-50 model of [1]. See resnet_v2() for arg and return description."""
263 | blocks = [
264 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
265 | resnet_v2_block('block2', base_depth=128, num_units=4, stride=2),
266 | resnet_v2_block('block3', base_depth=256, num_units=6, stride=2),
267 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
268 | ]
269 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
270 | global_pool=global_pool, output_stride=output_stride,
271 | include_root_block=True, spatial_squeeze=spatial_squeeze,
272 | reuse=reuse, scope=scope)
273 | resnet_v2_50.default_image_size = resnet_v2.default_image_size
274 |
275 |
276 | def resnet_v2_101(inputs,
277 | num_classes=None,
278 | is_training=True,
279 | global_pool=True,
280 | output_stride=None,
281 | spatial_squeeze=True,
282 | reuse=None,
283 | scope='resnet_v2_101'):
284 | """ResNet-101 model of [1]. See resnet_v2() for arg and return description."""
285 | blocks = [
286 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
287 | resnet_v2_block('block2', base_depth=128, num_units=4, stride=2),
288 | resnet_v2_block('block3', base_depth=256, num_units=23, stride=2),
289 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
290 | ]
291 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
292 | global_pool=global_pool, output_stride=output_stride,
293 | include_root_block=True, spatial_squeeze=spatial_squeeze,
294 | reuse=reuse, scope=scope)
295 | resnet_v2_101.default_image_size = resnet_v2.default_image_size
296 |
297 |
298 | def resnet_v2_152(inputs,
299 | num_classes=None,
300 | is_training=True,
301 | global_pool=True,
302 | output_stride=None,
303 | spatial_squeeze=True,
304 | reuse=None,
305 | scope='resnet_v2_152'):
306 | """ResNet-152 model of [1]. See resnet_v2() for arg and return description."""
307 | blocks = [
308 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
309 | resnet_v2_block('block2', base_depth=128, num_units=8, stride=2),
310 | resnet_v2_block('block3', base_depth=256, num_units=36, stride=2),
311 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
312 | ]
313 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
314 | global_pool=global_pool, output_stride=output_stride,
315 | include_root_block=True, spatial_squeeze=spatial_squeeze,
316 | reuse=reuse, scope=scope)
317 | resnet_v2_152.default_image_size = resnet_v2.default_image_size
318 |
319 |
320 | def resnet_v2_200(inputs,
321 | num_classes=None,
322 | is_training=True,
323 | global_pool=True,
324 | output_stride=None,
325 | spatial_squeeze=True,
326 | reuse=None,
327 | scope='resnet_v2_200'):
328 | """ResNet-200 model of [2]. See resnet_v2() for arg and return description."""
329 | blocks = [
330 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2),
331 | resnet_v2_block('block2', base_depth=128, num_units=24, stride=2),
332 | resnet_v2_block('block3', base_depth=256, num_units=36, stride=2),
333 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1),
334 | ]
335 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
336 | global_pool=global_pool, output_stride=output_stride,
337 | include_root_block=True, spatial_squeeze=spatial_squeeze,
338 | reuse=reuse, scope=scope)
339 | resnet_v2_200.default_image_size = resnet_v2.default_image_size
340 |
341 |
--------------------------------------------------------------------------------
/model/preprocessing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hiveml/tensorflow-grad-cam/9f4c9b7a5f9c94a0490b143282abc965ce1cfb0f/model/preprocessing/__init__.py
--------------------------------------------------------------------------------
/model/preprocessing/inception_preprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Provides utilities to preprocess images for the Inception networks."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | from tensorflow.python.ops import control_flow_ops
24 |
25 |
26 | def apply_with_random_selector(x, func, num_cases):
27 | """Computes func(x, sel), with sel sampled from [0...num_cases-1].
28 |
29 | Args:
30 | x: input Tensor.
31 | func: Python function to apply.
32 | num_cases: Python int32, number of cases to sample sel from.
33 |
34 | Returns:
35 | The result of func(x, sel), where func receives the value of the
36 | selector as a python integer, but sel is sampled dynamically.
37 | """
38 | sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
39 | # Pass the real x only to one of the func calls.
40 | return control_flow_ops.merge([
41 | func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
42 | for case in range(num_cases)])[0]
43 |
44 |
45 | def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
46 | """Distort the color of a Tensor image.
47 |
48 | Each color distortion is non-commutative and thus ordering of the color ops
49 | matters. Ideally we would randomly permute the ordering of the color ops.
50 | Rather then adding that level of complication, we select a distinct ordering
51 | of color ops for each preprocessing thread.
52 |
53 | Args:
54 | image: 3-D Tensor containing single image in [0, 1].
55 | color_ordering: Python int, a type of distortion (valid values: 0-3).
56 | fast_mode: Avoids slower ops (random_hue and random_contrast)
57 | scope: Optional scope for name_scope.
58 | Returns:
59 | 3-D Tensor color-distorted image on range [0, 1]
60 | Raises:
61 | ValueError: if color_ordering not in [0, 3]
62 | """
63 | with tf.name_scope(scope, 'distort_color', [image]):
64 | if fast_mode:
65 | if color_ordering == 0:
66 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
67 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
68 | else:
69 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
70 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
71 | else:
72 | if color_ordering == 0:
73 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
74 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
75 | image = tf.image.random_hue(image, max_delta=0.2)
76 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
77 | elif color_ordering == 1:
78 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
79 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
80 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
81 | image = tf.image.random_hue(image, max_delta=0.2)
82 | elif color_ordering == 2:
83 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
84 | image = tf.image.random_hue(image, max_delta=0.2)
85 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
86 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
87 | elif color_ordering == 3:
88 | image = tf.image.random_hue(image, max_delta=0.2)
89 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
90 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
91 | image = tf.image.random_brightness(image, max_delta=32. / 255.)
92 | else:
93 | raise ValueError('color_ordering must be in [0, 3]')
94 |
95 | # The random_* ops do not necessarily clamp.
96 | return tf.clip_by_value(image, 0.0, 1.0)
97 |
98 |
99 | def distorted_bounding_box_crop(image,
100 | bbox,
101 | min_object_covered=0.1,
102 | aspect_ratio_range=(0.75, 1.33),
103 | area_range=(0.05, 1.0),
104 | max_attempts=100,
105 | scope=None):
106 | """Generates cropped_image using a one of the bboxes randomly distorted.
107 |
108 | See `tf.image.sample_distorted_bounding_box` for more documentation.
109 |
110 | Args:
111 | image: 3-D Tensor of image (it will be converted to floats in [0, 1]).
112 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
113 | where each coordinate is [0, 1) and the coordinates are arranged
114 | as [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the whole
115 | image.
116 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
117 | area of the image must contain at least this fraction of any bounding box
118 | supplied.
119 | aspect_ratio_range: An optional list of `floats`. The cropped area of the
120 | image must have an aspect ratio = width / height within this range.
121 | area_range: An optional list of `floats`. The cropped area of the image
122 | must contain a fraction of the supplied image within in this range.
123 | max_attempts: An optional `int`. Number of attempts at generating a cropped
124 | region of the image of the specified constraints. After `max_attempts`
125 | failures, return the entire image.
126 | scope: Optional scope for name_scope.
127 | Returns:
128 | A tuple, a 3-D Tensor cropped_image and the distorted bbox
129 | """
130 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
131 | # Each bounding box has shape [1, num_boxes, box coords] and
132 | # the coordinates are ordered [ymin, xmin, ymax, xmax].
133 |
134 | # A large fraction of image datasets contain a human-annotated bounding
135 | # box delineating the region of the image containing the object of interest.
136 | # We choose to create a new bounding box for the object which is a randomly
137 | # distorted version of the human-annotated bounding box that obeys an
138 | # allowed range of aspect ratios, sizes and overlap with the human-annotated
139 | # bounding box. If no box is supplied, then we assume the bounding box is
140 | # the entire image.
141 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
142 | tf.shape(image),
143 | bounding_boxes=bbox,
144 | min_object_covered=min_object_covered,
145 | aspect_ratio_range=aspect_ratio_range,
146 | area_range=area_range,
147 | max_attempts=max_attempts,
148 | use_image_if_no_bounding_boxes=True)
149 | bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box
150 |
151 | # Crop the image to the specified bounding box.
152 | cropped_image = tf.slice(image, bbox_begin, bbox_size)
153 | return cropped_image, distort_bbox
154 |
155 |
156 | def preprocess_for_train(image, height, width, bbox,
157 | fast_mode=True,
158 | scope=None,
159 | add_image_summaries=True):
160 | """Distort one image for training a network.
161 |
162 | Distorting images provides a useful technique for augmenting the data
163 | set during training in order to make the network invariant to aspects
164 | of the image that do not effect the label.
165 |
166 | Additionally it would create image_summaries to display the different
167 | transformations applied to the image.
168 |
169 | Args:
170 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
171 | [0, 1], otherwise it would converted to tf.float32 assuming that the range
172 | is [0, MAX], where MAX is largest positive representable number for
173 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details).
174 | height: integer
175 | width: integer
176 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
177 | where each coordinate is [0, 1) and the coordinates are arranged
178 | as [ymin, xmin, ymax, xmax].
179 | fast_mode: Optional boolean, if True avoids slower transformations (i.e.
180 | bi-cubic resizing, random_hue or random_contrast).
181 | scope: Optional scope for name_scope.
182 | add_image_summaries: Enable image summaries.
183 | Returns:
184 | 3-D float Tensor of distorted image used for training with range [-1, 1].
185 | """
186 | with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]):
187 | if bbox is None:
188 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0],
189 | dtype=tf.float32,
190 | shape=[1, 1, 4])
191 | if image.dtype != tf.float32:
192 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
193 | # Each bounding box has shape [1, num_boxes, box coords] and
194 | # the coordinates are ordered [ymin, xmin, ymax, xmax].
195 | image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
196 | bbox)
197 | if add_image_summaries:
198 | tf.summary.image('image_with_bounding_boxes', image_with_box)
199 |
200 | distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox)
201 | # Restore the shape since the dynamic slice based upon the bbox_size loses
202 | # the third dimension.
203 | distorted_image.set_shape([None, None, 3])
204 | image_with_distorted_box = tf.image.draw_bounding_boxes(
205 | tf.expand_dims(image, 0), distorted_bbox)
206 | if add_image_summaries:
207 | tf.summary.image('images_with_distorted_bounding_box',
208 | image_with_distorted_box)
209 |
210 | # This resizing operation may distort the images because the aspect
211 | # ratio is not respected. We select a resize method in a round robin
212 | # fashion based on the thread number.
213 | # Note that ResizeMethod contains 4 enumerated resizing methods.
214 |
215 | # We select only 1 case for fast_mode bilinear.
216 | num_resize_cases = 1 if fast_mode else 4
217 | distorted_image = apply_with_random_selector(
218 | distorted_image,
219 | lambda x, method: tf.image.resize_images(x, [height, width], method),
220 | num_cases=num_resize_cases)
221 |
222 | if add_image_summaries:
223 | tf.summary.image('cropped_resized_image',
224 | tf.expand_dims(distorted_image, 0))
225 |
226 | # Randomly flip the image horizontally.
227 | distorted_image = tf.image.random_flip_left_right(distorted_image)
228 |
229 | # Randomly distort the colors. There are 4 ways to do it.
230 | distorted_image = apply_with_random_selector(
231 | distorted_image,
232 | lambda x, ordering: distort_color(x, ordering, fast_mode),
233 | num_cases=4)
234 |
235 | if add_image_summaries:
236 | tf.summary.image('final_distorted_image',
237 | tf.expand_dims(distorted_image, 0))
238 | distorted_image = tf.subtract(distorted_image, 0.5)
239 | distorted_image = tf.multiply(distorted_image, 2.0)
240 | return distorted_image
241 |
242 |
243 | def preprocess_for_eval(image, height, width,
244 | central_fraction=0.875, scope=None):
245 | """Prepare one image for evaluation.
246 |
247 | If height and width are specified it would output an image with that size by
248 | applying resize_bilinear.
249 |
250 | If central_fraction is specified it would crop the central fraction of the
251 | input image.
252 |
253 | Args:
254 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
255 | [0, 1], otherwise it would converted to tf.float32 assuming that the range
256 | is [0, MAX], where MAX is largest positive representable number for
257 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details).
258 | height: integer
259 | width: integer
260 | central_fraction: Optional Float, fraction of the image to crop.
261 | scope: Optional scope for name_scope.
262 | Returns:
263 | 3-D float Tensor of prepared image.
264 | """
265 | with tf.name_scope(scope, 'eval_image', [image, height, width]):
266 | if image.dtype != tf.float32:
267 | image = tf.image.convert_image_dtype(image, dtype=tf.float32)
268 | # Crop the central region of the image with an area containing 87.5% of
269 | # the original image.
270 | if central_fraction:
271 | image = tf.image.central_crop(image, central_fraction=central_fraction)
272 |
273 | if height and width:
274 | # Resize the image to the specified height and width.
275 | image = tf.expand_dims(image, 0)
276 | image = tf.image.resize_bilinear(image, [height, width],
277 | align_corners=False)
278 | image = tf.squeeze(image, [0])
279 | image = tf.subtract(image, 0.5)
280 | image = tf.multiply(image, 2.0)
281 | return image
282 |
283 |
284 | def preprocess_image(image, height, width,
285 | is_training=False,
286 | bbox=None,
287 | fast_mode=True,
288 | add_image_summaries=True):
289 | """Pre-process one image for training or evaluation.
290 |
291 | Args:
292 | image: 3-D Tensor [height, width, channels] with the image. If dtype is
293 | tf.float32 then the range should be [0, 1], otherwise it would converted
294 | to tf.float32 assuming that the range is [0, MAX], where MAX is largest
295 | positive representable number for int(8/16/32) data type (see
296 | `tf.image.convert_image_dtype` for details).
297 | height: integer, image expected height.
298 | width: integer, image expected width.
299 | is_training: Boolean. If true it would transform an image for train,
300 | otherwise it would transform it for evaluation.
301 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
302 | where each coordinate is [0, 1) and the coordinates are arranged as
303 | [ymin, xmin, ymax, xmax].
304 | fast_mode: Optional boolean, if True avoids slower transformations.
305 | add_image_summaries: Enable image summaries.
306 |
307 | Returns:
308 | 3-D float Tensor containing an appropriately scaled image
309 |
310 | Raises:
311 | ValueError: if user does not provide bounding box
312 | """
313 | if is_training:
314 | return preprocess_for_train(image, height, width, bbox, fast_mode,
315 | add_image_summaries=add_image_summaries)
316 | else:
317 | return preprocess_for_eval(image, height, width)
318 |
319 |
--------------------------------------------------------------------------------
/model/preprocessing/preprocessing_factory.py:
--------------------------------------------------------------------------------
1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 | """Contains a factory for building various models."""
16 |
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | import tensorflow as tf
22 |
23 | #from preprocessing import cifarnet_preprocessing
24 | from model.preprocessing import inception_preprocessing
25 | #from preprocessing import lenet_preprocessing
26 | #from preprocessing import vgg_preprocessing
27 |
28 | slim = tf.contrib.slim
29 |
30 |
31 | def get_preprocessing(name, is_training=False):
32 | """Returns preprocessing_fn(image, height, width, **kwargs).
33 |
34 | Args:
35 | name: The name of the preprocessing function.
36 | is_training: `True` if the model is being used for training and `False`
37 | otherwise.
38 |
39 | Returns:
40 | preprocessing_fn: A function that preprocessing a single image (pre-batch).
41 | It has the following signature:
42 | image = preprocessing_fn(image, output_height, output_width, ...).
43 |
44 | Raises:
45 | ValueError: If Preprocessing `name` is not recognized.
46 | """
47 | preprocessing_fn_map = {
48 | 'resnet_v2_50': inception_preprocessing,
49 | 'resnet_v2_101': inception_preprocessing,
50 | 'resnet_v2_152': inception_preprocessing,
51 | 'resnet_v2_200': inception_preprocessing,
52 | }
53 |
54 | if name not in preprocessing_fn_map:
55 | raise ValueError('Preprocessing name [%s] was not recognized' % name)
56 |
57 | def preprocessing_fn(image, output_height, output_width, **kwargs):
58 | return preprocessing_fn_map[name].preprocess_image(
59 | image, output_height, output_width, is_training=is_training, **kwargs)
60 |
61 | return preprocessing_fn
62 |
63 |
--------------------------------------------------------------------------------