├── .gitignore ├── .idea ├── Defect_Detection.iml ├── misc.xml ├── modules.xml ├── other.xml ├── vcs.xml └── workspace.xml ├── Detect ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── config.cpython-36.pyc │ ├── inference.cpython-36.pyc │ ├── mynet.cpython-36.pyc │ ├── submit.cpython-36.pyc │ ├── train.cpython-36.pyc │ └── utils.cpython-36.pyc ├── config.py ├── evaluate.py ├── inference.py ├── model_merge.py ├── mynet.py ├── phone_seg.py ├── predict.py ├── submit.py ├── train.py └── utils.py ├── Image_preprocessing ├── __init__.py ├── create_data.py ├── image2npy.py ├── image_augmentation.py ├── image_crop.py ├── image_segmentation.py └── json2npy.py ├── Siamese ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── inference.cpython-36.pyc │ └── utils.cpython-36.pyc ├── bin_classify.py ├── config.py ├── data_cache.py ├── evaluate.py ├── inference.py ├── inference_2.py ├── predict.py ├── train.py ├── train_2.py └── utils.py ├── Siamese_multi └── __pycache__ │ ├── __init__.cpython-36.pyc │ └── inference.cpython-36.pyc ├── camera ├── DllTest.dll ├── __init__.py └── test.py ├── mv ├── __init__.py ├── base.py ├── cig_detection.py └── evaluate.py ├── nets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── alexnet.cpython-36.pyc │ ├── inception_resnet_v2.cpython-36.pyc │ ├── inception_utils.cpython-36.pyc │ ├── inception_v4.cpython-36.pyc │ ├── mobilenet_v1.cpython-36.pyc │ ├── resnet_utils.cpython-36.pyc │ ├── resnet_v2.cpython-36.pyc │ └── vgg.cpython-36.pyc ├── alexnet.py ├── alexnet_test.py ├── cifarnet.py ├── cyclegan.py ├── cyclegan_test.py ├── dcgan.py ├── dcgan_test.py ├── inception.py ├── inception_resnet_v2.py ├── inception_resnet_v2_test.py ├── inception_utils.py ├── inception_v1.py ├── inception_v1_test.py ├── inception_v2.py ├── inception_v2_test.py ├── inception_v3.py ├── inception_v3_test.py ├── inception_v4.py ├── inception_v4_test.py ├── lenet.py ├── mobilenet │ ├── README.md │ ├── __init__.py │ ├── conv_blocks.py │ ├── madds_top1_accuracy.png │ ├── mnet_v1_vs_v2_pixel1_latency.png │ ├── mobilenet.py │ ├── mobilenet_example.ipynb │ ├── mobilenet_v2.py │ └── mobilenet_v2_test.py ├── mobilenet_v1.md ├── mobilenet_v1.png ├── mobilenet_v1.py ├── mobilenet_v1_eval.py ├── mobilenet_v1_test.py ├── mobilenet_v1_train.py ├── nasnet │ ├── README.md │ ├── __init__.py │ ├── nasnet.py │ ├── nasnet_test.py │ ├── nasnet_utils.py │ ├── nasnet_utils_test.py │ ├── pnasnet.py │ └── pnasnet_test.py ├── nets_factory.py ├── nets_factory_test.py ├── overfeat.py ├── overfeat_test.py ├── pix2pix.py ├── pix2pix_test.py ├── resnet_utils.py ├── resnet_v1.py ├── resnet_v1_test.py ├── resnet_v2.py ├── resnet_v2_test.py ├── vgg.py └── vgg_test.py ├── ocr ├── __init__.py └── ocr.py ├── qt ├── 1.jpg ├── LogDialog.py ├── MainWindow.py ├── __pycache__ │ ├── LogDialog.cpython-36.pyc │ ├── MainWindow.cpython-36.pyc │ ├── predict.cpython-36.pyc │ ├── predict_dl.cpython-36.pyc │ └── predict_ip.cpython-36.pyc ├── cv.py ├── hello-world-master │ ├── 1.txt │ └── README.md ├── log │ ├── 1 │ └── detect_20181016_134111.csv ├── main.py ├── nets │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── alexnet.cpython-36.pyc │ │ ├── mobilenet_v1.cpython-36.pyc │ │ └── vgg.cpython-36.pyc │ ├── alexnet.py │ └── mobilenet_v1.py ├── predict_dl.py ├── predict_ip.py ├── src │ ├── 校标.png │ └── 欢迎.jpg ├── test.py ├── test1.jpg ├── ui_database_set.ui ├── ui_defect_log.ui ├── ui_detect_log.ui └── ui_detection.ui └── test.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data/ 2 | /log/ 3 | /qt/weight/ 4 | -------------------------------------------------------------------------------- /.idea/Defect_Detection.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 13 | 14 | 16 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /Detect/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Detect/__init__.py -------------------------------------------------------------------------------- /Detect/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Detect/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /Detect/__pycache__/config.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Detect/__pycache__/config.cpython-36.pyc -------------------------------------------------------------------------------- /Detect/__pycache__/inference.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Detect/__pycache__/inference.cpython-36.pyc -------------------------------------------------------------------------------- /Detect/__pycache__/mynet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Detect/__pycache__/mynet.cpython-36.pyc -------------------------------------------------------------------------------- /Detect/__pycache__/submit.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Detect/__pycache__/submit.cpython-36.pyc -------------------------------------------------------------------------------- /Detect/__pycache__/train.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Detect/__pycache__/train.cpython-36.pyc -------------------------------------------------------------------------------- /Detect/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Detect/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /Detect/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Detect 配置文件,全局参数 3 | author: 王建坤 4 | date: 2018-12-3 5 | """ 6 | import numpy as np 7 | import tensorflow as tf 8 | from nets import alexnet, mobilenet_v1, vgg, inception_v4, resnet_v2, inception_resnet_v2 9 | from Detect import mynet 10 | import os 11 | 12 | # 分类类别 13 | CLASSES = 5 14 | # 图像尺寸 15 | IMG_SIZE = 224 16 | # 图像通道数 17 | CHANNEL = 2 18 | # 是否为非标准图像尺寸 19 | GLOBAL_POOL = True 20 | # 是否使用GPU 21 | USE_GPU = True 22 | if not USE_GPU: 23 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 24 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 25 | # 模型名称 'Mobile' 'My' 26 | MODEL_NAME = 'My' 27 | # 数据集路径 28 | date_set_name = 'sia_cig_' # b_cig_, sia_cig_, s_sia_cig_ 29 | images_path = '../data/' + date_set_name + 'data_' + str(IMG_SIZE) + '.npy' 30 | labels_path = '../data/' + date_set_name + 'label_' + str(IMG_SIZE) + '.npy' 31 | 32 | 33 | -------------------------------------------------------------------------------- /Detect/evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Detect 性能评估 3 | author: 王建坤 4 | date: 2018-9-10 5 | """ 6 | from Detect.config import * 7 | import math 8 | from sklearn.model_selection import train_test_split 9 | 10 | BATCH_SIZE = 1 11 | IS_TRAINING = False 12 | 13 | 14 | def evaluate(model=MODEL_NAME): 15 | """ 16 | 评估模型 17 | :param model: model name 18 | :return: none 19 | """ 20 | # 预测为某类的样本个数,某类预测正确的样本个数 21 | sample_labels = np.zeros(CLASSES, dtype=np.uint16) 22 | pre_pos = np.zeros(CLASSES, dtype=np.uint16) 23 | true_pos = np.zeros(CLASSES, dtype=np.uint16) 24 | 25 | # 加载测试集 26 | images = np.load(images_path) 27 | labels = np.load(labels_path) 28 | 29 | _, val_data, _, val_label = train_test_split(images, labels, test_size=0.2, random_state=222) 30 | # 如果输入是灰色图,要增加一维 31 | if CHANNEL == 1: 32 | val_data = np.expand_dims(val_data, axis=3) 33 | 34 | test_num = val_data.shape[0] 35 | # 存放预测结果 36 | res = np.zeros(test_num, dtype=np.uint16) 37 | 38 | # 占位符 39 | x = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, CHANNEL], name="x_input") 40 | y_ = tf.placeholder(tf.uint8, [None], name="y_input") 41 | 42 | # 模型保存路径,模型名,预训练文件路径,前向传播 43 | if model == 'Alex': 44 | log_dir = "../log/Alex" 45 | y, _ = alexnet.alexnet_v2(x, 46 | num_classes=CLASSES, # 分类的类别 47 | is_training=IS_TRAINING, # 是否在训练 48 | dropout_keep_prob=1.0, # 保留比率 49 | spatial_squeeze=True, # 压缩掉1维的维度 50 | global_pool=GLOBAL_POOL) # 输入不是规定的尺寸时,需要global_pool 51 | elif model == 'My': 52 | log_dir = "../log/My" 53 | # y = mynet.mynet_v1(x, is_training=IS_TRAINING, num_classes=CLASSES) 54 | y, _ = mobilenet_v1.mobilenet_v1(x, 55 | num_classes=CLASSES, 56 | dropout_keep_prob=1.0, 57 | is_training=IS_TRAINING, 58 | min_depth=8, 59 | depth_multiplier=1.0, 60 | conv_defs=None, 61 | prediction_fn=None, 62 | spatial_squeeze=True, 63 | reuse=None, 64 | scope='MobilenetV1', 65 | global_pool=GLOBAL_POOL) 66 | elif model == 'Mobile': 67 | log_dir = "../log/Mobile" 68 | y, _ = mobilenet_v1.mobilenet_v1(x, 69 | num_classes=CLASSES, 70 | dropout_keep_prob=1.0, 71 | is_training=IS_TRAINING, 72 | min_depth=8, 73 | depth_multiplier=1.0, 74 | conv_defs=None, 75 | prediction_fn=None, 76 | spatial_squeeze=True, 77 | reuse=None, 78 | scope='MobilenetV1', 79 | global_pool=GLOBAL_POOL) 80 | elif model == 'VGG': 81 | log_dir = "../log/VGG" 82 | y, _ = vgg.vgg_16(x, 83 | num_classes=CLASSES, 84 | is_training=IS_TRAINING, 85 | dropout_keep_prob=1.0, 86 | spatial_squeeze=True, 87 | global_pool=GLOBAL_POOL) 88 | elif model == 'Incep4': 89 | log_dir = "E:/alum/log/Incep4" 90 | y, _ = inception_v4.inception_v4(x, num_classes=CLASSES, 91 | is_training=IS_TRAINING, 92 | dropout_keep_prob=1.0, 93 | reuse=None, 94 | scope='InceptionV4', 95 | create_aux_logits=True) 96 | elif model == 'Res': 97 | log_dir = "E:/alum/log/Res" 98 | y, _ = resnet_v2.resnet_v2_50(x, 99 | num_classes=CLASSES, 100 | is_training=IS_TRAINING, 101 | global_pool=GLOBAL_POOL, 102 | output_stride=None, 103 | spatial_squeeze=True, 104 | reuse=None, 105 | scope='resnet_v2_50') 106 | else: 107 | print('Error: model name not exist') 108 | return 109 | 110 | # 预测结果 111 | y_pre = tf.argmax(y, 1) 112 | saver = tf.train.Saver(tf.global_variables()) 113 | 114 | with tf.Session() as sess: 115 | # 恢复模型权重 116 | print('Reading checkpoints of', model) 117 | ckpt = tf.train.get_checkpoint_state(log_dir) 118 | if ckpt and ckpt.model_checkpoint_path: 119 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 120 | saver.restore(sess, ckpt.model_checkpoint_path) 121 | print('Loading success, step is %s' % global_step) 122 | else: 123 | print('Error: no checkpoint file found') 124 | return 125 | 126 | # 遍历一次测试集需要次数 127 | num_iter = int(math.ceil(test_num / BATCH_SIZE)) 128 | 129 | step = 0 130 | start = 0 131 | while step < num_iter: 132 | # 获取一个 batch 133 | if step == num_iter-1: 134 | end = test_num 135 | else: 136 | end = start+BATCH_SIZE 137 | image_batch = val_data[start:end] 138 | label_batch = val_label[start:end] 139 | # 准确率和预测结果统计信息 140 | pres, pre = sess.run([y, y_pre], feed_dict={x: image_batch, y_: label_batch}) 141 | res[start:end] = pre 142 | start += BATCH_SIZE 143 | step += 1 144 | # 计算准确率 145 | normal_num = 0 146 | for i in range(test_num): 147 | pre_pos[res[i]] += 1 148 | sample_labels[val_label[i]] += 1 149 | if res[i] == val_label[i]: 150 | true_pos[res[i]] += 1 151 | if val_label[i] == 0: 152 | normal_num += 1 153 | 154 | precision = true_pos/pre_pos 155 | recall = true_pos/sample_labels 156 | f1 = 2*precision*recall/(precision+recall) 157 | error = (pre_pos - true_pos) / (test_num - sample_labels) 158 | print('测试样本数:', test_num) 159 | print('测试数:', sample_labels) 160 | print('预测数:', pre_pos) 161 | print('正确数:', true_pos) 162 | print('Precision:', precision) 163 | print('Recall:', recall) 164 | print('F1:', f1) 165 | print('误检率:', error) 166 | print('准确率:', np.sum(true_pos)/test_num) 167 | print('总漏检率:', (pre_pos[0]-true_pos[0])/(test_num-normal_num)) 168 | print('总过杀率:', (normal_num-true_pos[0])/normal_num) 169 | 170 | 171 | if __name__ == '__main__': 172 | evaluate() 173 | -------------------------------------------------------------------------------- /Detect/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Detect 模型网络结构(自己搭建的) 3 | author: 王建坤 4 | date: 2018-8-10 5 | """ 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | 10 | # 卷积层 11 | def convolution(input_tensor, conv_height, conv_width, conv_deep, x_stride, y_stride, name, padding='SAME'): 12 | with tf.variable_scope(name): 13 | channel = int(input_tensor.get_shape()[-1]) 14 | weights = tf.get_variable("weights", shape=[conv_height, conv_width, channel, conv_deep], 15 | initializer=tf.truncated_normal_initializer(stddev=0.025)) 16 | bias = tf.get_variable("bias", shape=[conv_deep], initializer=tf.constant_initializer(0.025)) 17 | conv = tf.nn.conv2d(input_tensor, weights, strides=[1, x_stride, y_stride, 1], padding=padding) 18 | relu = tf.nn.relu(tf.nn.bias_add(conv, bias)) 19 | return relu 20 | 21 | 22 | # 最大池化层 23 | def max_pool(input_tensor, height, width, x_stride, y_stride, name, padding="SAME"): 24 | return tf.nn.max_pool(input_tensor, ksize=[1, height, width, 1], strides=[1, x_stride, y_stride, 1], 25 | padding=padding, name=name) 26 | 27 | 28 | # 局部响应归一化层 29 | def LRN(input_tensor, R, alpha, beta, name=None, bias=1.0): 30 | return tf.nn.local_response_normalization(input_tensor, depth_radius=R, bias=bias, alpha=alpha, beta=beta, 31 | name=name) 32 | 33 | 34 | # dropout层 35 | def dropout(input_tensor, prob, name): 36 | return tf.nn.dropout(input_tensor, prob, name=name) 37 | 38 | 39 | # 全连接层 40 | def full_connect(input_tensor, in_dimension, out_dimension, relu_flag, name): 41 | with tf.variable_scope(name): 42 | weights = tf.get_variable("weights", [in_dimension, out_dimension], 43 | initializer=tf.truncated_normal_initializer(stddev=np.sqrt(2 / in_dimension))) 44 | bias = tf.get_variable("bias", [out_dimension], initializer=tf.constant_initializer(0.05)) 45 | fc = tf.matmul(input_tensor, weights) + bias 46 | if relu_flag: 47 | fc = tf.nn.relu(fc) 48 | return fc 49 | 50 | 51 | def inference(input_tensor, train): 52 | conv1 = convolution(input_tensor, 11, 11, 96, 4, 4, "conv1", padding="VALID") 53 | pool1 = max_pool(conv1, 3, 3, 2, 2, "pool1", "VALID") 54 | lrn1 = LRN(pool1, 2, 2e-05, 0.75, "lrn1") 55 | conv2 = convolution(lrn1, 5, 5, 256, 1, 1, "conv2") 56 | pool2 = max_pool(conv2, 3, 3, 2, 2, "pool2", "VALID") 57 | lrn2 = LRN(pool2, 2, 2e-05, 0.75, "lrn2") 58 | conv3 = convolution(lrn2, 3, 3, 384, 1, 1, "conv3") 59 | conv4 = convolution(conv3, 3, 3, 384, 1, 1, "conv4") 60 | conv5 = convolution(conv4, 3, 3, 256, 1, 1, "conv5") 61 | pool5 = max_pool(conv5, 3, 3, 2, 2, "pool5", "VALID") 62 | fcin = tf.reshape(pool5, [-1, 256 * 6 * 6]) 63 | fc1 = full_connect(fcin, 256 * 6 * 6, 4096, True, "fc6") 64 | if train: 65 | fc1 = dropout(fc1, 0.8, "drop6") 66 | fc2 = full_connect(fc1, 4096, 2, True, "fc7") # 2048 67 | # if train: 68 | # fc2 = dropout(fc2, 1, "drop7") 69 | fc3 = full_connect(fc2, 2, 2, True, "fc8") 70 | return fc3 71 | -------------------------------------------------------------------------------- /Detect/model_merge.py: -------------------------------------------------------------------------------- 1 | """ 2 | 模型融合--根据预测结果进行融合 3 | author: 王建坤 4 | date: 2018-9-29 5 | """ 6 | import math 7 | import os 8 | import pandas as pd 9 | import numpy as np 10 | import tensorflow as tf 11 | from PIL import Image 12 | from nets import alexnet, vgg, inception_v4, resnet_v2 13 | 14 | 15 | BATCH_SIZE = 32 16 | IMG_SIZE1 = 224 17 | IMG_SIZE2 = 299 18 | CLASSES = 12 19 | GLOBAL_POOL = True 20 | test_dir = 'E:/dataset/alum/guangdong_round1_test_a_20180916' 21 | word_class = {'norm': 0, 'defect1': 1, 'defect2': 2, 'defect3': 3, 'defect4': 4, 'defect5': 5, 'defect6': 6, 22 | 'defect7': 7, 'defect8': 8, 'defect9': 9, 'defect10': 10, 'defect11': 11} 23 | class_label = ['norm', 'defect1', 'defect2', 'defect3', 'defect4', 'defect5', 24 | 'defect6', 'defect7', 'defect8', 'defect9', 'defect10', 'defect11'] 25 | 26 | 27 | def choose_model(x, model): 28 | """ 29 | 选择模型 30 | :param x: 31 | :param model: 32 | :return: 33 | """ 34 | # 模型保存路径,模型名,预训练文件路径,前向传播 35 | if model == 'Alex': 36 | log_dir = "E:/alum/log/Alex" 37 | y, _ = alexnet.alexnet_v2(x, 38 | num_classes=CLASSES, # 分类的类别 39 | is_training=True, # 是否在训练 40 | dropout_keep_prob=1.0, # 保留比率 41 | spatial_squeeze=True, # 压缩掉1维的维度 42 | global_pool=GLOBAL_POOL) # 输入不是规定的尺寸时,需要global_pool 43 | elif model == 'VGG': 44 | log_dir = "E:/alum/log/VGG" 45 | y, _ = vgg.vgg_16(x, 46 | num_classes=CLASSES, 47 | is_training=True, 48 | dropout_keep_prob=1.0, 49 | spatial_squeeze=True, 50 | global_pool=GLOBAL_POOL) 51 | elif model == 'VGG2': 52 | log_dir = "E:/alum/log/VGG2" 53 | y, _ = vgg.vgg_16(x, 54 | num_classes=CLASSES, 55 | is_training=True, 56 | dropout_keep_prob=1.0, 57 | spatial_squeeze=True, 58 | global_pool=GLOBAL_POOL) 59 | elif model == 'Incep4': 60 | log_dir = "E:/alum/log/Incep4" 61 | y, _ = inception_v4.inception_v4(x, num_classes=CLASSES, 62 | is_training=True, 63 | dropout_keep_prob=1.0, 64 | reuse=None, 65 | scope='InceptionV4', 66 | create_aux_logits=True) 67 | elif model == 'Res': 68 | log_dir = "E:/alum/log/Res" 69 | y, _ = resnet_v2.resnet_v2_50(x, 70 | num_classes=CLASSES, 71 | is_training=True, 72 | global_pool=GLOBAL_POOL, 73 | output_stride=None, 74 | spatial_squeeze=True, 75 | reuse=None, 76 | scope='resnet_v2_50') 77 | else: 78 | print('Error: model name not exist') 79 | return 80 | 81 | return y, log_dir 82 | 83 | 84 | def evaluate_model_merge(model1='VGG2', model2='VGG', submit=True): 85 | """ 86 | 模型融合,并评估结果 87 | :param model1: 88 | :param model2: 89 | :param submit: 90 | :return: 91 | """ 92 | # 存放模型1、2的预测结果 93 | image_name_list = os.listdir(test_dir) 94 | image_labels = pd.read_csv('E:/WJK_File/Python_File/Defect_Detection/Detect/test_label.csv') 95 | res, res1, res2, res_name, labels = [], [], [], [], [] 96 | 97 | # model 1 98 | g1 = tf.Graph() 99 | with g1.as_default(): 100 | x1 = tf.placeholder(tf.float32, [None, IMG_SIZE1, IMG_SIZE1, 3], name="x_input_1") 101 | y1, log_dir1 = choose_model(x1, model1) 102 | y1 = tf.nn.softmax(y1) 103 | saver1 = tf.train.Saver() 104 | sess1 = tf.Session(graph=g1) 105 | print("Reading checkpoints of model1") 106 | ckpt = tf.train.get_checkpoint_state(log_dir1) 107 | if ckpt and ckpt.model_checkpoint_path: 108 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 109 | saver1.restore(sess1, ckpt.model_checkpoint_path) 110 | print('Loading success, global_step is %s' % global_step) 111 | else: 112 | print('Error: no checkpoint file found') 113 | return 114 | 115 | # model 1 116 | g2 = tf.Graph() 117 | with g2.as_default(): 118 | x2 = tf.placeholder(tf.float32, [None, IMG_SIZE2, IMG_SIZE2, 3], name="x_input_2") 119 | y2, log_dir2 = choose_model(x2, model2) 120 | y2 = tf.nn.softmax(y2) 121 | saver2 = tf.train.Saver() 122 | sess2 = tf.Session(graph=g2) 123 | print("Reading checkpoints of model2") 124 | ckpt = tf.train.get_checkpoint_state(log_dir2) 125 | if ckpt and ckpt.model_checkpoint_path: 126 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 127 | saver2.restore(sess2, ckpt.model_checkpoint_path) 128 | print('Loading success, global_step is %s' % global_step) 129 | else: 130 | print('model2 checkpoint file is not found ') 131 | return 132 | 133 | for image_name in image_name_list: 134 | label = image_labels[image_labels.iloc[:, 0].isin([image_name])].iloc[0, 1] 135 | labels.append(word_class[label]) 136 | image_path = os.path.join(test_dir, image_name) 137 | img = Image.open(image_path) 138 | img1 = img.resize((IMG_SIZE1, IMG_SIZE1)) 139 | img1 = np.array(img1, np.float32) 140 | img1 = np.expand_dims(img1, axis=0) 141 | pre1 = sess1.run(y1, feed_dict={x1: img1}) 142 | max_index1 = np.argmax(pre1) 143 | max_value1 = np.max(pre1) 144 | res1.append([max_index1, max_value1]) 145 | 146 | img2 = img.resize((IMG_SIZE2, IMG_SIZE2)) 147 | img2 = np.array(img2, np.float32) 148 | img2 = np.expand_dims(img2, axis=0) 149 | pre2 = sess2.run(y2, feed_dict={x2: img2}) 150 | max_index2 = np.argmax(pre2) 151 | max_value2 = np.max(pre2) 152 | res2.append([max_index2, max_value2]) 153 | 154 | sess1.close() 155 | sess2.close() 156 | # print(res1, '\n', res2) 157 | img_num = len(res1) 158 | for i in range(img_num): 159 | if res1[i][0] == res2[i][0]: 160 | res.append(res1[i][0]) 161 | else: 162 | if res1[i][1] > res2[i][1]: 163 | res.append(res1[i][0]) 164 | else: 165 | res.append(res2[i][0]) 166 | print(res1[i], res2[i], res[i], labels[i]) 167 | print(res, '\n', labels) 168 | 169 | # 评估结果 170 | pre_pos = np.zeros(12) 171 | true_pos = np.zeros(12) 172 | for i in range(img_num): 173 | pre_pos[res[i]] += 1 174 | if res[i] == labels[i]: 175 | true_pos[res[i]] += 1 176 | precision = true_pos/pre_pos 177 | print(pre_pos, '\n', true_pos, '\n', precision) 178 | print('准确率:', np.sum(true_pos)/img_num, '\n', '平均准确率:', np.mean(precision)) 179 | 180 | # 是否要保存结果为csv 181 | if submit: 182 | for i in res: 183 | res_name.append(class_label[i]) 184 | print('image: ', image_name_list, '\n', 'class:', res_name) 185 | # 存为csv文件。两列:图片名,预测结果 186 | data = pd.DataFrame({'0': image_name_list, '1': res_name}) 187 | data.to_csv("merge.csv", index=False, header=False) 188 | 189 | 190 | if __name__ == '__main__': 191 | evaluate_model_merge() 192 | -------------------------------------------------------------------------------- /Detect/mynet.py: -------------------------------------------------------------------------------- 1 | """ 2 | 缺陷检测,自定义网络结构 3 | author: 王建坤 4 | date: 2018-12-5 5 | """ 6 | import tensorflow as tf 7 | import tensorflow.contrib.slim as slim 8 | import numpy as np 9 | import time 10 | 11 | IMG_SIZE = 128 12 | 13 | 14 | def mynet_v1(inputs, num_classes=10, is_training=True, scope=None): 15 | with tf.variable_scope(scope, 'Mynet_v1', [inputs]): 16 | with slim.arg_scope([slim.conv2d, slim.separable_conv2d], normalizer_fn=slim.batch_norm): 17 | with slim.arg_scope([slim.batch_norm], is_training=is_training): 18 | net = slim.conv2d(inputs, 16, [11, 11], stride=4, padding='VALID', scope='Conv_1') 19 | # net = slim.batch_norm(net, scope='Conv_1_bn') 20 | net = slim.conv2d(net, 32, [1, 1], stride=1, scope='Conv_2') 21 | net = slim.separable_conv2d(net, None, [5, 5], depth_multiplier=1.0, stride=3, scope='Conv_3_dw') 22 | net = slim.conv2d(net, 64, [1, 1], stride=1, scope='Conv_3_pw') 23 | net = slim.separable_conv2d(net, None, [3, 3], depth_multiplier=1.0, stride=2, scope='Conv_4_dw') 24 | net = slim.conv2d(net, 64, [1, 1], stride=1, scope='Conv_4_pw') 25 | net = slim.conv2d(net, 64, [5, 5], padding='VALID', scope='Conv_5') 26 | pre = slim.conv2d(net, num_classes, [1, 1], scope='Conv_6') 27 | # pre = tf.squeeze(pre, [1, 2], name='squeezed') 28 | # pre = slim.softmax(pre, scope='Predictions') 29 | 30 | return pre 31 | 32 | 33 | def train(): 34 | x = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, 1], name="x_input") 35 | y_ = tf.placeholder(tf.uint8, [None, 3], name="y_input") 36 | 37 | images = np.random.randint(255, size=(100, IMG_SIZE, IMG_SIZE, 1)) 38 | labels = np.random.randint(3, size=(100, 3)) 39 | 40 | y = mynet_v1(x, num_classes=3, scope=None) 41 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_, name='entropy') 42 | loss = tf.reduce_mean(cross_entropy, name='loss') 43 | optimizer = tf.train.AdamOptimizer(0.01) 44 | train_op = optimizer.minimize(loss) 45 | 46 | init = tf.global_variables_initializer() 47 | 48 | with tf.Session() as sess: 49 | sess.run(init) 50 | step = 0 51 | 52 | while step < 10: 53 | sess.run(train_op, feed_dict={x: images, y_: labels}) 54 | step += 1 55 | print(step) 56 | 57 | 58 | def predict(): 59 | x = tf.placeholder(tf.float32, [None, 2000, 1000, 1], name="x_input") 60 | y = mynet_v1(x, num_classes=3, scope=None) 61 | 62 | image = np.random.randint(255, size=(1, 2000, 1000, 1)) 63 | 64 | init = tf.global_variables_initializer() 65 | 66 | with tf.Session() as sess: 67 | sess.run(init) 68 | sess.run(y, feed_dict={x: image}) 69 | 70 | start_time = time.clock() 71 | res = sess.run(y, feed_dict={x: image}) 72 | end_time = time.clock() 73 | print('run time: ', end_time - start_time) 74 | print(res.shape) 75 | 76 | 77 | if __name__ == '__main__': 78 | images = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, 1], name="x_input") 79 | res = mynet_v1(images, 2) 80 | print(res) 81 | # train() 82 | # predict() 83 | 84 | 85 | -------------------------------------------------------------------------------- /Detect/phone_seg.py: -------------------------------------------------------------------------------- 1 | """ 2 | 手机图像分割,学习目标矩形框 3 | author: 王建坤 4 | date: 2018-10-25 5 | """ 6 | import tensorflow as tf 7 | import numpy as np 8 | import os 9 | import time 10 | import tensorflow.contrib.slim as slim 11 | from sklearn.model_selection import train_test_split 12 | from nets import alexnet 13 | from Detect import utils 14 | from PIL import Image 15 | import cv2 16 | 17 | 18 | def lenet5(inputs): 19 | inputs = tf.reshape(inputs, [-1, 300, 300, 3]) 20 | net = slim.conv2d(inputs, 32, [5, 5], padding='SAME', scope='conv1') 21 | net = slim.max_pool2d(net, 2, stride=2, scope='pool1') 22 | net = slim.conv2d(net, 64, [5, 5], padding='SAME', scope='conv2') 23 | net = slim.max_pool2d(net, 2, stride=2, scope='pool2') 24 | net = slim.flatten(net, scope='flatten') 25 | net = slim.fully_connected(net, 500, scope='fully4') 26 | net = slim.fully_connected(net, 10, scope='fully5') 27 | return net 28 | 29 | 30 | MAX_STEP = 100 31 | LEARNING_RATE_BASE = 0.01 32 | LEARNING_RATE_DECAY = 0.96 33 | IMG_SIZE = 224 34 | SHOW_SIZE = 800 35 | CLASSES = 8 36 | BATCH_SIZE = 32 37 | 38 | INFO_STEP = 20 39 | SAVE_STEP = 200 40 | GLOBAL_POOL = False 41 | 42 | 43 | def train(inherit=False, model='Alex'): 44 | # 加载数据集 45 | images = np.load('../data/card_data.npy') 46 | labels = np.load('../data/card_label.npy') 47 | train_data, val_data, train_label, val_label = train_test_split(images, labels, test_size=0.2, random_state=222) 48 | 49 | # 占位符 50 | x = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, 1], name='x-input') 51 | y_ = tf.placeholder(tf.float32, [None, CLASSES], name='y-input') 52 | my_global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int64) 53 | 54 | # 前向传播 55 | if model == 'Alex': 56 | log_path = "../log/Alex" 57 | model_name = 'alex.ckpt' 58 | y, _ = alexnet.alexnet_v2(x, 59 | num_classes=CLASSES, # 分类的类别 60 | is_training=True, # 是否在训练 61 | dropout_keep_prob=1.0, # 保留比率 62 | spatial_squeeze=True, # 压缩掉1维的维度 63 | global_pool=GLOBAL_POOL) # 输入不是规定的尺寸时,需要global_pool 64 | else: 65 | log_path = '../log/My' 66 | model_name = 'my.ckpt' 67 | y, _ = lenet5(x) 68 | 69 | # 交叉熵、损失值、优化器、准确率 70 | loss = tf.reduce_mean(tf.sqrt(tf.square(y_ - y)+0.0000001)) 71 | learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, my_global_step, 100, LEARNING_RATE_DECAY) 72 | train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=my_global_step) 73 | 74 | # 模型保存器、初始化 75 | saver = tf.train.Saver() 76 | init = tf.global_variables_initializer() 77 | 78 | tf.summary.scalar("loss", loss) 79 | # tf.summary.histogram("W1", W) 80 | merged_summary_op = tf.summary.merge_all() 81 | 82 | # 训练迭代 83 | with tf.Session() as sess: 84 | summary_writer1 = tf.summary.FileWriter('../log/curve/train', sess.graph) 85 | summary_writer2 = tf.summary.FileWriter('../log/curve/test') 86 | step = 0 87 | if inherit: 88 | ckpt = tf.train.get_checkpoint_state(log_path) 89 | if ckpt and ckpt.model_checkpoint_path: 90 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 91 | saver.restore(sess, ckpt.model_checkpoint_path) 92 | print(model, 'continue train from %s:' % global_step) 93 | step = int(global_step) 94 | else: 95 | print('Error: no checkpoint file found') 96 | return 97 | else: 98 | print(model, 'restart train:') 99 | sess.run(init) 100 | 101 | # 迭代 102 | while step < MAX_STEP: 103 | start_time = time.clock() 104 | image_batch, label_batch = utils.get_batch(train_data, train_label, BATCH_SIZE) 105 | 106 | # 训练,损失值和准确率 107 | _, train_loss = sess.run([train_op, loss], feed_dict={x: image_batch, y_: label_batch}) 108 | end_time = time.clock() 109 | runtime = end_time - start_time 110 | 111 | step += 1 112 | # 训练信息、曲线和保存模型 113 | if step % INFO_STEP == 0 or step == MAX_STEP: 114 | summary_str = sess.run(merged_summary_op, feed_dict={x: image_batch, y_: label_batch}) 115 | summary_writer1.add_summary(summary_str, step) 116 | test_loss, summary_str = sess.run([loss, merged_summary_op], feed_dict={x: val_data, y_: val_label}) 117 | summary_writer2.add_summary(summary_str, step) 118 | print('step: %d, runtime: %.2f, train loss: %.4f, test loss: %.4f' % 119 | (step, runtime, train_loss, test_loss)) 120 | 121 | if step % SAVE_STEP == 0: 122 | checkpoint_path = os.path.join(log_path, model_name) 123 | saver.save(sess, checkpoint_path, global_step=step) 124 | 125 | 126 | def predict(root_path, model='Alex'): 127 | # 占位符 128 | x = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, 1]) 129 | 130 | # 模型保存路径,前向传播 131 | if model == 'Alex': 132 | log_path = "../log/Alex" 133 | y, _ = alexnet.alexnet_v2(x, 134 | num_classes=CLASSES, # 分类的类别 135 | is_training=True, # 是否在训练 136 | dropout_keep_prob=1.0, # 保留比率 137 | spatial_squeeze=True, # 压缩掉1维的维度 138 | global_pool=GLOBAL_POOL) # 输入不是规定的尺寸时,需要global_pool 139 | else: 140 | log_path = '../log/My' 141 | y, _ = lenet5(x) 142 | 143 | saver = tf.train.Saver() 144 | 145 | with tf.Session() as sess: 146 | # 恢复模型权重 147 | print('Reading checkpoints: ', model) 148 | ckpt = tf.train.get_checkpoint_state(log_path) 149 | if ckpt and ckpt.model_checkpoint_path: 150 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 151 | saver.restore(sess, ckpt.model_checkpoint_path) 152 | print('Loading success, global_step is %s' % global_step) 153 | else: 154 | print('Error: no checkpoint file found') 155 | return 156 | 157 | file_list = os.listdir(root_path) 158 | for file_name in file_list: 159 | if file_name.split('.')[-1] == 'jpg': 160 | img = Image.open(os.path.join(root_path, file_name)) 161 | img_show = img.resize((SHOW_SIZE, SHOW_SIZE)) 162 | img_show = np.array(img_show, np.uint8) 163 | img = img.resize((IMG_SIZE, IMG_SIZE)) 164 | img = np.array(img, np.uint8) 165 | img = np.expand_dims(img, axis=0) 166 | img = np.expand_dims(img, axis=3) 167 | pre = np.squeeze(sess.run(y, feed_dict={x: img})) 168 | print('predict:', pre) 169 | 170 | points = [[pre[0] * SHOW_SIZE, pre[1] * SHOW_SIZE], [pre[2] * SHOW_SIZE, pre[3] * SHOW_SIZE], 171 | [pre[4] * SHOW_SIZE, pre[5] * SHOW_SIZE], [pre[6] * SHOW_SIZE, pre[7] * SHOW_SIZE]] 172 | points = np.int0(points) 173 | cv2.polylines(img_show, [points], True, (0, 0, 255), 1) 174 | cv2.imshow('img_show', img_show) 175 | cv2.waitKey() 176 | cv2.destroyAllWindows() 177 | 178 | 179 | if __name__ == '__main__': 180 | train() 181 | # predict('../data/card') 182 | -------------------------------------------------------------------------------- /Detect/predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | Detect 预测 3 | author: 王建坤 4 | date: 2018-8-10 5 | """ 6 | import sys 7 | sys.path.append("..") 8 | from Detect.config import * 9 | from PIL import Image 10 | import time 11 | 12 | # 图像目录路径 13 | # IMG_DIR = '../data/crop/pos/' 14 | # IMG_DIR = '../data/phone/' 15 | IMG_DIR = '../data/backup/cigarette_ys/wire_fail/' # normal wire_fail 16 | IS_TRAINING = False 17 | TS = [] 18 | 19 | 20 | def predict(img_path, model=MODEL_NAME): 21 | """ 22 | 预测图片 23 | :param img_path: 图片路径 24 | :param model: 模型名 25 | :return: none 26 | """ 27 | # 占位符 28 | x = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, CHANNEL]) 29 | 30 | # 模型保存路径,前向传播 31 | if model == 'Alex': 32 | log_dir = "../log/Alex" 33 | y, _ = alexnet.alexnet_v2(x, 34 | num_classes=CLASSES, # 分类的类别 35 | is_training=IS_TRAINING, # 是否在训练 36 | dropout_keep_prob=1.0, # 保留比率 37 | spatial_squeeze=True, # 压缩掉1维的维度 38 | global_pool=GLOBAL_POOL) # 输入不是规定的尺寸时,需要global_pool 39 | elif model == 'My': 40 | log_dir = "../log/My" 41 | y, _ = mobilenet_v1.mobilenet_v1(x, 42 | num_classes=CLASSES, 43 | dropout_keep_prob=1.0, 44 | is_training=IS_TRAINING, 45 | min_depth=8, 46 | depth_multiplier=0.1, 47 | conv_defs=None, 48 | prediction_fn=None, 49 | spatial_squeeze=True, 50 | reuse=None, 51 | scope='MobilenetV1', 52 | global_pool=GLOBAL_POOL) 53 | elif model == 'Mobile': 54 | log_dir = "../log/Mobile" 55 | y, _ = mobilenet_v1.mobilenet_v1(x, 56 | num_classes=CLASSES, 57 | dropout_keep_prob=1.0, 58 | is_training=IS_TRAINING, 59 | min_depth=8, 60 | depth_multiplier=1.0, 61 | conv_defs=None, 62 | prediction_fn=None, 63 | spatial_squeeze=True, 64 | reuse=None, 65 | scope='MobilenetV1', 66 | global_pool=GLOBAL_POOL) 67 | elif model == 'VGG': 68 | log_dir = "../log/VGG" 69 | y, _ = vgg.vgg_16(x, 70 | num_classes=CLASSES, 71 | is_training=IS_TRAINING, 72 | dropout_keep_prob=1.0, 73 | spatial_squeeze=True, 74 | global_pool=GLOBAL_POOL) 75 | elif model == 'Incep4': 76 | log_dir = "../log/Incep4" 77 | y, _ = inception_v4.inception_v4(x, num_classes=CLASSES, 78 | is_training=IS_TRAINING, 79 | dropout_keep_prob=1.0, 80 | reuse=None, 81 | scope='InceptionV4', 82 | create_aux_logits=True) 83 | elif model == 'Res': 84 | log_dir = "../log/Res" 85 | y, _ = resnet_v2.resnet_v2_50(x, 86 | num_classes=CLASSES, 87 | is_training=IS_TRAINING, 88 | global_pool=GLOBAL_POOL, 89 | output_stride=None, 90 | spatial_squeeze=True, 91 | reuse=None, 92 | scope='resnet_v2_50') 93 | else: 94 | print('Error: model name not exist') 95 | return 96 | 97 | y = tf.nn.softmax(y) 98 | saver = tf.train.Saver() 99 | 100 | with tf.Session() as sess: 101 | # 恢复模型权重 102 | print('Reading checkpoints: ', model) 103 | ckpt = tf.train.get_checkpoint_state(log_dir) 104 | if ckpt and ckpt.model_checkpoint_path: 105 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 106 | saver.restore(sess, ckpt.model_checkpoint_path) 107 | # check_tensor_name(sess) 108 | print('Loading success, global_step is %s' % global_step) 109 | else: 110 | print('Error: no checkpoint file found') 111 | return 112 | 113 | # img = read_img(IMG_DIR + img_path[0]) 114 | # # 第一次运行时间会较长 115 | # sess.run(y, feed_dict={x: np.expand_dims(img, 3)}) 116 | if img_path: 117 | img_list = img_path 118 | else: 119 | img_list = os.listdir(IMG_DIR) 120 | for img_name in img_list: 121 | # img_name = '1.jpg' 122 | print(img_name) 123 | 124 | start_time = time.clock() 125 | img_path = IMG_DIR + img_name 126 | img = read_img(img_path) 127 | # 如果输入是灰色图,要增加一维 128 | if CHANNEL == 1: 129 | img = np.expand_dims(img, axis=3) 130 | if MODEL_NAME == 'My': 131 | std_img_path = '../data/' + 'std.jpg' 132 | std_img = read_img(std_img_path) 133 | img = np.stack((std_img, img), axis=-1) 134 | predictions = sess.run(y, feed_dict={x: img}) 135 | pre = np.argmax(predictions, 1) 136 | end_time = time.clock() 137 | runtime = end_time - start_time 138 | TS.append(round(runtime*1000, 3)) 139 | print('prediction is:', predictions) 140 | print('predict class is:', pre) 141 | print('run time:', runtime) 142 | print(min(TS), max(TS[1:]), sum(TS[1:]) / (len(TS)-1)) 143 | 144 | 145 | def check_tensor_name(sess): 146 | """ 147 | 查看模型的 tensor,在调用模型之后使用 148 | :return: none 149 | """ 150 | import tensorflow.contrib.slim as slim 151 | variables_to_restore = slim.get_variables_to_restore() 152 | for var in variables_to_restore: 153 | # tensor 的名 154 | print(var.name) 155 | # tensor 的值 156 | print(sess.run(var)) 157 | 158 | 159 | def read_img(img_path): 160 | """ 161 | 读取指定路径的图片 162 | :param img_path: 图片的路径 163 | :return: numpy array of image 164 | """ 165 | img = Image.open(img_path).convert('L') 166 | img = img.resize((IMG_SIZE, IMG_SIZE)) 167 | # img = np.array(img) 168 | img = np.expand_dims(img, axis=0) 169 | return img 170 | 171 | 172 | if __name__ == '__main__': 173 | # normal, nothing, lack_cotton, lack_piece, wire_fail 174 | img_list = os.listdir('../data/backup/cigarette_ys/wire_fail') 175 | # predict(['1.jpg', '1.jpg']) 176 | predict(img_list) 177 | -------------------------------------------------------------------------------- /Detect/submit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Detect 铝型材表面缺陷检测--提交结果 3 | author: 王建坤 4 | date: 2018-9-17 5 | """ 6 | import numpy as np 7 | import tensorflow as tf 8 | import pandas as pd 9 | import os 10 | from PIL import Image, ImageFile 11 | from nets import alexnet, vgg, inception_v4, resnet_v2 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | 14 | CLASSES = 12 15 | IMG_SIZE = 299 16 | GLOBAL_POOL = True 17 | 18 | 19 | def submit(test_dir, model='VGG'): 20 | """ 21 | 测试集预测结果保存为 csv 文件 22 | :param test_dir: 测试集根路径 23 | :param model: 模型名称 24 | :return: none 25 | """ 26 | # 类别名称 27 | print('running submit') 28 | tf.reset_default_graph() 29 | class_label = ['norm', 'defect1', 'defect2', 'defect3', 'defect4', 'defect5', 30 | 'defect6', 'defect7', 'defect8', 'defect9', 'defect10', 'defect11'] 31 | img_index, res = [], [] 32 | 33 | # 占位符 34 | x = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, 3]) 35 | 36 | # 模型保存路径,前向传播 37 | if model == 'Alex': 38 | log_dir = "../log/Alex" 39 | y, _ = alexnet.alexnet_v2(x, 40 | num_classes=CLASSES, # 分类的类别 41 | is_training=True, # 是否在训练 42 | dropout_keep_prob=1.0, # 保留比率 43 | spatial_squeeze=True, # 压缩掉1维的维度 44 | global_pool=GLOBAL_POOL) # 输入不是规定的尺寸时,需要global_pool 45 | elif model == 'VGG': 46 | log_dir = "../log/VGG" 47 | y, _ = vgg.vgg_16(x, 48 | num_classes=CLASSES, 49 | is_training=True, 50 | dropout_keep_prob=1.0, 51 | spatial_squeeze=True, 52 | global_pool=GLOBAL_POOL) 53 | elif model == 'Incep4': 54 | log_dir = "E:/alum/log/Incep4" 55 | y, _ = inception_v4.inception_v4(x, num_classes=CLASSES, 56 | is_training=True, 57 | dropout_keep_prob=1.0, 58 | reuse=None, 59 | scope='InceptionV4', 60 | create_aux_logits=True) 61 | elif model == 'Res': 62 | log_dir = "E:/alum/log/Res" 63 | y, _ = resnet_v2.resnet_v2_50(x, 64 | num_classes=CLASSES, 65 | is_training=True, 66 | global_pool=GLOBAL_POOL, 67 | output_stride=None, 68 | spatial_squeeze=True, 69 | reuse=None, 70 | scope='resnet_v2_50') 71 | else: 72 | print('Error: model name not exist') 73 | return 74 | 75 | saver = tf.train.Saver() 76 | 77 | with tf.Session() as sess: 78 | # 恢复模型权重 79 | print("Reading checkpoints...") 80 | # ckpt 有 model_checkpoint_path 和 all_model_checkpoint_paths 两个属性 81 | ckpt = tf.train.get_checkpoint_state(log_dir) 82 | if ckpt and ckpt.model_checkpoint_path: 83 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 84 | saver.restore(sess, ckpt.model_checkpoint_path) 85 | print('Loading success, global_step is %s' % global_step) 86 | else: 87 | print('Error: no checkpoint file found') 88 | return 89 | 90 | img_list = os.listdir(test_dir) 91 | for img_name in img_list: 92 | # print(img_name) 93 | # 读取图片并缩放 94 | img = Image.open(os.path.join(test_dir, img_name)) 95 | img = img.resize((IMG_SIZE, IMG_SIZE)) 96 | 97 | # 图片转为numpy的数组 98 | img = np.array(img, np.float32) 99 | img = np.expand_dims(img, axis=0) 100 | 101 | predictions = sess.run(y, feed_dict={x: img}) 102 | pre = np.argmax(predictions) 103 | img_index.append(img_name) 104 | res.append(class_label[int(pre)]) 105 | 106 | print('image: ', img_index, '\n', 'class:', res) 107 | # 存为csv文件。两列:图片名,预测结果 108 | data = pd.DataFrame({'0': img_index, '1': res}) 109 | # data.sort_index(axis=0) 110 | data.to_csv("../submit/vgg_2.csv", index=False, header=False) 111 | 112 | 113 | if __name__ == '__main__': 114 | submit('../data/guangdong_round1_test_b_20181009') 115 | -------------------------------------------------------------------------------- /Detect/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | 工具函数 3 | author: 王建坤 4 | date: 2018-10-15 5 | """ 6 | import numpy as np 7 | import os 8 | import pandas as pd 9 | 10 | 11 | def load_npy(data_path, label_path): 12 | """ 13 | Load npy file 14 | :param data_path: the path of data directory 15 | :param label_path: the path of label directory 16 | :return: numpy format of data and label 17 | """ 18 | data = np.load(data_path) 19 | label = np.load(label_path) 20 | print('data_set shape: ', np.shape(data)) 21 | return data, label 22 | 23 | 24 | def get_batch(data, label, batch_size): 25 | """ 26 | Get a batch from numpy array. 27 | :param data: numpy array of data 28 | :param label: numpy array of label 29 | :param batch_size: the number of a batch 30 | :param total: the number of data set sample 31 | :return: a batch of data set 32 | """ 33 | # 遍历数据集 34 | # start = start % total 35 | # end = start + batch_size 36 | # if end > total: 37 | # res = end % total 38 | # end = total 39 | # data1 = data[start:end] 40 | # data2 = data[0: res] 41 | # return np.vstack((data1, data2)) 42 | # return data[start:end] 43 | 44 | # 随机选取一个batch 45 | total = data.shape[0] 46 | index_list = np.random.randint(total, size=batch_size) 47 | return data[index_list, :], label[index_list] 48 | 49 | 50 | def shuffle_data(data, label): 51 | """ 52 | Shuffle the data to assure that the training model is valid. 53 | """ 54 | permutation = np.random.permutation(label.shape[0]) 55 | batch_data = data[permutation, :] 56 | batch_label = label[permutation] 57 | return batch_data, batch_label 58 | 59 | 60 | def normalize_data(data): 61 | """ 62 | Normalized the data by subtracting the mean value of R,G,B to accelerate training . 63 | """ 64 | r = data[:, :, :, 2] 65 | r_mean = np.average(r) 66 | b = data[:, :, :, 0] 67 | b_mean = np.average(b) 68 | g = data[:, :, :, 1] 69 | g_mean = np.average(g) 70 | r = r - r_mean 71 | b = b - b_mean 72 | g = g - g_mean 73 | data = np.zeros([r.shape[0], r.shape[1], r.shape[2], 3]) 74 | data[:, :, :, 0] = b 75 | data[:, :, :, 1] = g 76 | data[:, :, :, 2] = r 77 | return data 78 | 79 | 80 | def check_tensor(): 81 | """ 82 | 查看保存的 ckpt文件中的 tensor 83 | """ 84 | from tensorflow.python import pywrap_tensorflow 85 | 86 | # ckpt 文件路径 87 | checkpoint_path = 'E:/alum/weight/inception_v4_2016_09_09/inception_v4.ckpt' 88 | 89 | # Read data from checkpoint file 90 | reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) 91 | var_to_shape_map = reader.get_variable_to_shape_map() 92 | 93 | # Print tensor name and values 94 | for key in var_to_shape_map: 95 | print("tensor_name: ", key) 96 | print(reader.get_tensor(key).dtype) 97 | 98 | 99 | def image_to_csv(data_path, save_path): 100 | """ 101 | 图片路径和标签保存为 csv 102 | :param data_path: image root path 103 | :param save_path: save path of csv file 104 | :return: none 105 | """ 106 | img_path, label = [], [] 107 | per_class_num = {} 108 | # 图片类型 109 | class_dic = {'正常': 0, '不导电': 1, '擦花': 2, '角位漏底': 3, '桔皮': 4, '漏底': 5, '起坑': 6, '脏点': 7} 110 | # 遍历文件夹 111 | folder_list = os.listdir(data_path) 112 | for folder in folder_list: 113 | if folder not in class_dic.keys(): 114 | continue 115 | 116 | num = 0 117 | class_path = os.path.join(data_path, folder) 118 | img_list = os.listdir(class_path) 119 | for img_name in img_list: 120 | img_path.append(os.path.join(class_path, img_name)) 121 | label.append(class_dic[folder]) 122 | num += 1 123 | per_class_num[folder] = num 124 | 125 | img_path_label = pd.DataFrame({'img_path': img_path, 'label': label}) 126 | img_path_label.to_csv(save_path+'/data_label.csv', index=False) 127 | print('per class number: ', per_class_num) 128 | 129 | 130 | def csv_show_image(): 131 | """ 132 | 从 csv 文件读取图片路径,并显示 133 | :return: none 134 | """ 135 | from PIL import Image 136 | table = pd.read_csv('../data/data_label.csv') 137 | img_path = table.iloc[0, 0] 138 | img = Image.open(img_path) 139 | img.show() 140 | 141 | 142 | if __name__ == "__main__": 143 | print('running utils:') 144 | # train_data, train_label = load_npy('../data/data_299.npy', '../data/label_299.npy') 145 | # image_to_csv('../data/alum', '../data') 146 | 147 | 148 | -------------------------------------------------------------------------------- /Image_preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Image_preprocessing/__init__.py -------------------------------------------------------------------------------- /Image_preprocessing/create_data.py: -------------------------------------------------------------------------------- 1 | """ 2 | 生成手机后壳的虚假数据集--只是用来测试算法的效果 3 | author: 王建坤 4 | date: 2018-8-15 5 | """ 6 | import cv2 7 | import numpy as np 8 | import os 9 | from PIL import ImageEnhance, Image 10 | 11 | 12 | def img_to_gray(img_path): 13 | """ 14 | 彩色图片转为灰度图存储 15 | """ 16 | # src_img = cv2.imread(img_path, 0) 17 | # print(src_img.shape) 18 | # cv2.imwrite('F:/DefectDetection/img_data/zc.jpg', src_img) 19 | # # cv2.cvtColor(src_img, gray_img, cv2.COLOR_BGR2GRAY); 20 | 21 | src_img = Image.open(img_path) 22 | # 显示图片的尺寸,图片的模式,图片的格式 23 | print(src_img.size, src_img.mode, src_img.format) 24 | dst_img = src_img.convert("L") 25 | # dst_img.save(os.path.join(os.path.dirname(img_path), 'zc.jpg')) 26 | dst_img.save(os.path.join(os.path.dirname(img_path), 'wz.jpg')) 27 | 28 | 29 | def img_enhance(img_dir, save_path, num): 30 | """ 31 | 随机扰动图像的亮度和对比度 32 | """ 33 | # img_path_list = [img_dir] 34 | img_path_list = os.listdir(img_dir) 35 | for img_path in img_path_list: 36 | # src_img = Image.open(img_path) 37 | # filename_pre = img_path.split('/')[-1].split('.')[0] 38 | 39 | src_img = Image.open(os.path.join(img_dir, img_path)) 40 | # print(src_img.mode, src_img.size) 41 | filename_pre = img_path.split('.')[0] 42 | 43 | for i in range(num): 44 | brightness_factor = np.random.randint(97, 103) / 100.0 45 | brightness_image = ImageEnhance.Brightness(src_img).enhance(brightness_factor) 46 | contrast_factor = np.random.randint(98, 102) / 100.0 47 | contrast_image = ImageEnhance.Contrast(brightness_image).enhance(contrast_factor) 48 | save_name = save_path + '/' + filename_pre + '_' + str(i) + '.jpg' 49 | contrast_image.save(save_name) 50 | 51 | 52 | def draw_line(img_path, save_path, num): 53 | """ 54 | 随机画线 55 | """ 56 | src_img = cv2.imread(img_path, -1) 57 | filename_pre = save_path.split('/')[-1] 58 | 59 | for i in range(num): 60 | dst_img = src_img.copy() 61 | x1, y1 = np.random.randint(80, dst_img.shape[1]-120), np.random.randint(70, dst_img.shape[0]-70) 62 | x2, y2 = x1, y1 63 | while x1-10 < x2 < x1+10: 64 | x2 = np.random.randint(x1-50, x1+50) 65 | while y1-10 < y2 < y1+10: 66 | y2 = np.random.randint(y1-50, y1+50) 67 | # print(x1, y1, x2, y2) 68 | thickness = np.random.randint(1, 2) 69 | pixel = np.random.randint(160, 200) 70 | cv2.line(dst_img, (x1, y1), (x2, y2), pixel, thickness) 71 | 72 | save_name = save_path + '/' + filename_pre + '_' + str(i) + '.jpg' 73 | cv2.imwrite(save_name, dst_img) 74 | 75 | 76 | def draw_wz(img_path, save_path, num): 77 | """ 78 | 随机贴上污渍 79 | """ 80 | src_img = cv2.imread(img_path, -1) 81 | filename_pre = save_path.split('/')[-1] 82 | 83 | for i in range(num): 84 | dst_img = src_img.copy() 85 | x1, y1 = np.random.randint(50, dst_img.shape[1] - 60), np.random.randint(40, dst_img.shape[0] - 40) 86 | # print(x1, y1) 87 | radius = np.random.randint(5, 20) 88 | pixel = np.random.randint(180, 200) 89 | cv2.circle(dst_img, (x1, y1), radius, pixel, -1) 90 | 91 | save_name = save_path + '/' + filename_pre + '_' + str(i) + '.jpg' 92 | cv2.imwrite(save_name, dst_img) 93 | 94 | 95 | if __name__ == '__main__': 96 | print('create_data.py is running') 97 | # img_to_gray('F:/DefectDetection/img_data/src_zc.jpg') 98 | # img_enhance('F:/DefectDetection/img_data/zc.jpg', 'F:/DefectDetection/img_data/zc', 2000) 99 | # draw_line('F:/DefectDetection/img_data/zc.jpg', 'F:/DefectDetection/img_data/hh', 500) 100 | # img_enhance('F:/DefectDetection/img_data/hh', 'F:/DefectDetection/img_data/hh', 3) 101 | # img_to_gray('F:/DefectDetection/img_data/src_wz.jpg') 102 | # draw_wz('F:/DefectDetection/img_data/zc.jpg', 'F:/DefectDetection/img_data/wz', 500) 103 | # img_enhance('F:/DefectDetection/img_data/wz', 'F:/DefectDetection/img_data/wz', 3) 104 | -------------------------------------------------------------------------------- /Image_preprocessing/image2npy.py: -------------------------------------------------------------------------------- 1 | """ 2 | 图片数据集转 npy 文件 3 | author: 王建坤 4 | date: 2018-10-15 5 | """ 6 | import numpy as np 7 | import os 8 | from PIL import Image, ImageFile 9 | from matplotlib import pyplot as plt 10 | import random as rd 11 | 12 | # 读图出错的解决方法 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | IMG_SIZE = 224 # 299 15 | 16 | 17 | def multi_class_to_npy(data_path, save_path): 18 | """ 19 | 多种类别的图片数据集保存为 npy 文件 20 | :param data_path: the root path of data set 21 | :param save_path: the save path of npy 22 | :return: none 23 | """ 24 | data, label = [], [] 25 | per_class_num = {} 26 | save_name = 'b_cig_' 27 | 28 | # 图片类型 29 | # 铝型材类别 30 | # class_dic = {'正常': 0, '不导电': 1, '擦花': 2, '角位漏底': 3, '桔皮': 4, '漏底': 5, '起坑': 6, '脏点': 7} 31 | # 电子烟装配类别 32 | class_dic = {'normal': 0, 'nothing': 1, 'lack_cotton': 2, 'lack_piece': 3, 'wire_fail': 4} 33 | # 遍历文件夹 34 | folder_list = os.listdir(data_path) 35 | for folder in folder_list: 36 | # 过滤不需要的文件夹 37 | if folder not in class_dic.keys(): 38 | continue 39 | # 遍历文件夹中的文件 40 | num = 0 41 | class_path = os.path.join(data_path, folder) 42 | img_list = os.listdir(class_path) 43 | for img_name in img_list: 44 | print(img_name) 45 | # 读取图片并缩放 46 | img = Image.open(os.path.join(class_path, img_name)) 47 | # img.convert('L') 48 | # print(img.mode) 49 | img = img.resize((IMG_SIZE, IMG_SIZE)) 50 | # 添加数据和标签,根据文件夹名确定样本的标签 51 | img = np.array(img) 52 | data.append(img) 53 | label.append(class_dic[folder]) 54 | num += 1 55 | per_class_num[folder] = num 56 | 57 | # 数据集转化为numpy数组 58 | data = np.array(data, np.uint8) 59 | label = np.array(label, np.uint8) 60 | print('Per class number is: ', per_class_num) 61 | print('Data set shape is: ', np.shape(data), np.shape(label)) 62 | 63 | # 数组保存为npy 64 | data_save_path = save_path+'/'+save_name+'data_'+str(IMG_SIZE)+'.npy' 65 | label_save_path = save_path+'/'+save_name+'label_'+str(IMG_SIZE)+'.npy' 66 | np.save(data_save_path, data) 67 | np.save(label_save_path, label) 68 | 69 | 70 | def siamese_sample_to_npy(data_path, save_path): 71 | """ 72 | 构造 Siamese 数据集 73 | :param data_path: the root path of data set 74 | :param save_path: the save path of npy 75 | :return: none 76 | """ 77 | data, label = [], [] 78 | per_class_num = {'normal': 0, 'nothing': 0, 'lack_cotton': 0, 'lack_piece': 0, 'wire_fail': 0} 79 | save_name = 's_sia_cig_' 80 | 81 | # 电子烟装配类别 82 | class_dic = {'normal': 0, 'nothing': 1, 'lack_cotton': 2, 'lack_piece': 3, 'wire_fail': 4} 83 | 84 | # 标准图片 85 | std_class_path = os.path.join(data_path, 'std') 86 | std_img_list = os.listdir(std_class_path) 87 | # std_img_list = rd.sample(std_img_list, 300) 88 | for std_img_name in std_img_list: 89 | std_img = Image.open(os.path.join(std_class_path, std_img_name)) 90 | # img.convert('L') 91 | # print(img.mode) 92 | std_img = std_img.resize((IMG_SIZE, IMG_SIZE)) 93 | # 添加数据和标签,根据文件夹名确定样本的标签 94 | std_img = np.array(std_img) 95 | 96 | # 遍历文件夹 97 | folder_list = os.listdir(data_path) 98 | for folder in folder_list: 99 | # 过滤不需要的文件夹 100 | if folder not in class_dic.keys(): 101 | continue 102 | # 遍历文件夹中的文件 103 | class_path = os.path.join(data_path, folder) 104 | img_list = os.listdir(class_path) 105 | # 对正常类别进行降采样 106 | # if folder == 'normal': 107 | # img_list = rd.sample(img_list, 200) 108 | for img_name in img_list: 109 | print(img_name) 110 | # 读取图片并缩放 111 | img = Image.open(os.path.join(class_path, img_name)) 112 | # img.convert('L') 113 | # print(img.mode) 114 | img = img.resize((IMG_SIZE, IMG_SIZE)) 115 | # 添加数据和标签,根据文件夹名确定样本的标签 116 | sample = np.stack((std_img, img), axis=-1) 117 | data.append(sample) 118 | label.append(class_dic[folder]) 119 | per_class_num[folder] += 1 120 | 121 | # 数据集转化为numpy数组 122 | data = np.array(data, np.uint8) 123 | label = np.array(label, np.uint8) 124 | print('Per class number is: ', per_class_num) 125 | print('Data set shape is: ', np.shape(data), np.shape(label)) 126 | 127 | # 数组保存为npy 128 | data_save_path = save_path+'/'+save_name+'data_'+str(IMG_SIZE)+'.npy' 129 | label_save_path = save_path+'/'+save_name+'label_'+str(IMG_SIZE)+'.npy' 130 | np.save(data_save_path, data) 131 | np.save(label_save_path, label) 132 | 133 | 134 | def load_npy(data_path, label_path): 135 | """ 136 | 读取 npy 文件 137 | :param data_path: 138 | :param label_path: 139 | :return: 140 | """ 141 | data = np.load(data_path) 142 | label = np.load(label_path) 143 | print('data set shape is: ', np.shape(data)) 144 | return data, label 145 | 146 | 147 | def array_to_image(sample, sample_label): 148 | """ 149 | 显示一个样本的信息,图片和标签 150 | :param sample: 151 | :param sample_label: 152 | :return: 153 | """ 154 | print('sample shape: ', np.shape(sample)) 155 | sample = sample/255 156 | # plt.imshow 的像素值为 [0, 1] 157 | plt.imshow(sample) 158 | plt.show() 159 | print('sample label: ', sample_label) 160 | 161 | 162 | if __name__ == '__main__': 163 | print('running image2npy:') 164 | # alum_to_npy('../data/alum', '../data') 165 | multi_class_to_npy('../data/cigarette', '../data') 166 | # siamese_sample_to_npy('../data/cigarette', '../data') 167 | # siamese_sample_to_npy('E:/backup/cigarette', '../data') 168 | 169 | 170 | 171 | -------------------------------------------------------------------------------- /Image_preprocessing/image_crop.py: -------------------------------------------------------------------------------- 1 | """ 2 | 图像裁剪,从大图片中裁剪出小图片 3 | author: 王建坤 4 | date: 2018-12-6 5 | """ 6 | import cv2 7 | import os 8 | 9 | IMG_H_SIZE = 2000 10 | IMG_V_SIZE = 1000 11 | 12 | IMG_DIR = '../data/phone' 13 | SAVE_DIR = '../data/crop' 14 | 15 | 16 | def open_img(img_path): 17 | """ 18 | 打开指定路径的单张图像(灰度图),并缩放 19 | :param img_path: 图像路径 20 | :return: 21 | """ 22 | img = cv2.imread(img_path, 0) 23 | img = cv2.resize(img, (IMG_H_SIZE, IMG_V_SIZE)) 24 | return img 25 | 26 | 27 | def crop_img(img): 28 | for x in range(20, 101, 8): 29 | for y in range(60, 701, 8): 30 | print(x, '-', y) 31 | small_image = img[x:x+128, y:y+128] 32 | cv2.imwrite(os.path.join(SAVE_DIR, str(x)+'_'+str(y)+'.jpg'), small_image) 33 | 34 | 35 | def main(): 36 | img_list = os.listdir(IMG_DIR) 37 | img_name = img_list[0] 38 | print('image name:', img_name) 39 | img = open_img(os.path.join(IMG_DIR, img_name)) 40 | crop_img(img) 41 | 42 | 43 | if __name__ == '__main__': 44 | main() 45 | -------------------------------------------------------------------------------- /Image_preprocessing/json2npy.py: -------------------------------------------------------------------------------- 1 | """ 2 | 把 json文件转为数据集,并保存为npy 3 | author: 王建坤 4 | date: 2018-10-26 5 | """ 6 | # labelme_json_to_dataset E:/1.json 7 | # labelme_draw_json E:/1.json 8 | import json 9 | import os 10 | import numpy as np 11 | from PIL import Image 12 | 13 | IMG_SIZE = 224 14 | col = 2592 15 | row = 1944 16 | 17 | 18 | def json_to_npy(root_path, save_path='../data/'): 19 | """ 20 | 把 json文件转为数据集,并保存为npy 21 | :param root_path: 22 | :param save_path: 23 | :return: 24 | """ 25 | data, labels = [], [] 26 | f = open('../data/label.txt', 'a') 27 | file_list = os.listdir(root_path) 28 | for file_name in file_list: 29 | if file_name.split('.')[-1] == 'jpg': 30 | img = Image.open(os.path.join(root_path, file_name)) 31 | img = img.resize((IMG_SIZE, IMG_SIZE)) 32 | img = np.array(img) 33 | img = np.expand_dims(img, 3) 34 | data.append(img) 35 | json_path = os.path.join(root_path, file_name.split('.')[0]) + '.json' 36 | label = json_to_label(json_path) 37 | f.write(str(label) + '\n') 38 | temp = [label[0][0]/col, label[0][1]/row, label[1][0]/col, label[1][1]/row, 39 | label[2][0]/col, label[2][1]/row, label[3][0]/col, label[3][1]/row] 40 | labels.append(temp) 41 | 42 | f.close() 43 | # 数据集转化为numpy数组 44 | data = np.array(data, np.uint8) 45 | labels = np.array(labels, np.float16) 46 | print('data set shape is: ', np.shape(data), np.shape(labels)) 47 | 48 | # 数组保存为npy 49 | data_save_path = save_path + 'card_data.npy' 50 | label_save_path = save_path + 'card_label.npy' 51 | np.save(data_save_path, data) 52 | np.save(label_save_path, labels) 53 | 54 | 55 | def json_to_label(json_path): 56 | """ 57 | 抽取 json 文件中的 label 58 | :param json_path: 59 | :return: 60 | """ 61 | with open(json_path, 'r', encoding='utf-8') as load_f: 62 | data = json.load(load_f) 63 | label = data['shapes'][0]['points'] 64 | # print(label) 65 | return label 66 | 67 | 68 | if __name__ == '__main__': 69 | print('running json2npy: ') 70 | json_to_npy('../data/card') 71 | # json_to_label('E:/1.json') 72 | 73 | 74 | -------------------------------------------------------------------------------- /Siamese/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Siamese/__init__.py -------------------------------------------------------------------------------- /Siamese/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Siamese/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /Siamese/__pycache__/inference.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Siamese/__pycache__/inference.cpython-36.pyc -------------------------------------------------------------------------------- /Siamese/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Siamese/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /Siamese/bin_classify.py: -------------------------------------------------------------------------------- 1 | """ 2 | Siamese 训练 -- 二分类 3 | author: 王建坤 4 | date: 2018-10-15 5 | """ 6 | import tensorflow as tf 7 | import numpy as np 8 | import os 9 | from Siamese import inference, utils 10 | 11 | MAX_STEP = 600 12 | LEARNING_RATE = 0.01 13 | # 训练信息和保存权重的gap 14 | INFO_STEP = 20 15 | SAVE_STEP = 200 16 | 17 | BATCH_SIZE = 128 18 | IMG_SIZE = 299 19 | 20 | 21 | def train(inherit=False): 22 | """ 23 | 用二分类来训练 Siamese 24 | """ 25 | # 加载数据集 26 | images = np.load('../data/data_'+str(IMG_SIZE)+'.npy') 27 | labels = np.load('../data/label_'+str(IMG_SIZE)+'.npy') 28 | images, labels = utils.shuffle_data(images, labels) 29 | 30 | # 占位符 31 | with tf.variable_scope('bin_input_x1') as scope: 32 | x1 = tf.placeholder(tf.float32, shape=[None, IMG_SIZE, IMG_SIZE, 3]) 33 | with tf.variable_scope('bin_input_x2') as scope: 34 | x2 = tf.placeholder(tf.float32, shape=[None, IMG_SIZE, IMG_SIZE, 3]) 35 | with tf.variable_scope('bin_y') as scope: 36 | y = tf.placeholder(tf.float32, shape=[None, 1]) 37 | with tf.name_scope('bin_keep_prob') as scope: 38 | keep_prob = tf.placeholder(tf.float32) 39 | 40 | # 前向传播 41 | with tf.variable_scope('bin_siamese') as scope: 42 | out1 = inference.inference(x1, keep_prob) 43 | # 参数共享,不会生成两套参数 44 | scope.reuse_variables() 45 | out2 = inference.inference(x2, keep_prob) 46 | 47 | # 增加二分类层 48 | with tf.name_scope('bin_c') as scope: 49 | w_bc = tf.Variable(tf.truncated_normal(shape=[64, 1], stddev=0.05, mean=0), name='w_bc') 50 | b_bc = tf.Variable(tf.zeros(1), name='b_bc') 51 | out12 = tf.concat((out1, out2), 1, name='out12') 52 | pre = tf.add(tf.matmul(out12, w_bc), b_bc) 53 | 54 | # 损失函数和优化器 55 | with tf.variable_scope('metrics') as scope: 56 | cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits(labels=y, logits=pre, name='entropy') 57 | loss = tf.reduce_mean(cross_entropy, name='loss') 58 | train_op = tf.train.AdamOptimizer(LEARNING_RATE).minimize(loss) 59 | 60 | init = tf.global_variables_initializer() 61 | saver = tf.train.Saver(tf.global_variables()) 62 | 63 | # 模型保存路径 64 | log_path = '../log/Siamese' 65 | 66 | with tf.Session() as sess: 67 | step = 0 68 | if inherit: 69 | ckpt = tf.train.get_checkpoint_state(log_path) 70 | if ckpt and ckpt.model_checkpoint_path: 71 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 72 | saver.restore(sess, ckpt.model_checkpoint_path) 73 | print('Siamese continue train from %s:' % global_step) 74 | step = int(global_step) 75 | else: 76 | print('Error: no checkpoint file found') 77 | return 78 | else: 79 | print('Siamese restart train:') 80 | sess.run(init) 81 | 82 | while step < MAX_STEP: 83 | # 获取一对batch的数据集 84 | x_1, y_1 = utils.get_batch(images, labels, BATCH_SIZE) 85 | x_2, y_2 = utils.get_batch(images, labels, BATCH_SIZE) 86 | # 判断对应的两个标签是否相等 87 | y_s = np.array(y_1 == y_2, dtype=np.uint8) 88 | y_s = np.expand_dims(y_s, axis=1) 89 | 90 | _, train_loss = sess.run([train_op, loss], feed_dict={x1: x_1, x2: x_2, y: y_s, keep_prob: 0.8}) 91 | 92 | step += 1 93 | # 训练信息和保存模型 94 | if step % INFO_STEP == 0 or step == MAX_STEP: 95 | print('step: %d, loss: %.4f' % (step, train_loss)) 96 | 97 | if step % SAVE_STEP == 0: 98 | checkpoint_path = os.path.join(log_path, 'sia_bin.ckpt') 99 | saver.save(sess, checkpoint_path, global_step=step) 100 | 101 | 102 | if __name__ == '__main__': 103 | train() 104 | -------------------------------------------------------------------------------- /Siamese/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Siamese 配置文件,全局参数 3 | author: 王建坤 4 | date: 2018-12-20 5 | """ 6 | import numpy as np 7 | import tensorflow as tf 8 | from nets import alexnet, mobilenet_v1, vgg, inception_v4, resnet_v2, inception_resnet_v2 9 | from Detect import mynet 10 | import os 11 | 12 | # 分类类别 13 | CLASSES = 5 14 | # 图像尺寸 15 | IMG_SIZE = 224 16 | # 图像通道数 17 | CHANNEL = 1 18 | # 是否为非标准图像尺寸 19 | GLOBAL_POOL = True 20 | # 是否使用GPU 21 | USE_GPU = True 22 | if not USE_GPU: 23 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 24 | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" 25 | # 模型名称 26 | MODEL_NAME = 'Mobile' 27 | # 数据集路径 28 | images_path = '../data/cig_data_' + str(IMG_SIZE) + '.npy' 29 | labels_path = '../data/cig_label_' + str(IMG_SIZE) + '.npy' 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /Siamese/data_cache.py: -------------------------------------------------------------------------------- 1 | """ 2 | Siamese 缓存数据 -- 缓存用Siamese提取每个样本的特征向量 3 | author: 王建坤 4 | date: 2018-8-14 5 | """ 6 | import tensorflow as tf 7 | import numpy as np 8 | from Siamese import inference 9 | 10 | 11 | def cache_data(): 12 | """ 13 | 缓存用Siamese提取每个样本的特征向量 14 | """ 15 | # 加载数据集 16 | images = np.load('F:/DefectDetection/npy_data/train_data.npy') 17 | # 根据自己的图像尺寸修改shape 18 | images_input = tf.reshape(images, [-1, 28, 28, 1]) 19 | 20 | # 前向传播 21 | result = inference.inference(images_input, 1.0) 22 | 23 | saver = tf.train.Saver() 24 | 25 | with tf.Session() as sess: 26 | saver.restore(sess, 'F:/DefectDetection/log/Siamese') 27 | feature_vectors = sess.run(result) 28 | 29 | # 保存为npy文件 30 | np.save('F:/DefectDetection/npy_data/feature_vectors.npy', feature_vectors) 31 | 32 | 33 | if __name__ == '__main__': 34 | cache_data() 35 | -------------------------------------------------------------------------------- /Siamese/evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | Siamese 评估 3 | author: 王建坤 4 | date: 2018-8-16 5 | """ 6 | import tensorflow as tf 7 | import numpy as np 8 | from Siamese import inference, utils 9 | 10 | BATCH_SIZE = 132 11 | CLASSES = 3 12 | CODE_LEN = 32 13 | 14 | 15 | def evaluate(): 16 | # 加载数据集 17 | images = np.load('E:/dataset/npy/train_data.npy') 18 | labels = np.load('E:/dataset/npy/train_label.npy') 19 | total_test = images.shape[0] 20 | 21 | # 占位符 22 | with tf.variable_scope('input_x1') as scope: 23 | x1 = tf.placeholder(tf.float32, shape=[None, 227, 227, 1]) 24 | with tf.variable_scope('input_x2') as scope: 25 | x2 = tf.placeholder(tf.float32, shape=[None, 227, 227, 1]) 26 | with tf.variable_scope('y') as scope: 27 | y = tf.placeholder(tf.float32, shape=[None]) 28 | 29 | with tf.name_scope('keep_prob') as scope: 30 | keep_prob = tf.placeholder(tf.float32) 31 | 32 | # 前向传播 33 | with tf.variable_scope('siamese') as scope: 34 | out1 = inference.inference(x1, keep_prob) 35 | # 参数共享,不会生成两套参数 36 | scope.reuse_variables() 37 | out2 = inference.inference(x2, keep_prob) 38 | 39 | # 损失函数和优化器 40 | with tf.variable_scope('metrics') as scope: 41 | loss = inference.loss_spring(out1, out2, y) 42 | 43 | saver = tf.train.Saver(tf.global_variables()) 44 | 45 | # 模型保存路径 46 | log_dir = "E:/alum/log/Siamese" 47 | 48 | with tf.Session() as sess: 49 | print("Reading checkpoints...") 50 | ckpt = tf.train.get_checkpoint_state(log_dir) 51 | if ckpt and ckpt.model_checkpoint_path: 52 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 53 | saver.restore(sess, ckpt.model_checkpoint_path) 54 | print('Loading success, global_step is %s' % global_step) 55 | else: 56 | print('No checkpoint file found') 57 | return 58 | 59 | # 各个类别的样本数和平均编码 60 | sample_num = np.zeros(CLASSES) 61 | code_sum = np.zeros((CLASSES, CODE_LEN)) 62 | for i in range(total_test): 63 | xs_1 = images[i:i+1] 64 | xs_1 = np.expand_dims(xs_1, axis=3) 65 | l_1 = np.argmax(labels[i]) 66 | y1 = sess.run(out1, feed_dict={x1: xs_1, keep_prob: 1}) 67 | sample_num[l_1] += 1 68 | code_sum[l_1] += np.squeeze(y1) 69 | code_mean = code_sum/np.expand_dims(sample_num, axis=1) 70 | # print('sample_num: ', '\n', sample_num, '\n', 'code_mean: ', code_mean) 71 | 72 | # 类别间的距离 73 | class_diff = np.zeros((CLASSES, CLASSES)) 74 | for i in range(CLASSES): 75 | for j in range(i, CLASSES): 76 | class_diff[i][j] = class_diff[j][i] = np.mean(np.square(code_mean[i]-code_mean[j])) 77 | print('code diff between class: ', '\n', class_diff) 78 | 79 | 80 | if __name__ == '__main__': 81 | evaluate() 82 | -------------------------------------------------------------------------------- /Siamese/inference.py: -------------------------------------------------------------------------------- 1 | """ 2 | Siamese 模型网络结构--基于LeNet5 3 | author: 王建坤 4 | date: 2018-8-14 5 | """ 6 | import tensorflow as tf 7 | 8 | 9 | def inference(inputs, keep_prob): 10 | initer = tf.truncated_normal_initializer(stddev=0.1) 11 | 12 | with tf.name_scope('conv1') as scope: 13 | w1 = tf.get_variable('w1', dtype=tf.float32, shape=[11, 11, 3, 4], initializer=initer) 14 | b1 = tf.get_variable('b1', dtype=tf.float32, initializer=tf.constant(0.01, shape=[4], dtype=tf.float32)) 15 | conv1 = tf.nn.conv2d(inputs, w1, strides=[1, 1, 1, 1], padding='VALID', name='conv1') 16 | with tf.name_scope('relu1') as scope: 17 | relu1 = tf.nn.relu(tf.add(conv1, b1), name='relu1') 18 | with tf.name_scope('max_pool1') as scope: 19 | pool1 = tf.nn.max_pool(relu1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID', name='max_pool1') 20 | 21 | with tf.name_scope('conv2') as scope: 22 | w2 = tf.get_variable('w2', dtype=tf.float32, shape=[5, 5, 4, 8], initializer=initer) 23 | b2 = tf.get_variable('b2', dtype=tf.float32, initializer=tf.constant(0.01, shape=[8], dtype=tf.float32)) 24 | conv2 = tf.nn.conv2d(pool1, w2, strides=[1, 1, 1, 1], padding='VALID', name='conv2') 25 | with tf.name_scope('relu2') as scope: 26 | relu2 = tf.nn.relu(conv2 + b2, name='relu2') 27 | with tf.name_scope('max_pool2') as scope: 28 | pool2 = tf.nn.max_pool(relu2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID', name='max_pool2') 29 | 30 | with tf.name_scope('conv3') as scope: 31 | w3 = tf.get_variable('w3', dtype=tf.float32, shape=[3, 3, 8, 8], initializer=initer) 32 | b3 = tf.get_variable('b3', dtype=tf.float32, initializer=tf.constant(0.01, shape=[8], dtype=tf.float32)) 33 | conv3 = tf.nn.conv2d(pool2, w3, strides=[1, 1, 1, 1], padding='VALID', name='conv3') 34 | with tf.name_scope('relu3') as scope: 35 | relu3 = tf.nn.relu(conv3 + b3, name='relu3') 36 | with tf.name_scope('max_pool3') as scope: 37 | pool3 = tf.nn.max_pool(relu3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID', name='max_pool3') 38 | 39 | dim = pool3.get_shape()[1]*pool3.get_shape()[2]*pool3.get_shape()[3] 40 | # print(pool3.get_shape()[1], pool3.get_shape()[2], pool3.get_shape()[3]) 41 | 42 | with tf.name_scope('fc1') as scope: 43 | x_flat = tf.reshape(pool3, shape=[-1, dim]) 44 | w_fc1 = tf.get_variable('w_fc1', dtype=tf.float32, shape=[int(dim), 128], initializer=initer) 45 | b_fc1 = tf.get_variable('b_fc1', dtype=tf.float32, initializer=tf.constant(0.01, shape=[128], dtype=tf.float32)) 46 | fc1 = tf.add(tf.matmul(x_flat, w_fc1), b_fc1) 47 | with tf.name_scope('relu_fc1') as scope: 48 | relu_fc1 = tf.nn.relu(fc1, name='relu_fc1') 49 | # with tf.name_scope('bn_fc1') as scope: 50 | # bn_fc1 = tf.layers.batch_normalization(relu_fc1, name='bn_fc1') 51 | with tf.name_scope('drop_1') as scope: 52 | drop_1 = tf.nn.dropout(relu_fc1, keep_prob=keep_prob, name='drop_1') 53 | 54 | with tf.name_scope('fc2') as scope: 55 | w_fc2 = tf.get_variable('w_fc2', dtype=tf.float32, shape=[128, 32], initializer=initer) 56 | b_fc2 = tf.get_variable('b_fc2', dtype=tf.float32, initializer=tf.constant(0.01, shape=[32], dtype=tf.float32)) 57 | fc2 = tf.add(tf.matmul(drop_1, w_fc2), b_fc2) 58 | 59 | return fc2 60 | 61 | 62 | # 损失函数1 63 | # y=1 表示同一类 64 | def siamese_loss(out1, out2, y, margin=5.0): 65 | Q = tf.constant(margin, name="Q", dtype=tf.float32) 66 | E_w = tf.sqrt(tf.reduce_sum(tf.square(out1-out2), 1)) 67 | # 同类 68 | pos = tf.multiply(tf.multiply(y, 2/Q), tf.square(E_w)) 69 | # 不同类 70 | neg = tf.multiply(tf.multiply(1-y, 2*Q), tf.exp(-2.77/Q*E_w)) 71 | loss = pos + neg 72 | loss = tf.reduce_mean(loss) 73 | return loss 74 | 75 | 76 | # 损失函数2 77 | def loss_spring(out1, out2, y, margin=5.0): 78 | eucd2 = tf.reduce_sum(tf.square(out1-out2), 1) 79 | eucd = tf.sqrt(eucd2+1e-6, name="eucd") 80 | C = tf.constant(margin, name="C") 81 | # (1-yi)*||CNN(p1i)-CNN(p2i)||^2 + yi*max(0, C-||CNN(p1i)-CNN(p2i)||^2) 82 | # 同类 83 | pos = tf.multiply(y, eucd, name="pos_loss") 84 | # 不同类 85 | # neg = tf.multiply(1 - y, tf.pow(tf.maximum(C - eucd, 0.0), 2), name="neg_loss") 86 | neg = tf.multiply(1-y, tf.maximum(C-eucd, 0.0), name="neg_loss") 87 | losses = tf.add(pos, neg, name="losses") 88 | loss = tf.reduce_mean(losses, name="loss") 89 | return loss 90 | -------------------------------------------------------------------------------- /Siamese/inference_2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Siamese 训练 -- 多分类网络结构 3 | author: 王建坤 4 | date: 2018-8-14 5 | """ 6 | import tensorflow as tf 7 | 8 | 9 | def inference(inputs, keep_prob): 10 | """ 11 | 两层的神经网络进行多分类 12 | """ 13 | with tf.name_scope('mul_fc1') as scope: 14 | w_fc1 = tf.Variable(tf.truncated_normal(shape=[inputs.shape[1], 64], stddev=0.05, mean=0), name='w_fc1') 15 | b_fc1 = tf.Variable(tf.zeros(64), name='b_fc1') 16 | fc1 = tf.add(tf.matmul(inputs, w_fc1), b_fc1) 17 | with tf.name_scope('mul_relu_fc1') as scope: 18 | relu_fc1 = tf.nn.relu(fc1, name='relu_fc1') 19 | with tf.name_scope('mul_drop_1') as scope: 20 | drop_1 = tf.nn.dropout(relu_fc1, keep_prob=keep_prob, name='drop_1') 21 | with tf.name_scope('mul_bn_fc1') as scope: 22 | bn_fc1 = tf.layers.batch_normalization(drop_1, name='bn_fc1') 23 | 24 | with tf.name_scope('mul_fc2') as scope: 25 | w_fc2 = tf.Variable(tf.truncated_normal(shape=[64, 10], stddev=0.05, mean=0), name='w_fc1') 26 | b_fc2 = tf.Variable(tf.zeros(64), name='b_fc2') 27 | fc2 = tf.add(tf.matmul(bn_fc1, w_fc2), b_fc2) 28 | 29 | return fc2 30 | -------------------------------------------------------------------------------- /Siamese/predict.py: -------------------------------------------------------------------------------- 1 | """ 2 | Siamese 预测 3 | author: 王建坤 4 | date: 2018-8-16 5 | """ 6 | import tensorflow as tf 7 | import numpy as np 8 | from Siamese import inference 9 | 10 | # 搭建图和恢复模型 11 | # 占位符 12 | with tf.variable_scope('input_x1') as scope: 13 | x1 = tf.placeholder(tf.float32, shape=[None, 227, 227, 1]) 14 | with tf.variable_scope('input_x2') as scope: 15 | x2 = tf.placeholder(tf.float32, shape=[None, 227, 227, 1]) 16 | 17 | # 前向传播 18 | with tf.variable_scope('siamese') as scope: 19 | out1 = inference.inference(x1, 1) 20 | # 参数共享,不会生成两套参数 21 | scope.reuse_variables() 22 | out2 = inference.inference(x2, 1) 23 | 24 | diff = tf.sqrt(tf.reduce_sum(tf.square(out1 - out2))) 25 | 26 | saver = tf.train.Saver(tf.global_variables()) 27 | # 模型保存路径 28 | log_dir = "E:/alum/log/Siamese" 29 | 30 | sess = tf.Session() 31 | 32 | # 重复使用变量 33 | # tf.get_variable_scope().reuse_variables() 34 | 35 | print("Reading checkpoints...") 36 | ckpt = tf.train.get_checkpoint_state(log_dir) 37 | if ckpt and ckpt.model_checkpoint_path: 38 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 39 | saver.restore(sess, ckpt.model_checkpoint_path) 40 | print('Loading success, global_step is %s' % global_step) 41 | else: 42 | print('No checkpoint file found') 43 | 44 | 45 | # 单独预测,不用重复加载模型,sess是全局变量 46 | def predict(img1, img2): 47 | x_1 = np.expand_dims(img1, axis=3) 48 | x_2 = np.expand_dims(img2, axis=3) 49 | y1, y2, diff_loss = sess.run([out1, out2, diff], feed_dict={x1: x_1, x2: x_2}) 50 | print(y1, '\n', y2) 51 | print(diff_loss) 52 | 53 | # 加载模型和预测在一个函数中 54 | # def predict(img1, img2): 55 | # x_1 = np.expand_dims(img1, axis=3) 56 | # x_2 = np.expand_dims(img2, axis=3) 57 | # 58 | # # 占位符 59 | # with tf.variable_scope('input_x1') as scope: 60 | # x1 = tf.placeholder(tf.float32, shape=[None, 227, 227, 1]) 61 | # with tf.variable_scope('input_x2') as scope: 62 | # x2 = tf.placeholder(tf.float32, shape=[None, 227, 227, 1]) 63 | # 64 | # # 前向传播 65 | # with tf.variable_scope('siamese') as scope: 66 | # out1 = inference.inference(x1, 1) 67 | # # 参数共享,不会生成两套参数 68 | # scope.reuse_variables() 69 | # out2 = inference.inference(x2, 1) 70 | # 71 | # diff = tf.sqrt(tf.reduce_sum(tf.square(out1 - out2))) 72 | # 73 | # saver = tf.train.Saver(tf.global_variables()) 74 | # # 模型保存路径 75 | # log_dir = "E:/alum/log/Siamese" 76 | # 77 | # with tf.Session() as sess: 78 | # tf.get_variable_scope().reuse_variables() 79 | # print("Reading checkpoints...") 80 | # ckpt = tf.train.get_checkpoint_state(log_dir) 81 | # if ckpt and ckpt.model_checkpoint_path: 82 | # global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 83 | # saver.restore(sess, ckpt.model_checkpoint_path) 84 | # print('Loading success, global_step is %s' % global_step) 85 | # else: 86 | # print('No checkpoint file found') 87 | # return 88 | # 89 | # # 查看图的 tensor 90 | # # for var in tf.global_variables(): 91 | # # print(var.name) 92 | # 93 | # # 查看 tensor 的值 94 | # # print(sess.run(tf.get_default_graph().get_tensor_by_name("siamese/b2:0"))) 95 | # 96 | # y1, y2, diff_loss = sess.run([out1, out2, diff], feed_dict={x1: x_1, x2: x_2}) 97 | # # print(y1, '\n', y2) 98 | # print(diff_loss) 99 | 100 | 101 | if __name__ == '__main__': 102 | images = np.load('E:/dataset/npy/train_data.npy') 103 | labels = np.load('E:/dataset/npy/train_label.npy') 104 | index1 = 0 105 | img1 = images[index1:index1+1] 106 | index2 = 1 107 | img2 = images[index2:index2+1] 108 | print(img1.shape) 109 | predict(img1, img2) 110 | 111 | # for index2 in range(labels.shape[0]): 112 | # img2 = images[index2:index2 + 1] 113 | # predict(img1, img2) 114 | 115 | # 两个样本间的关系 116 | # correct = np.equal(np.argmax(labels[index1]), np.argmax(labels[index2])) 117 | # print('true relationship is: ', correct) 118 | 119 | 120 | # first = 0 121 | # second = 0 122 | # third = 0 123 | # for i in range(120): 124 | # for j in range(240, 400): 125 | # ou1, ou2 = predict(images[i:i+1], images[j:j+1]) 126 | # # ou2 = predict(images[301:302], images[301:302]) 127 | # diffe = np.sqrt(np.sum(np.square(ou1 - ou2), 1)) 128 | # # print('difference is: ', diffe) 129 | # if diffe < 1: 130 | # first+=1 131 | # elif 1<= diffe <5: 132 | # second +=1 133 | # else: 134 | # third += 1 135 | # print(first, second, third) 136 | -------------------------------------------------------------------------------- /Siamese/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Siamese 训练 -- 相似度损失函数 3 | author: 王建坤 4 | date: 2018-9-30 5 | """ 6 | from sklearn.model_selection import train_test_split 7 | from Siamese import inference, utils 8 | from Siamese.config import * 9 | 10 | MAX_STEP = 2000 11 | LEARNING_RATE_BASE = 0.001 12 | LEARNING_RATE_DECAY = 0.98 13 | # 训练信息和保存权重的gap 14 | INFO_STEP = 20 15 | SAVE_STEP = 200 16 | # 图像尺寸 17 | BATCH_SIZE = 16 18 | 19 | 20 | def train(inherit=False): 21 | # 加载数据集 22 | images = np.load(images_path) 23 | labels = np.load(labels_path) 24 | train_data, val_data, train_label, val_label = train_test_split(images, labels, test_size=0.2, random_state=222) 25 | # 如果输入是灰色图,要增加一维 26 | if CHANNEL == 1: 27 | train_data = np.expand_dims(train_data, axis=3) 28 | val_data = np.expand_dims(val_data, axis=3) 29 | 30 | # 占位符 31 | x1 = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, CHANNEL], name="x_input1") 32 | x2 = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, CHANNEL], name="x_input2") 33 | y = tf.placeholder(tf.float32, [None], name="y_input") 34 | # keep_prob = tf.placeholder(tf.float32, name='keep_prob') 35 | my_global_step = tf.Variable(0, name='global_step', trainable=False, dtype=tf.int64) 36 | 37 | # 前向传播 38 | with tf.variable_scope('siamese') as scope: 39 | out1, _ = mobilenet_v1.mobilenet_v1(x1, 40 | num_classes=CLASSES, 41 | dropout_keep_prob=1.0, 42 | is_training=True, 43 | min_depth=8, 44 | depth_multiplier=1.0, 45 | conv_defs=None, 46 | prediction_fn=None, 47 | spatial_squeeze=True, 48 | reuse=tf.AUTO_REUSE, 49 | scope='MobilenetV1', 50 | global_pool=GLOBAL_POOL) 51 | # 参数共享,不会生成两套参数。注意定义variable时要使用get_variable() 52 | # scope.reuse_variables() 53 | out2, _ = mobilenet_v1.mobilenet_v1(x2, 54 | num_classes=CLASSES, 55 | dropout_keep_prob=1.0, 56 | is_training=True, 57 | min_depth=8, 58 | depth_multiplier=1.0, 59 | conv_defs=None, 60 | prediction_fn=None, 61 | spatial_squeeze=True, 62 | reuse=tf.AUTO_REUSE, 63 | scope='MobilenetV1', 64 | global_pool=GLOBAL_POOL) 65 | 66 | # 损失函数和优化器 67 | with tf.variable_scope('metrics') as scope: 68 | loss = inference.loss_spring(out1, out2, y) 69 | learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, my_global_step, 100, LEARNING_RATE_DECAY) 70 | train_op = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=my_global_step) 71 | 72 | saver = tf.train.Saver(tf.global_variables()) 73 | 74 | # 模型保存路径和名称 75 | log_dir = "../log/Siamese" 76 | model_name = 'siamese.ckpt' 77 | 78 | with tf.Session() as sess: 79 | step = 0 80 | if inherit: 81 | ckpt = tf.train.get_checkpoint_state(log_dir) 82 | if ckpt and ckpt.model_checkpoint_path: 83 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 84 | saver.restore(sess, ckpt.model_checkpoint_path) 85 | print('Siamese continue train from %s' % global_step) 86 | step = int(global_step) 87 | else: 88 | print('No checkpoint file found') 89 | return 90 | else: 91 | print('restart train') 92 | sess.run(tf.global_variables_initializer()) 93 | 94 | while step < MAX_STEP: 95 | # 获取一对batch的数据集 96 | xs_1, ys_1 = utils.get_batch(train_data, train_label, BATCH_SIZE) 97 | xs_2, ys_2 = utils.get_batch(train_data, train_label, BATCH_SIZE) 98 | # 判断对应的两个标签是否相等 99 | y_s = np.array(ys_1 == ys_2, dtype=np.float32) 100 | 101 | _, y1, y2, train_loss = sess.run([train_op, out1, out2, loss], 102 | feed_dict={x1: xs_1, x2: xs_2, y: y_s}) 103 | 104 | # 训练信息和保存模型 105 | step += 1 106 | if step % INFO_STEP == 0 or step == MAX_STEP: 107 | print('step: %d, loss: %.4f' % (step, train_loss)) 108 | 109 | if step % SAVE_STEP == 0 or step == MAX_STEP: 110 | checkpoint_path = os.path.join(log_dir, model_name) 111 | saver.save(sess, checkpoint_path, global_step=step) 112 | 113 | 114 | if __name__ == '__main__': 115 | train() 116 | -------------------------------------------------------------------------------- /Siamese/train_2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Siamese 训练 -- 多分类训练 3 | author: 王建坤 4 | date: 2018-8-14 5 | """ 6 | import tensorflow as tf 7 | import numpy as np 8 | import os 9 | from Siamese_multi import utils 10 | from Siamese import inference_2 11 | 12 | learning_rate = 0.01 13 | iterations = 100 14 | batch_size = 64 15 | 16 | 17 | def multi_classify_train(): 18 | """ 19 | 用Siamese提取的特征进行多分类 20 | """ 21 | # 加载数据集 22 | feature_vectors = np.load('F:/DefectDetection/npy_data/feature_vectors.npy') 23 | labels = np.load('F:/DefectDetection/npy_data/train_label.npy') 24 | 25 | total_train = feature_vectors.shape[0] 26 | 27 | # 占位符 28 | input_vectors = tf.placeholder(tf.float32, shape=[None, 64], name='input_vectors') 29 | y = tf.placeholder(tf.float32, shape=[None, 3]) 30 | keep_prob = tf.placeholder(tf.float32, name='multi_keep_prob') 31 | 32 | # 前向传播 33 | result = inference_2.inference(input_vectors, keep_prob) 34 | 35 | # 损失函数和优化器 36 | cross_entropy = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=result, name='multi_entropy') 37 | loss = tf.reduce_mean(cross_entropy, name='multi_loss') 38 | train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss) 39 | 40 | # 模型保存路径 41 | log_dir = "F:/DefectDetection/log/Siamese_multi" 42 | saver = tf.train.Saver(tf.global_variables()) 43 | 44 | with tf.Session() as sess: 45 | sess.run(tf.global_variables_initializer()) 46 | 47 | start = 0 48 | for iter in range(iterations): 49 | # 获取一个batch的数据集 50 | vectors_batch = utils.get_batch(feature_vectors, start, batch_size, total_train) 51 | label_batch = utils.get_batch(labels, start, batch_size, total_train) 52 | _, train_loss = sess.run([train_op, loss], feed_dict={input_vectors: vectors_batch, 53 | y: label_batch, keep_prob: 0.6}) 54 | start += batch_size 55 | # if iter % 100 == 1: 56 | # print('iter {},train loss {}'.format(iter, train_loss)) 57 | 58 | # 训练信息和保存模型 59 | if (iter + 11) % 10 == 0 or (iter + 1) == iterations: 60 | print('iter: %d, loss: %.4f' % (iter, train_loss)) 61 | checkpoint_path = os.path.join(log_dir, 'Siamese_multi_model.ckpt') 62 | saver.save(sess, checkpoint_path, global_step=iter) 63 | 64 | 65 | if __name__ == '__main__': 66 | multi_classify_train() 67 | -------------------------------------------------------------------------------- /Siamese/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | 工具函数 3 | author: 王建坤 4 | date: 2018-10-15 5 | """ 6 | import numpy as np 7 | import os 8 | import pandas as pd 9 | 10 | 11 | def load_npy(data_path, label_path): 12 | """ 13 | Load npy file 14 | :param data_path: the path of data directory 15 | :param label_path: the path of label directory 16 | :return: numpy format of data and label 17 | """ 18 | data = np.load(data_path) 19 | label = np.load(label_path) 20 | print('data_set shape: ', np.shape(data)) 21 | return data, label 22 | 23 | 24 | def get_batch(data, label, batch_size): 25 | """ 26 | Get a batch from numpy array. 27 | :param data: numpy array of data 28 | :param label: numpy array of label 29 | :param batch_size: the number of a batch 30 | :param total: the number of data set sample 31 | :return: a batch of data set 32 | """ 33 | # 遍历数据集 34 | # start = start % total 35 | # end = start + batch_size 36 | # if end > total: 37 | # res = end % total 38 | # end = total 39 | # data1 = data[start:end] 40 | # data2 = data[0: res] 41 | # return np.vstack((data1, data2)) 42 | # return data[start:end] 43 | 44 | # 随机选取一个batch 45 | total = data.shape[0] 46 | index_list = np.random.randint(total, size=batch_size) 47 | return data[index_list, :], label[index_list] 48 | 49 | 50 | def shuffle_data(data, label): 51 | """ 52 | Shuffle the data to assure that the training model is valid. 53 | """ 54 | permutation = np.random.permutation(label.shape[0]) 55 | batch_data = data[permutation, :] 56 | batch_label = label[permutation] 57 | return batch_data, batch_label 58 | 59 | 60 | def normalize_data(data): 61 | """ 62 | Normalized the data by subtracting the mean value of R,G,B to accelerate training . 63 | """ 64 | r = data[:, :, :, 2] 65 | r_mean = np.average(r) 66 | b = data[:, :, :, 0] 67 | b_mean = np.average(b) 68 | g = data[:, :, :, 1] 69 | g_mean = np.average(g) 70 | r = r - r_mean 71 | b = b - b_mean 72 | g = g - g_mean 73 | data = np.zeros([r.shape[0], r.shape[1], r.shape[2], 3]) 74 | data[:, :, :, 0] = b 75 | data[:, :, :, 1] = g 76 | data[:, :, :, 2] = r 77 | return data 78 | 79 | 80 | def check_tensor(): 81 | """ 82 | 查看保存的 ckpt文件中的 tensor 83 | """ 84 | from tensorflow.python import pywrap_tensorflow 85 | 86 | # ckpt 文件路径 87 | checkpoint_path = 'E:/alum/weight/inception_v4_2016_09_09/inception_v4.ckpt' 88 | 89 | # Read data from checkpoint file 90 | reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path) 91 | var_to_shape_map = reader.get_variable_to_shape_map() 92 | 93 | # Print tensor name and values 94 | for key in var_to_shape_map: 95 | print("tensor_name: ", key) 96 | print(reader.get_tensor(key).dtype) 97 | 98 | 99 | def image_to_csv(data_path, save_path): 100 | """ 101 | 图片路径和标签保存为 csv 102 | :param data_path: image root path 103 | :param save_path: save path of csv file 104 | :return: none 105 | """ 106 | img_path, label = [], [] 107 | per_class_num = {} 108 | # 图片类型 109 | class_dic = {'正常': 0, '不导电': 1, '擦花': 2, '角位漏底': 3, '桔皮': 4, '漏底': 5, '起坑': 6, '脏点': 7} 110 | # 遍历文件夹 111 | folder_list = os.listdir(data_path) 112 | for folder in folder_list: 113 | if folder not in class_dic.keys(): 114 | continue 115 | 116 | num = 0 117 | class_path = os.path.join(data_path, folder) 118 | img_list = os.listdir(class_path) 119 | for img_name in img_list: 120 | img_path.append(os.path.join(class_path, img_name)) 121 | label.append(class_dic[folder]) 122 | num += 1 123 | per_class_num[folder] = num 124 | 125 | img_path_label = pd.DataFrame({'img_path': img_path, 'label': label}) 126 | img_path_label.to_csv(save_path+'/data_label.csv', index=False) 127 | print('per class number: ', per_class_num) 128 | 129 | 130 | def csv_show_image(): 131 | """ 132 | 从 csv 文件读取图片路径,并显示 133 | :return: none 134 | """ 135 | from PIL import Image 136 | table = pd.read_csv('../data/data_label.csv') 137 | img_path = table.iloc[0, 0] 138 | img = Image.open(img_path) 139 | img.show() 140 | 141 | 142 | if __name__ == "__main__": 143 | print('running utils:') 144 | # train_data, train_label = load_npy('../data/data_299.npy', '../data/label_299.npy') 145 | # image_to_csv('../data/alum', '../data') 146 | 147 | 148 | -------------------------------------------------------------------------------- /Siamese_multi/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Siamese_multi/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /Siamese_multi/__pycache__/inference.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/Siamese_multi/__pycache__/inference.cpython-36.pyc -------------------------------------------------------------------------------- /camera/DllTest.dll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/camera/DllTest.dll -------------------------------------------------------------------------------- /camera/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/camera/__init__.py -------------------------------------------------------------------------------- /camera/test.py: -------------------------------------------------------------------------------- 1 | from ctypes import * 2 | 3 | dll = CDLL("DllTest.dll") 4 | print(dll.add(10, 102)) 5 | -------------------------------------------------------------------------------- /mv/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/mv/__init__.py -------------------------------------------------------------------------------- /mv/evaluate.py: -------------------------------------------------------------------------------- 1 | """ 2 | 电子烟雾化器装配检测,效果评估 3 | author: 王建坤 4 | date: 2019-3-6 5 | """ 6 | import os 7 | from mv import cig_detection 8 | import numpy as np 9 | 10 | ROOT_DIR = 'E:/backup/cigarette/' 11 | CLASSES = 5 12 | 13 | 14 | def evaluate(): 15 | """ 16 | 17 | :return: 18 | """ 19 | img_path_list, labels = [], [] 20 | 21 | class_dic = {'normal': 0, 'nothing': 1, 'lack_cotton': 2, 'lack_piece': 3, 'wire_fail': 4} 22 | # 遍历文件夹 23 | folder_list = os.listdir(ROOT_DIR) 24 | for folder in folder_list: 25 | # 过滤不需要的文件夹 26 | if folder not in class_dic.keys(): 27 | continue 28 | folder_path = os.path.join(ROOT_DIR, folder) 29 | img_list = os.listdir(folder_path) 30 | for img_name in img_list: 31 | img_path = os.path.join(folder_path, img_name) 32 | img_path_list.append(img_path) 33 | labels.append(class_dic[folder]) 34 | 35 | res = [] 36 | for img, label in zip(img_path_list, labels): 37 | print(img) 38 | detector = cig_detection.AssembleDetection(img) 39 | detector.detect() 40 | if sum(detector.res) == 0: 41 | res.append(0) 42 | else: 43 | for i in range(len(detector.res)): 44 | if detector.res[i] == 1: 45 | if i < 2: 46 | res.append(i+1) 47 | break 48 | elif i == 2: 49 | res.append(4) 50 | break 51 | else: 52 | res.append(i) 53 | break 54 | 55 | sample_labels = np.zeros(CLASSES, dtype=np.uint16) 56 | pre_pos = np.zeros(CLASSES, dtype=np.uint16) 57 | true_pos = np.zeros(CLASSES, dtype=np.uint16) 58 | 59 | test_num = len(res) 60 | normal_num = 0 61 | for i in range(test_num): 62 | pre_pos[res[i]] += 1 63 | sample_labels[labels[i]] += 1 64 | if res[i] == labels[i]: 65 | true_pos[res[i]] += 1 66 | if labels[i] == 0: 67 | normal_num += 1 68 | 69 | precision = true_pos/pre_pos 70 | recall = true_pos/sample_labels 71 | f1 = 2*precision*recall/(precision+recall) 72 | error = (pre_pos - true_pos) / (test_num - sample_labels) 73 | print('测试样本数:', test_num) 74 | print('测试数:', sample_labels) 75 | print('预测数:', pre_pos) 76 | print('正确数:', true_pos) 77 | print('Precision:', precision) 78 | print('Recall:', recall) 79 | print('F1:', f1) 80 | print('误检率:', error) 81 | print('准确率:', np.sum(true_pos)/test_num) 82 | print('总漏检率:', (pre_pos[0]-true_pos[0])/(test_num-normal_num)) 83 | print('总过杀率:', (normal_num-true_pos[0])/normal_num) 84 | 85 | 86 | if __name__ == '__main__': 87 | evaluate() 88 | 89 | 90 | -------------------------------------------------------------------------------- /nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/alexnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/__pycache__/alexnet.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_resnet_v2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/__pycache__/inception_resnet_v2.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/__pycache__/inception_utils.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/inception_v4.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/__pycache__/inception_v4.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/mobilenet_v1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/__pycache__/mobilenet_v1.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/resnet_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/__pycache__/resnet_utils.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/resnet_v2.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/__pycache__/resnet_v2.cpython-36.pyc -------------------------------------------------------------------------------- /nets/__pycache__/vgg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/__pycache__/vgg.cpython-36.pyc -------------------------------------------------------------------------------- /nets/alexnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a model definition for AlexNet. 16 | 17 | This work was first described in: 18 | ImageNet Classification with Deep Convolutional Neural Networks 19 | Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton 20 | 21 | and later refined in: 22 | One weird trick for parallelizing convolutional neural networks 23 | Alex Krizhevsky, 2014 24 | 25 | Here we provide the implementation proposed in "One weird trick" and not 26 | "ImageNet Classification", as per the paper, the LRN layers have been removed. 27 | 28 | Usage: 29 | with slim.arg_scope(alexnet.alexnet_v2_arg_scope()): 30 | outputs, end_points = alexnet.alexnet_v2(inputs) 31 | 32 | @@alexnet_v2 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | import tensorflow as tf 40 | 41 | import tensorflow.contrib.slim as slim 42 | 43 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 44 | 45 | 46 | def alexnet_v2_arg_scope(weight_decay=0.0005): 47 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 48 | activation_fn=tf.nn.relu, 49 | biases_initializer=tf.constant_initializer(0.1), 50 | weights_regularizer=slim.l2_regularizer(weight_decay)): 51 | with slim.arg_scope([slim.conv2d], padding='SAME'): 52 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 53 | return arg_sc 54 | 55 | 56 | def alexnet_v2(inputs, 57 | num_classes=1000, 58 | is_training=True, 59 | dropout_keep_prob=0.5, 60 | spatial_squeeze=True, 61 | scope='alexnet_v2', 62 | global_pool=False): 63 | """AlexNet version 2. 64 | 65 | Described in: http://arxiv.org/pdf/1404.5997v2.pdf 66 | Parameters from: 67 | github.com/akrizhevsky/cuda-convnet2/blob/master/layers/ 68 | layers-imagenet-1gpu.cfg 69 | 70 | Note: All the fully_connected layers have been transformed to conv2d layers. 71 | To use in classification mode, resize input to 224x224 or set 72 | global_pool=True. To use in fully convolutional mode, set 73 | spatial_squeeze to false. 74 | The LRN layers have been removed and change the initializers from 75 | random_normal_initializer to xavier_initializer. 76 | 77 | Args: 78 | inputs: a tensor of size [batch_size, height, width, channels]. 79 | num_classes: the number of predicted classes. If 0 or None, the logits layer 80 | is omitted and the input features to the logits layer are returned instead. 81 | is_training: whether or not the model is being trained. 82 | dropout_keep_prob: the probability that activations are kept in the dropout 83 | layers during training. 84 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 85 | logits. Useful to remove unnecessary dimensions for classification. 86 | scope: Optional scope for the variables. 87 | global_pool: Optional boolean flag. If True, the input to the classification 88 | layer is avgpooled to size 1x1, for any input size. (This is not part 89 | of the original AlexNet.) 90 | 91 | Returns: 92 | net: the output of the logits layer (if num_classes is a non-zero integer), 93 | or the non-dropped-out input to the logits layer (if num_classes is 0 94 | or None). 95 | end_points: a dict of tensors with intermediate activations. 96 | """ 97 | with tf.variable_scope(scope, 'alexnet_v2', [inputs]) as sc: 98 | end_points_collection = sc.original_name_scope + '_end_points' 99 | # Collect outputs for conv2d, fully_connected and max_pool2d. 100 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 101 | outputs_collections=[end_points_collection]): 102 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 103 | scope='conv1') 104 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool1') 105 | net = slim.conv2d(net, 192, [5, 5], scope='conv2') 106 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool2') 107 | net = slim.conv2d(net, 384, [3, 3], scope='conv3') 108 | net = slim.conv2d(net, 384, [3, 3], scope='conv4') 109 | net = slim.conv2d(net, 256, [3, 3], scope='conv5') 110 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool5') 111 | 112 | # Use conv2d instead of fully_connected layers. 113 | with slim.arg_scope([slim.conv2d], 114 | weights_initializer=trunc_normal(0.005), 115 | biases_initializer=tf.constant_initializer(0.1)): 116 | net = slim.conv2d(net, 4096, [5, 5], padding='VALID', 117 | scope='fc6') 118 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 119 | scope='dropout6') 120 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 121 | # Convert end_points_collection into a end_point dict. 122 | end_points = slim.utils.convert_collection_to_dict( 123 | end_points_collection) 124 | if global_pool: 125 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 126 | end_points['global_pool'] = net 127 | if num_classes: 128 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 129 | scope='dropout7') 130 | net = slim.conv2d(net, num_classes, [1, 1], 131 | activation_fn=None, 132 | normalizer_fn=None, 133 | biases_initializer=tf.zeros_initializer(), 134 | scope='fc8') 135 | if spatial_squeeze: 136 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 137 | end_points[sc.name + '/fc8'] = net 138 | return net, end_points 139 | 140 | 141 | alexnet_v2.default_image_size = 224 142 | -------------------------------------------------------------------------------- /nets/cifarnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the CIFAR-10 model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev) 26 | 27 | 28 | def cifarnet(images, num_classes=10, is_training=False, 29 | dropout_keep_prob=0.5, 30 | prediction_fn=slim.softmax, 31 | scope='CifarNet'): 32 | """Creates a variant of the CifarNet model. 33 | 34 | Note that since the output is a set of 'logits', the values fall in the 35 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 36 | probability distribution over the characters, one will need to convert them 37 | using the softmax function: 38 | 39 | logits = cifarnet.cifarnet(images, is_training=False) 40 | probabilities = tf.nn.softmax(logits) 41 | predictions = tf.argmax(logits, 1) 42 | 43 | Args: 44 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 45 | num_classes: the number of classes in the dataset. If 0 or None, the logits 46 | layer is omitted and the input features to the logits layer are returned 47 | instead. 48 | is_training: specifies whether or not we're currently training the model. 49 | This variable will determine the behaviour of the dropout layer. 50 | dropout_keep_prob: the percentage of activation values that are retained. 51 | prediction_fn: a function to get predictions out of logits. 52 | scope: Optional variable_scope. 53 | 54 | Returns: 55 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 56 | is a non-zero integer, or the input to the logits layer if num_classes 57 | is 0 or None. 58 | end_points: a dictionary from components of the network to the corresponding 59 | activation. 60 | """ 61 | end_points = {} 62 | 63 | with tf.variable_scope(scope, 'CifarNet', [images]): 64 | net = slim.conv2d(images, 64, [5, 5], scope='conv1') 65 | end_points['conv1'] = net 66 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 67 | end_points['pool1'] = net 68 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1') 69 | net = slim.conv2d(net, 64, [5, 5], scope='conv2') 70 | end_points['conv2'] = net 71 | net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2') 72 | net = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 73 | end_points['pool2'] = net 74 | net = slim.flatten(net) 75 | end_points['Flatten'] = net 76 | net = slim.fully_connected(net, 384, scope='fc3') 77 | end_points['fc3'] = net 78 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 79 | scope='dropout3') 80 | net = slim.fully_connected(net, 192, scope='fc4') 81 | end_points['fc4'] = net 82 | if not num_classes: 83 | return net, end_points 84 | logits = slim.fully_connected(net, num_classes, 85 | biases_initializer=tf.zeros_initializer(), 86 | weights_initializer=trunc_normal(1/192.0), 87 | weights_regularizer=None, 88 | activation_fn=None, 89 | scope='logits') 90 | 91 | end_points['Logits'] = logits 92 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 93 | 94 | return logits, end_points 95 | cifarnet.default_image_size = 32 96 | 97 | 98 | def cifarnet_arg_scope(weight_decay=0.004): 99 | """Defines the default cifarnet argument scope. 100 | 101 | Args: 102 | weight_decay: The weight decay to use for regularizing the model. 103 | 104 | Returns: 105 | An `arg_scope` to use for the inception v3 model. 106 | """ 107 | with slim.arg_scope( 108 | [slim.conv2d], 109 | weights_initializer=tf.truncated_normal_initializer(stddev=5e-2), 110 | activation_fn=tf.nn.relu): 111 | with slim.arg_scope( 112 | [slim.fully_connected], 113 | biases_initializer=tf.constant_initializer(0.1), 114 | weights_initializer=trunc_normal(0.04), 115 | weights_regularizer=slim.l2_regularizer(weight_decay), 116 | activation_fn=tf.nn.relu) as sc: 117 | return sc 118 | -------------------------------------------------------------------------------- /nets/cyclegan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for tensorflow.contrib.slim.nets.cyclegan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets import cyclegan 24 | 25 | 26 | # TODO(joelshor): Add a test to check generator endpoints. 27 | class CycleganTest(tf.test.TestCase): 28 | 29 | def test_generator_inference(self): 30 | """Check one inference step.""" 31 | img_batch = tf.zeros([2, 32, 32, 3]) 32 | model_output, _ = cyclegan.cyclegan_generator_resnet(img_batch) 33 | with self.test_session() as sess: 34 | sess.run(tf.global_variables_initializer()) 35 | sess.run(model_output) 36 | 37 | def _test_generator_graph_helper(self, shape): 38 | """Check that generator can take small and non-square inputs.""" 39 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(tf.ones(shape)) 40 | self.assertAllEqual(shape, output_imgs.shape.as_list()) 41 | 42 | def test_generator_graph_small(self): 43 | self._test_generator_graph_helper([4, 32, 32, 3]) 44 | 45 | def test_generator_graph_medium(self): 46 | self._test_generator_graph_helper([3, 128, 128, 3]) 47 | 48 | def test_generator_graph_nonsquare(self): 49 | self._test_generator_graph_helper([2, 80, 400, 3]) 50 | 51 | def test_generator_unknown_batch_dim(self): 52 | """Check that generator can take unknown batch dimension inputs.""" 53 | img = tf.placeholder(tf.float32, shape=[None, 32, None, 3]) 54 | output_imgs, _ = cyclegan.cyclegan_generator_resnet(img) 55 | 56 | self.assertAllEqual([None, 32, None, 3], output_imgs.shape.as_list()) 57 | 58 | def _input_and_output_same_shape_helper(self, kernel_size): 59 | img_batch = tf.placeholder(tf.float32, shape=[None, 32, 32, 3]) 60 | output_img_batch, _ = cyclegan.cyclegan_generator_resnet( 61 | img_batch, kernel_size=kernel_size) 62 | 63 | self.assertAllEqual(img_batch.shape.as_list(), 64 | output_img_batch.shape.as_list()) 65 | 66 | def input_and_output_same_shape_kernel3(self): 67 | self._input_and_output_same_shape_helper(3) 68 | 69 | def input_and_output_same_shape_kernel4(self): 70 | self._input_and_output_same_shape_helper(4) 71 | 72 | def input_and_output_same_shape_kernel5(self): 73 | self._input_and_output_same_shape_helper(5) 74 | 75 | def input_and_output_same_shape_kernel6(self): 76 | self._input_and_output_same_shape_helper(6) 77 | 78 | def _error_if_height_not_multiple_of_four_helper(self, height): 79 | self.assertRaisesRegexp( 80 | ValueError, 81 | 'The input height must be a multiple of 4.', 82 | cyclegan.cyclegan_generator_resnet, 83 | tf.placeholder(tf.float32, shape=[None, height, 32, 3])) 84 | 85 | def test_error_if_height_not_multiple_of_four_height29(self): 86 | self._error_if_height_not_multiple_of_four_helper(29) 87 | 88 | def test_error_if_height_not_multiple_of_four_height30(self): 89 | self._error_if_height_not_multiple_of_four_helper(30) 90 | 91 | def test_error_if_height_not_multiple_of_four_height31(self): 92 | self._error_if_height_not_multiple_of_four_helper(31) 93 | 94 | def _error_if_width_not_multiple_of_four_helper(self, width): 95 | self.assertRaisesRegexp( 96 | ValueError, 97 | 'The input width must be a multiple of 4.', 98 | cyclegan.cyclegan_generator_resnet, 99 | tf.placeholder(tf.float32, shape=[None, 32, width, 3])) 100 | 101 | def test_error_if_width_not_multiple_of_four_width29(self): 102 | self._error_if_width_not_multiple_of_four_helper(29) 103 | 104 | def test_error_if_width_not_multiple_of_four_width30(self): 105 | self._error_if_width_not_multiple_of_four_helper(30) 106 | 107 | def test_error_if_width_not_multiple_of_four_width31(self): 108 | self._error_if_width_not_multiple_of_four_helper(31) 109 | 110 | 111 | if __name__ == '__main__': 112 | tf.test.main() 113 | -------------------------------------------------------------------------------- /nets/dcgan_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for dcgan.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from six.moves import xrange # pylint: disable=redefined-builtin 22 | import tensorflow as tf 23 | 24 | from nets import dcgan 25 | 26 | 27 | class DCGANTest(tf.test.TestCase): 28 | 29 | def test_generator_run(self): 30 | tf.set_random_seed(1234) 31 | noise = tf.random_normal([100, 64]) 32 | image, _ = dcgan.generator(noise) 33 | with self.test_session() as sess: 34 | sess.run(tf.global_variables_initializer()) 35 | image.eval() 36 | 37 | def test_generator_graph(self): 38 | tf.set_random_seed(1234) 39 | # Check graph construction for a number of image size/depths and batch 40 | # sizes. 41 | for i, batch_size in zip(xrange(3, 7), xrange(3, 8)): 42 | tf.reset_default_graph() 43 | final_size = 2 ** i 44 | noise = tf.random_normal([batch_size, 64]) 45 | image, end_points = dcgan.generator( 46 | noise, 47 | depth=32, 48 | final_size=final_size) 49 | 50 | self.assertAllEqual([batch_size, final_size, final_size, 3], 51 | image.shape.as_list()) 52 | 53 | expected_names = ['deconv%i' % j for j in xrange(1, i)] + ['logits'] 54 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 55 | 56 | # Check layer depths. 57 | for j in range(1, i): 58 | layer = end_points['deconv%i' % j] 59 | self.assertEqual(32 * 2**(i-j-1), layer.get_shape().as_list()[-1]) 60 | 61 | def test_generator_invalid_input(self): 62 | wrong_dim_input = tf.zeros([5, 32, 32]) 63 | with self.assertRaises(ValueError): 64 | dcgan.generator(wrong_dim_input) 65 | 66 | correct_input = tf.zeros([3, 2]) 67 | with self.assertRaisesRegexp(ValueError, 'must be a power of 2'): 68 | dcgan.generator(correct_input, final_size=30) 69 | 70 | with self.assertRaisesRegexp(ValueError, 'must be greater than 8'): 71 | dcgan.generator(correct_input, final_size=4) 72 | 73 | def test_discriminator_run(self): 74 | image = tf.random_uniform([5, 32, 32, 3], -1, 1) 75 | output, _ = dcgan.discriminator(image) 76 | with self.test_session() as sess: 77 | sess.run(tf.global_variables_initializer()) 78 | output.eval() 79 | 80 | def test_discriminator_graph(self): 81 | # Check graph construction for a number of image size/depths and batch 82 | # sizes. 83 | for i, batch_size in zip(xrange(1, 6), xrange(3, 8)): 84 | tf.reset_default_graph() 85 | img_w = 2 ** i 86 | image = tf.random_uniform([batch_size, img_w, img_w, 3], -1, 1) 87 | output, end_points = dcgan.discriminator( 88 | image, 89 | depth=32) 90 | 91 | self.assertAllEqual([batch_size, 1], output.get_shape().as_list()) 92 | 93 | expected_names = ['conv%i' % j for j in xrange(1, i+1)] + ['logits'] 94 | self.assertSetEqual(set(expected_names), set(end_points.keys())) 95 | 96 | # Check layer depths. 97 | for j in range(1, i+1): 98 | layer = end_points['conv%i' % j] 99 | self.assertEqual(32 * 2**(j-1), layer.get_shape().as_list()[-1]) 100 | 101 | def test_discriminator_invalid_input(self): 102 | wrong_dim_img = tf.zeros([5, 32, 32]) 103 | with self.assertRaises(ValueError): 104 | dcgan.discriminator(wrong_dim_img) 105 | 106 | spatially_undefined_shape = tf.placeholder(tf.float32, [5, 32, None, 3]) 107 | with self.assertRaises(ValueError): 108 | dcgan.discriminator(spatially_undefined_shape) 109 | 110 | not_square = tf.zeros([5, 32, 16, 3]) 111 | with self.assertRaisesRegexp(ValueError, 'not have equal width and height'): 112 | dcgan.discriminator(not_square) 113 | 114 | not_power_2 = tf.zeros([5, 30, 30, 3]) 115 | with self.assertRaisesRegexp(ValueError, 'not a power of 2'): 116 | dcgan.discriminator(not_power_2) 117 | 118 | 119 | if __name__ == '__main__': 120 | tf.test.main() 121 | -------------------------------------------------------------------------------- /nets/inception.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Brings all inception models under one namespace.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | # pylint: disable=unused-import 22 | from nets.inception_resnet_v2 import inception_resnet_v2 23 | from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope 24 | from nets.inception_resnet_v2 import inception_resnet_v2_base 25 | from nets.inception_v1 import inception_v1 26 | from nets.inception_v1 import inception_v1_arg_scope 27 | from nets.inception_v1 import inception_v1_base 28 | from nets.inception_v2 import inception_v2 29 | from nets.inception_v2 import inception_v2_arg_scope 30 | from nets.inception_v2 import inception_v2_base 31 | from nets.inception_v3 import inception_v3 32 | from nets.inception_v3 import inception_v3_arg_scope 33 | from nets.inception_v3 import inception_v3_base 34 | from nets.inception_v4 import inception_v4 35 | from nets.inception_v4 import inception_v4_arg_scope 36 | from nets.inception_v4 import inception_v4_base 37 | # pylint: enable=unused-import 38 | -------------------------------------------------------------------------------- /nets/inception_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains common code shared by all inception models. 16 | 17 | Usage of arg scope: 18 | with slim.arg_scope(inception_arg_scope()): 19 | logits, end_points = inception.inception_v3(images, num_classes, 20 | is_training=is_training) 21 | 22 | """ 23 | from __future__ import absolute_import 24 | from __future__ import division 25 | from __future__ import print_function 26 | 27 | import tensorflow as tf 28 | 29 | slim = tf.contrib.slim 30 | 31 | 32 | def inception_arg_scope(weight_decay=0.00004, 33 | use_batch_norm=True, 34 | batch_norm_decay=0.9997, 35 | batch_norm_epsilon=0.001, 36 | activation_fn=tf.nn.relu, 37 | batch_norm_updates_collections=tf.GraphKeys.UPDATE_OPS): 38 | """Defines the default arg scope for inception models. 39 | 40 | Args: 41 | weight_decay: The weight decay to use for regularizing the model. 42 | use_batch_norm: "If `True`, batch_norm is applied after each convolution. 43 | batch_norm_decay: Decay for batch norm moving average. 44 | batch_norm_epsilon: Small float added to variance to avoid dividing by zero 45 | in batch norm. 46 | activation_fn: Activation function for conv2d. 47 | batch_norm_updates_collections: Collection for the update ops for 48 | batch norm. 49 | 50 | Returns: 51 | An `arg_scope` to use for the inception models. 52 | """ 53 | batch_norm_params = { 54 | # Decay for the moving averages. 55 | 'decay': batch_norm_decay, 56 | # epsilon to prevent 0s in variance. 57 | 'epsilon': batch_norm_epsilon, 58 | # collection containing update_ops. 59 | 'updates_collections': batch_norm_updates_collections, 60 | # use fused batch norm if possible. 61 | 'fused': None, 62 | } 63 | if use_batch_norm: 64 | normalizer_fn = slim.batch_norm 65 | normalizer_params = batch_norm_params 66 | else: 67 | normalizer_fn = None 68 | normalizer_params = {} 69 | # Set weight_decay for weights in Conv and FC layers. 70 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 71 | weights_regularizer=slim.l2_regularizer(weight_decay)): 72 | with slim.arg_scope( 73 | [slim.conv2d], 74 | weights_initializer=slim.variance_scaling_initializer(), 75 | activation_fn=activation_fn, 76 | normalizer_fn=normalizer_fn, 77 | normalizer_params=normalizer_params) as sc: 78 | return sc 79 | -------------------------------------------------------------------------------- /nets/lenet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a variant of the LeNet model definition.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | slim = tf.contrib.slim 24 | 25 | 26 | def lenet(images, num_classes=10, is_training=False, 27 | dropout_keep_prob=0.5, 28 | prediction_fn=slim.softmax, 29 | scope='LeNet'): 30 | """Creates a variant of the LeNet model. 31 | 32 | Note that since the output is a set of 'logits', the values fall in the 33 | interval of (-infinity, infinity). Consequently, to convert the outputs to a 34 | probability distribution over the characters, one will need to convert them 35 | using the softmax function: 36 | 37 | logits = lenet.lenet(images, is_training=False) 38 | probabilities = tf.nn.softmax(logits) 39 | predictions = tf.argmax(logits, 1) 40 | 41 | Args: 42 | images: A batch of `Tensors` of size [batch_size, height, width, channels]. 43 | num_classes: the number of classes in the dataset. If 0 or None, the logits 44 | layer is omitted and the input features to the logits layer are returned 45 | instead. 46 | is_training: specifies whether or not we're currently training the model. 47 | This variable will determine the behaviour of the dropout layer. 48 | dropout_keep_prob: the percentage of activation values that are retained. 49 | prediction_fn: a function to get predictions out of logits. 50 | scope: Optional variable_scope. 51 | 52 | Returns: 53 | net: a 2D Tensor with the logits (pre-softmax activations) if num_classes 54 | is a non-zero integer, or the inon-dropped-out nput to the logits layer 55 | if num_classes is 0 or None. 56 | end_points: a dictionary from components of the network to the corresponding 57 | activation. 58 | """ 59 | end_points = {} 60 | 61 | with tf.variable_scope(scope, 'LeNet', [images]): 62 | net = end_points['conv1'] = slim.conv2d(images, 32, [5, 5], scope='conv1') 63 | net = end_points['pool1'] = slim.max_pool2d(net, [2, 2], 2, scope='pool1') 64 | net = end_points['conv2'] = slim.conv2d(net, 64, [5, 5], scope='conv2') 65 | net = end_points['pool2'] = slim.max_pool2d(net, [2, 2], 2, scope='pool2') 66 | net = slim.flatten(net) 67 | end_points['Flatten'] = net 68 | 69 | net = end_points['fc3'] = slim.fully_connected(net, 1024, scope='fc3') 70 | if not num_classes: 71 | return net, end_points 72 | net = end_points['dropout3'] = slim.dropout( 73 | net, dropout_keep_prob, is_training=is_training, scope='dropout3') 74 | logits = end_points['Logits'] = slim.fully_connected( 75 | net, num_classes, activation_fn=None, scope='fc4') 76 | 77 | end_points['Predictions'] = prediction_fn(logits, scope='Predictions') 78 | 79 | return logits, end_points 80 | lenet.default_image_size = 28 81 | 82 | 83 | def lenet_arg_scope(weight_decay=0.0): 84 | """Defines the default lenet argument scope. 85 | 86 | Args: 87 | weight_decay: The weight decay to use for regularizing the model. 88 | 89 | Returns: 90 | An `arg_scope` to use for the inception v3 model. 91 | """ 92 | with slim.arg_scope( 93 | [slim.conv2d, slim.fully_connected], 94 | weights_regularizer=slim.l2_regularizer(weight_decay), 95 | weights_initializer=tf.truncated_normal_initializer(stddev=0.1), 96 | activation_fn=tf.nn.relu) as sc: 97 | return sc 98 | -------------------------------------------------------------------------------- /nets/mobilenet/README.md: -------------------------------------------------------------------------------- 1 | # MobileNetV2 2 | This folder contains building code for MobileNetV2, based on 3 | [MobileNetV2: Inverted Residuals and Linear Bottlenecks](https://arxiv.org/abs/1801.04381) 4 | 5 | # Performance 6 | ## Latency 7 | This is the timing of [MobileNetV1](../mobilenet_v1.md) vs MobileNetV2 using 8 | TF-Lite on the large core of Pixel 1 phone. 9 | 10 | ![mnet_v1_vs_v2_pixel1_latency.png](mnet_v1_vs_v2_pixel1_latency.png) 11 | 12 | ## MACs 13 | MACs, also sometimes known as MADDs - the number of multiply-accumulates needed 14 | to compute an inference on a single image is a common metric to measure the efficiency of the model. 15 | 16 | Below is the graph comparing V2 vs a few selected networks. The size 17 | of each blob represents the number of parameters. Note for [ShuffleNet](https://arxiv.org/abs/1707.01083) there 18 | are no published size numbers. We estimate it to be comparable to MobileNetV2 numbers. 19 | 20 | ![madds_top1_accuracy](madds_top1_accuracy.png) 21 | 22 | # Pretrained models 23 | ## Imagenet Checkpoints 24 | 25 | Classification Checkpoint | MACs (M)| Parameters (M)| Top 1 Accuracy| Top 5 Accuracy | Mobile CPU (ms) Pixel 1 26 | ---------------------------|---------|---------------|---------|----|------------- 27 | | [mobilenet_v2_1.4_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.4_224.tgz) | 582 | 6.06 | 75.0 | 92.5 | 138.0 28 | | [mobilenet_v2_1.3_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.3_224.tgz) | 509 | 5.34 | 74.4 | 92.1 | 123.0 29 | | [mobilenet_v2_1.0_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_224.tgz) | 300 | 3.47 | 71.8 | 91.0 | 73.8 30 | | [mobilenet_v2_1.0_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_192.tgz) | 221 | 3.47 | 70.7 | 90.1 | 55.1 31 | | [mobilenet_v2_1.0_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_160.tgz) | 154 | 3.47 | 68.8 | 89.0 | 40.2 32 | | [mobilenet_v2_1.0_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_128.tgz) | 99 | 3.47 | 65.3 | 86.9 | 27.6 33 | | [mobilenet_v2_1.0_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_1.0_96.tgz) | 56 | 3.47 | 60.3 | 83.2 | 17.6 34 | | [mobilenet_v2_0.75_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_224.tgz) | 209 | 2.61 | 69.8 | 89.6 | 55.8 35 | | [mobilenet_v2_0.75_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_192.tgz) | 153 | 2.61 | 68.7 | 88.9 | 41.6 36 | | [mobilenet_v2_0.75_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_160.tgz) | 107 | 2.61 | 66.4 | 87.3 | 30.4 37 | | [mobilenet_v2_0.75_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_128.tgz) | 69 | 2.61 | 63.2 | 85.3 | 21.9 38 | | [mobilenet_v2_0.75_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.75_96.tgz) | 39 | 2.61 | 58.8 | 81.6 | 14.2 39 | | [mobilenet_v2_0.5_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_224.tgz) | 97 | 1.95 | 65.4 | 86.4 | 28.7 40 | | [mobilenet_v2_0.5_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_192.tgz) | 71 | 1.95 | 63.9 | 85.4 | 21.1 41 | | [mobilenet_v2_0.5_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_160.tgz) | 50 | 1.95 | 61.0 | 83.2 | 14.9 42 | | [mobilenet_v2_0.5_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_128.tgz) | 32 | 1.95 | 57.7 | 80.8 | 9.9 43 | | [mobilenet_v2_0.5_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.5_96.tgz) | 18 | 1.95 | 51.2 | 75.8 | 6.4 44 | | [mobilenet_v2_0.35_224](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_224.tgz) | 59 | 1.66 | 60.3 | 82.9 | 19.7 45 | | [mobilenet_v2_0.35_192](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_192.tgz) | 43 | 1.66 | 58.2 | 81.2 | 14.6 46 | | [mobilenet_v2_0.35_160](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_160.tgz) | 30 | 1.66 | 55.7 | 79.1 | 10.5 47 | | [mobilenet_v2_0.35_128](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_128.tgz) | 20 | 1.66 | 50.8 | 75.0 | 6.9 48 | | [mobilenet_v2_0.35_96](https://storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_0.35_96.tgz) | 11 | 1.66 | 45.5 | 70.4 | 4.5 49 | 50 | # Training 51 | The numbers above can be reproduced using slim's `train_image_classifier`. 52 | Below is the set of parameters that achieves 72.0% for full size MobileNetV2, after about 700K when trained on 8 GPU. 53 | If trained on a single GPU the full convergence is after 5.5M steps. Also note that learning rate and 54 | num_epochs_per_decay both need to be adjusted depending on how many GPUs are being 55 | used due to slim's internal averaging. 56 | 57 | ```bash 58 | --model_name="mobilenet_v2" 59 | --learning_rate=0.045 * NUM_GPUS #slim internally averages clones so we compensate 60 | --preprocessing_name="inception_v2" 61 | --label_smoothing=0.1 62 | --moving_average_decay=0.9999 63 | --batch_size= 96 64 | --num_clones = NUM_GPUS # you can use any number here between 1 and 8 depending on your hardware setup. 65 | --learning_rate_decay_factor=0.98 66 | --num_epochs_per_decay = 2.5 / NUM_GPUS # train_image_classifier does per clone epochs 67 | ``` 68 | 69 | # Example 70 | 71 | 72 | See this [ipython notebook](mobilenet_example.ipynb) or open and run the network directly in [Colaboratory](https://colab.research.google.com/github/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_example.ipynb). 73 | 74 | -------------------------------------------------------------------------------- /nets/mobilenet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/mobilenet/__init__.py -------------------------------------------------------------------------------- /nets/mobilenet/madds_top1_accuracy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/mobilenet/madds_top1_accuracy.png -------------------------------------------------------------------------------- /nets/mobilenet/mnet_v1_vs_v2_pixel1_latency.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/mobilenet/mnet_v1_vs_v2_pixel1_latency.png -------------------------------------------------------------------------------- /nets/mobilenet_v1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/nets/mobilenet_v1.png -------------------------------------------------------------------------------- /nets/mobilenet_v1_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Validate mobilenet_v1 with options for quantization.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import math 22 | import tensorflow as tf 23 | 24 | from datasets import dataset_factory 25 | from nets import mobilenet_v1 26 | from preprocessing import preprocessing_factory 27 | 28 | slim = tf.contrib.slim 29 | 30 | flags = tf.app.flags 31 | 32 | flags.DEFINE_string('master', '', 'Session master') 33 | flags.DEFINE_integer('batch_size', 250, 'Batch size') 34 | flags.DEFINE_integer('num_classes', 1001, 'Number of classes to distinguish') 35 | flags.DEFINE_integer('num_examples', 50000, 'Number of examples to evaluate') 36 | flags.DEFINE_integer('image_size', 224, 'Input image resolution') 37 | flags.DEFINE_float('depth_multiplier', 1.0, 'Depth multiplier for mobilenet') 38 | flags.DEFINE_bool('quantize', False, 'Quantize training') 39 | flags.DEFINE_string('checkpoint_dir', '', 'The directory for checkpoints') 40 | flags.DEFINE_string('eval_dir', '', 'Directory for writing eval event logs') 41 | flags.DEFINE_string('dataset_dir', '', 'Location of dataset') 42 | 43 | FLAGS = flags.FLAGS 44 | 45 | 46 | def imagenet_input(is_training): 47 | """Data reader for imagenet. 48 | 49 | Reads in imagenet data and performs pre-processing on the images. 50 | 51 | Args: 52 | is_training: bool specifying if train or validation dataset is needed. 53 | Returns: 54 | A batch of images and labels. 55 | """ 56 | if is_training: 57 | dataset = dataset_factory.get_dataset('imagenet', 'train', 58 | FLAGS.dataset_dir) 59 | else: 60 | dataset = dataset_factory.get_dataset('imagenet', 'validation', 61 | FLAGS.dataset_dir) 62 | 63 | provider = slim.dataset_data_provider.DatasetDataProvider( 64 | dataset, 65 | shuffle=is_training, 66 | common_queue_capacity=2 * FLAGS.batch_size, 67 | common_queue_min=FLAGS.batch_size) 68 | [image, label] = provider.get(['image', 'label']) 69 | 70 | image_preprocessing_fn = preprocessing_factory.get_preprocessing( 71 | 'mobilenet_v1', is_training=is_training) 72 | 73 | image = image_preprocessing_fn(image, FLAGS.image_size, FLAGS.image_size) 74 | 75 | images, labels = tf.train.batch( 76 | tensors=[image, label], 77 | batch_size=FLAGS.batch_size, 78 | num_threads=4, 79 | capacity=5 * FLAGS.batch_size) 80 | return images, labels 81 | 82 | 83 | def metrics(logits, labels): 84 | """Specify the metrics for eval. 85 | 86 | Args: 87 | logits: Logits output from the graph. 88 | labels: Ground truth labels for inputs. 89 | 90 | Returns: 91 | Eval Op for the graph. 92 | """ 93 | labels = tf.squeeze(labels) 94 | names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({ 95 | 'Accuracy': tf.metrics.accuracy(tf.argmax(logits, 1), labels), 96 | 'Recall_5': tf.metrics.recall_at_k(labels, logits, 5), 97 | }) 98 | for name, value in names_to_values.iteritems(): 99 | slim.summaries.add_scalar_summary( 100 | value, name, prefix='eval', print_summary=True) 101 | return names_to_updates.values() 102 | 103 | 104 | def build_model(): 105 | """Build the mobilenet_v1 model for evaluation. 106 | 107 | Returns: 108 | g: graph with rewrites after insertion of quantization ops and batch norm 109 | folding. 110 | eval_ops: eval ops for inference. 111 | variables_to_restore: List of variables to restore from checkpoint. 112 | """ 113 | g = tf.Graph() 114 | with g.as_default(): 115 | inputs, labels = imagenet_input(is_training=False) 116 | 117 | scope = mobilenet_v1.mobilenet_v1_arg_scope( 118 | is_training=False, weight_decay=0.0) 119 | with slim.arg_scope(scope): 120 | logits, _ = mobilenet_v1.mobilenet_v1( 121 | inputs, 122 | is_training=False, 123 | depth_multiplier=FLAGS.depth_multiplier, 124 | num_classes=FLAGS.num_classes) 125 | 126 | if FLAGS.quantize: 127 | tf.contrib.quantize.create_eval_graph() 128 | 129 | eval_ops = metrics(logits, labels) 130 | 131 | return g, eval_ops 132 | 133 | 134 | def eval_model(): 135 | """Evaluates mobilenet_v1.""" 136 | g, eval_ops = build_model() 137 | with g.as_default(): 138 | num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size)) 139 | slim.evaluation.evaluate_once( 140 | FLAGS.master, 141 | FLAGS.checkpoint_dir, 142 | logdir=FLAGS.eval_dir, 143 | num_evals=num_batches, 144 | eval_op=eval_ops) 145 | 146 | 147 | def main(unused_arg): 148 | eval_model() 149 | 150 | 151 | if __name__ == '__main__': 152 | tf.app.run(main) 153 | -------------------------------------------------------------------------------- /nets/nasnet/README.md: -------------------------------------------------------------------------------- 1 | # TensorFlow-Slim NASNet-A Implementation/Checkpoints 2 | This directory contains the code for the NASNet-A model from the paper 3 | [Learning Transferable Architectures for Scalable Image Recognition](https://arxiv.org/abs/1707.07012) by Zoph et al. 4 | In nasnet.py there are three different configurations of NASNet-A that are implementented. One of the models is the NASNet-A built for CIFAR-10 and the 5 | other two are variants of NASNet-A trained on ImageNet, which are listed below. 6 | 7 | # Pre-Trained Models 8 | Two NASNet-A checkpoints are available that have been trained on the 9 | [ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) 10 | image classification dataset. Accuracies were computed by evaluating using a single image crop. 11 | 12 | Model Checkpoint | Million MACs | Million Parameters | Top-1 Accuracy| Top-5 Accuracy | 13 | :----:|:------------:|:----------:|:-------:|:-------:| 14 | [NASNet-A_Mobile_224](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz)|564|5.3|74.0|91.6| 15 | [NASNet-A_Large_331](https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_large_04_10_2017.tar.gz)|23800|88.9|82.7|96.2| 16 | 17 | 18 | Here is an example of how to download the NASNet-A_Mobile_224 checkpoint. The way to download the NASNet-A_Large_331 is the same. 19 | 20 | ```shell 21 | CHECKPOINT_DIR=/tmp/checkpoints 22 | mkdir ${CHECKPOINT_DIR} 23 | cd ${CHECKPOINT_DIR} 24 | wget https://storage.googleapis.com/download.tensorflow.org/models/nasnet-a_mobile_04_10_2017.tar.gz 25 | tar -xvf nasnet-a_mobile_04_10_2017.tar.gz 26 | rm nasnet-a_mobile_04_10_2017.tar.gz 27 | ``` 28 | More information on integrating NASNet Models into your project can be found at the [TF-Slim Image Classification Library](https://github.com/tensorflow/models/blob/master/research/slim/README.md). 29 | 30 | To get started running models on-device go to [TensorFlow Mobile](https://www.tensorflow.org/mobile/). 31 | 32 | ## Sample Commands for using NASNet-A Mobile and Large Checkpoints for Inference 33 | ------- 34 | Run eval with the NASNet-A mobile ImageNet model 35 | 36 | ```shell 37 | DATASET_DIR=/tmp/imagenet 38 | EVAL_DIR=/tmp/tfmodel/eval 39 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 40 | python tensorflow_models/research/slim/eval_image_classifier \ 41 | --checkpoint_path=${CHECKPOINT_DIR} \ 42 | --eval_dir=${EVAL_DIR} \ 43 | --dataset_dir=${DATASET_DIR} \ 44 | --dataset_name=imagenet \ 45 | --dataset_split_name=validation \ 46 | --model_name=nasnet_mobile \ 47 | --eval_image_size=224 48 | ``` 49 | 50 | Run eval with the NASNet-A large ImageNet model 51 | 52 | ```shell 53 | DATASET_DIR=/tmp/imagenet 54 | EVAL_DIR=/tmp/tfmodel/eval 55 | CHECKPOINT_DIR=/tmp/checkpoints/model.ckpt 56 | python tensorflow_models/research/slim/eval_image_classifier \ 57 | --checkpoint_path=${CHECKPOINT_DIR} \ 58 | --eval_dir=${EVAL_DIR} \ 59 | --dataset_dir=${DATASET_DIR} \ 60 | --dataset_name=imagenet \ 61 | --dataset_split_name=validation \ 62 | --model_name=nasnet_large \ 63 | --eval_image_size=331 64 | ``` 65 | -------------------------------------------------------------------------------- /nets/nasnet/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /nets/nasnet/nasnet_utils_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Tests for slim.nets.nasnet.nasnet_utils.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | 23 | from nets.nasnet import nasnet_utils 24 | 25 | 26 | class NasnetUtilsTest(tf.test.TestCase): 27 | 28 | def testCalcReductionLayers(self): 29 | num_cells = 18 30 | num_reduction_layers = 2 31 | reduction_layers = nasnet_utils.calc_reduction_layers( 32 | num_cells, num_reduction_layers) 33 | self.assertEqual(len(reduction_layers), 2) 34 | self.assertEqual(reduction_layers[0], 6) 35 | self.assertEqual(reduction_layers[1], 12) 36 | 37 | def testGetChannelIndex(self): 38 | data_formats = ['NHWC', 'NCHW'] 39 | for data_format in data_formats: 40 | index = nasnet_utils.get_channel_index(data_format) 41 | correct_index = 3 if data_format == 'NHWC' else 1 42 | self.assertEqual(index, correct_index) 43 | 44 | def testGetChannelDim(self): 45 | data_formats = ['NHWC', 'NCHW'] 46 | shape = [10, 20, 30, 40] 47 | for data_format in data_formats: 48 | dim = nasnet_utils.get_channel_dim(shape, data_format) 49 | correct_dim = shape[3] if data_format == 'NHWC' else shape[1] 50 | self.assertEqual(dim, correct_dim) 51 | 52 | def testGlobalAvgPool(self): 53 | data_formats = ['NHWC', 'NCHW'] 54 | inputs = tf.placeholder(tf.float32, (5, 10, 20, 10)) 55 | for data_format in data_formats: 56 | output = nasnet_utils.global_avg_pool( 57 | inputs, data_format) 58 | self.assertEqual(output.shape, [5, 10]) 59 | 60 | 61 | if __name__ == '__main__': 62 | tf.test.main() 63 | -------------------------------------------------------------------------------- /nets/nets_factory.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a factory for building various models.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | import functools 21 | 22 | import tensorflow as tf 23 | 24 | from nets import alexnet 25 | from nets import cifarnet 26 | from nets import inception 27 | from nets import lenet 28 | from nets import mobilenet_v1 29 | from nets import overfeat 30 | from nets import resnet_v1 31 | from nets import resnet_v2 32 | from nets import vgg 33 | from nets.mobilenet import mobilenet_v2 34 | from nets.nasnet import nasnet 35 | from nets.nasnet import pnasnet 36 | 37 | slim = tf.contrib.slim 38 | 39 | networks_map = {'alexnet_v2': alexnet.alexnet_v2, 40 | 'cifarnet': cifarnet.cifarnet, 41 | 'overfeat': overfeat.overfeat, 42 | 'vgg_a': vgg.vgg_a, 43 | 'vgg_16': vgg.vgg_16, 44 | 'vgg_19': vgg.vgg_19, 45 | 'inception_v1': inception.inception_v1, 46 | 'inception_v2': inception.inception_v2, 47 | 'inception_v3': inception.inception_v3, 48 | 'inception_v4': inception.inception_v4, 49 | 'inception_resnet_v2': inception.inception_resnet_v2, 50 | 'lenet': lenet.lenet, 51 | 'resnet_v1_50': resnet_v1.resnet_v1_50, 52 | 'resnet_v1_101': resnet_v1.resnet_v1_101, 53 | 'resnet_v1_152': resnet_v1.resnet_v1_152, 54 | 'resnet_v1_200': resnet_v1.resnet_v1_200, 55 | 'resnet_v2_50': resnet_v2.resnet_v2_50, 56 | 'resnet_v2_101': resnet_v2.resnet_v2_101, 57 | 'resnet_v2_152': resnet_v2.resnet_v2_152, 58 | 'resnet_v2_200': resnet_v2.resnet_v2_200, 59 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1, 60 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_075, 61 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_050, 62 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_025, 63 | 'mobilenet_v2': mobilenet_v2.mobilenet, 64 | 'mobilenet_v2_140': mobilenet_v2.mobilenet_v2_140, 65 | 'mobilenet_v2_035': mobilenet_v2.mobilenet_v2_035, 66 | 'nasnet_cifar': nasnet.build_nasnet_cifar, 67 | 'nasnet_mobile': nasnet.build_nasnet_mobile, 68 | 'nasnet_large': nasnet.build_nasnet_large, 69 | 'pnasnet_large': pnasnet.build_pnasnet_large, 70 | 'pnasnet_mobile': pnasnet.build_pnasnet_mobile, 71 | } 72 | 73 | arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope, 74 | 'cifarnet': cifarnet.cifarnet_arg_scope, 75 | 'overfeat': overfeat.overfeat_arg_scope, 76 | 'vgg_a': vgg.vgg_arg_scope, 77 | 'vgg_16': vgg.vgg_arg_scope, 78 | 'vgg_19': vgg.vgg_arg_scope, 79 | 'inception_v1': inception.inception_v3_arg_scope, 80 | 'inception_v2': inception.inception_v3_arg_scope, 81 | 'inception_v3': inception.inception_v3_arg_scope, 82 | 'inception_v4': inception.inception_v4_arg_scope, 83 | 'inception_resnet_v2': 84 | inception.inception_resnet_v2_arg_scope, 85 | 'lenet': lenet.lenet_arg_scope, 86 | 'resnet_v1_50': resnet_v1.resnet_arg_scope, 87 | 'resnet_v1_101': resnet_v1.resnet_arg_scope, 88 | 'resnet_v1_152': resnet_v1.resnet_arg_scope, 89 | 'resnet_v1_200': resnet_v1.resnet_arg_scope, 90 | 'resnet_v2_50': resnet_v2.resnet_arg_scope, 91 | 'resnet_v2_101': resnet_v2.resnet_arg_scope, 92 | 'resnet_v2_152': resnet_v2.resnet_arg_scope, 93 | 'resnet_v2_200': resnet_v2.resnet_arg_scope, 94 | 'mobilenet_v1': mobilenet_v1.mobilenet_v1_arg_scope, 95 | 'mobilenet_v1_075': mobilenet_v1.mobilenet_v1_arg_scope, 96 | 'mobilenet_v1_050': mobilenet_v1.mobilenet_v1_arg_scope, 97 | 'mobilenet_v1_025': mobilenet_v1.mobilenet_v1_arg_scope, 98 | 'mobilenet_v2': mobilenet_v2.training_scope, 99 | 'mobilenet_v2_035': mobilenet_v2.training_scope, 100 | 'mobilenet_v2_140': mobilenet_v2.training_scope, 101 | 'nasnet_cifar': nasnet.nasnet_cifar_arg_scope, 102 | 'nasnet_mobile': nasnet.nasnet_mobile_arg_scope, 103 | 'nasnet_large': nasnet.nasnet_large_arg_scope, 104 | 'pnasnet_large': pnasnet.pnasnet_large_arg_scope, 105 | 'pnasnet_mobile': pnasnet.pnasnet_mobile_arg_scope, 106 | } 107 | 108 | 109 | def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False): 110 | """Returns a network_fn such as `logits, end_points = network_fn(images)`. 111 | 112 | Args: 113 | name: The name of the network. 114 | num_classes: The number of classes to use for classification. If 0 or None, 115 | the logits layer is omitted and its input features are returned instead. 116 | weight_decay: The l2 coefficient for the model weights. 117 | is_training: `True` if the model is being used for training and `False` 118 | otherwise. 119 | 120 | Returns: 121 | network_fn: A function that applies the model to a batch of images. It has 122 | the following signature: 123 | net, end_points = network_fn(images) 124 | The `images` input is a tensor of shape [batch_size, height, width, 3] 125 | with height = width = network_fn.default_image_size. (The permissibility 126 | and treatment of other sizes depends on the network_fn.) 127 | The returned `end_points` are a dictionary of intermediate activations. 128 | The returned `net` is the topmost layer, depending on `num_classes`: 129 | If `num_classes` was a non-zero integer, `net` is a logits tensor 130 | of shape [batch_size, num_classes]. 131 | If `num_classes` was 0 or `None`, `net` is a tensor with the input 132 | to the logits layer of shape [batch_size, 1, 1, num_features] or 133 | [batch_size, num_features]. Dropout has not been applied to this 134 | (even if the network's original classification does); it remains for 135 | the caller to do this or not. 136 | 137 | Raises: 138 | ValueError: If network `name` is not recognized. 139 | """ 140 | if name not in networks_map: 141 | raise ValueError('Name of network unknown %s' % name) 142 | func = networks_map[name] 143 | @functools.wraps(func) 144 | def network_fn(images, **kwargs): 145 | arg_scope = arg_scopes_map[name](weight_decay=weight_decay) 146 | with slim.arg_scope(arg_scope): 147 | return func(images, num_classes, is_training=is_training, **kwargs) 148 | if hasattr(func, 'default_image_size'): 149 | network_fn.default_image_size = func.default_image_size 150 | 151 | return network_fn 152 | -------------------------------------------------------------------------------- /nets/nets_factory_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 Google Inc. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for slim.inception.""" 17 | 18 | from __future__ import absolute_import 19 | from __future__ import division 20 | from __future__ import print_function 21 | 22 | 23 | import tensorflow as tf 24 | 25 | from nets import nets_factory 26 | 27 | 28 | class NetworksTest(tf.test.TestCase): 29 | 30 | def testGetNetworkFnFirstHalf(self): 31 | batch_size = 5 32 | num_classes = 1000 33 | for net in list(nets_factory.networks_map.keys())[:10]: 34 | with tf.Graph().as_default() as g, self.test_session(g): 35 | net_fn = nets_factory.get_network_fn(net, num_classes) 36 | # Most networks use 224 as their default_image_size 37 | image_size = getattr(net_fn, 'default_image_size', 224) 38 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 39 | logits, end_points = net_fn(inputs) 40 | self.assertTrue(isinstance(logits, tf.Tensor)) 41 | self.assertTrue(isinstance(end_points, dict)) 42 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 43 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 44 | 45 | def testGetNetworkFnSecondHalf(self): 46 | batch_size = 5 47 | num_classes = 1000 48 | for net in list(nets_factory.networks_map.keys())[10:]: 49 | with tf.Graph().as_default() as g, self.test_session(g): 50 | net_fn = nets_factory.get_network_fn(net, num_classes) 51 | # Most networks use 224 as their default_image_size 52 | image_size = getattr(net_fn, 'default_image_size', 224) 53 | inputs = tf.random_uniform((batch_size, image_size, image_size, 3)) 54 | logits, end_points = net_fn(inputs) 55 | self.assertTrue(isinstance(logits, tf.Tensor)) 56 | self.assertTrue(isinstance(end_points, dict)) 57 | self.assertEqual(logits.get_shape().as_list()[0], batch_size) 58 | self.assertEqual(logits.get_shape().as_list()[-1], num_classes) 59 | 60 | if __name__ == '__main__': 61 | tf.test.main() 62 | -------------------------------------------------------------------------------- /nets/overfeat.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains the model definition for the OverFeat network. 16 | 17 | The definition for the network was obtained from: 18 | OverFeat: Integrated Recognition, Localization and Detection using 19 | Convolutional Networks 20 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 21 | Yann LeCun, 2014 22 | http://arxiv.org/abs/1312.6229 23 | 24 | Usage: 25 | with slim.arg_scope(overfeat.overfeat_arg_scope()): 26 | outputs, end_points = overfeat.overfeat(inputs) 27 | 28 | @@overfeat 29 | """ 30 | from __future__ import absolute_import 31 | from __future__ import division 32 | from __future__ import print_function 33 | 34 | import tensorflow as tf 35 | 36 | slim = tf.contrib.slim 37 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 38 | 39 | 40 | def overfeat_arg_scope(weight_decay=0.0005): 41 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 42 | activation_fn=tf.nn.relu, 43 | weights_regularizer=slim.l2_regularizer(weight_decay), 44 | biases_initializer=tf.zeros_initializer()): 45 | with slim.arg_scope([slim.conv2d], padding='SAME'): 46 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 47 | return arg_sc 48 | 49 | 50 | def overfeat(inputs, 51 | num_classes=1000, 52 | is_training=True, 53 | dropout_keep_prob=0.5, 54 | spatial_squeeze=True, 55 | scope='overfeat', 56 | global_pool=False): 57 | """Contains the model definition for the OverFeat network. 58 | 59 | The definition for the network was obtained from: 60 | OverFeat: Integrated Recognition, Localization and Detection using 61 | Convolutional Networks 62 | Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and 63 | Yann LeCun, 2014 64 | http://arxiv.org/abs/1312.6229 65 | 66 | Note: All the fully_connected layers have been transformed to conv2d layers. 67 | To use in classification mode, resize input to 231x231. To use in fully 68 | convolutional mode, set spatial_squeeze to false. 69 | 70 | Args: 71 | inputs: a tensor of size [batch_size, height, width, channels]. 72 | num_classes: number of predicted classes. If 0 or None, the logits layer is 73 | omitted and the input features to the logits layer are returned instead. 74 | is_training: whether or not the model is being trained. 75 | dropout_keep_prob: the probability that activations are kept in the dropout 76 | layers during training. 77 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 78 | outputs. Useful to remove unnecessary dimensions for classification. 79 | scope: Optional scope for the variables. 80 | global_pool: Optional boolean flag. If True, the input to the classification 81 | layer is avgpooled to size 1x1, for any input size. (This is not part 82 | of the original OverFeat.) 83 | 84 | Returns: 85 | net: the output of the logits layer (if num_classes is a non-zero integer), 86 | or the non-dropped-out input to the logits layer (if num_classes is 0 or 87 | None). 88 | end_points: a dict of tensors with intermediate activations. 89 | """ 90 | with tf.variable_scope(scope, 'overfeat', [inputs]) as sc: 91 | end_points_collection = sc.original_name_scope + '_end_points' 92 | # Collect outputs for conv2d, fully_connected and max_pool2d 93 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 94 | outputs_collections=end_points_collection): 95 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 96 | scope='conv1') 97 | net = slim.max_pool2d(net, [2, 2], scope='pool1') 98 | net = slim.conv2d(net, 256, [5, 5], padding='VALID', scope='conv2') 99 | net = slim.max_pool2d(net, [2, 2], scope='pool2') 100 | net = slim.conv2d(net, 512, [3, 3], scope='conv3') 101 | net = slim.conv2d(net, 1024, [3, 3], scope='conv4') 102 | net = slim.conv2d(net, 1024, [3, 3], scope='conv5') 103 | net = slim.max_pool2d(net, [2, 2], scope='pool5') 104 | 105 | # Use conv2d instead of fully_connected layers. 106 | with slim.arg_scope([slim.conv2d], 107 | weights_initializer=trunc_normal(0.005), 108 | biases_initializer=tf.constant_initializer(0.1)): 109 | net = slim.conv2d(net, 3072, [6, 6], padding='VALID', scope='fc6') 110 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 111 | scope='dropout6') 112 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 113 | # Convert end_points_collection into a end_point dict. 114 | end_points = slim.utils.convert_collection_to_dict( 115 | end_points_collection) 116 | if global_pool: 117 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 118 | end_points['global_pool'] = net 119 | if num_classes: 120 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 121 | scope='dropout7') 122 | net = slim.conv2d(net, num_classes, [1, 1], 123 | activation_fn=None, 124 | normalizer_fn=None, 125 | biases_initializer=tf.zeros_initializer(), 126 | scope='fc8') 127 | if spatial_squeeze: 128 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 129 | end_points[sc.name + '/fc8'] = net 130 | return net, end_points 131 | overfeat.default_image_size = 231 132 | -------------------------------------------------------------------------------- /nets/pix2pix_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================= 15 | """Tests for pix2pix.""" 16 | 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | import tensorflow as tf 22 | from nets import pix2pix 23 | 24 | 25 | class GeneratorTest(tf.test.TestCase): 26 | 27 | def _reduced_default_blocks(self): 28 | """Returns the default blocks, scaled down to make test run faster.""" 29 | return [pix2pix.Block(b.num_filters // 32, b.decoder_keep_prob) 30 | for b in pix2pix._default_generator_blocks()] 31 | 32 | def test_output_size_nn_upsample_conv(self): 33 | batch_size = 2 34 | height, width = 256, 256 35 | num_outputs = 4 36 | 37 | images = tf.ones((batch_size, height, width, 3)) 38 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 39 | logits, _ = pix2pix.pix2pix_generator( 40 | images, num_outputs, blocks=self._reduced_default_blocks(), 41 | upsample_method='nn_upsample_conv') 42 | 43 | with self.test_session() as session: 44 | session.run(tf.global_variables_initializer()) 45 | np_outputs = session.run(logits) 46 | self.assertListEqual([batch_size, height, width, num_outputs], 47 | list(np_outputs.shape)) 48 | 49 | def test_output_size_conv2d_transpose(self): 50 | batch_size = 2 51 | height, width = 256, 256 52 | num_outputs = 4 53 | 54 | images = tf.ones((batch_size, height, width, 3)) 55 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 56 | logits, _ = pix2pix.pix2pix_generator( 57 | images, num_outputs, blocks=self._reduced_default_blocks(), 58 | upsample_method='conv2d_transpose') 59 | 60 | with self.test_session() as session: 61 | session.run(tf.global_variables_initializer()) 62 | np_outputs = session.run(logits) 63 | self.assertListEqual([batch_size, height, width, num_outputs], 64 | list(np_outputs.shape)) 65 | 66 | def test_block_number_dictates_number_of_layers(self): 67 | batch_size = 2 68 | height, width = 256, 256 69 | num_outputs = 4 70 | 71 | images = tf.ones((batch_size, height, width, 3)) 72 | blocks = [ 73 | pix2pix.Block(64, 0.5), 74 | pix2pix.Block(128, 0), 75 | ] 76 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 77 | _, end_points = pix2pix.pix2pix_generator( 78 | images, num_outputs, blocks) 79 | 80 | num_encoder_layers = 0 81 | num_decoder_layers = 0 82 | for end_point in end_points: 83 | if end_point.startswith('encoder'): 84 | num_encoder_layers += 1 85 | elif end_point.startswith('decoder'): 86 | num_decoder_layers += 1 87 | 88 | self.assertEqual(num_encoder_layers, len(blocks)) 89 | self.assertEqual(num_decoder_layers, len(blocks)) 90 | 91 | 92 | class DiscriminatorTest(tf.test.TestCase): 93 | 94 | def _layer_output_size(self, input_size, kernel_size=4, stride=2, pad=2): 95 | return (input_size + pad * 2 - kernel_size) // stride + 1 96 | 97 | def test_four_layers(self): 98 | batch_size = 2 99 | input_size = 256 100 | 101 | output_size = self._layer_output_size(input_size) 102 | output_size = self._layer_output_size(output_size) 103 | output_size = self._layer_output_size(output_size) 104 | output_size = self._layer_output_size(output_size, stride=1) 105 | output_size = self._layer_output_size(output_size, stride=1) 106 | 107 | images = tf.ones((batch_size, input_size, input_size, 3)) 108 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 109 | logits, end_points = pix2pix.pix2pix_discriminator( 110 | images, num_filters=[64, 128, 256, 512]) 111 | self.assertListEqual([batch_size, output_size, output_size, 1], 112 | logits.shape.as_list()) 113 | self.assertListEqual([batch_size, output_size, output_size, 1], 114 | end_points['predictions'].shape.as_list()) 115 | 116 | def test_four_layers_no_padding(self): 117 | batch_size = 2 118 | input_size = 256 119 | 120 | output_size = self._layer_output_size(input_size, pad=0) 121 | output_size = self._layer_output_size(output_size, pad=0) 122 | output_size = self._layer_output_size(output_size, pad=0) 123 | output_size = self._layer_output_size(output_size, stride=1, pad=0) 124 | output_size = self._layer_output_size(output_size, stride=1, pad=0) 125 | 126 | images = tf.ones((batch_size, input_size, input_size, 3)) 127 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 128 | logits, end_points = pix2pix.pix2pix_discriminator( 129 | images, num_filters=[64, 128, 256, 512], padding=0) 130 | self.assertListEqual([batch_size, output_size, output_size, 1], 131 | logits.shape.as_list()) 132 | self.assertListEqual([batch_size, output_size, output_size, 1], 133 | end_points['predictions'].shape.as_list()) 134 | 135 | def test_four_layers_wrog_paddig(self): 136 | batch_size = 2 137 | input_size = 256 138 | 139 | images = tf.ones((batch_size, input_size, input_size, 3)) 140 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 141 | with self.assertRaises(TypeError): 142 | pix2pix.pix2pix_discriminator( 143 | images, num_filters=[64, 128, 256, 512], padding=1.5) 144 | 145 | def test_four_layers_negative_padding(self): 146 | batch_size = 2 147 | input_size = 256 148 | 149 | images = tf.ones((batch_size, input_size, input_size, 3)) 150 | with tf.contrib.framework.arg_scope(pix2pix.pix2pix_arg_scope()): 151 | with self.assertRaises(ValueError): 152 | pix2pix.pix2pix_discriminator( 153 | images, num_filters=[64, 128, 256, 512], padding=-1) 154 | 155 | if __name__ == '__main__': 156 | tf.test.main() 157 | -------------------------------------------------------------------------------- /ocr/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | @Author: JK_Wang 4 | @Time: 13-May-19 5 | """ -------------------------------------------------------------------------------- /ocr/ocr.py: -------------------------------------------------------------------------------- 1 | """ 2 | 电子烟雾化器产品批号字符识别 3 | @Author: JK_Wang 4 | @Time: 13-May-19 5 | """ 6 | import cv2 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | ROI = {'x1': 130, 'y1': 90, 'x2': 440, 'y2': 195} 11 | THRESH = 80 12 | 13 | 14 | def c_roi(img_path): 15 | img = cv2.imread(img_path, 0) 16 | roi_img = img[ROI['y1']:ROI['y2'], ROI['x1']:ROI['x2']] 17 | cv2.imshow('1', roi_img) 18 | return roi_img 19 | 20 | 21 | def img_morphology(img_gray, method=0, kernel_size=3, kernel_type=0): 22 | """ 23 | 形态学操作 24 | :param img_gray: 灰度图 25 | :param method: 方法,0:腐蚀,1:膨胀,2:开运算,3:闭运算 26 | :param kernel_size: 窗口的大小 27 | :param kernel_type: 窗口的类型,0:矩形,1:椭圆形,2:交叉形 28 | :return: 29 | """ 30 | if method == 0: 31 | # 腐蚀 32 | img_gray = cv2.erode(img_gray, (kernel_size, kernel_size), iterations=2) 33 | elif method == 1: 34 | # 膨胀 35 | img_gray = cv2.dilate(img_gray, (kernel_size, kernel_size), iterations=2) 36 | else: 37 | if kernel_type == 0: 38 | k_type = cv2.MORPH_RECT 39 | elif kernel_type == 1: 40 | k_type = cv2.MORPH_ELLIPSE 41 | else: 42 | k_type = cv2.MORPH_CROSS 43 | kernel = cv2.getStructuringElement(k_type, (kernel_size, kernel_size)) 44 | if method == 2: 45 | # 开运算 46 | img_gray = cv2.morphologyEx(img_gray, cv2.MORPH_OPEN, kernel) 47 | else: 48 | # 闭运算 49 | img_gray = cv2.morphologyEx(img_gray, cv2.MORPH_CLOSE, kernel) 50 | return img_gray 51 | 52 | 53 | def c_location(roi_img): 54 | # blur_img = cv2.medianBlur(roi_img, 3) 55 | # blur_img = cv2.equalizeHist(blur_img) 56 | blur_img = cv2.bilateralFilter(roi_img, 9, 5, 5) 57 | cv2.imshow('2', blur_img) 58 | 59 | edge = cv2.Canny(blur_img, 50, 150) 60 | cv2.imshow('3', edge) 61 | 62 | _, th_img = cv2.threshold(roi_img, THRESH, 255, cv2.THRESH_OTSU) # cv2.THRESH_BINARY 63 | cv2.imshow('4', th_img) 64 | 65 | # loc_img = cv2.addWeighted(edge, 1, th_img, 1, 0) 66 | loc_img = img_morphology(th_img, 2, 5) 67 | loc_img = img_morphology(loc_img, 3, 9) 68 | # loc_img = th_img 69 | 70 | v_sum = np.sum(loc_img, 0, dtype=np.float) 71 | seg_pos = [] 72 | x, y = 0, 0 73 | for i in range(len(v_sum)-1): 74 | if v_sum[i] == 0 and v_sum[i+1] != 0: 75 | x = i 76 | elif v_sum[i] != 0 and v_sum[i+1] == 0: 77 | y = i 78 | seg_pos.append([x, y]) 79 | 80 | k = 10 81 | for i, j in seg_pos: 82 | c_img = roi_img[:, i:j] 83 | c_img = cv2.equalizeHist(c_img) 84 | _, c_img = cv2.threshold(c_img, THRESH, 255, cv2.THRESH_OTSU) 85 | c_img = img_morphology(c_img, 3, 3) 86 | k += 1 87 | cv2.imshow(str(k), c_img) 88 | 89 | plt.plot(v_sum) 90 | # plt.axis([0, 250, 0, 140]) 91 | plt.xlabel('列坐标', fontproperties='SimHei', fontsize=14) 92 | plt.ylabel('白点数量', fontproperties='SimHei', fontsize=14) 93 | plt.show() 94 | 95 | return loc_img 96 | 97 | 98 | def ocr(img_path): 99 | roi_img = c_roi(img_path) 100 | loc_img = c_location(roi_img) 101 | cv2.imshow('img', loc_img) 102 | cv2.waitKey() 103 | cv2.destroyAllWindows() 104 | 105 | 106 | if __name__ == '__main__': 107 | ocr('../data/ocr/2.png') 108 | 109 | -------------------------------------------------------------------------------- /qt/1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/1.jpg -------------------------------------------------------------------------------- /qt/LogDialog.py: -------------------------------------------------------------------------------- 1 | """ 2 | 缺陷检测QT软件--检测记录对话框类 3 | author: 王建坤 4 | date: 2018-10-16 5 | """ 6 | from PyQt5.QtWidgets import QDialog, QTableWidgetItem, QHeaderView, QFileDialog 7 | from PyQt5.uic import loadUi 8 | from PyQt5.QtCore import Qt 9 | import pandas as pd 10 | import pymysql 11 | 12 | 13 | class DetectLog(QDialog): 14 | """ 15 | 检测记录类 16 | """ 17 | def __init__(self, *args): 18 | super(DetectLog, self).__init__(*args) 19 | loadUi('ui_detect_log.ui', self) 20 | 21 | # 连接数据库 22 | self.connection = pymysql.connect(host='localhost', user='root', password='123456', 23 | db='detection', charset='utf8') 24 | self.cursor = self.connection.cursor() 25 | # 检索数据库中所有检测记录 26 | sql = "SELECT path,time,detect_class FROM detection_log" 27 | self.cursor.execute(sql) 28 | self.detect_logs = self.cursor.fetchall() 29 | 30 | self.setWindowFlags(Qt.WindowMinimizeButtonHint | Qt.WindowCloseButtonHint) 31 | self.table_log.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeToContents) 32 | # self.table_log.horizontalHeader().setStyleSheet("QHeaderView::section{background-color:lightgray;};") 33 | # self.table_log.verticalHeader().setStyleSheet("QHeaderView::section{background-color:lightgray;};") 34 | self.pb_clear.clicked.connect(self.slot_clear) 35 | self.pb_save.clicked.connect(self.slot_save) 36 | 37 | # 生成缺陷记录表 38 | rows = len(self.detect_logs) 39 | self.table_log.setRowCount(rows) 40 | for i in range(rows): 41 | self.table_log.setItem(i, 0, QTableWidgetItem(str(self.detect_logs[i][0]))) 42 | self.table_log.setItem(i, 1, QTableWidgetItem(str(self.detect_logs[i][1]))) 43 | self.table_log.setItem(i, 2, QTableWidgetItem(self.detect_logs[i][2])) 44 | 45 | def insert_row(self): 46 | """ 47 | 往表格中插入一行 48 | :return: 49 | """ 50 | row_count = self.table_log.rowCount() 51 | self.table_log.insertRow(row_count) 52 | 53 | def slot_clear(self): 54 | """ 55 | 清空表格 56 | :return: 57 | """ 58 | self.table_log.clearContents() 59 | self.detect_logs = [] 60 | 61 | def slot_save(self): 62 | """ 63 | 保存检测记录的表格 64 | :return: 65 | """ 66 | file_name = QFileDialog.getSaveFileName(self, 'save file', 'log', '.csv') 67 | detect_csv = pd.DataFrame({'图片路径': [x[0] for x in self.detect_logs], 68 | '检测结果': [x[1] for x in self.detect_logs], 69 | '检测时间': [x[2] for x in self.detect_logs]}) 70 | if file_name[0] != '': 71 | detect_csv.to_csv(file_name[0]) 72 | 73 | 74 | class DefectLog(QDialog): 75 | """ 76 | 缺陷记录类 77 | """ 78 | def __init__(self): 79 | super(DefectLog, self).__init__() 80 | loadUi('ui_defect_log.ui', self) 81 | 82 | self.name_class_dic = {'normal': 0, 'nothing': 1, 'lack_cotton': 2, 'lack_piece': 3, 'wire_fail': 4} 83 | # 连接数据库 84 | self.connection = pymysql.connect(host='localhost', user='root', password='123456', 85 | db='detection', charset='utf8') 86 | self.cursor = self.connection.cursor() 87 | # 检索数据库中所有检测为缺陷的记录 88 | sql = "SELECT path,time,detect_class FROM detection_log WHERE detect_class != 'normal'" 89 | self.cursor.execute(sql) 90 | self.defect_logs = self.cursor.fetchall() 91 | 92 | self.setWindowFlags(Qt.WindowMinimizeButtonHint | Qt.WindowCloseButtonHint) 93 | self.table_defect.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeToContents) 94 | # self.table_defect.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch) 95 | self.table_defect.resizeColumnsToContents() 96 | self.table_statistics.horizontalHeader().setSectionResizeMode(QHeaderView.ResizeToContents) 97 | self.pb_save.clicked.connect(self.slot_save) 98 | 99 | # 生成缺陷记录表 100 | defect_rows = len(self.defect_logs) 101 | self.table_defect.setRowCount(defect_rows) 102 | for i in range(defect_rows): 103 | self.table_defect.setItem(i, 0, QTableWidgetItem(str(self.defect_logs[i][0]))) 104 | self.table_defect.setItem(i, 1, QTableWidgetItem(str(self.defect_logs[i][1]))) 105 | self.table_defect.setItem(i, 2, QTableWidgetItem(self.defect_logs[i][2])) 106 | 107 | # 生成缺陷统计表 108 | self.class_num_list = [0] * 4 109 | # 统计不同缺陷的个数,并填充到相应的单元格 110 | for j in self.defect_logs: 111 | defect_class = self.name_class_dic[j[2]] 112 | if -1 < defect_class < 5: 113 | self.class_num_list[defect_class-1] += 1 114 | # 计算各缺陷的占比,并填充到相应的单元格 115 | for k in range(4): 116 | self.table_statistics.setItem(k, 1, QTableWidgetItem(str(self.class_num_list[k-1]))) 117 | if sum(self.class_num_list): 118 | self.table_statistics.setItem(k, 2, QTableWidgetItem('%.2f' % (self.class_num_list[k-1]/sum(self.class_num_list)))) 119 | 120 | def slot_save(self): 121 | """ 122 | 保存缺陷记录的表格 123 | :return: 124 | """ 125 | file_name = QFileDialog.getSaveFileName(self, 'save file', 'log', '.csv') 126 | if file_name[0] == '': 127 | return 128 | if self.tab_table.currentIndex() == 0: 129 | defect_csv = pd.DataFrame({'图片路径': [x[0] for x in self.defect_logs], 130 | '缺陷类别': [x[1] for x in self.defect_logs], 131 | '检测时间': [x[2] for x in self.defect_logs]}) 132 | defect_csv.to_csv(file_name[0]) 133 | else: 134 | statistics_csv = pd.DataFrame({self.table_statistics.horizontalHeaderItem(0).text(): [i for i in range(5)], self.table_statistics.horizontalHeaderItem(1).text(): self.class_num_list, self.table_statistics.horizontalHeaderItem(2).text(): [self.table_statistics.itemAt(i, 3).text() for i in range(5)]}) 135 | statistics_csv.to_csv(file_name[0], index=False) 136 | 137 | 138 | class DatabaseSet(QDialog): 139 | """ 140 | 数据库配置类 141 | """ 142 | def __init__(self): 143 | super(DatabaseSet, self).__init__() 144 | loadUi('ui_database_set.ui', self) 145 | 146 | self.pb_db_ok.clicked.connect(self.slot_db_set) 147 | self.pb_db_test.clicked.connect(self.slot_db_test) 148 | self.pb_db_cancel.clicked.connect(self.close) 149 | 150 | def slot_db_set(self): 151 | """ 152 | 数据库配置 153 | :return: 154 | """ 155 | print('数据库连接成功') 156 | 157 | def slot_db_test(self): 158 | """ 159 | 测试数据库连接 160 | :return: 161 | """ 162 | ip = self.le_ip.text() 163 | port = int(self.le_port.text()) 164 | ac = self.le_ac.text() 165 | pw = self.le_pw.text() 166 | db = self.le_db.text() 167 | 168 | try: 169 | pymysql.connect(host=ip, port=port, user=ac, password=pw, db=db, charset='utf8') 170 | # pymysql.connect(host='localhost', user='root', password='123456', db='detection', charset='utf8') 171 | self.le_test_res.setText('连接测试成功') 172 | print('数据库连接测试成功') 173 | except: 174 | self.le_test_res.setText('连接测试失败') 175 | print('数据库连接测试失败') 176 | -------------------------------------------------------------------------------- /qt/__pycache__/LogDialog.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/__pycache__/LogDialog.cpython-36.pyc -------------------------------------------------------------------------------- /qt/__pycache__/MainWindow.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/__pycache__/MainWindow.cpython-36.pyc -------------------------------------------------------------------------------- /qt/__pycache__/predict.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/__pycache__/predict.cpython-36.pyc -------------------------------------------------------------------------------- /qt/__pycache__/predict_dl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/__pycache__/predict_dl.cpython-36.pyc -------------------------------------------------------------------------------- /qt/__pycache__/predict_ip.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/__pycache__/predict_ip.cpython-36.pyc -------------------------------------------------------------------------------- /qt/cv.py: -------------------------------------------------------------------------------- 1 | """ 2 | 缺陷检测QT软件--图像处理 3 | author: 王建坤 4 | date: 2018-10-17 5 | """ 6 | import cv2 7 | 8 | 9 | def put_text(img_path, pre_class): 10 | """ 11 | 打开图片绘制文字 12 | :param img_path: 13 | :param pre_class: 14 | :return: 15 | """ 16 | # class_name_dic = {0: '正常', 1: '不导电', 2: '擦花', 3: '角位漏底', 4: '桔皮', 5: '漏底', 6: '起坑', 7: '脏点'} 17 | img = cv2.imread(img_path) 18 | img = cv2.resize(img, (500, 500)) 19 | cv2.putText(img, str(pre_class), (img.shape[1]-100, img.shape[0]-20), 20 | cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2) 21 | cv2.imshow('1', img) 22 | cv2.waitKey() 23 | cv2.destroyAllWindows() 24 | 25 | 26 | if __name__ == '__main__': 27 | put_text('../data/test1.jpg', 1) 28 | -------------------------------------------------------------------------------- /qt/hello-world-master/1.txt: -------------------------------------------------------------------------------- 1 | version 1.1 2 | link https://www.baidu.com/img/bd_logo1.png 3 | -------------------------------------------------------------------------------- /qt/hello-world-master/README.md: -------------------------------------------------------------------------------- 1 | hello-world 2 | ======= 3 | ok 4 | -------------------------------------------------------------------------------- /qt/log/1: -------------------------------------------------------------------------------- 1 | ,缺陷类别,数量,占比 2 | 0,0,0,0 3 | 1,1,2,0 4 | 2,2,0,0 5 | 3,3,0,0 6 | 4,4,0,0 7 | 5,5,0,0 8 | 6,6,0,0 9 | 7,7,0,0 10 | 8,8,0,0 11 | 9,9,0,0 12 | 10,10,0,0 13 | 11,11,0,0 14 | -------------------------------------------------------------------------------- /qt/log/detect_20181016_134111.csv: -------------------------------------------------------------------------------- 1 | ,图片路径,检测结果,检测时间 2 | 0,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:47 3 | 1,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:52 4 | 2,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:52 5 | 3,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:52 6 | 4,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:53 7 | 5,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:53 8 | 6,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:56 9 | 7,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:56 10 | 8,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:56 11 | 9,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:57 12 | 10,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:57 13 | 11,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:57 14 | 12,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:57 15 | 13,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:57 16 | 14,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:57 17 | 15,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:58 18 | 16,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:58 19 | 17,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:58 20 | 18,E:/Defect_Detection/qt/src/校标.png,-1,2018-10-16 13:40:58 21 | -------------------------------------------------------------------------------- /qt/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | 缺陷检测QT软件--启动 3 | author: 王建坤 4 | date: 2018-9-25 5 | """ 6 | import sys 7 | sys.path.append("..") 8 | 9 | import sys 10 | from PyQt5.QtWidgets import QApplication 11 | from qt import MainWindow 12 | 13 | # 实例化一个 App 14 | app = QApplication(sys.argv) 15 | # 实例化一个 窗口 16 | win = MainWindow.Detect() 17 | # 显示窗口 18 | win.show() 19 | # 进入主循环 20 | sys.exit(app.exec()) 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /qt/nets/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /qt/nets/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/nets/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /qt/nets/__pycache__/alexnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/nets/__pycache__/alexnet.cpython-36.pyc -------------------------------------------------------------------------------- /qt/nets/__pycache__/mobilenet_v1.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/nets/__pycache__/mobilenet_v1.cpython-36.pyc -------------------------------------------------------------------------------- /qt/nets/__pycache__/vgg.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/nets/__pycache__/vgg.cpython-36.pyc -------------------------------------------------------------------------------- /qt/nets/alexnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Contains a model definition for AlexNet. 16 | 17 | This work was first described in: 18 | ImageNet Classification with Deep Convolutional Neural Networks 19 | Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton 20 | 21 | and later refined in: 22 | One weird trick for parallelizing convolutional neural networks 23 | Alex Krizhevsky, 2014 24 | 25 | Here we provide the implementation proposed in "One weird trick" and not 26 | "ImageNet Classification", as per the paper, the LRN layers have been removed. 27 | 28 | Usage: 29 | with slim.arg_scope(alexnet.alexnet_v2_arg_scope()): 30 | outputs, end_points = alexnet.alexnet_v2(inputs) 31 | 32 | @@alexnet_v2 33 | """ 34 | 35 | from __future__ import absolute_import 36 | from __future__ import division 37 | from __future__ import print_function 38 | 39 | import tensorflow as tf 40 | 41 | import tensorflow.contrib.slim as slim 42 | 43 | trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev) 44 | 45 | 46 | def alexnet_v2_arg_scope(weight_decay=0.0005): 47 | with slim.arg_scope([slim.conv2d, slim.fully_connected], 48 | activation_fn=tf.nn.relu, 49 | biases_initializer=tf.constant_initializer(0.1), 50 | weights_regularizer=slim.l2_regularizer(weight_decay)): 51 | with slim.arg_scope([slim.conv2d], padding='SAME'): 52 | with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc: 53 | return arg_sc 54 | 55 | 56 | def alexnet_v2(inputs, 57 | num_classes=1000, 58 | is_training=True, 59 | dropout_keep_prob=0.5, 60 | spatial_squeeze=True, 61 | scope='alexnet_v2', 62 | global_pool=False): 63 | """AlexNet version 2. 64 | 65 | Described in: http://arxiv.org/pdf/1404.5997v2.pdf 66 | Parameters from: 67 | github.com/akrizhevsky/cuda-convnet2/blob/master/layers/ 68 | layers-imagenet-1gpu.cfg 69 | 70 | Note: All the fully_connected layers have been transformed to conv2d layers. 71 | To use in classification mode, resize input to 224x224 or set 72 | global_pool=True. To use in fully convolutional mode, set 73 | spatial_squeeze to false. 74 | The LRN layers have been removed and change the initializers from 75 | random_normal_initializer to xavier_initializer. 76 | 77 | Args: 78 | inputs: a tensor of size [batch_size, height, width, channels]. 79 | num_classes: the number of predicted classes. If 0 or None, the logits layer 80 | is omitted and the input features to the logits layer are returned instead. 81 | is_training: whether or not the model is being trained. 82 | dropout_keep_prob: the probability that activations are kept in the dropout 83 | layers during training. 84 | spatial_squeeze: whether or not should squeeze the spatial dimensions of the 85 | logits. Useful to remove unnecessary dimensions for classification. 86 | scope: Optional scope for the variables. 87 | global_pool: Optional boolean flag. If True, the input to the classification 88 | layer is avgpooled to size 1x1, for any input size. (This is not part 89 | of the original AlexNet.) 90 | 91 | Returns: 92 | net: the output of the logits layer (if num_classes is a non-zero integer), 93 | or the non-dropped-out input to the logits layer (if num_classes is 0 94 | or None). 95 | end_points: a dict of tensors with intermediate activations. 96 | """ 97 | with tf.variable_scope(scope, 'alexnet_v2', [inputs]) as sc: 98 | end_points_collection = sc.original_name_scope + '_end_points' 99 | # Collect outputs for conv2d, fully_connected and max_pool2d. 100 | with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d], 101 | outputs_collections=[end_points_collection]): 102 | net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID', 103 | scope='conv1') 104 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool1') 105 | net = slim.conv2d(net, 192, [5, 5], scope='conv2') 106 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool2') 107 | net = slim.conv2d(net, 384, [3, 3], scope='conv3') 108 | net = slim.conv2d(net, 384, [3, 3], scope='conv4') 109 | net = slim.conv2d(net, 256, [3, 3], scope='conv5') 110 | net = slim.max_pool2d(net, [3, 3], 2, scope='pool5') 111 | 112 | # Use conv2d instead of fully_connected layers. 113 | with slim.arg_scope([slim.conv2d], 114 | weights_initializer=trunc_normal(0.005), 115 | biases_initializer=tf.constant_initializer(0.1)): 116 | net = slim.conv2d(net, 4096, [5, 5], padding='VALID', 117 | scope='fc6') 118 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 119 | scope='dropout6') 120 | net = slim.conv2d(net, 4096, [1, 1], scope='fc7') 121 | # Convert end_points_collection into a end_point dict. 122 | end_points = slim.utils.convert_collection_to_dict( 123 | end_points_collection) 124 | if global_pool: 125 | net = tf.reduce_mean(net, [1, 2], keep_dims=True, name='global_pool') 126 | end_points['global_pool'] = net 127 | if num_classes: 128 | net = slim.dropout(net, dropout_keep_prob, is_training=is_training, 129 | scope='dropout7') 130 | net = slim.conv2d(net, num_classes, [1, 1], 131 | activation_fn=None, 132 | normalizer_fn=None, 133 | biases_initializer=tf.zeros_initializer(), 134 | scope='fc8') 135 | if spatial_squeeze: 136 | net = tf.squeeze(net, [1, 2], name='fc8/squeezed') 137 | end_points[sc.name + '/fc8'] = net 138 | return net, end_points 139 | 140 | 141 | alexnet_v2.default_image_size = 224 142 | -------------------------------------------------------------------------------- /qt/predict_dl.py: -------------------------------------------------------------------------------- 1 | """ 2 | 缺陷检测QT软件--预测图片 3 | author: 王建坤 4 | date: 2018-9-25 5 | """ 6 | import numpy as np 7 | import tensorflow as tf 8 | from nets import alexnet, mobilenet_v1 9 | from PIL import Image 10 | import time 11 | 12 | CLASSES = 5 13 | IMG_SIZE = 224 14 | GLOBAL_POOL = False 15 | 16 | 17 | def load_model(model='Mobile'): 18 | global x, y, sess 19 | # 占位符 20 | x = tf.placeholder(tf.float32, [None, IMG_SIZE, IMG_SIZE, 1]) 21 | 22 | # 模型保存路径,前向传播 23 | if model == 'Alex': 24 | log_path = 'weight/Alex' 25 | y, _ = alexnet.alexnet_v2(x, 26 | num_classes=CLASSES, # 分类的类别 27 | is_training=False, # 是否在训练 28 | dropout_keep_prob=1.0, # 保留比率 29 | spatial_squeeze=True, # 压缩掉1维的维度 30 | global_pool=GLOBAL_POOL) # 输入不是规定的尺寸时,需要global_pool 31 | elif model == 'Mobile': 32 | log_path = 'weight/Mobile' 33 | y, _ = mobilenet_v1.mobilenet_v1(x, 34 | num_classes=CLASSES, 35 | dropout_keep_prob=1.0, 36 | is_training=False, 37 | min_depth=8, 38 | depth_multiplier=1.0, 39 | conv_defs=None, 40 | prediction_fn=None, 41 | spatial_squeeze=True, 42 | reuse=None, 43 | scope='MobilenetV1', 44 | global_pool=GLOBAL_POOL) 45 | else: 46 | print('Error: model name not exist') 47 | return 48 | y = tf.nn.softmax(y) 49 | 50 | saver = tf.train.Saver() 51 | 52 | sess = tf.Session() 53 | # 恢复模型权重 54 | # print('Reading checkpoints from: ', model) 55 | ckpt = tf.train.get_checkpoint_state(log_path) 56 | if ckpt and ckpt.model_checkpoint_path: 57 | global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] 58 | saver.restore(sess, ckpt.model_checkpoint_path) 59 | # print('Loading success, global_step is %s' % global_step) 60 | else: 61 | print('Error: no checkpoint file found') 62 | return -1 63 | 64 | 65 | def close_sess(): 66 | sess.close() 67 | tf.reset_default_graph() 68 | print('Close the session successfully') 69 | 70 | 71 | def predict(m=0, img_path='E:\std.jpg'): 72 | start_time = time.clock() 73 | img = Image.open(img_path) 74 | if img.mode != 'L': 75 | print('Error: the image format is not support') 76 | return -1, -1 77 | img = img.resize((IMG_SIZE, IMG_SIZE)) 78 | img = np.array(img, np.float32) 79 | img = np.expand_dims(img, axis=0) 80 | img = np.expand_dims(img, axis=3) 81 | predictions = sess.run(y, feed_dict={x: img}) 82 | pre = np.argmax(predictions) 83 | end_time = time.clock() 84 | run_time = round((end_time - start_time)*1000, 1) 85 | if m == 0: 86 | print('class: %d running time: %s ms ' % (int(pre), run_time)) 87 | # print('prediction is:', '\n', predictions) 88 | return pre, run_time 89 | 90 | 91 | if __name__ == '__main__': 92 | print('run predict:') 93 | # load_model() 94 | -------------------------------------------------------------------------------- /qt/src/校标.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/src/校标.png -------------------------------------------------------------------------------- /qt/src/欢迎.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/src/欢迎.jpg -------------------------------------------------------------------------------- /qt/test.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from PyQt5.QtWidgets import QApplication, QMainWindow, QDialog, QPushButton, QWidget 3 | from PyQt5.QtCore import QObject, Qt, pyqtSignal 4 | 5 | 6 | class MySignal(QObject): 7 | instance = None 8 | signal = pyqtSignal() 9 | status_signal = pyqtSignal(str) 10 | 11 | @classmethod 12 | def my_signal(cls): 13 | if cls.instance: 14 | return cls.instance 15 | else: 16 | obj = cls() 17 | cls.instance = obj 18 | return cls.instance 19 | 20 | def em(self): 21 | print(id(self.signal)) 22 | self.signal.emit() 23 | 24 | def status_emit(self, s): 25 | self.status_signal.emit(s) 26 | 27 | 28 | class MyPushButton(QPushButton): 29 | def __init__(self, *args): 30 | super(MyPushButton, self).__init__(*args) 31 | 32 | self.setMouseTracking(True) 33 | 34 | def mouseMoveEvent(self, event): 35 | MySignal.my_signal().status_emit('X:'+str(event.pos().x())+' Y:'+str(event.pos().y())) 36 | self.update() 37 | 38 | 39 | class MainWindow(QMainWindow): 40 | """ 41 | 主窗口类 42 | """ 43 | Signal = MySignal.my_signal().signal 44 | Status_signal = MySignal.my_signal().status_signal 45 | 46 | print(id(Signal), '1') 47 | 48 | def __init__(self, *args): 49 | super(MainWindow, self).__init__(*args) 50 | 51 | # 设置主窗口的标题及大小 52 | self.setWindowTitle('主窗口') 53 | self.resize(400, 300) 54 | 55 | # 创建按钮 56 | self.btn = MyPushButton(self) 57 | self.btn.setText('自定义按钮') 58 | self.btn.move(50, 50) 59 | self.btn.clicked.connect(self.show_dialog) 60 | 61 | # 自定义信号绑定 62 | self.Signal.connect(self.test) 63 | self.Status_signal.connect(self.show_status) 64 | 65 | self.dialog = Dialog() 66 | 67 | def show_dialog(self): 68 | self.dialog.show() 69 | self.dialog.exec() 70 | 71 | def test(self): 72 | self.btn.setText('我改变了') 73 | 74 | def show_status(self, s): 75 | self.statusBar().showMessage(s) 76 | 77 | def keyPressEvent(self, event): 78 | if event.key() == Qt.Key_Home: 79 | print('Home') 80 | else: 81 | QWidget.keyPressEvent(self, event) 82 | 83 | 84 | class Dialog(QDialog): 85 | """ 86 | 对话框类 87 | """ 88 | def __init__(self, *args): 89 | super(Dialog, self).__init__(*args) 90 | 91 | # 设置对话框的标题及大小 92 | self.setWindowTitle('对话框') 93 | self.resize(200, 200) 94 | self.setWindowModality(Qt.ApplicationModal) 95 | self.btn = QPushButton(self) 96 | self.btn.setText('改变主窗口按钮的名称') 97 | self.btn.move(50, 50) 98 | self.btn.clicked.connect(MySignal.my_signal().em) 99 | print(id(MySignal.my_signal().signal)) 100 | 101 | 102 | if __name__ == '__main__': 103 | app = QApplication(sys.argv) 104 | demo = MainWindow() 105 | demo.show() 106 | sys.exit(app.exec()) 107 | 108 | 109 | -------------------------------------------------------------------------------- /qt/test1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wang-jiankun/Defect-Detection/e650be31b06115d3754bf9a636baed0c6a7ba037/qt/test1.jpg -------------------------------------------------------------------------------- /qt/ui_defect_log.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | Dialog 4 | 5 | 6 | 7 | 0 8 | 0 9 | 520 10 | 500 11 | 12 | 13 | 14 | 15 | 微软雅黑 16 | 10 17 | 18 | 19 | 20 | 缺陷记录 21 | 22 | 23 | 24 | 25 | 26 | 1 27 | 28 | 29 | 30 | 缺陷记录 31 | 32 | 33 | 34 | 5 35 | 36 | 37 | 5 38 | 39 | 40 | 5 41 | 42 | 43 | 5 44 | 45 | 46 | 47 | 48 | 49 | 0 50 | 0 51 | 52 | 53 | 54 | 55 | 56 | 57 | QAbstractScrollArea::AdjustToContents 58 | 59 | 60 | QAbstractItemView::NoEditTriggers 61 | 62 | 63 | 0 64 | 65 | 66 | 3 67 | 68 | 69 | 160 70 | 71 | 72 | 70 73 | 74 | 75 | false 76 | 77 | 78 | 79 | 图片路径 80 | 81 | 82 | 83 | 84 | 检测时间 85 | 86 | 87 | 88 | 89 | 缺陷类别 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 缺陷统计 99 | 100 | 101 | 102 | 5 103 | 104 | 105 | 5 106 | 107 | 108 | 5 109 | 110 | 111 | 5 112 | 113 | 114 | 5 115 | 116 | 117 | 118 | 119 | 4 120 | 121 | 122 | 3 123 | 124 | 125 | false 126 | 127 | 128 | 160 129 | 130 | 131 | 25 132 | 133 | 134 | true 135 | 136 | 137 | false 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 缺陷类别 146 | 147 | 148 | 149 | 150 | 数量 151 | 152 | 153 | 154 | 155 | 占比 156 | 157 | 158 | 159 | 160 | 0 161 | 162 | 163 | 164 | 165 | 1 166 | 167 | 168 | 169 | 170 | 2 171 | 172 | 173 | 174 | 175 | 3 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | Qt::Horizontal 190 | 191 | 192 | 193 | 40 194 | 20 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 微软雅黑 204 | 12 205 | 206 | 207 | 208 | 保存 209 | 210 | 211 | false 212 | 213 | 214 | 215 | 216 | 217 | 218 | Qt::Horizontal 219 | 220 | 221 | 222 | 40 223 | 20 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | -------------------------------------------------------------------------------- /qt/ui_detect_log.ui: -------------------------------------------------------------------------------- 1 | 2 | 3 | Dialog_log 4 | 5 | 6 | Qt::NonModal 7 | 8 | 9 | 10 | 0 11 | 0 12 | 500 13 | 500 14 | 15 | 16 | 17 | 18 | 微软雅黑 19 | 10 20 | 21 | 22 | 23 | 检查记录 24 | 25 | 26 | 27 | 28 | 29 | 30 | 0 31 | 0 32 | 33 | 34 | 35 | 36 | 37 | 38 | QAbstractScrollArea::AdjustToContents 39 | 40 | 41 | QAbstractItemView::NoEditTriggers 42 | 43 | 44 | 0 45 | 46 | 47 | 3 48 | 49 | 50 | 160 51 | 52 | 53 | 80 54 | 55 | 56 | false 57 | 58 | 59 | 60 | 图片路径 61 | 62 | 63 | 64 | 65 | 检测时间 66 | 67 | 68 | 69 | 70 | 检测结果 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | Qt::Horizontal 81 | 82 | 83 | 84 | 40 85 | 20 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 微软雅黑 95 | 12 96 | 97 | 98 | 99 | 清空 100 | 101 | 102 | false 103 | 104 | 105 | 106 | 107 | 108 | 109 | Qt::Horizontal 110 | 111 | 112 | 113 | 40 114 | 20 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 微软雅黑 124 | 12 125 | 126 | 127 | 128 | 保存 129 | 130 | 131 | false 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | 局部代码测试 3 | @Author: JK_Wang 4 | @Time: 20-May-19 5 | """ 6 | # import pymysql 7 | # 8 | # connection = pymysql.connect(host='localhost', user='root', password='123456', db='detection', charset='utf8') 9 | # print(connection) 10 | # cursor = connection.cursor() 11 | # sql = "INSERT INTO detection_log(detect_class, path) VALUES(%s, %s)" 12 | # cursor.execute(sql, ('1', '2')) 13 | # connection.commit() 14 | 15 | # import cv2 16 | # 17 | # img = cv2.imread('D:/test.jpg') 18 | # print(img.shape) 19 | # _, im = cv2.imencode('.jpg', img) 20 | # res = cv2.imdecode(im, -1) 21 | # print(res) 22 | # # cv2.imshow('1', im) 23 | # print(im.shape) 24 | # cv2.imwrite('D:/test.jpg', im) 25 | # 26 | # cv2.waitKey() 27 | 28 | import pymysql 29 | import time 30 | import random as rd 31 | connection = pymysql.connect(host='localhost', user='root', password='123456', 32 | db='detection', charset='utf8') 33 | cursor = connection.cursor() 34 | class_name_dic = {0: 'normal', 1: 'nothing', 2: 'lack_cotton', 3: 'lack_piece', 4: 'wire_fail'} 35 | defect_num = 0 36 | t_sum = [0.0, 0.0, 0.0, 0.0] 37 | for i in range(1, 1000): 38 | print(i) 39 | time.sleep(rd.randint(15, 25)/10) 40 | a = rd.randint(0, 100) 41 | t1 = rd.randint(19, 23)/10 42 | t2 = rd.randint(25, 29) / 10 43 | t3 = rd.randint(18, 22) / 10 44 | t4 = rd.randint(24, 26) / 10 45 | t_sum[0] += t1 46 | t_sum[1] += t2 47 | t_sum[2] += t3 48 | t_sum[3] += t4 49 | avg1 = round(t_sum[0] / i, 1) 50 | avg2 = round(t_sum[1] / i, 1) 51 | avg3 = round(t_sum[2] / i, 1) 52 | avg4 = round(t_sum[3] / i, 1) 53 | 54 | if a < 2: 55 | defect_num += 1 56 | sql_1 = "UPDATE running_state set uph = %s, detection_num = %s, defect_num = %s where id = 1" 57 | sql_2 = "UPDATE chart_data set step_1 = %s, step_2 = %s, step_3 = %s, step_4 = %s where id = 1" 58 | sql_3 = "UPDATE chart_data set step_1 = %s, step_2 = %s, step_3 = %s, step_4 = %s where id = 2" 59 | cursor.execute(sql_1, (str(i), str(1536 + i), str(defect_num))) 60 | cursor.execute(sql_2, (str(t1), str(t2), str(t3), str(t4))) 61 | cursor.execute(sql_3, (str(avg1), str(avg2), str(avg3), str(avg4))) 62 | 63 | # 提交事务 64 | connection.commit() --------------------------------------------------------------------------------