├── .gitignore
├── .idea
├── codeStyles
│ └── codeStyleConfig.xml
└── vcs.xml
├── README.md
├── _config.yml
├── config
├── __init__.py
└── global_config.py
├── data
├── images
│ ├── avg_train_loss.png
│ ├── avg_train_top1_error.png
│ ├── avg_val_loss.png
│ ├── avg_val_top1_error.png
│ ├── confusion_matrix.png
│ ├── evaluation_nsfw.png
│ ├── online_demo.png
│ └── precision_recall.png
├── nsfw_dataset_example
│ ├── drawing
│ │ ├── drawing_136.jpg
│ │ ├── drawing_147.jpg
│ │ ├── drawing_192.jpg
│ │ ├── drawing_207.jpg
│ │ ├── drawing_219.jpg
│ │ ├── drawing_227.jpg
│ │ ├── drawing_240.jpg
│ │ ├── drawing_253.jpg
│ │ ├── drawing_28.jpg
│ │ ├── drawing_346.jpg
│ │ ├── drawing_348.jpg
│ │ ├── drawing_361.jpg
│ │ ├── drawing_384.jpg
│ │ ├── drawing_4.jpg
│ │ ├── drawing_463.jpg
│ │ ├── drawing_483.jpg
│ │ ├── drawing_52.jpg
│ │ └── drawing_77.jpg
│ ├── hentai
│ │ ├── hentai_106.jpg
│ │ ├── hentai_112.jpg
│ │ ├── hentai_12.jpg
│ │ ├── hentai_194.jpg
│ │ ├── hentai_216.jpg
│ │ ├── hentai_291.jpg
│ │ ├── hentai_311.jpg
│ │ ├── hentai_318.jpg
│ │ ├── hentai_331.jpg
│ │ ├── hentai_377.jpg
│ │ ├── hentai_407.jpg
│ │ ├── hentai_424.jpg
│ │ ├── hentai_437.jpg
│ │ ├── hentai_458.jpg
│ │ ├── hentai_489.jpg
│ │ ├── hentai_499.jpg
│ │ ├── hentai_65.jpg
│ │ └── hentai_677.jpg
│ ├── neural
│ │ ├── neural_352.jpg
│ │ ├── neural_353.jpg
│ │ ├── neural_358.jpg
│ │ ├── neural_360.jpg
│ │ ├── neural_374.jpg
│ │ ├── neural_377.jpg
│ │ ├── neural_438.jpg
│ │ ├── neural_446.jpg
│ │ ├── neural_474.jpg
│ │ ├── neural_493.jpg
│ │ ├── neural_518.jpg
│ │ ├── neural_523.jpg
│ │ ├── neural_536.jpg
│ │ ├── neural_539.jpg
│ │ ├── neural_569.jpg
│ │ ├── neural_596.jpg
│ │ ├── neural_619.jpg
│ │ └── neural_683.jpg
│ ├── porn
│ │ ├── porn_1016.jpg
│ │ ├── porn_1198.jpg
│ │ ├── porn_1270.jpg
│ │ ├── porn_1364.jpg
│ │ ├── porn_1470.jpg
│ │ ├── porn_1488.jpg
│ │ ├── porn_1515.jpg
│ │ ├── porn_1551.jpg
│ │ ├── porn_1556.jpg
│ │ ├── porn_1570.jpg
│ │ ├── porn_1592.jpg
│ │ ├── porn_1677.jpg
│ │ ├── porn_1723.jpg
│ │ ├── porn_179.jpg
│ │ ├── porn_1794.jpg
│ │ ├── porn_1866.jpg
│ │ ├── porn_1899.jpg
│ │ ├── porn_254.jpg
│ │ ├── porn_364.jpg
│ │ ├── porn_400.jpg
│ │ ├── porn_679.jpg
│ │ ├── porn_749.jpg
│ │ ├── porn_755.jpg
│ │ └── porn_924.jpg
│ └── sexy
│ │ ├── sexy_167.jpg
│ │ ├── sexy_172.jpg
│ │ ├── sexy_180.jpg
│ │ ├── sexy_189.jpg
│ │ ├── sexy_227.jpg
│ │ ├── sexy_267.jpg
│ │ ├── sexy_296.jpg
│ │ ├── sexy_301.jpg
│ │ ├── sexy_312.jpg
│ │ ├── sexy_397.jpg
│ │ ├── sexy_474.jpg
│ │ ├── sexy_488.jpg
│ │ ├── sexy_508.jpg
│ │ ├── sexy_520.jpg
│ │ ├── sexy_524.jpg
│ │ ├── sexy_554.jpg
│ │ ├── sexy_556.jpg
│ │ ├── sexy_567.jpg
│ │ ├── sexy_599.jpg
│ │ ├── sexy_611.jpg
│ │ ├── sexy_629.jpg
│ │ └── sexy_712.jpg
└── test_data
│ ├── drawing_16715.jpg
│ ├── hentai_16814.jpg
│ ├── neural_2619.jpg
│ ├── porn_19625.jpg
│ └── sexy_6182.jpg
├── data_provider
├── __init__.py
├── nsfw_data_feed_pipline.py
└── tf_io_pipline_tools.py
├── docker_container
├── __init__.py
└── python_client.py
├── model
└── nsfw_export_saved_model
│ └── 1
│ ├── saved_model.pb
│ └── variables
│ ├── variables.data-00000-of-00001
│ └── variables.index
├── nsfw_model
├── __init__.py
├── cnn_basenet.py
└── nsfw_classification_net.py
├── requirements.txt
├── tboard
└── nsfw_cls
│ └── events.out.tfevents.1551264389.baidu-pc
└── tools
├── __init__.py
├── convert_tfjs_model.sh
├── evaluate_nsfw.py
├── export_nsfw_saved_model.sh
├── export_saved_model.py
├── make_nsfw_dataset.sh
├── test_nsfw.py
└── train_nsfw.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | ### Python template
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | .hypothesis/
50 | .pytest_cache/
51 |
52 | # Translations
53 | *.mo
54 | *.pot
55 |
56 | # Django stuff:
57 | *.log
58 | local_settings.py
59 | db.sqlite3
60 |
61 | # Flask stuff:
62 | instance/
63 | .webassets-cache
64 |
65 | # Scrapy stuff:
66 | .scrapy
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # celery beat schedule file
81 | celerybeat-schedule
82 |
83 | # SageMath parsed files
84 | *.sage.py
85 |
86 | # Environments
87 | .env
88 | .venv
89 | env/
90 | venv/
91 | ENV/
92 | env.bak/
93 | venv.bak/
94 |
95 | # Spyder project settings
96 | .spyderproject
97 | .spyproject
98 |
99 | # Rope project settings
100 | .ropeproject
101 |
102 | # mkdocs documentation
103 | /site
104 |
105 | # mypy
106 | .mypy_cache/
107 |
108 | .idea/dictionaries/
109 | .idea/inspectionProfiles/
110 | .idea/misc.xml
111 | .idea/modules.xml
112 | .idea/nsfw_classification.iml
113 | .idea/workspace.xml
114 |
115 | model/
116 | tboard/
117 |
118 | tmp*.py
--------------------------------------------------------------------------------
/.idea/codeStyles/codeStyleConfig.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Nsfw-Classify-Tensorflow
2 | NSFW classify model implemented with tensorflow. Use nsfw dataset provided here
3 | https://github.com/alexkimxyz/nsfw_data_scraper Thanks for sharing the dataset
4 | with us. You can find all the model details here. Don not hesitate to raise an
5 | issue if you're confused with the model.
6 |
7 | ## Installation
8 | This software has only been tested on ubuntu 16.04(x64). Here is the test environment
9 | info
10 |
11 | **OS**: Ubuntu 16.04 LTS
12 |
13 | **GPU**: Two GTX 1070TI
14 |
15 | **CUDA**: cuda 9.0
16 |
17 | **Tensorflow**: tensorflow 1.12.0
18 |
19 | **OPENCV**: opencv 3.4.1
20 |
21 | **NUMPY**: numpy 1.15.1
22 |
23 | Other required package you may install them by
24 |
25 | ```
26 | pip3 install -r requirements.txt
27 | ```
28 |
29 | ## Test model
30 | In this repo I uploaded a pretrained dataset introduced as before, you need to
31 | download the pretrained [weights_file](https://www.dropbox.com/sh/yfc3hy7enopdxj2/AADek2FlqCW1_j8Ax5x_VQy8a?dl=0)
32 | in folder REPO_ROOT_DIR/model/.
33 |
34 | You can test a single image on the trained model as follows
35 |
36 | ```
37 | cd REPO_ROOT_DIR
38 | python tools/test_model.py --weights_path model/new_model/nsfw_cls.ckpt-100000
39 | --image_path data/test_data/test_drawing.jpg
40 | ```
41 |
42 | The following part will show you how the dataset is well prepared
43 |
44 | ## Train your own model
45 |
46 | #### Data Preparation
47 | First you need to download all the origin nsfw data. Here is the
48 | [how_to_download_source_data](https://github.com/alexkimxyz/nsfw_data_scraper).
49 | The training example should be organized like the what you can see in
50 | REPO_ROOT_DIR/data/nsfw_dataset_example. Then you should modified the
51 | REPO_ROOT_DIR/tools/make_nsfw_dataset.sh with your local nsfw dataset. Then excute
52 | the dataset preparation script. That may take about one hour in my local machine.
53 | You may enlarge the __C.TRAIN.CPU_MULTI_PROCESS_NUMS in config/global_config.py
54 | if you have a powerful cpu to accelerate the prepare process.
55 |
56 | ```
57 | cd REPO_ROOT_DIR
58 | bash tools/make_nsfw_dataset.sh
59 | ```
60 |
61 | The image of each subclass will be split into three part according to the ratio
62 | training : validation : test = 0.75 : 0.1 : 0.15. All the image will be convert
63 | into tensorflow format record for efficient importing data pipline.
64 |
65 | #### Train Model
66 | The model support multi-gpu training. If you want to training the model on
67 | multiple gpus you need to first adjust the __C.TRAIN.GPU_NUM in config/global_config.py
68 | file. Then excute the multi-gpu training procedure as follows:
69 |
70 | ```
71 | cd REPO_ROOT_DIR
72 | python tools/train_nsfw.py --dataset_dir PATH/TO/PREPARED_NSFW_DATASET --use_multi_gpu True
73 | ```
74 |
75 | If you want to train the model from last snap shot you may excute following command:
76 |
77 | ```
78 | cd REPO_ROOT_DIR
79 | python tools/train_nsfw.py --dataset_dir PATH/TO/PREPARED_NSFW_DATASET
80 | --use_multi_gpu True --weights_path PATH/TO/YOUR/LAST_CKPT_FILE_PATH
81 | ```
82 |
83 | You may set the --use_multi_gpu False then the whole training process will be excuted
84 | on single gpu.
85 |
86 | The main model's hyperparameter are as follows:
87 |
88 | **iterations nums**: 160010
89 |
90 | **learning rate**: 0.1
91 |
92 | **batch size**: 32
93 |
94 | **origin image size**: 256
95 |
96 | **cropped image size**: 224
97 |
98 | **training example nums**: 159477
99 |
100 | **testing example nums**: 31895
101 |
102 | **validation example nums**: 21266
103 |
104 | The rest of the hyperparameter can be found [here](https://github.com/MaybeShewill-CV/nsfw-classification-tensorflow/blob/master/config/global_config.py).
105 |
106 | If you want to convert the downloaded ckpt model into tensorflow saved model
107 | you can simply modify the file path in ROOT_DIR/tools/export_nsfw_saved_model.sh
108 | and run it.
109 |
110 | ```
111 | bash tools/export_nsfw_saved_model.sh
112 | ```
113 |
114 | You may monitor the training process using tensorboard tools
115 |
116 | During my experiment the `train loss` drops as follows:
117 | 
118 |
119 | The `train_top_1_error` drops as follows:
120 | 
121 |
122 | The `validation loss` drops as follows:
123 | 
124 |
125 | The `validation_top_1_error` drops as follows:
126 | 
127 |
128 | #### The Model Evaluation
129 |
130 | You can evaluate the model's performance on the nsfw dataset prepared in
131 | advance as follows
132 |
133 | ```
134 | cd REPO_ROOT_DIR
135 | python tools/evaluate_nsfw.py --weights_path model/new_model/nsfw_cls.ckpt-160000
136 | --dataset_dir PATH/TO/YOUR/NSFW_DATASET
137 | ```
138 |
139 | After you run the script you should see something like this
140 | 
141 |
142 | The model's main evaluation index are as follows:
143 |
144 | **Precision**: 0.92406 with average weighted on each class
145 |
146 | **Recall**: 0.92364 with average weighted on each class
147 |
148 | **F1 score**: 0.92344 with average weighted on each class
149 |
150 | The `Confusion_Matrix` is as follows:
151 | 
152 |
153 | The `Precison_Recall` is as follows:
154 | 
155 |
156 |
157 | #### Online demo
158 |
159 | ##### URL: https://maybeshewill-cv.github.io/nsfw_classification
160 |
161 | Since tensorflo-js is well supported the online deep learning is easy to deploy.
162 | Here I have make a online demo to do local nsfw classification work. The whole js work
163 | can be found here https://github.com/MaybeShewill-CV/MaybeShewill-CV.github.io/tree/master/nsfw_classification
164 | I have supplied a tool to convert the trained tensorflow saved model file into
165 | tensorflow js model file. In order to generate saved model you can read the
166 | description about it above. After you generate the tensorflow saved model you
167 | can simply modify the file path and run the following script
168 |
169 | ```
170 | cd ROOT_DIR
171 | bash tools/convert_tfjs_model.sh
172 | ```
173 | The online demo's example are as follows:
174 | 
175 |
176 | ## TODO
177 | - [ ] Add tensorflow serving script
178 |
--------------------------------------------------------------------------------
/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-cayman
--------------------------------------------------------------------------------
/config/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 19-2-14 下午5:38
4 | # @Author : Luo Yao
5 | # @Site : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
6 | # @File : __init__.py.py
7 | # @IDE: PyCharm
--------------------------------------------------------------------------------
/config/global_config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 18-1-31 上午11:21
4 | # @Author : MaybeShewill-CV
5 | # @Site : https://github.com/MaybeShewill-CV/CRNN_Tensorflow
6 | # @File : global_config.py
7 | # @IDE: PyCharm Community Edition
8 | """
9 | 设置全局变量
10 | """
11 | import easydict
12 |
13 | __C = easydict.EasyDict()
14 | # Consumers can get config by: from config import cfg
15 |
16 | cfg = __C
17 |
18 | # Train options
19 | __C.TRAIN = easydict.EasyDict()
20 |
21 | # Set the shadownet training epochs
22 | __C.TRAIN.EPOCHS = 160010
23 | # Set the display step
24 | __C.TRAIN.DISPLAY_STEP = 1
25 | # Set the test display step during training process
26 | __C.TRAIN.VAL_DISPLAY_STEP = 1000
27 | # Set the momentum parameter of the optimizer
28 | __C.TRAIN.MOMENTUM = 0.9
29 | # Set the initial learning rate
30 | __C.TRAIN.LEARNING_RATE = 0.1
31 | # Set the GPU resource used during training process
32 | __C.TRAIN.GPU_MEMORY_FRACTION = 0.75
33 | # Set the GPU allow growth parameter during tensorflow training process
34 | __C.TRAIN.TF_ALLOW_GROWTH = False
35 | # Set the shadownet training batch size
36 | __C.TRAIN.BATCH_SIZE = 32
37 | # Set the shadownet validation batch size
38 | __C.TRAIN.VAL_BATCH_SIZE = 32
39 | # Set the learning rate decay steps
40 | __C.TRAIN.LR_DECAY_STEPS_1 = 60000
41 | # Set the learning rate decay steps
42 | __C.TRAIN.LR_DECAY_STEPS_2 = 120000
43 | # Set the learning rate decay rate
44 | __C.TRAIN.LR_DECAY_RATE = 0.1
45 | # Set the weights decay
46 | __C.TRAIN.WEIGHT_DECAY = 0.0001
47 | # Set the train moving average decay
48 | __C.TRAIN.MOVING_AVERAGE_DECAY = 0.9999
49 | # Set the class numbers
50 | __C.TRAIN.CLASSES_NUMS = 5
51 | # Set the image height
52 | __C.TRAIN.IMG_HEIGHT = 256
53 | # Set the image width
54 | __C.TRAIN.IMG_WIDTH = 256
55 | # Set the image height
56 | __C.TRAIN.CROP_IMG_HEIGHT = 224
57 | # Set the image width
58 | __C.TRAIN.CROP_IMG_WIDTH = 224
59 | # Set the GPU nums
60 | __C.TRAIN.GPU_NUM = 2
61 | # Set cpu multi process thread nums
62 | __C.TRAIN.CPU_MULTI_PROCESS_NUMS = 6
63 |
64 | # Test options
65 | __C.TEST = easydict.EasyDict()
66 |
67 | # Set the GPU resource used during testing process
68 | __C.TEST.GPU_MEMORY_FRACTION = 0.8
69 | # Set the GPU allow growth parameter during tensorflow testing process
70 | __C.TEST.TF_ALLOW_GROWTH = True
71 | # Set the test batch size
72 | __C.TEST.BATCH_SIZE = 64
73 |
74 | __C.NET = easydict.EasyDict()
75 | # Set net residual_blocks_nums
76 | __C.NET.RESNET_SIZE = 50
77 | # Set feats summary flag
78 | __C.NET.NEED_SUMMARY_FEATS_MAP = False
79 |
80 | # Set nsfw dataset label map
81 | NSFW_LABEL_MAP = {'drawing': 0, 'hentai': 1, 'neural': 2, 'porn': 3, 'sexy': 4}
82 | # Set nsfw dataset prediction map
83 | NSFW_PREDICT_MAP = {0: 'drawing', 1: 'hentai', 2: 'neural', 3: 'porn', 4: 'sexy'}
84 |
--------------------------------------------------------------------------------
/data/images/avg_train_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/images/avg_train_loss.png
--------------------------------------------------------------------------------
/data/images/avg_train_top1_error.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/images/avg_train_top1_error.png
--------------------------------------------------------------------------------
/data/images/avg_val_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/images/avg_val_loss.png
--------------------------------------------------------------------------------
/data/images/avg_val_top1_error.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/images/avg_val_top1_error.png
--------------------------------------------------------------------------------
/data/images/confusion_matrix.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/images/confusion_matrix.png
--------------------------------------------------------------------------------
/data/images/evaluation_nsfw.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/images/evaluation_nsfw.png
--------------------------------------------------------------------------------
/data/images/online_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/images/online_demo.png
--------------------------------------------------------------------------------
/data/images/precision_recall.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/images/precision_recall.png
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/drawing/drawing_136.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_136.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/drawing/drawing_147.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_147.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/drawing/drawing_192.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_192.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/drawing/drawing_207.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_207.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/drawing/drawing_219.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_219.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/drawing/drawing_227.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_227.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/drawing/drawing_240.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_240.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/drawing/drawing_253.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_253.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/drawing/drawing_28.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_28.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/drawing/drawing_346.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_346.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/drawing/drawing_348.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_348.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/drawing/drawing_361.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_361.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/drawing/drawing_384.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_384.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/drawing/drawing_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_4.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/drawing/drawing_463.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_463.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/drawing/drawing_483.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_483.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/drawing/drawing_52.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_52.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/drawing/drawing_77.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/drawing/drawing_77.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/hentai/hentai_106.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_106.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/hentai/hentai_112.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_112.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/hentai/hentai_12.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_12.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/hentai/hentai_194.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_194.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/hentai/hentai_216.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_216.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/hentai/hentai_291.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_291.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/hentai/hentai_311.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_311.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/hentai/hentai_318.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_318.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/hentai/hentai_331.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_331.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/hentai/hentai_377.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_377.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/hentai/hentai_407.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_407.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/hentai/hentai_424.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_424.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/hentai/hentai_437.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_437.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/hentai/hentai_458.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_458.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/hentai/hentai_489.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_489.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/hentai/hentai_499.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_499.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/hentai/hentai_65.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_65.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/hentai/hentai_677.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/hentai/hentai_677.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/neural/neural_352.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_352.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/neural/neural_353.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_353.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/neural/neural_358.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_358.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/neural/neural_360.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_360.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/neural/neural_374.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_374.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/neural/neural_377.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_377.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/neural/neural_438.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_438.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/neural/neural_446.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_446.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/neural/neural_474.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_474.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/neural/neural_493.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_493.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/neural/neural_518.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_518.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/neural/neural_523.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_523.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/neural/neural_536.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_536.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/neural/neural_539.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_539.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/neural/neural_569.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_569.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/neural/neural_596.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_596.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/neural/neural_619.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_619.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/neural/neural_683.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/neural/neural_683.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_1016.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1016.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_1198.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1198.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_1270.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1270.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_1364.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1364.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_1470.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1470.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_1488.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1488.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_1515.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1515.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_1551.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1551.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_1556.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1556.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_1570.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1570.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_1592.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1592.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_1677.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1677.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_1723.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1723.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_179.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_179.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_1794.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1794.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_1866.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1866.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_1899.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_1899.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_254.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_254.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_364.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_364.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_400.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_400.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_679.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_679.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_749.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_749.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_755.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_755.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/porn/porn_924.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/porn/porn_924.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_167.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_167.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_172.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_172.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_180.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_180.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_189.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_189.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_227.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_227.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_267.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_267.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_296.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_296.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_301.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_301.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_312.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_312.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_397.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_397.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_474.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_474.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_488.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_488.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_508.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_508.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_520.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_520.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_524.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_524.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_554.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_554.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_556.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_556.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_567.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_567.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_599.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_599.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_611.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_611.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_629.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_629.jpg
--------------------------------------------------------------------------------
/data/nsfw_dataset_example/sexy/sexy_712.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/nsfw_dataset_example/sexy/sexy_712.jpg
--------------------------------------------------------------------------------
/data/test_data/drawing_16715.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/test_data/drawing_16715.jpg
--------------------------------------------------------------------------------
/data/test_data/hentai_16814.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/test_data/hentai_16814.jpg
--------------------------------------------------------------------------------
/data/test_data/neural_2619.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/test_data/neural_2619.jpg
--------------------------------------------------------------------------------
/data/test_data/porn_19625.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/test_data/porn_19625.jpg
--------------------------------------------------------------------------------
/data/test_data/sexy_6182.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/data/test_data/sexy_6182.jpg
--------------------------------------------------------------------------------
/data_provider/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 19-2-14 下午5:42
4 | # @Author : Luo Yao
5 | # @Site : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
6 | # @File : __init__.py.py
7 | # @IDE: PyCharm
--------------------------------------------------------------------------------
/data_provider/nsfw_data_feed_pipline.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 19-2-14 下午5:43
4 | # @Author : MaybeShewill-CV
5 | # @Site : https://github.com/MaybeShewill-CV/CRNN_Tensorflow
6 | # @File : nsfw_data_feed_pipline.py
7 | # @IDE: PyCharm
8 | """
9 | nsfw数据feed pipline
10 | """
11 | import argparse
12 | import os
13 | import os.path as ops
14 | import random
15 | import multiprocessing
16 |
17 | import glob
18 | import glog as log
19 | import tensorflow as tf
20 | import pprint
21 |
22 | from config import global_config
23 | from data_provider import tf_io_pipline_tools
24 |
25 | CFG = global_config.cfg
26 |
27 |
28 | def init_args():
29 | """
30 |
31 | :return:
32 | """
33 | parser = argparse.ArgumentParser()
34 | parser.add_argument('--dataset_dir', type=str, help='The source nsfw data dir path')
35 | parser.add_argument('--tfrecords_dir', type=str, help='The dir path to save converted tfrecords')
36 |
37 | return parser.parse_args()
38 |
39 |
40 | class NsfwDataProducer(object):
41 | """
42 | Convert raw image file into tfrecords
43 | """
44 | def __init__(self, dataset_dir):
45 | """
46 |
47 | :param dataset_dir:
48 | """
49 | self._label_map = global_config.NSFW_LABEL_MAP
50 |
51 | self._dataset_dir = dataset_dir
52 |
53 | self._drawing_image_dir = ops.join(dataset_dir, 'drawing')
54 | self._hentai_image_dir = ops.join(dataset_dir, 'hentai')
55 | self._neural_image_dir = ops.join(dataset_dir, 'neural')
56 | self._porn_image_dir = ops.join(dataset_dir, 'porn')
57 | self._sexy_image_dir = ops.join(dataset_dir, 'sexy')
58 |
59 | self._train_example_index_file_path = ops.join(self._dataset_dir, 'train.txt')
60 | self._test_example_index_file_path = ops.join(self._dataset_dir, 'test.txt')
61 | self._val_example_index_file_path = ops.join(self._dataset_dir, 'val.txt')
62 |
63 | if not self._is_source_data_complete():
64 | raise ValueError('Source image data is not complete, '
65 | 'please check if one of the image folder is not exist')
66 |
67 | if not self._is_training_example_index_file_complete():
68 | self._generate_training_example_index_file()
69 |
70 | def print_label_map(self):
71 | """
72 | Print label map which show the labels information
73 | :return:
74 | """
75 | pprint.pprint(self._label_map)
76 |
77 | def generate_tfrecords(self, save_dir, step_size=10000):
78 | """
79 | Generate tensorflow records file
80 | :param save_dir:
81 | :param step_size: generate a tfrecord every step_size examples
82 | :return:
83 | """
84 | def _read_training_example_index_file(_index_file_path):
85 |
86 | assert ops.exists(_index_file_path)
87 |
88 | _example_path_info = []
89 | _example_label_info = []
90 |
91 | with open(_index_file_path, 'r') as _file:
92 | for _line in _file:
93 | _example_info = _line.rstrip('\r').rstrip('\n').split(' ')
94 | _example_path_info.append(_example_info[0])
95 | _example_label_info.append(int(_example_info[1]))
96 |
97 | return _example_path_info, _example_label_info
98 |
99 | def _split_writing_tfrecords_task(_example_paths, _example_labels, _flags='train'):
100 |
101 | _split_example_paths = []
102 | _split_example_labels = []
103 | _split_tfrecords_save_paths = []
104 |
105 | for i in range(0, len(_example_paths), step_size):
106 | _split_example_paths.append(_example_paths[i:i + step_size])
107 | _split_example_labels.append(_example_labels[i:i + step_size])
108 |
109 | if i + step_size > len(_example_paths):
110 | _split_tfrecords_save_paths.append(
111 | ops.join(save_dir, '{:s}_{:d}_{:d}.tfrecords'.format(_flags, i, len(_example_paths))))
112 | else:
113 | _split_tfrecords_save_paths.append(
114 | ops.join(save_dir, '{:s}_{:d}_{:d}.tfrecords'.format(_flags, i, i + step_size)))
115 |
116 | return _split_example_paths, _split_example_labels, _split_tfrecords_save_paths
117 |
118 | # make save dirs
119 | os.makedirs(save_dir, exist_ok=True)
120 |
121 | # set process pool
122 | process_pool = multiprocessing.Pool(processes=CFG.TRAIN.CPU_MULTI_PROCESS_NUMS)
123 |
124 | # generate training example tfrecords
125 | log.info('Generating training example tfrecords')
126 |
127 | train_example_paths, train_example_labels = _read_training_example_index_file(
128 | self._train_example_index_file_path)
129 | train_example_paths_split, train_example_labels_split, train_tfrecords_save_paths = \
130 | _split_writing_tfrecords_task(train_example_paths, train_example_labels, _flags='train')
131 |
132 | for index, example_paths in enumerate(train_example_paths_split):
133 | process_pool.apply_async(func=tf_io_pipline_tools.write_example_tfrecords,
134 | args=(example_paths,
135 | train_example_labels_split[index],
136 | train_tfrecords_save_paths[index],))
137 |
138 | process_pool.close()
139 | process_pool.join()
140 |
141 | log.info('Generate training example tfrecords complete')
142 |
143 | # set process pool
144 | process_pool = multiprocessing.Pool(processes=CFG.TRAIN.CPU_MULTI_PROCESS_NUMS)
145 |
146 | # generate val example tfrecords
147 | log.info('Generating validation example tfrecords')
148 |
149 | val_example_paths, val_example_labels = _read_training_example_index_file(
150 | self._val_example_index_file_path)
151 | val_example_paths_split, val_example_labels_split, val_tfrecords_save_paths = \
152 | _split_writing_tfrecords_task(val_example_paths, val_example_labels, _flags='val')
153 |
154 | for index, example_paths in enumerate(val_example_paths_split):
155 | process_pool.apply_async(func=tf_io_pipline_tools.write_example_tfrecords,
156 | args=(example_paths,
157 | val_example_labels_split[index],
158 | val_tfrecords_save_paths[index],))
159 |
160 | process_pool.close()
161 | process_pool.join()
162 |
163 | log.info('Generate validation example tfrecords complete')
164 |
165 | # set process pool
166 | process_pool = multiprocessing.Pool(processes=CFG.TRAIN.CPU_MULTI_PROCESS_NUMS)
167 |
168 | # generate test example tfrecords
169 | log.info('Generating testing example tfrecords')
170 |
171 | test_example_paths, test_example_labels = _read_training_example_index_file(
172 | self._test_example_index_file_path)
173 | test_example_paths_split, test_example_labels_split, test_tfrecords_save_paths = \
174 | _split_writing_tfrecords_task(test_example_paths, test_example_labels, _flags='test')
175 |
176 | for index, example_paths in enumerate(test_example_paths_split):
177 | process_pool.apply_async(func=tf_io_pipline_tools.write_example_tfrecords,
178 | args=(example_paths,
179 | test_example_labels_split[index],
180 | test_tfrecords_save_paths[index],))
181 |
182 | process_pool.close()
183 | process_pool.join()
184 |
185 | log.info('Generate testing example tfrecords complete')
186 |
187 | return
188 |
189 | def _is_source_data_complete(self):
190 | """
191 | Check if source data complete
192 | :return:
193 | """
194 | return \
195 | ops.exists(self._drawing_image_dir) and ops.exists(self._hentai_image_dir) \
196 | and ops.exists(self._neural_image_dir) and ops.exists(self._porn_image_dir) \
197 | and ops.exists(self._sexy_image_dir)
198 |
199 | def _is_training_example_index_file_complete(self):
200 | """
201 | Check if the training example index file is complete
202 | :return:
203 | """
204 | return \
205 | ops.exists(self._train_example_index_file_path) and \
206 | ops.exists(self._test_example_index_file_path) and \
207 | ops.exists(self._val_example_index_file_path)
208 |
209 | def _generate_training_example_index_file(self):
210 | """
211 | Generate training example index file, split source file into 0.75, 0.15, 0.1 for training,
212 | testing and validation. Each image folder are processed separately
213 | :return:
214 | """
215 | def _process_single_training_folder(_folder_dir):
216 |
217 | _folder_label_name = ops.split(_folder_dir)[1]
218 | _folder_label_index = self._label_map[_folder_label_name]
219 |
220 | _source_image_paths = glob.glob('{:s}/*'.format(_folder_dir))
221 |
222 | return ['{:s} {:d}\n'.format(s, _folder_label_index) for s in _source_image_paths]
223 |
224 | def _split_training_examples(_example_info):
225 |
226 | random.shuffle(_example_info)
227 |
228 | _example_nums = len(_example_info)
229 |
230 | _train_example_info = _example_info[:int(_example_nums * 0.75)]
231 | _test_example_info = _example_info[int(_example_nums * 0.75):int(_example_nums * 0.9)]
232 | _val_example_info = _example_info[int(_example_nums * 0.9):]
233 |
234 | return _train_example_info, _test_example_info, _val_example_info
235 |
236 | train_example_info = []
237 | test_example_info = []
238 | val_example_info = []
239 |
240 | for example_dir in [self._drawing_image_dir, self._hentai_image_dir,
241 | self._neural_image_dir, self._porn_image_dir,
242 | self._sexy_image_dir]:
243 | _train_tmp_info, _test_tmp_info, _val_tmp_info = \
244 | _split_training_examples(_process_single_training_folder(example_dir))
245 |
246 | train_example_info.extend(_train_tmp_info)
247 | test_example_info.extend(_test_tmp_info)
248 | val_example_info.extend(_val_tmp_info)
249 |
250 | random.shuffle(train_example_info)
251 | random.shuffle(test_example_info)
252 | random.shuffle(val_example_info)
253 |
254 | with open(ops.join(self._dataset_dir, 'train.txt'), 'w') as file:
255 | file.write(''.join(train_example_info))
256 |
257 | with open(ops.join(self._dataset_dir, 'test.txt'), 'w') as file:
258 | file.write(''.join(test_example_info))
259 |
260 | with open(ops.join(self._dataset_dir, 'val.txt'), 'w') as file:
261 | file.write(''.join(val_example_info))
262 |
263 | log.info('Generate training example index file complete')
264 |
265 | return
266 |
267 |
268 | class NsfwDataFeeder(object):
269 | """
270 | Read training examples from tfrecords for nsfw model
271 | """
272 | def __init__(self, dataset_dir, flags='train'):
273 | """
274 |
275 | :param dataset_dir:
276 | :param flags:
277 | """
278 | self._dataset_dir = dataset_dir
279 |
280 | self._tfrecords_dir = ops.join(dataset_dir, 'tfrecords')
281 | if not ops.exists(self._tfrecords_dir):
282 | raise ValueError('{:s} not exist, please check again'.format(self._tfrecords_dir))
283 |
284 | self._dataset_flags = flags.lower()
285 | if self._dataset_flags not in ['train', 'test', 'val']:
286 | raise ValueError('flags of the data feeder should be \'train\', \'test\', \'val\'')
287 |
288 | self._label_map = global_config.NSFW_LABEL_MAP
289 |
290 | self._prediction_map = global_config.NSFW_PREDICT_MAP
291 |
292 | @property
293 | def label_map(self):
294 | """
295 |
296 | :return:
297 | """
298 | return self._label_map
299 |
300 | @property
301 | def prediction_map(self):
302 | """
303 |
304 | :return:
305 | """
306 | return self._prediction_map
307 |
308 | def inputs(self, batch_size, num_epochs):
309 | """
310 | dataset feed pipline input
311 | :param batch_size:
312 | :param num_epochs:
313 | :return: A tuple (images, labels), where:
314 | * images is a float tensor with shape [batch_size, H, W, C]
315 | in the range [-0.5, 0.5].
316 | * labels is an int32 tensor with shape [batch_size] with the true label,
317 | a number in the range [0, CLASS_NUMS).
318 | """
319 | if not num_epochs:
320 | num_epochs = None
321 |
322 | tfrecords_file_paths = glob.glob('{:s}/{:s}*.tfrecords'.format(self._tfrecords_dir, self._dataset_flags))
323 | random.shuffle(tfrecords_file_paths)
324 |
325 | with tf.name_scope('input_tensor'):
326 |
327 | # TFRecordDataset opens a binary file and reads one record at a time.
328 | # `tfrecords_file_paths` could also be a list of filenames, which will be read in order.
329 | dataset = tf.data.TFRecordDataset(tfrecords_file_paths)
330 |
331 | # The map transformation takes a function and applies it to every element
332 | # of the dataset.
333 | dataset = dataset.map(map_func=tf_io_pipline_tools.decode,
334 | num_parallel_calls=CFG.TRAIN.CPU_MULTI_PROCESS_NUMS)
335 | if self._dataset_flags == 'train':
336 | dataset = dataset.map(map_func=tf_io_pipline_tools.augment_for_train,
337 | num_parallel_calls=CFG.TRAIN.CPU_MULTI_PROCESS_NUMS)
338 | else:
339 | dataset = dataset.map(map_func=tf_io_pipline_tools.augment_for_validation,
340 | num_parallel_calls=CFG.TRAIN.CPU_MULTI_PROCESS_NUMS)
341 | dataset = dataset.map(map_func=tf_io_pipline_tools.normalize,
342 | num_parallel_calls=CFG.TRAIN.CPU_MULTI_PROCESS_NUMS)
343 |
344 | # The shuffle transformation uses a finite-sized buffer to shuffle elements
345 | # in memory. The parameter is the number of elements in the buffer. For
346 | # completely uniform shuffling, set the parameter to be the same as the
347 | # number of elements in the dataset.
348 | if self._dataset_flags != 'test':
349 | dataset = dataset.shuffle(buffer_size=5000)
350 | # repeat num epochs
351 | dataset = dataset.repeat()
352 |
353 | dataset = dataset.batch(batch_size)
354 |
355 | iterator = dataset.make_one_shot_iterator()
356 |
357 | return iterator.get_next(name='{:s}_IteratorGetNext'.format(self._dataset_flags))
358 |
359 |
360 | if __name__ == '__main__':
361 |
362 | # init args
363 | args = init_args()
364 |
365 | assert ops.exists(args.dataset_dir), '{:s} not exist'.format(args.dataset_dir)
366 |
367 | producer = NsfwDataProducer(dataset_dir=args.dataset_dir)
368 | producer.print_label_map()
369 | producer.generate_tfrecords(save_dir=args.tfrecords_dir, step_size=10000)
370 |
--------------------------------------------------------------------------------
/data_provider/tf_io_pipline_tools.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 19-2-15 下午2:13
4 | # @Author : MaybeShewill-CV
5 | # @Site : https://github.com/MaybeShewill-CV/CRNN_Tensorflow
6 | # @File : tf_io_pipline_tools.py
7 | # @IDE: PyCharm
8 | """
9 | Some tensorflow records io tools
10 | """
11 | import os
12 | import os.path as ops
13 |
14 | import cv2
15 | import tensorflow as tf
16 | import glog as log
17 |
18 | from config import global_config
19 |
20 |
21 | CFG = global_config.cfg
22 |
23 | _R_MEAN = 123.68
24 | _G_MEAN = 116.78
25 | _B_MEAN = 103.94
26 | _CHANNEL_MEANS = [_B_MEAN, _G_MEAN, _R_MEAN]
27 |
28 |
29 | def int64_feature(value):
30 | """
31 |
32 | :return:
33 | """
34 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
35 |
36 |
37 | def bytes_feature(value):
38 | """
39 |
40 | :param value:
41 | :return:
42 | """
43 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
44 |
45 |
46 | def write_example_tfrecords(example_paths, example_labels, tfrecords_path):
47 | """
48 | write tfrecords
49 | :param example_paths:
50 | :param example_labels:
51 | :param tfrecords_path:
52 | :return:
53 | """
54 | _tfrecords_dir = ops.split(tfrecords_path)[0]
55 | os.makedirs(_tfrecords_dir, exist_ok=True)
56 |
57 | log.info('Writing {:s}....'.format(tfrecords_path))
58 |
59 | with tf.python_io.TFRecordWriter(tfrecords_path) as _writer:
60 | for _index, _example_path in enumerate(example_paths):
61 |
62 | with open(_example_path, 'rb') as f:
63 | check_chars = f.read()[-2:]
64 | if check_chars != b'\xff\xd9':
65 | log.error('Image file {:s} is not complete'.format(_example_path))
66 | continue
67 | else:
68 | _example_image = cv2.imread(_example_path, cv2.IMREAD_COLOR)
69 | _example_image = cv2.resize(_example_image,
70 | dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
71 | interpolation=cv2.INTER_CUBIC)
72 | _example_image_raw = _example_image.tostring()
73 |
74 | _example = tf.train.Example(
75 | features=tf.train.Features(
76 | feature={
77 | 'height': int64_feature(CFG.TRAIN.IMG_HEIGHT),
78 | 'width': int64_feature(CFG.TRAIN.IMG_WIDTH),
79 | 'depth': int64_feature(3),
80 | 'label': int64_feature(example_labels[_index]),
81 | 'image_raw': bytes_feature(_example_image_raw)
82 | }))
83 | _writer.write(_example.SerializeToString())
84 |
85 | log.info('Writing {:s} complete'.format(tfrecords_path))
86 |
87 | return
88 |
89 |
90 | def decode(serialized_example):
91 | """
92 | Parses an image and label from the given `serialized_example`
93 | :param serialized_example:
94 | :return:
95 | """
96 | features = tf.parse_single_example(
97 | serialized_example,
98 | # Defaults are not specified since both keys are required.
99 | features={
100 | 'image_raw': tf.FixedLenFeature([], tf.string),
101 | 'label': tf.FixedLenFeature([], tf.int64),
102 | 'height': tf.FixedLenFeature([], tf.int64),
103 | 'width': tf.FixedLenFeature([], tf.int64),
104 | 'depth': tf.FixedLenFeature([], tf.int64)
105 | })
106 |
107 | # decode image
108 | image = tf.decode_raw(features['image_raw'], tf.uint8)
109 | image_shape = tf.stack([CFG.TRAIN.IMG_HEIGHT, CFG.TRAIN.IMG_WIDTH, 3])
110 | image = tf.reshape(image, image_shape)
111 |
112 | # Convert label from a scalar int64 tensor to an int32 scalar.
113 | label = tf.cast(features['label'], tf.int32)
114 |
115 | return image, label
116 |
117 |
118 | def augment_for_train(image, label):
119 | """
120 |
121 | :param image:
122 | :param label:
123 | :return:
124 | """
125 | # first apply random crop
126 | image = tf.image.random_crop(value=image,
127 | size=[CFG.TRAIN.CROP_IMG_HEIGHT, CFG.TRAIN.CROP_IMG_WIDTH, 3],
128 | seed=tf.set_random_seed(1234),
129 | name='crop_image')
130 | # apply random flip
131 | image = tf.image.random_flip_left_right(image=image, seed=tf.set_random_seed(1234))
132 |
133 | return image, label
134 |
135 |
136 | def augment_for_validation(image, label):
137 | """
138 |
139 | :param image:
140 | :param label:
141 | :return:
142 | """
143 | assert CFG.TRAIN.IMG_HEIGHT == CFG.TRAIN.IMG_WIDTH
144 | assert CFG.TRAIN.CROP_IMG_HEIGHT == CFG.TRAIN.CROP_IMG_WIDTH
145 |
146 | # apply central crop
147 | central_fraction = CFG.TRAIN.CROP_IMG_HEIGHT / CFG.TRAIN.IMG_HEIGHT
148 | image = tf.image.central_crop(image=image, central_fraction=central_fraction)
149 |
150 | return image, label
151 |
152 |
153 | def normalize(image, label):
154 | """
155 | Normalize the image data by substracting the imagenet mean value
156 | :param image:
157 | :param label:
158 | :return:
159 | """
160 |
161 | if image.get_shape().ndims != 3:
162 | raise ValueError('Input must be of size [height, width, C>0]')
163 |
164 | image_fp = tf.cast(image, dtype=tf.float32)
165 | means = tf.expand_dims(tf.expand_dims(_CHANNEL_MEANS, 0), 0)
166 |
167 | return image_fp - means, label
168 |
--------------------------------------------------------------------------------
/docker_container/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 19-3-14 下午9:48
4 | # @Author : Luo Yao
5 | # @Site : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
6 | # @File : __init__.py.py
7 | # @IDE: PyCharm
--------------------------------------------------------------------------------
/docker_container/python_client.py:
--------------------------------------------------------------------------------
1 | '''
2 | Send JPEG image to tensorflow_model_server loaded with GAN model.
3 |
4 | Hint: the code has been compiled together with TensorFlow serving
5 | and not locally. The client is called in the TensorFlow Docker container
6 | '''
7 |
8 | import time
9 | from argparse import ArgumentParser
10 |
11 | import grpc
12 | import numpy as np
13 | import cv2
14 | from tensorflow.contrib.util import make_tensor_proto
15 |
16 | from tensorflow_serving.apis import predict_pb2
17 | from tensorflow_serving.apis import prediction_service_pb2_grpc
18 |
19 |
20 | def parse_args():
21 | parser = ArgumentParser(description='Request a TensorFlow server for a prediction on the image')
22 | parser.add_argument('-s', '--server',
23 | dest='server',
24 | default='0.0.0.0:9000',
25 | help='prediction service host:port')
26 | parser.add_argument("-i", "--image",
27 | dest="image",
28 | default='',
29 | help="path to image in JPEG format", )
30 | parser.add_argument('-p', '--image_path',
31 | dest='image_path',
32 | default='/home/baidu/Pictures/dota2_1.jpg',
33 | help='path to images folder', )
34 | parser.add_argument('-b', '--batch_mode',
35 | dest='batch_mode',
36 | default='true',
37 | help='send image as batch or one-by-one')
38 | args = parser.parse_args()
39 |
40 | return args.server, args.image, args.image_path, args.batch_mode == 'true'
41 |
42 |
43 | def main():
44 | """
45 |
46 | :return:
47 | """
48 | server, image, image_path, batch_mode = parse_args()
49 |
50 | channel = grpc.insecure_channel(server)
51 | stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
52 |
53 | image = cv2.imread(image_path, cv2.IMREAD_COLOR)
54 | image = cv2.resize(image, (224, 224))
55 | image = np.array(image, np.float32)
56 |
57 | image_list = []
58 |
59 | for i in range(128):
60 | image_list.append(image)
61 |
62 | image_list = np.array(image_list, dtype=np.float32)
63 |
64 | start = time.time()
65 |
66 | if batch_mode:
67 | print('In batch mode')
68 | request = predict_pb2.PredictRequest()
69 | request.model_spec.name = 'nsfw'
70 | request.model_spec.signature_name = 'classify_result'
71 |
72 | request.inputs['input_tensor'].CopyFrom(make_tensor_proto(
73 | image_list, shape=[128, 224, 224, 3]))
74 |
75 | try:
76 | result = stub.Predict(request, 60.0)
77 | except Exception as err:
78 | print(err)
79 | return
80 | print(result)
81 | else:
82 | return
83 |
84 | end = time.time()
85 | time_diff = end - start
86 | print('time elapased: {}'.format(time_diff))
87 |
88 |
89 | if __name__ == '__main__':
90 | main()
91 |
--------------------------------------------------------------------------------
/model/nsfw_export_saved_model/1/saved_model.pb:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/model/nsfw_export_saved_model/1/saved_model.pb
--------------------------------------------------------------------------------
/model/nsfw_export_saved_model/1/variables/variables.data-00000-of-00001:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/model/nsfw_export_saved_model/1/variables/variables.data-00000-of-00001
--------------------------------------------------------------------------------
/model/nsfw_export_saved_model/1/variables/variables.index:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/model/nsfw_export_saved_model/1/variables/variables.index
--------------------------------------------------------------------------------
/nsfw_model/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 19-2-14 下午5:38
4 | # @Author : Luo Yao
5 | # @Site : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
6 | # @File : __init__.py.py
7 | # @IDE: PyCharm
8 |
--------------------------------------------------------------------------------
/nsfw_model/cnn_basenet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 17-9-18 下午3:59
4 | # @Author : MaybeShewill-CV
5 | # @Site : https://github.com/MaybeShewill-CV/CRNN_Tensorflow
6 | # @File : cnn_basenet.py
7 | # @IDE: PyCharm Community Edition
8 | """
9 | The base convolution neural networks mainly implement some useful cnn functions
10 | """
11 | import tensorflow as tf
12 | from tensorflow.python.training import moving_averages
13 | from tensorflow.contrib.framework import add_model_variable
14 | import numpy as np
15 |
16 |
17 | class CNNBaseModel(object):
18 | """
19 | Base model for other specific cnn ctpn_models
20 | """
21 |
22 | def __init__(self):
23 | pass
24 |
25 | @staticmethod
26 | def conv2d(inputdata, out_channel, kernel_size, padding='SAME',
27 | stride=1, w_init=None, b_init=None,
28 | split=1, use_bias=True, data_format='NHWC', name=None):
29 | """
30 | Packing the tensorflow conv2d function.
31 | :param name: op name
32 | :param inputdata: A 4D tensorflow tensor which ust have known number of channels, but can have other
33 | unknown dimensions.
34 | :param out_channel: number of output channel.
35 | :param kernel_size: int so only support square kernel convolution
36 | :param padding: 'VALID' or 'SAME'
37 | :param stride: int so only support square stride
38 | :param w_init: initializer for convolution weights
39 | :param b_init: initializer for bias
40 | :param split: split channels as used in Alexnet mainly group for GPU memory save.
41 | :param use_bias: whether to use bias.
42 | :param data_format: default set to NHWC according tensorflow
43 | :return: tf.Tensor named ``output``
44 | """
45 | with tf.variable_scope(name):
46 | in_shape = inputdata.get_shape().as_list()
47 | channel_axis = 3 if data_format == 'NHWC' else 1
48 | in_channel = in_shape[channel_axis]
49 |
50 | assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!"
51 | assert in_channel % split == 0
52 | assert out_channel % split == 0
53 |
54 | padding = padding.upper()
55 |
56 | if isinstance(kernel_size, list):
57 | filter_shape = [kernel_size[0], kernel_size[1]] + [in_channel / split, out_channel]
58 | else:
59 | filter_shape = [kernel_size, kernel_size] + [in_channel / split, out_channel]
60 |
61 | if isinstance(stride, list):
62 | strides = [1, stride[0], stride[1], 1] if data_format == 'NHWC' \
63 | else [1, 1, stride[0], stride[1]]
64 | else:
65 | strides = [1, stride, stride, 1] if data_format == 'NHWC' \
66 | else [1, 1, stride, stride]
67 |
68 | if w_init is None:
69 | w_init = tf.contrib.layers.variance_scaling_initializer()
70 | if b_init is None:
71 | b_init = tf.constant_initializer()
72 |
73 | w = tf.get_variable('W', filter_shape, initializer=w_init)
74 | b = None
75 |
76 | if use_bias:
77 | b = tf.get_variable('b', [out_channel], initializer=b_init)
78 |
79 | if split == 1:
80 | conv = tf.nn.conv2d(inputdata, w, strides, padding, data_format=data_format)
81 | else:
82 | inputs = tf.split(inputdata, split, channel_axis)
83 | kernels = tf.split(w, split, 3)
84 | outputs = [tf.nn.conv2d(i, k, strides, padding, data_format=data_format)
85 | for i, k in zip(inputs, kernels)]
86 | conv = tf.concat(outputs, channel_axis)
87 |
88 | ret = tf.identity(tf.nn.bias_add(conv, b, data_format=data_format)
89 | if use_bias else conv, name=name)
90 |
91 | return ret
92 |
93 | @staticmethod
94 | def relu(inputdata, name=None):
95 | """
96 |
97 | :param name:
98 | :param inputdata:
99 | :return:
100 | """
101 | return tf.nn.relu(features=inputdata, name=name)
102 |
103 | @staticmethod
104 | def sigmoid(inputdata, name=None):
105 | """
106 |
107 | :param name:
108 | :param inputdata:
109 | :return:
110 | """
111 | return tf.nn.sigmoid(x=inputdata, name=name)
112 |
113 | @staticmethod
114 | def maxpooling(inputdata, kernel_size, stride=None, padding='VALID',
115 | data_format='NHWC', name=None):
116 | """
117 |
118 | :param name:
119 | :param inputdata:
120 | :param kernel_size:
121 | :param stride:
122 | :param padding:
123 | :param data_format:
124 | :return:
125 | """
126 | padding = padding.upper()
127 |
128 | if stride is None:
129 | stride = kernel_size
130 |
131 | if isinstance(kernel_size, list):
132 | kernel = [1, kernel_size[0], kernel_size[1], 1] if data_format == 'NHWC' else \
133 | [1, 1, kernel_size[0], kernel_size[1]]
134 | else:
135 | kernel = [1, kernel_size, kernel_size, 1] if data_format == 'NHWC' \
136 | else [1, 1, kernel_size, kernel_size]
137 |
138 | if isinstance(stride, list):
139 | strides = [1, stride[0], stride[1], 1] if data_format == 'NHWC' \
140 | else [1, 1, stride[0], stride[1]]
141 | else:
142 | strides = [1, stride, stride, 1] if data_format == 'NHWC' \
143 | else [1, 1, stride, stride]
144 |
145 | return tf.nn.max_pool(value=inputdata, ksize=kernel, strides=strides, padding=padding,
146 | data_format=data_format, name=name)
147 |
148 | @staticmethod
149 | def avgpooling(inputdata, kernel_size, stride=None, padding='VALID',
150 | data_format='NHWC', name=None):
151 | """
152 |
153 | :param name:
154 | :param inputdata:
155 | :param kernel_size:
156 | :param stride:
157 | :param padding:
158 | :param data_format:
159 | :return:
160 | """
161 | if stride is None:
162 | stride = kernel_size
163 |
164 | kernel = [1, kernel_size, kernel_size, 1] if data_format == 'NHWC' \
165 | else [1, 1, kernel_size, kernel_size]
166 |
167 | strides = [1, stride, stride, 1] if data_format == 'NHWC' else [1, 1, stride, stride]
168 |
169 | return tf.nn.avg_pool(value=inputdata, ksize=kernel, strides=strides, padding=padding,
170 | data_format=data_format, name=name)
171 |
172 | @staticmethod
173 | def globalavgpooling(inputdata, data_format='NHWC', name=None):
174 | """
175 |
176 | :param name:
177 | :param inputdata:
178 | :param data_format:
179 | :return:
180 | """
181 | assert inputdata.shape.ndims == 4
182 | assert data_format in ['NHWC', 'NCHW']
183 |
184 | axis = [1, 2] if data_format == 'NHWC' else [2, 3]
185 |
186 | return tf.reduce_mean(input_tensor=inputdata, axis=axis, name=name)
187 |
188 | @staticmethod
189 | def layernorm(inputdata, epsilon=1e-5, use_bias=True, use_scale=True,
190 | data_format='NHWC', name=None):
191 | """
192 | :param name:
193 | :param inputdata:
194 | :param epsilon: epsilon to avoid divide-by-zero.
195 | :param use_bias: whether to use the extra affine transformation or not.
196 | :param use_scale: whether to use the extra affine transformation or not.
197 | :param data_format:
198 | :return:
199 | """
200 | shape = inputdata.get_shape().as_list()
201 | ndims = len(shape)
202 | assert ndims in [2, 4]
203 |
204 | mean, var = tf.nn.moments(inputdata, list(range(1, len(shape))), keep_dims=True)
205 |
206 | if data_format == 'NCHW':
207 | channnel = shape[1]
208 | new_shape = [1, channnel, 1, 1]
209 | else:
210 | channnel = shape[-1]
211 | new_shape = [1, 1, 1, channnel]
212 | if ndims == 2:
213 | new_shape = [1, channnel]
214 |
215 | if use_bias:
216 | beta = tf.get_variable('beta', [channnel], initializer=tf.constant_initializer())
217 | beta = tf.reshape(beta, new_shape)
218 | else:
219 | beta = tf.zeros([1] * ndims, name='beta')
220 | if use_scale:
221 | gamma = tf.get_variable('gamma', [channnel], initializer=tf.constant_initializer(1.0))
222 | gamma = tf.reshape(gamma, new_shape)
223 | else:
224 | gamma = tf.ones([1] * ndims, name='gamma')
225 |
226 | return tf.nn.batch_normalization(inputdata, mean, var, beta, gamma, epsilon, name=name)
227 |
228 | @staticmethod
229 | def instancenorm(inputdata, epsilon=1e-5, data_format='NHWC', use_affine=True, name=None):
230 | """
231 |
232 | :param name:
233 | :param inputdata:
234 | :param epsilon:
235 | :param data_format:
236 | :param use_affine:
237 | :return:
238 | """
239 | shape = inputdata.get_shape().as_list()
240 | if len(shape) != 4:
241 | raise ValueError("Input data of instancebn layer has to be 4D tensor")
242 |
243 | if data_format == 'NHWC':
244 | axis = [1, 2]
245 | ch = shape[3]
246 | new_shape = [1, 1, 1, ch]
247 | else:
248 | axis = [2, 3]
249 | ch = shape[1]
250 | new_shape = [1, ch, 1, 1]
251 | if ch is None:
252 | raise ValueError("Input of instancebn require known channel!")
253 |
254 | mean, var = tf.nn.moments(inputdata, axis, keep_dims=True)
255 |
256 | if not use_affine:
257 | return tf.divide(inputdata - mean, tf.sqrt(var + epsilon), name='output')
258 |
259 | beta = tf.get_variable('beta', [ch], initializer=tf.constant_initializer())
260 | beta = tf.reshape(beta, new_shape)
261 | gamma = tf.get_variable('gamma', [ch], initializer=tf.constant_initializer(1.0))
262 | gamma = tf.reshape(gamma, new_shape)
263 | return tf.nn.batch_normalization(inputdata, mean, var, beta, gamma, epsilon, name=name)
264 |
265 | @staticmethod
266 | def dropout(inputdata, keep_prob, noise_shape=None, name=None):
267 | """
268 |
269 | :param name:
270 | :param inputdata:
271 | :param keep_prob:
272 | :param noise_shape:
273 | :return:
274 | """
275 | return tf.nn.dropout(inputdata, keep_prob=keep_prob, noise_shape=noise_shape, name=name)
276 |
277 | @staticmethod
278 | def fullyconnect(inputdata, out_dim, w_init=None, b_init=None,
279 | use_bias=True, name=None):
280 | """
281 | Fully-Connected layer, takes a N>1D tensor and returns a 2D tensor.
282 | It is an equivalent of `tf.layers.dense` except for naming conventions.
283 |
284 | :param inputdata: a tensor to be flattened except for the first dimension.
285 | :param out_dim: output dimension
286 | :param w_init: initializer for w. Defaults to `variance_scaling_initializer`.
287 | :param b_init: initializer for b. Defaults to zero
288 | :param use_bias: whether to use bias.
289 | :param name:
290 | :return: tf.Tensor: a NC tensor named ``output`` with attribute `variables`.
291 | """
292 | shape = inputdata.get_shape().as_list()[1:]
293 | if None not in shape:
294 | inputdata = tf.reshape(inputdata, [-1, int(np.prod(shape))])
295 | else:
296 | inputdata = tf.reshape(inputdata, tf.stack([tf.shape(inputdata)[0], -1]))
297 |
298 | if w_init is None:
299 | w_init = tf.contrib.layers.variance_scaling_initializer()
300 | if b_init is None:
301 | b_init = tf.constant_initializer()
302 |
303 | ret = tf.layers.dense(inputs=inputdata, activation=lambda x: tf.identity(x, name='output'),
304 | use_bias=use_bias, name=name,
305 | kernel_initializer=w_init,
306 | bias_initializer=b_init,
307 | trainable=True, units=out_dim)
308 | return ret
309 |
310 | @staticmethod
311 | def layerbn(inputdata, is_training, name, momentum=0.999, eps=1e-3):
312 | """
313 |
314 | :param inputdata:
315 | :param is_training:
316 | :param name:
317 | :param momentum:
318 | :param eps:
319 | :return:
320 | """
321 |
322 | return tf.layers.batch_normalization(
323 | inputs=inputdata, training=is_training, name=name, momentum=momentum, epsilon=eps)
324 |
325 | @staticmethod
326 | def layerbn_distributed(list_input, stats_mode, data_format='NHWC',
327 | float_type=tf.float32, trainable=True,
328 | use_gamma=True, use_beta=True, bn_epsilon=1e-5,
329 | bn_ema=0.9, name='BatchNorm'):
330 | """
331 | Batch norm for distributed training process
332 | :param list_input:
333 | :param stats_mode:
334 | :param data_format:
335 | :param float_type:
336 | :param trainable:
337 | :param use_gamma:
338 | :param use_beta:
339 | :param bn_epsilon:
340 | :param bn_ema:
341 | :param name:
342 | :return:
343 | """
344 |
345 | def _get_bn_variables(_n_out, _use_scale, _use_bias, _trainable, _float_type):
346 |
347 | if _use_bias:
348 | _beta = tf.get_variable('beta', [_n_out],
349 | initializer=tf.constant_initializer(),
350 | trainable=_trainable,
351 | dtype=_float_type)
352 | else:
353 | _beta = tf.zeros([_n_out], name='beta')
354 | if _use_scale:
355 | _gamma = tf.get_variable('gamma', [_n_out],
356 | initializer=tf.constant_initializer(1.0),
357 | trainable=_trainable,
358 | dtype=_float_type)
359 | else:
360 | _gamma = tf.ones([_n_out], name='gamma')
361 |
362 | _moving_mean = tf.get_variable('moving_mean', [_n_out],
363 | initializer=tf.constant_initializer(),
364 | trainable=False,
365 | dtype=_float_type)
366 | _moving_var = tf.get_variable('moving_variance', [_n_out],
367 | initializer=tf.constant_initializer(1),
368 | trainable=False,
369 | dtype=_float_type)
370 | return _beta, _gamma, _moving_mean, _moving_var
371 |
372 | def _update_bn_ema(_xn, _batch_mean, _batch_var, _moving_mean, _moving_var, _decay):
373 |
374 | _update_op1 = moving_averages.assign_moving_average(
375 | _moving_mean, _batch_mean, _decay, zero_debias=False,
376 | name='mean_ema_op')
377 | _update_op2 = moving_averages.assign_moving_average(
378 | _moving_var, _batch_var, _decay, zero_debias=False,
379 | name='var_ema_op')
380 | add_model_variable(moving_mean)
381 | add_model_variable(moving_var)
382 |
383 | # seems faster than delayed update, but might behave otherwise in distributed settings.
384 | with tf.control_dependencies([_update_op1, _update_op2]):
385 | return tf.identity(xn, name='output')
386 |
387 | # ======================== Checking valid values =========================
388 | if data_format not in ['NHWC', 'NCHW']:
389 | raise TypeError(
390 | "Only two data formats are supported at this moment: 'NHWC' or 'NCHW', "
391 | "%s is an unknown data format." % data_format)
392 | assert type(list_input) == list
393 |
394 | # ======================== Setting default values =========================
395 | shape = list_input[0].get_shape().as_list()
396 | assert len(shape) in [2, 4]
397 | n_out = shape[-1]
398 | if data_format == 'NCHW':
399 | n_out = shape[1]
400 |
401 | # ======================== Main operations =============================
402 | means = []
403 | square_means = []
404 | for i in range(len(list_input)):
405 | with tf.device('/gpu:%d' % i):
406 | batch_mean = tf.reduce_mean(list_input[i], [0, 1, 2])
407 | batch_square_mean = tf.reduce_mean(tf.square(list_input[i]), [0, 1, 2])
408 | means.append(batch_mean)
409 | square_means.append(batch_square_mean)
410 |
411 | # if your GPUs have NVLinks and you've install NCCL2, you can change `/cpu:0` to `/gpu:0`
412 | with tf.device('/cpu:0'):
413 | shape = tf.shape(list_input[0])
414 | num = shape[0] * shape[1] * shape[2] * len(list_input)
415 | mean = tf.reduce_mean(means, axis=0)
416 | var = tf.reduce_mean(square_means, axis=0) - tf.square(mean)
417 | var *= tf.cast(num, float_type) / tf.cast(num - 1, float_type) # unbiased variance
418 |
419 | list_output = []
420 | for i in range(len(list_input)):
421 | with tf.device('/gpu:%d' % i):
422 | with tf.variable_scope(name, reuse=i > 0):
423 | beta, gamma, moving_mean, moving_var = _get_bn_variables(
424 | n_out, use_gamma, use_beta, trainable, float_type)
425 |
426 | if 'train' in stats_mode:
427 | xn = tf.nn.batch_normalization(
428 | list_input[i], mean, var, beta, gamma, bn_epsilon)
429 | if tf.get_variable_scope().reuse or 'gather' not in stats_mode:
430 | list_output.append(xn)
431 | else:
432 | # gather stats and it is the main gpu device.
433 | xn = _update_bn_ema(xn, mean, var, moving_mean, moving_var, bn_ema)
434 | list_output.append(xn)
435 | else:
436 | xn = tf.nn.batch_normalization(
437 | list_input[i], moving_mean, moving_var, beta, gamma, bn_epsilon)
438 | list_output.append(xn)
439 |
440 | return list_output
441 |
442 | @staticmethod
443 | def layergn(inputdata, name, group_size=32, esp=1e-5):
444 | """
445 |
446 | :param inputdata:
447 | :param name:
448 | :param group_size:
449 | :param esp:
450 | :return:
451 | """
452 | with tf.variable_scope(name):
453 | inputdata = tf.transpose(inputdata, [0, 3, 1, 2])
454 | n, c, h, w = inputdata.get_shape().as_list()
455 | group_size = min(group_size, c)
456 | inputdata = tf.reshape(inputdata, [-1, group_size, c // group_size, h, w])
457 | mean, var = tf.nn.moments(inputdata, [2, 3, 4], keep_dims=True)
458 | inputdata = (inputdata - mean) / tf.sqrt(var + esp)
459 |
460 | # 每个通道的gamma和beta
461 | gamma = tf.Variable(tf.constant(1.0, shape=[c]), dtype=tf.float32, name='gamma')
462 | beta = tf.Variable(tf.constant(0.0, shape=[c]), dtype=tf.float32, name='beta')
463 | gamma = tf.reshape(gamma, [1, c, 1, 1])
464 | beta = tf.reshape(beta, [1, c, 1, 1])
465 |
466 | # 根据论文进行转换 [n, c, h, w, c] 到 [n, h, w, c]
467 | output = tf.reshape(inputdata, [-1, c, h, w])
468 | output = output * gamma + beta
469 | output = tf.transpose(output, [0, 2, 3, 1])
470 |
471 | return output
472 |
473 | @staticmethod
474 | def squeeze(inputdata, axis=None, name=None):
475 | """
476 |
477 | :param inputdata:
478 | :param axis:
479 | :param name:
480 | :return:
481 | """
482 | return tf.squeeze(input=inputdata, axis=axis, name=name)
483 |
484 | @staticmethod
485 | def deconv2d(inputdata, out_channel, kernel_size, padding='SAME',
486 | stride=1, w_init=None, b_init=None,
487 | use_bias=True, activation=None, data_format='channels_last',
488 | trainable=True, name=None):
489 | """
490 | Packing the tensorflow conv2d function.
491 | :param name: op name
492 | :param inputdata: A 4D tensorflow tensor which ust have known number of channels, but can have other
493 | unknown dimensions.
494 | :param out_channel: number of output channel.
495 | :param kernel_size: int so only support square kernel convolution
496 | :param padding: 'VALID' or 'SAME'
497 | :param stride: int so only support square stride
498 | :param w_init: initializer for convolution weights
499 | :param b_init: initializer for bias
500 | :param activation: whether to apply a activation func to deconv result
501 | :param use_bias: whether to use bias.
502 | :param data_format: default set to NHWC according tensorflow
503 | :param trainable:
504 | :return: tf.Tensor named ``output``
505 | """
506 | with tf.variable_scope(name):
507 | in_shape = inputdata.get_shape().as_list()
508 | channel_axis = 3 if data_format == 'channels_last' else 1
509 | in_channel = in_shape[channel_axis]
510 | assert in_channel is not None, "[Deconv2D] Input cannot have unknown channel!"
511 |
512 | padding = padding.upper()
513 |
514 | if w_init is None:
515 | w_init = tf.contrib.layers.variance_scaling_initializer()
516 | if b_init is None:
517 | b_init = tf.constant_initializer()
518 |
519 | ret = tf.layers.conv2d_transpose(inputs=inputdata, filters=out_channel,
520 | kernel_size=kernel_size,
521 | strides=stride, padding=padding,
522 | data_format=data_format,
523 | activation=activation, use_bias=use_bias,
524 | kernel_initializer=w_init,
525 | bias_initializer=b_init, trainable=trainable,
526 | name=name)
527 | return ret
528 |
529 | @staticmethod
530 | def dilation_conv(input_tensor, k_size, out_dims, rate, padding='SAME',
531 | w_init=None, b_init=None, use_bias=False, name=None):
532 | """
533 |
534 | :param input_tensor:
535 | :param k_size:
536 | :param out_dims:
537 | :param rate:
538 | :param padding:
539 | :param w_init:
540 | :param b_init:
541 | :param use_bias:
542 | :param name:
543 | :return:
544 | """
545 | with tf.variable_scope(name):
546 | in_shape = input_tensor.get_shape().as_list()
547 | in_channel = in_shape[3]
548 |
549 | assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!"
550 |
551 | padding = padding.upper()
552 |
553 | if isinstance(k_size, list):
554 | filter_shape = [k_size[0], k_size[1]] + [in_channel, out_dims]
555 | else:
556 | filter_shape = [k_size, k_size] + [in_channel, out_dims]
557 |
558 | if w_init is None:
559 | w_init = tf.contrib.layers.variance_scaling_initializer()
560 | if b_init is None:
561 | b_init = tf.constant_initializer()
562 |
563 | w = tf.get_variable('W', filter_shape, initializer=w_init)
564 | b = None
565 |
566 | if use_bias:
567 | b = tf.get_variable('b', [out_dims], initializer=b_init)
568 |
569 | conv = tf.nn.atrous_conv2d(value=input_tensor, filters=w, rate=rate,
570 | padding=padding, name='dilation_conv')
571 |
572 | if use_bias:
573 | ret = tf.add(conv, b)
574 | else:
575 | ret = conv
576 |
577 | return ret
578 |
579 | @staticmethod
580 | def spatial_dropout(input_tensor, keep_prob, is_training, name, seed=1234):
581 | """
582 | 空间dropout实现
583 | :param input_tensor:
584 | :param keep_prob:
585 | :param is_training:
586 | :param name:
587 | :param seed:
588 | :return:
589 | """
590 |
591 | def f1():
592 | input_shape = input_tensor.get_shape().as_list()
593 | noise_shape = tf.constant(value=[input_shape[0], 1, 1, input_shape[3]])
594 | return tf.nn.dropout(input_tensor, keep_prob, noise_shape, seed=seed, name="spatial_dropout")
595 |
596 | def f2():
597 | return input_tensor
598 |
599 | with tf.variable_scope(name_or_scope=name):
600 |
601 | output = tf.cond(is_training, f1, f2)
602 |
603 | return output
604 |
605 | @staticmethod
606 | def lrelu(inputdata, name, alpha=0.2):
607 | """
608 |
609 | :param inputdata:
610 | :param alpha:
611 | :param name:
612 | :return:
613 | """
614 | with tf.variable_scope(name):
615 | return tf.nn.relu(inputdata) - alpha * tf.nn.relu(-inputdata)
616 |
617 | @staticmethod
618 | def pad(inputdata, paddings, name):
619 | """
620 |
621 | :param inputdata:
622 | :param paddings:
623 | :return:
624 | """
625 | with tf.variable_scope(name_or_scope=name):
626 | return tf.pad(tensor=inputdata, paddings=paddings)
627 |
--------------------------------------------------------------------------------
/nsfw_model/nsfw_classification_net.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 19-2-15 下午6:56
4 | # @Author : MaybeShewill-CV
5 | # @Site : https://github.com/MaybeShewill-CV/CRNN_Tensorflow
6 | # @File : nsfw_classification_net.py
7 | # @IDE: PyCharm
8 | """
9 | NSFW Classification Net Model
10 | """
11 | import tensorflow as tf
12 |
13 | from nsfw_model import cnn_basenet
14 | from config import global_config
15 |
16 |
17 | CFG = global_config.cfg
18 |
19 |
20 | class NSFWNet(cnn_basenet.CNNBaseModel):
21 | """
22 | nsfw classification net
23 | """
24 | def __init__(self, phase, resnet_size=CFG.NET.RESNET_SIZE):
25 | """
26 |
27 | :param phase:
28 | """
29 | super(NSFWNet, self).__init__()
30 | self._train_phase = tf.constant('train', dtype=tf.string)
31 | self._test_phase = tf.constant('test', dtype=tf.string)
32 | self._phase = phase
33 | self._is_training = self._init_phase()
34 | self._need_summary_feats_map = CFG.NET.NEED_SUMMARY_FEATS_MAP
35 | self._resnet_size = resnet_size
36 | self._block_sizes = self._get_block_sizes(self._resnet_size)
37 | self._block_strides = [1, 2, 2, 2]
38 |
39 | def _init_phase(self):
40 | """
41 |
42 | :return:
43 | """
44 | return tf.equal(self._phase, self._train_phase)
45 |
46 | @staticmethod
47 | def _get_block_sizes(resnet_size):
48 | """
49 | Retrieve the size of each block_layer in the ResNet model.
50 | The number of block layers used for the Resnet model varies according
51 | to the size of the model. This helper grabs the layer set we want, throwing
52 | an error if a non-standard size has been selected.
53 | Args:
54 | resnet_size: The number of convolutional layers needed in the model.
55 | Returns:
56 | A list of block sizes to use in building the model.
57 | Raises:
58 | KeyError: if invalid resnet_size is received.
59 | """
60 | choices = {
61 | 32: [3, 4, 6, 3],
62 | 50: [3, 4, 6, 3]
63 | }
64 |
65 | try:
66 | return choices[resnet_size]
67 | except KeyError:
68 | err = ('Could not find layers for selected Resnet size.\n'
69 | 'Size received: {}; sizes allowed: {}.'.format(resnet_size, choices.keys()))
70 | raise ValueError(err)
71 |
72 | @staticmethod
73 | def _feature_map_summary(input_tensor, slice_nums, axis):
74 | """
75 | summary feature map
76 | :param input_tensor: A Tensor
77 | :return: Add histogram summary and scalar summary of the sparsity of the tensor
78 | """
79 | tensor_name = input_tensor.op.name
80 |
81 | split = tf.split(input_tensor, num_or_size_splits=slice_nums, axis=axis)
82 | for i in range(slice_nums):
83 | tf.summary.image(tensor_name + "/feature_maps_" + str(i), split[i])
84 |
85 | def _fixed_padding(self, inputs, kernel_size, name):
86 | """Pads the input along the spatial dimensions independently of input size.
87 | Args:
88 | inputs: A tensor of size [batch, channels, height_in, width_in] or
89 | [batch, height_in, width_in, channels] depending on data_format.
90 | kernel_size: The kernel to be used in the conv2d or max_pool2d operation.
91 | Should be a positive integer.
92 | name:
93 | Returns:
94 | A tensor with the same format as the input with the data either intact
95 | (if kernel_size == 1) or padded (if kernel_size > 1).
96 | """
97 | with tf.variable_scope(name_or_scope=name):
98 | pad_total = kernel_size - 1
99 | pad_beg = pad_total // 2
100 | pad_end = pad_total - pad_beg
101 |
102 | padded_inputs = self.pad(inputdata=inputs,
103 | paddings=[[0, 0], [pad_beg, pad_end],
104 | [pad_beg, pad_end], [0, 0]],
105 | name='pad')
106 | return padded_inputs
107 |
108 | def _conv2d_fixed_padding(self, inputs, kernel_size, output_dims, strides, name):
109 | """
110 |
111 | :param inputs:
112 | :param kernel_size:
113 | :param output_dims:
114 | :param strides:
115 | :param name:
116 | :return:
117 | """
118 | with tf.variable_scope(name_or_scope=name):
119 | if strides > 1:
120 | inputs = self._fixed_padding(inputs, kernel_size, name='fix_padding')
121 |
122 | result = self.conv2d(inputdata=inputs, out_channel=output_dims, kernel_size=kernel_size,
123 | stride=strides, padding=('SAME' if strides == 1 else 'VALID'),
124 | use_bias=False, name='conv')
125 |
126 | return result
127 |
128 | def _process_image_input_tensor(self, input_image_tensor, kernel_size,
129 | conv_stride, output_dims, pool_size, pool_stride):
130 | """
131 | Resnet entry
132 | :param input_image_tensor:
133 | :param kernel_size:
134 | :param conv_stride:
135 | :param output_dims:
136 | :param pool_size:
137 | :param pool_stride:
138 | :return:
139 | """
140 | inputs = self._conv2d_fixed_padding(
141 | inputs=input_image_tensor, kernel_size=kernel_size,
142 | strides=conv_stride, output_dims=output_dims, name='initial_conv_pad')
143 | inputs = tf.identity(inputs, 'initial_conv')
144 |
145 | inputs = self.maxpooling(inputdata=inputs, kernel_size=pool_size,
146 | stride=pool_stride, padding='SAME',
147 | name='initial_max_pool')
148 |
149 | return inputs
150 |
151 | def _resnet_block_fn(self, input_tensor, kernel_size, stride,
152 | output_dims, name, projection_shortcut=None):
153 | """
154 | A single block for ResNet v2, without a bottleneck.
155 | Batch normalization then ReLu then convolution as described by:
156 | Identity Mappings in Deep Residual Networks
157 | https://arxiv.org/pdf/1603.05027.pdf
158 | by Kaiming He, Xiangyu Zhang, Shaoqing Ren, and Jian Sun, Jul 2016.
159 | :param input_tensor:
160 | :param kernel_size:
161 | :param stride:
162 | :param output_dims:
163 | :param name:
164 | :param projection_shortcut:
165 | :return:
166 | """
167 | with tf.variable_scope(name_or_scope=name):
168 | shortcut = input_tensor
169 | inputs = self.layerbn(inputdata=input_tensor, is_training=self._is_training, name='bn_1')
170 | inputs = self.relu(inputdata=inputs, name='relu_1')
171 |
172 | if projection_shortcut is not None:
173 | shortcut = projection_shortcut(inputs)
174 |
175 | inputs = self._conv2d_fixed_padding(
176 | inputs=inputs, output_dims=output_dims,
177 | kernel_size=kernel_size, strides=stride, name='conv_pad_1')
178 |
179 | inputs = self.layerbn(inputdata=inputs, is_training=self._is_training, name='bn_2')
180 | inputs = self.relu(inputdata=inputs, name='relu_2')
181 | inputs = self._conv2d_fixed_padding(
182 | inputs=inputs, output_dims=output_dims, kernel_size=kernel_size, strides=1, name='conv_pad_2')
183 |
184 | return inputs + shortcut
185 |
186 | def _resnet_block_layer(self, input_tensor, kernel_size, stride, block_nums, output_dims, name):
187 | """
188 |
189 | :param input_tensor:
190 | :param kernel_size:
191 | :param stride:
192 | :param block_nums:
193 | :param name:
194 | :return:
195 | """
196 | def projection_shortcut(_inputs):
197 | return self._conv2d_fixed_padding(
198 | inputs=_inputs, output_dims=output_dims, kernel_size=1,
199 | strides=stride, name='projection_shortcut')
200 |
201 | with tf.variable_scope(name):
202 | inputs = self._resnet_block_fn(input_tensor=input_tensor,
203 | kernel_size=kernel_size,
204 | output_dims=output_dims,
205 | projection_shortcut=projection_shortcut,
206 | stride=stride,
207 | name='init_block_fn')
208 |
209 | for index in range(1, block_nums):
210 | inputs = self._resnet_block_fn(input_tensor=inputs,
211 | kernel_size=kernel_size,
212 | output_dims=output_dims,
213 | projection_shortcut=None,
214 | stride=1,
215 | name='block_fn_{:d}'.format(index))
216 | return inputs
217 |
218 | def inference(self, input_tensor, name, reuse=False):
219 | """
220 | The main function that defines the ResNet. total layers = 1 + 2n + 2n + 2n + 1 = 6n + 2
221 | :param input_tensor: 4D tensor
222 | :param name: net name
223 | :param reuse: To build train graph, reuse=False. To build validation graph and share weights
224 | with train graph, resue=True
225 | :return: last layer in the network. Not softmax-ed
226 | """
227 | with tf.variable_scope(name_or_scope=name, reuse=reuse):
228 |
229 | if self._need_summary_feats_map:
230 | self._feature_map_summary(input_tensor=input_tensor, slice_nums=1, axis=-1)
231 |
232 | # first layer process
233 | inputs = self._process_image_input_tensor(input_image_tensor=input_tensor,
234 | kernel_size=7,
235 | conv_stride=2,
236 | output_dims=64,
237 | pool_size=3,
238 | pool_stride=2)
239 | if self._need_summary_feats_map:
240 | self._feature_map_summary(input_tensor=inputs, slice_nums=64, axis=-1)
241 |
242 | for index, block_nums in enumerate(self._block_sizes):
243 | output_dims = 64 * (2 ** index)
244 |
245 | inputs = self._resnet_block_layer(input_tensor=inputs,
246 | kernel_size=3,
247 | output_dims=output_dims,
248 | block_nums=block_nums,
249 | stride=self._block_strides[index],
250 | name='resnet_block_layer_{:d}'.format(index + 1))
251 |
252 | if self._need_summary_feats_map:
253 | self._feature_map_summary(input_tensor=inputs, slice_nums=output_dims, axis=-1)
254 |
255 | inputs = self.layerbn(inputdata=inputs, is_training=self._is_training, name='bn_after_block_layer')
256 | inputs = self.relu(inputdata=inputs, name='relu_after_block_layer')
257 |
258 | inputs = tf.reduce_mean(input_tensor=inputs, axis=[1, 2], keepdims=True, name='final_reduce_mean')
259 | inputs = tf.squeeze(input=inputs, axis=[1, 2], name='final_squeeze')
260 |
261 | final_logits = self.fullyconnect(inputdata=inputs, out_dim=CFG.TRAIN.CLASSES_NUMS,
262 | use_bias=False, name='final_logits')
263 |
264 | return final_logits
265 |
266 | def compute_loss(self, input_tensor, labels, name, reuse=False):
267 | """
268 |
269 | :param input_tensor:
270 | :param labels:
271 | :param name:
272 | :param reuse:
273 | :return:
274 | """
275 | labels = tf.cast(labels, tf.int64)
276 |
277 | inference_logits = self.inference(input_tensor=input_tensor,
278 | name=name,
279 | reuse=reuse)
280 |
281 | cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=inference_logits,
282 | labels=labels,
283 | name='cross_entropy_per_example')
284 | cross_entropy_loss = tf.reduce_mean(cross_entropy, name='cross_entropy')
285 |
286 | l2_loss = CFG.TRAIN.WEIGHT_DECAY * tf.add_n(
287 | [tf.nn.l2_loss(tf.cast(vv, tf.float32)) for vv in tf.trainable_variables()
288 | if 'bn' not in vv.name])
289 |
290 | total_loss = cross_entropy_loss + l2_loss
291 |
292 | return total_loss
293 |
294 |
295 | if __name__ == '__main__':
296 | """
297 | test code
298 | """
299 | image_tensor = tf.placeholder(shape=[16, 256, 256, 3], dtype=tf.float32)
300 | label_tensor = tf.placeholder(shape=[16], dtype=tf.int32)
301 |
302 | net = NSFWNet(phase=tf.constant('train', dtype=tf.string))
303 |
304 | loss = net.compute_loss(input_tensor=image_tensor,
305 | labels=label_tensor,
306 | name='net',
307 | reuse=False)
308 | loss_val = net.compute_loss(input_tensor=image_tensor,
309 | labels=label_tensor,
310 | name='net',
311 | reuse=True)
312 |
313 | logits = net.inference(input_tensor=image_tensor,
314 | name='net',
315 | reuse=True)
316 |
317 | print(loss.get_shape().as_list())
318 | print(loss_val.get_shape().as_list())
319 | print(logits.get_shape().as_list())
320 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | easydict==1.9
2 | matplotlib==3.0.2
3 | numpy==1.15.1
4 | opencv_python==3.4.1.15
5 | tensorflow_gpu==1.15.4
6 | glog==0.3.1
7 | scikit_learn==0.20.2
8 |
--------------------------------------------------------------------------------
/tboard/nsfw_cls/events.out.tfevents.1551264389.baidu-pc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MaybeShewill-CV/nsfw-classification-tensorflow/6dfcb16fd655e66b9dd83237bbe89e84aa5322b9/tboard/nsfw_cls/events.out.tfevents.1551264389.baidu-pc
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 19-2-14 下午5:39
4 | # @Author : Luo Yao
5 | # @Site : http://icode.baidu.com/repos/baidu/personal-code/Luoyao
6 | # @File : __init__.py.py
7 | # @IDE: PyCharm
--------------------------------------------------------------------------------
/tools/convert_tfjs_model.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | tensorflowjs_converter --input_format=tf_saved_model --output_node_name=nsfw_cls_model/final_prediction \
3 | --saved_model_tags=serve ./model/nsfw_export_saved_model ./model/nsfw_web_model
--------------------------------------------------------------------------------
/tools/evaluate_nsfw.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 19-2-28 上午11:20
4 | # @Author : MaybeShewill-CV
5 | # @Site : https://github.com/MaybeShewill-CV/CRNN_Tensorflow
6 | # @File : evaluate_nsfw.py
7 | # @IDE: PyCharm
8 | """
9 | Evaluate nsfw model
10 | """
11 | import itertools
12 | import os.path as ops
13 | import argparse
14 | import glog as log
15 |
16 | import tensorflow as tf
17 | import numpy as np
18 | from sklearn.metrics import (confusion_matrix, precision_score, recall_score,
19 | precision_recall_curve, average_precision_score, f1_score)
20 | from sklearn.metrics import classification_report
21 | from sklearn.preprocessing import label_binarize
22 | from sklearn.utils.fixes import signature
23 | import matplotlib.pyplot as plt
24 |
25 | from config import global_config
26 | from nsfw_model import nsfw_classification_net
27 | from data_provider import nsfw_data_feed_pipline
28 |
29 | CFG = global_config.cfg
30 |
31 | _R_MEAN = 123.68
32 | _G_MEAN = 116.78
33 | _B_MEAN = 103.94
34 | _CHANNEL_MEANS = [_B_MEAN, _G_MEAN, _R_MEAN]
35 |
36 |
37 | def init_args():
38 | """
39 |
40 | :return:
41 | """
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument('--dataset_dir', type=str, default=None, help='The dataset dir')
44 | parser.add_argument('--weights_path', type=str, help='The model weights file path')
45 | parser.add_argument('--top_k', type=int, default=1, help='Evaluate top k error')
46 |
47 | return parser.parse_args()
48 |
49 |
50 | def calculate_top_k_error(predictions, labels, k=1):
51 | """
52 | Calculate the top-k error
53 | :param predictions: 2D tensor with shape [batch_size, num_labels]
54 | :param labels: 1D tensor with shape [batch_size, 1]
55 | :param k: int
56 | :return: tensor with shape [1]
57 | """
58 | batch_size = CFG.TEST.BATCH_SIZE
59 | in_top_k = tf.to_float(tf.nn.in_top_k(predictions, labels, k=k))
60 | num_correct = tf.reduce_sum(in_top_k)
61 |
62 | return (batch_size - num_correct) / float(batch_size)
63 |
64 |
65 | def plot_confusion_matrix(cm, classes,
66 | normalize=False,
67 | title='Confusion matrix',
68 | cmap=plt.cm.Blues):
69 | """
70 | This function prints and plots the confusion matrix.
71 | Normalization can be applied by setting `normalize=True`.
72 | """
73 | if normalize:
74 | cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
75 | log.info("Normalized confusion matrix")
76 | else:
77 | log.info('Confusion matrix, without normalization')
78 |
79 | print(cm)
80 |
81 | plt.imshow(cm, interpolation='nearest', cmap=cmap)
82 | plt.title(title)
83 | plt.colorbar()
84 | tick_marks = np.arange(len(classes))
85 | plt.xticks(tick_marks, classes, rotation=45)
86 | plt.yticks(tick_marks, classes)
87 |
88 | fmt = '.2f' if normalize else 'd'
89 | thresh = cm.max() / 2.
90 | for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
91 | plt.text(j, i, format(cm[i, j], fmt),
92 | horizontalalignment="center",
93 | color="white" if cm[i, j] > thresh else "black")
94 |
95 | plt.ylabel('True label')
96 | plt.xlabel('Predicted label')
97 | plt.tight_layout()
98 |
99 |
100 | def plot_precision_recall_curve(labels, predictions_prob, class_nums, average_function='weighted'):
101 | """
102 | Plot precision recall curve
103 | :param labels:
104 | :param predictions_prob:
105 | :param class_nums:
106 | :param average_function:
107 | :return:
108 | """
109 | labels = label_binarize(labels, classes=np.linspace(0, class_nums - 1, num=class_nums).tolist())
110 | predictions_prob = np.array(predictions_prob, dtype=np.float32)
111 |
112 | precision = dict()
113 | recall = dict()
114 | average_precision = dict()
115 |
116 | for i in range(class_nums):
117 | precision[i], recall[i], _ = precision_recall_curve(labels[:, i],
118 | predictions_prob[:, i])
119 | average_precision[i] = average_precision_score(labels[:, i], predictions_prob[:, i])
120 |
121 | precision[average_function], recall[average_function], _ = precision_recall_curve(
122 | labels.ravel(), predictions_prob.ravel())
123 | average_precision[average_function] = average_precision_score(
124 | labels, predictions_prob, average=average_function)
125 | log.info('Average precision score, {:s}-averaged '
126 | 'over all classes: {:.5f}'.format(average_function, average_precision[average_function]))
127 |
128 | plt.figure()
129 | plt.step(recall[average_function], precision[average_function], color='b', alpha=0.2,
130 | where='post')
131 | step_kwargs = ({'step': 'post'} if 'step' in signature(plt.fill_between).parameters else {})
132 | plt.fill_between(recall[average_function], precision[average_function], alpha=0.2, color='b',
133 | **step_kwargs)
134 |
135 | plt.xlabel('Recall')
136 | plt.ylabel('Precision')
137 | plt.ylim([0.0, 1.05])
138 | plt.xlim([0.0, 1.0])
139 | plt.title(
140 | 'Average precision score, {:s}-averaged over '
141 | 'all classes: AP={:.5f}'.format(average_function, average_precision[average_function]))
142 |
143 |
144 | def calculate_evaluate_statics(labels, predictions, model_name='Nsfw', avgerage_method='weighted'):
145 | """
146 | Calculate Precision, Recall and F1 score
147 | :param labels:
148 | :param predictions:
149 | :param model_name:
150 | :param avgerage_method:
151 | :return:
152 | """
153 | log.info('Model name: {:s}:'.format(model_name))
154 | log.info('\tPrecision: {:.5f}'.format(precision_score(y_true=labels,
155 | y_pred=predictions,
156 | average=avgerage_method)))
157 | log.info('\tRecall: {:.5f}'.format(recall_score(y_true=labels,
158 | y_pred=predictions,
159 | average=avgerage_method)))
160 | log.info('\tF1: {:.5f}\n'.format(f1_score(y_true=labels,
161 | y_pred=predictions,
162 | average=avgerage_method)))
163 |
164 |
165 | def nsfw_eval_dataset(dataset_dir, weights_path):
166 | """
167 | Evaluate the nsfw dataset
168 | :param dataset_dir: The nsfw dataset dir which contains tensorflow records file
169 | :param weights_path: The pretrained nsfw model weights file path
170 | :return:
171 | """
172 | assert ops.exists(dataset_dir)
173 |
174 | # set nsfw data feed pipline
175 | test_dataset = nsfw_data_feed_pipline.NsfwDataFeeder(dataset_dir=dataset_dir,
176 | flags='test')
177 | prediciton_map = test_dataset.prediction_map
178 | class_names = ['drawing', 'hentai', 'neural', 'porn', 'sexy']
179 |
180 | with tf.device('/gpu:1'):
181 | # set nsfw classification model
182 | phase = tf.constant('test', dtype=tf.string)
183 |
184 | # set nsfw net
185 | nsfw_net = nsfw_classification_net.NSFWNet(phase=phase,
186 | resnet_size=CFG.NET.RESNET_SIZE)
187 |
188 | # compute train loss
189 | images, labels = test_dataset.inputs(batch_size=CFG.TEST.BATCH_SIZE,
190 | num_epochs=1)
191 |
192 | logits = nsfw_net.inference(input_tensor=images,
193 | name='nsfw_cls_model',
194 | reuse=False)
195 |
196 | predictions = tf.nn.softmax(logits)
197 |
198 | # Restore the moving average version of the learned variables for eval.
199 | variable_averages = tf.train.ExponentialMovingAverage(
200 | CFG.TRAIN.MOVING_AVERAGE_DECAY)
201 | variables_to_restore = variable_averages.variables_to_restore()
202 |
203 | # set tensorflow saver
204 | saver = tf.train.Saver(variables_to_restore)
205 |
206 | # Set sess configuration
207 | sess_config = tf.ConfigProto(allow_soft_placement=True)
208 | sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TEST.GPU_MEMORY_FRACTION
209 | sess_config.gpu_options.allow_growth = CFG.TEST.TF_ALLOW_GROWTH
210 | sess_config.gpu_options.allocator_type = 'BFC'
211 |
212 | sess = tf.Session(config=sess_config)
213 |
214 | # labels overall test dataset
215 | labels_total = []
216 | # prediction result overall test dataset
217 | predictions_total = []
218 | # prediction score overall test dataset of all subclass
219 | predictions_prob_total = []
220 |
221 | with sess.as_default():
222 |
223 | saver.restore(sess=sess, save_path=weights_path)
224 |
225 | while True:
226 | try:
227 | predictions_vals, labels_vals = sess.run(
228 | fetches=[predictions,
229 | labels])
230 |
231 | log.info('**************')
232 | log.info('Test dataset batch size: {:d}'.format(predictions_vals.shape[0]))
233 | log.info('---- Sample Id ---- Gt label ---- Prediction ----')
234 |
235 | for index, predictions_val in enumerate(predictions_vals):
236 |
237 | label_gt = prediciton_map[labels_vals[index]]
238 |
239 | prediction_score = dict()
240 |
241 | for score_index, score in enumerate(predictions_val):
242 | prediction_score[prediciton_map[score_index]] = format(score, '.5f')
243 |
244 | log.info('---- {:d} ---- {:s} ---- {}'.format(index, label_gt, prediction_score))
245 |
246 | # record predicts prob map
247 | predictions_prob_total.append(predictions_val.tolist())
248 |
249 | # record total label and prediction results
250 | labels_total.extend(labels_vals.tolist())
251 | predictions_total.extend(np.argmax(predictions_vals, axis=1).tolist())
252 |
253 | except tf.errors.OutOfRangeError as err:
254 | log.info('Loop overall the test dataset')
255 | break
256 | except Exception as err:
257 | log.error(err)
258 | break
259 |
260 | # print prediction report
261 | print('Nsfw classification_report(left: labels):')
262 | print(classification_report(labels_total, predictions_total))
263 |
264 | # calculate confusion matrix
265 | cnf_matrix = confusion_matrix(labels_total, predictions_total)
266 | np.set_printoptions(precision=2)
267 | plot_confusion_matrix(cnf_matrix, classes=class_names, normalize=True,
268 | title='Normalized confusion matrix')
269 |
270 | # calculate evaluate statics
271 | calculate_evaluate_statics(labels=labels_total, predictions=predictions_total)
272 |
273 | # plot precision recall curve
274 | plot_precision_recall_curve(labels=labels_total,
275 | predictions_prob=predictions_prob_total,
276 | class_nums=5)
277 | plt.show()
278 |
279 | return
280 |
281 |
282 | if __name__ == '__main__':
283 | # init args
284 | args = init_args()
285 |
286 | # test net
287 | assert ops.exists(args.dataset_dir)
288 |
289 | nsfw_eval_dataset(args.dataset_dir, args.weights_path)
290 |
--------------------------------------------------------------------------------
/tools/export_nsfw_saved_model.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | python tools/export_saved_model.py --ckpt_path model/nsfw_cls/nsfw_cls_2019-02-27-18-46-28.ckpt-160000 \
4 | --export_dir ./model/nsfw_export_saved_model
--------------------------------------------------------------------------------
/tools/export_saved_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 19-2-22 上午11:10
4 | # @Author : MaybeShewill-CV
5 | # @Site : https://github.com/MaybeShewill-CV/CRNN_Tensorflow
6 | # @File : export_saved_model.py
7 | # @IDE: PyCharm
8 | """
9 | Build tensorflow saved model for tensorflowjs converter to use
10 | """
11 | import os.path as ops
12 | import argparse
13 | import glog as log
14 |
15 | import cv2
16 | import numpy as np
17 | import matplotlib.pyplot as plt
18 | import tensorflow as tf
19 | from tensorflow import saved_model as sm
20 |
21 | from config import global_config
22 | from nsfw_model import nsfw_classification_net
23 |
24 | CFG = global_config.cfg
25 |
26 | _R_MEAN = 123.68
27 | _G_MEAN = 116.78
28 | _B_MEAN = 103.94
29 | _CHANNEL_MEANS = [_B_MEAN, _G_MEAN, _R_MEAN]
30 |
31 |
32 | def init_args():
33 | """
34 |
35 | :return:
36 | """
37 | parser = argparse.ArgumentParser()
38 | parser.add_argument('--export_dir', type=str, help='The model export dir')
39 | parser.add_argument('--ckpt_path', type=str, help='The pretrained ckpt model weights file path')
40 |
41 | return parser.parse_args()
42 |
43 |
44 | def central_crop(image, central_fraction):
45 | """
46 |
47 | :param image:
48 | :param central_fraction:
49 | :return:
50 | """
51 |
52 | image_shape = np.shape(image)
53 | image_height = image_shape[0]
54 | image_width = image_shape[1]
55 |
56 | if central_fraction >= 1 or central_fraction < 0:
57 | raise ValueError('Central fraction should be in [0, 1)')
58 |
59 | top = int((1 - central_fraction) * image_height / 2)
60 | bottom = image_height - top
61 | left = int((1 - central_fraction) * image_width / 2)
62 | right = image_width - left
63 |
64 | if not image_height or not image_width:
65 | raise ValueError('Image shape with zero')
66 |
67 | if len(image_shape) == 2:
68 | crop_image = image[top:bottom, left:right]
69 | return crop_image
70 | elif len(image_shape) == 3:
71 | crop_image = image[top:bottom, left:right, :]
72 | return crop_image
73 | else:
74 | raise ValueError('Wrong image shape')
75 |
76 |
77 | def build_saved_model(ckpt_path, export_dir):
78 | """
79 | Convert source ckpt weights file into tensorflow saved model
80 | :param ckpt_path:
81 | :param export_dir:
82 | :return:
83 | """
84 |
85 | if ops.exists(export_dir):
86 | raise ValueError('Export dir must be a dir path that does not exist')
87 |
88 | assert ops.exists(ops.split(ckpt_path)[0])
89 |
90 | # build inference tensorflow graph
91 | image_tensor = tf.placeholder(dtype=tf.float32,
92 | shape=[1, CFG.TRAIN.CROP_IMG_HEIGHT, CFG.TRAIN.CROP_IMG_WIDTH, 3],
93 | name='input_tensor')
94 | # set nsfw net
95 | phase = tf.constant('test', dtype=tf.string)
96 | nsfw_net = nsfw_classification_net.NSFWNet(phase=phase,
97 | resnet_size=CFG.NET.RESNET_SIZE)
98 |
99 | # compute inference logits
100 | logits = nsfw_net.inference(input_tensor=image_tensor,
101 | name='nsfw_cls_model',
102 | reuse=False)
103 |
104 | predictions = tf.nn.softmax(logits, name='nsfw_cls_model/final_prediction')
105 |
106 | # Restore the moving average version of the learned variables for eval.
107 | variable_averages = tf.train.ExponentialMovingAverage(
108 | CFG.TRAIN.MOVING_AVERAGE_DECAY)
109 | variables_to_restore = variable_averages.variables_to_restore()
110 |
111 | # set tensorflow saver
112 | saver = tf.train.Saver(variables_to_restore)
113 |
114 | # Set sess configuration
115 | sess_config = tf.ConfigProto(allow_soft_placement=True)
116 | sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
117 | sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
118 | sess_config.gpu_options.allocator_type = 'BFC'
119 |
120 | sess = tf.Session(config=sess_config)
121 |
122 | with sess.as_default():
123 |
124 | saver.restore(sess=sess, save_path=ckpt_path)
125 |
126 | # set model save builder
127 | saved_builder = sm.builder.SavedModelBuilder(export_dir)
128 |
129 | # add tensor need to be saved
130 | saved_input_tensor = sm.utils.build_tensor_info(image_tensor)
131 | saved_prediction_tensor = sm.utils.build_tensor_info(predictions)
132 |
133 | # build SignatureDef protobuf
134 | signatur_def = sm.signature_def_utils.build_signature_def(
135 | inputs={'input_tensor': saved_input_tensor},
136 | outputs={'prediction': saved_prediction_tensor},
137 | method_name=tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME
138 | )
139 |
140 | # add graph into MetaGraphDef protobuf
141 | saved_builder.add_meta_graph_and_variables(
142 | sess,
143 | tags=[sm.tag_constants.SERVING],
144 | signature_def_map={'classify_result': signatur_def}
145 | )
146 |
147 | # save model
148 | saved_builder.save()
149 |
150 | return
151 |
152 |
153 | def test_load_saved_model(saved_model_dir):
154 | """
155 |
156 | :param saved_model_dir:
157 | :return:
158 | """
159 |
160 | prediciton_map = global_config.NSFW_PREDICT_MAP
161 |
162 | image = cv2.imread('data/test_data/drawing_16715.jpg', cv2.IMREAD_COLOR)
163 | image_vis = image
164 | image = cv2.resize(src=image,
165 | dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
166 | interpolation=cv2.INTER_CUBIC)
167 | image = central_crop(image=image,
168 | central_fraction=CFG.TRAIN.CROP_IMG_HEIGHT / CFG.TRAIN.IMG_HEIGHT)
169 | image = np.array(image, dtype=np.float32) - np.array(_CHANNEL_MEANS, np.float32)
170 | image = np.expand_dims(image, 0)
171 |
172 | # Set sess configuration
173 | sess_config = tf.ConfigProto(allow_soft_placement=True)
174 | sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
175 | sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
176 | sess_config.gpu_options.allocator_type = 'BFC'
177 |
178 | sess = tf.Session(config=sess_config)
179 |
180 | with sess.as_default():
181 |
182 | meta_graphdef = sm.loader.load(
183 | sess,
184 | tags=[sm.tag_constants.SERVING],
185 | export_dir=saved_model_dir)
186 |
187 | signature_def_d = meta_graphdef.signature_def
188 | signature_def_d = signature_def_d['classify_result']
189 |
190 | image_input_tensor = signature_def_d.inputs['input_tensor']
191 | prediction_tensor = signature_def_d.outputs['prediction']
192 |
193 | input_tensor = sm.utils.get_tensor_from_tensor_info(image_input_tensor, sess.graph)
194 | predictions = sm.utils.get_tensor_from_tensor_info(prediction_tensor, sess.graph)
195 |
196 | prediction_val = sess.run(predictions, feed_dict={input_tensor: image})
197 |
198 | prediction_score = dict()
199 |
200 | for score_index, score in enumerate(prediction_val[0]):
201 | prediction_score[prediciton_map[score_index]] = format(score, '.5f')
202 |
203 | log.info('Predict result: {}'.format(prediction_score))
204 |
205 | plt.figure('source image')
206 | plt.imshow(image_vis[:, :, (2, 1, 0)])
207 |
208 |
209 | if __name__ == '__main__':
210 | """
211 | build saved model
212 | """
213 | # init args
214 | args = init_args()
215 |
216 | # build saved model
217 | build_saved_model(args.ckpt_path, args.export_dir)
218 |
219 | # test build saved model
220 | test_load_saved_model(args.export_dir)
221 |
--------------------------------------------------------------------------------
/tools/make_nsfw_dataset.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # split the nsfw dataset and convert them into tfrecords
4 | python nsfw_classification/data_provider/nsfw_data_feed_pipline.py \
5 | --dataset_dir nsfw_classification/data/nsfw_dataset_example \
6 | --tfrecords_dir nsfw_classification/data/nsfw_dataset_example/tfrecords
7 |
--------------------------------------------------------------------------------
/tools/test_nsfw.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 19-2-20 上午10:48
4 | # @Author : MaybeShewill-CV
5 | # @Site : https://github.com/MaybeShewill-CV/CRNN_Tensorflow
6 | # @File : test_nsfw.py.py
7 | # @IDE: PyCharm
8 | """
9 | Test nsfw model script
10 | """
11 | import os.path as ops
12 | import argparse
13 | import glog as log
14 |
15 | import tensorflow as tf
16 | import numpy as np
17 | import cv2
18 | import matplotlib.pyplot as plt
19 |
20 | from config import global_config
21 | from nsfw_model import nsfw_classification_net
22 | from data_provider import nsfw_data_feed_pipline
23 |
24 |
25 | CFG = global_config.cfg
26 |
27 | _R_MEAN = 123.68
28 | _G_MEAN = 116.78
29 | _B_MEAN = 103.94
30 | _CHANNEL_MEANS = [_B_MEAN, _G_MEAN, _R_MEAN]
31 |
32 |
33 | def init_args():
34 | """
35 |
36 | :return:
37 | """
38 | parser = argparse.ArgumentParser()
39 | parser.add_argument('--image_path', type=str, default=None, help='The image path')
40 | parser.add_argument('--weights_path', type=str, help='The model weights file path')
41 |
42 | return parser.parse_args()
43 |
44 |
45 | def central_crop(image, central_fraction):
46 | """
47 |
48 | :param image:
49 | :param central_fraction:
50 | :return:
51 | """
52 |
53 | image_shape = np.shape(image)
54 | image_height = image_shape[0]
55 | image_width = image_shape[1]
56 |
57 | if central_fraction >= 1 or central_fraction < 0:
58 | raise ValueError('Central fraction should be in [0, 1)')
59 |
60 | top = int((1 - central_fraction) * image_height / 2)
61 | bottom = image_height - top
62 | left = int((1 - central_fraction) * image_width / 2)
63 | right = image_width - left
64 |
65 | if not image_height or not image_width:
66 | raise ValueError('Image shape with zero')
67 |
68 | if len(image_shape) == 2:
69 | crop_image = image[top:bottom, left:right]
70 | return crop_image
71 | elif len(image_shape) == 3:
72 | crop_image = image[top:bottom, left:right, :]
73 | return crop_image
74 | else:
75 | raise ValueError('Wrong image shape')
76 |
77 |
78 | def calculate_top_k_error(predictions, labels, k=1):
79 | """
80 | Calculate the top-k error
81 | :param predictions: 2D tensor with shape [batch_size, num_labels]
82 | :param labels: 1D tensor with shape [batch_size, 1]
83 | :param k: int
84 | :return: tensor with shape [1]
85 | """
86 | batch_size = CFG.TEST.BATCH_SIZE
87 | in_top_k = tf.to_float(tf.nn.in_top_k(predictions, labels, k=k))
88 | num_correct = tf.reduce_sum(in_top_k)
89 |
90 | return (batch_size - num_correct) / float(batch_size)
91 |
92 |
93 | def nsfw_classify_image(image_path, weights_path):
94 | """
95 | Use nsfw model to classify a single image
96 | :param image_path: The image file path
97 | :param weights_path: The pretrained weights file path
98 | :return:
99 | """
100 | assert ops.exists(image_path)
101 |
102 | prediciton_map = global_config.NSFW_PREDICT_MAP
103 |
104 | with tf.device('/gpu:1'):
105 | # set nsfw classification model
106 |
107 | image_tensor = tf.placeholder(dtype=tf.float32,
108 | shape=[1, CFG.TRAIN.CROP_IMG_HEIGHT, CFG.TRAIN.CROP_IMG_WIDTH, 3],
109 | name='input_tensor')
110 | # set nsfw net
111 | phase = tf.constant('test', dtype=tf.string)
112 | nsfw_net = nsfw_classification_net.NSFWNet(phase=phase,
113 | resnet_size=CFG.NET.RESNET_SIZE)
114 |
115 | # compute inference logits
116 | logits = nsfw_net.inference(input_tensor=image_tensor,
117 | name='nsfw_cls_model',
118 | reuse=False)
119 |
120 | predictions = tf.nn.softmax(logits)
121 |
122 | # Restore the moving average version of the learned variables for eval.
123 | variable_averages = tf.train.ExponentialMovingAverage(
124 | CFG.TRAIN.MOVING_AVERAGE_DECAY)
125 | variables_to_restore = variable_averages.variables_to_restore()
126 |
127 | # set tensorflow saver
128 | saver = tf.train.Saver(variables_to_restore)
129 |
130 | # Set sess configuration
131 | sess_config = tf.ConfigProto(allow_soft_placement=True)
132 | sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
133 | sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
134 | sess_config.gpu_options.allocator_type = 'BFC'
135 |
136 | sess = tf.Session(config=sess_config)
137 |
138 | with sess.as_default():
139 |
140 | saver.restore(sess=sess, save_path=weights_path)
141 |
142 | image = cv2.imread(image_path, cv2.IMREAD_COLOR)
143 | image_vis = image
144 | image = cv2.resize(src=image,
145 | dsize=(CFG.TRAIN.IMG_WIDTH, CFG.TRAIN.IMG_HEIGHT),
146 | interpolation=cv2.INTER_CUBIC)
147 | image = central_crop(image=image,
148 | central_fraction=CFG.TRAIN.CROP_IMG_HEIGHT / CFG.TRAIN.IMG_HEIGHT)
149 | image = np.array(image, dtype=np.float32) - np.array(_CHANNEL_MEANS, np.float32)
150 |
151 | predictions_vals = sess.run(
152 | fetches=predictions,
153 | feed_dict={image_tensor: [image]})
154 |
155 | prediction_score = dict()
156 |
157 | for score_index, score in enumerate(predictions_vals[0]):
158 | prediction_score[prediciton_map[score_index]] = format(score, '.5f')
159 |
160 | log.info('Predict result is: {}'.format(prediction_score))
161 |
162 | plt.figure('source image')
163 | plt.imshow(image_vis[:, :, (2, 1, 0)])
164 | plt.show()
165 |
166 | return
167 |
168 |
169 | if __name__ == '__main__':
170 | # init args
171 | args = init_args()
172 |
173 | # test net
174 | assert ops.exists(args.image_path)
175 |
176 | nsfw_classify_image(args.image_path, args.weights_path)
177 |
--------------------------------------------------------------------------------
/tools/train_nsfw.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | # @Time : 19-2-15 下午8:53
4 | # @Author : MaybeShewill-CV
5 | # @Site : https://github.com/MaybeShewill-CV/CRNN_Tensorflow
6 | # @File : train_nsfw.py
7 | # @IDE: PyCharm
8 | """
9 | Train nsfw model script
10 | """
11 | import argparse
12 | import os
13 | import os.path as ops
14 | import time
15 | import math
16 |
17 | import numpy as np
18 | import tensorflow as tf
19 | import glog as log
20 |
21 | from config import global_config
22 | from data_provider import nsfw_data_feed_pipline
23 | from nsfw_model import nsfw_classification_net
24 |
25 |
26 | CFG = global_config.cfg
27 |
28 |
29 | def init_args():
30 | """
31 |
32 | :return:
33 | """
34 | parser = argparse.ArgumentParser()
35 |
36 | parser.add_argument('--dataset_dir', type=str, help='The dataset_dir')
37 | parser.add_argument('--use_multi_gpu', type=bool, default=False, help='If use multiple gpu devices')
38 | parser.add_argument('--weights_path', type=str, default=None, help='The pretrained weights path')
39 |
40 | return parser.parse_args()
41 |
42 |
43 | def calculate_top_k_error(predictions, labels, k=1):
44 | """
45 | Calculate the top-k error
46 | :param predictions: 2D tensor with shape [batch_size, num_labels]
47 | :param labels: 1D tensor with shape [batch_size, 1]
48 | :param k: int
49 | :return: tensor with shape [1]
50 | """
51 | batch_size = CFG.TRAIN.BATCH_SIZE
52 | in_top1 = tf.to_float(tf.nn.in_top_k(predictions, labels, k=k))
53 | num_correct = tf.reduce_sum(in_top1)
54 |
55 | return (batch_size - num_correct) / float(batch_size)
56 |
57 |
58 | def average_gradients(tower_grads):
59 | """Calculate the average gradient for each shared variable across all towers.
60 | Note that this function provides a synchronization point across all towers.
61 | Args:
62 | tower_grads: List of lists of (gradient, variable) tuples. The outer list
63 | is over individual gradients. The inner list is over the gradient
64 | calculation for each tower.
65 | Returns:
66 | List of pairs of (gradient, variable) where the gradient has been averaged
67 | across all towers.
68 | """
69 | average_grads = []
70 | for grad_and_vars in zip(*tower_grads):
71 | # Note that each grad_and_vars looks like the following:
72 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
73 | grads = []
74 | for g, _ in grad_and_vars:
75 | # Add 0 dimension to the gradients to represent the tower.
76 | expanded_g = tf.expand_dims(g, 0)
77 |
78 | # Append on a 'tower' dimension which we will average over below.
79 | grads.append(expanded_g)
80 |
81 | # Average over the 'tower' dimension.
82 | grad = tf.concat(grads, 0)
83 | grad = tf.reduce_mean(grad, 0)
84 |
85 | # Keep in mind that the Variables are redundant because they are shared
86 | # across towers. So .. we will just return the first tower's pointer to
87 | # the Variable.
88 | v = grad_and_vars[0][1]
89 | grad_and_var = (grad, v)
90 | average_grads.append(grad_and_var)
91 |
92 | return average_grads
93 |
94 |
95 | def compute_net_gradients(images, labels, net, optimizer=None, is_net_first_initialized=False):
96 | """
97 | Calculate gradients for single GPU
98 | :param images: images for training
99 | :param labels: labels corresponding to images
100 | :param net: classification model
101 | :param optimizer: network optimizer
102 | :param is_net_first_initialized: if the network is initialized
103 | :return:
104 | """
105 | net_loss = net.compute_loss(input_tensor=images,
106 | labels=labels,
107 | name='nsfw_cls_model',
108 | reuse=is_net_first_initialized)
109 | net_logits = net.inference(input_tensor=images,
110 | name='nsfw_cls_model',
111 | reuse=True)
112 |
113 | net_predictions = tf.nn.softmax(net_logits)
114 | net_top1_error = calculate_top_k_error(net_predictions, labels, 1)
115 |
116 | if optimizer is not None:
117 | grads = optimizer.compute_gradients(net_loss)
118 | else:
119 | grads = None
120 |
121 | return net_loss, net_top1_error, grads
122 |
123 |
124 | def train_net(dataset_dir, weights_path=None):
125 | """
126 |
127 | :param dataset_dir:
128 | :param weights_path:
129 | :return:
130 | """
131 |
132 | # set nsfw data feed pipline
133 | train_dataset = nsfw_data_feed_pipline.NsfwDataFeeder(dataset_dir=dataset_dir,
134 | flags='train')
135 | val_dataset = nsfw_data_feed_pipline.NsfwDataFeeder(dataset_dir=dataset_dir,
136 | flags='val')
137 |
138 | with tf.device('/gpu:1'):
139 | # set nsfw net
140 | nsfw_net = nsfw_classification_net.NSFWNet(phase=tf.constant('train', dtype=tf.string),
141 | resnet_size=CFG.NET.RESNET_SIZE)
142 | nsfw_net_val = nsfw_classification_net.NSFWNet(phase=tf.constant('test', dtype=tf.string),
143 | resnet_size=CFG.NET.RESNET_SIZE)
144 |
145 | # compute train loss
146 | train_images, train_labels = train_dataset.inputs(batch_size=CFG.TRAIN.BATCH_SIZE,
147 | num_epochs=1)
148 | train_loss = nsfw_net.compute_loss(input_tensor=train_images,
149 | labels=train_labels,
150 | name='nsfw_cls_model',
151 | reuse=False)
152 |
153 | train_logits = nsfw_net.inference(input_tensor=train_images,
154 | name='nsfw_cls_model',
155 | reuse=True)
156 |
157 | train_predictions = tf.nn.softmax(train_logits)
158 | train_top1_error = calculate_top_k_error(train_predictions, train_labels, 1)
159 |
160 | # compute val loss
161 | val_images, val_labels = val_dataset.inputs(batch_size=CFG.TRAIN.VAL_BATCH_SIZE,
162 | num_epochs=1)
163 | # val_images = tf.reshape(val_images, example_tensor_shape)
164 | val_loss = nsfw_net_val.compute_loss(input_tensor=val_images,
165 | labels=val_labels,
166 | name='nsfw_cls_model',
167 | reuse=True)
168 |
169 | val_logits = nsfw_net_val.inference(input_tensor=val_images,
170 | name='nsfw_cls_model',
171 | reuse=True)
172 |
173 | val_predictions = tf.nn.softmax(val_logits)
174 | val_top1_error = calculate_top_k_error(val_predictions, val_labels, 1)
175 |
176 | # set tensorflow summary
177 | tboard_save_path = 'tboard/nsfw_cls'
178 | os.makedirs(tboard_save_path, exist_ok=True)
179 |
180 | summary_writer = tf.summary.FileWriter(tboard_save_path)
181 |
182 | train_loss_scalar = tf.summary.scalar(name='train_loss',
183 | tensor=train_loss)
184 | train_top1_err_scalar = tf.summary.scalar(name='train_top1_error',
185 | tensor=train_top1_error)
186 | val_loss_scalar = tf.summary.scalar(name='val_loss',
187 | tensor=val_loss)
188 | val_top1_err_scalar = tf.summary.scalar(name='val_top1_error',
189 | tensor=val_top1_error)
190 |
191 | train_merge_summary_op = tf.summary.merge([train_loss_scalar, train_top1_err_scalar])
192 |
193 | val_merge_summary_op = tf.summary.merge([val_loss_scalar, val_top1_err_scalar])
194 |
195 | # Set tf saver
196 | saver = tf.train.Saver()
197 | model_save_dir = 'model/nsfw_cls'
198 | os.makedirs(model_save_dir, exist_ok=True)
199 | train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
200 | model_name = 'nsfw_cls_{:s}.ckpt'.format(str(train_start_time))
201 | model_save_path = ops.join(model_save_dir, model_name)
202 |
203 | # set optimizer
204 | with tf.device('/gpu:1'):
205 | # set learning rate
206 | global_step = tf.Variable(0, trainable=False)
207 | decay_steps = [CFG.TRAIN.LR_DECAY_STEPS_1, CFG.TRAIN.LR_DECAY_STEPS_2]
208 | decay_values = []
209 | init_lr = CFG.TRAIN.LEARNING_RATE
210 | for step in range(len(decay_steps) + 1):
211 | decay_values.append(init_lr)
212 | init_lr = init_lr * CFG.TRAIN.LR_DECAY_RATE
213 |
214 | learning_rate = tf.train.piecewise_constant(
215 | x=global_step,
216 | boundaries=decay_steps,
217 | values=decay_values,
218 | name='learning_rate'
219 | )
220 |
221 | update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
222 | with tf.control_dependencies(update_ops):
223 | optimizer = tf.train.MomentumOptimizer(
224 | learning_rate=learning_rate, momentum=0.9).minimize(
225 | loss=train_loss,
226 | var_list=tf.trainable_variables(),
227 | global_step=global_step)
228 |
229 | # Set sess configuration
230 | sess_config = tf.ConfigProto(allow_soft_placement=True)
231 | sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
232 | sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
233 | sess_config.gpu_options.allocator_type = 'BFC'
234 |
235 | sess = tf.Session(config=sess_config)
236 |
237 | summary_writer.add_graph(sess.graph)
238 |
239 | # Set the training parameters
240 | train_epochs = CFG.TRAIN.EPOCHS
241 |
242 | log.info('Global configuration is as follows:')
243 | log.info(CFG)
244 |
245 | with sess.as_default():
246 |
247 | tf.train.write_graph(graph_or_graph_def=sess.graph, logdir='',
248 | name='{:s}/nsfw_cls_model.pb'.format(model_save_dir))
249 |
250 | if weights_path is None:
251 | log.info('Training from scratch')
252 | init = tf.global_variables_initializer()
253 | sess.run(init)
254 | else:
255 | log.info('Restore model from last model checkpoint {:s}'.format(weights_path))
256 | saver.restore(sess=sess, save_path=weights_path)
257 |
258 | train_cost_time_mean = []
259 | val_cost_time_mean = []
260 |
261 | for epoch in range(train_epochs):
262 |
263 | # training part
264 | t_start = time.time()
265 |
266 | _, train_loss_value, train_top1_err_value, train_summary, lr = \
267 | sess.run(fetches=[optimizer,
268 | train_loss,
269 | train_top1_error,
270 | train_merge_summary_op,
271 | learning_rate])
272 |
273 | if math.isnan(train_loss_value):
274 | log.error('Train loss is nan')
275 | return
276 |
277 | cost_time = time.time() - t_start
278 | train_cost_time_mean.append(cost_time)
279 |
280 | summary_writer.add_summary(summary=train_summary,
281 | global_step=epoch)
282 |
283 | # validation part
284 | t_start_val = time.time()
285 |
286 | val_loss_value, val_top1_err_value, val_summary = \
287 | sess.run(fetches=[val_loss,
288 | val_top1_error,
289 | val_merge_summary_op])
290 |
291 | summary_writer.add_summary(val_summary, global_step=epoch)
292 |
293 | cost_time_val = time.time() - t_start_val
294 | val_cost_time_mean.append(cost_time_val)
295 |
296 | if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
297 | log.info('Epoch_Train: {:d} total_loss= {:6f} top1_error= {:6f} '
298 | 'lr= {:6f} mean_cost_time= {:5f}s '.
299 | format(epoch + 1,
300 | train_loss_value,
301 | train_top1_err_value,
302 | lr,
303 | np.mean(train_cost_time_mean)))
304 | train_cost_time_mean.clear()
305 |
306 | if epoch % CFG.TRAIN.VAL_DISPLAY_STEP == 0:
307 | log.info('Epoch_Val: {:d} total_loss= {:6f} top1_error= {:6f}'
308 | ' mean_cost_time= {:5f}s '.
309 | format(epoch + 1,
310 | val_loss_value,
311 | val_top1_err_value,
312 | np.mean(val_cost_time_mean)))
313 | val_cost_time_mean.clear()
314 |
315 | if epoch % 2000 == 0:
316 | saver.save(sess=sess, save_path=model_save_path, global_step=epoch)
317 | sess.close()
318 |
319 | return
320 |
321 |
322 | def train_net_multi_gpu(dataset_dir, weights_path=None):
323 | """
324 |
325 | :param dataset_dir:
326 | :param weights_path:
327 | :return:
328 | """
329 | # set nsfw data feed pipline
330 | train_dataset = nsfw_data_feed_pipline.NsfwDataFeeder(dataset_dir=dataset_dir,
331 | flags='train')
332 | val_dataset = nsfw_data_feed_pipline.NsfwDataFeeder(dataset_dir=dataset_dir,
333 | flags='val')
334 |
335 | # set nsfw net
336 | nsfw_net = nsfw_classification_net.NSFWNet(phase=tf.constant('train', dtype=tf.string),
337 | resnet_size=CFG.NET.RESNET_SIZE)
338 | nsfw_net_val = nsfw_classification_net.NSFWNet(phase=tf.constant('test', dtype=tf.string),
339 | resnet_size=CFG.NET.RESNET_SIZE)
340 |
341 | # fetch train and validation data
342 | train_images, train_labels = train_dataset.inputs(
343 | batch_size=CFG.TRAIN.BATCH_SIZE, num_epochs=1)
344 | val_images, val_labels = val_dataset.inputs(
345 | batch_size=CFG.TRAIN.BATCH_SIZE, num_epochs=1)
346 |
347 | # set average container
348 | tower_grads = []
349 | train_tower_loss = []
350 | train_tower_top1_error = []
351 | val_tower_loss = []
352 | val_tower_top1_error = []
353 | batchnorm_updates = None
354 | train_summary_op_updates = None
355 |
356 | # set learning rate
357 | global_step = tf.Variable(0, trainable=False)
358 | decay_steps = [CFG.TRAIN.LR_DECAY_STEPS_1, CFG.TRAIN.LR_DECAY_STEPS_2]
359 | decay_values = []
360 | init_lr = CFG.TRAIN.LEARNING_RATE
361 | for step in range(len(decay_steps) + 1):
362 | decay_values.append(init_lr)
363 | init_lr = init_lr * CFG.TRAIN.LR_DECAY_RATE
364 |
365 | learning_rate = tf.train.piecewise_constant(
366 | x=global_step,
367 | boundaries=decay_steps,
368 | values=decay_values,
369 | name='learning_rate'
370 | )
371 |
372 | # set optimizer
373 | optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.9)
374 |
375 | # set distributed train op
376 | with tf.variable_scope(tf.get_variable_scope()):
377 | is_network_initialized = False
378 | for i in range(CFG.TRAIN.GPU_NUM):
379 | with tf.device('/gpu:{:d}'.format(i)):
380 | with tf.name_scope('tower_{:d}'.format(i)) as scope:
381 | train_loss, train_top1_error, grads = compute_net_gradients(
382 | train_images, train_labels, nsfw_net, optimizer,
383 | is_net_first_initialized=is_network_initialized)
384 |
385 | is_network_initialized = True
386 |
387 | # Only use the mean and var in the first gpu tower to update the parameter
388 | # TODO implement batch normalization for distributed device (luoyao@baidu.com)
389 | if i == 0:
390 | batchnorm_updates = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
391 | train_summary_op_updates = tf.get_collection(tf.GraphKeys.SUMMARIES)
392 |
393 | tower_grads.append(grads)
394 | train_tower_loss.append(train_loss)
395 | train_tower_top1_error.append(train_top1_error)
396 | with tf.name_scope('validation_{:d}'.format(i)) as scope:
397 | val_loss, val_top1_error, _ = compute_net_gradients(
398 | val_images, val_labels, nsfw_net_val, optimizer,
399 | is_net_first_initialized=is_network_initialized)
400 | val_tower_loss.append(val_loss)
401 | val_tower_top1_error.append(val_top1_error)
402 |
403 | grads = average_gradients(tower_grads)
404 | avg_train_loss = tf.reduce_mean(train_tower_loss)
405 | avg_train_top1_error = tf.reduce_mean(train_tower_top1_error)
406 | avg_val_loss = tf.reduce_mean(val_tower_loss)
407 | avg_val_top1_error = tf.reduce_mean(val_tower_top1_error)
408 |
409 | # Track the moving averages of all trainable variables
410 | variable_averages = tf.train.ExponentialMovingAverage(
411 | CFG.TRAIN.MOVING_AVERAGE_DECAY, num_updates=global_step)
412 | variables_to_average = tf.trainable_variables() + tf.moving_average_variables()
413 | variables_averages_op = variable_averages.apply(variables_to_average)
414 |
415 | # Group all the op needed for training
416 | batchnorm_updates_op = tf.group(*batchnorm_updates)
417 | apply_gradient_op = optimizer.apply_gradients(grads, global_step=global_step)
418 | train_op = tf.group(apply_gradient_op, variables_averages_op,
419 | batchnorm_updates_op)
420 |
421 | # set tensorflow summary
422 | tboard_save_path = 'tboard/nsfw_cls'
423 | os.makedirs(tboard_save_path, exist_ok=True)
424 |
425 | summary_writer = tf.summary.FileWriter(tboard_save_path)
426 |
427 | avg_train_loss_scalar = tf.summary.scalar(name='average_train_loss',
428 | tensor=avg_train_loss)
429 | avg_train_top1_err_scalar = tf.summary.scalar(name='average_train_top1_error',
430 | tensor=avg_train_top1_error)
431 | avg_val_loss_scalar = tf.summary.scalar(name='average_val_loss',
432 | tensor=avg_val_loss)
433 | avg_val_top1_err_scalar = tf.summary.scalar(name='average_val_top1_error',
434 | tensor=avg_val_top1_error)
435 | learning_rate_scalar = tf.summary.scalar(name='learning_rate_scalar',
436 | tensor=learning_rate)
437 |
438 | train_merge_summary_op = tf.summary.merge([avg_train_loss_scalar,
439 | avg_train_top1_err_scalar,
440 | learning_rate_scalar] + train_summary_op_updates)
441 |
442 | val_merge_summary_op = tf.summary.merge([avg_val_loss_scalar, avg_val_top1_err_scalar])
443 |
444 | # set tensorflow saver
445 | saver = tf.train.Saver()
446 | model_save_dir = 'model/nsfw_cls'
447 | os.makedirs(model_save_dir, exist_ok=True)
448 | train_start_time = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
449 | model_name = 'nsfw_cls_{:s}.ckpt'.format(str(train_start_time))
450 | model_save_path = ops.join(model_save_dir, model_name)
451 |
452 | # set sess config
453 | sess_config = tf.ConfigProto(device_count={'GPU': CFG.TRAIN.GPU_NUM}, allow_soft_placement=True)
454 | sess_config.gpu_options.per_process_gpu_memory_fraction = CFG.TRAIN.GPU_MEMORY_FRACTION
455 | sess_config.gpu_options.allow_growth = CFG.TRAIN.TF_ALLOW_GROWTH
456 | sess_config.gpu_options.allocator_type = 'BFC'
457 |
458 | # Set the training parameters
459 | train_epochs = CFG.TRAIN.EPOCHS
460 |
461 | log.info('Global configuration is as follows:')
462 | log.info(CFG)
463 |
464 | sess = tf.Session(config=sess_config)
465 |
466 | summary_writer.add_graph(sess.graph)
467 |
468 | with sess.as_default():
469 |
470 | tf.train.write_graph(graph_or_graph_def=sess.graph, logdir='',
471 | name='{:s}/nsfw_cls_model.pb'.format(model_save_dir))
472 |
473 | if weights_path is None:
474 | log.info('Training from scratch')
475 | init = tf.global_variables_initializer()
476 | sess.run(init)
477 | else:
478 | log.info('Restore model from last model checkpoint {:s}'.format(weights_path))
479 | saver.restore(sess=sess, save_path=weights_path)
480 |
481 | train_cost_time_mean = []
482 | val_cost_time_mean = []
483 |
484 | for epoch in range(train_epochs):
485 |
486 | # training part
487 | t_start = time.time()
488 |
489 | _, train_loss_value, train_top1_err_value, train_summary, lr = \
490 | sess.run(fetches=[train_op,
491 | avg_train_loss,
492 | avg_train_top1_error,
493 | train_merge_summary_op,
494 | learning_rate])
495 |
496 | if math.isnan(train_loss_value):
497 | log.error('Train loss is nan')
498 | return
499 |
500 | cost_time = time.time() - t_start
501 | train_cost_time_mean.append(cost_time)
502 |
503 | summary_writer.add_summary(summary=train_summary,
504 | global_step=epoch)
505 |
506 | # validation part
507 | t_start_val = time.time()
508 |
509 | val_loss_value, val_top1_err_value, val_summary = \
510 | sess.run(fetches=[avg_val_loss,
511 | avg_val_top1_error,
512 | val_merge_summary_op])
513 |
514 | summary_writer.add_summary(val_summary, global_step=epoch)
515 |
516 | cost_time_val = time.time() - t_start_val
517 | val_cost_time_mean.append(cost_time_val)
518 |
519 | if epoch % CFG.TRAIN.DISPLAY_STEP == 0:
520 | log.info('Epoch_Train: {:d} total_loss= {:6f} top1_error= {:6f} '
521 | 'lr= {:6f} mean_cost_time= {:5f}s '.
522 | format(epoch + 1,
523 | train_loss_value,
524 | train_top1_err_value,
525 | lr,
526 | np.mean(train_cost_time_mean)))
527 | train_cost_time_mean.clear()
528 |
529 | if epoch % CFG.TRAIN.VAL_DISPLAY_STEP == 0:
530 | log.info('Epoch_Val: {:d} total_loss= {:6f} top1_error= {:6f}'
531 | ' mean_cost_time= {:5f}s '.
532 | format(epoch + 1,
533 | val_loss_value,
534 | val_top1_err_value,
535 | np.mean(val_cost_time_mean)))
536 | val_cost_time_mean.clear()
537 |
538 | if epoch % 2000 == 0:
539 | saver.save(sess=sess, save_path=model_save_path, global_step=epoch)
540 | sess.close()
541 |
542 | return
543 |
544 |
545 | if __name__ == '__main__':
546 | # init args
547 | args = init_args()
548 |
549 | if CFG.TRAIN.GPU_NUM < 2:
550 | args.use_multi_gpu = False
551 |
552 | # train lanenet
553 | if not args.use_multi_gpu:
554 | train_net(args.dataset_dir, args.weights_path)
555 | else:
556 | train_net_multi_gpu(args.dataset_dir, args.weights_path)
557 |
558 |
--------------------------------------------------------------------------------