├── .gitignore ├── .idea ├── codeStyles │ └── codeStyleConfig.xml └── vcs.xml ├── README.md ├── _config.yml ├── config ├── __init__.py └── global_config.py ├── data ├── images │ ├── avg_train_loss.png │ ├── avg_train_top1_error.png │ ├── avg_val_loss.png │ ├── avg_val_top1_error.png │ ├── confusion_matrix.png │ ├── evaluation_nsfw.png │ ├── online_demo.png │ └── precision_recall.png ├── nsfw_dataset_example │ ├── drawing │ │ ├── drawing_136.jpg │ │ ├── drawing_147.jpg │ │ ├── drawing_192.jpg │ │ ├── drawing_207.jpg │ │ ├── drawing_219.jpg │ │ ├── drawing_227.jpg │ │ ├── drawing_240.jpg │ │ ├── drawing_253.jpg │ │ ├── drawing_28.jpg │ │ ├── drawing_346.jpg │ │ ├── drawing_348.jpg │ │ ├── drawing_361.jpg │ │ ├── drawing_384.jpg │ │ ├── drawing_4.jpg │ │ ├── drawing_463.jpg │ │ ├── drawing_483.jpg │ │ ├── drawing_52.jpg │ │ └── drawing_77.jpg │ ├── hentai │ │ ├── hentai_106.jpg │ │ ├── hentai_112.jpg │ │ ├── hentai_12.jpg │ │ ├── hentai_194.jpg │ │ ├── hentai_216.jpg │ │ ├── hentai_291.jpg │ │ ├── hentai_311.jpg │ │ ├── hentai_318.jpg │ │ ├── hentai_331.jpg │ │ ├── hentai_377.jpg │ │ ├── hentai_407.jpg │ │ ├── hentai_424.jpg │ │ ├── hentai_437.jpg │ │ ├── hentai_458.jpg │ │ ├── hentai_489.jpg │ │ ├── hentai_499.jpg │ │ ├── hentai_65.jpg │ │ └── hentai_677.jpg │ ├── neural │ │ ├── neural_352.jpg │ │ ├── neural_353.jpg │ │ ├── neural_358.jpg │ │ ├── neural_360.jpg │ │ ├── neural_374.jpg │ │ ├── neural_377.jpg │ │ ├── neural_438.jpg │ │ ├── neural_446.jpg │ │ ├── neural_474.jpg │ │ ├── neural_493.jpg │ │ ├── neural_518.jpg │ │ ├── neural_523.jpg │ │ ├── neural_536.jpg │ │ ├── neural_539.jpg │ │ ├── neural_569.jpg │ │ ├── neural_596.jpg │ │ ├── neural_619.jpg │ │ └── neural_683.jpg │ ├── porn │ │ ├── porn_1016.jpg │ │ ├── porn_1198.jpg │ │ ├── porn_1270.jpg │ │ ├── porn_1364.jpg │ │ ├── porn_1470.jpg │ │ ├── porn_1488.jpg │ │ ├── porn_1515.jpg │ │ ├── porn_1551.jpg │ │ ├── porn_1556.jpg │ │ ├── porn_1570.jpg │ │ ├── porn_1592.jpg │ │ ├── porn_1677.jpg │ │ ├── porn_1723.jpg │ │ ├── porn_179.jpg │ │ ├── porn_1794.jpg │ │ ├── porn_1866.jpg │ │ ├── porn_1899.jpg │ │ ├── porn_254.jpg │ │ ├── porn_364.jpg │ │ ├── porn_400.jpg │ │ ├── porn_679.jpg │ │ ├── porn_749.jpg │ │ ├── porn_755.jpg │ │ └── porn_924.jpg │ └── sexy │ │ ├── sexy_167.jpg │ │ ├── sexy_172.jpg │ │ ├── sexy_180.jpg │ │ ├── sexy_189.jpg │ │ ├── sexy_227.jpg │ │ ├── sexy_267.jpg │ │ ├── sexy_296.jpg │ │ ├── sexy_301.jpg │ │ ├── sexy_312.jpg │ │ ├── sexy_397.jpg │ │ ├── sexy_474.jpg │ │ ├── sexy_488.jpg │ │ ├── sexy_508.jpg │ │ ├── sexy_520.jpg │ │ ├── sexy_524.jpg │ │ ├── sexy_554.jpg │ │ ├── sexy_556.jpg │ │ ├── sexy_567.jpg │ │ ├── sexy_599.jpg │ │ ├── sexy_611.jpg │ │ ├── sexy_629.jpg │ │ └── sexy_712.jpg └── test_data │ ├── drawing_16715.jpg │ ├── hentai_16814.jpg │ ├── neural_2619.jpg │ ├── porn_19625.jpg │ └── sexy_6182.jpg ├── data_provider ├── __init__.py ├── nsfw_data_feed_pipline.py └── tf_io_pipline_tools.py ├── docker_container ├── __init__.py └── python_client.py ├── model └── nsfw_export_saved_model │ └── 1 │ ├── saved_model.pb │ └── variables │ ├── variables.data-00000-of-00001 │ └── variables.index ├── nsfw_model ├── __init__.py ├── cnn_basenet.py └── nsfw_classification_net.py ├── requirements.txt ├── tboard └── nsfw_cls │ └── events.out.tfevents.1551264389.baidu-pc └── tools ├── __init__.py ├── convert_tfjs_model.sh ├── evaluate_nsfw.py ├── export_nsfw_saved_model.sh ├── export_saved_model.py ├── make_nsfw_dataset.sh ├── test_nsfw.py └── train_nsfw.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | .pytest_cache/ 51 | 52 | # Translations 53 | *.mo 54 | *.pot 55 | 56 | # Django stuff: 57 | *.log 58 | local_settings.py 59 | db.sqlite3 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | .idea/dictionaries/ 109 | .idea/inspectionProfiles/ 110 | .idea/misc.xml 111 | .idea/modules.xml 112 | .idea/nsfw_classification.iml 113 | .idea/workspace.xml 114 | 115 | model/ 116 | tboard/ 117 | 118 | tmp*.py -------------------------------------------------------------------------------- /.idea/codeStyles/codeStyleConfig.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Nsfw-Classify-Tensorflow 2 | NSFW classify model implemented with tensorflow. Use nsfw dataset provided here 3 | https://github.com/alexkimxyz/nsfw_data_scraper Thanks for sharing the dataset 4 | with us. You can find all the model details here. Don not hesitate to raise an 5 | issue if you're confused with the model. 6 | 7 | ## Installation 8 | This software has only been tested on ubuntu 16.04(x64). Here is the test environment 9 | info 10 | 11 | **OS**: Ubuntu 16.04 LTS 12 | 13 | **GPU**: Two GTX 1070TI 14 | 15 | **CUDA**: cuda 9.0 16 | 17 | **Tensorflow**: tensorflow 1.12.0 18 | 19 | **OPENCV**: opencv 3.4.1 20 | 21 | **NUMPY**: numpy 1.15.1 22 | 23 | Other required package you may install them by 24 | 25 | ``` 26 | pip3 install -r requirements.txt 27 | ``` 28 | 29 | ## Test model 30 | In this repo I uploaded a pretrained dataset introduced as before, you need to 31 | download the pretrained [weights_file](https://www.dropbox.com/sh/yfc3hy7enopdxj2/AADek2FlqCW1_j8Ax5x_VQy8a?dl=0) 32 | in folder REPO_ROOT_DIR/model/. 33 | 34 | You can test a single image on the trained model as follows 35 | 36 | ``` 37 | cd REPO_ROOT_DIR 38 | python tools/test_model.py --weights_path model/new_model/nsfw_cls.ckpt-100000 39 | --image_path data/test_data/test_drawing.jpg 40 | ``` 41 | 42 | The following part will show you how the dataset is well prepared 43 | 44 | ## Train your own model 45 | 46 | #### Data Preparation 47 | First you need to download all the origin nsfw data. Here is the 48 | [how_to_download_source_data](https://github.com/alexkimxyz/nsfw_data_scraper). 49 | The training example should be organized like the what you can see in 50 | REPO_ROOT_DIR/data/nsfw_dataset_example. Then you should modified the 51 | REPO_ROOT_DIR/tools/make_nsfw_dataset.sh with your local nsfw dataset. Then excute 52 | the dataset preparation script. That may take about one hour in my local machine. 53 | You may enlarge the __C.TRAIN.CPU_MULTI_PROCESS_NUMS in config/global_config.py 54 | if you have a powerful cpu to accelerate the prepare process. 55 | 56 | ``` 57 | cd REPO_ROOT_DIR 58 | bash tools/make_nsfw_dataset.sh 59 | ``` 60 | 61 | The image of each subclass will be split into three part according to the ratio 62 | training : validation : test = 0.75 : 0.1 : 0.15. All the image will be convert 63 | into tensorflow format record for efficient importing data pipline. 64 | 65 | #### Train Model 66 | The model support multi-gpu training. If you want to training the model on 67 | multiple gpus you need to first adjust the __C.TRAIN.GPU_NUM in config/global_config.py 68 | file. Then excute the multi-gpu training procedure as follows: 69 | 70 | ``` 71 | cd REPO_ROOT_DIR 72 | python tools/train_nsfw.py --dataset_dir PATH/TO/PREPARED_NSFW_DATASET --use_multi_gpu True 73 | ``` 74 | 75 | If you want to train the model from last snap shot you may excute following command: 76 | 77 | ``` 78 | cd REPO_ROOT_DIR 79 | python tools/train_nsfw.py --dataset_dir PATH/TO/PREPARED_NSFW_DATASET 80 | --use_multi_gpu True --weights_path PATH/TO/YOUR/LAST_CKPT_FILE_PATH 81 | ``` 82 | 83 | You may set the --use_multi_gpu False then the whole training process will be excuted 84 | on single gpu. 85 | 86 | The main model's hyperparameter are as follows: 87 | 88 | **iterations nums**: 160010 89 | 90 | **learning rate**: 0.1 91 | 92 | **batch size**: 32 93 | 94 | **origin image size**: 256 95 | 96 | **cropped image size**: 224 97 | 98 | **training example nums**: 159477 99 | 100 | **testing example nums**: 31895 101 | 102 | **validation example nums**: 21266 103 | 104 | The rest of the hyperparameter can be found [here](https://github.com/MaybeShewill-CV/nsfw-classification-tensorflow/blob/master/config/global_config.py). 105 | 106 | If you want to convert the downloaded ckpt model into tensorflow saved model 107 | you can simply modify the file path in ROOT_DIR/tools/export_nsfw_saved_model.sh 108 | and run it. 109 | 110 | ``` 111 | bash tools/export_nsfw_saved_model.sh 112 | ``` 113 | 114 | You may monitor the training process using tensorboard tools 115 | 116 | During my experiment the `train loss` drops as follows: 117 | ![train_loss](./data/images/avg_train_loss.png) 118 | 119 | The `train_top_1_error` drops as follows: 120 | ![train_top_1_error](./data/images/avg_train_top1_error.png) 121 | 122 | The `validation loss` drops as follows: 123 | ![validation_loss](./data/images/avg_val_loss.png) 124 | 125 | The `validation_top_1_error` drops as follows: 126 | ![validation_top_1_error](./data/images/avg_val_top1_error.png) 127 | 128 | #### The Model Evaluation 129 | 130 | You can evaluate the model's performance on the nsfw dataset prepared in 131 | advance as follows 132 | 133 | ``` 134 | cd REPO_ROOT_DIR 135 | python tools/evaluate_nsfw.py --weights_path model/new_model/nsfw_cls.ckpt-160000 136 | --dataset_dir PATH/TO/YOUR/NSFW_DATASET 137 | ``` 138 | 139 | After you run the script you should see something like this 140 | ![evaluation_result](./data/images/evaluation_nsfw.png) 141 | 142 | The model's main evaluation index are as follows: 143 | 144 | **Precision**: 0.92406 with average weighted on each class 145 | 146 | **Recall**: 0.92364 with average weighted on each class 147 | 148 | **F1 score**: 0.92344 with average weighted on each class 149 | 150 | The `Confusion_Matrix` is as follows: 151 | ![confusion_matrix](./data/images/confusion_matrix.png) 152 | 153 | The `Precison_Recall` is as follows: 154 | ![precision_recall](./data/images/precision_recall.png) 155 | 156 | 157 | #### Online demo 158 | 159 | ##### URL: https://maybeshewill-cv.github.io/nsfw_classification 160 | 161 | Since tensorflo-js is well supported the online deep learning is easy to deploy. 162 | Here I have make a online demo to do local nsfw classification work. The whole js work 163 | can be found here https://github.com/MaybeShewill-CV/MaybeShewill-CV.github.io/tree/master/nsfw_classification 164 | I have supplied a tool to convert the trained tensorflow saved model file into 165 | tensorflow js model file. In order to generate saved model you can read the 166 | description about it above. After you generate the tensorflow saved model you 167 | can simply modify the file path and run the following script 168 | 169 | ``` 170 | cd ROOT_DIR 171 | bash tools/convert_tfjs_model.sh 172 | ``` 173 | The online demo's example are as follows: 174 | ![online_demo](./data/images/online_demo.png) 175 | 176 | ## TODO 177 | - [ ] Add tensorflow serving script 178 | -------------------------------------------------------------------------------- /_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 19-2-14 下午5:38 4 | # @Author : Luo Yao 5 | # @Site : http://icode.baidu.com/repos/baidu/personal-code/Luoyao 6 | # @File : __init__.py.py 7 | # @IDE: PyCharm -------------------------------------------------------------------------------- /config/global_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 18-1-31 上午11:21 4 | # @Author : MaybeShewill-CV 5 | # @Site : https://github.com/MaybeShewill-CV/CRNN_Tensorflow 6 | # @File : global_config.py 7 | # @IDE: PyCharm Community Edition 8 | """ 9 | 设置全局变量 10 | """ 11 | import easydict 12 | 13 | __C = easydict.EasyDict() 14 | # Consumers can get config by: from config import cfg 15 | 16 | cfg = __C 17 | 18 | # Train options 19 | __C.TRAIN = easydict.EasyDict() 20 | 21 | # Set the shadownet training epochs 22 | __C.TRAIN.EPOCHS = 160010 23 | # Set the display step 24 | __C.TRAIN.DISPLAY_STEP = 1 25 | # Set the test display step during training process 26 | __C.TRAIN.VAL_DISPLAY_STEP = 1000 27 | # Set the momentum parameter of the optimizer 28 | __C.TRAIN.MOMENTUM = 0.9 29 | # Set the initial learning rate 30 | __C.TRAIN.LEARNING_RATE = 0.1 31 | # Set the GPU resource used during training process 32 | __C.TRAIN.GPU_MEMORY_FRACTION = 0.75 33 | # Set the GPU allow growth parameter during tensorflow training process 34 | __C.TRAIN.TF_ALLOW_GROWTH = False 35 | # Set the shadownet training batch size 36 | __C.TRAIN.BATCH_SIZE = 32 37 | # Set the shadownet validation batch size 38 | __C.TRAIN.VAL_BATCH_SIZE = 32 39 | # Set the learning rate decay steps 40 | __C.TRAIN.LR_DECAY_STEPS_1 = 60000 41 | # Set the learning rate decay steps 42 | __C.TRAIN.LR_DECAY_STEPS_2 = 120000 43 | # Set the learning rate decay rate 44 | __C.TRAIN.LR_DECAY_RATE = 0.1 45 | # Set the weights decay 46 | __C.TRAIN.WEIGHT_DECAY = 0.0001 47 | # Set the train moving average decay 48 | __C.TRAIN.MOVING_AVERAGE_DECAY = 0.9999 49 | # Set the class numbers 50 | __C.TRAIN.CLASSES_NUMS = 5 51 | # Set the image height 52 | __C.TRAIN.IMG_HEIGHT = 256 53 | # Set the image width 54 | __C.TRAIN.IMG_WIDTH = 256 55 | # Set the image height 56 | __C.TRAIN.CROP_IMG_HEIGHT = 224 57 | # Set the image width 58 | __C.TRAIN.CROP_IMG_WIDTH = 224 59 | # Set the GPU nums 60 | __C.TRAIN.GPU_NUM = 2 61 | # Set cpu multi process thread nums 62 | __C.TRAIN.CPU_MULTI_PROCESS_NUMS = 6 63 | 64 | # Test options 65 | __C.TEST = easydict.EasyDict() 66 | 67 | # Set the GPU resource used during testing process 68 | __C.TEST.GPU_MEMORY_FRACTION = 0.8 69 | # Set the GPU allow growth parameter during tensorflow testing process 70 | __C.TEST.TF_ALLOW_GROWTH = True 71 | # Set the test batch size 72 | __C.TEST.BATCH_SIZE = 64 73 | 74 | __C.NET = easydict.EasyDict() 75 | # Set net residual_blocks_nums 76 | __C.NET.RESNET_SIZE = 50 77 | # Set feats summary flag 78 | __C.NET.NEED_SUMMARY_FEATS_MAP = False 79 | 80 | # Set nsfw dataset label map 81 | NSFW_LABEL_MAP = {'drawing': 0, 'hentai': 1, 'neural': 2, 'porn': 3, 'sexy': 4} 82 | # Set nsfw dataset prediction map 83 | NSFW_PREDICT_MAP = {0: 'drawing', 1: 'hentai', 2: 'neural', 3: 'porn', 4: 'sexy'} 84 | -------------------------------------------------------------------------------- /data/images/avg_train_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/images/avg_train_loss.png -------------------------------------------------------------------------------- /data/images/avg_train_top1_error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/images/avg_train_top1_error.png -------------------------------------------------------------------------------- /data/images/avg_val_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/images/avg_val_loss.png -------------------------------------------------------------------------------- /data/images/avg_val_top1_error.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/images/avg_val_top1_error.png -------------------------------------------------------------------------------- /data/images/confusion_matrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/images/confusion_matrix.png -------------------------------------------------------------------------------- /data/images/evaluation_nsfw.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/images/evaluation_nsfw.png -------------------------------------------------------------------------------- /data/images/online_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/images/online_demo.png -------------------------------------------------------------------------------- /data/images/precision_recall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/images/precision_recall.png -------------------------------------------------------------------------------- /data/nsfw_dataset_example/drawing/drawing_136.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_136.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/drawing/drawing_147.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_147.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/drawing/drawing_192.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_192.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/drawing/drawing_207.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_207.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/drawing/drawing_219.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_219.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/drawing/drawing_227.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_227.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/drawing/drawing_240.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_240.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/drawing/drawing_253.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_253.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/drawing/drawing_28.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_28.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/drawing/drawing_346.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_346.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/drawing/drawing_348.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_348.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/drawing/drawing_361.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_361.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/drawing/drawing_384.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_384.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/drawing/drawing_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_4.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/drawing/drawing_463.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_463.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/drawing/drawing_483.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_483.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/drawing/drawing_52.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_52.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/drawing/drawing_77.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_77.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/hentai/hentai_106.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_106.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/hentai/hentai_112.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_112.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/hentai/hentai_12.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_12.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/hentai/hentai_194.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_194.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/hentai/hentai_216.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_216.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/hentai/hentai_291.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_291.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/hentai/hentai_311.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_311.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/hentai/hentai_318.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_318.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/hentai/hentai_331.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_331.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/hentai/hentai_377.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_377.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/hentai/hentai_407.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_407.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/hentai/hentai_424.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_424.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/hentai/hentai_437.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_437.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/hentai/hentai_458.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_458.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/hentai/hentai_489.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_489.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/hentai/hentai_499.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_499.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/hentai/hentai_65.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_65.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/hentai/hentai_677.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_677.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/neural/neural_352.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_352.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/neural/neural_353.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_353.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/neural/neural_358.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_358.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/neural/neural_360.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_360.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/neural/neural_374.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_374.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/neural/neural_377.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_377.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/neural/neural_438.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_438.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/neural/neural_446.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_446.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/neural/neural_474.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_474.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/neural/neural_493.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_493.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/neural/neural_518.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_518.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/neural/neural_523.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_523.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/neural/neural_536.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_536.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/neural/neural_539.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_539.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/neural/neural_569.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_569.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/neural/neural_596.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_596.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/neural/neural_619.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_619.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/neural/neural_683.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_683.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_1016.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1016.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_1198.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1198.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_1270.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1270.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_1364.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1364.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_1470.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1470.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_1488.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1488.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_1515.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1515.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_1551.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1551.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_1556.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1556.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_1570.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1570.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_1592.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1592.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_1677.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1677.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_1723.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1723.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_179.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_179.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_1794.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1794.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_1866.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1866.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_1899.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1899.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_254.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_254.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_364.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_364.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_400.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_400.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_679.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_679.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_749.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_749.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_755.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_755.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/porn/porn_924.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_924.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_167.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_167.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_172.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_172.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_180.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_180.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_189.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_189.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_227.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_227.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_267.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_267.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_296.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_296.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_301.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_301.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_312.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_312.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_397.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_397.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_474.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_474.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_488.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_488.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_508.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_508.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_520.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_520.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_524.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_524.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_554.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_554.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_556.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_556.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_567.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_567.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_599.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_599.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_611.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_611.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_629.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_629.jpg -------------------------------------------------------------------------------- /data/nsfw_dataset_example/sexy/sexy_712.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_712.jpg -------------------------------------------------------------------------------- /data/test_data/drawing_16715.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/test_data/drawing_16715.jpg -------------------------------------------------------------------------------- /data/test_data/hentai_16814.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/test_data/hentai_16814.jpg -------------------------------------------------------------------------------- /data/test_data/neural_2619.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/test_data/neural_2619.jpg -------------------------------------------------------------------------------- /data/test_data/porn_19625.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/test_data/porn_19625.jpg -------------------------------------------------------------------------------- /data/test_data/sexy_6182.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/test_data/sexy_6182.jpg -------------------------------------------------------------------------------- /data_provider/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 19-2-14 下午5:42 4 | # @Author : Luo Yao 5 | # @Site : http://icode.baidu.com/repos/baidu/personal-code/Luoyao 6 | # @File : __init__.py.py 7 | # @IDE: PyCharm -------------------------------------------------------------------------------- /data_provider/nsfw_data_feed_pipline.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 19-2-14 下午5:43 4 | # @Author : MaybeShewill-CV 5 | # @Site : https://github.com/MaybeShewill-CV/CRNN_Tensorflow 6 | # @File : nsfw_data_feed_pipline.py 7 | # @IDE: PyCharm 8 | """ 9 | nsfw数据feed pipline 10 | """ 11 | import argparse 12 | import os 13 | import os.path as ops 14 | import random 15 | import multiprocessing 16 | 17 | import glob 18 | import glog as log 19 | import tensorflow as tf 20 | import pprint 21 | 22 | from config import global_config 23 | from data_provider import tf_io_pipline_tools 24 | 25 | CFG = global_config.cfg 26 | 27 | 28 | def init_args(): 29 | """ 30 | 31 | :return: 32 | """ 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--dataset_dir', type=str, help='The source nsfw data dir path') 35 | parser.add_argument('--tfrecords_dir', type=str, help='The dir path to save converted tfrecords') 36 | 37 | return parser.parse_args() 38 | 39 | 40 | class NsfwDataProducer(object): 41 | """ 42 | Convert raw image file into tfrecords 43 | """ 44 | def __init__(self, dataset_dir): 45 | """ 46 | 47 | :param dataset_dir: 48 | """ 49 | self._label_map = global_config.NSFW_LABEL_MAP 50 | 51 | self._dataset_dir = dataset_dir 52 | 53 | self._drawing_image_dir = ops.join(dataset_dir, 'drawing') 54 | self._hentai_image_dir = ops.join(dataset_dir, 'hentai') 55 | self._neural_image_dir = ops.join(dataset_dir, 'neural') 56 | self._porn_image_dir = ops.join(dataset_dir, 'porn') 57 | self._sexy_image_dir = ops.join(dataset_dir, 'sexy') 58 | 59 | self._train_example_index_file_path = ops.join(self._dataset_dir, 'train.txt') 60 | self._test_example_index_file_path = ops.join(self._dataset_dir, 'test.txt') 61 | self._val_example_index_file_path = ops.join(self._dataset_dir, 'val.txt') 62 | 63 | if not self._is_source_data_complete(): 64 | raise ValueError('Source image data is not complete, ' 65 | 'please check if one of the image folder is not exist') 66 | 67 | if not self._is_training_example_index_file_complete(): 68 | self._generate_training_example_index_file() 69 | 70 | def print_label_map(self): 71 | """ 72 | Print label map which show the labels information 73 | :return: 74 | """ 75 | pprint.pprint(self._label_map) 76 | 77 | def generate_tfrecords(self, save_dir, step_size=10000): 78 | """ 79 | Generate tensorflow records file 80 | :param save_dir: 81 | :param step_size: generate a tfrecord every step_size examples 82 | :return: 83 | """ 84 | def _read_training_example_index_file(_index_file_path): 85 | 86 | assert ops.exists(_index_file_path) 87 | 88 | _example_path_info = [] 89 | _example_label_info = [] 90 | 91 | with open(_index_file_path, 'r') as _file: 92 | for _line in _file: 93 | _example_info = _line.rstrip('\r').rstrip('\n').split(' ') 94 | _example_path_info.append(_example_info[0]) 95 | _example_label_info.append(int(_example_info[1])) 96 | 97 | return _example_path_info, _example_label_info 98 | 99 | def _split_writing_tfrecords_task(_example_paths, _example_labels, _flags='train'): 100 | 101 | _split_example_paths = [] 102 | _split_example_labels = [] 103 | _split_tfrecords_save_paths = [] 104 | 105 | for i in range(0, len(_example_paths), step_size): 106 | _split_example_paths.append(_example_paths[i:i + step_size]) 107 | _split_example_labels.append(_example_labels[i:i + step_size]) 108 | 109 | if i + step_size > len(_example_paths): 110 | _split_tfrecords_save_paths.append( 111 | ops.join(save_dir, '{:s}_{:d}_{:d}.tfrecords'.format(_flags, i, len(_example_paths)))) 112 | else: 113 | _split_tfrecords_save_paths.append( 114 | ops.join(save_dir, '{:s}_{:d}_{:d}.tfrecords'.format(_flags, i, i + step_size))) 115 | 116 | return _split_example_paths, _split_example_labels, _split_tfrecords_save_paths 117 | 118 | # make save dirs 119 | os.makedirs(save_dir, exist_ok=True) 120 | 121 | # set process pool 122 | process_pool = multiprocessing.Pool(processes=CFG.TRAIN.CPU_MULTI_PROCESS_NUMS) 123 | 124 | # generate training example tfrecords 125 | log.info('Generating training example tfrecords') 126 | 127 | train_example_paths, train_example_labels = _read_training_example_index_file( 128 | self._train_example_index_file_path) 129 | train_example_paths_split, train_example_labels_split, train_tfrecords_save_paths = \ 130 | _split_writing_tfrecords_task(train_example_paths, train_example_labels, _flags='train') 131 | 132 | for index, example_paths in enumerate(train_example_paths_split): 133 | process_pool.apply_async(func=tf_io_pipline_tools.write_example_tfrecords, 134 | args=(example_paths, 135 | train_example_labels_split[index], 136 | train_tfrecords_save_paths[index],)) 137 | 138 | process_pool.close() 139 | process_pool.join() 140 | 141 | log.info('Generate training example tfrecords complete') 142 | 143 | # set process pool 144 | process_pool = multiprocessing.Pool(processes=CFG.TRAIN.CPU_MULTI_PROCESS_NUMS) 145 | 146 | # generate val example tfrecords 147 | log.info('Generating validation example tfrecords') 148 | 149 | val_example_paths, val_example_labels = _read_training_example_index_file( 150 | self._val_example_index_file_path) 151 | val_example_paths_split, val_example_labels_split, val_tfrecords_save_paths = \ 152 | _split_writing_tfrecords_task(val_example_paths, val_example_labels, _flags='val') 153 | 154 | for index, example_paths in enumerate(val_example_paths_split): 155 | process_pool.apply_async(func=tf_io_pipline_tools.write_example_tfrecords, 156 | args=(example_paths, 157 | val_example_labels_split[index], 158 | val_tfrecords_save_paths[index],)) 159 | 160 | process_pool.close() 161 | process_pool.join() 162 | 163 | log.info('Generate validation example tfrecords complete') 164 | 165 | # set process pool 166 | process_pool = multiprocessing.Pool(processes=CFG.TRAIN.CPU_MULTI_PROCESS_NUMS) 167 | 168 | # generate test example tfrecords 169 | log.info('Generating testing example tfrecords') 170 | 171 | test_example_paths, test_example_labels = _read_training_example_index_file( 172 | self._test_example_index_file_path) 173 | test_example_paths_split, test_example_labels_split, test_tfrecords_save_paths = \ 174 | _split_writing_tfrecords_task(test_example_paths, test_example_labels, _flags='test') 175 | 176 | for index, example_paths in enumerate(test_example_paths_split): 177 | process_pool.apply_async(func=tf_io_pipline_tools.write_example_tfrecords, 178 | args=(example_paths, 179 | test_example_labels_split[index], 180 | test_tfrecords_save_paths[index],)) 181 | 182 | process_pool.close() 183 | process_pool.join() 184 | 185 | log.info('Generate testing example tfrecords complete') 186 | 187 | return 188 | 189 | def _is_source_data_complete(self): 190 | """ 191 | Check if source data complete 192 | :return: 193 | """ 194 | return \ 195 | ops.exists(self._drawing_image_dir) and ops.exists(self._hentai_image_dir) \ 196 | and ops.exists(self._neural_image_dir) and ops.exists(self._porn_image_dir) \ 197 | and ops.exists(self._sexy_image_dir) 198 | 199 | def _is_training_example_index_file_complete(self): 200 | """ 201 | Check if the training example index file is complete 202 | :return: 203 | """ 204 | return \ 205 | ops.exists(self._train_example_index_file_path) and \ 206 | ops.exists(self._test_example_index_file_path) and \ 207 | ops.exists(self._val_example_index_file_path) 208 | 209 | def _generate_training_example_index_file(self): 210 | """ 211 | Generate training example index file, split source file into 0.75, 0.15, 0.1 for training, 212 | testing and validation. Each image folder are processed separately 213 | :return: 214 | """ 215 | def _process_single_training_folder(_folder_dir): 216 | 217 | _folder_label_name = ops.split(_folder_dir)[1] 218 | _folder_label_index = self._label_map[_folder_label_name] 219 | 220 | _source_image_paths = glob.glob('{:s}/*'.format(_folder_dir)) 221 | 222 | return ['{:s} {:d}\n'.format(s, _folder_label_index) for s in _source_image_paths] 223 | 224 | def _split_training_examples(_example_info): 225 | 226 | random.shuffle(_example_info) 227 | 228 | _example_nums = len(_example_info) 229 | 230 | _train_example_info = _example_info[:int(_example_nums * 0.75)] 231 | _test_example_info = _example_info[int(_example_nums * 0.75):int(_example_nums * 0.9)] 232 | _val_example_info = _example_info[int(_example_nums * 0.9):] 233 | 234 | return _train_example_info, _test_example_info, _val_example_info 235 | 236 | train_example_info = [] 237 | test_example_info = [] 238 | val_example_info = [] 239 | 240 | for example_dir in [self._drawing_image_dir, self._hentai_image_dir, 241 | self._neural_image_dir, self._porn_image_dir, 242 | self._sexy_image_dir]: 243 | _train_tmp_info, _test_tmp_info, _val_tmp_info = \ 244 | _split_training_examples(_process_single_training_folder(example_dir)) 245 | 246 | train_example_info.extend(_train_tmp_info) 247 | test_example_info.extend(_test_tmp_info) 248 | val_example_info.extend(_val_tmp_info) 249 | 250 | random.shuffle(train_example_info) 251 | random.shuffle(test_example_info) 252 | random.shuffle(val_example_info) 253 | 254 | with open(ops.join(self._dataset_dir, 'train.txt'), 'w') as file: 255 | file.write(''.join(train_example_info)) 256 | 257 | with open(ops.join(self._dataset_dir, 'test.txt'), 'w') as file: 258 | file.write(''.join(test_example_info)) 259 | 260 | with open(ops.join(self._dataset_dir, 'val.txt'), 'w') as file: 261 | file.write(''.join(val_example_info)) 262 | 263 | log.info('Generate training example index file complete') 264 | 265 | return 266 | 267 | 268 | class NsfwDataFeeder(object): 269 | """ 270 | Read training examples from tfrecords for nsfw model 271 | """ 272 | def __init__(self, dataset_dir, flags='train'): 273 | """ 274 | 275 | :param dataset_dir: 276 | :param flags: 277 | """ 278 | self._dataset_dir = dataset_dir 279 | 280 | self._tfrecords_dir = ops.join(dataset_dir, 'tfrecords') 281 | if not ops.exists(self._tfrecords_dir): 282 | raise ValueError('{:s} not exist, please check again'.format(self._tfrecords_dir)) 283 | 284 | self._dataset_flags = flags.lower() 285 | if self._dataset_flags not in ['train', 'test', 'val']: 286 | raise ValueError('flags of the data feeder should be \'train\', \'test\', \'val\'') 287 | 288 | self._label_map = global_config.NSFW_LABEL_MAP 289 | 290 | self._prediction_map = global_config.NSFW_PREDICT_MAP 291 | 292 | @property 293 | def label_map(self): 294 | """ 295 | 296 | :return: 297 | """ 298 | return self._label_map 299 | 300 | @property 301 | def prediction_map(self): 302 | """ 303 | 304 | :return: 305 | """ 306 | return self._prediction_map 307 | 308 | def inputs(self, batch_size, num_epochs): 309 | """ 310 | dataset feed pipline input 311 | :param batch_size: 312 | :param num_epochs: 313 | :return: A tuple (images, labels), where: 314 | * images is a float tensor with shape [batch_size, H, W, C] 315 | in the range [-0.5, 0.5]. 316 | * labels is an int32 tensor with shape [batch_size] with the true label, 317 | a number in the range [0, CLASS_NUMS). 318 | """ 319 | if not num_epochs: 320 | num_epochs = None 321 | 322 | tfrecords_file_paths = glob.glob('{:s}/{:s}*.tfrecords'.format(self._tfrecords_dir, self._dataset_flags)) 323 | random.shuffle(tfrecords_file_paths) 324 | 325 | with tf.name_scope('input_tensor'): 326 | 327 | # TFRecordDataset opens a binary file and reads one record at a time. 328 | # `tfrecords_file_paths` could also be a list of filenames, which will be read in order. 329 | dataset = tf.data.TFRecordDataset(tfrecords_file_paths) 330 | 331 | # The map transformation takes a function and applies it to every element 332 | # of the dataset. 333 | dataset = dataset.map(map_func=tf_io_pipline_tools.decode, 334 | num_parallel_calls=CFG.TRAIN.CPU_MULTI_PROCESS_NUMS) 335 | if self._dataset_flags == 'train': 336 | dataset = dataset.map(map_func=tf_io_pipline_tools.augment_for_train, 337 | num_parallel_calls=CFG.TRAIN.CPU_MULTI_PROCESS_NUMS) 338 | else: 339 | dataset = dataset.map(map_func=tf_io_pipline_tools.augment_for_validation, 340 | num_parallel_calls=CFG.TRAIN.CPU_MULTI_PROCESS_NUMS) 341 | dataset = dataset.map(map_func=tf_io_pipline_tools.normalize, 342 | num_parallel_calls=CFG.TRAIN.CPU_MULTI_PROCESS_NUMS) 343 | 344 | # The shuffle transformation uses a finite-sized buffer to shuffle elements 345 | # in memory. The parameter is the number of elements in the buffer. For 346 | # completely uniform shuffling, set the parameter to be the same as the 347 | # number of elements in the dataset. 348 | if self._dataset_flags != 'test': 349 | dataset = dataset.shuffle(buffer_size=5000) 350 | # repeat num epochs 351 | dataset = dataset.repeat() 352 | 353 | dataset = dataset.batch(batch_size) 354 | 355 | iterator = dataset.make_one_shot_iterator() 356 | 357 | return iterator.get_next(name='{:s}_IteratorGetNext'.format(self._dataset_flags)) 358 | 359 | 360 | if __name__ == '__main__': 361 | 362 | # init args 363 | args = init_args() 364 | 365 | assert ops.exists(args.dataset_dir), '{:s} not exist'.format(args.dataset_dir) 366 | 367 | producer = NsfwDataProducer(dataset_dir=args.dataset_dir) 368 | producer.print_label_map() 369 | producer.generate_tfrecords(save_dir=args.tfrecords_dir, step_size=10000) 370 | -------------------------------------------------------------------------------- /data_provider/tf_io_pipline_tools.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 19-2-15 下午2:13 4 | # @Author : MaybeShewill-CV 5 | # @Site : https://github.com/MaybeShewill-CV/CRNN_Tensorflow 6 | # @File : tf_io_pipline_tools.py 7 | # @IDE: PyCharm 8 | """ 9 | Some tensorflow records io tools 10 | """ 11 | import os 12 | import os.path as ops 13 | 14 | import cv2 15 | import tensorflow as tf 16 | import glog as log 17 | 18 | from config import global_config 19 | 20 | 21 | CFG = global_config.cfg 22 | 23 | _R_MEAN = 123.68 24 | _G_MEAN = 116.78 25 | _B_MEAN = 103.94 26 | _CHANNEL_MEANS = [_B_MEAN, _G_MEAN, _R_MEAN] 27 | 28 | 29 | def int64_feature(value): 30 | """ 31 | 32 | :return: 33 | """ 34 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 35 | 36 | 37 | def bytes_feature(value): 38 | """ 39 | 40 | :param value: 41 | :return: 42 | """ 43 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 44 | 45 | 46 | def write_example_tfrecords(example_paths, example_labels, tfrecords_path): 47 | """ 48 | write tfrecords 49 | :param example_paths: 50 | :param example_labels: 51 | :param tfrecords_path: 52 | :return: 53 | """ 54 | _tfrecords_dir = ops.split(tfrecords_path)[0] 55 | os.makedirs(_tfrecords_dir, exist_ok=True) 56 | 57 | log.info('Writing {:s}....'.format(tfrecords_path)) 58 | 59 | with tf.python_io.TFRecordWriter(tfrecords_path) as _writer: 60 | for _index, _example_path in enumerate(example_paths): 61 | 62 | with open(_example_path, 'rb') as f: 63 | check_chars = f.read()[-2:] 64 | if check_chars != b'\xff\xd9': 65 | log.error('Image file {:s} is not complete'.format(_example_path)) 66 | continue 67 | else: 68 | _example_image = cv2.imread(_example_path, cv2.IMREAD_COLOR) 69 | _example_image = cv2.resize(_example_image, 70 | dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), 71 | interpolation=cv2.INTER_CUBIC) 72 | _example_image_raw = _example_image.tostring() 73 | 74 | _example = tf.train.Example( 75 | features=tf.train.Features( 76 | feature={ 77 | 'height': int64_feature(CFG.TRAIN.IMG_HEIGHT), 78 | 'width': int64_feature(CFG.TRAIN.IMG_WIDTH), 79 | 'depth': int64_feature(3), 80 | 'label': int64_feature(example_labels[_index]), 81 | 'image_raw': bytes_feature(_example_image_raw) 82 | })) 83 | _writer.write(_example.SerializeToString()) 84 | 85 | log.info('Writing {:s} complete'.format(tfrecords_path)) 86 | 87 | return 88 | 89 | 90 | def decode(serialized_example): 91 | """ 92 | Parses an image and label from the given `serialized_example` 93 | :param serialized_example: 94 | :return: 95 | """ 96 | features = tf.parse_single_example( 97 | serialized_example, 98 | # Defaults are not specified since both keys are required. 99 | features={ 100 | 'image_raw': tf.FixedLenFeature([], tf.string), 101 | 'label': tf.FixedLenFeature([], tf.int64), 102 | 'height': tf.FixedLenFeature([], tf.int64), 103 | 'width': tf.FixedLenFeature([], tf.int64), 104 | 'depth': tf.FixedLenFeature([], tf.int64) 105 | }) 106 | 107 | # decode image 108 | image = tf.decode_raw(features['image_raw'], tf.uint8) 109 | image_shape = tf.stack([CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 3]) 110 | image = tf.reshape(image, image_shape) 111 | 112 | # Convert label from a scalar int64 tensor to an int32 scalar. 113 | label = tf.cast(features['label'], tf.int32) 114 | 115 | return image, label 116 | 117 | 118 | def augment_for_train(image, label): 119 | """ 120 | 121 | :param image: 122 | :param label: 123 | :return: 124 | """ 125 | # first apply random crop 126 | image = tf.image.random_crop(value=image, 127 | size=[CFG.TRAIN.CROP_IMG_HEIGHT, CFG.TRAIN.CROP_IMG_WIDTH, 3], 128 | seed=tf.set_random_seed(1234), 129 | name='crop_image') 130 | # apply random flip 131 | image = tf.image.random_flip_left_right(image=image, seed=tf.set_random_seed(1234)) 132 | 133 | return image, label 134 | 135 | 136 | def augment_for_validation(image, label): 137 | """ 138 | 139 | :param image: 140 | :param label: 141 | :return: 142 | """ 143 | assert CFG.TRAIN.IMG_HEIGHT == CFG.TRAIN.IMG_WIDTH 144 | assert CFG.TRAIN.CROP_IMG_HEIGHT == CFG.TRAIN.CROP_IMG_WIDTH 145 | 146 | # apply central crop 147 | central_fraction = CFG.TRAIN.CROP_IMG_HEIGHT / CFG.TRAIN.IMG_HEIGHT 148 | image = tf.image.central_crop(image=image, central_fraction=central_fraction) 149 | 150 | return image, label 151 | 152 | 153 | def normalize(image, label): 154 | """ 155 | Normalize the image data by substracting the imagenet mean value 156 | :param image: 157 | :param label: 158 | :return: 159 | """ 160 | 161 | if image.get_shape().ndims != 3: 162 | raise ValueError('Input must be of size [height, width, C>0]') 163 | 164 | image_fp = tf.cast(image, dtype=tf.float32) 165 | means = tf.expand_dims(tf.expand_dims(_CHANNEL_MEANS, 0), 0) 166 | 167 | return image_fp - means, label 168 | -------------------------------------------------------------------------------- /docker_container/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 19-3-14 下午9:48 4 | # @Author : Luo Yao 5 | # @Site : http://icode.baidu.com/repos/baidu/personal-code/Luoyao 6 | # @File : __init__.py.py 7 | # @IDE: PyCharm -------------------------------------------------------------------------------- /docker_container/python_client.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Send JPEG image to tensorflow_model_server loaded with GAN model. 3 | 4 | Hint: the code has been compiled together with TensorFlow serving 5 | and not locally. The client is called in the TensorFlow Docker container 6 | ''' 7 | 8 | import time 9 | from argparse import ArgumentParser 10 | 11 | import grpc 12 | import numpy as np 13 | import cv2 14 | from tensorflow.contrib.util import make_tensor_proto 15 | 16 | from tensorflow_serving.apis import predict_pb2 17 | from tensorflow_serving.apis import prediction_service_pb2_grpc 18 | 19 | 20 | def parse_args(): 21 | parser = ArgumentParser(description='Request a TensorFlow server for a prediction on the image') 22 | parser.add_argument('-s', '--server', 23 | dest='server', 24 | default='0.0.0.0:9000', 25 | help='prediction service host:port') 26 | parser.add_argument("-i", "--image", 27 | dest="image", 28 | default='', 29 | help="path to image in JPEG format", ) 30 | parser.add_argument('-p', '--image_path', 31 | dest='image_path', 32 | default='/home/baidu/Pictures/dota2_1.jpg', 33 | help='path to images folder', ) 34 | parser.add_argument('-b', '--batch_mode', 35 | dest='batch_mode', 36 | default='true', 37 | help='send image as batch or one-by-one') 38 | args = parser.parse_args() 39 | 40 | return args.server, args.image, args.image_path, args.batch_mode == 'true' 41 | 42 | 43 | def main(): 44 | """ 45 | 46 | :return: 47 | """ 48 | server, image, image_path, batch_mode = parse_args() 49 | 50 | channel = grpc.insecure_channel(server) 51 | stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) 52 | 53 | image = cv2.imread(image_path, cv2.IMREAD_COLOR) 54 | image = cv2.resize(image, (224, 224)) 55 | image = np.array(image, np.float32) 56 | 57 | image_list = [] 58 | 59 | for i in range(128): 60 | image_list.append(image) 61 | 62 | image_list = np.array(image_list, dtype=np.float32) 63 | 64 | start = time.time() 65 | 66 | if batch_mode: 67 | print('In batch mode') 68 | request = predict_pb2.PredictRequest() 69 | request.model_spec.name = 'nsfw' 70 | request.model_spec.signature_name = 'classify_result' 71 | 72 | request.inputs['input_tensor'].CopyFrom(make_tensor_proto( 73 | image_list, shape=[128, 224, 224, 3])) 74 | 75 | try: 76 | result = stub.Predict(request, 60.0) 77 | except Exception as err: 78 | print(err) 79 | return 80 | print(result) 81 | else: 82 | return 83 | 84 | end = time.time() 85 | time_diff = end - start 86 | print('time elapased: {}'.format(time_diff)) 87 | 88 | 89 | if __name__ == '__main__': 90 | main() 91 | -------------------------------------------------------------------------------- /model/nsfw_export_saved_model/1/saved_model.pb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/model/nsfw_export_saved_model/1/saved_model.pb -------------------------------------------------------------------------------- /model/nsfw_export_saved_model/1/variables/variables.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/model/nsfw_export_saved_model/1/variables/variables.data-00000-of-00001 -------------------------------------------------------------------------------- /model/nsfw_export_saved_model/1/variables/variables.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/model/nsfw_export_saved_model/1/variables/variables.index -------------------------------------------------------------------------------- /nsfw_model/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 19-2-14 下午5:38 4 | # @Author : Luo Yao 5 | # @Site : http://icode.baidu.com/repos/baidu/personal-code/Luoyao 6 | # @File : __init__.py.py 7 | # @IDE: PyCharm 8 | -------------------------------------------------------------------------------- /nsfw_model/cnn_basenet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # @Time : 17-9-18 下午3:59 4 | # @Author : MaybeShewill-CV 5 | # @Site : https://github.com/MaybeShewill-CV/CRNN_Tensorflow 6 | # @File : cnn_basenet.py 7 | # @IDE: PyCharm Community Edition 8 | """ 9 | The base convolution neural networks mainly implement some useful cnn functions 10 | """ 11 | import tensorflow as tf 12 | from tensorflow.python.training import moving_averages 13 | from tensorflow.contrib.framework import add_model_variable 14 | import numpy as np 15 | 16 | 17 | class CNNBaseModel(object): 18 | """ 19 | Base model for other specific cnn ctpn_models 20 | """ 21 | 22 | def __init__(self): 23 | pass 24 | 25 | @staticmethod 26 | def conv2d(inputdata, out_channel, kernel_size, padding='SAME', 27 | stride=1, w_init=None, b_init=None, 28 | split=1, use_bias=True, data_format='NHWC', name=None): 29 | """ 30 | Packing the tensorflow conv2d function. 31 | :param name: op name 32 | :param inputdata: A 4D tensorflow tensor which ust have known number of channels, but can have other 33 | unknown dimensions. 34 | :param out_channel: number of output channel. 35 | :param kernel_size: int so only support square kernel convolution 36 | :param padding: 'VALID' or 'SAME' 37 | :param stride: int so only support square stride 38 | :param w_init: initializer for convolution weights 39 | :param b_init: initializer for bias 40 | :param split: split channels as used in Alexnet mainly group for GPU memory save. 41 | :param use_bias: whether to use bias. 42 | :param data_format: default set to NHWC according tensorflow 43 | :return: tf.Tensor named ``output`` 44 | """ 45 | with tf.variable_scope(name): 46 | in_shape = inputdata.get_shape().as_list() 47 | channel_axis = 3 if data_format == 'NHWC' else 1 48 | in_channel = in_shape[channel_axis] 49 | 50 | assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!" 51 | assert in_channel % split == 0 52 | assert out_channel % split == 0 53 | 54 | padding = padding.upper() 55 | 56 | if isinstance(kernel_size, list): 57 | filter_shape = [kernel_size[0], kernel_size[1]] + [in_channel / split, out_channel] 58 | else: 59 | filter_shape = [kernel_size, kernel_size] + [in_channel / split, out_channel] 60 | 61 | if isinstance(stride, list): 62 | strides = [1, stride[0], stride[1], 1] if data_format == 'NHWC' \ 63 | else [1, 1, stride[0], stride[1]] 64 | else: 65 | strides = [1, stride, stride, 1] if data_format == 'NHWC' \ 66 | else [1, 1, stride, stride] 67 | 68 | if w_init is None: 69 | w_init = tf.contrib.layers.variance_scaling_initializer() 70 | if b_init is None: 71 | b_init = tf.constant_initializer() 72 | 73 | w = tf.get_variable('W', filter_shape, initializer=w_init) 74 | b = None 75 | 76 | if use_bias: 77 | b = tf.get_variable('b', [out_channel], initializer=b_init) 78 | 79 | if split == 1: 80 | conv = tf.nn.conv2d(inputdata, w, strides, padding, data_format=data_format) 81 | else: 82 | inputs = tf.split(inputdata, split, channel_axis) 83 | kernels = tf.split(w, split, 3) 84 | outputs = [tf.nn.conv2d(i, k, strides, padding, data_format=data_format) 85 | for i, k in zip(inputs, kernels)] 86 | conv = tf.concat(outputs, channel_axis) 87 | 88 | ret = tf.identity(tf.nn.bias_add(conv, b, data_format=data_format) 89 | if use_bias else conv, name=name) 90 | 91 | return ret 92 | 93 | @staticmethod 94 | def relu(inputdata, name=None): 95 | """ 96 | 97 | :param name: 98 | :param inputdata: 99 | :return: 100 | """ 101 | return tf.nn.relu(features=inputdata, name=name) 102 | 103 | @staticmethod 104 | def sigmoid(inputdata, name=None): 105 | """ 106 | 107 | :param name: 108 | :param inputdata: 109 | :return: 110 | """ 111 | return tf.nn.sigmoid(x=inputdata, name=name) 112 | 113 | @staticmethod 114 | def maxpooling(inputdata, kernel_size, stride=None, padding='VALID', 115 | data_format='NHWC', name=None): 116 | """ 117 | 118 | :param name: 119 | :param inputdata: 120 | :param kernel_size: 121 | :param stride: 122 | :param padding: 123 | :param data_format: 124 | :return: 125 | """ 126 | padding = padding.upper() 127 | 128 | if stride is None: 129 | stride = kernel_size 130 | 131 | if isinstance(kernel_size, list): 132 | kernel = [1, kernel_size[0], kernel_size[1], 1] if data_format == 'NHWC' else \ 133 | [1, 1, kernel_size[0], kernel_size[1]] 134 | else: 135 | kernel = [1, kernel_size, kernel_size, 1] if data_format == 'NHWC' \ 136 | else [1, 1, kernel_size, kernel_size] 137 | 138 | if isinstance(stride, list): 139 | strides = [1, stride[0], stride[1], 1] if data_format == 'NHWC' \ 140 | else [1, 1, stride[0], stride[1]] 141 | else: 142 | strides = [1, stride, stride, 1] if data_format == 'NHWC' \ 143 | else [1, 1, stride, stride] 144 | 145 | return tf.nn.max_pool(value=inputdata, ksize=kernel, strides=strides, padding=padding, 146 | data_format=data_format, name=name) 147 | 148 | @staticmethod 149 | def avgpooling(inputdata, kernel_size, stride=None, padding='VALID', 150 | data_format='NHWC', name=None): 151 | """ 152 | 153 | :param name: 154 | :param inputdata: 155 | :param kernel_size: 156 | :param stride: 157 | :param padding: 158 | :param data_format: 159 | :return: 160 | """ 161 | if stride is None: 162 | stride = kernel_size 163 | 164 | kernel = [1, kernel_size, kernel_size, 1] if data_format == 'NHWC' \ 165 | else [1, 1, kernel_size, kernel_size] 166 | 167 | strides = [1, stride, stride, 1] if data_format == 'NHWC' else [1, 1, stride, stride] 168 | 169 | return tf.nn.avg_pool(value=inputdata, ksize=kernel, strides=strides, padding=padding, 170 | data_format=data_format, name=name) 171 | 172 | @staticmethod 173 | def globalavgpooling(inputdata, data_format='NHWC', name=None): 174 | """ 175 | 176 | :param name: 177 | :param inputdata: 178 | :param data_format: 179 | :return: 180 | """ 181 | assert inputdata.shape.ndims == 4 182 | assert data_format in ['NHWC', 'NCHW'] 183 | 184 | axis = [1, 2] if data_format == 'NHWC' else [2, 3] 185 | 186 | return tf.reduce_mean(input_tensor=inputdata, axis=axis, name=name) 187 | 188 | @staticmethod 189 | def layernorm(inputdata, epsilon=1e-5, use_bias=True, use_scale=True, 190 | data_format='NHWC', name=None): 191 | """ 192 | :param name: 193 | :param inputdata: 194 | :param epsilon: epsilon to avoid divide-by-zero. 195 | :param use_bias: whether to use the extra affine transformation or not. 196 | :param use_scale: whether to use the extra affine transformation or not. 197 | :param data_format: 198 | :return: 199 | """ 200 | shape = inputdata.get_shape().as_list() 201 | ndims = len(shape) 202 | assert ndims in [2, 4] 203 | 204 | mean, var = tf.nn.moments(inputdata, list(range(1, len(shape))), keep_dims=True) 205 | 206 | if data_format == 'NCHW': 207 | channnel = shape[1] 208 | new_shape = [1, channnel, 1, 1] 209 | else: 210 | channnel = shape[-1] 211 | new_shape = [1, 1, 1, channnel] 212 | if ndims == 2: 213 | new_shape = [1, channnel] 214 | 215 | if use_bias: 216 | beta = tf.get_variable('beta', [channnel], initializer=tf.constant_initializer()) 217 | beta = tf.reshape(beta, new_shape) 218 | else: 219 | beta = tf.zeros([1] * ndims, name='beta') 220 | if use_scale: 221 | gamma = tf.get_variable('gamma', [channnel], initializer=tf.constant_initializer(1.0)) 222 | gamma = tf.reshape(gamma, new_shape) 223 | else: 224 | gamma = tf.ones([1] * ndims, name='gamma') 225 | 226 | return tf.nn.batch_normalization(inputdata, mean, var, beta, gamma, epsilon, name=name) 227 | 228 | @staticmethod 229 | def instancenorm(inputdata, epsilon=1e-5, data_format='NHWC', use_affine=True, name=None): 230 | """ 231 | 232 | :param name: 233 | :param inputdata: 234 | :param epsilon: 235 | :param data_format: 236 | :param use_affine: 237 | :return: 238 | """ 239 | shape = inputdata.get_shape().as_list() 240 | if len(shape) != 4: 241 | raise ValueError("Input data of instancebn layer has to be 4D tensor") 242 | 243 | if data_format == 'NHWC': 244 | axis = [1, 2] 245 | ch = shape[3] 246 | new_shape = [1, 1, 1, ch] 247 | else: 248 | axis = [2, 3] 249 | ch = shape[1] 250 | new_shape = [1, ch, 1, 1] 251 | if ch is None: 252 | raise ValueError("Input of instancebn require known channel!") 253 | 254 | mean, var = tf.nn.moments(inputdata, axis, keep_dims=True) 255 | 256 | if not use_affine: 257 | return tf.divide(inputdata - mean, tf.sqrt(var + epsilon), name='output') 258 | 259 | beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer()) 260 | beta = tf.reshape(beta, new_shape) 261 | gamma = tf.get_variable('gamma', [ch], initializer=tf.constant_initializer(1.0)) 262 | gamma = tf.reshape(gamma, new_shape) 263 | return tf.nn.batch_normalization(inputdata, mean, var, beta, gamma, epsilon, name=name) 264 | 265 | @staticmethod 266 | def dropout(inputdata, keep_prob, noise_shape=None, name=None): 267 | """ 268 | 269 | :param name: 270 | :param inputdata: 271 | :param keep_prob: 272 | :param noise_shape: 273 | :return: 274 | """ 275 | return tf.nn.dropout(inputdata, keep_prob=keep_prob, noise_shape=noise_shape, name=name) 276 | 277 | @staticmethod 278 | def fullyconnect(inputdata, out_dim, w_init=None, b_init=None, 279 | use_bias=True, name=None): 280 | """ 281 | Fully-Connected layer, takes a N>1D tensor and returns a 2D tensor. 282 | It is an equivalent of `tf.layers.dense` except for naming conventions. 283 | 284 | :param inputdata: a tensor to be flattened except for the first dimension. 285 | :param out_dim: output dimension 286 | :param w_init: initializer for w. Defaults to `variance_scaling_initializer`. 287 | :param b_init: initializer for b. Defaults to zero 288 | :param use_bias: whether to use bias. 289 | :param name: 290 | :return: tf.Tensor: a NC tensor named ``output`` with attribute `variables`. 291 | """ 292 | shape = inputdata.get_shape().as_list()[1:] 293 | if None not in shape: 294 | inputdata = tf.reshape(inputdata, [-1, int(np.prod(shape))]) 295 | else: 296 | inputdata = tf.reshape(inputdata, tf.stack([tf.shape(inputdata)[0], -1])) 297 | 298 | if w_init is None: 299 | w_init = tf.contrib.layers.variance_scaling_initializer() 300 | if b_init is None: 301 | b_init = tf.constant_initializer() 302 | 303 | ret = tf.layers.dense(inputs=inputdata, activation=lambda x: tf.identity(x, name='output'), 304 | use_bias=use_bias, name=name, 305 | kernel_initializer=w_init, 306 | bias_initializer=b_init, 307 | trainable=True, units=out_dim) 308 | return ret 309 | 310 | @staticmethod 311 | def layerbn(inputdata, is_training, name, momentum=0.999, eps=1e-3): 312 | """ 313 | 314 | :param inputdata: 315 | :param is_training: 316 | :param name: 317 | :param momentum: 318 | :param eps: 319 | :return: 320 | """ 321 | 322 | return tf.layers.batch_normalization( 323 | inputs=inputdata, training=is_training, name=name, momentum=momentum, epsilon=eps) 324 | 325 | @staticmethod 326 | def layerbn_distributed(list_input, stats_mode, data_format='NHWC', 327 | float_type=tf.float32, trainable=True, 328 | use_gamma=True, use_beta=True, bn_epsilon=1e-5, 329 | bn_ema=0.9, name='BatchNorm'): 330 | """ 331 | Batch norm for distributed training process 332 | :param list_input: 333 | :param stats_mode: 334 | :param data_format: 335 | :param float_type: 336 | :param trainable: 337 | :param use_gamma: 338 | :param use_beta: 339 | :param bn_epsilon: 340 | :param bn_ema: 341 | :param name: 342 | :return: 343 | """ 344 | 345 | def _get_bn_variables(_n_out, _use_scale, _use_bias, _trainable, _float_type): 346 | 347 | if _use_bias: 348 | _beta = tf.get_variable('beta', [_n_out], 349 | initializer=tf.constant_initializer(), 350 | trainable=_trainable, 351 | dtype=_float_type) 352 | else: 353 | _beta = tf.zeros([_n_out], name='beta') 354 | if _use_scale: 355 | _gamma = tf.get_variable('gamma', [_n_out], 356 | initializer=tf.constant_initializer(1.0), 357 | trainable=_trainable, 358 | dtype=_float_type) 359 | else: 360 | _gamma = tf.ones([_n_out], name='gamma') 361 | 362 | _moving_mean = tf.get_variable('moving_mean', [_n_out], 363 | initializer=tf.constant_initializer(), 364 | trainable=False, 365 | dtype=_float_type) 366 | _moving_var = tf.get_variable('moving_variance', [_n_out], 367 | initializer=tf.constant_initializer(1), 368 | trainable=False, 369 | dtype=_float_type) 370 | return _beta, _gamma, _moving_mean, _moving_var 371 | 372 | def _update_bn_ema(_xn, _batch_mean, _batch_var, _moving_mean, _moving_var, _decay): 373 | 374 | _update_op1 = moving_averages.assign_moving_average( 375 | _moving_mean, _batch_mean, _decay, zero_debias=False, 376 | name='mean_ema_op') 377 | _update_op2 = moving_averages.assign_moving_average( 378 | _moving_var, _batch_var, _decay, zero_debias=False, 379 | name='var_ema_op') 380 | add_model_variable(moving_mean) 381 | add_model_variable(moving_var) 382 | 383 | # seems faster than delayed update, but might behave otherwise in distributed settings. 384 | with tf.control_dependencies([_update_op1, _update_op2]): 385 | return tf.identity(xn, name='output') 386 | 387 | # ======================== Checking valid values ========================= 388 | if data_format not in ['NHWC', 'NCHW']: 389 | raise TypeError( 390 | "Only two data formats are supported at this moment: 'NHWC' or 'NCHW', " 391 | "%s is an unknown data format." % data_format) 392 | assert type(list_input) == list 393 | 394 | # ======================== Setting default values ========================= 395 | shape = list_input[0].get_shape().as_list() 396 | assert len(shape) in [2, 4] 397 | n_out = shape[-1] 398 | if data_format == 'NCHW': 399 | n_out = shape[1] 400 | 401 | # ======================== Main operations ============================= 402 | means = [] 403 | square_means = [] 404 | for i in range(len(list_input)): 405 | with tf.device('/gpu:%d' % i): 406 | batch_mean = tf.reduce_mean(list_input[i], [0, 1, 2]) 407 | batch_square_mean = tf.reduce_mean(tf.square(list_input[i]), [0, 1, 2]) 408 | means.append(batch_mean) 409 | square_means.append(batch_square_mean) 410 | 411 | # if your GPUs have NVLinks and you've install NCCL2, you can change `/cpu:0` to `/gpu:0` 412 | with tf.device('/cpu:0'): 413 | shape = tf.shape(list_input[0]) 414 | num = shape[0] * shape[1] * shape[2] * len(list_input) 415 | mean = tf.reduce_mean(means, axis=0) 416 | var = tf.reduce_mean(square_means, axis=0) - tf.square(mean) 417 | var *= tf.cast(num, float_type) / tf.cast(num - 1, float_type) # unbiased variance 418 | 419 | list_output = [] 420 | for i in range(len(list_input)): 421 | with tf.device('/gpu:%d' % i): 422 | with tf.variable_scope(name, reuse=i > 0): 423 | beta, gamma, moving_mean, moving_var = _get_bn_variables( 424 | n_out, use_gamma, use_beta, trainable, float_type) 425 | 426 | if 'train' in stats_mode: 427 | xn = tf.nn.batch_normalization( 428 | list_input[i], mean, var, beta, gamma, bn_epsilon) 429 | if tf.get_variable_scope().reuse or 'gather' not in stats_mode: 430 | list_output.append(xn) 431 | else: 432 | # gather stats and it is the main gpu device. 433 | xn = _update_bn_ema(xn, mean, var, moving_mean, moving_var, bn_ema) 434 | list_output.append(xn) 435 | else: 436 | xn = tf.nn.batch_normalization( 437 | list_input[i], moving_mean, moving_var, beta, gamma, bn_epsilon) 438 | list_output.append(xn) 439 | 440 | return list_output 441 | 442 | @staticmethod 443 | def layergn(inputdata, name, group_size=32, esp=1e-5): 444 | """ 445 | 446 | :param inputdata: 447 | :param name: 448 | :param group_size: 449 | :param esp: 450 | :return: 451 | """ 452 | with tf.variable_scope(name): 453 | inputdata = tf.transpose(inputdata, [0, 3, 1, 2]) 454 | n, c, h, w = inputdata.get_shape().as_list() 455 | group_size = min(group_size, c) 456 | inputdata = tf.reshape(inputdata, [-1, group_size, c // group_size, h, w]) 457 | mean, var = tf.nn.moments(inputdata, [2, 3, 4], keep_dims=True) 458 | inputdata = (inputdata - mean) / tf.sqrt(var + esp) 459 | 460 | # 每个通道的gamma和beta 461 | gamma = tf.Variable(tf.constant(1.0, shape=[c]), dtype=tf.float32, name='gamma') 462 | beta = tf.Variable(tf.constant(0.0, shape=[c]), dtype=tf.float32, name='beta') 463 | gamma = tf.reshape(gamma, [1, c, 1, 1]) 464 | beta = tf.reshape(beta, [1, c, 1, 1]) 465 | 466 | # 根据论文进行转换 [n, c, h, w, c] 到 [n, h, w, c] 467 | output = tf.reshape(inputdata, [-1, c, h, w]) 468 | output = output * gamma + beta 469 | output = tf.transpose(output, [0, 2, 3, 1]) 470 | 471 | return output 472 | 473 | @staticmethod 474 | def squeeze(inputdata, axis=None, name=None): 475 | """ 476 | 477 | :param inputdata: 478 | :param axis: 479 | :param name: 480 | :return: 481 | """ 482 | return tf.squeeze(input=inputdata, axis=axis, name=name) 483 | 484 | @staticmethod 485 | def deconv2d(inputdata, out_channel, kernel_size, padding='SAME', 486 | stride=1, w_init=None, b_init=None, 487 | use_bias=True, activation=None, data_format='channels_last', 488 | trainable=True, name=None): 489 | """ 490 | Packing the tensorflow conv2d function. 491 | :param name: op name 492 | :param inputdata: A 4D tensorflow tensor which ust have known number of channels, but can have other 493 | unknown dimensions. 494 | :param out_channel: number of output channel. 495 | :param kernel_size: int so only support square kernel convolution 496 | :param padding: 'VALID' or 'SAME' 497 | :param stride: int so only support square stride 498 | :param w_init: initializer for convolution weights 499 | :param b_init: initializer for bias 500 | :param activation: whether to apply a activation func to deconv result 501 | :param use_bias: whether to use bias. 502 | :param data_format: default set to NHWC according tensorflow 503 | :param trainable: 504 | :return: tf.Tensor named ``output`` 505 | """ 506 | with tf.variable_scope(name): 507 | in_shape = inputdata.get_shape().as_list() 508 | channel_axis = 3 if data_format == 'channels_last' else 1 509 | in_channel = in_shape[channel_axis] 510 | assert in_channel is not None, "[Deconv2D] Input cannot have unknown channel!" 511 | 512 | padding = padding.upper() 513 | 514 | if w_init is None: 515 | w_init = tf.contrib.layers.variance_scaling_initializer() 516 | if b_init is None: 517 | b_init = tf.constant_initializer() 518 | 519 | ret = tf.layers.conv2d_transpose(inputs=inputdata, filters=out_channel, 520 | kernel_size=kernel_size, 521 | strides=stride, padding=padding, 522 | data_format=data_format, 523 | activation=activation, use_bias=use_bias, 524 | kernel_initializer=w_init, 525 | bias_initializer=b_init, trainable=trainable, 526 | name=name) 527 | return ret 528 | 529 | @staticmethod 530 | def dilation_conv(input_tensor, k_size, out_dims, rate, padding='SAME', 531 | w_init=None, b_init=None, use_bias=False, name=None): 532 | """ 533 | 534 | :param input_tensor: 535 | :param k_size: 536 | :param out_dims: 537 | :param rate: 538 | :param padding: 539 | :param w_init: 540 | :param b_init: 541 | :param use_bias: 542 | :param name: 543 | :return: 544 | """ 545 | with tf.variable_scope(name): 546 | in_shape = input_tensor.get_shape().as_list() 547 | in_channel = in_shape[3] 548 | 549 | assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!" 550 | 551 | padding = padding.upper() 552 | 553 | if isinstance(k_size, list): 554 | filter_shape = [k_size[0], k_size[1]] + [in_channel, out_dims] 555 | else: 556 | filter_shape = [k_size, k_size] + [in_channel, out_dims] 557 | 558 | if w_init is None: 559 | w_init = tf.contrib.layers.variance_scaling_initializer() 560 | if b_init is None: 561 | b_init = tf.constant_initializer() 562 | 563 | w = tf.get_variable('W', filter_shape, initializer=w_init) 564 | b = None 565 | 566 | if use_bias: 567 | b = tf.get_variable('b', [out_dims], initializer=b_init) 568 | 569 | conv = tf.nn.atrous_conv2d(value=input_tensor, filters=w, rate=rate, 570 | padding=padding, name='dilation_conv') 571 | 572 | if use_bias: 573 | ret = tf.add(conv, b) 574 | else: 575 | ret = conv 576 | 577 | return ret 578 | 579 | @staticmethod 580 | def spatial_dropout(input_tensor, keep_prob, is_training, name, seed=1234): 581 | """ 582 | 空间dropout实现 583 | :param input_tensor: 584 | :param keep_prob: 585 | :param is_training: 586 | :param name: 587 | :param seed: 588 | :return: 589 | """ 590 | 591 | def f1(): 592 | input_shape = input_tensor.get_shape().as_list() 593 | noise_shape = tf.constant(value=[input_shape[0], 1, 1, input_shape[3]]) 594 | return tf.nn.dropout(input_tensor, keep_prob, noise_shape, seed=seed, name="spatial_dropout") 595 | 596 | def f2(): 597 | return input_tensor 598 | 599 | with tf.variable_scope(name_or_scope=name): 600 | 601 | output = tf.cond(is_training, f1, f2) 602 | 603 | return output 604 | 605 | @staticmethod 606 | def lrelu(inputdata, name, alpha=0.2): 607 | """ 608 | 609 | :param inputdata: 610 | :param alpha: 611 | :param name: 612 | :return: 613 | """ 614 | with tf.variable_scope(name): 615 | return tf.nn.relu(inputdata) - alpha * tf.nn.relu(-inputdata) 616 | 617 | @staticmethod 618 | def pad(inputdata, paddings, name): 619 | """ 620 | 621 | :param inputdata: 622 | :param paddings: 623 | :return: 624 | """ 625 | with tf.variable_scope(name_or_scope=name): 626 | return tf.pad(tensor=inputdata, paddings=paddings) 627 | -------------------------------------------------------------------------------- /nsfw_model/nsfw_classification_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 19-2-15 下午6:56 4 | # @Author : MaybeShewill-CV 5 | # @Site : https://github.com/MaybeShewill-CV/CRNN_Tensorflow 6 | # @File : nsfw_classification_net.py 7 | # @IDE: PyCharm 8 | """ 9 | NSFW Classification Net Model 10 | """ 11 | import tensorflow as tf 12 | 13 | from nsfw_model import cnn_basenet 14 | from config import global_config 15 | 16 | 17 | CFG = global_config.cfg 18 | 19 | 20 | class NSFWNet(cnn_basenet.CNNBaseModel): 21 | """ 22 | nsfw classification net 23 | """ 24 | def __init__(self, phase, resnet_size=CFG.NET.RESNET_SIZE): 25 | """ 26 | 27 | :param phase: 28 | """ 29 | super(NSFWNet, self).__init__() 30 | self._train_phase = tf.constant('train', dtype=tf.string) 31 | self._test_phase = tf.constant('test', dtype=tf.string) 32 | self._phase = phase 33 | self._is_training = self._init_phase() 34 | self._need_summary_feats_map = CFG.NET.NEED_SUMMARY_FEATS_MAP 35 | self._resnet_size = resnet_size 36 | self._block_sizes = self._get_block_sizes(self._resnet_size) 37 | self._block_strides = [1, 2, 2, 2] 38 | 39 | def _init_phase(self): 40 | """ 41 | 42 | :return: 43 | """ 44 | return tf.equal(self._phase, self._train_phase) 45 | 46 | @staticmethod 47 | def _get_block_sizes(resnet_size): 48 | """ 49 | Retrieve the size of each block_layer in the ResNet model. 50 | The number of block layers used for the Resnet model varies according 51 | to the size of the model. This helper grabs the layer set we want, throwing 52 | an error if a non-standard size has been selected. 53 | Args: 54 | resnet_size: The number of convolutional layers needed in the model. 55 | Returns: 56 | A list of block sizes to use in building the model. 57 | Raises: 58 | KeyError: if invalid resnet_size is received. 59 | """ 60 | choices = { 61 | 32: [3, 4, 6, 3], 62 | 50: [3, 4, 6, 3] 63 | } 64 | 65 | try: 66 | return choices[resnet_size] 67 | except KeyError: 68 | err = ('Could not find layers for selected Resnet size.\n' 69 | 'Size received: {}; sizes allowed: {}.'.format(resnet_size, choices.keys())) 70 | raise ValueError(err) 71 | 72 | @staticmethod 73 | def _feature_map_summary(input_tensor, slice_nums, axis): 74 | """ 75 | summary feature map 76 | :param input_tensor: A Tensor 77 | :return: Add histogram summary and scalar summary of the sparsity of the tensor 78 | """ 79 | tensor_name = input_tensor.op.name 80 | 81 | split = tf.split(input_tensor, num_or_size_splits=slice_nums, axis=axis) 82 | for i in range(slice_nums): 83 | tf.summary.image(tensor_name + "/feature_maps_" + str(i), split[i]) 84 | 85 | def _fixed_padding(self, inputs, kernel_size, name): 86 | """Pads the input along the spatial dimensions independently of input size. 87 | Args: 88 | inputs: A tensor of size [batch, channels, height_in, width_in] or 89 | [batch, height_in, width_in, channels] depending on data_format. 90 | kernel_size: The kernel to be used in the conv2d or max_pool2d operation. 91 | Should be a positive integer. 92 | name: 93 | Returns: 94 | A tensor with the same format as the input with the data either intact 95 | (if kernel_size == 1) or padded (if kernel_size > 1). 96 | """ 97 | with tf.variable_scope(name_or_scope=name): 98 | pad_total = kernel_size - 1 99 | pad_beg = pad_total // 2 100 | pad_end = pad_total - pad_beg 101 | 102 | padded_inputs = self.pad(inputdata=inputs, 103 | paddings=[[0, 0], [pad_beg, pad_end], 104 | [pad_beg, pad_end], [0, 0]], 105 | name='pad') 106 | return padded_inputs 107 | 108 | def _conv2d_fixed_padding(self, inputs, kernel_size, output_dims, strides, name): 109 | """ 110 | 111 | :param inputs: 112 | :param kernel_size: 113 | :param output_dims: 114 | :param strides: 115 | :param name: 116 | :return: 117 | """ 118 | with tf.variable_scope(name_or_scope=name): 119 | if strides > 1: 120 | inputs = self._fixed_padding(inputs, kernel_size, name='fix_padding') 121 | 122 | result = self.conv2d(inputdata=inputs, out_channel=output_dims, kernel_size=kernel_size, 123 | stride=strides, padding=('SAME' if strides == 1 else 'VALID'), 124 | use_bias=False, name='conv') 125 | 126 | return result 127 | 128 | def _process_image_input_tensor(self, input_image_tensor, kernel_size, 129 | conv_stride, output_dims, pool_size, pool_stride): 130 | """ 131 | Resnet entry 132 | :param input_image_tensor: 133 | :param kernel_size: 134 | :param conv_stride: 135 | :param output_dims: 136 | :param pool_size: 137 | :param pool_stride: 138 | :return: 139 | """ 140 | inputs = self._conv2d_fixed_padding( 141 | inputs=input_image_tensor, kernel_size=kernel_size, 142 | strides=conv_stride, output_dims=output_dims, name='initial_conv_pad') 143 | inputs = tf.identity(inputs, 'initial_conv') 144 | 145 | inputs = self.maxpooling(inputdata=inputs, kernel_size=pool_size, 146 | stride=pool_stride, padding='SAME', 147 | name='initial_max_pool') 148 | 149 | return inputs 150 | 151 | def _resnet_block_fn(self, input_tensor, kernel_size, stride, 152 | output_dims, name, projection_shortcut=None): 153 | """ 154 | A single block for ResNet v2, without a bottleneck. 155 | Batch normalization then ReLu then convolution as described by: 156 | Identity Mappings in Deep Residual Networks 157 | https://arxiv.org/pdf/1603.05027.pdf 158 | by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016. 159 | :param input_tensor: 160 | :param kernel_size: 161 | :param stride: 162 | :param output_dims: 163 | :param name: 164 | :param projection_shortcut: 165 | :return: 166 | """ 167 | with tf.variable_scope(name_or_scope=name): 168 | shortcut = input_tensor 169 | inputs = self.layerbn(inputdata=input_tensor, is_training=self._is_training, name='bn_1') 170 | inputs = self.relu(inputdata=inputs, name='relu_1') 171 | 172 | if projection_shortcut is not None: 173 | shortcut = projection_shortcut(inputs) 174 | 175 | inputs = self._conv2d_fixed_padding( 176 | inputs=inputs, output_dims=output_dims, 177 | kernel_size=kernel_size, strides=stride, name='conv_pad_1') 178 | 179 | inputs = self.layerbn(inputdata=inputs, is_training=self._is_training, name='bn_2') 180 | inputs = self.relu(inputdata=inputs, name='relu_2') 181 | inputs = self._conv2d_fixed_padding( 182 | inputs=inputs, output_dims=output_dims, kernel_size=kernel_size, strides=1, name='conv_pad_2') 183 | 184 | return inputs + shortcut 185 | 186 | def _resnet_block_layer(self, input_tensor, kernel_size, stride, block_nums, output_dims, name): 187 | """ 188 | 189 | :param input_tensor: 190 | :param kernel_size: 191 | :param stride: 192 | :param block_nums: 193 | :param name: 194 | :return: 195 | """ 196 | def projection_shortcut(_inputs): 197 | return self._conv2d_fixed_padding( 198 | inputs=_inputs, output_dims=output_dims, kernel_size=1, 199 | strides=stride, name='projection_shortcut') 200 | 201 | with tf.variable_scope(name): 202 | inputs = self._resnet_block_fn(input_tensor=input_tensor, 203 | kernel_size=kernel_size, 204 | output_dims=output_dims, 205 | projection_shortcut=projection_shortcut, 206 | stride=stride, 207 | name='init_block_fn') 208 | 209 | for index in range(1, block_nums): 210 | inputs = self._resnet_block_fn(input_tensor=inputs, 211 | kernel_size=kernel_size, 212 | output_dims=output_dims, 213 | projection_shortcut=None, 214 | stride=1, 215 | name='block_fn_{:d}'.format(index)) 216 | return inputs 217 | 218 | def inference(self, input_tensor, name, reuse=False): 219 | """ 220 | The main function that defines the ResNet. total layers = 1 + 2n + 2n + 2n + 1 = 6n + 2 221 | :param input_tensor: 4D tensor 222 | :param name: net name 223 | :param reuse: To build train graph, reuse=False. To build validation graph and share weights 224 | with train graph, resue=True 225 | :return: last layer in the network. Not softmax-ed 226 | """ 227 | with tf.variable_scope(name_or_scope=name, reuse=reuse): 228 | 229 | if self._need_summary_feats_map: 230 | self._feature_map_summary(input_tensor=input_tensor, slice_nums=1, axis=-1) 231 | 232 | # first layer process 233 | inputs = self._process_image_input_tensor(input_image_tensor=input_tensor, 234 | kernel_size=7, 235 | conv_stride=2, 236 | output_dims=64, 237 | pool_size=3, 238 | pool_stride=2) 239 | if self._need_summary_feats_map: 240 | self._feature_map_summary(input_tensor=inputs, slice_nums=64, axis=-1) 241 | 242 | for index, block_nums in enumerate(self._block_sizes): 243 | output_dims = 64 * (2 ** index) 244 | 245 | inputs = self._resnet_block_layer(input_tensor=inputs, 246 | kernel_size=3, 247 | output_dims=output_dims, 248 | block_nums=block_nums, 249 | stride=self._block_strides[index], 250 | name='resnet_block_layer_{:d}'.format(index + 1)) 251 | 252 | if self._need_summary_feats_map: 253 | self._feature_map_summary(input_tensor=inputs, slice_nums=output_dims, axis=-1) 254 | 255 | inputs = self.layerbn(inputdata=inputs, is_training=self._is_training, name='bn_after_block_layer') 256 | inputs = self.relu(inputdata=inputs, name='relu_after_block_layer') 257 | 258 | inputs = tf.reduce_mean(input_tensor=inputs, axis=[1, 2], keepdims=True, name='final_reduce_mean') 259 | inputs = tf.squeeze(input=inputs, axis=[1, 2], name='final_squeeze') 260 | 261 | final_logits = self.fullyconnect(inputdata=inputs, out_dim=CFG.TRAIN.CLASSES_NUMS, 262 | use_bias=False, name='final_logits') 263 | 264 | return final_logits 265 | 266 | def compute_loss(self, input_tensor, labels, name, reuse=False): 267 | """ 268 | 269 | :param input_tensor: 270 | :param labels: 271 | :param name: 272 | :param reuse: 273 | :return: 274 | """ 275 | labels = tf.cast(labels, tf.int64) 276 | 277 | inference_logits = self.inference(input_tensor=input_tensor, 278 | name=name, 279 | reuse=reuse) 280 | 281 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=inference_logits, 282 | labels=labels, 283 | name='cross_entropy_per_example') 284 | cross_entropy_loss = tf.reduce_mean(cross_entropy, name='cross_entropy') 285 | 286 | l2_loss = CFG.TRAIN.WEIGHT_DECAY * tf.add_n( 287 | [tf.nn.l2_loss(tf.cast(vv, tf.float32)) for vv in tf.trainable_variables() 288 | if 'bn' not in vv.name]) 289 | 290 | total_loss = cross_entropy_loss + l2_loss 291 | 292 | return total_loss 293 | 294 | 295 | if __name__ == '__main__': 296 | """ 297 | test code 298 | """ 299 | image_tensor = tf.placeholder(shape=[16, 256, 256, 3], dtype=tf.float32) 300 | label_tensor = tf.placeholder(shape=[16], dtype=tf.int32) 301 | 302 | net = NSFWNet(phase=tf.constant('train', dtype=tf.string)) 303 | 304 | loss = net.compute_loss(input_tensor=image_tensor, 305 | labels=label_tensor, 306 | name='net', 307 | reuse=False) 308 | loss_val = net.compute_loss(input_tensor=image_tensor, 309 | labels=label_tensor, 310 | name='net', 311 | reuse=True) 312 | 313 | logits = net.inference(input_tensor=image_tensor, 314 | name='net', 315 | reuse=True) 316 | 317 | print(loss.get_shape().as_list()) 318 | print(loss_val.get_shape().as_list()) 319 | print(logits.get_shape().as_list()) 320 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict==1.9 2 | matplotlib==3.0.2 3 | numpy==1.15.1 4 | opencv_python==3.4.1.15 5 | tensorflow_gpu==1.15.4 6 | glog==0.3.1 7 | scikit_learn==0.20.2 8 | -------------------------------------------------------------------------------- /tboard/nsfw_cls/events.out.tfevents.1551264389.baidu-pc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/tboard/nsfw_cls/events.out.tfevents.1551264389.baidu-pc -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 19-2-14 下午5:39 4 | # @Author : Luo Yao 5 | # @Site : http://icode.baidu.com/repos/baidu/personal-code/Luoyao 6 | # @File : __init__.py.py 7 | # @IDE: PyCharm -------------------------------------------------------------------------------- /tools/convert_tfjs_model.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | tensorflowjs_converter --input_format=tf_saved_model --output_node_name=nsfw_cls_model/final_prediction \ 3 | --saved_model_tags=serve ./model/nsfw_export_saved_model ./model/nsfw_web_model -------------------------------------------------------------------------------- /tools/evaluate_nsfw.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 19-2-28 上午11:20 4 | # @Author : MaybeShewill-CV 5 | # @Site : https://github.com/MaybeShewill-CV/CRNN_Tensorflow 6 | # @File : evaluate_nsfw.py 7 | # @IDE: PyCharm 8 | """ 9 | Evaluate nsfw model 10 | """ 11 | import itertools 12 | import os.path as ops 13 | import argparse 14 | import glog as log 15 | 16 | import tensorflow as tf 17 | import numpy as np 18 | from sklearn.metrics import (confusion_matrix, precision_score, recall_score, 19 | precision_recall_curve, average_precision_score, f1_score) 20 | from sklearn.metrics import classification_report 21 | from sklearn.preprocessing import label_binarize 22 | from sklearn.utils.fixes import signature 23 | import matplotlib.pyplot as plt 24 | 25 | from config import global_config 26 | from nsfw_model import nsfw_classification_net 27 | from data_provider import nsfw_data_feed_pipline 28 | 29 | CFG = global_config.cfg 30 | 31 | _R_MEAN = 123.68 32 | _G_MEAN = 116.78 33 | _B_MEAN = 103.94 34 | _CHANNEL_MEANS = [_B_MEAN, _G_MEAN, _R_MEAN] 35 | 36 | 37 | def init_args(): 38 | """ 39 | 40 | :return: 41 | """ 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument('--dataset_dir', type=str, default=None, help='The dataset dir') 44 | parser.add_argument('--weights_path', type=str, help='The model weights file path') 45 | parser.add_argument('--top_k', type=int, default=1, help='Evaluate top k error') 46 | 47 | return parser.parse_args() 48 | 49 | 50 | def calculate_top_k_error(predictions, labels, k=1): 51 | """ 52 | Calculate the top-k error 53 | :param predictions: 2D tensor with shape [batch_size, num_labels] 54 | :param labels: 1D tensor with shape [batch_size, 1] 55 | :param k: int 56 | :return: tensor with shape [1] 57 | """ 58 | batch_size = CFG.TEST.BATCH_SIZE 59 | in_top_k = tf.to_float(tf.nn.in_top_k(predictions, labels, k=k)) 60 | num_correct = tf.reduce_sum(in_top_k) 61 | 62 | return (batch_size - num_correct) / float(batch_size) 63 | 64 | 65 | def plot_confusion_matrix(cm, classes, 66 | normalize=False, 67 | title='Confusion matrix', 68 | cmap=plt.cm.Blues): 69 | """ 70 | This function prints and plots the confusion matrix. 71 | Normalization can be applied by setting `normalize=True`. 72 | """ 73 | if normalize: 74 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] 75 | log.info("Normalized confusion matrix") 76 | else: 77 | log.info('Confusion matrix, without normalization') 78 | 79 | print(cm) 80 | 81 | plt.imshow(cm, interpolation='nearest', cmap=cmap) 82 | plt.title(title) 83 | plt.colorbar() 84 | tick_marks = np.arange(len(classes)) 85 | plt.xticks(tick_marks, classes, rotation=45) 86 | plt.yticks(tick_marks, classes) 87 | 88 | fmt = '.2f' if normalize else 'd' 89 | thresh = cm.max() / 2. 90 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): 91 | plt.text(j, i, format(cm[i, j], fmt), 92 | horizontalalignment="center", 93 | color="white" if cm[i, j] > thresh else "black") 94 | 95 | plt.ylabel('True label') 96 | plt.xlabel('Predicted label') 97 | plt.tight_layout() 98 | 99 | 100 | def plot_precision_recall_curve(labels, predictions_prob, class_nums, average_function='weighted'): 101 | """ 102 | Plot precision recall curve 103 | :param labels: 104 | :param predictions_prob: 105 | :param class_nums: 106 | :param average_function: 107 | :return: 108 | """ 109 | labels = label_binarize(labels, classes=np.linspace(0, class_nums - 1, num=class_nums).tolist()) 110 | predictions_prob = np.array(predictions_prob, dtype=np.float32) 111 | 112 | precision = dict() 113 | recall = dict() 114 | average_precision = dict() 115 | 116 | for i in range(class_nums): 117 | precision[i], recall[i], _ = precision_recall_curve(labels[:, i], 118 | predictions_prob[:, i]) 119 | average_precision[i] = average_precision_score(labels[:, i], predictions_prob[:, i]) 120 | 121 | precision[average_function], recall[average_function], _ = precision_recall_curve( 122 | labels.ravel(), predictions_prob.ravel()) 123 | average_precision[average_function] = average_precision_score( 124 | labels, predictions_prob, average=average_function) 125 | log.info('Average precision score, {:s}-averaged ' 126 | 'over all classes: {:.5f}'.format(average_function, average_precision[average_function])) 127 | 128 | plt.figure() 129 | plt.step(recall[average_function], precision[average_function], color='b', alpha=0.2, 130 | where='post') 131 | step_kwargs = ({'step': 'post'} if 'step' in signature(plt.fill_between).parameters else {}) 132 | plt.fill_between(recall[average_function], precision[average_function], alpha=0.2, color='b', 133 | **step_kwargs) 134 | 135 | plt.xlabel('Recall') 136 | plt.ylabel('Precision') 137 | plt.ylim([0.0, 1.05]) 138 | plt.xlim([0.0, 1.0]) 139 | plt.title( 140 | 'Average precision score, {:s}-averaged over ' 141 | 'all classes: AP={:.5f}'.format(average_function, average_precision[average_function])) 142 | 143 | 144 | def calculate_evaluate_statics(labels, predictions, model_name='Nsfw', avgerage_method='weighted'): 145 | """ 146 | Calculate Precision, Recall and F1 score 147 | :param labels: 148 | :param predictions: 149 | :param model_name: 150 | :param avgerage_method: 151 | :return: 152 | """ 153 | log.info('Model name: {:s}:'.format(model_name)) 154 | log.info('\tPrecision: {:.5f}'.format(precision_score(y_true=labels, 155 | y_pred=predictions, 156 | average=avgerage_method))) 157 | log.info('\tRecall: {:.5f}'.format(recall_score(y_true=labels, 158 | y_pred=predictions, 159 | average=avgerage_method))) 160 | log.info('\tF1: {:.5f}\n'.format(f1_score(y_true=labels, 161 | y_pred=predictions, 162 | average=avgerage_method))) 163 | 164 | 165 | def nsfw_eval_dataset(dataset_dir, weights_path): 166 | """ 167 | Evaluate the nsfw dataset 168 | :param dataset_dir: The nsfw dataset dir which contains tensorflow records file 169 | :param weights_path: The pretrained nsfw model weights file path 170 | :return: 171 | """ 172 | assert ops.exists(dataset_dir) 173 | 174 | # set nsfw data feed pipline 175 | test_dataset = nsfw_data_feed_pipline.NsfwDataFeeder(dataset_dir=dataset_dir, 176 | flags='test') 177 | prediciton_map = test_dataset.prediction_map 178 | class_names = ['drawing', 'hentai', 'neural', 'porn', 'sexy'] 179 | 180 | with tf.device('/gpu:1'): 181 | # set nsfw classification model 182 | phase = tf.constant('test', dtype=tf.string) 183 | 184 | # set nsfw net 185 | nsfw_net = nsfw_classification_net.NSFWNet(phase=phase, 186 | resnet_size=CFG.NET.RESNET_SIZE) 187 | 188 | # compute train loss 189 | images, labels = test_dataset.inputs(batch_size=CFG.TEST.BATCH_SIZE, 190 | num_epochs=1) 191 | 192 | logits = nsfw_net.inference(input_tensor=images, 193 | name='nsfw_cls_model', 194 | reuse=False) 195 | 196 | predictions = tf.nn.softmax(logits) 197 | 198 | # Restore the moving average version of the learned variables for eval. 199 | variable_averages = tf.train.ExponentialMovingAverage( 200 | CFG.TRAIN.MOVING_AVERAGE_DECAY) 201 | variables_to_restore = variable_averages.variables_to_restore() 202 | 203 | # set tensorflow saver 204 | saver = tf.train.Saver(variables_to_restore) 205 | 206 | # Set sess configuration 207 | sess_config = tf.ConfigProto(allow_soft_placement=True) 208 | sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION 209 | sess_config.gpu_options.allow_growth = CFG.TEST.TF_ALLOW_GROWTH 210 | sess_config.gpu_options.allocator_type = 'BFC' 211 | 212 | sess = tf.Session(config=sess_config) 213 | 214 | # labels overall test dataset 215 | labels_total = [] 216 | # prediction result overall test dataset 217 | predictions_total = [] 218 | # prediction score overall test dataset of all subclass 219 | predictions_prob_total = [] 220 | 221 | with sess.as_default(): 222 | 223 | saver.restore(sess=sess, save_path=weights_path) 224 | 225 | while True: 226 | try: 227 | predictions_vals, labels_vals = sess.run( 228 | fetches=[predictions, 229 | labels]) 230 | 231 | log.info('**************') 232 | log.info('Test dataset batch size: {:d}'.format(predictions_vals.shape[0])) 233 | log.info('---- Sample Id ---- Gt label ---- Prediction ----') 234 | 235 | for index, predictions_val in enumerate(predictions_vals): 236 | 237 | label_gt = prediciton_map[labels_vals[index]] 238 | 239 | prediction_score = dict() 240 | 241 | for score_index, score in enumerate(predictions_val): 242 | prediction_score[prediciton_map[score_index]] = format(score, '.5f') 243 | 244 | log.info('---- {:d} ---- {:s} ---- {}'.format(index, label_gt, prediction_score)) 245 | 246 | # record predicts prob map 247 | predictions_prob_total.append(predictions_val.tolist()) 248 | 249 | # record total label and prediction results 250 | labels_total.extend(labels_vals.tolist()) 251 | predictions_total.extend(np.argmax(predictions_vals, axis=1).tolist()) 252 | 253 | except tf.errors.OutOfRangeError as err: 254 | log.info('Loop overall the test dataset') 255 | break 256 | except Exception as err: 257 | log.error(err) 258 | break 259 | 260 | # print prediction report 261 | print('Nsfw classification_report(left: labels):') 262 | print(classification_report(labels_total, predictions_total)) 263 | 264 | # calculate confusion matrix 265 | cnf_matrix = confusion_matrix(labels_total, predictions_total) 266 | np.set_printoptions(precision=2) 267 | plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True, 268 | title='Normalized confusion matrix') 269 | 270 | # calculate evaluate statics 271 | calculate_evaluate_statics(labels=labels_total, predictions=predictions_total) 272 | 273 | # plot precision recall curve 274 | plot_precision_recall_curve(labels=labels_total, 275 | predictions_prob=predictions_prob_total, 276 | class_nums=5) 277 | plt.show() 278 | 279 | return 280 | 281 | 282 | if __name__ == '__main__': 283 | # init args 284 | args = init_args() 285 | 286 | # test net 287 | assert ops.exists(args.dataset_dir) 288 | 289 | nsfw_eval_dataset(args.dataset_dir, args.weights_path) 290 | -------------------------------------------------------------------------------- /tools/export_nsfw_saved_model.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | python tools/export_saved_model.py --ckpt_path model/nsfw_cls/nsfw_cls_2019-02-27-18-46-28.ckpt-160000 \ 4 | --export_dir ./model/nsfw_export_saved_model -------------------------------------------------------------------------------- /tools/export_saved_model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 19-2-22 上午11:10 4 | # @Author : MaybeShewill-CV 5 | # @Site : https://github.com/MaybeShewill-CV/CRNN_Tensorflow 6 | # @File : export_saved_model.py 7 | # @IDE: PyCharm 8 | """ 9 | Build tensorflow saved model for tensorflowjs converter to use 10 | """ 11 | import os.path as ops 12 | import argparse 13 | import glog as log 14 | 15 | import cv2 16 | import numpy as np 17 | import matplotlib.pyplot as plt 18 | import tensorflow as tf 19 | from tensorflow import saved_model as sm 20 | 21 | from config import global_config 22 | from nsfw_model import nsfw_classification_net 23 | 24 | CFG = global_config.cfg 25 | 26 | _R_MEAN = 123.68 27 | _G_MEAN = 116.78 28 | _B_MEAN = 103.94 29 | _CHANNEL_MEANS = [_B_MEAN, _G_MEAN, _R_MEAN] 30 | 31 | 32 | def init_args(): 33 | """ 34 | 35 | :return: 36 | """ 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--export_dir', type=str, help='The model export dir') 39 | parser.add_argument('--ckpt_path', type=str, help='The pretrained ckpt model weights file path') 40 | 41 | return parser.parse_args() 42 | 43 | 44 | def central_crop(image, central_fraction): 45 | """ 46 | 47 | :param image: 48 | :param central_fraction: 49 | :return: 50 | """ 51 | 52 | image_shape = np.shape(image) 53 | image_height = image_shape[0] 54 | image_width = image_shape[1] 55 | 56 | if central_fraction >= 1 or central_fraction < 0: 57 | raise ValueError('Central fraction should be in [0, 1)') 58 | 59 | top = int((1 - central_fraction) * image_height / 2) 60 | bottom = image_height - top 61 | left = int((1 - central_fraction) * image_width / 2) 62 | right = image_width - left 63 | 64 | if not image_height or not image_width: 65 | raise ValueError('Image shape with zero') 66 | 67 | if len(image_shape) == 2: 68 | crop_image = image[top:bottom, left:right] 69 | return crop_image 70 | elif len(image_shape) == 3: 71 | crop_image = image[top:bottom, left:right, :] 72 | return crop_image 73 | else: 74 | raise ValueError('Wrong image shape') 75 | 76 | 77 | def build_saved_model(ckpt_path, export_dir): 78 | """ 79 | Convert source ckpt weights file into tensorflow saved model 80 | :param ckpt_path: 81 | :param export_dir: 82 | :return: 83 | """ 84 | 85 | if ops.exists(export_dir): 86 | raise ValueError('Export dir must be a dir path that does not exist') 87 | 88 | assert ops.exists(ops.split(ckpt_path)[0]) 89 | 90 | # build inference tensorflow graph 91 | image_tensor = tf.placeholder(dtype=tf.float32, 92 | shape=[1, CFG.TRAIN.CROP_IMG_HEIGHT, CFG.TRAIN.CROP_IMG_WIDTH, 3], 93 | name='input_tensor') 94 | # set nsfw net 95 | phase = tf.constant('test', dtype=tf.string) 96 | nsfw_net = nsfw_classification_net.NSFWNet(phase=phase, 97 | resnet_size=CFG.NET.RESNET_SIZE) 98 | 99 | # compute inference logits 100 | logits = nsfw_net.inference(input_tensor=image_tensor, 101 | name='nsfw_cls_model', 102 | reuse=False) 103 | 104 | predictions = tf.nn.softmax(logits, name='nsfw_cls_model/final_prediction') 105 | 106 | # Restore the moving average version of the learned variables for eval. 107 | variable_averages = tf.train.ExponentialMovingAverage( 108 | CFG.TRAIN.MOVING_AVERAGE_DECAY) 109 | variables_to_restore = variable_averages.variables_to_restore() 110 | 111 | # set tensorflow saver 112 | saver = tf.train.Saver(variables_to_restore) 113 | 114 | # Set sess configuration 115 | sess_config = tf.ConfigProto(allow_soft_placement=True) 116 | sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION 117 | sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH 118 | sess_config.gpu_options.allocator_type = 'BFC' 119 | 120 | sess = tf.Session(config=sess_config) 121 | 122 | with sess.as_default(): 123 | 124 | saver.restore(sess=sess, save_path=ckpt_path) 125 | 126 | # set model save builder 127 | saved_builder = sm.builder.SavedModelBuilder(export_dir) 128 | 129 | # add tensor need to be saved 130 | saved_input_tensor = sm.utils.build_tensor_info(image_tensor) 131 | saved_prediction_tensor = sm.utils.build_tensor_info(predictions) 132 | 133 | # build SignatureDef protobuf 134 | signatur_def = sm.signature_def_utils.build_signature_def( 135 | inputs={'input_tensor': saved_input_tensor}, 136 | outputs={'prediction': saved_prediction_tensor}, 137 | method_name=tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME 138 | ) 139 | 140 | # add graph into MetaGraphDef protobuf 141 | saved_builder.add_meta_graph_and_variables( 142 | sess, 143 | tags=[sm.tag_constants.SERVING], 144 | signature_def_map={'classify_result': signatur_def} 145 | ) 146 | 147 | # save model 148 | saved_builder.save() 149 | 150 | return 151 | 152 | 153 | def test_load_saved_model(saved_model_dir): 154 | """ 155 | 156 | :param saved_model_dir: 157 | :return: 158 | """ 159 | 160 | prediciton_map = global_config.NSFW_PREDICT_MAP 161 | 162 | image = cv2.imread('data/test_data/drawing_16715.jpg', cv2.IMREAD_COLOR) 163 | image_vis = image 164 | image = cv2.resize(src=image, 165 | dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), 166 | interpolation=cv2.INTER_CUBIC) 167 | image = central_crop(image=image, 168 | central_fraction=CFG.TRAIN.CROP_IMG_HEIGHT / CFG.TRAIN.IMG_HEIGHT) 169 | image = np.array(image, dtype=np.float32) - np.array(_CHANNEL_MEANS, np.float32) 170 | image = np.expand_dims(image, 0) 171 | 172 | # Set sess configuration 173 | sess_config = tf.ConfigProto(allow_soft_placement=True) 174 | sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION 175 | sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH 176 | sess_config.gpu_options.allocator_type = 'BFC' 177 | 178 | sess = tf.Session(config=sess_config) 179 | 180 | with sess.as_default(): 181 | 182 | meta_graphdef = sm.loader.load( 183 | sess, 184 | tags=[sm.tag_constants.SERVING], 185 | export_dir=saved_model_dir) 186 | 187 | signature_def_d = meta_graphdef.signature_def 188 | signature_def_d = signature_def_d['classify_result'] 189 | 190 | image_input_tensor = signature_def_d.inputs['input_tensor'] 191 | prediction_tensor = signature_def_d.outputs['prediction'] 192 | 193 | input_tensor = sm.utils.get_tensor_from_tensor_info(image_input_tensor, sess.graph) 194 | predictions = sm.utils.get_tensor_from_tensor_info(prediction_tensor, sess.graph) 195 | 196 | prediction_val = sess.run(predictions, feed_dict={input_tensor: image}) 197 | 198 | prediction_score = dict() 199 | 200 | for score_index, score in enumerate(prediction_val[0]): 201 | prediction_score[prediciton_map[score_index]] = format(score, '.5f') 202 | 203 | log.info('Predict result: {}'.format(prediction_score)) 204 | 205 | plt.figure('source image') 206 | plt.imshow(image_vis[:, :, (2, 1, 0)]) 207 | 208 | 209 | if __name__ == '__main__': 210 | """ 211 | build saved model 212 | """ 213 | # init args 214 | args = init_args() 215 | 216 | # build saved model 217 | build_saved_model(args.ckpt_path, args.export_dir) 218 | 219 | # test build saved model 220 | test_load_saved_model(args.export_dir) 221 | -------------------------------------------------------------------------------- /tools/make_nsfw_dataset.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # split the nsfw dataset and convert them into tfrecords 4 | python nsfw_classification/data_provider/nsfw_data_feed_pipline.py \ 5 | --dataset_dir nsfw_classification/data/nsfw_dataset_example \ 6 | --tfrecords_dir nsfw_classification/data/nsfw_dataset_example/tfrecords 7 | -------------------------------------------------------------------------------- /tools/test_nsfw.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 19-2-20 上午10:48 4 | # @Author : MaybeShewill-CV 5 | # @Site : https://github.com/MaybeShewill-CV/CRNN_Tensorflow 6 | # @File : test_nsfw.py.py 7 | # @IDE: PyCharm 8 | """ 9 | Test nsfw model script 10 | """ 11 | import os.path as ops 12 | import argparse 13 | import glog as log 14 | 15 | import tensorflow as tf 16 | import numpy as np 17 | import cv2 18 | import matplotlib.pyplot as plt 19 | 20 | from config import global_config 21 | from nsfw_model import nsfw_classification_net 22 | from data_provider import nsfw_data_feed_pipline 23 | 24 | 25 | CFG = global_config.cfg 26 | 27 | _R_MEAN = 123.68 28 | _G_MEAN = 116.78 29 | _B_MEAN = 103.94 30 | _CHANNEL_MEANS = [_B_MEAN, _G_MEAN, _R_MEAN] 31 | 32 | 33 | def init_args(): 34 | """ 35 | 36 | :return: 37 | """ 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument('--image_path', type=str, default=None, help='The image path') 40 | parser.add_argument('--weights_path', type=str, help='The model weights file path') 41 | 42 | return parser.parse_args() 43 | 44 | 45 | def central_crop(image, central_fraction): 46 | """ 47 | 48 | :param image: 49 | :param central_fraction: 50 | :return: 51 | """ 52 | 53 | image_shape = np.shape(image) 54 | image_height = image_shape[0] 55 | image_width = image_shape[1] 56 | 57 | if central_fraction >= 1 or central_fraction < 0: 58 | raise ValueError('Central fraction should be in [0, 1)') 59 | 60 | top = int((1 - central_fraction) * image_height / 2) 61 | bottom = image_height - top 62 | left = int((1 - central_fraction) * image_width / 2) 63 | right = image_width - left 64 | 65 | if not image_height or not image_width: 66 | raise ValueError('Image shape with zero') 67 | 68 | if len(image_shape) == 2: 69 | crop_image = image[top:bottom, left:right] 70 | return crop_image 71 | elif len(image_shape) == 3: 72 | crop_image = image[top:bottom, left:right, :] 73 | return crop_image 74 | else: 75 | raise ValueError('Wrong image shape') 76 | 77 | 78 | def calculate_top_k_error(predictions, labels, k=1): 79 | """ 80 | Calculate the top-k error 81 | :param predictions: 2D tensor with shape [batch_size, num_labels] 82 | :param labels: 1D tensor with shape [batch_size, 1] 83 | :param k: int 84 | :return: tensor with shape [1] 85 | """ 86 | batch_size = CFG.TEST.BATCH_SIZE 87 | in_top_k = tf.to_float(tf.nn.in_top_k(predictions, labels, k=k)) 88 | num_correct = tf.reduce_sum(in_top_k) 89 | 90 | return (batch_size - num_correct) / float(batch_size) 91 | 92 | 93 | def nsfw_classify_image(image_path, weights_path): 94 | """ 95 | Use nsfw model to classify a single image 96 | :param image_path: The image file path 97 | :param weights_path: The pretrained weights file path 98 | :return: 99 | """ 100 | assert ops.exists(image_path) 101 | 102 | prediciton_map = global_config.NSFW_PREDICT_MAP 103 | 104 | with tf.device('/gpu:1'): 105 | # set nsfw classification model 106 | 107 | image_tensor = tf.placeholder(dtype=tf.float32, 108 | shape=[1, CFG.TRAIN.CROP_IMG_HEIGHT, CFG.TRAIN.CROP_IMG_WIDTH, 3], 109 | name='input_tensor') 110 | # set nsfw net 111 | phase = tf.constant('test', dtype=tf.string) 112 | nsfw_net = nsfw_classification_net.NSFWNet(phase=phase, 113 | resnet_size=CFG.NET.RESNET_SIZE) 114 | 115 | # compute inference logits 116 | logits = nsfw_net.inference(input_tensor=image_tensor, 117 | name='nsfw_cls_model', 118 | reuse=False) 119 | 120 | predictions = tf.nn.softmax(logits) 121 | 122 | # Restore the moving average version of the learned variables for eval. 123 | variable_averages = tf.train.ExponentialMovingAverage( 124 | CFG.TRAIN.MOVING_AVERAGE_DECAY) 125 | variables_to_restore = variable_averages.variables_to_restore() 126 | 127 | # set tensorflow saver 128 | saver = tf.train.Saver(variables_to_restore) 129 | 130 | # Set sess configuration 131 | sess_config = tf.ConfigProto(allow_soft_placement=True) 132 | sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION 133 | sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH 134 | sess_config.gpu_options.allocator_type = 'BFC' 135 | 136 | sess = tf.Session(config=sess_config) 137 | 138 | with sess.as_default(): 139 | 140 | saver.restore(sess=sess, save_path=weights_path) 141 | 142 | image = cv2.imread(image_path, cv2.IMREAD_COLOR) 143 | image_vis = image 144 | image = cv2.resize(src=image, 145 | dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT), 146 | interpolation=cv2.INTER_CUBIC) 147 | image = central_crop(image=image, 148 | central_fraction=CFG.TRAIN.CROP_IMG_HEIGHT / CFG.TRAIN.IMG_HEIGHT) 149 | image = np.array(image, dtype=np.float32) - np.array(_CHANNEL_MEANS, np.float32) 150 | 151 | predictions_vals = sess.run( 152 | fetches=predictions, 153 | feed_dict={image_tensor: [image]}) 154 | 155 | prediction_score = dict() 156 | 157 | for score_index, score in enumerate(predictions_vals[0]): 158 | prediction_score[prediciton_map[score_index]] = format(score, '.5f') 159 | 160 | log.info('Predict result is: {}'.format(prediction_score)) 161 | 162 | plt.figure('source image') 163 | plt.imshow(image_vis[:, :, (2, 1, 0)]) 164 | plt.show() 165 | 166 | return 167 | 168 | 169 | if __name__ == '__main__': 170 | # init args 171 | args = init_args() 172 | 173 | # test net 174 | assert ops.exists(args.image_path) 175 | 176 | nsfw_classify_image(args.image_path, args.weights_path) 177 | -------------------------------------------------------------------------------- /tools/train_nsfw.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # @Time : 19-2-15 下午8:53 4 | # @Author : MaybeShewill-CV 5 | # @Site : https://github.com/MaybeShewill-CV/CRNN_Tensorflow 6 | # @File : train_nsfw.py 7 | # @IDE: PyCharm 8 | """ 9 | Train nsfw model script 10 | """ 11 | import argparse 12 | import os 13 | import os.path as ops 14 | import time 15 | import math 16 | 17 | import numpy as np 18 | import tensorflow as tf 19 | import glog as log 20 | 21 | from config import global_config 22 | from data_provider import nsfw_data_feed_pipline 23 | from nsfw_model import nsfw_classification_net 24 | 25 | 26 | CFG = global_config.cfg 27 | 28 | 29 | def init_args(): 30 | """ 31 | 32 | :return: 33 | """ 34 | parser = argparse.ArgumentParser() 35 | 36 | parser.add_argument('--dataset_dir', type=str, help='The dataset_dir') 37 | parser.add_argument('--use_multi_gpu', type=bool, default=False, help='If use multiple gpu devices') 38 | parser.add_argument('--weights_path', type=str, default=None, help='The pretrained weights path') 39 | 40 | return parser.parse_args() 41 | 42 | 43 | def calculate_top_k_error(predictions, labels, k=1): 44 | """ 45 | Calculate the top-k error 46 | :param predictions: 2D tensor with shape [batch_size, num_labels] 47 | :param labels: 1D tensor with shape [batch_size, 1] 48 | :param k: int 49 | :return: tensor with shape [1] 50 | """ 51 | batch_size = CFG.TRAIN.BATCH_SIZE 52 | in_top1 = tf.to_float(tf.nn.in_top_k(predictions, labels, k=k)) 53 | num_correct = tf.reduce_sum(in_top1) 54 | 55 | return (batch_size - num_correct) / float(batch_size) 56 | 57 | 58 | def average_gradients(tower_grads): 59 | """Calculate the average gradient for each shared variable across all towers. 60 | Note that this function provides a synchronization point across all towers. 61 | Args: 62 | tower_grads: List of lists of (gradient, variable) tuples. The outer list 63 | is over individual gradients. The inner list is over the gradient 64 | calculation for each tower. 65 | Returns: 66 | List of pairs of (gradient, variable) where the gradient has been averaged 67 | across all towers. 68 | """ 69 | average_grads = [] 70 | for grad_and_vars in zip(*tower_grads): 71 | # Note that each grad_and_vars looks like the following: 72 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 73 | grads = [] 74 | for g, _ in grad_and_vars: 75 | # Add 0 dimension to the gradients to represent the tower. 76 | expanded_g = tf.expand_dims(g, 0) 77 | 78 | # Append on a 'tower' dimension which we will average over below. 79 | grads.append(expanded_g) 80 | 81 | # Average over the 'tower' dimension. 82 | grad = tf.concat(grads, 0) 83 | grad = tf.reduce_mean(grad, 0) 84 | 85 | # Keep in mind that the Variables are redundant because they are shared 86 | # across towers. So .. we will just return the first tower's pointer to 87 | # the Variable. 88 | v = grad_and_vars[0][1] 89 | grad_and_var = (grad, v) 90 | average_grads.append(grad_and_var) 91 | 92 | return average_grads 93 | 94 | 95 | def compute_net_gradients(images, labels, net, optimizer=None, is_net_first_initialized=False): 96 | """ 97 | Calculate gradients for single GPU 98 | :param images: images for training 99 | :param labels: labels corresponding to images 100 | :param net: classification model 101 | :param optimizer: network optimizer 102 | :param is_net_first_initialized: if the network is initialized 103 | :return: 104 | """ 105 | net_loss = net.compute_loss(input_tensor=images, 106 | labels=labels, 107 | name='nsfw_cls_model', 108 | reuse=is_net_first_initialized) 109 | net_logits = net.inference(input_tensor=images, 110 | name='nsfw_cls_model', 111 | reuse=True) 112 | 113 | net_predictions = tf.nn.softmax(net_logits) 114 | net_top1_error = calculate_top_k_error(net_predictions, labels, 1) 115 | 116 | if optimizer is not None: 117 | grads = optimizer.compute_gradients(net_loss) 118 | else: 119 | grads = None 120 | 121 | return net_loss, net_top1_error, grads 122 | 123 | 124 | def train_net(dataset_dir, weights_path=None): 125 | """ 126 | 127 | :param dataset_dir: 128 | :param weights_path: 129 | :return: 130 | """ 131 | 132 | # set nsfw data feed pipline 133 | train_dataset = nsfw_data_feed_pipline.NsfwDataFeeder(dataset_dir=dataset_dir, 134 | flags='train') 135 | val_dataset = nsfw_data_feed_pipline.NsfwDataFeeder(dataset_dir=dataset_dir, 136 | flags='val') 137 | 138 | with tf.device('/gpu:1'): 139 | # set nsfw net 140 | nsfw_net = nsfw_classification_net.NSFWNet(phase=tf.constant('train', dtype=tf.string), 141 | resnet_size=CFG.NET.RESNET_SIZE) 142 | nsfw_net_val = nsfw_classification_net.NSFWNet(phase=tf.constant('test', dtype=tf.string), 143 | resnet_size=CFG.NET.RESNET_SIZE) 144 | 145 | # compute train loss 146 | train_images, train_labels = train_dataset.inputs(batch_size=CFG.TRAIN.BATCH_SIZE, 147 | num_epochs=1) 148 | train_loss = nsfw_net.compute_loss(input_tensor=train_images, 149 | labels=train_labels, 150 | name='nsfw_cls_model', 151 | reuse=False) 152 | 153 | train_logits = nsfw_net.inference(input_tensor=train_images, 154 | name='nsfw_cls_model', 155 | reuse=True) 156 | 157 | train_predictions = tf.nn.softmax(train_logits) 158 | train_top1_error = calculate_top_k_error(train_predictions, train_labels, 1) 159 | 160 | # compute val loss 161 | val_images, val_labels = val_dataset.inputs(batch_size=CFG.TRAIN.VAL_BATCH_SIZE, 162 | num_epochs=1) 163 | # val_images = tf.reshape(val_images, example_tensor_shape) 164 | val_loss = nsfw_net_val.compute_loss(input_tensor=val_images, 165 | labels=val_labels, 166 | name='nsfw_cls_model', 167 | reuse=True) 168 | 169 | val_logits = nsfw_net_val.inference(input_tensor=val_images, 170 | name='nsfw_cls_model', 171 | reuse=True) 172 | 173 | val_predictions = tf.nn.softmax(val_logits) 174 | val_top1_error = calculate_top_k_error(val_predictions, val_labels, 1) 175 | 176 | # set tensorflow summary 177 | tboard_save_path = 'tboard/nsfw_cls' 178 | os.makedirs(tboard_save_path, exist_ok=True) 179 | 180 | summary_writer = tf.summary.FileWriter(tboard_save_path) 181 | 182 | train_loss_scalar = tf.summary.scalar(name='train_loss', 183 | tensor=train_loss) 184 | train_top1_err_scalar = tf.summary.scalar(name='train_top1_error', 185 | tensor=train_top1_error) 186 | val_loss_scalar = tf.summary.scalar(name='val_loss', 187 | tensor=val_loss) 188 | val_top1_err_scalar = tf.summary.scalar(name='val_top1_error', 189 | tensor=val_top1_error) 190 | 191 | train_merge_summary_op = tf.summary.merge([train_loss_scalar, train_top1_err_scalar]) 192 | 193 | val_merge_summary_op = tf.summary.merge([val_loss_scalar, val_top1_err_scalar]) 194 | 195 | # Set tf saver 196 | saver = tf.train.Saver() 197 | model_save_dir = 'model/nsfw_cls' 198 | os.makedirs(model_save_dir, exist_ok=True) 199 | train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) 200 | model_name = 'nsfw_cls_{:s}.ckpt'.format(str(train_start_time)) 201 | model_save_path = ops.join(model_save_dir, model_name) 202 | 203 | # set optimizer 204 | with tf.device('/gpu:1'): 205 | # set learning rate 206 | global_step = tf.Variable(0, trainable=False) 207 | decay_steps = [CFG.TRAIN.LR_DECAY_STEPS_1, CFG.TRAIN.LR_DECAY_STEPS_2] 208 | decay_values = [] 209 | init_lr = CFG.TRAIN.LEARNING_RATE 210 | for step in range(len(decay_steps) + 1): 211 | decay_values.append(init_lr) 212 | init_lr = init_lr * CFG.TRAIN.LR_DECAY_RATE 213 | 214 | learning_rate = tf.train.piecewise_constant( 215 | x=global_step, 216 | boundaries=decay_steps, 217 | values=decay_values, 218 | name='learning_rate' 219 | ) 220 | 221 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 222 | with tf.control_dependencies(update_ops): 223 | optimizer = tf.train.MomentumOptimizer( 224 | learning_rate=learning_rate, momentum=0.9).minimize( 225 | loss=train_loss, 226 | var_list=tf.trainable_variables(), 227 | global_step=global_step) 228 | 229 | # Set sess configuration 230 | sess_config = tf.ConfigProto(allow_soft_placement=True) 231 | sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION 232 | sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH 233 | sess_config.gpu_options.allocator_type = 'BFC' 234 | 235 | sess = tf.Session(config=sess_config) 236 | 237 | summary_writer.add_graph(sess.graph) 238 | 239 | # Set the training parameters 240 | train_epochs = CFG.TRAIN.EPOCHS 241 | 242 | log.info('Global configuration is as follows:') 243 | log.info(CFG) 244 | 245 | with sess.as_default(): 246 | 247 | tf.train.write_graph(graph_or_graph_def=sess.graph, logdir='', 248 | name='{:s}/nsfw_cls_model.pb'.format(model_save_dir)) 249 | 250 | if weights_path is None: 251 | log.info('Training from scratch') 252 | init = tf.global_variables_initializer() 253 | sess.run(init) 254 | else: 255 | log.info('Restore model from last model checkpoint {:s}'.format(weights_path)) 256 | saver.restore(sess=sess, save_path=weights_path) 257 | 258 | train_cost_time_mean = [] 259 | val_cost_time_mean = [] 260 | 261 | for epoch in range(train_epochs): 262 | 263 | # training part 264 | t_start = time.time() 265 | 266 | _, train_loss_value, train_top1_err_value, train_summary, lr = \ 267 | sess.run(fetches=[optimizer, 268 | train_loss, 269 | train_top1_error, 270 | train_merge_summary_op, 271 | learning_rate]) 272 | 273 | if math.isnan(train_loss_value): 274 | log.error('Train loss is nan') 275 | return 276 | 277 | cost_time = time.time() - t_start 278 | train_cost_time_mean.append(cost_time) 279 | 280 | summary_writer.add_summary(summary=train_summary, 281 | global_step=epoch) 282 | 283 | # validation part 284 | t_start_val = time.time() 285 | 286 | val_loss_value, val_top1_err_value, val_summary = \ 287 | sess.run(fetches=[val_loss, 288 | val_top1_error, 289 | val_merge_summary_op]) 290 | 291 | summary_writer.add_summary(val_summary, global_step=epoch) 292 | 293 | cost_time_val = time.time() - t_start_val 294 | val_cost_time_mean.append(cost_time_val) 295 | 296 | if epoch % CFG.TRAIN.DISPLAY_STEP == 0: 297 | log.info('Epoch_Train: {:d} total_loss= {:6f} top1_error= {:6f} ' 298 | 'lr= {:6f} mean_cost_time= {:5f}s '. 299 | format(epoch + 1, 300 | train_loss_value, 301 | train_top1_err_value, 302 | lr, 303 | np.mean(train_cost_time_mean))) 304 | train_cost_time_mean.clear() 305 | 306 | if epoch % CFG.TRAIN.VAL_DISPLAY_STEP == 0: 307 | log.info('Epoch_Val: {:d} total_loss= {:6f} top1_error= {:6f}' 308 | ' mean_cost_time= {:5f}s '. 309 | format(epoch + 1, 310 | val_loss_value, 311 | val_top1_err_value, 312 | np.mean(val_cost_time_mean))) 313 | val_cost_time_mean.clear() 314 | 315 | if epoch % 2000 == 0: 316 | saver.save(sess=sess, save_path=model_save_path, global_step=epoch) 317 | sess.close() 318 | 319 | return 320 | 321 | 322 | def train_net_multi_gpu(dataset_dir, weights_path=None): 323 | """ 324 | 325 | :param dataset_dir: 326 | :param weights_path: 327 | :return: 328 | """ 329 | # set nsfw data feed pipline 330 | train_dataset = nsfw_data_feed_pipline.NsfwDataFeeder(dataset_dir=dataset_dir, 331 | flags='train') 332 | val_dataset = nsfw_data_feed_pipline.NsfwDataFeeder(dataset_dir=dataset_dir, 333 | flags='val') 334 | 335 | # set nsfw net 336 | nsfw_net = nsfw_classification_net.NSFWNet(phase=tf.constant('train', dtype=tf.string), 337 | resnet_size=CFG.NET.RESNET_SIZE) 338 | nsfw_net_val = nsfw_classification_net.NSFWNet(phase=tf.constant('test', dtype=tf.string), 339 | resnet_size=CFG.NET.RESNET_SIZE) 340 | 341 | # fetch train and validation data 342 | train_images, train_labels = train_dataset.inputs( 343 | batch_size=CFG.TRAIN.BATCH_SIZE, num_epochs=1) 344 | val_images, val_labels = val_dataset.inputs( 345 | batch_size=CFG.TRAIN.BATCH_SIZE, num_epochs=1) 346 | 347 | # set average container 348 | tower_grads = [] 349 | train_tower_loss = [] 350 | train_tower_top1_error = [] 351 | val_tower_loss = [] 352 | val_tower_top1_error = [] 353 | batchnorm_updates = None 354 | train_summary_op_updates = None 355 | 356 | # set learning rate 357 | global_step = tf.Variable(0, trainable=False) 358 | decay_steps = [CFG.TRAIN.LR_DECAY_STEPS_1, CFG.TRAIN.LR_DECAY_STEPS_2] 359 | decay_values = [] 360 | init_lr = CFG.TRAIN.LEARNING_RATE 361 | for step in range(len(decay_steps) + 1): 362 | decay_values.append(init_lr) 363 | init_lr = init_lr * CFG.TRAIN.LR_DECAY_RATE 364 | 365 | learning_rate = tf.train.piecewise_constant( 366 | x=global_step, 367 | boundaries=decay_steps, 368 | values=decay_values, 369 | name='learning_rate' 370 | ) 371 | 372 | # set optimizer 373 | optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9) 374 | 375 | # set distributed train op 376 | with tf.variable_scope(tf.get_variable_scope()): 377 | is_network_initialized = False 378 | for i in range(CFG.TRAIN.GPU_NUM): 379 | with tf.device('/gpu:{:d}'.format(i)): 380 | with tf.name_scope('tower_{:d}'.format(i)) as scope: 381 | train_loss, train_top1_error, grads = compute_net_gradients( 382 | train_images, train_labels, nsfw_net, optimizer, 383 | is_net_first_initialized=is_network_initialized) 384 | 385 | is_network_initialized = True 386 | 387 | # Only use the mean and var in the first gpu tower to update the parameter 388 | # TODO implement batch normalization for distributed device (luoyao@baidu.com) 389 | if i == 0: 390 | batchnorm_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 391 | train_summary_op_updates = tf.get_collection(tf.GraphKeys.SUMMARIES) 392 | 393 | tower_grads.append(grads) 394 | train_tower_loss.append(train_loss) 395 | train_tower_top1_error.append(train_top1_error) 396 | with tf.name_scope('validation_{:d}'.format(i)) as scope: 397 | val_loss, val_top1_error, _ = compute_net_gradients( 398 | val_images, val_labels, nsfw_net_val, optimizer, 399 | is_net_first_initialized=is_network_initialized) 400 | val_tower_loss.append(val_loss) 401 | val_tower_top1_error.append(val_top1_error) 402 | 403 | grads = average_gradients(tower_grads) 404 | avg_train_loss = tf.reduce_mean(train_tower_loss) 405 | avg_train_top1_error = tf.reduce_mean(train_tower_top1_error) 406 | avg_val_loss = tf.reduce_mean(val_tower_loss) 407 | avg_val_top1_error = tf.reduce_mean(val_tower_top1_error) 408 | 409 | # Track the moving averages of all trainable variables 410 | variable_averages = tf.train.ExponentialMovingAverage( 411 | CFG.TRAIN.MOVING_AVERAGE_DECAY, num_updates=global_step) 412 | variables_to_average = tf.trainable_variables() + tf.moving_average_variables() 413 | variables_averages_op = variable_averages.apply(variables_to_average) 414 | 415 | # Group all the op needed for training 416 | batchnorm_updates_op = tf.group(*batchnorm_updates) 417 | apply_gradient_op = optimizer.apply_gradients(grads, global_step=global_step) 418 | train_op = tf.group(apply_gradient_op, variables_averages_op, 419 | batchnorm_updates_op) 420 | 421 | # set tensorflow summary 422 | tboard_save_path = 'tboard/nsfw_cls' 423 | os.makedirs(tboard_save_path, exist_ok=True) 424 | 425 | summary_writer = tf.summary.FileWriter(tboard_save_path) 426 | 427 | avg_train_loss_scalar = tf.summary.scalar(name='average_train_loss', 428 | tensor=avg_train_loss) 429 | avg_train_top1_err_scalar = tf.summary.scalar(name='average_train_top1_error', 430 | tensor=avg_train_top1_error) 431 | avg_val_loss_scalar = tf.summary.scalar(name='average_val_loss', 432 | tensor=avg_val_loss) 433 | avg_val_top1_err_scalar = tf.summary.scalar(name='average_val_top1_error', 434 | tensor=avg_val_top1_error) 435 | learning_rate_scalar = tf.summary.scalar(name='learning_rate_scalar', 436 | tensor=learning_rate) 437 | 438 | train_merge_summary_op = tf.summary.merge([avg_train_loss_scalar, 439 | avg_train_top1_err_scalar, 440 | learning_rate_scalar] + train_summary_op_updates) 441 | 442 | val_merge_summary_op = tf.summary.merge([avg_val_loss_scalar, avg_val_top1_err_scalar]) 443 | 444 | # set tensorflow saver 445 | saver = tf.train.Saver() 446 | model_save_dir = 'model/nsfw_cls' 447 | os.makedirs(model_save_dir, exist_ok=True) 448 | train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time())) 449 | model_name = 'nsfw_cls_{:s}.ckpt'.format(str(train_start_time)) 450 | model_save_path = ops.join(model_save_dir, model_name) 451 | 452 | # set sess config 453 | sess_config = tf.ConfigProto(device_count={'GPU': CFG.TRAIN.GPU_NUM}, allow_soft_placement=True) 454 | sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION 455 | sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH 456 | sess_config.gpu_options.allocator_type = 'BFC' 457 | 458 | # Set the training parameters 459 | train_epochs = CFG.TRAIN.EPOCHS 460 | 461 | log.info('Global configuration is as follows:') 462 | log.info(CFG) 463 | 464 | sess = tf.Session(config=sess_config) 465 | 466 | summary_writer.add_graph(sess.graph) 467 | 468 | with sess.as_default(): 469 | 470 | tf.train.write_graph(graph_or_graph_def=sess.graph, logdir='', 471 | name='{:s}/nsfw_cls_model.pb'.format(model_save_dir)) 472 | 473 | if weights_path is None: 474 | log.info('Training from scratch') 475 | init = tf.global_variables_initializer() 476 | sess.run(init) 477 | else: 478 | log.info('Restore model from last model checkpoint {:s}'.format(weights_path)) 479 | saver.restore(sess=sess, save_path=weights_path) 480 | 481 | train_cost_time_mean = [] 482 | val_cost_time_mean = [] 483 | 484 | for epoch in range(train_epochs): 485 | 486 | # training part 487 | t_start = time.time() 488 | 489 | _, train_loss_value, train_top1_err_value, train_summary, lr = \ 490 | sess.run(fetches=[train_op, 491 | avg_train_loss, 492 | avg_train_top1_error, 493 | train_merge_summary_op, 494 | learning_rate]) 495 | 496 | if math.isnan(train_loss_value): 497 | log.error('Train loss is nan') 498 | return 499 | 500 | cost_time = time.time() - t_start 501 | train_cost_time_mean.append(cost_time) 502 | 503 | summary_writer.add_summary(summary=train_summary, 504 | global_step=epoch) 505 | 506 | # validation part 507 | t_start_val = time.time() 508 | 509 | val_loss_value, val_top1_err_value, val_summary = \ 510 | sess.run(fetches=[avg_val_loss, 511 | avg_val_top1_error, 512 | val_merge_summary_op]) 513 | 514 | summary_writer.add_summary(val_summary, global_step=epoch) 515 | 516 | cost_time_val = time.time() - t_start_val 517 | val_cost_time_mean.append(cost_time_val) 518 | 519 | if epoch % CFG.TRAIN.DISPLAY_STEP == 0: 520 | log.info('Epoch_Train: {:d} total_loss= {:6f} top1_error= {:6f} ' 521 | 'lr= {:6f} mean_cost_time= {:5f}s '. 522 | format(epoch + 1, 523 | train_loss_value, 524 | train_top1_err_value, 525 | lr, 526 | np.mean(train_cost_time_mean))) 527 | train_cost_time_mean.clear() 528 | 529 | if epoch % CFG.TRAIN.VAL_DISPLAY_STEP == 0: 530 | log.info('Epoch_Val: {:d} total_loss= {:6f} top1_error= {:6f}' 531 | ' mean_cost_time= {:5f}s '. 532 | format(epoch + 1, 533 | val_loss_value, 534 | val_top1_err_value, 535 | np.mean(val_cost_time_mean))) 536 | val_cost_time_mean.clear() 537 | 538 | if epoch % 2000 == 0: 539 | saver.save(sess=sess, save_path=model_save_path, global_step=epoch) 540 | sess.close() 541 | 542 | return 543 | 544 | 545 | if __name__ == '__main__': 546 | # init args 547 | args = init_args() 548 | 549 | if CFG.TRAIN.GPU_NUM < 2: 550 | args.use_multi_gpu = False 551 | 552 | # train lanenet 553 | if not args.use_multi_gpu: 554 | train_net(args.dataset_dir, args.weights_path) 555 | else: 556 | train_net_multi_gpu(args.dataset_dir, args.weights_path) 557 | 558 | --------------------------------------------------------------------------------