├── .gitignore ├── Dockerfile ├── README.md ├── imagenet ├── get_checkpoint.sh └── labels.txt ├── images ├── cat.jpg ├── cat_heatmap.png └── scarjo.png ├── main.py ├── main.sh └── model ├── __init__.py ├── nets ├── __init__.py ├── nets_factory.py ├── resnet_utils.py └── resnet_v2.py └── preprocessing ├── __init__.py ├── inception_preprocessing.py └── preprocessing_factory.py /.gitignore: -------------------------------------------------------------------------------- 1 | imagenet/*.ckpt 2 | imagenet/*.graph 3 | output.png 4 | *.pyc 5 | model/nets/*.pyc 6 | model/preprocessing/*.pyc 7 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM tensorflow/tensorflow:1.5.0-rc0-devel 2 | 3 | RUN pip install -U pip 4 | 5 | RUN apt-get update && \ 6 | apt-get install -y \ 7 | build-essential \ 8 | cmake \ 9 | git \ 10 | libgtk2.0-dev \ 11 | pkg-config \ 12 | libavcodec-dev \ 13 | libavformat-dev \ 14 | libswscale-dev \ 15 | python-dev \ 16 | python-numpy \ 17 | python-skimage \ 18 | python-tk \ 19 | libtbb2 \ 20 | libtbb-dev \ 21 | libjpeg-dev \ 22 | libpng-dev \ 23 | libtiff-dev \ 24 | libjasper-dev \ 25 | libdc1394-22-dev \ 26 | qt5-default \ 27 | wget \ 28 | vim 29 | 30 | RUN git clone https://github.com/opencv/opencv.git /root/opencv && \ 31 | cd /root/opencv && \ 32 | git checkout 2.4 && \ 33 | mkdir build && \ 34 | cd build && \ 35 | cmake -DWITH_QT=ON -DWITH_OPENGL=ON -DFORCE_VTK=ON -DWITH_TBB=ON -DWITH_GDAL=ON -DWITH_XINE=ON -DBUILD_EXAMPLES=ON .. && \ 36 | make -j"$(nproc)" && \ 37 | make install && \ 38 | ldconfig && \ 39 | echo 'ln /dev/null /dev/raw1394' >> ~/.bashrc 40 | 41 | RUN ln /dev/null /dev/raw1394 42 | 43 | RUN cd /root && git clone https://github.com/hiveml/tensorflow-grad-cam 44 | 45 | WORKDIR /root/tensorflow-grad-cam 46 | 47 | RUN cd imagenet && ./get_checkpoint.sh 48 | 49 | CMD /bin/bash 50 | 51 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Grad-Cam - Tensorflow Slim 4 | 5 | 6 | 7 | 8 | ### Features: 9 | 10 | Modular with Tensorflow slim. Easy to drop in other [Slim Models](https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models) 11 | 12 | Udated to work with Tensorflow 1.5 13 | 14 | Includes various output options: heatmap, shading, blur 15 | 16 | #### More examples and explanation [here](https://thehive.ai/blog/inside-a-neural-networks-mind) 17 | 18 | 19 | ### Dependencies 20 | 21 | * Python 2.7 and pip 22 | * Scikit image : `sudo apt-get install python-skimage` 23 | * Tkinter: `sudo apt-get install python-tk` 24 | * Tensorflow >= 1.5 : `pip install tensorflow==1.5.0rc0` 25 | * Opencv - see https://docs.opencv.org/2.4/doc/tutorials/introduction/linux_install/linux_install.html 26 | 27 | 28 | ### Installation 29 | 30 | Clone the repo: 31 | ``` 32 | git clone https://github.com/hiveml/tensorflow-grad-cam.git 33 | cd tensorflow-grad-cam 34 | ``` 35 | Download the ResNet-50 weights: 36 | ``` 37 | ./imagenet/get_checkpoint.sh 38 | ``` 39 | ### Usage 40 | ``` 41 | ./main.sh 42 | ``` 43 | 44 | ### Changing the Class 45 | 46 | By default this code shows the grad-cam results for the top class. You can 47 | change the `predicted_class` argument to function `grad_cam` to see where the network 48 | would look for other classes. 49 | 50 | ### How to load another resnet\_v2 model 51 | 52 | First download the new model from here: [Slim Models](https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models) 53 | 54 | Then modify the input arguments in main.sh: 55 | ``` 56 | python main.py --model_name=resnet_v2_101 --dataset_dir=./imagenet/ --checkpoint_path=./imagenet/resnet_v2_101.ckpt --input=./images/cat.jpg --eval_image_size=299 57 | 58 | ``` 59 | 60 | 61 | Repo is based off this [code](https://github.com/Ankush96/grad-cam.tensorflow). 62 | 63 | -------------------------------------------------------------------------------- /imagenet/get_checkpoint.sh: -------------------------------------------------------------------------------- 1 | wget http://download.tensorflow.org/models/resnet_v2_50_2017_04_14.tar.gz 2 | tar xzvf resnet_v2_50_2017_04_14.tar.gz 3 | rm resnet_v2_50_2017_04_14.tar.gz 4 | -------------------------------------------------------------------------------- /imagenet/labels.txt: -------------------------------------------------------------------------------- 1 | 0: tench, Tinca tinca 2 | 1: goldfish, Carassius auratus 3 | 2: great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias 4 | 3: tiger shark, Galeocerdo cuvieri 5 | 4: hammerhead, hammerhead shark 6 | 5: electric ray, crampfish, numbfish, torpedo 7 | 6: stingray 8 | 7: cock 9 | 8: hen 10 | 9: ostrich, Struthio camelus 11 | 10: brambling, Fringilla montifringilla 12 | 11: goldfinch, Carduelis carduelis 13 | 12: house finch, linnet, Carpodacus mexicanus 14 | 13: junco, snowbird 15 | 14: indigo bunting, indigo finch, indigo bird, Passerina cyanea 16 | 15: robin, American robin, Turdus migratorius 17 | 16: bulbul 18 | 17: jay 19 | 18: magpie 20 | 19: chickadee 21 | 20: water ouzel, dipper 22 | 21: kite 23 | 22: bald eagle, American eagle, Haliaeetus leucocephalus 24 | 23: vulture 25 | 24: great grey owl, great gray owl, Strix nebulosa 26 | 25: European fire salamander, Salamandra salamandra 27 | 26: common newt, Triturus vulgaris 28 | 27: eft 29 | 28: spotted salamander, Ambystoma maculatum 30 | 29: axolotl, mud puppy, Ambystoma mexicanum 31 | 30: bullfrog, Rana catesbeiana 32 | 31: tree frog, tree-frog 33 | 32: tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui 34 | 33: loggerhead, loggerhead turtle, Caretta caretta 35 | 34: leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea 36 | 35: mud turtle 37 | 36: terrapin 38 | 37: box turtle, box tortoise 39 | 38: banded gecko 40 | 39: common iguana, iguana, Iguana iguana 41 | 40: American chameleon, anole, Anolis carolinensis 42 | 41: whiptail, whiptail lizard 43 | 42: agama 44 | 43: frilled lizard, Chlamydosaurus kingi 45 | 44: alligator lizard 46 | 45: Gila monster, Heloderma suspectum 47 | 46: green lizard, Lacerta viridis 48 | 47: African chameleon, Chamaeleo chamaeleon 49 | 48: Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis 50 | 49: African crocodile, Nile crocodile, Crocodylus niloticus 51 | 50: American alligator, Alligator mississipiensis 52 | 51: triceratops 53 | 52: thunder snake, worm snake, Carphophis amoenus 54 | 53: ringneck snake, ring-necked snake, ring snake 55 | 54: hognose snake, puff adder, sand viper 56 | 55: green snake, grass snake 57 | 56: king snake, kingsnake 58 | 57: garter snake, grass snake 59 | 58: water snake 60 | 59: vine snake 61 | 60: night snake, Hypsiglena torquata 62 | 61: boa constrictor, Constrictor constrictor 63 | 62: rock python, rock snake, Python sebae 64 | 63: Indian cobra, Naja naja 65 | 64: green mamba 66 | 65: sea snake 67 | 66: horned viper, cerastes, sand viper, horned asp, Cerastes cornutus 68 | 67: diamondback, diamondback rattlesnake, Crotalus adamanteus 69 | 68: sidewinder, horned rattlesnake, Crotalus cerastes 70 | 69: trilobite 71 | 70: harvestman, daddy longlegs, Phalangium opilio 72 | 71: scorpion 73 | 72: black and gold garden spider, Argiope aurantia 74 | 73: barn spider, Araneus cavaticus 75 | 74: garden spider, Aranea diademata 76 | 75: black widow, Latrodectus mactans 77 | 76: tarantula 78 | 77: wolf spider, hunting spider 79 | 78: tick 80 | 79: centipede 81 | 80: black grouse 82 | 81: ptarmigan 83 | 82: ruffed grouse, partridge, Bonasa umbellus 84 | 83: prairie chicken, prairie grouse, prairie fowl 85 | 84: peacock 86 | 85: quail 87 | 86: partridge 88 | 87: African grey, African gray, Psittacus erithacus 89 | 88: macaw 90 | 89: sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita 91 | 90: lorikeet 92 | 91: coucal 93 | 92: bee eater 94 | 93: hornbill 95 | 94: hummingbird 96 | 95: jacamar 97 | 96: toucan 98 | 97: drake 99 | 98: red-breasted merganser, Mergus serrator 100 | 99: goose 101 | 100: black swan, Cygnus atratus 102 | 101: tusker 103 | 102: echidna, spiny anteater, anteater 104 | 103: platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus 105 | 104: wallaby, brush kangaroo 106 | 105: koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus 107 | 106: wombat 108 | 107: jellyfish 109 | 108: sea anemone, anemone 110 | 109: brain coral 111 | 110: flatworm, platyhelminth 112 | 111: nematode, nematode worm, roundworm 113 | 112: conch 114 | 113: snail 115 | 114: slug 116 | 115: sea slug, nudibranch 117 | 116: chiton, coat-of-mail shell, sea cradle, polyplacophore 118 | 117: chambered nautilus, pearly nautilus, nautilus 119 | 118: Dungeness crab, Cancer magister 120 | 119: rock crab, Cancer irroratus 121 | 120: fiddler crab 122 | 121: king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica 123 | 122: American lobster, Northern lobster, Maine lobster, Homarus americanus 124 | 123: spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish 125 | 124: crayfish, crawfish, crawdad, crawdaddy 126 | 125: hermit crab 127 | 126: isopod 128 | 127: white stork, Ciconia ciconia 129 | 128: black stork, Ciconia nigra 130 | 129: spoonbill 131 | 130: flamingo 132 | 131: little blue heron, Egretta caerulea 133 | 132: American egret, great white heron, Egretta albus 134 | 133: bittern 135 | 134: crane 136 | 135: limpkin, Aramus pictus 137 | 136: European gallinule, Porphyrio porphyrio 138 | 137: American coot, marsh hen, mud hen, water hen, Fulica americana 139 | 138: bustard 140 | 139: ruddy turnstone, Arenaria interpres 141 | 140: red-backed sandpiper, dunlin, Erolia alpina 142 | 141: redshank, Tringa totanus 143 | 142: dowitcher 144 | 143: oystercatcher, oyster catcher 145 | 144: pelican 146 | 145: king penguin, Aptenodytes patagonica 147 | 146: albatross, mollymawk 148 | 147: grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus 149 | 148: killer whale, killer, orca, grampus, sea wolf, Orcinus orca 150 | 149: dugong, Dugong dugon 151 | 150: sea lion 152 | 151: Chihuahua 153 | 152: Japanese spaniel 154 | 153: Maltese dog, Maltese terrier, Maltese 155 | 154: Pekinese, Pekingese, Peke 156 | 155: Shih-Tzu 157 | 156: Blenheim spaniel 158 | 157: papillon 159 | 158: toy terrier 160 | 159: Rhodesian ridgeback 161 | 160: Afghan hound, Afghan 162 | 161: basset, basset hound 163 | 162: beagle 164 | 163: bloodhound, sleuthhound 165 | 164: bluetick 166 | 165: black-and-tan coonhound 167 | 166: Walker hound, Walker foxhound 168 | 167: English foxhound 169 | 168: redbone 170 | 169: borzoi, Russian wolfhound 171 | 170: Irish wolfhound 172 | 171: Italian greyhound 173 | 172: whippet 174 | 173: Ibizan hound, Ibizan Podenco 175 | 174: Norwegian elkhound, elkhound 176 | 175: otterhound, otter hound 177 | 176: Saluki, gazelle hound 178 | 177: Scottish deerhound, deerhound 179 | 178: Weimaraner 180 | 179: Staffordshire bullterrier, Staffordshire bull terrier 181 | 180: American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier 182 | 181: Bedlington terrier 183 | 182: Border terrier 184 | 183: Kerry blue terrier 185 | 184: Irish terrier 186 | 185: Norfolk terrier 187 | 186: Norwich terrier 188 | 187: Yorkshire terrier 189 | 188: wire-haired fox terrier 190 | 189: Lakeland terrier 191 | 190: Sealyham terrier, Sealyham 192 | 191: Airedale, Airedale terrier 193 | 192: cairn, cairn terrier 194 | 193: Australian terrier 195 | 194: Dandie Dinmont, Dandie Dinmont terrier 196 | 195: Boston bull, Boston terrier 197 | 196: miniature schnauzer 198 | 197: giant schnauzer 199 | 198: standard schnauzer 200 | 199: Scotch terrier, Scottish terrier, Scottie 201 | 200: Tibetan terrier, chrysanthemum dog 202 | 201: silky terrier, Sydney silky 203 | 202: soft-coated wheaten terrier 204 | 203: West Highland white terrier 205 | 204: Lhasa, Lhasa apso 206 | 205: flat-coated retriever 207 | 206: curly-coated retriever 208 | 207: golden retriever 209 | 208: Labrador retriever 210 | 209: Chesapeake Bay retriever 211 | 210: German short-haired pointer 212 | 211: vizsla, Hungarian pointer 213 | 212: English setter 214 | 213: Irish setter, red setter 215 | 214: Gordon setter 216 | 215: Brittany spaniel 217 | 216: clumber, clumber spaniel 218 | 217: English springer, English springer spaniel 219 | 218: Welsh springer spaniel 220 | 219: cocker spaniel, English cocker spaniel, cocker 221 | 220: Sussex spaniel 222 | 221: Irish water spaniel 223 | 222: kuvasz 224 | 223: schipperke 225 | 224: groenendael 226 | 225: malinois 227 | 226: briard 228 | 227: kelpie 229 | 228: komondor 230 | 229: Old English sheepdog, bobtail 231 | 230: Shetland sheepdog, Shetland sheep dog, Shetland 232 | 231: collie 233 | 232: Border collie 234 | 233: Bouvier des Flandres, Bouviers des Flandres 235 | 234: Rottweiler 236 | 235: German shepherd, German shepherd dog, German police dog, alsatian 237 | 236: Doberman, Doberman pinscher 238 | 237: miniature pinscher 239 | 238: Greater Swiss Mountain dog 240 | 239: Bernese mountain dog 241 | 240: Appenzeller 242 | 241: EntleBucher 243 | 242: boxer 244 | 243: bull mastiff 245 | 244: Tibetan mastiff 246 | 245: French bulldog 247 | 246: Great Dane 248 | 247: Saint Bernard, St Bernard 249 | 248: Eskimo dog, husky 250 | 249: malamute, malemute, Alaskan malamute 251 | 250: Siberian husky 252 | 251: dalmatian, coach dog, carriage dog 253 | 252: affenpinscher, monkey pinscher, monkey dog 254 | 253: basenji 255 | 254: pug, pug-dog 256 | 255: Leonberg 257 | 256: Newfoundland, Newfoundland dog 258 | 257: Great Pyrenees 259 | 258: Samoyed, Samoyede 260 | 259: Pomeranian 261 | 260: chow, chow chow 262 | 261: keeshond 263 | 262: Brabancon griffon 264 | 263: Pembroke, Pembroke Welsh corgi 265 | 264: Cardigan, Cardigan Welsh corgi 266 | 265: toy poodle 267 | 266: miniature poodle 268 | 267: standard poodle 269 | 268: Mexican hairless 270 | 269: timber wolf, grey wolf, gray wolf, Canis lupus 271 | 270: white wolf, Arctic wolf, Canis lupus tundrarum 272 | 271: red wolf, maned wolf, Canis rufus, Canis niger 273 | 272: coyote, prairie wolf, brush wolf, Canis latrans 274 | 273: dingo, warrigal, warragal, Canis dingo 275 | 274: dhole, Cuon alpinus 276 | 275: African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus 277 | 276: hyena, hyaena 278 | 277: red fox, Vulpes vulpes 279 | 278: kit fox, Vulpes macrotis 280 | 279: Arctic fox, white fox, Alopex lagopus 281 | 280: grey fox, gray fox, Urocyon cinereoargenteus 282 | 281: tabby, tabby cat 283 | 282: tiger cat 284 | 283: Persian cat 285 | 284: Siamese cat, Siamese 286 | 285: Egyptian cat 287 | 286: cougar, puma, catamount, mountain lion, painter, panther, Felis concolor 288 | 287: lynx, catamount 289 | 288: leopard, Panthera pardus 290 | 289: snow leopard, ounce, Panthera uncia 291 | 290: jaguar, panther, Panthera onca, Felis onca 292 | 291: lion, king of beasts, Panthera leo 293 | 292: tiger, Panthera tigris 294 | 293: cheetah, chetah, Acinonyx jubatus 295 | 294: brown bear, bruin, Ursus arctos 296 | 295: American black bear, black bear, Ursus americanus, Euarctos americanus 297 | 296: ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus 298 | 297: sloth bear, Melursus ursinus, Ursus ursinus 299 | 298: mongoose 300 | 299: meerkat, mierkat 301 | 300: tiger beetle 302 | 301: ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle 303 | 302: ground beetle, carabid beetle 304 | 303: long-horned beetle, longicorn, longicorn beetle 305 | 304: leaf beetle, chrysomelid 306 | 305: dung beetle 307 | 306: rhinoceros beetle 308 | 307: weevil 309 | 308: fly 310 | 309: bee 311 | 310: ant, emmet, pismire 312 | 311: grasshopper, hopper 313 | 312: cricket 314 | 313: walking stick, walkingstick, stick insect 315 | 314: cockroach, roach 316 | 315: mantis, mantid 317 | 316: cicada, cicala 318 | 317: leafhopper 319 | 318: lacewing, lacewing fly 320 | 319: dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk 321 | 320: damselfly 322 | 321: admiral 323 | 322: ringlet, ringlet butterfly 324 | 323: monarch, monarch butterfly, milkweed butterfly, Danaus plexippus 325 | 324: cabbage butterfly 326 | 325: sulphur butterfly, sulfur butterfly 327 | 326: lycaenid, lycaenid butterfly 328 | 327: starfish, sea star 329 | 328: sea urchin 330 | 329: sea cucumber, holothurian 331 | 330: wood rabbit, cottontail, cottontail rabbit 332 | 331: hare 333 | 332: Angora, Angora rabbit 334 | 333: hamster 335 | 334: porcupine, hedgehog 336 | 335: fox squirrel, eastern fox squirrel, Sciurus niger 337 | 336: marmot 338 | 337: beaver 339 | 338: guinea pig, Cavia cobaya 340 | 339: sorrel 341 | 340: zebra 342 | 341: hog, pig, grunter, squealer, Sus scrofa 343 | 342: wild boar, boar, Sus scrofa 344 | 343: warthog 345 | 344: hippopotamus, hippo, river horse, Hippopotamus amphibius 346 | 345: ox 347 | 346: water buffalo, water ox, Asiatic buffalo, Bubalus bubalis 348 | 347: bison 349 | 348: ram, tup 350 | 349: bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis 351 | 350: ibex, Capra ibex 352 | 351: hartebeest 353 | 352: impala, Aepyceros melampus 354 | 353: gazelle 355 | 354: Arabian camel, dromedary, Camelus dromedarius 356 | 355: llama 357 | 356: weasel 358 | 357: mink 359 | 358: polecat, fitch, foulmart, foumart, Mustela putorius 360 | 359: black-footed ferret, ferret, Mustela nigripes 361 | 360: otter 362 | 361: skunk, polecat, wood pussy 363 | 362: badger 364 | 363: armadillo 365 | 364: three-toed sloth, ai, Bradypus tridactylus 366 | 365: orangutan, orang, orangutang, Pongo pygmaeus 367 | 366: gorilla, Gorilla gorilla 368 | 367: chimpanzee, chimp, Pan troglodytes 369 | 368: gibbon, Hylobates lar 370 | 369: siamang, Hylobates syndactylus, Symphalangus syndactylus 371 | 370: guenon, guenon monkey 372 | 371: patas, hussar monkey, Erythrocebus patas 373 | 372: baboon 374 | 373: macaque 375 | 374: langur 376 | 375: colobus, colobus monkey 377 | 376: proboscis monkey, Nasalis larvatus 378 | 377: marmoset 379 | 378: capuchin, ringtail, Cebus capucinus 380 | 379: howler monkey, howler 381 | 380: titi, titi monkey 382 | 381: spider monkey, Ateles geoffroyi 383 | 382: squirrel monkey, Saimiri sciureus 384 | 383: Madagascar cat, ring-tailed lemur, Lemur catta 385 | 384: indri, indris, Indri indri, Indri brevicaudatus 386 | 385: Indian elephant, Elephas maximus 387 | 386: African elephant, Loxodonta africana 388 | 387: lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens 389 | 388: giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca 390 | 389: barracouta, snoek 391 | 390: eel 392 | 391: coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch 393 | 392: rock beauty, Holocanthus tricolor 394 | 393: anemone fish 395 | 394: sturgeon 396 | 395: gar, garfish, garpike, billfish, Lepisosteus osseus 397 | 396: lionfish 398 | 397: puffer, pufferfish, blowfish, globefish 399 | 398: abacus 400 | 399: abaya 401 | 400: academic gown, academic robe, judge's robe 402 | 401: accordion, piano accordion, squeeze box 403 | 402: acoustic guitar 404 | 403: aircraft carrier, carrier, flattop, attack aircraft carrier 405 | 404: airliner 406 | 405: airship, dirigible 407 | 406: altar 408 | 407: ambulance 409 | 408: amphibian, amphibious vehicle 410 | 409: analog clock 411 | 410: apiary, bee house 412 | 411: apron 413 | 412: ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin 414 | 413: assault rifle, assault gun 415 | 414: backpack, back pack, knapsack, packsack, rucksack, haversack 416 | 415: bakery, bakeshop, bakehouse 417 | 416: balance beam, beam 418 | 417: balloon 419 | 418: ballpoint, ballpoint pen, ballpen, Biro 420 | 419: Band Aid 421 | 420: banjo 422 | 421: bannister, banister, balustrade, balusters, handrail 423 | 422: barbell 424 | 423: barber chair 425 | 424: barbershop 426 | 425: barn 427 | 426: barometer 428 | 427: barrel, cask 429 | 428: barrow, garden cart, lawn cart, wheelbarrow 430 | 429: baseball 431 | 430: basketball 432 | 431: bassinet 433 | 432: bassoon 434 | 433: bathing cap, swimming cap 435 | 434: bath towel 436 | 435: bathtub, bathing tub, bath, tub 437 | 436: beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon 438 | 437: beacon, lighthouse, beacon light, pharos 439 | 438: beaker 440 | 439: bearskin, busby, shako 441 | 440: beer bottle 442 | 441: beer glass 443 | 442: bell cote, bell cot 444 | 443: bib 445 | 444: bicycle-built-for-two, tandem bicycle, tandem 446 | 445: bikini, two-piece 447 | 446: binder, ring-binder 448 | 447: binoculars, field glasses, opera glasses 449 | 448: birdhouse 450 | 449: boathouse 451 | 450: bobsled, bobsleigh, bob 452 | 451: bolo tie, bolo, bola tie, bola 453 | 452: bonnet, poke bonnet 454 | 453: bookcase 455 | 454: bookshop, bookstore, bookstall 456 | 455: bottlecap 457 | 456: bow 458 | 457: bow tie, bow-tie, bowtie 459 | 458: brass, memorial tablet, plaque 460 | 459: brassiere, bra, bandeau 461 | 460: breakwater, groin, groyne, mole, bulwark, seawall, jetty 462 | 461: breastplate, aegis, egis 463 | 462: broom 464 | 463: bucket, pail 465 | 464: buckle 466 | 465: bulletproof vest 467 | 466: bullet train, bullet 468 | 467: butcher shop, meat market 469 | 468: cab, hack, taxi, taxicab 470 | 469: caldron, cauldron 471 | 470: candle, taper, wax light 472 | 471: cannon 473 | 472: canoe 474 | 473: can opener, tin opener 475 | 474: cardigan 476 | 475: car mirror 477 | 476: carousel, carrousel, merry-go-round, roundabout, whirligig 478 | 477: carpenter's kit, tool kit 479 | 478: carton 480 | 479: car wheel 481 | 480: cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM 482 | 481: cassette 483 | 482: cassette player 484 | 483: castle 485 | 484: catamaran 486 | 485: CD player 487 | 486: cello, violoncello 488 | 487: cellular telephone, cellular phone, cellphone, cell, mobile phone 489 | 488: chain 490 | 489: chainlink fence 491 | 490: chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour 492 | 491: chain saw, chainsaw 493 | 492: chest 494 | 493: chiffonier, commode 495 | 494: chime, bell, gong 496 | 495: china cabinet, china closet 497 | 496: Christmas stocking 498 | 497: church, church building 499 | 498: cinema, movie theater, movie theatre, movie house, picture palace 500 | 499: cleaver, meat cleaver, chopper 501 | 500: cliff dwelling 502 | 501: cloak 503 | 502: clog, geta, patten, sabot 504 | 503: cocktail shaker 505 | 504: coffee mug 506 | 505: coffeepot 507 | 506: coil, spiral, volute, whorl, helix 508 | 507: combination lock 509 | 508: computer keyboard, keypad 510 | 509: confectionery, confectionary, candy store 511 | 510: container ship, containership, container vessel 512 | 511: convertible 513 | 512: corkscrew, bottle screw 514 | 513: cornet, horn, trumpet, trump 515 | 514: cowboy boot 516 | 515: cowboy hat, ten-gallon hat 517 | 516: cradle 518 | 517: crane 519 | 518: crash helmet 520 | 519: crate 521 | 520: crib, cot 522 | 521: Crock Pot 523 | 522: croquet ball 524 | 523: crutch 525 | 524: cuirass 526 | 525: dam, dike, dyke 527 | 526: desk 528 | 527: desktop computer 529 | 528: dial telephone, dial phone 530 | 529: diaper, nappy, napkin 531 | 530: digital clock 532 | 531: digital watch 533 | 532: dining table, board 534 | 533: dishrag, dishcloth 535 | 534: dishwasher, dish washer, dishwashing machine 536 | 535: disk brake, disc brake 537 | 536: dock, dockage, docking facility 538 | 537: dogsled, dog sled, dog sleigh 539 | 538: dome 540 | 539: doormat, welcome mat 541 | 540: drilling platform, offshore rig 542 | 541: drum, membranophone, tympan 543 | 542: drumstick 544 | 543: dumbbell 545 | 544: Dutch oven 546 | 545: electric fan, blower 547 | 546: electric guitar 548 | 547: electric locomotive 549 | 548: entertainment center 550 | 549: envelope 551 | 550: espresso maker 552 | 551: face powder 553 | 552: feather boa, boa 554 | 553: file, file cabinet, filing cabinet 555 | 554: fireboat 556 | 555: fire engine, fire truck 557 | 556: fire screen, fireguard 558 | 557: flagpole, flagstaff 559 | 558: flute, transverse flute 560 | 559: folding chair 561 | 560: football helmet 562 | 561: forklift 563 | 562: fountain 564 | 563: fountain pen 565 | 564: four-poster 566 | 565: freight car 567 | 566: French horn, horn 568 | 567: frying pan, frypan, skillet 569 | 568: fur coat 570 | 569: garbage truck, dustcart 571 | 570: gasmask, respirator, gas helmet 572 | 571: gas pump, gasoline pump, petrol pump, island dispenser 573 | 572: goblet 574 | 573: go-kart 575 | 574: golf ball 576 | 575: golfcart, golf cart 577 | 576: gondola 578 | 577: gong, tam-tam 579 | 578: gown 580 | 579: grand piano, grand 581 | 580: greenhouse, nursery, glasshouse 582 | 581: grille, radiator grille 583 | 582: grocery store, grocery, food market, market 584 | 583: guillotine 585 | 584: hair slide 586 | 585: hair spray 587 | 586: half track 588 | 587: hammer 589 | 588: hamper 590 | 589: hand blower, blow dryer, blow drier, hair dryer, hair drier 591 | 590: hand-held computer, hand-held microcomputer 592 | 591: handkerchief, hankie, hanky, hankey 593 | 592: hard disc, hard disk, fixed disk 594 | 593: harmonica, mouth organ, harp, mouth harp 595 | 594: harp 596 | 595: harvester, reaper 597 | 596: hatchet 598 | 597: holster 599 | 598: home theater, home theatre 600 | 599: honeycomb 601 | 600: hook, claw 602 | 601: hoopskirt, crinoline 603 | 602: horizontal bar, high bar 604 | 603: horse cart, horse-cart 605 | 604: hourglass 606 | 605: iPod 607 | 606: iron, smoothing iron 608 | 607: jack-o'-lantern 609 | 608: jean, blue jean, denim 610 | 609: jeep, landrover 611 | 610: jersey, T-shirt, tee shirt 612 | 611: jigsaw puzzle 613 | 612: jinrikisha, ricksha, rickshaw 614 | 613: joystick 615 | 614: kimono 616 | 615: knee pad 617 | 616: knot 618 | 617: lab coat, laboratory coat 619 | 618: ladle 620 | 619: lampshade, lamp shade 621 | 620: laptop, laptop computer 622 | 621: lawn mower, mower 623 | 622: lens cap, lens cover 624 | 623: letter opener, paper knife, paperknife 625 | 624: library 626 | 625: lifeboat 627 | 626: lighter, light, igniter, ignitor 628 | 627: limousine, limo 629 | 628: liner, ocean liner 630 | 629: lipstick, lip rouge 631 | 630: Loafer 632 | 631: lotion 633 | 632: loudspeaker, speaker, speaker unit, loudspeaker system, speaker system 634 | 633: loupe, jeweler's loupe 635 | 634: lumbermill, sawmill 636 | 635: magnetic compass 637 | 636: mailbag, postbag 638 | 637: mailbox, letter box 639 | 638: maillot 640 | 639: maillot, tank suit 641 | 640: manhole cover 642 | 641: maraca 643 | 642: marimba, xylophone 644 | 643: mask 645 | 644: matchstick 646 | 645: maypole 647 | 646: maze, labyrinth 648 | 647: measuring cup 649 | 648: medicine chest, medicine cabinet 650 | 649: megalith, megalithic structure 651 | 650: microphone, mike 652 | 651: microwave, microwave oven 653 | 652: military uniform 654 | 653: milk can 655 | 654: minibus 656 | 655: miniskirt, mini 657 | 656: minivan 658 | 657: missile 659 | 658: mitten 660 | 659: mixing bowl 661 | 660: mobile home, manufactured home 662 | 661: Model T 663 | 662: modem 664 | 663: monastery 665 | 664: monitor 666 | 665: moped 667 | 666: mortar 668 | 667: mortarboard 669 | 668: mosque 670 | 669: mosquito net 671 | 670: motor scooter, scooter 672 | 671: mountain bike, all-terrain bike, off-roader 673 | 672: mountain tent 674 | 673: mouse, computer mouse 675 | 674: mousetrap 676 | 675: moving van 677 | 676: muzzle 678 | 677: nail 679 | 678: neck brace 680 | 679: necklace 681 | 680: nipple 682 | 681: notebook, notebook computer 683 | 682: obelisk 684 | 683: oboe, hautboy, hautbois 685 | 684: ocarina, sweet potato 686 | 685: odometer, hodometer, mileometer, milometer 687 | 686: oil filter 688 | 687: organ, pipe organ 689 | 688: oscilloscope, scope, cathode-ray oscilloscope, CRO 690 | 689: overskirt 691 | 690: oxcart 692 | 691: oxygen mask 693 | 692: packet 694 | 693: paddle, boat paddle 695 | 694: paddlewheel, paddle wheel 696 | 695: padlock 697 | 696: paintbrush 698 | 697: pajama, pyjama, pj's, jammies 699 | 698: palace 700 | 699: panpipe, pandean pipe, syrinx 701 | 700: paper towel 702 | 701: parachute, chute 703 | 702: parallel bars, bars 704 | 703: park bench 705 | 704: parking meter 706 | 705: passenger car, coach, carriage 707 | 706: patio, terrace 708 | 707: pay-phone, pay-station 709 | 708: pedestal, plinth, footstall 710 | 709: pencil box, pencil case 711 | 710: pencil sharpener 712 | 711: perfume, essence 713 | 712: Petri dish 714 | 713: photocopier 715 | 714: pick, plectrum, plectron 716 | 715: pickelhaube 717 | 716: picket fence, paling 718 | 717: pickup, pickup truck 719 | 718: pier 720 | 719: piggy bank, penny bank 721 | 720: pill bottle 722 | 721: pillow 723 | 722: ping-pong ball 724 | 723: pinwheel 725 | 724: pirate, pirate ship 726 | 725: pitcher, ewer 727 | 726: plane, carpenter's plane, woodworking plane 728 | 727: planetarium 729 | 728: plastic bag 730 | 729: plate rack 731 | 730: plow, plough 732 | 731: plunger, plumber's helper 733 | 732: Polaroid camera, Polaroid Land camera 734 | 733: pole 735 | 734: police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria 736 | 735: poncho 737 | 736: pool table, billiard table, snooker table 738 | 737: pop bottle, soda bottle 739 | 738: pot, flowerpot 740 | 739: potter's wheel 741 | 740: power drill 742 | 741: prayer rug, prayer mat 743 | 742: printer 744 | 743: prison, prison house 745 | 744: projectile, missile 746 | 745: projector 747 | 746: puck, hockey puck 748 | 747: punching bag, punch bag, punching ball, punchball 749 | 748: purse 750 | 749: quill, quill pen 751 | 750: quilt, comforter, comfort, puff 752 | 751: racer, race car, racing car 753 | 752: racket, racquet 754 | 753: radiator 755 | 754: radio, wireless 756 | 755: radio telescope, radio reflector 757 | 756: rain barrel 758 | 757: recreational vehicle, RV, R.V. 759 | 758: reel 760 | 759: reflex camera 761 | 760: refrigerator, icebox 762 | 761: remote control, remote 763 | 762: restaurant, eating house, eating place, eatery 764 | 763: revolver, six-gun, six-shooter 765 | 764: rifle 766 | 765: rocking chair, rocker 767 | 766: rotisserie 768 | 767: rubber eraser, rubber, pencil eraser 769 | 768: rugby ball 770 | 769: rule, ruler 771 | 770: running shoe 772 | 771: safe 773 | 772: safety pin 774 | 773: saltshaker, salt shaker 775 | 774: sandal 776 | 775: sarong 777 | 776: sax, saxophone 778 | 777: scabbard 779 | 778: scale, weighing machine 780 | 779: school bus 781 | 780: schooner 782 | 781: scoreboard 783 | 782: screen, CRT screen 784 | 783: screw 785 | 784: screwdriver 786 | 785: seat belt, seatbelt 787 | 786: sewing machine 788 | 787: shield, buckler 789 | 788: shoe shop, shoe-shop, shoe store 790 | 789: shoji 791 | 790: shopping basket 792 | 791: shopping cart 793 | 792: shovel 794 | 793: shower cap 795 | 794: shower curtain 796 | 795: ski 797 | 796: ski mask 798 | 797: sleeping bag 799 | 798: slide rule, slipstick 800 | 799: sliding door 801 | 800: slot, one-armed bandit 802 | 801: snorkel 803 | 802: snowmobile 804 | 803: snowplow, snowplough 805 | 804: soap dispenser 806 | 805: soccer ball 807 | 806: sock 808 | 807: solar dish, solar collector, solar furnace 809 | 808: sombrero 810 | 809: soup bowl 811 | 810: space bar 812 | 811: space heater 813 | 812: space shuttle 814 | 813: spatula 815 | 814: speedboat 816 | 815: spider web, spider's web 817 | 816: spindle 818 | 817: sports car, sport car 819 | 818: spotlight, spot 820 | 819: stage 821 | 820: steam locomotive 822 | 821: steel arch bridge 823 | 822: steel drum 824 | 823: stethoscope 825 | 824: stole 826 | 825: stone wall 827 | 826: stopwatch, stop watch 828 | 827: stove 829 | 828: strainer 830 | 829: streetcar, tram, tramcar, trolley, trolley car 831 | 830: stretcher 832 | 831: studio couch, day bed 833 | 832: stupa, tope 834 | 833: submarine, pigboat, sub, U-boat 835 | 834: suit, suit of clothes 836 | 835: sundial 837 | 836: sunglass 838 | 837: sunglasses, dark glasses, shades 839 | 838: sunscreen, sunblock, sun blocker 840 | 839: suspension bridge 841 | 840: swab, swob, mop 842 | 841: sweatshirt 843 | 842: swimming trunks, bathing trunks 844 | 843: swing 845 | 844: switch, electric switch, electrical switch 846 | 845: syringe 847 | 846: table lamp 848 | 847: tank, army tank, armored combat vehicle, armoured combat vehicle 849 | 848: tape player 850 | 849: teapot 851 | 850: teddy, teddy bear 852 | 851: television, television system 853 | 852: tennis ball 854 | 853: thatch, thatched roof 855 | 854: theater curtain, theatre curtain 856 | 855: thimble 857 | 856: thresher, thrasher, threshing machine 858 | 857: throne 859 | 858: tile roof 860 | 859: toaster 861 | 860: tobacco shop, tobacconist shop, tobacconist 862 | 861: toilet seat 863 | 862: torch 864 | 863: totem pole 865 | 864: tow truck, tow car, wrecker 866 | 865: toyshop 867 | 866: tractor 868 | 867: trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi 869 | 868: tray 870 | 869: trench coat 871 | 870: tricycle, trike, velocipede 872 | 871: trimaran 873 | 872: tripod 874 | 873: triumphal arch 875 | 874: trolleybus, trolley coach, trackless trolley 876 | 875: trombone 877 | 876: tub, vat 878 | 877: turnstile 879 | 878: typewriter keyboard 880 | 879: umbrella 881 | 880: unicycle, monocycle 882 | 881: upright, upright piano 883 | 882: vacuum, vacuum cleaner 884 | 883: vase 885 | 884: vault 886 | 885: velvet 887 | 886: vending machine 888 | 887: vestment 889 | 888: viaduct 890 | 889: violin, fiddle 891 | 890: volleyball 892 | 891: waffle iron 893 | 892: wall clock 894 | 893: wallet, billfold, notecase, pocketbook 895 | 894: wardrobe, closet, press 896 | 895: warplane, military plane 897 | 896: washbasin, handbasin, washbowl, lavabo, wash-hand basin 898 | 897: washer, automatic washer, washing machine 899 | 898: water bottle 900 | 899: water jug 901 | 900: water tower 902 | 901: whiskey jug 903 | 902: whistle 904 | 903: wig 905 | 904: window screen 906 | 905: window shade 907 | 906: Windsor tie 908 | 907: wine bottle 909 | 908: wing 910 | 909: wok 911 | 910: wooden spoon 912 | 911: wool, woolen, woollen 913 | 912: worm fence, snake fence, snake-rail fence, Virginia fence 914 | 913: wreck 915 | 914: yawl 916 | 915: yurt 917 | 916: web site, website, internet site, site 918 | 917: comic book 919 | 918: crossword puzzle, crossword 920 | 919: street sign 921 | 920: traffic light, traffic signal, stoplight 922 | 921: book jacket, dust cover, dust jacket, dust wrapper 923 | 922: menu 924 | 923: plate 925 | 924: guacamole 926 | 925: consomme 927 | 926: hot pot, hotpot 928 | 927: trifle 929 | 928: ice cream, icecream 930 | 929: ice lolly, lolly, lollipop, popsicle 931 | 930: French loaf 932 | 931: bagel, beigel 933 | 932: pretzel 934 | 933: cheeseburger 935 | 934: hotdog, hot dog, red hot 936 | 935: mashed potato 937 | 936: head cabbage 938 | 937: broccoli 939 | 938: cauliflower 940 | 939: zucchini, courgette 941 | 940: spaghetti squash 942 | 941: acorn squash 943 | 942: butternut squash 944 | 943: cucumber, cuke 945 | 944: artichoke, globe artichoke 946 | 945: bell pepper 947 | 946: cardoon 948 | 947: mushroom 949 | 948: Granny Smith 950 | 949: strawberry 951 | 950: orange 952 | 951: lemon 953 | 952: fig 954 | 953: pineapple, ananas 955 | 954: banana 956 | 955: jackfruit, jak, jack 957 | 956: custard apple 958 | 957: pomegranate 959 | 958: hay 960 | 959: carbonara 961 | 960: chocolate sauce, chocolate syrup 962 | 961: dough 963 | 962: meat loaf, meatloaf 964 | 963: pizza, pizza pie 965 | 964: potpie 966 | 965: burrito 967 | 966: red wine 968 | 967: espresso 969 | 968: cup 970 | 969: eggnog 971 | 970: alp 972 | 971: bubble 973 | 972: cliff, drop, drop-off 974 | 973: coral reef 975 | 974: geyser 976 | 975: lakeside, lakeshore 977 | 976: promontory, headland, head, foreland 978 | 977: sandbar, sand bar 979 | 978: seashore, coast, seacoast, sea-coast 980 | 979: valley, vale 981 | 980: volcano 982 | 981: ballplayer, baseball player 983 | 982: groom, bridegroom 984 | 983: scuba diver 985 | 984: rapeseed 986 | 985: daisy 987 | 986: yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum 988 | 987: corn 989 | 988: acorn 990 | 989: hip, rose hip, rosehip 991 | 990: buckeye, horse chestnut, conker 992 | 991: coral fungus 993 | 992: agaric 994 | 993: gyromitra 995 | 994: stinkhorn, carrion fungus 996 | 995: earthstar 997 | 996: hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa 998 | 997: bolete 999 | 998: ear, spike, capitulum 1000 | 999: toilet tissue, toilet paper, bathroom tissue 1001 | -------------------------------------------------------------------------------- /images/cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiveml/tensorflow-grad-cam/9f4c9b7a5f9c94a0490b143282abc965ce1cfb0f/images/cat.jpg -------------------------------------------------------------------------------- /images/cat_heatmap.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiveml/tensorflow-grad-cam/9f4c9b7a5f9c94a0490b143282abc965ce1cfb0f/images/cat_heatmap.png -------------------------------------------------------------------------------- /images/scarjo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiveml/tensorflow-grad-cam/9f4c9b7a5f9c94a0490b143282abc965ce1cfb0f/images/scarjo.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import tensorflow as tf 3 | import numpy as np 4 | from skimage import io 5 | from matplotlib import pyplot as plt 6 | import cv2 7 | 8 | from model.nets import nets_factory 9 | from model.preprocessing import preprocessing_factory 10 | 11 | flags = tf.app.flags 12 | flags.DEFINE_string("input", "images/cat.jpg", "Path to input image ['images/cat.jpg']") 13 | flags.DEFINE_string("output", "output.png", "Path to output image ['output.png']") 14 | flags.DEFINE_string("layer_name", None, "Layer till which to backpropagate") 15 | flags.DEFINE_string("model_name", "resnet_v2_50", "Name of the model") 16 | flags.DEFINE_string("preprocessing_name", None, "Name of the image preprocessor") 17 | flags.DEFINE_integer("eval_image_size", None, "Resize images to this size before eval") 18 | flags.DEFINE_string("dataset_dir", "./imagenet", "Location of the labels.txt") 19 | flags.DEFINE_string("checkpoint_path", "./imagenet/resnet_v2_50.ckpt", "saved weights for model") 20 | flags.DEFINE_integer("label_offset", 1, "Used for imagenet with 1001 classes for background class") 21 | 22 | FLAGS = flags.FLAGS 23 | 24 | slim = tf.contrib.slim 25 | 26 | _layer_names = { "resnet_v2_50": ["PrePool","predictions"], 27 | "resnet_v2_101": ["PrePool","predictions"], 28 | "resnet_v2_152": ["PrePool","predictions"], 29 | } 30 | 31 | _logits_name = "Logits" 32 | 33 | def load_labels_from_file(dataset_dir): 34 | labels = {} 35 | labels_name = os.path.join(dataset_dir,'labels.txt') 36 | with open(labels_name) as label_file: 37 | for line in label_file: 38 | idx,label = line.rstrip('\n').split(':') 39 | labels[int(idx)] = label 40 | assert len(labels) > 1 41 | return labels 42 | 43 | 44 | def load_image(img_path): 45 | print("Loading image") 46 | img = cv2.imread(img_path) 47 | if img is None: 48 | sys.stderr.write('Unable to load img: %s\n' % img_path) 49 | sys.exit(1) 50 | img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB) 51 | return img 52 | 53 | 54 | def preprocess_image(image,eval_image_size): 55 | preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name 56 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 57 | preprocessing_name, is_training=False) 58 | image = image_preprocessing_fn(image, eval_image_size, eval_image_size) 59 | return image 60 | 61 | def grad_cam(img, imgs0, end_points, sess, predicted_class, layer_name, nb_classes, eval_image_size): 62 | # Conv layer tensor [?,10,10,2048] 63 | conv_layer = end_points[layer_name] 64 | # [1000]-D tensor with target class index set to 1 and rest as 0 65 | one_hot = tf.sparse_to_dense(predicted_class, [nb_classes], 1.0) 66 | signal = tf.multiply(end_points[_logits_name], one_hot) 67 | loss = tf.reduce_mean(signal) 68 | 69 | grads = tf.gradients(loss, conv_layer)[0] 70 | # Normalizing the gradients 71 | norm_grads = tf.divide(grads, tf.sqrt(tf.reduce_mean(tf.square(grads))) + tf.constant(1e-5)) 72 | 73 | output, grads_val = sess.run([conv_layer, norm_grads], feed_dict={imgs0: img}) 74 | output = output[0] # [10,10,2048] 75 | grads_val = grads_val[0] # [10,10,2048] 76 | 77 | weights = np.mean(grads_val, axis = (0, 1)) # [2048] 78 | cam = np.ones(output.shape[0 : 2], dtype = np.float32) # [10,10] 79 | 80 | # Taking a weighted average 81 | for i, w in enumerate(weights): 82 | cam += w * output[:, :, i] 83 | 84 | # Passing through ReLU 85 | cam = np.maximum(cam, 0) 86 | cam = cam / np.max(cam) 87 | cam3 = cv2.resize(cam, (eval_image_size,eval_image_size)) 88 | 89 | return cam3 90 | 91 | 92 | def main(_): 93 | checkpoint_path=FLAGS.checkpoint_path 94 | img = load_image(FLAGS.input) 95 | 96 | labels = load_labels_from_file(FLAGS.dataset_dir) 97 | num_classes = len(labels) + FLAGS.label_offset 98 | 99 | network_fn = nets_factory.get_network_fn( 100 | FLAGS.model_name, 101 | num_classes=num_classes, 102 | is_training=False) 103 | 104 | eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size 105 | 106 | print("\nLoading Model") 107 | imgs0 = tf.placeholder(tf.uint8, [None,None, 3]) 108 | imgs = preprocess_image(imgs0,eval_image_size) 109 | imgs = tf.expand_dims(imgs,0) 110 | 111 | _,end_points = network_fn(imgs) 112 | 113 | init_fn = slim.assign_from_checkpoint_fn(checkpoint_path, slim.get_variables_to_restore()) 114 | 115 | print("\nFeedforwarding") 116 | 117 | with tf.Session() as sess: 118 | init_fn(sess) 119 | 120 | ep = sess.run(end_points, feed_dict={imgs0: img}) 121 | pred_layer_name = _layer_names[FLAGS.model_name][1] 122 | probs = ep[pred_layer_name][0] 123 | 124 | preds = (np.argsort(probs)[::-1])[0:5] 125 | print('\nTop 5 classes are') 126 | for p in preds: 127 | print(labels[p-FLAGS.label_offset], probs[p]) 128 | 129 | # Target class 130 | predicted_class = preds[0] 131 | # Target layer for visualization 132 | layer_name = FLAGS.layer_name or _layer_names[FLAGS.model_name][0] 133 | # Number of output classes of model being used 134 | nb_classes = num_classes 135 | 136 | cam3 = grad_cam(img, imgs0, end_points, sess, predicted_class, layer_name, nb_classes, eval_image_size) 137 | 138 | img = cv2.resize(img,(eval_image_size,eval_image_size)) 139 | img = img.astype(float) 140 | img /= img.max() 141 | 142 | 143 | cam3 = cv2.applyColorMap(np.uint8(255*cam3), cv2.COLORMAP_JET) 144 | cam3 = cv2.cvtColor(cam3, cv2.COLOR_BGR2RGB) 145 | 146 | # Superimposing the visualization with the image. 147 | alpha = 0.0025 148 | new_img = img+alpha*cam3 149 | new_img /= new_img.max() 150 | 151 | # Display and save 152 | io.imshow(new_img) 153 | plt.axis('off') 154 | plt.savefig(FLAGS.output,bbox_inches='tight') 155 | plt.show() 156 | 157 | if __name__ == '__main__': 158 | tf.app.run() 159 | 160 | -------------------------------------------------------------------------------- /main.sh: -------------------------------------------------------------------------------- 1 | python main.py --model_name=resnet_v2_50 --dataset_dir=./imagenet/ --checkpoint_path=./imagenet/resnet_v2_50.ckpt --input=./images/cat.jpg --eval_image_size=299 2 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiveml/tensorflow-grad-cam/9f4c9b7a5f9c94a0490b143282abc965ce1cfb0f/model/__init__.py -------------------------------------------------------------------------------- /model/nets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiveml/tensorflow-grad-cam/9f4c9b7a5f9c94a0490b143282abc965ce1cfb0f/model/nets/__init__.py -------------------------------------------------------------------------------- /model/nets/nets_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import functools 21 | 22 | import tensorflow as tf 23 | 24 | #from nets import alexnet 25 | #from nets import cifarnet 26 | #from nets import inception 27 | #from nets import lenet 28 | #from nets import mobilenet_v1 29 | #from nets import overfeat 30 | #from nets import resnet_v1 31 | from model.nets import resnet_v2 32 | #from nets import vgg 33 | #from nets.nasnet import nasnet 34 | 35 | slim = tf.contrib.slim 36 | 37 | networks_map = { 38 | 'resnet_v2_50': resnet_v2.resnet_v2_50, 39 | 'resnet_v2_101': resnet_v2.resnet_v2_101, 40 | 'resnet_v2_152': resnet_v2.resnet_v2_152, 41 | 'resnet_v2_200': resnet_v2.resnet_v2_200, 42 | } 43 | 44 | arg_scopes_map = { 45 | 'resnet_v2_50': resnet_v2.resnet_arg_scope, 46 | 'resnet_v2_101': resnet_v2.resnet_arg_scope, 47 | 'resnet_v2_152': resnet_v2.resnet_arg_scope, 48 | 'resnet_v2_200': resnet_v2.resnet_arg_scope, 49 | } 50 | 51 | 52 | def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False): 53 | """Returns a network_fn such as `logits, end_points = network_fn(images)`. 54 | 55 | Args: 56 | name: The name of the network. 57 | num_classes: The number of classes to use for classification. If 0 or None, 58 | the logits layer is omitted and its input features are returned instead. 59 | weight_decay: The l2 coefficient for the model weights. 60 | is_training: `True` if the model is being used for training and `False` 61 | otherwise. 62 | 63 | Returns: 64 | network_fn: A function that applies the model to a batch of images. It has 65 | the following signature: 66 | net, end_points = network_fn(images) 67 | The `images` input is a tensor of shape [batch_size, height, width, 3] 68 | with height = width = network_fn.default_image_size. (The permissibility 69 | and treatment of other sizes depends on the network_fn.) 70 | The returned `end_points` are a dictionary of intermediate activations. 71 | The returned `net` is the topmost layer, depending on `num_classes`: 72 | If `num_classes` was a non-zero integer, `net` is a logits tensor 73 | of shape [batch_size, num_classes]. 74 | If `num_classes` was 0 or `None`, `net` is a tensor with the input 75 | to the logits layer of shape [batch_size, 1, 1, num_features] or 76 | [batch_size, num_features]. Dropout has not been applied to this 77 | (even if the network's original classification does); it remains for 78 | the caller to do this or not. 79 | 80 | Raises: 81 | ValueError: If network `name` is not recognized. 82 | """ 83 | if name not in networks_map: 84 | raise ValueError('Name of network unknown %s' % name) 85 | func = networks_map[name] 86 | @functools.wraps(func) 87 | def network_fn(images, **kwargs): 88 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay) 89 | with slim.arg_scope(arg_scope): 90 | return func(images, num_classes, is_training=is_training, **kwargs) 91 | if hasattr(func, 'default_image_size'): 92 | network_fn.default_image_size = func.default_image_size 93 | 94 | return network_fn 95 | 96 | -------------------------------------------------------------------------------- /model/nets/resnet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains building blocks for various versions of Residual Networks. 16 | 17 | Residual networks (ResNets) were proposed in: 18 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015 20 | 21 | More variants were introduced in: 22 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 23 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016 24 | 25 | We can obtain different ResNet variants by changing the network depth, width, 26 | and form of residual unit. This module implements the infrastructure for 27 | building them. Concrete ResNet units and full ResNet networks are implemented in 28 | the accompanying resnet_v1.py and resnet_v2.py modules. 29 | 30 | Compared to https://github.com/KaimingHe/deep-residual-networks, in the current 31 | implementation we subsample the output activations in the last residual unit of 32 | each block, instead of subsampling the input activations in the first residual 33 | unit of each block. The two implementations give identical results but our 34 | implementation is more memory efficient. 35 | """ 36 | from __future__ import absolute_import 37 | from __future__ import division 38 | from __future__ import print_function 39 | 40 | import collections 41 | import tensorflow as tf 42 | 43 | slim = tf.contrib.slim 44 | 45 | 46 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])): 47 | """A named tuple describing a ResNet block. 48 | 49 | Its parts are: 50 | scope: The scope of the `Block`. 51 | unit_fn: The ResNet unit function which takes as input a `Tensor` and 52 | returns another `Tensor` with the output of the ResNet unit. 53 | args: A list of length equal to the number of units in the `Block`. The list 54 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the 55 | block to serve as argument to unit_fn. 56 | """ 57 | 58 | 59 | def subsample(inputs, factor, scope=None): 60 | """Subsamples the input along the spatial dimensions. 61 | 62 | Args: 63 | inputs: A `Tensor` of size [batch, height_in, width_in, channels]. 64 | factor: The subsampling factor. 65 | scope: Optional variable_scope. 66 | 67 | Returns: 68 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the 69 | input, either intact (if factor == 1) or subsampled (if factor > 1). 70 | """ 71 | if factor == 1: 72 | return inputs 73 | else: 74 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope) 75 | 76 | 77 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): 78 | """Strided 2-D convolution with 'SAME' padding. 79 | 80 | When stride > 1, then we do explicit zero-padding, followed by conv2d with 81 | 'VALID' padding. 82 | 83 | Note that 84 | 85 | net = conv2d_same(inputs, num_outputs, 3, stride=stride) 86 | 87 | is equivalent to 88 | 89 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME') 90 | net = subsample(net, factor=stride) 91 | 92 | whereas 93 | 94 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME') 95 | 96 | is different when the input's height or width is even, which is why we add the 97 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven(). 98 | 99 | Args: 100 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels]. 101 | num_outputs: An integer, the number of output filters. 102 | kernel_size: An int with the kernel_size of the filters. 103 | stride: An integer, the output stride. 104 | rate: An integer, rate for atrous convolution. 105 | scope: Scope. 106 | 107 | Returns: 108 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with 109 | the convolution output. 110 | """ 111 | if stride == 1: 112 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate, 113 | padding='SAME', scope=scope) 114 | else: 115 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 116 | pad_total = kernel_size_effective - 1 117 | pad_beg = pad_total // 2 118 | pad_end = pad_total - pad_beg 119 | inputs = tf.pad(inputs, 120 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) 121 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride, 122 | rate=rate, padding='VALID', scope=scope) 123 | 124 | 125 | @slim.add_arg_scope 126 | def stack_blocks_dense(net, blocks, output_stride=None, 127 | outputs_collections=None): 128 | """Stacks ResNet `Blocks` and controls output feature density. 129 | 130 | First, this function creates scopes for the ResNet in the form of 131 | 'block_name/unit_1', 'block_name/unit_2', etc. 132 | 133 | Second, this function allows the user to explicitly control the ResNet 134 | output_stride, which is the ratio of the input to output spatial resolution. 135 | This is useful for dense prediction tasks such as semantic segmentation or 136 | object detection. 137 | 138 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a 139 | factor of 2 when transitioning between consecutive ResNet blocks. This results 140 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to 141 | half the nominal network stride (e.g., output_stride=4), then we compute 142 | responses twice. 143 | 144 | Control of the output feature density is implemented by atrous convolution. 145 | 146 | Args: 147 | net: A `Tensor` of size [batch, height, width, channels]. 148 | blocks: A list of length equal to the number of ResNet `Blocks`. Each 149 | element is a ResNet `Block` object describing the units in the `Block`. 150 | output_stride: If `None`, then the output will be computed at the nominal 151 | network stride. If output_stride is not `None`, it specifies the requested 152 | ratio of input to output spatial resolution, which needs to be equal to 153 | the product of unit strides from the start up to some level of the ResNet. 154 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1, 155 | then valid values for the output_stride are 1, 2, 6, 24 or None (which 156 | is equivalent to output_stride=24). 157 | outputs_collections: Collection to add the ResNet block outputs. 158 | 159 | Returns: 160 | net: Output tensor with stride equal to the specified output_stride. 161 | 162 | Raises: 163 | ValueError: If the target output_stride is not valid. 164 | """ 165 | # The current_stride variable keeps track of the effective stride of the 166 | # activations. This allows us to invoke atrous convolution whenever applying 167 | # the next residual unit would result in the activations having stride larger 168 | # than the target output_stride. 169 | current_stride = 1 170 | 171 | # The atrous convolution rate parameter. 172 | rate = 1 173 | 174 | for block in blocks: 175 | with tf.variable_scope(block.scope, 'block', [net]) as sc: 176 | for i, unit in enumerate(block.args): 177 | if output_stride is not None and current_stride > output_stride: 178 | raise ValueError('The target output_stride cannot be reached.') 179 | 180 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]): 181 | # If we have reached the target output_stride, then we need to employ 182 | # atrous convolution with stride=1 and multiply the atrous rate by the 183 | # current unit's stride for use in subsequent layers. 184 | if output_stride is not None and current_stride == output_stride: 185 | net = block.unit_fn(net, rate=rate, **dict(unit, stride=1)) 186 | rate *= unit.get('stride', 1) 187 | 188 | else: 189 | net = block.unit_fn(net, rate=1, **unit) 190 | current_stride *= unit.get('stride', 1) 191 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 192 | 193 | if output_stride is not None and current_stride != output_stride: 194 | raise ValueError('The target output_stride cannot be reached.') 195 | 196 | return net 197 | 198 | 199 | def resnet_arg_scope(weight_decay=0.0001, 200 | batch_norm_decay=0.997, 201 | batch_norm_epsilon=1e-5, 202 | batch_norm_scale=True, 203 | activation_fn=tf.nn.relu, 204 | use_batch_norm=True): 205 | """Defines the default ResNet arg scope. 206 | 207 | TODO(gpapan): The batch-normalization related default values above are 208 | appropriate for use in conjunction with the reference ResNet models 209 | released at https://github.com/KaimingHe/deep-residual-networks. When 210 | training ResNets from scratch, they might need to be tuned. 211 | 212 | Args: 213 | weight_decay: The weight decay to use for regularizing the model. 214 | batch_norm_decay: The moving average decay when estimating layer activation 215 | statistics in batch normalization. 216 | batch_norm_epsilon: Small constant to prevent division by zero when 217 | normalizing activations by their variance in batch normalization. 218 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 219 | activations in the batch normalization layer. 220 | activation_fn: The activation function which is used in ResNet. 221 | use_batch_norm: Whether or not to use batch normalization. 222 | 223 | Returns: 224 | An `arg_scope` to use for the resnet models. 225 | """ 226 | batch_norm_params = { 227 | 'decay': batch_norm_decay, 228 | 'epsilon': batch_norm_epsilon, 229 | 'scale': batch_norm_scale, 230 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 231 | 'fused': None, # Use fused batch norm if possible. 232 | } 233 | 234 | with slim.arg_scope( 235 | [slim.conv2d], 236 | weights_regularizer=slim.l2_regularizer(weight_decay), 237 | weights_initializer=slim.variance_scaling_initializer(), 238 | activation_fn=activation_fn, 239 | normalizer_fn=slim.batch_norm if use_batch_norm else None, 240 | normalizer_params=batch_norm_params): 241 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 242 | # The following implies padding='SAME' for pool1, which makes feature 243 | # alignment easier for dense prediction tasks. This is also used in 244 | # https://github.com/facebook/fb.resnet.torch. However the accompanying 245 | # code of 'Deep Residual Learning for Image Recognition' uses 246 | # padding='VALID' for pool1. You can switch to that choice by setting 247 | # slim.arg_scope([slim.max_pool2d], padding='VALID'). 248 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc: 249 | return arg_sc 250 | 251 | -------------------------------------------------------------------------------- /model/nets/resnet_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains definitions for the preactivation form of Residual Networks. 16 | 17 | Residual networks (ResNets) were originally proposed in: 18 | [1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385 20 | 21 | The full preactivation 'v2' ResNet variant implemented in this module was 22 | introduced by: 23 | [2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 24 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027 25 | 26 | The key difference of the full preactivation 'v2' variant compared to the 27 | 'v1' variant in [1] is the use of batch normalization before every weight layer. 28 | 29 | Typical use: 30 | 31 | from tensorflow.contrib.slim.nets import resnet_v2 32 | 33 | ResNet-101 for image classification into 1000 classes: 34 | 35 | # inputs has shape [batch, 224, 224, 3] 36 | with slim.arg_scope(resnet_v2.resnet_arg_scope()): 37 | net, end_points = resnet_v2.resnet_v2_101(inputs, 1000, is_training=False) 38 | 39 | ResNet-101 for semantic segmentation into 21 classes: 40 | 41 | # inputs has shape [batch, 513, 513, 3] 42 | with slim.arg_scope(resnet_v2.resnet_arg_scope()): 43 | net, end_points = resnet_v2.resnet_v2_101(inputs, 44 | 21, 45 | is_training=False, 46 | global_pool=False, 47 | output_stride=16) 48 | """ 49 | from __future__ import absolute_import 50 | from __future__ import division 51 | from __future__ import print_function 52 | 53 | import tensorflow as tf 54 | 55 | from model.nets import resnet_utils 56 | 57 | slim = tf.contrib.slim 58 | resnet_arg_scope = resnet_utils.resnet_arg_scope 59 | 60 | 61 | @slim.add_arg_scope 62 | def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1, 63 | outputs_collections=None, scope=None): 64 | """Bottleneck residual unit variant with BN before convolutions. 65 | 66 | This is the full preactivation residual unit variant proposed in [2]. See 67 | Fig. 1(b) of [2] for its definition. Note that we use here the bottleneck 68 | variant which has an extra bottleneck layer. 69 | 70 | When putting together two consecutive ResNet blocks that use this unit, one 71 | should use stride = 2 in the last unit of the first block. 72 | 73 | Args: 74 | inputs: A tensor of size [batch, height, width, channels]. 75 | depth: The depth of the ResNet unit output. 76 | depth_bottleneck: The depth of the bottleneck layers. 77 | stride: The ResNet unit's stride. Determines the amount of downsampling of 78 | the units output compared to its input. 79 | rate: An integer, rate for atrous convolution. 80 | outputs_collections: Collection to add the ResNet unit output. 81 | scope: Optional variable_scope. 82 | 83 | Returns: 84 | The ResNet unit's output. 85 | """ 86 | with tf.variable_scope(scope, 'bottleneck_v2', [inputs]) as sc: 87 | depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4) 88 | preact = slim.batch_norm(inputs, activation_fn=tf.nn.relu, scope='preact') 89 | if depth == depth_in: 90 | shortcut = resnet_utils.subsample(inputs, stride, 'shortcut') 91 | else: 92 | shortcut = slim.conv2d(preact, depth, [1, 1], stride=stride, 93 | normalizer_fn=None, activation_fn=None, 94 | scope='shortcut') 95 | 96 | residual = slim.conv2d(preact, depth_bottleneck, [1, 1], stride=1, 97 | scope='conv1') 98 | residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride, 99 | rate=rate, scope='conv2') 100 | residual = slim.conv2d(residual, depth, [1, 1], stride=1, 101 | normalizer_fn=None, activation_fn=None, 102 | scope='conv3') 103 | 104 | output = shortcut + residual 105 | 106 | return slim.utils.collect_named_outputs(outputs_collections, 107 | sc.name, 108 | output) 109 | 110 | 111 | def resnet_v2(inputs, 112 | blocks, 113 | num_classes=None, 114 | is_training=True, 115 | global_pool=True, 116 | output_stride=None, 117 | include_root_block=True, 118 | spatial_squeeze=True, 119 | reuse=None, 120 | scope=None): 121 | """Generator for v2 (preactivation) ResNet models. 122 | 123 | This function generates a family of ResNet v2 models. See the resnet_v2_*() 124 | methods for specific model instantiations, obtained by selecting different 125 | block instantiations that produce ResNets of various depths. 126 | 127 | Training for image classification on Imagenet is usually done with [224, 224] 128 | inputs, resulting in [7, 7] feature maps at the output of the last ResNet 129 | block for the ResNets defined in [1] that have nominal stride equal to 32. 130 | However, for dense prediction tasks we advise that one uses inputs with 131 | spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In 132 | this case the feature maps at the ResNet output will have spatial shape 133 | [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1] 134 | and corners exactly aligned with the input image corners, which greatly 135 | facilitates alignment of the features to the image. Using as input [225, 225] 136 | images results in [8, 8] feature maps at the output of the last ResNet block. 137 | 138 | For dense prediction tasks, the ResNet needs to run in fully-convolutional 139 | (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all 140 | have nominal stride equal to 32 and a good choice in FCN mode is to use 141 | output_stride=16 in order to increase the density of the computed features at 142 | small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915. 143 | 144 | Args: 145 | inputs: A tensor of size [batch, height_in, width_in, channels]. 146 | blocks: A list of length equal to the number of ResNet blocks. Each element 147 | is a resnet_utils.Block object describing the units in the block. 148 | num_classes: Number of predicted classes for classification tasks. 149 | If 0 or None, we return the features before the logit layer. 150 | is_training: whether batch_norm layers are in training mode. 151 | global_pool: If True, we perform global average pooling before computing the 152 | logits. Set to True for image classification, False for dense prediction. 153 | output_stride: If None, then the output will be computed at the nominal 154 | network stride. If output_stride is not None, it specifies the requested 155 | ratio of input to output spatial resolution. 156 | include_root_block: If True, include the initial convolution followed by 157 | max-pooling, if False excludes it. If excluded, `inputs` should be the 158 | results of an activation-less convolution. 159 | spatial_squeeze: if True, logits is of shape [B, C], if false logits is 160 | of shape [B, 1, 1, C], where B is batch_size and C is number of classes. 161 | To use this parameter, the input images must be smaller than 300x300 162 | pixels, in which case the output logit layer does not contain spatial 163 | information and can be removed. 164 | reuse: whether or not the network and its variables should be reused. To be 165 | able to reuse 'scope' must be given. 166 | scope: Optional variable_scope. 167 | 168 | 169 | Returns: 170 | net: A rank-4 tensor of size [batch, height_out, width_out, channels_out]. 171 | If global_pool is False, then height_out and width_out are reduced by a 172 | factor of output_stride compared to the respective height_in and width_in, 173 | else both height_out and width_out equal one. If num_classes is 0 or None, 174 | then net is the output of the last ResNet block, potentially after global 175 | average pooling. If num_classes is a non-zero integer, net contains the 176 | pre-softmax activations. 177 | end_points: A dictionary from components of the network to the corresponding 178 | activation. 179 | 180 | Raises: 181 | ValueError: If the target output_stride is not valid. 182 | """ 183 | with tf.variable_scope(scope, 'resnet_v2', [inputs], reuse=reuse) as sc: 184 | end_points_collection = sc.original_name_scope + '_end_points' 185 | with slim.arg_scope([slim.conv2d, bottleneck, 186 | resnet_utils.stack_blocks_dense], 187 | outputs_collections=end_points_collection): 188 | with slim.arg_scope([slim.batch_norm], is_training=is_training): 189 | net = inputs 190 | if include_root_block: 191 | if output_stride is not None: 192 | if output_stride % 4 != 0: 193 | raise ValueError('The output_stride needs to be a multiple of 4.') 194 | output_stride /= 4 195 | # We do not include batch normalization or activation functions in 196 | # conv1 because the first ResNet unit will perform these. Cf. 197 | # Appendix of [2]. 198 | with slim.arg_scope([slim.conv2d], 199 | activation_fn=None, normalizer_fn=None): 200 | net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1') 201 | net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') 202 | net = resnet_utils.stack_blocks_dense(net, blocks, output_stride) 203 | # This is needed because the pre-activation variant does not have batch 204 | # normalization or activation functions in the residual unit output. See 205 | # Appendix of [2]. 206 | net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='postnorm') 207 | # Convert end_points_collection into a dictionary of end_points. 208 | end_points = slim.utils.convert_collection_to_dict( 209 | end_points_collection) 210 | 211 | end_points['PrePool'] = net 212 | if global_pool: 213 | # Global average pooling. 214 | net = tf.reduce_mean(net, [1, 2], name='pool5', keepdims=True) 215 | end_points['global_pool'] = net 216 | if num_classes is not None: 217 | net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None, 218 | normalizer_fn=None, scope='logits') 219 | end_points[sc.name + '/logits'] = net 220 | if spatial_squeeze: 221 | net = tf.squeeze(net, [1, 2], name='SpatialSqueeze') 222 | end_points[sc.name + '/spatial_squeeze'] = net 223 | end_points['Logits'] = net 224 | end_points['predictions'] = slim.softmax(net, scope='predictions') 225 | return net, end_points 226 | resnet_v2.default_image_size = 224 227 | 228 | 229 | def resnet_v2_block(scope, base_depth, num_units, stride): 230 | """Helper function for creating a resnet_v2 bottleneck block. 231 | 232 | Args: 233 | scope: The scope of the block. 234 | base_depth: The depth of the bottleneck layer for each unit. 235 | num_units: The number of units in the block. 236 | stride: The stride of the block, implemented as a stride in the last unit. 237 | All other units have stride=1. 238 | 239 | Returns: 240 | A resnet_v2 bottleneck block. 241 | """ 242 | return resnet_utils.Block(scope, bottleneck, [{ 243 | 'depth': base_depth * 4, 244 | 'depth_bottleneck': base_depth, 245 | 'stride': 1 246 | }] * (num_units - 1) + [{ 247 | 'depth': base_depth * 4, 248 | 'depth_bottleneck': base_depth, 249 | 'stride': stride 250 | }]) 251 | resnet_v2.default_image_size = 224 252 | 253 | 254 | def resnet_v2_50(inputs, 255 | num_classes=None, 256 | is_training=True, 257 | global_pool=True, 258 | output_stride=None, 259 | spatial_squeeze=True, 260 | reuse=None, 261 | scope='resnet_v2_50'): 262 | """ResNet-50 model of [1]. See resnet_v2() for arg and return description.""" 263 | blocks = [ 264 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2), 265 | resnet_v2_block('block2', base_depth=128, num_units=4, stride=2), 266 | resnet_v2_block('block3', base_depth=256, num_units=6, stride=2), 267 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1), 268 | ] 269 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training, 270 | global_pool=global_pool, output_stride=output_stride, 271 | include_root_block=True, spatial_squeeze=spatial_squeeze, 272 | reuse=reuse, scope=scope) 273 | resnet_v2_50.default_image_size = resnet_v2.default_image_size 274 | 275 | 276 | def resnet_v2_101(inputs, 277 | num_classes=None, 278 | is_training=True, 279 | global_pool=True, 280 | output_stride=None, 281 | spatial_squeeze=True, 282 | reuse=None, 283 | scope='resnet_v2_101'): 284 | """ResNet-101 model of [1]. See resnet_v2() for arg and return description.""" 285 | blocks = [ 286 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2), 287 | resnet_v2_block('block2', base_depth=128, num_units=4, stride=2), 288 | resnet_v2_block('block3', base_depth=256, num_units=23, stride=2), 289 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1), 290 | ] 291 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training, 292 | global_pool=global_pool, output_stride=output_stride, 293 | include_root_block=True, spatial_squeeze=spatial_squeeze, 294 | reuse=reuse, scope=scope) 295 | resnet_v2_101.default_image_size = resnet_v2.default_image_size 296 | 297 | 298 | def resnet_v2_152(inputs, 299 | num_classes=None, 300 | is_training=True, 301 | global_pool=True, 302 | output_stride=None, 303 | spatial_squeeze=True, 304 | reuse=None, 305 | scope='resnet_v2_152'): 306 | """ResNet-152 model of [1]. See resnet_v2() for arg and return description.""" 307 | blocks = [ 308 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2), 309 | resnet_v2_block('block2', base_depth=128, num_units=8, stride=2), 310 | resnet_v2_block('block3', base_depth=256, num_units=36, stride=2), 311 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1), 312 | ] 313 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training, 314 | global_pool=global_pool, output_stride=output_stride, 315 | include_root_block=True, spatial_squeeze=spatial_squeeze, 316 | reuse=reuse, scope=scope) 317 | resnet_v2_152.default_image_size = resnet_v2.default_image_size 318 | 319 | 320 | def resnet_v2_200(inputs, 321 | num_classes=None, 322 | is_training=True, 323 | global_pool=True, 324 | output_stride=None, 325 | spatial_squeeze=True, 326 | reuse=None, 327 | scope='resnet_v2_200'): 328 | """ResNet-200 model of [2]. See resnet_v2() for arg and return description.""" 329 | blocks = [ 330 | resnet_v2_block('block1', base_depth=64, num_units=3, stride=2), 331 | resnet_v2_block('block2', base_depth=128, num_units=24, stride=2), 332 | resnet_v2_block('block3', base_depth=256, num_units=36, stride=2), 333 | resnet_v2_block('block4', base_depth=512, num_units=3, stride=1), 334 | ] 335 | return resnet_v2(inputs, blocks, num_classes, is_training=is_training, 336 | global_pool=global_pool, output_stride=output_stride, 337 | include_root_block=True, spatial_squeeze=spatial_squeeze, 338 | reuse=reuse, scope=scope) 339 | resnet_v2_200.default_image_size = resnet_v2.default_image_size 340 | 341 | -------------------------------------------------------------------------------- /model/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hiveml/tensorflow-grad-cam/9f4c9b7a5f9c94a0490b143282abc965ce1cfb0f/model/preprocessing/__init__.py -------------------------------------------------------------------------------- /model/preprocessing/inception_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images for the Inception networks.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from tensorflow.python.ops import control_flow_ops 24 | 25 | 26 | def apply_with_random_selector(x, func, num_cases): 27 | """Computes func(x, sel), with sel sampled from [0...num_cases-1]. 28 | 29 | Args: 30 | x: input Tensor. 31 | func: Python function to apply. 32 | num_cases: Python int32, number of cases to sample sel from. 33 | 34 | Returns: 35 | The result of func(x, sel), where func receives the value of the 36 | selector as a python integer, but sel is sampled dynamically. 37 | """ 38 | sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32) 39 | # Pass the real x only to one of the func calls. 40 | return control_flow_ops.merge([ 41 | func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case) 42 | for case in range(num_cases)])[0] 43 | 44 | 45 | def distort_color(image, color_ordering=0, fast_mode=True, scope=None): 46 | """Distort the color of a Tensor image. 47 | 48 | Each color distortion is non-commutative and thus ordering of the color ops 49 | matters. Ideally we would randomly permute the ordering of the color ops. 50 | Rather then adding that level of complication, we select a distinct ordering 51 | of color ops for each preprocessing thread. 52 | 53 | Args: 54 | image: 3-D Tensor containing single image in [0, 1]. 55 | color_ordering: Python int, a type of distortion (valid values: 0-3). 56 | fast_mode: Avoids slower ops (random_hue and random_contrast) 57 | scope: Optional scope for name_scope. 58 | Returns: 59 | 3-D Tensor color-distorted image on range [0, 1] 60 | Raises: 61 | ValueError: if color_ordering not in [0, 3] 62 | """ 63 | with tf.name_scope(scope, 'distort_color', [image]): 64 | if fast_mode: 65 | if color_ordering == 0: 66 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 67 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 68 | else: 69 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 70 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 71 | else: 72 | if color_ordering == 0: 73 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 74 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 75 | image = tf.image.random_hue(image, max_delta=0.2) 76 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 77 | elif color_ordering == 1: 78 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 79 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 80 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 81 | image = tf.image.random_hue(image, max_delta=0.2) 82 | elif color_ordering == 2: 83 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 84 | image = tf.image.random_hue(image, max_delta=0.2) 85 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 86 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 87 | elif color_ordering == 3: 88 | image = tf.image.random_hue(image, max_delta=0.2) 89 | image = tf.image.random_saturation(image, lower=0.5, upper=1.5) 90 | image = tf.image.random_contrast(image, lower=0.5, upper=1.5) 91 | image = tf.image.random_brightness(image, max_delta=32. / 255.) 92 | else: 93 | raise ValueError('color_ordering must be in [0, 3]') 94 | 95 | # The random_* ops do not necessarily clamp. 96 | return tf.clip_by_value(image, 0.0, 1.0) 97 | 98 | 99 | def distorted_bounding_box_crop(image, 100 | bbox, 101 | min_object_covered=0.1, 102 | aspect_ratio_range=(0.75, 1.33), 103 | area_range=(0.05, 1.0), 104 | max_attempts=100, 105 | scope=None): 106 | """Generates cropped_image using a one of the bboxes randomly distorted. 107 | 108 | See `tf.image.sample_distorted_bounding_box` for more documentation. 109 | 110 | Args: 111 | image: 3-D Tensor of image (it will be converted to floats in [0, 1]). 112 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 113 | where each coordinate is [0, 1) and the coordinates are arranged 114 | as [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the whole 115 | image. 116 | min_object_covered: An optional `float`. Defaults to `0.1`. The cropped 117 | area of the image must contain at least this fraction of any bounding box 118 | supplied. 119 | aspect_ratio_range: An optional list of `floats`. The cropped area of the 120 | image must have an aspect ratio = width / height within this range. 121 | area_range: An optional list of `floats`. The cropped area of the image 122 | must contain a fraction of the supplied image within in this range. 123 | max_attempts: An optional `int`. Number of attempts at generating a cropped 124 | region of the image of the specified constraints. After `max_attempts` 125 | failures, return the entire image. 126 | scope: Optional scope for name_scope. 127 | Returns: 128 | A tuple, a 3-D Tensor cropped_image and the distorted bbox 129 | """ 130 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]): 131 | # Each bounding box has shape [1, num_boxes, box coords] and 132 | # the coordinates are ordered [ymin, xmin, ymax, xmax]. 133 | 134 | # A large fraction of image datasets contain a human-annotated bounding 135 | # box delineating the region of the image containing the object of interest. 136 | # We choose to create a new bounding box for the object which is a randomly 137 | # distorted version of the human-annotated bounding box that obeys an 138 | # allowed range of aspect ratios, sizes and overlap with the human-annotated 139 | # bounding box. If no box is supplied, then we assume the bounding box is 140 | # the entire image. 141 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 142 | tf.shape(image), 143 | bounding_boxes=bbox, 144 | min_object_covered=min_object_covered, 145 | aspect_ratio_range=aspect_ratio_range, 146 | area_range=area_range, 147 | max_attempts=max_attempts, 148 | use_image_if_no_bounding_boxes=True) 149 | bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box 150 | 151 | # Crop the image to the specified bounding box. 152 | cropped_image = tf.slice(image, bbox_begin, bbox_size) 153 | return cropped_image, distort_bbox 154 | 155 | 156 | def preprocess_for_train(image, height, width, bbox, 157 | fast_mode=True, 158 | scope=None, 159 | add_image_summaries=True): 160 | """Distort one image for training a network. 161 | 162 | Distorting images provides a useful technique for augmenting the data 163 | set during training in order to make the network invariant to aspects 164 | of the image that do not effect the label. 165 | 166 | Additionally it would create image_summaries to display the different 167 | transformations applied to the image. 168 | 169 | Args: 170 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be 171 | [0, 1], otherwise it would converted to tf.float32 assuming that the range 172 | is [0, MAX], where MAX is largest positive representable number for 173 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details). 174 | height: integer 175 | width: integer 176 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 177 | where each coordinate is [0, 1) and the coordinates are arranged 178 | as [ymin, xmin, ymax, xmax]. 179 | fast_mode: Optional boolean, if True avoids slower transformations (i.e. 180 | bi-cubic resizing, random_hue or random_contrast). 181 | scope: Optional scope for name_scope. 182 | add_image_summaries: Enable image summaries. 183 | Returns: 184 | 3-D float Tensor of distorted image used for training with range [-1, 1]. 185 | """ 186 | with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]): 187 | if bbox is None: 188 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], 189 | dtype=tf.float32, 190 | shape=[1, 1, 4]) 191 | if image.dtype != tf.float32: 192 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 193 | # Each bounding box has shape [1, num_boxes, box coords] and 194 | # the coordinates are ordered [ymin, xmin, ymax, xmax]. 195 | image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), 196 | bbox) 197 | if add_image_summaries: 198 | tf.summary.image('image_with_bounding_boxes', image_with_box) 199 | 200 | distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox) 201 | # Restore the shape since the dynamic slice based upon the bbox_size loses 202 | # the third dimension. 203 | distorted_image.set_shape([None, None, 3]) 204 | image_with_distorted_box = tf.image.draw_bounding_boxes( 205 | tf.expand_dims(image, 0), distorted_bbox) 206 | if add_image_summaries: 207 | tf.summary.image('images_with_distorted_bounding_box', 208 | image_with_distorted_box) 209 | 210 | # This resizing operation may distort the images because the aspect 211 | # ratio is not respected. We select a resize method in a round robin 212 | # fashion based on the thread number. 213 | # Note that ResizeMethod contains 4 enumerated resizing methods. 214 | 215 | # We select only 1 case for fast_mode bilinear. 216 | num_resize_cases = 1 if fast_mode else 4 217 | distorted_image = apply_with_random_selector( 218 | distorted_image, 219 | lambda x, method: tf.image.resize_images(x, [height, width], method), 220 | num_cases=num_resize_cases) 221 | 222 | if add_image_summaries: 223 | tf.summary.image('cropped_resized_image', 224 | tf.expand_dims(distorted_image, 0)) 225 | 226 | # Randomly flip the image horizontally. 227 | distorted_image = tf.image.random_flip_left_right(distorted_image) 228 | 229 | # Randomly distort the colors. There are 4 ways to do it. 230 | distorted_image = apply_with_random_selector( 231 | distorted_image, 232 | lambda x, ordering: distort_color(x, ordering, fast_mode), 233 | num_cases=4) 234 | 235 | if add_image_summaries: 236 | tf.summary.image('final_distorted_image', 237 | tf.expand_dims(distorted_image, 0)) 238 | distorted_image = tf.subtract(distorted_image, 0.5) 239 | distorted_image = tf.multiply(distorted_image, 2.0) 240 | return distorted_image 241 | 242 | 243 | def preprocess_for_eval(image, height, width, 244 | central_fraction=0.875, scope=None): 245 | """Prepare one image for evaluation. 246 | 247 | If height and width are specified it would output an image with that size by 248 | applying resize_bilinear. 249 | 250 | If central_fraction is specified it would crop the central fraction of the 251 | input image. 252 | 253 | Args: 254 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be 255 | [0, 1], otherwise it would converted to tf.float32 assuming that the range 256 | is [0, MAX], where MAX is largest positive representable number for 257 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details). 258 | height: integer 259 | width: integer 260 | central_fraction: Optional Float, fraction of the image to crop. 261 | scope: Optional scope for name_scope. 262 | Returns: 263 | 3-D float Tensor of prepared image. 264 | """ 265 | with tf.name_scope(scope, 'eval_image', [image, height, width]): 266 | if image.dtype != tf.float32: 267 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 268 | # Crop the central region of the image with an area containing 87.5% of 269 | # the original image. 270 | if central_fraction: 271 | image = tf.image.central_crop(image, central_fraction=central_fraction) 272 | 273 | if height and width: 274 | # Resize the image to the specified height and width. 275 | image = tf.expand_dims(image, 0) 276 | image = tf.image.resize_bilinear(image, [height, width], 277 | align_corners=False) 278 | image = tf.squeeze(image, [0]) 279 | image = tf.subtract(image, 0.5) 280 | image = tf.multiply(image, 2.0) 281 | return image 282 | 283 | 284 | def preprocess_image(image, height, width, 285 | is_training=False, 286 | bbox=None, 287 | fast_mode=True, 288 | add_image_summaries=True): 289 | """Pre-process one image for training or evaluation. 290 | 291 | Args: 292 | image: 3-D Tensor [height, width, channels] with the image. If dtype is 293 | tf.float32 then the range should be [0, 1], otherwise it would converted 294 | to tf.float32 assuming that the range is [0, MAX], where MAX is largest 295 | positive representable number for int(8/16/32) data type (see 296 | `tf.image.convert_image_dtype` for details). 297 | height: integer, image expected height. 298 | width: integer, image expected width. 299 | is_training: Boolean. If true it would transform an image for train, 300 | otherwise it would transform it for evaluation. 301 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 302 | where each coordinate is [0, 1) and the coordinates are arranged as 303 | [ymin, xmin, ymax, xmax]. 304 | fast_mode: Optional boolean, if True avoids slower transformations. 305 | add_image_summaries: Enable image summaries. 306 | 307 | Returns: 308 | 3-D float Tensor containing an appropriately scaled image 309 | 310 | Raises: 311 | ValueError: if user does not provide bounding box 312 | """ 313 | if is_training: 314 | return preprocess_for_train(image, height, width, bbox, fast_mode, 315 | add_image_summaries=add_image_summaries) 316 | else: 317 | return preprocess_for_eval(image, height, width) 318 | 319 | -------------------------------------------------------------------------------- /model/preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | #from preprocessing import cifarnet_preprocessing 24 | from model.preprocessing import inception_preprocessing 25 | #from preprocessing import lenet_preprocessing 26 | #from preprocessing import vgg_preprocessing 27 | 28 | slim = tf.contrib.slim 29 | 30 | 31 | def get_preprocessing(name, is_training=False): 32 | """Returns preprocessing_fn(image, height, width, **kwargs). 33 | 34 | Args: 35 | name: The name of the preprocessing function. 36 | is_training: `True` if the model is being used for training and `False` 37 | otherwise. 38 | 39 | Returns: 40 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 41 | It has the following signature: 42 | image = preprocessing_fn(image, output_height, output_width, ...). 43 | 44 | Raises: 45 | ValueError: If Preprocessing `name` is not recognized. 46 | """ 47 | preprocessing_fn_map = { 48 | 'resnet_v2_50': inception_preprocessing, 49 | 'resnet_v2_101': inception_preprocessing, 50 | 'resnet_v2_152': inception_preprocessing, 51 | 'resnet_v2_200': inception_preprocessing, 52 | } 53 | 54 | if name not in preprocessing_fn_map: 55 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 56 | 57 | def preprocessing_fn(image, output_height, output_width, **kwargs): 58 | return preprocessing_fn_map[name].preprocess_image( 59 | image, output_height, output_width, is_training=is_training, **kwargs) 60 | 61 | return preprocessing_fn 62 | 63 | --------------------------------------------------------------------------------