├── conf ├── candy.yml ├── cubist.yml ├── denoised_starry.yml ├── feathers.yml ├── mosaic.yml ├── painting.yml ├── picasso.yml ├── scream.yml ├── udnie.yml └── wave.yml ├── img ├── content │ ├── test.jpg │ ├── test1.jpg │ ├── test2.jpg │ ├── test3.jpg │ ├── test4.jpg │ ├── test5.jpg │ ├── test6.jpg │ ├── test7.jpg │ ├── test8.jpg │ └── test9.png ├── favicon.ico ├── generated │ ├── cubist_res.jpg │ ├── denoised_starry_res.jpg │ ├── denoised_starry_res_test6.jpg │ ├── denoised_starry_res_test7.jpg │ ├── feathers_res_test7.jpg │ ├── mosaic_res_test7.jpg │ ├── painting_res.jpg │ ├── painting_res_test6.jpg │ ├── painting_res_test7.jpg │ ├── scream_res_test6.jpg │ ├── scream_res_test7.jpg │ ├── target_style_painting.jpg │ ├── test6.jpg │ └── udnie_res_test7.jpg ├── style │ ├── candy.jpg │ ├── cubist.jpg │ ├── denoised_starry.jpg │ ├── feathers.jpg │ ├── gouache.jpg │ ├── mosaic.jpg │ ├── painting.jpg │ ├── picasso.jpg │ ├── scream.jpg │ ├── starry.jpg │ ├── udnie.jpg │ └── wave.jpg └── uploads │ ├── test6.jpg │ └── test7.jpg ├── losses.py ├── model.py ├── models ├── denoised_starry.ckpt-done ├── painting │ ├── painting.ckpt-1000 │ ├── painting.ckpt-1000.meta │ ├── painting.ckpt-2000 │ └── painting.ckpt-2000.meta ├── picasso │ ├── checkpoint │ ├── picasso.ckpt-100 │ └── picasso.ckpt-100.meta └── scream.ckpt-done ├── nets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── alexnet.cpython-35.pyc │ ├── cifarnet.cpython-35.pyc │ ├── inception.cpython-35.pyc │ ├── inception_resnet_v2.cpython-35.pyc │ ├── inception_utils.cpython-35.pyc │ ├── inception_v1.cpython-35.pyc │ ├── inception_v2.cpython-35.pyc │ ├── inception_v3.cpython-35.pyc │ ├── inception_v4.cpython-35.pyc │ ├── lenet.cpython-35.pyc │ ├── nets_factory.cpython-35.pyc │ ├── overfeat.cpython-35.pyc │ ├── resnet_utils.cpython-35.pyc │ ├── resnet_v1.cpython-35.pyc │ ├── resnet_v2.cpython-35.pyc │ └── vgg.cpython-35.pyc ├── alexnet.py ├── alexnet_test.py ├── cifarnet.py ├── inception.py ├── inception_resnet_v2.py ├── inception_resnet_v2_test.py ├── inception_utils.py ├── inception_v1.py ├── inception_v1_test.py ├── inception_v2.py ├── inception_v2_test.py ├── inception_v3.py ├── inception_v3_test.py ├── inception_v4.py ├── inception_v4_test.py ├── lenet.py ├── nets_factory.py ├── nets_factory_test.py ├── overfeat.py ├── overfeat_test.py ├── resnet_utils.py ├── resnet_v1.py ├── resnet_v1_test.py ├── resnet_v2.py ├── resnet_v2_test.py ├── vgg.py └── vgg_test.py ├── preprocessing ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-35.pyc │ ├── cifarnet_preprocessing.cpython-35.pyc │ ├── inception_preprocessing.cpython-35.pyc │ ├── lenet_preprocessing.cpython-35.pyc │ ├── preprocessing_factory.cpython-35.pyc │ └── vgg_preprocessing.cpython-35.pyc ├── cifarnet_preprocessing.py ├── inception_preprocessing.py ├── lenet_preprocessing.py ├── preprocessing_factory.py └── vgg_preprocessing.py ├── reader.py ├── static ├── css │ ├── bootstrap.min.css │ ├── font-awesome.min.css │ └── style.css ├── img │ ├── background.jpg │ ├── content │ │ ├── test.jpg │ │ ├── test1.jpg │ │ ├── test2.jpg │ │ ├── test3.jpg │ │ ├── test4.jpg │ │ ├── test5.jpg │ │ ├── test6.jpg │ │ ├── test7.jpg │ │ ├── test8.jpg │ │ └── test9.png │ ├── favicon.ico │ ├── generated │ │ ├── cubist_res.jpg │ │ ├── cubist_res_test6.jpg │ │ ├── denoised_starry_res.jpg │ │ ├── denoised_starry_res_test.jpg │ │ ├── denoised_starry_res_test1.jpg │ │ ├── denoised_starry_res_test2.jpg │ │ ├── denoised_starry_res_test3.jpg │ │ ├── denoised_starry_res_test4.jpg │ │ ├── denoised_starry_res_test5.jpg │ │ ├── denoised_starry_res_test6.jpg │ │ ├── denoised_starry_res_test7.jpg │ │ ├── denoised_starry_res_test9.png │ │ ├── feathers_res_test7.jpg │ │ ├── mosaic_res_test6.jpg │ │ ├── mosaic_res_test7.jpg │ │ ├── painting_res.jpg │ │ ├── painting_res_test6.jpg │ │ ├── painting_res_test7.jpg │ │ ├── scream_res_test.jpg │ │ ├── scream_res_test4.jpg │ │ ├── scream_res_test6.jpg │ │ ├── scream_res_test7.jpg │ │ ├── target_style_painting.jpg │ │ ├── test6.jpg │ │ ├── udnie_res_test7.jpg │ │ ├── wave_res_test5.jpg │ │ └── wave_res_test6.jpg │ ├── loading1.gif │ ├── loading2.gif │ ├── loading3.gif │ ├── loading4.gif │ ├── style │ │ ├── candy.jpg │ │ ├── cubist.jpg │ │ ├── denoised_starry.jpg │ │ ├── feathers.jpg │ │ ├── gouache.jpg │ │ ├── mosaic.jpg │ │ ├── painting.jpg │ │ ├── picasso.jpg │ │ ├── scream.jpg │ │ ├── starry.jpg │ │ ├── udnie.jpg │ │ └── wave.jpg │ └── uploads │ │ ├── test.jpg │ │ ├── test1.jpg │ │ ├── test2.jpg │ │ ├── test3.jpg │ │ ├── test4.jpg │ │ ├── test5.jpg │ │ ├── test6.jpg │ │ ├── test7.jpg │ │ └── test9.png └── js │ ├── jquery.min.js │ ├── maple.js │ └── prefixfree.min.js ├── templates ├── index.html └── transformed.html ├── train.py ├── transform.py ├── utils.py └── web.py /conf/candy.yml: -------------------------------------------------------------------------------- 1 | ## Basic configuration 2 | style_image: static/img/style/candy.jpg # targeted style image 3 | naming: "candy" # the name of this model. Determine the path to save checkpoint and events file. 4 | model_path: models # root path to save checkpoint and events file. The final path would be / 5 | 6 | ## Weight of the loss 7 | content_weight: 1.0 # weight for content features loss 8 | style_weight: 50.0 # weight for style features loss 9 | tv_weight: 0.0 # weight for total variation loss 10 | 11 | ## The size, the iter number to run 12 | image_size: 256 13 | batch_size: 4 14 | epoch: 2 15 | 16 | ## Loss Network 17 | loss_model: "vgg_16" 18 | content_layers: # use these layers for content loss 19 | - "vgg_16/conv3/conv3_3" 20 | style_layers: # use these layers for style loss 21 | - "vgg_16/conv1/conv1_2" 22 | - "vgg_16/conv2/conv2_2" 23 | - "vgg_16/conv3/conv3_3" 24 | - "vgg_16/conv4/conv4_3" 25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers. 26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint 27 | -------------------------------------------------------------------------------- /conf/cubist.yml: -------------------------------------------------------------------------------- 1 | ## Basic configuration 2 | style_image: static/img/style/cubist.jpg # targeted style image 3 | naming: "cubist" # the name of this model. Determine the path to save checkpoint and events file. 4 | model_path: models # root path to save checkpoint and events file. The final path would be / 5 | 6 | ## Weight of the loss 7 | content_weight: 1.0 # weight for content features loss 8 | style_weight: 180.0 # weight for style features loss 9 | tv_weight: 0.0 # weight for total variation loss 10 | 11 | ## The size, the iter number to run 12 | image_size: 256 13 | batch_size: 4 14 | epoch: 2 15 | 16 | ## Loss Network 17 | loss_model: "vgg_16" 18 | content_layers: # use these layers for content loss 19 | - "vgg_16/conv3/conv3_3" 20 | style_layers: # use these layers for style loss 21 | - "vgg_16/conv1/conv1_2" 22 | - "vgg_16/conv2/conv2_2" 23 | - "vgg_16/conv3/conv3_3" 24 | - "vgg_16/conv4/conv4_3" 25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers. 26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint -------------------------------------------------------------------------------- /conf/denoised_starry.yml: -------------------------------------------------------------------------------- 1 | ## Basic configuration 2 | style_image: static/img/style/denoised_starry.jpg # targeted style image 3 | naming: "denoised_starry" # the name of this model. Determine the path to save checkpoint and events file. 4 | model_path: models # root path to save checkpoint and events file. The final path would be / 5 | 6 | ## Weight of the loss 7 | content_weight: 1.0 # weight for content features loss 8 | style_weight: 250 # weight for style features loss 9 | tv_weight: 0.0 # weight for total variation loss 10 | 11 | ## The size, the iter number to run 12 | image_size: 256 13 | batch_size: 4 14 | epoch: 2 15 | 16 | ## Loss Network 17 | loss_model: "vgg_16" 18 | content_layers: # use these layers for content loss 19 | - "vgg_16/conv3/conv3_3" 20 | style_layers: # use these layers for style loss 21 | - "vgg_16/conv1/conv1_2" 22 | - "vgg_16/conv2/conv2_2" 23 | - "vgg_16/conv3/conv3_3" 24 | - "vgg_16/conv4/conv4_3" 25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers. 26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint -------------------------------------------------------------------------------- /conf/feathers.yml: -------------------------------------------------------------------------------- 1 | ## Basic configuration 2 | style_image: static/img/style/feathers.jpg # targeted style image 3 | naming: "feathers" # the name of this model. Determine the path to save checkpoint and events file. 4 | model_path: models # root path to save checkpoint and events file. The final path would be / 5 | 6 | ## Weight of the loss 7 | content_weight: 1.0 # weight for content features loss 8 | style_weight: 220.0 # weight for style features loss 9 | tv_weight: 0.0 # weight for total variation loss 10 | 11 | ## The size, the iter number to run 12 | image_size: 256 13 | batch_size: 4 14 | epoch: 2 15 | 16 | ## Loss Network 17 | loss_model: "vgg_16" 18 | content_layers: # use these layers for content loss 19 | - "vgg_16/conv3/conv3_3" 20 | style_layers: # use these layers for style loss 21 | - "vgg_16/conv1/conv1_2" 22 | - "vgg_16/conv2/conv2_2" 23 | - "vgg_16/conv3/conv3_3" 24 | - "vgg_16/conv4/conv4_3" 25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers. 26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint -------------------------------------------------------------------------------- /conf/mosaic.yml: -------------------------------------------------------------------------------- 1 | ## Basic configuration 2 | style_image: static/img/style/mosaic.jpg # targeted style image 3 | naming: "mosaic" # the name of this model. Determine the path to save checkpoint and events file. 4 | model_path: models # root path to save checkpoint and events file. The final path would be / 5 | 6 | ## Weight of the loss 7 | content_weight: 1.0 # weight for content features loss 8 | style_weight: 100.0 # weight for style features loss 9 | tv_weight: 0.0 # weight for total variation loss 10 | 11 | ## The size, the iter number to run 12 | image_size: 256 13 | batch_size: 4 14 | epoch: 2 15 | 16 | ## Loss Network 17 | loss_model: "vgg_16" 18 | content_layers: # use these layers for content loss 19 | - "vgg_16/conv3/conv3_3" 20 | style_layers: # use these layers for style loss 21 | - "vgg_16/conv1/conv1_2" 22 | - "vgg_16/conv2/conv2_2" 23 | - "vgg_16/conv3/conv3_3" 24 | - "vgg_16/conv4/conv4_3" 25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers. 26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint -------------------------------------------------------------------------------- /conf/painting.yml: -------------------------------------------------------------------------------- 1 | ## Basic configuration 2 | style_image: static/img/style/painting.jpg # targeted style image 3 | naming: "painting" # the name of this model. Determine the path to save checkpoint and events file. 4 | model_path: models # root path to save checkpoint and events file. The final path would be / 5 | 6 | ## Weight of the loss 7 | content_weight: 1.0 # weight for content features loss 8 | style_weight: 220.0 # weight for style features loss 9 | tv_weight: 0.0 # weight for total variation loss 10 | 11 | ## The size, the iter number to run 12 | image_size: 256 13 | batch_size: 4 14 | epoch: 2 15 | 16 | ## Loss Network 17 | loss_model: "vgg_16" 18 | content_layers: # use these layers for content loss 19 | - "vgg_16/conv3/conv3_3" 20 | style_layers: # use these layers for style loss 21 | - "vgg_16/conv1/conv1_2" 22 | - "vgg_16/conv2/conv2_2" 23 | - "vgg_16/conv3/conv3_3" 24 | - "vgg_16/conv4/conv4_3" 25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers. 26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint 27 | -------------------------------------------------------------------------------- /conf/picasso.yml: -------------------------------------------------------------------------------- 1 | ## Basic configuration 2 | style_image: static/img/style/picasso.jpg # targeted style image 3 | naming: "picasso" # the name of this model. Determine the path to save checkpoint and events file. 4 | model_path: models # root path to save checkpoint and events file. The final path would be / 5 | 6 | ## Weight of the loss 7 | content_weight: 1.0 # weight for content features loss 8 | style_weight: 50.0 # weight for style features loss 9 | tv_weight: 0.0 # weight for total variation loss 10 | 11 | ## The size, the iter number to run 12 | image_size: 256 13 | batch_size: 4 14 | epoch: 2 15 | 16 | ## Loss Network 17 | loss_model: "vgg_16" 18 | content_layers: # use these layers for content loss 19 | - "vgg_16/conv3/conv3_3" 20 | style_layers: # use these layers for style loss 21 | - "vgg_16/conv1/conv1_2" 22 | - "vgg_16/conv2/conv2_2" 23 | - "vgg_16/conv3/conv3_3" 24 | - "vgg_16/conv4/conv4_3" 25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers. 26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint 27 | -------------------------------------------------------------------------------- /conf/scream.yml: -------------------------------------------------------------------------------- 1 | ## Basic configuration 2 | style_image: static/img/style/scream.jpg # targeted style image 3 | naming: "scream" # the name of this model. Determine the path to save checkpoint and events file. 4 | model_path: models # root path to save checkpoint and events file. The final path would be / 5 | 6 | ## Weight of the loss 7 | content_weight: 1.0 # weight for content features loss 8 | style_weight: 250.0 # weight for style features loss 9 | tv_weight: 0.0 # weight for total variation loss 10 | 11 | ## The size, the iter number to run 12 | image_size: 256 13 | batch_size: 4 14 | epoch: 2 15 | 16 | ## Loss Network 17 | loss_model: "vgg_16" 18 | content_layers: # use these layers for content loss 19 | - "vgg_16/conv3/conv3_3" 20 | style_layers: # use these layers for style loss 21 | - "vgg_16/conv1/conv1_2" 22 | - "vgg_16/conv2/conv2_2" 23 | - "vgg_16/conv3/conv3_3" 24 | - "vgg_16/conv4/conv4_3" 25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers. 26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint -------------------------------------------------------------------------------- /conf/udnie.yml: -------------------------------------------------------------------------------- 1 | ## Basic configuration 2 | style_image: static/img/style/udnie.jpg # targeted style image 3 | naming: "udnie" # the name of this model. Determine the path to save checkpoint and events file. 4 | model_path: models # root path to save checkpoint and events file. The final path would be / 5 | 6 | ## Weight of the loss 7 | content_weight: 1.0 # weight for content features loss 8 | style_weight: 200.0 # weight for style features loss 9 | tv_weight: 0.0 # weight for total variation loss 10 | 11 | ## The size, the iter number to run 12 | image_size: 256 13 | batch_size: 4 14 | epoch: 2 15 | 16 | ## Loss Network 17 | loss_model: "vgg_16" 18 | content_layers: # use these layers for content loss 19 | - "vgg_16/conv3/conv3_3" 20 | style_layers: # use these layers for style loss 21 | - "vgg_16/conv1/conv1_2" 22 | - "vgg_16/conv2/conv2_2" 23 | - "vgg_16/conv3/conv3_3" 24 | - "vgg_16/conv4/conv4_3" 25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers. 26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint -------------------------------------------------------------------------------- /conf/wave.yml: -------------------------------------------------------------------------------- 1 | ## Basic configuration 2 | style_image: static/img/style/wave.jpg # targeted style image 3 | naming: "wave" # the name of this model. Determine the path to save checkpoint and events file. 4 | model_path: models # root path to save checkpoint and events file. The final path would be / 5 | 6 | ## Weight of the loss 7 | content_weight: 1.0 # weight for content features loss 8 | style_weight: 220.0 # weight for style features loss 9 | tv_weight: 0.0 # weight for total variation loss 10 | 11 | ## The size, the iter number to run 12 | image_size: 256 13 | batch_size: 4 14 | epoch: 2 15 | 16 | ## Loss Network 17 | loss_model: "vgg_16" 18 | content_layers: # use these layers for content loss 19 | - "vgg_16/conv3/conv3_3" 20 | style_layers: # use these layers for style loss 21 | - "vgg_16/conv1/conv1_2" 22 | - "vgg_16/conv2/conv2_2" 23 | - "vgg_16/conv3/conv3_3" 24 | - "vgg_16/conv4/conv4_3" 25 | checkpoint_exclude_scopes: "vgg_16/fc" # we only use the convolution layers, so ignore fc layers. 26 | loss_model_file: "pretrained/vgg_16.ckpt" # the path to the checkpoint 27 | -------------------------------------------------------------------------------- /img/content/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/content/test.jpg -------------------------------------------------------------------------------- /img/content/test1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/content/test1.jpg -------------------------------------------------------------------------------- /img/content/test2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/content/test2.jpg -------------------------------------------------------------------------------- /img/content/test3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/content/test3.jpg -------------------------------------------------------------------------------- /img/content/test4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/content/test4.jpg -------------------------------------------------------------------------------- /img/content/test5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/content/test5.jpg -------------------------------------------------------------------------------- /img/content/test6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/content/test6.jpg -------------------------------------------------------------------------------- /img/content/test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/content/test7.jpg -------------------------------------------------------------------------------- /img/content/test8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/content/test8.jpg -------------------------------------------------------------------------------- /img/content/test9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/content/test9.png -------------------------------------------------------------------------------- /img/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/favicon.ico -------------------------------------------------------------------------------- /img/generated/cubist_res.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/generated/cubist_res.jpg -------------------------------------------------------------------------------- /img/generated/denoised_starry_res.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/generated/denoised_starry_res.jpg -------------------------------------------------------------------------------- /img/generated/denoised_starry_res_test6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/generated/denoised_starry_res_test6.jpg -------------------------------------------------------------------------------- /img/generated/denoised_starry_res_test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/generated/denoised_starry_res_test7.jpg -------------------------------------------------------------------------------- /img/generated/feathers_res_test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/generated/feathers_res_test7.jpg -------------------------------------------------------------------------------- /img/generated/mosaic_res_test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/generated/mosaic_res_test7.jpg -------------------------------------------------------------------------------- /img/generated/painting_res.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/generated/painting_res.jpg -------------------------------------------------------------------------------- /img/generated/painting_res_test6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/generated/painting_res_test6.jpg -------------------------------------------------------------------------------- /img/generated/painting_res_test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/generated/painting_res_test7.jpg -------------------------------------------------------------------------------- /img/generated/scream_res_test6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/generated/scream_res_test6.jpg -------------------------------------------------------------------------------- /img/generated/scream_res_test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/generated/scream_res_test7.jpg -------------------------------------------------------------------------------- /img/generated/target_style_painting.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/generated/target_style_painting.jpg -------------------------------------------------------------------------------- /img/generated/test6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/generated/test6.jpg -------------------------------------------------------------------------------- /img/generated/udnie_res_test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/generated/udnie_res_test7.jpg -------------------------------------------------------------------------------- /img/style/candy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/style/candy.jpg -------------------------------------------------------------------------------- /img/style/cubist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/style/cubist.jpg -------------------------------------------------------------------------------- /img/style/denoised_starry.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/style/denoised_starry.jpg -------------------------------------------------------------------------------- /img/style/feathers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/style/feathers.jpg -------------------------------------------------------------------------------- /img/style/gouache.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/style/gouache.jpg -------------------------------------------------------------------------------- /img/style/mosaic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/style/mosaic.jpg -------------------------------------------------------------------------------- /img/style/painting.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/style/painting.jpg -------------------------------------------------------------------------------- /img/style/picasso.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/style/picasso.jpg -------------------------------------------------------------------------------- /img/style/scream.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/style/scream.jpg -------------------------------------------------------------------------------- /img/style/starry.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/style/starry.jpg -------------------------------------------------------------------------------- /img/style/udnie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/style/udnie.jpg -------------------------------------------------------------------------------- /img/style/wave.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/style/wave.jpg -------------------------------------------------------------------------------- /img/uploads/test6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/uploads/test6.jpg -------------------------------------------------------------------------------- /img/uploads/test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/img/uploads/test7.jpg -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import print_function 3 | import tensorflow as tf 4 | from nets import nets_factory 5 | from preprocessing import preprocessing_factory 6 | import utils 7 | import os 8 | 9 | slim = tf.contrib.slim 10 | 11 | 12 | def gram(layer): 13 | shape = tf.shape(layer) 14 | num_images = shape[0] 15 | width = shape[1] 16 | height = shape[2] 17 | num_filters = shape[3] 18 | filters = tf.reshape(layer, tf.stack([num_images, -1, num_filters])) 19 | grams = tf.matmul(filters, filters, transpose_a=True) / tf.to_float(width * height * num_filters) 20 | return grams 21 | 22 | 23 | def get_style_features(FLAGS): 24 | """ 25 | 对于风格图片,预处理步骤: 26 | 1. Resize the shorter side to FLAGS.image_size 27 | 2. Apply central crop 28 | """ 29 | with tf.Graph().as_default(): 30 | network_fn = nets_factory.get_network_fn( 31 | FLAGS.loss_model, 32 | num_classes=1, 33 | is_training=False) 34 | image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing( 35 | FLAGS.loss_model, 36 | is_training=False) 37 | 38 | size = FLAGS.image_size 39 | img_bytes = tf.read_file(FLAGS.style_image) 40 | if FLAGS.style_image.lower().endswith('png'): 41 | image = tf.image.decode_png(img_bytes) 42 | else: 43 | image = tf.image.decode_jpeg(img_bytes) 44 | # image = _aspect_preserving_resize(image, size) 45 | images = tf.stack([image_preprocessing_fn(image, size, size)]) 46 | _, endpoints_dict = network_fn(images, spatial_squeeze=False) 47 | features = [] 48 | for layer in FLAGS.style_layers: 49 | feature = endpoints_dict[layer] 50 | feature = tf.squeeze(gram(feature), [0]) # remove the batch dimension 51 | features.append(feature) 52 | 53 | with tf.Session() as sess: 54 | init_func = utils._get_init_fn(FLAGS) 55 | init_func(sess) 56 | if os.path.exists('static/img/generated') is False: 57 | os.makedirs('static/img/generated') 58 | save_file = 'static/img/generated/target_style_' + FLAGS.naming + '.jpg' 59 | with open(save_file, 'wb') as f: 60 | target_image = image_unprocessing_fn(images[0, :]) 61 | value = tf.image.encode_jpeg(tf.cast(target_image, tf.uint8)) 62 | f.write(sess.run(value)) 63 | tf.logging.info('Target style pattern is saved to: %s.' % save_file) 64 | return sess.run(features) 65 | 66 | 67 | def style_loss(endpoints_dict, style_features_t, style_layers): 68 | style_loss = 0 69 | style_loss_summary = {} 70 | for style_gram, layer in zip(style_features_t, style_layers): 71 | generated_images, _ = tf.split(endpoints_dict[layer], 2, 0) 72 | size = tf.size(generated_images) 73 | layer_style_loss = tf.nn.l2_loss(gram(generated_images) - style_gram) * 2 / tf.to_float(size) 74 | style_loss_summary[layer] = layer_style_loss 75 | style_loss += layer_style_loss 76 | return style_loss, style_loss_summary 77 | 78 | 79 | def content_loss(endpoints_dict, content_layers): 80 | content_loss = 0 81 | for layer in content_layers: 82 | generated_images, content_images = tf.split(endpoints_dict[layer], 2, 0) 83 | size = tf.size(generated_images) 84 | content_loss += tf.nn.l2_loss(generated_images - content_images) * 2 / tf.to_float(size) # remain the same as in the paper 85 | return content_loss 86 | 87 | 88 | def total_variation_loss(layer): 89 | shape = tf.shape(layer) 90 | height = shape[1] 91 | width = shape[2] 92 | y = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, height - 1, -1, -1])) - tf.slice(layer, [0, 1, 0, 0], [-1, -1, -1, -1]) 93 | x = tf.slice(layer, [0, 0, 0, 0], tf.stack([-1, -1, width - 1, -1])) - tf.slice(layer, [0, 0, 1, 0], [-1, -1, -1, -1]) 94 | loss = tf.nn.l2_loss(x) / tf.to_float(tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y)) 95 | return loss 96 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def conv2d(x, input_filters, output_filters, kernel, strides, mode='REFLECT'): 5 | with tf.variable_scope('conv') as scope: 6 | shape = [kernel, kernel, input_filters, output_filters] 7 | # 卷积核/滤波器/权重矩阵 8 | weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight') 9 | x_padded = tf.pad(x, [[0, 0], [kernel // 2, kernel // 2], [kernel // 2, kernel // 2], [0, 0]], mode=mode) 10 | return tf.nn.conv2d(x_padded, weight, strides=[1, strides, strides, 1], padding='VALID', name='conv') 11 | 12 | 13 | def conv2d_transpose(x, input_filters, output_filters, kernel, strides): 14 | with tf.variable_scope('conv_transpose') as scope: 15 | 16 | shape = [kernel, kernel, output_filters, input_filters] 17 | weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight') 18 | 19 | batch_size = tf.shape(x)[0] 20 | height = tf.shape(x)[1] * strides 21 | width = tf.shape(x)[2] * strides 22 | output_shape = tf.stack([batch_size, height, width, output_filters]) 23 | return tf.nn.conv2d_transpose(x, weight, output_shape, strides=[1, strides, strides, 1], name='conv_transpose') 24 | 25 | 26 | def resize_conv2d(x, input_filters, output_filters, kernel, strides, training): 27 | """ 28 | 采用先放大再卷积的方式,取代用转置卷积做上采样 29 | """ 30 | with tf.variable_scope('conv_transpose') as scope: 31 | height = x.get_shape()[1].value if training else tf.shape(x)[1] 32 | width = x.get_shape()[2].value if training else tf.shape(x)[2] 33 | 34 | new_height = height * strides * 2 35 | new_width = width * strides * 2 36 | 37 | x_resized = tf.image.resize_images(x, [new_height, new_width], tf.image.ResizeMethod.NEAREST_NEIGHBOR) 38 | 39 | return conv2d(x_resized, input_filters, output_filters, kernel, strides) 40 | 41 | 42 | def instance_norm(x): 43 | """ 44 | instance normalization 取代batch normalization 45 | 正则化,防止过拟合 46 | """ 47 | epsilon = 1e-9 48 | 49 | mean, var = tf.nn.moments(x, [1, 2], keep_dims=True) 50 | 51 | return tf.div(tf.subtract(x, mean), tf.sqrt(tf.add(var, epsilon))) 52 | 53 | 54 | def batch_norm(x, size, training, decay=0.999): 55 | beta = tf.Variable(tf.zeros([size]), name='beta') 56 | scale = tf.Variable(tf.ones([size]), name='scale') 57 | pop_mean = tf.Variable(tf.zeros([size])) 58 | pop_var = tf.Variable(tf.ones([size])) 59 | epsilon = 1e-3 60 | 61 | batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2]) 62 | train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay)) 63 | train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay)) 64 | 65 | def batch_statistics(): 66 | with tf.control_dependencies([train_mean, train_var]): 67 | return tf.nn.batch_normalization(x, batch_mean, batch_var, beta, scale, epsilon, name='batch_norm') 68 | 69 | def population_statistics(): 70 | return tf.nn.batch_normalization(x, pop_mean, pop_var, beta, scale, epsilon, name='batch_norm') 71 | 72 | return tf.cond(training, batch_statistics, population_statistics) 73 | 74 | 75 | def residual(x, filters, kernel, strides): 76 | with tf.variable_scope('residual') as scope: 77 | conv1 = conv2d(x, filters, filters, kernel, strides) 78 | conv2 = conv2d(tf.nn.relu(conv1), filters, filters, kernel, strides) 79 | 80 | residual = x + conv2 81 | 82 | return residual 83 | 84 | 85 | def transform_network(image, training): 86 | """ 87 | 生成网络 88 | :param image: 内容图片 89 | :param training: 是否训练 90 | :return: 风格迁移后的图片 91 | """ 92 | image = tf.pad(image, [[0, 0], [10, 10], [10, 10], [0, 0]], mode='REFLECT') 93 | 94 | with tf.variable_scope('conv1'): 95 | conv1 = tf.nn.relu(instance_norm(conv2d(image, 3, 32, 9, 1))) 96 | with tf.variable_scope('conv2'): 97 | conv2 = tf.nn.relu(instance_norm(conv2d(conv1, 32, 64, 3, 2))) 98 | with tf.variable_scope('conv3'): 99 | conv3 = tf.nn.relu(instance_norm(conv2d(conv2, 64, 128, 3, 2))) 100 | with tf.variable_scope('res1'): 101 | res1 = residual(conv3, 128, 3, 1) 102 | with tf.variable_scope('res2'): 103 | res2 = residual(res1, 128, 3, 1) 104 | with tf.variable_scope('res3'): 105 | res3 = residual(res2, 128, 3, 1) 106 | with tf.variable_scope('res4'): 107 | res4 = residual(res3, 128, 3, 1) 108 | with tf.variable_scope('res5'): 109 | res5 = residual(res4, 128, 3, 1) 110 | with tf.variable_scope('deconv1'): 111 | deconv1 = tf.nn.relu(instance_norm(resize_conv2d(res5, 128, 64, 3, 2, training))) 112 | with tf.variable_scope('deconv2'): 113 | deconv2 = tf.nn.relu(instance_norm(resize_conv2d(deconv1, 64, 32, 3, 2, training))) 114 | with tf.variable_scope('deconv3'): 115 | deconv3 = tf.nn.tanh(instance_norm(conv2d(deconv2, 32, 3, 9, 1))) 116 | 117 | y = (deconv3 + 1) * 127.5 118 | 119 | # Remove border effect reducing padding. 120 | height = tf.shape(y)[1] 121 | width = tf.shape(y)[2] 122 | y = tf.slice(y, [0, 10, 10, 0], tf.stack([-1, height - 20, width - 20, -1])) 123 | 124 | return y 125 | -------------------------------------------------------------------------------- /models/denoised_starry.ckpt-done: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/models/denoised_starry.ckpt-done -------------------------------------------------------------------------------- /models/painting/painting.ckpt-1000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/models/painting/painting.ckpt-1000 -------------------------------------------------------------------------------- /models/painting/painting.ckpt-1000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/models/painting/painting.ckpt-1000.meta -------------------------------------------------------------------------------- /models/painting/painting.ckpt-2000: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/models/painting/painting.ckpt-2000 -------------------------------------------------------------------------------- /models/painting/painting.ckpt-2000.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/models/painting/painting.ckpt-2000.meta -------------------------------------------------------------------------------- /models/picasso/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "picasso.ckpt-100" 2 | all_model_checkpoint_paths: "picasso.ckpt-100" 3 | -------------------------------------------------------------------------------- /models/picasso/picasso.ckpt-100: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/models/picasso/picasso.ckpt-100 -------------------------------------------------------------------------------- /models/picasso/picasso.ckpt-100.meta: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/models/picasso/picasso.ckpt-100.meta -------------------------------------------------------------------------------- /models/scream.ckpt-done: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/models/scream.ckpt-done -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nets/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/nets/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/alexnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/nets/__pycache__/alexnet.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/cifarnet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/nets/__pycache__/cifarnet.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/nets/__pycache__/inception.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_resnet_v2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/nets/__pycache__/inception_resnet_v2.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/nets/__pycache__/inception_utils.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_v1.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/nets/__pycache__/inception_v1.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_v2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/nets/__pycache__/inception_v2.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_v3.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/nets/__pycache__/inception_v3.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_v4.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/nets/__pycache__/inception_v4.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/lenet.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/nets/__pycache__/lenet.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/nets_factory.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/nets/__pycache__/nets_factory.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/overfeat.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/nets/__pycache__/overfeat.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/resnet_utils.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/nets/__pycache__/resnet_utils.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/resnet_v1.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/nets/__pycache__/resnet_v1.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/resnet_v2.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/nets/__pycache__/resnet_v2.cpython-35.pyc -------------------------------------------------------------------------------- /nets/__pycache__/vgg.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/nets/__pycache__/vgg.cpython-35.pyc -------------------------------------------------------------------------------- /nets/alexnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a model definition for AlexNet. 16 | 17 | This work was first described in: 18 | ImageNet Classification with Deep Convolutional Neural Networks 19 | Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton 20 | 21 | and later refined in: 22 | One weird trick for parallelizing convolutional neural networks 23 | Alex Krizhevsky, 2014 24 | 25 | Here we provide the implementation proposed in "One weird trick" and not 26 | "ImageNet Classification", as per the paper, the LRN layers have been removed. 27 | 28 | Usage: 29 | with slim.arg_scope(alexnet.alexnet_v2_arg_scope()): 30 | outputs, end_points = alexnet.alexnet_v2(inputs) 31 | 32 | @@alexnet_v2 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | import tensorflow as tf 40 | 41 | slim = tf.contrib.slim 42 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 43 | 44 | 45 | def alexnet_v2_arg_scope(weight_decay=0.0005): 46 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 47 | activation_fn=tf.nn.relu, 48 | biases_initializer=tf.constant_initializer(0.1), 49 | weights_regularizer=slim.l2_regularizer(weight_decay)): 50 | with slim.arg_scope([slim.conv2d], padding='SAME'): 51 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 52 | return arg_sc 53 | 54 | 55 | def alexnet_v2(inputs, 56 | num_classes=1000, 57 | is_training=True, 58 | dropout_keep_prob=0.5, 59 | spatial_squeeze=True, 60 | scope='alexnet_v2'): 61 | """AlexNet version 2. 62 | 63 | Described in: http://arxiv.org/pdf/1404.5997v2.pdf 64 | Parameters from: 65 | github.com/akrizhevsky/cuda-convnet2/blob/master/layers/ 66 | layers-imagenet-1gpu.cfg 67 | 68 | Note: All the fully_connected layers have been transformed to conv2d layers. 69 | To use in classification mode, resize input to 224x224. To use in fully 70 | convolutional mode, set spatial_squeeze to false. 71 | The LRN layers have been removed and change the initializers from 72 | random_normal_initializer to xavier_initializer. 73 | 74 | Args: 75 | inputs: a tensor of size [batch_size, height, width, channels]. 76 | num_classes: number of predicted classes. 77 | is_training: whether or not the model is being trained. 78 | dropout_keep_prob: the probability that activations are kept in the dropout 79 | layers during training. 80 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 81 | outputs. Useful to remove unnecessary dimensions for classification. 82 | scope: Optional scope for the variables. 83 | 84 | Returns: 85 | the last op containing the log predictions and end_points dict. 86 | """ 87 | with tf.variable_scope(scope, 'alexnet_v2', [inputs]) as sc: 88 | end_points_collection = sc.name + '_end_points' 89 | # Collect outputs for conv2d, fully_connected and max_pool2d. 90 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 91 | outputs_collections=[end_points_collection]): 92 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 93 | scope='conv1') 94 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool1') 95 | net = slim.conv2d(net, 192, [5, 5], scope='conv2') 96 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool2') 97 | net = slim.conv2d(net, 384, [3, 3], scope='conv3') 98 | net = slim.conv2d(net, 384, [3, 3], scope='conv4') 99 | net = slim.conv2d(net, 256, [3, 3], scope='conv5') 100 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool5') 101 | 102 | # Use conv2d instead of fully_connected layers. 103 | with slim.arg_scope([slim.conv2d], 104 | weights_initializer=trunc_normal(0.005), 105 | biases_initializer=tf.constant_initializer(0.1)): 106 | net = slim.conv2d(net, 4096, [5, 5], padding='VALID', 107 | scope='fc6') 108 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 109 | scope='dropout6') 110 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 111 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 112 | scope='dropout7') 113 | net = slim.conv2d(net, num_classes, [1, 1], 114 | activation_fn=None, 115 | normalizer_fn=None, 116 | biases_initializer=tf.zeros_initializer, 117 | scope='fc8') 118 | 119 | # Convert end_points_collection into a end_point dict. 120 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 121 | if spatial_squeeze: 122 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 123 | end_points[sc.name + '/fc8'] = net 124 | return net, end_points 125 | alexnet_v2.default_image_size = 224 126 | -------------------------------------------------------------------------------- /nets/alexnet_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.alexnet.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import alexnet 23 | 24 | slim = tf.contrib.slim 25 | 26 | 27 | class AlexnetV2Test(tf.test.TestCase): 28 | 29 | def testBuild(self): 30 | batch_size = 5 31 | height, width = 224, 224 32 | num_classes = 1000 33 | with self.test_session(): 34 | inputs = tf.random_uniform((batch_size, height, width, 3)) 35 | logits, _ = alexnet.alexnet_v2(inputs, num_classes) 36 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/squeezed') 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | 40 | def testFullyConvolutional(self): 41 | batch_size = 1 42 | height, width = 300, 400 43 | num_classes = 1000 44 | with self.test_session(): 45 | inputs = tf.random_uniform((batch_size, height, width, 3)) 46 | logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False) 47 | self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd') 48 | self.assertListEqual(logits.get_shape().as_list(), 49 | [batch_size, 4, 7, num_classes]) 50 | 51 | def testEndPoints(self): 52 | batch_size = 5 53 | height, width = 224, 224 54 | num_classes = 1000 55 | with self.test_session(): 56 | inputs = tf.random_uniform((batch_size, height, width, 3)) 57 | _, end_points = alexnet.alexnet_v2(inputs, num_classes) 58 | expected_names = ['alexnet_v2/conv1', 59 | 'alexnet_v2/pool1', 60 | 'alexnet_v2/conv2', 61 | 'alexnet_v2/pool2', 62 | 'alexnet_v2/conv3', 63 | 'alexnet_v2/conv4', 64 | 'alexnet_v2/conv5', 65 | 'alexnet_v2/pool5', 66 | 'alexnet_v2/fc6', 67 | 'alexnet_v2/fc7', 68 | 'alexnet_v2/fc8' 69 | ] 70 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 71 | 72 | def testModelVariables(self): 73 | batch_size = 5 74 | height, width = 224, 224 75 | num_classes = 1000 76 | with self.test_session(): 77 | inputs = tf.random_uniform((batch_size, height, width, 3)) 78 | alexnet.alexnet_v2(inputs, num_classes) 79 | expected_names = ['alexnet_v2/conv1/weights', 80 | 'alexnet_v2/conv1/biases', 81 | 'alexnet_v2/conv2/weights', 82 | 'alexnet_v2/conv2/biases', 83 | 'alexnet_v2/conv3/weights', 84 | 'alexnet_v2/conv3/biases', 85 | 'alexnet_v2/conv4/weights', 86 | 'alexnet_v2/conv4/biases', 87 | 'alexnet_v2/conv5/weights', 88 | 'alexnet_v2/conv5/biases', 89 | 'alexnet_v2/fc6/weights', 90 | 'alexnet_v2/fc6/biases', 91 | 'alexnet_v2/fc7/weights', 92 | 'alexnet_v2/fc7/biases', 93 | 'alexnet_v2/fc8/weights', 94 | 'alexnet_v2/fc8/biases', 95 | ] 96 | model_variables = [v.op.name for v in slim.get_model_variables()] 97 | self.assertSetEqual(set(model_variables), set(expected_names)) 98 | 99 | def testEvaluation(self): 100 | batch_size = 2 101 | height, width = 224, 224 102 | num_classes = 1000 103 | with self.test_session(): 104 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 105 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False) 106 | self.assertListEqual(logits.get_shape().as_list(), 107 | [batch_size, num_classes]) 108 | predictions = tf.argmax(logits, 1) 109 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) 110 | 111 | def testTrainEvalWithReuse(self): 112 | train_batch_size = 2 113 | eval_batch_size = 1 114 | train_height, train_width = 224, 224 115 | eval_height, eval_width = 300, 400 116 | num_classes = 1000 117 | with self.test_session(): 118 | train_inputs = tf.random_uniform( 119 | (train_batch_size, train_height, train_width, 3)) 120 | logits, _ = alexnet.alexnet_v2(train_inputs) 121 | self.assertListEqual(logits.get_shape().as_list(), 122 | [train_batch_size, num_classes]) 123 | tf.get_variable_scope().reuse_variables() 124 | eval_inputs = tf.random_uniform( 125 | (eval_batch_size, eval_height, eval_width, 3)) 126 | logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False, 127 | spatial_squeeze=False) 128 | self.assertListEqual(logits.get_shape().as_list(), 129 | [eval_batch_size, 4, 7, num_classes]) 130 | logits = tf.reduce_mean(logits, [1, 2]) 131 | predictions = tf.argmax(logits, 1) 132 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) 133 | 134 | def testForward(self): 135 | batch_size = 1 136 | height, width = 224, 224 137 | with self.test_session() as sess: 138 | inputs = tf.random_uniform((batch_size, height, width, 3)) 139 | logits, _ = alexnet.alexnet_v2(inputs) 140 | sess.run(tf.initialize_all_variables()) 141 | output = sess.run(logits) 142 | self.assertTrue(output.any()) 143 | 144 | if __name__ == '__main__': 145 | tf.test.main() 146 | -------------------------------------------------------------------------------- /nets/cifarnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the CIFAR-10 model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev) 26 | 27 | 28 | def cifarnet(images, num_classes=10, is_training=False, 29 | dropout_keep_prob=0.5, 30 | prediction_fn=slim.softmax, 31 | scope='CifarNet'): 32 | """Creates a variant of the CifarNet model. 33 | 34 | Note that since the output is a set of 'logits', the values fall in the 35 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 36 | probability distribution over the characters, one will need to convert them 37 | using the softmax function: 38 | 39 | logits = cifarnet.cifarnet(images, is_training=False) 40 | probabilities = tf.nn.softmax(logits) 41 | predictions = tf.argmax(logits, 1) 42 | 43 | Args: 44 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 45 | num_classes: the number of classes in the dataset. 46 | is_training: specifies whether or not we're currently training the model. 47 | This variable will determine the behaviour of the dropout layer. 48 | dropout_keep_prob: the percentage of activation values that are retained. 49 | prediction_fn: a function to get predictions out of logits. 50 | scope: Optional variable_scope. 51 | 52 | Returns: 53 | logits: the pre-softmax activations, a tensor of size 54 | [batch_size, `num_classes`] 55 | end_points: a dictionary from components of the network to the corresponding 56 | activation. 57 | """ 58 | end_points = {} 59 | 60 | with tf.variable_scope(scope, 'CifarNet', [images, num_classes]): 61 | net = slim.conv2d(images, 64, [5, 5], scope='conv1') 62 | end_points['conv1'] = net 63 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 64 | end_points['pool1'] = net 65 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1') 66 | net = slim.conv2d(net, 64, [5, 5], scope='conv2') 67 | end_points['conv2'] = net 68 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2') 69 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 70 | end_points['pool2'] = net 71 | net = slim.flatten(net) 72 | end_points['Flatten'] = net 73 | net = slim.fully_connected(net, 384, scope='fc3') 74 | end_points['fc3'] = net 75 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 76 | scope='dropout3') 77 | net = slim.fully_connected(net, 192, scope='fc4') 78 | end_points['fc4'] = net 79 | logits = slim.fully_connected(net, num_classes, 80 | biases_initializer=tf.zeros_initializer, 81 | weights_initializer=trunc_normal(1/192.0), 82 | weights_regularizer=None, 83 | activation_fn=None, 84 | scope='logits') 85 | 86 | end_points['Logits'] = logits 87 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 88 | 89 | return logits, end_points 90 | cifarnet.default_image_size = 32 91 | 92 | 93 | def cifarnet_arg_scope(weight_decay=0.004): 94 | """Defines the default cifarnet argument scope. 95 | 96 | Args: 97 | weight_decay: The weight decay to use for regularizing the model. 98 | 99 | Returns: 100 | An `arg_scope` to use for the inception v3 model. 101 | """ 102 | with slim.arg_scope( 103 | [slim.conv2d], 104 | weights_initializer=tf.truncated_normal_initializer(stddev=5e-2), 105 | activation_fn=tf.nn.relu): 106 | with slim.arg_scope( 107 | [slim.fully_connected], 108 | biases_initializer=tf.constant_initializer(0.1), 109 | weights_initializer=trunc_normal(0.04), 110 | weights_regularizer=slim.l2_regularizer(weight_decay), 111 | activation_fn=tf.nn.relu) as sc: 112 | return sc 113 | -------------------------------------------------------------------------------- /nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings all inception models under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from nets.inception_resnet_v2 import inception_resnet_v2 23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope 24 | from nets.inception_v1 import inception_v1 25 | from nets.inception_v1 import inception_v1_arg_scope 26 | from nets.inception_v1 import inception_v1_base 27 | from nets.inception_v2 import inception_v2 28 | from nets.inception_v2 import inception_v2_arg_scope 29 | from nets.inception_v2 import inception_v2_base 30 | from nets.inception_v3 import inception_v3 31 | from nets.inception_v3 import inception_v3_arg_scope 32 | from nets.inception_v3 import inception_v3_base 33 | from nets.inception_v4 import inception_v4 34 | from nets.inception_v4 import inception_v4_arg_scope 35 | from nets.inception_v4 import inception_v4_base 36 | # pylint: enable=unused-import 37 | -------------------------------------------------------------------------------- /nets/inception_resnet_v2.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the definition of the Inception Resnet V2 architecture. 16 | 17 | As described in http://arxiv.org/abs/1602.07261. 18 | 19 | Inception-v4, Inception-ResNet and the Impact of Residual Connections 20 | on Learning 21 | Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | 28 | import tensorflow as tf 29 | 30 | slim = tf.contrib.slim 31 | 32 | 33 | def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None): 34 | """Builds the 35x35 resnet block.""" 35 | with tf.variable_scope(scope, 'Block35', [net], reuse=reuse): 36 | with tf.variable_scope('Branch_0'): 37 | tower_conv = slim.conv2d(net, 32, 1, scope='Conv2d_1x1') 38 | with tf.variable_scope('Branch_1'): 39 | tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1') 40 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3') 41 | with tf.variable_scope('Branch_2'): 42 | tower_conv2_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1') 43 | tower_conv2_1 = slim.conv2d(tower_conv2_0, 48, 3, scope='Conv2d_0b_3x3') 44 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 64, 3, scope='Conv2d_0c_3x3') 45 | mixed = tf.concat(3, [tower_conv, tower_conv1_1, tower_conv2_2]) 46 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, 47 | activation_fn=None, scope='Conv2d_1x1') 48 | net += scale * up 49 | if activation_fn: 50 | net = activation_fn(net) 51 | return net 52 | 53 | 54 | def block17(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None): 55 | """Builds the 17x17 resnet block.""" 56 | with tf.variable_scope(scope, 'Block17', [net], reuse=reuse): 57 | with tf.variable_scope('Branch_0'): 58 | tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1') 59 | with tf.variable_scope('Branch_1'): 60 | tower_conv1_0 = slim.conv2d(net, 128, 1, scope='Conv2d_0a_1x1') 61 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 160, [1, 7], 62 | scope='Conv2d_0b_1x7') 63 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 192, [7, 1], 64 | scope='Conv2d_0c_7x1') 65 | mixed = tf.concat(3, [tower_conv, tower_conv1_2]) 66 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, 67 | activation_fn=None, scope='Conv2d_1x1') 68 | net += scale * up 69 | if activation_fn: 70 | net = activation_fn(net) 71 | return net 72 | 73 | 74 | def block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None): 75 | """Builds the 8x8 resnet block.""" 76 | with tf.variable_scope(scope, 'Block8', [net], reuse=reuse): 77 | with tf.variable_scope('Branch_0'): 78 | tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1') 79 | with tf.variable_scope('Branch_1'): 80 | tower_conv1_0 = slim.conv2d(net, 192, 1, scope='Conv2d_0a_1x1') 81 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 224, [1, 3], 82 | scope='Conv2d_0b_1x3') 83 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 256, [3, 1], 84 | scope='Conv2d_0c_3x1') 85 | mixed = tf.concat(3, [tower_conv, tower_conv1_2]) 86 | up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None, 87 | activation_fn=None, scope='Conv2d_1x1') 88 | net += scale * up 89 | if activation_fn: 90 | net = activation_fn(net) 91 | return net 92 | 93 | 94 | def inception_resnet_v2(inputs, num_classes=1001, is_training=True, 95 | dropout_keep_prob=0.8, 96 | reuse=None, 97 | scope='InceptionResnetV2'): 98 | """Creates the Inception Resnet V2 model. 99 | 100 | Args: 101 | inputs: a 4-D tensor of size [batch_size, height, width, 3]. 102 | num_classes: number of predicted classes. 103 | is_training: whether is training or not. 104 | dropout_keep_prob: float, the fraction to keep before final layer. 105 | reuse: whether or not the network and its variables should be reused. To be 106 | able to reuse 'scope' must be given. 107 | scope: Optional variable_scope. 108 | 109 | Returns: 110 | logits: the logits outputs of the model. 111 | end_points: the set of end_points from the inception model. 112 | """ 113 | end_points = {} 114 | 115 | with tf.variable_scope(scope, 'InceptionResnetV2', [inputs], reuse=reuse): 116 | with slim.arg_scope([slim.batch_norm, slim.dropout], 117 | is_training=is_training): 118 | with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d], 119 | stride=1, padding='SAME'): 120 | 121 | # 149 x 149 x 32 122 | net = slim.conv2d(inputs, 32, 3, stride=2, padding='VALID', 123 | scope='Conv2d_1a_3x3') 124 | end_points['Conv2d_1a_3x3'] = net 125 | # 147 x 147 x 32 126 | net = slim.conv2d(net, 32, 3, padding='VALID', 127 | scope='Conv2d_2a_3x3') 128 | end_points['Conv2d_2a_3x3'] = net 129 | # 147 x 147 x 64 130 | net = slim.conv2d(net, 64, 3, scope='Conv2d_2b_3x3') 131 | end_points['Conv2d_2b_3x3'] = net 132 | # 73 x 73 x 64 133 | net = slim.max_pool2d(net, 3, stride=2, padding='VALID', 134 | scope='MaxPool_3a_3x3') 135 | end_points['MaxPool_3a_3x3'] = net 136 | # 73 x 73 x 80 137 | net = slim.conv2d(net, 80, 1, padding='VALID', 138 | scope='Conv2d_3b_1x1') 139 | end_points['Conv2d_3b_1x1'] = net 140 | # 71 x 71 x 192 141 | net = slim.conv2d(net, 192, 3, padding='VALID', 142 | scope='Conv2d_4a_3x3') 143 | end_points['Conv2d_4a_3x3'] = net 144 | # 35 x 35 x 192 145 | net = slim.max_pool2d(net, 3, stride=2, padding='VALID', 146 | scope='MaxPool_5a_3x3') 147 | end_points['MaxPool_5a_3x3'] = net 148 | 149 | # 35 x 35 x 320 150 | with tf.variable_scope('Mixed_5b'): 151 | with tf.variable_scope('Branch_0'): 152 | tower_conv = slim.conv2d(net, 96, 1, scope='Conv2d_1x1') 153 | with tf.variable_scope('Branch_1'): 154 | tower_conv1_0 = slim.conv2d(net, 48, 1, scope='Conv2d_0a_1x1') 155 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 64, 5, 156 | scope='Conv2d_0b_5x5') 157 | with tf.variable_scope('Branch_2'): 158 | tower_conv2_0 = slim.conv2d(net, 64, 1, scope='Conv2d_0a_1x1') 159 | tower_conv2_1 = slim.conv2d(tower_conv2_0, 96, 3, 160 | scope='Conv2d_0b_3x3') 161 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 96, 3, 162 | scope='Conv2d_0c_3x3') 163 | with tf.variable_scope('Branch_3'): 164 | tower_pool = slim.avg_pool2d(net, 3, stride=1, padding='SAME', 165 | scope='AvgPool_0a_3x3') 166 | tower_pool_1 = slim.conv2d(tower_pool, 64, 1, 167 | scope='Conv2d_0b_1x1') 168 | net = tf.concat(3, [tower_conv, tower_conv1_1, 169 | tower_conv2_2, tower_pool_1]) 170 | 171 | end_points['Mixed_5b'] = net 172 | net = slim.repeat(net, 10, block35, scale=0.17) 173 | 174 | # 17 x 17 x 1024 175 | with tf.variable_scope('Mixed_6a'): 176 | with tf.variable_scope('Branch_0'): 177 | tower_conv = slim.conv2d(net, 384, 3, stride=2, padding='VALID', 178 | scope='Conv2d_1a_3x3') 179 | with tf.variable_scope('Branch_1'): 180 | tower_conv1_0 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') 181 | tower_conv1_1 = slim.conv2d(tower_conv1_0, 256, 3, 182 | scope='Conv2d_0b_3x3') 183 | tower_conv1_2 = slim.conv2d(tower_conv1_1, 384, 3, 184 | stride=2, padding='VALID', 185 | scope='Conv2d_1a_3x3') 186 | with tf.variable_scope('Branch_2'): 187 | tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID', 188 | scope='MaxPool_1a_3x3') 189 | net = tf.concat(3, [tower_conv, tower_conv1_2, tower_pool]) 190 | 191 | end_points['Mixed_6a'] = net 192 | net = slim.repeat(net, 20, block17, scale=0.10) 193 | 194 | # Auxillary tower 195 | with tf.variable_scope('AuxLogits'): 196 | aux = slim.avg_pool2d(net, 5, stride=3, padding='VALID', 197 | scope='Conv2d_1a_3x3') 198 | aux = slim.conv2d(aux, 128, 1, scope='Conv2d_1b_1x1') 199 | aux = slim.conv2d(aux, 768, aux.get_shape()[1:3], 200 | padding='VALID', scope='Conv2d_2a_5x5') 201 | aux = slim.flatten(aux) 202 | aux = slim.fully_connected(aux, num_classes, activation_fn=None, 203 | scope='Logits') 204 | end_points['AuxLogits'] = aux 205 | 206 | with tf.variable_scope('Mixed_7a'): 207 | with tf.variable_scope('Branch_0'): 208 | tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') 209 | tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=2, 210 | padding='VALID', scope='Conv2d_1a_3x3') 211 | with tf.variable_scope('Branch_1'): 212 | tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') 213 | tower_conv1_1 = slim.conv2d(tower_conv1, 288, 3, stride=2, 214 | padding='VALID', scope='Conv2d_1a_3x3') 215 | with tf.variable_scope('Branch_2'): 216 | tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1') 217 | tower_conv2_1 = slim.conv2d(tower_conv2, 288, 3, 218 | scope='Conv2d_0b_3x3') 219 | tower_conv2_2 = slim.conv2d(tower_conv2_1, 320, 3, stride=2, 220 | padding='VALID', scope='Conv2d_1a_3x3') 221 | with tf.variable_scope('Branch_3'): 222 | tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID', 223 | scope='MaxPool_1a_3x3') 224 | net = tf.concat(3, [tower_conv_1, tower_conv1_1, 225 | tower_conv2_2, tower_pool]) 226 | 227 | end_points['Mixed_7a'] = net 228 | 229 | net = slim.repeat(net, 9, block8, scale=0.20) 230 | net = block8(net, activation_fn=None) 231 | 232 | net = slim.conv2d(net, 1536, 1, scope='Conv2d_7b_1x1') 233 | end_points['Conv2d_7b_1x1'] = net 234 | 235 | with tf.variable_scope('Logits'): 236 | end_points['PrePool'] = net 237 | net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID', 238 | scope='AvgPool_1a_8x8') 239 | net = slim.flatten(net) 240 | 241 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 242 | scope='Dropout') 243 | 244 | end_points['PreLogitsFlatten'] = net 245 | logits = slim.fully_connected(net, num_classes, activation_fn=None, 246 | scope='Logits') 247 | end_points['Logits'] = logits 248 | end_points['Predictions'] = tf.nn.softmax(logits, name='Predictions') 249 | 250 | return logits, end_points 251 | inception_resnet_v2.default_image_size = 299 252 | 253 | 254 | def inception_resnet_v2_arg_scope(weight_decay=0.00004, 255 | batch_norm_decay=0.9997, 256 | batch_norm_epsilon=0.001): 257 | """Yields the scope with the default parameters for inception_resnet_v2. 258 | 259 | Args: 260 | weight_decay: the weight decay for weights variables. 261 | batch_norm_decay: decay for the moving average of batch_norm momentums. 262 | batch_norm_epsilon: small float added to variance to avoid dividing by zero. 263 | 264 | Returns: 265 | a arg_scope with the parameters needed for inception_resnet_v2. 266 | """ 267 | # Set weight_decay for weights in conv2d and fully_connected layers. 268 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 269 | weights_regularizer=slim.l2_regularizer(weight_decay), 270 | biases_regularizer=slim.l2_regularizer(weight_decay)): 271 | 272 | batch_norm_params = { 273 | 'decay': batch_norm_decay, 274 | 'epsilon': batch_norm_epsilon, 275 | } 276 | # Set activation_fn and parameters for batch_norm. 277 | with slim.arg_scope([slim.conv2d], activation_fn=tf.nn.relu, 278 | normalizer_fn=slim.batch_norm, 279 | normalizer_params=batch_norm_params) as scope: 280 | return scope 281 | -------------------------------------------------------------------------------- /nets/inception_resnet_v2_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.inception_resnet_v2.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import inception 23 | 24 | 25 | class InceptionTest(tf.test.TestCase): 26 | 27 | def testBuildLogits(self): 28 | batch_size = 5 29 | height, width = 299, 299 30 | num_classes = 1000 31 | with self.test_session(): 32 | inputs = tf.random_uniform((batch_size, height, width, 3)) 33 | logits, _ = inception.inception_resnet_v2(inputs, num_classes) 34 | self.assertTrue(logits.op.name.startswith('InceptionResnetV2/Logits')) 35 | self.assertListEqual(logits.get_shape().as_list(), 36 | [batch_size, num_classes]) 37 | 38 | def testBuildEndPoints(self): 39 | batch_size = 5 40 | height, width = 299, 299 41 | num_classes = 1000 42 | with self.test_session(): 43 | inputs = tf.random_uniform((batch_size, height, width, 3)) 44 | _, end_points = inception.inception_resnet_v2(inputs, num_classes) 45 | self.assertTrue('Logits' in end_points) 46 | logits = end_points['Logits'] 47 | self.assertListEqual(logits.get_shape().as_list(), 48 | [batch_size, num_classes]) 49 | self.assertTrue('AuxLogits' in end_points) 50 | aux_logits = end_points['AuxLogits'] 51 | self.assertListEqual(aux_logits.get_shape().as_list(), 52 | [batch_size, num_classes]) 53 | pre_pool = end_points['PrePool'] 54 | self.assertListEqual(pre_pool.get_shape().as_list(), 55 | [batch_size, 8, 8, 1536]) 56 | 57 | def testVariablesSetDevice(self): 58 | batch_size = 5 59 | height, width = 299, 299 60 | num_classes = 1000 61 | with self.test_session(): 62 | inputs = tf.random_uniform((batch_size, height, width, 3)) 63 | # Force all Variables to reside on the device. 64 | with tf.variable_scope('on_cpu'), tf.device('/cpu:0'): 65 | inception.inception_resnet_v2(inputs, num_classes) 66 | with tf.variable_scope('on_gpu'), tf.device('/gpu:0'): 67 | inception.inception_resnet_v2(inputs, num_classes) 68 | for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'): 69 | self.assertDeviceEqual(v.device, '/cpu:0') 70 | for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'): 71 | self.assertDeviceEqual(v.device, '/gpu:0') 72 | 73 | def testHalfSizeImages(self): 74 | batch_size = 5 75 | height, width = 150, 150 76 | num_classes = 1000 77 | with self.test_session(): 78 | inputs = tf.random_uniform((batch_size, height, width, 3)) 79 | logits, end_points = inception.inception_resnet_v2(inputs, num_classes) 80 | self.assertTrue(logits.op.name.startswith('InceptionResnetV2/Logits')) 81 | self.assertListEqual(logits.get_shape().as_list(), 82 | [batch_size, num_classes]) 83 | pre_pool = end_points['PrePool'] 84 | self.assertListEqual(pre_pool.get_shape().as_list(), 85 | [batch_size, 3, 3, 1536]) 86 | 87 | def testUnknownBatchSize(self): 88 | batch_size = 1 89 | height, width = 299, 299 90 | num_classes = 1000 91 | with self.test_session() as sess: 92 | inputs = tf.placeholder(tf.float32, (None, height, width, 3)) 93 | logits, _ = inception.inception_resnet_v2(inputs, num_classes) 94 | self.assertTrue(logits.op.name.startswith('InceptionResnetV2/Logits')) 95 | self.assertListEqual(logits.get_shape().as_list(), 96 | [None, num_classes]) 97 | images = tf.random_uniform((batch_size, height, width, 3)) 98 | sess.run(tf.initialize_all_variables()) 99 | output = sess.run(logits, {inputs: images.eval()}) 100 | self.assertEquals(output.shape, (batch_size, num_classes)) 101 | 102 | def testEvaluation(self): 103 | batch_size = 2 104 | height, width = 299, 299 105 | num_classes = 1000 106 | with self.test_session() as sess: 107 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 108 | logits, _ = inception.inception_resnet_v2(eval_inputs, 109 | num_classes, 110 | is_training=False) 111 | predictions = tf.argmax(logits, 1) 112 | sess.run(tf.initialize_all_variables()) 113 | output = sess.run(predictions) 114 | self.assertEquals(output.shape, (batch_size,)) 115 | 116 | def testTrainEvalWithReuse(self): 117 | train_batch_size = 5 118 | eval_batch_size = 2 119 | height, width = 150, 150 120 | num_classes = 1000 121 | with self.test_session() as sess: 122 | train_inputs = tf.random_uniform((train_batch_size, height, width, 3)) 123 | inception.inception_resnet_v2(train_inputs, num_classes) 124 | eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3)) 125 | logits, _ = inception.inception_resnet_v2(eval_inputs, 126 | num_classes, 127 | is_training=False, 128 | reuse=True) 129 | predictions = tf.argmax(logits, 1) 130 | sess.run(tf.initialize_all_variables()) 131 | output = sess.run(predictions) 132 | self.assertEquals(output.shape, (eval_batch_size,)) 133 | 134 | 135 | if __name__ == '__main__': 136 | tf.test.main() 137 | -------------------------------------------------------------------------------- /nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | 17 | Usage of arg scope: 18 | with slim.arg_scope(inception_arg_scope()): 19 | logits, end_points = inception.inception_v3(images, num_classes, 20 | is_training=is_training) 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def inception_arg_scope(weight_decay=0.00004, 33 | use_batch_norm=True, 34 | batch_norm_decay=0.9997, 35 | batch_norm_epsilon=0.001): 36 | """Defines the default arg scope for inception models. 37 | 38 | Args: 39 | weight_decay: The weight decay to use for regularizing the model. 40 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 41 | batch_norm_decay: Decay for batch norm moving average. 42 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 43 | in batch norm. 44 | 45 | Returns: 46 | An `arg_scope` to use for the inception models. 47 | """ 48 | batch_norm_params = { 49 | # Decay for the moving averages. 50 | 'decay': batch_norm_decay, 51 | # epsilon to prevent 0s in variance. 52 | 'epsilon': batch_norm_epsilon, 53 | # collection containing update_ops. 54 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 55 | } 56 | if use_batch_norm: 57 | normalizer_fn = slim.batch_norm 58 | normalizer_params = batch_norm_params 59 | else: 60 | normalizer_fn = None 61 | normalizer_params = {} 62 | # Set weight_decay for weights in Conv and FC layers. 63 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 64 | weights_regularizer=slim.l2_regularizer(weight_decay)): 65 | with slim.arg_scope( 66 | [slim.conv2d], 67 | weights_initializer=slim.variance_scaling_initializer(), 68 | activation_fn=tf.nn.relu, 69 | normalizer_fn=normalizer_fn, 70 | normalizer_params=normalizer_params) as sc: 71 | return sc 72 | -------------------------------------------------------------------------------- /nets/inception_v1_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for nets.inception_v1.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from nets import inception 25 | 26 | slim = tf.contrib.slim 27 | 28 | 29 | class InceptionV1Test(tf.test.TestCase): 30 | 31 | def testBuildClassificationNetwork(self): 32 | batch_size = 5 33 | height, width = 224, 224 34 | num_classes = 1000 35 | 36 | inputs = tf.random_uniform((batch_size, height, width, 3)) 37 | logits, end_points = inception.inception_v1(inputs, num_classes) 38 | self.assertTrue(logits.op.name.startswith('InceptionV1/Logits')) 39 | self.assertListEqual(logits.get_shape().as_list(), 40 | [batch_size, num_classes]) 41 | self.assertTrue('Predictions' in end_points) 42 | self.assertListEqual(end_points['Predictions'].get_shape().as_list(), 43 | [batch_size, num_classes]) 44 | 45 | def testBuildBaseNetwork(self): 46 | batch_size = 5 47 | height, width = 224, 224 48 | 49 | inputs = tf.random_uniform((batch_size, height, width, 3)) 50 | mixed_6c, end_points = inception.inception_v1_base(inputs) 51 | self.assertTrue(mixed_6c.op.name.startswith('InceptionV1/Mixed_5c')) 52 | self.assertListEqual(mixed_6c.get_shape().as_list(), 53 | [batch_size, 7, 7, 1024]) 54 | expected_endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 55 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 56 | 'Mixed_3c', 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 57 | 'Mixed_4d', 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2', 58 | 'Mixed_5b', 'Mixed_5c'] 59 | self.assertItemsEqual(end_points.keys(), expected_endpoints) 60 | 61 | def testBuildOnlyUptoFinalEndpoint(self): 62 | batch_size = 5 63 | height, width = 224, 224 64 | endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 65 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 66 | 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 67 | 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 68 | 'Mixed_5c'] 69 | for index, endpoint in enumerate(endpoints): 70 | with tf.Graph().as_default(): 71 | inputs = tf.random_uniform((batch_size, height, width, 3)) 72 | out_tensor, end_points = inception.inception_v1_base( 73 | inputs, final_endpoint=endpoint) 74 | self.assertTrue(out_tensor.op.name.startswith( 75 | 'InceptionV1/' + endpoint)) 76 | self.assertItemsEqual(endpoints[:index+1], end_points) 77 | 78 | def testBuildAndCheckAllEndPointsUptoMixed5c(self): 79 | batch_size = 5 80 | height, width = 224, 224 81 | 82 | inputs = tf.random_uniform((batch_size, height, width, 3)) 83 | _, end_points = inception.inception_v1_base(inputs, 84 | final_endpoint='Mixed_5c') 85 | endpoints_shapes = {'Conv2d_1a_7x7': [5, 112, 112, 64], 86 | 'MaxPool_2a_3x3': [5, 56, 56, 64], 87 | 'Conv2d_2b_1x1': [5, 56, 56, 64], 88 | 'Conv2d_2c_3x3': [5, 56, 56, 192], 89 | 'MaxPool_3a_3x3': [5, 28, 28, 192], 90 | 'Mixed_3b': [5, 28, 28, 256], 91 | 'Mixed_3c': [5, 28, 28, 480], 92 | 'MaxPool_4a_3x3': [5, 14, 14, 480], 93 | 'Mixed_4b': [5, 14, 14, 512], 94 | 'Mixed_4c': [5, 14, 14, 512], 95 | 'Mixed_4d': [5, 14, 14, 512], 96 | 'Mixed_4e': [5, 14, 14, 528], 97 | 'Mixed_4f': [5, 14, 14, 832], 98 | 'MaxPool_5a_2x2': [5, 7, 7, 832], 99 | 'Mixed_5b': [5, 7, 7, 832], 100 | 'Mixed_5c': [5, 7, 7, 1024]} 101 | 102 | self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys()) 103 | for endpoint_name in endpoints_shapes: 104 | expected_shape = endpoints_shapes[endpoint_name] 105 | self.assertTrue(endpoint_name in end_points) 106 | self.assertListEqual(end_points[endpoint_name].get_shape().as_list(), 107 | expected_shape) 108 | 109 | def testModelHasExpectedNumberOfParameters(self): 110 | batch_size = 5 111 | height, width = 224, 224 112 | inputs = tf.random_uniform((batch_size, height, width, 3)) 113 | with slim.arg_scope(inception.inception_v1_arg_scope()): 114 | inception.inception_v1_base(inputs) 115 | total_params, _ = slim.model_analyzer.analyze_vars( 116 | slim.get_model_variables()) 117 | self.assertAlmostEqual(5607184, total_params) 118 | 119 | def testHalfSizeImages(self): 120 | batch_size = 5 121 | height, width = 112, 112 122 | 123 | inputs = tf.random_uniform((batch_size, height, width, 3)) 124 | mixed_5c, _ = inception.inception_v1_base(inputs) 125 | self.assertTrue(mixed_5c.op.name.startswith('InceptionV1/Mixed_5c')) 126 | self.assertListEqual(mixed_5c.get_shape().as_list(), 127 | [batch_size, 4, 4, 1024]) 128 | 129 | def testUnknownImageShape(self): 130 | tf.reset_default_graph() 131 | batch_size = 2 132 | height, width = 224, 224 133 | num_classes = 1000 134 | input_np = np.random.uniform(0, 1, (batch_size, height, width, 3)) 135 | with self.test_session() as sess: 136 | inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3)) 137 | logits, end_points = inception.inception_v1(inputs, num_classes) 138 | self.assertTrue(logits.op.name.startswith('InceptionV1/Logits')) 139 | self.assertListEqual(logits.get_shape().as_list(), 140 | [batch_size, num_classes]) 141 | pre_pool = end_points['Mixed_5c'] 142 | feed_dict = {inputs: input_np} 143 | tf.initialize_all_variables().run() 144 | pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict) 145 | self.assertListEqual(list(pre_pool_out.shape), [batch_size, 7, 7, 1024]) 146 | 147 | def testUnknowBatchSize(self): 148 | batch_size = 1 149 | height, width = 224, 224 150 | num_classes = 1000 151 | 152 | inputs = tf.placeholder(tf.float32, (None, height, width, 3)) 153 | logits, _ = inception.inception_v1(inputs, num_classes) 154 | self.assertTrue(logits.op.name.startswith('InceptionV1/Logits')) 155 | self.assertListEqual(logits.get_shape().as_list(), 156 | [None, num_classes]) 157 | images = tf.random_uniform((batch_size, height, width, 3)) 158 | 159 | with self.test_session() as sess: 160 | sess.run(tf.initialize_all_variables()) 161 | output = sess.run(logits, {inputs: images.eval()}) 162 | self.assertEquals(output.shape, (batch_size, num_classes)) 163 | 164 | def testEvaluation(self): 165 | batch_size = 2 166 | height, width = 224, 224 167 | num_classes = 1000 168 | 169 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 170 | logits, _ = inception.inception_v1(eval_inputs, num_classes, 171 | is_training=False) 172 | predictions = tf.argmax(logits, 1) 173 | 174 | with self.test_session() as sess: 175 | sess.run(tf.initialize_all_variables()) 176 | output = sess.run(predictions) 177 | self.assertEquals(output.shape, (batch_size,)) 178 | 179 | def testTrainEvalWithReuse(self): 180 | train_batch_size = 5 181 | eval_batch_size = 2 182 | height, width = 224, 224 183 | num_classes = 1000 184 | 185 | train_inputs = tf.random_uniform((train_batch_size, height, width, 3)) 186 | inception.inception_v1(train_inputs, num_classes) 187 | eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3)) 188 | logits, _ = inception.inception_v1(eval_inputs, num_classes, reuse=True) 189 | predictions = tf.argmax(logits, 1) 190 | 191 | with self.test_session() as sess: 192 | sess.run(tf.initialize_all_variables()) 193 | output = sess.run(predictions) 194 | self.assertEquals(output.shape, (eval_batch_size,)) 195 | 196 | def testLogitsNotSqueezed(self): 197 | num_classes = 25 198 | images = tf.random_uniform([1, 224, 224, 3]) 199 | logits, _ = inception.inception_v1(images, 200 | num_classes=num_classes, 201 | spatial_squeeze=False) 202 | 203 | with self.test_session() as sess: 204 | tf.initialize_all_variables().run() 205 | logits_out = sess.run(logits) 206 | self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes]) 207 | 208 | 209 | if __name__ == '__main__': 210 | tf.test.main() 211 | -------------------------------------------------------------------------------- /nets/inception_v2_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for nets.inception_v2.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from nets import inception 25 | 26 | slim = tf.contrib.slim 27 | 28 | 29 | class InceptionV2Test(tf.test.TestCase): 30 | 31 | def testBuildClassificationNetwork(self): 32 | batch_size = 5 33 | height, width = 224, 224 34 | num_classes = 1000 35 | 36 | inputs = tf.random_uniform((batch_size, height, width, 3)) 37 | logits, end_points = inception.inception_v2(inputs, num_classes) 38 | self.assertTrue(logits.op.name.startswith('InceptionV2/Logits')) 39 | self.assertListEqual(logits.get_shape().as_list(), 40 | [batch_size, num_classes]) 41 | self.assertTrue('Predictions' in end_points) 42 | self.assertListEqual(end_points['Predictions'].get_shape().as_list(), 43 | [batch_size, num_classes]) 44 | 45 | def testBuildBaseNetwork(self): 46 | batch_size = 5 47 | height, width = 224, 224 48 | 49 | inputs = tf.random_uniform((batch_size, height, width, 3)) 50 | mixed_5c, end_points = inception.inception_v2_base(inputs) 51 | self.assertTrue(mixed_5c.op.name.startswith('InceptionV2/Mixed_5c')) 52 | self.assertListEqual(mixed_5c.get_shape().as_list(), 53 | [batch_size, 7, 7, 1024]) 54 | expected_endpoints = ['Mixed_3b', 'Mixed_3c', 'Mixed_4a', 'Mixed_4b', 55 | 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 'Mixed_5a', 56 | 'Mixed_5b', 'Mixed_5c', 'Conv2d_1a_7x7', 57 | 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 'Conv2d_2c_3x3', 58 | 'MaxPool_3a_3x3'] 59 | self.assertItemsEqual(end_points.keys(), expected_endpoints) 60 | 61 | def testBuildOnlyUptoFinalEndpoint(self): 62 | batch_size = 5 63 | height, width = 224, 224 64 | endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 65 | 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 66 | 'Mixed_4a', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 67 | 'Mixed_5a', 'Mixed_5b', 'Mixed_5c'] 68 | for index, endpoint in enumerate(endpoints): 69 | with tf.Graph().as_default(): 70 | inputs = tf.random_uniform((batch_size, height, width, 3)) 71 | out_tensor, end_points = inception.inception_v2_base( 72 | inputs, final_endpoint=endpoint) 73 | self.assertTrue(out_tensor.op.name.startswith( 74 | 'InceptionV2/' + endpoint)) 75 | self.assertItemsEqual(endpoints[:index+1], end_points) 76 | 77 | def testBuildAndCheckAllEndPointsUptoMixed5c(self): 78 | batch_size = 5 79 | height, width = 224, 224 80 | 81 | inputs = tf.random_uniform((batch_size, height, width, 3)) 82 | _, end_points = inception.inception_v2_base(inputs, 83 | final_endpoint='Mixed_5c') 84 | endpoints_shapes = {'Mixed_3b': [batch_size, 28, 28, 256], 85 | 'Mixed_3c': [batch_size, 28, 28, 320], 86 | 'Mixed_4a': [batch_size, 14, 14, 576], 87 | 'Mixed_4b': [batch_size, 14, 14, 576], 88 | 'Mixed_4c': [batch_size, 14, 14, 576], 89 | 'Mixed_4d': [batch_size, 14, 14, 576], 90 | 'Mixed_4e': [batch_size, 14, 14, 576], 91 | 'Mixed_5a': [batch_size, 7, 7, 1024], 92 | 'Mixed_5b': [batch_size, 7, 7, 1024], 93 | 'Mixed_5c': [batch_size, 7, 7, 1024], 94 | 'Conv2d_1a_7x7': [batch_size, 112, 112, 64], 95 | 'MaxPool_2a_3x3': [batch_size, 56, 56, 64], 96 | 'Conv2d_2b_1x1': [batch_size, 56, 56, 64], 97 | 'Conv2d_2c_3x3': [batch_size, 56, 56, 192], 98 | 'MaxPool_3a_3x3': [batch_size, 28, 28, 192]} 99 | self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys()) 100 | for endpoint_name in endpoints_shapes: 101 | expected_shape = endpoints_shapes[endpoint_name] 102 | self.assertTrue(endpoint_name in end_points) 103 | self.assertListEqual(end_points[endpoint_name].get_shape().as_list(), 104 | expected_shape) 105 | 106 | def testModelHasExpectedNumberOfParameters(self): 107 | batch_size = 5 108 | height, width = 224, 224 109 | inputs = tf.random_uniform((batch_size, height, width, 3)) 110 | with slim.arg_scope(inception.inception_v2_arg_scope()): 111 | inception.inception_v2_base(inputs) 112 | total_params, _ = slim.model_analyzer.analyze_vars( 113 | slim.get_model_variables()) 114 | self.assertAlmostEqual(10173112, total_params) 115 | 116 | def testBuildEndPointsWithDepthMultiplierLessThanOne(self): 117 | batch_size = 5 118 | height, width = 224, 224 119 | num_classes = 1000 120 | 121 | inputs = tf.random_uniform((batch_size, height, width, 3)) 122 | _, end_points = inception.inception_v2(inputs, num_classes) 123 | 124 | endpoint_keys = [key for key in end_points.keys() 125 | if key.startswith('Mixed') or key.startswith('Conv')] 126 | 127 | _, end_points_with_multiplier = inception.inception_v2( 128 | inputs, num_classes, scope='depth_multiplied_net', 129 | depth_multiplier=0.5) 130 | 131 | for key in endpoint_keys: 132 | original_depth = end_points[key].get_shape().as_list()[3] 133 | new_depth = end_points_with_multiplier[key].get_shape().as_list()[3] 134 | self.assertEqual(0.5 * original_depth, new_depth) 135 | 136 | def testBuildEndPointsWithDepthMultiplierGreaterThanOne(self): 137 | batch_size = 5 138 | height, width = 224, 224 139 | num_classes = 1000 140 | 141 | inputs = tf.random_uniform((batch_size, height, width, 3)) 142 | _, end_points = inception.inception_v2(inputs, num_classes) 143 | 144 | endpoint_keys = [key for key in end_points.keys() 145 | if key.startswith('Mixed') or key.startswith('Conv')] 146 | 147 | _, end_points_with_multiplier = inception.inception_v2( 148 | inputs, num_classes, scope='depth_multiplied_net', 149 | depth_multiplier=2.0) 150 | 151 | for key in endpoint_keys: 152 | original_depth = end_points[key].get_shape().as_list()[3] 153 | new_depth = end_points_with_multiplier[key].get_shape().as_list()[3] 154 | self.assertEqual(2.0 * original_depth, new_depth) 155 | 156 | def testRaiseValueErrorWithInvalidDepthMultiplier(self): 157 | batch_size = 5 158 | height, width = 224, 224 159 | num_classes = 1000 160 | 161 | inputs = tf.random_uniform((batch_size, height, width, 3)) 162 | with self.assertRaises(ValueError): 163 | _ = inception.inception_v2(inputs, num_classes, depth_multiplier=-0.1) 164 | with self.assertRaises(ValueError): 165 | _ = inception.inception_v2(inputs, num_classes, depth_multiplier=0.0) 166 | 167 | def testHalfSizeImages(self): 168 | batch_size = 5 169 | height, width = 112, 112 170 | num_classes = 1000 171 | 172 | inputs = tf.random_uniform((batch_size, height, width, 3)) 173 | logits, end_points = inception.inception_v2(inputs, num_classes) 174 | self.assertTrue(logits.op.name.startswith('InceptionV2/Logits')) 175 | self.assertListEqual(logits.get_shape().as_list(), 176 | [batch_size, num_classes]) 177 | pre_pool = end_points['Mixed_5c'] 178 | self.assertListEqual(pre_pool.get_shape().as_list(), 179 | [batch_size, 4, 4, 1024]) 180 | 181 | def testUnknownImageShape(self): 182 | tf.reset_default_graph() 183 | batch_size = 2 184 | height, width = 224, 224 185 | num_classes = 1000 186 | input_np = np.random.uniform(0, 1, (batch_size, height, width, 3)) 187 | with self.test_session() as sess: 188 | inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3)) 189 | logits, end_points = inception.inception_v2(inputs, num_classes) 190 | self.assertTrue(logits.op.name.startswith('InceptionV2/Logits')) 191 | self.assertListEqual(logits.get_shape().as_list(), 192 | [batch_size, num_classes]) 193 | pre_pool = end_points['Mixed_5c'] 194 | feed_dict = {inputs: input_np} 195 | tf.initialize_all_variables().run() 196 | pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict) 197 | self.assertListEqual(list(pre_pool_out.shape), [batch_size, 7, 7, 1024]) 198 | 199 | def testUnknowBatchSize(self): 200 | batch_size = 1 201 | height, width = 224, 224 202 | num_classes = 1000 203 | 204 | inputs = tf.placeholder(tf.float32, (None, height, width, 3)) 205 | logits, _ = inception.inception_v2(inputs, num_classes) 206 | self.assertTrue(logits.op.name.startswith('InceptionV2/Logits')) 207 | self.assertListEqual(logits.get_shape().as_list(), 208 | [None, num_classes]) 209 | images = tf.random_uniform((batch_size, height, width, 3)) 210 | 211 | with self.test_session() as sess: 212 | sess.run(tf.initialize_all_variables()) 213 | output = sess.run(logits, {inputs: images.eval()}) 214 | self.assertEquals(output.shape, (batch_size, num_classes)) 215 | 216 | def testEvaluation(self): 217 | batch_size = 2 218 | height, width = 224, 224 219 | num_classes = 1000 220 | 221 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 222 | logits, _ = inception.inception_v2(eval_inputs, num_classes, 223 | is_training=False) 224 | predictions = tf.argmax(logits, 1) 225 | 226 | with self.test_session() as sess: 227 | sess.run(tf.initialize_all_variables()) 228 | output = sess.run(predictions) 229 | self.assertEquals(output.shape, (batch_size,)) 230 | 231 | def testTrainEvalWithReuse(self): 232 | train_batch_size = 5 233 | eval_batch_size = 2 234 | height, width = 150, 150 235 | num_classes = 1000 236 | 237 | train_inputs = tf.random_uniform((train_batch_size, height, width, 3)) 238 | inception.inception_v2(train_inputs, num_classes) 239 | eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3)) 240 | logits, _ = inception.inception_v2(eval_inputs, num_classes, reuse=True) 241 | predictions = tf.argmax(logits, 1) 242 | 243 | with self.test_session() as sess: 244 | sess.run(tf.initialize_all_variables()) 245 | output = sess.run(predictions) 246 | self.assertEquals(output.shape, (eval_batch_size,)) 247 | 248 | def testLogitsNotSqueezed(self): 249 | num_classes = 25 250 | images = tf.random_uniform([1, 224, 224, 3]) 251 | logits, _ = inception.inception_v2(images, 252 | num_classes=num_classes, 253 | spatial_squeeze=False) 254 | 255 | with self.test_session() as sess: 256 | tf.initialize_all_variables().run() 257 | logits_out = sess.run(logits) 258 | self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes]) 259 | 260 | 261 | if __name__ == '__main__': 262 | tf.test.main() 263 | -------------------------------------------------------------------------------- /nets/inception_v3_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for nets.inception_v1.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from nets import inception 25 | 26 | slim = tf.contrib.slim 27 | 28 | 29 | class InceptionV3Test(tf.test.TestCase): 30 | 31 | def testBuildClassificationNetwork(self): 32 | batch_size = 5 33 | height, width = 299, 299 34 | num_classes = 1000 35 | 36 | inputs = tf.random_uniform((batch_size, height, width, 3)) 37 | logits, end_points = inception.inception_v3(inputs, num_classes) 38 | self.assertTrue(logits.op.name.startswith('InceptionV3/Logits')) 39 | self.assertListEqual(logits.get_shape().as_list(), 40 | [batch_size, num_classes]) 41 | self.assertTrue('Predictions' in end_points) 42 | self.assertListEqual(end_points['Predictions'].get_shape().as_list(), 43 | [batch_size, num_classes]) 44 | 45 | def testBuildBaseNetwork(self): 46 | batch_size = 5 47 | height, width = 299, 299 48 | 49 | inputs = tf.random_uniform((batch_size, height, width, 3)) 50 | final_endpoint, end_points = inception.inception_v3_base(inputs) 51 | self.assertTrue(final_endpoint.op.name.startswith( 52 | 'InceptionV3/Mixed_7c')) 53 | self.assertListEqual(final_endpoint.get_shape().as_list(), 54 | [batch_size, 8, 8, 2048]) 55 | expected_endpoints = ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 56 | 'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 57 | 'MaxPool_5a_3x3', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 58 | 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 59 | 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c'] 60 | self.assertItemsEqual(end_points.keys(), expected_endpoints) 61 | 62 | def testBuildOnlyUptoFinalEndpoint(self): 63 | batch_size = 5 64 | height, width = 299, 299 65 | endpoints = ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 66 | 'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 67 | 'MaxPool_5a_3x3', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 68 | 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 69 | 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c'] 70 | 71 | for index, endpoint in enumerate(endpoints): 72 | with tf.Graph().as_default(): 73 | inputs = tf.random_uniform((batch_size, height, width, 3)) 74 | out_tensor, end_points = inception.inception_v3_base( 75 | inputs, final_endpoint=endpoint) 76 | self.assertTrue(out_tensor.op.name.startswith( 77 | 'InceptionV3/' + endpoint)) 78 | self.assertItemsEqual(endpoints[:index+1], end_points) 79 | 80 | def testBuildAndCheckAllEndPointsUptoMixed7c(self): 81 | batch_size = 5 82 | height, width = 299, 299 83 | 84 | inputs = tf.random_uniform((batch_size, height, width, 3)) 85 | _, end_points = inception.inception_v3_base( 86 | inputs, final_endpoint='Mixed_7c') 87 | endpoints_shapes = {'Conv2d_1a_3x3': [batch_size, 149, 149, 32], 88 | 'Conv2d_2a_3x3': [batch_size, 147, 147, 32], 89 | 'Conv2d_2b_3x3': [batch_size, 147, 147, 64], 90 | 'MaxPool_3a_3x3': [batch_size, 73, 73, 64], 91 | 'Conv2d_3b_1x1': [batch_size, 73, 73, 80], 92 | 'Conv2d_4a_3x3': [batch_size, 71, 71, 192], 93 | 'MaxPool_5a_3x3': [batch_size, 35, 35, 192], 94 | 'Mixed_5b': [batch_size, 35, 35, 256], 95 | 'Mixed_5c': [batch_size, 35, 35, 288], 96 | 'Mixed_5d': [batch_size, 35, 35, 288], 97 | 'Mixed_6a': [batch_size, 17, 17, 768], 98 | 'Mixed_6b': [batch_size, 17, 17, 768], 99 | 'Mixed_6c': [batch_size, 17, 17, 768], 100 | 'Mixed_6d': [batch_size, 17, 17, 768], 101 | 'Mixed_6e': [batch_size, 17, 17, 768], 102 | 'Mixed_7a': [batch_size, 8, 8, 1280], 103 | 'Mixed_7b': [batch_size, 8, 8, 2048], 104 | 'Mixed_7c': [batch_size, 8, 8, 2048]} 105 | self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys()) 106 | for endpoint_name in endpoints_shapes: 107 | expected_shape = endpoints_shapes[endpoint_name] 108 | self.assertTrue(endpoint_name in end_points) 109 | self.assertListEqual(end_points[endpoint_name].get_shape().as_list(), 110 | expected_shape) 111 | 112 | def testModelHasExpectedNumberOfParameters(self): 113 | batch_size = 5 114 | height, width = 299, 299 115 | inputs = tf.random_uniform((batch_size, height, width, 3)) 116 | with slim.arg_scope(inception.inception_v3_arg_scope()): 117 | inception.inception_v3_base(inputs) 118 | total_params, _ = slim.model_analyzer.analyze_vars( 119 | slim.get_model_variables()) 120 | self.assertAlmostEqual(21802784, total_params) 121 | 122 | def testBuildEndPoints(self): 123 | batch_size = 5 124 | height, width = 299, 299 125 | num_classes = 1000 126 | 127 | inputs = tf.random_uniform((batch_size, height, width, 3)) 128 | _, end_points = inception.inception_v3(inputs, num_classes) 129 | self.assertTrue('Logits' in end_points) 130 | logits = end_points['Logits'] 131 | self.assertListEqual(logits.get_shape().as_list(), 132 | [batch_size, num_classes]) 133 | self.assertTrue('AuxLogits' in end_points) 134 | aux_logits = end_points['AuxLogits'] 135 | self.assertListEqual(aux_logits.get_shape().as_list(), 136 | [batch_size, num_classes]) 137 | self.assertTrue('Mixed_7c' in end_points) 138 | pre_pool = end_points['Mixed_7c'] 139 | self.assertListEqual(pre_pool.get_shape().as_list(), 140 | [batch_size, 8, 8, 2048]) 141 | self.assertTrue('PreLogits' in end_points) 142 | pre_logits = end_points['PreLogits'] 143 | self.assertListEqual(pre_logits.get_shape().as_list(), 144 | [batch_size, 1, 1, 2048]) 145 | 146 | def testBuildEndPointsWithDepthMultiplierLessThanOne(self): 147 | batch_size = 5 148 | height, width = 299, 299 149 | num_classes = 1000 150 | 151 | inputs = tf.random_uniform((batch_size, height, width, 3)) 152 | _, end_points = inception.inception_v3(inputs, num_classes) 153 | 154 | endpoint_keys = [key for key in end_points.keys() 155 | if key.startswith('Mixed') or key.startswith('Conv')] 156 | 157 | _, end_points_with_multiplier = inception.inception_v3( 158 | inputs, num_classes, scope='depth_multiplied_net', 159 | depth_multiplier=0.5) 160 | 161 | for key in endpoint_keys: 162 | original_depth = end_points[key].get_shape().as_list()[3] 163 | new_depth = end_points_with_multiplier[key].get_shape().as_list()[3] 164 | self.assertEqual(0.5 * original_depth, new_depth) 165 | 166 | def testBuildEndPointsWithDepthMultiplierGreaterThanOne(self): 167 | batch_size = 5 168 | height, width = 299, 299 169 | num_classes = 1000 170 | 171 | inputs = tf.random_uniform((batch_size, height, width, 3)) 172 | _, end_points = inception.inception_v3(inputs, num_classes) 173 | 174 | endpoint_keys = [key for key in end_points.keys() 175 | if key.startswith('Mixed') or key.startswith('Conv')] 176 | 177 | _, end_points_with_multiplier = inception.inception_v3( 178 | inputs, num_classes, scope='depth_multiplied_net', 179 | depth_multiplier=2.0) 180 | 181 | for key in endpoint_keys: 182 | original_depth = end_points[key].get_shape().as_list()[3] 183 | new_depth = end_points_with_multiplier[key].get_shape().as_list()[3] 184 | self.assertEqual(2.0 * original_depth, new_depth) 185 | 186 | def testRaiseValueErrorWithInvalidDepthMultiplier(self): 187 | batch_size = 5 188 | height, width = 299, 299 189 | num_classes = 1000 190 | 191 | inputs = tf.random_uniform((batch_size, height, width, 3)) 192 | with self.assertRaises(ValueError): 193 | _ = inception.inception_v3(inputs, num_classes, depth_multiplier=-0.1) 194 | with self.assertRaises(ValueError): 195 | _ = inception.inception_v3(inputs, num_classes, depth_multiplier=0.0) 196 | 197 | def testHalfSizeImages(self): 198 | batch_size = 5 199 | height, width = 150, 150 200 | num_classes = 1000 201 | 202 | inputs = tf.random_uniform((batch_size, height, width, 3)) 203 | logits, end_points = inception.inception_v3(inputs, num_classes) 204 | self.assertTrue(logits.op.name.startswith('InceptionV3/Logits')) 205 | self.assertListEqual(logits.get_shape().as_list(), 206 | [batch_size, num_classes]) 207 | pre_pool = end_points['Mixed_7c'] 208 | self.assertListEqual(pre_pool.get_shape().as_list(), 209 | [batch_size, 3, 3, 2048]) 210 | 211 | def testUnknownImageShape(self): 212 | tf.reset_default_graph() 213 | batch_size = 2 214 | height, width = 299, 299 215 | num_classes = 1000 216 | input_np = np.random.uniform(0, 1, (batch_size, height, width, 3)) 217 | with self.test_session() as sess: 218 | inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3)) 219 | logits, end_points = inception.inception_v3(inputs, num_classes) 220 | self.assertListEqual(logits.get_shape().as_list(), 221 | [batch_size, num_classes]) 222 | pre_pool = end_points['Mixed_7c'] 223 | feed_dict = {inputs: input_np} 224 | tf.initialize_all_variables().run() 225 | pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict) 226 | self.assertListEqual(list(pre_pool_out.shape), [batch_size, 8, 8, 2048]) 227 | 228 | def testUnknowBatchSize(self): 229 | batch_size = 1 230 | height, width = 299, 299 231 | num_classes = 1000 232 | 233 | inputs = tf.placeholder(tf.float32, (None, height, width, 3)) 234 | logits, _ = inception.inception_v3(inputs, num_classes) 235 | self.assertTrue(logits.op.name.startswith('InceptionV3/Logits')) 236 | self.assertListEqual(logits.get_shape().as_list(), 237 | [None, num_classes]) 238 | images = tf.random_uniform((batch_size, height, width, 3)) 239 | 240 | with self.test_session() as sess: 241 | sess.run(tf.initialize_all_variables()) 242 | output = sess.run(logits, {inputs: images.eval()}) 243 | self.assertEquals(output.shape, (batch_size, num_classes)) 244 | 245 | def testEvaluation(self): 246 | batch_size = 2 247 | height, width = 299, 299 248 | num_classes = 1000 249 | 250 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 251 | logits, _ = inception.inception_v3(eval_inputs, num_classes, 252 | is_training=False) 253 | predictions = tf.argmax(logits, 1) 254 | 255 | with self.test_session() as sess: 256 | sess.run(tf.initialize_all_variables()) 257 | output = sess.run(predictions) 258 | self.assertEquals(output.shape, (batch_size,)) 259 | 260 | def testTrainEvalWithReuse(self): 261 | train_batch_size = 5 262 | eval_batch_size = 2 263 | height, width = 150, 150 264 | num_classes = 1000 265 | 266 | train_inputs = tf.random_uniform((train_batch_size, height, width, 3)) 267 | inception.inception_v3(train_inputs, num_classes) 268 | eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3)) 269 | logits, _ = inception.inception_v3(eval_inputs, num_classes, 270 | is_training=False, reuse=True) 271 | predictions = tf.argmax(logits, 1) 272 | 273 | with self.test_session() as sess: 274 | sess.run(tf.initialize_all_variables()) 275 | output = sess.run(predictions) 276 | self.assertEquals(output.shape, (eval_batch_size,)) 277 | 278 | def testLogitsNotSqueezed(self): 279 | num_classes = 25 280 | images = tf.random_uniform([1, 299, 299, 3]) 281 | logits, _ = inception.inception_v3(images, 282 | num_classes=num_classes, 283 | spatial_squeeze=False) 284 | 285 | with self.test_session() as sess: 286 | tf.initialize_all_variables().run() 287 | logits_out = sess.run(logits) 288 | self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes]) 289 | 290 | 291 | if __name__ == '__main__': 292 | tf.test.main() 293 | -------------------------------------------------------------------------------- /nets/inception_v4_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.inception_v4.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import inception 23 | 24 | 25 | class InceptionTest(tf.test.TestCase): 26 | 27 | def testBuildLogits(self): 28 | batch_size = 5 29 | height, width = 299, 299 30 | num_classes = 1000 31 | inputs = tf.random_uniform((batch_size, height, width, 3)) 32 | logits, end_points = inception.inception_v4(inputs, num_classes) 33 | auxlogits = end_points['AuxLogits'] 34 | predictions = end_points['Predictions'] 35 | self.assertTrue(auxlogits.op.name.startswith('InceptionV4/AuxLogits')) 36 | self.assertListEqual(auxlogits.get_shape().as_list(), 37 | [batch_size, num_classes]) 38 | self.assertTrue(logits.op.name.startswith('InceptionV4/Logits')) 39 | self.assertListEqual(logits.get_shape().as_list(), 40 | [batch_size, num_classes]) 41 | self.assertTrue(predictions.op.name.startswith( 42 | 'InceptionV4/Logits/Predictions')) 43 | self.assertListEqual(predictions.get_shape().as_list(), 44 | [batch_size, num_classes]) 45 | 46 | def testBuildWithoutAuxLogits(self): 47 | batch_size = 5 48 | height, width = 299, 299 49 | num_classes = 1000 50 | inputs = tf.random_uniform((batch_size, height, width, 3)) 51 | logits, endpoints = inception.inception_v4(inputs, num_classes, 52 | create_aux_logits=False) 53 | self.assertFalse('AuxLogits' in endpoints) 54 | self.assertTrue(logits.op.name.startswith('InceptionV4/Logits')) 55 | self.assertListEqual(logits.get_shape().as_list(), 56 | [batch_size, num_classes]) 57 | 58 | def testAllEndPointsShapes(self): 59 | batch_size = 5 60 | height, width = 299, 299 61 | num_classes = 1000 62 | inputs = tf.random_uniform((batch_size, height, width, 3)) 63 | _, end_points = inception.inception_v4(inputs, num_classes) 64 | endpoints_shapes = {'Conv2d_1a_3x3': [batch_size, 149, 149, 32], 65 | 'Conv2d_2a_3x3': [batch_size, 147, 147, 32], 66 | 'Conv2d_2b_3x3': [batch_size, 147, 147, 64], 67 | 'Mixed_3a': [batch_size, 73, 73, 160], 68 | 'Mixed_4a': [batch_size, 71, 71, 192], 69 | 'Mixed_5a': [batch_size, 35, 35, 384], 70 | # 4 x Inception-A blocks 71 | 'Mixed_5b': [batch_size, 35, 35, 384], 72 | 'Mixed_5c': [batch_size, 35, 35, 384], 73 | 'Mixed_5d': [batch_size, 35, 35, 384], 74 | 'Mixed_5e': [batch_size, 35, 35, 384], 75 | # Reduction-A block 76 | 'Mixed_6a': [batch_size, 17, 17, 1024], 77 | # 7 x Inception-B blocks 78 | 'Mixed_6b': [batch_size, 17, 17, 1024], 79 | 'Mixed_6c': [batch_size, 17, 17, 1024], 80 | 'Mixed_6d': [batch_size, 17, 17, 1024], 81 | 'Mixed_6e': [batch_size, 17, 17, 1024], 82 | 'Mixed_6f': [batch_size, 17, 17, 1024], 83 | 'Mixed_6g': [batch_size, 17, 17, 1024], 84 | 'Mixed_6h': [batch_size, 17, 17, 1024], 85 | # Reduction-A block 86 | 'Mixed_7a': [batch_size, 8, 8, 1536], 87 | # 3 x Inception-C blocks 88 | 'Mixed_7b': [batch_size, 8, 8, 1536], 89 | 'Mixed_7c': [batch_size, 8, 8, 1536], 90 | 'Mixed_7d': [batch_size, 8, 8, 1536], 91 | # Logits and predictions 92 | 'AuxLogits': [batch_size, num_classes], 93 | 'PreLogitsFlatten': [batch_size, 1536], 94 | 'Logits': [batch_size, num_classes], 95 | 'Predictions': [batch_size, num_classes]} 96 | self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys()) 97 | for endpoint_name in endpoints_shapes: 98 | expected_shape = endpoints_shapes[endpoint_name] 99 | self.assertTrue(endpoint_name in end_points) 100 | self.assertListEqual(end_points[endpoint_name].get_shape().as_list(), 101 | expected_shape) 102 | 103 | def testBuildBaseNetwork(self): 104 | batch_size = 5 105 | height, width = 299, 299 106 | inputs = tf.random_uniform((batch_size, height, width, 3)) 107 | net, end_points = inception.inception_v4_base(inputs) 108 | self.assertTrue(net.op.name.startswith( 109 | 'InceptionV4/Mixed_7d')) 110 | self.assertListEqual(net.get_shape().as_list(), [batch_size, 8, 8, 1536]) 111 | expected_endpoints = [ 112 | 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'Mixed_3a', 113 | 'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 114 | 'Mixed_5e', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 115 | 'Mixed_6e', 'Mixed_6f', 'Mixed_6g', 'Mixed_6h', 'Mixed_7a', 116 | 'Mixed_7b', 'Mixed_7c', 'Mixed_7d'] 117 | self.assertItemsEqual(end_points.keys(), expected_endpoints) 118 | for name, op in end_points.iteritems(): 119 | self.assertTrue(op.name.startswith('InceptionV4/' + name)) 120 | 121 | def testBuildOnlyUpToFinalEndpoint(self): 122 | batch_size = 5 123 | height, width = 299, 299 124 | all_endpoints = [ 125 | 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'Mixed_3a', 126 | 'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 127 | 'Mixed_5e', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 128 | 'Mixed_6e', 'Mixed_6f', 'Mixed_6g', 'Mixed_6h', 'Mixed_7a', 129 | 'Mixed_7b', 'Mixed_7c', 'Mixed_7d'] 130 | for index, endpoint in enumerate(all_endpoints): 131 | with tf.Graph().as_default(): 132 | inputs = tf.random_uniform((batch_size, height, width, 3)) 133 | out_tensor, end_points = inception.inception_v4_base( 134 | inputs, final_endpoint=endpoint) 135 | self.assertTrue(out_tensor.op.name.startswith( 136 | 'InceptionV4/' + endpoint)) 137 | self.assertItemsEqual(all_endpoints[:index+1], end_points) 138 | 139 | def testVariablesSetDevice(self): 140 | batch_size = 5 141 | height, width = 299, 299 142 | num_classes = 1000 143 | inputs = tf.random_uniform((batch_size, height, width, 3)) 144 | # Force all Variables to reside on the device. 145 | with tf.variable_scope('on_cpu'), tf.device('/cpu:0'): 146 | inception.inception_v4(inputs, num_classes) 147 | with tf.variable_scope('on_gpu'), tf.device('/gpu:0'): 148 | inception.inception_v4(inputs, num_classes) 149 | for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'): 150 | self.assertDeviceEqual(v.device, '/cpu:0') 151 | for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'): 152 | self.assertDeviceEqual(v.device, '/gpu:0') 153 | 154 | def testHalfSizeImages(self): 155 | batch_size = 5 156 | height, width = 150, 150 157 | num_classes = 1000 158 | inputs = tf.random_uniform((batch_size, height, width, 3)) 159 | logits, end_points = inception.inception_v4(inputs, num_classes) 160 | self.assertTrue(logits.op.name.startswith('InceptionV4/Logits')) 161 | self.assertListEqual(logits.get_shape().as_list(), 162 | [batch_size, num_classes]) 163 | pre_pool = end_points['Mixed_7d'] 164 | self.assertListEqual(pre_pool.get_shape().as_list(), 165 | [batch_size, 3, 3, 1536]) 166 | 167 | def testUnknownBatchSize(self): 168 | batch_size = 1 169 | height, width = 299, 299 170 | num_classes = 1000 171 | with self.test_session() as sess: 172 | inputs = tf.placeholder(tf.float32, (None, height, width, 3)) 173 | logits, _ = inception.inception_v4(inputs, num_classes) 174 | self.assertTrue(logits.op.name.startswith('InceptionV4/Logits')) 175 | self.assertListEqual(logits.get_shape().as_list(), 176 | [None, num_classes]) 177 | images = tf.random_uniform((batch_size, height, width, 3)) 178 | sess.run(tf.initialize_all_variables()) 179 | output = sess.run(logits, {inputs: images.eval()}) 180 | self.assertEquals(output.shape, (batch_size, num_classes)) 181 | 182 | def testEvaluation(self): 183 | batch_size = 2 184 | height, width = 299, 299 185 | num_classes = 1000 186 | with self.test_session() as sess: 187 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 188 | logits, _ = inception.inception_v4(eval_inputs, 189 | num_classes, 190 | is_training=False) 191 | predictions = tf.argmax(logits, 1) 192 | sess.run(tf.initialize_all_variables()) 193 | output = sess.run(predictions) 194 | self.assertEquals(output.shape, (batch_size,)) 195 | 196 | def testTrainEvalWithReuse(self): 197 | train_batch_size = 5 198 | eval_batch_size = 2 199 | height, width = 150, 150 200 | num_classes = 1000 201 | with self.test_session() as sess: 202 | train_inputs = tf.random_uniform((train_batch_size, height, width, 3)) 203 | inception.inception_v4(train_inputs, num_classes) 204 | eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3)) 205 | logits, _ = inception.inception_v4(eval_inputs, 206 | num_classes, 207 | is_training=False, 208 | reuse=True) 209 | predictions = tf.argmax(logits, 1) 210 | sess.run(tf.initialize_all_variables()) 211 | output = sess.run(predictions) 212 | self.assertEquals(output.shape, (eval_batch_size,)) 213 | 214 | 215 | if __name__ == '__main__': 216 | tf.test.main() 217 | -------------------------------------------------------------------------------- /nets/lenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the LeNet model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def lenet(images, num_classes=10, is_training=False, 27 | dropout_keep_prob=0.5, 28 | prediction_fn=slim.softmax, 29 | scope='LeNet'): 30 | """Creates a variant of the LeNet model. 31 | 32 | Note that since the output is a set of 'logits', the values fall in the 33 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 34 | probability distribution over the characters, one will need to convert them 35 | using the softmax function: 36 | 37 | logits = lenet.lenet(images, is_training=False) 38 | probabilities = tf.nn.softmax(logits) 39 | predictions = tf.argmax(logits, 1) 40 | 41 | Args: 42 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 43 | num_classes: the number of classes in the dataset. 44 | is_training: specifies whether or not we're currently training the model. 45 | This variable will determine the behaviour of the dropout layer. 46 | dropout_keep_prob: the percentage of activation values that are retained. 47 | prediction_fn: a function to get predictions out of logits. 48 | scope: Optional variable_scope. 49 | 50 | Returns: 51 | logits: the pre-softmax activations, a tensor of size 52 | [batch_size, `num_classes`] 53 | end_points: a dictionary from components of the network to the corresponding 54 | activation. 55 | """ 56 | end_points = {} 57 | 58 | with tf.variable_scope(scope, 'LeNet', [images, num_classes]): 59 | net = slim.conv2d(images, 32, [5, 5], scope='conv1') 60 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 61 | net = slim.conv2d(net, 64, [5, 5], scope='conv2') 62 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 63 | net = slim.flatten(net) 64 | end_points['Flatten'] = net 65 | 66 | net = slim.fully_connected(net, 1024, scope='fc3') 67 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 68 | scope='dropout3') 69 | logits = slim.fully_connected(net, num_classes, activation_fn=None, 70 | scope='fc4') 71 | 72 | end_points['Logits'] = logits 73 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 74 | 75 | return logits, end_points 76 | lenet.default_image_size = 28 77 | 78 | 79 | def lenet_arg_scope(weight_decay=0.0): 80 | """Defines the default lenet argument scope. 81 | 82 | Args: 83 | weight_decay: The weight decay to use for regularizing the model. 84 | 85 | Returns: 86 | An `arg_scope` to use for the inception v3 model. 87 | """ 88 | with slim.arg_scope( 89 | [slim.conv2d, slim.fully_connected], 90 | weights_regularizer=slim.l2_regularizer(weight_decay), 91 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1), 92 | activation_fn=tf.nn.relu) as sc: 93 | return sc 94 | -------------------------------------------------------------------------------- /nets/nets_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import functools 21 | 22 | import tensorflow as tf 23 | 24 | from nets import alexnet 25 | from nets import cifarnet 26 | from nets import inception 27 | from nets import lenet 28 | from nets import overfeat 29 | from nets import resnet_v1 30 | from nets import resnet_v2 31 | from nets import vgg 32 | 33 | slim = tf.contrib.slim 34 | 35 | networks_map = {'alexnet_v2': alexnet.alexnet_v2, 36 | 'cifarnet': cifarnet.cifarnet, 37 | 'overfeat': overfeat.overfeat, 38 | 'vgg_a': vgg.vgg_a, 39 | 'vgg_16': vgg.vgg_16, 40 | 'vgg_19': vgg.vgg_19, 41 | 'inception_v1': inception.inception_v1, 42 | 'inception_v2': inception.inception_v2, 43 | 'inception_v3': inception.inception_v3, 44 | 'inception_v4': inception.inception_v4, 45 | 'inception_resnet_v2': inception.inception_resnet_v2, 46 | 'lenet': lenet.lenet, 47 | 'resnet_v1_50': resnet_v1.resnet_v1_50, 48 | 'resnet_v1_101': resnet_v1.resnet_v1_101, 49 | 'resnet_v1_152': resnet_v1.resnet_v1_152, 50 | 'resnet_v1_200': resnet_v1.resnet_v1_200, 51 | 'resnet_v2_50': resnet_v2.resnet_v2_50, 52 | 'resnet_v2_101': resnet_v2.resnet_v2_101, 53 | 'resnet_v2_152': resnet_v2.resnet_v2_152, 54 | 'resnet_v2_200': resnet_v2.resnet_v2_200, 55 | } 56 | 57 | arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope, 58 | 'cifarnet': cifarnet.cifarnet_arg_scope, 59 | 'overfeat': overfeat.overfeat_arg_scope, 60 | 'vgg_a': vgg.vgg_arg_scope, 61 | 'vgg_16': vgg.vgg_arg_scope, 62 | 'vgg_19': vgg.vgg_arg_scope, 63 | 'inception_v1': inception.inception_v3_arg_scope, 64 | 'inception_v2': inception.inception_v3_arg_scope, 65 | 'inception_v3': inception.inception_v3_arg_scope, 66 | 'inception_v4': inception.inception_v4_arg_scope, 67 | 'inception_resnet_v2': 68 | inception.inception_resnet_v2_arg_scope, 69 | 'lenet': lenet.lenet_arg_scope, 70 | 'resnet_v1_50': resnet_v1.resnet_arg_scope, 71 | 'resnet_v1_101': resnet_v1.resnet_arg_scope, 72 | 'resnet_v1_152': resnet_v1.resnet_arg_scope, 73 | 'resnet_v1_200': resnet_v1.resnet_arg_scope, 74 | 'resnet_v2_50': resnet_v2.resnet_arg_scope, 75 | 'resnet_v2_101': resnet_v2.resnet_arg_scope, 76 | 'resnet_v2_152': resnet_v2.resnet_arg_scope, 77 | 'resnet_v2_200': resnet_v2.resnet_arg_scope, 78 | } 79 | 80 | 81 | def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False): 82 | """Returns a network_fn such as `logits, end_points = network_fn(images)`. 83 | 84 | Args: 85 | name: The name of the network. 86 | num_classes: The number of classes to use for classification. 87 | weight_decay: The l2 coefficient for the model weights. 88 | is_training: `True` if the model is being used for training and `False` 89 | otherwise. 90 | 91 | Returns: 92 | network_fn: A function that applies the model to a batch of images. It has 93 | the following signature: 94 | logits, end_points = network_fn(images) 95 | Raises: 96 | ValueError: If network `name` is not recognized. 97 | """ 98 | if name not in networks_map: 99 | raise ValueError('Name of network unknown %s' % name) 100 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay) 101 | func = networks_map[name] 102 | @functools.wraps(func) 103 | def network_fn(images, **kwargs): 104 | with slim.arg_scope(arg_scope): 105 | return func(images, num_classes, is_training=is_training, **kwargs) 106 | if hasattr(func, 'default_image_size'): 107 | network_fn.default_image_size = func.default_image_size 108 | 109 | return network_fn 110 | -------------------------------------------------------------------------------- /nets/nets_factory_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for slim.inception.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import tensorflow as tf 24 | 25 | from nets import nets_factory 26 | 27 | 28 | class NetworksTest(tf.test.TestCase): 29 | 30 | def testGetNetworkFn(self): 31 | batch_size = 5 32 | num_classes = 1000 33 | for net in nets_factory.networks_map: 34 | with self.test_session(): 35 | net_fn = nets_factory.get_network_fn(net, num_classes) 36 | # Most networks use 224 as their default_image_size 37 | image_size = getattr(net_fn, 'default_image_size', 224) 38 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 39 | logits, end_points = net_fn(inputs) 40 | self.assertTrue(isinstance(logits, tf.Tensor)) 41 | self.assertTrue(isinstance(end_points, dict)) 42 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 43 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 44 | 45 | if __name__ == '__main__': 46 | tf.test.main() 47 | -------------------------------------------------------------------------------- /nets/overfeat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the model definition for the OverFeat network. 16 | 17 | The definition for the network was obtained from: 18 | OverFeat: Integrated Recognition, Localization and Detection using 19 | Convolutional Networks 20 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 21 | Yann LeCun, 2014 22 | http://arxiv.org/abs/1312.6229 23 | 24 | Usage: 25 | with slim.arg_scope(overfeat.overfeat_arg_scope()): 26 | outputs, end_points = overfeat.overfeat(inputs) 27 | 28 | @@overfeat 29 | """ 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | 34 | import tensorflow as tf 35 | 36 | slim = tf.contrib.slim 37 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 38 | 39 | 40 | def overfeat_arg_scope(weight_decay=0.0005): 41 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 42 | activation_fn=tf.nn.relu, 43 | weights_regularizer=slim.l2_regularizer(weight_decay), 44 | biases_initializer=tf.zeros_initializer): 45 | with slim.arg_scope([slim.conv2d], padding='SAME'): 46 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 47 | return arg_sc 48 | 49 | 50 | def overfeat(inputs, 51 | num_classes=1000, 52 | is_training=True, 53 | dropout_keep_prob=0.5, 54 | spatial_squeeze=True, 55 | scope='overfeat'): 56 | """Contains the model definition for the OverFeat network. 57 | 58 | The definition for the network was obtained from: 59 | OverFeat: Integrated Recognition, Localization and Detection using 60 | Convolutional Networks 61 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 62 | Yann LeCun, 2014 63 | http://arxiv.org/abs/1312.6229 64 | 65 | Note: All the fully_connected layers have been transformed to conv2d layers. 66 | To use in classification mode, resize input to 231x231. To use in fully 67 | convolutional mode, set spatial_squeeze to false. 68 | 69 | Args: 70 | inputs: a tensor of size [batch_size, height, width, channels]. 71 | num_classes: number of predicted classes. 72 | is_training: whether or not the model is being trained. 73 | dropout_keep_prob: the probability that activations are kept in the dropout 74 | layers during training. 75 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 76 | outputs. Useful to remove unnecessary dimensions for classification. 77 | scope: Optional scope for the variables. 78 | 79 | Returns: 80 | the last op containing the log predictions and end_points dict. 81 | 82 | """ 83 | with tf.variable_scope(scope, 'overfeat', [inputs]) as sc: 84 | end_points_collection = sc.name + '_end_points' 85 | # Collect outputs for conv2d, fully_connected and max_pool2d 86 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 87 | outputs_collections=end_points_collection): 88 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 89 | scope='conv1') 90 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 91 | net = slim.conv2d(net, 256, [5, 5], padding='VALID', scope='conv2') 92 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 93 | net = slim.conv2d(net, 512, [3, 3], scope='conv3') 94 | net = slim.conv2d(net, 1024, [3, 3], scope='conv4') 95 | net = slim.conv2d(net, 1024, [3, 3], scope='conv5') 96 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 97 | with slim.arg_scope([slim.conv2d], 98 | weights_initializer=trunc_normal(0.005), 99 | biases_initializer=tf.constant_initializer(0.1)): 100 | # Use conv2d instead of fully_connected layers. 101 | net = slim.conv2d(net, 3072, [6, 6], padding='VALID', scope='fc6') 102 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 103 | scope='dropout6') 104 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 105 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 106 | scope='dropout7') 107 | net = slim.conv2d(net, num_classes, [1, 1], 108 | activation_fn=None, 109 | normalizer_fn=None, 110 | biases_initializer=tf.zeros_initializer, 111 | scope='fc8') 112 | # Convert end_points_collection into a end_point dict. 113 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 114 | if spatial_squeeze: 115 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 116 | end_points[sc.name + '/fc8'] = net 117 | return net, end_points 118 | overfeat.default_image_size = 231 119 | -------------------------------------------------------------------------------- /nets/overfeat_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.overfeat.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | 22 | from nets import overfeat 23 | 24 | slim = tf.contrib.slim 25 | 26 | 27 | class OverFeatTest(tf.test.TestCase): 28 | 29 | def testBuild(self): 30 | batch_size = 5 31 | height, width = 231, 231 32 | num_classes = 1000 33 | with self.test_session(): 34 | inputs = tf.random_uniform((batch_size, height, width, 3)) 35 | logits, _ = overfeat.overfeat(inputs, num_classes) 36 | self.assertEquals(logits.op.name, 'overfeat/fc8/squeezed') 37 | self.assertListEqual(logits.get_shape().as_list(), 38 | [batch_size, num_classes]) 39 | 40 | def testFullyConvolutional(self): 41 | batch_size = 1 42 | height, width = 281, 281 43 | num_classes = 1000 44 | with self.test_session(): 45 | inputs = tf.random_uniform((batch_size, height, width, 3)) 46 | logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False) 47 | self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd') 48 | self.assertListEqual(logits.get_shape().as_list(), 49 | [batch_size, 2, 2, num_classes]) 50 | 51 | def testEndPoints(self): 52 | batch_size = 5 53 | height, width = 231, 231 54 | num_classes = 1000 55 | with self.test_session(): 56 | inputs = tf.random_uniform((batch_size, height, width, 3)) 57 | _, end_points = overfeat.overfeat(inputs, num_classes) 58 | expected_names = ['overfeat/conv1', 59 | 'overfeat/pool1', 60 | 'overfeat/conv2', 61 | 'overfeat/pool2', 62 | 'overfeat/conv3', 63 | 'overfeat/conv4', 64 | 'overfeat/conv5', 65 | 'overfeat/pool5', 66 | 'overfeat/fc6', 67 | 'overfeat/fc7', 68 | 'overfeat/fc8' 69 | ] 70 | self.assertSetEqual(set(end_points.keys()), set(expected_names)) 71 | 72 | def testModelVariables(self): 73 | batch_size = 5 74 | height, width = 231, 231 75 | num_classes = 1000 76 | with self.test_session(): 77 | inputs = tf.random_uniform((batch_size, height, width, 3)) 78 | overfeat.overfeat(inputs, num_classes) 79 | expected_names = ['overfeat/conv1/weights', 80 | 'overfeat/conv1/biases', 81 | 'overfeat/conv2/weights', 82 | 'overfeat/conv2/biases', 83 | 'overfeat/conv3/weights', 84 | 'overfeat/conv3/biases', 85 | 'overfeat/conv4/weights', 86 | 'overfeat/conv4/biases', 87 | 'overfeat/conv5/weights', 88 | 'overfeat/conv5/biases', 89 | 'overfeat/fc6/weights', 90 | 'overfeat/fc6/biases', 91 | 'overfeat/fc7/weights', 92 | 'overfeat/fc7/biases', 93 | 'overfeat/fc8/weights', 94 | 'overfeat/fc8/biases', 95 | ] 96 | model_variables = [v.op.name for v in slim.get_model_variables()] 97 | self.assertSetEqual(set(model_variables), set(expected_names)) 98 | 99 | def testEvaluation(self): 100 | batch_size = 2 101 | height, width = 231, 231 102 | num_classes = 1000 103 | with self.test_session(): 104 | eval_inputs = tf.random_uniform((batch_size, height, width, 3)) 105 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False) 106 | self.assertListEqual(logits.get_shape().as_list(), 107 | [batch_size, num_classes]) 108 | predictions = tf.argmax(logits, 1) 109 | self.assertListEqual(predictions.get_shape().as_list(), [batch_size]) 110 | 111 | def testTrainEvalWithReuse(self): 112 | train_batch_size = 2 113 | eval_batch_size = 1 114 | train_height, train_width = 231, 231 115 | eval_height, eval_width = 281, 281 116 | num_classes = 1000 117 | with self.test_session(): 118 | train_inputs = tf.random_uniform( 119 | (train_batch_size, train_height, train_width, 3)) 120 | logits, _ = overfeat.overfeat(train_inputs) 121 | self.assertListEqual(logits.get_shape().as_list(), 122 | [train_batch_size, num_classes]) 123 | tf.get_variable_scope().reuse_variables() 124 | eval_inputs = tf.random_uniform( 125 | (eval_batch_size, eval_height, eval_width, 3)) 126 | logits, _ = overfeat.overfeat(eval_inputs, is_training=False, 127 | spatial_squeeze=False) 128 | self.assertListEqual(logits.get_shape().as_list(), 129 | [eval_batch_size, 2, 2, num_classes]) 130 | logits = tf.reduce_mean(logits, [1, 2]) 131 | predictions = tf.argmax(logits, 1) 132 | self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size]) 133 | 134 | def testForward(self): 135 | batch_size = 1 136 | height, width = 231, 231 137 | with self.test_session() as sess: 138 | inputs = tf.random_uniform((batch_size, height, width, 3)) 139 | logits, _ = overfeat.overfeat(inputs) 140 | sess.run(tf.initialize_all_variables()) 141 | output = sess.run(logits) 142 | self.assertTrue(output.any()) 143 | 144 | if __name__ == '__main__': 145 | tf.test.main() 146 | -------------------------------------------------------------------------------- /nets/resnet_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains building blocks for various versions of Residual Networks. 16 | 17 | Residual networks (ResNets) were proposed in: 18 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 19 | Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015 20 | 21 | More variants were introduced in: 22 | Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun 23 | Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016 24 | 25 | We can obtain different ResNet variants by changing the network depth, width, 26 | and form of residual unit. This module implements the infrastructure for 27 | building them. Concrete ResNet units and full ResNet networks are implemented in 28 | the accompanying resnet_v1.py and resnet_v2.py modules. 29 | 30 | Compared to https://github.com/KaimingHe/deep-residual-networks, in the current 31 | implementation we subsample the output activations in the last residual unit of 32 | each block, instead of subsampling the input activations in the first residual 33 | unit of each block. The two implementations give identical results but our 34 | implementation is more memory efficient. 35 | """ 36 | from __future__ import absolute_import 37 | from __future__ import division 38 | from __future__ import print_function 39 | 40 | import collections 41 | import tensorflow as tf 42 | 43 | slim = tf.contrib.slim 44 | 45 | 46 | class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])): 47 | """A named tuple describing a ResNet block. 48 | 49 | Its parts are: 50 | scope: The scope of the `Block`. 51 | unit_fn: The ResNet unit function which takes as input a `Tensor` and 52 | returns another `Tensor` with the output of the ResNet unit. 53 | args: A list of length equal to the number of units in the `Block`. The list 54 | contains one (depth, depth_bottleneck, stride) tuple for each unit in the 55 | block to serve as argument to unit_fn. 56 | """ 57 | 58 | 59 | def subsample(inputs, factor, scope=None): 60 | """Subsamples the input along the spatial dimensions. 61 | 62 | Args: 63 | inputs: A `Tensor` of size [batch, height_in, width_in, channels]. 64 | factor: The subsampling factor. 65 | scope: Optional variable_scope. 66 | 67 | Returns: 68 | output: A `Tensor` of size [batch, height_out, width_out, channels] with the 69 | input, either intact (if factor == 1) or subsampled (if factor > 1). 70 | """ 71 | if factor == 1: 72 | return inputs 73 | else: 74 | return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope) 75 | 76 | 77 | def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None): 78 | """Strided 2-D convolution with 'SAME' padding. 79 | 80 | When stride > 1, then we do explicit zero-padding, followed by conv2d with 81 | 'VALID' padding. 82 | 83 | Note that 84 | 85 | net = conv2d_same(inputs, num_outputs, 3, stride=stride) 86 | 87 | is equivalent to 88 | 89 | net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME') 90 | net = subsample(net, factor=stride) 91 | 92 | whereas 93 | 94 | net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME') 95 | 96 | is different when the input's height or width is even, which is why we add the 97 | current function. For more details, see ResnetUtilsTest.testConv2DSameEven(). 98 | 99 | Args: 100 | inputs: A 4-D tensor of size [batch, height_in, width_in, channels]. 101 | num_outputs: An integer, the number of output filters. 102 | kernel_size: An int with the kernel_size of the filters. 103 | stride: An integer, the output stride. 104 | rate: An integer, rate for atrous convolution. 105 | scope: Scope. 106 | 107 | Returns: 108 | output: A 4-D tensor of size [batch, height_out, width_out, channels] with 109 | the convolution output. 110 | """ 111 | if stride == 1: 112 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate, 113 | padding='SAME', scope=scope) 114 | else: 115 | kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1) 116 | pad_total = kernel_size_effective - 1 117 | pad_beg = pad_total // 2 118 | pad_end = pad_total - pad_beg 119 | inputs = tf.pad(inputs, 120 | [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]]) 121 | return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride, 122 | rate=rate, padding='VALID', scope=scope) 123 | 124 | 125 | @slim.add_arg_scope 126 | def stack_blocks_dense(net, blocks, output_stride=None, 127 | outputs_collections=None): 128 | """Stacks ResNet `Blocks` and controls output feature density. 129 | 130 | First, this function creates scopes for the ResNet in the form of 131 | 'block_name/unit_1', 'block_name/unit_2', etc. 132 | 133 | Second, this function allows the user to explicitly control the ResNet 134 | output_stride, which is the ratio of the input to output spatial resolution. 135 | This is useful for dense prediction tasks such as semantic segmentation or 136 | object detection. 137 | 138 | Most ResNets consist of 4 ResNet blocks and subsample the activations by a 139 | factor of 2 when transitioning between consecutive ResNet blocks. This results 140 | to a nominal ResNet output_stride equal to 8. If we set the output_stride to 141 | half the nominal network stride (e.g., output_stride=4), then we compute 142 | responses twice. 143 | 144 | Control of the output feature density is implemented by atrous convolution. 145 | 146 | Args: 147 | net: A `Tensor` of size [batch, height, width, channels]. 148 | blocks: A list of length equal to the number of ResNet `Blocks`. Each 149 | element is a ResNet `Block` object describing the units in the `Block`. 150 | output_stride: If `None`, then the output will be computed at the nominal 151 | network stride. If output_stride is not `None`, it specifies the requested 152 | ratio of input to output spatial resolution, which needs to be equal to 153 | the product of unit strides from the start up to some level of the ResNet. 154 | For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1, 155 | then valid values for the output_stride are 1, 2, 6, 24 or None (which 156 | is equivalent to output_stride=24). 157 | outputs_collections: Collection to add the ResNet block outputs. 158 | 159 | Returns: 160 | net: Output tensor with stride equal to the specified output_stride. 161 | 162 | Raises: 163 | ValueError: If the target output_stride is not valid. 164 | """ 165 | # The current_stride variable keeps track of the effective stride of the 166 | # activations. This allows us to invoke atrous convolution whenever applying 167 | # the next residual unit would result in the activations having stride larger 168 | # than the target output_stride. 169 | current_stride = 1 170 | 171 | # The atrous convolution rate parameter. 172 | rate = 1 173 | 174 | for block in blocks: 175 | with tf.variable_scope(block.scope, 'block', [net]) as sc: 176 | for i, unit in enumerate(block.args): 177 | if output_stride is not None and current_stride > output_stride: 178 | raise ValueError('The target output_stride cannot be reached.') 179 | 180 | with tf.variable_scope('unit_%d' % (i + 1), values=[net]): 181 | unit_depth, unit_depth_bottleneck, unit_stride = unit 182 | 183 | # If we have reached the target output_stride, then we need to employ 184 | # atrous convolution with stride=1 and multiply the atrous rate by the 185 | # current unit's stride for use in subsequent layers. 186 | if output_stride is not None and current_stride == output_stride: 187 | net = block.unit_fn(net, 188 | depth=unit_depth, 189 | depth_bottleneck=unit_depth_bottleneck, 190 | stride=1, 191 | rate=rate) 192 | rate *= unit_stride 193 | 194 | else: 195 | net = block.unit_fn(net, 196 | depth=unit_depth, 197 | depth_bottleneck=unit_depth_bottleneck, 198 | stride=unit_stride, 199 | rate=1) 200 | current_stride *= unit_stride 201 | net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net) 202 | 203 | if output_stride is not None and current_stride != output_stride: 204 | raise ValueError('The target output_stride cannot be reached.') 205 | 206 | return net 207 | 208 | 209 | def resnet_arg_scope(weight_decay=0.0001, 210 | batch_norm_decay=0.997, 211 | batch_norm_epsilon=1e-5, 212 | batch_norm_scale=True): 213 | """Defines the default ResNet arg scope. 214 | 215 | TODO(gpapan): The batch-normalization related default values above are 216 | appropriate for use in conjunction with the reference ResNet models 217 | released at https://github.com/KaimingHe/deep-residual-networks. When 218 | training ResNets from scratch, they might need to be tuned. 219 | 220 | Args: 221 | weight_decay: The weight decay to use for regularizing the model. 222 | batch_norm_decay: The moving average decay when estimating layer activation 223 | statistics in batch normalization. 224 | batch_norm_epsilon: Small constant to prevent division by zero when 225 | normalizing activations by their variance in batch normalization. 226 | batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the 227 | activations in the batch normalization layer. 228 | 229 | Returns: 230 | An `arg_scope` to use for the resnet models. 231 | """ 232 | batch_norm_params = { 233 | 'decay': batch_norm_decay, 234 | 'epsilon': batch_norm_epsilon, 235 | 'scale': batch_norm_scale, 236 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 237 | } 238 | 239 | with slim.arg_scope( 240 | [slim.conv2d], 241 | weights_regularizer=slim.l2_regularizer(weight_decay), 242 | weights_initializer=slim.variance_scaling_initializer(), 243 | activation_fn=tf.nn.relu, 244 | normalizer_fn=slim.batch_norm, 245 | normalizer_params=batch_norm_params): 246 | with slim.arg_scope([slim.batch_norm], **batch_norm_params): 247 | # The following implies padding='SAME' for pool1, which makes feature 248 | # alignment easier for dense prediction tasks. This is also used in 249 | # https://github.com/facebook/fb.resnet.torch. However the accompanying 250 | # code of 'Deep Residual Learning for Image Recognition' uses 251 | # padding='VALID' for pool1. You can switch to that choice by setting 252 | # slim.arg_scope([slim.max_pool2d], padding='VALID'). 253 | with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc: 254 | return arg_sc 255 | -------------------------------------------------------------------------------- /nets/vgg.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains model definitions for versions of the Oxford VGG network. 16 | These model definitions were introduced in the following technical report: 17 | Very Deep Convolutional Networks For Large-Scale Image Recognition 18 | Karen Simonyan and Andrew Zisserman 19 | arXiv technical report, 2015 20 | PDF: http://arxiv.org/pdf/1409.1556.pdf 21 | ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf 22 | CC-BY-4.0 23 | More information can be obtained from the VGG website: 24 | www.robots.ox.ac.uk/~vgg/research/very_deep/ 25 | Usage: 26 | with slim.arg_scope(vgg.vgg_arg_scope()): 27 | outputs, end_points = vgg.vgg_a(inputs) 28 | with slim.arg_scope(vgg.vgg_arg_scope()): 29 | outputs, end_points = vgg.vgg_16(inputs) 30 | @@vgg_a 31 | @@vgg_16 32 | @@vgg_19 33 | """ 34 | from __future__ import absolute_import 35 | from __future__ import division 36 | from __future__ import print_function 37 | 38 | import tensorflow as tf 39 | 40 | slim = tf.contrib.slim 41 | 42 | 43 | def vgg_arg_scope(weight_decay=0.0005): 44 | """Defines the VGG arg scope. 45 | Args: 46 | weight_decay: The l2 regularization coefficient. 47 | Returns: 48 | An arg_scope. 49 | """ 50 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 51 | activation_fn=tf.nn.relu, 52 | weights_regularizer=slim.l2_regularizer(weight_decay), 53 | biases_initializer=tf.zeros_initializer()): 54 | with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc: 55 | return arg_sc 56 | 57 | 58 | def vgg_a(inputs, 59 | num_classes=1000, 60 | is_training=True, 61 | dropout_keep_prob=0.5, 62 | spatial_squeeze=True, 63 | scope='vgg_a'): 64 | """Oxford Net VGG 11-Layers version A Example. 65 | Note: All the fully_connected layers have been transformed to conv2d layers. 66 | To use in classification mode, resize input to 224x224. 67 | Args: 68 | inputs: a tensor of size [batch_size, height, width, channels]. 69 | num_classes: number of predicted classes. 70 | is_training: whether or not the model is being trained. 71 | dropout_keep_prob: the probability that activations are kept in the dropout 72 | layers during training. 73 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 74 | outputs. Useful to remove unnecessary dimensions for classification. 75 | scope: Optional scope for the variables. 76 | Returns: 77 | the last op containing the log predictions and end_points dict. 78 | """ 79 | with tf.variable_scope(scope, 'vgg_a', [inputs]) as sc: 80 | end_points_collection = sc.name + '_end_points' 81 | # Collect outputs for conv2d, fully_connected and max_pool2d. 82 | with slim.arg_scope([slim.conv2d, slim.max_pool2d], 83 | outputs_collections=end_points_collection): 84 | net = slim.repeat(inputs, 1, slim.conv2d, 64, [3, 3], scope='conv1') 85 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 86 | net = slim.repeat(net, 1, slim.conv2d, 128, [3, 3], scope='conv2') 87 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 88 | net = slim.repeat(net, 2, slim.conv2d, 256, [3, 3], scope='conv3') 89 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 90 | net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv4') 91 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 92 | net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv5') 93 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 94 | # Use conv2d instead of fully_connected layers. 95 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') 96 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 97 | scope='dropout6') 98 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 99 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 100 | scope='dropout7') 101 | net = slim.conv2d(net, num_classes, [1, 1], 102 | activation_fn=None, 103 | normalizer_fn=None, 104 | scope='fc8') 105 | # Convert end_points_collection into a end_point dict. 106 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 107 | if spatial_squeeze: 108 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 109 | end_points[sc.name + '/fc8'] = net 110 | return net, end_points 111 | vgg_a.default_image_size = 224 112 | 113 | 114 | def vgg_16(inputs, 115 | num_classes=1000, 116 | is_training=True, 117 | dropout_keep_prob=0.5, 118 | spatial_squeeze=True, 119 | scope='vgg_16'): 120 | """Oxford Net VGG 16-Layers version D Example. 121 | Note: All the fully_connected layers have been transformed to conv2d layers. 122 | To use in classification mode, resize input to 224x224. 123 | Args: 124 | inputs: a tensor of size [batch_size, height, width, channels]. 125 | num_classes: number of predicted classes. 126 | is_training: whether or not the model is being trained. 127 | dropout_keep_prob: the probability that activations are kept in the dropout 128 | layers during training. 129 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 130 | outputs. Useful to remove unnecessary dimensions for classification. 131 | scope: Optional scope for the variables. 132 | Returns: 133 | the last op containing the log predictions and end_points dict. 134 | """ 135 | with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc: 136 | end_points_collection = sc.name + '_end_points' 137 | # Collect outputs for conv2d, fully_connected and max_pool2d. 138 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 139 | outputs_collections=end_points_collection): 140 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') 141 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 142 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') 143 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 144 | net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3') 145 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 146 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4') 147 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 148 | net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5') 149 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 150 | # Use conv2d instead of fully_connected layers. 151 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') 152 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 153 | scope='dropout6') 154 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 155 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 156 | scope='dropout7') 157 | net = slim.conv2d(net, num_classes, [1, 1], 158 | activation_fn=None, 159 | normalizer_fn=None, 160 | scope='fc8') 161 | # Convert end_points_collection into a end_point dict. 162 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 163 | if spatial_squeeze: 164 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 165 | end_points[sc.name + '/fc8'] = net 166 | return net, end_points 167 | vgg_16.default_image_size = 224 168 | 169 | 170 | def vgg_19(inputs, 171 | num_classes=1000, 172 | is_training=True, 173 | dropout_keep_prob=0.5, 174 | spatial_squeeze=True, 175 | scope='vgg_19'): 176 | """Oxford Net VGG 19-Layers version E Example. 177 | Note: All the fully_connected layers have been transformed to conv2d layers. 178 | To use in classification mode, resize input to 224x224. 179 | Args: 180 | inputs: a tensor of size [batch_size, height, width, channels]. 181 | num_classes: number of predicted classes. 182 | is_training: whether or not the model is being trained. 183 | dropout_keep_prob: the probability that activations are kept in the dropout 184 | layers during training. 185 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 186 | outputs. Useful to remove unnecessary dimensions for classification. 187 | scope: Optional scope for the variables. 188 | Returns: 189 | the last op containing the log predictions and end_points dict. 190 | """ 191 | with tf.variable_scope(scope, 'vgg_19', [inputs]) as sc: 192 | end_points_collection = sc.name + '_end_points' 193 | # Collect outputs for conv2d, fully_connected and max_pool2d. 194 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 195 | outputs_collections=end_points_collection): 196 | net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1') 197 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 198 | net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2') 199 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 200 | net = slim.repeat(net, 4, slim.conv2d, 256, [3, 3], scope='conv3') 201 | net = slim.max_pool2d(net, [2, 2], scope='pool3') 202 | net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv4') 203 | net = slim.max_pool2d(net, [2, 2], scope='pool4') 204 | net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv5') 205 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 206 | # Use conv2d instead of fully_connected layers. 207 | net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6') 208 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 209 | scope='dropout6') 210 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 211 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 212 | scope='dropout7') 213 | net = slim.conv2d(net, num_classes, [1, 1], 214 | activation_fn=None, 215 | normalizer_fn=None, 216 | scope='fc8') 217 | # Convert end_points_collection into a end_point dict. 218 | end_points = slim.utils.convert_collection_to_dict(end_points_collection) 219 | if spatial_squeeze: 220 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 221 | end_points[sc.name + '/fc8'] = net 222 | return net, end_points 223 | vgg_19.default_image_size = 224 224 | 225 | # Alias 226 | vgg_d = vgg_16 227 | vgg_e = vgg_19 -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /preprocessing/__pycache__/__init__.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/preprocessing/__pycache__/__init__.cpython-35.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/cifarnet_preprocessing.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/preprocessing/__pycache__/cifarnet_preprocessing.cpython-35.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/inception_preprocessing.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/preprocessing/__pycache__/inception_preprocessing.cpython-35.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/lenet_preprocessing.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/preprocessing/__pycache__/lenet_preprocessing.cpython-35.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/preprocessing_factory.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/preprocessing/__pycache__/preprocessing_factory.cpython-35.pyc -------------------------------------------------------------------------------- /preprocessing/__pycache__/vgg_preprocessing.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/preprocessing/__pycache__/vgg_preprocessing.cpython-35.pyc -------------------------------------------------------------------------------- /preprocessing/cifarnet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images in CIFAR-10. 16 | 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow as tf 24 | 25 | _PADDING = 4 26 | 27 | slim = tf.contrib.slim 28 | 29 | 30 | def preprocess_for_train(image, 31 | output_height, 32 | output_width, 33 | padding=_PADDING): 34 | """Preprocesses the given image for training. 35 | 36 | Note that the actual resizing scale is sampled from 37 | [`resize_size_min`, `resize_size_max`]. 38 | 39 | Args: 40 | image: A `Tensor` representing an image of arbitrary size. 41 | output_height: The height of the image after preprocessing. 42 | output_width: The width of the image after preprocessing. 43 | padding: The amound of padding before and after each dimension of the image. 44 | 45 | Returns: 46 | A preprocessed image. 47 | """ 48 | tf.image_summary('image', tf.expand_dims(image, 0)) 49 | 50 | # Transform the image to floats. 51 | image = tf.to_float(image) 52 | if padding > 0: 53 | image = tf.pad(image, [[padding, padding], [padding, padding], [0, 0]]) 54 | # Randomly crop a [height, width] section of the image. 55 | distorted_image = tf.random_crop(image, 56 | [output_height, output_width, 3]) 57 | 58 | # Randomly flip the image horizontally. 59 | distorted_image = tf.image.random_flip_left_right(distorted_image) 60 | 61 | tf.image_summary('distorted_image', tf.expand_dims(distorted_image, 0)) 62 | 63 | # Because these operations are not commutative, consider randomizing 64 | # the order their operation. 65 | distorted_image = tf.image.random_brightness(distorted_image, 66 | max_delta=63) 67 | distorted_image = tf.image.random_contrast(distorted_image, 68 | lower=0.2, upper=1.8) 69 | # Subtract off the mean and divide by the variance of the pixels. 70 | return tf.image.per_image_whitening(distorted_image) 71 | 72 | 73 | def preprocess_for_eval(image, output_height, output_width): 74 | """Preprocesses the given image for evaluation. 75 | 76 | Args: 77 | image: A `Tensor` representing an image of arbitrary size. 78 | output_height: The height of the image after preprocessing. 79 | output_width: The width of the image after preprocessing. 80 | 81 | Returns: 82 | A preprocessed image. 83 | """ 84 | tf.image_summary('image', tf.expand_dims(image, 0)) 85 | # Transform the image to floats. 86 | image = tf.to_float(image) 87 | 88 | # Resize and crop if needed. 89 | resized_image = tf.image.resize_image_with_crop_or_pad(image, 90 | output_width, 91 | output_height) 92 | tf.image_summary('resized_image', tf.expand_dims(resized_image, 0)) 93 | 94 | # Subtract off the mean and divide by the variance of the pixels. 95 | return tf.image.per_image_whitening(resized_image) 96 | 97 | 98 | def preprocess_image(image, output_height, output_width, is_training=False): 99 | """Preprocesses the given image. 100 | 101 | Args: 102 | image: A `Tensor` representing an image of arbitrary size. 103 | output_height: The height of the image after preprocessing. 104 | output_width: The width of the image after preprocessing. 105 | is_training: `True` if we're preprocessing the image for training and 106 | `False` otherwise. 107 | 108 | Returns: 109 | A preprocessed image. 110 | """ 111 | if is_training: 112 | return preprocess_for_train(image, output_height, output_width) 113 | else: 114 | return preprocess_for_eval(image, output_height, output_width) 115 | -------------------------------------------------------------------------------- /preprocessing/lenet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities for preprocessing.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def preprocess_image(image, output_height, output_width, is_training): 27 | """Preprocesses the given image. 28 | 29 | Args: 30 | image: A `Tensor` representing an image of arbitrary size. 31 | output_height: The height of the image after preprocessing. 32 | output_width: The width of the image after preprocessing. 33 | is_training: `True` if we're preprocessing the image for training and 34 | `False` otherwise. 35 | 36 | Returns: 37 | A preprocessed image. 38 | """ 39 | image = tf.to_float(image) 40 | image = tf.image.resize_image_with_crop_or_pad( 41 | image, output_width, output_height) 42 | image = tf.sub(image, 128.0) 43 | image = tf.div(image, 128.0) 44 | return image 45 | -------------------------------------------------------------------------------- /preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from preprocessing import cifarnet_preprocessing 24 | from preprocessing import inception_preprocessing 25 | from preprocessing import lenet_preprocessing 26 | from preprocessing import vgg_preprocessing 27 | 28 | slim = tf.contrib.slim 29 | 30 | 31 | def get_preprocessing(name, is_training=False): 32 | """Returns preprocessing_fn(image, height, width, **kwargs). 33 | 34 | Args: 35 | name: The name of the preprocessing function. 36 | is_training: `True` if the model is being used for training and `False` 37 | otherwise. 38 | 39 | Returns: 40 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 41 | It has the following signature: 42 | image = preprocessing_fn(image, output_height, output_width, ...). 43 | 44 | Raises: 45 | ValueError: If Preprocessing `name` is not recognized. 46 | """ 47 | preprocessing_fn_map = { 48 | 'cifarnet': cifarnet_preprocessing, 49 | 'inception': inception_preprocessing, 50 | 'inception_v1': inception_preprocessing, 51 | 'inception_v2': inception_preprocessing, 52 | 'inception_v3': inception_preprocessing, 53 | 'inception_v4': inception_preprocessing, 54 | 'inception_resnet_v2': inception_preprocessing, 55 | 'lenet': lenet_preprocessing, 56 | 'resnet_v1_50': vgg_preprocessing, 57 | 'resnet_v1_101': vgg_preprocessing, 58 | 'resnet_v1_152': vgg_preprocessing, 59 | 'vgg': vgg_preprocessing, 60 | 'vgg_a': vgg_preprocessing, 61 | 'vgg_16': vgg_preprocessing, 62 | 'vgg_19': vgg_preprocessing, 63 | } 64 | 65 | if name not in preprocessing_fn_map: 66 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 67 | 68 | def preprocessing_fn(image, output_height, output_width, **kwargs): 69 | return preprocessing_fn_map[name].preprocess_image( 70 | image, output_height, output_width, is_training=is_training, **kwargs) 71 | 72 | def unprocessing_fn(image, **kwargs): 73 | return preprocessing_fn_map[name].unprocess_image( 74 | image, **kwargs) 75 | 76 | return preprocessing_fn, unprocessing_fn 77 | -------------------------------------------------------------------------------- /reader.py: -------------------------------------------------------------------------------- 1 | from os import listdir 2 | from os.path import isfile, join 3 | 4 | import tensorflow as tf 5 | 6 | 7 | def get_image(path, height, width, preprocess_fn): 8 | png = path.lower().endswith('png') 9 | img_bytes = tf.read_file(path) 10 | image = tf.image.decode_png(img_bytes, channels=3) if png else tf.image.decode_jpeg(img_bytes, channels=3) 11 | return preprocess_fn(image, height, width) 12 | 13 | 14 | def batch_image(batch_size, height, width, path, preprocess_fn, epochs=2, shuffle=True): 15 | file_names = [join(path, f) for f in listdir(path) if isfile(join(path, f))] 16 | if not shuffle: 17 | file_names = sorted(file_names) 18 | 19 | png = file_names[0].lower().endswith('png') 20 | 21 | filename_queue = tf.train.string_input_producer(file_names, shuffle=shuffle, num_epochs=epochs) 22 | reader = tf.WholeFileReader() 23 | _, img_bytes = reader.read(filename_queue) 24 | image = tf.image.decode_png(img_bytes, channels=3) if png else tf.image.decode_jpeg(img_bytes, channels=3) 25 | 26 | processed_image = preprocess_fn(image, height, width) 27 | return tf.train.batch([processed_image], batch_size, dynamic_pad=True) 28 | -------------------------------------------------------------------------------- /static/img/background.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/background.jpg -------------------------------------------------------------------------------- /static/img/content/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/content/test.jpg -------------------------------------------------------------------------------- /static/img/content/test1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/content/test1.jpg -------------------------------------------------------------------------------- /static/img/content/test2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/content/test2.jpg -------------------------------------------------------------------------------- /static/img/content/test3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/content/test3.jpg -------------------------------------------------------------------------------- /static/img/content/test4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/content/test4.jpg -------------------------------------------------------------------------------- /static/img/content/test5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/content/test5.jpg -------------------------------------------------------------------------------- /static/img/content/test6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/content/test6.jpg -------------------------------------------------------------------------------- /static/img/content/test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/content/test7.jpg -------------------------------------------------------------------------------- /static/img/content/test8.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/content/test8.jpg -------------------------------------------------------------------------------- /static/img/content/test9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/content/test9.png -------------------------------------------------------------------------------- /static/img/favicon.ico: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/favicon.ico -------------------------------------------------------------------------------- /static/img/generated/cubist_res.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/cubist_res.jpg -------------------------------------------------------------------------------- /static/img/generated/cubist_res_test6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/cubist_res_test6.jpg -------------------------------------------------------------------------------- /static/img/generated/denoised_starry_res.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/denoised_starry_res.jpg -------------------------------------------------------------------------------- /static/img/generated/denoised_starry_res_test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/denoised_starry_res_test.jpg -------------------------------------------------------------------------------- /static/img/generated/denoised_starry_res_test1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/denoised_starry_res_test1.jpg -------------------------------------------------------------------------------- /static/img/generated/denoised_starry_res_test2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/denoised_starry_res_test2.jpg -------------------------------------------------------------------------------- /static/img/generated/denoised_starry_res_test3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/denoised_starry_res_test3.jpg -------------------------------------------------------------------------------- /static/img/generated/denoised_starry_res_test4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/denoised_starry_res_test4.jpg -------------------------------------------------------------------------------- /static/img/generated/denoised_starry_res_test5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/denoised_starry_res_test5.jpg -------------------------------------------------------------------------------- /static/img/generated/denoised_starry_res_test6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/denoised_starry_res_test6.jpg -------------------------------------------------------------------------------- /static/img/generated/denoised_starry_res_test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/denoised_starry_res_test7.jpg -------------------------------------------------------------------------------- /static/img/generated/denoised_starry_res_test9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/denoised_starry_res_test9.png -------------------------------------------------------------------------------- /static/img/generated/feathers_res_test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/feathers_res_test7.jpg -------------------------------------------------------------------------------- /static/img/generated/mosaic_res_test6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/mosaic_res_test6.jpg -------------------------------------------------------------------------------- /static/img/generated/mosaic_res_test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/mosaic_res_test7.jpg -------------------------------------------------------------------------------- /static/img/generated/painting_res.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/painting_res.jpg -------------------------------------------------------------------------------- /static/img/generated/painting_res_test6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/painting_res_test6.jpg -------------------------------------------------------------------------------- /static/img/generated/painting_res_test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/painting_res_test7.jpg -------------------------------------------------------------------------------- /static/img/generated/scream_res_test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/scream_res_test.jpg -------------------------------------------------------------------------------- /static/img/generated/scream_res_test4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/scream_res_test4.jpg -------------------------------------------------------------------------------- /static/img/generated/scream_res_test6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/scream_res_test6.jpg -------------------------------------------------------------------------------- /static/img/generated/scream_res_test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/scream_res_test7.jpg -------------------------------------------------------------------------------- /static/img/generated/target_style_painting.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/target_style_painting.jpg -------------------------------------------------------------------------------- /static/img/generated/test6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/test6.jpg -------------------------------------------------------------------------------- /static/img/generated/udnie_res_test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/udnie_res_test7.jpg -------------------------------------------------------------------------------- /static/img/generated/wave_res_test5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/wave_res_test5.jpg -------------------------------------------------------------------------------- /static/img/generated/wave_res_test6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/generated/wave_res_test6.jpg -------------------------------------------------------------------------------- /static/img/loading1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/loading1.gif -------------------------------------------------------------------------------- /static/img/loading2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/loading2.gif -------------------------------------------------------------------------------- /static/img/loading3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/loading3.gif -------------------------------------------------------------------------------- /static/img/loading4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/loading4.gif -------------------------------------------------------------------------------- /static/img/style/candy.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/style/candy.jpg -------------------------------------------------------------------------------- /static/img/style/cubist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/style/cubist.jpg -------------------------------------------------------------------------------- /static/img/style/denoised_starry.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/style/denoised_starry.jpg -------------------------------------------------------------------------------- /static/img/style/feathers.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/style/feathers.jpg -------------------------------------------------------------------------------- /static/img/style/gouache.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/style/gouache.jpg -------------------------------------------------------------------------------- /static/img/style/mosaic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/style/mosaic.jpg -------------------------------------------------------------------------------- /static/img/style/painting.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/style/painting.jpg -------------------------------------------------------------------------------- /static/img/style/picasso.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/style/picasso.jpg -------------------------------------------------------------------------------- /static/img/style/scream.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/style/scream.jpg -------------------------------------------------------------------------------- /static/img/style/starry.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/style/starry.jpg -------------------------------------------------------------------------------- /static/img/style/udnie.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/style/udnie.jpg -------------------------------------------------------------------------------- /static/img/style/wave.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/style/wave.jpg -------------------------------------------------------------------------------- /static/img/uploads/test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/uploads/test.jpg -------------------------------------------------------------------------------- /static/img/uploads/test1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/uploads/test1.jpg -------------------------------------------------------------------------------- /static/img/uploads/test2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/uploads/test2.jpg -------------------------------------------------------------------------------- /static/img/uploads/test3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/uploads/test3.jpg -------------------------------------------------------------------------------- /static/img/uploads/test4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/uploads/test4.jpg -------------------------------------------------------------------------------- /static/img/uploads/test5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/uploads/test5.jpg -------------------------------------------------------------------------------- /static/img/uploads/test6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/uploads/test6.jpg -------------------------------------------------------------------------------- /static/img/uploads/test7.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/uploads/test7.jpg -------------------------------------------------------------------------------- /static/img/uploads/test9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wisewong/ImageStyleTransform/f2f7930cbf51f544c59a949855e2403aadede41b/static/img/uploads/test9.png -------------------------------------------------------------------------------- /static/js/maple.js: -------------------------------------------------------------------------------- 1 | /** 2 | * Created by wz on 17/5/19. 3 | */ 4 | var d = "
🍁
"; 5 | setInterval(function () { 6 | var f = $(document).width(); 7 | var e = Math.random() * f - 300; // 枫叶的定位left值 8 | var o = 0.3; // 枫叶的透明度 9 | var fon = 25 + Math.random() * 10; // 枫叶大小 10 | var l = e - 100 + 200 * Math.random(); // 枫叶的横向位移 11 | var k = 8000 + 5000 * Math.random(); 12 | var deg = Math.random() * 360; // 枫叶的方向 13 | $(d).clone().appendTo(".maplebg").css({ 14 | left: e + "px", 15 | opacity: o, 16 | transform: "rotate(" + deg + "deg)", 17 | "font-size": fon, 18 | }).animate({ 19 | top: "550px", 20 | left: l + "px", 21 | opacity: 0.1, 22 | }, k, "linear", function () { 23 | $(this).remove() 24 | }) 25 | }, 500) 26 | 27 | 28 | 29 | function changeImage() { 30 | var styleFileImage = 'img/style/' + document.getElementById('style').value + '.jpg'; 31 | console.log(document.getElementById('style').value); 32 | console.log(styleFileImage); 33 | document.getElementById('imageShow').src = "../static/" + styleFileImage; 34 | } 35 | 36 | function showStatus() { 37 | // document.getElementById('status').src = '../static/img/loading1.gif' 38 | var status = document.createElement('img'); 39 | status.setAttribute('src','../static/img/loading1.gif') 40 | document.getElementById('status').appendChild(status) 41 | } 42 | 43 | $(".clickUpload").on("change","input[type='file']",function(){ 44 | var filePath=$(this).val(); 45 | if(filePath.indexOf("jpg")!=-1 || filePath.indexOf("png")!=-1){ 46 | $(".fileerrorTip").html("").hide(); 47 | var arr=filePath.split('\\'); 48 | var fileName=arr[arr.length-1]; 49 | $(".showFileName").html(fileName); 50 | }else{ 51 | $(".showFileName").html(""); 52 | $(".fileerrorTip").html("您未上传文件,或者您上传文件类型有误!").show(); 53 | return false 54 | } 55 | }) 56 | 57 | $(".picurlbtn").on("change","input[type='file']",function(){ 58 | var filePath=$(this).val(); 59 | if(filePath.indexOf("jpg")!=-1 || filePath.indexOf("jpeg")!=-1 || filePath.indexOf("png")!=-1){ 60 | $("#picture_name").attr("value",filePath); 61 | }else{ 62 | $("#picture_name").attr("value","仅支持jpg,jpeg,png格式!"); 63 | return false 64 | } 65 | }) -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 图像风格转换 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 |

图像风格转换

17 | 18 |
19 |
20 | 选择图片风格: 21 |

31 |
32 | 34 |
35 |

36 | 上传图片: 37 | 38 | 39 |
40 | 41 |
42 |
43 |
44 | 45 | 46 | -------------------------------------------------------------------------------- /templates/transformed.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 转换后图片 6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 |

转换后图片

14 | 15 |
16 | 17 | 18 | 19 |
20 | 21 |
22 |

返回 23 |
24 | 25 | 26 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import print_function 3 | from __future__ import division 4 | import tensorflow as tf 5 | from nets import nets_factory 6 | from preprocessing import preprocessing_factory 7 | import reader 8 | import model 9 | import time 10 | import losses 11 | import utils 12 | import os 13 | import argparse 14 | 15 | slim = tf.contrib.slim 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('-c', '--conf', default='conf/mosaic.yml', help='配置文件路径') 21 | return parser.parse_args() 22 | 23 | 24 | def main(FLAGS): 25 | style_features_t = losses.get_style_features(FLAGS) 26 | training_path = os.path.join(FLAGS.model_path, FLAGS.naming) 27 | if not (os.path.exists(training_path)): 28 | os.makedirs(training_path) 29 | 30 | with tf.Graph().as_default(): 31 | with tf.Session() as sess: 32 | """创建Network""" 33 | network_fn = nets_factory.get_network_fn( 34 | FLAGS.loss_model, 35 | num_classes=1, 36 | is_training=False) 37 | 38 | image_preprocessing_fn, image_unprocessing_fn = preprocessing_factory.get_preprocessing( 39 | FLAGS.loss_model, 40 | is_training=False) 41 | 42 | """训练图片预处理""" 43 | processed_images = reader.batch_image(FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 44 | 'train2014/', image_preprocessing_fn, epochs=FLAGS.epoch) 45 | generated = model.transform_network(processed_images, training=True) 46 | processed_generated = [image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size) 47 | for image in tf.unstack(generated, axis=0, num=FLAGS.batch_size) 48 | ] 49 | processed_generated = tf.stack(processed_generated) 50 | _, endpoints_dict = network_fn(tf.concat([processed_generated, processed_images], 0), spatial_squeeze=False) 51 | tf.logging.info('Loss network layers(You can define them in "content_layers" and "style_layers"):') 52 | for key in endpoints_dict: 53 | tf.logging.info(key) 54 | 55 | """创建 Losses""" 56 | content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers) 57 | style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers) 58 | tv_loss = losses.total_variation_loss(generated) # use the unprocessed image 59 | 60 | loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss 61 | 62 | """准备训练""" 63 | global_step = tf.Variable(0, name="global_step", trainable=False) 64 | variable_to_train = [] 65 | for variable in tf.trainable_variables(): 66 | # 只训练和保存生成网络中的变量 67 | if not (variable.name.startswith(FLAGS.loss_model)): 68 | variable_to_train.append(variable) 69 | 70 | """优化""" 71 | train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step, var_list=variable_to_train) 72 | 73 | variables_to_restore = [] 74 | for v in tf.global_variables(): 75 | if not (v.name.startswith(FLAGS.loss_model)): 76 | variables_to_restore.append(v) 77 | saver = tf.train.Saver(variables_to_restore, write_version=tf.train.SaverDef.V1) 78 | sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) 79 | init_func = utils._get_init_fn(FLAGS) 80 | init_func(sess) 81 | last_file = tf.train.latest_checkpoint(training_path) 82 | if last_file: 83 | tf.logging.info('Restoring model from {}'.format(last_file)) 84 | saver.restore(sess, last_file) 85 | 86 | """开始训练""" 87 | coord = tf.train.Coordinator() 88 | threads = tf.train.start_queue_runners(coord=coord) 89 | start_time = time.time() 90 | try: 91 | while not coord.should_stop(): 92 | _, loss_t, step = sess.run([train_op, loss, global_step]) 93 | elapsed_time = time.time() - start_time 94 | start_time = time.time() 95 | if step % 10 == 0: 96 | tf.logging.info( 97 | 'step: %d, total Loss %f, secs/step: %f,%s' % (step, loss_t, elapsed_time, time.asctime())) 98 | """checkpoint""" 99 | if step % 50 == 0: 100 | tf.logging.info('saving check point...') 101 | saver.save(sess, os.path.join(training_path, FLAGS.naming + '.ckpt'), global_step=step) 102 | except tf.errors.OutOfRangeError: 103 | saver.save(sess, os.path.join(training_path, 'fast-style-model.ckpt-done')) 104 | tf.logging.info('Done training -- epoch limit reached') 105 | finally: 106 | coord.request_stop() 107 | tf.logging.info('coordinator stop') 108 | coord.join(threads) 109 | 110 | 111 | if __name__ == '__main__': 112 | tf.logging.set_verbosity(tf.logging.INFO) 113 | args = parse_args() 114 | FLAGS = utils.read_conf_file(args.conf) 115 | main(FLAGS) 116 | -------------------------------------------------------------------------------- /transform.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | from __future__ import print_function 3 | 4 | import os 5 | import time 6 | 7 | import tensorflow as tf 8 | 9 | import model 10 | import reader 11 | from preprocessing import preprocessing_factory 12 | 13 | tf.app.flags.DEFINE_string('loss_model', 'vgg_16', 'The name of the architecture to evaluate. ' 14 | 'You can view all the support models in nets/nets_factory.py') 15 | tf.app.flags.DEFINE_integer('image_size', 256, 'Image size to train.') 16 | tf.app.flags.DEFINE_string("model_file", "models.ckpt", "风格模型") 17 | tf.app.flags.DEFINE_string("image_file", "content.jpg", "输入图片") 18 | tf.app.flags.DEFINE_string('target_file', 'res.jpg', '转换风格后的图片') 19 | 20 | FLAGS = tf.app.flags.FLAGS 21 | 22 | 23 | def main(_): 24 | height = 0 25 | width = 0 26 | with open(FLAGS.image_file, 'rb') as img: 27 | with tf.Session().as_default() as sess: 28 | if FLAGS.image_file.lower().endswith('png'): 29 | image = sess.run(tf.image.decode_png(img.read())) 30 | else: 31 | image = sess.run(tf.image.decode_jpeg(img.read())) 32 | height = image.shape[0] 33 | width = image.shape[1] 34 | tf.logging.info('Image size: %dx%d' % (width, height)) 35 | 36 | with tf.Graph().as_default(): 37 | with tf.Session().as_default() as sess: 38 | image_preprocessing_fn, _ = preprocessing_factory.get_preprocessing( 39 | FLAGS.loss_model, 40 | is_training=False) 41 | """获取经过预处理的输入图片,用于后面获取图片的content""" 42 | image = reader.get_image(FLAGS.image_file, height, width, image_preprocessing_fn) 43 | image = tf.expand_dims(image, 0) 44 | generated = model.transform_network(image, training=False) 45 | generated = tf.squeeze(generated, [0]) 46 | saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1) 47 | sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) 48 | 49 | """获取已训练好的model""" 50 | FLAGS.model_file = os.path.abspath(FLAGS.model_file) 51 | saver.restore(sess, FLAGS.model_file) 52 | 53 | """生成转换style后的image""" 54 | start_time = time.time() 55 | generated = sess.run(generated) 56 | generated = tf.cast(generated, tf.uint8) 57 | end_time = time.time() 58 | tf.logging.info('Elapsed time: %fs' % (end_time - start_time)) 59 | 60 | generated_file = FLAGS.target_file 61 | if os.path.exists('static/img/generated') is False: 62 | os.makedirs('static/img/generated') 63 | with open(generated_file, 'wb') as img: 64 | img.write(sess.run(tf.image.encode_jpeg(generated))) 65 | tf.logging.info('Done. Please check %s.' % generated_file) 66 | 67 | 68 | if __name__ == '__main__': 69 | tf.logging.set_verbosity(tf.logging.INFO) 70 | tf.app.run() 71 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import yaml 3 | 4 | slim = tf.contrib.slim 5 | 6 | 7 | def _get_init_fn(FLAGS): 8 | """ 9 | This function is copied from TF slim. 10 | 11 | Returns a function run by the chief worker to warm-start the training. 12 | 13 | Note that the init_fn is only run when initializing the model during the very 14 | first global step. 15 | 16 | Returns: 17 | An init function run by the supervisor. 18 | """ 19 | tf.logging.info('Use pretrained model %s' % FLAGS.loss_model_file) 20 | 21 | exclusions = [] 22 | if FLAGS.checkpoint_exclude_scopes: 23 | exclusions = [scope.strip() 24 | for scope in FLAGS.checkpoint_exclude_scopes.split(',')] 25 | 26 | variables_to_restore = [] 27 | for var in slim.get_model_variables(): 28 | excluded = False 29 | for exclusion in exclusions: 30 | if var.op.name.startswith(exclusion): 31 | excluded = True 32 | break 33 | if not excluded: 34 | variables_to_restore.append(var) 35 | 36 | return slim.assign_from_checkpoint_fn( 37 | FLAGS.loss_model_file, 38 | variables_to_restore, 39 | ignore_missing_vars=True) 40 | 41 | 42 | class Flag(object): 43 | def __init__(self, **entries): 44 | self.__dict__.update(entries) 45 | 46 | 47 | def read_conf_file(conf_file): 48 | with open(conf_file) as f: 49 | FLAGS = Flag(**yaml.load(f)) 50 | return FLAGS 51 | 52 | 53 | if __name__ == '__main__': 54 | f = read_conf_file('conf/mosaic.yml') 55 | print(f.loss_model_file) 56 | -------------------------------------------------------------------------------- /web.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import time 5 | 6 | import tensorflow as tf 7 | from flask import Flask, render_template, request, send_from_directory 8 | 9 | import model 10 | import reader 11 | from preprocessing import preprocessing_factory 12 | 13 | app = Flask(__name__) 14 | app.config['SECRET_KEY'] = '123456' 15 | app.static_folder = 'static' 16 | 17 | UPLOAD_FOLDER = 'static/img/uploads/' 18 | ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} 19 | 20 | app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER 21 | 22 | tf.app.flags.DEFINE_string('loss_model', 'vgg_16', 'The name of the architecture to evaluate. ' 23 | 'You can view all the support models in nets/nets_factory.py') 24 | tf.app.flags.DEFINE_integer('image_size', 256, 'Image size to train.') 25 | tf.app.flags.DEFINE_string("model_file", "models.ckpt", "") 26 | tf.app.flags.DEFINE_string("image_file", "a.jpg", "") 27 | FLAGS = tf.app.flags.FLAGS 28 | 29 | 30 | def allowed_file(filename): 31 | return '.' in filename and filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS 32 | 33 | 34 | @app.route('/') 35 | def index(): 36 | return render_template('index.html') 37 | 38 | 39 | @app.route('/transform', methods=['GET', 'POST']) 40 | def deal_image(): 41 | models_dict = {'cubist': 'cubist.ckpt-done', 42 | 'denoised_starry': 'denoised_starry.ckpt-done', 43 | 'feathers': 'feathers.ckpt-done', 44 | 'mosaic': 'mosaic.ckpt-done', 45 | 'scream': 'scream.ckpt-done', 46 | 'udnie': 'udnie.ckpt-done', 47 | 'wave': 'wave.ckpt-done', 48 | 'painting': 'painting.ckpt-done', 49 | } 50 | if request.method == 'POST': 51 | file = request.files['pic'] 52 | 53 | style = request.form['style'] 54 | if file and allowed_file(file.filename): 55 | if os.path.exists(app.config['UPLOAD_FOLDER']) is False: 56 | os.makedirs(app.config['UPLOAD_FOLDER']) 57 | file.save(os.path.join(app.config['UPLOAD_FOLDER'], file.filename)) 58 | model_file = 'wave.ckpt-done' 59 | if style != '': 60 | if models_dict[style] != '': 61 | model_file = models_dict[style] 62 | style_transform(style, 'models/' + model_file, os.path.join(app.config['UPLOAD_FOLDER']) + file.filename, 63 | style + '_res_' + file.filename) 64 | return render_template('transformed.html', style='img/style/' + style + '.jpg', 65 | upload='img/uploads/' + file.filename, 66 | transformed='img/generated/' + style + '_res_' + file.filename) 67 | return 'transform error:file format error' 68 | return 'transform error:method not post' 69 | 70 | 71 | @app.route('/uploads/') 72 | def uploaded_file(filename): 73 | return send_from_directory('static/img/generated/', filename) 74 | 75 | 76 | def style_transform(style, model_file, img_file, result_file): 77 | height = 0 78 | width = 0 79 | with open(img_file, 'rb') as img: 80 | with tf.Session().as_default() as sess: 81 | if img_file.lower().endswith('png'): 82 | image = sess.run(tf.image.decode_png(img.read())) 83 | else: 84 | image = sess.run(tf.image.decode_jpeg(img.read())) 85 | height = image.shape[0] 86 | width = image.shape[1] 87 | print('Image size: %dx%d' % (width, height)) 88 | 89 | with tf.Graph().as_default(): 90 | with tf.Session().as_default() as sess: 91 | image_preprocessing_fn, _ = preprocessing_factory.get_preprocessing( 92 | FLAGS.loss_model, 93 | is_training=False) 94 | image = reader.get_image(img_file, height, width, image_preprocessing_fn) 95 | image = tf.expand_dims(image, 0) 96 | generated = model.transform_network(image, training=False) 97 | generated = tf.squeeze(generated, [0]) 98 | saver = tf.train.Saver(tf.global_variables()) 99 | sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()]) 100 | FLAGS.model_file = os.path.abspath(model_file) 101 | saver.restore(sess, FLAGS.model_file) 102 | 103 | start_time = time.time() 104 | generated = sess.run(generated) 105 | generated = tf.cast(generated, tf.uint8) 106 | end_time = time.time() 107 | print('Elapsed time: %fs' % (end_time - start_time)) 108 | generated_file = 'static/img/generated/' + result_file 109 | if os.path.exists('static/img/generated') is False: 110 | os.makedirs('static/img/generated') 111 | with open(generated_file, 'wb') as img: 112 | img.write(sess.run(tf.image.encode_jpeg(generated))) 113 | print('Done. Please check %s.' % generated_file) 114 | 115 | 116 | if __name__ == '__main__': 117 | app.run(debug=True) 118 | 119 | 120 | 121 | --------------------------------------------------------------------------------