├── nets ├── a.py ├── __init__.py ├── mobilenet_v1.png ├── inception.py ├── nets_factory_test.py ├── inception_utils.py ├── mobilenet_v1.md ├── lenet.py ├── dcgan_test.py ├── cyclegan_test.py ├── cifarnet.py ├── overfeat.py ├── alexnet.py ├── nets_factory.py └── pix2pix_test.py ├── scripts ├── a.py ├── train_lenet_on_mnist.sh ├── train_cifarnet_on_cifar10.sh ├── finetune_inception_v1_on_flowers.sh ├── finetune_resnet_v1_50_on_flowers.sh ├── finetune_inception_v3_on_flowers.sh ├── finetune_inception_resnet_v2_on_flowers.sh └── export_mobilenet.sh ├── datasets ├── a.py ├── __init__.py ├── dataset_factory.py ├── preprocess_imagenet_validation_data.py ├── flowers.py ├── mnist.py ├── cifar10.py ├── download_imagenet.sh ├── download_and_convert_imagenet.sh └── dataset_utils.py ├── depoyment ├── a.py └── __init__.py ├── pig_vgg16 ├── a.py ├── config.py ├── resize.py ├── pig_input.py ├── read_tfrecords.py ├── batch_recognize.py ├── track.py ├── single_picture_recognize.py ├── single_picture_recognize_r+p.py ├── pig_model_vgg16.py ├── pig_model_restnet_v2.py ├── pig_records.py ├── pig_train.py ├── write_tfrecords.py ├── pig_records_padding.py ├── pig_model_vgg.py ├── pig_train_1.py └── pig_model_dark.py ├── preprocessing ├── a.py ├── __init__.py ├── lenet_preprocessing.py ├── preprocessing_factory.py └── cifarnet_preprocessing.py ├── 京东.PNG ├── process.py ├── README.md ├── setup.py ├── padding.py ├── pro_data.py ├── export_inference_graph_test.py ├── pro_data_1.py ├── pro_data_12.py ├── download_and_convert_data.py ├── export_inference_graph.py └── pre_crop.py /nets/a.py: -------------------------------------------------------------------------------- 1 | a 2 | -------------------------------------------------------------------------------- /scripts/a.py: -------------------------------------------------------------------------------- 1 | a 2 | -------------------------------------------------------------------------------- /datasets/a.py: -------------------------------------------------------------------------------- 1 | a 2 | -------------------------------------------------------------------------------- /depoyment/a.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pig_vgg16/a.py: -------------------------------------------------------------------------------- 1 | a 2 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /depoyment/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /preprocessing/a.py: -------------------------------------------------------------------------------- 1 | a 2 | -------------------------------------------------------------------------------- /preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /京东.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ei1994/pig_recognize_JD/HEAD/京东.PNG -------------------------------------------------------------------------------- /nets/mobilenet_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ei1994/pig_recognize_JD/HEAD/nets/mobilenet_v1.png -------------------------------------------------------------------------------- /pig_vgg16/config.py: -------------------------------------------------------------------------------- 1 | # about pig image 2 | IMAGE_HEIGHT = 299 3 | IMAGE_WIDTH = 299 4 | 5 | CLASSES_NUM = 30 6 | 7 | # for train 8 | RECORD_DIR = 'pig_body' 9 | TRAIN_FILE = 'train.tfrecords' 10 | #TRAIN_FILE = 'D:\pig_recognize\record\train.tfrecords' 11 | VALID_FILE = 'valid.tfrecords' 12 | -------------------------------------------------------------------------------- /process.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | with open('out_v4_13.csv','r') as f: 5 | content =[line.strip() for line in f.readlines()] 6 | 7 | all_content = [] 8 | for i in range(0,len(content),30): 9 | subs = content[i:i+30] 10 | prob = [] 11 | for sub in subs: 12 | temp = sub.split(',')[-1] 13 | prob.append(temp) 14 | 15 | prob1 = list(map(float,prob)) 16 | 17 | if max(prob1) >0.8: 18 | index = prob1.index(max(prob1)) 19 | 20 | prob2 =[ 0 for i in prob1] 21 | prob2[index] = 1 22 | 23 | temps = [] 24 | for i in range(len(subs)): 25 | temp1 = subs[i].split(',')[0]+ ','+subs[i].split(',')[1]+',' +str(prob2[i]) 26 | temps.append(temp1) 27 | all_content.extend(temps) 28 | 29 | else: 30 | all_content.extend(subs) 31 | 32 | with open('new.csv','w') as f1: 33 | for row in all_content: 34 | f1.writelines(row+'\n') -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## pig_recognize_JD 2 | [京东算法组猪脸识别官网](http://jddjr.jd.com/item/4 "京东算法") 3 | 4 | ## 1、比赛总结: 5 | **时间:** 2017.10 — 2017.12 6 | **运行环境:** Python3.5、TensorFlow1.2、cuda 8.0 、cudnn 5.0、Linux14.0 7 | 8 | * (1)**训练集的制作:** 30头猪的视频素材,对每个视频,每隔10帧取一帧图像保存;使用YOLOv2进行抠图,截取图像中**猪的主体及猪脸图像** 进行保存;最后制作成tfrecord格式的数据,这样训练数据集就完成了;训练数据集:猪主体数据:8700张(1000 * 800左右大小),猪脸数据:4100(500 * 500左右大小),一共12800张图像。 9 | * (2)**数据预处理:** data argumentation。包括数据类型转换(float32)、在图像中按一定长宽比和面积比crop图像区域、随机左右翻转、随机颜色变换、扩大像素区间范围。 10 | * (3)**网络训练:** 基础网络选择**Inception-ResNet-v2** ,softmax交叉熵损失函数(要区别30头猪),batchsize是28,优化器选择adam,学习率0.01,训练80个epoch后,降为0.0001微调网络。在训练这一阶段,主要是调节网络优化器及参数,观察是否过拟合等,选择训练最好的模型。 11 | * (4)**模型集成、后处理:** 网络预测一张图片,得到的是30个类别对应的概率值。一、选择几个**训练较好的模型进行集成** ,最终取平均融合结果,也可以按一定权重对几个模型进行融合集成;二、尽量让预测的类别标签概率值不出现极端情况,也就是让网络不那么肯定是哪头猪(如:某头猪预测概率为0.99就不可取),从而换取整体logloss的降低,实验证明该方案在一定程度上缺失提升了排名。 12 | 13 | ## 2、比赛排名: 14 | ![](/京东.PNG "京东猪脸识别") 15 | 16 | 17 | 18 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Setup script for slim.""" 16 | 17 | from setuptools import find_packages 18 | from setuptools import setup 19 | 20 | 21 | setup( 22 | name='slim', 23 | version='0.1', 24 | include_package_data=True, 25 | packages=find_packages(), 26 | description='tf-slim', 27 | ) 28 | -------------------------------------------------------------------------------- /padding.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Sat Nov 25 10:52:47 2017 5 | 6 | @author: no1 7 | """ 8 | 9 | import numpy as np 10 | import cv2 11 | import scipy.misc as misc 12 | import glob 13 | import os 14 | paths = glob.glob('pig_data_face/*.jpg') 15 | count = 0 16 | for path in paths: 17 | basename = os.path.basename(path) 18 | label = basename.split('_')[0] 19 | # new_path = os.path.join('D:/pig_recognize/pig_slim1/pig_data_face_padding',label) 20 | # if not os.path.exists(new_path): 21 | # os.mkdir(new_path) 22 | 23 | img = cv2.imread(path) 24 | height, width, _ = img.shape 25 | if height<200 or width<200: 26 | os.remove(path) 27 | continue 28 | offset = abs(height -width)//2 29 | if height >= width: 30 | pad_image = np.pad(img,((0,0),(offset, offset),(0,0)),mode='constant',constant_values =0) 31 | else: 32 | pad_image = np.pad(img,((offset, offset),(0,0),(0,0)),mode='constant',constant_values =0) 33 | 34 | cv2.imwrite(os.path.join('D:/pig_recognize/pig_slim1/pig_data_face_padding', basename), pad_image) 35 | 36 | count += 1 37 | if count %500 ==0: 38 | print('processed {}/{}'.format(count, len(paths))) 39 | -------------------------------------------------------------------------------- /pro_data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Dec 6 20:42:47 2017 4 | 5 | @author: DELL 6 | """ 7 | import copy 8 | import numpy as np 9 | 10 | offset = 0.0001 11 | 12 | with open('out_v4_13.csv','r') as f: 13 | content =[line.strip() for line in f.readlines()] 14 | 15 | all_content = [] 16 | 17 | for i in range(0,len(content),30): 18 | subs = content[i:i+30] 19 | x = [] 20 | for sub in subs: 21 | temp = sub.split(',')[-1] 22 | x.append(temp) 23 | 24 | prob = list(map(float,x)) 25 | prob_max = max(prob) 26 | index = prob.index(prob_max) 27 | result = copy.copy(prob) 28 | 29 | if prob_max < 0.3: 30 | all_sum = 0 31 | a = np.array(prob) 32 | index_a = np.argsort(a) 33 | index_b = index_a[:27] 34 | 35 | for m in (index_b): 36 | all_sum = all_sum + a[m] 37 | abc = all_sum/27 38 | for j in index_b: 39 | 40 | result[j] = abc 41 | 42 | 43 | temps = [] 44 | 45 | for i in range(len(subs)): 46 | temp1 = subs[i].split(',')[0]+ ','+subs[i].split(',')[1]+',' +str(result[i]) 47 | temps.append(temp1) 48 | 49 | all_content.extend(temps) 50 | 51 | else: 52 | all_content.extend(subs) 53 | # 54 | with open('new1.csv','w') as f1: 55 | for row in all_content: 56 | f1.writelines(row+'\n') 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /export_inference_graph_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for export_inference_graph.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | import os 23 | 24 | 25 | import tensorflow as tf 26 | 27 | from tensorflow.python.platform import gfile 28 | import export_inference_graph 29 | 30 | 31 | class ExportInferenceGraphTest(tf.test.TestCase): 32 | 33 | def testExportInferenceGraph(self): 34 | tmpdir = self.get_temp_dir() 35 | output_file = os.path.join(tmpdir, 'inception_v3.pb') 36 | flags = tf.app.flags.FLAGS 37 | flags.output_file = output_file 38 | flags.model_name = 'inception_v3' 39 | flags.dataset_dir = tmpdir 40 | export_inference_graph.main(None) 41 | self.assertTrue(gfile.Exists(output_file)) 42 | 43 | if __name__ == '__main__': 44 | tf.test.main() 45 | -------------------------------------------------------------------------------- /preprocessing/lenet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities for preprocessing.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def preprocess_image(image, output_height, output_width, is_training): 27 | """Preprocesses the given image. 28 | 29 | Args: 30 | image: A `Tensor` representing an image of arbitrary size. 31 | output_height: The height of the image after preprocessing. 32 | output_width: The width of the image after preprocessing. 33 | is_training: `True` if we're preprocessing the image for training and 34 | `False` otherwise. 35 | 36 | Returns: 37 | A preprocessed image. 38 | """ 39 | image = tf.to_float(image) 40 | image = tf.image.resize_image_with_crop_or_pad( 41 | image, output_width, output_height) 42 | image = tf.subtract(image, 128.0) 43 | image = tf.div(image, 128.0) 44 | return image 45 | -------------------------------------------------------------------------------- /pig_vgg16/resize.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Nov 22 22:38:02 2017 4 | 5 | @author: DELL 6 | """ 7 | ''' 8 | resize images 9 | ''' 10 | 11 | import cv2 12 | import numpy as np 13 | import glob 14 | import os 15 | 16 | size = 320 17 | files = glob.glob('D:/pig_recognize_body/pig_body/train_data/*.jpg') 18 | result_dir = 'pig_body/train_data_resize' 19 | 20 | try: 21 | os.makedirs(result_dir) 22 | except os.error: 23 | pass 24 | for i in files: 25 | img = cv2.imread(i) 26 | # img = cv2.imread('D:/pig_recognize_body/pig_body/train_data/2_00001_pig.jpg') 27 | base_name = os.path.basename(i) 28 | 29 | h,w,c = img.shape 30 | fig = np.ones((size,size,3))*255 31 | 32 | if h>w: 33 | rate = h/size 34 | h_v = size 35 | w_v = int(w/rate) 36 | border = (size - w_v) 37 | up = 0 38 | down = size + 1 39 | left = int(border/2) 40 | right = int(size - border/2) 41 | else: 42 | rate = w/size 43 | w_v = size 44 | h_v = int(h/rate) 45 | border = (size - h_v) 46 | up = int(border/2) 47 | down = int(size - border/2) 48 | left = 0 49 | right = size + 1 50 | 51 | img_v = cv2.resize(img, (w_v, h_v)) 52 | fig[up:down, left:right] = img_v 53 | cv2.imwrite(os.path.join(result_dir,base_name), fig) 54 | 55 | #fig = fig.astype(img_v.dtype) 56 | #cv2.waitKey(0) 57 | #cv2.imshow('img.jpg', img_v) 58 | #cv2.imshow('img1.jpg', fig) 59 | #cv2.waitKey(0) 60 | #cv2.waitKey(0) 61 | #cv2.waitKey(0) 62 | #cv2.waitKey(0) 63 | 64 | -------------------------------------------------------------------------------- /nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings all inception models under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from nets.inception_resnet_v2 import inception_resnet_v2 23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope 24 | from nets.inception_resnet_v2 import inception_resnet_v2_base 25 | from nets.inception_v1 import inception_v1 26 | from nets.inception_v1 import inception_v1_arg_scope 27 | from nets.inception_v1 import inception_v1_base 28 | from nets.inception_v2 import inception_v2 29 | from nets.inception_v2 import inception_v2_arg_scope 30 | from nets.inception_v2 import inception_v2_base 31 | from nets.inception_v3 import inception_v3 32 | from nets.inception_v3 import inception_v3_arg_scope 33 | from nets.inception_v3 import inception_v3_base 34 | from nets.inception_v4 import inception_v4 35 | from nets.inception_v4 import inception_v4_arg_scope 36 | from nets.inception_v4 import inception_v4_base 37 | # pylint: enable=unused-import 38 | -------------------------------------------------------------------------------- /datasets/dataset_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """A factory-pattern class which returns classification image/label pairs.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from datasets import cifar10 22 | from datasets import flowers 23 | from datasets import imagenet 24 | from datasets import mnist 25 | 26 | datasets_map = { 27 | 'cifar10': cifar10, 28 | 'flowers': flowers, 29 | 'imagenet': imagenet, 30 | 'mnist': mnist, 31 | } 32 | 33 | 34 | def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None): 35 | """Given a dataset name and a split_name returns a Dataset. 36 | 37 | Args: 38 | name: String, the name of the dataset. 39 | split_name: A train/test split name. 40 | dataset_dir: The directory where the dataset files are stored. 41 | file_pattern: The file pattern to use for matching the dataset source files. 42 | reader: The subclass of tf.ReaderBase. If left as `None`, then the default 43 | reader defined by each dataset is used. 44 | 45 | Returns: 46 | A `Dataset` class. 47 | 48 | Raises: 49 | ValueError: If the dataset `name` is unknown. 50 | """ 51 | if name not in datasets_map: 52 | raise ValueError('Name of dataset unknown %s' % name) 53 | return datasets_map[name].get_split( 54 | split_name, 55 | dataset_dir, 56 | file_pattern, 57 | reader) 58 | -------------------------------------------------------------------------------- /scripts/train_lenet_on_mnist.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the MNIST dataset 19 | # 2. Trains a LeNet model on the MNIST training set. 20 | # 3. Evaluates the model on the MNIST testing set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/train_lenet_on_mnist.sh 25 | set -e 26 | 27 | # Where the checkpoint and logs will be saved to. 28 | TRAIN_DIR=/tmp/lenet-model 29 | 30 | # Where the dataset is saved to. 31 | DATASET_DIR=/tmp/mnist 32 | 33 | # Download the dataset 34 | python download_and_convert_data.py \ 35 | --dataset_name=mnist \ 36 | --dataset_dir=${DATASET_DIR} 37 | 38 | # Run training. 39 | python train_image_classifier.py \ 40 | --train_dir=${TRAIN_DIR} \ 41 | --dataset_name=mnist \ 42 | --dataset_split_name=train \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --model_name=lenet \ 45 | --preprocessing_name=lenet \ 46 | --max_number_of_steps=20000 \ 47 | --batch_size=50 \ 48 | --learning_rate=0.01 \ 49 | --save_interval_secs=60 \ 50 | --save_summaries_secs=60 \ 51 | --log_every_n_steps=100 \ 52 | --optimizer=sgd \ 53 | --learning_rate_decay_type=fixed \ 54 | --weight_decay=0 55 | 56 | # Run evaluation. 57 | python eval_image_classifier.py \ 58 | --checkpoint_path=${TRAIN_DIR} \ 59 | --eval_dir=${TRAIN_DIR} \ 60 | --dataset_name=mnist \ 61 | --dataset_split_name=test \ 62 | --dataset_dir=${DATASET_DIR} \ 63 | --model_name=lenet 64 | -------------------------------------------------------------------------------- /scripts/train_cifarnet_on_cifar10.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Cifar10 dataset 19 | # 2. Trains a CifarNet model on the Cifar10 training set. 20 | # 3. Evaluates the model on the Cifar10 testing set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./scripts/train_cifarnet_on_cifar10.sh 25 | set -e 26 | 27 | # Where the checkpoint and logs will be saved to. 28 | TRAIN_DIR=/tmp/cifarnet-model 29 | 30 | # Where the dataset is saved to. 31 | DATASET_DIR=/tmp/cifar10 32 | 33 | # Download the dataset 34 | python download_and_convert_data.py \ 35 | --dataset_name=cifar10 \ 36 | --dataset_dir=${DATASET_DIR} 37 | 38 | # Run training. 39 | python train_image_classifier.py \ 40 | --train_dir=${TRAIN_DIR} \ 41 | --dataset_name=cifar10 \ 42 | --dataset_split_name=train \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --model_name=cifarnet \ 45 | --preprocessing_name=cifarnet \ 46 | --max_number_of_steps=100000 \ 47 | --batch_size=128 \ 48 | --save_interval_secs=120 \ 49 | --save_summaries_secs=120 \ 50 | --log_every_n_steps=100 \ 51 | --optimizer=sgd \ 52 | --learning_rate=0.1 \ 53 | --learning_rate_decay_factor=0.1 \ 54 | --num_epochs_per_decay=200 \ 55 | --weight_decay=0.004 56 | 57 | # Run evaluation. 58 | python eval_image_classifier.py \ 59 | --checkpoint_path=${TRAIN_DIR} \ 60 | --eval_dir=${TRAIN_DIR} \ 61 | --dataset_name=cifar10 \ 62 | --dataset_split_name=test \ 63 | --dataset_dir=${DATASET_DIR} \ 64 | --model_name=cifarnet 65 | -------------------------------------------------------------------------------- /pig_vgg16/pig_input.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import os.path 6 | import tensorflow as tf 7 | 8 | import config 9 | 10 | RECORD_DIR = config.RECORD_DIR 11 | TRAIN_FILE = config.TRAIN_FILE 12 | VALID_FILE = config.VALID_FILE 13 | 14 | IMAGE_WIDTH = config.IMAGE_WIDTH 15 | IMAGE_HEIGHT = config.IMAGE_HEIGHT 16 | CLASSES_NUM = config.CLASSES_NUM 17 | 18 | def read_and_decode(filename_queue): 19 | reader = tf.TFRecordReader() 20 | _, serialized_example = reader.read(filename_queue) 21 | features = tf.parse_single_example( 22 | serialized_example, 23 | features={ 24 | 'image_raw': tf.FixedLenFeature([], tf.string), 25 | 'label_raw': tf.FixedLenFeature([], tf.string), 26 | }) 27 | image = tf.decode_raw(features['image_raw'], tf.int16) 28 | image.set_shape([IMAGE_HEIGHT * IMAGE_WIDTH * 3]) 29 | 30 | image = tf.cast(image, tf.float32) * (1. / 127.5) - 1 #(-1,1) 31 | reshape_image = tf.reshape(image, [IMAGE_HEIGHT, IMAGE_WIDTH, 3]) 32 | 33 | reshape_image = tf.image.random_flip_left_right(reshape_image) 34 | 35 | label = tf.decode_raw(features['label_raw'], tf.uint8) 36 | reshape_label = tf.reshape(label, [CLASSES_NUM])#(30,) 37 | return tf.cast(reshape_image, tf.float32), tf.cast(reshape_label, tf.float32) 38 | 39 | 40 | def inputs(train, batch_size): 41 | filename = os.path.join(RECORD_DIR, 42 | TRAIN_FILE if train else VALID_FILE) 43 | 44 | with tf.name_scope('input'): 45 | filename_queue = tf.train.string_input_producer([filename]) 46 | image, label = read_and_decode(filename_queue) 47 | if train: 48 | images, sparse_labels = tf.train.shuffle_batch([image, label], 49 | batch_size=batch_size, 50 | num_threads=6, 51 | capacity=2000 + 3 * batch_size, 52 | min_after_dequeue=2000) 53 | else: 54 | images, sparse_labels = tf.train.batch([image, label], 55 | batch_size=batch_size, 56 | num_threads=6, 57 | capacity=2000 + 3 * batch_size) 58 | 59 | return images, sparse_labels 60 | -------------------------------------------------------------------------------- /pro_data_1.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Dec 9 19:59:58 2017 4 | 5 | @author: DELL 6 | """ 7 | 8 | import copy 9 | import numpy as np 10 | 11 | offset = 0.0001 12 | q = [] 13 | pq = 0 14 | 15 | with open('out_b_17.csv','r') as f: 16 | content =[line.strip() for line in f.readlines()] 17 | 18 | with open('out_b_17_face.csv','r') as f: 19 | content1 =[line.strip() for line in f.readlines()] 20 | 21 | all_content = [] 22 | 23 | for i in range(0,len(content),30): 24 | subs = content[i:i+30] 25 | x = [] 26 | for sub in subs: 27 | temp = sub.split(',')[-1] 28 | x.append(temp) 29 | 30 | prob = list(map(float,x)) 31 | prob_max = max(prob) 32 | # index = prob.index(prob_max) 33 | # result = copy.copy(prob) 34 | # a = np.array(prob) 35 | # index_a = np.argsort(-a) 36 | 37 | if prob_max > 0.6 : 38 | number = subs[1].split(',')[0] 39 | q.append(number) 40 | # pq = pq + 1 41 | for i1 in range(0,len(content1),30): 42 | subs1 = content1[i1:i1+30] 43 | y = [] 44 | number1 = subs1[1].split(',')[0] 45 | if (number1 == number): 46 | for sub in subs1: 47 | temp = sub.split(',')[-1] 48 | y.append(temp) 49 | prob1 = list(map(float,y)) 50 | prob_max1 = max(prob1) 51 | else: 52 | continue 53 | temps = [] 54 | 55 | # for j in range(len(subs)): 56 | # temp1 = subs[j].split(',')[0]+ ','+subs[j].split(',')[1]+',' + str(prob1[j]) 57 | # temps.append(temp1) 58 | 59 | if prob_max1 > prob_max: 60 | pq = pq + 1 61 | for i in range(len(subs)): 62 | # prob1[i] = prob1[i]/1.000001 63 | temp1 = subs[i].split(',')[0]+ ','+subs[i].split(',')[1]+',' + str(prob1[i]) 64 | temps.append(temp1) 65 | else: 66 | for i in range(len(subs)): 67 | temp1 = subs[i].split(',')[0]+ ','+subs[i].split(',')[1]+',' + str(prob[i]) 68 | temps.append(temp1) 69 | 70 | all_content.extend(temps) 71 | else: 72 | all_content.extend(subs) 73 | 74 | #with open('out_b_19_21.csv','w') as f1: 75 | # for row in all_content: 76 | # f1.writelines(row+'\n') 77 | 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /pro_data_12.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Dec 10 20:54:37 2017 4 | 5 | @author: DELL 6 | """ 7 | 8 | import copy 9 | import numpy as np 10 | 11 | offset = 0.0001 12 | q = [] 13 | pq = 0 14 | 15 | with open('out_b_19_21.csv','r') as f: 16 | content =[line.strip() for line in f.readlines()] 17 | 18 | with open('out_b_19_face.csv','r') as f: 19 | content1 =[line.strip() for line in f.readlines()] 20 | 21 | all_content = [] 22 | 23 | for i in range(0,len(content),30): 24 | subs = content[i:i+30] 25 | x = [] 26 | for sub in subs: 27 | temp = sub.split(',')[-1] 28 | x.append(temp) 29 | 30 | prob = list(map(float,x)) 31 | prob_max = max(prob) 32 | # index = prob.index(prob_max) 33 | # result = copy.copy(prob) 34 | # a = np.array(prob) 35 | # index_a = np.argsort(-a) 36 | 37 | if 0.3>prob_max > 0 : 38 | number = subs[1].split(',')[0] 39 | q.append(number) 40 | # pq = pq + 1 41 | for i1 in range(0,len(content1),30): 42 | subs1 = content1[i1:i1+30] 43 | y = [] 44 | number1 = subs1[1].split(',')[0] 45 | if (number1 == number): 46 | for sub in subs1: 47 | temp = sub.split(',')[-1] 48 | y.append(temp) 49 | prob1 = list(map(float,y)) 50 | prob_max1 = max(prob1) 51 | else: 52 | continue 53 | temps = [] 54 | 55 | # for j in range(len(subs)): 56 | # temp1 = subs[j].split(',')[0]+ ','+subs[j].split(',')[1]+',' + str(prob1[j]) 57 | # temps.append(temp1) 58 | 59 | if prob_max1 > prob_max and prob_max1 > 0.7 : 60 | pq = pq + 1 61 | for i in range(len(subs)): 62 | # prob1[i] = prob1[i]/1.000001 63 | temp1 = subs[i].split(',')[0]+ ','+subs[i].split(',')[1]+',' + str(prob1[i]) 64 | temps.append(temp1) 65 | else: 66 | for i in range(len(subs)): 67 | temp1 = subs[i].split(',')[0]+ ','+subs[i].split(',')[1]+',' + str(prob[i]) 68 | temps.append(temp1) 69 | 70 | all_content.extend(temps) 71 | else: 72 | all_content.extend(subs) 73 | 74 | with open('out_b_19_22.csv','w') as f1: 75 | for row in all_content: 76 | f1.writelines(row+'\n') 77 | 78 | 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /pig_vgg16/read_tfrecords.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Dec 1 22:01:56 2017 4 | 5 | @author: DELL 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import os.path 13 | import tensorflow as tf 14 | 15 | import config 16 | 17 | RECORD_DIR = config.RECORD_DIR 18 | TRAIN_FILE = config.TRAIN_FILE 19 | VALID_FILE = config.VALID_FILE 20 | 21 | IMAGE_WIDTH = config.IMAGE_WIDTH 22 | IMAGE_HEIGHT = config.IMAGE_HEIGHT 23 | CLASSES_NUM = config.CLASSES_NUM 24 | 25 | def read_and_decode(filename_queue): 26 | reader = tf.TFRecordReader() 27 | _, serialized_example = reader.read(filename_queue) 28 | features = tf.parse_single_example( 29 | serialized_example, 30 | features={ 31 | 'image_raw': tf.FixedLenFeature([], tf.string), 32 | 'label_raw': tf.FixedLenFeature([], tf.string), 33 | }) 34 | image = tf.decode_raw(features['image_raw'], tf.int16) 35 | image.set_shape([IMAGE_HEIGHT * IMAGE_WIDTH * 3]) 36 | 37 | image = tf.cast(image, tf.float32) * (1. / 127.5) - 1 #(-1,1) 38 | reshape_image = tf.reshape(image, [IMAGE_HEIGHT, IMAGE_WIDTH, 3]) 39 | 40 | reshape_image = tf.image.random_flip_left_right(reshape_image) 41 | 42 | label = tf.decode_raw(features['label_raw'], tf.uint8) 43 | reshape_label = tf.reshape(label, [CLASSES_NUM])#(30,) 44 | return tf.cast(reshape_image, tf.float32), tf.cast(reshape_label, tf.float32) 45 | 46 | 47 | def inputs(train, batch_size): 48 | filename = os.path.join(RECORD_DIR, 49 | TRAIN_FILE if train else VALID_FILE) 50 | 51 | with tf.name_scope('input'): 52 | filename_queue = tf.train.string_input_producer([filename]) 53 | image, label = read_and_decode(filename_queue) 54 | if train: 55 | images, sparse_labels = tf.train.shuffle_batch([image, label], 56 | batch_size=batch_size, 57 | num_threads=6, 58 | capacity=2000 + 3 * batch_size, 59 | min_after_dequeue=2000) 60 | else: 61 | images, sparse_labels = tf.train.batch([image, label], 62 | batch_size=batch_size, 63 | num_threads=6, 64 | capacity=2000 + 3 * batch_size) 65 | 66 | return images, sparse_labels 67 | -------------------------------------------------------------------------------- /nets/nets_factory_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for slim.inception.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import tensorflow as tf 24 | 25 | from nets import nets_factory 26 | 27 | 28 | class NetworksTest(tf.test.TestCase): 29 | 30 | def testGetNetworkFnFirstHalf(self): 31 | batch_size = 5 32 | num_classes = 1000 33 | for net in nets_factory.networks_map.keys()[:10]: 34 | with tf.Graph().as_default() as g, self.test_session(g): 35 | net_fn = nets_factory.get_network_fn(net, num_classes) 36 | # Most networks use 224 as their default_image_size 37 | image_size = getattr(net_fn, 'default_image_size', 224) 38 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 39 | logits, end_points = net_fn(inputs) 40 | self.assertTrue(isinstance(logits, tf.Tensor)) 41 | self.assertTrue(isinstance(end_points, dict)) 42 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 43 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 44 | 45 | def testGetNetworkFnSecondHalf(self): 46 | batch_size = 5 47 | num_classes = 1000 48 | for net in nets_factory.networks_map.keys()[10:]: 49 | with tf.Graph().as_default() as g, self.test_session(g): 50 | net_fn = nets_factory.get_network_fn(net, num_classes) 51 | # Most networks use 224 as their default_image_size 52 | image_size = getattr(net_fn, 'default_image_size', 224) 53 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 54 | logits, end_points = net_fn(inputs) 55 | self.assertTrue(isinstance(logits, tf.Tensor)) 56 | self.assertTrue(isinstance(end_points, dict)) 57 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 58 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /download_and_convert_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Downloads and converts a particular dataset. 16 | 17 | Usage: 18 | ```shell 19 | 20 | $ python download_and_convert_data.py \ 21 | --dataset_name=mnist \ 22 | --dataset_dir=/tmp/mnist 23 | 24 | $ python download_and_convert_data.py \ 25 | --dataset_name=cifar10 \ 26 | --dataset_dir=/tmp/cifar10 27 | 28 | $ python download_and_convert_data.py \ 29 | --dataset_name=flowers \ 30 | --dataset_dir=/tmp/flowers 31 | ``` 32 | """ 33 | from __future__ import absolute_import 34 | from __future__ import division 35 | from __future__ import print_function 36 | 37 | import tensorflow as tf 38 | 39 | from datasets import download_and_convert_cifar10 40 | from datasets import download_and_convert_flowers 41 | from datasets import download_and_convert_mnist 42 | 43 | FLAGS = tf.app.flags.FLAGS 44 | 45 | tf.app.flags.DEFINE_string( 46 | 'dataset_name', 47 | None, 48 | 'The name of the dataset to convert, one of "cifar10", "flowers", "mnist".') 49 | 50 | tf.app.flags.DEFINE_string( 51 | 'dataset_dir', 52 | None, 53 | 'The directory where the output TFRecords and temporary files are saved.') 54 | 55 | 56 | def main(_): 57 | FLAGS.dataset_name='cifar10' 58 | FLAGS.dataset_dir='D:/pig_recognize/models/research/slim/cifar10' 59 | if not FLAGS.dataset_name: 60 | raise ValueError('You must supply the dataset name with --dataset_name') 61 | if not FLAGS.dataset_dir: 62 | raise ValueError('You must supply the dataset directory with --dataset_dir') 63 | 64 | if FLAGS.dataset_name == 'cifar10': 65 | download_and_convert_cifar10.run(FLAGS.dataset_dir) 66 | elif FLAGS.dataset_name == 'flowers': 67 | download_and_convert_flowers.run(FLAGS.dataset_dir) 68 | elif FLAGS.dataset_name == 'mnist': 69 | download_and_convert_mnist.run(FLAGS.dataset_dir) 70 | else: 71 | raise ValueError( 72 | 'dataset_name [%s] was not recognized.' % FLAGS.dataset_name) 73 | 74 | if __name__ == '__main__': 75 | tf.app.run() 76 | -------------------------------------------------------------------------------- /pig_vgg16/batch_recognize.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import sys 7 | import os.path 8 | from datetime import datetime 9 | from PIL import Image 10 | import numpy as np 11 | 12 | import tensorflow as tf 13 | from tensorflow.python.platform import gfile 14 | import pig_model 15 | import csv 16 | import config 17 | 18 | IMAGE_WIDTH = config.IMAGE_WIDTH 19 | IMAGE_HEIGHT = config.IMAGE_HEIGHT 20 | CLASSES_NUM = config.CLASSES_NUM 21 | 22 | FLAGS = None 23 | Batch_size = 1 24 | 25 | def input_data(image_dir): 26 | if not gfile.Exists(image_dir): 27 | print(">> Image director '" + image_dir + "' not found.") 28 | return None 29 | 30 | print(">> Looking for images in '" + image_dir + "'") 31 | 32 | 33 | file_glob = os.path.join(image_dir, '*.JPG' ) 34 | file_list = gfile.Glob(file_glob) 35 | if not file_list: 36 | print(">> No files found in '" + image_dir + "'") 37 | return None 38 | file_list = sorted(file_list) 39 | all_files = len(file_list) 40 | images = np.zeros([all_files, IMAGE_HEIGHT*IMAGE_WIDTH*3], dtype='float32') 41 | files = [] 42 | i = 0 43 | for file_name in file_list: 44 | image = Image.open(file_name) 45 | image_resize = image.resize(size=(IMAGE_WIDTH,IMAGE_HEIGHT)) 46 | input_img = np.array(image_resize, dtype='float32') 47 | input_img = input_img.flatten()/127.5 - 1 48 | images[i,:] = input_img 49 | base_name = os.path.basename(file_name) 50 | files.append(base_name) 51 | i += 1 52 | return images, files 53 | 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument( 56 | '--checkpoint_dir', 57 | type=str, 58 | default='checkpoint', 59 | help='Directory where to restore checkpoint.' 60 | ) 61 | parser.add_argument( 62 | '--test_dir', 63 | type=str, 64 | default='pig_body/body_test', 65 | help='Directory where to get captcha images.' 66 | ) 67 | FLAGS, unparsed = parser.parse_known_args() 68 | 69 | with tf.Graph().as_default(): 70 | input_images, input_filenames = input_data(FLAGS.test_dir)#得到文件夹内所有照片和文件名 71 | max_step = len(input_images) 72 | images = tf.placeholder(tf.float32,[IMAGE_HEIGHT*IMAGE_WIDTH*3],name ='input') 73 | logits = pig_model.inference(images, keep_prob=1,is_training=True) 74 | output = pig_model.predict(logits) 75 | saver = tf.train.Saver() 76 | sess = tf.Session() 77 | saver.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir)) 78 | tag = [] 79 | for each in range(max_step): 80 | feed_dict = input_images[each] 81 | recog_result = sess.run(output,feed_dict={images:feed_dict}) 82 | tag.append(list(recog_result)) 83 | np.save('predict_body',tag) 84 | sess.close() -------------------------------------------------------------------------------- /nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | 17 | Usage of arg scope: 18 | with slim.arg_scope(inception_arg_scope()): 19 | logits, end_points = inception.inception_v3(images, num_classes, 20 | is_training=is_training) 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def inception_arg_scope(weight_decay=0.00004, 33 | use_batch_norm=True, 34 | batch_norm_decay=0.9997, 35 | batch_norm_epsilon=0.001, 36 | activation_fn=tf.nn.relu): 37 | """Defines the default arg scope for inception models. 38 | 39 | Args: 40 | weight_decay: The weight decay to use for regularizing the model. 41 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 42 | batch_norm_decay: Decay for batch norm moving average. 43 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 44 | in batch norm. 45 | activation_fn: Activation function for conv2d. 46 | 47 | Returns: 48 | An `arg_scope` to use for the inception models. 49 | """ 50 | batch_norm_params = { 51 | # Decay for the moving averages. 52 | 'decay': batch_norm_decay, 53 | # epsilon to prevent 0s in variance. 54 | 'epsilon': batch_norm_epsilon, 55 | # collection containing update_ops. 56 | 'updates_collections': tf.GraphKeys.UPDATE_OPS, 57 | # use fused batch norm if possible. 58 | 'fused': None, 59 | } 60 | if use_batch_norm: 61 | normalizer_fn = slim.batch_norm 62 | normalizer_params = batch_norm_params 63 | else: 64 | normalizer_fn = None 65 | normalizer_params = {} 66 | # Set weight_decay for weights in Conv and FC layers. 67 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 68 | weights_regularizer=slim.l2_regularizer(weight_decay)): 69 | with slim.arg_scope( 70 | [slim.conv2d], 71 | weights_initializer=slim.variance_scaling_initializer(), 72 | activation_fn=activation_fn, 73 | normalizer_fn=normalizer_fn, 74 | normalizer_params=normalizer_params) as sc: 75 | return sc 76 | -------------------------------------------------------------------------------- /preprocessing/preprocessing_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from preprocessing import cifarnet_preprocessing 24 | from preprocessing import inception_preprocessing 25 | from preprocessing import lenet_preprocessing 26 | from preprocessing import vgg_preprocessing 27 | 28 | slim = tf.contrib.slim 29 | 30 | 31 | def get_preprocessing(name, is_training=False): 32 | """Returns preprocessing_fn(image, height, width, **kwargs). 33 | 34 | Args: 35 | name: The name of the preprocessing function. 36 | is_training: `True` if the model is being used for training and `False` 37 | otherwise. 38 | 39 | Returns: 40 | preprocessing_fn: A function that preprocessing a single image (pre-batch). 41 | It has the following signature: 42 | image = preprocessing_fn(image, output_height, output_width, ...). 43 | 44 | Raises: 45 | ValueError: If Preprocessing `name` is not recognized. 46 | """ 47 | preprocessing_fn_map = { 48 | 'cifarnet': cifarnet_preprocessing, 49 | 'inception': inception_preprocessing, 50 | 'inception_v1': inception_preprocessing, 51 | 'inception_v2': inception_preprocessing, 52 | 'inception_v3': inception_preprocessing, 53 | 'inception_v4': inception_preprocessing, 54 | 'inception_resnet_v2': inception_preprocessing, 55 | 'lenet': lenet_preprocessing, 56 | 'mobilenet_v1': inception_preprocessing, 57 | 'nasnet_mobile': inception_preprocessing, 58 | 'nasnet_large': inception_preprocessing, 59 | 'resnet_v1_50': vgg_preprocessing, 60 | 'resnet_v1_101': vgg_preprocessing, 61 | 'resnet_v1_152': vgg_preprocessing, 62 | 'resnet_v1_200': vgg_preprocessing, 63 | 'resnet_v2_50': vgg_preprocessing, 64 | 'resnet_v2_101': vgg_preprocessing, 65 | 'resnet_v2_152': vgg_preprocessing, 66 | 'resnet_v2_200': vgg_preprocessing, 67 | 'vgg': vgg_preprocessing, 68 | 'vgg_a': vgg_preprocessing, 69 | 'vgg_16': vgg_preprocessing, 70 | 'vgg_19': vgg_preprocessing, 71 | } 72 | 73 | if name not in preprocessing_fn_map: 74 | raise ValueError('Preprocessing name [%s] was not recognized' % name) 75 | 76 | def preprocessing_fn(image, output_height, output_width, **kwargs): 77 | return preprocessing_fn_map[name].preprocess_image( 78 | image, output_height, output_width, is_training=is_training, **kwargs) 79 | 80 | return preprocessing_fn 81 | -------------------------------------------------------------------------------- /datasets/preprocess_imagenet_validation_data.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # Copyright 2016 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | """Process the ImageNet Challenge bounding boxes for TensorFlow model training. 17 | 18 | Associate the ImageNet 2012 Challenge validation data set with labels. 19 | 20 | The raw ImageNet validation data set is expected to reside in JPEG files 21 | located in the following directory structure. 22 | 23 | data_dir/ILSVRC2012_val_00000001.JPEG 24 | data_dir/ILSVRC2012_val_00000002.JPEG 25 | ... 26 | data_dir/ILSVRC2012_val_00050000.JPEG 27 | 28 | This script moves the files into a directory structure like such: 29 | data_dir/n01440764/ILSVRC2012_val_00000293.JPEG 30 | data_dir/n01440764/ILSVRC2012_val_00000543.JPEG 31 | ... 32 | where 'n01440764' is the unique synset label associated with 33 | these images. 34 | 35 | This directory reorganization requires a mapping from validation image 36 | number (i.e. suffix of the original file) to the associated label. This 37 | is provided in the ImageNet development kit via a Matlab file. 38 | 39 | In order to make life easier and divorce ourselves from Matlab, we instead 40 | supply a custom text file that provides this mapping for us. 41 | 42 | Sample usage: 43 | ./preprocess_imagenet_validation_data.py ILSVRC2012_img_val \ 44 | imagenet_2012_validation_synset_labels.txt 45 | """ 46 | 47 | from __future__ import absolute_import 48 | from __future__ import division 49 | from __future__ import print_function 50 | 51 | import os 52 | import os.path 53 | import sys 54 | 55 | 56 | if __name__ == '__main__': 57 | if len(sys.argv) < 3: 58 | print('Invalid usage\n' 59 | 'usage: preprocess_imagenet_validation_data.py ' 60 | ' ') 61 | sys.exit(-1) 62 | data_dir = sys.argv[1] 63 | validation_labels_file = sys.argv[2] 64 | 65 | # Read in the 50000 synsets associated with the validation data set. 66 | labels = [l.strip() for l in open(validation_labels_file).readlines()] 67 | unique_labels = set(labels) 68 | 69 | # Make all sub-directories in the validation data dir. 70 | for label in unique_labels: 71 | labeled_data_dir = os.path.join(data_dir, label) 72 | os.makedirs(labeled_data_dir) 73 | 74 | # Move all of the image to the appropriate sub-directory. 75 | for i in xrange(len(labels)): 76 | basename = 'ILSVRC2012_val_000%.5d.JPEG' % (i + 1) 77 | original_filename = os.path.join(data_dir, basename) 78 | if not os.path.exists(original_filename): 79 | print('Failed to find: ' % original_filename) 80 | sys.exit(-1) 81 | new_filename = os.path.join(data_dir, labels[i], basename) 82 | os.rename(original_filename, new_filename) 83 | -------------------------------------------------------------------------------- /pig_vgg16/track.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Nov 21 16:44:45 2017 5 | 6 | @author: no1 7 | """ 8 | 9 | import cv2 10 | import sys 11 | 12 | major_ver, minor_ver, subminor_ver = cv2.__version__.split('.') 13 | 14 | if __name__ == '__main__' : 15 | 16 | # Set up tracker. 17 | # Instead of MIL, you can also use 18 | 19 | tracker_types = ['BOOSTING', 'MIL','KCF', 'TLD', 'MEDIANFLOW', 'GOTURN'] 20 | tracker_type = tracker_types[2] 21 | 22 | if int(minor_ver) < 3: 23 | tracker = cv2.Tracker_create(tracker_type) 24 | else: 25 | if tracker_type == 'BOOSTING': 26 | tracker = cv2.TrackerBoosting_create() 27 | if tracker_type == 'MIL': 28 | tracker = cv2.TrackerMIL_create() 29 | if tracker_type == 'KCF': 30 | tracker = cv2.TrackerKCF_create() 31 | if tracker_type == 'TLD': 32 | tracker = cv2.TrackerTLD_create() 33 | if tracker_type == 'MEDIANFLOW': 34 | tracker = cv2.TrackerMedianFlow_create() 35 | if tracker_type == 'GOTURN': 36 | tracker = cv2.TrackerGOTURN_create() 37 | 38 | # Read video 39 | video = cv2.VideoCapture(0) 40 | 41 | # Exit if video not opened. 42 | if not video.isOpened(): 43 | print ("Could not open video") 44 | sys.exit() 45 | 46 | # Read first frame. 47 | ok, frame = video.read() 48 | if not ok: 49 | print ('Cannot read video file') 50 | sys.exit() 51 | 52 | # Define an initial bounding box 53 | # bbox = (287, 23, 86, 320) 54 | 55 | # Uncomment the line below to select a different bounding box 56 | bbox = cv2.selectROI(frame, False) 57 | 58 | # Initialize tracker with first frame and bounding box 59 | ok = tracker.init(frame, bbox) 60 | 61 | while True: 62 | # Read a new frame 63 | ok, frame = video.read() 64 | if not ok: 65 | break 66 | 67 | # Start timer 68 | timer = cv2.getTickCount() 69 | 70 | # Update tracker 71 | ok, bbox = tracker.update(frame) 72 | 73 | # Calculate Frames per second (FPS) 74 | fps = cv2.getTickFrequency() / (cv2.getTickCount() - timer); 75 | 76 | # Draw bounding box 77 | if ok: 78 | # Tracking success 79 | p1 = (int(bbox[0]), int(bbox[1])) 80 | p2 = (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3])) 81 | cv2.rectangle(frame, p1, p2, (255,0,0), 2) 82 | print(p1,p2) 83 | else : 84 | # Tracking failure 85 | cv2.putText(frame, "Tracking failure detected", (100,80), cv2.FONT_HERSHEY_SIMPLEX, 0.75,(0,0,255),2) 86 | 87 | # Display tracker type on frame 88 | cv2.putText(frame, tracker_type + " Tracker", (100,20), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (50,170,50),2); 89 | 90 | # Display FPS on frame 91 | cv2.putText(frame, "FPS : " + str(int(fps)), (100,50), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (50,170,50), 2); 92 | 93 | # Display result 94 | cv2.imshow("Tracking", frame) 95 | 96 | # Exit if ESC pressed 97 | k = cv2.waitKey(1) & 0xff 98 | if k == 27 : break 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | -------------------------------------------------------------------------------- /pig_vgg16/single_picture_recognize.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import sys 7 | import os.path 8 | from datetime import datetime 9 | from PIL import Image 10 | import numpy as np 11 | 12 | import tensorflow as tf 13 | from tensorflow.python.platform import gfile 14 | import pig_modela as pig_model 15 | import csv 16 | import config 17 | 18 | IMAGE_WIDTH = config.IMAGE_WIDTH 19 | IMAGE_HEIGHT = config.IMAGE_HEIGHT 20 | 21 | CLASSES_NUM = config.CLASSES_NUM 22 | 23 | 24 | FLAGS = None 25 | Batch_size = 1 26 | 27 | def input_data(image_dir): 28 | if not gfile.Exists(image_dir): 29 | print(">> Image director '" + image_dir + "' not found.") 30 | return None 31 | 32 | print(">> Looking for images in '" + image_dir + "'") 33 | 34 | 35 | file_glob = os.path.join(image_dir, '*.jpg' ) 36 | file_list = gfile.Glob(file_glob) 37 | if not file_list: 38 | print(">> No files found in '" + image_dir + "'") 39 | return None 40 | file_list = sorted(file_list) 41 | all_files = len(file_list) 42 | images = np.zeros([all_files, IMAGE_HEIGHT*IMAGE_WIDTH*3], dtype='float32') 43 | files = [] 44 | i = 0 45 | for file_name in file_list: 46 | image = Image.open(file_name) 47 | image_resize = image.resize(size=(IMAGE_WIDTH,IMAGE_HEIGHT)) 48 | input_img = np.array(image_resize, dtype='float32') 49 | input_img = input_img.flatten()/127.5 - 1 50 | images[i,:] = input_img 51 | base_name = os.path.basename(file_name) 52 | files.append(base_name) 53 | i += 1 54 | return images, files 55 | 56 | 57 | def run_predict(): 58 | with tf.Graph().as_default(): 59 | input_images, input_filenames = input_data(FLAGS.test_dir)#得到文件夹内所有照片和文件名 60 | max_step = len(input_images) 61 | images = tf.placeholder(tf.float32,[IMAGE_HEIGHT*IMAGE_WIDTH*3],name ='input') 62 | logits = pig_model.inference(images, keep_prob=1,is_training=True) 63 | output = pig_model.output(logits) 64 | saver = tf.train.Saver() 65 | sess = tf.Session() 66 | saver.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir)) 67 | # saver.restore(sess, 'checkpoint/model.ckpt-3084') 68 | tag = [] 69 | for each in range(max_step): 70 | feed_dict = input_images[each] 71 | recog_result = sess.run(output,feed_dict={images:feed_dict}) 72 | recog_result = np.squeeze(recog_result) 73 | current_name = str(int(input_filenames[each].split('.')[0])) 74 | for i in range(len(recog_result)): 75 | tag.append([current_name,str(i+1),str('%.8f'%(recog_result[i]))]) 76 | with open('out.csv','w',newline='') as csvfile: 77 | writer = csv.writer(csvfile) 78 | for x in tag: 79 | writer.writerow(x) 80 | sess.close() 81 | 82 | def main(_): 83 | run_predict() 84 | 85 | if __name__ == '__main__': 86 | parser = argparse.ArgumentParser() 87 | parser.add_argument( 88 | '--checkpoint_dir', 89 | type=str, 90 | default='checkpoint', 91 | help='Directory where to restore checkpoint.' 92 | ) 93 | parser.add_argument( 94 | '--test_dir', 95 | type=str, 96 | default='pig_body/body_test', 97 | help='Directory where to get captcha images.' 98 | ) 99 | FLAGS, unparsed = parser.parse_known_args() 100 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 101 | -------------------------------------------------------------------------------- /datasets/flowers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the flowers dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/research/slim/datasets/download_and_convert_flowers.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'flowers_%s_*.tfrecord' 33 | 34 | SPLITS_TO_SIZES = {'train': 3320, 'validation': 350} 35 | 36 | _NUM_CLASSES = 5 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A color image of varying size.', 40 | 'label': 'A single integer between 0 and 4', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading flowers. 46 | 47 | Args: 48 | split_name: A train/validation split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/validation split. 60 | """ 61 | if split_name not in SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if reader is None: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=SPLITS_TO_SIZES[split_name], 96 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 97 | num_classes=_NUM_CLASSES, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /datasets/mnist.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the MNIST dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/research/slim/datasets/download_and_convert_mnist.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'mnist_%s.tfrecord' 33 | 34 | _SPLITS_TO_SIZES = {'train': 60000, 'test': 10000} 35 | 36 | _NUM_CLASSES = 10 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A [28 x 28 x 1] grayscale image.', 40 | 'label': 'A single integer between 0 and 9', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading MNIST. 46 | 47 | Args: 48 | split_name: A train/test split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/test split. 60 | """ 61 | if split_name not in _SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if reader is None: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='raw'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [1], tf.int64, default_value=tf.zeros([1], dtype=tf.int64)), 77 | } 78 | 79 | items_to_handlers = { 80 | 'image': slim.tfexample_decoder.Image(shape=[28, 28, 1], channels=1), 81 | 'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[]), 82 | } 83 | 84 | decoder = slim.tfexample_decoder.TFExampleDecoder( 85 | keys_to_features, items_to_handlers) 86 | 87 | labels_to_names = None 88 | if dataset_utils.has_labels(dataset_dir): 89 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 90 | 91 | return slim.dataset.Dataset( 92 | data_sources=file_pattern, 93 | reader=reader, 94 | decoder=decoder, 95 | num_samples=_SPLITS_TO_SIZES[split_name], 96 | num_classes=_NUM_CLASSES, 97 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 98 | labels_to_names=labels_to_names) 99 | -------------------------------------------------------------------------------- /datasets/cifar10.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides data for the Cifar10 dataset. 16 | 17 | The dataset scripts used to create the dataset can be found at: 18 | tensorflow/models/research/slim/datasets/download_and_convert_cifar10.py 19 | """ 20 | 21 | from __future__ import absolute_import 22 | from __future__ import division 23 | from __future__ import print_function 24 | 25 | import os 26 | import tensorflow as tf 27 | 28 | from datasets import dataset_utils 29 | 30 | slim = tf.contrib.slim 31 | 32 | _FILE_PATTERN = 'cifar10_%s.tfrecord' 33 | 34 | SPLITS_TO_SIZES = {'train': 12770, 'test': 200} 35 | 36 | _NUM_CLASSES = 30 37 | 38 | _ITEMS_TO_DESCRIPTIONS = { 39 | 'image': 'A [32 x 32 x 3] color image.', 40 | 'label': 'A single integer between 0 and 9', 41 | } 42 | 43 | 44 | def get_split(split_name, dataset_dir, file_pattern=None, reader=None): 45 | """Gets a dataset tuple with instructions for reading cifar10. 46 | 47 | Args: 48 | split_name: A train/test split name. 49 | dataset_dir: The base directory of the dataset sources. 50 | file_pattern: The file pattern to use when matching the dataset sources. 51 | It is assumed that the pattern contains a '%s' string so that the split 52 | name can be inserted. 53 | reader: The TensorFlow reader type. 54 | 55 | Returns: 56 | A `Dataset` namedtuple. 57 | 58 | Raises: 59 | ValueError: if `split_name` is not a valid train/test split. 60 | """ 61 | if split_name not in SPLITS_TO_SIZES: 62 | raise ValueError('split name %s was not recognized.' % split_name) 63 | 64 | if not file_pattern: 65 | file_pattern = _FILE_PATTERN 66 | file_pattern = os.path.join(dataset_dir, file_pattern % split_name) 67 | 68 | # Allowing None in the signature so that dataset_factory can use the default. 69 | if not reader: 70 | reader = tf.TFRecordReader 71 | 72 | keys_to_features = { 73 | 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 74 | 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 75 | 'image/class/label': tf.FixedLenFeature( 76 | [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), 77 | } 78 | 79 | items_to_handlers = { 80 | # 'image': slim.tfexample_decoder.Image(shape=[299, 299, 3]), 81 | 'image': slim.tfexample_decoder.Image(), 82 | 'label': slim.tfexample_decoder.Tensor('image/class/label'), 83 | } 84 | 85 | decoder = slim.tfexample_decoder.TFExampleDecoder( 86 | keys_to_features, items_to_handlers) 87 | 88 | labels_to_names = None 89 | if dataset_utils.has_labels(dataset_dir): 90 | labels_to_names = dataset_utils.read_label_file(dataset_dir) 91 | 92 | return slim.dataset.Dataset( 93 | data_sources=file_pattern, 94 | reader=reader, 95 | decoder=decoder, 96 | num_samples=SPLITS_TO_SIZES[split_name], 97 | items_to_descriptions=_ITEMS_TO_DESCRIPTIONS, 98 | num_classes=_NUM_CLASSES, 99 | labels_to_names=labels_to_names) 100 | -------------------------------------------------------------------------------- /pig_vgg16/single_picture_recognize_r+p.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import sys 7 | import os.path 8 | from datetime import datetime 9 | 10 | import numpy as np 11 | import cv2 12 | import tensorflow as tf 13 | from tensorflow.python.platform import gfile 14 | import pig_model_dark as pig_model 15 | import csv 16 | import config 17 | 18 | IMAGE_WIDTH = config.IMAGE_WIDTH 19 | IMAGE_HEIGHT = config.IMAGE_HEIGHT 20 | 21 | CLASSES_NUM = config.CLASSES_NUM 22 | 23 | FLAGS = None 24 | Batch_size = 1 25 | 26 | def input_data(image_dir): 27 | if not gfile.Exists(image_dir): 28 | print(">> Image director '" + image_dir + "' not found.") 29 | return None 30 | 31 | print(">> Looking for images in '" + image_dir + "'") 32 | 33 | 34 | file_glob = os.path.join(image_dir, '*.jpg' ) 35 | file_list = gfile.Glob(file_glob) 36 | if not file_list: 37 | print(">> No files found in '" + image_dir + "'") 38 | return None 39 | file_list = sorted(file_list) 40 | all_files = len(file_list) 41 | images = np.zeros([all_files, IMAGE_HEIGHT*IMAGE_WIDTH*3], dtype='float32') 42 | files = [] 43 | i = 0 44 | for file_name in file_list: 45 | image = cv2.imread(file_name) 46 | input_img = np.array(image, dtype='float32') 47 | height, width, _ = input_img.shape 48 | offset = abs(height -width)//2 49 | if height >= width: 50 | pad_image = np.pad(input_img,((0,0),(offset,offset),(0,0)),mode='constant',constant_values =0) 51 | else: 52 | pad_image = np.pad(input_img,((offset,offset),(0,0),(0,0)),mode='constant',constant_values =0) 53 | 54 | image_resize = cv2.resize(pad_image, (IMAGE_WIDTH,IMAGE_HEIGHT)) 55 | # cv2.imwrite('a.jpg', pad_image) 56 | # cv2.imwrite('b.jpg', image_resize) 57 | 58 | image_resize = image_resize.flatten()/127.5 - 1 59 | images[i,:] = image_resize 60 | base_name = os.path.basename(file_name) 61 | files.append(base_name) 62 | i += 1 63 | return images, files 64 | 65 | 66 | def run_predict(): 67 | with tf.Graph().as_default(): 68 | input_images, input_filenames = input_data(FLAGS.test_dir)#得到文件夹内所有照片和文件名 69 | max_step = len(input_images) 70 | images = tf.placeholder(tf.float32,[IMAGE_HEIGHT*IMAGE_WIDTH*3],name ='input') 71 | logits = pig_model.inference(images, keep_prob=1,is_training=False) 72 | output = pig_model.output(logits) 73 | saver = tf.train.Saver() 74 | sess = tf.Session() 75 | saver.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir)) 76 | tag = [] 77 | for each in range(max_step): 78 | feed_dict = input_images[each] 79 | recog_result = sess.run(output,feed_dict={images:feed_dict}) 80 | recog_result = np.squeeze(recog_result) 81 | current_name = str(int(input_filenames[each].split('.')[0])) 82 | for i in range(len(recog_result)): 83 | tag.append([current_name,str(i+1),str('%.8f'%(recog_result[i]))]) 84 | with open('out.csv','w',newline='') as csvfile: 85 | writer = csv.writer(csvfile) 86 | for x in tag: 87 | writer.writerow(x) 88 | sess.close() 89 | return tag 90 | 91 | def main(_): 92 | tag = run_predict() 93 | 94 | if __name__ == '__main__': 95 | parser = argparse.ArgumentParser() 96 | parser.add_argument( 97 | '--checkpoint_dir', 98 | type=str, 99 | default='checkpoint', 100 | help='Directory where to restore checkpoint.' 101 | ) 102 | parser.add_argument( 103 | '--test_dir', 104 | type=str, 105 | default='pig_body/train_data1', 106 | help='Directory where to get captcha images.' 107 | ) 108 | FLAGS, unparsed = parser.parse_known_args() 109 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 110 | -------------------------------------------------------------------------------- /scripts/finetune_inception_v1_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Flowers dataset 19 | # 2. Fine-tunes an InceptionV1 model on the Flowers training set. 20 | # 3. Evaluates the model on the Flowers validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_inception_v1_on_flowers.sh 25 | set -e 26 | 27 | # Where the pre-trained InceptionV1 checkpoint is saved to. 28 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 29 | 30 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 31 | TRAIN_DIR=/tmp/flowers-models/inception_v1 32 | 33 | # Where the dataset is saved to. 34 | DATASET_DIR=/tmp/flowers 35 | 36 | # Download the pre-trained checkpoint. 37 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 38 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 39 | fi 40 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt ]; then 41 | wget http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz 42 | tar -xvf inception_v1_2016_08_28.tar.gz 43 | mv inception_v1.ckpt ${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt 44 | rm inception_v1_2016_08_28.tar.gz 45 | fi 46 | 47 | # Download the dataset 48 | python download_and_convert_data.py \ 49 | --dataset_name=flowers \ 50 | --dataset_dir=${DATASET_DIR} 51 | 52 | # Fine-tune only the new layers for 2000 steps. 53 | python train_image_classifier.py \ 54 | --train_dir=${TRAIN_DIR} \ 55 | --dataset_name=flowers \ 56 | --dataset_split_name=train \ 57 | --dataset_dir=${DATASET_DIR} \ 58 | --model_name=inception_v1 \ 59 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt \ 60 | --checkpoint_exclude_scopes=InceptionV1/Logits \ 61 | --trainable_scopes=InceptionV1/Logits \ 62 | --max_number_of_steps=3000 \ 63 | --batch_size=32 \ 64 | --learning_rate=0.01 \ 65 | --save_interval_secs=60 \ 66 | --save_summaries_secs=60 \ 67 | --log_every_n_steps=100 \ 68 | --optimizer=rmsprop \ 69 | --weight_decay=0.00004 70 | 71 | # Run evaluation. 72 | python eval_image_classifier.py \ 73 | --checkpoint_path=${TRAIN_DIR} \ 74 | --eval_dir=${TRAIN_DIR} \ 75 | --dataset_name=flowers \ 76 | --dataset_split_name=validation \ 77 | --dataset_dir=${DATASET_DIR} \ 78 | --model_name=inception_v1 79 | 80 | # Fine-tune all the new layers for 1000 steps. 81 | python train_image_classifier.py \ 82 | --train_dir=${TRAIN_DIR}/all \ 83 | --dataset_name=flowers \ 84 | --dataset_split_name=train \ 85 | --dataset_dir=${DATASET_DIR} \ 86 | --checkpoint_path=${TRAIN_DIR} \ 87 | --model_name=inception_v1 \ 88 | --max_number_of_steps=1000 \ 89 | --batch_size=32 \ 90 | --learning_rate=0.001 \ 91 | --save_interval_secs=60 \ 92 | --save_summaries_secs=60 \ 93 | --log_every_n_steps=100 \ 94 | --optimizer=rmsprop \ 95 | --weight_decay=0.00004 96 | 97 | # Run evaluation. 98 | python eval_image_classifier.py \ 99 | --checkpoint_path=${TRAIN_DIR}/all \ 100 | --eval_dir=${TRAIN_DIR}/all \ 101 | --dataset_name=flowers \ 102 | --dataset_split_name=validation \ 103 | --dataset_dir=${DATASET_DIR} \ 104 | --model_name=inception_v1 105 | -------------------------------------------------------------------------------- /scripts/finetune_resnet_v1_50_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Flowers dataset 19 | # 2. Fine-tunes a ResNetV1-50 model on the Flowers training set. 20 | # 3. Evaluates the model on the Flowers validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_resnet_v1_50_on_flowers.sh 25 | set -e 26 | 27 | # Where the pre-trained ResNetV1-50 checkpoint is saved to. 28 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 29 | 30 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 31 | TRAIN_DIR=/tmp/flowers-models/resnet_v1_50 32 | 33 | # Where the dataset is saved to. 34 | DATASET_DIR=/tmp/flowers 35 | 36 | # Download the pre-trained checkpoint. 37 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 38 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 39 | fi 40 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt ]; then 41 | wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz 42 | tar -xvf resnet_v1_50_2016_08_28.tar.gz 43 | mv resnet_v1_50.ckpt ${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt 44 | rm resnet_v1_50_2016_08_28.tar.gz 45 | fi 46 | 47 | # Download the dataset 48 | python download_and_convert_data.py \ 49 | --dataset_name=flowers \ 50 | --dataset_dir=${DATASET_DIR} 51 | 52 | # Fine-tune only the new layers for 3000 steps. 53 | python train_image_classifier.py \ 54 | --train_dir=${TRAIN_DIR} \ 55 | --dataset_name=flowers \ 56 | --dataset_split_name=train \ 57 | --dataset_dir=${DATASET_DIR} \ 58 | --model_name=resnet_v1_50 \ 59 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/resnet_v1_50.ckpt \ 60 | --checkpoint_exclude_scopes=resnet_v1_50/logits \ 61 | --trainable_scopes=resnet_v1_50/logits \ 62 | --max_number_of_steps=3000 \ 63 | --batch_size=32 \ 64 | --learning_rate=0.01 \ 65 | --save_interval_secs=60 \ 66 | --save_summaries_secs=60 \ 67 | --log_every_n_steps=100 \ 68 | --optimizer=rmsprop \ 69 | --weight_decay=0.00004 70 | 71 | # Run evaluation. 72 | python eval_image_classifier.py \ 73 | --checkpoint_path=${TRAIN_DIR} \ 74 | --eval_dir=${TRAIN_DIR} \ 75 | --dataset_name=flowers \ 76 | --dataset_split_name=validation \ 77 | --dataset_dir=${DATASET_DIR} \ 78 | --model_name=resnet_v1_50 79 | 80 | # Fine-tune all the new layers for 1000 steps. 81 | python train_image_classifier.py \ 82 | --train_dir=${TRAIN_DIR}/all \ 83 | --dataset_name=flowers \ 84 | --dataset_split_name=train \ 85 | --dataset_dir=${DATASET_DIR} \ 86 | --checkpoint_path=${TRAIN_DIR} \ 87 | --model_name=resnet_v1_50 \ 88 | --max_number_of_steps=1000 \ 89 | --batch_size=32 \ 90 | --learning_rate=0.001 \ 91 | --save_interval_secs=60 \ 92 | --save_summaries_secs=60 \ 93 | --log_every_n_steps=100 \ 94 | --optimizer=rmsprop \ 95 | --weight_decay=0.00004 96 | 97 | # Run evaluation. 98 | python eval_image_classifier.py \ 99 | --checkpoint_path=${TRAIN_DIR}/all \ 100 | --eval_dir=${TRAIN_DIR}/all \ 101 | --dataset_name=flowers \ 102 | --dataset_split_name=validation \ 103 | --dataset_dir=${DATASET_DIR} \ 104 | --model_name=resnet_v1_50 105 | -------------------------------------------------------------------------------- /scripts/finetune_inception_v3_on_flowers.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | # 17 | # This script performs the following operations: 18 | # 1. Downloads the Flowers dataset 19 | # 2. Fine-tunes an InceptionV3 model on the Flowers training set. 20 | # 3. Evaluates the model on the Flowers validation set. 21 | # 22 | # Usage: 23 | # cd slim 24 | # ./slim/scripts/finetune_inception_v3_on_flowers.sh 25 | set -e 26 | 27 | # Where the pre-trained InceptionV3 checkpoint is saved to. 28 | PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints 29 | 30 | # Where the training (fine-tuned) checkpoint and logs will be saved to. 31 | TRAIN_DIR=/tmp/flowers-models/inception_v3 32 | 33 | # Where the dataset is saved to. 34 | DATASET_DIR=/tmp/flowers 35 | 36 | # Download the pre-trained checkpoint. 37 | if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then 38 | mkdir ${PRETRAINED_CHECKPOINT_DIR} 39 | fi 40 | if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt ]; then 41 | wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz 42 | tar -xvf inception_v3_2016_08_28.tar.gz 43 | mv inception_v3.ckpt ${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt 44 | rm inception_v3_2016_08_28.tar.gz 45 | fi 46 | 47 | # Download the dataset 48 | python download_and_convert_data.py \ 49 | --dataset_name=flowers \ 50 | --dataset_dir=${DATASET_DIR} 51 | 52 | # Fine-tune only the new layers for 1000 steps. 53 | python train_image_classifier.py \ 54 | --train_dir=${TRAIN_DIR} \ 55 | --dataset_name=flowers \ 56 | --dataset_split_name=train \ 57 | --dataset_dir=${DATASET_DIR} \ 58 | --model_name=inception_v3 \ 59 | --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt \ 60 | --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \ 61 | --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \ 62 | --max_number_of_steps=1000 \ 63 | --batch_size=32 \ 64 | --learning_rate=0.01 \ 65 | --learning_rate_decay_type=fixed \ 66 | --save_interval_secs=60 \ 67 | --save_summaries_secs=60 \ 68 | --log_every_n_steps=100 \ 69 | --optimizer=rmsprop \ 70 | --weight_decay=0.00004 71 | 72 | # Run evaluation. 73 | python eval_image_classifier.py \ 74 | --checkpoint_path=${TRAIN_DIR} \ 75 | --eval_dir=${TRAIN_DIR} \ 76 | --dataset_name=flowers \ 77 | --dataset_split_name=validation \ 78 | --dataset_dir=${DATASET_DIR} \ 79 | --model_name=inception_v3 80 | 81 | # Fine-tune all the new layers for 500 steps. 82 | python train_image_classifier.py \ 83 | --train_dir=${TRAIN_DIR}/all \ 84 | --dataset_name=flowers \ 85 | --dataset_split_name=train \ 86 | --dataset_dir=${DATASET_DIR} \ 87 | --model_name=inception_v3 \ 88 | --checkpoint_path=${TRAIN_DIR} \ 89 | --max_number_of_steps=500 \ 90 | --batch_size=32 \ 91 | --learning_rate=0.0001 \ 92 | --learning_rate_decay_type=fixed \ 93 | --save_interval_secs=60 \ 94 | --save_summaries_secs=60 \ 95 | --log_every_n_steps=10 \ 96 | --optimizer=rmsprop \ 97 | --weight_decay=0.00004 98 | 99 | # Run evaluation. 100 | python eval_image_classifier.py \ 101 | --checkpoint_path=${TRAIN_DIR}/all \ 102 | --eval_dir=${TRAIN_DIR}/all \ 103 | --dataset_name=flowers \ 104 | --dataset_split_name=validation \ 105 | --dataset_dir=${DATASET_DIR} \ 106 | --model_name=inception_v3 107 | -------------------------------------------------------------------------------- /datasets/download_imagenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2016 Google Inc. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | # Script to download ImageNet Challenge 2012 training and validation data set. 18 | # 19 | # Downloads and decompresses raw images and bounding boxes. 20 | # 21 | # **IMPORTANT** 22 | # To download the raw images, the user must create an account with image-net.org 23 | # and generate a username and access_key. The latter two are required for 24 | # downloading the raw images. 25 | # 26 | # usage: 27 | # ./download_imagenet.sh [dirname] 28 | set -e 29 | 30 | if [ "x$IMAGENET_ACCESS_KEY" == x -o "x$IMAGENET_USERNAME" == x ]; then 31 | cat <"${BOUNDING_BOX_FILE}" 88 | echo "Finished downloading and preprocessing the ImageNet data." 89 | 90 | # Build the TFRecords version of the ImageNet data. 91 | BUILD_SCRIPT="${WORK_DIR}/build_imagenet_data" 92 | OUTPUT_DIRECTORY="${DATA_DIR}" 93 | IMAGENET_METADATA_FILE="${WORK_DIR}/datasets/imagenet_metadata.txt" 94 | 95 | "${BUILD_SCRIPT}" \ 96 | --train_directory="${TRAIN_DIRECTORY}" \ 97 | --validation_directory="${VALIDATION_DIRECTORY}" \ 98 | --output_directory="${OUTPUT_DIRECTORY}" \ 99 | --imagenet_metadata_file="${IMAGENET_METADATA_FILE}" \ 100 | --labels_file="${LABELS_FILE}" \ 101 | --bounding_box_file="${BOUNDING_BOX_FILE}" 102 | -------------------------------------------------------------------------------- /pig_vgg16/pig_model_vgg16.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Nov 24 15:40:02 2017 4 | 5 | @author: DELL 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import tensorflow as tf 13 | import pig_input 14 | import config 15 | import VGG16 16 | 17 | IMAGE_WIDTH = config.IMAGE_WIDTH 18 | IMAGE_HEIGHT = config.IMAGE_HEIGHT 19 | CLASSES_NUM = config.CLASSES_NUM 20 | rate = 0.01 21 | 22 | def inputs(train, batch_size): 23 | return pig_input.inputs(train, batch_size=batch_size) 24 | 25 | def _conv(name, input, size, input_channels, output_channels, is_training=True): 26 | with tf.variable_scope(name) as scope: 27 | if not is_training: 28 | scope.reuse_variables() 29 | kernel = _weight_variable('weights', shape=[size, size ,input_channels, output_channels]) 30 | biases = _bias_variable('biases',[output_channels]) 31 | pre_activation = tf.nn.bias_add(_conv2d(input, kernel),biases) 32 | conv = tf.maximum(rate*pre_activation,pre_activation, name=scope.name) 33 | return conv 34 | 35 | def _conv2d(value, weight): 36 | """conv2d returns a 2d convolution layer with full stride.""" 37 | return tf.nn.conv2d(value, weight, strides=[1, 1, 1, 1], padding='SAME') 38 | 39 | 40 | def _max_pool_2x2(value, name, is_training): 41 | """max_pool_2x2 downsamples a feature map by 2X.""" 42 | with tf.variable_scope(name) as scope1: 43 | if not is_training: 44 | scope1.reuse_variables() 45 | return tf.nn.max_pool(value, ksize=[1, 2, 2, 1], 46 | strides=[1, 2, 2, 1], padding='SAME', name=name) 47 | 48 | 49 | def _weight_variable(name, shape): 50 | """weight_variable generates a weight variable of a given shape.""" 51 | initializer = tf.truncated_normal_initializer(stddev=0.1) 52 | var = tf.get_variable(name,shape,initializer=initializer, dtype=tf.float32) 53 | return var 54 | 55 | 56 | def _bias_variable(name, shape): 57 | """bias_variable generates a bias variable of a given shape.""" 58 | initializer = tf.constant_initializer(0.1) 59 | var = tf.get_variable(name, shape, initializer=initializer,dtype=tf.float32) 60 | return var 61 | 62 | def _batch_norm(name, inputs, is_training): 63 | """ Batch Normalization 64 | """ 65 | with tf.variable_scope(name, reuse = not is_training): 66 | # return tf.layers.batch_normalization(input,training=is_training) 67 | return tf.contrib.layers.batch_norm(inputs, 68 | decay=0.9, 69 | scale=True, 70 | updates_collections=None, 71 | is_training=True) 72 | def inference(images, keep_prob, is_training): 73 | images = tf.reshape(images, [-1, IMAGE_HEIGHT, IMAGE_WIDTH, 3]) # 256,256,3 74 | 75 | softmax_linear = VGG16.VGG16(images) 76 | return tf.reshape(softmax_linear, [-1, CLASSES_NUM]) 77 | 78 | 79 | def loss(logits, labels): 80 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits( 81 | labels=labels, logits=logits, name='corss_entropy_per_example') 82 | cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') 83 | tf.add_to_collection('losses', cross_entropy_mean) 84 | return tf.add_n(tf.get_collection('losses'), name='total_loss') 85 | 86 | 87 | def training(loss): 88 | optimizer = tf.train.AdamOptimizer(1e-4) 89 | gen_grads_and_vars = optimizer.compute_gradients(loss) 90 | gen_train = optimizer.apply_gradients(gen_grads_and_vars) 91 | ema = tf.train.ExponentialMovingAverage(decay=0.99) 92 | update_losses = ema.apply([loss]) 93 | 94 | global_step = tf.contrib.framework.get_or_create_global_step() 95 | incr_global_step = tf.assign(global_step, global_step+1) 96 | 97 | return tf.group(update_losses, incr_global_step, gen_train) 98 | 99 | 100 | 101 | def evaluation(logits, labels): 102 | correct_prediction = tf.equal(tf.argmax(logits,1), tf.argmax(labels,1)) 103 | return tf.reduce_sum(tf.cast(correct_prediction, tf.float32)) 104 | 105 | 106 | def output(logits): 107 | return tf.nn.softmax(logits) 108 | 109 | def predict(logits): 110 | return tf.argmax(logits, 1) -------------------------------------------------------------------------------- /pig_vgg16/pig_model_restnet_v2.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sat Nov 25 20:04:34 2017 4 | 5 | @author: DELL 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import tensorflow as tf 13 | import pig_input 14 | import config 15 | from nets import inception 16 | 17 | IMAGE_WIDTH = config.IMAGE_WIDTH 18 | IMAGE_HEIGHT = config.IMAGE_HEIGHT 19 | CLASSES_NUM = config.CLASSES_NUM 20 | rate = 0.01 21 | 22 | def inputs(train, batch_size): 23 | return pig_input.inputs(train, batch_size=batch_size) 24 | 25 | def _conv(name, input, size, input_channels, output_channels, is_training=True): 26 | with tf.variable_scope(name) as scope: 27 | if not is_training: 28 | scope.reuse_variables() 29 | kernel = _weight_variable('weights', shape=[size, size ,input_channels, output_channels]) 30 | biases = _bias_variable('biases',[output_channels]) 31 | pre_activation = tf.nn.bias_add(_conv2d(input, kernel),biases) 32 | conv = tf.maximum(rate*pre_activation,pre_activation, name=scope.name) 33 | return conv 34 | 35 | def _conv2d(value, weight): 36 | """conv2d returns a 2d convolution layer with full stride.""" 37 | return tf.nn.conv2d(value, weight, strides=[1, 1, 1, 1], padding='SAME') 38 | 39 | 40 | def _max_pool_2x2(value, name, is_training): 41 | """max_pool_2x2 downsamples a feature map by 2X.""" 42 | with tf.variable_scope(name) as scope1: 43 | if not is_training: 44 | scope1.reuse_variables() 45 | return tf.nn.max_pool(value, ksize=[1, 2, 2, 1], 46 | strides=[1, 2, 2, 1], padding='SAME', name=name) 47 | 48 | 49 | def _weight_variable(name, shape): 50 | """weight_variable generates a weight variable of a given shape.""" 51 | initializer = tf.truncated_normal_initializer(stddev=0.1) 52 | var = tf.get_variable(name,shape,initializer=initializer, dtype=tf.float32) 53 | return var 54 | 55 | 56 | def _bias_variable(name, shape): 57 | """bias_variable generates a bias variable of a given shape.""" 58 | initializer = tf.constant_initializer(0.1) 59 | var = tf.get_variable(name, shape, initializer=initializer,dtype=tf.float32) 60 | return var 61 | 62 | def _batch_norm(name, inputs, is_training): 63 | """ Batch Normalization 64 | """ 65 | with tf.variable_scope(name, reuse = not is_training): 66 | # return tf.layers.batch_normalization(input,training=is_training) 67 | return tf.contrib.layers.batch_norm(inputs, 68 | decay=0.9, 69 | scale=True, 70 | updates_collections=None, 71 | is_training=True) 72 | def inference(images, keep_prob, is_training): 73 | images = tf.reshape(images, [-1, IMAGE_HEIGHT, IMAGE_WIDTH, 3]) # 256,256,3 74 | 75 | softmax_linear, end_points = inception.inception_resnet_v2(images, CLASSES_NUM) 76 | # return softmax_linear 77 | return tf.reshape(softmax_linear, [-1, CLASSES_NUM]) 78 | 79 | 80 | def loss(logits, labels): 81 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits( 82 | labels=labels, logits=logits, name='corss_entropy_per_example') 83 | cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') 84 | tf.add_to_collection('losses', cross_entropy_mean) 85 | return tf.add_n(tf.get_collection('losses'), name='total_loss') 86 | 87 | 88 | def training(loss): 89 | optimizer = tf.train.AdamOptimizer(1e-4) 90 | gen_grads_and_vars = optimizer.compute_gradients(loss) 91 | gen_train = optimizer.apply_gradients(gen_grads_and_vars) 92 | ema = tf.train.ExponentialMovingAverage(decay=0.99) 93 | update_losses = ema.apply([loss]) 94 | 95 | global_step = tf.contrib.framework.get_or_create_global_step() 96 | incr_global_step = tf.assign(global_step, global_step+1) 97 | 98 | return tf.group(update_losses, incr_global_step, gen_train) 99 | 100 | 101 | 102 | def evaluation(logits, labels): 103 | correct_prediction = tf.equal(tf.argmax(logits,1), tf.argmax(labels,1)) 104 | return tf.reduce_sum(tf.cast(correct_prediction, tf.float32)) 105 | 106 | 107 | def output(logits): 108 | return tf.nn.softmax(logits) 109 | 110 | def predict(logits): 111 | return tf.argmax(logits, 1) -------------------------------------------------------------------------------- /pig_vgg16/pig_records.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import os.path 7 | import sys 8 | 9 | from PIL import Image 10 | 11 | import numpy as np 12 | import tensorflow as tf 13 | from tensorflow.python.platform import gfile 14 | import config 15 | 16 | IMAGE_HEIGHT = config.IMAGE_HEIGHT #256 17 | IMAGE_WIDTH = config.IMAGE_WIDTH #256 18 | CLASSES_NUM = config.CLASSES_NUM #10 19 | 20 | 21 | RECORD_DIR = config.RECORD_DIR 22 | TRAIN_FILE = config.TRAIN_FILE 23 | VALID_FILE = config.VALID_FILE 24 | 25 | FLAGS = None 26 | 27 | def _int64_feature(values): 28 | if not isinstance(values, (tuple, list)): 29 | values = [values] 30 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[values])) 31 | 32 | def _bytes_feature(values): 33 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) 34 | 35 | def _float_feature(values): 36 | if not isinstance(values, (tuple, list)): 37 | values = [values] 38 | return tf.train.Feature(float_list=tf.train.FloatList(value=values)) 39 | 40 | def label_to_one_hot(label): 41 | one_hot_label = np.zeros([CLASSES_NUM]) 42 | one_hot_label[label] = 1.0 43 | return one_hot_label.astype(np.uint8) #(4,10) 44 | 45 | 46 | def conver_to_tfrecords(data_set, name): 47 | """Converts a dataset to tfrecords.""" 48 | if not os.path.exists(RECORD_DIR): 49 | os.makedirs(RECORD_DIR) 50 | filename = os.path.join(RECORD_DIR, name) 51 | print('>> Writing', filename) 52 | writer = tf.python_io.TFRecordWriter(filename) 53 | data_set_list=list(data_set) 54 | num_examples = len(data_set_list) 55 | count = 0 56 | for index in range(num_examples): 57 | count += 1 58 | image = data_set_list[index][0] 59 | height = image.shape[0] 60 | width = image.shape[1] 61 | image_raw = image.tostring() 62 | label = data_set_list[index][1] 63 | label_raw = label_to_one_hot(label).tostring() 64 | 65 | example = tf.train.Example(features=tf.train.Features(feature={ 66 | 'height': _int64_feature(height), 67 | 'width': _int64_feature(width), 68 | 'label_raw': _bytes_feature(label_raw), 69 | 'image_raw': _bytes_feature(image_raw)})) 70 | writer.write(example.SerializeToString()) 71 | if count %500 == 0: 72 | print('processed {}/{}'.format(count,num_examples)) 73 | writer.close() 74 | print('>> Writing Done!') 75 | 76 | 77 | def create_data_list(image_dir): 78 | if not gfile.Exists(image_dir): 79 | print("Image director '" + image_dir + "' not found.") 80 | return None 81 | extensions = [ '*.jpg'] 82 | print("Looking for images in '" + image_dir + "'") 83 | file_list = [] 84 | for extension in extensions: 85 | file_glob = os.path.join(image_dir, extension) 86 | file_list.extend(gfile.Glob(file_glob)) 87 | if not file_list: 88 | print("No files found in '" + image_dir + "'") 89 | return None 90 | images = [] 91 | labels = [] 92 | all_list = len(file_list) 93 | count = 0 94 | for file_name in file_list: 95 | count += 1 96 | image = Image.open(file_name) 97 | image_resize = image.resize(size=(IMAGE_WIDTH,IMAGE_HEIGHT)) 98 | input_img = np.array(image_resize, dtype='int16') 99 | image.close() 100 | label_name = int(os.path.basename(file_name).split('_')[0]) - 1 #start at 0 101 | images.append(input_img) 102 | labels.append(label_name) 103 | if count % 500 == 0: 104 | print('processed :{}/{}'.format(count,all_list)) 105 | return zip(images, labels) 106 | 107 | 108 | def main(_): 109 | training_data = create_data_list(FLAGS.train_dir) 110 | conver_to_tfrecords(training_data, TRAIN_FILE) 111 | 112 | 113 | if __name__ == '__main__': 114 | parser = argparse.ArgumentParser() 115 | parser.add_argument( 116 | '--train_dir', 117 | type=str, 118 | default='pig_body/train_data', 119 | help='Directory training to get captcha data files and write the converted result.' 120 | ) 121 | parser.add_argument( 122 | '--valid_dir', 123 | type=str, 124 | default='pig_body/valid_data', 125 | help='Directory validation to get captcha data files and write the converted result.' 126 | ) 127 | FLAGS, unparsed = parser.parse_known_args() 128 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 129 | -------------------------------------------------------------------------------- /nets/dcgan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for dcgan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from nets import dcgan 23 | 24 | 25 | class DCGANTest(tf.test.TestCase): 26 | 27 | def test_generator_run(self): 28 | tf.set_random_seed(1234) 29 | noise = tf.random_normal([100, 64]) 30 | image, _ = dcgan.generator(noise) 31 | with self.test_session() as sess: 32 | sess.run(tf.global_variables_initializer()) 33 | image.eval() 34 | 35 | def test_generator_graph(self): 36 | tf.set_random_seed(1234) 37 | # Check graph construction for a number of image size/depths and batch 38 | # sizes. 39 | for i, batch_size in zip(xrange(3, 7), xrange(3, 8)): 40 | tf.reset_default_graph() 41 | final_size = 2 ** i 42 | noise = tf.random_normal([batch_size, 64]) 43 | image, end_points = dcgan.generator( 44 | noise, 45 | depth=32, 46 | final_size=final_size) 47 | 48 | self.assertAllEqual([batch_size, final_size, final_size, 3], 49 | image.shape.as_list()) 50 | 51 | expected_names = ['deconv%i' % j for j in xrange(1, i)] + ['logits'] 52 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 53 | 54 | # Check layer depths. 55 | for j in range(1, i): 56 | layer = end_points['deconv%i' % j] 57 | self.assertEqual(32 * 2**(i-j-1), layer.get_shape().as_list()[-1]) 58 | 59 | def test_generator_invalid_input(self): 60 | wrong_dim_input = tf.zeros([5, 32, 32]) 61 | with self.assertRaises(ValueError): 62 | dcgan.generator(wrong_dim_input) 63 | 64 | correct_input = tf.zeros([3, 2]) 65 | with self.assertRaisesRegexp(ValueError, 'must be a power of 2'): 66 | dcgan.generator(correct_input, final_size=30) 67 | 68 | with self.assertRaisesRegexp(ValueError, 'must be greater than 8'): 69 | dcgan.generator(correct_input, final_size=4) 70 | 71 | def test_discriminator_run(self): 72 | image = tf.random_uniform([5, 32, 32, 3], -1, 1) 73 | output, _ = dcgan.discriminator(image) 74 | with self.test_session() as sess: 75 | sess.run(tf.global_variables_initializer()) 76 | output.eval() 77 | 78 | def test_discriminator_graph(self): 79 | # Check graph construction for a number of image size/depths and batch 80 | # sizes. 81 | for i, batch_size in zip(xrange(1, 6), xrange(3, 8)): 82 | tf.reset_default_graph() 83 | img_w = 2 ** i 84 | image = tf.random_uniform([batch_size, img_w, img_w, 3], -1, 1) 85 | output, end_points = dcgan.discriminator( 86 | image, 87 | depth=32) 88 | 89 | self.assertAllEqual([batch_size, 1], output.get_shape().as_list()) 90 | 91 | expected_names = ['conv%i' % j for j in xrange(1, i+1)] + ['logits'] 92 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 93 | 94 | # Check layer depths. 95 | for j in range(1, i+1): 96 | layer = end_points['conv%i' % j] 97 | self.assertEqual(32 * 2**(j-1), layer.get_shape().as_list()[-1]) 98 | 99 | def test_discriminator_invalid_input(self): 100 | wrong_dim_img = tf.zeros([5, 32, 32]) 101 | with self.assertRaises(ValueError): 102 | dcgan.discriminator(wrong_dim_img) 103 | 104 | spatially_undefined_shape = tf.placeholder(tf.float32, [5, 32, None, 3]) 105 | with self.assertRaises(ValueError): 106 | dcgan.discriminator(spatially_undefined_shape) 107 | 108 | not_square = tf.zeros([5, 32, 16, 3]) 109 | with self.assertRaisesRegexp(ValueError, 'not have equal width and height'): 110 | dcgan.discriminator(not_square) 111 | 112 | not_power_2 = tf.zeros([5, 30, 30, 3]) 113 | with self.assertRaisesRegexp(ValueError, 'not a power of 2'): 114 | dcgan.discriminator(not_power_2) 115 | 116 | 117 | if __name__ == '__main__': 118 | tf.test.main() 119 | -------------------------------------------------------------------------------- /nets/cyclegan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for tensorflow.contrib.slim.nets.cyclegan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets import cyclegan 24 | 25 | 26 | # TODO(joelshor): Add a test to check generator endpoints. 27 | class CycleganTest(tf.test.TestCase): 28 | 29 | def test_generator_inference(self): 30 | """Check one inference step.""" 31 | img_batch = tf.zeros([2, 32, 32, 3]) 32 | model_output, _ = cyclegan.cyclegan_generator_resnet(img_batch) 33 | with self.test_session() as sess: 34 | sess.run(tf.global_variables_initializer()) 35 | sess.run(model_output) 36 | 37 | def _test_generator_graph_helper(self, shape): 38 | """Check that generator can take small and non-square inputs.""" 39 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(tf.ones(shape)) 40 | self.assertAllEqual(shape, output_imgs.shape.as_list()) 41 | 42 | def test_generator_graph_small(self): 43 | self._test_generator_graph_helper([4, 32, 32, 3]) 44 | 45 | def test_generator_graph_medium(self): 46 | self._test_generator_graph_helper([3, 128, 128, 3]) 47 | 48 | def test_generator_graph_nonsquare(self): 49 | self._test_generator_graph_helper([2, 80, 400, 3]) 50 | 51 | def test_generator_unknown_batch_dim(self): 52 | """Check that generator can take unknown batch dimension inputs.""" 53 | img = tf.placeholder(tf.float32, shape=[None, 32, None, 3]) 54 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(img) 55 | 56 | self.assertAllEqual([None, 32, None, 3], output_imgs.shape.as_list()) 57 | 58 | def _input_and_output_same_shape_helper(self, kernel_size): 59 | img_batch = tf.placeholder(tf.float32, shape=[None, 32, 32, 3]) 60 | output_img_batch, _ = cyclegan.cyclegan_generator_resnet( 61 | img_batch, kernel_size=kernel_size) 62 | 63 | self.assertAllEqual(img_batch.shape.as_list(), 64 | output_img_batch.shape.as_list()) 65 | 66 | def input_and_output_same_shape_kernel3(self): 67 | self._input_and_output_same_shape_helper(3) 68 | 69 | def input_and_output_same_shape_kernel4(self): 70 | self._input_and_output_same_shape_helper(4) 71 | 72 | def input_and_output_same_shape_kernel5(self): 73 | self._input_and_output_same_shape_helper(5) 74 | 75 | def input_and_output_same_shape_kernel6(self): 76 | self._input_and_output_same_shape_helper(6) 77 | 78 | def _error_if_height_not_multiple_of_four_helper(self, height): 79 | self.assertRaisesRegexp( 80 | ValueError, 81 | 'The input height must be a multiple of 4.', 82 | cyclegan.cyclegan_generator_resnet, 83 | tf.placeholder(tf.float32, shape=[None, height, 32, 3])) 84 | 85 | def test_error_if_height_not_multiple_of_four_height29(self): 86 | self._error_if_height_not_multiple_of_four_helper(29) 87 | 88 | def test_error_if_height_not_multiple_of_four_height30(self): 89 | self._error_if_height_not_multiple_of_four_helper(30) 90 | 91 | def test_error_if_height_not_multiple_of_four_height31(self): 92 | self._error_if_height_not_multiple_of_four_helper(31) 93 | 94 | def _error_if_width_not_multiple_of_four_helper(self, width): 95 | self.assertRaisesRegexp( 96 | ValueError, 97 | 'The input width must be a multiple of 4.', 98 | cyclegan.cyclegan_generator_resnet, 99 | tf.placeholder(tf.float32, shape=[None, 32, width, 3])) 100 | 101 | def test_error_if_width_not_multiple_of_four_width29(self): 102 | self._error_if_width_not_multiple_of_four_helper(29) 103 | 104 | def test_error_if_width_not_multiple_of_four_width30(self): 105 | self._error_if_width_not_multiple_of_four_helper(30) 106 | 107 | def test_error_if_width_not_multiple_of_four_width31(self): 108 | self._error_if_width_not_multiple_of_four_helper(31) 109 | 110 | 111 | if __name__ == '__main__': 112 | tf.test.main() 113 | -------------------------------------------------------------------------------- /pig_vgg16/pig_train.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import time 6 | from datetime import datetime 7 | import argparse 8 | import sys 9 | import tensorflow as tf 10 | import pig_model_dark as captcha 11 | #import pig_model_vgg16 as captcha 12 | #import pig_model_restnet_v2 as captcha 13 | import logging 14 | 15 | learning_rate = 2e-4 16 | FLAGS = None 17 | def initLogging(logFilename='record.log'): 18 | """Init for logging 19 | """ 20 | logging.basicConfig( 21 | level = logging.DEBUG, 22 | format='%(asctime)s-%(levelname)s-%(message)s', 23 | datefmt = '%y-%m-%d %H:%M', 24 | filename = logFilename, 25 | filemode = 'w'); 26 | console = logging.StreamHandler() 27 | console.setLevel(logging.INFO) 28 | formatter = logging.Formatter('%(asctime)s-%(levelname)s-%(message)s') 29 | console.setFormatter(formatter) 30 | logging.getLogger('').addHandler(console) 31 | initLogging() 32 | 33 | def run_train(): 34 | """Train CAPTCHA for a number of steps.""" 35 | 36 | with tf.Graph().as_default(): 37 | images, labels = captcha.inputs(train=True, batch_size=FLAGS.batch_size) 38 | 39 | logits = captcha.inference(images, keep_prob=0.9,is_training=True) 40 | loss = captcha.loss(logits, labels) 41 | correct = captcha.evaluation(logits, labels)#train 42 | tf.summary.scalar('loss', loss) 43 | summary = tf.summary.merge_all() 44 | 45 | # train_precision = correct/FLAGS.batch_size 46 | # train_op = captcha.training(loss) 47 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss) 48 | 49 | saver = tf.train.Saver() 50 | init = tf.global_variables_initializer() 51 | 52 | sess = tf.Session() 53 | sess.run(init) 54 | summary_writer = tf.summary.FileWriter(FLAGS.train_dir, sess.graph) 55 | # saver.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_dir)) 56 | coord = tf.train.Coordinator() 57 | threads = tf.train.start_queue_runners(sess=sess, coord=coord) 58 | try: 59 | step = 0 60 | while not coord.should_stop(): 61 | start_time = time.time() 62 | _, loss_value, pre_value, logits_ ,labels_= sess.run([train_op, loss, correct,logits,labels]) 63 | 64 | duration = time.time() - start_time 65 | step += 1 66 | if step % 10 == 0: 67 | logging.info('>> Step %d run_train: loss = %.2f, train = %.2f (%.3f sec)' 68 | % (step, loss_value, pre_value, duration)) 69 | summary_str = sess.run(summary) 70 | summary_writer.add_summary(summary_str, step) 71 | summary_writer.flush() 72 | #------------------------------- 73 | 74 | if step % 500 == 0: 75 | logging.info('>> %s Saving in %s' % (datetime.now(), FLAGS.checkpoint)) 76 | saver.save(sess, FLAGS.checkpoint, global_step=step) 77 | print(images.shape.as_list(),labels.shape.as_list()) 78 | 79 | if step>200000: 80 | break 81 | except KeyboardInterrupt: 82 | print('INTERRUPTED') 83 | coord.request_stop() 84 | except Exception as e: 85 | 86 | coord.request_stop(e) 87 | finally: 88 | saver.save(sess, FLAGS.checkpoint, global_step=step) 89 | print('Model saved in file :%s'%FLAGS.checkpoint) 90 | 91 | coord.request_stop() 92 | coord.join(threads) 93 | sess.close() 94 | 95 | 96 | 97 | def main(_): 98 | # if tf.gfile.Exists(FLAGS.train_dir): 99 | # tf.gfile.DeleteRecursively(FLAGS.train_dir) 100 | # tf.gfile.MakeDirs(FLAGS.train_dir) 101 | run_train() 102 | 103 | 104 | if __name__ == '__main__': 105 | parser = argparse.ArgumentParser() 106 | parser.add_argument( 107 | '--batch_size', 108 | type=int, 109 | default=32, 110 | help='Batch size.' 111 | ) 112 | parser.add_argument( 113 | '--train_dir', 114 | type=str, 115 | default='pig_train', 116 | help='Directory where to write event logs.' 117 | ) 118 | parser.add_argument( 119 | '--checkpoint_dir', 120 | type=str, 121 | default='./checkpoint', 122 | help='Directory where to restore checkpoint.' 123 | ) 124 | parser.add_argument( 125 | '--checkpoint', 126 | type=str, 127 | default='checkpoint/model.ckpt', 128 | help='Directory where to write checkpoint.' 129 | ) 130 | FLAGS, unparsed = parser.parse_known_args() 131 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 132 | -------------------------------------------------------------------------------- /pig_vgg16/write_tfrecords.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Nov 30 22:24:39 2017 4 | 5 | @author: DELL 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import argparse 13 | import os.path 14 | import sys 15 | 16 | from PIL import Image 17 | 18 | import numpy as np 19 | import tensorflow as tf 20 | from tensorflow.python.platform import gfile 21 | import config 22 | import glob 23 | 24 | IMAGE_HEIGHT = 299 25 | IMAGE_WIDTH = 299 26 | CLASSES_NUM = 30 27 | 28 | # for train 29 | RECORD_DIR = 'pig_body' 30 | TRAIN_FILE = 'train.tfrecords' 31 | #TRAIN_FILE = 'D:\pig_recognize\record\train.tfrecords' 32 | VALID_FILE = 'valid.tfrecords' 33 | 34 | FLAGS = None 35 | 36 | def _int64_feature(values): 37 | if not isinstance(values, (tuple, list)): 38 | values = [values] 39 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[values])) 40 | 41 | def _bytes_feature(values): 42 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) 43 | 44 | def _float_feature(values): 45 | if not isinstance(values, (tuple, list)): 46 | values = [values] 47 | return tf.train.Feature(float_list=tf.train.FloatList(value=values)) 48 | 49 | def label_to_one_hot(label): 50 | one_hot_label = np.zeros([CLASSES_NUM]) 51 | one_hot_label[label] = 1.0 52 | return one_hot_label.astype(np.uint8) #(4,10) 53 | 54 | 55 | def image_to_tfexample(image_raw, label_raw, height, width): 56 | return tf.train.Example(features=tf.train.Features(feature={ 57 | 'height': _int64_feature(height), 58 | 'width': _int64_feature(width), 59 | 'label_raw': _bytes_feature(label_raw), 60 | 'image_raw': _bytes_feature(image_raw)})) 61 | 62 | 63 | def conver_to_tfrecords(data_set, name): 64 | """Converts a dataset to tfrecords.""" 65 | if not os.path.exists(RECORD_DIR): 66 | os.makedirs(RECORD_DIR) 67 | filename = os.path.join(RECORD_DIR, name) 68 | print('>> Writing', filename) 69 | writer = tf.python_io.TFRecordWriter(filename) 70 | data_set_list=list(data_set) 71 | num_examples = len(data_set_list) 72 | count = 0 73 | for index in range(num_examples): 74 | count += 1 75 | image = data_set_list[index][0] 76 | height = image.shape[0] 77 | width = image.shape[1] 78 | # 以二进制的形式保存 79 | image_raw = image.tostring() 80 | label = data_set_list[index][1] 81 | label_raw = label_to_one_hot(label).tostring() 82 | 83 | example = image_to_tfexample(image_raw, label_raw, height, width) 84 | writer.write(example.SerializeToString()) 85 | if count %500 == 0: 86 | print('processed {}/{}'.format(count,num_examples)) 87 | writer.close() 88 | print('>> Writing Done!') 89 | 90 | 91 | def create_data_list(image_dir): 92 | if not gfile.Exists(image_dir): 93 | print("Image director '" + image_dir + "' not found.") 94 | return None 95 | extensions = [ '*.jpg'] 96 | print("Looking for images in '" + image_dir + "'") 97 | 98 | # file_list = [] 99 | # for extension in extensions: 100 | # file_glob = os.path.join(image_dir, extension) 101 | # file_list.extend(gfile.Glob(file_glob)) 102 | 103 | file_list = glob.glob('D:/pig_recognize/pig_recognize_body/pig_body/train_data1/*.jpg') 104 | 105 | if not file_list: 106 | print("No files found in '" + image_dir + "'") 107 | return None 108 | images = [] 109 | labels = [] 110 | all_list = len(file_list) 111 | count = 0 112 | for file_name in file_list: 113 | count += 1 114 | image = Image.open(file_name) 115 | image_resize = image.resize(size=(IMAGE_WIDTH,IMAGE_HEIGHT)) 116 | input_img = np.array(image_resize, dtype='uint8') 117 | image.close() 118 | label_name = int(os.path.basename(file_name).split('_')[0]) - 1 #start at 0 119 | images.append(input_img) 120 | labels.append(label_name) 121 | if count % 500 == 0: 122 | print('processed :{}/{}'.format(count,all_list)) 123 | return zip(images, labels) 124 | 125 | 126 | def main(_): 127 | training_data = create_data_list(FLAGS.train_dir) 128 | conver_to_tfrecords(training_data, TRAIN_FILE) 129 | 130 | 131 | if __name__ == '__main__': 132 | parser = argparse.ArgumentParser() 133 | parser.add_argument( 134 | '--train_dir', 135 | type=str, 136 | default='pig_body/train_data1', 137 | help='Directory training to get captcha data files and write the converted result.' 138 | ) 139 | parser.add_argument( 140 | '--valid_dir', 141 | type=str, 142 | default='pig_body/valid_data', 143 | help='Directory validation to get captcha data files and write the converted result.' 144 | ) 145 | FLAGS, unparsed = parser.parse_known_args() 146 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 147 | 148 | 149 | 150 | 151 | -------------------------------------------------------------------------------- /nets/cifarnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the CIFAR-10 model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev) 26 | 27 | 28 | def cifarnet(images, num_classes=10, is_training=False, 29 | dropout_keep_prob=0.5, 30 | prediction_fn=slim.softmax, 31 | scope='CifarNet'): 32 | """Creates a variant of the CifarNet model. 33 | 34 | Note that since the output is a set of 'logits', the values fall in the 35 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 36 | probability distribution over the characters, one will need to convert them 37 | using the softmax function: 38 | 39 | logits = cifarnet.cifarnet(images, is_training=False) 40 | probabilities = tf.nn.softmax(logits) 41 | predictions = tf.argmax(logits, 1) 42 | 43 | Args: 44 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 45 | num_classes: the number of classes in the dataset. If 0 or None, the logits 46 | layer is omitted and the input features to the logits layer are returned 47 | instead. 48 | is_training: specifies whether or not we're currently training the model. 49 | This variable will determine the behaviour of the dropout layer. 50 | dropout_keep_prob: the percentage of activation values that are retained. 51 | prediction_fn: a function to get predictions out of logits. 52 | scope: Optional variable_scope. 53 | 54 | Returns: 55 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 56 | is a non-zero integer, or the input to the logits layer if num_classes 57 | is 0 or None. 58 | end_points: a dictionary from components of the network to the corresponding 59 | activation. 60 | """ 61 | end_points = {} 62 | 63 | with tf.variable_scope(scope, 'CifarNet', [images]): 64 | net = slim.conv2d(images, 64, [5, 5], scope='conv1') 65 | end_points['conv1'] = net 66 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 67 | end_points['pool1'] = net 68 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1') 69 | net = slim.conv2d(net, 64, [5, 5], scope='conv2') 70 | end_points['conv2'] = net 71 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2') 72 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 73 | end_points['pool2'] = net 74 | net = slim.flatten(net) 75 | end_points['Flatten'] = net 76 | net = slim.fully_connected(net, 384, scope='fc3') 77 | end_points['fc3'] = net 78 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 79 | scope='dropout3') 80 | net = slim.fully_connected(net, 192, scope='fc4') 81 | end_points['fc4'] = net 82 | if not num_classes: 83 | return net, end_points 84 | logits = slim.fully_connected(net, num_classes, 85 | biases_initializer=tf.zeros_initializer(), 86 | weights_initializer=trunc_normal(1/192.0), 87 | weights_regularizer=None, 88 | activation_fn=None, 89 | scope='logits') 90 | 91 | end_points['Logits'] = logits 92 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 93 | 94 | return logits, end_points 95 | cifarnet.default_image_size = 32 96 | 97 | 98 | def cifarnet_arg_scope(weight_decay=0.004): 99 | """Defines the default cifarnet argument scope. 100 | 101 | Args: 102 | weight_decay: The weight decay to use for regularizing the model. 103 | 104 | Returns: 105 | An `arg_scope` to use for the inception v3 model. 106 | """ 107 | with slim.arg_scope( 108 | [slim.conv2d], 109 | weights_initializer=tf.truncated_normal_initializer(stddev=5e-2), 110 | activation_fn=tf.nn.relu): 111 | with slim.arg_scope( 112 | [slim.fully_connected], 113 | biases_initializer=tf.constant_initializer(0.1), 114 | weights_initializer=trunc_normal(0.04), 115 | weights_regularizer=slim.l2_regularizer(weight_decay), 116 | activation_fn=tf.nn.relu) as sc: 117 | return sc 118 | -------------------------------------------------------------------------------- /export_inference_graph.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | r"""Saves out a GraphDef containing the architecture of the model. 16 | 17 | To use it, run something like this, with a model name defined by slim: 18 | 19 | bazel build tensorflow_models/research/slim:export_inference_graph 20 | bazel-bin/tensorflow_models/research/slim/export_inference_graph \ 21 | --model_name=inception_v3 --output_file=/tmp/inception_v3_inf_graph.pb 22 | 23 | If you then want to use the resulting model with your own or pretrained 24 | checkpoints as part of a mobile model, you can run freeze_graph to get a graph 25 | def with the variables inlined as constants using: 26 | 27 | bazel build tensorflow/python/tools:freeze_graph 28 | bazel-bin/tensorflow/python/tools/freeze_graph \ 29 | --input_graph=/tmp/inception_v3_inf_graph.pb \ 30 | --input_checkpoint=/tmp/checkpoints/inception_v3.ckpt \ 31 | --input_binary=true --output_graph=/tmp/frozen_inception_v3.pb \ 32 | --output_node_names=InceptionV3/Predictions/Reshape_1 33 | 34 | The output node names will vary depending on the model, but you can inspect and 35 | estimate them using the summarize_graph tool: 36 | 37 | bazel build tensorflow/tools/graph_transforms:summarize_graph 38 | bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \ 39 | --in_graph=/tmp/inception_v3_inf_graph.pb 40 | 41 | To run the resulting graph in C++, you can look at the label_image sample code: 42 | 43 | bazel build tensorflow/examples/label_image:label_image 44 | bazel-bin/tensorflow/examples/label_image/label_image \ 45 | --image=${HOME}/Pictures/flowers.jpg \ 46 | --input_layer=input \ 47 | --output_layer=InceptionV3/Predictions/Reshape_1 \ 48 | --graph=/tmp/frozen_inception_v3.pb \ 49 | --labels=/tmp/imagenet_slim_labels.txt \ 50 | --input_mean=0 \ 51 | --input_std=255 52 | 53 | """ 54 | 55 | from __future__ import absolute_import 56 | from __future__ import division 57 | from __future__ import print_function 58 | 59 | import tensorflow as tf 60 | 61 | from tensorflow.python.platform import gfile 62 | from datasets import dataset_factory 63 | from nets import nets_factory 64 | 65 | 66 | slim = tf.contrib.slim 67 | 68 | tf.app.flags.DEFINE_string( 69 | 'model_name', 'inception_v3', 'The name of the architecture to save.') 70 | 71 | tf.app.flags.DEFINE_boolean( 72 | 'is_training', False, 73 | 'Whether to save out a training-focused version of the model.') 74 | 75 | tf.app.flags.DEFINE_integer( 76 | 'image_size', None, 77 | 'The image size to use, otherwise use the model default_image_size.') 78 | 79 | tf.app.flags.DEFINE_integer( 80 | 'batch_size', None, 81 | 'Batch size for the exported model. Defaulted to "None" so batch size can ' 82 | 'be specified at model runtime.') 83 | 84 | tf.app.flags.DEFINE_string('dataset_name', 'imagenet', 85 | 'The name of the dataset to use with the model.') 86 | 87 | tf.app.flags.DEFINE_integer( 88 | 'labels_offset', 0, 89 | 'An offset for the labels in the dataset. This flag is primarily used to ' 90 | 'evaluate the VGG and ResNet architectures which do not use a background ' 91 | 'class for the ImageNet dataset.') 92 | 93 | tf.app.flags.DEFINE_string( 94 | 'output_file', '', 'Where to save the resulting file to.') 95 | 96 | tf.app.flags.DEFINE_string( 97 | 'dataset_dir', '', 'Directory to save intermediate dataset files to') 98 | 99 | FLAGS = tf.app.flags.FLAGS 100 | 101 | 102 | def main(_): 103 | if not FLAGS.output_file: 104 | raise ValueError('You must supply the path to save to with --output_file') 105 | tf.logging.set_verbosity(tf.logging.INFO) 106 | with tf.Graph().as_default() as graph: 107 | dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'train', 108 | FLAGS.dataset_dir) 109 | network_fn = nets_factory.get_network_fn( 110 | FLAGS.model_name, 111 | num_classes=(dataset.num_classes - FLAGS.labels_offset), 112 | is_training=FLAGS.is_training) 113 | image_size = FLAGS.image_size or network_fn.default_image_size 114 | placeholder = tf.placeholder(name='input', dtype=tf.float32, 115 | shape=[FLAGS.batch_size, image_size, 116 | image_size, 3]) 117 | network_fn(placeholder) 118 | graph_def = graph.as_graph_def() 119 | with gfile.GFile(FLAGS.output_file, 'wb') as f: 120 | f.write(graph_def.SerializeToString()) 121 | 122 | 123 | if __name__ == '__main__': 124 | tf.app.run() 125 | -------------------------------------------------------------------------------- /pig_vgg16/pig_records_padding.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Nov 23 23:21:19 2017 4 | 5 | @author: DELL 6 | """ 7 | 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import argparse 13 | import os.path 14 | import sys 15 | 16 | #from PIL import Image 17 | import cv2 18 | 19 | import numpy as np 20 | import tensorflow as tf 21 | from tensorflow.python.platform import gfile 22 | import config 23 | 24 | IMAGE_HEIGHT = config.IMAGE_HEIGHT #256 25 | IMAGE_WIDTH = config.IMAGE_WIDTH #256 26 | CLASSES_NUM = config.CLASSES_NUM #10 27 | 28 | 29 | RECORD_DIR = config.RECORD_DIR 30 | TRAIN_FILE = config.TRAIN_FILE 31 | VALID_FILE = config.VALID_FILE 32 | 33 | FLAGS = None 34 | 35 | def _int64_feature(value): 36 | return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) 37 | 38 | 39 | def _bytes_feature(value): 40 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 41 | 42 | 43 | def label_to_one_hot(label): 44 | one_hot_label = np.zeros([CLASSES_NUM]) 45 | one_hot_label[label] = 1.0 46 | return one_hot_label.astype(np.uint8) #(4,10) 47 | 48 | 49 | def conver_to_tfrecords(data_set, name): 50 | """Converts a dataset to tfrecords.""" 51 | if not os.path.exists(RECORD_DIR): 52 | os.makedirs(RECORD_DIR) 53 | filename = os.path.join(RECORD_DIR, name) 54 | print('>> Writing', filename) 55 | writer = tf.python_io.TFRecordWriter(filename) 56 | data_set_list=list(data_set) 57 | num_examples = len(data_set_list) 58 | count = 0 59 | for index in range(num_examples): 60 | count += 1 61 | image = data_set_list[index][0] 62 | height = image.shape[0] 63 | width = image.shape[1] 64 | image_raw = image.tostring() 65 | label = data_set_list[index][1] 66 | label_raw = label_to_one_hot(label).tostring() 67 | example = tf.train.Example(features=tf.train.Features(feature={ 68 | 'height': _int64_feature(height), 69 | 'width': _int64_feature(width), 70 | 'label_raw': _bytes_feature(label_raw), 71 | 'image_raw': _bytes_feature(image_raw)})) 72 | writer.write(example.SerializeToString()) 73 | if count %500 == 0: 74 | print('processed {}/{}'.format(count,num_examples)) 75 | writer.close() 76 | print('>> Writing Done!') 77 | 78 | 79 | def create_data_list(image_dir): 80 | if not gfile.Exists(image_dir): 81 | print("Image director '" + image_dir + "' not found.") 82 | return None 83 | extensions = [ '*.jpg'] 84 | print("Looking for images in '" + image_dir + "'") 85 | file_list = [] 86 | for extension in extensions: 87 | file_glob = os.path.join(image_dir, extension) 88 | file_list.extend(gfile.Glob(file_glob)) 89 | if not file_list: 90 | print("No files found in '" + image_dir + "'") 91 | return None 92 | images = [] 93 | labels = [] 94 | all_list = len(file_list) 95 | count = 0 96 | for file_name in file_list: 97 | count += 1 98 | image = cv2.imread(file_name) 99 | h,w,c = image.shape 100 | fig = np.ones((IMAGE_WIDTH,IMAGE_WIDTH,c))*255 101 | if h>w: 102 | rate = h/IMAGE_WIDTH 103 | h_v = IMAGE_WIDTH 104 | w_v = int(w/rate) 105 | border = (IMAGE_WIDTH - w_v) 106 | up = 0 107 | down = IMAGE_WIDTH + 1 108 | left = int(border/2) 109 | right = int(IMAGE_WIDTH - border/2) 110 | else: 111 | rate = w/IMAGE_WIDTH 112 | w_v = IMAGE_WIDTH 113 | h_v = int(h/rate) 114 | border = (IMAGE_WIDTH - h_v) 115 | up = int(border/2) 116 | down = int(IMAGE_WIDTH - border/2) 117 | left = 0 118 | right = IMAGE_WIDTH + 1 119 | 120 | img_v = cv2.resize(image, (w_v, h_v)) 121 | fig[up:down, left:right] = img_v 122 | input_img = np.array(fig, dtype='int16') 123 | # fig = fig.astype(img_v.dtype) 124 | # cv2.imwrite('q.jpg', input_img) 125 | # cv2.imwrite('p.jpg', fig) 126 | label_name = int(os.path.basename(file_name).split('_')[0]) - 1 #start at 0 127 | images.append(input_img) 128 | labels.append(label_name) 129 | if count % 500 == 0: 130 | print('processed :{}/{}'.format(count,all_list)) 131 | return zip(images, labels) 132 | 133 | 134 | def main(_): 135 | training_data = create_data_list(FLAGS.train_dir) 136 | conver_to_tfrecords(training_data, TRAIN_FILE) 137 | 138 | 139 | if __name__ == '__main__': 140 | parser = argparse.ArgumentParser() 141 | parser.add_argument( 142 | '--train_dir', 143 | type=str, 144 | default='pig_body/train_data', 145 | help='Directory training to get captcha data files and write the converted result.' 146 | ) 147 | parser.add_argument( 148 | '--valid_dir', 149 | type=str, 150 | default='pig_body/valid_data', 151 | help='Directory validation to get captcha data files and write the converted result.' 152 | ) 153 | FLAGS, unparsed = parser.parse_known_args() 154 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 155 | 156 | 157 | -------------------------------------------------------------------------------- /preprocessing/cifarnet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Provides utilities to preprocess images in CIFAR-10. 16 | 17 | """ 18 | 19 | from __future__ import absolute_import 20 | from __future__ import division 21 | from __future__ import print_function 22 | 23 | import tensorflow as tf 24 | 25 | _PADDING = 4 26 | 27 | slim = tf.contrib.slim 28 | 29 | 30 | def preprocess_for_train(image, 31 | output_height, 32 | output_width, 33 | padding=_PADDING, 34 | add_image_summaries=True): 35 | """Preprocesses the given image for training. 36 | 37 | Note that the actual resizing scale is sampled from 38 | [`resize_size_min`, `resize_size_max`]. 39 | 40 | Args: 41 | image: A `Tensor` representing an image of arbitrary size. 42 | output_height: The height of the image after preprocessing. 43 | output_width: The width of the image after preprocessing. 44 | padding: The amound of padding before and after each dimension of the image. 45 | add_image_summaries: Enable image summaries. 46 | 47 | Returns: 48 | A preprocessed image. 49 | """ 50 | if add_image_summaries: 51 | tf.summary.image('image', tf.expand_dims(image, 0)) 52 | 53 | # Transform the image to floats. 54 | image = tf.to_float(image) 55 | if padding > 0: 56 | image = tf.pad(image, [[padding, padding], [padding, padding], [0, 0]]) 57 | # Randomly crop a [height, width] section of the image. 58 | distorted_image = tf.random_crop(image, 59 | [output_height, output_width, 3]) 60 | 61 | # Randomly flip the image horizontally. 62 | distorted_image = tf.image.random_flip_left_right(distorted_image) 63 | 64 | if add_image_summaries: 65 | tf.summary.image('distorted_image', tf.expand_dims(distorted_image, 0)) 66 | 67 | # Because these operations are not commutative, consider randomizing 68 | # the order their operation. 69 | distorted_image = tf.image.random_brightness(distorted_image, 70 | max_delta=63) 71 | distorted_image = tf.image.random_contrast(distorted_image, 72 | lower=0.2, upper=1.8) 73 | # Subtract off the mean and divide by the variance of the pixels. 74 | return tf.image.per_image_standardization(distorted_image) 75 | 76 | 77 | def preprocess_for_eval(image, output_height, output_width, 78 | add_image_summaries=True): 79 | """Preprocesses the given image for evaluation. 80 | 81 | Args: 82 | image: A `Tensor` representing an image of arbitrary size. 83 | output_height: The height of the image after preprocessing. 84 | output_width: The width of the image after preprocessing. 85 | add_image_summaries: Enable image summaries. 86 | 87 | Returns: 88 | A preprocessed image. 89 | """ 90 | if add_image_summaries: 91 | tf.summary.image('image', tf.expand_dims(image, 0)) 92 | # Transform the image to floats. 93 | image = tf.to_float(image) 94 | 95 | # Resize and crop if needed. 96 | resized_image = tf.image.resize_image_with_crop_or_pad(image, 97 | output_width, 98 | output_height) 99 | if add_image_summaries: 100 | tf.summary.image('resized_image', tf.expand_dims(resized_image, 0)) 101 | 102 | # Subtract off the mean and divide by the variance of the pixels. 103 | return tf.image.per_image_standardization(resized_image) 104 | 105 | 106 | def preprocess_image(image, output_height, output_width, is_training=False, 107 | add_image_summaries=True): 108 | """Preprocesses the given image. 109 | 110 | Args: 111 | image: A `Tensor` representing an image of arbitrary size. 112 | output_height: The height of the image after preprocessing. 113 | output_width: The width of the image after preprocessing. 114 | is_training: `True` if we're preprocessing the image for training and 115 | `False` otherwise. 116 | add_image_summaries: Enable image summaries. 117 | 118 | Returns: 119 | A preprocessed image. 120 | """ 121 | if is_training: 122 | return preprocess_for_train( 123 | image, output_height, output_width, 124 | add_image_summaries=add_image_summaries) 125 | else: 126 | return preprocess_for_eval( 127 | image, output_height, output_width, 128 | add_image_summaries=add_image_summaries) 129 | -------------------------------------------------------------------------------- /datasets/dataset_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains utilities for downloading and converting datasets.""" 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import os 21 | import sys 22 | import tarfile 23 | 24 | from six.moves import urllib 25 | import tensorflow as tf 26 | 27 | LABELS_FILENAME = 'labels.txt' 28 | 29 | 30 | def int64_feature(values): 31 | """Returns a TF-Feature of int64s. 32 | 33 | Args: 34 | values: A scalar or list of values. 35 | 36 | Returns: 37 | A TF-Feature. 38 | """ 39 | if not isinstance(values, (tuple, list)): 40 | values = [values] 41 | return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) 42 | 43 | 44 | def bytes_feature(values): 45 | """Returns a TF-Feature of bytes. 46 | 47 | Args: 48 | values: A string. 49 | 50 | Returns: 51 | A TF-Feature. 52 | """ 53 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values])) 54 | 55 | 56 | def float_feature(values): 57 | """Returns a TF-Feature of floats. 58 | 59 | Args: 60 | values: A scalar of list of values. 61 | 62 | Returns: 63 | A TF-Feature. 64 | """ 65 | if not isinstance(values, (tuple, list)): 66 | values = [values] 67 | return tf.train.Feature(float_list=tf.train.FloatList(value=values)) 68 | 69 | 70 | def image_to_tfexample(image_data, image_format, height, width, class_id): 71 | return tf.train.Example(features=tf.train.Features(feature={ 72 | 'image/encoded': bytes_feature(image_data), 73 | 'image/format': bytes_feature(image_format), 74 | 'image/class/label': int64_feature(class_id), 75 | 'image/height': int64_feature(height), 76 | 'image/width': int64_feature(width), 77 | })) 78 | 79 | 80 | def download_and_uncompress_tarball(tarball_url, dataset_dir): 81 | """Downloads the `tarball_url` and uncompresses it locally. 82 | 83 | Args: 84 | tarball_url: The URL of a tarball file. 85 | dataset_dir: The directory where the temporary files are stored. 86 | """ 87 | filename = tarball_url.split('/')[-1] 88 | filepath = os.path.join(dataset_dir, filename) 89 | 90 | def _progress(count, block_size, total_size): 91 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 92 | filename, float(count * block_size) / float(total_size) * 100.0)) 93 | sys.stdout.flush() 94 | filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress) 95 | print() 96 | statinfo = os.stat(filepath) 97 | print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') 98 | tarfile.open(filepath, 'r:gz').extractall(dataset_dir) 99 | 100 | 101 | def write_label_file(labels_to_class_names, dataset_dir, 102 | filename=LABELS_FILENAME): 103 | """Writes a file with the list of class names. 104 | 105 | Args: 106 | labels_to_class_names: A map of (integer) labels to class names. 107 | dataset_dir: The directory in which the labels file should be written. 108 | filename: The filename where the class names are written. 109 | """ 110 | labels_filename = os.path.join(dataset_dir, filename) 111 | with tf.gfile.Open(labels_filename, 'w') as f: 112 | for label in labels_to_class_names: 113 | class_name = labels_to_class_names[label] 114 | f.write('%d:%s\n' % (label, class_name)) 115 | 116 | 117 | def has_labels(dataset_dir, filename=LABELS_FILENAME): 118 | """Specifies whether or not the dataset directory contains a label map file. 119 | 120 | Args: 121 | dataset_dir: The directory in which the labels file is found. 122 | filename: The filename where the class names are written. 123 | 124 | Returns: 125 | `True` if the labels file exists and `False` otherwise. 126 | """ 127 | return tf.gfile.Exists(os.path.join(dataset_dir, filename)) 128 | 129 | 130 | def read_label_file(dataset_dir, filename=LABELS_FILENAME): 131 | """Reads the labels file and returns a mapping from ID to class name. 132 | 133 | Args: 134 | dataset_dir: The directory in which the labels file is found. 135 | filename: The filename where the class names are written. 136 | 137 | Returns: 138 | A map from a label (integer) to class name. 139 | """ 140 | labels_filename = os.path.join(dataset_dir, filename) 141 | with tf.gfile.Open(labels_filename, 'rb') as f: 142 | lines = f.read().decode() 143 | lines = lines.split('\n') 144 | lines = filter(None, lines) 145 | 146 | labels_to_class_names = {} 147 | for line in lines: 148 | index = line.index(':') 149 | labels_to_class_names[int(line[:index])] = line[index+1:] 150 | return labels_to_class_names 151 | -------------------------------------------------------------------------------- /scripts/export_mobilenet.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # ============================================================================== 16 | 17 | # This script prepares the various different versions of MobileNet models for 18 | # use in a mobile application. If you don't specify your own trained checkpoint 19 | # file, it will download pretrained checkpoints for ImageNet. You'll also need 20 | # to have a copy of the TensorFlow source code to run some of the commands, 21 | # by default it will be looked for in ./tensorflow, but you can set the 22 | # TENSORFLOW_PATH environment variable before calling the script if your source 23 | # is in a different location. 24 | # The main slim/nets/mobilenet_v1.md description has more details about the 25 | # model, but the main points are that it comes in four size versions, 1.0, 0.75, 26 | # 0.50, and 0.25, which controls the number of parameters and so the file size 27 | # of the model, and the input image size, which can be 224, 192, 160, or 128 28 | # pixels, and affects the amount of computation needed, and the latency. 29 | # Here's an example generating a frozen model from pretrained weights: 30 | # 31 | 32 | set -e 33 | 34 | print_usage () { 35 | echo "Creates a frozen mobilenet model suitable for mobile use" 36 | echo "Usage:" 37 | echo "$0 [checkpoint path]" 38 | } 39 | 40 | MOBILENET_VERSION=$1 41 | IMAGE_SIZE=$2 42 | CHECKPOINT=$3 43 | 44 | if [[ ${MOBILENET_VERSION} = "1.0" ]]; then 45 | SLIM_NAME=mobilenet_v1 46 | elif [[ ${MOBILENET_VERSION} = "0.75" ]]; then 47 | SLIM_NAME=mobilenet_v1_075 48 | elif [[ ${MOBILENET_VERSION} = "0.50" ]]; then 49 | SLIM_NAME=mobilenet_v1_050 50 | elif [[ ${MOBILENET_VERSION} = "0.25" ]]; then 51 | SLIM_NAME=mobilenet_v1_025 52 | else 53 | echo "Bad mobilenet version, should be one of 1.0, 0.75, 0.50, or 0.25" 54 | print_usage 55 | exit 1 56 | fi 57 | 58 | if [[ ${IMAGE_SIZE} -ne "224" ]] && [[ ${IMAGE_SIZE} -ne "192" ]] && [[ ${IMAGE_SIZE} -ne "160" ]] && [[ ${IMAGE_SIZE} -ne "128" ]]; then 59 | echo "Bad input image size, should be one of 224, 192, 160, or 128" 60 | print_usage 61 | exit 1 62 | fi 63 | 64 | if [[ ${TENSORFLOW_PATH} -eq "" ]]; then 65 | TENSORFLOW_PATH=../tensorflow 66 | fi 67 | 68 | if [[ ! -d ${TENSORFLOW_PATH} ]]; then 69 | echo "TensorFlow source folder not found. You should download the source and then set" 70 | echo "the TENSORFLOW_PATH environment variable to point to it, like this:" 71 | echo "export TENSORFLOW_PATH=/my/path/to/tensorflow" 72 | print_usage 73 | exit 1 74 | fi 75 | 76 | MODEL_FOLDER=/tmp/mobilenet_v1_${MOBILENET_VERSION}_${IMAGE_SIZE} 77 | if [[ -d ${MODEL_FOLDER} ]]; then 78 | echo "Model folder ${MODEL_FOLDER} already exists!" 79 | echo "If you want to overwrite it, then 'rm -rf ${MODEL_FOLDER}' first." 80 | print_usage 81 | exit 1 82 | fi 83 | mkdir ${MODEL_FOLDER} 84 | 85 | if [[ ${CHECKPOINT} = "" ]]; then 86 | echo "*******" 87 | echo "Downloading pretrained weights" 88 | echo "*******" 89 | curl "http://download.tensorflow.org/models/mobilenet_v1_${MOBILENET_VERSION}_${IMAGE_SIZE}_2017_06_14.tar.gz" \ 90 | -o ${MODEL_FOLDER}/checkpoints.tar.gz 91 | tar xzf ${MODEL_FOLDER}/checkpoints.tar.gz --directory ${MODEL_FOLDER} 92 | CHECKPOINT=${MODEL_FOLDER}/mobilenet_v1_${MOBILENET_VERSION}_${IMAGE_SIZE}.ckpt 93 | fi 94 | 95 | echo "*******" 96 | echo "Exporting graph architecture to ${MODEL_FOLDER}/unfrozen_graph.pb" 97 | echo "*******" 98 | bazel run slim:export_inference_graph -- \ 99 | --model_name=${SLIM_NAME} --image_size=${IMAGE_SIZE} --logtostderr \ 100 | --output_file=${MODEL_FOLDER}/unfrozen_graph.pb --dataset_dir=${MODEL_FOLDER} 101 | 102 | cd ../tensorflow 103 | 104 | echo "*******" 105 | echo "Freezing graph to ${MODEL_FOLDER}/frozen_graph.pb" 106 | echo "*******" 107 | bazel run tensorflow/python/tools:freeze_graph -- \ 108 | --input_graph=${MODEL_FOLDER}/unfrozen_graph.pb \ 109 | --input_checkpoint=${CHECKPOINT} \ 110 | --input_binary=true --output_graph=${MODEL_FOLDER}/frozen_graph.pb \ 111 | --output_node_names=MobilenetV1/Predictions/Reshape_1 112 | 113 | echo "Quantizing weights to ${MODEL_FOLDER}/quantized_graph.pb" 114 | bazel run tensorflow/tools/graph_transforms:transform_graph -- \ 115 | --in_graph=${MODEL_FOLDER}/frozen_graph.pb \ 116 | --out_graph=${MODEL_FOLDER}/quantized_graph.pb \ 117 | --inputs=input --outputs=MobilenetV1/Predictions/Reshape_1 \ 118 | --transforms='fold_constants fold_batch_norms quantize_weights' 119 | 120 | echo "*******" 121 | echo "Running label_image using the graph" 122 | echo "*******" 123 | bazel build tensorflow/examples/label_image:label_image 124 | bazel-bin/tensorflow/examples/label_image/label_image \ 125 | --input_layer=input --output_layer=MobilenetV1/Predictions/Reshape_1 \ 126 | --graph=${MODEL_FOLDER}/quantized_graph.pb --input_mean=-127 --input_std=127 \ 127 | --image=tensorflow/examples/label_image/data/grace_hopper.jpg \ 128 | --input_width=${IMAGE_SIZE} --input_height=${IMAGE_SIZE} --labels=${MODEL_FOLDER}/labels.txt 129 | 130 | echo "*******" 131 | echo "Saved graphs to ${MODEL_FOLDER}/frozen_graph.pb and ${MODEL_FOLDER}/quantized_graph.pb" 132 | echo "*******" 133 | -------------------------------------------------------------------------------- /pre_crop.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Fri Dec 8 09:38:39 2017 4 | 5 | @author: DELL 6 | """ 7 | 8 | import tensorflow as tf 9 | import scipy.misc as misc 10 | import matplotlib.pylab as plt 11 | 12 | def preprocess_for_eval_beifen(image, height, width, 13 | central_fraction=0.875, scope=None): 14 | """Prepare one image for evaluation. 15 | 16 | If height and width are specified it would output an image with that size by 17 | applying resize_bilinear. 18 | 19 | If central_fraction is specified it would crop the central fraction of the 20 | input image. 21 | 22 | Args: 23 | image: 3-D Tensor of image. If dtype is tf.float32 then the range should be 24 | [0, 1], otherwise it would converted to tf.float32 assuming that the range 25 | is [0, MAX], where MAX is largest positive representable number for 26 | int(8/16/32) data type (see `tf.image.convert_image_dtype` for details). 27 | height: integer 28 | width: integer 29 | central_fraction: Optional Float, fraction of the image to crop. 30 | scope: Optional scope for name_scope. 31 | Returns: 32 | 3-D float Tensor of prepared image. 33 | """ 34 | with tf.name_scope(scope, 'eval_image', [image, height, width]): 35 | if image.dtype != tf.float32: 36 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 37 | # Crop the central region of the image with an area containing 87.5% of 38 | # the original image. 39 | 40 | if central_fraction: 41 | image = tf.image.central_crop(image, central_fraction=central_fraction) 42 | 43 | # if height and width: 44 | # # Resize the image to the specified height and width. 45 | # image = tf.expand_dims(image, 0) 46 | # image = tf.image.resize_bilinear(image, [height, width], align_corners = False) 47 | # image = tf.squeeze(image, [0]) 48 | ## image = tf.subtract(image, 0.5) 49 | ## image = tf.multiply(image, 2.0) 50 | image_1 = tf.image.convert_image_dtype(image, dtype=tf.uint8) 51 | return image_1 52 | 53 | def distorted_bounding_box_crop(image, 54 | bbox=None, 55 | min_object_covered=0.1, 56 | aspect_ratio_range=(0.75, 1.33), 57 | area_range=(0.05, 1.0), 58 | max_attempts=100, 59 | scope=None): 60 | """Generates cropped_image using a one of the bboxes randomly distorted. 61 | 62 | """ 63 | with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]): 64 | # Each bounding box has shape [1, num_boxes, box coords] and 65 | # the coordinates are ordered [ymin, xmin, ymax, xmax]. 66 | 67 | # A large fraction of image datasets contain a human-annotated bounding 68 | # box delineating the region of the image containing the object of interest. 69 | # We choose to create a new bounding box for the object which is a randomly 70 | # distorted version of the human-annotated bounding box that obeys an 71 | # allowed range of aspect ratios, sizes and overlap with the human-annotated 72 | # bounding box. If no box is supplied, then we assume the bounding box is 73 | # the entire image. 74 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 75 | tf.shape(image), 76 | bounding_boxes=bbox, 77 | min_object_covered=min_object_covered, 78 | aspect_ratio_range=aspect_ratio_range, 79 | area_range=area_range, 80 | max_attempts=max_attempts, 81 | use_image_if_no_bounding_boxes=True) 82 | bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box 83 | 84 | # Crop the image to the specified bounding box. 85 | cropped_image = tf.slice(image, bbox_begin, bbox_size) 86 | return cropped_image, distort_bbox 87 | 88 | def preprocess_for_train(image, height, width, bbox=None, 89 | fast_mode=True, 90 | scope=None, 91 | add_image_summaries=True): 92 | bbox = tf.constant([0.0, 0.0, 1.0, 1.0], 93 | dtype=tf.float32, 94 | shape=[1, 1, 4]) 95 | if image.dtype != tf.float32: 96 | image = tf.image.convert_image_dtype(image, dtype=tf.float32) 97 | # Each bounding box has shape [1, num_boxes, box coords] and 98 | # the coordinates are ordered [ymin, xmin, ymax, xmax]. 99 | image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0), 100 | bbox) 101 | distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox) 102 | # Restore the shape since the dynamic slice based upon the bbox_size loses 103 | # the third dimension. 104 | distorted_image.set_shape([None, None, 3]) 105 | image_with_distorted_box = tf.image.draw_bounding_boxes( 106 | tf.expand_dims(image, 0), distorted_bbox) 107 | return image_with_box, image_with_distorted_box 108 | 109 | if __name__ == '__main__': 110 | graph_a = tf.Graph() 111 | with graph_a.as_default(): 112 | img = plt.imread('D:/pig_recognize/pig_slim1/pig_test/00031.JPG') 113 | 114 | img_input = tf.placeholder(dtype=tf.float32 ) 115 | # image = preprocess_for_eval_beifen(img_input, 299, 299) 116 | image1, image2 = preprocess_for_train(img_input, height=299, width=299) 117 | 118 | init = tf.global_variables_initializer() 119 | 120 | with tf.Session() as sess: 121 | sess.run(init) 122 | image1_, image2_ = sess.run([image1, image2], feed_dict={img_input:img}) 123 | image_1 = image1_[0] 124 | image_2 = image2_[0] 125 | plt.imsave('1.jpg',image_1) 126 | plt.imsave('2.jpg',image_2) 127 | 128 | 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /nets/overfeat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the model definition for the OverFeat network. 16 | 17 | The definition for the network was obtained from: 18 | OverFeat: Integrated Recognition, Localization and Detection using 19 | Convolutional Networks 20 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 21 | Yann LeCun, 2014 22 | http://arxiv.org/abs/1312.6229 23 | 24 | Usage: 25 | with slim.arg_scope(overfeat.overfeat_arg_scope()): 26 | outputs, end_points = overfeat.overfeat(inputs) 27 | 28 | @@overfeat 29 | """ 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | 34 | import tensorflow as tf 35 | 36 | slim = tf.contrib.slim 37 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 38 | 39 | 40 | def overfeat_arg_scope(weight_decay=0.0005): 41 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 42 | activation_fn=tf.nn.relu, 43 | weights_regularizer=slim.l2_regularizer(weight_decay), 44 | biases_initializer=tf.zeros_initializer()): 45 | with slim.arg_scope([slim.conv2d], padding='SAME'): 46 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 47 | return arg_sc 48 | 49 | 50 | def overfeat(inputs, 51 | num_classes=1000, 52 | is_training=True, 53 | dropout_keep_prob=0.5, 54 | spatial_squeeze=True, 55 | scope='overfeat', 56 | global_pool=False): 57 | """Contains the model definition for the OverFeat network. 58 | 59 | The definition for the network was obtained from: 60 | OverFeat: Integrated Recognition, Localization and Detection using 61 | Convolutional Networks 62 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 63 | Yann LeCun, 2014 64 | http://arxiv.org/abs/1312.6229 65 | 66 | Note: All the fully_connected layers have been transformed to conv2d layers. 67 | To use in classification mode, resize input to 231x231. To use in fully 68 | convolutional mode, set spatial_squeeze to false. 69 | 70 | Args: 71 | inputs: a tensor of size [batch_size, height, width, channels]. 72 | num_classes: number of predicted classes. If 0 or None, the logits layer is 73 | omitted and the input features to the logits layer are returned instead. 74 | is_training: whether or not the model is being trained. 75 | dropout_keep_prob: the probability that activations are kept in the dropout 76 | layers during training. 77 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 78 | outputs. Useful to remove unnecessary dimensions for classification. 79 | scope: Optional scope for the variables. 80 | global_pool: Optional boolean flag. If True, the input to the classification 81 | layer is avgpooled to size 1x1, for any input size. (This is not part 82 | of the original OverFeat.) 83 | 84 | Returns: 85 | net: the output of the logits layer (if num_classes is a non-zero integer), 86 | or the non-dropped-out input to the logits layer (if num_classes is 0 or 87 | None). 88 | end_points: a dict of tensors with intermediate activations. 89 | """ 90 | with tf.variable_scope(scope, 'overfeat', [inputs]) as sc: 91 | end_points_collection = sc.original_name_scope + '_end_points' 92 | # Collect outputs for conv2d, fully_connected and max_pool2d 93 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 94 | outputs_collections=end_points_collection): 95 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 96 | scope='conv1') 97 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 98 | net = slim.conv2d(net, 256, [5, 5], padding='VALID', scope='conv2') 99 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 100 | net = slim.conv2d(net, 512, [3, 3], scope='conv3') 101 | net = slim.conv2d(net, 1024, [3, 3], scope='conv4') 102 | net = slim.conv2d(net, 1024, [3, 3], scope='conv5') 103 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 104 | 105 | # Use conv2d instead of fully_connected layers. 106 | with slim.arg_scope([slim.conv2d], 107 | weights_initializer=trunc_normal(0.005), 108 | biases_initializer=tf.constant_initializer(0.1)): 109 | net = slim.conv2d(net, 3072, [6, 6], padding='VALID', scope='fc6') 110 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 111 | scope='dropout6') 112 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 113 | # Convert end_points_collection into a end_point dict. 114 | end_points = slim.utils.convert_collection_to_dict( 115 | end_points_collection) 116 | if global_pool: 117 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 118 | end_points['global_pool'] = net 119 | if num_classes: 120 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 121 | scope='dropout7') 122 | net = slim.conv2d(net, num_classes, [1, 1], 123 | activation_fn=None, 124 | normalizer_fn=None, 125 | biases_initializer=tf.zeros_initializer(), 126 | scope='fc8') 127 | if spatial_squeeze: 128 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 129 | end_points[sc.name + '/fc8'] = net 130 | return net, end_points 131 | overfeat.default_image_size = 231 132 | -------------------------------------------------------------------------------- /nets/alexnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a model definition for AlexNet. 16 | 17 | This work was first described in: 18 | ImageNet Classification with Deep Convolutional Neural Networks 19 | Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton 20 | 21 | and later refined in: 22 | One weird trick for parallelizing convolutional neural networks 23 | Alex Krizhevsky, 2014 24 | 25 | Here we provide the implementation proposed in "One weird trick" and not 26 | "ImageNet Classification", as per the paper, the LRN layers have been removed. 27 | 28 | Usage: 29 | with slim.arg_scope(alexnet.alexnet_v2_arg_scope()): 30 | outputs, end_points = alexnet.alexnet_v2(inputs) 31 | 32 | @@alexnet_v2 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | import tensorflow as tf 40 | 41 | slim = tf.contrib.slim 42 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 43 | 44 | 45 | def alexnet_v2_arg_scope(weight_decay=0.0005): 46 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 47 | activation_fn=tf.nn.relu, 48 | biases_initializer=tf.constant_initializer(0.1), 49 | weights_regularizer=slim.l2_regularizer(weight_decay)): 50 | with slim.arg_scope([slim.conv2d], padding='SAME'): 51 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 52 | return arg_sc 53 | 54 | 55 | def alexnet_v2(inputs, 56 | num_classes=1000, 57 | is_training=True, 58 | dropout_keep_prob=0.5, 59 | spatial_squeeze=True, 60 | scope='alexnet_v2', 61 | global_pool=False): 62 | """AlexNet version 2. 63 | 64 | Described in: http://arxiv.org/pdf/1404.5997v2.pdf 65 | Parameters from: 66 | github.com/akrizhevsky/cuda-convnet2/blob/master/layers/ 67 | layers-imagenet-1gpu.cfg 68 | 69 | Note: All the fully_connected layers have been transformed to conv2d layers. 70 | To use in classification mode, resize input to 224x224 or set 71 | global_pool=True. To use in fully convolutional mode, set 72 | spatial_squeeze to false. 73 | The LRN layers have been removed and change the initializers from 74 | random_normal_initializer to xavier_initializer. 75 | 76 | Args: 77 | inputs: a tensor of size [batch_size, height, width, channels]. 78 | num_classes: the number of predicted classes. If 0 or None, the logits layer 79 | is omitted and the input features to the logits layer are returned instead. 80 | is_training: whether or not the model is being trained. 81 | dropout_keep_prob: the probability that activations are kept in the dropout 82 | layers during training. 83 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 84 | logits. Useful to remove unnecessary dimensions for classification. 85 | scope: Optional scope for the variables. 86 | global_pool: Optional boolean flag. If True, the input to the classification 87 | layer is avgpooled to size 1x1, for any input size. (This is not part 88 | of the original AlexNet.) 89 | 90 | Returns: 91 | net: the output of the logits layer (if num_classes is a non-zero integer), 92 | or the non-dropped-out input to the logits layer (if num_classes is 0 93 | or None). 94 | end_points: a dict of tensors with intermediate activations. 95 | """ 96 | with tf.variable_scope(scope, 'alexnet_v2', [inputs]) as sc: 97 | end_points_collection = sc.original_name_scope + '_end_points' 98 | # Collect outputs for conv2d, fully_connected and max_pool2d. 99 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 100 | outputs_collections=[end_points_collection]): 101 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 102 | scope='conv1') 103 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool1') 104 | net = slim.conv2d(net, 192, [5, 5], scope='conv2') 105 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool2') 106 | net = slim.conv2d(net, 384, [3, 3], scope='conv3') 107 | net = slim.conv2d(net, 384, [3, 3], scope='conv4') 108 | net = slim.conv2d(net, 256, [3, 3], scope='conv5') 109 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool5') 110 | 111 | # Use conv2d instead of fully_connected layers. 112 | with slim.arg_scope([slim.conv2d], 113 | weights_initializer=trunc_normal(0.005), 114 | biases_initializer=tf.constant_initializer(0.1)): 115 | net = slim.conv2d(net, 4096, [5, 5], padding='VALID', 116 | scope='fc6') 117 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 118 | scope='dropout6') 119 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 120 | # Convert end_points_collection into a end_point dict. 121 | end_points = slim.utils.convert_collection_to_dict( 122 | end_points_collection) 123 | if global_pool: 124 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 125 | end_points['global_pool'] = net 126 | if num_classes: 127 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 128 | scope='dropout7') 129 | net = slim.conv2d(net, num_classes, [1, 1], 130 | activation_fn=None, 131 | normalizer_fn=None, 132 | biases_initializer=tf.zeros_initializer(), 133 | scope='fc8') 134 | if spatial_squeeze: 135 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 136 | end_points[sc.name + '/fc8'] = net 137 | return net, end_points 138 | alexnet_v2.default_image_size = 224 139 | -------------------------------------------------------------------------------- /pig_vgg16/pig_model_vgg.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Nov 23 15:07:52 2017 4 | 5 | @author: DELL 6 | """ 7 | 8 | #!/usr/bin/env python3 9 | # -*- coding: utf-8 -*- 10 | """ 11 | Created on Sat Sep 30 16:16:13 2017 12 | 13 | @author: no1 14 | """ 15 | 16 | from __future__ import absolute_import 17 | from __future__ import division 18 | from __future__ import print_function 19 | 20 | import tensorflow as tf 21 | import pig_input 22 | import config 23 | 24 | IMAGE_WIDTH = config.IMAGE_WIDTH 25 | IMAGE_HEIGHT = config.IMAGE_HEIGHT 26 | CLASSES_NUM = config.CLASSES_NUM 27 | rate = 0.01 28 | 29 | def inputs(train, batch_size): 30 | return pig_input.inputs(train, batch_size=batch_size) 31 | 32 | def _conv(name, input, size, input_channels, output_channels, is_training=True): 33 | with tf.variable_scope(name) as scope: 34 | if not is_training: 35 | scope.reuse_variables() 36 | kernel = _weight_variable('weights', shape=[size, size ,input_channels, output_channels]) 37 | biases = _bias_variable('biases',[output_channels]) 38 | pre_activation = tf.nn.bias_add(_conv2d(input, kernel),biases) 39 | conv = tf.maximum(rate*pre_activation,pre_activation, name=scope.name) 40 | return conv 41 | 42 | def _conv2d(value, weight): 43 | """conv2d returns a 2d convolution layer with full stride.""" 44 | return tf.nn.conv2d(value, weight, strides=[1, 1, 1, 1], padding='SAME') 45 | 46 | 47 | def _max_pool_2x2(value, name, is_training): 48 | """max_pool_2x2 downsamples a feature map by 2X.""" 49 | with tf.variable_scope(name) as scope1: 50 | if not is_training: 51 | scope1.reuse_variables() 52 | return tf.nn.max_pool(value, ksize=[1, 2, 2, 1], 53 | strides=[1, 2, 2, 1], padding='SAME', name=name) 54 | 55 | 56 | def _weight_variable(name, shape): 57 | """weight_variable generates a weight variable of a given shape.""" 58 | initializer = tf.truncated_normal_initializer(stddev=0.1) 59 | var = tf.get_variable(name,shape,initializer=initializer, dtype=tf.float32) 60 | return var 61 | 62 | 63 | def _bias_variable(name, shape): 64 | """bias_variable generates a bias variable of a given shape.""" 65 | initializer = tf.constant_initializer(0.1) 66 | var = tf.get_variable(name, shape, initializer=initializer,dtype=tf.float32) 67 | return var 68 | 69 | def _batch_norm(name, inputs, is_training): 70 | """ Batch Normalization 71 | """ 72 | with tf.variable_scope(name, reuse = not is_training): 73 | # return tf.layers.batch_normalization(input,training=is_training) 74 | return tf.contrib.layers.batch_norm(inputs, 75 | decay=0.9, 76 | scale=True, 77 | updates_collections=None, 78 | is_training=True) 79 | def inference(images, keep_prob, is_training): 80 | images = tf.reshape(images, [-1, IMAGE_HEIGHT, IMAGE_WIDTH, 3]) # 256,256,3 81 | #%% vgg16 82 | # conv1_1 83 | 84 | conv1_1 = _conv('conv1_1', images, 3, 3, 64,is_training) #(batch,256,256,64) 85 | conv1_2 = _conv('conv1_2', conv1_1,3, 64,64,is_training)#(batch,256,256,64) 86 | pool1 = _max_pool_2x2(conv1_2,'pool1',is_training) 87 | 88 | 89 | 90 | conv2_1 = _conv('conv2_1', pool1, 3,64,128,is_training) #(60, 128, 128, 128) 91 | conv2_2 = _conv('conv2_2', conv2_1,3,128,128,is_training) #(60, 128, 128, 128) 92 | pool2 = _max_pool_2x2(conv2_2, 'pool2',is_training) 93 | 94 | 95 | conv3_1 = _conv('conv3_1', pool2, 3, 128, 256,is_training) #(60, 64, 64, 256) 96 | conv3_2 = _conv('conv3_2',conv3_1, 3, 256, 256,is_training)#(60, 64, 64, 256) 97 | conv3_3 = _conv('conv3_3',conv3_2, 3, 256, 256,is_training)#(60, 64, 64, 256) 98 | pool3 = _max_pool_2x2(conv3_3, 'pool3',is_training) 99 | 100 | 101 | conv4_1 = _conv('conv4_1',pool3, 3, 256, 512,is_training) 102 | conv4_2 = _conv('conv4_2',conv4_1, 3, 512, 512,is_training) 103 | conv4_3 = _conv('conv4_3',conv4_2, 3, 512, 512,is_training) 104 | pool4 = _max_pool_2x2(conv4_3, 'pool4',is_training) 105 | 106 | conv5_1 = _conv('conv5_1',pool4, 3, 512, 512,is_training) 107 | conv5_2 = _conv('conv5_2',conv5_1, 3, 512, 512,is_training) 108 | conv5_3 = _conv('conv5_3',conv5_2, 3, 512, 128,is_training) 109 | pool5 = _max_pool_2x2(conv5_3, 'pool5',is_training) #(batch,14,14,512) 110 | norm = _batch_norm('norm', pool5, is_training) 111 | #%% 112 | with tf.variable_scope('local1') as scope14: 113 | if not is_training: 114 | scope14.reuse_variables() 115 | tensor_shape = norm.get_shape().as_list() 116 | reshape = tf.reshape(norm, [-1, tensor_shape[1]*tensor_shape[2]*tensor_shape[3]]) 117 | dim = reshape.get_shape()[1].value 118 | weights = _weight_variable('weights', shape=[dim,1024]) 119 | biases = _bias_variable('biases',[1024]) 120 | local1 = tf.nn.relu(tf.matmul(reshape,weights) + biases, name=scope14.name) 121 | 122 | local1_drop = tf.nn.dropout(local1, keep_prob) 123 | 124 | with tf.variable_scope('softmax_linear') as scope15: 125 | if not is_training: 126 | scope15.reuse_variables() 127 | weights = _weight_variable('weights',shape=[1024,CLASSES_NUM]) 128 | biases = _bias_variable('biases',[CLASSES_NUM]) 129 | softmax_linear = tf.add(tf.matmul(local1_drop,weights), biases, name=scope15.name) 130 | 131 | return tf.reshape(softmax_linear, [-1, CLASSES_NUM]) 132 | 133 | 134 | def loss(logits, labels): 135 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits( 136 | labels=labels, logits=logits, name='corss_entropy_per_example') 137 | cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') 138 | tf.add_to_collection('losses', cross_entropy_mean) 139 | return tf.add_n(tf.get_collection('losses'), name='total_loss') 140 | 141 | 142 | def training(loss): 143 | optimizer = tf.train.AdamOptimizer(1e-4) 144 | gen_grads_and_vars = optimizer.compute_gradients(loss) 145 | gen_train = optimizer.apply_gradients(gen_grads_and_vars) 146 | ema = tf.train.ExponentialMovingAverage(decay=0.99) 147 | update_losses = ema.apply([loss]) 148 | 149 | global_step = tf.contrib.framework.get_or_create_global_step() 150 | incr_global_step = tf.assign(global_step, global_step+1) 151 | 152 | return tf.group(update_losses, incr_global_step, gen_train) 153 | 154 | 155 | 156 | def evaluation(logits, labels): 157 | correct_prediction = tf.equal(tf.argmax(logits,1), tf.argmax(labels,1)) 158 | return tf.reduce_sum(tf.cast(correct_prediction, tf.float32)) 159 | 160 | 161 | def output(logits): 162 | return tf.nn.softmax(logits) 163 | 164 | def predict(logits): 165 | return tf.argmax(logits, 1) -------------------------------------------------------------------------------- /pig_vgg16/pig_train_1.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import sys 7 | import tensorflow as tf 8 | import pig_model as captcha 9 | import logging 10 | import numpy as np 11 | import glob 12 | import matplotlib.pylab as plt 13 | import math 14 | import config 15 | import os 16 | from datetime import datetime 17 | ''' 18 | 一头一头猪的数据进行训练 19 | ''' 20 | learning_rate = 2e-4 21 | epoch = 100 22 | batch = 16 23 | class_num = 30 24 | FLAGS = None 25 | IMAGE_WIDTH = 320 26 | IMAGE_HEIGHT = config.IMAGE_HEIGHT 27 | 28 | checkpoint_dir = 'ckpt' 29 | checkpoint_file = os.path.join(checkpoint_dir, 'model.ckpt') 30 | train_dir='summary2' 31 | 32 | #def initLogging(logFilename='record.log'):# """Init for logging 33 | # """ 34 | # logging.basicConfig( 35 | # level = logging.DEBUG, 36 | # format='%(asctime)s-%(levelname)s-%(message)s', 37 | # datefmt = '%y-%m-%d %H:%M', 38 | # filename = logFilename, 39 | # filemode = 'w'); 40 | # console = logging.StreamHandler() 41 | # console.setLevel(logging.INFO) 42 | # formatter = logging.Formatter('%(asctime)s-%(levelname)s-%(message)s') 43 | # console.setFormatter(formatter) 44 | # logging.getLogger('').addHandler(console) 45 | #initLogging() 46 | 47 | def inputs(epoch,batch,img_all,lab_all): 48 | for i in range(epoch): 49 | shuf_num = np.random.permutation(1) 50 | # shuf_num = 0 51 | for i in shuf_num: 52 | lab_all1 = lab_all[shuf_num[i]] 53 | img_all1 = img_all[shuf_num[i]] 54 | for imgs in img_all1: 55 | batch_per_epoch = math.ceil(imgs.shape[0] / batch) 56 | for b in range(batch_per_epoch): 57 | if (b*batch+batch)>imgs.shape[0]: 58 | m,n = b*batch, imgs.shape[0] 59 | else: 60 | m,n = b*batch, b*batch+batch 61 | 62 | x_batch, label_batch = img_all1[m:n,:], lab_all1[m:n,:] 63 | yield x_batch, label_batch 64 | 65 | def label_to_one_hot(label): 66 | one_hot_label = np.zeros([1,class_num]) 67 | one_hot_label[:,label] = 1.0 68 | return one_hot_label.astype(np.uint8) #(4,10) 69 | 70 | 71 | def run_train(): 72 | """Train CAPTCHA for a number of steps.""" 73 | lab_all = [] 74 | img_all = [] 75 | for i in range(1): 76 | file = glob.glob('D:/pig_recognize_body/pig_body/train_data/' + str(i+1) + '_*.jpg') 77 | images = plt.imread(file[0]) 78 | images = plt.resize(images, [IMAGE_WIDTH,IMAGE_WIDTH,3]) 79 | images = np.reshape(images, [-1,IMAGE_WIDTH*IMAGE_WIDTH*3]) 80 | # images = np.expand_dims(images,0) 81 | # num = len(file) 82 | m = 1 83 | for j in file: 84 | image = plt.imread(j) 85 | image = plt.resize(image, [IMAGE_WIDTH,IMAGE_WIDTH,3]) 86 | 87 | image = image * (1. / 127.5) - 1 #(-1,1) 88 | image = np.reshape(image, [-1,IMAGE_WIDTH*IMAGE_WIDTH*3]) 89 | # image = np.expand_dims(image,0) 90 | images = np.append(images,image,0) 91 | m = m+1 92 | if m>=100: 93 | break 94 | 95 | label = label_to_one_hot(i) 96 | labels = np.tile(label,(m,1)) 97 | img_all.append(images) 98 | lab_all.append(labels) 99 | 100 | current_time = datetime.now().strftime('%Y%m%d-%H%M') 101 | checkpoints_dir = 'checkpoints/{}'.format(current_time) 102 | try: 103 | os.makedirs(checkpoints_dir) 104 | except os.error: 105 | pass 106 | 107 | with tf.Graph().as_default(): 108 | images = tf.placeholder(tf.float32, [None, IMAGE_WIDTH*IMAGE_WIDTH*3], name='inputs') 109 | labels = tf.placeholder(tf.float32, [None, class_num], name='labels') 110 | 111 | logits = captcha.inference(images, keep_prob=0.75,is_training=True) 112 | loss = captcha.loss(logits, labels) 113 | # correct = captcha.evaluation(logits, labels)#train 114 | 115 | # train_precision = correct/batch 116 | tf.summary.scalar('loss', loss) 117 | # tf.summary.scalar('train_precision', train_precision) 118 | # tf.summary.image('images',images,10) 119 | summary = tf.summary.merge_all() 120 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss) 121 | saver = tf.train.Saver() 122 | init = tf.global_variables_initializer() 123 | 124 | with tf.Session() as sess: 125 | # saver.restore(sess, tf.train.latest_checkpoint('ckpt_fm2')) 126 | sess.run(init) 127 | summary_writer = tf.summary.FileWriter(train_dir, sess.graph) 128 | try: 129 | shuffle_test= inputs(epoch,batch,img_all,lab_all) 130 | for step, (x_batch, l_batch) in enumerate(shuffle_test): 131 | 132 | feed_dict = {images:x_batch, labels:l_batch} 133 | _, loss_ = sess.run([train_op, loss], feed_dict=feed_dict) 134 | 135 | summary_str = sess.run(summary, feed_dict=feed_dict) 136 | summary_writer.add_summary(summary_str, step) 137 | summary_writer.flush() 138 | 139 | if step % 100 == 0: 140 | print('>> Step %d run_test: batch_precision = %.2f ' 141 | % (step,step)) 142 | if step % 500 == 0 : 143 | saver.save(sess, checkpoint_file, global_step=step) 144 | except KeyboardInterrupt: 145 | print('INTERRUPTED') 146 | 147 | finally: 148 | saver.save(sess, checkpoint_file, global_step=step) 149 | print('Model saved in file :%s'%FLAGS.checkpoint) 150 | 151 | def main(_): 152 | # if tf.gfile.Exists(FLAGS.train_dir): 153 | # tf.gfile.DeleteRecursively(FLAGS.train_dir) 154 | # tf.gfile.MakeDirs(FLAGS.train_dir) 155 | run_train() 156 | 157 | 158 | if __name__ == '__main__': 159 | parser = argparse.ArgumentParser() 160 | parser.add_argument( 161 | '--batch_size', 162 | type=int, 163 | default=8, 164 | help='Batch size.' 165 | ) 166 | parser.add_argument( 167 | '--train_dir', 168 | type=str, 169 | default='pig_train', 170 | help='Directory where to write event logs.' 171 | ) 172 | parser.add_argument( 173 | '--checkpoint', 174 | type=str, 175 | default='checkpoint/model.ckpt', 176 | help='Directory where to write checkpoint.' 177 | ) 178 | 179 | FLAGS, unparsed = parser.parse_known_args() 180 | tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 181 | -------------------------------------------------------------------------------- /nets/nets_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import functools 21 | 22 | import tensorflow as tf 23 | 24 | from nets import alexnet 25 | from nets import cifarnet 26 | from nets import inception 27 | from nets import lenet 28 | from nets import mobilenet_v1 29 | from nets import overfeat 30 | from nets import resnet_v1 31 | from nets import resnet_v2 32 | from nets import vgg 33 | from nets.nasnet import nasnet 34 | 35 | slim = tf.contrib.slim 36 | 37 | networks_map = {'alexnet_v2': alexnet.alexnet_v2, 38 | 'cifarnet': cifarnet.cifarnet, 39 | 'overfeat': overfeat.overfeat, 40 | 'vgg_a': vgg.vgg_a, 41 | 'vgg_16': vgg.vgg_16, 42 | 'vgg_19': vgg.vgg_19, 43 | 'inception_v1': inception.inception_v1, 44 | 'inception_v2': inception.inception_v2, 45 | 'inception_v3': inception.inception_v3, 46 | 'inception_v4': inception.inception_v4, 47 | 'inception_resnet_v2': inception.inception_resnet_v2, 48 | 'lenet': lenet.lenet, 49 | 'resnet_v1_50': resnet_v1.resnet_v1_50, 50 | 'resnet_v1_101': resnet_v1.resnet_v1_101, 51 | 'resnet_v1_152': resnet_v1.resnet_v1_152, 52 | 'resnet_v1_200': resnet_v1.resnet_v1_200, 53 | 'resnet_v2_50': resnet_v2.resnet_v2_50, 54 | 'resnet_v2_101': resnet_v2.resnet_v2_101, 55 | 'resnet_v2_152': resnet_v2.resnet_v2_152, 56 | 'resnet_v2_200': resnet_v2.resnet_v2_200, 57 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1, 58 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_075, 59 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_050, 60 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_025, 61 | 'nasnet_cifar': nasnet.build_nasnet_cifar, 62 | 'nasnet_mobile': nasnet.build_nasnet_mobile, 63 | 'nasnet_large': nasnet.build_nasnet_large, 64 | } 65 | 66 | arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope, 67 | 'cifarnet': cifarnet.cifarnet_arg_scope, 68 | 'overfeat': overfeat.overfeat_arg_scope, 69 | 'vgg_a': vgg.vgg_arg_scope, 70 | 'vgg_16': vgg.vgg_arg_scope, 71 | 'vgg_19': vgg.vgg_arg_scope, 72 | 'inception_v1': inception.inception_v3_arg_scope, 73 | 'inception_v2': inception.inception_v3_arg_scope, 74 | 'inception_v3': inception.inception_v3_arg_scope, 75 | 'inception_v4': inception.inception_v4_arg_scope, 76 | 'inception_resnet_v2': 77 | inception.inception_resnet_v2_arg_scope, 78 | 'lenet': lenet.lenet_arg_scope, 79 | 'resnet_v1_50': resnet_v1.resnet_arg_scope, 80 | 'resnet_v1_101': resnet_v1.resnet_arg_scope, 81 | 'resnet_v1_152': resnet_v1.resnet_arg_scope, 82 | 'resnet_v1_200': resnet_v1.resnet_arg_scope, 83 | 'resnet_v2_50': resnet_v2.resnet_arg_scope, 84 | 'resnet_v2_101': resnet_v2.resnet_arg_scope, 85 | 'resnet_v2_152': resnet_v2.resnet_arg_scope, 86 | 'resnet_v2_200': resnet_v2.resnet_arg_scope, 87 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1_arg_scope, 88 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_arg_scope, 89 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_arg_scope, 90 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_arg_scope, 91 | 'nasnet_cifar': nasnet.nasnet_cifar_arg_scope, 92 | 'nasnet_mobile': nasnet.nasnet_mobile_arg_scope, 93 | 'nasnet_large': nasnet.nasnet_large_arg_scope, 94 | } 95 | 96 | 97 | def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False): 98 | """Returns a network_fn such as `logits, end_points = network_fn(images)`. 99 | 100 | Args: 101 | name: The name of the network. 102 | num_classes: The number of classes to use for classification. If 0 or None, 103 | the logits layer is omitted and its input features are returned instead. 104 | weight_decay: The l2 coefficient for the model weights. 105 | is_training: `True` if the model is being used for training and `False` 106 | otherwise. 107 | 108 | Returns: 109 | network_fn: A function that applies the model to a batch of images. It has 110 | the following signature: 111 | net, end_points = network_fn(images) 112 | The `images` input is a tensor of shape [batch_size, height, width, 3] 113 | with height = width = network_fn.default_image_size. (The permissibility 114 | and treatment of other sizes depends on the network_fn.) 115 | The returned `end_points` are a dictionary of intermediate activations. 116 | The returned `net` is the topmost layer, depending on `num_classes`: 117 | If `num_classes` was a non-zero integer, `net` is a logits tensor 118 | of shape [batch_size, num_classes]. 119 | If `num_classes` was 0 or `None`, `net` is a tensor with the input 120 | to the logits layer of shape [batch_size, 1, 1, num_features] or 121 | [batch_size, num_features]. Dropout has not been applied to this 122 | (even if the network's original classification does); it remains for 123 | the caller to do this or not. 124 | 125 | Raises: 126 | ValueError: If network `name` is not recognized. 127 | """ 128 | if name not in networks_map: 129 | raise ValueError('Name of network unknown %s' % name) 130 | func = networks_map[name] 131 | @functools.wraps(func) 132 | def network_fn(images, **kwargs): 133 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay) 134 | with slim.arg_scope(arg_scope): 135 | return func(images, num_classes, is_training=is_training, **kwargs) 136 | if hasattr(func, 'default_image_size'): 137 | network_fn.default_image_size = func.default_image_size 138 | 139 | return network_fn 140 | -------------------------------------------------------------------------------- /nets/pix2pix_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | """Tests for pix2pix.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from nets import pix2pix 23 | 24 | 25 | class GeneratorTest(tf.test.TestCase): 26 | 27 | def test_nonsquare_inputs_raise_exception(self): 28 | batch_size = 2 29 | height, width = 240, 320 30 | num_outputs = 4 31 | 32 | images = tf.ones((batch_size, height, width, 3)) 33 | 34 | with self.assertRaises(ValueError): 35 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 36 | pix2pix.pix2pix_generator( 37 | images, num_outputs, upsample_method='nn_upsample_conv') 38 | 39 | def _reduced_default_blocks(self): 40 | """Returns the default blocks, scaled down to make test run faster.""" 41 | return [pix2pix.Block(b.num_filters // 32, b.decoder_keep_prob) 42 | for b in pix2pix._default_generator_blocks()] 43 | 44 | def test_output_size_nn_upsample_conv(self): 45 | batch_size = 2 46 | height, width = 256, 256 47 | num_outputs = 4 48 | 49 | images = tf.ones((batch_size, height, width, 3)) 50 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 51 | logits, _ = pix2pix.pix2pix_generator( 52 | images, num_outputs, blocks=self._reduced_default_blocks(), 53 | upsample_method='nn_upsample_conv') 54 | 55 | with self.test_session() as session: 56 | session.run(tf.global_variables_initializer()) 57 | np_outputs = session.run(logits) 58 | self.assertListEqual([batch_size, height, width, num_outputs], 59 | list(np_outputs.shape)) 60 | 61 | def test_output_size_conv2d_transpose(self): 62 | batch_size = 2 63 | height, width = 256, 256 64 | num_outputs = 4 65 | 66 | images = tf.ones((batch_size, height, width, 3)) 67 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 68 | logits, _ = pix2pix.pix2pix_generator( 69 | images, num_outputs, blocks=self._reduced_default_blocks(), 70 | upsample_method='conv2d_transpose') 71 | 72 | with self.test_session() as session: 73 | session.run(tf.global_variables_initializer()) 74 | np_outputs = session.run(logits) 75 | self.assertListEqual([batch_size, height, width, num_outputs], 76 | list(np_outputs.shape)) 77 | 78 | def test_block_number_dictates_number_of_layers(self): 79 | batch_size = 2 80 | height, width = 256, 256 81 | num_outputs = 4 82 | 83 | images = tf.ones((batch_size, height, width, 3)) 84 | blocks = [ 85 | pix2pix.Block(64, 0.5), 86 | pix2pix.Block(128, 0), 87 | ] 88 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 89 | _, end_points = pix2pix.pix2pix_generator( 90 | images, num_outputs, blocks) 91 | 92 | num_encoder_layers = 0 93 | num_decoder_layers = 0 94 | for end_point in end_points: 95 | if end_point.startswith('encoder'): 96 | num_encoder_layers += 1 97 | elif end_point.startswith('decoder'): 98 | num_decoder_layers += 1 99 | 100 | self.assertEqual(num_encoder_layers, len(blocks)) 101 | self.assertEqual(num_decoder_layers, len(blocks)) 102 | 103 | 104 | class DiscriminatorTest(tf.test.TestCase): 105 | 106 | def _layer_output_size(self, input_size, kernel_size=4, stride=2, pad=2): 107 | return (input_size + pad * 2 - kernel_size) // stride + 1 108 | 109 | def test_four_layers(self): 110 | batch_size = 2 111 | input_size = 256 112 | 113 | output_size = self._layer_output_size(input_size) 114 | output_size = self._layer_output_size(output_size) 115 | output_size = self._layer_output_size(output_size) 116 | output_size = self._layer_output_size(output_size, stride=1) 117 | output_size = self._layer_output_size(output_size, stride=1) 118 | 119 | images = tf.ones((batch_size, input_size, input_size, 3)) 120 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 121 | logits, end_points = pix2pix.pix2pix_discriminator( 122 | images, num_filters=[64, 128, 256, 512]) 123 | self.assertListEqual([batch_size, output_size, output_size, 1], 124 | logits.shape.as_list()) 125 | self.assertListEqual([batch_size, output_size, output_size, 1], 126 | end_points['predictions'].shape.as_list()) 127 | 128 | def test_four_layers_no_padding(self): 129 | batch_size = 2 130 | input_size = 256 131 | 132 | output_size = self._layer_output_size(input_size, pad=0) 133 | output_size = self._layer_output_size(output_size, pad=0) 134 | output_size = self._layer_output_size(output_size, pad=0) 135 | output_size = self._layer_output_size(output_size, stride=1, pad=0) 136 | output_size = self._layer_output_size(output_size, stride=1, pad=0) 137 | 138 | images = tf.ones((batch_size, input_size, input_size, 3)) 139 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 140 | logits, end_points = pix2pix.pix2pix_discriminator( 141 | images, num_filters=[64, 128, 256, 512], padding=0) 142 | self.assertListEqual([batch_size, output_size, output_size, 1], 143 | logits.shape.as_list()) 144 | self.assertListEqual([batch_size, output_size, output_size, 1], 145 | end_points['predictions'].shape.as_list()) 146 | 147 | def test_four_layers_wrog_paddig(self): 148 | batch_size = 2 149 | input_size = 256 150 | 151 | images = tf.ones((batch_size, input_size, input_size, 3)) 152 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 153 | with self.assertRaises(TypeError): 154 | pix2pix.pix2pix_discriminator( 155 | images, num_filters=[64, 128, 256, 512], padding=1.5) 156 | 157 | def test_four_layers_negative_padding(self): 158 | batch_size = 2 159 | input_size = 256 160 | 161 | images = tf.ones((batch_size, input_size, input_size, 3)) 162 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 163 | with self.assertRaises(ValueError): 164 | pix2pix.pix2pix_discriminator( 165 | images, num_filters=[64, 128, 256, 512], padding=-1) 166 | 167 | if __name__ == '__main__': 168 | tf.test.main() 169 | -------------------------------------------------------------------------------- /pig_vgg16/pig_model_dark.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Nov 22 12:27:34 2017 4 | 5 | @author: DELL 6 | """ 7 | ''' 8 | 使用 darknet19网络 9 | ''' 10 | #from __future__ import absolute_import 11 | #from __future__ import division 12 | #from __future__ import print_function 13 | 14 | import tensorflow as tf 15 | import pig_input 16 | import config 17 | 18 | IMAGE_WIDTH = config.IMAGE_WIDTH 19 | IMAGE_HEIGHT = config.IMAGE_HEIGHT 20 | CLASSES_NUM = config.CLASSES_NUM 21 | rate = 0.01 22 | 23 | def inputs(train, batch_size): 24 | return pig_input.inputs(train, batch_size=batch_size) 25 | 26 | def _conv(name, input, size, input_channels, output_channels, is_training=True): 27 | with tf.variable_scope(name) as scope: 28 | if not is_training: 29 | scope.reuse_variables() 30 | kernel = _weight_variable('weights', shape=[size, size ,input_channels, output_channels]) 31 | biases = _bias_variable('biases',[output_channels]) 32 | pre_activation = tf.nn.bias_add(_conv2d(input, kernel),biases) 33 | conv = tf.nn.relu(pre_activation) 34 | # conv = tf.maximum(rate*pre_activation,pre_activation, name=scope.name) 35 | # conv = _batch_norm('norm', conv, is_training) 36 | return conv 37 | 38 | def _conv2d(value, weight): 39 | """conv2d returns a 2d convolution layer with full stride.""" 40 | return tf.nn.conv2d(value, weight, strides=[1, 1, 1, 1], padding='SAME') 41 | 42 | 43 | def _max_pool_2x2(value, name, is_training): 44 | """max_pool_2x2 downsamples a feature map by 2X.""" 45 | with tf.variable_scope(name) as scope1: 46 | if not is_training: 47 | scope1.reuse_variables() 48 | return tf.nn.max_pool(value, ksize=[1, 2, 2, 1], 49 | strides=[1, 2, 2, 1], padding='SAME', name=name) 50 | 51 | def _weight_variable(name, shape): 52 | """weight_variable generates a weight variable of a given shape.""" 53 | initializer = tf.truncated_normal_initializer(stddev=0.1) 54 | var = tf.get_variable(name,shape,initializer=initializer, dtype=tf.float32) 55 | return var 56 | 57 | 58 | def _bias_variable(name, shape): 59 | """bias_variable generates a bias variable of a given shape.""" 60 | initializer = tf.constant_initializer(0.1) 61 | var = tf.get_variable(name, shape, initializer=initializer,dtype=tf.float32) 62 | return var 63 | 64 | def _batch_norm(name, inputs, is_training): 65 | """ Batch Normalization 66 | """ 67 | with tf.variable_scope(name, reuse = not is_training): 68 | # return tf.layers.batch_normalization(input,training=is_training) 69 | return tf.contrib.layers.batch_norm(inputs, 70 | decay=0.9, 71 | scale=True, 72 | updates_collections=None, 73 | is_training=True) 74 | def inference(images, keep_prob, is_training): 75 | 76 | images = tf.reshape(images, [-1, IMAGE_HEIGHT, IMAGE_WIDTH, 3]) # 320,320,3 77 | 78 | conv1 = _conv('conv1', images, 3, 3, 32, is_training) 79 | # conv1 = _batch_norm('norm', conv1, is_training) 80 | pool1 = _max_pool_2x2(conv1, name='pool1', is_training=is_training) 81 | 82 | conv2 = _conv('conv2', pool1, 3, 32, 64, is_training) 83 | conv2 = _batch_norm('norm1', conv2, is_training) 84 | pool2 = _max_pool_2x2(conv2, name='pool2', is_training=is_training) 85 | 86 | conv3 = _conv('conv3', pool2, 3, 64, 128, is_training) # 80*80*128 87 | conv4 = _conv('conv4', conv3, 1, 128, 64, is_training) 88 | conv5 = _conv('conv5', conv4, 3, 64, 128, is_training) 89 | conv5 = _batch_norm('norm2', conv5, is_training) 90 | 91 | pool3 = _max_pool_2x2(conv5, name='pool3', is_training=is_training) 92 | 93 | conv6 = _conv('conv6', pool3, 3, 128, 256, is_training) # 40*40*256 94 | conv7 = _conv('conv7', conv6, 1, 256, 128, is_training) 95 | conv8 = _conv('conv8', conv7, 3, 128, 256, is_training) 96 | conv8 = _batch_norm('norm3', conv8, is_training) 97 | 98 | pool4 = _max_pool_2x2(conv8, name='pool4', is_training=is_training) # 8,28 99 | 100 | conv9 = _conv('conv9', pool4, 3, 256, 512, is_training) # 20*20*512 101 | conv10 = _conv('conv10', conv9, 1, 512, 256, is_training) 102 | conv11 = _conv('conv11', conv10, 3, 256, 512, is_training) 103 | conv12 = _conv('conv12', conv11, 1, 512, 256, is_training) 104 | conv13 = _conv('conv13', conv12, 3, 256, 512, is_training) 105 | conv13 = _batch_norm('norm4', conv13, is_training) 106 | 107 | pool5 = _max_pool_2x2(conv13, name='pool5', is_training=is_training) # 8,28 108 | 109 | conv14 = _conv('conv14', pool5, 3, 512, 1024, is_training) # 10*10*1024 110 | conv15 = _conv('conv15', conv14, 1, 1024, 512, is_training) 111 | conv16 = _conv('conv16', conv15, 3, 512, 1024, is_training) 112 | conv17 = _conv('conv17', conv16, 1, 1024, 512, is_training) 113 | conv18 = _conv('conv18', conv17, 3, 512, 1024, is_training) 114 | conv18 = _batch_norm('norm5', conv18, is_training) 115 | 116 | pool6 = _max_pool_2x2(conv18, name='pool6', is_training=is_training) # 5*5*1024 117 | batch_size = int(pool6.get_shape()[0]) 118 | dense = tf.reshape(pool5, [batch_size,-1]) 119 | dense1 = tf.layers.dense(dense, 1024) 120 | dense1 = tf.nn.relu(dense1) 121 | dense1 = tf.nn.dropout(dense1, keep_prob) 122 | dense2 = tf.layers.dense(dense1, 256) 123 | dense2 = tf.nn.relu(dense2) 124 | dense2 = tf.nn.dropout(dense2, keep_prob) 125 | dense3 = tf.layers.dense(dense2, 30) 126 | # output = tf.nn.sigmoid(dense3) 127 | # output = tf.nn.softmax(dense3) 128 | 129 | return dense3 130 | 131 | 132 | def loss(logits, labels): 133 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits( 134 | labels=labels, logits=logits, name='corss_entropy_per_example') 135 | cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') 136 | tf.add_to_collection('losses', cross_entropy_mean) 137 | return tf.add_n(tf.get_collection('losses'), name='total_loss') 138 | 139 | 140 | def training(loss): 141 | optimizer = tf.train.AdamOptimizer(1e-4) 142 | gen_grads_and_vars = optimizer.compute_gradients(loss) 143 | gen_train = optimizer.apply_gradients(gen_grads_and_vars) 144 | ema = tf.train.ExponentialMovingAverage(decay=0.99) 145 | update_losses = ema.apply([loss]) 146 | 147 | global_step = tf.contrib.framework.get_or_create_global_step() 148 | incr_global_step = tf.assign(global_step, global_step+1) 149 | 150 | return tf.group(update_losses, incr_global_step, gen_train) 151 | 152 | 153 | 154 | def evaluation(logits, labels): 155 | correct_prediction = tf.equal(tf.argmax(logits,1), tf.argmax(labels,1)) 156 | return tf.reduce_sum(tf.cast(correct_prediction, tf.float32)) 157 | 158 | 159 | def output(logits): 160 | return tf.nn.softmax(logits) 161 | 162 | def predict(logits): 163 | return tf.argmax(logits, 1) --------------------------------------------------------------------------------